├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── examples.ipynb ├── setup.py ├── test └── test.py └── torchviz ├── __init__.py └── dot.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI for torchviz 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | test: 13 | name: Test on Python ${{ matrix.python-version }} 14 | runs-on: ubuntu-latest 15 | 16 | strategy: 17 | matrix: 18 | python-version: [3.8, 3.9, '3.10', 3.11, 3.12] 19 | 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v3 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install system dependencies 30 | run: | 31 | sudo apt-get update 32 | sudo apt-get install -y graphviz 33 | 34 | - name: Install Python dependencies 35 | run: | 36 | python -m pip install --upgrade pip 37 | pip install torch --index-url https://download.pytorch.org/whl/cpu 38 | pip install . 39 | pip install pytest 40 | 41 | - name: Run tests 42 | run: | 43 | pytest test/test.py 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | #####=== Python ===##### 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | #####=== JetBrains ===##### 111 | # User-specific stuff 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/**/usage.statistics.xml 115 | .idea/**/dictionaries 116 | .idea/**/shelf 117 | 118 | # Sensitive or high-churn files 119 | .idea/**/dataSources/ 120 | .idea/**/dataSources.ids 121 | .idea/**/dataSources.local.xml 122 | .idea/**/sqlDataSources.xml 123 | .idea/**/dynamic.xml 124 | .idea/**/uiDesigner.xml 125 | .idea/**/dbnavigator.xml 126 | 127 | # Gradle 128 | .idea/**/gradle.xml 129 | .idea/**/libraries 130 | 131 | # CMake 132 | cmake-build-*/ 133 | 134 | # Mongo Explorer plugin 135 | .idea/**/mongoSettings.xml 136 | 137 | # File-based project format 138 | *.iws 139 | 140 | # IntelliJ 141 | out/ 142 | 143 | # mpeltonen/sbt-idea plugin 144 | .idea_modules/ 145 | 146 | # JIRA plugin 147 | atlassian-ide-plugin.xml 148 | 149 | # Cursive Clojure plugin 150 | .idea/replstate.xml 151 | 152 | # Crashlytics plugin (for Android Studio and IntelliJ) 153 | com_crashlytics_export_strings.xml 154 | crashlytics.properties 155 | crashlytics-build.properties 156 | fabric.properties 157 | 158 | # Editor-based Rest Client 159 | .idea/httpRequests 160 | 161 | # Rider-specific rules 162 | *.sln.iml 163 | 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sergey Zagoruyko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: dist 2 | dist: 3 | python setup.py sdist 4 | 5 | .PHONY: upload 6 | upload: dist 7 | twine upload dist/* 8 | 9 | .PHONY: clean 10 | clean: 11 | @rm -rf build dist *.egg-info 12 | 13 | .PHONY: test 14 | test: 15 | PYTHONPATH=. python test/test.py 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorchViz 2 | ======= 3 | 4 | A small package to create visualizations of PyTorch execution graphs and traces. 5 | 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/szagoruyko/pytorchviz/blob/master/examples.ipynb) 7 | 8 | ## Installation 9 | 10 | Install graphviz, e.g.: 11 | 12 | ``` 13 | brew install graphviz 14 | ``` 15 | 16 | Install the package itself: 17 | 18 | ``` 19 | pip install torchviz 20 | ``` 21 | 22 | 23 | ## Usage 24 | Example usage of `make_dot`: 25 | ``` 26 | model = nn.Sequential() 27 | model.add_module('W0', nn.Linear(8, 16)) 28 | model.add_module('tanh', nn.Tanh()) 29 | model.add_module('W1', nn.Linear(16, 1)) 30 | 31 | x = torch.randn(1, 8) 32 | y = model(x) 33 | 34 | make_dot(y.mean(), params=dict(model.named_parameters())) 35 | ``` 36 | ![image](https://user-images.githubusercontent.com/13428986/110844921-ff3f7500-8277-11eb-912e-3ba03623fdf5.png) 37 | 38 | Set `show_attrs=True` and `show_saved=True` to see what autograd saves for the backward pass. (Note that this is only available for pytorch >= 1.9.) 39 | ``` 40 | model = nn.Sequential() 41 | model.add_module('W0', nn.Linear(8, 16)) 42 | model.add_module('tanh', nn.Tanh()) 43 | model.add_module('W1', nn.Linear(16, 1)) 44 | 45 | x = torch.randn(1, 8) 46 | y = model(x) 47 | 48 | make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True) 49 | ``` 50 | ![image](https://user-images.githubusercontent.com/13428986/110845186-4ded0f00-8278-11eb-88d2-cc33413bb261.png) 51 | 52 | ## Acknowledgements 53 | 54 | The script was moved from [functional-zoo](https://github.com/szagoruyko/functional-zoo) where it was created with the help of Adam Paszke, Soumith Chintala, Anton Osokin, and uses bits from [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch). 55 | Other contributors are [@willprice](https://github.com/willprice), [@soulitzer](https://github.com/soulitzer), [@albanD](https://github.com/albanD). 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import shutil 4 | import sys 5 | from setuptools import setup, find_packages 6 | 7 | VERSION = '0.0.2' 8 | 9 | long_description = "" 10 | 11 | setup_info = dict( 12 | # Metadata 13 | name='torchviz', 14 | version=VERSION, 15 | author='Sergey Zagoruyko', 16 | author_email='sergey.zagoruyko@enpc.fr', 17 | url='https://github.com/pytorch/pytorchviz', 18 | description='A small package to create visualizations of PyTorch execution graphs', 19 | long_description=long_description, 20 | license='BSD', 21 | 22 | # Package info 23 | packages=find_packages(exclude=('test',)), 24 | 25 | zip_safe=True, 26 | 27 | install_requires=[ 28 | 'torch', 29 | 'graphviz' 30 | ] 31 | ) 32 | 33 | setup(**setup_info) 34 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from torch import nn 4 | from torchviz import make_dot, make_dot_from_trace 5 | 6 | 7 | def make_mlp_and_input(): 8 | model = nn.Sequential() 9 | model.add_module('W0', nn.Linear(8, 16)) 10 | model.add_module('tanh', nn.Tanh()) 11 | model.add_module('W1', nn.Linear(16, 1)) 12 | x = torch.randn(1, 8) 13 | return model, x 14 | 15 | 16 | class TestTorchviz(unittest.TestCase): 17 | 18 | def test_mlp_make_dot(self): 19 | model, x = make_mlp_and_input() 20 | y = model(x) 21 | dot = make_dot(y.mean(), params=dict(model.named_parameters())) 22 | 23 | def test_double_backprop_make_dot(self): 24 | model, x = make_mlp_and_input() 25 | x.requires_grad = True 26 | 27 | def double_backprop(inputs, net): 28 | y = net(x).mean() 29 | grad, = torch.autograd.grad(y, x, create_graph=True, retain_graph=True) 30 | return grad.pow(2).mean() + y 31 | 32 | dot = make_dot(double_backprop(x, model), params=dict(list(model.named_parameters()) + [('x', x)])) 33 | 34 | def test_lstm_make_dot(self): 35 | lstm_cell = nn.LSTMCell(128, 128) 36 | x = torch.randn(1, 128) 37 | dot = make_dot(lstm_cell(x), params=dict(list(lstm_cell.named_parameters()))) 38 | 39 | 40 | if __name__ == '__main__': 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /torchviz/__init__.py: -------------------------------------------------------------------------------- 1 | from .dot import make_dot, make_dot_from_trace 2 | -------------------------------------------------------------------------------- /torchviz/dot.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from distutils.version import LooseVersion 3 | from graphviz import Digraph 4 | import torch 5 | from torch.autograd import Variable 6 | import warnings 7 | 8 | Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op')) 9 | 10 | # Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*` 11 | SAVED_PREFIX = "_saved_" 12 | 13 | def get_fn_name(fn, show_attrs, max_attr_chars): 14 | name = str(type(fn).__name__) 15 | if not show_attrs: 16 | return name 17 | attrs = dict() 18 | for attr in dir(fn): 19 | if not attr.startswith(SAVED_PREFIX): 20 | continue 21 | val = getattr(fn, attr) 22 | attr = attr[len(SAVED_PREFIX):] 23 | if torch.is_tensor(val): 24 | attrs[attr] = "[saved tensor]" 25 | elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val): 26 | attrs[attr] = "[saved tensors]" 27 | else: 28 | attrs[attr] = str(val) 29 | if not attrs: 30 | return name 31 | max_attr_chars = max(max_attr_chars, 3) 32 | col1width = max(len(k) for k in attrs.keys()) 33 | col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars) 34 | sep = "-" * max(col1width + col2width + 2, len(name)) 35 | attrstr = '%-' + str(col1width) + 's: %' + str(col2width)+ 's' 36 | truncate = lambda s: s[:col2width - 3] + "..." if len(s) > col2width else s 37 | params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items()) 38 | return name + '\n' + sep + '\n' + params 39 | 40 | 41 | def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50): 42 | """ Produces Graphviz representation of PyTorch autograd graph. 43 | 44 | If a node represents a backward function, it is gray. Otherwise, the node 45 | represents a tensor and is either blue, orange, or green: 46 | - Blue: reachable leaf tensors that requires grad (tensors whose `.grad` 47 | fields will be populated during `.backward()`) 48 | - Orange: saved tensors of custom autograd functions as well as those 49 | saved by built-in backward nodes 50 | - Green: tensor passed in as outputs 51 | - Dark green: if any output is a view, we represent its base tensor with 52 | a dark green node. 53 | 54 | Args: 55 | var: output tensor 56 | params: dict of (name, tensor) to add names to node that requires grad 57 | show_attrs: whether to display non-tensor attributes of backward nodes 58 | (Requires PyTorch version >= 1.9) 59 | show_saved: whether to display saved tensor nodes that are not by custom 60 | autograd functions. Saved tensor nodes for custom functions, if 61 | present, are always displayed. (Requires PyTorch version >= 1.9) 62 | max_attr_chars: if show_attrs is `True`, sets max number of characters 63 | to display for any given attribute. 64 | """ 65 | if LooseVersion(torch.__version__) < LooseVersion("1.9") and \ 66 | (show_attrs or show_saved): 67 | warnings.warn( 68 | "make_dot: showing grad_fn attributes and saved variables" 69 | " requires PyTorch version >= 1.9. (This does NOT apply to" 70 | " saved tensors saved by custom autograd functions.)") 71 | 72 | if params is not None: 73 | assert all(isinstance(p, Variable) for p in params.values()) 74 | param_map = {id(v): k for k, v in params.items()} 75 | else: 76 | param_map = {} 77 | 78 | node_attr = dict(style='filled', 79 | shape='box', 80 | align='left', 81 | fontsize='10', 82 | ranksep='0.1', 83 | height='0.2', 84 | fontname='monospace') 85 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 86 | seen = set() 87 | 88 | def size_to_str(size): 89 | return '(' + (', ').join(['%d' % v for v in size]) + ')' 90 | 91 | def get_var_name(var, name=None): 92 | if not name: 93 | name = param_map[id(var)] if id(var) in param_map else '' 94 | return '%s\n %s' % (name, size_to_str(var.size())) 95 | 96 | def add_nodes(fn): 97 | assert not torch.is_tensor(fn) 98 | if fn in seen: 99 | return 100 | seen.add(fn) 101 | 102 | if show_saved: 103 | for attr in dir(fn): 104 | if not attr.startswith(SAVED_PREFIX): 105 | continue 106 | val = getattr(fn, attr) 107 | seen.add(val) 108 | attr = attr[len(SAVED_PREFIX):] 109 | if torch.is_tensor(val): 110 | dot.edge(str(id(fn)), str(id(val)), dir="none") 111 | dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange') 112 | if isinstance(val, tuple): 113 | for i, t in enumerate(val): 114 | if torch.is_tensor(t): 115 | name = attr + '[%s]' % str(i) 116 | dot.edge(str(id(fn)), str(id(t)), dir="none") 117 | dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange') 118 | 119 | if hasattr(fn, 'variable'): 120 | # if grad_accumulator, add the node for `.variable` 121 | var = fn.variable 122 | seen.add(var) 123 | dot.node(str(id(var)), get_var_name(var), fillcolor='lightblue') 124 | dot.edge(str(id(var)), str(id(fn))) 125 | 126 | # add the node for this grad_fn 127 | dot.node(str(id(fn)), get_fn_name(fn, show_attrs, max_attr_chars)) 128 | 129 | # recurse 130 | if hasattr(fn, 'next_functions'): 131 | for u in fn.next_functions: 132 | if u[0] is not None: 133 | dot.edge(str(id(u[0])), str(id(fn))) 134 | add_nodes(u[0]) 135 | 136 | # note: this used to show .saved_tensors in pytorch0.2, but stopped 137 | # working* as it was moved to ATen and Variable-Tensor merged 138 | # also note that this still works for custom autograd functions 139 | if hasattr(fn, 'saved_tensors'): 140 | for t in fn.saved_tensors: 141 | seen.add(t) 142 | dot.edge(str(id(t)), str(id(fn)), dir="none") 143 | dot.node(str(id(t)), get_var_name(t), fillcolor='orange') 144 | 145 | 146 | def add_base_tensor(var, color='darkolivegreen1'): 147 | if var in seen: 148 | return 149 | seen.add(var) 150 | dot.node(str(id(var)), get_var_name(var), fillcolor=color) 151 | if (var.grad_fn): 152 | add_nodes(var.grad_fn) 153 | dot.edge(str(id(var.grad_fn)), str(id(var))) 154 | if var._is_view(): 155 | add_base_tensor(var._base, color='darkolivegreen3') 156 | dot.edge(str(id(var._base)), str(id(var)), style="dotted") 157 | 158 | 159 | # handle multiple outputs 160 | if isinstance(var, tuple): 161 | for v in var: 162 | add_base_tensor(v) 163 | else: 164 | add_base_tensor(var) 165 | 166 | resize_graph(dot) 167 | 168 | return dot 169 | 170 | 171 | def make_dot_from_trace(trace): 172 | """ This functionality is not available in pytorch core at 173 | https://pytorch.org/docs/stable/tensorboard.html 174 | """ 175 | # from tensorboardX 176 | raise NotImplementedError("This function has been moved to pytorch core and " 177 | "can be found here: https://pytorch.org/docs/stable/tensorboard.html") 178 | 179 | 180 | def resize_graph(dot, size_per_element=0.15, min_size=12): 181 | """Resize the graph according to how much content it contains. 182 | 183 | Modify the graph in place. 184 | """ 185 | # Get the approximate number of nodes and edges 186 | num_rows = len(dot.body) 187 | content_size = num_rows * size_per_element 188 | size = max(min_size, content_size) 189 | size_str = str(size) + "," + str(size) 190 | dot.graph_attr.update(size=size_str) 191 | --------------------------------------------------------------------------------