├── .github └── workflows │ ├── benchmark.yml │ ├── build.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── ATTIC ├── README.md ├── test_protobuf.py └── test_tensorizer.py ├── CHANGELOG.md ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── encryption.md └── subprocess-serialization.md ├── examples ├── __init__.py ├── benchmark │ ├── .images │ │ ├── gpt-j.png │ │ └── opt-30b.png │ ├── Dockerfile │ ├── GraphResults.ipynb │ ├── README.md │ ├── benchmark-job.yaml │ ├── benchmark-pvc.yaml │ ├── deserialize_benchmark.py │ ├── jupyter-lab-service.yaml │ ├── save-models-job.yaml │ └── save_models.py ├── benchmark_buffer_size │ ├── Dockerfile │ ├── README.md │ ├── benchmark.py │ ├── benchmark.yaml │ ├── lighttpd.conf │ ├── redis-server.yaml │ └── visualizations │ │ ├── requirements.txt │ │ └── visualizations.py ├── deserialize-simple.py ├── deserialize.py ├── encrypt_existing.py ├── encryption.py ├── hf_serialization.py ├── requirements.txt └── serialize.py ├── package-lock.json ├── package.json ├── proto ├── requirements.txt └── tensors.proto ├── pyproject.toml ├── requirements.txt ├── tensorizer ├── _NumpyTensor.py ├── __init__.py ├── _crypt │ ├── __init__.py │ ├── __main__.py │ ├── _cgroup_cpu_count.py │ ├── _encryption.py │ └── _exceptions.py ├── _crypt_info.py ├── _internal_utils.py ├── _linear_partition.py ├── _syscalls.py ├── _tensor_path.py ├── _version.py ├── _wide_pipes.py ├── protobuf.py ├── serialization.py ├── stream_io.py ├── tensors.proto ├── tensors_pb2.py └── utils.py ├── tensors ├── LICENSE ├── __init__.py ├── go.mod ├── go.sum ├── tensors.pb.go ├── tensors_pb.d.ts ├── tensors_pb.js └── tensors_pb2.py └── tests ├── __init__.py ├── requirements.txt ├── test_serialization.py ├── test_stream_io.py └── test_syscalls.py /.github/workflows/benchmark.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | env: 4 | REGISTRY: ghcr.io 5 | IMAGE_NAME: ${{ github.repository }} 6 | 7 | 8 | jobs: 9 | build: 10 | name: Build & push docker image 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: 'Login to GitHub Container Registry' 16 | uses: docker/login-action@v1 17 | with: 18 | registry: ${{env.REGISTRY}} 19 | username: ${{github.actor}} 20 | password: ${{secrets.GITHUB_TOKEN}} 21 | 22 | - name: Extract metadata (tags, labels) for Docker 23 | id: meta 24 | uses: docker/metadata-action@v5.0.0 25 | with: 26 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 27 | tags: 28 | type=sha,enable=true,priority=100,prefix=benchmark-,suffix=,format=short 29 | 30 | - name: Build and push Docker image 31 | uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 32 | with: 33 | context: . 34 | file: examples/benchmark_buffer_size/Dockerfile 35 | push: true 36 | tags: ${{ steps.meta.outputs.tags }} 37 | labels: ${{ steps.meta.outputs.labels }} 38 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build_dist: 9 | name: Build distribution on Ubuntu 10 | runs-on: ubuntu-22.04 11 | steps: 12 | - uses: actions/checkout@v4 13 | 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.10' 17 | 18 | - name: Install build dependencies 19 | run: python -m pip install --no-cache-dir -U setuptools build 20 | 21 | - name: Build distribution 22 | run: python -m build 23 | 24 | - uses: actions/upload-artifact@v3 25 | with: 26 | path: ./dist/* 27 | 28 | upload_dist: 29 | needs: [ build_dist ] 30 | name: Upload distribution to PyPI 31 | runs-on: ubuntu-22.04 32 | if: github.event_name == 'release' && github.event.action == 'published' 33 | steps: 34 | - uses: actions/download-artifact@v3 35 | with: 36 | name: artifact 37 | path: dist 38 | 39 | - uses: pypa/gh-action-pypi-publish@release/v1.8.1 40 | with: 41 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | run_tests: 7 | name: Run tests 8 | runs-on: ubuntu-22.04 9 | steps: 10 | - uses: actions/checkout@v4 11 | 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: '3.10' 15 | cache: 'pip' 16 | 17 | - name: Install dependencies 18 | run: > 19 | python -m pip install -e . && 20 | python -m pip install -r tests/requirements.txt 21 | 22 | - name: Install Redis 23 | run: sudo apt-get install -y redis-server 24 | 25 | - name: Install libsodium 26 | run: sudo apt-get install -y libsodium23 27 | 28 | - name: Run tests 29 | run: python -m unittest discover tests/ --verbose 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # CMake ignores 2 | CMakeLists.txt.user 3 | CMakeCache.txt 4 | CMakeFiles 5 | CMakeScripts 6 | Testing 7 | Makefile 8 | cmake_install.cmake 9 | install_manifest.txt 10 | compile_commands.json 11 | CTestTestfile.cmake 12 | _deps 13 | 14 | # Project specific 15 | node_modules 16 | pybuild 17 | 18 | # Python 19 | *.pyc 20 | *.egg-info 21 | build/ 22 | dist/ 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ^.*_pb2\.py|ATTIC/.*$ 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 24.3.0 5 | hooks: 6 | - id: black 7 | language_version: python3.10 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.13.2 10 | hooks: 11 | - id: isort 12 | args: ["--filter-files"] 13 | -------------------------------------------------------------------------------- /ATTIC/README.md: -------------------------------------------------------------------------------- 1 | # Attic (Unused Code) 2 | 3 | This directory contains tests and examples for older versions of `tensorizer` 4 | that have not yet been brought up-to-date. 5 | 6 | They do not work with the current version of `tensorizer`, 7 | but when they do work again, they should be moved out of the attic. 8 | -------------------------------------------------------------------------------- /ATTIC/test_protobuf.py: -------------------------------------------------------------------------------- 1 | import io 2 | import serializer 3 | import unittest 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | class TestSerializer(unittest.TestCase): 9 | def test_serializer_int(self): 10 | t = Tensor([1, 2, 3]) 11 | t_serialized = serializer.serialize_tensor(t) 12 | t_deserialized = serializer.deserialize_tensor(t_serialized) 13 | self.assertTrue(torch.equal(t, t_deserialized)) 14 | 15 | def test_serializer_fp16(self): 16 | t = Tensor([1.0, 2.0, 3.0]).half() 17 | t_serialized = serializer.serialize_tensor(t) 18 | t_deserialized = serializer.deserialize_tensor(t_serialized) 19 | self.assertTrue(torch.equal(t, t_deserialized)) 20 | 21 | def test_serializer_multi_dim(self): 22 | t = Tensor([[1, 2, 3], [4, 5, 6]]) 23 | t_serialized = serializer.serialize_tensor(t) 24 | t_deserialized = serializer.deserialize_tensor(t_serialized) 25 | self.assertTrue(torch.equal(t, t_deserialized)) 26 | 27 | def test_serializer_model(self): 28 | class TestModel(torch.nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.weight = torch.nn.Parameter(torch.rand(2, 3)) 32 | 33 | model = TestModel() 34 | model2 = TestModel() 35 | f = io.BytesIO() 36 | serializer.serialize_model(model, f) 37 | f.seek(0) 38 | serializer.deserialize_model(model2, f) 39 | self.assertTrue(torch.equal(model.weight, model2.weight)) 40 | -------------------------------------------------------------------------------- /ATTIC/test_tensorizer.py: -------------------------------------------------------------------------------- 1 | from tensorizer import tensorizer 2 | 3 | import unittest 4 | import torch 5 | 6 | class TestTensorizer(unittest.TestCase): 7 | def test_tensorizer_gptj(self): 8 | from transformers import GPTJConfig, GPTJForCausalLM 9 | # instantiate dummy model config 10 | config = GPTJConfig( 11 | n_positions=128, 12 | n_embd=16, 13 | n_layer=2, 14 | n_head=2 15 | ) 16 | model = GPTJForCausalLM(config=config) 17 | tensorizer.serialize_model(model, config, 'test') 18 | model2 = tensorizer.load_model('test', GPTJForCausalLM, GPTJConfig, None, 'float16') 19 | # compare models 20 | for name, param in model.named_parameters(): 21 | param2 = model2.state_dict()[name] 22 | self.assertTrue(torch.allclose(param, param2, atol=1e-3)) 23 | 24 | def test_tensorizer_vae(self): 25 | from diffusers import AutoencoderKL 26 | 27 | # instantiate dummy VAE 28 | config = { 29 | "in_channels": 1, 30 | "out_channels": 1, 31 | "block_out_channels": (64,), 32 | "latent_channels": 2, 33 | "norm_num_groups": 2, 34 | "sample_size": 2, 35 | } 36 | 37 | model = AutoencoderKL(**config) 38 | tensorizer.serialize_model(model, config, 'test') 39 | model2 = tensorizer.load_model('test', AutoencoderKL, None, None, 'float16') 40 | # compare models 41 | for name, param in model.named_parameters(): 42 | param2 = model2.state_dict()[name] 43 | self.assertTrue(torch.allclose(param, param2, atol=1e-3)) 44 | 45 | def tearDown(self): 46 | import shutil 47 | shutil.rmtree('test') 48 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project(interfaces) 3 | 4 | # Automatically determine our project namespace. 5 | find_package(Git) 6 | execute_process( 7 | COMMAND ${GIT_EXECUTABLE} config --get remote.origin.url 8 | OUTPUT_VARIABLE REMOTE_ORIGIN 9 | OUTPUT_STRIP_TRAILING_WHITESPACE) 10 | 11 | string(REPLACE "git@github.com:" "" REPO_PATH "${REMOTE_ORIGIN}") 12 | string(REPLACE "https://github.com/" "" REPO_PATH "${REPO_PATH}") 13 | 14 | string(REPLACE ".git" "" PROJECT_REF "${REPO_PATH}") 15 | string(TOLOWER ${PROJECT_REF} PROJECT_REF) 16 | 17 | 18 | find_package(PkgConfig REQUIRED) 19 | find_package(Protobuf REQUIRED) 20 | set(PROTO_PATH "${PROJECT_SOURCE_DIR}/proto") 21 | set(GENERATED_PROTOBUF_PATH "${PROJECT_SOURCE_DIR}/tensors") 22 | file(MAKE_DIRECTORY ${GENERATED_PROTOBUF_PATH}) 23 | 24 | ## Python target support 25 | find_package(Python3 REQUIRED COMPONENTS Interpreter) 26 | 27 | set(PYBUILD_PATH "${PROJECT_BINARY_DIR}/pybuild") 28 | execute_process(COMMAND python3 -m venv ${PYBUILD_PATH} 29 | RESULT_VARIABLE EXIT_CODE 30 | OUTPUT_QUIET) 31 | if (NOT ${EXIT_CODE} EQUAL 0) 32 | message(FATAL_ERROR 33 | "Could not create python3 env at ${PYBUILD_PATH}") 34 | endif() 35 | 36 | execute_process(COMMAND ${PYBUILD_PATH}/bin/pip3 show grpcio-tools grpcio protobuf 37 | RESULT_VARIABLE EXIT_CODE 38 | OUTPUT_QUIET) 39 | if (NOT ${EXIT_CODE} EQUAL 0) 40 | execute_process(COMMAND ${PYBUILD_PATH}/bin/pip3 install -r ${PROJECT_SOURCE_DIR}/proto/requirements.txt 41 | RESULT_VARIABLE EXIT_CODE) 42 | if (NOT ${EXIT_CODE} EQUAL 0) 43 | message(FATAL_ERROR 44 | "Could not install python3 requirements at ${PYBUILD_PATH}") 45 | endif() 46 | endif() 47 | 48 | set(python_exec "${PYBUILD_PATH}/bin/python3") 49 | set(python_args "-m" "grpc_tools.protoc") 50 | set(python_plugin "") 51 | set(python_output "--python_out=") 52 | set(python_output_dir "${PROJECT_SOURCE_DIR}") 53 | file(MAKE_DIRECTORY "${python_output_dir}") 54 | file(WRITE "${PROJECT_SOURCE_DIR}/tensors/__init__.py") 55 | set(python_exts "_pb2.py") 56 | 57 | ## Golang target support 58 | execute_Process(COMMAND go version 59 | RESULT_VARIABLE EXIT_CODE) 60 | if (NOT ${EXIT_CODE} EQUAL 0) 61 | message(FATAL_ERROR 62 | "You need to have a `golang` environment installed with an appropriately set GOROOT.") 63 | endif() 64 | 65 | execute_process(COMMAND go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 66 | RESULT_VARIABLE EXIT_CODE) 67 | if (NOT ${EXIT_CODE} EQUAL 0) 68 | message(FATAL_ERROR 69 | "Error ensuring that `protoc-gen-go` is installed.") 70 | endif() 71 | 72 | set(golang_plugin "") 73 | set(golang_output "--go_out=paths=source_relative:") 74 | set(golang_output_dir "${PROJECT_SOURCE_DIR}") 75 | file(MAKE_DIRECTORY "${golang_output_dir}/tensors") 76 | set(golang_exts ".go") 77 | 78 | # Javascript / Typescript target support 79 | execute_process(COMMAND npm --version 80 | RESULT_VARIABLE EXIT_CODE OUTPUT_QUIET) 81 | if (NOT ${EXIT_CODE} EQUAL 0) 82 | message(FATAL_ERROR 83 | "npm is not installed. Please ensure that it is installed by using your favorite package manager.") 84 | endif() 85 | 86 | execute_process(COMMAND npm install 87 | RESULT_VARIABLE EXIT_CODE) 88 | if (NOT ${EXIT_CODE} EQUAL 0) 89 | message(FATAL_ERROR 90 | "npm install failed!") 91 | endif() 92 | set(NODE_BIN_DIRECTORY "${PROJECT_SOURCE_DIR}/node_modules/.bin") 93 | 94 | set(javascript_exec "protoc") 95 | set(javascript_plugin "--plugin=protoc-gen-ts=${NODE_BIN_DIRECTORY}/protoc-gen-ts") 96 | set(javascript_args "") 97 | set(javascript_output "--js_out=import_style=commonjs,binary:") 98 | set(javascript_output_dir "${PROJECT_SOURCE_DIR}") 99 | file(MAKE_DIRECTORY "${javascript_output_dir}") 100 | set(javascript_exts "_pb.js") 101 | 102 | set(typescript_exec "protoc") 103 | set(typescript_plugin "--plugin=protoc-gen-ts=${NODE_BIN_DIRECTORY}/protoc-gen-ts") 104 | set(typescript_args "") 105 | set(typescript_output "--ts_out=") 106 | set(typescript_output_dir "${PROJECT_SOURCE_DIR}") 107 | file(MAKE_DIRECTORY "${typescript_output_dir}") 108 | set(typescript_exts "_pb.d.ts") 109 | 110 | ## Protobuf and GRPC stub building macros 111 | macro (_add_pb_file _src TYP VAR) 112 | message("Will generate stub ${VAR} for ${_src}") 113 | list(APPEND SRC_${VAR} ${_src}) 114 | endmacro() 115 | 116 | macro (add_protobufs) 117 | foreach (_src ${ARGN}) 118 | _add_pb_file(${_src} PROTO Protobufs) 119 | endforeach() 120 | endmacro() 121 | 122 | macro(_generate_interface LANG INTERFACE_FILE) 123 | get_filename_component(_PROTOBUF_DIR "${INTERFACE_FILE}" DIRECTORY) 124 | get_filename_component(_PROTOBUF_SHORT "${INTERFACE_FILE}" NAME_WE) 125 | file(MAKE_DIRECTORY "${${_lang}_output_dir}") 126 | file(MAKE_DIRECTORY "${_PROTOBUF_DIR}") 127 | set(_PROTOBUF_NAME "${_PROTOBUF_DIR}/${_PROTOBUF_SHORT}") 128 | set(OUTPUT_FILES) 129 | set(CMD_EXEC) 130 | foreach(_ext ${${LANG}_exts}) 131 | set(OUTPUT_FILE_NAME "${${LANG}_output_dir}/${_PROTOBUF_SHORT}${_ext}") 132 | list(APPEND GENERATED_PROTOBUF_FILES_${LANG} "${OUTPUT_FILE_NAME}") 133 | list(APPEND OUTPUT_FILES "${OUTPUT_FILE_NAME}") 134 | message("${INTERFACE_FILE} => ${OUTPUT_FILE_NAME}") 135 | endforeach() 136 | if(DEFINED ${LANG}_exec) 137 | set(CMD_EXEC ${${LANG}_exec}) 138 | else() 139 | set(CMD_EXEC "${PROTOBUF_PROTOC_EXECUTABLE}") 140 | endif() 141 | add_custom_command( 142 | OUTPUT ${OUTPUT_FILES} 143 | 144 | COMMAND "mkdir" 145 | ARGS "-p" 146 | ARGS "${${LANG}_output_dir}/${_PROTOBUF_SHORT}" 147 | 148 | COMMAND ${CMD_EXEC} 149 | ARGS ${${LANG}_args} 150 | ARGS "--proto_path=${PROTO_PATH}" 151 | ARGS "--experimental_allow_proto3_optional" 152 | ARGS ${${LANG}_plugin} 153 | ARGS "${${LANG}_output}${${LANG}_output_dir}/${_PROTOBUF_SHORT}" 154 | ARGS "${INTERFACE_FILE}") 155 | endmacro() 156 | 157 | macro(generate_interfaces) 158 | foreach(_lang ${TARGET_LANGUAGES}) 159 | foreach(_src ${SRC_Interfaces} ${SRC_Protobufs}) 160 | _generate_interface("${_lang}" ${_src}) 161 | endforeach() 162 | foreach(_src ${SRC_Interfaces}) 163 | if(DEFINED ${_lang}_grpc_output) 164 | _generate_interface(${_lang}_grpc ${_src}) 165 | endif() 166 | endforeach() 167 | endforeach() 168 | endmacro() 169 | 170 | macro(add_target_languages) 171 | foreach(_lang ${ARGN}) 172 | message("Will generate stubs for ${_lang}") 173 | #file(MAKE_DIRECTORY "${GENERATED_PROTOBUF_PATH}/${_lang}") 174 | file(MAKE_DIRECTORY "${${_lang}_output_dir}") 175 | list(APPEND TARGET_LANGUAGES ${_lang}) 176 | endforeach() 177 | endmacro() 178 | 179 | set(RESOURCES) 180 | macro(add_resource) 181 | foreach(_res ${ARGN}) 182 | list(APPEND RESOURCES "${CMAKE_CURRENT_BINARY_DIR}/${_res}") 183 | add_custom_command( 184 | OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${_res}" 185 | COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${_res}" 186 | "${CMAKE_CURRENT_BINARY_DIR}/${_res}" 187 | DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${_res}") 188 | endforeach() 189 | endmacro() 190 | 191 | # Set our build targets. 192 | add_target_languages( 193 | python 194 | golang 195 | javascript 196 | typescript 197 | ) 198 | 199 | # Generate base protobufs 200 | add_protobufs(${CMAKE_SOURCE_DIR}/proto/tensors.proto) 201 | generate_interfaces() 202 | 203 | add_custom_command( 204 | OUTPUT "tensors/go.mod" 205 | WORKING_DIRECTORY "${GENERATED_PROTOBUF_PATH}" 206 | COMMAND rm -f go.mod 207 | COMMAND go mod init github.com/${PROJECT_REF}/tensors 208 | COMMAND go mod tidy 209 | DEPENDS ${GENERATED_PROTOBUF_FILES_golang_grpc}) 210 | 211 | add_custom_target( 212 | generated ALL 213 | DEPENDS 214 | ${GENERATED_PROTOBUF_FILES_python} 215 | ${GENERATED_PROTOBUF_FILES_golang} 216 | ${GENERATED_PROTOBUF_FILES_javascript} 217 | ${GENERATED_PROTOBUF_FILES_typescript} 218 | ${PROJECT_SOURCE_DIR}/tensors/go.mod 219 | ) 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 CoreWeave 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft tensors 2 | graft proto 3 | include CMakeLists.txt 4 | exclude tensorizer/_crypt/__main__.py -------------------------------------------------------------------------------- /docs/encryption.md: -------------------------------------------------------------------------------- 1 | # Tensor Encryption 2 | 3 | `tensorizer` supports fast tensor weight encryption and decryption during 4 | serialization and deserialization, respectively. 5 | 6 | > [!NOTE] 7 | > 8 | > To use `tensorizer` encryption, a recent version of `libsodium` must be 9 | > installed. Install `libsodium` with `apt-get install libsodium23` 10 | > on Ubuntu or Debian, or follow 11 | > [the instructions in `libsodium`'s documentation](https://doc.libsodium.org/installation) 12 | > for other platforms. 13 | 14 | ## Encryption Algorithm 15 | 16 | Tensor encryption splits weights into up-to-2 MiB chunks encrypted independently 17 | with XSalsa20-Poly1305 symmetric authenticated encryption, 18 | stored separately from their MACs. 19 | These chunks can be encrypted or decrypted and authenticated independently of 20 | one another in the style of a block cipher, 21 | which allows decryption parallelized with streaming. 22 | 23 | All encryption and decryption is done in-place, as encrypted payloads are equal 24 | in length to their plaintexts in this scheme (since the MACs are stored 25 | separately). 26 | This allows for high speed processing, since memory allocations can be avoided. 27 | 28 | ## Scope / Security 29 | 30 | Only tensor weights are encrypted, using 256-bit keys (see 31 | [Choosing a Key Derivation Algorithm](#choosing-a-key-derivation-algorithm)), 32 | and weight chunks are independently authenticated. 33 | This is meant to provide security against tensor weights from being read by a 34 | third party, plus a small amount of authentication to confirm that, 35 | for example, a matching passphrase was used for encryption and decryption; 36 | security beyond that is beyond the scope of `tensorizer`'s encryption as 37 | currently available. 38 | 39 | > [!WARNING] 40 | > 41 | > This does not include encryption for anything except for tensor weights. 42 | > Metadata such as a tensor's name, dtype, shape, size, and non-keyed hashes 43 | > are neither encrypted nor authenticated. 44 | 45 | > [!WARNING] 46 | > 47 | > This level of encryption does not provide message authentication for metadata 48 | > and does not protect against reordering or truncation of chunks. 49 | 50 | > [!NOTE] 51 | > 52 | > Unencrypted/unauthenticated tensor data is rejected during deserialization 53 | > if decryption is requested, and vice versa. 54 | 55 | ## Usage 56 | 57 | A full usage example is given in 58 | [`examples/encryption.py`](/examples/encryption.py). 59 | 60 | The class docstrings of `EncryptionParams` and `DecryptionParams` include 61 | usage outlines like below as well as additional usage information. 62 | Most IDEs support automatically displaying this information while coding. 63 | 64 | ### Encrypting and Decrypting Existing Models with a CLI Tool 65 | 66 | Existing models that have already been serialized by tensorizer can have 67 | encryption added to them or removed from them using the example 68 | [`examples/encrypt_existing.py`](/examples/encrypt_existing.py) 69 | command line utility. Download the script, then run 70 | `python encrypt_existing.py -h` to see usage. 71 | 72 | The source code in [`examples/encrypt_existing.py`](/examples/encrypt_existing.py) 73 | also serves as usage examples for the various encryption methods. 74 | 75 | Examples: 76 | 77 | ```bash 78 | # Global help and subcommand help 79 | python encrypt_existing.py --help 80 | python encrypt_existing.py add pwhash --help 81 | 82 | # Encrypt using a random binary key (outputs generated key to --keyfile) 83 | python encrypt_existing.py add random --keyfile encrypted.tensors.key \ 84 | --infile original.tensors --outfile encrypted.tensors 85 | 86 | # Encrypt using a pre-existing binary key (reads key from --keyfile) 87 | python encrypt_existing.py add exact --keyfile encrypted.tensors.key \ 88 | --infile original.tensors --outfile encrypted.tensors 89 | 90 | # Encrypt using Argon2id key derivation (reads string to turn into a key from --keyfile) 91 | python encrypt_existing.py add pwhash --keyfile encrypted.tensors.key \ 92 | --opslimit MODERATE --memlimit MODERATE \ 93 | --infile original.tensors --outfile encrypted.tensors 94 | 95 | # Decrypt using a binary key (reads key from --keyfile) 96 | python encrypt_existing.py remove exact --keyfile encrypted.tensors.key \ 97 | --infile encrypted.tensors --outfile decrypted.tensors 98 | 99 | # Decrypt using Argon2id key derivation (reads string to turn into a key from --keyfile) 100 | python encrypt_existing.py remove pwhash --keyfile encrypted.tensors.key \ 101 | --infile encrypted.tensors --outfile decrypted.tensors 102 | ``` 103 | 104 | ### Using `EncryptionParams.random()` 105 | 106 | This is the preferred method of tensor encryption and decryption. 107 | Use this unless you have a good reason to do otherwise. 108 | 109 | ```py 110 | from tensorizer import ( 111 | EncryptionParams, DecryptionParams, TensorDeserializer, TensorSerializer 112 | ) 113 | 114 | # Serialize and encrypt a model: 115 | encryption_params = EncryptionParams.random() 116 | 117 | serializer = TensorSerializer("model.tensors", encryption=encryption_params) 118 | serializer.write_module(...) # or write_state_dict(), etc. 119 | serializer.close() 120 | 121 | # Save the randomly-generated encryption key somewhere 122 | with open("tensor.key", "wb") as key_file: 123 | key_file.write(encryption_params.key) 124 | 125 | 126 | # Then decrypt it again: 127 | 128 | # Load the randomly-generated key from where it was saved 129 | with open("tensor.key", "rb") as key_file: 130 | key: bytes = key_file.read() 131 | 132 | decryption_params = DecryptionParams.from_key(key) 133 | 134 | deserializer = TensorDeserializer("model.tensors", encryption=decryption_params) 135 | deserializer.load_into_module(...) 136 | deserializer.close() 137 | ``` 138 | 139 | ### Using `EncryptionParams.from_string()` with an environment variable 140 | 141 | If an encryption key must be provided as a pre-existing string, 142 | this method of encryption will allow the use of that string, 143 | and does not require saving a key generated at the time of encryption. 144 | 145 | > [!WARNING] 146 | > 147 | > Make sure a secure input string is used to create a key. 148 | > `EncryptionParams.from_string()` accepts parameters to tune its algorithm 149 | > to make searching the input string via brute-force checking less viable, 150 | > but nothing can protect against weak enough input strings, 151 | > like your birthdate, or a common password. 152 | 153 | ```py 154 | from tensorizer import ( 155 | EncryptionParams, DecryptionParams, TensorDeserializer, TensorSerializer 156 | ) 157 | 158 | source: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD") 159 | 160 | # Serialize and encrypt a model: 161 | encryption_params = EncryptionParams.from_string(source) 162 | serializer = TensorSerializer("model.tensors", encryption=encryption_params) 163 | serializer.write_module(...) # or write_state_dict(), etc. 164 | serializer.close() 165 | 166 | # Then decrypt it again: 167 | decryption_params = DecryptionParams.from_string(source) 168 | deserializer = TensorDeserializer("model.tensors", encryption=decryption_params) 169 | deserializer.load_into_module(...) 170 | deserializer.close() 171 | ``` 172 | 173 | ### Choosing a Key Derivation Algorithm 174 | 175 | The classes `EncryptionParams` and `DecryptionParams` allow a choice of 176 | key derivation method. Two methods are implemented: 177 | 178 | 1. Random key generation 179 | 1. Chosen by constructing an `EncryptionParams` object through calling 180 | `EncryptionParams.random()` 181 | 2. Not technically key derivation 182 | 3. Uses a completely random 32-byte sequence with no associated passphrase 183 | 4. Highly secure against being guessed 184 | 5. You must save the randomly generated key 185 | 2. [Argon2id](https://datatracker.ietf.org/doc/html/rfc9106) key derivation 186 | 1. Chosen by constructing an `EncryptionParams` object through calling 187 | `EncryptionParams.from_string(source)` 188 | 2. Transmutes an arbitrary-length `str` or `bytes` source string into a 189 | binary encryption key 190 | 3. Implements adjustable security against brute-force cracking 191 | via its `opslimit` and `memlimit` parameters 192 | 4. Internally uses 193 | [`libsodium`'s `pwhash` function with the algorithm `crypto_pwhash_ALG_ARGON2ID13`](https://libsodium.gitbook.io/doc/password_hashing/default_phf#key-derivation) 194 | 195 | An `EncryptionParams` object is passed to a `TensorSerializer` using its 196 | `encryption=...` keyword-only parameter during initialization. 197 | 198 | #### `EncryptionParams.from_string()` details (Argon2id) 199 | 200 | `EncryptionParams.from_string()` uses the Argon2 (Argon2id, RFC 9106) 201 | password hashing algorithm to create a key from an input string. 202 | 203 | The key has resistance against brute-force attacks that attempt 204 | to guess the input string, achieved by making each attempt 205 | expensive to compute, both in CPU time and RAM usage. 206 | 207 | The computational difficulty can be increased or decreased 208 | via the `opslimit` and `memlimit` parameters. 209 | Higher computational difficulty gives more security 210 | for weak input strings, but may impact performance. 211 | The default setting is a "moderate" profile taken from `libsodium`. 212 | 213 | Presets (as well as minimum values) are available through the 214 | `EncryptionParams.OpsLimit` and `EncryptionParams.MemLimit` enums. 215 | 216 | Rough estimates of performance impact (on a 3.20 GHz processor): 217 | 218 | ```py 219 | from tensorizer import EncryptionParams 220 | 221 | OpsLimit = EncryptionParams.OpsLimit 222 | MemLimit = EncryptionParams.MemLimit 223 | s = "X" * 40 224 | 225 | EncryptionParams.from_string( # Takes about 0.05 ms, 8 KiB RAM 226 | s, opslimit=OpsLimit.MIN, memlimit=MemLimit.MIN 227 | ) 228 | EncryptionParams.from_string( # Takes about 90 ms, 64 MiB RAM 229 | s, opslimit=OpsLimit.INTERACTIVE, memlimit=MemLimit.INTERACTIVE 230 | ) 231 | EncryptionParams.from_string( # Takes about 500 ms, 256 MiB RAM 232 | s, opslimit=OpsLimit.MODERATE, memlimit=MemLimit.MODERATE 233 | # Default: equivalent to opslimit=None, memlimit=None 234 | ) 235 | EncryptionParams.from_string( # Takes about 3.0 seconds, 1 GiB RAM 236 | s, opslimit=OpsLimit.SENSITIVE, memlimit=MemLimit.SENSITIVE 237 | ) 238 | ``` 239 | 240 | Timing may be different on different hardware. 241 | These do not reflect the exact times an attacker may require for each guess. 242 | 243 | ##### Performance tuning 244 | 245 | If possible, use `EncryptionParams.random()` instead of 246 | `EncryptionParams.from_string()`, and save the generated key 247 | to use for decryption. 248 | 249 | If that is not possible, save the binary key generated during 250 | `EncryptionParams.from_string()` (from the `.key` attribute), 251 | and use that key for decryption (via `DecryptionParams.from_key()`) 252 | to remove the cost of re-computing the key at deserialization time. 253 | 254 | If that is not possible, use a strong input string. 255 | For input strings that are already very strong and high-entropy, 256 | where brute-force attacks on the input string are no more likely 257 | to succeed than brute-force attacks on a 256-bit key itself, 258 | (e.g. very long, randomly generated strings), 259 | `opslimit` and `memlimit` may be tuned down to minimize 260 | their performance impact. 261 | 262 | If that is not possible, test different values of `opslimit` 263 | and `memlimit` to determine an acceptable tradeoff between 264 | performance and security for your use case. 265 | 266 | See also: 267 | - [`libsodium` documentation for `pwhash`](https://libsodium.gitbook.io/doc/password_hashing/default_phf#key-derivation), 268 | the Argon2id implementation used in `EncryptionParams.from_string()` 269 | - [RFC 9106](https://datatracker.ietf.org/doc/html/rfc9106) 270 | for details on Argon2 and Argon2id 271 | 272 | #### Using the right key derivation algorithm for decryption 273 | 274 | Specifying whether to decrypt with a passphrase or key is done with 275 | a `DecryptionParams` object. 276 | A `DecryptionParams` object is passed to a `TensorDeserializer` using its 277 | `encryption=...` keyword-only parameter during initialization. 278 | 279 | When passphrase-based key derivation is used during encryption, 280 | *key derivation metadata* recording the algorithm used is stored 281 | in the tensorized file. Since the file keeps track of the algorithm, 282 | any `from_string()`-based encryption can be decrypted the same way: 283 | 284 | ```py 285 | source: str = ... 286 | decryption_params = DecryptionParams.from_string(source) 287 | deserializer = TensorDeserializer(..., encryption=decryption_params) 288 | ``` 289 | 290 | Additionally, *any* encryption, 291 | whether a passphrase was used during encryption or not, 292 | can be decrypted if you know its exact binary key: 293 | 294 | ```py 295 | key: bytes = ... 296 | decryption_params = DecryptionParams.from_key(key) 297 | deserializer = TensorDeserializer(..., encryption=decryption_params) 298 | ``` 299 | 300 | This is the only way to decrypt a file that was encrypted using 301 | `EncryptionParams.random()`, since it has no associated passphrase. 302 | 303 | To retrieve a binary key from an `EncryptionParams` object, access its `key` 304 | attribute: 305 | 306 | ```py 307 | encryption_params = EncryptionParams.random() 308 | # Or 309 | encryption_params = EncryptionParams.from_string(...) 310 | 311 | key: bytes = encryption_params.key 312 | ``` 313 | 314 | ## Speed 315 | 316 | The throughput of `tensorizer`'s encryption algorithm reaches 31 GiB/s on 317 | ~26 cores, most likely limited by RAM or CPU cache speed. 318 | Since it can overlap with downloads, the time overhead of decryption is 319 | very small, with data-processed-to-latency-incurred rates in the terabit range 320 | encountered on test machines. 321 | 322 | Speed of key derivation is configurable, if used. 323 | See [Choosing a Key Derivation Algorithm](#choosing-a-key-derivation-algorithm). 324 | 325 | ## Compatibility 326 | 327 | Tensors serialized with encryption are stored using Tensorizer data format v3, 328 | compatible to be read with `tensorizer>=2.7.0`. 329 | -------------------------------------------------------------------------------- /docs/subprocess-serialization.md: -------------------------------------------------------------------------------- 1 | # Tensorizer serialization via subprocess 2 | 3 | If you're using Tensorizer serialization to write checkpoints during training, 4 | you may want to run the serialization concurrently from your training code so 5 | that you can execute your next training step as quickly as possible. And 6 | because of the Python GIL, it's better to do this in a separate process so that the 7 | serialization doesn't utilize any of the GIL that you'd otherwise use in your training code. 8 | 9 | Keep in mind that this is a way to achieve _concurrency_, not instant 10 | snapshotting. The tensors you are checkpointing still need to be kept in memory, 11 | unmodified, for the duration of the serialization process. (Though you may 12 | choose to copy them out of CUDA memory into CPU memory. These tradeoffs are 13 | discussed below.) 14 | 15 | Also refer to [PyTorch Multiprocessing best 16 | practices](https://pytorch.org/docs/stable/notes/multiprocessing.html) for more 17 | details about using PyTorch across processes 18 | 19 | 20 | ## Warning about fork() and threads 21 | Be aware that Python `os.fork()` is often not a viable option, as it can cause deadlocks if you have multiple threads. Python 3.12 and above 22 | will [issue a deprecation warning](https://github.com/python/cpython/pull/100229) if you attempt this. Some 3rd-party packages that rely on sockets or file descriptors may also not behave correctly when a process unexpectedly forks. 23 | 24 | A subprocess (fork + exec) is generally safer, but you do not inherently get 25 | shared memory with the calling process. `multiprocessing` has two ways to create 26 | a child process: `spawn` or `forkserver`. `spawn` should always be safe. 27 | `forkserver` can be faster but safety depends on the behavior of modules at 28 | import time. See 29 | https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods 30 | for more details. 31 | 32 | ## If starting from CUDA 33 | Presuming your tensors are in CUDA memory, there are a couple different options. 34 | 35 | ### Option 1: Communicate the CUDA tensors directly 36 | CUDA tensors can be "shared" to a subprocess very efficiently since it's only communicating a pointer to device memory. 37 | 38 | Basically send the CUDA tensors over a `multiprocessing.Queue` to a subprocess that does the serialization. Ensure that the CUDA tensors remain **unmodified** in device memory until the serialization process finishes. 39 | 40 | ```python 41 | import torch 42 | from tensorizer import TensorSerializer 43 | from transformers import AutoModelForCausalLM 44 | import torch.multiprocessing as mp 45 | 46 | def do_serialize(uri: str, model: torch.nn.Module): 47 | serializer = TensorSerializer(uri) 48 | serializer.write_module(model) 49 | serializer.close() 50 | 51 | def my_gpu_model() -> torch.nn.Module: 52 | model_ref = "EleutherAI/gpt-j-6B" 53 | model = AutoModelForCausalLM.from_pretrained( 54 | model_ref, 55 | revision="float16", 56 | torch_dtype=torch.float16, 57 | low_cpu_mem_usage=True, 58 | ) 59 | model.to('cuda') 60 | return model 61 | 62 | def main(): 63 | dest = "gpt-j-6B.tensors" 64 | model = my_gpu_model() 65 | 66 | mp.set_start_method('spawn') 67 | p = mp.Process(target=do_serialize, args=(dest, model)) 68 | p.start() 69 | 70 | # main process is now free to do other stuff but `model` must remain in CUDA 71 | # memory until the `p` subprocess finishes 72 | 73 | p.join() 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | ``` 79 | 80 | ### Option 2: Snapshot CUDA tensors to CPU memory in subprocess before serializing 81 | 82 | Once the tensors are in CPU memory, they no longer need to occupy CUDA memory. But the tensors 83 | will now need to occupy CPU memory until they are fully serialized. 84 | 85 | Do this by calling `model.to("cpu")` immediately after sending to the serializer. 86 | 87 | If you like, you can also use some sort of IPC object to communicate back to the 88 | host process when the snapshotting has finished so you know when the CUDA memory 89 | can be released. The below code uses a `Queue` 90 | 91 | ```python 92 | import torch 93 | from tensorizer import TensorSerializer 94 | from transformers import AutoModelForCausalLM 95 | import torch.multiprocessing as mp 96 | 97 | def do_serialize(uri: str, model: pytorch.nn.Module, snapshot_done: mp.Queue): 98 | model = model.to('cpu') # Snapshot now 99 | snapshot_done.put(True) 100 | 101 | serializer = TensorSerializer(uri) 102 | serializer.write_module(model) 103 | serializer.close() 104 | 105 | def my_gpu_model() -> torch.nn.Module: 106 | model_ref = "EleutherAI/gpt-j-6B" 107 | model = AutoModelForCausalLM.from_pretrained( 108 | model_ref, 109 | revision="float16", 110 | torch_dtype=torch.float16, 111 | low_cpu_mem_usage=True, 112 | ) 113 | model.to('cuda') 114 | return model 115 | 116 | def main(): 117 | dest = "gpt-j-6B.tensors" 118 | model = my_gpu_model() 119 | 120 | mp.set_start_method('spawn') 121 | snapshot_done = mp.Queue() 122 | p = mp.Process(target=do_serialize, args=(dest, model, snapshot_done)) 123 | p.start() 124 | 125 | # main process is now free to do other stuff 126 | # but `model` must remain in CUDA memory 127 | 128 | snapshot_done.get() 129 | # Subprocess copied model into CPU memory. Free to release the CUDA-based model 130 | del model 131 | 132 | # ... do other stuff ... 133 | 134 | if not p.is_alive(): 135 | print('Serialization finished.') 136 | 137 | p.join() 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | ``` 143 | 144 | ## If starting from CPU memory 145 | 146 | Tensors in CPU memory need to moved to shared memory to be communicated with a subprocess. PyTorch `multiprocessing` will do this automatically, but be aware 147 | that a memcpy occurs. You'll also need additional "surge" CPU memory during the duration of the copy of each tensor. PyTorch copies tensors serially, so you need additional memory equal to the size of your largest tensor. This is only used during the memcpy itself. The original non-shared memory is immediately freed thereafter (unless it is also in use by some other tensor) 148 | 149 | Depending on how you are constructing your CPU tensor, you may be able to preemptively `tensor.share_memory()` ahead of time, thus saving a memcpy when 150 | passing to the subprocess. 151 | 152 | > [!WARNING] 153 | > 154 | > The main process should avoid modifying tensors while they are being serialized from shared memory, to avoid corrupting the written file. If serializing *with encryption* from shared memory, tensors should additionally not be read again until serialization has finished, as encryption temporarily modifies tensors in-place. 155 | > 156 | > If concurrent modification or access is necessary, move the tensors out of shared memory and into a copy in the subprocess before serialization. This can be done in the same style shown for snapshotting CUDA tensors in a previous example. 157 | 158 | ```python 159 | import torch 160 | from tensorizer import TensorSerializer 161 | from transformers import AutoModelForCausalLM 162 | import torch.multiprocessing as mp 163 | 164 | def do_serialize(uri: str, model: torch.nn.Module): 165 | serializer = TensorSerializer(uri) 166 | serializer.write_module(model) 167 | serializer.close() 168 | 169 | def my_gpu_model() -> torch.nn.Module: 170 | model_ref = "EleutherAI/gpt-j-6B" 171 | model = AutoModelForCausalLM.from_pretrained( 172 | model_ref, 173 | revision="float16", 174 | torch_dtype=torch.float16, 175 | low_cpu_mem_usage=True, 176 | ) 177 | return model 178 | 179 | def main(): 180 | dest = "gpt-j-6B.tensors" 181 | model = my_gpu_model() 182 | 183 | mp.set_start_method('spawn') 184 | 185 | # this will execute model.share_memory() 186 | p = mp.Process(target=do_serialize, args=(dest, model)) 187 | 188 | p.start() 189 | 190 | # main process is now free to do other stuff 191 | # but `model` must remain in CPU memory until the `p` subprocess finishes 192 | 193 | p.join() 194 | 195 | 196 | if __name__ == '__main__': 197 | main() 198 | ``` 199 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coreweave/tensorizer/9241bc82e7b9fdc3f92aa38ad04efd54f0054525/examples/__init__.py -------------------------------------------------------------------------------- /examples/benchmark/.images/gpt-j.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coreweave/tensorizer/9241bc82e7b9fdc3f92aa38ad04efd54f0054525/examples/benchmark/.images/gpt-j.png -------------------------------------------------------------------------------- /examples/benchmark/.images/opt-30b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coreweave/tensorizer/9241bc82e7b9fdc3f92aa38ad04efd54f0054525/examples/benchmark/.images/opt-30b.png -------------------------------------------------------------------------------- /examples/benchmark/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/coreweave/ml-containers/torch:bb02bee-nccl-cuda12.0.1-nccl2.17.1-1-torch2.0.0-vision0.15.1-audio2.0.1 2 | 3 | RUN pip install transformers tensorizer==1.1.0 accelerate safetensors==0.3.1 4 | 5 | RUN mkdir /app 6 | ADD deserialize_benchmark.py /app/deserialize_benchmark.py 7 | ADD save_models.py /app/save_models.py 8 | -------------------------------------------------------------------------------- /examples/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Tensorizer Benchmarking 2 | 3 | The files in this directory contain everything needed to benchmark both the 4 | serialization and deserialization process. 5 | 6 | The benchmarks will be run using tensorizer, 7 | [Safetensors](https://huggingface.co/docs/safetensors/index), 8 | and [HuggingFace Transformers](https://huggingface.co/docs/transformers/index). 9 | 10 | As configured, the benchmark job will deserialize 11 | [OPT-30B](https://huggingface.co/facebook/opt-30b) 100 times using each library. 12 | This can be reconfigured by changing the the `MODEL_ID` and `MODEL_PATH` 13 | environment variables described under [Serialization](#serialization) 14 | and [Deserialization](#deserialization) below. 15 | 16 | ![gpt-j.png](.images/gpt-j.png) 17 | ![opt-30b.png](.images/opt-30b.png) 18 | 19 | ## Running the Benchmarks 20 | 21 | ### Setup 22 | 23 | #### Docker Image 24 | 25 | Both benchmarks use the same docker image defined by the `Dockerfile`. 26 | 27 | There is a publicly available version of this docker image already in the 28 | benchmark job manifests, but if you make any changes you will need to rebuild 29 | and push the container version containing your changes. 30 | 31 | #### PVC 32 | 33 | To keep variables as consistent as possible between the different packages, 34 | the models will be saved and loaded to a single NVMe-backed PVC 35 | (as opposed to tensorizer using CoreWeave's accelerated Object Storage). 36 | 37 | The PVC used in the benchmark jobs is defined in `benchmark-pvc.yaml`. 38 | To create it, run the following: 39 | 40 | ```bash 41 | kubectl apply -f benchmark-pvc.yaml 42 | ``` 43 | 44 | ### Serialization 45 | 46 | The serialization benchmark saves the model in all three frameworks in a 47 | single run. The script used is `save_models.py`, and the job is defined in 48 | `save-models-job.yaml`. 49 | 50 | The serialization benchmark job has a number of parameters that can be edited 51 | in the job manifest via environment variables. 52 | - `MODEL_ID`: HuggingFace model ID that will be saved and serialized 53 | - `NUM_TRIALS`: How many trials to run for each library in a single pod 54 | - `MODEL_PATH`: Where the model files will be saved 55 | 56 | ### Deserialization 57 | 58 | The deserialization benchmark initializes the model into the GPU from the 59 | serialized checkpoint files previously saved. The script used is 60 | `deserialize_benchmark.py`, and the job is defined in `benchmark-job.yaml`. 61 | 62 | The deserialization benchmark job has a number of parameters that can be 63 | edited in the job manifest via environment variables. 64 | - `MODEL_PATH`: Path to the serialized model files 65 | - `MODEL_ID`: HuggingFace model ID, used to load the tokenizer for inference 66 | - `NUM_TRIALS`: How many trials to run for each library in a single pod 67 | - `RES_PATH`: Where to save the result file 68 | - `SKIP_INFERENCE`: Set to skip the inference test after loading the model 69 | - `SKIP_HF`: Skip loading the model using HuggingFace Transformers 70 | - `SKIP_ST`: Skip loading the model using Safetensors 71 | - `SKIP_TZR`: Skip loading the model using tensorizer 72 | 73 | The benchmark job is broken up into 3 separate jobs, one for each library. 74 | Each job has a target of `100` completions and a parallelism of `1`. This 75 | means that 100 pods will be spawned for each library, each loading the model 76 | once. 77 | -------------------------------------------------------------------------------- /examples/benchmark/benchmark-job.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: tensorizer-benchmark-tzr 5 | spec: 6 | parallelism: 1 7 | completions: 100 8 | template: 9 | spec: 10 | containers: 11 | - name: benchmark 12 | image: navarrepratt/tensorizer-benchmark:1.2.9 13 | imagePullPolicy: IfNotPresent 14 | command: [ "python", "/app/deserialize_benchmark.py" ] 15 | resources: 16 | requests: 17 | cpu: "8" 18 | memory: 64Gi 19 | nvidia.com/gpu: "1" 20 | limits: 21 | cpu: "8" 22 | memory: 64Gi 23 | nvidia.com/gpu: "1" 24 | env: 25 | - name: PYTHONUNBUFFERED 26 | value: "1" 27 | - name: MODEL_PATH 28 | value: "/mnt/tensorizer/models/opt-30b/fp16" 29 | - name: MODEL_ID 30 | value: "facebook/opt-30b" 31 | - name: NUM_TRIALS 32 | value: "1" 33 | - name: RES_PATH 34 | value: "/mnt/tensorizer/opt-results" 35 | # - name: SKIP_INFERENCE 36 | # value: "1" 37 | - name: SKIP_HF 38 | value: "1" 39 | - name: SKIP_ST 40 | value: "1" 41 | volumeMounts: 42 | - name: tensorizer-benchmark 43 | mountPath: /mnt/tensorizer 44 | volumes: 45 | - name: tensorizer-benchmark 46 | persistentVolumeClaim: 47 | claimName: tensorizer-benchmark-nvme 48 | affinity: 49 | nodeAffinity: 50 | requiredDuringSchedulingIgnoredDuringExecution: 51 | nodeSelectorTerms: 52 | - matchExpressions: 53 | - key: topology.kubernetes.io/region 54 | operator: In 55 | values: 56 | - LAS1 57 | - key: gpu.nvidia.com/model 58 | operator: In 59 | values: 60 | - A100_NVLINK_80GB 61 | restartPolicy: Never 62 | backoffLimit: 1 63 | --- 64 | apiVersion: batch/v1 65 | kind: Job 66 | metadata: 67 | name: tensorizer-benchmark-st 68 | spec: 69 | parallelism: 1 70 | completions: 100 71 | template: 72 | spec: 73 | containers: 74 | - name: benchmark 75 | image: navarrepratt/tensorizer-benchmark:1.2.8 76 | imagePullPolicy: IfNotPresent 77 | command: [ "python", "/app/deserialize_benchmark.py" ] 78 | resources: 79 | requests: 80 | cpu: "8" 81 | memory: 128Gi 82 | nvidia.com/gpu: "1" 83 | limits: 84 | cpu: "8" 85 | memory: 128Gi 86 | nvidia.com/gpu: "1" 87 | env: 88 | - name: PYTHONUNBUFFERED 89 | value: "1" 90 | - name: MODEL_PATH 91 | value: "/mnt/tensorizer/models/opt-30b/fp16" 92 | - name: MODEL_ID 93 | value: "facebook/opt-30b" 94 | - name: NUM_TRIALS 95 | value: "1" 96 | - name: RES_PATH 97 | value: "/mnt/tensorizer/opt-results" 98 | # - name: SKIP_INFERENCE 99 | # value: "1" 100 | - name: SKIP_HF 101 | value: "1" 102 | - name: SKIP_TZR 103 | value: "1" 104 | volumeMounts: 105 | - name: tensorizer-benchmark 106 | mountPath: /mnt/tensorizer 107 | volumes: 108 | - name: tensorizer-benchmark 109 | persistentVolumeClaim: 110 | claimName: tensorizer-benchmark-nvme 111 | affinity: 112 | nodeAffinity: 113 | requiredDuringSchedulingIgnoredDuringExecution: 114 | nodeSelectorTerms: 115 | - matchExpressions: 116 | - key: topology.kubernetes.io/region 117 | operator: In 118 | values: 119 | - LAS1 120 | - key: gpu.nvidia.com/model 121 | operator: In 122 | values: 123 | - A100_NVLINK_80GB 124 | restartPolicy: Never 125 | backoffLimit: 1 126 | --- 127 | apiVersion: batch/v1 128 | kind: Job 129 | metadata: 130 | name: tensorizer-benchmark-hf 131 | spec: 132 | parallelism: 1 133 | completions: 100 134 | template: 135 | spec: 136 | containers: 137 | - name: benchmark 138 | image: navarrepratt/tensorizer-benchmark:1.2.8 139 | imagePullPolicy: IfNotPresent 140 | command: [ "python", "/app/deserialize_benchmark.py" ] 141 | resources: 142 | requests: 143 | cpu: "8" 144 | memory: 128Gi 145 | nvidia.com/gpu: "1" 146 | limits: 147 | cpu: "8" 148 | memory: 128Gi 149 | nvidia.com/gpu: "1" 150 | env: 151 | - name: PYTHONUNBUFFERED 152 | value: "1" 153 | - name: MODEL_PATH 154 | value: "/mnt/tensorizer/models/opt-30b/fp16" 155 | - name: MODEL_ID 156 | value: "facebook/opt-30b" 157 | - name: NUM_TRIALS 158 | value: "1" 159 | - name: RES_PATH 160 | value: "/mnt/tensorizer/opt-results" 161 | # - name: SKIP_INFERENCE 162 | # value: "1" 163 | - name: SKIP_TZR 164 | value: "1" 165 | - name: SKIP_ST 166 | value: "1" 167 | volumeMounts: 168 | - name: tensorizer-benchmark 169 | mountPath: /mnt/tensorizer 170 | volumes: 171 | - name: tensorizer-benchmark 172 | persistentVolumeClaim: 173 | claimName: tensorizer-benchmark-nvme 174 | affinity: 175 | nodeAffinity: 176 | requiredDuringSchedulingIgnoredDuringExecution: 177 | nodeSelectorTerms: 178 | - matchExpressions: 179 | - key: topology.kubernetes.io/region 180 | operator: In 181 | values: 182 | - LAS1 183 | - key: gpu.nvidia.com/model 184 | operator: In 185 | values: 186 | - A100_NVLINK_80GB 187 | restartPolicy: Never 188 | backoffLimit: 1 189 | -------------------------------------------------------------------------------- /examples/benchmark/benchmark-pvc.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | name: tensorizer-benchmark-nvme 5 | spec: 6 | storageClassName: shared-nvme-las1 7 | accessModes: 8 | - ReadWriteMany 9 | resources: 10 | requests: 11 | storage: 500Gi -------------------------------------------------------------------------------- /examples/benchmark/deserialize_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from tensorizer import TensorDeserializer 10 | from tensorizer.utils import convert_bytes, get_mem_usage 11 | 12 | # disable missing keys and unexpected key warnings 13 | os.environ["TRANSFORMERS_VERBOSITY"] = "error" 14 | 15 | # Improve safetensors performance 16 | os.environ["SAFETENSORS_FAST_GPU"] = "1" 17 | 18 | from accelerate import init_empty_weights 19 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 20 | 21 | 22 | def convert_to_bool(val: str) -> bool: 23 | return val.strip().lower() not in ("", "0", "no", "f", "false") 24 | 25 | 26 | MODEL_PATH = Path(os.environ.get("MODEL_PATH", f"./models")) 27 | TZR_PATH = MODEL_PATH / "model.tensors" 28 | HF_PATH = MODEL_PATH / "hf" 29 | RES_PATH = Path(os.environ.get("RES_PATH", "./results")) 30 | NUM_TRIALS = int(os.environ.get("NUM_TRIALS", 1)) 31 | SKIP_HF = convert_to_bool(os.environ.get("SKIP_HF", "")) 32 | SKIP_TZR = convert_to_bool(os.environ.get("SKIP_TZR", "")) 33 | SKIP_ST = convert_to_bool(os.environ.get("SKIP_ST", "")) 34 | SKIP_INFERENCE = convert_to_bool(os.environ.get("SKIP_INFERENCE", "")) 35 | CURL_PATH = shutil.which("curl") 36 | 37 | 38 | RES_PATH.mkdir(parents=True, exist_ok=True) 39 | config = AutoConfig.from_pretrained(HF_PATH) 40 | 41 | DEVICE = torch.device("cuda:0") 42 | 43 | 44 | def run_inference(model): 45 | if SKIP_INFERENCE: 46 | return 47 | # Tokenize and generate 48 | model.eval() 49 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH / "hf") 50 | input_ids = tokenizer.encode( 51 | "We the people, of the United States of America", return_tensors="pt" 52 | ).to(DEVICE) 53 | 54 | torch.manual_seed(100) 55 | with torch.no_grad(): 56 | output = model.generate(input_ids, max_new_tokens=50, do_sample=True) 57 | 58 | print(f"Output: {tokenizer.decode(output[0], skip_special_tokens=True)}") 59 | 60 | 61 | def hf_load() -> float: 62 | print("~" * 25) 63 | 64 | before_mem = get_mem_usage() 65 | 66 | start = time.time() 67 | model = AutoModelForCausalLM.from_pretrained( 68 | HF_PATH, 69 | # revision="float16", 70 | torch_dtype=torch.float16, 71 | low_cpu_mem_usage=True, 72 | config=config, 73 | use_safetensors=False, 74 | device_map="auto", 75 | ) 76 | duration = time.time() - start 77 | 78 | after_mem = get_mem_usage() 79 | print(f"Loaded huggingface model in {duration:0.2f}s") 80 | print(f"Memory usage before: {before_mem}") 81 | print(f"Memory usage after: {after_mem}") 82 | run_inference(model) 83 | 84 | return duration 85 | 86 | 87 | def tzr_load() -> float: 88 | print("~" * 25) 89 | before_mem = get_mem_usage() 90 | 91 | start = time.time() 92 | # This ensures that the model is not initialized. 93 | with init_empty_weights(): 94 | model = AutoModelForCausalLM.from_config(config) 95 | deserializer = TensorDeserializer(TZR_PATH, plaid_mode=True, device=DEVICE) 96 | deserializer.load_into_module(model) 97 | end = time.time() 98 | 99 | # Brag about how fast we are. 100 | total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) 101 | duration = end - start 102 | per_second = convert_bytes(deserializer.total_tensor_bytes / duration) 103 | after_mem = get_mem_usage() 104 | deserializer.close() 105 | print(f"Deserialized {total_bytes_str} in {duration:0.2f}s, {per_second}/s") 106 | print(f"Memory usage before: {before_mem}") 107 | print(f"Memory usage after: {after_mem}") 108 | run_inference(model) 109 | 110 | return duration 111 | 112 | 113 | def st_load() -> float: 114 | print("~" * 25) 115 | 116 | before_mem = get_mem_usage() 117 | 118 | start = time.time() 119 | model = AutoModelForCausalLM.from_pretrained( 120 | HF_PATH, 121 | low_cpu_mem_usage=True, 122 | config=config, 123 | use_safetensors=True, 124 | device_map="auto", 125 | ) 126 | end = time.time() 127 | 128 | after_mem = get_mem_usage() 129 | duration = end - start 130 | 131 | print(f"Deserialized safetensors in {duration:0.2f}s") 132 | print(f"Memory usage before: {before_mem}") 133 | print(f"Memory usage after: {after_mem}") 134 | 135 | run_inference(model) 136 | 137 | return duration 138 | 139 | 140 | if not SKIP_TZR: 141 | print("\nRunning Tensorizer...") 142 | tzr_times = [tzr_load() for _ in range(NUM_TRIALS)] 143 | print( 144 | "Average tensorizer deserialization:", sum(tzr_times) / len(tzr_times) 145 | ) 146 | with open(RES_PATH / f"tzr_times_{time.time()}.npy", "wb") as f: 147 | np.save(f, np.array(tzr_times)) 148 | 149 | if not SKIP_HF: 150 | print("\nRunning Huggingface...") 151 | hf_times = [hf_load() for _ in range(NUM_TRIALS)] 152 | print("Average huggingface load:", sum(hf_times) / len(hf_times)) 153 | with open(RES_PATH / f"hf_times_{time.time()}.npy", "wb") as f: 154 | np.save(f, np.array(hf_times)) 155 | 156 | if not SKIP_ST: 157 | print("\nRunning Safetensors...") 158 | st_times = [st_load() for _ in range(NUM_TRIALS)] 159 | print("Average safetensors load:", sum(st_times) / len(st_times)) 160 | with open(RES_PATH / f"st_times_{time.time()}.npy", "wb") as f: 161 | np.save(f, np.array(st_times)) 162 | -------------------------------------------------------------------------------- /examples/benchmark/jupyter-lab-service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: jupyter 5 | spec: 6 | type: ClusterIP 7 | clusterIP: None 8 | ports: 9 | - name: notebook 10 | port: 8888 11 | protocol: TCP 12 | selector: 13 | app.kubernetes.io/name: jupyter 14 | --- 15 | apiVersion: apps/v1 16 | kind: Deployment 17 | metadata: 18 | name: jupyter 19 | spec: 20 | strategy: 21 | type: Recreate 22 | replicas: 1 23 | selector: 24 | matchLabels: 25 | app.kubernetes.io/name: jupyter 26 | template: 27 | metadata: 28 | labels: 29 | app.kubernetes.io/name: jupyter 30 | spec: 31 | containers: 32 | - name: jupyter 33 | image: tverous/pytorch-notebook:latest 34 | command: 35 | - "jupyter" 36 | - "lab" 37 | - "--ip" 38 | - "0.0.0.0" 39 | - "--no-browser" 40 | - "--allow-root" 41 | - "--notebook-dir" 42 | - "/mnt/pvc" 43 | - "--LabApp.token=''" 44 | 45 | securityContext: 46 | runAsUser: 0 47 | 48 | ports: 49 | - name: notebook 50 | containerPort: 8888 51 | protocol: TCP 52 | 53 | readinessProbe: 54 | tcpSocket: 55 | port: notebook 56 | initialDelaySeconds: 5 57 | periodSeconds: 10 58 | livenessProbe: 59 | httpGet: 60 | path: / 61 | port: notebook 62 | initialDelaySeconds: 15 63 | periodSeconds: 15 64 | failureThreshold: 3 65 | timeoutSeconds: 10 66 | 67 | volumeMounts: 68 | - name: tensorizer-benchmark-nvme 69 | mountPath: /mnt/pvc 70 | 71 | env: 72 | - name: WANDB_API_KEY 73 | valueFrom: 74 | secretKeyRef: 75 | name: wandb-token-secret 76 | key: token 77 | 78 | resources: 79 | requests: 80 | cpu: "4" 81 | memory: 16Gi 82 | limits: 83 | cpu: "4" 84 | memory: 16Gi 85 | affinity: 86 | nodeAffinity: 87 | requiredDuringSchedulingIgnoredDuringExecution: 88 | nodeSelectorTerms: 89 | - matchExpressions: 90 | - key: topology.kubernetes.io/region 91 | operator: In 92 | values: 93 | - "LAS1" 94 | volumes: 95 | - name: tensorizer-benchmark-nvme 96 | persistentVolumeClaim: 97 | claimName: tensorizer-benchmark-nvme 98 | restartPolicy: Always -------------------------------------------------------------------------------- /examples/benchmark/save-models-job.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: save-models 5 | spec: 6 | template: 7 | spec: 8 | containers: 9 | - name: benchmark 10 | image: navarrepratt/tensorizer-benchmark:1.2.7 11 | imagePullPolicy: IfNotPresent 12 | command: [ "python", "/app/save_models.py" ] 13 | resources: 14 | requests: 15 | cpu: "16" 16 | memory: 128Gi 17 | limits: 18 | cpu: "16" 19 | memory: 128Gi 20 | env: 21 | - name: PYTHONUNBUFFERED 22 | value: "1" 23 | - name: MODEL_PATH 24 | value: "/mnt/tensorizer/models/opt-30b" 25 | - name: MODEL_ID 26 | value: "facebook/opt-30b" 27 | - name: NUM_TRIALS 28 | value: "1" 29 | volumeMounts: 30 | - name: tensorizer-benchmark 31 | mountPath: /mnt/tensorizer 32 | volumes: 33 | - name: tensorizer-benchmark 34 | persistentVolumeClaim: 35 | claimName: tensorizer-benchmark-nvme 36 | affinity: 37 | nodeAffinity: 38 | requiredDuringSchedulingIgnoredDuringExecution: 39 | nodeSelectorTerms: 40 | - matchExpressions: 41 | - key: topology.kubernetes.io/region 42 | operator: In 43 | values: 44 | - LAS1 45 | restartPolicy: Never 46 | backoffLimit: 2 47 | -------------------------------------------------------------------------------- /examples/benchmark/save_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import defaultdict 4 | from pathlib import Path 5 | from typing import Dict, List, Optional 6 | 7 | import torch 8 | from safetensors.torch import save_file 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | from tensorizer import TensorSerializer 12 | 13 | MODEL_ID = os.environ.get("MODEL_ID", "EleutherAI/gpt-neo-125M") 14 | MODEL_PATH = Path(os.environ.get("MODEL_PATH", "./models")) 15 | USE_FP16 = os.environ.get("USE_FP16", "").strip().lower() not in ( 16 | "", 17 | "0", 18 | "no", 19 | "f", 20 | "false", 21 | ) 22 | NUM_TRIALS = int(os.environ.get("NUM_TRIALS", 1)) 23 | 24 | if USE_FP16: 25 | MODEL_PATH /= "fp16" 26 | # extra_args = {"revision": "float16", "torch_dtype": torch.float16} 27 | dtype = torch.float16 28 | else: 29 | # extra_args = {} 30 | dtype = torch.float32 31 | 32 | model = AutoModelForCausalLM.from_pretrained( 33 | MODEL_ID, 34 | low_cpu_mem_usage=True, 35 | ).to(dtype) 36 | 37 | AutoTokenizer.from_pretrained(MODEL_ID).save_pretrained(MODEL_PATH / "hf") 38 | 39 | 40 | def shared_pointers(tensors) -> List: 41 | """Find tensors that share the same data.""" 42 | 43 | ptrs = defaultdict(list) 44 | for k, v in tensors.items(): 45 | ptrs[v.data_ptr()].append(k) 46 | shared = [] 47 | for ptr, names in ptrs.items(): 48 | if len(names) > 1: 49 | shared.append(names) 50 | return shared 51 | 52 | 53 | def convert_shared_tensors( 54 | pt_filename: Optional[str] = None, state_dict=None 55 | ) -> Dict: 56 | """ 57 | Clone data shared between tensors. 58 | 59 | If no state_dict is given, then it will be read from the model file saved 60 | at pt_filename. 61 | 62 | Shared models are only supported by passing in the state dict. 63 | """ 64 | 65 | if state_dict is None: 66 | loaded = torch.load(pt_filename, map_location="cpu") 67 | state_dict = loaded["state_dict"] 68 | 69 | shared = shared_pointers(state_dict) 70 | for shared_weights in shared: 71 | for name in shared_weights[1:]: 72 | state_dict[name] = state_dict[name].clone() 73 | 74 | # For tensors to be contiguous 75 | state_dict = {k: v.contiguous() for k, v in state_dict.items()} 76 | 77 | return state_dict 78 | 79 | 80 | def save_hf() -> float: 81 | dest = MODEL_PATH / "hf" 82 | dest.mkdir(parents=True, exist_ok=True) 83 | 84 | start = time.time() 85 | model.save_pretrained(MODEL_PATH / "hf") 86 | end = time.time() 87 | 88 | print(f"Huggingface saved the model in {end - start:0.2f}s") 89 | return end - start 90 | 91 | 92 | def save_tzr() -> float: 93 | start = time.time() 94 | serializer = TensorSerializer(MODEL_PATH / "model.tensors") 95 | serializer.write_module(model) 96 | serializer.close() 97 | end = time.time() 98 | 99 | print( 100 | f"Serialized {serializer.total_tensor_bytes} bytes in" 101 | f" {end - start:0.2f}s" 102 | ) 103 | return end - start 104 | 105 | 106 | def save_st() -> float: 107 | sf_filename = MODEL_PATH / "hf" / "model.safetensors" 108 | 109 | start = time.time() 110 | 111 | state_dict = convert_shared_tensors(state_dict=model.state_dict()) 112 | save_file(state_dict, sf_filename, metadata={"format": "pt"}) 113 | 114 | end = time.time() 115 | 116 | print(f"Saved the safetensors file in {end - start:0.2f}s") 117 | return end - start 118 | 119 | 120 | # Huggingface save 121 | hf_times = [save_hf() for _ in range(NUM_TRIALS)] 122 | print("Average huggingface save:", sum(hf_times) / len(hf_times)) 123 | print("~" * 25) 124 | 125 | # Tensorizer save 126 | tzr_times = [save_tzr() for _ in range(NUM_TRIALS)] 127 | print("Average tensorizer serialization:", sum(tzr_times) / len(tzr_times)) 128 | print("~" * 25) 129 | 130 | # Safetensors save 131 | st_times = [save_st() for _ in range(NUM_TRIALS)] 132 | print("Average safetensors serialization:", sum(st_times) / len(st_times)) 133 | -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/coreweave/ml-containers/torch:9faf4d7-base-cuda12.0.1-torch2.0.1-vision0.15.2-audio2.0.2 2 | 3 | RUN apt-get -qq update && \ 4 | apt-get install -y redis-server lighttpd && \ 5 | apt-get clean 6 | RUN mkdir /app 7 | WORKDIR /app 8 | COPY tensorizer /app/tensorizer 9 | COPY requirements.txt /app/tensorizer 10 | RUN pip3 install -r /app/tensorizer/requirements.txt 11 | ADD examples/benchmark_buffer_size/lighttpd.conf /app/lighttpd.conf 12 | ADD examples/benchmark_buffer_size/benchmark.py /app/benchmark.py 13 | ENTRYPOINT ["/bin/bash", "-c", "redis-server --daemonize yes >/dev/null & lighttpd -f /app/lighttpd.conf & python /app/benchmark.py" ] 14 | -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/README.md: -------------------------------------------------------------------------------- 1 | Buffer Size Benchmarking Framework 2 | ================================== 3 | This package contains a benchmarking framework for testing the performance of 4 | `tensorizer` with different transport layers and buffer sizes. Currently, the 5 | script tests the following: 6 | 7 | * `redis` transport layer using raw TCP socket (`RedisStreamFile`) 8 | * `https` transport layer using `curl` external process (`CURLStreamFile`) 9 | * `http` transport layer using `curl` external process (`CURLStreamFile`) 10 | * `s3` transport layer which computes authentication and uses `curl` external 11 | process (`CURLStreamFile`) 12 | 13 | It iterates through different buffer sizes, the range given by `--begin` and 14 | `--end` in powers of 2. For each buffer size, it runs the benchmark for all the 15 | transport layers. 16 | 17 | The `buffer_size` has different semantics depending on the transport layer. For 18 | Redis, it's the TCP socket buffer size. For `https`, `http`, and `s3`, it's the 19 | Python buffer size to the `curl` external process. 20 | 21 | By default, the `redis` tests are targeted to `localhost` on port `6379`. The 22 | pod definition automatically starts a Redis server on the same pod. We load the 23 | model into the Redis server from the `tensorized` S3 bucket. 24 | 25 | Running the Benchmark 26 | --------------------- 27 | You should be able to run the benchmark by invoking `kubectl apply -f benchmark.yaml` 28 | from this directory. This will start a Kubernetes Job that runs the benchmark across 29 | 10 pods. You can change the number of pods by changing the `parallelism` field in 30 | `benchmark.yaml`. 31 | 32 | To look at the benchmark results, you can run `kubectl logs --tail=-1 -l job-name==tensorizer-benchmark-read-size` 33 | which will collect the logs from all the pods and print them out. You can also 34 | look at the logs for individual pods by running `kubectl logs `. 35 | 36 | Parameterizing the Benchmark 37 | ---------------------------- 38 | If you want to test against an external Redis server, you can uncomment the 39 | `--redis` line. We provide a Helm chart in `redis-server.yaml` to deploy a 40 | Redis server in your namespace. You can install it by running `helm install 41 | redis-server redis-server.yaml`. 42 | 43 | If you want to test against a different model in the `tensorized` bucket, 44 | you can provide the `--model` flag. Please note that models larger than 2.7-3B 45 | require the container specs for GPUs to be increased to use a card with more 46 | than 8GB of memory. 47 | 48 | Depending on where you deploy your application, you may want to change the 49 | region affinity under `- key: topology.kubernetes.io/region`. This will 50 | ensure that the pods are scheduled in the same region as your application. -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/benchmark.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: batch/v1 2 | kind: Job 3 | metadata: 4 | name: tensorizer-benchmark-read-size 5 | spec: 6 | parallelism: 10 7 | completions: 100 8 | template: 9 | spec: 10 | affinity: 11 | nodeAffinity: 12 | requiredDuringSchedulingIgnoredDuringExecution: 13 | nodeSelectorTerms: 14 | - matchExpressions: 15 | - key: ethernet.coreweave.cloud/speed 16 | operator: In 17 | values: 18 | - 10G 19 | - 40G 20 | - 100G 21 | - key: topology.kubernetes.io/region 22 | operator: In 23 | values: 24 | - LGA1 25 | - LAS1 26 | - ORD1 27 | containers: 28 | - name: benchmark 29 | image: ghcr.io/coreweave/tensorizer:benchmark-b7459e6 30 | imagePullPolicy: IfNotPresent 31 | command: [ "/bin/bash" ] 32 | args: [ "-c", "redis-server --daemonize yes >/dev/null & lighttpd -f /app/lighttpd.conf & python /app/benchmark.py --json" ] 33 | #--redis redis://redis-master:6379" ] 34 | resources: 35 | requests: 36 | cpu: "8" 37 | memory: 64Gi 38 | nvidia.com/gpu: "1" 39 | limits: 40 | cpu: "8" 41 | memory: 64Gi 42 | nvidia.com/gpu: "1" 43 | env: 44 | - name: PYTHONUNBUFFERED 45 | value: "1" 46 | - name: K8S_NODE_NAME 47 | valueFrom: 48 | fieldRef: 49 | fieldPath: spec.nodeName 50 | - name: K8S_POD_NAME 51 | valueFrom: 52 | fieldRef: 53 | fieldPath: metadata.name 54 | #- name: K8S_POD_REGION 55 | # valueFrom: 56 | # fieldRef: 57 | # fieldPath: metadata.labels.topology.kubernetes.io/region 58 | #- name: K8S_LINK_SPEED 59 | # valueFrom: 60 | # fieldRef: 61 | # fieldPath: metadata.labels.ethernet.coreweave.cloud/speed 62 | topologySpreadConstraints: 63 | - maxSkew: 1 64 | topologyKey: kubernetes.io/hostname 65 | whenUnsatisfiable: DoNotSchedule 66 | labelSelector: 67 | matchLabels: 68 | job-name: tensorizer-benchmark-read-size 69 | - maxSkew: 1 70 | topologyKey: topology.kubernetes.io/region 71 | whenUnsatisfiable: ScheduleAnyway 72 | labelSelector: 73 | matchLabels: 74 | job-name: tensorizer-benchmark-read-size 75 | - maxSkew: 1 76 | topologyKey: kubernetes.io/instance-type 77 | whenUnsatisfiable: ScheduleAnyway 78 | labelSelector: 79 | matchLabels: 80 | job-name: tensorizer-benchmark-read-size 81 | restartPolicy: OnFailure 82 | -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/lighttpd.conf: -------------------------------------------------------------------------------- 1 | server.document-root = "/app" 2 | server.port = 3000 3 | -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/redis-server.yaml: -------------------------------------------------------------------------------- 1 | # Master pod spec 2 | master: &sharedConfig 3 | # Only use nodes from ORD1 4 | affinity: 5 | nodeAffinity: 6 | requiredDuringSchedulingIgnoredDuringExecution: 7 | nodeSelectorTerms: 8 | - matchExpressions: 9 | - key: topology.kubernetes.io/region 10 | operator: In 11 | values: 12 | - ORD1 13 | - key: ethernet.coreweave.cloud/speed 14 | operator: In 15 | values: 16 | - 40G 17 | # Set limits 18 | resources: 19 | limits: 20 | cpu: "4" 21 | memory: 49Gi 22 | # Persistent volume claim 23 | persistence: 24 | storageClass: block-hdd-ord1 25 | 26 | # Replica pod spec 27 | replica: *sharedConfig 28 | 29 | # Disable replication for now 30 | # TODO: Move to Redis Sentinel 31 | architecture: standalone 32 | 33 | # Set redis config 34 | commonConfiguration: | 35 | maxmemory 32768mb 36 | maxmemory-policy allkeys-lru 37 | appendonly no 38 | save "" 39 | 40 | # Disable password auth 41 | auth: 42 | enabled: false 43 | -------------------------------------------------------------------------------- /examples/benchmark_buffer_size/visualizations/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.8.0 2 | numpy>=1.24.2 3 | pandas>=2.1.1 4 | seaborn>=0.13.0 5 | -------------------------------------------------------------------------------- /examples/deserialize-simple.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 5 | 6 | from tensorizer import TensorDeserializer 7 | from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor 8 | 9 | model_ref = "EleutherAI/gpt-j-6B" 10 | # To run this at home, swap this with the line below for a smaller example: 11 | # model_ref = "EleutherAI/gpt-neo-125M" 12 | model_name = model_ref.split("/")[-1] 13 | # Change this to your S3 bucket. 14 | s3_bucket = "bucket" 15 | s3_uri = f"s3://{s3_bucket}/{model_name}.tensors" 16 | 17 | config = AutoConfig.from_pretrained(model_ref) 18 | 19 | # This ensures that the model is not initialized. 20 | with no_init_or_tensor(): 21 | model = AutoModelForCausalLM.from_config(config) 22 | 23 | before_mem = get_mem_usage() 24 | 25 | # Lazy load the tensors from S3 into the model. 26 | start = time.time() 27 | deserializer = TensorDeserializer(s3_uri) 28 | deserializer.load_into_module(model) 29 | end = time.time() 30 | 31 | # Brag about how fast we are. 32 | total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) 33 | duration = end - start 34 | per_second = convert_bytes(deserializer.total_tensor_bytes / duration) 35 | after_mem = get_mem_usage() 36 | deserializer.close() 37 | print(f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s") 38 | print(f"Memory usage before: {before_mem}") 39 | print(f"Memory usage after: {after_mem}") 40 | 41 | # Tokenize and generate 42 | model.eval() 43 | tokenizer = AutoTokenizer.from_pretrained(model_ref) 44 | eos = tokenizer.eos_token_id 45 | input_ids = tokenizer.encode( 46 | "¡Hola! Encantado de conocerte. hoy voy a", return_tensors="pt" 47 | ).to("cuda") 48 | 49 | with torch.no_grad(): 50 | output = model.generate( 51 | input_ids, max_new_tokens=50, do_sample=True, pad_token_id=eos 52 | ) 53 | 54 | print(f"Output: {tokenizer.decode(output[0], skip_special_tokens=True)}") 55 | -------------------------------------------------------------------------------- /examples/deserialize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 7 | 8 | import tensorizer.serialization 9 | from tensorizer import DecryptionParams, TensorDeserializer 10 | from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor 11 | 12 | parser = argparse.ArgumentParser("deserialize") 13 | parser.add_argument("--source", required=True, help="local path or URL") 14 | parser.add_argument("--model-ref", default="EleutherAI/gpt-j-6B") 15 | parser.add_argument("--no-plaid", action="store_true") 16 | parser.add_argument("--lazy-load", action="store_true") 17 | parser.add_argument("--verify-hash", action="store_true") 18 | parser.add_argument("--encryption", action="store_true") 19 | parser.add_argument("--viztracer", action="store_true") 20 | parser.add_argument("--num-readers", type=int, default=None) 21 | 22 | args = parser.parse_args() 23 | 24 | model_ref = args.model_ref 25 | # To run this at home, swap this with the line below for a smaller example: 26 | # model_ref = "EleutherAI/gpt-neo-125M" 27 | model_name = model_ref.split("/")[-1] 28 | 29 | tracer = None 30 | if args.viztracer: 31 | import viztracer 32 | 33 | tracer = viztracer.VizTracer(pid_suffix=True) 34 | 35 | decryption_params = None 36 | if args.encryption: 37 | decryption_params = DecryptionParams.from_string( 38 | os.getenv("SUPER_SECRET_STRONG_PASSWORD") 39 | ) 40 | 41 | config = AutoConfig.from_pretrained(model_ref) 42 | 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | # This ensures that the pretrained model weights are not initialized, 45 | # and non-persistent buffers (generated at runtime) are on the correct device. 46 | with torch.device(device), no_init_or_tensor(): 47 | model = AutoModelForCausalLM.from_config(config) 48 | 49 | print(f"Deserializing to {device}:") 50 | before_mem = get_mem_usage() 51 | 52 | 53 | # Lazy load the tensors from S3 into the model. 54 | if tracer is not None: 55 | tracer.start() 56 | start = time.perf_counter() 57 | deserializer = TensorDeserializer( 58 | args.source, 59 | device=device, 60 | plaid_mode=not args.no_plaid, 61 | lazy_load=args.lazy_load, 62 | encryption=decryption_params, 63 | num_readers=args.num_readers, 64 | verify_hash=args.verify_hash, 65 | ) 66 | deserializer.load_into_module(model) 67 | end = time.perf_counter() 68 | if tracer is not None: 69 | tracer.stop() 70 | tracer.save() 71 | after_mem = get_mem_usage() 72 | 73 | # Brag about how fast we are. 74 | total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) 75 | duration = end - start 76 | per_second = convert_bytes(deserializer.total_tensor_bytes / duration) 77 | deserializer.close() 78 | print(f"Deserialized {total_bytes_str} in {duration:0.2f}s, {per_second}/s") 79 | print(f"Memory usage before: {before_mem}") 80 | print(f"Memory usage after: {after_mem}") 81 | 82 | # Tokenize and generate 83 | model.eval() 84 | tokenizer = AutoTokenizer.from_pretrained(model_ref) 85 | eos = tokenizer.eos_token_id 86 | input_ids = tokenizer.encode( 87 | "¡Hola! Encantado de conocerte. hoy voy a", return_tensors="pt" 88 | ).to(device) 89 | 90 | with torch.no_grad(): 91 | output = model.generate( 92 | input_ids, max_new_tokens=50, do_sample=True, pad_token_id=eos 93 | ) 94 | 95 | print(f"Output: {tokenizer.decode(output[0], skip_special_tokens=True)}") 96 | 97 | if tensorizer.serialization._enable_perf_stats: 98 | perf_stats = tensorizer.serialization._get_perf_stats() 99 | to_device_bytes = perf_stats["tensor_to_device_bytes"] 100 | to_device_secs = perf_stats["tensor_to_device_secs"] 101 | to_device_speed = to_device_bytes / to_device_secs 102 | readinto_bytes = perf_stats["file_readinto_bytes"] 103 | readinto_secs = perf_stats["file_readinto_secs"] 104 | readinto_speed = readinto_bytes / readinto_secs 105 | 106 | print( 107 | f"to CUDA stats: {to_device_bytes} bytes in" 108 | f" {to_device_secs:.3f}s, {convert_bytes(to_device_speed, False)}/s" 109 | ) 110 | print( 111 | f"readinto stats: {readinto_bytes} bytes in" 112 | f" {readinto_secs:.3f}s, {convert_bytes(readinto_speed, False)}/s" 113 | ) 114 | -------------------------------------------------------------------------------- /examples/encrypt_existing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import binascii 3 | from contextlib import ExitStack 4 | 5 | 6 | def parse_args(argv=None) -> argparse.Namespace: 7 | parser = argparse.ArgumentParser( 8 | description=( 9 | "add or remove encryption from an already-tensorized file. See" 10 | " docs/encryption.md for an explanation." 11 | ) 12 | ) 13 | 14 | subparsers = parser.add_subparsers( 15 | help="whether to add or remove encryption" 16 | ) 17 | 18 | LIMITS_METAVAR = "" 19 | 20 | add_parser = subparsers.add_parser( 21 | "add", description="add encryption to an already-tensorized file" 22 | ) 23 | add_subparsers = add_parser.add_subparsers( 24 | help="key derivation / generation method" 25 | ) 26 | 27 | add_pwhash_parser = add_subparsers.add_parser( 28 | "pwhash", 29 | description=( 30 | "encrypt using a key generated with Argon2id key derivation" 31 | ), 32 | ) 33 | add_pwhash_parser.set_defaults(func=add_pwhash) 34 | add_pwhash_parser.add_argument( 35 | "--keyfile", 36 | type=argparse.FileType("rb"), 37 | required=True, 38 | help=( 39 | "file holding data to process into an encryption key using Argon2id" 40 | ), 41 | ) 42 | add_pwhash_parser.add_argument( 43 | "--no-strip-trailing-newlines", 44 | dest="strip_trailing_newlines", 45 | action="store_false", 46 | default=True, 47 | help="don't strip trailing newlines from the key file", 48 | ) 49 | 50 | add_pwhash_parser.add_argument( 51 | "--opslimit", 52 | metavar=LIMITS_METAVAR, 53 | type=str, 54 | required=True, 55 | help=( 56 | "Argon2id opslimit (CPU time difficulty; param from libsodium's" 57 | " pwhash function)" 58 | ), 59 | ) 60 | add_pwhash_parser.add_argument( 61 | "--memlimit", 62 | metavar=LIMITS_METAVAR, 63 | type=str, 64 | required=True, 65 | help=( 66 | "Argon2id memlimit (RAM difficulty; param from libsodium's pwhash" 67 | " function)" 68 | ), 69 | ) 70 | add_pwhash_parser.add_argument( 71 | "--salt", 72 | type=str, 73 | help=( 74 | "hex representation of a custom 16-byte cryptographic salt to use" 75 | " (randomly generated otherwise)" 76 | ), 77 | ) 78 | 79 | add_exact_key_parser = add_subparsers.add_parser( 80 | "exact", 81 | description="encrypt using an exact 32-byte binary key, unmodified", 82 | ) 83 | add_exact_key_parser.set_defaults(func=add_exact_key) 84 | add_exact_key_parser.add_argument( 85 | "--keyfile", 86 | type=argparse.FileType("rb"), 87 | required=True, 88 | help=( 89 | "file holding exactly 32 bytes of binary data to use verbatim as an" 90 | " encryption key" 91 | ), 92 | ) 93 | 94 | add_random_key_parser = add_subparsers.add_parser( 95 | "random", 96 | description=( 97 | "encrypt using a random 32-byte binary key, and write the key to a" 98 | " file" 99 | ), 100 | ) 101 | add_random_key_parser.set_defaults(func=add_random_key) 102 | add_random_key_parser.add_argument( 103 | "--keyfile", 104 | type=argparse.FileType("wb"), 105 | required=True, 106 | help=( 107 | "file to write 32 bytes of randomly-generated binary data used as" 108 | " an encryption key" 109 | ), 110 | ) 111 | 112 | remove_parser = subparsers.add_parser( 113 | "remove", 114 | description="remove encryption from an already-tensorized file", 115 | ) 116 | remove_subparsers = remove_parser.add_subparsers( 117 | help="key derivation / generation method" 118 | ) 119 | 120 | remove_pwhash_parser = remove_subparsers.add_parser( 121 | "pwhash", 122 | description=( 123 | "decrypt using a key generated with Argon2id key derivation" 124 | ), 125 | ) 126 | remove_pwhash_parser.set_defaults(func=remove_pwhash) 127 | remove_pwhash_parser.add_argument( 128 | "--keyfile", 129 | type=argparse.FileType("rb"), 130 | required=True, 131 | help=( 132 | "file holding data to process into an encryption key using Argon2id" 133 | ), 134 | ) 135 | remove_pwhash_parser.add_argument( 136 | "--no-strip-trailing-newlines", 137 | dest="strip_trailing_newlines", 138 | action="store_false", 139 | default=True, 140 | help="don't strip trailing newlines from the key file", 141 | ) 142 | 143 | remove_exact_key_parser = remove_subparsers.add_parser( 144 | "exact", 145 | description="decrypt using an exact 32-byte binary key, unmodified", 146 | ) 147 | remove_exact_key_parser.set_defaults(func=remove_exact_key) 148 | remove_exact_key_parser.add_argument( 149 | "--keyfile", 150 | type=argparse.FileType("rb"), 151 | required=True, 152 | help=( 153 | "file holding exactly 32 bytes of binary data to use verbatim as an" 154 | " encryption key" 155 | ), 156 | ) 157 | 158 | for subparser in ( 159 | add_pwhash_parser, 160 | add_exact_key_parser, 161 | add_random_key_parser, 162 | remove_pwhash_parser, 163 | remove_exact_key_parser, 164 | ): 165 | subparser.add_argument( 166 | "--infile", type=str, required=True, help="source file to convert" 167 | ) 168 | subparser.add_argument( 169 | "--outfile", 170 | type=str, 171 | required=True, 172 | help="where to write the resulting converted file", 173 | ) 174 | subparser.add_argument( 175 | "-q", 176 | "--quiet", 177 | action="store_true", 178 | default=False, 179 | help="show less output", 180 | ) 181 | 182 | args = parser.parse_args(argv) 183 | 184 | if args.infile == args.outfile: 185 | parser.error("--infile and --outfile can't be the same") 186 | 187 | if args.func != add_random_key: 188 | try: 189 | args.key = args.keyfile.read() 190 | args.keyfile.close() 191 | except OSError: 192 | parser.error("Provided --keyfile path could not be read") 193 | else: 194 | args.key = None 195 | 196 | exact_key_length = 32 197 | 198 | if args.func in (add_exact_key, remove_exact_key): 199 | if len(args.key) != exact_key_length: 200 | parser.error( 201 | "Invalid key length:" 202 | f" got {len(args.key)} bytes, expected {exact_key_length} bytes" 203 | ) 204 | elif ( 205 | args.func in (add_pwhash, remove_pwhash) 206 | and args.strip_trailing_newlines 207 | ): 208 | args.key = args.key.rstrip(b"\r\n") 209 | 210 | salt_length = 16 211 | 212 | if args.func == add_pwhash: 213 | if args.salt is not None: 214 | if len(args.salt) != salt_length * 2: 215 | parser.error( 216 | f"Invalid --salt length (should be {salt_length} bytes =" 217 | f" {salt_length * 2} hex characters)" 218 | ) 219 | try: 220 | args.salt = binascii.unhexlify(args.salt) 221 | assert len(args.salt) == salt_length 222 | except binascii.Error: 223 | parser.error("Invalid hexadecimal string provided for --salt") 224 | 225 | limit_options = ("SENSITIVE", "MODERATE", "INTERACTIVE", "MIN") 226 | args.opslimit = args.opslimit.upper() 227 | args.memlimit = args.memlimit.upper() 228 | try: 229 | int(args.opslimit) 230 | except ValueError: 231 | if args.opslimit not in limit_options: 232 | parser.error( 233 | "Invalid --opslimit, expected one of " 234 | + ", ".join(limit_options) 235 | + ", or an integer" 236 | ) 237 | try: 238 | int(args.memlimit) 239 | except ValueError: 240 | if args.memlimit not in limit_options: 241 | parser.error( 242 | "Invalid --memlimit, expected one of " 243 | + ", ".join(limit_options) 244 | + ", or an integer" 245 | ) 246 | 247 | return args 248 | 249 | 250 | def get_limit(value, enumeration) -> int: 251 | try: 252 | return int(value) 253 | except ValueError: 254 | value = getattr(enumeration, value, None) 255 | if value is not None: 256 | return value 257 | else: 258 | raise ValueError( 259 | f"Unrecognized limit: {value}, available:" 260 | f" {', '.join(v.name for v in enumeration)}" 261 | ) 262 | 263 | 264 | def add_pwhash(args: argparse.Namespace): 265 | from tensorizer import EncryptionParams 266 | 267 | opslimit = get_limit(args.opslimit, EncryptionParams.OpsLimit) 268 | memlimit = get_limit(args.memlimit, EncryptionParams.MemLimit) 269 | salt = args.salt 270 | key: bytes = args.key 271 | encryption_params = EncryptionParams.from_string( 272 | source=key, opslimit=opslimit, memlimit=memlimit, salt=salt 273 | ) 274 | add_encryption(encryption_params, args.infile, args.outfile, not args.quiet) 275 | print("Salt:", binascii.hexlify(encryption_params.salt).decode("ascii")) 276 | 277 | 278 | def add_exact_key(args: argparse.Namespace): 279 | from tensorizer import EncryptionParams 280 | 281 | encryption_params = EncryptionParams(key=args.key) 282 | add_encryption(encryption_params, args.infile, args.outfile, not args.quiet) 283 | 284 | 285 | def add_random_key(args: argparse.Namespace): 286 | from tensorizer import EncryptionParams 287 | 288 | encryption_params = EncryptionParams.random() 289 | args.keyfile.write(encryption_params.key) 290 | args.keyfile.close() 291 | add_encryption(encryption_params, args.infile, args.outfile, not args.quiet) 292 | 293 | 294 | def add_encryption( 295 | encryption_params, in_file: str, out_file: str, show_progress: bool = True 296 | ): 297 | from tensorizer import TensorDeserializer, TensorSerializer, TensorType 298 | 299 | with ExitStack() as cleanup: 300 | serializer = TensorSerializer(out_file, encryption=encryption_params) 301 | cleanup.callback(serializer.close) 302 | deserializer = TensorDeserializer( 303 | in_file, device="cpu", lazy_load=True, verify_hash=True 304 | ) 305 | cleanup.enter_context(deserializer) 306 | count: int = len(deserializer.keys()) 307 | i = 1 308 | for ( 309 | module_idx, 310 | tensor_type, 311 | name, 312 | tensor, 313 | ) in deserializer.read_tensors(): 314 | if show_progress: 315 | print(f"({i} / {count}) Encrypting {name}") 316 | i += 1 317 | tensor_type = TensorType(tensor_type) 318 | serializer.write_tensor(module_idx, name, tensor_type, tensor) 319 | # Release memory 320 | tensor.set_() 321 | del tensor 322 | 323 | 324 | def remove_pwhash(args: argparse.Namespace): 325 | from tensorizer import DecryptionParams 326 | 327 | decryption_params = DecryptionParams.from_string(args.key) 328 | remove_encryption( 329 | decryption_params, args.infile, args.outfile, not args.quiet 330 | ) 331 | 332 | 333 | def remove_exact_key(args: argparse.Namespace): 334 | from tensorizer import DecryptionParams 335 | 336 | decryption_params = DecryptionParams.from_key(args.key) 337 | remove_encryption( 338 | decryption_params, args.infile, args.outfile, not args.quiet 339 | ) 340 | 341 | 342 | def remove_encryption( 343 | decryption_params, in_file: str, out_file: str, show_progress: bool = True 344 | ): 345 | from tensorizer import TensorDeserializer, TensorSerializer, TensorType 346 | 347 | with ExitStack() as cleanup: 348 | serializer = TensorSerializer(out_file) 349 | cleanup.callback(serializer.close) 350 | deserializer = TensorDeserializer( 351 | in_file, 352 | device="cpu", 353 | lazy_load=True, 354 | verify_hash=True, 355 | encryption=decryption_params, 356 | ) 357 | cleanup.enter_context(deserializer) 358 | count: int = len(deserializer.keys()) 359 | i = 1 360 | for ( 361 | module_idx, 362 | tensor_type, 363 | name, 364 | tensor, 365 | ) in deserializer.read_tensors(): 366 | if show_progress: 367 | print(f"({i} / {count}) Decrypting {name}") 368 | i += 1 369 | tensor_type = TensorType(tensor_type) 370 | serializer.write_tensor(module_idx, name, tensor_type, tensor) 371 | # Release memory 372 | tensor.set_() 373 | del tensor 374 | 375 | 376 | def main(argv=None): 377 | args = parse_args(argv) 378 | args.func(args) 379 | if not args.quiet: 380 | print("Done") 381 | 382 | 383 | if __name__ == "__main__": 384 | main() 385 | -------------------------------------------------------------------------------- /examples/encryption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | 5 | import torch 6 | from transformers import AutoConfig, AutoModelForCausalLM 7 | 8 | from tensorizer import ( 9 | DecryptionParams, 10 | EncryptionParams, 11 | TensorDeserializer, 12 | TensorSerializer, 13 | ) 14 | from tensorizer.utils import no_init_or_tensor 15 | 16 | model_ref = "EleutherAI/gpt-neo-2.7B" 17 | 18 | 19 | def original_model(ref) -> torch.nn.Module: 20 | return AutoModelForCausalLM.from_pretrained(ref) 21 | 22 | 23 | def empty_model(ref) -> torch.nn.Module: 24 | config = AutoConfig.from_pretrained(ref) 25 | with no_init_or_tensor(): 26 | return AutoModelForCausalLM.from_config(config) 27 | 28 | 29 | # Set a strong string or bytes passphrase here 30 | source: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD", "") or input( 31 | "Source string to create an encryption key: " 32 | ) 33 | 34 | fd, path = tempfile.mkstemp(prefix="encrypted-tensors") 35 | 36 | try: 37 | # Encrypt a model during serialization 38 | encryption_params = EncryptionParams.from_string(source) 39 | 40 | model = original_model(model_ref) 41 | serialization_start = time.monotonic() 42 | 43 | serializer = TensorSerializer(path, encryption=encryption_params) 44 | serializer.write_module(model) 45 | serializer.close() 46 | 47 | serialization_end = time.monotonic() 48 | del model 49 | 50 | # Then decrypt it again during deserialization 51 | decryption_params = DecryptionParams.from_string(source) 52 | 53 | model = empty_model(model_ref) 54 | deserialization_start = time.monotonic() 55 | 56 | deserializer = TensorDeserializer(path, encryption=decryption_params) 57 | deserializer.load_into_module(model) 58 | deserializer.close() 59 | 60 | deserialization_end = time.monotonic() 61 | del model 62 | finally: 63 | os.close(fd) 64 | os.unlink(path) 65 | 66 | 67 | def print_speed(prefix, start, end, size): 68 | mebibyte = 1 << 20 69 | gibibyte = 1 << 30 70 | duration = end - start 71 | rate = size / duration 72 | print( 73 | f"{prefix} {size / gibibyte:.2f} GiB model in {duration:.2f} seconds," 74 | f" {rate / mebibyte:.2f} MiB/s" 75 | ) 76 | 77 | 78 | print_speed( 79 | "Serialized and encrypted", 80 | serialization_start, 81 | serialization_end, 82 | serializer.total_tensor_bytes, 83 | ) 84 | 85 | print_speed( 86 | "Deserialized encrypted", 87 | deserialization_start, 88 | deserialization_end, 89 | deserializer.total_tensor_bytes, 90 | ) 91 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers~=4.37.2 2 | diffusers==0.14.0 3 | scipy~=1.10.0 4 | accelerate==0.20.3 # Needed to serialize Mistral-7B 5 | pillow>=10.2.0 # not directly required, pinned to avoid a vulnerability 6 | setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability 7 | wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability -------------------------------------------------------------------------------- /examples/serialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tensorizer import TensorSerializer 3 | from transformers import AutoModelForCausalLM 4 | 5 | model_ref = "EleutherAI/gpt-j-6B" 6 | # For less intensive requirements, swap above with the line below: 7 | # model_ref = "EleutherAI/gpt-neo-125M" 8 | model_name = model_ref.split("/")[-1] 9 | # Change this to your S3 bucket. 10 | s3_bucket = "bucket" 11 | s3_uri = f"s3://{s3_bucket}/{model_name}.tensors" 12 | 13 | model = AutoModelForCausalLM.from_pretrained( 14 | model_ref, 15 | revision="float16", 16 | torch_dtype=torch.float16, 17 | low_cpu_mem_usage=True, 18 | ) 19 | 20 | serializer = TensorSerializer(s3_uri) 21 | serializer.write_module(model) 22 | serializer.close() 23 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorizer", 3 | "version": "0.1.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": {}, 7 | "keywords": [], 8 | "author": "CoreWeave", 9 | "license": "MIT", 10 | "dependencies": { 11 | "grpc-tools": "^1.11.2", 12 | "ts-protoc-gen": "^0.15.0" 13 | } 14 | } -------------------------------------------------------------------------------- /proto/requirements.txt: -------------------------------------------------------------------------------- 1 | grpcio==1.48.1 2 | grpcio-tools==1.48.1 3 | protobuf==3.19.5 -------------------------------------------------------------------------------- /proto/tensors.proto: -------------------------------------------------------------------------------- 1 | syntax = 'proto3'; 2 | package tensors; 3 | option go_package = "github.com/coreweave/tensorizer/tensors"; 4 | 5 | enum Dtype { 6 | DT_INVALID = 0; 7 | DT_FLOAT32 = 1; 8 | DT_FLOAT64 = 2; 9 | DT_FLOAT16 = 3; 10 | DT_BFLOAT16 = 4; 11 | DT_COMPLEX32 = 5; 12 | DT_COMPLEX64 = 6; 13 | DT_COMPLEX128 = 7; 14 | DT_UINT8 = 8; 15 | DT_INT8 = 9; 16 | DT_INT16 = 10; 17 | DT_INT32 = 11; 18 | DT_INT64 = 12; 19 | DT_BOOL = 13; 20 | DT_QUINT8 = 14; 21 | DT_QINT8 = 15; 22 | DT_QINT32 = 16; 23 | DT_QUINT4_2 = 17; 24 | } 25 | 26 | enum AttributeType { 27 | AT_PARAMETER = 0; 28 | AT_BUFFER = 1; 29 | } 30 | 31 | message Tensor { 32 | Dtype dtype = 1; 33 | repeated int64 shape = 2; 34 | bytes data = 3; 35 | optional AttributeType attr_type = 4; 36 | } 37 | 38 | message Attribute { 39 | string name = 1; 40 | oneof value { 41 | Module module = 3; 42 | Tensor tensor = 4; 43 | string string = 5; 44 | int64 int64 = 6; 45 | float float = 7; 46 | bool bool = 8; 47 | } 48 | } 49 | 50 | message Module { 51 | string name = 1; 52 | repeated string names = 2; 53 | repeated Attribute attributes = 3; 54 | } 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tensorizer" 3 | dynamic = ["version"] 4 | license = { text = "MIT License" } 5 | keywords = ["tensorizer", "machine learning", "serialization", "tensor", "pytorch"] 6 | authors = [ 7 | { name="CoreWeave" } 8 | ] 9 | description = "A tool for fast PyTorch module, model, and tensor serialization + deserialization." 10 | readme = "README.md" 11 | requires-python = ">=3.8" 12 | dependencies = [ 13 | "torch>=1.9.0", 14 | "numpy>=1.19.5", 15 | "protobuf>=3.19.5", 16 | "psutil>=5.9.4", 17 | "boto3>=1.26.0", 18 | "redis>=4.5.5", 19 | "hiredis>=2.2.0", 20 | "libnacl>=2.1.0" 21 | ] 22 | classifiers = [ 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Development Status :: 5 - Production/Stable", 26 | "Operating System :: OS Independent", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "Topic :: System :: Distributed Computing", 29 | "Topic :: Internet", 30 | "Intended Audience :: Developers" 31 | ] 32 | 33 | [project.urls] 34 | "Homepage" = "https://github.com/coreweave/tensorizer" 35 | "Repository" = "https://github.com/coreweave/tensorizer" 36 | "Changelog" = "https://github.com/coreweave/tensorizer/blob/main/CHANGELOG.md" 37 | 38 | [build-system] 39 | requires = ["setuptools"] 40 | build-backend = "setuptools.build_meta" 41 | 42 | [tool.setuptools] 43 | packages = ["tensorizer", "tensorizer._crypt"] 44 | 45 | [tool.setuptools.package-data] 46 | tensorizer = ["tensors.proto"] 47 | 48 | [tool.setuptools.dynamic.version] 49 | attr = "tensorizer._version.__version__" 50 | 51 | [tool.black] 52 | line-length = 80 53 | target-version = ["py38", "py39", "py310", "py311", "py312"] 54 | preview = true 55 | force-exclude = ''' 56 | ( 57 | ^/ATTIC 58 | | ^/tensors 59 | | .*_pb2.py 60 | ) 61 | ''' 62 | 63 | [tool.isort] 64 | profile = "black" 65 | line_length = 80 66 | src_paths = ["tensorizer/", "tests/", "examples/"] 67 | extend_skip_glob = [ 68 | "ATTIC/*", 69 | "*_pb2.py", 70 | "tensors/*", 71 | "examples/serialize.py", 72 | "examples/deserialize.py" 73 | ] 74 | use_parentheses = true 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | numpy>=1.22.2 3 | protobuf>=3.19.5 4 | psutil>=5.9.4 5 | boto3>=1.26.0 6 | redis>=4.5.5 7 | hiredis 8 | libnacl>=2.1.0 9 | setuptools>=65.5.1 # not directly required, pinned by Snyk to avoid a vulnerability 10 | wheel>=0.38.0 # not directly required, pinned by Snyk to avoid a vulnerability 11 | -------------------------------------------------------------------------------- /tensorizer/_NumpyTensor.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Optional, Sequence, Union 2 | 3 | import numpy 4 | import torch 5 | 6 | __all__ = ["_NumpyTensor"] 7 | 8 | 9 | _INTERMEDIATE_MAPPING = { 10 | 1: torch.int8, 11 | 2: torch.int16, 12 | 4: torch.int32, 13 | 8: torch.int64, 14 | } 15 | 16 | # Listing of types from a static copy of: 17 | # tuple( 18 | # dict.fromkeys( 19 | # str(t) 20 | # for t in vars(torch).values() 21 | # if isinstance(t, torch.dtype) 22 | # ) 23 | # ) 24 | _ALL_TYPES = { 25 | f"torch.{t}": v 26 | for t in ( 27 | "uint8", 28 | "int8", 29 | "int16", 30 | "int32", 31 | "int64", 32 | "float16", 33 | "float32", 34 | "float64", 35 | "complex32", 36 | "complex64", 37 | "complex128", 38 | "bool", 39 | "qint8", 40 | "quint8", 41 | "qint32", 42 | "bfloat16", 43 | "quint4x2", 44 | "quint2x4", 45 | ) 46 | if isinstance(v := getattr(torch, t, None), torch.dtype) 47 | } 48 | 49 | # torch types with no numpy equivalents 50 | # i.e. the only ones that need to be opaque 51 | # Uses a comprehension to filter out any dtypes 52 | # that don't exist in older torch versions 53 | _ASYMMETRIC_TYPES = { 54 | _ALL_TYPES[t] 55 | for t in { 56 | "torch.bfloat16", 57 | "torch.quint8", 58 | "torch.qint8", 59 | "torch.qint32", 60 | "torch.quint4x2", 61 | "torch.quint2x4", 62 | "torch.complex32", 63 | } 64 | & _ALL_TYPES.keys() 65 | } 66 | 67 | # These types aren't supported yet because they require supplemental 68 | # quantization parameters to deserialize correctly 69 | _UNSUPPORTED_TYPES = { 70 | _ALL_TYPES[t] 71 | for t in { 72 | "torch.quint8", 73 | "torch.qint8", 74 | "torch.qint32", 75 | "torch.quint4x2", 76 | "torch.quint2x4", 77 | } 78 | & _ALL_TYPES.keys() 79 | } 80 | 81 | _DECODE_MAPPING = { 82 | k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES 83 | } 84 | 85 | 86 | class _NumpyTensor(NamedTuple): 87 | data: numpy.ndarray 88 | numpy_dtype: str 89 | torch_dtype: Optional[str] 90 | 91 | @classmethod 92 | def from_buffer( 93 | cls, 94 | numpy_dtype: str, 95 | torch_dtype: Optional[str], 96 | shape_list: Sequence, 97 | buffer: memoryview, 98 | offset: int = 0, 99 | ) -> "_NumpyTensor": 100 | """ 101 | Decodes a raw byte buffer into a `_NumpyTensor` given its numpy dtype, 102 | its torch dtype, and its shape. 103 | 104 | Args: 105 | numpy_dtype: The encoded numpy dtype of the buffer. 106 | torch_dtype: The encoded torch dtype of the buffer. 107 | shape_list: The dimensions of the array represented by the buffer. 108 | buffer: The raw byte buffer containing encoded array data, 109 | as a memoryview. 110 | offset: An optional offset into the buffer to start from. 111 | 112 | Returns: 113 | A `_NumpyTensor` object that can have `.to_tensor()` called on it 114 | to receive a torch.Tensor. 115 | """ 116 | data = numpy.ndarray.__new__( 117 | numpy.memmap, 118 | shape_list, 119 | dtype=cls._decoder_dtype(numpy_dtype), 120 | buffer=buffer, 121 | offset=offset, 122 | ) 123 | return cls(data=data, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype) 124 | 125 | @classmethod 126 | def from_tensor( 127 | cls, tensor: Union[torch.Tensor, torch.nn.Module] 128 | ) -> "_NumpyTensor": 129 | """ 130 | Converts a torch tensor into a `_NumpyTensor`. 131 | May use an opaque dtype for the numpy array stored in 132 | the `data` field if the tensor's torch dtype has no numpy equivalent. 133 | See also: `_NumpyTensor.is_opaque`. 134 | 135 | Args: 136 | tensor: A torch tensor to convert to a `_NumpyTensor`. 137 | 138 | Returns: 139 | A `_NumpyTensor` with a `data` field holding a numpy array, 140 | and `numpy_dtype` and torch_dtype` fields suitable for 141 | record-keeping for serialization and deserialization. 142 | """ 143 | if tensor.dtype in _UNSUPPORTED_TYPES: 144 | raise NotImplementedError( 145 | f"Serialization for {tensor.dtype} is not implemented." 146 | ) 147 | torch_dtype = str(tensor.dtype) 148 | tensor = tensor.cpu().detach() 149 | 150 | if not cls._is_asymmetric(tensor.dtype): 151 | try: 152 | arr = tensor.numpy() 153 | numpy_dtype = arr.dtype.str 154 | return cls( 155 | data=arr, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype 156 | ) 157 | except TypeError: 158 | # Not a known asymmetric type, but torch can't convert it 159 | # so fall back to storing it as opaque data 160 | pass 161 | 162 | # Replace the dtype with some variety of int and mark as opaque data 163 | size = tensor.element_size() 164 | arr = tensor.view(cls._intermediate_type(size)).numpy() 165 | numpy_dtype = arr.dtype.str.replace("i", "V") 166 | return cls(data=arr, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype) 167 | 168 | @classmethod 169 | def from_array(cls, arr: numpy.ndarray) -> "_NumpyTensor": 170 | """ 171 | Converts a numpy array into a `_NumpyTensor`. 172 | This leaves the data as-is, but finds correct values for 173 | `numpy_dtype` and `torch_dtype`. 174 | 175 | Args: 176 | arr: A numpy array to convert to a `_NumpyTensor`. 177 | 178 | Returns: 179 | A `_NumpyTensor` with `arr` as its `data` field, 180 | and `numpy_dtype` and torch_dtype` fields suitable for 181 | record-keeping for serialization and deserialization. 182 | """ 183 | try: 184 | test_array = numpy.empty((), dtype=arr.dtype) 185 | torch_dtype = torch.from_numpy(test_array).dtype 186 | except TypeError as e: 187 | # If something were serialized with this type, 188 | # it wouldn't be able to be deserialized later. 189 | raise TypeError( 190 | f"Cannot serialize an array with dtype {arr.dtype.name}" 191 | " as a _NumpyTensor." 192 | ) from e 193 | return cls(data=arr, numpy_dtype=arr.dtype.str, torch_dtype=torch_dtype) 194 | 195 | def to_tensor(self) -> torch.Tensor: 196 | """ 197 | Converts a `_NumpyTensor` to a ``torch.Tensor`` and reifies any opaque 198 | data into the correct torch dtype. 199 | 200 | Returns: 201 | A ``torch.Tensor`` referring to the same data as the `data` field, 202 | with a correct torch dtype. 203 | """ 204 | if not self.is_opaque: 205 | return torch.from_numpy(self.data) 206 | 207 | if not self.torch_dtype: 208 | raise ValueError( 209 | "Tried to decode a tensor stored as opaque data, but no" 210 | " torch dtype was specified" 211 | ) 212 | tensor_view = torch.from_numpy(self.data) 213 | return tensor_view.view(self._decode_torch_dtype()) 214 | 215 | @property 216 | def is_opaque(self): 217 | """ 218 | Whether the ``self.data`` numpy array is opaque, 219 | i.e. stored as generic data without a meaningful dtype. 220 | 221 | Returns: 222 | True if ``self.data`` is uninterpretable without conversion 223 | to a tensor via `self.to_tensor()`, False otherwise. 224 | """ 225 | return self._is_opaque(self.numpy_dtype) 226 | 227 | @staticmethod 228 | def _intermediate_type(size: int) -> torch.dtype: 229 | """ 230 | Find a dtype to masquerade as that torch can convert to a numpy array. 231 | 232 | Args: 233 | size: The size of the dtype, in bytes. 234 | 235 | Returns: 236 | A ``torch.dtype`` for a tensor that torch can convert 237 | to a numpy array via ``tensor.numpy()``. 238 | """ 239 | try: 240 | return _INTERMEDIATE_MAPPING[size] 241 | except KeyError as e: 242 | raise ValueError( 243 | "Cannot create a numpy array with opaque elements of size" 244 | f" {size} bytes" 245 | ) from e 246 | 247 | @classmethod 248 | def _is_opaque(cls, numpy_dtype: str) -> bool: 249 | """ 250 | A check to see if the dtype needs to be swapped while decoding, 251 | based on whether the encoded dtype is in the opaque format 252 | used by this class. 253 | 254 | Args: 255 | numpy_dtype: The numpy dtype, as encoded in a tensorized file. 256 | 257 | Returns: 258 | True if the encoded dtype is opaque, False otherwise. 259 | """ 260 | return numpy.dtype(numpy_dtype).type == numpy.void 261 | 262 | @classmethod 263 | def _is_asymmetric(cls, torch_dtype: torch.dtype) -> bool: 264 | """ 265 | A check to see if the dtype needs to be swapped while encoding, 266 | based on whether numpy has a corresponding dtype or not. 267 | This check is hardcoded, not dynamic, but up to date as of torch 2.0. 268 | 269 | Args: 270 | dtype: The torch dtype to check 271 | 272 | Returns: 273 | True if a class is known not to have a corresponding numpy dtype, 274 | False otherwise. 275 | """ 276 | return torch_dtype in _ASYMMETRIC_TYPES 277 | 278 | @classmethod 279 | def _decoder_dtype(cls, numpy_dtype: str): 280 | """ 281 | Converts an opaque storage numpy dtype generated by this class 282 | into one that numpy can properly decode. 283 | 284 | NB: Even though a dtype like ``numpy.dtype(" torch.dtype: 301 | """ 302 | Parses the `self.torch_dtype` field. 303 | 304 | Returns: An instance of ``torch.dtype`` corresponding to the string 305 | stored in `self.torch_dtype`. 306 | 307 | Raises: 308 | ValueError: If `self.torch_dtype` is not set, is not in the form 309 | "torch.", cannot be found in torch, or refers to 310 | something other than a ``torch.dtype``. 311 | TypeError: If `self.torch_dtype` is not a string. 312 | """ 313 | # Quick route, table lookup for common types 314 | dtype = _DECODE_MAPPING.get(self.torch_dtype) 315 | if dtype is not None: 316 | return dtype 317 | 318 | # Long route using getattr(), any other type 319 | if not self.torch_dtype: 320 | raise ValueError("Cannot decode an empty dtype.") 321 | if not isinstance(self.torch_dtype, str): 322 | raise TypeError("torch_dtype must be a string.") 323 | module, *dtype_name = self.torch_dtype.split(".", 1) 324 | 325 | # Ensure that it's actually "torch.something" 326 | if module != "torch" or len(dtype_name) != 1: 327 | raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") 328 | 329 | try: 330 | dtype = getattr(torch, dtype_name[0]) 331 | # Ensure that it's a real dtype 332 | if not isinstance(dtype, torch.dtype): 333 | raise TypeError( 334 | "Provided torch_dtype is not an instance of torch.dtype" 335 | f" (type: {type(dtype).__name__})" 336 | ) 337 | except (AttributeError, TypeError) as e: 338 | raise ValueError(f"Invalid torch_dtype: {self.torch_dtype}") from e 339 | 340 | return dtype 341 | -------------------------------------------------------------------------------- /tensorizer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import serialization, stream_io, utils 2 | from ._version import __version__ 3 | from .serialization import * 4 | 5 | __all__ = [ 6 | *serialization.__all__, 7 | "stream_io", 8 | "utils", 9 | "protobuf", 10 | "tensors_pb2", 11 | ] 12 | -------------------------------------------------------------------------------- /tensorizer/_crypt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Internal cryptographic library for tensorizer. 3 | These functions are only meant to be used by tensorizer itself, 4 | and are not guaranteed to have a stable interface across versions. 5 | """ 6 | 7 | __all__ = ( 8 | "available", 9 | "ChunkedEncryption", 10 | "Const", 11 | "CryptographyError", 12 | "PWHash", 13 | "random_bytes", 14 | ) 15 | 16 | from ._exceptions import CryptographyError 17 | 18 | try: 19 | from ._encryption import ( 20 | ChunkedEncryption, 21 | Const, 22 | PWHash, 23 | SequentialEncryption, 24 | random_bytes, 25 | ) 26 | 27 | available: bool = True 28 | 29 | 30 | except (OSError, AttributeError): 31 | available: bool = False 32 | 33 | def __getattr__(name): 34 | if name in __all__: 35 | raise RuntimeError( 36 | "Encryption module was not initialized," 37 | " make sure a recent version of libsodium is installed" 38 | ) 39 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 40 | -------------------------------------------------------------------------------- /tensorizer/_crypt/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import json 4 | import mmap 5 | import pathlib 6 | import time 7 | from hashlib import sha256 8 | from typing import Optional, Union 9 | 10 | from ._cgroup_cpu_count import effective_cpu_count 11 | from ._encryption import ( 12 | AsymmetricParams, 13 | ChunkedEncryption, 14 | SequentialEncryption, 15 | SymmetricParams, 16 | as_ucstr, 17 | crypto_box_detached, 18 | crypto_box_open_detached, 19 | crypto_secretbox_detached, 20 | crypto_secretbox_open_detached, 21 | ) 22 | 23 | cpu_count: int = effective_cpu_count() 24 | 25 | parser = argparse.ArgumentParser( 26 | description="internal command for testing encryption performance" 27 | ) 28 | parser.add_argument( 29 | "--generate", 30 | metavar="NUMBYTES", 31 | type=int, 32 | nargs="?", 33 | default=None, 34 | help="generate a message to test with this size (default: 256 MiB)", 35 | ) 36 | parser.add_argument( 37 | "-f", 38 | "--file", 39 | type=pathlib.Path, 40 | default=None, 41 | help=( 42 | "encrypt or decrypt a file inplace" 43 | " (requires --crypt-info as well as one of --encrypt or --decrypt)" 44 | ), 45 | ) 46 | parser.set_defaults(encrypt=None) 47 | parser.add_argument( 48 | "--encrypt", action="store_true", help="encrypt a file (requires --file)" 49 | ) 50 | parser.add_argument( 51 | "--decrypt", 52 | action="store_false", 53 | dest="encrypt", 54 | help="decrypt a file (requires --file)", 55 | ) 56 | parser.add_argument( 57 | "-i", 58 | "--crypt-info", 59 | type=pathlib.Path, 60 | default=None, 61 | help="path to store and retrieve key, nonces, and MACs (requires --file)", 62 | ) 63 | parser.add_argument( 64 | "--chunk-size", 65 | type=int, 66 | default=2 << 20, 67 | help="chunk size for parallel encryption", 68 | ) 69 | parser.add_argument( 70 | "-d", 71 | "--delayed", 72 | action="store_true", 73 | help="enable delayed verification", 74 | ) 75 | parser.add_argument( 76 | "-t", 77 | "--threads", 78 | type=int, 79 | default=cpu_count, 80 | help=( 81 | "maximum number of threads for parallel encryption" 82 | f" (default: {cpu_count})" 83 | ), 84 | ) 85 | args = parser.parse_args() 86 | 87 | if args.file is not None: 88 | if args.generate is not None: 89 | parser.error("Cannot specify both --file and --generate") 90 | if args.encrypt is None: 91 | parser.error("Must specify either --encrypt or --decrypt with --file") 92 | if args.crypt_info is None: 93 | parser.error("Must specify --crypt-info with --file") 94 | if not args.encrypt: 95 | if not args.file.is_file(): 96 | parser.error("--file path is not a file and cannot be decrypted") 97 | if not args.crypt_info.is_file(): 98 | parser.error("--crypt-info path is not a file") 99 | else: 100 | if (args.crypt_info, args.encrypt) != (None, None): 101 | parser.error( 102 | "--crypt-info, --encrypt, and --decrypt are only valid" 103 | " when used with --file" 104 | ) 105 | if args.generate is None: 106 | args.generate = 256 << 20 107 | 108 | 109 | class Timer(contextlib.AbstractContextManager): 110 | def __init__(self): 111 | self.start: Optional[float] = None 112 | self.end: Optional[float] = None 113 | self.elapsed: Optional[float] = None 114 | 115 | def __enter__(self): 116 | self.start = time.monotonic() 117 | return super().__enter__() 118 | 119 | def __exit__(self, __exc_type, __exc_value, __traceback): 120 | self.end = time.monotonic() 121 | self.elapsed = self.end - self.start 122 | 123 | def rate(self, count) -> Union[float, str]: 124 | return "\u221e" if self.elapsed == 0 else count / self.elapsed 125 | 126 | 127 | def num_chunks(total_size: int, chunk_size: int) -> int: 128 | return (total_size // chunk_size) + (total_size % chunk_size != 0) 129 | 130 | 131 | def preview(text, length: int, as_hex: bool = False): 132 | with memoryview(text) as mv: 133 | truncated = mv[:length].hex() if as_hex else str(bytes(mv[:length])) 134 | return truncated + ("..." if len(mv) > length else "") 135 | 136 | 137 | def sequential_test( 138 | buffer: Union[bytearray, memoryview], asymmetric: bool = False 139 | ): 140 | if asymmetric: 141 | params = AsymmetricParams.random() 142 | key_views = {"pk": as_ucstr(params.pk), "sk": as_ucstr(params.sk)} 143 | encrypt = crypto_box_detached 144 | decrypt = crypto_box_open_detached 145 | else: 146 | params = SymmetricParams.random() 147 | key_views = {"k": as_ucstr(params.k)} 148 | encrypt = crypto_secretbox_detached 149 | decrypt = crypto_secretbox_open_detached 150 | buffer_view = as_ucstr(buffer) 151 | mac_view = as_ucstr(params.mac) 152 | nonce_view = as_ucstr(params.nonce) 153 | m_len = len(buffer) 154 | 155 | timer = Timer() 156 | mebibyte = 1 << 20 157 | with timer: 158 | encrypt( 159 | c=buffer_view, 160 | mac=mac_view, 161 | m=buffer_view, 162 | mlen=m_len, 163 | n=nonce_view, 164 | **key_views, 165 | ) 166 | 167 | print( 168 | f"Encrypted {m_len} bytes in {timer.elapsed:.6f} seconds," 169 | f" {timer.rate(m_len / mebibyte)} MiB/s", 170 | f"Contents: {preview(buffer, 32, True)}", 171 | f"MAC: {params.mac.hex()}", 172 | sep="\n", 173 | ) 174 | 175 | with timer: 176 | decrypt( 177 | m=buffer_view, 178 | c=buffer_view, 179 | mac=mac_view, 180 | clen=m_len, 181 | n=nonce_view, 182 | **key_views, 183 | ) 184 | 185 | print( 186 | f"Decrypted {m_len} bytes in {timer.elapsed:.6f} seconds," 187 | f" {timer.rate(m_len / mebibyte)} MiB/s", 188 | preview(buffer, 64), 189 | sep="\n", 190 | ) 191 | 192 | 193 | def symmetric_sequential_test( 194 | buffer: Union[bytearray, memoryview], key: Optional[bytes] = None 195 | ): 196 | if key is None: 197 | key = SymmetricParams.random().k 198 | timer = Timer() 199 | mebibyte = 1 << 20 200 | 201 | crypto = SequentialEncryption(key, buffer) 202 | with timer: 203 | crypto.encrypt() 204 | 205 | print( 206 | f"Encrypted {len(buffer)} bytes" 207 | f" in {timer.elapsed:.6f} seconds," 208 | f" {timer.rate(len(buffer) / mebibyte)} MiB/s", 209 | f"Contents: {preview(buffer, 32, True)}", 210 | f"MACs: {preview(crypto.mac, 32, True)}", 211 | sep="\n", 212 | ) 213 | 214 | with timer: 215 | crypto.decrypt() 216 | 217 | print( 218 | f"Decrypted {len(buffer)} bytes" 219 | f" in {timer.elapsed:.6f} seconds," 220 | f" {timer.rate(len(buffer) / mebibyte)} MiB/s", 221 | preview(buffer, 64), 222 | sep="\n", 223 | ) 224 | 225 | 226 | def parallel_test( 227 | buffer: Union[bytearray, memoryview], 228 | chunk_size: int, 229 | automatic_verification: bool, 230 | num_threads: int, 231 | key: Optional[bytes] = None, 232 | ): 233 | if key is None: 234 | key = SymmetricParams.random().k 235 | 236 | timer = Timer() 237 | mebibyte = 1 << 20 238 | 239 | with ChunkedEncryption( 240 | key, 241 | buffer, 242 | chunk_size, 243 | num_threads=num_threads, 244 | automatic_verification=automatic_verification, 245 | ) as crypto: 246 | with timer: 247 | crypto.encrypt_all(wait=True, timeout=None) 248 | 249 | macs = crypto.concatenated_macs() 250 | 251 | print( 252 | f"Encrypted {len(buffer)} bytes" 253 | f" in {timer.elapsed:.6f} seconds," 254 | f" {timer.rate(len(buffer) / mebibyte)} MiB/s", 255 | f"Contents: {preview(buffer, 32, True)}", 256 | f"MACs: {preview(macs, 32, True)}", 257 | sep="\n", 258 | ) 259 | 260 | with timer: 261 | crypto.decrypt_all(wait=True, timeout=None) 262 | 263 | print( 264 | f"Decrypted {len(buffer)} bytes" 265 | f" in {timer.elapsed:.6f} seconds," 266 | f" {timer.rate(len(buffer) / mebibyte)} MiB/s", 267 | preview(buffer, 64), 268 | sep="\n", 269 | ) 270 | 271 | 272 | def parallel_transform_file( 273 | encrypt: bool, 274 | path: pathlib.Path, 275 | crypt_info_path: pathlib.Path, 276 | chunk_size: int, 277 | automatic_verification: bool, 278 | num_threads: int, 279 | ): 280 | with contextlib.ExitStack() as context: 281 | file = context.enter_context(path.open("rb+")) 282 | buffer = context.enter_context(mmap.mmap(file.fileno(), 0)) 283 | context.callback(buffer.flush) 284 | num_bytes = len(buffer) 285 | context = context.pop_all() 286 | if encrypt: 287 | params = SymmetricParams.random() 288 | key = params.k 289 | nonces = None 290 | macs = None 291 | else: 292 | with crypt_info_path.open("rb") as crypt_info_file: 293 | json_params = json.load(crypt_info_file) 294 | key = bytes.fromhex(json_params["key"]) 295 | nonces = tuple(map(bytes.fromhex, json_params["nonces"])) 296 | macs = tuple(map(bytes.fromhex, json_params["macs"])) 297 | 298 | timer = Timer() 299 | mebibyte = 1 << 20 300 | 301 | with context, ChunkedEncryption( 302 | key, 303 | buffer, 304 | chunk_size, 305 | nonces=nonces, 306 | macs=macs, 307 | num_threads=num_threads, 308 | automatic_verification=automatic_verification, 309 | ) as crypto: 310 | try: 311 | if encrypt: 312 | with timer: 313 | crypto.encrypt_all(wait=True, timeout=None) 314 | else: 315 | with timer: 316 | crypto.decrypt_all(wait=True, timeout=None) 317 | macs = crypto.macs 318 | del crypto 319 | finally: 320 | import gc 321 | 322 | gc.collect() 323 | 324 | if encrypt: 325 | with crypt_info_path.open("w") as crypt_info_file: 326 | json_params = { 327 | "key": key.hex(), 328 | "nonces": [n.hex() for n in nonces], 329 | "macs": [m.hex() for m in macs], 330 | } 331 | json.dump(json_params, crypt_info_file) 332 | 333 | print( 334 | f"{'Encrypted' if encrypt else 'Decrypted'}" 335 | f" {num_bytes} bytes in {timer.elapsed:.6f} seconds" 336 | f" {timer.rate(num_bytes / mebibyte)} MiB/s" 337 | ) 338 | 339 | 340 | def run_tests( 341 | message: bytearray, 342 | chunk_size: int, 343 | threads: int, 344 | automatic_verification: bool, 345 | ): 346 | original_hash = sha256(message).digest() 347 | 348 | print("Asymmetric") 349 | sequential_test(message, asymmetric=True) 350 | 351 | assert sha256(message).digest() == original_hash 352 | 353 | print("\nSymmetric") 354 | sequential_test(message, asymmetric=False) 355 | 356 | assert sha256(message).digest() == original_hash 357 | 358 | key: bytes = SymmetricParams.random().k 359 | 360 | print("\nSymmetric (OO)") 361 | symmetric_sequential_test(message, key) 362 | 363 | assert sha256(message).digest() == original_hash 364 | 365 | print("\nParallel") 366 | parallel_test( 367 | message, 368 | chunk_size=chunk_size, 369 | automatic_verification=automatic_verification, 370 | num_threads=threads, 371 | key=key, 372 | ) 373 | 374 | assert sha256(message).digest() == original_hash 375 | 376 | 377 | if args.generate is not None: 378 | message = bytearray(b"Hello, World!") 379 | message_size = args.generate 380 | message *= num_chunks(message_size, len(message)) 381 | message[message_size:] = b"" 382 | assert len(message) == message_size 383 | run_tests(message, args.chunk_size, args.threads, not args.delayed) 384 | else: 385 | parallel_transform_file( 386 | args.encrypt, 387 | args.file, 388 | args.crypt_info, 389 | chunk_size=args.chunk_size, 390 | automatic_verification=not args.delayed, 391 | num_threads=args.threads, 392 | ) 393 | -------------------------------------------------------------------------------- /tensorizer/_crypt/_cgroup_cpu_count.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import os 3 | import pathlib 4 | import sys 5 | from fractions import Fraction 6 | from functools import lru_cache 7 | from typing import Optional, Union 8 | 9 | __all__ = ("effective_cpu_count", "RoundingMode") 10 | 11 | 12 | class RoundingMode(enum.Enum): 13 | UP = 1 14 | DOWN = 2 15 | HALF_EVEN = 3 16 | 17 | 18 | if sys.platform == "linux": 19 | 20 | @lru_cache(maxsize=None) 21 | def _cpu_quota() -> Optional[Fraction]: 22 | cgroup = pathlib.Path("/sys/fs/cgroup") 23 | cgroup_v1 = cgroup / "cpu,cpuacct" 24 | cgroup_v2 = cgroup / "user.slice" / "cpu.max" 25 | try: 26 | if not cgroup.is_dir(): 27 | return None 28 | elif cgroup_v1.is_dir(): 29 | quota, period = ( 30 | (cgroup_v1 / p).read_text() 31 | for p in ("cpu.cfs_quota_us", "cpu.cfs_period_us") 32 | ) 33 | elif cgroup_v2.is_file(): 34 | quota, period = cgroup_v2.read_text().split() 35 | else: 36 | raise OSError() 37 | 38 | if quota == "max": 39 | return None 40 | 41 | q, p = map(int, (quota, period)) 42 | if q > 0 and p > 0: 43 | return Fraction(q, p) 44 | else: 45 | raise ValueError() 46 | except (OSError, ValueError): 47 | return None 48 | 49 | else: 50 | 51 | def _cpu_quota() -> None: 52 | return None 53 | 54 | 55 | def effective_cpu_count( 56 | rounding: Optional[RoundingMode] = RoundingMode.UP, 57 | ) -> Union[int, Fraction]: 58 | if not isinstance(rounding, (RoundingMode, type(None))): 59 | raise TypeError("Invalid type for rounding mode") 60 | quota: Optional[Fraction] = _cpu_quota() 61 | if quota is None: 62 | return os.cpu_count() 63 | if rounding is None: 64 | return quota 65 | else: 66 | if rounding is RoundingMode.UP: 67 | return quota.__ceil__() 68 | elif rounding is RoundingMode.DOWN: 69 | return quota.__floor__() 70 | elif rounding is RoundingMode.HALF_EVEN: 71 | return round(quota) 72 | else: 73 | raise ValueError("Unknown rounding mode") 74 | -------------------------------------------------------------------------------- /tensorizer/_crypt/_exceptions.py: -------------------------------------------------------------------------------- 1 | class CryptographyError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /tensorizer/_crypt_info.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | import io 4 | import struct 5 | import typing 6 | import weakref 7 | from functools import partial 8 | from typing import ClassVar, List, Optional, Sequence, Union 9 | 10 | from tensorizer._internal_utils import _unpack_memoryview_from, _variable_read 11 | 12 | 13 | class CryptInfoChunk(abc.ABC): 14 | _chunk_types: ClassVar[ 15 | typing.MutableMapping[int, typing.Type["CryptInfoChunk"]] 16 | ] = weakref.WeakValueDictionary() 17 | chunk_type: ClassVar[int] 18 | _length_segment: ClassVar[struct.Struct] = struct.Struct(" "CryptInfoChunk": 26 | chunk_type = CryptInfoChunk._chunk_type_segment.unpack_from( 27 | buffer, offset 28 | )[0] 29 | return CryptInfoChunk._chunk_types[chunk_type].unpack_from( 30 | buffer, offset + CryptInfoChunk._chunk_type_segment.size 31 | ) 32 | 33 | @abc.abstractmethod 34 | def pack_into(self, buffer, offset: int = 0) -> int: 35 | CryptInfoChunk._chunk_type_segment.pack_into( 36 | buffer, offset, self.chunk_type 37 | ) 38 | return offset + CryptInfoChunk._chunk_type_segment.size 39 | 40 | def pack(self) -> bytes: 41 | buffer = io.BytesIO(bytes(self.size)) 42 | self.pack_into(buffer.getbuffer(), 0) 43 | return buffer.getvalue() 44 | 45 | def sized_pack(self) -> bytes: 46 | buffer = io.BytesIO(bytes(self.sized_size)) 47 | self.sized_pack_into(buffer.getbuffer(), 0) 48 | return buffer.getvalue() 49 | 50 | def sized_pack_into(self, buffer, offset: int = 0) -> int: 51 | length_offset = offset 52 | offset += CryptInfoChunk._length_segment.size 53 | ret = self.pack_into(buffer, offset) 54 | CryptInfoChunk._length_segment.pack_into( 55 | buffer, length_offset, ret - offset 56 | ) 57 | return ret 58 | 59 | @property 60 | def sized_size(self) -> int: 61 | return self.size + CryptInfoChunk._length_segment.size 62 | 63 | @property 64 | @abc.abstractmethod 65 | def size(self) -> int: 66 | return CryptInfoChunk._chunk_type_segment.size 67 | 68 | # noinspection PyMethodOverriding 69 | def __init_subclass__( 70 | cls, /, *, chunk_type: Optional[int] = None, **kwargs 71 | ): 72 | super().__init_subclass__(**kwargs) 73 | if chunk_type is not None: 74 | cls.chunk_type = chunk_type 75 | CryptInfoChunk._chunk_types[chunk_type] = cls 76 | 77 | 78 | class KeyDerivationChunk(CryptInfoChunk, abc.ABC, chunk_type=1): 79 | _derivation_methods: ClassVar[ 80 | typing.MutableMapping[int, typing.Type["KeyDerivationChunk"]] 81 | ] = weakref.WeakValueDictionary() 82 | derivation_method: ClassVar[int] 83 | _derivation_method_segment: ClassVar[struct.Struct] = struct.Struct(" "KeyDerivationChunk": 99 | derivation_method = ( 100 | KeyDerivationChunk._derivation_method_segment.unpack_from( 101 | buffer, offset 102 | )[0] 103 | ) 104 | return KeyDerivationChunk._derivation_methods[ 105 | derivation_method 106 | ].unpack_from( 107 | buffer, offset + KeyDerivationChunk._derivation_method_segment.size 108 | ) 109 | 110 | @abc.abstractmethod 111 | def pack_into(self, buffer, offset: int = 0) -> int: 112 | offset = super().pack_into(buffer, offset) 113 | KeyDerivationChunk._derivation_method_segment.pack_into( 114 | buffer, offset, self.derivation_method 115 | ) 116 | return offset + KeyDerivationChunk._derivation_method_segment.size 117 | 118 | @property 119 | @abc.abstractmethod 120 | def size(self) -> int: 121 | return KeyDerivationChunk._derivation_method_segment.size + super().size 122 | 123 | 124 | @dataclasses.dataclass(frozen=True) 125 | class PWHashKeyDerivationChunk(KeyDerivationChunk, derivation_method=1): 126 | opslimit: int 127 | memlimit: int 128 | alg: int 129 | salt: Union[bytes, bytearray, memoryview] 130 | 131 | __slots__ = ("opslimit", "memlimit", "alg", "salt") 132 | 133 | _algorithm_segment: ClassVar[struct.Struct] = struct.Struct( 134 | "<" # Little-endian 135 | "Q" # Opslimit (unsigned long long) 136 | "Q" # Memlimit (size_t) 137 | "i" # Algorithm identifier (int) 138 | ) 139 | _salt_segment_template: ClassVar[str] = ( 140 | " "PWHashKeyDerivationChunk": 152 | opslimit, memlimit, alg = cls._algorithm_segment.unpack_from( 153 | buffer, offset 154 | ) 155 | offset += cls._algorithm_segment.size 156 | salt = cls.read_salt(buffer, offset)[0] 157 | return cls(opslimit=opslimit, memlimit=memlimit, alg=alg, salt=salt) 158 | 159 | def pack_into(self, buffer, offset: int = 0) -> int: 160 | offset = super().pack_into(buffer, offset) 161 | self._algorithm_segment.pack_into( 162 | buffer, offset, self.opslimit, self.memlimit, self.alg 163 | ) 164 | offset += self._algorithm_segment.size 165 | salt_segment = self._salt_segment 166 | salt_segment.pack_into(buffer, offset, len(self.salt), self.salt) 167 | offset += salt_segment.size 168 | return offset 169 | 170 | @property 171 | def size(self) -> int: 172 | return self._algorithm_segment.size + 2 + len(self.salt) + super().size 173 | 174 | 175 | @dataclasses.dataclass 176 | class XSalsa20ParallelChunk(CryptInfoChunk, chunk_type=2): 177 | chunk_size: int 178 | nonce: Union[bytes, bytearray, memoryview] 179 | num_macs: int = dataclasses.field(init=False) 180 | macs: Sequence[Union[bytes, bytearray, memoryview]] 181 | 182 | __slots__ = ("chunk_size", "nonce", "macs", "__dict__") 183 | 184 | NONCE_BYTES: ClassVar[int] = 24 185 | MAC_BYTES: ClassVar[int] = 16 186 | CHUNK_QUANTUM: ClassVar[int] = 64 187 | MINIMUM_CHUNK_SIZE: ClassVar[int] = 1024 188 | 189 | _header_segment: ClassVar[struct.Struct] = struct.Struct( 190 | "<" # Little-endian 191 | "Q" # Chunk size 192 | f"{NONCE_BYTES:d}s" # Initial nonce 193 | "Q" # Number of MACs 194 | ) 195 | 196 | _mac_segment: ClassVar[struct.Struct] = struct.Struct(f"<{MAC_BYTES:d}s") 197 | 198 | def __post_init__(self): 199 | if len(self.nonce) != self.NONCE_BYTES: 200 | raise ValueError("Invalid nonce size") 201 | if not ( 202 | isinstance(self.chunk_size, int) 203 | and (self.chunk_size % self.CHUNK_QUANTUM == 0) 204 | and self.chunk_size >= self.MINIMUM_CHUNK_SIZE 205 | ): 206 | raise ValueError("Invalid chunk size") 207 | self.num_macs = len(self.macs) 208 | for mac in self.macs: 209 | if len(mac) != self.MAC_BYTES: 210 | raise ValueError("Invalid MAC size") 211 | 212 | @classmethod 213 | def unpack_from(cls, buffer, offset: int = 0) -> "XSalsa20ParallelChunk": 214 | chunk_size, nonce, num_macs = ( 215 | XSalsa20ParallelChunk._header_segment.unpack_from(buffer, offset) 216 | ) 217 | offset += XSalsa20ParallelChunk._header_segment.size 218 | macs = [] 219 | for i in range(num_macs): 220 | macs.append( 221 | _unpack_memoryview_from( 222 | XSalsa20ParallelChunk._mac_segment.size, buffer, offset 223 | ) 224 | ) 225 | offset += XSalsa20ParallelChunk._mac_segment.size 226 | return cls(chunk_size, nonce, macs) 227 | 228 | def pack_into(self, buffer, offset: int = 0) -> int: 229 | offset = super().pack_into(buffer, offset) 230 | XSalsa20ParallelChunk._header_segment.pack_into( 231 | buffer, offset, self.chunk_size, self.nonce, self.num_macs 232 | ) 233 | offset += XSalsa20ParallelChunk._header_segment.size 234 | for mac in self.macs: 235 | XSalsa20ParallelChunk._mac_segment.pack_into(buffer, offset, mac) 236 | del mac 237 | offset += XSalsa20ParallelChunk._mac_segment.size 238 | return offset 239 | 240 | @property 241 | def size(self) -> int: 242 | return ( 243 | XSalsa20ParallelChunk._header_segment.size 244 | + XSalsa20ParallelChunk._mac_segment.size * self.num_macs 245 | + super().size 246 | ) 247 | 248 | 249 | @dataclasses.dataclass 250 | class XSalsa20SequentialChunk(CryptInfoChunk, chunk_type=3): 251 | nonce: Union[bytes, bytearray, memoryview] 252 | mac: Union[bytes, bytearray, memoryview] 253 | 254 | __slots__ = ("nonce", "mac") 255 | 256 | NONCE_BYTES: ClassVar[int] = 24 257 | MAC_BYTES: ClassVar[int] = 16 258 | 259 | _contents_segment: ClassVar[struct.Struct] = struct.Struct( 260 | "<" # Little-endian 261 | f"{NONCE_BYTES:d}s" # Nonce 262 | f"{MAC_BYTES:d}s" # MAC 263 | ) 264 | 265 | def __post_init__(self): 266 | if len(self.nonce) != self.NONCE_BYTES: 267 | raise ValueError("Invalid nonce size") 268 | if len(self.mac) != self.MAC_BYTES: 269 | raise ValueError("Invalid MAC size") 270 | 271 | @classmethod 272 | def unpack_from(cls, buffer, offset: int = 0) -> "XSalsa20SequentialChunk": 273 | nonce, mac = XSalsa20SequentialChunk._contents_segment.unpack_from( 274 | buffer, offset 275 | ) 276 | return cls(nonce, mac) 277 | 278 | def pack_into(self, buffer, offset: int = 0) -> int: 279 | offset = super().pack_into(buffer, offset) 280 | XSalsa20SequentialChunk._contents_segment.pack_into( 281 | buffer, offset, self.nonce, self.mac 282 | ) 283 | return offset + XSalsa20SequentialChunk._contents_segment.size 284 | 285 | @property 286 | def size(self) -> int: 287 | return XSalsa20SequentialChunk._contents_segment.size + super().size 288 | 289 | 290 | @dataclasses.dataclass 291 | class CryptInfo: 292 | num_chunks: int = dataclasses.field(init=False) 293 | chunks: Sequence[CryptInfoChunk] = () 294 | 295 | _length_segment: ClassVar[struct.Struct] = struct.Struct( 296 | " int: 310 | return self._length_segment.size + self.size 311 | 312 | @property 313 | def size(self) -> int: 314 | return self._count_segment.size + sum(c.sized_size for c in self.chunks) 315 | 316 | def find_chunks( 317 | self, 318 | typ: Union[ 319 | typing.Type[CryptInfoChunk], 320 | typing.Tuple[typing.Type[CryptInfoChunk], ...], 321 | ], 322 | ) -> Sequence[CryptInfoChunk]: 323 | return tuple(c for c in self.chunks if isinstance(c, typ)) 324 | 325 | def pack_into(self, buffer, offset: int = 0) -> int: 326 | CryptInfo._count_segment.pack_into(buffer, offset, self.num_chunks) 327 | offset += CryptInfo._count_segment.size 328 | for chunk in self.chunks: 329 | offset = chunk.sized_pack_into(buffer, offset) 330 | return offset 331 | 332 | def sized_pack_into(self, buffer, offset: int = 0) -> int: 333 | length_offset = offset 334 | offset += CryptInfo._length_segment.size 335 | ret = self.pack_into(buffer, offset) 336 | CryptInfo._length_segment.pack_into(buffer, length_offset, ret - offset) 337 | return ret 338 | 339 | @classmethod 340 | def unpack_from(cls, buffer, offset: int = 0) -> "CryptInfo": 341 | num_chunks: int = CryptInfo._count_segment.unpack_from(buffer, offset)[ 342 | 0 343 | ] 344 | offset += CryptInfo._count_segment.size 345 | if num_chunks < 0: 346 | raise ValueError( 347 | "Invalid CryptInfo chunk count, cannot be negative" 348 | ) 349 | chunks: List[CryptInfoChunk] = [] 350 | with memoryview(buffer) as mv: 351 | for i in range(num_chunks): 352 | chunk_size: int = CryptInfo._chunk_length_segment.unpack_from( 353 | buffer, offset 354 | )[0] 355 | offset += CryptInfo._chunk_length_segment.size 356 | chunk_end: int = offset + chunk_size 357 | with mv[offset:chunk_end] as chunk_mv: 358 | # Blocks out-of-bounds accesses 359 | chunks.append(CryptInfoChunk.unpack_from(chunk_mv)) 360 | offset = chunk_end 361 | return cls(chunks) 362 | -------------------------------------------------------------------------------- /tensorizer/_internal_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import struct 3 | import typing 4 | from typing import Tuple, Union 5 | 6 | _Buffer = Union[bytes, bytearray, memoryview] # type: typing.TypeAlias 7 | 8 | 9 | @dataclasses.dataclass(init=False) 10 | class Chunked: 11 | __slots__ = ("count", "total_size", "chunk_size", "remainder") 12 | count: int 13 | total_size: int 14 | chunk_size: int 15 | remainder: int 16 | 17 | def __init__(self, total_size: int, chunk_size: int): 18 | self.total_size = total_size 19 | self.chunk_size = chunk_size 20 | self.remainder = total_size % chunk_size 21 | self.count = total_size // chunk_size + (self.remainder != 0) 22 | 23 | 24 | def _variable_read( 25 | data: bytes, offset: int = 0, length_fmt: str = "B", data_fmt: str = "s" 26 | ) -> Tuple[Union[memoryview, Tuple], int]: 27 | """ 28 | Reads a variable-length field preceded by a length from a buffer. 29 | 30 | Returns: 31 | A tuple of the data read, and the offset in the buffer 32 | following the end of the field. 33 | """ 34 | assert length_fmt in ("B", "H", "I", "Q") 35 | if length_fmt == "B": 36 | length: int = data[offset] 37 | offset += 1 38 | else: 39 | length_struct = struct.Struct("<" + length_fmt) 40 | length: int = length_struct.unpack_from(data, offset)[0] 41 | offset += length_struct.size 42 | if data_fmt == "s": 43 | # When the data is read as bytes, just return a memoryview 44 | end = offset + length 45 | return _unpack_memoryview_from(length, data, offset), end 46 | else: 47 | data_struct = struct.Struct(f"<{length:d}{data_fmt}") 48 | data = data_struct.unpack_from(data, offset) 49 | offset += data_struct.size 50 | return data, offset 51 | 52 | 53 | def _unpack_memoryview_from( 54 | length: int, buffer: _Buffer, offset: int 55 | ) -> memoryview: 56 | # Grabbing a memoryview with bounds checking. 57 | # Bounds checking is normally provided by the struct module, 58 | # but it can't return memoryviews. 59 | with memoryview(buffer) as mv: 60 | end = offset + length 61 | view = mv[offset:end] 62 | if len(view) < length: 63 | view.release() 64 | mv.release() 65 | # Simulate a struct.error message for consistency 66 | raise struct.error( 67 | "unpack_from requires a buffer of at least" 68 | f" {length:d} bytes for unpacking {length:d} bytes at offset" 69 | f" {offset:d} (actual buffer size is {len(buffer):d})" 70 | ) 71 | return view 72 | -------------------------------------------------------------------------------- /tensorizer/_linear_partition.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import sys 3 | from typing import Iterable, List, Sequence, Tuple 4 | 5 | __all__ = ("partition",) 6 | 7 | 8 | def partition( 9 | weights: Sequence[int], 10 | partitions: int, 11 | performance_threshold: int = 100, 12 | ) -> Iterable[slice]: 13 | """ 14 | Partitions a sequence of weights into slices with balanced sums, 15 | without changing the ordering of elements. 16 | Balancing minimizes the largest sum of any resulting slice. 17 | Args: 18 | weights: Element weights to balance. 19 | partitions: The maximum number of slices to return. 20 | May return fewer if there are too few weights. 21 | performance_threshold: Limit on the estimated time that would 22 | be required to calculate an optimal partitioning scheme. 23 | Not an exact measurement, but similar to milliseconds, 24 | with an additional fuzzy bound on memory usage thrown in. 25 | If this threshold is passed, an asymptotically faster, 26 | low-memory greedy approximation is used instead. 27 | 28 | Returns: 29 | An iterable of ``slice`` objects denoting ranges of the original 30 | sequence that belong to each partition. 31 | 32 | Examples: 33 | Balancing a list of integers:: 34 | 35 | ints = [2, 8, 10, 5, 5] 36 | split = [ints[s] for s in partition(ints, 3)] 37 | assert split == [[2, 8], [10], [5, 5]] 38 | 39 | Balancing a list of strings by length:: 40 | 41 | strings = ["abc", "ABC", "12345", "XYZ", "654321"] 42 | weights = [len(s) for s in strings] 43 | split = [strings[s] for s in partition(weights, 4)] 44 | assert split == [["abc", "ABC"], ["12345"], ["XYZ"], ["654321"]] 45 | """ 46 | n: int = len(weights) 47 | partitions = min(n, partitions) 48 | # The strange formula came from a least-squares fit over testing data; 49 | # it is likely an overestimate, but approximates the time in milliseconds 50 | # to run linear_partition 51 | too_intensive: bool = ( 52 | round( 53 | max( 54 | (partitions * n) / 5000, 55 | n**1.6565 * partitions**0.617 * 1.75e-4, 56 | ) 57 | ) 58 | > performance_threshold 59 | ) 60 | if partitions <= 2 or too_intensive: 61 | return greedy_linear_partition(weights, partitions) 62 | else: 63 | return linear_partition(weights, partitions) 64 | 65 | 66 | def linear_partition( 67 | weights: Sequence[int], partitions: int 68 | ) -> Iterable[slice]: 69 | n: int = len(weights) 70 | partitions = min(partitions, n) 71 | if partitions <= 1: 72 | return (slice(0, n),) 73 | for w in weights: 74 | if w < 0: 75 | raise ValueError("All weights must be non-negative") 76 | prefix_sums: Tuple[int, ...] = tuple( 77 | itertools.accumulate(weights, initial=0) 78 | ) 79 | 80 | inf: int = prefix_sums[-1] + 1 81 | sentinel = (inf, inf) 82 | memo = [sentinel] * ((n + 1) * partitions) 83 | # Key function: end * partitions + preceding_parts 84 | null = (0, 0) 85 | for i in range(partitions): 86 | # When end = 0 87 | memo[i] = null 88 | del null 89 | for i in range(n + 1): 90 | # When preceding_parts = 0 91 | memo[i * partitions] = (0, prefix_sums[i]) 92 | 93 | def find_start(end: int, preceding_parts: int) -> Tuple[int, int]: 94 | key = end * partitions + preceding_parts 95 | cache_hit = memo[key] 96 | if cache_hit is not sentinel: 97 | return cache_hit 98 | 99 | best_weight = inf 100 | best_start = -1 101 | end_sum = prefix_sums[end] 102 | 103 | # Optimization: In general, iterating in reverse will find the best one 104 | # faster for "uniformly shuffled" datasets, since the best segment split 105 | # will likely be closer to (end / preceding_parts) in length. 106 | # There are two more important observations: 107 | # 1. current_weight is monotonically increasing while moving 108 | # right-to-left, because the current segment is expanding 109 | # 2. earlier_weight is monotonically decreasing while moving 110 | # right-to-left, because the preceding segment is shrinking 111 | # 112 | # This means there are two regions that can be completely skipped: 113 | # 1. current_weight is too big (> best_weight) 114 | # - This means the start is too far left 115 | # 2. earlier_weight is too big (> best_weight) 116 | # - This means the start is too far right 117 | # Since best_weight is monotonically decreasing, regions skipped like 118 | # this never need to be revisited. On the other hand, each time 119 | # best_weight updates, more may be eligible to be skipped. 120 | # The sooner good candidates are found for best_weight, 121 | # the sooner the search space will be narrowed. 122 | # This leads to the following strategy: 123 | # 1. Keep track of the left and right boundaries of the search space 124 | # 2. While the rightmost element is a new best, shrink the right by 1 125 | # 3. When the rightmost element is not a new best: 126 | # a) Shrink the right by 1, then 127 | # b) Jump half of the remaining search space towards the left 128 | # c) If current_weight is now too big, update the left, 129 | # then jump half of the new search space back towards the right 130 | # d) If earlier_weight is still too big, update the right to the 131 | # jumped-to position, and then go back to step 3 132 | # e) If earlier_weight is not too big, update best_weight, 133 | # jump back to the right, and go back to step 2 134 | # This attempts to find the valid region for current_weight 135 | # and earlier_weight as quickly as possible via a dynamic variant of 136 | # binary search, and then linearly scans through the possibilities. 137 | 138 | left = 0 # first impossible element due to current_weight 139 | # Note: start = 0 is reserved for when preceding_parts = 0 anyway, 140 | # which is always handled by the cache, so setting left = 0 is safe. 141 | right = end - 1 # last possible element due to earlier_weight 142 | start = right 143 | while True: 144 | current_weight = end_sum - prefix_sums[start] 145 | if current_weight < best_weight: 146 | earlier_weight = find_start(start, preceding_parts - 1)[1] 147 | if earlier_weight < best_weight: 148 | best_weight = ( 149 | current_weight 150 | if current_weight > earlier_weight 151 | else earlier_weight 152 | ) 153 | best_start = start 154 | # If this was already the rightmost one, narrow the search 155 | right -= start == right 156 | if right <= left: 157 | break 158 | # Reset to the right end, in case the best was skipped 159 | start = right 160 | else: 161 | # Nothing right of this matters 162 | # Try skipping forward a bit 163 | right = start - 1 164 | dist = right - left 165 | if dist <= 0: 166 | break 167 | else: 168 | start = right - (dist >> 1) 169 | else: 170 | # Overshot, nothing left of this matters 171 | left = start 172 | dist = right - left 173 | if dist <= 0: 174 | break 175 | elif dist == 1: 176 | start = right 177 | else: 178 | start = left + (dist >> 1) 179 | 180 | result = (best_start, best_weight) 181 | memo[key] = result 182 | return result 183 | 184 | if partitions > 900: 185 | old_recursion_limit = sys.getrecursionlimit() 186 | # Loosen the recursion limit by up to about 6500 if needed. 187 | # Since too-high limits can cause the interpreter to crash, 188 | # anything beyond this point is handled on a purely algorithmic level. 189 | sys.setrecursionlimit(old_recursion_limit + min(partitions, 6500) + 10) 190 | else: 191 | old_recursion_limit = None 192 | try: 193 | for parts_before in range(6500, partitions - 1, 6500): 194 | # Limit the stack depth by pre-populating the cache 195 | # for extremely high numbers of partitions (> 6500). 196 | # Despite caching, this can easily take more time than 197 | # the main call, because it computes several extra values 198 | # that would have normally been skipped. 199 | parts_after: int = partitions - parts_before - 1 200 | # n - parts_after is the closest point to the end that could 201 | # feasibly have parts_after parts after it, skipping impossible 202 | # scenarios like end=n, parts_before=0 203 | for i in range(0, n - parts_after + 1): 204 | find_start(i, parts_before) 205 | 206 | i = n 207 | seq = [n] 208 | for k in range(1, partitions): 209 | i = find_start(i, partitions - k)[0] 210 | if i == 0: 211 | break 212 | seq.append(i) 213 | seq.append(0) 214 | seq.reverse() 215 | finally: 216 | if old_recursion_limit is not None: 217 | sys.setrecursionlimit(old_recursion_limit) 218 | memo.clear() 219 | return tuple(slice(a, b) for a, b in zip(seq, seq[1:])) 220 | 221 | 222 | def greedy_linear_partition( 223 | weights: Sequence[int], partitions: int 224 | ) -> Iterable[slice]: 225 | # Greedy approximation for the linear partitioning problem, adapted from: 226 | # https://www.werkema.com/2021/11/01/an-efficient-solution-to-linear-partitioning/ 227 | # Time complexity: O(len(weights)) 228 | # Space complexity: O(partitions) 229 | # Could have O(1) space if changed to be a generator 230 | partitions = min(len(weights), partitions) 231 | if partitions <= 1: 232 | return (slice(0, len(weights)),) 233 | 234 | # This implementation scales weights by 2 * partitions to avoid fractions 235 | target_size: int = sum(weights) * 2 236 | current_size: int = 0 237 | 238 | groups: List[slice] = [] 239 | start: int = 0 240 | 241 | for end, weight in enumerate(weights): 242 | scaled_weight: int = weight * partitions 243 | current_size += scaled_weight * 2 244 | if current_size > target_size + scaled_weight: 245 | groups.append(slice(start, end)) 246 | start = end 247 | if len(groups) == partitions - 1: 248 | break 249 | current_size -= target_size 250 | 251 | groups.append(slice(start, len(weights))) 252 | return groups 253 | -------------------------------------------------------------------------------- /tensorizer/_syscalls.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import errno 3 | import mmap 4 | 5 | __all__ = ( 6 | "has_fallocate", 7 | "try_fallocate", 8 | "prefault", 9 | ) 10 | 11 | 12 | try: 13 | _libc = ctypes.CDLL(None) 14 | except TypeError: 15 | _libc = ctypes.pythonapi 16 | 17 | _IN: int = 1 18 | 19 | 20 | def _errcheck(result, func, args) -> None: 21 | del args 22 | if result == -1: 23 | err: int = ctypes.get_errno() 24 | str_err: str = errno.errorcode.get(err, "Unknown error") 25 | raise OSError(err, str_err) 26 | elif result == 0: 27 | return None 28 | else: 29 | raise OSError("Unknown return code") 30 | 31 | 32 | def _get_fallocate(): 33 | from ctypes import CFUNCTYPE, c_int, c_longlong 34 | 35 | prototype = CFUNCTYPE( 36 | c_int, 37 | c_int, 38 | c_int, 39 | c_longlong, 40 | c_longlong, 41 | use_errno=True, 42 | ) 43 | paramflags = ( 44 | (_IN, "fd"), 45 | (_IN, "mode"), 46 | (_IN, "offset"), 47 | (_IN, "len"), 48 | ) 49 | 50 | try: 51 | _func = prototype(("fallocate", _libc), paramflags) 52 | except AttributeError: 53 | return None 54 | _func.errcheck = _errcheck 55 | 56 | return _func 57 | 58 | 59 | _fallocate = _get_fallocate() 60 | del _get_fallocate 61 | 62 | 63 | def has_fallocate() -> bool: 64 | """ 65 | Checks if the Linux ``fallocate(2)`` syscall is available. 66 | Returns: ``True`` if ``fallocate(2)`` is available, ``False`` otherwise. 67 | """ 68 | return _fallocate is not None 69 | 70 | 71 | def try_fallocate( 72 | fd: int, offset: int, length: int, suppress_all_errors: bool = False 73 | ) -> bool: 74 | """ 75 | Calls ``fallocate(2)`` on the given file descriptor `fd` if available, 76 | ignoring some errors if unsuccessful. 77 | 78 | Args: 79 | fd: File descriptor on which to call ``fallocate(2)``. 80 | offset: Starting position of the byte range to allocate. 81 | length: Number of bytes to allocate. 82 | suppress_all_errors: If True, ignore all errors from unsuccessful calls. 83 | Otherwise, only ignores ``EOPNOTSUPP``. 84 | 85 | Returns: ``True`` if fallocate ran successfully, ``False`` otherwise. 86 | Raises: 87 | OSError: If `suppress_all_errors` is ``False`` and the call failed 88 | due to an error other than ``EOPNOTSUPP``. 89 | """ 90 | if _fallocate is None: 91 | return False 92 | try: 93 | _fallocate(fd=fd, mode=0, offset=offset, len=length) 94 | return True 95 | except OSError as e: 96 | if suppress_all_errors or e.errno == errno.EOPNOTSUPP: 97 | return False 98 | else: 99 | raise 100 | 101 | 102 | def _get_madvise(): 103 | from ctypes import CFUNCTYPE, c_int, c_size_t, c_void_p 104 | 105 | prototype = CFUNCTYPE( 106 | c_int, 107 | c_void_p, 108 | c_size_t, 109 | c_int, 110 | use_errno=True, 111 | ) 112 | paramflags = ( 113 | (_IN, "addr"), 114 | (_IN, "length"), 115 | (_IN, "advice"), 116 | ) 117 | 118 | try: 119 | _func = prototype(("madvise", _libc), paramflags) 120 | except AttributeError: 121 | return None 122 | _func.errcheck = _errcheck 123 | 124 | return _func 125 | 126 | 127 | _madvise = _get_madvise() 128 | del _get_madvise 129 | 130 | _madv_populate_write: int = 23 131 | 132 | 133 | def _can_prefault_with_madvise() -> bool: 134 | if _madvise is None or _libc is ctypes.pythonapi: 135 | # If _libc is ctypes.pythonapi then the call would hold the GIL 136 | return False 137 | n: int = mmap.PAGESIZE 138 | private: int = getattr(mmap, "MAP_PRIVATE", 0) 139 | flags = {} if private == 0 else {"flags": private} 140 | with mmap.mmap(-1, n, **flags) as m: 141 | try: 142 | # MADV_POPULATE_WRITE is only available on Linux 5.14 and up 143 | _madvise( 144 | ctypes.byref((ctypes.c_ubyte * n).from_buffer(m)), 145 | n, 146 | _madv_populate_write, 147 | ) 148 | except OSError: 149 | return False 150 | else: 151 | return True 152 | 153 | 154 | if _can_prefault_with_madvise(): 155 | 156 | def prefault(address, length: int): 157 | _madvise(address, length, _madv_populate_write) 158 | 159 | else: 160 | 161 | def prefault(address, length: int): 162 | ctypes.memset(address, 0x00, length) 163 | 164 | 165 | del _can_prefault_with_madvise 166 | -------------------------------------------------------------------------------- /tensorizer/_tensor_path.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import operator 4 | import types 5 | import typing 6 | from typing import Callable, Dict, Iterable, Iterator, List, Tuple, Union 7 | 8 | # Tensor paths are made up of strings (for mapping keys) 9 | # and integers (for array indices) 10 | _TensorPathComponent: "typing.TypeAlias" = Union[str, int] 11 | 12 | 13 | class _TensorPath(tuple): 14 | def serialized_(self) -> bytes: 15 | if self.is_str_: 16 | return self[0].encode("utf-8") 17 | else: 18 | # application/json-seq format 19 | return b"\x1e" + json.dumps( 20 | self, indent=None, ensure_ascii=True, separators=(",", ":") 21 | ).encode("ascii") 22 | 23 | @property 24 | def is_str_(self) -> bool: 25 | return len(self) == 1 and isinstance(self[0], str) 26 | 27 | def normalize_(self) -> Union[tuple, str]: 28 | return self[0] if self.is_str_ else tuple(self) 29 | 30 | def __str__(self) -> str: 31 | return str(self.normalize_()) 32 | 33 | def append_(self, other: Union[str, int]) -> "_TensorPath": 34 | if not isinstance(other, (str, int)): 35 | raise TypeError(f"Invalid key type: {other.__class__.__name__!r}") 36 | else: 37 | return self.__class__(self + (other,)) 38 | 39 | def validate_(self) -> None: 40 | if not self: 41 | raise ValueError("Invalid empty tensor path") 42 | for i in self: 43 | if not isinstance(i, (str, int)): 44 | raise TypeError( 45 | "Invalid tensor path component type:" 46 | f" {i.__class__.__name__!r}" 47 | ) 48 | if isinstance(i, int) and i < 0: 49 | raise ValueError( 50 | f"Invalid negative integer tensor path component: {i}" 51 | ) 52 | 53 | @classmethod 54 | def wrap_(cls, value: Union["_TensorPath", tuple, str]) -> "_TensorPath": 55 | if isinstance(value, cls): 56 | return value 57 | elif isinstance(value, tuple): 58 | return cls(value) 59 | else: 60 | return cls((value,)) 61 | 62 | @staticmethod 63 | def _invalid_hook(*_args, **_kwargs): 64 | raise TypeError("Invalid deserialized type") 65 | 66 | @classmethod 67 | def deserialize_(cls, serialized: typing.ByteString) -> "_TensorPath": 68 | if not isinstance(serialized, (bytes, bytearray, memoryview)): 69 | raise TypeError( 70 | "Invalid tensor path: expected byte string," 71 | f" got {serialized.__class__.__name__!r}" 72 | ) 73 | if not serialized: 74 | ret = cls() 75 | ret.validate_() 76 | return ret 77 | is_mv: bool = isinstance(serialized, memoryview) 78 | first_byte: int = serialized[0] 79 | if first_byte == 0x1E: 80 | if ( 81 | len(serialized) < 3 82 | or serialized[1] != 0x5B # "[" 83 | or serialized[-1] != 0x5D # "]" 84 | ): 85 | # Require the form [...] 86 | raise ValueError("Invalid tensor path: non-array json-seq") 87 | if 0x0A in serialized or 0x0D in serialized: 88 | raise ValueError("Illegal newline in json-seq") 89 | try: 90 | deserialized: List[Union[str, int]] = json.loads( 91 | serialized[1:].tobytes() if is_mv else serialized[1:], 92 | object_hook=cls._invalid_hook, 93 | parse_float=cls._invalid_hook, 94 | parse_constant=cls._invalid_hook, 95 | ) 96 | except RecursionError as e: 97 | raise ValueError( 98 | "Cannot deserialize tensor path due to excessive nesting" 99 | ) from e 100 | if not isinstance(deserialized, list): 101 | raise TypeError( 102 | "Invalid deserialized type:" 103 | " expected array as top level object" 104 | ) 105 | ret = cls(deserialized) 106 | ret.validate_() 107 | return ret 108 | else: 109 | if is_mv: 110 | serialized = serialized.tobytes() 111 | return cls((serialized.decode("utf-8"),)) 112 | 113 | 114 | @dataclasses.dataclass 115 | class _TensorPathRegistry: 116 | """ 117 | Tracks tensor paths used so far, throwing an error on prefix conflicts, 118 | and building a prefix tree of layers of the nested structure. 119 | """ 120 | 121 | __slots__ = "_registered_paths" 122 | _registered_paths: dict 123 | 124 | def __init__(self): 125 | self._registered_paths = {} 126 | 127 | def _check_compatible_types(self, path: _TensorPath) -> None: 128 | branch = self._registered_paths 129 | for depth, component in enumerate(path): 130 | if not branch: 131 | break 132 | existing_type = type(next(iter(branch))) 133 | current_type = type(component) 134 | if existing_type is not current_type: 135 | prefix: tuple = path[: depth + 1] 136 | raise ValueError( 137 | "Conflicting tensor paths:" 138 | f" {path.normalize_()} has a different key type" 139 | f" ({current_type.__name__!r}) than existing keys at the" 140 | f" prefix {prefix} ({existing_type.__name__!r})" 141 | ) 142 | if component not in branch: 143 | break 144 | branch = branch[component] 145 | 146 | def register_path(self, path: Union[_TensorPath, str]) -> None: 147 | branch: dict = self._registered_paths 148 | if isinstance(path, str): 149 | path = _TensorPath((path,)) 150 | if not isinstance(path, _TensorPath): 151 | raise TypeError( 152 | f"Invalid tensor path type: {path.__class__.__name__!r}" 153 | ) 154 | if not path: 155 | raise ValueError("Invalid empty tensor path") 156 | self._check_compatible_types(path) 157 | for component in path[:-1]: 158 | branch = branch.setdefault(component, {}) 159 | if not isinstance(branch, dict): 160 | raise ValueError(f"Conflicting tensor paths: {path}, {branch}") 161 | component = path[-1] 162 | if component in branch: 163 | if isinstance(branch[component], dict): 164 | raise ValueError( 165 | f"Conflicting tensor paths: {path.normalize_()} is both" 166 | " a leaf and a prefix of another path" 167 | ) 168 | else: 169 | raise ValueError( 170 | "Conflicting tensor paths:" 171 | f" {path.normalize_()} is used multiple times" 172 | ) 173 | branch[component] = path 174 | 175 | def filter(self, leaf_filter: Callable[[_TensorPath], bool]): 176 | layers = [(self._registered_paths, iter(tuple(self._registered_paths)))] 177 | while layers: 178 | layer, layer_keys = layers[-1] 179 | for k in layer_keys: 180 | v = layer[k] 181 | if isinstance(v, _TensorPath): 182 | # If this is a leaf, check if it needs to be pruned 183 | if not leaf_filter(v): 184 | del layer[k] 185 | else: 186 | # Otherwise, recurse 187 | layers.append((v, iter(tuple(v)))) 188 | break 189 | else: 190 | layers.pop() 191 | 192 | def dict(self) -> dict: 193 | return self._registered_paths 194 | 195 | 196 | def key_value_iterator(obj: Union[typing.Sequence, typing.Mapping]): 197 | if isinstance(obj, typing.Mapping): 198 | for k in obj.keys(): 199 | if not isinstance(k, str): 200 | raise TypeError( 201 | "Invalid key type for state_dict: expected str, got" 202 | f" {k.__class__.__name__!r}" 203 | ) 204 | return iter(obj.items()) 205 | elif isinstance(obj, typing.Sequence): 206 | return enumerate(obj) 207 | else: 208 | raise TypeError( 209 | "Cannot serialize type as part of a state_dict:" 210 | f" {obj.__class__.__name__!r}" 211 | ) 212 | 213 | 214 | _LeafType = typing.TypeVar("_LeafType") 215 | 216 | 217 | def flatten_structure( 218 | leaf_type: typing.Type[_LeafType], 219 | obj: Union[List, typing.Mapping], 220 | prefix: _TensorPath = _TensorPath(), 221 | ) -> Iterable[Tuple[_TensorPath, _LeafType]]: 222 | iters: List[Tuple[_TensorPath, Iterator]] = [ 223 | (prefix, key_value_iterator(obj)) 224 | ] 225 | while iters: 226 | pre, it = iters[-1] 227 | for name, item in it: 228 | path: _TensorPath = pre.append_(name) 229 | if isinstance(item, leaf_type): 230 | yield path, item 231 | else: 232 | iters.append((path, key_value_iterator(item))) 233 | break 234 | else: 235 | iters.pop() 236 | 237 | 238 | def restructure( 239 | flat: Dict[_TensorPath, _LeafType], use_dict_proxies: bool = False 240 | ) -> Union[dict, list, types.MappingProxyType]: 241 | for path in flat.keys(): 242 | if len(path) < 1: 243 | raise ValueError("Invalid empty tensor path key") 244 | 245 | # Start reconstructing everything as nested dictionaries 246 | base = {} 247 | for path, tensor in flat.items(): 248 | branch = base 249 | for component in path[:-1]: 250 | branch = branch.setdefault(component, {}) 251 | if not isinstance(branch, dict): 252 | # Key path conflicts should be caught at the metadata 253 | # parsing step, so this is just an extra sanity check 254 | raise RuntimeError(f"Key path conflict for key {path}") 255 | component = path[-1] 256 | if component in branch: 257 | raise RuntimeError(f"Key path conflict for key {path}") 258 | branch[component] = tensor 259 | 260 | # Assign a type to each layer separately 261 | def re_type_layer( 262 | untyped_layer: dict, 263 | ) -> Union[dict, list, types.MappingProxyType]: 264 | if use_dict_proxies: 265 | return types.MappingProxyType(untyped_layer) 266 | is_list = False 267 | for key in untyped_layer: 268 | if isinstance(key, int): 269 | is_list = True 270 | if key < 0: 271 | raise ValueError( 272 | "Illegal negative integer tensor path component" 273 | ) 274 | elif is_list: 275 | raise ValueError( 276 | "Invalid tensor path keys:" 277 | " mixes dict and list on same layer" 278 | ) 279 | if is_list: 280 | # Lists are always ordered by the value of their key indices, 281 | # rather than the order in the file. 282 | list_layer = list(untyped_layer.items()) 283 | list_layer.sort(key=operator.itemgetter(0)) 284 | return [v for _, v in list_layer] 285 | else: 286 | return untyped_layer 287 | 288 | # Track recursive state with a stack. 289 | # Iterators track progress through the keys of each layer. 290 | # Direct iterators over dictionaries (rather than just keys) 291 | # may not be stable while actively mutating the dictionary 292 | # mid-iteration, so the iterators are over a stable copy 293 | # of the keys instead. 294 | base_iter = iter(tuple(base)) 295 | layers = [(None, None, base, base_iter)] 296 | while layers: 297 | last_layer, last_key, layer, key_iterator = layers[-1] 298 | for k in key_iterator: 299 | next_layer = layer[k] 300 | if isinstance(next_layer, dict): 301 | # Recurse 302 | next_iter = iter(tuple(next_layer)) 303 | layers.append((layer, k, next_layer, next_iter)) 304 | break 305 | else: 306 | # Update the key in the parent (or base) with the corrected type 307 | re_typed = re_type_layer(layer) 308 | if last_layer is None: 309 | base = re_typed 310 | else: 311 | last_layer[last_key] = re_typed 312 | layers.pop() 313 | return base 314 | -------------------------------------------------------------------------------- /tensorizer/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.9.3" 2 | -------------------------------------------------------------------------------- /tensorizer/_wide_pipes.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools 3 | import logging 4 | import subprocess 5 | import sys 6 | import threading 7 | 8 | # ============================================================================= 9 | # From `pipe(7)` manpage: 10 | # 11 | # Pipe capacity 12 | # A pipe has a limited capacity. If the pipe is full, then a write(2) will 13 | # block or fail, depending on whether the O_NONBLOCK flag is set (see below). 14 | # Different implementations have different limits for the pipe capacity. 15 | # 16 | # Applications should not rely on a particular capacity: an application should 17 | # be designed so that a reading process consumes data as soon as it is 18 | # available, so that a writing process does not remain blocked. 19 | # 20 | # In Linux versions before 2.6.11, the capacity of a pipe was the same as the 21 | # system page size (e.g., 4096 bytes on i386). Since Linux 2.6.11, the pipe 22 | # capacity is 16 pages (i.e., 65,536 bytes in a system with a page size of 23 | # 4096 bytes). Since Linux 2.6.35, the default pipe capacity is 16 pages, but 24 | # the capacity can be queried and set using the fcntl(2) F_GETPIPE_SZ and 25 | # F_SETPIPE_SZ operations. See fcntl(2) for more information. 26 | # 27 | # ============================================================================= 28 | # From `fcntl(2)` manpage: 29 | # 30 | # Changing the capacity of a pipe 31 | # 32 | # F_SETPIPE_SZ (int; since Linux 2.6.35) 33 | # Change the capacity of the pipe referred to by fd to be at least arg bytes. 34 | # An unprivileged process can adjust the pipe capacity to any value between the 35 | # system page size and the limit defined in /proc/sys/fs/pipe−max−size 36 | # (see proc(5)). Attempts to set the pipe capacity below the page size are 37 | # silently rounded up to the page size. Attempts by an unprivileged process to 38 | # set the pipe capacity above the limit in /proc/sys/fs/pipe−max−size yield the 39 | # error EPERM; a privileged process (CAP_SYS_RESOURCE) can override the limit. 40 | # 41 | # When allocating the buffer for the pipe, the kernel may use a capacity larger 42 | # than arg, if that is convenient for the implementation. (In the current 43 | # implementation, the allocation is the next higher power-of-two page-size 44 | # multiple of the requested size.) The actual capacity (in bytes) that is set 45 | # is returned as the function result. 46 | # 47 | # Attempting to set the pipe capacity smaller than the amount of buffer space 48 | # currently used to store data produces the error EBUSY. 49 | # 50 | # Note that because of the way the pages of the pipe buffer are employed when 51 | # data is written to the pipe, the number of bytes that can be written may be 52 | # less than the nominal size, depending on the size of the writes. 53 | # 54 | # F_GETPIPE_SZ (void; since Linux 2.6.35) 55 | # Return (as the function result) the capacity of the pipe referred to by fd. 56 | # 57 | # ============================================================================= 58 | # Constant for `F_SETPIPE_SZ`, as Python's `fcntl` module doesn't have this 59 | # defined until Python 3.10. 60 | F_SETPIPE_SZ = 1031 61 | 62 | _logger = logging.getLogger(__name__) 63 | 64 | __all__ = ["get_max_pipe_size", "widen_pipe", "widen_new_pipes"] 65 | 66 | # No-op default implementations 67 | widen_new_pipes = contextlib.nullcontext 68 | 69 | 70 | def widen_pipe(_fileno, _max_size=None): 71 | pass 72 | 73 | 74 | @functools.lru_cache(maxsize=None) 75 | def get_max_pipe_size(): 76 | pipe_buf_sz = 1024 * 1024 77 | if sys.platform != "win32": 78 | # Read our max-fd-size, fall back to 1mb if invalid. 79 | try: 80 | with open("/proc/sys/fs/pipe-max-size", "r") as pipe_file: 81 | pipe_buf_sz = int(pipe_file.read()) 82 | except IOError as e: 83 | _logger.warning( 84 | f"Could not read /proc/sys/fs/pipe-max-size: {e.strerror}" 85 | ) 86 | else: 87 | # Windows has no maximum pipe size, 88 | # so 256 MiB is chosen completely arbitrarily. 89 | pipe_buf_sz = 256 * 1024 * 1024 90 | _logger.debug(f"pipe-max-size: {pipe_buf_sz}") 91 | return pipe_buf_sz 92 | 93 | 94 | if sys.platform != "win32" and sys.platform != "darwin": 95 | # Linux uses fcntl to resize an existing pipe. 96 | import fcntl 97 | 98 | def widen_pipe(fileno, max_size=None): 99 | pipe_buf_sz = get_max_pipe_size() 100 | if max_size is not None and max_size < pipe_buf_sz: 101 | pipe_buf_sz = max_size 102 | try: 103 | fcntl.fcntl(fileno, F_SETPIPE_SZ, pipe_buf_sz) 104 | except OSError as e: 105 | _logger.warning( 106 | f"Couldn't fcntl F_SETPIPE_SZ to {pipe_buf_sz}: {e.strerror}" 107 | ) 108 | 109 | elif sys.platform == "win32": 110 | # Windows cannot change the size of a pipe after creation, 111 | # but it can set one's size during creation, so a context manager 112 | # is used to temporarily modify the creation of all pipes. 113 | _winapi = getattr(subprocess, "_winapi", None) 114 | if _winapi is not None and hasattr(_winapi, "CreatePipe"): 115 | 116 | class _LocalPipeSize(threading.local): 117 | pipe_size = 0 118 | 119 | _local = _LocalPipeSize() 120 | _original_create_pipe = _winapi.CreatePipe 121 | _pipe_routine_swap_mutex = threading.Lock() 122 | _pipe_widening_threads = 0 123 | 124 | def _create_wide_pipe(pipe_attrs, size): 125 | # The subprocess module creates new anonymous pipes on Windows as: 126 | # _winapi.CreatePipe(None, 0) 127 | # Where the first argument is ignored, 128 | # and the second is the pipe size (0 = default, usually 1 page). 129 | # To change this without reimplementing all of subprocess.Popen, 130 | # _winapi.CreatePipe itself is wrapped to override a size of 0 131 | # with a default of our choosing. 132 | # 133 | # This function is thread-safe in the sense that other threads 134 | # creating pipes while this is active will end up with 135 | # unchanged results due to the override value being thread-local. 136 | return _original_create_pipe( 137 | pipe_attrs, _local.pipe_size if size == 0 else size 138 | ) 139 | 140 | @contextlib.contextmanager 141 | def widen_new_pipes(max_size=None): 142 | global _pipe_widening_threads 143 | # Thread-safe but not re-entrant. 144 | # Thread safety in this function only matters if multiple threads 145 | # try to separately invoke this context manager at the same time, 146 | # which would only happen if multiple CURLStreamFiles were being 147 | # opened in the same process at the same time. It is less important 148 | # than _create_wide_pipe being thread-safe. 149 | _local.pipe_size = get_max_pipe_size() 150 | if max_size is not None and max_size < _local.pipe_size: 151 | _local.pipe_size = max_size 152 | with _pipe_routine_swap_mutex: 153 | _winapi.CreatePipe = _create_wide_pipe 154 | _pipe_widening_threads += 1 155 | try: 156 | yield 157 | finally: 158 | with _pipe_routine_swap_mutex: 159 | _pipe_widening_threads -= 1 160 | if _pipe_widening_threads == 0: 161 | _winapi.CreatePipe = _original_create_pipe 162 | del _local.pipe_size 163 | 164 | else: 165 | _logger.warning("Couldn't increase default pipe size.") 166 | -------------------------------------------------------------------------------- /tensorizer/protobuf.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import BinaryIO 3 | from typing import OrderedDict as OrderedDictType 4 | from typing import Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | 10 | import tensorizer.tensors_pb2 as tensors_pb 11 | from tensorizer.tensors_pb2 import Tensor as TensorPb 12 | 13 | DtypePbs = { 14 | torch.float32: tensors_pb.DT_FLOAT32, 15 | torch.float64: tensors_pb.DT_FLOAT64, 16 | torch.float16: tensors_pb.DT_FLOAT16, 17 | torch.bfloat16: tensors_pb.DT_BFLOAT16, 18 | torch.complex32: tensors_pb.DT_COMPLEX32, 19 | torch.complex64: tensors_pb.DT_COMPLEX64, 20 | torch.complex128: tensors_pb.DT_COMPLEX128, 21 | torch.uint8: tensors_pb.DT_UINT8, 22 | torch.int8: tensors_pb.DT_INT8, 23 | torch.int16: tensors_pb.DT_INT16, 24 | torch.int32: tensors_pb.DT_INT32, 25 | torch.int64: tensors_pb.DT_INT64, 26 | torch.bool: tensors_pb.DT_BOOL, 27 | torch.quint8: tensors_pb.DT_QUINT8, 28 | torch.qint8: tensors_pb.DT_QINT8, 29 | torch.qint32: tensors_pb.DT_QINT32, 30 | torch.quint4x2: tensors_pb.DT_QUINT4_2, 31 | } 32 | 33 | PbDtypes = { 34 | tensors_pb.DT_FLOAT32: torch.float32, 35 | tensors_pb.DT_FLOAT64: torch.float64, 36 | tensors_pb.DT_FLOAT16: torch.float16, 37 | tensors_pb.DT_BFLOAT16: torch.bfloat16, 38 | tensors_pb.DT_COMPLEX32: torch.complex32, 39 | tensors_pb.DT_COMPLEX64: torch.complex64, 40 | tensors_pb.DT_COMPLEX128: torch.complex128, 41 | tensors_pb.DT_UINT8: torch.uint8, 42 | tensors_pb.DT_INT8: torch.int8, 43 | tensors_pb.DT_INT16: torch.int16, 44 | tensors_pb.DT_INT32: torch.int32, 45 | tensors_pb.DT_INT64: torch.int64, 46 | tensors_pb.DT_BOOL: torch.bool, 47 | tensors_pb.DT_QUINT8: torch.quint8, 48 | tensors_pb.DT_QINT8: torch.qint8, 49 | tensors_pb.DT_QINT32: torch.qint32, 50 | tensors_pb.DT_QUINT4_2: torch.quint4x2, 51 | } 52 | 53 | PbNpyDtypes = { 54 | tensors_pb.DT_FLOAT32: np.float32, 55 | tensors_pb.DT_FLOAT64: np.float64, 56 | tensors_pb.DT_FLOAT16: np.float16, 57 | tensors_pb.DT_BFLOAT16: np.float16, 58 | tensors_pb.DT_COMPLEX32: np.complex64, 59 | tensors_pb.DT_COMPLEX64: np.complex64, 60 | tensors_pb.DT_COMPLEX128: np.complex128, 61 | tensors_pb.DT_UINT8: np.uint8, 62 | tensors_pb.DT_INT8: np.int8, 63 | tensors_pb.DT_INT16: np.int16, 64 | tensors_pb.DT_INT32: np.int32, 65 | tensors_pb.DT_INT64: np.int64, 66 | tensors_pb.DT_BOOL: bool, 67 | tensors_pb.DT_QUINT8: np.uint8, 68 | tensors_pb.DT_QINT8: np.int8, 69 | tensors_pb.DT_QINT32: np.int32, 70 | tensors_pb.DT_QUINT4_2: np.uint8, 71 | } 72 | 73 | 74 | def serialize_tensor( 75 | t: Tensor, attribute: tensors_pb.AttributeType = None 76 | ) -> tensors_pb.Tensor: 77 | assert isinstance(t, Tensor) 78 | assert attribute is None or attribute in [ 79 | tensors_pb.AT_PARAMETER, 80 | tensors_pb.AT_BUFFER, 81 | ] 82 | 83 | extra_opts = {} 84 | if attribute is not None: 85 | extra_opts = {"attr_type": attribute} 86 | 87 | return tensors_pb.Tensor( 88 | dtype=DtypePbs[t.dtype], 89 | shape=t.shape, 90 | data=t.cpu().detach().numpy().tobytes(), 91 | **extra_opts, 92 | ) 93 | 94 | 95 | def deserialize_tensor( 96 | t: tensors_pb.Tensor, 97 | ) -> Union[Tensor, Tuple[Tensor, "tensors_pb.AttributeType"]]: 98 | mv = bytearray(t.data) 99 | tensor = torch.as_tensor( 100 | np.ndarray.__new__( 101 | np.memmap, t.shape, dtype=PbNpyDtypes[t.dtype], buffer=mv, offset=0 102 | ) 103 | ) 104 | if t.HasField("attr_type"): 105 | return tensor, t.attr_type 106 | else: 107 | return tensor 108 | 109 | 110 | def serialize_model(model: torch.nn.Module, file_stream: BinaryIO) -> None: 111 | modules = list() 112 | for module_name, module in model.named_modules(): 113 | print(module_name) 114 | attributes = list() 115 | for name, param in module.named_parameters(recurse=False): 116 | v = param.cpu().detach() 117 | param_attr = tensors_pb.Attribute( 118 | name=name, tensor=serialize_tensor(v, tensors_pb.AT_PARAMETER) 119 | ) 120 | attributes.append(param_attr) 121 | for name, buffer in module.named_buffers(recurse=False): 122 | v = buffer.cpu().detach() 123 | buffer_attr = tensors_pb.Attribute( 124 | name=name, tensor=serialize_tensor(v, tensors_pb.AT_BUFFER) 125 | ) 126 | attributes.append(buffer_attr) 127 | module_attr = tensors_pb.Attribute( 128 | name=module_name, module=tensors_pb.Module(attributes=attributes) 129 | ) 130 | modules.append(module_attr) 131 | model_proto = tensors_pb.Module( # models are just modules as attributes 132 | name="", 133 | attributes=modules, 134 | ) 135 | file_stream.write(model_proto.SerializeToString()) 136 | 137 | 138 | def deserialize_model(model: torch.nn.Module, file_stream: BinaryIO) -> None: 139 | model_proto = tensors_pb.Module() 140 | model_proto.ParseFromString(file_stream.read()) 141 | 142 | modules: OrderedDictType[str, torch.nn.Module] = OrderedDict() 143 | for name, module in model.named_modules(): 144 | modules[name] = module 145 | 146 | for module_attr in model_proto.attributes: 147 | module = modules[module_attr.name] 148 | for attr in module_attr.module.attributes: 149 | if attr.tensor.HasField("attr_type"): 150 | if attr.tensor.attr_type == tensors_pb.AT_PARAMETER: 151 | module._parameters[attr.name] = deserialize_tensor( 152 | attr.tensor 153 | )[0] 154 | elif attr.tensor.attr_type == tensors_pb.AT_BUFFER: 155 | module._buffers[attr.name] = deserialize_tensor( 156 | attr.tensor 157 | )[0] 158 | else: 159 | raise ValueError("Unknown attribute type") 160 | -------------------------------------------------------------------------------- /tensorizer/tensors.proto: -------------------------------------------------------------------------------- 1 | ../proto/tensors.proto -------------------------------------------------------------------------------- /tensorizer/tensors_pb2.py: -------------------------------------------------------------------------------- 1 | ../tensors/tensors_pb2.py -------------------------------------------------------------------------------- /tensorizer/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import contextvars 3 | import threading 4 | from typing import ( 5 | Callable, 6 | ContextManager, 7 | NamedTuple, 8 | Optional, 9 | TypeVar, 10 | Union, 11 | ) 12 | 13 | import torch 14 | 15 | try: 16 | import resource 17 | except ImportError: 18 | resource = None 19 | 20 | import psutil 21 | 22 | try: 23 | import pynvml 24 | 25 | if not hasattr(pynvml, "nvml"): 26 | pynvml = None 27 | else: 28 | try: 29 | pynvml.nvmlInit() 30 | except pynvml.nvml.NVMLError_LibraryNotFound: 31 | pynvml = None 32 | except ImportError: 33 | pynvml = None 34 | 35 | __all__ = [ 36 | "convert_bytes", 37 | "get_device", 38 | "GlobalGPUMemoryUsage", 39 | "TorchGPUMemoryUsage", 40 | "CPUMemoryUsage", 41 | "MemoryUsage", 42 | "get_mem_usage", 43 | "get_gpu_name", 44 | "no_init_or_tensor", 45 | ] 46 | 47 | 48 | # Silly function to convert to human bytes 49 | def convert_bytes(num, decimal=True) -> str: 50 | """ 51 | Convert bytes to MB, GB, etc. 52 | 53 | Args: 54 | num: Quantity of bytes to format. 55 | decimal: Whether to use decimal or binary units 56 | (e.g. KB = 1000 bytes vs. KiB = 1024 bytes). 57 | 58 | Returns: 59 | A string in the format `` `` (e.g. ``123.4 MB``). 60 | """ 61 | if decimal: 62 | step_unit = 1000.0 63 | units = ("bytes", "KB", "MB", "GB", "TB", "PB") 64 | else: 65 | step_unit = 1024.0 66 | units = ("bytes", "KiB", "MiB", "GiB", "TiB", "PiB") 67 | 68 | for unit in units[:-1]: 69 | if num < step_unit: 70 | break 71 | num /= step_unit 72 | else: 73 | unit = units[-1] 74 | return "%3.1f %s" % (num, unit) 75 | 76 | 77 | def get_device() -> torch.device: 78 | return torch.device("cuda" if torch.cuda.is_available() else "cpu") 79 | 80 | 81 | class GlobalGPUMemoryUsage(NamedTuple): 82 | """Total memory usage statistics across all processes for a single GPU.""" 83 | 84 | total: int 85 | free: int 86 | used: int 87 | 88 | @classmethod 89 | def now(cls, device=None) -> Optional["GlobalGPUMemoryUsage"]: 90 | """ 91 | Capture a snapshot of the current total memory usage on a single GPU. 92 | 93 | Args: 94 | device: The GPU to gather memory statistics for. 95 | If None, the current GPU is used, 96 | as determined by ``torch.cuda.current_device()``. 97 | 98 | Returns: 99 | A tuple of (`total`, `free`, `used`) VRAM in bytes, if possible. 100 | 101 | None if neither PyTorch >=1.10 nor pynvml is available. 102 | """ 103 | if not torch.cuda.is_available(): 104 | return None 105 | 106 | # torch.cuda.mem_get_info() was introduced in PyTorch 1.10 107 | mem_get_info = getattr(torch.cuda, "mem_get_info", None) 108 | if mem_get_info is not None: 109 | free, total = mem_get_info(device) 110 | return cls(total, free, total - free) 111 | elif pynvml is not None: 112 | # Normalize the device to an int 113 | if isinstance(device, (int, str, bytes)): 114 | device = torch.device(device) 115 | if isinstance(device, torch.device): 116 | if device.type == "cpu": 117 | return None 118 | else: 119 | device = device.index 120 | if device is None: 121 | device = torch.cuda.current_device() 122 | nvml_device = pynvml.nvmlDeviceGetHandleByIndex(device) 123 | gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device) 124 | return cls(gpu_info.total, gpu_info.free, gpu_info.used) 125 | else: 126 | return None 127 | 128 | def __str__(self): 129 | return "GPU: (U: {:,}MiB F: {:,}MiB T: {:,}MiB)".format( 130 | self.used >> 20, self.free >> 20, self.total >> 20 131 | ) 132 | 133 | 134 | class TorchGPUMemoryUsage(NamedTuple): 135 | """Memory usage statistics for PyTorch on a single GPU.""" 136 | 137 | reserved: int 138 | reserved_max: int 139 | used: int 140 | used_max: int 141 | 142 | @classmethod 143 | def now(cls, device=None) -> Optional["TorchGPUMemoryUsage"]: 144 | """ 145 | Capture a snapshot of the current total memory usage on a single GPU. 146 | 147 | Args: 148 | device: The GPU to gather memory statistics for. 149 | If None, the current GPU is used, 150 | as determined by ``torch.cuda.current_device()``. 151 | 152 | Returns: 153 | A tuple of (`reserved`, `reserved_max`, `used`, `used_max`) 154 | memory statistics for PyTorch in bytes, if possible. 155 | 156 | None if CUDA isn't available. 157 | """ 158 | if torch.cuda.is_available(): 159 | stats = torch.cuda.memory.memory_stats(device) 160 | return cls( 161 | stats.get("reserved_bytes.all.current", 0), 162 | stats.get("reserved_bytes.all.peak", 0), 163 | stats.get("allocated_bytes.all.current", 0), 164 | stats.get("allocated_bytes.all.peak", 0), 165 | ) 166 | else: 167 | return None 168 | 169 | def __str__(self): 170 | return "TORCH: (R: {:,}MiB/{:,}MiB, A: {:,}MiB/{:,}MiB)".format( 171 | self.reserved >> 20, 172 | self.reserved_max >> 20, 173 | self.used >> 20, 174 | self.used_max >> 20, 175 | ) 176 | 177 | 178 | class CPUMemoryUsage(NamedTuple): 179 | """Memory usage statistics for CPU RAM.""" 180 | 181 | maxrss: int 182 | free: int 183 | 184 | @classmethod 185 | def now(cls) -> "CPUMemoryUsage": 186 | """ 187 | Capture a snapshot of the current CPU RAM usage. 188 | 189 | Returns: 190 | A tuple of (`maxrss`, `free`) RAM statistics in bytes, 191 | where maxrss is the max resident set size of the current process. 192 | 193 | On Unix, the system call ``getrusage(2)`` is used 194 | to measure the maxrss, so the granularity is 1024 bytes, 195 | but the unit is still bytes. 196 | """ 197 | if resource is not None: 198 | maxrss = ( 199 | resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 200 | + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss 201 | ) << 10 202 | else: 203 | process = psutil.Process() 204 | maxrss = process.memory_info().rss + sum( 205 | p.memory_info().rss for p in process.children(True) 206 | ) 207 | vmem = psutil.virtual_memory() 208 | return cls(maxrss, vmem.free) 209 | 210 | def __str__(self): 211 | return "CPU: (maxrss: {:,}MiB F: {:,}MiB)".format( 212 | self.maxrss >> 20, self.free >> 20 213 | ) 214 | 215 | 216 | class MemoryUsage(NamedTuple): 217 | """ 218 | Combined statistics for CPU, total GPU, and PyTorch memory usage. 219 | 220 | Gathers `CPUMemoryUsage`, `GlobalGPUMemoryUsage`, and `TorchGPUMemoryUsage` 221 | together in one tuple. 222 | """ 223 | 224 | cpu: CPUMemoryUsage 225 | gpu: Optional[GlobalGPUMemoryUsage] 226 | torch: Optional[TorchGPUMemoryUsage] 227 | 228 | @classmethod 229 | def now(cls, device=None): 230 | """ 231 | Capture a snapshot of CPU, total GPU, and PyTorch memory usage. 232 | If GPU memory usage statistics are not available, 233 | the `gpu` and `torch` fields of the resulting tuple are None. 234 | 235 | Args: 236 | device: The GPU to gather both total and PyTorch-specific 237 | memory statistics for. If None, the current GPU is used, 238 | as determined by ``torch.cuda.current_device()``. 239 | 240 | Returns: 241 | A tuple of (`cpu`, `gpu`, `torch`) memory statistics. 242 | If GPU memory usage statistics are not available, 243 | the `gpu` and `torch` fields are None. 244 | 245 | See the respective classes, `CPUMemoryUsage`, 246 | `GlobalGPUMemoryUsage`, and `TorchGPUMemoryUsage`, 247 | for more information on each component. 248 | """ 249 | gpu_info = torch_info = None 250 | try: 251 | gpu_info = GlobalGPUMemoryUsage.now(device) 252 | torch_info = TorchGPUMemoryUsage.now(device) 253 | except AssertionError: 254 | pass 255 | return cls(CPUMemoryUsage.now(), gpu_info, torch_info) 256 | 257 | def __str__(self): 258 | return " ".join(str(item) for item in self if item) 259 | 260 | 261 | def get_mem_usage() -> str: 262 | """ 263 | Captures and formats memory usage statistics for the CPU, GPU, and PyTorch. 264 | 265 | Equivalent to ``str(MemoryUsage.now())``. 266 | 267 | Returns: 268 | A formatted string summarizing memory usage 269 | across the CPU, GPU, and PyTorch. 270 | """ 271 | return str(MemoryUsage.now()) 272 | 273 | 274 | def get_gpu_name() -> str: 275 | if torch.cuda.is_available(): 276 | return torch.cuda.get_device_name() 277 | return "N/A" 278 | 279 | 280 | Model = TypeVar("Model") 281 | 282 | 283 | def no_init_or_tensor( 284 | loading_code: Optional[Callable[..., Model]] = None, 285 | ) -> Union[Model, ContextManager]: 286 | """ 287 | Suppress the initialization of weights while loading a model. 288 | 289 | Can either directly be passed a callable containing model-loading code, 290 | which will be evaluated with weight initialization suppressed, 291 | or used as a context manager around arbitrary model-loading code. 292 | 293 | Args: 294 | loading_code: Either a callable to evaluate 295 | with model weight initialization suppressed, 296 | or None (the default) to use as a context manager. 297 | 298 | Returns: 299 | The return value of `loading_code`, if `loading_code` is callable. 300 | 301 | Otherwise, if `loading_code` is None, returns a context manager 302 | to be used in a `with`-statement. 303 | 304 | Examples: 305 | As a context manager:: 306 | 307 | from transformers import AutoConfig, AutoModelForCausalLM 308 | config = AutoConfig("EleutherAI/gpt-j-6B") 309 | with no_init_or_tensor(): 310 | model = AutoModelForCausalLM.from_config(config) 311 | 312 | Or, directly passing a callable:: 313 | 314 | from transformers import AutoConfig, AutoModelForCausalLM 315 | config = AutoConfig("EleutherAI/gpt-j-6B") 316 | model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config)) 317 | """ 318 | if loading_code is None: 319 | return _NoInitOrTensorImpl.context_manager() 320 | elif callable(loading_code): 321 | with _NoInitOrTensorImpl.context_manager(): 322 | return loading_code() 323 | else: 324 | raise TypeError( 325 | "no_init_or_tensor() expected a callable to evaluate," 326 | " or None if being used as a context manager;" 327 | f' got an object of type "{type(loading_code).__name__}" instead.' 328 | ) 329 | 330 | 331 | class _NoInitOrTensorImpl: 332 | # Implementation of the thread-safe, async-safe, re-entrant context manager 333 | # version of no_init_or_tensor(). 334 | # This class essentially acts as a namespace. 335 | # It is not instantiable, because modifications to torch functions 336 | # inherently affect the global scope, and thus there is no worthwhile data 337 | # to store in the class instance scope. 338 | _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) 339 | _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) 340 | _ORIGINAL_EMPTY = torch.empty 341 | 342 | is_active = contextvars.ContextVar( 343 | "_NoInitOrTensorImpl.is_active", default=False 344 | ) 345 | _count_active: int = 0 346 | _count_active_lock = threading.Lock() 347 | 348 | @classmethod 349 | @contextlib.contextmanager 350 | def context_manager(cls): 351 | if cls.is_active.get(): 352 | yield 353 | return 354 | 355 | with cls._count_active_lock: 356 | cls._count_active += 1 357 | if cls._count_active == 1: 358 | for mod in cls._MODULES: 359 | mod.reset_parameters = cls._disable(mod.reset_parameters) 360 | # When torch.empty is called, make it map to meta device by replacing 361 | # the device in kwargs. 362 | torch.empty = cls._meta_empty 363 | reset_token = cls.is_active.set(True) 364 | 365 | try: 366 | yield 367 | finally: 368 | cls.is_active.reset(reset_token) 369 | with cls._count_active_lock: 370 | cls._count_active -= 1 371 | if cls._count_active == 0: 372 | torch.empty = cls._ORIGINAL_EMPTY 373 | for mod, original in cls._MODULE_ORIGINALS: 374 | mod.reset_parameters = original 375 | 376 | @staticmethod 377 | def _disable(func): 378 | def wrapper(*args, **kwargs): 379 | # Behaves as normal except in an active context 380 | if not _NoInitOrTensorImpl.is_active.get(): 381 | return func(*args, **kwargs) 382 | 383 | return wrapper 384 | 385 | @staticmethod 386 | def _meta_empty(*args, **kwargs): 387 | # Behaves as torch.empty except in an active context 388 | if _NoInitOrTensorImpl.is_active.get(): 389 | kwargs["device"] = "meta" 390 | return _NoInitOrTensorImpl._ORIGINAL_EMPTY(*args, **kwargs) 391 | 392 | __init__ = None 393 | -------------------------------------------------------------------------------- /tensors/LICENSE: -------------------------------------------------------------------------------- 1 | ../LICENSE -------------------------------------------------------------------------------- /tensors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coreweave/tensorizer/9241bc82e7b9fdc3f92aa38ad04efd54f0054525/tensors/__init__.py -------------------------------------------------------------------------------- /tensors/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coreweave/tensorizer/tensors 2 | 3 | go 1.18 4 | 5 | require google.golang.org/protobuf v1.28.1 6 | -------------------------------------------------------------------------------- /tensors/go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 2 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 3 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 4 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 5 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 6 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 7 | google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= 8 | google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= 9 | -------------------------------------------------------------------------------- /tensors/tensors_pb.d.ts: -------------------------------------------------------------------------------- 1 | // package: tensors 2 | // file: tensors.proto 3 | 4 | import * as jspb from "google-protobuf"; 5 | 6 | export class Tensor extends jspb.Message { 7 | getDtype(): DtypeMap[keyof DtypeMap]; 8 | setDtype(value: DtypeMap[keyof DtypeMap]): void; 9 | 10 | clearShapeList(): void; 11 | getShapeList(): Array; 12 | setShapeList(value: Array): void; 13 | addShape(value: number, index?: number): number; 14 | 15 | getData(): Uint8Array | string; 16 | getData_asU8(): Uint8Array; 17 | getData_asB64(): string; 18 | setData(value: Uint8Array | string): void; 19 | 20 | hasAttrType(): boolean; 21 | clearAttrType(): void; 22 | getAttrType(): AttributeTypeMap[keyof AttributeTypeMap]; 23 | setAttrType(value: AttributeTypeMap[keyof AttributeTypeMap]): void; 24 | 25 | serializeBinary(): Uint8Array; 26 | toObject(includeInstance?: boolean): Tensor.AsObject; 27 | static toObject(includeInstance: boolean, msg: Tensor): Tensor.AsObject; 28 | static extensions: {[key: number]: jspb.ExtensionFieldInfo}; 29 | static extensionsBinary: {[key: number]: jspb.ExtensionFieldBinaryInfo}; 30 | static serializeBinaryToWriter(message: Tensor, writer: jspb.BinaryWriter): void; 31 | static deserializeBinary(bytes: Uint8Array): Tensor; 32 | static deserializeBinaryFromReader(message: Tensor, reader: jspb.BinaryReader): Tensor; 33 | } 34 | 35 | export namespace Tensor { 36 | export type AsObject = { 37 | dtype: DtypeMap[keyof DtypeMap], 38 | shapeList: Array, 39 | data: Uint8Array | string, 40 | attrType: AttributeTypeMap[keyof AttributeTypeMap], 41 | } 42 | } 43 | 44 | export class Attribute extends jspb.Message { 45 | getName(): string; 46 | setName(value: string): void; 47 | 48 | hasModule(): boolean; 49 | clearModule(): void; 50 | getModule(): Module | undefined; 51 | setModule(value?: Module): void; 52 | 53 | hasTensor(): boolean; 54 | clearTensor(): void; 55 | getTensor(): Tensor | undefined; 56 | setTensor(value?: Tensor): void; 57 | 58 | hasString(): boolean; 59 | clearString(): void; 60 | getString(): string; 61 | setString(value: string): void; 62 | 63 | hasInt64(): boolean; 64 | clearInt64(): void; 65 | getInt64(): number; 66 | setInt64(value: number): void; 67 | 68 | hasFloat(): boolean; 69 | clearFloat(): void; 70 | getFloat(): number; 71 | setFloat(value: number): void; 72 | 73 | hasBool(): boolean; 74 | clearBool(): void; 75 | getBool(): boolean; 76 | setBool(value: boolean): void; 77 | 78 | getValueCase(): Attribute.ValueCase; 79 | serializeBinary(): Uint8Array; 80 | toObject(includeInstance?: boolean): Attribute.AsObject; 81 | static toObject(includeInstance: boolean, msg: Attribute): Attribute.AsObject; 82 | static extensions: {[key: number]: jspb.ExtensionFieldInfo}; 83 | static extensionsBinary: {[key: number]: jspb.ExtensionFieldBinaryInfo}; 84 | static serializeBinaryToWriter(message: Attribute, writer: jspb.BinaryWriter): void; 85 | static deserializeBinary(bytes: Uint8Array): Attribute; 86 | static deserializeBinaryFromReader(message: Attribute, reader: jspb.BinaryReader): Attribute; 87 | } 88 | 89 | export namespace Attribute { 90 | export type AsObject = { 91 | name: string, 92 | module?: Module.AsObject, 93 | tensor?: Tensor.AsObject, 94 | string: string, 95 | int64: number, 96 | pb_float: number, 97 | bool: boolean, 98 | } 99 | 100 | export enum ValueCase { 101 | VALUE_NOT_SET = 0, 102 | MODULE = 3, 103 | TENSOR = 4, 104 | STRING = 5, 105 | INT64 = 6, 106 | FLOAT = 7, 107 | BOOL = 8, 108 | } 109 | } 110 | 111 | export class Module extends jspb.Message { 112 | getName(): string; 113 | setName(value: string): void; 114 | 115 | clearNamesList(): void; 116 | getNamesList(): Array; 117 | setNamesList(value: Array): void; 118 | addNames(value: string, index?: number): string; 119 | 120 | clearAttributesList(): void; 121 | getAttributesList(): Array; 122 | setAttributesList(value: Array): void; 123 | addAttributes(value?: Attribute, index?: number): Attribute; 124 | 125 | serializeBinary(): Uint8Array; 126 | toObject(includeInstance?: boolean): Module.AsObject; 127 | static toObject(includeInstance: boolean, msg: Module): Module.AsObject; 128 | static extensions: {[key: number]: jspb.ExtensionFieldInfo}; 129 | static extensionsBinary: {[key: number]: jspb.ExtensionFieldBinaryInfo}; 130 | static serializeBinaryToWriter(message: Module, writer: jspb.BinaryWriter): void; 131 | static deserializeBinary(bytes: Uint8Array): Module; 132 | static deserializeBinaryFromReader(message: Module, reader: jspb.BinaryReader): Module; 133 | } 134 | 135 | export namespace Module { 136 | export type AsObject = { 137 | name: string, 138 | namesList: Array, 139 | attributesList: Array, 140 | } 141 | } 142 | 143 | export interface DtypeMap { 144 | DT_INVALID: 0; 145 | DT_FLOAT32: 1; 146 | DT_FLOAT64: 2; 147 | DT_FLOAT16: 3; 148 | DT_BFLOAT16: 4; 149 | DT_COMPLEX32: 5; 150 | DT_COMPLEX64: 6; 151 | DT_COMPLEX128: 7; 152 | DT_UINT8: 8; 153 | DT_INT8: 9; 154 | DT_INT16: 10; 155 | DT_INT32: 11; 156 | DT_INT64: 12; 157 | DT_BOOL: 13; 158 | DT_QUINT8: 14; 159 | DT_QINT8: 15; 160 | DT_QINT32: 16; 161 | DT_QUINT4_2: 17; 162 | } 163 | 164 | export const Dtype: DtypeMap; 165 | 166 | export interface AttributeTypeMap { 167 | AT_PARAMETER: 0; 168 | AT_BUFFER: 1; 169 | } 170 | 171 | export const AttributeType: AttributeTypeMap; 172 | 173 | -------------------------------------------------------------------------------- /tensors/tensors_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: tensors.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf.internal import enum_type_wrapper 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import descriptor_pool as _descriptor_pool 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rtensors.proto\x12\x07tensors\"\x82\x01\n\x06Tensor\x12\x1d\n\x05\x64type\x18\x01 \x01(\x0e\x32\x0e.tensors.Dtype\x12\r\n\x05shape\x18\x02 \x03(\x03\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12.\n\tattr_type\x18\x04 \x01(\x0e\x32\x16.tensors.AttributeTypeH\x00\x88\x01\x01\x42\x0c\n\n_attr_type\"\xac\x01\n\tAttribute\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x06module\x18\x03 \x01(\x0b\x32\x0f.tensors.ModuleH\x00\x12!\n\x06tensor\x18\x04 \x01(\x0b\x32\x0f.tensors.TensorH\x00\x12\x10\n\x06string\x18\x05 \x01(\tH\x00\x12\x0f\n\x05int64\x18\x06 \x01(\x03H\x00\x12\x0f\n\x05\x66loat\x18\x07 \x01(\x02H\x00\x12\x0e\n\x04\x62ool\x18\x08 \x01(\x08H\x00\x42\x07\n\x05value\"M\n\x06Module\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05names\x18\x02 \x03(\t\x12&\n\nattributes\x18\x03 \x03(\x0b\x32\x12.tensors.Attribute*\x9e\x02\n\x05\x44type\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0e\n\nDT_FLOAT32\x10\x01\x12\x0e\n\nDT_FLOAT64\x10\x02\x12\x0e\n\nDT_FLOAT16\x10\x03\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x04\x12\x10\n\x0c\x44T_COMPLEX32\x10\x05\x12\x10\n\x0c\x44T_COMPLEX64\x10\x06\x12\x11\n\rDT_COMPLEX128\x10\x07\x12\x0c\n\x08\x44T_UINT8\x10\x08\x12\x0b\n\x07\x44T_INT8\x10\t\x12\x0c\n\x08\x44T_INT16\x10\n\x12\x0c\n\x08\x44T_INT32\x10\x0b\x12\x0c\n\x08\x44T_INT64\x10\x0c\x12\x0b\n\x07\x44T_BOOL\x10\r\x12\r\n\tDT_QUINT8\x10\x0e\x12\x0c\n\x08\x44T_QINT8\x10\x0f\x12\r\n\tDT_QINT32\x10\x10\x12\x0f\n\x0b\x44T_QUINT4_2\x10\x11*0\n\rAttributeType\x12\x10\n\x0c\x41T_PARAMETER\x10\x00\x12\r\n\tAT_BUFFER\x10\x01\x42)Z\'github.com/coreweave/tensorizer/tensorsb\x06proto3') 19 | 20 | _DTYPE = DESCRIPTOR.enum_types_by_name['Dtype'] 21 | Dtype = enum_type_wrapper.EnumTypeWrapper(_DTYPE) 22 | _ATTRIBUTETYPE = DESCRIPTOR.enum_types_by_name['AttributeType'] 23 | AttributeType = enum_type_wrapper.EnumTypeWrapper(_ATTRIBUTETYPE) 24 | DT_INVALID = 0 25 | DT_FLOAT32 = 1 26 | DT_FLOAT64 = 2 27 | DT_FLOAT16 = 3 28 | DT_BFLOAT16 = 4 29 | DT_COMPLEX32 = 5 30 | DT_COMPLEX64 = 6 31 | DT_COMPLEX128 = 7 32 | DT_UINT8 = 8 33 | DT_INT8 = 9 34 | DT_INT16 = 10 35 | DT_INT32 = 11 36 | DT_INT64 = 12 37 | DT_BOOL = 13 38 | DT_QUINT8 = 14 39 | DT_QINT8 = 15 40 | DT_QINT32 = 16 41 | DT_QUINT4_2 = 17 42 | AT_PARAMETER = 0 43 | AT_BUFFER = 1 44 | 45 | 46 | _TENSOR = DESCRIPTOR.message_types_by_name['Tensor'] 47 | _ATTRIBUTE = DESCRIPTOR.message_types_by_name['Attribute'] 48 | _MODULE = DESCRIPTOR.message_types_by_name['Module'] 49 | Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { 50 | 'DESCRIPTOR' : _TENSOR, 51 | '__module__' : 'tensors_pb2' 52 | # @@protoc_insertion_point(class_scope:tensors.Tensor) 53 | }) 54 | _sym_db.RegisterMessage(Tensor) 55 | 56 | Attribute = _reflection.GeneratedProtocolMessageType('Attribute', (_message.Message,), { 57 | 'DESCRIPTOR' : _ATTRIBUTE, 58 | '__module__' : 'tensors_pb2' 59 | # @@protoc_insertion_point(class_scope:tensors.Attribute) 60 | }) 61 | _sym_db.RegisterMessage(Attribute) 62 | 63 | Module = _reflection.GeneratedProtocolMessageType('Module', (_message.Message,), { 64 | 'DESCRIPTOR' : _MODULE, 65 | '__module__' : 'tensors_pb2' 66 | # @@protoc_insertion_point(class_scope:tensors.Module) 67 | }) 68 | _sym_db.RegisterMessage(Module) 69 | 70 | if _descriptor._USE_C_DESCRIPTORS == False: 71 | 72 | DESCRIPTOR._options = None 73 | DESCRIPTOR._serialized_options = b'Z\'github.com/coreweave/tensorizer/tensors' 74 | _DTYPE._serialized_start=414 75 | _DTYPE._serialized_end=700 76 | _ATTRIBUTETYPE._serialized_start=702 77 | _ATTRIBUTETYPE._serialized_end=750 78 | _TENSOR._serialized_start=27 79 | _TENSOR._serialized_end=157 80 | _ATTRIBUTE._serialized_start=160 81 | _ATTRIBUTE._serialized_end=332 82 | _MODULE._serialized_start=334 83 | _MODULE._serialized_end=411 84 | # @@protoc_insertion_point(module_scope) 85 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coreweave/tensorizer/9241bc82e7b9fdc3f92aa38ad04efd54f0054525/tests/__init__.py -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.27.1 2 | moto[s3,server]>=4.1.4,<5.0.0 3 | redis>=5.0.0 4 | hiredis>=2.2.0 5 | -------------------------------------------------------------------------------- /tests/test_syscalls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import tensorizer._syscalls as syscalls 6 | 7 | 8 | class TestSyscalls(unittest.TestCase): 9 | def test_fallocate(self): 10 | from tensorizer._syscalls import try_fallocate 11 | 12 | has_fallocate: bool = syscalls.has_fallocate() 13 | 14 | with tempfile.NamedTemporaryFile(mode="wb+") as file: 15 | fd: int = file.fileno() 16 | with self.subTest("Regular fallocate"): 17 | self.assertEqual( 18 | try_fallocate(fd=fd, offset=50, length=1000), 19 | has_fallocate, 20 | ) 21 | try: 22 | self.assertEqual( 23 | os.stat(fd).st_size, 1050 if has_fallocate else 0 24 | ) 25 | finally: 26 | os.ftruncate(fd, 0) 27 | if not has_fallocate: 28 | # The rest of the tests check for errors, which cannot be raised 29 | # if the fallocate syscall is not actually available. 30 | return 31 | with self.subTest( 32 | "Invalid fallocate invocation, errors suppressed" 33 | ): 34 | self.assertFalse( 35 | try_fallocate( 36 | fd=fd, offset=-1, length=0, suppress_all_errors=True 37 | ) 38 | ) 39 | with self.subTest( 40 | "Invalid fallocate invocation (bad offset and length)" 41 | ), self.assertRaises(OSError): 42 | try_fallocate(fd=fd, offset=-1, length=0) 43 | self.assertEqual(os.stat(fd).st_size, 0) 44 | with self.subTest( 45 | "Invalid fallocate invocation (bad file descriptor)" 46 | ), self.assertRaises(OSError): 47 | try: 48 | r_fd, w_fd = os.pipe() 49 | try_fallocate(fd=w_fd, offset=0, length=1000) 50 | finally: 51 | os.close(r_fd) 52 | os.close(w_fd) 53 | 54 | @unittest.skipUnless( 55 | hasattr(os, "pwrite"), "pwrite must be available to test pwrite" 56 | ) 57 | def test_out_of_order_pwrite(self): 58 | def _filler(length: int) -> bytes: 59 | mul, rem = divmod(length, 10) 60 | return (b"0123456789" * (mul + (rem != 0)))[:length] 61 | 62 | with tempfile.TemporaryFile("wb+", buffering=0) as file: 63 | fd: int = file.fileno() 64 | 65 | def pwrite(buffer: bytes, offset: int) -> None: 66 | self.assertEqual(os.pwrite(fd, buffer, offset), len(buffer)) 67 | 68 | discontiguous_offset: int = (10 << 10) + 5 69 | end_contents: bytes = _filler(10) 70 | expected_size: int = discontiguous_offset + len(end_contents) 71 | # This should fill the file with zeroes up to discontiguous_offset 72 | pwrite(end_contents, discontiguous_offset) 73 | self.assertEqual(os.stat(fd).st_size, expected_size) 74 | start_contents: bytes = _filler(5 << 10) 75 | # This should overwrite the existing zeroes, 76 | # and not change the length of the file 77 | pwrite(start_contents, 0) 78 | self.assertEqual(os.stat(fd).st_size, expected_size) 79 | total_written: int = len(start_contents) + len(end_contents) 80 | # The expected end result is start_contents, 81 | # a gap of zeroes up to discontiguous_offset, and then end_contents 82 | expected_contents: bytes = ( 83 | start_contents 84 | + bytes(expected_size - total_written) 85 | + end_contents 86 | ) 87 | self.assertEqual(file.read(), expected_contents) 88 | --------------------------------------------------------------------------------