├── .github └── workflows │ └── rust.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Cargo.toml ├── README.md ├── examples ├── download.rs ├── iced │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs └── socks │ ├── Cargo.toml │ ├── README.md │ └── src │ └── main.rs ├── flake.lock ├── flake.nix └── src ├── api ├── mod.rs ├── sync.rs └── tokio.rs └── lib.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macOS-latest] 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Install Rust Stable 20 | uses: actions-rs/toolchain@v1 21 | with: 22 | toolchain: stable 23 | components: rustfmt, clippy 24 | override: true 25 | 26 | - uses: Swatinem/rust-cache@v2 27 | 28 | - name: Install cargo audit 29 | run: cargo install cargo-audit 30 | 31 | - name: Build 32 | run: cargo build --all-targets --verbose 33 | 34 | - name: Lint with Clippy 35 | run: cargo clippy --all-targets --all-features --tests --examples -- -D warnings 36 | 37 | - name: Run Tests 38 | run: cargo test --all-features --verbose 39 | 40 | - name: Run Tests (no ssl) 41 | run: cargo test --no-default-features --verbose 42 | 43 | - name: Run Tests (ssl cross) 44 | run: > 45 | cargo test --no-default-features --features ureq,native-tls && 46 | cargo test --no-default-features --features ureq,rustls-tls && 47 | cargo test --no-default-features --features tokio,native-tls && 48 | cargo test --no-default-features --features tokio,rustls-tls 49 | 50 | - name: Run Audit 51 | run: cargo audit -D warnings 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | /examples/socks/target/ 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Narsil/pre-commit-rust 3 | rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0 4 | hooks: 5 | - id: fmt 6 | name: "Rust (fmt)" 7 | - id: clippy 8 | name: "Rust (clippy)" 9 | args: 10 | [ 11 | "--tests", 12 | "--examples", 13 | "--", 14 | "-Dwarnings", 15 | ] 16 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hf-hub" 3 | version = "0.4.2" 4 | edition = "2021" 5 | homepage = "https://github.com/huggingface/hf-hub" 6 | license = "Apache-2.0" 7 | documentation = "https://docs.rs/hf-hub" 8 | repository = "https://github.com/huggingface/hf-hub" 9 | readme = "README.md" 10 | keywords = ["huggingface", "hf", "hub", "machine-learning"] 11 | description = """ 12 | This crates aims ease the interaction with [huggingface](https://huggingface.co/) 13 | It aims to be compatible with [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package, but only implements a smaller subset of functions. 14 | """ 15 | 16 | 17 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 18 | 19 | [dependencies] 20 | futures = { version = "0.3.28", optional = true } 21 | dirs = "5.0.1" 22 | http = { version = "1.0.0", optional = true } 23 | indicatif = { version = "0.17.5", optional = true } 24 | log = "0.4.19" 25 | num_cpus = { version = "1.15.0", optional = true } 26 | rand = { version = "0.8.5", optional = true } 27 | reqwest = { version = "0.12.2", optional = true, default-features = false, features = [ 28 | "json", 29 | "stream", 30 | ] } 31 | serde = { version = "1", features = ["derive"], optional = true } 32 | serde_json = { version = "1", optional = true } 33 | thiserror = { version = "2", optional = true } 34 | tokio = { version = "1.29.1", optional = true, features = ["fs", "macros"] } 35 | ureq = { version = "2.8.0", optional = true, features = [ 36 | "json", 37 | "socks-proxy", 38 | ] } 39 | native-tls = { version = "0.2.12", optional = true } 40 | 41 | [target.'cfg(windows)'.dependencies.windows-sys] 42 | version = "0.59" 43 | features = ["Win32_Foundation", "Win32_Storage_FileSystem", "Win32_System_IO"] 44 | optional = true 45 | 46 | [target.'cfg(unix)'.dependencies.libc] 47 | version = "0.2" 48 | optional = true 49 | 50 | [features] 51 | default = ["default-tls", "tokio", "ureq"] 52 | # These features are only relevant when used with the `tokio` feature, but this might change in the future. 53 | default-tls = ["native-tls"] 54 | native-tls = ["dep:reqwest", "reqwest?/default", "dep:native-tls", "dep:ureq", "ureq?/native-tls"] 55 | rustls-tls = ["reqwest?/rustls-tls"] 56 | tokio = [ 57 | "dep:futures", 58 | "dep:indicatif", 59 | "dep:num_cpus", 60 | "dep:rand", 61 | "dep:reqwest", 62 | "reqwest/charset", 63 | "reqwest/http2", 64 | "reqwest/macos-system-configuration", 65 | "dep:serde", 66 | "dep:serde_json", 67 | "dep:thiserror", 68 | "dep:tokio", 69 | "tokio/rt-multi-thread", 70 | "dep:libc", 71 | "dep:windows-sys", 72 | ] 73 | ureq = [ 74 | "dep:http", 75 | "dep:indicatif", 76 | "dep:rand", 77 | "dep:serde", 78 | "dep:serde_json", 79 | "dep:thiserror", 80 | "dep:ureq", 81 | "dep:libc", 82 | "dep:windows-sys", 83 | ] 84 | 85 | [dev-dependencies] 86 | hex-literal = "0.4.1" 87 | sha2 = "0.10" 88 | tokio-test = "0.4.2" 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This crates aims to emulate and be compatible with the 2 | [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package. 3 | 4 | compatible means the Api should reuse the same files skipping downloads if 5 | they are already present and whenever this crate downloads or modifies this cache 6 | it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/) 7 | 8 | At this time only a limited subset of the functionality is present, the goal is to add new 9 | features over time. We are currently treating this as an internel/external tool, meaning 10 | we will are currently modifying everything at will for out internal needs. This will eventually 11 | stabilize as it matures to accomodate most of our needs. 12 | 13 | If you're interested in using this, you're welcome to do it but be warned about potential changing grounds. 14 | 15 | If you want to contribute, you are more than welcome. 16 | 17 | However allowing new features or creating new features might be denied by lack of maintainability 18 | time. We're focusing on what we currently internally need. Hopefully that subset is already interesting 19 | to more users. 20 | 21 | 22 | # How to use 23 | 24 | Add the dependency 25 | 26 | ```bash 27 | cargo add hf-hub # --features tokio 28 | ``` 29 | `tokio` feature will enable an async (and potentially faster) API. 30 | 31 | Use the crate: 32 | 33 | ```rust 34 | use hf_hub::api::sync::Api; 35 | 36 | let api = Api::new().unwrap(); 37 | 38 | let repo = api.model("bert-base-uncased".to_string()); 39 | let _filename = repo.get("config.json").unwrap(); 40 | 41 | // filename is now the local location within hf cache of the config.json file 42 | ``` 43 | 44 | # SSL/TLS 45 | 46 | This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`. 47 | 48 | If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use. 49 | 50 | ```bash 51 | cargo add hf-hub --no-default-features --features ureq,rustls-tls 52 | cargo add hf-hub --no-default-features --features ureq,native-tls 53 | cargo add hf-hub --no-default-features --features tokio,rustls-tls 54 | cargo add hf-hub --no-default-features --features tokio,native-tls 55 | ``` 56 | 57 | 58 | When using the [`ureq`](https://github.com/algesten/ureq) feature, you will always use its default TLS backend which is [rustls](https://github.com/rustls/rustls). 59 | 60 | When using [`tokio`](https://github.com/tokio-rs/tokio), by default `default-tls` will be enabled, which means OpenSSL. If you want/need to use rustls, disable the default features and use `rustls-tls` in conjunction with `tokio`. 61 | -------------------------------------------------------------------------------- /examples/download.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "ureq"))] 2 | #[cfg(not(feature = "tokio"))] 3 | fn main() {} 4 | 5 | #[cfg(feature = "ureq")] 6 | #[cfg(not(feature = "tokio"))] 7 | fn main() { 8 | let api = hf_hub::api::sync::Api::new().unwrap(); 9 | 10 | let _filename = api 11 | .model("meta-llama/Llama-2-7b-hf".to_string()) 12 | .get("model-00001-of-00002.safetensors") 13 | .unwrap(); 14 | } 15 | 16 | #[cfg(feature = "tokio")] 17 | #[tokio::main] 18 | async fn main() { 19 | let api = hf_hub::api::tokio::Api::new().unwrap(); 20 | 21 | let _filename = api 22 | .model("meta-llama/Llama-2-7b-hf".to_string()) 23 | .get("model-00001-of-00002.safetensors") 24 | .await 25 | .unwrap(); 26 | } 27 | -------------------------------------------------------------------------------- /examples/iced/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /examples/iced/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "iced_hf_hub" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | iced = { version = "0.13.1", features = ["tokio"] } 8 | hf-hub = { path = "../../", default-features = false, features = ["tokio", "rustls-tls"] } 9 | -------------------------------------------------------------------------------- /examples/iced/src/main.rs: -------------------------------------------------------------------------------- 1 | use hf_hub::api::tokio::{Api, ApiError}; 2 | use iced::futures::{SinkExt, Stream}; 3 | use iced::stream::try_channel; 4 | use iced::task; 5 | use iced::widget::{button, center, column, progress_bar, text, Column}; 6 | 7 | use iced::{Center, Element, Right, Task}; 8 | 9 | #[derive(Debug, Clone)] 10 | pub enum Progress { 11 | Downloading { current: usize, total: usize }, 12 | Finished, 13 | } 14 | 15 | #[derive(Debug, Clone)] 16 | pub enum Error { 17 | Api(String), 18 | } 19 | 20 | impl From for Error { 21 | fn from(value: ApiError) -> Self { 22 | Self::Api(value.to_string()) 23 | } 24 | } 25 | 26 | pub fn main() -> iced::Result { 27 | iced::application("Download Progress - Iced", Example::update, Example::view).run() 28 | } 29 | 30 | #[derive(Debug)] 31 | struct Example { 32 | downloads: Vec, 33 | last_id: usize, 34 | } 35 | 36 | #[derive(Clone)] 37 | struct Prog { 38 | output: iced::futures::channel::mpsc::Sender, 39 | total: usize, 40 | } 41 | 42 | impl hf_hub::api::tokio::Progress for Prog { 43 | async fn update(&mut self, size: usize) { 44 | let _ = self 45 | .output 46 | .send(Progress::Downloading { 47 | current: size, 48 | total: self.total, 49 | }) 50 | .await; 51 | } 52 | async fn finish(&mut self) { 53 | let _ = self.output.send(Progress::Finished).await; 54 | } 55 | 56 | async fn init(&mut self, size: usize, _filename: &str) { 57 | println!("Initiating {size}"); 58 | let _ = self 59 | .output 60 | .send(Progress::Downloading { 61 | current: 0, 62 | total: size, 63 | }) 64 | .await; 65 | self.total = size; 66 | } 67 | } 68 | 69 | pub fn download( 70 | repo: String, 71 | filename: impl AsRef, 72 | ) -> impl Stream> { 73 | try_channel(1, move |output| async move { 74 | let prog = Prog { output, total: 0 }; 75 | 76 | let api = Api::new().unwrap().model(repo); 77 | api.download_with_progress(filename.as_ref(), prog).await?; 78 | 79 | Ok(()) 80 | }) 81 | } 82 | 83 | #[derive(Debug, Clone)] 84 | pub enum Message { 85 | Add, 86 | Download(usize), 87 | DownloadProgressed(usize, Result), 88 | } 89 | 90 | impl Example { 91 | fn new() -> Self { 92 | Self { 93 | downloads: vec![Download::new(0)], 94 | last_id: 0, 95 | } 96 | } 97 | 98 | fn update(&mut self, message: Message) -> Task { 99 | match message { 100 | Message::Add => { 101 | self.last_id += 1; 102 | 103 | self.downloads.push(Download::new(self.last_id)); 104 | 105 | Task::none() 106 | } 107 | Message::Download(index) => { 108 | let Some(download) = self.downloads.get_mut(index) else { 109 | return Task::none(); 110 | }; 111 | 112 | let task = download.start(); 113 | 114 | task.map(move |progress| Message::DownloadProgressed(index, progress)) 115 | } 116 | Message::DownloadProgressed(id, progress) => { 117 | if let Some(download) = self.downloads.iter_mut().find(|download| download.id == id) 118 | { 119 | download.progress(progress); 120 | } 121 | 122 | Task::none() 123 | } 124 | } 125 | } 126 | 127 | fn view(&self) -> Element { 128 | let downloads = Column::with_children(self.downloads.iter().map(Download::view)) 129 | .push( 130 | button("Add another download") 131 | .on_press(Message::Add) 132 | .padding(10), 133 | ) 134 | .spacing(20) 135 | .align_x(Right); 136 | 137 | center(downloads).padding(20).into() 138 | } 139 | } 140 | 141 | impl Default for Example { 142 | fn default() -> Self { 143 | Self::new() 144 | } 145 | } 146 | 147 | #[derive(Debug)] 148 | struct Download { 149 | id: usize, 150 | state: State, 151 | } 152 | 153 | #[derive(Debug)] 154 | enum State { 155 | Idle, 156 | Downloading { progress: f32, _task: task::Handle }, 157 | Finished, 158 | Errored, 159 | } 160 | 161 | impl Download { 162 | pub fn new(id: usize) -> Self { 163 | Download { 164 | id, 165 | state: State::Idle, 166 | } 167 | } 168 | 169 | pub fn start(&mut self) -> Task> { 170 | match self.state { 171 | State::Idle { .. } | State::Finished { .. } | State::Errored { .. } => { 172 | let (task, handle) = Task::stream(download( 173 | "mattshumer/Reflection-Llama-3.1-70B".to_string(), 174 | "model-00001-of-00162.safetensors", 175 | )) 176 | .abortable(); 177 | 178 | self.state = State::Downloading { 179 | progress: 0.0, 180 | _task: handle.abort_on_drop(), 181 | }; 182 | 183 | task 184 | } 185 | State::Downloading { .. } => Task::none(), 186 | } 187 | } 188 | 189 | pub fn progress(&mut self, new_progress: Result) { 190 | if let State::Downloading { progress, .. } = &mut self.state { 191 | match new_progress { 192 | Ok(Progress::Downloading { current, total }) => { 193 | println!("Status {progress} - {current}"); 194 | let new_progress = current as f32 / total as f32 * 100.0; 195 | println!("New progress {current} {new_progress}"); 196 | *progress += new_progress; 197 | } 198 | Ok(Progress::Finished) => { 199 | self.state = State::Finished; 200 | } 201 | Err(_error) => { 202 | self.state = State::Errored; 203 | } 204 | } 205 | } 206 | } 207 | 208 | pub fn view(&self) -> Element { 209 | let current_progress = match &self.state { 210 | State::Idle { .. } => 0.0, 211 | State::Downloading { progress, .. } => *progress, 212 | State::Finished { .. } => 100.0, 213 | State::Errored { .. } => 0.0, 214 | }; 215 | 216 | let progress_bar = progress_bar(0.0..=100.0, current_progress); 217 | 218 | let control: Element<_> = match &self.state { 219 | State::Idle => button("Start the download!") 220 | .on_press(Message::Download(self.id)) 221 | .into(), 222 | State::Finished => column!["Download finished!", button("Start again")] 223 | .spacing(10) 224 | .align_x(Center) 225 | .into(), 226 | State::Downloading { .. } => text!("Downloading... {current_progress:.2}%").into(), 227 | State::Errored => column![ 228 | "Something went wrong :(", 229 | button("Try again").on_press(Message::Download(self.id)), 230 | ] 231 | .spacing(10) 232 | .align_x(Center) 233 | .into(), 234 | }; 235 | 236 | Column::new() 237 | .spacing(10) 238 | .padding(10) 239 | .align_x(Center) 240 | .push(progress_bar) 241 | .push(control) 242 | .into() 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /examples/socks/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "socks" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | hf-hub = { version = "0.4.0", path = "../.." } 8 | # Adding the `socks` features automatically adds it into 9 | # The reqwest built by hf-hub therefore enabling socks proxying. 10 | reqwest = { version = "0.12.9", features = ["socks"] } 11 | tokio = { version = "1.42.0", features = ["macros"] } 12 | -------------------------------------------------------------------------------- /examples/socks/README.md: -------------------------------------------------------------------------------- 1 | Example showcasing socks routing. 2 | Users simply need to add `reqwest` with proper `socks` feature enabled in order to enable it into `hf-hub`. 3 | 4 | 5 | This is due to [feature unification](https://doc.rust-lang.org/cargo/reference/resolver.html#features). 6 | -------------------------------------------------------------------------------- /examples/socks/src/main.rs: -------------------------------------------------------------------------------- 1 | #[tokio::main] 2 | async fn main() { 3 | let _proxy = std::env::var("HTTPS_PROXY").expect("This example expects a HTTPS_PROXY environment variable to be defined to test that the routing happens correctly. Starts a socks servers and use point HTTPS_PROXY to that server to see the routing in action."); 4 | 5 | let api = hf_hub::api::tokio::ApiBuilder::new() 6 | .with_progress(true) 7 | .build() 8 | .unwrap(); 9 | 10 | let _filename = api 11 | .model("meta-llama/Llama-2-7b-hf".to_string()) 12 | .get("model-00001-of-00002.safetensors") 13 | .await 14 | .unwrap(); 15 | } 16 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "nixpkgs": { 4 | "locked": { 5 | "lastModified": 1734649271, 6 | "narHash": "sha256-4EVBRhOjMDuGtMaofAIqzJbg4Ql7Ai0PSeuVZTHjyKQ=", 7 | "owner": "NixOS", 8 | "repo": "nixpkgs", 9 | "rev": "d70bd19e0a38ad4790d3913bf08fcbfc9eeca507", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "NixOS", 14 | "ref": "nixos-unstable", 15 | "repo": "nixpkgs", 16 | "type": "github" 17 | } 18 | }, 19 | "root": { 20 | "inputs": { 21 | "nixpkgs": "nixpkgs" 22 | } 23 | } 24 | }, 25 | "root": "root", 26 | "version": 7 27 | } 28 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 4 | }; 5 | 6 | outputs = 7 | { nixpkgs, ... }: 8 | let 9 | forAllSystems = nixpkgs.lib.genAttrs [ 10 | "aarch64-linux" 11 | "x86_64-linux" 12 | "aarch64-darwin" 13 | ]; 14 | in 15 | { 16 | devShells = forAllSystems ( 17 | system: 18 | let 19 | pkgs = nixpkgs.legacyPackages.${system}; 20 | in 21 | { 22 | default = pkgs.mkShell { 23 | buildInputs = with pkgs; [ 24 | rustup 25 | pkg-config 26 | openssl 27 | ]; 28 | }; 29 | 30 | } 31 | ); 32 | }; 33 | } 34 | -------------------------------------------------------------------------------- /src/api/mod.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::VecDeque, time::Duration}; 2 | 3 | use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle}; 4 | use serde::Deserialize; 5 | 6 | /// The asynchronous version of the API 7 | #[cfg(feature = "tokio")] 8 | pub mod tokio; 9 | 10 | /// The synchronous version of the API 11 | #[cfg(feature = "ureq")] 12 | pub mod sync; 13 | 14 | const HF_ENDPOINT: &str = "HF_ENDPOINT"; 15 | 16 | /// This trait is used by users of the lib 17 | /// to implement custom behavior during file downloads 18 | pub trait Progress { 19 | /// At the start of the download 20 | /// The size is the total size in bytes of the file. 21 | fn init(&mut self, size: usize, filename: &str); 22 | /// This function is called whenever `size` bytes have been 23 | /// downloaded in the temporary file 24 | fn update(&mut self, size: usize); 25 | /// This is called at the end of the download 26 | fn finish(&mut self); 27 | } 28 | 29 | impl Progress for () { 30 | fn init(&mut self, _size: usize, _filename: &str) {} 31 | fn update(&mut self, _size: usize) {} 32 | fn finish(&mut self) {} 33 | } 34 | 35 | impl Progress for ProgressBar { 36 | fn init(&mut self, size: usize, filename: &str) { 37 | self.set_length(size as u64); 38 | self.set_style( 39 | ProgressStyle::with_template( 40 | "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})", 41 | ).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default()) 42 | , 43 | ); 44 | let maxlength = 30; 45 | let message = if filename.len() > maxlength { 46 | format!("..{}", &filename[filename.len() - maxlength..]) 47 | } else { 48 | filename.to_string() 49 | }; 50 | self.set_message(message); 51 | } 52 | 53 | fn update(&mut self, size: usize) { 54 | self.inc(size as u64) 55 | } 56 | 57 | fn finish(&mut self) { 58 | ProgressBar::finish(self); 59 | } 60 | } 61 | 62 | /// Siblings are simplified file descriptions of remote files on the hub 63 | #[derive(Debug, Clone, Deserialize, PartialEq)] 64 | pub struct Siblings { 65 | /// The path within the repo. 66 | pub rfilename: String, 67 | } 68 | 69 | /// The description of the repo given by the hub 70 | #[derive(Debug, Clone, Deserialize, PartialEq)] 71 | pub struct RepoInfo { 72 | /// See [`Siblings`] 73 | pub siblings: Vec, 74 | 75 | /// The commit sha of the repo. 76 | pub sha: String, 77 | } 78 | 79 | #[derive(Clone, Default)] 80 | struct MovingAvgRate { 81 | samples: VecDeque<(std::time::Instant, u64)>, 82 | } 83 | 84 | impl ProgressTracker for MovingAvgRate { 85 | fn clone_box(&self) -> Box { 86 | Box::new(self.clone()) 87 | } 88 | 89 | fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) { 90 | // sample at most every 20ms 91 | if self 92 | .samples 93 | .back() 94 | .is_none_or(|(prev, _)| (now - *prev) > Duration::from_millis(20)) 95 | { 96 | self.samples.push_back((now, state.pos())); 97 | } 98 | 99 | while let Some(first) = self.samples.front() { 100 | if now - first.0 > Duration::from_secs(1) { 101 | self.samples.pop_front(); 102 | } else { 103 | break; 104 | } 105 | } 106 | } 107 | 108 | fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) { 109 | self.samples = Default::default(); 110 | } 111 | 112 | fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) { 113 | match (self.samples.front(), self.samples.back()) { 114 | (Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => { 115 | let elapsed_ms = (*t1 - *t0).as_millis(); 116 | let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64; 117 | write!(w, "{}/s", HumanBytes(rate)).unwrap() 118 | } 119 | _ => write!(w, "-").unwrap(), 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/api/sync.rs: -------------------------------------------------------------------------------- 1 | use super::{RepoInfo, HF_ENDPOINT}; 2 | use crate::api::sync::ApiError::InvalidHeader; 3 | use crate::api::Progress; 4 | use crate::{Cache, Repo, RepoType}; 5 | use http::{StatusCode, Uri}; 6 | use indicatif::ProgressBar; 7 | use rand::Rng; 8 | use std::collections::HashMap; 9 | use std::io::Read; 10 | use std::io::Seek; 11 | use std::num::ParseIntError; 12 | use std::path::{Component, Path, PathBuf}; 13 | use std::str::FromStr; 14 | use thiserror::Error; 15 | use ureq::{Agent, AgentBuilder, Request}; 16 | 17 | /// Current version (used in user-agent) 18 | const VERSION: &str = env!("CARGO_PKG_VERSION"); 19 | /// Current name (used in user-agent) 20 | const NAME: &str = env!("CARGO_PKG_NAME"); 21 | 22 | const RANGE: &str = "Range"; 23 | const CONTENT_RANGE: &str = "Content-Range"; 24 | const LOCATION: &str = "Location"; 25 | const USER_AGENT: &str = "User-Agent"; 26 | const AUTHORIZATION: &str = "Authorization"; 27 | 28 | type HeaderMap = HashMap<&'static str, String>; 29 | type HeaderName = &'static str; 30 | 31 | /// Specific name for the sync part of the resumable file 32 | const EXTENSION: &str = "part"; 33 | 34 | struct Wrapper<'a, P: Progress, R: Read> { 35 | progress: &'a mut P, 36 | inner: R, 37 | } 38 | 39 | fn wrap_read(inner: R, progress: &mut P) -> Wrapper { 40 | Wrapper { inner, progress } 41 | } 42 | 43 | impl Read for Wrapper<'_, P, R> { 44 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 45 | let read = self.inner.read(buf)?; 46 | self.progress.update(read); 47 | Ok(read) 48 | } 49 | } 50 | 51 | /// Simple wrapper over [`ureq::Agent`] to include default headers 52 | #[derive(Clone, Debug)] 53 | pub struct HeaderAgent { 54 | agent: Agent, 55 | headers: HeaderMap, 56 | } 57 | 58 | impl HeaderAgent { 59 | fn new(agent: Agent, headers: HeaderMap) -> Self { 60 | Self { agent, headers } 61 | } 62 | 63 | fn get(&self, url: &str) -> ureq::Request { 64 | let mut request = self.agent.get(url); 65 | for (header, value) in &self.headers { 66 | request = request.set(header, value); 67 | } 68 | request 69 | } 70 | } 71 | 72 | struct Handle { 73 | file: std::fs::File, 74 | } 75 | 76 | impl Drop for Handle { 77 | fn drop(&mut self) { 78 | unlock(&self.file); 79 | } 80 | } 81 | 82 | fn lock_file(mut path: PathBuf) -> Result { 83 | path.set_extension("lock"); 84 | 85 | let file = std::fs::File::create(path.clone())?; 86 | let mut res = lock(&file); 87 | for _ in 0..5 { 88 | if res == 0 { 89 | break; 90 | } 91 | std::thread::sleep(std::time::Duration::from_secs(1)); 92 | res = lock(&file); 93 | } 94 | if res != 0 { 95 | Err(ApiError::LockAcquisition(path)) 96 | } else { 97 | Ok(Handle { file }) 98 | } 99 | } 100 | 101 | #[cfg(target_family = "unix")] 102 | mod unix { 103 | use std::os::fd::AsRawFd; 104 | 105 | pub(crate) fn lock(file: &std::fs::File) -> i32 { 106 | unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } 107 | } 108 | pub(crate) fn unlock(file: &std::fs::File) -> i32 { 109 | unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } 110 | } 111 | } 112 | #[cfg(target_family = "unix")] 113 | use unix::{lock, unlock}; 114 | 115 | #[cfg(target_family = "windows")] 116 | mod windows { 117 | use std::os::windows::io::AsRawHandle; 118 | use windows_sys::Win32::Foundation::HANDLE; 119 | use windows_sys::Win32::Storage::FileSystem::{ 120 | LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY, 121 | }; 122 | 123 | pub(crate) fn lock(file: &std::fs::File) -> i32 { 124 | unsafe { 125 | let mut overlapped = std::mem::zeroed(); 126 | let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY; 127 | let res = LockFileEx( 128 | file.as_raw_handle() as HANDLE, 129 | flags, 130 | 0, 131 | !0, 132 | !0, 133 | &mut overlapped, 134 | ); 135 | 1 - res 136 | } 137 | } 138 | pub(crate) fn unlock(file: &std::fs::File) -> i32 { 139 | unsafe { UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0) } 140 | } 141 | } 142 | #[cfg(target_family = "windows")] 143 | use windows::{lock, unlock}; 144 | 145 | #[cfg(not(any(target_family = "unix", target_family = "windows")))] 146 | mod other { 147 | pub(crate) fn lock(file: &std::fs::File) -> i32 { 148 | 0 149 | } 150 | pub(crate) fn unlock(file: &std::fs::File) -> i32 { 151 | 0 152 | } 153 | } 154 | #[cfg(not(any(target_family = "unix", target_family = "windows")))] 155 | use other::{lock, unlock}; 156 | 157 | #[derive(Debug, Error)] 158 | /// All errors the API can throw 159 | pub enum ApiError { 160 | /// Api expects certain header to be present in the results to derive some information 161 | #[error("Header {0} is missing")] 162 | MissingHeader(HeaderName), 163 | 164 | /// The header exists, but the value is not conform to what the Api expects. 165 | #[error("Header {0} is invalid")] 166 | InvalidHeader(HeaderName), 167 | 168 | // /// The value cannot be used as a header during request header construction 169 | // #[error("Invalid header value {0}")] 170 | // InvalidHeaderValue(#[from] InvalidHeaderValue), 171 | 172 | // /// The header value is not valid utf-8 173 | // #[error("header value is not a string")] 174 | // ToStr(#[from] ToStrError), 175 | /// Error in the request 176 | #[error("request error: {0}")] 177 | RequestError(#[from] Box), 178 | 179 | /// Error parsing some range value 180 | #[error("Cannot parse int")] 181 | ParseIntError(#[from] ParseIntError), 182 | 183 | /// I/O Error 184 | #[error("I/O error {0}")] 185 | IoError(#[from] std::io::Error), 186 | 187 | /// We tried to download chunk too many times 188 | #[error("Too many retries: {0}")] 189 | TooManyRetries(Box), 190 | 191 | /// Native tls error 192 | #[error("Native tls: {0}")] 193 | #[cfg(feature = "native-tls")] 194 | Native(#[from] native_tls::Error), 195 | 196 | /// The part file is corrupted 197 | #[error("Invalid part file - corrupted file")] 198 | InvalidResume, 199 | 200 | /// We failed to acquire lock for file `f`. Meaning 201 | /// Someone else is writing/downloading said file 202 | #[error("Lock acquisition failed: {0}")] 203 | LockAcquisition(PathBuf), 204 | } 205 | 206 | /// Helper to create [`Api`] with all the options. 207 | #[derive(Debug)] 208 | pub struct ApiBuilder { 209 | endpoint: String, 210 | cache: Cache, 211 | token: Option, 212 | max_retries: usize, 213 | progress: bool, 214 | user_agent: Vec<(String, String)>, 215 | } 216 | 217 | impl Default for ApiBuilder { 218 | fn default() -> Self { 219 | Self::new() 220 | } 221 | } 222 | 223 | impl ApiBuilder { 224 | /// Default api builder 225 | /// ``` 226 | /// use hf_hub::api::sync::ApiBuilder; 227 | /// let api = ApiBuilder::new().build().unwrap(); 228 | /// ``` 229 | pub fn new() -> Self { 230 | let cache = Cache::default(); 231 | Self::from_cache(cache) 232 | } 233 | 234 | /// Creates API with values potentially from environment variables. 235 | /// HF_HOME decides the location of the cache folder 236 | /// HF_ENDPOINT modifies the URL for the huggingface location 237 | /// to download files from. 238 | /// ``` 239 | /// use hf_hub::api::sync::ApiBuilder; 240 | /// let api = ApiBuilder::from_env().build().unwrap(); 241 | /// ``` 242 | pub fn from_env() -> Self { 243 | let cache = Cache::from_env(); 244 | let mut builder = Self::from_cache(cache); 245 | if let Ok(endpoint) = std::env::var(HF_ENDPOINT) { 246 | builder = builder.with_endpoint(endpoint); 247 | } 248 | builder 249 | } 250 | 251 | /// From a given cache 252 | /// ``` 253 | /// use hf_hub::{api::sync::ApiBuilder, Cache}; 254 | /// let path = std::path::PathBuf::from("/tmp"); 255 | /// let cache = Cache::new(path); 256 | /// let api = ApiBuilder::from_cache(cache).build().unwrap(); 257 | /// ``` 258 | pub fn from_cache(cache: Cache) -> Self { 259 | let token = cache.token(); 260 | 261 | let max_retries = 0; 262 | let progress = true; 263 | 264 | let endpoint = "https://huggingface.co".to_string(); 265 | 266 | let user_agent = vec![ 267 | ("unknown".to_string(), "None".to_string()), 268 | (NAME.to_string(), VERSION.to_string()), 269 | ("rust".to_string(), "unknown".to_string()), 270 | ]; 271 | 272 | Self { 273 | endpoint, 274 | cache, 275 | token, 276 | max_retries, 277 | progress, 278 | user_agent, 279 | } 280 | } 281 | 282 | /// Wether to show a progressbar 283 | pub fn with_progress(mut self, progress: bool) -> Self { 284 | self.progress = progress; 285 | self 286 | } 287 | 288 | /// Changes the endpoint of the API. Default is `https://huggingface.co`. 289 | pub fn with_endpoint(mut self, endpoint: String) -> Self { 290 | self.endpoint = endpoint; 291 | self 292 | } 293 | 294 | /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. 295 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 296 | self.cache = Cache::new(cache_dir); 297 | self 298 | } 299 | 300 | /// Sets the token to be used in the API 301 | pub fn with_token(mut self, token: Option) -> Self { 302 | self.token = token; 303 | self 304 | } 305 | 306 | /// Sets the number of times the API will retry to download a file 307 | pub fn with_retries(mut self, max_retries: usize) -> Self { 308 | self.max_retries = max_retries; 309 | self 310 | } 311 | 312 | /// Adds custom fields to headers user-agent 313 | pub fn with_user_agent(mut self, key: &str, value: &str) -> Self { 314 | self.user_agent.push((key.to_string(), value.to_string())); 315 | self 316 | } 317 | 318 | fn build_headers(&self) -> HeaderMap { 319 | let mut headers = HeaderMap::new(); 320 | let user_agent = self 321 | .user_agent 322 | .iter() 323 | .map(|(key, value)| format!("{key}/{value}")) 324 | .collect::>() 325 | .join("; "); 326 | headers.insert(USER_AGENT, user_agent.to_string()); 327 | if let Some(token) = &self.token { 328 | headers.insert(AUTHORIZATION, format!("Bearer {token}")); 329 | } 330 | headers 331 | } 332 | 333 | /// Consumes the builder and buids the final [`Api`] 334 | pub fn build(self) -> Result { 335 | let headers = self.build_headers(); 336 | 337 | let builder = builder()?; 338 | let agent = builder.build(); 339 | let client = HeaderAgent::new(agent, headers.clone()); 340 | 341 | let no_redirect_agent = ureq::builder() 342 | .try_proxy_from_env(true) 343 | .redirects(0) 344 | .build(); 345 | let no_redirect_client = HeaderAgent::new(no_redirect_agent, headers); 346 | 347 | Ok(Api { 348 | endpoint: self.endpoint, 349 | cache: self.cache, 350 | client, 351 | no_redirect_client, 352 | max_retries: self.max_retries, 353 | progress: self.progress, 354 | }) 355 | } 356 | } 357 | 358 | #[derive(Debug)] 359 | struct Metadata { 360 | commit_hash: String, 361 | etag: String, 362 | size: usize, 363 | } 364 | 365 | /// The actual Api used to interacto with the hub. 366 | /// Use any repo with [`Api::repo`] 367 | #[derive(Clone, Debug)] 368 | pub struct Api { 369 | endpoint: String, 370 | cache: Cache, 371 | client: HeaderAgent, 372 | no_redirect_client: HeaderAgent, 373 | max_retries: usize, 374 | progress: bool, 375 | } 376 | 377 | fn make_relative(src: &Path, dst: &Path) -> PathBuf { 378 | let path = src; 379 | let base = dst; 380 | 381 | assert_eq!( 382 | path.is_absolute(), 383 | base.is_absolute(), 384 | "This function is made to look at absolute paths only" 385 | ); 386 | let mut ita = path.components(); 387 | let mut itb = base.components(); 388 | 389 | loop { 390 | match (ita.next(), itb.next()) { 391 | (Some(a), Some(b)) if a == b => (), 392 | (some_a, _) => { 393 | // Ignoring b, because 1 component is the filename 394 | // for which we don't need to go back up for relative 395 | // filename to work. 396 | let mut new_path = PathBuf::new(); 397 | for _ in itb { 398 | new_path.push(Component::ParentDir); 399 | } 400 | if let Some(a) = some_a { 401 | new_path.push(a); 402 | for comp in ita { 403 | new_path.push(comp); 404 | } 405 | } 406 | return new_path; 407 | } 408 | } 409 | } 410 | } 411 | 412 | fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { 413 | if dst.exists() { 414 | return Ok(()); 415 | } 416 | 417 | let rel_src = make_relative(src, dst); 418 | #[cfg(target_os = "windows")] 419 | { 420 | if std::os::windows::fs::symlink_file(rel_src, dst).is_err() { 421 | std::fs::rename(src, dst)?; 422 | } 423 | } 424 | 425 | #[cfg(target_family = "unix")] 426 | std::os::unix::fs::symlink(rel_src, dst)?; 427 | 428 | Ok(()) 429 | } 430 | 431 | fn jitter() -> usize { 432 | rand::thread_rng().gen_range(0..=500) 433 | } 434 | 435 | fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { 436 | (base_wait_time + n.pow(2) + jitter()).min(max) 437 | } 438 | 439 | impl Api { 440 | /// Creates a default Api, for Api options See [`ApiBuilder`] 441 | pub fn new() -> Result { 442 | ApiBuilder::new().build() 443 | } 444 | 445 | /// Get the underlying api client 446 | /// Allows for lower level access 447 | pub fn client(&self) -> &HeaderAgent { 448 | &self.client 449 | } 450 | 451 | fn metadata(&self, url: &str) -> Result { 452 | let mut response = self 453 | .no_redirect_client 454 | .get(url) 455 | .set(RANGE, "bytes=0-0") 456 | .call() 457 | .map_err(Box::new)?; 458 | 459 | // Closure to check if status code is a redirection 460 | let should_redirect = |status_code: u16| { 461 | matches!( 462 | StatusCode::from_u16(status_code).unwrap(), 463 | StatusCode::MOVED_PERMANENTLY 464 | | StatusCode::FOUND 465 | | StatusCode::SEE_OTHER 466 | | StatusCode::TEMPORARY_REDIRECT 467 | | StatusCode::PERMANENT_REDIRECT 468 | ) 469 | }; 470 | 471 | // Follow redirects until `host.is_some()` i.e. only follow relative redirects 472 | // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 473 | let response = loop { 474 | // Check if redirect 475 | if should_redirect(response.status()) { 476 | // Get redirect location 477 | if let Some(location) = response.header("Location") { 478 | // Parse location 479 | let uri = Uri::from_str(location).map_err(|_| InvalidHeader("location"))?; 480 | 481 | // Check if relative i.e. host is none 482 | if uri.host().is_none() { 483 | // Merge relative path with url 484 | let mut parts = Uri::from_str(url).unwrap().into_parts(); 485 | parts.path_and_query = uri.into_parts().path_and_query; 486 | // Final uri 487 | let redirect_uri = Uri::from_parts(parts).unwrap(); 488 | 489 | // Follow redirect 490 | response = self 491 | .no_redirect_client 492 | .get(&redirect_uri.to_string()) 493 | .set(RANGE, "bytes=0-0") 494 | .call() 495 | .map_err(Box::new)?; 496 | continue; 497 | } 498 | }; 499 | } 500 | break response; 501 | }; 502 | 503 | // let headers = response.headers(); 504 | let header_commit = "x-repo-commit"; 505 | let header_linked_etag = "x-linked-etag"; 506 | let header_etag = "etag"; 507 | 508 | let etag = match response.header(header_linked_etag) { 509 | Some(etag) => etag, 510 | None => response 511 | .header(header_etag) 512 | .ok_or(ApiError::MissingHeader(header_etag))?, 513 | }; 514 | // Cleaning extra quotes 515 | let etag = etag.to_string().replace('"', ""); 516 | let commit_hash = response 517 | .header(header_commit) 518 | .ok_or(ApiError::MissingHeader(header_commit))? 519 | .to_string(); 520 | 521 | // The response was redirected o S3 most likely which will 522 | // know about the size of the file 523 | let status = response.status(); 524 | let is_redirection = (300..400).contains(&status); 525 | let response = if is_redirection { 526 | self.client 527 | .get(response.header(LOCATION).unwrap()) 528 | .set(RANGE, "bytes=0-0") 529 | .call() 530 | .map_err(Box::new)? 531 | } else { 532 | response 533 | }; 534 | let content_range = response 535 | .header(CONTENT_RANGE) 536 | .ok_or(ApiError::MissingHeader(CONTENT_RANGE))?; 537 | 538 | let size = content_range 539 | .split('/') 540 | .last() 541 | .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? 542 | .parse()?; 543 | Ok(Metadata { 544 | commit_hash, 545 | etag, 546 | size, 547 | }) 548 | } 549 | 550 | fn download_tempfile( 551 | &self, 552 | url: &str, 553 | size: usize, 554 | mut progress: P, 555 | tmp_path: PathBuf, 556 | filename: &str, 557 | ) -> Result { 558 | progress.init(size, filename); 559 | let filepath = tmp_path; 560 | 561 | // Create the file and set everything properly 562 | 563 | let mut file = match std::fs::OpenOptions::new().append(true).open(&filepath) { 564 | Ok(f) => f, 565 | Err(_) => std::fs::File::create(&filepath)?, 566 | }; 567 | 568 | // In case of resume. 569 | let start = file.metadata()?.len(); 570 | if start > size as u64 { 571 | return Err(ApiError::InvalidResume); 572 | } 573 | 574 | let mut res = self.download_from(url, start, size, &mut file, filename, &mut progress); 575 | if self.max_retries > 0 { 576 | let mut i = 0; 577 | while let Err(dlerr) = res { 578 | let wait_time = exponential_backoff(300, i, 10_000); 579 | std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); 580 | 581 | let current = file.stream_position()?; 582 | res = self.download_from(url, current, size, &mut file, filename, &mut progress); 583 | i += 1; 584 | if i > self.max_retries { 585 | return Err(ApiError::TooManyRetries(dlerr.into())); 586 | } 587 | } 588 | } 589 | res?; 590 | Ok(filepath) 591 | } 592 | 593 | fn download_from

