├── images
├── dtypes.png
├── experiments.graffle
├── sample-1.svg
├── sample-2.svg
├── ast.svg
└── mm.svg
├── talks
├── tensor-sensor.pdf
├── tensor-sensor.pptx
└── tensor-sensor-old-msgs.pptx
├── testing
├── ones.py
├── testexc.py
├── test_invalid_stat.py
├── test_nested.py
├── test3.py
├── test2.py
├── test_dtype.py
├── test_tensorflow.py
├── play_cmaps.py
├── test_incr_eval.py
├── str_size_matplotlib.py
├── test_tree_eval.py
├── test_jax.py
├── viz_testing.py
└── test_parser.py
├── contributing.md
├── .github
└── workflows
│ └── test.yml
├── LICENSE
├── tsensor
├── version.py
├── __init__.py
├── ast.py
├── parsing.py
├── analysis.py
└── viz.py
├── setup.py
├── .gitignore
└── README.md
/images/dtypes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/images/dtypes.png
--------------------------------------------------------------------------------
/talks/tensor-sensor.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor.pdf
--------------------------------------------------------------------------------
/talks/tensor-sensor.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor.pptx
--------------------------------------------------------------------------------
/images/experiments.graffle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/images/experiments.graffle
--------------------------------------------------------------------------------
/talks/tensor-sensor-old-msgs.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor-old-msgs.pptx
--------------------------------------------------------------------------------
/testing/ones.py:
--------------------------------------------------------------------------------
1 | # Regression test for https://github.com/parrt/tensor-sensor/issues/16
2 | # Should not throw exception or error, just show equation.
3 | import numpy as np
4 | import tsensor
5 |
6 | print(tsensor.__version__)
7 | with tsensor.explain():
8 | a = np.ones(3)
9 |
--------------------------------------------------------------------------------
/testing/testexc.py:
--------------------------------------------------------------------------------
1 | foo = None
2 |
3 | try:
4 | try:
5 | state = "bar"
6 | foo.append(state)
7 |
8 | except Exception as e:
9 | e.args = ("Appending '"+state+"' failed", *e.args)
10 | raise
11 |
12 | print(foo[0]) # would raise too
13 |
14 | except Exception as e:
15 | e.message = "foo"
16 | e.args = ("print(foo) failed: " + str(foo), *e.args)
17 | raise
--------------------------------------------------------------------------------
/contributing.md:
--------------------------------------------------------------------------------
1 | Thank you for considering a contribution to TensorSensor. I welcome updates to the examples, README, or the source code itself.
2 |
3 | Before you spend a lot of time creating a pull request (PR) for a new feature or a biggish change to the software, please contact me by email parrt@cs.usfca.edu to find out if it's something I'm interested in. One of the hardest things to do, as I've learned over the last 30 years pushing out open-source software, is to maintain project focus and conceptual integrity.
4 |
--------------------------------------------------------------------------------
/testing/test_invalid_stat.py:
--------------------------------------------------------------------------------
1 | import tsensor
2 | import numpy as np
3 |
4 | def f():
5 | # Currently can't handle double assign
6 | a = b = np.ones(1) @ np.ones(2)
7 |
8 | def A():
9 | with tsensor.clarify():
10 | f()
11 |
12 |
13 | def test_nested():
14 | msg = ""
15 | try:
16 | A()
17 | except BaseException as e:
18 | msg = e.args[0]
19 |
20 | expected = "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)"
21 | assert msg==expected
22 |
--------------------------------------------------------------------------------
/testing/test_nested.py:
--------------------------------------------------------------------------------
1 | # Test for https://github.com/parrt/tensor-sensor/issues/18
2 | # Nested clarify's and all catch exception
3 |
4 | import tsensor
5 | import numpy as np
6 |
7 | def f():
8 | np.ones(1) @ np.ones(2)
9 |
10 | def A():
11 | with tsensor.clarify():
12 | f()
13 |
14 | def B():
15 | with tsensor.clarify():
16 | A()
17 |
18 | def test_nested():
19 | msg = ""
20 | try:
21 | B()
22 | except BaseException as e:
23 | msg = e.args[0]
24 |
25 | expected = "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)\n"+\
26 | "Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)"
27 | assert msg==expected
28 |
--------------------------------------------------------------------------------
/testing/test3.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import tsensor
4 | import torch
5 | import sys
6 |
7 | W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
8 | b = torch.tensor([9, 10]).reshape(2, 1)
9 | x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
10 | h = torch.tensor([1,2])
11 |
12 | # fig, ax = plt.subplots(1,1)
13 | # # view = tsensor.pyviz("b + x", ax=ax, legend=True)
14 | # # view.savefig("/Users/parrt/Desktop/foo.pdf")
15 | # plt.show()
16 |
17 | W = torch.rand(size=(2000,2000), dtype=torch.float64)
18 | b = torch.rand(size=(2000,1), dtype=torch.float64)
19 | h = torch.zeros(size=(1_000_000,), dtype=int)
20 | x = torch.rand(size=(2000,1))
21 | z = torch.rand(size=(2000,1), dtype=torch.complex64)
22 | g = tsensor.astviz("b = W@b + (h+3).dot(h) + z",
23 | sys._getframe()) # eval, highlight vectors
24 | g.view()
25 |
26 | # with tsensor.explain():
27 | # b + x
28 |
29 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test TensorSensor
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 | build:
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | matrix:
14 | os: [ubuntu-latest, macos-latest, windows-latest]
15 | python-version: [3.9]
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up python 3
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: '3.9'
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install pyparsing==2.4.7
27 | pip install torch
28 | python setup.py install
29 |
30 | - name: Test with pytest
31 | run: |
32 | pip install pytest
33 | pytest testing/test_incr_eval.py
34 | pytest testing/test_invalid_stat.py
35 | pytest testing/test_parser.py
36 | pytest testing/test_tree_eval.py
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Terence Parr
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/testing/test2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tsensor
3 | import torch
4 |
5 | W = np.array([[1, 2], [3, 4]])
6 | b = np.array([9, 10]).reshape(2, 1)
7 | x = np.array([4, 5]).reshape(2, 1)
8 | h = np.array([1, 2])
9 | # with tsensor.explain(savefig="/Users/parrt/Desktop/foo.pdf"):
10 | # W @ np.dot(b,b) + np.eye(2,2)@x
11 |
12 |
13 | W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
14 | b = torch.tensor([9, 10]).reshape(2, 1)
15 | x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1)
16 | h = torch.tensor([1,2])
17 |
18 | a = torch.rand(size=(2, 20), dtype=torch.float64)
19 | b = torch.rand(size=(2, 20), dtype=torch.float32)
20 | c = torch.rand(size=(2,20,200), dtype=torch.complex64)
21 | d = torch.rand(size=(2,20,200,5), dtype=torch.float16)
22 |
23 |
24 | with tsensor.explain(savefig="/Users/parrt/Desktop/t2.pdf"):
25 | a + b + x + c[:,:,0] + d[:,:,0,0]
26 |
27 | with tsensor.explain(savefig="/Users/parrt/Desktop/t3.pdf"):
28 | c
29 |
30 | with tsensor.explain(savefig="/Users/parrt/Desktop/t4.pdf"):
31 | d
32 |
33 | # with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t.pdf") as e:
34 | # W @ torch.dot(b, b) + torch.eye(2, 2) @ x
35 |
--------------------------------------------------------------------------------
/tsensor/version.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | __version__ = '1.0'
25 |
--------------------------------------------------------------------------------
/testing/test_dtype.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 | import tensorflow as tf
4 | import torch
5 | import pytest
6 |
7 | import tsensor as ts
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "value,expected",
12 | [
13 | # Numpy
14 | (np.random.randint(1, 10, size=(10, 2, 5)), "int64"),
15 | (np.random.randint(1, 10, size=(10, 2, 5), dtype="int8"), "int8"),
16 | (np.random.normal(size=(5, 1)).astype(np.float32), "float32"),
17 | (np.random.normal(size=(5, 1)).astype(np.float32), "float32"),
18 | (np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')]),
19 | "str320,int32,float32"),
20 | # Jax
21 | (jnp.array([[1, 2], [3, 4]]), "int32"),
22 | (jnp.array([[1, 2], [3, 4]], dtype="int8"), "int8"),
23 | # Tensorflow
24 | (tf.constant([[1, 2], [3, 4]]), "int32"),
25 | (tf.constant([[1, 2], [3, 4]], dtype="int64"), "int64"),
26 | # Pytorch
27 | (torch.tensor([[1, 2], [3, 4]]), "int64"),
28 | (torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "int32"),
29 | ],
30 | )
31 | def test_dtypes(value, expected):
32 | assert ts.analysis._dtype(value) == expected
33 |
--------------------------------------------------------------------------------
/tsensor/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | # These classes/functions are the primary user interface so import them directly
25 | import tsensor.ast
26 | import tsensor.parsing
27 | import tsensor.viz
28 | import tsensor.analysis
29 | from tsensor.analysis import explain, clarify, eval
30 | from tsensor.parsing import parse
31 | from tsensor.viz import pyviz, astviz
32 | from .version import __version__
33 |
34 |
35 | __all__ = ["ast", "parsing", "viz", "analysis", "version"]
36 |
37 | # To fix an OpenMP runtime link issue.
38 | import os
39 |
40 | os.environ['KMP_DUPLICATE_LIB_OK'] = "True"
41 |
--------------------------------------------------------------------------------
/testing/test_tensorflow.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import tsensor
25 | import tensorflow as tf
26 |
27 | W = tf.constant([[1, 2], [3, 4]])
28 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1))
29 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1))
30 |
31 | def test_addition():
32 | msg = ""
33 | try:
34 | with tsensor.clarify():
35 | q = b + x + 3
36 | except tf.errors.InvalidArgumentError as iae:
37 | msg = iae.message
38 |
39 | expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\
40 | "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)"
41 | assert msg==expected
--------------------------------------------------------------------------------
/testing/play_cmaps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import matplotlib.pyplot as plt
4 | import matplotlib.colors as mc
5 |
6 | bits = [4,8,16,32,64,128]
7 | bits = [8,16,32,64,128]
8 | nhues = len(bits)
9 |
10 | blueish = '#3B75AF'
11 | greenish = '#519E3E'
12 |
13 | orangeish = '#FDDB7D'
14 | limeish = '#C1E1C5'
15 | limeish = '#A8E1B0'
16 | yellowish = '#FFFFAD'
17 |
18 | print(mc.hex2color(limeish))
19 |
20 | type_colors = {'float':limeish, 'int':blueish, 'complex':orangeish}
21 |
22 | # Derived from https://stackoverflow.com/questions/47222585/matplotlib-generic-colormap-from-tab10
23 |
24 | def categorical_cmap(color, nsc):
25 | # ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
26 | # print(ccolors[0:4])
27 | cols = np.zeros((nsc, 3))
28 | # chsv = mc.rgb_to_hsv(c[:3])
29 | chsv = mc.rgb_to_hsv(mc.hex2color(color))
30 | arhsv = np.tile(chsv,nsc).reshape(nsc,3)
31 | arhsv[:,1] = np.linspace(chsv[1],0.25,nsc)
32 | arhsv[:,2] = np.linspace(chsv[2],1,nsc)
33 | rgb = mc.hsv_to_rgb(arhsv)
34 | cols[0:nsc,:] = rgb
35 | cmap = mc.ListedColormap(cols)
36 | return cmap
37 |
38 | plt.figure(figsize=(3,3))
39 | c1 = categorical_cmap(blueish,nhues)
40 | plt.scatter(np.arange(nhues),[1]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
41 | c1 = categorical_cmap(limeish,nhues)
42 | plt.scatter(np.arange(nhues),[2]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
43 | c1 = categorical_cmap(yellowish,nhues)
44 | plt.scatter(np.arange(nhues),[3]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey')
45 |
46 | plt.margins(y=3)
47 | plt.xticks([])
48 | plt.yticks([0,1,2],["(5, 4)", "(2, 5)", "(4, 3)"])
49 | plt.ylim(0, 4)
50 | plt.axis('off')
51 |
52 | plt.savefig("/Users/parrt/Desktop/colors.pdf")
53 | plt.show()
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from setuptools import setup
25 |
26 | tensorflow_requires = ['tensorflow']
27 | torch_requires = ['torch']
28 | jax_requires = ['jax', 'jaxlib']
29 | all_requires = tensorflow_requires + torch_requires + jax_requires
30 |
31 | exec(open('tsensor/version.py').read())
32 | setup(
33 | name='tensor-sensor',
34 | version=__version__,
35 | url='https://github.com/parrt/tensor-sensor',
36 | license='MIT',
37 | py_modules=['tsensor.parsing', 'tsensor.ast', 'tsensor.analysis', 'tsensor.viz', 'tsensor.version'],
38 | author='Terence Parr',
39 | author_email='parrt@cs.usfca.edu',
40 | python_requires='>=3.6',
41 | install_requires=['graphviz>=0.14.1','numpy','IPython', 'matplotlib'],
42 | extras_require = {'all': all_requires,
43 | 'torch': torch_requires,
44 | 'tensorflow': tensorflow_requires,
45 | 'jax': jax_requires
46 | },
47 | description='The goal of this library is to generate more helpful exception messages for numpy/pytorch tensor algebra expressions.',
48 | classifiers=['License :: OSI Approved :: MIT License',
49 | 'Intended Audience :: Developers']
50 | )
51 |
--------------------------------------------------------------------------------
/testing/test_incr_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.ast import IncrEvalTrap
25 | from tsensor.parsing import PyExprParser
26 | import sys
27 | import numpy as np
28 | import torch
29 |
30 | def check(s,expected):
31 | frame = sys._getframe()
32 | caller = frame.f_back
33 | p = PyExprParser(s)
34 | t = p.parse()
35 | bad_subexpr = None
36 | try:
37 | t.eval(caller)
38 | except IncrEvalTrap as exc:
39 | bad_subexpr = str(exc.offending_expr)
40 | assert bad_subexpr==expected
41 |
42 |
43 | def test_missing_var():
44 | a = 3
45 | c = 5
46 | check("a+b+c", "b")
47 | check("z+b+c", "z")
48 |
49 | def test_matrix_mult():
50 | W = torch.tensor([[1, 2], [3, 4]])
51 | b = torch.tensor([[1,2,3]])
52 | check("W@b+torch.abs(b)", "W@b")
53 |
54 | def test_bad_arg():
55 | check("torch.abs('foo')", "torch.abs('foo')")
56 |
57 | def test_parens():
58 | a = 3
59 | b = 4
60 | c = 5
61 | check("(a+b)/0", "(a+b)/0")
62 |
63 | def test_array_literal():
64 | a = torch.tensor([[1,2,3],[4,5,6]])
65 | b = torch.tensor([[1,2,3]])
66 | a+b
67 | check("a + b@2", """b@2""")
68 |
69 | def test_array_literal2():
70 | a = torch.tensor([[1,2,3],[4,5,6]])
71 | b = torch.tensor([[1,2,3]])
72 | a+b
73 | check("(a+b)@2", """(a+b)@2""")
74 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/testing/str_size_matplotlib.py:
--------------------------------------------------------------------------------
1 | import matplotlib as mpl
2 | import matplotlib.patches as patches
3 | from matplotlib import pyplot as plt
4 |
5 | def textdim(s, fontname='Consolas', fontsize=11):
6 | fig, ax = plt.subplots(1,1)
7 | t = ax.text(0, 0, s, bbox={'lw':0, 'pad':0}, fontname=fontname, fontsize=fontsize)
8 | # plt.savefig(tempfile.mktemp(".pdf"))
9 | plt.savefig("/tmp/font.pdf", pad_inches=0, dpi=200)
10 | print(t)
11 | plt.close()
12 | bb = t.get_bbox_patch()
13 | print(bb)
14 | w, h = bb.get_width(), bb.get_height()
15 | return w, h
16 |
17 | # print(textdim("@"))
18 | #exit()
19 |
20 |
21 | # From: https://stackoverflow.com/questions/22667224/matplotlib-get-text-bounding-box-independent-of-backend
22 | def find_renderer(fig):
23 | if hasattr(fig.canvas, "get_renderer"):
24 | #Some backends, such as TkAgg, have the get_renderer method, which
25 | #makes this easy.
26 | renderer = fig.canvas.get_renderer()
27 | else:
28 | #Other backends do not have the get_renderer method, so we have a work
29 | #around to find the renderer. Print the figure to a temporary file
30 | #object, and then grab the renderer that was used.
31 | #(I stole this trick from the matplotlib backend_bases.py
32 | #print_figure() method.)
33 | import io
34 | fig.canvas.print_pdf(io.BytesIO())
35 | renderer = fig._cachedRenderer
36 | return(renderer)
37 |
38 |
39 | def textdim(s, fontname='Consolas', fontsize=11):
40 | fig, ax = plt.subplots(1, 1)
41 | t = ax.text(0, 0, s, fontname=fontname, fontsize=fontsize, transform=None)
42 | bb = t.get_window_extent(find_renderer(fig))
43 | print(s, bb.width, bb.height)
44 |
45 | # t = mpl.textpath.TextPath(xy=(0, 0), s=s, size=fontsize, prop=fontname)
46 | # bb = t.get_extents()
47 | # print(s, "new", bb)
48 | plt.close()
49 | return bb.width, bb.height
50 |
51 | # print(textdim("test of foo", fontsize=11))
52 | # print(textdim("test of foO", fontsize=11))
53 | # print(textdim("W @ b + x * 3 + h.dot(h)", fontsize=12))
54 |
55 | code = 'W@ b + x *3 + h.dot(h)'
56 | code = 'W@ b.f(x,y)'# + x *3 + h.dot(h)'
57 |
58 | fig, ax = plt.subplots(1,1,figsize=(4,1))
59 |
60 | fontname = "Serif"
61 | fontsize = 16
62 |
63 | # for c in code:
64 | # t = ax.text(0,0,c)
65 | # bbox1 = t.get_window_extent(find_renderer(fig))
66 | # # print(c, '->', bbox1.width, bbox1.height)
67 | # print(c, '->', textdim(c, fontname=fontname, fontsize=fontsize))
68 | # rect1 = patches.Rectangle((0,0), bbox1.width, bbox1.height, \
69 | # color = [0,0,0], fill = False)
70 | # fig.patches.append(rect1)
71 |
72 | x = 0
73 | for c in code:
74 | # print(f"plot {c} at {x},{0}")
75 | ax.text(x, 10, c, fontname=fontname, fontsize=fontsize, transform=None)
76 | w, h = textdim(c, fontname=fontname, fontsize=fontsize)
77 | # print(w,h,'->',x)
78 | x = x + w
79 |
80 | ax.set_xlim(0,x)
81 | #ax.set_ylim(0,10)
82 |
83 | ax.axis('off')
84 |
85 | #plt.show()
86 |
87 | plt.tight_layout()
88 |
89 | plt.savefig("/tmp/t.pdf", bbox_inches='tight', pad_inches=0, dpi=200)
90 |
--------------------------------------------------------------------------------
/testing/test_tree_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.parsing import *
25 | import tsensor.ast
26 | import sys
27 | import numpy as np
28 |
29 |
30 | def check(s,expected):
31 | frame = sys._getframe()
32 | caller = frame.f_back
33 | p = PyExprParser(s)
34 | t = p.parse()
35 | result = t.eval(caller)
36 | assert str(result)==str(expected)
37 |
38 |
39 | def test_int():
40 | check("34", 34)
41 |
42 | def test_assign():
43 | check("a = 34", 34)
44 |
45 | def test_var():
46 | a = 34
47 | check("a", 34)
48 |
49 | def test_member_var():
50 | class A:
51 | def __init__(self):
52 | self.a = 34
53 | x = A()
54 | check("x.a", 34)
55 |
56 | def test_member_func():
57 | class A:
58 | def f(self, a):
59 | return a+4
60 | x = A()
61 | check("x.f(30)", 34)
62 |
63 | def test_index():
64 | a = [1,2,3]
65 | check("a[2]", 3)
66 |
67 | def test_add():
68 | a = 3
69 | b = 4
70 | c = 5
71 | check("a+b+c", 12)
72 |
73 | def test_add_mul():
74 | a = 3
75 | b = 4
76 | c = 5
77 | check("a+b*c", 23)
78 |
79 | def test_pow():
80 | a = 3
81 | b = 4
82 | check("a**b", 3**4)
83 |
84 | def test_pow2():
85 | a = 3
86 | b = 4
87 | c = 5
88 | check("a**(b+1)**c", 3**5**5)
89 |
90 | def test_parens():
91 | a = 3
92 | b = 4
93 | c = 5
94 | check("(a+b)*c", 35)
95 |
96 | def test_list_literal():
97 | a = [[1,2,3],[4,5,6]]
98 | check("a", """[[1, 2, 3], [4, 5, 6]]""")
99 |
100 |
101 | def test_np_literal():
102 | a = np.array([[1,2,3],[4,5,6]])
103 | check("a*2", """[[ 2 4 6]\n [ 8 10 12]]""")
104 |
105 |
106 | def test_np_add():
107 | a = np.array([[1,2,3],[4,5,6]])
108 | check("a+a", """[[ 2 4 6]\n [ 8 10 12]]""")
109 |
110 |
111 | def test_np_add2():
112 | a = np.array([[1,2,3],[4,5,6]])
113 | check("a+a+a", """[[ 3 6 9]\n [12 15 18]]""")
114 |
--------------------------------------------------------------------------------
/testing/test_jax.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import jax.numpy as jnp
25 | import numpy as np
26 | import tsensor
27 |
28 | def test_dot():
29 | size = 5000
30 | x = np.random.normal(size=(size, size)).astype(np.float32)
31 | y = np.random.normal(size=(5, 1)).astype(np.float32)
32 |
33 | msg = ""
34 | try:
35 | with tsensor.clarify():
36 | z = jnp.dot(x, y).block_until_ready()
37 | except TypeError as e:
38 | msg = e.args[0]
39 |
40 | expected = "Incompatible shapes for dot: got (5000, 5000) and (5, 1).\n"+\
41 | "Cause: jnp.dot(x, y) tensor arg x w/shape (5000, 5000), arg y w/shape (5, 1)"
42 | assert msg==expected
43 |
44 |
45 | def test_scalar_arg():
46 | size = 5000
47 | x = np.random.normal(size=(size, size)).astype(np.float32)
48 |
49 | msg = ""
50 | try:
51 | with tsensor.clarify():
52 | z = jnp.dot(x, "foo")
53 | except TypeError as e:
54 | msg = e.args[0]
55 |
56 | expected = 'data type \'foo\' not understood\n'+\
57 | 'Cause: jnp.dot(x, "foo") tensor arg x w/shape (5000, 5000)'
58 | assert msg==expected
59 |
60 |
61 | def test_mmul():
62 | W = jnp.array([[1, 2], [3, 4]])
63 | b = jnp.array([9, 10, 11])
64 |
65 | msg = ""
66 | try:
67 | with tsensor.clarify():
68 | y = W @ b
69 | except TypeError as e:
70 | msg = e.args[0]
71 |
72 | expected = "dot_general requires contracting dimensions to have the same shape, got [2] and [3].\n"+\
73 | "Cause: @ on tensor operand W w/shape (2, 2) and operand b w/shape (3,)"
74 | assert msg==expected
75 |
76 |
77 | def test_fft():
78 | "Test a library function that doesn't have a shape related message in the exception."
79 | x = np.exp(2j * np.pi * np.arange(8) / 8)
80 | msg = ""
81 | try:
82 | with tsensor.clarify():
83 | y = jnp.fft.fft(x, norm="something weird")
84 | except BaseException as e:
85 | msg = e.args[0]
86 | print(msg)
87 |
88 | expected = 'jax.numpy.fft.fft only supports norm=None, got something weird\n'+\
89 | 'Cause: jnp.fft.fft(x, norm="something weird") tensor arg x w/shape (8,)'
90 | assert msg==expected
91 |
--------------------------------------------------------------------------------
/testing/viz_testing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import numpy as np
4 | import graphviz
5 | import tempfile
6 | import matplotlib.patches as patches
7 | import matplotlib.pyplot as plt
8 | import matplotlib.font_manager as fm
9 |
10 |
11 | # print('\n'.join(str(f) for f in fm.fontManager.ttflist))
12 | import tsensor
13 | # from tsensor.viz import pyviz, astviz
14 |
15 | import torch
16 | import tsensor
17 |
18 | n = 200 # number of instances
19 | d = 764 # number of instance features
20 | nhidden = 256
21 |
22 | Whh = torch.eye(nhidden, nhidden) # Identity matrix
23 | Uxh = torch.randn(nhidden, d)
24 | bh = torch.zeros(nhidden, 1)
25 | h = torch.randn(nhidden, 1) # fake previous hidden state h
26 | # r = torch.randn(nhidden, 1) # fake this computation
27 | r = torch.randn(nhidden, 3) # fake this computation
28 | X = torch.rand(n,d) # fake input
29 |
30 | # Following code raises an exception
31 | with tsensor.explain(savefig="/Users/parrt/Desktop/toomany.png") as e:
32 | h = torch.tanh(Whh @ (r*h) + Uxh @ X.T + bh) # state vector update equation
33 |
34 | exit()
35 |
36 | def foo():
37 | # W = torch.rand(size=(2000, 2000))
38 | W = torch.rand(size=(2000, 2000, 10, 8))
39 | b = torch.rand(size=(2000, 1))
40 | h = torch.rand(size=(1_000_000,))
41 | x = torch.rand(size=(2000, 1))
42 | # g = tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
43 | # sys._getframe())
44 | frame = sys._getframe()
45 | frame = None
46 | g = tsensor.astviz("b = W[:,:,0,0]@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
47 | frame)
48 | g.view()
49 |
50 | #foo()
51 |
52 | class Linear:
53 | def __init__(self, d, n_neurons):
54 | self.W = torch.randn(n_neurons, d)
55 | self.b = torch.zeros(n_neurons, 1)
56 | def __call__(self, x):
57 | return self.W@x + self.b
58 |
59 |
60 | n = 200 # number of instances
61 | d = 764 # number of instance features
62 | n_neurons = 100 # how many neurons in this layer?
63 | # L = Linear(d,n_neurons)
64 | #
65 | # import tensorflow as tf
66 | # X = tf.random.uniform((n,d))
67 | # with tsensor.clarify(hush_errors=False):
68 | # Y = L(X)
69 |
70 | # g = tsensor.pyviz("Y = L(X)", hush_errors=False)
71 | # g.show()
72 |
73 | class GRU:
74 | def __init__(self):
75 | self.W = torch.rand(size=(2,20,2000,10))
76 | self.b = torch.rand(size=(20,1))
77 | # self.x = torch.tensor([4, 5]).reshape(2, 1)
78 | self.h = torch.rand(size=(1_000_000,))
79 | self.a = 3
80 | print(self.W.shape)
81 | print(self.W[:, :, 1].shape)
82 |
83 | def get(self):
84 | return torch.tensor([[1, 2], [3, 4]])
85 |
86 | # W = torch.tensor([[1, 2], [3, 4]])
87 | b = torch.rand(size=(2000,1))
88 | h = torch.rand(size=(1_000_000,2))
89 | x = torch.rand(size=(1_000_000,2))
90 | a = 3
91 |
92 | # foo = torch.rand(size=(2000,))
93 | # torch.relu(foo)
94 |
95 | g = GRU()
96 |
97 | # with tsensor.clarify():
98 | # tf.constant([1,2]) @ tf.constant([1,3])
99 |
100 |
101 | code = "b = g.W[0,:,:,1]@b+torch.zeros(200,1)+(h+3).dot(h)"
102 | # code = "torch.relu(foo)"
103 | # code = "np.dot(b,b)"
104 | # code = "b.T"
105 | g = tsensor.pyviz(code, fontname='Courier New', fontsize=16, dimfontsize=9,
106 | char_sep_scale=1.8, hush_errors=False)
107 | plt.tight_layout()
108 | plt.savefig("/tmp/t.svg", dpi=200, bbox_inches='tight', pad_inches=0)
109 |
110 | # W = torch.tensor([[1, 2], [3, 4]])
111 | # x = torch.tensor([4, 5]).reshape(2, 1)
112 | # with tsensor.explain():
113 | # b = torch.rand(size=(2000,))
114 | # torch.relu(b)
115 |
116 |
117 | # g = GRU()
118 | #
119 | # g1 = tsensor.astviz("b = g.W@b + torch.eye(3,3)")
120 | # g1.view()
121 | # g1 = tsensor.pyviz("b = g.W@b")
122 | # g1.view()
123 | # g2 = tsensor.astviz("b = g.W@b + g.h.dot(g.h) + torch.abs(torch.tensor(34))")
124 |
--------------------------------------------------------------------------------
/images/sample-1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensor Sensor
2 |
3 | See article [Clarifying exceptions and visualizing tensor operations in deep learning code](https://explained.ai/tensor-sensor/index.html) and [TensorSensor implementation slides](https://github.com/parrt/tensor-sensor/raw/master/talks/tensor-sensor.pdf) (PDF).
4 |
5 | (*As of September 2021, M1 macs experience illegal instructions in many of the tensor libraries installed via Anaconda, so you should expect TensorSensor to work only on Intel-based Macs at the moment. PyTorch appears to work.*)
6 |
7 | One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages.
8 |
9 | To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/).
10 |
11 | *TensorSensor is currently at 1.0 (December 2021)*.
12 |
13 | ## Visualizations
14 |
15 | For more, see [examples.ipynb at colab](https://colab.research.google.com/github/parrt/tensor-sensor/blob/master/testing/examples.ipynb). (The github rendering does not show images for some reason: [examples.ipynb](testing/examples.ipynb).)
16 |
17 | ```python
18 | import numpy as np
19 |
20 | n = 200 # number of instances
21 | d = 764 # number of instance features
22 | n_neurons = 100 # how many neurons in this layer?
23 |
24 | W = np.random.rand(d,n_neurons)
25 | b = np.random.rand(n_neurons,1)
26 | X = np.random.rand(n,d)
27 | with tsensor.clarify() as c:
28 | Y = W @ X.T + b
29 | ```
30 |
31 | Displays this in a jupyter notebook or separate window:
32 |
33 |
34 |
35 | Instead of the following default exception message:
36 |
37 | ```
38 | ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
39 | ```
40 |
41 | TensorSensor augments the message with more information about which operator caused the problem and includes the shape of the operands:
42 |
43 | ```
44 | Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)
45 | ```
46 |
47 | You can also get the full computation graph for an expression that includes all of the sub-expression shapes.
48 |
49 | ```python
50 | W = torch.rand(size=(2000,2000), dtype=torch.float64)
51 | b = torch.rand(size=(2000,1), dtype=torch.float64)
52 | h = torch.zeros(size=(1_000_000,), dtype=int)
53 | x = torch.rand(size=(2000,1))
54 | z = torch.rand(size=(2000,1), dtype=torch.complex64)
55 |
56 | tsensor.astviz("b = W@b + (h+3).dot(h) + z", sys._getframe())
57 | ```
58 |
59 | yields the following abstract syntax tree with shapes:
60 |
61 |
62 |
63 | ## Install
64 |
65 | ```
66 | pip install tensor-sensor # This will only install the library for you
67 | pip install tensor-sensor[torch] # install pytorch related dependency
68 | pip install tensor-sensor[tensorflow] # install tensorflow related dependency
69 | pip install tensor-sensor[jax] # install jax, jaxlib
70 | pip install tensor-sensor[all] # install tensorflow, pytorch, jax
71 | ```
72 |
73 | which gives you module `tsensor`. I developed and tested with the following versions
74 |
75 | ```
76 | $ pip list | grep -i flow
77 | tensorflow 2.5.0
78 | tensorflow-estimator 2.5.0
79 | $ pip list | grep -i numpy
80 | numpy 1.19.5
81 | numpydoc 1.1.0
82 | $ pip list | grep -i torch
83 | torch 1.10.0
84 | torchvision 0.10.0
85 | $ pip list | grep -i jax
86 | jax 0.2.20
87 | jaxlib 0.1.71
88 | ```
89 |
90 | ### Graphviz for tsensor.astviz()
91 |
92 | For displaying abstract syntax trees (ASTs) with `tsensor.astviz(...)`, you need the `dot` executable from graphviz, not just the python library.
93 |
94 | On **Mac**, do this before or after tensor-sensor install:
95 |
96 | ```
97 | brew install graphviz
98 | ```
99 |
100 | On **Windows**, apparently you need
101 |
102 | ```
103 | conda install python-graphviz # Do this first; get's dot executable and py lib
104 | pip install tensor-sensor # Or one of the other installs
105 | ```
106 |
107 |
108 | ## Limitations
109 |
110 | I rely on parsing lines that are assignments or expressions only so the clarify and explain routines do not handle methods expressed like:
111 |
112 | ```
113 | def bar(): b + x * 3
114 | ```
115 |
116 | Instead, use
117 |
118 | ```
119 | def bar():
120 | b + x * 3
121 | ```
122 |
123 | watch out for side effects! I don't do assignments, but any functions you call with side effects will be done while I reevaluate statements.
124 |
125 | Can't handle `\` continuations.
126 |
127 | With Python `threading` package, don't use multiple threads calling clarify(). `multiprocessing` package should be fine.
128 |
129 | Also note: I've built my own parser to handle just the assignments / expressions tsensor can handle.
130 |
131 | ## Deploy (parrt's use)
132 |
133 | ```bash
134 | $ python setup.py sdist upload
135 | ```
136 |
137 | Or download and install locally
138 |
139 | ```bash
140 | $ cd ~/github/tensor-sensor
141 | $ pip install .
142 | ```
143 |
144 | ### TODO
145 |
146 | * can i call pyviz in debugger?
147 |
--------------------------------------------------------------------------------
/testing/test_parser.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.parsing import *
25 | import re
26 |
27 | def check(s, expected_repr, expect_str=None):
28 | p = PyExprParser(s, hush_errors=False)
29 | t = p.parse()
30 |
31 | s = re.sub(r"\s+", "", s)
32 | result_str = str(t)
33 | result_str = re.sub(r"\s+", "", result_str)
34 | if expect_str:
35 | s = expect_str
36 | assert result_str==s
37 |
38 | result_repr = repr(t)
39 | result_repr = re.sub(r"\s+", "", result_repr)
40 | expected_repr = re.sub(r"\s+", "", expected_repr)
41 | # print("result", result_repr)
42 | # print("expected", expected)
43 | assert result_repr == expected_repr
44 |
45 |
46 | def test_assign():
47 | check("a = 3", "Assign(op=,lhs=a,rhs=3)")
48 |
49 |
50 | def test_index():
51 | check("a[:,i,j]", "Index(arr=a, index=[:, i, j])")
52 |
53 |
54 | def test_index2():
55 | check("z = a[:]", "Assign(op=,lhs=z,rhs=Index(arr=a,index=[:]))")
56 |
57 | def test_index3():
58 | check("g.W[:,:,1]", "Index(arr=Member(op=,obj=g,member=W),index=[:,:,1])")
59 |
60 | def test_literal_list():
61 | check("[[1, 2], [3, 4]]",
62 | "ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])")
63 |
64 |
65 | def test_literal_array():
66 | check("np.array([[1, 2], [3, 4]])",
67 | """
68 | Call(func=Member(op=,obj=np,member=array),
69 | args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])])
70 | """)
71 |
72 |
73 | def test_method():
74 | check("h = torch.tanh(h)",
75 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[h]))")
76 |
77 |
78 | def test_method2():
79 | check("np.dot(b,b)",
80 | "Call(func=Member(op=,obj=np,member=dot),args=[b,b])")
81 |
82 |
83 | def test_method3():
84 | check("y_pred = model(X)",
85 | "Assign(op=,lhs=y_pred,rhs=Call(func=model,args=[X]))")
86 |
87 |
88 | def test_field():
89 | check("a.b", "Member(op=,obj=a,member=b)")
90 |
91 |
92 | def test_member_func():
93 | check("a.f()", "Call(func=Member(op=,obj=a,member=f),args=[])")
94 |
95 |
96 | def test_field2():
97 | check("a.b.c", "Member(op=,obj=Member(op=,obj=a,member=b),member=c)")
98 |
99 |
100 | def test_field_and_func():
101 | check("a.f().c", "Member(op=,obj=Call(func=Member(op=,obj=a,member=f),args=[]),member=c)")
102 |
103 |
104 | def test_parens():
105 | check("(a+b)*c", "BinaryOp(op=,lhs=BinaryOp(op=,lhs=a,rhs=b),rhs=c)")
106 |
107 |
108 | def test_expr_in_arg_with_parens():
109 | check("h = torch.tanh( (1-z)*h + z*h_ )",
110 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[BinaryOp(op=,lhs=BinaryOp(op=,lhs=BinaryOp(op=,lhs=1,rhs=z),rhs=h),rhs=BinaryOp(op=,lhs=z,rhs=h_))]))")
111 |
112 |
113 | def test_1tuple():
114 | check("(3,)", "TupleLiteral(elems=[3])")
115 |
116 |
117 | def test_2tuple():
118 | check("(3,4)", "TupleLiteral(elems=[3,4])")
119 |
120 |
121 | def test_2tuple_with_trailing_comma():
122 | check("(3,4,)", "TupleLiteral(elems=[3,4])")
123 |
124 |
125 | def test_field_array():
126 | check("a.b[34]", "Index(arr=Member(op=,obj=a,member=b),index=[34])")
127 |
128 |
129 | def test_field_array_func():
130 | check("a.b[34].f()", "Call(func=Member(op=,obj=Index(arr=Member(op=,obj=a,member=b),index=[34]),member=f),args=[])")
131 |
132 |
133 | def test_arith():
134 | check("(1-z)*h + z*h_",
135 | """BinaryOp(op=,
136 | lhs=BinaryOp(op=,
137 | lhs=BinaryOp(op=,
138 | lhs=1,
139 | rhs=z),
140 | rhs=h),
141 | rhs=BinaryOp(op=,lhs=z,rhs=h_))""")
142 |
143 |
144 | def test_pow():
145 | check("a**2",
146 | """BinaryOp(op=,lhs=a,rhs=2)""")
147 |
148 |
149 | def test_chained_pow():
150 | check("a**b**c",
151 | """BinaryOp(op=,lhs=a,rhs=BinaryOp(op=,lhs=b,rhs=c))""")
152 |
153 |
154 | def test_chained_op():
155 | check("a + b + c",
156 | """BinaryOp(op=,
157 | lhs=BinaryOp(op=, lhs=a, rhs=b),
158 | rhs=c)""")
159 |
160 |
161 | def test_matrix_arith():
162 | check("self.Whz@h + Uxz@x + bz",
163 | """
164 | BinaryOp(op=,
165 | lhs=BinaryOp(op=,
166 | lhs=BinaryOp(op=,lhs=Member(op=,obj=self,member=Whz),rhs=h),
167 | rhs=BinaryOp(op=,lhs=Uxz,rhs=x)),
168 | rhs=bz)
169 | """)
170 |
171 | def test_kwarg():
172 | check("torch.relu(torch.rand(size=(2000,)))",
173 | """
174 | Call(func=Member(op=,obj=torch,member=relu),
175 | args=[Call(func=Member(op=,obj=torch,member=rand),
176 | args=[Assign(op=,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""")
--------------------------------------------------------------------------------
/images/sample-2.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
132 |
--------------------------------------------------------------------------------
/tsensor/ast.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import tsensor
25 |
26 | # Parse tree definitions
27 | # I found package ast in python3 lib after I built this. whoops. No biggie.
28 | # This tree structure is easier to visit for my purposes here. Also lets me
29 | # control the kinds of statements I process.
30 |
31 | class ParseTreeNode:
32 | def __init__(self, parser):
33 | self.parser = parser # which parser object created this node;
34 | # useful for getting access to the code string from a token
35 | self.value = None # used during evaluation
36 | self.start = None # start token
37 | self.stop = None # end token
38 | def eval(self, frame):
39 | """
40 | Evaluate the expression represented by this (sub)tree in context of frame.
41 | Try any exception found while evaluating and remember which operation that
42 | was in this tree
43 | """
44 | try:
45 | self.value = eval(str(self), frame.f_globals, frame.f_locals)
46 | except BaseException as e:
47 | raise IncrEvalTrap(self) from e
48 | # print(self, "=>", self.value)
49 | return self.value
50 | @property
51 | def optokens(self): # the associated token if atom or representative token if operation
52 | return None
53 | @property
54 | def kids(self):
55 | return []
56 | def clarify(self):
57 | return None
58 | def __str__(self):
59 | # Extract text from the original code string using token character indexes
60 | return self.parser.code[self.start.cstart_idx:self.stop.cstop_idx]
61 | def __repr__(self):
62 | fields = self.__dict__.copy()
63 | kill = ['start', 'stop', 'lbrack', 'lparen', 'parser']
64 | for name in kill:
65 | if name in fields: del fields[name]
66 | args = [
67 | v + '=' + fields[v].__repr__()
68 | for v in fields
69 | if v != 'value' or fields['value'] is not None
70 | ]
71 | args = ','.join(args)
72 | return f"{self.__class__.__name__}({args})"
73 |
74 | class Assign(ParseTreeNode):
75 | def __init__(self, parser, op, lhs, rhs, start, stop):
76 | super().__init__(parser)
77 | self.op, self.lhs, self.rhs = op, lhs, rhs
78 | self.start, self.stop = start, stop
79 | def eval(self, frame):
80 | self.value = self.rhs.eval(frame)
81 | # Don't eval this node as it causes side effect of making actual assignment to lhs
82 | self.lhs.value = self.value
83 | return self.value
84 | @property
85 | def optokens(self):
86 | return [self.op]
87 | @property
88 | def kids(self):
89 | return [self.lhs, self.rhs]
90 |
91 |
92 | class Call(ParseTreeNode):
93 | def __init__(self, parser, func, lparen, args, start, stop):
94 | super().__init__(parser)
95 | self.func = func
96 | self.lparen = lparen
97 | self.args = args
98 | self.start, self.stop = start, stop
99 | def eval(self, frame):
100 | self.func.eval(frame)
101 | for a in self.args:
102 | a.eval(frame)
103 | return super().eval(frame)
104 | def clarify(self):
105 | arg_msgs = []
106 | for a in self.args:
107 | ashape = tsensor.analysis._shape(a.value)
108 | if ashape:
109 | arg_msgs.append(f"arg {a} w/shape {ashape}")
110 | if len(arg_msgs)==0:
111 | return f"Cause: {self}"
112 | return f"Cause: {self} tensor " + ', '.join(arg_msgs)
113 | @property
114 | def optokens(self):
115 | f = None # assume complicated like a[i](args) with weird func expr
116 | if isinstance(self.func, Member):
117 | f = self.func.member
118 | elif isinstance(self.func, Atom):
119 | f = self.func
120 | if f:
121 | return [f.token,self.lparen,self.stop]
122 | return [self.lparen,self.stop]
123 | @property
124 | def kids(self):
125 | return [self.func]+self.args
126 |
127 |
128 | class Return(ParseTreeNode):
129 | def __init__(self, parser, result, start, stop):
130 | super().__init__(parser)
131 | self.result = result
132 | self.start, self.stop = start, stop
133 | def eval(self, frame):
134 | self.value = [a.eval(frame) for a in self.result]
135 | if len(self.value)==1:
136 | self.value = self.value[0]
137 | return self.value
138 | @property
139 | def optokens(self):
140 | return [self.start]
141 | @property
142 | def kids(self):
143 | return self.result
144 |
145 |
146 | class Index(ParseTreeNode):
147 | def __init__(self, parser, arr, lbrack, index, start, stop):
148 | super().__init__(parser)
149 | self.arr = arr
150 | self.lbrack = lbrack
151 | self.index = index
152 | self.start, self.stop = start, stop
153 | def eval(self, frame):
154 | self.arr.eval(frame)
155 | for i in self.index:
156 | i.eval(frame)
157 | return super().eval(frame)
158 | @property
159 | def optokens(self):
160 | return [self.lbrack,self.stop]
161 | @property
162 | def kids(self):
163 | return [self.arr] + self.index
164 |
165 |
166 | class Member(ParseTreeNode):
167 | def __init__(self, parser, op, obj, member, start, stop):
168 | super().__init__(parser)
169 | self.op = op # always DOT
170 | self.obj = obj
171 | self.member = member
172 | self.start, self.stop = start, stop
173 | def eval(self, frame):
174 | self.obj.eval(frame)
175 | # don't eval member as it's just a name to look up in obj
176 | return super().eval(frame)
177 | @property
178 | def optokens(self): # the associated token if atom or representative token if operation
179 | return [self.op]
180 | @property
181 | def kids(self):
182 | return [self.obj, self.member]
183 |
184 |
185 | class BinaryOp(ParseTreeNode):
186 | def __init__(self, parser, op, lhs, rhs, start, stop):
187 | super().__init__(parser)
188 | self.op, self.lhs, self.rhs = op, lhs, rhs
189 | self.start, self.stop = start, stop
190 | def eval(self, frame):
191 | self.lhs.eval(frame)
192 | self.rhs.eval(frame)
193 | return super().eval(frame)
194 | def clarify(self):
195 | opnd_msgs = []
196 | lshape = tsensor.analysis._shape(self.lhs.value)
197 | rshape = tsensor.analysis._shape(self.rhs.value)
198 | if lshape:
199 | opnd_msgs.append(f"operand {self.lhs} w/shape {lshape}")
200 | if rshape:
201 | opnd_msgs.append(f"operand {self.rhs} w/shape {rshape}")
202 | return f"Cause: {self.op} on tensor " + ' and '.join(opnd_msgs)
203 | @property
204 | def optokens(self): # the associated token if atom or representative token if operation
205 | return [self.op]
206 | @property
207 | def kids(self):
208 | return [self.lhs, self.rhs]
209 |
210 |
211 | class UnaryOp(ParseTreeNode):
212 | def __init__(self, parser, op, opnd, start, stop):
213 | super().__init__(parser)
214 | self.op = op
215 | self.opnd = opnd
216 | self.start, self.stop = start, stop
217 | def eval(self, frame):
218 | self.opnd.eval(frame)
219 | return super().eval(frame)
220 | @property
221 | def optokens(self):
222 | return [self.op]
223 | @property
224 | def kids(self):
225 | return [self.opnd]
226 |
227 |
228 | class ListLiteral(ParseTreeNode):
229 | def __init__(self, parser, elems, start, stop):
230 | super().__init__(parser)
231 | self.elems = elems
232 | self.start, self.stop = start, stop
233 | def eval(self, frame):
234 | for i in self.elems:
235 | i.eval(frame)
236 | return super().eval(frame)
237 | @property
238 | def kids(self):
239 | return self.elems
240 |
241 |
242 | class TupleLiteral(ParseTreeNode):
243 | def __init__(self, parser, elems, start, stop):
244 | super().__init__(parser)
245 | self.elems = elems
246 | self.start, self.stop = start, stop
247 | def eval(self, frame):
248 | for i in self.elems:
249 | i.eval(frame)
250 | return super().eval(frame)
251 | @property
252 | def kids(self):
253 | return self.elems
254 |
255 |
256 | class SubExpr(ParseTreeNode):
257 | # record parens for later display to keep precedence
258 | def __init__(self, parser, e, start, stop):
259 | super().__init__(parser)
260 | self.e = e
261 | self.start, self.stop = start, stop
262 | def eval(self, frame):
263 | self.value = self.e.eval(frame)
264 | return self.value # don't re-evaluate
265 | @property
266 | def optokens(self):
267 | return [self.start, self.stop]
268 | @property
269 | def kids(self):
270 | return [self.e]
271 |
272 |
273 | class Atom(ParseTreeNode):
274 | def __init__(self, parser, token):
275 | super().__init__(parser)
276 | self.token = token
277 | self.start, self.stop = token, token
278 | def eval(self, frame):
279 | if self.token.type == tsensor.parsing.COLON:
280 | return ':' # fake a value here
281 | return super().eval(frame)
282 | def __repr__(self):
283 | # v = f"{{{self.value}}}" if hasattr(self,'value') and self.value is not None else ""
284 | return self.token.value
285 |
286 |
287 | def postorder(t):
288 | nodes = []
289 | _postorder(t, nodes)
290 | return nodes
291 |
292 |
293 | def _postorder(t, nodes):
294 | if t is None:
295 | return
296 | for sub in t.kids:
297 | _postorder(sub, nodes)
298 | nodes.append(t)
299 |
300 |
301 | def leaves(t):
302 | nodes = []
303 | _leaves(t, nodes)
304 | return nodes
305 |
306 |
307 | def _leaves(t, nodes):
308 | if t is None:
309 | return
310 | if len(t.kids) == 0:
311 | nodes.append(t)
312 | return
313 | for sub in t.kids:
314 | _leaves(sub, nodes)
315 |
316 |
317 | def walk(t, pre=lambda x: None, post=lambda x: None):
318 | if t is None:
319 | return
320 | pre(t)
321 | for sub in t.kids:
322 | walk(sub, pre, post)
323 | post(t)
324 |
325 |
326 | class IncrEvalTrap(BaseException):
327 | """
328 | Used during re-evaluation of python line that threw exception to trap which
329 | subexpression caused the problem.
330 | """
331 | def __init__(self, offending_expr):
332 | self.offending_expr = offending_expr # where in tree did we get exception?
333 |
--------------------------------------------------------------------------------
/tsensor/parsing.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from io import BytesIO
25 | import token
26 | import keyword
27 | import tsensor.ast
28 | from tokenize import (
29 | tokenize,
30 | NUMBER,
31 | STRING,
32 | NAME,
33 | OP,
34 | ENDMARKER,
35 | LPAR,
36 | LSQB,
37 | RPAR,
38 | RSQB,
39 | COMMA,
40 | COLON,
41 | PLUS,
42 | MINUS,
43 | STAR,
44 | SLASH,
45 | AT,
46 | PERCENT,
47 | TILDE,
48 | DOT,
49 | NOTEQUAL,
50 | PERCENTEQUAL,
51 | AMPEREQUAL,
52 | DOUBLESTAREQUAL,
53 | STAREQUAL,
54 | PLUSEQUAL,
55 | MINEQUAL,
56 | DOUBLESLASHEQUAL,
57 | SLASHEQUAL,
58 | LEFTSHIFTEQUAL,
59 | LESSEQUAL,
60 | EQUAL,
61 | EQEQUAL,
62 | GREATEREQUAL,
63 | RIGHTSHIFTEQUAL,
64 | ATEQUAL,
65 | CIRCUMFLEXEQUAL,
66 | VBAREQUAL,
67 | DOUBLESTAR,
68 | )
69 |
70 |
71 | ADDOP = {PLUS, MINUS}
72 | MULOP = {STAR, SLASH, AT, PERCENT}
73 | ASSIGNOP = {
74 | NOTEQUAL,
75 | PERCENTEQUAL,
76 | AMPEREQUAL,
77 | DOUBLESTAREQUAL,
78 | STAREQUAL,
79 | PLUSEQUAL,
80 | MINEQUAL,
81 | DOUBLESLASHEQUAL,
82 | SLASHEQUAL,
83 | LEFTSHIFTEQUAL,
84 | LESSEQUAL,
85 | EQUAL,
86 | EQEQUAL,
87 | GREATEREQUAL,
88 | RIGHTSHIFTEQUAL,
89 | ATEQUAL,
90 | CIRCUMFLEXEQUAL,
91 | VBAREQUAL,
92 | }
93 | UNARYOP = {TILDE}
94 |
95 | class Token:
96 | """My own version of a token, with content copied from Python's TokenInfo object."""
97 | def __init__(self, type, value,
98 | index, # token index
99 | cstart_idx, # char start
100 | cstop_idx, # one past char end index so text[start_idx:stop_idx] works
101 | line):
102 | self.type, self.value, self.index, self.cstart_idx, self.cstop_idx, self.line = \
103 | type, value, index, cstart_idx, cstop_idx, line
104 | def __repr__(self):
105 | return f"<{token.tok_name[self.type]}:{self.value},{self.cstart_idx}:{self.cstop_idx}>"
106 | def __str__(self):
107 | return self.value
108 |
109 |
110 | def mytokenize(s):
111 | "Use Python's tokenizer to lex s and collect my own token objects"
112 | tokensO = tokenize(BytesIO(s.encode('utf-8')).readline)
113 | tokens = []
114 | i = 0
115 | for tok in tokensO:
116 | type, value, start, end, _ = tok
117 | line = start[0]
118 | start_idx = start[1]
119 | stop_idx = end[1] # one past end index
120 | if type in {NUMBER, STRING, NAME, OP, ENDMARKER}:
121 | tokens.append(Token(tok.exact_type,value,i,start_idx,stop_idx,line))
122 | i += 1
123 | else:
124 | # print("ignoring", type, value)
125 | pass
126 | # It leaves ENDMARKER on end. set text to ""
127 | tokens[-1].value = ""
128 | # print(tokens)
129 | return tokens
130 |
131 |
132 | class PyExprParser:
133 | """
134 | A recursive-descent parser for subset of Python expressions and assignments.
135 | There is a built-in parser, but I only want to process Python code this library
136 | can handle and I also want my own kind of abstract syntax tree. Constantly,
137 | it's easier if I just parse the code I care about and ignore everything else.
138 | Building this parser was certainly no great burden.
139 | """
140 | def __init__(self, code:str, hush_errors=True):
141 | self.code = code
142 | self.hush_errors = hush_errors
143 | self.tokens = mytokenize(code)
144 | self.t = 0 # current lookahead
145 |
146 | def parse(self):
147 | # print("\nparse", self.code)
148 | # print(self.tokens)
149 | # only process assignments and expressions
150 | root = None
151 | if self.tokens[0].value=='return' or not keyword.iskeyword(self.tokens[0].value):
152 | if self.hush_errors:
153 | try:
154 | root = self.assignment_or_return_or_expr()
155 | self.match(ENDMARKER)
156 | except SyntaxError:
157 | root = None
158 | else:
159 | root = self.assignment_or_return_or_expr()
160 | self.match(ENDMARKER)
161 | return root
162 |
163 | def assignment_or_return_or_expr(self):
164 | start = self.LT(1)
165 | if self.LA(1)==NAME and self.LT(1).value=='return':
166 | self.match(NAME)
167 | r = self.exprlist()
168 | stop = self.LT(-1)
169 | return tsensor.ast.Return(self,r,start,stop)
170 | lhs = self.expression()
171 | if self.LA(1) in ASSIGNOP:
172 | eq = self.LT(1)
173 | self.t += 1
174 | rhs = self.expression()
175 | stop = self.LT(-1)
176 | return tsensor.ast.Assign(self,eq,lhs,rhs,start,stop)
177 | return lhs
178 |
179 | def expression(self):
180 | return self.addexpr()
181 |
182 | def addexpr(self):
183 | start = self.LT(1)
184 | root = self.multexpr()
185 | while self.LA(1) in ADDOP:
186 | op = self.LT(1)
187 | self.t += 1
188 | b = self.multexpr()
189 | stop = self.LT(-1)
190 | root = tsensor.ast.BinaryOp(self, op, root, b, start, stop)
191 | return root
192 |
193 | def multexpr(self):
194 | start = self.LT(1)
195 | root = self.powexpr()
196 | while self.LA(1) in MULOP:
197 | op = self.LT(1)
198 | self.t += 1
199 | b = self.powexpr()
200 | stop = self.LT(-1)
201 | root = tsensor.ast.BinaryOp(self, op, root, b, start, stop)
202 | return root
203 |
204 | def powexpr(self):
205 | start = self.LT(1)
206 | root = self.unaryexpr()
207 | if self.LA(1)==DOUBLESTAR:
208 | op = self.match(DOUBLESTAR)
209 | r = self.powexpr()
210 | stop = self.LT(-1)
211 | root = tsensor.ast.BinaryOp(self, op, root, r, start, stop)
212 | return root
213 |
214 | def unaryexpr(self):
215 | start = self.LT(1)
216 | if self.LA(1) in UNARYOP:
217 | op = self.LT(1)
218 | self.t += 1
219 | e = self.unaryexpr()
220 | stop = self.LT(-1)
221 | return tsensor.ast.UnaryOp(self, op, e, start, stop)
222 | elif self.isatom() or self.isgroup():
223 | return self.postexpr()
224 | else:
225 | self.error(f"missing unary expr at: {self.LT(1)}")
226 |
227 | def postexpr(self):
228 | start = self.LT(1)
229 | root = self.atom()
230 | while self.LA(1) in {LPAR, LSQB, DOT}:
231 | if self.LA(1)==LPAR:
232 | lp = self.LT(1)
233 | self.match(LPAR)
234 | el = []
235 | if self.LA(1) != RPAR:
236 | el = self.arglist()
237 | self.match(RPAR)
238 | stop = self.LT(-1)
239 | root = tsensor.ast.Call(self, root, lp, el, start, stop)
240 | if self.LA(1)==LSQB:
241 | lb = self.LT(1)
242 | self.match(LSQB)
243 | el = self.exprlist()
244 | self.match(RSQB)
245 | stop = self.LT(-1)
246 | root = tsensor.ast.Index(self, root, lb, el, start, stop)
247 | if self.LA(1)==DOT:
248 | op = self.match(DOT)
249 | m = self.match(NAME)
250 | m = tsensor.ast.Atom(self, m)
251 | stop = self.LT(-1)
252 | root = tsensor.ast.Member(self, op, root, m, start, stop)
253 | return root
254 |
255 | def atom(self):
256 | if self.LA(1) == LPAR:
257 | return self.subexpr()
258 | elif self.LA(1) == LSQB:
259 | return self.listatom()
260 | elif self.LA(1) in {NUMBER, NAME, STRING, COLON}:
261 | atom = self.LT(1)
262 | self.t += 1
263 | return tsensor.ast.Atom(self, atom)
264 | else:
265 | self.error("unknown or missing atom:"+str(self.LT(1)))
266 |
267 | def exprlist(self):
268 | elist = []
269 | e = self.expression()
270 | elist.append(e)
271 | while self.LA(1)==COMMA and self.LA(2)!=RPAR: # could be trailing comma in a tuple like (3,4,)
272 | self.match(COMMA)
273 | e = self.expression()
274 | elist.append(e)
275 | return elist
276 |
277 | def arglist(self):
278 | elist = []
279 | if self.LA(1)==NAME and self.LA(2)==EQUAL:
280 | e = self.arg()
281 | else:
282 | e = self.expression()
283 | elist.append(e)
284 | while self.LA(1)==COMMA:
285 | self.match(COMMA)
286 | if self.LA(1) == NAME and self.LA(2)==EQUAL:
287 | e = self.arg()
288 | else:
289 | e = self.expression()
290 | elist.append(e)
291 | return elist
292 |
293 | def arg(self):
294 | start = self.LT(1)
295 | kwarg = self.match(NAME)
296 | eq = self.match(EQUAL)
297 | e = self.expression()
298 | kwarg = tsensor.ast.Atom(self, kwarg)
299 | stop = self.LT(-1)
300 | return tsensor.ast.Assign(self, eq, kwarg, e, start, stop)
301 |
302 | def subexpr(self):
303 | start = self.match(LPAR)
304 | e = self.exprlist() # could be a tuple like (3,4) or even (3,4,)
305 | istuple = len(e)>1
306 | if self.LA(1)==COMMA:
307 | self.match(COMMA)
308 | istuple = True
309 | stop = self.match(RPAR)
310 | if istuple:
311 | return tsensor.ast.TupleLiteral(self, e, start, stop)
312 | subexpr = e[0]
313 | # Parentheses just alter the precedence and don't actually indicate an operator
314 | # so we just pass the sub expression through (if not a tuple)
315 | return subexpr
316 |
317 | def listatom(self):
318 | start = self.LT(1)
319 | self.match(LSQB)
320 | e = self.exprlist()
321 | self.match(RSQB)
322 | stop = self.LT(-1)
323 | return tsensor.ast.ListLiteral(self, e, start, stop)
324 |
325 | def isatom(self):
326 | return self.LA(1) in {NUMBER, NAME, STRING, COLON}
327 | # return idstart(self.LA(1)) or self.LA(1).isdigit() or self.LA(1)==':'
328 |
329 | def isgroup(self):
330 | return self.LA(1)==LPAR or self.LA(1)==LSQB
331 |
332 | def LA(self, i):
333 | return self.LT(i).type
334 |
335 | def LT(self, i):
336 | if i==0:
337 | return None
338 | if i<0:
339 | return self.tokens[self.t + i] # -1 should give prev token
340 | ahead = self.t + i - 1
341 | if ahead >= len(self.tokens):
342 | return self.tokens[-1] # return last (end marker)
343 | return self.tokens[ahead]
344 |
345 | def match(self, type):
346 | if self.LA(1)!=type:
347 | self.error(f"mismatch token {self.LT(1)}, looking for {token.tok_name[type]}")
348 | tok = self.LT(1)
349 | self.t += 1
350 | return tok
351 |
352 | def error(self, msg):
353 | raise SyntaxError(msg)
354 |
355 |
356 | def parse(statement:str, hush_errors=True):
357 | """
358 | Parse statement and return ast and token objects. Parsing errors from invalid code
359 | or code that I cannot parse are ignored unless hush_hush_errors is False.
360 | """
361 | p = tsensor.parsing.PyExprParser(statement, hush_errors=hush_errors)
362 | return p.parse(), p.tokens
363 |
--------------------------------------------------------------------------------
/images/ast.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
275 |
--------------------------------------------------------------------------------
/images/mm.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
760 |
--------------------------------------------------------------------------------
/tsensor/analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import os
25 | import sys
26 | import traceback
27 | import inspect
28 | import hashlib
29 | from pathlib import Path
30 |
31 | import matplotlib.pyplot as plt
32 |
33 | import tsensor
34 |
35 |
36 | class clarify:
37 | # Prevent nested clarify() calls from processing exceptions.
38 | # See https://github.com/parrt/tensor-sensor/issues/18
39 | # Probably will fail with Python `threading` package due to this class var
40 | # but only if multiple threads call clarify().
41 | # Multiprocessing forks new processes so not a problem. Each vm has it's own class var.
42 | # Bump in __enter__, drop in __exit__
43 | nesting = 0
44 |
45 | def __init__(self,
46 | fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,
47 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
48 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
49 | show:(None,'viz')='viz',
50 | hush_errors=True,
51 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
52 | """
53 | Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
54 | Also display a visual representation of the offending Python line that
55 | shows the shape of tensors referenced by the code. All you have to do is wrap
56 | the outermost level of your code and clarify() will activate upon exception.
57 |
58 | Visualizations pop up in a separate window unless running from a notebook,
59 | in which case the visualization appears as part of the cell execution output.
60 |
61 | There is no runtime overhead associated with clarify() unless an exception occurs.
62 |
63 | The offending code is executed a second time, to identify which sub expressions
64 | are to blame. This implies that code with side effects could conceivably cause
65 | a problem, but since an exception has been generated, results are suspicious
66 | anyway.
67 |
68 | Example:
69 |
70 | import numpy as np
71 | import tsensor
72 |
73 | b = np.array([9, 10]).reshape(2, 1)
74 | with tsensor.clarify():
75 | np.dot(b,b) # tensor code or call to a function with tensor code
76 |
77 | See examples.ipynb for more examples.
78 |
79 | :param fontname: The name of the font used to display Python code
80 | :param fontsize: The font size used to display Python code; default is 13.
81 | Also use this to increase the size of the generated figure;
82 | larger font size means larger image.
83 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
84 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
85 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
86 | text is when plotted in matplotlib. In fact there's probably,
87 | no hope to discover this information accurately in all cases.
88 | Certainly, I gave up after spending huge effort. We have a
89 | situation here where the font should be constant width, so
90 | we can just use a simple scaler times the font size to get
91 | a reasonable approximation to the width and height of a
92 | character box; the default of 1.8 seems to work reasonably
93 | well for a wide range of fonts, but you might have to tweak it
94 | when you change the font size.
95 | :param fontcolor: The color of the Python code.
96 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
97 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
98 | :param error_op_color: The color to use for characters associated with the erroneous operator
99 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
100 | :param dpi: This library tries to generate SVG files, which are vector graphics not
101 | 2D arrays of pixels like PNG files. However, it needs to know how to
102 | compute the exact figure size to remove padding around the visualization.
103 | Matplotlib uses inches for its figure size and so we must convert
104 | from pixels or data units to inches, which means we have to know what the
105 | dots per inch, dpi, is for the image.
106 | :param hush_errors: Normally, error messages from true syntax errors but also
107 | unhandled code caught by my parser are ignored. Turn this off
108 | to see what the error messages are coming from my parser.
109 | :param show: Show visualization upon tensor error if show='viz'.
110 | :param dtype_colors: map from dtype w/o precision like 'int' to color
111 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
112 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
113 | and the alpha channel is used to show precision; the
114 | smaller the bit size, the lower the alpha channel. You
115 | can play with the range to get better visual dynamic range
116 | depending on how many precisions you want to display.
117 | """
118 | self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
119 | self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \
120 | self.error_op_color, self.hush_errors, \
121 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
122 | show, fontname, fontsize, dimfontname, dimfontsize, \
123 | char_sep_scale, fontcolor, underline_color, ignored_color, \
124 | error_op_color, hush_errors, \
125 | dtype_colors, dtype_precisions, dtype_alpha_range
126 |
127 | def __enter__(self):
128 | self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this
129 | # print("ENTER", clarify.nesting, self.frame, id(self.frame))
130 | clarify.nesting += 1
131 | return self
132 |
133 | def __exit__(self, exc_type, exc_value, exc_traceback):
134 | # print("EXIT", clarify.nesting, self.frame, id(self.frame))
135 | clarify.nesting -= 1
136 | if clarify.nesting>0:
137 | return
138 | if exc_type is None:
139 | return
140 | exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)
141 | if lib_entry_frame is not None or is_interesting_exception(exc_value):
142 | # print("exception:", exc_value, exc_traceback)
143 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
144 | module, name, filename, line, code = info(exc_frame)
145 | # print('info', module, name, filename, line, code)
146 | # print("exc id", id(exc_value))
147 | if code is not None:
148 | self.view = tsensor.viz.pyviz(code, exc_frame,
149 | self.fontname, self.fontsize, self.dimfontname,
150 | self.dimfontsize,
151 | self.char_sep_scale, self.fontcolor,
152 | self.underline_color, self.ignored_color,
153 | self.error_op_color,
154 | hush_errors=self.hush_errors,
155 | dtype_colors=self.dtype_colors,
156 | dtype_precisions=self.dtype_precisions,
157 | dtype_alpha_range=self.dtype_alpha_range)
158 | if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser)
159 | if self.show=='viz':
160 | self.view.show()
161 | augment_exception(exc_value, self.view.offending_expr)
162 |
163 |
164 | class explain:
165 | def __init__(self,
166 | fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13,
167 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
168 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
169 | savefig=None, hush_errors=True,
170 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
171 | """
172 | As the Python virtual machine executes lines of code, generate a
173 | visualization for tensor-related expressions using from numpy, pytorch,
174 | and tensorflow. The shape of tensors referenced by the code are displayed.
175 |
176 | Visualizations pop up in a separate window unless running from a notebook,
177 | in which case the visualization appears as part of the cell execution output.
178 |
179 | There is heavy runtime overhead associated with explain() as every line
180 | is executed twice: once by explain() and then another time by the interpreter
181 | as part of normal execution.
182 |
183 | Expressions with side effects can easily generate incorrect results. Due to
184 | this and the overhead, you should limit the use of this to code you're trying
185 | to debug. Assignments are not evaluated by explain so code `x = ...` causes
186 | an assignment to x just once, during normal execution. This explainer
187 | knows the value of x and will display it but does not assign to it.
188 |
189 | Upon exception, execution will stop as usual but, like clarify(), explain()
190 | will augment the exception to indicate the offending sub expression. Further,
191 | the visualization will deemphasize code not associated with the offending
192 | sub expression. The sizes of relevant tensor values are still visualized.
193 |
194 | Example:
195 |
196 | import numpy as np
197 | import tsensor
198 |
199 | b = np.array([9, 10]).reshape(2, 1)
200 | with tsensor.explain():
201 | b + b # tensor code or call to a function with tensor code
202 |
203 | See examples.ipynb for more examples.
204 |
205 | :param fontname: The name of the font used to display Python code
206 | :param fontsize: The font size used to display Python code; default is 13.
207 | Also use this to increase the size of the generated figure;
208 | larger font size means larger image.
209 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
210 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
211 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
212 | text is when plotted in matplotlib. In fact there's probably,
213 | no hope to discover this information accurately in all cases.
214 | Certainly, I gave up after spending huge effort. We have a
215 | situation here where the font should be constant width, so
216 | we can just use a simple scaler times the font size to get
217 | a reasonable approximation to the width and height of a
218 | character box; the default of 1.8 seems to work reasonably
219 | well for a wide range of fonts, but you might have to tweak it
220 | when you change the font size.
221 | :param fontcolor: The color of the Python code.
222 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
223 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
224 | :param error_op_color: The color to use for characters associated with the erroneous operator
225 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
226 | :param dpi: This library tries to generate SVG files, which are vector graphics not
227 | 2D arrays of pixels like PNG files. However, it needs to know how to
228 | compute the exact figure size to remove padding around the visualization.
229 | Matplotlib uses inches for its figure size and so we must convert
230 | from pixels or data units to inches, which means we have to know what the
231 | dots per inch, dpi, is for the image.
232 | :param hush_errors: Normally, error messages from true syntax errors but also
233 | unhandled code caught by my parser are ignored. Turn this off
234 | to see what the error messages are coming from my parser.
235 | :param savefig: A string indicating where to save the visualization; don't save
236 | a file if None.
237 | :param dtype_colors: map from dtype w/o precision like 'int' to color
238 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
239 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
240 | and the alpha channel is used to show precision; the
241 | smaller the bit size, the lower the alpha channel. You
242 | can play with the range to get better visual dynamic range
243 | depending on how many precisions you want to display.
244 | """
245 | self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
246 | self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \
247 | self.error_op_color, self.hush_errors, \
248 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
249 | savefig, fontname, fontsize, dimfontname, dimfontsize, \
250 | char_sep_scale, fontcolor, underline_color, ignored_color, \
251 | error_op_color, hush_errors, \
252 | dtype_colors, dtype_precisions, dtype_alpha_range
253 |
254 | def __enter__(self):
255 | # print("ON trace", sys._getframe())
256 | self.tracer = ExplainTensorTracer(self)
257 | sys.settrace(self.tracer.listener)
258 | frame = sys._getframe()
259 | prev = frame.f_back # get block wrapped in "with"
260 | prev.f_trace = self.tracer.listener
261 | return self.tracer
262 |
263 | def __exit__(self, exc_type, exc_value, exc_traceback):
264 | # print("OFF trace")
265 | sys.settrace(None)
266 | # At this point we have already tried to visualize the statement
267 | # If there was no error, the visualization will look normal
268 | # but a matrix operation error will show the erroneous operator highlighted.
269 | # That was artificial execution of the code. Now the VM has executed
270 | # the statement for real and has found the same exception. Make sure to
271 | # augment the message with causal information.
272 | if exc_type is None:
273 | return
274 | exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback)
275 | if lib_entry_frame is not None or is_interesting_exception(exc_value):
276 | # print("exception:", exc_value, exc_traceback)
277 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
278 | module, name, filename, line, code = info(exc_frame)
279 | # print('info', module, name, filename, line, code)
280 | if code is not None:
281 | # We've already displayed picture so just augment message
282 | root, tokens = tsensor.parsing.parse(code)
283 | if root is not None: # Could be syntax error in statement or code I can't handle
284 | offending_expr = None
285 | try:
286 | root.eval(exc_frame)
287 | except tsensor.ast.IncrEvalTrap as e:
288 | offending_expr = e.offending_expr
289 | augment_exception(exc_value, offending_expr)
290 |
291 |
292 | class ExplainTensorTracer:
293 | def __init__(self, explainer):
294 | self.explainer = explainer
295 | self.exceptions = set()
296 | self.linecount = 0
297 | self.views = []
298 | # set of hashes for statements already visualized;
299 | # generate each combination of statement and shapes once
300 | self.done = set()
301 |
302 | def listener(self, frame, event, arg):
303 | # print("listener", event, ":", frame)
304 | if event!='line':
305 | # It seems that we are getting CALL events even for calls in foo() from:
306 | # with tsensor.explain(): foo()
307 | # Must be that we already have a listener and, though we returned None here,
308 | # somehow the original listener is still getting events. Strange but oh well.
309 | # We must ignore these.
310 | return None
311 | module = frame.f_globals['__name__']
312 | info = inspect.getframeinfo(frame)
313 | filename, line = info.filename, info.lineno
314 | name = info.function
315 |
316 | # Note: always true since L292 above...
317 | if event=='line':
318 | self.line_listener(module, name, filename, line, info, frame)
319 |
320 | # By returning none, we prevent explain()'ing from descending into
321 | # invoked functions. In principle, we could allow a certain amount
322 | # of tracing but I'm not sure that would be super useful.
323 | return None
324 |
325 | def line_listener(self, module, name, filename, line, info, frame):
326 | code = info.code_context[0].strip()
327 | if code.startswith("sys.settrace(None)"):
328 | return
329 |
330 | # Don't generate a statement visualization more than once
331 | h = hash(code)
332 | if h in self.done:
333 | return
334 | self.done.add(h)
335 |
336 | p = tsensor.parsing.PyExprParser(code)
337 | t = p.parse()
338 | if t is not None:
339 | # print(f"A line encountered in {module}.{name}() at {filename}:{line}")
340 | # print("\t", code)
341 | # print("\t", repr(t))
342 | self.linecount += 1
343 | self.viz_statement(code, frame)
344 |
345 | def viz_statement(self, code, frame):
346 | view = tsensor.viz.pyviz(code, frame,
347 | self.explainer.fontname, self.explainer.fontsize,
348 | self.explainer.dimfontname,
349 | self.explainer.dimfontsize,
350 | self.explainer.char_sep_scale, self.explainer.fontcolor,
351 | self.explainer.underline_color, self.explainer.ignored_color,
352 | self.explainer.error_op_color,
353 | hush_errors=self.explainer.hush_errors,
354 | dtype_colors=self.explainer.dtype_colors,
355 | dtype_precisions=self.explainer.dtype_precisions,
356 | dtype_alpha_range=self.explainer.dtype_alpha_range)
357 | self.views.append(view)
358 | if self.explainer.savefig is not None:
359 | file_path = Path(self.explainer.savefig)
360 | file_path = file_path.parent / f"{file_path.stem}-{self.linecount}{file_path.suffix}"
361 | view.savefig(file_path)
362 | view.filename = file_path
363 | plt.close()
364 | else:
365 | view.show()
366 | return view
367 |
368 | @staticmethod
369 | def hash(statement):
370 | """
371 | We want to avoid generating a visualization more than once.
372 | For now, assume that the code for a statement is the unique identifier.
373 | """
374 | return hashlib.md5(statement.encode('utf-8')).hexdigest()
375 |
376 |
377 | def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object):
378 | """
379 | Parse statement and return an ast in the context of execution frame or, if None,
380 | the invoking function's frame. Set the value field of all ast nodes.
381 | Overall result is in root.value.
382 | :param statement: A string representing the line of Python code to visualize within an execution frame.
383 | :param frame: The execution frame in which to evaluate the statement. If None,
384 | use the execution frame of the invoking function
385 | :return An abstract parse tree representing the statement; nodes are
386 | ParseTreeNode subclasses.
387 | """
388 | p = tsensor.parsing.PyExprParser(statement)
389 | root = p.parse()
390 | if frame is None: # use frame of caller
391 | frame = sys._getframe().f_back
392 | root.eval(frame)
393 | return root, root.value
394 |
395 |
396 | def augment_exception(exc_value, subexpr):
397 | explanation = subexpr.clarify()
398 | augment = ""
399 | if explanation is not None:
400 | augment = explanation
401 | # Reuse exception but overwrite the message
402 | if hasattr(exc_value, "_message"):
403 | exc_value._message = exc_value.message + "\n" + augment
404 | else:
405 | exc_value.args = [exc_value.args[0] + "\n" + augment]
406 |
407 |
408 | def is_interesting_exception(e):
409 | # print(f"is_interesting_exception: type is {type(e)}")
410 | if e.__class__.__module__.startswith("tensorflow"):
411 | return True
412 | sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension',
413 | 'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix',
414 | 'call to _th_addmm'}
415 | if len(e.args)==0:
416 | msg = e.message
417 | else:
418 | msg = e.args[0]
419 | return any([s in msg for s in sentinels])
420 |
421 |
422 | def tensor_lib_entry_frame(exc_traceback):
423 | """
424 | Don't trace into internals of numpy/torch/tensorflow/jax; we want to reset frame
425 | to where in the user's python code it asked the tensor lib to perform an
426 | invalid operation.
427 |
428 | To detect libraries, look for code whose filename has "site-packages/{package}"
429 | or "dist-packages/{package}".
430 |
431 | Return last-user-frame, first-tensor-lib-frame if lib found else last-user-frame, None
432 |
433 | Note: Sometimes operators yield exceptions and no tensor lib entry frame. E.g.,
434 | np.ones(1) @ np.ones(2).
435 | """
436 | tb = exc_traceback
437 | # import traceback
438 | # for t in traceback.extract_tb(exc_traceback):
439 | # print(t)
440 | packages = ['numpy','torch','tensorflow','jax']
441 | dirs = [os.path.join('site-packages',p) for p in packages]
442 | dirs += [os.path.join('dist-packages',p) for p in packages]
443 | dirs += ['<__array_function__'] # numpy seems to not have real filename
444 | prev = tb
445 | while tb is not None:
446 | filename = tb.tb_frame.f_code.co_filename
447 | reached_lib = [p in filename for p in dirs]
448 | if sum(reached_lib)>0:
449 | return prev.tb_frame, tb.tb_frame
450 | prev = tb
451 | tb = tb.tb_next
452 | return prev.tb_frame, None
453 |
454 |
455 | def info(frame):
456 | if hasattr(frame, '__name__'):
457 | module = frame.f_globals['__name__']
458 | else:
459 | module = None
460 | info = inspect.getframeinfo(frame)
461 | if info.code_context is not None:
462 | code = info.code_context[0].strip()
463 | else:
464 | code = None
465 | filename, line = info.filename, info.lineno
466 | name = info.function
467 | return module, name, filename, line, code
468 |
469 |
470 | def smallest_matrix_subexpr(t):
471 | """
472 | During visualization, we need to find the smallest expression
473 | that evaluates to a non-scalar. That corresponds to the deepest subtree
474 | that evaluates to a non-scalar. Because we do not have parent pointers,
475 | we cannot start at the leaves and walk upwards. Instead, pass a Boolean
476 | back to indicate whether this node or one of the descendants
477 | evaluates to a non-scalar. Nodes in the tree that have matrix values and
478 | no matrix below are the ones to visualize.
479 | """
480 | nodes = []
481 | _smallest_matrix_subexpr(t, nodes)
482 | return nodes
483 |
484 |
485 | def _smallest_matrix_subexpr(t, nodes) -> bool:
486 | if t is None: return False # prevent buggy code from causing us to fail
487 | if isinstance(t, tsensor.ast.Member) and \
488 | isinstance(t.obj, tsensor.ast.Atom) and \
489 | isinstance(t.member, tsensor.ast.Atom) and \
490 | str(t.member)=='T':
491 | nodes.append(t)
492 | return True
493 | if len(t.kids)==0: # leaf node
494 | if istensor(t.value):
495 | nodes.append(t)
496 | return istensor(t.value)
497 | n_matrix_below = 0 # once this latches true, it's passed all the way up to the root
498 | for sub in t.kids:
499 | matrix_below = _smallest_matrix_subexpr(sub, nodes)
500 | n_matrix_below += matrix_below # how many descendents evaluated two non-scalar?
501 | # If current node is matrix and no descendents are, then this is smallest
502 | # sub expression that evaluates to a matrix; keep track
503 | if istensor(t.value) and n_matrix_below==0:
504 | nodes.append(t)
505 | # Report to caller that this node or some descendent is a matrix
506 | return istensor(t.value) or n_matrix_below > 0
507 |
508 |
509 | def istensor(x):
510 | return _shape(x) is not None
511 |
512 |
513 | def _dtype(v) -> str:
514 | if hasattr(v, "dtype"):
515 | dtype = v.dtype
516 | elif "dtype" in v.__class__.__name__:
517 | dtype = v
518 | else:
519 | return None
520 |
521 | if dtype.__class__.__module__ == "torch":
522 | # ugly but works
523 | return str(dtype).replace("torch.", "")
524 | if hasattr(dtype, "names") and dtype.names is not None and hasattr(dtype, "fields"):
525 | # structured dtype: https://numpy.org/devdocs/user/basics.rec.html
526 | return ",".join([_dtype(val) for val, _ in dtype.fields.values()])
527 | return dtype.name
528 |
529 |
530 | def _shape(v):
531 | # do we have a shape and it answers len()? Should get stuff right.
532 | if hasattr(v, "shape") and hasattr(v.shape, "__len__"):
533 | if v.shape.__class__.__module__ == "torch" and v.shape.__class__.__name__ == "Size":
534 | if len(v.shape)==0:
535 | return None
536 | return list(v.shape)
537 | return v.shape
538 | return None
539 |
--------------------------------------------------------------------------------
/tsensor/viz.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import sys
25 | import os
26 | from pathlib import Path
27 | import tempfile
28 | import graphviz
29 | import graphviz.backend
30 | import token
31 | import matplotlib.patches as patches
32 | import matplotlib.pyplot as plt
33 | import matplotlib.colors as mc
34 | from IPython.display import display, SVG
35 | from IPython import get_ipython
36 |
37 | import numpy as np
38 | import tsensor
39 | import tsensor.ast
40 | import tsensor.analysis
41 | import tsensor.parsing
42 |
43 |
44 | class DTypeColorInfo:
45 | """
46 | Track the colors for various types, the transparency range, and bit precisions.
47 | By default, green indicates floating-point, blue indicates integer, and orange
48 | indicates complex numbers. The more saturated the color (lower transparency),
49 | the higher the precision.
50 | """
51 | orangeish = '#FDD66C'
52 | limeish = '#A8E1B0'
53 | blueish = '#7FA4D3'
54 | grey = '#EFEFF0'
55 | default_dtype_colors = {'float': limeish, 'int': blueish, 'complex': orangeish, 'other': grey}
56 | default_dtype_precisions = [32, 64, 128] # hard to see diff if we use [4, 8, 16, 32, 64, 128]
57 | default_dtype_alpha_range = (0.5, 1.0) # use (0.1, 1.0) if more precision values
58 |
59 | def __init__(self, dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
60 | if dtype_colors is None:
61 | dtype_colors = DTypeColorInfo.default_dtype_colors
62 | if dtype_precisions is None:
63 | dtype_precisions = DTypeColorInfo.default_dtype_precisions
64 | if dtype_alpha_range is None:
65 | dtype_alpha_range = DTypeColorInfo.default_dtype_alpha_range
66 |
67 | if not isinstance(dtype_colors, dict) or (len(dtype_colors) > 0 and \
68 | not isinstance(list(dtype_colors.values())[0], str)):
69 | raise TypeError(
70 | "dtype_colors should be a dict mapping type name to color name or color hex RGB values."
71 | )
72 |
73 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \
74 | dtype_colors, dtype_precisions, dtype_alpha_range
75 |
76 | def color(self, dtype):
77 | """Get color based on type and precision. Return list of RGB and alpha"""
78 | dtype_name, dtype_precision = PyVizView._split_dtype_precision(dtype)
79 | if dtype_name not in self.dtype_colors:
80 | return self.dtype_colors['other']
81 | color = self.dtype_colors[dtype_name]
82 | dtype_precision = int(dtype_precision)
83 | if dtype_precision not in self.dtype_precisions:
84 | return self.dtype_colors['other']
85 |
86 | color = mc.hex2color(color) if color[0] == '#' else mc.cnames[color]
87 | precision_idx = self.dtype_precisions.index(dtype_precision)
88 | nshades = len(self.dtype_precisions)
89 | alphas = np.linspace(*self.dtype_alpha_range, nshades)
90 | alpha = alphas[precision_idx]
91 | return list(color) + [alpha]
92 |
93 |
94 | class PyVizView:
95 | """
96 | An object that collects relevant information about viewing Python code
97 | with visual annotations.
98 | """
99 | def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
100 | char_sep_scale, dpi,
101 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None):
102 | self.statement = statement
103 | self.fontsize = fontsize
104 | self.fontname = fontname
105 | self.dimfontsize = dimfontsize
106 | self.dimfontname = dimfontname
107 | self.char_sep_scale = char_sep_scale
108 | self.dpi = dpi
109 | self.dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
110 | self._dtype_encountered = set() # which types, like 'int64', did we find in one plot?
111 | self.wchar = self.char_sep_scale * self.fontsize
112 | self.wchar_small = self.char_sep_scale * (self.fontsize - 2) # for typenames
113 | self.hchar = self.char_sep_scale * self.fontsize
114 | self.dim_ypadding = 5
115 | self.dim_xpadding = 0
116 | self.linewidth = .7
117 | self.leftedge = 25
118 | self.bottomedge = 3
119 | self.filename = None
120 | self.matrix_size_scaler = 3.5 # How wide or tall as scaled fontsize is matrix?
121 | self.vector_size_scaler = 3.2 / 4 # How wide or tall as scaled fontsize is vector for skinny part?
122 | self.shift3D = 6
123 | self.cause = None # Did an exception occurred during evaluation?
124 | self.offending_expr = None
125 | self.fignumber = None
126 |
127 | @staticmethod
128 | def _split_dtype_precision(s):
129 | """Split the final integer part from a string"""
130 | head = s.rstrip('0123456789')
131 | tail = s[len(head):]
132 | return head, tail
133 |
134 | def set_locations(self, maxh):
135 | """
136 | This function finishes setting up necessary parameters about text
137 | and graphics locations for the plot. We don't know how to set these
138 | values until we know what the max height of the drawing will be. We don't know
139 | what that height is until after we've parsed and so on, which requires that
140 | we collect and store information in this view object before computing maxh.
141 | That is why this is a separate function not part of the constructor.
142 | """
143 | line2text = self.hchar / 1.7
144 | box2line = line2text*2.6
145 | self.texty = self.bottomedge + maxh + box2line + line2text
146 | self.liney = self.bottomedge + maxh + box2line
147 | self.box_topy = self.bottomedge + maxh
148 | self.maxy = self.texty + 1.4 * self.fontsize
149 |
150 | def _repr_svg_(self):
151 | "Show an SVG rendition in a notebook"
152 | return self.svg()
153 |
154 | def svg(self):
155 | """
156 | Render as svg and return svg text. Save file and store name in field svgfilename.
157 | """
158 | if self.filename is None: # have we saved before? (i.e., is it cached?)
159 | self.savefig(tempfile.mktemp(suffix='.svg'))
160 | elif not self.filename.endswith(".svg"):
161 | return None
162 | with open(self.filename, encoding='UTF-8') as f:
163 | svg = f.read()
164 | return svg
165 |
166 | def savefig(self, filename):
167 | "Save viz in format according to file extension."
168 | if plt.fignum_exists(self.fignumber):
169 | # If the matplotlib figure is still active, save it
170 | self.filename = filename # Remember the file so we can pull it back
171 | plt.savefig(filename, dpi=self.dpi, bbox_inches='tight', pad_inches=0)
172 | else: # we have already closed it so try to copy to new filename from the previous
173 | if filename!=self.filename:
174 | f,ext = os.path.splitext(filename)
175 | prev_f,prev_ext = os.path.splitext(self.filename)
176 | if ext != prev_ext:
177 | print(f"File extension {ext} differs from previous {prev_ext}; uses previous.")
178 | ext = prev_ext
179 | filename = f+ext # make sure that we don't copy raw bits and change the file extension to be inconsistent
180 | with open(self.filename, 'rb') as f:
181 | img = f.read()
182 | with open(filename, 'wb') as f:
183 | f.write(img)
184 | self.filename = filename # overwrite the filename with new name
185 |
186 | def show(self):
187 | "Display an SVG in a notebook or pop up a window if not in notebook"
188 | if get_ipython() is None:
189 | svgfilename = tempfile.mktemp(suffix='.svg')
190 | self.savefig(svgfilename)
191 | self.filename = svgfilename
192 | plt.show()
193 | else:
194 | svg = self.svg()
195 | display(SVG(svg))
196 | plt.close()
197 |
198 | def boxsize(self, v):
199 | """
200 | How wide and tall should we draw the box representing a vector or matrix.
201 | """
202 | sh = tsensor.analysis._shape(v)
203 | ty = tsensor.analysis._dtype(v)
204 | if sh is None: return None
205 | if len(sh)==1: return self.vector_size(sh, ty)
206 | return self.matrix_size(sh, ty)
207 |
208 | def matrix_size(self, sh, ty):
209 | """
210 | How wide and tall should we draw the box representing a matrix.
211 | """
212 | if len(sh)==1 and sh[0]==1:
213 | return self.vector_size(sh, ty)
214 |
215 | if len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
216 | # A special case where we have a 1x1 matrix extending into the screen.
217 | # Make the 1x1 part a little bit wider than a vector so it's more readable
218 | w, h = 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar
219 | elif len(sh) > 1 and sh[1] == 1:
220 | w, h = self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
221 | elif len(sh)>1 and sh[0]==1:
222 | w, h = self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
223 | else:
224 | w, h = self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar
225 | return w, h
226 |
227 | def vector_size(self, sh, ty):
228 | """
229 | How wide and tall is a vector? It's not a function of vector length; instead
230 | we make a row vector with same width as a matrix but height of just one char.
231 | For consistency with matrix_size(), I pass in shape, though it's ignored.
232 | """
233 | return self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar
234 |
235 | def draw(self, ax, sub):
236 | sh = tsensor.analysis._shape(sub.value)
237 | ty = tsensor.analysis._dtype(sub.value)
238 | self._dtype_encountered.add(ty)
239 | if len(sh) == 1:
240 | self.draw_vector(ax, sub, sh, ty)
241 | else:
242 | self.draw_matrix(ax, sub, sh, ty)
243 |
244 | def draw_vector(self,ax,sub, sh, ty: str):
245 | mid = (sub.leftx + sub.rightx) / 2
246 | w,h = self.vector_size(sh, ty)
247 | color = self.dtype_color_info.color(ty)
248 | rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h),
249 | width=w,
250 | height=h,
251 | linewidth=self.linewidth,
252 | facecolor=color,
253 | edgecolor='grey',
254 | fill=True)
255 | ax.add_patch(rect1)
256 |
257 | # Text above vector rectangle
258 | ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]),
259 | horizontalalignment='center',
260 | fontname=self.dimfontname, fontsize=self.dimfontsize)
261 | # Type info at the bottom of everything
262 | ax.text(mid, self.box_topy - self.hchar, '<${\mathit{'+ty+'}}$>',
263 | verticalalignment='top', horizontalalignment='center',
264 | fontname=self.dimfontname, fontsize=self.dimfontsize-2)
265 |
266 | def draw_matrix(self,ax,sub, sh, ty):
267 | mid = (sub.leftx + sub.rightx) / 2
268 | w,h = self.matrix_size(sh, ty)
269 | box_left = mid - w / 2
270 | color = self.dtype_color_info.color(ty)
271 |
272 | if len(sh) > 2:
273 | back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D),
274 | width=w,
275 | height=h,
276 | linewidth=self.linewidth,
277 | facecolor=color,
278 | edgecolor='grey',
279 | fill=True)
280 | ax.add_patch(back_rect)
281 | rect = patches.Rectangle(xy=(box_left, self.box_topy - h),
282 | width=w,
283 | height=h,
284 | linewidth=self.linewidth,
285 | facecolor=color,
286 | edgecolor='grey',
287 | fill=True)
288 | ax.add_patch(rect)
289 |
290 | # Text above matrix rectangle
291 | ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]),
292 | verticalalignment='center', horizontalalignment='right',
293 | fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90)
294 |
295 | # Note: this was always true since matrix...
296 | textx = mid
297 | texty = self.box_topy + self.dim_ypadding
298 | if len(sh) > 2:
299 | texty += self.dim_ypadding
300 | textx += self.shift3D
301 |
302 | # Text to the left
303 | ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center',
304 | fontname=self.dimfontname, fontsize=self.dimfontsize)
305 |
306 | if len(sh) > 2:
307 | # Text to the right
308 | ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]),
309 | verticalalignment='center', horizontalalignment='center',
310 | fontname=self.dimfontname, fontsize=self.dimfontsize,
311 | rotation=45)
312 |
313 | bottom_text_line = self.box_topy - h - self.dim_ypadding
314 | if len(sh) > 3:
315 | # Text below
316 | remaining = r"$\cdots\mathsf{x}$"+r"$\mathsf{x}$".join([self.nabbrev(sh[i]) for i in range(3,len(sh))])
317 | bottom_text_line = self.box_topy - h - self.dim_ypadding
318 | ax.text(mid, bottom_text_line, remaining,
319 | verticalalignment='top', horizontalalignment='center',
320 | fontname=self.dimfontname, fontsize=self.dimfontsize)
321 | bottom_text_line -= self.hchar + self.dim_ypadding
322 |
323 | # Type info at the bottom of everything
324 | ax.text(mid, bottom_text_line, '<${\mathit{'+ty+'}}$>',
325 | verticalalignment='top', horizontalalignment='center',
326 | fontname=self.dimfontname, fontsize=self.dimfontsize-2)
327 |
328 | @staticmethod
329 | def nabbrev(n: int) -> str:
330 | if n % 1_000_000 == 0:
331 | return str(n // 1_000_000)+'m'
332 | if n % 1_000 == 0:
333 | return str(n // 1_000)+'k'
334 | return str(n)
335 |
336 |
337 | def pyviz(statement: str, frame=None,
338 | fontname='Consolas', fontsize=13,
339 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443',
340 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
341 | ax=None, dpi=200, hush_errors=True,
342 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> PyVizView:
343 | """
344 | Parse and evaluate the Python code in the statement string passed in using
345 | the indicated execution frame. The execution frame of the invoking function
346 | is used if frame is None.
347 |
348 | The visualization finds the smallest subexpressions that evaluate to
349 | tensors then underlies them and shows a box or rectangle representing
350 | the tensor dimensions. Boxes in blue (default) have two or more dimensions
351 | but rectangles in yellow (default) have one dimension with shape (n,).
352 |
353 | Upon tensor-related execution error, the offending self-expression is
354 | highlighted (by de-highlighting the other code) and the operator is shown
355 | using error_op_color.
356 |
357 | To adjust the size of the generated visualization to be smaller or bigger,
358 | decrease or increase the font size.
359 |
360 | :param statement: A string representing the line of Python code to visualize within an execution frame.
361 | :param frame: The execution frame in which to evaluate the statement. If None,
362 | use the execution frame of the invoking function
363 | :param fontname: The name of the font used to display Python code
364 | :param fontsize: The font size used to display Python code; default is 13.
365 | Also use this to increase the size of the generated figure;
366 | larger font size means larger image.
367 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
368 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
369 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
370 | text is when plotted in matplotlib. In fact there's probably,
371 | no hope to discover this information accurately in all cases.
372 | Certainly, I gave up after spending huge effort. We have a
373 | situation here where the font should be constant width, so
374 | we can just use a simple scalar times the font size to get
375 | a reasonable approximation of the width and height of a
376 | character box; the default of 1.8 seems to work reasonably
377 | well for a wide range of fonts, but you might have to tweak it
378 | when you change the font size.
379 | :param fontcolor: The color of the Python code.
380 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
381 | :param ignored_color: The de-highlighted color for de-emphasizing code not involved in an erroneous sub expression
382 | :param error_op_color: The color to use for characters associated with the erroneous operator
383 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
384 | :param dpi: This library tries to generate SVG files, which are vector graphics not
385 | 2D arrays of pixels like PNG files. However, it needs to know how to
386 | compute the exact figure size to remove padding around the visualization.
387 | Matplotlib uses inches for its figure size and so we must convert
388 | from pixels or data units to inches, which means we have to know what the
389 | dots per inch, dpi, is for the image.
390 | :param hush_errors: Normally, error messages from true syntax errors but also
391 | unhandled code caught by my parser are ignored. Turn this off
392 | to see what the error messages are coming from my parser.
393 | :param dtype_colors: map from dtype w/o precision like 'int' to color
394 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128]
395 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color,
396 | and the alpha channel is used to show precision; the
397 | smaller the bit size, the lower the alpha channel. You
398 | can play with the range to get better visual dynamic range
399 | depending on how many precisions you want to display.
400 | :return: Returns a PyVizView holding info about the visualization; from a notebook
401 | an SVG image will appear. Return none upon parsing error in statement.
402 | """
403 | view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, char_sep_scale, dpi,
404 | dtype_colors, dtype_precisions, dtype_alpha_range)
405 |
406 | if frame is None: # use frame of caller if not passed in
407 | frame = sys._getframe().f_back
408 | root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors)
409 | if root is None:
410 | print(f"Can't parse {statement}; root is None")
411 | # likely syntax error in statement or code I can't handle
412 | return None
413 | root_to_viz = root
414 | try:
415 | root.eval(frame)
416 | except tsensor.ast.IncrEvalTrap as e:
417 | root_to_viz = e.offending_expr
418 | view.offending_expr = e.offending_expr
419 | view.cause = e.__cause__
420 | # Don't raise the exception; keep going to visualize code and erroneous
421 | # subexpressions. If this function is invoked from clarify() or explain(),
422 | # the statement will be executed and will fail again during normal execution;
423 | # an exception will be thrown at that time. Then explain/clarify
424 | # will update the error message
425 | subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz)
426 | if ax is None:
427 | fig, ax = plt.subplots(1, 1, dpi=dpi)
428 | else:
429 | fig = ax.figure
430 | view.fignumber = fig.number # track this so that we can determine if the figure has been closed
431 |
432 | ax.axis("off")
433 |
434 | # First, we need to figure out how wide the visualization components are
435 | # for each sub expression. If these are wider than the sub expression text,
436 | # than we need to leave space around the sub expression text
437 | lpad = np.zeros((len(statement),)) # pad for characters
438 | rpad = np.zeros((len(statement),))
439 | maxh = 0
440 | for sub in subexprs:
441 | w, h = view.boxsize(sub.value)
442 | # update width to include horizontal room for type text like int32
443 | ty = tsensor.analysis._dtype(sub.value)
444 | w_typename = len(ty) * view.wchar_small
445 | w = max(w, w_typename)
446 | maxh = max(h, maxh)
447 | nexpr = sub.stop.cstop_idx - sub.start.cstart_idx
448 | if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space
449 | nexpr += 1
450 | if sub.stop.cstop_idx view.wchar * nexpr:
453 | lpad[sub.start.cstart_idx] += (w - view.wchar) / 2
454 | rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2
455 |
456 | # Now we know how to place all the elements, since we know what the maximum height is
457 | view.set_locations(maxh)
458 |
459 | # Find each character's position based upon width of a character and any padding
460 | charx = np.empty((len(statement),))
461 | x = view.leftedge
462 | for i,c in enumerate(statement):
463 | x += lpad[i]
464 | charx[i] = x
465 | x += view.wchar
466 | x += rpad[i]
467 |
468 | # Draw text for statement or expression
469 | if view.offending_expr is not None: # highlight erroneous subexpr
470 | highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
471 | for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]:
472 | highlight[tok.cstart_idx:tok.cstop_idx] = True
473 | errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
474 | for tok in root_to_viz.optokens:
475 | errors[tok.cstart_idx:tok.cstop_idx] = True
476 | for i, c in enumerate(statement):
477 | color = ignored_color
478 | if highlight[i]:
479 | color = fontcolor
480 | if errors[i]: # override color if operator token
481 | color = error_op_color
482 | ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize)
483 | else:
484 | for i, c in enumerate(statement):
485 | ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize)
486 |
487 | # Compute the left and right edges of subexpressions (alter nodes with info)
488 | for i,sub in enumerate(subexprs):
489 | a = charx[sub.start.cstart_idx]
490 | b = charx[sub.stop.cstop_idx - 1] + view.wchar
491 | sub.leftx = a
492 | sub.rightx = b
493 |
494 | # Draw grey underlines and draw matrices
495 | for i,sub in enumerate(subexprs):
496 | a,b = sub.leftx, sub.rightx
497 | pad = view.wchar*0.1
498 | ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color)
499 | view.draw(ax, sub)
500 |
501 | fig_width = charx[-1] + view.wchar + rpad[-1]
502 | fig_width_inches = fig_width / dpi
503 | fig_height_inches = view.maxy / dpi
504 | fig.set_size_inches(fig_width_inches, fig_height_inches)
505 |
506 | ax.set_xlim(0, fig_width)
507 | ax.set_ylim(0, view.maxy)
508 |
509 | return view
510 |
511 |
512 | # ---------------- SHOW AST STUFF ---------------------------
513 |
514 | class QuietGraphvizWrapper(graphviz.Source):
515 | def __init__(self, dotsrc):
516 | super().__init__(source=dotsrc)
517 |
518 | def _repr_svg_(self):
519 | return self.pipe(format='svg', quiet=True).decode(self._encoding)
520 |
521 | def savefig(self, filename):
522 | path = Path(filename)
523 | path.parent.mkdir(exist_ok=True)
524 |
525 | dotfilename = self.save(directory=path.parent.as_posix(), filename=path.stem)
526 | format = path.suffix[1:] # ".svg" -> "svg" etc...
527 | cmd = ["dot", f"-T{format}", "-o", filename, dotfilename]
528 | # print(' '.join(cmd))
529 | if graphviz.__version__ <= '0.17':
530 | graphviz.backend.run(cmd, capture_output=True, check=True, quiet=False)
531 | else:
532 | graphviz.backend.execute.run_check(cmd, capture_output=True, check=True, quiet=False)
533 |
534 |
535 | def astviz(statement:str, frame='current',
536 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> graphviz.Source:
537 | """
538 | Display the abstract syntax tree (AST) for the indicated Python code
539 | in statement. Evaluate that code in the context of frame. If the frame
540 | is not specified, the default is to execute the code within the context of
541 | the invoking code. Pass in frame=None to avoid evaluation and just display
542 | the AST.
543 |
544 | Returns a QuietGraphvizWrapper that renders as SVG in a notebook but
545 | you can also call `savefig()` to save the file and in a variety of formats,
546 | according to the file extension.
547 | """
548 | return QuietGraphvizWrapper(
549 | astviz_dot(statement, frame,
550 | dtype_colors, dtype_precisions, dtype_alpha_range)
551 | )
552 |
553 |
554 | def astviz_dot(statement:str, frame='current',
555 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> str:
556 | def internal_label(node):
557 | sh = tsensor.analysis._shape(node.value)
558 | ty = tsensor.analysis._dtype(node.value)
559 | text = ''.join(str(t) for t in node.optokens)
560 | if sh is None:
561 | return f'{text}'
562 |
563 | sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))])
564 | return f"""{text} {sz} <{ty}>"""
565 |
566 | dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range)
567 |
568 | root, tokens = tsensor.parsing.parse(statement)
569 |
570 | if frame=='current': # use frame of caller if nothing passed in
571 | frame = sys._getframe().f_back
572 | if frame.f_code.co_name=='astviz':
573 | frame = frame.f_back
574 |
575 | if frame is not None: # if the passed in None, then don't do the evaluation
576 | root.eval(frame)
577 |
578 | nodes = tsensor.ast.postorder(root)
579 | atoms = tsensor.ast.leaves(root)
580 | atomsS = set(atoms)
581 | ops = [nd for nd in nodes if nd not in atomsS] # keep order
582 |
583 | gr = """digraph G {
584 | margin=0;
585 | nodesep=.01;
586 | ranksep=.3;
587 | rankdir=BT;
588 | ordering=out; # keep order of leaves
589 | """
590 |
591 | fontname="Consolas"
592 | fontsize=12
593 | dimfontsize = 9
594 | spread = 0
595 |
596 | # Gen leaf nodes
597 | for i in range(len(tokens)):
598 | t = tokens[i]
599 | if t.type!=token.ENDMARKER:
600 | nodetext = t.value
601 | # if ']' in nodetext:
602 | if nodetext==']':
603 | nodetext = nodetext.replace(']',']') # is 0-width nonjoiner. ']' by itself is bad for DOT
604 | label = f'{nodetext}'
605 | _spread = spread
606 | if t.type==token.DOT:
607 | _spread=.1
608 | elif t.type==token.EQUAL:
609 | _spread=.25
610 | elif t.type in tsensor.parsing.ADDOP:
611 | _spread=.4
612 | elif t.type in tsensor.parsing.MULOP:
613 | _spread=.2
614 | gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n'
615 |
616 | # Make sure leaves are on same level
617 | gr += f'{{ rank=same; '
618 | for t in tokens:
619 | if t.type!=token.ENDMARKER:
620 | gr += f' leaf{id(t)}'
621 | gr += '\n}\n'
622 |
623 | # Make sure leaves are left to right by linking
624 | for i in range(len(tokens) - 2):
625 | t = tokens[i]
626 | t2 = tokens[i + 1]
627 | gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n'
628 |
629 | # Draw internal ops nodes
630 | for nd in ops:
631 | label = internal_label(nd)
632 | sh = tsensor.analysis._shape(nd.value)
633 | if sh is None:
634 | color = ""
635 | else:
636 | ty = tsensor.analysis._dtype(nd.value)
637 | color = dtype_color_info.color(ty)
638 | color = mc.rgb2hex(color, keep_alpha=True)
639 | color = f'fillcolor="{color}" style=filled'
640 | gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n'
641 |
642 | # Link internal nodes to other nodes or leaves
643 | for nd in nodes:
644 | kids = nd.kids
645 | for sub in kids:
646 | if sub in atomsS:
647 | gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
648 | else:
649 | gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
650 |
651 | gr += "}\n"
652 | return gr
653 |
--------------------------------------------------------------------------------