├── .gitignore ├── LICENSE ├── README.md ├── images ├── experiments.graffle ├── sample-1.svg └── sample-2.svg ├── matricks_old.py ├── setup.py ├── testing ├── examples.ipynb ├── playground.ipynb ├── str_size_matplotlib.py ├── test1.py ├── test2.py ├── test3.py ├── test_incr_eval.py ├── test_parser.py ├── test_tensorflow.py ├── test_tree_eval.py ├── testexc.py └── viz_testing.py └── tsensor ├── __init__.py ├── analysis.py ├── ast.py ├── parsing.py ├── version.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Terence Parr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensor Sensor 2 | 3 | The goal of this library is to generate more helpful exception 4 | messages for numpy/pytorch/tensorflow matrix algebra expressions. Because the 5 | matrix algebra in these libraries is all done in C/C++, they do not 6 | have access to the Python execution environment so they are literally 7 | unable to give information about which Python variables and subexpression caused the problem. Only by catching the exception and then analyzing/re-executing the Python code can we get this kind of an error message. 8 | 9 | The Python `with` statement allows me to trap exceptions that occur 10 | and then I literally parse the Python code of the offending line, build an 11 | expression tree, and then incrementally evaluate the operands 12 | bottom-up until I run into an exception. That tells me which of the 13 | subexpressions caused the problem and then I can pull it apart and 14 | ask if any of those operands are matrices. 15 | 16 | Imagine you have a complicated little matrix expression like: 17 | 18 | ``` 19 | W @ torch.dot(b,b)+ torch.eye(2,2)@x + z 20 | ``` 21 | 22 | And you get this unhelpful error message from pytorch: 23 | 24 | ``` 25 | RuntimeError: 1D tensors expected, got 2D, 2D tensors at [...]/THTensorEvenMoreMath.cpp:83 26 | ``` 27 | 28 | There are two problems: it does not tell you which of the sub 29 | expressions threw the exception and it does not tell you what the 30 | shape of relevant operands are. This library that lets you 31 | do this: 32 | 33 | ``` 34 | import tsensor 35 | with tsensor.clarify(): 36 | W @ torch.dot(b,b)+ torch.eye(2,2)@x + z 37 | ``` 38 | 39 | which then augments the exception message with the following clarification: 40 | 41 | ``` 42 | Cause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1] 43 | ``` 44 | 45 | Here’s another default error message that is almost helpful for expression `W @ z`: 46 | 47 | ``` 48 | RuntimeError: size mismatch, get 2, 2x2,3 49 | ``` 50 | 51 | But tensor-sensor gives: 52 | 53 | ``` 54 | Cause: @ on tensor operand W w/shape [2, 2] and operand z w/shape [3] 55 | ``` 56 | 57 | Non-tensor args/values are ignored. 58 | 59 | ``` 60 | with tsensor.clarify(): 61 | torch.dot(b, 3) 62 | ``` 63 | 64 | gives: 65 | 66 | ``` 67 | TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int 68 | Cause: torch.dot(b,3) tensor arg b w/shape [2, 1] 69 | ``` 70 | 71 | If there are no tensor args, it just shows the cause: 72 | 73 | ``` 74 | with tsensor.clarify(): 75 | z.reshape(1,2,2) 76 | ``` 77 | 78 | gives: 79 | 80 | ``` 81 | RuntimeError: shape '[1, 2, 2]' is invalid for input of size 3 82 | Cause: z.reshape(1,2,2) 83 | ``` 84 | 85 | ## Visualizations 86 | 87 | For more, see [examples.ipynb](testing/examples.ipynb). 88 | 89 | ```python 90 | import tsensor 91 | import graphviz 92 | import torch 93 | import sys 94 | 95 | W = torch.tensor([[1, 2], [3, 4]]) 96 | b = torch.tensor([9, 10]).reshape(2, 1) 97 | x = torch.tensor([4, 5]).reshape(2, 1) 98 | h = torch.tensor([1,2]) 99 | 100 | with tsensor.explain(): 101 | a = torch.relu(x) 102 | b = W @ b + h.dot(h) 103 | ``` 104 | 105 | Displays this in a notebook: 106 | 107 | 108 | 109 | 110 | 111 | 112 | ## Install 113 | 114 | ``` 115 | pip install -U graphviz # make sure you have latest 116 | pip install tensor-sensor 117 | ``` 118 | 119 | which gives you module `tsensor`. I developed and tested with the following versions 120 | 121 | ``` 122 | $ pip list | grep -i flow 123 | tensorflow 2.3.0 124 | tensorflow-estimator 2.3.0 125 | $ pip list | grep -i numpy 126 | numpy 1.18.5 127 | numpydoc 1.1.0 128 | $ pip list | grep -i torch 129 | torch 1.6.0 130 | ``` 131 | 132 | 133 | ## Limitations 134 | 135 | I rely on parsing lines that are assignments or expressions only so the clarify and explain routines do not handle methods expressed like: 136 | 137 | ``` 138 | def bar(): b + x * 3 139 | ``` 140 | 141 | Instead, use 142 | 143 | ``` 144 | def bar(): 145 | b + x * 3 146 | ``` 147 | 148 | watch out for side effects! I don't do assignments, but any functions you call with side effects will be done while I reevaluate statements. 149 | 150 | Can't handle `\` continuations. 151 | 152 | Also note: I've built my own parser to handle just the assignments / expressions tsensor can handle. 153 | 154 | ## Deploy (parrt's use) 155 | 156 | ```bash 157 | $ python setup.py sdist upload 158 | ``` 159 | 160 | Or download and install locally 161 | 162 | ```bash 163 | $ cd ~/github/tensor-sensor 164 | $ pip install . 165 | ``` 166 | 167 | ## Notes 168 | 169 | The behavior of clarify. Clarify has no burden on the run time unless an exception occurs. At this time, it reevaluates the offending line looking for the self-expression that caused the problem. It not only updates the error message in the exception object, but it visualizes the error. 170 | 171 | The behavior of explain. Explain is a big burden on runtime execution. Before every line is executed, explain will evaluate all sub expressions and produce a visualization. Then, the line executes normally. If there is an exception in that line, we detected during visualization and display an altered view of that statement that highlights the offending sub expression. We also need to behave like clarify for the error message in the exception triggered when the Python VM executes that statement normally (after are visualization). 172 | 173 | So, in both cases, we trap exceptions using the `with` construct and augment exception messages. Explain differs from clarify in that we use `settrace()` to process each line of code before the VM executes it normally. Clarify never needs to deal with tracing. 174 | 175 | ### TODO 176 | 177 | * can i call pyviz in debugger? 178 | * try on real examples 179 | * `dict(W=[3,0,1,2], b=[1,0])` that would indicate (300, 30, 60, 3) would best be displayed as (30,60,3, 300) and b would be first dimension last and last dimension first -------------------------------------------------------------------------------- /images/experiments.graffle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prinshul/tensorsensor/803995a54d0f7f56e6d2d5995fcc66a48a21f1ae/images/experiments.graffle -------------------------------------------------------------------------------- /images/sample-1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | leaf140728778297408 15 | 16 |     17 | 2x1 18 |     19 | 20 | 21 |     22 | a 23 |     24 | 25 | 26 | 27 | leaf140728778297504 28 | 29 | = 30 | 31 | 32 | 33 | 34 | leaf140728778299904 35 | 36 | torch 37 | 38 | 39 | 40 | 41 | leaf140728778300000 42 | 43 | . 44 | 45 | 46 | 47 | 48 | leaf140728778300048 49 | 50 | relu 51 | 52 | 53 | 54 | 55 | leaf140728778300144 56 | 57 | ( 58 | 59 | 60 | 61 | 62 | leaf140728778300192 63 | 64 |     65 | 2x1 66 |     67 | 68 | 69 |     70 | x 71 |     72 | 73 | 74 | 75 | 76 | leaf140728778300288 77 | 78 | ) 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /images/sample-2.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | leaf140728778300816 15 | 16 |     17 | 2x1 18 |     19 | 20 | 21 |     22 | b 23 |     24 | 25 | 26 | 27 | leaf140728778299856 28 | 29 | = 30 | 31 | 32 | 33 | 34 | leaf140728778299376 35 | 36 |     37 | 2x2 38 |     39 | 40 | 41 |     42 | W 43 |     44 | 45 | 46 | 47 | 48 | leaf140728778300384 49 | 50 | @ 51 | 52 | 53 | 54 | 55 | leaf140728778300960 56 | 57 |     58 | 2x1 59 |     60 | 61 | 62 |     63 | b 64 |     65 | 66 | 67 | 68 | 69 | leaf140728778300912 70 | 71 | + 72 | 73 | 74 | 75 | 76 | leaf140728778301008 77 | 78 |     79 | 2 80 |     81 | 82 | 83 |     84 | h 85 |     86 | 87 | 88 | 89 | 90 | leaf140728778301104 91 | 92 | . 93 | 94 | 95 | 96 | 97 | leaf140728778301296 98 | 99 | dot 100 | 101 | 102 | 103 | 104 | leaf140728778301392 105 | 106 | ( 107 | 108 | 109 | 110 | 111 | leaf140728778298512 112 | 113 |     114 | 2 115 |     116 | 117 | 118 |     119 | h 120 |     121 | 122 | 123 | 124 | 125 | leaf140728640009264 126 | 127 | ) 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /matricks_old.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import inspect 25 | import typing 26 | import sys 27 | from collections import namedtuple 28 | from tokenize import tokenize, TokenInfo, NUMBER, STRING, NAME, OP 29 | from io import BytesIO 30 | 31 | ADDOP = {'+', '-'} 32 | MULOP = {'*', '/', '@', '%'} 33 | UNARYOP = {'!', '~'} 34 | OPERATORS = {'+', '-', '*', '/', '@', '%', '!', '~'} 35 | SYMBOLS = OPERATORS.union({'(', ')', '[', ']', '=', ',', ':'}) 36 | EOF = '' 37 | 38 | def idstart(c): 39 | return c[0].isalpha() or c[0]=='_' 40 | 41 | def idchar(c): # include '.'; assume single char here 42 | return c.isalpha() or c.isdigit() or c == '_' or c == '.' 43 | 44 | 45 | # Parse tree definitions 46 | 47 | class ParseTreeNode: 48 | def eval(self, frame): 49 | "Evaluate the expression represented by this (sub)tree in context of frame" 50 | return eval(str(self), frame.f_locals, frame.f_globals) 51 | @property 52 | def left(self): return None 53 | @property 54 | def right(self): return None 55 | def __str__(self): 56 | pass 57 | def __repr__(self): 58 | args = [v+'='+self.__dict__[v].__repr__() for v in self.__dict__] 59 | args = ','.join(args) 60 | return f"{self.__class__.__name__}({args})" 61 | 62 | class Assign(ParseTreeNode): 63 | def __init__(self, lhs, rhs): 64 | self.lhs, self.rhs = lhs, rhs 65 | @property 66 | def left(self): return self.lhs 67 | @property 68 | def right(self): return self.rhs 69 | def __str__(self): 70 | return str(self.lhs)+'='+str(self.rhs) 71 | 72 | class Call(ParseTreeNode): 73 | def __init__(self, name, args): 74 | self.name = name 75 | self.args = args 76 | @property 77 | def left(self): return self.args 78 | def __str__(self): 79 | if isinstance(self.args,list): 80 | args_ = ','.join([str(a) for a in self.args]) 81 | else: 82 | args_ = str(self.args) 83 | return f"{self.name}({args_})" 84 | 85 | class Index(ParseTreeNode): 86 | def __init__(self, name, index): 87 | self.name = name 88 | self.index = index 89 | @property 90 | def left(self): return self.index 91 | def __str__(self): 92 | i = self.index 93 | if isinstance(i,list): 94 | i = ','.join(str(v) for v in i) 95 | return f"{self.name}[{i}]" 96 | 97 | class BinaryOp(ParseTreeNode): 98 | def __init__(self, op, a, b): 99 | self.op, self.a, self.b = op, a, b 100 | @property 101 | def left(self): return self.a 102 | @property 103 | def right(self): return self.b 104 | def __str__(self): 105 | return f"{self.a}{self.op}{self.b}" 106 | 107 | class UnaryOp(ParseTreeNode): 108 | def __init__(self, op, opnd): 109 | self.op = op 110 | self.opnd = opnd 111 | @property 112 | def left(self): return self.opnd 113 | def __str__(self): 114 | return f"{self.op}{self.opnd}" 115 | 116 | class ListLiteral(ParseTreeNode): 117 | def __init__(self, elems): 118 | self.elems = elems 119 | @property 120 | def left(self): return self.elems 121 | def __str__(self): 122 | if isinstance(self.elems,list): 123 | elems_ = ','.join(str(e) for e in self.elems) 124 | else: 125 | elems_ = self.elems 126 | return f"[{elems_}]" 127 | 128 | class SubExpr(ParseTreeNode): 129 | # record parens for later display to keep precedence 130 | def __init__(self, e): 131 | self.e = e 132 | @property 133 | def left(self): return self.e 134 | def __str__(self): 135 | return f"({self.e})" 136 | 137 | class Atom(ParseTreeNode): 138 | def __init__(self, s): 139 | self.s = s 140 | def __repr__(self): 141 | return self.s 142 | def __str__(self): 143 | return self.s 144 | 145 | 146 | class PyExprParser: 147 | def __init__(self, code): 148 | self.code = code 149 | self.tokens = mytokenize(code) 150 | self.t = 0 # current lookahead 151 | 152 | def parse(self): 153 | # print("\nparse", self.code) 154 | # print(self.tokens) 155 | s = self.statement() 156 | self.match(EOF) 157 | return s 158 | 159 | def statement(self): 160 | lhs = self.expression() 161 | rhs = None 162 | if self.LA(1) == '=': 163 | self.t += 1 164 | rhs = self.expression() 165 | return Assign(lhs,rhs) 166 | return lhs 167 | 168 | def expression(self): 169 | return self.addexpr() 170 | 171 | def addexpr(self): 172 | elist = [] 173 | root = self.multexpr() 174 | while self.LA(1) in ADDOP: 175 | op = self.LA(1) 176 | elist.append(self.LA(1)) 177 | self.t += 1 178 | b = self.multexpr() 179 | root = BinaryOp(op, root, b) 180 | return root 181 | 182 | def multexpr(self): 183 | elist = [] 184 | root = self.unaryexpr() 185 | while self.LA(1) in MULOP: 186 | op = self.LA(1) 187 | elist.append(self.LA(1)) 188 | self.t += 1 189 | b = self.unaryexpr() 190 | root = BinaryOp(op, root, b) 191 | return root 192 | 193 | def unaryexpr(self): 194 | if self.LA(1) in UNARYOP: 195 | op = self.LA(1) 196 | self.t += 1 197 | e = self.unaryexpr() 198 | return UnaryOp(op, e) 199 | elif self.isatom() or self.isgroup(): 200 | return self.postexpr() 201 | else: 202 | print(f"missing unary expr at: {self.LA(1)}") 203 | 204 | def postexpr(self): 205 | e = self.atom() 206 | if self.LA(1)=='(': 207 | return self.funccall(e) 208 | if self.LA(1) == '[': 209 | return self.index(e) 210 | return e 211 | 212 | def atom(self): 213 | if self.LA(1) == '(': 214 | return self.subexpr() 215 | elif self.LA(1) == '[': 216 | return self.listatom() 217 | elif self.isatom() or self.isgroup(): 218 | atom = self.LA(1) 219 | self.t += 1 # match name or number 220 | return Atom(atom) 221 | else: 222 | print("error") 223 | 224 | def funccall(self, f): 225 | self.match('(') 226 | el = None 227 | if self.LA(1)!=')': 228 | el = self.exprlist() 229 | self.match(')') 230 | return Call(f, el) 231 | 232 | def index(self, e): 233 | self.match('[') 234 | el = self.exprlist() 235 | self.match(']') 236 | return Index(e, el) 237 | 238 | def exprlist(self): 239 | elist = [] 240 | e = self.expression() 241 | elist.append(e) 242 | while self.LA(1)==',': 243 | self.match(',') 244 | e = self.expression() 245 | elist.append(e) 246 | return elist if len(elist)>1 else elist[0] 247 | 248 | def subexpr(self): 249 | self.match('(') 250 | e = self.expression() 251 | self.match(')') 252 | return SubExpr(e) 253 | 254 | def listatom(self): 255 | self.match('[') 256 | e = self.exprlist() 257 | self.match(']') 258 | return ListLiteral(e) 259 | 260 | def isatom(self): 261 | return idstart(self.LA(1)) or self.LA(1).isdigit() or self.LA(1)==':' 262 | 263 | def isgroup(self): 264 | return self.LA(1)=='(' or self.LA(1)=='[' 265 | 266 | def LA(self, i): 267 | ahead = self.t + i - 1 268 | if ahead >= len(self.tokens): 269 | return EOF 270 | return self.tokens[ahead] 271 | 272 | def match(self, token): 273 | if self.LA(1)!=token: 274 | print(f"mismatch token {self.LA(1)}, looking for {token}") 275 | self.t += 1 276 | 277 | 278 | # def mytokenize(s): 279 | # tokensO = tokenize(BytesIO(s.encode('utf-8')).readline) 280 | # tokens = [] 281 | # for tok in tokensO: 282 | # type, value, _, _, _ = tok 283 | # if type in {NUMBER, STRING, NAME, OP}: 284 | # tokens.append(value) 285 | # else: 286 | # print("ignoring", type, value) 287 | # 288 | # return tokens + [EOF] 289 | 290 | def mytokenize(code): 291 | n = len(code) 292 | i = 0 293 | tokens = [] 294 | while i0 358 | 359 | def deepest_frame(self, exc_traceback): 360 | tb = exc_traceback 361 | while tb.tb_next != None: 362 | tb = tb.tb_next 363 | return tb.tb_frame 364 | 365 | def info(self, frame): 366 | module = frame.f_globals['__name__'] 367 | info = inspect.getframeinfo(frame) 368 | code = info.code_context[0].strip() 369 | filename, line = info.filename, info.lineno 370 | name = info.function 371 | return module, name, filename, line, code 372 | 373 | 374 | class IncrEvalTrap(BaseException): 375 | def __init__(self, expr): 376 | self.expr = expr # where in tree did we get exception? 377 | 378 | 379 | def incr_eval(tree, frame): 380 | "Incrementally evaluate all subexpressions, looking for operation that fails; return that subtree" 381 | if tree is None: 382 | return 383 | if isinstance(tree, list): # must be args list or expr list 384 | for t in tree: 385 | incr_eval(t, frame) 386 | return 387 | if tree.left is not None and tree.right is not None: # binary 388 | incr_eval(tree.left, frame) 389 | incr_eval(tree.right, frame) 390 | elif tree.left is not None: # unary 391 | incr_eval(tree.left, frame) 392 | try: 393 | tree.eval(frame) # try to do this operator 394 | except: 395 | raise IncrEvalTrap(tree) 396 | # else all is well, just return to larger subexpr up the tree 397 | 398 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from setuptools import setup 25 | 26 | exec(open('tsensor/version.py').read()) 27 | setup( 28 | name='tensor-sensor', 29 | version=__version__, 30 | url='https://github.com/parrt/tensor-sensor', 31 | license='MIT', 32 | py_modules=['tsensor.parsing', 'tsensor.ast', 'tsensor.analysis', 'tsensor.viz', 'tsensor.version'], 33 | author='Terence Parr', 34 | author_email='parrt@cs.usfca.edu', 35 | python_requires='>=3.6', 36 | install_requires=['graphviz>=0.14.1','numpy','torch','tensorflow', 'IPython', 'matplotlib'], 37 | description='The goal of this library is to generate more helpful exception messages for numpy/pytorch tensor algebra expressions.', 38 | # keywords='visualization data structures', 39 | classifiers=['License :: OSI Approved :: MIT License', 40 | 'Intended Audience :: Developers'] 41 | ) 42 | -------------------------------------------------------------------------------- /testing/playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "ename": "RuntimeError", 10 | "evalue": "1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]", 11 | "output_type": "error", 12 | "traceback": [ 13 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 14 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 15 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;31m# W[33, 33] = 3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 16 | "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]" 17 | ] 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "import tsensor\n", 23 | "import sys\n", 24 | "\n", 25 | "W = torch.tensor([[1, 2], [3, 4]])\n", 26 | "b = torch.tensor([9, 10]).reshape(2, 1)\n", 27 | "x = torch.tensor([4, 5]).reshape(2, 1)\n", 28 | "h = torch.tensor([1,2])\n", 29 | "# z + z + W @ z\n", 30 | "# W @ z\n", 31 | "#torch.dot(b, 3)\n", 32 | "\n", 33 | "with tsensor.clarify():\n", 34 | " W @ torch.dot(b,b)+ torch.eye(2,2)@x + z\n", 35 | "# W[33, 33] = 3\n", 36 | "b = torch.abs( W @ b + x )" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "ename": "RuntimeError", 46 | "evalue": "1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]", 47 | "output_type": "error", 48 | "traceback": [ 49 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 50 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 51 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtsensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclarify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m+\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 52 | "\u001b[0;31mRuntimeError\u001b[0m: 1D tensors expected, got 2D, 2D tensors at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorEvenMoreMath.cpp:83\nCause: torch.dot(b,b) tensor arg b w/shape [2, 1], arg b w/shape [2, 1]" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "with tsensor.clarify():\n", 58 | " W @ torch.dot(b,b)+ torch.eye(2,2)@x + z" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## Graphviz" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 8, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "[, , , , , , , , , , , , , , , , , , , ]\n" 78 | ] 79 | }, 80 | { 81 | "data": { 82 | "text/plain": [ 83 | "Assign(lhs=b,rhs=BinaryOp(op=,lhs=BinaryOp(op=,lhs=W,rhs=h),rhs=BinaryOp(op=,lhs=Call(func=Member(obj=torch,member=abs),args=[x]),rhs=Call(func=Member(obj=h,member=dot),args=[h]))))" 84 | ] 85 | }, 86 | "execution_count": 8, 87 | "metadata": {}, 88 | "output_type": "execute_result" 89 | } 90 | ], 91 | "source": [ 92 | "p = tsensor.parse.PyExprParser(\"b = W @ h + torch.abs(x) *h.dot(h)\")\n", 93 | "print(p.tokens)\n", 94 | "root = p.parse()\n", 95 | "root" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 9, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "tensor([[25, 31],\n", 107 | " [30, 36]])" 108 | ] 109 | }, 110 | "execution_count": 9, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "frame = sys._getframe()\n", 117 | "result = root.eval(frame)\n", 118 | "result" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 7, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def postorder(t):\n", 128 | " nodes = []\n", 129 | " _postorder(t, nodes)\n", 130 | " return nodes\n", 131 | "def _postorder(t, nodes):\n", 132 | " if t is None:\n", 133 | " return\n", 134 | " for sub in t.kids:\n", 135 | " _postorder(sub, nodes)\n", 136 | " nodes.append(t)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "def leaves(t):\n", 146 | " nodes = []\n", 147 | " _leaves(t, nodes)\n", 148 | " return nodes\n", 149 | "def _leaves(t, nodes):\n", 150 | " if t is None:\n", 151 | " return\n", 152 | " if len(t.kids)==0:\n", 153 | " nodes.append(t)\n", 154 | " return\n", 155 | " for sub in t.kids:\n", 156 | " _leaves(sub, nodes)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "nodes = postorder(root)\n", 166 | "atoms = leaves(root)\n", 167 | "atoms" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# map tokens to nodes\n", 177 | "tok2node = {}\n", 178 | "for nd in atoms:\n", 179 | " tok2node[nd.token] = nd" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "def walk(t, pre=lambda x:None, post=lambda x:None):\n", 189 | " if t is None:\n", 190 | " return\n", 191 | " pre(t)\n", 192 | " for sub in t.kids:\n", 193 | " walk(sub, pre, post)\n", 194 | " post(t)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "import graphviz\n", 204 | "\n", 205 | "s = \"\"\"\n", 206 | "digraph G {\n", 207 | " nodesep=.1;\n", 208 | " ranksep=.3;\n", 209 | " rankdir=BT;\n", 210 | " ordering=out; # keep order of leaves\n", 211 | " node [penwidth=\"0.5\", shape=plaintext, width=.1, height=.1];\n", 212 | "\"\"\"\n", 213 | "\n", 214 | "nodesS = set(nodes)\n", 215 | "atomsS = set(atoms)\n", 216 | "ops = nodesS.difference(atomsS)\n", 217 | "\n", 218 | "# s += f'{{ rank=same;'\n", 219 | "# for a in atoms:\n", 220 | "# s += f' node{id(tok2node[a.token])};'\n", 221 | "# s += '}\\n'\n", 222 | "\n", 223 | "s += f'{{ rank=same;'\n", 224 | "for t in p.tokens:\n", 225 | " if t.type!=tsensor.ENDMARKER:\n", 226 | " x = tok2node[t] if t in tok2node else t\n", 227 | " shape = \"\"\n", 228 | " sh = tsensor._shape(x.value)\n", 229 | " if x in atomsS and sh is not None:\n", 230 | " shape = \"shape=box, fixedsize=shape \"\n", 231 | " if len(sh)==1:\n", 232 | " shape += f\"width={.15}, height={sh[0]/6.66}\"\n", 233 | " elif len(sh)==2:\n", 234 | " shape += f\"width={sh[1]/6.66}, height={sh[0]/6.66}\"\n", 235 | "# print(sh, shape)\n", 236 | " s += f'leaf{id(x)} [{shape} label=<{t.value}>]\\n'\n", 237 | "s += '}\\n'\n", 238 | "\n", 239 | "# for nd in ops:\n", 240 | "# s += f'leaf{id(nd)} [label=<{str(nd)}>]'\n", 241 | "\n", 242 | "for nd in nodes:\n", 243 | " if nd in ops:\n", 244 | " text = str(nd)\n", 245 | "# if isinstance(nd, tsensor.Atom):\n", 246 | "# text = str(nd)\n", 247 | "# else:\n", 248 | "# text = nd.__class__.__name__\n", 249 | "# text = str(nd)#+\"\\n\"+str(nd.value)\n", 250 | "# text = \" \"\n", 251 | " shape = \"\"\n", 252 | " sh = tsensor._shape(nd.value)\n", 253 | " if sh is not None:\n", 254 | " shape = \"shape=box, fixedsize=shape \"\n", 255 | " if len(sh)==1:\n", 256 | " shape += f\"width={.15}, height={sh[0]/6.66}\"\n", 257 | " elif len(sh)==2:\n", 258 | " shape += f\"width={sh[1]/6.66}, height={sh[0]/6.66}\"\n", 259 | " if sh is not None:\n", 260 | " text += \"
\"+'x'.join(str(s) for s in sh)\n", 261 | "# else:\n", 262 | "# text += \"=\"+str(nd.value)\n", 263 | " s += f'node{id(nd)} [{shape} label=<{text}>]\\n'\n", 264 | "\n", 265 | "# link leaves left to right\n", 266 | "for i in range(len(p.tokens)-2):\n", 267 | " t = p.tokens[i]\n", 268 | " t2 = p.tokens[i+1]\n", 269 | " x = tok2node[t] if t in tok2node else t\n", 270 | " x2 = tok2node[t2] if t2 in tok2node else t2\n", 271 | " s += f'leaf{id(x)} -> leaf{id(x2)} [style=invis];\\n'\n", 272 | " \n", 273 | "for nd in nodes:\n", 274 | " kids = nd.kids\n", 275 | "# if isinstance(nd, tsensor.Call) and isinstance(nd.kids[0], tsensor.Member):\n", 276 | "# print('ignore', nd)\n", 277 | "# kids = kids[1:]\n", 278 | " for sub in kids:\n", 279 | " if sub in atomsS:\n", 280 | " s += f'node{id(nd)} -> leaf{id(sub)} [dir=back, penwidth=\"0.5\", color=\"#444443\", arrowsize=.4];\\n'\n", 281 | " else:\n", 282 | " s += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth=\"0.5\", color=\"#444443\", arrowsize=.4];\\n'\n", 283 | "s += \"}\\n\"\n", 284 | "graphviz.Source(s)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "import graphviz\n", 294 | "\n", 295 | "s = \"\"\"\n", 296 | "digraph G {\n", 297 | " nodesep=.1;\n", 298 | " ranksep=.3;\n", 299 | " rankdir=BT;\n", 300 | " ordering=out; # keep order of leaves\n", 301 | " node [penwidth=\"0.5\", shape=plaintext, width=.1, height=.1];\n", 302 | "\"\"\"\n", 303 | "\n", 304 | "nodesS = set(nodes)\n", 305 | "atomsS = set(atoms)\n", 306 | "ops = nodesS.difference(atomsS)\n", 307 | "\n", 308 | "s += f'{{ rank=same;'\n", 309 | "for a in atoms:\n", 310 | " s += f'node{id(a)};'\n", 311 | "s += '}\\n'\n", 312 | "\n", 313 | "for nd in nodes:\n", 314 | " if isinstance(nd, tsensor.Atom):\n", 315 | " text = str(nd)\n", 316 | " else:\n", 317 | " text = nd.__class__.__name__\n", 318 | " shape = \"\"\n", 319 | " sh = _shape(nd.value)\n", 320 | " if sh is not None:\n", 321 | " shape = \"shape=box, fixedsize=shape \"\n", 322 | " if len(sh)==1:\n", 323 | " shape += f\"width={.15}, height={sh[0]/6.66}\"\n", 324 | " elif len(sh)==2:\n", 325 | " shape += f\"width={sh[1]/6.66}, height={sh[0]/6.66}\"\n", 326 | " s += f'node{id(nd)} [{shape} label=<{text}>]\\n'\n", 327 | " \n", 328 | "for nd in nodes:\n", 329 | " kids = nd.kids\n", 330 | "# if isinstance(nd, tsensor.Call) and isinstance(nd.kids[0], tsensor.Member):\n", 331 | "# print('ignore', nd)\n", 332 | "# kids = kids[1:]\n", 333 | " for sub in kids:\n", 334 | " if sub in atomsS:\n", 335 | " s += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth=\"0.5\", color=\"#444443\", arrowsize=.4];'\n", 336 | " else:\n", 337 | " s += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth=\"0.5\", color=\"#444443\", arrowsize=.4];'\n", 338 | "s += \"}\\n\"\n", 339 | "graphviz.Source(s)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "W = torch.tensor([[1, 2], [3, 4], [5, 6]])\n", 349 | "b = torch.tensor([9, 10]).reshape(2, 1)\n", 350 | "x = torch.tensor([4, 5]).reshape(2, 1)\n", 351 | "h = torch.tensor([1,2])\n", 352 | "a = 3\n", 353 | "\n", 354 | "x = torch.tensor([4, 5]).reshape(2, 1)\n", 355 | "p = tsensor.PyExprParser(\"a = a\")\n", 356 | "p = tsensor.PyExprParser(\"b = W@b + torch.abs(x) + h.dot(h)\")\n", 357 | "p = tsensor.PyExprParser(\"b = W@b + h.dot(h)\")\n", 358 | "print(p.tokens)\n", 359 | "root = p.parse()\n", 360 | "nodes = postorder(root)\n", 361 | "atoms = leaves(root)\n", 362 | "# map tokens to nodes\n", 363 | "tok2node = {}\n", 364 | "for nd in atoms:\n", 365 | " tok2node[nd.token] = nd\n", 366 | "frame = sys._getframe()\n", 367 | "result = root.eval(frame)\n", 368 | "result" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "import graphviz\n", 378 | "\n", 379 | "def matrix_html(nrows, ncols, label, fontsize=12, fontname=\"Consolas\", dimfontsize=9, color=\"#cfe2d4\"):\n", 380 | " isvec = ncols==None\n", 381 | " if isvec:\n", 382 | " sz = str(nrows)\n", 383 | " ncols=nrows\n", 384 | " nrows=1\n", 385 | " else:\n", 386 | " sz = f\"{nrows}x{ncols}\"\n", 387 | " w = ncols*20\n", 388 | " h = nrows*20\n", 389 | " if ncols==1:\n", 390 | " w = 15\n", 391 | " if nrows==1:\n", 392 | " h = 15\n", 393 | " html = f\"\"\"\n", 394 | " \n", 395 | " \n", 396 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 404 | " \n", 405 | "
\n", 397 | " {sz}\n", 398 | "
\n", 402 | " {label}\n", 403 | "
\"\"\"\n", 406 | " return html\n", 407 | "\n", 408 | "nodesS = set(nodes)\n", 409 | "atomsS = set(atoms)\n", 410 | "ops = nodesS.difference(atomsS)\n", 411 | "\n", 412 | "s = \"\"\"\n", 413 | "digraph G {\n", 414 | " nodesep=.0;\n", 415 | " ranksep=.3;\n", 416 | " rankdir=BT;\n", 417 | " ordering=out; # keep order of leaves\n", 418 | "\"\"\"\n", 419 | "\n", 420 | "fontname=\"Consolas\"\n", 421 | "fontsize=12\n", 422 | "spread = .2\n", 423 | "\n", 424 | "s += f'{{ rank=same; '\n", 425 | "for t in p.tokens:\n", 426 | " if t.type!=tsensor.ENDMARKER:\n", 427 | " x = tok2node[t] if t in tok2node else t\n", 428 | " shape = \"\"\n", 429 | " sh = tsensor._shape(x.value)\n", 430 | " label = f'{t.value}'\n", 431 | " matrixcolor=\"#cfe2d4\"\n", 432 | " vectorcolor=\"#fefecd\"\n", 433 | " if x in atomsS and sh is not None:\n", 434 | " if len(sh)==1:\n", 435 | " label = matrix_html(sh[0],None,t.value,fontname=fontname,fontsize=fontsize,color=vectorcolor)\n", 436 | " elif len(sh)==2:\n", 437 | " label = matrix_html(sh[0],sh[1],t.value,fontname=fontname,fontsize=fontsize,color=matrixcolor)\n", 438 | " # margin/width don't seem to do anything for shape=plain\n", 439 | " if t.type==tsensor.DOT:\n", 440 | " spread=.1\n", 441 | " if t.type==tsensor.EQUAL:\n", 442 | " spread=.25\n", 443 | " if t.type in tsensor.ADDOP:\n", 444 | " spread=.5\n", 445 | " if t.type in tsensor.MULOP:\n", 446 | " spread=.2\n", 447 | " s += f'leaf{id(x)} [shape=box penwidth=0 margin=.001 width={spread} label=<{label}>]\\n'\n", 448 | "s += '}\\n'\n", 449 | "\n", 450 | "s += \"}\\n\"\n", 451 | "# print(s)\n", 452 | "graphviz.Source(s)" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "s = \"\"\"\n", 462 | "digraph foo {\n", 463 | " rankdir=TB\n", 464 | " subgraph {\n", 465 | " node1;\n", 466 | " }\n", 467 | " subgraph {\n", 468 | " node2;\n", 469 | " }\n", 470 | " node1 -> node2\n", 471 | "}\n", 472 | "\"\"\"\n", 473 | "display(graphviz.Source(s))\n", 474 | "display(graphviz.Source(s))" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "metadata": {}, 480 | "source": [ 481 | "## Get string size in matplotlib" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [ 490 | "import matplotlib\n", 491 | "import matplotlib.patches as patches\n", 492 | "from matplotlib import pyplot as plt\n", 493 | "def textdim(s, fontsize=11):\n", 494 | " fig, ax = plt.subplots(1,1)\n", 495 | " t = ax.text(0, 0, s, bbox={'lw':0}, fontsize=fontsize)\n", 496 | " plt.savefig(\"/tmp/junk\")\n", 497 | " plt.close()\n", 498 | " bb = t.get_bbox_patch()\n", 499 | " w, h = bb.get_width(), bb.get_height()\n", 500 | " return w, h" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [ 509 | "textdim(\"test of foO\", fontsize=11)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "import numpy as np\n", 519 | "a = [[1,2,3],[3,4,5]]\n", 520 | "objviz(a)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "from lolviz import *\n", 530 | "print(objviz(a).source)" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": null, 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "s = \"\"\"\n", 540 | "digraph G {\n", 541 | " ranksep=0;\n", 542 | " rankdir=BT;\n", 543 | " ordering=out; # keep order of leaves\n", 544 | " node1 [shape=plain, space=\"0.0\", margin=\"0.01\",label=<\n", 545 | "\n", 546 | "\n", 547 | "\n", 548 | "\n", 549 | "\n", 550 | "\n", 551 | "\n", 552 | "
100x200
self.W
\n", 553 | " >]\n", 554 | "}\n", 555 | "\"\"\"\n", 556 | "graphviz.Source(s)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": {}, 562 | "source": [ 563 | "## CSS (yuck)" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [ 572 | "from IPython.core.display import display, HTML\n", 573 | "h = \"\"\"\n", 574 | "
Actief
\n", 575 | "\n", 576 | "\n", 577 | " \n", 578 | " \n", 592 | " \n", 593 | "\n", 594 | "\"\"\"\n", 595 | "t = HTML(h)\n", 596 | "t" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "def matrix(nrows,ncols,text,dim_fontsize=9,fontsize=12):\n", 606 | " h = f\"\"\"\n", 607 | "
\n", 608 | "
\n", 609 | "
\n", 610 | " {ncols}\n", 611 | "
\n", 612 | "
\n", 613 | " {nrows}\n", 614 | "
\n", 615 | "
\n", 616 | "
\n", 617 | " {text}\n", 618 | "
\n", 619 | "
\n", 620 | " \"\"\"\n", 621 | " sp = f\"\"\"\n", 622 | "
\n", 623 | " \n", 624 | " \n", 625 | " {ncols}\n", 626 | " \n", 627 | " \n", 628 | " {nrows}\n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " {text}\n", 633 | " \n", 634 | "
\n", 635 | " \"\"\"\n", 636 | "\n", 637 | " return sp\n", 638 | "\n", 639 | "m = matrix(120,20,\"self.W\")\n", 640 | "m = f\"torch.relu({m})\"\n", 641 | "HTML(m)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [] 657 | } 658 | ], 659 | "metadata": { 660 | "kernelspec": { 661 | "display_name": "Python 3", 662 | "language": "python", 663 | "name": "python3" 664 | }, 665 | "language_info": { 666 | "codemirror_mode": { 667 | "name": "ipython", 668 | "version": 3 669 | }, 670 | "file_extension": ".py", 671 | "mimetype": "text/x-python", 672 | "name": "python", 673 | "nbconvert_exporter": "python", 674 | "pygments_lexer": "ipython3", 675 | "version": "3.8.3" 676 | } 677 | }, 678 | "nbformat": 4, 679 | "nbformat_minor": 4 680 | } 681 | -------------------------------------------------------------------------------- /testing/str_size_matplotlib.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.patches as patches 3 | from matplotlib import pyplot as plt 4 | 5 | def textdim(s, fontname='Consolas', fontsize=11): 6 | fig, ax = plt.subplots(1,1) 7 | t = ax.text(0, 0, s, bbox={'lw':0, 'pad':0}, fontname=fontname, fontsize=fontsize) 8 | # plt.savefig(tempfile.mktemp(".pdf")) 9 | plt.savefig("/tmp/font.pdf", pad_inches=0, dpi=200) 10 | print(t) 11 | plt.close() 12 | bb = t.get_bbox_patch() 13 | print(bb) 14 | w, h = bb.get_width(), bb.get_height() 15 | return w, h 16 | 17 | # print(textdim("@")) 18 | #exit() 19 | 20 | 21 | # From: https://stackoverflow.com/questions/22667224/matplotlib-get-text-bounding-box-independent-of-backend 22 | def find_renderer(fig): 23 | if hasattr(fig.canvas, "get_renderer"): 24 | #Some backends, such as TkAgg, have the get_renderer method, which 25 | #makes this easy. 26 | renderer = fig.canvas.get_renderer() 27 | else: 28 | #Other backends do not have the get_renderer method, so we have a work 29 | #around to find the renderer. Print the figure to a temporary file 30 | #object, and then grab the renderer that was used. 31 | #(I stole this trick from the matplotlib backend_bases.py 32 | #print_figure() method.) 33 | import io 34 | fig.canvas.print_pdf(io.BytesIO()) 35 | renderer = fig._cachedRenderer 36 | return(renderer) 37 | 38 | 39 | def textdim(s, fontname='Consolas', fontsize=11): 40 | fig, ax = plt.subplots(1, 1) 41 | t = ax.text(0, 0, s, fontname=fontname, fontsize=fontsize, transform=None) 42 | bb = t.get_window_extent(find_renderer(fig)) 43 | print(s, bb.width, bb.height) 44 | 45 | # t = mpl.textpath.TextPath(xy=(0, 0), s=s, size=fontsize, prop=fontname) 46 | # bb = t.get_extents() 47 | # print(s, "new", bb) 48 | plt.close() 49 | return bb.width, bb.height 50 | 51 | # print(textdim("test of foo", fontsize=11)) 52 | # print(textdim("test of foO", fontsize=11)) 53 | # print(textdim("W @ b + x * 3 + h.dot(h)", fontsize=12)) 54 | 55 | code = 'W@ b + x *3 + h.dot(h)' 56 | code = 'W@ b.f(x,y)'# + x *3 + h.dot(h)' 57 | 58 | fig, ax = plt.subplots(1,1,figsize=(4,1)) 59 | 60 | fontname = "Serif" 61 | fontsize = 16 62 | 63 | # for c in code: 64 | # t = ax.text(0,0,c) 65 | # bbox1 = t.get_window_extent(find_renderer(fig)) 66 | # # print(c, '->', bbox1.width, bbox1.height) 67 | # print(c, '->', textdim(c, fontname=fontname, fontsize=fontsize)) 68 | # rect1 = patches.Rectangle((0,0), bbox1.width, bbox1.height, \ 69 | # color = [0,0,0], fill = False) 70 | # fig.patches.append(rect1) 71 | 72 | x = 0 73 | for c in code: 74 | # print(f"plot {c} at {x},{0}") 75 | ax.text(x, 10, c, fontname=fontname, fontsize=fontsize, transform=None) 76 | w, h = textdim(c, fontname=fontname, fontsize=fontsize) 77 | # print(w,h,'->',x) 78 | x = x + w 79 | 80 | ax.set_xlim(0,x) 81 | #ax.set_ylim(0,10) 82 | 83 | ax.axis('off') 84 | 85 | #plt.show() 86 | 87 | plt.tight_layout() 88 | 89 | plt.savefig("/tmp/t.pdf", bbox_inches='tight', pad_inches=0, dpi=200) 90 | -------------------------------------------------------------------------------- /testing/test1.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import numpy as np 25 | import dis 26 | 27 | a = np.array([1,2,3]) 28 | g = globals() 29 | 30 | from inspect import currentframe, getframeinfo, stack 31 | def info(): 32 | prev = stack()[1] 33 | return prev.filename, prev.lineno 34 | 35 | def prevline(): 36 | prev = stack()[1] 37 | filename, line = prev.filename, prev.lineno 38 | line -= 1 # get previous line 39 | with open(filename, "r") as f: 40 | code = f.read() 41 | lines = code.split('\n') 42 | if line==1: 43 | return None 44 | return lines[line-1] # indexed from 1 45 | 46 | 47 | def nextline(n): 48 | prev = stack()[1] 49 | filename, line = prev.filename, prev.lineno 50 | line += n # get n lines ahead 51 | with open(filename, "r") as f: 52 | code = f.read() 53 | lines = code.split('\n') 54 | if line==1: 55 | return None 56 | return lines[line-1] # indexed from 1 57 | 58 | 59 | def f(): 60 | W = np.array([[1,2],[3,4]]) 61 | b = np.array([9,10]).reshape(2,1) 62 | x = np.array([4,5]).reshape(2,1) 63 | code = nextline(2).strip() 64 | exec(code) 65 | b = W @ b + x 66 | loc = locals() 67 | print(eval(compile("b*x", "", "eval"))) 68 | 69 | class dbg: 70 | def __enter__(self): 71 | prev = stack()[1] 72 | filename, line = prev.filename, prev.lineno 73 | with open(filename, "r") as f: 74 | code = f.read() 75 | lines = code.split('\n') 76 | line += 1 # next line 77 | code = lines[line-1].strip() # index from 0 78 | print("code to dbg", code) 79 | # c = compile(code, "", "exec") 80 | c = dis.Bytecode(code) 81 | print(c.dis()) 82 | VARLOADS = {'LOAD_NAME','LOAD_GLOBAL'} 83 | varrefs = [I.argval for I in c if I.opname in VARLOADS] 84 | funcrefs = [I.argval for I in c if I.opname in {'LOAD_METHOD'}] 85 | print("symbols",set(varrefs), set(funcrefs)) 86 | def __exit__(self, exc_type, exc_val, exc_tb): 87 | print("exit") 88 | 89 | def hi(): print("hi"); return 99 90 | 91 | 92 | W = np.array([[1, 2], [3, 4]]) 93 | b = np.array([9, 10]).reshape(2, 1) 94 | x = np.array([4, 5]).reshape(2, 1) 95 | 96 | with dbg(): 97 | z = torch.sigmoid(self.Whz @ h + self.Uxz @ x + self.bz) 98 | b = W @ b + np.abs(x) 99 | 100 | # dis.dis("f()") 101 | # dis.dis("a.f()") 102 | # dis.dis("np.pi") 103 | # dis.dis("a+3") -------------------------------------------------------------------------------- /testing/test2.py: -------------------------------------------------------------------------------- 1 | import tsensor 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | W = torch.tensor([[1, 2], [3, 4]]) 7 | b = torch.tensor([9, 10]).reshape(2, 1) 8 | x = torch.tensor([4, 5]).reshape(2, 1) 9 | h = torch.tensor([1,2]) 10 | 11 | 12 | # tsensor.pyviz("a = torch.relu(x)") 13 | # plt.show() 14 | # # 15 | 16 | with tsensor.clarify(): 17 | W @ np.dot(b, b) + np.eye(2, 2) @ x 18 | # b = W @ b + x * 3 + h.dot(h) -------------------------------------------------------------------------------- /testing/test3.py: -------------------------------------------------------------------------------- 1 | import tsensor 2 | import tensorflow as tf 3 | import matplotlib.pyplot as plt 4 | 5 | W = tf.constant([[1, 2], [3, 4]]) 6 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1)) 7 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1)) 8 | z = 0 9 | 10 | # tsensor.parse("z /= b + x * 3", hush_errors=False) 11 | 12 | # with tsensor.clarify(show='viz'): 13 | # b + x * 3 14 | 15 | fig, ax = plt.subplots(1,1) 16 | tsensor.pyviz("b + x", ax=ax) 17 | plt.show() 18 | # with tsensor.explain(): 19 | # b + x 20 | 21 | -------------------------------------------------------------------------------- /testing/test_incr_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.ast import IncrEvalTrap 25 | from tsensor.parsing import PyExprParser 26 | import sys 27 | import numpy as np 28 | import torch 29 | 30 | def check(s,expected): 31 | frame = sys._getframe() 32 | caller = frame.f_back 33 | p = PyExprParser(s) 34 | t = p.parse() 35 | bad_subexpr = None 36 | try: 37 | t.eval(caller) 38 | except IncrEvalTrap as exc: 39 | bad_subexpr = str(exc.offending_expr) 40 | assert bad_subexpr==expected 41 | 42 | 43 | def test_missing_var(): 44 | a = 3 45 | c = 5 46 | check("a+b+c", "b") 47 | check("z+b+c", "z") 48 | 49 | def test_matrix_mult(): 50 | W = torch.tensor([[1, 2], [3, 4]]) 51 | b = torch.tensor([[1,2,3]]) 52 | check("W@b+torch.abs(b)", "W@b") 53 | 54 | def test_bad_arg(): 55 | check("torch.abs('foo')", "torch.abs('foo')") 56 | 57 | def test_parens(): 58 | a = 3 59 | b = 4 60 | c = 5 61 | check("(a+b)/0", "(a+b)/0") 62 | 63 | def test_array_literal(): 64 | a = torch.tensor([[1,2,3],[4,5,6]]) 65 | b = torch.tensor([[1,2,3]]) 66 | a+b 67 | check("a + b@2", """b@2""") 68 | 69 | def test_array_literal2(): 70 | a = torch.tensor([[1,2,3],[4,5,6]]) 71 | b = torch.tensor([[1,2,3]]) 72 | a+b 73 | check("(a+b)@2", """(a+b)@2""") 74 | -------------------------------------------------------------------------------- /testing/test_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.parsing import * 25 | import re 26 | 27 | def check(s, expected_repr, expect_str=None): 28 | p = PyExprParser(s, hush_errors=False) 29 | t = p.parse() 30 | 31 | s = re.sub(r"\s+", "", s) 32 | result_str = str(t) 33 | result_str = re.sub(r"\s+", "", result_str) 34 | if expect_str: 35 | s = expect_str 36 | assert result_str==s 37 | 38 | result_repr = repr(t) 39 | result_repr = re.sub(r"\s+", "", result_repr) 40 | expected_repr = re.sub(r"\s+", "", expected_repr) 41 | # print("result", result_repr) 42 | # print("expected", expected) 43 | assert result_repr == expected_repr 44 | 45 | 46 | def test_assign(): 47 | check("a = 3", "Assign(op=,lhs=a,rhs=3)") 48 | 49 | 50 | def test_index(): 51 | check("a[:,i,j]", "Index(arr=a, index=[:, i, j])") 52 | 53 | 54 | def test_index2(): 55 | check("z = a[:]", "Assign(op=,lhs=z,rhs=Index(arr=a,index=[:]))") 56 | 57 | def test_index3(): 58 | check("g.W[:,:,1]", "Index(arr=Member(op=,obj=g,member=W),index=[:,:,1])") 59 | 60 | def test_literal_list(): 61 | check("[[1, 2], [3, 4]]", 62 | "ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])") 63 | 64 | 65 | def test_literal_array(): 66 | check("np.array([[1, 2], [3, 4]])", 67 | """ 68 | Call(func=Member(op=,obj=np,member=array), 69 | args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])]) 70 | """) 71 | 72 | 73 | def test_method(): 74 | check("h = torch.tanh(h)", 75 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[h]))") 76 | 77 | 78 | def test_method2(): 79 | check("np.dot(b,b)", 80 | "Call(func=Member(op=,obj=np,member=dot),args=[b,b])") 81 | 82 | 83 | def test_field(): 84 | check("a.b", "Member(op=,obj=a,member=b)") 85 | 86 | 87 | def test_member_func(): 88 | check("a.f()", "Call(func=Member(op=,obj=a,member=f),args=[])") 89 | 90 | 91 | def test_field2(): 92 | check("a.b.c", "Member(op=,obj=Member(op=,obj=a,member=b),member=c)") 93 | 94 | 95 | def test_field_and_func(): 96 | check("a.f().c", "Member(op=,obj=Call(func=Member(op=,obj=a,member=f),args=[]),member=c)") 97 | 98 | 99 | def test_parens(): 100 | check("(a+b)*c", "BinaryOp(op=,lhs=SubExpr(e=BinaryOp(op=,lhs=a,rhs=b)),rhs=c)") 101 | 102 | 103 | def test_1tuple(): 104 | check("(3,)", "TupleLiteral(elems=[3])") 105 | 106 | 107 | def test_2tuple(): 108 | check("(3,4)", "TupleLiteral(elems=[3,4])") 109 | 110 | 111 | def test_2tuple_with_trailing_comma(): 112 | check("(3,4,)", "TupleLiteral(elems=[3,4])", expect_str="(3,4)") 113 | 114 | 115 | def test_field_array(): 116 | check("a.b[34]", "Index(arr=Member(op=,obj=a,member=b),index=[34])") 117 | 118 | 119 | def test_field_array_func(): 120 | check("a.b[34].f()", "Call(func=Member(op=,obj=Index(arr=Member(op=,obj=a,member=b),index=[34]),member=f),args=[])") 121 | 122 | 123 | def test_arith(): 124 | check("(1-z)*h + z*h_", 125 | """BinaryOp(op=, 126 | lhs=BinaryOp(op=, 127 | lhs=SubExpr(e=BinaryOp(op=, 128 | lhs=1, 129 | rhs=z)), 130 | rhs=h), 131 | rhs=BinaryOp(op=,lhs=z,rhs=h_))""") 132 | 133 | 134 | def test_chained_op(): 135 | check("a + b + c", 136 | """BinaryOp(op=, 137 | lhs=BinaryOp(op=, lhs=a, rhs=b), 138 | rhs=c)""") 139 | 140 | 141 | def test_matrix_arith(): 142 | check("self.Whz@h + Uxz@x + bz", 143 | """ 144 | BinaryOp(op=, 145 | lhs=BinaryOp(op=, 146 | lhs=BinaryOp(op=,lhs=Member(op=,obj=self,member=Whz),rhs=h), 147 | rhs=BinaryOp(op=,lhs=Uxz,rhs=x)), 148 | rhs=bz) 149 | """) 150 | 151 | def test_kwarg(): 152 | check("torch.relu(torch.rand(size=(2000,)))", 153 | """ 154 | Call(func=Member(op=,obj=torch,member=relu), 155 | args=[Call(func=Member(op=,obj=torch,member=rand), 156 | args=[Assign(op=,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""") -------------------------------------------------------------------------------- /testing/test_tensorflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import tsensor 25 | import tensorflow as tf 26 | 27 | W = tf.constant([[1, 2], [3, 4]]) 28 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1)) 29 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1)) 30 | 31 | def test_addition(): 32 | msg = "" 33 | try: 34 | with tsensor.clarify(): 35 | q = b + x + 3 36 | except tf.errors.InvalidArgumentError as iae: 37 | msg = iae.message 38 | 39 | expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\ 40 | "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)" 41 | assert msg==expected -------------------------------------------------------------------------------- /testing/test_tree_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from tsensor.parsing import * 25 | import tsensor.ast 26 | import sys 27 | import numpy as np 28 | 29 | 30 | def check(s,expected): 31 | frame = sys._getframe() 32 | caller = frame.f_back 33 | p = PyExprParser(s) 34 | t = p.parse() 35 | result = t.eval(caller) 36 | assert str(result)==str(expected) 37 | 38 | 39 | def test_int(): 40 | check("34", 34) 41 | 42 | def test_assign(): 43 | check("a = 34", 34) 44 | 45 | def test_var(): 46 | a = 34 47 | check("a", 34) 48 | 49 | def test_member_var(): 50 | class A: 51 | def __init__(self): 52 | self.a = 34 53 | x = A() 54 | check("x.a", 34) 55 | 56 | def test_member_func(): 57 | class A: 58 | def f(self, a): 59 | return a+4 60 | x = A() 61 | check("x.f(30)", 34) 62 | 63 | def test_index(): 64 | a = [1,2,3] 65 | check("a[2]", 3) 66 | 67 | def test_add(): 68 | a = 3 69 | b = 4 70 | c = 5 71 | check("a+b+c", 12) 72 | 73 | def test_add_mul(): 74 | a = 3 75 | b = 4 76 | c = 5 77 | check("a+b*c", 23) 78 | 79 | def test_parens(): 80 | a = 3 81 | b = 4 82 | c = 5 83 | check("(a+b)*c", 35) 84 | 85 | def test_list_literal(): 86 | a = [[1,2,3],[4,5,6]] 87 | check("a", """[[1, 2, 3], [4, 5, 6]]""") 88 | 89 | 90 | def test_np_literal(): 91 | a = np.array([[1,2,3],[4,5,6]]) 92 | check("a*2", """[[ 2 4 6]\n [ 8 10 12]]""") 93 | 94 | 95 | def test_np_add(): 96 | a = np.array([[1,2,3],[4,5,6]]) 97 | check("a+a", """[[ 2 4 6]\n [ 8 10 12]]""") 98 | 99 | 100 | def test_np_add2(): 101 | a = np.array([[1,2,3],[4,5,6]]) 102 | check("a+a+a", """[[ 3 6 9]\n [12 15 18]]""") 103 | -------------------------------------------------------------------------------- /testing/testexc.py: -------------------------------------------------------------------------------- 1 | foo = None 2 | 3 | try: 4 | try: 5 | state = "bar" 6 | foo.append(state) 7 | 8 | except Exception as e: 9 | e.args = ("Appending '"+state+"' failed", *e.args) 10 | raise 11 | 12 | print(foo[0]) # would raise too 13 | 14 | except Exception as e: 15 | e.message = "foo" 16 | e.args = ("print(foo) failed: " + str(foo), *e.args) 17 | raise -------------------------------------------------------------------------------- /testing/viz_testing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import tensorflow as tf 5 | import graphviz 6 | import tempfile 7 | import matplotlib.patches as patches 8 | import matplotlib.pyplot as plt 9 | import matplotlib.font_manager as fm 10 | 11 | 12 | # print('\n'.join(str(f) for f in fm.fontManager.ttflist)) 13 | import tsensor 14 | # from tsensor.viz import pyviz, astviz 15 | 16 | def foo(): 17 | # W = torch.rand(size=(2000, 2000)) 18 | W = torch.rand(size=(2000, 2000, 10, 8)) 19 | b = torch.rand(size=(2000, 1)) 20 | h = torch.rand(size=(1_000_000,)) 21 | x = torch.rand(size=(2000, 1)) 22 | # g = tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", 23 | # sys._getframe()) 24 | frame = sys._getframe() 25 | frame = None 26 | g = tsensor.astviz("b = W[:,:,0,0]@b + (h+3).dot(h) + torch.abs(torch.tensor(34))", 27 | frame) 28 | g.view() 29 | 30 | #foo() 31 | 32 | class GRU: 33 | def __init__(self): 34 | self.W = torch.rand(size=(2,20,2000,10)) 35 | self.b = torch.rand(size=(20,1)) 36 | # self.x = torch.tensor([4, 5]).reshape(2, 1) 37 | self.h = torch.rand(size=(1_000_000,)) 38 | self.a = 3 39 | print(self.W.shape) 40 | print(self.W[:, :, 1].shape) 41 | 42 | def get(self): 43 | return torch.tensor([[1, 2], [3, 4]]) 44 | 45 | # W = torch.tensor([[1, 2], [3, 4]]) 46 | b = torch.rand(size=(2000,1)) 47 | h = torch.rand(size=(1_000_000,2)) 48 | x = torch.rand(size=(1_000_000,2)) 49 | a = 3 50 | 51 | foo = torch.rand(size=(2000,)) 52 | torch.relu(foo) 53 | 54 | g = GRU() 55 | 56 | with tsensor.clarify(): 57 | tf.constant([1,2]) @ tf.constant([1,3]) 58 | 59 | 60 | # code = "b = g.W[0,:,:,1]@b+torch.zeros(200,1)+(h+3).dot(h)" 61 | # code = "torch.relu(foo)" 62 | # code = "np.dot(b,b)" 63 | # g = tsensor.pyviz(code, fontname='Courier New', fontsize=16, dimfontsize=9, 64 | # char_sep_scale=1.8, hush_errors=False) 65 | # plt.tight_layout() 66 | # plt.savefig("/tmp/t.svg", dpi=200, bbox_inches='tight', pad_inches=0) 67 | 68 | # W = torch.tensor([[1, 2], [3, 4]]) 69 | # x = torch.tensor([4, 5]).reshape(2, 1) 70 | # with tsensor.explain(): 71 | # b = torch.rand(size=(2000,)) 72 | # torch.relu(b) 73 | 74 | 75 | # g = GRU() 76 | # 77 | # g1 = tsensor.astviz("b = g.W@b + torch.eye(3,3)") 78 | # g1.view() 79 | # g1 = tsensor.pyviz("b = g.W@b") 80 | # g1.view() 81 | # g2 = tsensor.astviz("b = g.W@b + g.h.dot(g.h) + torch.abs(torch.tensor(34))") 82 | -------------------------------------------------------------------------------- /tsensor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | __all__ = ["ast", "parsing", "viz", "analysis", "version"] 25 | 26 | # These classes/functions are the primary user interface so import them directly 27 | import tsensor.ast 28 | import tsensor.parsing 29 | import tsensor.viz 30 | import tsensor.analysis 31 | from tsensor.analysis import explain, clarify, eval 32 | from tsensor.parsing import parse 33 | from tsensor.viz import pyviz, astviz 34 | 35 | from .version import __version__ -------------------------------------------------------------------------------- /tsensor/analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import os 25 | import sys 26 | import traceback 27 | import torch 28 | import inspect 29 | 30 | import matplotlib.pyplot as plt 31 | 32 | import tsensor 33 | 34 | class clarify: 35 | def __init__(self, 36 | fontname='Consolas', fontsize=13, 37 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4", 38 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443', 39 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 40 | show:(None,'viz')='viz'): 41 | """ 42 | Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow. 43 | Also display a visual representation of the offending Python line that 44 | shows the shape of tensors referenced by the code. All you have to do is wrap 45 | the outermost level of your code and clarify() will activate upon exception. 46 | 47 | Visualizations pop up in a separate window unless running from a notebook, 48 | in which case the visualization appears as part of the cell execution output. 49 | 50 | There is no runtime overhead associated with clarify() unless an exception occurs. 51 | 52 | The offending code is executed a second time, to identify which sub expressions 53 | are to blame. This implies that code with side effects could conceivably cause 54 | a problem, but since an exception has been generated, results are suspicious 55 | anyway. 56 | 57 | Example: 58 | 59 | import numpy as np 60 | import tsensor 61 | 62 | b = np.array([9, 10]).reshape(2, 1) 63 | with tsensor.clarify(): 64 | np.dot(b,b) # tensor code or call to a function with tensor code 65 | 66 | See examples.ipynb for more examples. 67 | 68 | :param fontname: The name of the font used to display Python code 69 | :param fontsize: The font size used to display Python code; default is 13. 70 | Also use this to increase the size of the generated figure; 71 | larger font size means larger image. 72 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 73 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 74 | :param matrixcolor: The color of matrix boxes 75 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,). 76 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 77 | text is when plotted in matplotlib. In fact there's probably, 78 | no hope to discover this information accurately in all cases. 79 | Certainly, I gave up after spending huge effort. We have a 80 | situation here where the font should be constant width, so 81 | we can just use a simple scaler times the font size to get 82 | a reasonable approximation to the width and height of a 83 | character box; the default of 1.8 seems to work reasonably 84 | well for a wide range of fonts, but you might have to tweak it 85 | when you change the font size. 86 | :param fontcolor: The color of the Python code. 87 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 88 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression 89 | :param error_op_color: The color to use for characters associated with the erroneous operator 90 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 91 | :param dpi: This library tries to generate SVG files, which are vector graphics not 92 | 2D arrays of pixels like PNG files. However, it needs to know how to 93 | compute the exact figure size to remove padding around the visualization. 94 | Matplotlib uses inches for its figure size and so we must convert 95 | from pixels or data units to inches, which means we have to know what the 96 | dots per inch, dpi, is for the image. 97 | :param hush_errors: Normally, error messages from true syntax errors but also 98 | unhandled code caught by my parser are ignored. Turn this off 99 | to see what the error messages are coming from my parser. 100 | :param show: Show visualization upon tensor error if show='viz'. 101 | """ 102 | self.show = show 103 | self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \ 104 | self.matrixcolor, self.vectorcolor, self.char_sep_scale,\ 105 | self.fontcolor, self.underline_color, self.ignored_color, self.error_op_color = \ 106 | fontname, fontsize, dimfontname, dimfontsize, \ 107 | matrixcolor, vectorcolor, char_sep_scale, \ 108 | fontcolor, underline_color, ignored_color, error_op_color 109 | 110 | def __enter__(self): 111 | self.frame = sys._getframe().f_back # where do we start tracking 112 | return self 113 | 114 | def __exit__(self, exc_type, exc_value, exc_traceback): 115 | if exc_type is not None and is_interesting_exception(exc_value): 116 | # print("exception:", exc_value, exc_traceback) 117 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout) 118 | exc_frame = deepest_frame(exc_traceback) 119 | module, name, filename, line, code = info(exc_frame) 120 | # print('info', module, name, filename, line, code) 121 | if code is not None: 122 | view = tsensor.viz.pyviz(code, exc_frame, 123 | self.fontname, self.fontsize, self.dimfontname, 124 | self.dimfontsize, self.matrixcolor, self.vectorcolor, 125 | self.char_sep_scale, self.fontcolor, 126 | self.underline_color, self.ignored_color, 127 | self.error_op_color) 128 | if self.show=='viz': 129 | view.show() 130 | augment_exception(exc_value, view.offending_expr) 131 | 132 | 133 | class explain: 134 | def __init__(self, 135 | fontname='Consolas', fontsize=13, 136 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4", 137 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443', 138 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 139 | savefig=None): 140 | """ 141 | As the Python virtual machine executes lines of code, generate a 142 | visualization for tensor-related expressions using from numpy, pytorch, 143 | and tensorflow. The shape of tensors referenced by the code are displayed. 144 | 145 | Visualizations pop up in a separate window unless running from a notebook, 146 | in which case the visualization appears as part of the cell execution output. 147 | 148 | There is heavy runtime overhead associated with explain() as every line 149 | is executed twice: once by explain() and then another time by the interpreter 150 | as part of normal execution. 151 | 152 | Expressions with side effects can easily generate incorrect results. Due to 153 | this and the overhead, you should limit the use of this to code you're trying 154 | to debug. Assignments are not evaluated by explain so code `x = ...` causes 155 | an assignment to x just once, during normal execution. This explainer 156 | knows the value of x and will display it but does not assign to it. 157 | 158 | Upon exception, execution will stop as usual but, like clarify(), explain() 159 | will augment the exception to indicate the offending sub expression. Further, 160 | the visualization will deemphasize code not associated with the offending 161 | sub expression. The sizes of relevant tensor values are still visualized. 162 | 163 | Example: 164 | 165 | import numpy as np 166 | import tsensor 167 | 168 | b = np.array([9, 10]).reshape(2, 1) 169 | with tsensor.explain(): 170 | b + b # tensor code or call to a function with tensor code 171 | 172 | See examples.ipynb for more examples. 173 | 174 | :param fontname: The name of the font used to display Python code 175 | :param fontsize: The font size used to display Python code; default is 13. 176 | Also use this to increase the size of the generated figure; 177 | larger font size means larger image. 178 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 179 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 180 | :param matrixcolor: The color of matrix boxes 181 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,). 182 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 183 | text is when plotted in matplotlib. In fact there's probably, 184 | no hope to discover this information accurately in all cases. 185 | Certainly, I gave up after spending huge effort. We have a 186 | situation here where the font should be constant width, so 187 | we can just use a simple scaler times the font size to get 188 | a reasonable approximation to the width and height of a 189 | character box; the default of 1.8 seems to work reasonably 190 | well for a wide range of fonts, but you might have to tweak it 191 | when you change the font size. 192 | :param fontcolor: The color of the Python code. 193 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 194 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression 195 | :param error_op_color: The color to use for characters associated with the erroneous operator 196 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 197 | :param dpi: This library tries to generate SVG files, which are vector graphics not 198 | 2D arrays of pixels like PNG files. However, it needs to know how to 199 | compute the exact figure size to remove padding around the visualization. 200 | Matplotlib uses inches for its figure size and so we must convert 201 | from pixels or data units to inches, which means we have to know what the 202 | dots per inch, dpi, is for the image. 203 | :param hush_errors: Normally, error messages from true syntax errors but also 204 | unhandled code caught by my parser are ignored. Turn this off 205 | to see what the error messages are coming from my parser. 206 | :param savefig: A string indicating where to save the visualization; don't save 207 | a file if None. 208 | """ 209 | self.savefig = savefig 210 | self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \ 211 | self.matrixcolor, self.vectorcolor, self.char_sep_scale,\ 212 | self.fontcolor, self.underline_color, self.ignored_color, self.error_op_color = \ 213 | fontname, fontsize, dimfontname, dimfontsize, \ 214 | matrixcolor, vectorcolor, char_sep_scale, \ 215 | fontcolor, underline_color, ignored_color, error_op_color 216 | 217 | def __enter__(self, format="svg"): 218 | # print("ON trace") 219 | self.tracer = ExplainTensorTracer(self.savefig, format=format) 220 | sys.settrace(self.tracer.listener) 221 | frame = sys._getframe() 222 | prev = frame.f_back # get block wrapped in "with" 223 | prev.f_trace = self.tracer.listener 224 | return self.tracer 225 | 226 | def __exit__(self, exc_type, exc_value, exc_traceback): 227 | sys.settrace(None) 228 | # At this point we have already tried to visualize the statement 229 | # If there was no error, the visualization will look normal 230 | # but a matrix operation error will show the erroneous operator highlighted. 231 | # That was artificial execution of the code. Now the VM has executed 232 | # the statement for real and has found the same exception. Make sure to 233 | # augment the message with causal information. 234 | if exc_type is not None and is_interesting_exception(exc_value): 235 | # print("exception:", exc_value, exc_traceback) 236 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout) 237 | exc_frame = deepest_frame(exc_traceback) 238 | module, name, filename, line, code = info(exc_frame) 239 | # print('info', module, name, filename, line, code) 240 | if code is not None: 241 | # We've already displayed picture so just augment message 242 | root, tokens = tsensor.parsing.parse(code) 243 | if root is not None: # Could be syntax error in statement or code I can't handle 244 | offending_expr = None 245 | try: 246 | root.eval(exc_frame) 247 | except tsensor.ast.IncrEvalTrap as e: 248 | offending_expr = e.offending_expr 249 | augment_exception(exc_value, offending_expr) 250 | 251 | 252 | class ExplainTensorTracer: 253 | def __init__(self, savefig:str=None, format="svg", modules=['__main__'], filenames=[]): 254 | self.savefig = savefig 255 | self.format = format 256 | self.modules = modules 257 | self.filenames = filenames 258 | self.exceptions = set() 259 | self.linecount = 0 260 | self.views = [] 261 | 262 | def listener(self, frame, event, arg): 263 | module = frame.f_globals['__name__'] 264 | if module not in self.modules: 265 | return 266 | 267 | info = inspect.getframeinfo(frame) 268 | filename, line = info.filename, info.lineno 269 | name = info.function 270 | if len(self.filenames)>0 and filename not in self.filenames: 271 | return 272 | 273 | if event=='line': 274 | self.line_listener(module, name, filename, line, info, frame) 275 | 276 | return None 277 | 278 | def line_listener(self, module, name, filename, line, info, frame): 279 | code = info.code_context[0].strip() 280 | if code.startswith("sys.settrace(None)"): 281 | return 282 | self.linecount += 1 283 | p = tsensor.parsing.PyExprParser(code) 284 | t = p.parse() 285 | if t is not None: 286 | # print(f"A line encountered in {module}.{name}() at {filename}:{line}") 287 | # print("\t", code) 288 | # print("\t", repr(t)) 289 | ExplainTensorTracer.viz_statement(self, code, frame) 290 | 291 | @staticmethod 292 | def viz_statement(tracer, code, frame): 293 | view = tsensor.viz.pyviz(code, frame) 294 | tracer.views.append(view) 295 | if tracer.savefig is not None: 296 | svgfilename = f"{tracer.savefig}-{tracer.linecount}.svg" 297 | view.savefig(svgfilename) 298 | view.filename = svgfilename 299 | plt.close() 300 | else: 301 | view.show() 302 | return view 303 | 304 | 305 | def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object): 306 | """ 307 | Parse statement and return an ast in the context of execution frame or, if None, 308 | the invoking function's frame. Set the value field of all ast nodes. 309 | Overall result is in root.value. 310 | :param statement: A string representing the line of Python code to visualize within an execution frame. 311 | :param frame: The execution frame in which to evaluate the statement. If None, 312 | use the execution frame of the invoking function 313 | :return An abstract parse tree representing the statement; nodes are 314 | ParseTreeNode subclasses. 315 | """ 316 | p = tsensor.parsing.PyExprParser(statement) 317 | root = p.parse() 318 | if frame is None: # use frame of caller 319 | frame = sys._getframe().f_back 320 | root.eval(frame) 321 | return root, root.value 322 | 323 | 324 | def augment_exception(exc_value, subexpr): 325 | explanation = subexpr.clarify() 326 | augment = "" 327 | if explanation is not None: 328 | augment = explanation 329 | # Reuse exception but overwrite the message 330 | # print(f"Exc type is {type(exc_value)}, len(args)={len(exc_value.args)}, has '_message'=={hasattr(exc_value, '_message')}") 331 | # print(f"Msg {str(exc_value)}") 332 | if hasattr(exc_value, "_message"): 333 | exc_value._message = exc_value.message + "\n" + augment 334 | else: 335 | exc_value.args = [exc_value.args[0] + "\n" + augment] 336 | 337 | 338 | def is_interesting_exception(e): 339 | # print(f"is_interesting_exception: type is {type(e)}") 340 | if e.__class__.__module__.startswith("tensorflow"): 341 | return True 342 | sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension', 343 | 'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix'} 344 | if len(e.args)==0: 345 | msg = e.message 346 | else: 347 | msg = e.args[0] 348 | return sum([s in msg for s in sentinels])>0 349 | 350 | 351 | def deepest_frame(exc_traceback): 352 | """ 353 | Don't trace into internals of numpy/torch/tensorflow; we want to reset frame 354 | to where in the user's python code it asked the tensor lib to perform an 355 | invalid operation. 356 | 357 | To detect libraries, look for code whose filename has "site-packages/{package}" 358 | or "dist-packages/{package}". 359 | """ 360 | tb = exc_traceback 361 | packages = ['numpy','torch','tensorflow'] 362 | dirs = [os.path.join('site-packages',p) for p in packages] 363 | dirs += [os.path.join('dist-packages',p) for p in packages] 364 | dirs += ['<__array_function__'] # numpy seems to not have real filename 365 | prev = tb 366 | while tb is not None: 367 | filename = tb.tb_frame.f_code.co_filename 368 | reached_lib = [p in filename for p in dirs] 369 | if sum(reached_lib)>0: 370 | break 371 | prev = tb 372 | tb = tb.tb_next 373 | return prev.tb_frame 374 | 375 | 376 | def info(frame): 377 | if hasattr(frame, '__name__'): 378 | module = frame.f_globals['__name__'] 379 | else: 380 | module = None 381 | info = inspect.getframeinfo(frame) 382 | if info.code_context is not None: 383 | code = info.code_context[0].strip() 384 | else: 385 | code = None 386 | filename, line = info.filename, info.lineno 387 | name = info.function 388 | return module, name, filename, line, code 389 | 390 | 391 | def smallest_matrix_subexpr(t): 392 | """ 393 | During visualization, we need to find the smallest expression 394 | that evaluates to a non-scalar. That corresponds to the deepest subtree 395 | that evaluates to a non-scalar. Because we do not have parent pointers, 396 | we cannot start at the leaves and walk upwards. Instead, set a Boolean 397 | in each node to indicate whether one of the descendents (but not itself) 398 | evaluates to a non-scalar. Nodes in the tree that have matrix values and 399 | not matrix_below are the ones to visualize. 400 | 401 | This routine modifies the tree nodes to turn on matrix_below where appropriate. 402 | """ 403 | nodes = [] 404 | _smallest_matrix_subexpr(t, nodes) 405 | return nodes 406 | 407 | def _smallest_matrix_subexpr(t, nodes) -> bool: 408 | if t is None: return False # prevent buggy code from causing us to fail 409 | if len(t.kids)==0: # leaf node 410 | if istensor(t.value): 411 | nodes.append(t) 412 | return istensor(t.value) 413 | n_matrix_below = 0 # once this latches true, it's passed all the way up to the root 414 | for sub in t.kids: 415 | matrix_below = _smallest_matrix_subexpr(sub, nodes) 416 | n_matrix_below += matrix_below # how many descendents evaluated two non-scalar? 417 | # If current node is matrix and no descendents are, then this is smallest 418 | # sub expression that evaluates to a matrix; keep track 419 | if istensor(t.value) and n_matrix_below==0: 420 | nodes.append(t) 421 | # Report to caller that this node or some descendent is a matrix 422 | return istensor(t.value) or n_matrix_below > 0 423 | 424 | 425 | def istensor(x): 426 | return _shape(x) is not None 427 | 428 | 429 | def _shape(v): 430 | # do we have a shape and it answers len()? Should get stuff right. 431 | if hasattr(v, "shape") and hasattr(v.shape,"__len__"): 432 | if isinstance(v.shape, torch.Size): 433 | if len(v.shape)==0: 434 | return None 435 | return list(v.shape) 436 | return v.shape 437 | return None 438 | -------------------------------------------------------------------------------- /tsensor/ast.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import tsensor 25 | 26 | # Parse tree definitions 27 | # I found package ast in python3 lib after I built this. whoops. No biggie. 28 | # This tree structure is easier to visit for my purposes here. Also lets me 29 | # control the kinds of statements I process. 30 | 31 | class ParseTreeNode: 32 | def __init__(self): 33 | self.value = None # used during evaluation 34 | # self.matrix_below = False # indicates decendant has non-scalar value UNUSED 35 | self.start = None # start token 36 | self.stop = None # end token 37 | def eval(self, frame): 38 | """ 39 | Evaluate the expression represented by this (sub)tree in context of frame. 40 | Try any exception found while evaluating and remember which operation that 41 | was in this tree 42 | """ 43 | try: 44 | self.value = eval(str(self), frame.f_locals, frame.f_globals) 45 | except BaseException as e: 46 | raise IncrEvalTrap(self) from e 47 | # print(self, "=>", self.value) 48 | return self.value 49 | @property 50 | def optokens(self): # the associated token if atom or representative token if operation 51 | return None 52 | @property 53 | def kids(self): 54 | return [] 55 | def clarify(self): 56 | return None 57 | def __str__(self): 58 | pass 59 | def __repr__(self): 60 | fields = self.__dict__.copy() 61 | kill = ['start', 'stop', 'lbrack', 'lparen'] 62 | for name in kill: 63 | if name in fields: del fields[name] 64 | args = [v+'='+fields[v].__repr__() for v in fields if v!='value' or fields['value'] is not None] 65 | args = ','.join(args) 66 | return f"{self.__class__.__name__}({args})" 67 | 68 | class Assign(ParseTreeNode): 69 | def __init__(self, op, lhs, rhs, start, stop): 70 | super().__init__() 71 | self.op, self.lhs, self.rhs = op, lhs, rhs 72 | self.start, self.stop = start, stop 73 | def eval(self, frame): 74 | self.value = self.rhs.eval(frame) 75 | # Don't eval this node as it causes side effect of making actual assignment to lhs 76 | self.lhs.value = self.value 77 | return self.value 78 | @property 79 | def optokens(self): 80 | return [self.op] 81 | @property 82 | def kids(self): 83 | return [self.lhs, self.rhs] 84 | def __str__(self): 85 | return str(self.lhs)+'='+str(self.rhs) 86 | 87 | class Call(ParseTreeNode): 88 | def __init__(self, func, lparen, args, start, stop): 89 | super().__init__() 90 | self.func = func 91 | self.lparen = lparen 92 | self.args = args 93 | self.start, self.stop = start, stop 94 | def eval(self, frame): 95 | self.func.eval(frame) 96 | for a in self.args: 97 | a.eval(frame) 98 | return super().eval(frame) 99 | def clarify(self): 100 | arg_msgs = [] 101 | for a in self.args: 102 | ashape = tsensor.analysis._shape(a.value) 103 | if ashape: 104 | arg_msgs.append(f"arg {a} w/shape {ashape}") 105 | if len(arg_msgs)==0: 106 | return f"Cause: {self}" 107 | return f"Cause: {self} tensor " + ', '.join(arg_msgs) 108 | @property 109 | def optokens(self): 110 | f = None # assume complicated like a[i](args) with weird func expr 111 | if isinstance(self.func, Member): 112 | f = self.func.member 113 | elif isinstance(self.func, Atom): 114 | f = self.func 115 | if f: 116 | return [f.token,self.lparen,self.stop] 117 | return [self.lparen,self.stop] 118 | @property 119 | def kids(self): 120 | return [self.func]+self.args 121 | def __str__(self): 122 | args_ = ','.join([str(a) for a in self.args]) 123 | return f"{self.func}({args_})" 124 | 125 | class Index(ParseTreeNode): 126 | def __init__(self, arr, lbrack, index, start, stop): 127 | super().__init__() 128 | self.arr = arr 129 | self.lbrack = lbrack 130 | self.index = index 131 | self.start, self.stop = start, stop 132 | def eval(self, frame): 133 | self.arr.eval(frame) 134 | for i in self.index: 135 | i.eval(frame) 136 | return super().eval(frame) 137 | @property 138 | def optokens(self): 139 | arr = None # assume complicated like f()[i] with no clear array var 140 | # if isinstance(self.arr, Member): 141 | # arr = self.arr.member 142 | # elif isinstance(self.arr, Atom): 143 | # arr = self.arr 144 | # if arr: 145 | # return [self.lbrack,self.stop] 146 | return [self.lbrack,self.stop] 147 | @property 148 | def kids(self): 149 | return [self.arr] + self.index 150 | def __str__(self): 151 | i = self.index 152 | i = ','.join(str(v) for v in i) 153 | return f"{self.arr}[{i}]" 154 | 155 | class Member(ParseTreeNode): 156 | def __init__(self, op, obj, member, start, stop): 157 | super().__init__() 158 | self.op = op # always DOT 159 | self.obj = obj 160 | self.member = member 161 | self.start, self.stop = start, stop 162 | def eval(self, frame): 163 | self.obj.eval(frame) 164 | # don't eval member as it's just a name to look up in obj 165 | return super().eval(frame) 166 | @property 167 | def optokens(self): # the associated token if atom or representative token if operation 168 | return [self.op] 169 | @property 170 | def kids(self): 171 | return [self.obj, self.member] 172 | def __str__(self): 173 | return f"{self.obj}.{self.member}" 174 | 175 | class BinaryOp(ParseTreeNode): 176 | def __init__(self, op, lhs, rhs, start, stop): 177 | super().__init__() 178 | self.op, self.lhs, self.rhs = op, lhs, rhs 179 | self.start, self.stop = start, stop 180 | def eval(self, frame): 181 | self.lhs.eval(frame) 182 | self.rhs.eval(frame) 183 | return super().eval(frame) 184 | def clarify(self): 185 | opnd_msgs = [] 186 | lshape = tsensor.analysis._shape(self.lhs.value) 187 | rshape = tsensor.analysis._shape(self.rhs.value) 188 | if lshape: 189 | opnd_msgs.append(f"operand {self.lhs} w/shape {lshape}") 190 | if rshape: 191 | opnd_msgs.append(f"operand {self.rhs} w/shape {rshape}") 192 | return f"Cause: {self.op} on tensor " + ' and '.join(opnd_msgs) 193 | @property 194 | def optokens(self): # the associated token if atom or representative token if operation 195 | return [self.op] 196 | @property 197 | def kids(self): 198 | return [self.lhs, self.rhs] 199 | def __str__(self): 200 | return f"{self.lhs}{self.op}{self.rhs}" 201 | 202 | class UnaryOp(ParseTreeNode): 203 | def __init__(self, op, opnd, start, stop): 204 | super().__init__() 205 | self.op = op 206 | self.opnd = opnd 207 | self.start, self.stop = start, stop 208 | def eval(self, frame): 209 | self.opnd.eval(frame) 210 | return super().eval(frame) 211 | @property 212 | def optokens(self): 213 | return [self.op] 214 | @property 215 | def kids(self): 216 | return [self.opnd] 217 | def __str__(self): 218 | return f"{self.op}{self.opnd}" 219 | 220 | class ListLiteral(ParseTreeNode): 221 | def __init__(self, elems, start, stop): 222 | super().__init__() 223 | self.elems = elems 224 | self.start, self.stop = start, stop 225 | def eval(self, frame): 226 | for i in self.elems: 227 | i.eval(frame) 228 | return super().eval(frame) 229 | @property 230 | def kids(self): 231 | return self.elems 232 | def __str__(self): 233 | if isinstance(self.elems,list): 234 | elems_ = ','.join(str(e) for e in self.elems) 235 | else: 236 | elems_ = self.elems 237 | return f"[{elems_}]" 238 | 239 | class TupleLiteral(ParseTreeNode): 240 | def __init__(self, elems, start, stop): 241 | super().__init__() 242 | self.elems = elems 243 | self.start, self.stop = start, stop 244 | def eval(self, frame): 245 | for i in self.elems: 246 | i.eval(frame) 247 | return super().eval(frame) 248 | @property 249 | def kids(self): 250 | return self.elems 251 | def __str__(self): 252 | if len(self.elems)==1: 253 | return f"({self.elems[0]},)" 254 | else: 255 | return f"({','.join(str(e) for e in self.elems)})" 256 | 257 | class SubExpr(ParseTreeNode): 258 | # record parens for later display to keep precedence 259 | def __init__(self, e, start, stop): 260 | super().__init__() 261 | self.e = e 262 | self.start, self.stop = start, stop 263 | def eval(self, frame): 264 | self.value = self.e.eval(frame) 265 | return self.value # don't re-evaluate 266 | @property 267 | def optokens(self): 268 | return [self.start, self.stop] 269 | @property 270 | def kids(self): 271 | return [self.e] 272 | def __str__(self): 273 | return f"({self.e})" 274 | 275 | class Atom(ParseTreeNode): 276 | def __init__(self, token): 277 | super().__init__() 278 | self.token = token 279 | self.start, self.stop = token, token 280 | def eval(self, frame): 281 | if self.token.type == tsensor.parsing.COLON: 282 | return ':' # fake a value here 283 | return super().eval(frame) 284 | def __repr__(self): 285 | # v = f"{{{self.value}}}" if hasattr(self,'value') and self.value is not None else "" 286 | return self.token.value 287 | def __str__(self): 288 | return self.token.value 289 | 290 | 291 | def postorder(t): 292 | nodes = [] 293 | _postorder(t, nodes) 294 | return nodes 295 | 296 | 297 | def _postorder(t, nodes): 298 | if t is None: 299 | return 300 | for sub in t.kids: 301 | _postorder(sub, nodes) 302 | nodes.append(t) 303 | 304 | 305 | def leaves(t): 306 | nodes = [] 307 | _leaves(t, nodes) 308 | return nodes 309 | 310 | 311 | def _leaves(t, nodes): 312 | if t is None: 313 | return 314 | if len(t.kids) == 0: 315 | nodes.append(t) 316 | return 317 | for sub in t.kids: 318 | _leaves(sub, nodes) 319 | 320 | 321 | def walk(t, pre=lambda x: None, post=lambda x: None): 322 | if t is None: 323 | return 324 | pre(t) 325 | for sub in t.kids: 326 | walk(sub, pre, post) 327 | post(t) 328 | 329 | 330 | class IncrEvalTrap(BaseException): 331 | """ 332 | Used during re-evaluation of python line that threw exception to trap which 333 | subexpression caused the problem. 334 | """ 335 | def __init__(self, offending_expr): 336 | self.offending_expr = offending_expr # where in tree did we get exception? 337 | -------------------------------------------------------------------------------- /tsensor/parsing.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | from io import BytesIO 25 | import token 26 | import keyword 27 | from tokenize import tokenize, \ 28 | NUMBER, STRING, NAME, OP, ENDMARKER, LPAR, LSQB, RPAR, RSQB, COMMA, COLON,\ 29 | PLUS, MINUS, STAR, SLASH, AT, PERCENT, TILDE, DOT,\ 30 | NOTEQUAL, PERCENTEQUAL, AMPEREQUAL, DOUBLESTAREQUAL, STAREQUAL, PLUSEQUAL,\ 31 | MINEQUAL, DOUBLESLASHEQUAL, SLASHEQUAL, LEFTSHIFTEQUAL,\ 32 | LESSEQUAL, EQUAL, EQEQUAL, GREATEREQUAL, RIGHTSHIFTEQUAL, ATEQUAL,\ 33 | CIRCUMFLEXEQUAL, VBAREQUAL 34 | 35 | import tsensor.ast 36 | 37 | 38 | ADDOP = {PLUS, MINUS} 39 | MULOP = {STAR, SLASH, AT, PERCENT} 40 | ASSIGNOP = {NOTEQUAL, 41 | PERCENTEQUAL, 42 | AMPEREQUAL, 43 | DOUBLESTAREQUAL, 44 | STAREQUAL, 45 | PLUSEQUAL, 46 | MINEQUAL, 47 | DOUBLESLASHEQUAL, 48 | SLASHEQUAL, 49 | LEFTSHIFTEQUAL, 50 | LESSEQUAL, 51 | EQUAL, 52 | EQEQUAL, 53 | GREATEREQUAL, 54 | RIGHTSHIFTEQUAL, 55 | ATEQUAL, 56 | CIRCUMFLEXEQUAL, 57 | VBAREQUAL} 58 | UNARYOP = {TILDE} 59 | 60 | class Token: 61 | """My own version of a token, with content copied from Python's TokenInfo object.""" 62 | def __init__(self, type, value, 63 | index, # token index 64 | cstart_idx, # char start 65 | cstop_idx, # one past char end index so text[start_idx:stop_idx] works 66 | line): 67 | self.type, self.value, self.index, self.cstart_idx, self.cstop_idx, self.line = \ 68 | type, value, index, cstart_idx, cstop_idx, line 69 | def __repr__(self): 70 | return f"<{token.tok_name[self.type]}:{self.value},{self.cstart_idx}:{self.cstop_idx}>" 71 | def __str__(self): 72 | return self.value 73 | 74 | 75 | def mytokenize(s): 76 | "Use Python's tokenizer to lex s and collect my own token objects" 77 | tokensO = tokenize(BytesIO(s.encode('utf-8')).readline) 78 | tokens = [] 79 | i = 0 80 | for tok in tokensO: 81 | type, value, start, end, _ = tok 82 | line = start[0] 83 | start_idx = start[1] 84 | stop_idx = end[1] # one past end index 85 | if type in {NUMBER, STRING, NAME, OP, ENDMARKER}: 86 | tokens.append(Token(tok.exact_type,value,i,start_idx,stop_idx,line)) 87 | i += 1 88 | else: 89 | # print("ignoring", type, value) 90 | pass 91 | # It leaves ENDMARKER on end. set text to "" 92 | tokens[-1].value = "" 93 | # print(tokens) 94 | return tokens 95 | 96 | 97 | class PyExprParser: 98 | """ 99 | A recursive-descent parser for subset of Python expressions and assignments. 100 | There is a built-in parser, but I only want to process Python code this library 101 | can handle and I also want my own kind of abstract syntax tree. Constantly, 102 | it's easier if I just parse the code I care about and ignore everything else. 103 | Building this parser was certainly no great burden. 104 | """ 105 | def __init__(self, code:str, hush_errors=True): 106 | self.code = code 107 | self.hush_errors = hush_errors 108 | self.tokens = mytokenize(code) 109 | self.t = 0 # current lookahead 110 | 111 | def parse(self): 112 | # print("\nparse", self.code) 113 | # print(self.tokens) 114 | # only process assignments and expressions 115 | root = None 116 | if not keyword.iskeyword(self.tokens[0].value): 117 | if self.hush_errors: 118 | try: 119 | root = self.assignment_or_expr() 120 | self.match(ENDMARKER) 121 | except SyntaxError as e: 122 | root = None 123 | else: 124 | root = self.assignment_or_expr() 125 | self.match(ENDMARKER) 126 | return root 127 | 128 | def assignment_or_expr(self): 129 | start = self.LT(1) 130 | lhs = self.expression() 131 | if self.LA(1) in ASSIGNOP: 132 | eq = self.LT(1) 133 | self.t += 1 134 | rhs = self.expression() 135 | stop = self.LT(-1) 136 | return tsensor.ast.Assign(eq,lhs,rhs,start,stop) 137 | return lhs 138 | 139 | def expression(self): 140 | return self.addexpr() 141 | 142 | def addexpr(self): 143 | start = self.LT(1) 144 | root = self.multexpr() 145 | while self.LA(1) in ADDOP: 146 | op = self.LT(1) 147 | self.t += 1 148 | b = self.multexpr() 149 | stop = self.LT(-1) 150 | root = tsensor.ast.BinaryOp(op, root, b, start, stop) 151 | return root 152 | 153 | def multexpr(self): 154 | start = self.LT(1) 155 | root = self.unaryexpr() 156 | while self.LA(1) in MULOP: 157 | op = self.LT(1) 158 | self.t += 1 159 | b = self.unaryexpr() 160 | stop = self.LT(-1) 161 | root = tsensor.ast.BinaryOp(op, root, b, start, stop) 162 | return root 163 | 164 | def unaryexpr(self): 165 | start = self.LT(1) 166 | if self.LA(1) in UNARYOP: 167 | op = self.LT(1) 168 | self.t += 1 169 | e = self.unaryexpr() 170 | stop = self.LT(-1) 171 | return tsensor.ast.UnaryOp(op, e, start, stop) 172 | elif self.isatom() or self.isgroup(): 173 | return self.postexpr() 174 | else: 175 | self.error(f"missing unary expr at: {self.LT(1)}") 176 | 177 | def postexpr(self): 178 | start = self.LT(1) 179 | root = self.atom() 180 | while self.LA(1) in {LPAR, LSQB, DOT}: 181 | if self.LA(1)==LPAR: 182 | lp = self.LT(1) 183 | self.match(LPAR) 184 | el = [] 185 | if self.LA(1) != RPAR: 186 | el = self.arglist() 187 | self.match(RPAR) 188 | stop = self.LT(-1) 189 | root = tsensor.ast.Call(root, lp, el, start, stop) 190 | if self.LA(1)==LSQB: 191 | lb = self.LT(1) 192 | self.match(LSQB) 193 | el = self.exprlist() 194 | self.match(RSQB) 195 | stop = self.LT(-1) 196 | root = tsensor.ast.Index(root, lb, el, start, stop) 197 | if self.LA(1)==DOT: 198 | op = self.match(DOT) 199 | m = self.match(NAME) 200 | m = tsensor.ast.Atom(m) 201 | stop = self.LT(-1) 202 | root = tsensor.ast.Member(op, root, m, start, stop) 203 | return root 204 | 205 | def atom(self): 206 | if self.LA(1) == LPAR: 207 | return self.subexpr() 208 | elif self.LA(1) == LSQB: 209 | return self.listatom() 210 | elif self.LA(1) in {NUMBER, NAME, STRING, COLON}: 211 | atom = self.LT(1) 212 | self.t += 1 213 | return tsensor.ast.Atom(atom) 214 | else: 215 | self.error("unknown or missing atom:"+str(self.LT(1))) 216 | 217 | def exprlist(self): 218 | elist = [] 219 | e = self.expression() 220 | elist.append(e) 221 | while self.LA(1)==COMMA and self.LA(2)!=RPAR: # could be trailing comma in a tuple like (3,4,) 222 | self.match(COMMA) 223 | e = self.expression() 224 | elist.append(e) 225 | return elist 226 | 227 | def arglist(self): 228 | elist = [] 229 | if self.LA(1)==NAME and self.LA(2)==EQUAL: 230 | e = self.arg() 231 | else: 232 | e = self.expression() 233 | elist.append(e) 234 | while self.LA(1)==COMMA: 235 | self.match(COMMA) 236 | if self.LA(1) == NAME and self.LA(2)==EQUAL: 237 | e = self.arg() 238 | else: 239 | e = self.expression() 240 | elist.append(e) 241 | return elist 242 | 243 | def arg(self): 244 | start = self.LT(1) 245 | kwarg = self.match(NAME) 246 | eq = self.match(EQUAL) 247 | e = self.expression() 248 | kwarg = tsensor.ast.Atom(kwarg) 249 | stop = self.LT(-1) 250 | return tsensor.ast.Assign(eq, kwarg, e, start, stop) 251 | 252 | def subexpr(self): 253 | start = self.match(LPAR) 254 | e = self.exprlist() # could be a tuple like (3,4) or even (3,4,) 255 | istuple = len(e)>1 256 | if self.LA(1)==COMMA: 257 | self.match(COMMA) 258 | istuple = True 259 | stop = self.match(RPAR) 260 | if istuple: 261 | return tsensor.ast.TupleLiteral(e, start, stop) 262 | subexpr = e[0] 263 | return tsensor.ast.SubExpr(subexpr, start, stop) 264 | 265 | def listatom(self): 266 | start = self.LT(1) 267 | self.match(LSQB) 268 | e = self.exprlist() 269 | self.match(RSQB) 270 | stop = self.LT(-1) 271 | return tsensor.ast.ListLiteral(e, start, stop) 272 | 273 | def isatom(self): 274 | return self.LA(1) in {NUMBER, NAME, STRING, COLON} 275 | # return idstart(self.LA(1)) or self.LA(1).isdigit() or self.LA(1)==':' 276 | 277 | def isgroup(self): 278 | return self.LA(1)==LPAR or self.LA(1)==LSQB 279 | 280 | def LA(self, i): 281 | return self.LT(i).type 282 | 283 | def LT(self, i): 284 | if i==0: 285 | return None 286 | if i<0: 287 | return self.tokens[self.t + i] # -1 should give prev token 288 | ahead = self.t + i - 1 289 | if ahead >= len(self.tokens): 290 | return self.tokens[-1] # return last (end marker) 291 | return self.tokens[ahead] 292 | 293 | def match(self, type): 294 | if self.LA(1)!=type: 295 | self.error(f"mismatch token {self.LT(1)}, looking for {token.tok_name[type]}") 296 | tok = self.LT(1) 297 | self.t += 1 298 | return tok 299 | 300 | def error(self, msg): 301 | raise SyntaxError(msg) 302 | 303 | 304 | def parse(statement:str, hush_errors=True): 305 | """ 306 | Parse statement and return ast and token objects. Parsing errors from invalid code 307 | or code that I cannot parse are ignored unless hush_hush_errors is False. 308 | """ 309 | p = tsensor.parsing.PyExprParser(statement, hush_errors=hush_errors) 310 | return p.parse(), p.tokens 311 | -------------------------------------------------------------------------------- /tsensor/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | __version__ = '0.1a27' -------------------------------------------------------------------------------- /tsensor/viz.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Terence Parr 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import sys 25 | import tempfile 26 | import graphviz 27 | import token 28 | import matplotlib.patches as patches 29 | import matplotlib.pyplot as plt 30 | from IPython.display import display, SVG 31 | from IPython import get_ipython 32 | 33 | import numpy as np 34 | import tsensor 35 | import tsensor.ast 36 | import tsensor.analysis 37 | import tsensor.parsing 38 | 39 | 40 | class PyVizView: 41 | """ 42 | An object that collects relevant information about viewing Python code 43 | with visual annotations. 44 | """ 45 | def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize, 46 | matrixcolor, vectorcolor, char_sep_scale, dpi): 47 | self.statement = statement 48 | self.fontsize = fontsize 49 | self.fontname = fontname 50 | self.dimfontsize = dimfontsize 51 | self.dimfontname = dimfontname 52 | self.matrixcolor = matrixcolor 53 | self.vectorcolor = vectorcolor 54 | self.char_sep_scale = char_sep_scale 55 | self.dpi = dpi 56 | self.wchar = self.char_sep_scale * self.fontsize 57 | self.hchar = self.char_sep_scale * self.fontsize 58 | self.dim_ypadding = 5 59 | self.dim_xpadding = 0 60 | self.linewidth = .7 61 | self.leftedge = 25 62 | self.bottomedge = 3 63 | self.svgfilename = None 64 | self.matrix_size_scaler = 3.5 # How wide or tall as scaled fontsize is matrix? 65 | self.vector_size_scaler = 3.2 / 4 # How wide or tall as scaled fontsize is vector for skinny part? 66 | self.shift3D = 6 67 | self.cause = None # Did an exception occurred during evaluation? 68 | self.offending_expr = None 69 | 70 | def set_locations(self, maxh): 71 | """ 72 | This function finishes setting up necessary parameters about text 73 | and graphics locations for the plot. We don't know how to set these 74 | values until we know what the max height of the drawing will be. We don't know 75 | what that height is until after we've parsed and so on, which requires that 76 | we collect and store information in this view object before computing maxh. 77 | That is why this is a separate function not part of the constructor. 78 | """ 79 | line2text = self.hchar / 1.7 80 | box2line = line2text*2.6 81 | self.texty = self.bottomedge + maxh + box2line + line2text 82 | self.liney = self.bottomedge + maxh + box2line 83 | self.box_topy = self.bottomedge + maxh 84 | self.maxy = self.texty + 1.4 * self.fontsize 85 | 86 | def _repr_svg_(self): 87 | "Show an SVG rendition in a notebook" 88 | return self.svg() 89 | 90 | def svg(self): 91 | """ 92 | Render as svg and return svg text. Save file and store name in field svgfilename. 93 | """ 94 | if self.svgfilename is None: # cached? 95 | self.svgfilename = tempfile.mktemp(suffix='.svg') 96 | self.savefig(self.svgfilename) 97 | with open(self.svgfilename, encoding='UTF-8') as f: 98 | svg = f.read() 99 | return svg 100 | 101 | def savefig(self, filename): 102 | "Save viz in format according to file extension." 103 | plt.savefig(filename, dpi = self.dpi, bbox_inches = 'tight', pad_inches = 0) 104 | 105 | def show(self): 106 | "Display an SVG in a notebook or pop up a window if not in notebook" 107 | if get_ipython() is None: 108 | svgfilename = tempfile.mktemp(suffix='.svg') 109 | self.savefig(svgfilename) 110 | self.filename = svgfilename 111 | plt.show() 112 | else: 113 | svg = self.svg() 114 | display(SVG(svg)) 115 | plt.close() 116 | 117 | def boxsize(self, v): 118 | """ 119 | How wide and tall should we draw the box representing a vector or matrix. 120 | """ 121 | sh = tsensor.analysis._shape(v) 122 | if sh is None: return None 123 | if len(sh)==1: return self.vector_size(sh) 124 | return self.matrix_size(sh) 125 | 126 | def matrix_size(self, sh): 127 | """ 128 | How wide and tall should we draw the box representing a matrix. 129 | """ 130 | if len(sh)==1 and sh[0]==1: 131 | return self.vector_size(sh) 132 | elif len(sh) > 1 and sh[0] == 1 and sh[1] == 1: 133 | # A special case where we have a 1x1 matrix extending into the screen. 134 | # Make the 1x1 part a little bit wider than a vector so it's more readable 135 | return (2*self.vector_size_scaler * self.wchar, 2*self.vector_size_scaler * self.wchar) 136 | elif len(sh)>1 and sh[1]==1: 137 | return (self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar) 138 | return (self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar) 139 | 140 | def vector_size(self, sh): 141 | return (self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar) 142 | 143 | def draw(self, ax, sub): 144 | sh = tsensor.analysis._shape(sub.value) 145 | if len(sh)==1: self.draw_vector(ax, sub) 146 | else: self.draw_matrix(ax, sub) 147 | 148 | def draw_vector(self,ax,sub): 149 | a, b = sub.leftx, sub.rightx 150 | mid = (a + b) / 2 151 | sh = tsensor.analysis._shape(sub.value) 152 | w,h = self.vector_size(sh) 153 | rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h), 154 | width=w, 155 | height=h, 156 | linewidth=self.linewidth, 157 | facecolor=self.vectorcolor, 158 | edgecolor='grey', 159 | fill=True) 160 | ax.add_patch(rect1) 161 | ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]), 162 | horizontalalignment='center', 163 | fontname=self.dimfontname, fontsize=self.dimfontsize) 164 | 165 | def draw_matrix(self,ax,sub): 166 | a, b = sub.leftx, sub.rightx 167 | mid = (a + b) / 2 168 | sh = tsensor.analysis._shape(sub.value) 169 | w,h = self.matrix_size(sh) 170 | box_left = mid - w / 2 171 | if len(sh)>2: 172 | back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D), 173 | width=w, 174 | height=h, 175 | linewidth=self.linewidth, 176 | facecolor=self.matrixcolor, 177 | edgecolor='grey', 178 | fill=True) 179 | ax.add_patch(back_rect) 180 | rect = patches.Rectangle(xy=(box_left, self.box_topy - h), 181 | width=w, 182 | height=h, 183 | linewidth=self.linewidth, 184 | facecolor=self.matrixcolor, 185 | edgecolor='grey', 186 | fill=True) 187 | ax.add_patch(rect) 188 | ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]), 189 | verticalalignment='center', horizontalalignment='right', 190 | fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90) 191 | if len(sh)>1: 192 | textx = mid 193 | texty = self.box_topy + self.dim_ypadding 194 | if len(sh) > 2: 195 | texty += self.dim_ypadding 196 | textx += self.shift3D 197 | ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center', 198 | fontname=self.dimfontname, fontsize=self.dimfontsize) 199 | if len(sh)>2: 200 | ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]), 201 | verticalalignment='center', horizontalalignment='center', 202 | fontname=self.dimfontname, fontsize=self.dimfontsize, 203 | rotation=45) 204 | if len(sh)>3: 205 | remaining = "$\cdots$x"+'x'.join([self.nabbrev(sh[i]) for i in range(3,len(sh))]) 206 | ax.text(mid, self.box_topy - h - self.dim_ypadding, remaining, 207 | verticalalignment='top', horizontalalignment='center', 208 | fontname=self.dimfontname, fontsize=self.dimfontsize) 209 | 210 | @staticmethod 211 | def nabbrev(n) -> str: 212 | if n % 1_000_000 == 0: 213 | return str(n // 1_000_000)+'m' 214 | if n % 1_000 == 0: 215 | return str(n // 1000)+'k' 216 | return str(n) 217 | 218 | 219 | def pyviz(statement: str, frame=None, 220 | fontname='Consolas', fontsize=13, 221 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4", 222 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443', 223 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227', 224 | dimorder=None, 225 | ax=None, dpi=200, hush_errors=True) -> PyVizView: 226 | """ 227 | Parse and evaluate the Python code in argument statement (string) using 228 | the indicated execution frame. The execution frame of the invoking function 229 | is used if frame is None. 230 | 231 | The visualization finds the smallest subexpressions that evaluate to 232 | tensors then underlies them and shows a box or rectangle representing 233 | the tensor dimensions. Boxes in blue (default) have two or more dimensions 234 | but rectangles in yellow (default) have one dimension with shape (n,). 235 | 236 | Upon tensor-related execution error, the offending self-expression is 237 | highlighted (by de-highlighting the other code) and the operator is shown 238 | using error_op_color. 239 | 240 | To adjust the size of the generated visualization to be smaller or bigger, 241 | decrease or increase the font size. 242 | 243 | :param statement: A string representing the line of Python code to visualize within an execution frame. 244 | :param frame: The execution frame in which to evaluate the statement. If None, 245 | use the execution frame of the invoking function 246 | :param fontname: The name of the font used to display Python code 247 | :param fontsize: The font size used to display Python code; default is 13. 248 | Also use this to increase the size of the generated figure; 249 | larger font size means larger image. 250 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes 251 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes 252 | :param matrixcolor: The color of matrix boxes 253 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,). 254 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall 255 | text is when plotted in matplotlib. In fact there's probably, 256 | no hope to discover this information accurately in all cases. 257 | Certainly, I gave up after spending huge effort. We have a 258 | situation here where the font should be constant width, so 259 | we can just use a simple scaler times the font size to get 260 | a reasonable approximation to the width and height of a 261 | character box; the default of 1.8 seems to work reasonably 262 | well for a wide range of fonts, but you might have to tweak it 263 | when you change the font size. 264 | :param fontcolor: The color of the Python code. 265 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey 266 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression 267 | :param error_op_color: The color to use for characters associated with the erroneous operator 268 | :param dimorder: When training deep learning models in batches, we must add a 269 | batch dimension to our training data matrix. The dimension order 270 | required by the deep learning model might be different than the way 271 | we want to visualize it. For example, if each input instance is 272 | a 2D image, then our training data has shape, say, (width,height,n). 273 | If we need to add the batch dimension first, that changes to 274 | (batch,width,height,n) but we still want to visualize the input 275 | as width,height as a 2D picture with a number of instances going 276 | back into the screen (like a 3D cube). It would be convenient to 277 | visualize the batch dimension as the fourth. This parameter 278 | allows you to specify the order of dimensions for any variable 279 | referenced in the code being visualized. E.g., if the input is in 280 | variable X with shape (batch,width,height,n) but we want to display 281 | X as (width,height,n,batch), pass in dimorder=dict(X=[1,2,3,0]). 282 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization 283 | :param dpi: This library tries to generate SVG files, which are vector graphics not 284 | 2D arrays of pixels like PNG files. However, it needs to know how to 285 | compute the exact figure size to remove padding around the visualization. 286 | Matplotlib uses inches for its figure size and so we must convert 287 | from pixels or data units to inches, which means we have to know what the 288 | dots per inch, dpi, is for the image. 289 | :param hush_errors: Normally, error messages from true syntax errors but also 290 | unhandled code caught by my parser are ignored. Turn this off 291 | to see what the error messages are coming from my parser. 292 | :return: Returns a PyVizView holding info about the visualization; from a notebook 293 | an SVG image will appear. Return none upon parsing error in statement. 294 | """ 295 | view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, matrixcolor, 296 | vectorcolor, char_sep_scale, dpi) 297 | 298 | if frame is None: # use frame of caller if not passed in 299 | frame = sys._getframe().f_back 300 | root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors) 301 | if root is None: 302 | # likely syntax error in statement or code I can't handle 303 | return None 304 | root_to_viz = root 305 | try: 306 | root.eval(frame) 307 | except tsensor.ast.IncrEvalTrap as e: 308 | root_to_viz = e.offending_expr 309 | view.offending_expr = e.offending_expr 310 | view.cause = e.__cause__ 311 | # Don't raise the exception; keep going to visualize code and erroneous 312 | # subexpressions. If this function is invoked from clarify() or explain(), 313 | # the statement will be executed and will fail again during normal execution; 314 | # an exception will be thrown at that time. Then explain/clarify 315 | # will update the error message 316 | subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz) 317 | 318 | # print(statement) # For debugging 319 | # for i in range(8): 320 | # for j in range(10): 321 | # print(j,end='') 322 | # print() 323 | 324 | if ax is None: 325 | fig, ax = plt.subplots(1, 1, dpi=dpi) 326 | else: 327 | fig = ax.figure 328 | 329 | ax.axis("off") 330 | 331 | # First, we need to figure out how wide the visualization components are 332 | # for each sub expression. If these are wider than the sub expression text, 333 | # than we need to leave space around the sub expression text 334 | lpad = np.zeros((len(statement),)) # pad for characters 335 | rpad = np.zeros((len(statement),)) 336 | maxh = 0 337 | for sub in subexprs: 338 | w, h = view.boxsize(sub.value) 339 | maxh = max(h, maxh) 340 | nexpr = sub.stop.cstop_idx - sub.start.cstart_idx 341 | if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space 342 | nexpr += 1 343 | if sub.stop.cstop_idxview.wchar * nexpr: 346 | lpad[sub.start.cstart_idx] += (w - view.wchar) / 2 347 | rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2 348 | 349 | # Now we know how to place all the elements, since we know what the maximum height is 350 | view.set_locations(maxh) 351 | 352 | # Find each character's position based upon width of a character and any padding 353 | charx = np.empty((len(statement),)) 354 | x = view.leftedge 355 | for i,c in enumerate(statement): 356 | x += lpad[i] 357 | charx[i] = x 358 | x += view.wchar 359 | x += rpad[i] 360 | 361 | # Draw text for statement or expression 362 | if view.offending_expr is not None: # highlight erroneous subexpr 363 | highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool) 364 | for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]: 365 | highlight[tok.cstart_idx:tok.cstop_idx] = True 366 | errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool) 367 | for tok in root_to_viz.optokens: 368 | errors[tok.cstart_idx:tok.cstop_idx] = True 369 | for i, c in enumerate(statement): 370 | color = ignored_color 371 | if highlight[i]: 372 | color = fontcolor 373 | if errors[i]: # override color if operator token 374 | color = error_op_color 375 | ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize) 376 | else: 377 | for i, c in enumerate(statement): 378 | ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize) 379 | 380 | # Compute the left and right edges of subexpressions (alter nodes with info) 381 | for i,sub in enumerate(subexprs): 382 | a = charx[sub.start.cstart_idx] 383 | b = charx[sub.stop.cstop_idx - 1] + view.wchar 384 | sub.leftx = a 385 | sub.rightx = b 386 | 387 | # Draw grey underlines and draw matrices 388 | for i,sub in enumerate(subexprs): 389 | a,b = sub.leftx, sub.rightx 390 | pad = view.wchar*0.1 391 | ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color) 392 | view.draw(ax, sub) 393 | 394 | fig_width = charx[-1] + view.wchar + rpad[-1] 395 | fig_width_inches = (fig_width) / dpi 396 | fig_height_inches = view.maxy / dpi 397 | fig.set_size_inches(fig_width_inches, fig_height_inches) 398 | 399 | ax.set_xlim(0, (fig_width)) 400 | ax.set_ylim(0, view.maxy) 401 | 402 | return view 403 | 404 | 405 | # ---------------- SHOW AST STUFF --------------------------- 406 | 407 | class QuietGraphvizWrapper(graphviz.Source): 408 | def __init__(self, dotsrc): 409 | super().__init__(source=dotsrc) 410 | 411 | def _repr_svg_(self): 412 | return self.pipe(format='svg', quiet=True).decode(self._encoding) 413 | 414 | 415 | def astviz(statement:str, frame=None) -> graphviz.Source: 416 | return QuietGraphvizWrapper(astviz_dot(statement, frame)) 417 | 418 | 419 | def astviz_dot(statement:str, frame=None) -> str: 420 | def internal_label(node,color="yellow"): 421 | text = ''.join(str(t) for t in node.optokens) 422 | sh = tsensor.analysis._shape(node.value) 423 | if sh is None: 424 | return f'{text}' 425 | 426 | sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))]) 427 | print(sz) 428 | return f"""{text}
{sz}""" 429 | 430 | root, tokens = tsensor.parsing.parse(statement) 431 | if frame is not None: 432 | root.eval(frame) 433 | 434 | nodes = tsensor.ast.postorder(root) 435 | atoms = tsensor.ast.leaves(root) 436 | atomsS = set(atoms) 437 | ops = [nd for nd in nodes if nd not in atomsS] # keep order 438 | 439 | gr = """digraph G { 440 | margin=0; 441 | nodesep=.01; 442 | ranksep=.3; 443 | rankdir=BT; 444 | ordering=out; # keep order of leaves 445 | """ 446 | 447 | matrixcolor = "#cfe2d4" 448 | vectorcolor = "#fefecd" 449 | fontname="Consolas" 450 | fontsize=12 451 | dimfontsize = 9 452 | spread = 0 453 | 454 | # Gen leaf nodes 455 | for i in range(len(tokens)): 456 | t = tokens[i] 457 | if t.type!=token.ENDMARKER: 458 | nodetext = t.value 459 | # if ']' in nodetext: 460 | if nodetext==']': 461 | nodetext = nodetext.replace(']','‌]') # ‌ is 0-width nonjoiner. ']' by itself is bad for DOT 462 | label = f'{nodetext}' 463 | _spread = spread 464 | if t.type==token.DOT: 465 | _spread=.1 466 | elif t.type==token.EQUAL: 467 | _spread=.25 468 | elif t.type in tsensor.parsing.ADDOP: 469 | _spread=.4 470 | elif t.type in tsensor.parsing.MULOP: 471 | _spread=.2 472 | gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n' 473 | 474 | # Make sure leaves are on same level 475 | gr += f'{{ rank=same; ' 476 | for t in tokens: 477 | if t.type!=token.ENDMARKER: 478 | gr += f' leaf{id(t)}' 479 | gr += '\n}\n' 480 | 481 | # Make sure leaves are left to right by linking 482 | for i in range(len(tokens) - 2): 483 | t = tokens[i] 484 | t2 = tokens[i + 1] 485 | gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n' 486 | 487 | # Draw internal ops nodes 488 | for nd in ops: 489 | label = internal_label(nd) 490 | sh = tsensor.analysis._shape(nd.value) 491 | if sh is None: 492 | color = "" 493 | else: 494 | if len(sh)==1: 495 | color = f'fillcolor="{vectorcolor}" style=filled' 496 | else: 497 | color = f'fillcolor="{matrixcolor}" style=filled' 498 | gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n' 499 | 500 | # Link internal nodes to other nodes or leaves 501 | for nd in nodes: 502 | kids = nd.kids 503 | for sub in kids: 504 | if sub in atomsS: 505 | gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n' 506 | else: 507 | gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n' 508 | 509 | gr += "}\n" 510 | return gr --------------------------------------------------------------------------------
b = torch.abs( W @ b +
\n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | "
40
100self.W
\n", 590 | "
\n", 591 | "
)