( 594 | &self, 595 | url: &str, 596 | current: u64, 597 | size: usize, 598 | file: &mut std::fs::File, 599 | filename: &str, 600 | progress: &mut P, 601 | ) -> Result<(), ApiError> 602 | where 603 | P: Progress, 604 | { 605 | let range = format!("bytes={current}-"); 606 | let response = self 607 | .client 608 | .get(url) 609 | .set(RANGE, &range) 610 | .call() 611 | .map_err(Box::new)?; 612 | let reader = response.into_reader(); 613 | progress.init(size, filename); 614 | progress.update(current as usize); 615 | let mut reader = Box::new(wrap_read(reader, progress)); 616 | std::io::copy(&mut reader, file)?; 617 | progress.finish(); 618 | Ok(()) 619 | } 620 | 621 | /// Creates a new handle [`ApiRepo`] which contains operations 622 | /// on a particular [`Repo`] 623 | pub fn repo(&self, repo: Repo) -> ApiRepo { 624 | ApiRepo::new(self.clone(), repo) 625 | } 626 | 627 | /// Simple wrapper over 628 | /// ``` 629 | /// # use hf_hub::{api::sync::Api, Repo, RepoType}; 630 | /// # let model_id = "gpt2".to_string(); 631 | /// let api = Api::new().unwrap(); 632 | /// let api = api.repo(Repo::new(model_id, RepoType::Model)); 633 | /// ``` 634 | pub fn model(&self, model_id: String) -> ApiRepo { 635 | self.repo(Repo::new(model_id, RepoType::Model)) 636 | } 637 | 638 | /// Simple wrapper over 639 | /// ``` 640 | /// # use hf_hub::{api::sync::Api, Repo, RepoType}; 641 | /// # let model_id = "gpt2".to_string(); 642 | /// let api = Api::new().unwrap(); 643 | /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); 644 | /// ``` 645 | pub fn dataset(&self, model_id: String) -> ApiRepo { 646 | self.repo(Repo::new(model_id, RepoType::Dataset)) 647 | } 648 | 649 | /// Simple wrapper over 650 | /// ``` 651 | /// # use hf_hub::{api::sync::Api, Repo, RepoType}; 652 | /// # let model_id = "gpt2".to_string(); 653 | /// let api = Api::new().unwrap(); 654 | /// let api = api.repo(Repo::new(model_id, RepoType::Space)); 655 | /// ``` 656 | pub fn space(&self, model_id: String) -> ApiRepo { 657 | self.repo(Repo::new(model_id, RepoType::Space)) 658 | } 659 | } 660 | 661 | /// Shorthand for accessing things within a particular repo 662 | /// You can inspect repos with [`ApiRepo::info`] 663 | /// or download files with [`ApiRepo::download`] 664 | #[derive(Debug)] 665 | pub struct ApiRepo { 666 | api: Api, 667 | repo: Repo, 668 | } 669 | 670 | impl ApiRepo { 671 | fn new(api: Api, repo: Repo) -> Self { 672 | Self { api, repo } 673 | } 674 | } 675 | 676 | #[cfg(feature = "native-tls")] 677 | fn builder() -> Result { 678 | Ok(ureq::builder() 679 | .try_proxy_from_env(true) 680 | .tls_connector(std::sync::Arc::new(native_tls::TlsConnector::new()?))) 681 | } 682 | 683 | #[cfg(not(feature = "native-tls"))] 684 | fn builder() -> Result { 685 | Ok(ureq::builder().try_proxy_from_env(true)) 686 | } 687 | 688 | impl ApiRepo { 689 | /// Get the fully qualified URL of the remote filename 690 | /// ``` 691 | /// # use hf_hub::api::sync::Api; 692 | /// let api = Api::new().unwrap(); 693 | /// let url = api.model("gpt2".to_string()).url("model.safetensors"); 694 | /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); 695 | /// ``` 696 | pub fn url(&self, filename: &str) -> String { 697 | let endpoint = &self.api.endpoint; 698 | let revision = &self.repo.url_revision(); 699 | let repo_id = self.repo.url(); 700 | format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") 701 | } 702 | 703 | /// This will attempt the fetch the file locally first, then [`Api.download`] 704 | /// if the file is not present. 705 | /// ```no_run 706 | /// use hf_hub::{api::sync::Api}; 707 | /// let api = Api::new().unwrap(); 708 | /// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").unwrap(); 709 | pub fn get(&self, filename: &str) -> Result { 710 | if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) { 711 | Ok(path) 712 | } else { 713 | self.download(filename) 714 | } 715 | } 716 | 717 | /// This function is used to download a file with a custom progress function. 718 | /// It uses the [`Progress`] trait and can be used in more complex use 719 | /// cases like downloading a showing progress in a UI. 720 | /// ```no_run 721 | /// # use hf_hub::api::{sync::Api, Progress}; 722 | /// struct MyProgress{ 723 | /// current: usize, 724 | /// total: usize 725 | /// } 726 | /// 727 | /// impl Progress for MyProgress{ 728 | /// fn init(&mut self, size: usize, _filename: &str){ 729 | /// self.total = size; 730 | /// self.current = 0; 731 | /// } 732 | /// 733 | /// fn update(&mut self, size: usize){ 734 | /// self.current += size; 735 | /// println!("{}/{}", self.current, self.total) 736 | /// } 737 | /// 738 | /// fn finish(&mut self){ 739 | /// println!("Done !"); 740 | /// } 741 | /// } 742 | /// let api = Api::new().unwrap(); 743 | /// let progress = MyProgress{current: 0, total: 0}; 744 | /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap(); 745 | /// ``` 746 | pub fn download_with_progress( 747 | &self, 748 | filename: &str, 749 | progress: P, 750 | ) -> Result { 751 | let url = self.url(filename); 752 | let metadata = self.api.metadata(&url)?; 753 | 754 | let blob_path = self 755 | .api 756 | .cache 757 | .repo(self.repo.clone()) 758 | .blob_path(&metadata.etag); 759 | std::fs::create_dir_all(blob_path.parent().unwrap())?; 760 | 761 | let lock = lock_file(blob_path.clone()).unwrap(); 762 | let mut tmp_path = blob_path.clone(); 763 | tmp_path.set_extension(EXTENSION); 764 | let tmp_filename = 765 | self.api 766 | .download_tempfile(&url, metadata.size, progress, tmp_path, filename)?; 767 | 768 | std::fs::rename(tmp_filename, &blob_path)?; 769 | drop(lock); 770 | 771 | let mut pointer_path = self 772 | .api 773 | .cache 774 | .repo(self.repo.clone()) 775 | .pointer_path(&metadata.commit_hash); 776 | pointer_path.push(filename); 777 | std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); 778 | 779 | symlink_or_rename(&blob_path, &pointer_path)?; 780 | self.api 781 | .cache 782 | .repo(self.repo.clone()) 783 | .create_ref(&metadata.commit_hash)?; 784 | 785 | assert!(pointer_path.exists()); 786 | 787 | Ok(pointer_path) 788 | } 789 | 790 | /// Downloads a remote file (if not already present) into the cache directory 791 | /// to be used locally. 792 | /// This functions require internet access to verify if new versions of the file 793 | /// exist, even if a file is already on disk at location. 794 | /// ```no_run 795 | /// # use hf_hub::api::sync::Api; 796 | /// let api = Api::new().unwrap(); 797 | /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); 798 | /// ``` 799 | pub fn download(&self, filename: &str) -> Result { 800 | if self.api.progress { 801 | self.download_with_progress(filename, ProgressBar::new(0)) 802 | } else { 803 | self.download_with_progress(filename, ()) 804 | } 805 | } 806 | 807 | /// Get information about the Repo 808 | /// ``` 809 | /// use hf_hub::{api::sync::Api}; 810 | /// let api = Api::new().unwrap(); 811 | /// api.model("gpt2".to_string()).info(); 812 | /// ``` 813 | pub fn info(&self) -> Result { 814 | Ok(self.info_request().call().map_err(Box::new)?.into_json()?) 815 | } 816 | 817 | /// Get the raw [`ureq::Request`] with the url and method already set 818 | /// ``` 819 | /// # use hf_hub::api::sync::Api; 820 | /// let api = Api::new().unwrap(); 821 | /// api.model("gpt2".to_owned()) 822 | /// .info_request() 823 | /// .query("blobs", "true") 824 | /// .call(); 825 | /// ``` 826 | pub fn info_request(&self) -> Request { 827 | let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url()); 828 | self.api.client.get(&url) 829 | } 830 | } 831 | 832 | #[cfg(test)] 833 | mod tests { 834 | use super::*; 835 | use crate::api::Siblings; 836 | use crate::assert_no_diff; 837 | use hex_literal::hex; 838 | use rand::{distributions::Alphanumeric, Rng}; 839 | use serde_json::{json, Value}; 840 | use sha2::{Digest, Sha256}; 841 | use std::io::{Seek, SeekFrom, Write}; 842 | use std::time::Duration; 843 | 844 | struct TempDir { 845 | path: PathBuf, 846 | } 847 | 848 | impl TempDir { 849 | pub fn new() -> Self { 850 | let s: String = rand::thread_rng() 851 | .sample_iter(&Alphanumeric) 852 | .take(7) 853 | .map(char::from) 854 | .collect(); 855 | let mut path = std::env::temp_dir(); 856 | path.push(s); 857 | std::fs::create_dir(&path).unwrap(); 858 | Self { path } 859 | } 860 | } 861 | 862 | impl Drop for TempDir { 863 | fn drop(&mut self) { 864 | std::fs::remove_dir_all(&self.path).unwrap() 865 | } 866 | } 867 | 868 | #[test] 869 | fn simple() { 870 | let tmp = TempDir::new(); 871 | let api = ApiBuilder::new() 872 | .with_progress(false) 873 | .with_cache_dir(tmp.path.clone()) 874 | .build() 875 | .unwrap(); 876 | 877 | let model_id = "julien-c/dummy-unknown".to_string(); 878 | let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); 879 | assert!(downloaded_path.exists()); 880 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 881 | assert_eq!( 882 | val[..], 883 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 884 | ); 885 | 886 | // Make sure the file is now seeable without connection 887 | let cache_path = api 888 | .cache 889 | .repo(Repo::new(model_id, RepoType::Model)) 890 | .get("config.json") 891 | .unwrap(); 892 | assert_eq!(cache_path, downloaded_path); 893 | } 894 | 895 | #[test] 896 | fn resume() { 897 | let tmp = TempDir::new(); 898 | let api = ApiBuilder::new() 899 | .with_progress(false) 900 | .with_cache_dir(tmp.path.clone()) 901 | .build() 902 | .unwrap(); 903 | 904 | let model_id = "julien-c/dummy-unknown".to_string(); 905 | let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); 906 | assert!(downloaded_path.exists()); 907 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 908 | assert_eq!( 909 | val[..], 910 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 911 | ); 912 | 913 | let blob = std::fs::canonicalize(&downloaded_path).unwrap(); 914 | let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); 915 | let size = file.metadata().unwrap().len(); 916 | let truncate: f32 = rand::random(); 917 | let new_size = (size as f32 * truncate) as u64; 918 | file.set_len(new_size).unwrap(); 919 | let mut blob_part = blob.clone(); 920 | blob_part.set_extension("part"); 921 | std::fs::rename(blob, &blob_part).unwrap(); 922 | std::fs::remove_file(&downloaded_path).unwrap(); 923 | let content = std::fs::read(&*blob_part).unwrap(); 924 | assert_eq!(content.len() as u64, new_size); 925 | let val = Sha256::digest(content); 926 | // We modified the sha. 927 | assert!( 928 | val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 929 | ); 930 | let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); 931 | let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); 932 | assert_eq!(downloaded_path, new_downloaded_path); 933 | assert_eq!( 934 | val[..], 935 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 936 | ); 937 | 938 | // Here we prove the previous part was correctly resuming by purposefully corrupting the 939 | // file. 940 | let blob = std::fs::canonicalize(&downloaded_path).unwrap(); 941 | let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); 942 | let size = file.metadata().unwrap().len(); 943 | // Not random for consistent sha corruption 944 | let truncate: f32 = 0.5; 945 | let new_size = (size as f32 * truncate) as u64; 946 | // Truncating 947 | file.set_len(new_size).unwrap(); 948 | // Corrupting by changing a single byte. 949 | file.seek(SeekFrom::Start(new_size - 1)).unwrap(); 950 | file.write_all(&[0]).unwrap(); 951 | 952 | let mut blob_part = blob.clone(); 953 | blob_part.set_extension("part"); 954 | std::fs::rename(blob, &blob_part).unwrap(); 955 | std::fs::remove_file(&downloaded_path).unwrap(); 956 | let content = std::fs::read(&*blob_part).unwrap(); 957 | assert_eq!(content.len() as u64, new_size); 958 | let val = Sha256::digest(content); 959 | // We modified the sha. 960 | assert!( 961 | val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 962 | ); 963 | let new_downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); 964 | let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); 965 | assert_eq!(downloaded_path, new_downloaded_path); 966 | println!("{new_downloaded_path:?}"); 967 | println!("Corrupted {val:#x}"); 968 | assert_eq!( 969 | val[..], 970 | // Corrupted sha 971 | hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") 972 | ); 973 | } 974 | 975 | #[test] 976 | fn locking() { 977 | use std::sync::{Arc, Mutex}; 978 | let tmp = Arc::new(Mutex::new(TempDir::new())); 979 | 980 | let mut handles = vec![]; 981 | for _ in 0..5 { 982 | let tmp2 = tmp.clone(); 983 | let f = std::thread::spawn(move || { 984 | // 0..256ms sleep to randomize potential clashes 985 | std::thread::sleep(Duration::from_millis(rand::random::().into())); 986 | let api = ApiBuilder::new() 987 | .with_progress(false) 988 | .with_cache_dir(tmp2.lock().unwrap().path.clone()) 989 | .build() 990 | .unwrap(); 991 | 992 | let model_id = "julien-c/dummy-unknown".to_string(); 993 | api.model(model_id.clone()).download("config.json").unwrap() 994 | }); 995 | handles.push(f); 996 | } 997 | while let Some(handle) = handles.pop() { 998 | let downloaded_path = handle.join().unwrap(); 999 | assert!(downloaded_path.exists()); 1000 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1001 | assert_eq!( 1002 | val[..], 1003 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1004 | ); 1005 | } 1006 | } 1007 | 1008 | #[test] 1009 | fn simple_with_retries() { 1010 | let tmp = TempDir::new(); 1011 | let api = ApiBuilder::new() 1012 | .with_progress(false) 1013 | .with_cache_dir(tmp.path.clone()) 1014 | .with_retries(3) 1015 | .build() 1016 | .unwrap(); 1017 | 1018 | let model_id = "julien-c/dummy-unknown".to_string(); 1019 | let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); 1020 | assert!(downloaded_path.exists()); 1021 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1022 | assert_eq!( 1023 | val[..], 1024 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1025 | ); 1026 | 1027 | // Make sure the file is now seeable without connection 1028 | let cache_path = api 1029 | .cache 1030 | .repo(Repo::new(model_id, RepoType::Model)) 1031 | .get("config.json") 1032 | .unwrap(); 1033 | assert_eq!(cache_path, downloaded_path); 1034 | } 1035 | 1036 | #[test] 1037 | fn dataset() { 1038 | let tmp = TempDir::new(); 1039 | let api = ApiBuilder::new() 1040 | .with_progress(false) 1041 | .with_cache_dir(tmp.path.clone()) 1042 | .build() 1043 | .unwrap(); 1044 | let repo = Repo::with_revision( 1045 | "wikitext".to_string(), 1046 | RepoType::Dataset, 1047 | "refs/convert/parquet".to_string(), 1048 | ); 1049 | let downloaded_path = api 1050 | .repo(repo) 1051 | .download("wikitext-103-v1/test/0000.parquet") 1052 | .unwrap(); 1053 | assert!(downloaded_path.exists()); 1054 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1055 | assert_eq!( 1056 | val[..], 1057 | hex!("ABDFC9F83B1103B502924072460D4C92F277C9B49C313CEF3E48CFCF7428E125") 1058 | ); 1059 | } 1060 | 1061 | #[test] 1062 | fn models() { 1063 | let tmp = TempDir::new(); 1064 | let api = ApiBuilder::new() 1065 | .with_progress(false) 1066 | .with_cache_dir(tmp.path.clone()) 1067 | .build() 1068 | .unwrap(); 1069 | let repo = Repo::with_revision( 1070 | "BAAI/bGe-reRanker-Base".to_string(), 1071 | RepoType::Model, 1072 | "refs/pr/5".to_string(), 1073 | ); 1074 | let downloaded_path = api.repo(repo).download("tokenizer.json").unwrap(); 1075 | assert!(downloaded_path.exists()); 1076 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1077 | assert_eq!( 1078 | val[..], 1079 | hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E") 1080 | ); 1081 | } 1082 | 1083 | #[test] 1084 | fn info() { 1085 | let tmp = TempDir::new(); 1086 | let api = ApiBuilder::new() 1087 | .with_progress(false) 1088 | .with_cache_dir(tmp.path.clone()) 1089 | .build() 1090 | .unwrap(); 1091 | let repo = Repo::with_revision( 1092 | "wikitext".to_string(), 1093 | RepoType::Dataset, 1094 | "refs/convert/parquet".to_string(), 1095 | ); 1096 | let model_info = api.repo(repo).info().unwrap(); 1097 | assert_eq!( 1098 | model_info, 1099 | RepoInfo { 1100 | siblings: vec![ 1101 | Siblings { 1102 | rfilename: ".gitattributes".to_string() 1103 | }, 1104 | Siblings { 1105 | rfilename: "wikitext-103-raw-v1/test/0000.parquet".to_string() 1106 | }, 1107 | Siblings { 1108 | rfilename: "wikitext-103-raw-v1/train/0000.parquet".to_string() 1109 | }, 1110 | Siblings { 1111 | rfilename: "wikitext-103-raw-v1/train/0001.parquet".to_string() 1112 | }, 1113 | Siblings { 1114 | rfilename: "wikitext-103-raw-v1/validation/0000.parquet".to_string() 1115 | }, 1116 | Siblings { 1117 | rfilename: "wikitext-103-v1/test/0000.parquet".to_string() 1118 | }, 1119 | Siblings { 1120 | rfilename: "wikitext-103-v1/train/0000.parquet".to_string() 1121 | }, 1122 | Siblings { 1123 | rfilename: "wikitext-103-v1/train/0001.parquet".to_string() 1124 | }, 1125 | Siblings { 1126 | rfilename: "wikitext-103-v1/validation/0000.parquet".to_string() 1127 | }, 1128 | Siblings { 1129 | rfilename: "wikitext-2-raw-v1/test/0000.parquet".to_string() 1130 | }, 1131 | Siblings { 1132 | rfilename: "wikitext-2-raw-v1/train/0000.parquet".to_string() 1133 | }, 1134 | Siblings { 1135 | rfilename: "wikitext-2-raw-v1/validation/0000.parquet".to_string() 1136 | }, 1137 | Siblings { 1138 | rfilename: "wikitext-2-v1/test/0000.parquet".to_string() 1139 | }, 1140 | Siblings { 1141 | rfilename: "wikitext-2-v1/train/0000.parquet".to_string() 1142 | }, 1143 | Siblings { 1144 | rfilename: "wikitext-2-v1/validation/0000.parquet".to_string() 1145 | } 1146 | ], 1147 | sha: "3f68cd45302c7b4b532d933e71d9e6e54b1c7d5e".to_string() 1148 | } 1149 | ); 1150 | } 1151 | 1152 | #[test] 1153 | fn detailed_info() { 1154 | let tmp = TempDir::new(); 1155 | let api = ApiBuilder::new() 1156 | .with_progress(false) 1157 | .with_token(None) 1158 | .with_cache_dir(tmp.path.clone()) 1159 | .build() 1160 | .unwrap(); 1161 | let repo = Repo::with_revision( 1162 | "mcpotato/42-eicar-street".to_string(), 1163 | RepoType::Model, 1164 | "8b3861f6931c4026b0cd22b38dbc09e7668983ac".to_string(), 1165 | ); 1166 | let blobs_info: Value = api 1167 | .repo(repo) 1168 | .info_request() 1169 | .query("blobs", "true") 1170 | .call() 1171 | .unwrap() 1172 | .into_json() 1173 | .unwrap(); 1174 | assert_no_diff!( 1175 | blobs_info, 1176 | json!({ 1177 | "_id": "621ffdc136468d709f17ddb4", 1178 | "author": "mcpotato", 1179 | "createdAt": "2022-03-02T23:29:05.000Z", 1180 | "disabled": false, 1181 | "downloads": 0, 1182 | "gated": false, 1183 | "id": "mcpotato/42-eicar-street", 1184 | "lastModified": "2022-11-30T19:54:16.000Z", 1185 | "likes": 2, 1186 | "modelId": "mcpotato/42-eicar-street", 1187 | "private": false, 1188 | "sha": "8b3861f6931c4026b0cd22b38dbc09e7668983ac", 1189 | "siblings": [ 1190 | { 1191 | "blobId": "6d34772f5ca361021038b404fb913ec8dc0b1a5a", 1192 | "rfilename": ".gitattributes", 1193 | "size": 1175 1194 | }, 1195 | { 1196 | "blobId": "be98037f7c542112c15a1d2fc7e2a2427e42cb50", 1197 | "rfilename": "build_pickles.py", 1198 | "size": 304 1199 | }, 1200 | { 1201 | "blobId": "8acd02161fff53f9df9597e377e22b04bc34feff", 1202 | "rfilename": "danger.dat", 1203 | "size": 66 1204 | }, 1205 | { 1206 | "blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e", 1207 | "rfilename": "eicar_test_file", 1208 | "size": 70 1209 | }, 1210 | { 1211 | "blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e", 1212 | "rfilename": "eicar_test_file_bis", 1213 | "size":70 1214 | }, 1215 | { 1216 | "blobId": "cd1c6d8bde5006076655711a49feae66f07d707e", 1217 | "lfs": { 1218 | "pointerSize": 127, 1219 | "sha256": "f9343d7d7ec5c3d8bcced056c438fc9f1d3819e9ca3d42418a40857050e10e20", 1220 | "size": 22 1221 | }, 1222 | "rfilename": "pytorch_model.bin", 1223 | "size": 22 1224 | }, 1225 | { 1226 | "blobId": "8ab39654695136173fee29cba0193f679dfbd652", 1227 | "rfilename": "supposedly_safe.pkl", 1228 | "size": 31 1229 | } 1230 | ], 1231 | "spaces": [], 1232 | "tags": ["pytorch", "region:us"], 1233 | "usedStorage": 22 1234 | }) 1235 | ); 1236 | } 1237 | 1238 | #[test] 1239 | fn endpoint() { 1240 | let api = ApiBuilder::new().build().unwrap(); 1241 | assert_eq!(api.endpoint, "https://huggingface.co".to_string()); 1242 | let fake_endpoint = "https://fake_endpoint.com".to_string(); 1243 | let api = ApiBuilder::new() 1244 | .with_endpoint(fake_endpoint.clone()) 1245 | .build() 1246 | .unwrap(); 1247 | assert_eq!(api.endpoint, fake_endpoint); 1248 | } 1249 | 1250 | #[test] 1251 | fn headers_with_token() { 1252 | let api = ApiBuilder::new() 1253 | .with_token(Some("token".to_string())) 1254 | .build() 1255 | .unwrap(); 1256 | let headers = api.client.headers; 1257 | assert_eq!( 1258 | headers.get("Authorization"), 1259 | Some(&"Bearer token".to_string()) 1260 | ); 1261 | } 1262 | 1263 | #[test] 1264 | fn headers_default() { 1265 | let api = ApiBuilder::new().build().unwrap(); 1266 | let headers = api.client.headers; 1267 | assert_eq!( 1268 | headers.get(USER_AGENT), 1269 | Some(&"unknown/None; hf-hub/0.4.2; rust/unknown".to_string()) 1270 | ); 1271 | } 1272 | 1273 | #[test] 1274 | fn headers_custom() { 1275 | let api = ApiBuilder::new() 1276 | .with_user_agent("origin", "custom") 1277 | .build() 1278 | .unwrap(); 1279 | let headers = api.client.headers; 1280 | assert_eq!( 1281 | headers.get(USER_AGENT), 1282 | Some(&"unknown/None; hf-hub/0.4.2; rust/unknown; origin/custom".to_string()) 1283 | ); 1284 | } 1285 | 1286 | // #[test] 1287 | // fn real() { 1288 | // let api = Api::new().unwrap(); 1289 | // let repo = api.model("bert-base-uncased".to_string()); 1290 | // let weights = repo.get("model.safetensors").unwrap(); 1291 | // let val = Sha256::digest(std::fs::read(&*weights).unwrap()); 1292 | // assert_eq!( 1293 | // val[..], 1294 | // hex!("68d45e234eb4a928074dfd868cead0219ab85354cc53d20e772753c6bb9169d3") 1295 | // ); 1296 | // } 1297 | } 1298 | -------------------------------------------------------------------------------- /src/api/tokio.rs: -------------------------------------------------------------------------------- 1 | use super::Progress as SyncProgress; 2 | use super::{RepoInfo, HF_ENDPOINT}; 3 | use crate::{Cache, Repo, RepoType}; 4 | use futures::stream::FuturesUnordered; 5 | use futures::StreamExt; 6 | use indicatif::ProgressBar; 7 | use rand::Rng; 8 | use reqwest::{ 9 | header::{ 10 | HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, 11 | CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, 12 | }, 13 | redirect::Policy, 14 | Client, Error as ReqwestError, RequestBuilder, 15 | }; 16 | use std::cmp::Reverse; 17 | use std::collections::BinaryHeap; 18 | use std::num::ParseIntError; 19 | use std::path::{Component, Path, PathBuf}; 20 | use std::sync::Arc; 21 | use thiserror::Error; 22 | use tokio::io::AsyncReadExt; 23 | use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; 24 | use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; 25 | use tokio::task::JoinError; 26 | 27 | /// Current version (used in user-agent) 28 | const VERSION: &str = env!("CARGO_PKG_VERSION"); 29 | /// Current name (used in user-agent) 30 | const NAME: &str = env!("CARGO_PKG_NAME"); 31 | 32 | const EXTENSION: &str = "sync.part"; 33 | 34 | /// This trait is used by users of the lib 35 | /// to implement custom behavior during file downloads 36 | pub trait Progress { 37 | /// At the start of the download 38 | /// The size is the total size in bytes of the file. 39 | fn init(&mut self, size: usize, filename: &str) 40 | -> impl std::future::Future + Send; 41 | /// This function is called whenever `size` bytes have been 42 | /// downloaded in the temporary file 43 | fn update(&mut self, size: usize) -> impl std::future::Future + Send; 44 | /// This is called at the end of the download 45 | fn finish(&mut self) -> impl std::future::Future + Send; 46 | } 47 | 48 | impl Progress for ProgressBar { 49 | async fn init(&mut self, size: usize, filename: &str) { 50 | ::init(self, size, filename); 51 | } 52 | async fn finish(&mut self) { 53 | ::finish(self); 54 | } 55 | async fn update(&mut self, size: usize) { 56 | ::update(self, size); 57 | } 58 | } 59 | 60 | impl Progress for () { 61 | async fn init(&mut self, _size: usize, _filename: &str) {} 62 | async fn finish(&mut self) {} 63 | async fn update(&mut self, _size: usize) {} 64 | } 65 | 66 | struct Handle { 67 | file: tokio::fs::File, 68 | } 69 | 70 | impl Drop for Handle { 71 | fn drop(&mut self) { 72 | unlock(&self.file); 73 | } 74 | } 75 | 76 | async fn lock_file(mut path: PathBuf) -> Result { 77 | path.set_extension("lock"); 78 | 79 | let file = tokio::fs::File::create(path.clone()).await?; 80 | let mut res = lock(&file); 81 | for _ in 0..5 { 82 | if res == 0 { 83 | break; 84 | } 85 | tokio::time::sleep(std::time::Duration::from_secs(1)).await; 86 | res = lock(&file); 87 | } 88 | if res != 0 { 89 | Err(ApiError::LockAcquisition(path)) 90 | } else { 91 | Ok(Handle { file }) 92 | } 93 | } 94 | 95 | #[cfg(target_family = "unix")] 96 | mod unix { 97 | use std::os::fd::AsRawFd; 98 | 99 | pub(crate) fn lock(file: &tokio::fs::File) -> i32 { 100 | unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } 101 | } 102 | pub(crate) fn unlock(file: &tokio::fs::File) -> i32 { 103 | unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } 104 | } 105 | } 106 | #[cfg(target_family = "unix")] 107 | use unix::{lock, unlock}; 108 | 109 | #[cfg(target_family = "windows")] 110 | mod windows { 111 | use std::os::windows::io::AsRawHandle; 112 | use windows_sys::Win32::Foundation::HANDLE; 113 | use windows_sys::Win32::Storage::FileSystem::{ 114 | LockFileEx, UnlockFile, LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY, 115 | }; 116 | 117 | pub(crate) fn lock(file: &tokio::fs::File) -> i32 { 118 | unsafe { 119 | let mut overlapped = std::mem::zeroed(); 120 | let flags = LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY; 121 | let res = LockFileEx( 122 | file.as_raw_handle() as HANDLE, 123 | flags, 124 | 0, 125 | !0, 126 | !0, 127 | &mut overlapped, 128 | ); 129 | 1 - res 130 | } 131 | } 132 | pub(crate) fn unlock(file: &tokio::fs::File) -> i32 { 133 | unsafe { UnlockFile(file.as_raw_handle() as HANDLE, 0, 0, !0, !0) } 134 | } 135 | } 136 | #[cfg(target_family = "windows")] 137 | use windows::{lock, unlock}; 138 | 139 | #[cfg(not(any(target_family = "unix", target_family = "windows")))] 140 | mod other { 141 | pub(crate) fn lock(file: &tokio::fs::File) -> i32 { 142 | 0 143 | } 144 | pub(crate) fn unlock(file: &tokio::fs::File) -> i32 { 145 | 0 146 | } 147 | } 148 | #[cfg(not(any(target_family = "unix", target_family = "windows")))] 149 | use other::{lock, unlock}; 150 | 151 | #[derive(Debug, Error)] 152 | /// All errors the API can throw 153 | pub enum ApiError { 154 | /// Api expects certain header to be present in the results to derive some information 155 | #[error("Header {0} is missing")] 156 | MissingHeader(HeaderName), 157 | 158 | /// The header exists, but the value is not conform to what the Api expects. 159 | #[error("Header {0} is invalid")] 160 | InvalidHeader(HeaderName), 161 | 162 | /// The value cannot be used as a header during request header construction 163 | #[error("Invalid header value {0}")] 164 | InvalidHeaderValue(#[from] InvalidHeaderValue), 165 | 166 | /// The header value is not valid utf-8 167 | #[error("header value is not a string")] 168 | ToStr(#[from] ToStrError), 169 | 170 | /// Error in the request 171 | #[error("request error: {0}")] 172 | RequestError(#[from] ReqwestError), 173 | 174 | /// Error parsing some range value 175 | #[error("Cannot parse int")] 176 | ParseIntError(#[from] ParseIntError), 177 | 178 | /// I/O Error 179 | #[error("I/O error {0}")] 180 | IoError(#[from] std::io::Error), 181 | 182 | /// We tried to download chunk too many times 183 | #[error("Too many retries: {0}")] 184 | TooManyRetries(Box), 185 | 186 | /// Semaphore cannot be acquired 187 | #[error("Try acquire: {0}")] 188 | TryAcquireError(#[from] TryAcquireError), 189 | 190 | /// Semaphore cannot be acquired 191 | #[error("Acquire: {0}")] 192 | AcquireError(#[from] AcquireError), 193 | // /// Semaphore cannot be acquired 194 | // #[error("Invalid Response: {0:?}")] 195 | // InvalidResponse(Response), 196 | /// Join failed 197 | #[error("Join: {0}")] 198 | Join(#[from] JoinError), 199 | 200 | /// We failed to acquire lock for file `f`. Meaning 201 | /// Someone else is writing/downloading said file 202 | #[error("Lock acquisition failed: {0}")] 203 | LockAcquisition(PathBuf), 204 | } 205 | 206 | /// Helper to create [`Api`] with all the options. 207 | #[derive(Debug)] 208 | pub struct ApiBuilder { 209 | endpoint: String, 210 | cache: Cache, 211 | token: Option, 212 | max_files: usize, 213 | chunk_size: Option, 214 | parallel_failures: usize, 215 | max_retries: usize, 216 | progress: bool, 217 | user_agent: Vec<(String, String)>, 218 | } 219 | 220 | impl Default for ApiBuilder { 221 | fn default() -> Self { 222 | Self::new() 223 | } 224 | } 225 | 226 | impl ApiBuilder { 227 | /// Default api builder 228 | /// ``` 229 | /// use hf_hub::api::tokio::ApiBuilder; 230 | /// let api = ApiBuilder::new().build().unwrap(); 231 | /// ``` 232 | pub fn new() -> Self { 233 | let cache = Cache::default(); 234 | Self::from_cache(cache) 235 | } 236 | 237 | /// Creates API with values potentially from environment variables. 238 | /// HF_HOME decides the location of the cache folder 239 | /// HF_ENDPOINT modifies the URL for the huggingface location 240 | /// to download files from. 241 | /// ``` 242 | /// use hf_hub::api::tokio::ApiBuilder; 243 | /// let api = ApiBuilder::from_env().build().unwrap(); 244 | /// ``` 245 | pub fn from_env() -> Self { 246 | let cache = Cache::from_env(); 247 | let mut builder = Self::from_cache(cache); 248 | if let Ok(endpoint) = std::env::var(HF_ENDPOINT) { 249 | builder = builder.with_endpoint(endpoint); 250 | } 251 | builder 252 | } 253 | 254 | /// High CPU download 255 | /// 256 | /// This may cause issues on regular desktops as it will saturate 257 | /// CPUs by multiplexing the downloads. 258 | /// However on high CPU machines on the cloud, this may help 259 | /// saturate the bandwidth (>500MB/s) better. 260 | /// ``` 261 | /// use hf_hub::api::tokio::ApiBuilder; 262 | /// let api = ApiBuilder::new().high().build().unwrap(); 263 | /// ``` 264 | pub fn high(self) -> Self { 265 | self.with_max_files(num_cpus::get()) 266 | .with_chunk_size(Some(10_000_000)) 267 | } 268 | 269 | /// From a given cache 270 | /// ``` 271 | /// use hf_hub::{api::tokio::ApiBuilder, Cache}; 272 | /// let path = std::path::PathBuf::from("/tmp"); 273 | /// let cache = Cache::new(path); 274 | /// let api = ApiBuilder::from_cache(cache).build().unwrap(); 275 | /// ``` 276 | pub fn from_cache(cache: Cache) -> Self { 277 | let token = cache.token(); 278 | 279 | let progress = true; 280 | 281 | let user_agent = vec![ 282 | ("unknown".to_string(), "None".to_string()), 283 | (NAME.to_string(), VERSION.to_string()), 284 | ("rust".to_string(), "unknown".to_string()), 285 | ]; 286 | 287 | Self { 288 | endpoint: "https://huggingface.co".to_string(), 289 | cache, 290 | token, 291 | max_files: 1, 292 | // We need to have some chunk size for things to be able to resume. 293 | chunk_size: Some(10_000_000), 294 | parallel_failures: 0, 295 | max_retries: 0, 296 | progress, 297 | user_agent, 298 | } 299 | } 300 | 301 | /// Wether to show a progressbar 302 | pub fn with_progress(mut self, progress: bool) -> Self { 303 | self.progress = progress; 304 | self 305 | } 306 | 307 | /// Changes the endpoint of the API. Default is `https://huggingface.co`. 308 | pub fn with_endpoint(mut self, endpoint: String) -> Self { 309 | self.endpoint = endpoint; 310 | self 311 | } 312 | 313 | /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. 314 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 315 | self.cache = Cache::new(cache_dir); 316 | self 317 | } 318 | 319 | /// Sets the token to be used in the API 320 | pub fn with_token(mut self, token: Option) -> Self { 321 | self.token = token; 322 | self 323 | } 324 | 325 | /// Sets the number of open files 326 | pub fn with_max_files(mut self, max_files: usize) -> Self { 327 | self.max_files = max_files; 328 | self 329 | } 330 | 331 | /// Sets the size of each chunk 332 | pub fn with_chunk_size(mut self, chunk_size: Option) -> Self { 333 | self.chunk_size = chunk_size; 334 | self 335 | } 336 | 337 | /// Adds custom fields to headers user-agent 338 | pub fn with_user_agent(mut self, key: &str, value: &str) -> Self { 339 | self.user_agent.push((key.to_string(), value.to_string())); 340 | self 341 | } 342 | 343 | fn build_headers(&self) -> Result { 344 | let mut headers = HeaderMap::new(); 345 | let user_agent = self 346 | .user_agent 347 | .iter() 348 | .map(|(key, value)| format!("{key}/{value}")) 349 | .collect::>() 350 | .join("; "); 351 | headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); 352 | if let Some(token) = &self.token { 353 | headers.insert( 354 | AUTHORIZATION, 355 | HeaderValue::from_str(&format!("Bearer {token}"))?, 356 | ); 357 | } 358 | Ok(headers) 359 | } 360 | 361 | /// Consumes the builder and builds the final [`Api`] 362 | pub fn build(self) -> Result { 363 | let headers = self.build_headers()?; 364 | let client = Client::builder().default_headers(headers.clone()).build()?; 365 | 366 | // Policy: only follow relative redirects 367 | // See: https://github.com/huggingface/huggingface_hub/blob/9c6af39cdce45b570f0b7f8fad2b311c96019804/src/huggingface_hub/file_download.py#L411 368 | let relative_redirect_policy = Policy::custom(|attempt| { 369 | // Follow redirects up to a maximum of 10. 370 | if attempt.previous().len() > 10 { 371 | return attempt.error("too many redirects"); 372 | } 373 | 374 | if let Some(last) = attempt.previous().last() { 375 | // If the url is not relative 376 | if last.make_relative(attempt.url()).is_none() { 377 | return attempt.stop(); 378 | } 379 | } 380 | 381 | // Follow redirect 382 | attempt.follow() 383 | }); 384 | 385 | let relative_redirect_client = Client::builder() 386 | .redirect(relative_redirect_policy) 387 | .default_headers(headers) 388 | .build()?; 389 | Ok(Api { 390 | endpoint: self.endpoint, 391 | cache: self.cache, 392 | client, 393 | relative_redirect_client, 394 | max_files: self.max_files, 395 | chunk_size: self.chunk_size, 396 | parallel_failures: self.parallel_failures, 397 | max_retries: self.max_retries, 398 | progress: self.progress, 399 | }) 400 | } 401 | } 402 | 403 | #[derive(Debug)] 404 | struct Metadata { 405 | commit_hash: String, 406 | etag: String, 407 | size: usize, 408 | } 409 | 410 | /// The actual Api used to interact with the hub. 411 | /// Use any repo with [`Api::repo`] 412 | #[derive(Clone, Debug)] 413 | pub struct Api { 414 | endpoint: String, 415 | cache: Cache, 416 | client: Client, 417 | relative_redirect_client: Client, 418 | max_files: usize, 419 | chunk_size: Option, 420 | parallel_failures: usize, 421 | max_retries: usize, 422 | progress: bool, 423 | } 424 | 425 | fn make_relative(src: &Path, dst: &Path) -> PathBuf { 426 | let path = src; 427 | let base = dst; 428 | 429 | assert_eq!( 430 | path.is_absolute(), 431 | base.is_absolute(), 432 | "This function is made to look at absolute paths only" 433 | ); 434 | let mut ita = path.components(); 435 | let mut itb = base.components(); 436 | 437 | loop { 438 | match (ita.next(), itb.next()) { 439 | (Some(a), Some(b)) if a == b => (), 440 | (some_a, _) => { 441 | // Ignoring b, because 1 component is the filename 442 | // for which we don't need to go back up for relative 443 | // filename to work. 444 | let mut new_path = PathBuf::new(); 445 | for _ in itb { 446 | new_path.push(Component::ParentDir); 447 | } 448 | if let Some(a) = some_a { 449 | new_path.push(a); 450 | for comp in ita { 451 | new_path.push(comp); 452 | } 453 | } 454 | return new_path; 455 | } 456 | } 457 | } 458 | } 459 | 460 | fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { 461 | if dst.exists() { 462 | return Ok(()); 463 | } 464 | 465 | let rel_src = make_relative(src, dst); 466 | #[cfg(target_os = "windows")] 467 | { 468 | if std::os::windows::fs::symlink_file(rel_src, dst).is_err() { 469 | std::fs::rename(src, dst)?; 470 | } 471 | } 472 | 473 | #[cfg(target_family = "unix")] 474 | std::os::unix::fs::symlink(rel_src, dst)?; 475 | 476 | Ok(()) 477 | } 478 | 479 | fn jitter() -> usize { 480 | rand::thread_rng().gen_range(0..=500) 481 | } 482 | 483 | fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { 484 | (base_wait_time + n.pow(2) + jitter()).min(max) 485 | } 486 | 487 | impl Api { 488 | /// Creates a default Api, for Api options See [`ApiBuilder`] 489 | pub fn new() -> Result { 490 | ApiBuilder::new().build() 491 | } 492 | 493 | /// Get the underlying api client 494 | /// Allows for lower level access 495 | pub fn client(&self) -> &Client { 496 | &self.client 497 | } 498 | 499 | async fn metadata(&self, url: &str) -> Result { 500 | let response = self 501 | .relative_redirect_client 502 | .get(url) 503 | .header(RANGE, "bytes=0-0") 504 | .send() 505 | .await?; 506 | let response = response.error_for_status()?; 507 | let headers = response.headers(); 508 | let header_commit = HeaderName::from_static("x-repo-commit"); 509 | let header_linked_etag = HeaderName::from_static("x-linked-etag"); 510 | let header_etag = HeaderName::from_static("etag"); 511 | 512 | let etag = match headers.get(&header_linked_etag) { 513 | Some(etag) => etag, 514 | None => headers 515 | .get(&header_etag) 516 | .ok_or(ApiError::MissingHeader(header_etag))?, 517 | }; 518 | // Cleaning extra quotes 519 | let etag = etag.to_str()?.to_string().replace('"', ""); 520 | let commit_hash = headers 521 | .get(&header_commit) 522 | .ok_or(ApiError::MissingHeader(header_commit))? 523 | .to_str()? 524 | .to_string(); 525 | 526 | // The response was redirected o S3 most likely which will 527 | // know about the size of the file 528 | let response = if response.status().is_redirection() { 529 | self.client 530 | .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) 531 | .header(RANGE, "bytes=0-0") 532 | .send() 533 | .await? 534 | } else { 535 | response 536 | }; 537 | let headers = response.headers(); 538 | let content_range = headers 539 | .get(CONTENT_RANGE) 540 | .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? 541 | .to_str()?; 542 | 543 | let size = content_range 544 | .split('/') 545 | .last() 546 | .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? 547 | .parse()?; 548 | Ok(Metadata { 549 | commit_hash, 550 | etag, 551 | size, 552 | }) 553 | } 554 | 555 | /// Creates a new handle [`ApiRepo`] which contains operations 556 | /// on a particular [`Repo`] 557 | pub fn repo(&self, repo: Repo) -> ApiRepo { 558 | ApiRepo::new(self.clone(), repo) 559 | } 560 | 561 | /// Simple wrapper over 562 | /// ``` 563 | /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; 564 | /// # let model_id = "gpt2".to_string(); 565 | /// let api = Api::new().unwrap(); 566 | /// let api = api.repo(Repo::new(model_id, RepoType::Model)); 567 | /// ``` 568 | pub fn model(&self, model_id: String) -> ApiRepo { 569 | self.repo(Repo::new(model_id, RepoType::Model)) 570 | } 571 | 572 | /// Simple wrapper over 573 | /// ``` 574 | /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; 575 | /// # let model_id = "gpt2".to_string(); 576 | /// let api = Api::new().unwrap(); 577 | /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); 578 | /// ``` 579 | pub fn dataset(&self, model_id: String) -> ApiRepo { 580 | self.repo(Repo::new(model_id, RepoType::Dataset)) 581 | } 582 | 583 | /// Simple wrapper over 584 | /// ``` 585 | /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; 586 | /// # let model_id = "gpt2".to_string(); 587 | /// let api = Api::new().unwrap(); 588 | /// let api = api.repo(Repo::new(model_id, RepoType::Space)); 589 | /// ``` 590 | pub fn space(&self, model_id: String) -> ApiRepo { 591 | self.repo(Repo::new(model_id, RepoType::Space)) 592 | } 593 | } 594 | 595 | /// Shorthand for accessing things within a particular repo 596 | /// You can inspect repos with [`ApiRepo::info`] 597 | /// or download files with [`ApiRepo::download`] 598 | #[derive(Debug)] 599 | pub struct ApiRepo { 600 | api: Api, 601 | repo: Repo, 602 | } 603 | 604 | impl ApiRepo { 605 | fn new(api: Api, repo: Repo) -> Self { 606 | Self { api, repo } 607 | } 608 | } 609 | 610 | impl ApiRepo { 611 | /// Get the fully qualified URL of the remote filename 612 | /// ``` 613 | /// # use hf_hub::api::tokio::Api; 614 | /// let api = Api::new().unwrap(); 615 | /// let url = api.model("gpt2".to_string()).url("model.safetensors"); 616 | /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); 617 | /// ``` 618 | pub fn url(&self, filename: &str) -> String { 619 | let endpoint = &self.api.endpoint; 620 | let revision = &self.repo.url_revision(); 621 | let repo_id = self.repo.url(); 622 | format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") 623 | } 624 | 625 | async fn download_tempfile( 626 | &self, 627 | url: &str, 628 | length: usize, 629 | filename: PathBuf, 630 | mut progressbar: P, 631 | ) -> Result { 632 | let semaphore = Arc::new(Semaphore::new(self.api.max_files)); 633 | let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); 634 | 635 | // Create the file and set everything properly 636 | const N_BYTES: usize = size_of::(); 637 | 638 | let start = match tokio::fs::OpenOptions::new() 639 | .read(true) 640 | .open(&filename) 641 | .await 642 | { 643 | Ok(mut f) => { 644 | let len = f.metadata().await?.len(); 645 | if len == (length + N_BYTES) as u64 { 646 | f.seek(SeekFrom::Start(length as u64)).await?; 647 | let mut buf = [0u8; N_BYTES]; 648 | let n = f.read(buf.as_mut_slice()).await?; 649 | if n == N_BYTES { 650 | let committed = u64::from_le_bytes(buf); 651 | committed as usize 652 | } else { 653 | 0 654 | } 655 | } else { 656 | 0 657 | } 658 | } 659 | Err(_err) => { 660 | tokio::fs::File::create(&filename) 661 | .await? 662 | .set_len((length + N_BYTES) as u64) 663 | .await?; 664 | 0 665 | } 666 | }; 667 | progressbar.update(start).await; 668 | 669 | let chunk_size = self.api.chunk_size.unwrap_or(length); 670 | let n_chunks = length / chunk_size; 671 | let mut handles = Vec::with_capacity(n_chunks); 672 | for start in (start..length).step_by(chunk_size) { 673 | let url = url.to_string(); 674 | let filename = filename.clone(); 675 | let client = self.api.client.clone(); 676 | 677 | let stop = std::cmp::min(start + chunk_size - 1, length); 678 | let permit = semaphore.clone(); 679 | let parallel_failures = self.api.parallel_failures; 680 | let max_retries = self.api.max_retries; 681 | let parallel_failures_semaphore = parallel_failures_semaphore.clone(); 682 | let progress = progressbar.clone(); 683 | handles.push(tokio::spawn(async move { 684 | let permit = permit.acquire_owned().await?; 685 | let mut chunk = 686 | Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) 687 | .await; 688 | let mut i = 0; 689 | if parallel_failures > 0 { 690 | while let Err(dlerr) = chunk { 691 | let parallel_failure_permit = 692 | parallel_failures_semaphore.clone().try_acquire_owned()?; 693 | 694 | let wait_time = exponential_backoff(300, i, 10_000); 695 | tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) 696 | .await; 697 | 698 | chunk = Self::download_chunk( 699 | &client, 700 | &url, 701 | &filename, 702 | start, 703 | stop, 704 | progress.clone(), 705 | ) 706 | .await; 707 | i += 1; 708 | if i > max_retries { 709 | return Err(ApiError::TooManyRetries(dlerr.into())); 710 | } 711 | drop(parallel_failure_permit); 712 | } 713 | } 714 | drop(permit); 715 | chunk 716 | })); 717 | } 718 | 719 | let mut futures: FuturesUnordered<_> = handles.into_iter().collect(); 720 | let mut temporaries = BinaryHeap::new(); 721 | let mut committed: u64 = start as u64; 722 | while let Some(chunk) = futures.next().await { 723 | let chunk = chunk?; 724 | let (start, stop) = chunk?; 725 | temporaries.push(Reverse((start, stop))); 726 | 727 | let mut modified = false; 728 | while let Some(Reverse((min, max))) = temporaries.pop() { 729 | if min as u64 == committed { 730 | committed = max as u64 + 1; 731 | modified = true; 732 | } else { 733 | temporaries.push(Reverse((min, max))); 734 | break; 735 | } 736 | } 737 | if modified { 738 | let mut file = tokio::fs::OpenOptions::new() 739 | .write(true) 740 | .open(&filename) 741 | .await?; 742 | file.seek(SeekFrom::Start(length as u64)).await?; 743 | file.write_all(&committed.to_le_bytes()).await?; 744 | file.flush().await?; 745 | } 746 | } 747 | let mut f = tokio::fs::OpenOptions::new() 748 | .write(true) 749 | .open(&filename) 750 | .await?; 751 | f.set_len(length as u64).await?; 752 | // XXX Extremely important and not obvious. 753 | // Tokio::fs doesn't guarantee data is written at the end of `.await` 754 | // boundaries. Even though we await the `set_len` it may not have been 755 | // committed to disk, leading to invalid rename. 756 | // Forcing a flush forces the data (here the truncation) to be committed to disk 757 | f.flush().await?; 758 | 759 | progressbar.finish().await; 760 | Ok(filename) 761 | } 762 | 763 | async fn download_chunk

