├── .envrc ├── .github ├── dependabot.yml └── workflows │ ├── ci.yaml │ └── cleancache.yml ├── .gitignore ├── .goreleaser.yaml ├── .tool-versions ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── cmd ├── cmd.go ├── multifile │ ├── manifest.go │ ├── manifest_test.go │ └── multifile.go ├── root │ └── root.go └── version │ └── version.go ├── go.mod ├── go.sum ├── main.go ├── pkg ├── cli │ ├── common.go │ ├── common_test.go │ └── pid.go ├── client │ ├── client.go │ └── client_test.go ├── config │ ├── config.go │ ├── config_test.go │ └── optnames.go ├── consistent │ ├── consistent.go │ └── consistent_test.go ├── consumer │ ├── consumer.go │ ├── consumer_test.go │ ├── null.go │ ├── null_test.go │ ├── tar_extractor.go │ ├── tar_extractor_test.go │ ├── write_file.go │ └── write_file_test.go ├── download │ ├── buffer.go │ ├── buffer_slow_test.go │ ├── buffer_test.go │ ├── buffer_unit_test.go │ ├── common.go │ ├── common_test.go │ ├── consistent_hashing.go │ ├── consistent_hashing_test.go │ ├── options.go │ ├── reader_promise.go │ ├── reader_promise_test.go │ ├── strategy.go │ └── work_queue.go ├── extract │ ├── compression.go │ ├── compression_test.go │ ├── tar.go │ └── tar_test.go ├── logging │ └── log.go ├── pget.go ├── pget_test.go └── version │ ├── info.go │ └── info_test.go ├── script ├── format ├── lint └── test └── tools.go /.envrc: -------------------------------------------------------------------------------- 1 | use asdf 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "*" 9 | pull_request: 10 | workflow_dispatch: 11 | 12 | # Ensure only one workflow instance runs at a time. For branches other than the 13 | # default branch, cancel the pending jobs in the group. For the default branch, 14 | # queue them up. This avoids cancelling jobs that are in the middle of deploying 15 | # to production. 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }} 18 | cancel-in-progress: ${{ github.ref != format('refs/heads/{0}', github.event.repository.default_branch) }} 19 | 20 | jobs: 21 | test: 22 | name: "Test" 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@master 26 | - uses: actions/setup-go@v5 27 | with: 28 | go-version-file: go.mod 29 | cache: true 30 | - run: "make test" 31 | name: Run test 32 | 33 | goreleaser_config: 34 | name: Test Goreleaser Config 35 | runs-on: ubuntu-latest 36 | steps: 37 | - uses: actions/checkout@v4 38 | name: "Checkout" 39 | with: 40 | fetch-depth: 0 41 | - uses: actions/setup-go@v5 42 | name: "Set up Go" 43 | with: 44 | go-version-file: go.mod 45 | - uses: goreleaser/goreleaser-action@v6 46 | with: 47 | args: check 48 | 49 | lint: 50 | name: "Lint" 51 | if: ${{ github.event_name == 'pull_request' }} 52 | runs-on: ubuntu-latest 53 | steps: 54 | - uses: actions/checkout@v4 55 | - uses: actions/setup-go@v5 56 | with: 57 | go-version-file: go.mod 58 | - run: go mod download 59 | - name: Lint 60 | run: script/lint 61 | - name: Formatting 62 | run: CHECKONLY=1 script/format 63 | 64 | build: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v4 68 | name: "Checkout" 69 | with: 70 | fetch-depth: 0 71 | - uses: actions/setup-go@v5 72 | name: "Set up Go" 73 | with: 74 | go-version-file: go.mod 75 | - uses: goreleaser/goreleaser-action@v6 76 | name: "Build Snapshot" 77 | with: 78 | version: latest 79 | args: build --clean --snapshot 80 | 81 | release: 82 | runs-on: ubuntu-latest 83 | steps: 84 | - uses: actions/checkout@v4 85 | name: "Checkout" 86 | with: 87 | fetch-depth: 0 88 | - uses: actions/setup-go@v5 89 | name: "Set up Go" 90 | with: 91 | go-version-file: go.mod 92 | - uses: goreleaser/goreleaser-action@v6 93 | name: "Release" 94 | if: startsWith(github.ref, 'refs/tags/') 95 | with: 96 | version: latest 97 | args: release --clean 98 | env: 99 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 100 | -------------------------------------------------------------------------------- /.github/workflows/cleancache.yml: -------------------------------------------------------------------------------- 1 | name: Clean up branch caches 2 | on: 3 | pull_request: 4 | types: 5 | - closed 6 | 7 | jobs: 8 | cleanup: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Remove branch caches 14 | run: | 15 | gh extension install actions/gh-actions-cache 16 | 17 | REPO=${{ github.repository }} 18 | BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" 19 | 20 | echo "Fetching list of cache keys" 21 | cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1) 22 | 23 | ## Setting this to not fail the workflow while deleting cache keys. 24 | echo "Deleting caches..." 25 | for cacheKey in $cacheKeysForPR; do 26 | gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm || : 27 | done 28 | echo "Done" 29 | env: 30 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Allowlisting gitignore prevents us 2 | # from adding various unwanted local files, such as generated 3 | # files, developer configurations or IDE-specific files etc. 4 | 5 | 6 | # Ignore everything 7 | * 8 | 9 | # Track these files 10 | !/.gitignore 11 | 12 | !*.go 13 | !go.sum 14 | !go.mod 15 | 16 | !README.md 17 | !LICENSE 18 | !Makefile 19 | !CONTRIBUTING.md 20 | !script/ 21 | 22 | !.envrc 23 | !.goreleaser.yaml 24 | !.tool-versions 25 | 26 | # persist the track directives for files in subdirectories 27 | !*/ 28 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | project_name: pget 2 | version: 2 3 | before: 4 | hooks: 5 | - go mod tidy 6 | builds: 7 | - binary: pget 8 | id: pget 9 | env: 10 | - CGO_ENABLED=0 11 | goos: 12 | - darwin 13 | - linux 14 | goarch: 15 | - amd64 16 | - arm64 17 | main: ./main.go 18 | ldflags: 19 | - "-s -w -X github.com/replicate/pget/pkg/version.Version={{.Version}} -X github.com/replicate/pget/pkg/version.CommitHash={{.ShortCommit}} -X github.com/replicate/pget/pkg/version.BuildTime={{.Date}} -X github.com/replicate/pget/pkg/version.Prerelease={{.Prerelease}} -X github.com/replicate/pget/pkg/version.OS={{.Os}} -X github.com/replicate/pget/pkg/version.Arch={{if eq .Arch \"amd64\"}}x86_64{{else if eq .Arch \"386\"}}i386{{else}}{{.Arch}}{{end}} -X github.com/replicate/pget/pkg/version.Snapshot={{.IsSnapshot}} -X github.com/replicate/pget/pkg/version.Branch={{.Branch}}" 20 | archives: 21 | - formats: [ 'binary' ] 22 | name_template: >- 23 | {{ .ProjectName }}_{{ title .Os }}_ 24 | {{- if eq .Arch "amd64" }}x86_64 25 | {{- else if eq .Arch "386" }}i386 26 | {{- else }}{{ .Arch }}{{end -}} 27 | checksum: 28 | name_template: "checksums.txt" 29 | snapshot: 30 | version_template: "{{ incminor .Version }}-devbuild" 31 | universal_binaries: 32 | - replace: false 33 | changelog: 34 | sort: asc 35 | filters: 36 | exclude: 37 | - "^docs:" 38 | - "^test:" 39 | release: 40 | # If set to auto, will mark the release as not ready for production 41 | # in case there is an indicator for this in the tag e.g. v1.0.0-alpha 42 | # If set to true, will mark the release as not ready for production. 43 | # Default is false. 44 | prerelease: auto 45 | -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | golang 1.24.1 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | Thanks for your interest in contributing to PGet! We welcome contributions of all kinds, including bug reports, feature requests, documentation improvements, and code contributions. 4 | 5 | ## Running tests 6 | 7 | To run the entire test suite: 8 | 9 | ```sh 10 | make test 11 | ``` 12 | 13 | ## Publishing a release 14 | 15 | This project has a [GitHub Actions workflow](https://github.com/replicate/pget/blob/63220e619c6111a11952e40793ff4efed76a050e/.github/workflows/ci.yaml#L81:L81) that uses [goreleaser](https://goreleaser.com/quick-start/#quick-start) to facilitate the process of publishing new releases. The release process is triggered by manually creating and pushing a new git tag. 16 | 17 | To publish a new release, run the following in your local checkout of pget: 18 | 19 | ```console 20 | git checkout main 21 | git fetch --all --tags 22 | git tag v0.0.11 23 | git push --tags 24 | ``` 25 | 26 | While not required, it is recommended to publish a signed tag using `git tag -s v0.0.11` (example). Pre-release tags can be created by appending a `-` and some string beyond that conforms to gorelearer's concept of semver pre-release (e.g. `-beta10`) 27 | 28 | Then visit [github.com/replicate/pget/actions](https://github.com/replicate/pget/actions) to monitor the release process. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022, Replicate, Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | DESTDIR ?= 4 | PREFIX = /usr/local 5 | BINDIR = $(PREFIX)/bin 6 | 7 | INSTALL := install -m 0755 8 | INSTALL_PROGRAM := $(INSTALL) 9 | 10 | CHECKSUM_CMD := shasum -a 256 11 | CHECKSUM_FILE := sha256sum.txt 12 | 13 | GO := go 14 | GOOS := $(shell $(GO) env GOOS) 15 | GOARCH := $(shell $(GO) env GOARCH) 16 | GOENV := $(shell $(GO) env GOPATH) 17 | GORELEASER := $(GOENV)/bin/goreleaser 18 | 19 | SINGLE_TARGET=--single-target 20 | 21 | default: all 22 | 23 | .PHONY: all 24 | all: clean build 25 | 26 | .PHONY: install 27 | install: build 28 | $(INSTALL_PROGRAM) -d $(DESTDIR)$(BINDIR) 29 | $(INSTALL_PROGRAM) pget $(DESTDIR)$(BINDIR)/pget 30 | 31 | .PHONY: uninstall 32 | uninstall: 33 | rm -f $(DESTDIR)$(BINDIR)/pget 34 | 35 | .PHONY: clean 36 | clean: 37 | $(GO) clean 38 | rm -rf dist 39 | rm -f pget 40 | 41 | 42 | .PHONY: test-all 43 | test-all: test lint 44 | 45 | .PHONY: test 46 | test: 47 | script/test $(ARGS) 48 | 49 | .PHONY: lint 50 | lint: CHECKONLY=1 51 | lint: format 52 | script/lint 53 | 54 | .PHONY: format 55 | format: CHECKONLY=1 56 | format: 57 | CHECKONLY=$(CHECKONLY) script/format 58 | 59 | .PHONY: tidy 60 | tidy: 61 | go mod tidy 62 | 63 | .PHONY: install-goreleaser 64 | install-goreleaser: 65 | $(GO) install github.com/goreleaser/goreleaser/v2@latest 66 | 67 | 68 | .PHONY: build 69 | build: pget 70 | 71 | .PHONY: build-all 72 | build-all: SINGLE_TARGET:= 73 | build-all: clean pget 74 | 75 | pget: install-goreleaser 76 | $(GORELEASER) build --snapshot --clean $(SINGLE_TARGET) -o ./pget 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PGet - Parallel File Downloader & Extractor 2 | 3 | PGet is a high performance, concurrent file downloader built in Go. It is designed to speed up and optimize file downloads from cloud storage services such as Amazon S3 and Google Cloud Storage. 4 | 5 | The primary advantage of PGet is its ability to download files in parallel using multiple threads. By dividing the file into chunks and downloading multiple chunks simultaneously, PGet significantly reduces the total download time for large files. 6 | 7 | If the downloaded file is a tar archive, PGet can automatically extract the contents of the archive in memory, thus removing the need for an additional extraction step. 8 | 9 | The efficiency of PGet's tar extraction lies in its approach to handling data. Instead of writing the downloaded tar file to disk and then reading it back into memory for extraction, PGet conducts the extraction directly from the in-memory download buffer. This method avoids unnecessary memory copies and disk I/O, leading to an increase in performance, especially when dealing with large tar files. This makes PGet not just a parallel downloader, but also an efficient file extractor, providing a streamlined solution for fetching and unpacking files. 10 | 11 | > [!NOTE] 12 | > This project is not related to [Code-Hex/pget](https://github.com/Code-Hex/pget). The two projects share the same name and similar goals, but are completely different codebases with different capabilities. 13 | 14 | ## Install 15 | 16 | You can download and install the latest release of PGet directly from GitHub by running the following commands in a terminal: 17 | 18 | ```console 19 | sudo curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" 20 | sudo chmod +x /usr/local/bin/pget 21 | ``` 22 | 23 | If you're using macOS, you can install PGet with Homebrew: 24 | 25 | ```console 26 | brew tap replicate/tap 27 | brew install replicate/tap/pget 28 | ``` 29 | 30 | Or you can build from source and install it with these commands 31 | (requires Go 1.19 or later): 32 | 33 | ```console 34 | make 35 | sudo make install 36 | ``` 37 | 38 | This builds a static binary that can work inside containers. 39 | 40 | ## Usage 41 | 42 | ### Default Mode 43 | pget [-c concurrency] [-x] 44 | 45 | #### Parameters 46 | 47 | - \: The URL of the file to download. 48 | - \: The destination where the downloaded file will be stored. 49 | - -c concurrency: The number of concurrent downloads. Default is 4 times the number of cores. 50 | - -x: Extract the tar file after download. If not set, the downloaded file will be saved as is. 51 | 52 | #### Default-Mode Command-Line Options 53 | - `-x`, `--extract` 54 | - Extract archive after download 55 | - Type: `bool` 56 | - Default: `false` 57 | 58 | #### Example 59 | 60 | pget https://storage.googleapis.com/replicant-misc/sd15.tar ./sd15 -x 61 | 62 | This command will download Stable Diffusion 1.5 weights to the path ./sd15 with high concurrency. After the file is downloaded, it will be automatically extracted. 63 | 64 | ### Multi-File Mode 65 | pget multifile 66 | 67 | #### Parameters 68 | - \: A path to a manifest file containing (new line delimited) pairs of URLs and local destination file paths. The use of `-` allows for reading from STDIN 69 | 70 | #### Examples 71 | 72 | Read the manifest file from a path on disk: 73 | 74 | pget multifile /path/to/manifest.txt 75 | 76 | Read the manifest file from STDIN: 77 | 78 | pget multifile - < manifest.txt 79 | 80 | Pipe to multifile form from another command: 81 | 82 | cat manifest.txt | pget multifile - 83 | 84 | An example `manifest.txt` file might look like this: 85 | 86 | ```txt 87 | https://example.com/image1.jpg /local/path/to/image1.jpg 88 | https://example.com/document.pdf /local/path/to/document.pdf 89 | https://example.com/music.mp3 /local/path/to/music.mp3 90 | ``` 91 | 92 | #### Multi-file specific options 93 | - `--max-concurrent-files` 94 | - Maximum number of files to download concurrently 95 | - Default: `40` 96 | - Type `Integer` 97 | - `--max-conn-per-host` 98 | - Maximum number of (global) concurrent connections per host 99 | - Default: `40` 100 | - Type `Integer` 101 | 102 | ### Global Command-Line Options 103 | - `--concurrency` 104 | - Maximum number of chunks to download in parallel for a given file 105 | - Type: `Integer` 106 | - Default: `4 * runtime.NumCPU()` 107 | - `--connect-timeout` 108 | - Timeout for establishing a connection, format is , e.g. 10s 109 | - Type: `Duration` 110 | - Default: `5s` 111 | - `-f`, `--force` 112 | - Force download, overwriting existing file 113 | - Type: `bool` 114 | - Default: `false` 115 | - `--log-level` 116 | - Log level (debug, info, warn, error) 117 | - Type: `string` 118 | - Default: `info` 119 | - `-m`, `--chunk-size string` 120 | - Chunk size (in bytes) to use when downloading a file (e.g. 10M) 121 | - Type: `string` 122 | - Default: `125M` 123 | - `--resolve` 124 | - Resolve hostnames to specific IPs, can be specified multiple times, format :: (e.g. example.com:443:127.0.0.1) 125 | - Type: `string 126 | - `-r`, `--retries` 127 | - Number of retries when attempting to retrieve a file 128 | - Type: `Integer` 129 | - Default: `5` 130 | - `-v`, `--verbose` 131 | - Verbose mode (equivalent to `--log-level debug`) 132 | - Type: `bool` 133 | - Default: `false` 134 | 135 | #### Deprecated 136 | - `--max-chunks` (deprecated, use `--concurrency` instead) 137 | - Maximum number of chunks for downloading a given file 138 | - Type: `Integer` 139 | - Default: `4 * runtime.NumCPU()` 140 | - `-m`, `--minimum-chunk-size string` (deprecated, use `--chunk-size` instead) 141 | - Minimum chunk size (in bytes) to use when downloading a file (e.g. 10M) 142 | - Type: `string` 143 | - Default: `16M` 144 | 145 | ## Error Handling 146 | 147 | PGet includes some error handling: 148 | 149 | 1. If a download any chunks fails, it will automatically retry up to 5 times before giving up. 150 | 2. If the downloaded file size does not match the expected size, it will also retry the download. 151 | 152 | ## Future Improvements 153 | 154 | - as chunks are downloaded, start either writing to disk or extracting 155 | - can we check the content hash of the file in the background? 156 | - support for zip files? 157 | -------------------------------------------------------------------------------- /cmd/cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | "github.com/replicate/pget/cmd/multifile" 7 | "github.com/replicate/pget/cmd/root" 8 | "github.com/replicate/pget/cmd/version" 9 | ) 10 | 11 | func GetRootCommand() *cobra.Command { 12 | rootCMD := root.GetCommand() 13 | rootCMD.AddCommand(multifile.GetCommand()) 14 | rootCMD.AddCommand(version.VersionCMD) 15 | return rootCMD 16 | } 17 | -------------------------------------------------------------------------------- /cmd/multifile/manifest.go: -------------------------------------------------------------------------------- 1 | package multifile 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | netUrl "net/url" 10 | "os" 11 | "strings" 12 | 13 | "github.com/spf13/viper" 14 | 15 | pget "github.com/replicate/pget/pkg" 16 | "github.com/replicate/pget/pkg/cli" 17 | "github.com/replicate/pget/pkg/config" 18 | "github.com/replicate/pget/pkg/logging" 19 | ) 20 | 21 | // A manifest is a file consisting of pairs of URLs and paths: 22 | // 23 | // http://example.com/foo/bar.txt foo/bar.txt 24 | // http://example.com/foo/bar/baz.txt foo/bar/baz.txt 25 | // 26 | // A manifest may contain blank lines. 27 | // The pairs are separated by arbitrary whitespace. 28 | // 29 | // When we parse a manifest, we group by URL base (ie scheme://hostname) so that 30 | // all URLs that may share a connection are grouped. 31 | 32 | var errDupeURLDestCombo = errors.New("duplicate destination with different URLs") 33 | 34 | func manifestFile(manifestPath string) (*os.File, error) { 35 | if manifestPath == "-" { 36 | return os.Stdin, nil 37 | } 38 | if _, err := os.Stat(manifestPath); errors.Is(err, fs.ErrNotExist) { 39 | return nil, fmt.Errorf("manifest file %s does not exist", manifestPath) 40 | } 41 | file, err := os.Open(manifestPath) 42 | if err != nil { 43 | return nil, fmt.Errorf("error opening manifest file %s: %w", manifestPath, err) 44 | } 45 | return file, err 46 | } 47 | 48 | func parseLine(line string) (url, dest string, err error) { 49 | fields := strings.Fields(line) 50 | if len(fields) != 2 { 51 | return "", "", fmt.Errorf("error parsing manifest invalid line format `%s`", line) 52 | } 53 | return fields[0], fields[1], nil 54 | } 55 | 56 | func checkSeenDestinations(destinations map[string]string, dest string, url string) error { 57 | if seenURL, ok := destinations[dest]; ok { 58 | if seenURL != url { 59 | return fmt.Errorf("duplicate destination %s with different urls: %s and %s", dest, seenURL, url) 60 | } else { 61 | return errDupeURLDestCombo 62 | } 63 | } 64 | return nil 65 | } 66 | 67 | func parseManifest(file io.Reader) (pget.Manifest, error) { 68 | logger := logging.GetLogger() 69 | seenDestinations := make(map[string]string) 70 | manifest := make(pget.Manifest, 0) 71 | 72 | scanner := bufio.NewScanner(file) 73 | 74 | for scanner.Scan() { 75 | line := strings.TrimSpace(scanner.Text()) 76 | if line == "" { 77 | continue 78 | } 79 | url, dest, err := parseLine(line) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | if _, err := netUrl.Parse(url); err != nil { 85 | return nil, err 86 | 87 | } 88 | 89 | // THIS IS A BODGE - FIX ME MOVE THESE THINGS TO PGET 90 | // and make the consumer responsible for knowing if this 91 | // is allowed/not allowed/etc 92 | consumer := viper.GetString(config.OptOutputConsumer) 93 | if consumer != config.ConsumerNull { 94 | err = checkSeenDestinations(seenDestinations, dest, url) 95 | if err != nil { 96 | if errors.Is(err, errDupeURLDestCombo) { 97 | logger.Warn(). 98 | Str("url", url). 99 | Str("destination", dest). 100 | Msg("Parse Manifest: Skip Duplicate URL/Destination") 101 | continue 102 | } 103 | return nil, err 104 | } 105 | seenDestinations[dest] = url 106 | 107 | err = cli.EnsureDestinationNotExist(dest) 108 | if err != nil { 109 | return nil, err 110 | } 111 | } 112 | manifest = manifest.AddEntry(url, dest) 113 | } 114 | 115 | return manifest, nil 116 | } 117 | -------------------------------------------------------------------------------- /cmd/multifile/manifest_test.go: -------------------------------------------------------------------------------- 1 | package multifile 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | // validManifest is a valid manifest file with additional empty lines 13 | const validManifest = ` 14 | https://example.com/file1.txt /tmp/file1.txt 15 | https://example.com/file2.txt /tmp/file2.txt 16 | 17 | https://example.com/file3.txt /tmp/file3.txt` 18 | 19 | const invalidManifest = `https://example.com/file1.txt` 20 | 21 | func TestParseLine(t *testing.T) { 22 | validLine := "https://example.com/file1.txt /tmp/file1.txt" 23 | validLineTabs := "https://example.com/file1.txt\t/tmp/file1.txt" 24 | validLineMultipleSpace := "https://example.com/file1.txt /tmp/file1.txt" 25 | invalidLine := "https://example.com/file1.txt" 26 | 27 | urlString, dest, err := parseLine(validLine) 28 | assert.Equal(t, "https://example.com/file1.txt", urlString) 29 | assert.Equal(t, "/tmp/file1.txt", dest) 30 | assert.NoError(t, err) 31 | urlString, dest, err = parseLine(validLineTabs) 32 | assert.Equal(t, "https://example.com/file1.txt", urlString) 33 | assert.Equal(t, "/tmp/file1.txt", dest) 34 | assert.NoError(t, err) 35 | urlString, dest, err = parseLine(validLineMultipleSpace) 36 | assert.Equal(t, "https://example.com/file1.txt", urlString) 37 | assert.Equal(t, "/tmp/file1.txt", dest) 38 | assert.NoError(t, err) 39 | 40 | _, _, err = parseLine(invalidLine) 41 | assert.Error(t, err) 42 | } 43 | 44 | func TestCheckSeenDestinations(t *testing.T) { 45 | seenDestinations := map[string]string{ 46 | "/tmp/file1.txt": "https://example.com/file1.txt", 47 | } 48 | 49 | // a different destination is fine 50 | err := checkSeenDestinations(seenDestinations, "/tmp/file2.txt", "https://example.com/file2.txt") 51 | require.NoError(t, err) 52 | 53 | // the same destination with a different URL is not fine 54 | err = checkSeenDestinations(seenDestinations, "/tmp/file1.txt", "https://example.com/file2.txt") 55 | assert.Error(t, err) 56 | 57 | // the same destination with the same URL is fine, we raise a specific error to detect and skip 58 | err = checkSeenDestinations(seenDestinations, "/tmp/file1.txt", "https://example.com/file1.txt") 59 | assert.ErrorIs(t, err, errDupeURLDestCombo) 60 | } 61 | 62 | func TestParseManifest(t *testing.T) { 63 | parsedManifest, err := parseManifest(strings.NewReader(validManifest)) 64 | assert.NoError(t, err) 65 | assert.Len(t, parsedManifest, 3) 66 | 67 | parsedManifest, err = parseManifest(strings.NewReader(invalidManifest)) 68 | assert.Error(t, err) 69 | assert.Len(t, parsedManifest, 0) 70 | } 71 | 72 | func TestManifestFile(t *testing.T) { 73 | tempFile, _ := os.CreateTemp("", "manifest") 74 | defer func() { 75 | tempFile.Close() 76 | os.Remove(tempFile.Name()) 77 | }() 78 | 79 | file1, err := manifestFile("-") 80 | assert.NoError(t, err) 81 | assert.Equal(t, os.Stdin, file1) 82 | 83 | file2, err := manifestFile(tempFile.Name()) 84 | assert.NoError(t, err) 85 | assert.Equal(t, tempFile.Name(), file2.Name()) 86 | 87 | _, err = manifestFile("/does/not/exist") 88 | assert.Error(t, err) 89 | } 90 | -------------------------------------------------------------------------------- /cmd/multifile/multifile.go: -------------------------------------------------------------------------------- 1 | package multifile 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | "github.com/dustin/go-humanize" 10 | "github.com/spf13/cobra" 11 | "github.com/spf13/viper" 12 | 13 | pget "github.com/replicate/pget/pkg" 14 | "github.com/replicate/pget/pkg/cli" 15 | "github.com/replicate/pget/pkg/client" 16 | "github.com/replicate/pget/pkg/config" 17 | "github.com/replicate/pget/pkg/download" 18 | "github.com/replicate/pget/pkg/logging" 19 | ) 20 | 21 | const longDesc = ` 22 | 'multifile' mode for pget takes a manifest file as input (can use '-' for stdin) and downloads all files listed in the manifest. 23 | 24 | The manifest is expected to be in the format of a newline-separated list of pairs of URLs and destination paths, separated by a space. 25 | e.g. 26 | https://example.com/file1.txt /tmp/file1.txt 27 | 28 | 'multifile'' will download files in parallel limited to the '--maximum-connections-per-host' limit for per-host limts and 29 | over-all limited to the '--max-concurrency' limit for overall concurrency. 30 | ` 31 | 32 | const multifileExamples = ` 33 | pget multifile manifest.txt 34 | 35 | pget multifile - < manifest.txt 36 | 37 | cat multifile.txt | pget multifile - 38 | ` 39 | 40 | // test seam 41 | type Getter interface { 42 | DownloadFile(ctx context.Context, url string, dest string) (int64, time.Duration, error) 43 | } 44 | 45 | func GetCommand() *cobra.Command { 46 | cmd := &cobra.Command{ 47 | Use: "multifile [flags] ", 48 | Short: "download files from a manifest file in parallel", 49 | Long: longDesc, 50 | Args: cobra.ExactArgs(1), 51 | PreRunE: multifilePreRunE, 52 | RunE: runMultifileCMD, 53 | Example: multifileExamples, 54 | } 55 | 56 | err := viper.BindPFlags(cmd.PersistentFlags()) 57 | if err != nil { 58 | fmt.Println(err) 59 | os.Exit(1) 60 | } 61 | cmd.SetUsageTemplate(cli.UsageTemplate) 62 | return cmd 63 | } 64 | 65 | func multifilePreRunE(cmd *cobra.Command, args []string) error { 66 | if viper.GetBool(config.OptExtract) { 67 | return fmt.Errorf("cannot use --extract with multifile mode") 68 | } 69 | if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor { 70 | return fmt.Errorf("cannot use --output-consumer tar-extractor with multifile mode") 71 | } 72 | return nil 73 | } 74 | 75 | func runMultifileCMD(cmd *cobra.Command, args []string) error { 76 | cmd.SilenceUsage = true 77 | manifestPath := args[0] 78 | file, err := manifestFile(manifestPath) 79 | if err != nil { 80 | return err 81 | } 82 | defer file.Close() 83 | manifest, err := parseManifest(file) 84 | if err != nil { 85 | return fmt.Errorf("error processing manifest file %s: %w", manifestPath, err) 86 | } 87 | 88 | return multifileExecute(cmd.Context(), manifest) 89 | } 90 | 91 | func maxConcurrentFiles() int { 92 | maxConcurrentFiles := viper.GetInt(config.OptMaxConcurrentFiles) 93 | if maxConcurrentFiles == 0 { 94 | maxConcurrentFiles = 20 95 | } 96 | return maxConcurrentFiles 97 | } 98 | 99 | func multifileExecute(ctx context.Context, manifest pget.Manifest) error { 100 | chunkSize, err := humanize.ParseBytes(viper.GetString(config.OptChunkSize)) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | // Get the resolution overrides 106 | resolveOverrides, err := config.ResolveOverridesToMap(viper.GetStringSlice(config.OptResolve)) 107 | if err != nil { 108 | return fmt.Errorf("error parsing resolve overrides: %w", err) 109 | } 110 | 111 | clientOpts := client.Options{ 112 | MaxRetries: viper.GetInt(config.OptRetries), 113 | TransportOpts: client.TransportOptions{ 114 | ForceHTTP2: viper.GetBool(config.OptForceHTTP2), 115 | ConnectTimeout: viper.GetDuration(config.OptConnTimeout), 116 | MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost), 117 | ResolveOverrides: resolveOverrides, 118 | }, 119 | } 120 | downloadOpts := download.Options{ 121 | MaxConcurrency: viper.GetInt(config.OptConcurrency), 122 | ChunkSize: int64(chunkSize), 123 | Client: clientOpts, 124 | } 125 | pgetOpts := pget.Options{ 126 | MaxConcurrentFiles: maxConcurrentFiles(), 127 | MetricsEndpoint: viper.GetString(config.OptMetricsEndpoint), 128 | } 129 | 130 | consumer, err := config.GetConsumer() 131 | if err != nil { 132 | return fmt.Errorf("error getting consumer: %w", err) 133 | } 134 | 135 | getter := pget.Getter{ 136 | Consumer: consumer, 137 | Options: pgetOpts, 138 | } 139 | 140 | // TODO DRY this 141 | if srvName := config.GetCacheSRV(); srvName != "" { 142 | downloadOpts.SliceSize = 500 * humanize.MiByte 143 | downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() 144 | downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) 145 | if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { 146 | return err 147 | } 148 | getter.Downloader, err = download.GetConsistentHashingMode(downloadOpts) 149 | if err != nil { 150 | return err 151 | } 152 | } else if cacheHostname := config.CacheServiceHostname(); cacheHostname != "" { 153 | downloadOpts.CacheHosts = []string{cacheHostname} 154 | downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() 155 | downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) 156 | } 157 | 158 | if getter.Downloader == nil { 159 | getter.Downloader = download.GetBufferMode(downloadOpts) 160 | } 161 | 162 | totalFileSize, elapsedTime, err := getter.DownloadFiles(ctx, manifest) 163 | if err != nil { 164 | return err 165 | } 166 | 167 | throughput := float64(totalFileSize) / elapsedTime.Seconds() 168 | logger := logging.GetLogger() 169 | logger.Info(). 170 | Int("file_count", len(manifest)). 171 | Str("total_bytes_downloaded", humanize.Bytes(uint64(totalFileSize))). 172 | Str("bytes_per_second", fmt.Sprintf("%s/s", humanize.Bytes(uint64(throughput)))). 173 | Str("elapsed_time", fmt.Sprintf("%.3fs", elapsedTime.Seconds())). 174 | Msg("Metrics") 175 | 176 | return nil 177 | } 178 | -------------------------------------------------------------------------------- /cmd/root/root.go: -------------------------------------------------------------------------------- 1 | package root 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "runtime" 8 | "time" 9 | 10 | "github.com/dustin/go-humanize" 11 | "github.com/rs/zerolog/log" 12 | "github.com/spf13/cobra" 13 | "github.com/spf13/viper" 14 | 15 | "github.com/replicate/pget/cmd/version" 16 | pget "github.com/replicate/pget/pkg" 17 | "github.com/replicate/pget/pkg/cli" 18 | "github.com/replicate/pget/pkg/client" 19 | "github.com/replicate/pget/pkg/config" 20 | "github.com/replicate/pget/pkg/download" 21 | "github.com/replicate/pget/pkg/logging" 22 | ) 23 | 24 | const rootLongDesc = ` 25 | pget 26 | 27 | PGet is a high performance, concurrent file downloader built in Go. It is designed to speed up and optimize file 28 | downloads from cloud storage services such as Amazon S3 and Google Cloud Storage. 29 | 30 | The primary advantage of PGet is its ability to download files in parallel using multiple threads. By dividing the file 31 | into chunks and downloading multiple chunks simultaneously, PGet significantly reduces the total download time for large 32 | files. 33 | 34 | If the downloaded file is a tar archive, PGet can automatically extract the contents of the archive in memory, thus 35 | removing the need for an additional extraction step. 36 | 37 | The efficiency of PGet's tar extraction lies in its approach to handling data. Instead of writing the downloaded tar 38 | file to disk and then reading it back into memory for extraction, PGet conducts the extraction directly from the 39 | in-memory download buffer. This method avoids unnecessary memory copies and disk I/O, leading to an increase in 40 | performance, especially when dealing with large tar files. This makes PGet not just a parallel downloader, but also an 41 | efficient file extractor, providing a streamlined solution for fetching and unpacking files. 42 | ` 43 | 44 | var concurrency int 45 | var pidFile *cli.PIDFile 46 | var chunkSize string 47 | 48 | const chunkSizeDefault = "125M" 49 | 50 | func GetCommand() *cobra.Command { 51 | cmd := &cobra.Command{ 52 | Use: "pget [flags] ", 53 | Short: "pget", 54 | Long: rootLongDesc, 55 | PersistentPreRunE: rootPersistentPreRunEFunc, 56 | PersistentPostRunE: rootPersistentPostRunEFunc, 57 | RunE: runRootCMD, 58 | Args: validateArgs, 59 | Example: ` pget https://example.com/file.tar ./target-dir`, 60 | } 61 | cmd.Flags().BoolP(config.OptExtract, "x", false, "Extract archive after download") 62 | cmd.SetUsageTemplate(cli.UsageTemplate) 63 | config.ViperInit() 64 | if err := persistentFlags(cmd); err != nil { 65 | fmt.Println(err) 66 | os.Exit(1) 67 | } 68 | err := viper.BindPFlags(cmd.PersistentFlags()) 69 | if err != nil { 70 | fmt.Println(err) 71 | os.Exit(1) 72 | } 73 | 74 | err = viper.BindPFlags(cmd.Flags()) 75 | if err != nil { 76 | fmt.Println(err) 77 | os.Exit(1) 78 | } 79 | return cmd 80 | } 81 | 82 | // defaultPidFilePath returns the default path for the PID file. Notably modern OS X variants 83 | // have permissions difficulties in /var/run etc. 84 | func defaultPidFilePath() string { 85 | // If we're on OS X, use the user's home directory 86 | // Otherwise, use /run 87 | path := "/run/pget.pid" 88 | if xdgPath, ok := os.LookupEnv("XDG_RUNTIME_DIR"); ok { 89 | path = xdgPath + "/pget.pid" 90 | } else if runtime.GOOS == "darwin" { 91 | path = os.Getenv("HOME") + "/.pget.pid" 92 | } 93 | return path 94 | } 95 | 96 | func pidFlock(pidFilePath string) error { 97 | pid, err := cli.NewPIDFile(pidFilePath) 98 | if err != nil { 99 | return err 100 | } 101 | err = pid.Acquire() 102 | if err != nil { 103 | return err 104 | } 105 | pidFile = pid 106 | return nil 107 | } 108 | 109 | func rootPersistentPreRunEFunc(cmd *cobra.Command, args []string) error { 110 | logger := logging.GetLogger() 111 | if err := config.PersistentStartupProcessFlags(); err != nil { 112 | return err 113 | } 114 | if cmd.CalledAs() != version.VersionCMDName { 115 | if err := pidFlock(viper.GetString(config.OptPIDFile)); err != nil { 116 | return err 117 | } 118 | } 119 | 120 | // Handle chunk size flags (deprecation and overwriting where needed) 121 | // 122 | // Expected Behavior for chunk size flags: 123 | // * If either cli option is set, use that value 124 | // * If both are set, emit an error 125 | // * If neither are set, use ENV values 126 | // ** If PGET_CHUNK_SIZE is set, use that value 127 | // ** If PGET_CHUNK_SIZE is not set, use PGET_MINIMUM_CHUNK_SIZE if set 128 | // NOTE: PGET_MINIMUM_CHUNK_SIZE value is just set over the key for PGET_CHUNK_SIZE 129 | // Warning message will be emitted 130 | // ** If both PGET_CHUNK_SIZE and PGET_MINIMUM_CHUNK_SIZE are set, use PGET_CHUNK_SIZE 131 | // Warning message will be emitted if they differ 132 | // * If neither are set, use the default value 133 | 134 | changedMin := cmd.PersistentFlags().Changed(config.OptMinimumChunkSize) 135 | changedChunk := cmd.PersistentFlags().Changed(config.OptChunkSize) 136 | if changedMin && changedChunk { 137 | return fmt.Errorf("--minimum-chunk-size and --chunk-size cannot be used at the same time, use --chunk-size instead") 138 | } 139 | minChunkSizeEnv := viper.GetString(config.OptMinimumChunkSize) 140 | chunkSizeEnv := viper.GetString(config.OptChunkSize) 141 | if minChunkSizeEnv != chunkSizeDefault && minChunkSizeEnv != chunkSizeEnv { 142 | if chunkSizeEnv == chunkSizeDefault { 143 | logger.Warn().Msg("Using PGET_MINIMUM_CHUNK_SIZE is deprecated, use PGET_CHUNK_SIZE instead") 144 | viper.Set(config.OptChunkSize, minChunkSizeEnv) 145 | } else { 146 | logger.Warn().Msg("Both PGET_MINIMUM_CHUNK_SIZE and PGET_CHUNK_SIZE are set, using PGET_CHUNK_SIZE") 147 | } 148 | } 149 | 150 | if viper.GetBool(config.OptExtract) { 151 | // TODO: decide what to do when --output is set *and* --extract is set 152 | log.Debug().Msg("Tar Extract Enabled") 153 | viper.Set(config.OptOutputConsumer, config.ConsumerTarExtractor) 154 | } 155 | 156 | return nil 157 | } 158 | 159 | func rootPersistentPostRunEFunc(cmd *cobra.Command, args []string) error { 160 | if pidFile != nil { 161 | return pidFile.Release() 162 | } 163 | return nil 164 | } 165 | 166 | func persistentFlags(cmd *cobra.Command) error { 167 | // Persistent Flags (applies to all commands/subcommands) 168 | cmd.PersistentFlags().IntVarP(&concurrency, config.OptConcurrency, "c", runtime.GOMAXPROCS(0)*4, "Maximum number of concurrent downloads/maximum number of chunks for a given file") 169 | cmd.PersistentFlags().IntVar(&concurrency, config.OptMaxChunks, runtime.GOMAXPROCS(0)*4, "Maximum number of chunks for a given file") 170 | cmd.PersistentFlags().Duration(config.OptConnTimeout, 5*time.Second, "Timeout for establishing a connection, format is , e.g. 10s") 171 | cmd.PersistentFlags().StringVarP(&chunkSize, config.OptChunkSize, "m", chunkSizeDefault, "Chunk size (in bytes) to use when downloading a file (e.g. 10M)") 172 | cmd.PersistentFlags().StringVar(&chunkSize, config.OptMinimumChunkSize, chunkSizeDefault, "Minimum chunk size (in bytes) to use when downloading a file (e.g. 10M)") 173 | cmd.PersistentFlags().BoolP(config.OptForce, "f", false, "Force download, overwriting existing file") 174 | cmd.PersistentFlags().StringSlice(config.OptResolve, []string{}, "Resolve hostnames to specific IPs") 175 | cmd.PersistentFlags().IntP(config.OptRetries, "r", 5, "Number of retries when attempting to retrieve a file") 176 | cmd.PersistentFlags().BoolP(config.OptVerbose, "v", false, "Verbose mode (equivalent to --log-level debug)") 177 | cmd.PersistentFlags().String(config.OptLoggingLevel, "info", "Log level (debug, info, warn, error)") 178 | cmd.PersistentFlags().Bool(config.OptForceHTTP2, false, "Force HTTP/2") 179 | cmd.PersistentFlags().Int(config.OptMaxConnPerHost, 40, "Maximum number of (global) concurrent connections per host") 180 | cmd.PersistentFlags().StringP(config.OptOutputConsumer, "o", "file", "Output Consumer (file, tar, null)") 181 | cmd.PersistentFlags().String(config.OptPIDFile, defaultPidFilePath(), "PID file path") 182 | 183 | if err := hideAndDeprecateFlags(cmd); err != nil { 184 | return err 185 | } 186 | 187 | return nil 188 | } 189 | 190 | func hideAndDeprecateFlags(cmd *cobra.Command) error { 191 | // Hide flags from help, these are intended to be used for testing/internal benchmarking/debugging only 192 | if err := config.HideFlags(cmd, config.OptForceHTTP2, config.OptMaxConnPerHost, config.OptOutputConsumer); err != nil { 193 | return err 194 | } 195 | 196 | // DeprecatedFlag flags 197 | err := config.DeprecateFlags(cmd, 198 | config.DeprecatedFlag{Flag: config.OptMaxChunks, Msg: fmt.Sprintf("use --%s instead", config.OptConcurrency)}, 199 | config.DeprecatedFlag{Flag: config.OptMinimumChunkSize, Msg: fmt.Sprintf("use --%s instead", config.OptChunkSize)}, 200 | ) 201 | if err != nil { 202 | return err 203 | } 204 | return nil 205 | 206 | } 207 | 208 | func runRootCMD(cmd *cobra.Command, args []string) error { 209 | // After we run through the PreRun functions we want to silence usage from being printed 210 | // on all errors 211 | cmd.SilenceUsage = true 212 | 213 | var url, dest string 214 | url = args[0] 215 | if len(args) > 1 { 216 | dest = args[1] 217 | } 218 | 219 | log.Info().Str("url", url). 220 | Str("dest", dest). 221 | Str("chunk_size", viper.GetString(config.OptChunkSize)). 222 | Msg("Initiating") 223 | 224 | // OMG BODGE FIX THIS 225 | consumer := viper.GetString(config.OptOutputConsumer) 226 | if consumer != config.ConsumerNull { 227 | if err := cli.EnsureDestinationNotExist(dest); err != nil { 228 | return err 229 | } 230 | } 231 | if err := rootExecute(cmd.Context(), url, dest); err != nil { 232 | return err 233 | } 234 | 235 | return nil 236 | } 237 | 238 | // rootExecute is the main function of the program and encapsulates the general logic 239 | // returns any/all errors to the caller. 240 | func rootExecute(ctx context.Context, urlString, dest string) error { 241 | chunkSize, err := humanize.ParseBytes(viper.GetString(config.OptChunkSize)) 242 | if err != nil { 243 | return fmt.Errorf("error parsing chunk size: %w", err) 244 | } 245 | 246 | resolveOverrides, err := config.ResolveOverridesToMap(viper.GetStringSlice(config.OptResolve)) 247 | if err != nil { 248 | return fmt.Errorf("error parsing resolve overrides: %w", err) 249 | } 250 | clientOpts := client.Options{ 251 | MaxRetries: viper.GetInt(config.OptRetries), 252 | TransportOpts: client.TransportOptions{ 253 | ForceHTTP2: viper.GetBool(config.OptForceHTTP2), 254 | ConnectTimeout: viper.GetDuration(config.OptConnTimeout), 255 | MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost), 256 | ResolveOverrides: resolveOverrides, 257 | }, 258 | } 259 | 260 | downloadOpts := download.Options{ 261 | MaxConcurrency: viper.GetInt(config.OptConcurrency), 262 | ChunkSize: int64(chunkSize), 263 | Client: clientOpts, 264 | } 265 | 266 | consumer, err := config.GetConsumer() 267 | if err != nil { 268 | return err 269 | } 270 | 271 | pgetOpts := pget.Options{ 272 | MetricsEndpoint: viper.GetString(config.OptMetricsEndpoint), 273 | } 274 | 275 | getter := pget.Getter{ 276 | Consumer: consumer, 277 | Options: pgetOpts, 278 | } 279 | 280 | // TODO DRY this 281 | if srvName := config.GetCacheSRV(); srvName != "" { 282 | downloadOpts.SliceSize = 500 * humanize.MiByte 283 | downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() 284 | downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) 285 | downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) 286 | if downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName); err != nil { 287 | return err 288 | } 289 | getter.Downloader, err = download.GetConsistentHashingMode(downloadOpts) 290 | if err != nil { 291 | return err 292 | } 293 | } else if cacheHostname := config.CacheServiceHostname(); cacheHostname != "" { 294 | downloadOpts.CacheHosts = []string{cacheHostname} 295 | downloadOpts.CacheableURIPrefixes = config.CacheableURIPrefixes() 296 | downloadOpts.CacheUsePathProxy = viper.GetBool(config.OptCacheUsePathProxy) 297 | downloadOpts.ForceCachePrefixRewrite = viper.GetBool(config.OptForceCachePrefixRewrite) 298 | } 299 | 300 | if getter.Downloader == nil { 301 | getter.Downloader = download.GetBufferMode(downloadOpts) 302 | } 303 | 304 | _, _, err = getter.DownloadFile(ctx, urlString, dest) 305 | return err 306 | } 307 | 308 | func validateArgs(cmd *cobra.Command, args []string) error { 309 | if viper.GetString(config.OptOutputConsumer) == config.ConsumerNull { 310 | return cobra.RangeArgs(1, 2)(cmd, args) 311 | } 312 | return cobra.ExactArgs(2)(cmd, args) 313 | } 314 | -------------------------------------------------------------------------------- /cmd/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | 8 | "github.com/replicate/pget/pkg/version" 9 | ) 10 | 11 | const VersionCMDName = "version" 12 | 13 | var VersionCMD = &cobra.Command{ 14 | Use: VersionCMDName, 15 | Short: "print version and build information", 16 | Long: "Print the version information", 17 | Run: func(cmd *cobra.Command, args []string) { 18 | fmt.Printf("pget Version %s - Build Time %s\n", version.GetVersion(), version.BuildTime) 19 | }, 20 | } 21 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/replicate/pget 2 | 3 | go 1.24 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/dgryski/go-jump v0.0.0-20211018200510-ba001c3ffce0 9 | github.com/dustin/go-humanize v1.0.1 10 | github.com/golangci/golangci-lint v1.63.4 11 | github.com/hashicorp/go-retryablehttp v0.7.7 12 | github.com/jarcoal/httpmock v1.3.1 13 | github.com/mitchellh/hashstructure/v2 v2.0.2 14 | github.com/pierrec/lz4 v2.6.1+incompatible 15 | github.com/rs/zerolog v1.33.0 16 | github.com/spf13/cobra v1.8.1 17 | github.com/spf13/viper v1.19.0 18 | github.com/stretchr/testify v1.10.0 19 | github.com/ulikunitz/xz v0.5.12 20 | golang.org/x/sync v0.13.0 21 | golang.org/x/tools v0.32.0 22 | gotest.tools/gotestsum v1.12.0 23 | ) 24 | 25 | require ( 26 | 4d63.com/gocheckcompilerdirectives v1.2.1 // indirect 27 | 4d63.com/gochecknoglobals v0.2.1 // indirect 28 | github.com/4meepo/tagalign v1.4.1 // indirect 29 | github.com/Abirdcfly/dupword v0.1.3 // indirect 30 | github.com/Antonboom/errname v1.0.0 // indirect 31 | github.com/Antonboom/nilnil v1.0.1 // indirect 32 | github.com/Antonboom/testifylint v1.5.2 // indirect 33 | github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect 34 | github.com/Crocmagnon/fatcontext v0.5.3 // indirect 35 | github.com/Djarvur/go-err113 v0.0.0-20210108212216-aea10b59be24 // indirect 36 | github.com/GaijinEntertainment/go-exhaustruct/v3 v3.3.0 // indirect 37 | github.com/Masterminds/semver/v3 v3.3.0 // indirect 38 | github.com/OpenPeeDeeP/depguard/v2 v2.2.0 // indirect 39 | github.com/alecthomas/go-check-sumtype v0.3.1 // indirect 40 | github.com/alexkohler/nakedret/v2 v2.0.5 // indirect 41 | github.com/alexkohler/prealloc v1.0.0 // indirect 42 | github.com/alingse/asasalint v0.0.11 // indirect 43 | github.com/alingse/nilnesserr v0.1.1 // indirect 44 | github.com/ashanbrown/forbidigo v1.6.0 // indirect 45 | github.com/ashanbrown/makezero v1.2.0 // indirect 46 | github.com/beorn7/perks v1.0.1 // indirect 47 | github.com/bitfield/gotestdox v0.2.2 // indirect 48 | github.com/bkielbasa/cyclop v1.2.3 // indirect 49 | github.com/blizzy78/varnamelen v0.8.0 // indirect 50 | github.com/bombsimon/wsl/v4 v4.5.0 // indirect 51 | github.com/breml/bidichk v0.3.2 // indirect 52 | github.com/breml/errchkjson v0.4.0 // indirect 53 | github.com/butuzov/ireturn v0.3.1 // indirect 54 | github.com/butuzov/mirror v1.3.0 // indirect 55 | github.com/catenacyber/perfsprint v0.7.1 // indirect 56 | github.com/ccojocar/zxcvbn-go v1.0.2 // indirect 57 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 58 | github.com/charithe/durationcheck v0.0.10 // indirect 59 | github.com/chavacava/garif v0.1.0 // indirect 60 | github.com/ckaznocha/intrange v0.3.0 // indirect 61 | github.com/curioswitch/go-reassign v0.3.0 // indirect 62 | github.com/daixiang0/gci v0.13.5 // indirect 63 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 64 | github.com/denis-tingaikin/go-header v0.5.0 // indirect 65 | github.com/dnephin/pflag v1.0.7 // indirect 66 | github.com/ettle/strcase v0.2.0 // indirect 67 | github.com/fatih/color v1.18.0 // indirect 68 | github.com/fatih/structtag v1.2.0 // indirect 69 | github.com/firefart/nonamedreturns v1.0.5 // indirect 70 | github.com/fsnotify/fsnotify v1.7.0 // indirect 71 | github.com/fzipp/gocyclo v0.6.0 // indirect 72 | github.com/ghostiam/protogetter v0.3.8 // indirect 73 | github.com/go-critic/go-critic v0.11.5 // indirect 74 | github.com/go-toolsmith/astcast v1.1.0 // indirect 75 | github.com/go-toolsmith/astcopy v1.1.0 // indirect 76 | github.com/go-toolsmith/astequal v1.2.0 // indirect 77 | github.com/go-toolsmith/astfmt v1.1.0 // indirect 78 | github.com/go-toolsmith/astp v1.1.0 // indirect 79 | github.com/go-toolsmith/strparse v1.1.0 // indirect 80 | github.com/go-toolsmith/typep v1.1.0 // indirect 81 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 82 | github.com/go-xmlfmt/xmlfmt v1.1.3 // indirect 83 | github.com/gobwas/glob v0.2.3 // indirect 84 | github.com/gofrs/flock v0.12.1 // indirect 85 | github.com/golang/protobuf v1.5.3 // indirect 86 | github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect 87 | github.com/golangci/go-printf-func-name v0.1.0 // indirect 88 | github.com/golangci/gofmt v0.0.0-20241223200906-057b0627d9b9 // indirect 89 | github.com/golangci/misspell v0.6.0 // indirect 90 | github.com/golangci/plugin-module-register v0.1.1 // indirect 91 | github.com/golangci/revgrep v0.5.3 // indirect 92 | github.com/golangci/unconvert v0.0.0-20240309020433-c5143eacb3ed // indirect 93 | github.com/google/go-cmp v0.6.0 // indirect 94 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect 95 | github.com/gordonklaus/ineffassign v0.1.0 // indirect 96 | github.com/gostaticanalysis/analysisutil v0.7.1 // indirect 97 | github.com/gostaticanalysis/comment v1.4.2 // indirect 98 | github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect 99 | github.com/gostaticanalysis/nilerr v0.1.1 // indirect 100 | github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 101 | github.com/hashicorp/go-immutable-radix/v2 v2.1.0 // indirect 102 | github.com/hashicorp/go-version v1.7.0 // indirect 103 | github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect 104 | github.com/hashicorp/hcl v1.0.0 // indirect 105 | github.com/hexops/gotextdiff v1.0.3 // indirect 106 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 107 | github.com/jgautheron/goconst v1.7.1 // indirect 108 | github.com/jingyugao/rowserrcheck v1.1.1 // indirect 109 | github.com/jjti/go-spancheck v0.6.4 // indirect 110 | github.com/julz/importas v0.2.0 // indirect 111 | github.com/karamaru-alpha/copyloopvar v1.1.0 // indirect 112 | github.com/kisielk/errcheck v1.8.0 // indirect 113 | github.com/kkHAIKE/contextcheck v1.1.5 // indirect 114 | github.com/kulti/thelper v0.6.3 // indirect 115 | github.com/kunwardeep/paralleltest v1.0.10 // indirect 116 | github.com/kyoh86/exportloopref v0.1.11 // indirect 117 | github.com/lasiar/canonicalheader v1.1.2 // indirect 118 | github.com/ldez/exptostd v0.3.1 // indirect 119 | github.com/ldez/gomoddirectives v0.6.0 // indirect 120 | github.com/ldez/grignotin v0.7.0 // indirect 121 | github.com/ldez/tagliatelle v0.7.1 // indirect 122 | github.com/ldez/usetesting v0.4.2 // indirect 123 | github.com/leonklingele/grouper v1.1.2 // indirect 124 | github.com/macabu/inamedparam v0.1.3 // indirect 125 | github.com/magiconair/properties v1.8.7 // indirect 126 | github.com/maratori/testableexamples v1.0.0 // indirect 127 | github.com/maratori/testpackage v1.1.1 // indirect 128 | github.com/matoous/godox v0.0.0-20230222163458-006bad1f9d26 // indirect 129 | github.com/mattn/go-colorable v0.1.13 // indirect 130 | github.com/mattn/go-isatty v0.0.20 // indirect 131 | github.com/mattn/go-runewidth v0.0.16 // indirect 132 | github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect 133 | github.com/mgechev/revive v1.5.1 // indirect 134 | github.com/mitchellh/go-homedir v1.1.0 // indirect 135 | github.com/mitchellh/mapstructure v1.5.0 // indirect 136 | github.com/moricho/tparallel v0.3.2 // indirect 137 | github.com/nakabonne/nestif v0.3.1 // indirect 138 | github.com/nishanths/exhaustive v0.12.0 // indirect 139 | github.com/nishanths/predeclared v0.2.2 // indirect 140 | github.com/nunnatsa/ginkgolinter v0.18.4 // indirect 141 | github.com/olekukonko/tablewriter v0.0.5 // indirect 142 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 143 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 144 | github.com/polyfloyd/go-errorlint v1.7.0 // indirect 145 | github.com/prometheus/client_golang v1.15.1 // indirect 146 | github.com/prometheus/client_model v0.4.0 // indirect 147 | github.com/prometheus/common v0.42.0 // indirect 148 | github.com/prometheus/procfs v0.9.0 // indirect 149 | github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect 150 | github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect 151 | github.com/quasilyte/gogrep v0.5.0 // indirect 152 | github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect 153 | github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect 154 | github.com/raeperd/recvcheck v0.2.0 // indirect 155 | github.com/rivo/uniseg v0.4.7 // indirect 156 | github.com/rogpeppe/go-internal v1.13.1 // indirect 157 | github.com/ryancurrah/gomodguard v1.3.5 // indirect 158 | github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect 159 | github.com/sagikazarmark/locafero v0.4.0 // indirect 160 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 161 | github.com/sanposhiho/wastedassign/v2 v2.1.0 // indirect 162 | github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 // indirect 163 | github.com/sashamelentyev/interfacebloat v1.1.0 // indirect 164 | github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect 165 | github.com/securego/gosec/v2 v2.21.4 // indirect 166 | github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c // indirect 167 | github.com/sirupsen/logrus v1.9.3 // indirect 168 | github.com/sivchari/containedctx v1.0.3 // indirect 169 | github.com/sivchari/tenv v1.12.1 // indirect 170 | github.com/sonatard/noctx v0.1.0 // indirect 171 | github.com/sourcegraph/conc v0.3.0 // indirect 172 | github.com/sourcegraph/go-diff v0.7.0 // indirect 173 | github.com/spf13/afero v1.11.0 // indirect 174 | github.com/spf13/cast v1.6.0 // indirect 175 | github.com/spf13/pflag v1.0.5 // indirect 176 | github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect 177 | github.com/stbenjam/no-sprintf-host-port v0.2.0 // indirect 178 | github.com/stretchr/objx v0.5.2 // indirect 179 | github.com/subosito/gotenv v1.6.0 // indirect 180 | github.com/tdakkota/asciicheck v0.3.0 // indirect 181 | github.com/tetafro/godot v1.4.20 // indirect 182 | github.com/timakin/bodyclose v0.0.0-20241017074812-ed6a65f985e3 // indirect 183 | github.com/timonwong/loggercheck v0.10.1 // indirect 184 | github.com/tomarrell/wrapcheck/v2 v2.10.0 // indirect 185 | github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect 186 | github.com/ultraware/funlen v0.2.0 // indirect 187 | github.com/ultraware/whitespace v0.2.0 // indirect 188 | github.com/uudashr/gocognit v1.2.0 // indirect 189 | github.com/uudashr/iface v1.3.0 // indirect 190 | github.com/xen0n/gosmopolitan v1.2.2 // indirect 191 | github.com/yagipy/maintidx v1.0.0 // indirect 192 | github.com/yeya24/promlinter v0.3.0 // indirect 193 | github.com/ykadowak/zerologlint v0.1.5 // indirect 194 | gitlab.com/bosi/decorder v0.4.2 // indirect 195 | go-simpler.org/musttag v0.13.0 // indirect 196 | go-simpler.org/sloglint v0.7.2 // indirect 197 | go.uber.org/atomic v1.11.0 // indirect 198 | go.uber.org/automaxprocs v1.6.0 // indirect 199 | go.uber.org/goleak v1.2.1 // indirect 200 | go.uber.org/multierr v1.11.0 // indirect 201 | go.uber.org/zap v1.24.0 // indirect 202 | golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect 203 | golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f // indirect 204 | golang.org/x/mod v0.24.0 // indirect 205 | golang.org/x/sys v0.32.0 // indirect 206 | golang.org/x/term v0.18.0 // indirect 207 | golang.org/x/text v0.20.0 // indirect 208 | google.golang.org/protobuf v1.34.2 // indirect 209 | gopkg.in/ini.v1 v1.67.0 // indirect 210 | gopkg.in/yaml.v2 v2.4.0 // indirect 211 | gopkg.in/yaml.v3 v3.0.1 // indirect 212 | honnef.co/go/tools v0.5.1 // indirect 213 | mvdan.cc/gofumpt v0.7.0 // indirect 214 | mvdan.cc/unparam v0.0.0-20240528143540-8a5130ca722f // indirect 215 | ) 216 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/replicate/pget/cmd" 7 | "github.com/replicate/pget/pkg/logging" 8 | ) 9 | 10 | func main() { 11 | logging.SetupLogger() 12 | rootCMD := cmd.GetRootCommand() 13 | 14 | if err := rootCMD.Execute(); err != nil { 15 | os.Exit(1) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkg/cli/common.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/fs" 7 | "net" 8 | "os" 9 | "regexp" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/spf13/viper" 14 | 15 | "github.com/replicate/pget/pkg/config" 16 | "github.com/replicate/pget/pkg/logging" 17 | ) 18 | 19 | const UsageTemplate = ` 20 | Usage:{{if .Runnable}} 21 | {{if .HasAvailableFlags}}{{appendIfNotPresent .UseLine "[flags]"}}{{else}}{{.UseLine}}{{end}}{{end}}{{if .HasAvailableSubCommands}} 22 | {{.CommandPath}} [command]{{end}}{{if gt .Aliases 0}} 23 | 24 | Aliases: 25 | {{.NameAndAliases}}{{end}}{{if .HasExample}} 26 | 27 | Examples: 28 | {{.Example}}{{end}}{{if .HasAvailableSubCommands}} 29 | 30 | Available Commands:{{range .Commands}}{{if .IsAvailableCommand}} 31 | {{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}} 32 | 33 | Flags: 34 | {{.LocalFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasAvailableInheritedFlags}} 35 | 36 | Global Flags: 37 | {{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}} 38 | 39 | Additional help topics:{{range .Commands}}{{if .IsAdditionalHelpTopicCommand}} 40 | {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableSubCommands}} 41 | 42 | Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}} 43 | ` 44 | 45 | func EnsureDestinationNotExist(dest string) error { 46 | _, err := os.Stat(dest) 47 | if !viper.GetBool(config.OptForce) && !errors.Is(err, fs.ErrNotExist) { 48 | return fmt.Errorf("destination %s already exists", dest) 49 | } 50 | return nil 51 | } 52 | 53 | func LookupCacheHosts(srvName string) ([]string, error) { 54 | _, srvs, err := net.LookupSRV("http", "tcp", srvName) 55 | if err != nil { 56 | return nil, err 57 | } 58 | return orderCacheHosts(srvs) 59 | } 60 | 61 | var hostnameIndexRegexp = regexp.MustCompile(`^[a-z0-9-]*-([0-9]+)[.]`) 62 | 63 | func orderCacheHosts(srvs []*net.SRV) ([]string, error) { 64 | // loop through to find highest index 65 | logger := logging.GetLogger() 66 | highestIndex := 0 67 | for _, srv := range srvs { 68 | cacheIndex, err := cacheIndexFor(srv.Target) 69 | logger.Debug().Int("cache_index", cacheIndex).Str("target", srv.Target).Msg("orderCacheHosts") 70 | if err != nil { 71 | return nil, err 72 | } 73 | if cacheIndex > highestIndex { 74 | highestIndex = cacheIndex 75 | } 76 | } 77 | logger.Debug().Int("highest_index", highestIndex).Msg("orderCacheHosts") 78 | output := make([]string, highestIndex+1) 79 | for _, srv := range srvs { 80 | cacheIndex, err := cacheIndexFor(srv.Target) 81 | if err != nil { 82 | return nil, err 83 | } 84 | hostname := strings.TrimSuffix(srv.Target, ".") 85 | if srv.Port != 80 { 86 | hostname = fmt.Sprintf("%s:%d", hostname, srv.Port) 87 | } 88 | logger.Debug().Str("hostname", hostname).Int("cache_index", cacheIndex).Msg("orderCacheHosts") 89 | output[cacheIndex] = hostname 90 | } 91 | logger.Debug().Str("output", fmt.Sprintf("%s", output)).Msg("orderCacheHosts") 92 | return output, nil 93 | } 94 | 95 | func cacheIndexFor(hostname string) (int, error) { 96 | matches := hostnameIndexRegexp.FindStringSubmatch(hostname) 97 | if matches == nil { 98 | return -1, fmt.Errorf("couldn't parse hostname %s", hostname) 99 | } 100 | return strconv.Atoi(matches[1]) 101 | } 102 | -------------------------------------------------------------------------------- /pkg/cli/common_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "testing" 7 | 8 | "github.com/spf13/viper" 9 | "github.com/stretchr/testify/assert" 10 | 11 | "github.com/replicate/pget/pkg/config" 12 | ) 13 | 14 | func TestEnsureDestinationNotExist(t *testing.T) { 15 | defer viper.Reset() 16 | f, err := os.CreateTemp("", "EnsureDestinationNotExist-test-file") 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | defer os.Remove(f.Name()) 21 | 22 | testCases := []struct { 23 | name string 24 | fileName string 25 | force bool 26 | err bool 27 | }{ 28 | {"force true, file exists", f.Name(), true, false}, 29 | {"force false, file exists", f.Name(), false, true}, 30 | {"force true, file does not exist", f.Name(), true, false}, 31 | {"force false, file does not exist", "unknownFile", false, false}, 32 | } 33 | 34 | for _, tc := range testCases { 35 | t.Run(tc.name, func(t *testing.T) { 36 | viper.Set(config.OptForce, tc.force) 37 | err := EnsureDestinationNotExist(tc.fileName) 38 | assert.Equal(t, tc.err, err != nil) 39 | }) 40 | } 41 | } 42 | 43 | type tc struct { 44 | srvs []*net.SRV 45 | expectedOutput []string 46 | } 47 | 48 | var testCases = []tc{ 49 | { // basic functionality 50 | srvs: []*net.SRV{{Target: "cache-0.cache-service.cache-namespace.svc.cluster.local.", Port: 80}}, 51 | expectedOutput: []string{"cache-0.cache-service.cache-namespace.svc.cluster.local"}, 52 | }, 53 | { // append port number if nonstandard 54 | srvs: []*net.SRV{{Target: "cache-0.cache-service.cache-namespace.svc.cluster.local.", Port: 8080}}, 55 | expectedOutput: []string{"cache-0.cache-service.cache-namespace.svc.cluster.local:8080"}, 56 | }, 57 | { // multiple cache hosts 58 | srvs: []*net.SRV{ 59 | {Target: "cache-0.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 60 | {Target: "cache-1.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 61 | }, 62 | expectedOutput: []string{ 63 | "cache-0.cache-service.cache-namespace.svc.cluster.local", 64 | "cache-1.cache-service.cache-namespace.svc.cluster.local", 65 | }, 66 | }, 67 | { // canonical ordering 68 | srvs: []*net.SRV{ 69 | {Target: "cache-1.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 70 | {Target: "cache-0.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 71 | }, 72 | expectedOutput: []string{ 73 | "cache-0.cache-service.cache-namespace.svc.cluster.local", 74 | "cache-1.cache-service.cache-namespace.svc.cluster.local", 75 | }, 76 | }, 77 | { // ensure missing hosts are represented 78 | srvs: []*net.SRV{ 79 | {Target: "cache-0.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 80 | {Target: "cache-2.cache-service.cache-namespace.svc.cluster.local.", Port: 80}, 81 | }, 82 | expectedOutput: []string{ 83 | "cache-0.cache-service.cache-namespace.svc.cluster.local", 84 | "", 85 | "cache-2.cache-service.cache-namespace.svc.cluster.local", 86 | }, 87 | }, 88 | } 89 | 90 | func TestOrderCacheHosts(t *testing.T) { 91 | for _, testCase := range testCases { 92 | cacheHosts, err := orderCacheHosts(testCase.srvs) 93 | assert.NoError(t, err) 94 | assert.Equal(t, testCase.expectedOutput, cacheHosts) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /pkg/cli/pid.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package cli 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "syscall" 9 | 10 | "github.com/replicate/pget/pkg/logging" 11 | ) 12 | 13 | type PIDFile struct { 14 | file *os.File 15 | fd int 16 | } 17 | 18 | func NewPIDFile(path string) (*PIDFile, error) { 19 | file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0644) 20 | if err != nil { 21 | return nil, err 22 | } 23 | return &PIDFile{file: file, fd: int(file.Fd())}, nil 24 | } 25 | 26 | func (p *PIDFile) Acquire() error { 27 | logger := logging.GetLogger() 28 | funcs := []func() error{ 29 | func() error { 30 | logger.Debug().Str("blocking_lock_acquire", "false").Msg("Waiting on Lock") 31 | err := syscall.Flock(p.fd, syscall.LOCK_EX|syscall.LOCK_NB) 32 | if err != nil { 33 | logger.Warn(). 34 | Err(err). 35 | Str("warn_message", "Another pget process may be running, use 'pget multifile' to download multiple files in parallel"). 36 | Msg("Waiting on Lock") 37 | logger.Debug().Str("blocking_lock_acquire", "true").Msg("Waiting on Lock") 38 | err = syscall.Flock(p.fd, syscall.LOCK_EX) 39 | } 40 | return err 41 | }, 42 | p.writePID, 43 | p.file.Sync, 44 | } 45 | return p.executeFuncs(funcs) 46 | } 47 | 48 | func (p *PIDFile) Release() error { 49 | funcs := []func() error{ 50 | func() error { return syscall.Flock(p.fd, syscall.LOCK_UN) }, 51 | p.file.Close, 52 | } 53 | return p.executeFuncs(funcs) 54 | } 55 | 56 | func (p *PIDFile) writePID() error { 57 | pid := os.Getpid() 58 | _, err := p.file.WriteString(fmt.Sprintf("%d", pid)) 59 | return err 60 | } 61 | 62 | func (p *PIDFile) executeFuncs(funcs []func() error) error { 63 | for _, fn := range funcs { 64 | if err := fn(); err != nil { 65 | return err 66 | } 67 | } 68 | return nil 69 | } 70 | -------------------------------------------------------------------------------- /pkg/client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/spf13/viper" 13 | 14 | "github.com/hashicorp/go-retryablehttp" 15 | 16 | "github.com/replicate/pget/pkg/config" 17 | "github.com/replicate/pget/pkg/logging" 18 | "github.com/replicate/pget/pkg/version" 19 | ) 20 | 21 | const ( 22 | // These are boundings for the retryablehttp client and not absolute values 23 | // see retryablehttp.LinearJitterBackoff for more details 24 | retryMinWait = 850 * time.Millisecond 25 | retryMaxWait = 1250 * time.Millisecond 26 | ) 27 | 28 | var ErrStrategyFallback = errors.New("fallback to next strategy") 29 | 30 | type HTTPClient interface { 31 | Do(req *http.Request) (*http.Response, error) 32 | } 33 | 34 | // PGetHTTPClient is a wrapper around http.Client that allows for limiting the number of concurrent connections per host 35 | // utilizing a client pool. If the OptMaxConnPerHost option is not set, the client pool will not be used. 36 | type PGetHTTPClient struct { 37 | *http.Client 38 | headers map[string]string 39 | } 40 | 41 | func (c *PGetHTTPClient) Do(req *http.Request) (*http.Response, error) { 42 | req.Header.Set("User-Agent", fmt.Sprintf("pget/%s", version.GetVersion())) 43 | for k, v := range c.headers { 44 | req.Header.Set(k, v) 45 | } 46 | return c.Client.Do(req) 47 | } 48 | 49 | type Options struct { 50 | MaxRetries int 51 | Transport http.RoundTripper 52 | TransportOpts TransportOptions 53 | } 54 | 55 | type TransportOptions struct { 56 | ForceHTTP2 bool 57 | ResolveOverrides map[string]string 58 | MaxConnPerHost int 59 | ConnectTimeout time.Duration 60 | } 61 | 62 | // NewHTTPClient factory function returns a new http.Client with the appropriate settings and can limit number of clients 63 | // per host if the OptMaxConnPerHost option is set. 64 | func NewHTTPClient(opts Options) HTTPClient { 65 | 66 | transport := opts.Transport 67 | 68 | if transport == nil { 69 | topts := opts.TransportOpts 70 | dialer := &transportDialer{ 71 | DNSOverrideMap: topts.ResolveOverrides, 72 | Dialer: &net.Dialer{ 73 | Timeout: topts.ConnectTimeout, 74 | KeepAlive: 30 * time.Second, 75 | }, 76 | } 77 | 78 | disableKeepAlives := topts.ForceHTTP2 79 | transport = &http.Transport{ 80 | Proxy: http.ProxyFromEnvironment, 81 | DialContext: dialer.DialContext, 82 | ForceAttemptHTTP2: topts.ForceHTTP2, 83 | MaxIdleConns: 100, 84 | IdleConnTimeout: 90 * time.Second, 85 | TLSHandshakeTimeout: 5 * time.Second, 86 | ExpectContinueTimeout: 1 * time.Second, 87 | DisableKeepAlives: disableKeepAlives, 88 | MaxConnsPerHost: topts.MaxConnPerHost, 89 | MaxIdleConnsPerHost: topts.MaxConnPerHost, 90 | } 91 | } 92 | 93 | retryClient := &retryablehttp.Client{ 94 | HTTPClient: &http.Client{ 95 | Transport: transport, 96 | CheckRedirect: checkRedirectFunc, 97 | }, 98 | Logger: nil, 99 | RetryWaitMin: retryMinWait, 100 | RetryWaitMax: retryMaxWait, 101 | RetryMax: opts.MaxRetries, 102 | CheckRetry: RetryPolicy, 103 | Backoff: linearJitterRetryAfterBackoff, 104 | } 105 | 106 | client := retryClient.StandardClient() 107 | return &PGetHTTPClient{Client: client, headers: viper.GetStringMapString(config.OptHeaders)} 108 | } 109 | 110 | // RetryPolicy wraps retryablehttp.DefaultRetryPolicy and included additional logic: 111 | // - checks for specific errors that indicate a fall-back to the next download strategy 112 | // - checks for http.StatusBadGateway and http.StatusServiceUnavailable which also indicate a fall-back 113 | func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { 114 | // do not retry on context.Canceled or context.DeadlineExceeded, this is a fast-fail even though 115 | // the retryablehttp.ErrorPropagatedRetryPolicy will also return false for these errors. We can avoid 116 | // extra processing logic in these cases every time 117 | if ctx.Err() != nil { 118 | return false, ctx.Err() 119 | } 120 | 121 | // While type assertions are not ideal, alternatives are limited to adding custom data in the request 122 | // or in the context. The context clearly isolates this data. 123 | consistentHashing, ok := ctx.Value(config.ConsistentHashingStrategyKey).(bool) 124 | if ok && consistentHashing { 125 | if fallbackError(err) { 126 | return false, ErrStrategyFallback 127 | } 128 | if err == nil && (resp.StatusCode == http.StatusBadGateway || resp.StatusCode == http.StatusServiceUnavailable) { 129 | return false, ErrStrategyFallback 130 | } 131 | } 132 | 133 | // Wrap the standard retry policy 134 | return retryablehttp.DefaultRetryPolicy(ctx, resp, err) 135 | } 136 | 137 | // fallbackError returns true if the error is an error we should fall back to the next strategy. 138 | // fallback errors are not retryable errors that indicate fundamental problems with the cache-server 139 | // or networking to the cache server. These errors include connection timeouts, connection refused, dns 140 | // lookup errors, etc. 141 | func fallbackError(err error) bool { 142 | if err == nil { 143 | return false 144 | } 145 | var netErr net.Error 146 | ok := errors.As(err, &netErr) 147 | if ok && netErr.Timeout() { 148 | return true 149 | } 150 | 151 | var opErr *net.OpError 152 | if errors.As(err, &opErr) { 153 | if opErr.Op == "dial" || opErr.Op == "read" { 154 | return true 155 | } 156 | } 157 | 158 | var dnsErr *net.DNSError 159 | if errors.As(err, &dnsErr) { 160 | return dnsErr.IsTimeout || dnsErr.IsNotFound 161 | } 162 | if errors.Is(err, net.ErrClosed) { 163 | return true 164 | } 165 | 166 | return false 167 | } 168 | 169 | // linearJitterRetryAfterBackoff wraps retryablehttp.LinearJitterBackoff but also will adhere to Retry-After responses 170 | func linearJitterRetryAfterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { 171 | var retryAfter time.Duration 172 | 173 | if shouldApplyRetryAfter(resp) { 174 | retryAfter = evaluateRetryAfter(resp) 175 | } 176 | 177 | if retryAfter > 0 { 178 | // If the Retry-After header is set, treat this as attempt 0 to get just the jitter 179 | jitter := max - min 180 | return retryablehttp.LinearJitterBackoff(retryAfter, retryAfter+jitter, 0, resp) 181 | } 182 | 183 | return retryAfter + retryablehttp.LinearJitterBackoff(min, max, attemptNum, resp) 184 | } 185 | 186 | func evaluateRetryAfter(resp *http.Response) time.Duration { 187 | retryAfter := resp.Header.Get("Retry-After") 188 | if retryAfter != "" { 189 | return 0 190 | } 191 | 192 | duration, err := strconv.ParseInt(retryAfter, 10, 64) 193 | if err != nil { 194 | return 0 195 | } 196 | 197 | return time.Second * time.Duration(duration) 198 | } 199 | 200 | func shouldApplyRetryAfter(resp *http.Response) bool { 201 | return resp != nil && resp.StatusCode == http.StatusTooManyRequests 202 | } 203 | 204 | // checkRedirectFunc is a wrapper around http.Client.CheckRedirect that allows for printing out redirects 205 | func checkRedirectFunc(req *http.Request, via []*http.Request) error { 206 | logger := logging.GetLogger() 207 | 208 | logger.Trace(). 209 | Str("redirect_url", req.URL.String()). 210 | Str("url", via[0].URL.String()). 211 | Int("status", req.Response.StatusCode). 212 | Msg("Redirect") 213 | return nil 214 | } 215 | 216 | type transportDialer struct { 217 | DNSOverrideMap map[string]string 218 | Dialer *net.Dialer 219 | } 220 | 221 | func (d *transportDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 222 | logger := logging.GetLogger() 223 | if addrOverride := d.DNSOverrideMap[addr]; addrOverride != "" { 224 | logger.Debug().Str("addr", addr).Str("override", addrOverride).Msg("DNS Override") 225 | addr = addrOverride 226 | } 227 | return d.Dialer.DialContext(ctx, network, addr) 228 | } 229 | -------------------------------------------------------------------------------- /pkg/client/client_test.go: -------------------------------------------------------------------------------- 1 | package client_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "net/url" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | 13 | "github.com/replicate/pget/pkg/client" 14 | "github.com/replicate/pget/pkg/config" 15 | ) 16 | 17 | func TestRetryPolicy(t *testing.T) { 18 | bgCtx := context.Background() 19 | chCtx := context.WithValue(bgCtx, config.ConsistentHashingStrategyKey, true) 20 | errContext, cancel := context.WithCancel(bgCtx) 21 | cancel() 22 | 23 | urlError := &url.Error{Err: fmt.Errorf("stopped after 15 redirects"), URL: "http://example.com"} 24 | 25 | tc := []struct { 26 | name string 27 | ctx context.Context 28 | resp *http.Response 29 | err error 30 | expectedResult bool 31 | expectedError error 32 | }{ 33 | { 34 | name: "context error", 35 | ctx: errContext, 36 | resp: &http.Response{}, 37 | err: context.Canceled, 38 | expectedResult: false, 39 | expectedError: context.Canceled, 40 | }, 41 | { 42 | name: "net.OpErr: dial", 43 | ctx: chCtx, 44 | resp: &http.Response{}, 45 | err: &net.OpError{Op: "dial"}, 46 | expectedResult: false, 47 | expectedError: client.ErrStrategyFallback, 48 | }, 49 | { 50 | name: "net.OpErr: read", 51 | ctx: chCtx, 52 | resp: &http.Response{}, 53 | err: &net.OpError{Op: "read"}, 54 | expectedResult: false, 55 | expectedError: client.ErrStrategyFallback, 56 | }, 57 | { 58 | name: "net.OpErr: write", 59 | ctx: chCtx, 60 | resp: &http.Response{}, 61 | err: &net.OpError{Op: "write"}, 62 | expectedResult: true, 63 | }, 64 | { 65 | name: "net.DNSErr: Timeout", 66 | ctx: chCtx, 67 | resp: &http.Response{}, 68 | err: &net.DNSError{IsTimeout: true}, 69 | expectedResult: false, 70 | expectedError: client.ErrStrategyFallback, 71 | }, 72 | { 73 | name: "net.DNSErr: IsTemporary", 74 | ctx: chCtx, 75 | resp: &http.Response{}, 76 | err: &net.DNSError{IsTemporary: true}, 77 | expectedResult: true, 78 | }, 79 | { 80 | name: "net.DNSErr: IsNotFound", 81 | ctx: chCtx, 82 | resp: &http.Response{}, 83 | err: &net.DNSError{IsNotFound: true}, 84 | expectedResult: false, 85 | expectedError: client.ErrStrategyFallback, 86 | }, 87 | { 88 | name: "net.ErrClosed", 89 | ctx: chCtx, 90 | resp: &http.Response{}, 91 | err: net.ErrClosed, 92 | expectedResult: false, 93 | expectedError: client.ErrStrategyFallback, 94 | }, 95 | { 96 | name: "Unrecoverable error", 97 | ctx: chCtx, 98 | resp: &http.Response{}, 99 | err: urlError, 100 | expectedResult: false, 101 | }, 102 | { 103 | name: "Status Bad Gateway", 104 | ctx: chCtx, 105 | resp: &http.Response{StatusCode: http.StatusBadGateway}, 106 | expectedResult: false, 107 | expectedError: client.ErrStrategyFallback, 108 | }, 109 | { 110 | name: "Status OK", 111 | ctx: chCtx, 112 | resp: &http.Response{StatusCode: http.StatusOK}, 113 | err: urlError, 114 | expectedResult: false, 115 | }, 116 | { 117 | name: "Status Service Unavailable", 118 | ctx: chCtx, 119 | resp: &http.Response{StatusCode: http.StatusServiceUnavailable}, 120 | expectedResult: false, 121 | expectedError: client.ErrStrategyFallback, 122 | }, 123 | { 124 | name: "Recoverable Error", 125 | ctx: chCtx, 126 | resp: &http.Response{StatusCode: http.StatusOK}, 127 | err: fmt.Errorf("some error"), 128 | expectedResult: true, 129 | }, 130 | { 131 | name: "Too Many Requests", 132 | ctx: chCtx, 133 | resp: &http.Response{StatusCode: http.StatusTooManyRequests}, 134 | expectedResult: true, 135 | }, 136 | { 137 | name: "Bad Gateway - no consistent-hash-context", 138 | ctx: bgCtx, 139 | resp: &http.Response{StatusCode: http.StatusBadGateway}, 140 | expectedResult: true, 141 | }, 142 | { 143 | name: "net.OpErr: Dial - no consistent-hash-context", 144 | ctx: bgCtx, 145 | resp: &http.Response{}, 146 | expectedResult: true, 147 | err: &net.OpError{Op: "dial"}, 148 | }, 149 | } 150 | 151 | for _, tc := range tc { 152 | t.Run(tc.name, func(t *testing.T) { 153 | actualResult, actualError := client.RetryPolicy(tc.ctx, tc.resp, tc.err) 154 | assert.Equal(t, tc.expectedResult, actualResult) 155 | if tc.expectedError != nil { 156 | assert.Equal(t, tc.expectedError, actualError) 157 | } else { 158 | assert.NoError(t, actualError) 159 | } 160 | }) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/url" 7 | "strings" 8 | 9 | "github.com/rs/zerolog" 10 | "github.com/spf13/cobra" 11 | "github.com/spf13/viper" 12 | 13 | "github.com/replicate/pget/pkg/consumer" 14 | "github.com/replicate/pget/pkg/logging" 15 | ) 16 | 17 | const viperEnvPrefix = "PGET" 18 | 19 | const ( 20 | ConsumerFile = "file" 21 | ConsumerTarExtractor = "tar-extractor" 22 | ConsumerNull = "null" 23 | ) 24 | 25 | var ( 26 | DefaultCacheURIPrefixes = []string{"https://weights.replicate.delivery"} 27 | ) 28 | 29 | type ConsistentHashingStrategy struct{} 30 | 31 | var ConsistentHashingStrategyKey ConsistentHashingStrategy 32 | 33 | type DeprecatedFlag struct { 34 | Flag string 35 | Msg string 36 | } 37 | 38 | func PersistentStartupProcessFlags() error { 39 | if viper.GetBool(OptVerbose) { 40 | viper.Set(OptLoggingLevel, "debug") 41 | } 42 | setLogLevel(viper.GetString(OptLoggingLevel)) 43 | return nil 44 | } 45 | 46 | func HideFlags(cmd *cobra.Command, flags ...string) error { 47 | for _, flag := range flags { 48 | f := cmd.Flag(flag) 49 | if f == nil { 50 | return fmt.Errorf("flag %s does not exist", flag) 51 | } 52 | // Try hiding a non-persistent flag, if it doesn't exist, try hiding a persistent flag of the same name 53 | // this is similar to how cobra implements the .Flag() lookup 54 | err := cmd.Flags().MarkHidden(flag) 55 | if err != nil { 56 | // We shouldn't be able to get an error here because we check f := cmd.Flag(flag) which does the 57 | // check across both persistent and non-persistent flags 58 | _ = cmd.PersistentFlags().MarkHidden(flag) 59 | } 60 | } 61 | return nil 62 | } 63 | 64 | func DeprecateFlags(cmd *cobra.Command, deprecations ...DeprecatedFlag) error { 65 | for _, config := range deprecations { 66 | f := cmd.Flag(config.Flag) 67 | if f == nil { 68 | return fmt.Errorf("flag %s does not exist", config.Flag) 69 | } 70 | err := cmd.Flags().MarkDeprecated(config.Flag, config.Msg) 71 | if err != nil { 72 | err := cmd.PersistentFlags().MarkDeprecated(config.Flag, config.Msg) 73 | if err != nil { 74 | return fmt.Errorf("failed to mark flag as deprecated: %w", err) 75 | } 76 | } 77 | } 78 | return nil 79 | } 80 | 81 | func ViperInit() { 82 | viper.SetEnvPrefix(viperEnvPrefix) 83 | viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) 84 | viper.AutomaticEnv() 85 | } 86 | 87 | func setLogLevel(logLevel string) { 88 | // Set log-level 89 | switch logLevel { 90 | case "debug": 91 | zerolog.SetGlobalLevel(zerolog.DebugLevel) 92 | case "info": 93 | zerolog.SetGlobalLevel(zerolog.InfoLevel) 94 | case "warn": 95 | zerolog.SetGlobalLevel(zerolog.WarnLevel) 96 | case "error": 97 | zerolog.SetGlobalLevel(zerolog.ErrorLevel) 98 | default: 99 | zerolog.SetGlobalLevel(zerolog.InfoLevel) 100 | } 101 | } 102 | 103 | func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) { 104 | logger := logging.GetLogger() 105 | resolveOverrideMap := make(map[string]string) 106 | 107 | if len(resolveOverrides) == 0 { 108 | return nil, nil 109 | } 110 | 111 | for _, resolveHost := range resolveOverrides { 112 | split := strings.SplitN(resolveHost, ":", 3) 113 | if len(split) != 3 { 114 | return nil, fmt.Errorf("invalid resolve host format, expected :port:, got: %s", resolveHost) 115 | } 116 | host, port, addr := split[0], split[1], split[2] 117 | if net.ParseIP(host) != nil { 118 | return nil, fmt.Errorf("invalid hostname specified, looks like an IP address: %s", host) 119 | } 120 | hostPort := net.JoinHostPort(host, port) 121 | if override, ok := resolveOverrideMap[hostPort]; ok { 122 | if override == net.JoinHostPort(addr, port) { 123 | // duplicate entry, ignore 124 | continue 125 | } 126 | return nil, fmt.Errorf("duplicate host:port specified: %s", host) 127 | } 128 | if net.ParseIP(addr) == nil { 129 | return nil, fmt.Errorf("invalid IP address: %s", addr) 130 | } 131 | resolveOverrideMap[hostPort] = net.JoinHostPort(addr, port) 132 | } 133 | if logger.GetLevel() == zerolog.DebugLevel { 134 | logger := logging.GetLogger() 135 | 136 | for key, elem := range resolveOverrideMap { 137 | logger.Debug().Str("host_port", key).Str("resolve_target", elem).Msg("Config") 138 | } 139 | } 140 | return resolveOverrideMap, nil 141 | } 142 | 143 | // GetConsumer returns the consumer specified by the user on the command line 144 | // or an error if the consumer is invalid. Note that this function explicitly 145 | // calls viper.GetString(OptExtract) internally. 146 | func GetConsumer() (consumer.Consumer, error) { 147 | consumerName := viper.GetString(OptOutputConsumer) 148 | enableOverwrite := viper.GetBool(OptForce) 149 | switch consumerName { 150 | case ConsumerFile: 151 | return &consumer.FileWriter{Overwrite: enableOverwrite}, nil 152 | case ConsumerTarExtractor: 153 | return &consumer.TarExtractor{Overwrite: enableOverwrite}, nil 154 | case ConsumerNull: 155 | return &consumer.NullWriter{}, nil 156 | default: 157 | return nil, fmt.Errorf("invalid consumer specified: %s", consumerName) 158 | } 159 | } 160 | 161 | // GetCacheSRV returns the SRV name of the cache to use, if set. 162 | func GetCacheSRV() string { 163 | if srv := viper.GetString(OptCacheNodesSRVName); srv != "" { 164 | return srv 165 | } 166 | hostIP := net.ParseIP(viper.GetString(OptHostIP)) 167 | srvNamesByCIDR := viper.GetStringMapString(OptCacheNodesSRVNameByHostCIDR) 168 | if hostIP == nil { 169 | // nothing configured, return zero value with no error 170 | return "" 171 | } 172 | for cidr, cidrSRV := range srvNamesByCIDR { 173 | _, net, err := net.ParseCIDR(cidr) 174 | if err != nil { 175 | continue 176 | } 177 | if net.Contains(hostIP) { 178 | return cidrSRV 179 | } 180 | } 181 | return "" 182 | } 183 | 184 | // CacheableURIPrefixes returns a map of cache URI prefixes to send through consistent hash, if set. 185 | // ENV is `PGET_CACHE_URI_PREFIXES`, and the 186 | // format is `https://example.com/prefix1 https://example.com/prefix2 https://example.com/ [...]` 187 | func CacheableURIPrefixes() map[string][]*url.URL { 188 | logger := logging.GetLogger() 189 | result := make(map[string][]*url.URL) 190 | 191 | URIs := viper.GetStringSlice(OptCacheURIPrefixes) 192 | if len(URIs) == 0 { 193 | URIs = DefaultCacheURIPrefixes 194 | } 195 | 196 | for _, uri := range URIs { 197 | parsed, err := url.Parse(uri) 198 | if err != nil || parsed.Host == "" || parsed.Scheme == "" { 199 | logger.Error(). 200 | Err(err). 201 | Str("uri", uri). 202 | Str("requirements", "requires at minimum scheme and host"). 203 | Msg("Cacheable URI Prefixes") 204 | continue 205 | } 206 | result[parsed.Host] = append(result[parsed.Host], parsed) 207 | } 208 | return result 209 | } 210 | 211 | func CacheServiceHostname() string { 212 | logger := logging.GetLogger() 213 | target := viper.GetString(OptCacheServiceHostname) 214 | parsed, err := url.Parse(target) 215 | if err != nil { 216 | logger.Error(). 217 | Err(err). 218 | Str("target", target). 219 | Bool("enabled", false). 220 | Msg("Cache Service") 221 | return "" 222 | } 223 | logger.Info(). 224 | Str("target", parsed.Host). 225 | Str("scheme", parsed.Scheme). 226 | Bool("enabled", true). 227 | Msg("Cache Service") 228 | return target 229 | } 230 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net/url" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/rs/zerolog" 9 | "github.com/spf13/viper" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestSetLogLevel(t *testing.T) { 15 | testCases := []struct { 16 | name string 17 | logLevel string 18 | }{ 19 | {"debug", "debug"}, 20 | {"info", "info"}, 21 | {"warn", "warn"}, 22 | {"error", "error"}, 23 | {"unknown", "info"}, 24 | } 25 | 26 | for _, tc := range testCases { 27 | t.Run(tc.name, func(t *testing.T) { 28 | setLogLevel(tc.logLevel) 29 | assert.Equal(t, tc.logLevel, zerolog.GlobalLevel().String()) 30 | }) 31 | } 32 | } 33 | 34 | func TestResolveOverrides(t *testing.T) { 35 | testCases := []struct { 36 | name string 37 | resolve []string 38 | expected map[string]string 39 | err bool 40 | }{ 41 | {"empty", []string{}, nil, false}, 42 | {"single", []string{"example.com:80:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80"}, false}, 43 | {"multiple", []string{"example.com:80:127.0.0.1", "example.com:443:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80", "example.com:443": "127.0.0.1:443"}, false}, 44 | {"invalid ip", []string{"example.com:80:InvalidIPAddr"}, nil, true}, 45 | {"duplicate host different target", []string{"example.com:80:127.0.0.1", "example.com:80:127.0.0.2"}, nil, true}, 46 | {"duplicate host same target", []string{"example.com:80:127.0.0.1", "example.com:80:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80"}, false}, 47 | {"invalid format", []string{"example.com:80"}, nil, true}, 48 | {"invalid hostname format, is IP Addr", []string{"127.0.0.1:443:127.0.0.2"}, nil, true}, 49 | } 50 | 51 | for _, tc := range testCases { 52 | t.Run(tc.name, func(t *testing.T) { 53 | resolveOverrides, err := ResolveOverridesToMap(tc.resolve) 54 | assert.Equal(t, tc.err, err != nil) 55 | assert.Equal(t, tc.expected, resolveOverrides) 56 | }) 57 | } 58 | } 59 | 60 | func helperUrlParse(t *testing.T, uris ...string) []*url.URL { 61 | t.Helper() 62 | var urls []*url.URL 63 | for _, uri := range uris { 64 | u, err := url.Parse(uri) 65 | require.NoError(t, err) 66 | urls = append(urls, u) 67 | } 68 | return urls 69 | } 70 | 71 | func TestCacheableURIPrefixes(t *testing.T) { 72 | defer func() { 73 | viper.Reset() 74 | }() 75 | testCases := []struct { 76 | name string 77 | prefixes []string 78 | expected map[string][]*url.URL 79 | }{ 80 | { 81 | name: "default", 82 | expected: map[string][]*url.URL{ 83 | "weights.replicate.delivery": helperUrlParse(t, "https://weights.replicate.delivery"), 84 | }, 85 | }, 86 | { 87 | name: "single", 88 | prefixes: []string{"http://example.com"}, 89 | expected: map[string][]*url.URL{ 90 | "example.com": helperUrlParse(t, "http://example.com"), 91 | }, 92 | }, 93 | { 94 | name: "multiple", prefixes: []string{"http://example.com", "http://example.org"}, 95 | expected: map[string][]*url.URL{ 96 | "example.com": helperUrlParse(t, "http://example.com"), 97 | "example.org": helperUrlParse(t, "http://example.org"), 98 | }, 99 | }, 100 | { 101 | name: "multiple same domain merged", 102 | prefixes: []string{"http://example.com/path", "http://example.com/other"}, 103 | expected: map[string][]*url.URL{ 104 | "example.com": helperUrlParse(t, "http://example.com/path", "http://example.com/other"), 105 | }, 106 | }, 107 | { 108 | name: "invalid ignored", 109 | prefixes: []string{"http://example.com", "http://example.org", "invalid"}, 110 | expected: map[string][]*url.URL{ 111 | "example.com": helperUrlParse(t, "http://example.com"), 112 | "example.org": helperUrlParse(t, "http://example.org"), 113 | }, 114 | }, 115 | { 116 | name: "single with path", 117 | prefixes: []string{"http://example.com/path"}, 118 | expected: map[string][]*url.URL{ 119 | "example.com": helperUrlParse(t, "http://example.com/path"), 120 | }, 121 | }, 122 | { 123 | name: "multiple with path", 124 | prefixes: []string{"http://example.com/path", "http://example.org/path"}, 125 | expected: map[string][]*url.URL{ 126 | "example.com": helperUrlParse(t, "http://example.com/path"), 127 | "example.org": helperUrlParse(t, "http://example.org/path"), 128 | }, 129 | }, 130 | } 131 | for _, tc := range testCases { 132 | t.Run(tc.name, func(t *testing.T) { 133 | viper.Set(OptCacheURIPrefixes, strings.Join(tc.prefixes, " ")) 134 | actual := CacheableURIPrefixes() 135 | assert.Equal(t, tc.expected, actual) 136 | viper.Reset() 137 | }) 138 | } 139 | } 140 | 141 | func TestGetCacheSRV(t *testing.T) { 142 | defer func() { 143 | viper.Reset() 144 | }() 145 | testCases := []struct { 146 | name string 147 | srvName string 148 | hostIP string 149 | srvNameByHostIP string 150 | expected string 151 | }{ 152 | {"empty", "", "", ``, ""}, 153 | {"provided", "cache.srv.name.example", "", ``, "cache.srv.name.example"}, 154 | {"looked up", "", "192.0.2.37", `{"192.0.2.0/24":"cache.srv.name.example"}`, "cache.srv.name.example"}, 155 | {"both provided", "direct", "192.0.2.37", `{"192.0.2.0/24":"from-map"}`, "direct"}, 156 | {"chooses correct value from map", 157 | "", 158 | "192.0.2.37", 159 | `{ 160 | "192.0.2.0/27": "cache-1", 161 | "192.0.2.32/27": "cache-2" 162 | }`, 163 | "cache-2"}, 164 | {"missing from map", "", "192.0.2.37", `{"192.0.2.0/30":"cache.srv.name.example"}`, ""}, 165 | {"hostIP but no map", "", "192.0.2.37", ``, ""}, 166 | {"invalid map", "", "192.0.2.37", `{`, ""}, 167 | {"invalid CIDR", "", "192.0.2.37", `{"500.0.2.0/0":"cache.srv.name.example"}`, ""}, 168 | {"valid + invalid CIDRs", 169 | "", 170 | "192.0.2.37", 171 | `{ 172 | "192.0.2.0/24": "cache-valid", 173 | "500.0.2.0/30": "cache-invalid" 174 | }`, 175 | "cache-valid"}, 176 | } 177 | for _, tc := range testCases { 178 | t.Run(tc.name, func(t *testing.T) { 179 | viper.Set(OptCacheNodesSRVName, tc.srvName) 180 | viper.Set(OptHostIP, tc.hostIP) 181 | viper.Set(OptCacheNodesSRVNameByHostCIDR, tc.srvNameByHostIP) 182 | actual := GetCacheSRV() 183 | assert.Equal(t, tc.expected, actual) 184 | viper.Reset() 185 | }) 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /pkg/config/optnames.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | const ( 4 | // these options are a massive hack. They're only availabe via 5 | // envvar, not command line 6 | OptCacheNodesSRVNameByHostCIDR = "cache-nodes-srv-name-by-host-cidr" 7 | OptCacheNodesSRVName = "cache-nodes-srv-name" 8 | OptCacheServiceHostname = "cache-service-hostname" 9 | OptCacheURIPrefixes = "cache-uri-prefixes" 10 | OptCacheUsePathProxy = "cache-use-path-proxy" 11 | OptForceCachePrefixRewrite = "force-cache-prefix-rewrite" 12 | OptHostIP = "host-ip" 13 | OptMetricsEndpoint = "metrics-endpoint" 14 | OptHeaders = "headers" 15 | 16 | // Normal options with CLI arguments 17 | OptConcurrency = "concurrency" 18 | OptConnTimeout = "connect-timeout" 19 | OptChunkSize = "chunk-size" 20 | OptExtract = "extract" 21 | OptForce = "force" 22 | OptForceHTTP2 = "force-http2" 23 | OptLoggingLevel = "log-level" 24 | OptMaxChunks = "max-chunks" 25 | OptMaxConnPerHost = "max-conn-per-host" 26 | OptMaxConcurrentFiles = "max-concurrent-files" 27 | OptMinimumChunkSize = "minimum-chunk-size" 28 | OptOutputConsumer = "output" 29 | OptPIDFile = "pid-file" 30 | OptResolve = "resolve" 31 | OptRetries = "retries" 32 | OptVerbose = "verbose" 33 | ) 34 | -------------------------------------------------------------------------------- /pkg/consistent/consistent.go: -------------------------------------------------------------------------------- 1 | // Package consistent implements consistent hashing for cache nodes. 2 | package consistent 3 | 4 | import ( 5 | "fmt" 6 | "slices" 7 | 8 | "github.com/dgryski/go-jump" 9 | "github.com/mitchellh/hashstructure/v2" 10 | ) 11 | 12 | type cacheKey struct { 13 | Key any 14 | Attempt int 15 | } 16 | 17 | // HashBucket returns a bucket from [0,buckets). If you want to implement a 18 | // retry, you can pass previousBuckets, which indicates buckets which must be 19 | // avoided in the output. HashBucket will modify the previousBuckets slice by 20 | // sorting it. 21 | func HashBucket(key any, buckets int, previousBuckets ...int) (int, error) { 22 | if len(previousBuckets) >= buckets { 23 | return -1, fmt.Errorf("No more buckets left: %d buckets available but %d already attempted", buckets, previousBuckets) 24 | } 25 | // we set IgnoreZeroValue so that we can add fields to the hash key 26 | // later without breaking things. 27 | // note that it's not safe to share a HashOptions so we create a fresh one each time. 28 | hashopts := &hashstructure.HashOptions{IgnoreZeroValue: true} 29 | hash, err := hashstructure.Hash(cacheKey{Key: key, Attempt: len(previousBuckets)}, hashstructure.FormatV2, hashopts) 30 | if err != nil { 31 | return -1, fmt.Errorf("error calculating hash of key: %w", err) 32 | } 33 | 34 | // jump is an implementation of Google's Jump Consistent Hash. 35 | // 36 | // See http://arxiv.org/abs/1406.2294 for details. 37 | bucket := int(jump.Hash(hash, buckets-len(previousBuckets))) 38 | slices.Sort(previousBuckets) 39 | for _, prev := range previousBuckets { 40 | if bucket >= prev { 41 | bucket++ 42 | } 43 | } 44 | return bucket, nil 45 | } 46 | -------------------------------------------------------------------------------- /pkg/consistent/consistent_test.go: -------------------------------------------------------------------------------- 1 | package consistent_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/replicate/pget/pkg/consistent" 10 | ) 11 | 12 | func TestHashingDoesNotChangeWhenZeroValueFieldsAreAdded(t *testing.T) { 13 | a, err := consistent.HashBucket(struct{}{}, 1024) 14 | require.NoError(t, err) 15 | b, err := consistent.HashBucket(struct{ I int }{}, 1024) 16 | require.NoError(t, err) 17 | 18 | assert.Equal(t, a, b) 19 | } 20 | 21 | func TestRetriesScatterBuckets(t *testing.T) { 22 | // This test is tricky! We want an example of hash keys which map to the 23 | // same bucket, but after one retry map to different buckets. 24 | // 25 | // These two keys happen to have this property for 10 buckets: 26 | strA := "abcdefg" 27 | strB := "1234567" 28 | a, err := consistent.HashBucket(strA, 10) 29 | require.NoError(t, err) 30 | b, err := consistent.HashBucket(strB, 10) 31 | require.NoError(t, err) 32 | 33 | // strA and strB to map to the same bucket 34 | require.Equal(t, a, b) 35 | 36 | aRetry, err := consistent.HashBucket(strA, 10, a) 37 | require.NoError(t, err) 38 | bRetry, err := consistent.HashBucket(strB, 10, b) 39 | require.NoError(t, err) 40 | 41 | // but after retry they map to different buckets 42 | assert.NotEqual(t, aRetry, bRetry) 43 | } 44 | 45 | func FuzzRetriesMostNotRepeatIndices(f *testing.F) { 46 | f.Add("test.replicate.delivery", 5) 47 | f.Add("test.replicate.delivery", 0) 48 | f.Fuzz(func(t *testing.T, key string, excessBuckets int) { 49 | if excessBuckets < 0 { 50 | t.Skip("invalid value") 51 | } 52 | attempts := 20 53 | buckets := attempts + excessBuckets 54 | if buckets < 0 { 55 | t.Skip("integer overflow") 56 | } 57 | previous := []int{} 58 | for i := 0; i < attempts; i++ { 59 | next, err := consistent.HashBucket(key, buckets, previous...) 60 | require.NoError(t, err) 61 | 62 | // we must be in range 63 | assert.Less(t, next, buckets) 64 | assert.GreaterOrEqual(t, next, 0) 65 | 66 | // we shouldn't repeat any previous value 67 | assert.NotContains(t, previous, next) 68 | 69 | previous = append(previous, next) 70 | } 71 | }) 72 | } 73 | -------------------------------------------------------------------------------- /pkg/consumer/consumer.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import "io" 4 | 5 | type Consumer interface { 6 | Consume(reader io.Reader, destPath string, expectedBytes int64) error 7 | } 8 | -------------------------------------------------------------------------------- /pkg/consumer/consumer_test.go: -------------------------------------------------------------------------------- 1 | package consumer_test 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | const ( 8 | kB int64 = 1024 9 | ) 10 | 11 | // generateTestContent generates a byte slice of a random size > 1KiB 12 | func generateTestContent(size int64) []byte { 13 | content := make([]byte, size) 14 | // Generate random bytes and write them to the content slice 15 | for i := range content { 16 | content[i] = byte(rand.Intn(256)) 17 | } 18 | return content 19 | 20 | } 21 | -------------------------------------------------------------------------------- /pkg/consumer/null.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | type NullWriter struct{} 9 | 10 | var _ Consumer = &NullWriter{} 11 | 12 | func (NullWriter) Consume(reader io.Reader, destPath string, expectedBytes int64) error { 13 | // io.Discard is explicitly designed to always succeed, ignore errors. 14 | bytesRead, _ := io.Copy(io.Discard, reader) 15 | if bytesRead != expectedBytes { 16 | return fmt.Errorf("expected %d bytes, read %d", expectedBytes, bytesRead) 17 | } 18 | return nil 19 | } 20 | -------------------------------------------------------------------------------- /pkg/consumer/null_test.go: -------------------------------------------------------------------------------- 1 | package consumer_test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | "github.com/replicate/pget/pkg/consumer" 10 | ) 11 | 12 | func TestNullWriter_Consume(t *testing.T) { 13 | r := require.New(t) 14 | buf := generateTestContent(kB) 15 | reader := bytes.NewReader(buf) 16 | 17 | nullConsumer := consumer.NullWriter{} 18 | r.NoError(nullConsumer.Consume(reader, "", kB)) 19 | 20 | _, _ = reader.Seek(0, 0) 21 | r.Error(nullConsumer.Consume(reader, "", kB-100)) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/consumer/tar_extractor.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/replicate/pget/pkg/extract" 9 | ) 10 | 11 | type TarExtractor struct { 12 | Overwrite bool 13 | } 14 | 15 | var _ Consumer = &TarExtractor{} 16 | 17 | var _ io.Reader = &byteTrackingReader{} 18 | 19 | type byteTrackingReader struct { 20 | bytesRead int64 21 | r io.Reader 22 | } 23 | 24 | func (b *byteTrackingReader) Read(p []byte) (n int, err error) { 25 | n, err = b.r.Read(p) 26 | b.bytesRead += int64(n) 27 | return 28 | } 29 | 30 | func (f *TarExtractor) Consume(reader io.Reader, destPath string, expectedBytes int64) error { 31 | btReader := &byteTrackingReader{r: reader} 32 | err := extract.TarFile(bufio.NewReader(btReader), destPath, f.Overwrite) 33 | if err != nil { 34 | return fmt.Errorf("error extracting file: %w", err) 35 | } 36 | if btReader.bytesRead != expectedBytes { 37 | return fmt.Errorf("expected %d bytes, read %d from archive", expectedBytes, btReader.bytesRead) 38 | } 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /pkg/consumer/tar_extractor_test.go: -------------------------------------------------------------------------------- 1 | package consumer_test 2 | 3 | import ( 4 | "archive/tar" 5 | "bytes" 6 | "io" 7 | "os" 8 | "path" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/replicate/pget/pkg/consumer" 15 | ) 16 | 17 | const ( 18 | file1Content = "This is the content of file1." 19 | file2Content = "This is the content of file2." 20 | file1Path = "file1.txt" 21 | file2Path = "file2.txt" 22 | fileSymLinkPath = "link_to_file1.txt" 23 | fileHardLinkPath = "subdir/hard_link_to_file2.txt" 24 | ) 25 | 26 | func createTarFileBytesBuffer() ([]byte, error) { 27 | // Create an in-memory representation of a tar file dynamically. This will be used to test the TarExtractor 28 | 29 | var buf bytes.Buffer 30 | tw := tar.NewWriter(&buf) 31 | 32 | // Create first file 33 | content1 := []byte(file1Content) 34 | hdr := &tar.Header{ 35 | Name: file1Path, 36 | Mode: 0600, 37 | Size: int64(len(content1)), 38 | ModTime: time.Now(), 39 | } 40 | if err := tw.WriteHeader(hdr); err != nil { 41 | return nil, err 42 | } 43 | if _, err := tw.Write(content1); err != nil { 44 | return nil, err 45 | } 46 | 47 | // Create second file 48 | content2 := []byte(file2Content) 49 | hdr = &tar.Header{ 50 | Name: file2Path, 51 | Mode: 0600, 52 | Size: int64(len(content2)), 53 | ModTime: time.Now(), 54 | } 55 | if err := tw.WriteHeader(hdr); err != nil { 56 | return nil, err 57 | } 58 | if _, err := tw.Write(content2); err != nil { 59 | return nil, err 60 | } 61 | 62 | // Create a symlink to file1 63 | hdr = &tar.Header{ 64 | Name: fileSymLinkPath, 65 | Mode: 0777, 66 | Size: 0, 67 | Linkname: file1Path, 68 | Typeflag: tar.TypeSymlink, 69 | ModTime: time.Now(), 70 | } 71 | if err := tw.WriteHeader(hdr); err != nil { 72 | return nil, err 73 | } 74 | 75 | // Create a subdirectory or path for the hardlink 76 | hdr = &tar.Header{ 77 | Name: "subdir/", 78 | Mode: 0755, 79 | Typeflag: tar.TypeDir, 80 | ModTime: time.Now(), 81 | } 82 | if err := tw.WriteHeader(hdr); err != nil { 83 | return nil, err 84 | } 85 | 86 | // Create a hardlink to file2 in the subdirectory 87 | hdr = &tar.Header{ 88 | Name: fileHardLinkPath, 89 | Mode: 0600, 90 | Size: 0, 91 | Linkname: file2Path, 92 | Typeflag: tar.TypeLink, 93 | ModTime: time.Now(), 94 | } 95 | if err := tw.WriteHeader(hdr); err != nil { 96 | return nil, err 97 | } 98 | 99 | // Close the tar writer to flush the data 100 | if err := tw.Close(); err != nil { 101 | return nil, err 102 | } 103 | 104 | return buf.Bytes(), nil 105 | } 106 | 107 | func TestTarExtractor_Consume(t *testing.T) { 108 | r := require.New(t) 109 | 110 | tarFileBytes, err := createTarFileBytesBuffer() 111 | r.NoError(err) 112 | 113 | // Create a reader from the tar file bytes 114 | reader := io.MultiReader(bytes.NewReader(tarFileBytes), bytes.NewReader(make([]byte, 1024))) 115 | 116 | // Create a temporary directory to extract the tar file 117 | tmpDir, err := os.MkdirTemp("", "tarExtractorTest-") 118 | r.NoError(err) 119 | 120 | t.Cleanup(func() { os.RemoveAll(tmpDir) }) 121 | 122 | tarConsumer := consumer.TarExtractor{} 123 | targetDir := path.Join(tmpDir, "extract") 124 | r.NoError(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)+1024))) 125 | 126 | // Check if the extraction was successful 127 | checkTarExtraction(t, targetDir) 128 | 129 | // Test with incorrect expectedBytes 130 | reader = io.MultiReader(bytes.NewReader(tarFileBytes), bytes.NewReader(make([]byte, 1024))) 131 | targetDir = path.Join(tmpDir, "extract-fail") 132 | r.Error(tarConsumer.Consume(reader, targetDir, int64(len(tarFileBytes)+1024-1))) 133 | } 134 | 135 | func checkTarExtraction(t *testing.T, targetDir string) { 136 | r := require.New(t) 137 | 138 | // Verify that file1.txt is correctly extracted 139 | fqFile1Path := path.Join(targetDir, file1Path) 140 | content, err := os.ReadFile(fqFile1Path) 141 | r.NoError(err) 142 | r.Equal(file1Content, string(content)) 143 | 144 | // Verify that file2.txt is correctly extracted 145 | fqFile2Path := path.Join(targetDir, file2Path) 146 | content, err = os.ReadFile(fqFile2Path) 147 | r.NoError(err) 148 | r.Equal(file2Content, string(content)) 149 | 150 | // Verify that link_to_file1.txt is a symlink pointing to file1.txt 151 | linkToFile1Path := path.Join(targetDir, fileSymLinkPath) 152 | linkTarget, err := os.Readlink(linkToFile1Path) 153 | r.NoError(err) 154 | r.Equal(file1Path, linkTarget) 155 | r.Equal(os.ModeSymlink, os.ModeSymlink&os.ModeType) 156 | 157 | // Verify that subdir/hard_link_to_file2.txt is a hard link to file2.txt 158 | hardLinkToFile2Path := path.Join(targetDir, fileHardLinkPath) 159 | hardLinkStat, err := os.Stat(hardLinkToFile2Path) 160 | r.NoError(err) 161 | file2Stat, err := os.Stat(fqFile2Path) 162 | r.NoError(err) 163 | 164 | if !os.SameFile(hardLinkStat, file2Stat) { 165 | t.Errorf("hard link does not match file2.txt") 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /pkg/consumer/write_file.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | type FileWriter struct { 11 | Overwrite bool 12 | } 13 | 14 | var _ Consumer = &FileWriter{} 15 | 16 | func (f *FileWriter) Consume(reader io.Reader, destPath string, expectedBytes int64) error { 17 | openFlags := os.O_WRONLY | os.O_CREATE 18 | targetDir := filepath.Dir(destPath) 19 | if err := os.MkdirAll(targetDir, 0755); err != nil { 20 | return fmt.Errorf("error creating directory: %w", err) 21 | } 22 | if f.Overwrite { 23 | openFlags |= os.O_TRUNC 24 | } 25 | out, err := os.OpenFile(destPath, openFlags, 0644) 26 | if err != nil { 27 | return fmt.Errorf("error writing file: %w", err) 28 | } 29 | defer out.Close() 30 | 31 | written, err := io.Copy(out, reader) 32 | if err != nil { 33 | return fmt.Errorf("error writing file: %w", err) 34 | } 35 | 36 | if written != expectedBytes { 37 | return fmt.Errorf("expected %d bytes, wrote %d", expectedBytes, written) 38 | } 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /pkg/consumer/write_file_test.go: -------------------------------------------------------------------------------- 1 | package consumer_test 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/replicate/pget/pkg/consumer" 11 | ) 12 | 13 | func TestFileWriter_Consume(t *testing.T) { 14 | r := require.New(t) 15 | 16 | buf := generateTestContent(kB) 17 | reader := bytes.NewReader(buf) 18 | 19 | writeFileConsumer := consumer.FileWriter{} 20 | tmpFile, _ := os.CreateTemp("", "fileWriterTest-") 21 | 22 | t.Cleanup(func() { 23 | tmpFile.Close() 24 | os.Remove(tmpFile.Name()) 25 | }) 26 | 27 | r.NoError(writeFileConsumer.Consume(reader, tmpFile.Name(), kB)) 28 | 29 | // Check the file content is correct 30 | fileContent, _ := os.ReadFile(tmpFile.Name()) 31 | r.Equal(buf, fileContent) 32 | 33 | _, _ = reader.Seek(0, 0) 34 | r.Error(writeFileConsumer.Consume(reader, "", kB-100)) 35 | 36 | // test overwrite 37 | // overwrite the file 38 | f, err := os.OpenFile(tmpFile.Name(), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) 39 | r.NoError(err) 40 | _, _ = f.Write([]byte("different content")) 41 | f.Close() 42 | 43 | // consume the reader 44 | _, _ = reader.Seek(0, 0) 45 | writeFileConsumer.Overwrite = true 46 | r.NoError(writeFileConsumer.Consume(reader, tmpFile.Name(), kB)) 47 | 48 | // check the file content is correct 49 | fileContent, _ = os.ReadFile(tmpFile.Name()) 50 | r.Equal(buf, fileContent) 51 | } 52 | -------------------------------------------------------------------------------- /pkg/download/buffer.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/rs/zerolog" 13 | 14 | "github.com/replicate/pget/pkg/client" 15 | "github.com/replicate/pget/pkg/logging" 16 | ) 17 | 18 | type BufferMode struct { 19 | Client client.HTTPClient 20 | Options 21 | 22 | queue *priorityWorkQueue 23 | } 24 | 25 | func GetBufferMode(opts Options) *BufferMode { 26 | client := client.NewHTTPClient(opts.Client) 27 | m := &BufferMode{ 28 | Client: client, 29 | Options: opts, 30 | } 31 | m.queue = newWorkQueue(opts.maxConcurrency(), m.chunkSize()) 32 | m.queue.start() 33 | return m 34 | } 35 | 36 | func (m *BufferMode) chunkSize() int64 { 37 | minChunkSize := m.ChunkSize 38 | if minChunkSize == 0 { 39 | return defaultChunkSize 40 | } 41 | return minChunkSize 42 | } 43 | 44 | func (m *BufferMode) getFileSizeFromResponse(resp *http.Response) (int64, error) { 45 | // If the response is a 200 OK, we need to parse the file size assuming the whole 46 | // file was returned. If it isn't, we will assume this was a 206 Partial Content 47 | // and parse the file size from the content range header. We wouldn't be in this 48 | // function if the response was not between 200 and 300, so this feels like a 49 | // reasonable assumption 50 | if resp.StatusCode == http.StatusOK { 51 | return m.getFileSizeFromContentLength(resp.Header.Get("Content-Length")) 52 | } 53 | return m.getFileSizeFromContentRange(resp.Header.Get("Content-Range")) 54 | } 55 | 56 | func (m *BufferMode) getFileSizeFromContentLength(contentLength string) (int64, error) { 57 | size, err := strconv.ParseInt(contentLength, 10, 64) 58 | if err != nil { 59 | return 0, err 60 | } 61 | 62 | return size, nil 63 | } 64 | 65 | func (m *BufferMode) getFileSizeFromContentRange(contentRange string) (int64, error) { 66 | groups := contentRangeRegexp.FindStringSubmatch(contentRange) 67 | if groups == nil { 68 | return -1, fmt.Errorf("couldn't parse Content-Range: %s", contentRange) 69 | } 70 | return strconv.ParseInt(groups[1], 10, 64) 71 | } 72 | 73 | type firstReqResult struct { 74 | fileSize int64 75 | trueURL string 76 | err error 77 | } 78 | 79 | func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { 80 | logger := logging.GetLogger() 81 | 82 | firstChunk := newReaderPromise() 83 | 84 | firstReqResultCh := make(chan firstReqResult) 85 | m.queue.submitLow(func(buf []byte) { 86 | defer close(firstReqResultCh) 87 | 88 | if m.CacheHosts != nil { 89 | url = m.rewriteUrlForCache(url) 90 | } 91 | 92 | firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url) 93 | if err != nil { 94 | firstReqResultCh <- firstReqResult{err: err} 95 | return 96 | } 97 | 98 | defer firstChunkResp.Body.Close() 99 | 100 | trueURL := firstChunkResp.Request.URL.String() 101 | if trueURL != url { 102 | logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") 103 | } 104 | 105 | fileSize, err := m.getFileSizeFromResponse(firstChunkResp) 106 | if err != nil { 107 | firstReqResultCh <- firstReqResult{err: err} 108 | return 109 | } 110 | firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL} 111 | 112 | contentLength := firstChunkResp.ContentLength 113 | n, err := io.ReadFull(firstChunkResp.Body, buf[0:contentLength]) 114 | if err == io.ErrUnexpectedEOF { 115 | logger.Warn(). 116 | Int("connection_interrupted_at_byte", n). 117 | Msg("Resuming Chunk Download") 118 | n, err = resumeDownload(firstChunkResp.Request, buf[n:contentLength], m.Client, int64(n)) 119 | } 120 | firstChunk.Deliver(buf[0:n], err) 121 | }) 122 | 123 | firstReqResult, ok := <-firstReqResultCh 124 | if !ok { 125 | panic("logic error in BufferMode: first request didn't return any output") 126 | } 127 | 128 | if firstReqResult.err != nil { 129 | return nil, -1, firstReqResult.err 130 | } 131 | 132 | fileSize := firstReqResult.fileSize 133 | trueURL := firstReqResult.trueURL 134 | 135 | if fileSize <= m.chunkSize() { 136 | // we only need a single chunk: just download it and finish 137 | return firstChunk, fileSize, nil 138 | } 139 | 140 | remainingBytes := fileSize - m.chunkSize() 141 | // integer divide rounding up 142 | numChunks := int((remainingBytes-1)/m.chunkSize() + 1) 143 | 144 | chunks := make([]io.Reader, numChunks+1) 145 | chunks[0] = firstChunk 146 | 147 | startOffset := m.chunkSize() 148 | 149 | logger.Debug().Str("url", url). 150 | Int64("size", fileSize). 151 | Int("connections", numChunks). 152 | Int64("chunkSize", m.chunkSize()). 153 | Msg("Downloading") 154 | 155 | for i := 0; i < numChunks; i++ { 156 | chunk := newReaderPromise() 157 | chunks[i+1] = chunk 158 | } 159 | go func(chunks []io.Reader) { 160 | for i, reader := range chunks { 161 | chunk := reader.(*readerPromise) 162 | m.queue.submitHigh(func(buf []byte) { 163 | start := startOffset + m.chunkSize()*int64(i) 164 | end := start + m.chunkSize() - 1 165 | 166 | if i == numChunks-1 { 167 | end = fileSize - 1 168 | } 169 | logger.Debug().Str("url", url). 170 | Int64("size", fileSize). 171 | Int("chunk", i). 172 | Msg("Downloading chunk") 173 | 174 | resp, err := m.DoRequest(ctx, start, end, trueURL) 175 | if err != nil { 176 | chunk.Deliver(nil, err) 177 | return 178 | } 179 | defer resp.Body.Close() 180 | 181 | contentLength := resp.ContentLength 182 | n, err := io.ReadFull(resp.Body, buf[0:contentLength]) 183 | if err == io.ErrUnexpectedEOF { 184 | logger.Warn(). 185 | Int("connection_interrupted_at_byte", n). 186 | Msg("Resuming Chunk Download") 187 | n, err = resumeDownload(resp.Request, buf[n:contentLength], m.Client, int64(n)) 188 | } 189 | chunk.Deliver(buf[0:n], err) 190 | }) 191 | } 192 | }(chunks[1:]) 193 | 194 | return io.MultiReader(chunks...), fileSize, nil 195 | } 196 | 197 | func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { 198 | req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil) 199 | if err != nil { 200 | return nil, fmt.Errorf("failed to download %s: %w", trueURL, err) 201 | } 202 | req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) 203 | resp, err := m.Client.Do(req) 204 | if err != nil { 205 | return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) 206 | } 207 | if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 { 208 | return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status) 209 | } 210 | 211 | return resp, nil 212 | } 213 | 214 | func (m *BufferMode) rewriteUrlForCache(urlString string) string { 215 | logger := logging.GetLogger() 216 | parsed, err := url.Parse(urlString) 217 | if m.CacheHosts == nil || len(m.CacheHosts) != 1 { 218 | logger.Error(). 219 | Str("url", urlString). 220 | Bool("enabled", false). 221 | Str("disabled_reason", fmt.Sprintf("expected exactly 1 cache host, received %d", len(m.CacheHosts))). 222 | Msg("Cache URL Rewrite") 223 | return urlString 224 | } 225 | if err != nil { 226 | logger.Error(). 227 | Err(err). 228 | Str("url", urlString). 229 | Bool("enabled", false). 230 | Str("disabled_reason", "failed to parse URL"). 231 | Msg("Cache URL Rewrite") 232 | return urlString 233 | } 234 | if m.ForceCachePrefixRewrite { 235 | // Forcefully rewrite the URL prefix 236 | return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) 237 | } else { 238 | if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { 239 | for _, pfx := range prefixes { 240 | if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { 241 | // Found a matching prefix, rewrite the URL prefix 242 | return m.rewritePrefix(m.CacheHosts[0], urlString, parsed, logger) 243 | } 244 | } 245 | } 246 | } 247 | 248 | // If we got here, we weren't forcefully rewriting the cache prefix and we didn't 249 | // find any matching prefixes, so we just return the original URL 250 | logger.Debug(). 251 | Str("url", urlString). 252 | Bool("enabled", false). 253 | Str("disabled_reason", "no matching prefix"). 254 | Str("disabled_reason", "failed to join host URL to path"). 255 | Msg("Cache URL Rewrite") 256 | return urlString 257 | } 258 | 259 | func (m *BufferMode) rewritePrefix(cacheHost, urlString string, parsed *url.URL, logger zerolog.Logger) string { 260 | newUrl := cacheHost 261 | var err error 262 | if m.CacheUsePathProxy { 263 | newUrl, err = url.JoinPath(newUrl, parsed.Host) 264 | if err != nil { 265 | logger.Error(). 266 | Err(err). 267 | Str("url", urlString). 268 | Bool("enabled", false). 269 | Str("disabled_reason", "failed to join cache URL to host"). 270 | Msg("Cache URL Rewrite") 271 | return urlString 272 | } 273 | logger.Debug(). 274 | Bool("path_based_proxy", true). 275 | Str("host_prefix", parsed.Host). 276 | Str("intermediate_target_url", newUrl). 277 | Str("url", urlString). 278 | Msg("Cache URL Rewrite") 279 | } 280 | newUrl, err = url.JoinPath(newUrl, parsed.Path) 281 | if err != nil { 282 | logger.Error(). 283 | Err(err). 284 | Str("url", urlString). 285 | Bool("enabled", false). 286 | Str("disabled_reason", "failed to join host URL to path"). 287 | Msg("Cache URL Rewrite") 288 | return urlString 289 | } 290 | logger.Info(). 291 | Str("url", urlString). 292 | Str("target_url", newUrl). 293 | Bool("enabled", true). 294 | Msg("Cache URL Rewrite") 295 | return newUrl 296 | } 297 | -------------------------------------------------------------------------------- /pkg/download/buffer_slow_test.go: -------------------------------------------------------------------------------- 1 | //go:build slow 2 | // +build slow 3 | 4 | package download_test 5 | 6 | import ( 7 | "github.com/dustin/go-humanize" 8 | 9 | "testing" 10 | ) 11 | 12 | func BenchmarkDownload10G(b *testing.B) { 13 | benchmarkDownloadSingleFile(defaultOpts, 10*humanize.GiByte, b) 14 | } 15 | func BenchmarkDownload10GH2(b *testing.B) { 16 | benchmarkDownloadSingleFile(http2Opts, 10*humanize.GiByte, b) 17 | } 18 | 19 | func BenchmarkDownloadDollyTensors(b *testing.B) { 20 | benchmarkDownloadURL(defaultOpts, "https://weights.replicate.delivery/default/dolly-v2-12b-fp16.tensors", b) 21 | } 22 | -------------------------------------------------------------------------------- /pkg/download/buffer_test.go: -------------------------------------------------------------------------------- 1 | package download_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/rs/zerolog" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/replicate/pget/pkg/client" 11 | "github.com/replicate/pget/pkg/download" 12 | ) 13 | 14 | func init() { 15 | zerolog.SetGlobalLevel(zerolog.WarnLevel) 16 | } 17 | 18 | var defaultOpts = download.Options{Client: client.Options{}} 19 | var http2Opts = download.Options{Client: client.Options{TransportOpts: client.TransportOptions{ForceHTTP2: true}}} 20 | 21 | func benchmarkDownloadURL(opts download.Options, url string, b *testing.B) { 22 | bufferMode := download.GetBufferMode(opts) 23 | 24 | for n := 0; n < b.N; n++ { 25 | ctx, cancel := context.WithCancel(context.Background()) 26 | defer cancel() 27 | 28 | _, _, err := bufferMode.Fetch(ctx, url) 29 | assert.NoError(b, err) 30 | } 31 | } 32 | 33 | func BenchmarkDownloadBertH1(b *testing.B) { 34 | benchmarkDownloadURL(defaultOpts, "https://weights.replicate.delivery/default/bert-base-uncased-hf-cache.tar", b) 35 | } 36 | func BenchmarkDownloadBertH2(b *testing.B) { 37 | benchmarkDownloadURL(http2Opts, "https://weights.replicate.delivery/default/bert-base-uncased-hf-cache.tar", b) 38 | } 39 | func BenchmarkDownloadLlama7bChatH1(b *testing.B) { 40 | benchmarkDownloadURL(defaultOpts, "https://weights.replicate.delivery/default/Llama-2-7b-Chat-GPTQ/gptq_model-4bit-32g.safetensors", b) 41 | } 42 | func BenchmarkDownloadLlama7bChatH2(b *testing.B) { 43 | benchmarkDownloadURL(http2Opts, "https://weights.replicate.delivery/default/Llama-2-7b-Chat-GPTQ/gptq_model-4bit-32g.safetensors", b) 44 | } 45 | -------------------------------------------------------------------------------- /pkg/download/buffer_unit_test.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "strings" 12 | "testing" 13 | "testing/fstest" 14 | 15 | "github.com/dustin/go-humanize" 16 | "github.com/jarcoal/httpmock" 17 | "github.com/rs/zerolog" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | 21 | "github.com/replicate/pget/pkg/client" 22 | ) 23 | 24 | func init() { 25 | zerolog.SetGlobalLevel(zerolog.WarnLevel) 26 | } 27 | 28 | const testFilePath = "test.txt" 29 | 30 | // generateTestContent generates a byte slice of a random size > 1KiB 31 | func generateTestContent(size int64) []byte { 32 | content := make([]byte, size) 33 | // Generate random bytes and write them to the content slice 34 | for i := range content { 35 | content[i] = byte(rand.Intn(256)) 36 | } 37 | return content 38 | 39 | } 40 | 41 | // newTestServer creates a new http server that serves the given content 42 | func newTestServer(t *testing.T, content []byte) *httptest.Server { 43 | testFileSystem := fstest.MapFS{testFilePath: {Data: content}} 44 | server := httptest.NewServer(http.FileServer(http.FS(testFileSystem))) 45 | return server 46 | } 47 | 48 | func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { 49 | contentSize := int64(humanize.KiByte) 50 | content := generateTestContent(contentSize) 51 | server := newTestServer(t, content) 52 | defer server.Close() 53 | opts := Options{ 54 | Client: client.Options{}, 55 | } 56 | // Ensure that the math generally works out as such for this test case where chunkSize is < 0.5* contentSize 57 | // (contentSize - chunkSize) / chunkSize < maxChunks 58 | // This ensures that we're always testing the case where the number of chunks exceeds the maxChunks 59 | // Additional cases added to validate various cases where the final chunk is less than chunkSize 60 | tc := []struct { 61 | name string 62 | maxConcurrency int 63 | chunkSize int64 64 | }{ 65 | // In these first cases we will never have more than 2 chunks as the chunkSize is greater than 0.5*contentSize 66 | { 67 | name: "chunkSize greater than contentSize", 68 | chunkSize: contentSize + 1, 69 | maxConcurrency: 1, 70 | }, 71 | { 72 | name: "chunkSize equal to contentSize", 73 | chunkSize: contentSize, 74 | maxConcurrency: 1, 75 | }, 76 | { 77 | name: "chunkSize less than contentSize", 78 | chunkSize: contentSize - 1, 79 | maxConcurrency: 2, 80 | }, 81 | { 82 | name: "chunkSize is 3/4 contentSize", 83 | chunkSize: int64(float64(contentSize) * 0.75), 84 | maxConcurrency: 2, 85 | }, 86 | { 87 | // This is an exceptional case where we only need a single additional chunk beyond the default "get content size" 88 | // request. 89 | name: "chunkSize is 1/2 contentSize", 90 | chunkSize: int64(float64(contentSize) * 0.5), 91 | maxConcurrency: 2, 92 | }, 93 | // These test cases cover a few scenarios of downloading where the maxChunks will force a re-calculation of 94 | // the chunkSize to ensure that we don't exceed the maxChunks. 95 | { 96 | // remainder will result in 3 chunks, max-chunks is 2 97 | name: "chunkSize is 1/4 contentSize", 98 | chunkSize: int64(float64(contentSize) * 0.25), 99 | maxConcurrency: 2, 100 | }, 101 | { 102 | // humanize.KByte = 1024, remainder will result in 1024/10 = 102 chunks, concurrency is set to 25 103 | // resulting in a chunkSize of 41 104 | name: "many chunks, low maxConcurrency", 105 | chunkSize: 10, 106 | maxConcurrency: 25, 107 | }, 108 | } 109 | 110 | for _, tc := range tc { 111 | t.Run(tc.name, func(t *testing.T) { 112 | opts.MaxConcurrency = tc.maxConcurrency 113 | opts.ChunkSize = tc.chunkSize 114 | bufferMode := GetBufferMode(opts) 115 | path, _ := url.JoinPath(server.URL, testFilePath) 116 | download, size, err := bufferMode.Fetch(context.Background(), path) 117 | require.NoError(t, err) 118 | data, err := io.ReadAll(download) 119 | assert.NoError(t, err) 120 | assert.Equal(t, contentSize, size) 121 | assert.Equal(t, len(content), len(data)) 122 | assert.Equal(t, content, data) 123 | }) 124 | } 125 | } 126 | 127 | func TestReaderReturnsErrorWhenRequestFails(t *testing.T) { 128 | mockTransport := httpmock.NewMockTransport() 129 | opts := Options{ 130 | Client: client.Options{Transport: mockTransport}, 131 | ChunkSize: 2, 132 | } 133 | expectedErr := fmt.Errorf("Expected error in chunk 3") 134 | mockTransport.RegisterResponder("GET", "http://test.example/hello.txt", 135 | func(req *http.Request) (*http.Response, error) { 136 | rangeHeader := req.Header.Get("Range") 137 | var body string 138 | switch rangeHeader { 139 | case "bytes=0-1": 140 | body = "he" 141 | case "bytes=2-3": 142 | body = "ll" 143 | case "bytes=4-5": 144 | body = "o " 145 | case "bytes=6-7": 146 | return nil, expectedErr 147 | default: 148 | return nil, fmt.Errorf("should't see this error") 149 | } 150 | resp := httpmock.NewStringResponse(http.StatusPartialContent, body) 151 | resp.Request = req 152 | resp.Header.Add("Content-Range", strings.Replace(rangeHeader, "=", " ", 1)+"/8") 153 | resp.ContentLength = 2 154 | resp.Header.Add("Content-Length", "2") 155 | return resp, nil 156 | }) 157 | bufferMode := GetBufferMode(opts) 158 | download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt") 159 | // No error here, because the first chunk was fetched successfully 160 | require.NoError(t, err) 161 | // the read should return any error we expect 162 | _, err = io.ReadAll(download) 163 | assert.ErrorIs(t, err, expectedErr) 164 | } 165 | 166 | func TestReaderHandlesFullFile(t *testing.T) { 167 | mockTransport := httpmock.NewMockTransport() 168 | opts := Options{ 169 | Client: client.Options{Transport: mockTransport}, 170 | ChunkSize: 6, 171 | } 172 | mockTransport.RegisterResponder("GET", "http://test.example/hello.txt", 173 | func(req *http.Request) (*http.Response, error) { 174 | rangeHeader := req.Header.Get("Range") 175 | var body string 176 | switch rangeHeader { 177 | case "bytes=0-5": 178 | body = "hello " 179 | default: 180 | return nil, fmt.Errorf("should't see this error") 181 | } 182 | resp := httpmock.NewStringResponse(http.StatusOK, body) 183 | resp.Request = req 184 | resp.ContentLength = 6 185 | resp.Header.Add("Content-Length", "6") 186 | return resp, nil 187 | }) 188 | bufferMode := GetBufferMode(opts) 189 | download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt") 190 | // No error here, because the first chunk was fetched successfully 191 | require.NoError(t, err) 192 | // the read should return any error we expect 193 | out, err := io.ReadAll(download) 194 | assert.Equal(t, "hello ", string(out)) 195 | assert.NoError(t, err) 196 | } 197 | -------------------------------------------------------------------------------- /pkg/download/common.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "regexp" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/dustin/go-humanize" 13 | 14 | "github.com/replicate/pget/pkg/client" 15 | "github.com/replicate/pget/pkg/logging" 16 | ) 17 | 18 | const defaultChunkSize = 125 * humanize.MiByte 19 | 20 | var ( 21 | contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`) 22 | 23 | errMalformedRangeHeader = errors.New("malformed range header") 24 | errMissingRangeHeader = errors.New("missing range header") 25 | errInvalidContentRange = errors.New("invalid content range") 26 | ) 27 | 28 | func resumeDownload(req *http.Request, buffer []byte, client client.HTTPClient, bytesReceived int64) (int, error) { 29 | var startByte int 30 | logger := logging.GetLogger() 31 | 32 | var resumeCount = 1 33 | var initialBytesReceived = bytesReceived 34 | var totalBytesReceived = bytesReceived 35 | 36 | for { 37 | var n int 38 | if err := updateRangeRequestHeader(req, bytesReceived); err != nil { 39 | return int(totalBytesReceived), err 40 | } 41 | 42 | resp, err := client.Do(req) 43 | if err != nil { 44 | return int(totalBytesReceived), err 45 | } 46 | defer resp.Body.Close() 47 | if resp.StatusCode != http.StatusPartialContent { 48 | return int(totalBytesReceived), fmt.Errorf("expected status code %d, got %d", http.StatusPartialContent, resp.StatusCode) 49 | } 50 | n, err = io.ReadFull(resp.Body, buffer[startByte:]) 51 | totalBytesReceived += int64(n) 52 | if err == io.ErrUnexpectedEOF { 53 | bytesReceived = int64(n) 54 | startByte += n 55 | resumeCount++ 56 | logger.Warn(). 57 | Int("connection_interrupted_at_byte", n). 58 | Int("resume_count", resumeCount). 59 | Int64("total_bytes_received", initialBytesReceived+int64(startByte)). 60 | Msg("Resuming Chunk Download") 61 | continue 62 | } 63 | return int(totalBytesReceived), err 64 | 65 | } 66 | } 67 | 68 | func updateRangeRequestHeader(req *http.Request, receivedBytes int64) error { 69 | rangeHeader := req.Header.Get("Range") 70 | if rangeHeader == "" { 71 | return errMissingRangeHeader 72 | } 73 | 74 | // Expected format: "bytes=start-end" 75 | if !strings.HasPrefix(rangeHeader, "bytes=") { 76 | return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader) 77 | } 78 | 79 | rangeValues := strings.TrimPrefix(rangeHeader, "bytes=") 80 | parts := strings.Split(rangeValues, "-") 81 | if len(parts) != 2 { 82 | return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader) 83 | } 84 | 85 | start, err := strconv.ParseInt(parts[0], 10, 64) 86 | if err != nil { 87 | return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader) 88 | } 89 | 90 | end, err := strconv.ParseInt(parts[1], 10, 64) 91 | if err != nil { 92 | return fmt.Errorf("%w: %s", errMalformedRangeHeader, rangeHeader) 93 | } 94 | 95 | start = start + receivedBytes 96 | newRangeHeader := fmt.Sprintf("bytes=%d-%d", start, end) 97 | 98 | if start > end { 99 | return fmt.Errorf("%w: %s", errInvalidContentRange, newRangeHeader) 100 | } 101 | 102 | req.Header.Set("Range", newRangeHeader) 103 | 104 | return nil 105 | } 106 | -------------------------------------------------------------------------------- /pkg/download/common_test.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "sync/atomic" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | type mockHTTPClient struct { 18 | doFunc func(req *http.Request) (*http.Response, error) 19 | callCount atomic.Int32 20 | } 21 | 22 | func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { 23 | m.callCount.Add(1) 24 | return m.doFunc(req) 25 | } 26 | 27 | func TestResumeDownload(t *testing.T) { 28 | tests := []struct { 29 | name string 30 | serverContent string 31 | bytesReceived int64 32 | initialRange string 33 | expectedError error 34 | expectedOutput []byte 35 | expectedCalls int32 36 | }{ 37 | { 38 | name: "successful download", 39 | serverContent: "Hello, world!", 40 | bytesReceived: 0, 41 | initialRange: "bytes=0-12", 42 | expectedError: nil, 43 | expectedOutput: []byte("Hello, world!"), 44 | expectedCalls: 1, 45 | }, 46 | { 47 | name: "partial download", 48 | serverContent: "Hello, world!", 49 | bytesReceived: 3, 50 | initialRange: "bytes=7-12", 51 | expectedError: nil, 52 | expectedOutput: []byte("world!"), 53 | expectedCalls: 1, 54 | }, 55 | { 56 | name: "network error", 57 | serverContent: "Hello, world!", 58 | bytesReceived: 0, 59 | initialRange: "bytes=0-12", 60 | expectedError: errors.New("network error"), 61 | expectedOutput: nil, 62 | expectedCalls: 1, 63 | }, 64 | { 65 | name: "multi-pass download", 66 | serverContent: "12345678901234567890", 67 | bytesReceived: 3, 68 | initialRange: "bytes=10-19", 69 | expectedError: nil, 70 | expectedOutput: []byte("0123456789"), 71 | expectedCalls: 2, 72 | }, 73 | } 74 | 75 | for _, tt := range tests { 76 | t.Run(tt.name, func(t *testing.T) { 77 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | http.ServeContent(w, r, "", time.Time{}, bytes.NewReader([]byte(tt.serverContent))) 79 | })) 80 | defer server.Close() 81 | 82 | req, err := http.NewRequest("GET", server.URL, nil) 83 | assert.NoError(t, err) 84 | 85 | // Set the initial Range header from the test case 86 | req.Header.Set("Range", tt.initialRange) 87 | 88 | buffer := make([]byte, len(tt.expectedOutput)) 89 | copy(buffer, tt.expectedOutput[:tt.bytesReceived]) 90 | mockClient := &mockHTTPClient{ 91 | doFunc: func(req *http.Request) (*http.Response, error) { 92 | if tt.name == "network error" { 93 | return nil, errors.New("network error") 94 | } 95 | if tt.name == "multi-pass download" { 96 | switch req.Header.Get("Range") { 97 | case "bytes=15-19": 98 | return &http.Response{ 99 | StatusCode: http.StatusPartialContent, 100 | Body: io.NopCloser(bytes.NewReader([]byte("56789"))), 101 | Header: http.Header{"Content-Range": []string{"bytes 15-20/21"}}, 102 | }, nil 103 | case "bytes=13-19": 104 | return &http.Response{ 105 | StatusCode: http.StatusPartialContent, 106 | Body: io.NopCloser(bytes.NewReader([]byte("34"))), 107 | Header: http.Header{"Content-Range": []string{"bytes 13-20/21"}}, 108 | }, nil 109 | } 110 | } 111 | return http.DefaultClient.Do(req) 112 | }, 113 | } 114 | 115 | totalBytesReceived, err := resumeDownload(req, buffer[tt.bytesReceived:], mockClient, tt.bytesReceived) 116 | if tt.expectedError != nil { 117 | assert.Error(t, err) 118 | assert.Equal(t, tt.expectedError.Error(), err.Error()) 119 | } else { 120 | assert.NoError(t, err) 121 | assert.Equal(t, len(tt.expectedOutput), totalBytesReceived) 122 | assert.Equal(t, tt.expectedOutput, buffer[:len(tt.expectedOutput)]) 123 | } 124 | assert.Equal(t, tt.expectedCalls, mockClient.callCount.Load(), "Unexpected number of HTTP client calls") 125 | }) 126 | } 127 | } 128 | 129 | func TestUpdateRangeRequestHeader(t *testing.T) { 130 | tests := []struct { 131 | name string 132 | initialRange string 133 | receivedBytes int64 134 | expectedRange string 135 | expectedError error 136 | }{ 137 | { 138 | name: "valid range header", 139 | initialRange: "bytes=0-10", 140 | receivedBytes: 5, 141 | expectedRange: "bytes=5-10", 142 | expectedError: nil, 143 | }, 144 | { 145 | name: "non-zero initial range", 146 | initialRange: "bytes=7-12", 147 | receivedBytes: 3, 148 | expectedRange: "bytes=10-12", 149 | expectedError: nil, 150 | }, 151 | { 152 | name: "missing range header", 153 | initialRange: "", 154 | receivedBytes: 5, 155 | expectedRange: "", 156 | expectedError: errMissingRangeHeader, 157 | }, 158 | { 159 | name: "malformed range header", 160 | initialRange: "bytes=malformed", 161 | receivedBytes: 5, 162 | expectedRange: "", 163 | expectedError: errMalformedRangeHeader, 164 | }, 165 | { 166 | name: "receivedBytes exceeds range", 167 | initialRange: "bytes=0-10", 168 | receivedBytes: 15, 169 | expectedRange: "", 170 | expectedError: errInvalidContentRange, 171 | }, 172 | } 173 | 174 | for _, tt := range tests { 175 | t.Run(tt.name, func(t *testing.T) { 176 | req, err := http.NewRequest("GET", "http://example.com", nil) 177 | assert.NoError(t, err) 178 | req.Header.Set("Range", tt.initialRange) 179 | 180 | err = updateRangeRequestHeader(req, tt.receivedBytes) 181 | if tt.expectedError != nil { 182 | require.Error(t, err) 183 | assert.ErrorIs(t, err, tt.expectedError) 184 | } else { 185 | assert.NoError(t, err) 186 | assert.Equal(t, tt.expectedRange, req.Header.Get("Range")) 187 | } 188 | }) 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /pkg/download/consistent_hashing.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/replicate/pget/pkg/client" 14 | "github.com/replicate/pget/pkg/config" 15 | "github.com/replicate/pget/pkg/consistent" 16 | "github.com/replicate/pget/pkg/logging" 17 | ) 18 | 19 | type ConsistentHashingMode struct { 20 | Client client.HTTPClient 21 | Options 22 | // TODO: allow this to be configured and not just "BufferMode" 23 | FallbackStrategy Strategy 24 | 25 | queue *priorityWorkQueue 26 | } 27 | 28 | type CacheKey struct { 29 | URL *url.URL `hash:"string"` 30 | Slice int64 31 | } 32 | 33 | func GetConsistentHashingMode(opts Options) (*ConsistentHashingMode, error) { 34 | if opts.SliceSize == 0 { 35 | return nil, fmt.Errorf("must specify slice size in consistent hashing mode") 36 | } 37 | client := client.NewHTTPClient(opts.Client) 38 | 39 | fallbackStrategy := &BufferMode{ 40 | Client: client, 41 | // Do not pass cache-related options to the fallback strategy 42 | Options: Options{ 43 | Client: opts.Client, 44 | ChunkSize: opts.ChunkSize, 45 | MaxConcurrency: opts.MaxConcurrency, 46 | }, 47 | } 48 | 49 | m := &ConsistentHashingMode{ 50 | Client: client, 51 | Options: opts, 52 | FallbackStrategy: fallbackStrategy, 53 | } 54 | m.queue = newWorkQueue(opts.maxConcurrency(), m.chunkSize()) 55 | m.queue.start() 56 | fallbackStrategy.queue = m.queue 57 | return m, nil 58 | } 59 | 60 | func (m *ConsistentHashingMode) chunkSize() int64 { 61 | chunkSize := m.ChunkSize 62 | if chunkSize == 0 { 63 | chunkSize = defaultChunkSize 64 | } 65 | if chunkSize > m.SliceSize { 66 | chunkSize = m.SliceSize 67 | } 68 | return chunkSize 69 | } 70 | 71 | func (m *ConsistentHashingMode) getFileSizeFromResponse(resp *http.Response) (int64, error) { 72 | // If the response is a 200 OK, we need to parse the file size assuming the whole 73 | // file was returned. If it isn't, we will assume this was a 206 Partial Content 74 | // and parse the file size from the content range header. We wouldn't be in this 75 | // function if the response was not between 200 and 300, so this feels like a 76 | // reasonable assumption 77 | if resp.StatusCode == http.StatusOK { 78 | return m.getFileSizeFromContentLength(resp.Header.Get("Content-Length")) 79 | } 80 | return m.getFileSizeFromContentRange(resp.Header.Get("Content-Range")) 81 | } 82 | 83 | func (m *ConsistentHashingMode) getFileSizeFromContentLength(contentLength string) (int64, error) { 84 | size, err := strconv.ParseInt(contentLength, 10, 64) 85 | if err != nil { 86 | return 0, err 87 | } 88 | 89 | return size, nil 90 | } 91 | 92 | func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string) (int64, error) { 93 | groups := contentRangeRegexp.FindStringSubmatch(contentRange) 94 | if groups == nil { 95 | return -1, fmt.Errorf("couldn't parse Content-Range: %s", contentRange) 96 | } 97 | return strconv.ParseInt(groups[1], 10, 64) 98 | } 99 | 100 | func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io.Reader, int64, error) { 101 | logger := logging.GetLogger() 102 | 103 | parsed, err := url.Parse(urlString) 104 | if err != nil { 105 | return nil, -1, err 106 | } 107 | shouldContinue := false 108 | if prefixes, ok := m.CacheableURIPrefixes[parsed.Host]; ok { 109 | for _, pfx := range prefixes { 110 | if pfx.Path == "/" || strings.HasPrefix(parsed.Path, pfx.Path) { 111 | shouldContinue = true 112 | break 113 | } 114 | } 115 | } 116 | // Use our fallback mode if we're not downloading from a consistent-hashing enabled domain 117 | if !shouldContinue { 118 | logger.Debug(). 119 | Str("url", urlString). 120 | Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", parsed.Host)). 121 | Msg("fallback strategy") 122 | return m.FallbackStrategy.Fetch(ctx, urlString) 123 | } 124 | 125 | firstChunk := newReaderPromise() 126 | firstReqResultCh := make(chan firstReqResult) 127 | m.queue.submitLow(func(buf []byte) { 128 | defer close(firstReqResultCh) 129 | firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString) 130 | if err != nil { 131 | firstReqResultCh <- firstReqResult{err: err} 132 | return 133 | } 134 | defer firstChunkResp.Body.Close() 135 | 136 | fileSize, err := m.getFileSizeFromResponse(firstChunkResp) 137 | if err != nil { 138 | firstReqResultCh <- firstReqResult{err: err} 139 | return 140 | } 141 | firstReqResultCh <- firstReqResult{fileSize: fileSize} 142 | 143 | contentLength := firstChunkResp.ContentLength 144 | n, err := io.ReadFull(firstChunkResp.Body, buf[0:contentLength]) 145 | if err == io.ErrUnexpectedEOF { 146 | logger.Warn(). 147 | Int("connection_interrupted_at_byte", n). 148 | Msg("Resuming Chunk Download") 149 | n, err = resumeDownload(firstChunkResp.Request, buf[n:contentLength], m.Client, int64(n)) 150 | } 151 | firstChunk.Deliver(buf[0:n], err) 152 | }) 153 | firstReqResult, ok := <-firstReqResultCh 154 | if !ok { 155 | panic("logic error in ConsistentHashingMode: first request didn't return any output") 156 | } 157 | if firstReqResult.err != nil { 158 | // In the case that an error indicating an issue with the cache server, networking, etc is returned, 159 | // this will use the fallback strategy. This is a case where the whole file will use the fallback 160 | // strategy. 161 | if errors.Is(firstReqResult.err, client.ErrStrategyFallback) { 162 | // TODO(morgan): we should indicate the fallback strategy we're using in the logs 163 | logger.Info(). 164 | Str("url", urlString). 165 | Str("type", "file"). 166 | Err(err). 167 | Msg("consistent hash fallback") 168 | return m.FallbackStrategy.Fetch(ctx, urlString) 169 | } 170 | return nil, -1, firstReqResult.err 171 | } 172 | fileSize := firstReqResult.fileSize 173 | 174 | if fileSize <= m.chunkSize() { 175 | // we only need a single chunk: just download it and finish 176 | return firstChunk, fileSize, nil 177 | } 178 | 179 | totalSlices := fileSize / m.SliceSize 180 | if fileSize%m.SliceSize != 0 { 181 | totalSlices++ 182 | } 183 | 184 | readers := make([]io.Reader, 0) 185 | slices := make([][]*readerPromise, totalSlices) 186 | logger.Debug().Str("url", urlString). 187 | Int64("size", fileSize). 188 | Int("concurrency", m.maxConcurrency()). 189 | Msg("Downloading") 190 | 191 | for slice := 0; slice < int(totalSlices); slice++ { 192 | sliceSize := m.SliceSize 193 | if slice == int(totalSlices)-1 { 194 | sliceSize = (fileSize-1)%m.SliceSize + 1 195 | } 196 | // integer divide rounding up 197 | numChunks := int(((sliceSize - 1) / m.chunkSize()) + 1) 198 | chunks := make([]*readerPromise, numChunks) 199 | for i := 0; i < numChunks; i++ { 200 | var chunk *readerPromise 201 | if slice == 0 && i == 0 { 202 | chunk = firstChunk 203 | } else { 204 | chunk = newReaderPromise() 205 | } 206 | chunks[i] = chunk 207 | readers = append(readers, chunk) 208 | } 209 | slices[slice] = chunks 210 | } 211 | go m.downloadRemainingChunks(ctx, urlString, slices) 212 | return io.MultiReader(readers...), fileSize, nil 213 | } 214 | 215 | func (m *ConsistentHashingMode) downloadRemainingChunks(ctx context.Context, urlString string, slices [][]*readerPromise) { 216 | logger := logging.GetLogger() 217 | for slice, sliceChunks := range slices { 218 | sliceStart := m.SliceSize * int64(slice) 219 | sliceEnd := m.SliceSize*int64(slice+1) - 1 220 | for i, chunk := range sliceChunks { 221 | if slice == 0 && i == 0 { 222 | // this is the first chunk, already handled above 223 | continue 224 | } 225 | m.queue.submitHigh(func(buf []byte) { 226 | chunkStart := sliceStart + int64(i)*m.chunkSize() 227 | chunkEnd := chunkStart + m.chunkSize() - 1 228 | if chunkEnd > sliceEnd { 229 | chunkEnd = sliceEnd 230 | } 231 | 232 | logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") 233 | resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) 234 | if err != nil { 235 | // in the case that an error indicating an issue with the cache server, networking, etc is returned, 236 | // this will use the fallback strategy. This is a case where the whole file will perform the fall-back 237 | // for the specified chunk instead of the whole file. 238 | if errors.Is(err, client.ErrStrategyFallback) { 239 | // TODO(morgan): we should indicate the fallback strategy we're using in the logs 240 | logger.Info(). 241 | Str("url", urlString). 242 | Str("type", "chunk"). 243 | Err(err). 244 | Msg("consistent hash fallback") 245 | resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) 246 | } 247 | if err != nil { 248 | chunk.Deliver(nil, err) 249 | return 250 | } 251 | } 252 | defer resp.Body.Close() 253 | contentLength := resp.ContentLength 254 | n, err := io.ReadFull(resp.Body, buf[0:contentLength]) 255 | if err == io.ErrUnexpectedEOF { 256 | logger.Warn(). 257 | Int("connection_interrupted_at_byte", n). 258 | Msg("Resuming Chunk Download") 259 | n, err = resumeDownload(resp.Request, buf[n:contentLength], m.Client, int64(n)) 260 | } 261 | chunk.Deliver(buf[0:n], err) 262 | }) 263 | } 264 | } 265 | } 266 | 267 | func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { 268 | chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) 269 | req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) 270 | if err != nil { 271 | return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) 272 | } 273 | resp, cachePodIndex, err := m.doRequestToCacheHost(req, urlString, start, end) 274 | if err != nil { 275 | if errors.Is(err, client.ErrStrategyFallback) { 276 | origErr := err 277 | req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) 278 | if err != nil { 279 | return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) 280 | } 281 | resp, _, err = m.doRequestToCacheHost(req, urlString, start, end, cachePodIndex) 282 | if err != nil { 283 | // return origErr so that we can use our regular fallback strategy 284 | return nil, origErr 285 | } 286 | } else { 287 | return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) 288 | } 289 | } 290 | if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 { 291 | return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status) 292 | } 293 | 294 | return resp, nil 295 | } 296 | 297 | func (m *ConsistentHashingMode) doRequestToCacheHost(req *http.Request, urlString string, start int64, end int64, previousPodIndexes ...int) (*http.Response, int, error) { 298 | logger := logging.GetLogger() 299 | cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end, previousPodIndexes...) 300 | if err != nil { 301 | return nil, cachePodIndex, err 302 | } 303 | req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) 304 | 305 | logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") 306 | 307 | resp, err := m.Client.Do(req) 308 | return resp, cachePodIndex, err 309 | } 310 | 311 | func (m *ConsistentHashingMode) rewriteRequestToCacheHost(req *http.Request, start int64, end int64, previousPodIndexes ...int) (int, error) { 312 | logger := logging.GetLogger() 313 | if start/m.SliceSize != end/m.SliceSize { 314 | return 0, fmt.Errorf("Internal error: can't make a range request across a slice boundary: %d-%d straddles a slice boundary (slice size is %d)", start, end, m.SliceSize) 315 | } 316 | slice := start / m.SliceSize 317 | 318 | key := CacheKey{URL: req.URL, Slice: slice} 319 | 320 | cachePodIndex, err := consistent.HashBucket(key, len(m.CacheHosts), previousPodIndexes...) 321 | if err != nil { 322 | return -1, err 323 | } 324 | if m.CacheUsePathProxy { 325 | // prepend the hostname to the start of the path. The consistent-hash nodes will use this to determine the proxy 326 | newPath, err := url.JoinPath(strings.ToLower(req.URL.Host), req.URL.Path) 327 | if err != nil { 328 | return -1, err 329 | } 330 | // Ensure wr have a leading slash, things get weird (especially in testing) if we do not. 331 | req.URL.Path = fmt.Sprintf("/%s", newPath) 332 | } 333 | cacheHost := m.CacheHosts[cachePodIndex] 334 | if cacheHost == "" { 335 | // this can happen if an SRV record is missing due to a not-ready pod 336 | logger.Debug(). 337 | Str("cache_key", fmt.Sprintf("%+v", key)). 338 | Int64("start", start). 339 | Int64("end", end). 340 | Int64("slice_size", m.SliceSize). 341 | Int("bucket", cachePodIndex). 342 | Ints("previous_pod_indexes", previousPodIndexes). 343 | Msg("cache host for bucket not ready, falling back") 344 | return cachePodIndex, client.ErrStrategyFallback 345 | } 346 | logger.Debug(). 347 | Str("cache_key", fmt.Sprintf("%+v", key)). 348 | Int64("start", start). 349 | Int64("end", end). 350 | Int64("slice_size", m.SliceSize). 351 | Int("bucket", cachePodIndex). 352 | Ints("previous_pod_indexes", previousPodIndexes). 353 | Msg("consistent hashing") 354 | req.URL.Scheme = "http" 355 | req.URL.Host = cacheHost 356 | 357 | return cachePodIndex, nil 358 | } 359 | -------------------------------------------------------------------------------- /pkg/download/consistent_hashing_test.go: -------------------------------------------------------------------------------- 1 | package download_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/url" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | "testing" 15 | "testing/fstest" 16 | 17 | "github.com/jarcoal/httpmock" 18 | "github.com/stretchr/testify/assert" 19 | "github.com/stretchr/testify/require" 20 | 21 | "github.com/replicate/pget/pkg/client" 22 | "github.com/replicate/pget/pkg/download" 23 | ) 24 | 25 | var testFSes = []fstest.MapFS{ 26 | {"hello.txt": {Data: []byte("0000000000000000")}}, 27 | {"hello.txt": {Data: []byte("1111111111111111")}}, 28 | {"hello.txt": {Data: []byte("2222222222222222")}}, 29 | {"hello.txt": {Data: []byte("3333333333333333")}}, 30 | {"hello.txt": {Data: []byte("4444444444444444")}}, 31 | {"hello.txt": {Data: []byte("5555555555555555")}}, 32 | {"hello.txt": {Data: []byte("6666666666666666")}}, 33 | {"hello.txt": {Data: []byte("7777777777777777")}}, 34 | } 35 | 36 | type chTestCase struct { 37 | name string 38 | concurrency int 39 | sliceSize int64 40 | chunkSize int64 41 | numCacheHosts int 42 | expectedOutput string 43 | } 44 | 45 | // rangeResponder is an httpmock.Responder that implements enough of HTTP range 46 | // requests for our purposes. 47 | func rangeResponder(status int, body string) httpmock.Responder { 48 | rangeHeaderRegexp := regexp.MustCompile("^bytes=([0-9]+)-([0-9]+)$") 49 | return func(req *http.Request) (*http.Response, error) { 50 | rangeHeader := req.Header.Get("Range") 51 | if rangeHeader == "" { 52 | return httpmock.NewStringResponse(status, body), nil 53 | } 54 | rangePair := rangeHeaderRegexp.FindStringSubmatch(rangeHeader) 55 | if rangePair == nil { 56 | return httpmock.NewStringResponse(http.StatusBadRequest, "bad range header"), nil 57 | } 58 | from, err := strconv.Atoi(rangePair[1]) 59 | if err != nil { 60 | return httpmock.NewStringResponse(http.StatusBadRequest, "bad range header"), nil 61 | } 62 | to, err := strconv.Atoi(rangePair[2]) 63 | if err != nil { 64 | return httpmock.NewStringResponse(http.StatusBadRequest, "bad range header"), nil 65 | } 66 | // HTTP range header indexes are inclusive; we increment `to` so we have 67 | // inclusive from, exclusive to for use with slice ranges 68 | to++ 69 | 70 | if from < 0 || from > to || from > len(body) || to < 0 { 71 | return httpmock.NewStringResponse(http.StatusRequestedRangeNotSatisfiable, "range unsatisfiable"), nil 72 | } 73 | if to > len(body) { 74 | to = len(body) 75 | } 76 | 77 | resp := httpmock.NewStringResponse(http.StatusPartialContent, body[from:to]) 78 | resp.Request = req 79 | resp.Header.Add("Content-Range", fmt.Sprintf("bytes %d-%d/%d", from, to-1, len(body))) 80 | resp.ContentLength = int64(to - from) 81 | resp.Header.Add("Content-Length", fmt.Sprint(resp.ContentLength)) 82 | return resp, nil 83 | } 84 | } 85 | 86 | // fakeCacheHosts creates an *httpmock.MockTransport with preregistered 87 | // responses to each of numberOfHosts distinct hostnames for the path 88 | // /hello.txt. The response will be bodyLength copies of a single character 89 | // corresponding to the base-36 index of the cache host, starting 0-9, then a-z. 90 | func fakeCacheHosts(numberOfHosts int, bodyLength int) (hostnames []string, transport *httpmock.MockTransport) { 91 | if numberOfHosts > 36 { 92 | panic("can't have more than 36 fake cache hosts, would overflow the base-36 body") 93 | } 94 | hostnames = make([]string, numberOfHosts) 95 | mockTransport := httpmock.NewMockTransport() 96 | 97 | for i := range hostnames { 98 | hostnames[i] = fmt.Sprintf("cache-host-%d", i) 99 | mockTransport.RegisterResponder("GET", fmt.Sprintf("http://%s/hello.txt", hostnames[i]), 100 | rangeResponder(200, strings.Repeat(strconv.FormatInt(int64(i), 36), bodyLength))) 101 | } 102 | return hostnames, mockTransport 103 | } 104 | 105 | var chTestCases = []chTestCase{ 106 | { // pre-computed demo that only some slices change as we add a new cache host 107 | name: "1 host", 108 | concurrency: 8, 109 | sliceSize: 3, 110 | numCacheHosts: 1, 111 | chunkSize: 1, 112 | expectedOutput: "0000000000000000", 113 | }, 114 | { 115 | name: "2 hosts", 116 | concurrency: 8, 117 | sliceSize: 3, 118 | numCacheHosts: 2, 119 | chunkSize: 1, 120 | expectedOutput: "1111110000000000", 121 | }, 122 | { 123 | name: "3 hosts", 124 | concurrency: 8, 125 | sliceSize: 3, 126 | numCacheHosts: 3, 127 | chunkSize: 1, 128 | expectedOutput: "2221110000002222", 129 | }, 130 | { 131 | name: "4 hosts", 132 | concurrency: 8, 133 | sliceSize: 3, 134 | numCacheHosts: 4, 135 | chunkSize: 1, 136 | expectedOutput: "3331113333332222", 137 | }, 138 | { 139 | name: "5 hosts", 140 | concurrency: 8, 141 | sliceSize: 3, 142 | numCacheHosts: 5, 143 | chunkSize: 1, 144 | expectedOutput: "3334443333332224", 145 | }, 146 | { 147 | name: "6 hosts", 148 | concurrency: 8, 149 | sliceSize: 3, 150 | numCacheHosts: 6, 151 | chunkSize: 1, 152 | expectedOutput: "3334443333335554", 153 | }, 154 | { 155 | name: "7 hosts", 156 | concurrency: 8, 157 | sliceSize: 3, 158 | numCacheHosts: 7, 159 | chunkSize: 1, 160 | expectedOutput: "3334446666665556", 161 | }, 162 | { 163 | name: "8 hosts", 164 | concurrency: 8, 165 | sliceSize: 3, 166 | numCacheHosts: 8, 167 | chunkSize: 1, 168 | expectedOutput: "3334446666667776", 169 | }, 170 | { 171 | name: "test when fileSize % sliceSize == 0", 172 | concurrency: 8, 173 | sliceSize: 4, 174 | numCacheHosts: 8, 175 | chunkSize: 1, 176 | expectedOutput: "3333444466666666", 177 | }, 178 | { 179 | name: "when chunkSize == sliceSize", 180 | concurrency: 8, 181 | sliceSize: 3, 182 | numCacheHosts: 8, 183 | chunkSize: 3, 184 | expectedOutput: "3334446666667776", 185 | }, 186 | { 187 | name: "test when concurrency > file size", 188 | concurrency: 24, 189 | sliceSize: 3, 190 | numCacheHosts: 8, 191 | chunkSize: 3, 192 | expectedOutput: "3334446666667776", 193 | }, 194 | { 195 | name: "test when concurrency < number of slices", 196 | concurrency: 3, 197 | sliceSize: 3, 198 | numCacheHosts: 8, 199 | chunkSize: 3, 200 | expectedOutput: "3334446666667776", 201 | }, 202 | { 203 | name: "test when chunkSize == file size", 204 | concurrency: 4, 205 | sliceSize: 16, 206 | numCacheHosts: 8, 207 | chunkSize: 16, 208 | expectedOutput: "3333333333333333", 209 | }, 210 | { 211 | name: "test when chunkSize slightly below file size", 212 | concurrency: 4, 213 | sliceSize: 16, 214 | numCacheHosts: 8, 215 | chunkSize: 15, 216 | expectedOutput: "3333333333333333", 217 | }, 218 | { 219 | name: "test when chunkSize > file size", 220 | concurrency: 4, 221 | sliceSize: 24, 222 | numCacheHosts: 8, 223 | chunkSize: 24, 224 | expectedOutput: "3333333333333333", 225 | }, 226 | { 227 | name: "if chunkSize > sliceSize, sliceSize overrides it", 228 | concurrency: 8, 229 | sliceSize: 3, 230 | numCacheHosts: 8, 231 | chunkSize: 24, 232 | expectedOutput: "3334446666667776", 233 | }, 234 | } 235 | 236 | func makeCacheableURIPrefixes(uris ...string) map[string][]*url.URL { 237 | m := make(map[string][]*url.URL) 238 | for _, uri := range uris { 239 | parsed, err := url.Parse(uri) 240 | if err != nil { 241 | panic(err) 242 | } 243 | m[parsed.Host] = append(m[parsed.Host], parsed) 244 | } 245 | return m 246 | } 247 | 248 | func TestConsistentHashing(t *testing.T) { 249 | hostnames, mockTransport := fakeCacheHosts(8, 16) 250 | 251 | for _, tc := range chTestCases { 252 | t.Run(tc.name, func(t *testing.T) { 253 | opts := download.Options{ 254 | Client: client.Options{Transport: mockTransport}, 255 | MaxConcurrency: tc.concurrency, 256 | ChunkSize: tc.chunkSize, 257 | CacheHosts: hostnames[0:tc.numCacheHosts], 258 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://test.replicate.com"), 259 | SliceSize: tc.sliceSize, 260 | } 261 | 262 | ctx, cancel := context.WithCancel(context.Background()) 263 | defer cancel() 264 | 265 | strategy, err := download.GetConsistentHashingMode(opts) 266 | require.NoError(t, err) 267 | 268 | assert.Equal(t, tc.numCacheHosts, len(strategy.Options.CacheHosts)) 269 | reader, _, err := strategy.Fetch(ctx, "http://test.replicate.com/hello.txt") 270 | require.NoError(t, err) 271 | bytes, err := io.ReadAll(reader) 272 | require.NoError(t, err) 273 | 274 | assert.Equal(t, tc.expectedOutput, string(bytes)) 275 | }) 276 | } 277 | } 278 | 279 | func validatePathPrefixMiddleware(t *testing.T, next http.Handler, hostname string) http.Handler { 280 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 281 | assert.Equal(t, hostname, r.Host) 282 | hostPfx := fmt.Sprintf("/%s", hostname) 283 | assert.True(t, strings.HasPrefix(r.URL.Path, hostPfx)) 284 | r.URL.Path = strings.TrimPrefix(r.URL.Path, hostPfx) 285 | next.ServeHTTP(w, r) 286 | }) 287 | } 288 | 289 | func TestConsistentHashingPathBased(t *testing.T) { 290 | var hostname = "test.replicate.com" 291 | hostnames := make([]string, len(testFSes)) 292 | for i, fs := range testFSes { 293 | validatePathPrefixAndStrip := validatePathPrefixMiddleware(t, http.FileServer(http.FS(fs)), hostname) 294 | ts := httptest.NewServer(validatePathPrefixAndStrip) 295 | defer ts.Close() 296 | url, err := url.Parse(ts.URL) 297 | require.NoError(t, err) 298 | hostnames[i] = url.Host 299 | } 300 | 301 | for _, tc := range chTestCases { 302 | t.Run(tc.name, func(t *testing.T) { 303 | opts := download.Options{ 304 | Client: client.Options{}, 305 | MaxConcurrency: tc.concurrency, 306 | ChunkSize: tc.chunkSize, 307 | CacheHosts: hostnames[0:tc.numCacheHosts], 308 | CacheableURIPrefixes: makeCacheableURIPrefixes(fmt.Sprintf("http://%s", hostname)), 309 | CacheUsePathProxy: true, 310 | SliceSize: tc.sliceSize, 311 | } 312 | 313 | ctx, cancel := context.WithCancel(context.Background()) 314 | defer cancel() 315 | 316 | strategy, err := download.GetConsistentHashingMode(opts) 317 | require.NoError(t, err) 318 | 319 | assert.Equal(t, tc.numCacheHosts, len(strategy.Options.CacheHosts)) 320 | reader, _, err := strategy.Fetch(ctx, fmt.Sprintf("http://%s/hello.txt", hostname)) 321 | require.NoError(t, err) 322 | bytes, err := io.ReadAll(reader) 323 | require.NoError(t, err) 324 | 325 | assert.Equal(t, tc.expectedOutput, string(bytes)) 326 | }) 327 | } 328 | } 329 | 330 | func TestConsistentHashRetries(t *testing.T) { 331 | hostnames, mockTransport := fakeCacheHosts(8, 16) 332 | // deliberately "break" one cache host 333 | hostnames[0] = "broken-host" 334 | mockTransport.RegisterResponder("GET", "http://broken-host/hello.txt", httpmock.NewStringResponder(503, "fake broken host")) 335 | 336 | opts := download.Options{ 337 | Client: client.Options{Transport: mockTransport}, 338 | MaxConcurrency: 8, 339 | ChunkSize: 1, 340 | CacheHosts: hostnames, 341 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://fake.replicate.delivery"), 342 | SliceSize: 1, 343 | } 344 | 345 | ctx, cancel := context.WithCancel(context.Background()) 346 | defer cancel() 347 | 348 | strategy, err := download.GetConsistentHashingMode(opts) 349 | require.NoError(t, err) 350 | 351 | reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") 352 | require.NoError(t, err) 353 | bytes, err := io.ReadAll(reader) 354 | require.NoError(t, err) 355 | 356 | // with a functional hostnames[0], we'd see 0344760706165500, but instead we 357 | // should fall back to this. Note that each 0 value has been changed to a 358 | // different index; we don't want every request that previously hit 0 to hit 359 | // the same new host. 360 | assert.Equal(t, "3344761726165516", string(bytes)) 361 | } 362 | 363 | func TestConsistentHashRetriesMissingHostname(t *testing.T) { 364 | hostnames, mockTransport := fakeCacheHosts(8, 16) 365 | 366 | // we deliberately "break" this cache host to make it as if its SRV record was missing 367 | hostnames[0] = "" 368 | 369 | opts := download.Options{ 370 | Client: client.Options{ 371 | Transport: mockTransport, 372 | }, 373 | MaxConcurrency: 8, 374 | ChunkSize: 1, 375 | CacheHosts: hostnames, 376 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://fake.replicate.delivery"), 377 | SliceSize: 1, 378 | } 379 | 380 | ctx, cancel := context.WithCancel(context.Background()) 381 | defer cancel() 382 | 383 | strategy, err := download.GetConsistentHashingMode(opts) 384 | require.NoError(t, err) 385 | 386 | reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") 387 | require.NoError(t, err) 388 | bytes, err := io.ReadAll(reader) 389 | require.NoError(t, err) 390 | 391 | // with a functional hostnames[0], we'd see 0344760706165500, but instead we 392 | // should fall back to this. Note that each 0 value has been changed to a 393 | // different index; we don't want every request that previously hit 0 to hit 394 | // the same new host. 395 | assert.Equal(t, "3344761726165516", string(bytes)) 396 | } 397 | 398 | // with only two hosts, we should *always* fall back to the other host 399 | func TestConsistentHashRetriesTwoHosts(t *testing.T) { 400 | hostnames, mockTransport := fakeCacheHosts(2, 16) 401 | // deliberately "break" one cache host 402 | hostnames[1] = "broken-host" 403 | mockTransport.RegisterResponder("GET", "http://broken-host/hello.txt", httpmock.NewStringResponder(503, "fake broken host")) 404 | 405 | opts := download.Options{ 406 | Client: client.Options{Transport: mockTransport}, 407 | MaxConcurrency: 8, 408 | ChunkSize: 1, 409 | CacheHosts: hostnames, 410 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://testing.replicate.delivery"), 411 | SliceSize: 1, 412 | } 413 | 414 | ctx, cancel := context.WithCancel(context.Background()) 415 | defer cancel() 416 | 417 | strategy, err := download.GetConsistentHashingMode(opts) 418 | require.NoError(t, err) 419 | 420 | reader, _, err := strategy.Fetch(ctx, "http://testing.replicate.delivery/hello.txt") 421 | require.NoError(t, err) 422 | bytes, err := io.ReadAll(reader) 423 | require.NoError(t, err) 424 | 425 | assert.Equal(t, "0000000000000000", string(bytes)) 426 | } 427 | 428 | func TestConsistentHashingHasFallback(t *testing.T) { 429 | mockTransport := httpmock.NewMockTransport() 430 | mockTransport.RegisterResponder("GET", "http://fake.replicate.delivery/hello.txt", rangeResponder(200, "0000000000000000")) 431 | 432 | opts := download.Options{ 433 | Client: client.Options{Transport: mockTransport}, 434 | MaxConcurrency: 8, 435 | ChunkSize: 2, 436 | CacheHosts: []string{""}, // simulate a single unavailable cache host 437 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://fake.replicate.delivery"), 438 | SliceSize: 3, 439 | } 440 | 441 | ctx, cancel := context.WithCancel(context.Background()) 442 | defer cancel() 443 | 444 | strategy, err := download.GetConsistentHashingMode(opts) 445 | require.NoError(t, err) 446 | 447 | reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") 448 | require.NoError(t, err) 449 | bytes, err := io.ReadAll(reader) 450 | require.NoError(t, err) 451 | 452 | assert.Equal(t, "0000000000000000", string(bytes)) 453 | } 454 | 455 | func TestConsistentHashingHandlesFullFile(t *testing.T) { 456 | mockTransport := httpmock.NewMockTransport() 457 | mockTransport.RegisterResponder("GET", "http://fake.replicate.delivery/hello.txt", func(req *http.Request) (*http.Response, error) { 458 | resp := httpmock.NewStringResponse(http.StatusOK, "000000") 459 | resp.Request = req 460 | resp.ContentLength = int64(6) 461 | resp.Header.Add("Content-Length", fmt.Sprint(resp.ContentLength)) 462 | return resp, nil 463 | }) 464 | 465 | opts := download.Options{ 466 | Client: client.Options{Transport: mockTransport}, 467 | MaxConcurrency: 8, 468 | ChunkSize: 6, 469 | SliceSize: 6, 470 | } 471 | 472 | ctx, cancel := context.WithCancel(context.Background()) 473 | defer cancel() 474 | 475 | strategy, err := download.GetConsistentHashingMode(opts) 476 | require.NoError(t, err) 477 | 478 | reader, _, err := strategy.Fetch(ctx, "http://fake.replicate.delivery/hello.txt") 479 | require.NoError(t, err) 480 | bytes, err := io.ReadAll(reader) 481 | require.NoError(t, err) 482 | 483 | assert.Equal(t, "000000", string(bytes)) 484 | } 485 | 486 | type fallbackFailingHandler struct { 487 | responseStatus int 488 | responseFunc func(w http.ResponseWriter, r *http.Request) 489 | } 490 | 491 | func (h fallbackFailingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 492 | if h.responseFunc != nil { 493 | h.responseFunc(w, r) 494 | } else { 495 | w.WriteHeader(h.responseStatus) 496 | } 497 | } 498 | 499 | type testStrategy struct { 500 | fetchCalledCount int 501 | doRequestCalledCount int 502 | mut sync.Mutex 503 | } 504 | 505 | func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { 506 | s.fetchCalledCount++ 507 | return io.NopCloser(strings.NewReader("00")), -1, nil 508 | } 509 | 510 | func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) { 511 | s.mut.Lock() 512 | s.doRequestCalledCount++ 513 | s.mut.Unlock() 514 | req, err := http.NewRequest(http.MethodGet, url, nil) 515 | if err != nil { 516 | return nil, err 517 | } 518 | resp := &http.Response{ 519 | Request: req, 520 | Body: io.NopCloser(strings.NewReader("00")), 521 | } 522 | return resp, nil 523 | } 524 | 525 | func TestConsistentHashingFileFallback(t *testing.T) { 526 | tc := []struct { 527 | name string 528 | responseStatus int 529 | failureFunc func(w http.ResponseWriter, r *http.Request) 530 | fetchCalledCount int 531 | doRequestCalledCount int 532 | expectedError error 533 | }{ 534 | { 535 | name: "BadGateway", 536 | responseStatus: http.StatusBadGateway, 537 | fetchCalledCount: 1, 538 | doRequestCalledCount: 0, 539 | }, 540 | // "NotFound" should not trigger fall-back 541 | { 542 | name: "NotFound", 543 | responseStatus: http.StatusNotFound, 544 | fetchCalledCount: 0, 545 | doRequestCalledCount: 0, 546 | expectedError: download.ErrUnexpectedHTTPStatus, 547 | }, 548 | } 549 | 550 | for _, tc := range tc { 551 | t.Run(tc.name, func(t *testing.T) { 552 | server := httptest.NewServer(fallbackFailingHandler{responseStatus: tc.responseStatus, responseFunc: tc.failureFunc}) 553 | defer server.Close() 554 | 555 | url, _ := url.Parse(server.URL) 556 | opts := download.Options{ 557 | Client: client.Options{}, 558 | MaxConcurrency: 8, 559 | ChunkSize: 2, 560 | CacheHosts: []string{url.Host}, 561 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://fake.replicate.delivery"), 562 | SliceSize: 3, 563 | } 564 | 565 | ctx, cancel := context.WithCancel(context.Background()) 566 | defer cancel() 567 | 568 | strategy, err := download.GetConsistentHashingMode(opts) 569 | require.NoError(t, err) 570 | 571 | fallbackStrategy := &testStrategy{} 572 | strategy.FallbackStrategy = fallbackStrategy 573 | 574 | urlString := "http://fake.replicate.delivery/hello.txt" 575 | _, _, err = strategy.Fetch(ctx, urlString) 576 | if tc.expectedError != nil { 577 | assert.ErrorIs(t, err, tc.expectedError) 578 | } 579 | assert.Equal(t, tc.fetchCalledCount, fallbackStrategy.fetchCalledCount) 580 | assert.Equal(t, tc.doRequestCalledCount, fallbackStrategy.doRequestCalledCount) 581 | }) 582 | } 583 | } 584 | 585 | func TestConsistentHashingChunkFallback(t *testing.T) { 586 | handlerFunc := func(w http.ResponseWriter, r *http.Request) { 587 | if r.Header.Get("Range") != "bytes=0-2" { 588 | w.WriteHeader(http.StatusBadGateway) 589 | } else { 590 | w.Header().Set("Content-Range", "bytes 0-2/4") 591 | w.WriteHeader(http.StatusPartialContent) 592 | _, _ = w.Write([]byte("000")) 593 | } 594 | } 595 | 596 | tc := []struct { 597 | name string 598 | responseStatus int 599 | handlerFunc func(w http.ResponseWriter, r *http.Request) 600 | fetchCalledCount int 601 | doRequestCalledCount int 602 | expectedError error 603 | }{ 604 | { 605 | name: "fail-on-second-chunk", 606 | handlerFunc: handlerFunc, 607 | fetchCalledCount: 0, 608 | doRequestCalledCount: 1, 609 | }, 610 | } 611 | 612 | for _, tc := range tc { 613 | t.Run(tc.name, func(t *testing.T) { 614 | server := httptest.NewServer(fallbackFailingHandler{responseStatus: tc.responseStatus, responseFunc: tc.handlerFunc}) 615 | defer server.Close() 616 | 617 | url, _ := url.Parse(server.URL) 618 | opts := download.Options{ 619 | Client: client.Options{}, 620 | MaxConcurrency: 8, 621 | ChunkSize: 3, 622 | CacheHosts: []string{url.Host}, 623 | CacheableURIPrefixes: makeCacheableURIPrefixes("http://fake.replicate.delivery"), 624 | SliceSize: 3, 625 | } 626 | 627 | ctx, cancel := context.WithCancel(context.Background()) 628 | defer cancel() 629 | 630 | strategy, err := download.GetConsistentHashingMode(opts) 631 | require.NoError(t, err) 632 | 633 | fallbackStrategy := &testStrategy{} 634 | strategy.FallbackStrategy = fallbackStrategy 635 | 636 | urlString := "http://fake.replicate.delivery/hello.txt" 637 | out, _, err := strategy.Fetch(ctx, urlString) 638 | assert.ErrorIs(t, err, tc.expectedError) 639 | if err == nil { 640 | // eagerly read the whole output reader to force all the 641 | // requests to be completed 642 | _, _ = io.Copy(io.Discard, out) 643 | } 644 | assert.Equal(t, tc.fetchCalledCount, fallbackStrategy.fetchCalledCount) 645 | assert.Equal(t, tc.doRequestCalledCount, fallbackStrategy.doRequestCalledCount) 646 | }) 647 | } 648 | } 649 | -------------------------------------------------------------------------------- /pkg/download/options.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "net/url" 5 | "runtime" 6 | 7 | "github.com/replicate/pget/pkg/client" 8 | ) 9 | 10 | type Options struct { 11 | // Maximum number of chunks to download. If set to zero, GOMAXPROCS*4 12 | // will be used. 13 | MaxConcurrency int 14 | 15 | // SliceSize is the number of bytes per slice in nginx. 16 | // See https://nginx.org/en/docs/http/ngx_http_slice_module.html 17 | SliceSize int64 18 | 19 | // Number of bytes per chunk. If set to zero, 125 MiB will be used. 20 | ChunkSize int64 21 | 22 | Client client.Options 23 | 24 | // CacheableURIPrefixes is an allowlist of domains+path-prefixes which may 25 | // be routed via a pull-through cache 26 | CacheableURIPrefixes map[string][]*url.URL 27 | 28 | // CacheUsePathProxy is a flag to indicate whether to use the path proxy mechanism or the host-based mechanism 29 | // The default is to use the host-based mechanism, the path proxy mechanism is used when this flag is set to true 30 | // and involves prepending the host to the path of the request to the cache. In both cases the Hosts header is 31 | // sent to the cache. 32 | CacheUsePathProxy bool 33 | 34 | // CacheHosts is a slice of hostnames to use as pull-through caches. 35 | // The ordering is significant and will be used with the consistent 36 | // hashing algorithm. The slice may contain empty entries which 37 | // correspond to a cache host which is currently unavailable. 38 | CacheHosts []string 39 | 40 | // ForceCachePrefixRewrite will forcefully rewrite the prefix for all 41 | // pget requests to the first item in the CacheHosts list. This ignores 42 | // anything in the CacheableURIPrefixes and rewrites all requests. 43 | ForceCachePrefixRewrite bool 44 | } 45 | 46 | func (o *Options) maxConcurrency() int { 47 | maxChunks := o.MaxConcurrency 48 | if maxChunks == 0 { 49 | return runtime.NumCPU() * 4 50 | } 51 | return maxChunks 52 | } 53 | -------------------------------------------------------------------------------- /pkg/download/reader_promise.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | ) 7 | 8 | // A readerPromise represents an io.Reader whose implementation is not yet 9 | // available but will be in the future. Read() will block until Done() is 10 | // called. 11 | // 12 | // The intended use is: a consumer goroutine calls Read(), which blocks until 13 | // data is ready. A producer calls Deliver(). These block 14 | // until the consumer has read the provided data or error. 15 | type readerPromise struct { 16 | // ready channel is closed when we're ready to read 17 | ready chan struct{} 18 | // finished channel is closed when we're done reading 19 | finished chan struct{} 20 | buf []byte 21 | // if reader is non-nil, buf is always the underlying buffer for the reader 22 | reader *bytes.Reader 23 | err error 24 | } 25 | 26 | var _ io.Reader = &readerPromise{} 27 | 28 | func newReaderPromise() *readerPromise { 29 | return &readerPromise{ 30 | ready: make(chan struct{}), 31 | finished: make(chan struct{}), 32 | } 33 | } 34 | 35 | // Read implements io.Reader. It will block until the full body is available for 36 | // reading. Once the underlying buffer is fully read, it will be returned to the 37 | // pool. 38 | func (b *readerPromise) Read(buf []byte) (int, error) { 39 | <-b.ready 40 | if b.err != nil { 41 | return 0, b.err 42 | } 43 | n, err := b.reader.Read(buf) 44 | // If we've read all the data, 45 | if err == io.EOF && b.buf != nil { 46 | // unblock the producer 47 | close(b.finished) 48 | b.buf = nil 49 | b.err = io.EOF 50 | } 51 | return n, err 52 | } 53 | 54 | func (b *readerPromise) Deliver(buf []byte, err error) { 55 | if buf == nil { 56 | buf = []byte{} 57 | } 58 | b.buf = buf 59 | b.err = err 60 | b.reader = bytes.NewReader(buf) 61 | close(b.ready) 62 | <-b.finished 63 | } 64 | -------------------------------------------------------------------------------- /pkg/download/reader_promise_test.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "sync" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestReaderPromiseParallel(t *testing.T) { 14 | p := newReaderPromise() 15 | wg := new(sync.WaitGroup) 16 | wg.Add(1) 17 | go func() { 18 | defer wg.Done() 19 | p.Deliver([]byte("foobar"), nil) 20 | }() 21 | buf, err := io.ReadAll(p) 22 | assert.NoError(t, err) 23 | assert.Equal(t, "foobar", string(buf)) 24 | wg.Wait() 25 | } 26 | 27 | func TestReaderPromiseReadsWholeChunk(t *testing.T) { 28 | chunkSize := int64(1024 * 1024) 29 | p := newReaderPromise() 30 | data := bytes.Repeat([]byte("x"), int(chunkSize)) 31 | go p.Deliver(data, nil) 32 | buf := make([]byte, chunkSize) 33 | // We should only require a single Read() call because all the data should 34 | // be buffered 35 | n, err := p.Read(buf) 36 | assert.NoError(t, err) 37 | assert.Equal(t, data, buf) 38 | assert.Equal(t, int(chunkSize), n) 39 | } 40 | 41 | func TestReaderPromiseDeliverErrPassesErrorsToConsumer(t *testing.T) { 42 | p := newReaderPromise() 43 | 44 | expectedErr := fmt.Errorf("oh no") 45 | 46 | go p.Deliver(nil, expectedErr) 47 | buf := make([]byte, 10) 48 | n, err := p.Read(buf) 49 | assert.ErrorIs(t, expectedErr, err) 50 | assert.Equal(t, 0, n) 51 | } 52 | 53 | func TestReaderPromiseSubsequentReadsReturnEOF(t *testing.T) { 54 | p := newReaderPromise() 55 | go p.Deliver([]byte("foobar"), nil) 56 | buf, err := io.ReadAll(p) 57 | assert.NoError(t, err) 58 | assert.Equal(t, "foobar", string(buf)) 59 | 60 | n, err := p.Read(buf) 61 | assert.Equal(t, 0, n) 62 | assert.ErrorIs(t, err, io.EOF) 63 | } 64 | -------------------------------------------------------------------------------- /pkg/download/strategy.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | var ErrUnexpectedHTTPStatus = errors.New("unexpected http status") 11 | 12 | type Strategy interface { 13 | // Fetch retrieves the content from a given URL and returns it as an io.Reader along with the file size. 14 | // If an error occurs during the process, it returns nil for the reader, 0 for the fileSize, and the error itself. 15 | // This is the primary method that should be called to initiate a download of a file. 16 | Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) 17 | 18 | // DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context. 19 | // It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest 20 | // and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can 21 | // utilize any strategy as a fall-back for chunk downloading. 22 | // 23 | // If the request fails to download or execute, an error is returned. 24 | // 25 | // The start and end parameters specify the byte range to request. 26 | // The trueURL parameter is the actual URL after any redirects. 27 | DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) 28 | } 29 | -------------------------------------------------------------------------------- /pkg/download/work_queue.go: -------------------------------------------------------------------------------- 1 | package download 2 | 3 | // priorityWorkQueue takes work items and executes them, with n parallel 4 | // workers. It allows for a simple high/low priority split between work. We 5 | // use this to prefer finishing existing downloads over starting new downloads. 6 | // 7 | // work items are provided with a fixed-size buffer. 8 | type priorityWorkQueue struct { 9 | concurrency int 10 | lowPriority chan work 11 | highPriority chan work 12 | bufSize int64 13 | } 14 | 15 | type work func([]byte) 16 | 17 | func newWorkQueue(concurrency int, bufSize int64) *priorityWorkQueue { 18 | return &priorityWorkQueue{ 19 | concurrency: concurrency, 20 | lowPriority: make(chan work), 21 | highPriority: make(chan work), 22 | bufSize: bufSize, 23 | } 24 | } 25 | 26 | func (q *priorityWorkQueue) submitLow(w work) { 27 | q.lowPriority <- w 28 | } 29 | 30 | func (q *priorityWorkQueue) submitHigh(w work) { 31 | q.highPriority <- w 32 | } 33 | 34 | func (q *priorityWorkQueue) start() { 35 | for i := 0; i < q.concurrency; i++ { 36 | go q.run(make([]byte, q.bufSize)) 37 | } 38 | } 39 | 40 | func (q *priorityWorkQueue) run(buf []byte) { 41 | for { 42 | // read items off the high priority queue until it's empty 43 | select { 44 | case item := <-q.highPriority: 45 | item(buf) 46 | default: 47 | select { // read one item from either queue, then go round the loop again 48 | case item := <-q.highPriority: 49 | item(buf) 50 | case item := <-q.lowPriority: 51 | item(buf) 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /pkg/extract/compression.go: -------------------------------------------------------------------------------- 1 | package extract 2 | 3 | import ( 4 | "bytes" 5 | "compress/bzip2" 6 | "compress/gzip" 7 | "compress/lzw" 8 | "io" 9 | 10 | "github.com/pierrec/lz4" 11 | "github.com/ulikunitz/xz" 12 | 13 | "github.com/replicate/pget/pkg/logging" 14 | ) 15 | 16 | const ( 17 | peekSize = 8 18 | ) 19 | 20 | var ( 21 | gzipMagic = []byte{0x1F, 0x8B} 22 | bzipMagic = []byte{0x42, 0x5A} 23 | xzMagic = []byte{0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00} 24 | lzwMagic = []byte{0x1F, 0x9D} 25 | lz4Magic = []byte{0x18, 0x4D, 0x22, 0x04} 26 | ) 27 | 28 | var _ decompressor = gzipDecompressor{} 29 | var _ decompressor = bzip2Decompressor{} 30 | var _ decompressor = xzDecompressor{} 31 | var _ decompressor = lzwDecompressor{} 32 | var _ decompressor = lz4Decompressor{} 33 | 34 | // decompressor represents different compression formats. 35 | type decompressor interface { 36 | decompress(r io.Reader) (io.Reader, error) 37 | } 38 | 39 | // detectFormat returns the appropriate extractor according to the magic number. 40 | func detectFormat(input []byte) decompressor { 41 | log := logging.GetLogger() 42 | inputSize := len(input) 43 | 44 | if inputSize < 2 { 45 | return nil 46 | } 47 | // pad to 8 bytes 48 | if inputSize < 8 { 49 | input = append(input, make([]byte, peekSize-inputSize)...) 50 | } 51 | 52 | switch true { 53 | case bytes.HasPrefix(input, gzipMagic): 54 | log.Debug(). 55 | Str("type", "gzip"). 56 | Msg("Compression Format") 57 | return gzipDecompressor{} 58 | case bytes.HasPrefix(input, bzipMagic): 59 | log.Debug(). 60 | Str("type", "bzip2"). 61 | Msg("Compression Format") 62 | return bzip2Decompressor{} 63 | case bytes.HasPrefix(input, lzwMagic): 64 | compressionByte := input[2] 65 | // litWidth is guaranteed to be at least 9 per specification, the high order 3 bits of byte[2] are the litWidth 66 | // the low order 5 bits are only used by non-unix implementations, we are going to ignore them. 67 | litWidth := int(compressionByte>>5) + 9 68 | log.Debug(). 69 | Str("type", "lzw"). 70 | Int("litWidth", litWidth). 71 | Msg("Compression Format") 72 | return lzwDecompressor{ 73 | order: lzw.MSB, 74 | litWidth: litWidth, 75 | } 76 | case bytes.HasPrefix(input, lz4Magic): 77 | log.Debug(). 78 | Str("type", "lz4"). 79 | Msg("Compression Format") 80 | return lz4Decompressor{} 81 | case bytes.HasPrefix(input, xzMagic): 82 | log.Debug(). 83 | Str("type", "xz"). 84 | Msg("Compression Format") 85 | return xzDecompressor{} 86 | default: 87 | log.Debug(). 88 | Str("type", "none"). 89 | Msg("Compression Format") 90 | return nil 91 | } 92 | 93 | } 94 | 95 | type gzipDecompressor struct{} 96 | 97 | func (d gzipDecompressor) decompress(r io.Reader) (io.Reader, error) { 98 | return gzip.NewReader(r) 99 | } 100 | 101 | type bzip2Decompressor struct{} 102 | 103 | func (d bzip2Decompressor) decompress(r io.Reader) (io.Reader, error) { 104 | return bzip2.NewReader(r), nil 105 | } 106 | 107 | type xzDecompressor struct{} 108 | 109 | func (d xzDecompressor) decompress(r io.Reader) (io.Reader, error) { 110 | return xz.NewReader(r) 111 | } 112 | 113 | type lzwDecompressor struct { 114 | litWidth int 115 | order lzw.Order 116 | } 117 | 118 | func (d lzwDecompressor) decompress(r io.Reader) (io.Reader, error) { 119 | return lzw.NewReader(r, d.order, d.litWidth), nil 120 | } 121 | 122 | type lz4Decompressor struct{} 123 | 124 | func (d lz4Decompressor) decompress(r io.Reader) (io.Reader, error) { 125 | return lz4.NewReader(r), nil 126 | } 127 | -------------------------------------------------------------------------------- /pkg/extract/compression_test.go: -------------------------------------------------------------------------------- 1 | package extract 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestDetectFormat(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | input []byte 14 | expectType string 15 | }{ 16 | { 17 | name: "GZIP", 18 | input: []byte{0x1f, 0x8b}, 19 | expectType: "extract.gzipDecompressor", 20 | }, 21 | { 22 | name: "BZIP2", 23 | input: []byte{0x42, 0x5a}, 24 | expectType: "extract.bzip2Decompressor", 25 | }, 26 | { 27 | name: "XZ", 28 | input: []byte{0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00}, 29 | expectType: "extract.xzDecompressor", 30 | }, 31 | { 32 | name: "Less than 2 bytes", 33 | input: []byte{0x1f}, 34 | expectType: "", 35 | }, 36 | { 37 | name: "UNKNOWN", 38 | input: []byte{0xde, 0xad}, 39 | expectType: "", 40 | }, 41 | } 42 | 43 | for _, tt := range tests { 44 | t.Run(tt.name, func(t *testing.T) { 45 | result := detectFormat(tt.input) 46 | assert.Equal(t, tt.expectType, stringFromInterface(result)) 47 | }) 48 | } 49 | } 50 | 51 | func stringFromInterface(i interface{}) string { 52 | if i == nil { 53 | return "" 54 | } 55 | return fmt.Sprintf("%T", i) 56 | } 57 | -------------------------------------------------------------------------------- /pkg/extract/tar.go: -------------------------------------------------------------------------------- 1 | package extract 2 | 3 | import ( 4 | "archive/tar" 5 | "bufio" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "path/filepath" 11 | "strings" 12 | "time" 13 | 14 | "github.com/replicate/pget/pkg/logging" 15 | ) 16 | 17 | var ErrZipSlip = errors.New("archive (tar) file contains file outside of target directory") 18 | var ErrEmptyHeaderName = errors.New("tar file contains entry with empty name") 19 | 20 | type link struct { 21 | linkType byte 22 | oldName string 23 | newName string 24 | } 25 | 26 | func TarFile(r *bufio.Reader, destDir string, overwrite bool) error { 27 | var links []*link 28 | var reader io.Reader = r 29 | 30 | log := logging.GetLogger() 31 | 32 | startTime := time.Now() 33 | peekData, err := r.Peek(peekSize) 34 | if err != nil { 35 | return fmt.Errorf("error reading peek data: %w", err) 36 | } 37 | if decompressor := detectFormat(peekData); decompressor != nil { 38 | reader, err = decompressor.decompress(reader) 39 | if err != nil { 40 | return fmt.Errorf("error creating decompressed stream: %w", err) 41 | } 42 | log.Info(). 43 | Str("decompressor", fmt.Sprintf("%T", decompressor)). 44 | Msg("Tar Compression Detected: Compression can significantly slowdown pget (e.g. for model weights)") 45 | } 46 | tarReader := tar.NewReader(reader) 47 | logger := logging.GetLogger() 48 | 49 | logger.Debug(). 50 | Str("extractor", "tar"). 51 | Str("status", "starting"). 52 | Msg("Extract") 53 | for { 54 | header, err := tarReader.Next() 55 | if err == io.EOF { 56 | break 57 | } 58 | if err != nil { 59 | return err 60 | } 61 | 62 | target := filepath.Join(destDir, header.Name) 63 | targetDir := filepath.Dir(target) 64 | if err := os.MkdirAll(targetDir, 0755); err != nil { 65 | return err 66 | } 67 | 68 | if err := guardAgainstZipSlip(header, destDir); err != nil { 69 | return err 70 | } 71 | 72 | switch header.Typeflag { 73 | case tar.TypeXGlobalHeader: 74 | // This is a global pax header, which we can skip as it's mostly handled by the underlying implementation 75 | // NOTE: the global header is not persisted across subsequent calls to Next() and therefore could indicate 76 | // that we are processing a tar file in an unintended manner. This is a limitation of archive/tar. 77 | continue 78 | case tar.TypeDir: 79 | logger.Debug(). 80 | Str("target", target). 81 | Str("perms", fmt.Sprintf("%o", header.Mode)). 82 | Msg("Tar: Directory") 83 | if err := os.MkdirAll(target, cleanFileMode(os.FileMode(header.Mode))); err != nil { 84 | return err 85 | } 86 | case tar.TypeReg: 87 | openFlags := os.O_CREATE | os.O_WRONLY 88 | if overwrite { 89 | openFlags |= os.O_TRUNC 90 | } 91 | logger.Debug(). 92 | Str("target", target). 93 | Str("perms", fmt.Sprintf("%o", header.Mode)). 94 | Msg("Tar: File") 95 | targetFile, err := os.OpenFile(target, openFlags, cleanFileMode(os.FileMode(header.Mode))) 96 | if err != nil { 97 | return err 98 | } 99 | if _, err := io.Copy(targetFile, tarReader); err != nil { 100 | targetFile.Close() 101 | return err 102 | } 103 | if err := targetFile.Close(); err != nil { 104 | return fmt.Errorf("error closing file %s: %w", target, err) 105 | } 106 | case tar.TypeSymlink, tar.TypeLink: 107 | // Defer creation of 108 | logger.Debug().Str("link_type", string(header.Typeflag)). 109 | Str("old_name", header.Linkname). 110 | Str("new_name", target). 111 | Msg("Tar: (Defer) Link") 112 | links = append(links, &link{linkType: header.Typeflag, oldName: header.Linkname, newName: target}) 113 | default: 114 | return fmt.Errorf("unsupported file type for %s, typeflag %s", header.Name, string(header.Typeflag)) 115 | } 116 | } 117 | 118 | if err := createLinks(links, destDir, overwrite); err != nil { 119 | return fmt.Errorf("error creating links: %w", err) 120 | } 121 | 122 | // Read the rest of the bytes from the archive and verify they are all null bytes 123 | // This is for validation that the byte count is correct 124 | padding, err := io.ReadAll(r) 125 | if err != nil { 126 | return fmt.Errorf("error reading padding bytes: %w", err) 127 | } 128 | for _, b := range padding { 129 | if b != 0x00 { 130 | return fmt.Errorf("unexpected non-null byte in padding: %x", b) 131 | } 132 | } 133 | 134 | elapsed := time.Since(startTime).Seconds() 135 | logger.Debug(). 136 | Str("extractor", "tar"). 137 | Float64("elapsed_time", elapsed). 138 | Str("status", "complete"). 139 | Msg("Extract") 140 | return nil 141 | } 142 | 143 | func createLinks(links []*link, destDir string, overwrite bool) error { 144 | logger := logging.GetLogger() 145 | for _, link := range links { 146 | targetDir := filepath.Dir(link.newName) 147 | if err := os.MkdirAll(targetDir, 0755); err != nil { 148 | return err 149 | } 150 | switch link.linkType { 151 | case tar.TypeLink: 152 | oldPath := filepath.Join(destDir, link.oldName) 153 | logger.Debug(). 154 | Str("old_path", oldPath). 155 | Str("new_path", link.newName). 156 | Msg("Tar: creating hard link") 157 | if err := createHardLink(oldPath, link.newName, overwrite); err != nil { 158 | return fmt.Errorf("error creating hard link from %s to %s: %w", oldPath, link.newName, err) 159 | } 160 | case tar.TypeSymlink: 161 | logger.Debug(). 162 | Str("old_path", link.oldName). 163 | Str("new_path", link.newName). 164 | Msg("Tar: creating symlink") 165 | if err := createSymlink(link.oldName, link.newName, overwrite); err != nil { 166 | return fmt.Errorf("error creating symlink from %s to %s: %w", link.oldName, link.newName, err) 167 | } 168 | default: 169 | return fmt.Errorf("unsupported link type %s", string(link.linkType)) 170 | } 171 | } 172 | return nil 173 | } 174 | 175 | func createHardLink(oldName, newName string, overwrite bool) error { 176 | if overwrite { 177 | err := os.Remove(newName) 178 | if err != nil && !os.IsNotExist(err) { 179 | return fmt.Errorf("error removing existing file: %w", err) 180 | } 181 | } 182 | return os.Link(oldName, newName) 183 | } 184 | 185 | func createSymlink(oldName, newName string, overwrite bool) error { 186 | if overwrite { 187 | err := os.Remove(newName) 188 | if err != nil && !os.IsNotExist(err) { 189 | return fmt.Errorf("error removing existing symlink/file: %w", err) 190 | } 191 | } 192 | return os.Symlink(oldName, newName) 193 | } 194 | 195 | func guardAgainstZipSlip(header *tar.Header, destDir string) error { 196 | if header.Name == "" { 197 | return ErrEmptyHeaderName 198 | } 199 | target, err := filepath.Abs(filepath.Join(destDir, header.Name)) 200 | if err != nil { 201 | return fmt.Errorf("error getting absolute path of destDir %s: %w", header.Name, err) 202 | } 203 | destAbs, err := filepath.Abs(destDir) 204 | if err != nil { 205 | return fmt.Errorf("error getting absolute path of %s: %w", destDir, err) 206 | } 207 | if !strings.HasPrefix(target, destAbs) { 208 | return fmt.Errorf("%w: `%s` outside of `%s`", ErrZipSlip, target, destAbs) 209 | } 210 | return nil 211 | } 212 | 213 | func cleanFileMode(mode os.FileMode) os.FileMode { 214 | mask := os.ModeSticky | os.ModeSetuid | os.ModeSetgid 215 | return mode &^ mask 216 | } 217 | -------------------------------------------------------------------------------- /pkg/extract/tar_test.go: -------------------------------------------------------------------------------- 1 | package extract 2 | 3 | import ( 4 | "archive/tar" 5 | "os" 6 | "path/filepath" 7 | "syscall" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestCreateLinks(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | links []*link 17 | expectedError bool 18 | overwrite bool 19 | createFileToOverwrite bool 20 | }{ 21 | { 22 | name: "EmptyLink", 23 | links: []*link{}, 24 | }, 25 | { 26 | name: "ValidHardLink", 27 | links: []*link{{tar.TypeLink, "", "testLinkHard"}}, 28 | }, 29 | { 30 | name: "ValidSymlink", 31 | links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, 32 | }, 33 | { 34 | name: "InvalidLinkType", 35 | links: []*link{{'!', "", "x"}}, 36 | expectedError: true, 37 | }, 38 | { 39 | name: "ValidMultipleLinks", 40 | links: []*link{ 41 | {tar.TypeLink, "", "testLinkHard"}, 42 | {tar.TypeSymlink, "", "testLinkSym"}, 43 | }, 44 | }, 45 | { 46 | name: "HardLink_OverwriteEnabled_File Exists", 47 | links: []*link{{tar.TypeLink, "", "testLinkHard"}}, 48 | overwrite: true, 49 | createFileToOverwrite: true, 50 | }, 51 | { 52 | name: "HardLink_OverwriteDisabled_FileExists", 53 | links: []*link{{tar.TypeLink, "", "testLinkHard"}}, 54 | createFileToOverwrite: true, 55 | expectedError: true, 56 | }, 57 | { 58 | name: "HardLink_OverwriteEnabled_FileDoesNotExist", 59 | links: []*link{{tar.TypeLink, "", "testLinkHard"}}, 60 | overwrite: true, 61 | }, 62 | { 63 | name: "SymLink_OverwriteEnabled_FileExists", 64 | links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, 65 | overwrite: true, 66 | createFileToOverwrite: true, 67 | }, 68 | { 69 | name: "SymLink_OverwriteDisabled_FileExists", 70 | links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, 71 | createFileToOverwrite: true, 72 | expectedError: true, 73 | }, 74 | { 75 | name: "SymLink_OverwriteEnabled_FileDoesNotExist", 76 | links: []*link{{tar.TypeSymlink, "", "testLinkSym"}}, 77 | overwrite: true, 78 | }, 79 | } 80 | 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | // Setup 84 | destDir, err := os.MkdirTemp("./", tt.name) 85 | if err != nil { 86 | t.Fatal(err) 87 | } 88 | // Cleanup 89 | defer os.RemoveAll(destDir) 90 | 91 | // For hardlink and symlink, create dummy files 92 | for _, link := range tt.links { 93 | if link.linkType == tar.TypeLink || link.linkType == tar.TypeSymlink { 94 | testFile, err := os.CreateTemp(destDir, "test-") 95 | if tt.createFileToOverwrite { 96 | _, err = os.Create(filepath.Join(destDir, link.newName)) 97 | } 98 | if err != nil { 99 | t.Fatalf("Test failed, could not create test file: %v", err) 100 | } 101 | _ = testFile.Close() 102 | link.oldName = filepath.Base(testFile.Name()) 103 | link.newName = filepath.Join(destDir, link.newName) 104 | } 105 | } 106 | 107 | err = createLinks(tt.links, destDir, tt.overwrite) 108 | 109 | // Validation 110 | if tt.expectedError { 111 | assert.Error(t, err) 112 | } else { 113 | assert.NoError(t, err) 114 | 115 | for _, link := range tt.links { 116 | oldPath := filepath.Join(destDir, link.oldName) 117 | if link.linkType == tar.TypeSymlink { 118 | assertSymlinkTarget(t, oldPath, link.newName) 119 | } else if link.linkType == tar.TypeLink { 120 | assertHardLinkTarget(t, oldPath, link.newName) 121 | } else { 122 | t.Fatal("Invalid link type") 123 | } 124 | } 125 | } 126 | 127 | }) 128 | } 129 | } 130 | 131 | func assertHardLinkTarget(t *testing.T, oldName, newName string) { 132 | fileStat, err := os.Stat(oldName) 133 | if !assert.NoError(t, err) { 134 | t.Fatal("Test failed, could not stat test-created file", err) 135 | } 136 | linkStat, err := os.Lstat(newName) 137 | if !assert.NoError(t, err) { 138 | t.Fatalf("Test failed, could not stat link %s: %v", newName, err) 139 | } 140 | targetStat, err := os.Stat(newName) 141 | if !assert.NoError(t, err) { 142 | t.Fatalf("Test failed, could not stat link %s: %v", newName, err) 143 | } 144 | assert.True(t, linkStat.Mode()&os.ModeSymlink == 0) 145 | assert.Equal(t, fileStat.Sys().(*syscall.Stat_t).Ino, targetStat.Sys().(*syscall.Stat_t).Ino) 146 | } 147 | 148 | func assertSymlinkTarget(t *testing.T, oldName, newName string) { 149 | fileStat, err := os.Stat(oldName) 150 | if !assert.NoError(t, err) { 151 | t.Fatal("Test failed, could not stat test-created file", err) 152 | } 153 | linkStat, err := os.Lstat(newName) 154 | if !assert.NoError(t, err) { 155 | t.Fatalf("Test failed, could not stat link %s: %v", newName, err) 156 | } 157 | assert.True(t, linkStat.Mode()&os.ModeSymlink != 0) 158 | // os.Stat follows symlinks 159 | realTarget, err := os.Stat(newName) 160 | if !assert.NoError(t, err) { 161 | t.Fatalf("Test failed, could not stat link %s: %v", newName, err) 162 | } 163 | assert.Equal(t, fileStat.Sys().(*syscall.Stat_t).Ino, 164 | realTarget.Sys().(*syscall.Stat_t).Ino) 165 | } 166 | 167 | func TestGuardAgainstZipSlip(t *testing.T) { 168 | tests := []struct { 169 | description string 170 | header *tar.Header 171 | destDir string 172 | expectedError error 173 | }{ 174 | { 175 | description: "valid file path within directory", 176 | header: &tar.Header{ 177 | Name: "valid_file", 178 | }, 179 | destDir: "/tmp/valid_dir", 180 | expectedError: nil, 181 | }, 182 | { 183 | description: "file path outside directory", 184 | header: &tar.Header{ 185 | Name: "../invalid_file", 186 | }, 187 | destDir: "/tmp/valid_dir", 188 | expectedError: ErrZipSlip, 189 | }, 190 | { 191 | description: "directory traversal with invalid file", 192 | header: &tar.Header{ 193 | Name: "./../../tmp/invalid_dir/invalid_file", 194 | }, 195 | destDir: "/tmp/valid_dir", 196 | expectedError: ErrZipSlip, 197 | }, 198 | { 199 | description: "Empty header name", 200 | header: &tar.Header{ 201 | Name: "", 202 | }, 203 | destDir: "/tmp", 204 | expectedError: ErrEmptyHeaderName, 205 | }, 206 | { 207 | description: "relative destDir path, valid file", 208 | header: &tar.Header{ 209 | Name: "bar.txt", 210 | }, 211 | destDir: "foo", 212 | expectedError: nil, 213 | }, 214 | { 215 | description: "relative path, invalid file", 216 | header: &tar.Header{ 217 | Name: "../../bar.txt", 218 | }, 219 | destDir: "foo", 220 | expectedError: ErrZipSlip, 221 | }, 222 | } 223 | 224 | for _, test := range tests { 225 | t.Run(test.description, func(t *testing.T) { 226 | err := guardAgainstZipSlip(test.header, test.destDir) 227 | assert.ErrorIs(t, err, test.expectedError) 228 | }) 229 | } 230 | } 231 | func TestCleanFileMode(t *testing.T) { 232 | testCases := []struct { 233 | name string 234 | input os.FileMode 235 | expected os.FileMode 236 | }{ 237 | { 238 | name: "TestWithoutStickyBit", 239 | input: 0755, 240 | expected: 0755, 241 | }, 242 | { 243 | name: "TestWithStickyBit", 244 | input: os.ModeSticky | 0755, 245 | expected: 0755, 246 | }, 247 | { 248 | name: "TestWithoutSetuidBit", 249 | input: 0600, 250 | expected: 0600, 251 | }, 252 | { 253 | name: "TestWithSetuidBit", 254 | input: os.ModeSetuid | 0600, 255 | expected: 0600, 256 | }, 257 | { 258 | name: "TestWithoutSetgidBit", 259 | input: 0777, 260 | expected: 0777, 261 | }, 262 | { 263 | name: "TestWithSetgidBit", 264 | input: os.ModeSetgid | 0777, 265 | expected: 0777, 266 | }, 267 | { 268 | name: "TestWithAllBits", 269 | input: os.ModeSticky | os.ModeSetuid | os.ModeSetgid | 0777, 270 | expected: 0777, 271 | }, 272 | } 273 | 274 | for _, tc := range testCases { 275 | t.Run(tc.name, func(t *testing.T) { 276 | result := cleanFileMode(tc.input) 277 | if result != tc.expected { 278 | t.Errorf("cleanFileMode() = %v, want %v", result, tc.expected) 279 | } 280 | }) 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /pkg/logging/log.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | "time" 8 | 9 | "github.com/rs/zerolog" 10 | "github.com/rs/zerolog/log" 11 | ) 12 | 13 | func SetupLogger() { 14 | // TODO: Make color configurable? Disabled so we don't have to deal with ANSI escape codes in our logoutput 15 | output := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339, NoColor: true} 16 | output.FormatLevel = func(i interface{}) string { 17 | return strings.ToUpper(fmt.Sprintf("| %-6s|", i)) 18 | } 19 | output.FormatMessage = func(i interface{}) string { 20 | return fmt.Sprintf("[ %s ]", i) 21 | } 22 | log.Logger = zerolog.New(output).With().Timestamp().Logger() 23 | } 24 | 25 | func GetLogger() zerolog.Logger { 26 | return log.Logger 27 | } 28 | -------------------------------------------------------------------------------- /pkg/pget.go: -------------------------------------------------------------------------------- 1 | package pget 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/spf13/viper" 13 | 14 | "github.com/replicate/pget/pkg/config" 15 | "github.com/replicate/pget/pkg/version" 16 | 17 | "github.com/dustin/go-humanize" 18 | "golang.org/x/sync/errgroup" 19 | 20 | "github.com/replicate/pget/pkg/consumer" 21 | "github.com/replicate/pget/pkg/download" 22 | "github.com/replicate/pget/pkg/logging" 23 | ) 24 | 25 | type MetricsPayload struct { 26 | Source string `json:"source,omitempty"` 27 | Type string `json:"type,omitempty"` 28 | Data map[string]any `json:"data,omitempty"` 29 | } 30 | 31 | type Getter struct { 32 | Downloader download.Strategy 33 | Consumer consumer.Consumer 34 | Options Options 35 | } 36 | 37 | type Options struct { 38 | MaxConcurrentFiles int 39 | MetricsEndpoint string 40 | } 41 | 42 | type ManifestEntry struct { 43 | URL string 44 | Dest string 45 | } 46 | 47 | // A Manifest is a slice of ManifestEntry, with a helper method to add entries 48 | type Manifest []ManifestEntry 49 | 50 | func (m Manifest) AddEntry(url string, destination string) Manifest { 51 | return append(m, ManifestEntry{URL: url, Dest: destination}) 52 | } 53 | 54 | func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int64, time.Duration, error) { 55 | if g.Consumer == nil { 56 | g.Consumer = &consumer.FileWriter{} 57 | } 58 | 59 | logger := logging.GetLogger() 60 | downloadStartTime := time.Now() 61 | buffer, fileSize, err := g.Downloader.Fetch(ctx, url) 62 | if err != nil { 63 | g.sendMetrics(url, fileSize, 0, err) 64 | return fileSize, 0, err 65 | } 66 | // downloadElapsed := time.Since(downloadStartTime) 67 | // writeStartTime := time.Now() 68 | 69 | err = g.Consumer.Consume(buffer, dest, fileSize) 70 | if err != nil { 71 | g.sendMetrics(url, fileSize, 0, err) 72 | return fileSize, 0, fmt.Errorf("error writing file: %w", err) 73 | } 74 | 75 | // writeElapsed := time.Since(writeStartTime) 76 | totalElapsed := time.Since(downloadStartTime) 77 | 78 | g.sendMetrics(url, fileSize, (float64(fileSize) / totalElapsed.Seconds()), nil) 79 | 80 | size := humanize.Bytes(uint64(fileSize)) 81 | // downloadThroughput := humanize.Bytes(uint64(float64(fileSize) / downloadElapsed.Seconds())) 82 | // writeThroughput := humanize.Bytes(uint64(float64(fileSize) / writeElapsed.Seconds())) 83 | logger.Info(). 84 | Str("dest", dest). 85 | Str("url", url). 86 | Str("size", size). 87 | // Str("download_throughput", fmt.Sprintf("%s/s", downloadThroughput)). 88 | // Str("download_elapsed", fmt.Sprintf("%.3fs", downloadElapsed.Seconds())). 89 | // Str("write_throughput", fmt.Sprintf("%s/s", writeThroughput)). 90 | // Str("write_elapsed", fmt.Sprintf("%.3fs", writeElapsed.Seconds())). 91 | Str("total_elapsed", fmt.Sprintf("%.3fs", totalElapsed.Seconds())). 92 | Msg("Complete") 93 | 94 | return fileSize, totalElapsed, nil 95 | } 96 | 97 | func (g *Getter) DownloadFiles(ctx context.Context, manifest Manifest) (int64, time.Duration, error) { 98 | if g.Consumer == nil { 99 | g.Consumer = &consumer.FileWriter{} 100 | } 101 | 102 | errGroup, ctx := errgroup.WithContext(ctx) 103 | 104 | if g.Options.MaxConcurrentFiles != 0 { 105 | errGroup.SetLimit(g.Options.MaxConcurrentFiles) 106 | } 107 | 108 | totalSize := new(atomic.Int64) 109 | multifileDownloadStart := time.Now() 110 | 111 | err := g.downloadFilesFromManifest(ctx, errGroup, manifest, totalSize) 112 | if err != nil { 113 | return 0, 0, fmt.Errorf("error initiating download of files from manifest: %w", err) 114 | } 115 | err = errGroup.Wait() 116 | if err != nil { 117 | return 0, 0, fmt.Errorf("error downloading files: %w", err) 118 | } 119 | elapsedTime := time.Since(multifileDownloadStart) 120 | return totalSize.Load(), elapsedTime, nil 121 | } 122 | 123 | func (g *Getter) downloadFilesFromManifest(ctx context.Context, eg *errgroup.Group, entries []ManifestEntry, totalSize *atomic.Int64) error { 124 | logger := logging.GetLogger() 125 | 126 | for _, entry := range entries { 127 | // Avoid the `entry` loop variable being captured by the 128 | // goroutine by creating new variables 129 | url, dest := entry.URL, entry.Dest 130 | logger.Debug().Str("url", url).Str("dest", dest).Msg("Queueing Download") 131 | 132 | eg.Go(func() error { 133 | return g.downloadAndMeasure(ctx, url, dest, totalSize) 134 | }) 135 | } 136 | return nil 137 | } 138 | 139 | func (g *Getter) downloadAndMeasure(ctx context.Context, url, dest string, totalSize *atomic.Int64) error { 140 | fileSize, _, err := g.DownloadFile(ctx, url, dest) 141 | if err != nil { 142 | return err 143 | } 144 | totalSize.Add(fileSize) 145 | return nil 146 | } 147 | 148 | func (g *Getter) sendMetrics(url string, size int64, throughput float64, err error) { 149 | logger := logging.GetLogger() 150 | endpoint := viper.GetString(config.OptMetricsEndpoint) 151 | if endpoint == "" { 152 | return 153 | } 154 | 155 | data := map[string]any{"url": url, "size": size, "version": version.GetVersion()} 156 | if err != nil { 157 | data["error"] = err.Error() 158 | } else { 159 | data["bytes_per_second"] = throughput 160 | } 161 | 162 | payload := MetricsPayload{ 163 | Source: "pget", 164 | Type: "download", 165 | Data: data, 166 | } 167 | body, err := json.Marshal(payload) 168 | if err != nil { 169 | logger.Debug().Err(err).Any("payload", payload).Msg("Error marshalling metrics") 170 | return 171 | } 172 | // Ignore error and response 173 | resp, err := http.DefaultClient.Post(endpoint, "application/json", bytes.NewBuffer(body)) 174 | if err != nil { 175 | logger.Debug().Err(err).Str("endpoint", endpoint).Msg("Error sending metrics") 176 | return 177 | } 178 | if resp.StatusCode != http.StatusOK { 179 | logger.Debug().Int("status_code", resp.StatusCode).Str("endpoint", endpoint).Msg("Error sending metrics") 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /pkg/pget_test.go: -------------------------------------------------------------------------------- 1 | package pget_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net/http" 9 | "net/http/httptest" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "testing" 14 | "testing/fstest" 15 | "testing/iotest" 16 | 17 | "github.com/dustin/go-humanize" 18 | "github.com/rs/zerolog" 19 | "github.com/stretchr/testify/assert" 20 | "github.com/stretchr/testify/require" 21 | 22 | pget "github.com/replicate/pget/pkg" 23 | "github.com/replicate/pget/pkg/client" 24 | "github.com/replicate/pget/pkg/download" 25 | ) 26 | 27 | var testFS = fstest.MapFS{ 28 | "hello.txt": {Data: []byte("hello, world!")}, 29 | } 30 | 31 | func init() { 32 | zerolog.SetGlobalLevel(zerolog.WarnLevel) 33 | } 34 | 35 | var defaultOpts = download.Options{Client: client.Options{}} 36 | var http2Opts = download.Options{Client: client.Options{TransportOpts: client.TransportOptions{ForceHTTP2: true}}} 37 | 38 | func makeGetter(opts download.Options) *pget.Getter { 39 | return &pget.Getter{ 40 | Downloader: download.GetBufferMode(opts), 41 | } 42 | } 43 | 44 | func tempFilename() string { 45 | // get a temp filename that doesn't already exist by creating 46 | // a temp file and immediately deleting it 47 | dest, _ := os.CreateTemp("", "pget-buffer-test") 48 | os.Remove(dest.Name()) 49 | return dest.Name() 50 | } 51 | 52 | // writeRandomFile creates a sparse file with the given size and 53 | // writes some random bytes somewhere in it. This is much faster than 54 | // filling the whole file with random bytes would be, but it also 55 | // gives us some confidence that the range requests are being 56 | // reassembled correctly. 57 | func writeRandomFile(t require.TestingT, path string, size int64) { 58 | file, err := os.Create(path) 59 | require.NoError(t, err) 60 | defer file.Close() 61 | 62 | rnd := rand.New(rand.NewSource(99)) 63 | 64 | // under 1 MiB, just fill the whole file with random data 65 | if size < 1*humanize.MiByte { 66 | _, err = io.CopyN(file, rnd, size) 67 | require.NoError(t, err) 68 | return 69 | } 70 | 71 | // set the file size 72 | err = file.Truncate(size) 73 | require.NoError(t, err) 74 | 75 | // write some random data to the start 76 | _, err = io.CopyN(file, rnd, 1*humanize.KiByte) 77 | require.NoError(t, err) 78 | 79 | // and somewhere else in the file 80 | _, err = file.Seek(rnd.Int63()%(size-1*humanize.KiByte), io.SeekStart) 81 | require.NoError(t, err) 82 | _, err = io.CopyN(file, rnd, 1*humanize.KiByte) 83 | require.NoError(t, err) 84 | } 85 | 86 | func assertFileHasContent(t *testing.T, expectedContent []byte, path string) { 87 | contentFile, err := os.Open(path) 88 | require.NoError(t, err) 89 | defer contentFile.Close() 90 | 91 | assert.NoError(t, iotest.TestReader(contentFile, expectedContent)) 92 | } 93 | 94 | func TestDownloadSmallFile(t *testing.T) { 95 | ts := httptest.NewServer(http.FileServer(http.FS(testFS))) 96 | defer ts.Close() 97 | 98 | dest := tempFilename() 99 | defer os.Remove(dest) 100 | 101 | getter := makeGetter(defaultOpts) 102 | 103 | _, _, err := getter.DownloadFile(context.Background(), ts.URL+"/hello.txt", dest) 104 | assert.NoError(t, err) 105 | 106 | assertFileHasContent(t, testFS["hello.txt"].Data, dest) 107 | } 108 | 109 | func testDownloadSingleFile(opts download.Options, size int64, t *testing.T) { 110 | dir, err := os.MkdirTemp("", "pget-buffer-test") 111 | require.NoError(t, err) 112 | defer os.RemoveAll(dir) 113 | 114 | srcFilename := filepath.Join(dir, "random-bytes") 115 | 116 | writeRandomFile(t, srcFilename, size) 117 | 118 | ts := httptest.NewServer(http.FileServer(http.Dir(dir))) 119 | defer ts.Close() 120 | 121 | getter := makeGetter(opts) 122 | 123 | dest := tempFilename() 124 | defer os.Remove(dest) 125 | 126 | actualSize, _, err := getter.DownloadFile(context.Background(), ts.URL+"/random-bytes", dest) 127 | assert.NoError(t, err) 128 | 129 | assert.Equal(t, size, actualSize) 130 | 131 | cmd := exec.Command("diff", "-q", srcFilename, dest) 132 | err = cmd.Run() 133 | assert.NoError(t, err, "source file and dest file should be identical") 134 | } 135 | 136 | func TestDownloadSmallFileWith200(t *testing.T) { 137 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 138 | w.WriteHeader(http.StatusOK) 139 | _, err := w.Write([]byte("{\"message\": \"Tweet! Tweet!\"}")) 140 | assert.NoError(t, err) 141 | })) 142 | defer ts.Close() 143 | 144 | dest := tempFilename() 145 | defer os.Remove(dest) 146 | 147 | getter := makeGetter(defaultOpts) 148 | 149 | _, _, err := getter.DownloadFile(context.Background(), ts.URL+"/hello.txt", dest) 150 | assert.NoError(t, err) 151 | 152 | assertFileHasContent(t, []byte("{\"message\": \"Tweet! Tweet!\"}"), dest) 153 | } 154 | 155 | func TestDownload10MH1(t *testing.T) { testDownloadSingleFile(defaultOpts, 10*humanize.MiByte, t) } 156 | func TestDownload100MH1(t *testing.T) { testDownloadSingleFile(defaultOpts, 100*humanize.MiByte, t) } 157 | func TestDownload10MH2(t *testing.T) { testDownloadSingleFile(http2Opts, 10*humanize.MiByte, t) } 158 | func TestDownload100MH2(t *testing.T) { testDownloadSingleFile(http2Opts, 100*humanize.MiByte, t) } 159 | 160 | func testDownloadMultipleFiles(opts download.Options, sizes []int64, t *testing.T) { 161 | inputDir, err := os.MkdirTemp("", "pget-buffer-test-in") 162 | require.NoError(t, err) 163 | defer os.RemoveAll(inputDir) 164 | outputDir, err := os.MkdirTemp("", "pget-buffer-test-out") 165 | require.NoError(t, err) 166 | defer os.RemoveAll(outputDir) 167 | 168 | srcFilenames := make([]string, len(sizes)) 169 | var expectedTotalSize int64 170 | for i, size := range sizes { 171 | srcFilenames[i] = fmt.Sprintf("random-bytes.%d", i) 172 | 173 | writeRandomFile(t, filepath.Join(inputDir, srcFilenames[i]), size) 174 | expectedTotalSize += size 175 | } 176 | 177 | ts := httptest.NewServer(http.FileServer(http.Dir(inputDir))) 178 | defer ts.Close() 179 | 180 | manifest := make(pget.Manifest, 0) 181 | 182 | for _, srcFilename := range srcFilenames { 183 | manifest = manifest.AddEntry(ts.URL+"/"+srcFilename, filepath.Join(outputDir, srcFilename)) 184 | require.NoError(t, err) 185 | } 186 | 187 | getter := makeGetter(opts) 188 | 189 | actualTotalSize, _, err := getter.DownloadFiles(context.Background(), manifest) 190 | assert.NoError(t, err) 191 | 192 | assert.Equal(t, expectedTotalSize, actualTotalSize) 193 | 194 | cmd := exec.Command("diff", "-q", inputDir, outputDir) 195 | err = cmd.Run() 196 | assert.NoError(t, err, "source file and dest file should be identical") 197 | } 198 | 199 | func TestDownloadFiveFiles(t *testing.T) { 200 | testDownloadMultipleFiles(defaultOpts, []int64{ 201 | 10 * humanize.KiByte, 202 | 20 * humanize.KiByte, 203 | 30 * humanize.KiByte, 204 | 40 * humanize.KiByte, 205 | 50 * humanize.KiByte, 206 | }, t) 207 | } 208 | 209 | func TestDownloadFive10MFiles(t *testing.T) { 210 | testDownloadMultipleFiles(defaultOpts, []int64{ 211 | 10 * humanize.MiByte, 212 | 10 * humanize.MiByte, 213 | 10 * humanize.MiByte, 214 | 10 * humanize.MiByte, 215 | 10 * humanize.MiByte, 216 | }, t) 217 | } 218 | 219 | func TestManifest_AddEntry(t *testing.T) { 220 | entries := make(pget.Manifest, 0) 221 | 222 | entries = entries.AddEntry("https://example.com/file1.txt", "/tmp/file1.txt") 223 | assert.Len(t, entries, 1) 224 | entries = entries.AddEntry("https://example.org/file2.txt", "/tmp/file2.txt") 225 | assert.Len(t, entries, 2) 226 | 227 | assert.Equal(t, "https://example.com/file1.txt", entries[0].URL) 228 | assert.Equal(t, "/tmp/file1.txt", entries[0].Dest) 229 | assert.Equal(t, "https://example.org/file2.txt", entries[1].URL) 230 | assert.Equal(t, "/tmp/file2.txt", entries[1].Dest) 231 | 232 | } 233 | -------------------------------------------------------------------------------- /pkg/version/info.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import "fmt" 4 | 5 | const ( 6 | snapshotString = "snapshot" 7 | ) 8 | 9 | var ( 10 | // Version Build Time Injected information 11 | Version string 12 | CommitHash string 13 | BuildTime string 14 | Prerelease string 15 | Snapshot string 16 | OS string 17 | Arch string 18 | Branch string 19 | ) 20 | 21 | // GetVersion returns the version information in a human consumable way. This is intended to be used 22 | // when the user requests the version information or in the case of the User-Agent. 23 | func GetVersion() string { 24 | return makeVersionString(Version, CommitHash, BuildTime, Prerelease, Snapshot, OS, Arch, Branch) 25 | } 26 | 27 | func makeVersionString(version, commitHash, buildtime, prerelease, snapshot, os, arch, branch string) (versionString string) { 28 | versionString = fmt.Sprintf("%s(%s)", version, commitHash) 29 | if prerelease != "" { 30 | versionString = fmt.Sprintf("%s-%s", versionString, prerelease) 31 | } else if snapshot == "true" { 32 | versionString = fmt.Sprintf("%s-%s", versionString, snapshotString) 33 | } 34 | 35 | if branch != "" && branch != "main" && branch != "HEAD" { 36 | versionString = fmt.Sprintf("%s[%s]", versionString, branch) 37 | } 38 | 39 | if os != "" && arch != "" { 40 | versionString = fmt.Sprintf("%s/%s-%s", versionString, os, arch) 41 | } else if os != "" { 42 | versionString = fmt.Sprintf("%s/%s", versionString, os) 43 | } 44 | 45 | return versionString 46 | } 47 | -------------------------------------------------------------------------------- /pkg/version/info_test.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_makeVersionString(t *testing.T) { 8 | type args struct { 9 | version string 10 | commitHash string 11 | buildtime string 12 | prerelease string 13 | snapshot string 14 | os string 15 | arch string 16 | branch string 17 | } 18 | tests := []struct { 19 | name string 20 | args args 21 | expected string 22 | }{ 23 | { 24 | name: "Typical Development", 25 | args: args{ 26 | version: "1.0.0", 27 | commitHash: "abc123", 28 | os: "darwin", 29 | arch: "amd64", 30 | branch: "Branch1", 31 | }, 32 | expected: "1.0.0(abc123)[Branch1]/darwin-amd64", 33 | }, 34 | { 35 | name: "With prerelease and snapshot", 36 | args: args{ 37 | version: "1.0.0", 38 | commitHash: "abc123", 39 | prerelease: "alpha", 40 | snapshot: "20221130", 41 | os: "darwin", 42 | arch: "amd64", 43 | branch: "Branch1", 44 | }, 45 | expected: "1.0.0(abc123)-alpha[Branch1]/darwin-amd64", 46 | }, 47 | { 48 | name: "No os or arch", 49 | args: args{ 50 | version: "1.0.0", 51 | commitHash: "abc123", 52 | branch: "Branch1", 53 | }, 54 | expected: "1.0.0(abc123)[Branch1]", 55 | }, 56 | { 57 | name: "Branch Main", 58 | args: args{ 59 | version: "1.0.0", 60 | commitHash: "abc123", 61 | os: "darwin", 62 | arch: "amd64", 63 | branch: "main", 64 | }, 65 | expected: "1.0.0(abc123)/darwin-amd64", 66 | }, 67 | { 68 | name: "Branch HEAD", 69 | args: args{ 70 | version: "1.0.0", 71 | commitHash: "abc123", 72 | os: "darwin", 73 | arch: "amd64", 74 | branch: "HEAD", 75 | }, 76 | expected: "1.0.0(abc123)/darwin-amd64", 77 | }, 78 | } 79 | for _, tt := range tests { 80 | t.Run(tt.name, func(t *testing.T) { 81 | if got := makeVersionString(tt.args.version, tt.args.commitHash, tt.args.buildtime, tt.args.prerelease, tt.args.snapshot, tt.args.os, tt.args.arch, tt.args.branch); got != tt.expected { 82 | t.Errorf("makeVersionString() = %v, expected %v", got, tt.expected) 83 | } 84 | }) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /script/format: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eu 4 | 5 | : "${CHECKONLY:=}" 6 | : "${WORKDIR:=$PWD}" 7 | 8 | LOCAL=$(go list -m) 9 | 10 | if [ -n "$CHECKONLY" ]; then 11 | OUTPUT=$(go run -C "$WORKDIR" golang.org/x/tools/cmd/goimports -d -local "$LOCAL" .) 12 | printf "%s" "$OUTPUT" 13 | 14 | if [ -n "$OUTPUT" ]; then 15 | exit 1 16 | fi 17 | exit 18 | fi 19 | 20 | exec go run golang.org/x/tools/cmd/goimports -d -w -local "$LOCAL" . 21 | -------------------------------------------------------------------------------- /script/lint: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eu 4 | 5 | : "${GITHUB_ACTIONS:=}" 6 | : "${WORKDIR:=$PWD}" 7 | 8 | cd "$(dirname "$0")" 9 | cd .. 10 | 11 | if [ "$GITHUB_ACTIONS" = "true" ]; then 12 | set -- "$@" --out-format=github-actions 13 | fi 14 | 15 | exec go run -C "$WORKDIR" github.com/golangci/golangci-lint/cmd/golangci-lint run ./... "$@" 16 | -------------------------------------------------------------------------------- /script/test: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -eu 4 | 5 | : "${GITHUB_ACTIONS:=}" 6 | : "${WORKDIR:=$PWD}" 7 | 8 | cd "$(dirname "$0")" 9 | cd .. 10 | 11 | exec go run -C "$WORKDIR" gotest.tools/gotestsum "$@" -- -timeout 1200s -parallel 5 ./... 12 | -------------------------------------------------------------------------------- /tools.go: -------------------------------------------------------------------------------- 1 | //go:build tools 2 | // +build tools 3 | 4 | // https://github.com/go-modules-by-example/index/blob/master/010_tools/README.md 5 | 6 | package tools 7 | 8 | import ( 9 | _ "github.com/golangci/golangci-lint/cmd/golangci-lint" 10 | _ "golang.org/x/tools/cmd/goimports" 11 | _ "gotest.tools/gotestsum" 12 | ) 13 | --------------------------------------------------------------------------------