The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    └── workflows
    │   └── ci.yml
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── examples.ipynb
├── setup.py
├── test
    └── test.py
└── torchviz
    ├── __init__.py
    └── dot.py


/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
 1 | name: CI for torchviz
 2 | 
 3 | on:
 4 |   push:
 5 |     branches:
 6 |       - master
 7 |   pull_request:
 8 |     branches:
 9 |       - master
10 | 
11 | jobs:
12 |   test:
13 |     name: Test on Python ${{ matrix.python-version }}
14 |     runs-on: ubuntu-latest
15 | 
16 |     strategy:
17 |       matrix:
18 |         python-version: [3.8, 3.9, '3.10', 3.11, 3.12]
19 | 
20 |     steps:
21 |       - name: Checkout code
22 |         uses: actions/checkout@v3
23 | 
24 |       - name: Set up Python
25 |         uses: actions/setup-python@v4
26 |         with:
27 |           python-version: ${{ matrix.python-version }}
28 | 
29 |       - name: Install system dependencies
30 |         run: |
31 |           sudo apt-get update
32 |           sudo apt-get install -y graphviz
33 | 
34 |       - name: Install Python dependencies
35 |         run: |
36 |           python -m pip install --upgrade pip
37 |           pip install torch --index-url https://download.pytorch.org/whl/cpu
38 |           pip install .
39 |           pip install pytest
40 | 
41 |       - name: Run tests
42 |         run: |
43 |           pytest test/test.py
44 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
  1 | #### joe made this: http://goel.io/joe
  2 | 
  3 | #####=== Python ===#####
  4 | 
  5 | # Byte-compiled / optimized / DLL files
  6 | __pycache__/
  7 | *.py[cod]
  8 | *$py.class
  9 | 
 10 | # C extensions
 11 | *.so
 12 | 
 13 | # Distribution / packaging
 14 | .Python
 15 | build/
 16 | develop-eggs/
 17 | dist/
 18 | downloads/
 19 | eggs/
 20 | .eggs/
 21 | lib/
 22 | lib64/
 23 | parts/
 24 | sdist/
 25 | var/
 26 | wheels/
 27 | *.egg-info/
 28 | .installed.cfg
 29 | *.egg
 30 | MANIFEST
 31 | 
 32 | # PyInstaller
 33 | #  Usually these files are written by a python script from a template
 34 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 35 | *.manifest
 36 | *.spec
 37 | 
 38 | # Installer logs
 39 | pip-log.txt
 40 | pip-delete-this-directory.txt
 41 | 
 42 | # Unit test / coverage reports
 43 | htmlcov/
 44 | .tox/
 45 | .coverage
 46 | .coverage.*
 47 | .cache
 48 | nosetests.xml
 49 | coverage.xml
 50 | *.cover
 51 | .hypothesis/
 52 | .pytest_cache/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | 
 63 | # Flask stuff:
 64 | instance/
 65 | .webassets-cache
 66 | 
 67 | # Scrapy stuff:
 68 | .scrapy
 69 | 
 70 | # Sphinx documentation
 71 | docs/_build/
 72 | 
 73 | # PyBuilder
 74 | target/
 75 | 
 76 | # Jupyter Notebook
 77 | .ipynb_checkpoints
 78 | 
 79 | # pyenv
 80 | .python-version
 81 | 
 82 | # celery beat schedule file
 83 | celerybeat-schedule
 84 | 
 85 | # SageMath parsed files
 86 | *.sage.py
 87 | 
 88 | # Environments
 89 | .env
 90 | .venv
 91 | env/
 92 | venv/
 93 | ENV/
 94 | env.bak/
 95 | venv.bak/
 96 | 
 97 | # Spyder project settings
 98 | .spyderproject
 99 | .spyproject