( 764 | client: &reqwest::Client, 765 | url: &str, 766 | filename: &PathBuf, 767 | start: usize, 768 | stop: usize, 769 | mut progress: P, 770 | ) -> Result<(usize, usize), ApiError> 771 | where 772 | P: Progress, 773 | { 774 | // Process each socket concurrently. 775 | let range = format!("bytes={start}-{stop}"); 776 | let response = client 777 | .get(url) 778 | .header(RANGE, range) 779 | .send() 780 | .await? 781 | .error_for_status()?; 782 | let mut byte_stream = response.bytes_stream(); 783 | let mut buf: Vec = Vec::with_capacity(stop - start); 784 | while let Some(next) = byte_stream.next().await { 785 | let next = next?; 786 | buf.extend(&next); 787 | progress.update(next.len()).await; 788 | } 789 | let mut file = tokio::fs::OpenOptions::new() 790 | .write(true) 791 | .open(filename) 792 | .await?; 793 | file.seek(SeekFrom::Start(start as u64)).await?; 794 | file.write_all(&buf).await?; 795 | file.flush().await?; 796 | Ok((start, stop)) 797 | } 798 | 799 | /// This will attempt the fetch the file locally first, then [`Api.download`] 800 | /// if the file is not present. 801 | /// ```no_run 802 | /// # use hf_hub::api::tokio::Api; 803 | /// # tokio_test::block_on(async { 804 | /// let api = Api::new().unwrap(); 805 | /// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").await.unwrap(); 806 | /// # }) 807 | pub async fn get(&self, filename: &str) -> Result { 808 | if let Some(path) = self.api.cache.repo(self.repo.clone()).get(filename) { 809 | Ok(path) 810 | } else { 811 | self.download(filename).await 812 | } 813 | } 814 | 815 | /// Downloads a remote file (if not already present) into the cache directory 816 | /// to be used locally. 817 | /// This functions require internet access to verify if new versions of the file 818 | /// exist, even if a file is already on disk at location. 819 | /// ```no_run 820 | /// # use hf_hub::api::tokio::Api; 821 | /// # tokio_test::block_on(async { 822 | /// let api = Api::new().unwrap(); 823 | /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").await.unwrap(); 824 | /// # }) 825 | /// ``` 826 | pub async fn download(&self, filename: &str) -> Result { 827 | if self.api.progress { 828 | self.download_with_progress(filename, ProgressBar::new(0)) 829 | .await 830 | } else { 831 | self.download_with_progress(filename, ()).await 832 | } 833 | } 834 | 835 | /// This function is used to download a file with a custom progress function. 836 | /// It uses the [`Progress`] trait and can be used in more complex use 837 | /// cases like downloading a showing progress in a UI. 838 | /// ```no_run 839 | /// use hf_hub::api::tokio::{Api, Progress}; 840 | /// 841 | /// #[derive(Clone)] 842 | /// struct MyProgress{ 843 | /// current: usize, 844 | /// total: usize 845 | /// } 846 | /// 847 | /// impl Progress for MyProgress{ 848 | /// async fn init(&mut self, size: usize, _filename: &str){ 849 | /// self.total = size; 850 | /// self.current = 0; 851 | /// } 852 | /// 853 | /// async fn update(&mut self, size: usize){ 854 | /// self.current += size; 855 | /// println!("{}/{}", self.current, self.total) 856 | /// } 857 | /// 858 | /// async fn finish(&mut self){ 859 | /// println!("Done !"); 860 | /// } 861 | /// } 862 | /// # tokio_test::block_on(async { 863 | /// let api = Api::new().unwrap(); 864 | /// let progress = MyProgress{ current: 0, total : 0}; 865 | /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).await.unwrap(); 866 | /// # }) 867 | /// ``` 868 | pub async fn download_with_progress( 869 | &self, 870 | filename: &str, 871 | mut progress: P, 872 | ) -> Result { 873 | let url = self.url(filename); 874 | let metadata = self.api.metadata(&url).await?; 875 | let cache = self.api.cache.repo(self.repo.clone()); 876 | 877 | let blob_path = cache.blob_path(&metadata.etag); 878 | std::fs::create_dir_all(blob_path.parent().unwrap())?; 879 | 880 | let lock = lock_file(blob_path.clone()).await?; 881 | progress.init(metadata.size, filename).await; 882 | let mut tmp_path = blob_path.clone(); 883 | tmp_path.set_extension(EXTENSION); 884 | let tmp_filename = self 885 | .download_tempfile(&url, metadata.size, tmp_path, progress) 886 | .await?; 887 | 888 | tokio::fs::rename(&tmp_filename, &blob_path).await?; 889 | drop(lock); 890 | 891 | let mut pointer_path = cache.pointer_path(&metadata.commit_hash); 892 | pointer_path.push(filename); 893 | std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); 894 | 895 | symlink_or_rename(&blob_path, &pointer_path)?; 896 | cache.create_ref(&metadata.commit_hash)?; 897 | 898 | Ok(pointer_path) 899 | } 900 | 901 | /// Get information about the Repo 902 | /// ``` 903 | /// # use hf_hub::api::tokio::Api; 904 | /// # tokio_test::block_on(async { 905 | /// let api = Api::new().unwrap(); 906 | /// api.model("gpt2".to_string()).info(); 907 | /// # }) 908 | /// ``` 909 | pub async fn info(&self) -> Result { 910 | Ok(self.info_request().send().await?.json().await?) 911 | } 912 | 913 | /// Get the raw [`reqwest::RequestBuilder`] with the url and method already set 914 | /// ``` 915 | /// # use hf_hub::api::tokio::Api; 916 | /// # tokio_test::block_on(async { 917 | /// let api = Api::new().unwrap(); 918 | /// api.model("gpt2".to_owned()) 919 | /// .info_request() 920 | /// .query(&[("blobs", "true")]) 921 | /// .send() 922 | /// .await; 923 | /// # }) 924 | /// ``` 925 | pub fn info_request(&self) -> RequestBuilder { 926 | let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url()); 927 | self.api.client.get(url) 928 | } 929 | } 930 | 931 | #[cfg(test)] 932 | mod tests { 933 | use super::*; 934 | use crate::api::Siblings; 935 | use crate::assert_no_diff; 936 | use hex_literal::hex; 937 | use rand::distributions::Alphanumeric; 938 | use serde_json::{json, Value}; 939 | use sha2::{Digest, Sha256}; 940 | use std::io::{Seek, Write}; 941 | use std::time::Duration; 942 | 943 | struct TempDir { 944 | path: PathBuf, 945 | } 946 | 947 | impl TempDir { 948 | pub fn new() -> Self { 949 | let s: String = rand::thread_rng() 950 | .sample_iter(&Alphanumeric) 951 | .take(7) 952 | .map(char::from) 953 | .collect(); 954 | let mut path = std::env::temp_dir(); 955 | path.push(s); 956 | std::fs::create_dir(&path).unwrap(); 957 | Self { path } 958 | } 959 | } 960 | 961 | impl Drop for TempDir { 962 | fn drop(&mut self) { 963 | std::fs::remove_dir_all(&self.path).unwrap(); 964 | } 965 | } 966 | 967 | #[tokio::test] 968 | async fn simple() { 969 | let tmp = TempDir::new(); 970 | let api = ApiBuilder::new() 971 | .with_progress(false) 972 | .with_cache_dir(tmp.path.clone()) 973 | .build() 974 | .unwrap(); 975 | let model_id = "julien-c/dummy-unknown".to_string(); 976 | let repo = Repo::new(model_id.clone(), RepoType::Model); 977 | let downloaded_path = api.model(model_id).download("config.json").await.unwrap(); 978 | assert!(downloaded_path.exists()); 979 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 980 | assert_eq!( 981 | val[..], 982 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 983 | ); 984 | 985 | // Make sure the file is now seeable without connection 986 | let cache_path = api.cache.repo(repo.clone()).get("config.json").unwrap(); 987 | assert_eq!(cache_path, downloaded_path); 988 | } 989 | 990 | #[tokio::test] 991 | async fn locking() { 992 | use std::sync::Arc; 993 | use tokio::sync::Mutex; 994 | use tokio::task::JoinSet; 995 | let tmp = Arc::new(Mutex::new(TempDir::new())); 996 | 997 | let mut handles = JoinSet::new(); 998 | for _ in 0..5 { 999 | let tmp2 = tmp.clone(); 1000 | handles.spawn(async move { 1001 | let api = ApiBuilder::new() 1002 | .with_progress(false) 1003 | .with_cache_dir(tmp2.lock().await.path.clone()) 1004 | .build() 1005 | .unwrap(); 1006 | 1007 | // 0..256ms sleep to randomize potential clashes 1008 | let millis: u64 = rand::random::().into(); 1009 | tokio::time::sleep(Duration::from_millis(millis)).await; 1010 | let model_id = "julien-c/dummy-unknown".to_string(); 1011 | api.model(model_id.clone()) 1012 | .download("config.json") 1013 | .await 1014 | .unwrap() 1015 | }); 1016 | } 1017 | while let Some(handle) = handles.join_next().await { 1018 | let downloaded_path = handle.unwrap(); 1019 | assert!(downloaded_path.exists()); 1020 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1021 | assert_eq!( 1022 | val[..], 1023 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1024 | ); 1025 | } 1026 | } 1027 | 1028 | #[tokio::test] 1029 | async fn resume() { 1030 | let tmp = TempDir::new(); 1031 | let api = ApiBuilder::new() 1032 | .with_progress(false) 1033 | .with_cache_dir(tmp.path.clone()) 1034 | .build() 1035 | .unwrap(); 1036 | let model_id = "julien-c/dummy-unknown".to_string(); 1037 | let downloaded_path = api 1038 | .model(model_id.clone()) 1039 | .download("config.json") 1040 | .await 1041 | .unwrap(); 1042 | assert!(downloaded_path.exists()); 1043 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1044 | assert_eq!( 1045 | val[..], 1046 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1047 | ); 1048 | 1049 | // This actually sets the file to a trashed version of the part file, full redownload will 1050 | // ensue 1051 | let blob = std::fs::canonicalize(&downloaded_path).unwrap(); 1052 | let file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); 1053 | let size = file.metadata().unwrap().len(); 1054 | let truncate: f32 = rand::random(); 1055 | let new_size = (size as f32 * truncate) as u64; 1056 | file.set_len(new_size).unwrap(); 1057 | let mut blob_part = blob.clone(); 1058 | blob_part.set_extension("sync.part"); 1059 | std::fs::rename(blob, &blob_part).unwrap(); 1060 | std::fs::remove_file(&downloaded_path).unwrap(); 1061 | let content = std::fs::read(&*blob_part).unwrap(); 1062 | assert_eq!(content.len() as u64, new_size); 1063 | let val = Sha256::digest(content); 1064 | // We modified the sha. 1065 | assert!( 1066 | val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1067 | ); 1068 | let new_downloaded_path = api 1069 | .model(model_id.clone()) 1070 | .download("config.json") 1071 | .await 1072 | .unwrap(); 1073 | let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); 1074 | assert_eq!(downloaded_path, new_downloaded_path); 1075 | assert_eq!( 1076 | val[..], 1077 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1078 | ); 1079 | 1080 | // Now this is a valid partial download file 1081 | let blob = std::fs::canonicalize(&downloaded_path).unwrap(); 1082 | let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); 1083 | let size = file.metadata().unwrap().len(); 1084 | let truncate: f32 = rand::random(); 1085 | let new_size = (size as f32 * truncate) as u64; 1086 | // Truncating 1087 | file.set_len(new_size).unwrap(); 1088 | let total_size = size + size_of::() as u64; 1089 | file.set_len(total_size).unwrap(); 1090 | file.seek(SeekFrom::Start(size)).unwrap(); 1091 | file.write_all(&new_size.to_le_bytes()).unwrap(); 1092 | 1093 | let mut blob_part = blob.clone(); 1094 | blob_part.set_extension("sync.part"); 1095 | std::fs::rename(blob, &blob_part).unwrap(); 1096 | std::fs::remove_file(&downloaded_path).unwrap(); 1097 | let content = std::fs::read(&*blob_part).unwrap(); 1098 | assert_eq!(content.len() as u64, total_size); 1099 | let val = Sha256::digest(content); 1100 | // We modified the sha. 1101 | assert!( 1102 | val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1103 | ); 1104 | let new_downloaded_path = api 1105 | .model(model_id.clone()) 1106 | .download("config.json") 1107 | .await 1108 | .unwrap(); 1109 | let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); 1110 | assert_eq!(downloaded_path, new_downloaded_path); 1111 | assert_eq!( 1112 | val[..], 1113 | hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1114 | ); 1115 | 1116 | // Here we prove the previous part was correctly resuming by purposefully corrupting the 1117 | // file. 1118 | let blob = std::fs::canonicalize(&downloaded_path).unwrap(); 1119 | let mut file = std::fs::OpenOptions::new().write(true).open(&blob).unwrap(); 1120 | let size = file.metadata().unwrap().len(); 1121 | // Not random for consistent sha corruption 1122 | let truncate: f32 = 0.5; 1123 | let new_size = (size as f32 * truncate) as u64; 1124 | // Truncating 1125 | file.set_len(new_size).unwrap(); 1126 | let total_size = size + size_of::() as u64; 1127 | file.set_len(total_size).unwrap(); 1128 | file.seek(SeekFrom::Start(size)).unwrap(); 1129 | file.write_all(&new_size.to_le_bytes()).unwrap(); 1130 | 1131 | // Corrupting by changing a single byte. 1132 | file.seek(SeekFrom::Start(new_size - 1)).unwrap(); 1133 | file.write_all(&[0]).unwrap(); 1134 | 1135 | let mut blob_part = blob.clone(); 1136 | blob_part.set_extension("sync.part"); 1137 | std::fs::rename(blob, &blob_part).unwrap(); 1138 | std::fs::remove_file(&downloaded_path).unwrap(); 1139 | let content = std::fs::read(&*blob_part).unwrap(); 1140 | assert_eq!(content.len() as u64, total_size); 1141 | let val = Sha256::digest(content); 1142 | // We modified the sha. 1143 | assert!( 1144 | val[..] != hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32") 1145 | ); 1146 | let new_downloaded_path = api 1147 | .model(model_id.clone()) 1148 | .download("config.json") 1149 | .await 1150 | .unwrap(); 1151 | let val = Sha256::digest(std::fs::read(&*new_downloaded_path).unwrap()); 1152 | assert_eq!(downloaded_path, new_downloaded_path); 1153 | assert_eq!( 1154 | val[..], 1155 | // Corrupted sha 1156 | hex!("32b83c94ee55a8d43d68b03a859975f6789d647342ddeb2326fcd5e0127035b5") 1157 | ); 1158 | } 1159 | 1160 | #[tokio::test] 1161 | async fn revision() { 1162 | let tmp = TempDir::new(); 1163 | let api = ApiBuilder::new() 1164 | .with_progress(false) 1165 | .with_cache_dir(tmp.path.clone()) 1166 | .build() 1167 | .unwrap(); 1168 | let model_id = "BAAI/bge-base-en".to_string(); 1169 | let repo = Repo::with_revision(model_id.clone(), RepoType::Model, "refs/pr/2".to_string()); 1170 | let downloaded_path = api 1171 | .repo(repo.clone()) 1172 | .download("tokenizer.json") 1173 | .await 1174 | .unwrap(); 1175 | assert!(downloaded_path.exists()); 1176 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1177 | assert_eq!( 1178 | val[..], 1179 | hex!("d241a60d5e8f04cc1b2b3e9ef7a4921b27bf526d9f6050ab90f9267a1f9e5c66") 1180 | ); 1181 | 1182 | // Make sure the file is now seeable without connection 1183 | let cache_path = api.cache.repo(repo).get("tokenizer.json").unwrap(); 1184 | assert_eq!(cache_path, downloaded_path); 1185 | } 1186 | 1187 | #[tokio::test] 1188 | async fn dataset() { 1189 | let tmp = TempDir::new(); 1190 | let api = ApiBuilder::new() 1191 | .with_progress(false) 1192 | .with_cache_dir(tmp.path.clone()) 1193 | .build() 1194 | .unwrap(); 1195 | let repo = Repo::with_revision( 1196 | "wikitext".to_string(), 1197 | RepoType::Dataset, 1198 | "refs/convert/parquet".to_string(), 1199 | ); 1200 | let downloaded_path = api 1201 | .repo(repo) 1202 | .download("wikitext-103-v1/test/0000.parquet") 1203 | .await 1204 | .unwrap(); 1205 | assert!(downloaded_path.exists()); 1206 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1207 | assert_eq!( 1208 | val[..], 1209 | hex!("ABDFC9F83B1103B502924072460D4C92F277C9B49C313CEF3E48CFCF7428E125") 1210 | ); 1211 | } 1212 | 1213 | #[tokio::test] 1214 | async fn models() { 1215 | let tmp = TempDir::new(); 1216 | let api = ApiBuilder::new() 1217 | .with_progress(false) 1218 | .with_cache_dir(tmp.path.clone()) 1219 | .build() 1220 | .unwrap(); 1221 | let repo = Repo::with_revision( 1222 | "BAAI/bGe-reRanker-Base".to_string(), 1223 | RepoType::Model, 1224 | "refs/pr/5".to_string(), 1225 | ); 1226 | let downloaded_path = api.repo(repo).download("tokenizer.json").await.unwrap(); 1227 | assert!(downloaded_path.exists()); 1228 | let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); 1229 | assert_eq!( 1230 | val[..], 1231 | hex!("9EB652AC4E40CC093272BBBE0F55D521CF67570060227109B5CDC20945A4489E") 1232 | ); 1233 | } 1234 | 1235 | #[tokio::test] 1236 | async fn info() { 1237 | let tmp = TempDir::new(); 1238 | let api = ApiBuilder::new() 1239 | .with_progress(false) 1240 | .with_cache_dir(tmp.path.clone()) 1241 | .build() 1242 | .unwrap(); 1243 | let repo = Repo::with_revision( 1244 | "wikitext".to_string(), 1245 | RepoType::Dataset, 1246 | "refs/convert/parquet".to_string(), 1247 | ); 1248 | let model_info = api.repo(repo).info().await.unwrap(); 1249 | assert_eq!( 1250 | model_info, 1251 | RepoInfo { 1252 | siblings: vec![ 1253 | Siblings { 1254 | rfilename: ".gitattributes".to_string() 1255 | }, 1256 | Siblings { 1257 | rfilename: "wikitext-103-raw-v1/test/0000.parquet".to_string() 1258 | }, 1259 | Siblings { 1260 | rfilename: "wikitext-103-raw-v1/train/0000.parquet".to_string() 1261 | }, 1262 | Siblings { 1263 | rfilename: "wikitext-103-raw-v1/train/0001.parquet".to_string() 1264 | }, 1265 | Siblings { 1266 | rfilename: "wikitext-103-raw-v1/validation/0000.parquet".to_string() 1267 | }, 1268 | Siblings { 1269 | rfilename: "wikitext-103-v1/test/0000.parquet".to_string() 1270 | }, 1271 | Siblings { 1272 | rfilename: "wikitext-103-v1/train/0000.parquet".to_string() 1273 | }, 1274 | Siblings { 1275 | rfilename: "wikitext-103-v1/train/0001.parquet".to_string() 1276 | }, 1277 | Siblings { 1278 | rfilename: "wikitext-103-v1/validation/0000.parquet".to_string() 1279 | }, 1280 | Siblings { 1281 | rfilename: "wikitext-2-raw-v1/test/0000.parquet".to_string() 1282 | }, 1283 | Siblings { 1284 | rfilename: "wikitext-2-raw-v1/train/0000.parquet".to_string() 1285 | }, 1286 | Siblings { 1287 | rfilename: "wikitext-2-raw-v1/validation/0000.parquet".to_string() 1288 | }, 1289 | Siblings { 1290 | rfilename: "wikitext-2-v1/test/0000.parquet".to_string() 1291 | }, 1292 | Siblings { 1293 | rfilename: "wikitext-2-v1/train/0000.parquet".to_string() 1294 | }, 1295 | Siblings { 1296 | rfilename: "wikitext-2-v1/validation/0000.parquet".to_string() 1297 | } 1298 | ], 1299 | sha: "3f68cd45302c7b4b532d933e71d9e6e54b1c7d5e".to_string() 1300 | } 1301 | ); 1302 | } 1303 | 1304 | #[tokio::test] 1305 | async fn info_request() { 1306 | let tmp = TempDir::new(); 1307 | let api = ApiBuilder::new() 1308 | .with_token(None) 1309 | .with_progress(false) 1310 | .with_cache_dir(tmp.path.clone()) 1311 | .build() 1312 | .unwrap(); 1313 | let repo = Repo::with_revision( 1314 | "mcpotato/42-eicar-street".to_string(), 1315 | RepoType::Model, 1316 | "8b3861f6931c4026b0cd22b38dbc09e7668983ac".to_string(), 1317 | ); 1318 | let blobs_info: Value = api 1319 | .repo(repo) 1320 | .info_request() 1321 | .query(&[("blobs", "true")]) 1322 | .send() 1323 | .await 1324 | .unwrap() 1325 | .json() 1326 | .await 1327 | .unwrap(); 1328 | assert_no_diff!( 1329 | blobs_info, 1330 | json!({ 1331 | "_id": "621ffdc136468d709f17ddb4", 1332 | "author": "mcpotato", 1333 | "createdAt": "2022-03-02T23:29:05.000Z", 1334 | "disabled": false, 1335 | "downloads": 0, 1336 | "gated": false, 1337 | "id": "mcpotato/42-eicar-street", 1338 | "lastModified": "2022-11-30T19:54:16.000Z", 1339 | "likes": 2, 1340 | "modelId": "mcpotato/42-eicar-street", 1341 | "private": false, 1342 | "sha": "8b3861f6931c4026b0cd22b38dbc09e7668983ac", 1343 | "siblings": [ 1344 | { 1345 | "blobId": "6d34772f5ca361021038b404fb913ec8dc0b1a5a", 1346 | "rfilename": ".gitattributes", 1347 | "size": 1175 1348 | }, 1349 | { 1350 | "blobId": "be98037f7c542112c15a1d2fc7e2a2427e42cb50", 1351 | "rfilename": "build_pickles.py", 1352 | "size": 304 1353 | }, 1354 | { 1355 | "blobId": "8acd02161fff53f9df9597e377e22b04bc34feff", 1356 | "rfilename": "danger.dat", 1357 | "size": 66 1358 | }, 1359 | { 1360 | "blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e", 1361 | "rfilename": "eicar_test_file", 1362 | "size": 70 1363 | }, 1364 | { 1365 | "blobId": "86b812515e075a1ae216e1239e615a1d9e0b316e", 1366 | "rfilename": "eicar_test_file_bis", 1367 | "size":70 1368 | }, 1369 | { 1370 | "blobId": "cd1c6d8bde5006076655711a49feae66f07d707e", 1371 | "lfs": { 1372 | "pointerSize": 127, 1373 | "sha256": "f9343d7d7ec5c3d8bcced056c438fc9f1d3819e9ca3d42418a40857050e10e20", 1374 | "size": 22 1375 | }, 1376 | "rfilename": "pytorch_model.bin", 1377 | "size": 22 1378 | }, 1379 | { 1380 | "blobId": "8ab39654695136173fee29cba0193f679dfbd652", 1381 | "rfilename": "supposedly_safe.pkl", 1382 | "size": 31 1383 | } 1384 | ], 1385 | "spaces": [], 1386 | "tags": ["pytorch", "region:us"], 1387 | "usedStorage": 22 1388 | }) 1389 | ); 1390 | } 1391 | 1392 | #[test] 1393 | fn headers_default() { 1394 | let headers = ApiBuilder::new().build_headers().unwrap(); 1395 | assert_eq!( 1396 | headers.get(USER_AGENT), 1397 | Some( 1398 | &"unknown/None; hf-hub/0.4.2; rust/unknown" 1399 | .try_into() 1400 | .unwrap() 1401 | ) 1402 | ); 1403 | } 1404 | 1405 | #[test] 1406 | fn headers_custom() { 1407 | let headers = ApiBuilder::new() 1408 | .with_user_agent("origin", "custom") 1409 | .build_headers() 1410 | .unwrap(); 1411 | assert_eq!( 1412 | headers.get(USER_AGENT), 1413 | Some( 1414 | &"unknown/None; hf-hub/0.4.2; rust/unknown; origin/custom" 1415 | .try_into() 1416 | .unwrap() 1417 | ) 1418 | ); 1419 | } 1420 | 1421 | // #[tokio::test] 1422 | // async fn real() { 1423 | // let api = Api::new().unwrap(); 1424 | // let repo = api.model("bert-base-uncased".to_string()); 1425 | // let weights = repo.get("model.safetensors").await.unwrap(); 1426 | // let val = Sha256::digest(std::fs::read(&*weights).unwrap()); 1427 | // println!("Digest {val:#x}"); 1428 | // assert_eq!( 1429 | // val[..], 1430 | // hex!("68d45e234eb4a928074dfd868cead0219ab85354cc53d20e772753c6bb9169d3") 1431 | // ); 1432 | // } 1433 | } 1434 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | #![cfg_attr(feature="ureq", doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md")))] 3 | #![cfg_attr( 4 | not(feature = "ureq"), 5 | doc = "Documentation is meant to be compiled with default features (at least ureq)" 6 | )] 7 | use std::io::Write; 8 | use std::path::PathBuf; 9 | 10 | /// The actual Api to interact with the hub. 11 | #[cfg(any(feature = "tokio", feature = "ureq"))] 12 | pub mod api; 13 | 14 | const HF_HOME: &str = "HF_HOME"; 15 | 16 | /// The type of repo to interact with 17 | #[derive(Debug, Clone, Copy)] 18 | pub enum RepoType { 19 | /// This is a model, usually it consists of weight files and some configuration 20 | /// files 21 | Model, 22 | /// This is a dataset, usually contains data within parquet files 23 | Dataset, 24 | /// This is a space, usually a demo showcashing a given model or dataset 25 | Space, 26 | } 27 | 28 | /// A local struct used to fetch information from the cache folder. 29 | #[derive(Clone, Debug)] 30 | pub struct Cache { 31 | path: PathBuf, 32 | } 33 | 34 | impl Cache { 35 | /// Creates a new cache object location 36 | pub fn new(path: PathBuf) -> Self { 37 | Self { path } 38 | } 39 | 40 | /// Creates cache from environment variable HF_HOME (if defined) otherwise 41 | /// defaults to [`home_dir`]/.cache/huggingface/ 42 | pub fn from_env() -> Self { 43 | match std::env::var(HF_HOME) { 44 | Ok(home) => { 45 | let mut path: PathBuf = home.into(); 46 | path.push("hub"); 47 | Self::new(path) 48 | } 49 | Err(_) => Self::default(), 50 | } 51 | } 52 | 53 | /// Creates a new cache object location 54 | pub fn path(&self) -> &PathBuf { 55 | &self.path 56 | } 57 | 58 | /// Returns the location of the token file 59 | pub fn token_path(&self) -> PathBuf { 60 | let mut path = self.path.clone(); 61 | // Remove `"hub"` 62 | path.pop(); 63 | path.push("token"); 64 | path 65 | } 66 | 67 | /// Returns the token value if it exists in the cache 68 | /// Use `huggingface-cli login` to set it up. 69 | pub fn token(&self) -> Option { 70 | let token_filename = self.token_path(); 71 | if token_filename.exists() { 72 | log::info!("Using token file found {token_filename:?}"); 73 | } 74 | match std::fs::read_to_string(token_filename) { 75 | Ok(token_content) => { 76 | let token_content = token_content.trim(); 77 | if token_content.is_empty() { 78 | None 79 | } else { 80 | Some(token_content.to_string()) 81 | } 82 | } 83 | Err(_) => None, 84 | } 85 | } 86 | 87 | /// Creates a new handle [`CacheRepo`] which contains operations 88 | /// on a particular [`Repo`] 89 | pub fn repo(&self, repo: Repo) -> CacheRepo { 90 | CacheRepo::new(self.clone(), repo) 91 | } 92 | 93 | /// Simple wrapper over 94 | /// ``` 95 | /// # use hf_hub::{Cache, Repo, RepoType}; 96 | /// # let model_id = "gpt2".to_string(); 97 | /// let cache = Cache::new("/tmp/".into()); 98 | /// let cache = cache.repo(Repo::new(model_id, RepoType::Model)); 99 | /// ``` 100 | pub fn model(&self, model_id: String) -> CacheRepo { 101 | self.repo(Repo::new(model_id, RepoType::Model)) 102 | } 103 | 104 | /// Simple wrapper over 105 | /// ``` 106 | /// # use hf_hub::{Cache, Repo, RepoType}; 107 | /// # let model_id = "gpt2".to_string(); 108 | /// let cache = Cache::new("/tmp/".into()); 109 | /// let cache = cache.repo(Repo::new(model_id, RepoType::Dataset)); 110 | /// ``` 111 | pub fn dataset(&self, model_id: String) -> CacheRepo { 112 | self.repo(Repo::new(model_id, RepoType::Dataset)) 113 | } 114 | 115 | /// Simple wrapper over 116 | /// ``` 117 | /// # use hf_hub::{Cache, Repo, RepoType}; 118 | /// # let model_id = "gpt2".to_string(); 119 | /// let cache = Cache::new("/tmp/".into()); 120 | /// let cache = cache.repo(Repo::new(model_id, RepoType::Space)); 121 | /// ``` 122 | pub fn space(&self, model_id: String) -> CacheRepo { 123 | self.repo(Repo::new(model_id, RepoType::Space)) 124 | } 125 | } 126 | 127 | /// Shorthand for accessing things within a particular repo 128 | #[derive(Debug)] 129 | pub struct CacheRepo { 130 | cache: Cache, 131 | repo: Repo, 132 | } 133 | 134 | impl CacheRepo { 135 | fn new(cache: Cache, repo: Repo) -> Self { 136 | Self { cache, repo } 137 | } 138 | 139 | /// This will get the location of the file within the cache for the remote 140 | /// `filename`. Will return `None` if file is not already present in cache. 141 | pub fn get(&self, filename: &str) -> Option { 142 | let commit_path = self.ref_path(); 143 | let commit_hash = std::fs::read_to_string(commit_path).ok()?; 144 | let mut pointer_path = self.pointer_path(&commit_hash); 145 | pointer_path.push(filename); 146 | if pointer_path.exists() { 147 | Some(pointer_path) 148 | } else { 149 | None 150 | } 151 | } 152 | 153 | fn path(&self) -> PathBuf { 154 | let mut ref_path = self.cache.path.clone(); 155 | ref_path.push(self.repo.folder_name()); 156 | ref_path 157 | } 158 | 159 | fn ref_path(&self) -> PathBuf { 160 | let mut ref_path = self.path(); 161 | ref_path.push("refs"); 162 | ref_path.push(self.repo.revision()); 163 | ref_path 164 | } 165 | 166 | /// Creates a reference in the cache directory that points branches to the correct 167 | /// commits within the blobs. 168 | pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> { 169 | let ref_path = self.ref_path(); 170 | // Needs to be done like this because revision might contain `/` creating subfolders here. 171 | std::fs::create_dir_all(ref_path.parent().unwrap())?; 172 | let mut file = std::fs::OpenOptions::new() 173 | .write(true) 174 | .create(true) 175 | .truncate(true) 176 | .open(&ref_path)?; 177 | file.write_all(commit_hash.trim().as_bytes())?; 178 | Ok(()) 179 | } 180 | 181 | #[cfg(any(feature = "tokio", feature = "ureq"))] 182 | pub(crate) fn blob_path(&self, etag: &str) -> PathBuf { 183 | let mut blob_path = self.path(); 184 | blob_path.push("blobs"); 185 | blob_path.push(etag); 186 | blob_path 187 | } 188 | 189 | pub(crate) fn pointer_path(&self, commit_hash: &str) -> PathBuf { 190 | let mut pointer_path = self.path(); 191 | pointer_path.push("snapshots"); 192 | pointer_path.push(commit_hash); 193 | pointer_path 194 | } 195 | } 196 | 197 | impl Default for Cache { 198 | fn default() -> Self { 199 | let mut path = dirs::home_dir().expect("Cache directory cannot be found"); 200 | path.push(".cache"); 201 | path.push("huggingface"); 202 | path.push("hub"); 203 | Self::new(path) 204 | } 205 | } 206 | 207 | /// The representation of a repo on the hub. 208 | #[derive(Clone, Debug)] 209 | pub struct Repo { 210 | repo_id: String, 211 | repo_type: RepoType, 212 | revision: String, 213 | } 214 | 215 | impl Repo { 216 | /// Repo with the default branch ("main"). 217 | pub fn new(repo_id: String, repo_type: RepoType) -> Self { 218 | Self::with_revision(repo_id, repo_type, "main".to_string()) 219 | } 220 | 221 | /// fully qualified Repo 222 | pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self { 223 | Self { 224 | repo_id, 225 | repo_type, 226 | revision, 227 | } 228 | } 229 | 230 | /// Shortcut for [`Repo::new`] with [`RepoType::Model`] 231 | pub fn model(repo_id: String) -> Self { 232 | Self::new(repo_id, RepoType::Model) 233 | } 234 | 235 | /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`] 236 | pub fn dataset(repo_id: String) -> Self { 237 | Self::new(repo_id, RepoType::Dataset) 238 | } 239 | 240 | /// Shortcut for [`Repo::new`] with [`RepoType::Space`] 241 | pub fn space(repo_id: String) -> Self { 242 | Self::new(repo_id, RepoType::Space) 243 | } 244 | 245 | /// The normalized folder nameof the repo within the cache directory 246 | pub fn folder_name(&self) -> String { 247 | let prefix = match self.repo_type { 248 | RepoType::Model => "models", 249 | RepoType::Dataset => "datasets", 250 | RepoType::Space => "spaces", 251 | }; 252 | format!("{prefix}--{}", self.repo_id).replace('/', "--") 253 | } 254 | 255 | /// The revision 256 | pub fn revision(&self) -> &str { 257 | &self.revision 258 | } 259 | 260 | /// The actual URL part of the repo 261 | #[cfg(any(feature = "tokio", feature = "ureq"))] 262 | pub fn url(&self) -> String { 263 | match self.repo_type { 264 | RepoType::Model => self.repo_id.to_string(), 265 | RepoType::Dataset => { 266 | format!("datasets/{}", self.repo_id) 267 | } 268 | RepoType::Space => { 269 | format!("spaces/{}", self.repo_id) 270 | } 271 | } 272 | } 273 | 274 | /// Revision needs to be url escaped before being used in a URL 275 | #[cfg(any(feature = "tokio", feature = "ureq"))] 276 | pub fn url_revision(&self) -> String { 277 | self.revision.replace('/', "%2F") 278 | } 279 | 280 | /// Used to compute the repo's url part when accessing the metadata of the repo 281 | #[cfg(any(feature = "tokio", feature = "ureq"))] 282 | pub fn api_url(&self) -> String { 283 | let prefix = match self.repo_type { 284 | RepoType::Model => "models", 285 | RepoType::Dataset => "datasets", 286 | RepoType::Space => "spaces", 287 | }; 288 | format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision()) 289 | } 290 | } 291 | 292 | #[cfg(test)] 293 | mod tests { 294 | use super::*; 295 | 296 | /// Internal macro used to show cleaners errors 297 | /// on the payloads received from the hub. 298 | #[macro_export] 299 | macro_rules! assert_no_diff { 300 | ($left: expr, $right: expr) => { 301 | let left = serde_json::to_string_pretty(&$left).unwrap(); 302 | let right = serde_json::to_string_pretty(&$right).unwrap(); 303 | if left != right { 304 | use rand::Rng; 305 | use std::io::Write; 306 | use std::process::Command; 307 | let rand_string: String = rand::thread_rng() 308 | .sample_iter(&rand::distributions::Alphanumeric) 309 | .take(6) 310 | .map(char::from) 311 | .collect(); 312 | let left_filename = format!("/tmp/left-{rand_string}.txt"); 313 | let mut file = std::fs::File::create(&left_filename).unwrap(); 314 | file.write_all(left.as_bytes()).unwrap(); 315 | let right_filename = format!("/tmp/right-{rand_string}.txt"); 316 | let mut file = std::fs::File::create(&right_filename).unwrap(); 317 | file.write_all(right.as_bytes()).unwrap(); 318 | let output = Command::new("diff") 319 | // Reverse order seems to be more appropriate for how we set up the tests. 320 | .args(["-U5", &right_filename, &left_filename]) 321 | .output() 322 | .expect("Failed to diff") 323 | .stdout; 324 | let diff = String::from_utf8(output).expect("Invalid utf-8 diff output"); 325 | // eprintln!("assertion `left == right` failed\n{diff}"); 326 | assert!(false, "{diff}") 327 | }; 328 | }; 329 | } 330 | 331 | #[test] 332 | #[cfg(not(target_os = "windows"))] 333 | fn token_path() { 334 | let cache = Cache::from_env(); 335 | let token_path = cache.token_path().to_str().unwrap().to_string(); 336 | if let Ok(hf_home) = std::env::var(HF_HOME) { 337 | assert_eq!(token_path, format!("{hf_home}/token")); 338 | } else { 339 | let n = "huggingface/token".len(); 340 | assert_eq!(&token_path[token_path.len() - n..], "huggingface/token"); 341 | } 342 | } 343 | 344 | #[test] 345 | #[cfg(target_os = "windows")] 346 | fn token_path() { 347 | let cache = Cache::from_env(); 348 | let token_path = cache.token_path().to_str().unwrap().to_string(); 349 | if let Ok(hf_home) = std::env::var(HF_HOME) { 350 | assert_eq!(token_path, format!("{hf_home}\\token")); 351 | } else { 352 | let n = "huggingface/token".len(); 353 | assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token"); 354 | } 355 | } 356 | } 357 | --------------------------------------------------------------------------------