├── .github └── workflows │ ├── pypi-deploy.yml │ └── python-test.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples └── example_resnet18.py ├── requirements.txt ├── setup.py ├── tests ├── test_checkpoint_readonly.py └── test_synthetic.py └── zipslicer ├── __init__.py ├── custom_load.py └── weights_only_unpickler.py /.github/workflows/pypi-deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # Then it will deploy the package to PyPI 3 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 4 | 5 | name: Deploy Python package to PyPI 6 | 7 | on: 8 | release: 9 | tags: 10 | - '*' 11 | 12 | jobs: 13 | tests: 14 | uses: "./.github/workflows/python-test.yml" 15 | publish: 16 | name: publish 17 | needs: [tests] # require tests to pass before deploy runs 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@main 21 | - name: Publish Python Package 22 | uses: mariamrf/py-package-publish-action@v1.1.0 23 | with: 24 | python_version: '3.10.0' 25 | env: 26 | TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} 27 | TWINE_USERNAME: __token__ 28 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Lint and Test Python package 5 | 6 | on: 7 | workflow_dispatch: 8 | workflow_call: # allow this workflow to be called from other workflows 9 | push: 10 | branches: [ "main" ] 11 | paths: 12 | - 'zipslicer/**' 13 | - 'tests/**' 14 | - 'examples/**' 15 | pull_request: 16 | branches: [ "main" ] 17 | 18 | jobs: 19 | tests: 20 | runs-on: ubuntu-latest 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | python-version: ["3.10", "3.11"] 25 | torch-version: ["1.11.0", "1.12.0", "stable"] 26 | exclude: # there is no easily available older torch build for newer python 27 | - python-version: "3.11" 28 | torch-version: "1.11.0" 29 | - python-version: "3.11" 30 | torch-version: "1.12.0" 31 | 32 | steps: 33 | - uses: actions/checkout@v3 34 | - name: Set up Python ${{ matrix.python-version }} 35 | uses: actions/setup-python@v4 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install testing and linting tools 39 | run: | 40 | python -m pip install --upgrade pip 41 | python -m pip install flake8 pytest 42 | - name: Lint with flake8 43 | run: | 44 | # stop the build if there are Python syntax errors or undefined names 45 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 47 | flake8 --ignore 'E501 W503 E203 E402 W293' . --count --exit-zero --max-complexity=40 --max-line-length=127 --statistics 48 | - name: Install dependencies, testing with cpu-only pytorch to save CI time, torch==${{matrix.torch-version}} 49 | run: | 50 | if [[ "stable" == ${{matrix.torch-version}} ]]; then 51 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 52 | else 53 | python -m pip install 'torch==${{matrix.torch-version}}+cpu' --extra-index-url https://download.pytorch.org/whl/cpu 54 | fi 55 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 56 | - name: Test with pytest 57 | run: | 58 | PYTHONPATH="." pytest -o log_cli=true --capture=tee-sys -p no:asyncio . 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *~ 3 | *.pth 4 | __pycache__ 5 | .pytest_cache 6 | *.py[cod] 7 | *$py.class 8 | *.backup 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All contributions by Kirill Gadjello: 2 | Copyright (c) 2023- Kirill Gadjello. 3 | 4 | From PyTorch: 5 | 6 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 7 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 8 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 9 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 10 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 11 | Copyright (c) 2011-2013 NYU (Clement Farabet) 12 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 13 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 14 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 15 | 16 | From Caffe2: 17 | 18 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 19 | 20 | All contributions by Facebook: 21 | Copyright (c) 2016 Facebook Inc. 22 | 23 | All contributions by Google: 24 | Copyright (c) 2015 Google Inc. 25 | All rights reserved. 26 | 27 | All contributions by Yangqing Jia: 28 | Copyright (c) 2015 Yangqing Jia 29 | All rights reserved. 30 | 31 | All contributions by Kakao Brain: 32 | Copyright 2019-2020 Kakao Brain 33 | 34 | All contributions by Cruise LLC: 35 | Copyright (c) 2022 Cruise LLC. 36 | All rights reserved. 37 | 38 | All contributions from Caffe: 39 | Copyright(c) 2013, 2014, 2015, the respective contributors 40 | All rights reserved. 41 | 42 | All other contributions: 43 | Copyright(c) 2015, 2016 the respective contributors 44 | All rights reserved. 45 | 46 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 47 | copyright over their contributions to Caffe2. The project versioning records 48 | all such contribution and copyright details. If a contributor wants to further 49 | mark their specific copyright on a particular contribution, they should 50 | indicate their copyright solely in the commit message of the change when it is 51 | committed. 52 | 53 | All rights reserved. 54 | 55 | Redistribution and use in source and binary forms, with or without 56 | modification, are permitted provided that the following conditions are met: 57 | 58 | 1. Redistributions of source code must retain the above copyright 59 | notice, this list of conditions and the following disclaimer. 60 | 61 | 2. Redistributions in binary form must reproduce the above copyright 62 | notice, this list of conditions and the following disclaimer in the 63 | documentation and/or other materials provided with the distribution. 64 | 65 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 66 | and IDIAP Research Institute nor the names of its contributors may be 67 | used to endorse or promote products derived from this software without 68 | specific prior written permission. 69 | 70 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 71 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 72 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 73 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 74 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 75 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 76 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 77 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 78 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 79 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 80 | POSSIBILITY OF SUCH DAMAGE. 81 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | recursive-include examples *.py 3 | recursive-include tests *.py 4 | recursive-include zipslicer *.py 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *ZIPSLICER* 📁✂️ 2 | [![Lint and Test Python package](https://github.com/kir-gadjello/zipslicer/actions/workflows/python-test.yml/badge.svg)](https://github.com/kir-gadjello/zipslicer/actions/workflows/python-test.yml) 3 | [![Published to PyPI](https://github.com/kir-gadjello/zipslicer/actions/workflows/pypi-deploy.yml/badge.svg)](https://github.com/kir-gadjello/zipslicer/actions/workflows/pypi-deploy.yml) 4 | 5 | A library for incremental loading of large PyTorch checkpoints
6 | [Read a blogpost introduction by yours truly](https://kir-gadjello.github.io/zipslicer) 7 | 8 | ## Synopsis 9 | ```python 10 | import torch 11 | import zipslicer 12 | 13 | # Could be a private custom recurrent sentient transformer 14 | # instead of a garden variety resnet 15 | my_complicated_network = torch.hub.load( 16 | "pytorch/vision:v0.10.0", "resnet18", pretrained=True 17 | ) 18 | s_dict = my_complicated_network.state_dict() 19 | torch.save(s_dict, "my_network_checkpoint_v123.pth") 20 | del my_complicated_network 21 | 22 | # Later, on a smaller unrelated machine you load a "LazyStateDict" 23 | # Which is just like a regular state dict, but it loads tensors only when it has to 24 | lazy_s_dict = zipslicer.load("my_network_checkpoint_v123.pth") 25 | layer3_tensors = {} 26 | for k in lazy_s_dict.keys(): 27 | if k.startswith("layer3"): 28 | layer3_tensors[k] = lazy_s_dict[k] 29 | # Now you have layer3's tensors and you can analyze them without breaking your RAM. 30 | # Or you can instantiate the layers' classes in sequence and compute the whole 31 | # network's output for a given input by threading the activations through them. 32 | # But we will just print the tensors instead: 33 | print(layer3_tensors) 34 | ``` 35 | 36 | Run this example and unit-tests: 37 | 38 | `python examples/example_resnet18.py` 39 | 40 | `pytest -o log_cli=true --capture=tee-sys -p no:asyncio` 41 | 42 | Test your checkpoint for compatibility: 43 | 44 | `python tests/test_checkpoint_readonly.py your_magnificent_checkpoint.pth` 45 | 46 | If it's all green, it will work. 47 | 48 | ## Prerequisites 49 | * Supported python and torch versions: `python-3.10 + torch-(1.11,1.12,stable)` `python-3.11 + torch:stable` 50 | * Generally, `zipslicer` should work with modern enough install of PyTorch - use [included safe test](https://github.com/kir-gadjello/zipslicer/blob/main/tests/test_checkpoint_readonly.py) to check for compatibility of `zipslicer` with your PyTorch and your checkpoint. This is a pure Python library, so specific CPU architecture shouldn't matter. 51 | * A checkpoint produced by saving your model's `state_dict` via vanilla torch.save(...) - default settings should suffice, as Torch doesn't use ZIP compression. 52 | * An application that can take advantage of incrementally-loaded checkpoint - i.e. if your app just loads all `state_dict.items()` in a loop right away it doesn't make much sense to use this library. Make sure your code reads `state_dict.keys()` (and `state_dict.get_meta(k)` if necessary) and uses these intelligently to work on a subset of `state_dict[k]` tensors at a time. For general inspiration you might read [this (HF)](https://huggingface.co/docs/transformers/v4.26.0/en/main_classes/model#transformers.modeling_utils.load_sharded_checkpoint) and [this (arxiv)](https://arxiv.org/abs/2104.07857). With some additional engineering it should be possible to run Large Language Models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) or [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) on a single mid-range GPU at home - if you are willing to wait for a night's worth of time. In the large batch regime this might even make some practical sense, for example to process a set of documents into embeddings. 53 | 54 | ## Install 55 | 56 | Generally, copying the `zipslicer/zipslicer` directory into your project's source tree is enough. 57 | 58 | If you are a fan of official ceremony-driven install processes for executable modules of dubious provenance, soon there will be a possibility of installing this boutique software module via pip: `pip install zipslicer` 59 | 60 | ## Notes 61 | * This library is only for reading pytorch tensors from checkpoints. We leave writing for future work. 62 | * Writing to loaded `state_dict` is frowned upon, but it *will* work - though you should avoid doing this while iterating over keys for now and expecting the keys to reflect this update. 63 | * Perhaps more importantly, **general-purpose pickles are not supported** - the design of this library doesn't allow you to load whole neural network class instances. Usually this isn't necessary, and [pytorch official documentation recommends you to use `state_dict` for model serialization](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). We support `state_dict`'s. 64 | * Some rare tensor types (i.e: pytorch quantized tensors - not to be confused with integer tensors which work fine) are not yet supported. If this bothers you, share your experience in issues. 65 | * We say "Hi" to [HF `safetensors` project](https://github.com/huggingface/safetensors), but note that in comparison to theirs, our approach doesn't require checkpoint conversion which takes significant time and storage. In fact, both approaches could be complementary, as you will have to load tensors from the pytorch checkpoint somehow to convert it to `safetensors` - and the default loading mechanism is constrained by available RAM. 66 | 67 | ## Prospective features we are considering 68 | If you are interested in some of these features, consider creating an issue: 69 | * Effective loading of tensor slices - to implement tensor parallelism in sharded deployments 70 | * Accessing the source checkpoint over a network 71 | * Writing to a checkpoint in-place 72 | * Incremental conversion to other checkpoint formats 73 | -------------------------------------------------------------------------------- /examples/example_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | import zipslicer 5 | except Exception: 6 | import sys 7 | 8 | sys.path.append("./zipslicer") 9 | import zipslicer 10 | 11 | 12 | # Could be a private custom recurrent sentient transformer 13 | # instead of a garden variety resnet 14 | my_complicated_network = torch.hub.load( 15 | "pytorch/vision:v0.10.0", "resnet18", pretrained=True 16 | ) 17 | s_dict = my_complicated_network.state_dict() 18 | torch.save(s_dict, "my_network_checkpoint_v123.pth") 19 | del my_complicated_network 20 | 21 | # Later, on a smaller unrelated machine you load a "LazyStateDict" 22 | # Which is just like a regular state dict, but it loads tensors only when it has to 23 | lazy_s_dict = zipslicer.load("my_network_checkpoint_v123.pth") 24 | layer3_tensors = {} 25 | for k in lazy_s_dict.keys(): 26 | if k.startswith("layer3"): 27 | layer3_tensors[k] = lazy_s_dict[k] 28 | # Now you have layer3's tensors and you can analyze them without breaking your RAM. 29 | # Or you can load the layers' classes in sequence and compute the whole network's output. 30 | # But we will just print the tensors: 31 | print(layer3_tensors) 32 | 33 | # Cleanup after our experiment 34 | import os 35 | os.unlink("my_network_checkpoint_v123.pth") 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.11 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | __version__ = "0.8.1" 4 | 5 | python_min_version = (3, 8, 0) 6 | version_range_max = 12 7 | 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | setup( 12 | name="zipslicer", 13 | version=__version__, 14 | description="A library for efficient incremental access to tensors stored in PyTorch checkpoints", 15 | packages=["zipslicer"], 16 | install_requires=["torch >= 1.10.0"], 17 | extras_require={ 18 | "dev": [ 19 | "pytest >= 3.10", 20 | ] 21 | }, 22 | # PyPI package information from pytorch 23 | classifiers=[ 24 | "Development Status :: 5 - Production/Stable", 25 | "Intended Audience :: Developers", 26 | "Intended Audience :: Education", 27 | "Intended Audience :: Science/Research", 28 | "License :: OSI Approved :: BSD License", 29 | "Topic :: Scientific/Engineering", 30 | "Topic :: Scientific/Engineering :: Mathematics", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "Topic :: Software Development", 33 | "Topic :: Software Development :: Libraries", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | "Programming Language :: Python :: 3", 36 | ] + [ 37 | "Programming Language :: Python :: 3.{}".format(i) 38 | for i in range(python_min_version[1], version_range_max) 39 | ], 40 | license="BSD-3", 41 | keywords="pytorch, machine learning", 42 | python_requires=">=3", 43 | long_description=long_description, 44 | long_description_content_type="text/markdown", 45 | author="Kirill Gadjello", 46 | author_email="kirill.gadjello@protonmail.com", 47 | url="https://github.com/kir-gadjello/zipslicer", 48 | ) 49 | -------------------------------------------------------------------------------- /tests/test_checkpoint_readonly.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023- Kirill Gadjello. 2 | # See LICENSE for details (basically it uses part of PyTorch sourcecode and is licensed under the same conditions) 3 | 4 | # Run it like this (from zipslicer repo root directory): 5 | # python ./tests/test_checkpoint_readonly.py 'path_to_your_checkpoint.pth' 6 | 7 | import os 8 | import sys 9 | import time 10 | import torch 11 | import random 12 | 13 | sys.path.append("./zipslicer") 14 | 15 | import zipslicer 16 | 17 | cgreen = "\033[92m" 18 | cyellow = "\033[93m" 19 | creset = "\033[0m" 20 | ok_green = f"{cgreen}[OK]{creset}" 21 | 22 | 23 | def __test_incremental_load(ckpt=None, seed=1337): 24 | random.seed(int(os.environ.get("ZIPSLICER_TEST_SEED", seed))) 25 | 26 | print_note = False 27 | if ckpt is None: 28 | if len(sys.argv) <= 1: 29 | print( 30 | "Usage:\n\tpython ./tests/test_checkpoint_readonly.py 'path_to_your_checkpoint.pth'" 31 | ) 32 | sys.exit(-1) 33 | ckpt = sys.argv[1] 34 | print_note = True 35 | 36 | assert os.path.isfile(ckpt) 37 | if print_note: 38 | print(f'Using "{cyellow}{ckpt}{creset}" in {cgreen}readonly{creset} mode') 39 | print("=" * (os.get_terminal_size().columns)) 40 | print( 41 | "Note: this test loads two copies of the checkpoint, one using standard torch.load and the other using zipslicer. You need enough CPU RAM to fit both, or you risk unresponsive behavior and massive swapping from your machine." 42 | ) 43 | print("=" * (os.get_terminal_size().columns)) 44 | 45 | sdict = torch.load(ckpt, map_location="cpu") 46 | skeys = sdict.keys() 47 | lazy_sdict = zipslicer.load( 48 | ckpt, map_location="cpu", debug=os.environ.get("ZIPSLICER_DEBUG") == "1" 49 | ) 50 | lazy_keys = lazy_sdict.keys() 51 | 52 | print("Checking basic key correspondence") 53 | for k in skeys: 54 | assert k in lazy_keys 55 | 56 | for k in lazy_keys: 57 | assert k in skeys 58 | print(f"{ok_green}: {len(skeys)} keys total") 59 | 60 | print("Checking tensor metadata correspondence") 61 | for k, v in sdict.items(): 62 | meta = lazy_sdict.get_meta(k) 63 | if k.endswith("._extra_state") and not isinstance(v, torch.Tensor): 64 | assert meta is None 65 | continue 66 | 67 | assert meta.shape == v.shape 68 | assert meta.size() == v.size() 69 | assert meta.dtype == v.dtype 70 | print(f"{ok_green}: {len(skeys)} keys total") 71 | 72 | test_keys = list(skeys) 73 | 74 | if os.environ.get("ZIPSLICER_TEST_SUBSET"): 75 | ratio = float(os.environ.get("ZIPSLICER_TEST_SUBSET")) 76 | random.shuffle(test_keys) 77 | N = int(len(test_keys) * ratio) 78 | test_keys = test_keys[:N] 79 | print(f"Using randomized key subset of length {N} for testing") 80 | 81 | N = len(test_keys) 82 | for i, k in enumerate(test_keys): 83 | print(f"[{i+1}/{N}] Checking key: {k}", end=" ") 84 | t0 = time.time_ns() 85 | T = sdict[k] 86 | LT = lazy_sdict[k] 87 | 88 | if k.endswith("._extra_state") and not isinstance(T, torch.Tensor): 89 | assert T == LT 90 | else: 91 | assert T.dtype == LT.dtype 92 | assert T.shape == LT.shape 93 | assert torch.allclose(T, LT) 94 | 95 | dt = time.time_ns() - t0 96 | print(f"{ok_green} in {round(dt/1e6, 2)}ms") 97 | 98 | del sdict 99 | del lazy_sdict 100 | 101 | 102 | if __name__ == "__main__": 103 | __test_incremental_load() 104 | print(f"{ok_green} All tests passed successfully") 105 | -------------------------------------------------------------------------------- /tests/test_synthetic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023- Kirill Gadjello. 2 | # See LICENSE for details (basically it uses part of PyTorch sourcecode and is licensed under the same conditions) 3 | 4 | import os 5 | import sys 6 | import torch 7 | import random 8 | from test_checkpoint_readonly import __test_incremental_load 9 | 10 | sys.path.append("./zipslicer") 11 | 12 | import zipslicer 13 | 14 | seed = int(os.environ.get("ZIPSLICER_TEST_SEED", "1337")) 15 | 16 | 17 | def test_basic(): 18 | FNAME = "test_basic.pth" 19 | torch.manual_seed(seed) 20 | 21 | sdict = dict( 22 | a=torch.randn(10, 20, 3, dtype=torch.float32), 23 | longer_name=torch.randn(10, 20, 3, dtype=torch.bfloat16), 24 | ) 25 | 26 | torch.save(sdict, FNAME) 27 | __test_incremental_load(ckpt=FNAME) 28 | os.unlink(FNAME) 29 | 30 | 31 | def test_various_dtypes(): 32 | FNAME = "test_various_dtypes.pth" 33 | torch.manual_seed(seed) 34 | random.seed(seed) 35 | 36 | sdict = dict() 37 | for dtype in zipslicer.dtype_sizes.keys(): 38 | key = ".".join(str(random.randint(0, 2**16)) for _ in range(6)) 39 | # TODO: quantized tensor support 40 | if "q" not in str(dtype): 41 | t = ( 42 | torch.randn( 43 | random.randint(1, 16), 44 | random.randint(1, 16), 45 | random.randint(1, 16), 46 | dtype=torch.float32, 47 | ) 48 | * 200.0 49 | ).to(dtype) 50 | 51 | sdict[key] = t 52 | 53 | torch.save(sdict, FNAME) 54 | __test_incremental_load(ckpt=FNAME) 55 | os.unlink(FNAME) 56 | 57 | 58 | def test_nn_sdict(): 59 | FNAME = "test_nn_sdict.pth" 60 | torch.manual_seed(seed) 61 | 62 | network = torch.nn.ModuleList( 63 | [torch.nn.Linear(1000, 2000), torch.nn.Linear(2000, 2000)] 64 | ) 65 | 66 | sdict = network.state_dict() 67 | 68 | torch.save(sdict, FNAME) 69 | __test_incremental_load(ckpt=FNAME) 70 | os.unlink(FNAME) 71 | 72 | 73 | def test_nn_sdict_w_extra_state(): 74 | FNAME = "test_nn_sdict_w_extra_state.pth" 75 | torch.manual_seed(seed) 76 | 77 | class CustomLinear(torch.nn.Linear): 78 | def get_extra_state(self): 79 | return dict( 80 | a=random.randint(1, 2**64 - 1), b="this is extra state", c=[1, 2, 3] 81 | ) 82 | 83 | network = torch.nn.ModuleList([CustomLinear(1000, 2000), CustomLinear(2000, 2000)]) 84 | 85 | sdict = network.state_dict() 86 | 87 | torch.save(sdict, FNAME) 88 | __test_incremental_load(ckpt=FNAME) 89 | os.unlink(FNAME) 90 | 91 | 92 | def test_nn_pickle_raises(): 93 | FNAME = "test_nn_pickle.pth" 94 | torch.manual_seed(seed) 95 | 96 | network = torch.nn.ModuleList( 97 | [torch.nn.Linear(1000, 2000), torch.nn.Linear(2000, 2000)] 98 | ) 99 | 100 | torch.save(network, FNAME) 101 | 102 | try: 103 | zipslicer.load(FNAME) 104 | except Exception as e: 105 | assert ( 106 | "Error at zipslicer.load bootstrap stage, your torch pickle checkpoint is likely too complex for the lightweight loader to interpret. Make sure your network was saved as a state_dict, instead of general-purpose network pickle" 107 | in str(e) 108 | ) 109 | 110 | os.unlink(FNAME) 111 | -------------------------------------------------------------------------------- /zipslicer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023- Kirill Gadjello. 2 | # See LICENSE for details (basically it uses parts of PyTorch sourcecode and is licensed under the same conditions) 3 | 4 | import os 5 | from functools import reduce 6 | import types 7 | from collections import OrderedDict 8 | import pickle 9 | import zipfile 10 | import struct 11 | 12 | import torch 13 | 14 | from . import custom_load 15 | from . import weights_only_unpickler 16 | 17 | 18 | def create_storage(buf): 19 | if hasattr(torch, "UntypedStorage"): 20 | return torch.UntypedStorage.from_buffer(buf, dtype=torch.uint8) 21 | elif hasattr(torch, "_UntypedStorage"): 22 | # fallback for older torch versions which don't have UntypedStorage 23 | return torch._UntypedStorage.from_buffer(buf, dtype=torch.uint8) 24 | else: 25 | # fallback for older torch versions which don't have _UntypedStorage 26 | return torch.ByteStorage.from_buffer(buf) 27 | 28 | 29 | def get_typed_storage(backing_storage, dtype): 30 | # TODO: Upstream pytorch might change the semantics here eventually 31 | if hasattr(torch, "TypedStorage"): 32 | return torch.storage.TypedStorage(wrap_storage=backing_storage, dtype=dtype) 33 | else: 34 | legacy_storage_types = { 35 | torch.bfloat16: torch.BFloat16Storage, 36 | torch.half: torch.HalfStorage, 37 | torch.float: torch.FloatStorage, 38 | torch.double: torch.DoubleStorage, 39 | torch.int8: torch.CharStorage, 40 | torch.uint8: torch.ByteStorage, 41 | torch.short: torch.ShortStorage, 42 | torch.int: torch.IntStorage, 43 | torch.long: torch.LongStorage, 44 | torch.bool: torch.BoolStorage, 45 | } 46 | if dtype not in legacy_storage_types: 47 | raise Exception( 48 | "Failed to create Torch Storage object, your Torch version is probably too old" 49 | ) 50 | return legacy_storage_types[dtype](wrap_storage=backing_storage) 51 | 52 | 53 | # ZIP "local file header" structure, magic number, size, and indices 54 | # (section V.A in the format document) 55 | structFileHeader = "<4s2B4HL2L2H" 56 | stringFileHeader = b"PK\003\004" 57 | _FH_SIGNATURE = 0 58 | _FH_EXTRA_FIELD_LENGTH = 11 59 | _FH_FILENAME_LENGTH = 10 60 | _FH_GENERAL_PURPOSE_FLAG_BITS = 3 61 | 62 | sizeFileHeader = struct.calcsize(structFileHeader) 63 | 64 | 65 | def skip_header(zef_file, zinfo): 66 | zef_file.seek(zinfo.header_offset) 67 | fheader = zef_file.read(sizeFileHeader) 68 | if len(fheader) != sizeFileHeader: 69 | raise Exception("Truncated file header") 70 | fheader = struct.unpack(structFileHeader, fheader) 71 | if fheader[_FH_SIGNATURE] != stringFileHeader: 72 | raise Exception("Bad magic number for file header") 73 | 74 | fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) 75 | if fheader[_FH_EXTRA_FIELD_LENGTH]: 76 | zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) 77 | 78 | if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & 0x800: 79 | # UTF-8 filename 80 | fname_str = fname.decode("utf-8") 81 | else: 82 | fname_str = fname.decode("cp437") 83 | 84 | if fname_str != zinfo.orig_filename: 85 | raise Exception( 86 | "File name in directory %r and header %r differ." 87 | % (zinfo.orig_filename, fname) 88 | ) 89 | 90 | return True 91 | 92 | 93 | dtype_sizes = { 94 | torch.float64: 8, 95 | torch.float32: 4, 96 | torch.bfloat16: 2, 97 | torch.float16: 2, 98 | torch.int64: 8, 99 | torch.int32: 4, 100 | torch.int16: 2, 101 | torch.uint8: 1, 102 | torch.int8: 1, 103 | torch.bool: 1, 104 | # torch.quint8: 1, # TODO 105 | # torch.qint8: 1, 106 | } 107 | 108 | dtype_by_name = { 109 | "torch.float64": torch.float64, 110 | "torch.float32": torch.float32, 111 | "torch.bfloat16": torch.bfloat16, 112 | "torch.float16": torch.float16, 113 | "torch.int64": torch.int64, 114 | "torch.int32": torch.int32, 115 | "torch.int16": torch.int16, 116 | "torch.uint8": torch.uint8, 117 | "torch.int8": torch.int8, 118 | "torch.bool": torch.bool, 119 | # "torch.quint8": torch.quint8, 120 | # "torch.qint8": torch.qint8, 121 | } 122 | 123 | 124 | def load_tensor_partial( 125 | zipfile, 126 | fh, 127 | offset_index, 128 | dtype, 129 | numel, 130 | key, 131 | location, 132 | offset, 133 | use_uncompressed=True, 134 | ): 135 | name = f"data/{key}" 136 | znames = list(filter(lambda x: x.endswith(name), zipfile.namelist())) 137 | assert len(znames) == 1 138 | zname = znames[0] 139 | 140 | dsize = dtype_sizes[dtype] 141 | bbuffer = None 142 | 143 | if use_uncompressed: 144 | try: 145 | zero_offset = None 146 | 147 | # fast path 148 | if offset_index.get(zname) is not None: 149 | zero_offset = offset_index.get(zname) 150 | data_offset = zero_offset + offset 151 | fh.seek(data_offset) 152 | bbuffer = fh.read(dsize * numel) 153 | else: 154 | info = zipfile.getinfo(zname) 155 | is_uncompressed = ( 156 | info.compress_size == info.file_size 157 | ) and info.compress_type == 0 158 | 159 | if is_uncompressed: 160 | fh.seek(info.header_offset) 161 | success = skip_header(fh, info) 162 | 163 | if success: 164 | zero_offset = fh.tell() 165 | offset_index[zname] = zero_offset 166 | data_offset = zero_offset + offset 167 | fh.seek(data_offset) 168 | bbuffer = fh.read(dsize * numel) 169 | 170 | assert len(bbuffer) == dsize * numel 171 | except Exception as e: 172 | print(f"[ZIPSLICER]: Exception during attempt at fast seek: {e}") 173 | 174 | # fallback uses python-native zipfile seek which becomes slow for large checkpoints 175 | if bbuffer is None: 176 | print("[ZIPSLICER]: fast torch storage seek failed, executing fallback") 177 | with zipfile.open(zname, "r") as zf: 178 | zf.seek(offset) 179 | bbuffer = zf.read(dsize * numel) 180 | 181 | storage = create_storage(bbuffer) 182 | return get_typed_storage(storage, dtype) 183 | 184 | 185 | # This class is meant to behave in a functionally similar manner to 186 | # the conventional PyTorch state_dict objects (instances of Python's OrderedDict). 187 | # It should be noted that as the docs say in https://peps.python.org/pep-0372/ 188 | # "subclassing dict is a non-trivial task", and this implementation doesn't try 189 | # to be perfect. But it should work for pytorch checkpoint access, see tests. 190 | # Note that creation of new keys is not yet supported. 191 | # Caching isn't meant to be enabled just yet, but tcache member is used 192 | # for storage of some k-v pairs some apps may assign to state_dict object. 193 | # Currently, this class is meant to be created inside zipslicer.load method 194 | class LazyStateDict(OrderedDict): 195 | """ 196 | A Lazy state_dict produced by zipslicer https://github.com:kir-gadjello/zipslicer 197 | This object should be used to access PyTorch checkpoints in an incremental way. 198 | If caching is disabled (by default) this object doesn't force the tensors and 199 | Torch Storage objects it created to stay in RAM - you can delete them safely. 200 | Updating and deleting the values in-place (without saving) is supported. 201 | Creation of new keys isn't supported just yet. 202 | 203 | Special methods: there is a 'get_meta' method for accessing torch tensor shapes 204 | for the available keys without loading the whole tensor from disk. 205 | """ 206 | 207 | def __init__( 208 | self, 209 | tensors=None, 210 | extras=None, # TODO: we might have to handle extra data 211 | untie_weights=False, # TODO 212 | map_location="cpu", 213 | zipfile=None, 214 | fh=None, 215 | debug=False, 216 | dtype=None, 217 | cache_tensors=False, 218 | *args, 219 | **kwargs, 220 | ): 221 | super().__init__(*args, **kwargs) 222 | 223 | self.__lazy = True 224 | self.tensors = tensors 225 | self.zipfile = zipfile 226 | self.fh = fh 227 | self.offset_index = {} 228 | self.map_location = map_location 229 | self.untie_weights = untie_weights 230 | self.cache_tensors = cache_tensors 231 | self.debug = debug 232 | self.dtype = dtype 233 | self.tcache = {} 234 | 235 | for k in self.keys(): 236 | self.validate_tensor_ref(k) 237 | 238 | def __del__(self): 239 | if self.zipfile is not None: 240 | if hasattr(self.zipfile, "close"): 241 | self.zipfile.close() 242 | del self.zipfile 243 | if self.fh is not None: 244 | if hasattr(self.fh, "close"): 245 | self.fh.close() 246 | del self.fh 247 | 248 | def __len__(self): 249 | return len(self.tensors.keys()) 250 | 251 | def __setitem__(self, key, value): 252 | # Not supporting adding new keys for now 253 | if key not in self.tensors: 254 | raise KeyError(key) 255 | 256 | self.tcache[key] = value 257 | 258 | def __delitem__(self, key): 259 | if key not in self.tcache and key not in self.tensors: 260 | raise KeyError(key) 261 | 262 | if key in self.tcache: 263 | del self.tcache[key] 264 | if key in self.tensors: 265 | del self.tensors[key] 266 | 267 | def __getitem__(self, k): 268 | if k in self.tcache: 269 | return self.tcache[k] 270 | elif k in self.tensors: 271 | ret = self.reform_tensor(k) 272 | if self.cache_tensors: 273 | self.tcache[k] = ret 274 | return ret 275 | else: 276 | raise KeyError(k) 277 | 278 | def __contains__(self, k): 279 | return k in self.tcache.keys() or k in self.tensors.keys() 280 | 281 | def keys(self): 282 | # Fast path, cache is disabled and no keys were assigned 283 | if len(self.tcache) == 0: 284 | return self.tensors.keys() 285 | else: 286 | # A naive implementation, but should suffice 287 | return list(set(self.tcache.keys()) + set(self.tensors.keys())) 288 | 289 | def values(self): 290 | raise Exception( 291 | "LazyStateDict isn't meant for loading all values at once due to RAM constraints, reconsider your usage pattern" 292 | ) 293 | 294 | def items(self): 295 | for k in self.keys(): 296 | yield (k, self.__getitem__(k)) 297 | 298 | def get_meta(self, k): 299 | if k not in self.tensors: 300 | raise KeyError(k) 301 | 302 | if k.endswith("._extra_state"): 303 | return None 304 | 305 | dtype = self.tensors[k]["args"][0]["dtype"] 306 | dtype = dtype_by_name[dtype] 307 | size = torch.Size(self.tensors[k]["args"][2]) 308 | 309 | return types.SimpleNamespace( 310 | shape=size, 311 | size=lambda: size, 312 | dtype=dtype, 313 | ) 314 | 315 | def validate_tensor_ref(self, k): 316 | if k not in self.tensors: 317 | raise KeyError(k) 318 | 319 | # Allow free-form ._extra_state for now: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state 320 | if k.endswith("._extra_state"): 321 | assert isinstance(self.tensors[k], dict) 322 | return True 323 | 324 | ref = self.tensors[k] 325 | assert ref["type"] == "stub_obj" 326 | assert ref["fn"] == "_rebuild_tensor_v2" 327 | storage_args = ref["args"][0] 328 | 329 | dtype = None 330 | try: 331 | dtype = dtype_by_name[storage_args["dtype"]] 332 | except Exception as e: 333 | print("Couldn't load tensor:", e) 334 | return None 335 | 336 | assert isinstance(dtype, torch.dtype) 337 | 338 | return True 339 | 340 | def reform_tensor(self, k): 341 | if k not in self.tensors: 342 | raise KeyError(k) 343 | 344 | if k.endswith("._extra_state"): 345 | return self.tensors[k] 346 | 347 | ref = self.tensors[k] 348 | storage_args = ref["args"][0] 349 | rebuild_tensor_args = ref["args"][1:] 350 | dtype = eval(storage_args["dtype"]) 351 | 352 | ( 353 | storage_offset, 354 | size, 355 | stride, 356 | requires_grad, 357 | backward_hooks, 358 | ) = rebuild_tensor_args 359 | 360 | assert dtype in dtype_sizes 361 | 362 | # TODO stride correctness checks 363 | 364 | dsize = dtype_sizes[dtype] 365 | numel = ( 366 | reduce(lambda x, y: x * y, size) if size is not None and len(size) else 1 367 | ) 368 | 369 | storage = load_tensor_partial( 370 | self.zipfile, 371 | self.fh, 372 | self.offset_index, 373 | dtype=dtype, 374 | numel=numel, 375 | key=storage_args["key"], 376 | location=self.map_location, 377 | offset=storage_offset * dsize, 378 | ) 379 | 380 | ret = torch._utils._rebuild_tensor_v2( 381 | storage, 0, size, stride, requires_grad, backward_hooks 382 | ) 383 | 384 | if self.dtype is not None and ret.dtype != dtype: 385 | ret = ret.to(dtype) 386 | 387 | return ret 388 | 389 | 390 | def load( 391 | ckpt, 392 | map_location="cpu", 393 | dtype=None, 394 | cache_tensors=False, 395 | debug=False, 396 | ): 397 | """ 398 | Should behave similarly to torch.load, but operates incrementally. 399 | Loads accessed tensors on the fly. 400 | ckpt should be a valid file path for now 401 | cache_tensors should be False 402 | """ 403 | assert map_location == "cpu" 404 | assert os.path.isfile(ckpt) 405 | 406 | tensors_meta = None 407 | with custom_load._open_zipfile_reader(open(ckpt, "rb")) as zf: 408 | try: 409 | tensors_meta = custom_load._custom_load( 410 | zf, torch.device(map_location), weights_only_unpickler 411 | ) 412 | except Exception as e: 413 | raise pickle.UnpicklingError( 414 | f"Error at zipslicer.load bootstrap stage, your torch pickle checkpoint is likely too complex for the lightweight loader to interpret. Make sure your network was saved as a state_dict, instead of general-purpose network pickle. Exception was: {e}" 415 | ) 416 | 417 | zipfile_h = zipfile.ZipFile(ckpt, "r", allowZip64=True) 418 | 419 | return LazyStateDict( 420 | tensors=tensors_meta, 421 | fh=open(ckpt, "rb"), 422 | zipfile=zipfile_h, 423 | map_location=map_location, 424 | debug=debug, 425 | dtype=dtype, 426 | cache_tensors=cache_tensors, 427 | ) 428 | -------------------------------------------------------------------------------- /zipslicer/custom_load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023- Kirill Gadjello with most input coming from Pytorch contributors 2 | # See LICENSE for details (basically it uses parts of PyTorch sourcecode and is licensed under the same conditions) 3 | 4 | import difflib 5 | import os 6 | 7 | import io 8 | import shutil 9 | import struct 10 | import sys 11 | import torch 12 | import tarfile 13 | import pathlib 14 | 15 | import tempfile 16 | import warnings 17 | 18 | from contextlib import closing, contextmanager 19 | 20 | from torch._sources import get_source_lines_and_file 21 | from torch.types import Storage 22 | from torch.storage import _get_dtype_from_pickle_storage_type 23 | 24 | from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Union, IO, Type 25 | from typing_extensions import TypeAlias 26 | 27 | TORCH_UNTYPED_STORAGE_CLS = ( 28 | torch.UntypedStorage if hasattr(torch, "UntypedStorage") else torch.ByteStorage 29 | ) 30 | 31 | DEFAULT_PROTOCOL = 2 32 | 33 | LONG_SIZE = struct.Struct("=l").size 34 | INT_SIZE = struct.Struct("=i").size 35 | SHORT_SIZE = struct.Struct("=h").size 36 | 37 | MAGIC_NUMBER = 0x1950A86A20F9469CFC6C 38 | PROTOCOL_VERSION = 1001 39 | STORAGE_KEY_SEPARATOR = "," 40 | 41 | FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] 42 | MAP_LOCATION: TypeAlias = Optional[ 43 | Union[ 44 | Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str] 45 | ] 46 | ] 47 | 48 | 49 | _string_classes = (str, bytes) 50 | 51 | 52 | def default_restore_location(storage, location): 53 | for _, _, fn in _package_registry: 54 | result = fn(storage, location) 55 | if result is not None: 56 | return result 57 | raise RuntimeError( 58 | "don't know how to restore data location of " 59 | + torch.typename(storage) 60 | + " (tagged with " 61 | + location 62 | + ")" 63 | ) 64 | 65 | 66 | def _check_seekable(f) -> bool: 67 | def raise_err_msg(patterns, e): 68 | for p in patterns: 69 | if p in str(e): 70 | msg = ( 71 | str(e) 72 | + ". You can only torch.load from a file that is seekable." 73 | + " Please pre-load the data into a buffer like io.BytesIO and" 74 | + " try to load from it instead." 75 | ) 76 | raise type(e)(msg) 77 | raise e 78 | 79 | try: 80 | f.seek(f.tell()) 81 | return True 82 | except (io.UnsupportedOperation, AttributeError) as e: 83 | raise_err_msg(["seek", "tell"], e) 84 | return False 85 | 86 | 87 | class SourceChangeWarning(Warning): 88 | pass 89 | 90 | 91 | @contextmanager 92 | def mkdtemp(): 93 | path = tempfile.mkdtemp() 94 | yield path 95 | shutil.rmtree(path) 96 | 97 | 98 | _package_registry = [] 99 | 100 | 101 | def _import_dotted_name(name): 102 | components = name.split(".") 103 | obj = __import__(components[0]) 104 | for component in components[1:]: 105 | obj = getattr(obj, component) 106 | return obj 107 | 108 | 109 | def normalize_storage_type(storage_type): 110 | return getattr(torch, storage_type.__name__) 111 | 112 | 113 | def storage_to_tensor_type(storage): 114 | storage_type = type(storage) 115 | module = _import_dotted_name(storage_type.__module__) 116 | return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) 117 | 118 | 119 | def _is_path(name_or_buffer): 120 | return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path) 121 | 122 | 123 | def _is_zipfile(f) -> bool: 124 | # This is a stricter implementation than zipfile.is_zipfile(). 125 | # zipfile.is_zipfile() is True if the magic number appears anywhere in the 126 | # binary. Since we expect the files here to be generated by torch.save or 127 | # torch.jit.save, it's safe to only check the start bytes and avoid 128 | # collisions and assume the zip has only 1 file. 129 | # See bugs.python.org/issue28494. 130 | 131 | # Read the first 4 bytes of the file 132 | read_bytes = [] 133 | start = f.tell() 134 | 135 | byte = f.read(1) 136 | while byte != b"": 137 | read_bytes.append(byte) 138 | if len(read_bytes) == 4: 139 | break 140 | byte = f.read(1) 141 | f.seek(start) 142 | 143 | local_header_magic_number = [b"P", b"K", b"\x03", b"\x04"] 144 | return read_bytes == local_header_magic_number 145 | 146 | 147 | class _opener(object): 148 | def __init__(self, file_like): 149 | self.file_like = file_like 150 | 151 | def __enter__(self): 152 | return self.file_like 153 | 154 | def __exit__(self, *args): 155 | pass 156 | 157 | 158 | class _open_file(_opener): 159 | def __init__(self, name, mode): 160 | super(_open_file, self).__init__(open(name, mode)) 161 | 162 | def __exit__(self, *args): 163 | self.file_like.close() 164 | 165 | 166 | class _open_buffer_reader(_opener): 167 | def __init__(self, buffer): 168 | super(_open_buffer_reader, self).__init__(buffer) 169 | _check_seekable(buffer) 170 | 171 | 172 | class _open_buffer_writer(_opener): 173 | def __exit__(self, *args): 174 | self.file_like.flush() 175 | 176 | 177 | def _open_file_like(name_or_buffer, mode): 178 | if _is_path(name_or_buffer): 179 | return _open_file(name_or_buffer, mode) 180 | else: 181 | if "w" in mode: 182 | return _open_buffer_writer(name_or_buffer) 183 | elif "r" in mode: 184 | return _open_buffer_reader(name_or_buffer) 185 | else: 186 | raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") 187 | 188 | 189 | class _open_zipfile_reader(_opener): 190 | def __init__(self, name_or_buffer) -> None: 191 | super(_open_zipfile_reader, self).__init__( 192 | torch._C.PyTorchFileReader(name_or_buffer) 193 | ) 194 | 195 | 196 | class _open_zipfile_writer_file(_opener): 197 | def __init__(self, name) -> None: 198 | super(_open_zipfile_writer_file, self).__init__( 199 | torch._C.PyTorchFileWriter(str(name)) 200 | ) 201 | 202 | def __exit__(self, *args) -> None: 203 | self.file_like.write_end_of_file() 204 | 205 | 206 | class _open_zipfile_writer_buffer(_opener): 207 | def __init__(self, buffer) -> None: 208 | self.buffer = buffer 209 | super(_open_zipfile_writer_buffer, self).__init__( 210 | torch._C.PyTorchFileWriter(buffer) 211 | ) 212 | 213 | def __exit__(self, *args) -> None: 214 | self.file_like.write_end_of_file() 215 | self.buffer.flush() 216 | 217 | 218 | def _open_zipfile_writer(name_or_buffer): 219 | container: Type[_opener] 220 | if _is_path(name_or_buffer): 221 | container = _open_zipfile_writer_file 222 | else: 223 | container = _open_zipfile_writer_buffer 224 | return container(name_or_buffer) 225 | 226 | 227 | def _is_compressed_file(f) -> bool: 228 | compress_modules = ["gzip"] 229 | try: 230 | return f.__module__ in compress_modules 231 | except AttributeError: 232 | return False 233 | 234 | 235 | def _should_read_directly(f): 236 | """ 237 | Checks if f is a file that should be read directly. It should be read 238 | directly if it is backed by a real file (has a fileno) and is not a 239 | a compressed file (e.g. gzip) 240 | """ 241 | if _is_compressed_file(f): 242 | return False 243 | try: 244 | return f.fileno() >= 0 245 | except io.UnsupportedOperation: 246 | return False 247 | except AttributeError: 248 | return False 249 | 250 | 251 | def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: 252 | # When using encoding='bytes' in Py3, some **internal** keys stored as 253 | # strings in Py2 are loaded as bytes. This function decodes them with 254 | # ascii encoding, one that Py3 uses by default. 255 | # 256 | # NOTE: This should only be used on internal keys (e.g., `typename` and 257 | # `location` in `persistent_load` below! 258 | if isinstance(bytes_str, bytes): 259 | return bytes_str.decode("ascii") 260 | return bytes_str 261 | 262 | 263 | def _get_restore_location(map_location): 264 | if map_location is None: 265 | restore_location = default_restore_location 266 | elif isinstance(map_location, dict): 267 | 268 | def restore_location(storage, location): 269 | location = map_location.get(location, location) 270 | return default_restore_location(storage, location) 271 | 272 | elif isinstance(map_location, _string_classes): 273 | 274 | def restore_location(storage, location): 275 | return default_restore_location(storage, map_location) 276 | 277 | elif isinstance(map_location, torch.device): 278 | 279 | def restore_location(storage, location): 280 | return default_restore_location(storage, str(map_location)) 281 | 282 | else: 283 | 284 | def restore_location(storage, location): 285 | result = map_location(storage, location) 286 | if result is None: 287 | result = default_restore_location(storage, location) 288 | return result 289 | 290 | return restore_location 291 | 292 | 293 | class StorageType: 294 | def __init__(self, name): 295 | self.dtype = _get_dtype_from_pickle_storage_type(name) 296 | 297 | def __str__(self): 298 | return f"StorageType(dtype={self.dtype})" 299 | 300 | 301 | def _load( 302 | zip_file, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args 303 | ): 304 | restore_location = _get_restore_location(map_location) 305 | 306 | loaded_storages = {} 307 | 308 | def load_tensor(dtype, numel, key, location): 309 | name = f"data/{key}" 310 | 311 | storage = ( 312 | zip_file.get_storage_from_record(name, numel, TORCH_UNTYPED_STORAGE_CLS) 313 | .storage() 314 | .untyped() 315 | ) 316 | # TODO: Once we decide to break serialization FC, we can 317 | # stop wrapping with TypedStorage 318 | loaded_storages[key] = torch.storage.TypedStorage( 319 | wrap_storage=restore_location(storage, location), dtype=dtype 320 | ) 321 | 322 | def persistent_load(saved_id): 323 | assert isinstance(saved_id, tuple) 324 | typename = _maybe_decode_ascii(saved_id[0]) 325 | data = saved_id[1:] 326 | 327 | assert ( 328 | typename == "storage" 329 | ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" 330 | storage_type, key, location, numel = data 331 | if storage_type is TORCH_UNTYPED_STORAGE_CLS: 332 | dtype = torch.uint8 333 | else: 334 | dtype = storage_type.dtype 335 | 336 | if key not in loaded_storages: 337 | nbytes = numel * torch._utils._element_size(dtype) 338 | load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) 339 | 340 | return loaded_storages[key] 341 | 342 | load_module_mapping: Dict[str, str] = { 343 | # See https://github.com/pytorch/pytorch/pull/51633 344 | "torch.tensor": "torch._tensor" 345 | } 346 | 347 | # Need to subclass Unpickler instead of directly monkey-patching the find_class method 348 | # because it's marked readonly in pickle. 349 | # The type: ignore is because mypy can't statically determine the type of this class. 350 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 351 | # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 352 | # Lets us override the imports that pickle uses when unpickling an object. 353 | # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. 354 | def find_class(self, mod_name, name): 355 | if type(name) is str and "Storage" in name: 356 | try: 357 | return StorageType(name) 358 | except KeyError: 359 | pass 360 | mod_name = load_module_mapping.get(mod_name, mod_name) 361 | return super().find_class(mod_name, name) 362 | 363 | # Load the data (which may in turn use `persistent_load` to load tensors) 364 | data_file = io.BytesIO(zip_file.get_record(pickle_file)) 365 | 366 | unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 367 | unpickler.persistent_load = persistent_load 368 | result = unpickler.load() 369 | 370 | torch._utils._validate_loaded_sparse_tensors() 371 | 372 | return result 373 | 374 | 375 | def _legacy_load(f, map_location, pickle_module, **pickle_load_args): 376 | deserialized_objects: Dict[int, Any] = {} 377 | 378 | restore_location = _get_restore_location(map_location) 379 | 380 | class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 381 | def find_class(self, mod_name, name): 382 | if type(name) is str and "Storage" in name: 383 | try: 384 | return StorageType(name) 385 | except KeyError: 386 | pass 387 | return super().find_class(mod_name, name) 388 | 389 | def _check_container_source(container_type, source_file, original_source): 390 | try: 391 | current_source = "".join(get_source_lines_and_file(container_type)[0]) 392 | except Exception: # saving the source is optional, so we can ignore any errors 393 | warnings.warn( 394 | "Couldn't retrieve source code for container of " 395 | "type " + container_type.__name__ + ". It won't be checked " 396 | "for correctness upon loading." 397 | ) 398 | return 399 | if original_source != current_source: 400 | if container_type.dump_patches: 401 | file_name = container_type.__name__ + ".patch" 402 | diff = difflib.unified_diff( 403 | current_source.split("\n"), 404 | original_source.split("\n"), 405 | source_file, 406 | source_file, 407 | lineterm="", 408 | ) 409 | lines = "\n".join(diff) 410 | try: 411 | with open(file_name, "a+") as f: 412 | file_size = f.seek(0, 2) 413 | f.seek(0) 414 | if file_size == 0: 415 | f.write(lines) 416 | elif file_size != len(lines) or f.read() != lines: 417 | raise IOError 418 | msg = ( 419 | "Saved a reverse patch to " + file_name + ". " 420 | "Run `patch -p0 < " + file_name + "` to revert your " 421 | "changes." 422 | ) 423 | except IOError: 424 | msg = ( 425 | "Tried to save a patch, but couldn't create a " 426 | "writable file " + file_name + ". Make sure it " 427 | "doesn't exist and your working directory is " 428 | "writable." 429 | ) 430 | else: 431 | msg = ( 432 | "you can retrieve the original source code by " 433 | "accessing the object's source attribute or set " 434 | "`torch.nn.Module.dump_patches = True` and use the " 435 | "patch tool to revert the changes." 436 | ) 437 | msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" 438 | warnings.warn(msg, SourceChangeWarning) 439 | 440 | def legacy_load(f): 441 | deserialized_objects: Dict[int, Any] = {} 442 | 443 | def persistent_load(saved_id): 444 | if isinstance(saved_id, tuple): 445 | # Ignore containers that don't have any sources saved 446 | if all(saved_id[1:]): 447 | _check_container_source(*saved_id) 448 | return saved_id[0] 449 | return deserialized_objects[int(saved_id)] 450 | 451 | with closing( 452 | tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) 453 | ) as tar, mkdtemp() as tmpdir: 454 | 455 | tar.extract("storages", path=tmpdir) 456 | with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: 457 | num_storages = pickle_module.load(f, **pickle_load_args) 458 | for i in range(num_storages): 459 | args = pickle_module.load(f, **pickle_load_args) 460 | key, location, storage_type = args 461 | dtype = storage_type.dtype 462 | obj = cast(Storage, TORCH_UNTYPED_STORAGE_CLS)._new_with_file( 463 | f, torch._utils._element_size(dtype) 464 | ) 465 | obj = restore_location(obj, location) 466 | # TODO: Once we decide to break serialization FC, we can 467 | # stop wrapping with TypedStorage 468 | deserialized_objects[key] = torch.storage.TypedStorage( 469 | wrap_storage=obj, dtype=dtype 470 | ) 471 | 472 | storage_views = pickle_module.load(f, **pickle_load_args) 473 | for target_cdata, root_cdata, offset, numel in storage_views: 474 | root = deserialized_objects[root_cdata] 475 | element_size = torch._utils._element_size(root.dtype) 476 | offset_bytes = offset * element_size 477 | # TODO: Once we decide to break serialization FC, we can 478 | # stop wrapping with TypedStorage 479 | deserialized_objects[target_cdata] = torch.storage.TypedStorage( 480 | wrap_storage=root._storage[ 481 | offset_bytes : offset_bytes + numel * element_size 482 | ], 483 | dtype=root.dtype, 484 | ) 485 | 486 | tar.extract("tensors", path=tmpdir) 487 | with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f: 488 | num_tensors = pickle_module.load(f, **pickle_load_args) 489 | for _ in range(num_tensors): 490 | args = pickle_module.load(f, **pickle_load_args) 491 | key, storage_id, original_tensor_type = args 492 | storage = deserialized_objects[storage_id] 493 | (ndim,) = struct.unpack(" maxsize: 268 | raise RuntimeError("String is too long") 269 | strval = str(read(strlen), "utf-8", "surrogatepass") 270 | self.append(strval) 271 | elif key[0] == SHORT_BINSTRING[0]: 272 | strlen = read(1)[0] 273 | strdata = read(strlen) 274 | if self.encoding != "bytes": 275 | strdata = strdata.decode(self.encoding, "strict") 276 | self.append(strdata) 277 | elif key[0] == BINPERSID[0]: 278 | pid = self.stack.pop() 279 | # Only allow persistent load of storage 280 | if type(pid) is not tuple and not type(pid) is not int: 281 | raise RuntimeError( 282 | f"persistent_load id must be tuple or int, but got {type(pid)}" 283 | ) 284 | if ( 285 | type(pid) is tuple 286 | and len(pid) > 0 287 | and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" 288 | ): 289 | raise RuntimeError( 290 | f"Only persistent_load of storage is allowed, but got {pid[0]}" 291 | ) 292 | self.append(self.persistent_load(pid)) 293 | elif key[0] in [BINGET[0], LONG_BINGET[0]]: 294 | idx = (read(1) if key[0] == BINGET[0] else unpack("