├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── states.wsd └── v3 ├── client.go ├── client_test.go ├── cmd └── grab │ ├── .gitignore │ ├── Makefile │ └── main.go ├── doc.go ├── error.go ├── example_client_test.go ├── example_request_test.go ├── go.mod ├── grab.go ├── grab_test.go ├── pkg ├── bps │ ├── bps.go │ ├── sma.go │ └── sma_test.go ├── grabtest │ ├── assert.go │ ├── handler.go │ ├── handler_option.go │ ├── handler_test.go │ └── util.go └── grabui │ ├── console_client.go │ └── grabui.go ├── rate_limiter.go ├── rate_limiter_test.go ├── request.go ├── response.go ├── response_test.go ├── transfer.go ├── util.go └── util_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore IDE project files 2 | *.iml 3 | .idea/ 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - tip 5 | - 1.17.x 6 | - 1.16.x 7 | - 1.15.x 8 | - 1.14.x 9 | 10 | script: make check 11 | 12 | env: 13 | - GOARCH=amd64 14 | - GOARCH=386 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Ryan Armstrong. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO = go 2 | GOGET = $(GO) get -u 3 | 4 | all: check 5 | 6 | check: 7 | cd v3 && $(GO) test -v -cover -race ./... 8 | cd v3/cmd/grab && $(MAKE) -B all 9 | 10 | install: 11 | cd v3/cmd/grab && $(MAKE) install 12 | 13 | clean: 14 | cd v3 && $(GO) clean -x ./... 15 | rm -rvf ./.test* 16 | 17 | .PHONY: all check install clean 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # grab 2 | 3 | [![GoDoc](https://godoc.org/github.com/cavaliercoder/grab?status.svg)](https://godoc.org/github.com/cavaliercoder/grab) [![Build Status](https://travis-ci.org/cavaliercoder/grab.svg?branch=master)](https://travis-ci.org/cavaliercoder/grab) [![Go Report Card](https://goreportcard.com/badge/github.com/cavaliercoder/grab)](https://goreportcard.com/report/github.com/cavaliercoder/grab) 4 | 5 | *Downloading the internet, one goroutine at a time!* 6 | 7 | $ go get github.com/cavaliergopher/grab/v3 8 | 9 | Grab is a Go package for downloading files from the internet with the following 10 | rad features: 11 | 12 | * Monitor download progress concurrently 13 | * Auto-resume incomplete downloads 14 | * Guess filename from content header or URL path 15 | * Safely cancel downloads using context.Context 16 | * Validate downloads using checksums 17 | * Download batches of files concurrently 18 | * Apply rate limiters 19 | 20 | Requires Go v1.7+ 21 | 22 | ## Example 23 | 24 | The following example downloads a PDF copy of the free eBook, "An Introduction 25 | to Programming in Go" into the current working directory. 26 | 27 | ```go 28 | resp, err := grab.Get(".", "http://www.golang-book.com/public/pdf/gobook.pdf") 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | 33 | fmt.Println("Download saved to", resp.Filename) 34 | ``` 35 | 36 | The following, more complete example allows for more granular control and 37 | periodically prints the download progress until it is complete. 38 | 39 | The second time you run the example, it will auto-resume the previous download 40 | and exit sooner. 41 | 42 | ```go 43 | package main 44 | 45 | import ( 46 | "fmt" 47 | "os" 48 | "time" 49 | 50 | "github.com/cavaliergopher/grab/v3" 51 | ) 52 | 53 | func main() { 54 | // create client 55 | client := grab.NewClient() 56 | req, _ := grab.NewRequest(".", "http://www.golang-book.com/public/pdf/gobook.pdf") 57 | 58 | // start download 59 | fmt.Printf("Downloading %v...\n", req.URL()) 60 | resp := client.Do(req) 61 | fmt.Printf(" %v\n", resp.HTTPResponse.Status) 62 | 63 | // start UI loop 64 | t := time.NewTicker(500 * time.Millisecond) 65 | defer t.Stop() 66 | 67 | Loop: 68 | for { 69 | select { 70 | case <-t.C: 71 | fmt.Printf(" transferred %v / %v bytes (%.2f%%)\n", 72 | resp.BytesComplete(), 73 | resp.Size, 74 | 100*resp.Progress()) 75 | 76 | case <-resp.Done: 77 | // download is complete 78 | break Loop 79 | } 80 | } 81 | 82 | // check for errors 83 | if err := resp.Err(); err != nil { 84 | fmt.Fprintf(os.Stderr, "Download failed: %v\n", err) 85 | os.Exit(1) 86 | } 87 | 88 | fmt.Printf("Download saved to ./%v \n", resp.Filename) 89 | 90 | // Output: 91 | // Downloading http://www.golang-book.com/public/pdf/gobook.pdf... 92 | // 200 OK 93 | // transferred 42970 / 2893557 bytes (1.49%) 94 | // transferred 1207474 / 2893557 bytes (41.73%) 95 | // transferred 2758210 / 2893557 bytes (95.32%) 96 | // Download saved to ./gobook.pdf 97 | } 98 | ``` 99 | 100 | ## Design trade-offs 101 | 102 | The primary use case for Grab is to concurrently downloading thousands of large 103 | files from remote file repositories where the remote files are immutable. 104 | Examples include operating system package repositories or ISO libraries. 105 | 106 | Grab aims to provide robust, sane defaults. These are usually determined using 107 | the HTTP specifications, or by mimicking the behavior of common web clients like 108 | cURL, wget and common web browsers. 109 | 110 | Grab aims to be stateless. The only state that exists is the remote files you 111 | wish to download and the local copy which may be completed, partially completed 112 | or not yet created. The advantage to this is that the local file system is not 113 | cluttered unnecessarily with addition state files (like a `.crdownload` file). 114 | The disadvantage of this approach is that grab must make assumptions about the 115 | local and remote state; specifically, that they have not been modified by 116 | another program. 117 | 118 | If the local or remote file are modified outside of grab, and you download the 119 | file again with resuming enabled, the local file will likely become corrupted. 120 | In this case, you might consider making remote files immutable, or disabling 121 | resume. 122 | 123 | Grab aims to enable best-in-class functionality for more complex features 124 | through extensible interfaces, rather than reimplementation. For example, 125 | you can provide your own Hash algorithm to compute file checksums, or your 126 | own rate limiter implementation (with all the associated trade-offs) to rate 127 | limit downloads. 128 | -------------------------------------------------------------------------------- /states.wsd: -------------------------------------------------------------------------------- 1 | @startuml 2 | title Grab transfer state 3 | 4 | legend 5 | | # | Meaning | 6 | | D | Destination path known | 7 | | S | File size known | 8 | | O | Server options known (Accept-Ranges) | 9 | | R | Resume supported (Accept-Ranges) | 10 | | Z | Local file empty or missing | 11 | | P | Local file partially complete | 12 | endlegend 13 | 14 | [*] --> Empty 15 | [*] --> D 16 | [*] --> S 17 | [*] --> DS 18 | 19 | Empty : Filename: "" 20 | Empty : Size: 0 21 | Empty --> O : HEAD: Method not allowed 22 | Empty --> DSO : HEAD: Range not supported 23 | Empty --> DSOR : HEAD: Range supported 24 | 25 | DS : Filename: "foo.bar" 26 | DS : Size: > 0 27 | DS --> DSZ : checkExisting(): File missing 28 | DS --> DSP : checkExisting(): File partial 29 | DS --> [*] : checkExisting(): File complete 30 | DS --> ERROR 31 | 32 | S : Filename: "" 33 | S : Size: > 0 34 | S --> SO : HEAD: Method not allowed 35 | S --> DSO : HEAD: Range not supported 36 | S --> DSOR : HEAD: Range supported 37 | 38 | D : Filename: "foo.bar" 39 | D : Size: 0 40 | D --> DO : HEAD: Method not allowed 41 | D --> DSO : HEAD: Range not supported 42 | D --> DSOR : HEAD: Range supported 43 | 44 | 45 | O : Filename: "" 46 | O : Size: 0 47 | O : CanResume: false 48 | O --> DSO : GET 200 49 | O --> ERROR 50 | 51 | SO : Filename: "" 52 | SO : Size: > 0 53 | SO : CanResume: false 54 | SO --> DSO : GET: 200 55 | SO --> ERROR 56 | 57 | DO : Filename: "foo.bar" 58 | DO : Size: 0 59 | DO : CanResume: false 60 | DO --> DSO : GET 200 61 | DO --> ERROR 62 | 63 | DSZ : Filename: "foo.bar" 64 | DSZ : Size: > 0 65 | DSZ : File: empty 66 | DSZ --> DSORZ : HEAD: Range supported 67 | DSZ --> DSOZ : HEAD 405 or Range unsupported 68 | 69 | DSP : Filename: "foo.bar" 70 | DSP : Size: > 0 71 | DSP : File: partial 72 | DSP --> DSORP : HEAD: Range supported 73 | DSP --> DSOZ : HEAD: 405 or Range unsupported 74 | 75 | DSO : Filename: "foo.bar" 76 | DSO : Size: > 0 77 | DSO : CanResume: false 78 | DSO --> DSOZ : checkExisting(): File partial|missing 79 | DSO --> [*] : checkExisting(): File complete 80 | 81 | DSOR : Filename: "foo.bar" 82 | DSOR : Size: > 0 83 | DSOR : CanResume: true 84 | DSOR --> DSORP : CheckLocal: File partial 85 | DSOR --> DSORZ : CheckLocal: File missing 86 | 87 | DSORP : Filename: "foo.bar" 88 | DSORP : Size: > 0 89 | DSORP : CanResume: true 90 | DSORP : File: partial 91 | DSORP --> Transferring 92 | 93 | DSORZ : Filename: "foo.bar" 94 | DSORZ : Size: > 0 95 | DSORZ : CanResume: true 96 | DSORZ : File: empty 97 | DSORZ --> Transferring 98 | 99 | DSOZ : Filename: "foo.bar" 100 | DSOZ : Size: > 0 101 | DSOZ : CanResume: false 102 | DSOZ : File: empty 103 | DSOZ --> Transferring 104 | 105 | Transferring --> [*] 106 | Transferring --> ERROR 107 | 108 | ERROR : Something went wrong 109 | ERROR --> [*] 110 | 111 | @enduml -------------------------------------------------------------------------------- /v3/client.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "os" 10 | "path/filepath" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | ) 15 | 16 | // HTTPClient provides an interface allowing us to perform HTTP requests. 17 | type HTTPClient interface { 18 | Do(req *http.Request) (*http.Response, error) 19 | } 20 | 21 | // truncater is a private interface allowing different response 22 | // Writers to be truncated 23 | type truncater interface { 24 | Truncate(size int64) error 25 | } 26 | 27 | // A Client is a file download client. 28 | // 29 | // Clients are safe for concurrent use by multiple goroutines. 30 | type Client struct { 31 | // HTTPClient specifies the http.Client which will be used for communicating 32 | // with the remote server during the file transfer. 33 | HTTPClient HTTPClient 34 | 35 | // UserAgent specifies the User-Agent string which will be set in the 36 | // headers of all requests made by this client. 37 | // 38 | // The user agent string may be overridden in the headers of each request. 39 | UserAgent string 40 | 41 | // BufferSize specifies the size in bytes of the buffer that is used for 42 | // transferring all requested files. Larger buffers may result in faster 43 | // throughput but will use more memory and result in less frequent updates 44 | // to the transfer progress statistics. The BufferSize of each request can 45 | // be overridden on each Request object. Default: 32KB. 46 | BufferSize int 47 | } 48 | 49 | // NewClient returns a new file download Client, using default configuration. 50 | func NewClient() *Client { 51 | return &Client{ 52 | UserAgent: "grab", 53 | HTTPClient: &http.Client{ 54 | Transport: &http.Transport{ 55 | Proxy: http.ProxyFromEnvironment, 56 | }, 57 | }, 58 | } 59 | } 60 | 61 | // DefaultClient is the default client and is used by all Get convenience 62 | // functions. 63 | var DefaultClient = NewClient() 64 | 65 | // Do sends a file transfer request and returns a file transfer response, 66 | // following policy (e.g. redirects, cookies, auth) as configured on the 67 | // client's HTTPClient. 68 | // 69 | // Like http.Get, Do blocks while the transfer is initiated, but returns as soon 70 | // as the transfer has started transferring in a background goroutine, or if it 71 | // failed early. 72 | // 73 | // An error is returned via Response.Err if caused by client policy (such as 74 | // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err 75 | // will block the caller until the transfer is completed, successfully or 76 | // otherwise. 77 | func (c *Client) Do(req *Request) *Response { 78 | // cancel will be called on all code-paths via closeResponse 79 | ctx, cancel := context.WithCancel(req.Context()) 80 | req = req.WithContext(ctx) 81 | resp := &Response{ 82 | Request: req, 83 | Start: time.Now(), 84 | Done: make(chan struct{}, 0), 85 | Filename: req.Filename, 86 | ctx: ctx, 87 | cancel: cancel, 88 | bufferSize: req.BufferSize, 89 | } 90 | if resp.bufferSize == 0 { 91 | // default to Client.BufferSize 92 | resp.bufferSize = c.BufferSize 93 | } 94 | 95 | // Run state-machine while caller is blocked to initialize the file transfer. 96 | // Must never transition to the copyFile state - this happens next in another 97 | // goroutine. 98 | c.run(resp, c.statFileInfo) 99 | 100 | // Run copyFile in a new goroutine. copyFile will no-op if the transfer is 101 | // already complete or failed. 102 | go c.run(resp, c.copyFile) 103 | return resp 104 | } 105 | 106 | // DoChannel executes all requests sent through the given Request channel, one 107 | // at a time, until it is closed by another goroutine. The caller is blocked 108 | // until the Request channel is closed and all transfers have completed. All 109 | // responses are sent through the given Response channel as soon as they are 110 | // received from the remote servers and can be used to track the progress of 111 | // each download. 112 | // 113 | // Slow Response receivers will cause a worker to block and therefore delay the 114 | // start of the transfer for an already initiated connection - potentially 115 | // causing a server timeout. It is the caller's responsibility to ensure a 116 | // sufficient buffer size is used for the Response channel to prevent this. 117 | // 118 | // If an error occurs during any of the file transfers it will be accessible via 119 | // the associated Response.Err function. 120 | func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) { 121 | // TODO: enable cancelling of batch jobs 122 | for req := range reqch { 123 | resp := c.Do(req) 124 | respch <- resp 125 | <-resp.Done 126 | } 127 | } 128 | 129 | // DoBatch executes all the given requests using the given number of concurrent 130 | // workers. Control is passed back to the caller as soon as the workers are 131 | // initiated. 132 | // 133 | // If the requested number of workers is less than one, a worker will be created 134 | // for every request. I.e. all requests will be executed concurrently. 135 | // 136 | // If an error occurs during any of the file transfers it will be accessible via 137 | // call to the associated Response.Err. 138 | // 139 | // The returned Response channel is closed only after all of the given Requests 140 | // have completed, successfully or otherwise. 141 | func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response { 142 | if workers < 1 { 143 | workers = len(requests) 144 | } 145 | reqch := make(chan *Request, len(requests)) 146 | respch := make(chan *Response, len(requests)) 147 | wg := sync.WaitGroup{} 148 | for i := 0; i < workers; i++ { 149 | wg.Add(1) 150 | go func() { 151 | c.DoChannel(reqch, respch) 152 | wg.Done() 153 | }() 154 | } 155 | 156 | // queue requests 157 | go func() { 158 | for _, req := range requests { 159 | reqch <- req 160 | } 161 | close(reqch) 162 | wg.Wait() 163 | close(respch) 164 | }() 165 | return respch 166 | } 167 | 168 | // An stateFunc is an action that mutates the state of a Response and returns 169 | // the next stateFunc to be called. 170 | type stateFunc func(*Response) stateFunc 171 | 172 | // run calls the given stateFunc function and all subsequent returned stateFuncs 173 | // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc 174 | // should mutate the state of the given Response until it has completed 175 | // downloading or failed. 176 | func (c *Client) run(resp *Response, f stateFunc) { 177 | for { 178 | select { 179 | case <-resp.ctx.Done(): 180 | if resp.IsComplete() { 181 | return 182 | } 183 | resp.err = resp.ctx.Err() 184 | f = c.closeResponse 185 | 186 | default: 187 | // keep working 188 | } 189 | if f = f(resp); f == nil { 190 | return 191 | } 192 | } 193 | } 194 | 195 | // statFileInfo retrieves FileInfo for any local file matching 196 | // Response.Filename. 197 | // 198 | // If the file does not exist, is a directory, or its name is unknown the next 199 | // stateFunc is headRequest. 200 | // 201 | // If the file exists, Response.fi is set and the next stateFunc is 202 | // validateLocal. 203 | // 204 | // If an error occurs, the next stateFunc is closeResponse. 205 | func (c *Client) statFileInfo(resp *Response) stateFunc { 206 | if resp.Request.NoStore || resp.Filename == "" { 207 | return c.headRequest 208 | } 209 | fi, err := os.Stat(resp.Filename) 210 | if err != nil { 211 | if os.IsNotExist(err) { 212 | return c.headRequest 213 | } 214 | resp.err = err 215 | return c.closeResponse 216 | } 217 | if fi.IsDir() { 218 | resp.Filename = "" 219 | return c.headRequest 220 | } 221 | resp.fi = fi 222 | return c.validateLocal 223 | } 224 | 225 | // validateLocal compares a local copy of the downloaded file to the remote 226 | // file. 227 | // 228 | // An error is returned if the local file is larger than the remote file, or 229 | // Request.SkipExisting is true. 230 | // 231 | // If the existing file matches the length of the remote file, the next 232 | // stateFunc is checksumFile. 233 | // 234 | // If the local file is smaller than the remote file and the remote server is 235 | // known to support ranged requests, the next stateFunc is getRequest. 236 | func (c *Client) validateLocal(resp *Response) stateFunc { 237 | if resp.Request.SkipExisting { 238 | resp.err = ErrFileExists 239 | return c.closeResponse 240 | } 241 | 242 | // determine target file size 243 | expectedSize := resp.Request.Size 244 | if expectedSize == 0 && resp.HTTPResponse != nil { 245 | expectedSize = resp.HTTPResponse.ContentLength 246 | } 247 | 248 | if expectedSize == 0 { 249 | // size is either actually 0 or unknown 250 | // if unknown, we ask the remote server 251 | // if known to be 0, we proceed with a GET 252 | return c.headRequest 253 | } 254 | 255 | if expectedSize == resp.fi.Size() { 256 | // local file matches remote file size - wrap it up 257 | resp.DidResume = true 258 | resp.bytesResumed = resp.fi.Size() 259 | return c.checksumFile 260 | } 261 | 262 | if resp.Request.NoResume { 263 | // local file should be overwritten 264 | return c.getRequest 265 | } 266 | 267 | if expectedSize >= 0 && expectedSize < resp.fi.Size() { 268 | // remote size is known, is smaller than local size and we want to resume 269 | resp.err = ErrBadLength 270 | return c.closeResponse 271 | } 272 | 273 | if resp.CanResume { 274 | // set resume range on GET request 275 | resp.Request.HTTPRequest.Header.Set( 276 | "Range", 277 | fmt.Sprintf("bytes=%d-", resp.fi.Size())) 278 | resp.DidResume = true 279 | resp.bytesResumed = resp.fi.Size() 280 | return c.getRequest 281 | } 282 | return c.headRequest 283 | } 284 | 285 | func (c *Client) checksumFile(resp *Response) stateFunc { 286 | if resp.Request.hash == nil { 287 | return c.closeResponse 288 | } 289 | if resp.Filename == "" { 290 | panic("grab: developer error: filename not set") 291 | } 292 | if resp.Size() < 0 { 293 | panic("grab: developer error: size unknown") 294 | } 295 | req := resp.Request 296 | 297 | // compute checksum 298 | var sum []byte 299 | sum, resp.err = resp.checksumUnsafe() 300 | if resp.err != nil { 301 | return c.closeResponse 302 | } 303 | 304 | // compare checksum 305 | if !bytes.Equal(sum, req.checksum) { 306 | resp.err = ErrBadChecksum 307 | if !resp.Request.NoStore && req.deleteOnError { 308 | if err := os.Remove(resp.Filename); err != nil { 309 | // err should be os.PathError and include file path 310 | resp.err = fmt.Errorf( 311 | "cannot remove downloaded file with checksum mismatch: %v", 312 | err) 313 | } 314 | } 315 | } 316 | return c.closeResponse 317 | } 318 | 319 | // doHTTPRequest sends a HTTP Request and returns the response 320 | func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) { 321 | if c.UserAgent != "" && req.Header.Get("User-Agent") == "" { 322 | req.Header.Set("User-Agent", c.UserAgent) 323 | } 324 | return c.HTTPClient.Do(req) 325 | } 326 | 327 | func (c *Client) headRequest(resp *Response) stateFunc { 328 | if resp.optionsKnown { 329 | return c.getRequest 330 | } 331 | resp.optionsKnown = true 332 | 333 | if resp.Request.NoResume { 334 | return c.getRequest 335 | } 336 | 337 | if resp.Filename != "" && resp.fi == nil { 338 | // destination path is already known and does not exist 339 | return c.getRequest 340 | } 341 | 342 | hreq := new(http.Request) 343 | *hreq = *resp.Request.HTTPRequest 344 | hreq.Method = "HEAD" 345 | 346 | resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq) 347 | if resp.err != nil { 348 | return c.closeResponse 349 | } 350 | resp.HTTPResponse.Body.Close() 351 | 352 | if resp.HTTPResponse.StatusCode != http.StatusOK { 353 | return c.getRequest 354 | } 355 | 356 | // In case of redirects during HEAD, record the final URL and use it 357 | // instead of the original URL when sending future requests. 358 | // This way we avoid sending potentially unsupported requests to 359 | // the original URL, e.g. "Range", since it was the final URL 360 | // that advertised its support. 361 | resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL 362 | resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host 363 | 364 | return c.readResponse 365 | } 366 | 367 | func (c *Client) getRequest(resp *Response) stateFunc { 368 | resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) 369 | if resp.err != nil { 370 | return c.closeResponse 371 | } 372 | 373 | // TODO: check Content-Range 374 | 375 | // check status code 376 | if !resp.Request.IgnoreBadStatusCodes { 377 | if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 { 378 | resp.err = StatusCodeError(resp.HTTPResponse.StatusCode) 379 | return c.closeResponse 380 | } 381 | } 382 | 383 | return c.readResponse 384 | } 385 | 386 | func (c *Client) readResponse(resp *Response) stateFunc { 387 | if resp.HTTPResponse == nil { 388 | panic("grab: developer error: Response.HTTPResponse is nil") 389 | } 390 | 391 | // check expected size 392 | resp.sizeUnsafe = resp.HTTPResponse.ContentLength 393 | if resp.sizeUnsafe >= 0 { 394 | // remote size is known 395 | resp.sizeUnsafe += resp.bytesResumed 396 | if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe { 397 | resp.err = ErrBadLength 398 | return c.closeResponse 399 | } 400 | } 401 | 402 | // check filename 403 | if resp.Filename == "" { 404 | filename, err := guessFilename(resp.HTTPResponse) 405 | if err != nil { 406 | resp.err = err 407 | return c.closeResponse 408 | } 409 | // Request.Filename will be empty or a directory 410 | resp.Filename = filepath.Join(resp.Request.Filename, filename) 411 | } 412 | 413 | if !resp.Request.NoStore && resp.requestMethod() == "HEAD" { 414 | if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" { 415 | resp.CanResume = true 416 | } 417 | return c.statFileInfo 418 | } 419 | return c.openWriter 420 | } 421 | 422 | // openWriter opens the destination file for writing and seeks to the location 423 | // from whence the file transfer will resume. 424 | // 425 | // Requires that Response.Filename and resp.DidResume are already be set. 426 | func (c *Client) openWriter(resp *Response) stateFunc { 427 | if !resp.Request.NoStore && !resp.Request.NoCreateDirectories { 428 | resp.err = mkdirp(resp.Filename) 429 | if resp.err != nil { 430 | return c.closeResponse 431 | } 432 | } 433 | 434 | if resp.Request.NoStore { 435 | resp.writer = &resp.storeBuffer 436 | } else { 437 | // compute write flags 438 | flag := os.O_CREATE | os.O_WRONLY 439 | if resp.fi != nil { 440 | if resp.DidResume { 441 | flag = os.O_APPEND | os.O_WRONLY 442 | } else { 443 | // truncate later in copyFile, if not cancelled 444 | // by BeforeCopy hook 445 | flag = os.O_WRONLY 446 | } 447 | } 448 | 449 | // open file 450 | f, err := os.OpenFile(resp.Filename, flag, 0666) 451 | if err != nil { 452 | resp.err = err 453 | return c.closeResponse 454 | } 455 | resp.writer = f 456 | 457 | // seek to start or end 458 | whence := os.SEEK_SET 459 | if resp.bytesResumed > 0 { 460 | whence = os.SEEK_END 461 | } 462 | _, resp.err = f.Seek(0, whence) 463 | if resp.err != nil { 464 | return c.closeResponse 465 | } 466 | } 467 | 468 | // init transfer 469 | if resp.bufferSize < 1 { 470 | resp.bufferSize = 32 * 1024 471 | } 472 | b := make([]byte, resp.bufferSize) 473 | resp.transfer = newTransfer( 474 | resp.Request.Context(), 475 | resp.Request.RateLimiter, 476 | resp.writer, 477 | resp.HTTPResponse.Body, 478 | b) 479 | 480 | // next step is copyFile, but this will be called later in another goroutine 481 | return nil 482 | } 483 | 484 | // copy transfers content for a HTTP connection established via Client.do() 485 | func (c *Client) copyFile(resp *Response) stateFunc { 486 | if resp.IsComplete() { 487 | return nil 488 | } 489 | 490 | // run BeforeCopy hook 491 | if f := resp.Request.BeforeCopy; f != nil { 492 | resp.err = f(resp) 493 | if resp.err != nil { 494 | return c.closeResponse 495 | } 496 | } 497 | 498 | var bytesCopied int64 499 | if resp.transfer == nil { 500 | panic("grab: developer error: Response.transfer is nil") 501 | } 502 | 503 | // We waited to truncate the file in openWriter() to make sure 504 | // the BeforeCopy didn't cancel the copy. If this was an existing 505 | // file that is not going to be resumed, truncate the contents. 506 | if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume { 507 | t.Truncate(0) 508 | } 509 | 510 | bytesCopied, resp.err = resp.transfer.copy() 511 | if resp.err != nil { 512 | return c.closeResponse 513 | } 514 | closeWriter(resp) 515 | 516 | // set file timestamp 517 | if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime { 518 | resp.err = setLastModified(resp.HTTPResponse, resp.Filename) 519 | if resp.err != nil { 520 | return c.closeResponse 521 | } 522 | } 523 | 524 | // update transfer size if previously unknown 525 | if resp.Size() < 0 { 526 | discoveredSize := resp.bytesResumed + bytesCopied 527 | atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize) 528 | if resp.Request.Size > 0 && resp.Request.Size != discoveredSize { 529 | resp.err = ErrBadLength 530 | return c.closeResponse 531 | } 532 | } 533 | 534 | // run AfterCopy hook 535 | if f := resp.Request.AfterCopy; f != nil { 536 | resp.err = f(resp) 537 | if resp.err != nil { 538 | return c.closeResponse 539 | } 540 | } 541 | 542 | return c.checksumFile 543 | } 544 | 545 | func closeWriter(resp *Response) { 546 | if closer, ok := resp.writer.(io.Closer); ok { 547 | closer.Close() 548 | } 549 | resp.writer = nil 550 | } 551 | 552 | // close finalizes the Response 553 | func (c *Client) closeResponse(resp *Response) stateFunc { 554 | if resp.IsComplete() { 555 | panic("grab: developer error: response already closed") 556 | } 557 | 558 | resp.fi = nil 559 | closeWriter(resp) 560 | resp.closeResponseBody() 561 | 562 | resp.End = time.Now() 563 | close(resp.Done) 564 | if resp.cancel != nil { 565 | resp.cancel() 566 | } 567 | 568 | return nil 569 | } 570 | -------------------------------------------------------------------------------- /v3/client_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/md5" 7 | "crypto/sha1" 8 | "crypto/sha256" 9 | "crypto/sha512" 10 | "errors" 11 | "fmt" 12 | "hash" 13 | "io/ioutil" 14 | "math/rand" 15 | "net/http" 16 | "os" 17 | "path/filepath" 18 | "strings" 19 | "testing" 20 | "time" 21 | 22 | "github.com/cavaliergopher/grab/v3/pkg/grabtest" 23 | ) 24 | 25 | // TestFilenameResolutions tests that the destination filename for Requests can 26 | // be determined correctly, using an explicitly requested path, 27 | // Content-Disposition headers or a URL path - with or without an existing 28 | // target directory. 29 | func TestFilenameResolution(t *testing.T) { 30 | tests := []struct { 31 | Name string 32 | Filename string 33 | URL string 34 | AttachmentFilename string 35 | Expect string 36 | }{ 37 | {"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"}, 38 | {"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"}, 39 | {"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"}, 40 | {"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"}, 41 | {"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"}, 42 | {"Failure", "", "", "", ""}, 43 | } 44 | 45 | err := os.Mkdir(".test", 0777) 46 | if err != nil { 47 | panic(err) 48 | } 49 | defer os.RemoveAll(".test") 50 | 51 | for _, test := range tests { 52 | t.Run(test.Name, func(t *testing.T) { 53 | opts := []grabtest.HandlerOption{} 54 | if test.AttachmentFilename != "" { 55 | opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename)) 56 | } 57 | grabtest.WithTestServer(t, func(url string) { 58 | req := mustNewRequest(test.Filename, url+test.URL) 59 | resp := DefaultClient.Do(req) 60 | defer os.Remove(resp.Filename) 61 | if err := resp.Err(); err != nil { 62 | if test.Expect != "" || err != ErrNoFilename { 63 | panic(err) 64 | } 65 | } else { 66 | if test.Expect == "" { 67 | t.Errorf("expected: %v, got: %v", ErrNoFilename, err) 68 | } 69 | } 70 | if resp.Filename != test.Expect { 71 | t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename) 72 | } 73 | testComplete(t, resp) 74 | }, opts...) 75 | }) 76 | } 77 | } 78 | 79 | // TestChecksums checks that checksum validation behaves as expected for valid 80 | // and corrupted downloads. 81 | func TestChecksums(t *testing.T) { 82 | tests := []struct { 83 | size int 84 | hash hash.Hash 85 | sum string 86 | match bool 87 | }{ 88 | {128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true}, 89 | {128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false}, 90 | {1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true}, 91 | {1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false}, 92 | {1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true}, 93 | {1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false}, 94 | {128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true}, 95 | {128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false}, 96 | {1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true}, 97 | {1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false}, 98 | {1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true}, 99 | {1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false}, 100 | {128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true}, 101 | {128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false}, 102 | {1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true}, 103 | {1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false}, 104 | {1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true}, 105 | {1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false}, 106 | {128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true}, 107 | {128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false}, 108 | {1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true}, 109 | {1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false}, 110 | {1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true}, 111 | {1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false}, 112 | } 113 | 114 | for _, test := range tests { 115 | var expect error 116 | comparison := "Match" 117 | if !test.match { 118 | comparison = "Mismatch" 119 | expect = ErrBadChecksum 120 | } 121 | 122 | t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) { 123 | filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8]) 124 | defer os.Remove(filename) 125 | 126 | grabtest.WithTestServer(t, func(url string) { 127 | req := mustNewRequest(filename, url) 128 | req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true) 129 | 130 | resp := DefaultClient.Do(req) 131 | err := resp.Err() 132 | if err != expect { 133 | t.Errorf("expected error: %v, got: %v", expect, err) 134 | } 135 | 136 | // ensure mismatch file was deleted 137 | if !test.match { 138 | if _, err := os.Stat(filename); err == nil { 139 | t.Errorf("checksum failure not cleaned up: %s", filename) 140 | } else if !os.IsNotExist(err) { 141 | panic(err) 142 | } 143 | } 144 | 145 | testComplete(t, resp) 146 | }, grabtest.ContentLength(test.size)) 147 | }) 148 | } 149 | } 150 | 151 | // TestContentLength ensures that ErrBadLength is returned if a server response 152 | // does not match the requested length. 153 | func TestContentLength(t *testing.T) { 154 | size := int64(32768) 155 | testCases := []struct { 156 | Name string 157 | NoHead bool 158 | Size int64 159 | Expect int64 160 | Match bool 161 | }{ 162 | {"Good size in HEAD request", false, size, size, true}, 163 | {"Good size in GET request", true, size, size, true}, 164 | {"Bad size in HEAD request", false, size - 1, size, false}, 165 | {"Bad size in GET request", true, size - 1, size, false}, 166 | } 167 | 168 | for _, test := range testCases { 169 | t.Run(test.Name, func(t *testing.T) { 170 | opts := []grabtest.HandlerOption{ 171 | grabtest.ContentLength(int(test.Size)), 172 | } 173 | if test.NoHead { 174 | opts = append(opts, grabtest.MethodWhitelist("GET")) 175 | } 176 | 177 | grabtest.WithTestServer(t, func(url string) { 178 | req := mustNewRequest(".testSize-mismatch-head", url) 179 | req.Size = size 180 | resp := DefaultClient.Do(req) 181 | defer os.Remove(resp.Filename) 182 | err := resp.Err() 183 | if test.Match { 184 | if err == ErrBadLength { 185 | t.Errorf("error: %v", err) 186 | } else if err != nil { 187 | panic(err) 188 | } else if resp.Size() != size { 189 | t.Errorf("expected %v bytes, got %v bytes", size, resp.Size()) 190 | } 191 | } else { 192 | if err == nil { 193 | t.Errorf("expected: %v, got %v", ErrBadLength, err) 194 | } else if err != ErrBadLength { 195 | panic(err) 196 | } 197 | } 198 | testComplete(t, resp) 199 | }, opts...) 200 | }) 201 | } 202 | } 203 | 204 | // TestAutoResume tests segmented downloading of a large file. 205 | func TestAutoResume(t *testing.T) { 206 | segs := 8 207 | size := 1048576 208 | sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83") 209 | filename := ".testAutoResume" 210 | 211 | defer os.Remove(filename) 212 | 213 | for i := 0; i < segs; i++ { 214 | segsize := (i + 1) * (size / segs) 215 | t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) { 216 | grabtest.WithTestServer(t, func(url string) { 217 | req := mustNewRequest(filename, url) 218 | if i == segs-1 { 219 | req.SetChecksum(sha256.New(), sum, false) 220 | } 221 | resp := mustDo(req) 222 | if i > 0 && !resp.DidResume { 223 | t.Errorf("expected Response.DidResume to be true") 224 | } 225 | testComplete(t, resp) 226 | }, 227 | grabtest.ContentLength(segsize), 228 | ) 229 | }) 230 | } 231 | 232 | t.Run("WithFailure", func(t *testing.T) { 233 | grabtest.WithTestServer(t, func(url string) { 234 | // request smaller segment 235 | req := mustNewRequest(filename, url) 236 | resp := DefaultClient.Do(req) 237 | if err := resp.Err(); err != ErrBadLength { 238 | t.Errorf("expected ErrBadLength for smaller request, got: %v", err) 239 | } 240 | }, 241 | grabtest.ContentLength(size-128), 242 | ) 243 | }) 244 | 245 | t.Run("WithNoResume", func(t *testing.T) { 246 | grabtest.WithTestServer(t, func(url string) { 247 | req := mustNewRequest(filename, url) 248 | req.NoResume = true 249 | resp := mustDo(req) 250 | if resp.DidResume { 251 | t.Errorf("expected Response.DidResume to be false") 252 | } 253 | testComplete(t, resp) 254 | }, 255 | grabtest.ContentLength(size+128), 256 | ) 257 | }) 258 | 259 | t.Run("WithNoResumeAndTruncate", func(t *testing.T) { 260 | size := size - 128 261 | grabtest.WithTestServer(t, func(url string) { 262 | req := mustNewRequest(filename, url) 263 | req.NoResume = true 264 | resp := mustDo(req) 265 | if resp.DidResume { 266 | t.Errorf("expected Response.DidResume to be false") 267 | } 268 | if v := resp.BytesComplete(); v != int64(size) { 269 | t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v) 270 | } 271 | testComplete(t, resp) 272 | }, 273 | grabtest.ContentLength(size), 274 | ) 275 | }) 276 | 277 | t.Run("WithNoContentLengthHeader", func(t *testing.T) { 278 | grabtest.WithTestServer(t, func(url string) { 279 | req := mustNewRequest(filename, url) 280 | req.SetChecksum(sha256.New(), sum, false) 281 | resp := mustDo(req) 282 | if !resp.DidResume { 283 | t.Errorf("expected Response.DidResume to be true") 284 | } 285 | if actual := resp.Size(); actual != int64(size) { 286 | t.Errorf("expected Response.Size: %d, got: %d", size, actual) 287 | } 288 | testComplete(t, resp) 289 | }, 290 | grabtest.ContentLength(size), 291 | grabtest.HeaderBlacklist("Content-Length"), 292 | ) 293 | }) 294 | 295 | t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) { 296 | // ref: https://github.com/cavaliergopher/grab/v3/pull/27 297 | size := size * 2 298 | grabtest.WithTestServer(t, func(url string) { 299 | req := mustNewRequest(filename, url) 300 | req.SetChecksum(sha256.New(), sum, false) 301 | resp := DefaultClient.Do(req) 302 | if err := resp.Err(); err != ErrBadChecksum { 303 | t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err) 304 | } 305 | if !resp.DidResume { 306 | t.Errorf("expected Response.DidResume to be true") 307 | } 308 | if actual := resp.BytesComplete(); actual != int64(size) { 309 | t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual) 310 | } 311 | if actual := resp.Size(); actual != int64(size) { 312 | t.Errorf("expected Response.Size: %d, got: %d", size, actual) 313 | } 314 | testComplete(t, resp) 315 | }, 316 | grabtest.ContentLength(size), 317 | grabtest.HeaderBlacklist("Content-Length"), 318 | ) 319 | }) 320 | // TODO: test when existing file is corrupted 321 | } 322 | 323 | func TestSkipExisting(t *testing.T) { 324 | filename := ".testSkipExisting" 325 | defer os.Remove(filename) 326 | 327 | // download a file 328 | grabtest.WithTestServer(t, func(url string) { 329 | resp := mustDo(mustNewRequest(filename, url)) 330 | testComplete(t, resp) 331 | }) 332 | 333 | // redownload 334 | grabtest.WithTestServer(t, func(url string) { 335 | resp := mustDo(mustNewRequest(filename, url)) 336 | testComplete(t, resp) 337 | 338 | // ensure download was resumed 339 | if !resp.DidResume { 340 | t.Fatalf("Expected download to skip existing file, but it did not") 341 | } 342 | 343 | // ensure all bytes were resumed 344 | if resp.Size() == 0 || resp.Size() != resp.bytesResumed { 345 | t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed) 346 | } 347 | }) 348 | 349 | // ensure checksum is performed on pre-existing file 350 | grabtest.WithTestServer(t, func(url string) { 351 | req := mustNewRequest(filename, url) 352 | req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true) 353 | resp := DefaultClient.Do(req) 354 | if err := resp.Err(); err != ErrBadChecksum { 355 | t.Fatalf("Expected checksum error, got: %v", err) 356 | } 357 | }) 358 | } 359 | 360 | // TestBatch executes multiple requests simultaneously and validates the 361 | // responses. 362 | func TestBatch(t *testing.T) { 363 | tests := 32 364 | size := 32768 365 | sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8") 366 | 367 | // test with 4 workers and with one per request 368 | grabtest.WithTestServer(t, func(url string) { 369 | for _, workerCount := range []int{4, 0} { 370 | // create requests 371 | reqs := make([]*Request, tests) 372 | for i := 0; i < len(reqs); i++ { 373 | filename := fmt.Sprintf(".testBatch.%d", i+1) 374 | reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1)) 375 | reqs[i].Label = fmt.Sprintf("Test %d", i+1) 376 | reqs[i].SetChecksum(sha256.New(), sum, false) 377 | } 378 | 379 | // batch run 380 | responses := DefaultClient.DoBatch(workerCount, reqs...) 381 | 382 | // listen for responses 383 | Loop: 384 | for i := 0; i < len(reqs); { 385 | select { 386 | case resp := <-responses: 387 | if resp == nil { 388 | break Loop 389 | } 390 | testComplete(t, resp) 391 | if err := resp.Err(); err != nil { 392 | t.Errorf("%s: %v", resp.Filename, err) 393 | } 394 | 395 | // remove test file 396 | if resp.IsComplete() { 397 | os.Remove(resp.Filename) // ignore errors 398 | } 399 | i++ 400 | } 401 | } 402 | } 403 | }, 404 | grabtest.ContentLength(size), 405 | ) 406 | } 407 | 408 | // TestCancelContext tests that a batch of requests can be cancel using a 409 | // context.Context cancellation. Requests are cancelled in multiple states: 410 | // in-progress and unstarted. 411 | func TestCancelContext(t *testing.T) { 412 | fileSize := 134217728 413 | tests := 256 414 | client := NewClient() 415 | ctx, cancel := context.WithCancel(context.Background()) 416 | defer cancel() 417 | 418 | grabtest.WithTestServer(t, func(url string) { 419 | reqs := make([]*Request, tests) 420 | for i := 0; i < tests; i++ { 421 | req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i)) 422 | reqs[i] = req.WithContext(ctx) 423 | } 424 | 425 | respch := client.DoBatch(8, reqs...) 426 | time.Sleep(time.Millisecond * 500) 427 | cancel() 428 | for resp := range respch { 429 | defer os.Remove(resp.Filename) 430 | 431 | // err should be context.Canceled or http.errRequestCanceled 432 | if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") { 433 | t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err()) 434 | } 435 | if resp.BytesComplete() >= int64(fileSize) { 436 | t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete()) 437 | } 438 | } 439 | }, 440 | grabtest.ContentLength(fileSize), 441 | ) 442 | } 443 | 444 | // TestCancelHangingResponse tests that a never ending request is terminated 445 | // when the response is cancelled. 446 | func TestCancelHangingResponse(t *testing.T) { 447 | fileSize := 10 448 | client := NewClient() 449 | 450 | grabtest.WithTestServer(t, func(url string) { 451 | req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url)) 452 | 453 | resp := client.Do(req) 454 | defer os.Remove(resp.Filename) 455 | 456 | // Wait for some bytes to be transferred 457 | for resp.BytesComplete() == 0 { 458 | time.Sleep(50 * time.Millisecond) 459 | } 460 | 461 | done := make(chan error) 462 | go func() { 463 | done <- resp.Cancel() 464 | }() 465 | 466 | select { 467 | case err := <-done: 468 | if err != context.Canceled { 469 | t.Errorf("Expected context.Canceled error, go: %v", err) 470 | } 471 | case <-time.After(time.Second): 472 | t.Fatal("response was not cancelled within 1s") 473 | } 474 | if resp.BytesComplete() == int64(fileSize) { 475 | t.Error("download was not supposed to be complete") 476 | } 477 | }, 478 | grabtest.RateLimiter(1), 479 | grabtest.ContentLength(fileSize), 480 | ) 481 | } 482 | 483 | // TestNestedDirectory tests that missing subdirectories are created. 484 | func TestNestedDirectory(t *testing.T) { 485 | dir := "./.testNested/one/two/three" 486 | filename := ".testNestedFile" 487 | expect := dir + "/" + filename 488 | 489 | t.Run("Create", func(t *testing.T) { 490 | grabtest.WithTestServer(t, func(url string) { 491 | resp := mustDo(mustNewRequest(expect, url+"/"+filename)) 492 | defer os.RemoveAll("./.testNested/") 493 | if resp.Filename != expect { 494 | t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename) 495 | } 496 | }) 497 | }) 498 | 499 | t.Run("No create", func(t *testing.T) { 500 | grabtest.WithTestServer(t, func(url string) { 501 | req := mustNewRequest(expect, url+"/"+filename) 502 | req.NoCreateDirectories = true 503 | resp := DefaultClient.Do(req) 504 | err := resp.Err() 505 | if !os.IsNotExist(err) { 506 | t.Errorf("expected: %v, got: %v", os.ErrNotExist, err) 507 | } 508 | }) 509 | }) 510 | } 511 | 512 | // TestRemoteTime tests that the timestamp of the downloaded file can be set 513 | // according to the timestamp of the remote file. 514 | func TestRemoteTime(t *testing.T) { 515 | filename := "./.testRemoteTime" 516 | defer os.Remove(filename) 517 | 518 | // random time between epoch and now 519 | expect := time.Unix(rand.Int63n(time.Now().Unix()), 0) 520 | grabtest.WithTestServer(t, func(url string) { 521 | resp := mustDo(mustNewRequest(filename, url)) 522 | fi, err := os.Stat(resp.Filename) 523 | if err != nil { 524 | panic(err) 525 | } 526 | actual := fi.ModTime() 527 | if !actual.Equal(expect) { 528 | t.Errorf("expected %v, got %v", expect, actual) 529 | } 530 | }, 531 | grabtest.LastModified(expect), 532 | ) 533 | } 534 | 535 | func TestResponseCode(t *testing.T) { 536 | filename := "./.testResponseCode" 537 | 538 | t.Run("With404", func(t *testing.T) { 539 | defer os.Remove(filename) 540 | grabtest.WithTestServer(t, func(url string) { 541 | req := mustNewRequest(filename, url) 542 | resp := DefaultClient.Do(req) 543 | expect := StatusCodeError(http.StatusNotFound) 544 | err := resp.Err() 545 | if err != expect { 546 | t.Errorf("expected %v, got '%v'", expect, err) 547 | } 548 | if !IsStatusCodeError(err) { 549 | t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err) 550 | } 551 | }, 552 | grabtest.StatusCodeStatic(http.StatusNotFound), 553 | ) 554 | }) 555 | 556 | t.Run("WithIgnoreNon2XX", func(t *testing.T) { 557 | defer os.Remove(filename) 558 | grabtest.WithTestServer(t, func(url string) { 559 | req := mustNewRequest(filename, url) 560 | req.IgnoreBadStatusCodes = true 561 | resp := DefaultClient.Do(req) 562 | if err := resp.Err(); err != nil { 563 | t.Errorf("expected nil, got '%v'", err) 564 | } 565 | }, 566 | grabtest.StatusCodeStatic(http.StatusNotFound), 567 | ) 568 | }) 569 | } 570 | 571 | func TestBeforeCopyHook(t *testing.T) { 572 | filename := "./.testBeforeCopy" 573 | t.Run("Noop", func(t *testing.T) { 574 | defer os.RemoveAll(filename) 575 | grabtest.WithTestServer(t, func(url string) { 576 | called := false 577 | req := mustNewRequest(filename, url) 578 | req.BeforeCopy = func(resp *Response) error { 579 | called = true 580 | if resp.IsComplete() { 581 | t.Error("Response object passed to BeforeCopy hook has already been closed") 582 | } 583 | if resp.Progress() != 0 { 584 | t.Error("Download progress already > 0 when BeforeCopy hook was called") 585 | } 586 | if resp.Duration() == 0 { 587 | t.Error("Duration was zero when BeforeCopy was called") 588 | } 589 | if resp.BytesComplete() != 0 { 590 | t.Error("BytesComplete already > 0 when BeforeCopy hook was called") 591 | } 592 | return nil 593 | } 594 | resp := DefaultClient.Do(req) 595 | if err := resp.Err(); err != nil { 596 | t.Errorf("unexpected error using BeforeCopy hook: %v", err) 597 | } 598 | testComplete(t, resp) 599 | if !called { 600 | t.Error("BeforeCopy hook was never called") 601 | } 602 | }) 603 | }) 604 | 605 | t.Run("WithError", func(t *testing.T) { 606 | defer os.RemoveAll(filename) 607 | grabtest.WithTestServer(t, func(url string) { 608 | testError := errors.New("test") 609 | req := mustNewRequest(filename, url) 610 | req.BeforeCopy = func(resp *Response) error { 611 | return testError 612 | } 613 | resp := DefaultClient.Do(req) 614 | if err := resp.Err(); err != testError { 615 | t.Errorf("expected error '%v', got '%v'", testError, err) 616 | } 617 | if resp.BytesComplete() != 0 { 618 | t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d", 619 | resp.BytesComplete()) 620 | } 621 | testComplete(t, resp) 622 | }) 623 | }) 624 | 625 | // Assert that an existing local file will not be truncated prior to the 626 | // BeforeCopy hook has a chance to cancel the request 627 | t.Run("NoTruncate", func(t *testing.T) { 628 | tfile, err := ioutil.TempFile("", "grab_client_test.*.file") 629 | if err != nil { 630 | t.Fatal(err) 631 | } 632 | defer os.Remove(tfile.Name()) 633 | 634 | const size = 128 635 | _, err = tfile.Write(bytes.Repeat([]byte("x"), size)) 636 | if err != nil { 637 | t.Fatal(err) 638 | } 639 | 640 | grabtest.WithTestServer(t, func(url string) { 641 | called := false 642 | req := mustNewRequest(tfile.Name(), url) 643 | req.NoResume = true 644 | req.BeforeCopy = func(resp *Response) error { 645 | called = true 646 | fi, err := tfile.Stat() 647 | if err != nil { 648 | t.Errorf("failed to stat temp file: %v", err) 649 | return nil 650 | } 651 | if fi.Size() != size { 652 | t.Errorf("expected existing file size of %d bytes "+ 653 | "prior to BeforeCopy hook, got %d", size, fi.Size()) 654 | } 655 | return nil 656 | } 657 | resp := DefaultClient.Do(req) 658 | if err := resp.Err(); err != nil { 659 | t.Errorf("unexpected error using BeforeCopy hook: %v", err) 660 | } 661 | testComplete(t, resp) 662 | if !called { 663 | t.Error("BeforeCopy hook was never called") 664 | } 665 | }) 666 | }) 667 | } 668 | 669 | func TestAfterCopyHook(t *testing.T) { 670 | filename := "./.testAfterCopy" 671 | t.Run("Noop", func(t *testing.T) { 672 | defer os.RemoveAll(filename) 673 | grabtest.WithTestServer(t, func(url string) { 674 | called := false 675 | req := mustNewRequest(filename, url) 676 | req.AfterCopy = func(resp *Response) error { 677 | called = true 678 | if resp.IsComplete() { 679 | t.Error("Response object passed to AfterCopy hook has already been closed") 680 | } 681 | if resp.Progress() <= 0 { 682 | t.Error("Download progress was 0 when AfterCopy hook was called") 683 | } 684 | if resp.Duration() == 0 { 685 | t.Error("Duration was zero when AfterCopy was called") 686 | } 687 | if resp.BytesComplete() <= 0 { 688 | t.Error("BytesComplete was 0 when AfterCopy hook was called") 689 | } 690 | return nil 691 | } 692 | resp := DefaultClient.Do(req) 693 | if err := resp.Err(); err != nil { 694 | t.Errorf("unexpected error using AfterCopy hook: %v", err) 695 | } 696 | testComplete(t, resp) 697 | if !called { 698 | t.Error("AfterCopy hook was never called") 699 | } 700 | }) 701 | }) 702 | 703 | t.Run("WithError", func(t *testing.T) { 704 | defer os.RemoveAll(filename) 705 | grabtest.WithTestServer(t, func(url string) { 706 | testError := errors.New("test") 707 | req := mustNewRequest(filename, url) 708 | req.AfterCopy = func(resp *Response) error { 709 | return testError 710 | } 711 | resp := DefaultClient.Do(req) 712 | if err := resp.Err(); err != testError { 713 | t.Errorf("expected error '%v', got '%v'", testError, err) 714 | } 715 | if resp.BytesComplete() <= 0 { 716 | t.Errorf("ByteCompleted was %d after AfterCopy hook was called", 717 | resp.BytesComplete()) 718 | } 719 | testComplete(t, resp) 720 | }) 721 | }) 722 | } 723 | 724 | func TestIssue37(t *testing.T) { 725 | // ref: https://github.com/cavaliergopher/grab/v3/issues/37 726 | filename := "./.testIssue37" 727 | largeSize := int64(2097152) 728 | smallSize := int64(1048576) 729 | defer os.RemoveAll(filename) 730 | 731 | // download large file 732 | grabtest.WithTestServer(t, func(url string) { 733 | resp := mustDo(mustNewRequest(filename, url)) 734 | if resp.Size() != largeSize { 735 | t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size()) 736 | } 737 | }, grabtest.ContentLength(int(largeSize))) 738 | 739 | // download new, smaller version of same file 740 | grabtest.WithTestServer(t, func(url string) { 741 | req := mustNewRequest(filename, url) 742 | req.NoResume = true 743 | resp := mustDo(req) 744 | if resp.Size() != smallSize { 745 | t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size()) 746 | } 747 | 748 | // local file should have truncated and not resumed 749 | if resp.DidResume { 750 | t.Errorf("expected download to truncate, resumed instead") 751 | } 752 | }, grabtest.ContentLength(int(smallSize))) 753 | 754 | fi, err := os.Stat(filename) 755 | if err != nil { 756 | t.Fatal(err) 757 | } 758 | if fi.Size() != int64(smallSize) { 759 | t.Errorf("expected file size %d, got %d", smallSize, fi.Size()) 760 | } 761 | } 762 | 763 | // TestHeadBadStatus validates that HEAD requests that return non-200 can be 764 | // ignored and succeed if the GET requests succeeeds. 765 | // 766 | // Fixes: https://github.com/cavaliergopher/grab/v3/issues/43 767 | func TestHeadBadStatus(t *testing.T) { 768 | expect := http.StatusOK 769 | filename := ".testIssue43" 770 | 771 | statusFunc := func(r *http.Request) int { 772 | if r.Method == "HEAD" { 773 | return http.StatusForbidden 774 | } 775 | return http.StatusOK 776 | } 777 | 778 | grabtest.WithTestServer(t, func(url string) { 779 | testURL := fmt.Sprintf("%s/%s", url, filename) 780 | resp := mustDo(mustNewRequest("", testURL)) 781 | if resp.HTTPResponse.StatusCode != expect { 782 | t.Errorf( 783 | "expected status code: %d, got:% d", 784 | expect, 785 | resp.HTTPResponse.StatusCode) 786 | } 787 | }, 788 | grabtest.StatusCode(statusFunc), 789 | ) 790 | } 791 | 792 | // TestMissingContentLength ensures that the Response.Size is correct for 793 | // transfers where the remote server does not send a Content-Length header. 794 | // 795 | // TestAutoResume also covers cases with checksum validation. 796 | // 797 | // Kudos to Setnička Jiří for identifying and raising 798 | // a solution to this issue. Ref: https://github.com/cavaliergopher/grab/v3/pull/27 799 | func TestMissingContentLength(t *testing.T) { 800 | // expectSize must be sufficiently large that DefaultClient.Do won't prefetch 801 | // the entire body and compute ContentLength before returning a Response. 802 | expectSize := 1048576 803 | opts := []grabtest.HandlerOption{ 804 | grabtest.ContentLength(expectSize), 805 | grabtest.HeaderBlacklist("Content-Length"), 806 | grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read 807 | } 808 | grabtest.WithTestServer(t, func(url string) { 809 | req := mustNewRequest(".testMissingContentLength", url) 810 | req.SetChecksum( 811 | md5.New(), 812 | grabtest.DefaultHandlerMD5ChecksumBytes, 813 | false) 814 | resp := DefaultClient.Do(req) 815 | 816 | // ensure remote server is not sending content-length header 817 | if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" { 818 | panic(fmt.Sprintf("http header content length must be empty, got: %s", v)) 819 | } 820 | if v := resp.HTTPResponse.ContentLength; v != -1 { 821 | panic(fmt.Sprintf("http response content length must be -1, got: %d", v)) 822 | } 823 | 824 | // before completion, response size should be -1 825 | if resp.Size() != -1 { 826 | t.Errorf("expected response size: -1, got: %d", resp.Size()) 827 | } 828 | 829 | // block for completion 830 | if err := resp.Err(); err != nil { 831 | panic(err) 832 | } 833 | 834 | // on completion, response size should be actual transfer size 835 | if resp.Size() != int64(expectSize) { 836 | t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size()) 837 | } 838 | }, opts...) 839 | } 840 | 841 | func TestNoStore(t *testing.T) { 842 | filename := ".testSubdir/testNoStore" 843 | t.Run("DefaultCase", func(t *testing.T) { 844 | grabtest.WithTestServer(t, func(url string) { 845 | req := mustNewRequest(filename, url) 846 | req.NoStore = true 847 | req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true) 848 | resp := mustDo(req) 849 | 850 | // ensure Response.Bytes is correct and can be reread 851 | b, err := resp.Bytes() 852 | if err != nil { 853 | panic(err) 854 | } 855 | grabtest.AssertSHA256Sum( 856 | t, 857 | grabtest.DefaultHandlerSHA256ChecksumBytes, 858 | bytes.NewReader(b), 859 | ) 860 | 861 | // ensure Response.Open stream is correct and can be reread 862 | r, err := resp.Open() 863 | if err != nil { 864 | panic(err) 865 | } 866 | defer r.Close() 867 | grabtest.AssertSHA256Sum( 868 | t, 869 | grabtest.DefaultHandlerSHA256ChecksumBytes, 870 | r, 871 | ) 872 | 873 | // Response.Filename should still be set 874 | if resp.Filename != filename { 875 | t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename) 876 | } 877 | 878 | // ensure no files were written 879 | paths := []string{ 880 | filename, 881 | filepath.Base(filename), 882 | filepath.Dir(filename), 883 | resp.Filename, 884 | filepath.Base(resp.Filename), 885 | filepath.Dir(resp.Filename), 886 | } 887 | for _, path := range paths { 888 | _, err := os.Stat(path) 889 | if !os.IsNotExist(err) { 890 | t.Errorf( 891 | "expect error: %v, got: %v, for path: %s", 892 | os.ErrNotExist, 893 | err, 894 | path) 895 | } 896 | } 897 | }) 898 | }) 899 | 900 | t.Run("ChecksumValidation", func(t *testing.T) { 901 | grabtest.WithTestServer(t, func(url string) { 902 | req := mustNewRequest("", url) 903 | req.NoStore = true 904 | req.SetChecksum( 905 | md5.New(), 906 | grabtest.MustHexDecodeString("deadbeefcafebabe"), 907 | true) 908 | resp := DefaultClient.Do(req) 909 | if err := resp.Err(); err != ErrBadChecksum { 910 | t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err) 911 | } 912 | }) 913 | }) 914 | } 915 | -------------------------------------------------------------------------------- /v3/cmd/grab/.gitignore: -------------------------------------------------------------------------------- 1 | grab 2 | -------------------------------------------------------------------------------- /v3/cmd/grab/Makefile: -------------------------------------------------------------------------------- 1 | SOURCES = main.go 2 | 3 | all : grab 4 | 5 | grab: $(SOURCES) 6 | go build -o grab $(SOURCES) 7 | 8 | clean: 9 | go clean -x 10 | rm -vf grab 11 | 12 | check: 13 | go test -v . 14 | 15 | install: 16 | go install -v . 17 | 18 | .PHONY: all clean check install 19 | -------------------------------------------------------------------------------- /v3/cmd/grab/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/cavaliergopher/grab/v3/pkg/grabui" 9 | ) 10 | 11 | func main() { 12 | // validate command args 13 | if len(os.Args) < 2 { 14 | fmt.Fprintf(os.Stderr, "usage: %s url...\n", os.Args[0]) 15 | os.Exit(1) 16 | } 17 | urls := os.Args[1:] 18 | 19 | // download files 20 | respch, err := grabui.GetBatch(context.Background(), 0, ".", urls...) 21 | if err != nil { 22 | fmt.Fprint(os.Stderr, err) 23 | os.Exit(1) 24 | } 25 | 26 | // return the number of failed downloads as exit code 27 | failed := 0 28 | for resp := range respch { 29 | if resp.Err() != nil { 30 | failed++ 31 | } 32 | } 33 | os.Exit(failed) 34 | } 35 | -------------------------------------------------------------------------------- /v3/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package grab provides a HTTP download manager implementation. 3 | 4 | Get is the most simple way to download a file: 5 | 6 | resp, err := grab.Get("/tmp", "http://example.com/example.zip") 7 | // ... 8 | 9 | Get will download the given URL and save it to the given destination directory. 10 | The destination filename will be determined automatically by grab using 11 | Content-Disposition headers returned by the remote server, or by inspecting the 12 | requested URL path. 13 | 14 | An empty destination string or "." means the transfer will be stored in the 15 | current working directory. 16 | 17 | If a destination file already exists, grab will assume it is a complete or 18 | partially complete download of the requested file. If the remote server supports 19 | resuming interrupted downloads, grab will resume downloading from the end of the 20 | partial file. If the server does not support resumed downloads, the file will be 21 | retransferred in its entirety. If the file is already complete, grab will return 22 | successfully. 23 | 24 | For control over the HTTP client, destination path, auto-resume, checksum 25 | validation and other settings, create a Client: 26 | 27 | client := grab.NewClient() 28 | client.HTTPClient.Transport.DisableCompression = true 29 | 30 | req, err := grab.NewRequest("/tmp", "http://example.com/example.zip") 31 | // ... 32 | req.NoResume = true 33 | req.HTTPRequest.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l") 34 | 35 | resp := client.Do(req) 36 | // ... 37 | 38 | You can monitor the progress of downloads while they are transferring: 39 | 40 | client := grab.NewClient() 41 | req, err := grab.NewRequest("", "http://example.com/example.zip") 42 | // ... 43 | resp := client.Do(req) 44 | 45 | t := time.NewTicker(time.Second) 46 | defer t.Stop() 47 | 48 | for { 49 | select { 50 | case <-t.C: 51 | fmt.Printf("%.02f%% complete\n", resp.Progress()) 52 | 53 | case <-resp.Done: 54 | if err := resp.Err(); err != nil { 55 | // ... 56 | } 57 | 58 | // ... 59 | return 60 | } 61 | } 62 | */ 63 | package grab 64 | -------------------------------------------------------------------------------- /v3/error.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | var ( 10 | // ErrBadLength indicates that the server response or an existing file does 11 | // not match the expected content length. 12 | ErrBadLength = errors.New("bad content length") 13 | 14 | // ErrBadChecksum indicates that a downloaded file failed to pass checksum 15 | // validation. 16 | ErrBadChecksum = errors.New("checksum mismatch") 17 | 18 | // ErrNoFilename indicates that a reasonable filename could not be 19 | // automatically determined using the URL or response headers from a server. 20 | ErrNoFilename = errors.New("no filename could be determined") 21 | 22 | // ErrNoTimestamp indicates that a timestamp could not be automatically 23 | // determined using the response headers from the remote server. 24 | ErrNoTimestamp = errors.New("no timestamp could be determined for the remote file") 25 | 26 | // ErrFileExists indicates that the destination path already exists. 27 | ErrFileExists = errors.New("file exists") 28 | ) 29 | 30 | // StatusCodeError indicates that the server response had a status code that 31 | // was not in the 200-299 range (after following any redirects). 32 | type StatusCodeError int 33 | 34 | func (err StatusCodeError) Error() string { 35 | return fmt.Sprintf("server returned %d %s", err, http.StatusText(int(err))) 36 | } 37 | 38 | // IsStatusCodeError returns true if the given error is of type StatusCodeError. 39 | func IsStatusCodeError(err error) bool { 40 | _, ok := err.(StatusCodeError) 41 | return ok 42 | } 43 | -------------------------------------------------------------------------------- /v3/example_client_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | func ExampleClient_Do() { 9 | client := NewClient() 10 | req, err := NewRequest("/tmp", "http://example.com/example.zip") 11 | if err != nil { 12 | panic(err) 13 | } 14 | 15 | resp := client.Do(req) 16 | if err := resp.Err(); err != nil { 17 | panic(err) 18 | } 19 | 20 | fmt.Println("Download saved to", resp.Filename) 21 | } 22 | 23 | // This example uses DoChannel to create a Producer/Consumer model for 24 | // downloading multiple files concurrently. This is similar to how DoBatch uses 25 | // DoChannel under the hood except that it allows the caller to continually send 26 | // new requests until they wish to close the request channel. 27 | func ExampleClient_DoChannel() { 28 | // create a request and a buffered response channel 29 | reqch := make(chan *Request) 30 | respch := make(chan *Response, 10) 31 | 32 | // start 4 workers 33 | client := NewClient() 34 | wg := sync.WaitGroup{} 35 | for i := 0; i < 4; i++ { 36 | wg.Add(1) 37 | go func() { 38 | client.DoChannel(reqch, respch) 39 | wg.Done() 40 | }() 41 | } 42 | 43 | go func() { 44 | // send requests 45 | for i := 0; i < 10; i++ { 46 | url := fmt.Sprintf("http://example.com/example%d.zip", i+1) 47 | req, err := NewRequest("/tmp", url) 48 | if err != nil { 49 | panic(err) 50 | } 51 | reqch <- req 52 | } 53 | close(reqch) 54 | 55 | // wait for workers to finish 56 | wg.Wait() 57 | close(respch) 58 | }() 59 | 60 | // check each response 61 | for resp := range respch { 62 | // block until complete 63 | if err := resp.Err(); err != nil { 64 | panic(err) 65 | } 66 | 67 | fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename) 68 | } 69 | } 70 | 71 | func ExampleClient_DoBatch() { 72 | // create multiple download requests 73 | reqs := make([]*Request, 0) 74 | for i := 0; i < 10; i++ { 75 | url := fmt.Sprintf("http://example.com/example%d.zip", i+1) 76 | req, err := NewRequest("/tmp", url) 77 | if err != nil { 78 | panic(err) 79 | } 80 | reqs = append(reqs, req) 81 | } 82 | 83 | // start downloads with 4 workers 84 | client := NewClient() 85 | respch := client.DoBatch(4, reqs...) 86 | 87 | // check each response 88 | for resp := range respch { 89 | if err := resp.Err(); err != nil { 90 | panic(err) 91 | } 92 | 93 | fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /v3/example_request_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "encoding/hex" 7 | "fmt" 8 | "time" 9 | ) 10 | 11 | func ExampleRequest_WithContext() { 12 | // create context with a 100ms timeout 13 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 14 | defer cancel() 15 | 16 | // create download request with context 17 | req, err := NewRequest("", "http://example.com/example.zip") 18 | if err != nil { 19 | panic(err) 20 | } 21 | req = req.WithContext(ctx) 22 | 23 | // send download request 24 | resp := DefaultClient.Do(req) 25 | if err := resp.Err(); err != nil { 26 | fmt.Println("error: request cancelled") 27 | } 28 | 29 | // Output: 30 | // error: request cancelled 31 | } 32 | 33 | func ExampleRequest_SetChecksum() { 34 | // create download request 35 | req, err := NewRequest("", "http://example.com/example.zip") 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | // set request checksum 41 | sum, err := hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5da473e537f4d69b470e57") 42 | if err != nil { 43 | panic(err) 44 | } 45 | req.SetChecksum(sha256.New(), sum, true) 46 | 47 | // download and validate file 48 | resp := DefaultClient.Do(req) 49 | if err := resp.Err(); err != nil { 50 | panic(err) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /v3/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cavaliergopher/grab/v3 2 | 3 | go 1.14 4 | -------------------------------------------------------------------------------- /v3/grab.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | // Get sends a HTTP request and downloads the content of the requested URL to 9 | // the given destination file path. The caller is blocked until the download is 10 | // completed, successfully or otherwise. 11 | // 12 | // An error is returned if caused by client policy (such as CheckRedirect), or 13 | // if there was an HTTP protocol or IO error. 14 | // 15 | // For non-blocking calls or control over HTTP client headers, redirect policy, 16 | // and other settings, create a Client instead. 17 | func Get(dst, urlStr string) (*Response, error) { 18 | req, err := NewRequest(dst, urlStr) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | resp := DefaultClient.Do(req) 24 | return resp, resp.Err() 25 | } 26 | 27 | // GetBatch sends multiple HTTP requests and downloads the content of the 28 | // requested URLs to the given destination directory using the given number of 29 | // concurrent worker goroutines. 30 | // 31 | // The Response for each requested URL is sent through the returned Response 32 | // channel, as soon as a worker receives a response from the remote server. The 33 | // Response can then be used to track the progress of the download while it is 34 | // in progress. 35 | // 36 | // The returned Response channel will be closed by Grab, only once all downloads 37 | // have completed or failed. 38 | // 39 | // If an error occurs during any download, it will be available via call to the 40 | // associated Response.Err. 41 | // 42 | // For control over HTTP client headers, redirect policy, and other settings, 43 | // create a Client instead. 44 | func GetBatch(workers int, dst string, urlStrs ...string) (<-chan *Response, error) { 45 | fi, err := os.Stat(dst) 46 | if err != nil { 47 | return nil, err 48 | } 49 | if !fi.IsDir() { 50 | return nil, fmt.Errorf("destination is not a directory") 51 | } 52 | 53 | reqs := make([]*Request, len(urlStrs)) 54 | for i := 0; i < len(urlStrs); i++ { 55 | req, err := NewRequest(dst, urlStrs[i]) 56 | if err != nil { 57 | return nil, err 58 | } 59 | reqs[i] = req 60 | } 61 | 62 | ch := DefaultClient.DoBatch(workers, reqs...) 63 | return ch, nil 64 | } 65 | -------------------------------------------------------------------------------- /v3/grab_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "testing" 9 | 10 | "github.com/cavaliergopher/grab/v3/pkg/grabtest" 11 | ) 12 | 13 | func TestMain(m *testing.M) { 14 | os.Exit(func() int { 15 | // chdir to temp so test files downloaded to pwd are isolated and cleaned up 16 | cwd, err := os.Getwd() 17 | if err != nil { 18 | panic(err) 19 | } 20 | tmpDir, err := ioutil.TempDir("", "grab-") 21 | if err != nil { 22 | panic(err) 23 | } 24 | if err := os.Chdir(tmpDir); err != nil { 25 | panic(err) 26 | } 27 | defer func() { 28 | os.Chdir(cwd) 29 | if err := os.RemoveAll(tmpDir); err != nil { 30 | panic(err) 31 | } 32 | }() 33 | return m.Run() 34 | }()) 35 | } 36 | 37 | // TestGet tests grab.Get 38 | func TestGet(t *testing.T) { 39 | filename := ".testGet" 40 | defer os.Remove(filename) 41 | grabtest.WithTestServer(t, func(url string) { 42 | resp, err := Get(filename, url) 43 | if err != nil { 44 | t.Fatalf("error in Get(): %v", err) 45 | } 46 | testComplete(t, resp) 47 | }) 48 | } 49 | 50 | func ExampleGet() { 51 | // download a file to /tmp 52 | resp, err := Get("/tmp", "http://example.com/example.zip") 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | 57 | fmt.Println("Download saved to", resp.Filename) 58 | } 59 | 60 | func mustNewRequest(dst, urlStr string) *Request { 61 | req, err := NewRequest(dst, urlStr) 62 | if err != nil { 63 | panic(err) 64 | } 65 | return req 66 | } 67 | 68 | func mustDo(req *Request) *Response { 69 | resp := DefaultClient.Do(req) 70 | if err := resp.Err(); err != nil { 71 | panic(err) 72 | } 73 | return resp 74 | } 75 | -------------------------------------------------------------------------------- /v3/pkg/bps/bps.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package bps provides gauges for calculating the Bytes Per Second transfer rate 3 | of data streams. 4 | */ 5 | package bps 6 | 7 | import ( 8 | "context" 9 | "time" 10 | ) 11 | 12 | // Gauge is the common interface for all BPS gauges in this package. Given a 13 | // set of samples over time, each gauge type can be used to measure the Bytes 14 | // Per Second transfer rate of a data stream. 15 | // 16 | // All samples must monotonically increase in timestamp and value. Each sample 17 | // should represent the total number of bytes sent in a stream, rather than 18 | // accounting for the number sent since the last sample. 19 | // 20 | // To ensure a gauge can report progress as quickly as possible, take an initial 21 | // sample when your stream first starts. 22 | // 23 | // All gauge implementations are safe for concurrent use. 24 | type Gauge interface { 25 | // Sample adds a new sample of the progress of the monitored stream. 26 | Sample(t time.Time, n int64) 27 | 28 | // BPS returns the calculated Bytes Per Second rate of the monitored stream. 29 | BPS() float64 30 | } 31 | 32 | // SampleFunc is used by Watch to take periodic samples of a monitored stream. 33 | type SampleFunc func() (n int64) 34 | 35 | // Watch will periodically call the given SampleFunc to sample the progress of 36 | // a monitored stream and update the given gauge. SampleFunc should return the 37 | // total number of bytes transferred by the stream since it started. 38 | // 39 | // Watch is a blocking call and should typically be called in a new goroutine. 40 | // To prevent the goroutine from leaking, make sure to cancel the given context 41 | // once the stream is completed or canceled. 42 | func Watch(ctx context.Context, g Gauge, f SampleFunc, interval time.Duration) { 43 | g.Sample(time.Now(), f()) 44 | t := time.NewTicker(interval) 45 | defer t.Stop() 46 | for { 47 | select { 48 | case <-ctx.Done(): 49 | return 50 | case now := <-t.C: 51 | g.Sample(now, f()) 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /v3/pkg/bps/sma.go: -------------------------------------------------------------------------------- 1 | package bps 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // NewSMA returns a gauge that uses a Simple Moving Average with the given 9 | // number of samples to measure the bytes per second of a byte stream. 10 | // 11 | // BPS is computed using the timestamp of the most recent and oldest sample in 12 | // the sample buffer. When a new sample is added, the oldest sample is dropped 13 | // if the sample count exceeds maxSamples. 14 | // 15 | // The gauge does not account for any latency in arrival time of new samples or 16 | // the desired window size. Any variance in the arrival of samples will result 17 | // in a BPS measurement that is correct for the submitted samples, but over a 18 | // varying time window. 19 | // 20 | // maxSamples should be equal to 1 + (window size / sampling interval) where 21 | // window size is the number of seconds over which the moving average is 22 | // smoothed and sampling interval is the number of seconds between each sample. 23 | // 24 | // For example, if you want a five second window, sampling once per second, 25 | // maxSamples should be 1 + 5/1 = 6. 26 | func NewSMA(maxSamples int) Gauge { 27 | if maxSamples < 2 { 28 | panic("sample count must be greater than 1") 29 | } 30 | return &sma{ 31 | maxSamples: uint64(maxSamples), 32 | samples: make([]int64, maxSamples), 33 | timestamps: make([]time.Time, maxSamples), 34 | } 35 | } 36 | 37 | type sma struct { 38 | mu sync.Mutex 39 | index uint64 40 | maxSamples uint64 41 | sampleCount uint64 42 | samples []int64 43 | timestamps []time.Time 44 | } 45 | 46 | func (c *sma) Sample(t time.Time, n int64) { 47 | c.mu.Lock() 48 | defer c.mu.Unlock() 49 | 50 | c.timestamps[c.index] = t 51 | c.samples[c.index] = n 52 | c.index = (c.index + 1) % c.maxSamples 53 | 54 | // prevent integer overflow in sampleCount. Values greater or equal to 55 | // maxSamples have the same semantic meaning. 56 | c.sampleCount++ 57 | if c.sampleCount > c.maxSamples { 58 | c.sampleCount = c.maxSamples 59 | } 60 | } 61 | 62 | func (c *sma) BPS() float64 { 63 | c.mu.Lock() 64 | defer c.mu.Unlock() 65 | 66 | // we need two samples to start 67 | if c.sampleCount < 2 { 68 | return 0 69 | } 70 | 71 | // First sample is always the oldest until ring buffer first overflows 72 | oldest := c.index 73 | if c.sampleCount < c.maxSamples { 74 | oldest = 0 75 | } 76 | 77 | newest := (c.index + c.maxSamples - 1) % c.maxSamples 78 | seconds := c.timestamps[newest].Sub(c.timestamps[oldest]).Seconds() 79 | bytes := float64(c.samples[newest] - c.samples[oldest]) 80 | return bytes / seconds 81 | } 82 | -------------------------------------------------------------------------------- /v3/pkg/bps/sma_test.go: -------------------------------------------------------------------------------- 1 | package bps 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | type Sample struct { 9 | N int64 10 | Expect float64 11 | } 12 | 13 | func getSimpleSamples(sampleCount, rate int) []Sample { 14 | a := make([]Sample, sampleCount) 15 | for i := 1; i < sampleCount; i++ { 16 | a[i] = Sample{N: int64(i * rate), Expect: float64(rate)} 17 | } 18 | return a 19 | } 20 | 21 | type SampleSetTest struct { 22 | Gauge Gauge 23 | Interval time.Duration 24 | Samples []Sample 25 | } 26 | 27 | func (c *SampleSetTest) Run(t *testing.T) { 28 | ts := time.Unix(0, 0) 29 | for i, sample := range c.Samples { 30 | c.Gauge.Sample(ts, sample.N) 31 | if actual := c.Gauge.BPS(); actual != sample.Expect { 32 | t.Errorf("expected: Gauge.BPS() → %0.2f, got %0.2f in test %d", sample.Expect, actual, i+1) 33 | } 34 | ts = ts.Add(c.Interval) 35 | } 36 | } 37 | 38 | func TestSMA_SimpleSteadyCase(t *testing.T) { 39 | test := &SampleSetTest{ 40 | Interval: time.Second, 41 | Samples: getSimpleSamples(100000, 3), 42 | } 43 | t.Run("SmallSampleSize", func(t *testing.T) { 44 | test.Gauge = NewSMA(2) 45 | test.Run(t) 46 | }) 47 | t.Run("RegularSize", func(t *testing.T) { 48 | test.Gauge = NewSMA(6) 49 | test.Run(t) 50 | }) 51 | t.Run("LargeSampleSize", func(t *testing.T) { 52 | test.Gauge = NewSMA(1000) 53 | test.Run(t) 54 | }) 55 | } 56 | -------------------------------------------------------------------------------- /v3/pkg/grabtest/assert.go: -------------------------------------------------------------------------------- 1 | package grabtest 2 | 3 | import ( 4 | "bytes" 5 | "crypto/sha256" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net/http" 10 | "testing" 11 | ) 12 | 13 | func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) { 14 | if resp.StatusCode != expect { 15 | t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode) 16 | return 17 | } 18 | ok = true 19 | return true 20 | } 21 | 22 | func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) { 23 | expect := fmt.Sprintf(format, a...) 24 | actual := resp.Header.Get(key) 25 | if actual != expect { 26 | t.Errorf("expected header %s: %s, got: %s", key, expect, actual) 27 | return 28 | } 29 | ok = true 30 | return 31 | } 32 | 33 | func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) { 34 | ok = true 35 | if resp.ContentLength != n { 36 | ok = false 37 | t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength) 38 | } 39 | if !AssertHTTPResponseBodyLength(t, resp, n) { 40 | ok = false 41 | } 42 | return 43 | } 44 | 45 | func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) { 46 | defer func() { 47 | if err := resp.Body.Close(); err != nil { 48 | panic(err) 49 | } 50 | }() 51 | b, err := ioutil.ReadAll(resp.Body) 52 | if err != nil { 53 | panic(err) 54 | } 55 | if int64(len(b)) != n { 56 | ok = false 57 | t.Errorf("expected body length: %d, got: %d", n, len(b)) 58 | } 59 | return 60 | } 61 | 62 | func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request { 63 | req, err := http.NewRequest(method, url, body) 64 | if err != nil { 65 | panic(err) 66 | } 67 | return req 68 | } 69 | 70 | func MustHTTPDo(req *http.Request) *http.Response { 71 | resp, err := http.DefaultClient.Do(req) 72 | if err != nil { 73 | panic(err) 74 | } 75 | return resp 76 | } 77 | 78 | func MustHTTPDoWithClose(req *http.Request) *http.Response { 79 | resp := MustHTTPDo(req) 80 | if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { 81 | panic(err) 82 | } 83 | if err := resp.Body.Close(); err != nil { 84 | panic(err) 85 | } 86 | return resp 87 | } 88 | 89 | func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) { 90 | h := sha256.New() 91 | if _, err := io.Copy(h, r); err != nil { 92 | panic(err) 93 | } 94 | computed := h.Sum(nil) 95 | ok = bytes.Equal(sum, computed) 96 | if !ok { 97 | t.Errorf( 98 | "expected checksum: %s, got: %s", 99 | MustHexEncodeString(sum), 100 | MustHexEncodeString(computed), 101 | ) 102 | } 103 | return 104 | } 105 | -------------------------------------------------------------------------------- /v3/pkg/grabtest/handler.go: -------------------------------------------------------------------------------- 1 | package grabtest 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | var ( 13 | DefaultHandlerContentLength = 1 << 20 14 | DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372" 15 | DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum) 16 | DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83" 17 | DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum) 18 | ) 19 | 20 | type StatusCodeFunc func(req *http.Request) int 21 | 22 | type handler struct { 23 | statusCodeFunc StatusCodeFunc 24 | methodWhitelist []string 25 | headerBlacklist []string 26 | contentLength int 27 | acceptRanges bool 28 | attachmentFilename string 29 | lastModified time.Time 30 | ttfb time.Duration 31 | rateLimiter *time.Ticker 32 | } 33 | 34 | func NewHandler(options ...HandlerOption) (http.Handler, error) { 35 | h := &handler{ 36 | statusCodeFunc: func(req *http.Request) int { return http.StatusOK }, 37 | methodWhitelist: []string{"GET", "HEAD"}, 38 | contentLength: DefaultHandlerContentLength, 39 | acceptRanges: true, 40 | } 41 | for _, option := range options { 42 | if err := option(h); err != nil { 43 | return nil, err 44 | } 45 | } 46 | return h, nil 47 | } 48 | 49 | func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) { 50 | h, err := NewHandler(options...) 51 | if err != nil { 52 | t.Fatalf("unable to create test server handler: %v", err) 53 | return 54 | } 55 | s := httptest.NewServer(h) 56 | defer func() { 57 | h.(*handler).close() 58 | s.Close() 59 | }() 60 | f(s.URL) 61 | } 62 | 63 | func (h *handler) close() { 64 | if h.rateLimiter != nil { 65 | h.rateLimiter.Stop() 66 | } 67 | } 68 | 69 | func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 70 | // delay response 71 | if h.ttfb > 0 { 72 | time.Sleep(h.ttfb) 73 | } 74 | 75 | // validate request method 76 | allowed := false 77 | for _, m := range h.methodWhitelist { 78 | if r.Method == m { 79 | allowed = true 80 | break 81 | } 82 | } 83 | if !allowed { 84 | httpError(w, http.StatusMethodNotAllowed) 85 | return 86 | } 87 | 88 | // set server options 89 | if h.acceptRanges { 90 | w.Header().Set("Accept-Ranges", "bytes") 91 | } 92 | 93 | // set attachment filename 94 | if h.attachmentFilename != "" { 95 | w.Header().Set( 96 | "Content-Disposition", 97 | fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename), 98 | ) 99 | } 100 | 101 | // set last modified timestamp 102 | lastMod := time.Now() 103 | if !h.lastModified.IsZero() { 104 | lastMod = h.lastModified 105 | } 106 | w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat)) 107 | 108 | // set content-length 109 | offset := 0 110 | if h.acceptRanges { 111 | if reqRange := r.Header.Get("Range"); reqRange != "" { 112 | if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil { 113 | httpError(w, http.StatusBadRequest) 114 | return 115 | } 116 | if offset >= h.contentLength { 117 | httpError(w, http.StatusRequestedRangeNotSatisfiable) 118 | return 119 | } 120 | } 121 | } 122 | w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset)) 123 | 124 | // apply header blacklist 125 | for _, key := range h.headerBlacklist { 126 | w.Header().Del(key) 127 | } 128 | 129 | // send header and status code 130 | w.WriteHeader(h.statusCodeFunc(r)) 131 | 132 | // send body 133 | if r.Method == "GET" { 134 | // use buffered io to reduce overhead on the reader 135 | bw := bufio.NewWriterSize(w, 4096) 136 | for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ { 137 | bw.Write([]byte{byte(i)}) 138 | if h.rateLimiter != nil { 139 | bw.Flush() 140 | w.(http.Flusher).Flush() // force the server to send the data to the client 141 | select { 142 | case <-h.rateLimiter.C: 143 | case <-r.Context().Done(): 144 | } 145 | } 146 | } 147 | if !isRequestClosed(r) { 148 | bw.Flush() 149 | } 150 | } 151 | } 152 | 153 | // isRequestClosed returns true if the client request has been canceled. 154 | func isRequestClosed(r *http.Request) bool { 155 | return r.Context().Err() != nil 156 | } 157 | 158 | func httpError(w http.ResponseWriter, code int) { 159 | http.Error(w, http.StatusText(code), code) 160 | } 161 | -------------------------------------------------------------------------------- /v3/pkg/grabtest/handler_option.go: -------------------------------------------------------------------------------- 1 | package grabtest 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | type HandlerOption func(*handler) error 10 | 11 | func StatusCodeStatic(code int) HandlerOption { 12 | return func(h *handler) error { 13 | return StatusCode(func(req *http.Request) int { 14 | return code 15 | })(h) 16 | } 17 | } 18 | 19 | func StatusCode(f StatusCodeFunc) HandlerOption { 20 | return func(h *handler) error { 21 | if f == nil { 22 | return errors.New("status code function cannot be nil") 23 | } 24 | h.statusCodeFunc = f 25 | return nil 26 | } 27 | } 28 | 29 | func MethodWhitelist(methods ...string) HandlerOption { 30 | return func(h *handler) error { 31 | h.methodWhitelist = methods 32 | return nil 33 | } 34 | } 35 | 36 | func HeaderBlacklist(headers ...string) HandlerOption { 37 | return func(h *handler) error { 38 | h.headerBlacklist = headers 39 | return nil 40 | } 41 | } 42 | 43 | func ContentLength(n int) HandlerOption { 44 | return func(h *handler) error { 45 | if n < 0 { 46 | return errors.New("content length must be zero or greater") 47 | } 48 | h.contentLength = n 49 | return nil 50 | } 51 | } 52 | 53 | func AcceptRanges(enabled bool) HandlerOption { 54 | return func(h *handler) error { 55 | h.acceptRanges = enabled 56 | return nil 57 | } 58 | } 59 | 60 | func LastModified(t time.Time) HandlerOption { 61 | return func(h *handler) error { 62 | h.lastModified = t.UTC() 63 | return nil 64 | } 65 | } 66 | 67 | func TimeToFirstByte(d time.Duration) HandlerOption { 68 | return func(h *handler) error { 69 | if d < 1 { 70 | return errors.New("time to first byte must be greater than zero") 71 | } 72 | h.ttfb = d 73 | return nil 74 | } 75 | } 76 | 77 | func RateLimiter(bps int) HandlerOption { 78 | return func(h *handler) error { 79 | if bps < 1 { 80 | return errors.New("bytes per second must be greater than zero") 81 | } 82 | h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps)) 83 | return nil 84 | } 85 | } 86 | 87 | func AttachmentFilename(filename string) HandlerOption { 88 | return func(h *handler) error { 89 | h.attachmentFilename = filename 90 | return nil 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /v3/pkg/grabtest/handler_test.go: -------------------------------------------------------------------------------- 1 | package grabtest 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net/http" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestHandlerDefaults(t *testing.T) { 12 | WithTestServer(t, func(url string) { 13 | resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) 14 | AssertHTTPResponseStatusCode(t, resp, http.StatusOK) 15 | AssertHTTPResponseContentLength(t, resp, 1048576) 16 | AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes") 17 | }) 18 | } 19 | 20 | func TestHandlerMethodWhitelist(t *testing.T) { 21 | tests := []struct { 22 | Whitelist []string 23 | Method string 24 | ExpectStatusCode int 25 | }{ 26 | {[]string{"GET", "HEAD"}, "GET", http.StatusOK}, 27 | {[]string{"GET", "HEAD"}, "HEAD", http.StatusOK}, 28 | {[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed}, 29 | {[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed}, 30 | } 31 | 32 | for _, test := range tests { 33 | WithTestServer(t, func(url string) { 34 | resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil)) 35 | AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode) 36 | }, MethodWhitelist(test.Whitelist...)) 37 | } 38 | } 39 | 40 | func TestHandlerHeaderBlacklist(t *testing.T) { 41 | contentLength := 4096 42 | WithTestServer(t, func(url string) { 43 | resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) 44 | defer resp.Body.Close() 45 | if resp.ContentLength != -1 { 46 | t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength) 47 | } 48 | AssertHTTPResponseHeader(t, resp, "Content-Length", "") 49 | AssertHTTPResponseBodyLength(t, resp, int64(contentLength)) 50 | }, 51 | ContentLength(contentLength), 52 | HeaderBlacklist("Content-Length"), 53 | ) 54 | } 55 | 56 | func TestHandlerStatusCodeFuncs(t *testing.T) { 57 | expect := 418 // I'm a teapot 58 | WithTestServer(t, func(url string) { 59 | resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) 60 | defer resp.Body.Close() 61 | AssertHTTPResponseStatusCode(t, resp, expect) 62 | }, 63 | StatusCode(func(req *http.Request) int { return expect }), 64 | ) 65 | } 66 | 67 | func TestHandlerContentLength(t *testing.T) { 68 | tests := []struct { 69 | Method string 70 | ContentLength int 71 | ExpectHeaderLen int64 72 | ExpectBodyLen int 73 | }{ 74 | {"GET", 321, 321, 321}, 75 | {"HEAD", 321, 321, 0}, 76 | {"GET", 0, 0, 0}, 77 | {"HEAD", 0, 0, 0}, 78 | } 79 | 80 | for _, test := range tests { 81 | WithTestServer(t, func(url string) { 82 | resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil)) 83 | defer resp.Body.Close() 84 | 85 | AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen) 86 | 87 | b, err := ioutil.ReadAll(resp.Body) 88 | if err != nil { 89 | panic(err) 90 | } 91 | if len(b) != test.ExpectBodyLen { 92 | t.Errorf( 93 | "expected body length: %v, got: %v, in: %v", 94 | test.ExpectBodyLen, 95 | len(b), 96 | test, 97 | ) 98 | } 99 | }, 100 | ContentLength(test.ContentLength), 101 | ) 102 | } 103 | } 104 | 105 | func TestHandlerAcceptRanges(t *testing.T) { 106 | header := "Accept-Ranges" 107 | n := 128 108 | t.Run("Enabled", func(t *testing.T) { 109 | WithTestServer(t, func(url string) { 110 | req := MustHTTPNewRequest("GET", url, nil) 111 | req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2)) 112 | resp := MustHTTPDo(req) 113 | AssertHTTPResponseHeader(t, resp, header, "bytes") 114 | AssertHTTPResponseContentLength(t, resp, int64(n/2)) 115 | }, 116 | ContentLength(n), 117 | ) 118 | }) 119 | 120 | t.Run("Disabled", func(t *testing.T) { 121 | WithTestServer(t, func(url string) { 122 | req := MustHTTPNewRequest("GET", url, nil) 123 | req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2)) 124 | resp := MustHTTPDo(req) 125 | AssertHTTPResponseHeader(t, resp, header, "") 126 | AssertHTTPResponseContentLength(t, resp, int64(n)) 127 | }, 128 | AcceptRanges(false), 129 | ContentLength(n), 130 | ) 131 | }) 132 | } 133 | 134 | func TestHandlerAttachmentFilename(t *testing.T) { 135 | filename := "foo.pdf" 136 | WithTestServer(t, func(url string) { 137 | resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil)) 138 | AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename) 139 | }, 140 | AttachmentFilename(filename), 141 | ) 142 | } 143 | 144 | func TestHandlerLastModified(t *testing.T) { 145 | WithTestServer(t, func(url string) { 146 | resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil)) 147 | AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT") 148 | }, 149 | LastModified(time.Unix(123456789, 0)), 150 | ) 151 | } 152 | -------------------------------------------------------------------------------- /v3/pkg/grabtest/util.go: -------------------------------------------------------------------------------- 1 | package grabtest 2 | 3 | import "encoding/hex" 4 | 5 | func MustHexDecodeString(s string) (b []byte) { 6 | var err error 7 | b, err = hex.DecodeString(s) 8 | if err != nil { 9 | panic(err) 10 | } 11 | return 12 | } 13 | 14 | func MustHexEncodeString(b []byte) (s string) { 15 | return hex.EncodeToString(b) 16 | } 17 | -------------------------------------------------------------------------------- /v3/pkg/grabui/console_client.go: -------------------------------------------------------------------------------- 1 | package grabui 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | "github.com/cavaliergopher/grab/v3" 11 | ) 12 | 13 | type ConsoleClient struct { 14 | mu sync.Mutex 15 | client *grab.Client 16 | succeeded, failed, inProgress int 17 | responses []*grab.Response 18 | } 19 | 20 | func NewConsoleClient(client *grab.Client) *ConsoleClient { 21 | return &ConsoleClient{ 22 | client: client, 23 | } 24 | } 25 | 26 | func (c *ConsoleClient) Do( 27 | ctx context.Context, 28 | workers int, 29 | reqs ...*grab.Request, 30 | ) <-chan *grab.Response { 31 | // buffer size prevents slow receivers causing back pressure 32 | pump := make(chan *grab.Response, len(reqs)) 33 | 34 | go func() { 35 | c.mu.Lock() 36 | defer c.mu.Unlock() 37 | 38 | c.failed = 0 39 | c.inProgress = 0 40 | c.succeeded = 0 41 | c.responses = make([]*grab.Response, 0, len(reqs)) 42 | if c.client == nil { 43 | c.client = grab.DefaultClient 44 | } 45 | 46 | fmt.Printf("Downloading %d files...\n", len(reqs)) 47 | respch := c.client.DoBatch(workers, reqs...) 48 | t := time.NewTicker(200 * time.Millisecond) 49 | defer t.Stop() 50 | 51 | Loop: 52 | for { 53 | select { 54 | case <-ctx.Done(): 55 | break Loop 56 | 57 | case resp := <-respch: 58 | if resp != nil { 59 | // a new response has been received and has started downloading 60 | c.responses = append(c.responses, resp) 61 | pump <- resp // send to caller 62 | } else { 63 | // channel is closed - all downloads are complete 64 | break Loop 65 | } 66 | 67 | case <-t.C: 68 | // update UI on clock tick 69 | c.refresh() 70 | } 71 | } 72 | 73 | c.refresh() 74 | close(pump) 75 | 76 | fmt.Printf( 77 | "Finished %d successful, %d failed, %d incomplete.\n", 78 | c.succeeded, 79 | c.failed, 80 | c.inProgress) 81 | }() 82 | return pump 83 | } 84 | 85 | // refresh prints the progress of all downloads to the terminal 86 | func (c *ConsoleClient) refresh() { 87 | // clear lines for incomplete downloads 88 | if c.inProgress > 0 { 89 | fmt.Printf("\033[%dA\033[K", c.inProgress) 90 | } 91 | 92 | // print newly completed downloads 93 | for i, resp := range c.responses { 94 | if resp != nil && resp.IsComplete() { 95 | if resp.Err() != nil { 96 | c.failed++ 97 | fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n", 98 | resp.Request.URL(), 99 | resp.Err()) 100 | } else { 101 | c.succeeded++ 102 | fmt.Printf("Finished %s %s / %s (%d%%)\n", 103 | resp.Filename, 104 | byteString(resp.BytesComplete()), 105 | byteString(resp.Size()), 106 | int(100*resp.Progress())) 107 | } 108 | c.responses[i] = nil 109 | } 110 | } 111 | 112 | // print progress for incomplete downloads 113 | c.inProgress = 0 114 | for _, resp := range c.responses { 115 | if resp != nil { 116 | fmt.Printf("Downloading %s %s / %s (%d%%) - %s ETA: %s \033[K\n", 117 | resp.Filename, 118 | byteString(resp.BytesComplete()), 119 | byteString(resp.Size()), 120 | int(100*resp.Progress()), 121 | bpsString(resp.BytesPerSecond()), 122 | etaString(resp.ETA())) 123 | c.inProgress++ 124 | } 125 | } 126 | } 127 | 128 | func bpsString(n float64) string { 129 | if n < 1e3 { 130 | return fmt.Sprintf("%.02fBps", n) 131 | } 132 | if n < 1e6 { 133 | return fmt.Sprintf("%.02fKB/s", n/1e3) 134 | } 135 | if n < 1e9 { 136 | return fmt.Sprintf("%.02fMB/s", n/1e6) 137 | } 138 | return fmt.Sprintf("%.02fGB/s", n/1e9) 139 | } 140 | 141 | func byteString(n int64) string { 142 | if n < 1<<10 { 143 | return fmt.Sprintf("%dB", n) 144 | } 145 | if n < 1<<20 { 146 | return fmt.Sprintf("%dKB", n>>10) 147 | } 148 | if n < 1<<30 { 149 | return fmt.Sprintf("%dMB", n>>20) 150 | } 151 | if n < 1<<40 { 152 | return fmt.Sprintf("%dGB", n>>30) 153 | } 154 | return fmt.Sprintf("%dTB", n>>40) 155 | } 156 | 157 | func etaString(eta time.Time) string { 158 | d := eta.Sub(time.Now()) 159 | if d < time.Second { 160 | return "<1s" 161 | } 162 | // truncate to 1s resolution 163 | d /= time.Second 164 | d *= time.Second 165 | return d.String() 166 | } 167 | -------------------------------------------------------------------------------- /v3/pkg/grabui/grabui.go: -------------------------------------------------------------------------------- 1 | package grabui 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cavaliergopher/grab/v3" 7 | ) 8 | 9 | func GetBatch( 10 | ctx context.Context, 11 | workers int, 12 | dst string, 13 | urlStrs ...string, 14 | ) (<-chan *grab.Response, error) { 15 | reqs := make([]*grab.Request, len(urlStrs)) 16 | for i := 0; i < len(urlStrs); i++ { 17 | req, err := grab.NewRequest(dst, urlStrs[i]) 18 | if err != nil { 19 | return nil, err 20 | } 21 | req = req.WithContext(ctx) 22 | reqs[i] = req 23 | } 24 | 25 | ui := NewConsoleClient(grab.DefaultClient) 26 | return ui.Do(ctx, workers, reqs...), nil 27 | } 28 | -------------------------------------------------------------------------------- /v3/rate_limiter.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import "context" 4 | 5 | // RateLimiter is an interface that must be satisfied by any third-party rate 6 | // limiters that may be used to limit download transfer speeds. 7 | // 8 | // A recommended token bucket implementation can be found at 9 | // https://godoc.org/golang.org/x/time/rate#Limiter. 10 | type RateLimiter interface { 11 | WaitN(ctx context.Context, n int) (err error) 12 | } 13 | -------------------------------------------------------------------------------- /v3/rate_limiter_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/cavaliergopher/grab/v3/pkg/grabtest" 11 | ) 12 | 13 | // testRateLimiter is a naive rate limiter that limits throughput to r tokens 14 | // per second. The total number of tokens issued is tracked as n. 15 | type testRateLimiter struct { 16 | r, n int 17 | } 18 | 19 | func NewLimiter(r int) RateLimiter { 20 | return &testRateLimiter{r: r} 21 | } 22 | 23 | func (c *testRateLimiter) WaitN(ctx context.Context, n int) (err error) { 24 | c.n += n 25 | time.Sleep( 26 | time.Duration(1.00 / float64(c.r) * float64(n) * float64(time.Second))) 27 | return 28 | } 29 | 30 | func TestRateLimiter(t *testing.T) { 31 | // download a 128 byte file, 8 bytes at a time, with a naive 512bps limiter 32 | // should take > 250ms 33 | filesize := 128 34 | filename := ".testRateLimiter" 35 | defer os.Remove(filename) 36 | 37 | grabtest.WithTestServer(t, func(url string) { 38 | // limit to 512bps 39 | lim := &testRateLimiter{r: 512} 40 | req := mustNewRequest(filename, url) 41 | 42 | // ensure multiple trips to the rate limiter by downloading 8 bytes at a time 43 | req.BufferSize = 8 44 | req.RateLimiter = lim 45 | 46 | resp := mustDo(req) 47 | testComplete(t, resp) 48 | if lim.n != filesize { 49 | t.Errorf("expected %d bytes to pass through limiter, got %d", filesize, lim.n) 50 | } 51 | if resp.Duration().Seconds() < 0.25 { 52 | // BUG: this test can pass if the transfer was slow for unrelated reasons 53 | t.Errorf("expected transfer to take >250ms, took %v", resp.Duration()) 54 | } 55 | }, grabtest.ContentLength(filesize)) 56 | } 57 | 58 | func ExampleRateLimiter() { 59 | req, _ := NewRequest("", "http://www.golang-book.com/public/pdf/gobook.pdf") 60 | 61 | // Attach a 1Mbps rate limiter, like the token bucket implementation from 62 | // golang.org/x/time/rate. 63 | req.RateLimiter = NewLimiter(1048576) 64 | 65 | resp := DefaultClient.Do(req) 66 | if err := resp.Err(); err != nil { 67 | log.Fatal(err) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /v3/request.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "context" 5 | "hash" 6 | "net/http" 7 | "net/url" 8 | ) 9 | 10 | // A Hook is a user provided callback function that can be called by grab at 11 | // various stages of a requests lifecycle. If a hook returns an error, the 12 | // associated request is canceled and the same error is returned on the Response 13 | // object. 14 | // 15 | // Hook functions are called synchronously and should never block unnecessarily. 16 | // Response methods that block until a download is complete, such as 17 | // Response.Err, Response.Cancel or Response.Wait will deadlock. To cancel a 18 | // download from a callback, simply return a non-nil error. 19 | type Hook func(*Response) error 20 | 21 | // A Request represents an HTTP file transfer request to be sent by a Client. 22 | type Request struct { 23 | // Label is an arbitrary string which may used to label a Request with a 24 | // user friendly name. 25 | Label string 26 | 27 | // Tag is an arbitrary interface which may be used to relate a Request to 28 | // other data. 29 | Tag interface{} 30 | 31 | // HTTPRequest specifies the http.Request to be sent to the remote server to 32 | // initiate a file transfer. It includes request configuration such as URL, 33 | // protocol version, HTTP method, request headers and authentication. 34 | HTTPRequest *http.Request 35 | 36 | // Filename specifies the path where the file transfer will be stored in 37 | // local storage. If Filename is empty or a directory, the true Filename will 38 | // be resolved using Content-Disposition headers or the request URL. 39 | // 40 | // An empty string means the transfer will be stored in the current working 41 | // directory. 42 | Filename string 43 | 44 | // SkipExisting specifies that ErrFileExists should be returned if the 45 | // destination path already exists. The existing file will not be checked for 46 | // completeness. 47 | SkipExisting bool 48 | 49 | // NoResume specifies that a partially completed download will be restarted 50 | // without attempting to resume any existing file. If the download is already 51 | // completed in full, it will not be restarted. 52 | NoResume bool 53 | 54 | // NoStore specifies that grab should not write to the local file system. 55 | // Instead, the download will be stored in memory and accessible only via 56 | // Response.Open or Response.Bytes. 57 | NoStore bool 58 | 59 | // NoCreateDirectories specifies that any missing directories in the given 60 | // Filename path should not be created automatically, if they do not already 61 | // exist. 62 | NoCreateDirectories bool 63 | 64 | // IgnoreBadStatusCodes specifies that grab should accept any status code in 65 | // the response from the remote server. Otherwise, grab expects the response 66 | // status code to be within the 2XX range (after following redirects). 67 | IgnoreBadStatusCodes bool 68 | 69 | // IgnoreRemoteTime specifies that grab should not attempt to set the 70 | // timestamp of the local file to match the remote file. 71 | IgnoreRemoteTime bool 72 | 73 | // Size specifies the expected size of the file transfer if known. If the 74 | // server response size does not match, the transfer is cancelled and 75 | // ErrBadLength returned. 76 | Size int64 77 | 78 | // BufferSize specifies the size in bytes of the buffer that is used for 79 | // transferring the requested file. Larger buffers may result in faster 80 | // throughput but will use more memory and result in less frequent updates 81 | // to the transfer progress statistics. If a RateLimiter is configured, 82 | // BufferSize should be much lower than the rate limit. Default: 32KB. 83 | BufferSize int 84 | 85 | // RateLimiter allows the transfer rate of a download to be limited. The given 86 | // Request.BufferSize determines how frequently the RateLimiter will be 87 | // polled. 88 | RateLimiter RateLimiter 89 | 90 | // BeforeCopy is a user provided callback that is called immediately before 91 | // a request starts downloading. If BeforeCopy returns an error, the request 92 | // is cancelled and the same error is returned on the Response object. 93 | BeforeCopy Hook 94 | 95 | // AfterCopy is a user provided callback that is called immediately after a 96 | // request has finished downloading, before checksum validation and closure. 97 | // This hook is only called if the transfer was successful. If AfterCopy 98 | // returns an error, the request is canceled and the same error is returned on 99 | // the Response object. 100 | AfterCopy Hook 101 | 102 | // hash, checksum and deleteOnError - set via SetChecksum. 103 | hash hash.Hash 104 | checksum []byte 105 | deleteOnError bool 106 | 107 | // Context for cancellation and timeout - set via WithContext 108 | ctx context.Context 109 | } 110 | 111 | // NewRequest returns a new file transfer Request suitable for use with 112 | // Client.Do. 113 | func NewRequest(dst, urlStr string) (*Request, error) { 114 | if dst == "" { 115 | dst = "." 116 | } 117 | req, err := http.NewRequest("GET", urlStr, nil) 118 | if err != nil { 119 | return nil, err 120 | } 121 | return &Request{ 122 | HTTPRequest: req, 123 | Filename: dst, 124 | }, nil 125 | } 126 | 127 | // Context returns the request's context. To change the context, use 128 | // WithContext. 129 | // 130 | // The returned context is always non-nil; it defaults to the background 131 | // context. 132 | // 133 | // The context controls cancelation. 134 | func (r *Request) Context() context.Context { 135 | if r.ctx != nil { 136 | return r.ctx 137 | } 138 | 139 | return context.Background() 140 | } 141 | 142 | // WithContext returns a shallow copy of r with its context changed 143 | // to ctx. The provided ctx must be non-nil. 144 | func (r *Request) WithContext(ctx context.Context) *Request { 145 | if ctx == nil { 146 | panic("nil context") 147 | } 148 | r2 := new(Request) 149 | *r2 = *r 150 | r2.ctx = ctx 151 | r2.HTTPRequest = r2.HTTPRequest.WithContext(ctx) 152 | return r2 153 | } 154 | 155 | // URL returns the URL to be downloaded. 156 | func (r *Request) URL() *url.URL { 157 | return r.HTTPRequest.URL 158 | } 159 | 160 | // SetChecksum sets the desired hashing algorithm and checksum value to validate 161 | // a downloaded file. Once the download is complete, the given hashing algorithm 162 | // will be used to compute the actual checksum of the downloaded file. If the 163 | // checksums do not match, an error will be returned by the associated 164 | // Response.Err method. 165 | // 166 | // If deleteOnError is true, the downloaded file will be deleted automatically 167 | // if it fails checksum validation. 168 | // 169 | // To prevent corruption of the computed checksum, the given hash must not be 170 | // used by any other request or goroutines. 171 | // 172 | // To disable checksum validation, call SetChecksum with a nil hash. 173 | func (r *Request) SetChecksum(h hash.Hash, sum []byte, deleteOnError bool) { 174 | r.hash = h 175 | r.checksum = sum 176 | r.deleteOnError = deleteOnError 177 | } 178 | -------------------------------------------------------------------------------- /v3/response.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "os" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | // Response represents the response to a completed or in-progress download 15 | // request. 16 | // 17 | // A response may be returned as soon a HTTP response is received from a remote 18 | // server, but before the body content has started transferring. 19 | // 20 | // All Response method calls are thread-safe. 21 | type Response struct { 22 | // The Request that was submitted to obtain this Response. 23 | Request *Request 24 | 25 | // HTTPResponse represents the HTTP response received from an HTTP request. 26 | // 27 | // The response Body should not be used as it will be consumed and closed by 28 | // grab. 29 | HTTPResponse *http.Response 30 | 31 | // Filename specifies the path where the file transfer is stored in local 32 | // storage. 33 | Filename string 34 | 35 | // Size specifies the total expected size of the file transfer. 36 | sizeUnsafe int64 37 | 38 | // Start specifies the time at which the file transfer started. 39 | Start time.Time 40 | 41 | // End specifies the time at which the file transfer completed. 42 | // 43 | // This will return zero until the transfer has completed. 44 | End time.Time 45 | 46 | // CanResume specifies that the remote server advertised that it can resume 47 | // previous downloads, as the 'Accept-Ranges: bytes' header is set. 48 | CanResume bool 49 | 50 | // DidResume specifies that the file transfer resumed a previously incomplete 51 | // transfer. 52 | DidResume bool 53 | 54 | // Done is closed once the transfer is finalized, either successfully or with 55 | // errors. Errors are available via Response.Err 56 | Done chan struct{} 57 | 58 | // ctx is a Context that controls cancelation of an inprogress transfer 59 | ctx context.Context 60 | 61 | // cancel is a cancel func that can be used to cancel the context of this 62 | // Response. 63 | cancel context.CancelFunc 64 | 65 | // fi is the FileInfo for the destination file if it already existed before 66 | // transfer started. 67 | fi os.FileInfo 68 | 69 | // optionsKnown indicates that a HEAD request has been completed and the 70 | // capabilities of the remote server are known. 71 | optionsKnown bool 72 | 73 | // writer is the file handle used to write the downloaded file to local 74 | // storage 75 | writer io.Writer 76 | 77 | // storeBuffer receives the contents of the transfer if Request.NoStore is 78 | // enabled. 79 | storeBuffer bytes.Buffer 80 | 81 | // bytesCompleted specifies the number of bytes which were already 82 | // transferred before this transfer began. 83 | bytesResumed int64 84 | 85 | // transfer is responsible for copying data from the remote server to a local 86 | // file, tracking progress and allowing for cancelation. 87 | transfer *transfer 88 | 89 | // bufferSize specifies the size in bytes of the transfer buffer. 90 | bufferSize int 91 | 92 | // Error contains any error that may have occurred during the file transfer. 93 | // This should not be read until IsComplete returns true. 94 | err error 95 | } 96 | 97 | // IsComplete returns true if the download has completed. If an error occurred 98 | // during the download, it can be returned via Err. 99 | func (c *Response) IsComplete() bool { 100 | select { 101 | case <-c.Done: 102 | return true 103 | default: 104 | return false 105 | } 106 | } 107 | 108 | // Cancel cancels the file transfer by canceling the underlying Context for 109 | // this Response. Cancel blocks until the transfer is closed and returns any 110 | // error - typically context.Canceled. 111 | func (c *Response) Cancel() error { 112 | c.cancel() 113 | return c.Err() 114 | } 115 | 116 | // Wait blocks until the download is completed. 117 | func (c *Response) Wait() { 118 | <-c.Done 119 | } 120 | 121 | // Err blocks the calling goroutine until the underlying file transfer is 122 | // completed and returns any error that may have occurred. If the download is 123 | // already completed, Err returns immediately. 124 | func (c *Response) Err() error { 125 | <-c.Done 126 | return c.err 127 | } 128 | 129 | // Size returns the size of the file transfer. If the remote server does not 130 | // specify the total size and the transfer is incomplete, the return value is 131 | // -1. 132 | func (c *Response) Size() int64 { 133 | return atomic.LoadInt64(&c.sizeUnsafe) 134 | } 135 | 136 | // BytesComplete returns the total number of bytes which have been copied to 137 | // the destination, including any bytes that were resumed from a previous 138 | // download. 139 | func (c *Response) BytesComplete() int64 { 140 | return c.bytesResumed + c.transfer.N() 141 | } 142 | 143 | // BytesPerSecond returns the number of bytes per second transferred using a 144 | // simple moving average of the last five seconds. If the download is already 145 | // complete, the average bytes/sec for the life of the download is returned. 146 | func (c *Response) BytesPerSecond() float64 { 147 | if c.IsComplete() { 148 | return float64(c.transfer.N()) / c.Duration().Seconds() 149 | } 150 | return c.transfer.BPS() 151 | } 152 | 153 | // Progress returns the ratio of total bytes that have been downloaded. Multiply 154 | // the returned value by 100 to return the percentage completed. 155 | func (c *Response) Progress() float64 { 156 | size := c.Size() 157 | if size <= 0 { 158 | return 0 159 | } 160 | return float64(c.BytesComplete()) / float64(size) 161 | } 162 | 163 | // Duration returns the duration of a file transfer. If the transfer is in 164 | // process, the duration will be between now and the start of the transfer. If 165 | // the transfer is complete, the duration will be between the start and end of 166 | // the completed transfer process. 167 | func (c *Response) Duration() time.Duration { 168 | if c.IsComplete() { 169 | return c.End.Sub(c.Start) 170 | } 171 | 172 | return time.Now().Sub(c.Start) 173 | } 174 | 175 | // ETA returns the estimated time at which the the download will complete, given 176 | // the current BytesPerSecond. If the transfer has already completed, the actual 177 | // end time will be returned. 178 | func (c *Response) ETA() time.Time { 179 | if c.IsComplete() { 180 | return c.End 181 | } 182 | bt := c.BytesComplete() 183 | bps := c.transfer.BPS() 184 | if bps == 0 { 185 | return time.Time{} 186 | } 187 | secs := float64(c.Size()-bt) / bps 188 | return time.Now().Add(time.Duration(secs) * time.Second) 189 | } 190 | 191 | // Open blocks the calling goroutine until the underlying file transfer is 192 | // completed and then opens the transferred file for reading. If Request.NoStore 193 | // was enabled, the reader will read from memory. 194 | // 195 | // If an error occurred during the transfer, it will be returned. 196 | // 197 | // It is the callers responsibility to close the returned file handle. 198 | func (c *Response) Open() (io.ReadCloser, error) { 199 | if err := c.Err(); err != nil { 200 | return nil, err 201 | } 202 | return c.openUnsafe() 203 | } 204 | 205 | func (c *Response) openUnsafe() (io.ReadCloser, error) { 206 | if c.Request.NoStore { 207 | return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil 208 | } 209 | return os.Open(c.Filename) 210 | } 211 | 212 | // Bytes blocks the calling goroutine until the underlying file transfer is 213 | // completed and then reads all bytes from the completed tranafer. If 214 | // Request.NoStore was enabled, the bytes will be read from memory. 215 | // 216 | // If an error occurred during the transfer, it will be returned. 217 | func (c *Response) Bytes() ([]byte, error) { 218 | if err := c.Err(); err != nil { 219 | return nil, err 220 | } 221 | if c.Request.NoStore { 222 | return c.storeBuffer.Bytes(), nil 223 | } 224 | f, err := c.Open() 225 | if err != nil { 226 | return nil, err 227 | } 228 | defer f.Close() 229 | return ioutil.ReadAll(f) 230 | } 231 | 232 | func (c *Response) requestMethod() string { 233 | if c == nil || c.HTTPResponse == nil || c.HTTPResponse.Request == nil { 234 | return "" 235 | } 236 | return c.HTTPResponse.Request.Method 237 | } 238 | 239 | func (c *Response) checksumUnsafe() ([]byte, error) { 240 | f, err := c.openUnsafe() 241 | if err != nil { 242 | return nil, err 243 | } 244 | defer f.Close() 245 | t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil) 246 | if _, err = t.copy(); err != nil { 247 | return nil, err 248 | } 249 | sum := c.Request.hash.Sum(nil) 250 | return sum, nil 251 | } 252 | 253 | func (c *Response) closeResponseBody() error { 254 | if c.HTTPResponse == nil || c.HTTPResponse.Body == nil { 255 | return nil 256 | } 257 | return c.HTTPResponse.Body.Close() 258 | } 259 | -------------------------------------------------------------------------------- /v3/response_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | "github.com/cavaliergopher/grab/v3/pkg/grabtest" 10 | ) 11 | 12 | // testComplete validates that a completed Response has all the desired fields. 13 | func testComplete(t *testing.T, resp *Response) { 14 | <-resp.Done 15 | if !resp.IsComplete() { 16 | t.Errorf("Response.IsComplete returned false") 17 | } 18 | 19 | if resp.Start.IsZero() { 20 | t.Errorf("Response.Start is zero") 21 | } 22 | 23 | if resp.End.IsZero() { 24 | t.Error("Response.End is zero") 25 | } 26 | 27 | if eta := resp.ETA(); eta != resp.End { 28 | t.Errorf("Response.ETA is not equal to Response.End: %v", eta) 29 | } 30 | 31 | // the following fields should only be set if no error occurred 32 | if resp.Err() == nil { 33 | if resp.Filename == "" { 34 | t.Errorf("Response.Filename is empty") 35 | } 36 | 37 | if resp.Size() == 0 { 38 | t.Error("Response.Size is zero") 39 | } 40 | 41 | if p := resp.Progress(); p != 1.00 { 42 | t.Errorf("Response.Progress returned %v (%v/%v bytes), expected 1", p, resp.BytesComplete(), resp.Size()) 43 | } 44 | } 45 | } 46 | 47 | // TestResponseProgress tests the functions which indicate the progress of an 48 | // in-process file transfer. 49 | func TestResponseProgress(t *testing.T) { 50 | filename := ".testResponseProgress" 51 | defer os.Remove(filename) 52 | 53 | sleep := 300 * time.Millisecond 54 | size := 1024 * 8 // bytes 55 | 56 | grabtest.WithTestServer(t, func(url string) { 57 | // request a slow transfer 58 | req := mustNewRequest(filename, url) 59 | resp := DefaultClient.Do(req) 60 | 61 | // make sure transfer has not started 62 | if resp.IsComplete() { 63 | t.Errorf("Transfer should not have started") 64 | } 65 | 66 | if p := resp.Progress(); p != 0 { 67 | t.Errorf("Transfer should not have started yet but progress is %v", p) 68 | } 69 | 70 | // wait for transfer to complete 71 | <-resp.Done 72 | 73 | // make sure transfer is complete 74 | if p := resp.Progress(); p != 1 { 75 | t.Errorf("Transfer is complete but progress is %v", p) 76 | } 77 | 78 | if s := resp.BytesComplete(); s != int64(size) { 79 | t.Errorf("Expected to transfer %v bytes, got %v", size, s) 80 | } 81 | }, 82 | grabtest.TimeToFirstByte(sleep), 83 | grabtest.ContentLength(size), 84 | ) 85 | } 86 | 87 | func TestResponseOpen(t *testing.T) { 88 | grabtest.WithTestServer(t, func(url string) { 89 | resp := mustDo(mustNewRequest("", url+"/someFilename")) 90 | f, err := resp.Open() 91 | if err != nil { 92 | t.Error(err) 93 | return 94 | } 95 | defer func() { 96 | if err := f.Close(); err != nil { 97 | t.Error(err) 98 | } 99 | }() 100 | grabtest.AssertSHA256Sum(t, grabtest.DefaultHandlerSHA256ChecksumBytes, f) 101 | }) 102 | } 103 | 104 | func TestResponseBytes(t *testing.T) { 105 | grabtest.WithTestServer(t, func(url string) { 106 | resp := mustDo(mustNewRequest("", url+"/someFilename")) 107 | b, err := resp.Bytes() 108 | if err != nil { 109 | t.Error(err) 110 | return 111 | } 112 | grabtest.AssertSHA256Sum( 113 | t, 114 | grabtest.DefaultHandlerSHA256ChecksumBytes, 115 | bytes.NewReader(b), 116 | ) 117 | }) 118 | } 119 | -------------------------------------------------------------------------------- /v3/transfer.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/cavaliergopher/grab/v3/pkg/bps" 10 | ) 11 | 12 | type transfer struct { 13 | n int64 // must be 64bit aligned on 386 14 | ctx context.Context 15 | gauge bps.Gauge 16 | lim RateLimiter 17 | w io.Writer 18 | r io.Reader 19 | b []byte 20 | } 21 | 22 | func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer { 23 | return &transfer{ 24 | ctx: ctx, 25 | gauge: bps.NewSMA(6), // five second moving average sampling every second 26 | lim: lim, 27 | w: dst, 28 | r: src, 29 | b: buf, 30 | } 31 | } 32 | 33 | // copy behaves similarly to io.CopyBuffer except that it checks for cancelation 34 | // of the given context.Context, reports progress in a thread-safe manner and 35 | // tracks the transfer rate. 36 | func (c *transfer) copy() (written int64, err error) { 37 | // maintain a bps gauge in another goroutine 38 | ctx, cancel := context.WithCancel(c.ctx) 39 | defer cancel() 40 | go bps.Watch(ctx, c.gauge, c.N, time.Second) 41 | 42 | // start the transfer 43 | if c.b == nil { 44 | c.b = make([]byte, 32*1024) 45 | } 46 | for { 47 | select { 48 | case <-c.ctx.Done(): 49 | err = c.ctx.Err() 50 | return 51 | default: 52 | // keep working 53 | } 54 | nr, er := c.r.Read(c.b) 55 | if nr > 0 { 56 | nw, ew := c.w.Write(c.b[0:nr]) 57 | if nw > 0 { 58 | written += int64(nw) 59 | atomic.StoreInt64(&c.n, written) 60 | } 61 | if ew != nil { 62 | err = ew 63 | break 64 | } 65 | if nr != nw { 66 | err = io.ErrShortWrite 67 | break 68 | } 69 | // wait for rate limiter 70 | if c.lim != nil { 71 | err = c.lim.WaitN(c.ctx, nr) 72 | if err != nil { 73 | return 74 | } 75 | } 76 | } 77 | if er != nil { 78 | if er != io.EOF { 79 | err = er 80 | } 81 | break 82 | } 83 | } 84 | return written, err 85 | } 86 | 87 | // N returns the number of bytes transferred. 88 | func (c *transfer) N() (n int64) { 89 | if c == nil { 90 | return 0 91 | } 92 | n = atomic.LoadInt64(&c.n) 93 | return 94 | } 95 | 96 | // BPS returns the current bytes per second transfer rate using a simple moving 97 | // average. 98 | func (c *transfer) BPS() (bps float64) { 99 | if c == nil || c.gauge == nil { 100 | return 0 101 | } 102 | return c.gauge.BPS() 103 | } 104 | -------------------------------------------------------------------------------- /v3/util.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "fmt" 5 | "mime" 6 | "net/http" 7 | "os" 8 | "path" 9 | "path/filepath" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | // setLastModified sets the last modified timestamp of a local file according to 15 | // the Last-Modified header returned by a remote server. 16 | func setLastModified(resp *http.Response, filename string) error { 17 | // https://tools.ietf.org/html/rfc7232#section-2.2 18 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified 19 | header := resp.Header.Get("Last-Modified") 20 | if header == "" { 21 | return nil 22 | } 23 | lastmod, err := time.Parse(http.TimeFormat, header) 24 | if err != nil { 25 | return nil 26 | } 27 | return os.Chtimes(filename, lastmod, lastmod) 28 | } 29 | 30 | // mkdirp creates all missing parent directories for the destination file path. 31 | func mkdirp(path string) error { 32 | dir := filepath.Dir(path) 33 | if fi, err := os.Stat(dir); err != nil { 34 | if !os.IsNotExist(err) { 35 | return fmt.Errorf("error checking destination directory: %v", err) 36 | } 37 | if err := os.MkdirAll(dir, 0777); err != nil { 38 | return fmt.Errorf("error creating destination directory: %v", err) 39 | } 40 | } else if !fi.IsDir() { 41 | panic("grab: developer error: destination path is not directory") 42 | } 43 | return nil 44 | } 45 | 46 | // guessFilename returns a filename for the given http.Response. If none can be 47 | // determined ErrNoFilename is returned. 48 | // 49 | // TODO: NoStore operations should not require a filename 50 | func guessFilename(resp *http.Response) (string, error) { 51 | filename := resp.Request.URL.Path 52 | if cd := resp.Header.Get("Content-Disposition"); cd != "" { 53 | if _, params, err := mime.ParseMediaType(cd); err == nil { 54 | if val, ok := params["filename"]; ok { 55 | filename = val 56 | } // else filename directive is missing.. fallback to URL.Path 57 | } 58 | } 59 | 60 | // sanitize 61 | if filename == "" || strings.HasSuffix(filename, "/") || strings.Contains(filename, "\x00") { 62 | return "", ErrNoFilename 63 | } 64 | 65 | filename = filepath.Base(path.Clean("/" + filename)) 66 | if filename == "" || filename == "." || filename == "/" { 67 | return "", ErrNoFilename 68 | } 69 | 70 | return filename, nil 71 | } 72 | -------------------------------------------------------------------------------- /v3/util_test.go: -------------------------------------------------------------------------------- 1 | package grab 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "testing" 8 | ) 9 | 10 | func TestURLFilenames(t *testing.T) { 11 | t.Run("Valid", func(t *testing.T) { 12 | expect := "filename" 13 | testCases := []string{ 14 | "http://test.com/filename", 15 | "http://test.com/path/filename", 16 | "http://test.com/deep/path/filename", 17 | "http://test.com/filename?with=args", 18 | "http://test.com/filename#with-fragment", 19 | "http://test.com/filename?with=args&and#with-fragment", 20 | } 21 | 22 | for _, tc := range testCases { 23 | req, _ := http.NewRequest("GET", tc, nil) 24 | resp := &http.Response{ 25 | Request: req, 26 | } 27 | actual, err := guessFilename(resp) 28 | if err != nil { 29 | t.Errorf("%v", err) 30 | } 31 | 32 | if actual != expect { 33 | t.Errorf("expected '%v', got '%v'", expect, actual) 34 | } 35 | } 36 | }) 37 | 38 | t.Run("Invalid", func(t *testing.T) { 39 | testCases := []string{ 40 | "http://test.com", 41 | "http://test.com/", 42 | "http://test.com/filename/", 43 | "http://test.com/filename/?with=args", 44 | "http://test.com/filename/#with-fragment", 45 | "http://test.com/filename\x00", 46 | } 47 | 48 | for _, tc := range testCases { 49 | t.Run(tc, func(t *testing.T) { 50 | req, err := http.NewRequest("GET", tc, nil) 51 | if err != nil { 52 | if tc == "http://test.com/filename\x00" { 53 | // Since go1.12, urls with invalid control character return an error 54 | // See https://github.com/golang/go/commit/829c5df58694b3345cb5ea41206783c8ccf5c3ca 55 | t.Skip() 56 | } 57 | } 58 | resp := &http.Response{ 59 | Request: req, 60 | } 61 | 62 | _, err = guessFilename(resp) 63 | if err != ErrNoFilename { 64 | t.Errorf("expected '%v', got '%v'", ErrNoFilename, err) 65 | } 66 | }) 67 | } 68 | }) 69 | } 70 | 71 | func TestHeaderFilenames(t *testing.T) { 72 | u, _ := url.ParseRequestURI("http://test.com/badfilename") 73 | resp := &http.Response{ 74 | Request: &http.Request{ 75 | URL: u, 76 | }, 77 | Header: http.Header{}, 78 | } 79 | 80 | setFilename := func(resp *http.Response, filename string) { 81 | resp.Header.Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filename)) 82 | } 83 | 84 | t.Run("Valid", func(t *testing.T) { 85 | expect := "filename" 86 | testCases := []string{ 87 | "filename", 88 | "path/filename", 89 | "/path/filename", 90 | "../../filename", 91 | "/path/../../filename", 92 | "/../../././///filename", 93 | } 94 | 95 | for _, tc := range testCases { 96 | setFilename(resp, tc) 97 | actual, err := guessFilename(resp) 98 | if err != nil { 99 | t.Errorf("error (%v): %v", tc, err) 100 | } 101 | 102 | if actual != expect { 103 | t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual) 104 | } 105 | } 106 | }) 107 | 108 | t.Run("Invalid", func(t *testing.T) { 109 | testCases := []string{ 110 | "", 111 | "/", 112 | ".", 113 | "/.", 114 | "/./", 115 | "..", 116 | "../", 117 | "/../", 118 | "/path/", 119 | "../path/", 120 | "filename\x00", 121 | "filename/", 122 | "filename//", 123 | "filename/..", 124 | } 125 | 126 | for _, tc := range testCases { 127 | setFilename(resp, tc) 128 | if actual, err := guessFilename(resp); err != ErrNoFilename { 129 | t.Errorf("expected: %v (%v), got: %v (%v)", ErrNoFilename, tc, err, actual) 130 | } 131 | } 132 | }) 133 | } 134 | 135 | func TestHeaderWithMissingDirective(t *testing.T) { 136 | u, _ := url.ParseRequestURI("http://test.com/filename") 137 | resp := &http.Response{ 138 | Request: &http.Request{ 139 | URL: u, 140 | }, 141 | Header: http.Header{}, 142 | } 143 | 144 | setHeader := func(resp *http.Response, value string) { 145 | resp.Header.Set("Content-Disposition", value) 146 | } 147 | 148 | t.Run("Valid", func(t *testing.T) { 149 | expect := "filename" 150 | testCases := []string{ 151 | "inline", 152 | "attachment", 153 | } 154 | 155 | for _, tc := range testCases { 156 | setHeader(resp, tc) 157 | actual, err := guessFilename(resp) 158 | if err != nil { 159 | t.Errorf("error (%v): %v", tc, err) 160 | } 161 | 162 | if actual != expect { 163 | t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual) 164 | } 165 | } 166 | }) 167 | } 168 | --------------------------------------------------------------------------------