├── .github └── workflows │ ├── build.yaml │ ├── format.yaml │ ├── release.yaml │ ├── typecheck.yaml │ └── unittest.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE.md ├── README.md ├── assets ├── koila.png └── koila.svg ├── docs ├── CODE_OF_CONDUCT.md ├── LICENSE.md ├── _config.yml ├── _toc.yml ├── assets └── index.md ├── pyproject.toml ├── src └── koila │ ├── __init__.py │ ├── constants.py │ ├── eager.py │ ├── errors.py │ ├── gpus.py │ ├── interfaces.py │ ├── lazy.py │ ├── prepasses.py │ └── shapes.py └── tests ├── __init__.py ├── common.py ├── test_layers.py ├── test_lazy.py ├── test_models.py └── test_prepasses.py /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build Pages 2 | on: [push] 3 | jobs: 4 | build-and-deploy: 5 | name: 📃 Website Build 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: 🔔 Check out 9 | uses: actions/checkout@v3 10 | 11 | - name: 🏗️ python 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: "3.10" 15 | 16 | - name: ⬇️ Python PDM 17 | uses: pdm-project/setup-pdm@v4 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: ⬇️ Python Dependencies 22 | run: pdm install -G:all 23 | 24 | - name: 🚂 Activate environment 25 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH 26 | 27 | - name: 🚧 Jupyter build 28 | run: jupyter book build docs 29 | 30 | - name: 📰 Publish docs 31 | uses: JamesIves/github-pages-deploy-action@v4 32 | with: 33 | branch: gh-pages 34 | folder: ./docs/_build/html 35 | git-config-name: "github-actions[bot]" 36 | git-config-email: "github-actions[bot]@users.noreply.github.com" 37 | commit-message: 🎉 Book deployed 38 | -------------------------------------------------------------------------------- /.github/workflows/format.yaml: -------------------------------------------------------------------------------- 1 | name: Formatting 2 | on: [push] 3 | jobs: 4 | format-all: 5 | name: 📀 Formatting 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: 🔔 Check out 9 | uses: actions/checkout@v3 10 | 11 | - name: 🏗️ python 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: "3.10" 15 | 16 | - name: ⬇️ Python PDM 17 | uses: pdm-project/setup-pdm@v4 18 | 19 | - name: ⬇️ Python Dependencies 20 | run: pdm install -G:all 21 | 22 | - name: 🚂 Activate environment 23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH 24 | 25 | - name: 🏃 autoflake, isort, black 26 | run: | 27 | autoflake -cr $(find -iname "*.py" ! -path '*/.venv/*' ! -name __init__.py) --remove-all-unused-imports 28 | isort --profile black --check . 29 | black --check . 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | pypi-publish: 10 | name: ⬆️ Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: read 14 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 15 | 16 | steps: 17 | - name: 🔔 Check out 18 | uses: actions/checkout@v3 19 | 20 | - name: 🏗️ python 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.13" 24 | 25 | - name: ⬇️ Python PDM 26 | uses: pdm-project/setup-pdm@v4 27 | with: 28 | cache: true 29 | 30 | - name: ⬇️ Python Dependencies 31 | run: pdm sync -G:all 32 | 33 | - name: 📰 Publish to PyPI 34 | run: pdm publish 35 | -------------------------------------------------------------------------------- /.github/workflows/typecheck.yaml: -------------------------------------------------------------------------------- 1 | name: Type Checking 2 | on: [push] 3 | jobs: 4 | type-check: 5 | name: 👨‍⚕️ Type Checking 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: 🔔 Check out 9 | uses: actions/checkout@v3 10 | 11 | - name: 🏗️ python 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: "3.10" 15 | 16 | - name: ⬇️ Python PDM 17 | uses: pdm-project/setup-pdm@v4 18 | 19 | - name: ⬇️ Python Dependencies 20 | run: pdm install -G:all 21 | 22 | - name: 🚂 Activate environment 23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH 24 | 25 | - name: 🏃 mypy 26 | run: mypy . --disable-error-code=import-untyped --disable-error-code=import-not-found 27 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: Unit Testing 2 | on: [push] 3 | jobs: 4 | unit-test: 5 | name: 🧪 Unit Testing 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: 🔔 Check out 9 | uses: actions/checkout@v3 10 | 11 | - name: 🏗️ python 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: "3.10" 15 | 16 | - name: ⬇️ Python PDM 17 | uses: pdm-project/setup-pdm@v4 18 | 19 | - name: ⬇️ Python Dependencies 20 | run: pdm install -G:all 21 | 22 | - name: 🚂 Activate environment 23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH 24 | 25 | - name: 🏃 pytest 26 | run: pytest -xv 27 | 28 | # - name: 🏃 pytest 29 | # run: coverage run -m pytest -v 30 | 31 | # - name: 📊 coverage 32 | # run: coverage report -m 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # General gitignore 2 | .DS_Store 3 | .vscode/ 4 | 5 | # Python gitignore 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | .pdm-python 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | 168 | 169 | _build/ 170 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, social-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [INSERT CONTACT METHOD]. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][faq]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [mozilla coc]: https://github.com/mozilla/diversity 131 | [faq]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 RenChu Wang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🐨 Koila 2 | 3 | > Koila solves `CUDA error: out of memory error` painlessly. 4 | > Fix it with just one line of code, and forget it. 5 | 6 | [![Unit Testing](https://github.com/rentruewang/koila/actions/workflows/unittest.yaml/badge.svg)](https://github.com/rentruewang/koila/actions/workflows/unittest.yaml) 7 | [![Type Checking](https://github.com/rentruewang/koila/actions/workflows/typecheck.yaml/badge.svg)](https://github.com/rentruewang/koila/actions/workflows/typecheck.yaml) 8 | [![Formatting](https://github.com/rentruewang/koila/actions/workflows/format.yaml/badge.svg)](https://github.com/rentruewang/koila/actions/workflows/format.yaml) 9 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 10 | [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Never%20worry%20about%20out%20of%20memory%20errors%20again&url=https://github.com/rentruewang/koila&hashtags=pytorch,outofmemory) 11 | 12 | ![Koila](./assets/koila.png) 13 | 14 | ## 🚨 Warning 15 | 16 | **Main branch is a complete re-structure of the project (that is currently mostly empty due to me not having enough time to complete it). To see working code, checkout the `v0.1.1` tag for a proof of concept (that doesn't have full support over all operations and is not suited for production). To use it, download [release v0.1.1 here](https://github.com/rentruewang/koila/releases/tag/v0.1.1).** 17 | 18 | ## 🚀 Features 19 | 20 | - 🙅 Prevents `CUDA error: out of memory error` with one single line of code. 21 | 22 | - ⚗️ Automatically accumulates gradients when batch sizes are too large. 23 | 24 | - 🦥 Lazily evaluates PyTorch code to save computing power. 25 | 26 | - ✂️ Automatically splits along the batch dimension to more GPU friendly numbers (2's powers) to speed up the execution. 27 | 28 | - 🤏 Minimal API (wrapping all inputs will be enough). 29 | 30 | ## 🤔 Why Koila? 31 | 32 | Ever encountered `RuntimeError: CUDA error: out of memory`? 33 | We all love `PyTorch` because of its speed, efficiency, and transparency, but that means it doesn't do extra things. Things like preventing a very common error that has been bothering many users since [2017](https://github.com/pytorch/pytorch/issues/958#issuecomment-285090162). 34 | 35 | This library aims to prevent that by being a light-weight wrapper over native `PyTorch`. When a tensor is wrapped, the library **automatically computes the amount of remaining GPU memory and uses the right batch size**, saving everyone from having to manually fine-tune the batch size whenever a model is used. 36 | 37 | Also, the library automatically uses the right batch size to GPU. Did you know that using bigger batches doesn't always speed up processing? It's handled automatically in this library too. 38 | 39 | Because `Koila` code is `PyTorch` code, as it runs `PyTorch` under the hood, you can use both together without worrying compatibility. 40 | 41 | Oh, and all that in 1 line of code! 😊 42 | 43 | ## ⬇️ Installation 44 | 45 | `Koila` is available on [PyPI](https://pypi.org/project/koila/). To install, run the following command. 46 | 47 | ```bash 48 | pip install koila 49 | ``` 50 | 51 | ## 🏃 Getting started 52 | 53 | The usage is dead simple. For example, you have the following `PyTorch` code (copied from `PyTorch`'s [tutorial](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html)) 54 | 55 | Define the input, label, and model: 56 | 57 | ```python 58 | # A batch of MNIST image 59 | input = torch.randn(8, 28, 28) 60 | 61 | # A batch of labels 62 | label = torch.randn(0, 10, [8]) 63 | 64 | class NeuralNetwork(Module): 65 | def __init__(self): 66 | super(NeuralNetwork, self).__init__() 67 | self.flatten = Flatten() 68 | self.linear_relu_stack = Sequential( 69 | Linear(28 * 28, 512), 70 | ReLU(), 71 | Linear(512, 512), 72 | ReLU(), 73 | Linear(512, 10), 74 | ) 75 | 76 | def forward(self, x): 77 | x = self.flatten(x) 78 | logits = self.linear_relu_stack(x) 79 | return logits 80 | ``` 81 | 82 | Define the loss function, calculate output and losses. 83 | 84 | ```python 85 | loss_fn = CrossEntropyLoss() 86 | 87 | # Calculate losses 88 | out = nn(t) 89 | loss = loss_fn(out, label) 90 | 91 | # Backward pass 92 | nn.zero_grad() 93 | loss.backward() 94 | ``` 95 | 96 | Ok. How to adapt the code to use `Koila`'s features? 97 | 98 | You add this line of code (as of v0.1.1): 99 | 100 | ```python 101 | # Wrap the input tensor and label tensor. 102 | # If a batch argument is provided, that dimension of the tensor would be treated as the batch. 103 | # In this case, the first dimension (dim=0) is used as batch's dimension. 104 | (input, label) = lazy(input, label, batch=0) 105 | ``` 106 | 107 | Done. You will not run out of memory again. 108 | 109 | ## 🏋️ How does it work under the hood? 110 | 111 | `CUDA error: out of memory` generally happens in forward pass, because temporary variables will need to be saved in memory. 112 | 113 | `Koila` is a thin wrapper around `PyTorch`. It is inspired by TensorFlow's static/lazy evaluation. By building the graph first, and run the model only when necessarily, the model has access to all the information necessarily to determine how much resources is really need to compute the model. 114 | 115 | In terms of memory usage, only **shapes of temporary variables are required to calculate the memory usage of those variables used in the model**. For example, `+` takes in two tensors with equal sizes, and outputs a tensor with a size equal to the input size, and `log` takes in one tensor, and outputs another tensor with the same shape. Broadcasting makes it a little more complicated than that, but the general ideas are the same. By tracking all these shapes, one could easily tell how much memory is used in a forward pass. And select the optimal batch size accordingly. 116 | 117 | ## 🐌 It sounds slow. Is it? 118 | 119 | **NO**. Indeed, calculating shapes and computing the size and memory usage sound like a lot of work. However, keep in mind that even a gigantic model like GPT-3, which has 96 layers, has only a few hundred nodes in its computing graph. Because `Koila`'s algorithms run in linear time, any modern computer will be able to handle a graph like this instantly. 120 | 121 | Most of the computing is spent on computing individual tensors, and transferring tensors across devices. And bear in mind that those checks happen in vanilla `PyTorch` anyways. So no, not slow at all. 122 | 123 | ## 🔊 How to pronounce koila? 124 | 125 | This project was originally named _koala_, the laziest species in the world, and this project is about lazy evaluation of tensors. However, as that name is taken on [PyPI](https://pypi.org/project/koala/), I had no choice but to use another name. `Koila` is a word made up by me, pronounced similarly to _voila_ (It's a French word), so sounds like koala. 126 | 127 | ## ⭐ Give me a star! 128 | 129 | If you like what you see, please consider giving this a star (★)! 130 | 131 | ## 🏗️ Why did I build this, despite similar libraries? 132 | 133 | Why did I go through the trouble and build this project, despite a lot of similar libraries on the internet? 134 | 135 | ### 🔎 Batch size search 136 | 137 | Batch size search is not new. In fact, the mighty popular [Lightning](https://lightning.ai/) has it. 138 | 139 | Lightning's batch size search is deeply integrated in its own ecosystem. You have to use its `DataLoader`, subclass from their models, and train your models accordingly. While refactoring supervised learning tasks to use lightning is relatively easy, it's really painful to do the same with a reinforcement learning code base, where interacting with the environment is a must. 140 | 141 | In comparison, because `Koila` is a super lightweight PyTorch wrapper, it works when PyTorch works, thus providing maximum flexibility and minimal changes to existing code. 142 | 143 | However, note that in the case where you're writing new code, Lightning is recommended as it enforces a better pattern of code style, which would benefit modularity in the long run. 144 | 145 | ### ♏ Symbolic pre-passing 146 | 147 | Likewise, passing an empty tensor to build a computational graph (AKA **static graph**) isn't a new idea, but thoroughly explored in the popular [TensorFlow](https://www.tensorflow.org/) library, and a similar `PyTorch` wrapper library [KeOps](https://www.kernel-operations.io/). These libraries suffer from the fact that debugging programs in them is unnecessarily complicated. For example, `TensorFlow` was known for its ease of deployment but pain in development, to the point that users switched to `PyTorch`. During debugging, people like to see what's _inside_ a variable, to see if it contains an incorrect value. However, because static graphs only define relations, the values are not computed, thus making debugging difficult. 148 | 149 | `Koila` solves that by eagerly evaluating when being converted to strings, integers, or any Python values. This enables seamless debugging while maintaining the ability to perform memory management that simply isn't available for a more straight forward `PyTorch` program, which dynamically (when needed) allocates and frees memory on the fly. 150 | 151 | ## 📝 Todos 152 | 153 | - 😌 Simplify internal workings even further. (Especially interaction between `Tensor`s and `LazyTensor`s). 154 | - 🧩 Provide an extensible API to write custom functions for the users. 155 | - 🍪 Work with multiple GPUs. 156 | 157 | ## 🚧 Caution 158 | 159 | The code works on many cases, but it's still a work in progress. This is not (yet) a fully `PyTorch` compatible library due to limited time. Avoid using it in production environments! 160 | 161 | ## 🥰 Contributing and Using 162 | 163 | Openness and inclusiveness are taken very seriously. The code is available under [Apache License](./LICENSE.md). Please follow the following [Code of Conduct](./CODE_OF_CONDUCT.md). 164 | -------------------------------------------------------------------------------- /assets/koila.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rentruewang/koila/4eeb42e971538142ce065ab627d191a6d2547573/assets/koila.png -------------------------------------------------------------------------------- /assets/koila.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 36 | 38 | 44 | 45 | 49 | 55 | 58 | 65 | 68 | 76 | 86 | 96 | 106 | 107 | 114 | 117 | 121 | 129 | 130 | 133 | 140 | 147 | 148 | 149 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /docs/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ```{include} ../CODE_OF_CONDUCT.md 2 | ``` 3 | -------------------------------------------------------------------------------- /docs/LICENSE.md: -------------------------------------------------------------------------------- 1 | # License 2 | 3 | ```{include} ../LICENSE.md 4 | ``` 5 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | title: Koila 2 | author: RenChu Wang 3 | copyright: RenChu Wang, 2024 4 | logo: assets/koila.png 5 | 6 | exclude_patterns: [_build] 7 | only_build_toc_files: true 8 | 9 | repository: 10 | url: https://github.com/rentruewang/koila 11 | 12 | html: 13 | use_repository_button: true 14 | 15 | execute: 16 | execute_notebooks: force 17 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | format: jb-article 2 | root: index 3 | sections: 4 | - file: LICENSE 5 | - file: CODE_OF_CONDUCT 6 | -------------------------------------------------------------------------------- /docs/assets: -------------------------------------------------------------------------------- 1 | ../assets/ -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../README.md 2 | ``` 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "koila" 3 | description = "Prevent PyTorch's `CUDA error out of memory` in a few lines of code" 4 | authors = [ 5 | {name = "RenChu Wang", email = "patrick1031wang@gmail.com"}, 6 | ] 7 | dependencies = [ 8 | "numpy>=1.26.3", 9 | "scipy>=1.11.4", 10 | "torch>=2.1.2", 11 | "black>=24.4.2", 12 | ] 13 | requires-python = ">=3.10" 14 | readme = "README.md" 15 | license = {text = "Apache-2.0"} 16 | dynamic = ["version"] 17 | 18 | [build-system] 19 | requires = ["setuptools", "wheel", "setuptools-scm"] 20 | build-backend = "setuptools.build_meta" 21 | 22 | [tool.setuptools_scm] 23 | 24 | [tool.pdm] 25 | distribution = true 26 | 27 | [tool.pdm.dev-dependencies] 28 | test = [ 29 | "coverage>=7.4.0", 30 | "pytest>=7.4.4", 31 | "pytest-cov>=4.1.0", 32 | "pytest-xdist>=3.5.0", 33 | ] 34 | format = [ 35 | "autoflake>=2.2.1", 36 | "black>=23.12.1", 37 | "isort>=5.13.2", 38 | ] 39 | website = [ 40 | "jupyter>=1.1.1", 41 | "jupyter-book>=1.0.3", 42 | "myst-parser>=2.0.0", 43 | ] 44 | type = [ 45 | "mypy>=1.8.0", 46 | ] 47 | -------------------------------------------------------------------------------- /src/koila/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from . import constants, gpus 4 | from .eager import EagerTensor 5 | from .errors import UnsupportedError 6 | from .interfaces import ( 7 | BatchedPair, 8 | BatchInfo, 9 | Runnable, 10 | RunnableTensor, 11 | TensorMixin, 12 | run, 13 | ) 14 | from .lazy import Evaluation, LazyFunction, LazyTensor, lazy 15 | from .prepasses import CallBack, MetaData, PrePass, PrePassFunc 16 | -------------------------------------------------------------------------------- /src/koila/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from typing import Dict 4 | 5 | import torch 6 | from torch import dtype 7 | 8 | UNITS: Dict[str, int] = { 9 | "b": 1, 10 | "kb": 10**3, 11 | "kib": 2**10, 12 | "mb": 10**6, 13 | "mib": 2**20, 14 | "gb": 10**9, 15 | "gib": 2**30, 16 | "tb": 10**4, 17 | "tib": 2**40, 18 | } 19 | 20 | MEMORY_BYTES: Dict[dtype, int] = { 21 | torch.bool: 1, 22 | torch.uint8: 1, 23 | torch.int8: 1, 24 | torch.short: 2, 25 | torch.int16: 2, 26 | torch.int: 4, 27 | torch.int32: 4, 28 | torch.long: 8, 29 | torch.int64: 8, 30 | torch.half: 2, 31 | torch.float16: 2, 32 | torch.float: 4, 33 | torch.float32: 4, 34 | torch.double: 8, 35 | torch.float64: 8, 36 | } 37 | -------------------------------------------------------------------------------- /src/koila/eager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | from typing import Any, Callable, Dict, Sequence, Tuple, Type 7 | 8 | from rich.logging import RichHandler 9 | from torch import Tensor 10 | from torch import device as Device 11 | from torch import dtype as DType 12 | 13 | from .interfaces import BatchInfo, RunnableTensor, TensorLike 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | LOGGER.addHandler(RichHandler()) 17 | LOGGER.setLevel(logging.DEBUG) 18 | 19 | # So, it seems that torch's Tensor base class utilizes metaclass 20 | # to pretend to be a parent of LongTensor, FloatTensor etc. 21 | # Perhaps I'll be using the same paradigm. 22 | 23 | 24 | class EagerTensor(RunnableTensor): 25 | def __init__(self, data: Tensor) -> None: 26 | self.data = data 27 | 28 | def __getattr__(self, name: str) -> Any: 29 | return getattr(self.data, name) 30 | 31 | def batch(self) -> BatchInfo | None: 32 | raise NotImplementedError 33 | 34 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor: 35 | del partial 36 | return self.data 37 | 38 | def visit(self, nodes: Dict[int, TensorLike]) -> None: 39 | raise NotImplementedError 40 | 41 | def device(self) -> str | Device: 42 | raise NotImplementedError 43 | 44 | def dtype(self) -> DType: 45 | raise NotImplementedError 46 | 47 | def size(self) -> Tuple[int, ...]: 48 | return self.data.size() 49 | 50 | @classmethod 51 | def __torch_function__( 52 | cls, 53 | func: Callable[..., Tensor], 54 | types: Tuple[Type[Any], ...], 55 | args: Sequence[TensorLike] = (), 56 | kwargs: Dict[str, TensorLike] | None = None, 57 | ) -> TensorLike: 58 | if kwargs is None: 59 | kwargs = {} 60 | 61 | if not all(issubclass(typ, (Tensor, EagerTensor)) for typ in types): 62 | return NotImplemented 63 | 64 | return EagerTensor(func(*args, **kwargs)) 65 | -------------------------------------------------------------------------------- /src/koila/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from typing import NoReturn 4 | 5 | 6 | class UnsupportedError(RuntimeError): 7 | "Sorry, this function is currently not supported." 8 | 9 | @classmethod 10 | def raise_error(cls, *args, **kwargs) -> NoReturn: 11 | del args 12 | del kwargs 13 | raise cls 14 | -------------------------------------------------------------------------------- /src/koila/gpus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import math 6 | from typing import Generator 7 | 8 | from pynvml.smi import nvidia_smi 9 | from torch import cuda 10 | 11 | from . import constants 12 | from .interfaces import BatchedPair 13 | 14 | NVSMI = None 15 | 16 | 17 | def nvidia_free_memory() -> int: 18 | """ 19 | Calls nvidia's nvml library and queries available GPU memory. 20 | Currently the function only works with 1 GPU. 21 | 22 | Returns 23 | ------- 24 | 25 | Free GPU memory in terms of bytes. 26 | """ 27 | 28 | global NVSMI 29 | if NVSMI is None: 30 | NVSMI = nvidia_smi.getInstance() 31 | 32 | assert NVSMI is not None 33 | query = NVSMI.DeviceQuery("memory.free") 34 | 35 | # Only works on one GPU as of now. 36 | gpu = query["gpu"][0]["fb_memory_usage"] 37 | 38 | unit = constants.UNITS[gpu["unit"].lower()] 39 | free = gpu["free"] 40 | 41 | return free * unit 42 | 43 | 44 | def torch_free_memory() -> int: 45 | """ 46 | Calls torch's memory statistics to calculate the amount of GPU memory unused. 47 | Currently the function only works with 1 GPU. 48 | 49 | Returns 50 | ------- 51 | 52 | Reserved GPU memory in terms of bytes. 53 | """ 54 | 55 | if not cuda.is_available(): 56 | return 0 57 | 58 | # Only works on one GPU as of now. 59 | 60 | reserved_memory = cuda.memory_reserved(0) 61 | active_memory = cuda.memory_allocated(0) 62 | unused_memory = reserved_memory - active_memory 63 | return unused_memory 64 | 65 | 66 | def free_memory() -> int | None: 67 | """ 68 | The amount of free GPU memory that can be used. 69 | 70 | Returns 71 | ------- 72 | 73 | Unused GPU memory, or None if no GPUs are available. 74 | """ 75 | 76 | if cuda.is_available(): 77 | return nvidia_free_memory() + torch_free_memory() 78 | else: 79 | return None 80 | 81 | 82 | def maximum_batch(memory: BatchedPair, total_memory: int | None = None) -> int | None: 83 | # batch * x + no_batch = unused_memoroy 84 | if total_memory is None: 85 | total_memory = free_memory() 86 | 87 | if total_memory is None: 88 | return None 89 | 90 | return (total_memory - memory.no_batch) // memory.batch 91 | 92 | 93 | def split_batch( 94 | memory: BatchedPair, current_batch: int, total_memory: int | None = None 95 | ) -> Generator[int, None, None]: 96 | max_batch = maximum_batch(memory, total_memory) 97 | 98 | if max_batch is None: 99 | yield current_batch 100 | return 101 | 102 | batch_size = 2 ** (math.floor(math.log2(max_batch))) 103 | (times, current_batch) = divmod(current_batch, batch_size) 104 | 105 | for _ in range(times): 106 | yield batch_size 107 | 108 | while current_batch > 0: 109 | batch_size >>= 1 110 | if current_batch >= batch_size: 111 | current_batch -= batch_size 112 | yield batch_size 113 | assert current_batch < batch_size, [current_batch, batch_size] 114 | -------------------------------------------------------------------------------- /src/koila/interfaces.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import functools 6 | import operator 7 | from abc import abstractmethod 8 | from typing import ( 9 | Any, 10 | Callable, 11 | Dict, 12 | NamedTuple, 13 | Protocol, 14 | Tuple, 15 | TypeVar, 16 | Union, 17 | overload, 18 | runtime_checkable, 19 | ) 20 | 21 | from torch import Tensor 22 | from torch import device as Device 23 | from torch import dtype as DType 24 | 25 | from . import constants 26 | 27 | E = TypeVar("E") 28 | T = TypeVar("T", covariant=True) 29 | V = TypeVar("V", contravariant=True) 30 | 31 | 32 | @runtime_checkable 33 | class Runnable(Protocol[T]): 34 | @abstractmethod 35 | def run(self) -> T: ... 36 | 37 | 38 | @runtime_checkable 39 | class TensorMixin(Protocol): 40 | @overload 41 | @abstractmethod 42 | def size(self) -> Tuple[int, ...]: ... 43 | 44 | @overload 45 | @abstractmethod 46 | def size(self, dim: int) -> int: ... 47 | 48 | @abstractmethod 49 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]: ... 50 | 51 | def numel(self) -> int: 52 | return functools.reduce(operator.mul, self.size(), 1) 53 | 54 | def dim(self) -> int: 55 | return len(self.size()) 56 | 57 | @abstractmethod 58 | def dtype(self) -> DType: ... 59 | 60 | @abstractmethod 61 | def device(self) -> str | Device: ... 62 | 63 | 64 | class BatchedPair(NamedTuple): 65 | batch: int 66 | no_batch: int 67 | 68 | 69 | class BatchInfo(NamedTuple): 70 | index: int 71 | value: int 72 | 73 | def map(self, func: Callable[[int], int]) -> BatchInfo: 74 | index = func(self.index) 75 | return BatchInfo(index, self.value) 76 | 77 | 78 | @runtime_checkable 79 | class RunnableTensor(Runnable[Tensor], TensorMixin, Protocol): 80 | @abstractmethod 81 | def batch(self) -> BatchInfo | None: ... 82 | 83 | @abstractmethod 84 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor: ... 85 | 86 | @abstractmethod 87 | def visit(self, nodes: Dict[int, TensorLike]) -> None: ... 88 | 89 | def buffer(self) -> Dict[int, TensorLike]: 90 | nodes = {} 91 | self.visit(nodes) 92 | return nodes 93 | 94 | def buffer_numel(self) -> BatchedPair: 95 | buffer = self.buffer().values() 96 | return BatchedPair( 97 | sum(t.numel() for t in buffer if bat(t) is not None), 98 | sum(t.numel() for t in buffer if bat(t) is None), 99 | ) 100 | 101 | def buffer_memory(self) -> BatchedPair: 102 | buffer = self.buffer().values() 103 | return BatchedPair( 104 | sum(mem(t) for t in buffer if bat(t) is not None), 105 | sum(mem(t) for t in buffer if bat(t) is None), 106 | ) 107 | 108 | def memory(self) -> int: 109 | return mem(self) 110 | 111 | 112 | def dtyp(tensor: TensorLike) -> DType: 113 | if isinstance(tensor, Tensor): 114 | return tensor.dtype 115 | 116 | return tensor.dtype() 117 | 118 | 119 | def dev(tensor: TensorLike) -> str | Device: 120 | if isinstance(tensor, Tensor): 121 | return tensor.device 122 | 123 | return tensor.device() 124 | 125 | 126 | def mem(tensor: TensorLike) -> int: 127 | dt = dtyp(tensor) 128 | numel = tensor.numel() 129 | 130 | if (batch := bat(tensor)) is not None: 131 | numel //= batch.value 132 | 133 | return constants.MEMORY_BYTES[dt] * numel 134 | 135 | 136 | def bat(tensor: TensorLike) -> BatchInfo | None: 137 | if isinstance(tensor, RunnableTensor): 138 | return tensor.batch() 139 | return None 140 | 141 | 142 | TensorLike = Union[Tensor, RunnableTensor] 143 | 144 | 145 | @overload 146 | def run(val: RunnableTensor, partial: Tuple[int, int] | None = None) -> Tensor: ... 147 | 148 | 149 | @overload 150 | def run(val: Runnable[E], partial: Tuple[int, int] | None = None) -> E: ... 151 | 152 | 153 | @overload 154 | def run(val: E, partial: Tuple[int, int] | None = None) -> E: ... 155 | 156 | 157 | def run(val: Any, partial: Tuple[int, int] | None = None) -> Any: 158 | if isinstance(val, RunnableTensor): 159 | return val.run(partial) 160 | 161 | if isinstance(val, Runnable): 162 | return val.run() 163 | 164 | return val 165 | -------------------------------------------------------------------------------- /src/koila/lazy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import builtins 6 | import dataclasses as dcls 7 | import functools 8 | import logging 9 | from dataclasses import dataclass 10 | from functools import wraps 11 | from typing import ( 12 | Any, 13 | Callable, 14 | Dict, 15 | Generic, 16 | List, 17 | NamedTuple, 18 | NoReturn, 19 | Sequence, 20 | Tuple, 21 | Type, 22 | TypeVar, 23 | final, 24 | overload, 25 | ) 26 | 27 | import torch 28 | from rich.logging import RichHandler 29 | from torch import Tensor, cuda 30 | from torch import device as Device 31 | from torch import dtype as DType 32 | 33 | from . import gpus, interfaces, prepasses 34 | from .errors import UnsupportedError 35 | from .interfaces import BatchInfo, RunnableTensor, TensorLike 36 | from .prepasses import PrePass, PrePassFunc 37 | 38 | T = TypeVar("T") 39 | V = TypeVar("V", contravariant=True) 40 | 41 | LOGGER = logging.getLogger(__name__) 42 | LOGGER.addHandler(RichHandler()) 43 | 44 | 45 | @dataclass(frozen=True) 46 | class LazyFunction(Generic[V]): 47 | func: Callable[..., Tensor] 48 | prepass_func: PrePassFunc 49 | 50 | def __call__(self, *args: Any, **kwargs: Any) -> LazyTensor: 51 | lazy_args = tuple(lazy(arg) for arg in args) 52 | lazy_kwargs = dict((k, lazy(v)) for (k, v) in kwargs.items()) 53 | prepass = self.prepass_func(*args, **kwargs) 54 | return LazyTensor(Evaluation(self.func, prepass, *lazy_args, **lazy_kwargs)) 55 | 56 | def __get__(self, obj: V, objtype: Type[V]) -> Callable[..., LazyTensor]: 57 | assert isinstance(obj, objtype), [type(obj), objtype] 58 | if obj is None: 59 | return self 60 | else: 61 | return functools.partial(self, obj) 62 | 63 | 64 | @final 65 | @dataclass(init=False) 66 | class Evaluation(RunnableTensor): 67 | func: Callable[..., Tensor] 68 | prepass: PrePass 69 | args: Tuple[LazyTensor | Tensor | int | float | bool, ...] = dcls.field( 70 | default_factory=tuple 71 | ) 72 | kwargs: Dict[str, LazyTensor | Tensor | int | float | bool] = dcls.field( 73 | default_factory=dict 74 | ) 75 | 76 | def __init__( 77 | self, 78 | func: Callable[..., Tensor], 79 | prepass: PrePass, 80 | *args: LazyTensor | Tensor | int | float | bool, 81 | **kwargs: LazyTensor | Tensor | int | float | bool, 82 | ) -> None: 83 | self.func = func 84 | self.prepass = prepass 85 | self.args = args 86 | self.kwargs = kwargs 87 | 88 | def __hash__(self) -> int: 89 | # Evaluations are unique. 90 | return id(self) 91 | 92 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor: 93 | real_args = [interfaces.run(arg, partial) for arg in self.args] 94 | real_kwargs = {k: interfaces.run(v, partial) for (k, v) in self.kwargs.items()} 95 | 96 | result = self.func(*real_args, **real_kwargs) 97 | 98 | # Checks the shape only when pre-passing. 99 | # If partial is supplemented, it means the tensors are really evaluated 100 | if partial is None: 101 | assert self.prepass.shape == result.shape, [self.prepass, result.shape] 102 | elif (reducer := self.prepass.reducer()) is None: 103 | raise UnsupportedError("Cannot safely parallelize.") 104 | else: 105 | LOGGER.debug( 106 | "Evaluation taking batch: (%s, %s), low=%s, high=%s", 107 | self.size(), 108 | self.batch(), 109 | partial[0], 110 | partial[1], 111 | ) 112 | callback = reducer(input, *self.args, **self.kwargs) 113 | result = callback(result) 114 | 115 | return result 116 | 117 | def visit(self, nodes: Dict[int, TensorLike]) -> None: 118 | if hash(self) in nodes.keys(): 119 | return 120 | 121 | for arg in self.args: 122 | if isinstance(arg, Tensor): 123 | nodes[hash(arg)] = arg 124 | elif isinstance(arg, RunnableTensor): 125 | arg.visit(nodes) 126 | 127 | for val in self.kwargs.values(): 128 | if isinstance(val, Tensor): 129 | nodes[hash(val)] = val 130 | elif isinstance(val, RunnableTensor): 131 | val.visit(nodes) 132 | 133 | assert hash(self) not in nodes.keys() 134 | nodes[hash(self)] = self 135 | 136 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]: 137 | shape = self.prepass.shape 138 | if dim is not None: 139 | return shape[dim] 140 | else: 141 | return shape 142 | 143 | def dtype(self) -> DType: 144 | return self.prepass.dtype() 145 | 146 | def device(self) -> str | Device: 147 | return self.prepass.device() 148 | 149 | def batch(self) -> BatchInfo | None: 150 | return self.prepass.batch() 151 | 152 | 153 | @final 154 | @dataclass(init=False, repr=False) 155 | class LazyTensor(RunnableTensor): 156 | _data: TensorLike 157 | _batch: BatchInfo | None = None 158 | 159 | def __init__(self, data: TensorLike, batch: int | None = None) -> None: 160 | if isinstance(data, LazyTensor): 161 | self._data = data._data 162 | self._batch = data._batch 163 | elif isinstance(data, Evaluation): 164 | self._data = data 165 | self._batch = data.batch() 166 | else: 167 | self._data = data 168 | if batch is None: 169 | self._batch = None 170 | else: 171 | self._batch = BatchInfo(batch, data.size(batch)) 172 | 173 | LOGGER.debug("Creating LazyTensor. %s, %s", type(self._data), self._batch) 174 | 175 | # Implementations 176 | 177 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor: 178 | data = self._data 179 | if isinstance(data, Tensor): 180 | if partial is None or self._batch is None: 181 | return data 182 | else: 183 | (low, high) = partial 184 | return data.index_select( 185 | self._batch.index, 186 | torch.tensor(list(range(low, high)), device=data.device), 187 | ) 188 | else: 189 | return data.run(partial) 190 | 191 | def visit(self, nodes: Dict[int, TensorLike]) -> None: 192 | data = self._data 193 | 194 | if hash(self) in nodes.keys(): 195 | return 196 | 197 | if isinstance(data, Evaluation): 198 | data.visit(nodes) 199 | else: 200 | nodes[hash(self)] = self 201 | 202 | assert hash(self) in nodes.keys() 203 | 204 | @overload 205 | def size(self) -> Tuple[int, ...]: ... 206 | 207 | @overload 208 | def size(self, dim: int) -> int: ... 209 | 210 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]: 211 | data = self._data 212 | 213 | if dim is None: 214 | return data.size() 215 | 216 | return data.size(dim) 217 | 218 | def dtype(self) -> DType: 219 | dt = interfaces.dtyp(self._data) 220 | return dt 221 | 222 | def device(self) -> str | Device: 223 | return interfaces.dev(self._data) 224 | 225 | def batch(self) -> BatchInfo | None: 226 | return self._batch 227 | 228 | # Magic methods 229 | 230 | def __str__(self) -> str: 231 | return f"LazyTensor {self.run()}" 232 | 233 | def __bool__(self) -> bool: 234 | return bool(self.item()) 235 | 236 | def __int__(self) -> int: 237 | return int(self.item()) 238 | 239 | def __float__(self) -> float: 240 | return float(self.item()) 241 | 242 | def __invert__(self) -> bool: 243 | return not bool(self) 244 | 245 | def __pos__(self) -> TensorLike: 246 | return lazy_forward(Tensor.__pos__, prepasses.identity, self) 247 | 248 | def __neg__(self) -> TensorLike: 249 | return lazy_forward(Tensor.__neg__, prepasses.identity, self) 250 | 251 | def __add__(self, other: TensorLike) -> TensorLike: 252 | return lazy_forward(Tensor.__add__, prepasses.symmetric, self, other) 253 | 254 | def __radd__(self, other: TensorLike) -> TensorLike: 255 | return lazy_forward(Tensor.__add__, prepasses.symmetric, other, self) 256 | 257 | def __sub__(self, other: TensorLike) -> TensorLike: 258 | return lazy_forward(Tensor.__sub__, prepasses.symmetric, self, other) 259 | 260 | def __rsub__(self, other: TensorLike) -> TensorLike: 261 | return lazy_forward(Tensor.__sub__, prepasses.symmetric, other, self) 262 | 263 | def __mul__(self, other: TensorLike) -> TensorLike: 264 | return lazy_forward(Tensor.__mul__, prepasses.symmetric, self, other) 265 | 266 | def __rmul__(self, other: TensorLike) -> TensorLike: 267 | return lazy_forward(Tensor.__mul__, prepasses.symmetric, other, self) 268 | 269 | def __truediv__(self, other: TensorLike) -> TensorLike: 270 | return lazy_forward(Tensor.__truediv__, prepasses.symmetric, self, other) 271 | 272 | def __rtruediv__(self, other: TensorLike) -> TensorLike: 273 | return lazy_forward(Tensor.__truediv__, prepasses.symmetric, other, self) 274 | 275 | def __floordiv__(self, other: TensorLike) -> NoReturn: 276 | del other 277 | raise UnsupportedError 278 | 279 | def __rfloordiv__(self, other: TensorLike) -> NoReturn: 280 | del other 281 | raise UnsupportedError 282 | 283 | def __pow__(self, other: TensorLike) -> TensorLike: 284 | return lazy_forward(Tensor.__pow__, prepasses.symmetric, self, other) 285 | 286 | def __rpow__(self, other: TensorLike) -> TensorLike: 287 | return lazy_forward(Tensor.__pow__, prepasses.symmetric, other, self) 288 | 289 | def __mod__(self, other: TensorLike) -> TensorLike: 290 | return lazy_forward(Tensor.__mod__, prepasses.symmetric, self, other) 291 | 292 | def __rmod__(self, other: TensorLike) -> TensorLike: 293 | return lazy_forward(Tensor.__mod__, prepasses.symmetric, other, self) 294 | 295 | def __divmod__(self, other: TensorLike) -> NoReturn: 296 | del other 297 | raise UnsupportedError 298 | 299 | def __rdivmod__(self, other: TensorLike) -> NoReturn: 300 | del other 301 | raise UnsupportedError 302 | 303 | def __abs__(self) -> TensorLike: 304 | return lazy_forward(Tensor.__abs__, prepasses.identity, self) 305 | 306 | def __hash__(self) -> int: 307 | # LazyTensors are not unique. They are defined by their data. 308 | return id(self._data) 309 | 310 | def __matmul__(self, other: TensorLike) -> TensorLike: 311 | return lazy_forward(Tensor.__matmul__, prepasses.matmul, self, other) 312 | 313 | def __rmatmul__(self, other: TensorLike) -> TensorLike: 314 | return lazy_forward(Tensor.__matmul__, prepasses.matmul, other, self) 315 | 316 | def __eq__(self, other: TensorLike) -> TensorLike: 317 | return lazy_forward(Tensor.__eq__, prepasses.symmetric, self, other) 318 | 319 | def __ne__(self, other: TensorLike) -> TensorLike: 320 | return lazy_forward(Tensor.__ne__, prepasses.symmetric, self, other) 321 | 322 | def __gt__(self, other: TensorLike) -> TensorLike: 323 | return lazy_forward(Tensor.__gt__, prepasses.symmetric, self, other) 324 | 325 | def __ge__(self, other: TensorLike) -> TensorLike: 326 | return lazy_forward(Tensor.__ge__, prepasses.symmetric, self, other) 327 | 328 | def __lt__(self, other: TensorLike) -> TensorLike: 329 | return lazy_forward(Tensor.__lt__, prepasses.symmetric, self, other) 330 | 331 | def __le__(self, other: TensorLike) -> TensorLike: 332 | return lazy_forward(Tensor.__le__, prepasses.symmetric, self, other) 333 | 334 | def __len__(self) -> int: 335 | return self.size(0) 336 | 337 | def __getitem__( 338 | self, index: int | slice | Tensor | List[Any] | Tuple[Any] | None 339 | ) -> Tensor: 340 | if isinstance(self._data, RunnableTensor): 341 | data = self._data.run() 342 | else: 343 | data = self._data 344 | return data[index] 345 | 346 | def __setitem__( 347 | self, 348 | index: int | slice | Tensor | List[Any] | Tuple[Any] | None, 349 | value: Tensor, 350 | ) -> None: 351 | if isinstance(self._data, RunnableTensor): 352 | raise UnsupportedError 353 | 354 | self._data[index] = value 355 | 356 | def __getattr__(self, name: str) -> Callable[..., Any]: 357 | LOGGER.debug( 358 | f"__getattr__ called for {name}. Automatically resolving function." 359 | ) 360 | 361 | method = getattr(Tensor, name) 362 | wrapper = functools.wraps(method) 363 | 364 | if (custom_impl := CUSTOM_OPS.lookup_method(name)) is not None: 365 | LOGGER.debug("A custom method definition is found.") 366 | partial = functools.partial(custom_impl, self) 367 | elif (shape_impl := SHAPE_OPS.lookup_method(name)) is not None: 368 | LOGGER.debug("A custom shape method is found. Lazy evaluation.") 369 | partial = functools.partial(lazy_forward, method, shape_impl, self) 370 | else: 371 | LOGGER.debug("No custom methods found. Evaluating eagerly.") 372 | partial = functools.partial(method, interfaces.run(self)) 373 | 374 | return wrapper(partial) 375 | 376 | @classmethod 377 | def __torch_function__( 378 | cls, 379 | func: Callable[..., Tensor], 380 | types: Tuple[Type[Any], ...], 381 | args: Sequence[TensorLike] = (), 382 | kwargs: Dict[str, TensorLike] | None = None, 383 | ) -> TensorLike: 384 | if kwargs is None: 385 | kwargs = {} 386 | 387 | if not builtins.all( 388 | issubclass(typ, (LazyTensor, Tensor, int, float, bool)) for typ in types 389 | ): 390 | return NotImplemented 391 | 392 | name = func.__name__ 393 | 394 | if (custom_impl := CUSTOM_OPS.lookup_function(name)) is not None: 395 | LOGGER.debug("A custom function definition is found.") 396 | return custom_impl(*args, **kwargs) 397 | elif (shape_impl := SHAPE_OPS.lookup_function(name)) is not None: 398 | LOGGER.debug("A custom shape function is found. Lazy evaluation.") 399 | return lazy_forward(func, shape_impl, *args, **kwargs) 400 | else: 401 | LOGGER.debug("No custom method found. Evaluating eagerly.") 402 | args = [interfaces.run(arg) for arg in args] 403 | kwargs = {k: interfaces.run(v) for (k, v) in kwargs.items()} 404 | return func(*args, **kwargs) 405 | 406 | @property 407 | @wraps(Tensor.size) 408 | def shape(self) -> Tuple[int, ...]: 409 | return self.size() 410 | 411 | @property 412 | @wraps(Tensor.dim) 413 | def ndim(self) -> int: 414 | return self.dim() 415 | 416 | @property 417 | @wraps(Tensor.t) 418 | def T(self) -> TensorLike: 419 | return self.t() 420 | 421 | def torch(self) -> Tensor: 422 | return self.run() 423 | 424 | def backward(self) -> None: 425 | if self._batch is None or not cuda.is_available(): 426 | LOGGER.debug( 427 | "Unable to parallelize across batches." 428 | " " 429 | "Running backward with native pytorch." 430 | ) 431 | self.run().backward() 432 | else: 433 | total = 0 434 | LOGGER.debug("Able to parallelize across batches. Hooray!") 435 | for mini_batch_size in gpus.split_batch( 436 | self.buffer_memory(), self._batch.value 437 | ): 438 | LOGGER.debug("Using mini batch size: %d.", mini_batch_size) 439 | mini_batch = self.run((total, total + mini_batch_size)) 440 | total += mini_batch_size 441 | mini_batch.backward() 442 | 443 | 444 | @overload 445 | def lazy(val: Tensor | LazyTensor, batch: int | None = None) -> LazyTensor: ... 446 | 447 | 448 | @overload 449 | def lazy( 450 | *val: Tensor | LazyTensor, batch: int | None = None 451 | ) -> Tuple[LazyTensor, ...]: ... 452 | 453 | 454 | @overload 455 | def lazy(val: int) -> int: ... 456 | 457 | 458 | @overload 459 | def lazy(*val: int) -> Tuple[int, ...]: ... 460 | 461 | 462 | @overload 463 | def lazy(val: float) -> float: ... 464 | 465 | 466 | @overload 467 | def lazy(*val: float) -> Tuple[float, ...]: ... 468 | 469 | 470 | @overload 471 | def lazy(val: bool) -> bool: ... 472 | 473 | 474 | @overload 475 | def lazy(*val: bool) -> Tuple[bool, ...]: ... 476 | 477 | 478 | def lazy(*values: Any, batch: int | None = None) -> Any: 479 | results = [] 480 | for val in values: 481 | LOGGER.debug("lazy %s, %s", type(val), interfaces.bat(val)) 482 | 483 | if isinstance(val, Tensor): 484 | val = LazyTensor(val, batch) 485 | 486 | results.append(val) 487 | 488 | if len(results) == 1: 489 | return results[0] 490 | 491 | return tuple(results) 492 | 493 | 494 | def lazy_forward( 495 | func: Callable[..., Any], shape_func: PrePassFunc, *args: Any, **kwargs: Any 496 | ) -> TensorLike: 497 | if torch.is_grad_enabled(): 498 | out = LazyFunction(func, shape_func)(*args, **kwargs) 499 | LOGGER.debug("lazy forward %s, %s", out.size(), out.batch()) 500 | return out 501 | else: 502 | run_args = [interfaces.run(arg) for arg in args] 503 | run_kwargs = {k: interfaces.run(v) for (k, v) in kwargs.items()} 504 | out = func(*run_args, **run_kwargs) 505 | LOGGER.debug("eager forward (%s, %s) -> %s", run_args, run_kwargs, out) 506 | return out 507 | 508 | 509 | # Functions that require special handling. 510 | 511 | 512 | class _ValIdx(NamedTuple): 513 | values: TensorLike 514 | indices: TensorLike 515 | 516 | 517 | @overload 518 | def _min(input: TensorLike) -> TensorLike: ... 519 | 520 | 521 | @overload 522 | def _min(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ... 523 | 524 | 525 | @overload 526 | def _min(input: TensorLike, other: TensorLike) -> TensorLike: ... 527 | 528 | 529 | @wraps(torch.min) 530 | def _min(input: TensorLike, *args: Any, **kwargs: Any) -> TensorLike | _ValIdx: 531 | if len(args) == len(kwargs) == 0: 532 | return lazy_forward(torch.min, prepasses.reduce_dims, input) 533 | 534 | if ( 535 | len(args) == 1 536 | and isinstance((other := args[0]), (Tensor, LazyTensor)) 537 | or len(kwargs) == 1 538 | and (other := kwargs.get("other", None) is not None) 539 | ): 540 | return lazy_forward(torch.minimum, prepasses.symmetric, input, other) 541 | 542 | return _ValIdx( 543 | lazy_forward(torch.amin, prepasses.reduce_dims, input, *args, **kwargs), 544 | lazy_forward(torch.argmin, prepasses.reduce_dims, input, *args, **kwargs), 545 | ) 546 | 547 | 548 | @overload 549 | def _max(input: TensorLike) -> TensorLike: ... 550 | 551 | 552 | @overload 553 | def _max(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ... 554 | 555 | 556 | @overload 557 | def _max(input: TensorLike, other: TensorLike) -> TensorLike: ... 558 | 559 | 560 | @wraps(torch.max) 561 | def _max(input: TensorLike, *args: Any, **kwargs: Any) -> TensorLike | _ValIdx: 562 | if len(args) == len(kwargs) == 0: 563 | return lazy_forward(torch.max, prepasses.reduce_dims, input) 564 | 565 | if ( 566 | len(args) == 1 567 | and isinstance((other := args[0]), (Tensor, LazyTensor)) 568 | or len(kwargs) == 1 569 | and (other := kwargs.get("other", None) is not None) 570 | ): 571 | return lazy_forward(torch.maximum, prepasses.symmetric, input, other) 572 | 573 | return _ValIdx( 574 | lazy_forward(torch.amax, prepasses.reduce_dims, input, *args, **kwargs), 575 | lazy_forward(torch.argmax, prepasses.reduce_dims, input, *args, **kwargs), 576 | ) 577 | 578 | 579 | def _permute_function_shape( 580 | input: TensorLike, dims: int | Tuple[int, ...], *args: Any, **kwargs: Any 581 | ) -> PrePass: 582 | prepasses.mute_unused_args(*args, **kwargs) 583 | 584 | if isinstance(dims, int): 585 | dims = (dims,) 586 | 587 | return prepasses.permute(input, *dims) 588 | 589 | 590 | def _reshape_function_shape( 591 | input: TensorLike, dims: Tuple[int, ...], *args: Any, **kwargs: Any 592 | ) -> PrePass: 593 | prepasses.mute_unused_args(*args, **kwargs) 594 | 595 | return prepasses.reshape(input, *dims) 596 | 597 | 598 | def _t_shape(input: TensorLike, *args: Any, **kwargs: Any) -> PrePass: 599 | prepasses.mute_unused_args(*args, **kwargs) 600 | 601 | return prepasses.tranpose(input, 0, 1) 602 | 603 | 604 | @dataclass 605 | class MethodFunction(Generic[T]): 606 | method: Dict[str, T] 607 | function: Dict[str, T] 608 | 609 | @staticmethod 610 | def _search(key: str, *dbs: Dict[str, T]) -> T | None: 611 | for db in dbs: 612 | if (value := db.get(key)) is not None: 613 | return value 614 | return None 615 | 616 | def lookup(self, key: str, *dbs: Dict[str, T]) -> T | None: 617 | if (result := self._search(key, *dbs)) is not None: 618 | return result 619 | 620 | if key.startswith("_"): 621 | fallback = key.lstrip("_") 622 | return self._search(fallback, *dbs) 623 | return None 624 | 625 | def lookup_method(self, key: str) -> T | None: 626 | return self.lookup(key, self.method, self.function) 627 | 628 | def lookup_function(self, key: str) -> T | None: 629 | return self.lookup(key, self.function) 630 | 631 | 632 | CUSTOM_OPS = MethodFunction[Callable]( 633 | method={}, 634 | function={ 635 | "min": _min, 636 | "max": _max, 637 | }, 638 | ) 639 | 640 | PARTIAL_OPS = MethodFunction[Callable](method={}, function={"sum": lambda x: x}) 641 | 642 | SHAPE_OPS = MethodFunction[PrePassFunc]( 643 | method={"permute": prepasses.permute, "view": prepasses.view}, 644 | function={ 645 | "positive": prepasses.identity, 646 | "negative": prepasses.identity, 647 | "neg": prepasses.identity, 648 | "add": prepasses.symmetric, 649 | "sub": prepasses.symmetric, 650 | "subtract": prepasses.symmetric, 651 | "mul": prepasses.symmetric, 652 | "multiply": prepasses.symmetric, 653 | "div": prepasses.symmetric, 654 | "divide": prepasses.symmetric, 655 | "true_divide": prepasses.symmetric, 656 | "floor": prepasses.identity, 657 | "fmod": prepasses.symmetric, 658 | "remainder": prepasses.symmetric, 659 | "frac": prepasses.identity, 660 | "pow": prepasses.symmetric, 661 | "exp": prepasses.identity, 662 | "exp2": prepasses.identity, 663 | "log": prepasses.identity, 664 | "log2": prepasses.identity, 665 | "log10": prepasses.identity, 666 | "log1p": prepasses.identity, 667 | "abs": prepasses.identity, 668 | "matmul": prepasses.matmul, 669 | "bmm": prepasses.matmul, 670 | "mm": prepasses.matmul, 671 | "mv": prepasses.matmul, 672 | "dot": prepasses.matmul, 673 | "eq": prepasses.symmetric, 674 | "equal": prepasses.symmetric, 675 | "ne": prepasses.symmetric, 676 | "not_equal": prepasses.symmetric, 677 | "gt": prepasses.symmetric, 678 | "greater": prepasses.symmetric, 679 | "ge": prepasses.symmetric, 680 | "greater_equal": prepasses.symmetric, 681 | "lt": prepasses.symmetric, 682 | "less": prepasses.symmetric, 683 | "le": prepasses.symmetric, 684 | "less_equal": prepasses.symmetric, 685 | "mean": prepasses.mean, 686 | "sum": prepasses.reduce_dims, 687 | "std": prepasses.reduce_dims, 688 | "minimum": prepasses.symmetric, 689 | "maximum": prepasses.symmetric, 690 | "amin": prepasses.reduce_dims, 691 | "amax": prepasses.reduce_dims, 692 | "argmin": prepasses.reduce_dims, 693 | "argmax": prepasses.reduce_dims, 694 | "isclose": prepasses.symmetric, 695 | "cat": prepasses.cat, 696 | "t": _t_shape, 697 | "permute": _permute_function_shape, 698 | "reshape": _reshape_function_shape, 699 | "flatten": prepasses.flatten, 700 | "transpose": prepasses.tranpose, 701 | "select": prepasses.select, 702 | "index_select": prepasses.select, 703 | "sin": prepasses.identity, 704 | "cos": prepasses.identity, 705 | "tan": prepasses.identity, 706 | "asin": prepasses.identity, 707 | "acos": prepasses.identity, 708 | "atan": prepasses.identity, 709 | "sinh": prepasses.identity, 710 | "cosh": prepasses.identity, 711 | "tanh": prepasses.identity, 712 | "asinh": prepasses.identity, 713 | "acosh": prepasses.identity, 714 | "atanh": prepasses.identity, 715 | "sigmoid": prepasses.identity, 716 | "hardsigmoid": prepasses.identity, 717 | "softmax": prepasses.identity, 718 | "relu": prepasses.identity, 719 | "relu6": prepasses.identity, 720 | "leaky_relu": prepasses.identity, 721 | "l1_loss": prepasses.loss, 722 | "smooth_l1_loss": prepasses.loss, 723 | "mse_loss": prepasses.loss, 724 | "cross_entropy": prepasses.loss, 725 | "binary_cross_entropy": prepasses.loss, 726 | "binary_cross_entropy_with_logits": prepasses.loss, 727 | "elu": prepasses.identity, 728 | "gelu": prepasses.identity, 729 | "dropout": prepasses.identity, 730 | "batch_norm": prepasses.identity, 731 | "layer_norm": prepasses.identity, 732 | "linear": prepasses.linear, 733 | "embedding": prepasses.embedding, 734 | "pad": prepasses.pad, 735 | "conv1d": prepasses.conv, 736 | "conv2d": prepasses.conv, 737 | "conv3d": prepasses.conv, 738 | "conv_transpose1d": prepasses.conv_transpose, 739 | "conv_transpose2d": prepasses.conv_transpose, 740 | "conv_transpose3d": prepasses.conv_transpose, 741 | "max_pool1d": prepasses.maxpool, 742 | "max_pool2d": prepasses.maxpool, 743 | "max_pool3d": prepasses.maxpool, 744 | "avg_pool1d": prepasses.avgpool, 745 | "avg_pool2d": prepasses.avgpool, 746 | "avg_pool3d": prepasses.avgpool, 747 | # Functions that will not be implemented. 748 | "__floordiv__": UnsupportedError.raise_error, 749 | }, 750 | ) 751 | -------------------------------------------------------------------------------- /src/koila/prepasses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import functools 6 | import logging 7 | import math 8 | import operator 9 | from abc import abstractmethod 10 | from dataclasses import dataclass 11 | from typing import ( 12 | Any, 13 | List, 14 | Literal, 15 | Protocol, 16 | Sequence, 17 | Tuple, 18 | overload, 19 | runtime_checkable, 20 | ) 21 | 22 | from rich.logging import RichHandler 23 | from torch import device as Device 24 | from torch import dtype as DType 25 | from torch.functional import Tensor 26 | 27 | from . import constants, interfaces, shapes 28 | from .errors import UnsupportedError 29 | from .interfaces import BatchInfo, TensorLike 30 | 31 | LOGGER = logging.getLogger(__name__) 32 | LOGGER.addHandler(RichHandler()) 33 | 34 | 35 | class CallBack(Protocol): 36 | @abstractmethod 37 | def __call__(self, *args: Any, **kwargs: Any) -> Reducer: ... 38 | 39 | 40 | class Reducer(Protocol): 41 | @abstractmethod 42 | def __call__(self, result: Tensor, /) -> Tensor: ... 43 | 44 | 45 | @dataclass(frozen=True) 46 | class MetaData: 47 | dtype: DType 48 | device: str | Device 49 | batch: BatchInfo | None 50 | reducer: CallBack | None 51 | 52 | 53 | @dataclass 54 | class PrePass: 55 | shape: Tuple[int, ...] 56 | metadata: MetaData 57 | 58 | def __init__(self, shape: Sequence[int], metadata: MetaData) -> None: 59 | self.shape = tuple(shape) 60 | self.metadata = metadata 61 | 62 | def __iter__(self): 63 | return iter(self.shape) 64 | 65 | @overload 66 | def __getitem__(self, index: int) -> int: ... 67 | 68 | @overload 69 | def __getitem__(self, index: slice) -> Tuple[int, ...]: ... 70 | 71 | def __getitem__(self, index: int | slice) -> int | Tuple[int, ...]: 72 | return self.shape[index] 73 | 74 | def __eq__(self, other: Any) -> bool: 75 | if isinstance(other, PrePass): 76 | return self == other 77 | 78 | if isinstance(other, Tuple): 79 | return self.shape == other 80 | 81 | return False 82 | 83 | def dtype(self) -> DType: 84 | return self.metadata.dtype 85 | 86 | def device(self) -> str | Device: 87 | return self.metadata.device 88 | 89 | def batch(self) -> BatchInfo | None: 90 | return self.metadata.batch 91 | 92 | def reducer(self) -> CallBack | None: 93 | return self.metadata.reducer 94 | 95 | 96 | @runtime_checkable 97 | class PrePassFunc(Protocol): 98 | @abstractmethod 99 | def __call__(self, *args: Any, **kwargs: Any) -> PrePass: ... 100 | 101 | 102 | def mute_unused_args(*args: Any, **kwargs: Any) -> None: 103 | del args 104 | del kwargs 105 | 106 | 107 | def trivial(input: Tensor, *args: Any, **kwargs: Any) -> Reducer: 108 | mute_unused_args(input, *args, **kwargs) 109 | return lambda result: result 110 | 111 | 112 | def same( 113 | tensors: Sequence[TensorLike], batch: BatchInfo | None, reducer: CallBack | None 114 | ) -> MetaData: 115 | assert len(tensors) > 0 116 | dtypes = [interfaces.dtyp(t) for t in tensors] 117 | 118 | max_dtype = max(dtypes, key=lambda typ: constants.MEMORY_BYTES[typ]) 119 | 120 | devices = [str(interfaces.dev(t)) for t in tensors] 121 | 122 | if len(set(devices)) != 1: 123 | raise ValueError(f"Expected tensors to be on the same device, got {devices}.") 124 | 125 | return MetaData(max_dtype, devices[0], batch, reducer) 126 | 127 | 128 | def identity(input: TensorLike, /, *args: Any, **kwargs: Any) -> PrePass: 129 | mute_unused_args(*args, **kwargs) 130 | 131 | return PrePass(input.size(), same([input], interfaces.bat(input), trivial)) 132 | 133 | 134 | def symmetric( 135 | input: TensorLike, other: TensorLike, /, *args: Any, **kwargs: Any 136 | ) -> PrePass: 137 | mute_unused_args(*args, **kwargs) 138 | 139 | shape = shapes.coerce(input.size(), other.size(), broadcast=True, scalars=True) 140 | 141 | if shape is None: 142 | raise ValueError 143 | 144 | batch = None 145 | if (b := interfaces.bat(input)) == interfaces.bat(other): 146 | batch = b 147 | 148 | return PrePass(shape, same([input, other], batch, trivial)) 149 | 150 | 151 | def reduce_dims( 152 | input: TensorLike, 153 | /, 154 | dim: int | Tuple[int, ...] | None = None, 155 | keepdim: bool = False, 156 | *args: Any, 157 | **kwargs: Any, 158 | ) -> PrePass: 159 | mute_unused_args(*args, **kwargs) 160 | 161 | (shape, dimensions) = shapes.reduce_dims(input.size(), dim, keepdim) 162 | 163 | if interfaces.bat(input) in dimensions: 164 | batch = None 165 | reducer = None 166 | else: 167 | batch = interfaces.bat(input) 168 | reducer = trivial 169 | 170 | return PrePass(shape, same([input], batch, reducer)) 171 | 172 | 173 | def scalars(input: TensorLike, /, *args: Any, **kwargs: Any) -> PrePass: 174 | mute_unused_args(*args, **kwargs) 175 | 176 | return reduce_dims(input, tuple(range(input.dim()))) 177 | 178 | 179 | def mean( 180 | input: TensorLike, 181 | /, 182 | dim: int | Tuple[int, ...] | None = None, 183 | keepdim: bool = False, 184 | *args: Any, 185 | **kwargs: Any, 186 | ) -> PrePass: 187 | mute_unused_args(*args, **kwargs) 188 | 189 | (shape, dimensions) = shapes.reduce_dims(input.size(), dim, keepdim) 190 | 191 | if (b := interfaces.bat(input)) in dimensions: 192 | batch = None 193 | 194 | def mean_callback(input: Tensor, *args: Any, **kwargs: Any) -> Reducer: 195 | def reducer(result: Tensor) -> Tensor: 196 | return result * input.size(b) / shape[b] 197 | 198 | return reducer 199 | 200 | reducer = mean_callback 201 | else: 202 | batch = interfaces.bat(input) 203 | reducer = trivial 204 | return PrePass(shape, same([input], batch, reducer)) 205 | 206 | 207 | def permute(input: TensorLike, /, *dims: int, **kwargs: Any) -> PrePass: 208 | mute_unused_args(**kwargs) 209 | 210 | mapping = dict(enumerate(dims)) 211 | 212 | batch = None 213 | if (b := interfaces.bat(input)) is not None: 214 | batch = b.map(lambda x: mapping[x]) 215 | 216 | return PrePass(shapes.permute(input.size(), *dims), same([input], batch, trivial)) 217 | 218 | 219 | def reshape(input: TensorLike, /, *shape: int, **kwargs: Any) -> PrePass: 220 | mute_unused_args(**kwargs) 221 | 222 | shape = shapes.reshape(input.size(), *shape) 223 | 224 | batch = None 225 | if (b := interfaces.bat(input)) is not None: 226 | if b in shape: 227 | batch = b.map(shape.index) 228 | 229 | return PrePass(shape, same([input], batch, trivial)) 230 | 231 | 232 | def view(input: TensorLike, /, *shape: int, **kwargs: Any) -> PrePass: 233 | mute_unused_args(**kwargs) 234 | 235 | shape = shapes.view(input.size(), *shape) 236 | 237 | batch = None 238 | if (b := interfaces.bat(input)) is not None: 239 | if b in shape: 240 | batch = b.map(shape.index) 241 | 242 | return PrePass(shape, same([input], batch, trivial)) 243 | 244 | 245 | def flatten( 246 | input: TensorLike, 247 | /, 248 | start_dim: int = 0, 249 | end_dim: int = -1, 250 | *args: Any, 251 | **kwargs: Any, 252 | ) -> PrePass: 253 | LOGGER.debug("%s, %s, %s", input.size(), start_dim, end_dim) 254 | 255 | mute_unused_args(*args, **kwargs) 256 | 257 | start_dim %= input.dim() 258 | end_dim %= input.dim() 259 | 260 | sizes = input.size() 261 | 262 | shape = ( 263 | *sizes[:start_dim], 264 | functools.reduce(operator.mul, sizes[start_dim : end_dim + 1]), 265 | *sizes[end_dim + 1 :], 266 | ) 267 | 268 | batch = None 269 | if (b := interfaces.bat(input)) is not None: 270 | if not (start_dim <= b.index <= end_dim): 271 | batch = b 272 | 273 | return PrePass(shape, same([input], batch, trivial)) 274 | 275 | 276 | def tranpose( 277 | input: TensorLike, dim0: int, dim1: int, /, *args: Any, **kwargs: Any 278 | ) -> PrePass: 279 | mute_unused_args(*args, **kwargs) 280 | 281 | batch = None 282 | if (b := interfaces.bat(input)) is not None: 283 | batch = b.map(lambda x: {dim0: dim1, dim1: dim0}[x]) 284 | 285 | return PrePass( 286 | shapes.tranpose(input.size(), dim0, dim1), same([input], batch, trivial) 287 | ) 288 | 289 | 290 | def select( 291 | input: TensorLike, 292 | dim: int | ... | None, 293 | index: int | Tensor, 294 | /, 295 | *args: Any, 296 | **kwargs: Any, 297 | ) -> PrePass: 298 | mute_unused_args(*args, **kwargs) 299 | 300 | shape = input.size() 301 | 302 | if dim is ...: 303 | dim = -1 304 | 305 | if dim is None: 306 | dim = 0 307 | shape = (1,) + shape 308 | 309 | if not -len(shape) <= dim < len(shape): 310 | raise IndexError 311 | 312 | dim %= len(shape) 313 | assert isinstance(dim, int) 314 | 315 | if isinstance(index, Tensor): 316 | sliced_idx = (len(index),) 317 | else: 318 | sliced_idx = () 319 | 320 | batch = None 321 | if (b := interfaces.bat(input)) != dim: 322 | batch = b 323 | 324 | return PrePass( 325 | shape[:dim] + sliced_idx + shape[dim + 1 :], 326 | same([input], batch, trivial), 327 | ) 328 | 329 | 330 | def embedding( 331 | input: TensorLike, weight: TensorLike, /, *args: Any, **kwargs: Any 332 | ) -> PrePass: 333 | mute_unused_args(*args, **kwargs) 334 | 335 | shape = input.size() 336 | return PrePass( 337 | (*shape, weight.size(-1)), 338 | same([input], interfaces.bat(input), trivial), 339 | ) 340 | 341 | 342 | def matmul( 343 | input: TensorLike, other: TensorLike, /, *args: Any, **kwargs: Any 344 | ) -> PrePass: 345 | mute_unused_args(*args, **kwargs) 346 | 347 | if (batch := interfaces.bat(input)) != interfaces.bat(other): 348 | raise UnsupportedError 349 | 350 | return PrePass( 351 | shapes.matmul(input.size(), other.size()), 352 | same([input, other], interfaces.bat(input), trivial), 353 | ) 354 | 355 | 356 | def loss( 357 | input: TensorLike, 358 | target: TensorLike, 359 | /, 360 | reduction: Literal["none", "mean", "sum"] = "mean", 361 | *args: Any, 362 | **kwargs: Any, 363 | ) -> PrePass: 364 | mute_unused_args(*args, **kwargs) 365 | 366 | # Currently only supports tensors of the same batch size. 367 | if (batch := interfaces.bat(input)) != interfaces.bat(target): 368 | raise UnsupportedError 369 | 370 | output_shape = { 371 | "none": input.size(), 372 | "mean": (), 373 | "sum": (), 374 | }[reduction] 375 | 376 | reducer = {"none": trivial, "mean": trivial, "sum": trivial}[reduction] 377 | 378 | return PrePass(output_shape, same([input, target], batch, reducer)) 379 | 380 | 381 | def linear( 382 | input: TensorLike, 383 | weight: TensorLike, 384 | bias: TensorLike | None = None, 385 | *args: Any, 386 | **kwargs: Any, 387 | ) -> PrePass: 388 | mute_unused_args(*args, **kwargs) 389 | 390 | result = shapes.matmul(input.size(), shapes.tranpose(weight.size(), -1, -2)) 391 | 392 | if bias is not None: 393 | result = shapes.coerce(result, bias.size()) 394 | 395 | if result is None: 396 | raise ValueError 397 | 398 | return PrePass(result, same([input, weight], interfaces.bat(input), trivial)) 399 | 400 | 401 | def cat( 402 | tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any 403 | ) -> PrePass: 404 | mute_unused_args(*args, **kwargs) 405 | 406 | if len(tensors) == 0: 407 | raise ValueError("Expected a sequence of tensors. Got empty sequence.") 408 | 409 | shapes = [t.size() for t in tensors] 410 | no_dim = [t[:dim] + t[dim + 1 :] for t in shapes] 411 | 412 | result_size = no_dim[0] 413 | for size in no_dim[1:]: 414 | if result_size != size: 415 | raise ValueError( 416 | f"Dimension should be equal outside dim {dim}. Got {shapes}." 417 | ) 418 | 419 | if len(set(interfaces.bat(t) for t in tensors)) != 1: 420 | raise UnsupportedError 421 | 422 | batch = None 423 | if (b := interfaces.bat(tensors[0])) != dim: 424 | batch = b 425 | 426 | concat_size = sum(t[dim] for t in shapes) 427 | return PrePass( 428 | [*result_size[:dim], concat_size, *result_size[dim:]], 429 | same(tensors, batch, trivial), 430 | ) 431 | 432 | 433 | def pad(input: TensorLike, pad: List[int], *args: Any, **kwargs: Any) -> PrePass: 434 | mute_unused_args(*args, **kwargs) 435 | 436 | shapes = input.size() 437 | 438 | if len(pad) % 2 == 1: 439 | raise ValueError(f"Length of pad must be divisible by 2. Got {len(pad)}.") 440 | 441 | if len(pad) > (maxlen := len(shapes) * 2): 442 | raise ValueError( 443 | f"Padding is way too long. Got {pad}, but {maxlen} is the maximum dimensions allowed." 444 | ) 445 | 446 | pad = (2 * len(shapes) - len(pad)) * [0] + list(reversed(pad)) 447 | 448 | pad0 = pad[0::2] 449 | pad1 = pad[1::2] 450 | 451 | assert len(pad0) == len(pad1) == len(shapes), [pad0, pad1, shapes] 452 | 453 | return PrePass( 454 | [s + p0 + p1 for (s, p0, p1) in zip(shapes, pad0, pad1)], 455 | same([input], interfaces.bat(input), trivial), 456 | ) 457 | 458 | 459 | def _int_to_tuple(value: int | Tuple[int, ...], length: int) -> Tuple[int, ...]: 460 | if isinstance(value, int): 461 | return (value,) * length 462 | 463 | assert isinstance(value, Tuple) 464 | assert len(value) == length 465 | return value 466 | 467 | 468 | def conv( 469 | input: TensorLike, 470 | weight: TensorLike, 471 | bias: TensorLike | None = None, 472 | stride: int | Tuple[int, ...] = 1, 473 | padding: int | Tuple[int, ...] | str = "valid", 474 | dilation: int | Tuple[int, ...] = 1, 475 | groups: int = 1, 476 | *args: Any, 477 | **kwargs: Any, 478 | ) -> PrePass: 479 | mute_unused_args(groups, *args, **kwargs) 480 | 481 | (batch, chan, *dims) = input.size() 482 | (out_chan, in_chan, *kernels) = weight.size() 483 | 484 | assert chan == in_chan 485 | 486 | if bias is not None: 487 | assert shapes.coerce(bias.size(), (out_chan,)) is not None 488 | 489 | if isinstance(padding, str): 490 | raise UnsupportedError 491 | 492 | stride = _int_to_tuple(stride, len(dims)) 493 | padding = _int_to_tuple(padding, len(dims)) 494 | dilation = _int_to_tuple(dilation, len(dims)) 495 | 496 | assert len(dims) == len(kernels) == len(stride) == len(padding) == len(dilation) 497 | 498 | out_dims = [ 499 | math.floor((dim + 2 * pad - dil * (ker - 1) - 1) / st + 1) 500 | for (dim, pad, dil, ker, st) in zip(dims, padding, dilation, kernels, stride) 501 | ] 502 | 503 | return PrePass( 504 | (batch, out_chan, *out_dims), 505 | same([input, weight], interfaces.bat(input), trivial), 506 | ) 507 | 508 | 509 | def conv_transpose( 510 | input: TensorLike, 511 | weight: TensorLike, 512 | bias: TensorLike | None = None, 513 | stride: int | Tuple[int, ...] = 1, 514 | padding: int | Tuple[int, ...] = 0, 515 | output_padding: int | Tuple[int, ...] = 0, 516 | groups: int = 1, 517 | dilation: int | Tuple[int, ...] = 1, 518 | *args: Any, 519 | **kwargs: Any, 520 | ) -> PrePass: 521 | mute_unused_args(groups, *args, **kwargs) 522 | 523 | (batch, chan, *dims) = input.size() 524 | (in_chan, out_chan, *kernels) = weight.size() 525 | 526 | assert chan == in_chan 527 | 528 | if bias is not None: 529 | assert shapes.coerce(bias.size(), (out_chan,)) is not None 530 | 531 | stride = _int_to_tuple(stride, len(dims)) 532 | padding = _int_to_tuple(padding, len(dims)) 533 | output_padding = _int_to_tuple(output_padding, len(dims)) 534 | dilation = _int_to_tuple(dilation, len(dims)) 535 | 536 | assert len(dims) == len(kernels) == len(stride) == len(padding) == len(dilation) 537 | 538 | out_dims = [ 539 | (dim - 1) * st - 2 * pad + dil * (ker - 1) + opad + 1 540 | for (dim, st, pad, dil, ker, opad) in zip( 541 | dims, stride, padding, dilation, kernels, output_padding 542 | ) 543 | ] 544 | 545 | return PrePass( 546 | (batch, out_chan, *out_dims), 547 | same([input, weight], interfaces.bat(input), trivial), 548 | ) 549 | 550 | 551 | def pool( 552 | input: TensorLike, 553 | *, 554 | kernel_size: int | Tuple[int, ...], 555 | stride: int | Tuple[int, ...] = (), 556 | padding: int | Tuple[int, ...] = 0, 557 | dilation: int | Tuple[int, ...] = 1, 558 | ceil_mode: bool = False, 559 | ) -> PrePass: 560 | (batch, chan, *dims) = input.size() 561 | 562 | kernel_size = _int_to_tuple(kernel_size, len(dims)) 563 | stride = _int_to_tuple(stride, len(dims)) 564 | padding = _int_to_tuple(padding, len(dims)) 565 | dilation = _int_to_tuple(dilation, len(dims)) 566 | 567 | rounding = math.ceil if ceil_mode else math.floor 568 | out_dims = [ 569 | rounding((dim + 2 * pad - dil * (ker - 1) - 1) / st + 1) 570 | for (dim, pad, dil, ker, st) in zip( 571 | dims, padding, dilation, kernel_size, stride 572 | ) 573 | ] 574 | 575 | return PrePass( 576 | (batch, chan, *out_dims), same([input], interfaces.bat(input), trivial) 577 | ) 578 | 579 | 580 | def maxpool( 581 | input: TensorLike, 582 | kernel_size: int | Tuple[int, ...], 583 | stride: int | Tuple[int, ...] = (), 584 | padding: int | Tuple[int, ...] = 0, 585 | dilation: int | Tuple[int, ...] = 1, 586 | ceil_mode: bool = False, 587 | return_indices: bool = False, 588 | *args: Any, 589 | **kwargs: Any, 590 | ) -> PrePass: 591 | mute_unused_args(*args, **kwargs) 592 | 593 | if return_indices: 594 | raise UnsupportedError 595 | 596 | return pool( 597 | input, 598 | kernel_size=kernel_size, 599 | stride=stride, 600 | padding=padding, 601 | dilation=dilation, 602 | ceil_mode=ceil_mode, 603 | ) 604 | 605 | 606 | def avgpool( 607 | input: TensorLike, 608 | kernel_size: int | Tuple[int, ...], 609 | stride: int | Tuple[int, ...] = (), 610 | padding: int | Tuple[int, ...] = 0, 611 | ceil_mode: bool = False, 612 | *args: Any, 613 | **kwargs: Any, 614 | ) -> PrePass: 615 | mute_unused_args(*args, **kwargs) 616 | 617 | return pool( 618 | input, 619 | kernel_size=kernel_size, 620 | stride=stride, 621 | padding=padding, 622 | ceil_mode=ceil_mode, 623 | ) 624 | -------------------------------------------------------------------------------- /src/koila/shapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import functools 6 | import logging 7 | import operator 8 | from typing import Set, Tuple 9 | 10 | from rich.logging import RichHandler 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | LOGGER.addHandler(RichHandler()) 14 | 15 | 16 | def compatible_dim(input: int, other: int, broadcast: bool = True) -> bool: 17 | if broadcast: 18 | return input == 1 or other == 1 or input == other 19 | else: 20 | return input == other 21 | 22 | 23 | def prepends( 24 | input: Tuple[int, ...], other: Tuple[int, ...], value: int 25 | ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: 26 | LOGGER.debug("Prepending %s and %s.", input, other) 27 | 28 | prepended = (value,) * abs(len(input) - len(other)) 29 | if len(input) >= len(other): 30 | other = prepended + other 31 | else: 32 | input = prepended + input 33 | assert len(input) == len(other) 34 | return (input, other) 35 | 36 | 37 | def coerce( 38 | input: Tuple[int, ...], 39 | other: Tuple[int, ...], 40 | broadcast: bool = True, 41 | scalars: bool = True, 42 | ) -> Tuple[int, ...] | None: 43 | LOGGER.debug( 44 | "Coercing %s and %s. Broadcasting: %s. Allow scalars: %s.", 45 | input, 46 | other, 47 | broadcast, 48 | scalars, 49 | ) 50 | 51 | if scalars: 52 | if len(input) == 0: 53 | return other 54 | 55 | if len(other) == 0: 56 | return input 57 | 58 | if not broadcast: 59 | if (shape := input) == other: 60 | return shape 61 | else: 62 | return None 63 | 64 | (input, other) = prepends(input, other, 1) 65 | 66 | shape = [] 67 | for a, b in zip(input, other): 68 | if a <= 0 or b <= 0: 69 | raise ValueError 70 | 71 | if compatible_dim(a, b): 72 | shape.append(max(a, b)) 73 | else: 74 | return None 75 | 76 | return tuple(shape) 77 | 78 | 79 | def permute(input: Tuple[int, ...], *dims: int) -> Tuple[int, ...]: 80 | LOGGER.debug("%s, %s", input, dims) 81 | 82 | if not len(input) == len(dims): 83 | raise TypeError 84 | 85 | if sorted(dims) != list(range(len(input))): 86 | raise ValueError 87 | 88 | if not len(set(dims)) == len(input): 89 | raise ValueError 90 | 91 | dims_order_pair = sorted(enumerate(dims), key=lambda pair: pair[1]) 92 | scattered_dims = [pair[0] for pair in dims_order_pair] 93 | paired = sorted(zip(scattered_dims, input)) 94 | reordered_dim = [pair[1] for pair in paired] 95 | return tuple(reordered_dim) 96 | 97 | 98 | def reshape(input: Tuple[int, ...], *shape: int) -> Tuple[int, ...]: 99 | LOGGER.debug("%s, %s", input, shape) 100 | 101 | if not functools.reduce(operator.mul, input) == functools.reduce( 102 | operator.mul, shape 103 | ): 104 | raise ValueError 105 | return shape 106 | 107 | 108 | def view(input: Tuple[int, ...], *shape: int) -> Tuple[int, ...]: 109 | LOGGER.debug("%s, %s", input, shape) 110 | 111 | special_values = [x for x in shape if x < 0] 112 | 113 | if len(special_values) > 1: 114 | raise ValueError 115 | 116 | if set(special_values) | {-1} != {-1}: 117 | raise ValueError 118 | 119 | special = -( 120 | functools.reduce(operator.mul, input) // functools.reduce(operator.mul, shape) 121 | ) 122 | new_shape = [] 123 | for s in shape: 124 | if s > 0: 125 | new_shape.append(s) 126 | else: 127 | new_shape.append(special) 128 | 129 | return reshape(input, *new_shape) 130 | 131 | 132 | def tranpose(input: Tuple[int, ...], dim0: int, dim1: int) -> Tuple[int, ...]: 133 | LOGGER.debug("%s, %d, %d", input, dim0, dim1) 134 | 135 | if len(input) < 2: 136 | raise ValueError 137 | 138 | shapes = list(input) 139 | (shapes[dim0], shapes[dim1]) = (shapes[dim1], shapes[dim0]) 140 | return tuple(shapes) 141 | 142 | 143 | def matmul(input: Tuple[int, ...], other: Tuple[int, ...]) -> Tuple[int, ...]: 144 | LOGGER.debug("%s, %s", input, other) 145 | 146 | if len(input) == 0 or len(other) == 0: 147 | raise ValueError( 148 | "Both arguments to matmul need to be at least 1D." 149 | " " 150 | f"Got {len(input)}D and {len(other)}D." 151 | ) 152 | 153 | if len(input) == len(other) == 1: 154 | if input[0] != other[0]: 155 | raise ValueError 156 | 157 | return () 158 | 159 | if len(input) == len(other) == 2: 160 | if input[1] != other[0]: 161 | raise ValueError 162 | 163 | return (input[0], other[1]) 164 | 165 | if len(input) == 1 and len(other) == 2: 166 | if input[0] != other[0]: 167 | raise ValueError 168 | 169 | return (other[1],) 170 | 171 | if len(input) == 2 and len(other) == 1: 172 | if input[1] != other[0]: 173 | raise ValueError 174 | 175 | return (input[0],) 176 | 177 | (input, other) = prepends(input, other, 1) 178 | 179 | shapes = [] 180 | for dimi, dimo in zip(input[:-2], other[:-2]): 181 | if not compatible_dim(dimi, dimo): 182 | raise ValueError 183 | shapes.append(max(dimi, dimo)) 184 | 185 | if input[-1] != other[-2]: 186 | raise ValueError 187 | 188 | shapes.extend([input[-2], other[-1]]) 189 | 190 | return tuple(shapes) 191 | 192 | 193 | def reduce_dims( 194 | input: Tuple[int, ...], 195 | dim: int | Tuple[int, ...] | None = None, 196 | keepdim: bool = False, 197 | ) -> Tuple[Tuple[int, ...], Set[int]]: 198 | LOGGER.debug("%s, %s", input, dim) 199 | 200 | shapes = [] 201 | 202 | if dim is None: 203 | dimensions = set(range(len(input))) 204 | elif isinstance(dim, int): 205 | dimensions = {dim} 206 | else: 207 | dimensions = set(dim) 208 | 209 | for idx, dimsize in enumerate(input): 210 | if idx not in dimensions: 211 | shapes.append(dimsize) 212 | continue 213 | 214 | if keepdim: 215 | shapes.append(1) 216 | 217 | if keepdim: 218 | assert len(shapes) == len(input) 219 | 220 | return (tuple(shapes), dimensions) 221 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | from __future__ import annotations 4 | 5 | import dataclasses as dcls 6 | import math 7 | import typing 8 | from dataclasses import dataclass 9 | from typing import Any, Callable, Dict, Sequence 10 | 11 | import numpy as np 12 | import torch 13 | from numpy import ndarray 14 | from torch import Tensor 15 | from torch.types import Number 16 | 17 | 18 | @dataclass(init=False) 19 | class ArgsKwargs: 20 | def __init__(self, *args: Any, **kwargs: Any) -> None: 21 | self.args = args 22 | self.kwargs = kwargs 23 | 24 | args: Sequence[Any] = dcls.field(default_factory=tuple) 25 | kwargs: Dict[str, Any] = dcls.field(default_factory=dict) 26 | 27 | 28 | @dataclass(init=False) 29 | class Caller: 30 | func: Callable[..., Any] 31 | arguments: Sequence[ArgsKwargs] = dcls.field(default_factory=list) 32 | 33 | def __init__( 34 | self, 35 | func: Callable[..., Any], 36 | arguments: Sequence[ArgsKwargs | Sequence[Any] | Dict[str, Any]], 37 | ) -> None: 38 | self.func = func 39 | self.arguments = [] 40 | 41 | for argument in arguments: 42 | if isinstance(argument, Sequence): 43 | argument = ArgsKwargs(*argument) 44 | 45 | if isinstance(argument, dict): 46 | assert all(isinstance(key, str) for key in argument.keys()) 47 | argument = ArgsKwargs(**argument) 48 | 49 | self.arguments.append(argument) 50 | 51 | def call(self) -> None: 52 | for argument in self.arguments: 53 | self.func(*argument.args, **argument.kwargs) 54 | 55 | 56 | def call( 57 | func: Callable[..., Any], 58 | arguments: Sequence[ArgsKwargs | Sequence[Any] | Dict[str, Any]], 59 | ) -> None: 60 | Caller(func, arguments=arguments).call() 61 | 62 | 63 | def assert_equal( 64 | input: Tensor | ndarray | Number, other: Tensor | ndarray | Number 65 | ) -> None: 66 | if isinstance(input, ndarray) or isinstance(other, ndarray): 67 | assert np.all(input == other), input != other 68 | return 69 | 70 | if isinstance(input, Tensor) or isinstance(other, Tensor): 71 | assert typing.cast(Tensor, input == other).all(), input != other 72 | return 73 | 74 | assert input == other, [input, other] 75 | 76 | 77 | def assert_isclose( 78 | input: Tensor | ndarray | Number, other: Tensor | ndarray | Number 79 | ) -> None: 80 | if isinstance(input, ndarray) or isinstance(other, ndarray): 81 | assert np.allclose(input, other, atol=1e-5), [input, other] 82 | return 83 | 84 | if isinstance(input, Tensor) and isinstance(other, Tensor): 85 | assert torch.allclose(input, other, atol=1e-5), [input, other] 86 | return 87 | 88 | assert math.isclose(input, other, abs_tol=1e-5), [input, other] 89 | 90 | 91 | def is_notimplemented(func: Callable[[], Any]) -> bool: 92 | try: 93 | func() 94 | return False 95 | except: 96 | return True 97 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import ( 6 | AvgPool1d, 7 | AvgPool2d, 8 | AvgPool3d, 9 | BatchNorm1d, 10 | BatchNorm2d, 11 | BatchNorm3d, 12 | Conv1d, 13 | Conv2d, 14 | Conv3d, 15 | ConvTranspose1d, 16 | ConvTranspose2d, 17 | ConvTranspose3d, 18 | Dropout, 19 | Embedding, 20 | LayerNorm, 21 | LeakyReLU, 22 | Linear, 23 | MaxPool1d, 24 | MaxPool2d, 25 | MaxPool3d, 26 | ReLU, 27 | Sigmoid, 28 | Softmax, 29 | ) 30 | 31 | import koila 32 | from koila import LazyTensor 33 | 34 | from . import common 35 | 36 | 37 | def test_linear_layer() -> None: 38 | arr = torch.randn(7, 11, 13) 39 | la = koila.lazy(arr) 40 | layer = Linear(13, 17) 41 | 42 | out = layer(arr) 43 | assert out.shape == (7, 11, 17) 44 | assert not isinstance(out, LazyTensor) 45 | assert isinstance(out, Tensor) 46 | 47 | assert isinstance(la, LazyTensor) 48 | lo = layer(la) 49 | assert lo.shape == (7, 11, 17) 50 | assert not isinstance(lo, Tensor) 51 | assert isinstance(lo, LazyTensor) 52 | common.assert_isclose(lo.run(), out) 53 | 54 | 55 | def test_batchnorm_layers() -> None: 56 | # 1D 57 | arr = torch.randn(3, 5, 7) 58 | la = koila.lazy(arr) 59 | layer = BatchNorm1d(5) 60 | 61 | out = layer(arr) 62 | assert out.shape == (3, 5, 7) 63 | assert not isinstance(out, LazyTensor) 64 | assert isinstance(out, Tensor) 65 | 66 | assert isinstance(la, LazyTensor) 67 | lo = layer(la) 68 | assert lo.shape == (3, 5, 7) 69 | assert not isinstance(lo, Tensor) 70 | assert isinstance(lo, LazyTensor) 71 | common.assert_isclose(lo.run(), out) 72 | 73 | # 2D 74 | arr = torch.randn(3, 5, 7, 11) 75 | la = koila.lazy(arr) 76 | layer = BatchNorm2d(5) 77 | 78 | out = layer(arr) 79 | assert out.shape == (3, 5, 7, 11) 80 | assert not isinstance(out, LazyTensor) 81 | assert isinstance(out, Tensor) 82 | 83 | assert isinstance(la, LazyTensor) 84 | lo = layer(la) 85 | assert lo.shape == (3, 5, 7, 11) 86 | assert not isinstance(lo, Tensor) 87 | assert isinstance(lo, LazyTensor) 88 | common.assert_isclose(lo.run(), out) 89 | 90 | # 3D 91 | arr = torch.randn(3, 5, 7, 11, 13) 92 | la = koila.lazy(arr) 93 | layer = BatchNorm3d(5) 94 | 95 | out = layer(arr) 96 | assert out.shape == (3, 5, 7, 11, 13) 97 | assert not isinstance(out, LazyTensor) 98 | assert isinstance(out, Tensor) 99 | 100 | assert isinstance(la, LazyTensor) 101 | lo = layer(la) 102 | assert lo.shape == (3, 5, 7, 11, 13) 103 | assert not isinstance(lo, Tensor) 104 | assert isinstance(lo, LazyTensor) 105 | common.assert_isclose(lo.run(), out) 106 | 107 | 108 | def test_layernorm_layers() -> None: 109 | # 1D 110 | arr = torch.randn(3, 5, 7) 111 | la = koila.lazy(arr) 112 | layer = LayerNorm([5, 7]) 113 | 114 | out = layer(arr) 115 | assert out.shape == (3, 5, 7) 116 | assert not isinstance(out, LazyTensor) 117 | assert isinstance(out, Tensor) 118 | 119 | assert isinstance(la, LazyTensor) 120 | lo = layer(la) 121 | assert lo.shape == (3, 5, 7) 122 | assert not isinstance(lo, Tensor) 123 | assert isinstance(lo, LazyTensor) 124 | common.assert_isclose(lo.run(), out) 125 | 126 | 127 | def test_dropout_layer() -> None: 128 | arr = torch.randn(7, 11) 129 | la = koila.lazy(arr) 130 | layer = Dropout(p=0.5) 131 | 132 | out = layer(arr) 133 | assert out.shape == (7, 11) 134 | assert not isinstance(out, LazyTensor) 135 | assert isinstance(out, Tensor) 136 | 137 | assert isinstance(la, LazyTensor) 138 | lo = layer(la) 139 | assert lo.shape == (7, 11) 140 | assert not isinstance(lo, Tensor) 141 | assert isinstance(lo, LazyTensor) 142 | 143 | 144 | def test_relu_layer() -> None: 145 | arr = torch.randn(7, 11) 146 | la = koila.lazy(arr) 147 | layer = ReLU() 148 | 149 | out = layer(arr) 150 | assert out.shape == (7, 11) 151 | assert not isinstance(out, LazyTensor) 152 | assert isinstance(out, Tensor) 153 | 154 | assert isinstance(la, LazyTensor) 155 | lo = layer(la) 156 | assert lo.shape == (7, 11) 157 | assert not isinstance(lo, Tensor) 158 | assert isinstance(lo, LazyTensor) 159 | common.assert_isclose(lo.run(), out) 160 | 161 | 162 | def test_leaky_relu_layer() -> None: 163 | arr = torch.randn(7, 11) 164 | la = koila.lazy(arr) 165 | layer = LeakyReLU(negative_slope=0.3) 166 | 167 | out = layer(arr) 168 | assert out.shape == (7, 11) 169 | assert not isinstance(out, LazyTensor) 170 | assert isinstance(out, Tensor) 171 | 172 | assert isinstance(la, LazyTensor) 173 | lo = layer(la) 174 | assert lo.shape == (7, 11) 175 | assert not isinstance(lo, Tensor) 176 | assert isinstance(lo, LazyTensor) 177 | common.assert_isclose(lo.run(), out) 178 | 179 | 180 | def test_sigmoid_layer() -> None: 181 | arr = torch.randn(7, 11) 182 | la = koila.lazy(arr) 183 | layer = Sigmoid() 184 | 185 | out = layer(arr) 186 | assert out.shape == (7, 11) 187 | assert not isinstance(out, LazyTensor) 188 | assert isinstance(out, Tensor) 189 | 190 | assert isinstance(la, LazyTensor) 191 | lo = layer(la) 192 | assert lo.shape == (7, 11) 193 | assert not isinstance(lo, Tensor) 194 | assert isinstance(lo, LazyTensor) 195 | common.assert_isclose(lo.run(), out) 196 | 197 | 198 | def test_softmax_layer() -> None: 199 | arr = torch.randn(7, 11) 200 | la = koila.lazy(arr) 201 | layer = Softmax(dim=-1) 202 | 203 | out = layer(arr) 204 | assert out.shape == (7, 11) 205 | assert not isinstance(out, LazyTensor) 206 | assert isinstance(out, Tensor) 207 | 208 | assert isinstance(la, LazyTensor) 209 | lo = layer(la) 210 | assert lo.shape == (7, 11) 211 | assert not isinstance(lo, Tensor) 212 | assert isinstance(lo, LazyTensor) 213 | common.assert_isclose(lo.run(), out) 214 | 215 | 216 | def test_conv_layer() -> None: 217 | # 1D 218 | arr = torch.randn(7, 11, 13) 219 | la = koila.lazy(arr) 220 | layer = Conv1d(11, 17, kernel_size=3, stride=2) 221 | 222 | out = layer(arr) 223 | assert not isinstance(out, LazyTensor) 224 | assert isinstance(out, Tensor) 225 | 226 | assert isinstance(la, LazyTensor) 227 | lo = layer(la) 228 | assert not isinstance(lo, Tensor) 229 | assert isinstance(lo, LazyTensor) 230 | assert lo.shape == out.shape 231 | common.assert_isclose(lo.run(), out) 232 | 233 | # 2D 234 | arr = torch.randn(7, 11, 13, 14) 235 | la = koila.lazy(arr) 236 | layer = Conv2d(11, 17, kernel_size=3, stride=2) 237 | 238 | out = layer(arr) 239 | assert not isinstance(out, LazyTensor) 240 | assert isinstance(out, Tensor) 241 | 242 | assert isinstance(la, LazyTensor) 243 | lo = layer(la) 244 | assert not isinstance(lo, Tensor) 245 | assert isinstance(lo, LazyTensor) 246 | assert lo.shape == out.shape 247 | common.assert_isclose(lo.run(), out) 248 | 249 | # 3D 250 | arr = torch.randn(7, 11, 13, 14, 15) 251 | la = koila.lazy(arr) 252 | layer = Conv3d(11, 17, kernel_size=3, stride=2) 253 | 254 | out = layer(arr) 255 | assert not isinstance(out, LazyTensor) 256 | assert isinstance(out, Tensor) 257 | 258 | assert isinstance(la, LazyTensor) 259 | lo = layer(la) 260 | assert not isinstance(lo, Tensor) 261 | assert isinstance(lo, LazyTensor) 262 | assert lo.shape == out.shape 263 | common.assert_isclose(lo.run(), out) 264 | 265 | 266 | def test_convtranspose_layer() -> None: 267 | # 1D 268 | arr = torch.randn(7, 11, 13) 269 | la = koila.lazy(arr) 270 | layer = ConvTranspose1d(11, 17, kernel_size=3, stride=2) 271 | 272 | out = layer(arr) 273 | assert not isinstance(out, LazyTensor) 274 | assert isinstance(out, Tensor) 275 | 276 | assert isinstance(la, LazyTensor) 277 | lo = layer(la) 278 | assert not isinstance(lo, Tensor) 279 | assert isinstance(lo, LazyTensor) 280 | assert lo.shape == out.shape 281 | common.assert_isclose(lo.run(), out) 282 | 283 | # 2D 284 | arr = torch.randn(7, 11, 13, 14) 285 | la = koila.lazy(arr) 286 | layer = ConvTranspose2d(11, 17, kernel_size=3, stride=2) 287 | 288 | out = layer(arr) 289 | assert not isinstance(out, LazyTensor) 290 | assert isinstance(out, Tensor) 291 | 292 | assert isinstance(la, LazyTensor) 293 | lo = layer(la) 294 | assert not isinstance(lo, Tensor) 295 | assert isinstance(lo, LazyTensor) 296 | assert lo.shape == out.shape 297 | common.assert_isclose(lo.run(), out) 298 | 299 | # 3D 300 | arr = torch.randn(7, 11, 13, 14, 15) 301 | la = koila.lazy(arr) 302 | layer = ConvTranspose3d(11, 17, kernel_size=3, stride=2) 303 | 304 | out = layer(arr) 305 | assert not isinstance(out, LazyTensor) 306 | assert isinstance(out, Tensor) 307 | 308 | assert isinstance(la, LazyTensor) 309 | lo = layer(la) 310 | assert not isinstance(lo, Tensor) 311 | assert isinstance(lo, LazyTensor) 312 | assert lo.shape == out.shape 313 | common.assert_isclose(lo.run(), out) 314 | 315 | 316 | def test_maxpool_layer() -> None: 317 | # 1D 318 | arr = torch.randn(7, 11, 13) 319 | la = koila.lazy(arr) 320 | layer = MaxPool1d(kernel_size=3, stride=2) 321 | 322 | out = layer(arr) 323 | assert not isinstance(out, LazyTensor) 324 | assert isinstance(out, Tensor) 325 | 326 | assert isinstance(la, LazyTensor) 327 | lo = layer(la) 328 | assert not isinstance(lo, Tensor) 329 | assert isinstance(lo, LazyTensor) 330 | assert lo.shape == out.shape 331 | common.assert_isclose(lo.run(), out) 332 | 333 | # 2D 334 | arr = torch.randn(7, 11, 13, 14) 335 | la = koila.lazy(arr) 336 | layer = MaxPool2d(kernel_size=3, stride=2) 337 | 338 | out = layer(arr) 339 | assert not isinstance(out, LazyTensor) 340 | assert isinstance(out, Tensor) 341 | 342 | assert isinstance(la, LazyTensor) 343 | lo = layer(la) 344 | assert not isinstance(lo, Tensor) 345 | assert isinstance(lo, LazyTensor) 346 | assert lo.shape == out.shape 347 | common.assert_isclose(lo.run(), out) 348 | 349 | # 3D 350 | arr = torch.randn(7, 11, 13, 14, 15) 351 | la = koila.lazy(arr) 352 | layer = MaxPool3d(kernel_size=3, stride=2) 353 | 354 | out = layer(arr) 355 | assert not isinstance(out, LazyTensor) 356 | assert isinstance(out, Tensor) 357 | 358 | assert isinstance(la, LazyTensor) 359 | lo = layer(la) 360 | assert not isinstance(lo, Tensor) 361 | assert isinstance(lo, LazyTensor) 362 | assert lo.shape == out.shape 363 | common.assert_isclose(lo.run(), out) 364 | 365 | 366 | def test_avgpool_layer() -> None: 367 | # 1D 368 | arr = torch.randn(7, 11, 13) 369 | la = koila.lazy(arr) 370 | layer = AvgPool1d(kernel_size=3, stride=2) 371 | 372 | out = layer(arr) 373 | assert not isinstance(out, LazyTensor) 374 | assert isinstance(out, Tensor) 375 | 376 | assert isinstance(la, LazyTensor) 377 | lo = layer(la) 378 | assert not isinstance(lo, Tensor) 379 | assert isinstance(lo, LazyTensor) 380 | assert lo.shape == out.shape 381 | common.assert_isclose(lo.run(), out) 382 | 383 | # 2D 384 | arr = torch.randn(7, 11, 13, 14) 385 | la = koila.lazy(arr) 386 | layer = AvgPool2d(kernel_size=3, stride=2) 387 | 388 | out = layer(arr) 389 | assert not isinstance(out, LazyTensor) 390 | assert isinstance(out, Tensor) 391 | 392 | assert isinstance(la, LazyTensor) 393 | lo = layer(la) 394 | assert not isinstance(lo, Tensor) 395 | assert isinstance(lo, LazyTensor) 396 | assert lo.shape == out.shape 397 | common.assert_isclose(lo.run(), out) 398 | 399 | # 3D 400 | arr = torch.randn(7, 11, 13, 14, 15) 401 | la = koila.lazy(arr) 402 | layer = AvgPool3d(kernel_size=3, stride=2) 403 | 404 | out = layer(arr) 405 | assert not isinstance(out, LazyTensor) 406 | assert isinstance(out, Tensor) 407 | 408 | assert isinstance(la, LazyTensor) 409 | lo = layer(la) 410 | assert not isinstance(lo, Tensor) 411 | assert isinstance(lo, LazyTensor) 412 | assert lo.shape == out.shape 413 | common.assert_isclose(lo.run(), out) 414 | 415 | 416 | def test_embedding_layer() -> None: 417 | arr = torch.randint(0, 11, [5]) 418 | la = koila.lazy(arr) 419 | layer = Embedding(num_embeddings=11, embedding_dim=13) 420 | 421 | out = layer(arr) 422 | assert out.shape == (5, 13) 423 | assert not isinstance(out, LazyTensor) 424 | assert isinstance(out, Tensor) 425 | 426 | assert isinstance(la, LazyTensor) 427 | lo = layer(la) 428 | assert lo.shape == (5, 13) 429 | assert not isinstance(lo, Tensor) 430 | assert isinstance(lo, LazyTensor) 431 | common.assert_isclose(lo.run(), out) 432 | -------------------------------------------------------------------------------- /tests/test_lazy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | import math 4 | import typing 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | from torch.nn import functional as F 10 | 11 | import koila 12 | from koila import Evaluation, LazyTensor, Runnable, RunnableTensor 13 | 14 | from . import common 15 | 16 | 17 | def test_lazytensor_is_runnable() -> None: 18 | assert issubclass(Evaluation, Runnable) 19 | assert issubclass(Evaluation, RunnableTensor) 20 | assert issubclass(LazyTensor, Runnable) 21 | assert issubclass(LazyTensor, RunnableTensor) 22 | 23 | 24 | def test_positive_op() -> None: 25 | common.call( 26 | lambda a, c: common.assert_isclose((+a).item(), c), 27 | [[LazyTensor(torch.tensor(-11)), -11]], 28 | ) 29 | 30 | 31 | def test_positive_method() -> None: 32 | common.call( 33 | lambda a, c: common.assert_isclose(a.positive().item(), c), 34 | [[LazyTensor(torch.tensor(4)), 4]], 35 | ) 36 | 37 | 38 | def test_positive_function() -> None: 39 | common.call( 40 | lambda a, c: common.assert_isclose(torch.positive(a).item(), c), 41 | [[LazyTensor(torch.tensor(-8)), -8]], 42 | ) 43 | 44 | 45 | def test_negative_op() -> None: 46 | common.call( 47 | lambda a, c: common.assert_isclose((-a).item(), c), 48 | [[LazyTensor(torch.tensor(-13)), 13]], 49 | ) 50 | 51 | 52 | def test_negative_method() -> None: 53 | common.call( 54 | lambda a, c: common.assert_isclose(a.neg().item(), c), 55 | [[LazyTensor(torch.tensor(2)), -2]], 56 | ) 57 | 58 | 59 | def test_negative_function() -> None: 60 | common.call( 61 | lambda a, c: common.assert_equal(torch.neg(a).item(), c), 62 | [[LazyTensor(torch.tensor(-5)), 5]], 63 | ) 64 | 65 | 66 | def test_eq_ne_op() -> None: 67 | arr = torch.randint(0, 2, [2, 3, 4]) 68 | brr = torch.randint(0, 2, [2, 3, 4]) 69 | la = typing.cast(Tensor, LazyTensor(arr)) 70 | lb = typing.cast(Tensor, LazyTensor(brr)) 71 | common.call( 72 | lambda a, c: common.assert_equal(koila.run(a), c), 73 | [[la == lb, arr == brr], [la != lb, arr != brr]], 74 | ) 75 | 76 | 77 | def test_cmp_op() -> None: 78 | arr = torch.randint(0, 5, [2, 3, 4]) 79 | brr = torch.randint(0, 5, [2, 3, 4]) 80 | la = typing.cast(Tensor, LazyTensor(arr)) 81 | lb = typing.cast(Tensor, LazyTensor(brr)) 82 | common.call( 83 | lambda a, c: common.assert_equal(koila.run(a), c), 84 | [ 85 | [la < lb, arr < brr], 86 | [la <= lb, arr <= brr], 87 | [la > lb, arr > brr], 88 | [la >= lb, arr >= brr], 89 | ], 90 | ) 91 | 92 | 93 | def test_add_op() -> None: 94 | common.call( 95 | lambda a, b, c: common.assert_isclose((a + b).item(), c), 96 | [ 97 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 + 2], 98 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 + 2], 99 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 + 2], 100 | ], 101 | ) 102 | 103 | 104 | def test_add_method() -> None: 105 | common.call( 106 | lambda a, b, c: common.assert_isclose(a.add(b).item(), c), 107 | [ 108 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 + 3], 109 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 + 3], 110 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 + 3], 111 | ], 112 | ) 113 | 114 | 115 | def test_add_function() -> None: 116 | common.call( 117 | lambda a, b, c: common.assert_isclose(torch.add(a, b).item(), c), 118 | [ 119 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 8 + 4], 120 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 8 + 4], 121 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 8 + 4], 122 | ], 123 | ) 124 | 125 | 126 | def test_sub_op() -> None: 127 | common.call( 128 | lambda a, b, c: common.assert_isclose((a - b).item(), c), 129 | [ 130 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 - 2], 131 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 - 2], 132 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 - 2], 133 | ], 134 | ) 135 | 136 | 137 | def test_sub_method() -> None: 138 | common.call( 139 | lambda a, b, c: common.assert_isclose(a.sub(b).item(), c), 140 | [ 141 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 - 3], 142 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 - 3], 143 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 - 3], 144 | ], 145 | ) 146 | 147 | 148 | def test_sub_function() -> None: 149 | common.call( 150 | lambda a, b, c: common.assert_isclose(torch.sub(a, b).item(), c), 151 | [ 152 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 8 - 4], 153 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 8 - 4], 154 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 8 - 4], 155 | ], 156 | ) 157 | 158 | 159 | def test_mul_op() -> None: 160 | common.call( 161 | lambda a, b, c: common.assert_isclose((a * b).item(), c), 162 | [ 163 | [LazyTensor(torch.tensor(0.5)), LazyTensor(torch.tensor(2)), 0.5 * 2], 164 | [torch.tensor(0.5), LazyTensor(torch.tensor(2)), 0.5 * 2], 165 | [LazyTensor(torch.tensor(0.5)), torch.tensor(2), 0.5 * 2], 166 | ], 167 | ) 168 | 169 | 170 | def test_mul_method() -> None: 171 | common.call( 172 | lambda a, b, c: common.assert_isclose(a.mul(b).item(), c), 173 | [ 174 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 12], 175 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 12], 176 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 12], 177 | ], 178 | ) 179 | 180 | 181 | def test_mul_function() -> None: 182 | common.call( 183 | lambda a, b, c: common.assert_isclose(torch.mul(a, b).item(), c), 184 | [ 185 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 32], 186 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 32], 187 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 32], 188 | ], 189 | ) 190 | 191 | 192 | def test_floordiv_op() -> None: 193 | common.call( 194 | common.is_notimplemented, 195 | [ 196 | [lambda: LazyTensor(torch.tensor(1)) // LazyTensor(torch.tensor(2))], 197 | [lambda: torch.tensor(1) // LazyTensor(torch.tensor(2))], 198 | [lambda: LazyTensor(torch.tensor(1)) // torch.tensor(2)], 199 | ], 200 | ) 201 | 202 | 203 | def test_floordiv_method() -> None: 204 | common.call( 205 | lambda a, b, c: common.assert_isclose( 206 | a.div(b, rounding_mode="trunc").item(), c 207 | ), 208 | [ 209 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 // 3], 210 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 // 3], 211 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 // 3], 212 | ], 213 | ) 214 | 215 | 216 | def test_floordiv_function() -> None: 217 | common.call( 218 | lambda a, b, c: common.assert_isclose( 219 | torch.div(a, b, rounding_mode="trunc").item(), c 220 | ), 221 | [ 222 | [LazyTensor(torch.tensor(9)), LazyTensor(torch.tensor(4)), 9 // 4], 223 | [torch.tensor(9), LazyTensor(torch.tensor(4)), 9 // 4], 224 | [LazyTensor(torch.tensor(9)), torch.tensor(4), 9 // 4], 225 | ], 226 | ) 227 | 228 | 229 | def test_truediv_op() -> None: 230 | common.call( 231 | lambda a, b, c: common.assert_isclose((a / b).item(), c), 232 | [ 233 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 / 2], 234 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 / 2], 235 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 / 2], 236 | ], 237 | ) 238 | 239 | 240 | def test_truediv_method() -> None: 241 | common.call( 242 | lambda a, b, c: common.assert_isclose(a.div(b).item(), c), 243 | [ 244 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 / 3], 245 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 / 3], 246 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 / 3], 247 | ], 248 | ) 249 | 250 | 251 | def test_truediv_function() -> None: 252 | common.call( 253 | lambda a, b, c: common.assert_isclose(torch.div(a, b).item(), c), 254 | [ 255 | [LazyTensor(torch.tensor(9)), LazyTensor(torch.tensor(4)), 9 / 4], 256 | [torch.tensor(9), LazyTensor(torch.tensor(4)), 9 / 4], 257 | [LazyTensor(torch.tensor(9)), torch.tensor(4), 9 / 4], 258 | ], 259 | ) 260 | 261 | 262 | def test_pow_op() -> None: 263 | common.call( 264 | lambda a, b, c: common.assert_isclose((a**b).item(), c), 265 | [ 266 | [LazyTensor(torch.tensor(1.5)), LazyTensor(torch.tensor(2)), 1.5**2], 267 | [torch.tensor(1.5), LazyTensor(torch.tensor(2)), 1.5**2], 268 | [LazyTensor(torch.tensor(1.5)), torch.tensor(2), 1.5**2], 269 | ], 270 | ) 271 | 272 | 273 | def test_pow_method() -> None: 274 | common.call( 275 | lambda a, b, c: common.assert_isclose(a.pow(b).item(), c), 276 | [ 277 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4**3], 278 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4**3], 279 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4**3], 280 | ], 281 | ) 282 | 283 | 284 | def test_pow_function() -> None: 285 | common.call( 286 | lambda a, b, c: common.assert_isclose(torch.pow(a, b).item(), c), 287 | [ 288 | [LazyTensor(torch.tensor(9.0)), LazyTensor(torch.tensor(-2)), 9.0**-2], 289 | [torch.tensor(9.0), LazyTensor(torch.tensor(-2)), 9.0**-2], 290 | [LazyTensor(torch.tensor(9.0)), torch.tensor(-2), 9.0**-2], 291 | ], 292 | ) 293 | 294 | 295 | def test_remainder_op() -> None: 296 | common.call( 297 | lambda a, b, c: common.assert_isclose((a % b).item(), c), 298 | [ 299 | [LazyTensor(torch.tensor(3.3)), LazyTensor(torch.tensor(1.9)), 3.3 % 1.9], 300 | [torch.tensor(3.3), LazyTensor(torch.tensor(1.9)), 3.3 % 1.9], 301 | [LazyTensor(torch.tensor(3.3)), torch.tensor(1.9), 3.3 % 1.9], 302 | ], 303 | ) 304 | 305 | 306 | def test_remainder_method() -> None: 307 | common.call( 308 | lambda a, b, c: common.assert_isclose(a.remainder(b).item(), c), 309 | [ 310 | [LazyTensor(torch.tensor(99)), LazyTensor(torch.tensor(7)), 99 % 7], 311 | [torch.tensor(99), LazyTensor(torch.tensor(7)), 99 % 7], 312 | [LazyTensor(torch.tensor(99)), torch.tensor(7), 99 % 7], 313 | ], 314 | ) 315 | 316 | 317 | def test_remainder_function() -> None: 318 | common.call( 319 | lambda a, b, c: common.assert_isclose(torch.remainder(a, b).item(), c), 320 | [ 321 | [LazyTensor(torch.tensor(25)), LazyTensor(torch.tensor(7.8)), 25 % 7.8], 322 | [torch.tensor(25), LazyTensor(torch.tensor(7.8)), 25 % 7.8], 323 | [LazyTensor(torch.tensor(25)), torch.tensor(7.8), 25 % 7.8], 324 | ], 325 | ) 326 | 327 | 328 | def test_matmul_op() -> None: 329 | arr = torch.randn(2, 10, 11) 330 | 331 | common.call( 332 | lambda a, b, c: common.assert_isclose(koila.run(a @ b), c), 333 | [ 334 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T], 335 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T], 336 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T], 337 | ], 338 | ) 339 | 340 | 341 | def test_matmul_method() -> None: 342 | arr = torch.randn(2, 10, 11) 343 | 344 | common.call( 345 | lambda a, b, c: common.assert_isclose(koila.run(a.matmul(b)), c), 346 | [ 347 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T], 348 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T], 349 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T], 350 | ], 351 | ) 352 | 353 | 354 | def test_matmul_function() -> None: 355 | arr = torch.randn(2, 10, 11) 356 | 357 | common.call( 358 | lambda a, b, c: common.assert_isclose(koila.run(torch.matmul(a, b)), c), 359 | [ 360 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T], 361 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T], 362 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T], 363 | ], 364 | ) 365 | 366 | 367 | def test_identity() -> None: 368 | tensor = torch.tensor(13.5) 369 | 370 | assert LazyTensor(tensor).run() == 13.5 371 | assert LazyTensor(tensor).item() == 13.5 372 | assert int(LazyTensor(tensor)) == 13 373 | assert float(LazyTensor(tensor)) == 13.5 374 | assert bool(LazyTensor(tensor)) 375 | 376 | tensor = torch.tensor(-17.5) 377 | assert LazyTensor(tensor).run() == -17.5 378 | assert LazyTensor(tensor).item() == -17.5 379 | assert int(LazyTensor(tensor)) == -17 380 | assert float(LazyTensor(tensor)) == -17.5 381 | assert bool(LazyTensor(tensor)) 382 | 383 | tensor = torch.tensor(0) 384 | assert not LazyTensor(tensor).run() 385 | assert not LazyTensor(tensor).item() 386 | assert not int(LazyTensor(tensor)) 387 | assert not float(LazyTensor(tensor)) 388 | assert not bool(LazyTensor(tensor)) 389 | 390 | 391 | def test_frac_method() -> None: 392 | common.call( 393 | lambda a, c: common.assert_isclose(a.frac().item(), c), 394 | [ 395 | [LazyTensor(torch.tensor(13.22)), 0.22], 396 | [LazyTensor(torch.tensor(55.0)), 0], 397 | [LazyTensor(torch.tensor(-55.55)), -0.55], 398 | ], 399 | ) 400 | 401 | 402 | def test_frac_function() -> None: 403 | common.call( 404 | lambda a, c: common.assert_isclose(torch.frac(a).item(), c), 405 | [ 406 | [LazyTensor(torch.tensor(25.25)), 0.25], 407 | [LazyTensor(torch.tensor(11.0)), 0], 408 | [LazyTensor(torch.tensor(-25.33)), -0.33], 409 | ], 410 | ) 411 | 412 | 413 | def test_exp_method() -> None: 414 | common.call( 415 | lambda a, c: common.assert_isclose(a.exp().item(), c), 416 | [ 417 | [LazyTensor(torch.tensor(1.23)), math.e**1.23], 418 | [LazyTensor(torch.tensor(0)), 1], 419 | [LazyTensor(torch.tensor(1)), math.e], 420 | ], 421 | ) 422 | 423 | 424 | def test_exp_function() -> None: 425 | common.call( 426 | lambda a, c: common.assert_isclose(torch.exp(a).item(), c), 427 | [ 428 | [LazyTensor(torch.tensor(0.41)), math.e**0.41], 429 | [LazyTensor(torch.tensor(0)), 1], 430 | [LazyTensor(torch.tensor(1)), math.e], 431 | ], 432 | ) 433 | 434 | 435 | def test_exp2_method() -> None: 436 | common.call( 437 | lambda a, c: common.assert_isclose(a.exp2().item(), c), 438 | [ 439 | [LazyTensor(torch.tensor(10)), 2**10], 440 | [LazyTensor(torch.tensor(0)), 1], 441 | [LazyTensor(torch.tensor(1)), 2], 442 | ], 443 | ) 444 | 445 | 446 | def test_exp2_function() -> None: 447 | common.call( 448 | lambda a, c: common.assert_isclose(torch.exp2(a).item(), c), 449 | [ 450 | [LazyTensor(torch.tensor(-5)), 2**-5], 451 | [LazyTensor(torch.tensor(0)), 1], 452 | [LazyTensor(torch.tensor(1)), 2], 453 | ], 454 | ) 455 | 456 | 457 | def test_log_method() -> None: 458 | common.call( 459 | lambda a, c: common.assert_isclose(a.log().item(), c), 460 | [ 461 | [LazyTensor(torch.tensor(13)), math.log(13)], 462 | [LazyTensor(torch.tensor(1)), 0], 463 | [LazyTensor(torch.tensor(math.e)), 1], 464 | ], 465 | ) 466 | 467 | 468 | def test_log_function() -> None: 469 | common.call( 470 | lambda a, c: common.assert_isclose(torch.log(a).item(), c), 471 | [ 472 | [LazyTensor(torch.tensor(5)), math.log(5)], 473 | [LazyTensor(torch.tensor(1)), 0], 474 | [LazyTensor(torch.tensor(math.e)), 1], 475 | ], 476 | ) 477 | 478 | 479 | def test_log2_method() -> None: 480 | common.call( 481 | lambda a, c: common.assert_isclose(a.log2().item(), c), 482 | [ 483 | [LazyTensor(torch.tensor(442)), math.log2(442)], 484 | [LazyTensor(torch.tensor(1)), 0], 485 | [LazyTensor(torch.tensor(2)), 1], 486 | ], 487 | ) 488 | 489 | 490 | def test_log2_function() -> None: 491 | common.call( 492 | lambda a, c: common.assert_isclose(torch.log2(a).item(), c), 493 | [ 494 | [LazyTensor(torch.tensor(81)), math.log2(81)], 495 | [LazyTensor(torch.tensor(1)), 0], 496 | [LazyTensor(torch.tensor(2)), 1], 497 | ], 498 | ) 499 | 500 | 501 | def test_log10_method() -> None: 502 | common.call( 503 | lambda a, c: common.assert_isclose(a.log10().item(), c), 504 | [ 505 | [LazyTensor(torch.tensor(132)), math.log10(132)], 506 | [LazyTensor(torch.tensor(1)), 0], 507 | [LazyTensor(torch.tensor(10)), 1], 508 | ], 509 | ) 510 | 511 | 512 | def test_log10_function() -> None: 513 | common.call( 514 | lambda a, c: common.assert_isclose(torch.log10(a).item(), c), 515 | [ 516 | [LazyTensor(torch.tensor(979)), math.log10(979)], 517 | [LazyTensor(torch.tensor(1)), 0], 518 | [LazyTensor(torch.tensor(10)), 1], 519 | ], 520 | ) 521 | 522 | 523 | def test_log1p_method() -> None: 524 | common.call( 525 | lambda a, c: common.assert_isclose(a.log1p().item(), c), 526 | [[LazyTensor(torch.tensor(1.5)), math.log1p(1.5)]], 527 | ) 528 | 529 | 530 | def test_log1p_function() -> None: 531 | common.call( 532 | lambda a, c: common.assert_isclose(torch.log1p(a).item(), c), 533 | [[LazyTensor(torch.tensor(2.7)), math.log1p(2.7)]], 534 | ) 535 | 536 | 537 | def test_abs_op() -> None: 538 | common.call( 539 | lambda a, c: common.assert_isclose(abs(a).item(), c), 540 | [ 541 | [LazyTensor(torch.tensor(-7.122)), abs(-7.122)], 542 | [LazyTensor(torch.tensor(4.002)), abs(4.002)], 543 | ], 544 | ) 545 | 546 | 547 | def test_abs_method() -> None: 548 | common.call( 549 | lambda a, c: common.assert_isclose(a.abs().item(), c), 550 | [ 551 | [LazyTensor(torch.tensor(-1.5)), abs(-1.5)], 552 | [LazyTensor(torch.tensor(3.7)), abs(3.7)], 553 | ], 554 | ) 555 | 556 | 557 | def test_abs_function() -> None: 558 | common.call( 559 | lambda a, c: common.assert_isclose(torch.abs(a).item(), c), 560 | [ 561 | [LazyTensor(torch.tensor(0.001)), abs(0.001)], 562 | [LazyTensor(torch.tensor(-24)), abs(-24)], 563 | ], 564 | ) 565 | 566 | 567 | def test_min_method() -> None: 568 | arr = torch.randn(6, 7, 8) 569 | 570 | common.call( 571 | lambda a, c: common.assert_isclose(koila.run(a), c), 572 | [ 573 | [LazyTensor(arr).min(), arr.min()], 574 | [LazyTensor(arr).min(1)[0], arr.min(1)[0]], 575 | [LazyTensor(arr).min(1)[1], arr.min(1)[1]], 576 | ], 577 | ) 578 | 579 | 580 | def test_min_function() -> None: 581 | arr = torch.randn(6, 7, 8) 582 | brr = torch.randn(1, 7, 8) 583 | la = typing.cast(Tensor, LazyTensor(arr)) 584 | lb = typing.cast(Tensor, LazyTensor(brr)) 585 | 586 | common.call( 587 | lambda a, c: common.assert_isclose(koila.run(a), c), 588 | [ 589 | [torch.min(la), torch.min(arr)], 590 | [torch.min(la, 2)[0], torch.min(arr, 2)[0]], 591 | [ 592 | torch.min(la, 1, keepdim=True).indices, 593 | torch.min(arr, 1, keepdim=True).indices, 594 | ], 595 | [torch.min(la, lb), torch.min(arr, brr)], 596 | ], 597 | ) 598 | 599 | 600 | def test_max_method() -> None: 601 | arr = torch.randn(6, 7, 8) 602 | 603 | common.call( 604 | lambda a, c: common.assert_isclose(koila.run(a), c), 605 | [ 606 | [LazyTensor(arr).max(), arr.max()], 607 | [LazyTensor(arr).max(1)[0], arr.max(1)[0]], 608 | [LazyTensor(arr).max(1)[1], arr.max(1)[1]], 609 | ], 610 | ) 611 | 612 | 613 | def test_max_function() -> None: 614 | arr = torch.randn(6, 7, 8) 615 | brr = torch.randn(1, 7, 8) 616 | la = typing.cast(Tensor, LazyTensor(arr)) 617 | lb = typing.cast(Tensor, LazyTensor(brr)) 618 | 619 | common.call( 620 | lambda a, c: common.assert_isclose(koila.run(a), c), 621 | [ 622 | [torch.max(la), torch.max(arr)], 623 | [torch.max(la, 2)[0], torch.max(arr, 2)[0]], 624 | [ 625 | torch.max(la, 1, keepdim=True).indices, 626 | torch.max(arr, 1, keepdim=True).indices, 627 | ], 628 | [torch.max(la, lb), torch.max(arr, brr)], 629 | ], 630 | ) 631 | 632 | 633 | def test_size_shape_method() -> None: 634 | arr = torch.randn(11, 13) 635 | la = LazyTensor(arr) 636 | assert la.size() == la.shape == (11, 13) 637 | assert la.size(0) == 11 638 | assert la.size(1) == 13 639 | 640 | 641 | def test_t_method() -> None: 642 | arr = torch.randn(11, 13) 643 | la = LazyTensor(arr) 644 | assert la.T.size() == la.t().size() == (13, 11) 645 | 646 | 647 | def test_t_function() -> None: 648 | arr = torch.randn(11, 13) 649 | la = typing.cast(Tensor, LazyTensor(arr)) 650 | assert torch.t(la).shape == (13, 11) 651 | 652 | 653 | def test_dim_method() -> None: 654 | arr = torch.randn(11, 13) 655 | assert arr.ndim == arr.dim() == 2 656 | arr = torch.randn(1, 2, 3, 4, 5) 657 | assert arr.dim() == 5 658 | 659 | 660 | def test_permute_method() -> None: 661 | arr = torch.randn(2, 3, 4, 5, 6) 662 | la = LazyTensor(arr) 663 | assert la.permute(3, 4, 1, 2, 0).shape == (5, 6, 3, 4, 2) 664 | assert la.permute(0, 1, 4, 3, 2).shape == (2, 3, 6, 5, 4) 665 | 666 | 667 | def test_permute_function() -> None: 668 | arr = torch.randn(2, 3, 4, 5, 6) 669 | la = typing.cast(Tensor, LazyTensor(arr)) 670 | assert torch.permute(la, (3, 4, 1, 2, 0)).shape == (5, 6, 3, 4, 2) 671 | assert torch.permute(la, (0, 1, 4, 3, 2)).shape == (2, 3, 6, 5, 4) 672 | 673 | 674 | def test_transpose_method() -> None: 675 | arr = torch.randn(2, 3, 4, 5, 6) 676 | la = LazyTensor(arr) 677 | assert la.transpose(3, 4).shape == (2, 3, 4, 6, 5) 678 | assert la.transpose(0, 1).shape == (3, 2, 4, 5, 6) 679 | assert la.transpose(0, 3).shape == (5, 3, 4, 2, 6) 680 | 681 | 682 | def test_select_method() -> None: 683 | arr = torch.randn(3, 4, 5) 684 | sel = arr.select(1, 2) 685 | assert isinstance(sel, Tensor) 686 | assert not isinstance(sel, LazyTensor) 687 | 688 | la = LazyTensor(arr) 689 | lsel = la.select(1, 2) 690 | 691 | assert not isinstance(lsel, Tensor) 692 | assert isinstance(lsel, LazyTensor) 693 | assert sel.size() == lsel.size() == (3, 5) 694 | common.assert_isclose(lsel.run(), sel) 695 | 696 | 697 | def test_select_function() -> None: 698 | arr = torch.randn(3, 4, 5) 699 | sel = torch.select(arr, 1, 2) 700 | assert isinstance(sel, Tensor) 701 | assert not isinstance(sel, LazyTensor) 702 | 703 | la = typing.cast(Tensor, LazyTensor(arr)) 704 | lsel = torch.select(la, 1, 2) 705 | 706 | assert not isinstance(lsel, Tensor) 707 | assert isinstance(lsel, LazyTensor) 708 | assert sel.size() == lsel.size() == (3, 5) 709 | common.assert_isclose(lsel.run(), sel) 710 | 711 | 712 | def test_index_select_method() -> None: 713 | arr = torch.randn(3, 4, 5) 714 | idx = torch.tensor([1, 2, 3]) 715 | sel = arr.index_select(1, idx) 716 | assert isinstance(sel, Tensor) 717 | assert not isinstance(sel, LazyTensor) 718 | 719 | la = LazyTensor(arr) 720 | lsel = la.index_select(1, idx) 721 | 722 | assert not isinstance(lsel, Tensor) 723 | assert isinstance(lsel, LazyTensor) 724 | assert sel.size() == lsel.size() == (3, 3, 5) 725 | common.assert_isclose(lsel.run(), sel) 726 | 727 | 728 | def test_index_select_function() -> None: 729 | arr = torch.randn(3, 4, 5) 730 | idx = torch.tensor([1, 2, 3]) 731 | sel = torch.index_select(arr, 1, idx) 732 | assert isinstance(sel, Tensor) 733 | assert not isinstance(sel, LazyTensor) 734 | 735 | la = typing.cast(Tensor, LazyTensor(arr)) 736 | lsel = torch.index_select(la, 1, idx) 737 | 738 | assert not isinstance(lsel, Tensor) 739 | assert isinstance(lsel, LazyTensor) 740 | assert sel.size() == lsel.size() == (3, 3, 5) 741 | common.assert_isclose(lsel.run(), sel) 742 | 743 | 744 | def test_numel_method() -> None: 745 | arr = torch.randn(2, 3, 4, 5, 6) 746 | la = typing.cast(Tensor, LazyTensor(arr)) 747 | assert la.numel() == 2 * 3 * 4 * 5 * 6 748 | 749 | arr = torch.randn(15, 19) 750 | la = typing.cast(Tensor, LazyTensor(arr)) 751 | assert la.numel() == 15 * 19 752 | 753 | 754 | def test_numel_function() -> None: 755 | arr = torch.randn(2, 3, 4, 5, 6) 756 | la = typing.cast(Tensor, LazyTensor(arr)) 757 | assert torch.numel(la) == 2 * 3 * 4 * 5 * 6 758 | 759 | arr = torch.randn(15, 19) 760 | la = typing.cast(Tensor, LazyTensor(arr)) 761 | assert torch.numel(la) == 15 * 19 762 | 763 | 764 | def test_sigmoid_method() -> None: 765 | arr = torch.randn(4, 5, 6) 766 | common.call( 767 | lambda a, c: common.assert_isclose(koila.run(a), c), 768 | [[LazyTensor(arr).sigmoid(), torch.sigmoid(arr)]], 769 | ) 770 | 771 | 772 | def test_sigmoid_function() -> None: 773 | arr = torch.randn(4, 5, 6) 774 | la = typing.cast(Tensor, arr) 775 | common.call( 776 | lambda a, c: common.assert_isclose(koila.run(a), c), 777 | [[torch.sigmoid(la), torch.sigmoid(arr)]], 778 | ) 779 | 780 | 781 | def test_sin_method() -> None: 782 | common.call( 783 | lambda a, c: common.assert_isclose(a.sin().item(), c), 784 | [ 785 | [LazyTensor(torch.tensor(0)), 0], 786 | [LazyTensor(torch.tensor(math.pi)), 0], 787 | [LazyTensor(torch.tensor(math.pi / 2)), 1], 788 | [LazyTensor(torch.tensor(3 * math.pi / 2)), -1], 789 | [LazyTensor(torch.tensor(42.0)), math.sin(42)], 790 | [LazyTensor(torch.tensor(-75.0)), math.sin(-75)], 791 | ], 792 | ) 793 | 794 | 795 | def test_sin_function() -> None: 796 | common.call( 797 | lambda a, c: common.assert_isclose(torch.sin(a).item(), c), 798 | [ 799 | [LazyTensor(torch.tensor(0)), 0], 800 | [LazyTensor(torch.tensor(math.pi)), 0], 801 | [LazyTensor(torch.tensor(math.pi / 2)), 1], 802 | [LazyTensor(torch.tensor(3 * math.pi / 2)), -1], 803 | [LazyTensor(torch.tensor(42.0)), math.sin(42)], 804 | [LazyTensor(torch.tensor(-75.0)), math.sin(-75)], 805 | ], 806 | ) 807 | 808 | 809 | def test_cos_method() -> None: 810 | common.call( 811 | lambda a, c: common.assert_isclose(a.cos().item(), c), 812 | [ 813 | [LazyTensor(torch.tensor(0)), 1], 814 | [LazyTensor(torch.tensor(math.pi)), -1], 815 | [LazyTensor(torch.tensor(math.pi / 2)), 0], 816 | [LazyTensor(torch.tensor(3 * math.pi / 2)), 0], 817 | [LazyTensor(torch.tensor(27.0)), math.cos(27)], 818 | [LazyTensor(torch.tensor(-14.0)), math.cos(-14)], 819 | ], 820 | ) 821 | 822 | 823 | def test_cos_function() -> None: 824 | common.call( 825 | lambda a, c: common.assert_isclose(torch.cos(a).item(), c), 826 | [ 827 | [LazyTensor(torch.tensor(0)), 1], 828 | [LazyTensor(torch.tensor(math.pi)), -1], 829 | [LazyTensor(torch.tensor(math.pi / 2)), 0], 830 | [LazyTensor(torch.tensor(3 * math.pi / 2)), 0], 831 | [LazyTensor(torch.tensor(27.0)), math.cos(27)], 832 | [LazyTensor(torch.tensor(-14.0)), math.cos(-14)], 833 | ], 834 | ) 835 | 836 | 837 | def test_tan_method() -> None: 838 | common.call( 839 | lambda a, c: common.assert_isclose(a.tan().item(), c), 840 | [ 841 | [LazyTensor(torch.tensor(0)), 0], 842 | [LazyTensor(torch.tensor(math.pi)), 0], 843 | [LazyTensor(torch.tensor(99.0)), math.tan(99)], 844 | [LazyTensor(torch.tensor(-4.0)), math.tan(-4)], 845 | ], 846 | ) 847 | 848 | 849 | def test_tan_function() -> None: 850 | common.call( 851 | lambda a, c: common.assert_isclose(torch.tan(a).item(), c), 852 | [ 853 | [LazyTensor(torch.tensor(0)), 0], 854 | [LazyTensor(torch.tensor(math.pi)), 0], 855 | [LazyTensor(torch.tensor(99.0)), math.tan(99)], 856 | [LazyTensor(torch.tensor(-4.0)), math.tan(-4)], 857 | ], 858 | ) 859 | 860 | 861 | def test_asin_method() -> None: 862 | common.call( 863 | lambda a, c: common.assert_isclose(a.asin().item(), c), 864 | [ 865 | [LazyTensor(torch.tensor(n)), math.asin(n)] 866 | for n in np.linspace(-1, 1).tolist() 867 | ], 868 | ) 869 | 870 | 871 | def test_asin_function() -> None: 872 | common.call( 873 | lambda a, c: common.assert_isclose(torch.asin(a).item(), c), 874 | [ 875 | [LazyTensor(torch.tensor(n)), math.asin(n)] 876 | for n in np.linspace(-1, 1).tolist() 877 | ], 878 | ) 879 | 880 | 881 | def test_acos_method() -> None: 882 | common.call( 883 | lambda a, c: common.assert_isclose(a.acos().item(), c), 884 | [ 885 | [LazyTensor(torch.tensor(n)), math.acos(n)] 886 | for n in np.linspace(-1, 1).tolist() 887 | ], 888 | ) 889 | 890 | 891 | def test_acos_function() -> None: 892 | common.call( 893 | lambda a, c: common.assert_isclose(torch.acos(a).item(), c), 894 | [ 895 | [LazyTensor(torch.tensor(n)), math.acos(n)] 896 | for n in np.linspace(-1, 1).tolist() 897 | ], 898 | ) 899 | 900 | 901 | def test_atan_method() -> None: 902 | common.call( 903 | lambda a, c: common.assert_isclose(a.atan().item(), c), 904 | [ 905 | [LazyTensor(torch.tensor(99.0)), math.atan(99)], 906 | [LazyTensor(torch.tensor(-4.0)), math.atan(-4)], 907 | [LazyTensor(torch.tensor(-6.0)), math.atan(-6)], 908 | [LazyTensor(torch.tensor(242.0)), math.atan(242)], 909 | ], 910 | ) 911 | 912 | 913 | def test_atan_function() -> None: 914 | common.call( 915 | lambda a, c: common.assert_isclose(torch.atan(a).item(), c), 916 | [ 917 | [LazyTensor(torch.tensor(99.0)), math.atan(99)], 918 | [LazyTensor(torch.tensor(-4.0)), math.atan(-4)], 919 | [LazyTensor(torch.tensor(-6.0)), math.atan(-6)], 920 | [LazyTensor(torch.tensor(242.0)), math.atan(242)], 921 | ], 922 | ) 923 | 924 | 925 | def test_sinh_method() -> None: 926 | common.call( 927 | lambda a, c: common.assert_isclose(a.sinh().item(), c), 928 | [ 929 | [LazyTensor(torch.tensor(n)), math.sinh(n)] 930 | for n in np.linspace(-1, 1).tolist() 931 | ], 932 | ) 933 | 934 | 935 | def test_sinh_function() -> None: 936 | common.call( 937 | lambda a, c: common.assert_isclose(torch.sinh(a).item(), c), 938 | [ 939 | [LazyTensor(torch.tensor(n)), math.sinh(n)] 940 | for n in np.linspace(-1, 1).tolist() 941 | ], 942 | ) 943 | 944 | 945 | def test_cosh_method() -> None: 946 | common.call( 947 | lambda a, c: common.assert_isclose(a.cosh().item(), c), 948 | [ 949 | [LazyTensor(torch.tensor(n)), math.cosh(n)] 950 | for n in np.linspace(-1, 1).tolist() 951 | ], 952 | ) 953 | 954 | 955 | def test_cosh_function() -> None: 956 | common.call( 957 | lambda a, c: common.assert_isclose(torch.cosh(a).item(), c), 958 | [ 959 | [LazyTensor(torch.tensor(n)), math.cosh(n)] 960 | for n in np.linspace(-1, 1).tolist() 961 | ], 962 | ) 963 | 964 | 965 | def test_tanh_method() -> None: 966 | common.call( 967 | lambda a, c: common.assert_isclose(a.tanh().item(), c), 968 | [[LazyTensor(torch.tensor(n)), math.tanh(n)] for n in np.linspace(-10, 10)], 969 | ) 970 | 971 | 972 | def test_tanh_function() -> None: 973 | common.call( 974 | lambda a, c: common.assert_isclose(torch.tanh(a).item(), c), 975 | [[LazyTensor(torch.tensor(n)), math.tanh(n)] for n in np.linspace(-10, 10)], 976 | ) 977 | 978 | 979 | def test_asinh_method() -> None: 980 | common.call( 981 | lambda a, c: common.assert_isclose(a.asinh().item(), c), 982 | [ 983 | [LazyTensor(torch.tensor(199.0)), math.asinh(199)], 984 | [LazyTensor(torch.tensor(-241.0)), math.asinh(-241)], 985 | [LazyTensor(torch.tensor(-9.0)), math.asinh(-9)], 986 | [LazyTensor(torch.tensor(0.0)), math.asinh(0)], 987 | ], 988 | ) 989 | 990 | 991 | def test_asinh_function() -> None: 992 | common.call( 993 | lambda a, c: common.assert_isclose(torch.asinh(a).item(), c), 994 | [ 995 | [LazyTensor(torch.tensor(199.0)), math.asinh(199)], 996 | [LazyTensor(torch.tensor(-241.0)), math.asinh(-241)], 997 | [LazyTensor(torch.tensor(-9.0)), math.asinh(-9)], 998 | [LazyTensor(torch.tensor(0.0)), math.asinh(0)], 999 | ], 1000 | ) 1001 | 1002 | 1003 | def test_acosh_method() -> None: 1004 | common.call( 1005 | lambda a, c: common.assert_isclose(a.acosh().item(), c), 1006 | [ 1007 | [LazyTensor(torch.tensor(14.0)), math.acosh(14)], 1008 | [LazyTensor(torch.tensor(2.0)), math.acosh(2)], 1009 | [LazyTensor(torch.tensor(1.0)), math.acosh(1)], 1010 | [LazyTensor(torch.tensor(65.0)), math.acosh(65)], 1011 | ], 1012 | ) 1013 | 1014 | 1015 | def test_acosh_function() -> None: 1016 | common.call( 1017 | lambda a, c: common.assert_isclose(torch.acosh(a).item(), c), 1018 | [ 1019 | [LazyTensor(torch.tensor(14.0)), math.acosh(14)], 1020 | [LazyTensor(torch.tensor(2.0)), math.acosh(2)], 1021 | [LazyTensor(torch.tensor(1.0)), math.acosh(1)], 1022 | [LazyTensor(torch.tensor(65.0)), math.acosh(65)], 1023 | ], 1024 | ) 1025 | 1026 | 1027 | def test_atanh_method() -> None: 1028 | common.call( 1029 | lambda a, c: common.assert_isclose(a.atanh().item(), c), 1030 | [ 1031 | [LazyTensor(torch.tensor(n)), math.atanh(n)] 1032 | for n in np.linspace(-0.99, 0.99, endpoint=False).tolist() 1033 | ], 1034 | ) 1035 | 1036 | 1037 | def test_atanh_function() -> None: 1038 | common.call( 1039 | lambda a, c: common.assert_isclose(torch.atanh(a).item(), c), 1040 | [ 1041 | [LazyTensor(torch.tensor(n)), math.atanh(n)] 1042 | for n in np.linspace(-0.99, 0.99, endpoint=False).tolist() 1043 | ], 1044 | ) 1045 | 1046 | 1047 | def test_run_method() -> None: 1048 | random = torch.randn(3, 4, 5, 6) 1049 | common.call( 1050 | lambda a, b: common.assert_isclose(a.run(), b), [[LazyTensor(random), random]] 1051 | ) 1052 | 1053 | 1054 | def test_torch_method() -> None: 1055 | random = torch.randn(3, 4, 5, 6) 1056 | common.call( 1057 | lambda a, b: common.assert_isclose(a.torch(), b), [[LazyTensor(random), random]] 1058 | ) 1059 | 1060 | 1061 | def test_numpy_method() -> None: 1062 | random = torch.randn(3, 4, 5, 6) 1063 | common.call( 1064 | lambda a, b: common.assert_isclose(a.numpy(), b.numpy()), 1065 | [[LazyTensor(random), random]], 1066 | ) 1067 | 1068 | 1069 | def test_pad_function() -> None: 1070 | tensor = torch.randn(3, 4, 5, 6) 1071 | padded = F.pad(tensor, (2, 3, 0, 1), mode="reflect") 1072 | assert isinstance(padded, Tensor) 1073 | assert not isinstance(padded, LazyTensor) 1074 | 1075 | la = typing.cast(Tensor, LazyTensor(tensor)) 1076 | lazy_padded = F.pad(la, (2, 3, 0, 1), mode="reflect") 1077 | assert not isinstance(lazy_padded, Tensor) 1078 | assert isinstance(lazy_padded, LazyTensor) 1079 | assert padded.shape == lazy_padded.shape 1080 | 1081 | common.assert_isclose(lazy_padded.run(), padded) 1082 | 1083 | 1084 | def test_buffer_sizes() -> None: 1085 | a = torch.randn(4, 5, 6) 1086 | 1087 | la = LazyTensor(a) 1088 | assert a.numel() == la.numel() == la.buffer_numel()[1] 1089 | 1090 | b = torch.randn(4, 5, 1) 1091 | lb = LazyTensor(b) 1092 | assert b.numel() == lb.numel() == lb.buffer_numel()[1] 1093 | 1094 | lc = typing.cast(LazyTensor, la + lb) 1095 | assert lc.numel() == la.numel() == 6 * lb.numel() 1096 | assert lc.buffer_numel()[1] == la.numel() + lb.numel() + lc.numel() 1097 | 1098 | d = torch.randn(4, 5, 6) 1099 | ld = typing.cast(LazyTensor, d) 1100 | 1101 | le = typing.cast(LazyTensor, lc * ld) 1102 | assert d.numel() == ld.numel() == le.numel() 1103 | assert le.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le})) 1104 | 1105 | lf = le.sum() 1106 | assert lf.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le, lf})) 1107 | 1108 | lg = typing.cast(LazyTensor, lc + le) 1109 | assert lg.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le, lg})) 1110 | 1111 | assert lg.buffer_memory()[1] == lg.buffer_numel()[1] * 4 1112 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Flatten, Linear, Module, ReLU, Sequential 6 | 7 | from koila import BatchInfo, LazyTensor 8 | 9 | from . import common 10 | 11 | 12 | def test_torch_tutorial() -> None: 13 | "Testing the model taken from pytorch's tutorial." 14 | 15 | class NeuralNetwork(Module): 16 | def __init__(self): 17 | super(NeuralNetwork, self).__init__() 18 | self.flatten = Flatten() 19 | self.linear_relu_stack = Sequential( 20 | Linear(28 * 28, 512), 21 | ReLU(), 22 | Linear(512, 512), 23 | ReLU(), 24 | Linear(512, 10), 25 | ) 26 | 27 | def forward(self, x): 28 | x = self.flatten(x) 29 | logits = self.linear_relu_stack(x) 30 | return logits 31 | 32 | input = torch.randn(9, 28, 28) 33 | nn = NeuralNetwork() 34 | 35 | output = nn(input) 36 | assert output.shape == (9, 10) 37 | assert isinstance(output, Tensor) 38 | assert not isinstance(output, LazyTensor) 39 | 40 | lazy_input = LazyTensor(input, batch=0) 41 | assert lazy_input.batch() == BatchInfo(0, 9) 42 | nn = NeuralNetwork() 43 | 44 | lazy_output = nn(lazy_input) 45 | assert lazy_output.shape == (9, 10) 46 | assert not isinstance(lazy_output, Tensor) 47 | assert isinstance(lazy_output, LazyTensor) 48 | 49 | assert lazy_input.run((3, 6)).size() == (3, 28, 28) 50 | common.assert_isclose(lazy_input.run((3, 6)), input[3:6]) 51 | tbout = lazy_output.run((3, 6)) 52 | assert tbout.shape == (3, 10) 53 | assert isinstance(tbout, Tensor) 54 | assert not isinstance(tbout, LazyTensor) 55 | common.assert_isclose(tbout, nn(input[3:6])) 56 | 57 | assert lazy_output.batch() == BatchInfo(0, 9) 58 | -------------------------------------------------------------------------------- /tests/test_prepasses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) RenChu Wang - All Rights Reserved 2 | 3 | import torch 4 | 5 | from koila import PrePassFunc, prepasses 6 | 7 | from . import common 8 | 9 | 10 | def test_compatibility() -> None: 11 | assert isinstance(prepasses.identity, PrePassFunc) 12 | assert isinstance(prepasses.symmetric, PrePassFunc) 13 | assert isinstance(prepasses.reduce_dims, PrePassFunc) 14 | assert isinstance(prepasses.permute, PrePassFunc) 15 | assert isinstance(prepasses.tranpose, PrePassFunc) 16 | assert isinstance(prepasses.view, PrePassFunc) 17 | assert isinstance(prepasses.reshape, PrePassFunc) 18 | assert isinstance(prepasses.flatten, PrePassFunc) 19 | assert isinstance(prepasses.matmul, PrePassFunc) 20 | assert isinstance(prepasses.linear, PrePassFunc) 21 | assert isinstance(prepasses.cat, PrePassFunc) 22 | assert isinstance(prepasses.pad, PrePassFunc) 23 | assert isinstance(prepasses.conv, PrePassFunc) 24 | assert isinstance(prepasses.conv_transpose, PrePassFunc) 25 | assert isinstance(prepasses.pool, PrePassFunc) 26 | assert isinstance(prepasses.maxpool, PrePassFunc) 27 | assert isinstance(prepasses.avgpool, PrePassFunc) 28 | 29 | 30 | def test_identity() -> None: 31 | common.call( 32 | common.assert_equal, 33 | [ 34 | [prepasses.identity(torch.randn(1, 2, 3, 4, 5)), (1, 2, 3, 4, 5)], 35 | [prepasses.identity(torch.randn(4, 2, 5)), (4, 2, 5)], 36 | [prepasses.identity(torch.randn(17, 1, 4)), (17, 1, 4)], 37 | ], 38 | ) 39 | 40 | 41 | def test_symmetric() -> None: 42 | common.call( 43 | common.assert_equal, 44 | [ 45 | [prepasses.symmetric(torch.randn(2, 4, 5), torch.randn(())), (2, 4, 5)], 46 | [ 47 | prepasses.symmetric(torch.randn(2, 4, 5), torch.randn(2, 4, 5)), 48 | (2, 4, 5), 49 | ], 50 | [ 51 | prepasses.symmetric(torch.randn(2, 1, 5), torch.randn(2, 4, 5)), 52 | (2, 4, 5), 53 | ], 54 | [ 55 | prepasses.symmetric(torch.randn(2, 1, 5), torch.randn(2, 4, 1)), 56 | (2, 4, 5), 57 | ], 58 | ], 59 | ) 60 | 61 | 62 | def test_reduce_dims() -> None: 63 | common.call( 64 | common.assert_equal, 65 | [ 66 | [prepasses.reduce_dims(torch.randn(1, 2, 3, 4, 5), 1), (1, 3, 4, 5)], 67 | [prepasses.reduce_dims(torch.randn(1, 2, 3, 4, 5), (2, 3)), (1, 2, 5)], 68 | [ 69 | prepasses.reduce_dims(torch.randn(5, 2, 3, 4), (2, 3), keepdim=True), 70 | (5, 2, 1, 1), 71 | ], 72 | ], 73 | ) 74 | 75 | 76 | def test_scalar() -> None: 77 | common.call( 78 | common.assert_equal, 79 | [ 80 | [prepasses.reduce_dims(torch.randn(5, 5, 2)), ()], 81 | [prepasses.reduce_dims(torch.randn(7, 8)), ()], 82 | ], 83 | ) 84 | 85 | 86 | def test_matmul() -> None: 87 | common.call( 88 | common.assert_equal, 89 | [ 90 | [prepasses.matmul(torch.randn(8), torch.randn(8)), ()], 91 | [prepasses.matmul(torch.randn(8, 3), torch.randn(3)), (8,)], 92 | [prepasses.matmul(torch.randn(8), torch.randn(8, 3)), (3,)], 93 | [prepasses.matmul(torch.randn(4, 5), torch.randn(5, 3)), (4, 3)], 94 | [prepasses.matmul(torch.randn(9, 4, 5), torch.randn(9, 5, 3)), (9, 4, 3)], 95 | [prepasses.matmul(torch.randn(9, 4, 5), torch.randn(1, 5, 3)), (9, 4, 3)], 96 | [ 97 | prepasses.matmul(torch.randn(9, 7, 4, 5), torch.randn(1, 5, 3)), 98 | (9, 7, 4, 3), 99 | ], 100 | ], 101 | ) 102 | 103 | 104 | def test_transpose() -> None: 105 | common.call( 106 | common.assert_equal, 107 | [[prepasses.tranpose(torch.randn(3, 4, 5), 1, 2), (3, 5, 4)]], 108 | ) 109 | 110 | 111 | def test_linear() -> None: 112 | common.call( 113 | common.assert_equal, 114 | [ 115 | [ 116 | prepasses.linear( 117 | torch.randn(7, 11, 13), 118 | weight=torch.randn(17, 13), 119 | bias=torch.randn(17), 120 | ), 121 | (7, 11, 17), 122 | ] 123 | ], 124 | ) 125 | 126 | 127 | def test_cat() -> None: 128 | common.call( 129 | common.assert_equal, 130 | [ 131 | [prepasses.cat([torch.randn(2, 3, 5), torch.randn(3, 3, 5)]), (5, 3, 5)], 132 | [ 133 | prepasses.cat([torch.randn(2, 3, 5), torch.randn(2, 4, 5)], dim=1), 134 | (2, 7, 5), 135 | ], 136 | ], 137 | ) 138 | 139 | 140 | def test_loss() -> None: 141 | common.call( 142 | common.assert_equal, 143 | [ 144 | [prepasses.loss(torch.randn(2, 4, 5), torch.randn(2, 4, 5)), ()], 145 | [ 146 | prepasses.loss( 147 | torch.randn(2, 4, 5), torch.randn(2, 4, 5), reduction="sum" 148 | ), 149 | (), 150 | ], 151 | [ 152 | prepasses.loss( 153 | torch.randn(2, 4, 5), torch.randn(2, 4, 5), reduction="none" 154 | ), 155 | (2, 4, 5), 156 | ], 157 | ], 158 | ) 159 | --------------------------------------------------------------------------------