├── .cargo └── config.toml ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ └── feature-request.yml ├── actions │ ├── cache-rust-build │ │ └── action.yml │ ├── set-build-profile │ │ └── action.yml │ └── windows-codesign │ │ └── action.yml └── workflows │ ├── ci.yml │ ├── git-xet-release.yml │ ├── hf-xet-tests.yml │ ├── pre-release-testing.yml │ └── release.yml ├── .gitignore ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── cas_client ├── Cargo.toml ├── README.md └── src │ ├── constants.rs │ ├── download_utils.rs │ ├── error.rs │ ├── exports.rs │ ├── http_client.rs │ ├── interface.rs │ ├── lib.rs │ ├── local_client.rs │ ├── output_provider.rs │ ├── remote_client.rs │ ├── retry_wrapper.rs │ └── upload_progress_stream.rs ├── cas_object ├── Cargo.toml ├── benches │ ├── bg_split_regroup_bench.rs │ └── compression_bench.rs └── src │ ├── byte_grouping │ ├── bg4.rs │ ├── bg4_prediction.rs │ ├── bg4_prediction_benchmark.rs │ ├── compression_stats │ │ ├── collect_compression_stats.rs │ │ └── compression_prediction_tests.py │ └── mod.rs │ ├── cas_chunk_format.rs │ ├── cas_chunk_format │ └── deserialize_async.rs │ ├── cas_object_format.rs │ ├── compression_scheme.rs │ ├── constants.rs │ ├── error.rs │ └── lib.rs ├── cas_types ├── Cargo.toml ├── README.md └── src │ ├── error.rs │ ├── key.rs │ └── lib.rs ├── chunk_cache ├── .gitignore ├── Cargo.toml └── src │ ├── bin │ └── analysis.rs │ ├── cache_manager.rs │ ├── disk.rs │ ├── disk │ ├── cache_file_header.rs │ ├── cache_item.rs │ └── test_utils.rs │ ├── error.rs │ └── lib.rs ├── chunk_cache_bench ├── Cargo.lock ├── Cargo.toml ├── benches │ ├── cache_bench.rs │ └── results.md └── src │ ├── bin │ └── cache_resilience_test.rs │ ├── lib.rs │ ├── sccache.rs │ └── solid_cache.rs ├── data ├── Cargo.toml ├── README.md ├── examples │ ├── chunk │ │ └── main.rs │ ├── hash │ │ └── main.rs │ └── xorb-check │ │ └── main.rs ├── src │ ├── bin │ │ ├── example.rs │ │ └── xtool.rs │ ├── configurations.rs │ ├── constants.rs │ ├── data_client.rs │ ├── deduplication_interface.rs │ ├── errors.rs │ ├── file_cleaner.rs │ ├── file_downloader.rs │ ├── file_upload_session.rs │ ├── lib.rs │ ├── migration_tool │ │ ├── hub_client_token_refresher.rs │ │ ├── migrate.rs │ │ └── mod.rs │ ├── prometheus_metrics.rs │ ├── remote_client_interface.rs │ ├── sha256.rs │ ├── shard_interface.rs │ ├── test_utils.rs │ └── xet_file.rs └── tests │ ├── integration_tests.rs │ ├── integration_tests │ ├── initialize.sh │ └── test_basic_clean_smudge.sh │ ├── test_clean_smudge.rs │ └── test_session_resume.rs ├── deduplication ├── Cargo.toml ├── README.md └── src │ ├── chunk.rs │ ├── chunking.rs │ ├── constants.rs │ ├── data_aggregator.rs │ ├── dedup_metrics.rs │ ├── defrag_prevention.rs │ ├── file_deduplication.rs │ ├── interface.rs │ ├── lib.rs │ ├── parallel chunking.lyx │ ├── parallel chunking.pdf │ └── raw_xorb_data.rs ├── error_printer ├── Cargo.toml ├── src │ └── lib.rs └── tests │ ├── test_error.rs │ └── test_option.rs ├── file_utils ├── Cargo.toml └── src │ ├── file_metadata.rs │ ├── lib.rs │ ├── privilege_context.rs │ └── safe_file_creator.rs ├── git_xet ├── Cargo.toml ├── README.md ├── entitlements.xml ├── install.sh ├── src │ ├── app.rs │ ├── app │ │ ├── install.rs │ │ ├── uninstall.rs │ │ └── xet_agent.rs │ ├── auth.rs │ ├── auth │ │ ├── git.rs │ │ └── ssh.rs │ ├── bin │ │ └── main.rs │ ├── constants.rs │ ├── errors.rs │ ├── git_process_wrapping.rs │ ├── git_repo.rs │ ├── git_url.rs │ ├── lfs_agent_protocol.rs │ ├── lfs_agent_protocol │ │ ├── agent_state.rs │ │ ├── errors.rs │ │ ├── progress_updater.rs │ │ └── protocol_spec.rs │ ├── lib.rs │ ├── test_utils │ │ ├── gitaskpass.sh │ │ ├── mod.rs │ │ ├── temp.rs │ │ └── test_repo.rs │ └── token_refresher.rs └── windows_installer │ ├── .gitignore │ ├── Package.wxs │ └── sign_metadata.json ├── hf-xet-diag-linux.sh ├── hf-xet-diag-macos.sh ├── hf-xet-diag-windows.sh ├── hf_xet ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── pyproject.toml ├── python │ └── .gitkeep └── src │ ├── lib.rs │ ├── logging.rs │ ├── profiling.rs │ ├── progress_update.rs │ ├── runtime.rs │ └── token_refresh.rs ├── hf_xet_thin_wasm ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build_wasm.sh └── src │ └── lib.rs ├── hf_xet_wasm ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build_wasm.sh ├── examples │ ├── commit.js │ ├── index.html │ ├── simple.rs │ └── xet_meta.js ├── rust-toolchain.toml ├── src │ ├── auth.rs │ ├── blob_reader.rs │ ├── configurations.rs │ ├── errors.rs │ ├── lib.rs │ ├── session.rs │ ├── sha256.rs │ ├── wasm_deduplication_interface.rs │ ├── wasm_file_cleaner.rs │ ├── wasm_file_upload_session.rs │ ├── wasm_timer.rs │ └── xorb_uploader.rs └── webdriver.json ├── hub_client ├── Cargo.toml └── src │ ├── auth.rs │ ├── auth │ ├── basics.rs │ └── interface.rs │ ├── client.rs │ ├── errors.rs │ ├── lib.rs │ └── types.rs ├── markdownlint.toml ├── mdb_shard ├── Cargo.toml ├── README.md └── src │ ├── cas_structs.rs │ ├── chunk_verification.rs │ ├── constants.rs │ ├── error.rs │ ├── file_structs.rs │ ├── interpolation_search.rs │ ├── lib.rs │ ├── session_directory.rs │ ├── set_operations.rs │ ├── shard_benchmark.rs │ ├── shard_file.rs │ ├── shard_file_handle.rs │ ├── shard_file_manager.rs │ ├── shard_file_reconstructor.rs │ ├── shard_format.rs │ ├── shard_in_memory.rs │ ├── streaming_shard.rs │ └── utils.rs ├── merklehash ├── Cargo.toml ├── README.md └── src │ ├── aggregated_hashes.rs │ ├── data_hash.rs │ └── lib.rs ├── openapi ├── .gitignore ├── Makefile └── cas.openapi.yaml ├── progress_tracking ├── Cargo.toml └── src │ ├── aggregator.rs │ ├── item_tracking.rs │ ├── lib.rs │ ├── no_op_tracker.rs │ ├── progress_info.rs │ ├── upload_tracking.rs │ └── verification_wrapper.rs ├── pyproject.toml ├── rustfmt.toml ├── utils ├── Cargo.toml ├── README.md └── src │ ├── async_iterator.rs │ ├── async_read.rs │ ├── auth.rs │ ├── byte_size.rs │ ├── constant_declarations.rs │ ├── errors.rs │ ├── file_paths.rs │ ├── lib.rs │ ├── limited_joinset.rs │ ├── output_bytes.rs │ ├── rw_task_lock.rs │ ├── serialization_utils.rs │ └── singleflight.rs └── xet_runtime ├── Cargo.toml └── src ├── errors.rs ├── exports.rs ├── file_handle_limits.rs ├── global_semaphores.rs ├── lib.rs ├── runtime.rs ├── sync_primatives.rs └── utils.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.wasm32-unknown-unknown] 2 | # the following is necessary when compiling rand & getrandom for wasm 3 | # https://github.com/rust-random/getrandom/blob/master/README.md#webassembly-support 4 | rustflags = ['--cfg', 'getrandom_backend="wasm_js"'] 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Report a bug on xet-core 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this bug report! 9 | - type: textarea 10 | id: bug-description 11 | attributes: 12 | label: Describe the bug 13 | description: A clear and concise description of what the bug is and details about your machine (e.g., network capcity, file system, disk type, etc.). 14 | placeholder: Bug description 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: reproduction 19 | attributes: 20 | label: Reproduction 21 | description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue. 22 | placeholder: Reproduction 23 | - type: textarea 24 | id: logs 25 | attributes: 26 | label: Logs 27 | description: "Please include any printed warnings or errors if you can." 28 | render: shell 29 | - type: textarea 30 | id: system-info 31 | attributes: 32 | label: System info 33 | description: | 34 | Please dump your environment info by running the following commands and copy-paste the results here: 35 | ```txt 36 | huggingface-cli env 37 | env | grep HF_XET 38 | ``` 39 | 40 | If you are working in a notebook, please run it in a code cell: 41 | ```py 42 | import os 43 | from huggingface_hub import dump_environment_info 44 | 45 | # Dump environment info to the console 46 | dump_environment_info() 47 | 48 | # Dump HF_XET environment variables 49 | for key, value in os.environ.items(): 50 | if key.startswith("HF_XET"): 51 | print(f"{key}={value}") 52 | ``` 53 | render: shell 54 | placeholder: | 55 | - huggingface_hub version: 0.11.0.dev0 56 | - Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.35 57 | - Python version: 3.10.6 58 | ... 59 | validations: 60 | required: true -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Website Related 4 | url: https://github.com/huggingface/hub-docs/issues 5 | about: Feature requests and bug reports related to the website 6 | - name: Forum 7 | url: https://discuss.huggingface.co/ 8 | about: General usage questions and community discussions -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Propose an enhancement. 3 | labels: ["enhancement"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for this feature request! We appreciate your input and will review it as soon as possible. 9 | - type: textarea 10 | id: current-limitation 11 | attributes: 12 | label: Current Limitation 13 | description: A clear and concise description of the current limitation you would like to see addressed. 14 | placeholder: Current limitation description 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: feature-description 19 | attributes: 20 | label: Feature Description 21 | description: A clear and concise description of the feature you would like to see and how it would address your current issues. 22 | placeholder: Feature description 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: additional-context 27 | attributes: 28 | label: Additional Context 29 | description: Any additional context or information that may be relevant to the feature request. 30 | placeholder: Additional context 31 | validations: 32 | required: false 33 | -------------------------------------------------------------------------------- /.github/actions/cache-rust-build/action.yml: -------------------------------------------------------------------------------- 1 | name: 'Cache Rust Build' 2 | description: 'Cache Rust dependency and build artifact' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - name: Cache 7 | uses: actions/cache@v4 8 | with: 9 | path: | 10 | ~/.cargo/registry 11 | ~/.cargo/git 12 | target 13 | hf_xet/target 14 | hf_xet_wasm/target 15 | hf_xet_thin_wasm/target 16 | key: ${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('**/Cargo.lock') }} -------------------------------------------------------------------------------- /.github/actions/set-build-profile/action.yml: -------------------------------------------------------------------------------- 1 | name: Set Build Profile 2 | description: Set build profile to include debug checks for dev, alpha, and beta tags. 3 | runs: 4 | using: "composite" 5 | steps: 6 | - shell: bash 7 | run: | 8 | TAG=${{ inputs.tag }} 9 | LOWERTAG=$(echo "$TAG" | tr '[:upper:]' '[:lower:]') 10 | if [[ "$LOWERTAG" == *dev* || "$LOWERTAG" == *alpha* || "$LOWERTAG" == *beta* ]]; then 11 | BUILD_PROFILE=opt-test 12 | IS_RELEASE=false 13 | else 14 | BUILD_PROFILE=release-dbgsymbols 15 | IS_RELEASE=true 16 | fi 17 | echo "BUILD_PROFILE=$BUILD_PROFILE" >> $GITHUB_ENV 18 | echo "IS_RELEASE=$IS_RELEASE" >> $GITHUB_ENV 19 | inputs: 20 | tag: 21 | required: true 22 | type: string 23 | -------------------------------------------------------------------------------- /.github/actions/windows-codesign/action.yml: -------------------------------------------------------------------------------- 1 | name: Codesign with Microsoft Trusted Signing 2 | description: Sign Windows files with Microsoft Trusted Signing Service 3 | runs: 4 | using: "composite" 5 | steps: 6 | - uses: azure/trusted-signing-action@v0 7 | with: 8 | azure-tenant-id: ${{ inputs.azure_tenant_id }} 9 | azure-client-id: ${{ inputs.azure_client_id }} 10 | azure-client-secret: ${{ inputs.azure_client_secret }} 11 | endpoint: https://eus.codesigning.azure.net/ 12 | trusted-signing-account-name: tsa-huggingface-apps 13 | certificate-profile-name: git-xet-windows 14 | files: ${{ inputs.file }} 15 | file-digest: SHA256 16 | timestamp-rfc3161: http://timestamp.acs.microsoft.com 17 | timestamp-digest: SHA256 18 | exclude-environment-credential: false 19 | exclude-workload-identity-credential: true 20 | exclude-managed-identity-credential: true 21 | exclude-shared-token-cache-credential: true 22 | exclude-visual-studio-credential: true 23 | exclude-visual-studio-code-credential: true 24 | exclude-azure-cli-credential: true 25 | exclude-azure-powershell-credential: true 26 | exclude-azure-developer-cli-credential: true 27 | exclude-interactive-browser-credential: true 28 | inputs: 29 | file: 30 | required: true 31 | type: string 32 | azure_tenant_id: 33 | required: true 34 | type: string 35 | azure_client_id: 36 | required: true 37 | type: string 38 | azure_client_secret: 39 | required: true 40 | type: string -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: xet-core CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | fmt: 15 | name: Rustfmt 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: dtolnay/rust-toolchain@stable 20 | with: 21 | toolchain: nightly 22 | components: rustfmt 23 | - name: Format 24 | run: | 25 | cargo fmt --manifest-path ./Cargo.toml --all -- --check 26 | cargo fmt --manifest-path ./hf_xet/Cargo.toml --all -- --check 27 | 28 | build_and_test-linux: 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: Checkout repository 32 | uses: actions/checkout@v4 33 | - name: Install Rust 1.89 34 | uses: dtolnay/rust-toolchain@1.89.0 35 | with: 36 | components: clippy 37 | - uses: ./.github/actions/cache-rust-build 38 | - name: Lint 39 | run: | 40 | cargo clippy -r --verbose -- -D warnings # elevates warnings to errors 41 | cargo clippy -r --verbose --manifest-path hf_xet/Cargo.toml -- -D warnings # elevates warnings to errors 42 | - name: Set up Git LFS 43 | run: | 44 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 45 | sudo apt-get install git-lfs 46 | git lfs install 47 | - name: Build and Test 48 | run: | 49 | cargo test --verbose --no-fail-fast --features "strict" 50 | - name: Check Cargo.lock has no uncommitted changes 51 | run: | 52 | # the build and test steps would update Cargo.lock if it is out of date 53 | test -z "$(git status --porcelain Cargo.lock)" || (echo "Cargo.lock has uncommitted changes!" && exit 1) 54 | build_and_test-win: 55 | runs-on: windows-latest 56 | steps: 57 | - name: Checkout repository 58 | uses: actions/checkout@v4 59 | - name: Install Rust 1.89 60 | uses: dtolnay/rust-toolchain@1.89.0 61 | with: 62 | components: clippy 63 | - uses: ./.github/actions/cache-rust-build 64 | - name: Build and Test 65 | run: | 66 | cargo test --verbose --no-fail-fast --features "strict" 67 | build_and_test-macos: 68 | runs-on: macos-latest 69 | steps: 70 | - name: Checkout repository 71 | uses: actions/checkout@v4 72 | - name: Install Rust 1.89 73 | uses: dtolnay/rust-toolchain@1.89.0 74 | with: 75 | components: clippy 76 | - name: Set up Git LFS 77 | run: | 78 | brew install git-lfs 79 | git lfs install 80 | - uses: ./.github/actions/cache-rust-build 81 | - name: Build and Test 82 | run: | 83 | cargo test --verbose --no-fail-fast --features "strict" 84 | build_and_test-wasm: 85 | name: Build WASM 86 | runs-on: ubuntu-latest 87 | steps: 88 | - name: Checkout repository 89 | uses: actions/checkout@v4 90 | - name: Install Rust nightly 91 | uses: dtolnay/rust-toolchain@nightly 92 | with: 93 | targets: wasm32-unknown-unknown 94 | components: rust-src 95 | - uses: ./.github/actions/cache-rust-build 96 | - name: Install wasm-bindgen-cli and wasm-pack 97 | run: | 98 | cargo install --version 0.2.100 wasm-bindgen-cli 99 | cargo install --version 0.13.1 wasm-pack 100 | - name: Build hf_xet_thin_wasm 101 | working-directory: hf_xet_thin_wasm 102 | run: | 103 | ./build_wasm.sh 104 | - name: Build hf_xet_wasm 105 | working-directory: hf_xet_wasm 106 | run: | 107 | ./build_wasm.sh 108 | -------------------------------------------------------------------------------- /.github/workflows/hf-xet-tests.yml: -------------------------------------------------------------------------------- 1 | name: Test huggingface_hub xet tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '*' 9 | pull_request: 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | hub-python-tests: 17 | runs-on: ubuntu-latest 18 | steps: 19 | # checkout out xet-core 20 | - uses: actions/checkout@v4 21 | # checkout out huggingface_hub 22 | - uses: actions/checkout@v4 23 | with: 24 | repository: huggingface/huggingface_hub 25 | path: huggingface_hub 26 | - uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.10' 29 | - name: Create venv 30 | run: python3 -m venv .venv 31 | - name: Build wheel 32 | uses: PyO3/maturin-action@v1 33 | with: 34 | command: develop 35 | sccache: 'true' 36 | working-directory: hf_xet 37 | - name: Install huggingface_hub dependencies 38 | run: | 39 | source .venv/bin/activate 40 | python3 -m pip install -e 'huggingface_hub[testing]' 41 | - name: Run huggingface_hub xet tests 42 | run: | 43 | source .venv/bin/activate 44 | pytest huggingface_hub/tests/test_xet_*.py 45 | - name: Check Cargo.lock has no uncommitted changes 46 | run: | 47 | # the Build wheel step would update hf_xet/Cargo.lock if it is out of date 48 | test -z "$(git status --porcelain hf_xet/Cargo.lock)" || (echo "hf_xet/Cargo.lock has uncommitted changes!" && exit 1) -------------------------------------------------------------------------------- /.github/workflows/pre-release-testing.yml: -------------------------------------------------------------------------------- 1 | name: hf-xet prerelease testing 2 | # This workflow is triggered when a new pre-release build is triggered by the release workflow. 3 | 4 | on: 5 | push: 6 | tags: 7 | - v*rc* 8 | workflow_dispatch: 9 | inputs: 10 | tag: 11 | description: "Tag to test (e.g., v1.0.3-rc2)" 12 | required: true 13 | jobs: 14 | trigger_rc_testing: 15 | runs-on: ubuntu-latest 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | target-repo: ["huggingface_hub"] 21 | 22 | steps: 23 | - name: Determine PyPi version from tag 24 | id: get-version 25 | run: | 26 | if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then 27 | TAG=${{ inputs.tag }} 28 | else 29 | TAG=${GITHUB_REF#refs/tags/} 30 | fi 31 | SEM_VERSION=${TAG#v} 32 | TRIMMED_VERSION=${SEM_VERSION//-} 33 | echo "VERSION=${TRIMMED_VERSION}" >> $GITHUB_OUTPUT 34 | echo "BRANCH_NAME=ci_test_hf_xet_${TRIMMED_VERSION}_release" >> $GITHUB_OUTPUT 35 | 36 | - name: Checkout target repo 37 | uses: actions/checkout@v4 38 | with: 39 | repository: huggingface/${{ matrix.target-repo }} 40 | path: ${{ matrix.target-repo }} 41 | token: ${{ secrets.TOKEN_HUGGINGFACE_HUB_AUTO_BY_XET }} 42 | 43 | - name: Configure Git 44 | run: | 45 | cd ${{ matrix.target-repo }} 46 | git config user.name "Hugging Face Bot (Xet RC Testing)" 47 | git config user.email "bot+xet@huggingface.co" 48 | 49 | - name: Wait for prerelease to be out on PyPI 50 | run: | 51 | VERSION=${{ steps.get-version.outputs.VERSION }} 52 | echo "Waiting for hf_xet==${VERSION} to be available on PyPI" 53 | while ! pip install hf_xet==${VERSION}; do 54 | echo "hf_xet==${VERSION} not available yet, retrying in 15s" 55 | sleep 15 56 | done 57 | 58 | - name: Create test branch and update dependencies 59 | id: create-pr 60 | run: | 61 | cd ${{ matrix.target-repo }} 62 | VERSION=${{ steps.get-version.outputs.VERSION }} 63 | BRANCH_NAME=${{ steps.get-version.outputs.BRANCH_NAME }} 64 | 65 | # Create and checkout new branch 66 | git checkout -b $BRANCH_NAME 67 | 68 | # Update hf_xet dependency to use the fixed rc version 69 | sed -i -E "s/hf(_|-)xet(>|<|=)=(([0-9]+\.[0-9]+\.[0-9]+)|([0-9]+\.[0-9]+\.[0-9]+,<[0-9]+\.[0-9]+\.[0-9]+))/hf_xet==${VERSION}/g" setup.py 70 | git add setup.py 71 | 72 | # Any line with `uv pip install --prerelease=allow` in the `.github/` folder must be updated with `--prerelease=allow` flag 73 | find .github/workflows/ -type f -exec sed -i 's/uv pip install /uv pip install --prerelease=allow /g' {} + 74 | git add .github/workflows/ 75 | 76 | # Commit and push changes 77 | git --no-pager diff --staged 78 | git commit -m "Test hfh ${VERSION}" 79 | git push --set-upstream origin $BRANCH_NAME 80 | 81 | - name: Print URLs for manual check 82 | run: | 83 | VERSION=${{ steps.get-version.outputs.VERSION }} 84 | BRANCH_NAME=${{ steps.get-version.outputs.BRANCH_NAME }} 85 | echo "https://github.com/huggingface/${{ matrix.target-repo }}/actions" 86 | echo "https://github.com/huggingface/${{matrix.target-repo}}/tree/refs/heads/${BRANCH_NAME}" 87 | echo "https://github.com/huggingface/huggingface_hub/${{ matrix.target-repo }}/compare/main...${BRANCH_NAME}" 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.idea 2 | # Generated by Cargo 3 | # will have compiled files and executables 4 | debug/ 5 | **/target/ 6 | 7 | # These are backup files generated by rustfmt 8 | **/*.rs.bk 9 | 10 | # MSVC Windows builds of rustc generate these, which store debugging information 11 | *.pdb 12 | 13 | # Mac OS Trash 14 | .DS_Store 15 | 16 | # VS Code configs 17 | .vscode/* 18 | !.vscode/settings.json 19 | venv 20 | **/*.env 21 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[rust]": { 3 | "editor.defaultFormatter": "rust-lang.rust-analyzer" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | 4 | members = [ 5 | "cas_client", 6 | "cas_object", 7 | "cas_types", 8 | "chunk_cache", 9 | "data", 10 | "deduplication", 11 | "error_printer", 12 | "file_utils", 13 | "git_xet", 14 | "hub_client", 15 | "mdb_shard", 16 | "merklehash", 17 | "progress_tracking", 18 | "utils", 19 | "xet_runtime", 20 | ] 21 | 22 | exclude = ["chunk_cache_bench", "hf_xet", "hf_xet_wasm", "hf_xet_thin_wasm"] 23 | 24 | [profile.release] 25 | opt-level = 3 26 | lto = true 27 | debug = 1 28 | 29 | [profile.opt-test] 30 | inherits = "dev" 31 | opt-level = 3 32 | debug = 1 33 | 34 | [workspace.dependencies] 35 | anyhow = "1" 36 | async-trait = "0.1" 37 | base64 = "0.22" 38 | bincode = "1.3" 39 | bitflags = { version = "2.9", features = ["serde"] } 40 | blake3 = "1.5" 41 | bytes = "1.8" 42 | chrono = "0.4" 43 | clap = { version = "4", features = ["derive"] } 44 | colored = "2" 45 | countio = { version = "0.2", features = ["futures"] } 46 | crc32fast = "1.4" 47 | csv = "1" 48 | ctor = "0.4" 49 | derivative = "2.2" 50 | dirs = "5.0" 51 | duration-str = "0.17" 52 | futures = "0.3" 53 | futures-util = "0.3" 54 | gearhash = "0.1" 55 | getrandom = "0.3" 56 | git-url-parse = "0.4" 57 | git2 = "0.20" 58 | half = "2.4" 59 | heapify = "0.2" 60 | heed = "0.11" 61 | http = "1" 62 | hyper = "1.7" 63 | hyper-util = "0.1" 64 | itertools = "0.14" 65 | jsonwebtoken = "9.3" 66 | lazy_static = "1.5" 67 | libc = "0.2" 68 | lz4_flex = "0.11" 69 | mockall = "0.13" 70 | more-asserts = "0.3" 71 | once_cell = "1.20" 72 | oneshot = "0.1" 73 | openssh = "0.11" 74 | pin-project = "1" 75 | prometheus = "0.14" 76 | rand = "0.9" 77 | rand_chacha = "0.9" 78 | rayon = "1.5" 79 | regex = "1" 80 | reqwest = { version = "0.12", features = [ 81 | "json", 82 | "stream", 83 | "system-proxy", 84 | "socks" 85 | ], default-features = false } 86 | reqwest-middleware = "0.4" 87 | reqwest-retry = "0.7" 88 | rust-netrc = "0.1" 89 | rustc-hash = "1.1" 90 | safe-transmute = "0.11" 91 | serde = { version = "1", features = ["derive"] } 92 | serde_json = "1" 93 | serde_repr = "0.1" 94 | sha2 = "0.10" 95 | shellexpand = "3.1" 96 | static_assertions = "1.1" 97 | tempfile = "3.20" 98 | thiserror = "2.0" 99 | tokio = { version = "1.47" } 100 | tokio-retry = "0.3" 101 | tokio-util = { version = "0.7" } 102 | tower-service = "0.3" 103 | tracing = "0.1" 104 | ulid = "1.2" 105 | url = "2.5" 106 | urlencoding = "2.1" 107 | uuid = "1" 108 | walkdir = "2" 109 | web-time = "1.1" 110 | whoami = "1" 111 | 112 | # windows 113 | winapi = { version = "0.3", features = [ 114 | "winerror", 115 | "winnt", 116 | "handleapi", 117 | "processthreadsapi", 118 | "securitybaseapi", 119 | ] } 120 | 121 | # dev-deps 122 | criterion = { version = "0.5", features = ["html_reports"] } 123 | httpmock = "0.7" 124 | serial_test = "3" 125 | tempdir = "0.3" 126 | tracing-test = { version = "0.2", features = ["no-env-filter"] } 127 | wiremock = "0.6" 128 | -------------------------------------------------------------------------------- /cas_client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cas_client" 3 | version = "0.14.5" 4 | edition = "2024" 5 | 6 | 7 | [dependencies] 8 | cas_object = { path = "../cas_object" } 9 | cas_types = { path = "../cas_types" } 10 | chunk_cache = { path = "../chunk_cache" } 11 | deduplication = { path = "../deduplication" } 12 | error_printer = { path = "../error_printer" } 13 | file_utils = { path = "../file_utils" } 14 | mdb_shard = { path = "../mdb_shard" } 15 | merklehash = { path = "../merklehash" } 16 | progress_tracking = { path = "../progress_tracking" } 17 | utils = { path = "../utils" } 18 | xet_runtime = { path = "../xet_runtime" } 19 | 20 | anyhow = { workspace = true } 21 | async-trait = { workspace = true } 22 | bytes = { workspace = true } 23 | derivative = { workspace = true } 24 | futures = { workspace = true } 25 | http = { workspace = true } 26 | hyper-util = { workspace = true } 27 | lazy_static = { workspace = true } 28 | more-asserts = { workspace = true } 29 | reqwest = { workspace = true } 30 | reqwest-middleware = { workspace = true } 31 | reqwest-retry = { workspace = true } 32 | serde = { workspace = true } 33 | serde_json = { workspace = true } 34 | tempfile = { workspace = true } 35 | thiserror = { workspace = true } 36 | tokio = { workspace = true } 37 | tokio-retry = { workspace = true } 38 | tower-service = { workspace = true } 39 | tracing = { workspace = true } 40 | url = { workspace = true } 41 | 42 | # Reqwest -- uses different flags with different features. 43 | [features] 44 | strict = [] 45 | default = ["rustls-tls"] 46 | 47 | # Three options for compliation here. 48 | rustls-tls = [ 49 | "reqwest/rustls-tls", 50 | "reqwest/rustls-tls-webpki-roots", 51 | "reqwest/rustls-tls-native-roots", 52 | ] 53 | 54 | # rustls-tls uses the rustls package, which embeds all of the ssl stuff in a rust package. This is the 55 | # most portable option, but also may not respect local network configurations. Use this if the native-ssl options don't work. 56 | # Uses native tls in the request package; this uses the native-tls package to wrap openssl, which is a more robust and portable 57 | # way of ensuring that tls just works. 58 | native-tls = ["reqwest/native-tls"] 59 | 60 | # This uses the above, but statically compiles in openssl, which makes the result more portable at the expense of 61 | # library size. 62 | native-tls-vendored = ["reqwest/native-tls-vendored"] 63 | 64 | 65 | [target.'cfg(not(target_family = "wasm"))'.dependencies] 66 | heed = { workspace = true } 67 | hyper = { workspace = true } 68 | 69 | [dev-dependencies] 70 | httpmock = { workspace = true } 71 | rand = { workspace = true } 72 | tracing-test = { workspace = true } 73 | wiremock = { workspace = true } 74 | -------------------------------------------------------------------------------- /cas_client/README.md: -------------------------------------------------------------------------------- 1 | # CAS client 2 | 3 | This package is responsible for handling all communication with the CAS services. 4 | 5 | ## Layout 6 | 7 | Check out the traits published by this crate to understand how it is intended to be used. 8 | These are stored in [interface.rs](../src/interface.rs). 9 | 10 | ### Main impl of Client trait 11 | 12 | - [remote_client.rs](../src/remote_client.rs): This is the main impl of Client - and is responsible for communicating with a remote CAS. 13 | - [../src/local_client.rs]: This is an impl of Client for local filesystem usage. It is only used for testing. 14 | 15 | ### Caching 16 | 17 | Caching happens locally using the [chunk_cache](../chunk_cache) crate, specifically using the ChunkCache. 18 | When RemoteClient is provided a ChunkCache then it will use this on download calls (the ReconstructionClient trait `get_file`). 19 | 20 | ## Overall CAS Communication Design 21 | 22 | ### Authentication 23 | 24 | Authentication is done using AuthMiddleware, which sets an Authorization Header and refreshes it periodically with CAS. 25 | See [http_client.rs](../src/http_client.rs). 26 | 27 | ### Retry 28 | 29 | HTTP operations are retried using a RetryPolicy defined in [http_client.rs](../src/http_client.rs). 30 | This is implemented as Middleware for the reqwest HTTP clients. 31 | 32 | ### Operations 33 | 34 | CAS offers a set of services used by the client to upload and download user files. 35 | These files are stored using two different types of storage objects, Xorbs and Shards. 36 | Xorbs contain chunks and Shards contain mappings of files to Xorb chunks. 37 | 38 | ### Logging / Tracing 39 | 40 | Logging & Tracing is done through the tracing crate, with `info!`, `warn!`, and `debug!` macros widely used in the code. 41 | 42 | ### Progress tracking 43 | 44 | To enable progress updates, pass `Some(Arc)` to `Client::get_file(...)` when downloading and `Some(Arc)` to `Client::upload_xorb(...)` when uploading. 45 | -------------------------------------------------------------------------------- /cas_client/src/constants.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | utils::configurable_constants! { 4 | 5 | /// Retry at most this many times before permanently failing. 6 | ref CLIENT_RETRY_MAX_ATTEMPTS : usize = 5; 7 | 8 | /// On errors that can be retried, delay for this amount of time 9 | /// before retrying. 10 | ref CLIENT_RETRY_BASE_DELAY : Duration = Duration::from_millis(3000); 11 | 12 | /// After this much time has passed since the first attempt, 13 | /// no more retries are attempted. 14 | ref CLIENT_RETRY_MAX_DURATION: Duration = Duration::from_secs(6 * 60); 15 | 16 | /// Cleanup idle connections that are unused for this amount of time. 17 | ref CLIENT_IDLE_CONNECTION_TIMEOUT: Duration = Duration::from_secs(60); 18 | 19 | /// Only no more than this number of idle connections in the connection pool. 20 | ref CLIENT_MAX_IDLE_CONNECTIONS: usize = 16; 21 | } 22 | -------------------------------------------------------------------------------- /cas_client/src/exports.rs: -------------------------------------------------------------------------------- 1 | // Re-export this with the current configurations 2 | pub use reqwest; 3 | pub use reqwest_middleware::ClientWithMiddleware; 4 | -------------------------------------------------------------------------------- /cas_client/src/interface.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::sync::Arc; 3 | 4 | use bytes::Bytes; 5 | use cas_object::SerializedCasObject; 6 | use cas_types::FileRange; 7 | use mdb_shard::file_structs::MDBFileInfo; 8 | use merklehash::MerkleHash; 9 | use progress_tracking::item_tracking::SingleItemProgressUpdater; 10 | use progress_tracking::upload_tracking::CompletionTracker; 11 | 12 | #[cfg(not(target_family = "wasm"))] 13 | use crate::OutputProvider; 14 | use crate::error::Result; 15 | 16 | /// A Client to the Shard service. The shard service 17 | /// provides for 18 | /// 1. upload shard to the shard service 19 | /// 2. querying of file->reconstruction information 20 | /// 3. querying of chunk->shard information 21 | #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] 22 | #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] 23 | pub trait Client { 24 | /// Get an entire file by file hash with an optional bytes range. 25 | /// 26 | /// The http_client passed in is a non-authenticated client. This is used to directly communicate 27 | /// with the backing store (S3) to retrieve xorbs. 28 | #[cfg(not(target_family = "wasm"))] 29 | async fn get_file( 30 | &self, 31 | hash: &MerkleHash, 32 | byte_range: Option, 33 | output_provider: &OutputProvider, 34 | progress_updater: Option>, 35 | ) -> Result; 36 | 37 | #[cfg(not(target_family = "wasm"))] 38 | async fn batch_get_file(&self, files: HashMap) -> Result { 39 | let mut n_bytes = 0; 40 | // Provide the basic naive implementation as a default. 41 | for (h, w) in files { 42 | n_bytes += self.get_file(&h, None, w, None).await?; 43 | } 44 | Ok(n_bytes) 45 | } 46 | 47 | async fn get_file_reconstruction_info( 48 | &self, 49 | file_hash: &MerkleHash, 50 | ) -> Result)>>; 51 | 52 | async fn query_for_global_dedup_shard(&self, prefix: &str, chunk_hash: &MerkleHash) -> Result>; 53 | 54 | /// Upload a new shard. 55 | async fn upload_shard(&self, shard_data: Bytes) -> Result; 56 | 57 | /// Upload a new xorb. 58 | async fn upload_xorb( 59 | &self, 60 | prefix: &str, 61 | serialized_cas_object: SerializedCasObject, 62 | upload_tracker: Option>, 63 | ) -> Result; 64 | 65 | /// Indicates if the serialized cas object should have a written footer. 66 | /// This should only be true for testing with LocalClient. 67 | fn use_xorb_footer(&self) -> bool; 68 | 69 | /// Indicates if the serialized cas object should have a written footer. 70 | /// This should only be true for testing with LocalClient. 71 | fn use_shard_footer(&self) -> bool; 72 | } 73 | -------------------------------------------------------------------------------- /cas_client/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | pub use chunk_cache::{CHUNK_CACHE_SIZE_BYTES, CacheConfig}; 4 | pub use http_client::{Api, ResponseErrorLogger, RetryConfig, build_auth_http_client, build_http_client}; 5 | pub use interface::Client; 6 | #[cfg(not(target_family = "wasm"))] 7 | pub use local_client::LocalClient; 8 | #[cfg(not(target_family = "wasm"))] 9 | pub use output_provider::{FileProvider, OutputProvider}; 10 | pub use remote_client::RemoteClient; 11 | 12 | pub use crate::error::CasClientError; 13 | 14 | mod constants; 15 | #[cfg(not(target_family = "wasm"))] 16 | mod download_utils; 17 | mod error; 18 | pub mod exports; 19 | mod http_client; 20 | mod interface; 21 | #[cfg(not(target_family = "wasm"))] 22 | mod local_client; 23 | #[cfg(not(target_family = "wasm"))] 24 | mod output_provider; 25 | pub mod remote_client; 26 | mod retry_wrapper; 27 | mod upload_progress_stream; 28 | -------------------------------------------------------------------------------- /cas_client/src/output_provider.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Cursor, Seek, SeekFrom, Write}; 2 | use std::path::PathBuf; 3 | use std::sync::{Arc, Mutex}; 4 | 5 | use crate::error::Result; 6 | 7 | /// Enum of different output formats to write reconstructed files. 8 | #[derive(Debug, Clone)] 9 | pub enum OutputProvider { 10 | File(FileProvider), 11 | #[cfg(test)] 12 | Buffer(BufferProvider), 13 | } 14 | 15 | impl OutputProvider { 16 | /// Create a new writer to start writing at the indicated start location. 17 | pub(crate) fn get_writer_at(&self, start: u64) -> Result> { 18 | match self { 19 | OutputProvider::File(fp) => fp.get_writer_at(start), 20 | #[cfg(test)] 21 | OutputProvider::Buffer(bp) => bp.get_writer_at(start), 22 | } 23 | } 24 | } 25 | 26 | /// Provides new Writers to a file located at a particular location 27 | #[derive(Debug, Clone)] 28 | pub struct FileProvider { 29 | filename: PathBuf, 30 | } 31 | 32 | impl FileProvider { 33 | pub fn new(filename: PathBuf) -> Self { 34 | Self { filename } 35 | } 36 | 37 | fn get_writer_at(&self, start: u64) -> Result> { 38 | let mut file = std::fs::OpenOptions::new() 39 | .write(true) 40 | .truncate(false) 41 | .create(true) 42 | .open(&self.filename)?; 43 | file.seek(SeekFrom::Start(start))?; 44 | Ok(Box::new(file)) 45 | } 46 | } 47 | 48 | #[derive(Debug, Default, Clone)] 49 | pub struct BufferProvider { 50 | pub buf: ThreadSafeBuffer, 51 | } 52 | 53 | impl BufferProvider { 54 | pub fn get_writer_at(&self, start: u64) -> crate::error::Result> { 55 | let mut buffer = self.buf.clone(); 56 | buffer.idx = start; 57 | Ok(Box::new(buffer)) 58 | } 59 | } 60 | 61 | #[derive(Debug, Default, Clone)] 62 | /// Thread-safe in-memory buffer that implements [Write](Write) trait at some position 63 | /// within an underlying buffer and allows access to inner buffer. 64 | /// Thread-safe in-memory buffer that implements [Write](Write) trait and allows 65 | /// access to inner buffer 66 | pub struct ThreadSafeBuffer { 67 | idx: u64, 68 | inner: Arc>>>, 69 | } 70 | impl ThreadSafeBuffer { 71 | pub fn value(&self) -> Vec { 72 | self.inner.lock().unwrap().get_ref().clone() 73 | } 74 | } 75 | 76 | impl std::io::Write for ThreadSafeBuffer { 77 | fn write(&mut self, buf: &[u8]) -> std::io::Result { 78 | let mut guard = self.inner.lock().map_err(|e| std::io::Error::other(format!("{e}")))?; 79 | guard.set_position(self.idx); 80 | let num_written = guard.write(buf)?; 81 | self.idx = guard.position(); 82 | Ok(num_written) 83 | } 84 | 85 | fn flush(&mut self) -> std::io::Result<()> { 86 | Ok(()) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /cas_object/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cas_object" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [[bench]] 7 | name = "compression_bench" 8 | harness = false 9 | bench = true 10 | 11 | [[bench]] 12 | name = "bg_split_regroup_bench" 13 | harness = false 14 | bench = true 15 | 16 | [dependencies] 17 | deduplication = { path = "../deduplication" } 18 | error_printer = { path = "../error_printer" } 19 | mdb_shard = { path = "../mdb_shard" } 20 | merklehash = { path = "../merklehash" } 21 | utils = { path = "../utils" } 22 | 23 | anyhow = { workspace = true } 24 | blake3 = { workspace = true } 25 | bytes = { workspace = true } 26 | clap = { workspace = true } 27 | countio = { workspace = true } 28 | csv = { workspace = true } 29 | futures = { workspace = true } 30 | half = { workspace = true } 31 | lz4_flex = { workspace = true } 32 | more-asserts = { workspace = true } 33 | rand = { workspace = true } 34 | serde = { workspace = true } 35 | thiserror = { workspace = true } 36 | tokio = { workspace = true, features = ["time", "rt", "macros", "io-util"] } 37 | tokio-util = { workspace = true, features = ["io"] } 38 | tracing = { workspace = true } 39 | 40 | [target.'cfg(not(target_family = "wasm"))'.dependencies] 41 | tokio = { workspace = true, features = [ 42 | "time", 43 | "rt", 44 | "macros", 45 | "io-util", 46 | "rt-multi-thread", 47 | ] } 48 | 49 | [[bin]] 50 | path = "src/byte_grouping/compression_stats/collect_compression_stats.rs" 51 | name = "collect_compression_stats" 52 | 53 | [[bin]] 54 | path = "src/byte_grouping/bg4_prediction_benchmark.rs" 55 | name = "bg4_prediction_benchmark" 56 | -------------------------------------------------------------------------------- /cas_object/src/byte_grouping/bg4_prediction_benchmark.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | use cas_object::byte_grouping::BG4Predictor; 4 | use rand::prelude::*; 5 | 6 | fn main() { 7 | const SIZE_MB: usize = 100; 8 | const SIZE: usize = SIZE_MB * 1024 * 1024; 9 | 10 | let mut rng = StdRng::seed_from_u64(12345); 11 | let mut data = vec![0u8; SIZE]; 12 | rng.fill_bytes(&mut data); 13 | 14 | let offset = 0; 15 | 16 | let mut ref_pred = BG4Predictor::default(); 17 | let start = Instant::now(); 18 | ref_pred.add_data_reference(offset, &data); 19 | let duration = start.elapsed().as_secs_f64(); 20 | println!("Reference: {:.2} MB/s", SIZE_MB as f64 / duration); 21 | 22 | let mut ref_v1 = BG4Predictor::default(); 23 | let start = Instant::now(); 24 | ref_v1.add_data_v1(offset, &data); 25 | let duration = start.elapsed().as_secs_f64(); 26 | println!("V1: {:.2} MB/s", SIZE_MB as f64 / duration); 27 | 28 | let mut ref_swar = BG4Predictor::default(); 29 | let start = Instant::now(); 30 | ref_swar.add_data_swar(offset, &data); 31 | let duration = start.elapsed().as_secs_f64(); 32 | println!("SWAR: {:.2} MB/s", SIZE_MB as f64 / duration); 33 | 34 | let mut new_pred = BG4Predictor::default(); 35 | let start = Instant::now(); 36 | new_pred.add_data(offset, &data); 37 | let duration = start.elapsed().as_secs_f64(); 38 | println!("Optimized: {:.2} MB/s", SIZE_MB as f64 / duration); 39 | 40 | assert_eq!(ref_pred.histograms(), new_pred.histograms()); 41 | } 42 | -------------------------------------------------------------------------------- /cas_object/src/byte_grouping/compression_stats/compression_prediction_tests.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from sklearn.linear_model import LogisticRegression 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.metrics import classification_report, accuracy_score, confusion_matrix 6 | 7 | # TO USE: 8 | # 9 | # python compression_prediction_tests.py compression_stats.csv 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description="Train a logistic regression model to predict when BG4 is >10% better.") 13 | parser.add_argument("csv_file", type=str, help="Path to the CSV file with block analysis results.") 14 | args = parser.parse_args() 15 | 16 | # 1. Load data 17 | df = pd.read_csv(args.csv_file) 18 | 19 | # 2. Create the target: is size_scheme_2 at least 10% smaller than size_scheme_1? 20 | # Equivalently, is size_scheme_2 < 0.9 * size_scheme_1? 21 | # We'll define a 'big_improvement' column: 1 if BG4 is >10% better, else 0. 22 | df["improvement"] = ( 23 | df["size_scheme_2"] < df["size_scheme_1"] 24 | ).astype(int) 25 | 26 | df["big_improvement"] = ( 27 | df["size_scheme_2"] < 0.95 * df["size_scheme_1"] 28 | ).astype(int) 29 | 30 | df["possible_improvement"] = ( 31 | df["size_scheme_2"] < 1.05 * df["size_scheme_1"] 32 | ).astype(int) 33 | 34 | slice_cols = ["slice_0_entropy", "slice_1_entropy", "slice_2_entropy", "slice_3_entropy"] 35 | 36 | df["min_slice_entropy"] = df[slice_cols].min(axis=1) - df["full_entropy"] 37 | df["max_slice_entropy"] = df[slice_cols].max(axis=1) - df["full_entropy"] 38 | 39 | slice_cols = ["slice_0_kl", "slice_1_kl", "slice_2_kl", "slice_3_kl"] 40 | 41 | df["max_kl"] = df[slice_cols].max(axis=1) 42 | 43 | # 3. Select features to use. This one is simply to use the maximum kl divergence as the 44 | # sole feature. 45 | features = [ 46 | "max_kl", 47 | ] 48 | 49 | X = df[features] 50 | y = df["improvement"] 51 | 52 | 53 | # 5. Train Logistic Regression 54 | clf = LogisticRegression(max_iter=1000, random_state=42) 55 | clf.fit(X, y) 56 | 57 | # 6. Evaluate. Here, because it is a super simple model, just use the whole data to evaluate it instead 58 | # of bothering with a train/test split. 59 | y_pred = clf.predict(X) 60 | 61 | # This rule is what the learned regression above uses. 62 | # y_pred = (X["max_kl"] > 0.02).astype(int) 63 | 64 | accuracy = accuracy_score(y, y_pred) 65 | print("Accuracy on test set:", accuracy) 66 | print(classification_report(y, y_pred)) 67 | 68 | print("\nAccuracy on big improvement", accuracy) 69 | print(classification_report(df["big_improvement"], y_pred)) 70 | 71 | cm = confusion_matrix(df["big_improvement"], y_pred) 72 | tn, fp, fn, tp = cm.ravel() 73 | 74 | print(f"Incorrectly predicted {fn} / {fn + tn} of cases where bg4_lz4_size < 0.95 * lz4_size") 75 | 76 | 77 | print("\nAccuracy on possible improvement", accuracy) 78 | print(classification_report(df["possible_improvement"], y_pred)) 79 | 80 | cm = confusion_matrix(df["possible_improvement"], y_pred) 81 | tn, fp, fn, tp = cm.ravel() 82 | 83 | print(f"Incorrectly predicted {fp} / {fp + tp} of cases where bg4_lz4_size > 1.05 * lz4_size") 84 | 85 | print("Coefficients:", clf.coef_) 86 | print("intercept:", clf.intercept_) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | 92 | -------------------------------------------------------------------------------- /cas_object/src/byte_grouping/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod bg4; 2 | mod bg4_prediction; 3 | 4 | pub use bg4_prediction::BG4Predictor; 5 | -------------------------------------------------------------------------------- /cas_object/src/constants.rs: -------------------------------------------------------------------------------- 1 | use utils::configurable_constants; 2 | 3 | configurable_constants! { 4 | /// How often should we retest the compression scheme? 5 | /// Determining the optimal compression scheme takes time, but 6 | /// it also minimizes the storage costs of the data. 7 | /// 8 | /// If set to zero, it's set once per file block per xorb. 9 | ref CAS_OBJECT_COMPRESSION_SCHEME_RETEST_INTERVAL : usize = 32; 10 | 11 | /// Target 1024 chunks per CAS block 12 | ref IDEAL_CAS_BLOCK_SIZE: usize = release_fixed(64 * 1024 * 1024); 13 | } 14 | -------------------------------------------------------------------------------- /cas_object/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | 3 | use thiserror::Error; 4 | use tracing::warn; 5 | 6 | #[non_exhaustive] 7 | #[derive(Error, Debug)] 8 | pub enum CasObjectError { 9 | #[error("Invalid Range Read")] 10 | InvalidRange, 11 | 12 | #[error("Invalid Arguments")] 13 | InvalidArguments, 14 | 15 | #[error("Format Error: {0}")] 16 | FormatError(anyhow::Error), 17 | 18 | #[error("Hash Mismatch")] 19 | HashMismatch, 20 | 21 | #[error("Internal IO Error: {0}")] 22 | InternalIOError(#[from] std::io::Error), 23 | 24 | #[error("Other Internal Error: {0}")] 25 | InternalError(anyhow::Error), 26 | 27 | #[error("(De)Compression Error: {0}")] 28 | CompressionError(#[from] lz4_flex::frame::Error), 29 | 30 | #[error("Internal Hash Parsing Error")] 31 | HashParsingError(#[from] Infallible), 32 | 33 | #[error("ChunkHeaderParseErrorFooterIdent")] 34 | ChunkHeaderParseErrorFooterIdent, 35 | } 36 | 37 | // Define our own result type here (this seems to be the standard). 38 | pub type Result = std::result::Result; 39 | 40 | impl PartialEq for CasObjectError { 41 | fn eq(&self, other: &CasObjectError) -> bool { 42 | std::mem::discriminant(self) == std::mem::discriminant(other) 43 | } 44 | } 45 | 46 | /// Helper trait to swallow CAS object format errors. Used in object 47 | /// validation to reject the object instead of propagating errors. 48 | pub trait Validate { 49 | fn ok_for_format_error(self) -> Result>; 50 | } 51 | 52 | impl Validate for Result { 53 | fn ok_for_format_error(self) -> Result> { 54 | match self { 55 | Ok(v) => Ok(Some(v)), 56 | Err(CasObjectError::FormatError(e)) => { 57 | warn!("XORB Validation: {e}"); 58 | Ok(None) 59 | }, 60 | Err(e) => Err(e), 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /cas_object/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod byte_grouping; 2 | mod cas_chunk_format; 3 | mod cas_object_format; 4 | mod compression_scheme; 5 | pub mod constants; 6 | pub mod error; 7 | 8 | pub use cas_chunk_format::*; 9 | pub use cas_object_format::*; 10 | pub use compression_scheme::*; 11 | -------------------------------------------------------------------------------- /cas_types/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cas_types" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | merklehash = { path = "../merklehash" } 8 | 9 | serde = { workspace = true } 10 | serde_repr = { workspace = true } 11 | thiserror = { workspace = true } 12 | -------------------------------------------------------------------------------- /cas_types/README.md: -------------------------------------------------------------------------------- 1 | # cas_types 2 | 3 | This crate provides type definitions for common types and formats used throughout the system. 4 | 5 | ## Features 6 | 7 | - Core type definitions for content addressing and data structures 8 | - Hex format serialization support for `HexMerkleHash` and `HexKey` types 9 | - Consistent type interfaces across the codebase 10 | 11 | ## Usage 12 | 13 | Add this crate as a dependency to access standardized type definitions and serialization formats. 14 | 15 | ```toml 16 | [dependencies] 17 | cas_types = { path = "../cas_types" } 18 | ``` 19 | -------------------------------------------------------------------------------- /cas_types/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[non_exhaustive] 4 | #[derive(Error, Debug)] 5 | pub enum CasTypesError { 6 | #[error("Invalid key: {0}")] 7 | InvalidKey(String), 8 | } 9 | -------------------------------------------------------------------------------- /cas_types/src/key.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | use std::str::FromStr; 3 | 4 | use merklehash::MerkleHash; 5 | use merklehash::data_hash::hex; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | use crate::error::CasTypesError; 9 | 10 | /// A Key indicates a prefixed merkle hash for some data stored in the CAS DB. 11 | #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Ord, PartialOrd, Eq, Hash, Clone)] 12 | pub struct Key { 13 | pub prefix: String, 14 | pub hash: MerkleHash, 15 | } 16 | 17 | impl Display for Key { 18 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 19 | write!(f, "{}/{:x}", self.prefix, self.hash) 20 | } 21 | } 22 | 23 | impl FromStr for Key { 24 | type Err = CasTypesError; 25 | 26 | fn from_str(s: &str) -> Result { 27 | let parts = s.rsplit_once('/'); 28 | let Some((prefix, hash)) = parts else { 29 | return Err(CasTypesError::InvalidKey(s.to_owned())); 30 | }; 31 | 32 | let hash = MerkleHash::from_hex(hash).map_err(|_| CasTypesError::InvalidKey(s.to_owned()))?; 33 | 34 | Ok(Key { 35 | prefix: prefix.to_owned(), 36 | hash, 37 | }) 38 | } 39 | } 40 | 41 | #[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq, Hash)] 42 | pub struct HexMerkleHash(#[serde(with = "hex::serde")] pub MerkleHash); 43 | 44 | impl Display for HexMerkleHash { 45 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 46 | write!(f, "{}", self.0.hex()) 47 | } 48 | } 49 | 50 | impl From for HexMerkleHash { 51 | fn from(value: MerkleHash) -> Self { 52 | HexMerkleHash(value) 53 | } 54 | } 55 | 56 | impl From for MerkleHash { 57 | fn from(value: HexMerkleHash) -> Self { 58 | value.0 59 | } 60 | } 61 | 62 | impl From<&HexMerkleHash> for MerkleHash { 63 | fn from(value: &HexMerkleHash) -> Self { 64 | value.0 65 | } 66 | } 67 | 68 | impl From<&MerkleHash> for HexMerkleHash { 69 | fn from(value: &MerkleHash) -> Self { 70 | HexMerkleHash(*value) 71 | } 72 | } 73 | 74 | #[derive(Debug, Clone, Serialize, Deserialize, Default, Hash, PartialEq, Eq)] 75 | pub struct HexKey { 76 | pub prefix: String, 77 | #[serde(with = "hex::serde")] 78 | pub hash: MerkleHash, 79 | } 80 | 81 | impl From for Key { 82 | fn from(HexKey { prefix, hash }: HexKey) -> Self { 83 | Key { prefix, hash } 84 | } 85 | } 86 | 87 | impl From for HexKey { 88 | fn from(Key { prefix, hash }: Key) -> Self { 89 | HexKey { prefix, hash } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /chunk_cache/.gitignore: -------------------------------------------------------------------------------- 1 | flamegraph.svg 2 | -------------------------------------------------------------------------------- /chunk_cache/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "chunk_cache" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | cas_types = { path = "../cas_types" } 8 | error_printer = { path = "../error_printer" } 9 | file_utils = { path = "../file_utils" } 10 | merklehash = { path = "../merklehash" } 11 | utils = { path = "../utils" } 12 | 13 | async-trait = { workspace = true } 14 | base64 = { workspace = true } 15 | clap = { workspace = true, optional = true } 16 | crc32fast = { workspace = true } 17 | mockall = { workspace = true } 18 | once_cell = { workspace = true } 19 | rand = { workspace = true } 20 | thiserror = { workspace = true } 21 | tokio = { workspace = true } 22 | tracing = { workspace = true } 23 | 24 | [dev-dependencies] 25 | tempdir = { workspace = true } 26 | tokio = { workspace = true, features = ["rt-multi-thread"] } 27 | 28 | [[bin]] 29 | name = "cache_analysis" 30 | path = "./src/bin/analysis.rs" 31 | required-features = ["analysis"] 32 | 33 | [features] 34 | analysis = ["dep:clap"] 35 | -------------------------------------------------------------------------------- /chunk_cache/src/bin/analysis.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | use std::u64; 3 | 4 | use chunk_cache::{CacheConfig, DiskCache}; 5 | use clap::Parser; 6 | 7 | #[derive(Debug, Parser)] 8 | struct CacheAnalysisArgs { 9 | #[clap(long, short, default_value = "./xet/cache")] 10 | root: PathBuf, 11 | } 12 | 13 | /// Usage: ./cache_analysis --root "path to cache root" 14 | /// prints out the state of the cache 15 | fn main() { 16 | let args = CacheAnalysisArgs::parse(); 17 | print_main(args.root); 18 | } 19 | 20 | fn print_main(root: PathBuf) { 21 | let cache = DiskCache::initialize(&CacheConfig { 22 | cache_directory: root, 23 | cache_size: u64::MAX, 24 | }) 25 | .unwrap(); 26 | cache.print(); 27 | } 28 | -------------------------------------------------------------------------------- /chunk_cache/src/cache_manager.rs: -------------------------------------------------------------------------------- 1 | use std::cell::RefCell; 2 | use std::collections::HashMap; 3 | use std::path::PathBuf; 4 | use std::sync::{Arc, Mutex, Weak}; 5 | 6 | use once_cell::sync::Lazy; 7 | 8 | use crate::error::ChunkCacheError; 9 | use crate::{CacheConfig, ChunkCache, DiskCache}; 10 | 11 | // single instance of CACHE_MANAGER not exposed to outside users that 12 | // dedupes cache instances based on configurations 13 | static CACHE_MANAGER: Lazy = Lazy::new(CacheManager::new); 14 | 15 | /// get_cache attempts to return a cache given the provided config parameter 16 | pub fn get_cache(config: &CacheConfig) -> Result, ChunkCacheError> { 17 | CACHE_MANAGER.get(config) 18 | } 19 | 20 | struct CacheManager { 21 | vals: Mutex>>>, 22 | } 23 | 24 | impl CacheManager { 25 | fn new() -> Self { 26 | Self { 27 | vals: Mutex::new(HashMap::new()), 28 | } 29 | } 30 | 31 | /// get takes a CacheConfig and checks if there exists a valid `DiskCache` with a matching 32 | /// cache_directory then it will return an Arc to that `DiskCache` instance. If it doesn't exist 33 | /// or the `DiskCache` instance has been deallocated (CacheManager only holds a weak pointer) 34 | /// then it creates a new instance based on the provided config. 35 | fn get(&self, config: &CacheConfig) -> Result, ChunkCacheError> { 36 | let mut vals = self.vals.lock()?; 37 | if let Some(v) = vals.get_mut(&config.cache_directory) { 38 | let weak = v.borrow().clone(); 39 | // if upgrade from Weak to Arc is successful, returns the upgraded pointer 40 | if let Some(value) = weak.upgrade() { 41 | return Ok(value); 42 | } 43 | // since upgrading failed, creates a new DiskCache, replaces the weak pointer with a 44 | // weak pointer to the new instance and then returns the Arc to the new cache instance 45 | let result: Arc = Arc::new(DiskCache::initialize(config)?); 46 | v.replace(Arc::downgrade(&result)); 47 | Ok(result) 48 | } else { 49 | // create a new Cache and insert weak pointer to managed map 50 | let result: Arc = Arc::new(DiskCache::initialize(config)?); 51 | vals.insert(config.cache_directory.clone(), RefCell::new(Arc::downgrade(&result))); 52 | Ok(result) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /chunk_cache/src/disk/cache_file_header.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Seek, Write}; 2 | use std::mem::size_of; 3 | 4 | use utils::serialization_utils::{read_u32, write_u32, write_u32s}; 5 | 6 | use crate::error::ChunkCacheError; 7 | 8 | /// Header for every cache file, it is simple to deserialize and serialize 9 | /// All numbers are unsigned 32 bit little endian integers 10 | /// 11 | /// format: 12 | /// (chunk_byte_indices length n) 13 | /// ( 14 | /// chunk_byte_indices[0] 15 | /// chunk_byte_indices[1] 16 | /// chunk_byte_indices[2] 17 | /// ... 18 | /// chunk_byte_indices[n - 1] 19 | /// ) 20 | pub struct CacheFileHeader { 21 | pub chunk_byte_indices: Vec, 22 | } 23 | 24 | impl CacheFileHeader { 25 | pub fn new>>(chunk_byte_indices: T) -> Self { 26 | let chunk_byte_indices = chunk_byte_indices.into(); 27 | Self { chunk_byte_indices } 28 | } 29 | 30 | pub fn header_len(&self) -> usize { 31 | (self.chunk_byte_indices.len() + 1) * size_of::() 32 | } 33 | 34 | pub fn deserialize(reader: &mut R) -> Result { 35 | reader.seek(std::io::SeekFrom::Start(0))?; 36 | let chunk_byte_indices_len = read_u32(reader)?; 37 | let mut chunk_byte_indices: Vec = Vec::with_capacity(chunk_byte_indices_len as usize); 38 | for i in 0..chunk_byte_indices_len { 39 | let idx = read_u32(reader)?; 40 | if i == 0 && idx != 0 { 41 | return Err(ChunkCacheError::parse("first byte index isn't 0")); 42 | } else if !chunk_byte_indices.is_empty() && chunk_byte_indices.last().unwrap() >= &idx { 43 | return Err(ChunkCacheError::parse("chunk byte indices are not strictly increasing")); 44 | } 45 | chunk_byte_indices.push(idx); 46 | } 47 | 48 | Ok(Self::new(chunk_byte_indices)) 49 | } 50 | 51 | pub fn serialize(&self, writer: &mut W) -> Result<(), std::io::Error> { 52 | write_u32(writer, self.chunk_byte_indices.len() as u32)?; 53 | write_u32s(writer, &self.chunk_byte_indices)?; 54 | Ok(()) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /chunk_cache/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::array::TryFromSliceError; 2 | use std::str::Utf8Error; 3 | 4 | use base64::DecodeError; 5 | use merklehash::DataHashBytesParseError; 6 | use thiserror::Error; 7 | use tokio::task::JoinError; 8 | 9 | #[derive(Debug, Error)] 10 | pub enum ChunkCacheError { 11 | #[error("General: {0}")] 12 | General(String), 13 | #[error("IO: {0}")] 14 | IO(#[from] std::io::Error), 15 | #[error("ParseError: {0}")] 16 | Parse(String), 17 | #[error("bad range")] 18 | BadRange, 19 | #[error("cache is empty when it is presumed no empty")] 20 | CacheEmpty, 21 | #[error("Infallible")] 22 | Infallible, 23 | #[error("LockPoison")] 24 | LockPoison, 25 | #[error("invalid arguments")] 26 | InvalidArguments, 27 | 28 | #[error("RuntimeError")] 29 | RuntimeError(#[from] JoinError), 30 | } 31 | 32 | impl ChunkCacheError { 33 | pub fn parse(value: T) -> ChunkCacheError { 34 | ChunkCacheError::Parse(value.to_string()) 35 | } 36 | 37 | pub fn general(value: T) -> ChunkCacheError { 38 | ChunkCacheError::General(value.to_string()) 39 | } 40 | } 41 | 42 | impl From> for ChunkCacheError { 43 | fn from(_value: std::sync::PoisonError) -> Self { 44 | ChunkCacheError::LockPoison 45 | } 46 | } 47 | 48 | macro_rules! impl_parse_error_from_error { 49 | ($error_type:ty) => { 50 | impl From<$error_type> for ChunkCacheError { 51 | fn from(value: $error_type) -> Self { 52 | ChunkCacheError::parse(value) 53 | } 54 | } 55 | }; 56 | } 57 | 58 | impl_parse_error_from_error!(TryFromSliceError); 59 | impl_parse_error_from_error!(DecodeError); 60 | impl_parse_error_from_error!(DataHashBytesParseError); 61 | impl_parse_error_from_error!(Utf8Error); 62 | -------------------------------------------------------------------------------- /chunk_cache/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod cache_manager; 2 | mod disk; 3 | pub mod error; 4 | 5 | use std::path::PathBuf; 6 | 7 | use async_trait::async_trait; 8 | pub use cache_manager::get_cache; 9 | use cas_types::{ChunkRange, Key}; 10 | pub use disk::DiskCache; 11 | pub use disk::test_utils::*; 12 | use error::ChunkCacheError; 13 | use mockall::automock; 14 | 15 | pub use crate::disk::DEFAULT_CHUNK_CACHE_CAPACITY; 16 | 17 | utils::configurable_constants! { 18 | ref CHUNK_CACHE_SIZE_BYTES: u64 = DEFAULT_CHUNK_CACHE_CAPACITY; 19 | } 20 | 21 | /// Return dto for cache gets 22 | /// offsets has 1 more than then number of chunks in the specified range 23 | /// suppose the range is for chunks [2, 5) then offsets may look like: 24 | /// [0, 2000, 4000, 6000] where chunk 2 is made of bytes [0, 2000) 25 | /// chunk 3 [2000, 4000) and chunk 4 is [4000, 6000). 26 | /// It is guaranteed that the first number in offsets is 0 and the last number is data.len() 27 | #[derive(Debug)] 28 | pub struct CacheRange { 29 | pub offsets: Vec, 30 | pub data: Vec, 31 | pub range: ChunkRange, 32 | } 33 | 34 | /// ChunkCache is a trait for storing and fetching Xorb ranges. 35 | /// implementors are expected to return bytes for a key and a given chunk range 36 | /// (no compression or further deserialization should be required) 37 | /// Range inputs use chunk indices in an end exclusive way i.e. [start, end) 38 | /// 39 | /// implementors are allowed to evict data, a get after a put is not required to 40 | /// be a cache hit. 41 | #[automock] 42 | #[async_trait] 43 | pub trait ChunkCache: Sync + Send { 44 | /// get should return an Ok() variant if significant error occurred, check the error 45 | /// variant for issues with IO or parsing contents etc. 46 | /// 47 | /// if get returns an Ok(None) then there was no error, but there was a cache miss 48 | /// otherwise returns an Ok(Some(data)) where data matches exactly the bytes for 49 | /// the requested key and the requested chunk index range for that key 50 | /// 51 | /// Given implementors are expected to be able to evict members there's no guarantee 52 | /// that a previously put range will be a cache hit 53 | /// 54 | /// key is required to be a valid CAS Key 55 | /// range is intended to be an index range within the xorb with constraint 56 | /// 0 <= range.start < range.end <= num_chunks_in_xorb(key) 57 | async fn get(&self, key: &Key, range: &ChunkRange) -> Result, ChunkCacheError>; 58 | 59 | /// put should return Ok(()) if the put succeeded with no error, check the error 60 | /// variant for issues with validating the input, cache state, IO, etc. 61 | /// 62 | /// put expects that chunk_byte_indices.len() is range.end - range.start + 1 63 | /// with 1 entry for each start byte index for [range.start, range.end] 64 | /// the first entry must be 0 (start of first chunk in the data) 65 | /// the last entry must be data.len() i.e. the end of data, start of chunk past end 66 | /// 67 | /// key is required to be a valid CAS Key 68 | /// range is intended to be an index range within the xorb with constraint 69 | /// 0 <= range.start < range.end <= num_chunks_in_xorb(key) 70 | async fn put( 71 | &self, 72 | key: &Key, 73 | range: &ChunkRange, 74 | chunk_byte_indices: &[u32], 75 | data: &[u8], 76 | ) -> Result<(), ChunkCacheError>; 77 | } 78 | 79 | #[derive(Debug, Clone)] 80 | pub struct CacheConfig { 81 | pub cache_directory: PathBuf, 82 | pub cache_size: u64, 83 | } 84 | 85 | impl Default for CacheConfig { 86 | fn default() -> Self { 87 | CacheConfig { 88 | cache_directory: PathBuf::from("/tmp"), 89 | cache_size: *CHUNK_CACHE_SIZE_BYTES, 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /chunk_cache_bench/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "chunk_cache_bench" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | cas_types = { path = "../cas_types" } 8 | chunk_cache = { path = "../chunk_cache" } 9 | 10 | async-trait = "0.1" 11 | base64 = "0.22" 12 | clap = { version = "4", features = ["derive"] } 13 | r2d2 = "0.8.10" 14 | r2d2_postgres = "0.18.1" 15 | sccache = "0.8" 16 | tempdir = "0.3" 17 | tokio = { version = "1.44", features = ["full"] } 18 | 19 | [[bench]] 20 | name = "cache_bench" 21 | harness = false 22 | bench = true 23 | 24 | # To run: ./cache_resilience_test parent 25 | [[bin]] 26 | name = "cache_resilience_test" 27 | 28 | [dev-dependencies] 29 | criterion = { version = "0.4", features = ["async_tokio"] } 30 | rand = "0.8" 31 | -------------------------------------------------------------------------------- /chunk_cache_bench/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use chunk_cache::error::ChunkCacheError; 4 | use chunk_cache::{CacheConfig, DiskCache}; 5 | 6 | pub mod sccache; 7 | pub mod solid_cache; 8 | 9 | /// only used for benchmark code 10 | pub trait ChunkCacheExt: chunk_cache::ChunkCache + Sized + Clone { 11 | fn _initialize(cache_root: PathBuf, capacity: u64) -> Result; 12 | fn name() -> &'static str; 13 | } 14 | 15 | impl ChunkCacheExt for chunk_cache::DiskCache { 16 | fn _initialize(cache_root: PathBuf, capacity: u64) -> Result { 17 | let config = CacheConfig { 18 | cache_directory: cache_root, 19 | cache_size: capacity, 20 | }; 21 | DiskCache::initialize(&config) 22 | } 23 | 24 | fn name() -> &'static str { 25 | "disk" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /chunk_cache_bench/src/sccache.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::{OsStr, OsString}; 2 | use std::io::Write; 3 | use std::os::unix::ffi::OsStringExt; 4 | use std::path::PathBuf; 5 | use std::sync::{Arc, Mutex}; 6 | 7 | use base64::Engine; 8 | use cas_types::{ChunkRange, Key}; 9 | use chunk_cache::error::ChunkCacheError; 10 | use chunk_cache::{CacheRange, ChunkCache}; 11 | use sccache::lru_disk_cache::LruDiskCache; 12 | 13 | use crate::ChunkCacheExt; 14 | 15 | #[derive(Clone)] 16 | pub struct SCCache { 17 | cache: Arc>, 18 | } 19 | 20 | impl ChunkCacheExt for SCCache { 21 | fn _initialize(cache_root: PathBuf, capacity: u64) -> Result { 22 | let cache = LruDiskCache::new(cache_root, capacity).map_err(ChunkCacheError::general)?; 23 | 24 | Ok(Self { 25 | cache: Arc::new(Mutex::new(cache)), 26 | }) 27 | } 28 | 29 | fn name() -> &'static str { 30 | "sccache" 31 | } 32 | } 33 | 34 | #[async_trait::async_trait] 35 | impl ChunkCache for SCCache { 36 | async fn get( 37 | &self, 38 | key: &cas_types::Key, 39 | range: &cas_types::ChunkRange, 40 | ) -> Result, ChunkCacheError> { 41 | let cache_key = CacheKey::new(key, range)?; 42 | let mut file = if let Ok(file) = self.cache.lock()?.get(&cache_key) { 43 | file 44 | } else { 45 | return Ok(None); 46 | }; 47 | 48 | let mut res = Vec::new(); 49 | file.read_to_end(&mut res)?; 50 | Ok(Some(CacheRange { 51 | offsets: (range.start..=range.end).collect::>().into(), 52 | data: res.into(), 53 | range: range.clone(), 54 | })) 55 | } 56 | 57 | async fn put( 58 | &self, 59 | key: &cas_types::Key, 60 | range: &cas_types::ChunkRange, 61 | _chunk_byte_indices: &[u32], 62 | data: &[u8], 63 | ) -> Result<(), ChunkCacheError> { 64 | let mut cache = self.cache.lock()?; 65 | 66 | let cache_key = CacheKey::new(key, range)?; 67 | if cache.get(&cache_key).is_ok() { 68 | return Ok(()); 69 | } 70 | 71 | cache.insert_bytes(cache_key, data).map_err(ChunkCacheError::general)?; 72 | Ok(()) 73 | } 74 | } 75 | 76 | #[derive(Debug)] 77 | struct CacheKey(OsString); 78 | 79 | impl CacheKey { 80 | fn new(key: &Key, range: &ChunkRange) -> Result { 81 | let mut buf = Vec::new(); 82 | buf.write_all(key.hash.as_bytes())?; 83 | buf.write_all(key.prefix.as_bytes())?; 84 | buf.write_all(format!("{}_{}", range.start, range.end).as_bytes())?; 85 | let result = base64::engine::general_purpose::URL_SAFE.encode(buf).as_bytes().to_vec(); 86 | Ok(CacheKey(OsString::from_vec(result))) 87 | } 88 | } 89 | 90 | impl AsRef for CacheKey { 91 | fn as_ref(&self) -> &OsStr { 92 | &self.0 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /data/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "data" 3 | version = "0.14.5" 4 | edition = "2024" 5 | 6 | [lib] 7 | doctest = false 8 | 9 | [[bin]] 10 | name = "x" 11 | path = "src/bin/example.rs" 12 | 13 | [[bin]] 14 | name = "xtool" 15 | path = "src/bin/xtool.rs" 16 | 17 | [[example]] 18 | name = "chunk" 19 | path = "examples/chunk/main.rs" 20 | 21 | [[example]] 22 | name = "hash" 23 | path = "examples/hash/main.rs" 24 | 25 | [[example]] 26 | name = "xorb-check" 27 | path = "examples/xorb-check/main.rs" 28 | 29 | [dependencies] 30 | cas_client = { path = "../cas_client" } 31 | cas_object = { path = "../cas_object" } 32 | cas_types = { path = "../cas_types" } 33 | deduplication = { path = "../deduplication" } 34 | error_printer = { path = "../error_printer" } 35 | hub_client = { path = "../hub_client" } 36 | mdb_shard = { path = "../mdb_shard" } 37 | merklehash = { path = "../merklehash" } 38 | progress_tracking = { path = "../progress_tracking" } 39 | utils = { path = "../utils" } 40 | xet_runtime = { path = "../xet_runtime" } 41 | 42 | anyhow = { workspace = true } 43 | async-trait = { workspace = true } 44 | bytes = { workspace = true } 45 | chrono = { workspace = true } 46 | clap = { workspace = true } 47 | dirs = { workspace = true } 48 | jsonwebtoken = { workspace = true } 49 | lazy_static = { workspace = true } 50 | more-asserts = { workspace = true } 51 | prometheus = { workspace = true } 52 | rand = { workspace = true } 53 | rand_chacha = { workspace = true } 54 | regex = { workspace = true } 55 | serde = { workspace = true } 56 | serde_json = { workspace = true } 57 | tempfile = { workspace = true } 58 | thiserror = { workspace = true } 59 | tokio = { workspace = true, features = ["rt-multi-thread", "rt"] } 60 | tracing = { workspace = true } 61 | ulid = { workspace = true } 62 | walkdir = { workspace = true } 63 | 64 | # Windows doesn't support assembly for compilation 65 | [target.'cfg(not(target_os = "windows"))'.dependencies] 66 | sha2 = { workspace = true, features = ["asm"] } 67 | 68 | [target.'cfg(target_os = "windows")'.dependencies] 69 | sha2 = { workspace = true } 70 | 71 | [dev-dependencies] 72 | serial_test = { workspace = true } 73 | tracing-test = { workspace = true } 74 | ctor = { workspace = true } 75 | 76 | [features] 77 | strict = [] 78 | expensive_tests = [] 79 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # data 2 | 3 | A high-level data translation layer for Xet's content-addressable storage (CAS). This crate handles: 4 | 5 | - **Cleaning (uploading)** regular files into deduplicated CAS objects (xorbs + shards) and producing lightweight pointers (`XetFileInfo`). 6 | - **Smudging (downloading)** pointer metadata back into materialized files. 7 | 8 | ## Core APIs 9 | 10 | - **High-level async functions** in `data::data_client`: 11 | - `upload_async(file_paths, endpoint, token_info, token_refresher, progress_updater) -> Vec` 12 | - `download_async(files: Vec<(XetFileInfo, String)>, endpoint, token_info, token_refresher, progress_updaters) -> Vec` 13 | 14 | - **Sessions and primitives** (re-exported at the crate root): 15 | - `FileUploadSession` – multi-file, deduplicated upload session. Handles chunking, xorb/shard production, and finalization. 16 | - `FileDownloader` – smudges files from CAS given a `MerkleHash`/`XetFileInfo`. 17 | - `XetFileInfo` – compact pointer describing a file by its hash and size. 18 | 19 | Both high-level functions create sensible defaults (cache paths, progress aggregation, endpoint separation) via `data_client::default_config` and enforce bounded concurrency. 20 | 21 | ## How `hf_xet` uses this crate 22 | 23 | The `hf_xet` Python extension exposes thin wrappers around these async functions and types. In `hf_xet/src/lib.rs`: 24 | 25 | - `upload_files(...)` calls `data::data_client::upload_async`. 26 | - `upload_bytes(...)` calls `data::data_client::upload_bytes_async`. 27 | - `download_files(...)` calls `data::data_client::download_async`. 28 | -------------------------------------------------------------------------------- /data/examples/chunk/main.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufWriter, Read, Write}; 3 | use std::path::PathBuf; 4 | 5 | use clap::Parser; 6 | use deduplication::Chunker; 7 | use deduplication::constants::TARGET_CHUNK_SIZE; 8 | 9 | #[derive(Debug, Parser)] 10 | #[command( 11 | version, 12 | about, 13 | long_about = "Example of using the chunker. Splits the input file or stdin into chunks and writes to stdout or the specified file the chunk hash in string format and the chunk size on a new line for each chunk in order in the file" 14 | )] 15 | struct ChunkArgs { 16 | /// Input file or uses stdin if not specified. 17 | #[arg(short, long)] 18 | input: Option, 19 | /// Output file or uses stdout if not specified, where to write the chunk information 20 | #[arg(short, long)] 21 | output: Option, 22 | } 23 | 24 | fn main() { 25 | let args = ChunkArgs::parse(); 26 | 27 | // setup content reader 28 | let mut input: Box = if let Some(file_path) = args.input { 29 | Box::new(File::open(file_path).unwrap()) 30 | } else { 31 | Box::new(std::io::stdin()) 32 | }; 33 | 34 | // set up writer to output chunks information 35 | let mut output: Box = if let Some(save) = args.output { 36 | Box::new(BufWriter::new(File::create(save).unwrap())) 37 | } else { 38 | Box::new(std::io::stdout()) 39 | }; 40 | 41 | let mut chunker = Chunker::new(*TARGET_CHUNK_SIZE); 42 | 43 | // read input in up to 8 MB sections and pass through chunker 44 | const INGESTION_BLOCK_SIZE: usize = 8 * 1024 * 1024; // 8 MiB 45 | let mut buf = vec![0u8; INGESTION_BLOCK_SIZE]; 46 | loop { 47 | let num_read = input.read(&mut buf).unwrap(); 48 | if num_read == 0 { 49 | break; 50 | } 51 | let chunks = chunker.next_block(&buf[..num_read], false); 52 | for chunk in chunks { 53 | output 54 | .write_all(format!("{} {}\n", chunk.hash, chunk.data.len()).as_bytes()) 55 | .unwrap(); 56 | } 57 | } 58 | if let Some(chunk) = chunker.finish() { 59 | output 60 | .write_all(format!("{} {}\n", chunk.hash, chunk.data.len()).as_bytes()) 61 | .unwrap(); 62 | } 63 | output.flush().unwrap(); 64 | } 65 | -------------------------------------------------------------------------------- /data/examples/xorb-check/main.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufReader, BufWriter, Read, Write}; 3 | use std::path::PathBuf; 4 | 5 | use clap::Parser; 6 | use merklehash::{MerkleHash, compute_data_hash, xorb_hash}; 7 | use utils::output_bytes; 8 | 9 | #[derive(Debug, Parser)] 10 | struct XorbCheckArgs { 11 | /// Input file or uses stdin if not specified. Expects xorb format object (with no footer) 12 | #[arg(short, long)] 13 | input: Option, 14 | /// Specific hash to check that the xorb hash is equal to, optional, can use --hash-from-path to parse a hash from 15 | /// the input file path or ignore the check altogether to just compute the xorb hash 16 | #[arg(short, long)] 17 | hash: Option, 18 | /// If true, tries to parse a hash from the first 64 characters of the file name in the path of the input file 19 | #[arg(long, conflicts_with = "hash")] 20 | hash_from_path: bool, 21 | /// Output file or uses stdout if not specified, where to write the chunk information 22 | #[arg(short, long)] 23 | output_chunks: Option, 24 | /// If true, write the chunk information to stdout, if not set and output_chunks is not set, will not output the 25 | /// chunk information 26 | #[arg(long, conflicts_with = "output_chunks")] 27 | output_chunks_stdout: bool, 28 | } 29 | 30 | fn main() { 31 | let args = XorbCheckArgs::parse(); 32 | 33 | if args.hash_from_path && args.input.is_none() { 34 | panic!("--hash-from-path requires --file to be set"); 35 | } 36 | 37 | let mut provided_hash = None; 38 | if let Some(hash_str) = args.hash { 39 | provided_hash = Some(MerkleHash::from_hex(&hash_str).unwrap()) 40 | } else if args.hash_from_path { 41 | let mut path_hash = args.input.clone().unwrap().file_name().unwrap().to_str().unwrap().to_string(); 42 | path_hash.truncate(64); 43 | provided_hash = Some(MerkleHash::from_hex(&path_hash).unwrap()) 44 | } 45 | 46 | let mut input: Box = match args.input { 47 | Some(path) => Box::new(BufReader::new(File::open(path).unwrap())), 48 | None => Box::new(std::io::stdin()), 49 | }; 50 | 51 | let (data, boundaries) = match cas_object::deserialize_chunks(&mut input) { 52 | Ok(chunks) => chunks, 53 | Err(e) => panic!("failed to deserialize xorb: {e}"), 54 | }; 55 | 56 | eprintln!( 57 | "Successfully deserialized xorb with {} chunks totalling {} Bytes ({})!", 58 | boundaries.len() - 1, 59 | data.len(), 60 | output_bytes(data.len() as u64) 61 | ); 62 | 63 | let mut chunk_hashes = Vec::with_capacity(boundaries.len() - 1); 64 | for (chunk_start, next_chunk_start) in boundaries.iter().take(boundaries.len() - 1).zip(boundaries.iter().skip(1)) { 65 | let chunk = &data[(*chunk_start as usize)..(*next_chunk_start as usize)]; 66 | let chunk_hash = compute_data_hash(chunk); 67 | chunk_hashes.push((chunk_hash, (next_chunk_start - chunk_start) as u64)); 68 | } 69 | 70 | let computed_xorb_hash = xorb_hash(&chunk_hashes); 71 | 72 | eprintln!("computed xorb hash: {computed_xorb_hash}"); 73 | 74 | if let Some(provided_hash) = provided_hash { 75 | if computed_xorb_hash != provided_hash { 76 | eprintln!("provided hash does not match computed hash!"); 77 | } else { 78 | eprintln!("provided hash matches computed hash!"); 79 | } 80 | } 81 | 82 | let mut chunks_writer: BufWriter> = 83 | BufWriter::new(match (args.output_chunks_stdout, args.output_chunks) { 84 | (true, _) => Box::new(std::io::stdout()), 85 | (false, Some(path)) => Box::new(File::create(path).unwrap()), 86 | (false, None) => { 87 | return; 88 | }, 89 | }); 90 | 91 | for (hash, size) in chunk_hashes { 92 | chunks_writer.write_all(format!("{hash} {size}\n").as_bytes()).unwrap(); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /data/src/constants.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | utils::configurable_constants! { 4 | 5 | // Approximately 4 MB min spacing between global dedup queries. Calculated by 4MB / TARGET_CHUNK_SIZE 6 | ref MIN_SPACING_BETWEEN_GLOBAL_DEDUP_QUERIES: usize = 256; 7 | 8 | /// scheme for a local filesystem based CAS server 9 | ref LOCAL_CAS_SCHEME: String = "local://".to_owned(); 10 | 11 | /// The current version 12 | ref CURRENT_VERSION: String = release_fixed( env!("CARGO_PKG_VERSION").to_owned()); 13 | 14 | /// The expiration time of a local shard when first placed in the local shard cache. Currently 15 | /// set to 3 weeks. 16 | ref MDB_SHARD_LOCAL_CACHE_EXPIRATION: Duration = Duration::from_secs(3 * 7 * 24 * 3600); 17 | 18 | /// The maximum number of simultaneous xorb upload streams. 19 | /// can be overwritten by environment variable "HF_XET_MAX_CONCURRENT_UPLOADS". 20 | /// The default value changes from 8 to 100 when "High Performance Mode" is enabled 21 | ref MAX_CONCURRENT_UPLOADS: usize = GlobalConfigMode::HighPerformanceOption { 22 | standard: 8, 23 | high_performance: 100, 24 | }; 25 | 26 | /// The maximum number of files to ingest at once on the upload path 27 | ref MAX_CONCURRENT_FILE_INGESTION: usize = GlobalConfigMode::HighPerformanceOption { 28 | standard: 8, 29 | high_performance: 100, 30 | }; 31 | 32 | /// The maximum number of files to download at one time. 33 | ref MAX_CONCURRENT_DOWNLOADS : usize = GlobalConfigMode::HighPerformanceOption { 34 | standard: 8, 35 | high_performance: 100, 36 | }; 37 | 38 | /// The maximum block size from a file to process at once. 39 | ref INGESTION_BLOCK_SIZE : usize = 8 * 1024 * 1024; 40 | 41 | /// How often to send updates on file progress, in milliseconds. Disables batching 42 | /// if set to 0. 43 | ref PROGRESS_UPDATE_INTERVAL : Duration = Duration::from_millis(200); 44 | 45 | /// How large of a time window to use for aggregating the progress speed results. 46 | ref PROGRESS_UPDATE_SPEED_SAMPLING_WINDOW: Duration = Duration::from_millis(10 * 1000); 47 | 48 | 49 | /// How often do we flush new xorb data to disk on a long running upload session? 50 | ref SESSION_XORB_METADATA_FLUSH_INTERVAL : Duration = Duration::from_secs(20); 51 | 52 | /// Force a flush of the xorb metadata every this many xorbs, if more are created 53 | /// in this time window. 54 | ref SESSION_XORB_METADATA_FLUSH_MAX_COUNT : usize = 64; 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /data/src/deduplication_interface.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | use deduplication::{DeduplicationDataInterface, RawXorbData}; 5 | use mdb_shard::file_structs::FileDataSequenceEntry; 6 | use merklehash::MerkleHash; 7 | use progress_tracking::upload_tracking::{CompletionTrackerFileId, FileXorbDependency}; 8 | use tokio::task::JoinSet; 9 | use tracing::Instrument; 10 | 11 | use crate::configurations::GlobalDedupPolicy; 12 | use crate::errors::Result; 13 | use crate::file_upload_session::FileUploadSession; 14 | 15 | pub struct UploadSessionDataManager { 16 | file_id: CompletionTrackerFileId, 17 | session: Arc, 18 | active_global_dedup_queries: JoinSet>, 19 | } 20 | 21 | impl UploadSessionDataManager { 22 | pub fn new(session: Arc, file_id: CompletionTrackerFileId) -> Self { 23 | Self { 24 | file_id, 25 | session, 26 | active_global_dedup_queries: Default::default(), 27 | } 28 | } 29 | 30 | fn global_dedup_queries_enabled(&self) -> bool { 31 | matches!(self.session.config.shard_config.global_dedup_policy, GlobalDedupPolicy::Always) 32 | } 33 | } 34 | 35 | #[async_trait] 36 | impl DeduplicationDataInterface for UploadSessionDataManager { 37 | type ErrorType = crate::errors::DataProcessingError; 38 | 39 | /// Query for possible 40 | async fn chunk_hash_dedup_query( 41 | &self, 42 | query_hashes: &[MerkleHash], 43 | ) -> Result> { 44 | Ok(self.session.shard_interface.chunk_hash_dedup_query(query_hashes).await?) 45 | } 46 | 47 | /// Registers a new query for more information about the 48 | /// global deduplication. This is expected to run in the background. 49 | async fn register_global_dedup_query(&mut self, chunk_hash: MerkleHash) -> Result<()> { 50 | if !self.global_dedup_queries_enabled() { 51 | return Ok(()); 52 | } 53 | 54 | // Now, query for a global dedup shard in the background to make sure that all the rest of this 55 | // can continue. 56 | let session: Arc = self.session.clone(); 57 | 58 | self.active_global_dedup_queries.spawn( 59 | async move { 60 | session.shard_interface.query_dedup_shard_by_chunk(&chunk_hash).await?; 61 | 62 | Ok(true) 63 | } 64 | .instrument(tracing::info_span!("UploadSessionDataManager::dedup_task")), 65 | ); 66 | 67 | Ok(()) 68 | } 69 | 70 | /// Waits for all the current queries to complete, then returns true if there is 71 | /// new deduplication information available. 72 | async fn complete_global_dedup_queries(&mut self) -> Result { 73 | if !self.global_dedup_queries_enabled() { 74 | return Ok(false); 75 | } 76 | 77 | let mut any_result = false; 78 | while let Some(result) = self.active_global_dedup_queries.join_next().await { 79 | any_result |= result??; 80 | } 81 | Ok(any_result) 82 | } 83 | 84 | /// Registers a Xorb of new data that has no deduplication references. 85 | async fn register_new_xorb(&mut self, xorb: RawXorbData) -> Result<()> { 86 | // Begin the process for upload. 87 | self.session.register_new_xorb(xorb, &[]).await?; 88 | 89 | Ok(()) 90 | } 91 | 92 | /// Periodically registers xorb dependencies; used for progress tracking. 93 | async fn register_xorb_dependencies(&mut self, dependencies: &[FileXorbDependency]) { 94 | self.session.register_xorb_dependencies(dependencies).await; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /data/src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::string::FromUtf8Error; 2 | use std::sync::mpsc::RecvError; 3 | 4 | use cas_client::CasClientError; 5 | use cas_object::error::CasObjectError; 6 | use mdb_shard::error::MDBShardError; 7 | use thiserror::Error; 8 | use tokio::sync::AcquireError; 9 | use tracing::error; 10 | use utils::errors::{AuthError, SingleflightError}; 11 | use xet_runtime::utils::ParutilsError; 12 | 13 | #[derive(Error, Debug)] 14 | pub enum DataProcessingError { 15 | #[error("File query policy configuration error: {0}")] 16 | FileQueryPolicyError(String), 17 | 18 | #[error("CAS configuration error: {0}")] 19 | CASConfigError(String), 20 | 21 | #[error("Shard configuration error: {0}")] 22 | ShardConfigError(String), 23 | 24 | #[error("Cache configuration error: {0}")] 25 | CacheConfigError(String), 26 | 27 | #[error("Deduplication configuration error: {0}")] 28 | DedupConfigError(String), 29 | 30 | #[error("Clean task error: {0}")] 31 | CleanTaskError(String), 32 | 33 | #[error("Upload task error: {0}")] 34 | UploadTaskError(String), 35 | 36 | #[error("Internal error : {0}")] 37 | InternalError(String), 38 | 39 | #[error("Synchronization error: {0}")] 40 | SyncError(String), 41 | 42 | #[error("Channel error: {0}")] 43 | ChannelRecvError(#[from] RecvError), 44 | 45 | #[error("MerkleDB Shard error: {0}")] 46 | MDBShardError(#[from] MDBShardError), 47 | 48 | #[error("CAS service error : {0}")] 49 | CasClientError(#[from] CasClientError), 50 | 51 | #[error("Xorb Serialization error : {0}")] 52 | XorbSerializationError(#[from] CasObjectError), 53 | 54 | #[error("Subtask scheduling error: {0}")] 55 | JoinError(#[from] tokio::task::JoinError), 56 | 57 | #[error("Non-small file not cleaned: {0}")] 58 | FileNotCleanedError(#[from] FromUtf8Error), 59 | 60 | #[error("I/O error: {0}")] 61 | IOError(#[from] std::io::Error), 62 | 63 | #[error("Hash not found")] 64 | HashNotFound, 65 | 66 | #[error("Parameter error: {0}")] 67 | ParameterError(String), 68 | 69 | #[error("Unable to parse string as hex hash value")] 70 | HashStringParsingFailure(#[from] merklehash::DataHashHexParseError), 71 | 72 | #[error("Deprecated feature: {0}")] 73 | DeprecatedError(String), 74 | 75 | #[error("AuthError: {0}")] 76 | AuthError(#[from] AuthError), 77 | 78 | #[error("Permit Acquisition Error: {0}")] 79 | PermitAcquisitionError(#[from] AcquireError), 80 | } 81 | 82 | pub type Result = std::result::Result; 83 | 84 | // Specific implementation for this one so that we can extract the internal error when appropriate 85 | impl From> for DataProcessingError { 86 | fn from(value: SingleflightError) -> Self { 87 | let msg = format!("{value:?}"); 88 | error!("{msg}"); 89 | match value { 90 | SingleflightError::InternalError(e) => e, 91 | _ => DataProcessingError::InternalError(format!("SingleflightError: {msg}")), 92 | } 93 | } 94 | } 95 | 96 | impl From> for DataProcessingError { 97 | fn from(value: ParutilsError) -> Self { 98 | match value { 99 | ParutilsError::Join(e) => DataProcessingError::JoinError(e), 100 | ParutilsError::Acquire(e) => DataProcessingError::PermitAcquisitionError(e), 101 | ParutilsError::Task(e) => e, 102 | e => DataProcessingError::InternalError(e.to_string()), 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /data/src/file_downloader.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | use std::sync::Arc; 3 | 4 | use cas_client::{Client, OutputProvider}; 5 | use cas_types::FileRange; 6 | use merklehash::MerkleHash; 7 | use progress_tracking::item_tracking::ItemProgressUpdater; 8 | use tracing::instrument; 9 | use ulid::Ulid; 10 | 11 | use crate::configurations::TranslatorConfig; 12 | use crate::errors::*; 13 | use crate::prometheus_metrics; 14 | use crate::remote_client_interface::create_remote_client; 15 | 16 | /// Manages the download of files based on a hash or pointer file. 17 | /// 18 | /// This class handles the clean operations. It's meant to be a single atomic session 19 | /// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards 20 | /// and xorbs needed to reconstruct those files are properly uploaded and registered. 21 | pub struct FileDownloader { 22 | /* ----- Configurations ----- */ 23 | config: Arc, 24 | client: Arc, 25 | } 26 | 27 | /// Smudge operations 28 | impl FileDownloader { 29 | pub async fn new(config: Arc) -> Result { 30 | let session_id = config 31 | .session_id 32 | .as_ref() 33 | .map(Cow::Borrowed) 34 | .unwrap_or_else(|| Cow::Owned(Ulid::new().to_string())); 35 | let client = create_remote_client(&config, &session_id, false)?; 36 | 37 | Ok(Self { config, client }) 38 | } 39 | 40 | #[instrument(skip_all, name = "FileDownloader::smudge_file_from_hash", fields(hash=file_id.hex()))] 41 | pub async fn smudge_file_from_hash( 42 | &self, 43 | file_id: &MerkleHash, 44 | file_name: Arc, 45 | output: &OutputProvider, 46 | range: Option, 47 | progress_updater: Option>, 48 | ) -> Result { 49 | let file_progress_tracker = progress_updater.map(|p| ItemProgressUpdater::item_tracker(&p, file_name, None)); 50 | 51 | // Currently, this works by always directly querying the remote server. 52 | let n_bytes = self.client.get_file(file_id, range, output, file_progress_tracker).await?; 53 | 54 | prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); 55 | 56 | Ok(n_bytes) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /data/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | pub mod configurations; 3 | pub mod constants; 4 | pub mod data_client; 5 | mod deduplication_interface; 6 | pub mod errors; 7 | mod file_cleaner; 8 | mod file_downloader; 9 | mod file_upload_session; 10 | pub mod migration_tool; 11 | mod prometheus_metrics; 12 | mod remote_client_interface; 13 | mod sha256; 14 | mod shard_interface; 15 | mod xet_file; 16 | 17 | pub use cas_client::CacheConfig; 18 | // Reexport this one for now 19 | pub use deduplication::RawXorbData; 20 | pub use file_downloader::FileDownloader; 21 | pub use file_upload_session::FileUploadSession; 22 | pub use xet_file::XetFileInfo; 23 | 24 | #[cfg(debug_assertions)] 25 | pub mod test_utils; 26 | -------------------------------------------------------------------------------- /data/src/migration_tool/hub_client_token_refresher.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use hub_client::{HubClient, Operation}; 4 | use utils::auth::{TokenInfo, TokenRefresher}; 5 | use utils::errors::AuthError; 6 | 7 | pub struct HubClientTokenRefresher { 8 | pub operation: Operation, 9 | pub client: Arc, 10 | } 11 | 12 | #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] 13 | #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] 14 | impl TokenRefresher for HubClientTokenRefresher { 15 | async fn refresh(&self) -> std::result::Result { 16 | let jwt_info = self 17 | .client 18 | .get_cas_jwt(self.operation) 19 | .await 20 | .map_err(AuthError::token_refresh_failure)?; 21 | 22 | Ok((jwt_info.access_token, jwt_info.exp)) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /data/src/migration_tool/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod hub_client_token_refresher; 2 | pub mod migrate; 3 | -------------------------------------------------------------------------------- /data/src/prometheus_metrics.rs: -------------------------------------------------------------------------------- 1 | use lazy_static::lazy_static; 2 | use prometheus::{IntCounter, register_int_counter}; 3 | 4 | // Some of the common tracking things 5 | lazy_static! { 6 | pub static ref FILTER_CAS_BYTES_PRODUCED: IntCounter = 7 | register_int_counter!("filter_process_cas_bytes_produced", "Number of CAS bytes produced during cleaning") 8 | .unwrap(); 9 | pub static ref FILTER_BYTES_CLEANED: IntCounter = 10 | register_int_counter!("filter_process_bytes_cleaned", "Number of bytes cleaned").unwrap(); 11 | pub static ref FILTER_BYTES_SMUDGED: IntCounter = 12 | register_int_counter!("filter_process_bytes_smudged", "Number of bytes smudged").unwrap(); 13 | } 14 | -------------------------------------------------------------------------------- /data/src/remote_client_interface.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | pub use cas_client::Client; 4 | use cas_client::RemoteClient; 5 | 6 | use crate::configurations::*; 7 | use crate::errors::Result; 8 | 9 | pub(crate) fn create_remote_client( 10 | config: &TranslatorConfig, 11 | session_id: &str, 12 | dry_run: bool, 13 | ) -> Result> { 14 | let cas_storage_config = &config.data_config; 15 | 16 | match cas_storage_config.endpoint { 17 | Endpoint::Server(ref endpoint) => Ok(Arc::new(RemoteClient::new( 18 | endpoint, 19 | &cas_storage_config.auth, 20 | &Some(cas_storage_config.cache_config.clone()), 21 | Some(config.shard_config.cache_directory.clone()), 22 | session_id, 23 | dry_run, 24 | ))), 25 | Endpoint::FileSystem(ref path) => { 26 | #[cfg(not(target_family = "wasm"))] 27 | { 28 | Ok(Arc::new(cas_client::LocalClient::new(path)?)) 29 | } 30 | #[cfg(target_family = "wasm")] 31 | unimplemented!("Local file system access is not supported in WASM builds") 32 | }, 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /data/src/sha256.rs: -------------------------------------------------------------------------------- 1 | use merklehash::MerkleHash; 2 | use sha2::{Digest, Sha256}; 3 | use tokio::task::{JoinError, JoinHandle}; 4 | 5 | /// Helper struct to generate a sha256 hash as a MerkleHash. 6 | #[derive(Debug)] 7 | pub struct ShaGenerator { 8 | hasher: Option>>, 9 | } 10 | 11 | impl ShaGenerator { 12 | pub fn new() -> Self { 13 | Self { hasher: None } 14 | } 15 | 16 | /// Complete the last block, then hand off the new chunks to the new hasher. 17 | pub async fn update(&mut self, new_data: impl AsRef<[u8]> + Send + Sync + 'static) -> Result<(), JoinError> { 18 | let mut hasher = match self.hasher.take() { 19 | Some(jh) => jh.await??, 20 | None => Sha256::default(), 21 | }; 22 | 23 | // The previous task returns the hasher; we consume that and pass it on. 24 | // Use the compute background thread for this process. 25 | self.hasher = Some(tokio::task::spawn_blocking(move || { 26 | hasher.update(&new_data); 27 | 28 | Ok(hasher) 29 | })); 30 | 31 | Ok(()) 32 | } 33 | 34 | /// Generates a sha256 from the current state of the variant. 35 | pub async fn finalize(mut self) -> Result { 36 | let current_state = self.hasher.take(); 37 | 38 | let hasher = match current_state { 39 | Some(jh) => jh.await??, 40 | None => return Ok(MerkleHash::default()), 41 | }; 42 | 43 | let sha256 = hasher.finalize(); 44 | let hex_str = format!("{sha256:x}"); 45 | Ok(MerkleHash::from_hex(&hex_str).expect("Converting sha256 to merklehash.")) 46 | } 47 | } 48 | 49 | #[cfg(test)] 50 | mod sha_tests { 51 | use rand::{Rng, rng}; 52 | 53 | use super::*; 54 | 55 | const TEST_DATA: &str = "some data"; 56 | 57 | // use `echo -n "..." | sha256sum` with the `TEST_DATA` contents to get the sha to compare against 58 | const TEST_SHA: &str = "1307990e6ba5ca145eb35e99182a9bec46531bc54ddf656a602c780fa0240dee"; 59 | 60 | #[tokio::test] 61 | async fn test_sha_generation_builder() { 62 | let mut sha_generator = ShaGenerator::new(); 63 | sha_generator.update(TEST_DATA.as_bytes()).await.unwrap(); 64 | let hash = sha_generator.finalize().await.unwrap(); 65 | 66 | assert_eq!(TEST_SHA.to_string(), hash.hex()); 67 | } 68 | 69 | #[tokio::test] 70 | async fn test_sha_generation_build_multiple_chunks() { 71 | let mut sha_generator = ShaGenerator::new(); 72 | let td = TEST_DATA.as_bytes(); 73 | sha_generator.update(&td[0..4]).await.unwrap(); 74 | sha_generator.update(&td[4..td.len()]).await.unwrap(); 75 | let hash = sha_generator.finalize().await.unwrap(); 76 | 77 | assert_eq!(TEST_SHA.to_string(), hash.hex()); 78 | } 79 | 80 | #[tokio::test] 81 | async fn test_sha_multiple_updates() { 82 | // Test multiple versions. 83 | 84 | // Generate 4096 bytes of random data 85 | let mut rand_data = [0u8; 4096]; 86 | rng().fill(&mut rand_data[..]); 87 | 88 | let mut sha_generator = ShaGenerator::new(); 89 | 90 | // Add in random chunks. 91 | let mut pos = 0; 92 | while pos < rand_data.len() { 93 | let l = rng().random_range(0..32); 94 | let next_pos = (pos + l).min(rand_data.len()); 95 | sha_generator.update(rand_data[pos..next_pos].to_vec()).await.unwrap(); 96 | pos = next_pos; 97 | } 98 | 99 | let out_hash = sha_generator.finalize().await.unwrap(); 100 | 101 | let ref_hash = format!("{:x}", Sha256::digest(rand_data)); 102 | 103 | assert_eq!(out_hash.hex(), ref_hash); 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /data/src/xet_file.rs: -------------------------------------------------------------------------------- 1 | use error_printer::ErrorPrinter; 2 | use merklehash::{DataHashHexParseError, MerkleHash}; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | /// A struct that wraps a the Xet file information. 6 | #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] 7 | pub struct XetFileInfo { 8 | /// The Merkle hash of the file 9 | hash: String, 10 | 11 | /// The size of the file 12 | file_size: u64, 13 | } 14 | 15 | impl XetFileInfo { 16 | /// Creates a new `XetFileInfo` instance. 17 | /// 18 | /// # Arguments 19 | /// 20 | /// * `hash` - The Xet hash of the file. This is a Merkle hash string. 21 | /// * `file_size` - The size of the file. 22 | pub fn new(hash: String, file_size: u64) -> Self { 23 | Self { hash, file_size } 24 | } 25 | 26 | /// Returns the Merkle hash of the file. 27 | pub fn hash(&self) -> &str { 28 | &self.hash 29 | } 30 | 31 | /// Returns the parsed merkle hash of the file. 32 | pub fn merkle_hash(&self) -> std::result::Result { 33 | MerkleHash::from_hex(&self.hash).log_error("Error parsing hash value for file info") 34 | } 35 | 36 | /// Returns the size of the file. 37 | pub fn file_size(&self) -> u64 { 38 | self.file_size 39 | } 40 | 41 | pub fn as_pointer_file(&self) -> std::result::Result { 42 | serde_json::to_string(self) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /data/tests/integration_tests/test_basic_clean_smudge.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | set -x 4 | 5 | SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]:-$0}")" &>/dev/null && pwd 2>/dev/null)" 6 | . "$SCRIPT_DIR/initialize.sh" 7 | 8 | # Test small binary file clean & smudge 9 | create_data_file small.dat 1452 10 | 11 | x clean -d small.pft small.dat 12 | assert_is_pointer_file small.pft 13 | assert_pointer_file_size small.pft 1452 14 | 15 | x smudge -f small.pft small.dat.2 16 | assert_files_equal small.dat small.dat.2 17 | 18 | # Test big binary file clean & smudge 19 | create_data_file large.dat 4621684 # 4.6 MB 20 | 21 | x clean -d large.pft large.dat 22 | assert_is_pointer_file large.pft 23 | assert_pointer_file_size large.pft 4621684 24 | 25 | x smudge -f large.pft large.dat.2 26 | assert_files_equal large.dat large.dat.2 27 | 28 | # Test small text file clean 29 | create_text_file small.txt key1 100 1 30 | 31 | x clean -d small.pft small.txt 32 | assert_is_pointer_file small.pft 33 | x smudge -f small.pft small.txt.2 34 | assert_files_equal small.txt small.txt.2 35 | -------------------------------------------------------------------------------- /deduplication/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "deduplication" 3 | version = "0.14.5" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | mdb_shard = { path = "../mdb_shard" } 8 | merklehash = { path = "../merklehash" } 9 | progress_tracking = { path = "../progress_tracking" } 10 | utils = { path = "../utils" } 11 | 12 | async-trait = { workspace = true } 13 | bytes = { workspace = true } 14 | gearhash = { workspace = true } 15 | more-asserts = { workspace = true } 16 | 17 | [dev-dependencies] 18 | rand = { workspace = true } 19 | -------------------------------------------------------------------------------- /deduplication/README.md: -------------------------------------------------------------------------------- 1 | # Deduplication crate 2 | 3 | This package contains components and functionality to create chunks from raw data and attempt to deduplicate chunks locally and globally. 4 | 5 | ## Notable exports 6 | 7 | - `Chunker`: Stateful chunk boundary detector and chunk producer. 8 | - `DataAggregator`: Builder that accumulates chunks and file-info into uploadable units (xorbs and shards). 9 | - `FileDeduper`: Orchestrator for per-file deduplication over a provided data interface. 10 | - `DeduplicationDataInterface`: Trait defining the data-access/upload interface required by the deduper. 11 | - `RawXorbData`: Container for a single upload unit (xorb) with its metadata. 12 | - `constants`: Public constants used to configure chunking and xorb limits. 13 | -------------------------------------------------------------------------------- /deduplication/src/chunk.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use merklehash::{MerkleHash, compute_data_hash}; 3 | 4 | #[derive(Debug, Clone, PartialEq)] 5 | pub struct Chunk { 6 | pub hash: MerkleHash, 7 | pub data: Bytes, 8 | } 9 | 10 | impl Chunk { 11 | pub fn new(data: Bytes) -> Self { 12 | Chunk { 13 | hash: compute_data_hash(&data), 14 | data, 15 | } 16 | } 17 | } 18 | 19 | // Implement &[u8] dereferencing for the Chunk 20 | impl AsRef<[u8]> for Chunk { 21 | fn as_ref(&self) -> &[u8] { 22 | self.data.as_ref() 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /deduplication/src/constants.rs: -------------------------------------------------------------------------------- 1 | utils::configurable_constants! { 2 | 3 | /// This will target 1024 chunks per Xorb / CAS block 4 | ref TARGET_CHUNK_SIZE: usize = release_fixed(64 * 1024); 5 | 6 | /// TARGET_CDC_CHUNK_SIZE / MINIMUM_CHUNK_DIVISOR is the smallest chunk size 7 | /// Note that this is not a threshold but a recommendation. 8 | /// Smaller chunks can be produced if size of a file is smaller than this number. 9 | ref MINIMUM_CHUNK_DIVISOR: usize = release_fixed(8); 10 | 11 | /// TARGET_CDC_CHUNK_SIZE * MAXIMUM_CHUNK_MULTIPLIER is the largest chunk size 12 | /// Note that this is a limit. 13 | ref MAXIMUM_CHUNK_MULTIPLIER: usize = release_fixed(2); 14 | 15 | /// The maximum number of bytes to go in a single xorb. 16 | ref MAX_XORB_BYTES: usize = release_fixed(64 * 1024 * 1024); 17 | 18 | /// The maximum number of chunks to go in a single xorb. 19 | /// Chunks are targeted at 64K, for ~1024 chunks per xorb, but 20 | /// can be much higher when there are a lot of small files. 21 | ref MAX_XORB_CHUNKS: usize = 8 * 1024; 22 | } 23 | 24 | lazy_static! { 25 | /// The maximum chunk size, calculated from the configurable constants above 26 | pub static ref MAX_CHUNK_SIZE: usize = (*TARGET_CHUNK_SIZE) * *(MAXIMUM_CHUNK_MULTIPLIER); 27 | } 28 | -------------------------------------------------------------------------------- /deduplication/src/dedup_metrics.rs: -------------------------------------------------------------------------------- 1 | #[derive(Default, Debug, Clone, Copy)] 2 | pub struct DeduplicationMetrics { 3 | pub total_bytes: u64, 4 | pub deduped_bytes: u64, 5 | pub new_bytes: u64, 6 | pub deduped_bytes_by_global_dedup: u64, 7 | pub defrag_prevented_dedup_bytes: u64, 8 | 9 | pub total_chunks: u64, 10 | pub deduped_chunks: u64, 11 | pub new_chunks: u64, 12 | pub deduped_chunks_by_global_dedup: u64, 13 | pub defrag_prevented_dedup_chunks: u64, 14 | 15 | pub xorb_bytes_uploaded: u64, 16 | pub shard_bytes_uploaded: u64, 17 | pub total_bytes_uploaded: u64, 18 | } 19 | 20 | /// Implement + for the metrics above, so they can be added 21 | /// and updated after each call to process_chunks. 22 | impl DeduplicationMetrics { 23 | pub fn merge_in(&mut self, other: &Self) { 24 | self.total_bytes += other.total_bytes; 25 | self.deduped_bytes += other.deduped_bytes; 26 | self.new_bytes += other.new_bytes; 27 | self.deduped_bytes_by_global_dedup += other.deduped_bytes_by_global_dedup; 28 | self.defrag_prevented_dedup_bytes += other.defrag_prevented_dedup_bytes; 29 | 30 | self.total_chunks += other.total_chunks; 31 | self.deduped_chunks += other.deduped_chunks; 32 | self.new_chunks += other.new_chunks; 33 | self.deduped_chunks_by_global_dedup += other.deduped_chunks_by_global_dedup; 34 | self.defrag_prevented_dedup_chunks += other.defrag_prevented_dedup_chunks; 35 | 36 | self.xorb_bytes_uploaded += other.xorb_bytes_uploaded; 37 | self.shard_bytes_uploaded += other.shard_bytes_uploaded; 38 | self.total_bytes_uploaded += other.total_bytes_uploaded; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /deduplication/src/interface.rs: -------------------------------------------------------------------------------- 1 | use std::result::Result; 2 | 3 | use async_trait::async_trait; 4 | use mdb_shard::file_structs::FileDataSequenceEntry; 5 | use merklehash::MerkleHash; 6 | use progress_tracking::upload_tracking::FileXorbDependency; 7 | 8 | use crate::raw_xorb_data::RawXorbData; 9 | 10 | /// The interface needed for the deduplication routines to run. To use the deduplication code, 11 | /// define a struct that implements these methods. This struct must be given by value to the FileDeduper 12 | /// struct on creation. 13 | /// 14 | /// The two primary methods are chunk_hash_dedup_query, which determines whether and how a chunk can be deduped, 15 | /// and register_new_xorb, which is called intermittently when a new block of data is available for upload. 16 | /// 17 | /// The global dedup query functions are optional but needed if global dedup is to be enabled. 18 | #[cfg_attr(not(target_family = "wasm"), async_trait)] 19 | #[cfg_attr(target_family = "wasm", async_trait(?Send))] 20 | pub trait DeduplicationDataInterface: Send + Sync + 'static { 21 | /// The error type used for the interface 22 | type ErrorType; 23 | 24 | /// Query for possible shards that 25 | async fn chunk_hash_dedup_query( 26 | &self, 27 | query_hashes: &[MerkleHash], 28 | ) -> std::result::Result, Self::ErrorType>; 29 | 30 | /// Registers a new query for more information about the 31 | /// global deduplication. This is expected to run in the background. Simply return Ok(()) to 32 | /// disable global dedup queries. 33 | async fn register_global_dedup_query(&mut self, _chunk_hash: MerkleHash) -> Result<(), Self::ErrorType>; 34 | 35 | /// Waits for all the current queries to complete, then returns true if there is 36 | /// new deduplication information available. 37 | async fn complete_global_dedup_queries(&mut self) -> Result; 38 | 39 | /// Registers a Xorb of new data that has no deduplication references. 40 | async fn register_new_xorb(&mut self, xorb: RawXorbData) -> Result<(), Self::ErrorType>; 41 | 42 | /// Register a set of xorb dependencies; this is called periodically during the dedup 43 | /// process with a list of (xorb hash, n_bytes). As the final bit may get 44 | /// returned as a partial xorb without a hash yet, it is not gauranteed that the 45 | /// sum of the n_bytes across all the dependencies will equal the size of the file. 46 | async fn register_xorb_dependencies(&mut self, dependencies: &[FileXorbDependency]); 47 | } 48 | -------------------------------------------------------------------------------- /deduplication/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod chunk; 2 | mod chunking; 3 | pub mod constants; 4 | mod data_aggregator; 5 | mod dedup_metrics; 6 | mod defrag_prevention; 7 | mod file_deduplication; 8 | mod interface; 9 | mod raw_xorb_data; 10 | 11 | pub use chunk::Chunk; 12 | pub use chunking::{Chunker, find_partitions}; 13 | pub use data_aggregator::DataAggregator; 14 | pub use dedup_metrics::DeduplicationMetrics; 15 | pub use file_deduplication::FileDeduper; 16 | pub use interface::DeduplicationDataInterface; 17 | pub use raw_xorb_data::{RawXorbData, test_utils}; 18 | -------------------------------------------------------------------------------- /deduplication/src/parallel chunking.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/xet-core/2eec20baf15b61a3787bc542a6ab25dec61318a3/deduplication/src/parallel chunking.pdf -------------------------------------------------------------------------------- /deduplication/src/raw_xorb_data.rs: -------------------------------------------------------------------------------- 1 | use mdb_shard::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfo}; 2 | use merklehash::{MerkleHash, xorb_hash}; 3 | use more_asserts::*; 4 | 5 | use crate::Chunk; 6 | use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS}; 7 | 8 | /// This struct is the data needed to cut a 9 | #[derive(Default, Debug, Clone)] 10 | pub struct RawXorbData { 11 | /// The data for the xorb info. 12 | pub data: Vec, 13 | 14 | /// The cas info associated with the current xorb. 15 | pub cas_info: MDBCASInfo, 16 | 17 | /// The indices where a new file starts, to be used for the compression heuristic. 18 | pub file_boundaries: Vec, 19 | } 20 | 21 | impl RawXorbData { 22 | // Construct from raw chunks. chunk data from raw chunks. 23 | pub fn from_chunks(chunks: &[Chunk], file_boundaries: Vec) -> Self { 24 | debug_assert_le!(chunks.len(), *MAX_XORB_CHUNKS); 25 | 26 | let mut data = Vec::with_capacity(chunks.len()); 27 | let mut chunk_seq_entries = Vec::with_capacity(chunks.len()); 28 | 29 | // Build the sequences. 30 | let mut pos = 0; 31 | for c in chunks { 32 | chunk_seq_entries.push(CASChunkSequenceEntry::new(c.hash, c.data.len(), pos)); 33 | data.push(c.data.clone()); 34 | pos += c.data.len(); 35 | } 36 | let num_bytes = pos; 37 | 38 | debug_assert_le!(num_bytes, *MAX_XORB_BYTES); 39 | 40 | let hash_and_len: Vec<_> = chunks.iter().map(|c| (c.hash, c.data.len() as u64)).collect(); 41 | let cas_hash = xorb_hash(&hash_and_len); 42 | 43 | // Build the MDBCASInfo struct. 44 | let metadata = CASChunkSequenceHeader::new(cas_hash, chunks.len(), num_bytes); 45 | 46 | let cas_info = MDBCASInfo { 47 | metadata, 48 | chunks: chunk_seq_entries, 49 | }; 50 | 51 | RawXorbData { 52 | data, 53 | cas_info, 54 | file_boundaries, 55 | } 56 | } 57 | 58 | pub fn hash(&self) -> MerkleHash { 59 | self.cas_info.metadata.cas_hash 60 | } 61 | 62 | pub fn num_bytes(&self) -> usize { 63 | let n = self.cas_info.metadata.num_bytes_in_cas as usize; 64 | 65 | debug_assert_eq!(n, self.data.iter().map(|c| c.len()).sum::()); 66 | 67 | n 68 | } 69 | } 70 | 71 | pub mod test_utils { 72 | use super::RawXorbData; 73 | 74 | pub fn raw_xorb_to_vec(xorb: &RawXorbData) -> Vec { 75 | let mut new_vec = Vec::with_capacity(xorb.num_bytes()); 76 | 77 | for ch in xorb.data.iter() { 78 | new_vec.extend_from_slice(ch); 79 | } 80 | 81 | new_vec 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /error_printer/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "error_printer" 3 | version = "0.14.5" 4 | edition = "2024" 5 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 6 | 7 | [dependencies] 8 | tracing = { workspace = true } 9 | 10 | [dev-dependencies] 11 | tracing-test = { workspace = true } 12 | -------------------------------------------------------------------------------- /error_printer/tests/test_error.rs: -------------------------------------------------------------------------------- 1 | use error_printer::ErrorPrinter; 2 | use tracing_test::traced_test; 3 | 4 | #[test] 5 | #[traced_test] 6 | fn test_error() { 7 | let err: Result<(), &str> = Err("some error"); 8 | // Important: the line number of this log statement 9 | // is important as we check the log output to make sure the file/line are correct. 10 | let line_num = 11; 11 | assert!(err.log_error("error test").is_err()); 12 | 13 | check_logs(logs_contain, "ERROR", line_num) 14 | } 15 | 16 | #[test] 17 | #[traced_test] 18 | fn test_warn() { 19 | let err: Result<(), &str> = Err("some error"); 20 | // Important: the line number of this log statement 21 | // is important as we check the log output to make sure the file/line are correct. 22 | let line_num = 23; 23 | assert!(err.warn_error("warn test").is_err()); 24 | 25 | check_logs(logs_contain, "WARN", line_num) 26 | } 27 | 28 | #[test] 29 | #[traced_test] 30 | fn test_debug() { 31 | let err: Result<(), &str> = Err("some error"); 32 | // Important: the line number of this log statement 33 | // is important as we check the log output to make sure the file/line are correct. 34 | let line_num = 35; 35 | assert!(err.debug_error("debug test").is_err()); 36 | 37 | check_logs(logs_contain, "DEBUG", line_num) 38 | } 39 | 40 | #[test] 41 | #[traced_test] 42 | fn test_info() { 43 | let err: Result<(), &str> = Err("some error"); 44 | // Important: the line number of this log statement 45 | // is important as we check the log output to make sure the file/line are correct. 46 | let line_num = 47; 47 | assert!(err.info_error("info test").is_err()); 48 | 49 | check_logs(logs_contain, "INFO", line_num) 50 | } 51 | 52 | #[test] 53 | fn test_ok() { 54 | let i = 2642; 55 | let res: Result = Ok(i); 56 | assert_eq!(i, res.log_error("was err").unwrap()); 57 | assert_eq!(i, res.warn_error("was err").unwrap()); 58 | assert_eq!(i, res.debug_error("was err").unwrap()); 59 | assert_eq!(i, res.info_error("was err").unwrap()); 60 | } 61 | 62 | fn check_logs bool>(logs_contain: F, log_level: &str, line_num: i32) { 63 | assert!(logs_contain(log_level)); 64 | assert!(logs_contain("test, error: \"some error\"")); 65 | #[cfg(not(windows))] 66 | let expected_line = format!("{}:{}", file!(), line_num); 67 | #[cfg(windows)] 68 | let expected_line = format!("{}:{}", file!(), line_num).replace("\\", "\\\\"); 69 | 70 | assert!(logs_contain(&expected_line)); 71 | } 72 | -------------------------------------------------------------------------------- /error_printer/tests/test_option.rs: -------------------------------------------------------------------------------- 1 | use error_printer::OptionPrinter; 2 | use tracing_test::traced_test; 3 | 4 | #[test] 5 | #[traced_test] 6 | fn test_error() { 7 | let opt: Option<()> = None; 8 | // Important: the line number of this log statement 9 | // is important as we check the log output to make sure the file/line are correct. 10 | let line_num = 11; 11 | assert!(opt.error_none("error test: opt is None").is_none()); 12 | 13 | check_logs(logs_contain, "ERROR", line_num) 14 | } 15 | 16 | #[test] 17 | #[traced_test] 18 | fn test_warn() { 19 | let opt: Option<()> = None; 20 | // Important: the line number of this log statement 21 | // is important as we check the log output to make sure the file/line are correct. 22 | let line_num = 23; 23 | assert!(opt.warn_none("warn test: opt is None").is_none()); 24 | 25 | check_logs(logs_contain, "WARN", line_num) 26 | } 27 | 28 | #[test] 29 | #[traced_test] 30 | fn test_debug() { 31 | let opt: Option<()> = None; 32 | // Important: the line number of this log statement 33 | // is important as we check the log output to make sure the file/line are correct. 34 | let line_num = 35; 35 | assert!(opt.debug_none("debug test: opt is None").is_none()); 36 | 37 | check_logs(logs_contain, "DEBUG", line_num) 38 | } 39 | 40 | #[test] 41 | #[traced_test] 42 | fn test_info() { 43 | let opt: Option<()> = None; 44 | // Important: the line number of this log statement 45 | // is important as we check the log output to make sure the file/line are correct. 46 | let line_num = 47; 47 | assert!(opt.info_none("info test: opt is None").is_none()); 48 | 49 | check_logs(logs_contain, "INFO", line_num) 50 | } 51 | 52 | #[test] 53 | fn test_some() { 54 | let i = 2642; 55 | let opt = Some(i); 56 | assert_eq!(i, opt.error_none("was none").unwrap()); 57 | assert_eq!(i, opt.warn_none("was none").unwrap()); 58 | assert_eq!(i, opt.debug_none("was none").unwrap()); 59 | assert_eq!(i, opt.info_none("was none").unwrap()); 60 | } 61 | 62 | fn check_logs bool>(logs_contain: F, log_level: &str, line_num: i32) { 63 | assert!(logs_contain(log_level)); 64 | assert!(logs_contain("opt is None")); 65 | #[cfg(not(windows))] 66 | let expected_line = format!("{}:{}", file!(), line_num); 67 | #[cfg(windows)] 68 | let expected_line = format!("{}:{}", file!(), line_num).replace("\\", "\\\\"); 69 | assert!(logs_contain(&expected_line)); 70 | } 71 | -------------------------------------------------------------------------------- /file_utils/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "file_utils" 3 | version = "0.14.2" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | lazy_static = { workspace = true } 8 | libc = { workspace = true } 9 | rand = { workspace = true } 10 | tracing = { workspace = true } 11 | 12 | [target.'cfg(windows)'.dependencies] 13 | winapi = { workspace = true } 14 | 15 | [target.'cfg(unix)'.dependencies] 16 | whoami = { workspace = true } 17 | colored = { workspace = true } 18 | 19 | [dev-dependencies] 20 | anyhow = { workspace = true } 21 | tempfile = { workspace = true } 22 | -------------------------------------------------------------------------------- /file_utils/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod file_metadata; 2 | mod privilege_context; 3 | mod safe_file_creator; 4 | 5 | pub use privilege_context::{PrivilegedExecutionContext, create_dir_all, create_file}; 6 | pub use safe_file_creator::SafeFileCreator; 7 | -------------------------------------------------------------------------------- /git_xet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "git_xet" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [[bin]] 7 | name = "git-xet" 8 | path = "src/bin/main.rs" 9 | 10 | [dependencies] 11 | cas_client = { path = "../cas_client" } 12 | data = { path = "../data" } 13 | progress_tracking = { path = "../progress_tracking" } 14 | utils = { path = "../utils" } 15 | hub_client = { path = "../hub_client" } 16 | 17 | anyhow = { workspace = true } 18 | async-trait = { workspace = true } 19 | chrono = { workspace = true } 20 | clap = { workspace = true } 21 | derivative = { workspace = true } 22 | git-url-parse = { workspace = true } 23 | git2 = { workspace = true } 24 | reqwest = { workspace = true } 25 | reqwest-middleware = { workspace = true } 26 | rust-netrc = { workspace = true } 27 | serde = { workspace = true } 28 | serde_json = { workspace = true } 29 | tempfile = { workspace = true } 30 | thiserror = { workspace = true } 31 | tokio = { workspace = true } 32 | 33 | [target.'cfg(unix)'.dependencies] 34 | openssh = { workspace = true } 35 | 36 | [dev-dependencies] 37 | serial_test = { workspace = true } -------------------------------------------------------------------------------- /git_xet/README.md: -------------------------------------------------------------------------------- 1 | Git-Xet is a Git LFS custom transfer agent that implements upload and download of files using the Xet protocol. Install `git-xet`, follow your regular workflow to `git lfs track ...` & `git add ...` & `git commit ...` & `git push`, and your files are uploaded to Hugging Face repos using the Xet protocol. Enjoy the dedupe! 2 | 3 | ## Installation 4 | ### Prerequisite 5 | Make sure you have [git](https://git-scm.com/downloads) and [git-lfs](https://git-lfs.com/) installed and configured correctly. 6 | ### macOS or Linux (amd64 or aarch64) 7 | To install using Homebrew: 8 | ``` 9 | brew tap huggingface/tap 10 | brew install git-xet 11 | ``` 12 | Or, using an installation script, run the following in your terminal (requires `curl` and `unzip`): 13 | ``` 14 | curl --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/huggingface/xet-core/refs/heads/main/git_xet/install.sh | sh 15 | ``` 16 | To verify the installation, run: 17 | ``` 18 | git-xet --version 19 | ``` 20 | 21 | ### Windows (amd64) 22 | Using an installer: 23 | - Download `git-xet-windows-installer-x86_64.zip` ([available here](https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-windows-installer-x86_64.zip)) and unzip. 24 | - Run the `msi` installer file and follow the prompts. 25 | 26 | Manual installation: 27 | - Download `git-xet-windows-x86_64.zip` ([available here](https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-windows-x86_64.zip)) and unzip. 28 | - Place the extracted `git-xet.exe` under a `PATH` directory. 29 | - Run `git-xet install` in a terminal. 30 | 31 | To verity the installation, run: 32 | ``` 33 | git-xet --version 34 | ``` 35 | 36 | ## Uninstall 37 | ### macOS or Linux 38 | Using Homebrew: 39 | ``` 40 | git-xet uninstall 41 | brew uninstall git-xet 42 | ``` 43 | If you used the installation script (for MacOS or Linux), run the following in your terminal: 44 | ``` 45 | git-xet uninstall 46 | sudo rm $(which git-xet) 47 | ``` 48 | ### Windows 49 | If you used the installer: 50 | - Navigate to Settings -> Apps -> Installed apps 51 | - Find "Git-Xet". 52 | - Select the "Uninstall" option available in the context menu. 53 | 54 | If you manually installed: 55 | - Run `git-xet uninstall` in a terminal. 56 | - Delete the `git-xet.exe` file from the location where it was originally placed. 57 | 58 | ## How It Works 59 | Git-Xet works by registering itself as a custom transfer agent to Git LFS by name "xet". On `git push`, `git fetch` or `git pull`, `git-lfs` negotiates with the remote server to determine the transfer agent to use. During this process, `git-lfs` sends to the server all locally registered agent names in the Batch API request, and the server replies with exactly one agent name in the response. Should "xet" be picked, `git-lfs` delegates the uploading or downloading operation to `git-xet` through a sequential protocol. 60 | 61 | For more details, see the Git LFS [Batch API](https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md) and [Custom Transfer Agent](https://github.com/git-lfs/git-lfs/blob/main/docs/custom-transfers.md) documentation. 62 | -------------------------------------------------------------------------------- /git_xet/entitlements.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.cs.allow-unsigned-executable-memory 6 | 7 | com.apple.security.cs.disable-library-validation 8 | 9 | 10 | -------------------------------------------------------------------------------- /git_xet/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script detects the OS and architecture to download the correct binary, 4 | # unzips it, and moves it to the user's local bin directory. 5 | 6 | # --- Configuration --- 7 | URL_LINUX_AMD64="https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-linux-x86_64.zip" 8 | URL_LINUX_ARM64="https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-linux-aarch64.zip" 9 | URL_MACOS_AMD64="https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-macos-x86_64.zip" 10 | URL_MACOS_ARM64="https://github.com/huggingface/xet-core/releases/download/git-xet-v0.1.0/git-xet-macos-aarch64.zip" 11 | 12 | # The name of the binary inside the zip file. 13 | BINARY_NAME="git-xet" 14 | 15 | # The destination for the binary. 16 | INSTALL_DIR="/usr/local/bin" 17 | 18 | # --- Main Script --- 19 | 20 | # Function to handle errors and exit 21 | handle_error() { 22 | echo "Error: $1" >&2 23 | exit 1 24 | } 25 | 26 | # Get OS and architecture 27 | OS="$(uname -s)" 28 | ARCH="$(uname -m)" 29 | 30 | echo "Detected OS: ${OS}" 31 | echo "Detected Architecture: ${ARCH}" 32 | 33 | # Determine which URL to use 34 | DOWNLOAD_URL="" 35 | case "${OS}" in 36 | Linux) 37 | case "${ARCH}" in 38 | x86_64) 39 | DOWNLOAD_URL="${URL_LINUX_AMD64}" 40 | ;; 41 | aarch64 | arm64) 42 | DOWNLOAD_URL="${URL_LINUX_ARM64}" 43 | ;; 44 | esac 45 | ;; 46 | Darwin) # macOS 47 | case "${ARCH}" in 48 | x86_64) 49 | DOWNLOAD_URL="${URL_MACOS_AMD64}" 50 | ;; 51 | arm64) 52 | DOWNLOAD_URL="${URL_MACOS_ARM64}" 53 | ;; 54 | esac 55 | ;; 56 | esac 57 | 58 | # Check if a download URL was found 59 | if [ -z "${DOWNLOAD_URL}" ]; then 60 | handle_error "Unsupported OS/Architecture combination: ${OS}/${ARCH}" 61 | fi 62 | 63 | # Create a temporary directory for the download and extraction 64 | TMP_DIR=$(mktemp -d) 65 | 66 | # Function to clean up the temporary directory on exit 67 | cleanup() { 68 | echo "Cleaning up..." 69 | rm -rf "${TMP_DIR}" 70 | } 71 | trap cleanup EXIT 72 | 73 | # Change to the temporary directory 74 | cd "${TMP_DIR}" || handle_error "Could not change to temporary directory." 75 | 76 | # Download the file 77 | echo "Downloading from: ${DOWNLOAD_URL}..." 78 | curl -sSL -o binary.zip "${DOWNLOAD_URL}" || handle_error "Download failed." 79 | 80 | # Unzip the file 81 | echo "Unzipping..." 82 | unzip -q binary.zip || handle_error "Unzipping failed. Make sure 'unzip' is installed." 83 | 84 | # Check if the binary exists 85 | if [ ! -f "${BINARY_NAME}" ]; then 86 | handle_error "Binary '${BINARY_NAME}' not found in the zip file." 87 | fi 88 | 89 | # Make the binary executable 90 | echo "Setting executable permissions..." 91 | chmod +x "${BINARY_NAME}" 92 | 93 | # Move the binary to the install directory 94 | echo "Installing binary to ${INSTALL_DIR}..." 95 | # Use sudo if the user doesn't have write permissions to the directory 96 | if [ -w "${INSTALL_DIR}" ]; then 97 | mv "${BINARY_NAME}" "${INSTALL_DIR}/" || handle_error "Failed to move binary." 98 | else 99 | echo "This script requires sudo permissions to install to ${INSTALL_DIR}." 100 | sudo mv "${BINARY_NAME}" "${INSTALL_DIR}/" || handle_error "Failed to move binary with sudo." 101 | fi 102 | 103 | git-xet install --concurrency 3 104 | 105 | echo "Installation complete!" 106 | -------------------------------------------------------------------------------- /git_xet/src/bin/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use clap::Parser; 3 | use git_xet::app::XetAgentApp; 4 | 5 | #[tokio::main] 6 | async fn main() -> Result<()> { 7 | let app = XetAgentApp::parse(); 8 | 9 | app.run().await?; 10 | 11 | Ok(()) 12 | } 13 | -------------------------------------------------------------------------------- /git_xet/src/constants.rs: -------------------------------------------------------------------------------- 1 | // Naming convention 2 | pub const GIT_EXECUTABLE: &str = "git"; 3 | pub const GIT_LFS_CUSTOM_TRANSFER_AGENT_NAME: &str = "xet"; 4 | pub const GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM: &str = "git-xet"; 5 | 6 | // The current version of executable 7 | pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); 8 | 9 | // Moon-landing Xet service headers 10 | pub const XET_CAS_URL: &str = "X-Xet-Cas-Url"; 11 | pub const XET_ACCESS_TOKEN_HEADER: &str = "X-Xet-Access-Token"; 12 | pub const XET_TOKEN_EXPIRATION_HEADER: &str = "X-Xet-Token-Expiration"; 13 | pub const XET_SESSION_ID: &str = "X-Xet-Session-Id"; 14 | 15 | // Environment variable names 16 | pub const HF_TOKEN_ENV: &str = "HF_TOKEN"; 17 | pub const HF_ENDPOINT_ENV: &str = "HF_ENDPOINT"; 18 | -------------------------------------------------------------------------------- /git_xet/src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Display; 2 | use std::path::PathBuf; 3 | 4 | use cas_client::CasClientError; 5 | use data::errors::DataProcessingError; 6 | use thiserror::Error; 7 | 8 | use crate::lfs_agent_protocol::GitLFSProtocolError; 9 | 10 | #[derive(Error, Debug)] 11 | pub enum GitXetError { 12 | #[error("Git command failed: {reason}, {source:?}")] 13 | GitCommandFailed { 14 | reason: String, 15 | source: Option, 16 | }, 17 | 18 | #[error("Failed to find Git repo at {path}, internal error {source}")] 19 | NoGitRepo { path: PathBuf, source: git2::Error }, 20 | 21 | #[error("Internal Git error: {0}")] 22 | GitError(#[from] git2::Error), 23 | 24 | #[error("Invalid Git config: {0}")] 25 | InvalidGitConfig(String), 26 | 27 | #[error("Invalid Git URL: {0}")] 28 | InvalidGitUrl(#[from] git_url_parse::GitUrlParseError), 29 | 30 | #[error("Invalid LFS protocol: {0}")] 31 | InvalidGitLFSProtocol(#[from] GitLFSProtocolError), 32 | 33 | #[error("Operation not supported: {0}")] 34 | NotSupported(String), 35 | 36 | #[error("I/O error: {0}")] 37 | IO(#[from] std::io::Error), 38 | 39 | #[error("Internal error: {0}")] 40 | Internal(String), 41 | 42 | #[error("Transfer agent error: {0}")] 43 | TransferAgent(#[from] DataProcessingError), 44 | 45 | #[error("Hub client error: {0}")] 46 | HubClient(#[from] hub_client::HubClientError), 47 | } 48 | 49 | pub type Result = std::result::Result; 50 | 51 | impl GitXetError { 52 | pub(crate) fn git_cmd_failed(e: impl Display, source: Option) -> GitXetError { 53 | GitXetError::GitCommandFailed { 54 | reason: e.to_string(), 55 | source, 56 | } 57 | } 58 | 59 | pub(crate) fn not_supported(e: impl Display) -> GitXetError { 60 | GitXetError::NotSupported(e.to_string()) 61 | } 62 | 63 | pub(crate) fn config_error(e: impl Display) -> GitXetError { 64 | GitXetError::InvalidGitConfig(e.to_string()) 65 | } 66 | 67 | pub(crate) fn internal(e: impl Display) -> GitXetError { 68 | GitXetError::Internal(e.to_string()) 69 | } 70 | } 71 | 72 | impl From for GitXetError { 73 | fn from(value: CasClientError) -> Self { 74 | Self::from(DataProcessingError::CasClientError(value)) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /git_xet/src/lfs_agent_protocol/errors.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Display; 2 | 3 | use thiserror::Error; 4 | 5 | #[derive(Error, Debug)] 6 | pub enum GitLFSProtocolError { 7 | #[error("Bad custom transfer protocol syntax: {0}")] 8 | Syntax(String), 9 | 10 | #[error("Invalid custom transfer protocol argument: {0}")] 11 | Argument(String), 12 | 13 | #[error("Invalid custom transfer agent state: {0}")] 14 | State(String), 15 | 16 | #[error("Serde to/from Json failed: {0:?}")] 17 | SerdeJson(#[from] serde_json::Error), 18 | 19 | #[error("I/O error: {0}")] 20 | IO(#[from] std::io::Error), 21 | } 22 | 23 | pub(super) type Result = std::result::Result; 24 | 25 | impl GitLFSProtocolError { 26 | pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError { 27 | GitLFSProtocolError::Syntax(e.to_string()) 28 | } 29 | 30 | pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError { 31 | GitLFSProtocolError::Argument(e.to_string()) 32 | } 33 | 34 | pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError { 35 | GitLFSProtocolError::State(e.to_string()) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /git_xet/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod app; 2 | mod auth; 3 | mod constants; 4 | mod errors; 5 | mod git_process_wrapping; 6 | mod git_repo; 7 | mod git_url; 8 | mod lfs_agent_protocol; 9 | mod test_utils; 10 | mod token_refresher; 11 | -------------------------------------------------------------------------------- /git_xet/src/test_utils/gitaskpass.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | # This is intended to be left blank to simulate the action that users don't input anything 4 | # when being prompted by git credential helper to input credentials. -------------------------------------------------------------------------------- /git_xet/src/test_utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod temp; 2 | mod test_repo; 3 | 4 | #[cfg(test)] 5 | pub use temp::TempHome; 6 | #[cfg(test)] 7 | pub use test_repo::TestRepo; 8 | -------------------------------------------------------------------------------- /git_xet/src/test_utils/temp.rs: -------------------------------------------------------------------------------- 1 | #![cfg(test)] 2 | use std::path::Path; 3 | use std::rc::Rc; 4 | 5 | use anyhow::Result; 6 | use tempfile::{TempDir, tempdir}; 7 | use utils::EnvVarGuard; 8 | 9 | use crate::git_process_wrapping::run_git_captured; 10 | 11 | // A test utility to create a temporary HOME environment, this sets up a clean environment for 12 | // git operations which depend heavily on a global config file. 13 | // This is meant to be used only in single-threaded tests as `Rc` is `!Send` and `!Sync`. 14 | #[derive(Clone)] 15 | pub struct TempHome { 16 | pub _env_guard: Rc, 17 | pub dir: Rc, 18 | } 19 | 20 | impl TempHome { 21 | pub fn new() -> Result { 22 | let home = tempdir()?; 23 | let home_dir = home.path(); 24 | let abs_home_dir = std::path::absolute(home_dir)?; 25 | let home_env_guard = EnvVarGuard::set("HOME", abs_home_dir.as_os_str()); 26 | 27 | Ok(Self { 28 | _env_guard: home_env_guard.into(), 29 | dir: home.into(), 30 | }) 31 | } 32 | 33 | pub fn with_default_git_config(self) -> Result { 34 | run_git_captured(self.dir.path(), "config", &["--global", "user.name", "test"])?; 35 | run_git_captured(self.dir.path(), "config", &["--global", "user.email", "test@hf.co"])?; 36 | 37 | #[cfg(target_os = "macos")] 38 | let _ = run_git_captured(self.dir.path(), "config", &["--global", "--unset", "credential.helper"]); 39 | 40 | Ok(self) 41 | } 42 | 43 | pub fn path(&self) -> &Path { 44 | self.dir.path() 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /git_xet/src/token_refresher.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | use cas_client::{Api, RetryConfig, build_http_client}; 5 | use hub_client::{CasJWTInfo, CredentialHelper, Operation}; 6 | use reqwest::header; 7 | use reqwest_middleware::ClientWithMiddleware; 8 | use utils::auth::{TokenInfo, TokenRefresher}; 9 | use utils::errors::AuthError; 10 | 11 | use crate::auth::get_credential; 12 | use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM; 13 | use crate::errors::Result; 14 | use crate::git_repo::GitRepo; 15 | use crate::git_url::GitUrl; 16 | 17 | pub struct DirectRefreshRouteTokenRefresher { 18 | refresh_route: String, 19 | client: ClientWithMiddleware, 20 | cred_helper: Arc, 21 | } 22 | 23 | impl DirectRefreshRouteTokenRefresher { 24 | pub fn new( 25 | repo: &GitRepo, 26 | remote_url: Option, 27 | refresh_route: &str, 28 | operation: Operation, 29 | session_id: &str, 30 | ) -> Result { 31 | let remote_url = match remote_url { 32 | Some(r) => r, 33 | None => repo.remote_url()?, 34 | }; 35 | 36 | let cred_helper = get_credential(repo, &remote_url, operation)?; 37 | 38 | Ok(Self { 39 | refresh_route: refresh_route.to_owned(), 40 | client: build_http_client(RetryConfig::default(), session_id)?, 41 | cred_helper, 42 | }) 43 | } 44 | } 45 | 46 | #[async_trait] 47 | impl TokenRefresher for DirectRefreshRouteTokenRefresher { 48 | async fn refresh(&self) -> std::result::Result { 49 | let req = self 50 | .client 51 | .get(&self.refresh_route) 52 | .with_extension(Api("xet-token")) 53 | .header(header::USER_AGENT, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM); 54 | let req = self 55 | .cred_helper 56 | .fill_credential(req) 57 | .await 58 | .map_err(AuthError::token_refresh_failure)?; 59 | let response = req.send().await.map_err(AuthError::token_refresh_failure)?; 60 | 61 | let jwt_info: CasJWTInfo = response.json().await.map_err(AuthError::token_refresh_failure)?; 62 | 63 | Ok((jwt_info.access_token, jwt_info.exp)) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /git_xet/windows_installer/.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | *.exe 3 | -------------------------------------------------------------------------------- /git_xet/windows_installer/sign_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "Endpoint": "https://eus.codesigning.azure.net/", 3 | "CodeSigningAccountName": "tsa-huggingface-apps", 4 | "CertificateProfileName": "git-xet-windows", 5 | "ExcludeCredentials": [ 6 | "ManagedIdentityCredential", 7 | "WorkloadIdentityCredential", 8 | "SharedTokenCacheCredential", 9 | "VisualStudioCredential", 10 | "VisualStudioCodeCredential", 11 | "AzureCliCredential", 12 | "AzurePowerShellCredential", 13 | "AzureDeveloperCliCredential", 14 | "InteractiveBrowserCredential" 15 | ] 16 | } -------------------------------------------------------------------------------- /hf_xet/.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | .venv/ 3 | -------------------------------------------------------------------------------- /hf_xet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hf_xet" 3 | version = "1.1.10" 4 | edition = "2024" 5 | license = "Apache-2.0" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | [lib] 9 | name = "hf_xet" 10 | crate-type = ["cdylib"] 11 | 12 | [dependencies] 13 | cas_client = { path = "../cas_client" } 14 | data = { path = "../data" } 15 | error_printer = { path = "../error_printer" } 16 | progress_tracking = { path = "../progress_tracking" } 17 | utils = { path = "../utils" } 18 | xet_runtime = { path = "../xet_runtime" } 19 | 20 | async-trait = "0.1" 21 | bipbuffer = "0.1" 22 | chrono = "0.4" 23 | itertools = "0.14" 24 | lazy_static = "1.5" 25 | pprof = { version = "0.14", features = [ 26 | "flamegraph", 27 | "prost", 28 | "protobuf-codec", 29 | ], optional = true } 30 | pyo3 = { version = "0.26", features = [ 31 | "extension-module", 32 | "abi3-py37", 33 | "auto-initialize", 34 | ] } 35 | rand = "0.9.2" 36 | serde = { version = "1", features = ["derive"] } 37 | serde_json = "1" 38 | tracing = "0.1" 39 | tracing-subscriber = { version = "0.3", features = [ 40 | "json", 41 | "tracing-log", 42 | "env-filter", 43 | "registry", 44 | ] } 45 | tracing-appender = "0.2" 46 | 47 | console-subscriber = { version = "0.4.1", optional = true } 48 | 49 | # Unix-specific dependencies 50 | [target.'cfg(unix)'.dependencies] 51 | signal-hook = "0.3" 52 | 53 | # Windows-specific dependencies 54 | [target.'cfg(windows)'.dependencies] 55 | ctrlc = "3.4" 56 | 57 | [features] 58 | native-tls = ["cas_client/native-tls-vendored"] 59 | native-tls-vendored = ["cas_client/native-tls-vendored"] 60 | profiling = ["pprof"] 61 | tokio-console = ["dep:console-subscriber"] 62 | 63 | [profile.release] 64 | split-debuginfo = "packed" 65 | opt-level = "s" 66 | lto = true 67 | codegen-units = 1 68 | 69 | # on manylinux and macos maturin + split-debuginfo doesn't output debug symbols for .so objects, 70 | # so we need a different profile to build. For mac, the below settings will output a .dSYM 71 | # file. For Linux, we are stripping them manually using binutils. 72 | [profile.release-dbgsymbols] 73 | inherits = "release" 74 | debug = true 75 | split-debuginfo = "none" 76 | 77 | 78 | [profile.opt-test] 79 | inherits = "dev" 80 | debug = true 81 | opt-level = 3 82 | -------------------------------------------------------------------------------- /hf_xet/LICENSE: -------------------------------------------------------------------------------- 1 | ../LICENSE -------------------------------------------------------------------------------- /hf_xet/README.md: -------------------------------------------------------------------------------- 1 | 16 |

