├── .flake8 ├── .github ├── black_linter.py └── workflows │ ├── base.yml │ └── lint.yml ├── .gitignore ├── .lintrunner.toml ├── Model_mem_usage.png ├── README.md ├── autograd_monkeypatch.py ├── base_tensor.py ├── bug_zoo.py ├── complex_tensor.py ├── cuda_sanitizer.py ├── custom_parameter.py ├── data_parallel_tensor.py ├── deferred_init.py ├── dispatch_mem_profiler.py ├── dynamic_shapes.ipynb ├── dynamic_shapes.py ├── dynamic_strides.ipynb ├── dynamic_strides.py ├── empty_tensor.py ├── enhanced_error_mode.py ├── failures └── grad_several_ways.py ├── flat_view_tensor.py ├── format.sh ├── functorch_test.py ├── inner_autograd_tensor.py ├── logging_mode.py ├── max_mem_tracker.py ├── memory_debugging_tensor.ipynb ├── memory_debugging_tensor.py ├── nan_detect.py ├── negative_tensor.py ├── nested_forward_ad.py ├── new_device.py ├── numerical_consistency_mode.py ├── progressive_lowering_tensor.py ├── py_dispatcher.py ├── python_meta_tensor.py ├── quantization_transform.py ├── quantized_tensor.py ├── requirements.txt ├── run_test.py ├── simple_functorch.ipynb ├── simple_functorch.py ├── sparse_output.py ├── torchdynamo_dynamic_inference.ipynb ├── torchdynamo_dynamic_inference.py ├── tracer_tensor.py ├── tracing_guards.ipynb ├── tracing_guards.py ├── trivial_tensors.py ├── uint4_tensor.py ├── use_cpu_for_rng.py ├── utils.py └── verifier_tensor.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | max-line-length = 120 4 | # C408 ignored because we like the dict keyword argument syntax 5 | # E501 is not flexible enough, we're using B950 instead 6 | ignore = 7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 9 | # to line this up with executable bit 10 | EXE001, 11 | # these ignores are from flake8-bugbear; please fix! 12 | B007,B008, 13 | # these ignores are from flake8-comprehensions; please fix! 14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 15 | optional-ascii-coding = True 16 | exclude = ./.git,failures 17 | 18 | -------------------------------------------------------------------------------- /.github/black_linter.py: -------------------------------------------------------------------------------- 1 | # Taken from pytorch/pytorch/tools/linter/adapters/black_linter.py 2 | 3 | import argparse 4 | import concurrent.futures 5 | import json 6 | import logging 7 | import os 8 | import subprocess 9 | import sys 10 | import time 11 | from enum import Enum 12 | from typing import Any, List, NamedTuple, Optional, BinaryIO 13 | 14 | 15 | IS_WINDOWS: bool = os.name == "nt" 16 | 17 | 18 | def eprint(*args: Any, **kwargs: Any) -> None: 19 | print(*args, file=sys.stderr, flush=True, **kwargs) 20 | 21 | 22 | class LintSeverity(str, Enum): 23 | ERROR = "error" 24 | WARNING = "warning" 25 | ADVICE = "advice" 26 | DISABLED = "disabled" 27 | 28 | 29 | class LintMessage(NamedTuple): 30 | path: Optional[str] 31 | line: Optional[int] 32 | char: Optional[int] 33 | code: str 34 | severity: LintSeverity 35 | name: str 36 | original: Optional[str] 37 | replacement: Optional[str] 38 | description: Optional[str] 39 | 40 | 41 | def as_posix(name: str) -> str: 42 | return name.replace("\\", "/") if IS_WINDOWS else name 43 | 44 | 45 | def _run_command( 46 | args: List[str], 47 | *, 48 | stdin: BinaryIO, 49 | timeout: int, 50 | ) -> "subprocess.CompletedProcess[bytes]": 51 | logging.debug("$ %s", " ".join(args)) 52 | start_time = time.monotonic() 53 | try: 54 | return subprocess.run( 55 | args, 56 | stdin=stdin, 57 | stdout=subprocess.PIPE, 58 | stderr=subprocess.PIPE, 59 | shell=IS_WINDOWS, # So batch scripts are found. 60 | timeout=timeout, 61 | check=True, 62 | ) 63 | finally: 64 | end_time = time.monotonic() 65 | logging.debug("took %dms", (end_time - start_time) * 1000) 66 | 67 | 68 | def run_command( 69 | args: List[str], 70 | *, 71 | stdin: BinaryIO, 72 | retries: int, 73 | timeout: int, 74 | ) -> "subprocess.CompletedProcess[bytes]": 75 | remaining_retries = retries 76 | while True: 77 | try: 78 | return _run_command(args, stdin=stdin, timeout=timeout) 79 | except subprocess.TimeoutExpired as err: 80 | if remaining_retries == 0: 81 | raise err 82 | remaining_retries -= 1 83 | logging.warning( 84 | "(%s/%s) Retrying because command failed with: %r", 85 | retries - remaining_retries, 86 | retries, 87 | err, 88 | ) 89 | time.sleep(1) 90 | 91 | 92 | def check_file( 93 | filename: str, 94 | retries: int, 95 | timeout: int, 96 | ) -> List[LintMessage]: 97 | try: 98 | with open(filename, "rb") as f: 99 | original = f.read() 100 | with open(filename, "rb") as f: 101 | proc = run_command( 102 | [sys.executable, "-mblack", "--stdin-filename", filename, "-"], 103 | stdin=f, 104 | retries=retries, 105 | timeout=timeout, 106 | ) 107 | except subprocess.TimeoutExpired: 108 | return [ 109 | LintMessage( 110 | path=filename, 111 | line=None, 112 | char=None, 113 | code="BLACK", 114 | severity=LintSeverity.ERROR, 115 | name="timeout", 116 | original=None, 117 | replacement=None, 118 | description=( 119 | "black timed out while trying to process a file. " 120 | "Please report an issue in pytorch/pytorch with the " 121 | "label 'module: lint'" 122 | ), 123 | ) 124 | ] 125 | except (OSError, subprocess.CalledProcessError) as err: 126 | return [ 127 | LintMessage( 128 | path=filename, 129 | line=None, 130 | char=None, 131 | code="BLACK", 132 | severity=LintSeverity.ADVICE, 133 | name="command-failed", 134 | original=None, 135 | replacement=None, 136 | description=( 137 | f"Failed due to {err.__class__.__name__}:\n{err}" 138 | if not isinstance(err, subprocess.CalledProcessError) 139 | else ( 140 | "COMMAND (exit code {returncode})\n" 141 | "{command}\n\n" 142 | "STDERR\n{stderr}\n\n" 143 | "STDOUT\n{stdout}" 144 | ).format( 145 | returncode=err.returncode, 146 | command=" ".join(as_posix(x) for x in err.cmd), 147 | stderr=err.stderr.decode("utf-8").strip() or "(empty)", 148 | stdout=err.stdout.decode("utf-8").strip() or "(empty)", 149 | ) 150 | ), 151 | ) 152 | ] 153 | 154 | replacement = proc.stdout 155 | if original == replacement: 156 | return [] 157 | 158 | return [ 159 | LintMessage( 160 | path=filename, 161 | line=None, 162 | char=None, 163 | code="BLACK", 164 | severity=LintSeverity.WARNING, 165 | name="format", 166 | original=original.decode("utf-8"), 167 | replacement=replacement.decode("utf-8"), 168 | description="Run `lintrunner -a` to apply this patch.", 169 | ) 170 | ] 171 | 172 | 173 | def main() -> None: 174 | parser = argparse.ArgumentParser( 175 | description="Format files with black.", 176 | fromfile_prefix_chars="@", 177 | ) 178 | parser.add_argument( 179 | "--retries", 180 | default=3, 181 | type=int, 182 | help="times to retry timed out black", 183 | ) 184 | parser.add_argument( 185 | "--timeout", 186 | default=90, 187 | type=int, 188 | help="seconds to wait for black", 189 | ) 190 | parser.add_argument( 191 | "--verbose", 192 | action="store_true", 193 | help="verbose logging", 194 | ) 195 | parser.add_argument( 196 | "filenames", 197 | nargs="+", 198 | help="paths to lint", 199 | ) 200 | args = parser.parse_args() 201 | 202 | logging.basicConfig( 203 | format="<%(threadName)s:%(levelname)s> %(message)s", 204 | level=logging.NOTSET 205 | if args.verbose 206 | else logging.DEBUG 207 | if len(args.filenames) < 1000 208 | else logging.INFO, 209 | stream=sys.stderr, 210 | ) 211 | 212 | with concurrent.futures.ThreadPoolExecutor( 213 | max_workers=os.cpu_count(), 214 | thread_name_prefix="Thread", 215 | ) as executor: 216 | futures = { 217 | executor.submit(check_file, x, args.retries, args.timeout): x 218 | for x in args.filenames 219 | } 220 | for future in concurrent.futures.as_completed(futures): 221 | try: 222 | for lint_message in future.result(): 223 | print(json.dumps(lint_message._asdict()), flush=True) 224 | except Exception: 225 | logging.critical('Failed at "%s".', futures[future]) 226 | raise 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /.github/workflows/base.yml: -------------------------------------------------------------------------------- 1 | name: Base tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | linux-test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout Repo 14 | uses: actions/checkout@v3 15 | with: 16 | ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} 17 | 18 | - name: Setup miniconda 19 | uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | auto-update-conda: true 22 | python-version: 3.8 23 | activate-environment: build 24 | miniconda-version: 4.7.12 25 | 26 | - name: Install PyTorch 27 | run: | 28 | python3 -mpip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu 29 | python3 -c "import torch;print('All Ok!')" 30 | 31 | - name: Install Dependencies 32 | run: | 33 | python3 -mpip install -r requirements.txt 34 | 35 | - name: Run test 36 | run: | 37 | python run_test.py 38 | 39 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | linux-test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout Repo 14 | uses: actions/checkout@v3 15 | with: 16 | ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} 17 | 18 | - name: Setup miniconda 19 | uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | auto-update-conda: true 22 | python-version: 3.8 23 | activate-environment: build 24 | miniconda-version: 4.7.12 25 | 26 | - name: Install Dependencies 27 | run: | 28 | python3 -mpip install lintrunner 29 | 30 | - name: Run lint init 31 | run: | 32 | lintrunner init 33 | 34 | - name: Run lint 35 | run: | 36 | lintrunner 37 | 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gdb_history 2 | *.bak 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # PyCharm 151 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 153 | # and can be added to the global gitignore or merged into this file. For a more nuclear 154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 155 | #.idea/ 156 | 157 | # Direnv 158 | .envrc 159 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | [[linter]] 2 | code = 'BLACK' 3 | include_patterns = ['**/*.py'] 4 | command = [ 5 | 'python3', 6 | '.github/black_linter.py', 7 | '--', 8 | '@{{PATHSFILE}}' 9 | ] 10 | init_command = [ 11 | 'python3', 12 | '-mpip', 13 | 'install', 14 | 'black==22.3.0', 15 | '{{DRYRUN}}', # Dry run means crash here 16 | ] 17 | is_formatter = true 18 | -------------------------------------------------------------------------------- /Model_mem_usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albanD/subclass_zoo/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/Model_mem_usage.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # subclass zoo 2 | 3 | This repository contains a number of examples of Tensor subclasses in PyTorch, 4 | specifically using `__torch_dispatch__` to integrate deeply into PyTorch's 5 | existing subsystems (there's also some use of modes as well). We're still 6 | working out good APIs for working with Tensor subclasses, and this repository 7 | is here to tell you about what we've figured out so far! To run these 8 | examples, you will want a recent nightly of PyTorch. 9 | 10 | Here's what's in the repo so far: 11 | 12 | - `inner_autograd_tensor.py` shows how to override autograd from 13 | `__torch_dispatch__`, by deferring autograd to the inner tensor on a 14 | subclass. 15 | - `negative_tensor.py` is a reimplementation of negative tensor views as 16 | implemented in PyTorch core (https://github.com/pytorch/pytorch/pull/56058) 17 | - `python_meta_tensor.py` is a demonstration of how to extend an existing 18 | tensor (meta tensor) with some extra behavior (in this case, implementations 19 | of meta functions for operations that don't support it natively) 20 | - `sparse_output.py` 21 | - `tracer_tensor.py` 22 | - `trivial_tensors.py` is a comparison for two ways how to "wrap" tensors, 23 | one using inheritance (is-a) and one using composition (has-a) (so called 24 | wrapper tensors) 25 | - `verifier_tensor.py` 26 | 27 | There are also some utility files: 28 | 29 | - `base_tensor.py` contains a common superclass that most of our tensors 30 | inherit from, that fixes up some problems with directly inheriting from 31 | torch.Tensor. We intend to upstream these changes so that this superclass 32 | is not necessary. 33 | - `utils.py` contains some handy utility functions that we found ourselves 34 | repeatedly using in our implementations. 35 | 36 | We're still working on the APIs in questions, so sometimes there will be bugs. 37 | `bug_zoo.py` contains repros for known bugs we're tracking in PyTorch proper. 38 | 39 | TODO 40 | 41 | - CUDA sanitizer in Python (hard cuz no event hooks) 42 | - Sparse gradients / outputs per Christian (using modes; gradients hard cuz 43 | need torch function mode) 44 | - SSD tensor 45 | - Reimplement functionalization tensor 46 | - Nested tensor 47 | - Custom allocator mode (albanD) 48 | - Lazy tensor 49 | - Immutable tensor 50 | - Various ways of writing FX passes https://gist.github.com/1c640ea30fd7451b08e90e34461459c1 51 | 52 | ## Work plan 53 | 54 | * TODO: merge BaseTensor into Tensor 55 | * TODO: torch function disable https://github.com/pytorch/pytorch/pull/73942 56 | * Get rid of `fill_defaults` 57 | 58 | * Compositionality 59 | * TODO: suppress elem in init 60 | 61 | ## Developer notes 62 | 63 | * This repo is formatted with ufmt and autoflakes. Use `./format.sh` to 64 | reformat all the files in this repository. 65 | -------------------------------------------------------------------------------- /autograd_monkeypatch.py: -------------------------------------------------------------------------------- 1 | from torch.overrides import TorchFunctionMode 2 | import torch.nn.functional 3 | 4 | class BuggyDropout(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x, p=0.5): 7 | print("forward") 8 | return x 9 | 10 | @staticmethod 11 | def backward(ctx, grad_output): 12 | print("backward") 13 | return grad_output, None 14 | 15 | 16 | class AutogradMonkeypatch(TorchFunctionMode): 17 | def __torch_function__(self, func, types, args=(), kwargs=None): 18 | if not kwargs: 19 | kwargs = {} 20 | if func is torch.nn.functional.dropout: 21 | return BuggyDropout.apply(*args) 22 | return func(*args, **kwargs) 23 | 24 | with AutogradMonkeypatch(): 25 | torch.nn.functional.dropout(torch.randn(4, requires_grad=True)).sum().backward() 26 | -------------------------------------------------------------------------------- /base_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # All of the tensor examples in this zoo inherit from BaseTensor. Ideally, 4 | # however, they would inherit directly from Tensor. This is just our staging 5 | # ground for applying behavior that hasn't yet made it into core but that 6 | # we would like to apply by default. 7 | class BaseTensor(torch.Tensor): 8 | # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary 9 | # to ensure that super().__new__ can cooperate with each other 10 | @staticmethod 11 | def __new__(cls, elem, *, requires_grad=None): 12 | if requires_grad is None: 13 | return super().__new__(cls, elem) 14 | else: 15 | return cls._make_subclass(cls, elem, requires_grad) 16 | 17 | # To ensure constructors can cooperate with one another, must accept and 18 | # ignore element tensor (TODO: is this right???) 19 | def __init__(self, elem): 20 | super().__init__() 21 | 22 | # If __torch_dispatch__ is defined (which it will be for all our examples) 23 | # the default torch function implementation (which preserves subclasses) 24 | # typically must be disabled 25 | __torch_function__ = torch._C._disabled_torch_function_impl 26 | -------------------------------------------------------------------------------- /bug_zoo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from base_tensor import BaseTensor 6 | from trivial_tensors import TrivialTensorViaInheritance 7 | from torch.testing._internal.common_utils import run_tests, TestCase 8 | 9 | 10 | class BugZoo(TestCase): 11 | @unittest.expectedFailure 12 | def test_binary_ops_swallow_errors(self): 13 | class BuggyTensor(BaseTensor): 14 | @classmethod 15 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 16 | raise TypeError("foobar") 17 | 18 | x = BuggyTensor(torch.tensor(1.0)) 19 | self.assertRaisesRegex(TypeError, "foobar", lambda: x + x) 20 | 21 | @unittest.skip 22 | def test_super_dispatch_segfault(self): 23 | class SuperDispatchSegfaultTensor(BaseTensor): 24 | @classmethod 25 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 26 | return super().__torch_dispatch__(func, types, list(args), kwargs) 27 | 28 | SuperDispatchSegfaultTensor(torch.tensor(1.0)).neg() 29 | 30 | # Fixed! 31 | def test_trivial_inplace(self): 32 | x = TrivialTensorViaInheritance(torch.tensor(1.0)) 33 | y = x * torch.tensor(1.0, requires_grad=True) 34 | y.relu_() 35 | 36 | # Fixed! 37 | def test_grad_fn(self): 38 | class TestTensor(BaseTensor): 39 | @classmethod 40 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 41 | if func is torch.ops.aten.add.Tensor and "alpha" in kwargs: 42 | # decompose it 43 | r = torch.add(args[0], args[1] * kwargs["alpha"]) 44 | self.assertIsNone(r.grad_fn) 45 | return r 46 | return super().__torch_dispatch__(func, types, args, kwargs) 47 | 48 | x = TestTensor(torch.tensor(1.0)).requires_grad_() 49 | y = TestTensor(torch.tensor(2.0)).requires_grad_() 50 | torch.add(x, y, alpha=2) 51 | 52 | 53 | if __name__ == "__main__": 54 | run_tests() 55 | -------------------------------------------------------------------------------- /complex_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ComplexTensor(torch.Tensor): 5 | def __new__(cls, re, im): 6 | assert ( 7 | re.device == im.device 8 | and re.layout == im.layout 9 | and re.requires_grad == im.requires_grad 10 | and re.dtype == im.dtype 11 | ) 12 | res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 13 | cls, 14 | size=re.size(), 15 | strides=re.stride(), # todo: contiguous only 16 | storage_offset=0, 17 | dtype=torch.complex64, # todo: real to complex dtype 18 | layout=re.layout, 19 | device=re.device, 20 | requires_grad=False, # todo: autograd support 21 | ) 22 | res.re = re 23 | res.im = im 24 | return res 25 | 26 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 27 | if func is torch.ops.aten.mm.default: 28 | assert not kwargs 29 | x, y = args 30 | re = x.re * y.re - x.im * y.im 31 | im = x.re * y.im + x.im * y.re 32 | return ComplexTensor(re, im) 33 | raise NotImplementedError(f"todo {func}") 34 | 35 | def __tensor_flatten__(self): 36 | return ["re", "im"], None 37 | 38 | @staticmethod 39 | def __tensor_unflatten__(inner_tensors, meta): 40 | assert meta is None 41 | re, im = inner_tensors["re"], inner_tensors["im"] 42 | return ComplexTensor(re, im) 43 | 44 | def __repr__(self): 45 | return f"ComplexTensor(real={self.re}, imag={self.im})" 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | @torch.compile() 51 | def f(x, y): 52 | return x @ y 53 | 54 | x = ComplexTensor(torch.tensor([[1]]), torch.tensor([[2]])) 55 | y = ComplexTensor(torch.tensor([[3]]), torch.tensor([[4]])) 56 | 57 | print(f(x, y)) # (1 + 2i) * (3 + 4i) = -5 + 10i 58 | -------------------------------------------------------------------------------- /cuda_sanitizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._python_dispatch import TorchDispatchMode 3 | from torch.utils._pytree import tree_map 4 | 5 | # TODO: dedupe from torch._subclasses.fake_tensor 6 | def contains_tensor_types(type): 7 | tensor_type = torch._C.TensorType.get() 8 | return type.isSubtypeOf(tensor_type) or any( 9 | contains_tensor_types(e) for e in type.containedTypes() 10 | ) 11 | 12 | 13 | class CUDASanitizer(TorchDispatchMode): 14 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 15 | if not kwargs: 16 | kwargs = {} 17 | 18 | # TODO: short circuit dispatch if no CUDA involved 19 | inputs = set() 20 | outputs = set() 21 | 22 | # TODO: a variant of tree map that also gives you the arg 23 | # schema would be pretty handy 24 | schema = func._schema 25 | for i, arg in enumerate(schema.arguments): 26 | if i < len(args): 27 | argument = args[i] 28 | else: 29 | if arg.name not in kwargs: 30 | continue 31 | argument = kwargs[arg.name] 32 | if not contains_tensor_types(arg.type): 33 | continue 34 | mut_arg = False 35 | if arg.alias_info: 36 | if arg.alias_info.is_write: 37 | mut_arg = True 38 | if isinstance(argument, torch.Tensor): 39 | if mut_arg: 40 | outputs.add(argument.storage()) 41 | else: 42 | inputs.add(argument.storage()) 43 | else: 44 | raise NotImplemented("todo tensor list") 45 | 46 | r = func(*args, **kwargs) 47 | 48 | def add_output(t): 49 | if isinstance(t, torch.Tensor): 50 | outputs.add(t.storage()) 51 | 52 | tree_map(add_output, r) 53 | 54 | def render(storage): 55 | stream = torch.cuda.current_stream(storage.device) 56 | return f"ptr {storage.data_ptr():#08x} on stream {stream.cuda_stream:#08x}" 57 | 58 | readonly_str = " ".join(map(render, inputs - outputs)) 59 | readwrite_str = " ".join(map(render, outputs)) 60 | 61 | print(f"launch_kernel inputs {readonly_str} outputs {readwrite_str} # {schema}") 62 | return r 63 | 64 | 65 | with CUDASanitizer.push(): 66 | s1 = torch.cuda.Stream() 67 | s2 = torch.cuda.Stream() 68 | 69 | with torch.cuda.stream(s1): 70 | t = torch.ones((100,), device="cuda:0", requires_grad=True) 71 | 72 | with torch.cuda.stream(s2): 73 | s = t.sum() 74 | s.backward() 75 | -------------------------------------------------------------------------------- /custom_parameter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils._pytree import tree_map, tree_map_only 4 | from torch.utils._python_dispatch import ( 5 | _get_current_dispatch_mode_stack, 6 | is_traceable_wrapper_subclass, 7 | TorchDispatchMode, 8 | ) 9 | 10 | # Subclasses are not very compositional: there is no one true way to 11 | # combine two distinct subclasses into a single one combining both 12 | # of their functionalities. 13 | # 14 | # This file shows a recipe for how to combine a custom parameter subclass 15 | # with a traditional tensor subclass, from Natalia Gimelshein. 16 | 17 | 18 | # First, the custom parameter subclass is just a subclass of nn.Parameter 19 | # that does NOT make use of the __torch_dispatch__ mechanism. Typical 20 | # use cases are to annotate parameters with extra methods and data describing 21 | # information about a Parameter that aren't supported on base parameter 22 | # (e.g., sharding.) Other than that it doesn't integrate with PyTorch 23 | # in any nontrivial way (if it did, we wouldn't be able to combine it.) 24 | class MyParameter(nn.Parameter): 25 | # This is added to make things work, come back here later 26 | def __new__(cls, data): 27 | if isinstance(data, ModeTensor): 28 | return ModeParameter(data.elem, data.mode) 29 | return super().__new__(cls, data) 30 | 31 | def custom_fn(self): 32 | print("Some custom function") 33 | 34 | 35 | # This is the tensor subclass we want to support. We've written it in the 36 | # same style as FakeTensor, which also supports a FakeTensorMode which can 37 | # be used to automatically cause plain tensors to be transformed into 38 | # ModeTensors. In this particular implementation, you can only work with 39 | # ModeTensor inside the mode, but it's also possible to add a 40 | # __torch_dispatch__ implementation that automatically installs the mode 41 | # when a ModeTensor is used without an active mode. 42 | # 43 | # This subclass is written in wrapper tensor style, so elem is probably 44 | # some real tensor. 45 | class ModeTensor(torch.Tensor): 46 | def __new__(cls, elem, mode): 47 | res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 48 | cls, 49 | size=elem.size(), 50 | strides=elem.stride(), 51 | storage_offset=elem.storage_offset(), 52 | dtype=elem.dtype, 53 | layout=elem.layout, 54 | device=elem.device, 55 | requires_grad=elem.requires_grad, 56 | ) 57 | 58 | res.elem = elem 59 | res.mode = mode 60 | return res 61 | 62 | @classmethod 63 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 64 | raise NotImplementedError("Shouldn't be here") 65 | 66 | # The mode is pretty trivial, just wrapping/unwrapping. 67 | class Mode(TorchDispatchMode): 68 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 69 | def unwrap(e): 70 | if isinstance(e, ModeTensor): 71 | return e.elem 72 | else: 73 | return e 74 | 75 | def wrap(t): 76 | if isinstance(t, torch.Tensor): 77 | return ModeTensor(t, self) 78 | else: 79 | return t 80 | 81 | return wrap(func(*tuple(unwrap(a) for a in args), **kwargs)) 82 | 83 | # So, the key to making this all work, is: 84 | # 85 | # 1. You need to make another class that multiply inherits from ModeTensor 86 | # and MyParameter. Order matters as you want to preferentially 87 | # use ModeTensor to handle methods. 88 | # 89 | # 2. You need to update __new__ on MyParameter to redirect to this class 90 | # (above) when you get a ModeTensor as argument, so that 91 | # Parameter(mode_tensor) works. 92 | # 93 | # If your ModeTensor has non-trivial extra data, you have to send all of 94 | # that data to the ModeParameter constructor 95 | class ModeParameter(ModeTensor, MyParameter): 96 | pass 97 | 98 | 99 | # See it in action: 100 | class MyModule(nn.Module): 101 | def __init__(self): 102 | super().__init__() 103 | self.my_param = MyParameter(torch.randn(3, 4)) 104 | 105 | # This works without mode tensor 106 | mod = MyModule() 107 | mod.my_param.custom_fn() 108 | 109 | # Now you get a mode tensor 110 | with Mode(): 111 | mod = MyModule() 112 | print(type(mod.my_param)) 113 | mod.my_param.custom_fn() 114 | -------------------------------------------------------------------------------- /data_parallel_tensor.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from typing import Any, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch._C import NoneType, device 8 | from torch._utils import _get_all_device_indices 9 | from torch.cuda import comm 10 | from torch.utils._pytree import tree_map 11 | 12 | # NOTE: We need to set this because when we lift the module parameters to DataParallelTensors (DPT) using mod._apply, 13 | # we not not want to do an in place copy of the new parameter value, we want to overwrite it. 14 | # The DPT is a list of tensors and hence an in-place copy between the old and new values of the parameter are incompatible. 15 | torch.__future__.set_overwrite_module_params_on_conversion(True) 16 | import concurrent.futures as futures 17 | 18 | torch.manual_seed(0) 19 | aten = torch.ops.aten 20 | NUM_DEVICES = 8 21 | PARALLEL_DISPATCH = False 22 | ALL_REDUCE = True 23 | 24 | 25 | class DPTensorType(Enum): 26 | # This tensor will be replicated across all the devices 27 | replicated = auto() 28 | # This tensor will be sharded along the first/batch dimension across 29 | # the devices, NOTE: only equal chunk sizes are supported 30 | distributed_batch = auto() 31 | # This is a list of tensors, each of which rests on different devices 32 | distributed = auto() 33 | 34 | 35 | class DataParallelTensor(torch.Tensor): 36 | # This class is a tensor subclass that stores a list of tensors with the aim 37 | # DataParallelTensors(DPT) are categorized in three ways 38 | # 1) replicated: When a single tensor is supplied, it is replicated across 39 | # all the devices by using broadcast 40 | # 2) distributed: DPT can also be initialized by supplying a list/tuple of tensors 41 | # if the elements rest on different devices, they will just be wrapped in DPT 42 | # else the elements are scattered to different devices 43 | # 3) distributed batch: This type of DPT tensor is created by sharding the input tensor across 44 | # a specified sharding dimension (default: 0). Currently only equal chunk sizes are supported. 45 | 46 | elem: List[torch.Tensor] 47 | 48 | if torch.cuda.is_available(): 49 | # device_ids: List[int] = _get_all_device_indices() 50 | device_ids = [i for i in range(NUM_DEVICES)] 51 | if PARALLEL_DISPATCH: 52 | num_threads: int = len(device_ids) 53 | threadpool: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor( 54 | max_workers=num_threads 55 | ) 56 | __slots__ = ["elem"] 57 | 58 | @staticmethod 59 | def __new__( 60 | cls, 61 | elem: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], 62 | func: Optional[Any] = None, 63 | dpt_type: DPTensorType = DPTensorType.replicated, 64 | batch_dim: Optional[int] = 0, 65 | ): 66 | 67 | if dpt_type == DPTensorType.replicated: 68 | # NOTE: If the input is None, we return None 69 | if elem is None: 70 | return None 71 | assert isinstance(elem, torch.Tensor) 72 | # NOTE: For handling meta tensors, if the device of an input tensor is meta, 73 | # we just return the first element in such a list/tuple 74 | if elem.device == device("meta"): 75 | return elem 76 | 77 | with torch.no_grad(): 78 | dp_tensor: List[torch.Tensor] = comm.broadcast( 79 | elem, devices=cls.device_ids 80 | ) 81 | 82 | elif dpt_type == DPTensorType.distributed: 83 | assert isinstance(elem, list) or isinstance(elem, tuple) 84 | # We check if the first elemnt of the list/tuple is a tensor 85 | if isinstance(elem[0], torch.Tensor): 86 | # Make a check to see if all elements are of type tensor 87 | assert all(isinstance(e, torch.Tensor) for e in elem) 88 | requires_scatter: bool = False 89 | with torch.no_grad(): 90 | for t, d_id in zip(elem, cls.device_ids): 91 | if t.device == device("meta"): 92 | # NOTE: For handling meta tensors, if the device of any tensor in such a list/tuple is meta, 93 | # we just return the first element in such a list/tuple. This usually happens for factory functions, 94 | # like torch.ones or torch.zeros generated either during forward or backward mode autodiff. 95 | # we cannot check the equality of elemts in here since they do not exist physically 96 | # we just check that all of them should be meta tensors 97 | if all(e.device == torch.device("meta") for e in elem): 98 | return elem[0] 99 | else: 100 | raise TypeError( 101 | f"Device error in {func}: Not all tensors are meta." 102 | ) 103 | if t.device != device(d_id): 104 | requires_scatter = True 105 | break 106 | 107 | if requires_scatter: 108 | # We first stack all the tensors in the list/tuple along dimension 0, to get a single tensor 109 | # We then scatter the tensor along the 0th dimension to different devices 110 | # The scatter function returns a list of tensors with a redundant 0th dimension for each element 111 | # We squeeze out the redundant dimension from each of these elements to finally get a list of tensors 112 | # each residing on a list of devices 113 | stacked_t: torch.Tensor = torch.stack(elem, dim=0) 114 | scattered_t: Tuple[torch.Tensor] = comm.scatter( 115 | stacked_t, devices=cls.device_ids, dim=0 116 | ) 117 | dp_tensor: List[torch.Tensor] = [ 118 | torch.squeeze(t, dim=0) for t in scattered_t 119 | ] 120 | else: 121 | dp_tensor: List[torch.Tensor] = elem 122 | else: 123 | # Elements of the list/tuple are non-tensors. 124 | # NOTE: If the list contains non-tensor types then we return a single value only if all of them have identical value. 125 | if all(v == elem[0] for v in elem): 126 | return elem[0] 127 | else: 128 | raise ValueError( 129 | f"Operation {func} retuns non-identical non-tensor values for some elemnts of DPT" 130 | ) 131 | 132 | elif dpt_type == DPTensorType.distributed_batch: 133 | # NOTE: This requires the batch dimension to be divisible by the number of devices. 134 | assert isinstance(elem, torch.Tensor) 135 | 136 | with torch.no_grad(): 137 | scattered_t: Tuple[torch.Tensor] = comm.scatter( 138 | elem, devices=cls.device_ids, dim=batch_dim 139 | ) 140 | dp_tensor: List[torch.Tensor] = list(scattered_t) 141 | 142 | meta_t: torch.Tensor = ( 143 | elem if dpt_type == DPTensorType.replicated else dp_tensor[0] 144 | ) 145 | 146 | r = torch.Tensor._make_wrapper_subclass( 147 | cls, 148 | meta_t.size(), 149 | strides=meta_t.stride(), 150 | storage_offset=meta_t.storage_offset(), 151 | device=meta_t.device, # This is the device of of either input tensor or first tensor of a list 152 | dtype=meta_t.dtype, 153 | layout=meta_t.layout, 154 | requires_grad=meta_t.requires_grad, 155 | ) 156 | r.elem = dp_tensor 157 | return r 158 | 159 | def __repr__(self): 160 | if self.grad_fn: 161 | return f"DataParallelTensor({self.elem}, grad_fn={self.grad_fn})" 162 | return f"DataParallelTensor({self.elem})" 163 | 164 | @classmethod 165 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 166 | def wrap(e): 167 | if isinstance(e, DataParallelTensor): 168 | return e 169 | elif isinstance(e, torch.Tensor): 170 | return DataParallelTensor(e, func, DPTensorType.replicated) 171 | else: 172 | return e 173 | 174 | # All the args and kwargs are checked and any leaf tensors are wrapped as replicated DPTs 175 | args = tree_map(wrap, args) 176 | kwargs = tree_map(wrap, kwargs) 177 | 178 | def unwrap_with_position(pos): 179 | def get_element(e): 180 | return e.elem[pos] if isinstance(e, DataParallelTensor) else e 181 | 182 | return get_element 183 | 184 | # Call the function for each of the DPT elements by unwarpping them and corresponding args and kwargs, 185 | # into element tensors so that the operation is performed on all the elements residing on the same device 186 | if PARALLEL_DISPATCH: 187 | future_res: List[futures.Future] = [] 188 | for pos in range(cls.num_threads): 189 | future_res.append( 190 | cls.threadpool.submit( 191 | func, 192 | *tree_map(unwrap_with_position(pos), args), 193 | **tree_map(unwrap_with_position(pos), kwargs), 194 | ) 195 | ) 196 | outs = [future_res[i].result() for i in range(cls.num_threads)] 197 | else: 198 | outs = [] 199 | for pos in range(len(cls.device_ids)): 200 | outs.append( 201 | func( 202 | *tree_map(unwrap_with_position(pos), args), 203 | **tree_map(unwrap_with_position(pos), kwargs), 204 | ) 205 | ) 206 | 207 | # The ouput will always be a list since we are creating it 208 | # The list can contain tensors, bools, list of tensors or tuples of tensors or None 209 | # In case of tensors we just wrap them in DPT 210 | # In case of list/tuple of tensors, the corresponding elements across list/tuple are warpped 211 | # into a DPT and a list/tuple is returned respectively 212 | 213 | def out_wrap(e, func): 214 | 215 | assert isinstance(e, list) 216 | 217 | if isinstance(e[0], torch.Tensor): 218 | return DataParallelTensor(outs, func, DPTensorType.distributed) 219 | elif isinstance(e[0], list): 220 | return list( 221 | DataParallelTensor(list(t), func, DPTensorType.distributed) 222 | for t in zip(*e) 223 | ) 224 | elif isinstance(e[0], tuple): 225 | return tuple( 226 | DataParallelTensor(list(t), func, DPTensorType.distributed) 227 | for t in zip(*e) 228 | ) 229 | else: 230 | # NOTE: If the list contains non-tensor types then we return a single value only if all of them have identical value. 231 | if all(v == e[0] for v in e): 232 | return e[0] 233 | else: 234 | raise ValueError( 235 | f"Operation {func} retuns non-identical non-tensor values for some elemnts of DPT" 236 | ) 237 | 238 | outs = out_wrap(outs, func) 239 | return outs 240 | 241 | def all_reduce_grad( 242 | self, 243 | r_device: Optional[int] = torch.cuda.current_device() 244 | if torch.cuda.is_available() 245 | else 0, 246 | ): 247 | with torch.no_grad(): 248 | reduced_tensor: torch.Tensor = comm.reduce_add(self.elem, r_device) 249 | b_tensor: List[torch.Tensor] = comm.broadcast(reduced_tensor, out=self.elem) 250 | self.elem = b_tensor 251 | return reduced_tensor 252 | 253 | 254 | def make_data_parallel_module(mod: torch.nn.Module): 255 | # This function converts the parameters of a nn.Module to replicated DataParallelTensors 256 | # the else part is important for buffers of the module 257 | def wrapper(t): 258 | if isinstance(t, torch.nn.Parameter): 259 | return DataParallelTensor(t.data, None, DPTensorType.replicated) 260 | else: 261 | assert type(t) in (torch.Tensor, NoneType, bool) 262 | return DataParallelTensor(t, None, DPTensorType.replicated) 263 | 264 | mod._apply(wrapper) 265 | 266 | 267 | if __name__ == "__main__": 268 | 269 | if torch.cuda.is_available(): 270 | print("Devices: ", [i for i in range(NUM_DEVICES)]) 271 | else: 272 | print("GPU not found. Need GPUs to run examples. Exiting...") 273 | exit() 274 | 275 | try: 276 | from functools import partial 277 | 278 | from functorch import hessian, jacfwd, jacrev, vjp, vmap 279 | 280 | D = 16 281 | x: torch.Tensor = torch.randn(D, device="cuda") 282 | dpt_x = DataParallelTensor(x, None, DPTensorType.replicated) 283 | 284 | def predict(weight, bias, x): 285 | return F.linear(x, weight, bias).tanh() 286 | 287 | weight = torch.randn(D, D, device="cuda") 288 | bias = torch.randn(D, device="cuda") 289 | 290 | # Computing Jacobian using vmap and vjp and jacrev 291 | clone_x = dpt_x.clone().requires_grad_() 292 | unit_vectors = torch.eye(D).cuda() 293 | 294 | _, vjp_fn = vjp(partial(predict, weight, bias), clone_x) 295 | (ft_jacobian,) = vmap(vjp_fn)(unit_vectors) 296 | 297 | clone_x = dpt_x.clone().requires_grad_() 298 | jacobian_rev = jacrev(predict, argnums=2)(weight, bias, clone_x) 299 | 300 | print(torch.allclose(ft_jacobian, jacobian_rev)) 301 | 302 | # Computing Hessian using composition of jacrev and jacfwd vs hessian api 303 | clone_x = dpt_x.clone().requires_grad_() 304 | hess_api = hessian(predict, argnums=2)(weight, bias, clone_x) 305 | hess_fwdrev = jacfwd(jacrev(predict, argnums=2), argnums=2)( 306 | weight, bias, clone_x 307 | ) 308 | print(torch.allclose(hess_api, hess_fwdrev)) 309 | except ImportError: 310 | print("Skipping functorch example, package missing.") 311 | 312 | try: 313 | # Example with a torchvision model 314 | import torchvision.models as models 315 | 316 | batch_size = 256 317 | test_tensor: torch.Tensor = torch.randn( 318 | batch_size * NUM_DEVICES, 3, 224, 224, device="cuda" 319 | ) 320 | dp_tensor = DataParallelTensor( 321 | test_tensor, None, DPTensorType.distributed_batch 322 | ) 323 | model = models.resnet50().cuda() 324 | make_data_parallel_module(model) 325 | # Warmp up iteration 326 | out = model(dp_tensor) 327 | loss = out.sum() 328 | loss.backward() 329 | start_event = torch.cuda.Event(enable_timing=True) 330 | end_event = torch.cuda.Event(enable_timing=True) 331 | start_event.record() 332 | for i in range(1): 333 | out = model(dp_tensor) 334 | loss = out.sum() 335 | loss.backward() 336 | if ALL_REDUCE: 337 | for p in model.parameters(): 338 | p.grad.all_reduce_grad() 339 | # p = p - 0.5 * p.grad 340 | end_event.record() 341 | torch.cuda.synchronize() 342 | print("Timing for 1 iteration (ms) DPT: ", start_event.elapsed_time(end_event)) 343 | 344 | test_tensor: torch.Tensor = torch.randn(batch_size, 3, 224, 224, device="cuda") 345 | model = models.resnet50().cuda() 346 | # Warmp up iteration 347 | out = model(test_tensor) 348 | loss = out.sum() 349 | loss.backward() 350 | start_event.record() 351 | for i in range(NUM_DEVICES): 352 | out = model(test_tensor) 353 | loss = out.sum() 354 | loss.backward() 355 | 356 | # for p in model.parameters(): 357 | # p = p - 0.5 * p.grad 358 | 359 | end_event.record() 360 | torch.cuda.synchronize() 361 | print( 362 | "Timing for " + str(NUM_DEVICES) + " iterations(ms): ", 363 | start_event.elapsed_time(end_event), 364 | ) 365 | except ImportError: 366 | print("Running custom model since torchvision package is absent.") 367 | 368 | # Custom Model Example 369 | class MyModel(torch.nn.Module): 370 | def __init__(self): 371 | super(MyModel, self).__init__() 372 | 373 | self.conv1 = nn.Conv2d(3, 6, 5) 374 | self.pool = nn.MaxPool2d(2, 2) 375 | self.conv2 = nn.Conv2d(6, 16, 5) 376 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 377 | self.fc2 = nn.Linear(120, 84) 378 | self.fc3 = nn.Linear(84, 10) 379 | 380 | def forward(self, x): 381 | x = self.pool(F.relu(self.conv1(x))) 382 | x = self.pool(F.relu(self.conv2(x))) 383 | x = torch.flatten(x, 1) # flatten all dimensions except batch 384 | x = F.relu(self.fc1(x)) 385 | x = F.relu(self.fc2(x)) 386 | x = self.fc3(x) 387 | return x 388 | 389 | mod: torch.nn.Module = MyModel().cuda() 390 | inp: torch.Tensor = torch.randn(512, 3, 32, 32, device="cuda") 391 | dpt_inp = DataParallelTensor(inp, None, DPTensorType.distributed_batch) 392 | make_data_parallel_module(mod) 393 | out = mod(dpt_inp) 394 | loss = out.sum() 395 | loss.backward() 396 | 397 | for p in mod.parameters(): 398 | p.grad.all_reduce_grad() 399 | p = p - 0.5 * p.grad 400 | 401 | # Custom Function Example 402 | test_tensor = torch.randn(8, 5, device="cuda", requires_grad=True) 403 | dp_tensor = DataParallelTensor(test_tensor, None, DPTensorType.distributed_batch) 404 | 405 | def custom_func(x): 406 | return x.cos().cos().sum() 407 | 408 | res_tensor = custom_func(dp_tensor) 409 | print(res_tensor) 410 | res_tensor.backward() 411 | print(dp_tensor.grad) 412 | -------------------------------------------------------------------------------- /deferred_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._subclasses.fake_tensor import FakeTensorMode 3 | from torch.fx.experimental.proxy_tensor import PythonKeyTracer, ProxyTorchDispatchMode 4 | from torch.fx import Graph, GraphModule 5 | 6 | # Limitations: 7 | # - initialization cannot refer to external tensors 8 | # - parameters are these weird ProxyTensors, should have a custom class for 9 | # these placeholders 10 | # - DCE is likely not sound, needs to be implemented more carefully by 11 | # understanding aliasing relationships 12 | # - only top level module is rematerialized 13 | # - we lose parameter-ness and requires_grad-ness 14 | # - no version counter safety to guard against input mutation 15 | 16 | def deferred_init(f, *args, **kwargs): 17 | fx_tracer = PythonKeyTracer() 18 | fx_tracer.graph = Graph(fx_tracer) 19 | fx_tracer.root = torch.nn.Module() 20 | fx_tracer.tensor_attrs = {} 21 | fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) 22 | proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode="real") 23 | with fake_tensor_mode, proxy_mode: 24 | r = f(*args, **kwargs) 25 | r._deferred = fx_tracer 26 | return r 27 | 28 | def materialize_module(m): 29 | # TODO: handle children 30 | 31 | outputs = [] 32 | 33 | def mark_for_materialize(tensors): 34 | for k, t in tensors.items(): 35 | if t is None: 36 | continue 37 | outputs.append(t.proxy.node) 38 | 39 | mark_for_materialize(m._parameters) 40 | mark_for_materialize(m._buffers) 41 | 42 | m._deferred.graph.output(outputs) 43 | m._deferred.graph.eliminate_dead_code() # hmmm 44 | recomp = GraphModule(m._deferred.root, m._deferred.graph) 45 | results = recomp() 46 | results_iter = iter(results) 47 | 48 | def replace_results(tensors): 49 | for k, t in tensors.items(): 50 | if t is None: 51 | continue 52 | tensors[k] = next(results_iter) 53 | 54 | replace_results(m._parameters) 55 | replace_results(m._buffers) 56 | 57 | del m._deferred 58 | 59 | 60 | m = deferred_init(torch.nn.Linear, 3, 5) 61 | print(m.weight) 62 | materialize_module(m) 63 | print(m.weight) 64 | -------------------------------------------------------------------------------- /dispatch_mem_profiler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils._python_dispatch import TorchDispatchMode 9 | 10 | aten = torch.ops.aten 11 | 12 | MB = 1024 * 1024.0 13 | 14 | operator_names: Dict[str, int] = defaultdict(int) 15 | mem_usage: Dict[str, float] = defaultdict(float) 16 | markers: Dict[str, int] = defaultdict(int) 17 | series: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) 18 | 19 | 20 | def normalize_tuple(x): 21 | if not isinstance(x, tuple): 22 | return (x,) 23 | return x 24 | 25 | 26 | def reduce_to_scalar_loss(inp): 27 | return inp.sum() 28 | 29 | 30 | class MemoryProfileDispatchMode(TorchDispatchMode): 31 | def __init__(self, verbose=False): 32 | self.verbose: bool = verbose 33 | 34 | def __torch_dispatch__(self, func, types, args=..., kwargs=None): 35 | rs = func(*args, **kwargs) 36 | if func == torch.ops.aten.detach.default: 37 | return rs 38 | mem: float = torch.cuda.memory_allocated() / MB 39 | func_name: str = func.__name__ + "_" + str(operator_names[func.__name__]) 40 | operator_names[func.__name__] = operator_names[func.__name__] + 1 41 | mem_usage[func_name] = mem 42 | if self.verbose: 43 | print("Mem Usage (" + func_name + "): ", mem) 44 | return rs 45 | 46 | 47 | def clear_state(): 48 | operator_names.clear() 49 | mem_usage.clear() 50 | 51 | 52 | def add_series(series_name): 53 | global mem_usage 54 | fin_usage = torch.cuda.memory_allocated() / MB 55 | mem_usage["fin_usage"] = fin_usage 56 | series[series_name] = mem_usage 57 | mem_usage = defaultdict(float) 58 | 59 | 60 | def save_graph(filename: str): 61 | for series_name, mem_usage in series.items(): 62 | y = mem_usage.values() 63 | min_val = min(y) 64 | max_val = max(y) 65 | x = [i for i in range(len(y))] 66 | plt.plot(x, y, label=series_name) 67 | plt.xlabel("# Operator Calls") 68 | plt.ylabel("Allocated Memory (MB)") 69 | plt.title(filename) 70 | for marker_name, marker in markers.items(): 71 | plt.plot([marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name) 72 | plt.legend() 73 | print("Saving Graph") 74 | plt.savefig(filename) 75 | plt.close() 76 | markers.clear() 77 | series.clear() 78 | 79 | 80 | def add_marker(marker_name): 81 | k = len(series.keys()) 82 | last_val_num = len(mem_usage.values()) 83 | markers[marker_name + str(k)] = last_val_num 84 | 85 | 86 | def mem_profile_model(mod: torch.nn.Module, inp: torch.Tensor): 87 | 88 | with MemoryProfileDispatchMode(True): 89 | pred = mod(inp) 90 | loss = reduce_to_scalar_loss(pred) 91 | loss.backward() 92 | mod.zero_grad(True) 93 | torch.cuda.synchronize() 94 | clear_state() 95 | pred = mod(inp) 96 | loss = reduce_to_scalar_loss(pred) 97 | add_marker("fw_bw_boundary") 98 | loss.backward() 99 | 100 | 101 | if __name__ == "__main__": 102 | try: 103 | import torchvision.models as models 104 | from functorch.compile import aot_module 105 | from functorch.compile import min_cut_rematerialization_partition 106 | from functorch.compile import nop 107 | from functorch.compile import print_compile 108 | 109 | mod: torch.nn.Module = models.resnet18().cuda() 110 | inp: torch.Tensor = torch.randn(32, 3, 224, 224, device="cuda") 111 | mem_profile_model(mod, inp) 112 | add_series("eager_mode") 113 | mod3 = aot_module(mod, nop, partition_fn=min_cut_rematerialization_partition) 114 | mem_profile_model(mod3, inp) 115 | add_series("aot_autograd_min_cut") 116 | save_graph("Resnet_mem_usage") 117 | clear_state() 118 | with MemoryProfileDispatchMode(True): 119 | mod3 = aot_module( 120 | mod, nop, partition_fn=min_cut_rematerialization_partition 121 | ) 122 | mod3(inp).sum().backward() 123 | add_series("aot_autograd_mem_usage") 124 | save_graph("autograd_mem_usage") 125 | except ImportError: 126 | 127 | class MyModel(torch.nn.Module): 128 | def __init__(self): 129 | super(MyModel, self).__init__() 130 | 131 | self.conv1 = nn.Conv2d(3, 6, 5) 132 | self.pool = nn.MaxPool2d(2, 2) 133 | self.conv2 = nn.Conv2d(6, 16, 5) 134 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 135 | self.fc2 = nn.Linear(120, 84) 136 | self.fc3 = nn.Linear(84, 10) 137 | 138 | def forward(self, x): 139 | x = self.pool(F.relu(self.conv1(x))) 140 | x = self.pool(F.relu(self.conv2(x))) 141 | x = torch.flatten(x, 1) # flatten all dimensions except batch 142 | x = F.relu(self.fc1(x)) 143 | x = F.relu(self.fc2(x)) 144 | x = self.fc3(x) 145 | return x 146 | 147 | mod: torch.nn.Module = MyModel().cuda() 148 | inp: torch.Tensor = torch.randn(512, 3, 32, 32, device="cuda") 149 | mem_profile_model(mod, inp) 150 | add_series("eager_mode") 151 | save_graph("Model_mem_usage") 152 | -------------------------------------------------------------------------------- /dynamic_strides.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:light 5 | # text_representation: 6 | # extension: .py 7 | # format_name: light 8 | # format_version: '1.5' 9 | # jupytext_version: 1.13.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | import sympy as sp 17 | from typing import List 18 | from dataclasses import dataclass 19 | 20 | # In this notebook, we explore how to model contiguity and strides in a 21 | # universe where we support dynamic shapes. We don't care about dynamic 22 | # strides/contiguity per se (we'd be OK with specializing on the input 23 | # being contiguous, channels-last, etc), but strides and contiguity 24 | # are *derived* from shapes, so if you have dynamic shapes, you 25 | # also end up with dynamic strides and contiguity. 26 | # 27 | # Let's take a concrete look at this phenomenon in the simplest possible 28 | # context: a contiguous tensor. Here is the C++ code which implements 29 | # computation of contiguous strides for a tensor: 30 | 31 | """ 32 | // From c10/util/strides.h 33 | // Computes the contiguous strides of a tensor, given its sizes. 34 | static inline std::vector contiguous_strides( 35 | const IntArrayRef sizes) { 36 | using Int = IntArrayRef::value_type; 37 | const Int dims = static_cast(sizes.size()); 38 | 39 | std::vector strides; 40 | 41 | if (dims > 0) { 42 | strides.assign(dims, 0); 43 | // Start by populating the last dimension: its strides is always 1. 44 | strides[dims - 1] = 1; 45 | for (auto i = dims - 2; i >= 0; --i) { 46 | // Strides can't be 0 even if sizes are 0. 47 | strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1}); 48 | } 49 | } 50 | 51 | return strides; 52 | } 53 | """ 54 | 55 | # And a port to Python: 56 | 57 | 58 | def contiguous_strides(sizes: List[int]): 59 | dims = len(sizes) 60 | strides = [] 61 | if dims > 0: 62 | strides = [0] * dims 63 | strides[dims - 1] = 1 64 | for i in range(dims - 2, -1, -1): 65 | strides[i] = strides[i + 1] * sp.Max(sizes[i + 1], 1) 66 | return strides 67 | 68 | 69 | print(contiguous_strides([2, 3, 5])) 70 | 71 | # Let's look at the symbolic output of this function. When only the batch 72 | # dimension is dynamic, things are pretty simple: 73 | 74 | x = sp.symbols("x") 75 | print(contiguous_strides([x, 3, 5])) 76 | 77 | # However, if an inner dimension is dynamic, the dynamic shape variable 78 | # shows up in the stride calculation 79 | 80 | print(contiguous_strides([2, x, 5])) 81 | 82 | # The set of strides returned by contiguous is guaranteed to be 83 | # contiguous, but the inverse is not true: there are some degrees of 84 | # freedom in the definition of strides when sizes are one or zero. 85 | # Here is our definition of "when something is contiguous" (not accounting 86 | # for overflow): 87 | 88 | """ 89 | // In c10/core/TensorImpl.h 90 | inline bool is_empty() const { 91 | return numel() == 0; 92 | } 93 | 94 | // In c10/core/TensorImpl.cpp 95 | bool TensorImpl::compute_contiguous() const { 96 | bool is_contiguous = true; 97 | if (is_empty()) 98 | return is_contiguous; 99 | int64_t z = 1; 100 | for (int64_t d = dim() - 1; d >= 0; d--) { 101 | const auto size_d = sizes_and_strides_.size_at_unchecked(d); 102 | if (size_d != 1) { 103 | if (sizes_and_strides_.stride_at_unchecked(d) == z) { 104 | z *= size_d; 105 | } else { 106 | is_contiguous = false; 107 | break; 108 | } 109 | } 110 | } 111 | return is_contiguous; 112 | } 113 | """ 114 | 115 | # In Python (note that we will use the suffix branchy to refer 116 | # to code which branches on the concrete value of sizes/strides): 117 | 118 | 119 | def compute_numel(sizes: List[int]): 120 | numel = 1 121 | for s in sizes: 122 | numel *= s 123 | return numel 124 | 125 | 126 | def compute_contiguous_branchy(sizes: List[int], strides: List[int]): 127 | is_contiguous = True 128 | if compute_numel(sizes) == 0: 129 | return is_contiguous 130 | z = 1 131 | for d in range(len(sizes) - 1, -1, -1): 132 | if sizes[d] != 1: 133 | if strides[d] == z: 134 | z *= sizes[d] 135 | else: 136 | is_contiguous = False 137 | break 138 | return is_contiguous 139 | 140 | 141 | # When a dimension has size 1, we are indifferent to the stride at that 142 | # dimension: 143 | 144 | print(contiguous_strides([3, 1, 5])) 145 | 146 | print(compute_contiguous_branchy([3, 1, 5], [5, 5, 1])) 147 | print(compute_contiguous_branchy([3, 1, 5], [5, 999999, 1])) 148 | 149 | # When a tensor contains zero elements, we are indifferent to all the 150 | # strides 151 | 152 | print(contiguous_strides([3, 0, 5])) 153 | 154 | print(compute_contiguous_branchy([3, 0, 5], [5, 5, 1])) 155 | print(compute_contiguous_branchy([3, 0, 5], [123456, 999999, 424242])) 156 | 157 | # Can we compute_contiguous symbolically? Unfortunately, the "branchy" 158 | # implementation, as written above cannot be run directly on SymPy 159 | # integers, as in several points in the code we condition on the 160 | # concrete values of various comparisons on integers. Fortunately, 161 | # we can introduce a SymInt/SymBool abstraction (as done in previous 162 | # notebooks) to provide concrete values and record guards expressing 163 | # what is required to be true for the computation to be correct. 164 | 165 | # + 166 | 167 | GUARDS = [] 168 | 169 | 170 | def is_constant(e): 171 | if hasattr(e, "is_constant"): 172 | return e.is_constant() 173 | elif e is sp.true or e is sp.false: 174 | return True 175 | else: 176 | return False 177 | 178 | 179 | class SymObject: 180 | def __post_init__(self): 181 | if self.expr is None: 182 | self.expr = sp.sympify(self.val) 183 | elif not isinstance(self.expr, sp.Expr): 184 | self.expr = sp.sympify(self.expr) 185 | 186 | 187 | @dataclass 188 | class SymBool(SymObject): 189 | val: bool 190 | expr: sp.Expr = None 191 | guarded: bool = False 192 | 193 | def __bool__(self): 194 | if not self.guarded: 195 | self.guarded = True 196 | if not is_constant(self.expr): 197 | if self.val: 198 | GUARDS.append(self.expr) 199 | else: 200 | GUARDS.append(sp.Not(self.expr)) 201 | return self.val 202 | 203 | 204 | def logical_and(self: bool, other: bool): 205 | if isinstance(self, SymBool) and isinstance(other, SymBool): 206 | return SymBool(self.val and other.val, sp.And(self.expr, other.expr)) 207 | return sp.And(self, other) 208 | 209 | 210 | def logical_or(self: bool, other: bool): 211 | if isinstance(self, SymBool) and isinstance(other, SymBool): 212 | return SymBool(self.val or other.val, sp.Or(self.expr, other.expr)) 213 | return sp.Or(self, other) 214 | 215 | 216 | @dataclass 217 | class SymInt(SymObject): 218 | val: int 219 | expr: sp.Expr = None 220 | guarded: bool = False 221 | 222 | def __int__(self): 223 | if not self.guarded: 224 | self.guarded = True 225 | if not is_constant(self.expr): 226 | GUARDS.append(self.Eq(self.expr, self.val).simplify()) 227 | return self.val 228 | 229 | def __eq__(self, other): 230 | if not isinstance(other, SymInt): 231 | other = SymInt(other) 232 | return SymBool(self.val == other.val, sp.Eq(self.expr, other.expr)) 233 | 234 | def __ne__(self, other): 235 | if not isinstance(other, SymInt): 236 | other = SymInt(other) 237 | return SymBool(self.val != other.val, sp.Ne(self.expr, other.expr)) 238 | 239 | def __mul__(self, other): 240 | if not isinstance(other, SymInt): 241 | other = SymInt(other) 242 | return SymInt(self.val * other.val, sp.Mul(self.expr, other.expr)) 243 | 244 | def __rmul__(self, other): 245 | if not isinstance(other, SymInt): 246 | other = SymInt(other) 247 | return SymInt(self.val * other.val, sp.Mul(self.expr, other.expr)) 248 | 249 | 250 | def I(val, expr=None): 251 | return SymInt(val, expr) 252 | 253 | 254 | # - 255 | 256 | # Let's run our example. Under the guards model, we must provide 257 | # concrete values for every symbolic integer, so we can resolve 258 | # conditionals. 259 | 260 | x1, x2, x3, y1, y2, y3 = sp.symbols("x1 x2 x3 y1 y2 y3") 261 | 262 | GUARDS.clear() 263 | print( 264 | compute_contiguous_branchy( 265 | [I(3, x1), I(1, x2), I(5, x3)], [I(5, y1), I(99999, y2), I(1, y3)] 266 | ) 267 | ) 268 | 269 | # We see that this tensor is contiguous... 270 | 271 | print(GUARDS) 272 | 273 | # ...subject to these conditions. These conditions say which particular 274 | # path through the loop we took: we require the sizes to be nonzero, 275 | # there are number of size one equalities/disequalities, and the 276 | # equality requirement between y1 and x3 is the "true" contiguity 277 | # requirement. 278 | 279 | # If we are willing to rewrite the definition of compute contiguous, we 280 | # can eliminate the branches, giving a symbolic expression with no 281 | # guards. 282 | 283 | 284 | def compute_contiguous(sizes, strides): 285 | is_contiguous = True 286 | z = 1 287 | for d in range(len(sizes) - 1, -1, -1): 288 | is_contiguous = logical_and( 289 | is_contiguous, logical_or(sp.Eq(sizes[d], 1), sp.Eq(strides[d], z)) 290 | ) 291 | z *= sizes[d] 292 | return logical_or(sp.Eq(compute_numel(sizes), 0), is_contiguous) 293 | 294 | 295 | # TODO: prove these two implementations are equivalent, somehow 296 | 297 | # We can see that no matter the choice of the stride for a size one 298 | # dimension, the result is always contiguous: 299 | 300 | print(compute_contiguous([3, 1, 5], [5, x, 1])) 301 | 302 | # And we can see the unflattened contiguity requirement for a completely 303 | # general size/stride tensor. 304 | 305 | print(compute_contiguous([x1, x2, x3], [y1, y2, y3])) 306 | 307 | # There's other stuff too: 308 | # 309 | # - We are not "just" compute_contiguous; we also have have variations 310 | # of this for every memory layout we support. So the same exercise 311 | # needs to apply everywhere. 312 | # 313 | # - We also have non_overlapping_and_dense which which involves a sort 314 | # which is very annoying. 315 | 316 | # In conclusion: 317 | # 318 | # - We have an explicit choice whether or not to branch inside 319 | # implementations of code that may be traced. More trace friendly 320 | # code is not as good for eager execution (because you can't do 321 | # things like short circuit). 322 | # 323 | # - If we store SymInt inside TensorImpl, we need to make a call about 324 | # how we represent the contiguity bits inside Tensor. These bits 325 | # are literally a single bit, so we cannot store a symbolic boolean 326 | # in them. It seems the easiest fix is to ensure the 327 | # is_contiguous() is virtualized (it is), and then internally run 328 | # (and cache) the symbolic formula done here. 329 | -------------------------------------------------------------------------------- /empty_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from base_tensor import BaseTensor 3 | from torch.testing._internal.common_utils import run_tests, TestCase 4 | from torch.utils._pytree import tree_map 5 | 6 | from utils import no_dispatch 7 | 8 | 9 | class EmptyTensor(BaseTensor): 10 | @staticmethod 11 | def __new__(cls, elem): 12 | return torch.Tensor._make_wrapper_subclass( 13 | cls, 14 | elem.size(), 15 | strides=elem.stride(), 16 | storage_offset=elem.storage_offset(), 17 | dtype=elem.dtype, 18 | layout=elem.layout, 19 | requires_grad=elem.requires_grad, 20 | device=elem.device, 21 | ) 22 | 23 | def __init__(self, elem): 24 | pass 25 | 26 | def __repr__(self): 27 | # TODO: this is wrong 28 | return f"EmptyTensor({self.size()})" 29 | 30 | @classmethod 31 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 32 | def inflate(t): 33 | if isinstance(t, cls): 34 | with no_dispatch(): 35 | return torch.ones_like(t, device=t.device) 36 | else: 37 | return t 38 | 39 | def deflate(t): 40 | if isinstance(t, torch.Tensor) and not isinstance(t, cls): 41 | return EmptyTensor(t) 42 | else: 43 | return t 44 | 45 | return tree_map( 46 | deflate, 47 | super().__torch_dispatch__( 48 | func, types, tree_map(inflate, args), tree_map(inflate, kwargs) 49 | ), 50 | ) 51 | 52 | 53 | class EmptyTensorTest(TestCase): 54 | def test_basic(self): 55 | x = EmptyTensor(torch.randn(4)) 56 | y = EmptyTensor(torch.randn(4)) 57 | r = x + y 58 | self.assertEqual(r.shape, (4,)) 59 | 60 | 61 | if __name__ == "__main__": 62 | run_tests() 63 | -------------------------------------------------------------------------------- /enhanced_error_mode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._python_dispatch import TorchDispatchMode 3 | from torch.utils._pytree import tree_map 4 | import itertools 5 | 6 | # cribbed from https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py 7 | 8 | class Lit: 9 | def __init__(self, s): 10 | self.s = s 11 | 12 | def __repr__(self): 13 | return self.s 14 | 15 | def fmt(t: object) -> str: 16 | if isinstance(t, torch.Tensor): 17 | return Lit(f"torch.tensor(..., size={tuple(t.shape)}, dtype={t.dtype}, device='{t.device}')") 18 | else: 19 | return t 20 | 21 | class EnhancedErrorMode(TorchDispatchMode): 22 | def __torch_dispatch__(self, func, types, args, kwargs): 23 | try: 24 | return func(*args, **kwargs) 25 | except Exception as ex: 26 | fmt_args = ", ".join( 27 | itertools.chain( 28 | (repr(tree_map(fmt, a)) for a in args), 29 | (f"{k}={tree_map(fmt, v)}" for k, v in kwargs.items()), 30 | ) 31 | ) 32 | msg = f"...when running {func}({fmt_args})" 33 | # https://stackoverflow.com/questions/17677680/how-can-i-add-context-to-an-exception-in-python 34 | msg = f'{ex.args[0]}\n{msg}' if ex.args else msg 35 | ex.args = (msg,) + ex.args[1:] 36 | raise 37 | 38 | if __name__ == "__main__": 39 | with EnhancedErrorMode(): 40 | torch.matmul(torch.randn(3), torch.randn(4, 5)) 41 | -------------------------------------------------------------------------------- /failures/grad_several_ways.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, NamedTuple, Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | """ 7 | This is a remix of Zachary DeVito's Simple Autograd 8 | https://colab.research.google.com/drive/1VpeE6UvEPRz9HmsHh1KS0XxXjYu533EC?usp=sharing 9 | to illustrate some concepts on a few different ways you can 10 | write autograd. To start, we replicate some data structures 11 | as is from the original colab. 12 | """ 13 | 14 | 15 | class TapeEntry(NamedTuple): 16 | # names of the inputs to the original computation 17 | inputs: List[str] 18 | # names of the outputs of the original computation 19 | outputs: List[str] 20 | # apply chain rule 21 | propagate: Callable[List[Tensor], List[Tensor]] 22 | 23 | 24 | _name = 0 25 | 26 | 27 | def fresh_name() -> str: 28 | """create a new unique name for a variable: v0, v1, v2""" 29 | global _name 30 | r = f"v{_name}" 31 | _name += 1 32 | return r 33 | 34 | 35 | """ 36 | We won't make use of a Variable wrapper class; instead we are going to store 37 | our autograd tape with the autograd context itself. For debuggability 38 | purposes, however, we still need a way to identify variables by a human 39 | readable name, and we'll do this by annotating them with a t_name attribute. 40 | """ 41 | 42 | 43 | def variable(t: Tensor, name: str = None): 44 | if not hasattr(t, "t_name"): 45 | t.t_name = name or fresh_name() 46 | return t 47 | 48 | 49 | """ 50 | In this file, I wanted to demonstrate a few different ways of implementing 51 | autograd in the context of a dispatcher-style system, and to do that, I 52 | needed to organize my operator calls in a different way than you might be 53 | used to. Instead of having a Variable wrapper class and defining methods 54 | on it, I will instead rely on a series of "dispatch layer" mixins which define 55 | a series of operations at a particular level in the dispatcher, and then 56 | I will mix them together (using Python multiple inheritance) to create a 57 | fully-featured applications. Do not worry if you are not familiar with 58 | mixin-style programming, we will explain it as we go along. 59 | 60 | To start with, we will implement a backend dispatch layer Torch. This just 61 | forwards on the operator calls to our underlying library PyTorch. You could 62 | also imagine replacing this with a Numpy backend or even a pure Python variant 63 | (although this file is not currently setup to do so.) Each operator is a 64 | method on the dispatch layer mixin, for reasons that will become apparent 65 | shortly. 66 | """ 67 | 68 | 69 | class Torch: 70 | def mul(self, lhs, rhs): 71 | return torch.mul(lhs, rhs) 72 | 73 | def add(self, lhs, rhs): 74 | return torch.add(lhs, rhs) 75 | 76 | def sum(self, input): 77 | return torch.sum(input) 78 | 79 | def expand(self, input, sizes): 80 | return input.expand(*sizes) 81 | 82 | 83 | """ 84 | Similar to the Torch mixin, the Autograd mixin will also define four 85 | methods (mul, add, sum, expand). Intuitively, the idea is that we will 86 | first call into Autograd, autograd will do some stuff, and then 87 | delegate to Torch to do the actual compute. 88 | """ 89 | 90 | 91 | class Delegate: 92 | def __init__(self, cls, obj): 93 | self._delegate_cls = cls 94 | self._delegate_obj = obj 95 | 96 | def __getattr__(self, name): 97 | x = getattr(self._delegate_cls, name) 98 | if hasattr(x, "__get__"): 99 | return x.__get__(self._delegate_obj) 100 | return x 101 | 102 | 103 | def gen_autograd(suffix="", *, backward_super: bool = False): 104 | class Autograd: 105 | def __init__(self): 106 | super().__init__() 107 | Autograd.set_gradient_tape(self, []) 108 | 109 | def get_gradient_tape(self): 110 | return getattr(self, f"gradient_tape_{Autograd.__name__}") 111 | 112 | def set_gradient_tape(self, x): 113 | return setattr(self, f"gradient_tape_{Autograd.__name__}", x) 114 | 115 | gradient_tape = property(get_gradient_tape, set_gradient_tape) 116 | 117 | def backward_dispatch(self, cb): 118 | if backward_super: 119 | return variable(cb(super())) 120 | else: 121 | return cb(Delegate(Autograd, self)) 122 | # return cb(self) 123 | 124 | def mul(self, lhs, rhs): 125 | if isinstance(rhs, float) and rhs == 1.0: 126 | # peephole optimization 127 | return lhs 128 | 129 | # define forward 130 | r = variable(super().mul(lhs, rhs)) 131 | print(f"{Autograd.__name__} {r.t_name} = {lhs.t_name} * {rhs.t_name}") 132 | 133 | # record what the inputs and outputs of the op were 134 | inputs = [lhs.t_name, rhs.t_name] 135 | outputs = [r.t_name] 136 | 137 | # define backprop 138 | def propagate(dL_doutputs: List[Tensor]): 139 | (dL_dr,) = dL_doutputs 140 | 141 | dr_dlhs = rhs # partial derivative of r = lhs*rhs 142 | dr_drhs = lhs # partial derivative of r = lhs*rhs 143 | 144 | # chain rule propagation from outputs to inputs of multiply 145 | # self or Autograd??? 146 | dL_dlhs = Autograd.backward_dispatch( 147 | self, lambda s: s.mul(dL_dr, dr_dlhs) 148 | ) 149 | dL_drhs = Autograd.backward_dispatch( 150 | self, lambda s: s.mul(dL_dr, dr_drhs) 151 | ) 152 | dL_dinputs = [dL_dlhs, dL_drhs] 153 | return dL_dinputs 154 | 155 | # finally, we record the compute we did on the tape 156 | Autograd.get_gradient_tape(self).append( 157 | TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate) 158 | ) 159 | return r 160 | 161 | def add(self, lhs, rhs): 162 | # Add follows a similar pattern to Mul, but it doesn't end up 163 | # capturing any variables. 164 | r = variable(super().add(lhs, rhs)) 165 | print(f"{Autograd.__name__} {r.t_name} = {lhs.t_name} + {rhs.t_name}") 166 | 167 | def propagate(dL_doutputs: List[Tensor]): 168 | (dL_dr,) = dL_doutputs 169 | dr_dlhs = 1.0 170 | dr_drhs = 1.0 171 | dL_dlhs = Autograd.backward_dispatch( 172 | self, lambda s: s.mul(dL_dr, dr_dlhs) 173 | ) 174 | dL_drhs = Autograd.backward_dispatch( 175 | self, lambda s: s.mul(dL_dr, dr_drhs) 176 | ) 177 | return [dL_dlhs, dL_drhs] 178 | 179 | Autograd.get_gradient_tape(self).append( 180 | TapeEntry( 181 | inputs=[lhs.t_name, rhs.t_name], 182 | outputs=[r.t_name], 183 | propagate=propagate, 184 | ) 185 | ) 186 | return r 187 | 188 | def sum(self, input: Tensor, name: Optional[str] = None): 189 | r = variable(super().sum(input), name=name) 190 | print(f"{Autograd.__name__} {r.t_name} = {input.t_name}.sum()") 191 | 192 | def propagate(dL_doutputs: List[Tensor]): 193 | (dL_dr,) = dL_doutputs 194 | size = input.size() 195 | return [ 196 | Autograd.backward_dispatch(self, lambda s: s.expand(dL_dr, size)) 197 | ] 198 | 199 | Autograd.get_gradient_tape(self).append( 200 | TapeEntry( 201 | inputs=[input.t_name], outputs=[r.t_name], propagate=propagate 202 | ) 203 | ) 204 | return r 205 | 206 | def expand(self, input: Tensor, sizes: List[int]): 207 | assert input.dim() == 0 # only works for scalars 208 | r = variable(super().expand(input, sizes)) 209 | print(f"{Autograd.__name__} {r.t_name} = {input.t_name}.expand({sizes})") 210 | 211 | def propagate(dL_doutputs: List[Tensor]): 212 | (dL_dr,) = dL_doutputs 213 | return [Autograd.backward_dispatch(self, lambda s: s.sum(dL_dr))] 214 | 215 | Autograd.get_gradient_tape(self).append( 216 | TapeEntry( 217 | inputs=[input.t_name], outputs=[r.t_name], propagate=propagate 218 | ) 219 | ) 220 | return r 221 | 222 | def reset_tape(self): 223 | Autograd.get_gradient_tape(self).clear() 224 | 225 | def grad(self, L, desired_results: List[Tensor]) -> List[Tensor]: 226 | # this map holds dL/dX for all values X 227 | dL_d: Dict[str, Tensor] = {} 228 | # It starts by initializing the 'seed' dL/dL, which is 1 229 | # TODO: indirect this via the backend 230 | dL_d[L.t_name] = variable(torch.ones(())) 231 | print(f"{Autograd.__name__} d{L.t_name} ------------------------") 232 | 233 | # look up dL_dentries. If a variable is never used to compute the loss, 234 | # we consider its gradient None, see the note below about zeros for more information. 235 | def gather_grad(entries: List[str]): 236 | return [dL_d[entry] if entry in dL_d else None for entry in entries] 237 | 238 | # propagate the gradient information backward 239 | for entry in reversed(Autograd.get_gradient_tape(self)): 240 | dL_doutputs = gather_grad(entry.outputs) 241 | if all(dL_doutput is None for dL_doutput in dL_doutputs): 242 | # optimize for the case where some gradient pathways are zero. See 243 | # The note below for more details. 244 | continue 245 | 246 | # perform chain rule propagation specific to each compute 247 | dL_dinputs = entry.propagate(dL_doutputs) 248 | 249 | # Accululate the gradient produced for each input. 250 | # Each use of a variable produces some gradient dL_dinput for that 251 | # use. The multivariate chain rule tells us it is safe to sum 252 | # all the contributions together. 253 | for input, dL_dinput in zip(entry.inputs, dL_dinputs): 254 | if input not in dL_d: 255 | dL_d[input] = dL_dinput 256 | else: 257 | dL_d[input] = Autograd.backward_dispatch( 258 | self, lambda s: s.add(dL_d[input], dL_dinput) 259 | ) 260 | 261 | # print some information to understand the values of each intermediate 262 | for name, value in dL_d.items(): 263 | print(f"{Autograd.__name__} d{L.t_name}_d{name} = {value.t_name}") 264 | print(f"------------------------") 265 | 266 | return gather_grad(desired.t_name for desired in desired_results) 267 | 268 | Autograd.__name__ = f"Autograd{suffix}" 269 | return Autograd 270 | 271 | 272 | Autograd = gen_autograd() 273 | 274 | # sum is used to turn our matrices into a single scalar to get a loss. 275 | # expand is the backward of sum, so it is added to make sure our Variable 276 | # is closed under differentiation. Both have rules similar to mul above. 277 | 278 | # TODO: indirect this via the backend 279 | torch.manual_seed(0) 280 | a, b = variable(torch.rand(4)), variable(torch.rand(4)) 281 | 282 | 283 | class Example1(Autograd, Torch): 284 | def simple(self, a, b): 285 | t = self.add(a, b) 286 | return self.mul(t, b) 287 | 288 | def main(self): 289 | loss = self.simple(a, b) 290 | da, db = self.grad(loss, [a, b]) 291 | print("da", da) 292 | print("db", db) 293 | 294 | 295 | Example1().main() 296 | 297 | 298 | class Example2Direct(Autograd, Torch): 299 | def simple(self, a, b): 300 | t = self.add(a, b) 301 | return self.mul(t, b) 302 | 303 | def run_gradients(self): 304 | # our first loss 305 | L0 = self.sum(self.simple(a, b), name="L0") 306 | 307 | # compute derivatives of our inputs 308 | dL0_da, dL0_db = self.grad(L0, [a, b]) 309 | 310 | # now lets compute the L2 norm of our derivatives 311 | L1 = self.sum( 312 | self.add(self.mul(dL0_da, dL0_da), self.mul(dL0_db, dL0_db)), name="L1" 313 | ) 314 | 315 | # and take the gradient of that. 316 | # notice there are two losses involved. 317 | dL1_da, dL1_db = self.grad(L1, [a, b]) 318 | return dL1_da, dL1_db 319 | 320 | def main(self): 321 | da, db = self.run_gradients() 322 | print("da", da) 323 | print("db", db) 324 | 325 | 326 | Example2Direct().main() 327 | 328 | Autograd1 = gen_autograd("1", backward_super=True) 329 | Autograd2 = gen_autograd("2", backward_super=True) 330 | 331 | 332 | class Example2Indirect(Autograd2, Autograd1, Torch): 333 | def simple(self, cls, a, b): 334 | t = cls.add(self, a, b) 335 | return cls.mul(self, t, b) 336 | 337 | def run_gradients(self): 338 | 339 | # Imagine grad(grad(...)) 340 | # we first allocate variables for the outer grad (Autograd1) 341 | # then they get wrapped in variables again for inner grad (Autograd2) 342 | 343 | L0 = Autograd2.sum(self, self.simple(Autograd2, a, b), name="L0") 344 | dL0_da, dL0_db = Autograd2.grad(self, L0, [a, b]) 345 | 346 | # Now we can "throw out" the tape for Autograd2 347 | Autograd2.reset_tape(self) 348 | 349 | # now lets compute the L2 norm of our derivatives, in Autograd1 350 | L1 = Autograd1.sum( 351 | self, 352 | Autograd1.add( 353 | self, 354 | Autograd1.mul(self, dL0_da, dL0_da), 355 | Autograd1.mul(self, dL0_db, dL0_db), 356 | ), 357 | name="L1", 358 | ) 359 | 360 | # and take the gradient of that. 361 | # notice there are two losses involved. 362 | dL1_da, dL1_db = Autograd1.grad(self, L1, [a, b]) 363 | return dL1_da, dL1_db 364 | 365 | def main(self): 366 | da, db = self.run_gradients() 367 | print("da", da) 368 | print("db", db) 369 | 370 | 371 | Example2Indirect().main() 372 | 373 | # Autograd2, Batched, Autograd1 374 | # Autograd2 will record a tape with batched tensors 375 | # Autograd1 will record a tape with raw tensors 376 | # Autograd1 tape is not usable for Autograd2 (vmap(grad(...))) case, as 377 | # losses are expected to come out batched but you'll get out raw tensors 378 | # Autograd2 tape could work, but the type is "wrong" and you need 379 | # to unwrap them 380 | # 381 | # Does Batched have to actually wrap? If everything was virtualized, 382 | # not actually; batched layer can simply "reinterpret" size query 383 | # appropriately 384 | # 385 | # New idea: don't wrap tensors, instead wrap "dispatcher levels" in Python. 386 | # One object instead of many. 387 | # ~ how to reuse the old objects? "Phantom" tensors organized by the 388 | # instance (need weakrefs...) 389 | # ~ this is the old "levels subsume everything" idea (dropped because 390 | # we wanted nonlexical--but nonlexical just implies determining the 391 | # dispatch list from some other context) 392 | # 393 | # Wrapping is very natural for users. Still OK: single wrapper tensor, 394 | # solve by internal dispatch mechanism. Don't use super for compositional. 395 | # 396 | # Is there still a place for OO/multiple inheritance/MRO/super cooperative? 397 | # Traditional OO/mixin style programming, with self footgun (extra 398 | # expressivity typically not what you want). Mixin can be converted into 399 | # level but there isn't really any reason to do it. Natively interoperates 400 | # with traditional object dispatch. 401 | # 402 | # Literally the point was 403 | # - get the __torch_function__ subclass ordering to work in our favor 404 | # - noticed super() is a thing and want to use that as basis for dispatch 405 | # 406 | # - why not wrapper: because of metadata swapping 407 | # 408 | # https://fuhm.net/super-harmful/ 409 | -------------------------------------------------------------------------------- /flat_view_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._pytree import tree_map 3 | from torch import nn 4 | from torch.nn import Parameter 5 | from typing import List 6 | 7 | # TODO: support tensor methods 8 | # TODO: use wrapper subclass so we don't leak the original parameter storage 9 | class IndirectParameter(torch.Tensor): 10 | def __new__(cls, indir, grad_indir): 11 | elem = indir() 12 | return cls._make_subclass(cls, elem, True) 13 | 14 | def __init__(self, indir, grad_indir): 15 | self.indir = indir 16 | self.grad_indir = grad_indir 17 | self._is_param = True 18 | 19 | @classmethod 20 | def __torch_function__(cls, func, types, args=(), kwargs=None): 21 | if kwargs is None: 22 | kwargs = {} 23 | 24 | # TODO: you could imagine some sort of caching mechanism, but 25 | # then you'd also need to design an invalidation mechanism 26 | def resolve_indir(t): 27 | if isinstance(t, IndirectParameter): 28 | return t.indir() 29 | else: 30 | return t 31 | 32 | return func(*tree_map(resolve_indir, args), **tree_map(resolve_indir, kwargs)) 33 | 34 | @property 35 | def grad(self): 36 | return self.grad_indir() 37 | 38 | # TODO: need to handle checkpointing(?) 39 | 40 | 41 | # model: 42 | # - as far as autograd is concerned, flat parameter is the only leaf 43 | # - as far as optimizer is concerned, real parameters are the only 44 | # parameters 45 | 46 | 47 | class FlattenParamsWrapper(nn.Module): 48 | def __init__(self, module, param_buckets: List[List[Parameter]]): 49 | super().__init__() 50 | self._module = module 51 | # TODO: shift the parameter level 52 | # find where the parameters live in the modules, install default 53 | # mapping 54 | shared_param_memo = {} 55 | self._underlying = {} 56 | self._transform = {} 57 | for submodule_name, submodule in module.named_modules(): 58 | for param_name, param in submodule.named_parameters(recurse=False): 59 | assert param not in shared_param_memo, "NYI" 60 | shared_param_memo[param] = (submodule, submodule_name, param_name) 61 | k = (submodule_name, param_name) 62 | self._underlying[k] = param 63 | self._transform[k] = lambda t: t 64 | for param, memo in shared_param_memo.items(): 65 | submodule, submodule_name, param_name = memo 66 | 67 | def mk_indirect_parameter(k): 68 | return IndirectParameter( 69 | lambda: self._transform[k](self._underlying[k]), 70 | lambda: self._transform[k](self._underlying[k].grad), 71 | ) 72 | new_p = mk_indirect_parameter((submodule_name, param_name)) 73 | 74 | delattr(submodule, param_name) 75 | # TODO: make this look like a parameter 76 | setattr(submodule, param_name, new_p) 77 | # go through the buckets and update the mapping into the flat 78 | # parameters 79 | # TODO: shared params are not handled. the aliasing should be detected 80 | # and the params coalesced into one location in the flat parameter 81 | # TODO: copying into a preallocated cat buffer save reallocation 82 | # TODO: this doesn't preserve memory format of the input parameters 83 | # TODO: check dtypes match 84 | for i, params in enumerate(param_buckets): 85 | flat_param = torch.cat([ 86 | p.detach().clone(memory_format=torch.contiguous_format).view(-1) 87 | for p in params 88 | ], dim=0) 89 | flat_param.requires_grad = True 90 | self.register_buffer(f"flat_param{i}", flat_param) 91 | offset = 0 92 | for p in params: 93 | submodule, submodule_name, param_name = shared_param_memo[p] 94 | k = (submodule_name, param_name) 95 | self._underlying[k] = flat_param 96 | 97 | def mk_transform(offset, numel, shape): 98 | def transform(t): 99 | if t is None: 100 | return t 101 | return t[offset:offset + numel].view(shape) 102 | return transform 103 | 104 | self._transform[k] = mk_transform(offset, p.numel(), p.shape) 105 | offset += p.numel() 106 | 107 | def forward(self, *args, **kwargs): 108 | return self._module(*args, **kwargs) 109 | 110 | model = nn.Sequential( 111 | nn.Linear(8, 4, bias=False), 112 | nn.Linear(4, 2, bias=False), 113 | ) 114 | 115 | B = 10 116 | input = torch.randn(B, 8) 117 | 118 | print(model(input)) 119 | 120 | model = FlattenParamsWrapper(model, [[model[0].weight, model[1].weight]]) 121 | print(model.flat_param0) 122 | print(type(model._module[0].weight)) 123 | print(model(input)) 124 | print(list(model.named_parameters())) 125 | 126 | loss = model(input).sum() 127 | loss.backward() 128 | print(model._module[0].weight.grad) 129 | print(model.flat_param0.grad) 130 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -ex 3 | ufmt format -- *.py 4 | autoflake --remove-all-unused-imports --in-place -- *.py 5 | -------------------------------------------------------------------------------- /functorch_test.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import functools 4 | 5 | import torch 6 | from base_tensor import BaseTensor 7 | from torch import Tensor 8 | from torch.testing._internal.common_utils import run_tests, TestCase 9 | from torch.utils._pytree import tree_map 10 | from torch.overrides import enable_reentrant_dispatch 11 | 12 | from utils import no_dispatch 13 | 14 | # TODO: batched tensor (metadata doesn't match, so this needs more APIs) 15 | 16 | LEVEL = 0 17 | 18 | 19 | @contextlib.contextmanager 20 | def new_level(): 21 | global LEVEL 22 | LEVEL += 1 23 | try: 24 | yield LEVEL 25 | finally: 26 | LEVEL -= 1 27 | 28 | 29 | def unwrap(t, level): 30 | if isinstance(t, WrapperTensor) and t.level == level: 31 | return t.elem 32 | else: 33 | return t 34 | 35 | 36 | class WrapperTensor(BaseTensor): 37 | @staticmethod 38 | def __new__(cls, elem, level): 39 | # This is probably wrong for batched tensor, for autograd 40 | # it's good because make_subclass internally detaches. 41 | # no_dispatch is required to prevent detach form going to subclass. 42 | with no_dispatch(): 43 | return cls._make_subclass(cls, elem) 44 | 45 | def __repr__(self): 46 | return f"WrapperTensor{self.level}({super().__repr__()}, {repr(self.elem)})" 47 | 48 | def __init__(self, elem, level): 49 | self.elem = elem 50 | self.level = level 51 | 52 | @classmethod 53 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 54 | max_level = -1 55 | 56 | def find_level(t): 57 | nonlocal max_level 58 | if isinstance(t, cls): 59 | max_level = max(max_level, t.level) 60 | 61 | # TODO: don't use tree_map 62 | tree_map(find_level, args) 63 | tree_map(find_level, kwargs) 64 | 65 | def matches_level(t): 66 | return isinstance(t, cls) and t.level == max_level 67 | 68 | def unwrap(t): 69 | if matches_level(t): 70 | return t.elem 71 | else: 72 | return t 73 | 74 | def wrap(t): 75 | if isinstance(t, Tensor) and not matches_level(t): 76 | return cls(t, max_level) 77 | else: 78 | return t 79 | 80 | with enable_reentrant_dispatch(): 81 | return tree_map( 82 | wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 83 | ) 84 | 85 | 86 | def grad_and_value(func): 87 | @functools.wraps(func) 88 | def wrapper(input): 89 | with new_level() as level: 90 | assert isinstance(input, torch.Tensor) 91 | input = WrapperTensor(input, level) 92 | input.requires_grad_() 93 | output = func(input) 94 | (grad_input,) = torch.autograd.grad( 95 | output, input, create_graph=True, allow_unused=True 96 | ) 97 | return unwrap(grad_input, level), unwrap(output, level) 98 | 99 | return wrapper 100 | 101 | 102 | def grad(func): 103 | @functools.wraps(func) 104 | def wrapper(input): 105 | grad_input, _ = grad_and_value(func)(input) 106 | return grad_input 107 | 108 | return wrapper 109 | 110 | 111 | class FunctorchTest(TestCase): 112 | def test_basic(self): 113 | x = torch.randn([]) 114 | result = grad(torch.sin)(x) 115 | self.assertEqual(result, torch.cos(x)) 116 | 117 | def test_grad_of_grad(self): 118 | x = torch.randn([]) 119 | result = grad(grad(lambda x: x**3))(x) 120 | self.assertEqual(result, 6 * x) 121 | 122 | 123 | if __name__ == "__main__": 124 | run_tests() 125 | -------------------------------------------------------------------------------- /inner_autograd_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional 3 | 4 | from base_tensor import BaseTensor 5 | from torch.testing._internal.common_utils import run_tests, TestCase 6 | from torch.utils._pytree import tree_map 7 | from utils import fill_defaults 8 | from torch.overrides import enable_reentrant_dispatch 9 | 10 | # This file describes how to use wrapper tensors (ala TrivialTensorViaComposition) 11 | # to override autograd from __torch_dispatch__. Ordinarily, 12 | # __torch_dispatch__ runs after autograd, so you have no way of overriding 13 | # the autograd behavior (since it will be handled after you return). However, 14 | # if we put the autograd tensor *inside* a wrapper tensor (which doesn't 15 | # itself require gradients), we get a chance to interpose (in __torch_dispatch__) 16 | # before you handle gradients on the inner element. 17 | # 18 | # Note that you can also use __torch_function__ instead to implement this 19 | # functionality, so this is mostly a question of whether or not you want to 20 | # target the public Python API, or the internal ATen operators API 21 | # (torch.ops.aten). 22 | 23 | 24 | class InnerAutogradTensor(BaseTensor): 25 | @staticmethod 26 | def __new__(cls, elem, *, requires_grad=None): 27 | # Outer tensor's autograd is now disconnected from the inner 28 | # tensors autograd... 29 | return super().__new__(cls, elem, requires_grad=False) 30 | 31 | def __init__(self, elem): 32 | # ... but note that we save the inner tensor, so we can still 33 | # do autograd on operations on the inside! 34 | self.elem = elem 35 | 36 | @classmethod 37 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 38 | def unwrap(t): 39 | if isinstance(t, cls): 40 | return t.elem 41 | elif isinstance(t, torch.Tensor) and t.requires_grad: 42 | # If any other argument at this level does require gradients 43 | # it will not interact with our inner Tensor and thus this 44 | # should fail. 45 | raise RuntimeError("Bad mixup of autograd level") 46 | else: 47 | return t 48 | 49 | def wrap(t): 50 | # Micro-optimization: not necessary to rewrap if the output tensor 51 | # doesn't require gradients 52 | if ( 53 | isinstance(t, torch.Tensor) 54 | and not isinstance(t, cls) 55 | and t.requires_grad 56 | ): 57 | return cls(t) 58 | else: 59 | return t 60 | 61 | with enable_reentrant_dispatch(): 62 | # Override gradient behavior 63 | if func == torch.ops.aten.embedding.default: 64 | args = fill_defaults(args, 5, [-1, False, False]) 65 | weight, indices, padding_idx, scale_grad_by_freq, _sparse = map( 66 | unwrap, args 67 | ) 68 | assert not kwargs 69 | # Force sparse gradients. We could have also done this by 70 | # defining a custom autograd function. 71 | return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True)) 72 | 73 | return tree_map( 74 | wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 75 | ) 76 | 77 | 78 | class InnerAutogradTensorTest(TestCase): 79 | def test_basic(self): 80 | x = torch.randn(1, requires_grad=True) 81 | y = InnerAutogradTensor(x) 82 | self.assertFalse(y.requires_grad) 83 | self.assertTrue(y.elem.requires_grad) 84 | z = InnerAutogradTensor(x) 85 | # Although y and z do not require grad, we are still able 86 | # to differentiate 87 | r = y + z 88 | # Note we have to extract out the inner tensor (which requires_grad) 89 | # to actually differentiate 90 | r.sum().elem.backward() 91 | self.assertEqual(x.grad, torch.tensor([2.0])) # two uses! 92 | 93 | def test_embedding(self): 94 | input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 95 | weights = torch.rand(10, 3, requires_grad=True) 96 | embedding_matrix = InnerAutogradTensor(weights) 97 | r = torch.nn.functional.embedding(input, embedding_matrix) 98 | r.sum().elem.backward() 99 | # Gradient is sparse even though we didn't ask for it in embedding! 100 | self.assertTrue(weights.grad.is_sparse) 101 | 102 | def test_mixing(self): 103 | # Mixing behavior is confusing. Let's take a look 104 | w1 = torch.randn(1, requires_grad=True) 105 | w2 = torch.randn(1, requires_grad=True) 106 | 107 | # Autograd doesn't "unwrap" variables, they still remember if they 108 | # requires_grad; and in fact, inside __torch_dispatch__ it is willing 109 | # to mix gradients between multiple levels. The current class does 110 | # catch most of these though when it is looking at the different 111 | # arguments 112 | with self.assertRaisesRegex(RuntimeError, "Bad mixup of autograd level"): 113 | x = InnerAutogradTensor(w1) + w2 114 | 115 | 116 | if __name__ == "__main__": 117 | run_tests() 118 | -------------------------------------------------------------------------------- /logging_mode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._python_dispatch import TorchDispatchMode 3 | from torch.utils._pytree import tree_map 4 | from torch.utils.weak import WeakIdRef 5 | from torch.testing._internal.common_utils import run_tests, TestCase 6 | import torch.overrides 7 | 8 | import weakref 9 | from functools import partial 10 | import itertools 11 | 12 | 13 | 14 | dtype_abbrs = { 15 | torch.bfloat16: "bf16", 16 | torch.float64: "f64", 17 | torch.float32: "f32", 18 | torch.float16: "f16", 19 | torch.complex32: "c32", 20 | torch.complex64: "c64", 21 | torch.complex128: "c128", 22 | torch.int8: "i8", 23 | torch.int16: "i16", 24 | torch.int32: "i32", 25 | torch.int64: "i64", 26 | torch.bool: "b8", 27 | torch.uint8: "u8", 28 | } 29 | 30 | 31 | class Lit: 32 | def __init__(self, s): 33 | self.s = s 34 | 35 | def __repr__(self): 36 | return self.s 37 | 38 | 39 | class LoggingMode(TorchDispatchMode): 40 | next_id: int 41 | 42 | def __init__(self, with_type: bool = True, collect_logs=False): 43 | self.memo = {} 44 | self.next_id = 0 45 | self.with_type = with_type 46 | self.collect_logs = collect_logs 47 | self.logs = [] 48 | 49 | def _shortid(self, t: torch.Tensor) -> int: 50 | o = WeakIdRef(t) 51 | weak_self = weakref.ref(self) 52 | 53 | def del_memo(): 54 | self = weak_self() 55 | if self is None: 56 | return 57 | self.memo.pop(o, None) 58 | 59 | weakref.finalize(t, del_memo) 60 | if o not in self.memo: 61 | self.memo[o] = self.next_id 62 | self.next_id += 1 63 | return self.memo[o] 64 | 65 | def _fmt(self, a: object, with_type: bool = False) -> str: 66 | if isinstance(a, torch.Tensor): 67 | maybe_type = "" 68 | if with_type and self.with_type: 69 | maybe_type = f": {dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" 70 | return Lit(f"${self._shortid(a)}{maybe_type}") 71 | else: 72 | return a 73 | 74 | def str_logs(self): 75 | return '\n'.join(self.logs) 76 | 77 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 78 | if kwargs is None: 79 | kwargs = {} 80 | rs = func(*args, **kwargs) 81 | fmt_args = ", ".join( 82 | itertools.chain( 83 | (repr(tree_map(self._fmt, a)) for a in args), 84 | (f"{k}={tree_map(self._fmt, v)}" for k, v in kwargs.items()), 85 | ) 86 | ) 87 | fmt_rets = repr(tree_map(partial(self._fmt, with_type=True), rs)) 88 | log_msg = f"{fmt_rets} = {torch.overrides.resolve_name(func)}({fmt_args})" 89 | if self.collect_logs: 90 | self.logs.append(log_msg) 91 | else: 92 | print(log_msg) 93 | return rs 94 | 95 | 96 | with LoggingMode(): 97 | torch.nn.functional.dropout(torch.randn(3), 0.5) 98 | 99 | 100 | class TracerTensorTest(TestCase): 101 | def test_basic(self): 102 | with LoggingMode(collect_logs=True) as mode: 103 | x = torch.randn(2, 3, requires_grad=True) 104 | y = torch.randn(3, 4) 105 | with torch.autocast('cpu'): 106 | r = x @ y 107 | r.sum().backward() 108 | self.assertExpectedInline( 109 | mode.str_logs(), 110 | """\ 111 | $0: f32[2, 3] = aten.randn.default([2, 3], device=cpu, pin_memory=False) 112 | $1: f32[3, 4] = aten.randn.default([3, 4], device=cpu, pin_memory=False) 113 | $2: bf16[3, 4] = aten._to_copy.default($1, dtype=torch.bfloat16) 114 | $3: bf16[2, 3] = aten._to_copy.default($0, dtype=torch.bfloat16) 115 | $4: bf16[2, 4] = aten.mm.default($3, $2) 116 | $5: bf16[] = aten.sum.default($4) 117 | $6: bf16[] = aten.ones_like.default($5, dtype=torch.bfloat16, layout=torch.strided, device=cpu, pin_memory=False, memory_format=torch.preserve_format) 118 | $7: bf16[2, 4] = aten.expand.default($6, [2, 4]) 119 | $8: bf16[4, 3] = aten.t.default($2) 120 | $9: bf16[2, 3] = aten.mm.default($7, $8) 121 | $10: f32[2, 3] = aten._to_copy.default($9, dtype=torch.float32, layout=torch.strided, device=cpu) 122 | $11: f32[2, 3] = aten.detach.default($10) 123 | $12: f32[2, 3] = aten.detach.default($11)""", # noqa 124 | ) 125 | 126 | if __name__ == "__main__": 127 | run_tests() 128 | -------------------------------------------------------------------------------- /max_mem_tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._pytree import tree_map_only 3 | from torch.utils._python_dispatch import TorchDispatchMode 4 | from torch._subclasses.fake_tensor import FakeTensorMode 5 | from torch.utils.weak import WeakIdKeyDictionary 6 | import weakref 7 | import math 8 | 9 | # Track all the memory being used by Tensors. 10 | # Only max is tracked but others can be added. 11 | MEMORY_USE = WeakIdKeyDictionary() 12 | MEMORY_MAX = 0 13 | # Minimum allocation size 14 | PYTORCH_MIN_ALLOCATE = 2**9 15 | 16 | def update_stats(): 17 | global MEMORY_MAX 18 | curr_use = 0 19 | for k, v in MEMORY_USE.items(): 20 | curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE 21 | 22 | if MEMORY_MAX < curr_use: 23 | MEMORY_MAX = curr_use 24 | 25 | # Should be called on every Tensor created 26 | def track(t:torch.Tensor): 27 | def cb(_): 28 | update_stats() 29 | st = t.untyped_storage() 30 | wt = weakref.ref(st, cb) 31 | MEMORY_USE[st] = wt 32 | update_stats() 33 | 34 | # Use this Mode to call track on every Tensor being created by functions 35 | class MemoryTrackingMode(TorchDispatchMode): 36 | def __torch_dispatch__(self, func, types, args, kwargs=None): 37 | res = func(*args, **kwargs or {}) 38 | 39 | tree_map_only(torch.Tensor, track, res) 40 | return res 41 | 42 | 43 | if __name__ == "__main__": 44 | # Use FakeTensorMode to run the code without any actual data 45 | with FakeTensorMode(), MemoryTrackingMode(): 46 | def f(a): 47 | b = a * 10 48 | d = b + 3 49 | return d 50 | 51 | a = torch.rand(100) 52 | f(a) 53 | f(a) 54 | print(f"Just f: {MEMORY_MAX}") 55 | c = f(a) 56 | c = f(a) 57 | print(f"f with return: {MEMORY_MAX}") 58 | 59 | 60 | -------------------------------------------------------------------------------- /memory_debugging_tensor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Subclass memory debugging 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1vxJkMT1kTUpd9RoqEzf0i825fvKNQAZl 8 | """ 9 | 10 | import torch 11 | # Some mode APIs are changing on master while this runs on colab 12 | # print(torch.__version__) 13 | # if not torch.__version__.split("+")[0] == "1.12.1": 14 | # raise RuntimeError("This notebook is for pytorch 1.12") 15 | 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | 20 | from collections import defaultdict 21 | from typing import Dict 22 | from contextlib import contextmanager 23 | 24 | from torch.utils._python_dispatch import TorchDispatchMode 25 | 26 | # We do want to use non-full hooks here as they behavior similar to 27 | # backward pre-hooks 28 | import warnings 29 | warnings.filterwarnings("ignore", "Using a non-full backward") 30 | 31 | MB = 1024 * 1024.0 32 | 33 | # Globals used to save state 34 | operator_names: Dict[str, int] = defaultdict(int) 35 | mem_usage: Dict[str, float] = defaultdict(float) 36 | max_mem_usage: Dict[str, float] = defaultdict(float) 37 | markers: Dict[str, int] = defaultdict(int) 38 | cur_module: str = "" 39 | op_id: int = 0 40 | 41 | def clear_state(): 42 | operator_names.clear() 43 | mem_usage.clear() 44 | max_mem_usage.clear() 45 | markers.clear() 46 | 47 | 48 | # To add markers in the final print 49 | def add_marker(marker_name): 50 | marker_val = len(mem_usage.values()) 51 | markers[marker_name] = marker_val 52 | 53 | def record_fn(fn_name): 54 | global op_id 55 | mem: float = torch.cuda.memory_allocated() / MB 56 | mem_usage[op_id] = (fn_name, mem) 57 | max_mem: float = torch.cuda.max_memory_allocated() / MB 58 | max_mem_usage[op_id] = (fn_name, max_mem) 59 | torch.cuda.reset_peak_memory_stats() 60 | op_id += 1 61 | 62 | # Mode that records all allocations 63 | class MemoryProfileDispatchMode(TorchDispatchMode): 64 | def __torch_dispatch__(self, func, types, args=..., kwargs=None): 65 | rs = func(*args, **kwargs) 66 | global cur_module 67 | if func == torch.ops.aten.detach.default: 68 | return rs 69 | func_name: str = cur_module + '.' + func.__name__ + "_" + str(operator_names[func.__name__]) 70 | operator_names[func.__name__] = operator_names[func.__name__] + 1 71 | record_fn(func_name) 72 | 73 | return rs 74 | 75 | # Functions to print and draw the graph 76 | def show_graph(): 77 | import matplotlib.pyplot as plt 78 | 79 | y = [gb for (name, gb) in mem_usage.values()] 80 | min_val = min(y) 81 | max_val = max(y) 82 | x = [i for i in range(len(y))] 83 | fig = plt.figure(figsize=(16,8)) 84 | plt.plot(x, list(y), label="memory") 85 | plt.xlabel("# Operator Calls") 86 | plt.ylabel("Allocated Memory (MB)") 87 | # plt.title(filename) 88 | for marker_name, marker in markers.items(): 89 | if marker_name == "fw_bw_boundary": 90 | plt.plot([marker, marker], [min_val, max_val], "r", lw=2, label=marker_name) 91 | else: 92 | plt.plot([marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name) 93 | plt.legend() 94 | 95 | def print_top_mem_op(top: int = 50): 96 | global op_id 97 | op_diff: Dict[str, float] = defaultdict(float) 98 | op, pre_mem = mem_usage[0] 99 | for i in range(1, op_id): 100 | op, mem = mem_usage[i] 101 | op_diff[op] = mem - pre_mem 102 | pre_mem = mem 103 | 104 | print("------------------------------------------------") 105 | print(f"Top {top} ops that generates memory are:") 106 | for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[:top]: 107 | print(f"{k}: {v}MB") 108 | print("------------------------------------------------") 109 | 110 | 111 | # Module level printing and logging to make the Mode's output better 112 | def mem_profile_model(mod: torch.nn.Module, *args): 113 | with torch.utils._python_dispatch.push_torch_dispatch_mode(MemoryProfileDispatchMode): 114 | torch.cuda.reset_peak_memory_stats() 115 | mod.zero_grad(True) 116 | clear_state() 117 | record_fn("Start") 118 | loss = mod(*args) 119 | add_marker("fw_bw_boundary") 120 | if isinstance(loss, dict): 121 | loss = loss['out'] 122 | loss.sum().backward() 123 | add_marker("bw_zero_boundary") 124 | mod.zero_grad(set_to_none=True) 125 | record_fn("Finished") 126 | 127 | def fwd_wrapped(name): 128 | def fwd_debug_hook(module, input) -> None: 129 | global cur_module 130 | cur_module = f"{name}.forward" 131 | return fwd_debug_hook 132 | 133 | def bwd_wrapped(name): 134 | def bwd_debug_hook(module, input, out) -> None: 135 | global cur_module 136 | cur_module = f"{name}.backward" 137 | return bwd_debug_hook 138 | 139 | # this context manager attached/detecheds hooks for debugging 140 | @contextmanager 141 | def debug_model(model): 142 | global op_id, cur_module 143 | hooks = [] 144 | cur_module = 'forward' 145 | op_id = 0 146 | for name, module in model.named_modules(): 147 | hooks.append(module.register_forward_pre_hook(fwd_wrapped(name))) 148 | hooks.append(module.register_backward_hook(bwd_wrapped(name))) 149 | try: 150 | yield model 151 | finally: 152 | for hook in hooks: 153 | hook.remove() 154 | 155 | def run_one_model(net, input): 156 | net.cuda() 157 | input = input.cuda() 158 | 159 | with debug_model(net) as m: 160 | # mem_profile_model(m, input1, input2) 161 | mem_profile_model(m, input) 162 | print_top_mem_op(20) 163 | 164 | show_graph() 165 | 166 | import torchvision 167 | run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda")) 168 | 169 | import torchvision 170 | run_one_model(torchvision.models.mobilenet_v3_large(), torch.rand(32, 3, 224, 224, device="cuda")) 171 | 172 | import torchvision 173 | run_one_model(torchvision.models.segmentation.fcn_resnet50(), torch.rand(32, 3, 224, 224, device="cuda")) 174 | 175 | import torchvision 176 | run_one_model(torchvision.models.vision_transformer.vit_b_32(), torch.rand(32, 3, 224, 224, device="cuda")) -------------------------------------------------------------------------------- /nan_detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._python_dispatch import TorchDispatchMode 3 | from torch.utils._pytree import tree_flatten 4 | 5 | class NanDetect(TorchDispatchMode): 6 | def __torch_dispatch__(self, func, types, args, kwargs=None): 7 | kwargs = kwargs or {} 8 | res = func(*args, **kwargs) 9 | flat_res, _ = tree_flatten(res) 10 | 11 | for t in flat_res: 12 | if not torch.is_tensor(t): 13 | continue 14 | try: 15 | if (t != t).any(): 16 | raise RuntimeError( 17 | f"Function {func}(*{args}, **{kwargs}) " "returned a NaN" 18 | ) 19 | except NotImplementedError: 20 | pass 21 | return res 22 | 23 | a = torch.tensor([0.,]) 24 | print(a.div(a)) 25 | 26 | # This will raise 27 | # RuntimeError: Function aten.div.Tensor(*(tensor([0.]), tensor([0.])), **{}) returned a NaN 28 | with NanDetect(): 29 | print(a.div(a)) 30 | 31 | -------------------------------------------------------------------------------- /negative_tensor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.autograd import Function 6 | from torch.testing._internal.common_utils import run_tests, TestCase 7 | from torch.utils._pytree import tree_map 8 | 9 | from utils import no_dispatch 10 | 11 | 12 | # This is a reimplementation of "negative tensor views" as currently 13 | # implemented in PyTorch core. This lets you represent a negation 14 | # without actually materializing the negated value, so it can be fused 15 | # with further operations. See also the PR that added this to PyTorch: 16 | # https://github.com/pytorch/pytorch/pull/56058 17 | class NegativeTensor(Tensor): 18 | @staticmethod 19 | def __new__(cls, elem): 20 | # At the moment, this class is not compositional, so we assert 21 | # that the tensor we're wrapping is exactly a Tensor 22 | assert type(elem) is Tensor 23 | 24 | # Note [Passing requires_grad=true tensors to subclasses] 25 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | # Calling _make_subclass directly in an autograd context is 27 | # never the right thing to do, as this will detach you from 28 | # the autograd graph. You must create an autograd function 29 | # representing the "constructor" (NegativeView, in this case) 30 | # and call that instead. This assert helps prevent direct usage 31 | # (which is bad!) 32 | assert not elem.requires_grad or not torch.is_grad_enabled() 33 | 34 | # There is something very subtle going on here. In particular, 35 | # suppose that elem is a view. Does all of the view metadata 36 | # (sizes, strides, storages) get propagated correctly? Yes! 37 | # Internally, the way _make_subclass works is it creates an 38 | # alias (using Tensor.alias) of the original tensor, which 39 | # means we replicate storage/strides, but with the Python object 40 | # as an instance of your subclass. In other words, 41 | # _make_subclass is the "easy" case of metadata propagation, 42 | # because anything that alias() propagates, you will get in 43 | # your subclass. It is _make_wrapper_subclass which is 44 | # problematic... 45 | # 46 | # TODO: We need to think about how we want to turn this into 47 | # official API. I am thinking that something that does the 48 | # assert above and this call could be made into a utility function 49 | # that is in the public API 50 | return Tensor._make_subclass(cls, elem) 51 | 52 | def __repr__(self): 53 | with no_dispatch(): 54 | return repr(self.neg()) 55 | 56 | def physical_repr(self): 57 | with no_dispatch(): 58 | return f"negative_view({super().__repr__()})" 59 | 60 | # Without this, the default __torch_function__ implementation will 61 | # attempt to wrap the returned tensor for any operation in a NegativeView 62 | # (wrong wrong wrong) 63 | __torch_function__ = torch._C._disabled_torch_function_impl 64 | 65 | @classmethod 66 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 67 | # TODO: inplace and out 68 | 69 | # This implements fallback behavior, where we materialize the 70 | # negative view into a normal tensor, and then do the operation on 71 | # normal tensors. Because we eliminate all negative views before 72 | # performing our operation, no_dispatch() is not necessary here. 73 | def unwrap(t): 74 | if isinstance(t, cls): 75 | with no_dispatch(): 76 | return t.neg() 77 | else: 78 | return t 79 | 80 | return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 81 | 82 | 83 | # A differentiable function that takes a negative view on a function. Use 84 | # this to construct NegativeTensors. 85 | class NegativeView(Function): 86 | @staticmethod 87 | def forward(ctx, input): 88 | # Exact type matches as NegativeTensor is not compositional 89 | if type(input) is NegativeTensor: 90 | # If we are passed a NegativeTensor, we can simply alias it as 91 | # a normal tensor and return it. 92 | # TODO: this should be in standard library 93 | return torch.Tensor._make_subclass(torch.Tensor, input) 94 | elif type(input) is Tensor: 95 | return NegativeTensor(input) 96 | else: 97 | raise AssertionError("negative tensors are not yet compositional") 98 | 99 | @staticmethod 100 | def backward(ctx, grad): 101 | return negative_view(grad) 102 | 103 | 104 | negative_view = NegativeView.apply 105 | 106 | 107 | class NegativeTensorTest(TestCase): 108 | def test_construction(self): 109 | # NegativeTensor is semantically equivalent to negating the tensor 110 | self.assertEqual(NegativeTensor(torch.tensor(1)), torch.tensor(-1)) 111 | self.assertEqual(negative_view(torch.tensor(1)), torch.tensor(-1)) 112 | 113 | # The direct constructor is not valid in autograd contexts; you must 114 | # use negative_view 115 | self.assertRaises( 116 | Exception, lambda: NegativeTensor(torch.empty(1, requires_grad=True)) 117 | ) 118 | self.assertRaises( 119 | Exception, lambda: NegativeTensor(torch.empty(1, requires_grad=True).sum()) 120 | ) 121 | negative_view(torch.empty(1, requires_grad=True)) 122 | negative_view(torch.empty(1, requires_grad=True).sum()) 123 | 124 | # The tensor is aliases with its original 125 | x = torch.tensor(1.0) 126 | y = negative_view(x) 127 | self.assertEqual(y, torch.tensor(-1.0)) 128 | x.add_(1) 129 | self.assertEqual(y, torch.tensor(-2.0)) 130 | 131 | def test_repr(self): 132 | x = negative_view(torch.tensor(1)) 133 | 134 | # I decided to make the normal repr print "as if" it were a normal 135 | # tensor 136 | self.assertExpectedInline(repr(x), """tensor(-1)""") 137 | 138 | # physical_repr tells you if something funny is going on 139 | self.assertExpectedInline( 140 | x.physical_repr(), 141 | """\ 142 | negative_view(NegativeTensor(1))""", 143 | ) 144 | 145 | def test_functional(self): 146 | self.assertEqual(negative_view(torch.tensor(1)) + 1, torch.tensor(0)) 147 | 148 | def test_backward(self): 149 | base = torch.tensor(-1.0, requires_grad=True) 150 | x = negative_view(base) 151 | x.sum().backward() 152 | self.assertEqual(base.grad, torch.tensor(-1.0)) 153 | 154 | def test_negative_view_of_view(self): 155 | base = torch.zeros(2, 2) 156 | view = base[0] 157 | neg_view = negative_view(view) 158 | self.assertEqual(neg_view, torch.zeros(2)) 159 | base[0, 0].add_(1) 160 | base[0, 1].add_(2) 161 | base[1, 0].add_(3) 162 | base[1, 1].add_(4) 163 | self.assertEqual(neg_view, torch.tensor([-1.0, -2.0])) 164 | 165 | # autograd custom functions with views don't work 166 | # tracked in https://github.com/pytorch/pytorch/issues/73604 167 | @unittest.expectedFailure 168 | def test_view_backward(self): 169 | base = torch.tensor(1.0, requires_grad=True) 170 | z = base * 1 171 | x = negative_view(z) 172 | z.mul_(-1) 173 | # Uncomment this line, which manually recomputes the view, to make this 174 | # test pass while master is broken 175 | # x = negative_view(z) 176 | x.sum().backward() 177 | self.assertEqual(base.grad, torch.tensor(1.0)) 178 | 179 | @unittest.expectedFailure 180 | def test_non_subclass_view_backward(self): 181 | class Alias(Function): 182 | @staticmethod 183 | def forward(ctx, input): 184 | return input[:] 185 | 186 | @staticmethod 187 | def backward(ctx, grad): 188 | return grad 189 | 190 | base = torch.tensor([1.0], requires_grad=True) 191 | z = base * 1 192 | x = Alias.apply(z) 193 | z.mul_(-1) 194 | x.sum().backward() 195 | self.assertEqual(base.grad, torch.tensor([-1.0])) 196 | 197 | 198 | if __name__ == "__main__": 199 | run_tests() 200 | -------------------------------------------------------------------------------- /nested_forward_ad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import forward_ad as fwAD 3 | from torch import Tensor 4 | from torch.utils._pytree import tree_map 5 | from torch.testing._internal.common_utils import ( 6 | TestCase, 7 | run_tests, 8 | disable_gc, 9 | parametrize, 10 | instantiate_parametrized_tests, 11 | ) 12 | from torch.overrides import enable_reentrant_dispatch 13 | 14 | import functools 15 | import contextlib 16 | 17 | from utils import no_dispatch 18 | from base_tensor import BaseTensor 19 | 20 | # WARNING: 21 | # This class requires https://github.com/pytorch/pytorch/pull/73925 (that was reverted) 22 | # to properly work with forward AD implementation 23 | # If you get an error about "Expected this function to only be reached in inference mode" 24 | # then you don't have that patch! 25 | 26 | # This class wraps a pytorch dual Tensor and associates a level to it. 27 | # This allows to do multi-level forward AD even though pytorch only 28 | # support one level. 29 | class ForwardADTensor(BaseTensor): 30 | @staticmethod 31 | def __new__(cls, dual_t, *, level, ignore_no_grad=False): 32 | # Use this to check if the plain object has a forward grad or not while ignoring 33 | # all of the torch_dispatch handling 34 | with no_dispatch(): 35 | primal, tangent = fwAD.unpack_dual(dual_t) 36 | # Ensure we actually have a dual Tensor 37 | assert ( 38 | ignore_no_grad or tangent is not None 39 | ), "ForwardADTensor can only wrap Tensors with forward gradients" 40 | # Ensure that nesting is happening in the right order 41 | if isinstance(dual_t, cls): 42 | assert dual_t.level < level, "Level ordering is wrong!" 43 | res = super().__new__(cls, primal) 44 | return res 45 | 46 | def __repr__(self): 47 | # Use no_dispatch here to get a plain representation of this Tensor without any of the 48 | # torch_dispatch handling 49 | with no_dispatch(): 50 | self_repr = super().__repr__() 51 | return f"ForwardADTensor{self.level}({self_repr}, {self.elem!r})" 52 | 53 | def __init__(self, dual_t, *, level, ignore_no_grad=False): 54 | self.elem = dual_t 55 | self.level = level 56 | 57 | @classmethod 58 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 59 | # Detach is a special case here. 60 | # This special case is for the code from autograd that uses shallow_copy_and_detach 61 | # (which is rerouted to detach in torch_dispatch) and user code that calls detach 62 | # In this case, we want to only detach *one* level of forward grad. Since forward grad 63 | # is already handled before getting here, we just want to convert detach into alias before 64 | # applying it to the underlying Tensor. 65 | # We also need a special case to force wrapping even though there aren't any forward grad (yet) 66 | # as the forward grad will be associated to the result in the dispatcher on the return from this 67 | # call. 68 | ignore_no_grad = False 69 | if func is torch.ops.aten.detach.default: 70 | ignore_no_grad = True 71 | func = torch.ops.aten.alias.default 72 | 73 | max_level = -1 74 | 75 | def find_level(t): 76 | nonlocal max_level 77 | if isinstance(t, cls): 78 | max_level = max(max_level, t.level) 79 | 80 | # TODO: don't use tree_map 81 | tree_map(find_level, args) 82 | tree_map(find_level, kwargs) 83 | 84 | def matches_level(t): 85 | return isinstance(t, cls) and t.level == max_level 86 | 87 | def unwrap(t): 88 | # All the Tensors at this level must be unpacked so that the new call into the 89 | # dispatcher will handle this level of forward AD 90 | if matches_level(t): 91 | return t.elem 92 | else: 93 | # If we get a forward AD Tensor here, its level have been handled in the dispatcher 94 | # call that lead to this torch dispatch. So now we want to just consider it as a 95 | # constant for level during the next call into the dispatcher. 96 | if ( 97 | isinstance(t, torch.Tensor) 98 | and fwAD.unpack_dual(t).tangent is not None 99 | ): 100 | return fwAD.unpack_dual(t).primal 101 | return t 102 | 103 | def wrap(t): 104 | if isinstance(t, Tensor) and not matches_level(t): 105 | # Only wrap Tensors that have a tangent 106 | # or are about to get one (when calling detach) 107 | tp, td = fwAD.unpack_dual(t) 108 | if td is not None or ignore_no_grad: 109 | return cls(t, level=max_level, ignore_no_grad=ignore_no_grad) 110 | return t 111 | 112 | with enable_reentrant_dispatch(): 113 | return tree_map( 114 | wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 115 | ) 116 | 117 | 118 | class NestedForwardADTest(TestCase): 119 | def test_basic(self): 120 | # We could have a better UX for making sure forward AD is enabled. 121 | # For simplicity here, we just keep it enabled for all the test 122 | with fwAD.dual_level(): 123 | t_p = torch.rand(2) 124 | t_t = torch.rand(2) 125 | t = ForwardADTensor(fwAD.make_dual(t_p, t_t), level=0) 126 | out = t * 2 127 | out_p, out_t = fwAD.unpack_dual(out.elem) 128 | self.assertEqual(out_p, 2 * t_p) 129 | self.assertEqual(out_t, 2 * t_t) 130 | 131 | def test_nested(self): 132 | with fwAD.dual_level(): 133 | t_p = torch.rand(2) 134 | t_t = torch.rand(2) 135 | t = ForwardADTensor(fwAD.make_dual(t_p, t_t), level=1) 136 | 137 | t2_t = torch.rand(2) 138 | # There is only one order of nesting that makes sense! 139 | with self.assertRaisesRegex(AssertionError, "Level ordering is wrong!"): 140 | t2 = ForwardADTensor(fwAD.make_dual(t, t2_t), level=0) 141 | 142 | # Note that both gradients are on the primal. So we do *not* compute 143 | # higher order derivatives here! 144 | t2 = ForwardADTensor(fwAD.make_dual(t, t2_t), level=2) 145 | 146 | # Make sure that t2 has all the right metadata 147 | self.assertIsInstance(t2, ForwardADTensor) 148 | self.assertEqual(t2.level, 2) 149 | self.assertEqual(t2, t_p) 150 | self.assertIsNone(fwAD.unpack_dual(t2).tangent) 151 | elem = t2.elem 152 | self.assertIsInstance(elem, ForwardADTensor) 153 | self.assertEqual(elem.level, 1) 154 | self.assertEqual(elem, t_p) 155 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t2_t) 156 | elem = elem.elem 157 | self.assertNotIsInstance(elem, ForwardADTensor) 158 | self.assertEqual(elem, t_p) 159 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t_t) 160 | 161 | # Simple op that doesn't take extra arguments 162 | out = t2.exp() 163 | 164 | # Make sure that ops of t2 compute both levels of autograd independently 165 | self.assertIsInstance(out, ForwardADTensor) 166 | self.assertEqual(out.level, 2) 167 | self.assertEqual(out, t_p.exp()) 168 | self.assertIsNone(fwAD.unpack_dual(out).tangent) 169 | elem = out.elem 170 | self.assertIsInstance(elem, ForwardADTensor) 171 | self.assertEqual(elem.level, 1) 172 | self.assertEqual(elem, t_p.exp()) 173 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t2_t * t_p.exp()) 174 | elem = elem.elem 175 | self.assertNotIsInstance(elem, ForwardADTensor) 176 | self.assertEqual(elem, t_p.exp()) 177 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t_t * t_p.exp()) 178 | 179 | # Computing higher order derivatives now! 180 | t = ForwardADTensor(fwAD.make_dual(t_t, t2_t), level=1) 181 | t2 = ForwardADTensor(fwAD.make_dual(t_p, t), level=2) 182 | 183 | # Make sure that t2 has all the right metadata 184 | self.assertIsInstance(t2, ForwardADTensor) 185 | self.assertEqual(t2.level, 2) 186 | self.assertEqual(t2, t_p) 187 | self.assertIsNone(fwAD.unpack_dual(t2).tangent) 188 | elem = t2.elem 189 | self.assertNotIsInstance(elem, ForwardADTensor) 190 | self.assertEqual(elem, t_p) 191 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t_t) 192 | elem = fwAD.unpack_dual(elem).tangent 193 | self.assertIsInstance(elem, ForwardADTensor) 194 | self.assertEqual(elem.level, 1) 195 | self.assertEqual(elem, t_t) 196 | self.assertIsNone(fwAD.unpack_dual(elem).tangent) 197 | elem = elem.elem 198 | self.assertNotIsInstance(elem, ForwardADTensor) 199 | self.assertEqual(elem, t_t) 200 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t2_t) 201 | 202 | # An op with different first and second derivative 203 | out = t2.pow(2) 204 | 205 | # Make sure that ops of t2 computes higher order derivatives 206 | self.assertIsInstance(out, ForwardADTensor) 207 | self.assertEqual(out.level, 2) 208 | self.assertEqual(out, t_p.pow(2)) 209 | self.assertIsNone(fwAD.unpack_dual(out).tangent) 210 | elem = out.elem 211 | self.assertNotIsInstance(elem, ForwardADTensor) 212 | self.assertEqual(elem, t_p.pow(2)) 213 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t_t * 2 * t_p) 214 | elem = fwAD.unpack_dual(elem).tangent 215 | self.assertIsInstance(elem, ForwardADTensor) 216 | self.assertEqual(elem.level, 1) 217 | self.assertEqual(elem, t_t * 2 * t_p) 218 | self.assertIsNone(fwAD.unpack_dual(elem).tangent) 219 | elem = elem.elem 220 | self.assertNotIsInstance(elem, ForwardADTensor) 221 | self.assertEqual(elem, t_t * 2 * t_p) 222 | self.assertEqual(fwAD.unpack_dual(elem).tangent, t2_t * 2 * t_p) 223 | 224 | def test_no_confusion(self): 225 | # This test ensure that we don't do "perturbation confusion" 226 | # meaning that gradients at each levels are indeed computed independently 227 | # and don't interact with each other 228 | with fwAD.dual_level(): 229 | t_p = torch.rand(2) 230 | t_t = torch.rand(2) 231 | t = ForwardADTensor(fwAD.make_dual(t_p, t_t), level=0) 232 | t2_p = torch.rand(2) 233 | t2_t = torch.rand(2) 234 | t2 = ForwardADTensor(fwAD.make_dual(t2_p, t2_t), level=1) 235 | 236 | mixed_out = t * t2 237 | 238 | mixed_out_lvl1_p, mixed_out_lvl1_t = fwAD.unpack_dual(mixed_out.elem) 239 | mixed_out_lvl0_p, mixed_out_lvl0_t = fwAD.unpack_dual(mixed_out.elem.elem) 240 | self.assertEqual(mixed_out_lvl1_p, t_p * t2_p) 241 | self.assertEqual(mixed_out_lvl1_t, t2_t * t_p) 242 | self.assertEqual(mixed_out_lvl0_p, t_p * t2_p) 243 | self.assertEqual(mixed_out_lvl0_t, t_t * t2_p) 244 | 245 | 246 | if __name__ == "__main__": 247 | run_tests() 248 | -------------------------------------------------------------------------------- /new_device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.overrides import TorchFunctionMode 3 | from torch.utils._pytree import tree_map 4 | 5 | import numpy as np 6 | 7 | aten = torch.ops.aten 8 | 9 | # 1. A Tensor that stores custom raw_data and implement functions for it 10 | class MyDeviceTensor(torch.Tensor): 11 | IMPLEMENTATIONS = {} 12 | 13 | @staticmethod 14 | def __new__(cls, size, dtype, raw_data=None, requires_grad=False): 15 | # Use a meta Tensor here to be used as the wrapper 16 | return torch.Tensor._make_subclass( 17 | cls, 18 | torch.empty(size, dtype=dtype, device="meta"), 19 | require_grad=requires_grad, 20 | ) 21 | 22 | def __init__(self, size, dtype, raw_data=None, requires_grad=False): 23 | # Store any provided user raw_data 24 | self.raw_data = raw_data 25 | 26 | def __repr__(self): 27 | st = super().__repr__() 28 | st = st.replace("device='meta'", "device='my_device'") 29 | # Print the content the best way possible 30 | new_content = "[" + ", ".join(str(el) for el in self.raw_data) + "]" 31 | st = st.replace("...", new_content) 32 | return st 33 | 34 | @classmethod 35 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 36 | if func in cls.IMPLEMENTATIONS: 37 | try: 38 | 39 | def super_fn(*args, **kwargs): 40 | return super(cls, MyDeviceTensor).__torch_dispatch__( 41 | func, types, args, kwargs 42 | ) 43 | 44 | return cls.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {}) 45 | except Exception as e: 46 | print(e) 47 | raise e 48 | raise RuntimeError( 49 | f"No implementation for 'my_device' for {func}, {args}, {kwargs}" 50 | ) 51 | 52 | 53 | # Convenient wrapper to register functions 54 | def implements(func): 55 | def _inner_fn(impl): 56 | MyDeviceTensor.IMPLEMENTATIONS[func] = impl 57 | return impl 58 | 59 | return _inner_fn 60 | 61 | 62 | # Add some ops 63 | @implements(aten.add.Tensor) 64 | def add(super_fn, t1, t2): 65 | # You can do whatever you want with the raw data here 66 | # In particular, this can call any c++ code as needed. 67 | out = t1.raw_data + t2.raw_data 68 | return MyDeviceTensor(t1.size(), t1.dtype, raw_data=out) 69 | 70 | 71 | @implements(aten.mul.Tensor) 72 | def mul(super_fn, t1, t2): 73 | # If unsure what should be the result's properties, you can 74 | # use the super_fn (can be useful for type promotion) 75 | meta_out = super_fn(t1, t2) 76 | 77 | out = t1.raw_data * t2.raw_data 78 | return MyDeviceTensor(meta_out.size(), meta_out.dtype, raw_data=out) 79 | 80 | 81 | # Add some trivial ops that need impl 82 | @implements(aten.detach.default) 83 | @implements(aten.alias.default) 84 | def detach(super_fn, self): 85 | return super_fn(self) 86 | 87 | 88 | # 2. A mode that allows us to override factory functions 89 | # This needs to be a torch function mode before the arg parser creates a device 90 | # based on the passed string, so we need to change it before reaching the arg parser 91 | class MyDeviceMode(TorchFunctionMode): 92 | IMPLEMENTATIONS = {} 93 | 94 | def __torch_function__(self, func, types, args=(), kwargs=None): 95 | def super_fn(*args, **kwargs): 96 | # Disable torch_function by hand because we don't want the wrapping behavior of 97 | # the super() impl 98 | with torch._C.DisableTorchFunction(): 99 | return func(*args, **kwargs) 100 | 101 | if func in self.IMPLEMENTATIONS: 102 | try: 103 | return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {}) 104 | except Exception as e: 105 | print(e) 106 | raise e 107 | # This is just a no-op for all the non-factory functions: 108 | return super_fn(*args, **kwargs or {}) 109 | 110 | 111 | # Convenient wrapper to register functions 112 | def implements_factory(func): 113 | def _inner_fn(impl): 114 | MyDeviceMode.IMPLEMENTATIONS[func] = impl 115 | return impl 116 | 117 | return _inner_fn 118 | 119 | 120 | def enable_my_device(): 121 | # Globally enable the mode 122 | holder = MyDeviceMode() 123 | holder.__enter__() 124 | 125 | 126 | # And some factory functions 127 | # By hand 128 | @implements_factory(torch.Tensor.to) 129 | def to(super_fn, self, device): 130 | # Note that we only implement a subset of .to() here 131 | if device == "my_device": 132 | return MyDeviceTensor(self.size(), self.dtype, self.numpy()) 133 | elif isinstance(self, MyDeviceTensor): 134 | return torch.from_numpy(self.raw_data).to(device) 135 | else: 136 | return super_fn(self, device) 137 | 138 | 139 | # Have a nicer way to add many factories 140 | def get_factory_wrapper(func): 141 | def inner(super_fn, size, **kwargs): 142 | if str(kwargs.get("device", None)) != "my_device": 143 | return super_fn(size, **kwargs) 144 | 145 | return MyDeviceTensor(size, kwargs.get("dtype", torch.float32), func(size)) 146 | 147 | return inner 148 | 149 | 150 | implements_factory(torch.rand)(get_factory_wrapper(np.random.rand)) 151 | implements_factory(torch.arange)(get_factory_wrapper(np.arange)) 152 | implements_factory(torch.empty)(get_factory_wrapper(np.empty)) 153 | 154 | 155 | if __name__ == "__main__": 156 | enable_my_device() 157 | 158 | # 3. Show what it does in practice 159 | size = (2, 2) 160 | t1 = MyDeviceTensor(size, torch.float32, np.ones(size)) 161 | t2 = MyDeviceTensor(size, torch.float32, np.arange(size[0] * size[1]).reshape(size)) 162 | print("Inputs:") 163 | print(t1) 164 | print(t2) 165 | 166 | out = torch.add(t1, t2) 167 | print("torch.add(t1, t2):") 168 | print(out) 169 | 170 | out = t1 * t2 171 | print("t1 * t2:") 172 | print(out) 173 | 174 | # Factory functions 175 | t1 = torch.empty(4, device="my_device") 176 | print("Empty Tensor (un-initialized memory!):") 177 | print(t1) 178 | 179 | t1 = torch.rand(4, device="my_device") 180 | print("Random Tensor:") 181 | print(t1) 182 | 183 | t1 = torch.arange(4, device="my_device") 184 | print("Arange Tensor:") 185 | print(t1) 186 | 187 | t1 = torch.rand(5) 188 | print("Cpu Tensor:") 189 | print(t1) 190 | print("t2 = t1.to('my_device'):") 191 | t2 = t1.to("my_device") 192 | print(t2) 193 | print("t2.to('cpu'):") 194 | print(t2.to("cpu")) 195 | -------------------------------------------------------------------------------- /numerical_consistency_mode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._python_dispatch import TorchDispatchMode 3 | from torch.utils._pytree import tree_map_only 4 | 5 | 6 | # The goal of this mode is to check that two device are consistent 7 | # in what they compute. 8 | # We do NOT run the two models in parallel, we only branch at the op 9 | # level to make sure the two branch don't slowly diverge. 10 | 11 | def as_tuple(o): 12 | if isinstance(o, tuple): 13 | return o 14 | else: 15 | return (o,) 16 | 17 | class ConsistentWithCPUMode(TorchDispatchMode): 18 | def __torch_dispatch__(self, func, types, args, kwargs): 19 | orig_out = func(*args, **kwargs) 20 | 21 | # Run the same thing on CPU 22 | # and convert original outputs to CPU 23 | cpu_args, cpu_kwargs, orig_cpu_out = tree_map_only(torch.Tensor, 24 | lambda x: x.cpu(), 25 | (args, kwargs, orig_out)) 26 | cpu_out = func(*cpu_args, **cpu_kwargs) 27 | 28 | # Make sure the output is close enough! 29 | for orig, cpu in zip(as_tuple(orig_cpu_out), as_tuple(cpu_out)): 30 | torch.testing.assert_close(orig, cpu) 31 | 32 | 33 | return orig_out 34 | 35 | t = torch.rand(100, device="cuda") 36 | 37 | # This should work just fine! 38 | with ConsistentWithCPUMode(): 39 | t2 = t + 2 40 | t3 = t2.norm() 41 | t4 = t2 / t3 42 | 43 | 44 | # Let's break some cuda impl! 45 | def my_new_norm_is_actually_a_mean(t): 46 | return t.mean() 47 | 48 | aten = torch.library.Library("aten", "IMPL") 49 | aten.impl("linalg_vector_norm", my_new_norm_is_actually_a_mean, "CUDA") 50 | 51 | # We should see that the impl is not correct anymore! 52 | with ConsistentWithCPUMode(): 53 | t2 = t + 2 54 | try: 55 | t3 = t2.norm() 56 | except AssertionError as e: 57 | print("Norm evaluation failed as expected:") 58 | print(e) 59 | else: 60 | raise AssertionError("Error was not raised!") 61 | 62 | -------------------------------------------------------------------------------- /progressive_lowering_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import no_dispatch 3 | from torch.utils._pytree import tree_map 4 | from torch.testing._internal.common_utils import run_tests, TestCase 5 | 6 | CALLED = [] 7 | 8 | 9 | class ProgressiveLoweringTensor(torch.Tensor): 10 | @classmethod 11 | def wrap(cls, t): 12 | if isinstance(t, torch.Tensor) and not isinstance(t, cls): 13 | return cls(t) 14 | else: 15 | return t 16 | 17 | @classmethod 18 | def __torch_function__(cls, func, types, args=(), kwargs=None): 19 | if kwargs is None: 20 | kwargs = {} 21 | if func is torch.Tensor.relu: 22 | CALLED.append(func) 23 | with torch._C.DisableTorchFunction(): 24 | with no_dispatch(): 25 | return tree_map(cls.wrap, func(*args, **kwargs)) 26 | else: 27 | with torch._C.DisableTorchFunction(): 28 | return func(*args, **kwargs) 29 | 30 | @classmethod 31 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 32 | CALLED.append(func) 33 | with no_dispatch(): 34 | return tree_map(cls.wrap, func(*args, **kwargs)) 35 | 36 | 37 | class ProgressiveLoweringTensorTest(TestCase): 38 | def test_basic(self): 39 | CALLED.clear() 40 | x = ProgressiveLoweringTensor(torch.randn(2)) 41 | x.add(2).relu() 42 | # add call is low level aten op; relu call is high level torch 43 | # op 44 | self.assertEqual(CALLED, [torch.ops.aten.add.Tensor, torch.Tensor.relu]) 45 | 46 | 47 | if __name__ == "__main__": 48 | run_tests() 49 | -------------------------------------------------------------------------------- /py_dispatcher.py: -------------------------------------------------------------------------------- 1 | from torch._dispatch.python import enable_python_dispatcher, no_python_dispatcher 2 | import torch 3 | 4 | @torch.ops.aten.sub.Tensor.py_impl(torch._C.DispatchKey.CPU) 5 | def my_sub(x, y): 6 | print("Hello") 7 | # This private API permits dispatcher to return to Python dispatcher if 8 | # there are internal dispatches. 9 | # return torch.ops.aten.sub.Tensor._op_dk(torch._C.DispatchKey.CPU, x, y) 10 | with no_python_dispatcher(): 11 | return torch.ops.aten.sub.Tensor(x, y) 12 | 13 | x = torch.tensor(2) 14 | with enable_python_dispatcher(): 15 | print(torch.sub(x, x)) 16 | 17 | # Hack to apply it globally 18 | ctx = enable_python_dispatcher() 19 | ctx.__enter__() 20 | 21 | print(torch.sub(x, x)) 22 | -------------------------------------------------------------------------------- /python_meta_tensor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is obsolete; check out https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py for the modern way to do this 3 | """ 4 | -------------------------------------------------------------------------------- /quantization_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils._pytree import _register_pytree_node 3 | from torch.fx.experimental.proxy_tensor import make_fx 4 | import torch.fx 5 | from torch.fx import subgraph_rewriter 6 | from torch.overrides import wrap_torch_function 7 | 8 | # We use a QTensor class to conveniently pass around both int8 tensor 9 | # as well as scale and zero point necessary to quantize/dequantize them 10 | # when we do graph transformations. 11 | 12 | # TODO: Name this something more specific like PerTensorAffineInt8QuantizedTensor 13 | class QTensor: 14 | tensor: torch.Tensor 15 | # NB: you could represent these as scalar tensors if you need to 16 | # trace through them 17 | scale: float 18 | zero_point: int 19 | 20 | def __init__(self, tensor, scale, zero_point): 21 | self.tensor = tensor 22 | self.scale = scale 23 | self.zero_point = zero_point 24 | 25 | # NB: wrap_torch_function so that this "factory" function can be 26 | # symbolically traced as is. This is not strictly necessary but 27 | # makes constructing patterns more convenient. 28 | @staticmethod 29 | @wrap_torch_function(lambda t, x, y: (t, )) 30 | def quantize(t: torch.Tensor, scale: float, zero_point: int): 31 | i8_min = torch.iinfo(torch.int8).min 32 | i8_max = torch.iinfo(torch.int8).max 33 | # This formula is probably not quite right, fix it as necessary 34 | return QTensor( 35 | torch.clamp(torch.round(t / scale).to(torch.int64) + zero_point, i8_min, i8_max).to(torch.int8), 36 | scale, 37 | zero_point 38 | ) 39 | 40 | def dequantize(self): 41 | return (self.tensor.to(torch.int64) - self.zero_point) * self.scale 42 | 43 | # We register it as a pytree node, as in the final graph we want QTensor 44 | # to be eliminated completely (aka QTensor is an entirely out of core concept) 45 | # TODO: We probably could have made QTensor a named tuple and then wouldn't 46 | # need explicit flatten/unflatten 47 | 48 | def _qtensor_flatten(q): 49 | return [q.tensor, q.scale, q.zero_point], None 50 | 51 | def _qtensor_unflatten(values, context): 52 | return QTensor(*values) 53 | 54 | _register_pytree_node(QTensor, _qtensor_flatten, _qtensor_unflatten) 55 | 56 | # Let's take a simple model that runs linear twice 57 | 58 | def f(inp, linear_weight): 59 | r = torch.nn.functional.linear(inp, linear_weight) 60 | return torch.nn.functional.linear(r, linear_weight) 61 | 62 | # We use the pattern matching API to look for occurrences of linear. 63 | 64 | # We use make_fx to generate the sequence of ATen ops that correspond to a 65 | # linear call. Note that this pattern is only valid if there aren't any 66 | # conditions on, e.g., the shapes of the input tensor. In general you 67 | # may need a pattern for every permutation of how a composite operator may 68 | # lower; you can get all of them by running through a sufficiently large 69 | # number of example inputs. 70 | # TODO: use symbolic shapes here; this would give you a series of guards 71 | # that would tell you what input sizes the pattern is valid for. 72 | linear_pattern = make_fx(lambda i, w: torch.nn.functional.linear(i, w))(torch.randn(0, 0), torch.randn(0, 0)) 73 | 74 | # In reality we would first insert observers, and then actually 75 | # insert quantize/dequantize nodes. In this PoC, I skip observers 76 | # and go straight to quantize/dequantize, and make up random crap for 77 | # the observed quantities. 78 | def linear_replace_fn(i, w): 79 | fp_i = i.dequantize() 80 | fp_w = w.dequantize() 81 | fp_r = torch.nn.functional.linear(fp_i, fp_w) 82 | # TODO: get the scale and zero_point from observer 83 | return QTensor.quantize(fp_r, 5.0, 0) 84 | linear_replace = torch.fx.symbolic_trace(linear_replace_fn) 85 | 86 | # We first trace out the ATen OP IR of the original model 87 | inp = torch.randn(3, 4) 88 | weight = torch.randn(4, 4) 89 | g = make_fx(f)(inp, weight) 90 | print(g) 91 | 92 | # Now, we replace occurrences of linear with quantize/dequantize 93 | subgraph_rewriter.replace_pattern(g, linear_pattern, linear_replace) 94 | print(g) 95 | 96 | # Finally, we retrace the model to get lowered operations in terms 97 | # of only pure PyTorch operations. 98 | # TODO: use an interpreter here to preserve stack traces 99 | g2 = make_fx(g)(QTensor(inp, 5.0, 0), QTensor(weight, 5.0, 0)) 100 | print(g2) 101 | -------------------------------------------------------------------------------- /quantized_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import _pytree as pytree 3 | from torch.autograd import Function 4 | 5 | # This example is showing how to implement a QuantizedTensor and how to get it to interact with 6 | # autograd smootly. 7 | # This is ONLY the QuantizedTensor which does a few things: 8 | # - Hold only the low precision data 9 | # - Route implementation to the right custom kernel when available 10 | # - Perform type promotion to use fallbackward when custom kernel not available 11 | # - "pretends" to be a full precision floating point Tensor to the outside world 12 | 13 | class Quantizer(Function): 14 | @staticmethod 15 | def forward(ctx, base): 16 | # Just to do the quantization 17 | out_data = base.to(torch.int8) 18 | return QuantizedTensor(out_data, base.dtype) 19 | 20 | @staticmethod 21 | def backward(ctx, gO): 22 | # Assume we always do gradient computation in full precision 23 | return gO 24 | 25 | # Small util, should exist somewhere else? 26 | def compare_dtype(d1, d2): 27 | if d1.is_floating_point: 28 | return d1 29 | elif d2.is_floating_point: 30 | return d2 31 | else: 32 | assert False, "NYI" 33 | 34 | class QuantizedTensor(torch.Tensor): 35 | @staticmethod 36 | def __new__(cls, data, dtype, requires_grad=False): 37 | # This constructor can ONLY create leaf Tensors wrt autograd. 38 | # Use QuantizedTensor.from_tensor(t) to get a non-leaf Tensor wrt autograd. 39 | return torch.Tensor._make_wrapper_subclass(cls, data.size(), dtype=dtype, requires_grad=requires_grad) 40 | 41 | def __init__(self, data, dtype, requires_grad=False): 42 | self._data = data 43 | 44 | __torch_function__ = torch._C._disabled_torch_function_impl 45 | 46 | def __repr__(self): # Zero out missing values for printing 47 | autograd_info = f", grad_fn={self.grad_fn}" if self.grad_fn else f", requires_grad=True" if self.requires_grad else "" 48 | return f"QuantizedTensor({self._data}, public_dtype={self.dtype}{autograd_info})" 49 | 50 | @classmethod 51 | def from_tensor(cls, base): 52 | # This is a differentiable function!! 53 | return Quantizer.apply(base) 54 | 55 | @classmethod 56 | def __torch_dispatch__(cls, func, types, args, kwargs=None): 57 | # Basic implementation that will need refinement based on what should be upcasted or not 58 | # similar to amp. 59 | # For now, just do the compute in the highest precision of any input and requantize 60 | # like the first one. While ignoring all non-floating point dtypes. 61 | base_qt_tensor = None 62 | for a in args: 63 | if isinstance(a, QuantizedTensor): 64 | base_qt_tensor = a 65 | break 66 | assert base_qt_tensor is not None 67 | inp_dtype = base_qt_tensor._data.dtype 68 | out_public_dtype = base_qt_tensor.dtype 69 | # Unpack QuantizedTensor 70 | args, kwargs = pytree.tree_map_only(QuantizedTensor, lambda x: x._data, (args, kwargs or {})) 71 | # Get highest dtype 72 | highest_type = inp_dtype 73 | def check_type(t): 74 | nonlocal highest_type 75 | if t.dtype.is_floating_point and compare_dtype(t.dtype, highest_type): 76 | highest_type = t.dtype 77 | pytree.tree_map_only(torch.Tensor, check_type, (args, kwargs)) 78 | # Promote everything to the right dtype 79 | args, kwargs = pytree.tree_map_only(torch.Tensor, lambda x: x.to(highest_type) if x.dtype.is_floating_point else x, (args, kwargs)) 80 | # Run the original function with the new dtype 81 | # This can also be a custom kernel if you need 82 | raw_out = func(*args, **kwargs) 83 | # Rewrap everything back 84 | # Since we're below autograd, we don't need to use from_tensor 85 | def repack(t): 86 | if t.dtype is highest_type: 87 | if highest_type.is_floating_point: 88 | # Requantize back to input dtype if we computed in float 89 | return QuantizedTensor(t.to(inp_dtype), out_public_dtype) 90 | else: 91 | # Otherwise keep it as-is 92 | return QuantizedTensor(t, out_public_dtype) 93 | # Just a hack for sum that has higher precision result, shouldn't happen if you have 94 | # custom kernels 95 | elif func is torch.ops.aten.sum.default and t.dtype is torch.int64: 96 | return QuantizedTensor(t, out_public_dtype) 97 | else: 98 | return t 99 | out = pytree.tree_map_only(torch.Tensor, repack, raw_out) 100 | return out 101 | 102 | 103 | inp = torch.randint(0, 100, (2,), dtype=torch.float, requires_grad=True) 104 | qt = QuantizedTensor.from_tensor(inp) 105 | print("Input 1") 106 | print(qt) 107 | 108 | (qt * 3).sum().backward(retain_graph=True) 109 | print("Raw input 1's grad") 110 | print(inp.grad) 111 | 112 | qt2 = QuantizedTensor.from_tensor(torch.randint(0, 100, (2,), dtype=torch.float)).requires_grad_() 113 | print("Input 2") 114 | print(qt2) 115 | 116 | (qt2 * qt).sum().backward() 117 | print("Input 2's grad") 118 | print(qt2.grad) 119 | print("Raw input 1's grad") 120 | print(inp.grad) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # These requirements are needed to be able to run all the tests 2 | # Anything added here will automatically be installed in CI 3 | pytest 4 | numpy 5 | sympy 6 | expecttest 7 | matplotlib 8 | unittest-xml-reporting -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | # Files that need to be run manually 2 | files_to_run = [ 3 | "autograd_monkeypatch", 4 | "cuda_sanitizer", 5 | "data_parallel_tensor", 6 | "deferred_init", 7 | "dynamic_shapes", 8 | "dynamic_strides", 9 | # "enhanced_error_mode", # Actually raises an error 10 | "flat_view_tensor", 11 | "new_device", 12 | "py_dispatcher", 13 | "memory_debugging_tensor", 14 | "quantization_transform", 15 | "quantized_tensor", 16 | "simple_functorch", 17 | "torchdynamo_dynamic_inference", 18 | "tracing_guards", 19 | "use_cpu_for_rng", 20 | ] 21 | cuda_only_files = { 22 | "cuda_sanitizer", 23 | "memory_debugging_tensor", 24 | } 25 | 26 | # Files with actual tests 27 | import torch 28 | from torch.testing._internal.common_utils import run_tests 29 | from bug_zoo import BugZoo 30 | from empty_tensor import EmptyTensorTest 31 | from functorch_test import FunctorchTest 32 | from inner_autograd_tensor import InnerAutogradTensorTest 33 | from logging_mode import TracerTensorTest 34 | from negative_tensor import NegativeTensorTest 35 | # from nested_forward_ad import NestedForwardADTest 36 | from progressive_lowering_tensor import ProgressiveLoweringTensorTest 37 | from sparse_output import SparseOutputTest 38 | from tracer_tensor import TracerTensorTest 39 | from trivial_tensors import TrivialTensorTest 40 | from verifier_tensor import VerifierTensorTest 41 | 42 | if __name__ == "__main__": 43 | import os 44 | for file in files_to_run: 45 | print(f"Running {file}:") 46 | if (not torch.cuda.is_available()) and file in cuda_only_files: 47 | print("Skipped as cuda is not available") 48 | continue 49 | ret = os.system(f"python {file}.py 1> /dev/null 2>/dev/null") 50 | if ret != 0: 51 | print("Failure:") 52 | ret = os.system(f"python {file}.py") 53 | exit(-1) 54 | else: 55 | print("All good!") 56 | 57 | run_tests() 58 | -------------------------------------------------------------------------------- /sparse_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing._internal.common_utils import run_tests, TestCase 3 | from torch.utils._python_dispatch import TorchDispatchMode 4 | 5 | """ 6 | From Christian: 7 | 8 | One category is select operators. The most stereotypical example is 9 | nn.Embedding (and historically it was the reason we introduced sparsity). Part 10 | of sparse gradient support is also preventing further spread of the 11 | "sparse_grad" kwarg (e.g. gather 12 | (https://pytorch.org/docs/master/generated/torch.gather.html#torch.gather)) 13 | and getting rid of torch.sparse.sum (sometimes sparse grad sometimes not 14 | https://pytorch.org/docs/master/generated/torch.sparse.sum.html#torch.sparse.sum 15 | ) or torch.sparse.mm. 16 | 17 | The other category are binary ops. This is also where the output layout choice 18 | comes from. 19 | 20 | I wrote up an issue overview that categories things 21 | https://docs.google.com/document/d/12wN0RPFoavSxIYtvtRTD5cv0fN1FlRhOkaOAFYCfxEI/edit# 22 | - checkout the section under "mul". There's also 23 | https://github.com/pytorch/pytorch/issues/8853 . 24 | """ 25 | 26 | 27 | class SparseOutputMode(TorchDispatchMode): 28 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 29 | if func == torch.ops.aten.mul: 30 | # TODO: this algorithm is probably not what you actually want to do 31 | # run the multiply 32 | r = func(*args, **kwargs) 33 | # sparsify it 34 | return r.to_sparse() 35 | 36 | return func(*args, **kwargs) 37 | 38 | 39 | def sparse_output(func, *args, **kwargs): 40 | with SparseOutputMode(): 41 | return func(*args, **kwargs) 42 | 43 | 44 | class SparseOutputTest(TestCase): 45 | def test_mul(self): 46 | x = torch.randn(3, requires_grad=True) 47 | y = torch.randn(3, requires_grad=True) 48 | r = sparse_output(torch.mul, torch.diag(x), torch.diag(y)) 49 | self.assertEqual( 50 | r, 51 | torch.sparse_coo_tensor( 52 | torch.tensor([[0, 1, 2], [0, 1, 2]], dtype=torch.long), x * y 53 | ), 54 | ) 55 | # This doesn't work yet because this results in a sparse-dense 56 | # multiply which is not supported 57 | # r.values().sum().backward() 58 | 59 | 60 | if __name__ == "__main__": 61 | run_tests() 62 | -------------------------------------------------------------------------------- /torchdynamo_dynamic_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "63da03a1", 7 | "metadata": { 8 | "lines_to_end_of_cell_marker": 0, 9 | "lines_to_next_cell": 1 10 | }, 11 | "outputs": [ 12 | { 13 | "data": { 14 | "text/plain": [ 15 | "" 16 | ] 17 | }, 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "output_type": "execute_result" 21 | } 22 | ], 23 | "source": [ 24 | "import functools\n", 25 | "import itertools\n", 26 | "import traceback\n", 27 | "from dataclasses import dataclass, field\n", 28 | "from enum import Enum\n", 29 | "from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union\n", 30 | "\n", 31 | "import torch\n", 32 | "\n", 33 | "torch.manual_seed(0)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "12e2ab68", 39 | "metadata": {}, 40 | "source": [ 41 | "This notebook explains how Jason Ansel's proposal for very simple\n", 42 | "dynamic shapes in TorchDynamo works in\n", 43 | "https://github.com/facebookresearch/torchdynamo/issues/38\n", 44 | "\n", 45 | "The general model for torchdynamo graphs is that they consist of a\n", 46 | "set of guards plus a trace. The guards say whether or not the trace\n", 47 | "is valid; if it is not, torchdynamo must redo its analysis and\n", 48 | "recompile the graph in question.\n", 49 | "\n", 50 | "In this simplified model, we will model torchdynamo graphs as a\n", 51 | "dead simple AST (in reality, you need a graph representation to\n", 52 | "specify ordering of operations, sharing and multiple outputs, but\n", 53 | "they are not relevant to this example so I've dumped them.)\n", 54 | "\n", 55 | "First, we define various operations on the graph. add and mul\n", 56 | "do what you expect: they perform a (broadcasting) PyTorch add and\n", 57 | "mul. `dynamic_param` and `static_param` both represent inputs\n", 58 | "to the graph. The distinction is that `dynamic_param` inputs\n", 59 | "correspond to inputs which are fully dynamic: their shapes can\n", 60 | "vary from execution to execution of the graph. `static_param`\n", 61 | "inputs, on the other hand, are required to be some specific size.\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 2, 67 | "id": "4d57a5aa", 68 | "metadata": { 69 | "lines_to_end_of_cell_marker": 0, 70 | "lines_to_next_cell": 1 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "@dataclass(frozen=True)\n", 75 | "class Op:\n", 76 | " name: str\n", 77 | "\n", 78 | " def __str__(self):\n", 79 | " return self.name\n", 80 | "\n", 81 | "\n", 82 | "v_dynamic_param = Op(\"v_dynamic_param\")\n", 83 | "v_static_param = Op(\"v_static_param\")\n", 84 | "v_add = Op(\"v_add\")\n", 85 | "v_mul = Op(\"v_mul\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "24264adb", 91 | "metadata": {}, 92 | "source": [ 93 | "We can stitch these operations together in an AST of expressions\n", 94 | "of operators applied to some other expressions (and possibly some\n", 95 | "other, static metadata)." 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "id": "47ccaa14", 102 | "metadata": { 103 | "lines_to_end_of_cell_marker": 0, 104 | "lines_to_next_cell": 1 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "\n", 109 | "@dataclass(eq=False)\n", 110 | "class Node:\n", 111 | " op: Op\n", 112 | " inputs: List[\"Node\"] = field(default_factory=list)\n", 113 | " params: Dict[str, Any] = field(default_factory=dict)\n", 114 | "\n", 115 | " def __repr__(self):\n", 116 | " inputs_str = \", \".join(repr(i) for i in self.inputs)\n", 117 | " params_str = \"\"\n", 118 | " if self.inputs and self.params:\n", 119 | " params_str += \", \"\n", 120 | " params_str += \", \".join(\n", 121 | " f\"{k}={v}\"\n", 122 | " for k, v in self.params.items()\n", 123 | " if k != \"size\" and self.op is v_dynamic_param\n", 124 | " )\n", 125 | " return f\"{self.op}({inputs_str}{params_str})\"\n" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "dfcd54fb", 131 | "metadata": {}, 132 | "source": [ 133 | "And then we can write an interpreter for these inputs. Notice that\n", 134 | "we fetch parameters from an environment that's passed into the\n", 135 | "interpreter; if the parameter is dynamic we pass it in directly,\n", 136 | "but if it's static, we first check that the size of the parameter\n", 137 | "is consistent with the saved size." 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "id": "924b75a6", 144 | "metadata": { 145 | "lines_to_end_of_cell_marker": 0, 146 | "lines_to_next_cell": 1 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "\n", 151 | "INTERP_RULES = {}\n", 152 | "INTERP_RULES[v_add] = lambda x, y: x + y\n", 153 | "INTERP_RULES[v_mul] = lambda x, y: x * y\n", 154 | "\n", 155 | "\n", 156 | "def interp_node(n: Node, env: Dict[Node, torch.Tensor]):\n", 157 | " if n.op is v_dynamic_param:\n", 158 | " return env[n.params['name']]\n", 159 | " elif n.op is v_static_param:\n", 160 | " r = env[n.params['name']]\n", 161 | " assert (\n", 162 | " r.shape == n.params[\"size\"]\n", 163 | " ), f\"static shape mismatch: {r.shape} and {n.params['size']}\"\n", 164 | " return r\n", 165 | " args = [interp_node(i, env) for i in n.inputs]\n", 166 | " return INTERP_RULES[n.op](*args, **n.params)\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "14a70302", 172 | "metadata": {}, 173 | "source": [ 174 | "In actual torchdynamo, we can construct our IR directly via\n", 175 | "bytecode analysis. But this isn't really necessary for our\n", 176 | "example here; we can use an ordinary tracer to construct the IR as\n", 177 | "well. Our tracer is very simple." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "id": "036777d0", 184 | "metadata": { 185 | "lines_to_next_cell": 1 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "@dataclass\n", 190 | "class Variable:\n", 191 | " tensor: torch.Tensor\n", 192 | " node: Node\n", 193 | "\n", 194 | " # This will be implemented later\n", 195 | " def size(self):\n", 196 | " return variable_size(self)\n", 197 | "\n", 198 | " @staticmethod\n", 199 | " def param(tensor: torch.Tensor, name: str):\n", 200 | " # Save the observed shape, but by default dynamic_param won't\n", 201 | " # check it!\n", 202 | " return Variable(tensor, Node(v_dynamic_param, [], {\"name\": name, \"size\": tensor.shape}))\n", 203 | "\n", 204 | " def __mul__(self, rhs: \"Variable\") -> \"Variable\":\n", 205 | " r_tensor = self.tensor * rhs.tensor\n", 206 | " r_node = Node(v_mul, [self.node, rhs.node])\n", 207 | " return Variable(r_tensor, r_node)\n", 208 | "\n", 209 | " def __add__(self, rhs: \"Variable\") -> \"Variable\":\n", 210 | " r_tensor = self.tensor + rhs.tensor\n", 211 | " r_node = Node(v_add, [self.node, rhs.node])\n", 212 | " return Variable(r_tensor, r_node)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "id": "8d05d056", 218 | "metadata": {}, 219 | "source": [ 220 | "With this, we can run a simple example, print out the IR for it,\n", 221 | "and then rerun it. By default, we treat the inputs as dynamics,\n", 222 | "so we are allowed to rerun the IR even though the input sizes have\n", 223 | "changed (because there is nothing shape specific in the IR.)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 6, 229 | "id": "42443acd", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "a = Variable.param(torch.randn(4), \"a\")\n", 234 | "b = Variable.param(torch.randn(4), \"b\")\n", 235 | "r = a * b" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 7, 241 | "id": "6b22da5f", 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "v_mul(v_dynamic_param(name=a), v_dynamic_param(name=b))\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "print(r.node)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "id": "c8fb3092", 260 | "metadata": { 261 | "lines_to_next_cell": 1 262 | }, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "tensor([-0.7916, -0.4439, -0.6567, 0.2004, -0.9429])\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "print(interp_node(r.node, {\"a\": torch.randn(5), \"b\": torch.randn(1)}))" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "2b0b306f", 279 | "metadata": {}, 280 | "source": [ 281 | "Now, the problem is what happens if a user wants to vary the\n", 282 | "behavior of their computation based on the size of their input?\n", 283 | "Then our trace is no longer valid in this situation!\n", 284 | "\n", 285 | "torchdynamo deals with this situation by looking for explicit uses\n", 286 | "of sizes. If there is an explicit use of a size, it goes ahead\n", 287 | "and conservatively marks all of the parameters which could have\n", 288 | "contributed to the size of this tensor as static, indicating that\n", 289 | "the trace is now only valid for those specific sizes." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 9, 295 | "id": "c32d30d6", 296 | "metadata": { 297 | "lines_to_end_of_cell_marker": 0, 298 | "lines_to_next_cell": 1 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "\n", 303 | "def input_sources(node):\n", 304 | " r = set()\n", 305 | " for i in node.inputs:\n", 306 | " r |= input_sources(i)\n", 307 | " if node.op is v_dynamic_param:\n", 308 | " r.add(node)\n", 309 | " return r\n", 310 | "\n", 311 | "def variable_size(self):\n", 312 | " for i in input_sources(self.node):\n", 313 | " # change it from dynamic to static. (the parameter\n", 314 | " # already saved the size, we don't need to recover it)\n", 315 | " i.op = v_static_param\n", 316 | " return self.tensor.size()\n" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "e151cbc8", 322 | "metadata": {}, 323 | "source": [ 324 | "Now if we have dependent control flow on an input, we will\n", 325 | "appropriately fail if you pass in mismatching sizes." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 10, 331 | "id": "7a5b2a4f", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "\n", 336 | "a = Variable.param(torch.randn(4), \"a\")\n", 337 | "b = Variable.param(torch.randn(4), \"b\")\n", 338 | "if a.size()[0] == 4:\n", 339 | " r = a + b\n", 340 | "else:\n", 341 | " r = a * b\n" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 11, 347 | "id": "ca4017fa", 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "v_add(v_static_param(), v_dynamic_param(name=b))\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "print(r.node)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 12, 365 | "id": "22c28e13", 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "tensor([-0.3506, -0.0163, 0.1710, 0.5453])\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "print(interp_node(r.node, {\"a\": torch.randn(4), \"b\": torch.randn(4)}))" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 13, 383 | "id": "cfa558aa", 384 | "metadata": {}, 385 | "outputs": [ 386 | { 387 | "name": "stderr", 388 | "output_type": "stream", 389 | "text": [ 390 | "Traceback (most recent call last):\n", 391 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/2739262870.py\", line 2, in \n", 392 | " print(interp_node(r.node, {\"a\": torch.randn(5), \"b\": torch.randn(1)}))\n", 393 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in interp_node\n", 394 | " args = [interp_node(i, env) for i in n.inputs]\n", 395 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in \n", 396 | " args = [interp_node(i, env) for i in n.inputs]\n", 397 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 11, in interp_node\n", 398 | " assert (\n", 399 | "AssertionError: static shape mismatch: torch.Size([5]) and torch.Size([4])\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "try:\n", 405 | " print(interp_node(r.node, {\"a\": torch.randn(5), \"b\": torch.randn(1)}))\n", 406 | "except Exception:\n", 407 | " traceback.print_exc()" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "id": "cd961281", 413 | "metadata": {}, 414 | "source": [ 415 | "It will still work even if the shape check is done an intermediate\n", 416 | "computation (in this case, both a and b are marked as dynamic)." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 14, 422 | "id": "ee8a24c1", 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "\n", 427 | "a = Variable.param(torch.randn(1), \"a\")\n", 428 | "b = Variable.param(torch.randn(1), \"b\")\n", 429 | "c = a + b\n", 430 | "if c.size()[0] == 1:\n", 431 | " r = a + c\n", 432 | "else:\n", 433 | " r = a * c" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 15, 439 | "id": "17f83008", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "name": "stdout", 444 | "output_type": "stream", 445 | "text": [ 446 | "v_add(v_static_param(), v_add(v_static_param(), v_static_param()))\n" 447 | ] 448 | } 449 | ], 450 | "source": [ 451 | "print(r.node)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 16, 457 | "id": "3eaf8e72", 458 | "metadata": {}, 459 | "outputs": [ 460 | { 461 | "name": "stderr", 462 | "output_type": "stream", 463 | "text": [ 464 | "Traceback (most recent call last):\n", 465 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/81332487.py\", line 2, in \n", 466 | " print(interp_node(r.node, {\"a\": torch.randn(1), \"b\": torch.randn(5)}))\n", 467 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in interp_node\n", 468 | " args = [interp_node(i, env) for i in n.inputs]\n", 469 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in \n", 470 | " args = [interp_node(i, env) for i in n.inputs]\n", 471 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in interp_node\n", 472 | " args = [interp_node(i, env) for i in n.inputs]\n", 473 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 15, in \n", 474 | " args = [interp_node(i, env) for i in n.inputs]\n", 475 | " File \"/var/folders/11/bcmcs8d57q7dxbysb4w_h1ym0000gn/T/ipykernel_65823/4116253730.py\", line 11, in interp_node\n", 476 | " assert (\n", 477 | "AssertionError: static shape mismatch: torch.Size([5]) and torch.Size([1])\n" 478 | ] 479 | } 480 | ], 481 | "source": [ 482 | "try:\n", 483 | " print(interp_node(r.node, {\"a\": torch.randn(1), \"b\": torch.randn(5)}))\n", 484 | "except Exception:\n", 485 | " traceback.print_exc()" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "id": "989c9887", 491 | "metadata": {}, 492 | "source": [ 493 | "This analysis is VERY conservative. Although there are some easy\n", 494 | "improvements you can apply, you are limited in the precision you can\n", 495 | "have without having shape formulas for operators that can propagate\n", 496 | "dynamic shapes. With shape formulas, you can track exact dependencies\n", 497 | "on a size-by-size basis; if you matrix multiply two tensors C = A @ B,\n", 498 | "a use of C.shape[0] will only add a guard for A.shape[0], and a use of\n", 499 | "C.shape[1] will only add a guard for B.shape[1]. The analysis here\n", 500 | "will just make both A and B static, and we cannot do any better\n", 501 | "without more knowledge of formulas. This suggests that an important\n", 502 | "workstream to improve precision is to get dynamic-aware shape formulas\n", 503 | "in Python for as many operators as possible." 504 | ] 505 | } 506 | ], 507 | "metadata": { 508 | "jupytext": { 509 | "formats": "ipynb,py:light" 510 | }, 511 | "kernelspec": { 512 | "display_name": "Python 3 (ipykernel)", 513 | "language": "python", 514 | "name": "python3" 515 | }, 516 | "language_info": { 517 | "codemirror_mode": { 518 | "name": "ipython", 519 | "version": 3 520 | }, 521 | "file_extension": ".py", 522 | "mimetype": "text/x-python", 523 | "name": "python", 524 | "nbconvert_exporter": "python", 525 | "pygments_lexer": "ipython3", 526 | "version": "3.8.11" 527 | } 528 | }, 529 | "nbformat": 4, 530 | "nbformat_minor": 5 531 | } 532 | -------------------------------------------------------------------------------- /torchdynamo_dynamic_inference.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:light 5 | # text_representation: 6 | # extension: .py 7 | # format_name: light 8 | # format_version: '1.5' 9 | # jupytext_version: 1.13.7 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # + 17 | import traceback 18 | from dataclasses import dataclass, field 19 | from typing import Any, Dict, List 20 | 21 | import torch 22 | 23 | torch.manual_seed(0) 24 | # - 25 | 26 | # This notebook explains how Jason Ansel's proposal for very simple 27 | # dynamic shapes in TorchDynamo works in 28 | # https://github.com/facebookresearch/torchdynamo/issues/38 29 | # 30 | # The general model for torchdynamo graphs is that they consist of a 31 | # set of guards plus a trace. The guards say whether or not the trace 32 | # is valid; if it is not, torchdynamo must redo its analysis and 33 | # recompile the graph in question. 34 | # 35 | # In this simplified model, we will model torchdynamo graphs as a 36 | # dead simple AST (in reality, you need a graph representation to 37 | # specify ordering of operations, sharing and multiple outputs, but 38 | # they are not relevant to this example so I've dumped them.) 39 | # 40 | # First, we define various operations on the graph. add and mul 41 | # do what you expect: they perform a (broadcasting) PyTorch add and 42 | # mul. `dynamic_param` and `static_param` both represent inputs 43 | # to the graph. The distinction is that `dynamic_param` inputs 44 | # correspond to inputs which are fully dynamic: their shapes can 45 | # vary from execution to execution of the graph. `static_param` 46 | # inputs, on the other hand, are required to be some specific size. 47 | # 48 | 49 | # + 50 | @dataclass(frozen=True) 51 | class Op: 52 | name: str 53 | 54 | def __str__(self): 55 | return self.name 56 | 57 | 58 | v_dynamic_param = Op("v_dynamic_param") 59 | v_static_param = Op("v_static_param") 60 | v_add = Op("v_add") 61 | v_mul = Op("v_mul") 62 | # - 63 | 64 | # We can stitch these operations together in an AST of expressions 65 | # of operators applied to some other expressions (and possibly some 66 | # other, static metadata). 67 | 68 | # + 69 | 70 | 71 | @dataclass(eq=False) 72 | class Node: 73 | op: Op 74 | inputs: List["Node"] = field(default_factory=list) 75 | params: Dict[str, Any] = field(default_factory=dict) 76 | 77 | def __repr__(self): 78 | inputs_str = ", ".join(repr(i) for i in self.inputs) 79 | params_str = "" 80 | if self.inputs and self.params: 81 | params_str += ", " 82 | params_str += ", ".join( 83 | f"{k}={v}" 84 | for k, v in self.params.items() 85 | if k != "size" and self.op is v_dynamic_param 86 | ) 87 | return f"{self.op}({inputs_str}{params_str})" 88 | 89 | 90 | # - 91 | 92 | # And then we can write an interpreter for these inputs. Notice that 93 | # we fetch parameters from an environment that's passed into the 94 | # interpreter; if the parameter is dynamic we pass it in directly, 95 | # but if it's static, we first check that the size of the parameter 96 | # is consistent with the saved size. 97 | 98 | # + 99 | 100 | INTERP_RULES = {} 101 | INTERP_RULES[v_add] = lambda x, y: x + y 102 | INTERP_RULES[v_mul] = lambda x, y: x * y 103 | 104 | 105 | def interp_node(n: Node, env: Dict[Node, torch.Tensor]): 106 | if n.op is v_dynamic_param: 107 | return env[n.params["name"]] 108 | elif n.op is v_static_param: 109 | r = env[n.params["name"]] 110 | assert ( 111 | r.shape == n.params["size"] 112 | ), f"static shape mismatch: {r.shape} and {n.params['size']}" 113 | return r 114 | args = [interp_node(i, env) for i in n.inputs] 115 | return INTERP_RULES[n.op](*args, **n.params) 116 | 117 | 118 | # - 119 | 120 | # In actual torchdynamo, we can construct our IR directly via 121 | # bytecode analysis. But this isn't really necessary for our 122 | # example here; we can use an ordinary tracer to construct the IR as 123 | # well. Our tracer is very simple. 124 | 125 | 126 | @dataclass 127 | class Variable: 128 | tensor: torch.Tensor 129 | node: Node 130 | 131 | # This will be implemented later 132 | def size(self): 133 | return variable_size(self) 134 | 135 | @staticmethod 136 | def param(tensor: torch.Tensor, name: str): 137 | # Save the observed shape, but by default dynamic_param won't 138 | # check it! 139 | return Variable( 140 | tensor, Node(v_dynamic_param, [], {"name": name, "size": tensor.shape}) 141 | ) 142 | 143 | def __mul__(self, rhs: "Variable") -> "Variable": 144 | r_tensor = self.tensor * rhs.tensor 145 | r_node = Node(v_mul, [self.node, rhs.node]) 146 | return Variable(r_tensor, r_node) 147 | 148 | def __add__(self, rhs: "Variable") -> "Variable": 149 | r_tensor = self.tensor + rhs.tensor 150 | r_node = Node(v_add, [self.node, rhs.node]) 151 | return Variable(r_tensor, r_node) 152 | 153 | 154 | # With this, we can run a simple example, print out the IR for it, 155 | # and then rerun it. By default, we treat the inputs as dynamics, 156 | # so we are allowed to rerun the IR even though the input sizes have 157 | # changed (because there is nothing shape specific in the IR.) 158 | 159 | a = Variable.param(torch.randn(4), "a") 160 | b = Variable.param(torch.randn(4), "b") 161 | r = a * b 162 | 163 | print(r.node) 164 | 165 | print(interp_node(r.node, {"a": torch.randn(5), "b": torch.randn(1)})) 166 | 167 | # Now, the problem is what happens if a user wants to vary the 168 | # behavior of their computation based on the size of their input? 169 | # Then our trace is no longer valid in this situation! 170 | # 171 | # torchdynamo deals with this situation by looking for explicit uses 172 | # of sizes. If there is an explicit use of a size, it goes ahead 173 | # and conservatively marks all of the parameters which could have 174 | # contributed to the size of this tensor as static, indicating that 175 | # the trace is now only valid for those specific sizes. 176 | 177 | # + 178 | 179 | 180 | def input_sources(node): 181 | r = set() 182 | for i in node.inputs: 183 | r |= input_sources(i) 184 | if node.op is v_dynamic_param: 185 | r.add(node) 186 | return r 187 | 188 | 189 | def variable_size(self): 190 | for i in input_sources(self.node): 191 | # change it from dynamic to static. (the parameter 192 | # already saved the size, we don't need to recover it) 193 | i.op = v_static_param 194 | return self.tensor.size() 195 | 196 | 197 | # - 198 | 199 | # Now if we have dependent control flow on an input, we will 200 | # appropriately fail if you pass in mismatching sizes. 201 | 202 | # + 203 | 204 | a = Variable.param(torch.randn(4), "a") 205 | b = Variable.param(torch.randn(4), "b") 206 | if a.size()[0] == 4: 207 | r = a + b 208 | else: 209 | r = a * b 210 | 211 | # - 212 | 213 | print(r.node) 214 | 215 | print(interp_node(r.node, {"a": torch.randn(4), "b": torch.randn(4)})) 216 | 217 | try: 218 | print(interp_node(r.node, {"a": torch.randn(5), "b": torch.randn(1)})) 219 | except Exception: 220 | traceback.print_exc() 221 | 222 | # It will still work even if the shape check is done an intermediate 223 | # computation (in this case, both a and b are marked as dynamic). 224 | 225 | # + 226 | 227 | a = Variable.param(torch.randn(1), "a") 228 | b = Variable.param(torch.randn(1), "b") 229 | c = a + b 230 | if c.size()[0] == 1: 231 | r = a + c 232 | else: 233 | r = a * c 234 | # - 235 | 236 | print(r.node) 237 | 238 | try: 239 | print(interp_node(r.node, {"a": torch.randn(1), "b": torch.randn(5)})) 240 | except Exception: 241 | traceback.print_exc() 242 | 243 | # This analysis is VERY conservative. Although there are some easy 244 | # improvements you can apply, you are limited in the precision you can 245 | # have without having shape formulas for operators that can propagate 246 | # dynamic shapes. With shape formulas, you can track exact dependencies 247 | # on a size-by-size basis; if you matrix multiply two tensors C = A @ B, 248 | # a use of C.shape[0] will only add a guard for A.shape[0], and a use of 249 | # C.shape[1] will only add a guard for B.shape[1]. The analysis here 250 | # will just make both A and B static, and we cannot do any better 251 | # without more knowledge of formulas. This suggests that an important 252 | # workstream to improve precision is to get dynamic-aware shape formulas 253 | # in Python for as many operators as possible. 254 | -------------------------------------------------------------------------------- /tracer_tensor.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | 3 | import torch 4 | from base_tensor import BaseTensor 5 | from torch import Tensor 6 | from torch.fx import Graph, GraphModule, Tracer 7 | from torch.fx.passes.shape_prop import _extract_tensor_metadata 8 | from torch.testing._internal.common_utils import run_tests, TestCase 9 | from torch.utils._pytree import tree_map 10 | 11 | from utils import no_dispatch 12 | 13 | """ 14 | TracerTensor is a tensor that traces ATen operations that are performed on it 15 | and writes the resulting trace to FX IR. We extracted this tracing 16 | implementation from Horace He's implementation for AOTAutograd 17 | (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py) 18 | to make it easier for you to see how it is put together. The basic 19 | implementation concept is simple: we run all tensor operations as normal, but 20 | on the side, we also duplicate the operations on FX Proxy objects, which are 21 | then responsible for writing in the results into FX IR. The top level tracing 22 | function dispatch_trace is a modified version of FX's `symbolic_trace` 23 | function: we always take a tuple of concrete Tensor inputs, and we generate 24 | placeholder proxies for all of them and attach them to TracerTensors which we 25 | actually feed into the model. 26 | 27 | Tracing with __torch_dispatch__ has some properties which are worth being 28 | aware of: 29 | 30 | - It is able to trace through autograd and other PyTorch subsystems (as they 31 | are all desugared into lower level calls by the time you get to 32 | `__torch_dispatch__`. Composite operations (CompositeImplicitAutograd) 33 | will be desugared by the time you get to trace. 34 | - It produces FX IR with `torch.ops.aten` nodes (e.g., you will get 35 | `torch.ops.aten.add.Tensor`, not merely `torch.add`) 36 | - Unlike FX, it is not able to trace non-Tensor symbolic values (e.g., 37 | sizes); these are all specialized to particular ints by the time 38 | `__torch_dispatch__` is called. Nick Korovaiko is working on removing this 39 | limitation. 40 | - In fact, you can think of it as a pure Python implementation of 41 | torch.jit.trace, except that it outputs FX IR rather than TorchScript IR. 42 | """ 43 | 44 | 45 | class TracerTensor(BaseTensor): 46 | # We support autograd-ing through the TracerTensor (which you 47 | # really can think of as a good old fashioned tensor that also 48 | # takes a proxy along for the ride). If you need to terminate 49 | # the autograd early, use torch.autograd.grad with explicit 50 | # inputs. 51 | @staticmethod 52 | def __new__(cls, elem, proxy): 53 | return super().__new__(cls, elem) 54 | 55 | def __init__(self, elem, proxy): 56 | # elem does not need to be recorded, because TracerTensor *is a* elem 57 | self.proxy = proxy 58 | # Since the proxy is associated with a concrete Tensor object, we also 59 | # know exactly what its tensor metadata should be, so populate it 60 | proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(self) 61 | 62 | @classmethod 63 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 64 | def unwrap_proxy(t): 65 | if isinstance(t, cls): 66 | return t.proxy 67 | else: 68 | return t 69 | 70 | def wrap(t, p): 71 | if isinstance(t, Tensor) and not isinstance(t, cls): 72 | return cls(t, p) 73 | else: 74 | assert t == p 75 | return t 76 | 77 | # Run the original computation 78 | r = super().__torch_dispatch__(func, types, args, kwargs) 79 | 80 | # Run the computation on FX proxies to record it into graph 81 | r_proxy = func(*tree_map(unwrap_proxy, args), **tree_map(unwrap_proxy, kwargs)) 82 | 83 | # NB: we cannot zip r and r_proxy, or rely on r_proxy knowing its 84 | # structure, because r_proxy as implemented in FX typically is a proxy 85 | # that will generate IR for accessing subfields as you invoke. So 86 | # r has to "drive" the deconstruction. 87 | # NB: this assumes there aren't any recursive return structs, which 88 | # is generally a safe bet in the current codegen 89 | if isinstance(r, list): 90 | return [wrap(t, r_proxy[i]) for i, t in enumerate(r)] 91 | elif isinstance(r, tuple): 92 | return tuple(wrap(t, r_proxy[i]) for i, t in enumerate(r)) 93 | else: 94 | return wrap(r, r_proxy) 95 | 96 | 97 | class DispatchTracer(Tracer): 98 | # Our implementation here divergences a bit from Horace's. This version 99 | # modeled off of Trace.trace but we don't need to monkeypatch anything 100 | # because we will rely on __torch_dispatch__ to handle interposition. 101 | # Unlike standard FX, we don't want to trace leaf modules, we want to get 102 | # a graph of entirely torch.ops.aten operations 103 | # 104 | # Unlike FX, the semantics for concrete_args is a little different. 105 | # Typically, if you FX trace a function, you leave concrete_args None 106 | # (because you want most of the arguments to be symbolic). When we 107 | # dispatch trace a function, we want the arguments to be concrete because 108 | # they are going to advertise as honest to goodness tensors (if you want 109 | # to avoid actually doing the compute while tracing, you should pass in 110 | # meta tensors). 111 | def trace(self, root, concrete_args): 112 | # TODO: add torch.nn.Module support (??) 113 | assert not isinstance(root, torch.nn.Module) 114 | self.root = torch.nn.Module() 115 | fn = root 116 | 117 | tracer_cls = getattr(self, "__class__", None) 118 | self.graph = Graph(tracer_cls=tracer_cls) 119 | # Don't support module, so tensor_attrs is always empty 120 | self.tensor_attrs = {} 121 | assert isinstance(fn, FunctionType) 122 | 123 | # Reimplementation of create_args_for_root, but this is pretty 124 | # different as we always expect concrete arguments to be provided 125 | # and we still generate placeholders for each of them. 126 | cnt = 0 127 | 128 | def replace_tracer(arg): 129 | nonlocal cnt 130 | cnt += 1 131 | # TODO: add back argument name sniffing 132 | return TracerTensor( 133 | arg, self.create_proxy("placeholder", f"arg_{str(cnt)}", (), {}) 134 | ) 135 | 136 | # TODO: generalize to tree_map (but this will make verifier_tensor 137 | # harder to implement) 138 | args = [replace_tracer(a) for a in concrete_args] 139 | 140 | result = fn(*args) 141 | 142 | self.create_node( 143 | "output", 144 | "output", 145 | (self.create_arg(result.proxy),), 146 | {}, 147 | type_expr=fn.__annotations__.get("return", None), 148 | ) 149 | 150 | self.submodule_paths = None 151 | 152 | # Unlike regular Tracer.trace, we also return the result as it 153 | # contains useful data (the result of your computation) 154 | # TODO: better idiom for this 155 | with no_dispatch(): 156 | unwrapped_result = result.view(result.shape) 157 | return unwrapped_result, self.graph 158 | 159 | 160 | def dispatch_trace(root, concrete_args): 161 | tracer = DispatchTracer() 162 | result, graph = tracer.trace(root, concrete_args) 163 | name = root.__name__ 164 | return result, GraphModule(tracer.root, graph, name) 165 | 166 | 167 | class TracerTensorTest(TestCase): 168 | def test_basic(self): 169 | r, g = dispatch_trace(lambda x, y: x + y, (torch.ones(2), torch.ones(2))) 170 | self.assertNotIsInstance(r, TracerTensor) 171 | self.assertEqual(r, torch.tensor([2.0, 2.0])) 172 | self.assertExpectedInline( 173 | str(g.graph), 174 | """\ 175 | graph(): 176 | %arg_1 : [num_users=1] = placeholder[target=arg_1] 177 | %arg_2 : [num_users=1] = placeholder[target=arg_2] 178 | %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg_1, %arg_2), kwargs = {}) 179 | return add_tensor""", 180 | ) 181 | 182 | def test_constant(self): 183 | x = torch.ones(2) 184 | _, g = dispatch_trace(lambda y: x + y, (torch.ones(2),)) 185 | self.assertExpectedInline( 186 | str(g.graph), 187 | """\ 188 | graph(): 189 | %arg_1 : [num_users=1] = placeholder[target=arg_1] 190 | %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0] 191 | %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_tensor_constant0, %arg_1), kwargs = {}) 192 | return add_tensor""", 193 | ) 194 | 195 | 196 | if __name__ == "__main__": 197 | run_tests() 198 | -------------------------------------------------------------------------------- /trivial_tensors.py: -------------------------------------------------------------------------------- 1 | import weakref 2 | 3 | import torch 4 | 5 | from base_tensor import BaseTensor 6 | from torch import Tensor 7 | from torch.testing._internal.common_utils import ( 8 | disable_gc, 9 | instantiate_parametrized_tests, 10 | parametrize, 11 | run_tests, 12 | TestCase, 13 | ) 14 | from torch.utils._pytree import tree_map 15 | 16 | # In a lot of use cases for tensor subclasses, there is a concept 17 | # of an "inner" tensor, which is a normal, non-subclassed tensor 18 | # that after you do your stuff you can redispatch to. This file gives recipes 19 | # for a number of trivial tensors; tensors which look and behave exactly like 20 | # their inner tensors, and propagate themselves through all invocations. As 21 | # it turns out, there are a number of different ways to do the same thing. 22 | # However, the main axis of variation is this: 23 | # 24 | # Do you actually store the inner tensor (composition / has-a 25 | # relationship) or do you make what is effectively a super call 26 | # (inheritance / is-a relationship) 27 | # 28 | # These options have different tradeoffs which are discussed in more 29 | # detail below. Hopefully this file will give you an idea about some of the 30 | # tools in your toolbox. 31 | # 32 | # WARNING: These tensors inherit from BaseTensor, which is a local 33 | # compatibility shim that tracks changes to Tensor that we intend to make but 34 | # haven't made it to core. If you want to use these classes you will need to 35 | # include the extra bits from BaseTensor. 36 | # 37 | # TODO: Channeling Alban, we probably want to take some of these exemplars and 38 | # turn them into part of the official public API, so end users don't have to 39 | # copy paste them into their own functions. 40 | # 41 | # TODO: Redo these examples with compositionality in mind. 42 | 43 | 44 | class TrivialTensorViaInheritance(BaseTensor): 45 | """ 46 | TrivialTensorViaInheritance extends tensor behavior using inheritance ("is 47 | a"). These implementations are very straightforward and we recommend 48 | using them if it works for your use case. To get the base behavior, 49 | you use standard object-oriented idiom of super(). 50 | 51 | Benefits and downsides of this representation: 52 | 53 | + Efficient representation (only one tensor). 54 | + Do not have to worry about synchronizing metadata between the inner 55 | and outer tensor. 56 | = Requires multiple inheritance to do composition. This *does* 57 | work, but it is a bit mind-bending, you have to deal with the 58 | diamond inheritance problem, and traditionally you only get a fixed 59 | set of composition (rather than dynamic, as in functorch) unless 60 | you're willing to generate classes on the fly. 61 | - Doesn't work if you need to run internal PyTorch subsystems 62 | (e.g., autograd) multiple times. 63 | - Doesn't work if the internal tensor has a different shape 64 | than the outer tensor. 65 | - Doesn't work if you need multiple internal tensors. 66 | """ 67 | 68 | @classmethod 69 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 70 | def wrap(t): 71 | # When could the returned tensor already be our subclass? 72 | # The most common situation is when an input tensor 73 | # is returned as an output tensor, e.g., inplace or out 74 | # implementations. 75 | if isinstance(t, Tensor) and not isinstance(t, cls): 76 | return cls(t) 77 | else: 78 | return t 79 | 80 | return tree_map(wrap, super().__torch_dispatch__(func, types, args, kwargs)) 81 | 82 | 83 | class TrivialTensorViaComposition(BaseTensor): 84 | """ 85 | TrivialTensorViaComposition extends tesor behavior using composition ("has 86 | a"). You can see that unlike inheritance, we save the original tensor in 87 | a field in the tensor. These are often referred to as "wrapper tensors", 88 | as you are wrapping the original tensor. 89 | 90 | The nuance of wrapper tensors is that the external wrapper tensor is still 91 | required to have all of the metadata that the inner tensor has; this 92 | includes stride and storage! In this example, we assume the inner and 93 | outer metadata is exactly synchronized... so in fact the wrapper tensor is 94 | literally just a TrivialTensorViaInheritance (in particular, the outer 95 | wrapper points to the same storage as the inner wrapped tensor). The only 96 | difference is that we've also recorded the original tensor as an element 97 | on the class as well. 98 | 99 | Benefits and downsides of this representation: 100 | 101 | + Many people find perform operations in the inner layer more 102 | intuitive (just unwrap the tensor) 103 | + In principle, is compositional with other tensor subclasses; in 104 | practice, compositionality in this way is hard to understand 105 | without more structure (e.g., functorch) 106 | + Allows you to use PyTorch's subsystems (e.g., autograd) multiple 107 | times (e.g., as done in functorch) 108 | + Metadata between the inside and outside can diverge (not shown in 109 | this example, TODO: add to zoo) 110 | - Representation requires two tensors (inner and outer); sometimes 111 | this is unnecessary 112 | - You must synchronize the metadata for the two tensors. Historically 113 | we had a number of incomplete/incorrect implementations of this; 114 | this file shows how to correctly (and easily). 115 | """ 116 | 117 | def __init__(self, elem): 118 | super().__init__(elem) 119 | self.elem = elem 120 | 121 | @classmethod 122 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 123 | def unwrap(t): 124 | if isinstance(t, cls): 125 | return t.elem 126 | else: 127 | return t 128 | 129 | def wrap(t): 130 | if isinstance(t, Tensor) and not isinstance(t, cls): 131 | return cls(t) 132 | else: 133 | return t 134 | 135 | return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) 136 | 137 | 138 | parametrize_trivial = parametrize( 139 | "TrivialTensor", 140 | [ 141 | TrivialTensorViaInheritance, 142 | TrivialTensorViaComposition, 143 | ], 144 | name_fn=lambda x: x.__name__, 145 | ) 146 | 147 | 148 | # We run our tests on both formulations of trivial tensors to show that 149 | # in the trivial case, they are exactly equivalent 150 | class TrivialTensorTest(TestCase): 151 | @parametrize_trivial 152 | def test_no_cycle(self, TrivialTensor): 153 | fins = [] 154 | with disable_gc(): 155 | r = TrivialTensor(torch.empty(2)) 156 | w = weakref.ref(r, lambda _: fins.append(1)) 157 | self.assertEqual(fins, []) 158 | del r 159 | self.assertEqual(fins, [1]) 160 | del w 161 | 162 | @parametrize_trivial 163 | def test_no_copy(self, TrivialTensor): 164 | inner = torch.empty(2) 165 | outer = TrivialTensor(inner) 166 | self.assertEqual(inner.data_ptr(), outer.data_ptr()) 167 | 168 | @parametrize_trivial 169 | def test_basic(self, TrivialTensor): 170 | # NB: this is not so basic, this executes a shit ton of 171 | # ops, including inplace ops 172 | self.assertEqual( 173 | (TrivialTensor(torch.tensor(1.0)) + TrivialTensor(torch.tensor(2.0))), 174 | TrivialTensor(torch.tensor(3.0)), 175 | ) 176 | 177 | 178 | instantiate_parametrized_tests(TrivialTensorTest) 179 | 180 | 181 | if __name__ == "__main__": 182 | run_tests() 183 | 184 | 185 | # Random commentary 186 | # Although this sounds trivial, it is nontrivial, both in terms 187 | # of behavior as well as implementation. 188 | # 189 | # - Behaviorally, trivial wrapper tensors are complicated because 190 | # they allow you to layer preexisting tensor features multiple 191 | # times (ala functorch) in a way that is impossible in normal 192 | # tensors. This is because there are two tensors involved: 193 | # the outer wrapper tensor, as well as the inner tensor. 194 | # 195 | # - Implementation, trivial wrapper tensors are complicated because 196 | # the outer wrapper tensor must faithfully replicate all of the 197 | # properties (including storage and strides) of the inner tensor. 198 | # This is not so easy to do, and most existing wrapper tensor 199 | # implementations in the wild do not do this correctly, and 200 | # subsequently fail asserts in PyTorch autograd when running 201 | # PyTorch with DEBUG. 202 | # 203 | # This tensor could have been implemented in terms of Alban's 204 | # WrapperTensor, but I wanted to keep all of the implementation 205 | # in one place for easier modification, because as you will see, 206 | # doing this completely correctly is quite involved. 207 | # 208 | # We have an interesting problem for the constructor. What if you 209 | # pass in a view to the TrivialWrapperTensor? Do we accurately 210 | # represent the storage in this situation. If we accurately represent it, 211 | # then what if you call TrivialWrapperTensor on that view again; there 212 | # is no way to recover the new meta storage you had previously allocated. 213 | # If we don't accurately represent it, we're at risk of failing 214 | # autograd tests (but maybe this is OK if you don't expect to 215 | # autograd across the boundary). 216 | # 217 | # How to autograd through the constructor of TrivialWrapperTensor? 218 | # 219 | # Current idea: 220 | # - constructor is OK, even for views, but we'll construct a fresh 221 | # storage on entry each time. use_count 1 on storage is safest 222 | # but if you wrap the same tensor multiple times they are 223 | # disconnected 224 | # 225 | # Another idea for storage is to point to the SAME storage as the 226 | # tensor we're wrapping 227 | -------------------------------------------------------------------------------- /uint4_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch._prims_common as utils 3 | 4 | def down_size(size): 5 | assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" 6 | return (*size[:-1], size[-1] // 2) 7 | 8 | def up_size(size): 9 | return (*size[:-1], size[-1] * 2) 10 | 11 | def fill_defaults(args, n, defaults_tail): 12 | """ 13 | __torch_dispatch__ doesn't guarantee the number of arguments you are 14 | passed (e.g., defaulted arguments are not passed); but usually it is 15 | convenient to pad out the arguments list with defaults. This function 16 | helps you do that. 17 | 18 | Args: 19 | args: the list of positional arguments passed to __torch_dispatch__ 20 | n: the number of arguments you are expecting to get 21 | defaults_tail: default values for the arguments, starting from the 22 | end of the list 23 | 24 | Example: 25 | 26 | >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) 27 | [1, 2, 3, 4, 5] 28 | >>> fill_defaults([1, 2, 3], 5, [None, None, None]) 29 | [1, 2, 3, None, None]] 30 | """ 31 | if n - len(defaults_tail) > len(args): 32 | raise RuntimeError("not enough defaults to fill arguments") 33 | r = list(args) 34 | for i in range(len(args), n): 35 | r.append(defaults_tail[i - n + len(defaults_tail)]) 36 | return r 37 | 38 | # from 39 | # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 40 | def unpack_uint4(quantized_data) -> torch.Tensor: 41 | """Get the original weight from the normalized float weight format""" 42 | # since we are using uint8 we will decode 2 entries per byte 43 | # Shift elements down 4 and select out the bottom 4 bits 44 | first_elements = (quantized_data >> 4).to(torch.uint8) 45 | second_elements = (quantized_data & 0b1111).to(torch.uint8) 46 | 47 | return torch.stack([first_elements, second_elements], dim=-1) 48 | 49 | class UInt4Tensor(torch.Tensor): 50 | @staticmethod 51 | def __new__(cls, elem): 52 | # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) 53 | # weird shit will happen 54 | assert elem.dtype is torch.uint8 55 | return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64) 56 | 57 | def __init__(self, elem): 58 | self.elem = elem 59 | 60 | def tolist(self): 61 | return self.to(torch.uint8).tolist() 62 | 63 | @classmethod 64 | def __torch_dispatch__(cls, func, types, args, kwargs=None): 65 | if func is torch.ops.aten.view.default: 66 | self, size = args 67 | size = utils.infer_size(size, self.numel()) 68 | assert not kwargs 69 | # WARNING: views not preserved 70 | return UInt4Tensor(self.elem.reshape(down_size(size))) 71 | elif func is torch.ops.aten._to_copy.default: 72 | self, = args 73 | if kwargs == {'dtype': torch.uint8}: 74 | return unpack_uint4(self.elem).view(self.shape) # no wrap 75 | else: 76 | raise NotImplementedError(f"_to_copy {kwargs}") 77 | elif func is torch.ops.aten.unbind.int: 78 | # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to 79 | # create four tensors containing one element each. But we can't 80 | # do this with uint4 because such a tensor's size is not divisible 81 | # by bytes. What I am going to do instead is promote to uint8 82 | # when this happens 83 | self, dim = fill_defaults(args, 2, [0]) 84 | if dim != self.dim() - 1: 85 | raise NotImplementedError(f"unbind dim={dim}") 86 | else: 87 | # We're unbinding the last dimension, need to promote 88 | return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) 89 | elif func is torch.ops.aten.select.int: 90 | self, dim, index = args 91 | if dim != self.dim() - 1: 92 | return UInt4Tensor(torch.ops.aten.select.int(self.elem, dim, index)) 93 | else: 94 | raise NotImplementedError(f"select dim={dim}") 95 | elif func is torch.ops.aten.slice.Tensor: 96 | self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) 97 | if dim == self.dim() - 1: 98 | # hard case 99 | if step != 1: 100 | raise NotImplementedError(f"slice step={step}") 101 | assert start % 2 == 0, start 102 | assert end >= self.shape[dim] or end % 2 == 0, end 103 | return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start // 2, end // 2, 1)) 104 | else: 105 | # easy case 106 | return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step)) 107 | raise NotImplementedError(f"{func}") 108 | 109 | __torch_function__ = torch._C._disabled_torch_function_impl 110 | 111 | 112 | x = UInt4Tensor(torch.tensor([ 113 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 114 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 115 | [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 116 | ], dtype=torch.uint8)) 117 | 118 | print(x.shape) # (3, 8) 119 | print(x.to(torch.uint8)) 120 | print(x) 121 | 122 | print(x[0:1, :]) 123 | print(x[:, 2:6]) 124 | -------------------------------------------------------------------------------- /use_cpu_for_rng.py: -------------------------------------------------------------------------------- 1 | from torch._dispatch.python import enable_python_dispatcher, no_python_dispatcher 2 | import torch 3 | 4 | # TODO: See https://github.com/pytorch/pytorch/issues/88109 for why 5 | # you have to use BackendSelect here and CUDA doesn't work 6 | @torch.ops.aten.randn.default.py_impl(torch._C.DispatchKey.BackendSelect) 7 | def randn(size, device=None, **kwargs): 8 | with no_python_dispatcher(): 9 | r = torch.ops.aten.randn.default(size, device='cpu', **kwargs) 10 | return r.to(device) 11 | 12 | # TODO: do the rest of the random functions 13 | 14 | # Hack to apply it globally 15 | ctx = enable_python_dispatcher() 16 | ctx.__enter__() 17 | 18 | if torch.cuda.is_available(): 19 | torch.manual_seed(0) 20 | x = torch.randn(10, device='cpu') 21 | torch.manual_seed(0) 22 | y = torch.ops.aten.randn.default([10], device='cuda') 23 | torch.testing.assert_close(x, y.cpu()) 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Any 3 | 4 | import torch 5 | from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten 6 | 7 | # Dumping ground for utilities that should eventual make their way into 8 | # PyTorch proper 9 | 10 | 11 | @contextlib.contextmanager 12 | def no_dispatch(): 13 | guard = torch._C._DisableTorchDispatch() 14 | try: 15 | yield 16 | finally: 17 | del guard 18 | 19 | 20 | def tree_map2(fn: Any, pytree1: PyTree, pytree2: PyTree) -> PyTree: 21 | flat_args1, spec1 = tree_flatten(pytree1) 22 | flat_args2, spec2 = tree_flatten(pytree2) 23 | assert spec1 == spec2 24 | return tree_unflatten([fn(i, j) for i, j in zip(flat_args1, flat_args2)], spec1) 25 | 26 | 27 | # IDK if this is actually useful or not 28 | def unmake_subclass(tensor): 29 | with no_dispatch(): 30 | return torch.Tensor._make_subclass(torch.Tensor, tensor) 31 | 32 | 33 | def fill_defaults(args, n, defaults_tail): 34 | """ 35 | __torch_dispatch__ doesn't guarantee the number of arguments you are 36 | passed (e.g., defaulted arguments are not passed); but usually it is 37 | convenient to pad out the arguments list with defaults. This function 38 | helps you do that. 39 | 40 | Args: 41 | args: the list of positional arguments passed to __torch_dispatch__ 42 | n: the number of arguments you are expecting to get 43 | defaults_tail: default values for the arguments, starting from the 44 | end of the list 45 | 46 | Example: 47 | 48 | >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) 49 | [1, 2, 3, 4, 5] 50 | >>> fill_defaults([1, 2, 3], 5, [None, None, None]) 51 | [1, 2, 3, None, None]] 52 | """ 53 | if n - len(defaults_tail) > len(args): 54 | raise RuntimeError("not enough defaults to fill arguments") 55 | r = list(args) 56 | for i in range(len(args), n): 57 | r.append(defaults_tail[i - n + len(defaults_tail)]) 58 | return r 59 | -------------------------------------------------------------------------------- /verifier_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from base_tensor import BaseTensor 3 | from torch.fx import Interpreter, Node 4 | from torch.testing._internal.common_utils import run_tests, TestCase 5 | 6 | from tracer_tensor import dispatch_trace 7 | 8 | # https://github.com/albanD/subclass_zoo/blob/33d7afe63c2a336e01eaf3e81fba085a68e3955f/bug_zoo.py#L18-L24 9 | 10 | # how to do speculate and validate 11 | # - need a function under trace (dispatch_trace) 12 | # - first time run with normal TracerTensor 13 | # - second time run with VerifierTensor 14 | # recovery is not necessary 15 | 16 | 17 | class Verifier: 18 | def __init__(self, interpreter, node): 19 | self.node = node 20 | # We aren't actually going to run the interpreter, it's just 21 | # here for fetch_attr 22 | self.interpreter = interpreter 23 | # TODO: IDK if there's a better way to do this 24 | self.constant_map = {} 25 | 26 | def advance(self): 27 | node = self.node 28 | self.node = node.next 29 | 30 | # Whenever constant nodes show up, FX will give these get_attr nodes. 31 | # When we're verifying torch dispatch calls this is not relevant, 32 | # but we do need to know about these so that we can appropriately 33 | # check if the user is reusing the correct constants. 34 | while node.op == "get_attr": 35 | self.constant_map[self.interpreter.fetch_attr(node.target)] = node 36 | node = self.node 37 | self.node = node.next 38 | 39 | return node 40 | 41 | def constant_node(self, t): 42 | return self.constant_map[t] 43 | 44 | 45 | VERIFIER = None 46 | 47 | 48 | class VerifierTensor(BaseTensor): 49 | @staticmethod 50 | def __new__(cls, elem, node): 51 | return super().__new__(cls, elem) 52 | 53 | def __init__(self, elem, node): 54 | self.node = node 55 | 56 | @classmethod 57 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 58 | # Verify that this is correct 59 | node = VERIFIER.advance() 60 | assert node.op == "call_function", node.op 61 | assert node.target == func 62 | 63 | def translate(n, v): 64 | if isinstance(n, Node): 65 | if isinstance(v, VerifierTensor): 66 | assert n is v.node 67 | return v 68 | else: 69 | assert n is VERIFIER.constant_node(v) 70 | # Need to translate constants to meta so that 71 | # we satisfy device checks 72 | return v.to("meta") 73 | else: 74 | assert n == v 75 | return v 76 | 77 | meta_args = [] 78 | meta_kwargs = {} 79 | for i, n in enumerate(node.args): 80 | meta_args.append(translate(n, args[i])) 81 | for k, n in node.kwargs.items(): 82 | meta_kwargs[k] = translate(n, kwargs[k]) 83 | assert len(node.kwargs) == len(kwargs) 84 | 85 | r = super().__torch_dispatch__(func, types, tuple(meta_args), meta_kwargs) 86 | 87 | # For the multi-outputs need to advance verifier past the indexing 88 | # nodes 89 | if isinstance(r, list): 90 | raise NotImplementedError 91 | elif isinstance(r, tuple): 92 | raise NotImplementedError 93 | else: 94 | return VerifierTensor(r, node) 95 | 96 | 97 | class SpeculatingJit: 98 | def __init__(self, root): 99 | self.root = root 100 | self.graph = None 101 | self.interpreter = None 102 | 103 | def transform(self, graph): 104 | return graph 105 | 106 | def __call__(self, *args): 107 | if self.graph is None: 108 | r, self.graph = dispatch_trace(self.root, args) 109 | self.interpreter = Interpreter(self.transform(self.graph)) 110 | return r 111 | else: 112 | # assume the placeholder nodes are first 113 | # TODO: there is a problem with the verifier design here which 114 | # is that it is not possible to free constants that are captured 115 | # by the graph, which might be important for memory usage 116 | # if FX transformation did weight transformation. I think what 117 | # you want to do is stub out the tensors with meta "shadows" 118 | # that have a correspondence to getattr nodes but it is a little 119 | # fiddly to implement 120 | global VERIFIER 121 | VERIFIER = Verifier( 122 | Interpreter(self.graph), next(iter(self.graph.graph.nodes)) 123 | ) 124 | i = 0 125 | verifier_args = [] 126 | for a in args: 127 | n = VERIFIER.advance() 128 | assert n.op == "placeholder" 129 | verifier_args.append(VerifierTensor(a.to("meta"), n)) 130 | r = self.interpreter.run(*args) 131 | verifier_r = self.root(*verifier_args) 132 | VERIFIER = None 133 | assert r.shape == verifier_r.shape 134 | assert r.dtype == verifier_r.dtype 135 | return r 136 | 137 | 138 | class VerifierTensorTest(TestCase): 139 | def test_basic(self): 140 | def root(x, y): 141 | # TODO: x + y is annoying to debug because the exception gets 142 | # swallowed 143 | return torch.add(x, y) 144 | 145 | f = SpeculatingJit(root) 146 | r = f(torch.zeros(2), torch.zeros(2)) 147 | self.assertEqual(r, torch.zeros(2)) 148 | r2 = f(torch.ones(2), torch.zeros(2)) 149 | self.assertEqual(r2, torch.ones(2)) 150 | 151 | def test_constant(self): 152 | x = torch.zeros(2) 153 | 154 | def root(y): 155 | return torch.add(x, y) 156 | 157 | f = SpeculatingJit(root) 158 | r = f(torch.zeros(2)) 159 | self.assertEqual(r, torch.zeros(2)) 160 | r2 = f(torch.ones(2)) 161 | self.assertEqual(r2, torch.ones(2)) 162 | 163 | def test_validation_failure(self): 164 | i = 0 165 | 166 | def root(x, y): 167 | nonlocal i 168 | i += 1 169 | if i == 1: 170 | return torch.add(x, y) 171 | else: 172 | return torch.mul(x, y) 173 | 174 | f = SpeculatingJit(root) 175 | r = f(torch.zeros(2), torch.zeros(2)) 176 | self.assertEqual(r, torch.zeros(2)) 177 | self.assertRaises(AssertionError, lambda: f(torch.ones(2), torch.zeros(2))) 178 | 179 | 180 | if __name__ == "__main__": 181 | run_tests() 182 | --------------------------------------------------------------------------------