├── .cruft.json ├── .github ├── CONTRIBUTING.md ├── codecov.yml └── workflows │ ├── cruft.yml │ └── tests.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── docs └── source │ ├── conf.py │ ├── index.rst │ ├── installation.rst │ └── usage.rst ├── pyproject.toml ├── src └── torch_max_mem │ ├── __init__.py │ ├── api.py │ ├── py.typed │ └── version.py ├── tests ├── __init__.py └── test_decorator.py └── tox.ini /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/cthoyt/cookiecutter-snekpack", 3 | "commit": "9bad1ea3f001d8630c8a6e54fffc7d8d0f3e7b42", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "package_name": "torch_max_mem", 8 | "package_name_stylized": "Torch Max Mem", 9 | "short_description": "Maximize memory utilization with PyTorch.", 10 | "author_name": "Max Berrendorf", 11 | "author_github": "mberr", 12 | "author_email": "max.berrendorf@gmail.com", 13 | "github_organization_name": "mberr", 14 | "github_repository_name": "torch-max-mem", 15 | "command_line_interface": false, 16 | "gitlab": false, 17 | "runner": "tox", 18 | "__not_charlie": "true", 19 | "__runner": "tox -e", 20 | "__runner_uv": "--with tox-uv tox -e", 21 | "__runner_pip": "tox tox-uv", 22 | "__runner_install_uv": "uv tool install tox --with tox-uv", 23 | "__runner_install_pip": "python3 -m pip install tox tox-uv", 24 | "__runner_tests": "py", 25 | "__gh_slug": "mberr/torch-max-mem", 26 | "_template": "https://github.com/cthoyt/cookiecutter-snekpack", 27 | "_commit": "9bad1ea3f001d8630c8a6e54fffc7d8d0f3e7b42" 28 | } 29 | }, 30 | "directory": null 31 | } 32 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions to this repository are welcomed and encouraged. 4 | 5 | ## Code Contribution 6 | 7 | This project uses the [GitHub Flow](https://guides.github.com/introduction/flow) 8 | model for code contributions. Follow these steps: 9 | 10 | 1. [Create a fork](https://help.github.com/articles/fork-a-repo) of the upstream 11 | repository at [`mberr/torch-max-mem`](https://github.com/mberr/torch-max-mem) 12 | on your GitHub account (or in one of your organizations) 13 | 2. [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) 14 | with `git clone https://github.com//torch-max-mem.git` 15 | 3. Make and commit changes to your fork with `git commit` 16 | 4. Push changes to your fork with `git push` 17 | 5. Repeat steps 3 and 4 as needed 18 | 6. Submit a pull request back to the upstream repository 19 | 20 | ### Merge Model 21 | 22 | This repository uses 23 | [squash merges](https://docs.github.com/en/github/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-pull-request-commits) 24 | to group all related commits in a given pull request into a single commit upon 25 | acceptance and merge into the main branch. This has several benefits: 26 | 27 | 1. Keeps the commit history on the main branch focused on high-level narrative 28 | 2. Enables people to make lots of small commits without worrying about muddying 29 | up the commit history 30 | 3. Commits correspond 1-to-1 with pull requests 31 | 32 | ### Code Style 33 | 34 | This project uses `tox` for running code quality checks. Start by installing it 35 | with `pip install tox tox-uv`. 36 | 37 | This project encourages the use of optional static typing. It uses 38 | [`mypy`](http://mypy-lang.org/) as a type checker. You can check if your code 39 | passes `mypy` with `tox -e mypy`. 40 | 41 | This project uses [`ruff`](https://docs.astral.sh/ruff/) to automatically 42 | enforce a consistent code style. You can apply `ruff format` and other 43 | pre-configured formatters with `tox -e format`. 44 | 45 | This project uses [`ruff`](https://docs.astral.sh/ruff/) and several plugins for 46 | additional checks of documentation style, security issues, good variable 47 | nomenclature, and more (see `pyproject.toml` for a list of Ruff plugins). You 48 | can check if your code passes `ruff check` with `tox -e lint`. 49 | 50 | Each of these checks are run on each commit using GitHub Actions as a continuous 51 | integration service. Passing all of them is required for accepting a 52 | contribution. If you're unsure how to address the feedback from one of these 53 | tools, please say so either in the description of your pull request or in a 54 | comment, and we will help you. 55 | 56 | ### Logging 57 | 58 | Python's builtin `print()` should not be used (except when writing to files), 59 | it's checked by the 60 | [`flake8-print` (T20)](https://docs.astral.sh/ruff/rules/#flake8-print-t20) 61 | plugin to `ruff`. If you're in a command line setting or `main()` function for a 62 | module, you can use `click.echo()`. Otherwise, you can use the builtin `logging` 63 | module by adding `logger = logging.getLogger(__name__)` below the imports at the 64 | top of your file. 65 | 66 | ### Documentation 67 | 68 | All public functions (i.e., not starting with an underscore `_`) must be 69 | documented using the 70 | [sphinx documentation format](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format). 71 | The [`darglint2`](https://github.com/akaihola/darglint2) tool reports on 72 | functions that are not fully documented. 73 | 74 | This project uses [`sphinx`](https://www.sphinx-doc.org) to automatically build 75 | documentation into a narrative structure. You can check that the documentation 76 | builds properly in an isolated environment with `tox -e docs-test` and actually 77 | build it locally with `tox -e docs`. 78 | 79 | ### Testing 80 | 81 | Functions in this repository should be unit tested. These can either be written 82 | using the `unittest` framework in the `tests/` directory or as embedded 83 | doctests. You can check that the unit tests pass with `tox -e py` and that the 84 | doctests pass with `tox -e doctests`. These tests are required to pass for 85 | accepting a contribution. 86 | 87 | ### Syncing your fork 88 | 89 | If other code is updated before your contribution gets merged, you might need to 90 | resolve conflicts against the main branch. After cloning, you should add the 91 | upstream repository with 92 | 93 | ```shell 94 | $ git remote add mberr https://github.com/mberr/torch-max-mem.git 95 | ``` 96 | 97 | Then, you can merge upstream code into your branch. You can also use the GitHub 98 | UI to do this by following 99 | [this tutorial](https://docs.github.com/en/github/collaborating-with-pull-requests/working-with-forks/syncing-a-fork). 100 | 101 | ### Python Version Compatibility 102 | 103 | This project aims to support all versions of Python that have not passed their 104 | end-of-life dates. After end-of-life, the version will be removed from the Trove 105 | qualifiers in the `pyproject.toml` and from the GitHub Actions testing 106 | configuration. 107 | 108 | See https://endoflife.date/python for a timeline of Python release and 109 | end-of-life dates. 110 | 111 | ## Acknowledgements 112 | 113 | These code contribution guidelines are derived from the 114 | [cthoyt/cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) 115 | Python package template. They're free to reuse and modify as long as they're 116 | properly acknowledged. 117 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.com/v4.6/docs/codecov-yaml 2 | ignore: 3 | - "src/torch_max_mem/__main__.py" 4 | - "src/torch_max_mem/cli.py" 5 | -------------------------------------------------------------------------------- /.github/workflows/cruft.yml: -------------------------------------------------------------------------------- 1 | # from https://cruft.github.io/cruft/#automating-updates-with-github-actions 2 | 3 | name: Update repository with Cruft 4 | 5 | permissions: { } 6 | 7 | on: 8 | workflow_dispatch: 9 | schedule: 10 | - cron: "0 2 * * 1" # Every Monday at 2am 11 | 12 | jobs: 13 | update: 14 | permissions: 15 | contents: write 16 | pull-requests: write 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: true 20 | matrix: 21 | include: 22 | - add-paths: . 23 | body: Use this to merge the changes to this repository. 24 | branch: cruft/update 25 | commit-message: "chore: accept new Cruft update" 26 | title: New updates detected with Cruft 27 | - add-paths: .cruft.json 28 | body: Use this to reject the changes in this repository. 29 | branch: cruft/reject 30 | commit-message: "chore: reject new Cruft update" 31 | title: Reject new updates detected with Cruft 32 | steps: 33 | - uses: actions/checkout@v3 34 | 35 | - uses: actions/setup-python@v4 36 | with: 37 | python-version: "3.10" 38 | 39 | - name: Install Cruft 40 | run: pip3 install cruft 41 | 42 | - name: Check if update is available 43 | continue-on-error: false 44 | id: check 45 | run: | 46 | CHANGES=0 47 | if [ -f .cruft.json ]; then 48 | if ! cruft check; then 49 | CHANGES=1 50 | fi 51 | else 52 | echo "No .cruft.json file" 53 | fi 54 | 55 | echo "has_changes=$CHANGES" >> "$GITHUB_OUTPUT" 56 | 57 | - name: Run update if available 58 | if: steps.check.outputs.has_changes == '1' 59 | run: | 60 | git config --global user.email "you@example.com" 61 | git config --global user.name "GitHub" 62 | 63 | cruft update --skip-apply-ask --refresh-private-variables 64 | git restore --staged . 65 | 66 | - name: Create pull request 67 | if: steps.check.outputs.has_changes == '1' 68 | uses: peter-evans/create-pull-request@v4 69 | with: 70 | token: ${{ secrets.GITHUB_TOKEN }} 71 | add-paths: ${{ matrix.add-paths }} 72 | commit-message: ${{ matrix.commit-message }} 73 | branch: ${{ matrix.branch }} 74 | delete-branch: true 75 | branch-suffix: timestamp 76 | title: ${{ matrix.title }} 77 | body: | 78 | This is an autogenerated PR. ${{ matrix.body }} 79 | 80 | [Cruft](https://cruft.github.io/cruft/) has detected updates from the Cookiecutter repository. 81 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This file configures the continuous integration (CI) system on GitHub. 2 | # Introductory materials can be found here: https://docs.github.com/en/actions/learn-github-actions/understanding-github-actions. 3 | # Documentation for editing this file can be found here: https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions 4 | 5 | name: Tests 6 | 7 | # by default, give the GITHUB_TOKEN no permissions 8 | # See https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/controlling-permissions-for-github_token 9 | permissions: { } 10 | 11 | on: 12 | push: 13 | branches: [ main ] 14 | pull_request: 15 | branches: [ main ] 16 | 17 | jobs: 18 | lint: 19 | name: Code Quality 20 | permissions: 21 | # give only read-only access to the contents of the repository 22 | # this is the only permission this job requires, so keep it to the least privilege 23 | # i.e., not to issues, discussions, actions, etc. 24 | contents: read 25 | runs-on: ubuntu-latest 26 | strategy: 27 | matrix: 28 | python-version: [ "3.13", "3.9" ] 29 | tox-command: [ "lint", "pyroma", "mypy" ] 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: "Install uv" 33 | uses: "astral-sh/setup-uv@v3" 34 | with: 35 | enable-cache: true 36 | cache-dependency-glob: "pyproject.toml" 37 | - name: "Run command" 38 | run: | 39 | uvx -p ${{ matrix.python-version }} --with tox-uv tox -e ${{ matrix.tox-command }} 40 | 41 | docs: 42 | name: Documentation 43 | permissions: 44 | contents: read 45 | runs-on: ubuntu-latest 46 | strategy: 47 | matrix: 48 | # We only test documentation on the latest version 49 | # sphinx 8.0 / sphinx-rtd-theme 3.0 discontinued Python 3.9 support 50 | # a year early, which prompted re-thinking about this. 51 | python-version: [ "3.13" ] 52 | steps: 53 | - uses: actions/checkout@v4 54 | - name: "Install uv" 55 | uses: "astral-sh/setup-uv@v3" 56 | with: 57 | enable-cache: true 58 | cache-dependency-glob: "pyproject.toml" 59 | - name: Install dependencies 60 | run: | 61 | sudo apt-get install graphviz 62 | - name: Lint documentation 63 | run: uvx -p ${{ matrix.python-version }} --with tox-uv tox -e docs-lint 64 | - name: Check docstring coverage 65 | run: uvx -p ${{ matrix.python-version }} --with tox-uv tox -e docstr-coverage 66 | - name: Check documentation build with Sphinx 67 | run: uvx -p ${{ matrix.python-version }} --with tox-uv tox -e docs-test 68 | - name: Lint markdown 69 | run: uvx -p ${{ matrix.python-version }} --with tox-uv tox -e lint-markdown 70 | 71 | tests: 72 | name: Tests 73 | permissions: 74 | contents: read 75 | runs-on: ${{ matrix.os }} 76 | strategy: 77 | matrix: 78 | os: [ ubuntu-latest ] 79 | python-version: [ "3.13", "3.9" ] 80 | steps: 81 | - uses: actions/checkout@v4 82 | - name: "Install uv" 83 | uses: "astral-sh/setup-uv@v3" 84 | with: 85 | enable-cache: true 86 | cache-dependency-glob: "pyproject.toml" 87 | - name: Test with pytest and generate coverage file 88 | run: 89 | uvx -p ${{ matrix.python-version }} --with tox-uv tox -e py 90 | - name: Run doctests 91 | run: 92 | uvx -p ${{ matrix.python-version }} --with tox-uv tox -e doctests 93 | - name: Upload coverage report to codecov 94 | uses: codecov/codecov-action@v4 95 | if: success() 96 | with: 97 | file: coverage.xml 98 | 99 | concurrency: 100 | group: ${{ github.workflow }}-${{ github.ref }} 101 | cancel-in-progress: true 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,linux,pycharm,python,windows 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,linux,pycharm,python,windows 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### PyCharm ### 49 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 50 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 51 | 52 | # User-specific stuff 53 | .idea/**/workspace.xml 54 | .idea/**/tasks.xml 55 | .idea/**/usage.statistics.xml 56 | .idea/**/dictionaries 57 | .idea/**/shelf 58 | 59 | # Generated files 60 | .idea/**/contentModel.xml 61 | 62 | # Sensitive or high-churn files 63 | .idea/**/dataSources/ 64 | .idea/**/dataSources.ids 65 | .idea/**/dataSources.local.xml 66 | .idea/**/sqlDataSources.xml 67 | .idea/**/dynamic.xml 68 | .idea/**/uiDesigner.xml 69 | .idea/**/dbnavigator.xml 70 | 71 | # Gradle 72 | .idea/**/gradle.xml 73 | .idea/**/libraries 74 | 75 | # Gradle and Maven with auto-import 76 | # When using Gradle or Maven with auto-import, you should exclude module files, 77 | # since they will be recreated, and may cause churn. Uncomment if using 78 | # auto-import. 79 | # .idea/artifacts 80 | # .idea/compiler.xml 81 | # .idea/jarRepositories.xml 82 | # .idea/modules.xml 83 | # .idea/*.iml 84 | # .idea/modules 85 | # *.iml 86 | # *.ipr 87 | 88 | # CMake 89 | cmake-build-*/ 90 | 91 | # Mongo Explorer plugin 92 | .idea/**/mongoSettings.xml 93 | 94 | # File-based project format 95 | *.iws 96 | 97 | # IntelliJ 98 | out/ 99 | 100 | # mpeltonen/sbt-idea plugin 101 | .idea_modules/ 102 | 103 | # JIRA plugin 104 | atlassian-ide-plugin.xml 105 | 106 | # Cursive Clojure plugin 107 | .idea/replstate.xml 108 | 109 | # Crashlytics plugin (for Android Studio and IntelliJ) 110 | com_crashlytics_export_strings.xml 111 | crashlytics.properties 112 | crashlytics-build.properties 113 | fabric.properties 114 | 115 | # Editor-based Rest Client 116 | .idea/httpRequests 117 | 118 | # Android studio 3.1+ serialized cache file 119 | .idea/caches/build_file_checksums.ser 120 | 121 | ### PyCharm Patch ### 122 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 123 | 124 | # *.iml 125 | # modules.xml 126 | # .idea/misc.xml 127 | # *.ipr 128 | 129 | # Sonarlint plugin 130 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 131 | .idea/**/sonarlint/ 132 | 133 | # SonarQube Plugin 134 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 135 | .idea/**/sonarIssues.xml 136 | 137 | # Markdown Navigator plugin 138 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 139 | .idea/**/markdown-navigator.xml 140 | .idea/**/markdown-navigator-enh.xml 141 | .idea/**/markdown-navigator/ 142 | 143 | # Cache file creation bug 144 | # See https://youtrack.jetbrains.com/issue/JBR-2257 145 | .idea/$CACHE_FILE$ 146 | 147 | # CodeStream plugin 148 | # https://plugins.jetbrains.com/plugin/12206-codestream 149 | .idea/codestream.xml 150 | 151 | ### Python ### 152 | # Byte-compiled / optimized / DLL files 153 | __pycache__/ 154 | *.py[cod] 155 | *$py.class 156 | 157 | # C extensions 158 | *.so 159 | 160 | # Distribution / packaging 161 | .Python 162 | build/ 163 | develop-eggs/ 164 | dist/ 165 | downloads/ 166 | eggs/ 167 | .eggs/ 168 | lib/ 169 | lib64/ 170 | parts/ 171 | sdist/ 172 | var/ 173 | wheels/ 174 | pip-wheel-metadata/ 175 | share/python-wheels/ 176 | *.egg-info/ 177 | .installed.cfg 178 | *.egg 179 | MANIFEST 180 | 181 | # PyInstaller 182 | # Usually these files are written by a python script from a template 183 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 184 | *.manifest 185 | *.spec 186 | 187 | # Installer logs 188 | pip-log.txt 189 | pip-delete-this-directory.txt 190 | 191 | # Unit test / coverage reports 192 | htmlcov/ 193 | .tox/ 194 | .nox/ 195 | .coverage 196 | .coverage.* 197 | .cache 198 | nosetests.xml 199 | coverage.xml 200 | *.cover 201 | *.py,cover 202 | .hypothesis/ 203 | .pytest_cache/ 204 | pytestdebug.log 205 | 206 | # Translations 207 | *.mo 208 | *.pot 209 | 210 | # Django stuff: 211 | *.log 212 | local_settings.py 213 | db.sqlite3 214 | db.sqlite3-journal 215 | 216 | # Flask stuff: 217 | instance/ 218 | .webassets-cache 219 | 220 | # Scrapy stuff: 221 | .scrapy 222 | 223 | # Sphinx documentation 224 | docs/_build/ 225 | doc/_build/ 226 | 227 | # PyBuilder 228 | target/ 229 | 230 | # Jupyter Notebook 231 | .ipynb_checkpoints 232 | 233 | # IPython 234 | profile_default/ 235 | ipython_config.py 236 | 237 | # pyenv 238 | .python-version 239 | 240 | # pipenv 241 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 242 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 243 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 244 | # install all needed dependencies. 245 | #Pipfile.lock 246 | 247 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 248 | __pypackages__/ 249 | 250 | # Celery stuff 251 | celerybeat-schedule 252 | celerybeat.pid 253 | 254 | # SageMath parsed files 255 | *.sage.py 256 | 257 | # Environments 258 | .env 259 | .venv 260 | env/ 261 | venv/ 262 | ENV/ 263 | env.bak/ 264 | venv.bak/ 265 | pythonenv* 266 | 267 | # Spyder project settings 268 | .spyderproject 269 | .spyproject 270 | 271 | # Rope project settings 272 | .ropeproject 273 | 274 | # mkdocs documentation 275 | /site 276 | 277 | # mypy 278 | .mypy_cache/ 279 | .dmypy.json 280 | dmypy.json 281 | 282 | # Pyre type checker 283 | .pyre/ 284 | 285 | # pytype static type analyzer 286 | .pytype/ 287 | 288 | # profiling data 289 | .prof 290 | 291 | ### Windows ### 292 | # Windows thumbnail cache files 293 | Thumbs.db 294 | Thumbs.db:encryptable 295 | ehthumbs.db 296 | ehthumbs_vista.db 297 | 298 | # Dump file 299 | *.stackdump 300 | 301 | # Folder config file 302 | [Dd]esktop.ini 303 | 304 | # Recycle Bin used on file shares 305 | $RECYCLE.BIN/ 306 | 307 | # Windows Installer files 308 | *.cab 309 | *.msi 310 | *.msix 311 | *.msm 312 | *.msp 313 | 314 | # Windows shortcuts 315 | *.lnk 316 | 317 | # End of https://www.toptal.com/developers/gitignore/api/macos,linux,pycharm,python,windows 318 | 319 | scratch/ 320 | 321 | .vscode 322 | .idea -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | version: 2 6 | 7 | sphinx: 8 | # Path to your Sphinx configuration file, required as of 9 | # https://about.readthedocs.com/blog/2024/12/deprecate-config-files-without-sphinx-or-mkdocs-config/ 10 | configuration: docs/source/conf.py 11 | 12 | # Set the version of Python and other tools you might need 13 | build: 14 | os: ubuntu-22.04 15 | apt_packages: 16 | - graphviz 17 | tools: 18 | python: "3.12" 19 | 20 | # adapted from uv recipe at https://docs.readthedocs.io/en/stable/build-customization.html#install-dependencies-with-uv 21 | # and comment at https://github.com/readthedocs/readthedocs.org/issues/11289#issuecomment-2103832834 22 | commands: 23 | - asdf plugin add uv 24 | - asdf install uv latest 25 | - asdf global uv latest 26 | - uv venv $READTHEDOCS_VIRTUALENV_PATH 27 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install --group docs . 28 | - python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs/source $READTHEDOCS_OUTPUT/html 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Max Berrendorf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 |