17 | License 18 | GitHub release 19 | Contributor Covenant 20 |

21 | 22 |

23 |

🤗 hf-xet - xet client tech, used in huggingface_hub

24 |

25 | 26 | ## Welcome 27 | 28 | `hf-xet` enables `huggingface_hub` to utilize xet storage for uploading and downloading to HF Hub. Xet storage provides chunk-based deduplication, efficient storage/retrieval with local disk caching, and backwards compatibility with Git LFS. This library is not meant to be used directly, and is instead intended to be used from [huggingface_hub](https://pypi.org/project/huggingface-hub). 29 | 30 | ## Key features 31 | 32 | ♻ **chunk-based deduplication implementation**: avoid transferring and storing chunks that are shared across binary files (models, datasets, etc). 33 | 34 | 🤗 **Python bindings**: bindings for [huggingface_hub](https://github.com/huggingface/huggingface_hub/) package. 35 | 36 | ↔ **network communications**: concurrent communication to HF Hub Xet backend services (CAS). 37 | 38 | 🔖 **local disk caching**: chunk-based cache that sits alongside the existing [huggingface_hub disk cache](https://huggingface.co/docs/huggingface_hub/guides/manage-cache). 39 | 40 | ## Installation 41 | 42 | Install the `hf_xet` package with [pip](https://pypi.org/project/hf-xet/): 43 | 44 | ```bash 45 | pip install hf_xet 46 | ``` 47 | 48 | ## Quick Start 49 | 50 | `hf_xet` is not intended to be run independently as it is expected to be used from `huggingface_hub`, so to get started with `huggingface_hub` check out the documentation [here]("https://hf.co/docs/huggingface_hub"). 51 | 52 | ## Contributions (feature requests, bugs, etc.) are encouraged & appreciated 💙💚💛💜🧡❤️ 53 | 54 | Please join us in making hf-xet better. We value everyone's contributions. Code is not the only way to help. Answering questions, helping each other, improving documentation, filing issues all help immensely. If you are interested in contributing (please do!), check out the [contribution guide](https://github.com/huggingface/xet-core/blob/main/CONTRIBUTING.md) for this repository. -------------------------------------------------------------------------------- /hf_xet/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.7,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "hf-xet" 7 | requires-python = ">=3.8" 8 | description = "Fast transfer of large files with the Hugging Face Hub." 9 | author = "Hugging Face, Inc." 10 | author_email = "julien@huggingface.co" 11 | maintainers = [ 12 | { name = "Rajat Arya", email = "rajat@rajatarya.com" }, 13 | { name = "Jared Sulzdorf", email = "j.sulzdorf@gmail.com" }, 14 | { name = "Di Xiao", email = "di@huggingface.co" }, 15 | { name = "Assaf Vayner", email = "assaf@huggingface.co" }, 16 | { name = "Hoyt Koepke", email = "hoytak@gmail.com" }, 17 | ] 18 | license = "Apache-2.0" 19 | license-file = "LICENSE" 20 | readme = "README.md" 21 | classifiers = [ 22 | "License :: OSI Approved :: Apache Software License", 23 | "Programming Language :: Rust", 24 | "Programming Language :: Python :: Implementation :: CPython", 25 | "Programming Language :: Python :: Implementation :: PyPy", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3 :: Only", 28 | "Programming Language :: Python :: 3.8", 29 | "Programming Language :: Python :: 3.9", 30 | "Programming Language :: Python :: 3.10", 31 | "Programming Language :: Python :: 3.11", 32 | "Programming Language :: Python :: 3.12", 33 | "Programming Language :: Python :: 3.13", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | ] 36 | dynamic = ["version"] 37 | 38 | [project.optional-dependencies] 39 | tests = [ 40 | "pytest", 41 | ] 42 | 43 | [project.urls] 44 | Homepage = "https://github.com/huggingface/xet-core" 45 | Documentation = "https://huggingface.co/docs/hub/en/storage-backends#using-xet-storage" 46 | Issues = "https://github.com/huggingface/xet-core/issues" 47 | Repository = "https://github.com/huggingface/xet-core.git" 48 | 49 | [tool.maturin] 50 | python-source = "python" 51 | features = ["pyo3/extension-module"] 52 | -------------------------------------------------------------------------------- /hf_xet/python/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/xet-core/2eec20baf15b61a3787bc542a6ab25dec61318a3/hf_xet/python/.gitkeep -------------------------------------------------------------------------------- /hf_xet/src/token_refresh.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter}; 2 | 3 | use pyo3::exceptions::PyTypeError; 4 | use pyo3::prelude::PyAnyMethods; 5 | use pyo3::{Py, PyAny, PyErr, PyResult, Python}; 6 | use tracing::error; 7 | use utils::auth::{TokenInfo, TokenRefresher}; 8 | use utils::errors::AuthError; 9 | 10 | /// A wrapper struct of a python function to refresh the CAS auth token. 11 | /// Since tokens are generated by hub, we want to be able to refresh the 12 | /// token using the hub client, which is only available in python. 13 | pub struct WrappedTokenRefresher { 14 | /// The function responsible for refreshing a token. 15 | /// Expects no inputs and returns a (str, u64) representing the new token 16 | /// and the unixtime (in seconds) of expiration, raising an exception 17 | /// if there is an issue. 18 | py_func: Py, 19 | name: String, 20 | } 21 | 22 | impl Debug for WrappedTokenRefresher { 23 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 24 | write!(f, "WrappedTokenRefresher({})", self.name) 25 | } 26 | } 27 | 28 | impl WrappedTokenRefresher { 29 | pub fn from_func(py_func: Py) -> PyResult { 30 | let name = Self::validate_callable(&py_func)?; 31 | Ok(Self { py_func, name }) 32 | } 33 | 34 | /// Validate that the inputted python object is callable 35 | fn validate_callable(py_func: &Py) -> Result { 36 | Python::attach(|py| { 37 | let f = py_func.bind(py); 38 | let name = f 39 | .repr() 40 | .and_then(|repr| repr.extract::()) 41 | .unwrap_or("unknown".to_string()); 42 | if !f.is_callable() { 43 | error!("TokenRefresher func: {name} is not callable"); 44 | return Err(PyTypeError::new_err(format!("refresh func: {name} is not callable"))); 45 | } 46 | Ok(name) 47 | }) 48 | } 49 | } 50 | 51 | #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] 52 | #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] 53 | impl TokenRefresher for WrappedTokenRefresher { 54 | async fn refresh(&self) -> Result { 55 | Python::attach(|py| { 56 | let f = self.py_func.bind(py); 57 | if !f.is_callable() { 58 | return Err(AuthError::RefreshFunctionNotCallable(self.name.clone())); 59 | } 60 | let result = f 61 | .call0() 62 | .map_err(|e| AuthError::TokenRefreshFailure(format!("Error refreshing token: {e:?}")))?; 63 | result.extract::<(String, u64)>().map_err(|e| { 64 | AuthError::TokenRefreshFailure(format!("refresh function didn't return a (String, u64) tuple: {e:?}")) 65 | }) 66 | }) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /hf_xet_thin_wasm/.gitignore: -------------------------------------------------------------------------------- 1 | pkg 2 | target -------------------------------------------------------------------------------- /hf_xet_thin_wasm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hf_xet_thin_wasm" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [lib] 7 | crate-type = ["cdylib", "rlib"] 8 | 9 | [dependencies] 10 | deduplication = { path = "../deduplication" } 11 | mdb_shard = { path = "../mdb_shard" } 12 | merklehash = { path = "../merklehash" } 13 | 14 | serde = { version = "1.0.219", features = ["derive"] } 15 | serde-wasm-bindgen = "0.6.5" 16 | wasm-bindgen = "=0.2.100" 17 | -------------------------------------------------------------------------------- /hf_xet_thin_wasm/README.md: -------------------------------------------------------------------------------- 1 | # hf_xet_thin_wasm 2 | 3 | Exports limited functionality from xet-core in a WebAssembly compile-able/compatible way for use primarily by [huggingface.js](https://github.com/huggingface/huggingface.js). 4 | 5 | Exports: 6 | 7 | - Xorb hash computation 8 | - File hash computation 9 | - Verification range hash computation 10 | - Chunker struct/class 11 | - Generate chunk boundaries 12 | - Compute chunk hashes 13 | -------------------------------------------------------------------------------- /hf_xet_thin_wasm/build_wasm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -ex 4 | 5 | # however, wasm-pack produces by default smaller wasm binaries by using a tool called wasm-opt 6 | # we can tweak the optimizations via Cargo.toml configuration 7 | # See: https://rustwasm.github.io/docs/wasm-pack/cargo-toml-configuration.html 8 | 9 | # valid values here: web, nodejs, bundler, no-modules, deno, default: web 10 | JS_TARGET="${JS_TARGET:-web}" 11 | 12 | wasm-pack build --release --target $JS_TARGET 13 | 14 | # adapted version of hf_xet_wasm for this package (no need for special features) 15 | # This is essentially the same steps that `wasm-pack` runs minus the optimization which we can add explicitly. 16 | 17 | #RUSTFLAGS='--cfg getrandom_backend="wasm_js"' \ 18 | # cargo +nightly build --target wasm32-unknown-unknown --release -Z build-std=std,panic_abort 19 | 20 | #RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-bindgen \ 21 | # target/wasm32-unknown-unknown/release/hf_xet_thin_wasm.wasm \ 22 | # --out-dir pkg \ 23 | # --typescript \ 24 | # --target web 25 | -------------------------------------------------------------------------------- /hf_xet_wasm/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | examples/target 3 | -------------------------------------------------------------------------------- /hf_xet_wasm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hf_xet_wasm" 3 | version = "0.0.1" 4 | edition = "2024" 5 | 6 | [lib] 7 | crate-type = ["cdylib", "rlib"] 8 | 9 | [dependencies] 10 | cas_client = { path = "../cas_client" } 11 | cas_object = { path = "../cas_object" } 12 | cas_types = { path = "../cas_types" } 13 | deduplication = { path = "../deduplication" } 14 | mdb_shard = { path = "../mdb_shard" } 15 | merklehash = { path = "../merklehash" } 16 | progress_tracking = { path = "../progress_tracking" } 17 | utils = { path = "../utils" } 18 | 19 | anyhow = "1" 20 | async-channel = "2.3.1" 21 | async-trait = "0.1.88" 22 | blake3 = "1.7.0" 23 | bytes = "1.10.1" 24 | console_error_panic_hook = "0.1.7" 25 | console_log = { version = "1.0.0", features = ["color"] } 26 | env_logger = "0.11.5" 27 | futures = "0.3.31" 28 | futures-io = "0.3.31" 29 | getrandom = { version = "0.3", features = ["wasm_js"] } 30 | js-sys = "0.3.72" 31 | log = "0.4.22" 32 | serde = { version = "1.0.217", features = ["derive"] } 33 | serde_json = "1.0.140" 34 | serde-wasm-bindgen = "0.6.5" 35 | sha2 = { version = "0.10.8", features = ["asm"] } 36 | thiserror = "2.0" 37 | tokio = { version = "1.44", features = ["sync", "rt"] } 38 | tokio_with_wasm = { version = "0.8.2", features = ["rt"] } 39 | tokio-stream = "0.1.17" 40 | uuid = { version = "1", features = ["v4", "js"] } 41 | wasm_thread = "0.3" 42 | wasm-bindgen = "=0.2.100" 43 | wasm-bindgen-futures = "0.4.50" 44 | web-sys = { version = "0.3.72", features = [ 45 | "File", 46 | "ReadableStream", 47 | "ReadableStreamDefaultReader", 48 | "Blob", 49 | "DedicatedWorkerGlobalScope", 50 | "MessageEvent", 51 | "Url", 52 | "Worker", 53 | "WorkerType", 54 | "WorkerOptions", 55 | "WorkerGlobalScope", 56 | "Window", 57 | "Navigator", 58 | "WorkerNavigator", 59 | "Headers", 60 | "Request", 61 | "RequestInit", 62 | "RequestMode", 63 | "Response", 64 | ] } 65 | 66 | [package.metadata.docs.rs] 67 | targets = ["wasm32-unknown-unknown"] 68 | 69 | [dev-dependencies] 70 | wasm-bindgen-test = "0.3.50" 71 | -------------------------------------------------------------------------------- /hf_xet_wasm/README.md: -------------------------------------------------------------------------------- 1 | # hf_xet_wasm: xet-core for WebAssembly 2 | 3 | This crate enables functionality to use the xet upload protocol from the browser with the use of a wasm based binary replicating the functionality of the `hf_xet` python library. 4 | Functionality included but not limited to chunking, global deduplication, xorb formation, xorb upload, shard formation, shard upload. 5 | 6 | Download functionality is not currently supported. 7 | 8 | hf_xet_wasm has: chunking, global deduplication, xorb formation, xorb upload, shard formation, shard upload 9 | 10 | hf_xet_wasm is missing: complete download support (xorbs, shards, chunk caching) 11 | 12 | ## Critical Differences and Changes 13 | 14 | In order to compile xet-core to wasm there are numerous changes: 15 | 16 | - A version of the data crate that does not assume the presence of any tokio threads 17 | - there is not yet such a thing as "multiple threads" in WebAssembly (at the time of writing) 18 | - Additionally only a specific feature set of tokio is supported in WASM, we only use those traits: ["sync", "rt", "macros", "time", "io-util"] 19 | - To support multithreading we use web workers (wasm_thread dependency) 20 | - Any components that use `async_trait` are required to change the `async_trait` proc_macro usage to not dictate `Send`'ness 21 | - any use of `#[async_trait::async_trait]` becomes: 22 | - ```rust 23 | #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] 24 | #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] 25 | pub trait Blah {} 26 | ``` 27 | - this is required as the output from the `async_trait` macro is not compatible to be `Send` when compiled to WASM 28 | - (pattern adopted from from reqwest_middleware) 29 | - Moves any operations that utilise or rely on the file system to in memory, primarily shard formation and storage 30 | - We choose not to use on the file system interface provided to browser based applications 31 | - Remove custom dns resolver to HTTP requests 32 | - HTTP requests in the browser are limited fetch calls made by reqwest. 33 | - custom dns is not allowed, only HTTP 34 | 35 | ## Build Instructions 36 | 37 | - Install nightly toolchain and dependencies: 38 | ```bash 39 | rustup toolchain install nightly 40 | rustup component add rust-src --toolchain nightly 41 | cargo install --version 0.2.100 wasm-bindgen-cli 42 | ``` 43 | - Build with `./build_wasm.sh` (bash) 44 | 45 | ## Run Instructions 46 | 47 | The runnable example is composed of a set of files in the examples directory. 48 | 49 | First fill up the four `[FILL_ME]` fields in examples/index.html with a desired testing target. 50 | 51 | Then serve the web directory using a local http server, for example, https://crates.io/crates/sfz. 52 | 53 | - Install sfz: 54 | ```bash 55 | cargo install sfz 56 | ``` 57 | 58 | - Serve the web 59 | ```bash 60 | sfz --coi -r examples 61 | ``` 62 | 63 | - Observe in browser 64 | In browser, go to URL http://127.0.0.1:5000, hit F12 and check the output 65 | under the "Console" tab. 66 | 67 | ## Authentication in hf_xet_wasm 68 | 69 | Like hf_xet it is the caller's responsibility to set up authentication with the CAS server by getting a token from the huggingface hub. 70 | The caller is also required to provide a method to get a fresh/refreshed token from the hub in the event of token expiration. 71 | 72 | In hf_xet_wasm it must be supplied to the XetSession using a user-defined set of interfaces. 73 | 74 | ```typescript 75 | class TokenInfo { 76 | token(): string { 77 | } 78 | exp(): bigint { 79 | return this.exp; 80 | } 81 | } 82 | 83 | class TokenRefresher { 84 | async refreshToken(): TokenInfo { 85 | } 86 | } 87 | 88 | const xetSession = new XetSession(, tokenInfo, tokenRefresher); 89 | ``` -------------------------------------------------------------------------------- /hf_xet_wasm/build_wasm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -ex 4 | 5 | # A couple of steps are necessary to get this build working which makes it slightly 6 | # nonstandard compared to most other builds. 7 | # 8 | # * First, the Rust standard library needs to be recompiled with atomics 9 | # enabled. to do that we use Cargo's unstable `-Zbuild-std` feature. 10 | # 11 | # * Next we need to compile everything with the `atomics` and `bulk-memory` 12 | # features enabled, ensuring that LLVM will generate atomic instructions, 13 | # shared memory, passive segments, etc. 14 | 15 | RUSTFLAGS='-C target-feature=+atomics,+bulk-memory,+mutable-globals --cfg getrandom_backend="wasm_js"' \ 16 | cargo +nightly build --example simple --target wasm32-unknown-unknown --release -Z build-std=std,panic_abort 17 | 18 | RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-bindgen \ 19 | target/wasm32-unknown-unknown/release/examples/simple.wasm \ 20 | --out-dir ./examples/target/ \ 21 | --typescript \ 22 | --target web 23 | -------------------------------------------------------------------------------- /hf_xet_wasm/examples/commit.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Commits a file to a Hugging Face dataset. 3 | * 4 | * @param {string} hf_endpoint - The HF Hub endpoint. 5 | * @param {string} file_name The name of the file to commit. 6 | * @param {string} sha256 The SHA256 hash of the file (as a string). 7 | * @param {number} file_size The size of the file in bytes. 8 | * @param {string} repo_type The type of the repo, i.e. "dataset" or "model" or "space" 9 | * @param {string} repo_id The id of the repo, specified as a namespace and a repo name separated by a '/' 10 | * @param {string} revision The revision to make the commit on top of 11 | * @param {string} hf_token The HF token for auth 12 | * @returns {Promise} A promise that resolves if the commit is successful, or rejects if an error occurs. 13 | */ 14 | async function commit(hf_endpoint, file_name, sha256, file_size, repo_type, repo_id, revision, hf_token) { 15 | const obj1 = { 16 | key: "header", 17 | value: { 18 | summary: `Upload ${file_name} with hf_xet_wasm`, 19 | description: "" 20 | } 21 | }; 22 | 23 | const obj2 = { 24 | key: "lfsFile", 25 | value: { 26 | path: file_name, 27 | algo: "sha256", 28 | oid: sha256, 29 | size: file_size 30 | } 31 | }; 32 | 33 | // Serialize to JSON string and concatenate with a newline. 34 | const body = `${JSON.stringify(obj1)}\n${JSON.stringify(obj2)}`; 35 | 36 | const url = `${hf_endpoint}/api/${repo_type}s/${repo_id}/commit/${revision}`; 37 | 38 | try { 39 | const response = await fetch(url, { 40 | method: 'POST', 41 | headers: { 42 | 'Authorization': `Bearer ${hf_token}`, 43 | 'Content-Type': 'application/x-ndjson' 44 | }, 45 | body: body 46 | }); 47 | 48 | // Check for HTTP errors. 49 | // `response.ok` is true for 2xx status codes. 50 | // This is the equivalent of `error_for_status()`. 51 | if (!response.ok) { 52 | const errorText = await response.text(); 53 | throw new Error(`HTTP error! Status: ${response.status}, Body: ${errorText}`); 54 | } 55 | 56 | // Get the response text. 57 | const responseText = await response.text(); 58 | 59 | return responseText; 60 | } catch (error) { 61 | // Handle any network errors or errors thrown by the fetch API. 62 | // This is the equivalent of `anyhow::Result<()>`. 63 | console.error("Commit failed:", error); 64 | throw error; // Re-throw the error so the caller can handle it 65 | } 66 | } -------------------------------------------------------------------------------- /hf_xet_wasm/examples/xet_meta.js: -------------------------------------------------------------------------------- 1 | function xetMetadataOrNone(jsonData) { 2 | /** 3 | * Extract XET metadata from the HTTP body or return null if not found. 4 | * 5 | * @param {jsonData} - HTTP body in JSON to extract the XET metadata from. 6 | * @returns {XetMetadata|null} The extracted metadata or null if missing. 7 | */ 8 | 9 | const xetEndpoint = jsonData.casUrl; 10 | const accessToken = jsonData.accessToken; 11 | const expiration = jsonData.exp; 12 | 13 | if (xetEndpoint == undefined || accessToken == undefined || expiration == undefined) { 14 | return null; 15 | } 16 | 17 | const expirationUnixEpoch = parseInt(expiration, 10); 18 | if (isNaN(expirationUnixEpoch)) { 19 | return null; 20 | } 21 | 22 | return { 23 | endpoint: xetEndpoint, 24 | accessToken: accessToken, 25 | expirationUnixEpoch: expirationUnixEpoch, 26 | }; 27 | } 28 | 29 | async function fetchXetMetadataFromRepoInfo({ 30 | hfEndpoint, 31 | tokenType, 32 | repoId, 33 | repoType, 34 | headers, 35 | params = null 36 | }) { 37 | /** 38 | * Uses the repo info to request a XET access token from Hub. 39 | * 40 | * @param {string} hfEndpoint - The HF Hub endpoint. 41 | * @param {string} tokenType - Type of the token to request: "read" or "write". 42 | * @param {string} repoId - A namespace (user or an organization) and a repo name separated by a `/`. 43 | * @param {string} repoType - Type of the repo to upload to: "model", "dataset", or "space". 44 | * @param {Object} headers - Headers to use for the request, including authorization headers and user agent. 45 | * @param {Object|null} params - Additional parameters to pass with the request. 46 | * @returns {Promise} The metadata needed to make the request to the XET storage service. 47 | * @throws {Error} If the Hub API returned an error or the response is improperly formatted. 48 | */ 49 | 50 | const url = `${hfEndpoint}/api/${repoType}s/${repoId}/xet-${tokenType}-token/main`; 51 | console.log(`${url}`); 52 | 53 | return fetchXetMetadataWithUrl(url, headers, params); 54 | } 55 | 56 | async function fetchXetMetadataWithUrl(url, headers, params = null) { 57 | /** 58 | * Requests the XET access token from the supplied URL. 59 | * 60 | * @param {string} url - The access token endpoint URL. 61 | * @param {Object} headers - Headers to use for the request, including authorization headers and user agent. 62 | * @param {Object|null} params - Additional parameters to pass with the request. 63 | * @returns {Promise} The metadata needed to make the request to the XET storage service. 64 | * @throws {Error} If the Hub API returned an error or the response is improperly formatted. 65 | */ 66 | 67 | const response = await fetch(url, { 68 | method: "GET", 69 | headers: headers, 70 | }); 71 | 72 | const jsonData = await response.json(); 73 | 74 | if (!response.ok) { 75 | console.log("response not ok"); 76 | throw new Error(`HTTP error! Status: ${response.status}`); 77 | } 78 | 79 | const metadata = xetMetadataOrNone(jsonData); 80 | if (!metadata) { 81 | throw new Error("XET headers have not been correctly set by the server."); 82 | } 83 | 84 | return metadata; 85 | } -------------------------------------------------------------------------------- /hf_xet_wasm/rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | # Before upgrading check that everything is available on all tier1 targets here: 2 | # https://rust-lang.github.io/rustup-components-history 3 | [toolchain] 4 | channel = "nightly" 5 | components = ["rust-src", "rustfmt"] 6 | targets = ["wasm32-unknown-unknown"] 7 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/auth.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter}; 2 | use std::sync::Arc; 3 | 4 | use tokio::sync::Mutex; 5 | use wasm_bindgen::prelude::*; 6 | use wasm_bindgen::JsValue; 7 | 8 | /// The internal token refreshing mechanism expects to be passed in a TokenInfo and TokenRefresher 9 | /// javascript interface objects that implement the following interface. 10 | /// 11 | /// To pass these constructs into the wasm program, use the XetSession export. 12 | /// 13 | ///```typescript 14 | /// interface TokenInfo { 15 | /// token(): string {} 16 | /// exp(): number {} 17 | /// } 18 | /// 19 | /// interface TokenRefresher { 20 | /// async refreshToken(): TokenInfo {} 21 | /// } 22 | ///``` 23 | 24 | #[wasm_bindgen] 25 | extern "C" { 26 | pub type TokenInfo; 27 | #[wasm_bindgen(method, getter)] 28 | pub fn token(this: &TokenInfo) -> String; 29 | #[wasm_bindgen(method, getter)] 30 | pub fn exp(this: &TokenInfo) -> f64; 31 | 32 | pub type TokenRefresher; 33 | #[wasm_bindgen(method, catch, js_name = "refreshToken")] 34 | pub async fn refresh_token(this: &TokenRefresher) -> Result; 35 | } 36 | 37 | impl From for utils::auth::TokenInfo { 38 | fn from(value: TokenInfo) -> Self { 39 | (value.token(), value.exp() as u64) 40 | } 41 | } 42 | 43 | impl Debug for TokenRefresher { 44 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 45 | write!(f, "TokenRefresher") 46 | } 47 | } 48 | 49 | #[derive(Debug, Clone)] 50 | pub(crate) struct WrappedTokenRefresher(Arc>); 51 | 52 | // TODO: revise the safety of this! 53 | unsafe impl Send for WrappedTokenRefresher {} 54 | unsafe impl Sync for WrappedTokenRefresher {} 55 | 56 | impl From for WrappedTokenRefresher { 57 | fn from(value: TokenRefresher) -> Self { 58 | #[allow(clippy::arc_with_non_send_sync)] 59 | WrappedTokenRefresher(Arc::new(Mutex::new(value))) 60 | } 61 | } 62 | 63 | #[async_trait::async_trait(?Send)] 64 | impl utils::auth::TokenRefresher for WrappedTokenRefresher { 65 | async fn refresh(&self) -> Result { 66 | self.0 67 | .lock() 68 | .await 69 | .refresh_token() 70 | .await 71 | .map(utils::auth::TokenInfo::from) 72 | .map_err(|e| utils::errors::AuthError::token_refresh_failure(format!("{e:?}"))) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/configurations.rs: -------------------------------------------------------------------------------- 1 | use cas_object::CompressionScheme; 2 | use utils::auth::AuthConfig; 3 | 4 | // configurations for hf_xet_wasm components, generally less complicated than hf_xet/data crate configurations 5 | 6 | #[derive(Debug)] 7 | pub struct DataConfig { 8 | pub endpoint: String, 9 | pub compression: Option, 10 | pub auth: Option, 11 | pub prefix: String, 12 | } 13 | 14 | #[derive(Debug)] 15 | pub struct ShardConfig { 16 | pub prefix: String, 17 | } 18 | 19 | #[derive(Debug)] 20 | pub struct TranslatorConfig { 21 | pub data_config: DataConfig, 22 | pub shard_config: ShardConfig, 23 | pub session_id: String, 24 | } 25 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | use cas_client::CasClientError; 4 | use cas_object::error::CasObjectError; 5 | use mdb_shard::error::MDBShardError; 6 | use merklehash::DataHashHexParseError; 7 | use thiserror::Error; 8 | 9 | #[non_exhaustive] 10 | #[derive(Error, Debug)] 11 | pub enum DataProcessingError { 12 | #[error("Internal error: {0}")] 13 | InternalError(String), 14 | 15 | #[error("Bad hash: {0}")] 16 | BadHash(#[from] DataHashHexParseError), 17 | 18 | #[error("CAS service error : {0}")] 19 | CasClientError(#[from] CasClientError), 20 | 21 | #[error("Xorb Serialization error : {0}")] 22 | XorbSerializationError(#[from] CasObjectError), 23 | 24 | #[error("MerkleDB Shard error: {0}")] 25 | MDBShardError(#[from] MDBShardError), 26 | } 27 | 28 | impl DataProcessingError { 29 | pub fn internal(value: T) -> Self { 30 | DataProcessingError::InternalError(format!("{value:?}")) 31 | } 32 | } 33 | 34 | pub type Result = std::result::Result; 35 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(target_family = "wasm"))] 2 | compile_error!("This crate is only meant to be used on the WebAssembly target"); 3 | 4 | mod auth; 5 | pub mod blob_reader; 6 | pub mod configurations; 7 | mod errors; 8 | mod session; 9 | mod sha256; 10 | mod wasm_deduplication_interface; 11 | mod wasm_file_cleaner; 12 | pub mod wasm_file_upload_session; 13 | pub mod wasm_timer; 14 | mod xorb_uploader; 15 | 16 | pub use session::XetSession; 17 | 18 | // sample test 19 | #[cfg(test)] 20 | mod tests { 21 | use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; 22 | 23 | wasm_bindgen_test_configure!(run_in_browser); 24 | 25 | #[test] 26 | #[wasm_bindgen_test] 27 | fn simple_test() { 28 | assert_eq!(1, 1); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/sha256.rs: -------------------------------------------------------------------------------- 1 | use deduplication::Chunk; 2 | use merklehash::MerkleHash; 3 | use sha2::{Digest, Sha256}; 4 | 5 | use super::errors::*; 6 | 7 | // utilities to generate sha256 in webassembly in rust. 8 | // The `Value` variant hides an already provided result hash as if it were an operation to maintain 9 | // the same interface but do no work. 10 | pub enum ShaGeneration { 11 | Value(MerkleHash), 12 | Action(ShaGenerator), 13 | } 14 | 15 | impl ShaGeneration { 16 | pub fn new(hash: Option) -> Self { 17 | match hash { 18 | Some(h) => Self::Value(h), 19 | None => Self::Action(ShaGenerator::new()), 20 | } 21 | } 22 | 23 | pub fn update(&mut self, new_chunks: &[Chunk]) { 24 | match self { 25 | ShaGeneration::Value(_) => {}, 26 | ShaGeneration::Action(sha_generator) => sha_generator.update(new_chunks), 27 | } 28 | } 29 | 30 | pub fn update_with_bytes(&mut self, new_bytes: &[u8]) { 31 | match self { 32 | ShaGeneration::Value(_) => {}, 33 | ShaGeneration::Action(sha_generator) => sha_generator.update_with_bytes(new_bytes), 34 | } 35 | } 36 | 37 | pub fn finalize(self) -> Result { 38 | match self { 39 | ShaGeneration::Value(hash) => Ok(hash), 40 | ShaGeneration::Action(sha_generator) => sha_generator.finalize(), 41 | } 42 | } 43 | } 44 | 45 | // struct to generate a sha256 progressively by calling `update` or the `with_bytes` variation 46 | // yielding the final hash when calling `finalize()` 47 | pub struct ShaGenerator { 48 | hasher: Sha256, 49 | } 50 | 51 | impl ShaGenerator { 52 | pub fn new() -> Self { 53 | Self { 54 | hasher: Sha256::default(), 55 | } 56 | } 57 | 58 | pub fn update(&mut self, new_chunks: &[Chunk]) { 59 | for chunk in new_chunks.iter() { 60 | self.hasher.update(&chunk.data); 61 | } 62 | } 63 | 64 | pub fn update_with_bytes(&mut self, new_bytes: &[u8]) { 65 | self.hasher.update(new_bytes); 66 | } 67 | 68 | pub fn finalize(self) -> Result { 69 | let sha256 = self.hasher.finalize(); 70 | let hex_str = format!("{sha256:x}"); 71 | Ok(MerkleHash::from_hex(&hex_str)?) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/wasm_timer.rs: -------------------------------------------------------------------------------- 1 | use web_sys::console; 2 | 3 | static PROFILING: bool = false; 4 | 5 | // when dropped the time since the timer is created until it is dropped 6 | // is logged to the browser console. 7 | pub struct ConsoleTimer { 8 | name: String, 9 | } 10 | 11 | impl ConsoleTimer { 12 | pub fn new(name: impl AsRef) -> ConsoleTimer { 13 | if PROFILING { 14 | Self::new_enforce_report(name) 15 | } else { 16 | ConsoleTimer { name: "".to_owned() } 17 | } 18 | } 19 | 20 | pub fn new_enforce_report(name: impl AsRef) -> ConsoleTimer { 21 | let name = name.as_ref(); 22 | console::time_with_label(name); 23 | ConsoleTimer { name: name.to_owned() } 24 | } 25 | } 26 | 27 | impl Drop for ConsoleTimer { 28 | fn drop(&mut self) { 29 | if !self.name.is_empty() { 30 | console::time_end_with_label(&self.name); 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /hf_xet_wasm/src/xorb_uploader.rs: -------------------------------------------------------------------------------- 1 | use std::result::Result as stdResult; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use cas_client::{CasClientError, Client}; 6 | use cas_object::SerializedCasObject; 7 | use tokio::sync::Semaphore; 8 | use tokio_with_wasm::alias as wasmtokio; 9 | 10 | use crate::errors::*; 11 | use crate::wasm_timer::ConsoleTimer; 12 | 13 | #[cfg_attr(not(target_family = "wasm"), async_trait)] 14 | #[cfg_attr(target_family = "wasm", async_trait(?Send))] 15 | pub trait XorbUploader { 16 | async fn upload_xorb(&mut self, input: SerializedCasObject) -> Result<()>; 17 | async fn finalize(&mut self) -> Result<()>; 18 | } 19 | 20 | pub struct XorbUploaderLocalSequential { 21 | client: Arc, 22 | cas_prefix: String, 23 | } 24 | 25 | impl XorbUploaderLocalSequential { 26 | pub fn new(client: Arc, cas_prefix: &str, _upload_concurrency: usize) -> Self { 27 | Self { 28 | client, 29 | cas_prefix: cas_prefix.to_owned(), 30 | } 31 | } 32 | } 33 | 34 | #[cfg_attr(not(target_family = "wasm"), async_trait)] 35 | #[cfg_attr(target_family = "wasm", async_trait(?Send))] 36 | impl XorbUploader for XorbUploaderLocalSequential { 37 | async fn upload_xorb(&mut self, input: SerializedCasObject) -> Result<()> { 38 | let _ = self.client.upload_xorb(&self.cas_prefix, input, None).await?; 39 | Ok(()) 40 | } 41 | 42 | async fn finalize(&mut self) -> Result<()> { 43 | Ok(()) 44 | } 45 | } 46 | 47 | pub struct XorbUploaderSpawnParallel { 48 | client: Arc, 49 | cas_prefix: String, 50 | semaphore: Arc, 51 | tasks: wasmtokio::task::JoinSet>, 52 | } 53 | 54 | impl XorbUploaderSpawnParallel { 55 | pub fn new(client: Arc, cas_prefix: &str, upload_concurrency: usize) -> Self { 56 | Self { 57 | client, 58 | cas_prefix: cas_prefix.to_owned(), 59 | semaphore: Arc::new(Semaphore::new(upload_concurrency)), 60 | tasks: wasmtokio::task::JoinSet::new(), 61 | } 62 | } 63 | } 64 | 65 | #[cfg_attr(not(target_family = "wasm"), async_trait)] 66 | #[cfg_attr(target_family = "wasm", async_trait(?Send))] 67 | impl XorbUploader for XorbUploaderSpawnParallel { 68 | async fn upload_xorb(&mut self, input: SerializedCasObject) -> Result<()> { 69 | while let Some(ret) = self.tasks.try_join_next() { 70 | ret.map_err(DataProcessingError::internal)??; 71 | } 72 | 73 | let client = self.client.clone(); 74 | let cas_prefix = self.cas_prefix.clone(); 75 | let permit = self 76 | .semaphore 77 | .clone() 78 | .acquire_owned() 79 | .await 80 | .map_err(DataProcessingError::internal)?; 81 | self.tasks.spawn(async move { 82 | let _timer = ConsoleTimer::new(format!("upload xorb {}", input.hash)); 83 | let ret = client.upload_xorb(&cas_prefix, input, None).await; 84 | drop(permit); 85 | ret 86 | }); 87 | 88 | Ok(()) 89 | } 90 | 91 | async fn finalize(&mut self) -> Result<()> { 92 | while let Some(ret) = self.tasks.join_next().await { 93 | ret.map_err(DataProcessingError::internal)??; 94 | } 95 | 96 | Ok(()) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /hf_xet_wasm/webdriver.json: -------------------------------------------------------------------------------- 1 | { 2 | "moz:firefoxOptions": { 3 | "args": [] 4 | }, 5 | "goog:chromeOptions": { 6 | "args": [ 7 | ] 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /hub_client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hub_client" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | cas_client = { path = "../cas_client" } 8 | 9 | anyhow = { workspace = true } 10 | async-trait = { workspace = true } 11 | http = { workspace = true } 12 | reqwest = { workspace = true } 13 | reqwest-middleware = { workspace = true } 14 | serde = { workspace = true } 15 | thiserror = { workspace = true } 16 | urlencoding = { workspace = true } 17 | 18 | [dev-dependencies] 19 | serde_json = { workspace = true } 20 | tokio = { workspace = true } -------------------------------------------------------------------------------- /hub_client/src/auth.rs: -------------------------------------------------------------------------------- 1 | mod basics; 2 | mod interface; 3 | 4 | pub use basics::{BearerCredentialHelper, NoopCredentialHelper}; 5 | pub use interface::CredentialHelper; 6 | -------------------------------------------------------------------------------- /hub_client/src/auth/basics.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use anyhow::Result; 4 | use async_trait::async_trait; 5 | use reqwest_middleware::RequestBuilder; 6 | 7 | use super::CredentialHelper; 8 | 9 | pub struct NoopCredentialHelper {} 10 | 11 | impl NoopCredentialHelper { 12 | #[allow(clippy::new_ret_no_self)] 13 | pub fn new() -> Arc { 14 | Arc::new(Self {}) 15 | } 16 | } 17 | 18 | #[async_trait] 19 | impl CredentialHelper for NoopCredentialHelper { 20 | async fn fill_credential(&self, req: RequestBuilder) -> Result { 21 | Ok(req) 22 | } 23 | 24 | fn whoami(&self) -> &str { 25 | "noop" 26 | } 27 | } 28 | 29 | pub struct BearerCredentialHelper { 30 | pub hf_token: String, 31 | 32 | _whoami: &'static str, 33 | } 34 | 35 | impl BearerCredentialHelper { 36 | #[allow(clippy::new_ret_no_self)] 37 | pub fn new(hf_token: String, whoami: &'static str) -> Arc { 38 | Arc::new(Self { 39 | hf_token, 40 | _whoami: whoami, 41 | }) 42 | } 43 | } 44 | 45 | #[async_trait] 46 | impl CredentialHelper for BearerCredentialHelper { 47 | async fn fill_credential(&self, req: RequestBuilder) -> Result { 48 | Ok(req.bearer_auth(&self.hf_token)) 49 | } 50 | 51 | fn whoami(&self) -> &str { 52 | self._whoami 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /hub_client/src/auth/interface.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use reqwest_middleware::RequestBuilder; 4 | 5 | #[async_trait] 6 | pub trait CredentialHelper: Send + Sync { 7 | async fn fill_credential(&self, req: RequestBuilder) -> Result; 8 | // Used in tests to identify the source of the credential. 9 | fn whoami(&self) -> &str; 10 | } 11 | -------------------------------------------------------------------------------- /hub_client/src/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum HubClientError { 5 | #[error("Cas client error: {0}")] 6 | CasClient(#[from] cas_client::CasClientError), 7 | 8 | #[error("Reqwest error: {0}")] 9 | Reqwest(#[from] reqwest::Error), 10 | 11 | #[error("Reqwest middleware error: {0}")] 12 | ReqwestMiddleware(#[from] reqwest_middleware::Error), 13 | 14 | #[error("Credential helper error: {0}")] 15 | CredentialHelper(anyhow::Error), 16 | 17 | #[error("Invalid repo type: {0}")] 18 | InvalidRepoType(String), 19 | } 20 | 21 | pub type Result = std::result::Result; 22 | 23 | impl HubClientError { 24 | pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError { 25 | HubClientError::CredentialHelper(e.into()) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /hub_client/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod auth; 2 | mod client; 3 | mod errors; 4 | mod types; 5 | 6 | pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper}; 7 | pub use client::{HubClient, Operation}; 8 | pub use errors::{HubClientError, Result}; 9 | pub use types::{CasJWTInfo, HFRepoType, RepoInfo}; 10 | -------------------------------------------------------------------------------- /hub_client/src/types.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Display; 2 | use std::str::FromStr; 3 | 4 | use serde::Deserialize; 5 | 6 | use crate::errors::{HubClientError, Result}; 7 | 8 | /// This defines the response format from the Huggingface Hub Xet CAS access token API. 9 | #[derive(Deserialize, Debug)] 10 | #[serde(rename_all = "camelCase")] 11 | pub struct CasJWTInfo { 12 | pub cas_url: String, // CAS server endpoint base URL 13 | pub exp: u64, // access token expiry since UNIX_EPOCH 14 | pub access_token: String, 15 | } 16 | 17 | // This defines the exact three types of repos served on HF Hub. 18 | #[derive(Debug, PartialEq)] 19 | pub enum HFRepoType { 20 | Model, 21 | Dataset, 22 | Space, 23 | } 24 | 25 | impl FromStr for HFRepoType { 26 | type Err = HubClientError; 27 | 28 | fn from_str(s: &str) -> std::result::Result { 29 | match s.to_lowercase().as_str() { 30 | "" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model" 31 | "model" | "models" => Ok(HFRepoType::Model), 32 | "dataset" | "datasets" => Ok(HFRepoType::Dataset), 33 | "space" | "spaces" => Ok(HFRepoType::Space), 34 | t => Err(HubClientError::InvalidRepoType(t.to_owned())), 35 | } 36 | } 37 | } 38 | 39 | impl HFRepoType { 40 | pub fn as_str(&self) -> &str { 41 | match self { 42 | HFRepoType::Model => "model", 43 | HFRepoType::Dataset => "dataset", 44 | HFRepoType::Space => "space", 45 | } 46 | } 47 | } 48 | 49 | impl Display for HFRepoType { 50 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 51 | f.write_str(self.as_str()) 52 | } 53 | } 54 | 55 | #[derive(Debug, PartialEq)] 56 | pub struct RepoInfo { 57 | // The type of a repo, one of "model | dataset | space" 58 | pub repo_type: HFRepoType, 59 | // The full name of a repo, formatted as "owner/name" 60 | pub full_name: String, 61 | } 62 | 63 | impl RepoInfo { 64 | pub fn try_from(repo_type: &str, repo_id: &str) -> Result { 65 | Ok(Self { 66 | repo_type: repo_type.parse()?, 67 | full_name: repo_id.into(), 68 | }) 69 | } 70 | } 71 | 72 | impl Display for RepoInfo { 73 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 74 | write!(f, "{}/{}", self.repo_type, self.full_name) 75 | } 76 | } 77 | 78 | #[cfg(test)] 79 | mod tests { 80 | use anyhow::Result; 81 | 82 | use crate::types::CasJWTInfo; 83 | 84 | #[test] 85 | fn test_cas_jwt_response_deser() -> Result<()> { 86 | let bytes = r#"{"casUrl":"https://cas-server.xethub.hf.co","exp":1756489133,"accessToken":"ey...jQ"}"#; 87 | 88 | let info: CasJWTInfo = serde_json::from_slice(bytes.as_bytes())?; 89 | 90 | assert_eq!(info.cas_url, "https://cas-server.xethub.hf.co"); 91 | assert_eq!(info.exp, 1756489133); 92 | assert_eq!(info.access_token, "ey...jQ"); 93 | 94 | Ok(()) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /markdownlint.toml: -------------------------------------------------------------------------------- 1 | default = true 2 | MD013.line_length = 300 3 | -------------------------------------------------------------------------------- /mdb_shard/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mdb_shard" 3 | version = "0.14.5" 4 | edition = "2024" 5 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 6 | 7 | [dependencies] 8 | merklehash = { path = "../merklehash" } 9 | utils = { path = "../utils" } 10 | 11 | anyhow = { workspace = true } 12 | async-trait = { workspace = true } 13 | blake3 = { workspace = true } 14 | bytes = { workspace = true } 15 | clap = { workspace = true } 16 | futures = { workspace = true } 17 | futures-util = { workspace = true } 18 | heapify = { workspace = true } 19 | itertools = { workspace = true } 20 | lazy_static = { workspace = true } 21 | rand = { workspace = true, features = ["small_rng"] } 22 | regex = { workspace = true } 23 | serde = { workspace = true } 24 | static_assertions = { workspace = true } 25 | tempfile = { workspace = true } 26 | thiserror = { workspace = true } 27 | tokio = { workspace = true } 28 | tracing = { workspace = true } 29 | 30 | [target.'cfg(target_family = "wasm")'.dependencies] 31 | uuid = { workspace = true, features = ["v4", "js"] } 32 | [target.'cfg(not(target_family = "wasm"))'.dependencies] 33 | uuid = { workspace = true, features = ["v4"] } 34 | tokio = { workspace = true, features = ["rt-multi-thread"] } 35 | 36 | [[bin]] 37 | name = "shard_benchmark" 38 | path = "src/shard_benchmark.rs" 39 | -------------------------------------------------------------------------------- /mdb_shard/README.md: -------------------------------------------------------------------------------- 1 | # mdb_shard 2 | 3 | > MDB -> Merkle Database 4 | 5 | The mdb_shard crate exposes multiple interfaces for working with shards. 6 | 7 | This includes particularly the shard file format as used as API payloads as well as used internally within xet-core to 8 | manage and store state during and between processes to deduplicate and upload data. 9 | 10 | ## Serialization and Deserialization Interfaces 11 | 12 | The mdb_shard crate provides multiple interfaces for serializing and deserializing shard data, organized by their purpose and usage patterns. 13 | These interfaces allow you to work with shard data at different levels of abstraction, from low-level binary serialization to high-level streaming processing. 14 | 15 | ### Core Shard Format Interfaces 16 | 17 | [**`src/shard_format.rs`**](src/shard_format.rs) 18 | 19 | These interfaces handle the core shard file format and metadata: 20 | 21 | - **`MDBShardInfo::load_from_reader()`** - Loads complete shard metadata (header + footer) from a reader 22 | - **`MDBShardInfo::serialize_from()`** - Serializes an in-memory shard to binary format 23 | 24 | ### Streaming and Processing Interfaces 25 | 26 | [**`src/streaming_shard.rs`**](src/streaming_shard.rs) 27 | 28 | - **`MDBMinimalShard::from_reader()`** - Creates a minimal shard representation for lightweight operations from a reader 29 | - **`MDBMinimalShard::from_reader_async()`** - Creates a minimal shard representation for lightweight operations, from an async reader 30 | 31 | ### File Handle Interfaces 32 | 33 | [**`src/shard_file_handle.rs`**](src/shard_file_handle.rs) 34 | 35 | - **`MDBShardFile::load_from_file()`** - Loads shard from a file path with caching 36 | -------------------------------------------------------------------------------- /mdb_shard/src/chunk_verification.rs: -------------------------------------------------------------------------------- 1 | use merklehash::MerkleHash; 2 | 3 | /// The hash key used for generating chunk range hash for shard verification 4 | pub const VERIFICATION_KEY: [u8; 32] = [ 5 | 127, 24, 87, 214, 206, 86, 237, 102, 18, 127, 249, 19, 231, 165, 195, 243, 164, 205, 38, 213, 181, 219, 73, 230, 6 | 65, 36, 152, 127, 40, 251, 148, 195, 7 | ]; 8 | 9 | pub fn range_hash_from_chunks(chunks: &[MerkleHash]) -> MerkleHash { 10 | let combined: Vec = chunks.iter().flat_map(|hash| hash.as_bytes().to_vec()).collect(); 11 | 12 | // now apply hmac to hashes and return 13 | let range_hash = blake3::keyed_hash(&VERIFICATION_KEY, combined.as_slice()); 14 | 15 | MerkleHash::from(range_hash.as_bytes()) 16 | } 17 | -------------------------------------------------------------------------------- /mdb_shard/src/constants.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use utils::ByteSize; 4 | 5 | utils::configurable_constants! { 6 | 7 | /// The target shard size; shards. 8 | ref MDB_SHARD_TARGET_SIZE: u64 = 64 * 1024 * 1024; 9 | 10 | /// Maximum shard size; small shards are aggregated until they are at most this. 11 | ref MDB_SHARD_MAX_TARGET_SIZE: u64 = 64 * 1024 * 1024; 12 | 13 | /// The global dedup chunk modulus; a chunk is considered global dedup 14 | /// eligible if the hash modulus this value is zero. 15 | ref MDB_SHARD_GLOBAL_DEDUP_CHUNK_MODULUS: u64 = release_fixed(1024); 16 | 17 | /// The (soft) maximum size in bytes of the shard cache. Default is 16 GB. 18 | /// 19 | /// As a rough calculation, a cache of size X will allow for dedup against data 20 | /// of size 1000 * X. The default would allow a 16 TB repo to be deduped effectively. 21 | /// 22 | /// Note the cache is pruned to below this value at the beginning of a session, 23 | /// but during a single session new shards may be added such that this limit is exceeded. 24 | ref SHARD_CACHE_SIZE_LIMIT : ByteSize = ByteSize::from("16gb"); 25 | 26 | /// The amount of time a shard should be expired by before it's deleted, in seconds. 27 | /// By default set to 7 days. 28 | ref MDB_SHARD_EXPIRATION_BUFFER: Duration = Duration::from_secs(7 * 24 * 3600); 29 | 30 | /// The maximum size of the chunk index table that's stored in memory. After this, 31 | /// no new chunks are loaded for deduplication. 32 | ref CHUNK_INDEX_TABLE_MAX_SIZE: usize = 64 * 1024 * 1024; 33 | } 34 | 35 | // How the MDB_SHARD_GLOBAL_DEDUP_CHUNK_MODULUS is used. 36 | pub fn hash_is_global_dedup_eligible(h: &merklehash::MerkleHash) -> bool { 37 | (*h) % *MDB_SHARD_GLOBAL_DEDUP_CHUNK_MODULUS == 0 38 | } 39 | -------------------------------------------------------------------------------- /mdb_shard/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use merklehash::MerkleHash; 4 | use thiserror::Error; 5 | use utils::RwTaskLockError; 6 | 7 | #[non_exhaustive] 8 | #[derive(Error, Debug)] 9 | pub enum MDBShardError { 10 | #[error("File I/O error")] 11 | IOError(#[from] io::Error), 12 | 13 | #[error("Too many collisions when searching for truncated hash : {0}")] 14 | TruncatedHashCollisionError(u64), 15 | 16 | #[error("Shard version error: {0}")] 17 | ShardVersionError(String), 18 | 19 | #[error("Bad file name format: {0}")] 20 | BadFilename(String), 21 | 22 | #[error("Other Internal Error: {0}")] 23 | InternalError(anyhow::Error), 24 | 25 | #[error("Shard not found")] 26 | ShardNotFound(MerkleHash), 27 | 28 | #[error("File not found")] 29 | FileNotFound(MerkleHash), 30 | 31 | #[error("Query failed: {0}")] 32 | QueryFailed(String), 33 | 34 | #[error("Smudge query policy Error: {0}")] 35 | SmudgeQueryPolicyError(String), 36 | 37 | #[error("Runtime Error (task scheduler): {0}")] 38 | TaskRuntimeError(#[from] RwTaskLockError), 39 | 40 | #[error("Runtime Error (task scheduler): {0}")] 41 | TaskJoinError(#[from] tokio::task::JoinError), 42 | 43 | #[error("InvalidShard {0}")] 44 | InvalidShard(String), 45 | 46 | #[error("Error: {0}")] 47 | Other(String), 48 | } 49 | 50 | // Define our own result type here (this seems to be the standard). 51 | pub type Result = std::result::Result; 52 | 53 | // For error checking 54 | impl PartialEq for MDBShardError { 55 | fn eq(&self, other: &MDBShardError) -> bool { 56 | match (self, other) { 57 | (MDBShardError::IOError(e1), MDBShardError::IOError(e2)) => e1.kind() == e2.kind(), 58 | _ => false, 59 | } 60 | } 61 | } 62 | 63 | impl MDBShardError { 64 | pub fn other(inner: impl ToString) -> Self { 65 | Self::Other(inner.to_string()) 66 | } 67 | 68 | pub fn invalid_shard(inner: impl ToString) -> Self { 69 | Self::InvalidShard(inner.to_string()) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /mdb_shard/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod cas_structs; 2 | pub mod chunk_verification; 3 | pub mod constants; 4 | pub mod error; 5 | pub mod file_structs; 6 | pub mod interpolation_search; 7 | pub mod session_directory; 8 | pub mod set_operations; 9 | pub mod shard_file_handle; 10 | pub mod shard_file_manager; 11 | pub mod shard_file_reconstructor; 12 | pub mod shard_format; 13 | pub mod shard_in_memory; 14 | pub mod utils; 15 | 16 | pub use constants::{MDB_SHARD_TARGET_SIZE, hash_is_global_dedup_eligible}; 17 | pub use shard_file_handle::MDBShardFile; 18 | pub use shard_file_manager::ShardFileManager; 19 | pub use shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo}; 20 | 21 | // Temporary to transition dependent code to new location 22 | pub mod shard_file; 23 | 24 | pub mod streaming_shard; 25 | -------------------------------------------------------------------------------- /mdb_shard/src/shard_file.rs: -------------------------------------------------------------------------------- 1 | // Temporary to transition dependent code to new location 2 | pub use crate::shard_format::*; 3 | -------------------------------------------------------------------------------- /mdb_shard/src/shard_file_reconstructor.rs: -------------------------------------------------------------------------------- 1 | use merklehash::MerkleHash; 2 | 3 | use crate::file_structs::MDBFileInfo; 4 | 5 | #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] 6 | #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] 7 | pub trait FileReconstructor { 8 | /// Returns a pair of (file reconstruction information, maybe shard ID) 9 | /// Err(_) if an error occured 10 | /// Ok(None) if the file is not found. 11 | async fn get_file_reconstruction_info( 12 | &self, 13 | file_hash: &MerkleHash, 14 | ) -> Result)>, E>; 15 | } 16 | -------------------------------------------------------------------------------- /mdb_shard/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::OsStr; 2 | use std::ops::Deref; 3 | use std::path::Path; 4 | use std::time::Duration; 5 | 6 | use lazy_static::lazy_static; 7 | use merklehash::MerkleHash; 8 | use regex::Regex; 9 | use uuid::Uuid; 10 | 11 | lazy_static! { 12 | static ref MERKLE_DB_FILE_PATTERN: Regex = Regex::new(r"^(?P[0-9a-fA-F]{64})\.mdb$").unwrap(); 13 | } 14 | 15 | /// Parses a shard filename. If the filename matches the shard filename pattern, 16 | /// then Some(hash) is returned, where hash is the CAS hash of the merkledb file. 17 | /// If the filename does not match, None is returned. 18 | #[inline] 19 | pub fn parse_shard_filename>(path: P) -> Option { 20 | let path: &Path = path.as_ref(); 21 | let filename = path.file_name()?; 22 | 23 | let filename = filename.to_str().unwrap_or_default(); 24 | 25 | MERKLE_DB_FILE_PATTERN 26 | .captures(filename) 27 | .map(|capture| MerkleHash::from_hex(capture.name("hash").unwrap().as_str()).unwrap()) 28 | } 29 | 30 | #[inline] 31 | pub fn truncate_hash(hash: &MerkleHash) -> u64 { 32 | hash.deref()[0] 33 | } 34 | 35 | pub fn shard_file_name(hash: &MerkleHash) -> String { 36 | format!("{}.mdb", hash.hex()) 37 | } 38 | 39 | pub fn temp_shard_file_name() -> String { 40 | let uuid = Uuid::new_v4(); 41 | format!(".{uuid}.mdb_temp") 42 | } 43 | 44 | pub fn is_temp_shard_file(p: &Path) -> bool { 45 | p.file_name() 46 | .unwrap_or_else(|| OsStr::new("")) 47 | .to_str() 48 | .unwrap_or("") 49 | .ends_with("mdb_temp") 50 | } 51 | 52 | pub fn shard_expiry_time(shard_valid_for: Duration) -> u64 { 53 | use std::ops::Add; 54 | 55 | std::time::SystemTime::now() 56 | .add(shard_valid_for) 57 | .duration_since(std::time::UNIX_EPOCH) 58 | .unwrap_or_default() 59 | .as_secs() 60 | } 61 | 62 | #[cfg(test)] 63 | mod tests { 64 | use super::*; 65 | use crate::shard_format::test_routines::rng_hash; 66 | 67 | #[test] 68 | fn test_regex() { 69 | let mh = rng_hash(0); 70 | 71 | assert!(parse_shard_filename(format!("/Users/me/temp/{}.mdb", mh.hex())).is_some()); 72 | assert!(parse_shard_filename(format!("{}.mdb", mh.hex())).is_some()); 73 | assert!(parse_shard_filename(format!("other_{}.mdb", mh.hex())).is_none()); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /merklehash/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "merklehash" 3 | version = "0.14.5" 4 | edition = "2024" 5 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 6 | 7 | [dependencies] 8 | base64 = { workspace = true } 9 | blake3 = { workspace = true } 10 | rand = { workspace = true, features = ["small_rng"] } 11 | safe-transmute = { workspace = true } 12 | serde = { workspace = true } 13 | 14 | [target.'cfg(not(target_family = "wasm"))'.dependencies] 15 | heed = { workspace = true } 16 | 17 | [target.'cfg(target_family = "wasm")'.dependencies] 18 | getrandom = { workspace = true, features = ["wasm_js"] } 19 | 20 | [features] 21 | strict = [] 22 | -------------------------------------------------------------------------------- /merklehash/README.md: -------------------------------------------------------------------------------- 1 | # merklehash 2 | 3 | The `merklehash` crate exports a `Merklehash` type that represents a 32 byte hash throughout all of the xet-core 4 | components. 5 | 6 | `merklehash` also exports some hashing functions e.g. `file_hash` and `xorb_hash` to compute `MerkleHash`es. 7 | 8 | The `MerkleHash` is internally represented as 4 `u64` (`[u64; 4]`). 9 | 10 | The `HexMerkleHash` is also exported and is intended to be used to provide a `serde::Serialize` implementation for a 11 | `MerkleHash` using the string hexadecimal representation of the hash. 12 | -------------------------------------------------------------------------------- /merklehash/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! The merklehash module provides common and convenient operations 2 | //! around the [DataHash] (aliased to [MerkleHash]). 3 | //! 4 | //! The [MerkleHash] is internally a 256-bit value stored as 4 u64 and is 5 | //! a node in a MerkleTree. A MerkleTree is a hierarchical datastructure 6 | //! where the leaves are hashes of data (for instance, blocks in a file). 7 | //! Then the hash of each non-leaf node is derived from the hashes of its child 8 | //! nodes. 9 | //! 10 | //! A default constructor is provided to make the hash of 0s. 11 | //! ```ignore 12 | //! // creates a default hash value of all 0s 13 | //! let hash = MerkleHash::default(); 14 | //! ``` 15 | //! 16 | //! Two hash functions are provided to compute a MerkleHash from a slice of 17 | //! bytes. The first is [compute_data_hash] which should be used when computing 18 | //! a hash from any user-provided sequence of bytes (i.e. the leaf nodes) 19 | //! ```ignore 20 | //! // compute from a byte slice of &[u8] 21 | //! let string = "hello world"; 22 | //! let hash = compute_data_hash(slice.as_bytes()); 23 | //! ``` 24 | //! 25 | //! The second is [compute_internal_node_hash] should be used when computing 26 | //! the hash of interior nodes. Note that this method also just accepts a slice 27 | //! of `&[u8]` and it is up to the caller to format the string appropriately. 28 | //! For instance: the string could be simply the child hashes printed out 29 | //! consecutively. 30 | //! 31 | //! The reason why this method does not simply take an array of Hashes, and 32 | //! instead require the caller to format the input as a string is to allow the 33 | //! user to add additional information to the string being hashed (beyond just 34 | //! the hashes itself). i.e. the string being hashed could be a concatenation 35 | //! of "hashes of children + children metadata". 36 | //! ```ignore 37 | //! let hash = compute_internal_node_hash(slice.as_bytes()); 38 | //! ``` 39 | //! 40 | //! The two hash functions [compute_data_hash] and [compute_internal_node_hash] 41 | //! are keyed differently such the same inputs will produce different outputs. 42 | //! And in particular, it should be difficult to find a collision where 43 | //! a `compute_data_hash(a) == compute_internal_node_hash(b)` 44 | 45 | #![cfg_attr(feature = "strict", deny(warnings))] 46 | 47 | pub mod data_hash; 48 | pub use data_hash::*; 49 | pub type MerkleHash = DataHash; 50 | 51 | mod aggregated_hashes; 52 | 53 | pub use aggregated_hashes::{file_hash, file_hash_with_salt, xorb_hash}; 54 | -------------------------------------------------------------------------------- /openapi/.gitignore: -------------------------------------------------------------------------------- 1 | openapitools.json 2 | generated/ 3 | -------------------------------------------------------------------------------- /openapi/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all check-cli rust typescript python java golang clean outdir 2 | 3 | CLI ?= openapi-generator-cli 4 | SPEC := $(CURDIR)/cas.openapi.yaml 5 | OUT_ROOT := $(CURDIR)/generated 6 | 7 | .DEFAULT_GOAL := all 8 | 9 | all: rust typescript python java golang 10 | 11 | define ensure_cli 12 | @bash -c 'set -euo pipefail; \ 13 | if command -v $(CLI) >/dev/null 2>&1; then \ 14 | echo "$(CLI) found: '$$(command -v $(CLI))'"; \ 15 | exit 0; \ 16 | fi; \ 17 | echo "$(CLI) not found; installing via npm..." >&2; \ 18 | if ! command -v npm >/dev/null 2>&1; then \ 19 | echo "npm is not installed. Please install Node.js/npm and re-run." >&2; \ 20 | exit 1; \ 21 | fi; \ 22 | npm install @openapitools/openapi-generator-cli -g; \ 23 | export PATH="$$PATH:$$(npm bin -g)"; \ 24 | if ! command -v $(CLI) >/dev/null 2>&1; then \ 25 | echo "$(CLI) still not found after npm installation. Ensure npm global bin is in PATH." >&2; \ 26 | exit 1; \ 27 | fi; \ 28 | echo "$(CLI) installed: '$$(command -v $(CLI))'";' 29 | endef 30 | 31 | check-cli: 32 | $(ensure_cli) 33 | 34 | outdir: 35 | mkdir -p "$(OUT_ROOT)" 36 | 37 | define gen 38 | @echo "Generating $(2) client -> $(OUT_ROOT)/$(2)" 39 | @rm -rf "$(OUT_ROOT)/$(2)" 40 | $(CLI) generate -i "$(SPEC)" -g "$(1)" -o "$(OUT_ROOT)/$(2)" $(3) 41 | endef 42 | 43 | rust: check-cli outdir 44 | $(call gen,rust,rust,--additional-properties=packageName=xet_cas_client,packageVersion=0.1.0,library=reqwest,preferUnsignedInt=true) 45 | 46 | typescript: check-cli outdir 47 | $(call gen,typescript-fetch,typescript,--additional-properties=npmName=@xet/cas-client,npmVersion=0.1.0,typescriptThreePlus=true) 48 | 49 | python: check-cli outdir 50 | $(call gen,python,python,--additional-properties=packageName=xet_cas_client,projectName=xet_cas_client,packageVersion=0.1.0) 51 | 52 | java: check-cli outdir 53 | $(call gen,java,java,--additional-properties=artifactId=xet-cas-client,groupId=ai.huggingface.xet,artifactVersion=0.1.0,library=webclient,datetimeLibrary=java8) 54 | 55 | golang: check-cli outdir 56 | $(call gen,go,golang,--additional-properties=packageName=casclient,enumClassPrefix=true,isGoSubmodule=false,withGoCodegenComment=true) 57 | 58 | clean: 59 | rm -rf "$(OUT_ROOT)" 60 | 61 | -------------------------------------------------------------------------------- /progress_tracking/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "progress_tracking" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | merklehash = { path = "../merklehash" } 8 | utils = { path = "../utils" } 9 | 10 | async-trait = { workspace = true } 11 | more-asserts = { workspace = true } 12 | tokio = { workspace = true, features = ["test-util", "time"] } 13 | 14 | [dev-dependencies] 15 | tokio = { workspace = true, features = ["test-util"] } 16 | -------------------------------------------------------------------------------- /progress_tracking/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod aggregator; 2 | pub mod item_tracking; 3 | mod no_op_tracker; 4 | mod progress_info; 5 | pub mod upload_tracking; 6 | pub mod verification_wrapper; 7 | 8 | use async_trait::async_trait; 9 | pub use no_op_tracker::NoOpProgressUpdater; 10 | pub use progress_info::{ItemProgressUpdate, ProgressUpdate}; 11 | 12 | /// The trait that a progress updater that reports per-item progress completion. 13 | #[async_trait] 14 | pub trait TrackingProgressUpdater: Send + Sync { 15 | /// Register a set of updates as a list of ProgressUpdate instances, which 16 | /// contain the name and progress information. 17 | async fn register_updates(&self, updates: ProgressUpdate); 18 | 19 | /// Flush any updates out, if needed 20 | async fn flush(&self) {} 21 | } 22 | -------------------------------------------------------------------------------- /progress_tracking/src/no_op_tracker.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::{ProgressUpdate, TrackingProgressUpdater}; 4 | 5 | #[derive(Debug, Default)] 6 | pub struct NoOpProgressUpdater; 7 | 8 | impl NoOpProgressUpdater { 9 | pub fn new() -> Arc { 10 | Arc::new(Self {}) 11 | } 12 | } 13 | 14 | #[async_trait::async_trait] 15 | impl TrackingProgressUpdater for NoOpProgressUpdater { 16 | async fn register_updates(&self, _updates: ProgressUpdate) {} 17 | } 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.7,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "hf-xet" 7 | requires-python = ">=3.8" 8 | description = "Fast transfer of large files with the Hugging Face Hub." 9 | license = { text = "Apache-2.0", file = "LICENSE" } 10 | classifiers = [ 11 | "Programming Language :: Rust", 12 | "Programming Language :: Python :: Implementation :: CPython", 13 | "Programming Language :: Python :: Implementation :: PyPy", 14 | ] 15 | dynamic = ["version"] 16 | 17 | [project.optional-dependencies] 18 | tests = [ 19 | "pytest", 20 | ] 21 | 22 | [project.urls] 23 | Homepage = "https://github.com/huggingface/xet-core" 24 | Documentation = "https://huggingface.co/docs/hub/en/storage-backends#using-xet-storage" 25 | Issues = "https://github.com/huggingface/xet-core/issues" 26 | Repository = "https://github.com/huggingface/xet-core.git" 27 | 28 | [tool.maturin] 29 | python-source = "hf_xet/python" 30 | features = ["pyo3/extension-module"] 31 | manifest-path = "hf_xet/Cargo.toml" 32 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | format_code_in_doc_comments = true 2 | wrap_comments = true 3 | reorder_imports = true 4 | unstable_features = true 5 | group_imports = "StdExternalCrate" 6 | imports_granularity = "Module" 7 | imports_layout = "Mixed" 8 | match_block_trailing_comma = true 9 | comment_width = 120 10 | max_width = 120 11 | fn_call_width = 100 12 | chain_width = 80 13 | use_field_init_shorthand = true 14 | -------------------------------------------------------------------------------- /utils/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "utils" 3 | version = "0.14.5" 4 | edition = "2024" 5 | 6 | [lib] 7 | name = "utils" 8 | path = "src/lib.rs" 9 | 10 | [dependencies] 11 | error_printer = { path = "../error_printer" } 12 | merklehash = { path = "../merklehash" } 13 | 14 | async-trait = { workspace = true } 15 | bytes = { workspace = true } 16 | ctor = { workspace = true } 17 | derivative = { workspace = true } 18 | duration-str = { workspace = true } 19 | futures = { workspace = true } 20 | lazy_static = { workspace = true } 21 | pin-project = { workspace = true } 22 | shellexpand = { workspace = true } 23 | thiserror = { workspace = true } 24 | tokio = { workspace = true, features = ["time", "rt", "macros", "sync"] } 25 | tracing = { workspace = true } 26 | 27 | [target.'cfg(not(target_family = "wasm"))'.dev-dependencies] 28 | tempfile = { workspace = true } 29 | xet_runtime = { path = "../xet_runtime" } 30 | 31 | [target.'cfg(target_family = "wasm")'.dependencies] 32 | web-time = { workspace = true } 33 | 34 | [dev-dependencies] 35 | serial_test = { workspace = true } 36 | futures-util = { workspace = true } 37 | 38 | [features] 39 | strict = [] 40 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Proto 2 | Directory where gproto files will be created 3 | 4 | # Operational helpers 5 | - Logs, metrics and traces 6 | - Configuration 7 | - Access to AWS services (e.g. S3) 8 | 9 | # Examples 10 | ## Identify which cas_server owns a particular key 11 | ``` 12 | cargo run --example infra -- --server-name cas-lb.xetbeta.com:5000 --key bar 13 | Host: 35.89.208.89 14 | Load Stats: SystemStatus { timestamp: "2022-07-06T19:15:00Z", cpu_utilization: 0.3416666833712037 } 15 | Host: 54.245.178.249 16 | Load Stats: SystemStatus { timestamp: "2022-07-06T19:15:00Z", cpu_utilization: 0.2943333333333333 } 17 | Key bar gets hashed to server "54.245.178.249" 18 | ``` 19 | -------------------------------------------------------------------------------- /utils/src/async_iterator.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | #[async_trait] 4 | pub trait AsyncIterator: Send + Sync { 5 | type Item: Send + Sync; 6 | 7 | /// The traditional next method for iterators, with a Result and Error 8 | /// type. Returns None when everything is done. 9 | async fn next(&mut self) -> Result, E>; 10 | } 11 | 12 | #[async_trait] 13 | pub trait BatchedAsyncIterator: AsyncIterator { 14 | /// Return a block of items. If the stream is done, then an empty vector is returned; 15 | /// otherwise, at least one item is returned. 16 | /// 17 | /// If given, max_num dictates the maximum number of items to return. If None, then all 18 | /// available items are returned. 19 | async fn next_batch(&mut self, max_num: Option) -> Result, E>; 20 | 21 | /// Returns the number of items remaining in the stream 22 | /// if known, and None otherwise. Returns Some(0) if 23 | /// there are no items remaining. 24 | fn items_remaining(&self) -> Option; 25 | } 26 | 27 | #[async_trait] 28 | impl AsyncIterator for Vec { 29 | type Item = Vec; 30 | 31 | async fn next(&mut self) -> Result, E> { 32 | if self.is_empty() { 33 | Ok(None) 34 | } else { 35 | Ok(Some(std::mem::take(self))) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /utils/src/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Debug, Error)] 4 | #[non_exhaustive] 5 | pub enum KeyError { 6 | #[error("Key parsing failure: {0}")] 7 | UnparsableKey(String), 8 | } 9 | 10 | #[derive(Debug, Error)] 11 | #[non_exhaustive] 12 | pub enum SingleflightError 13 | where 14 | E: Send + std::fmt::Debug + Sync, 15 | { 16 | #[error("BUG: singleflight waiter was notified before result was updated")] 17 | NoResult, 18 | 19 | #[error("BUG: call was removed before singleflight owner could update it")] 20 | CallMissing, 21 | 22 | #[error("BUG: call didn't create a Notifier for the initial task")] 23 | NoNotifierCreated, 24 | 25 | #[error(transparent)] 26 | InternalError(#[from] E), 27 | 28 | #[error("Real call failed: {0}")] 29 | WaiterInternalError(String), 30 | 31 | #[error("JoinError inside singleflight owner task: {0}")] 32 | JoinError(String), 33 | 34 | #[error("Owner task panicked")] 35 | OwnerPanicked, 36 | 37 | #[error("Poisoned Group lock")] 38 | GroupLockPoisoned, 39 | } 40 | 41 | #[derive(Debug, Error)] 42 | #[non_exhaustive] 43 | pub enum AuthError { 44 | #[error("Refresh function: {0} is not callable")] 45 | RefreshFunctionNotCallable(String), 46 | 47 | #[error("Token refresh failed: {0}")] 48 | TokenRefreshFailure(String), 49 | } 50 | 51 | impl AuthError { 52 | pub fn token_refresh_failure(err: impl ToString) -> Self { 53 | Self::TokenRefreshFailure(err.to_string()) 54 | } 55 | } 56 | 57 | impl Clone for SingleflightError { 58 | fn clone(&self) -> Self { 59 | match self { 60 | SingleflightError::NoResult => SingleflightError::NoResult, 61 | SingleflightError::CallMissing => SingleflightError::CallMissing, 62 | SingleflightError::NoNotifierCreated => SingleflightError::NoNotifierCreated, 63 | SingleflightError::InternalError(e) => SingleflightError::WaiterInternalError(format!("{e:?}")), 64 | SingleflightError::WaiterInternalError(s) => SingleflightError::WaiterInternalError(s.clone()), 65 | SingleflightError::JoinError(e) => SingleflightError::JoinError(e.clone()), 66 | SingleflightError::OwnerPanicked => SingleflightError::OwnerPanicked, 67 | SingleflightError::GroupLockPoisoned => SingleflightError::GroupLockPoisoned, 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /utils/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(feature = "strict", deny(warnings))] 2 | 3 | pub mod async_iterator; 4 | pub mod async_read; 5 | pub mod auth; 6 | pub mod constant_declarations; 7 | pub mod errors; 8 | #[cfg(not(target_family = "wasm"))] 9 | pub mod limited_joinset; 10 | mod output_bytes; 11 | pub mod serialization_utils; 12 | #[cfg(not(target_family = "wasm"))] 13 | pub mod singleflight; 14 | 15 | pub use output_bytes::output_bytes; 16 | 17 | pub mod rw_task_lock; 18 | pub use rw_task_lock::{RwTaskLock, RwTaskLockError, RwTaskLockReadGuard}; 19 | 20 | #[cfg(not(target_family = "wasm"))] 21 | mod file_paths; 22 | 23 | #[cfg(not(target_family = "wasm"))] 24 | pub use file_paths::{CwdGuard, EnvVarGuard, normalized_path_from_user_string}; 25 | 26 | pub mod byte_size; 27 | pub use byte_size::ByteSize; 28 | -------------------------------------------------------------------------------- /utils/src/limited_joinset.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::sync::Arc; 3 | use std::task::{Context, Poll}; 4 | 5 | use tokio::sync::Semaphore; 6 | use tokio::task::{AbortHandle, JoinError, JoinSet as TokioJoinSet}; 7 | 8 | pub struct LimitedJoinSet { 9 | inner: TokioJoinSet, 10 | semaphore: Arc, 11 | } 12 | 13 | impl LimitedJoinSet { 14 | pub fn new(max_concurrent: usize) -> Self { 15 | Self { 16 | inner: TokioJoinSet::new(), 17 | semaphore: Arc::new(Semaphore::new(max_concurrent)), 18 | } 19 | } 20 | 21 | pub fn spawn(&mut self, task: F) -> AbortHandle 22 | where 23 | F: Future, 24 | F: Send + 'static, 25 | T: Send, 26 | { 27 | let semaphore = self.semaphore.clone(); 28 | self.inner.spawn(async move { 29 | let _permit = semaphore.acquire().await; 30 | task.await 31 | }) 32 | } 33 | 34 | pub fn try_join_next(&mut self) -> Option> { 35 | self.inner.try_join_next() 36 | } 37 | 38 | pub async fn join_next(&mut self) -> Option> { 39 | self.inner.join_next().await 40 | } 41 | 42 | pub async fn join_all(self) -> Vec { 43 | self.inner.join_all().await 44 | } 45 | 46 | pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll>> { 47 | self.inner.poll_join_next(cx) 48 | } 49 | 50 | pub fn len(&self) -> usize { 51 | self.inner.len() 52 | } 53 | 54 | pub fn is_empty(&self) -> bool { 55 | self.inner.is_empty() 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod tests { 61 | use std::time::Duration; 62 | 63 | use super::*; 64 | 65 | #[tokio::test] 66 | async fn test_joinset() { 67 | let mut join_set = LimitedJoinSet::new(3); 68 | 69 | for i in 0..4 { 70 | join_set.spawn(async move { 71 | tokio::time::sleep(Duration::from_millis(10 - i)).await; 72 | i 73 | }); 74 | } 75 | 76 | let mut outs = Vec::new(); 77 | while let Some(Ok(value)) = join_set.join_next().await { 78 | outs.push(value); 79 | } 80 | 81 | // expect that the task returning 3 was spawned after at least 1 other task finished 82 | assert_eq!(outs.len(), 4); 83 | for (i, out) in outs.into_iter().enumerate() { 84 | if out == 3 { 85 | assert!(i > 0); 86 | } 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /utils/src/output_bytes.rs: -------------------------------------------------------------------------------- 1 | /// Convert a usize into an output string, chooosing the nearest byte prefix. 2 | /// 3 | /// # Arguments 4 | /// * `v` - the size in bytes 5 | pub fn output_bytes(v: u64) -> String { 6 | let map = vec![ 7 | (1_099_511_627_776, "TiB"), 8 | (1_073_741_824, "GiB"), 9 | (1_048_576, "MiB"), 10 | (1024, "KiB"), 11 | ]; 12 | 13 | if v == 0 { 14 | return "0 bytes".to_string(); 15 | } 16 | 17 | for (div, s) in map { 18 | let curr = v as f64 / div as f64; 19 | if v / div > 0 { 20 | return if v % div == 0 { 21 | format!("{} {}", v / div, s) 22 | } else { 23 | format!("{curr:.2} {s}") 24 | }; 25 | } 26 | } 27 | 28 | format!("{v} bytes") 29 | } 30 | 31 | #[cfg(test)] 32 | mod tests { 33 | use super::*; 34 | 35 | #[test] 36 | fn test_size_conversion() { 37 | assert_eq!("500 bytes", output_bytes(500)); 38 | assert_eq!("999 bytes", output_bytes(999)); 39 | assert_eq!("1 KiB", output_bytes(1024)); 40 | assert_eq!("1.00 KiB", output_bytes(1025)); 41 | assert_eq!("999.99 KiB", output_bytes(1_023_989)); 42 | assert_eq!("1 MiB", output_bytes(1_048_576)); 43 | assert_eq!("1.00 MiB", output_bytes(1048577)); 44 | assert_eq!("999.99 MiB", output_bytes(1_048_565_514)); 45 | assert_eq!("1 GiB", output_bytes(1_073_741_824)); 46 | assert_eq!("1.00 GiB", output_bytes(1_073_741_825)); 47 | assert_eq!("999.99 GiB", output_bytes(1_073_731_086_581)); 48 | assert_eq!("1 TiB", output_bytes(1_099_511_627_776)); 49 | assert_eq!("1.00 TiB", output_bytes(1_099_511_627_777)); 50 | assert_eq!("1234.57 TiB", output_bytes(1_357_424_070_303_416)); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /xet_runtime/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "xet_runtime" 3 | version = "0.1.0" 4 | edition = "2024" 5 | 6 | [dependencies] 7 | utils = { path = "../utils" } 8 | error_printer = { path = "../error_printer" } 9 | 10 | oneshot = { workspace = true } 11 | reqwest = { workspace = true } 12 | thiserror = { workspace = true } 13 | tokio = { workspace = true, features = ["time", "rt", "macros", "io-util"] } 14 | tracing = { workspace = true } 15 | 16 | [target.'cfg(not(target_family = "wasm"))'.dependencies] 17 | tokio = { workspace = true, features = ["rt-multi-thread"] } 18 | 19 | [target.'cfg(target_os = "macos")'.dependencies] 20 | libc = { workspace = true} 21 | -------------------------------------------------------------------------------- /xet_runtime/src/errors.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// Define an error time for spawning external threads. 4 | #[derive(Debug, Error)] 5 | #[non_exhaustive] 6 | pub enum MultithreadedRuntimeError { 7 | #[error("Error Initializing Multithreaded Runtime: {0:?}")] 8 | RuntimeInitializationError(std::io::Error), 9 | 10 | #[error("Task Panic: {0:?}.")] 11 | TaskPanic(String), 12 | 13 | #[error("Task cancelled; possible runtime shutdown in progress ({0}).")] 14 | TaskCanceled(String), 15 | 16 | #[error("Unknown task runtime error: {0}")] 17 | Other(String), 18 | } 19 | 20 | impl From for MultithreadedRuntimeError { 21 | fn from(err: tokio::task::JoinError) -> Self { 22 | if err.is_panic() { 23 | // The task panic'd. Pass this exception on. 24 | tracing::error!("Panic reported on xet worker task: {err:?}"); 25 | MultithreadedRuntimeError::TaskPanic(format!("{err:?}")) 26 | } else if err.is_cancelled() { 27 | // Likely caused by the runtime shutting down (e.g. with a keyboard CTRL-C). 28 | MultithreadedRuntimeError::TaskCanceled(format!("{err}")) 29 | } else { 30 | MultithreadedRuntimeError::Other(format!("task join error: {err}")) 31 | } 32 | } 33 | } 34 | 35 | // Define our own result type here (this seems to be the standard). 36 | pub type Result = std::result::Result; 37 | -------------------------------------------------------------------------------- /xet_runtime/src/exports.rs: -------------------------------------------------------------------------------- 1 | // Re-exports for dependent libraries like hf_xet to use to consolidate 2 | // Cargo.toml specifications. 3 | pub use tokio; 4 | -------------------------------------------------------------------------------- /xet_runtime/src/file_handle_limits.rs: -------------------------------------------------------------------------------- 1 | #[cfg(target_os = "macos")] 2 | pub fn raise_nofile_soft_to_hard() { 3 | use tracing::info; 4 | 5 | unsafe { 6 | use libc; 7 | 8 | let mut lim = libc::rlimit { 9 | rlim_cur: 0, 10 | rlim_max: 0, 11 | }; 12 | if libc::getrlimit(libc::RLIMIT_NOFILE, &mut lim) != 0 { 13 | info!("Failed to get RLIMIT_NOFILE: {:?}", std::io::Error::last_os_error()); 14 | return; 15 | } 16 | 17 | if lim.rlim_cur < lim.rlim_max { 18 | let new_lim = libc::rlimit { 19 | rlim_cur: lim.rlim_max, 20 | rlim_max: lim.rlim_max, 21 | }; 22 | if libc::setrlimit(libc::RLIMIT_NOFILE, &new_lim) != 0 { 23 | info!( 24 | "Failed to set RLIMIT_NOFILE soft limit from {} to {}: {:?}", 25 | lim.rlim_cur, 26 | lim.rlim_max, 27 | std::io::Error::last_os_error() 28 | ); 29 | return; 30 | } 31 | info!("Increased RLIMIT_NOFILE soft limit from {} to {}", lim.rlim_cur, lim.rlim_max); 32 | } else { 33 | info!("RLIMIT_NOFILE soft limit already at hard limit: {}", lim.rlim_cur); 34 | } 35 | } 36 | } 37 | 38 | #[cfg(not(target_os = "macos"))] 39 | pub fn raise_nofile_soft_to_hard() {} 40 | -------------------------------------------------------------------------------- /xet_runtime/src/global_semaphores.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::sync::{Arc, Mutex}; 3 | 4 | use tokio::sync::Semaphore; 5 | 6 | /// Identifies a process-wide semaphore and its initial size. 7 | /// 8 | /// Create one using global_semaphore_handle!() and pass it to 9 | /// `ThreadPool::current().global_semaphore(handle)` to obtain a global semaphore 10 | /// `Arc` 11 | /// 12 | /// The `initial_value` is applied only when the semaphore is first created for 13 | /// this handle; later lookups return the existing semaphore. 14 | /// 15 | /// # Typical usage 16 | /// 17 | /// ```ignore 18 | /// use lazy_static::lazy_static; 19 | /// 20 | /// lazy_static! { 21 | /// static ref UPLOAD_LIMITER: GlobalSemaphoreHandle = global_semaphore_handle!(32); 22 | /// } 23 | /// 24 | /// // Acquire a permit 25 | /// let permit = ThreadPool::current() 26 | /// .global_semaphore(*UPLOAD_LIMITER) 27 | /// .acquire_owned() 28 | /// .await?; 29 | /// ``` 30 | #[derive(Copy, Clone)] 31 | pub struct GlobalSemaphoreHandle { 32 | pub handle: &'static str, 33 | pub initial_value: usize, 34 | } 35 | 36 | impl AsRef for GlobalSemaphoreHandle { 37 | fn as_ref(&self) -> &GlobalSemaphoreHandle { 38 | self 39 | } 40 | } 41 | 42 | /// Creates a GlobalSemaphoreHandle instance with a compile-time unique handle and 43 | /// initial value. 44 | /// 45 | /// # Example usage 46 | /// 47 | /// ```ignore 48 | /// use lazy_static::lazy_static; 49 | /// 50 | /// lazy_static! { 51 | /// static ref UPLOAD_LIMITER: GlobalSemaphoreHandle = global_semaphore_handle!(32); 52 | /// } 53 | /// 54 | /// // Acquire a permit 55 | /// let permit = ThreadPool::current() 56 | /// .global_semaphore(*UPLOAD_LIMITER) 57 | /// .acquire_owned() 58 | /// .await?; 59 | /// ``` 60 | #[macro_export] 61 | macro_rules! global_semaphore_handle { 62 | // Expression form: returns a GlobalSemaphoreHandle 63 | ($perm:expr) => {{ 64 | // A compile-time unique &'static str using module, file, line, and column. 65 | const __HANDLE: &str = concat!(module_path!(), "::", file!(), ":", line!(), ":", column!()); 66 | 67 | $crate::GlobalSemaphoreHandle { 68 | handle: __HANDLE, 69 | initial_value: ($perm).into(), 70 | } 71 | }}; 72 | } 73 | 74 | #[derive(Debug, Default)] 75 | pub(crate) struct GlobalSemaphoreLookup { 76 | lookup: Mutex>>, 77 | } 78 | 79 | impl GlobalSemaphoreLookup { 80 | pub(crate) fn get(&self, handle: impl Into) -> Arc { 81 | let handle = handle.into(); 82 | 83 | let mut sl = self.lookup.lock().expect("Recursive lock; bug"); 84 | 85 | sl.entry(handle.handle) 86 | .or_insert_with(|| Arc::new(Semaphore::new(handle.initial_value))) 87 | .clone() 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /xet_runtime/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod errors; 2 | pub mod exports; 3 | 4 | pub mod runtime; 5 | 6 | pub use runtime::XetRuntime; 7 | pub mod sync_primatives; 8 | pub use sync_primatives::{SyncJoinHandle, spawn_os_thread}; 9 | 10 | #[macro_use] 11 | mod global_semaphores; 12 | pub mod utils; 13 | 14 | pub use global_semaphores::GlobalSemaphoreHandle; 15 | 16 | pub mod file_handle_limits; 17 | --------------------------------------------------------------------------------