├── requirements.txt ├── doc ├── usage_to_frame.png ├── UsagePopTorch.ipynb └── Usage.ipynb ├── requirements-dev.txt ├── tensor_tracker ├── __init__.py └── core.py ├── .gitignore ├── setup.py ├── setup.cfg ├── .github └── workflows │ └── ci.yaml ├── LICENSE ├── README.md ├── tests └── test_tensor_tracker.py └── dev /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0 2 | -------------------------------------------------------------------------------- /doc/usage_to_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/pytorch-tensor-tracker/HEAD/doc/usage_to_frame.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | black 3 | flake8 4 | isort 5 | jupyter 6 | mypy 7 | pandas 8 | pandas-stubs 9 | pdoc3 10 | pylint 11 | pytest 12 | pytest-cov 13 | -------------------------------------------------------------------------------- /tensor_tracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from .core import * # NOQA: F401 F403 4 | from .core import __all__, __doc__ # NOQA: F401 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .coverage 3 | .devcontainer.json 4 | .devcontainer.dockerfile 5 | .dockerignore 6 | .venv 7 | .vscode 8 | 9 | /build 10 | /dist 11 | /doc/tensor_tracker 12 | /local 13 | 14 | *.egg-info/ 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from pathlib import Path 4 | 5 | import setuptools 6 | 7 | setuptools.setup( 8 | name="tensor-tracker", 9 | version="0.1", 10 | install_requires=Path("requirements.txt").read_text().rstrip("\n").split("\n"), 11 | ) 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [options] 2 | packages = 3 | tensor_tracker 4 | 5 | [mypy] 6 | pretty = true 7 | show_error_codes = true 8 | strict = true 9 | check_untyped_defs = true 10 | 11 | [mypy-setuptools.*] 12 | ignore_missing_imports = True 13 | 14 | [flake8] 15 | # See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html 16 | max-line-length = 88 17 | extend-ignore = E203 18 | 19 | [isort] 20 | profile = black 21 | 22 | [tool:pytest] 23 | addopts = --no-cov-on-fail 24 | 25 | [coverage:report] 26 | # fail_under = 100 27 | skip_covered = true 28 | show_missing = true 29 | exclude_lines = 30 | pragma: no cover 31 | raise NotImplementedError 32 | assert False 33 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: { branches: [ "main" ] } 5 | pull_request: 6 | workflow_dispatch: 7 | 8 | concurrency: 9 | # Run everything on main, most-recent on PR builds 10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | ci: 15 | runs-on: ubuntu-latest 16 | container: pytorch/pytorch 17 | timeout-minutes: 10 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Install dependencies 21 | run: pip install -r requirements-dev.txt 22 | - name: Run CI 23 | run: ./dev ci 24 | - name: Publish documentation 25 | if: ${{github.ref == 'refs/heads/main'}} 26 | uses: Cecilapp/GitHub-Pages-deploy@v3 27 | env: { GITHUB_TOKEN: "${{ github.token }}" } 28 | with: 29 | build_dir: doc/tensor_tracker 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Graphcore Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensor tracker 2 | 3 | [API documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) | [Example](doc/Usage.ipynb) 4 | 5 | Flexibly track outputs and grad-outputs of `torch.nn.Module`. 6 | 7 | **Installation:** 8 | 9 | ```bash 10 | pip install git+https://github.com/graphcore-research/pytorch-tensor-tracker 11 | ``` 12 | 13 | **Usage:** 14 | 15 | Use `tensor_tracker.track(module)` as a context manager to start capturing tensors from within your module's forward and backward passes: 16 | 17 | ```python 18 | import tensor_tracker 19 | 20 | with tensor_tracker.track(module) as tracker: 21 | module(inputs).backward() 22 | 23 | print(tracker) # => Tracker(stashes=8, tracking=0) 24 | ``` 25 | 26 | Now `Tracker` is filled with stashes, containing copies of fwd/bwd tensors at (sub)module outputs. (Note, this can consume a lot of memory.) 27 | 28 | It behaves like a list of `Stash` objects, with their attached `value`, usually a tensor or tuple of tensors. We can also use `to_frame()` to get a Pandas table of summary statistics: 29 | 30 | ```python 31 | print(list(tracker)) 32 | # => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)), 33 | # ...] 34 | 35 | display(tracker.to_frame()) 36 | ``` 37 | 38 | tensor tracker to_frame output 39 | 40 | See the [documentation](https://graphcore-research.github.io/pytorch-tensor-tracker/) for more info, or for a more practical example, see our demo of [visualising transformer activations & gradients using UMAP](doc/Example.ipynb). To use on IPU with PopTorch, please see [Usage (PopTorch)](doc/UsagePopTorch.ipynb). 41 | 42 | 43 | ## License 44 | 45 | Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License ([LICENSE](LICENSE)). 46 | 47 | Our dependencies are (see [requirements.txt](requirements.txt)): 48 | 49 | | Component | About | License | 50 | | --- | --- | --- | 51 | | torch | Machine learning framework | BSD 3-Clause | 52 | 53 | We also use additional Python dependencies for development/testing (see [requirements-dev.txt](requirements-dev.txt)). 54 | -------------------------------------------------------------------------------- /doc/UsagePopTorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n", 8 | "\n", 9 | "# Usage example (PopTorch)\n", 10 | "\n", 11 | "Create a toy model to track:" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch import nn, Tensor\n", 22 | "\n", 23 | "class Model(nn.Module):\n", 24 | " def __init__(self):\n", 25 | " super().__init__()\n", 26 | " self.embed = nn.Embedding(10, 4)\n", 27 | " self.project = nn.Linear(4, 4)\n", 28 | " self.unembed = nn.Linear(4, 10)\n", 29 | "\n", 30 | " def forward(self, tokens: Tensor) -> Tensor:\n", 31 | " logits = self.unembed(self.project(self.embed(tokens)))\n", 32 | " return nn.functional.cross_entropy(logits, tokens)\n", 33 | "\n", 34 | "torch.manual_seed(100)\n", 35 | "module = Model()\n", 36 | "inputs = torch.randint(0, 10, (3,))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "**PopTorch:**\n", 44 | "\n", 45 | "A few modifications to work with PopTorch:\n", 46 | " - Any tracking should be contained within `forward()`.\n", 47 | " - We shouldn't call `tensor.cpu()`, as this is implicit on returned tensors.\n", 48 | " - We don't have access to the backward pass." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 8, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stderr", 58 | "output_type": "stream", 59 | "text": [ 60 | "[13:48:29.771] [poptorch:cpp] [warning] [DISPATCHER] Type coerced from Long to Int for tensor id 138\n", 61 | "Graph compilation: 100%|██████████| 100/100 [00:04<00:00]\n" 62 | ] 63 | }, 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "[Stash(name='embed', type=, grad=False, value=tensor([[ 0.4520, -0.1066, 1.1028, -1.1578],\n", 68 | " [-0.4866, -0.1484, -1.6819, 0.7740],\n", 69 | " [-1.0324, 0.2063, -0.7983, 0.4695]])),\n", 70 | " Stash(name='project', type=, grad=False, value=tensor([[ 1.2474, 0.4518, 0.2115, -0.6991],\n", 71 | " [-0.3698, -0.1035, -0.2358, -0.3482],\n", 72 | " [ 0.2165, 0.2673, -0.1278, -0.1348]])),\n", 73 | " Stash(name='unembed', type=, grad=False, value=tensor([[-0.2676, 0.0945, 0.4727, 0.0716, -0.1146, 0.2311, 0.4380, -0.1172,\n", 74 | " 0.6078, -0.0632],\n", 75 | " [ 0.2343, -0.0936, 0.1143, -0.0777, 0.0148, -0.0783, 0.2015, 0.1975,\n", 76 | " 0.2441, -0.3956],\n", 77 | " [ 0.1521, -0.0814, 0.2678, 0.0481, 0.1128, -0.0149, 0.3953, 0.2135,\n", 78 | " 0.3824, -0.2818]]))]" 79 | ] 80 | }, 81 | "metadata": {}, 82 | "output_type": "display_data" 83 | } 84 | ], 85 | "source": [ 86 | "from typing import Dict\n", 87 | "import poptorch\n", 88 | "import tensor_tracker\n", 89 | "\n", 90 | "class TrackingModel(Model):\n", 91 | " def forward(self, inputs: Tensor) -> Dict[str, Tensor]:\n", 92 | " with tensor_tracker.track(self, stash_value=lambda t: t) as tracker:\n", 93 | " loss = super().forward(inputs)\n", 94 | " return loss, [t.__dict__ for t in tracker]\n", 95 | "\n", 96 | "loss, tracked = poptorch.inferenceModel(TrackingModel())(inputs)\n", 97 | "tracked = [tensor_tracker.Stash(**d) for d in tracked]\n", 98 | "display(tracked)\n", 99 | "# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n", 100 | "# ...]" 101 | ] 102 | } 103 | ], 104 | "metadata": { 105 | "kernelspec": { 106 | "display_name": "Python 3", 107 | "language": "python", 108 | "name": "python3" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.8.10" 121 | }, 122 | "orig_nbformat": 4 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | -------------------------------------------------------------------------------- /tests/test_tensor_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from dataclasses import dataclass 4 | 5 | import pandas as pd 6 | import pytest 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import tensor_tracker 11 | 12 | pytestmark = pytest.mark.filterwarnings("ignore:.+backward hooks:UserWarning") 13 | 14 | 15 | class Mul3(nn.Module): 16 | def forward(self, x: Tensor) -> Tensor: 17 | return x * 3 18 | 19 | 20 | @dataclass 21 | class Output: 22 | thing: Tensor 23 | status: str 24 | 25 | 26 | class EgModule(nn.Module): 27 | def __init__(self) -> None: 28 | super().__init__() 29 | self.m3 = Mul3() 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | def forward(self, x: Tensor) -> Output: 33 | return Output(thing=self.sigmoid(self.m3(x=x)), status="ok") 34 | 35 | 36 | def test_basic() -> None: 37 | module = EgModule() 38 | with tensor_tracker.track(module) as tracker: 39 | assert str(tracker) == "Tracker(stashes=0, tracking=6)" 40 | x = torch.full((8,), 0.7, requires_grad=True) 41 | out = module(x) 42 | grad = torch.full_like(out.thing, 1000) 43 | out.thing.backward(grad) 44 | 45 | # Tracked stashes 46 | sigmoid_grad = 1000 * torch.sigmoid(3 * x) * (1 - torch.sigmoid(3 * x)) 47 | expected_stash = [ 48 | tensor_tracker.Stash("m3", Mul3, False, 3 * x), 49 | tensor_tracker.Stash("sigmoid", nn.Sigmoid, False, torch.sigmoid(3 * x)), 50 | tensor_tracker.Stash("", EgModule, False, Output(torch.sigmoid(3 * x), "ok")), 51 | # tensor_tracker.Stash("", EgModule, True, grad), # can't track dataclass grad 52 | tensor_tracker.Stash("sigmoid", nn.Sigmoid, True, grad), 53 | tensor_tracker.Stash("m3", Mul3, True, sigmoid_grad), 54 | ] 55 | assert len(tracker) == len(expected_stash) 56 | assert str(tracker) == f"Tracker(stashes={len(expected_stash)}, tracking=0)" 57 | for i, expected in enumerate(expected_stash): 58 | assert tracker[i].name == expected.name 59 | assert tracker[i].type == expected.type 60 | assert tracker[i].grad == expected.grad 61 | if isinstance(expected.value, Output): 62 | assert isinstance(tracker[i].value, Output) 63 | assert torch.equal(tracker[i].value.thing, expected.value.thing) 64 | assert tracker[i].value.status == expected.value.status 65 | else: 66 | assert torch.equal(tracker[i].first_value, expected.value), expected 67 | 68 | # Pandas output 69 | df = tracker.to_frame() 70 | assert len(df) == len(expected_stash) 71 | df.iloc[0] == pd.Series( 72 | dict(name="m3", type="test_tensor_tracker.Mul3", grad=False, std=0.0) 73 | ) 74 | 75 | # Reset 76 | tracker.clear() 77 | assert len(tracker) == 0 78 | 79 | 80 | def test_custom_stash_value() -> None: 81 | torch.manual_seed(100) 82 | module = EgModule() 83 | with tensor_tracker.track( 84 | module, stash_value=lambda t: t.std().cpu().detach() 85 | ) as tracker: 86 | x = torch.randn(1000) 87 | module(x) 88 | 89 | assert tracker[0].name == "m3" 90 | assert tracker[0].value == (3 * x).std() 91 | assert all( 92 | s.first_value.ndim == 0 for s in tracker if isinstance(s.first_value, Tensor) 93 | ) 94 | 95 | 96 | def test_custom_stash() -> None: 97 | def custom_stash(event: tensor_tracker.Event) -> tensor_tracker.Stash: 98 | stash = tensor_tracker.Stash( 99 | event.name, 100 | event.type, 101 | event.grad, 102 | tensor_tracker.rmap_tensor(event.value, tensor_tracker.default_stash_value), 103 | ) 104 | args = tensor_tracker.rmap_tensor( 105 | event.args, tensor_tracker.default_stash_value 106 | ) 107 | setattr(stash, "args", args) 108 | kwargs = tensor_tracker.rmap_tensor( 109 | event.kwargs, tensor_tracker.default_stash_value 110 | ) 111 | setattr(stash, "kwargs", kwargs) 112 | return stash 113 | 114 | module = EgModule() 115 | with tensor_tracker.track(module, stash=custom_stash) as tracker: 116 | x = torch.ones(1000) 117 | module(x) 118 | 119 | assert tracker[0].name == "m3" 120 | assert torch.equal(tracker[0].value, 3 * x) 121 | assert torch.equal(tracker[0].kwargs["x"], x) # type:ignore[attr-defined] 122 | assert torch.equal(tracker[1].args[0], 3 * x) # type:ignore[attr-defined] 123 | 124 | # Cannot specify both stash=? and stash_value=? 125 | with pytest.raises(ValueError), tensor_tracker.track( 126 | module, stash=custom_stash, stash_value=lambda t: t 127 | ): 128 | pass 129 | -------------------------------------------------------------------------------- /dev: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 3 | 4 | """Dev task launcher.""" 5 | 6 | import argparse 7 | import datetime 8 | import os 9 | import subprocess 10 | import sys 11 | from pathlib import Path 12 | from typing import Any, Callable, Iterable, List, Optional, TypeVar 13 | 14 | # Utilities 15 | 16 | 17 | def run(command: Iterable[Any], gdb: bool = False) -> None: 18 | """Run a command, terminating on failure.""" 19 | cmd = [str(arg) for arg in command if arg is not None] 20 | if gdb: 21 | cmd = ["gdb", "-ex", "catch throw", "-ex", "run", "--args"] + cmd 22 | print("$ " + " ".join(cmd), file=sys.stderr) 23 | environ = os.environ.copy() 24 | environ["PYTHONPATH"] = f"{os.getcwd()}:{environ.get('PYTHONPATH', '')}" 25 | exit_code = subprocess.call(cmd, env=environ) 26 | if exit_code: 27 | sys.exit(exit_code) 28 | 29 | 30 | T = TypeVar("T") 31 | 32 | 33 | def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]: 34 | """Declare a CLI command / arguments for that command.""" 35 | 36 | def wrap(func: T) -> T: 37 | if not hasattr(func, "cli_args"): 38 | setattr(func, "cli_args", []) 39 | if args or kwargs: 40 | getattr(func, "cli_args").append((args, kwargs)) 41 | return func 42 | 43 | return wrap 44 | 45 | 46 | # Commands 47 | 48 | PYTHON_ROOTS = ["tensor_tracker", "tests", "doc", "dev", "setup.py"] 49 | 50 | 51 | @cli("-s", "--no-capture", action="store_false", dest="capture") 52 | @cli("-k", "--filter") 53 | def tests(capture: bool, filter: Optional[str]) -> None: 54 | """run Python tests""" 55 | run( 56 | [ 57 | "python3", 58 | "-m", 59 | "pytest", 60 | "tests", 61 | None if filter else "--cov=tensor_tracker", 62 | *(["-k", filter] if filter else []), 63 | None if capture else "-s", 64 | ], 65 | ) 66 | 67 | 68 | @cli() 69 | def lint() -> None: 70 | """run static analysis""" 71 | run(["python3", "-m", "flake8", *PYTHON_ROOTS]) 72 | run(["python3", "-m", "mypy", *(r for r in PYTHON_ROOTS if r != "doc")]) 73 | 74 | 75 | @cli("--check", action="store_true") 76 | def format(check: bool) -> None: 77 | """autoformat all sources""" 78 | run(["python3", "-m", "black", "--check" if check else None, *PYTHON_ROOTS]) 79 | run(["python3", "-m", "isort", "--check" if check else None, *PYTHON_ROOTS]) 80 | 81 | 82 | @cli() 83 | def copyright() -> None: 84 | """check for Graphcore copyright headers on relevant files""" 85 | command = ( 86 | "find " + " ".join(PYTHON_ROOTS) + " -type f -not -name *.pyc -not -name *.png" 87 | " | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'" 88 | ) 89 | print(f"$ {command}", file=sys.stderr) 90 | # Note: grep exit codes are not consistent between versions, so we don't use 91 | # check=True 92 | output = ( 93 | subprocess.run( 94 | command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT 95 | ) 96 | .stdout.decode() 97 | .strip() 98 | ) 99 | if output: 100 | print( 101 | "Error - failed copyright header check in:\n " 102 | + output.replace("\n", "\n "), 103 | file=sys.stderr, 104 | ) 105 | print("Template(s):") 106 | comment_prefixes = { 107 | {".cpp": "//"}.get(Path(f).suffix, "#") for f in output.split("\n") 108 | } 109 | for prefix in comment_prefixes: 110 | print( 111 | f"{prefix} Copyright (c) {datetime.datetime.now().year}" 112 | " Graphcore Ltd. All rights reserved.", 113 | file=sys.stderr, 114 | ) 115 | sys.exit(1) 116 | 117 | 118 | @cli() 119 | def doc() -> None: 120 | """generate API documentation""" 121 | subprocess.call(["rm", "-r", "doc/tensor_tracker"]) 122 | run( 123 | [ 124 | "python3", 125 | "-m", 126 | "pdoc", 127 | "--html", 128 | "--output-dir", 129 | "doc", 130 | "tensor_tracker", 131 | ] 132 | ) 133 | for notebook in ["Example", "Usage"]: 134 | run( 135 | [ 136 | "jupyter", 137 | "nbconvert", 138 | "--to", 139 | "html", 140 | f"doc/{notebook}.ipynb", 141 | "--output-dir", 142 | "doc/tensor_tracker", 143 | "--output", 144 | f"{notebook.lower()}.html", 145 | ] 146 | ) 147 | 148 | 149 | @cli("--skip", nargs="*", default=[], help="commands to skip") 150 | def ci(skip: List[str] = []) -> None: 151 | """run all continuous integration tests & checks""" 152 | if "tests" not in skip: 153 | tests(capture=True, filter=None) 154 | if "lint" not in skip: 155 | lint() 156 | if "format" not in skip: 157 | format(check=True) 158 | if "copyright" not in skip: 159 | copyright() 160 | if "doc" not in skip: 161 | doc() 162 | 163 | 164 | # Script 165 | 166 | 167 | def _main() -> None: 168 | parser = argparse.ArgumentParser(description=__doc__) 169 | parser.set_defaults(action=ci) 170 | 171 | subs = parser.add_subparsers() 172 | for key, value in globals().items(): 173 | if hasattr(value, "cli_args"): 174 | sub = subs.add_parser(key.replace("_", "-"), help=value.__doc__) 175 | for args, kwargs in value.cli_args: 176 | sub.add_argument(*args, **kwargs) 177 | sub.set_defaults(action=value) 178 | 179 | cli_args = vars(parser.parse_args()) 180 | action = cli_args.pop("action") 181 | action(**cli_args) 182 | 183 | 184 | if __name__ == "__main__": 185 | _main() 186 | -------------------------------------------------------------------------------- /tensor_tracker/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | """Utility for tracking activations and gradients at `nn.Module` outputs. 4 | 5 | Use `track` to start tracking a module & submodules. Then use the original module 6 | as usual. Your `Tracker` will be filled with a list of `Stash`es, containing 7 | copies of fwd/bwd tensors at (sub)module outputs. (Beware, this can consume 8 | a lot of memory.) 9 | 10 | Usage ([notebook](usage.html)): 11 | 12 | ``` 13 | with tensor_tracker.track(model) as tracker: 14 | model(inputs).backward() 15 | 16 | print(list(tracker)) 17 | # => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)), 18 | # ...] 19 | 20 | display(tracker.to_frame()) # requires 'pandas' 21 | ``` 22 | 23 | Advanced usage: 24 | 25 | - Filter modules based on name: 26 | `track(include="", exclude="")` 27 | 28 | - Pre-transform tracked tensors to save memory: 29 | `track(stash_value=lambda t: t.std().detach().cpu())` 30 | 31 | - Customise tracked state: 32 | `track(stash=lambda event: ...)` 33 | 34 | - Manually register/unregister hooks: 35 | `tracker = Tracker(); tracker.register(...); tracker.unregister()` 36 | 37 | See also: [example of 38 | visualising transformer activations & gradients using UMAP](example.html). 39 | """ 40 | 41 | import dataclasses 42 | import re 43 | from dataclasses import dataclass 44 | from functools import partial 45 | from types import TracebackType 46 | from typing import ( 47 | Any, 48 | Callable, 49 | Dict, 50 | Iterator, 51 | List, 52 | Optional, 53 | Pattern, 54 | Tuple, 55 | Type, 56 | Union, 57 | ) 58 | 59 | import torch.utils.hooks 60 | from torch import Tensor, nn 61 | 62 | 63 | @dataclass 64 | class Event: 65 | name: str 66 | type: Type[nn.Module] 67 | grad: bool 68 | value: Any 69 | args: Tuple[Any, ...] 70 | kwargs: Dict[str, Any] 71 | 72 | 73 | @dataclass 74 | class Stash: 75 | name: str 76 | type: Type[nn.Module] 77 | grad: bool 78 | value: Any # output(s) or grad_output(s) 79 | 80 | @property 81 | def first_value(self) -> Any: 82 | def _value(v: Any) -> Any: 83 | if isinstance(v, (tuple, list)) and len(v) >= 1: 84 | return _value(v[0]) 85 | return v 86 | 87 | return _value(self.value) 88 | 89 | 90 | StashFn = Callable[[Event], Stash] 91 | StashValueFn = Callable[[Tensor], Any] 92 | 93 | 94 | def rmap_tensor(value: Any, fn: Callable[[Tensor], Any]) -> Any: 95 | if isinstance(value, (tuple, list)): 96 | return type(value)(rmap_tensor(a, fn) for a in value) 97 | if isinstance(value, dict): 98 | return {rmap_tensor(k, fn): rmap_tensor(a, fn) for k, a in value.items()} 99 | if dataclasses.is_dataclass(value): 100 | return type(value)(**{k: rmap_tensor(v, fn) for k, v in value.__dict__.items()}) 101 | if isinstance(value, Tensor): 102 | return fn(value) 103 | return value 104 | 105 | 106 | def default_stash_value(tensor: Tensor) -> Tensor: 107 | return tensor.detach().cpu().clone() 108 | 109 | 110 | def default_stash(event: Event, stash_value: StashValueFn) -> Stash: 111 | return Stash( 112 | event.name, event.type, event.grad, rmap_tensor(event.value, stash_value) 113 | ) 114 | 115 | 116 | def get_stash_fn( 117 | stash_value: Optional[StashValueFn] = None, stash: Optional[StashFn] = None 118 | ) -> StashFn: 119 | if stash_value and stash: 120 | raise ValueError("Cannot provide StashValueFn and StashFn to get_stash_fn()") 121 | if stash: 122 | return stash 123 | return partial(default_stash, stash_value=stash_value or default_stash_value) 124 | 125 | 126 | NamePattern = Union[None, Pattern[str], str] 127 | 128 | 129 | class Tracker: 130 | def __init__(self, stash: StashFn): 131 | self.stashes: List[Stash] = [] 132 | self._handles: List[torch.utils.hooks.RemovableHandle] = [] 133 | self._stash = stash 134 | 135 | # Registration/tracking 136 | 137 | def __enter__(self) -> "Tracker": 138 | return self 139 | 140 | def __exit__( 141 | self, 142 | exc_type: Optional[Type[BaseException]], 143 | exc: Optional[BaseException], 144 | traceback: Optional[TracebackType], 145 | ) -> None: 146 | self.unregister() 147 | 148 | def clear(self) -> None: 149 | self.stashes.clear() 150 | 151 | def register(self, module: nn.Module, name: str = "", grad: bool = True) -> None: 152 | self._handles.append( 153 | module.register_forward_hook( 154 | partial(self._forward_hook, name=name), with_kwargs=True 155 | ) 156 | ) 157 | if grad: 158 | self._handles.append( 159 | module.register_full_backward_pre_hook( 160 | partial(self._backward_hook, name=name) 161 | ) 162 | ) 163 | 164 | def register_all( 165 | self, 166 | module: nn.Module, 167 | grad: bool = True, 168 | include: NamePattern = None, 169 | exclude: NamePattern = None, 170 | ) -> None: 171 | include = re.compile(include) if isinstance(include, str) else include 172 | exclude = re.compile(exclude) if isinstance(exclude, str) else exclude 173 | for name, child in module.named_modules(): 174 | if ((not include) or include.search(name)) and not ( 175 | exclude and exclude.search(name) 176 | ): 177 | self.register(child, name, grad=grad) 178 | 179 | def unregister(self) -> None: 180 | for handle in self._handles: 181 | handle.remove() 182 | self._handles.clear() 183 | 184 | def _forward_hook( 185 | self, 186 | module: nn.Module, 187 | args: Tuple[Any], 188 | kwargs: Dict[str, Any], 189 | output: Any, 190 | *, 191 | name: str, 192 | ) -> None: 193 | self.stashes.append( 194 | self._stash(Event(name, type(module), False, output, args, kwargs)) 195 | ) 196 | 197 | def _backward_hook(self, module: nn.Module, grad_output: Any, *, name: str) -> None: 198 | self.stashes.append( 199 | self._stash(Event(name, type(module), True, grad_output, (), {})) 200 | ) 201 | 202 | # Read results 203 | 204 | def __str__(self) -> str: 205 | return f"Tracker(stashes={len(self)}, tracking={len(self._handles)})" 206 | 207 | def __iter__(self) -> Iterator[Stash]: 208 | return iter(self.stashes) 209 | 210 | def __getitem__(self, index: int) -> Stash: 211 | return self.stashes[index] 212 | 213 | def __len__(self) -> int: 214 | return len(self.stashes) 215 | 216 | def to_frame( 217 | self, 218 | stat: Callable[[Tensor], Tensor] = torch.std, 219 | stat_name: Optional[str] = None, 220 | ) -> "pandas.DataFrame": # type:ignore[name-defined] # NOQA: F821 221 | import pandas 222 | 223 | column_name = ( 224 | getattr(stat, "__name__", "value") if stat_name is None else stat_name 225 | ) 226 | 227 | def to_item(stash: Stash) -> Dict[str, Any]: 228 | d = stash.__dict__.copy() 229 | d.pop("value") 230 | v = stash.first_value 231 | d[column_name] = stat(v).item() if isinstance(v, Tensor) else None 232 | d["type"] = f"{stash.type.__module__}.{stash.type.__name__}" 233 | return d 234 | 235 | return pandas.DataFrame.from_dict(map(to_item, self)) # type:ignore[arg-type] 236 | 237 | 238 | def track( 239 | module: nn.Module, 240 | grad: bool = True, 241 | include: NamePattern = None, 242 | exclude: NamePattern = None, 243 | stash_value: Optional[StashValueFn] = None, 244 | stash: Optional[StashFn] = None, 245 | ) -> Tracker: 246 | tracker = Tracker(get_stash_fn(stash_value=stash_value, stash=stash)) 247 | tracker.register_all(module, grad=grad, include=include, exclude=exclude) 248 | return tracker 249 | 250 | 251 | track.__doc__ = __doc__ 252 | 253 | __all__ = [ 254 | "Event", 255 | "Stash", 256 | "StashFn", 257 | "StashValueFn", 258 | "rmap_tensor", 259 | "default_stash_value", 260 | "default_stash", 261 | "get_stash_fn", 262 | "Tracker", 263 | "track", 264 | ] 265 | -------------------------------------------------------------------------------- /doc/Usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n", 8 | "\n", 9 | "# Usage example\n", 10 | "\n", 11 | "Create a toy model to track:" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch import nn, Tensor\n", 22 | "\n", 23 | "class Model(nn.Module):\n", 24 | " def __init__(self):\n", 25 | " super().__init__()\n", 26 | " self.embed = nn.Embedding(10, 4)\n", 27 | " self.project = nn.Linear(4, 4)\n", 28 | " self.unembed = nn.Linear(4, 10)\n", 29 | "\n", 30 | " def forward(self, tokens: Tensor) -> Tensor:\n", 31 | " logits = self.unembed(self.project(self.embed(tokens)))\n", 32 | " return nn.functional.cross_entropy(logits, tokens)\n", 33 | "\n", 34 | "torch.manual_seed(100)\n", 35 | "module = Model()\n", 36 | "inputs = torch.randint(0, 10, (3,))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "Use `tensor_tracker` to capture forward pass activations and backward pass gradients from our toy model. By default, the tracker saves full tensors, as a list of `tensor_tracker.Stash` objects." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Tracker(stashes=8, tracking=0)\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "import tensor_tracker\n", 61 | "\n", 62 | "with tensor_tracker.track(module) as tracker:\n", 63 | " module(inputs).backward()\n", 64 | "\n", 65 | "print(tracker)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "Note that calls are only tracked within the `with` context. Then, the tracker behaves like a list of `Stash` objects, with attached `name`, `value` etc." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "[Stash(name='embed', type=, grad=False, value=tensor([[ 0.4698, 1.2426, 0.5403, -1.1454],\n", 84 | " [-0.8425, -0.6475, -0.2189, -1.1326],\n", 85 | " [ 0.1268, 1.3564, 0.5632, -0.1039]])),\n", 86 | " Stash(name='project', type=, grad=False, value=tensor([[-0.6237, -0.1652, 0.3782, -0.8841],\n", 87 | " [-0.9278, -0.2848, -0.8688, -0.4719],\n", 88 | " [-0.3449, 0.3643, 0.3935, -0.6302]])),\n", 89 | " Stash(name='unembed', type=, grad=False, value=tensor([[-0.2458, 1.0003, -0.8231, -0.1405, -0.2964, 0.5837, 0.2889, 0.2059,\n", 90 | " -0.6114, -0.5916],\n", 91 | " [-0.6345, 1.0882, -0.4304, -0.2196, -0.0426, 0.9428, 0.2051, 0.5897,\n", 92 | " -0.2217, -0.9132],\n", 93 | " [-0.0822, 0.9985, -0.7097, -0.3139, -0.4805, 0.6878, 0.2560, 0.3254,\n", 94 | " -0.4447, -0.3332]])),\n", 95 | " Stash(name='', type=, grad=False, value=tensor(2.5663)),\n", 96 | " Stash(name='', type=, grad=True, value=(tensor(1.),)),\n", 97 | " Stash(name='unembed', type=, grad=True, value=(tensor([[ 0.0237, 0.0824, -0.3200, 0.0263, 0.0225, 0.0543, 0.0404, 0.0372,\n", 98 | " 0.0164, 0.0168],\n", 99 | " [ 0.0139, 0.0779, 0.0171, 0.0211, 0.0251, 0.0673, 0.0322, -0.2860,\n", 100 | " 0.0210, 0.0105],\n", 101 | " [-0.3066, 0.0787, 0.0143, 0.0212, 0.0179, 0.0577, 0.0374, 0.0401,\n", 102 | " 0.0186, 0.0208]]),)),\n", 103 | " Stash(name='project', type=, grad=True, value=(tensor([[-0.1755, 0.1306, 0.0443, -0.1823],\n", 104 | " [ 0.1202, -0.0728, 0.0066, -0.0839],\n", 105 | " [-0.1863, 0.0470, -0.1055, -0.0353]]),)),\n", 106 | " Stash(name='embed', type=, grad=True, value=(tensor([[-0.0108, 0.1086, -0.1304, -0.0370],\n", 107 | " [ 0.0534, -0.0029, 0.0078, -0.0074],\n", 108 | " [-0.0829, 0.0152, -0.1170, -0.0625]]),))]" 109 | ] 110 | }, 111 | "metadata": {}, 112 | "output_type": "display_data" 113 | } 114 | ], 115 | "source": [ 116 | "display(list(tracker))\n", 117 | "# => [Stash(name=\"embed\", type=nn.Embedding, grad=False, value=tensor(...)),\n", 118 | "# ...]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "As a higher-level API, `to_frame` computes summary statistics, defaulting to `torch.std`." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 4, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/html": [ 136 | "
\n", 137 | "\n", 150 | "\n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | "
nametypegradstd
0embedtorch.nn.modules.sparse.EmbeddingFalse0.853265
1projecttorch.nn.modules.linear.LinearFalse0.494231
2unembedtorch.nn.modules.linear.LinearFalse0.581503
3__main__.ModelFalseNaN
4__main__.ModelTrueNaN
5unembedtorch.nn.modules.linear.LinearTrue0.105266
6projecttorch.nn.modules.linear.LinearTrue0.112392
7embedtorch.nn.modules.sparse.EmbeddingTrue0.068816
\n", 219 | "
" 220 | ], 221 | "text/plain": [ 222 | " name type grad std\n", 223 | "0 embed torch.nn.modules.sparse.Embedding False 0.853265\n", 224 | "1 project torch.nn.modules.linear.Linear False 0.494231\n", 225 | "2 unembed torch.nn.modules.linear.Linear False 0.581503\n", 226 | "3 __main__.Model False NaN\n", 227 | "4 __main__.Model True NaN\n", 228 | "5 unembed torch.nn.modules.linear.Linear True 0.105266\n", 229 | "6 project torch.nn.modules.linear.Linear True 0.112392\n", 230 | "7 embed torch.nn.modules.sparse.Embedding True 0.068816" 231 | ] 232 | }, 233 | "metadata": {}, 234 | "output_type": "display_data" 235 | } 236 | ], 237 | "source": [ 238 | "display(tracker.to_frame())" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.10.12" 259 | }, 260 | "orig_nbformat": 4 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 2 264 | } 265 | --------------------------------------------------------------------------------