├── .github
└── workflows
│ ├── pypi-deploy.yml
│ └── python-test.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── examples
└── example_resnet18.py
├── requirements.txt
├── setup.py
├── tests
├── test_checkpoint_readonly.py
└── test_synthetic.py
└── zipslicer
├── __init__.py
├── custom_load.py
└── weights_only_unpickler.py
/.github/workflows/pypi-deploy.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # Then it will deploy the package to PyPI
3 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
4 |
5 | name: Deploy Python package to PyPI
6 |
7 | on:
8 | release:
9 | tags:
10 | - '*'
11 |
12 | jobs:
13 | tests:
14 | uses: "./.github/workflows/python-test.yml"
15 | publish:
16 | name: publish
17 | needs: [tests] # require tests to pass before deploy runs
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@main
21 | - name: Publish Python Package
22 | uses: mariamrf/py-package-publish-action@v1.1.0
23 | with:
24 | python_version: '3.10.0'
25 | env:
26 | TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
27 | TWINE_USERNAME: __token__
28 |
--------------------------------------------------------------------------------
/.github/workflows/python-test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Lint and Test Python package
5 |
6 | on:
7 | workflow_dispatch:
8 | workflow_call: # allow this workflow to be called from other workflows
9 | push:
10 | branches: [ "main" ]
11 | paths:
12 | - 'zipslicer/**'
13 | - 'tests/**'
14 | - 'examples/**'
15 | pull_request:
16 | branches: [ "main" ]
17 |
18 | jobs:
19 | tests:
20 | runs-on: ubuntu-latest
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | python-version: ["3.10", "3.11"]
25 | torch-version: ["1.11.0", "1.12.0", "stable"]
26 | exclude: # there is no easily available older torch build for newer python
27 | - python-version: "3.11"
28 | torch-version: "1.11.0"
29 | - python-version: "3.11"
30 | torch-version: "1.12.0"
31 |
32 | steps:
33 | - uses: actions/checkout@v3
34 | - name: Set up Python ${{ matrix.python-version }}
35 | uses: actions/setup-python@v4
36 | with:
37 | python-version: ${{ matrix.python-version }}
38 | - name: Install testing and linting tools
39 | run: |
40 | python -m pip install --upgrade pip
41 | python -m pip install flake8 pytest
42 | - name: Lint with flake8
43 | run: |
44 | # stop the build if there are Python syntax errors or undefined names
45 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
47 | flake8 --ignore 'E501 W503 E203 E402 W293' . --count --exit-zero --max-complexity=40 --max-line-length=127 --statistics
48 | - name: Install dependencies, testing with cpu-only pytorch to save CI time, torch==${{matrix.torch-version}}
49 | run: |
50 | if [[ "stable" == ${{matrix.torch-version}} ]]; then
51 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
52 | else
53 | python -m pip install 'torch==${{matrix.torch-version}}+cpu' --extra-index-url https://download.pytorch.org/whl/cpu
54 | fi
55 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
56 | - name: Test with pytest
57 | run: |
58 | PYTHONPATH="." pytest -o log_cli=true --capture=tee-sys -p no:asyncio .
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | *~
3 | *.pth
4 | __pycache__
5 | .pytest_cache
6 | *.py[cod]
7 | *$py.class
8 | *.backup
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | All contributions by Kirill Gadjello:
2 | Copyright (c) 2023- Kirill Gadjello.
3 |
4 | From PyTorch:
5 |
6 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
7 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
8 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
9 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
10 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
11 | Copyright (c) 2011-2013 NYU (Clement Farabet)
12 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
13 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
14 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
15 |
16 | From Caffe2:
17 |
18 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
19 |
20 | All contributions by Facebook:
21 | Copyright (c) 2016 Facebook Inc.
22 |
23 | All contributions by Google:
24 | Copyright (c) 2015 Google Inc.
25 | All rights reserved.
26 |
27 | All contributions by Yangqing Jia:
28 | Copyright (c) 2015 Yangqing Jia
29 | All rights reserved.
30 |
31 | All contributions by Kakao Brain:
32 | Copyright 2019-2020 Kakao Brain
33 |
34 | All contributions by Cruise LLC:
35 | Copyright (c) 2022 Cruise LLC.
36 | All rights reserved.
37 |
38 | All contributions from Caffe:
39 | Copyright(c) 2013, 2014, 2015, the respective contributors
40 | All rights reserved.
41 |
42 | All other contributions:
43 | Copyright(c) 2015, 2016 the respective contributors
44 | All rights reserved.
45 |
46 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
47 | copyright over their contributions to Caffe2. The project versioning records
48 | all such contribution and copyright details. If a contributor wants to further
49 | mark their specific copyright on a particular contribution, they should
50 | indicate their copyright solely in the commit message of the change when it is
51 | committed.
52 |
53 | All rights reserved.
54 |
55 | Redistribution and use in source and binary forms, with or without
56 | modification, are permitted provided that the following conditions are met:
57 |
58 | 1. Redistributions of source code must retain the above copyright
59 | notice, this list of conditions and the following disclaimer.
60 |
61 | 2. Redistributions in binary form must reproduce the above copyright
62 | notice, this list of conditions and the following disclaimer in the
63 | documentation and/or other materials provided with the distribution.
64 |
65 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
66 | and IDIAP Research Institute nor the names of its contributors may be
67 | used to endorse or promote products derived from this software without
68 | specific prior written permission.
69 |
70 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
71 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
72 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
73 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
74 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
75 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
76 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
77 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
78 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
79 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
80 | POSSIBILITY OF SUCH DAMAGE.
81 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.txt
2 | recursive-include examples *.py
3 | recursive-include tests *.py
4 | recursive-include zipslicer *.py
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # *ZIPSLICER* 📁✂️
2 | [](https://github.com/kir-gadjello/zipslicer/actions/workflows/python-test.yml)
3 | [](https://github.com/kir-gadjello/zipslicer/actions/workflows/pypi-deploy.yml)
4 |
5 | A library for incremental loading of large PyTorch checkpoints
6 | [Read a blogpost introduction by yours truly](https://kir-gadjello.github.io/zipslicer)
7 |
8 | ## Synopsis
9 | ```python
10 | import torch
11 | import zipslicer
12 |
13 | # Could be a private custom recurrent sentient transformer
14 | # instead of a garden variety resnet
15 | my_complicated_network = torch.hub.load(
16 | "pytorch/vision:v0.10.0", "resnet18", pretrained=True
17 | )
18 | s_dict = my_complicated_network.state_dict()
19 | torch.save(s_dict, "my_network_checkpoint_v123.pth")
20 | del my_complicated_network
21 |
22 | # Later, on a smaller unrelated machine you load a "LazyStateDict"
23 | # Which is just like a regular state dict, but it loads tensors only when it has to
24 | lazy_s_dict = zipslicer.load("my_network_checkpoint_v123.pth")
25 | layer3_tensors = {}
26 | for k in lazy_s_dict.keys():
27 | if k.startswith("layer3"):
28 | layer3_tensors[k] = lazy_s_dict[k]
29 | # Now you have layer3's tensors and you can analyze them without breaking your RAM.
30 | # Or you can instantiate the layers' classes in sequence and compute the whole
31 | # network's output for a given input by threading the activations through them.
32 | # But we will just print the tensors instead:
33 | print(layer3_tensors)
34 | ```
35 |
36 | Run this example and unit-tests:
37 |
38 | `python examples/example_resnet18.py`
39 |
40 | `pytest -o log_cli=true --capture=tee-sys -p no:asyncio`
41 |
42 | Test your checkpoint for compatibility:
43 |
44 | `python tests/test_checkpoint_readonly.py your_magnificent_checkpoint.pth`
45 |
46 | If it's all green, it will work.
47 |
48 | ## Prerequisites
49 | * Supported python and torch versions: `python-3.10 + torch-(1.11,1.12,stable)` `python-3.11 + torch:stable`
50 | * Generally, `zipslicer` should work with modern enough install of PyTorch - use [included safe test](https://github.com/kir-gadjello/zipslicer/blob/main/tests/test_checkpoint_readonly.py) to check for compatibility of `zipslicer` with your PyTorch and your checkpoint. This is a pure Python library, so specific CPU architecture shouldn't matter.
51 | * A checkpoint produced by saving your model's `state_dict` via vanilla torch.save(...) - default settings should suffice, as Torch doesn't use ZIP compression.
52 | * An application that can take advantage of incrementally-loaded checkpoint - i.e. if your app just loads all `state_dict.items()` in a loop right away it doesn't make much sense to use this library. Make sure your code reads `state_dict.keys()` (and `state_dict.get_meta(k)` if necessary) and uses these intelligently to work on a subset of `state_dict[k]` tensors at a time. For general inspiration you might read [this (HF)](https://huggingface.co/docs/transformers/v4.26.0/en/main_classes/model#transformers.modeling_utils.load_sharded_checkpoint) and [this (arxiv)](https://arxiv.org/abs/2104.07857). With some additional engineering it should be possible to run Large Language Models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) or [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) on a single mid-range GPU at home - if you are willing to wait for a night's worth of time. In the large batch regime this might even make some practical sense, for example to process a set of documents into embeddings.
53 |
54 | ## Install
55 |
56 | Generally, copying the `zipslicer/zipslicer` directory into your project's source tree is enough.
57 |
58 | If you are a fan of official ceremony-driven install processes for executable modules of dubious provenance, soon there will be a possibility of installing this boutique software module via pip: `pip install zipslicer`
59 |
60 | ## Notes
61 | * This library is only for reading pytorch tensors from checkpoints. We leave writing for future work.
62 | * Writing to loaded `state_dict` is frowned upon, but it *will* work - though you should avoid doing this while iterating over keys for now and expecting the keys to reflect this update.
63 | * Perhaps more importantly, **general-purpose pickles are not supported** - the design of this library doesn't allow you to load whole neural network class instances. Usually this isn't necessary, and [pytorch official documentation recommends you to use `state_dict` for model serialization](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). We support `state_dict`'s.
64 | * Some rare tensor types (i.e: pytorch quantized tensors - not to be confused with integer tensors which work fine) are not yet supported. If this bothers you, share your experience in issues.
65 | * We say "Hi" to [HF `safetensors` project](https://github.com/huggingface/safetensors), but note that in comparison to theirs, our approach doesn't require checkpoint conversion which takes significant time and storage. In fact, both approaches could be complementary, as you will have to load tensors from the pytorch checkpoint somehow to convert it to `safetensors` - and the default loading mechanism is constrained by available RAM.
66 |
67 | ## Prospective features we are considering
68 | If you are interested in some of these features, consider creating an issue:
69 | * Effective loading of tensor slices - to implement tensor parallelism in sharded deployments
70 | * Accessing the source checkpoint over a network
71 | * Writing to a checkpoint in-place
72 | * Incremental conversion to other checkpoint formats
73 |
--------------------------------------------------------------------------------
/examples/example_resnet18.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | try:
4 | import zipslicer
5 | except Exception:
6 | import sys
7 |
8 | sys.path.append("./zipslicer")
9 | import zipslicer
10 |
11 |
12 | # Could be a private custom recurrent sentient transformer
13 | # instead of a garden variety resnet
14 | my_complicated_network = torch.hub.load(
15 | "pytorch/vision:v0.10.0", "resnet18", pretrained=True
16 | )
17 | s_dict = my_complicated_network.state_dict()
18 | torch.save(s_dict, "my_network_checkpoint_v123.pth")
19 | del my_complicated_network
20 |
21 | # Later, on a smaller unrelated machine you load a "LazyStateDict"
22 | # Which is just like a regular state dict, but it loads tensors only when it has to
23 | lazy_s_dict = zipslicer.load("my_network_checkpoint_v123.pth")
24 | layer3_tensors = {}
25 | for k in lazy_s_dict.keys():
26 | if k.startswith("layer3"):
27 | layer3_tensors[k] = lazy_s_dict[k]
28 | # Now you have layer3's tensors and you can analyze them without breaking your RAM.
29 | # Or you can load the layers' classes in sequence and compute the whole network's output.
30 | # But we will just print the tensors:
31 | print(layer3_tensors)
32 |
33 | # Cleanup after our experiment
34 | import os
35 | os.unlink("my_network_checkpoint_v123.pth")
36 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch >= 1.11
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | __version__ = "0.8.1"
4 |
5 | python_min_version = (3, 8, 0)
6 | version_range_max = 12
7 |
8 | with open("README.md", "r") as fh:
9 | long_description = fh.read()
10 |
11 | setup(
12 | name="zipslicer",
13 | version=__version__,
14 | description="A library for efficient incremental access to tensors stored in PyTorch checkpoints",
15 | packages=["zipslicer"],
16 | install_requires=["torch >= 1.10.0"],
17 | extras_require={
18 | "dev": [
19 | "pytest >= 3.10",
20 | ]
21 | },
22 | # PyPI package information from pytorch
23 | classifiers=[
24 | "Development Status :: 5 - Production/Stable",
25 | "Intended Audience :: Developers",
26 | "Intended Audience :: Education",
27 | "Intended Audience :: Science/Research",
28 | "License :: OSI Approved :: BSD License",
29 | "Topic :: Scientific/Engineering",
30 | "Topic :: Scientific/Engineering :: Mathematics",
31 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
32 | "Topic :: Software Development",
33 | "Topic :: Software Development :: Libraries",
34 | "Topic :: Software Development :: Libraries :: Python Modules",
35 | "Programming Language :: Python :: 3",
36 | ] + [
37 | "Programming Language :: Python :: 3.{}".format(i)
38 | for i in range(python_min_version[1], version_range_max)
39 | ],
40 | license="BSD-3",
41 | keywords="pytorch, machine learning",
42 | python_requires=">=3",
43 | long_description=long_description,
44 | long_description_content_type="text/markdown",
45 | author="Kirill Gadjello",
46 | author_email="kirill.gadjello@protonmail.com",
47 | url="https://github.com/kir-gadjello/zipslicer",
48 | )
49 |
--------------------------------------------------------------------------------
/tests/test_checkpoint_readonly.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023- Kirill Gadjello.
2 | # See LICENSE for details (basically it uses part of PyTorch sourcecode and is licensed under the same conditions)
3 |
4 | # Run it like this (from zipslicer repo root directory):
5 | # python ./tests/test_checkpoint_readonly.py 'path_to_your_checkpoint.pth'
6 |
7 | import os
8 | import sys
9 | import time
10 | import torch
11 | import random
12 |
13 | sys.path.append("./zipslicer")
14 |
15 | import zipslicer
16 |
17 | cgreen = "\033[92m"
18 | cyellow = "\033[93m"
19 | creset = "\033[0m"
20 | ok_green = f"{cgreen}[OK]{creset}"
21 |
22 |
23 | def __test_incremental_load(ckpt=None, seed=1337):
24 | random.seed(int(os.environ.get("ZIPSLICER_TEST_SEED", seed)))
25 |
26 | print_note = False
27 | if ckpt is None:
28 | if len(sys.argv) <= 1:
29 | print(
30 | "Usage:\n\tpython ./tests/test_checkpoint_readonly.py 'path_to_your_checkpoint.pth'"
31 | )
32 | sys.exit(-1)
33 | ckpt = sys.argv[1]
34 | print_note = True
35 |
36 | assert os.path.isfile(ckpt)
37 | if print_note:
38 | print(f'Using "{cyellow}{ckpt}{creset}" in {cgreen}readonly{creset} mode')
39 | print("=" * (os.get_terminal_size().columns))
40 | print(
41 | "Note: this test loads two copies of the checkpoint, one using standard torch.load and the other using zipslicer. You need enough CPU RAM to fit both, or you risk unresponsive behavior and massive swapping from your machine."
42 | )
43 | print("=" * (os.get_terminal_size().columns))
44 |
45 | sdict = torch.load(ckpt, map_location="cpu")
46 | skeys = sdict.keys()
47 | lazy_sdict = zipslicer.load(
48 | ckpt, map_location="cpu", debug=os.environ.get("ZIPSLICER_DEBUG") == "1"
49 | )
50 | lazy_keys = lazy_sdict.keys()
51 |
52 | print("Checking basic key correspondence")
53 | for k in skeys:
54 | assert k in lazy_keys
55 |
56 | for k in lazy_keys:
57 | assert k in skeys
58 | print(f"{ok_green}: {len(skeys)} keys total")
59 |
60 | print("Checking tensor metadata correspondence")
61 | for k, v in sdict.items():
62 | meta = lazy_sdict.get_meta(k)
63 | if k.endswith("._extra_state") and not isinstance(v, torch.Tensor):
64 | assert meta is None
65 | continue
66 |
67 | assert meta.shape == v.shape
68 | assert meta.size() == v.size()
69 | assert meta.dtype == v.dtype
70 | print(f"{ok_green}: {len(skeys)} keys total")
71 |
72 | test_keys = list(skeys)
73 |
74 | if os.environ.get("ZIPSLICER_TEST_SUBSET"):
75 | ratio = float(os.environ.get("ZIPSLICER_TEST_SUBSET"))
76 | random.shuffle(test_keys)
77 | N = int(len(test_keys) * ratio)
78 | test_keys = test_keys[:N]
79 | print(f"Using randomized key subset of length {N} for testing")
80 |
81 | N = len(test_keys)
82 | for i, k in enumerate(test_keys):
83 | print(f"[{i+1}/{N}] Checking key: {k}", end=" ")
84 | t0 = time.time_ns()
85 | T = sdict[k]
86 | LT = lazy_sdict[k]
87 |
88 | if k.endswith("._extra_state") and not isinstance(T, torch.Tensor):
89 | assert T == LT
90 | else:
91 | assert T.dtype == LT.dtype
92 | assert T.shape == LT.shape
93 | assert torch.allclose(T, LT)
94 |
95 | dt = time.time_ns() - t0
96 | print(f"{ok_green} in {round(dt/1e6, 2)}ms")
97 |
98 | del sdict
99 | del lazy_sdict
100 |
101 |
102 | if __name__ == "__main__":
103 | __test_incremental_load()
104 | print(f"{ok_green} All tests passed successfully")
105 |
--------------------------------------------------------------------------------
/tests/test_synthetic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023- Kirill Gadjello.
2 | # See LICENSE for details (basically it uses part of PyTorch sourcecode and is licensed under the same conditions)
3 |
4 | import os
5 | import sys
6 | import torch
7 | import random
8 | from test_checkpoint_readonly import __test_incremental_load
9 |
10 | sys.path.append("./zipslicer")
11 |
12 | import zipslicer
13 |
14 | seed = int(os.environ.get("ZIPSLICER_TEST_SEED", "1337"))
15 |
16 |
17 | def test_basic():
18 | FNAME = "test_basic.pth"
19 | torch.manual_seed(seed)
20 |
21 | sdict = dict(
22 | a=torch.randn(10, 20, 3, dtype=torch.float32),
23 | longer_name=torch.randn(10, 20, 3, dtype=torch.bfloat16),
24 | )
25 |
26 | torch.save(sdict, FNAME)
27 | __test_incremental_load(ckpt=FNAME)
28 | os.unlink(FNAME)
29 |
30 |
31 | def test_various_dtypes():
32 | FNAME = "test_various_dtypes.pth"
33 | torch.manual_seed(seed)
34 | random.seed(seed)
35 |
36 | sdict = dict()
37 | for dtype in zipslicer.dtype_sizes.keys():
38 | key = ".".join(str(random.randint(0, 2**16)) for _ in range(6))
39 | # TODO: quantized tensor support
40 | if "q" not in str(dtype):
41 | t = (
42 | torch.randn(
43 | random.randint(1, 16),
44 | random.randint(1, 16),
45 | random.randint(1, 16),
46 | dtype=torch.float32,
47 | )
48 | * 200.0
49 | ).to(dtype)
50 |
51 | sdict[key] = t
52 |
53 | torch.save(sdict, FNAME)
54 | __test_incremental_load(ckpt=FNAME)
55 | os.unlink(FNAME)
56 |
57 |
58 | def test_nn_sdict():
59 | FNAME = "test_nn_sdict.pth"
60 | torch.manual_seed(seed)
61 |
62 | network = torch.nn.ModuleList(
63 | [torch.nn.Linear(1000, 2000), torch.nn.Linear(2000, 2000)]
64 | )
65 |
66 | sdict = network.state_dict()
67 |
68 | torch.save(sdict, FNAME)
69 | __test_incremental_load(ckpt=FNAME)
70 | os.unlink(FNAME)
71 |
72 |
73 | def test_nn_sdict_w_extra_state():
74 | FNAME = "test_nn_sdict_w_extra_state.pth"
75 | torch.manual_seed(seed)
76 |
77 | class CustomLinear(torch.nn.Linear):
78 | def get_extra_state(self):
79 | return dict(
80 | a=random.randint(1, 2**64 - 1), b="this is extra state", c=[1, 2, 3]
81 | )
82 |
83 | network = torch.nn.ModuleList([CustomLinear(1000, 2000), CustomLinear(2000, 2000)])
84 |
85 | sdict = network.state_dict()
86 |
87 | torch.save(sdict, FNAME)
88 | __test_incremental_load(ckpt=FNAME)
89 | os.unlink(FNAME)
90 |
91 |
92 | def test_nn_pickle_raises():
93 | FNAME = "test_nn_pickle.pth"
94 | torch.manual_seed(seed)
95 |
96 | network = torch.nn.ModuleList(
97 | [torch.nn.Linear(1000, 2000), torch.nn.Linear(2000, 2000)]
98 | )
99 |
100 | torch.save(network, FNAME)
101 |
102 | try:
103 | zipslicer.load(FNAME)
104 | except Exception as e:
105 | assert (
106 | "Error at zipslicer.load bootstrap stage, your torch pickle checkpoint is likely too complex for the lightweight loader to interpret. Make sure your network was saved as a state_dict, instead of general-purpose network pickle"
107 | in str(e)
108 | )
109 |
110 | os.unlink(FNAME)
111 |
--------------------------------------------------------------------------------
/zipslicer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023- Kirill Gadjello.
2 | # See LICENSE for details (basically it uses parts of PyTorch sourcecode and is licensed under the same conditions)
3 |
4 | import os
5 | from functools import reduce
6 | import types
7 | from collections import OrderedDict
8 | import pickle
9 | import zipfile
10 | import struct
11 |
12 | import torch
13 |
14 | from . import custom_load
15 | from . import weights_only_unpickler
16 |
17 |
18 | def create_storage(buf):
19 | if hasattr(torch, "UntypedStorage"):
20 | return torch.UntypedStorage.from_buffer(buf, dtype=torch.uint8)
21 | elif hasattr(torch, "_UntypedStorage"):
22 | # fallback for older torch versions which don't have UntypedStorage
23 | return torch._UntypedStorage.from_buffer(buf, dtype=torch.uint8)
24 | else:
25 | # fallback for older torch versions which don't have _UntypedStorage
26 | return torch.ByteStorage.from_buffer(buf)
27 |
28 |
29 | def get_typed_storage(backing_storage, dtype):
30 | # TODO: Upstream pytorch might change the semantics here eventually
31 | if hasattr(torch, "TypedStorage"):
32 | return torch.storage.TypedStorage(wrap_storage=backing_storage, dtype=dtype)
33 | else:
34 | legacy_storage_types = {
35 | torch.bfloat16: torch.BFloat16Storage,
36 | torch.half: torch.HalfStorage,
37 | torch.float: torch.FloatStorage,
38 | torch.double: torch.DoubleStorage,
39 | torch.int8: torch.CharStorage,
40 | torch.uint8: torch.ByteStorage,
41 | torch.short: torch.ShortStorage,
42 | torch.int: torch.IntStorage,
43 | torch.long: torch.LongStorage,
44 | torch.bool: torch.BoolStorage,
45 | }
46 | if dtype not in legacy_storage_types:
47 | raise Exception(
48 | "Failed to create Torch Storage object, your Torch version is probably too old"
49 | )
50 | return legacy_storage_types[dtype](wrap_storage=backing_storage)
51 |
52 |
53 | # ZIP "local file header" structure, magic number, size, and indices
54 | # (section V.A in the format document)
55 | structFileHeader = "<4s2B4HL2L2H"
56 | stringFileHeader = b"PK\003\004"
57 | _FH_SIGNATURE = 0
58 | _FH_EXTRA_FIELD_LENGTH = 11
59 | _FH_FILENAME_LENGTH = 10
60 | _FH_GENERAL_PURPOSE_FLAG_BITS = 3
61 |
62 | sizeFileHeader = struct.calcsize(structFileHeader)
63 |
64 |
65 | def skip_header(zef_file, zinfo):
66 | zef_file.seek(zinfo.header_offset)
67 | fheader = zef_file.read(sizeFileHeader)
68 | if len(fheader) != sizeFileHeader:
69 | raise Exception("Truncated file header")
70 | fheader = struct.unpack(structFileHeader, fheader)
71 | if fheader[_FH_SIGNATURE] != stringFileHeader:
72 | raise Exception("Bad magic number for file header")
73 |
74 | fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
75 | if fheader[_FH_EXTRA_FIELD_LENGTH]:
76 | zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
77 |
78 | if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & 0x800:
79 | # UTF-8 filename
80 | fname_str = fname.decode("utf-8")
81 | else:
82 | fname_str = fname.decode("cp437")
83 |
84 | if fname_str != zinfo.orig_filename:
85 | raise Exception(
86 | "File name in directory %r and header %r differ."
87 | % (zinfo.orig_filename, fname)
88 | )
89 |
90 | return True
91 |
92 |
93 | dtype_sizes = {
94 | torch.float64: 8,
95 | torch.float32: 4,
96 | torch.bfloat16: 2,
97 | torch.float16: 2,
98 | torch.int64: 8,
99 | torch.int32: 4,
100 | torch.int16: 2,
101 | torch.uint8: 1,
102 | torch.int8: 1,
103 | torch.bool: 1,
104 | # torch.quint8: 1, # TODO
105 | # torch.qint8: 1,
106 | }
107 |
108 | dtype_by_name = {
109 | "torch.float64": torch.float64,
110 | "torch.float32": torch.float32,
111 | "torch.bfloat16": torch.bfloat16,
112 | "torch.float16": torch.float16,
113 | "torch.int64": torch.int64,
114 | "torch.int32": torch.int32,
115 | "torch.int16": torch.int16,
116 | "torch.uint8": torch.uint8,
117 | "torch.int8": torch.int8,
118 | "torch.bool": torch.bool,
119 | # "torch.quint8": torch.quint8,
120 | # "torch.qint8": torch.qint8,
121 | }
122 |
123 |
124 | def load_tensor_partial(
125 | zipfile,
126 | fh,
127 | offset_index,
128 | dtype,
129 | numel,
130 | key,
131 | location,
132 | offset,
133 | use_uncompressed=True,
134 | ):
135 | name = f"data/{key}"
136 | znames = list(filter(lambda x: x.endswith(name), zipfile.namelist()))
137 | assert len(znames) == 1
138 | zname = znames[0]
139 |
140 | dsize = dtype_sizes[dtype]
141 | bbuffer = None
142 |
143 | if use_uncompressed:
144 | try:
145 | zero_offset = None
146 |
147 | # fast path
148 | if offset_index.get(zname) is not None:
149 | zero_offset = offset_index.get(zname)
150 | data_offset = zero_offset + offset
151 | fh.seek(data_offset)
152 | bbuffer = fh.read(dsize * numel)
153 | else:
154 | info = zipfile.getinfo(zname)
155 | is_uncompressed = (
156 | info.compress_size == info.file_size
157 | ) and info.compress_type == 0
158 |
159 | if is_uncompressed:
160 | fh.seek(info.header_offset)
161 | success = skip_header(fh, info)
162 |
163 | if success:
164 | zero_offset = fh.tell()
165 | offset_index[zname] = zero_offset
166 | data_offset = zero_offset + offset
167 | fh.seek(data_offset)
168 | bbuffer = fh.read(dsize * numel)
169 |
170 | assert len(bbuffer) == dsize * numel
171 | except Exception as e:
172 | print(f"[ZIPSLICER]: Exception during attempt at fast seek: {e}")
173 |
174 | # fallback uses python-native zipfile seek which becomes slow for large checkpoints
175 | if bbuffer is None:
176 | print("[ZIPSLICER]: fast torch storage seek failed, executing fallback")
177 | with zipfile.open(zname, "r") as zf:
178 | zf.seek(offset)
179 | bbuffer = zf.read(dsize * numel)
180 |
181 | storage = create_storage(bbuffer)
182 | return get_typed_storage(storage, dtype)
183 |
184 |
185 | # This class is meant to behave in a functionally similar manner to
186 | # the conventional PyTorch state_dict objects (instances of Python's OrderedDict).
187 | # It should be noted that as the docs say in https://peps.python.org/pep-0372/
188 | # "subclassing dict is a non-trivial task", and this implementation doesn't try
189 | # to be perfect. But it should work for pytorch checkpoint access, see tests.
190 | # Note that creation of new keys is not yet supported.
191 | # Caching isn't meant to be enabled just yet, but tcache member is used
192 | # for storage of some k-v pairs some apps may assign to state_dict object.
193 | # Currently, this class is meant to be created inside zipslicer.load method
194 | class LazyStateDict(OrderedDict):
195 | """
196 | A Lazy state_dict produced by zipslicer https://github.com:kir-gadjello/zipslicer
197 | This object should be used to access PyTorch checkpoints in an incremental way.
198 | If caching is disabled (by default) this object doesn't force the tensors and
199 | Torch Storage objects it created to stay in RAM - you can delete them safely.
200 | Updating and deleting the values in-place (without saving) is supported.
201 | Creation of new keys isn't supported just yet.
202 |
203 | Special methods: there is a 'get_meta' method for accessing torch tensor shapes
204 | for the available keys without loading the whole tensor from disk.
205 | """
206 |
207 | def __init__(
208 | self,
209 | tensors=None,
210 | extras=None, # TODO: we might have to handle extra data
211 | untie_weights=False, # TODO
212 | map_location="cpu",
213 | zipfile=None,
214 | fh=None,
215 | debug=False,
216 | dtype=None,
217 | cache_tensors=False,
218 | *args,
219 | **kwargs,
220 | ):
221 | super().__init__(*args, **kwargs)
222 |
223 | self.__lazy = True
224 | self.tensors = tensors
225 | self.zipfile = zipfile
226 | self.fh = fh
227 | self.offset_index = {}
228 | self.map_location = map_location
229 | self.untie_weights = untie_weights
230 | self.cache_tensors = cache_tensors
231 | self.debug = debug
232 | self.dtype = dtype
233 | self.tcache = {}
234 |
235 | for k in self.keys():
236 | self.validate_tensor_ref(k)
237 |
238 | def __del__(self):
239 | if self.zipfile is not None:
240 | if hasattr(self.zipfile, "close"):
241 | self.zipfile.close()
242 | del self.zipfile
243 | if self.fh is not None:
244 | if hasattr(self.fh, "close"):
245 | self.fh.close()
246 | del self.fh
247 |
248 | def __len__(self):
249 | return len(self.tensors.keys())
250 |
251 | def __setitem__(self, key, value):
252 | # Not supporting adding new keys for now
253 | if key not in self.tensors:
254 | raise KeyError(key)
255 |
256 | self.tcache[key] = value
257 |
258 | def __delitem__(self, key):
259 | if key not in self.tcache and key not in self.tensors:
260 | raise KeyError(key)
261 |
262 | if key in self.tcache:
263 | del self.tcache[key]
264 | if key in self.tensors:
265 | del self.tensors[key]
266 |
267 | def __getitem__(self, k):
268 | if k in self.tcache:
269 | return self.tcache[k]
270 | elif k in self.tensors:
271 | ret = self.reform_tensor(k)
272 | if self.cache_tensors:
273 | self.tcache[k] = ret
274 | return ret
275 | else:
276 | raise KeyError(k)
277 |
278 | def __contains__(self, k):
279 | return k in self.tcache.keys() or k in self.tensors.keys()
280 |
281 | def keys(self):
282 | # Fast path, cache is disabled and no keys were assigned
283 | if len(self.tcache) == 0:
284 | return self.tensors.keys()
285 | else:
286 | # A naive implementation, but should suffice
287 | return list(set(self.tcache.keys()) + set(self.tensors.keys()))
288 |
289 | def values(self):
290 | raise Exception(
291 | "LazyStateDict isn't meant for loading all values at once due to RAM constraints, reconsider your usage pattern"
292 | )
293 |
294 | def items(self):
295 | for k in self.keys():
296 | yield (k, self.__getitem__(k))
297 |
298 | def get_meta(self, k):
299 | if k not in self.tensors:
300 | raise KeyError(k)
301 |
302 | if k.endswith("._extra_state"):
303 | return None
304 |
305 | dtype = self.tensors[k]["args"][0]["dtype"]
306 | dtype = dtype_by_name[dtype]
307 | size = torch.Size(self.tensors[k]["args"][2])
308 |
309 | return types.SimpleNamespace(
310 | shape=size,
311 | size=lambda: size,
312 | dtype=dtype,
313 | )
314 |
315 | def validate_tensor_ref(self, k):
316 | if k not in self.tensors:
317 | raise KeyError(k)
318 |
319 | # Allow free-form ._extra_state for now: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state
320 | if k.endswith("._extra_state"):
321 | assert isinstance(self.tensors[k], dict)
322 | return True
323 |
324 | ref = self.tensors[k]
325 | assert ref["type"] == "stub_obj"
326 | assert ref["fn"] == "_rebuild_tensor_v2"
327 | storage_args = ref["args"][0]
328 |
329 | dtype = None
330 | try:
331 | dtype = dtype_by_name[storage_args["dtype"]]
332 | except Exception as e:
333 | print("Couldn't load tensor:", e)
334 | return None
335 |
336 | assert isinstance(dtype, torch.dtype)
337 |
338 | return True
339 |
340 | def reform_tensor(self, k):
341 | if k not in self.tensors:
342 | raise KeyError(k)
343 |
344 | if k.endswith("._extra_state"):
345 | return self.tensors[k]
346 |
347 | ref = self.tensors[k]
348 | storage_args = ref["args"][0]
349 | rebuild_tensor_args = ref["args"][1:]
350 | dtype = eval(storage_args["dtype"])
351 |
352 | (
353 | storage_offset,
354 | size,
355 | stride,
356 | requires_grad,
357 | backward_hooks,
358 | ) = rebuild_tensor_args
359 |
360 | assert dtype in dtype_sizes
361 |
362 | # TODO stride correctness checks
363 |
364 | dsize = dtype_sizes[dtype]
365 | numel = (
366 | reduce(lambda x, y: x * y, size) if size is not None and len(size) else 1
367 | )
368 |
369 | storage = load_tensor_partial(
370 | self.zipfile,
371 | self.fh,
372 | self.offset_index,
373 | dtype=dtype,
374 | numel=numel,
375 | key=storage_args["key"],
376 | location=self.map_location,
377 | offset=storage_offset * dsize,
378 | )
379 |
380 | ret = torch._utils._rebuild_tensor_v2(
381 | storage, 0, size, stride, requires_grad, backward_hooks
382 | )
383 |
384 | if self.dtype is not None and ret.dtype != dtype:
385 | ret = ret.to(dtype)
386 |
387 | return ret
388 |
389 |
390 | def load(
391 | ckpt,
392 | map_location="cpu",
393 | dtype=None,
394 | cache_tensors=False,
395 | debug=False,
396 | ):
397 | """
398 | Should behave similarly to torch.load, but operates incrementally.
399 | Loads accessed tensors on the fly.
400 | ckpt should be a valid file path for now
401 | cache_tensors should be False
402 | """
403 | assert map_location == "cpu"
404 | assert os.path.isfile(ckpt)
405 |
406 | tensors_meta = None
407 | with custom_load._open_zipfile_reader(open(ckpt, "rb")) as zf:
408 | try:
409 | tensors_meta = custom_load._custom_load(
410 | zf, torch.device(map_location), weights_only_unpickler
411 | )
412 | except Exception as e:
413 | raise pickle.UnpicklingError(
414 | f"Error at zipslicer.load bootstrap stage, your torch pickle checkpoint is likely too complex for the lightweight loader to interpret. Make sure your network was saved as a state_dict, instead of general-purpose network pickle. Exception was: {e}"
415 | )
416 |
417 | zipfile_h = zipfile.ZipFile(ckpt, "r", allowZip64=True)
418 |
419 | return LazyStateDict(
420 | tensors=tensors_meta,
421 | fh=open(ckpt, "rb"),
422 | zipfile=zipfile_h,
423 | map_location=map_location,
424 | debug=debug,
425 | dtype=dtype,
426 | cache_tensors=cache_tensors,
427 | )
428 |
--------------------------------------------------------------------------------
/zipslicer/custom_load.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023- Kirill Gadjello with most input coming from Pytorch contributors
2 | # See LICENSE for details (basically it uses parts of PyTorch sourcecode and is licensed under the same conditions)
3 |
4 | import difflib
5 | import os
6 |
7 | import io
8 | import shutil
9 | import struct
10 | import sys
11 | import torch
12 | import tarfile
13 | import pathlib
14 |
15 | import tempfile
16 | import warnings
17 |
18 | from contextlib import closing, contextmanager
19 |
20 | from torch._sources import get_source_lines_and_file
21 | from torch.types import Storage
22 | from torch.storage import _get_dtype_from_pickle_storage_type
23 |
24 | from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Union, IO, Type
25 | from typing_extensions import TypeAlias
26 |
27 | TORCH_UNTYPED_STORAGE_CLS = (
28 | torch.UntypedStorage if hasattr(torch, "UntypedStorage") else torch.ByteStorage
29 | )
30 |
31 | DEFAULT_PROTOCOL = 2
32 |
33 | LONG_SIZE = struct.Struct("=l").size
34 | INT_SIZE = struct.Struct("=i").size
35 | SHORT_SIZE = struct.Struct("=h").size
36 |
37 | MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
38 | PROTOCOL_VERSION = 1001
39 | STORAGE_KEY_SEPARATOR = ","
40 |
41 | FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
42 | MAP_LOCATION: TypeAlias = Optional[
43 | Union[
44 | Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]
45 | ]
46 | ]
47 |
48 |
49 | _string_classes = (str, bytes)
50 |
51 |
52 | def default_restore_location(storage, location):
53 | for _, _, fn in _package_registry:
54 | result = fn(storage, location)
55 | if result is not None:
56 | return result
57 | raise RuntimeError(
58 | "don't know how to restore data location of "
59 | + torch.typename(storage)
60 | + " (tagged with "
61 | + location
62 | + ")"
63 | )
64 |
65 |
66 | def _check_seekable(f) -> bool:
67 | def raise_err_msg(patterns, e):
68 | for p in patterns:
69 | if p in str(e):
70 | msg = (
71 | str(e)
72 | + ". You can only torch.load from a file that is seekable."
73 | + " Please pre-load the data into a buffer like io.BytesIO and"
74 | + " try to load from it instead."
75 | )
76 | raise type(e)(msg)
77 | raise e
78 |
79 | try:
80 | f.seek(f.tell())
81 | return True
82 | except (io.UnsupportedOperation, AttributeError) as e:
83 | raise_err_msg(["seek", "tell"], e)
84 | return False
85 |
86 |
87 | class SourceChangeWarning(Warning):
88 | pass
89 |
90 |
91 | @contextmanager
92 | def mkdtemp():
93 | path = tempfile.mkdtemp()
94 | yield path
95 | shutil.rmtree(path)
96 |
97 |
98 | _package_registry = []
99 |
100 |
101 | def _import_dotted_name(name):
102 | components = name.split(".")
103 | obj = __import__(components[0])
104 | for component in components[1:]:
105 | obj = getattr(obj, component)
106 | return obj
107 |
108 |
109 | def normalize_storage_type(storage_type):
110 | return getattr(torch, storage_type.__name__)
111 |
112 |
113 | def storage_to_tensor_type(storage):
114 | storage_type = type(storage)
115 | module = _import_dotted_name(storage_type.__module__)
116 | return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
117 |
118 |
119 | def _is_path(name_or_buffer):
120 | return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path)
121 |
122 |
123 | def _is_zipfile(f) -> bool:
124 | # This is a stricter implementation than zipfile.is_zipfile().
125 | # zipfile.is_zipfile() is True if the magic number appears anywhere in the
126 | # binary. Since we expect the files here to be generated by torch.save or
127 | # torch.jit.save, it's safe to only check the start bytes and avoid
128 | # collisions and assume the zip has only 1 file.
129 | # See bugs.python.org/issue28494.
130 |
131 | # Read the first 4 bytes of the file
132 | read_bytes = []
133 | start = f.tell()
134 |
135 | byte = f.read(1)
136 | while byte != b"":
137 | read_bytes.append(byte)
138 | if len(read_bytes) == 4:
139 | break
140 | byte = f.read(1)
141 | f.seek(start)
142 |
143 | local_header_magic_number = [b"P", b"K", b"\x03", b"\x04"]
144 | return read_bytes == local_header_magic_number
145 |
146 |
147 | class _opener(object):
148 | def __init__(self, file_like):
149 | self.file_like = file_like
150 |
151 | def __enter__(self):
152 | return self.file_like
153 |
154 | def __exit__(self, *args):
155 | pass
156 |
157 |
158 | class _open_file(_opener):
159 | def __init__(self, name, mode):
160 | super(_open_file, self).__init__(open(name, mode))
161 |
162 | def __exit__(self, *args):
163 | self.file_like.close()
164 |
165 |
166 | class _open_buffer_reader(_opener):
167 | def __init__(self, buffer):
168 | super(_open_buffer_reader, self).__init__(buffer)
169 | _check_seekable(buffer)
170 |
171 |
172 | class _open_buffer_writer(_opener):
173 | def __exit__(self, *args):
174 | self.file_like.flush()
175 |
176 |
177 | def _open_file_like(name_or_buffer, mode):
178 | if _is_path(name_or_buffer):
179 | return _open_file(name_or_buffer, mode)
180 | else:
181 | if "w" in mode:
182 | return _open_buffer_writer(name_or_buffer)
183 | elif "r" in mode:
184 | return _open_buffer_reader(name_or_buffer)
185 | else:
186 | raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
187 |
188 |
189 | class _open_zipfile_reader(_opener):
190 | def __init__(self, name_or_buffer) -> None:
191 | super(_open_zipfile_reader, self).__init__(
192 | torch._C.PyTorchFileReader(name_or_buffer)
193 | )
194 |
195 |
196 | class _open_zipfile_writer_file(_opener):
197 | def __init__(self, name) -> None:
198 | super(_open_zipfile_writer_file, self).__init__(
199 | torch._C.PyTorchFileWriter(str(name))
200 | )
201 |
202 | def __exit__(self, *args) -> None:
203 | self.file_like.write_end_of_file()
204 |
205 |
206 | class _open_zipfile_writer_buffer(_opener):
207 | def __init__(self, buffer) -> None:
208 | self.buffer = buffer
209 | super(_open_zipfile_writer_buffer, self).__init__(
210 | torch._C.PyTorchFileWriter(buffer)
211 | )
212 |
213 | def __exit__(self, *args) -> None:
214 | self.file_like.write_end_of_file()
215 | self.buffer.flush()
216 |
217 |
218 | def _open_zipfile_writer(name_or_buffer):
219 | container: Type[_opener]
220 | if _is_path(name_or_buffer):
221 | container = _open_zipfile_writer_file
222 | else:
223 | container = _open_zipfile_writer_buffer
224 | return container(name_or_buffer)
225 |
226 |
227 | def _is_compressed_file(f) -> bool:
228 | compress_modules = ["gzip"]
229 | try:
230 | return f.__module__ in compress_modules
231 | except AttributeError:
232 | return False
233 |
234 |
235 | def _should_read_directly(f):
236 | """
237 | Checks if f is a file that should be read directly. It should be read
238 | directly if it is backed by a real file (has a fileno) and is not a
239 | a compressed file (e.g. gzip)
240 | """
241 | if _is_compressed_file(f):
242 | return False
243 | try:
244 | return f.fileno() >= 0
245 | except io.UnsupportedOperation:
246 | return False
247 | except AttributeError:
248 | return False
249 |
250 |
251 | def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
252 | # When using encoding='bytes' in Py3, some **internal** keys stored as
253 | # strings in Py2 are loaded as bytes. This function decodes them with
254 | # ascii encoding, one that Py3 uses by default.
255 | #
256 | # NOTE: This should only be used on internal keys (e.g., `typename` and
257 | # `location` in `persistent_load` below!
258 | if isinstance(bytes_str, bytes):
259 | return bytes_str.decode("ascii")
260 | return bytes_str
261 |
262 |
263 | def _get_restore_location(map_location):
264 | if map_location is None:
265 | restore_location = default_restore_location
266 | elif isinstance(map_location, dict):
267 |
268 | def restore_location(storage, location):
269 | location = map_location.get(location, location)
270 | return default_restore_location(storage, location)
271 |
272 | elif isinstance(map_location, _string_classes):
273 |
274 | def restore_location(storage, location):
275 | return default_restore_location(storage, map_location)
276 |
277 | elif isinstance(map_location, torch.device):
278 |
279 | def restore_location(storage, location):
280 | return default_restore_location(storage, str(map_location))
281 |
282 | else:
283 |
284 | def restore_location(storage, location):
285 | result = map_location(storage, location)
286 | if result is None:
287 | result = default_restore_location(storage, location)
288 | return result
289 |
290 | return restore_location
291 |
292 |
293 | class StorageType:
294 | def __init__(self, name):
295 | self.dtype = _get_dtype_from_pickle_storage_type(name)
296 |
297 | def __str__(self):
298 | return f"StorageType(dtype={self.dtype})"
299 |
300 |
301 | def _load(
302 | zip_file, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args
303 | ):
304 | restore_location = _get_restore_location(map_location)
305 |
306 | loaded_storages = {}
307 |
308 | def load_tensor(dtype, numel, key, location):
309 | name = f"data/{key}"
310 |
311 | storage = (
312 | zip_file.get_storage_from_record(name, numel, TORCH_UNTYPED_STORAGE_CLS)
313 | .storage()
314 | .untyped()
315 | )
316 | # TODO: Once we decide to break serialization FC, we can
317 | # stop wrapping with TypedStorage
318 | loaded_storages[key] = torch.storage.TypedStorage(
319 | wrap_storage=restore_location(storage, location), dtype=dtype
320 | )
321 |
322 | def persistent_load(saved_id):
323 | assert isinstance(saved_id, tuple)
324 | typename = _maybe_decode_ascii(saved_id[0])
325 | data = saved_id[1:]
326 |
327 | assert (
328 | typename == "storage"
329 | ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
330 | storage_type, key, location, numel = data
331 | if storage_type is TORCH_UNTYPED_STORAGE_CLS:
332 | dtype = torch.uint8
333 | else:
334 | dtype = storage_type.dtype
335 |
336 | if key not in loaded_storages:
337 | nbytes = numel * torch._utils._element_size(dtype)
338 | load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
339 |
340 | return loaded_storages[key]
341 |
342 | load_module_mapping: Dict[str, str] = {
343 | # See https://github.com/pytorch/pytorch/pull/51633
344 | "torch.tensor": "torch._tensor"
345 | }
346 |
347 | # Need to subclass Unpickler instead of directly monkey-patching the find_class method
348 | # because it's marked readonly in pickle.
349 | # The type: ignore is because mypy can't statically determine the type of this class.
350 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
351 | # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
352 | # Lets us override the imports that pickle uses when unpickling an object.
353 | # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
354 | def find_class(self, mod_name, name):
355 | if type(name) is str and "Storage" in name:
356 | try:
357 | return StorageType(name)
358 | except KeyError:
359 | pass
360 | mod_name = load_module_mapping.get(mod_name, mod_name)
361 | return super().find_class(mod_name, name)
362 |
363 | # Load the data (which may in turn use `persistent_load` to load tensors)
364 | data_file = io.BytesIO(zip_file.get_record(pickle_file))
365 |
366 | unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
367 | unpickler.persistent_load = persistent_load
368 | result = unpickler.load()
369 |
370 | torch._utils._validate_loaded_sparse_tensors()
371 |
372 | return result
373 |
374 |
375 | def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
376 | deserialized_objects: Dict[int, Any] = {}
377 |
378 | restore_location = _get_restore_location(map_location)
379 |
380 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
381 | def find_class(self, mod_name, name):
382 | if type(name) is str and "Storage" in name:
383 | try:
384 | return StorageType(name)
385 | except KeyError:
386 | pass
387 | return super().find_class(mod_name, name)
388 |
389 | def _check_container_source(container_type, source_file, original_source):
390 | try:
391 | current_source = "".join(get_source_lines_and_file(container_type)[0])
392 | except Exception: # saving the source is optional, so we can ignore any errors
393 | warnings.warn(
394 | "Couldn't retrieve source code for container of "
395 | "type " + container_type.__name__ + ". It won't be checked "
396 | "for correctness upon loading."
397 | )
398 | return
399 | if original_source != current_source:
400 | if container_type.dump_patches:
401 | file_name = container_type.__name__ + ".patch"
402 | diff = difflib.unified_diff(
403 | current_source.split("\n"),
404 | original_source.split("\n"),
405 | source_file,
406 | source_file,
407 | lineterm="",
408 | )
409 | lines = "\n".join(diff)
410 | try:
411 | with open(file_name, "a+") as f:
412 | file_size = f.seek(0, 2)
413 | f.seek(0)
414 | if file_size == 0:
415 | f.write(lines)
416 | elif file_size != len(lines) or f.read() != lines:
417 | raise IOError
418 | msg = (
419 | "Saved a reverse patch to " + file_name + ". "
420 | "Run `patch -p0 < " + file_name + "` to revert your "
421 | "changes."
422 | )
423 | except IOError:
424 | msg = (
425 | "Tried to save a patch, but couldn't create a "
426 | "writable file " + file_name + ". Make sure it "
427 | "doesn't exist and your working directory is "
428 | "writable."
429 | )
430 | else:
431 | msg = (
432 | "you can retrieve the original source code by "
433 | "accessing the object's source attribute or set "
434 | "`torch.nn.Module.dump_patches = True` and use the "
435 | "patch tool to revert the changes."
436 | )
437 | msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
438 | warnings.warn(msg, SourceChangeWarning)
439 |
440 | def legacy_load(f):
441 | deserialized_objects: Dict[int, Any] = {}
442 |
443 | def persistent_load(saved_id):
444 | if isinstance(saved_id, tuple):
445 | # Ignore containers that don't have any sources saved
446 | if all(saved_id[1:]):
447 | _check_container_source(*saved_id)
448 | return saved_id[0]
449 | return deserialized_objects[int(saved_id)]
450 |
451 | with closing(
452 | tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
453 | ) as tar, mkdtemp() as tmpdir:
454 |
455 | tar.extract("storages", path=tmpdir)
456 | with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
457 | num_storages = pickle_module.load(f, **pickle_load_args)
458 | for i in range(num_storages):
459 | args = pickle_module.load(f, **pickle_load_args)
460 | key, location, storage_type = args
461 | dtype = storage_type.dtype
462 | obj = cast(Storage, TORCH_UNTYPED_STORAGE_CLS)._new_with_file(
463 | f, torch._utils._element_size(dtype)
464 | )
465 | obj = restore_location(obj, location)
466 | # TODO: Once we decide to break serialization FC, we can
467 | # stop wrapping with TypedStorage
468 | deserialized_objects[key] = torch.storage.TypedStorage(
469 | wrap_storage=obj, dtype=dtype
470 | )
471 |
472 | storage_views = pickle_module.load(f, **pickle_load_args)
473 | for target_cdata, root_cdata, offset, numel in storage_views:
474 | root = deserialized_objects[root_cdata]
475 | element_size = torch._utils._element_size(root.dtype)
476 | offset_bytes = offset * element_size
477 | # TODO: Once we decide to break serialization FC, we can
478 | # stop wrapping with TypedStorage
479 | deserialized_objects[target_cdata] = torch.storage.TypedStorage(
480 | wrap_storage=root._storage[
481 | offset_bytes : offset_bytes + numel * element_size
482 | ],
483 | dtype=root.dtype,
484 | )
485 |
486 | tar.extract("tensors", path=tmpdir)
487 | with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
488 | num_tensors = pickle_module.load(f, **pickle_load_args)
489 | for _ in range(num_tensors):
490 | args = pickle_module.load(f, **pickle_load_args)
491 | key, storage_id, original_tensor_type = args
492 | storage = deserialized_objects[storage_id]
493 | (ndim,) = struct.unpack(" maxsize:
268 | raise RuntimeError("String is too long")
269 | strval = str(read(strlen), "utf-8", "surrogatepass")
270 | self.append(strval)
271 | elif key[0] == SHORT_BINSTRING[0]:
272 | strlen = read(1)[0]
273 | strdata = read(strlen)
274 | if self.encoding != "bytes":
275 | strdata = strdata.decode(self.encoding, "strict")
276 | self.append(strdata)
277 | elif key[0] == BINPERSID[0]:
278 | pid = self.stack.pop()
279 | # Only allow persistent load of storage
280 | if type(pid) is not tuple and not type(pid) is not int:
281 | raise RuntimeError(
282 | f"persistent_load id must be tuple or int, but got {type(pid)}"
283 | )
284 | if (
285 | type(pid) is tuple
286 | and len(pid) > 0
287 | and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
288 | ):
289 | raise RuntimeError(
290 | f"Only persistent_load of storage is allowed, but got {pid[0]}"
291 | )
292 | self.append(self.persistent_load(pid))
293 | elif key[0] in [BINGET[0], LONG_BINGET[0]]:
294 | idx = (read(1) if key[0] == BINGET[0] else unpack("