100 | 
101 | # Rope project settings
102 | .ropeproject
103 | 
104 | # mkdocs documentation
105 | /site
106 | 
107 | # mypy
108 | .mypy_cache/
109 | 
110 | #####=== JetBrains ===#####
111 | # User-specific stuff
112 | .idea/**/workspace.xml
113 | .idea/**/tasks.xml
114 | .idea/**/usage.statistics.xml
115 | .idea/**/dictionaries
116 | .idea/**/shelf
117 | 
118 | # Sensitive or high-churn files
119 | .idea/**/dataSources/
120 | .idea/**/dataSources.ids
121 | .idea/**/dataSources.local.xml
122 | .idea/**/sqlDataSources.xml
123 | .idea/**/dynamic.xml
124 | .idea/**/uiDesigner.xml
125 | .idea/**/dbnavigator.xml
126 | 
127 | # Gradle
128 | .idea/**/gradle.xml
129 | .idea/**/libraries
130 | 
131 | # CMake
132 | cmake-build-*/
133 | 
134 | # Mongo Explorer plugin
135 | .idea/**/mongoSettings.xml
136 | 
137 | # File-based project format
138 | *.iws
139 | 
140 | # IntelliJ
141 | out/
142 | 
143 | # mpeltonen/sbt-idea plugin
144 | .idea_modules/
145 | 
146 | # JIRA plugin
147 | atlassian-ide-plugin.xml
148 | 
149 | # Cursive Clojure plugin
150 | .idea/replstate.xml
151 | 
152 | # Crashlytics plugin (for Android Studio and IntelliJ)
153 | com_crashlytics_export_strings.xml
154 | crashlytics.properties
155 | crashlytics-build.properties
156 | fabric.properties
157 | 
158 | # Editor-based Rest Client
159 | .idea/httpRequests
160 | 
161 | # Rider-specific rules
162 | *.sln.iml
163 | 
164 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2019 Sergey Zagoruyko
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
 1 | .PHONY: dist
 2 | dist:
 3 | 	python setup.py sdist
 4 | 
 5 | .PHONY: upload
 6 | upload: dist
 7 | 	twine upload dist/*
 8 | 
 9 | .PHONY: clean
10 | clean:
11 | 	@rm -rf build dist *.egg-info
12 | 
13 | .PHONY: test
14 | test:
15 | 	PYTHONPATH=. python test/test.py
16 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
 1 | PyTorchViz
 2 | =======
 3 | 
 4 | A small package to create visualizations of PyTorch execution graphs and traces.
 5 | 
 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/szagoruyko/pytorchviz/blob/master/examples.ipynb)
 7 | 
 8 | ## Installation
 9 | 
10 | Install graphviz, e.g.:
11 | 
12 | ```
13 | brew install graphviz
14 | ```
15 | 
16 | Install the package itself:
17 | 
18 | ```
19 | pip install torchviz
20 | ```
21 | 
22 | 
23 | ## Usage
24 | Example usage of `make_dot`:
25 | ```
26 | model = nn.Sequential()
27 | model.add_module('W0', nn.Linear(8, 16))
28 | model.add_module('tanh', nn.Tanh())
29 | model.add_module('W1', nn.Linear(16, 1))
30 | 
31 | x = torch.randn(1, 8)
32 | y = model(x)
33 | 
34 | make_dot(y.mean(), params=dict(model.named_parameters()))
35 | ```
36 | ![image](https://user-images.githubusercontent.com/13428986/110844921-ff3f7500-8277-11eb-912e-3ba03623fdf5.png)
37 | 
38 | Set `show_attrs=True` and `show_saved=True` to see what autograd saves for the backward pass. (Note that this is only available for pytorch >= 1.9.)
39 | ```
40 | model = nn.Sequential()
41 | model.add_module('W0', nn.Linear(8, 16))
42 | model.add_module('tanh', nn.Tanh())
43 | model.add_module('W1', nn.Linear(16, 1))
44 | 
45 | x = torch.randn(1, 8)
46 | y = model(x)
47 | 
48 | make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
49 | ```
50 | ![image](https://user-images.githubusercontent.com/13428986/110845186-4ded0f00-8278-11eb-88d2-cc33413bb261.png)
51 | 
52 | ## Acknowledgements
53 | 
54 | The script was moved from [functional-zoo](https://github.com/szagoruyko/functional-zoo) where it was created with the help of Adam Paszke, Soumith Chintala, Anton Osokin, and uses bits from [tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch).
55 | Other contributors are [@willprice](https://github.com/willprice), [@soulitzer](https://github.com/soulitzer), [@albanD](https://github.com/albanD).
56 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python
 2 | import os
 3 | import shutil
 4 | import sys
 5 | from setuptools import setup, find_packages
 6 | 
 7 | VERSION = '0.0.2'
 8 | 
 9 | long_description = ""
10 | 
11 | setup_info = dict(
12 |     # Metadata
13 |     name='torchviz',
14 |     version=VERSION,
15 |     author='Sergey Zagoruyko',
16 |     author_email='sergey.zagoruyko@enpc.fr',
17 |     url='https://github.com/pytorch/pytorchviz',
18 |     description='A small package to create visualizations of PyTorch execution graphs',
19 |     long_description=long_description,
20 |     license='BSD',
21 | 
22 |     # Package info
23 |     packages=find_packages(exclude=('test',)),
24 | 
25 |     zip_safe=True,
26 | 
27 |     install_requires=[
28 |         'torch',
29 |         'graphviz'
30 |     ]
31 | )
32 | 
33 | setup(**setup_info)
34 | 


--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | import torch
 3 | from torch import nn
 4 | from torchviz import make_dot, make_dot_from_trace
 5 | 
 6 | 
 7 | def make_mlp_and_input():
 8 |     model = nn.Sequential()
 9 |     model.add_module('W0', nn.Linear(8, 16))
10 |     model.add_module('tanh', nn.Tanh())
11 |     model.add_module('W1', nn.Linear(16, 1))
12 |     x = torch.randn(1, 8)
13 |     return model, x
14 | 
15 | 
16 | class TestTorchviz(unittest.TestCase):
17 | 
18 |     def test_mlp_make_dot(self):
19 |         model, x = make_mlp_and_input()
20 |         y = model(x)
21 |         dot = make_dot(y.mean(), params=dict(model.named_parameters()))
22 | 
23 |     def test_double_backprop_make_dot(self):
24 |         model, x = make_mlp_and_input()
25 |         x.requires_grad = True
26 | 
27 |         def double_backprop(inputs, net):
28 |             y = net(x).mean()
29 |             grad, = torch.autograd.grad(y, x, create_graph=True, retain_graph=True)
30 |             return grad.pow(2).mean() + y
31 | 
32 |         dot = make_dot(double_backprop(x, model), params=dict(list(model.named_parameters()) + [('x', x)]))
33 | 
34 |     def test_lstm_make_dot(self):
35 |         lstm_cell = nn.LSTMCell(128, 128)
36 |         x = torch.randn(1, 128)
37 |         dot = make_dot(lstm_cell(x), params=dict(list(lstm_cell.named_parameters())))
38 | 
39 | 
40 | if __name__ == '__main__':
41 |     unittest.main()
42 | 


--------------------------------------------------------------------------------
/torchviz/__init__.py:
--------------------------------------------------------------------------------
1 | from .dot import make_dot, make_dot_from_trace
2 | 


--------------------------------------------------------------------------------
/torchviz/dot.py:
--------------------------------------------------------------------------------
  1 | from collections import namedtuple
  2 | from distutils.version import LooseVersion
  3 | from graphviz import Digraph
  4 | import torch
  5 | from torch.autograd import Variable
  6 | import warnings
  7 | 
  8 | Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))
  9 | 
 10 | # Saved attrs for grad_fn (incl. saved variables) begin with `._saved_*`
 11 | SAVED_PREFIX = "_saved_"
 12 | 
 13 | def get_fn_name(fn, show_attrs, max_attr_chars):
 14 |     name = str(type(fn).__name__)
 15 |     if not show_attrs:
 16 |         return name
 17 |     attrs = dict()
 18 |     for attr in dir(fn):
 19 |         if not attr.startswith(SAVED_PREFIX):
 20 |             continue
 21 |         val = getattr(fn, attr)
 22 |         attr = attr[len(SAVED_PREFIX):]
 23 |         if torch.is_tensor(val):
 24 |             attrs[attr] = "[saved tensor]"
 25 |         elif isinstance(val, tuple) and any(torch.is_tensor(t) for t in val):
 26 |             attrs[attr] = "[saved tensors]"
 27 |         else:
 28 |             attrs[attr] = str(val)
 29 |     if not attrs:
 30 |         return name
 31 |     max_attr_chars = max(max_attr_chars, 3)
 32 |     col1width = max(len(k) for k in attrs.keys())
 33 |     col2width = min(max(len(str(v)) for v in attrs.values()), max_attr_chars)
 34 |     sep = "-" * max(col1width + col2width + 2, len(name))
 35 |     attrstr = '%-' + str(col1width) + 's: %' + str(col2width)+ 's'
 36 |     truncate = lambda s: s[:col2width - 3] + "..." if len(s) > col2width else s
 37 |     params = '\n'.join(attrstr % (k, truncate(str(v))) for (k, v) in attrs.items())
 38 |     return name + '\n' + sep + '\n' + params
 39 | 
 40 | 
 41 | def make_dot(var, params=None, show_attrs=False, show_saved=False, max_attr_chars=50):
 42 |     """ Produces Graphviz representation of PyTorch autograd graph.
 43 | 
 44 |     If a node represents a backward function, it is gray. Otherwise, the node
 45 |     represents a tensor and is either blue, orange, or green:
 46 |      - Blue: reachable leaf tensors that requires grad (tensors whose `.grad`
 47 |          fields will be populated during `.backward()`)
 48 |      - Orange: saved tensors of custom autograd functions as well as those
 49 |          saved by built-in backward nodes
 50 |      - Green: tensor passed in as outputs
 51 |      - Dark green: if any output is a view, we represent its base tensor with
 52 |          a dark green node.
 53 | 
 54 |     Args:
 55 |         var: output tensor
 56 |         params: dict of (name, tensor) to add names to node that requires grad
 57 |         show_attrs: whether to display non-tensor attributes of backward nodes
 58 |             (Requires PyTorch version >= 1.9)
 59 |         show_saved: whether to display saved tensor nodes that are not by custom
 60 |             autograd functions. Saved tensor nodes for custom functions, if
 61 |             present, are always displayed. (Requires PyTorch version >= 1.9)
 62 |         max_attr_chars: if show_attrs is `True`, sets max number of characters
 63 |             to display for any given attribute.
 64 |     """
 65 |     if LooseVersion(torch.__version__) < LooseVersion("1.9") and \
 66 |         (show_attrs or show_saved):
 67 |         warnings.warn(
 68 |             "make_dot: showing grad_fn attributes and saved variables"
 69 |             " requires PyTorch version >= 1.9. (This does NOT apply to"
 70 |             " saved tensors saved by custom autograd functions.)")
 71 | 
 72 |     if params is not None:
 73 |         assert all(isinstance(p, Variable) for p in params.values())
 74 |         param_map = {id(v): k for k, v in params.items()}
 75 |     else:
 76 |         param_map = {}
 77 | 
 78 |     node_attr = dict(style='filled',
 79 |                      shape='box',
 80 |                      align='left',
 81 |                      fontsize='10',
 82 |                      ranksep='0.1',
 83 |                      height='0.2',
 84 |                      fontname='monospace')
 85 |     dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
 86 |     seen = set()
 87 | 
 88 |     def size_to_str(size):
 89 |         return '(' + (', ').join(['%d' % v for v in size]) + ')'
 90 | 
 91 |     def get_var_name(var, name=None):
 92 |         if not name:
 93 |             name = param_map[id(var)] if id(var) in param_map else ''
 94 |         return '%s\n %s' % (name, size_to_str(var.size()))
 95 | 
 96 |     def add_nodes(fn):
 97 |         assert not torch.is_tensor(fn)
 98 |         if fn in seen:
 99 |             return
100 |         seen.add(fn)
101 | 
102 |         if show_saved:
103 |             for attr in dir(fn):
104 |                 if not attr.startswith(SAVED_PREFIX):
105 |                     continue
106 |                 val = getattr(fn, attr)
107 |                 seen.add(val)
108 |                 attr = attr[len(SAVED_PREFIX):]
109 |                 if torch.is_tensor(val):
110 |                     dot.edge(str(id(fn)), str(id(val)), dir="none")
111 |                     dot.node(str(id(val)), get_var_name(val, attr), fillcolor='orange')
112 |                 if isinstance(val, tuple):
113 |                     for i, t in enumerate(val):
114 |                         if torch.is_tensor(t):
115 |                             name = attr + '[%s]' % str(i)
116 |                             dot.edge(str(id(fn)), str(id(t)), dir="none")
117 |                             dot.node(str(id(t)), get_var_name(t, name), fillcolor='orange')
118 | 
119 |         if hasattr(fn, 'variable'):
120 |             # if grad_accumulator, add the node for `.variable`
121 |             var = fn.variable
122 |             seen.add(var)
123 |             dot.node(str(id(var)), get_var_name(var), fillcolor='lightblue')
124 |             dot.edge(str(id(var)), str(id(fn)))
125 | 
126 |         # add the node for this grad_fn
127 |         dot.node(str(id(fn)), get_fn_name(fn, show_attrs, max_attr_chars))
128 | 
129 |         # recurse
130 |         if hasattr(fn, 'next_functions'):
131 |             for u in fn.next_functions:
132 |                 if u[0] is not None:
133 |                     dot.edge(str(id(u[0])), str(id(fn)))
134 |                     add_nodes(u[0])
135 | 
136 |         # note: this used to show .saved_tensors in pytorch0.2, but stopped
137 |         # working* as it was moved to ATen and Variable-Tensor merged
138 |         # also note that this still works for custom autograd functions
139 |         if hasattr(fn, 'saved_tensors'):
140 |             for t in fn.saved_tensors:
141 |                 seen.add(t)
142 |                 dot.edge(str(id(t)), str(id(fn)), dir="none")
143 |                 dot.node(str(id(t)), get_var_name(t), fillcolor='orange')
144 | 
145 | 
146 |     def add_base_tensor(var, color='darkolivegreen1'):
147 |         if var in seen:
148 |             return
149 |         seen.add(var)
150 |         dot.node(str(id(var)), get_var_name(var), fillcolor=color)
151 |         if (var.grad_fn):
152 |             add_nodes(var.grad_fn)
153 |             dot.edge(str(id(var.grad_fn)), str(id(var)))
154 |         if var._is_view():
155 |             add_base_tensor(var._base, color='darkolivegreen3')
156 |             dot.edge(str(id(var._base)), str(id(var)), style="dotted")
157 | 
158 | 
159 |     # handle multiple outputs
160 |     if isinstance(var, tuple):
161 |         for v in var:
162 |             add_base_tensor(v)
163 |     else:
164 |         add_base_tensor(var)
165 | 
166 |     resize_graph(dot)
167 | 
168 |     return dot
169 | 
170 | 
171 | def make_dot_from_trace(trace):
172 |     """ This functionality is not available in pytorch core at
173 |     https://pytorch.org/docs/stable/tensorboard.html
174 |     """
175 |     # from tensorboardX
176 |     raise NotImplementedError("This function has been moved to pytorch core and "
177 |                               "can be found here: https://pytorch.org/docs/stable/tensorboard.html")
178 | 
179 | 
180 | def resize_graph(dot, size_per_element=0.15, min_size=12):
181 |     """Resize the graph according to how much content it contains.
182 | 
183 |     Modify the graph in place.
184 |     """
185 |     # Get the approximate number of nodes and edges
186 |     num_rows = len(dot.body)
187 |     content_size = num_rows * size_per_element
188 |     size = max(min_size, content_size)
189 |     size_str = str(size) + "," + str(size)
190 |     dot.graph_attr.update(size=size_str)
191 | 


--------------------------------------------------------------------------------