├── images ├── dtypes.png ├── experiments.graffle ├── sample-1.svg ├── sample-2.svg ├── ast.svg └── mm.svg ├── talks ├── tensor-sensor.pdf ├── tensor-sensor.pptx └── tensor-sensor-old-msgs.pptx ├── testing ├── ones.py ├── testexc.py ├── test_invalid_stat.py ├── test_nested.py ├── test3.py ├── test2.py ├── test_dtype.py ├── test_tensorflow.py ├── play_cmaps.py ├── test_incr_eval.py ├── str_size_matplotlib.py ├── test_tree_eval.py ├── test_jax.py ├── viz_testing.py └── test_parser.py ├── contributing.md ├── .github └── workflows │ └── test.yml ├── LICENSE ├── tsensor ├── version.py ├── __init__.py ├── ast.py ├── parsing.py ├── analysis.py └── viz.py ├── setup.py ├── .gitignore └── README.md /images/dtypes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/images/dtypes.png -------------------------------------------------------------------------------- /talks/tensor-sensor.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor.pdf -------------------------------------------------------------------------------- /talks/tensor-sensor.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor.pptx -------------------------------------------------------------------------------- /images/experiments.graffle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/images/experiments.graffle -------------------------------------------------------------------------------- /talks/tensor-sensor-old-msgs.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parrt/tensor-sensor/HEAD/talks/tensor-sensor-old-msgs.pptx -------------------------------------------------------------------------------- /testing/ones.py: -------------------------------------------------------------------------------- 1 | # Regression test for https://github.com/parrt/tensor-sensor/issues/16 2 | # Should not throw exception or error, just show equation. 3 | import numpy as np 4 | import tsensor 5 | 6 | print(tsensor.__version__) 7 | with tsensor.explain(): 8 | a = np.ones(3) 9 | -------------------------------------------------------------------------------- /testing/testexc.py: -------------------------------------------------------------------------------- 1 | foo = None 2 | 3 | try: 4 | try: 5 | state = "bar" 6 | foo.append(state) 7 | 8 | except Exception as e: 9 | e.args = ("Appending '"+state+"' failed", *e.args) 10 | raise 11 | 12 | print(foo[0]) # would raise too 13 | 14 | except Exception as e: 15 | e.message = "foo" 16 | e.args = ("print(foo) failed: " + str(foo), *e.args) 17 | raise -------------------------------------------------------------------------------- /contributing.md: -------------------------------------------------------------------------------- 1 | Thank you for considering a contribution to TensorSensor. I welcome updates to the examples, README, or the source code itself. 2 | 3 | Before you spend a lot of time creating a pull request (PR) for a new feature or a biggish change to the software, please contact me by email parrt@cs.usfca.edu to find out if it's something I'm interested in. One of the hardest things to do, as I've learned over the last 30 years pushing out open-source software, is to maintain project focus and conceptual integrity. 4 | -------------------------------------------------------------------------------- /testing/test_invalid_stat.py: -------------------------------------------------------------------------------- 1 | import tsensor 2 | import numpy as np 3 | 4 | def f(): 5 | # Currently can't handle double assign 6 | a = b = np.ones(1) @ np.ones(2) 7 | 8 | def A(): 9 | with tsensor.clarify(): 10 | f() 11 | 12 | 13 | def test_nested(): 14 | msg = "" 15 | try: 16 | A() 17 | except BaseException as e: 18 | msg = e.args[0] 19 | 20 | expected = "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)" 21 | assert msg==expected 22 | -------------------------------------------------------------------------------- /testing/test_nested.py: -------------------------------------------------------------------------------- 1 | # Test for https://github.com/parrt/tensor-sensor/issues/18 2 | # Nested clarify's and all catch exception 3 | 4 | import tsensor 5 | import numpy as np 6 | 7 | def f(): 8 | np.ones(1) @ np.ones(2) 9 | 10 | def A(): 11 | with tsensor.clarify(): 12 | f() 13 | 14 | def B(): 15 | with tsensor.clarify(): 16 | A() 17 | 18 | def test_nested(): 19 | msg = "" 20 | try: 21 | B() 22 | except BaseException as e: 23 | msg = e.args[0] 24 | 25 | expected = "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)\n"+\ 26 | "Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)" 27 | assert msg==expected 28 | -------------------------------------------------------------------------------- /testing/test3.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import tsensor 4 | import torch 5 | import sys 6 | 7 | W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) 8 | b = torch.tensor([9, 10]).reshape(2, 1) 9 | x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1) 10 | h = torch.tensor([1,2]) 11 | 12 | # fig, ax = plt.subplots(1,1) 13 | # # view = tsensor.pyviz("b + x", ax=ax, legend=True) 14 | # # view.savefig("/Users/parrt/Desktop/foo.pdf") 15 | # plt.show() 16 | 17 | W = torch.rand(size=(2000,2000), dtype=torch.float64) 18 | b = torch.rand(size=(2000,1), dtype=torch.float64) 19 | h = torch.zeros(size=(1_000_000,), dtype=int) 20 | x = torch.rand(size=(2000,1)) 21 | z = torch.rand(size=(2000,1), dtype=torch.complex64) 22 | g = tsensor.astviz("b = W@b + (h+3).dot(h) + z", 23 | sys._getframe()) # eval, highlight vectors 24 | g.view() 25 | 26 | # with tsensor.explain(): 27 | # b + x 28 | 29 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test TensorSensor 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, macos-latest, windows-latest] 15 | python-version: [3.9] 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up python 3 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: '3.9' 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pyparsing==2.4.7 27 | pip install torch 28 | python setup.py install 29 | 30 | - name: Test with pytest 31 | run: | 32 | pip install pytest 33 | pytest testing/test_incr_eval.py 34 | pytest testing/test_invalid_stat.py 35 | pytest testing/test_parser.py 36 | pytest testing/test_tree_eval.py 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Terence Parr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /testing/test2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tsensor 3 | import torch 4 | 5 | W = np.array([[1, 2], [3, 4]]) 6 | b = np.array([9, 10]).reshape(2, 1) 7 | x = np.array([4, 5]).reshape(2, 1) 8 | h = np.array([1, 2]) 9 | # with tsensor.explain(savefig="/Users/parrt/Desktop/foo.pdf"): 10 | # W @ np.dot(b,b) + np.eye(2,2)@x 11 | 12 | 13 | W = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) 14 | b = torch.tensor([9, 10]).reshape(2, 1) 15 | x = torch.tensor([4, 5], dtype=torch.int32).reshape(2, 1) 16 | h = torch.tensor([1,2]) 17 | 18 | a = torch.rand(size=(2, 20), dtype=torch.float64) 19 | b = torch.rand(size=(2, 20), dtype=torch.float32) 20 | c = torch.rand(size=(2,20,200), dtype=torch.complex64) 21 | d = torch.rand(size=(2,20,200,5), dtype=torch.float16) 22 | 23 | 24 | with tsensor.explain(savefig="/Users/parrt/Desktop/t2.pdf"): 25 | a + b + x + c[:,:,0] + d[:,:,0,0] 26 | 27 | with tsensor.explain(savefig="/Users/parrt/Desktop/t3.pdf"): 28 | c 29 | 30 | with tsensor.explain(savefig="/Users/parrt/Desktop/t4.pdf"): 31 | d 32 | 33 | # with tsensor.explain(legend=True, savefig="/Users/parrt/Desktop/t.pdf") as e: 34 | # W @ torch.dot(b, b) + torch.eye(2, 2) @ x 35 | -------------------------------------------------------------------------------- /tsensor/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | __version__ = '1.0' 25 | -------------------------------------------------------------------------------- /testing/test_dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | import tensorflow as tf 4 | import torch 5 | import pytest 6 | 7 | import tsensor as ts 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "value,expected", 12 | [ 13 | # Numpy 14 | (np.random.randint(1, 10, size=(10, 2, 5)), "int64"), 15 | (np.random.randint(1, 10, size=(10, 2, 5), dtype="int8"), "int8"), 16 | (np.random.normal(size=(5, 1)).astype(np.float32), "float32"), 17 | (np.random.normal(size=(5, 1)).astype(np.float32), "float32"), 18 | (np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')]), 19 | "str320,int32,float32"), 20 | # Jax 21 | (jnp.array([[1, 2], [3, 4]]), "int32"), 22 | (jnp.array([[1, 2], [3, 4]], dtype="int8"), "int8"), 23 | # Tensorflow 24 | (tf.constant([[1, 2], [3, 4]]), "int32"), 25 | (tf.constant([[1, 2], [3, 4]], dtype="int64"), "int64"), 26 | # Pytorch 27 | (torch.tensor([[1, 2], [3, 4]]), "int64"), 28 | (torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "int32"), 29 | ], 30 | ) 31 | def test_dtypes(value, expected): 32 | assert ts.analysis._dtype(value) == expected 33 | -------------------------------------------------------------------------------- /tsensor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | # These classes/functions are the primary user interface so import them directly 25 | import tsensor.ast 26 | import tsensor.parsing 27 | import tsensor.viz 28 | import tsensor.analysis 29 | from tsensor.analysis import explain, clarify, eval 30 | from tsensor.parsing import parse 31 | from tsensor.viz import pyviz, astviz 32 | from .version import __version__ 33 | 34 | 35 | __all__ = ["ast", "parsing", "viz", "analysis", "version"] 36 | 37 | # To fix an OpenMP runtime link issue. 38 | import os 39 | 40 | os.environ['KMP_DUPLICATE_LIB_OK'] = "True" 41 | -------------------------------------------------------------------------------- /testing/test_tensorflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import tsensor 25 | import tensorflow as tf 26 | 27 | W = tf.constant([[1, 2], [3, 4]]) 28 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1)) 29 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1)) 30 | 31 | def test_addition(): 32 | msg = "" 33 | try: 34 | with tsensor.clarify(): 35 | q = b + x + 3 36 | except tf.errors.InvalidArgumentError as iae: 37 | msg = iae.message 38 | 39 | expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\ 40 | "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)" 41 | assert msg==expected -------------------------------------------------------------------------------- /testing/play_cmaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as plt 4 | import matplotlib.colors as mc 5 | 6 | bits = [4,8,16,32,64,128] 7 | bits = [8,16,32,64,128] 8 | nhues = len(bits) 9 | 10 | blueish = '#3B75AF' 11 | greenish = '#519E3E' 12 | 13 | orangeish = '#FDDB7D' 14 | limeish = '#C1E1C5' 15 | limeish = '#A8E1B0' 16 | yellowish = '#FFFFAD' 17 | 18 | print(mc.hex2color(limeish)) 19 | 20 | type_colors = {'float':limeish, 'int':blueish, 'complex':orangeish} 21 | 22 | # Derived from https://stackoverflow.com/questions/47222585/matplotlib-generic-colormap-from-tab10 23 | 24 | def categorical_cmap(color, nsc): 25 | # ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int)) 26 | # print(ccolors[0:4]) 27 | cols = np.zeros((nsc, 3)) 28 | # chsv = mc.rgb_to_hsv(c[:3]) 29 | chsv = mc.rgb_to_hsv(mc.hex2color(color)) 30 | arhsv = np.tile(chsv,nsc).reshape(nsc,3) 31 | arhsv[:,1] = np.linspace(chsv[1],0.25,nsc) 32 | arhsv[:,2] = np.linspace(chsv[2],1,nsc) 33 | rgb = mc.hsv_to_rgb(arhsv) 34 | cols[0:nsc,:] = rgb 35 | cmap = mc.ListedColormap(cols) 36 | return cmap 37 | 38 | plt.figure(figsize=(3,3)) 39 | c1 = categorical_cmap(blueish,nhues) 40 | plt.scatter(np.arange(nhues),[1]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey') 41 | c1 = categorical_cmap(limeish,nhues) 42 | plt.scatter(np.arange(nhues),[2]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey') 43 | c1 = categorical_cmap(yellowish,nhues) 44 | plt.scatter(np.arange(nhues),[3]*nhues, c=np.arange(nhues), s=1080, cmap=c1, linewidths=.5, edgecolors='grey') 45 | 46 | plt.margins(y=3) 47 | plt.xticks([]) 48 | plt.yticks([0,1,2],["(5, 4)", "(2, 5)", "(4, 3)"]) 49 | plt.ylim(0, 4) 50 | plt.axis('off') 51 | 52 | plt.savefig("/Users/parrt/Desktop/colors.pdf") 53 | plt.show() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from setuptools import setup 25 | 26 | tensorflow_requires = ['tensorflow'] 27 | torch_requires = ['torch'] 28 | jax_requires = ['jax', 'jaxlib'] 29 | all_requires = tensorflow_requires + torch_requires + jax_requires 30 | 31 | exec(open('tsensor/version.py').read()) 32 | setup( 33 | name='tensor-sensor', 34 | version=__version__, 35 | url='https://github.com/parrt/tensor-sensor', 36 | license='MIT', 37 | py_modules=['tsensor.parsing', 'tsensor.ast', 'tsensor.analysis', 'tsensor.viz', 'tsensor.version'], 38 | author='Terence Parr', 39 | author_email='parrt@cs.usfca.edu', 40 | python_requires='>=3.6', 41 | install_requires=['graphviz>=0.14.1','numpy','IPython', 'matplotlib'], 42 | extras_require = {'all': all_requires, 43 | 'torch': torch_requires, 44 | 'tensorflow': tensorflow_requires, 45 | 'jax': jax_requires 46 | }, 47 | description='The goal of this library is to generate more helpful exception messages for numpy/pytorch tensor algebra expressions.', 48 | classifiers=['License :: OSI Approved :: MIT License', 49 | 'Intended Audience :: Developers'] 50 | ) 51 | -------------------------------------------------------------------------------- /testing/test_incr_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.ast import IncrEvalTrap 25 | from tsensor.parsing import PyExprParser 26 | import sys 27 | import numpy as np 28 | import torch 29 | 30 | def check(s,expected): 31 | frame = sys._getframe() 32 | caller = frame.f_back 33 | p = PyExprParser(s) 34 | t = p.parse() 35 | bad_subexpr = None 36 | try: 37 | t.eval(caller) 38 | except IncrEvalTrap as exc: 39 | bad_subexpr = str(exc.offending_expr) 40 | assert bad_subexpr==expected 41 | 42 | 43 | def test_missing_var(): 44 | a = 3 45 | c = 5 46 | check("a+b+c", "b") 47 | check("z+b+c", "z") 48 | 49 | def test_matrix_mult(): 50 | W = torch.tensor([[1, 2], [3, 4]]) 51 | b = torch.tensor([[1,2,3]]) 52 | check("W@b+torch.abs(b)", "W@b") 53 | 54 | def test_bad_arg(): 55 | check("torch.abs('foo')", "torch.abs('foo')") 56 | 57 | def test_parens(): 58 | a = 3 59 | b = 4 60 | c = 5 61 | check("(a+b)/0", "(a+b)/0") 62 | 63 | def test_array_literal(): 64 | a = torch.tensor([[1,2,3],[4,5,6]]) 65 | b = torch.tensor([[1,2,3]]) 66 | a+b 67 | check("a + b@2", """b@2""") 68 | 69 | def test_array_literal2(): 70 | a = torch.tensor([[1,2,3],[4,5,6]]) 71 | b = torch.tensor([[1,2,3]]) 72 | a+b 73 | check("(a+b)@2", """(a+b)@2""") 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /testing/str_size_matplotlib.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.patches as patches 3 | from matplotlib import pyplot as plt 4 | 5 | def textdim(s, fontname='Consolas', fontsize=11): 6 | fig, ax = plt.subplots(1,1) 7 | t = ax.text(0, 0, s, bbox={'lw':0, 'pad':0}, fontname=fontname, fontsize=fontsize) 8 | # plt.savefig(tempfile.mktemp(".pdf")) 9 | plt.savefig("/tmp/font.pdf", pad_inches=0, dpi=200) 10 | print(t) 11 | plt.close() 12 | bb = t.get_bbox_patch() 13 | print(bb) 14 | w, h = bb.get_width(), bb.get_height() 15 | return w, h 16 | 17 | # print(textdim("@")) 18 | #exit() 19 | 20 | 21 | # From: https://stackoverflow.com/questions/22667224/matplotlib-get-text-bounding-box-independent-of-backend 22 | def find_renderer(fig): 23 | if hasattr(fig.canvas, "get_renderer"): 24 | #Some backends, such as TkAgg, have the get_renderer method, which 25 | #makes this easy. 26 | renderer = fig.canvas.get_renderer() 27 | else: 28 | #Other backends do not have the get_renderer method, so we have a work 29 | #around to find the renderer. Print the figure to a temporary file 30 | #object, and then grab the renderer that was used. 31 | #(I stole this trick from the matplotlib backend_bases.py 32 | #print_figure() method.) 33 | import io 34 | fig.canvas.print_pdf(io.BytesIO()) 35 | renderer = fig._cachedRenderer 36 | return(renderer) 37 | 38 | 39 | def textdim(s, fontname='Consolas', fontsize=11): 40 | fig, ax = plt.subplots(1, 1) 41 | t = ax.text(0, 0, s, fontname=fontname, fontsize=fontsize, transform=None) 42 | bb = t.get_window_extent(find_renderer(fig)) 43 | print(s, bb.width, bb.height) 44 | 45 | # t = mpl.textpath.TextPath(xy=(0, 0), s=s, size=fontsize, prop=fontname) 46 | # bb = t.get_extents() 47 | # print(s, "new", bb) 48 | plt.close() 49 | return bb.width, bb.height 50 | 51 | # print(textdim("test of foo", fontsize=11)) 52 | # print(textdim("test of foO", fontsize=11)) 53 | # print(textdim("W @ b + x * 3 + h.dot(h)", fontsize=12)) 54 | 55 | code = 'W@ b + x *3 + h.dot(h)' 56 | code = 'W@ b.f(x,y)'# + x *3 + h.dot(h)' 57 | 58 | fig, ax = plt.subplots(1,1,figsize=(4,1)) 59 | 60 | fontname = "Serif" 61 | fontsize = 16 62 | 63 | # for c in code: 64 | # t = ax.text(0,0,c) 65 | # bbox1 = t.get_window_extent(find_renderer(fig)) 66 | # # print(c, '->', bbox1.width, bbox1.height) 67 | # print(c, '->', textdim(c, fontname=fontname, fontsize=fontsize)) 68 | # rect1 = patches.Rectangle((0,0), bbox1.width, bbox1.height, \ 69 | # color = [0,0,0], fill = False) 70 | # fig.patches.append(rect1) 71 | 72 | x = 0 73 | for c in code: 74 | # print(f"plot {c} at {x},{0}") 75 | ax.text(x, 10, c, fontname=fontname, fontsize=fontsize, transform=None) 76 | w, h = textdim(c, fontname=fontname, fontsize=fontsize) 77 | # print(w,h,'->',x) 78 | x = x + w 79 | 80 | ax.set_xlim(0,x) 81 | #ax.set_ylim(0,10) 82 | 83 | ax.axis('off') 84 | 85 | #plt.show() 86 | 87 | plt.tight_layout() 88 | 89 | plt.savefig("/tmp/t.pdf", bbox_inches='tight', pad_inches=0, dpi=200) 90 | -------------------------------------------------------------------------------- /testing/test_tree_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.parsing import * 25 | import tsensor.ast 26 | import sys 27 | import numpy as np 28 | 29 | 30 | def check(s,expected): 31 | frame = sys._getframe() 32 | caller = frame.f_back 33 | p = PyExprParser(s) 34 | t = p.parse() 35 | result = t.eval(caller) 36 | assert str(result)==str(expected) 37 | 38 | 39 | def test_int(): 40 | check("34", 34) 41 | 42 | def test_assign(): 43 | check("a = 34", 34) 44 | 45 | def test_var(): 46 | a = 34 47 | check("a", 34) 48 | 49 | def test_member_var(): 50 | class A: 51 | def __init__(self): 52 | self.a = 34 53 | x = A() 54 | check("x.a", 34) 55 | 56 | def test_member_func(): 57 | class A: 58 | def f(self, a): 59 | return a+4 60 | x = A() 61 | check("x.f(30)", 34) 62 | 63 | def test_index(): 64 | a = [1,2,3] 65 | check("a[2]", 3) 66 | 67 | def test_add(): 68 | a = 3 69 | b = 4 70 | c = 5 71 | check("a+b+c", 12) 72 | 73 | def test_add_mul(): 74 | a = 3 75 | b = 4 76 | c = 5 77 | check("a+b*c", 23) 78 | 79 | def test_pow(): 80 | a = 3 81 | b = 4 82 | check("a**b", 3**4) 83 | 84 | def test_pow2(): 85 | a = 3 86 | b = 4 87 | c = 5 88 | check("a**(b+1)**c", 3**5**5) 89 | 90 | def test_parens(): 91 | a = 3 92 | b = 4 93 | c = 5 94 | check("(a+b)*c", 35) 95 | 96 | def test_list_literal(): 97 | a = [[1,2,3],[4,5,6]] 98 | check("a", """[[1, 2, 3], [4, 5, 6]]""") 99 | 100 | 101 | def test_np_literal(): 102 | a = np.array([[1,2,3],[4,5,6]]) 103 | check("a*2", """[[ 2 4 6]\n [ 8 10 12]]""") 104 | 105 | 106 | def test_np_add(): 107 | a = np.array([[1,2,3],[4,5,6]]) 108 | check("a+a", """[[ 2 4 6]\n [ 8 10 12]]""") 109 | 110 | 111 | def test_np_add2(): 112 | a = np.array([[1,2,3],[4,5,6]]) 113 | check("a+a+a", """[[ 3 6 9]\n [12 15 18]]""") 114 | -------------------------------------------------------------------------------- /testing/test_jax.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import tsensor 27 | 28 | def test_dot(): 29 | size = 5000 30 | x = np.random.normal(size=(size, size)).astype(np.float32) 31 | y = np.random.normal(size=(5, 1)).astype(np.float32) 32 | 33 | msg = "" 34 | try: 35 | with tsensor.clarify(): 36 | z = jnp.dot(x, y).block_until_ready() 37 | except TypeError as e: 38 | msg = e.args[0] 39 | 40 | expected = "Incompatible shapes for dot: got (5000, 5000) and (5, 1).\n"+\ 41 | "Cause: jnp.dot(x, y) tensor arg x w/shape (5000, 5000), arg y w/shape (5, 1)" 42 | assert msg==expected 43 | 44 | 45 | def test_scalar_arg(): 46 | size = 5000 47 | x = np.random.normal(size=(size, size)).astype(np.float32) 48 | 49 | msg = "" 50 | try: 51 | with tsensor.clarify(): 52 | z = jnp.dot(x, "foo") 53 | except TypeError as e: 54 | msg = e.args[0] 55 | 56 | expected = 'data type \'foo\' not understood\n'+\ 57 | 'Cause: jnp.dot(x, "foo") tensor arg x w/shape (5000, 5000)' 58 | assert msg==expected 59 | 60 | 61 | def test_mmul(): 62 | W = jnp.array([[1, 2], [3, 4]]) 63 | b = jnp.array([9, 10, 11]) 64 | 65 | msg = "" 66 | try: 67 | with tsensor.clarify(): 68 | y = W @ b 69 | except TypeError as e: 70 | msg = e.args[0] 71 | 72 | expected = "dot_general requires contracting dimensions to have the same shape, got [2] and [3].\n"+\ 73 | "Cause: @ on tensor operand W w/shape (2, 2) and operand b w/shape (3,)" 74 | assert msg==expected 75 | 76 | 77 | def test_fft(): 78 | "Test a library function that doesn't have a shape related message in the exception." 79 | x = np.exp(2j * np.pi * np.arange(8) / 8) 80 | msg = "" 81 | try: 82 | with tsensor.clarify(): 83 | y = jnp.fft.fft(x, norm="something weird") 84 | except BaseException as e: 85 | msg = e.args[0] 86 | print(msg) 87 | 88 | expected = 'jax.numpy.fft.fft only supports norm=None, got something weird\n'+\ 89 | 'Cause: jnp.fft.fft(x, norm="something weird") tensor arg x w/shape (8,)' 90 | assert msg==expected 91 | -------------------------------------------------------------------------------- /testing/viz_testing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import graphviz 5 | import tempfile 6 | import matplotlib.patches as patches 7 | import matplotlib.pyplot as plt 8 | import matplotlib.font_manager as fm 9 | 10 | 11 | # print('\n'.join(str(f) for f in fm.fontManager.ttflist)) 12 | import tsensor 13 | # from tsensor.viz import pyviz, astviz 14 | 15 | import torch 16 | import tsensor 17 | 18 | n = 200 # number of instances 19 | d = 764 # number of instance features 20 | nhidden = 256 21 | 22 | Whh = torch.eye(nhidden, nhidden) # Identity matrix 23 | Uxh = torch.randn(nhidden, d) 24 | bh = torch.zeros(nhidden, 1) 25 | h = torch.randn(nhidden, 1) # fake previous hidden state h 26 | # r = torch.randn(nhidden, 1) # fake this computation 27 | r = torch.randn(nhidden, 3) # fake this computation 28 | X = torch.rand(n,d) # fake input 29 | 30 | # Following code raises an exception 31 | with tsensor.explain(savefig="/Users/parrt/Desktop/toomany.png") as e: 32 | h = torch.tanh(Whh @ (r*h) + Uxh @ X.T + bh) # state vector update equation 33 | 34 | exit() 35 | 36 | def foo(): 37 | # W = torch.rand(size=(2000, 2000)) 38 | W = torch.rand(size=(2000, 2000, 10, 8)) 39 | b = torch.rand(size=(2000, 1)) 40 | h = torch.rand(size=(1_000_000,)) 41 | x = torch.rand(size=(2000, 1)) 42 | # g = tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", 43 | # sys._getframe()) 44 | frame = sys._getframe() 45 | frame = None 46 | g = tsensor.astviz("b = W[:,:,0,0]@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", 47 | frame) 48 | g.view() 49 | 50 | #foo() 51 | 52 | class Linear: 53 | def __init__(self, d, n_neurons): 54 | self.W = torch.randn(n_neurons, d) 55 | self.b = torch.zeros(n_neurons, 1) 56 | def __call__(self, x): 57 | return self.W@x + self.b 58 | 59 | 60 | n = 200 # number of instances 61 | d = 764 # number of instance features 62 | n_neurons = 100 # how many neurons in this layer? 63 | # L = Linear(d,n_neurons) 64 | # 65 | # import tensorflow as tf 66 | # X = tf.random.uniform((n,d)) 67 | # with tsensor.clarify(hush_errors=False): 68 | # Y = L(X) 69 | 70 | # g = tsensor.pyviz("Y = L(X)", hush_errors=False) 71 | # g.show() 72 | 73 | class GRU: 74 | def __init__(self): 75 | self.W = torch.rand(size=(2,20,2000,10)) 76 | self.b = torch.rand(size=(20,1)) 77 | # self.x = torch.tensor([4, 5]).reshape(2, 1) 78 | self.h = torch.rand(size=(1_000_000,)) 79 | self.a = 3 80 | print(self.W.shape) 81 | print(self.W[:, :, 1].shape) 82 | 83 | def get(self): 84 | return torch.tensor([[1, 2], [3, 4]]) 85 | 86 | # W = torch.tensor([[1, 2], [3, 4]]) 87 | b = torch.rand(size=(2000,1)) 88 | h = torch.rand(size=(1_000_000,2)) 89 | x = torch.rand(size=(1_000_000,2)) 90 | a = 3 91 | 92 | # foo = torch.rand(size=(2000,)) 93 | # torch.relu(foo) 94 | 95 | g = GRU() 96 | 97 | # with tsensor.clarify(): 98 | # tf.constant([1,2]) @ tf.constant([1,3]) 99 | 100 | 101 | code = "b = g.W[0,:,:,1]@b+torch.zeros(200,1)+(h+3).dot(h)" 102 | # code = "torch.relu(foo)" 103 | # code = "np.dot(b,b)" 104 | # code = "b.T" 105 | g = tsensor.pyviz(code, fontname='Courier New', fontsize=16, dimfontsize=9, 106 | char_sep_scale=1.8, hush_errors=False) 107 | plt.tight_layout() 108 | plt.savefig("/tmp/t.svg", dpi=200, bbox_inches='tight', pad_inches=0) 109 | 110 | # W = torch.tensor([[1, 2], [3, 4]]) 111 | # x = torch.tensor([4, 5]).reshape(2, 1) 112 | # with tsensor.explain(): 113 | # b = torch.rand(size=(2000,)) 114 | # torch.relu(b) 115 | 116 | 117 | # g = GRU() 118 | # 119 | # g1 = tsensor.astviz("b = g.W@b + torch.eye(3,3)") 120 | # g1.view() 121 | # g1 = tsensor.pyviz("b = g.W@b") 122 | # g1.view() 123 | # g2 = tsensor.astviz("b = g.W@b + g.h.dot(g.h) + torch.abs(torch.tensor(34))") 124 | -------------------------------------------------------------------------------- /images/sample-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | leaf140728778297408 15 | 16 |     17 | 2x1 18 |     19 | 20 | 21 |     22 | a 23 |     24 | 25 | 26 | 27 | leaf140728778297504 28 | 29 | = 30 | 31 | 32 | 33 | 34 | leaf140728778299904 35 | 36 | torch 37 | 38 | 39 | 40 | 41 | leaf140728778300000 42 | 43 | . 44 | 45 | 46 | 47 | 48 | leaf140728778300048 49 | 50 | relu 51 | 52 | 53 | 54 | 55 | leaf140728778300144 56 | 57 | ( 58 | 59 | 60 | 61 | 62 | leaf140728778300192 63 | 64 |     65 | 2x1 66 |     67 | 68 | 69 |     70 | x 71 |     72 | 73 | 74 | 75 | 76 | leaf140728778300288 77 | 78 | ) 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensor Sensor 2 | 3 | See article [Clarifying exceptions and visualizing tensor operations in deep learning code](https://explained.ai/tensor-sensor/index.html) and [TensorSensor implementation slides](https://github.com/parrt/tensor-sensor/raw/master/talks/tensor-sensor.pdf) (PDF). 4 | 5 | (*As of September 2021, M1 macs experience illegal instructions in many of the tensor libraries installed via Anaconda, so you should expect TensorSensor to work only on Intel-based Macs at the moment. PyTorch appears to work.*) 6 | 7 | One of the biggest challenges when writing code to implement deep learning networks, particularly for us newbies, is getting all of the tensor (matrix and vector) dimensions to line up properly. It's really easy to lose track of tensor dimensionality in complicated expressions involving multiple tensors and tensor operations. Even when just feeding data into predefined [Tensorflow](https://www.tensorflow.org/) network layers, we still need to get the dimensions right. When you ask for improper computations, you're going to run into some less than helpful exception messages. 8 | 9 | To help myself and other programmers debug tensor code, I built this library. TensorSensor clarifies exceptions by augmenting messages and visualizing Python code to indicate the shape of tensor variables (see figure to the right for a teaser). It works with [Tensorflow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/), [JAX](https://github.com/google/jax), and [Numpy](https://numpy.org/), as well as higher-level libraries like [Keras](https://keras.io/) and [fastai](https://www.fast.ai/). 10 | 11 | *TensorSensor is currently at 1.0 (December 2021)*. 12 | 13 | ## Visualizations 14 | 15 | For more, see [examples.ipynb at colab](https://colab.research.google.com/github/parrt/tensor-sensor/blob/master/testing/examples.ipynb). (The github rendering does not show images for some reason: [examples.ipynb](testing/examples.ipynb).) 16 | 17 | ```python 18 | import numpy as np 19 | 20 | n = 200 # number of instances 21 | d = 764 # number of instance features 22 | n_neurons = 100 # how many neurons in this layer? 23 | 24 | W = np.random.rand(d,n_neurons) 25 | b = np.random.rand(n_neurons,1) 26 | X = np.random.rand(n,d) 27 | with tsensor.clarify() as c: 28 | Y = W @ X.T + b 29 | ``` 30 | 31 | Displays this in a jupyter notebook or separate window: 32 | 33 | 34 | 35 | Instead of the following default exception message: 36 | 37 | ``` 38 | ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100) 39 | ``` 40 | 41 | TensorSensor augments the message with more information about which operator caused the problem and includes the shape of the operands: 42 | 43 | ``` 44 | Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200) 45 | ``` 46 | 47 | You can also get the full computation graph for an expression that includes all of the sub-expression shapes. 48 | 49 | ```python 50 | W = torch.rand(size=(2000,2000), dtype=torch.float64) 51 | b = torch.rand(size=(2000,1), dtype=torch.float64) 52 | h = torch.zeros(size=(1_000_000,), dtype=int) 53 | x = torch.rand(size=(2000,1)) 54 | z = torch.rand(size=(2000,1), dtype=torch.complex64) 55 | 56 | tsensor.astviz("b = W@b + (h+3).dot(h) + z", sys._getframe()) 57 | ``` 58 | 59 | yields the following abstract syntax tree with shapes: 60 | 61 | 62 | 63 | ## Install 64 | 65 | ``` 66 | pip install tensor-sensor # This will only install the library for you 67 | pip install tensor-sensor[torch] # install pytorch related dependency 68 | pip install tensor-sensor[tensorflow] # install tensorflow related dependency 69 | pip install tensor-sensor[jax] # install jax, jaxlib 70 | pip install tensor-sensor[all] # install tensorflow, pytorch, jax 71 | ``` 72 | 73 | which gives you module `tsensor`. I developed and tested with the following versions 74 | 75 | ``` 76 | $ pip list | grep -i flow 77 | tensorflow 2.5.0 78 | tensorflow-estimator 2.5.0 79 | $ pip list | grep -i numpy 80 | numpy 1.19.5 81 | numpydoc 1.1.0 82 | $ pip list | grep -i torch 83 | torch 1.10.0 84 | torchvision 0.10.0 85 | $ pip list | grep -i jax 86 | jax 0.2.20 87 | jaxlib 0.1.71 88 | ``` 89 | 90 | ### Graphviz for tsensor.astviz() 91 | 92 | For displaying abstract syntax trees (ASTs) with `tsensor.astviz(...)`, you need the `dot` executable from graphviz, not just the python library. 93 | 94 | On **Mac**, do this before or after tensor-sensor install: 95 | 96 | ``` 97 | brew install graphviz 98 | ``` 99 | 100 | On **Windows**, apparently you need 101 | 102 | ``` 103 | conda install python-graphviz # Do this first; get's dot executable and py lib 104 | pip install tensor-sensor # Or one of the other installs 105 | ``` 106 | 107 | 108 | ## Limitations 109 | 110 | I rely on parsing lines that are assignments or expressions only so the clarify and explain routines do not handle methods expressed like: 111 | 112 | ``` 113 | def bar(): b + x * 3 114 | ``` 115 | 116 | Instead, use 117 | 118 | ``` 119 | def bar(): 120 | b + x * 3 121 | ``` 122 | 123 | watch out for side effects! I don't do assignments, but any functions you call with side effects will be done while I reevaluate statements. 124 | 125 | Can't handle `\` continuations. 126 | 127 | With Python `threading` package, don't use multiple threads calling clarify(). `multiprocessing` package should be fine. 128 | 129 | Also note: I've built my own parser to handle just the assignments / expressions tsensor can handle. 130 | 131 | ## Deploy (parrt's use) 132 | 133 | ```bash 134 | $ python setup.py sdist upload 135 | ``` 136 | 137 | Or download and install locally 138 | 139 | ```bash 140 | $ cd ~/github/tensor-sensor 141 | $ pip install . 142 | ``` 143 | 144 | ### TODO 145 | 146 | * can i call pyviz in debugger? 147 | -------------------------------------------------------------------------------- /testing/test_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.parsing import * 25 | import re 26 | 27 | def check(s, expected_repr, expect_str=None): 28 | p = PyExprParser(s, hush_errors=False) 29 | t = p.parse() 30 | 31 | s = re.sub(r"\s+", "", s) 32 | result_str = str(t) 33 | result_str = re.sub(r"\s+", "", result_str) 34 | if expect_str: 35 | s = expect_str 36 | assert result_str==s 37 | 38 | result_repr = repr(t) 39 | result_repr = re.sub(r"\s+", "", result_repr) 40 | expected_repr = re.sub(r"\s+", "", expected_repr) 41 | # print("result", result_repr) 42 | # print("expected", expected) 43 | assert result_repr == expected_repr 44 | 45 | 46 | def test_assign(): 47 | check("a = 3", "Assign(op=,lhs=a,rhs=3)") 48 | 49 | 50 | def test_index(): 51 | check("a[:,i,j]", "Index(arr=a, index=[:, i, j])") 52 | 53 | 54 | def test_index2(): 55 | check("z = a[:]", "Assign(op=,lhs=z,rhs=Index(arr=a,index=[:]))") 56 | 57 | def test_index3(): 58 | check("g.W[:,:,1]", "Index(arr=Member(op=,obj=g,member=W),index=[:,:,1])") 59 | 60 | def test_literal_list(): 61 | check("[[1, 2], [3, 4]]", 62 | "ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])") 63 | 64 | 65 | def test_literal_array(): 66 | check("np.array([[1, 2], [3, 4]])", 67 | """ 68 | Call(func=Member(op=,obj=np,member=array), 69 | args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])]) 70 | """) 71 | 72 | 73 | def test_method(): 74 | check("h = torch.tanh(h)", 75 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[h]))") 76 | 77 | 78 | def test_method2(): 79 | check("np.dot(b,b)", 80 | "Call(func=Member(op=,obj=np,member=dot),args=[b,b])") 81 | 82 | 83 | def test_method3(): 84 | check("y_pred = model(X)", 85 | "Assign(op=,lhs=y_pred,rhs=Call(func=model,args=[X]))") 86 | 87 | 88 | def test_field(): 89 | check("a.b", "Member(op=,obj=a,member=b)") 90 | 91 | 92 | def test_member_func(): 93 | check("a.f()", "Call(func=Member(op=,obj=a,member=f),args=[])") 94 | 95 | 96 | def test_field2(): 97 | check("a.b.c", "Member(op=,obj=Member(op=,obj=a,member=b),member=c)") 98 | 99 | 100 | def test_field_and_func(): 101 | check("a.f().c", "Member(op=,obj=Call(func=Member(op=,obj=a,member=f),args=[]),member=c)") 102 | 103 | 104 | def test_parens(): 105 | check("(a+b)*c", "BinaryOp(op=,lhs=BinaryOp(op=,lhs=a,rhs=b),rhs=c)") 106 | 107 | 108 | def test_expr_in_arg_with_parens(): 109 | check("h = torch.tanh( (1-z)*h + z*h_ )", 110 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[BinaryOp(op=,lhs=BinaryOp(op=,lhs=BinaryOp(op=,lhs=1,rhs=z),rhs=h),rhs=BinaryOp(op=,lhs=z,rhs=h_))]))") 111 | 112 | 113 | def test_1tuple(): 114 | check("(3,)", "TupleLiteral(elems=[3])") 115 | 116 | 117 | def test_2tuple(): 118 | check("(3,4)", "TupleLiteral(elems=[3,4])") 119 | 120 | 121 | def test_2tuple_with_trailing_comma(): 122 | check("(3,4,)", "TupleLiteral(elems=[3,4])") 123 | 124 | 125 | def test_field_array(): 126 | check("a.b[34]", "Index(arr=Member(op=,obj=a,member=b),index=[34])") 127 | 128 | 129 | def test_field_array_func(): 130 | check("a.b[34].f()", "Call(func=Member(op=,obj=Index(arr=Member(op=,obj=a,member=b),index=[34]),member=f),args=[])") 131 | 132 | 133 | def test_arith(): 134 | check("(1-z)*h + z*h_", 135 | """BinaryOp(op=, 136 | lhs=BinaryOp(op=, 137 | lhs=BinaryOp(op=, 138 | lhs=1, 139 | rhs=z), 140 | rhs=h), 141 | rhs=BinaryOp(op=,lhs=z,rhs=h_))""") 142 | 143 | 144 | def test_pow(): 145 | check("a**2", 146 | """BinaryOp(op=,lhs=a,rhs=2)""") 147 | 148 | 149 | def test_chained_pow(): 150 | check("a**b**c", 151 | """BinaryOp(op=,lhs=a,rhs=BinaryOp(op=,lhs=b,rhs=c))""") 152 | 153 | 154 | def test_chained_op(): 155 | check("a + b + c", 156 | """BinaryOp(op=, 157 | lhs=BinaryOp(op=, lhs=a, rhs=b), 158 | rhs=c)""") 159 | 160 | 161 | def test_matrix_arith(): 162 | check("self.Whz@h + Uxz@x + bz", 163 | """ 164 | BinaryOp(op=, 165 | lhs=BinaryOp(op=, 166 | lhs=BinaryOp(op=,lhs=Member(op=,obj=self,member=Whz),rhs=h), 167 | rhs=BinaryOp(op=,lhs=Uxz,rhs=x)), 168 | rhs=bz) 169 | """) 170 | 171 | def test_kwarg(): 172 | check("torch.relu(torch.rand(size=(2000,)))", 173 | """ 174 | Call(func=Member(op=,obj=torch,member=relu), 175 | args=[Call(func=Member(op=,obj=torch,member=rand), 176 | args=[Assign(op=,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""") -------------------------------------------------------------------------------- /images/sample-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | leaf140728778300816 15 | 16 |     17 | 2x1 18 |     19 | 20 | 21 |     22 | b 23 |     24 | 25 | 26 | 27 | leaf140728778299856 28 | 29 | = 30 | 31 | 32 | 33 | 34 | leaf140728778299376 35 | 36 |     37 | 2x2 38 |     39 | 40 | 41 |     42 | W 43 |     44 | 45 | 46 | 47 | 48 | leaf140728778300384 49 | 50 | @ 51 | 52 | 53 | 54 | 55 | leaf140728778300960 56 | 57 |     58 | 2x1 59 |     60 | 61 | 62 |     63 | b 64 |     65 | 66 | 67 | 68 | 69 | leaf140728778300912 70 | 71 | + 72 | 73 | 74 | 75 | 76 | leaf140728778301008 77 | 78 |     79 | 2 80 |     81 | 82 | 83 |     84 | h 85 |     86 | 87 | 88 | 89 | 90 | leaf140728778301104 91 | 92 | . 93 | 94 | 95 | 96 | 97 | leaf140728778301296 98 | 99 | dot 100 | 101 | 102 | 103 | 104 | leaf140728778301392 105 | 106 | ( 107 | 108 | 109 | 110 | 111 | leaf140728778298512 112 | 113 |     114 | 2 115 |     116 | 117 | 118 |     119 | h 120 |     121 | 122 | 123 | 124 | 125 | leaf140728640009264 126 | 127 | ) 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /tsensor/ast.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import tsensor 25 | 26 | # Parse tree definitions 27 | # I found package ast in python3 lib after I built this. whoops. No biggie. 28 | # This tree structure is easier to visit for my purposes here. Also lets me 29 | # control the kinds of statements I process. 30 | 31 | class ParseTreeNode: 32 | def __init__(self, parser): 33 | self.parser = parser # which parser object created this node; 34 | # useful for getting access to the code string from a token 35 | self.value = None # used during evaluation 36 | self.start = None # start token 37 | self.stop = None # end token 38 | def eval(self, frame): 39 | """ 40 | Evaluate the expression represented by this (sub)tree in context of frame. 41 | Try any exception found while evaluating and remember which operation that 42 | was in this tree 43 | """ 44 | try: 45 | self.value = eval(str(self), frame.f_globals, frame.f_locals) 46 | except BaseException as e: 47 | raise IncrEvalTrap(self) from e 48 | # print(self, "=>", self.value) 49 | return self.value 50 | @property 51 | def optokens(self): # the associated token if atom or representative token if operation 52 | return None 53 | @property 54 | def kids(self): 55 | return [] 56 | def clarify(self): 57 | return None 58 | def __str__(self): 59 | # Extract text from the original code string using token character indexes 60 | return self.parser.code[self.start.cstart_idx:self.stop.cstop_idx] 61 | def __repr__(self): 62 | fields = self.__dict__.copy() 63 | kill = ['start', 'stop', 'lbrack', 'lparen', 'parser'] 64 | for name in kill: 65 | if name in fields: del fields[name] 66 | args = [ 67 | v + '=' + fields[v].__repr__() 68 | for v in fields 69 | if v != 'value' or fields['value'] is not None 70 | ] 71 | args = ','.join(args) 72 | return f"{self.__class__.__name__}({args})" 73 | 74 | class Assign(ParseTreeNode): 75 | def __init__(self, parser, op, lhs, rhs, start, stop): 76 | super().__init__(parser) 77 | self.op, self.lhs, self.rhs = op, lhs, rhs 78 | self.start, self.stop = start, stop 79 | def eval(self, frame): 80 | self.value = self.rhs.eval(frame) 81 | # Don't eval this node as it causes side effect of making actual assignment to lhs 82 | self.lhs.value = self.value 83 | return self.value 84 | @property 85 | def optokens(self): 86 | return [self.op] 87 | @property 88 | def kids(self): 89 | return [self.lhs, self.rhs] 90 | 91 | 92 | class Call(ParseTreeNode): 93 | def __init__(self, parser, func, lparen, args, start, stop): 94 | super().__init__(parser) 95 | self.func = func 96 | self.lparen = lparen 97 | self.args = args 98 | self.start, self.stop = start, stop 99 | def eval(self, frame): 100 | self.func.eval(frame) 101 | for a in self.args: 102 | a.eval(frame) 103 | return super().eval(frame) 104 | def clarify(self): 105 | arg_msgs = [] 106 | for a in self.args: 107 | ashape = tsensor.analysis._shape(a.value) 108 | if ashape: 109 | arg_msgs.append(f"arg {a} w/shape {ashape}") 110 | if len(arg_msgs)==0: 111 | return f"Cause: {self}" 112 | return f"Cause: {self} tensor " + ', '.join(arg_msgs) 113 | @property 114 | def optokens(self): 115 | f = None # assume complicated like a[i](args) with weird func expr 116 | if isinstance(self.func, Member): 117 | f = self.func.member 118 | elif isinstance(self.func, Atom): 119 | f = self.func 120 | if f: 121 | return [f.token,self.lparen,self.stop] 122 | return [self.lparen,self.stop] 123 | @property 124 | def kids(self): 125 | return [self.func]+self.args 126 | 127 | 128 | class Return(ParseTreeNode): 129 | def __init__(self, parser, result, start, stop): 130 | super().__init__(parser) 131 | self.result = result 132 | self.start, self.stop = start, stop 133 | def eval(self, frame): 134 | self.value = [a.eval(frame) for a in self.result] 135 | if len(self.value)==1: 136 | self.value = self.value[0] 137 | return self.value 138 | @property 139 | def optokens(self): 140 | return [self.start] 141 | @property 142 | def kids(self): 143 | return self.result 144 | 145 | 146 | class Index(ParseTreeNode): 147 | def __init__(self, parser, arr, lbrack, index, start, stop): 148 | super().__init__(parser) 149 | self.arr = arr 150 | self.lbrack = lbrack 151 | self.index = index 152 | self.start, self.stop = start, stop 153 | def eval(self, frame): 154 | self.arr.eval(frame) 155 | for i in self.index: 156 | i.eval(frame) 157 | return super().eval(frame) 158 | @property 159 | def optokens(self): 160 | return [self.lbrack,self.stop] 161 | @property 162 | def kids(self): 163 | return [self.arr] + self.index 164 | 165 | 166 | class Member(ParseTreeNode): 167 | def __init__(self, parser, op, obj, member, start, stop): 168 | super().__init__(parser) 169 | self.op = op # always DOT 170 | self.obj = obj 171 | self.member = member 172 | self.start, self.stop = start, stop 173 | def eval(self, frame): 174 | self.obj.eval(frame) 175 | # don't eval member as it's just a name to look up in obj 176 | return super().eval(frame) 177 | @property 178 | def optokens(self): # the associated token if atom or representative token if operation 179 | return [self.op] 180 | @property 181 | def kids(self): 182 | return [self.obj, self.member] 183 | 184 | 185 | class BinaryOp(ParseTreeNode): 186 | def __init__(self, parser, op, lhs, rhs, start, stop): 187 | super().__init__(parser) 188 | self.op, self.lhs, self.rhs = op, lhs, rhs 189 | self.start, self.stop = start, stop 190 | def eval(self, frame): 191 | self.lhs.eval(frame) 192 | self.rhs.eval(frame) 193 | return super().eval(frame) 194 | def clarify(self): 195 | opnd_msgs = [] 196 | lshape = tsensor.analysis._shape(self.lhs.value) 197 | rshape = tsensor.analysis._shape(self.rhs.value) 198 | if lshape: 199 | opnd_msgs.append(f"operand {self.lhs} w/shape {lshape}") 200 | if rshape: 201 | opnd_msgs.append(f"operand {self.rhs} w/shape {rshape}") 202 | return f"Cause: {self.op} on tensor " + ' and '.join(opnd_msgs) 203 | @property 204 | def optokens(self): # the associated token if atom or representative token if operation 205 | return [self.op] 206 | @property 207 | def kids(self): 208 | return [self.lhs, self.rhs] 209 | 210 | 211 | class UnaryOp(ParseTreeNode): 212 | def __init__(self, parser, op, opnd, start, stop): 213 | super().__init__(parser) 214 | self.op = op 215 | self.opnd = opnd 216 | self.start, self.stop = start, stop 217 | def eval(self, frame): 218 | self.opnd.eval(frame) 219 | return super().eval(frame) 220 | @property 221 | def optokens(self): 222 | return [self.op] 223 | @property 224 | def kids(self): 225 | return [self.opnd] 226 | 227 | 228 | class ListLiteral(ParseTreeNode): 229 | def __init__(self, parser, elems, start, stop): 230 | super().__init__(parser) 231 | self.elems = elems 232 | self.start, self.stop = start, stop 233 | def eval(self, frame): 234 | for i in self.elems: 235 | i.eval(frame) 236 | return super().eval(frame) 237 | @property 238 | def kids(self): 239 | return self.elems 240 | 241 | 242 | class TupleLiteral(ParseTreeNode): 243 | def __init__(self, parser, elems, start, stop): 244 | super().__init__(parser) 245 | self.elems = elems 246 | self.start, self.stop = start, stop 247 | def eval(self, frame): 248 | for i in self.elems: 249 | i.eval(frame) 250 | return super().eval(frame) 251 | @property 252 | def kids(self): 253 | return self.elems 254 | 255 | 256 | class SubExpr(ParseTreeNode): 257 | # record parens for later display to keep precedence 258 | def __init__(self, parser, e, start, stop): 259 | super().__init__(parser) 260 | self.e = e 261 | self.start, self.stop = start, stop 262 | def eval(self, frame): 263 | self.value = self.e.eval(frame) 264 | return self.value # don't re-evaluate 265 | @property 266 | def optokens(self): 267 | return [self.start, self.stop] 268 | @property 269 | def kids(self): 270 | return [self.e] 271 | 272 | 273 | class Atom(ParseTreeNode): 274 | def __init__(self, parser, token): 275 | super().__init__(parser) 276 | self.token = token 277 | self.start, self.stop = token, token 278 | def eval(self, frame): 279 | if self.token.type == tsensor.parsing.COLON: 280 | return ':' # fake a value here 281 | return super().eval(frame) 282 | def __repr__(self): 283 | # v = f"{{{self.value}}}" if hasattr(self,'value') and self.value is not None else "" 284 | return self.token.value 285 | 286 | 287 | def postorder(t): 288 | nodes = [] 289 | _postorder(t, nodes) 290 | return nodes 291 | 292 | 293 | def _postorder(t, nodes): 294 | if t is None: 295 | return 296 | for sub in t.kids: 297 | _postorder(sub, nodes) 298 | nodes.append(t) 299 | 300 | 301 | def leaves(t): 302 | nodes = [] 303 | _leaves(t, nodes) 304 | return nodes 305 | 306 | 307 | def _leaves(t, nodes): 308 | if t is None: 309 | return 310 | if len(t.kids) == 0: 311 | nodes.append(t) 312 | return 313 | for sub in t.kids: 314 | _leaves(sub, nodes) 315 | 316 | 317 | def walk(t, pre=lambda x: None, post=lambda x: None): 318 | if t is None: 319 | return 320 | pre(t) 321 | for sub in t.kids: 322 | walk(sub, pre, post) 323 | post(t) 324 | 325 | 326 | class IncrEvalTrap(BaseException): 327 | """ 328 | Used during re-evaluation of python line that threw exception to trap which 329 | subexpression caused the problem. 330 | """ 331 | def __init__(self, offending_expr): 332 | self.offending_expr = offending_expr # where in tree did we get exception? 333 | -------------------------------------------------------------------------------- /tsensor/parsing.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from io import BytesIO 25 | import token 26 | import keyword 27 | import tsensor.ast 28 | from tokenize import ( 29 | tokenize, 30 | NUMBER, 31 | STRING, 32 | NAME, 33 | OP, 34 | ENDMARKER, 35 | LPAR, 36 | LSQB, 37 | RPAR, 38 | RSQB, 39 | COMMA, 40 | COLON, 41 | PLUS, 42 | MINUS, 43 | STAR, 44 | SLASH, 45 | AT, 46 | PERCENT, 47 | TILDE, 48 | DOT, 49 | NOTEQUAL, 50 | PERCENTEQUAL, 51 | AMPEREQUAL, 52 | DOUBLESTAREQUAL, 53 | STAREQUAL, 54 | PLUSEQUAL, 55 | MINEQUAL, 56 | DOUBLESLASHEQUAL, 57 | SLASHEQUAL, 58 | LEFTSHIFTEQUAL, 59 | LESSEQUAL, 60 | EQUAL, 61 | EQEQUAL, 62 | GREATEREQUAL, 63 | RIGHTSHIFTEQUAL, 64 | ATEQUAL, 65 | CIRCUMFLEXEQUAL, 66 | VBAREQUAL, 67 | DOUBLESTAR, 68 | ) 69 | 70 | 71 | ADDOP = {PLUS, MINUS} 72 | MULOP = {STAR, SLASH, AT, PERCENT} 73 | ASSIGNOP = { 74 | NOTEQUAL, 75 | PERCENTEQUAL, 76 | AMPEREQUAL, 77 | DOUBLESTAREQUAL, 78 | STAREQUAL, 79 | PLUSEQUAL, 80 | MINEQUAL, 81 | DOUBLESLASHEQUAL, 82 | SLASHEQUAL, 83 | LEFTSHIFTEQUAL, 84 | LESSEQUAL, 85 | EQUAL, 86 | EQEQUAL, 87 | GREATEREQUAL, 88 | RIGHTSHIFTEQUAL, 89 | ATEQUAL, 90 | CIRCUMFLEXEQUAL, 91 | VBAREQUAL, 92 | } 93 | UNARYOP = {TILDE} 94 | 95 | class Token: 96 | """My own version of a token, with content copied from Python's TokenInfo object.""" 97 | def __init__(self, type, value, 98 | index, # token index 99 | cstart_idx, # char start 100 | cstop_idx, # one past char end index so text[start_idx:stop_idx] works 101 | line): 102 | self.type, self.value, self.index, self.cstart_idx, self.cstop_idx, self.line = \ 103 | type, value, index, cstart_idx, cstop_idx, line 104 | def __repr__(self): 105 | return f"<{token.tok_name[self.type]}:{self.value},{self.cstart_idx}:{self.cstop_idx}>" 106 | def __str__(self): 107 | return self.value 108 | 109 | 110 | def mytokenize(s): 111 | "Use Python's tokenizer to lex s and collect my own token objects" 112 | tokensO = tokenize(BytesIO(s.encode('utf-8')).readline) 113 | tokens = [] 114 | i = 0 115 | for tok in tokensO: 116 | type, value, start, end, _ = tok 117 | line = start[0] 118 | start_idx = start[1] 119 | stop_idx = end[1] # one past end index 120 | if type in {NUMBER, STRING, NAME, OP, ENDMARKER}: 121 | tokens.append(Token(tok.exact_type,value,i,start_idx,stop_idx,line)) 122 | i += 1 123 | else: 124 | # print("ignoring", type, value) 125 | pass 126 | # It leaves ENDMARKER on end. set text to "" 127 | tokens[-1].value = "" 128 | # print(tokens) 129 | return tokens 130 | 131 | 132 | class PyExprParser: 133 | """ 134 | A recursive-descent parser for subset of Python expressions and assignments. 135 | There is a built-in parser, but I only want to process Python code this library 136 | can handle and I also want my own kind of abstract syntax tree. Constantly, 137 | it's easier if I just parse the code I care about and ignore everything else. 138 | Building this parser was certainly no great burden. 139 | """ 140 | def __init__(self, code:str, hush_errors=True): 141 | self.code = code 142 | self.hush_errors = hush_errors 143 | self.tokens = mytokenize(code) 144 | self.t = 0 # current lookahead 145 | 146 | def parse(self): 147 | # print("\nparse", self.code) 148 | # print(self.tokens) 149 | # only process assignments and expressions 150 | root = None 151 | if self.tokens[0].value=='return' or not keyword.iskeyword(self.tokens[0].value): 152 | if self.hush_errors: 153 | try: 154 | root = self.assignment_or_return_or_expr() 155 | self.match(ENDMARKER) 156 | except SyntaxError: 157 | root = None 158 | else: 159 | root = self.assignment_or_return_or_expr() 160 | self.match(ENDMARKER) 161 | return root 162 | 163 | def assignment_or_return_or_expr(self): 164 | start = self.LT(1) 165 | if self.LA(1)==NAME and self.LT(1).value=='return': 166 | self.match(NAME) 167 | r = self.exprlist() 168 | stop = self.LT(-1) 169 | return tsensor.ast.Return(self,r,start,stop) 170 | lhs = self.expression() 171 | if self.LA(1) in ASSIGNOP: 172 | eq = self.LT(1) 173 | self.t += 1 174 | rhs = self.expression() 175 | stop = self.LT(-1) 176 | return tsensor.ast.Assign(self,eq,lhs,rhs,start,stop) 177 | return lhs 178 | 179 | def expression(self): 180 | return self.addexpr() 181 | 182 | def addexpr(self): 183 | start = self.LT(1) 184 | root = self.multexpr() 185 | while self.LA(1) in ADDOP: 186 | op = self.LT(1) 187 | self.t += 1 188 | b = self.multexpr() 189 | stop = self.LT(-1) 190 | root = tsensor.ast.BinaryOp(self, op, root, b, start, stop) 191 | return root 192 | 193 | def multexpr(self): 194 | start = self.LT(1) 195 | root = self.powexpr() 196 | while self.LA(1) in MULOP: 197 | op = self.LT(1) 198 | self.t += 1 199 | b = self.powexpr() 200 | stop = self.LT(-1) 201 | root = tsensor.ast.BinaryOp(self, op, root, b, start, stop) 202 | return root 203 | 204 | def powexpr(self): 205 | start = self.LT(1) 206 | root = self.unaryexpr() 207 | if self.LA(1)==DOUBLESTAR: 208 | op = self.match(DOUBLESTAR) 209 | r = self.powexpr() 210 | stop = self.LT(-1) 211 | root = tsensor.ast.BinaryOp(self, op, root, r, start, stop) 212 | return root 213 | 214 | def unaryexpr(self): 215 | start = self.LT(1) 216 | if self.LA(1) in UNARYOP: 217 | op = self.LT(1) 218 | self.t += 1 219 | e = self.unaryexpr() 220 | stop = self.LT(-1) 221 | return tsensor.ast.UnaryOp(self, op, e, start, stop) 222 | elif self.isatom() or self.isgroup(): 223 | return self.postexpr() 224 | else: 225 | self.error(f"missing unary expr at: {self.LT(1)}") 226 | 227 | def postexpr(self): 228 | start = self.LT(1) 229 | root = self.atom() 230 | while self.LA(1) in {LPAR, LSQB, DOT}: 231 | if self.LA(1)==LPAR: 232 | lp = self.LT(1) 233 | self.match(LPAR) 234 | el = [] 235 | if self.LA(1) != RPAR: 236 | el = self.arglist() 237 | self.match(RPAR) 238 | stop = self.LT(-1) 239 | root = tsensor.ast.Call(self, root, lp, el, start, stop) 240 | if self.LA(1)==LSQB: 241 | lb = self.LT(1) 242 | self.match(LSQB) 243 | el = self.exprlist() 244 | self.match(RSQB) 245 | stop = self.LT(-1) 246 | root = tsensor.ast.Index(self, root, lb, el, start, stop) 247 | if self.LA(1)==DOT: 248 | op = self.match(DOT) 249 | m = self.match(NAME) 250 | m = tsensor.ast.Atom(self, m) 251 | stop = self.LT(-1) 252 | root = tsensor.ast.Member(self, op, root, m, start, stop) 253 | return root 254 | 255 | def atom(self): 256 | if self.LA(1) == LPAR: 257 | return self.subexpr() 258 | elif self.LA(1) == LSQB: 259 | return self.listatom() 260 | elif self.LA(1) in {NUMBER, NAME, STRING, COLON}: 261 | atom = self.LT(1) 262 | self.t += 1 263 | return tsensor.ast.Atom(self, atom) 264 | else: 265 | self.error("unknown or missing atom:"+str(self.LT(1))) 266 | 267 | def exprlist(self): 268 | elist = [] 269 | e = self.expression() 270 | elist.append(e) 271 | while self.LA(1)==COMMA and self.LA(2)!=RPAR: # could be trailing comma in a tuple like (3,4,) 272 | self.match(COMMA) 273 | e = self.expression() 274 | elist.append(e) 275 | return elist 276 | 277 | def arglist(self): 278 | elist = [] 279 | if self.LA(1)==NAME and self.LA(2)==EQUAL: 280 | e = self.arg() 281 | else: 282 | e = self.expression() 283 | elist.append(e) 284 | while self.LA(1)==COMMA: 285 | self.match(COMMA) 286 | if self.LA(1) == NAME and self.LA(2)==EQUAL: 287 | e = self.arg() 288 | else: 289 | e = self.expression() 290 | elist.append(e) 291 | return elist 292 | 293 | def arg(self): 294 | start = self.LT(1) 295 | kwarg = self.match(NAME) 296 | eq = self.match(EQUAL) 297 | e = self.expression() 298 | kwarg = tsensor.ast.Atom(self, kwarg) 299 | stop = self.LT(-1) 300 | return tsensor.ast.Assign(self, eq, kwarg, e, start, stop) 301 | 302 | def subexpr(self): 303 | start = self.match(LPAR) 304 | e = self.exprlist() # could be a tuple like (3,4) or even (3,4,) 305 | istuple = len(e)>1 306 | if self.LA(1)==COMMA: 307 | self.match(COMMA) 308 | istuple = True 309 | stop = self.match(RPAR) 310 | if istuple: 311 | return tsensor.ast.TupleLiteral(self, e, start, stop) 312 | subexpr = e[0] 313 | # Parentheses just alter the precedence and don't actually indicate an operator 314 | # so we just pass the sub expression through (if not a tuple) 315 | return subexpr 316 | 317 | def listatom(self): 318 | start = self.LT(1) 319 | self.match(LSQB) 320 | e = self.exprlist() 321 | self.match(RSQB) 322 | stop = self.LT(-1) 323 | return tsensor.ast.ListLiteral(self, e, start, stop) 324 | 325 | def isatom(self): 326 | return self.LA(1) in {NUMBER, NAME, STRING, COLON} 327 | # return idstart(self.LA(1)) or self.LA(1).isdigit() or self.LA(1)==':' 328 | 329 | def isgroup(self): 330 | return self.LA(1)==LPAR or self.LA(1)==LSQB 331 | 332 | def LA(self, i): 333 | return self.LT(i).type 334 | 335 | def LT(self, i): 336 | if i==0: 337 | return None 338 | if i<0: 339 | return self.tokens[self.t + i] # -1 should give prev token 340 | ahead = self.t + i - 1 341 | if ahead >= len(self.tokens): 342 | return self.tokens[-1] # return last (end marker) 343 | return self.tokens[ahead] 344 | 345 | def match(self, type): 346 | if self.LA(1)!=type: 347 | self.error(f"mismatch token {self.LT(1)}, looking for {token.tok_name[type]}") 348 | tok = self.LT(1) 349 | self.t += 1 350 | return tok 351 | 352 | def error(self, msg): 353 | raise SyntaxError(msg) 354 | 355 | 356 | def parse(statement:str, hush_errors=True): 357 | """ 358 | Parse statement and return ast and token objects. Parsing errors from invalid code 359 | or code that I cannot parse are ignored unless hush_hush_errors is False. 360 | """ 361 | p = tsensor.parsing.PyExprParser(statement, hush_errors=hush_errors) 362 | return p.parse(), p.tokens 363 | -------------------------------------------------------------------------------- /images/ast.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | leaf140396598700880 15 | 16 | b 17 | 18 | 19 | 20 | leaf140396332381616 21 | 22 | = 23 | 24 | 25 | 26 | 27 | leaf140395790355424 28 | 29 | W 30 | 31 | 32 | 33 | 34 | leaf140395790354080 35 | 36 | @ 37 | 38 | 39 | 40 | 41 | leaf140395790355952 42 | 43 | b 44 | 45 | 46 | 47 | 48 | leaf140395790353216 49 | 50 | + 51 | 52 | 53 | 54 | 55 | leaf140395790355040 56 | 57 | ( 58 | 59 | 60 | 61 | 62 | leaf140395790354752 63 | 64 | h 65 | 66 | 67 | 68 | 69 | leaf140395790354320 70 | 71 | + 72 | 73 | 74 | 75 | 76 | leaf140395790352688 77 | 78 | 3 79 | 80 | 81 | 82 | 83 | leaf140395790353360 84 | 85 | ) 86 | 87 | 88 | 89 | 90 | leaf140396596003168 91 | 92 | . 93 | 94 | 95 | 96 | 97 | leaf140396596004224 98 | 99 | dot 100 | 101 | 102 | 103 | 104 | leaf140396596005760 105 | 106 | ( 107 | 108 | 109 | 110 | 111 | leaf140396599103696 112 | 113 | h 114 | 115 | 116 | 117 | 118 | leaf140396599104896 119 | 120 | ) 121 | 122 | 123 | 124 | 125 | leaf140396599104944 126 | 127 | + 128 | 129 | 130 | 131 | 132 | leaf140396599104608 133 | 134 | z 135 | 136 | 137 | 138 | 139 | node140396598701936 140 | 141 | @ 142 | 2kx1 143 | <float64> 144 | 145 | 146 | 147 | node140396598701936->leaf140395790355424 148 | 149 | 150 | 151 | 152 | 153 | node140396598701936->leaf140395790355952 154 | 155 | 156 | 157 | 158 | 159 | node140396599104128 160 | 161 | + 162 | 1m 163 | <int64> 164 | 165 | 166 | 167 | node140396599104128->leaf140395790354752 168 | 169 | 170 | 171 | 172 | 173 | node140396599104128->leaf140395790352688 174 | 175 | 176 | 177 | 178 | 179 | node140396599104800 180 | 181 | . 182 | 183 | 184 | 185 | node140396599104800->leaf140396596004224 186 | 187 | 188 | 189 | 190 | 191 | node140396599104800->node140396599104128 192 | 193 | 194 | 195 | 196 | 197 | node140396599105136 198 | 199 | dot() 200 | 201 | 202 | 203 | node140396599105136->leaf140396599103696 204 | 205 | 206 | 207 | 208 | 209 | node140396599105136->node140396599104800 210 | 211 | 212 | 213 | 214 | 215 | node140396599104416 216 | 217 | + 218 | 2kx1 219 | <float64> 220 | 221 | 222 | 223 | node140396599104416->node140396598701936 224 | 225 | 226 | 227 | 228 | 229 | node140396599104416->node140396599105136 230 | 231 | 232 | 233 | 234 | 235 | node140396599104320 236 | 237 | + 238 | 2kx1 239 | <complex128> 240 | 241 | 242 | 243 | node140396599104320->leaf140396599104608 244 | 245 | 246 | 247 | 248 | 249 | node140396599104320->node140396599104416 250 | 251 | 252 | 253 | 254 | 255 | node140396599104368 256 | 257 | = 258 | 2kx1 259 | <complex128> 260 | 261 | 262 | 263 | node140396599104368->leaf140396598700880 264 | 265 | 266 | 267 | 268 | 269 | node140396599104368->node140396599104320 270 | 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /images/mm.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 2021-12-11T13:11:04.999774 11 | image/svg+xml 12 | 13 | 14 | Matplotlib v3.3.4, https://matplotlib.org/ 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 33 | 34 | 40 | 41 | 42 | 48 | 49 | 50 | 53 | 54 | 55 | 58 | 59 | 60 | 61 | 62 | 63 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 400 | 433 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 471 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 515 | 535 | 541 | 568 | 604 | 628 | 658 | 675 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 732 | 733 | 734 | 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | -------------------------------------------------------------------------------- /tsensor/analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import os 25 | import sys 26 | import traceback 27 | import inspect 28 | import hashlib 29 | from pathlib import Path 30 | 31 | import matplotlib.pyplot as plt 32 | 33 | import tsensor 34 | 35 | 36 | class clarify: 37 | # Prevent nested clarify() calls from processing exceptions. 38 | # See https://github.com/parrt/tensor-sensor/issues/18 39 | # Probably will fail with Python `threading` package due to this class var 40 | # but only if multiple threads call clarify(). 41 | # Multiprocessing forks new processes so not a problem. Each vm has it's own class var. 42 | # Bump in __enter__, drop in __exit__ 43 | nesting = 0 44 | 45 | def __init__(self, 46 | fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13, 47 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443', 48 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 49 | show:(None,'viz')='viz', 50 | hush_errors=True, 51 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None): 52 | """ 53 | Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow. 54 | Also display a visual representation of the offending Python line that 55 | shows the shape of tensors referenced by the code. All you have to do is wrap 56 | the outermost level of your code and clarify() will activate upon exception. 57 | 58 | Visualizations pop up in a separate window unless running from a notebook, 59 | in which case the visualization appears as part of the cell execution output. 60 | 61 | There is no runtime overhead associated with clarify() unless an exception occurs. 62 | 63 | The offending code is executed a second time, to identify which sub expressions 64 | are to blame. This implies that code with side effects could conceivably cause 65 | a problem, but since an exception has been generated, results are suspicious 66 | anyway. 67 | 68 | Example: 69 | 70 | import numpy as np 71 | import tsensor 72 | 73 | b = np.array([9, 10]).reshape(2, 1) 74 | with tsensor.clarify(): 75 | np.dot(b,b) # tensor code or call to a function with tensor code 76 | 77 | See examples.ipynb for more examples. 78 | 79 | :param fontname: The name of the font used to display Python code 80 | :param fontsize: The font size used to display Python code; default is 13. 81 | Also use this to increase the size of the generated figure; 82 | larger font size means larger image. 83 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 84 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 85 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 86 | text is when plotted in matplotlib. In fact there's probably, 87 | no hope to discover this information accurately in all cases. 88 | Certainly, I gave up after spending huge effort. We have a 89 | situation here where the font should be constant width, so 90 | we can just use a simple scaler times the font size to get 91 | a reasonable approximation to the width and height of a 92 | character box; the default of 1.8 seems to work reasonably 93 | well for a wide range of fonts, but you might have to tweak it 94 | when you change the font size. 95 | :param fontcolor: The color of the Python code. 96 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 97 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression 98 | :param error_op_color: The color to use for characters associated with the erroneous operator 99 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 100 | :param dpi: This library tries to generate SVG files, which are vector graphics not 101 | 2D arrays of pixels like PNG files. However, it needs to know how to 102 | compute the exact figure size to remove padding around the visualization. 103 | Matplotlib uses inches for its figure size and so we must convert 104 | from pixels or data units to inches, which means we have to know what the 105 | dots per inch, dpi, is for the image. 106 | :param hush_errors: Normally, error messages from true syntax errors but also 107 | unhandled code caught by my parser are ignored. Turn this off 108 | to see what the error messages are coming from my parser. 109 | :param show: Show visualization upon tensor error if show='viz'. 110 | :param dtype_colors: map from dtype w/o precision like 'int' to color 111 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128] 112 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color, 113 | and the alpha channel is used to show precision; the 114 | smaller the bit size, the lower the alpha channel. You 115 | can play with the range to get better visual dynamic range 116 | depending on how many precisions you want to display. 117 | """ 118 | self.show, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \ 119 | self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \ 120 | self.error_op_color, self.hush_errors, \ 121 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \ 122 | show, fontname, fontsize, dimfontname, dimfontsize, \ 123 | char_sep_scale, fontcolor, underline_color, ignored_color, \ 124 | error_op_color, hush_errors, \ 125 | dtype_colors, dtype_precisions, dtype_alpha_range 126 | 127 | def __enter__(self): 128 | self.frame = sys._getframe().f_back # where do we start tracking? Hmm...not sure we use this 129 | # print("ENTER", clarify.nesting, self.frame, id(self.frame)) 130 | clarify.nesting += 1 131 | return self 132 | 133 | def __exit__(self, exc_type, exc_value, exc_traceback): 134 | # print("EXIT", clarify.nesting, self.frame, id(self.frame)) 135 | clarify.nesting -= 1 136 | if clarify.nesting>0: 137 | return 138 | if exc_type is None: 139 | return 140 | exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback) 141 | if lib_entry_frame is not None or is_interesting_exception(exc_value): 142 | # print("exception:", exc_value, exc_traceback) 143 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout) 144 | module, name, filename, line, code = info(exc_frame) 145 | # print('info', module, name, filename, line, code) 146 | # print("exc id", id(exc_value)) 147 | if code is not None: 148 | self.view = tsensor.viz.pyviz(code, exc_frame, 149 | self.fontname, self.fontsize, self.dimfontname, 150 | self.dimfontsize, 151 | self.char_sep_scale, self.fontcolor, 152 | self.underline_color, self.ignored_color, 153 | self.error_op_color, 154 | hush_errors=self.hush_errors, 155 | dtype_colors=self.dtype_colors, 156 | dtype_precisions=self.dtype_precisions, 157 | dtype_alpha_range=self.dtype_alpha_range) 158 | if self.view is not None: # Ignore if we can't process code causing exception (I use a subparser) 159 | if self.show=='viz': 160 | self.view.show() 161 | augment_exception(exc_value, self.view.offending_expr) 162 | 163 | 164 | class explain: 165 | def __init__(self, 166 | fontname=('Consolas', 'DejaVu Sans Mono'), fontsize=13, 167 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443', 168 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 169 | savefig=None, hush_errors=True, 170 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None): 171 | """ 172 | As the Python virtual machine executes lines of code, generate a 173 | visualization for tensor-related expressions using from numpy, pytorch, 174 | and tensorflow. The shape of tensors referenced by the code are displayed. 175 | 176 | Visualizations pop up in a separate window unless running from a notebook, 177 | in which case the visualization appears as part of the cell execution output. 178 | 179 | There is heavy runtime overhead associated with explain() as every line 180 | is executed twice: once by explain() and then another time by the interpreter 181 | as part of normal execution. 182 | 183 | Expressions with side effects can easily generate incorrect results. Due to 184 | this and the overhead, you should limit the use of this to code you're trying 185 | to debug. Assignments are not evaluated by explain so code `x = ...` causes 186 | an assignment to x just once, during normal execution. This explainer 187 | knows the value of x and will display it but does not assign to it. 188 | 189 | Upon exception, execution will stop as usual but, like clarify(), explain() 190 | will augment the exception to indicate the offending sub expression. Further, 191 | the visualization will deemphasize code not associated with the offending 192 | sub expression. The sizes of relevant tensor values are still visualized. 193 | 194 | Example: 195 | 196 | import numpy as np 197 | import tsensor 198 | 199 | b = np.array([9, 10]).reshape(2, 1) 200 | with tsensor.explain(): 201 | b + b # tensor code or call to a function with tensor code 202 | 203 | See examples.ipynb for more examples. 204 | 205 | :param fontname: The name of the font used to display Python code 206 | :param fontsize: The font size used to display Python code; default is 13. 207 | Also use this to increase the size of the generated figure; 208 | larger font size means larger image. 209 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 210 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 211 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 212 | text is when plotted in matplotlib. In fact there's probably, 213 | no hope to discover this information accurately in all cases. 214 | Certainly, I gave up after spending huge effort. We have a 215 | situation here where the font should be constant width, so 216 | we can just use a simple scaler times the font size to get 217 | a reasonable approximation to the width and height of a 218 | character box; the default of 1.8 seems to work reasonably 219 | well for a wide range of fonts, but you might have to tweak it 220 | when you change the font size. 221 | :param fontcolor: The color of the Python code. 222 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 223 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression 224 | :param error_op_color: The color to use for characters associated with the erroneous operator 225 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 226 | :param dpi: This library tries to generate SVG files, which are vector graphics not 227 | 2D arrays of pixels like PNG files. However, it needs to know how to 228 | compute the exact figure size to remove padding around the visualization. 229 | Matplotlib uses inches for its figure size and so we must convert 230 | from pixels or data units to inches, which means we have to know what the 231 | dots per inch, dpi, is for the image. 232 | :param hush_errors: Normally, error messages from true syntax errors but also 233 | unhandled code caught by my parser are ignored. Turn this off 234 | to see what the error messages are coming from my parser. 235 | :param savefig: A string indicating where to save the visualization; don't save 236 | a file if None. 237 | :param dtype_colors: map from dtype w/o precision like 'int' to color 238 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128] 239 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color, 240 | and the alpha channel is used to show precision; the 241 | smaller the bit size, the lower the alpha channel. You 242 | can play with the range to get better visual dynamic range 243 | depending on how many precisions you want to display. 244 | """ 245 | self.savefig, self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \ 246 | self.char_sep_scale, self.fontcolor, self.underline_color, self.ignored_color, \ 247 | self.error_op_color, self.hush_errors, \ 248 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \ 249 | savefig, fontname, fontsize, dimfontname, dimfontsize, \ 250 | char_sep_scale, fontcolor, underline_color, ignored_color, \ 251 | error_op_color, hush_errors, \ 252 | dtype_colors, dtype_precisions, dtype_alpha_range 253 | 254 | def __enter__(self): 255 | # print("ON trace", sys._getframe()) 256 | self.tracer = ExplainTensorTracer(self) 257 | sys.settrace(self.tracer.listener) 258 | frame = sys._getframe() 259 | prev = frame.f_back # get block wrapped in "with" 260 | prev.f_trace = self.tracer.listener 261 | return self.tracer 262 | 263 | def __exit__(self, exc_type, exc_value, exc_traceback): 264 | # print("OFF trace") 265 | sys.settrace(None) 266 | # At this point we have already tried to visualize the statement 267 | # If there was no error, the visualization will look normal 268 | # but a matrix operation error will show the erroneous operator highlighted. 269 | # That was artificial execution of the code. Now the VM has executed 270 | # the statement for real and has found the same exception. Make sure to 271 | # augment the message with causal information. 272 | if exc_type is None: 273 | return 274 | exc_frame, lib_entry_frame = tensor_lib_entry_frame(exc_traceback) 275 | if lib_entry_frame is not None or is_interesting_exception(exc_value): 276 | # print("exception:", exc_value, exc_traceback) 277 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout) 278 | module, name, filename, line, code = info(exc_frame) 279 | # print('info', module, name, filename, line, code) 280 | if code is not None: 281 | # We've already displayed picture so just augment message 282 | root, tokens = tsensor.parsing.parse(code) 283 | if root is not None: # Could be syntax error in statement or code I can't handle 284 | offending_expr = None 285 | try: 286 | root.eval(exc_frame) 287 | except tsensor.ast.IncrEvalTrap as e: 288 | offending_expr = e.offending_expr 289 | augment_exception(exc_value, offending_expr) 290 | 291 | 292 | class ExplainTensorTracer: 293 | def __init__(self, explainer): 294 | self.explainer = explainer 295 | self.exceptions = set() 296 | self.linecount = 0 297 | self.views = [] 298 | # set of hashes for statements already visualized; 299 | # generate each combination of statement and shapes once 300 | self.done = set() 301 | 302 | def listener(self, frame, event, arg): 303 | # print("listener", event, ":", frame) 304 | if event!='line': 305 | # It seems that we are getting CALL events even for calls in foo() from: 306 | # with tsensor.explain(): foo() 307 | # Must be that we already have a listener and, though we returned None here, 308 | # somehow the original listener is still getting events. Strange but oh well. 309 | # We must ignore these. 310 | return None 311 | module = frame.f_globals['__name__'] 312 | info = inspect.getframeinfo(frame) 313 | filename, line = info.filename, info.lineno 314 | name = info.function 315 | 316 | # Note: always true since L292 above... 317 | if event=='line': 318 | self.line_listener(module, name, filename, line, info, frame) 319 | 320 | # By returning none, we prevent explain()'ing from descending into 321 | # invoked functions. In principle, we could allow a certain amount 322 | # of tracing but I'm not sure that would be super useful. 323 | return None 324 | 325 | def line_listener(self, module, name, filename, line, info, frame): 326 | code = info.code_context[0].strip() 327 | if code.startswith("sys.settrace(None)"): 328 | return 329 | 330 | # Don't generate a statement visualization more than once 331 | h = hash(code) 332 | if h in self.done: 333 | return 334 | self.done.add(h) 335 | 336 | p = tsensor.parsing.PyExprParser(code) 337 | t = p.parse() 338 | if t is not None: 339 | # print(f"A line encountered in {module}.{name}() at {filename}:{line}") 340 | # print("\t", code) 341 | # print("\t", repr(t)) 342 | self.linecount += 1 343 | self.viz_statement(code, frame) 344 | 345 | def viz_statement(self, code, frame): 346 | view = tsensor.viz.pyviz(code, frame, 347 | self.explainer.fontname, self.explainer.fontsize, 348 | self.explainer.dimfontname, 349 | self.explainer.dimfontsize, 350 | self.explainer.char_sep_scale, self.explainer.fontcolor, 351 | self.explainer.underline_color, self.explainer.ignored_color, 352 | self.explainer.error_op_color, 353 | hush_errors=self.explainer.hush_errors, 354 | dtype_colors=self.explainer.dtype_colors, 355 | dtype_precisions=self.explainer.dtype_precisions, 356 | dtype_alpha_range=self.explainer.dtype_alpha_range) 357 | self.views.append(view) 358 | if self.explainer.savefig is not None: 359 | file_path = Path(self.explainer.savefig) 360 | file_path = file_path.parent / f"{file_path.stem}-{self.linecount}{file_path.suffix}" 361 | view.savefig(file_path) 362 | view.filename = file_path 363 | plt.close() 364 | else: 365 | view.show() 366 | return view 367 | 368 | @staticmethod 369 | def hash(statement): 370 | """ 371 | We want to avoid generating a visualization more than once. 372 | For now, assume that the code for a statement is the unique identifier. 373 | """ 374 | return hashlib.md5(statement.encode('utf-8')).hexdigest() 375 | 376 | 377 | def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object): 378 | """ 379 | Parse statement and return an ast in the context of execution frame or, if None, 380 | the invoking function's frame. Set the value field of all ast nodes. 381 | Overall result is in root.value. 382 | :param statement: A string representing the line of Python code to visualize within an execution frame. 383 | :param frame: The execution frame in which to evaluate the statement. If None, 384 | use the execution frame of the invoking function 385 | :return An abstract parse tree representing the statement; nodes are 386 | ParseTreeNode subclasses. 387 | """ 388 | p = tsensor.parsing.PyExprParser(statement) 389 | root = p.parse() 390 | if frame is None: # use frame of caller 391 | frame = sys._getframe().f_back 392 | root.eval(frame) 393 | return root, root.value 394 | 395 | 396 | def augment_exception(exc_value, subexpr): 397 | explanation = subexpr.clarify() 398 | augment = "" 399 | if explanation is not None: 400 | augment = explanation 401 | # Reuse exception but overwrite the message 402 | if hasattr(exc_value, "_message"): 403 | exc_value._message = exc_value.message + "\n" + augment 404 | else: 405 | exc_value.args = [exc_value.args[0] + "\n" + augment] 406 | 407 | 408 | def is_interesting_exception(e): 409 | # print(f"is_interesting_exception: type is {type(e)}") 410 | if e.__class__.__module__.startswith("tensorflow"): 411 | return True 412 | sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension', 413 | 'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix', 414 | 'call to _th_addmm'} 415 | if len(e.args)==0: 416 | msg = e.message 417 | else: 418 | msg = e.args[0] 419 | return any([s in msg for s in sentinels]) 420 | 421 | 422 | def tensor_lib_entry_frame(exc_traceback): 423 | """ 424 | Don't trace into internals of numpy/torch/tensorflow/jax; we want to reset frame 425 | to where in the user's python code it asked the tensor lib to perform an 426 | invalid operation. 427 | 428 | To detect libraries, look for code whose filename has "site-packages/{package}" 429 | or "dist-packages/{package}". 430 | 431 | Return last-user-frame, first-tensor-lib-frame if lib found else last-user-frame, None 432 | 433 | Note: Sometimes operators yield exceptions and no tensor lib entry frame. E.g., 434 | np.ones(1) @ np.ones(2). 435 | """ 436 | tb = exc_traceback 437 | # import traceback 438 | # for t in traceback.extract_tb(exc_traceback): 439 | # print(t) 440 | packages = ['numpy','torch','tensorflow','jax'] 441 | dirs = [os.path.join('site-packages',p) for p in packages] 442 | dirs += [os.path.join('dist-packages',p) for p in packages] 443 | dirs += ['<__array_function__'] # numpy seems to not have real filename 444 | prev = tb 445 | while tb is not None: 446 | filename = tb.tb_frame.f_code.co_filename 447 | reached_lib = [p in filename for p in dirs] 448 | if sum(reached_lib)>0: 449 | return prev.tb_frame, tb.tb_frame 450 | prev = tb 451 | tb = tb.tb_next 452 | return prev.tb_frame, None 453 | 454 | 455 | def info(frame): 456 | if hasattr(frame, '__name__'): 457 | module = frame.f_globals['__name__'] 458 | else: 459 | module = None 460 | info = inspect.getframeinfo(frame) 461 | if info.code_context is not None: 462 | code = info.code_context[0].strip() 463 | else: 464 | code = None 465 | filename, line = info.filename, info.lineno 466 | name = info.function 467 | return module, name, filename, line, code 468 | 469 | 470 | def smallest_matrix_subexpr(t): 471 | """ 472 | During visualization, we need to find the smallest expression 473 | that evaluates to a non-scalar. That corresponds to the deepest subtree 474 | that evaluates to a non-scalar. Because we do not have parent pointers, 475 | we cannot start at the leaves and walk upwards. Instead, pass a Boolean 476 | back to indicate whether this node or one of the descendants 477 | evaluates to a non-scalar. Nodes in the tree that have matrix values and 478 | no matrix below are the ones to visualize. 479 | """ 480 | nodes = [] 481 | _smallest_matrix_subexpr(t, nodes) 482 | return nodes 483 | 484 | 485 | def _smallest_matrix_subexpr(t, nodes) -> bool: 486 | if t is None: return False # prevent buggy code from causing us to fail 487 | if isinstance(t, tsensor.ast.Member) and \ 488 | isinstance(t.obj, tsensor.ast.Atom) and \ 489 | isinstance(t.member, tsensor.ast.Atom) and \ 490 | str(t.member)=='T': 491 | nodes.append(t) 492 | return True 493 | if len(t.kids)==0: # leaf node 494 | if istensor(t.value): 495 | nodes.append(t) 496 | return istensor(t.value) 497 | n_matrix_below = 0 # once this latches true, it's passed all the way up to the root 498 | for sub in t.kids: 499 | matrix_below = _smallest_matrix_subexpr(sub, nodes) 500 | n_matrix_below += matrix_below # how many descendents evaluated two non-scalar? 501 | # If current node is matrix and no descendents are, then this is smallest 502 | # sub expression that evaluates to a matrix; keep track 503 | if istensor(t.value) and n_matrix_below==0: 504 | nodes.append(t) 505 | # Report to caller that this node or some descendent is a matrix 506 | return istensor(t.value) or n_matrix_below > 0 507 | 508 | 509 | def istensor(x): 510 | return _shape(x) is not None 511 | 512 | 513 | def _dtype(v) -> str: 514 | if hasattr(v, "dtype"): 515 | dtype = v.dtype 516 | elif "dtype" in v.__class__.__name__: 517 | dtype = v 518 | else: 519 | return None 520 | 521 | if dtype.__class__.__module__ == "torch": 522 | # ugly but works 523 | return str(dtype).replace("torch.", "") 524 | if hasattr(dtype, "names") and dtype.names is not None and hasattr(dtype, "fields"): 525 | # structured dtype: https://numpy.org/devdocs/user/basics.rec.html 526 | return ",".join([_dtype(val) for val, _ in dtype.fields.values()]) 527 | return dtype.name 528 | 529 | 530 | def _shape(v): 531 | # do we have a shape and it answers len()? Should get stuff right. 532 | if hasattr(v, "shape") and hasattr(v.shape, "__len__"): 533 | if v.shape.__class__.__module__ == "torch" and v.shape.__class__.__name__ == "Size": 534 | if len(v.shape)==0: 535 | return None 536 | return list(v.shape) 537 | return v.shape 538 | return None 539 | -------------------------------------------------------------------------------- /tsensor/viz.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import sys 25 | import os 26 | from pathlib import Path 27 | import tempfile 28 | import graphviz 29 | import graphviz.backend 30 | import token 31 | import matplotlib.patches as patches 32 | import matplotlib.pyplot as plt 33 | import matplotlib.colors as mc 34 | from IPython.display import display, SVG 35 | from IPython import get_ipython 36 | 37 | import numpy as np 38 | import tsensor 39 | import tsensor.ast 40 | import tsensor.analysis 41 | import tsensor.parsing 42 | 43 | 44 | class DTypeColorInfo: 45 | """ 46 | Track the colors for various types, the transparency range, and bit precisions. 47 | By default, green indicates floating-point, blue indicates integer, and orange 48 | indicates complex numbers. The more saturated the color (lower transparency), 49 | the higher the precision. 50 | """ 51 | orangeish = '#FDD66C' 52 | limeish = '#A8E1B0' 53 | blueish = '#7FA4D3' 54 | grey = '#EFEFF0' 55 | default_dtype_colors = {'float': limeish, 'int': blueish, 'complex': orangeish, 'other': grey} 56 | default_dtype_precisions = [32, 64, 128] # hard to see diff if we use [4, 8, 16, 32, 64, 128] 57 | default_dtype_alpha_range = (0.5, 1.0) # use (0.1, 1.0) if more precision values 58 | 59 | def __init__(self, dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None): 60 | if dtype_colors is None: 61 | dtype_colors = DTypeColorInfo.default_dtype_colors 62 | if dtype_precisions is None: 63 | dtype_precisions = DTypeColorInfo.default_dtype_precisions 64 | if dtype_alpha_range is None: 65 | dtype_alpha_range = DTypeColorInfo.default_dtype_alpha_range 66 | 67 | if not isinstance(dtype_colors, dict) or (len(dtype_colors) > 0 and \ 68 | not isinstance(list(dtype_colors.values())[0], str)): 69 | raise TypeError( 70 | "dtype_colors should be a dict mapping type name to color name or color hex RGB values." 71 | ) 72 | 73 | self.dtype_colors, self.dtype_precisions, self.dtype_alpha_range = \ 74 | dtype_colors, dtype_precisions, dtype_alpha_range 75 | 76 | def color(self, dtype): 77 | """Get color based on type and precision. Return list of RGB and alpha""" 78 | dtype_name, dtype_precision = PyVizView._split_dtype_precision(dtype) 79 | if dtype_name not in self.dtype_colors: 80 | return self.dtype_colors['other'] 81 | color = self.dtype_colors[dtype_name] 82 | dtype_precision = int(dtype_precision) 83 | if dtype_precision not in self.dtype_precisions: 84 | return self.dtype_colors['other'] 85 | 86 | color = mc.hex2color(color) if color[0] == '#' else mc.cnames[color] 87 | precision_idx = self.dtype_precisions.index(dtype_precision) 88 | nshades = len(self.dtype_precisions) 89 | alphas = np.linspace(*self.dtype_alpha_range, nshades) 90 | alpha = alphas[precision_idx] 91 | return list(color) + [alpha] 92 | 93 | 94 | class PyVizView: 95 | """ 96 | An object that collects relevant information about viewing Python code 97 | with visual annotations. 98 | """ 99 | def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize, 100 | char_sep_scale, dpi, 101 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None): 102 | self.statement = statement 103 | self.fontsize = fontsize 104 | self.fontname = fontname 105 | self.dimfontsize = dimfontsize 106 | self.dimfontname = dimfontname 107 | self.char_sep_scale = char_sep_scale 108 | self.dpi = dpi 109 | self.dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range) 110 | self._dtype_encountered = set() # which types, like 'int64', did we find in one plot? 111 | self.wchar = self.char_sep_scale * self.fontsize 112 | self.wchar_small = self.char_sep_scale * (self.fontsize - 2) # for typenames 113 | self.hchar = self.char_sep_scale * self.fontsize 114 | self.dim_ypadding = 5 115 | self.dim_xpadding = 0 116 | self.linewidth = .7 117 | self.leftedge = 25 118 | self.bottomedge = 3 119 | self.filename = None 120 | self.matrix_size_scaler = 3.5 # How wide or tall as scaled fontsize is matrix? 121 | self.vector_size_scaler = 3.2 / 4 # How wide or tall as scaled fontsize is vector for skinny part? 122 | self.shift3D = 6 123 | self.cause = None # Did an exception occurred during evaluation? 124 | self.offending_expr = None 125 | self.fignumber = None 126 | 127 | @staticmethod 128 | def _split_dtype_precision(s): 129 | """Split the final integer part from a string""" 130 | head = s.rstrip('0123456789') 131 | tail = s[len(head):] 132 | return head, tail 133 | 134 | def set_locations(self, maxh): 135 | """ 136 | This function finishes setting up necessary parameters about text 137 | and graphics locations for the plot. We don't know how to set these 138 | values until we know what the max height of the drawing will be. We don't know 139 | what that height is until after we've parsed and so on, which requires that 140 | we collect and store information in this view object before computing maxh. 141 | That is why this is a separate function not part of the constructor. 142 | """ 143 | line2text = self.hchar / 1.7 144 | box2line = line2text*2.6 145 | self.texty = self.bottomedge + maxh + box2line + line2text 146 | self.liney = self.bottomedge + maxh + box2line 147 | self.box_topy = self.bottomedge + maxh 148 | self.maxy = self.texty + 1.4 * self.fontsize 149 | 150 | def _repr_svg_(self): 151 | "Show an SVG rendition in a notebook" 152 | return self.svg() 153 | 154 | def svg(self): 155 | """ 156 | Render as svg and return svg text. Save file and store name in field svgfilename. 157 | """ 158 | if self.filename is None: # have we saved before? (i.e., is it cached?) 159 | self.savefig(tempfile.mktemp(suffix='.svg')) 160 | elif not self.filename.endswith(".svg"): 161 | return None 162 | with open(self.filename, encoding='UTF-8') as f: 163 | svg = f.read() 164 | return svg 165 | 166 | def savefig(self, filename): 167 | "Save viz in format according to file extension." 168 | if plt.fignum_exists(self.fignumber): 169 | # If the matplotlib figure is still active, save it 170 | self.filename = filename # Remember the file so we can pull it back 171 | plt.savefig(filename, dpi=self.dpi, bbox_inches='tight', pad_inches=0) 172 | else: # we have already closed it so try to copy to new filename from the previous 173 | if filename!=self.filename: 174 | f,ext = os.path.splitext(filename) 175 | prev_f,prev_ext = os.path.splitext(self.filename) 176 | if ext != prev_ext: 177 | print(f"File extension {ext} differs from previous {prev_ext}; uses previous.") 178 | ext = prev_ext 179 | filename = f+ext # make sure that we don't copy raw bits and change the file extension to be inconsistent 180 | with open(self.filename, 'rb') as f: 181 | img = f.read() 182 | with open(filename, 'wb') as f: 183 | f.write(img) 184 | self.filename = filename # overwrite the filename with new name 185 | 186 | def show(self): 187 | "Display an SVG in a notebook or pop up a window if not in notebook" 188 | if get_ipython() is None: 189 | svgfilename = tempfile.mktemp(suffix='.svg') 190 | self.savefig(svgfilename) 191 | self.filename = svgfilename 192 | plt.show() 193 | else: 194 | svg = self.svg() 195 | display(SVG(svg)) 196 | plt.close() 197 | 198 | def boxsize(self, v): 199 | """ 200 | How wide and tall should we draw the box representing a vector or matrix. 201 | """ 202 | sh = tsensor.analysis._shape(v) 203 | ty = tsensor.analysis._dtype(v) 204 | if sh is None: return None 205 | if len(sh)==1: return self.vector_size(sh, ty) 206 | return self.matrix_size(sh, ty) 207 | 208 | def matrix_size(self, sh, ty): 209 | """ 210 | How wide and tall should we draw the box representing a matrix. 211 | """ 212 | if len(sh)==1 and sh[0]==1: 213 | return self.vector_size(sh, ty) 214 | 215 | if len(sh) > 1 and sh[0] == 1 and sh[1] == 1: 216 | # A special case where we have a 1x1 matrix extending into the screen. 217 | # Make the 1x1 part a little bit wider than a vector so it's more readable 218 | w, h = 2 * self.vector_size_scaler * self.wchar, 2 * self.vector_size_scaler * self.wchar 219 | elif len(sh) > 1 and sh[1] == 1: 220 | w, h = self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar 221 | elif len(sh)>1 and sh[0]==1: 222 | w, h = self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar 223 | else: 224 | w, h = self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar 225 | return w, h 226 | 227 | def vector_size(self, sh, ty): 228 | """ 229 | How wide and tall is a vector? It's not a function of vector length; instead 230 | we make a row vector with same width as a matrix but height of just one char. 231 | For consistency with matrix_size(), I pass in shape, though it's ignored. 232 | """ 233 | return self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar 234 | 235 | def draw(self, ax, sub): 236 | sh = tsensor.analysis._shape(sub.value) 237 | ty = tsensor.analysis._dtype(sub.value) 238 | self._dtype_encountered.add(ty) 239 | if len(sh) == 1: 240 | self.draw_vector(ax, sub, sh, ty) 241 | else: 242 | self.draw_matrix(ax, sub, sh, ty) 243 | 244 | def draw_vector(self,ax,sub, sh, ty: str): 245 | mid = (sub.leftx + sub.rightx) / 2 246 | w,h = self.vector_size(sh, ty) 247 | color = self.dtype_color_info.color(ty) 248 | rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h), 249 | width=w, 250 | height=h, 251 | linewidth=self.linewidth, 252 | facecolor=color, 253 | edgecolor='grey', 254 | fill=True) 255 | ax.add_patch(rect1) 256 | 257 | # Text above vector rectangle 258 | ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]), 259 | horizontalalignment='center', 260 | fontname=self.dimfontname, fontsize=self.dimfontsize) 261 | # Type info at the bottom of everything 262 | ax.text(mid, self.box_topy - self.hchar, '<${\mathit{'+ty+'}}$>', 263 | verticalalignment='top', horizontalalignment='center', 264 | fontname=self.dimfontname, fontsize=self.dimfontsize-2) 265 | 266 | def draw_matrix(self,ax,sub, sh, ty): 267 | mid = (sub.leftx + sub.rightx) / 2 268 | w,h = self.matrix_size(sh, ty) 269 | box_left = mid - w / 2 270 | color = self.dtype_color_info.color(ty) 271 | 272 | if len(sh) > 2: 273 | back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D), 274 | width=w, 275 | height=h, 276 | linewidth=self.linewidth, 277 | facecolor=color, 278 | edgecolor='grey', 279 | fill=True) 280 | ax.add_patch(back_rect) 281 | rect = patches.Rectangle(xy=(box_left, self.box_topy - h), 282 | width=w, 283 | height=h, 284 | linewidth=self.linewidth, 285 | facecolor=color, 286 | edgecolor='grey', 287 | fill=True) 288 | ax.add_patch(rect) 289 | 290 | # Text above matrix rectangle 291 | ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]), 292 | verticalalignment='center', horizontalalignment='right', 293 | fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90) 294 | 295 | # Note: this was always true since matrix... 296 | textx = mid 297 | texty = self.box_topy + self.dim_ypadding 298 | if len(sh) > 2: 299 | texty += self.dim_ypadding 300 | textx += self.shift3D 301 | 302 | # Text to the left 303 | ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center', 304 | fontname=self.dimfontname, fontsize=self.dimfontsize) 305 | 306 | if len(sh) > 2: 307 | # Text to the right 308 | ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]), 309 | verticalalignment='center', horizontalalignment='center', 310 | fontname=self.dimfontname, fontsize=self.dimfontsize, 311 | rotation=45) 312 | 313 | bottom_text_line = self.box_topy - h - self.dim_ypadding 314 | if len(sh) > 3: 315 | # Text below 316 | remaining = r"$\cdots\mathsf{x}$"+r"$\mathsf{x}$".join([self.nabbrev(sh[i]) for i in range(3,len(sh))]) 317 | bottom_text_line = self.box_topy - h - self.dim_ypadding 318 | ax.text(mid, bottom_text_line, remaining, 319 | verticalalignment='top', horizontalalignment='center', 320 | fontname=self.dimfontname, fontsize=self.dimfontsize) 321 | bottom_text_line -= self.hchar + self.dim_ypadding 322 | 323 | # Type info at the bottom of everything 324 | ax.text(mid, bottom_text_line, '<${\mathit{'+ty+'}}$>', 325 | verticalalignment='top', horizontalalignment='center', 326 | fontname=self.dimfontname, fontsize=self.dimfontsize-2) 327 | 328 | @staticmethod 329 | def nabbrev(n: int) -> str: 330 | if n % 1_000_000 == 0: 331 | return str(n // 1_000_000)+'m' 332 | if n % 1_000 == 0: 333 | return str(n // 1_000)+'k' 334 | return str(n) 335 | 336 | 337 | def pyviz(statement: str, frame=None, 338 | fontname='Consolas', fontsize=13, 339 | dimfontname='Arial', dimfontsize=9, char_sep_scale=1.8, fontcolor='#444443', 340 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 341 | ax=None, dpi=200, hush_errors=True, 342 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> PyVizView: 343 | """ 344 | Parse and evaluate the Python code in the statement string passed in using 345 | the indicated execution frame. The execution frame of the invoking function 346 | is used if frame is None. 347 | 348 | The visualization finds the smallest subexpressions that evaluate to 349 | tensors then underlies them and shows a box or rectangle representing 350 | the tensor dimensions. Boxes in blue (default) have two or more dimensions 351 | but rectangles in yellow (default) have one dimension with shape (n,). 352 | 353 | Upon tensor-related execution error, the offending self-expression is 354 | highlighted (by de-highlighting the other code) and the operator is shown 355 | using error_op_color. 356 | 357 | To adjust the size of the generated visualization to be smaller or bigger, 358 | decrease or increase the font size. 359 | 360 | :param statement: A string representing the line of Python code to visualize within an execution frame. 361 | :param frame: The execution frame in which to evaluate the statement. If None, 362 | use the execution frame of the invoking function 363 | :param fontname: The name of the font used to display Python code 364 | :param fontsize: The font size used to display Python code; default is 13. 365 | Also use this to increase the size of the generated figure; 366 | larger font size means larger image. 367 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 368 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 369 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 370 | text is when plotted in matplotlib. In fact there's probably, 371 | no hope to discover this information accurately in all cases. 372 | Certainly, I gave up after spending huge effort. We have a 373 | situation here where the font should be constant width, so 374 | we can just use a simple scalar times the font size to get 375 | a reasonable approximation of the width and height of a 376 | character box; the default of 1.8 seems to work reasonably 377 | well for a wide range of fonts, but you might have to tweak it 378 | when you change the font size. 379 | :param fontcolor: The color of the Python code. 380 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 381 | :param ignored_color: The de-highlighted color for de-emphasizing code not involved in an erroneous sub expression 382 | :param error_op_color: The color to use for characters associated with the erroneous operator 383 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 384 | :param dpi: This library tries to generate SVG files, which are vector graphics not 385 | 2D arrays of pixels like PNG files. However, it needs to know how to 386 | compute the exact figure size to remove padding around the visualization. 387 | Matplotlib uses inches for its figure size and so we must convert 388 | from pixels or data units to inches, which means we have to know what the 389 | dots per inch, dpi, is for the image. 390 | :param hush_errors: Normally, error messages from true syntax errors but also 391 | unhandled code caught by my parser are ignored. Turn this off 392 | to see what the error messages are coming from my parser. 393 | :param dtype_colors: map from dtype w/o precision like 'int' to color 394 | :param dtype_precisions: list of bit precisions to colorize, such as [32,64,128] 395 | :param dtype_alpha_range: all tensors of the same type are drawn to the same color, 396 | and the alpha channel is used to show precision; the 397 | smaller the bit size, the lower the alpha channel. You 398 | can play with the range to get better visual dynamic range 399 | depending on how many precisions you want to display. 400 | :return: Returns a PyVizView holding info about the visualization; from a notebook 401 | an SVG image will appear. Return none upon parsing error in statement. 402 | """ 403 | view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, char_sep_scale, dpi, 404 | dtype_colors, dtype_precisions, dtype_alpha_range) 405 | 406 | if frame is None: # use frame of caller if not passed in 407 | frame = sys._getframe().f_back 408 | root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors) 409 | if root is None: 410 | print(f"Can't parse {statement}; root is None") 411 | # likely syntax error in statement or code I can't handle 412 | return None 413 | root_to_viz = root 414 | try: 415 | root.eval(frame) 416 | except tsensor.ast.IncrEvalTrap as e: 417 | root_to_viz = e.offending_expr 418 | view.offending_expr = e.offending_expr 419 | view.cause = e.__cause__ 420 | # Don't raise the exception; keep going to visualize code and erroneous 421 | # subexpressions. If this function is invoked from clarify() or explain(), 422 | # the statement will be executed and will fail again during normal execution; 423 | # an exception will be thrown at that time. Then explain/clarify 424 | # will update the error message 425 | subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz) 426 | if ax is None: 427 | fig, ax = plt.subplots(1, 1, dpi=dpi) 428 | else: 429 | fig = ax.figure 430 | view.fignumber = fig.number # track this so that we can determine if the figure has been closed 431 | 432 | ax.axis("off") 433 | 434 | # First, we need to figure out how wide the visualization components are 435 | # for each sub expression. If these are wider than the sub expression text, 436 | # than we need to leave space around the sub expression text 437 | lpad = np.zeros((len(statement),)) # pad for characters 438 | rpad = np.zeros((len(statement),)) 439 | maxh = 0 440 | for sub in subexprs: 441 | w, h = view.boxsize(sub.value) 442 | # update width to include horizontal room for type text like int32 443 | ty = tsensor.analysis._dtype(sub.value) 444 | w_typename = len(ty) * view.wchar_small 445 | w = max(w, w_typename) 446 | maxh = max(h, maxh) 447 | nexpr = sub.stop.cstop_idx - sub.start.cstart_idx 448 | if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space 449 | nexpr += 1 450 | if sub.stop.cstop_idx view.wchar * nexpr: 453 | lpad[sub.start.cstart_idx] += (w - view.wchar) / 2 454 | rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2 455 | 456 | # Now we know how to place all the elements, since we know what the maximum height is 457 | view.set_locations(maxh) 458 | 459 | # Find each character's position based upon width of a character and any padding 460 | charx = np.empty((len(statement),)) 461 | x = view.leftedge 462 | for i,c in enumerate(statement): 463 | x += lpad[i] 464 | charx[i] = x 465 | x += view.wchar 466 | x += rpad[i] 467 | 468 | # Draw text for statement or expression 469 | if view.offending_expr is not None: # highlight erroneous subexpr 470 | highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool) 471 | for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]: 472 | highlight[tok.cstart_idx:tok.cstop_idx] = True 473 | errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool) 474 | for tok in root_to_viz.optokens: 475 | errors[tok.cstart_idx:tok.cstop_idx] = True 476 | for i, c in enumerate(statement): 477 | color = ignored_color 478 | if highlight[i]: 479 | color = fontcolor 480 | if errors[i]: # override color if operator token 481 | color = error_op_color 482 | ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize) 483 | else: 484 | for i, c in enumerate(statement): 485 | ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize) 486 | 487 | # Compute the left and right edges of subexpressions (alter nodes with info) 488 | for i,sub in enumerate(subexprs): 489 | a = charx[sub.start.cstart_idx] 490 | b = charx[sub.stop.cstop_idx - 1] + view.wchar 491 | sub.leftx = a 492 | sub.rightx = b 493 | 494 | # Draw grey underlines and draw matrices 495 | for i,sub in enumerate(subexprs): 496 | a,b = sub.leftx, sub.rightx 497 | pad = view.wchar*0.1 498 | ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color) 499 | view.draw(ax, sub) 500 | 501 | fig_width = charx[-1] + view.wchar + rpad[-1] 502 | fig_width_inches = fig_width / dpi 503 | fig_height_inches = view.maxy / dpi 504 | fig.set_size_inches(fig_width_inches, fig_height_inches) 505 | 506 | ax.set_xlim(0, fig_width) 507 | ax.set_ylim(0, view.maxy) 508 | 509 | return view 510 | 511 | 512 | # ---------------- SHOW AST STUFF --------------------------- 513 | 514 | class QuietGraphvizWrapper(graphviz.Source): 515 | def __init__(self, dotsrc): 516 | super().__init__(source=dotsrc) 517 | 518 | def _repr_svg_(self): 519 | return self.pipe(format='svg', quiet=True).decode(self._encoding) 520 | 521 | def savefig(self, filename): 522 | path = Path(filename) 523 | path.parent.mkdir(exist_ok=True) 524 | 525 | dotfilename = self.save(directory=path.parent.as_posix(), filename=path.stem) 526 | format = path.suffix[1:] # ".svg" -> "svg" etc... 527 | cmd = ["dot", f"-T{format}", "-o", filename, dotfilename] 528 | # print(' '.join(cmd)) 529 | if graphviz.__version__ <= '0.17': 530 | graphviz.backend.run(cmd, capture_output=True, check=True, quiet=False) 531 | else: 532 | graphviz.backend.execute.run_check(cmd, capture_output=True, check=True, quiet=False) 533 | 534 | 535 | def astviz(statement:str, frame='current', 536 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> graphviz.Source: 537 | """ 538 | Display the abstract syntax tree (AST) for the indicated Python code 539 | in statement. Evaluate that code in the context of frame. If the frame 540 | is not specified, the default is to execute the code within the context of 541 | the invoking code. Pass in frame=None to avoid evaluation and just display 542 | the AST. 543 | 544 | Returns a QuietGraphvizWrapper that renders as SVG in a notebook but 545 | you can also call `savefig()` to save the file and in a variety of formats, 546 | according to the file extension. 547 | """ 548 | return QuietGraphvizWrapper( 549 | astviz_dot(statement, frame, 550 | dtype_colors, dtype_precisions, dtype_alpha_range) 551 | ) 552 | 553 | 554 | def astviz_dot(statement:str, frame='current', 555 | dtype_colors=None, dtype_precisions=None, dtype_alpha_range=None) -> str: 556 | def internal_label(node): 557 | sh = tsensor.analysis._shape(node.value) 558 | ty = tsensor.analysis._dtype(node.value) 559 | text = ''.join(str(t) for t in node.optokens) 560 | if sh is None: 561 | return f'{text}' 562 | 563 | sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))]) 564 | return f"""{text}
{sz}
<{ty}>""" 565 | 566 | dtype_color_info = DTypeColorInfo(dtype_colors, dtype_precisions, dtype_alpha_range) 567 | 568 | root, tokens = tsensor.parsing.parse(statement) 569 | 570 | if frame=='current': # use frame of caller if nothing passed in 571 | frame = sys._getframe().f_back 572 | if frame.f_code.co_name=='astviz': 573 | frame = frame.f_back 574 | 575 | if frame is not None: # if the passed in None, then don't do the evaluation 576 | root.eval(frame) 577 | 578 | nodes = tsensor.ast.postorder(root) 579 | atoms = tsensor.ast.leaves(root) 580 | atomsS = set(atoms) 581 | ops = [nd for nd in nodes if nd not in atomsS] # keep order 582 | 583 | gr = """digraph G { 584 | margin=0; 585 | nodesep=.01; 586 | ranksep=.3; 587 | rankdir=BT; 588 | ordering=out; # keep order of leaves 589 | """ 590 | 591 | fontname="Consolas" 592 | fontsize=12 593 | dimfontsize = 9 594 | spread = 0 595 | 596 | # Gen leaf nodes 597 | for i in range(len(tokens)): 598 | t = tokens[i] 599 | if t.type!=token.ENDMARKER: 600 | nodetext = t.value 601 | # if ']' in nodetext: 602 | if nodetext==']': 603 | nodetext = nodetext.replace(']','‌]') # ‌ is 0-width nonjoiner. ']' by itself is bad for DOT 604 | label = f'{nodetext}' 605 | _spread = spread 606 | if t.type==token.DOT: 607 | _spread=.1 608 | elif t.type==token.EQUAL: 609 | _spread=.25 610 | elif t.type in tsensor.parsing.ADDOP: 611 | _spread=.4 612 | elif t.type in tsensor.parsing.MULOP: 613 | _spread=.2 614 | gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n' 615 | 616 | # Make sure leaves are on same level 617 | gr += f'{{ rank=same; ' 618 | for t in tokens: 619 | if t.type!=token.ENDMARKER: 620 | gr += f' leaf{id(t)}' 621 | gr += '\n}\n' 622 | 623 | # Make sure leaves are left to right by linking 624 | for i in range(len(tokens) - 2): 625 | t = tokens[i] 626 | t2 = tokens[i + 1] 627 | gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n' 628 | 629 | # Draw internal ops nodes 630 | for nd in ops: 631 | label = internal_label(nd) 632 | sh = tsensor.analysis._shape(nd.value) 633 | if sh is None: 634 | color = "" 635 | else: 636 | ty = tsensor.analysis._dtype(nd.value) 637 | color = dtype_color_info.color(ty) 638 | color = mc.rgb2hex(color, keep_alpha=True) 639 | color = f'fillcolor="{color}" style=filled' 640 | gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n' 641 | 642 | # Link internal nodes to other nodes or leaves 643 | for nd in nodes: 644 | kids = nd.kids 645 | for sub in kids: 646 | if sub in atomsS: 647 | gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n' 648 | else: 649 | gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n' 650 | 651 | gr += "}\n" 652 | return gr 653 | --------------------------------------------------------------------------------