├── docs
├── .nojekyll
├── assets
├── README.md
├── CNAME
├── _coverpage.md
└── index.html
├── tests
├── __init__.py
├── test_models.py
├── common.py
├── test_prepasses.py
├── test_layers.py
└── test_lazy.py
├── assets
├── koila.png
└── koila.svg
├── CITATION.cff
├── src
└── koila
│ ├── errors.py
│ ├── __init__.py
│ ├── constants.py
│ ├── eager.py
│ ├── gpus.py
│ ├── interfaces.py
│ ├── shapes.py
│ ├── prepasses.py
│ └── lazy.py
├── .github
└── workflows
│ ├── typecheck.yaml
│ ├── release.yaml
│ ├── unittest.yaml
│ ├── format.yaml
│ └── build.yaml
├── LICENSE.md
├── pyproject.toml
├── .gitignore
├── CODE_OF_CONDUCT.md
└── README.md
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/assets:
--------------------------------------------------------------------------------
1 | ../assets/
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | koila.rentruewang.com
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
--------------------------------------------------------------------------------
/assets/koila.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rentruewang/koila/HEAD/assets/koila.png
--------------------------------------------------------------------------------
/docs/_coverpage.md:
--------------------------------------------------------------------------------
1 | # Koila
2 |
3 | No more `CUDA error: out of memory error`.
4 |
5 | [Github](https://github.com/rentruewang/koila)
6 | [Website](https://koila.rentruewang.com)
7 | [Why Archive](#archive)
8 | [Successor: Aioway](https://github.com/rentruewang/aioway)
9 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Wang"
5 | given-names: "RenChu"
6 | title: "Koila"
7 | version: 0.1.1
8 | date-released: 2021-11-23
9 | url: "https://github.com/rentruewang/koila"
10 |
--------------------------------------------------------------------------------
/src/koila/errors.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from typing import NoReturn
4 |
5 |
6 | class UnsupportedError(RuntimeError):
7 | "Sorry, this function is currently not supported."
8 |
9 | @classmethod
10 | def raise_error(cls, *args, **kwargs) -> NoReturn:
11 | del args
12 | del kwargs
13 | raise cls
14 |
--------------------------------------------------------------------------------
/src/koila/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from . import constants, gpus
4 | from .eager import EagerTensor
5 | from .errors import UnsupportedError
6 | from .interfaces import (
7 | BatchedPair,
8 | BatchInfo,
9 | Runnable,
10 | RunnableTensor,
11 | TensorMixin,
12 | run,
13 | )
14 | from .lazy import Evaluation, LazyFunction, LazyTensor, lazy
15 | from .prepasses import CallBack, MetaData, PrePass, PrePassFunc
16 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/workflows/typecheck.yaml:
--------------------------------------------------------------------------------
1 | name: Type Checking
2 | on: [push]
3 | jobs:
4 | type-check:
5 | name: 👨⚕️ Type Checking
6 | runs-on: ubuntu-latest
7 | steps:
8 | - name: 🔔 Check out
9 | uses: actions/checkout@v3
10 |
11 | - name: 🏗️ python
12 | uses: actions/setup-python@v4
13 | with:
14 | python-version: "3.10"
15 |
16 | - name: ⬇️ Python PDM
17 | uses: pdm-project/setup-pdm@v4
18 |
19 | - name: ⬇️ Python Dependencies
20 | run: pdm install -G:all
21 |
22 | - name: 🚂 Activate environment
23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH
24 |
25 | - name: 🏃 mypy
26 | run: mypy . --disable-error-code=import-untyped --disable-error-code=import-not-found
27 |
--------------------------------------------------------------------------------
/src/koila/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from typing import Dict
4 |
5 | import torch
6 | from torch import dtype
7 |
8 | UNITS: Dict[str, int] = {
9 | "b": 1,
10 | "kb": 10**3,
11 | "kib": 2**10,
12 | "mb": 10**6,
13 | "mib": 2**20,
14 | "gb": 10**9,
15 | "gib": 2**30,
16 | "tb": 10**4,
17 | "tib": 2**40,
18 | }
19 |
20 | MEMORY_BYTES: Dict[dtype, int] = {
21 | torch.bool: 1,
22 | torch.uint8: 1,
23 | torch.int8: 1,
24 | torch.short: 2,
25 | torch.int16: 2,
26 | torch.int: 4,
27 | torch.int32: 4,
28 | torch.long: 8,
29 | torch.int64: 8,
30 | torch.half: 2,
31 | torch.float16: 2,
32 | torch.float: 4,
33 | torch.float32: 4,
34 | torch.double: 8,
35 | torch.float64: 8,
36 | }
37 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 |
8 | jobs:
9 | pypi-publish:
10 | name: ⬆️ Upload release to PyPI
11 | runs-on: ubuntu-latest
12 | permissions:
13 | contents: read
14 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
15 |
16 | steps:
17 | - name: 🔔 Check out
18 | uses: actions/checkout@v3
19 |
20 | - name: 🏗️ python
21 | uses: actions/setup-python@v4
22 | with:
23 | python-version: "3.13"
24 |
25 | - name: ⬇️ Python PDM
26 | uses: pdm-project/setup-pdm@v4
27 | with:
28 | cache: true
29 |
30 | - name: ⬇️ Python Dependencies
31 | run: pdm sync -G:all
32 |
33 | - name: 📰 Publish to PyPI
34 | run: pdm publish
35 |
--------------------------------------------------------------------------------
/.github/workflows/unittest.yaml:
--------------------------------------------------------------------------------
1 | name: Unit Testing
2 | on: [push]
3 | jobs:
4 | unit-test:
5 | name: 🧪 Unit Testing
6 | runs-on: ubuntu-latest
7 | steps:
8 | - name: 🔔 Check out
9 | uses: actions/checkout@v3
10 |
11 | - name: 🏗️ python
12 | uses: actions/setup-python@v4
13 | with:
14 | python-version: "3.10"
15 |
16 | - name: ⬇️ Python PDM
17 | uses: pdm-project/setup-pdm@v4
18 |
19 | - name: ⬇️ Python Dependencies
20 | run: pdm install -G:all
21 |
22 | - name: 🚂 Activate environment
23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH
24 |
25 | - name: 🏃 pytest
26 | run: pytest -xv
27 |
28 | # - name: 🏃 pytest
29 | # run: coverage run -m pytest -v
30 |
31 | # - name: 📊 coverage
32 | # run: coverage report -m
33 |
--------------------------------------------------------------------------------
/.github/workflows/format.yaml:
--------------------------------------------------------------------------------
1 | name: Formatting
2 | on: [push]
3 | jobs:
4 | format-all:
5 | name: 📀 Formatting
6 | runs-on: ubuntu-latest
7 | steps:
8 | - name: 🔔 Check out
9 | uses: actions/checkout@v3
10 |
11 | - name: 🏗️ python
12 | uses: actions/setup-python@v4
13 | with:
14 | python-version: "3.10"
15 |
16 | - name: ⬇️ Python PDM
17 | uses: pdm-project/setup-pdm@v4
18 |
19 | - name: ⬇️ Python Dependencies
20 | run: pdm install -G:all
21 |
22 | - name: 🚂 Activate environment
23 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH
24 |
25 | - name: 🏃 autoflake, isort, black
26 | run: |
27 | autoflake -cr $(find -iname "*.py" ! -path '*/.venv/*' ! -name __init__.py) --remove-all-unused-imports
28 | isort --profile black --check .
29 | black --check .
30 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021-2025 RenChu Wang
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "koila"
3 | description = "Prevent PyTorch's `CUDA error out of memory` in a few lines of code"
4 | authors = [
5 | {name = "RenChu Wang", email = "patrick1031wang@gmail.com"},
6 | ]
7 | dependencies = [
8 | "numpy>=1.26.3",
9 | "scipy>=1.11.4",
10 | "torch>=2.1.2",
11 | "black>=24.4.2",
12 | ]
13 | requires-python = ">=3.10"
14 | readme = "README.md"
15 | license = {text = "Apache-2.0"}
16 | dynamic = ["version"]
17 |
18 | [build-system]
19 | requires = ["setuptools", "wheel", "setuptools-scm"]
20 | build-backend = "setuptools.build_meta"
21 |
22 | [tool.setuptools_scm]
23 |
24 | [tool.pdm]
25 | distribution = true
26 |
27 | [tool.pdm.dev-dependencies]
28 | test = [
29 | "coverage>=7.4.0",
30 | "pytest>=7.4.4",
31 | "pytest-cov>=4.1.0",
32 | "pytest-xdist>=3.5.0",
33 | ]
34 | format = [
35 | "autoflake>=2.2.1",
36 | "black>=23.12.1",
37 | "isort>=5.13.2",
38 | ]
39 | website = [
40 | "jupyter>=1.1.1",
41 | "jupyter-book>=1.0.3",
42 | "myst-parser>=2.0.0",
43 | ]
44 | type = [
45 | "mypy>=1.8.0",
46 | ]
47 |
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: Build Pages
2 | on: [push]
3 | jobs:
4 | build-and-deploy:
5 | name: 📃 Website Build
6 | runs-on: ubuntu-latest
7 | steps:
8 | - name: 🔔 Check out
9 | uses: actions/checkout@v3
10 |
11 | - name: 🏗️ python
12 | uses: actions/setup-python@v4
13 | with:
14 | python-version: "3.10"
15 |
16 | - name: ⬇️ Python PDM
17 | uses: pdm-project/setup-pdm@v4
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: ⬇️ Python Dependencies
22 | run: pdm install -G:all
23 |
24 | - name: 🚂 Activate environment
25 | run: echo "$(pdm venv --path in-project)/bin" >> $GITHUB_PATH
26 |
27 | - name: 🚧 Jupyter build
28 | run: jupyter book build docs
29 |
30 | - name: 📰 Publish docs
31 | uses: JamesIves/github-pages-deploy-action@v4
32 | with:
33 | branch: gh-pages
34 | folder: ./docs/_build/html
35 | git-config-name: "github-actions[bot]"
36 | git-config-email: "github-actions[bot]@users.noreply.github.com"
37 | commit-message: 🎉 Book deployed
38 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Flatten, Linear, Module, ReLU, Sequential
6 |
7 | from koila import BatchInfo, LazyTensor
8 |
9 | from . import common
10 |
11 |
12 | def test_torch_tutorial() -> None:
13 | "Testing the model taken from pytorch's tutorial."
14 |
15 | class NeuralNetwork(Module):
16 | def __init__(self):
17 | super(NeuralNetwork, self).__init__()
18 | self.flatten = Flatten()
19 | self.linear_relu_stack = Sequential(
20 | Linear(28 * 28, 512),
21 | ReLU(),
22 | Linear(512, 512),
23 | ReLU(),
24 | Linear(512, 10),
25 | )
26 |
27 | def forward(self, x):
28 | x = self.flatten(x)
29 | logits = self.linear_relu_stack(x)
30 | return logits
31 |
32 | input = torch.randn(9, 28, 28)
33 | nn = NeuralNetwork()
34 |
35 | output = nn(input)
36 | assert output.shape == (9, 10)
37 | assert isinstance(output, Tensor)
38 | assert not isinstance(output, LazyTensor)
39 |
40 | lazy_input = LazyTensor(input, batch=0)
41 | assert lazy_input.batch() == BatchInfo(0, 9)
42 | nn = NeuralNetwork()
43 |
44 | lazy_output = nn(lazy_input)
45 | assert lazy_output.shape == (9, 10)
46 | assert not isinstance(lazy_output, Tensor)
47 | assert isinstance(lazy_output, LazyTensor)
48 |
49 | assert lazy_input.run((3, 6)).size() == (3, 28, 28)
50 | common.assert_isclose(lazy_input.run((3, 6)), input[3:6])
51 | tbout = lazy_output.run((3, 6))
52 | assert tbout.shape == (3, 10)
53 | assert isinstance(tbout, Tensor)
54 | assert not isinstance(tbout, LazyTensor)
55 | common.assert_isclose(tbout, nn(input[3:6]))
56 |
57 | assert lazy_output.batch() == BatchInfo(0, 9)
58 |
--------------------------------------------------------------------------------
/src/koila/eager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import logging
6 | from typing import Any, Callable, Dict, Sequence, Tuple, Type
7 |
8 | from rich.logging import RichHandler
9 | from torch import Tensor
10 | from torch import device as Device
11 | from torch import dtype as DType
12 |
13 | from .interfaces import BatchInfo, RunnableTensor, TensorLike
14 |
15 | LOGGER = logging.getLogger(__name__)
16 | LOGGER.addHandler(RichHandler())
17 | LOGGER.setLevel(logging.DEBUG)
18 |
19 | # So, it seems that torch's Tensor base class utilizes metaclass
20 | # to pretend to be a parent of LongTensor, FloatTensor etc.
21 | # Perhaps I'll be using the same paradigm.
22 |
23 |
24 | class EagerTensor(RunnableTensor):
25 | def __init__(self, data: Tensor) -> None:
26 | self.data = data
27 |
28 | def __getattr__(self, name: str) -> Any:
29 | return getattr(self.data, name)
30 |
31 | def batch(self) -> BatchInfo | None:
32 | raise NotImplementedError
33 |
34 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor:
35 | del partial
36 | return self.data
37 |
38 | def visit(self, nodes: Dict[int, TensorLike]) -> None:
39 | raise NotImplementedError
40 |
41 | def device(self) -> str | Device:
42 | raise NotImplementedError
43 |
44 | def dtype(self) -> DType:
45 | raise NotImplementedError
46 |
47 | def size(self) -> Tuple[int, ...]:
48 | return self.data.size()
49 |
50 | @classmethod
51 | def __torch_function__(
52 | cls,
53 | func: Callable[..., Tensor],
54 | types: Tuple[Type[Any], ...],
55 | args: Sequence[TensorLike] = (),
56 | kwargs: Dict[str, TensorLike] | None = None,
57 | ) -> TensorLike:
58 | if kwargs is None:
59 | kwargs = {}
60 |
61 | if not all(issubclass(typ, (Tensor, EagerTensor)) for typ in types):
62 | return NotImplemented
63 |
64 | return EagerTensor(func(*args, **kwargs))
65 |
--------------------------------------------------------------------------------
/tests/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import dataclasses as dcls
6 | import math
7 | import typing
8 | from dataclasses import dataclass
9 | from typing import Any, Callable, Dict, Sequence
10 |
11 | import numpy as np
12 | import torch
13 | from numpy import ndarray
14 | from torch import Tensor
15 | from torch.types import Number
16 |
17 |
18 | @dataclass(init=False)
19 | class ArgsKwargs:
20 | def __init__(self, *args: Any, **kwargs: Any) -> None:
21 | self.args = args
22 | self.kwargs = kwargs
23 |
24 | args: Sequence[Any] = dcls.field(default_factory=tuple)
25 | kwargs: Dict[str, Any] = dcls.field(default_factory=dict)
26 |
27 |
28 | @dataclass(init=False)
29 | class Caller:
30 | func: Callable[..., Any]
31 | arguments: Sequence[ArgsKwargs] = dcls.field(default_factory=list)
32 |
33 | def __init__(
34 | self,
35 | func: Callable[..., Any],
36 | arguments: Sequence[ArgsKwargs | Sequence[Any] | Dict[str, Any]],
37 | ) -> None:
38 | self.func = func
39 | self.arguments = []
40 |
41 | for argument in arguments:
42 | if isinstance(argument, Sequence):
43 | argument = ArgsKwargs(*argument)
44 |
45 | if isinstance(argument, dict):
46 | assert all(isinstance(key, str) for key in argument.keys())
47 | argument = ArgsKwargs(**argument)
48 |
49 | self.arguments.append(argument)
50 |
51 | def call(self) -> None:
52 | for argument in self.arguments:
53 | self.func(*argument.args, **argument.kwargs)
54 |
55 |
56 | def call(
57 | func: Callable[..., Any],
58 | arguments: Sequence[ArgsKwargs | Sequence[Any] | Dict[str, Any]],
59 | ) -> None:
60 | Caller(func, arguments=arguments).call()
61 |
62 |
63 | def assert_equal(
64 | input: Tensor | ndarray | Number, other: Tensor | ndarray | Number
65 | ) -> None:
66 | if isinstance(input, ndarray) or isinstance(other, ndarray):
67 | assert np.all(input == other), input != other
68 | return
69 |
70 | if isinstance(input, Tensor) or isinstance(other, Tensor):
71 | assert typing.cast(Tensor, input == other).all(), input != other
72 | return
73 |
74 | assert input == other, [input, other]
75 |
76 |
77 | def assert_isclose(
78 | input: Tensor | ndarray | Number, other: Tensor | ndarray | Number
79 | ) -> None:
80 | if isinstance(input, ndarray) or isinstance(other, ndarray):
81 | assert np.allclose(input, other, atol=1e-5), [input, other]
82 | return
83 |
84 | if isinstance(input, Tensor) and isinstance(other, Tensor):
85 | assert torch.allclose(input, other, atol=1e-5), [input, other]
86 | return
87 |
88 | assert math.isclose(input, other, abs_tol=1e-5), [input, other]
89 |
90 |
91 | def is_notimplemented(func: Callable[[], Any]) -> bool:
92 | try:
93 | func()
94 | return False
95 | except:
96 | return True
97 |
--------------------------------------------------------------------------------
/src/koila/gpus.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import math
6 | from typing import Generator
7 |
8 | from pynvml.smi import nvidia_smi
9 | from torch import cuda
10 |
11 | from . import constants
12 | from .interfaces import BatchedPair
13 |
14 | NVSMI = None
15 |
16 |
17 | def nvidia_free_memory() -> int:
18 | """
19 | Calls nvidia's nvml library and queries available GPU memory.
20 | Currently the function only works with 1 GPU.
21 |
22 | Returns
23 | -------
24 |
25 | Free GPU memory in terms of bytes.
26 | """
27 |
28 | global NVSMI
29 | if NVSMI is None:
30 | NVSMI = nvidia_smi.getInstance()
31 |
32 | assert NVSMI is not None
33 | query = NVSMI.DeviceQuery("memory.free")
34 |
35 | # Only works on one GPU as of now.
36 | gpu = query["gpu"][0]["fb_memory_usage"]
37 |
38 | unit = constants.UNITS[gpu["unit"].lower()]
39 | free = gpu["free"]
40 |
41 | return free * unit
42 |
43 |
44 | def torch_free_memory() -> int:
45 | """
46 | Calls torch's memory statistics to calculate the amount of GPU memory unused.
47 | Currently the function only works with 1 GPU.
48 |
49 | Returns
50 | -------
51 |
52 | Reserved GPU memory in terms of bytes.
53 | """
54 |
55 | if not cuda.is_available():
56 | return 0
57 |
58 | # Only works on one GPU as of now.
59 |
60 | reserved_memory = cuda.memory_reserved(0)
61 | active_memory = cuda.memory_allocated(0)
62 | unused_memory = reserved_memory - active_memory
63 | return unused_memory
64 |
65 |
66 | def free_memory() -> int | None:
67 | """
68 | The amount of free GPU memory that can be used.
69 |
70 | Returns
71 | -------
72 |
73 | Unused GPU memory, or None if no GPUs are available.
74 | """
75 |
76 | if cuda.is_available():
77 | return nvidia_free_memory() + torch_free_memory()
78 | else:
79 | return None
80 |
81 |
82 | def maximum_batch(memory: BatchedPair, total_memory: int | None = None) -> int | None:
83 | # batch * x + no_batch = unused_memoroy
84 | if total_memory is None:
85 | total_memory = free_memory()
86 |
87 | if total_memory is None:
88 | return None
89 |
90 | return (total_memory - memory.no_batch) // memory.batch
91 |
92 |
93 | def split_batch(
94 | memory: BatchedPair, current_batch: int, total_memory: int | None = None
95 | ) -> Generator[int, None, None]:
96 | max_batch = maximum_batch(memory, total_memory)
97 |
98 | if max_batch is None:
99 | yield current_batch
100 | return
101 |
102 | batch_size = 2 ** (math.floor(math.log2(max_batch)))
103 | (times, current_batch) = divmod(current_batch, batch_size)
104 |
105 | for _ in range(times):
106 | yield batch_size
107 |
108 | while current_batch > 0:
109 | batch_size >>= 1
110 | if current_batch >= batch_size:
111 | current_batch -= batch_size
112 | yield batch_size
113 | assert current_batch < batch_size, [current_batch, batch_size]
114 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # General gitignore
2 | .DS_Store
3 | .vscode/
4 |
5 | # Python gitignore
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 | .pdm-python
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | #.idea/
167 |
168 |
169 | _build/
170 |
--------------------------------------------------------------------------------
/src/koila/interfaces.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import functools
6 | import operator
7 | from abc import abstractmethod
8 | from typing import (
9 | Any,
10 | Callable,
11 | Dict,
12 | NamedTuple,
13 | Protocol,
14 | Tuple,
15 | TypeVar,
16 | Union,
17 | overload,
18 | runtime_checkable,
19 | )
20 |
21 | from torch import Tensor
22 | from torch import device as Device
23 | from torch import dtype as DType
24 |
25 | from . import constants
26 |
27 | E = TypeVar("E")
28 | T = TypeVar("T", covariant=True)
29 | V = TypeVar("V", contravariant=True)
30 |
31 |
32 | @runtime_checkable
33 | class Runnable(Protocol[T]):
34 | @abstractmethod
35 | def run(self) -> T: ...
36 |
37 |
38 | @runtime_checkable
39 | class TensorMixin(Protocol):
40 | @overload
41 | @abstractmethod
42 | def size(self) -> Tuple[int, ...]: ...
43 |
44 | @overload
45 | @abstractmethod
46 | def size(self, dim: int) -> int: ...
47 |
48 | @abstractmethod
49 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]: ...
50 |
51 | def numel(self) -> int:
52 | return functools.reduce(operator.mul, self.size(), 1)
53 |
54 | def dim(self) -> int:
55 | return len(self.size())
56 |
57 | @abstractmethod
58 | def dtype(self) -> DType: ...
59 |
60 | @abstractmethod
61 | def device(self) -> str | Device: ...
62 |
63 |
64 | class BatchedPair(NamedTuple):
65 | batch: int
66 | no_batch: int
67 |
68 |
69 | class BatchInfo(NamedTuple):
70 | index: int
71 | value: int
72 |
73 | def map(self, func: Callable[[int], int]) -> BatchInfo:
74 | index = func(self.index)
75 | return BatchInfo(index, self.value)
76 |
77 |
78 | @runtime_checkable
79 | class RunnableTensor(Runnable[Tensor], TensorMixin, Protocol):
80 | @abstractmethod
81 | def batch(self) -> BatchInfo | None: ...
82 |
83 | @abstractmethod
84 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor: ...
85 |
86 | @abstractmethod
87 | def visit(self, nodes: Dict[int, TensorLike]) -> None: ...
88 |
89 | def buffer(self) -> Dict[int, TensorLike]:
90 | nodes = {}
91 | self.visit(nodes)
92 | return nodes
93 |
94 | def buffer_numel(self) -> BatchedPair:
95 | buffer = self.buffer().values()
96 | return BatchedPair(
97 | sum(t.numel() for t in buffer if bat(t) is not None),
98 | sum(t.numel() for t in buffer if bat(t) is None),
99 | )
100 |
101 | def buffer_memory(self) -> BatchedPair:
102 | buffer = self.buffer().values()
103 | return BatchedPair(
104 | sum(mem(t) for t in buffer if bat(t) is not None),
105 | sum(mem(t) for t in buffer if bat(t) is None),
106 | )
107 |
108 | def memory(self) -> int:
109 | return mem(self)
110 |
111 |
112 | def dtyp(tensor: TensorLike) -> DType:
113 | if isinstance(tensor, Tensor):
114 | return tensor.dtype
115 |
116 | return tensor.dtype()
117 |
118 |
119 | def dev(tensor: TensorLike) -> str | Device:
120 | if isinstance(tensor, Tensor):
121 | return tensor.device
122 |
123 | return tensor.device()
124 |
125 |
126 | def mem(tensor: TensorLike) -> int:
127 | dt = dtyp(tensor)
128 | numel = tensor.numel()
129 |
130 | if (batch := bat(tensor)) is not None:
131 | numel //= batch.value
132 |
133 | return constants.MEMORY_BYTES[dt] * numel
134 |
135 |
136 | def bat(tensor: TensorLike) -> BatchInfo | None:
137 | if isinstance(tensor, RunnableTensor):
138 | return tensor.batch()
139 | return None
140 |
141 |
142 | TensorLike = Union[Tensor, RunnableTensor]
143 |
144 |
145 | @overload
146 | def run(val: RunnableTensor, partial: Tuple[int, int] | None = None) -> Tensor: ...
147 |
148 |
149 | @overload
150 | def run(val: Runnable[E], partial: Tuple[int, int] | None = None) -> E: ...
151 |
152 |
153 | @overload
154 | def run(val: E, partial: Tuple[int, int] | None = None) -> E: ...
155 |
156 |
157 | def run(val: Any, partial: Tuple[int, int] | None = None) -> Any:
158 | if isinstance(val, RunnableTensor):
159 | return val.run(partial)
160 |
161 | if isinstance(val, Runnable):
162 | return val.run()
163 |
164 | return val
165 |
--------------------------------------------------------------------------------
/tests/test_prepasses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | import torch
4 |
5 | from koila import PrePassFunc, prepasses
6 |
7 | from . import common
8 |
9 |
10 | def test_compatibility() -> None:
11 | assert isinstance(prepasses.identity, PrePassFunc)
12 | assert isinstance(prepasses.symmetric, PrePassFunc)
13 | assert isinstance(prepasses.reduce_dims, PrePassFunc)
14 | assert isinstance(prepasses.permute, PrePassFunc)
15 | assert isinstance(prepasses.tranpose, PrePassFunc)
16 | assert isinstance(prepasses.view, PrePassFunc)
17 | assert isinstance(prepasses.reshape, PrePassFunc)
18 | assert isinstance(prepasses.flatten, PrePassFunc)
19 | assert isinstance(prepasses.matmul, PrePassFunc)
20 | assert isinstance(prepasses.linear, PrePassFunc)
21 | assert isinstance(prepasses.cat, PrePassFunc)
22 | assert isinstance(prepasses.pad, PrePassFunc)
23 | assert isinstance(prepasses.conv, PrePassFunc)
24 | assert isinstance(prepasses.conv_transpose, PrePassFunc)
25 | assert isinstance(prepasses.pool, PrePassFunc)
26 | assert isinstance(prepasses.maxpool, PrePassFunc)
27 | assert isinstance(prepasses.avgpool, PrePassFunc)
28 |
29 |
30 | def test_identity() -> None:
31 | common.call(
32 | common.assert_equal,
33 | [
34 | [prepasses.identity(torch.randn(1, 2, 3, 4, 5)), (1, 2, 3, 4, 5)],
35 | [prepasses.identity(torch.randn(4, 2, 5)), (4, 2, 5)],
36 | [prepasses.identity(torch.randn(17, 1, 4)), (17, 1, 4)],
37 | ],
38 | )
39 |
40 |
41 | def test_symmetric() -> None:
42 | common.call(
43 | common.assert_equal,
44 | [
45 | [prepasses.symmetric(torch.randn(2, 4, 5), torch.randn(())), (2, 4, 5)],
46 | [
47 | prepasses.symmetric(torch.randn(2, 4, 5), torch.randn(2, 4, 5)),
48 | (2, 4, 5),
49 | ],
50 | [
51 | prepasses.symmetric(torch.randn(2, 1, 5), torch.randn(2, 4, 5)),
52 | (2, 4, 5),
53 | ],
54 | [
55 | prepasses.symmetric(torch.randn(2, 1, 5), torch.randn(2, 4, 1)),
56 | (2, 4, 5),
57 | ],
58 | ],
59 | )
60 |
61 |
62 | def test_reduce_dims() -> None:
63 | common.call(
64 | common.assert_equal,
65 | [
66 | [prepasses.reduce_dims(torch.randn(1, 2, 3, 4, 5), 1), (1, 3, 4, 5)],
67 | [prepasses.reduce_dims(torch.randn(1, 2, 3, 4, 5), (2, 3)), (1, 2, 5)],
68 | [
69 | prepasses.reduce_dims(torch.randn(5, 2, 3, 4), (2, 3), keepdim=True),
70 | (5, 2, 1, 1),
71 | ],
72 | ],
73 | )
74 |
75 |
76 | def test_scalar() -> None:
77 | common.call(
78 | common.assert_equal,
79 | [
80 | [prepasses.reduce_dims(torch.randn(5, 5, 2)), ()],
81 | [prepasses.reduce_dims(torch.randn(7, 8)), ()],
82 | ],
83 | )
84 |
85 |
86 | def test_matmul() -> None:
87 | common.call(
88 | common.assert_equal,
89 | [
90 | [prepasses.matmul(torch.randn(8), torch.randn(8)), ()],
91 | [prepasses.matmul(torch.randn(8, 3), torch.randn(3)), (8,)],
92 | [prepasses.matmul(torch.randn(8), torch.randn(8, 3)), (3,)],
93 | [prepasses.matmul(torch.randn(4, 5), torch.randn(5, 3)), (4, 3)],
94 | [prepasses.matmul(torch.randn(9, 4, 5), torch.randn(9, 5, 3)), (9, 4, 3)],
95 | [prepasses.matmul(torch.randn(9, 4, 5), torch.randn(1, 5, 3)), (9, 4, 3)],
96 | [
97 | prepasses.matmul(torch.randn(9, 7, 4, 5), torch.randn(1, 5, 3)),
98 | (9, 7, 4, 3),
99 | ],
100 | ],
101 | )
102 |
103 |
104 | def test_transpose() -> None:
105 | common.call(
106 | common.assert_equal,
107 | [[prepasses.tranpose(torch.randn(3, 4, 5), 1, 2), (3, 5, 4)]],
108 | )
109 |
110 |
111 | def test_linear() -> None:
112 | common.call(
113 | common.assert_equal,
114 | [
115 | [
116 | prepasses.linear(
117 | torch.randn(7, 11, 13),
118 | weight=torch.randn(17, 13),
119 | bias=torch.randn(17),
120 | ),
121 | (7, 11, 17),
122 | ]
123 | ],
124 | )
125 |
126 |
127 | def test_cat() -> None:
128 | common.call(
129 | common.assert_equal,
130 | [
131 | [prepasses.cat([torch.randn(2, 3, 5), torch.randn(3, 3, 5)]), (5, 3, 5)],
132 | [
133 | prepasses.cat([torch.randn(2, 3, 5), torch.randn(2, 4, 5)], dim=1),
134 | (2, 7, 5),
135 | ],
136 | ],
137 | )
138 |
139 |
140 | def test_loss() -> None:
141 | common.call(
142 | common.assert_equal,
143 | [
144 | [prepasses.loss(torch.randn(2, 4, 5), torch.randn(2, 4, 5)), ()],
145 | [
146 | prepasses.loss(
147 | torch.randn(2, 4, 5), torch.randn(2, 4, 5), reduction="sum"
148 | ),
149 | (),
150 | ],
151 | [
152 | prepasses.loss(
153 | torch.randn(2, 4, 5), torch.randn(2, 4, 5), reduction="none"
154 | ),
155 | (2, 4, 5),
156 | ],
157 | ],
158 | )
159 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, social-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | - Demonstrating empathy and kindness toward other people
21 | - Being respectful of differing opinions, viewpoints, and experiences
22 | - Giving and gracefully accepting constructive feedback
23 | - Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | - Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | - The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | - Trolling, insulting or derogatory comments, and personal or political attacks
33 | - Public or private harassment
34 | - Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | - Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | [INSERT CONTACT METHOD].
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0].
120 |
121 | Community Impact Guidelines were inspired by
122 | [Mozilla's code of conduct enforcement ladder][mozilla coc].
123 |
124 | For answers to common questions about this code of conduct, see the FAQ at
125 | [https://www.contributor-covenant.org/faq][faq]. Translations are available
126 | at [https://www.contributor-covenant.org/translations][translations].
127 |
128 | [homepage]: https://www.contributor-covenant.org
129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html
130 | [mozilla coc]: https://github.com/mozilla/diversity
131 | [faq]: https://www.contributor-covenant.org/faq
132 | [translations]: https://www.contributor-covenant.org/translations
133 |
--------------------------------------------------------------------------------
/src/koila/shapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import functools
6 | import logging
7 | import operator
8 | from typing import Set, Tuple
9 |
10 | from rich.logging import RichHandler
11 |
12 | LOGGER = logging.getLogger(__name__)
13 | LOGGER.addHandler(RichHandler())
14 |
15 |
16 | def compatible_dim(input: int, other: int, broadcast: bool = True) -> bool:
17 | if broadcast:
18 | return input == 1 or other == 1 or input == other
19 | else:
20 | return input == other
21 |
22 |
23 | def prepends(
24 | input: Tuple[int, ...], other: Tuple[int, ...], value: int
25 | ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
26 | LOGGER.debug("Prepending %s and %s.", input, other)
27 |
28 | prepended = (value,) * abs(len(input) - len(other))
29 | if len(input) >= len(other):
30 | other = prepended + other
31 | else:
32 | input = prepended + input
33 | assert len(input) == len(other)
34 | return (input, other)
35 |
36 |
37 | def coerce(
38 | input: Tuple[int, ...],
39 | other: Tuple[int, ...],
40 | broadcast: bool = True,
41 | scalars: bool = True,
42 | ) -> Tuple[int, ...] | None:
43 | LOGGER.debug(
44 | "Coercing %s and %s. Broadcasting: %s. Allow scalars: %s.",
45 | input,
46 | other,
47 | broadcast,
48 | scalars,
49 | )
50 |
51 | if scalars:
52 | if len(input) == 0:
53 | return other
54 |
55 | if len(other) == 0:
56 | return input
57 |
58 | if not broadcast:
59 | if (shape := input) == other:
60 | return shape
61 | else:
62 | return None
63 |
64 | (input, other) = prepends(input, other, 1)
65 |
66 | shape = []
67 | for a, b in zip(input, other):
68 | if a <= 0 or b <= 0:
69 | raise ValueError
70 |
71 | if compatible_dim(a, b):
72 | shape.append(max(a, b))
73 | else:
74 | return None
75 |
76 | return tuple(shape)
77 |
78 |
79 | def permute(input: Tuple[int, ...], *dims: int) -> Tuple[int, ...]:
80 | LOGGER.debug("%s, %s", input, dims)
81 |
82 | if not len(input) == len(dims):
83 | raise TypeError
84 |
85 | if sorted(dims) != list(range(len(input))):
86 | raise ValueError
87 |
88 | if not len(set(dims)) == len(input):
89 | raise ValueError
90 |
91 | dims_order_pair = sorted(enumerate(dims), key=lambda pair: pair[1])
92 | scattered_dims = [pair[0] for pair in dims_order_pair]
93 | paired = sorted(zip(scattered_dims, input))
94 | reordered_dim = [pair[1] for pair in paired]
95 | return tuple(reordered_dim)
96 |
97 |
98 | def reshape(input: Tuple[int, ...], *shape: int) -> Tuple[int, ...]:
99 | LOGGER.debug("%s, %s", input, shape)
100 |
101 | if not functools.reduce(operator.mul, input) == functools.reduce(
102 | operator.mul, shape
103 | ):
104 | raise ValueError
105 | return shape
106 |
107 |
108 | def view(input: Tuple[int, ...], *shape: int) -> Tuple[int, ...]:
109 | LOGGER.debug("%s, %s", input, shape)
110 |
111 | special_values = [x for x in shape if x < 0]
112 |
113 | if len(special_values) > 1:
114 | raise ValueError
115 |
116 | if set(special_values) | {-1} != {-1}:
117 | raise ValueError
118 |
119 | special = -(
120 | functools.reduce(operator.mul, input) // functools.reduce(operator.mul, shape)
121 | )
122 | new_shape = []
123 | for s in shape:
124 | if s > 0:
125 | new_shape.append(s)
126 | else:
127 | new_shape.append(special)
128 |
129 | return reshape(input, *new_shape)
130 |
131 |
132 | def tranpose(input: Tuple[int, ...], dim0: int, dim1: int) -> Tuple[int, ...]:
133 | LOGGER.debug("%s, %d, %d", input, dim0, dim1)
134 |
135 | if len(input) < 2:
136 | raise ValueError
137 |
138 | shapes = list(input)
139 | (shapes[dim0], shapes[dim1]) = (shapes[dim1], shapes[dim0])
140 | return tuple(shapes)
141 |
142 |
143 | def matmul(input: Tuple[int, ...], other: Tuple[int, ...]) -> Tuple[int, ...]:
144 | LOGGER.debug("%s, %s", input, other)
145 |
146 | if len(input) == 0 or len(other) == 0:
147 | raise ValueError(
148 | "Both arguments to matmul need to be at least 1D."
149 | " "
150 | f"Got {len(input)}D and {len(other)}D."
151 | )
152 |
153 | if len(input) == len(other) == 1:
154 | if input[0] != other[0]:
155 | raise ValueError
156 |
157 | return ()
158 |
159 | if len(input) == len(other) == 2:
160 | if input[1] != other[0]:
161 | raise ValueError
162 |
163 | return (input[0], other[1])
164 |
165 | if len(input) == 1 and len(other) == 2:
166 | if input[0] != other[0]:
167 | raise ValueError
168 |
169 | return (other[1],)
170 |
171 | if len(input) == 2 and len(other) == 1:
172 | if input[1] != other[0]:
173 | raise ValueError
174 |
175 | return (input[0],)
176 |
177 | (input, other) = prepends(input, other, 1)
178 |
179 | shapes = []
180 | for dimi, dimo in zip(input[:-2], other[:-2]):
181 | if not compatible_dim(dimi, dimo):
182 | raise ValueError
183 | shapes.append(max(dimi, dimo))
184 |
185 | if input[-1] != other[-2]:
186 | raise ValueError
187 |
188 | shapes.extend([input[-2], other[-1]])
189 |
190 | return tuple(shapes)
191 |
192 |
193 | def reduce_dims(
194 | input: Tuple[int, ...],
195 | dim: int | Tuple[int, ...] | None = None,
196 | keepdim: bool = False,
197 | ) -> Tuple[Tuple[int, ...], Set[int]]:
198 | LOGGER.debug("%s, %s", input, dim)
199 |
200 | shapes = []
201 |
202 | if dim is None:
203 | dimensions = set(range(len(input)))
204 | elif isinstance(dim, int):
205 | dimensions = {dim}
206 | else:
207 | dimensions = set(dim)
208 |
209 | for idx, dimsize in enumerate(input):
210 | if idx not in dimensions:
211 | shapes.append(dimsize)
212 | continue
213 |
214 | if keepdim:
215 | shapes.append(1)
216 |
217 | if keepdim:
218 | assert len(shapes) == len(input)
219 |
220 | return (tuple(shapes), dimensions)
221 |
--------------------------------------------------------------------------------
/assets/koila.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
156 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 🫙 Archive (successor: aioway)
2 |
3 | **Koila is now built into a sub component of [aioway][aioway], check it out!**
4 |
5 | Project `koila` has 3 main components:
6 |
7 | #### 1. Metadata tracking for `torch.Tensor`s.
8 |
9 | But now `PyTorch` officially has [`FakeTensor`]("https://docs.pytorch.org/docs/stable/torch.compiler_fake_tensor.html") (`koila` predates it). It has great compatibility and support all of torch operators, something `koila` never was able to do.
10 |
11 | #### 2. Decouple some symbolic info (batch size) to run a reduced graph
12 |
13 | `Koila` only tracks symbolic info partially, on the batch dimension. I have now something a lot better, a compiler / interpreter for deep learning [`aioway`][aioway] that handles all these info, without the burden of `torch` compatibility.
14 |
15 | [`Aioway`][aioway] also has something that `koila` does not have: tracking layers. (how do you go 1 level up, and work on layers, when you API is simply `Tensor` input?)
16 |
17 | I have (at a day job) split a package into 2 (core / non core), and didn't have a good experience (quite a pain actually), because of the release cycle and IDE support (autocomplete, go to location etc), so using 2 (tightly coupled) packages simultaneously is out of the question.
18 |
19 | As I have plans to rewrite this tracking part in C++, I don't want to cross the repository boundary as `koila` is not going to have a native API, so this part would be rewritten in [`aioway`](aioway).
20 |
21 | #### 3. Run with gradient accumulation to prevent OOM.
22 |
23 | `Koila` requests batches iteratively, adhering to the `torch` API and work around some layers that conflicts with gradient accumulation.
24 |
25 | However, `torch` has too many operators, and this can (kind of?) be achieved with other measures, such as `try` `except` with binary search, which is super simple, low maintainence, and not that slow.
26 |
27 | I think to keep `koila` as it is, a POC that I did for fun.
28 |
29 | [aioway]: https://github.com/rentruewang/aioway
30 |
31 | # 🐨 Koila
32 |
33 | > Koila solves `CUDA error: out of memory error` painlessly.
34 | > Fix it with just one line of code, and forget it.
35 |
36 | [](https://github.com/rentruewang/koila/actions/workflows/unittest.yaml)
37 | [](https://github.com/rentruewang/koila/actions/workflows/typecheck.yaml)
38 | [](https://github.com/rentruewang/koila/actions/workflows/format.yaml)
39 | [](https://opensource.org/licenses/MIT)
40 | [](https://twitter.com/intent/tweet?text=Never%20worry%20about%20out%20of%20memory%20errors%20again&url=https://github.com/rentruewang/koila&hashtags=pytorch,outofmemory)
41 |
42 | 
43 |
44 | ## 🚨 Warning
45 |
46 | **Main branch is a complete re-structure of the project (that is currently mostly empty due to me not having enough time to complete it). To see working code, checkout the `v0.1.1` tag for a proof of concept (that doesn't have full support over all operations and is not suited for production). To use it, download [release v0.1.1 here](https://github.com/rentruewang/koila/releases/tag/v0.1.1).**
47 |
48 | ## 🚀 Features
49 |
50 | - 🙅 Prevents `CUDA error: out of memory error` with one single line of code.
51 |
52 | - 🧮 Without touching the main logic.
53 |
54 | - ⚗️ Automatically accumulates gradients when batch sizes are too large.
55 |
56 | - 🦥 Lazily evaluates PyTorch code to save computing power.
57 |
58 | - ✂️ Automatically splits along the batch dimension to more GPU friendly numbers (2's powers) to speed up the execution.
59 |
60 | - 🤏 Minimal API (wrapping all inputs will be enough).
61 |
62 | ## 🤔 Why Koila?
63 |
64 | Ever encountered `RuntimeError: CUDA error: out of memory`?
65 | We all love `PyTorch` because of its speed, efficiency, and transparency, but that means it doesn't do extra things. Things like preventing a very common error that has been bothering many users since [2017](https://github.com/pytorch/pytorch/issues/958#issuecomment-285090162).
66 |
67 | This library aims to prevent that by being a light-weight wrapper over native `PyTorch`. When a tensor is wrapped, the library **automatically computes the amount of remaining GPU memory and uses the right batch size**, saving everyone from having to manually fine-tune the batch size whenever a model is used.
68 |
69 | Also, the library automatically uses the right batch size to GPU. Did you know that using bigger batches doesn't always speed up processing? It's handled automatically in this library too.
70 |
71 | Because `Koila` code is `PyTorch` code, as it runs `PyTorch` under the hood, you can use both together without worrying compatibility.
72 |
73 | Oh, and all that in 1 line of code! 😊
74 |
75 | ## ⬇️ Installation
76 |
77 | `Koila` is available on [PyPI](https://pypi.org/project/koila/). To install, run the following command.
78 |
79 | ```bash
80 | pip install koila
81 | ```
82 |
83 | ## 🏃 Getting started
84 |
85 | The usage is dead simple. For example, you have the following `PyTorch` code (copied from `PyTorch`'s [tutorial](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html))
86 |
87 | Define the input, label, and model:
88 |
89 | ```python
90 | # A batch of MNIST image
91 | input = torch.randn(8, 28, 28)
92 |
93 | # A batch of labels
94 | label = torch.randn(0, 10, [8])
95 |
96 | class NeuralNetwork(Module):
97 | def __init__(self):
98 | super(NeuralNetwork, self).__init__()
99 | self.flatten = Flatten()
100 | self.linear_relu_stack = Sequential(
101 | Linear(28 * 28, 512),
102 | ReLU(),
103 | Linear(512, 512),
104 | ReLU(),
105 | Linear(512, 10),
106 | )
107 |
108 | def forward(self, x):
109 | x = self.flatten(x)
110 | logits = self.linear_relu_stack(x)
111 | return logits
112 | ```
113 |
114 | Define the loss function, calculate output and losses.
115 |
116 | ```python
117 | loss_fn = CrossEntropyLoss()
118 |
119 | # Calculate losses
120 | out = nn(t)
121 | loss = loss_fn(out, label)
122 |
123 | # Backward pass
124 | nn.zero_grad()
125 | loss.backward()
126 | ```
127 |
128 | Ok. How to adapt the code to use `Koila`'s features?
129 |
130 | You add this line of code (as of v0.1.1):
131 |
132 | ```python
133 | # Wrap the input tensor and label tensor.
134 | # If a batch argument is provided, that dimension of the tensor would be treated as the batch.
135 | # In this case, the first dimension (dim=0) is used as batch's dimension.
136 | (input, label) = lazy(input, label, batch=0)
137 | ```
138 |
139 | Done. You will not run out of memory again.
140 |
141 | ## 🏋️ How does it work under the hood?
142 |
143 | `CUDA error: out of memory` generally happens in forward pass, because temporary variables will need to be saved in memory.
144 |
145 | `Koila` is a thin wrapper around `PyTorch`. It is inspired by TensorFlow's static/lazy evaluation. By building the graph first, and run the model only when necessarily, the model has access to all the information necessarily to determine how much resources is really need to compute the model.
146 |
147 | In terms of memory usage, only **shapes of temporary variables are required to calculate the memory usage of those variables used in the model**. For example, `+` takes in two tensors with equal sizes, and outputs a tensor with a size equal to the input size, and `log` takes in one tensor, and outputs another tensor with the same shape. Broadcasting makes it a little more complicated than that, but the general ideas are the same. By tracking all these shapes, one could easily tell how much memory is used in a forward pass. And select the optimal batch size accordingly.
148 |
149 | ## 🐌 It sounds slow. Is it?
150 |
151 | **NO**. Indeed, calculating shapes and computing the size and memory usage sound like a lot of work. However, keep in mind that even a gigantic model like GPT-3, which has 96 layers, has only a few hundred nodes in its computing graph. Because `Koila`'s algorithms run in linear time, any modern computer will be able to handle a graph like this instantly.
152 |
153 | Most of the computing is spent on computing individual tensors, and transferring tensors across devices. And bear in mind that those checks happen in vanilla `PyTorch` anyways. So no, not slow at all.
154 |
155 | ## 🔊 How to pronounce koila?
156 |
157 | This project was originally named _koala_, the laziest species in the world, and this project is about lazy evaluation of tensors. However, as that name is taken on [PyPI](https://pypi.org/project/koala/), I had no choice but to use another name. `Koila` is a word made up by me, pronounced similarly to _voila_ (It's a French word), so sounds like koala.
158 |
159 | ## ⭐ Give me a star!
160 |
161 | If you like what you see, please consider giving this a star (★)!
162 |
163 | ## 🏗️ Why did I build this, despite similar libraries?
164 |
165 | Why did I go through the trouble and build this project, despite a lot of similar libraries on the internet?
166 |
167 | ### 🔎 Batch size search
168 |
169 | Batch size search is not new. In fact, the mighty popular [Lightning](https://lightning.ai/) has it.
170 |
171 | Lightning's batch size search is deeply integrated in its own ecosystem. You have to use its `DataLoader`, subclass from their models, and train your models accordingly. While refactoring supervised learning tasks to use lightning is relatively easy, it's really painful to do the same with a reinforcement learning code base, where interacting with the environment is a must.
172 |
173 | In comparison, because `Koila` is a super lightweight PyTorch wrapper, it works when PyTorch works, thus providing maximum flexibility and minimal changes to existing code.
174 |
175 | However, note that in the case where you're writing new code, Lightning is recommended as it enforces a better pattern of code style, which would benefit modularity in the long run.
176 |
177 | ### ♏ Symbolic pre-passing
178 |
179 | Likewise, passing an empty tensor to build a computational graph (AKA **static graph**) isn't a new idea, but thoroughly explored in the popular [TensorFlow](https://www.tensorflow.org/) library, and a similar `PyTorch` wrapper library [KeOps](https://www.kernel-operations.io/). These libraries suffer from the fact that debugging programs in them is unnecessarily complicated. For example, `TensorFlow` was known for its ease of deployment but pain in development, to the point that users switched to `PyTorch`. During debugging, people like to see what's _inside_ a variable, to see if it contains an incorrect value. However, because static graphs only define relations, the values are not computed, thus making debugging difficult.
180 |
181 | `Koila` solves that by eagerly evaluating when being converted to strings, integers, or any Python values. This enables seamless debugging while maintaining the ability to perform memory management that simply isn't available for a more straight forward `PyTorch` program, which dynamically (when needed) allocates and frees memory on the fly.
182 |
183 | ## 📝 Todos
184 |
185 | - 😌 Simplify internal workings even further. (Especially interaction between `Tensor`s and `LazyTensor`s).
186 | - 🧩 Provide an extensible API to write custom functions for the users.
187 | - 🍪 Work with multiple GPUs.
188 |
189 | ## 🚧 Caution
190 |
191 | The code works on many cases, but it's still a work in progress. This is not (yet) a fully `PyTorch` compatible library due to limited time. Avoid using it in production environments!
192 |
193 | ## 🥰 Contributing and Using
194 |
195 | Openness and inclusiveness are taken very seriously. The code is available under [Apache License](./LICENSE.md). Please follow the following [Code of Conduct](./CODE_OF_CONDUCT.md).
196 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import (
6 | AvgPool1d,
7 | AvgPool2d,
8 | AvgPool3d,
9 | BatchNorm1d,
10 | BatchNorm2d,
11 | BatchNorm3d,
12 | Conv1d,
13 | Conv2d,
14 | Conv3d,
15 | ConvTranspose1d,
16 | ConvTranspose2d,
17 | ConvTranspose3d,
18 | Dropout,
19 | Embedding,
20 | LayerNorm,
21 | LeakyReLU,
22 | Linear,
23 | MaxPool1d,
24 | MaxPool2d,
25 | MaxPool3d,
26 | ReLU,
27 | Sigmoid,
28 | Softmax,
29 | )
30 |
31 | import koila
32 | from koila import LazyTensor
33 |
34 | from . import common
35 |
36 |
37 | def test_linear_layer() -> None:
38 | arr = torch.randn(7, 11, 13)
39 | la = koila.lazy(arr)
40 | layer = Linear(13, 17)
41 |
42 | out = layer(arr)
43 | assert out.shape == (7, 11, 17)
44 | assert not isinstance(out, LazyTensor)
45 | assert isinstance(out, Tensor)
46 |
47 | assert isinstance(la, LazyTensor)
48 | lo = layer(la)
49 | assert lo.shape == (7, 11, 17)
50 | assert not isinstance(lo, Tensor)
51 | assert isinstance(lo, LazyTensor)
52 | common.assert_isclose(lo.run(), out)
53 |
54 |
55 | def test_batchnorm_layers() -> None:
56 | # 1D
57 | arr = torch.randn(3, 5, 7)
58 | la = koila.lazy(arr)
59 | layer = BatchNorm1d(5)
60 |
61 | out = layer(arr)
62 | assert out.shape == (3, 5, 7)
63 | assert not isinstance(out, LazyTensor)
64 | assert isinstance(out, Tensor)
65 |
66 | assert isinstance(la, LazyTensor)
67 | lo = layer(la)
68 | assert lo.shape == (3, 5, 7)
69 | assert not isinstance(lo, Tensor)
70 | assert isinstance(lo, LazyTensor)
71 | common.assert_isclose(lo.run(), out)
72 |
73 | # 2D
74 | arr = torch.randn(3, 5, 7, 11)
75 | la = koila.lazy(arr)
76 | layer = BatchNorm2d(5)
77 |
78 | out = layer(arr)
79 | assert out.shape == (3, 5, 7, 11)
80 | assert not isinstance(out, LazyTensor)
81 | assert isinstance(out, Tensor)
82 |
83 | assert isinstance(la, LazyTensor)
84 | lo = layer(la)
85 | assert lo.shape == (3, 5, 7, 11)
86 | assert not isinstance(lo, Tensor)
87 | assert isinstance(lo, LazyTensor)
88 | common.assert_isclose(lo.run(), out)
89 |
90 | # 3D
91 | arr = torch.randn(3, 5, 7, 11, 13)
92 | la = koila.lazy(arr)
93 | layer = BatchNorm3d(5)
94 |
95 | out = layer(arr)
96 | assert out.shape == (3, 5, 7, 11, 13)
97 | assert not isinstance(out, LazyTensor)
98 | assert isinstance(out, Tensor)
99 |
100 | assert isinstance(la, LazyTensor)
101 | lo = layer(la)
102 | assert lo.shape == (3, 5, 7, 11, 13)
103 | assert not isinstance(lo, Tensor)
104 | assert isinstance(lo, LazyTensor)
105 | common.assert_isclose(lo.run(), out)
106 |
107 |
108 | def test_layernorm_layers() -> None:
109 | # 1D
110 | arr = torch.randn(3, 5, 7)
111 | la = koila.lazy(arr)
112 | layer = LayerNorm([5, 7])
113 |
114 | out = layer(arr)
115 | assert out.shape == (3, 5, 7)
116 | assert not isinstance(out, LazyTensor)
117 | assert isinstance(out, Tensor)
118 |
119 | assert isinstance(la, LazyTensor)
120 | lo = layer(la)
121 | assert lo.shape == (3, 5, 7)
122 | assert not isinstance(lo, Tensor)
123 | assert isinstance(lo, LazyTensor)
124 | common.assert_isclose(lo.run(), out)
125 |
126 |
127 | def test_dropout_layer() -> None:
128 | arr = torch.randn(7, 11)
129 | la = koila.lazy(arr)
130 | layer = Dropout(p=0.5)
131 |
132 | out = layer(arr)
133 | assert out.shape == (7, 11)
134 | assert not isinstance(out, LazyTensor)
135 | assert isinstance(out, Tensor)
136 |
137 | assert isinstance(la, LazyTensor)
138 | lo = layer(la)
139 | assert lo.shape == (7, 11)
140 | assert not isinstance(lo, Tensor)
141 | assert isinstance(lo, LazyTensor)
142 |
143 |
144 | def test_relu_layer() -> None:
145 | arr = torch.randn(7, 11)
146 | la = koila.lazy(arr)
147 | layer = ReLU()
148 |
149 | out = layer(arr)
150 | assert out.shape == (7, 11)
151 | assert not isinstance(out, LazyTensor)
152 | assert isinstance(out, Tensor)
153 |
154 | assert isinstance(la, LazyTensor)
155 | lo = layer(la)
156 | assert lo.shape == (7, 11)
157 | assert not isinstance(lo, Tensor)
158 | assert isinstance(lo, LazyTensor)
159 | common.assert_isclose(lo.run(), out)
160 |
161 |
162 | def test_leaky_relu_layer() -> None:
163 | arr = torch.randn(7, 11)
164 | la = koila.lazy(arr)
165 | layer = LeakyReLU(negative_slope=0.3)
166 |
167 | out = layer(arr)
168 | assert out.shape == (7, 11)
169 | assert not isinstance(out, LazyTensor)
170 | assert isinstance(out, Tensor)
171 |
172 | assert isinstance(la, LazyTensor)
173 | lo = layer(la)
174 | assert lo.shape == (7, 11)
175 | assert not isinstance(lo, Tensor)
176 | assert isinstance(lo, LazyTensor)
177 | common.assert_isclose(lo.run(), out)
178 |
179 |
180 | def test_sigmoid_layer() -> None:
181 | arr = torch.randn(7, 11)
182 | la = koila.lazy(arr)
183 | layer = Sigmoid()
184 |
185 | out = layer(arr)
186 | assert out.shape == (7, 11)
187 | assert not isinstance(out, LazyTensor)
188 | assert isinstance(out, Tensor)
189 |
190 | assert isinstance(la, LazyTensor)
191 | lo = layer(la)
192 | assert lo.shape == (7, 11)
193 | assert not isinstance(lo, Tensor)
194 | assert isinstance(lo, LazyTensor)
195 | common.assert_isclose(lo.run(), out)
196 |
197 |
198 | def test_softmax_layer() -> None:
199 | arr = torch.randn(7, 11)
200 | la = koila.lazy(arr)
201 | layer = Softmax(dim=-1)
202 |
203 | out = layer(arr)
204 | assert out.shape == (7, 11)
205 | assert not isinstance(out, LazyTensor)
206 | assert isinstance(out, Tensor)
207 |
208 | assert isinstance(la, LazyTensor)
209 | lo = layer(la)
210 | assert lo.shape == (7, 11)
211 | assert not isinstance(lo, Tensor)
212 | assert isinstance(lo, LazyTensor)
213 | common.assert_isclose(lo.run(), out)
214 |
215 |
216 | def test_conv_layer() -> None:
217 | # 1D
218 | arr = torch.randn(7, 11, 13)
219 | la = koila.lazy(arr)
220 | layer = Conv1d(11, 17, kernel_size=3, stride=2)
221 |
222 | out = layer(arr)
223 | assert not isinstance(out, LazyTensor)
224 | assert isinstance(out, Tensor)
225 |
226 | assert isinstance(la, LazyTensor)
227 | lo = layer(la)
228 | assert not isinstance(lo, Tensor)
229 | assert isinstance(lo, LazyTensor)
230 | assert lo.shape == out.shape
231 | common.assert_isclose(lo.run(), out)
232 |
233 | # 2D
234 | arr = torch.randn(7, 11, 13, 14)
235 | la = koila.lazy(arr)
236 | layer = Conv2d(11, 17, kernel_size=3, stride=2)
237 |
238 | out = layer(arr)
239 | assert not isinstance(out, LazyTensor)
240 | assert isinstance(out, Tensor)
241 |
242 | assert isinstance(la, LazyTensor)
243 | lo = layer(la)
244 | assert not isinstance(lo, Tensor)
245 | assert isinstance(lo, LazyTensor)
246 | assert lo.shape == out.shape
247 | common.assert_isclose(lo.run(), out)
248 |
249 | # 3D
250 | arr = torch.randn(7, 11, 13, 14, 15)
251 | la = koila.lazy(arr)
252 | layer = Conv3d(11, 17, kernel_size=3, stride=2)
253 |
254 | out = layer(arr)
255 | assert not isinstance(out, LazyTensor)
256 | assert isinstance(out, Tensor)
257 |
258 | assert isinstance(la, LazyTensor)
259 | lo = layer(la)
260 | assert not isinstance(lo, Tensor)
261 | assert isinstance(lo, LazyTensor)
262 | assert lo.shape == out.shape
263 | common.assert_isclose(lo.run(), out)
264 |
265 |
266 | def test_convtranspose_layer() -> None:
267 | # 1D
268 | arr = torch.randn(7, 11, 13)
269 | la = koila.lazy(arr)
270 | layer = ConvTranspose1d(11, 17, kernel_size=3, stride=2)
271 |
272 | out = layer(arr)
273 | assert not isinstance(out, LazyTensor)
274 | assert isinstance(out, Tensor)
275 |
276 | assert isinstance(la, LazyTensor)
277 | lo = layer(la)
278 | assert not isinstance(lo, Tensor)
279 | assert isinstance(lo, LazyTensor)
280 | assert lo.shape == out.shape
281 | common.assert_isclose(lo.run(), out)
282 |
283 | # 2D
284 | arr = torch.randn(7, 11, 13, 14)
285 | la = koila.lazy(arr)
286 | layer = ConvTranspose2d(11, 17, kernel_size=3, stride=2)
287 |
288 | out = layer(arr)
289 | assert not isinstance(out, LazyTensor)
290 | assert isinstance(out, Tensor)
291 |
292 | assert isinstance(la, LazyTensor)
293 | lo = layer(la)
294 | assert not isinstance(lo, Tensor)
295 | assert isinstance(lo, LazyTensor)
296 | assert lo.shape == out.shape
297 | common.assert_isclose(lo.run(), out)
298 |
299 | # 3D
300 | arr = torch.randn(7, 11, 13, 14, 15)
301 | la = koila.lazy(arr)
302 | layer = ConvTranspose3d(11, 17, kernel_size=3, stride=2)
303 |
304 | out = layer(arr)
305 | assert not isinstance(out, LazyTensor)
306 | assert isinstance(out, Tensor)
307 |
308 | assert isinstance(la, LazyTensor)
309 | lo = layer(la)
310 | assert not isinstance(lo, Tensor)
311 | assert isinstance(lo, LazyTensor)
312 | assert lo.shape == out.shape
313 | common.assert_isclose(lo.run(), out)
314 |
315 |
316 | def test_maxpool_layer() -> None:
317 | # 1D
318 | arr = torch.randn(7, 11, 13)
319 | la = koila.lazy(arr)
320 | layer = MaxPool1d(kernel_size=3, stride=2)
321 |
322 | out = layer(arr)
323 | assert not isinstance(out, LazyTensor)
324 | assert isinstance(out, Tensor)
325 |
326 | assert isinstance(la, LazyTensor)
327 | lo = layer(la)
328 | assert not isinstance(lo, Tensor)
329 | assert isinstance(lo, LazyTensor)
330 | assert lo.shape == out.shape
331 | common.assert_isclose(lo.run(), out)
332 |
333 | # 2D
334 | arr = torch.randn(7, 11, 13, 14)
335 | la = koila.lazy(arr)
336 | layer = MaxPool2d(kernel_size=3, stride=2)
337 |
338 | out = layer(arr)
339 | assert not isinstance(out, LazyTensor)
340 | assert isinstance(out, Tensor)
341 |
342 | assert isinstance(la, LazyTensor)
343 | lo = layer(la)
344 | assert not isinstance(lo, Tensor)
345 | assert isinstance(lo, LazyTensor)
346 | assert lo.shape == out.shape
347 | common.assert_isclose(lo.run(), out)
348 |
349 | # 3D
350 | arr = torch.randn(7, 11, 13, 14, 15)
351 | la = koila.lazy(arr)
352 | layer = MaxPool3d(kernel_size=3, stride=2)
353 |
354 | out = layer(arr)
355 | assert not isinstance(out, LazyTensor)
356 | assert isinstance(out, Tensor)
357 |
358 | assert isinstance(la, LazyTensor)
359 | lo = layer(la)
360 | assert not isinstance(lo, Tensor)
361 | assert isinstance(lo, LazyTensor)
362 | assert lo.shape == out.shape
363 | common.assert_isclose(lo.run(), out)
364 |
365 |
366 | def test_avgpool_layer() -> None:
367 | # 1D
368 | arr = torch.randn(7, 11, 13)
369 | la = koila.lazy(arr)
370 | layer = AvgPool1d(kernel_size=3, stride=2)
371 |
372 | out = layer(arr)
373 | assert not isinstance(out, LazyTensor)
374 | assert isinstance(out, Tensor)
375 |
376 | assert isinstance(la, LazyTensor)
377 | lo = layer(la)
378 | assert not isinstance(lo, Tensor)
379 | assert isinstance(lo, LazyTensor)
380 | assert lo.shape == out.shape
381 | common.assert_isclose(lo.run(), out)
382 |
383 | # 2D
384 | arr = torch.randn(7, 11, 13, 14)
385 | la = koila.lazy(arr)
386 | layer = AvgPool2d(kernel_size=3, stride=2)
387 |
388 | out = layer(arr)
389 | assert not isinstance(out, LazyTensor)
390 | assert isinstance(out, Tensor)
391 |
392 | assert isinstance(la, LazyTensor)
393 | lo = layer(la)
394 | assert not isinstance(lo, Tensor)
395 | assert isinstance(lo, LazyTensor)
396 | assert lo.shape == out.shape
397 | common.assert_isclose(lo.run(), out)
398 |
399 | # 3D
400 | arr = torch.randn(7, 11, 13, 14, 15)
401 | la = koila.lazy(arr)
402 | layer = AvgPool3d(kernel_size=3, stride=2)
403 |
404 | out = layer(arr)
405 | assert not isinstance(out, LazyTensor)
406 | assert isinstance(out, Tensor)
407 |
408 | assert isinstance(la, LazyTensor)
409 | lo = layer(la)
410 | assert not isinstance(lo, Tensor)
411 | assert isinstance(lo, LazyTensor)
412 | assert lo.shape == out.shape
413 | common.assert_isclose(lo.run(), out)
414 |
415 |
416 | def test_embedding_layer() -> None:
417 | arr = torch.randint(0, 11, [5])
418 | la = koila.lazy(arr)
419 | layer = Embedding(num_embeddings=11, embedding_dim=13)
420 |
421 | out = layer(arr)
422 | assert out.shape == (5, 13)
423 | assert not isinstance(out, LazyTensor)
424 | assert isinstance(out, Tensor)
425 |
426 | assert isinstance(la, LazyTensor)
427 | lo = layer(la)
428 | assert lo.shape == (5, 13)
429 | assert not isinstance(lo, Tensor)
430 | assert isinstance(lo, LazyTensor)
431 | common.assert_isclose(lo.run(), out)
432 |
--------------------------------------------------------------------------------
/src/koila/prepasses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import functools
6 | import logging
7 | import math
8 | import operator
9 | from abc import abstractmethod
10 | from dataclasses import dataclass
11 | from typing import (
12 | Any,
13 | List,
14 | Literal,
15 | Protocol,
16 | Sequence,
17 | Tuple,
18 | overload,
19 | runtime_checkable,
20 | )
21 |
22 | from rich.logging import RichHandler
23 | from torch import device as Device
24 | from torch import dtype as DType
25 | from torch.functional import Tensor
26 |
27 | from . import constants, interfaces, shapes
28 | from .errors import UnsupportedError
29 | from .interfaces import BatchInfo, TensorLike
30 |
31 | LOGGER = logging.getLogger(__name__)
32 | LOGGER.addHandler(RichHandler())
33 |
34 |
35 | class CallBack(Protocol):
36 | @abstractmethod
37 | def __call__(self, *args: Any, **kwargs: Any) -> Reducer: ...
38 |
39 |
40 | class Reducer(Protocol):
41 | @abstractmethod
42 | def __call__(self, result: Tensor, /) -> Tensor: ...
43 |
44 |
45 | @dataclass(frozen=True)
46 | class MetaData:
47 | dtype: DType
48 | device: str | Device
49 | batch: BatchInfo | None
50 | reducer: CallBack | None
51 |
52 |
53 | @dataclass
54 | class PrePass:
55 | shape: Tuple[int, ...]
56 | metadata: MetaData
57 |
58 | def __init__(self, shape: Sequence[int], metadata: MetaData) -> None:
59 | self.shape = tuple(shape)
60 | self.metadata = metadata
61 |
62 | def __iter__(self):
63 | return iter(self.shape)
64 |
65 | @overload
66 | def __getitem__(self, index: int) -> int: ...
67 |
68 | @overload
69 | def __getitem__(self, index: slice) -> Tuple[int, ...]: ...
70 |
71 | def __getitem__(self, index: int | slice) -> int | Tuple[int, ...]:
72 | return self.shape[index]
73 |
74 | def __eq__(self, other: Any) -> bool:
75 | if isinstance(other, PrePass):
76 | return self == other
77 |
78 | if isinstance(other, Tuple):
79 | return self.shape == other
80 |
81 | return False
82 |
83 | def dtype(self) -> DType:
84 | return self.metadata.dtype
85 |
86 | def device(self) -> str | Device:
87 | return self.metadata.device
88 |
89 | def batch(self) -> BatchInfo | None:
90 | return self.metadata.batch
91 |
92 | def reducer(self) -> CallBack | None:
93 | return self.metadata.reducer
94 |
95 |
96 | @runtime_checkable
97 | class PrePassFunc(Protocol):
98 | @abstractmethod
99 | def __call__(self, *args: Any, **kwargs: Any) -> PrePass: ...
100 |
101 |
102 | def mute_unused_args(*args: Any, **kwargs: Any) -> None:
103 | del args
104 | del kwargs
105 |
106 |
107 | def trivial(input: Tensor, *args: Any, **kwargs: Any) -> Reducer:
108 | mute_unused_args(input, *args, **kwargs)
109 | return lambda result: result
110 |
111 |
112 | def same(
113 | tensors: Sequence[TensorLike], batch: BatchInfo | None, reducer: CallBack | None
114 | ) -> MetaData:
115 | assert len(tensors) > 0
116 | dtypes = [interfaces.dtyp(t) for t in tensors]
117 |
118 | max_dtype = max(dtypes, key=lambda typ: constants.MEMORY_BYTES[typ])
119 |
120 | devices = [str(interfaces.dev(t)) for t in tensors]
121 |
122 | if len(set(devices)) != 1:
123 | raise ValueError(f"Expected tensors to be on the same device, got {devices}.")
124 |
125 | return MetaData(max_dtype, devices[0], batch, reducer)
126 |
127 |
128 | def identity(input: TensorLike, /, *args: Any, **kwargs: Any) -> PrePass:
129 | mute_unused_args(*args, **kwargs)
130 |
131 | return PrePass(input.size(), same([input], interfaces.bat(input), trivial))
132 |
133 |
134 | def symmetric(
135 | input: TensorLike, other: TensorLike, /, *args: Any, **kwargs: Any
136 | ) -> PrePass:
137 | mute_unused_args(*args, **kwargs)
138 |
139 | shape = shapes.coerce(input.size(), other.size(), broadcast=True, scalars=True)
140 |
141 | if shape is None:
142 | raise ValueError
143 |
144 | batch = None
145 | if (b := interfaces.bat(input)) == interfaces.bat(other):
146 | batch = b
147 |
148 | return PrePass(shape, same([input, other], batch, trivial))
149 |
150 |
151 | def reduce_dims(
152 | input: TensorLike,
153 | /,
154 | dim: int | Tuple[int, ...] | None = None,
155 | keepdim: bool = False,
156 | *args: Any,
157 | **kwargs: Any,
158 | ) -> PrePass:
159 | mute_unused_args(*args, **kwargs)
160 |
161 | (shape, dimensions) = shapes.reduce_dims(input.size(), dim, keepdim)
162 |
163 | if interfaces.bat(input) in dimensions:
164 | batch = None
165 | reducer = None
166 | else:
167 | batch = interfaces.bat(input)
168 | reducer = trivial
169 |
170 | return PrePass(shape, same([input], batch, reducer))
171 |
172 |
173 | def scalars(input: TensorLike, /, *args: Any, **kwargs: Any) -> PrePass:
174 | mute_unused_args(*args, **kwargs)
175 |
176 | return reduce_dims(input, tuple(range(input.dim())))
177 |
178 |
179 | def mean(
180 | input: TensorLike,
181 | /,
182 | dim: int | Tuple[int, ...] | None = None,
183 | keepdim: bool = False,
184 | *args: Any,
185 | **kwargs: Any,
186 | ) -> PrePass:
187 | mute_unused_args(*args, **kwargs)
188 |
189 | (shape, dimensions) = shapes.reduce_dims(input.size(), dim, keepdim)
190 |
191 | if (b := interfaces.bat(input)) in dimensions:
192 | batch = None
193 |
194 | def mean_callback(input: Tensor, *args: Any, **kwargs: Any) -> Reducer:
195 | def reducer(result: Tensor) -> Tensor:
196 | return result * input.size(b) / shape[b]
197 |
198 | return reducer
199 |
200 | reducer = mean_callback
201 | else:
202 | batch = interfaces.bat(input)
203 | reducer = trivial
204 | return PrePass(shape, same([input], batch, reducer))
205 |
206 |
207 | def permute(input: TensorLike, /, *dims: int, **kwargs: Any) -> PrePass:
208 | mute_unused_args(**kwargs)
209 |
210 | mapping = dict(enumerate(dims))
211 |
212 | batch = None
213 | if (b := interfaces.bat(input)) is not None:
214 | batch = b.map(lambda x: mapping[x])
215 |
216 | return PrePass(shapes.permute(input.size(), *dims), same([input], batch, trivial))
217 |
218 |
219 | def reshape(input: TensorLike, /, *shape: int, **kwargs: Any) -> PrePass:
220 | mute_unused_args(**kwargs)
221 |
222 | shape = shapes.reshape(input.size(), *shape)
223 |
224 | batch = None
225 | if (b := interfaces.bat(input)) is not None:
226 | if b in shape:
227 | batch = b.map(shape.index)
228 |
229 | return PrePass(shape, same([input], batch, trivial))
230 |
231 |
232 | def view(input: TensorLike, /, *shape: int, **kwargs: Any) -> PrePass:
233 | mute_unused_args(**kwargs)
234 |
235 | shape = shapes.view(input.size(), *shape)
236 |
237 | batch = None
238 | if (b := interfaces.bat(input)) is not None:
239 | if b in shape:
240 | batch = b.map(shape.index)
241 |
242 | return PrePass(shape, same([input], batch, trivial))
243 |
244 |
245 | def flatten(
246 | input: TensorLike,
247 | /,
248 | start_dim: int = 0,
249 | end_dim: int = -1,
250 | *args: Any,
251 | **kwargs: Any,
252 | ) -> PrePass:
253 | LOGGER.debug("%s, %s, %s", input.size(), start_dim, end_dim)
254 |
255 | mute_unused_args(*args, **kwargs)
256 |
257 | start_dim %= input.dim()
258 | end_dim %= input.dim()
259 |
260 | sizes = input.size()
261 |
262 | shape = (
263 | *sizes[:start_dim],
264 | functools.reduce(operator.mul, sizes[start_dim : end_dim + 1]),
265 | *sizes[end_dim + 1 :],
266 | )
267 |
268 | batch = None
269 | if (b := interfaces.bat(input)) is not None:
270 | if not (start_dim <= b.index <= end_dim):
271 | batch = b
272 |
273 | return PrePass(shape, same([input], batch, trivial))
274 |
275 |
276 | def tranpose(
277 | input: TensorLike, dim0: int, dim1: int, /, *args: Any, **kwargs: Any
278 | ) -> PrePass:
279 | mute_unused_args(*args, **kwargs)
280 |
281 | batch = None
282 | if (b := interfaces.bat(input)) is not None:
283 | batch = b.map(lambda x: {dim0: dim1, dim1: dim0}[x])
284 |
285 | return PrePass(
286 | shapes.tranpose(input.size(), dim0, dim1), same([input], batch, trivial)
287 | )
288 |
289 |
290 | def select(
291 | input: TensorLike,
292 | dim: int | ... | None,
293 | index: int | Tensor,
294 | /,
295 | *args: Any,
296 | **kwargs: Any,
297 | ) -> PrePass:
298 | mute_unused_args(*args, **kwargs)
299 |
300 | shape = input.size()
301 |
302 | if dim is ...:
303 | dim = -1
304 |
305 | if dim is None:
306 | dim = 0
307 | shape = (1,) + shape
308 |
309 | if not -len(shape) <= dim < len(shape):
310 | raise IndexError
311 |
312 | dim %= len(shape)
313 | assert isinstance(dim, int)
314 |
315 | if isinstance(index, Tensor):
316 | sliced_idx = (len(index),)
317 | else:
318 | sliced_idx = ()
319 |
320 | batch = None
321 | if (b := interfaces.bat(input)) != dim:
322 | batch = b
323 |
324 | return PrePass(
325 | shape[:dim] + sliced_idx + shape[dim + 1 :],
326 | same([input], batch, trivial),
327 | )
328 |
329 |
330 | def embedding(
331 | input: TensorLike, weight: TensorLike, /, *args: Any, **kwargs: Any
332 | ) -> PrePass:
333 | mute_unused_args(*args, **kwargs)
334 |
335 | shape = input.size()
336 | return PrePass(
337 | (*shape, weight.size(-1)),
338 | same([input], interfaces.bat(input), trivial),
339 | )
340 |
341 |
342 | def matmul(
343 | input: TensorLike, other: TensorLike, /, *args: Any, **kwargs: Any
344 | ) -> PrePass:
345 | mute_unused_args(*args, **kwargs)
346 |
347 | if (batch := interfaces.bat(input)) != interfaces.bat(other):
348 | raise UnsupportedError
349 |
350 | return PrePass(
351 | shapes.matmul(input.size(), other.size()),
352 | same([input, other], interfaces.bat(input), trivial),
353 | )
354 |
355 |
356 | def loss(
357 | input: TensorLike,
358 | target: TensorLike,
359 | /,
360 | reduction: Literal["none", "mean", "sum"] = "mean",
361 | *args: Any,
362 | **kwargs: Any,
363 | ) -> PrePass:
364 | mute_unused_args(*args, **kwargs)
365 |
366 | # Currently only supports tensors of the same batch size.
367 | if (batch := interfaces.bat(input)) != interfaces.bat(target):
368 | raise UnsupportedError
369 |
370 | output_shape = {
371 | "none": input.size(),
372 | "mean": (),
373 | "sum": (),
374 | }[reduction]
375 |
376 | reducer = {"none": trivial, "mean": trivial, "sum": trivial}[reduction]
377 |
378 | return PrePass(output_shape, same([input, target], batch, reducer))
379 |
380 |
381 | def linear(
382 | input: TensorLike,
383 | weight: TensorLike,
384 | bias: TensorLike | None = None,
385 | *args: Any,
386 | **kwargs: Any,
387 | ) -> PrePass:
388 | mute_unused_args(*args, **kwargs)
389 |
390 | result = shapes.matmul(input.size(), shapes.tranpose(weight.size(), -1, -2))
391 |
392 | if bias is not None:
393 | result = shapes.coerce(result, bias.size())
394 |
395 | if result is None:
396 | raise ValueError
397 |
398 | return PrePass(result, same([input, weight], interfaces.bat(input), trivial))
399 |
400 |
401 | def cat(
402 | tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any
403 | ) -> PrePass:
404 | mute_unused_args(*args, **kwargs)
405 |
406 | if len(tensors) == 0:
407 | raise ValueError("Expected a sequence of tensors. Got empty sequence.")
408 |
409 | shapes = [t.size() for t in tensors]
410 | no_dim = [t[:dim] + t[dim + 1 :] for t in shapes]
411 |
412 | result_size = no_dim[0]
413 | for size in no_dim[1:]:
414 | if result_size != size:
415 | raise ValueError(
416 | f"Dimension should be equal outside dim {dim}. Got {shapes}."
417 | )
418 |
419 | if len(set(interfaces.bat(t) for t in tensors)) != 1:
420 | raise UnsupportedError
421 |
422 | batch = None
423 | if (b := interfaces.bat(tensors[0])) != dim:
424 | batch = b
425 |
426 | concat_size = sum(t[dim] for t in shapes)
427 | return PrePass(
428 | [*result_size[:dim], concat_size, *result_size[dim:]],
429 | same(tensors, batch, trivial),
430 | )
431 |
432 |
433 | def pad(input: TensorLike, pad: List[int], *args: Any, **kwargs: Any) -> PrePass:
434 | mute_unused_args(*args, **kwargs)
435 |
436 | shapes = input.size()
437 |
438 | if len(pad) % 2 == 1:
439 | raise ValueError(f"Length of pad must be divisible by 2. Got {len(pad)}.")
440 |
441 | if len(pad) > (maxlen := len(shapes) * 2):
442 | raise ValueError(
443 | f"Padding is way too long. Got {pad}, but {maxlen} is the maximum dimensions allowed."
444 | )
445 |
446 | pad = (2 * len(shapes) - len(pad)) * [0] + list(reversed(pad))
447 |
448 | pad0 = pad[0::2]
449 | pad1 = pad[1::2]
450 |
451 | assert len(pad0) == len(pad1) == len(shapes), [pad0, pad1, shapes]
452 |
453 | return PrePass(
454 | [s + p0 + p1 for (s, p0, p1) in zip(shapes, pad0, pad1)],
455 | same([input], interfaces.bat(input), trivial),
456 | )
457 |
458 |
459 | def _int_to_tuple(value: int | Tuple[int, ...], length: int) -> Tuple[int, ...]:
460 | if isinstance(value, int):
461 | return (value,) * length
462 |
463 | assert isinstance(value, Tuple)
464 | assert len(value) == length
465 | return value
466 |
467 |
468 | def conv(
469 | input: TensorLike,
470 | weight: TensorLike,
471 | bias: TensorLike | None = None,
472 | stride: int | Tuple[int, ...] = 1,
473 | padding: int | Tuple[int, ...] | str = "valid",
474 | dilation: int | Tuple[int, ...] = 1,
475 | groups: int = 1,
476 | *args: Any,
477 | **kwargs: Any,
478 | ) -> PrePass:
479 | mute_unused_args(groups, *args, **kwargs)
480 |
481 | (batch, chan, *dims) = input.size()
482 | (out_chan, in_chan, *kernels) = weight.size()
483 |
484 | assert chan == in_chan
485 |
486 | if bias is not None:
487 | assert shapes.coerce(bias.size(), (out_chan,)) is not None
488 |
489 | if isinstance(padding, str):
490 | raise UnsupportedError
491 |
492 | stride = _int_to_tuple(stride, len(dims))
493 | padding = _int_to_tuple(padding, len(dims))
494 | dilation = _int_to_tuple(dilation, len(dims))
495 |
496 | assert len(dims) == len(kernels) == len(stride) == len(padding) == len(dilation)
497 |
498 | out_dims = [
499 | math.floor((dim + 2 * pad - dil * (ker - 1) - 1) / st + 1)
500 | for (dim, pad, dil, ker, st) in zip(dims, padding, dilation, kernels, stride)
501 | ]
502 |
503 | return PrePass(
504 | (batch, out_chan, *out_dims),
505 | same([input, weight], interfaces.bat(input), trivial),
506 | )
507 |
508 |
509 | def conv_transpose(
510 | input: TensorLike,
511 | weight: TensorLike,
512 | bias: TensorLike | None = None,
513 | stride: int | Tuple[int, ...] = 1,
514 | padding: int | Tuple[int, ...] = 0,
515 | output_padding: int | Tuple[int, ...] = 0,
516 | groups: int = 1,
517 | dilation: int | Tuple[int, ...] = 1,
518 | *args: Any,
519 | **kwargs: Any,
520 | ) -> PrePass:
521 | mute_unused_args(groups, *args, **kwargs)
522 |
523 | (batch, chan, *dims) = input.size()
524 | (in_chan, out_chan, *kernels) = weight.size()
525 |
526 | assert chan == in_chan
527 |
528 | if bias is not None:
529 | assert shapes.coerce(bias.size(), (out_chan,)) is not None
530 |
531 | stride = _int_to_tuple(stride, len(dims))
532 | padding = _int_to_tuple(padding, len(dims))
533 | output_padding = _int_to_tuple(output_padding, len(dims))
534 | dilation = _int_to_tuple(dilation, len(dims))
535 |
536 | assert len(dims) == len(kernels) == len(stride) == len(padding) == len(dilation)
537 |
538 | out_dims = [
539 | (dim - 1) * st - 2 * pad + dil * (ker - 1) + opad + 1
540 | for (dim, st, pad, dil, ker, opad) in zip(
541 | dims, stride, padding, dilation, kernels, output_padding
542 | )
543 | ]
544 |
545 | return PrePass(
546 | (batch, out_chan, *out_dims),
547 | same([input, weight], interfaces.bat(input), trivial),
548 | )
549 |
550 |
551 | def pool(
552 | input: TensorLike,
553 | *,
554 | kernel_size: int | Tuple[int, ...],
555 | stride: int | Tuple[int, ...] = (),
556 | padding: int | Tuple[int, ...] = 0,
557 | dilation: int | Tuple[int, ...] = 1,
558 | ceil_mode: bool = False,
559 | ) -> PrePass:
560 | (batch, chan, *dims) = input.size()
561 |
562 | kernel_size = _int_to_tuple(kernel_size, len(dims))
563 | stride = _int_to_tuple(stride, len(dims))
564 | padding = _int_to_tuple(padding, len(dims))
565 | dilation = _int_to_tuple(dilation, len(dims))
566 |
567 | rounding = math.ceil if ceil_mode else math.floor
568 | out_dims = [
569 | rounding((dim + 2 * pad - dil * (ker - 1) - 1) / st + 1)
570 | for (dim, pad, dil, ker, st) in zip(
571 | dims, padding, dilation, kernel_size, stride
572 | )
573 | ]
574 |
575 | return PrePass(
576 | (batch, chan, *out_dims), same([input], interfaces.bat(input), trivial)
577 | )
578 |
579 |
580 | def maxpool(
581 | input: TensorLike,
582 | kernel_size: int | Tuple[int, ...],
583 | stride: int | Tuple[int, ...] = (),
584 | padding: int | Tuple[int, ...] = 0,
585 | dilation: int | Tuple[int, ...] = 1,
586 | ceil_mode: bool = False,
587 | return_indices: bool = False,
588 | *args: Any,
589 | **kwargs: Any,
590 | ) -> PrePass:
591 | mute_unused_args(*args, **kwargs)
592 |
593 | if return_indices:
594 | raise UnsupportedError
595 |
596 | return pool(
597 | input,
598 | kernel_size=kernel_size,
599 | stride=stride,
600 | padding=padding,
601 | dilation=dilation,
602 | ceil_mode=ceil_mode,
603 | )
604 |
605 |
606 | def avgpool(
607 | input: TensorLike,
608 | kernel_size: int | Tuple[int, ...],
609 | stride: int | Tuple[int, ...] = (),
610 | padding: int | Tuple[int, ...] = 0,
611 | ceil_mode: bool = False,
612 | *args: Any,
613 | **kwargs: Any,
614 | ) -> PrePass:
615 | mute_unused_args(*args, **kwargs)
616 |
617 | return pool(
618 | input,
619 | kernel_size=kernel_size,
620 | stride=stride,
621 | padding=padding,
622 | ceil_mode=ceil_mode,
623 | )
624 |
--------------------------------------------------------------------------------
/src/koila/lazy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | from __future__ import annotations
4 |
5 | import builtins
6 | import dataclasses as dcls
7 | import functools
8 | import logging
9 | from dataclasses import dataclass
10 | from functools import wraps
11 | from typing import (
12 | Any,
13 | Callable,
14 | Dict,
15 | Generic,
16 | List,
17 | NamedTuple,
18 | NoReturn,
19 | Sequence,
20 | Tuple,
21 | Type,
22 | TypeVar,
23 | final,
24 | overload,
25 | )
26 |
27 | import torch
28 | from rich.logging import RichHandler
29 | from torch import Tensor, cuda
30 | from torch import device as Device
31 | from torch import dtype as DType
32 |
33 | from . import gpus, interfaces, prepasses
34 | from .errors import UnsupportedError
35 | from .interfaces import BatchInfo, RunnableTensor, TensorLike
36 | from .prepasses import PrePass, PrePassFunc
37 |
38 | T = TypeVar("T")
39 | V = TypeVar("V", contravariant=True)
40 |
41 | LOGGER = logging.getLogger(__name__)
42 | LOGGER.addHandler(RichHandler())
43 |
44 |
45 | @dataclass(frozen=True)
46 | class LazyFunction(Generic[V]):
47 | func: Callable[..., Tensor]
48 | prepass_func: PrePassFunc
49 |
50 | def __call__(self, *args: Any, **kwargs: Any) -> LazyTensor:
51 | lazy_args = tuple(lazy(arg) for arg in args)
52 | lazy_kwargs = dict((k, lazy(v)) for (k, v) in kwargs.items())
53 | prepass = self.prepass_func(*args, **kwargs)
54 | return LazyTensor(Evaluation(self.func, prepass, *lazy_args, **lazy_kwargs))
55 |
56 | def __get__(self, obj: V, objtype: Type[V]) -> Callable[..., LazyTensor]:
57 | assert isinstance(obj, objtype), [type(obj), objtype]
58 | if obj is None:
59 | return self
60 | else:
61 | return functools.partial(self, obj)
62 |
63 |
64 | @final
65 | @dataclass(init=False)
66 | class Evaluation(RunnableTensor):
67 | func: Callable[..., Tensor]
68 | prepass: PrePass
69 | args: Tuple[LazyTensor | Tensor | int | float | bool, ...] = dcls.field(
70 | default_factory=tuple
71 | )
72 | kwargs: Dict[str, LazyTensor | Tensor | int | float | bool] = dcls.field(
73 | default_factory=dict
74 | )
75 |
76 | def __init__(
77 | self,
78 | func: Callable[..., Tensor],
79 | prepass: PrePass,
80 | *args: LazyTensor | Tensor | int | float | bool,
81 | **kwargs: LazyTensor | Tensor | int | float | bool,
82 | ) -> None:
83 | self.func = func
84 | self.prepass = prepass
85 | self.args = args
86 | self.kwargs = kwargs
87 |
88 | def __hash__(self) -> int:
89 | # Evaluations are unique.
90 | return id(self)
91 |
92 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor:
93 | real_args = [interfaces.run(arg, partial) for arg in self.args]
94 | real_kwargs = {k: interfaces.run(v, partial) for (k, v) in self.kwargs.items()}
95 |
96 | result = self.func(*real_args, **real_kwargs)
97 |
98 | # Checks the shape only when pre-passing.
99 | # If partial is supplemented, it means the tensors are really evaluated
100 | if partial is None:
101 | assert self.prepass.shape == result.shape, [self.prepass, result.shape]
102 | elif (reducer := self.prepass.reducer()) is None:
103 | raise UnsupportedError("Cannot safely parallelize.")
104 | else:
105 | LOGGER.debug(
106 | "Evaluation taking batch: (%s, %s), low=%s, high=%s",
107 | self.size(),
108 | self.batch(),
109 | partial[0],
110 | partial[1],
111 | )
112 | callback = reducer(input, *self.args, **self.kwargs)
113 | result = callback(result)
114 |
115 | return result
116 |
117 | def visit(self, nodes: Dict[int, TensorLike]) -> None:
118 | if hash(self) in nodes.keys():
119 | return
120 |
121 | for arg in self.args:
122 | if isinstance(arg, Tensor):
123 | nodes[hash(arg)] = arg
124 | elif isinstance(arg, RunnableTensor):
125 | arg.visit(nodes)
126 |
127 | for val in self.kwargs.values():
128 | if isinstance(val, Tensor):
129 | nodes[hash(val)] = val
130 | elif isinstance(val, RunnableTensor):
131 | val.visit(nodes)
132 |
133 | assert hash(self) not in nodes.keys()
134 | nodes[hash(self)] = self
135 |
136 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]:
137 | shape = self.prepass.shape
138 | if dim is not None:
139 | return shape[dim]
140 | else:
141 | return shape
142 |
143 | def dtype(self) -> DType:
144 | return self.prepass.dtype()
145 |
146 | def device(self) -> str | Device:
147 | return self.prepass.device()
148 |
149 | def batch(self) -> BatchInfo | None:
150 | return self.prepass.batch()
151 |
152 |
153 | @final
154 | @dataclass(init=False, repr=False)
155 | class LazyTensor(RunnableTensor):
156 | _data: TensorLike
157 | _batch: BatchInfo | None = None
158 |
159 | def __init__(self, data: TensorLike, batch: int | None = None) -> None:
160 | if isinstance(data, LazyTensor):
161 | self._data = data._data
162 | self._batch = data._batch
163 | elif isinstance(data, Evaluation):
164 | self._data = data
165 | self._batch = data.batch()
166 | else:
167 | self._data = data
168 | if batch is None:
169 | self._batch = None
170 | else:
171 | self._batch = BatchInfo(batch, data.size(batch))
172 |
173 | LOGGER.debug("Creating LazyTensor. %s, %s", type(self._data), self._batch)
174 |
175 | # Implementations
176 |
177 | def run(self, partial: Tuple[int, int] | None = None) -> Tensor:
178 | data = self._data
179 | if isinstance(data, Tensor):
180 | if partial is None or self._batch is None:
181 | return data
182 | else:
183 | (low, high) = partial
184 | return data.index_select(
185 | self._batch.index,
186 | torch.tensor(list(range(low, high)), device=data.device),
187 | )
188 | else:
189 | return data.run(partial)
190 |
191 | def visit(self, nodes: Dict[int, TensorLike]) -> None:
192 | data = self._data
193 |
194 | if hash(self) in nodes.keys():
195 | return
196 |
197 | if isinstance(data, Evaluation):
198 | data.visit(nodes)
199 | else:
200 | nodes[hash(self)] = self
201 |
202 | assert hash(self) in nodes.keys()
203 |
204 | @overload
205 | def size(self) -> Tuple[int, ...]: ...
206 |
207 | @overload
208 | def size(self, dim: int) -> int: ...
209 |
210 | def size(self, dim: int | None = None) -> int | Tuple[int, ...]:
211 | data = self._data
212 |
213 | if dim is None:
214 | return data.size()
215 |
216 | return data.size(dim)
217 |
218 | def dtype(self) -> DType:
219 | dt = interfaces.dtyp(self._data)
220 | return dt
221 |
222 | def device(self) -> str | Device:
223 | return interfaces.dev(self._data)
224 |
225 | def batch(self) -> BatchInfo | None:
226 | return self._batch
227 |
228 | # Magic methods
229 |
230 | def __str__(self) -> str:
231 | return f"LazyTensor {self.run()}"
232 |
233 | def __bool__(self) -> bool:
234 | return bool(self.item())
235 |
236 | def __int__(self) -> int:
237 | return int(self.item())
238 |
239 | def __float__(self) -> float:
240 | return float(self.item())
241 |
242 | def __invert__(self) -> bool:
243 | return not bool(self)
244 |
245 | def __pos__(self) -> TensorLike:
246 | return lazy_forward(Tensor.__pos__, prepasses.identity, self)
247 |
248 | def __neg__(self) -> TensorLike:
249 | return lazy_forward(Tensor.__neg__, prepasses.identity, self)
250 |
251 | def __add__(self, other: TensorLike) -> TensorLike:
252 | return lazy_forward(Tensor.__add__, prepasses.symmetric, self, other)
253 |
254 | def __radd__(self, other: TensorLike) -> TensorLike:
255 | return lazy_forward(Tensor.__add__, prepasses.symmetric, other, self)
256 |
257 | def __sub__(self, other: TensorLike) -> TensorLike:
258 | return lazy_forward(Tensor.__sub__, prepasses.symmetric, self, other)
259 |
260 | def __rsub__(self, other: TensorLike) -> TensorLike:
261 | return lazy_forward(Tensor.__sub__, prepasses.symmetric, other, self)
262 |
263 | def __mul__(self, other: TensorLike) -> TensorLike:
264 | return lazy_forward(Tensor.__mul__, prepasses.symmetric, self, other)
265 |
266 | def __rmul__(self, other: TensorLike) -> TensorLike:
267 | return lazy_forward(Tensor.__mul__, prepasses.symmetric, other, self)
268 |
269 | def __truediv__(self, other: TensorLike) -> TensorLike:
270 | return lazy_forward(Tensor.__truediv__, prepasses.symmetric, self, other)
271 |
272 | def __rtruediv__(self, other: TensorLike) -> TensorLike:
273 | return lazy_forward(Tensor.__truediv__, prepasses.symmetric, other, self)
274 |
275 | def __floordiv__(self, other: TensorLike) -> NoReturn:
276 | del other
277 | raise UnsupportedError
278 |
279 | def __rfloordiv__(self, other: TensorLike) -> NoReturn:
280 | del other
281 | raise UnsupportedError
282 |
283 | def __pow__(self, other: TensorLike) -> TensorLike:
284 | return lazy_forward(Tensor.__pow__, prepasses.symmetric, self, other)
285 |
286 | def __rpow__(self, other: TensorLike) -> TensorLike:
287 | return lazy_forward(Tensor.__pow__, prepasses.symmetric, other, self)
288 |
289 | def __mod__(self, other: TensorLike) -> TensorLike:
290 | return lazy_forward(Tensor.__mod__, prepasses.symmetric, self, other)
291 |
292 | def __rmod__(self, other: TensorLike) -> TensorLike:
293 | return lazy_forward(Tensor.__mod__, prepasses.symmetric, other, self)
294 |
295 | def __divmod__(self, other: TensorLike) -> NoReturn:
296 | del other
297 | raise UnsupportedError
298 |
299 | def __rdivmod__(self, other: TensorLike) -> NoReturn:
300 | del other
301 | raise UnsupportedError
302 |
303 | def __abs__(self) -> TensorLike:
304 | return lazy_forward(Tensor.__abs__, prepasses.identity, self)
305 |
306 | def __hash__(self) -> int:
307 | # LazyTensors are not unique. They are defined by their data.
308 | return id(self._data)
309 |
310 | def __matmul__(self, other: TensorLike) -> TensorLike:
311 | return lazy_forward(Tensor.__matmul__, prepasses.matmul, self, other)
312 |
313 | def __rmatmul__(self, other: TensorLike) -> TensorLike:
314 | return lazy_forward(Tensor.__matmul__, prepasses.matmul, other, self)
315 |
316 | def __eq__(self, other: TensorLike) -> TensorLike:
317 | return lazy_forward(Tensor.__eq__, prepasses.symmetric, self, other)
318 |
319 | def __ne__(self, other: TensorLike) -> TensorLike:
320 | return lazy_forward(Tensor.__ne__, prepasses.symmetric, self, other)
321 |
322 | def __gt__(self, other: TensorLike) -> TensorLike:
323 | return lazy_forward(Tensor.__gt__, prepasses.symmetric, self, other)
324 |
325 | def __ge__(self, other: TensorLike) -> TensorLike:
326 | return lazy_forward(Tensor.__ge__, prepasses.symmetric, self, other)
327 |
328 | def __lt__(self, other: TensorLike) -> TensorLike:
329 | return lazy_forward(Tensor.__lt__, prepasses.symmetric, self, other)
330 |
331 | def __le__(self, other: TensorLike) -> TensorLike:
332 | return lazy_forward(Tensor.__le__, prepasses.symmetric, self, other)
333 |
334 | def __len__(self) -> int:
335 | return self.size(0)
336 |
337 | def __getitem__(
338 | self, index: int | slice | Tensor | List[Any] | Tuple[Any] | None
339 | ) -> Tensor:
340 | if isinstance(self._data, RunnableTensor):
341 | data = self._data.run()
342 | else:
343 | data = self._data
344 | return data[index]
345 |
346 | def __setitem__(
347 | self,
348 | index: int | slice | Tensor | List[Any] | Tuple[Any] | None,
349 | value: Tensor,
350 | ) -> None:
351 | if isinstance(self._data, RunnableTensor):
352 | raise UnsupportedError
353 |
354 | self._data[index] = value
355 |
356 | def __getattr__(self, name: str) -> Callable[..., Any]:
357 | LOGGER.debug(
358 | f"__getattr__ called for {name}. Automatically resolving function."
359 | )
360 |
361 | method = getattr(Tensor, name)
362 | wrapper = functools.wraps(method)
363 |
364 | if (custom_impl := CUSTOM_OPS.lookup_method(name)) is not None:
365 | LOGGER.debug("A custom method definition is found.")
366 | partial = functools.partial(custom_impl, self)
367 | elif (shape_impl := SHAPE_OPS.lookup_method(name)) is not None:
368 | LOGGER.debug("A custom shape method is found. Lazy evaluation.")
369 | partial = functools.partial(lazy_forward, method, shape_impl, self)
370 | else:
371 | LOGGER.debug("No custom methods found. Evaluating eagerly.")
372 | partial = functools.partial(method, interfaces.run(self))
373 |
374 | return wrapper(partial)
375 |
376 | @classmethod
377 | def __torch_function__(
378 | cls,
379 | func: Callable[..., Tensor],
380 | types: Tuple[Type[Any], ...],
381 | args: Sequence[TensorLike] = (),
382 | kwargs: Dict[str, TensorLike] | None = None,
383 | ) -> TensorLike:
384 | if kwargs is None:
385 | kwargs = {}
386 |
387 | if not builtins.all(
388 | issubclass(typ, (LazyTensor, Tensor, int, float, bool)) for typ in types
389 | ):
390 | return NotImplemented
391 |
392 | name = func.__name__
393 |
394 | if (custom_impl := CUSTOM_OPS.lookup_function(name)) is not None:
395 | LOGGER.debug("A custom function definition is found.")
396 | return custom_impl(*args, **kwargs)
397 | elif (shape_impl := SHAPE_OPS.lookup_function(name)) is not None:
398 | LOGGER.debug("A custom shape function is found. Lazy evaluation.")
399 | return lazy_forward(func, shape_impl, *args, **kwargs)
400 | else:
401 | LOGGER.debug("No custom method found. Evaluating eagerly.")
402 | args = [interfaces.run(arg) for arg in args]
403 | kwargs = {k: interfaces.run(v) for (k, v) in kwargs.items()}
404 | return func(*args, **kwargs)
405 |
406 | @property
407 | @wraps(Tensor.size)
408 | def shape(self) -> Tuple[int, ...]:
409 | return self.size()
410 |
411 | @property
412 | @wraps(Tensor.dim)
413 | def ndim(self) -> int:
414 | return self.dim()
415 |
416 | @property
417 | @wraps(Tensor.t)
418 | def T(self) -> TensorLike:
419 | return self.t()
420 |
421 | def torch(self) -> Tensor:
422 | return self.run()
423 |
424 | def backward(self) -> None:
425 | if self._batch is None or not cuda.is_available():
426 | LOGGER.debug(
427 | "Unable to parallelize across batches."
428 | " "
429 | "Running backward with native pytorch."
430 | )
431 | self.run().backward()
432 | else:
433 | total = 0
434 | LOGGER.debug("Able to parallelize across batches. Hooray!")
435 | for mini_batch_size in gpus.split_batch(
436 | self.buffer_memory(), self._batch.value
437 | ):
438 | LOGGER.debug("Using mini batch size: %d.", mini_batch_size)
439 | mini_batch = self.run((total, total + mini_batch_size))
440 | total += mini_batch_size
441 | mini_batch.backward()
442 |
443 |
444 | @overload
445 | def lazy(val: Tensor | LazyTensor, batch: int | None = None) -> LazyTensor: ...
446 |
447 |
448 | @overload
449 | def lazy(
450 | *val: Tensor | LazyTensor, batch: int | None = None
451 | ) -> Tuple[LazyTensor, ...]: ...
452 |
453 |
454 | @overload
455 | def lazy(val: int) -> int: ...
456 |
457 |
458 | @overload
459 | def lazy(*val: int) -> Tuple[int, ...]: ...
460 |
461 |
462 | @overload
463 | def lazy(val: float) -> float: ...
464 |
465 |
466 | @overload
467 | def lazy(*val: float) -> Tuple[float, ...]: ...
468 |
469 |
470 | @overload
471 | def lazy(val: bool) -> bool: ...
472 |
473 |
474 | @overload
475 | def lazy(*val: bool) -> Tuple[bool, ...]: ...
476 |
477 |
478 | def lazy(*values: Any, batch: int | None = None) -> Any:
479 | results = []
480 | for val in values:
481 | LOGGER.debug("lazy %s, %s", type(val), interfaces.bat(val))
482 |
483 | if isinstance(val, Tensor):
484 | val = LazyTensor(val, batch)
485 |
486 | results.append(val)
487 |
488 | if len(results) == 1:
489 | return results[0]
490 |
491 | return tuple(results)
492 |
493 |
494 | def lazy_forward(
495 | func: Callable[..., Any], shape_func: PrePassFunc, *args: Any, **kwargs: Any
496 | ) -> TensorLike:
497 | if torch.is_grad_enabled():
498 | out = LazyFunction(func, shape_func)(*args, **kwargs)
499 | LOGGER.debug("lazy forward %s, %s", out.size(), out.batch())
500 | return out
501 | else:
502 | run_args = [interfaces.run(arg) for arg in args]
503 | run_kwargs = {k: interfaces.run(v) for (k, v) in kwargs.items()}
504 | out = func(*run_args, **run_kwargs)
505 | LOGGER.debug("eager forward (%s, %s) -> %s", run_args, run_kwargs, out)
506 | return out
507 |
508 |
509 | # Functions that require special handling.
510 |
511 |
512 | class _ValIdx(NamedTuple):
513 | values: TensorLike
514 | indices: TensorLike
515 |
516 |
517 | @overload
518 | def _min(input: TensorLike) -> TensorLike: ...
519 |
520 |
521 | @overload
522 | def _min(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ...
523 |
524 |
525 | @overload
526 | def _min(input: TensorLike, other: TensorLike) -> TensorLike: ...
527 |
528 |
529 | @wraps(torch.min)
530 | def _min(input: TensorLike, *args: Any, **kwargs: Any) -> TensorLike | _ValIdx:
531 | if len(args) == len(kwargs) == 0:
532 | return lazy_forward(torch.min, prepasses.reduce_dims, input)
533 |
534 | if (
535 | len(args) == 1
536 | and isinstance((other := args[0]), (Tensor, LazyTensor))
537 | or len(kwargs) == 1
538 | and (other := kwargs.get("other", None) is not None)
539 | ):
540 | return lazy_forward(torch.minimum, prepasses.symmetric, input, other)
541 |
542 | return _ValIdx(
543 | lazy_forward(torch.amin, prepasses.reduce_dims, input, *args, **kwargs),
544 | lazy_forward(torch.argmin, prepasses.reduce_dims, input, *args, **kwargs),
545 | )
546 |
547 |
548 | @overload
549 | def _max(input: TensorLike) -> TensorLike: ...
550 |
551 |
552 | @overload
553 | def _max(input: TensorLike, dim: int, keepdim: bool = False) -> _ValIdx: ...
554 |
555 |
556 | @overload
557 | def _max(input: TensorLike, other: TensorLike) -> TensorLike: ...
558 |
559 |
560 | @wraps(torch.max)
561 | def _max(input: TensorLike, *args: Any, **kwargs: Any) -> TensorLike | _ValIdx:
562 | if len(args) == len(kwargs) == 0:
563 | return lazy_forward(torch.max, prepasses.reduce_dims, input)
564 |
565 | if (
566 | len(args) == 1
567 | and isinstance((other := args[0]), (Tensor, LazyTensor))
568 | or len(kwargs) == 1
569 | and (other := kwargs.get("other", None) is not None)
570 | ):
571 | return lazy_forward(torch.maximum, prepasses.symmetric, input, other)
572 |
573 | return _ValIdx(
574 | lazy_forward(torch.amax, prepasses.reduce_dims, input, *args, **kwargs),
575 | lazy_forward(torch.argmax, prepasses.reduce_dims, input, *args, **kwargs),
576 | )
577 |
578 |
579 | def _permute_function_shape(
580 | input: TensorLike, dims: int | Tuple[int, ...], *args: Any, **kwargs: Any
581 | ) -> PrePass:
582 | prepasses.mute_unused_args(*args, **kwargs)
583 |
584 | if isinstance(dims, int):
585 | dims = (dims,)
586 |
587 | return prepasses.permute(input, *dims)
588 |
589 |
590 | def _reshape_function_shape(
591 | input: TensorLike, dims: Tuple[int, ...], *args: Any, **kwargs: Any
592 | ) -> PrePass:
593 | prepasses.mute_unused_args(*args, **kwargs)
594 |
595 | return prepasses.reshape(input, *dims)
596 |
597 |
598 | def _t_shape(input: TensorLike, *args: Any, **kwargs: Any) -> PrePass:
599 | prepasses.mute_unused_args(*args, **kwargs)
600 |
601 | return prepasses.tranpose(input, 0, 1)
602 |
603 |
604 | @dataclass
605 | class MethodFunction(Generic[T]):
606 | method: Dict[str, T]
607 | function: Dict[str, T]
608 |
609 | @staticmethod
610 | def _search(key: str, *dbs: Dict[str, T]) -> T | None:
611 | for db in dbs:
612 | if (value := db.get(key)) is not None:
613 | return value
614 | return None
615 |
616 | def lookup(self, key: str, *dbs: Dict[str, T]) -> T | None:
617 | if (result := self._search(key, *dbs)) is not None:
618 | return result
619 |
620 | if key.startswith("_"):
621 | fallback = key.lstrip("_")
622 | return self._search(fallback, *dbs)
623 | return None
624 |
625 | def lookup_method(self, key: str) -> T | None:
626 | return self.lookup(key, self.method, self.function)
627 |
628 | def lookup_function(self, key: str) -> T | None:
629 | return self.lookup(key, self.function)
630 |
631 |
632 | CUSTOM_OPS = MethodFunction[Callable](
633 | method={},
634 | function={
635 | "min": _min,
636 | "max": _max,
637 | },
638 | )
639 |
640 | PARTIAL_OPS = MethodFunction[Callable](method={}, function={"sum": lambda x: x})
641 |
642 | SHAPE_OPS = MethodFunction[PrePassFunc](
643 | method={"permute": prepasses.permute, "view": prepasses.view},
644 | function={
645 | "positive": prepasses.identity,
646 | "negative": prepasses.identity,
647 | "neg": prepasses.identity,
648 | "add": prepasses.symmetric,
649 | "sub": prepasses.symmetric,
650 | "subtract": prepasses.symmetric,
651 | "mul": prepasses.symmetric,
652 | "multiply": prepasses.symmetric,
653 | "div": prepasses.symmetric,
654 | "divide": prepasses.symmetric,
655 | "true_divide": prepasses.symmetric,
656 | "floor": prepasses.identity,
657 | "fmod": prepasses.symmetric,
658 | "remainder": prepasses.symmetric,
659 | "frac": prepasses.identity,
660 | "pow": prepasses.symmetric,
661 | "exp": prepasses.identity,
662 | "exp2": prepasses.identity,
663 | "log": prepasses.identity,
664 | "log2": prepasses.identity,
665 | "log10": prepasses.identity,
666 | "log1p": prepasses.identity,
667 | "abs": prepasses.identity,
668 | "matmul": prepasses.matmul,
669 | "bmm": prepasses.matmul,
670 | "mm": prepasses.matmul,
671 | "mv": prepasses.matmul,
672 | "dot": prepasses.matmul,
673 | "eq": prepasses.symmetric,
674 | "equal": prepasses.symmetric,
675 | "ne": prepasses.symmetric,
676 | "not_equal": prepasses.symmetric,
677 | "gt": prepasses.symmetric,
678 | "greater": prepasses.symmetric,
679 | "ge": prepasses.symmetric,
680 | "greater_equal": prepasses.symmetric,
681 | "lt": prepasses.symmetric,
682 | "less": prepasses.symmetric,
683 | "le": prepasses.symmetric,
684 | "less_equal": prepasses.symmetric,
685 | "mean": prepasses.mean,
686 | "sum": prepasses.reduce_dims,
687 | "std": prepasses.reduce_dims,
688 | "minimum": prepasses.symmetric,
689 | "maximum": prepasses.symmetric,
690 | "amin": prepasses.reduce_dims,
691 | "amax": prepasses.reduce_dims,
692 | "argmin": prepasses.reduce_dims,
693 | "argmax": prepasses.reduce_dims,
694 | "isclose": prepasses.symmetric,
695 | "cat": prepasses.cat,
696 | "t": _t_shape,
697 | "permute": _permute_function_shape,
698 | "reshape": _reshape_function_shape,
699 | "flatten": prepasses.flatten,
700 | "transpose": prepasses.tranpose,
701 | "select": prepasses.select,
702 | "index_select": prepasses.select,
703 | "sin": prepasses.identity,
704 | "cos": prepasses.identity,
705 | "tan": prepasses.identity,
706 | "asin": prepasses.identity,
707 | "acos": prepasses.identity,
708 | "atan": prepasses.identity,
709 | "sinh": prepasses.identity,
710 | "cosh": prepasses.identity,
711 | "tanh": prepasses.identity,
712 | "asinh": prepasses.identity,
713 | "acosh": prepasses.identity,
714 | "atanh": prepasses.identity,
715 | "sigmoid": prepasses.identity,
716 | "hardsigmoid": prepasses.identity,
717 | "softmax": prepasses.identity,
718 | "relu": prepasses.identity,
719 | "relu6": prepasses.identity,
720 | "leaky_relu": prepasses.identity,
721 | "l1_loss": prepasses.loss,
722 | "smooth_l1_loss": prepasses.loss,
723 | "mse_loss": prepasses.loss,
724 | "cross_entropy": prepasses.loss,
725 | "binary_cross_entropy": prepasses.loss,
726 | "binary_cross_entropy_with_logits": prepasses.loss,
727 | "elu": prepasses.identity,
728 | "gelu": prepasses.identity,
729 | "dropout": prepasses.identity,
730 | "batch_norm": prepasses.identity,
731 | "layer_norm": prepasses.identity,
732 | "linear": prepasses.linear,
733 | "embedding": prepasses.embedding,
734 | "pad": prepasses.pad,
735 | "conv1d": prepasses.conv,
736 | "conv2d": prepasses.conv,
737 | "conv3d": prepasses.conv,
738 | "conv_transpose1d": prepasses.conv_transpose,
739 | "conv_transpose2d": prepasses.conv_transpose,
740 | "conv_transpose3d": prepasses.conv_transpose,
741 | "max_pool1d": prepasses.maxpool,
742 | "max_pool2d": prepasses.maxpool,
743 | "max_pool3d": prepasses.maxpool,
744 | "avg_pool1d": prepasses.avgpool,
745 | "avg_pool2d": prepasses.avgpool,
746 | "avg_pool3d": prepasses.avgpool,
747 | # Functions that will not be implemented.
748 | "__floordiv__": UnsupportedError.raise_error,
749 | },
750 | )
751 |
--------------------------------------------------------------------------------
/tests/test_lazy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) RenChu Wang - All Rights Reserved
2 |
3 | import math
4 | import typing
5 |
6 | import numpy as np
7 | import torch
8 | from torch import Tensor
9 | from torch.nn import functional as F
10 |
11 | import koila
12 | from koila import Evaluation, LazyTensor, Runnable, RunnableTensor
13 |
14 | from . import common
15 |
16 |
17 | def test_lazytensor_is_runnable() -> None:
18 | assert issubclass(Evaluation, Runnable)
19 | assert issubclass(Evaluation, RunnableTensor)
20 | assert issubclass(LazyTensor, Runnable)
21 | assert issubclass(LazyTensor, RunnableTensor)
22 |
23 |
24 | def test_positive_op() -> None:
25 | common.call(
26 | lambda a, c: common.assert_isclose((+a).item(), c),
27 | [[LazyTensor(torch.tensor(-11)), -11]],
28 | )
29 |
30 |
31 | def test_positive_method() -> None:
32 | common.call(
33 | lambda a, c: common.assert_isclose(a.positive().item(), c),
34 | [[LazyTensor(torch.tensor(4)), 4]],
35 | )
36 |
37 |
38 | def test_positive_function() -> None:
39 | common.call(
40 | lambda a, c: common.assert_isclose(torch.positive(a).item(), c),
41 | [[LazyTensor(torch.tensor(-8)), -8]],
42 | )
43 |
44 |
45 | def test_negative_op() -> None:
46 | common.call(
47 | lambda a, c: common.assert_isclose((-a).item(), c),
48 | [[LazyTensor(torch.tensor(-13)), 13]],
49 | )
50 |
51 |
52 | def test_negative_method() -> None:
53 | common.call(
54 | lambda a, c: common.assert_isclose(a.neg().item(), c),
55 | [[LazyTensor(torch.tensor(2)), -2]],
56 | )
57 |
58 |
59 | def test_negative_function() -> None:
60 | common.call(
61 | lambda a, c: common.assert_equal(torch.neg(a).item(), c),
62 | [[LazyTensor(torch.tensor(-5)), 5]],
63 | )
64 |
65 |
66 | def test_eq_ne_op() -> None:
67 | arr = torch.randint(0, 2, [2, 3, 4])
68 | brr = torch.randint(0, 2, [2, 3, 4])
69 | la = typing.cast(Tensor, LazyTensor(arr))
70 | lb = typing.cast(Tensor, LazyTensor(brr))
71 | common.call(
72 | lambda a, c: common.assert_equal(koila.run(a), c),
73 | [[la == lb, arr == brr], [la != lb, arr != brr]],
74 | )
75 |
76 |
77 | def test_cmp_op() -> None:
78 | arr = torch.randint(0, 5, [2, 3, 4])
79 | brr = torch.randint(0, 5, [2, 3, 4])
80 | la = typing.cast(Tensor, LazyTensor(arr))
81 | lb = typing.cast(Tensor, LazyTensor(brr))
82 | common.call(
83 | lambda a, c: common.assert_equal(koila.run(a), c),
84 | [
85 | [la < lb, arr < brr],
86 | [la <= lb, arr <= brr],
87 | [la > lb, arr > brr],
88 | [la >= lb, arr >= brr],
89 | ],
90 | )
91 |
92 |
93 | def test_add_op() -> None:
94 | common.call(
95 | lambda a, b, c: common.assert_isclose((a + b).item(), c),
96 | [
97 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 + 2],
98 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 + 2],
99 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 + 2],
100 | ],
101 | )
102 |
103 |
104 | def test_add_method() -> None:
105 | common.call(
106 | lambda a, b, c: common.assert_isclose(a.add(b).item(), c),
107 | [
108 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 + 3],
109 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 + 3],
110 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 + 3],
111 | ],
112 | )
113 |
114 |
115 | def test_add_function() -> None:
116 | common.call(
117 | lambda a, b, c: common.assert_isclose(torch.add(a, b).item(), c),
118 | [
119 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 8 + 4],
120 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 8 + 4],
121 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 8 + 4],
122 | ],
123 | )
124 |
125 |
126 | def test_sub_op() -> None:
127 | common.call(
128 | lambda a, b, c: common.assert_isclose((a - b).item(), c),
129 | [
130 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 - 2],
131 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 - 2],
132 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 - 2],
133 | ],
134 | )
135 |
136 |
137 | def test_sub_method() -> None:
138 | common.call(
139 | lambda a, b, c: common.assert_isclose(a.sub(b).item(), c),
140 | [
141 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 - 3],
142 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 - 3],
143 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 - 3],
144 | ],
145 | )
146 |
147 |
148 | def test_sub_function() -> None:
149 | common.call(
150 | lambda a, b, c: common.assert_isclose(torch.sub(a, b).item(), c),
151 | [
152 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 8 - 4],
153 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 8 - 4],
154 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 8 - 4],
155 | ],
156 | )
157 |
158 |
159 | def test_mul_op() -> None:
160 | common.call(
161 | lambda a, b, c: common.assert_isclose((a * b).item(), c),
162 | [
163 | [LazyTensor(torch.tensor(0.5)), LazyTensor(torch.tensor(2)), 0.5 * 2],
164 | [torch.tensor(0.5), LazyTensor(torch.tensor(2)), 0.5 * 2],
165 | [LazyTensor(torch.tensor(0.5)), torch.tensor(2), 0.5 * 2],
166 | ],
167 | )
168 |
169 |
170 | def test_mul_method() -> None:
171 | common.call(
172 | lambda a, b, c: common.assert_isclose(a.mul(b).item(), c),
173 | [
174 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 12],
175 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 12],
176 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 12],
177 | ],
178 | )
179 |
180 |
181 | def test_mul_function() -> None:
182 | common.call(
183 | lambda a, b, c: common.assert_isclose(torch.mul(a, b).item(), c),
184 | [
185 | [LazyTensor(torch.tensor(8)), LazyTensor(torch.tensor(4)), 32],
186 | [torch.tensor(8), LazyTensor(torch.tensor(4)), 32],
187 | [LazyTensor(torch.tensor(8)), torch.tensor(4), 32],
188 | ],
189 | )
190 |
191 |
192 | def test_floordiv_op() -> None:
193 | common.call(
194 | common.is_notimplemented,
195 | [
196 | [lambda: LazyTensor(torch.tensor(1)) // LazyTensor(torch.tensor(2))],
197 | [lambda: torch.tensor(1) // LazyTensor(torch.tensor(2))],
198 | [lambda: LazyTensor(torch.tensor(1)) // torch.tensor(2)],
199 | ],
200 | )
201 |
202 |
203 | def test_floordiv_method() -> None:
204 | common.call(
205 | lambda a, b, c: common.assert_isclose(
206 | a.div(b, rounding_mode="trunc").item(), c
207 | ),
208 | [
209 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 // 3],
210 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 // 3],
211 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 // 3],
212 | ],
213 | )
214 |
215 |
216 | def test_floordiv_function() -> None:
217 | common.call(
218 | lambda a, b, c: common.assert_isclose(
219 | torch.div(a, b, rounding_mode="trunc").item(), c
220 | ),
221 | [
222 | [LazyTensor(torch.tensor(9)), LazyTensor(torch.tensor(4)), 9 // 4],
223 | [torch.tensor(9), LazyTensor(torch.tensor(4)), 9 // 4],
224 | [LazyTensor(torch.tensor(9)), torch.tensor(4), 9 // 4],
225 | ],
226 | )
227 |
228 |
229 | def test_truediv_op() -> None:
230 | common.call(
231 | lambda a, b, c: common.assert_isclose((a / b).item(), c),
232 | [
233 | [LazyTensor(torch.tensor(1)), LazyTensor(torch.tensor(2)), 1 / 2],
234 | [torch.tensor(1), LazyTensor(torch.tensor(2)), 1 / 2],
235 | [LazyTensor(torch.tensor(1)), torch.tensor(2), 1 / 2],
236 | ],
237 | )
238 |
239 |
240 | def test_truediv_method() -> None:
241 | common.call(
242 | lambda a, b, c: common.assert_isclose(a.div(b).item(), c),
243 | [
244 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4 / 3],
245 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4 / 3],
246 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4 / 3],
247 | ],
248 | )
249 |
250 |
251 | def test_truediv_function() -> None:
252 | common.call(
253 | lambda a, b, c: common.assert_isclose(torch.div(a, b).item(), c),
254 | [
255 | [LazyTensor(torch.tensor(9)), LazyTensor(torch.tensor(4)), 9 / 4],
256 | [torch.tensor(9), LazyTensor(torch.tensor(4)), 9 / 4],
257 | [LazyTensor(torch.tensor(9)), torch.tensor(4), 9 / 4],
258 | ],
259 | )
260 |
261 |
262 | def test_pow_op() -> None:
263 | common.call(
264 | lambda a, b, c: common.assert_isclose((a**b).item(), c),
265 | [
266 | [LazyTensor(torch.tensor(1.5)), LazyTensor(torch.tensor(2)), 1.5**2],
267 | [torch.tensor(1.5), LazyTensor(torch.tensor(2)), 1.5**2],
268 | [LazyTensor(torch.tensor(1.5)), torch.tensor(2), 1.5**2],
269 | ],
270 | )
271 |
272 |
273 | def test_pow_method() -> None:
274 | common.call(
275 | lambda a, b, c: common.assert_isclose(a.pow(b).item(), c),
276 | [
277 | [LazyTensor(torch.tensor(4)), LazyTensor(torch.tensor(3)), 4**3],
278 | [torch.tensor(4), LazyTensor(torch.tensor(3)), 4**3],
279 | [LazyTensor(torch.tensor(4)), torch.tensor(3), 4**3],
280 | ],
281 | )
282 |
283 |
284 | def test_pow_function() -> None:
285 | common.call(
286 | lambda a, b, c: common.assert_isclose(torch.pow(a, b).item(), c),
287 | [
288 | [LazyTensor(torch.tensor(9.0)), LazyTensor(torch.tensor(-2)), 9.0**-2],
289 | [torch.tensor(9.0), LazyTensor(torch.tensor(-2)), 9.0**-2],
290 | [LazyTensor(torch.tensor(9.0)), torch.tensor(-2), 9.0**-2],
291 | ],
292 | )
293 |
294 |
295 | def test_remainder_op() -> None:
296 | common.call(
297 | lambda a, b, c: common.assert_isclose((a % b).item(), c),
298 | [
299 | [LazyTensor(torch.tensor(3.3)), LazyTensor(torch.tensor(1.9)), 3.3 % 1.9],
300 | [torch.tensor(3.3), LazyTensor(torch.tensor(1.9)), 3.3 % 1.9],
301 | [LazyTensor(torch.tensor(3.3)), torch.tensor(1.9), 3.3 % 1.9],
302 | ],
303 | )
304 |
305 |
306 | def test_remainder_method() -> None:
307 | common.call(
308 | lambda a, b, c: common.assert_isclose(a.remainder(b).item(), c),
309 | [
310 | [LazyTensor(torch.tensor(99)), LazyTensor(torch.tensor(7)), 99 % 7],
311 | [torch.tensor(99), LazyTensor(torch.tensor(7)), 99 % 7],
312 | [LazyTensor(torch.tensor(99)), torch.tensor(7), 99 % 7],
313 | ],
314 | )
315 |
316 |
317 | def test_remainder_function() -> None:
318 | common.call(
319 | lambda a, b, c: common.assert_isclose(torch.remainder(a, b).item(), c),
320 | [
321 | [LazyTensor(torch.tensor(25)), LazyTensor(torch.tensor(7.8)), 25 % 7.8],
322 | [torch.tensor(25), LazyTensor(torch.tensor(7.8)), 25 % 7.8],
323 | [LazyTensor(torch.tensor(25)), torch.tensor(7.8), 25 % 7.8],
324 | ],
325 | )
326 |
327 |
328 | def test_matmul_op() -> None:
329 | arr = torch.randn(2, 10, 11)
330 |
331 | common.call(
332 | lambda a, b, c: common.assert_isclose(koila.run(a @ b), c),
333 | [
334 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T],
335 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T],
336 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T],
337 | ],
338 | )
339 |
340 |
341 | def test_matmul_method() -> None:
342 | arr = torch.randn(2, 10, 11)
343 |
344 | common.call(
345 | lambda a, b, c: common.assert_isclose(koila.run(a.matmul(b)), c),
346 | [
347 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T],
348 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T],
349 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T],
350 | ],
351 | )
352 |
353 |
354 | def test_matmul_function() -> None:
355 | arr = torch.randn(2, 10, 11)
356 |
357 | common.call(
358 | lambda a, b, c: common.assert_isclose(koila.run(torch.matmul(a, b)), c),
359 | [
360 | [LazyTensor(arr[0]), LazyTensor(arr[1].T), arr[0] @ arr[1].T],
361 | [arr[0], LazyTensor(arr[1].T), arr[0] @ arr[1].T],
362 | [LazyTensor(arr[0]), arr[1].T, arr[0] @ arr[1].T],
363 | ],
364 | )
365 |
366 |
367 | def test_identity() -> None:
368 | tensor = torch.tensor(13.5)
369 |
370 | assert LazyTensor(tensor).run() == 13.5
371 | assert LazyTensor(tensor).item() == 13.5
372 | assert int(LazyTensor(tensor)) == 13
373 | assert float(LazyTensor(tensor)) == 13.5
374 | assert bool(LazyTensor(tensor))
375 |
376 | tensor = torch.tensor(-17.5)
377 | assert LazyTensor(tensor).run() == -17.5
378 | assert LazyTensor(tensor).item() == -17.5
379 | assert int(LazyTensor(tensor)) == -17
380 | assert float(LazyTensor(tensor)) == -17.5
381 | assert bool(LazyTensor(tensor))
382 |
383 | tensor = torch.tensor(0)
384 | assert not LazyTensor(tensor).run()
385 | assert not LazyTensor(tensor).item()
386 | assert not int(LazyTensor(tensor))
387 | assert not float(LazyTensor(tensor))
388 | assert not bool(LazyTensor(tensor))
389 |
390 |
391 | def test_frac_method() -> None:
392 | common.call(
393 | lambda a, c: common.assert_isclose(a.frac().item(), c),
394 | [
395 | [LazyTensor(torch.tensor(13.22)), 0.22],
396 | [LazyTensor(torch.tensor(55.0)), 0],
397 | [LazyTensor(torch.tensor(-55.55)), -0.55],
398 | ],
399 | )
400 |
401 |
402 | def test_frac_function() -> None:
403 | common.call(
404 | lambda a, c: common.assert_isclose(torch.frac(a).item(), c),
405 | [
406 | [LazyTensor(torch.tensor(25.25)), 0.25],
407 | [LazyTensor(torch.tensor(11.0)), 0],
408 | [LazyTensor(torch.tensor(-25.33)), -0.33],
409 | ],
410 | )
411 |
412 |
413 | def test_exp_method() -> None:
414 | common.call(
415 | lambda a, c: common.assert_isclose(a.exp().item(), c),
416 | [
417 | [LazyTensor(torch.tensor(1.23)), math.e**1.23],
418 | [LazyTensor(torch.tensor(0)), 1],
419 | [LazyTensor(torch.tensor(1)), math.e],
420 | ],
421 | )
422 |
423 |
424 | def test_exp_function() -> None:
425 | common.call(
426 | lambda a, c: common.assert_isclose(torch.exp(a).item(), c),
427 | [
428 | [LazyTensor(torch.tensor(0.41)), math.e**0.41],
429 | [LazyTensor(torch.tensor(0)), 1],
430 | [LazyTensor(torch.tensor(1)), math.e],
431 | ],
432 | )
433 |
434 |
435 | def test_exp2_method() -> None:
436 | common.call(
437 | lambda a, c: common.assert_isclose(a.exp2().item(), c),
438 | [
439 | [LazyTensor(torch.tensor(10)), 2**10],
440 | [LazyTensor(torch.tensor(0)), 1],
441 | [LazyTensor(torch.tensor(1)), 2],
442 | ],
443 | )
444 |
445 |
446 | def test_exp2_function() -> None:
447 | common.call(
448 | lambda a, c: common.assert_isclose(torch.exp2(a).item(), c),
449 | [
450 | [LazyTensor(torch.tensor(-5)), 2**-5],
451 | [LazyTensor(torch.tensor(0)), 1],
452 | [LazyTensor(torch.tensor(1)), 2],
453 | ],
454 | )
455 |
456 |
457 | def test_log_method() -> None:
458 | common.call(
459 | lambda a, c: common.assert_isclose(a.log().item(), c),
460 | [
461 | [LazyTensor(torch.tensor(13)), math.log(13)],
462 | [LazyTensor(torch.tensor(1)), 0],
463 | [LazyTensor(torch.tensor(math.e)), 1],
464 | ],
465 | )
466 |
467 |
468 | def test_log_function() -> None:
469 | common.call(
470 | lambda a, c: common.assert_isclose(torch.log(a).item(), c),
471 | [
472 | [LazyTensor(torch.tensor(5)), math.log(5)],
473 | [LazyTensor(torch.tensor(1)), 0],
474 | [LazyTensor(torch.tensor(math.e)), 1],
475 | ],
476 | )
477 |
478 |
479 | def test_log2_method() -> None:
480 | common.call(
481 | lambda a, c: common.assert_isclose(a.log2().item(), c),
482 | [
483 | [LazyTensor(torch.tensor(442)), math.log2(442)],
484 | [LazyTensor(torch.tensor(1)), 0],
485 | [LazyTensor(torch.tensor(2)), 1],
486 | ],
487 | )
488 |
489 |
490 | def test_log2_function() -> None:
491 | common.call(
492 | lambda a, c: common.assert_isclose(torch.log2(a).item(), c),
493 | [
494 | [LazyTensor(torch.tensor(81)), math.log2(81)],
495 | [LazyTensor(torch.tensor(1)), 0],
496 | [LazyTensor(torch.tensor(2)), 1],
497 | ],
498 | )
499 |
500 |
501 | def test_log10_method() -> None:
502 | common.call(
503 | lambda a, c: common.assert_isclose(a.log10().item(), c),
504 | [
505 | [LazyTensor(torch.tensor(132)), math.log10(132)],
506 | [LazyTensor(torch.tensor(1)), 0],
507 | [LazyTensor(torch.tensor(10)), 1],
508 | ],
509 | )
510 |
511 |
512 | def test_log10_function() -> None:
513 | common.call(
514 | lambda a, c: common.assert_isclose(torch.log10(a).item(), c),
515 | [
516 | [LazyTensor(torch.tensor(979)), math.log10(979)],
517 | [LazyTensor(torch.tensor(1)), 0],
518 | [LazyTensor(torch.tensor(10)), 1],
519 | ],
520 | )
521 |
522 |
523 | def test_log1p_method() -> None:
524 | common.call(
525 | lambda a, c: common.assert_isclose(a.log1p().item(), c),
526 | [[LazyTensor(torch.tensor(1.5)), math.log1p(1.5)]],
527 | )
528 |
529 |
530 | def test_log1p_function() -> None:
531 | common.call(
532 | lambda a, c: common.assert_isclose(torch.log1p(a).item(), c),
533 | [[LazyTensor(torch.tensor(2.7)), math.log1p(2.7)]],
534 | )
535 |
536 |
537 | def test_abs_op() -> None:
538 | common.call(
539 | lambda a, c: common.assert_isclose(abs(a).item(), c),
540 | [
541 | [LazyTensor(torch.tensor(-7.122)), abs(-7.122)],
542 | [LazyTensor(torch.tensor(4.002)), abs(4.002)],
543 | ],
544 | )
545 |
546 |
547 | def test_abs_method() -> None:
548 | common.call(
549 | lambda a, c: common.assert_isclose(a.abs().item(), c),
550 | [
551 | [LazyTensor(torch.tensor(-1.5)), abs(-1.5)],
552 | [LazyTensor(torch.tensor(3.7)), abs(3.7)],
553 | ],
554 | )
555 |
556 |
557 | def test_abs_function() -> None:
558 | common.call(
559 | lambda a, c: common.assert_isclose(torch.abs(a).item(), c),
560 | [
561 | [LazyTensor(torch.tensor(0.001)), abs(0.001)],
562 | [LazyTensor(torch.tensor(-24)), abs(-24)],
563 | ],
564 | )
565 |
566 |
567 | def test_min_method() -> None:
568 | arr = torch.randn(6, 7, 8)
569 |
570 | common.call(
571 | lambda a, c: common.assert_isclose(koila.run(a), c),
572 | [
573 | [LazyTensor(arr).min(), arr.min()],
574 | [LazyTensor(arr).min(1)[0], arr.min(1)[0]],
575 | [LazyTensor(arr).min(1)[1], arr.min(1)[1]],
576 | ],
577 | )
578 |
579 |
580 | def test_min_function() -> None:
581 | arr = torch.randn(6, 7, 8)
582 | brr = torch.randn(1, 7, 8)
583 | la = typing.cast(Tensor, LazyTensor(arr))
584 | lb = typing.cast(Tensor, LazyTensor(brr))
585 |
586 | common.call(
587 | lambda a, c: common.assert_isclose(koila.run(a), c),
588 | [
589 | [torch.min(la), torch.min(arr)],
590 | [torch.min(la, 2)[0], torch.min(arr, 2)[0]],
591 | [
592 | torch.min(la, 1, keepdim=True).indices,
593 | torch.min(arr, 1, keepdim=True).indices,
594 | ],
595 | [torch.min(la, lb), torch.min(arr, brr)],
596 | ],
597 | )
598 |
599 |
600 | def test_max_method() -> None:
601 | arr = torch.randn(6, 7, 8)
602 |
603 | common.call(
604 | lambda a, c: common.assert_isclose(koila.run(a), c),
605 | [
606 | [LazyTensor(arr).max(), arr.max()],
607 | [LazyTensor(arr).max(1)[0], arr.max(1)[0]],
608 | [LazyTensor(arr).max(1)[1], arr.max(1)[1]],
609 | ],
610 | )
611 |
612 |
613 | def test_max_function() -> None:
614 | arr = torch.randn(6, 7, 8)
615 | brr = torch.randn(1, 7, 8)
616 | la = typing.cast(Tensor, LazyTensor(arr))
617 | lb = typing.cast(Tensor, LazyTensor(brr))
618 |
619 | common.call(
620 | lambda a, c: common.assert_isclose(koila.run(a), c),
621 | [
622 | [torch.max(la), torch.max(arr)],
623 | [torch.max(la, 2)[0], torch.max(arr, 2)[0]],
624 | [
625 | torch.max(la, 1, keepdim=True).indices,
626 | torch.max(arr, 1, keepdim=True).indices,
627 | ],
628 | [torch.max(la, lb), torch.max(arr, brr)],
629 | ],
630 | )
631 |
632 |
633 | def test_size_shape_method() -> None:
634 | arr = torch.randn(11, 13)
635 | la = LazyTensor(arr)
636 | assert la.size() == la.shape == (11, 13)
637 | assert la.size(0) == 11
638 | assert la.size(1) == 13
639 |
640 |
641 | def test_t_method() -> None:
642 | arr = torch.randn(11, 13)
643 | la = LazyTensor(arr)
644 | assert la.T.size() == la.t().size() == (13, 11)
645 |
646 |
647 | def test_t_function() -> None:
648 | arr = torch.randn(11, 13)
649 | la = typing.cast(Tensor, LazyTensor(arr))
650 | assert torch.t(la).shape == (13, 11)
651 |
652 |
653 | def test_dim_method() -> None:
654 | arr = torch.randn(11, 13)
655 | assert arr.ndim == arr.dim() == 2
656 | arr = torch.randn(1, 2, 3, 4, 5)
657 | assert arr.dim() == 5
658 |
659 |
660 | def test_permute_method() -> None:
661 | arr = torch.randn(2, 3, 4, 5, 6)
662 | la = LazyTensor(arr)
663 | assert la.permute(3, 4, 1, 2, 0).shape == (5, 6, 3, 4, 2)
664 | assert la.permute(0, 1, 4, 3, 2).shape == (2, 3, 6, 5, 4)
665 |
666 |
667 | def test_permute_function() -> None:
668 | arr = torch.randn(2, 3, 4, 5, 6)
669 | la = typing.cast(Tensor, LazyTensor(arr))
670 | assert torch.permute(la, (3, 4, 1, 2, 0)).shape == (5, 6, 3, 4, 2)
671 | assert torch.permute(la, (0, 1, 4, 3, 2)).shape == (2, 3, 6, 5, 4)
672 |
673 |
674 | def test_transpose_method() -> None:
675 | arr = torch.randn(2, 3, 4, 5, 6)
676 | la = LazyTensor(arr)
677 | assert la.transpose(3, 4).shape == (2, 3, 4, 6, 5)
678 | assert la.transpose(0, 1).shape == (3, 2, 4, 5, 6)
679 | assert la.transpose(0, 3).shape == (5, 3, 4, 2, 6)
680 |
681 |
682 | def test_select_method() -> None:
683 | arr = torch.randn(3, 4, 5)
684 | sel = arr.select(1, 2)
685 | assert isinstance(sel, Tensor)
686 | assert not isinstance(sel, LazyTensor)
687 |
688 | la = LazyTensor(arr)
689 | lsel = la.select(1, 2)
690 |
691 | assert not isinstance(lsel, Tensor)
692 | assert isinstance(lsel, LazyTensor)
693 | assert sel.size() == lsel.size() == (3, 5)
694 | common.assert_isclose(lsel.run(), sel)
695 |
696 |
697 | def test_select_function() -> None:
698 | arr = torch.randn(3, 4, 5)
699 | sel = torch.select(arr, 1, 2)
700 | assert isinstance(sel, Tensor)
701 | assert not isinstance(sel, LazyTensor)
702 |
703 | la = typing.cast(Tensor, LazyTensor(arr))
704 | lsel = torch.select(la, 1, 2)
705 |
706 | assert not isinstance(lsel, Tensor)
707 | assert isinstance(lsel, LazyTensor)
708 | assert sel.size() == lsel.size() == (3, 5)
709 | common.assert_isclose(lsel.run(), sel)
710 |
711 |
712 | def test_index_select_method() -> None:
713 | arr = torch.randn(3, 4, 5)
714 | idx = torch.tensor([1, 2, 3])
715 | sel = arr.index_select(1, idx)
716 | assert isinstance(sel, Tensor)
717 | assert not isinstance(sel, LazyTensor)
718 |
719 | la = LazyTensor(arr)
720 | lsel = la.index_select(1, idx)
721 |
722 | assert not isinstance(lsel, Tensor)
723 | assert isinstance(lsel, LazyTensor)
724 | assert sel.size() == lsel.size() == (3, 3, 5)
725 | common.assert_isclose(lsel.run(), sel)
726 |
727 |
728 | def test_index_select_function() -> None:
729 | arr = torch.randn(3, 4, 5)
730 | idx = torch.tensor([1, 2, 3])
731 | sel = torch.index_select(arr, 1, idx)
732 | assert isinstance(sel, Tensor)
733 | assert not isinstance(sel, LazyTensor)
734 |
735 | la = typing.cast(Tensor, LazyTensor(arr))
736 | lsel = torch.index_select(la, 1, idx)
737 |
738 | assert not isinstance(lsel, Tensor)
739 | assert isinstance(lsel, LazyTensor)
740 | assert sel.size() == lsel.size() == (3, 3, 5)
741 | common.assert_isclose(lsel.run(), sel)
742 |
743 |
744 | def test_numel_method() -> None:
745 | arr = torch.randn(2, 3, 4, 5, 6)
746 | la = typing.cast(Tensor, LazyTensor(arr))
747 | assert la.numel() == 2 * 3 * 4 * 5 * 6
748 |
749 | arr = torch.randn(15, 19)
750 | la = typing.cast(Tensor, LazyTensor(arr))
751 | assert la.numel() == 15 * 19
752 |
753 |
754 | def test_numel_function() -> None:
755 | arr = torch.randn(2, 3, 4, 5, 6)
756 | la = typing.cast(Tensor, LazyTensor(arr))
757 | assert torch.numel(la) == 2 * 3 * 4 * 5 * 6
758 |
759 | arr = torch.randn(15, 19)
760 | la = typing.cast(Tensor, LazyTensor(arr))
761 | assert torch.numel(la) == 15 * 19
762 |
763 |
764 | def test_sigmoid_method() -> None:
765 | arr = torch.randn(4, 5, 6)
766 | common.call(
767 | lambda a, c: common.assert_isclose(koila.run(a), c),
768 | [[LazyTensor(arr).sigmoid(), torch.sigmoid(arr)]],
769 | )
770 |
771 |
772 | def test_sigmoid_function() -> None:
773 | arr = torch.randn(4, 5, 6)
774 | la = typing.cast(Tensor, arr)
775 | common.call(
776 | lambda a, c: common.assert_isclose(koila.run(a), c),
777 | [[torch.sigmoid(la), torch.sigmoid(arr)]],
778 | )
779 |
780 |
781 | def test_sin_method() -> None:
782 | common.call(
783 | lambda a, c: common.assert_isclose(a.sin().item(), c),
784 | [
785 | [LazyTensor(torch.tensor(0)), 0],
786 | [LazyTensor(torch.tensor(math.pi)), 0],
787 | [LazyTensor(torch.tensor(math.pi / 2)), 1],
788 | [LazyTensor(torch.tensor(3 * math.pi / 2)), -1],
789 | [LazyTensor(torch.tensor(42.0)), math.sin(42)],
790 | [LazyTensor(torch.tensor(-75.0)), math.sin(-75)],
791 | ],
792 | )
793 |
794 |
795 | def test_sin_function() -> None:
796 | common.call(
797 | lambda a, c: common.assert_isclose(torch.sin(a).item(), c),
798 | [
799 | [LazyTensor(torch.tensor(0)), 0],
800 | [LazyTensor(torch.tensor(math.pi)), 0],
801 | [LazyTensor(torch.tensor(math.pi / 2)), 1],
802 | [LazyTensor(torch.tensor(3 * math.pi / 2)), -1],
803 | [LazyTensor(torch.tensor(42.0)), math.sin(42)],
804 | [LazyTensor(torch.tensor(-75.0)), math.sin(-75)],
805 | ],
806 | )
807 |
808 |
809 | def test_cos_method() -> None:
810 | common.call(
811 | lambda a, c: common.assert_isclose(a.cos().item(), c),
812 | [
813 | [LazyTensor(torch.tensor(0)), 1],
814 | [LazyTensor(torch.tensor(math.pi)), -1],
815 | [LazyTensor(torch.tensor(math.pi / 2)), 0],
816 | [LazyTensor(torch.tensor(3 * math.pi / 2)), 0],
817 | [LazyTensor(torch.tensor(27.0)), math.cos(27)],
818 | [LazyTensor(torch.tensor(-14.0)), math.cos(-14)],
819 | ],
820 | )
821 |
822 |
823 | def test_cos_function() -> None:
824 | common.call(
825 | lambda a, c: common.assert_isclose(torch.cos(a).item(), c),
826 | [
827 | [LazyTensor(torch.tensor(0)), 1],
828 | [LazyTensor(torch.tensor(math.pi)), -1],
829 | [LazyTensor(torch.tensor(math.pi / 2)), 0],
830 | [LazyTensor(torch.tensor(3 * math.pi / 2)), 0],
831 | [LazyTensor(torch.tensor(27.0)), math.cos(27)],
832 | [LazyTensor(torch.tensor(-14.0)), math.cos(-14)],
833 | ],
834 | )
835 |
836 |
837 | def test_tan_method() -> None:
838 | common.call(
839 | lambda a, c: common.assert_isclose(a.tan().item(), c),
840 | [
841 | [LazyTensor(torch.tensor(0)), 0],
842 | [LazyTensor(torch.tensor(math.pi)), 0],
843 | [LazyTensor(torch.tensor(99.0)), math.tan(99)],
844 | [LazyTensor(torch.tensor(-4.0)), math.tan(-4)],
845 | ],
846 | )
847 |
848 |
849 | def test_tan_function() -> None:
850 | common.call(
851 | lambda a, c: common.assert_isclose(torch.tan(a).item(), c),
852 | [
853 | [LazyTensor(torch.tensor(0)), 0],
854 | [LazyTensor(torch.tensor(math.pi)), 0],
855 | [LazyTensor(torch.tensor(99.0)), math.tan(99)],
856 | [LazyTensor(torch.tensor(-4.0)), math.tan(-4)],
857 | ],
858 | )
859 |
860 |
861 | def test_asin_method() -> None:
862 | common.call(
863 | lambda a, c: common.assert_isclose(a.asin().item(), c),
864 | [
865 | [LazyTensor(torch.tensor(n)), math.asin(n)]
866 | for n in np.linspace(-1, 1).tolist()
867 | ],
868 | )
869 |
870 |
871 | def test_asin_function() -> None:
872 | common.call(
873 | lambda a, c: common.assert_isclose(torch.asin(a).item(), c),
874 | [
875 | [LazyTensor(torch.tensor(n)), math.asin(n)]
876 | for n in np.linspace(-1, 1).tolist()
877 | ],
878 | )
879 |
880 |
881 | def test_acos_method() -> None:
882 | common.call(
883 | lambda a, c: common.assert_isclose(a.acos().item(), c),
884 | [
885 | [LazyTensor(torch.tensor(n)), math.acos(n)]
886 | for n in np.linspace(-1, 1).tolist()
887 | ],
888 | )
889 |
890 |
891 | def test_acos_function() -> None:
892 | common.call(
893 | lambda a, c: common.assert_isclose(torch.acos(a).item(), c),
894 | [
895 | [LazyTensor(torch.tensor(n)), math.acos(n)]
896 | for n in np.linspace(-1, 1).tolist()
897 | ],
898 | )
899 |
900 |
901 | def test_atan_method() -> None:
902 | common.call(
903 | lambda a, c: common.assert_isclose(a.atan().item(), c),
904 | [
905 | [LazyTensor(torch.tensor(99.0)), math.atan(99)],
906 | [LazyTensor(torch.tensor(-4.0)), math.atan(-4)],
907 | [LazyTensor(torch.tensor(-6.0)), math.atan(-6)],
908 | [LazyTensor(torch.tensor(242.0)), math.atan(242)],
909 | ],
910 | )
911 |
912 |
913 | def test_atan_function() -> None:
914 | common.call(
915 | lambda a, c: common.assert_isclose(torch.atan(a).item(), c),
916 | [
917 | [LazyTensor(torch.tensor(99.0)), math.atan(99)],
918 | [LazyTensor(torch.tensor(-4.0)), math.atan(-4)],
919 | [LazyTensor(torch.tensor(-6.0)), math.atan(-6)],
920 | [LazyTensor(torch.tensor(242.0)), math.atan(242)],
921 | ],
922 | )
923 |
924 |
925 | def test_sinh_method() -> None:
926 | common.call(
927 | lambda a, c: common.assert_isclose(a.sinh().item(), c),
928 | [
929 | [LazyTensor(torch.tensor(n)), math.sinh(n)]
930 | for n in np.linspace(-1, 1).tolist()
931 | ],
932 | )
933 |
934 |
935 | def test_sinh_function() -> None:
936 | common.call(
937 | lambda a, c: common.assert_isclose(torch.sinh(a).item(), c),
938 | [
939 | [LazyTensor(torch.tensor(n)), math.sinh(n)]
940 | for n in np.linspace(-1, 1).tolist()
941 | ],
942 | )
943 |
944 |
945 | def test_cosh_method() -> None:
946 | common.call(
947 | lambda a, c: common.assert_isclose(a.cosh().item(), c),
948 | [
949 | [LazyTensor(torch.tensor(n)), math.cosh(n)]
950 | for n in np.linspace(-1, 1).tolist()
951 | ],
952 | )
953 |
954 |
955 | def test_cosh_function() -> None:
956 | common.call(
957 | lambda a, c: common.assert_isclose(torch.cosh(a).item(), c),
958 | [
959 | [LazyTensor(torch.tensor(n)), math.cosh(n)]
960 | for n in np.linspace(-1, 1).tolist()
961 | ],
962 | )
963 |
964 |
965 | def test_tanh_method() -> None:
966 | common.call(
967 | lambda a, c: common.assert_isclose(a.tanh().item(), c),
968 | [[LazyTensor(torch.tensor(n)), math.tanh(n)] for n in np.linspace(-10, 10)],
969 | )
970 |
971 |
972 | def test_tanh_function() -> None:
973 | common.call(
974 | lambda a, c: common.assert_isclose(torch.tanh(a).item(), c),
975 | [[LazyTensor(torch.tensor(n)), math.tanh(n)] for n in np.linspace(-10, 10)],
976 | )
977 |
978 |
979 | def test_asinh_method() -> None:
980 | common.call(
981 | lambda a, c: common.assert_isclose(a.asinh().item(), c),
982 | [
983 | [LazyTensor(torch.tensor(199.0)), math.asinh(199)],
984 | [LazyTensor(torch.tensor(-241.0)), math.asinh(-241)],
985 | [LazyTensor(torch.tensor(-9.0)), math.asinh(-9)],
986 | [LazyTensor(torch.tensor(0.0)), math.asinh(0)],
987 | ],
988 | )
989 |
990 |
991 | def test_asinh_function() -> None:
992 | common.call(
993 | lambda a, c: common.assert_isclose(torch.asinh(a).item(), c),
994 | [
995 | [LazyTensor(torch.tensor(199.0)), math.asinh(199)],
996 | [LazyTensor(torch.tensor(-241.0)), math.asinh(-241)],
997 | [LazyTensor(torch.tensor(-9.0)), math.asinh(-9)],
998 | [LazyTensor(torch.tensor(0.0)), math.asinh(0)],
999 | ],
1000 | )
1001 |
1002 |
1003 | def test_acosh_method() -> None:
1004 | common.call(
1005 | lambda a, c: common.assert_isclose(a.acosh().item(), c),
1006 | [
1007 | [LazyTensor(torch.tensor(14.0)), math.acosh(14)],
1008 | [LazyTensor(torch.tensor(2.0)), math.acosh(2)],
1009 | [LazyTensor(torch.tensor(1.0)), math.acosh(1)],
1010 | [LazyTensor(torch.tensor(65.0)), math.acosh(65)],
1011 | ],
1012 | )
1013 |
1014 |
1015 | def test_acosh_function() -> None:
1016 | common.call(
1017 | lambda a, c: common.assert_isclose(torch.acosh(a).item(), c),
1018 | [
1019 | [LazyTensor(torch.tensor(14.0)), math.acosh(14)],
1020 | [LazyTensor(torch.tensor(2.0)), math.acosh(2)],
1021 | [LazyTensor(torch.tensor(1.0)), math.acosh(1)],
1022 | [LazyTensor(torch.tensor(65.0)), math.acosh(65)],
1023 | ],
1024 | )
1025 |
1026 |
1027 | def test_atanh_method() -> None:
1028 | common.call(
1029 | lambda a, c: common.assert_isclose(a.atanh().item(), c),
1030 | [
1031 | [LazyTensor(torch.tensor(n)), math.atanh(n)]
1032 | for n in np.linspace(-0.99, 0.99, endpoint=False).tolist()
1033 | ],
1034 | )
1035 |
1036 |
1037 | def test_atanh_function() -> None:
1038 | common.call(
1039 | lambda a, c: common.assert_isclose(torch.atanh(a).item(), c),
1040 | [
1041 | [LazyTensor(torch.tensor(n)), math.atanh(n)]
1042 | for n in np.linspace(-0.99, 0.99, endpoint=False).tolist()
1043 | ],
1044 | )
1045 |
1046 |
1047 | def test_run_method() -> None:
1048 | random = torch.randn(3, 4, 5, 6)
1049 | common.call(
1050 | lambda a, b: common.assert_isclose(a.run(), b), [[LazyTensor(random), random]]
1051 | )
1052 |
1053 |
1054 | def test_torch_method() -> None:
1055 | random = torch.randn(3, 4, 5, 6)
1056 | common.call(
1057 | lambda a, b: common.assert_isclose(a.torch(), b), [[LazyTensor(random), random]]
1058 | )
1059 |
1060 |
1061 | def test_numpy_method() -> None:
1062 | random = torch.randn(3, 4, 5, 6)
1063 | common.call(
1064 | lambda a, b: common.assert_isclose(a.numpy(), b.numpy()),
1065 | [[LazyTensor(random), random]],
1066 | )
1067 |
1068 |
1069 | def test_pad_function() -> None:
1070 | tensor = torch.randn(3, 4, 5, 6)
1071 | padded = F.pad(tensor, (2, 3, 0, 1), mode="reflect")
1072 | assert isinstance(padded, Tensor)
1073 | assert not isinstance(padded, LazyTensor)
1074 |
1075 | la = typing.cast(Tensor, LazyTensor(tensor))
1076 | lazy_padded = F.pad(la, (2, 3, 0, 1), mode="reflect")
1077 | assert not isinstance(lazy_padded, Tensor)
1078 | assert isinstance(lazy_padded, LazyTensor)
1079 | assert padded.shape == lazy_padded.shape
1080 |
1081 | common.assert_isclose(lazy_padded.run(), padded)
1082 |
1083 |
1084 | def test_buffer_sizes() -> None:
1085 | a = torch.randn(4, 5, 6)
1086 |
1087 | la = LazyTensor(a)
1088 | assert a.numel() == la.numel() == la.buffer_numel()[1]
1089 |
1090 | b = torch.randn(4, 5, 1)
1091 | lb = LazyTensor(b)
1092 | assert b.numel() == lb.numel() == lb.buffer_numel()[1]
1093 |
1094 | lc = typing.cast(LazyTensor, la + lb)
1095 | assert lc.numel() == la.numel() == 6 * lb.numel()
1096 | assert lc.buffer_numel()[1] == la.numel() + lb.numel() + lc.numel()
1097 |
1098 | d = torch.randn(4, 5, 6)
1099 | ld = typing.cast(LazyTensor, d)
1100 |
1101 | le = typing.cast(LazyTensor, lc * ld)
1102 | assert d.numel() == ld.numel() == le.numel()
1103 | assert le.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le}))
1104 |
1105 | lf = le.sum()
1106 | assert lf.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le, lf}))
1107 |
1108 | lg = typing.cast(LazyTensor, lc + le)
1109 | assert lg.buffer_numel()[1] == sum(map(LazyTensor.numel, {la, lb, lc, ld, le, lg}))
1110 |
1111 | assert lg.buffer_memory()[1] == lg.buffer_numel()[1] * 4
1112 |
--------------------------------------------------------------------------------