├── .coveragerc ├── .github └── workflows │ └── workflow.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── test_benchmark.py ├── test_cpu_mem_limit.py ├── test_explicit_toma.py ├── test_simple_toma.py ├── test_stacktrace.py └── test_toma.py └── toma ├── __init__.py ├── batchsize_cache.py ├── cpu_memory.py ├── stacktrace.py └── torch_cuda_memory.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=toma/* 3 | omit= 4 | */tests/* 5 | setup.py 6 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: API workflow 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | name: Test python API 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up Python 12 | # This is the version of the action for setting up Python, not the Python version. 13 | uses: actions/setup-python@v5 14 | with: 15 | # Semantic version range syntax or exact version of a Python version 16 | python-version: '3.10' 17 | # Optional - x64 or x86 architecture, defaults to x64 18 | architecture: 'x64' 19 | # You can test your matrix by printing the current Python version 20 | - name: Display Python version 21 | run: python -c "import sys; print(sys.version)" 22 | - name: Install requirements 23 | run: pip install -e ".[dev,test]" 24 | - name: Run tests and collect coverage 25 | run: pytest --cov . 26 | - name: Upload coverage reports to Codecov 27 | run: | 28 | # Replace `linux` below with the appropriate OS 29 | # Options are `alpine`, `linux`, `macos`, `windows` 30 | curl -Os https://cli.codecov.io/latest/linux/codecov 31 | chmod +x codecov 32 | ./codecov --verbose upload-process --fail-on-error -t ${{ secrets.CODECOV_TOKEN }} -n 'service'-${{ github.run_id }} -F service -f coverage-service.xml 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /build/ 3 | /dist/ 4 | .benchmarks/ 5 | *.egg-info -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | matrix: 3 | include: 4 | - python: 3.7 5 | dist: xenial 6 | sudo: true 7 | install: 8 | - pip install -q -e .[dev,test] 9 | script: 10 | - python setup.py test 11 | 12 | # Push the results back to codecov 13 | after_success: 14 | - codecov -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Andreas @blackhc Kirsch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.py 2 | include LICENSE 3 | recursive-include src *.py 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch Memory-adaptive Algorithms (TOMA) 2 | 3 | [![Build Status](https://www.travis-ci.com/BlackHC/toma.svg?branch=master)](https://www.travis-ci.com/BlackHC/toma) [![codecov](https://codecov.io/gh/BlackHC/toma/branch/master/graph/badge.svg)](https://codecov.io/gh/BlackHC/toma) [![PyPI](https://img.shields.io/badge/PyPI-toma-blue.svg)](https://pypi.python.org/pypi/toma/) 4 | 5 | A collection of helpers to make it easier to write code that adapts to the available (CUDA) memory. 6 | Specifically, it retries code that fails due to OOM (out-of-memory) conditions and lowers batchsizes automatically. 7 | 8 | To avoid failing over repeatedly, a simple cache is implemented that memorizes that last successful batchsize given the call and available free memory. 9 | 10 | ## Installation 11 | 12 | To install using pip, use: 13 | 14 | ``` 15 | pip install toma 16 | ``` 17 | 18 | To run the tests, use: 19 | 20 | ``` 21 | python setup.py test 22 | ``` 23 | 24 | ## Example 25 | 26 | ```python 27 | from toma import toma 28 | 29 | @toma.batch(initial_batchsize=512) 30 | def run_inference(batchsize, model, dataset): 31 | # ... 32 | 33 | run_inference(batchsize, model, dataset) 34 | ``` 35 | 36 | This will try to execute train_model with batchsize=512. If a memory error is thrown, it will decrease the batchsize until it succeeds. 37 | 38 | **Note:** 39 | This batch size can be different from the batch size used to accumulate gradients by only calling `optimizer.step()` every so often. 40 | 41 | To make it easier to loop over a ranges, there are also `toma.range` and `toma.chunked`: 42 | 43 | ```python 44 | @toma.chunked(initial_step=512) 45 | def compute_result(out: torch.Tensor, start: int, end: int): 46 | # ... 47 | 48 | result = torch.empty((8192, ...)) 49 | compute_result(result) 50 | ``` 51 | 52 | This will chunk `result` and pass the chunks to `compute_result` one by one. 53 | Again, if it fails due to OOM, the step will be halfed etc. 54 | Compared to `toma.batch`, this allows for reduction of the step size while looping over the chunks. 55 | This can save computation. 56 | 57 | ```python 58 | @toma.range(initial_step=32) 59 | def reduce_data(start: int, end: int, out: torch.Tensor, dataA: torch.Tensor, dataB: torch.Tensor): 60 | # ... 61 | 62 | reduce_data(0, 1024, result, dataA, dataB) 63 | ``` 64 | 65 | `toma.range` iterates over `range(start, end, step)` with `step=initial_step`. If it fails due to OOM, it will lower the step size and continue. 66 | 67 | ### `toma.execute` 68 | 69 | To make it easier to just execute a block without having to extract it into a function and then call it, we also provide `toma.execute.batch`, `toma.execute.range` and `toma.execute.chunked`, which are somewhat unorthodox and call the function that is passed to them right away. (Mainly because there is no support for anonymous functions in Python beyond lambda expressions.) 70 | 71 | ```python 72 | def function(): 73 | # ... other code 74 | 75 | @toma.execute.chunked(batched_data, initial_step=128): 76 | def compute(chunk, start, end): 77 | # ... 78 | ``` 79 | 80 | ## Cache 81 | 82 | There are 3 available cache types at the moment. 83 | They can be changed by either setting `toma.DEFAULT_CACHE_TYPE` or by passing `cache_type` to the calls. 84 | 85 | For example: 86 | ```python 87 | @toma.batch(initial_batchsize=512, cache_type=toma.GlobalBatchsizeCache) 88 | ``` 89 | or 90 | ```python 91 | toma.explicit.batch(..., toma_cache_type=toma.GlobalBatchsizeCache) 92 | ``` 93 | 94 | ### `StacktraceMemoryBatchsizeCache`: Stacktrace & Available Memory (*the default*) 95 | 96 | This memorizes the successful batchsizes for a given call trace and available memory at that point. 97 | For most machine learning code, this is sufficient to remember the right batchsize without having to look at the actual arguments and understanding more of the semantics. 98 | 99 | The implicit assumption is that after a few iterations a stable state will be reached in regards to GPU and CPU memory usage. 100 | 101 | To limit the CPU memory of the process, toma provides: 102 | ```python 103 | import toma.cpu_memory 104 | 105 | toma.cpu_memory.set_cpu_memory_limit(8) 106 | ``` 107 | This can also be useful to avoid accidental swap thrashing. 108 | 109 | ### `GlobalBatchsizeCache`: Global per Function 110 | 111 | This reuses the last successful batchsize independently from where the call happened. 112 | 113 | ### `NoBatchsizeCache`: No Caching 114 | 115 | Always starts with the suggested batchsize and fails over if necessary. 116 | 117 | ## Benchmark/Overhead 118 | 119 | There is overhead involved. Toma should only be used with otherwise time/memory-consuming operations. 120 | 121 | ```text 122 | ---------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------- 123 | Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations 124 | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 125 | test_native 2.1455 (1.0) 3.7733 (1.0) 2.3037 (1.0) 0.1103 (1.0) 2.2935 (1.0) 0.1302 (1.0) 81;5 434.0822 (1.0) 448 1 126 | test_simple 17.4657 (8.14) 27.0049 (7.16) 21.0453 (9.14) 2.6233 (23.79) 20.4881 (8.93) 3.4384 (26.42) 13;0 47.5165 (0.11) 39 1 127 | test_toma_no_cache 31.4380 (14.65) 40.8567 (10.83) 33.2749 (14.44) 2.2530 (20.43) 32.2698 (14.07) 2.8210 (21.67) 4;1 30.0527 (0.07) 25 1 128 | test_explicit 33.0759 (15.42) 52.1866 (13.83) 39.6956 (17.23) 6.9620 (63.14) 38.4929 (16.78) 11.2344 (86.31) 4;0 25.1917 (0.06) 20 1 129 | test_toma 36.9633 (17.23) 57.0220 (15.11) 43.5201 (18.89) 6.7318 (61.05) 41.6034 (18.14) 7.2173 (55.45) 2;2 22.9779 (0.05) 13 1 130 | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 131 | ``` 132 | 133 | ## Thanks 134 | 135 | Thanks to [@y0ast](https://github.com/y0ast) for feedback and discussion. 136 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py36', 'py37', 'py38'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | # addopts = --cov toma 6 | 7 | [pylama:pycodestyle] 8 | max_line_length = 120 9 | 10 | [pylama:pylint] 11 | max_line_length = 120 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | from setuptools import setup 3 | 4 | # To use a consistent encoding 5 | from codecs import open 6 | from os import path 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # Get the long description from the README file 11 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | setup( 15 | name="toma", 16 | # Versions should comply with PEP440. For a discussion on single-sourcing 17 | # the version across setup.py and the project code, see 18 | # https://packaging.python.org/en/latest/single_source_version.html 19 | version="1.1.0", 20 | description="Write algorithms in PyTorch that adapt to the available (CUDA) memory", 21 | # Fix windows newlines. 22 | long_description=long_description.replace("\r\n", "\n"), 23 | long_description_content_type="text/markdown", 24 | 25 | # The project's main homepage. 26 | url="https://github.com/blackhc/toma", 27 | # Author details 28 | author="Andreas @blackhc Kirsch", 29 | author_email="blackhc+toma@gmail.com", 30 | # Choose your license 31 | license="MIT", 32 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 33 | classifiers=[ 34 | # How mature is this project? Common values are 35 | # 3 - Alpha 36 | # 4 - Beta 37 | # 5 - Production/Stable 38 | "Development Status :: 5 - Production/Stable", 39 | # Indicate who your project is intended for 40 | "Intended Audience :: Developers", 41 | "Intended Audience :: Science/Research", 42 | "Topic :: Software Development :: Libraries :: Python Modules", 43 | # Pick your license as you wish (should match "license" above) 44 | "License :: OSI Approved :: MIT License", 45 | "Programming Language :: Python :: 3.7", 46 | ], 47 | # What does your project relate to? 48 | keywords="tools pytorch", 49 | # You can just specify the packages manually here if your project is 50 | # simple. Or you can use find_packages(). 51 | packages=["toma"], 52 | # List run-time dependencies here. These will be installed by pip when 53 | # your project is installed. For an analysis of "install_requires" vs pip's 54 | # requirements files see: 55 | # https://packaging.python.org/en/latest/requirements.html 56 | install_requires=["torch", "psutil"], 57 | # List additional groups of dependencies here (e.g. development 58 | # dependencies). You can install these using the following syntax, 59 | # for example: 60 | # $ pip install -e .[dev,test] 61 | extras_require={ 62 | "dev": ["check-manifest"], 63 | "test": ["coverage", "codecov", "pytest", "pytest-benchmark", "pytest-cov", "pytest-forked"], 64 | }, 65 | setup_requires=["setuptools>=60.7.0", "pytest-runner"], 66 | ) 67 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest_benchmark 3 | 4 | # Preload this import 5 | import resource 6 | 7 | from toma import simple, toma, explicit, NoBatchsizeCache 8 | 9 | 10 | def test_toma(benchmark): 11 | benchmark.extra_info["Debug Mode"] = __debug__ 12 | 13 | @toma.chunked(initial_step=32) 14 | def func(tensor, start, end): 15 | tensor[:] = 1.0 16 | 17 | tensor = torch.zeros((128, 256, 256)) 18 | benchmark(func, tensor) 19 | 20 | 21 | def test_toma_no_cache(benchmark): 22 | benchmark.extra_info["Debug Mode"] = __debug__ 23 | 24 | @toma.chunked(initial_step=32, cache_type=NoBatchsizeCache) 25 | def func(tensor, start, end): 26 | tensor[:] = 1.0 27 | 28 | tensor = torch.zeros((128, 256, 256)) 29 | benchmark(func, tensor) 30 | 31 | 32 | def test_explicit(benchmark): 33 | benchmark.extra_info["Debug Mode"] = __debug__ 34 | 35 | def func(tensor, start, end): 36 | tensor[:] = 1.0 37 | 38 | tensor = torch.zeros((128, 256, 256)) 39 | benchmark(explicit.chunked, func, tensor, 32) 40 | 41 | 42 | def test_simple(benchmark): 43 | benchmark.extra_info["Debug Mode"] = __debug__ 44 | 45 | def func(tensor, start, end): 46 | tensor[:] = 1.0 47 | 48 | tensor = torch.zeros((128, 256, 256)) 49 | benchmark(simple.chunked, func, tensor, 32) 50 | 51 | 52 | def test_native(benchmark): 53 | benchmark.extra_info["Debug Mode"] = __debug__ 54 | 55 | def func(tensor, start, end): 56 | tensor[:] = 1.0 57 | 58 | def native(func, tensor, batch): 59 | end = tensor.shape[0] 60 | current = 0 61 | while current < end: 62 | current_end = min(current + batch, end) 63 | func(tensor.narrow(0, current, current_end - current), current, current_end) 64 | current = current_end 65 | 66 | tensor = torch.zeros((128, 256, 256)) 67 | benchmark(native, func, tensor, 32) 68 | -------------------------------------------------------------------------------- /tests/test_cpu_mem_limit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from toma import cpu_memory 5 | 6 | 7 | @pytest.mark.forked 8 | def test_cpu_mem_limit(): 9 | # 128 MB (128/4M float32) 10 | tensor = torch.empty((128, 1024, 1024 // 4), dtype=torch.float32) 11 | tensor.resize_(1) 12 | 13 | cpu_memory.set_cpu_memory_limit(0.25) 14 | 15 | # 512 MB (128/4M float32) 16 | with pytest.raises(RuntimeError): 17 | torch.empty((512, 1024, 1024 // 4), dtype=torch.float32) 18 | -------------------------------------------------------------------------------- /tests/test_explicit_toma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from toma import explicit 5 | from toma.batchsize_cache import NoBatchsizeCache, GlobalBatchsizeCache, StacktraceMemoryBatchsizeCache 6 | 7 | 8 | def raise_fake_oom(): 9 | raise RuntimeError("CUDA out of memory.") 10 | 11 | 12 | def test_fake_explicit_batch_raise(): 13 | def f(batchsize): 14 | raise_fake_oom() 15 | 16 | with pytest.raises(RuntimeError): 17 | explicit.batch(f, 64) 18 | 19 | 20 | def test_fake_explicit_range_raise(): 21 | def f(start, end): 22 | raise_fake_oom() 23 | 24 | with pytest.raises(RuntimeError): 25 | explicit.range(f, 0, 64, 64) 26 | 27 | 28 | def test_fake_explicit_batch_none(): 29 | batchsizes = [] 30 | 31 | def f(batchsize): 32 | nonlocal batchsizes 33 | batchsizes.append(batchsize) 34 | 35 | if batchsize != 16: 36 | raise_fake_oom() 37 | 38 | for _i in range(2): 39 | explicit.batch(f, 64, toma_cache_type=NoBatchsizeCache) 40 | 41 | assert batchsizes == [64, 32, 16, 64, 32, 16] 42 | 43 | 44 | def test_fake_explicit_batch_global(): 45 | batchsizes = [] 46 | 47 | def f(batchsize): 48 | nonlocal batchsizes 49 | batchsizes.append(batchsize) 50 | 51 | if batchsize != 16: 52 | raise_fake_oom() 53 | 54 | for _i in range(3): 55 | explicit.batch(f, 64, toma_cache_type=GlobalBatchsizeCache) 56 | 57 | explicit.batch(f, 64, toma_cache_type=GlobalBatchsizeCache) 58 | 59 | assert batchsizes == [64, 32, 16, 16, 16, 16] 60 | 61 | 62 | def test_fake_explicit_batch_sm(): 63 | batchsizes = [] 64 | 65 | def f(batchsize): 66 | nonlocal batchsizes 67 | batchsizes.append(batchsize) 68 | 69 | if batchsize != 16: 70 | raise_fake_oom() 71 | 72 | for _i in range(3): 73 | explicit.batch(f, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 74 | 75 | explicit.batch(f, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 76 | 77 | assert batchsizes == [64, 32, 16, 16, 16, 64, 32, 16] 78 | 79 | 80 | def test_fake_explicit_batch_mix(): 81 | batchsizes = [] 82 | 83 | def f(batchsize): 84 | nonlocal batchsizes 85 | batchsizes.append(batchsize) 86 | 87 | if batchsize != 16: 88 | raise_fake_oom() 89 | 90 | for _ in range(3): 91 | explicit.batch(f, 64, toma_cache_type=GlobalBatchsizeCache) 92 | 93 | explicit.batch(f, 64, toma_cache_type=GlobalBatchsizeCache) 94 | 95 | for _ in range(2): 96 | explicit.batch(f, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 97 | 98 | assert batchsizes == [64, 32, 16, 16, 16, 16, 64, 32, 16, 16] 99 | 100 | 101 | def test_fake_explicit_range_none(): 102 | batchsizes = [] 103 | 104 | def f(start, end): 105 | batchsize = end - start 106 | 107 | nonlocal batchsizes 108 | batchsizes.append(batchsize) 109 | 110 | remaining = 128 - end 111 | 112 | if batchsize > 16 and batchsize > remaining: 113 | raise_fake_oom() 114 | 115 | for _ in range(2): 116 | explicit.range(f, 0, 128, 64, toma_cache_type=NoBatchsizeCache) 117 | 118 | assert batchsizes == [64, 64, 32, 32, 16, 16] * 2 119 | 120 | 121 | def test_fake_explicit_range_global(): 122 | batchsizes = [] 123 | 124 | def f(start, end): 125 | batchsize = end - start 126 | 127 | nonlocal batchsizes 128 | batchsizes.append(batchsize) 129 | 130 | remaining = 128 - end 131 | 132 | if batchsize > 16 and batchsize > remaining: 133 | raise_fake_oom() 134 | 135 | for _ in range(2): 136 | explicit.range(f, 0, 128, 64, toma_cache_type=GlobalBatchsizeCache) 137 | 138 | explicit.range(f, 0, 128, 64, toma_cache_type=GlobalBatchsizeCache) 139 | 140 | assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 * 2 141 | 142 | 143 | def test_fake_explicit_range_sm(): 144 | batchsizes = [] 145 | 146 | def f(start, end): 147 | batchsize = end - start 148 | 149 | nonlocal batchsizes 150 | batchsizes.append(batchsize) 151 | 152 | remaining = 128 - end 153 | 154 | if batchsize > 16 and batchsize > remaining: 155 | raise_fake_oom() 156 | 157 | for _ in range(2): 158 | explicit.range(f, 0, 128, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 159 | 160 | explicit.range(f, 0, 128, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 161 | 162 | assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 + [64, 64, 32, 32, 16, 16] 163 | 164 | 165 | def test_fake_explicit_range_sm(): 166 | batchsizes = [] 167 | 168 | def f(start, end): 169 | batchsize = end - start 170 | 171 | nonlocal batchsizes 172 | batchsizes.append(batchsize) 173 | 174 | remaining = 128 - end 175 | 176 | if batchsize > 16 and batchsize > remaining: 177 | raise_fake_oom() 178 | 179 | for _ in range(2): 180 | explicit.range(f, 0, 128, 64, toma_cache_type=GlobalBatchsizeCache) 181 | 182 | explicit.range(f, 0, 128, 64, toma_cache_type=GlobalBatchsizeCache) 183 | 184 | for _ in range(2): 185 | explicit.range(f, 0, 128, 64, toma_cache_type=StacktraceMemoryBatchsizeCache) 186 | 187 | assert batchsizes == ([64, 64, 32, 32, 16, 16] + [16] * 8 * 2 + [64, 64, 32, 32, 16, 16] + [16] * 8) 188 | 189 | 190 | def test_explicit_chunked(): 191 | def func(tensor, start, end): 192 | tensor[:] = 1.0 193 | 194 | tensor = torch.zeros((128, 4, 4)) 195 | explicit.chunked(func, tensor, 32) 196 | assert torch.allclose(tensor, torch.tensor(1.0)) 197 | -------------------------------------------------------------------------------- /tests/test_simple_toma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from toma import simple 4 | 5 | 6 | def raise_fake_oom(): 7 | raise RuntimeError("CUDA out of memory.") 8 | 9 | 10 | def test_fake_simple_batch(): 11 | hit_16 = False 12 | 13 | def f(batch_size): 14 | if batch_size != 16: 15 | raise_fake_oom() 16 | if batch_size == 16: 17 | nonlocal hit_16 18 | hit_16 = True 19 | 20 | assert batch_size >= 16 21 | 22 | simple.batch(f, 64) 23 | 24 | assert hit_16 25 | 26 | 27 | def test_fake_simple_range(): 28 | hit_16 = False 29 | 30 | def f(start, end): 31 | batch_size = end - start 32 | 33 | if batch_size != 16: 34 | raise_fake_oom() 35 | if batch_size == 16: 36 | nonlocal hit_16 37 | hit_16 = True 38 | 39 | assert batch_size >= 16 40 | 41 | simple.range(f, 0, 128, 64) 42 | 43 | assert hit_16 44 | 45 | 46 | def test_fake_simple_chunked(): 47 | hit_16 = False 48 | 49 | def f(tensor, start, end): 50 | batch_size = end - start 51 | 52 | if batch_size != 16: 53 | raise_fake_oom() 54 | if batch_size == 16: 55 | nonlocal hit_16 56 | hit_16 = True 57 | 58 | tensor[:] = 1 59 | 60 | tensor = torch.zeros(128, dtype=torch.float) 61 | simple.chunked(f, tensor, 64) 62 | assert torch.allclose(tensor, torch.tensor(1.0)) 63 | assert hit_16 64 | 65 | 66 | def test_simple_batch(): 67 | import torch 68 | 69 | if not torch.cuda.is_available(): 70 | print("CUDA not available") 71 | return 72 | 73 | failed = False 74 | succeeded = False 75 | 76 | def f(batch_size): 77 | # 2**20*2*7*2**3 * batch_size = batch_size GB 78 | try: 79 | torch.empty((batch_size, 1024, 1024, 128), dtype=torch.double, device="cuda") 80 | except: 81 | nonlocal failed 82 | failed = True 83 | 84 | nonlocal succeeded 85 | succeeded = True 86 | 87 | simple.batch(f, 64) 88 | assert failed 89 | assert succeeded 90 | 91 | 92 | def test_simple_range(): 93 | import torch 94 | 95 | if not torch.cuda.is_available(): 96 | print("CUDA not available") 97 | return 98 | 99 | failed = False 100 | succeeded = False 101 | 102 | def f(start, end): 103 | # 2**20*2*7*2**3 * batch_size = batch_size GB 104 | try: 105 | torch.empty((end - start, 1024, 1024, 128), dtype=torch.double, device="cuda") 106 | except: 107 | nonlocal failed 108 | failed = True 109 | 110 | nonlocal succeeded 111 | succeeded = True 112 | 113 | simple.range(f, 0, 128, 64) 114 | assert failed 115 | assert succeeded 116 | -------------------------------------------------------------------------------- /tests/test_stacktrace.py: -------------------------------------------------------------------------------- 1 | from toma import stacktrace 2 | 3 | 4 | def get_stacktrace(): 5 | return stacktrace.get_simple_traceback() 6 | 7 | 8 | def outer_func(): 9 | return get_stacktrace() 10 | 11 | 12 | def test_get_simple_traceback(): 13 | stacktrace1 = outer_func() 14 | stacktrace2 = outer_func() 15 | 16 | assert hash(stacktrace1) != hash(stacktrace2) 17 | assert stacktrace1 != stacktrace2 18 | 19 | stacktraces = [] 20 | for i in range(2): 21 | stacktraces.append(outer_func()) 22 | 23 | assert stacktraces[0] == stacktraces[1] 24 | assert hash(stacktraces[0]) == hash(stacktraces[1]) 25 | 26 | 27 | def test_watermark(): 28 | stacktrace1 = outer_func() 29 | 30 | assert len(stacktrace1) > 1 31 | 32 | with stacktrace.watermark(): 33 | stacktrace2 = outer_func() 34 | 35 | assert len(stacktrace2) == 3 36 | 37 | stacktrace3 = outer_func() 38 | assert len(stacktrace1) == len(stacktrace3) 39 | 40 | 41 | def test_set_watermark(): 42 | @stacktrace.set_watermark 43 | def outer_func2(): 44 | return outer_func() 45 | 46 | assert len(outer_func2()) == 3 47 | -------------------------------------------------------------------------------- /tests/test_toma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from toma import toma, explicit, batchsize_cache as tbc 3 | 4 | 5 | def raise_fake_oom(): 6 | raise RuntimeError("CUDA out of memory.") 7 | 8 | 9 | def test_fake_batch_none(): 10 | batchsizes = [] 11 | 12 | @toma.batch(initial_batchsize=64, cache_type=tbc.NoBatchsizeCache) 13 | def f(batchsize): 14 | nonlocal batchsizes 15 | batchsizes.append(batchsize) 16 | 17 | if batchsize != 16: 18 | raise_fake_oom() 19 | 20 | for _ in range(2): 21 | f() 22 | 23 | assert batchsizes == [64, 32, 16, 64, 32, 16] 24 | 25 | 26 | def test_fake_batch_global(): 27 | batchsizes = [] 28 | 29 | @toma.batch(initial_batchsize=64, cache_type=tbc.GlobalBatchsizeCache) 30 | def f(batchsize): 31 | nonlocal batchsizes 32 | batchsizes.append(batchsize) 33 | 34 | if batchsize != 16: 35 | raise_fake_oom() 36 | 37 | for _ in range(2): 38 | f() 39 | 40 | assert batchsizes == [64, 32, 16, 16] 41 | 42 | 43 | def test_fake_range_none(): 44 | batchsizes = [] 45 | 46 | @toma.range(initial_step=64, cache_type=tbc.NoBatchsizeCache) 47 | def f(start, end): 48 | batchsize = end - start 49 | 50 | nonlocal batchsizes 51 | batchsizes.append(batchsize) 52 | 53 | remaining = 128 - end 54 | 55 | if batchsize > 16 and batchsize > remaining: 56 | raise_fake_oom() 57 | 58 | for _ in range(2): 59 | f(0, 128) 60 | 61 | assert batchsizes == [64, 64, 32, 32, 16, 16] * 2 62 | 63 | 64 | def test_fake_range_global(): 65 | batchsizes = [] 66 | 67 | @toma.range(initial_step=64, cache_type=tbc.GlobalBatchsizeCache) 68 | def f(start, end): 69 | batchsize = end - start 70 | 71 | nonlocal batchsizes 72 | batchsizes.append(batchsize) 73 | 74 | remaining = 128 - end 75 | 76 | if batchsize > 16 and batchsize > remaining: 77 | raise_fake_oom() 78 | 79 | for _ in range(2): 80 | f(0, 128) 81 | 82 | assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 83 | 84 | 85 | def test_fake_batch_none_execute(): 86 | batchsizes = [] 87 | 88 | @toma.batch(initial_batchsize=64, cache_type=tbc.NoBatchsizeCache) 89 | def f(batchsize): 90 | nonlocal batchsizes 91 | batchsizes.append(batchsize) 92 | 93 | if batchsize != 16: 94 | raise_fake_oom() 95 | 96 | for _ in range(2): 97 | f() 98 | 99 | assert batchsizes == [64, 32, 16, 64, 32, 16] 100 | 101 | 102 | def test_fake_batch_global_execute(): 103 | batchsizes = [] 104 | 105 | for _ in range(2): 106 | 107 | @toma.execute.batch(initial_batchsize=64, cache_type=tbc.GlobalBatchsizeCache) 108 | def f(batchsize): 109 | nonlocal batchsizes 110 | batchsizes.append(batchsize) 111 | 112 | if batchsize != 16: 113 | raise_fake_oom() 114 | 115 | assert batchsizes == [64, 32, 16, 16] 116 | 117 | 118 | def test_fake_range_none_execute(): 119 | batchsizes = [] 120 | 121 | for _ in range(2): 122 | 123 | @toma.execute.range(0, 128, initial_step=64, cache_type=tbc.NoBatchsizeCache) 124 | def f(start, end): 125 | batchsize = end - start 126 | 127 | nonlocal batchsizes 128 | batchsizes.append(batchsize) 129 | 130 | remaining = 128 - end 131 | 132 | if batchsize > 16 and batchsize > remaining: 133 | raise_fake_oom() 134 | 135 | assert batchsizes == [64, 64, 32, 32, 16, 16] * 2 136 | 137 | 138 | def test_fake_range_global_execute(): 139 | batchsizes = [] 140 | 141 | for _ in range(2): 142 | 143 | @toma.execute.range(0, 128, initial_step=64, cache_type=tbc.GlobalBatchsizeCache) 144 | def f(start, end): 145 | batchsize = end - start 146 | 147 | nonlocal batchsizes 148 | batchsizes.append(batchsize) 149 | 150 | remaining = 128 - end 151 | 152 | if batchsize > 16 and batchsize > remaining: 153 | raise_fake_oom() 154 | 155 | assert batchsizes == [64, 64, 32, 32, 16, 16] + [16] * 8 156 | 157 | 158 | def test_chunked(): 159 | @toma.chunked(initial_step=32) 160 | def func(tensor, start, end): 161 | tensor[:] = 1.0 162 | 163 | tensor = torch.zeros((128, 4, 4)) 164 | func(tensor) 165 | assert torch.allclose(tensor, torch.tensor(1.0)) 166 | -------------------------------------------------------------------------------- /toma/__init__.py: -------------------------------------------------------------------------------- 1 | """Torch Memory-Adaptive Algorithms. 2 | 3 | Helpers to allow for OOM conditions and dynamic adaptation of "internal" 4 | batchsizes. (Without affecting the computational ones.) 5 | """ 6 | import functools 7 | from typing import Type, Optional 8 | 9 | import torch 10 | 11 | import toma.stacktrace as tst 12 | from toma.batchsize_cache import StacktraceMemoryBatchsizeCache, NoBatchsizeCache, GlobalBatchsizeCache 13 | from toma.cpu_memory import is_out_of_cpu_memory 14 | from toma.torch_cuda_memory import is_cuda_out_of_memory, is_cudnn_snafu, gc_cuda 15 | 16 | DEFAULT_CACHE_TYPE = StacktraceMemoryBatchsizeCache 17 | 18 | 19 | class simple: 20 | """ 21 | Straight-forward wrappers (can be copy-pasted and hacked easily). 22 | """ 23 | 24 | @staticmethod 25 | def batch(func, initial_batchsize: int, *args, **kwargs): 26 | gc_cuda() 27 | 28 | batchsize = initial_batchsize 29 | while True: 30 | try: 31 | return func(batchsize, *args, **kwargs) 32 | except RuntimeError as exception: 33 | if batchsize > 1 and should_reduce_batch_size(exception): 34 | batchsize //= 2 35 | gc_cuda() 36 | else: 37 | raise 38 | 39 | @staticmethod 40 | def range(func, start: int, end: int, initial_step: int, *args, **kwargs): 41 | gc_cuda() 42 | 43 | stepsize = initial_step 44 | current = start 45 | while current < end: 46 | try: 47 | func(current, min(current + stepsize, end), *args, **kwargs) 48 | current += stepsize 49 | except RuntimeError as exception: 50 | if stepsize > 1 and should_reduce_batch_size(exception): 51 | stepsize //= 2 52 | gc_cuda() 53 | else: 54 | raise 55 | 56 | @staticmethod 57 | def chunked(func, tensor, initial_step: int, dimension: int = 0): 58 | def body(start, end): 59 | return func(tensor.narrow(dim=dimension, start=start, length=end - start), start, end) 60 | 61 | return simple.range(body, 0, tensor.shape[dimension], initial_step) 62 | 63 | 64 | class toma: 65 | """ 66 | Decorators that make it easy to wrap functions. 67 | """ 68 | 69 | @staticmethod 70 | def batch(func=None, *, initial_batchsize=None, cache_type=DEFAULT_CACHE_TYPE, context=None): 71 | if func is None: 72 | return functools.partial( 73 | toma.batch, initial_batchsize=initial_batchsize, cache_type=cache_type, context=None 74 | ) 75 | 76 | @functools.wraps(func) 77 | def wrapped(*args, toma_initial_batchsize=None, toma_context=None, **kwargs): 78 | _initial_batchsize = toma_initial_batchsize or initial_batchsize 79 | _context = toma_context or context 80 | return explicit.batch( 81 | func, _initial_batchsize, *args, toma_cache_type=cache_type, toma_context=_context, **kwargs 82 | ) 83 | 84 | wrapped.__doc__ = f""" 85 | Wrapped in toma.batch: 86 | 87 | Additional keyargs: 88 | toma_initial_batchsize: initial step size to use. 89 | 90 | {wrapped.__doc__} 91 | """ 92 | 93 | return wrapped 94 | 95 | @staticmethod 96 | def range(func=None, *, initial_step: Optional[int] = None, cache_type=DEFAULT_CACHE_TYPE, context=None): 97 | if func is None: 98 | return functools.partial(toma.range, initial_step=initial_step, cache_type=cache_type, context=context) 99 | 100 | @functools.wraps(func) 101 | def wrapped(start: int, end: int, *args, toma_initial_step: Optional[int] = None, toma_context=None, **kwargs): 102 | _initial_step = toma_initial_step or initial_step 103 | _context = toma_context or context 104 | 105 | return explicit.range( 106 | func, start, end, _initial_step, *args, toma_context=_context, toma_cache_type=cache_type, **kwargs 107 | ) 108 | 109 | wrapped.__doc__ = f""" 110 | Wrapped in toma.range: 111 | 112 | Additional keyargs: 113 | toma_initial_step: initial step size to use. 114 | 115 | {wrapped.__doc__} 116 | """ 117 | 118 | return wrapped 119 | 120 | @staticmethod 121 | def chunked( 122 | func=None, 123 | *, 124 | initial_step: Optional[int] = None, 125 | dimension: Optional[int] = None, 126 | cache_type: Type = DEFAULT_CACHE_TYPE, 127 | context=None, 128 | ): 129 | dimension = dimension or 0 130 | if func is None: 131 | return functools.partial( 132 | toma.chunked, initial_step=initial_step, dimension=dimension, cache_type=cache_type 133 | ) 134 | 135 | @functools.wraps(func) 136 | def wrapped( 137 | tensor: torch.Tensor, 138 | *args, 139 | toma_initial_step: Optional[int] = None, 140 | toma_dimension: Optional[int] = None, 141 | toma_context=None, 142 | **kwargs, 143 | ): 144 | _initial_step = toma_initial_step or initial_step 145 | _dimension = toma_dimension or dimension 146 | _context = toma_context or context 147 | 148 | explicit.chunked( 149 | func, 150 | tensor, 151 | _initial_step, 152 | *args, 153 | toma_dimension=toma_dimension, 154 | toma_cache_type=cache_type, 155 | toma_context=_context, 156 | **kwargs, 157 | ) 158 | 159 | wrapped.__doc__ = f""" 160 | Wrapped in toma.chunked: 161 | 162 | Additional keyargs: 163 | toma_initial_step: initial step size to use 164 | toma_dimension: dimension of the tensor to chunk along 165 | 166 | {wrapped.__doc__} 167 | """ 168 | 169 | return wrapped 170 | 171 | class execute: 172 | @staticmethod 173 | def batch(initial_batchsize, cache_type=DEFAULT_CACHE_TYPE, context=None): 174 | context = context or tst.get_simple_traceback(1) 175 | 176 | def execute_batch(func): 177 | return explicit.batch(func, initial_batchsize, toma_cache_type=cache_type, toma_context=context) 178 | 179 | return execute_batch 180 | 181 | @staticmethod 182 | def range(start, end, initial_step, cache_type=DEFAULT_CACHE_TYPE, context=None): 183 | context = context or tst.get_simple_traceback(1) 184 | 185 | def execute_range(func): 186 | return explicit.range(func, start, end, initial_step, toma_cache_type=cache_type, toma_context=context) 187 | 188 | return execute_range 189 | 190 | @staticmethod 191 | def chunked( 192 | tensor: torch.Tensor, 193 | initial_step: Optional[int] = None, 194 | dimension: Optional[int] = None, 195 | cache_type: Type = DEFAULT_CACHE_TYPE, 196 | context=None, 197 | ): 198 | context = context or tst.get_simple_traceback(1) 199 | 200 | def execute_chunked(func): 201 | return explicit.chunked( 202 | func, 203 | tensor, 204 | initial_step, 205 | toma_dimension=dimension, 206 | toma_cache_type=cache_type, 207 | toma_context=context, 208 | ) 209 | 210 | return execute_chunked 211 | 212 | 213 | CONTEXT_CACHE_SIZE = 2 ** 14 214 | 215 | 216 | @functools.lru_cache(CONTEXT_CACHE_SIZE) 217 | def get_cache_for_context(batchsize_cache_type, context): 218 | return batchsize_cache_type() 219 | 220 | 221 | class explicit: 222 | """ 223 | Explicit calls that can use different cache types to memorize settings. 224 | """ 225 | 226 | @staticmethod 227 | def batch( 228 | func, initial_batchsize: int, *args, toma_context=None, toma_cache_type: Type = DEFAULT_CACHE_TYPE, **kwargs 229 | ): 230 | gc_cuda() 231 | 232 | cache = get_cache_for_context(toma_cache_type, toma_context or func) 233 | 234 | batchsize = cache.get_batchsize(initial_batchsize) 235 | 236 | while True: 237 | try: 238 | value = batchsize.get() 239 | result = func(value, *args, **kwargs) 240 | gc_cuda() 241 | return result 242 | except RuntimeError as exception: 243 | if value > 1 and should_reduce_batch_size(exception): 244 | batchsize.decrease_batchsize() 245 | gc_cuda() 246 | else: 247 | raise 248 | 249 | @staticmethod 250 | def range( 251 | func, 252 | start: int, 253 | end: int, 254 | initial_step: int, 255 | *args, 256 | toma_context=None, 257 | toma_cache_type: Type = DEFAULT_CACHE_TYPE, 258 | **kwargs, 259 | ): 260 | gc_cuda() 261 | 262 | cache = get_cache_for_context(toma_cache_type, toma_context or func) 263 | 264 | batchsize = cache.get_batchsize(initial_step) 265 | 266 | current = start 267 | while current < end: 268 | try: 269 | func(current, min(current + batchsize.get(), end), *args, **kwargs) 270 | current += batchsize.get() 271 | gc_cuda() 272 | except RuntimeError as exception: 273 | if batchsize.get() > 1 and should_reduce_batch_size(exception): 274 | batchsize.decrease_batchsize() 275 | gc_cuda() 276 | else: 277 | raise 278 | 279 | @staticmethod 280 | def chunked( 281 | func, 282 | tensor: torch.Tensor, 283 | initial_step: int, 284 | *args, 285 | toma_dimension: int = None, 286 | toma_context=None, 287 | toma_cache_type: Type = DEFAULT_CACHE_TYPE, 288 | **kwargs, 289 | ): 290 | toma_dimension = toma_dimension or 0 291 | 292 | def body(start: int, end: int): 293 | return func(tensor.narrow(dim=toma_dimension, start=start, length=end - start), start, end, *args, **kwargs) 294 | 295 | explicit.range( 296 | body, 297 | 0, 298 | tensor.shape[toma_dimension], 299 | initial_step, 300 | *args, 301 | toma_context=toma_context or func, 302 | toma_cache_type=toma_cache_type, 303 | ) 304 | 305 | 306 | def should_reduce_batch_size(exception): 307 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 308 | -------------------------------------------------------------------------------- /toma/batchsize_cache.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | import toma.cpu_memory 6 | from toma import stacktrace as tst, torch_cuda_memory as tcm 7 | import weakref 8 | 9 | 10 | @dataclass 11 | class Batchsize: 12 | value: Optional[int] = None 13 | 14 | def set_initial_batchsize(self, initial_batchsize: int): 15 | if not self.value: 16 | self.value = initial_batchsize 17 | 18 | def get(self) -> int: 19 | return self.value 20 | 21 | def decrease_batchsize(self): 22 | self.value //= 2 23 | assert self.value > 0 24 | 25 | 26 | class BatchsizeCache: 27 | all_instances = weakref.WeakValueDictionary() 28 | 29 | def __init__(self): 30 | stacktrace = tst.get_simple_traceback(2) 31 | BatchsizeCache.all_instances[stacktrace] = self 32 | 33 | def get_batchsize(self, initial_batchsize: int) -> Batchsize: 34 | raise NotImplementedError() 35 | 36 | 37 | @dataclass 38 | class NoBatchsizeCache(BatchsizeCache): 39 | def get_batchsize(self, initial_batchsize: int) -> Batchsize: 40 | return Batchsize(initial_batchsize) 41 | 42 | 43 | @dataclass 44 | class GlobalBatchsizeCache(BatchsizeCache): 45 | batchsize: Optional[Batchsize] = None 46 | 47 | def get_batchsize(self, initial_batchsize: int) -> Batchsize: 48 | if not self.batchsize: 49 | self.batchsize = Batchsize(initial_batchsize) 50 | return self.batchsize 51 | 52 | 53 | class StacktraceMemoryBatchsizeCache(BatchsizeCache): 54 | LRU_CACHE_SIZE: int = 2 ** 16 55 | MEMORY_GRANULARITY: int = 2 ** 28 56 | TRACK_RAM: bool = True 57 | 58 | initial_batchsize: Optional[int] 59 | 60 | def __init__(self, lru_cache_size=None): 61 | super().__init__() 62 | 63 | self.initial_batchsize = None 64 | 65 | @functools.lru_cache(lru_cache_size or StacktraceMemoryBatchsizeCache.LRU_CACHE_SIZE) 66 | def get_batchsize_from_cache(stacktrace, cpu_available_memory, gpu_available_memory): 67 | return Batchsize(self.initial_batchsize) 68 | 69 | self.get_batchsize_from_cache = get_batchsize_from_cache 70 | 71 | def get_batchsize(self, initial_batchsize: int): 72 | stacktrace = tst.get_simple_traceback(2) 73 | 74 | if self.TRACK_RAM: 75 | cpu_available_memory = int(toma.cpu_memory.get_available_cpu_memory() // self.MEMORY_GRANULARITY) 76 | else: 77 | cpu_available_memory = -1 78 | 79 | gpu_available_memory = int(tcm.get_cuda_assumed_available_memory() // self.MEMORY_GRANULARITY) 80 | 81 | batchsize = self.get_batchsize_from_cache(stacktrace, cpu_available_memory, gpu_available_memory) 82 | batchsize.set_initial_batchsize(initial_batchsize) 83 | return batchsize 84 | -------------------------------------------------------------------------------- /toma/cpu_memory.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | 3 | 4 | def get_available_cpu_memory(): 5 | this_process = psutil.Process() 6 | available_memory = psutil.virtual_memory().available 7 | 8 | try: 9 | import resource 10 | 11 | soft_mem_limit, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS) 12 | if hard_mem_limit != resource.RLIM_INFINITY: 13 | used_memory = this_process.memory_info().vms 14 | available_memory = min(hard_mem_limit - used_memory, available_memory) 15 | except ImportError: 16 | pass 17 | 18 | return available_memory 19 | 20 | 21 | def set_cpu_memory_limit(num_gigabytes): 22 | try: 23 | import resource 24 | 25 | num_bytes = int(num_gigabytes * 2 ** 30) 26 | _, hard_limit = resource.getrlimit(resource.RLIMIT_AS) 27 | if hard_limit != resource.RLIM_INFINITY: 28 | hard_limit = min(num_bytes, hard_limit) 29 | else: 30 | hard_limit = num_bytes 31 | resource.setrlimit(resource.RLIMIT_AS, (hard_limit, hard_limit)) 32 | except ImportError: 33 | pass 34 | 35 | 36 | def is_out_of_cpu_memory(exception): 37 | return ( 38 | isinstance(exception, RuntimeError) 39 | and len(exception.args) == 1 40 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 41 | ) 42 | -------------------------------------------------------------------------------- /toma/stacktrace.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | from contextlib import contextmanager 4 | 5 | __watermark = 0 6 | 7 | 8 | def _constant_code_context(code_context): 9 | if not code_context: 10 | return None 11 | if len(code_context) == 1: 12 | return code_context[0] 13 | return tuple(code_context) 14 | 15 | 16 | def get_simple_traceback(ignore_top=0): 17 | """Get a simple trackback that can be hashed and won't create reference 18 | cyles.""" 19 | stack = inspect.stack(context=1)[ignore_top + 1 : -__watermark - 1] 20 | simple_traceback = tuple( 21 | (fi.filename, fi.lineno, fi.function, _constant_code_context(fi.code_context), fi.index) for fi in stack 22 | ) 23 | return simple_traceback 24 | 25 | 26 | @contextmanager 27 | def watermark(): 28 | global __watermark 29 | old_watermark = __watermark 30 | 31 | # Remove the entries for `watermark` and 32 | # `contextmanager.__enter__` and for the with block. 33 | # Remove another one to keep the caller. 34 | __watermark = len(inspect.stack(context=0)) - 4 35 | 36 | try: 37 | yield 38 | finally: 39 | __watermark = old_watermark 40 | 41 | 42 | def set_watermark(func): 43 | @functools.wraps(func) 44 | def watermark_wrapper(*args, **kwargs): 45 | global __watermark 46 | old_watermark = __watermark 47 | 48 | # Remove the entries for `watermark` and 49 | # `contextmanager.__enter__`. 50 | # Dump frames for this wrapper. 51 | __watermark = len(inspect.stack(context=0)) - 1 52 | 53 | try: 54 | return func(*args, **kwargs) 55 | finally: 56 | __watermark = old_watermark 57 | 58 | return watermark_wrapper 59 | -------------------------------------------------------------------------------- /toma/torch_cuda_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper to free Torch cuda memory and determine when a Torch exception might be 3 | because of OOM conditions. 4 | """ 5 | from __future__ import print_function 6 | import torch 7 | import gc 8 | 9 | 10 | def gc_cuda(): 11 | """Gargage collect Torch (CUDA) memory.""" 12 | gc.collect() 13 | if torch.cuda.is_available(): 14 | torch.cuda.empty_cache() 15 | 16 | 17 | def get_cuda_total_memory(): 18 | if torch.cuda.is_available(): 19 | return torch.cuda.get_device_properties(0).total_memory 20 | return 0 21 | 22 | 23 | def get_cuda_assumed_available_memory(): 24 | if torch.cuda.is_available(): 25 | return get_cuda_total_memory() - torch.cuda.memory_reserved() 26 | return 0 27 | 28 | 29 | def get_cuda_available_memory(): 30 | # Always allow for 1 GB overhead. 31 | if torch.cuda.is_available(): 32 | return get_cuda_assumed_available_memory() - get_cuda_blocked_memory() 33 | return 0 34 | 35 | 36 | def get_cuda_blocked_memory(): 37 | if not torch.cuda.is_available(): 38 | return 0 39 | 40 | available_memory = get_cuda_assumed_available_memory() 41 | current_block = available_memory - 2 ** 28 # 256 MB steps 42 | while True: 43 | try: 44 | block = torch.empty((current_block,), dtype=torch.uint8, device="cuda") 45 | break 46 | except RuntimeError as exception: 47 | if is_cuda_out_of_memory(exception): 48 | current_block -= 2 ** 30 49 | if current_block <= 0: 50 | return available_memory 51 | else: 52 | raise 53 | block = None 54 | gc_cuda() 55 | return available_memory - current_block 56 | 57 | 58 | def is_cuda_out_of_memory(exception): 59 | return ( 60 | isinstance(exception, RuntimeError) and len(exception.args) == 1 and "CUDA out of memory." in exception.args[0] 61 | ) 62 | 63 | 64 | def is_cudnn_snafu(exception): 65 | # For/because of https://github.com/pytorch/pytorch/issues/4107 66 | return ( 67 | isinstance(exception, RuntimeError) 68 | and len(exception.args) == 1 69 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 70 | ) 71 | 72 | 73 | def cuda_meminfo(): 74 | if not torch.cuda.is_available(): 75 | return 76 | 77 | print( 78 | "Total:", torch.cuda.memory_allocated() / 2 ** 30, " GB Cached: ", torch.cuda.memory_reserved() / 2 ** 30, "GB" 79 | ) 80 | print( 81 | "Max Total:", 82 | torch.cuda.max_memory_allocated() / 2 ** 30, 83 | " GB Max Cached: ", 84 | torch.cuda.max_memory_reserved() / 2 ** 30, 85 | "GB", 86 | ) 87 | --------------------------------------------------------------------------------