├── .gitignore
├── LICENSE
├── README.md
├── autodiff_puzzlers.ipynb
├── autodiff_puzzlers_files
├── autodiff_puzzlers_10_2.svg
├── autodiff_puzzlers_11_0.svg
├── autodiff_puzzlers_12_0.svg
├── autodiff_puzzlers_13_2.svg
├── autodiff_puzzlers_14_2.svg
├── autodiff_puzzlers_15_2.svg
├── autodiff_puzzlers_16_2.svg
├── autodiff_puzzlers_17_2.svg
├── autodiff_puzzlers_18_2.svg
├── autodiff_puzzlers_19_2.svg
├── autodiff_puzzlers_20_2.svg
├── autodiff_puzzlers_21_2.svg
├── autodiff_puzzlers_22_2.svg
├── autodiff_puzzlers_23_2.svg
├── autodiff_puzzlers_24_2.svg
├── autodiff_puzzlers_25_2.svg
├── autodiff_puzzlers_26_2.svg
├── autodiff_puzzlers_27_2.svg
├── autodiff_puzzlers_28_2.svg
├── autodiff_puzzlers_29_4.svg
├── autodiff_puzzlers_30_4.svg
├── autodiff_puzzlers_31_4.svg
├── autodiff_puzzlers_32_4.svg
├── autodiff_puzzlers_33_4.svg
├── autodiff_puzzlers_34_4.svg
├── autodiff_puzzlers_36_2.svg
├── autodiff_puzzlers_37_2.svg
├── autodiff_puzzlers_38_2.svg
├── autodiff_puzzlers_39_2.svg
├── autodiff_puzzlers_40_4.svg
├── autodiff_puzzlers_41_4.svg
├── autodiff_puzzlers_42_2.svg
├── autodiff_puzzlers_43_2.svg
├── autodiff_puzzlers_44_4.svg
├── autodiff_puzzlers_45_4.svg
├── autodiff_puzzlers_46_4.svg
├── autodiff_puzzlers_47_4.svg
├── autodiff_puzzlers_5_2.svg
├── autodiff_puzzlers_6_2.svg
├── autodiff_puzzlers_7_2.svg
├── autodiff_puzzlers_8_2.svg
└── autodiff_puzzlers_9_2.svg
├── image.png
└── lib.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Sasha Rush
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Autodiff Puzzles
2 | - by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp)
3 |
4 | **Click here to get started:**
5 |
6 | [](https://colab.research.google.com/github/srush/Autodiff-Puzzles/blob/main/autodiff_puzzlers.ipynb)
7 |
8 | This notebook contains a series of self-contained puzzles for learning about derivatives in tensor libraries. It is the 3rd puzzle set in a series of puzzles about programming for deep learning ([Tensor Puzzles](https://github.com/srush/Tensor-Puzzles), [GPU Puzzles](https://github.com/srush/GPU-Puzzles)).
9 |
10 | 
11 |
12 |
13 | Your goal in these puzzles is to implement the derivatives for each core function. In each case the function takes in a tensor x and returns a tensor f(x), so your job is to compute $\frac{d f(x)_o}{dx_i}$ for all indices $o$ and $i$. If you get discouraged, just remember, you did this in high school (it just had way less indexing).
14 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_10_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_11_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_12_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_21_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_5_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_6_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_7_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_8_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/autodiff_puzzlers_files/autodiff_puzzlers_9_2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/srush/Autodiff-Puzzles/ae4586b052f3328031ad64024edabbf6fbd2e4bb/image.png
--------------------------------------------------------------------------------
/lib.py:
--------------------------------------------------------------------------------
1 | import chalk
2 | from chalk import *
3 | import torch
4 | import random
5 | from colour import Color
6 |
7 | import pandas as pd
8 | import sys
9 | sys.setrecursionlimit(10000)
10 | set_svg_height(400)
11 | set_svg_draw_height(400)
12 | from IPython.display import display, HTML
13 |
14 | def show_dog():
15 | print("Passed Tests!")
16 | pups = [
17 | "2m78jPG",
18 | "pn1e9TO",
19 | "MQCIwzT",
20 | "udLK6FS",
21 | "ZNem5o3",
22 | "DS2IZ6K",
23 | "aydRUz8",
24 | "MVUdQYK",
25 | "kLvno0p",
26 | "wScLiVz",
27 | "Z0TII8i",
28 | "F1SChho",
29 | "9hRi2jN",
30 | "lvzRF3W",
31 | "fqHxOGI",
32 | "1xeUYme",
33 | "6tVqKyM",
34 | "CCxZ6Wr",
35 | "lMW0OPQ",
36 | "wHVpHVG",
37 | "Wj2PGRl",
38 | "HlaTE8H",
39 | "k5jALH0",
40 | "3V37Hqr",
41 | "Eq2uMTA",
42 | "Vy9JShx",
43 | "g9I2ZmK",
44 | "Nu4RH7f",
45 | "sWp0Dqd",
46 | "bRKfspn",
47 | "qawCMl5",
48 | "2F6j2B4",
49 | "fiJxCVA",
50 | "pCAIlxD",
51 | "zJx2skh",
52 | "2Gdl1u7",
53 | "aJJAY4c",
54 | "ros6RLC",
55 | "DKLBJh7",
56 | "eyxH0Wc",
57 | "rJEkEw4"]
58 | display(HTML("""
59 |
62 | """%(random.sample(pups, 1)[0])))
63 |
64 |
65 |
66 | def line(y, rev=False):
67 | return [(i + j, -2 * float(y[i]))
68 | for i in range(y.shape[0])
69 | for j in range(2)]
70 |
71 | def graph(y, name, splits=[]):
72 | top = make_path([(0, -2), (0, 2)])
73 | top += make_path([(0, 0), (y.shape[0],0)]).line_width(0.05).line_color(Color("grey"))
74 | top += make_path(line(y)).line_width(0.2)
75 | for s in splits:
76 | top += make_path([(s, -2), (s, 2)])
77 | top = top.named(Name(name))
78 | top = frame(name, top, Color("#EEEEEE"), name, h=1.5) + top
79 | top = top.scale_uniform_to_x((y.shape[0] + 15)/ 100).align_tl()
80 | return top
81 |
82 | def overgraph(d, name, y, y2, color):
83 | one = torch.tensor(1)
84 | y = torch.maximum(torch.minimum(one, y), -one)
85 | y2 = torch.maximum(torch.minimum(one, y2), -one)
86 |
87 | bot2 = make_path([(0, -2), (0, 2)])
88 |
89 | pts = line(y2) + list(reversed(line(y)))
90 | bot2 += make_path(pts, True)
91 | bot2 = bot2.line_color(color).fill_opacity(0.5).fill_color(color).line_width(0.1)
92 | scale, trans = get_transform(name, d, y)
93 | old = make_path(line(y))
94 | return (bot2 + old).scale_x(scale[0]).scale_y(scale[1]).translate_by(trans)
95 |
96 | def frame(name, d, c, l = "", w=15, h=1.5):
97 | s = d.get_subdiagram(Name(name))
98 | env = s.get_envelope()
99 | r = rectangle(env.width + w, env.height * h).fill_color(c).line_width(0)
100 | if l:
101 | r = r.align_l() + (hstrut(w/4) | text(l, 2.5).fill_color(Color("black")).line_width(0)).align_l()
102 | return r.center_xy().translate_by(env.center) + d
103 |
104 | def get_transform(name, d, v):
105 | s = d.get_subdiagram(Name(name))
106 | e = s.get_envelope()
107 | x = e.width / v.shape[0]
108 | y = e.height / (2 * 2)
109 | l = s.get_location()
110 | return (x, y), l
111 |
112 | def get_locations(name, d, x, y, v):
113 | (x_step, y_step), l = get_transform(name, d, v)
114 | return (l[0] + x_step * x,
115 | l[1] - y_step * 2 * y )
116 |
117 | ORANGE = Color("orange")
118 |
119 | def connect(d, x, x_name, f_x, f_x_name, q,
120 | i_s=None,
121 | color=ORANGE, to_line=False, bad={}):
122 | offset = 0.1
123 | x1, y1 = get_locations(x_name, d,
124 | torch.arange(x.shape[0]+1),
125 | torch.zeros(x.shape[0]+1) - offset, x)
126 |
127 | if to_line:
128 | land = f_x
129 | else:
130 | land = f_x / f_x
131 | land = torch.cat([land, torch.tensor([0, 0])], 0)
132 | x2, y2 = get_locations(f_x_name, d,
133 | torch.arange(f_x.shape[0]+2),
134 | land + offset, f_x)
135 | c = color
136 | if i_s is None:
137 | i_s = range(x.shape[0])
138 | m = q.abs().max()
139 | #q = torch.where(q.abs() / m > 0.1, q, 0)
140 | for i in i_s:
141 | for j in q[:, i].nonzero()[:, 0]:
142 | if q[j, i] != 0 and i < x1.shape[0] -1 and j < y2.shape[0]-2:
143 | p = make_path([((x1[i] + x1[i+1])/2, y1[i]),
144 | (x2[j], (y2[j] + y2[j-1]) / 2),
145 | (x2[j+1], (y2[j] + y2[j+1]) / 2)], True)
146 | p = p.line_width(0).fill_color(c).fill_color(c if j not in bad.get(i, []) else Color("red"))\
147 | .fill_opacity(0.3 * abs(q[j, i] / m))
148 |
149 | d += p
150 | return d
151 |
152 | def diff_graph(d, f_x, f_x_name, d_f_x, c, amount=5):
153 | updated = f_x + 5* d_f_x
154 | overgraph(d, f_x_name, "", f_x, updated, c)
155 |
156 | def double(y, f_y):
157 | d = vcat([graph(y, "x").center_xy(),
158 | graph(f_y, "f(x)").center_xy()], 0.1)
159 | return d
160 |
161 | def two_arg(x, y, f_xy, gaps=None):
162 | return vcat([hcat([graph(x, "x", gaps), graph(y, "y")], 0.03).center_xy(),
163 | graph(f_xy, "f(x)").center_xy()], 0.1)
164 | def outer_frame(d):
165 | return frame("full", d.named(Name("full")), Color("white"), w=0.1, h=1.2) + d
166 |
167 |
168 | def m(inp, out, f):
169 | ret = torch.zeros(inp.shape[0], out.shape[0])
170 | for i in range(ret.shape[0]):
171 | for j in range(ret.shape[1]):
172 | ret[i, j] = f(i, j)
173 | return ret.T
174 |
175 | def m2(inp, out, f):
176 | ret = torch.zeros(inp.shape[0], out.shape[0], out.shape[1])
177 | for i in range(ret.shape[0]):
178 | for j in range(ret.shape[1]):
179 | for k in range(ret.shape[2]):
180 | ret[i, j, k] = f(i, j, k)
181 | return ret.permute(1, 2, 0)
182 |
183 | def m3(inp, out, f):
184 | ret = torch.zeros(inp.shape[0], inp.shape[1], out.shape[0], out.shape[1])
185 | for i in range(ret.shape[0]):
186 | for j in range(ret.shape[1]):
187 | for k in range(ret.shape[2]):
188 | for l in range(ret.shape[3]):
189 | ret[i, j, k, l] = f(i, j, k, l)
190 | return ret.permute(2, 3, 0, 1)
191 |
192 | def v(s, f):
193 | ret = torch.zeros(s)
194 | for i in range(s):
195 | ret[i] = f(i)
196 | return ret
197 |
198 | def v2(s, f):
199 | ret = torch.zeros(*s)
200 | for i in range(s[0]):
201 | for j in range(s[1]):
202 | ret[i, j] = f(i, j)
203 | return ret
204 |
205 | def numerical_deriv(fb, x, out):
206 | s = None
207 | if s is None:
208 | s = x.shape[0]
209 | dx2 = torch.zeros(out.shape[0], s)
210 | for i in range(s):
211 | up = x + 1e-5 * torch.eye(s).double()[i]
212 | f1 = fb(up)
213 | up = x - 1e-5 * torch.eye(s).double()[i]
214 | f2 = fb(up)
215 | for j in range(out.shape[0]):
216 | dx2[j, i] = (f1(j) - f2(j)) / (2 * 1e-5)
217 | return dx2
218 |
219 | def two_argf(fb, x, y, out_shape):
220 | f, dx, dy = fb(x, y)
221 | out = v(out_shape, f)
222 | dx = m(x, out, dx)
223 | dy = m(y, out, dy)
224 | dx2 = numerical_deriv(lambda a: fb(a, y)[0], x, out)
225 | dy2 = numerical_deriv(lambda b: fb(x, b)[0], y, out)
226 | return out, dx, dy, dx2, dy2
227 |
228 | def one_argf(fb, x, out_shape):
229 | f, dx = fb(x)
230 | out = v(out_shape, f)
231 | dx = m(x, out, dx)
232 | dx2 = numerical_deriv(lambda a: fb(a)[0], x, out)
233 | return out, dx, dx2
234 |
235 | # def two_mat_argf(fb, x, y, out_shape, in_shape):
236 | # f, dx, dy = fb(x, y2)
237 | # out = v2(out_shape, f)
238 | # dx = m3(x, out, dx)
239 | # dy = m2(y2, out, dy)
240 | # dx2 = []
241 | # dx2 = numerical_deriv(lambda a: lambda v: fb(a, y)[0](v // out_shape[0], v % out_shape[0]), x.view(-1), out.view(-1), in_shape=in_shape)
242 | # dy2 = numerical_deriv(lambda b: fb(x, b)[0], y.view(-1), out.view(-1))
243 | # return out, dx, dy, dx2, dy2
244 |
245 | def check(dx, dx2):
246 | bad = {}
247 | df = []
248 | for j, i in (~torch.isclose(dx, dx2, atol=1e-4)).nonzero():
249 | #print(i.item(), j.item(), dx[i,j].item(), dx2[i,j].item())
250 | bad.setdefault(i.item(), [])
251 | bad[i.item()].append(j.item())
252 | df.append({"In Index": i.item(), "Out Index": j.item()})
253 | return bad, pd.DataFrame(df)
254 |
255 | gy = torch.tensor([math.sin(x/20) * 0.5 + (random.random() - 0.5)
256 | for x in range(50)]).double()
257 |
258 | def fb_demo(x):
259 | f = lambda o: x[o]
260 | dx = lambda i, o: (abs(o-i) < 4) * (abs(o-i) % 2) # Fill in this line
261 | return f, dx
262 |
263 | def in_out2(fb, fb2, pos=None, overlap=False, diff=1, out_shape=50, y=gy):
264 | "For functions with point samples"
265 | set_svg_height(500)
266 | f_y, dx, _ = one_argf(fb, y, out_shape)
267 | g_f_y, dxg, _ = one_argf(fb2, f_y, out_shape)
268 |
269 | if pos is None:
270 | pos = range(y.shape[0])
271 | d = vcat([graph(y, "x").center_xy(),
272 | graph(f_y, "f(x)").center_xy(),
273 | graph(g_f_y, "g(f(x))").center_xy(),
274 | ], 0.1)
275 |
276 | d += overgraph(d, "f(x)", f_y, f_y + diff * dx[:, pos].sum(-1), Color("red"))
277 | d += overgraph(d, "g(f(x))", g_f_y, g_f_y + diff * (dxg @ dx)[:, pos].sum(-1), Color("green"))
278 |
279 | d = connect(d, y, "x", f_y, "f(x)", dx, pos, to_line=True)
280 | d = connect(d, f_y, "f(x)", g_f_y, "g(f(x))", dxg @ dx,
281 | [i.item() for i in dx[:, pos].sum(-1).nonzero()], to_line=True, color=Color("lightgreen"))
282 |
283 | return outer_frame(d)
284 |
285 |
286 | def in_out(fb, pos=None, overlap=False, diff=1, out_shape=50, y=gy):
287 | "For functions with point samples"
288 | out, dx, dx2 = one_argf(fb, y, out_shape)
289 | bad, df = check(dx, dx2)
290 |
291 | if pos is None:
292 | pos = range(y.shape[0])
293 | set_svg_height(300)
294 | d = double(y, out)
295 | if overlap:
296 | for p in pos:
297 | d += overgraph(d, "f(x)", out, out + diff * dx[:, p], Color("red"))
298 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, p], Color("green"))
299 |
300 | else:
301 | d += overgraph(d, "f(x)", out, out + diff * dx[:, pos].sum(-1), Color("red"))
302 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, pos].sum(-1), Color("lightyellow"))
303 |
304 | d = connect(d, y, "x", out, "f(x)", dx,
305 | pos, to_line=True, bad=bad)
306 | set_svg_height(300)
307 | if bad:
308 | print("Errors")
309 | display(df[:10])
310 | else:
311 | show_dog()
312 | return outer_frame(d)
313 |
314 | def zip(fb, split=25, pos1 = None, pos2=None, out_shape=25, diff=1, overlap=False, gaps=[0], y=gy):
315 | x, y2 = y[:split], y[split:]
316 | out, dx, dy, dx2, dy2= two_argf(fb, x, y2, out_shape)
317 | bad_x, df_x = check(dx, dx2)
318 | bad_y, df_y = check(dy, dy2)
319 | if pos1 is None:
320 | pos1 = range(x.shape[0])
321 | if pos2 is None:
322 | pos2 = range(y2.shape[0])
323 | d = two_arg(x, y2, out, gaps)
324 | gaps = gaps + [x.shape[0]]
325 | if len(gaps) == 2:
326 | colors = [ORANGE]
327 | else:
328 | colors = list(Color("yellow").range_to("darkorange", len(gaps)-1))
329 | for k, c in enumerate(colors):
330 | d = connect(d, x, "x", out, "f(x)", dx, [p for p in pos1 if gaps[k] < p < gaps[k+1]],
331 | bad=bad_x, color=c)
332 |
333 | d = connect(d, y2, "y", out, "f(x)", dy, color=Color("lightblue"), i_s=pos2, bad =bad_y)
334 |
335 | if overlap:
336 | for p in pos1:
337 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, p], Color("lightyellow"))
338 | d += overgraph(d, "f(x)", out, out + diff * dx[:, p], Color("darkorange"))
339 | for p in pos2:
340 | d += overgraph(d, "f(x)", out, out + diff * dy2[:, p], Color("lightblue"))
341 | d += overgraph(d, "f(x)", out, out + diff * dy[:, p], Color("blue"))
342 |
343 | else:
344 | d += overgraph(d, "f(x)", out, out + diff * dx2.sum(-1), Color("lightyellow"))
345 | d += overgraph(d, "f(x)", out, out + diff * dy2.sum(-1), Color("lightblue"))
346 | d += overgraph(d, "f(x)", out, out + diff * dx.sum(-1), Color("darkorange"))
347 | d += overgraph(d, "f(x)", out, out + diff * dy.sum(-1), Color("blue"))
348 | set_svg_height(300)
349 | if bad_x:
350 | print("x Errors")
351 | display(df_x[:10])
352 | if bad_y:
353 | print("y Errors")
354 | display(df_y[:10])
355 | if not bad_x and not bad_y:
356 | show_dog()
357 | return outer_frame(d)
358 | # def fb_index(x):
359 | # f = lambda o: x[o+5]
360 | # dx = lambda i, o: (o + 25) == i
361 | # return f, dx
362 | # in_out(fb_index, overlap=False, out_shape=25)
363 | # def mat(fb, split, in_shape, out_shape):
364 | # x, y2 = y[:split], y[split:]
365 | # x = x.view(*in_shape)
366 | # f, dx, dy, dx2, dy2 = two_mat_argf(fb, x, y2, out_shape)
367 | # bad_y, df_y = check(dy, dy2)
368 |
369 | # d = vcat([graph(x[i], f"x{i}") for i in range(x.shape[0])], 0.0)
370 | # d = hcat([d.center_xy(), graph(y2, "y")], 0.2)
371 | # d = vcat([d.center_xy(),
372 | # vcat([graph(out[i], f"f(x){i}").center_xy() for i in range(out.shape[0])])], 0.15)
373 | # s = d
374 | # for j in range(out.shape[0]):
375 | # for i in range(x.shape[0]):
376 | # d = connect(d, x[i], f"x{i}", out[j], f"f(x){j}", dx[j, :, i],
377 | # range(x.shape[1]),
378 | # list(Color("red").range_to("orange", x.shape[0]))[i])
379 | # d = connect(d, y2, "y", out[j], f"f(x){j}", dy[j],
380 | # range(y2.shape[0]),
381 | # Color("lightblue"), bad=bad_y)
382 | # if bad_y:
383 | # print("y Errors")
384 | # display(df_y[:10])
385 | # set_svg_height(800)
386 | # return outer_frame(d)
387 | def make_mat(fb, in_shape, out_shape):
388 | def nf(x, y):
389 | f, d_x, d_y = fb(x.view(in_shape), y)
390 | def f2(o):
391 | return f(o // out_shape[1], o % out_shape[1])
392 | def d_x2(i, o):
393 | return d_x(i // in_shape[1], i % in_shape[1], o // out_shape[1], o % out_shape[1])
394 | def d_y2(j, o):
395 | return d_y(j, o // out_shape[1], o % out_shape[1])
396 | return f2, d_x2, d_y2
397 | return nf
398 |
399 | def make_mat2(fb, in_shape, in_shape2, out_shape):
400 | def nf(x, y):
401 | f, d_x, d_y = fb(x.view(in_shape), y.view(in_shape2))
402 | def f2(o):
403 | return f(o // out_shape[1], o % out_shape[1])
404 | def d_x2(i, o):
405 | return d_x(i // in_shape[1], i % in_shape[1], o // out_shape[1], o % out_shape[1])
406 | def d_y2(j, o):
407 | return d_y(j // in_shape2[1], j % in_shape2[1], o // out_shape[1], o % out_shape[1])
408 | return f2, d_x2, d_y2
409 | return nf
410 |
--------------------------------------------------------------------------------