8 | torch-max-mem 9 |

10 | 11 |

12 | 13 | Tests 14 | 15 | PyPI 16 | 17 | PyPI - Python Version 18 | 19 | PyPI - License 20 | 21 | Documentation Status 22 | 23 | Codecov status 24 | 25 | Cookiecutter template from @cthoyt 26 | 27 | Ruff 28 | 29 | Contributor Covenant 30 | 34 |

35 | 36 | This package provides decorators for memory utilization maximization with 37 | PyTorch and CUDA by starting with a maximum parameter size and applying 38 | successive halving until no more out-of-memory exception occurs. 39 | 40 | ## 💪 Getting Started 41 | 42 | Assume you have a function for batched computation of nearest neighbors using 43 | brute-force distance calculation. 44 | 45 | ```python 46 | import torch 47 | 48 | def knn(x, y, batch_size, k: int = 3): 49 | return torch.cat( 50 | [ 51 | torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices 52 | for start in range(0, x.shape[0], batch_size) 53 | ], 54 | dim=0, 55 | ) 56 | ``` 57 | 58 | With `torch_max_mem` you can decorate this function to reduce the batch size 59 | until no more out-of-memory error occurs. 60 | 61 | ```python 62 | import torch 63 | from torch_max_mem import maximize_memory_utilization 64 | 65 | 66 | @maximize_memory_utilization() 67 | def knn(x, y, batch_size, k: int = 3): 68 | return torch.cat( 69 | [ 70 | torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices 71 | for start in range(0, x.shape[0], batch_size) 72 | ], 73 | dim=0, 74 | ) 75 | ``` 76 | 77 | In the code, you can now always pass the largest sensible batch size, e.g., 78 | 79 | ```python 80 | x = torch.rand(100, 100, device="cuda") 81 | y = torch.rand(200, 100, device="cuda") 82 | knn(x, y, batch_size=x.shape[0]) 83 | ``` 84 | 85 | ## 🚀 Installation 86 | 87 | The most recent release can be installed from 88 | [PyPI](https://pypi.org/project/torch_max_mem/) with uv: 89 | 90 | ```console 91 | uv pip install torch_max_mem 92 | ``` 93 | 94 | or with pip: 95 | 96 | ```console 97 | python3 -m pip install torch_max_mem 98 | ``` 99 | 100 | The most recent code and data can be installed directly from GitHub with uv: 101 | 102 | ```console 103 | uv pip install git+https://github.com/mberr/torch-max-mem.git 104 | ``` 105 | 106 | or with pip: 107 | 108 | ```console 109 | python3 -m pip install git+https://github.com/mberr/torch-max-mem.git 110 | ``` 111 | 112 | ## 👐 Contributing 113 | 114 | Contributions, whether filing an issue, making a pull request, or forking, are 115 | appreciated. See 116 | [CONTRIBUTING.md](https://github.com/mberr/torch-max-mem/blob/master/.github/CONTRIBUTING.md) 117 | for more information on getting involved. 118 | 119 | ## 👋 Attribution 120 | 121 | Parts of the logic have been developed with 122 | [Laurent Vermue](https://github.com/lvermue) for 123 | [PyKEEN](https://github.com/pykeen/pykeen). 124 | 125 | ### ⚖️ License 126 | 127 | The code in this package is licensed under the MIT License. 128 | 129 | ### 🍪 Cookiecutter 130 | 131 | This package was created with 132 | [@audreyfeldroy](https://github.com/audreyfeldroy)'s 133 | [cookiecutter](https://github.com/cookiecutter/cookiecutter) package using 134 | [@cthoyt](https://github.com/cthoyt)'s 135 | [cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) 136 | template. 137 | 138 | ## 🛠️ For Developers 139 | 140 |
141 | See developer instructions 142 | 143 | The final section of the README is for if you want to get involved by making a 144 | code contribution. 145 | 146 | ### Development Installation 147 | 148 | To install in development mode, use the following: 149 | 150 | ```console 151 | git clone git+https://github.com/mberr/torch-max-mem.git 152 | cd snekpack-demo 153 | uv pip install -e . 154 | ``` 155 | 156 | Alternatively, install using pip: 157 | 158 | ```console 159 | python3 -m pip install -e . 160 | ``` 161 | 162 | ### Updating Package Boilerplate 163 | 164 | This project uses `cruft` to keep boilerplate (i.e., configuration, contribution 165 | guidelines, documentation configuration) up-to-date with the upstream 166 | cookiecutter package. Install cruft with either `uv tool install cruft` or 167 | `python3 -m pip install cruft` then run: 168 | 169 | ```console 170 | cruft update 171 | ``` 172 | 173 | More info on Cruft's update command is available 174 | [here](https://github.com/cruft/cruft?tab=readme-ov-file#updating-a-project). 175 | 176 | ### 🥼 Testing 177 | 178 | After cloning the repository and installing `tox` with 179 | `uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, the 180 | unit tests in the `tests/` folder can be run reproducibly with: 181 | 182 | ```console 183 | tox -e py 184 | ``` 185 | 186 | Additionally, these tests are automatically re-run with each commit in a 187 | [GitHub Action](https://github.com/mberr/torch-max-mem/actions?query=workflow%3ATests). 188 | 189 | ### 📖 Building the Documentation 190 | 191 | The documentation can be built locally using the following: 192 | 193 | ```console 194 | git clone git+https://github.com/mberr/torch-max-mem.git 195 | cd snekpack-demo 196 | tox -e docs 197 | open docs/build/html/index.html 198 | ``` 199 | 200 | The documentation automatically installs the package as well as the `docs` extra 201 | specified in the [`pyproject.toml`](pyproject.toml). `sphinx` plugins like 202 | `texext` can be added there. Additionally, they need to be added to the 203 | `extensions` list in [`docs/source/conf.py`](docs/source/conf.py). 204 | 205 | The documentation can be deployed to [ReadTheDocs](https://readthedocs.io) using 206 | [this guide](https://docs.readthedocs.io/en/stable/intro/import-guide.html). The 207 | [`.readthedocs.yml`](.readthedocs.yml) YAML file contains all the configuration 208 | you'll need. You can also set up continuous integration on GitHub to check not 209 | only that Sphinx can build the documentation in an isolated environment (i.e., 210 | with `tox -e docs-test`) but also that 211 | [ReadTheDocs can build it too](https://docs.readthedocs.io/en/stable/pull-requests.html). 212 | 213 | #### Configuring ReadTheDocs 214 | 215 | 1. Log in to ReadTheDocs with your GitHub account to install the integration at 216 | https://readthedocs.org/accounts/login/?next=/dashboard/ 217 | 2. Import your project by navigating to https://readthedocs.org/dashboard/import 218 | then clicking the plus icon next to your repository 219 | 3. You can rename the repository on the next screen using a more stylized name 220 | (i.e., with spaces and capital letters) 221 | 4. Click next, and you're good to go! 222 | 223 | ### 📦 Making a Release 224 | 225 | #### Configuring Zenodo 226 | 227 | [Zenodo](https://zenodo.org) is a long-term archival system that assigns a DOI 228 | to each release of your package. 229 | 230 | 1. Log in to Zenodo via GitHub with this link: 231 | https://zenodo.org/oauth/login/github/?next=%2F. This brings you to a page 232 | that lists all of your organizations and asks you to approve installing the 233 | Zenodo app on GitHub. Click "grant" next to any organizations you want to 234 | enable the integration for, then click the big green "approve" button. This 235 | step only needs to be done once. 236 | 2. Navigate to https://zenodo.org/account/settings/github/, which lists all of 237 | your GitHub repositories (both in your username and any organizations you 238 | enabled). Click the on/off toggle for any relevant repositories. When you 239 | make a new repository, you'll have to come back to this 240 | 241 | After these steps, you're ready to go! After you make "release" on GitHub (steps 242 | for this are below), you can navigate to 243 | https://zenodo.org/account/settings/github/repository/mberr/torch-max-mem to see 244 | the DOI for the release and link to the Zenodo record for it. 245 | 246 | #### Registering with the Python Package Index (PyPI) 247 | 248 | You only have to do the following steps once. 249 | 250 | 1. Register for an account on the 251 | [Python Package Index (PyPI)](https://pypi.org/account/register) 252 | 2. Navigate to https://pypi.org/manage/account and make sure you have verified 253 | your email address. A verification email might not have been sent by default, 254 | so you might have to click the "options" dropdown next to your address to get 255 | to the "re-send verification email" button 256 | 3. 2-Factor authentication is required for PyPI since the end of 2023 (see this 257 | [blog post from PyPI](https://blog.pypi.org/posts/2023-05-25-securing-pypi-with-2fa/)). 258 | This means you have to first issue account recovery codes, then set up 259 | 2-factor authentication 260 | 4. Issue an API token from https://pypi.org/manage/account/token 261 | 262 | #### Configuring your machine's connection to PyPI 263 | 264 | You have to do the following steps once per machine. 265 | 266 | ```console 267 | uv tool install keyring 268 | keyring set https://upload.pypi.org/legacy/ __token__ 269 | keyring set https://test.pypi.org/legacy/ __token__ 270 | ``` 271 | 272 | Note that this deprecates previous workflows using `.pypirc`. 273 | 274 | #### Uploading to PyPI 275 | 276 | After installing the package in development mode and installing `tox` with 277 | `uv tool install tox --with tox-uv` or `python3 -m pip install tox tox-uv`, run 278 | the following from the console: 279 | 280 | ```console 281 | tox -e finish 282 | ``` 283 | 284 | This script does the following: 285 | 286 | 1. Uses [bump-my-version](https://github.com/callowayproject/bump-my-version) to 287 | switch the version number in the `pyproject.toml`, `CITATION.cff`, 288 | `src/torch_max_mem/version.py`, and 289 | [`docs/source/conf.py`](docs/source/conf.py) to not have the `-dev` suffix 290 | 2. Packages the code in both a tar archive and a wheel using 291 | [`uv build`](https://docs.astral.sh/uv/guides/publish/#building-your-package) 292 | 3. Uploads to PyPI using 293 | [`uv publish`](https://docs.astral.sh/uv/guides/publish/#publishing-your-package). 294 | 4. Push to GitHub. You'll need to make a release going with the commit where the 295 | version was bumped. 296 | 5. Bump the version to the next patch. If you made big changes and want to bump 297 | the version by minor, you can use `tox -e bumpversion -- minor` after. 298 | 299 | #### Releasing on GitHub 300 | 301 | 1. Navigate to https://github.com/mberr/torch-max-mem/releases/new to draft a 302 | new release 303 | 2. Click the "Choose a Tag" dropdown and select the tag corresponding to the 304 | release you just made 305 | 3. Click the "Generate Release Notes" button to get a quick outline of recent 306 | changes. Modify the title and description as you see fit 307 | 4. Click the big green "Publish Release" button 308 | 309 | This will trigger Zenodo to assign a DOI to your release as well. 310 | 311 |
312 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder. 2 | 3 | This file does only contain a selection of the most common options. For a full list see 4 | the documentation: http://www.sphinx-doc.org/en/master/config 5 | 6 | -- Path setup -------------------------------------------------------------- 7 | 8 | If extensions (or modules to document with autodoc) are in another directory, add these 9 | directories to ``sys.path`` here. If the directory is relative to the documentation 10 | root, use ``os.path.abspath`` to make it absolute, like shown here. 11 | """ 12 | 13 | import os 14 | import re 15 | import sys 16 | from datetime import date 17 | 18 | sys.path.insert(0, os.path.abspath("../../src")) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "torch_max_mem" 23 | copyright = f"{date.today().year}, Max Berrendorf" 24 | author = "Max Berrendorf" 25 | 26 | # The full version, including alpha/beta/rc tags. 27 | release = "0.1.5-dev" 28 | 29 | # The short X.Y version. 30 | parsed_version = re.match( 31 | r"(?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?", 32 | release, 33 | ) 34 | version = parsed_version.expand(r"\g.\g.\g") 35 | 36 | if parsed_version.group("release"): 37 | tags.add("prerelease") # noqa:F821 38 | 39 | 40 | # See https://about.readthedocs.com/blog/2024/07/addons-by-default/ 41 | # Define the canonical URL if you are using a custom domain on Read the Docs 42 | html_baseurl = os.environ.get("READTHEDOCS_CANONICAL_URL", "") 43 | 44 | # See https://about.readthedocs.com/blog/2024/07/addons-by-default/ 45 | # Tell Jinja2 templates the build is running on Read the Docs 46 | if os.environ.get("READTHEDOCS", "") == "True": 47 | if "html_context" not in globals(): 48 | html_context = {} 49 | html_context["READTHEDOCS"] = True 50 | 51 | 52 | # -- General configuration --------------------------------------------------- 53 | 54 | # If your documentation needs a minimal Sphinx version, state it here. 55 | # 56 | # needs_sphinx = '1.0' 57 | 58 | # If true, the current module name will be prepended to all description 59 | # unit titles (such as .. function::). 60 | add_module_names = False 61 | 62 | # A list of prefixes that are ignored when creating the module index. (new in Sphinx 0.6) 63 | modindex_common_prefix = ["torch_max_mem."] 64 | 65 | # Add any Sphinx extension module names here, as strings. They can be 66 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 67 | # ones. 68 | extensions = [ 69 | "sphinx.ext.autosummary", 70 | "sphinx.ext.autodoc", 71 | "sphinx.ext.coverage", 72 | "sphinx.ext.intersphinx", 73 | "sphinx.ext.todo", 74 | "sphinx.ext.mathjax", 75 | "sphinx.ext.viewcode", 76 | "sphinx_automodapi.automodapi", 77 | "sphinx_automodapi.smart_resolver", 78 | # 'texext', 79 | ] 80 | 81 | 82 | extensions.append("sphinx_click.ext") 83 | 84 | 85 | # generate autosummary pages 86 | autosummary_generate = True 87 | 88 | # Add any paths that contain templates here, relative to this directory. 89 | templates_path = ["_templates"] 90 | 91 | # The suffix(es) of source filenames. 92 | # You can specify multiple suffix as a list of string: 93 | # 94 | # source_suffix = ['.rst', '.md'] 95 | source_suffix = { 96 | ".rst": "restructuredtext", 97 | } 98 | 99 | # The master toctree document. 100 | master_doc = "index" 101 | 102 | # The language for content autogenerated by Sphinx. Refer to documentation 103 | # for a list of supported languages. 104 | # 105 | # This is also used if you do content translation via gettext catalogs. 106 | # Usually you set "language" from the command line for these cases. 107 | language = "en" 108 | 109 | # List of patterns, relative to source directory, that match files and 110 | # directories to ignore when looking for source files. 111 | # This pattern also affects html_static_path and html_extra_path. 112 | exclude_patterns = [] 113 | 114 | # The name of the Pygments (syntax highlighting) style to use. 115 | pygments_style = "sphinx" 116 | 117 | # -- Options for HTML output ------------------------------------------------- 118 | 119 | # The theme to use for HTML and HTML Help pages. See the documentation for 120 | # a list of builtin themes. 121 | # 122 | html_theme = "sphinx_rtd_theme" 123 | 124 | # Theme options are theme-specific and customize the look and feel of a theme 125 | # further. For a list of options available for each theme, see the 126 | # documentation. 127 | # 128 | # html_theme_options = {} 129 | 130 | # Add any paths that contain custom static files (such as style sheets) here, 131 | # relative to this directory. They are copied after the builtin static files, 132 | # so a file named "default.css" will overwrite the builtin "default.css". 133 | # html_static_path = ['_static'] 134 | 135 | # Custom sidebar templates, must be a dictionary that maps document names 136 | # to template names. 137 | # 138 | # The default sidebars (for documents that don't match any pattern) are 139 | # defined by theme itself. Builtin themes are using these templates by 140 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 141 | # 'searchbox.html']``. 142 | # 143 | # html_sidebars = {} 144 | 145 | # The name of an image file (relative to this directory) to place at the top 146 | # of the sidebar. 147 | # 148 | if os.path.exists("logo.png"): 149 | html_logo = "logo.png" 150 | 151 | # -- Options for HTMLHelp output --------------------------------------------- 152 | 153 | # Output file base name for HTML help builder. 154 | htmlhelp_basename = "torch_max_mem_doc" 155 | 156 | # -- Options for LaTeX output ------------------------------------------------ 157 | 158 | # latex_elements = { 159 | # The paper size ('letterpaper' or 'a4paper'). 160 | # 161 | # 'papersize': 'letterpaper', 162 | # 163 | # The font size ('10pt', '11pt' or '12pt'). 164 | # 165 | # 'pointsize': '10pt', 166 | # 167 | # Additional stuff for the LaTeX preamble. 168 | # 169 | # 'preamble': '', 170 | # 171 | # Latex figure (float) alignment 172 | # 173 | # 'figure_align': 'htbp', 174 | # } 175 | 176 | # Grouping the document tree into LaTeX files. List of tuples 177 | # (source start file, target name, title, 178 | # author, documentclass [howto, manual, or own class]). 179 | # latex_documents = [ 180 | # ( 181 | # master_doc, 182 | # 'torch_max_mem.tex', 183 | # 'torch-max-mem Documentation', 184 | # author, 185 | # 'manual', 186 | # ), 187 | # ] 188 | 189 | # -- Options for manual page output ------------------------------------------ 190 | 191 | # One entry per manual page. List of tuples 192 | # (source start file, name, description, authors, manual section). 193 | man_pages = [ 194 | ( 195 | master_doc, 196 | "torch_max_mem", 197 | "torch-max-mem Documentation", 198 | [author], 199 | 1, 200 | ), 201 | ] 202 | 203 | # -- Options for Texinfo output ---------------------------------------------- 204 | 205 | # Grouping the document tree into Texinfo files. List of tuples 206 | # (source start file, target name, title, author, 207 | # dir menu entry, description, category) 208 | texinfo_documents = [ 209 | ( 210 | master_doc, 211 | "torch_max_mem", 212 | "torch-max-mem Documentation", 213 | author, 214 | "Max Berrendorf", 215 | "Maximize memory utilization with PyTorch.", 216 | "Miscellaneous", 217 | ), 218 | ] 219 | 220 | # -- Options for Epub output ------------------------------------------------- 221 | 222 | # Bibliographic Dublin Core info. 223 | # epub_title = project 224 | 225 | # The unique identifier of the text. This can be a ISBN number 226 | # or the project homepage. 227 | # 228 | # epub_identifier = '' 229 | 230 | # A unique identification for the text. 231 | # 232 | # epub_uid = '' 233 | 234 | # A list of files that should not be packed into the epub file. 235 | # epub_exclude_files = ['search.html'] 236 | 237 | # -- Extension configuration ------------------------------------------------- 238 | 239 | # -- Options for intersphinx extension --------------------------------------- 240 | 241 | # Example configuration for intersphinx: refer to the Python standard library. 242 | # Note: don't add trailing slashes, since sphinx adds "/objects.inv" to the end 243 | intersphinx_mapping = { 244 | "python": ("https://docs.python.org/3", None), 245 | "torch": ("https://pytorch.org/docs/stable", None), 246 | } 247 | 248 | autoclass_content = "both" 249 | 250 | # Don't sort alphabetically, explained at: 251 | # https://stackoverflow.com/questions/37209921/python-how-not-to-sort-sphinx-output-in-alphabetical-order 252 | autodoc_member_order = "bysource" 253 | 254 | todo_include_todos = True 255 | todo_emit_warnings = True 256 | 257 | # Output SVG inheritance diagrams 258 | graphviz_output_format = "svg" 259 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | torch-max-mem |release| Documentation 2 | ===================================== 3 | 4 | ``torch-max-mem`` is a package to enable automatic memory optimization / batch size selection via simple function 5 | decorators. 6 | 7 | Table of Contents 8 | ----------------- 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Getting Started 12 | :name: start 13 | 14 | installation 15 | usage 16 | 17 | Indices and Tables 18 | ------------------ 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | The most recent release can be installed from 4 | `PyPI `_ with: 5 | 6 | .. code-block:: shell 7 | 8 | $ pip install torch_max_mem 9 | 10 | The most recent code and data can be installed directly from GitHub with: 11 | 12 | .. code-block:: shell 13 | 14 | $ pip install git+https://github.com/mberr/torch-max-mem.git 15 | 16 | To install in development mode, use the following: 17 | 18 | .. code-block:: shell 19 | 20 | $ git clone git+https://github.com/mberr/torch-max-mem.git 21 | $ cd torch-max-mem 22 | $ pip install -e . 23 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | .. automodule:: torch_max_mem.api 4 | :members: 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["uv_build>=0.6.6,<0.7"] 3 | build-backend = "uv_build" 4 | 5 | [project] 6 | name = "torch_max_mem" 7 | version = "0.1.5-dev" 8 | description = "Maximize memory utilization with PyTorch." 9 | readme = "README.md" 10 | authors = [{ name = "Max Berrendorf", email = "max.berrendorf@gmail.com" }] 11 | maintainers = [{ name = "Max Berrendorf", email = "max.berrendorf@gmail.com" }] 12 | 13 | # See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#classifiers 14 | # Search tags using the controlled vocabulary at https://pypi.org/classifiers 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Environment :: Console", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | "Framework :: Pytest", 22 | "Framework :: tox", 23 | "Framework :: Sphinx", 24 | "Programming Language :: Python", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3.13", 30 | "Programming Language :: Python :: 3 :: Only", 31 | # TODO add your topics from the Trove controlled vocabulary (see https://pypi.org/classifiers) 32 | ] 33 | keywords = [ 34 | "snekpack", # please keep this keyword to credit the cookiecutter-snekpack template 35 | "cookiecutter", 36 | "torch", 37 | ] 38 | 39 | # License Information. 40 | # See PEP-639 at https://peps.python.org/pep-0639/#add-license-files-key 41 | license-files = ["LICENSE"] 42 | 43 | requires-python = ">=3.9" 44 | dependencies = ["torch>=2.0", "typing_extensions"] 45 | 46 | # see https://peps.python.org/pep-0735/ and https://docs.astral.sh/uv/concepts/dependencies/#dependency-groups 47 | [dependency-groups] 48 | tests = ["pytest", "coverage[toml]"] 49 | docs = [ 50 | "sphinx>=8", 51 | "sphinx-rtd-theme>=3.0", 52 | "sphinx-click", 53 | "sphinx_automodapi", 54 | # Include if your project uses Pydantic: 55 | # "autodoc_pydantic", 56 | # To include LaTeX comments easily in your docs. 57 | # If you uncomment this, don't forget to do the same in docs/conf.py 58 | # texext 59 | ] 60 | lint = ["ruff"] 61 | typing = [ 62 | { include-group = "tests" }, 63 | "mypy", 64 | # You will probably have to add additional type stubs here, especially if you're using tox-uv 65 | ] 66 | docs-lint = [{ include-group = "docs" }, "doc8"] 67 | format-docs = [{ include-group = "docs" }, "docstrfmt"] 68 | doctests = ["xdoctest", "pygments"] 69 | pyroma = ["pyroma", "pygments"] 70 | # follow https://github.com/astral-sh/uv/issues/6298 for switching to a uv-based version bump workflow 71 | bump = ["bump-my-version"] 72 | build = ["uv", "uv-build"] 73 | release = [{ include-group = "build" }, "uv", "keyring"] 74 | 75 | # see https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#dependencies-optional-dependencies 76 | # [project.optional-dependencies] 77 | 78 | [project.urls] 79 | Homepage = "https://github.com/mberr/torch-max-mem" 80 | Download = "https://github.com/mberr/torch-max-mem/releases" 81 | "Bug Tracker" = "https://github.com/mberr/torch-max-mem/issues" 82 | "Source Code" = "https://github.com/mberr/torch-max-mem" 83 | 84 | 85 | [tool.cruft] 86 | skip = ["**/__init__.py", "tests/*"] 87 | 88 | # MyPy, see https://mypy.readthedocs.io/en/stable/config_file.html 89 | [tool.mypy] 90 | 91 | # Doc8, see https://doc8.readthedocs.io/en/stable/readme.html#ini-file-usage 92 | [tool.doc8] 93 | max-line-length = 120 94 | 95 | # Pytest, see https://docs.pytest.org/en/stable/reference/customize.html#pyproject-toml 96 | [tool.pytest.ini_options] 97 | markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] 98 | 99 | # Coverage, see https://coverage.readthedocs.io/en/latest/config.html 100 | [tool.coverage.run] 101 | branch = true 102 | source = ["torch_max_mem"] 103 | omit = ["tests/*", "docs/*", "src/torch_max_mem/version.py"] 104 | 105 | [tool.coverage.paths] 106 | source = ["src/torch_max_mem", ".tox/*/lib/python*/site-packages/torch_max_mem"] 107 | 108 | [tool.coverage.report] 109 | show_missing = true 110 | exclude_lines = [ 111 | "pragma: no cover", 112 | "raise NotImplementedError", 113 | "if __name__ == \"__main__\":", 114 | "if TYPE_CHECKING:", 115 | "def __str__", 116 | "def __repr__", 117 | ] 118 | 119 | [tool.ruff] 120 | line-length = 120 121 | extend-include = ["*.ipynb"] 122 | 123 | [tool.ruff.lint] 124 | # See https://docs.astral.sh/ruff/rules 125 | extend-select = [ 126 | "F", # pyflakes 127 | "E", # pycodestyle errors 128 | "W", # pycodestyle warnings 129 | "C90", # mccabe 130 | "I", # isort 131 | "UP", # pyupgrade 132 | "D", # pydocstyle 133 | "DOC", # pydoclint 134 | "B", # bugbear 135 | "S", # bandit 136 | "T20", # print 137 | "N", # pep8 naming 138 | "ERA", # eradicate commented out code 139 | "NPY", # numpy checks 140 | "RUF", # ruff rules 141 | "C4", # comprehensions 142 | ] 143 | ignore = [ 144 | "D105", # Missing docstring in magic method 145 | "E203", # Black conflicts with the following 146 | ] 147 | 148 | # See https://docs.astral.sh/ruff/settings/#per-file-ignores 149 | [tool.ruff.lint.per-file-ignores] 150 | # Ignore security issues in the version.py, which are inconsistent 151 | "src/torch_max_mem/version.py" = ["S603", "S607"] 152 | # Ignore commented out code in Sphinx configuration file 153 | "docs/source/conf.py" = ["ERA001"] 154 | # Prints are okay in notebooks 155 | "notebooks/**/*.ipynb" = ["T201"] 156 | # Ignore asserts in tests (with pytest) 157 | "tests/**/*.py" = ["S101"] 158 | 159 | [tool.ruff.lint.pydocstyle] 160 | convention = "pep257" 161 | 162 | [tool.ruff.lint.isort] 163 | relative-imports-order = "closest-to-furthest" 164 | known-third-party = ["tqdm"] 165 | known-first-party = ["torch_max_mem", "tests"] 166 | 167 | [tool.ruff.format] 168 | # see https://docs.astral.sh/ruff/settings/#format_docstring-code-format 169 | docstring-code-format = true 170 | 171 | [tool.bumpversion] 172 | current_version = "0.1.5-dev" 173 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(?:-(?P[0-9A-Za-z-]+(?:\\.[0-9A-Za-z-]+)*))?(?:\\+(?P[0-9A-Za-z-]+(?:\\.[0-9A-Za-z-]+)*))?" 174 | serialize = [ 175 | "{major}.{minor}.{patch}-{release}+{build}", 176 | "{major}.{minor}.{patch}+{build}", 177 | "{major}.{minor}.{patch}-{release}", 178 | "{major}.{minor}.{patch}", 179 | ] 180 | commit = true 181 | tag = false 182 | 183 | [tool.bumpversion.parts.release] 184 | optional_value = "production" 185 | first_value = "dev" 186 | values = ["dev", "production"] 187 | 188 | [[tool.bumpversion.files]] 189 | filename = "pyproject.toml" 190 | search = "version = \"{current_version}\"" 191 | replace = "version = \"{new_version}\"" 192 | 193 | [[tool.bumpversion.files]] 194 | filename = "docs/source/conf.py" 195 | search = "release = \"{current_version}\"" 196 | replace = "release = \"{new_version}\"" 197 | 198 | [[tool.bumpversion.files]] 199 | filename = "src/torch_max_mem/version.py" 200 | search = "VERSION = \"{current_version}\"" 201 | replace = "VERSION = \"{new_version}\"" 202 | -------------------------------------------------------------------------------- /src/torch_max_mem/__init__.py: -------------------------------------------------------------------------------- 1 | """Maximize memory utilization with PyTorch.""" 2 | 3 | from .api import MemoryUtilizationMaximizer, maximize_memory_utilization 4 | 5 | __all__ = [ 6 | "MemoryUtilizationMaximizer", 7 | "maximize_memory_utilization", 8 | ] 9 | -------------------------------------------------------------------------------- /src/torch_max_mem/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the public API. 3 | 4 | Assume you have a function for batched computation of nearest neighbors using brute-force distance calculation. 5 | 6 | .. code-block:: python 7 | 8 | import torch 9 | 10 | 11 | def knn(x, y, batch_size, k: int = 3): 12 | return torch.cat( 13 | [ 14 | torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices 15 | for start in range(0, x.shape[0], batch_size) 16 | ], 17 | dim=0, 18 | ) 19 | 20 | Using :func:`maximize_memory_utilization` you can decorate this function to reduce the batch size until no more 21 | out-of-memory error occurs. 22 | 23 | .. code-block:: python 24 | 25 | import torch 26 | from torch_max_mem import maximize_memory_utilization 27 | 28 | 29 | @maximize_memory_utilization() 30 | def knn(x, y, batch_size, k: int = 3): 31 | return torch.cat( 32 | [ 33 | torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices 34 | for start in range(0, x.shape[0], batch_size) 35 | ], 36 | dim=0, 37 | ) 38 | 39 | 40 | In the code, you can now always pass the largest sensible batch size, e.g., 41 | 42 | .. code-block:: python 43 | 44 | x = torch.rand(100, 100, device="cuda") 45 | y = torch.rand(200, 100, device="cuda") 46 | knn(x, y, batch_size=x.shape[0]) 47 | """ 48 | 49 | # cf. https://gist.github.com/mberr/c37a8068b38cabc98228db2cbe358043 50 | from __future__ import annotations 51 | 52 | import functools 53 | import inspect 54 | import itertools 55 | import logging 56 | from collections.abc import Collection, Iterable, Mapping, MutableMapping, Sequence 57 | from typing import ( 58 | Any, 59 | Callable, 60 | TypeVar, 61 | ) 62 | 63 | import torch 64 | from typing_extensions import ParamSpec 65 | 66 | logger = logging.getLogger(__name__) 67 | 68 | __all__ = [ 69 | "maximize_memory_utilization", 70 | ] 71 | 72 | R = TypeVar("R") 73 | P = ParamSpec("P") 74 | 75 | 76 | def upgrade_to_sequence( 77 | parameter_name: str | Sequence[str], q: int | Sequence[int] 78 | ) -> tuple[tuple[str, ...], tuple[int, ...]]: 79 | """ 80 | Ensure that both, parameter names and q values, are provided as a sequence. 81 | 82 | Besides upgrading both to a tuple, it will also broadcast q if necessary. 83 | 84 | :param parameter_name: 85 | the parameter name, or a sequence thereof 86 | :param q: 87 | the q value, or a sequence thereof 88 | 89 | :return: 90 | a tuple of parameter names and a sequence of q values of same length 91 | 92 | :raises ValueError: 93 | when the (inferred) length of q and parameter_name do not match 94 | """ 95 | # normalize parameter name 96 | parameter_names = (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name) 97 | q = (q,) if isinstance(q, int) else tuple(q) 98 | q = q * len(parameter_names) if len(q) == 1 else q 99 | if len(q) != len(parameter_names): 100 | raise ValueError(f"length of {q=} does not match length of {parameter_names=}") 101 | return parameter_names, q 102 | 103 | 104 | def determine_default_max_value( 105 | func: Callable[..., Any], parameter_name: str, signature: inspect.Signature 106 | ) -> int | None: 107 | """ 108 | Determine the default maximum value based on the signature. 109 | 110 | :param func: 111 | the function; only used for nice error messages 112 | :param parameter_name: 113 | the name of the parameter 114 | :param signature: 115 | the signature of the function 116 | 117 | :return: 118 | the default value as an integer, if any is given. 119 | 120 | :raises ValueError: 121 | when the function does not have a parameter of the given name 122 | """ 123 | if parameter_name not in signature.parameters: 124 | raise ValueError(f"{func} does not have a parameter {parameter_name}.") 125 | _parameter = signature.parameters[parameter_name] 126 | if _parameter.annotation != inspect.Parameter.empty and _parameter.annotation not in ( 127 | int, 128 | "int", 129 | ): 130 | logger.warning( 131 | f"Memory utilization maximization is written for integer parameters, but the " 132 | f"{parameter_name} is annotated as {_parameter.annotation}; casting to int", 133 | ) 134 | if _parameter.default != inspect.Parameter.empty: 135 | return int(_parameter.default) 136 | return None 137 | 138 | 139 | def determine_max_value( 140 | bound_arguments: inspect.BoundArguments, 141 | parameter_name: str, 142 | default_max_value: int | Callable[P, int] | None, 143 | *args: P.args, 144 | **kwargs: P.kwargs, 145 | ) -> int: 146 | """ 147 | Either use the provided value, or the default maximum value. 148 | 149 | :param bound_arguments: 150 | the bound arguments of the function 151 | :param args: 152 | the positional parameters of the function: necessary when the default max value is a callable 153 | :param kwargs: 154 | the keyword parameters of the function: necessary when the default max value is a callable 155 | :param parameter_name: 156 | the parameter name 157 | :param default_max_value: 158 | the default max value, or a callable to determine one 159 | 160 | :return: 161 | the maximum value 162 | 163 | :raises ValueError: 164 | when the given value to the parameter is None 165 | """ 166 | max_value = bound_arguments.arguments.get(parameter_name) 167 | if isinstance(max_value, int): 168 | return max_value 169 | if max_value is not None: 170 | raise ValueError(f"{parameter_name}={max_value!r} is neither integer nor None.") 171 | if default_max_value is None: 172 | raise ValueError("Neither value nor default value found") 173 | if isinstance(default_max_value, int): 174 | return default_max_value 175 | return default_max_value(*args, **kwargs) 176 | 177 | 178 | # cf. https://github.com/pykeen/pykeen/pull/279 179 | ADDITIONAL_OOM_ERROR_INFIXES = { 180 | # An error that occurs because the input in CUDA is too big. 181 | # cf. https://discuss.pytorch.org/t/cudnn-status-not-supported-this-error-may-appear-if-you-passed-in-a-non-contiguous-input/ # noqa: E501 182 | "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.", 183 | # The torch < 2.0 way of OOM errors 184 | "CUDA out of memory.", 185 | # cf. https://github.com/pytorch/pytorch/issues/51871 186 | "nonzero is not supported for tensors with more than INT_MAX elements", 187 | # cf. https://discuss.pytorch.org/t/runtime-error-invalid-buffer-size-when-calculating-cosine-similarity/152088 188 | "Invalid buffer size: ", 189 | # MPS OOM error 190 | "MPS backend out of memory", 191 | # CPU OOM error 192 | "DefaultCPUAllocator: not enough memory:", 193 | } 194 | 195 | 196 | def iter_tensor_devices(*args: Any, **kwargs: Any) -> Iterable[torch.device]: 197 | """Iterate over tensors' devices (may contain duplicates).""" 198 | for obj in itertools.chain(args, kwargs.values()): 199 | if isinstance(obj, torch.Tensor): 200 | yield obj.device 201 | 202 | 203 | def create_tensor_checker( 204 | safe_devices: Collection[str] | None = None, 205 | ) -> Callable[P, None]: 206 | """ 207 | Create a function that warns when tensors are on any device that is not considered safe. 208 | 209 | :param safe_devices: 210 | these devices are considered safe, i.e., the program will receive meaningful exceptions to handle out of memory 211 | (OOM) issues. For example for CPU, OOM errors may trigger the operating system's OOM killer to directly 212 | terminate the process without any catchable exceptions. Defaults to ``{"cuda"}``. 213 | 214 | :return: 215 | a function that checks its parameters for tensors and emits a warning if any is on a non-safe device. 216 | """ 217 | if safe_devices is None: 218 | safe_devices = {"cuda"} 219 | safe_devices_set = frozenset(safe_devices) 220 | logger.debug( 221 | f"Will warn about running memory utilization maximization on tensors on devices other than {safe_devices_set}", 222 | ) 223 | 224 | def check_tensors(*args: P.args, **kwargs: P.kwargs) -> None: 225 | """Check whether any tensor argument is on a dangerous device.""" 226 | device_types = {device.type for device in iter_tensor_devices(*args, **kwargs)} 227 | 228 | if not safe_devices_set.issuperset(device_types): 229 | logger.warning( 230 | f"Encountered tensors on {device_types=} while only {sorted(safe_devices_set)} are considered safe for " 231 | f"automatic memory utilization maximization. This may lead to undocumented crashes (but can be safe, " 232 | f"too).", 233 | ) 234 | 235 | return check_tensors 236 | 237 | 238 | def floor_to_nearest_multiple_of(x: int, q: int) -> int: 239 | """ 240 | Try to ensure that x is a multiple of q. 241 | 242 | :param x: 243 | the input value 244 | :param q: 245 | the desired base factor 246 | 247 | :return: 248 | x if x is smaller than q, otherwise, the largest multiple of q that is smaller than x 249 | """ 250 | if x <= q: 251 | return x 252 | # note: the brackets are for readability only 253 | return (x // q) * q 254 | 255 | 256 | def is_oom_error(error: BaseException) -> bool: 257 | """ 258 | Return whether the given exception is an out-of-memory (like) exception. 259 | 260 | :param error: 261 | the error 262 | 263 | :return: 264 | whether it should be handled like an out-of-memory exception 265 | """ 266 | if isinstance(error, torch.cuda.OutOfMemoryError): 267 | return True 268 | if not isinstance(error, RuntimeError): 269 | return False 270 | message = str(error) 271 | return any(infix in message for infix in ADDITIONAL_OOM_ERROR_INFIXES) 272 | 273 | 274 | def maximize_memory_utilization_decorator( 275 | parameter_name: str | Sequence[str] = "batch_size", 276 | q: int | Sequence[int] = 32, 277 | safe_devices: Collection[str] | None = None, 278 | ) -> Callable[[Callable[P, R]], Callable[P, tuple[R, tuple[int, ...]]]]: 279 | """ 280 | Create decorators to create methods for memory utilization maximization. 281 | 282 | :param parameter_name: 283 | The parameter name. 284 | :param q: 285 | Prefer multiples of q as size. 286 | :param safe_devices: 287 | These devices are considered safe to run maximization on, cf. :meth:`create_tensor_checker`. 288 | 289 | :return: 290 | A decorator for functions. 291 | """ 292 | maybe_warn: Callable[..., None] = create_tensor_checker(safe_devices=safe_devices) 293 | parameter_names, qs = upgrade_to_sequence(parameter_name, q) 294 | 295 | def decorator_maximize_memory_utilization( 296 | func: Callable[P, R], 297 | ) -> Callable[P, tuple[R, tuple[int, ...]]]: 298 | """ 299 | Decorate a function to maximize memory utilization. 300 | 301 | :param func: 302 | The function to decorate. 303 | 304 | :return: 305 | The decorated function. 306 | """ 307 | # Input validation, and extraction of default maximum values 308 | signature = inspect.signature(func) 309 | default_max_values = { 310 | name: determine_default_max_value(func=func, parameter_name=name, signature=signature) 311 | for name in parameter_names 312 | } 313 | 314 | @functools.wraps(func) 315 | def wrapper_maximize_memory_utilization(*args: P.args, **kwargs: P.kwargs) -> tuple[R, tuple[int, ...]]: 316 | """ 317 | Wrap a function to maximize memory utilization by successive halving. 318 | 319 | :param args: 320 | The positional arguments. 321 | :param kwargs: 322 | The key-word based arguments. 323 | 324 | :return: 325 | A tuple (result, max_value). 326 | 327 | :raises MemoryError: 328 | if the execution did not even succeed with the smallest parameter value 329 | :raises RuntimeError: 330 | if a runtime error which is unrelated to known OOM errors occurred 331 | """ 332 | maybe_warn(*args, **kwargs) 333 | bound_arguments = signature.bind(*args, **kwargs) 334 | bound_arguments.apply_defaults() 335 | # determine actual max values 336 | max_values = [ 337 | determine_max_value( 338 | bound_arguments, 339 | name, 340 | default_max_value, 341 | *args, 342 | **kwargs, 343 | ) 344 | for name, default_max_value in default_max_values.items() 345 | ] 346 | i = 0 347 | 348 | # store the last error, so we can have a nice traceback for further inspection 349 | last_error: BaseException | None = None 350 | 351 | while i < len(max_values): 352 | while max_values[i] > 0: 353 | p_kwargs = dict(zip(parameter_names, max_values)) 354 | # note: changes to arguments apply to both, .args and .kwargs 355 | bound_arguments.arguments.update(p_kwargs) 356 | try: 357 | return func(*bound_arguments.args, **bound_arguments.kwargs), tuple(max_values) 358 | except (torch.cuda.OutOfMemoryError, RuntimeError) as error: 359 | # raise errors unrelated to out-of-memory 360 | if not is_oom_error(error): 361 | raise error 362 | 363 | # clear cache 364 | if torch.cuda.is_available(): 365 | torch.cuda.empty_cache() 366 | # https://pytorch.org/docs/stable/notes/mps.html 367 | if torch.backends.mps.is_available(): 368 | # there is no torch.mps.is_available() 369 | torch.mps.empty_cache() 370 | 371 | # reduce parameter 372 | logger.info(f"Execution failed with {p_kwargs=}") 373 | max_values[i] = floor_to_nearest_multiple_of(x=max_values[i] // 2, q=qs[i]) 374 | 375 | # update last error 376 | last_error = error 377 | # we lowered the current parameter to 1, but still see memory issues; continue with the next in line... 378 | max_values[i] = 1 379 | i += 1 380 | # log memory summary for each CUDA device before raising memory error 381 | for device in {d for d in iter_tensor_devices(*args, **kwargs) if d.type == "cuda"}: 382 | logger.debug(f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}") 383 | raise MemoryError(f"Execution did not even succeed with {parameter_names} all equal to 1.") from last_error 384 | 385 | return wrapper_maximize_memory_utilization 386 | 387 | return decorator_maximize_memory_utilization 388 | 389 | 390 | class KeyHasher: 391 | """A hasher based on (a subset of) keys.""" 392 | 393 | @staticmethod 394 | def normalize_keys(keys: Collection[str] | str | None) -> Collection[str]: 395 | """ 396 | Normalize keys to be a collection of strings. 397 | 398 | :param keys: 399 | the keys 400 | 401 | :return: 402 | - if keys is None, the empty list 403 | - if keys is a string, a singleton list 404 | - else the keys 405 | """ 406 | if keys is None: 407 | return [] 408 | if isinstance(keys, str): 409 | return [keys] 410 | return keys 411 | 412 | def __init__(self, keys: Collection[str] | str | None) -> None: 413 | """ 414 | Initialize the hasher. 415 | 416 | :param keys: 417 | the keys whose associated values should be used for hashing 418 | """ 419 | self.keys = self.normalize_keys(keys) 420 | 421 | def __call__(self, kwargs: Mapping[str, Any]) -> int: 422 | """ 423 | Calculate the hash based on the values associated with the selected keys. 424 | 425 | :param kwargs: 426 | the key-value dictionary 427 | 428 | :return: 429 | the hash of the tuple of values associated with the stored keys. 430 | """ 431 | return hash(tuple(kwargs.get(key, None) for key in self.keys)) 432 | 433 | 434 | class MemoryUtilizationMaximizer: 435 | """Stateful memory utilization maximizer.""" 436 | 437 | def __init__( 438 | self, 439 | parameter_name: str | Sequence[str] = "batch_size", 440 | q: int | Sequence[int] = 32, 441 | safe_devices: Collection[str] | None = None, 442 | hasher: Callable[[Mapping[str, Any]], int] | None = None, 443 | keys: Collection[str] | str | None = None, 444 | ) -> None: 445 | """ 446 | Initialize the stateful maximizer. 447 | 448 | :param parameter_name: 449 | The parameter name. 450 | :param q: 451 | Prefer multiples of q as size. 452 | :param safe_devices: 453 | These devices are considered safe to run maximization on, cf. :meth:`create_tensor_checker`. 454 | :param hasher: 455 | a hashing function for separate parameter values depending on hash value; if None, use the same for all 456 | :param keys: 457 | the keys to use for creating a hasher. Only used if hasher is None. 458 | """ 459 | self.parameter_names, self.qs = upgrade_to_sequence(parameter_name=parameter_name, q=q) 460 | self.safe_devices = safe_devices 461 | self.parameter_value: MutableMapping[int, tuple[int, ...]] = {} 462 | if hasher is None: 463 | keys = KeyHasher.normalize_keys(keys) 464 | intersection = set(keys).intersection(self.parameter_names) 465 | if intersection: 466 | logger.warning( 467 | f"{intersection=} are contained in the hashing keys *and* the parameter names; " 468 | f"likely you want to remove {intersection} from hashing keys.", 469 | ) 470 | hasher = KeyHasher(keys=keys) 471 | self.hasher = hasher 472 | 473 | def __call__(self, func: Callable[P, R]) -> Callable[P, R]: 474 | """Wrap the function.""" 475 | wrapped = maximize_memory_utilization_decorator( 476 | parameter_name=self.parameter_names, 477 | q=self.qs, 478 | safe_devices=self.safe_devices, 479 | )(func) 480 | signature = inspect.signature(func) 481 | 482 | @functools.wraps(wrapped) 483 | def inner(*args: P.args, **kwargs: P.kwargs) -> R: 484 | """Evaluate function with the stored parameter size.""" 485 | h = self.hasher(kwargs) 486 | if h in self.parameter_value: 487 | values = self.parameter_value[h] 488 | else: 489 | bound = signature.bind(*args, **kwargs) 490 | bound.apply_defaults() 491 | # todo: default logic? 492 | values = tuple(bound.arguments[name] for name in self.parameter_names) 493 | kwargs.update(zip(self.parameter_names, values)) 494 | result, self.parameter_value[h] = wrapped(*args, **kwargs) 495 | return result 496 | 497 | return inner 498 | 499 | 500 | # alias 501 | maximize_memory_utilization = MemoryUtilizationMaximizer 502 | -------------------------------------------------------------------------------- /src/torch_max_mem/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mberr/torch-max-mem/fb9d91e5e5e9d566fe83410798fea8521098afb8/src/torch_max_mem/py.typed -------------------------------------------------------------------------------- /src/torch_max_mem/version.py: -------------------------------------------------------------------------------- 1 | """Version information for :mod:`torch_max_mem`. 2 | 3 | Run with ``python -m torch_max_mem.version`` 4 | """ 5 | 6 | import os 7 | from subprocess import CalledProcessError, check_output 8 | 9 | __all__ = [ 10 | "VERSION", 11 | "get_git_hash", 12 | "get_version", 13 | ] 14 | 15 | VERSION = "0.1.5-dev" 16 | 17 | 18 | def get_git_hash() -> str: 19 | """Get the :mod:`torch_max_mem` git hash.""" 20 | with open(os.devnull, "w") as devnull: 21 | try: 22 | ret = check_output( 23 | ["git", "rev-parse", "HEAD"], 24 | cwd=os.path.dirname(__file__), 25 | stderr=devnull, 26 | ) 27 | except CalledProcessError: 28 | return "UNHASHED" 29 | else: 30 | return ret.strip().decode("utf-8")[:8] 31 | 32 | 33 | def get_version(with_git_hash: bool = False) -> str: 34 | """Get the :mod:`torch_max_mem` version string, including a git hash.""" 35 | return f"{VERSION}-{get_git_hash()}" if with_git_hash else VERSION 36 | 37 | 38 | if __name__ == "__main__": 39 | print(get_version(with_git_hash=True)) # noqa:T201 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for :mod:`torch_max_mem`.""" 2 | -------------------------------------------------------------------------------- /tests/test_decorator.py: -------------------------------------------------------------------------------- 1 | """Tests.""" 2 | 3 | import unittest 4 | from typing import Any, Optional 5 | 6 | import pytest 7 | import torch 8 | 9 | from torch_max_mem import maximize_memory_utilization 10 | from torch_max_mem.api import floor_to_nearest_multiple_of, is_oom_error, maximize_memory_utilization_decorator 11 | 12 | 13 | def knn(x: torch.Tensor, y: torch.Tensor, batch_size: int, k: int = 3) -> torch.Tensor: 14 | """Compute k-nearst neigbors via batched brute-force distance calculation.""" 15 | return torch.cat( 16 | [ 17 | torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices 18 | for start in range(0, x.shape[0], batch_size) 19 | ], 20 | dim=0, 21 | ) 22 | 23 | 24 | wrapped_knn = maximize_memory_utilization_decorator(parameter_name="batch_size")(knn) 25 | wrapped_knn_stateful = maximize_memory_utilization()(knn) 26 | 27 | 28 | class TestDecorator(unittest.TestCase): 29 | """Test the decorator.""" 30 | 31 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 32 | 33 | @property 34 | def rng(self) -> torch.Generator: 35 | """Return the random number generator.""" 36 | return torch.Generator(device=self.device).manual_seed(42) 37 | 38 | def test_knn(self) -> None: 39 | """Test consistent results between original and wrapped method.""" 40 | x = torch.rand(100, 100, device=self.device, generator=self.rng) 41 | y = torch.rand(200, 100, device=self.device, generator=self.rng) 42 | for batch_size in [1, 10, x.shape[0]]: 43 | reference = knn(x, y, batch_size) 44 | optimized = wrapped_knn(x, y, batch_size=x.shape[0])[0] 45 | assert reference.shape == optimized.shape 46 | assert torch.allclose(reference, optimized) 47 | 48 | def test_knn_stateful(self) -> None: 49 | """Test consistent results between original and wrapped method for stateful wrapper.""" 50 | x = torch.rand(100, 100, device=self.device, generator=self.rng) 51 | y = torch.rand(200, 100, device=self.device, generator=self.rng) 52 | for batch_size in [1, 10, x.shape[0]]: 53 | reference = knn(x, y, batch_size) 54 | optimized = wrapped_knn_stateful(x, y, batch_size=x.shape[0]) 55 | assert reference.shape == optimized.shape 56 | assert torch.allclose(reference, optimized) 57 | 58 | 59 | def test_parameter_types() -> None: 60 | """Test decoration for various parameter types.""" 61 | 62 | @maximize_memory_utilization() 63 | def positional_or_keyword_only_func(a: Any, batch_size: int) -> None: 64 | """Evaluate a function where batch_size is a positional or keyword parameter.""" 65 | 66 | @maximize_memory_utilization() 67 | def keyword_only_func(*a: Any, batch_size: int) -> None: 68 | """Evaluate a function where batch_size is a keyword-only parameter.""" 69 | 70 | 71 | @pytest.mark.parametrize("keys", [None, ("a",), ("a", "b", "c")]) 72 | def test_key_hasher(keys: Optional[tuple[str, ...]]) -> None: 73 | """Test ad-hoc hasher.""" 74 | 75 | def func(a: Any, b: Any, c: Any, batch_size: int) -> None: 76 | """Test function.""" 77 | pass 78 | 79 | wrapped = maximize_memory_utilization(keys=keys)(func) 80 | wrapped(a=1, b=3, c=7, batch_size=2) 81 | 82 | 83 | def test_default_no_arg() -> None: 84 | """Test decoration's interaction with default parameters.""" 85 | 86 | @maximize_memory_utilization() 87 | def func(batch_size: int = 7) -> None: 88 | """Test function.""" 89 | 90 | # call with no arg 91 | func() 92 | 93 | 94 | def test_optimization() -> None: 95 | """Test optimization.""" 96 | 97 | @maximize_memory_utilization() 98 | def func(batch_size: int = 8) -> int: 99 | """Test function.""" 100 | if batch_size > 2: 101 | raise torch.cuda.OutOfMemoryError 102 | return batch_size 103 | 104 | assert func() == 2 105 | 106 | 107 | def test_optimization_multi_level() -> None: 108 | """Test optimization with multiple levels.""" 109 | 110 | @maximize_memory_utilization(parameter_name=("batch_size", "slice_size")) 111 | def func(batch_size: int = 8, slice_size: int = 16) -> tuple[int, int]: 112 | """Test function.""" 113 | if batch_size > 1 or slice_size > 8: 114 | raise torch.cuda.OutOfMemoryError 115 | return batch_size, slice_size 116 | 117 | assert func() == (1, 8) 118 | 119 | 120 | @pytest.mark.parametrize(("x", "q"), [(15, 4), (3, 4)]) 121 | def test_floor_to_nearest_multiple_of(x: int, q: int) -> None: 122 | """Test floor_to_nearest_multiple_of.""" 123 | r = floor_to_nearest_multiple_of(x=x, q=q) 124 | # check type 125 | assert isinstance(r, int) 126 | # check flooring 127 | assert r <= x 128 | # check multiple of q if possible 129 | assert r < q or (r % q == 0) 130 | # check maximality 131 | assert r + q > x 132 | 133 | 134 | @pytest.mark.parametrize( 135 | ("error", "exp"), 136 | [ 137 | # base cases 138 | (NameError(), False), 139 | # CUDA 140 | (torch.cuda.OutOfMemoryError(), True), 141 | # MPS 142 | # cf. https://github.com/mberr/torch-max-mem/issues/14 143 | (RuntimeError("Invalid buffer size: 74.51 GB"), True), 144 | ( 145 | RuntimeError( 146 | "MPS backend out of memory (MPS allocated: 119.30 MB, other allocations: 43.18 GB, max allowed: " 147 | "36.27 GB). Tried to allocate 4.76 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 " 148 | "to disable upper limit for memory allocations (may cause system failure).", 149 | ), 150 | True, 151 | ), 152 | # cf. https://github.com/mberr/torch-max-mem/pull/15 153 | (RuntimeError("selected index k out of range"), False), 154 | ], 155 | ) 156 | def test_oom_error_detection(error: BaseException, exp: bool) -> None: 157 | """Test OOM error detection.""" 158 | assert is_oom_error(error) is exp 159 | 160 | 161 | @pytest.mark.slow 162 | def test_large_on_mps() -> None: 163 | """Test memory optimization on a large input.""" 164 | import torch.backends.mps 165 | 166 | if not torch.backends.mps.is_available(): 167 | pytest.skip("Cannot run on CPU") 168 | # note: this test currently cannot run on GHA, cf. 169 | # - https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773 170 | # - https://github.com/mberr/torch-max-mem/actions/runs/7820367693/job/21334908894 171 | pytest.skip( 172 | "temporarily disabled, cf. https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773" 173 | ) 174 | 175 | x = torch.rand(100000, 100, device="mps") 176 | y = torch.rand(200000, 100, device="mps") 177 | _result, (batch_size,) = wrapped_knn(x, y, batch_size=x.shape[0]) 178 | assert batch_size > 0 179 | 180 | 181 | @pytest.mark.slow 182 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support.") 183 | def test_large_on_cuda() -> None: 184 | """Test memory optimization on a large input.""" 185 | x = torch.rand(32_000, 100, device="cuda") 186 | y = torch.rand(200_000, 100, device="cuda") 187 | _result, (batch_size,) = wrapped_knn(x, y, batch_size=x.shape[0]) 188 | assert batch_size < x.shape[0], "test example was too small" 189 | assert batch_size > 0 190 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | # To use a PEP 517 build-backend you are required to configure tox to use an isolated_build: 8 | # https://tox.readthedocs.io/en/latest/example/package.html 9 | isolated_build = True 10 | 11 | # These environments are run in order if you just use `tox`: 12 | envlist = 13 | # always keep coverage-clean first 14 | coverage-clean 15 | # code formatters 16 | format 17 | # format-docs 18 | # Code quality assessment 19 | pyroma 20 | lint 21 | lint-markdown 22 | mypy 23 | # Documentation quality assurance 24 | doc8 25 | docstr-coverage 26 | docs-test 27 | # the actual tests 28 | py 29 | doctests 30 | # always keep coverage-report last 31 | coverage-report 32 | 33 | [testenv] 34 | description = Run unit and integration tests. 35 | # Runs on the "tests" directory by default, or passes the positional 36 | # arguments from `tox -e py ... 37 | commands = 38 | coverage run -p -m pytest --durations=20 {posargs:tests} 39 | coverage combine 40 | coverage xml 41 | # See the [dependency-groups] entry in pyproject.toml for "tests" 42 | dependency_groups = 43 | tests 44 | 45 | [testenv:coverage-clean] 46 | description = Remove testing coverage artifacts. 47 | deps = coverage[toml] 48 | skip_install = true 49 | commands = coverage erase 50 | 51 | [testenv:doctests] 52 | description = Test that documentation examples run properly. 53 | commands = 54 | # note that the package name is required for discovery 55 | xdoctest -m src/torch_max_mem 56 | dependency_groups = 57 | doctests 58 | 59 | [testenv:treon] 60 | description = Test that notebooks can run to completion 61 | commands = 62 | treon notebooks/ 63 | deps = 64 | treon 65 | 66 | [testenv:format] 67 | description = Format the code in a deterministic way using ruff. Note that ruff check should come before ruff format when using --fix (ref: https://github.com/astral-sh/ruff-pre-commit/blob/main/README.md) 68 | dependency_groups = 69 | lint 70 | skip_install = true 71 | commands = 72 | ruff check --fix 73 | ruff format 74 | 75 | [testenv:format-docs] 76 | description = Run documentation linters. 77 | # note that this doesn't work with sphinx-click 78 | # or any other extension that adds extra directives 79 | # See the [dependency-groups] entry in pyproject.toml for "rstfmt" 80 | dependency_groups = 81 | format-docs 82 | skip_install = true 83 | commands = 84 | docstrfmt src/ tests/ docs/ --no-docstring-trailing-line 85 | 86 | [testenv:format-markdown] 87 | description = Run markdown formatter. 88 | skip_install = true 89 | allowlist_externals = 90 | npx 91 | commands = 92 | npx --yes prettier --write --prose-wrap always "**/*.md" 93 | 94 | [testenv:lint] 95 | description = Check code quality using ruff and other tools. 96 | skip_install = true 97 | dependency_groups = 98 | lint 99 | commands = 100 | ruff check 101 | ruff format --check 102 | 103 | [testenv:lint-markdown] 104 | description = Check markdown is properly formatted. 105 | # inspired by https://github.com/astral-sh/uv/blob/98523e2014e9a5c69706623344026d76296e178f/.github/workflows/ci.yml#L67C1-L70C61 106 | skip_install = true 107 | allowlist_externals = 108 | npx 109 | commands = 110 | npx --yes prettier --check --prose-wrap always "**/*.md" 111 | 112 | [testenv:pyroma] 113 | dependency_groups = 114 | pyroma 115 | skip_install = true 116 | commands = pyroma --min=10 . 117 | description = Run the pyroma tool to check the package friendliness of the project. 118 | 119 | [testenv:mypy] 120 | description = Run the mypy tool to check static typing on the project. Installs the package to make sure all type stubs get recognized. 121 | dependency_groups = 122 | typing 123 | commands = mypy --ignore-missing-imports --strict src/ tests/ 124 | 125 | [testenv:docs-lint] 126 | skip_install = true 127 | dependency_groups = 128 | docs-lint 129 | commands = 130 | doc8 docs/source/ 131 | description = Run the doc8 tool to check the style of the RST files in the project docs. 132 | 133 | [testenv:docstr-coverage] 134 | description = Run the docstr-coverage tool to check documentation coverage. 135 | skip_install = true 136 | deps = 137 | docstr-coverage 138 | commands = 139 | docstr-coverage src/ tests/ --skip-private --skip-magic 140 | 141 | [testenv:docs] 142 | description = Build the documentation locally, allowing warnings. 143 | dependency_groups = 144 | # See the [dependency-groups] entry in pyproject.toml for "docs" 145 | docs 146 | # You might need to add additional extras if your documentation covers it 147 | commands = 148 | python -m sphinx -b html -d docs/build/doctrees docs/source docs/build/html 149 | 150 | [testenv:docs-test] 151 | description = Test building the documentation in an isolated environment. Warnings are considered as errors via -W. 152 | changedir = docs 153 | dependency_groups = 154 | {[testenv:docs]dependency_groups} 155 | commands = 156 | mkdir -p {envtmpdir} 157 | cp -r source {envtmpdir}/source 158 | python -m sphinx -W -b html -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/html 159 | # python -m sphinx -W -b coverage -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/coverage 160 | # cat {envtmpdir}/build/coverage/c.txt 161 | # cat {envtmpdir}/build/coverage/python.txt 162 | allowlist_externals = 163 | cp 164 | cat 165 | mkdir 166 | 167 | [testenv:coverage-report] 168 | deps = coverage[toml] 169 | skip_install = true 170 | commands = 171 | coverage report 172 | 173 | #################### 174 | # Deployment tools # 175 | #################### 176 | 177 | [testenv:bumpversion] 178 | description = Bump the version number 179 | commands = bump-my-version bump {posargs} 180 | skip_install = true 181 | passenv = HOME 182 | dependency_groups = 183 | bump 184 | 185 | [testenv:bumpversion-release] 186 | description = Remove the -dev tag from the version 187 | commands = bump-my-version bump release --tag 188 | skip_install = true 189 | passenv = HOME 190 | dependency_groups = 191 | bump 192 | 193 | [testenv:build] 194 | skip_install = true 195 | dependency_groups = 196 | build 197 | commands = 198 | uv build --sdist --wheel --no-build-isolation 199 | 200 | ############ 201 | # Releases # 202 | ############ 203 | 204 | # In order to make a release to PyPI, you'll need to take the following steps: 205 | # 206 | # 1. Navigate to https://pypi.org/account/register/ to register for Test PyPI 207 | # 2. Navigate to https://pypi.org/manage/account/ and request to re-send a verification email. 208 | # This is not sent by default, and is required to set up 2-Factor Authentication. 209 | # 3. Get account recovery codes 210 | # 4. Set up 2-Factor Authentication 211 | # 5. Get an API token from https://pypi.org/manage/account/token/ 212 | # 6. Install keyring with `uv tool install keyring` 213 | # 7. Add your token to keyring with `keyring set https://upload.pypi.org/legacy/ __token__` 214 | 215 | [testenv:release] 216 | description = Release the code to PyPI so users can pip install it, using credentials from keyring 217 | skip_install = true 218 | dependency_groups = 219 | release 220 | commands = 221 | {[testenv:build]commands} 222 | uv publish --username __token__ --keyring-provider subprocess --publish-url https://upload.pypi.org/legacy/ 223 | 224 | [testenv:release-via-env] 225 | description = Release the code to PyPI so users can pip install it, using credentials from the environment. 226 | skip_install = true 227 | dependency_groups = 228 | {[testenv:build]dependency_groups} 229 | release 230 | commands = 231 | {[testenv:build]commands} 232 | uv publish --publish-url https://upload.pypi.org/legacy/ 233 | passenv = 234 | UV_PUBLISH_USERNAME 235 | UV_PUBLISH_PASSWORD 236 | 237 | [testenv:finish] 238 | description = 239 | Run a workflow that removes -dev from the version, creates a tagged release on GitHub, 240 | creates a release on PyPI, and bumps the version again. 241 | skip_install = true 242 | passenv = 243 | HOME 244 | dependency_groups = 245 | bump 246 | release 247 | commands = 248 | {[testenv:bumpversion-release]commands} 249 | {[testenv:release]commands} 250 | git push --tags 251 | bump-my-version bump patch 252 | git push 253 | allowlist_externals = 254 | git 255 | 256 | ################# 257 | # Test Releases # 258 | ################# 259 | 260 | # In order to test making a release to Test PyPI, you'll need to take the following steps: 261 | # 262 | # 1. Navigate to https://test.pypi.org/account/register/ to register for Test PyPI 263 | # 2. Navigate to https://test.pypi.org/manage/account/ and request to re-send a verification email. 264 | # This is not sent by default, and is required to set up 2-Factor Authentication. 265 | # 3. Get account recovery codes 266 | # 4. Set up 2-Factor Authentication 267 | # 5. Get an API token from https://test.pypi.org/manage/account/token/ 268 | # 6. Install keyring with `uv tool install keyring` 269 | # 7. Add your token to keyring with `keyring set https://test.pypi.org/legacy/ __token__` 270 | 271 | [testenv:testrelease] 272 | description = Release the code to the test PyPI site 273 | skip_install = true 274 | dependency_groups = 275 | release 276 | commands = 277 | {[testenv:build]commands} 278 | uv publish --username __token__ --keyring-provider subprocess --publish-url https://test.pypi.org/legacy/ 279 | 280 | [testenv:testfinish] 281 | description = 282 | Run a workflow that removes -dev from the version, creates a tagged release on GitHub, 283 | creates a release on Test PyPI, and bumps the version again. 284 | skip_install = true 285 | passenv = 286 | HOME 287 | dependency_groups = 288 | {[testenv:testrelease]dependency_groups} 289 | bump 290 | release 291 | commands = 292 | {[testenv:bumpversion-release]commands} 293 | {[testenv:testrelease]commands} 294 | git push --tags 295 | bump-my-version bump patch 296 | git push 297 | allowlist_externals = 298 | git 299 | --------------------------------------------------------------------------------