\n",
576 | "\n",
577 | " b = torch.abs( W @ b + | \n",
578 | " \n",
579 | " \n",
580 | " \n",
581 | " | \n",
582 | " 40 | | \n",
583 | " \n",
584 | " \n",
585 | " 100 | \n",
586 | " self.W | \n",
587 | " | \n",
588 | " \n",
589 | " \n",
590 | " \n",
591 | " | \n",
592 | " ) | \n",
593 | "
\n",
594 | "\"\"\"\n",
595 | "t = HTML(h)\n",
596 | "t"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {},
603 | "outputs": [],
604 | "source": [
605 | "def matrix(nrows,ncols,text,dim_fontsize=9,fontsize=12):\n",
606 | " h = f\"\"\"\n",
607 | " \n",
608 | "
\n",
609 | "
\n",
610 | " {ncols}\n",
611 | "
\n",
612 | "
\n",
613 | " {nrows}\n",
614 | "
\n",
615 | "
\n",
616 | "
\n",
617 | " {text}\n",
618 | "
\n",
619 | "
\n",
620 | " \"\"\"\n",
621 | " sp = f\"\"\"\n",
622 | " \n",
623 | " \n",
624 | " \n",
625 | " {ncols}\n",
626 | " \n",
627 | " \n",
628 | " {nrows}\n",
629 | " \n",
630 | " \n",
631 | " \n",
632 | " {text}\n",
633 | " \n",
634 | "
\n",
635 | " \"\"\"\n",
636 | "\n",
637 | " return sp\n",
638 | "\n",
639 | "m = matrix(120,20,\"self.W\")\n",
640 | "m = f\"torch.relu({m})\"\n",
641 | "HTML(m)"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": null,
647 | "metadata": {},
648 | "outputs": [],
649 | "source": []
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": null,
654 | "metadata": {},
655 | "outputs": [],
656 | "source": []
657 | }
658 | ],
659 | "metadata": {
660 | "kernelspec": {
661 | "display_name": "Python 3",
662 | "language": "python",
663 | "name": "python3"
664 | },
665 | "language_info": {
666 | "codemirror_mode": {
667 | "name": "ipython",
668 | "version": 3
669 | },
670 | "file_extension": ".py",
671 | "mimetype": "text/x-python",
672 | "name": "python",
673 | "nbconvert_exporter": "python",
674 | "pygments_lexer": "ipython3",
675 | "version": "3.8.3"
676 | }
677 | },
678 | "nbformat": 4,
679 | "nbformat_minor": 4
680 | }
681 |
--------------------------------------------------------------------------------
/testing/str_size_matplotlib.py:
--------------------------------------------------------------------------------
1 | import matplotlib as mpl
2 | import matplotlib.patches as patches
3 | from matplotlib import pyplot as plt
4 |
5 | def textdim(s, fontname='Consolas', fontsize=11):
6 | fig, ax = plt.subplots(1,1)
7 | t = ax.text(0, 0, s, bbox={'lw':0, 'pad':0}, fontname=fontname, fontsize=fontsize)
8 | # plt.savefig(tempfile.mktemp(".pdf"))
9 | plt.savefig("/tmp/font.pdf", pad_inches=0, dpi=200)
10 | print(t)
11 | plt.close()
12 | bb = t.get_bbox_patch()
13 | print(bb)
14 | w, h = bb.get_width(), bb.get_height()
15 | return w, h
16 |
17 | # print(textdim("@"))
18 | #exit()
19 |
20 |
21 | # From: https://stackoverflow.com/questions/22667224/matplotlib-get-text-bounding-box-independent-of-backend
22 | def find_renderer(fig):
23 | if hasattr(fig.canvas, "get_renderer"):
24 | #Some backends, such as TkAgg, have the get_renderer method, which
25 | #makes this easy.
26 | renderer = fig.canvas.get_renderer()
27 | else:
28 | #Other backends do not have the get_renderer method, so we have a work
29 | #around to find the renderer. Print the figure to a temporary file
30 | #object, and then grab the renderer that was used.
31 | #(I stole this trick from the matplotlib backend_bases.py
32 | #print_figure() method.)
33 | import io
34 | fig.canvas.print_pdf(io.BytesIO())
35 | renderer = fig._cachedRenderer
36 | return(renderer)
37 |
38 |
39 | def textdim(s, fontname='Consolas', fontsize=11):
40 | fig, ax = plt.subplots(1, 1)
41 | t = ax.text(0, 0, s, fontname=fontname, fontsize=fontsize, transform=None)
42 | bb = t.get_window_extent(find_renderer(fig))
43 | print(s, bb.width, bb.height)
44 |
45 | # t = mpl.textpath.TextPath(xy=(0, 0), s=s, size=fontsize, prop=fontname)
46 | # bb = t.get_extents()
47 | # print(s, "new", bb)
48 | plt.close()
49 | return bb.width, bb.height
50 |
51 | # print(textdim("test of foo", fontsize=11))
52 | # print(textdim("test of foO", fontsize=11))
53 | # print(textdim("W @ b + x * 3 + h.dot(h)", fontsize=12))
54 |
55 | code = 'W@ b + x *3 + h.dot(h)'
56 | code = 'W@ b.f(x,y)'# + x *3 + h.dot(h)'
57 |
58 | fig, ax = plt.subplots(1,1,figsize=(4,1))
59 |
60 | fontname = "Serif"
61 | fontsize = 16
62 |
63 | # for c in code:
64 | # t = ax.text(0,0,c)
65 | # bbox1 = t.get_window_extent(find_renderer(fig))
66 | # # print(c, '->', bbox1.width, bbox1.height)
67 | # print(c, '->', textdim(c, fontname=fontname, fontsize=fontsize))
68 | # rect1 = patches.Rectangle((0,0), bbox1.width, bbox1.height, \
69 | # color = [0,0,0], fill = False)
70 | # fig.patches.append(rect1)
71 |
72 | x = 0
73 | for c in code:
74 | # print(f"plot {c} at {x},{0}")
75 | ax.text(x, 10, c, fontname=fontname, fontsize=fontsize, transform=None)
76 | w, h = textdim(c, fontname=fontname, fontsize=fontsize)
77 | # print(w,h,'->',x)
78 | x = x + w
79 |
80 | ax.set_xlim(0,x)
81 | #ax.set_ylim(0,10)
82 |
83 | ax.axis('off')
84 |
85 | #plt.show()
86 |
87 | plt.tight_layout()
88 |
89 | plt.savefig("/tmp/t.pdf", bbox_inches='tight', pad_inches=0, dpi=200)
90 |
--------------------------------------------------------------------------------
/testing/test1.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import numpy as np
25 | import dis
26 |
27 | a = np.array([1,2,3])
28 | g = globals()
29 |
30 | from inspect import currentframe, getframeinfo, stack
31 | def info():
32 | prev = stack()[1]
33 | return prev.filename, prev.lineno
34 |
35 | def prevline():
36 | prev = stack()[1]
37 | filename, line = prev.filename, prev.lineno
38 | line -= 1 # get previous line
39 | with open(filename, "r") as f:
40 | code = f.read()
41 | lines = code.split('\n')
42 | if line==1:
43 | return None
44 | return lines[line-1] # indexed from 1
45 |
46 |
47 | def nextline(n):
48 | prev = stack()[1]
49 | filename, line = prev.filename, prev.lineno
50 | line += n # get n lines ahead
51 | with open(filename, "r") as f:
52 | code = f.read()
53 | lines = code.split('\n')
54 | if line==1:
55 | return None
56 | return lines[line-1] # indexed from 1
57 |
58 |
59 | def f():
60 | W = np.array([[1,2],[3,4]])
61 | b = np.array([9,10]).reshape(2,1)
62 | x = np.array([4,5]).reshape(2,1)
63 | code = nextline(2).strip()
64 | exec(code)
65 | b = W @ b + x
66 | loc = locals()
67 | print(eval(compile("b*x", "", "eval")))
68 |
69 | class dbg:
70 | def __enter__(self):
71 | prev = stack()[1]
72 | filename, line = prev.filename, prev.lineno
73 | with open(filename, "r") as f:
74 | code = f.read()
75 | lines = code.split('\n')
76 | line += 1 # next line
77 | code = lines[line-1].strip() # index from 0
78 | print("code to dbg", code)
79 | # c = compile(code, "", "exec")
80 | c = dis.Bytecode(code)
81 | print(c.dis())
82 | VARLOADS = {'LOAD_NAME','LOAD_GLOBAL'}
83 | varrefs = [I.argval for I in c if I.opname in VARLOADS]
84 | funcrefs = [I.argval for I in c if I.opname in {'LOAD_METHOD'}]
85 | print("symbols",set(varrefs), set(funcrefs))
86 | def __exit__(self, exc_type, exc_val, exc_tb):
87 | print("exit")
88 |
89 | def hi(): print("hi"); return 99
90 |
91 |
92 | W = np.array([[1, 2], [3, 4]])
93 | b = np.array([9, 10]).reshape(2, 1)
94 | x = np.array([4, 5]).reshape(2, 1)
95 |
96 | with dbg():
97 | z = torch.sigmoid(self.Whz @ h + self.Uxz @ x + self.bz)
98 | b = W @ b + np.abs(x)
99 |
100 | # dis.dis("f()")
101 | # dis.dis("a.f()")
102 | # dis.dis("np.pi")
103 | # dis.dis("a+3")
--------------------------------------------------------------------------------
/testing/test2.py:
--------------------------------------------------------------------------------
1 | import tsensor
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 | W = torch.tensor([[1, 2], [3, 4]])
7 | b = torch.tensor([9, 10]).reshape(2, 1)
8 | x = torch.tensor([4, 5]).reshape(2, 1)
9 | h = torch.tensor([1,2])
10 |
11 |
12 | # tsensor.pyviz("a = torch.relu(x)")
13 | # plt.show()
14 | # #
15 |
16 | with tsensor.clarify():
17 | W @ np.dot(b, b) + np.eye(2, 2) @ x
18 | # b = W @ b + x * 3 + h.dot(h)
--------------------------------------------------------------------------------
/testing/test3.py:
--------------------------------------------------------------------------------
1 | import tsensor
2 | import tensorflow as tf
3 | import matplotlib.pyplot as plt
4 |
5 | W = tf.constant([[1, 2], [3, 4]])
6 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1))
7 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1))
8 | z = 0
9 |
10 | # tsensor.parse("z /= b + x * 3", hush_errors=False)
11 |
12 | # with tsensor.clarify(show='viz'):
13 | # b + x * 3
14 |
15 | fig, ax = plt.subplots(1,1)
16 | tsensor.pyviz("b + x", ax=ax)
17 | plt.show()
18 | # with tsensor.explain():
19 | # b + x
20 |
21 |
--------------------------------------------------------------------------------
/testing/test_incr_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.ast import IncrEvalTrap
25 | from tsensor.parsing import PyExprParser
26 | import sys
27 | import numpy as np
28 | import torch
29 |
30 | def check(s,expected):
31 | frame = sys._getframe()
32 | caller = frame.f_back
33 | p = PyExprParser(s)
34 | t = p.parse()
35 | bad_subexpr = None
36 | try:
37 | t.eval(caller)
38 | except IncrEvalTrap as exc:
39 | bad_subexpr = str(exc.offending_expr)
40 | assert bad_subexpr==expected
41 |
42 |
43 | def test_missing_var():
44 | a = 3
45 | c = 5
46 | check("a+b+c", "b")
47 | check("z+b+c", "z")
48 |
49 | def test_matrix_mult():
50 | W = torch.tensor([[1, 2], [3, 4]])
51 | b = torch.tensor([[1,2,3]])
52 | check("W@b+torch.abs(b)", "W@b")
53 |
54 | def test_bad_arg():
55 | check("torch.abs('foo')", "torch.abs('foo')")
56 |
57 | def test_parens():
58 | a = 3
59 | b = 4
60 | c = 5
61 | check("(a+b)/0", "(a+b)/0")
62 |
63 | def test_array_literal():
64 | a = torch.tensor([[1,2,3],[4,5,6]])
65 | b = torch.tensor([[1,2,3]])
66 | a+b
67 | check("a + b@2", """b@2""")
68 |
69 | def test_array_literal2():
70 | a = torch.tensor([[1,2,3],[4,5,6]])
71 | b = torch.tensor([[1,2,3]])
72 | a+b
73 | check("(a+b)@2", """(a+b)@2""")
74 |
--------------------------------------------------------------------------------
/testing/test_parser.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.parsing import *
25 | import re
26 |
27 | def check(s, expected_repr, expect_str=None):
28 | p = PyExprParser(s, hush_errors=False)
29 | t = p.parse()
30 |
31 | s = re.sub(r"\s+", "", s)
32 | result_str = str(t)
33 | result_str = re.sub(r"\s+", "", result_str)
34 | if expect_str:
35 | s = expect_str
36 | assert result_str==s
37 |
38 | result_repr = repr(t)
39 | result_repr = re.sub(r"\s+", "", result_repr)
40 | expected_repr = re.sub(r"\s+", "", expected_repr)
41 | # print("result", result_repr)
42 | # print("expected", expected)
43 | assert result_repr == expected_repr
44 |
45 |
46 | def test_assign():
47 | check("a = 3", "Assign(op=,lhs=a,rhs=3)")
48 |
49 |
50 | def test_index():
51 | check("a[:,i,j]", "Index(arr=a, index=[:, i, j])")
52 |
53 |
54 | def test_index2():
55 | check("z = a[:]", "Assign(op=,lhs=z,rhs=Index(arr=a,index=[:]))")
56 |
57 | def test_index3():
58 | check("g.W[:,:,1]", "Index(arr=Member(op=,obj=g,member=W),index=[:,:,1])")
59 |
60 | def test_literal_list():
61 | check("[[1, 2], [3, 4]]",
62 | "ListLiteral(elems=[ListLiteral(elems=[1, 2]), ListLiteral(elems=[3, 4])])")
63 |
64 |
65 | def test_literal_array():
66 | check("np.array([[1, 2], [3, 4]])",
67 | """
68 | Call(func=Member(op=,obj=np,member=array),
69 | args=[ListLiteral(elems=[ListLiteral(elems=[1,2]),ListLiteral(elems=[3,4])])])
70 | """)
71 |
72 |
73 | def test_method():
74 | check("h = torch.tanh(h)",
75 | "Assign(op=,lhs=h,rhs=Call(func=Member(op=,obj=torch,member=tanh),args=[h]))")
76 |
77 |
78 | def test_method2():
79 | check("np.dot(b,b)",
80 | "Call(func=Member(op=,obj=np,member=dot),args=[b,b])")
81 |
82 |
83 | def test_field():
84 | check("a.b", "Member(op=,obj=a,member=b)")
85 |
86 |
87 | def test_member_func():
88 | check("a.f()", "Call(func=Member(op=,obj=a,member=f),args=[])")
89 |
90 |
91 | def test_field2():
92 | check("a.b.c", "Member(op=,obj=Member(op=,obj=a,member=b),member=c)")
93 |
94 |
95 | def test_field_and_func():
96 | check("a.f().c", "Member(op=,obj=Call(func=Member(op=,obj=a,member=f),args=[]),member=c)")
97 |
98 |
99 | def test_parens():
100 | check("(a+b)*c", "BinaryOp(op=,lhs=SubExpr(e=BinaryOp(op=,lhs=a,rhs=b)),rhs=c)")
101 |
102 |
103 | def test_1tuple():
104 | check("(3,)", "TupleLiteral(elems=[3])")
105 |
106 |
107 | def test_2tuple():
108 | check("(3,4)", "TupleLiteral(elems=[3,4])")
109 |
110 |
111 | def test_2tuple_with_trailing_comma():
112 | check("(3,4,)", "TupleLiteral(elems=[3,4])", expect_str="(3,4)")
113 |
114 |
115 | def test_field_array():
116 | check("a.b[34]", "Index(arr=Member(op=,obj=a,member=b),index=[34])")
117 |
118 |
119 | def test_field_array_func():
120 | check("a.b[34].f()", "Call(func=Member(op=,obj=Index(arr=Member(op=,obj=a,member=b),index=[34]),member=f),args=[])")
121 |
122 |
123 | def test_arith():
124 | check("(1-z)*h + z*h_",
125 | """BinaryOp(op=,
126 | lhs=BinaryOp(op=,
127 | lhs=SubExpr(e=BinaryOp(op=,
128 | lhs=1,
129 | rhs=z)),
130 | rhs=h),
131 | rhs=BinaryOp(op=,lhs=z,rhs=h_))""")
132 |
133 |
134 | def test_chained_op():
135 | check("a + b + c",
136 | """BinaryOp(op=,
137 | lhs=BinaryOp(op=, lhs=a, rhs=b),
138 | rhs=c)""")
139 |
140 |
141 | def test_matrix_arith():
142 | check("self.Whz@h + Uxz@x + bz",
143 | """
144 | BinaryOp(op=,
145 | lhs=BinaryOp(op=,
146 | lhs=BinaryOp(op=,lhs=Member(op=,obj=self,member=Whz),rhs=h),
147 | rhs=BinaryOp(op=,lhs=Uxz,rhs=x)),
148 | rhs=bz)
149 | """)
150 |
151 | def test_kwarg():
152 | check("torch.relu(torch.rand(size=(2000,)))",
153 | """
154 | Call(func=Member(op=,obj=torch,member=relu),
155 | args=[Call(func=Member(op=,obj=torch,member=rand),
156 | args=[Assign(op=,lhs=size,rhs=TupleLiteral(elems=[2000]))])])""")
--------------------------------------------------------------------------------
/testing/test_tensorflow.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import tsensor
25 | import tensorflow as tf
26 |
27 | W = tf.constant([[1, 2], [3, 4]])
28 | b = tf.reshape(tf.constant([[9, 10]]), (2, 1))
29 | x = tf.reshape(tf.constant([[8, 5, 7]]), (3, 1))
30 |
31 | def test_addition():
32 | msg = ""
33 | try:
34 | with tsensor.clarify():
35 | q = b + x + 3
36 | except tf.errors.InvalidArgumentError as iae:
37 | msg = iae.message
38 |
39 | expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\
40 | "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)"
41 | assert msg==expected
--------------------------------------------------------------------------------
/testing/test_tree_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from tsensor.parsing import *
25 | import tsensor.ast
26 | import sys
27 | import numpy as np
28 |
29 |
30 | def check(s,expected):
31 | frame = sys._getframe()
32 | caller = frame.f_back
33 | p = PyExprParser(s)
34 | t = p.parse()
35 | result = t.eval(caller)
36 | assert str(result)==str(expected)
37 |
38 |
39 | def test_int():
40 | check("34", 34)
41 |
42 | def test_assign():
43 | check("a = 34", 34)
44 |
45 | def test_var():
46 | a = 34
47 | check("a", 34)
48 |
49 | def test_member_var():
50 | class A:
51 | def __init__(self):
52 | self.a = 34
53 | x = A()
54 | check("x.a", 34)
55 |
56 | def test_member_func():
57 | class A:
58 | def f(self, a):
59 | return a+4
60 | x = A()
61 | check("x.f(30)", 34)
62 |
63 | def test_index():
64 | a = [1,2,3]
65 | check("a[2]", 3)
66 |
67 | def test_add():
68 | a = 3
69 | b = 4
70 | c = 5
71 | check("a+b+c", 12)
72 |
73 | def test_add_mul():
74 | a = 3
75 | b = 4
76 | c = 5
77 | check("a+b*c", 23)
78 |
79 | def test_parens():
80 | a = 3
81 | b = 4
82 | c = 5
83 | check("(a+b)*c", 35)
84 |
85 | def test_list_literal():
86 | a = [[1,2,3],[4,5,6]]
87 | check("a", """[[1, 2, 3], [4, 5, 6]]""")
88 |
89 |
90 | def test_np_literal():
91 | a = np.array([[1,2,3],[4,5,6]])
92 | check("a*2", """[[ 2 4 6]\n [ 8 10 12]]""")
93 |
94 |
95 | def test_np_add():
96 | a = np.array([[1,2,3],[4,5,6]])
97 | check("a+a", """[[ 2 4 6]\n [ 8 10 12]]""")
98 |
99 |
100 | def test_np_add2():
101 | a = np.array([[1,2,3],[4,5,6]])
102 | check("a+a+a", """[[ 3 6 9]\n [12 15 18]]""")
103 |
--------------------------------------------------------------------------------
/testing/testexc.py:
--------------------------------------------------------------------------------
1 | foo = None
2 |
3 | try:
4 | try:
5 | state = "bar"
6 | foo.append(state)
7 |
8 | except Exception as e:
9 | e.args = ("Appending '"+state+"' failed", *e.args)
10 | raise
11 |
12 | print(foo[0]) # would raise too
13 |
14 | except Exception as e:
15 | e.message = "foo"
16 | e.args = ("print(foo) failed: " + str(foo), *e.args)
17 | raise
--------------------------------------------------------------------------------
/testing/viz_testing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import numpy as np
4 | import tensorflow as tf
5 | import graphviz
6 | import tempfile
7 | import matplotlib.patches as patches
8 | import matplotlib.pyplot as plt
9 | import matplotlib.font_manager as fm
10 |
11 |
12 | # print('\n'.join(str(f) for f in fm.fontManager.ttflist))
13 | import tsensor
14 | # from tsensor.viz import pyviz, astviz
15 |
16 | def foo():
17 | # W = torch.rand(size=(2000, 2000))
18 | W = torch.rand(size=(2000, 2000, 10, 8))
19 | b = torch.rand(size=(2000, 1))
20 | h = torch.rand(size=(1_000_000,))
21 | x = torch.rand(size=(2000, 1))
22 | # g = tsensor.astviz("b = W@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
23 | # sys._getframe())
24 | frame = sys._getframe()
25 | frame = None
26 | g = tsensor.astviz("b = W[:,:,0,0]@b + (h+3).dot(h) + torch.abs(torch.tensor(34))",
27 | frame)
28 | g.view()
29 |
30 | #foo()
31 |
32 | class GRU:
33 | def __init__(self):
34 | self.W = torch.rand(size=(2,20,2000,10))
35 | self.b = torch.rand(size=(20,1))
36 | # self.x = torch.tensor([4, 5]).reshape(2, 1)
37 | self.h = torch.rand(size=(1_000_000,))
38 | self.a = 3
39 | print(self.W.shape)
40 | print(self.W[:, :, 1].shape)
41 |
42 | def get(self):
43 | return torch.tensor([[1, 2], [3, 4]])
44 |
45 | # W = torch.tensor([[1, 2], [3, 4]])
46 | b = torch.rand(size=(2000,1))
47 | h = torch.rand(size=(1_000_000,2))
48 | x = torch.rand(size=(1_000_000,2))
49 | a = 3
50 |
51 | foo = torch.rand(size=(2000,))
52 | torch.relu(foo)
53 |
54 | g = GRU()
55 |
56 | with tsensor.clarify():
57 | tf.constant([1,2]) @ tf.constant([1,3])
58 |
59 |
60 | # code = "b = g.W[0,:,:,1]@b+torch.zeros(200,1)+(h+3).dot(h)"
61 | # code = "torch.relu(foo)"
62 | # code = "np.dot(b,b)"
63 | # g = tsensor.pyviz(code, fontname='Courier New', fontsize=16, dimfontsize=9,
64 | # char_sep_scale=1.8, hush_errors=False)
65 | # plt.tight_layout()
66 | # plt.savefig("/tmp/t.svg", dpi=200, bbox_inches='tight', pad_inches=0)
67 |
68 | # W = torch.tensor([[1, 2], [3, 4]])
69 | # x = torch.tensor([4, 5]).reshape(2, 1)
70 | # with tsensor.explain():
71 | # b = torch.rand(size=(2000,))
72 | # torch.relu(b)
73 |
74 |
75 | # g = GRU()
76 | #
77 | # g1 = tsensor.astviz("b = g.W@b + torch.eye(3,3)")
78 | # g1.view()
79 | # g1 = tsensor.pyviz("b = g.W@b")
80 | # g1.view()
81 | # g2 = tsensor.astviz("b = g.W@b + g.h.dot(g.h) + torch.abs(torch.tensor(34))")
82 |
--------------------------------------------------------------------------------
/tsensor/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | __all__ = ["ast", "parsing", "viz", "analysis", "version"]
25 |
26 | # These classes/functions are the primary user interface so import them directly
27 | import tsensor.ast
28 | import tsensor.parsing
29 | import tsensor.viz
30 | import tsensor.analysis
31 | from tsensor.analysis import explain, clarify, eval
32 | from tsensor.parsing import parse
33 | from tsensor.viz import pyviz, astviz
34 |
35 | from .version import __version__
--------------------------------------------------------------------------------
/tsensor/analysis.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import os
25 | import sys
26 | import traceback
27 | import torch
28 | import inspect
29 |
30 | import matplotlib.pyplot as plt
31 |
32 | import tsensor
33 |
34 | class clarify:
35 | def __init__(self,
36 | fontname='Consolas', fontsize=13,
37 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4",
38 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
39 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
40 | show:(None,'viz')='viz'):
41 | """
42 | Augment tensor-related exceptions generated from numpy, pytorch, and tensorflow.
43 | Also display a visual representation of the offending Python line that
44 | shows the shape of tensors referenced by the code. All you have to do is wrap
45 | the outermost level of your code and clarify() will activate upon exception.
46 |
47 | Visualizations pop up in a separate window unless running from a notebook,
48 | in which case the visualization appears as part of the cell execution output.
49 |
50 | There is no runtime overhead associated with clarify() unless an exception occurs.
51 |
52 | The offending code is executed a second time, to identify which sub expressions
53 | are to blame. This implies that code with side effects could conceivably cause
54 | a problem, but since an exception has been generated, results are suspicious
55 | anyway.
56 |
57 | Example:
58 |
59 | import numpy as np
60 | import tsensor
61 |
62 | b = np.array([9, 10]).reshape(2, 1)
63 | with tsensor.clarify():
64 | np.dot(b,b) # tensor code or call to a function with tensor code
65 |
66 | See examples.ipynb for more examples.
67 |
68 | :param fontname: The name of the font used to display Python code
69 | :param fontsize: The font size used to display Python code; default is 13.
70 | Also use this to increase the size of the generated figure;
71 | larger font size means larger image.
72 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
73 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
74 | :param matrixcolor: The color of matrix boxes
75 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,).
76 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
77 | text is when plotted in matplotlib. In fact there's probably,
78 | no hope to discover this information accurately in all cases.
79 | Certainly, I gave up after spending huge effort. We have a
80 | situation here where the font should be constant width, so
81 | we can just use a simple scaler times the font size to get
82 | a reasonable approximation to the width and height of a
83 | character box; the default of 1.8 seems to work reasonably
84 | well for a wide range of fonts, but you might have to tweak it
85 | when you change the font size.
86 | :param fontcolor: The color of the Python code.
87 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
88 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
89 | :param error_op_color: The color to use for characters associated with the erroneous operator
90 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
91 | :param dpi: This library tries to generate SVG files, which are vector graphics not
92 | 2D arrays of pixels like PNG files. However, it needs to know how to
93 | compute the exact figure size to remove padding around the visualization.
94 | Matplotlib uses inches for its figure size and so we must convert
95 | from pixels or data units to inches, which means we have to know what the
96 | dots per inch, dpi, is for the image.
97 | :param hush_errors: Normally, error messages from true syntax errors but also
98 | unhandled code caught by my parser are ignored. Turn this off
99 | to see what the error messages are coming from my parser.
100 | :param show: Show visualization upon tensor error if show='viz'.
101 | """
102 | self.show = show
103 | self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
104 | self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
105 | self.fontcolor, self.underline_color, self.ignored_color, self.error_op_color = \
106 | fontname, fontsize, dimfontname, dimfontsize, \
107 | matrixcolor, vectorcolor, char_sep_scale, \
108 | fontcolor, underline_color, ignored_color, error_op_color
109 |
110 | def __enter__(self):
111 | self.frame = sys._getframe().f_back # where do we start tracking
112 | return self
113 |
114 | def __exit__(self, exc_type, exc_value, exc_traceback):
115 | if exc_type is not None and is_interesting_exception(exc_value):
116 | # print("exception:", exc_value, exc_traceback)
117 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
118 | exc_frame = deepest_frame(exc_traceback)
119 | module, name, filename, line, code = info(exc_frame)
120 | # print('info', module, name, filename, line, code)
121 | if code is not None:
122 | view = tsensor.viz.pyviz(code, exc_frame,
123 | self.fontname, self.fontsize, self.dimfontname,
124 | self.dimfontsize, self.matrixcolor, self.vectorcolor,
125 | self.char_sep_scale, self.fontcolor,
126 | self.underline_color, self.ignored_color,
127 | self.error_op_color)
128 | if self.show=='viz':
129 | view.show()
130 | augment_exception(exc_value, view.offending_expr)
131 |
132 |
133 | class explain:
134 | def __init__(self,
135 | fontname='Consolas', fontsize=13,
136 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4",
137 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
138 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
139 | savefig=None):
140 | """
141 | As the Python virtual machine executes lines of code, generate a
142 | visualization for tensor-related expressions using from numpy, pytorch,
143 | and tensorflow. The shape of tensors referenced by the code are displayed.
144 |
145 | Visualizations pop up in a separate window unless running from a notebook,
146 | in which case the visualization appears as part of the cell execution output.
147 |
148 | There is heavy runtime overhead associated with explain() as every line
149 | is executed twice: once by explain() and then another time by the interpreter
150 | as part of normal execution.
151 |
152 | Expressions with side effects can easily generate incorrect results. Due to
153 | this and the overhead, you should limit the use of this to code you're trying
154 | to debug. Assignments are not evaluated by explain so code `x = ...` causes
155 | an assignment to x just once, during normal execution. This explainer
156 | knows the value of x and will display it but does not assign to it.
157 |
158 | Upon exception, execution will stop as usual but, like clarify(), explain()
159 | will augment the exception to indicate the offending sub expression. Further,
160 | the visualization will deemphasize code not associated with the offending
161 | sub expression. The sizes of relevant tensor values are still visualized.
162 |
163 | Example:
164 |
165 | import numpy as np
166 | import tsensor
167 |
168 | b = np.array([9, 10]).reshape(2, 1)
169 | with tsensor.explain():
170 | b + b # tensor code or call to a function with tensor code
171 |
172 | See examples.ipynb for more examples.
173 |
174 | :param fontname: The name of the font used to display Python code
175 | :param fontsize: The font size used to display Python code; default is 13.
176 | Also use this to increase the size of the generated figure;
177 | larger font size means larger image.
178 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
179 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
180 | :param matrixcolor: The color of matrix boxes
181 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,).
182 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
183 | text is when plotted in matplotlib. In fact there's probably,
184 | no hope to discover this information accurately in all cases.
185 | Certainly, I gave up after spending huge effort. We have a
186 | situation here where the font should be constant width, so
187 | we can just use a simple scaler times the font size to get
188 | a reasonable approximation to the width and height of a
189 | character box; the default of 1.8 seems to work reasonably
190 | well for a wide range of fonts, but you might have to tweak it
191 | when you change the font size.
192 | :param fontcolor: The color of the Python code.
193 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
194 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
195 | :param error_op_color: The color to use for characters associated with the erroneous operator
196 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
197 | :param dpi: This library tries to generate SVG files, which are vector graphics not
198 | 2D arrays of pixels like PNG files. However, it needs to know how to
199 | compute the exact figure size to remove padding around the visualization.
200 | Matplotlib uses inches for its figure size and so we must convert
201 | from pixels or data units to inches, which means we have to know what the
202 | dots per inch, dpi, is for the image.
203 | :param hush_errors: Normally, error messages from true syntax errors but also
204 | unhandled code caught by my parser are ignored. Turn this off
205 | to see what the error messages are coming from my parser.
206 | :param savefig: A string indicating where to save the visualization; don't save
207 | a file if None.
208 | """
209 | self.savefig = savefig
210 | self.fontname, self.fontsize, self.dimfontname, self.dimfontsize, \
211 | self.matrixcolor, self.vectorcolor, self.char_sep_scale,\
212 | self.fontcolor, self.underline_color, self.ignored_color, self.error_op_color = \
213 | fontname, fontsize, dimfontname, dimfontsize, \
214 | matrixcolor, vectorcolor, char_sep_scale, \
215 | fontcolor, underline_color, ignored_color, error_op_color
216 |
217 | def __enter__(self, format="svg"):
218 | # print("ON trace")
219 | self.tracer = ExplainTensorTracer(self.savefig, format=format)
220 | sys.settrace(self.tracer.listener)
221 | frame = sys._getframe()
222 | prev = frame.f_back # get block wrapped in "with"
223 | prev.f_trace = self.tracer.listener
224 | return self.tracer
225 |
226 | def __exit__(self, exc_type, exc_value, exc_traceback):
227 | sys.settrace(None)
228 | # At this point we have already tried to visualize the statement
229 | # If there was no error, the visualization will look normal
230 | # but a matrix operation error will show the erroneous operator highlighted.
231 | # That was artificial execution of the code. Now the VM has executed
232 | # the statement for real and has found the same exception. Make sure to
233 | # augment the message with causal information.
234 | if exc_type is not None and is_interesting_exception(exc_value):
235 | # print("exception:", exc_value, exc_traceback)
236 | # traceback.print_tb(exc_traceback, limit=5, file=sys.stdout)
237 | exc_frame = deepest_frame(exc_traceback)
238 | module, name, filename, line, code = info(exc_frame)
239 | # print('info', module, name, filename, line, code)
240 | if code is not None:
241 | # We've already displayed picture so just augment message
242 | root, tokens = tsensor.parsing.parse(code)
243 | if root is not None: # Could be syntax error in statement or code I can't handle
244 | offending_expr = None
245 | try:
246 | root.eval(exc_frame)
247 | except tsensor.ast.IncrEvalTrap as e:
248 | offending_expr = e.offending_expr
249 | augment_exception(exc_value, offending_expr)
250 |
251 |
252 | class ExplainTensorTracer:
253 | def __init__(self, savefig:str=None, format="svg", modules=['__main__'], filenames=[]):
254 | self.savefig = savefig
255 | self.format = format
256 | self.modules = modules
257 | self.filenames = filenames
258 | self.exceptions = set()
259 | self.linecount = 0
260 | self.views = []
261 |
262 | def listener(self, frame, event, arg):
263 | module = frame.f_globals['__name__']
264 | if module not in self.modules:
265 | return
266 |
267 | info = inspect.getframeinfo(frame)
268 | filename, line = info.filename, info.lineno
269 | name = info.function
270 | if len(self.filenames)>0 and filename not in self.filenames:
271 | return
272 |
273 | if event=='line':
274 | self.line_listener(module, name, filename, line, info, frame)
275 |
276 | return None
277 |
278 | def line_listener(self, module, name, filename, line, info, frame):
279 | code = info.code_context[0].strip()
280 | if code.startswith("sys.settrace(None)"):
281 | return
282 | self.linecount += 1
283 | p = tsensor.parsing.PyExprParser(code)
284 | t = p.parse()
285 | if t is not None:
286 | # print(f"A line encountered in {module}.{name}() at {filename}:{line}")
287 | # print("\t", code)
288 | # print("\t", repr(t))
289 | ExplainTensorTracer.viz_statement(self, code, frame)
290 |
291 | @staticmethod
292 | def viz_statement(tracer, code, frame):
293 | view = tsensor.viz.pyviz(code, frame)
294 | tracer.views.append(view)
295 | if tracer.savefig is not None:
296 | svgfilename = f"{tracer.savefig}-{tracer.linecount}.svg"
297 | view.savefig(svgfilename)
298 | view.filename = svgfilename
299 | plt.close()
300 | else:
301 | view.show()
302 | return view
303 |
304 |
305 | def eval(statement:str, frame=None) -> (tsensor.ast.ParseTreeNode, object):
306 | """
307 | Parse statement and return an ast in the context of execution frame or, if None,
308 | the invoking function's frame. Set the value field of all ast nodes.
309 | Overall result is in root.value.
310 | :param statement: A string representing the line of Python code to visualize within an execution frame.
311 | :param frame: The execution frame in which to evaluate the statement. If None,
312 | use the execution frame of the invoking function
313 | :return An abstract parse tree representing the statement; nodes are
314 | ParseTreeNode subclasses.
315 | """
316 | p = tsensor.parsing.PyExprParser(statement)
317 | root = p.parse()
318 | if frame is None: # use frame of caller
319 | frame = sys._getframe().f_back
320 | root.eval(frame)
321 | return root, root.value
322 |
323 |
324 | def augment_exception(exc_value, subexpr):
325 | explanation = subexpr.clarify()
326 | augment = ""
327 | if explanation is not None:
328 | augment = explanation
329 | # Reuse exception but overwrite the message
330 | # print(f"Exc type is {type(exc_value)}, len(args)={len(exc_value.args)}, has '_message'=={hasattr(exc_value, '_message')}")
331 | # print(f"Msg {str(exc_value)}")
332 | if hasattr(exc_value, "_message"):
333 | exc_value._message = exc_value.message + "\n" + augment
334 | else:
335 | exc_value.args = [exc_value.args[0] + "\n" + augment]
336 |
337 |
338 | def is_interesting_exception(e):
339 | # print(f"is_interesting_exception: type is {type(e)}")
340 | if e.__class__.__module__.startswith("tensorflow"):
341 | return True
342 | sentinels = {'matmul', 'THTensorMath', 'tensor', 'tensors', 'dimension',
343 | 'not aligned', 'size mismatch', 'shape', 'shapes', 'matrix'}
344 | if len(e.args)==0:
345 | msg = e.message
346 | else:
347 | msg = e.args[0]
348 | return sum([s in msg for s in sentinels])>0
349 |
350 |
351 | def deepest_frame(exc_traceback):
352 | """
353 | Don't trace into internals of numpy/torch/tensorflow; we want to reset frame
354 | to where in the user's python code it asked the tensor lib to perform an
355 | invalid operation.
356 |
357 | To detect libraries, look for code whose filename has "site-packages/{package}"
358 | or "dist-packages/{package}".
359 | """
360 | tb = exc_traceback
361 | packages = ['numpy','torch','tensorflow']
362 | dirs = [os.path.join('site-packages',p) for p in packages]
363 | dirs += [os.path.join('dist-packages',p) for p in packages]
364 | dirs += ['<__array_function__'] # numpy seems to not have real filename
365 | prev = tb
366 | while tb is not None:
367 | filename = tb.tb_frame.f_code.co_filename
368 | reached_lib = [p in filename for p in dirs]
369 | if sum(reached_lib)>0:
370 | break
371 | prev = tb
372 | tb = tb.tb_next
373 | return prev.tb_frame
374 |
375 |
376 | def info(frame):
377 | if hasattr(frame, '__name__'):
378 | module = frame.f_globals['__name__']
379 | else:
380 | module = None
381 | info = inspect.getframeinfo(frame)
382 | if info.code_context is not None:
383 | code = info.code_context[0].strip()
384 | else:
385 | code = None
386 | filename, line = info.filename, info.lineno
387 | name = info.function
388 | return module, name, filename, line, code
389 |
390 |
391 | def smallest_matrix_subexpr(t):
392 | """
393 | During visualization, we need to find the smallest expression
394 | that evaluates to a non-scalar. That corresponds to the deepest subtree
395 | that evaluates to a non-scalar. Because we do not have parent pointers,
396 | we cannot start at the leaves and walk upwards. Instead, set a Boolean
397 | in each node to indicate whether one of the descendents (but not itself)
398 | evaluates to a non-scalar. Nodes in the tree that have matrix values and
399 | not matrix_below are the ones to visualize.
400 |
401 | This routine modifies the tree nodes to turn on matrix_below where appropriate.
402 | """
403 | nodes = []
404 | _smallest_matrix_subexpr(t, nodes)
405 | return nodes
406 |
407 | def _smallest_matrix_subexpr(t, nodes) -> bool:
408 | if t is None: return False # prevent buggy code from causing us to fail
409 | if len(t.kids)==0: # leaf node
410 | if istensor(t.value):
411 | nodes.append(t)
412 | return istensor(t.value)
413 | n_matrix_below = 0 # once this latches true, it's passed all the way up to the root
414 | for sub in t.kids:
415 | matrix_below = _smallest_matrix_subexpr(sub, nodes)
416 | n_matrix_below += matrix_below # how many descendents evaluated two non-scalar?
417 | # If current node is matrix and no descendents are, then this is smallest
418 | # sub expression that evaluates to a matrix; keep track
419 | if istensor(t.value) and n_matrix_below==0:
420 | nodes.append(t)
421 | # Report to caller that this node or some descendent is a matrix
422 | return istensor(t.value) or n_matrix_below > 0
423 |
424 |
425 | def istensor(x):
426 | return _shape(x) is not None
427 |
428 |
429 | def _shape(v):
430 | # do we have a shape and it answers len()? Should get stuff right.
431 | if hasattr(v, "shape") and hasattr(v.shape,"__len__"):
432 | if isinstance(v.shape, torch.Size):
433 | if len(v.shape)==0:
434 | return None
435 | return list(v.shape)
436 | return v.shape
437 | return None
438 |
--------------------------------------------------------------------------------
/tsensor/ast.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import tsensor
25 |
26 | # Parse tree definitions
27 | # I found package ast in python3 lib after I built this. whoops. No biggie.
28 | # This tree structure is easier to visit for my purposes here. Also lets me
29 | # control the kinds of statements I process.
30 |
31 | class ParseTreeNode:
32 | def __init__(self):
33 | self.value = None # used during evaluation
34 | # self.matrix_below = False # indicates decendant has non-scalar value UNUSED
35 | self.start = None # start token
36 | self.stop = None # end token
37 | def eval(self, frame):
38 | """
39 | Evaluate the expression represented by this (sub)tree in context of frame.
40 | Try any exception found while evaluating and remember which operation that
41 | was in this tree
42 | """
43 | try:
44 | self.value = eval(str(self), frame.f_locals, frame.f_globals)
45 | except BaseException as e:
46 | raise IncrEvalTrap(self) from e
47 | # print(self, "=>", self.value)
48 | return self.value
49 | @property
50 | def optokens(self): # the associated token if atom or representative token if operation
51 | return None
52 | @property
53 | def kids(self):
54 | return []
55 | def clarify(self):
56 | return None
57 | def __str__(self):
58 | pass
59 | def __repr__(self):
60 | fields = self.__dict__.copy()
61 | kill = ['start', 'stop', 'lbrack', 'lparen']
62 | for name in kill:
63 | if name in fields: del fields[name]
64 | args = [v+'='+fields[v].__repr__() for v in fields if v!='value' or fields['value'] is not None]
65 | args = ','.join(args)
66 | return f"{self.__class__.__name__}({args})"
67 |
68 | class Assign(ParseTreeNode):
69 | def __init__(self, op, lhs, rhs, start, stop):
70 | super().__init__()
71 | self.op, self.lhs, self.rhs = op, lhs, rhs
72 | self.start, self.stop = start, stop
73 | def eval(self, frame):
74 | self.value = self.rhs.eval(frame)
75 | # Don't eval this node as it causes side effect of making actual assignment to lhs
76 | self.lhs.value = self.value
77 | return self.value
78 | @property
79 | def optokens(self):
80 | return [self.op]
81 | @property
82 | def kids(self):
83 | return [self.lhs, self.rhs]
84 | def __str__(self):
85 | return str(self.lhs)+'='+str(self.rhs)
86 |
87 | class Call(ParseTreeNode):
88 | def __init__(self, func, lparen, args, start, stop):
89 | super().__init__()
90 | self.func = func
91 | self.lparen = lparen
92 | self.args = args
93 | self.start, self.stop = start, stop
94 | def eval(self, frame):
95 | self.func.eval(frame)
96 | for a in self.args:
97 | a.eval(frame)
98 | return super().eval(frame)
99 | def clarify(self):
100 | arg_msgs = []
101 | for a in self.args:
102 | ashape = tsensor.analysis._shape(a.value)
103 | if ashape:
104 | arg_msgs.append(f"arg {a} w/shape {ashape}")
105 | if len(arg_msgs)==0:
106 | return f"Cause: {self}"
107 | return f"Cause: {self} tensor " + ', '.join(arg_msgs)
108 | @property
109 | def optokens(self):
110 | f = None # assume complicated like a[i](args) with weird func expr
111 | if isinstance(self.func, Member):
112 | f = self.func.member
113 | elif isinstance(self.func, Atom):
114 | f = self.func
115 | if f:
116 | return [f.token,self.lparen,self.stop]
117 | return [self.lparen,self.stop]
118 | @property
119 | def kids(self):
120 | return [self.func]+self.args
121 | def __str__(self):
122 | args_ = ','.join([str(a) for a in self.args])
123 | return f"{self.func}({args_})"
124 |
125 | class Index(ParseTreeNode):
126 | def __init__(self, arr, lbrack, index, start, stop):
127 | super().__init__()
128 | self.arr = arr
129 | self.lbrack = lbrack
130 | self.index = index
131 | self.start, self.stop = start, stop
132 | def eval(self, frame):
133 | self.arr.eval(frame)
134 | for i in self.index:
135 | i.eval(frame)
136 | return super().eval(frame)
137 | @property
138 | def optokens(self):
139 | arr = None # assume complicated like f()[i] with no clear array var
140 | # if isinstance(self.arr, Member):
141 | # arr = self.arr.member
142 | # elif isinstance(self.arr, Atom):
143 | # arr = self.arr
144 | # if arr:
145 | # return [self.lbrack,self.stop]
146 | return [self.lbrack,self.stop]
147 | @property
148 | def kids(self):
149 | return [self.arr] + self.index
150 | def __str__(self):
151 | i = self.index
152 | i = ','.join(str(v) for v in i)
153 | return f"{self.arr}[{i}]"
154 |
155 | class Member(ParseTreeNode):
156 | def __init__(self, op, obj, member, start, stop):
157 | super().__init__()
158 | self.op = op # always DOT
159 | self.obj = obj
160 | self.member = member
161 | self.start, self.stop = start, stop
162 | def eval(self, frame):
163 | self.obj.eval(frame)
164 | # don't eval member as it's just a name to look up in obj
165 | return super().eval(frame)
166 | @property
167 | def optokens(self): # the associated token if atom or representative token if operation
168 | return [self.op]
169 | @property
170 | def kids(self):
171 | return [self.obj, self.member]
172 | def __str__(self):
173 | return f"{self.obj}.{self.member}"
174 |
175 | class BinaryOp(ParseTreeNode):
176 | def __init__(self, op, lhs, rhs, start, stop):
177 | super().__init__()
178 | self.op, self.lhs, self.rhs = op, lhs, rhs
179 | self.start, self.stop = start, stop
180 | def eval(self, frame):
181 | self.lhs.eval(frame)
182 | self.rhs.eval(frame)
183 | return super().eval(frame)
184 | def clarify(self):
185 | opnd_msgs = []
186 | lshape = tsensor.analysis._shape(self.lhs.value)
187 | rshape = tsensor.analysis._shape(self.rhs.value)
188 | if lshape:
189 | opnd_msgs.append(f"operand {self.lhs} w/shape {lshape}")
190 | if rshape:
191 | opnd_msgs.append(f"operand {self.rhs} w/shape {rshape}")
192 | return f"Cause: {self.op} on tensor " + ' and '.join(opnd_msgs)
193 | @property
194 | def optokens(self): # the associated token if atom or representative token if operation
195 | return [self.op]
196 | @property
197 | def kids(self):
198 | return [self.lhs, self.rhs]
199 | def __str__(self):
200 | return f"{self.lhs}{self.op}{self.rhs}"
201 |
202 | class UnaryOp(ParseTreeNode):
203 | def __init__(self, op, opnd, start, stop):
204 | super().__init__()
205 | self.op = op
206 | self.opnd = opnd
207 | self.start, self.stop = start, stop
208 | def eval(self, frame):
209 | self.opnd.eval(frame)
210 | return super().eval(frame)
211 | @property
212 | def optokens(self):
213 | return [self.op]
214 | @property
215 | def kids(self):
216 | return [self.opnd]
217 | def __str__(self):
218 | return f"{self.op}{self.opnd}"
219 |
220 | class ListLiteral(ParseTreeNode):
221 | def __init__(self, elems, start, stop):
222 | super().__init__()
223 | self.elems = elems
224 | self.start, self.stop = start, stop
225 | def eval(self, frame):
226 | for i in self.elems:
227 | i.eval(frame)
228 | return super().eval(frame)
229 | @property
230 | def kids(self):
231 | return self.elems
232 | def __str__(self):
233 | if isinstance(self.elems,list):
234 | elems_ = ','.join(str(e) for e in self.elems)
235 | else:
236 | elems_ = self.elems
237 | return f"[{elems_}]"
238 |
239 | class TupleLiteral(ParseTreeNode):
240 | def __init__(self, elems, start, stop):
241 | super().__init__()
242 | self.elems = elems
243 | self.start, self.stop = start, stop
244 | def eval(self, frame):
245 | for i in self.elems:
246 | i.eval(frame)
247 | return super().eval(frame)
248 | @property
249 | def kids(self):
250 | return self.elems
251 | def __str__(self):
252 | if len(self.elems)==1:
253 | return f"({self.elems[0]},)"
254 | else:
255 | return f"({','.join(str(e) for e in self.elems)})"
256 |
257 | class SubExpr(ParseTreeNode):
258 | # record parens for later display to keep precedence
259 | def __init__(self, e, start, stop):
260 | super().__init__()
261 | self.e = e
262 | self.start, self.stop = start, stop
263 | def eval(self, frame):
264 | self.value = self.e.eval(frame)
265 | return self.value # don't re-evaluate
266 | @property
267 | def optokens(self):
268 | return [self.start, self.stop]
269 | @property
270 | def kids(self):
271 | return [self.e]
272 | def __str__(self):
273 | return f"({self.e})"
274 |
275 | class Atom(ParseTreeNode):
276 | def __init__(self, token):
277 | super().__init__()
278 | self.token = token
279 | self.start, self.stop = token, token
280 | def eval(self, frame):
281 | if self.token.type == tsensor.parsing.COLON:
282 | return ':' # fake a value here
283 | return super().eval(frame)
284 | def __repr__(self):
285 | # v = f"{{{self.value}}}" if hasattr(self,'value') and self.value is not None else ""
286 | return self.token.value
287 | def __str__(self):
288 | return self.token.value
289 |
290 |
291 | def postorder(t):
292 | nodes = []
293 | _postorder(t, nodes)
294 | return nodes
295 |
296 |
297 | def _postorder(t, nodes):
298 | if t is None:
299 | return
300 | for sub in t.kids:
301 | _postorder(sub, nodes)
302 | nodes.append(t)
303 |
304 |
305 | def leaves(t):
306 | nodes = []
307 | _leaves(t, nodes)
308 | return nodes
309 |
310 |
311 | def _leaves(t, nodes):
312 | if t is None:
313 | return
314 | if len(t.kids) == 0:
315 | nodes.append(t)
316 | return
317 | for sub in t.kids:
318 | _leaves(sub, nodes)
319 |
320 |
321 | def walk(t, pre=lambda x: None, post=lambda x: None):
322 | if t is None:
323 | return
324 | pre(t)
325 | for sub in t.kids:
326 | walk(sub, pre, post)
327 | post(t)
328 |
329 |
330 | class IncrEvalTrap(BaseException):
331 | """
332 | Used during re-evaluation of python line that threw exception to trap which
333 | subexpression caused the problem.
334 | """
335 | def __init__(self, offending_expr):
336 | self.offending_expr = offending_expr # where in tree did we get exception?
337 |
--------------------------------------------------------------------------------
/tsensor/parsing.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | from io import BytesIO
25 | import token
26 | import keyword
27 | from tokenize import tokenize, \
28 | NUMBER, STRING, NAME, OP, ENDMARKER, LPAR, LSQB, RPAR, RSQB, COMMA, COLON,\
29 | PLUS, MINUS, STAR, SLASH, AT, PERCENT, TILDE, DOT,\
30 | NOTEQUAL, PERCENTEQUAL, AMPEREQUAL, DOUBLESTAREQUAL, STAREQUAL, PLUSEQUAL,\
31 | MINEQUAL, DOUBLESLASHEQUAL, SLASHEQUAL, LEFTSHIFTEQUAL,\
32 | LESSEQUAL, EQUAL, EQEQUAL, GREATEREQUAL, RIGHTSHIFTEQUAL, ATEQUAL,\
33 | CIRCUMFLEXEQUAL, VBAREQUAL
34 |
35 | import tsensor.ast
36 |
37 |
38 | ADDOP = {PLUS, MINUS}
39 | MULOP = {STAR, SLASH, AT, PERCENT}
40 | ASSIGNOP = {NOTEQUAL,
41 | PERCENTEQUAL,
42 | AMPEREQUAL,
43 | DOUBLESTAREQUAL,
44 | STAREQUAL,
45 | PLUSEQUAL,
46 | MINEQUAL,
47 | DOUBLESLASHEQUAL,
48 | SLASHEQUAL,
49 | LEFTSHIFTEQUAL,
50 | LESSEQUAL,
51 | EQUAL,
52 | EQEQUAL,
53 | GREATEREQUAL,
54 | RIGHTSHIFTEQUAL,
55 | ATEQUAL,
56 | CIRCUMFLEXEQUAL,
57 | VBAREQUAL}
58 | UNARYOP = {TILDE}
59 |
60 | class Token:
61 | """My own version of a token, with content copied from Python's TokenInfo object."""
62 | def __init__(self, type, value,
63 | index, # token index
64 | cstart_idx, # char start
65 | cstop_idx, # one past char end index so text[start_idx:stop_idx] works
66 | line):
67 | self.type, self.value, self.index, self.cstart_idx, self.cstop_idx, self.line = \
68 | type, value, index, cstart_idx, cstop_idx, line
69 | def __repr__(self):
70 | return f"<{token.tok_name[self.type]}:{self.value},{self.cstart_idx}:{self.cstop_idx}>"
71 | def __str__(self):
72 | return self.value
73 |
74 |
75 | def mytokenize(s):
76 | "Use Python's tokenizer to lex s and collect my own token objects"
77 | tokensO = tokenize(BytesIO(s.encode('utf-8')).readline)
78 | tokens = []
79 | i = 0
80 | for tok in tokensO:
81 | type, value, start, end, _ = tok
82 | line = start[0]
83 | start_idx = start[1]
84 | stop_idx = end[1] # one past end index
85 | if type in {NUMBER, STRING, NAME, OP, ENDMARKER}:
86 | tokens.append(Token(tok.exact_type,value,i,start_idx,stop_idx,line))
87 | i += 1
88 | else:
89 | # print("ignoring", type, value)
90 | pass
91 | # It leaves ENDMARKER on end. set text to ""
92 | tokens[-1].value = ""
93 | # print(tokens)
94 | return tokens
95 |
96 |
97 | class PyExprParser:
98 | """
99 | A recursive-descent parser for subset of Python expressions and assignments.
100 | There is a built-in parser, but I only want to process Python code this library
101 | can handle and I also want my own kind of abstract syntax tree. Constantly,
102 | it's easier if I just parse the code I care about and ignore everything else.
103 | Building this parser was certainly no great burden.
104 | """
105 | def __init__(self, code:str, hush_errors=True):
106 | self.code = code
107 | self.hush_errors = hush_errors
108 | self.tokens = mytokenize(code)
109 | self.t = 0 # current lookahead
110 |
111 | def parse(self):
112 | # print("\nparse", self.code)
113 | # print(self.tokens)
114 | # only process assignments and expressions
115 | root = None
116 | if not keyword.iskeyword(self.tokens[0].value):
117 | if self.hush_errors:
118 | try:
119 | root = self.assignment_or_expr()
120 | self.match(ENDMARKER)
121 | except SyntaxError as e:
122 | root = None
123 | else:
124 | root = self.assignment_or_expr()
125 | self.match(ENDMARKER)
126 | return root
127 |
128 | def assignment_or_expr(self):
129 | start = self.LT(1)
130 | lhs = self.expression()
131 | if self.LA(1) in ASSIGNOP:
132 | eq = self.LT(1)
133 | self.t += 1
134 | rhs = self.expression()
135 | stop = self.LT(-1)
136 | return tsensor.ast.Assign(eq,lhs,rhs,start,stop)
137 | return lhs
138 |
139 | def expression(self):
140 | return self.addexpr()
141 |
142 | def addexpr(self):
143 | start = self.LT(1)
144 | root = self.multexpr()
145 | while self.LA(1) in ADDOP:
146 | op = self.LT(1)
147 | self.t += 1
148 | b = self.multexpr()
149 | stop = self.LT(-1)
150 | root = tsensor.ast.BinaryOp(op, root, b, start, stop)
151 | return root
152 |
153 | def multexpr(self):
154 | start = self.LT(1)
155 | root = self.unaryexpr()
156 | while self.LA(1) in MULOP:
157 | op = self.LT(1)
158 | self.t += 1
159 | b = self.unaryexpr()
160 | stop = self.LT(-1)
161 | root = tsensor.ast.BinaryOp(op, root, b, start, stop)
162 | return root
163 |
164 | def unaryexpr(self):
165 | start = self.LT(1)
166 | if self.LA(1) in UNARYOP:
167 | op = self.LT(1)
168 | self.t += 1
169 | e = self.unaryexpr()
170 | stop = self.LT(-1)
171 | return tsensor.ast.UnaryOp(op, e, start, stop)
172 | elif self.isatom() or self.isgroup():
173 | return self.postexpr()
174 | else:
175 | self.error(f"missing unary expr at: {self.LT(1)}")
176 |
177 | def postexpr(self):
178 | start = self.LT(1)
179 | root = self.atom()
180 | while self.LA(1) in {LPAR, LSQB, DOT}:
181 | if self.LA(1)==LPAR:
182 | lp = self.LT(1)
183 | self.match(LPAR)
184 | el = []
185 | if self.LA(1) != RPAR:
186 | el = self.arglist()
187 | self.match(RPAR)
188 | stop = self.LT(-1)
189 | root = tsensor.ast.Call(root, lp, el, start, stop)
190 | if self.LA(1)==LSQB:
191 | lb = self.LT(1)
192 | self.match(LSQB)
193 | el = self.exprlist()
194 | self.match(RSQB)
195 | stop = self.LT(-1)
196 | root = tsensor.ast.Index(root, lb, el, start, stop)
197 | if self.LA(1)==DOT:
198 | op = self.match(DOT)
199 | m = self.match(NAME)
200 | m = tsensor.ast.Atom(m)
201 | stop = self.LT(-1)
202 | root = tsensor.ast.Member(op, root, m, start, stop)
203 | return root
204 |
205 | def atom(self):
206 | if self.LA(1) == LPAR:
207 | return self.subexpr()
208 | elif self.LA(1) == LSQB:
209 | return self.listatom()
210 | elif self.LA(1) in {NUMBER, NAME, STRING, COLON}:
211 | atom = self.LT(1)
212 | self.t += 1
213 | return tsensor.ast.Atom(atom)
214 | else:
215 | self.error("unknown or missing atom:"+str(self.LT(1)))
216 |
217 | def exprlist(self):
218 | elist = []
219 | e = self.expression()
220 | elist.append(e)
221 | while self.LA(1)==COMMA and self.LA(2)!=RPAR: # could be trailing comma in a tuple like (3,4,)
222 | self.match(COMMA)
223 | e = self.expression()
224 | elist.append(e)
225 | return elist
226 |
227 | def arglist(self):
228 | elist = []
229 | if self.LA(1)==NAME and self.LA(2)==EQUAL:
230 | e = self.arg()
231 | else:
232 | e = self.expression()
233 | elist.append(e)
234 | while self.LA(1)==COMMA:
235 | self.match(COMMA)
236 | if self.LA(1) == NAME and self.LA(2)==EQUAL:
237 | e = self.arg()
238 | else:
239 | e = self.expression()
240 | elist.append(e)
241 | return elist
242 |
243 | def arg(self):
244 | start = self.LT(1)
245 | kwarg = self.match(NAME)
246 | eq = self.match(EQUAL)
247 | e = self.expression()
248 | kwarg = tsensor.ast.Atom(kwarg)
249 | stop = self.LT(-1)
250 | return tsensor.ast.Assign(eq, kwarg, e, start, stop)
251 |
252 | def subexpr(self):
253 | start = self.match(LPAR)
254 | e = self.exprlist() # could be a tuple like (3,4) or even (3,4,)
255 | istuple = len(e)>1
256 | if self.LA(1)==COMMA:
257 | self.match(COMMA)
258 | istuple = True
259 | stop = self.match(RPAR)
260 | if istuple:
261 | return tsensor.ast.TupleLiteral(e, start, stop)
262 | subexpr = e[0]
263 | return tsensor.ast.SubExpr(subexpr, start, stop)
264 |
265 | def listatom(self):
266 | start = self.LT(1)
267 | self.match(LSQB)
268 | e = self.exprlist()
269 | self.match(RSQB)
270 | stop = self.LT(-1)
271 | return tsensor.ast.ListLiteral(e, start, stop)
272 |
273 | def isatom(self):
274 | return self.LA(1) in {NUMBER, NAME, STRING, COLON}
275 | # return idstart(self.LA(1)) or self.LA(1).isdigit() or self.LA(1)==':'
276 |
277 | def isgroup(self):
278 | return self.LA(1)==LPAR or self.LA(1)==LSQB
279 |
280 | def LA(self, i):
281 | return self.LT(i).type
282 |
283 | def LT(self, i):
284 | if i==0:
285 | return None
286 | if i<0:
287 | return self.tokens[self.t + i] # -1 should give prev token
288 | ahead = self.t + i - 1
289 | if ahead >= len(self.tokens):
290 | return self.tokens[-1] # return last (end marker)
291 | return self.tokens[ahead]
292 |
293 | def match(self, type):
294 | if self.LA(1)!=type:
295 | self.error(f"mismatch token {self.LT(1)}, looking for {token.tok_name[type]}")
296 | tok = self.LT(1)
297 | self.t += 1
298 | return tok
299 |
300 | def error(self, msg):
301 | raise SyntaxError(msg)
302 |
303 |
304 | def parse(statement:str, hush_errors=True):
305 | """
306 | Parse statement and return ast and token objects. Parsing errors from invalid code
307 | or code that I cannot parse are ignored unless hush_hush_errors is False.
308 | """
309 | p = tsensor.parsing.PyExprParser(statement, hush_errors=hush_errors)
310 | return p.parse(), p.tokens
311 |
--------------------------------------------------------------------------------
/tsensor/version.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | __version__ = '0.1a27'
--------------------------------------------------------------------------------
/tsensor/viz.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2020 Terence Parr
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import sys
25 | import tempfile
26 | import graphviz
27 | import token
28 | import matplotlib.patches as patches
29 | import matplotlib.pyplot as plt
30 | from IPython.display import display, SVG
31 | from IPython import get_ipython
32 |
33 | import numpy as np
34 | import tsensor
35 | import tsensor.ast
36 | import tsensor.analysis
37 | import tsensor.parsing
38 |
39 |
40 | class PyVizView:
41 | """
42 | An object that collects relevant information about viewing Python code
43 | with visual annotations.
44 | """
45 | def __init__(self, statement, fontname, fontsize, dimfontname, dimfontsize,
46 | matrixcolor, vectorcolor, char_sep_scale, dpi):
47 | self.statement = statement
48 | self.fontsize = fontsize
49 | self.fontname = fontname
50 | self.dimfontsize = dimfontsize
51 | self.dimfontname = dimfontname
52 | self.matrixcolor = matrixcolor
53 | self.vectorcolor = vectorcolor
54 | self.char_sep_scale = char_sep_scale
55 | self.dpi = dpi
56 | self.wchar = self.char_sep_scale * self.fontsize
57 | self.hchar = self.char_sep_scale * self.fontsize
58 | self.dim_ypadding = 5
59 | self.dim_xpadding = 0
60 | self.linewidth = .7
61 | self.leftedge = 25
62 | self.bottomedge = 3
63 | self.svgfilename = None
64 | self.matrix_size_scaler = 3.5 # How wide or tall as scaled fontsize is matrix?
65 | self.vector_size_scaler = 3.2 / 4 # How wide or tall as scaled fontsize is vector for skinny part?
66 | self.shift3D = 6
67 | self.cause = None # Did an exception occurred during evaluation?
68 | self.offending_expr = None
69 |
70 | def set_locations(self, maxh):
71 | """
72 | This function finishes setting up necessary parameters about text
73 | and graphics locations for the plot. We don't know how to set these
74 | values until we know what the max height of the drawing will be. We don't know
75 | what that height is until after we've parsed and so on, which requires that
76 | we collect and store information in this view object before computing maxh.
77 | That is why this is a separate function not part of the constructor.
78 | """
79 | line2text = self.hchar / 1.7
80 | box2line = line2text*2.6
81 | self.texty = self.bottomedge + maxh + box2line + line2text
82 | self.liney = self.bottomedge + maxh + box2line
83 | self.box_topy = self.bottomedge + maxh
84 | self.maxy = self.texty + 1.4 * self.fontsize
85 |
86 | def _repr_svg_(self):
87 | "Show an SVG rendition in a notebook"
88 | return self.svg()
89 |
90 | def svg(self):
91 | """
92 | Render as svg and return svg text. Save file and store name in field svgfilename.
93 | """
94 | if self.svgfilename is None: # cached?
95 | self.svgfilename = tempfile.mktemp(suffix='.svg')
96 | self.savefig(self.svgfilename)
97 | with open(self.svgfilename, encoding='UTF-8') as f:
98 | svg = f.read()
99 | return svg
100 |
101 | def savefig(self, filename):
102 | "Save viz in format according to file extension."
103 | plt.savefig(filename, dpi = self.dpi, bbox_inches = 'tight', pad_inches = 0)
104 |
105 | def show(self):
106 | "Display an SVG in a notebook or pop up a window if not in notebook"
107 | if get_ipython() is None:
108 | svgfilename = tempfile.mktemp(suffix='.svg')
109 | self.savefig(svgfilename)
110 | self.filename = svgfilename
111 | plt.show()
112 | else:
113 | svg = self.svg()
114 | display(SVG(svg))
115 | plt.close()
116 |
117 | def boxsize(self, v):
118 | """
119 | How wide and tall should we draw the box representing a vector or matrix.
120 | """
121 | sh = tsensor.analysis._shape(v)
122 | if sh is None: return None
123 | if len(sh)==1: return self.vector_size(sh)
124 | return self.matrix_size(sh)
125 |
126 | def matrix_size(self, sh):
127 | """
128 | How wide and tall should we draw the box representing a matrix.
129 | """
130 | if len(sh)==1 and sh[0]==1:
131 | return self.vector_size(sh)
132 | elif len(sh) > 1 and sh[0] == 1 and sh[1] == 1:
133 | # A special case where we have a 1x1 matrix extending into the screen.
134 | # Make the 1x1 part a little bit wider than a vector so it's more readable
135 | return (2*self.vector_size_scaler * self.wchar, 2*self.vector_size_scaler * self.wchar)
136 | elif len(sh)>1 and sh[1]==1:
137 | return (self.vector_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar)
138 | return (self.matrix_size_scaler * self.wchar, self.matrix_size_scaler * self.wchar)
139 |
140 | def vector_size(self, sh):
141 | return (self.matrix_size_scaler * self.wchar, self.vector_size_scaler * self.wchar)
142 |
143 | def draw(self, ax, sub):
144 | sh = tsensor.analysis._shape(sub.value)
145 | if len(sh)==1: self.draw_vector(ax, sub)
146 | else: self.draw_matrix(ax, sub)
147 |
148 | def draw_vector(self,ax,sub):
149 | a, b = sub.leftx, sub.rightx
150 | mid = (a + b) / 2
151 | sh = tsensor.analysis._shape(sub.value)
152 | w,h = self.vector_size(sh)
153 | rect1 = patches.Rectangle(xy=(mid - w/2, self.box_topy-h),
154 | width=w,
155 | height=h,
156 | linewidth=self.linewidth,
157 | facecolor=self.vectorcolor,
158 | edgecolor='grey',
159 | fill=True)
160 | ax.add_patch(rect1)
161 | ax.text(mid, self.box_topy + self.dim_ypadding, self.nabbrev(sh[0]),
162 | horizontalalignment='center',
163 | fontname=self.dimfontname, fontsize=self.dimfontsize)
164 |
165 | def draw_matrix(self,ax,sub):
166 | a, b = sub.leftx, sub.rightx
167 | mid = (a + b) / 2
168 | sh = tsensor.analysis._shape(sub.value)
169 | w,h = self.matrix_size(sh)
170 | box_left = mid - w / 2
171 | if len(sh)>2:
172 | back_rect = patches.Rectangle(xy=(box_left + self.shift3D, self.box_topy - h + self.shift3D),
173 | width=w,
174 | height=h,
175 | linewidth=self.linewidth,
176 | facecolor=self.matrixcolor,
177 | edgecolor='grey',
178 | fill=True)
179 | ax.add_patch(back_rect)
180 | rect = patches.Rectangle(xy=(box_left, self.box_topy - h),
181 | width=w,
182 | height=h,
183 | linewidth=self.linewidth,
184 | facecolor=self.matrixcolor,
185 | edgecolor='grey',
186 | fill=True)
187 | ax.add_patch(rect)
188 | ax.text(box_left, self.box_topy - h/2, self.nabbrev(sh[0]),
189 | verticalalignment='center', horizontalalignment='right',
190 | fontname=self.dimfontname, fontsize=self.dimfontsize, rotation=90)
191 | if len(sh)>1:
192 | textx = mid
193 | texty = self.box_topy + self.dim_ypadding
194 | if len(sh) > 2:
195 | texty += self.dim_ypadding
196 | textx += self.shift3D
197 | ax.text(textx, texty, self.nabbrev(sh[1]), horizontalalignment='center',
198 | fontname=self.dimfontname, fontsize=self.dimfontsize)
199 | if len(sh)>2:
200 | ax.text(box_left+w, self.box_topy - h/2, self.nabbrev(sh[2]),
201 | verticalalignment='center', horizontalalignment='center',
202 | fontname=self.dimfontname, fontsize=self.dimfontsize,
203 | rotation=45)
204 | if len(sh)>3:
205 | remaining = "$\cdots$x"+'x'.join([self.nabbrev(sh[i]) for i in range(3,len(sh))])
206 | ax.text(mid, self.box_topy - h - self.dim_ypadding, remaining,
207 | verticalalignment='top', horizontalalignment='center',
208 | fontname=self.dimfontname, fontsize=self.dimfontsize)
209 |
210 | @staticmethod
211 | def nabbrev(n) -> str:
212 | if n % 1_000_000 == 0:
213 | return str(n // 1_000_000)+'m'
214 | if n % 1_000 == 0:
215 | return str(n // 1000)+'k'
216 | return str(n)
217 |
218 |
219 | def pyviz(statement: str, frame=None,
220 | fontname='Consolas', fontsize=13,
221 | dimfontname='Arial', dimfontsize=9, matrixcolor="#cfe2d4",
222 | vectorcolor="#fefecd", char_sep_scale=1.8, fontcolor='#444443',
223 | underline_color='#C2C2C2', ignored_color='#B4B4B4', error_op_color='#A40227',
224 | dimorder=None,
225 | ax=None, dpi=200, hush_errors=True) -> PyVizView:
226 | """
227 | Parse and evaluate the Python code in argument statement (string) using
228 | the indicated execution frame. The execution frame of the invoking function
229 | is used if frame is None.
230 |
231 | The visualization finds the smallest subexpressions that evaluate to
232 | tensors then underlies them and shows a box or rectangle representing
233 | the tensor dimensions. Boxes in blue (default) have two or more dimensions
234 | but rectangles in yellow (default) have one dimension with shape (n,).
235 |
236 | Upon tensor-related execution error, the offending self-expression is
237 | highlighted (by de-highlighting the other code) and the operator is shown
238 | using error_op_color.
239 |
240 | To adjust the size of the generated visualization to be smaller or bigger,
241 | decrease or increase the font size.
242 |
243 | :param statement: A string representing the line of Python code to visualize within an execution frame.
244 | :param frame: The execution frame in which to evaluate the statement. If None,
245 | use the execution frame of the invoking function
246 | :param fontname: The name of the font used to display Python code
247 | :param fontsize: The font size used to display Python code; default is 13.
248 | Also use this to increase the size of the generated figure;
249 | larger font size means larger image.
250 | :param dimfontname: The name of the font used to display the dimensions on the matrix and vector boxes
251 | :param dimfontsize: The size of the font used to display the dimensions on the matrix and vector boxes
252 | :param matrixcolor: The color of matrix boxes
253 | :param vectorcolor: The color of vector boxes; only for tensors whose shape is (n,).
254 | :param char_sep_scale: It is notoriously difficult to discover how wide and tall
255 | text is when plotted in matplotlib. In fact there's probably,
256 | no hope to discover this information accurately in all cases.
257 | Certainly, I gave up after spending huge effort. We have a
258 | situation here where the font should be constant width, so
259 | we can just use a simple scaler times the font size to get
260 | a reasonable approximation to the width and height of a
261 | character box; the default of 1.8 seems to work reasonably
262 | well for a wide range of fonts, but you might have to tweak it
263 | when you change the font size.
264 | :param fontcolor: The color of the Python code.
265 | :param underline_color: The color of the lines that underscore tensor subexpressions; default is grey
266 | :param ignored_color: The de-highlighted color for deemphasizing code not involved in an erroneous sub expression
267 | :param error_op_color: The color to use for characters associated with the erroneous operator
268 | :param dimorder: When training deep learning models in batches, we must add a
269 | batch dimension to our training data matrix. The dimension order
270 | required by the deep learning model might be different than the way
271 | we want to visualize it. For example, if each input instance is
272 | a 2D image, then our training data has shape, say, (width,height,n).
273 | If we need to add the batch dimension first, that changes to
274 | (batch,width,height,n) but we still want to visualize the input
275 | as width,height as a 2D picture with a number of instances going
276 | back into the screen (like a 3D cube). It would be convenient to
277 | visualize the batch dimension as the fourth. This parameter
278 | allows you to specify the order of dimensions for any variable
279 | referenced in the code being visualized. E.g., if the input is in
280 | variable X with shape (batch,width,height,n) but we want to display
281 | X as (width,height,n,batch), pass in dimorder=dict(X=[1,2,3,0]).
282 | :param ax: If not none, this is the matplotlib drawing region in which to draw the visualization
283 | :param dpi: This library tries to generate SVG files, which are vector graphics not
284 | 2D arrays of pixels like PNG files. However, it needs to know how to
285 | compute the exact figure size to remove padding around the visualization.
286 | Matplotlib uses inches for its figure size and so we must convert
287 | from pixels or data units to inches, which means we have to know what the
288 | dots per inch, dpi, is for the image.
289 | :param hush_errors: Normally, error messages from true syntax errors but also
290 | unhandled code caught by my parser are ignored. Turn this off
291 | to see what the error messages are coming from my parser.
292 | :return: Returns a PyVizView holding info about the visualization; from a notebook
293 | an SVG image will appear. Return none upon parsing error in statement.
294 | """
295 | view = PyVizView(statement, fontname, fontsize, dimfontname, dimfontsize, matrixcolor,
296 | vectorcolor, char_sep_scale, dpi)
297 |
298 | if frame is None: # use frame of caller if not passed in
299 | frame = sys._getframe().f_back
300 | root, tokens = tsensor.parsing.parse(statement, hush_errors=hush_errors)
301 | if root is None:
302 | # likely syntax error in statement or code I can't handle
303 | return None
304 | root_to_viz = root
305 | try:
306 | root.eval(frame)
307 | except tsensor.ast.IncrEvalTrap as e:
308 | root_to_viz = e.offending_expr
309 | view.offending_expr = e.offending_expr
310 | view.cause = e.__cause__
311 | # Don't raise the exception; keep going to visualize code and erroneous
312 | # subexpressions. If this function is invoked from clarify() or explain(),
313 | # the statement will be executed and will fail again during normal execution;
314 | # an exception will be thrown at that time. Then explain/clarify
315 | # will update the error message
316 | subexprs = tsensor.analysis.smallest_matrix_subexpr(root_to_viz)
317 |
318 | # print(statement) # For debugging
319 | # for i in range(8):
320 | # for j in range(10):
321 | # print(j,end='')
322 | # print()
323 |
324 | if ax is None:
325 | fig, ax = plt.subplots(1, 1, dpi=dpi)
326 | else:
327 | fig = ax.figure
328 |
329 | ax.axis("off")
330 |
331 | # First, we need to figure out how wide the visualization components are
332 | # for each sub expression. If these are wider than the sub expression text,
333 | # than we need to leave space around the sub expression text
334 | lpad = np.zeros((len(statement),)) # pad for characters
335 | rpad = np.zeros((len(statement),))
336 | maxh = 0
337 | for sub in subexprs:
338 | w, h = view.boxsize(sub.value)
339 | maxh = max(h, maxh)
340 | nexpr = sub.stop.cstop_idx - sub.start.cstart_idx
341 | if (sub.start.cstart_idx-1)>0 and statement[sub.start.cstart_idx - 1]== ' ': # if char to left is space
342 | nexpr += 1
343 | if sub.stop.cstop_idxview.wchar * nexpr:
346 | lpad[sub.start.cstart_idx] += (w - view.wchar) / 2
347 | rpad[sub.stop.cstop_idx - 1] += (w - view.wchar) / 2
348 |
349 | # Now we know how to place all the elements, since we know what the maximum height is
350 | view.set_locations(maxh)
351 |
352 | # Find each character's position based upon width of a character and any padding
353 | charx = np.empty((len(statement),))
354 | x = view.leftedge
355 | for i,c in enumerate(statement):
356 | x += lpad[i]
357 | charx[i] = x
358 | x += view.wchar
359 | x += rpad[i]
360 |
361 | # Draw text for statement or expression
362 | if view.offending_expr is not None: # highlight erroneous subexpr
363 | highlight = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
364 | for tok in tokens[root_to_viz.start.index:root_to_viz.stop.index+1]:
365 | highlight[tok.cstart_idx:tok.cstop_idx] = True
366 | errors = np.full(shape=(len(statement),), fill_value=False, dtype=bool)
367 | for tok in root_to_viz.optokens:
368 | errors[tok.cstart_idx:tok.cstop_idx] = True
369 | for i, c in enumerate(statement):
370 | color = ignored_color
371 | if highlight[i]:
372 | color = fontcolor
373 | if errors[i]: # override color if operator token
374 | color = error_op_color
375 | ax.text(charx[i], view.texty, c, color=color, fontname=fontname, fontsize=fontsize)
376 | else:
377 | for i, c in enumerate(statement):
378 | ax.text(charx[i], view.texty, c, color=fontcolor, fontname=fontname, fontsize=fontsize)
379 |
380 | # Compute the left and right edges of subexpressions (alter nodes with info)
381 | for i,sub in enumerate(subexprs):
382 | a = charx[sub.start.cstart_idx]
383 | b = charx[sub.stop.cstop_idx - 1] + view.wchar
384 | sub.leftx = a
385 | sub.rightx = b
386 |
387 | # Draw grey underlines and draw matrices
388 | for i,sub in enumerate(subexprs):
389 | a,b = sub.leftx, sub.rightx
390 | pad = view.wchar*0.1
391 | ax.plot([a-pad, b+pad], [view.liney,view.liney], '-', linewidth=.5, c=underline_color)
392 | view.draw(ax, sub)
393 |
394 | fig_width = charx[-1] + view.wchar + rpad[-1]
395 | fig_width_inches = (fig_width) / dpi
396 | fig_height_inches = view.maxy / dpi
397 | fig.set_size_inches(fig_width_inches, fig_height_inches)
398 |
399 | ax.set_xlim(0, (fig_width))
400 | ax.set_ylim(0, view.maxy)
401 |
402 | return view
403 |
404 |
405 | # ---------------- SHOW AST STUFF ---------------------------
406 |
407 | class QuietGraphvizWrapper(graphviz.Source):
408 | def __init__(self, dotsrc):
409 | super().__init__(source=dotsrc)
410 |
411 | def _repr_svg_(self):
412 | return self.pipe(format='svg', quiet=True).decode(self._encoding)
413 |
414 |
415 | def astviz(statement:str, frame=None) -> graphviz.Source:
416 | return QuietGraphvizWrapper(astviz_dot(statement, frame))
417 |
418 |
419 | def astviz_dot(statement:str, frame=None) -> str:
420 | def internal_label(node,color="yellow"):
421 | text = ''.join(str(t) for t in node.optokens)
422 | sh = tsensor.analysis._shape(node.value)
423 | if sh is None:
424 | return f'{text}'
425 |
426 | sz = 'x'.join([PyVizView.nabbrev(sh[i]) for i in range(len(sh))])
427 | print(sz)
428 | return f"""{text}
{sz}"""
429 |
430 | root, tokens = tsensor.parsing.parse(statement)
431 | if frame is not None:
432 | root.eval(frame)
433 |
434 | nodes = tsensor.ast.postorder(root)
435 | atoms = tsensor.ast.leaves(root)
436 | atomsS = set(atoms)
437 | ops = [nd for nd in nodes if nd not in atomsS] # keep order
438 |
439 | gr = """digraph G {
440 | margin=0;
441 | nodesep=.01;
442 | ranksep=.3;
443 | rankdir=BT;
444 | ordering=out; # keep order of leaves
445 | """
446 |
447 | matrixcolor = "#cfe2d4"
448 | vectorcolor = "#fefecd"
449 | fontname="Consolas"
450 | fontsize=12
451 | dimfontsize = 9
452 | spread = 0
453 |
454 | # Gen leaf nodes
455 | for i in range(len(tokens)):
456 | t = tokens[i]
457 | if t.type!=token.ENDMARKER:
458 | nodetext = t.value
459 | # if ']' in nodetext:
460 | if nodetext==']':
461 | nodetext = nodetext.replace(']',']') # is 0-width nonjoiner. ']' by itself is bad for DOT
462 | label = f'{nodetext}'
463 | _spread = spread
464 | if t.type==token.DOT:
465 | _spread=.1
466 | elif t.type==token.EQUAL:
467 | _spread=.25
468 | elif t.type in tsensor.parsing.ADDOP:
469 | _spread=.4
470 | elif t.type in tsensor.parsing.MULOP:
471 | _spread=.2
472 | gr += f'leaf{id(t)} [shape=box penwidth=0 margin=.001 width={_spread} label=<{label}>]\n'
473 |
474 | # Make sure leaves are on same level
475 | gr += f'{{ rank=same; '
476 | for t in tokens:
477 | if t.type!=token.ENDMARKER:
478 | gr += f' leaf{id(t)}'
479 | gr += '\n}\n'
480 |
481 | # Make sure leaves are left to right by linking
482 | for i in range(len(tokens) - 2):
483 | t = tokens[i]
484 | t2 = tokens[i + 1]
485 | gr += f'leaf{id(t)} -> leaf{id(t2)} [style=invis];\n'
486 |
487 | # Draw internal ops nodes
488 | for nd in ops:
489 | label = internal_label(nd)
490 | sh = tsensor.analysis._shape(nd.value)
491 | if sh is None:
492 | color = ""
493 | else:
494 | if len(sh)==1:
495 | color = f'fillcolor="{vectorcolor}" style=filled'
496 | else:
497 | color = f'fillcolor="{matrixcolor}" style=filled'
498 | gr += f'node{id(nd)} [shape=box {color} penwidth=0 margin=0 width=.25 height=.2 label=<{label}>]\n'
499 |
500 | # Link internal nodes to other nodes or leaves
501 | for nd in nodes:
502 | kids = nd.kids
503 | for sub in kids:
504 | if sub in atomsS:
505 | gr += f'node{id(nd)} -> leaf{id(sub.token)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
506 | else:
507 | gr += f'node{id(nd)} -> node{id(sub)} [dir=back, penwidth="0.5", color="#6B6B6B", arrowsize=.3];\n'
508 |
509 | gr += "}\n"
510 | return gr
--------------------------------------------------------------------------------