├── .gitignore ├── LICENSE ├── README.md ├── autodiff_puzzlers.ipynb ├── autodiff_puzzlers_files ├── autodiff_puzzlers_10_2.svg ├── autodiff_puzzlers_11_0.svg ├── autodiff_puzzlers_12_0.svg ├── autodiff_puzzlers_13_2.svg ├── autodiff_puzzlers_14_2.svg ├── autodiff_puzzlers_15_2.svg ├── autodiff_puzzlers_16_2.svg ├── autodiff_puzzlers_17_2.svg ├── autodiff_puzzlers_18_2.svg ├── autodiff_puzzlers_19_2.svg ├── autodiff_puzzlers_20_2.svg ├── autodiff_puzzlers_21_2.svg ├── autodiff_puzzlers_22_2.svg ├── autodiff_puzzlers_23_2.svg ├── autodiff_puzzlers_24_2.svg ├── autodiff_puzzlers_25_2.svg ├── autodiff_puzzlers_26_2.svg ├── autodiff_puzzlers_27_2.svg ├── autodiff_puzzlers_28_2.svg ├── autodiff_puzzlers_29_4.svg ├── autodiff_puzzlers_30_4.svg ├── autodiff_puzzlers_31_4.svg ├── autodiff_puzzlers_32_4.svg ├── autodiff_puzzlers_33_4.svg ├── autodiff_puzzlers_34_4.svg ├── autodiff_puzzlers_36_2.svg ├── autodiff_puzzlers_37_2.svg ├── autodiff_puzzlers_38_2.svg ├── autodiff_puzzlers_39_2.svg ├── autodiff_puzzlers_40_4.svg ├── autodiff_puzzlers_41_4.svg ├── autodiff_puzzlers_42_2.svg ├── autodiff_puzzlers_43_2.svg ├── autodiff_puzzlers_44_4.svg ├── autodiff_puzzlers_45_4.svg ├── autodiff_puzzlers_46_4.svg ├── autodiff_puzzlers_47_4.svg ├── autodiff_puzzlers_5_2.svg ├── autodiff_puzzlers_6_2.svg ├── autodiff_puzzlers_7_2.svg ├── autodiff_puzzlers_8_2.svg └── autodiff_puzzlers_9_2.svg ├── image.png └── lib.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sasha Rush 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autodiff Puzzles 2 | - by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp) 3 | 4 | **Click here to get started:** 5 | 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/Autodiff-Puzzles/blob/main/autodiff_puzzlers.ipynb) 7 | 8 | This notebook contains a series of self-contained puzzles for learning about derivatives in tensor libraries. It is the 3rd puzzle set in a series of puzzles about programming for deep learning ([Tensor Puzzles](https://github.com/srush/Tensor-Puzzles), [GPU Puzzles](https://github.com/srush/GPU-Puzzles)). 9 | 10 | ![image](https://github.com/user-attachments/assets/be04dea0-a15c-4d7d-b2d5-263c171d0f35) 11 | 12 | 13 | Your goal in these puzzles is to implement the derivatives for each core function. In each case the function takes in a tensor x and returns a tensor f(x), so your job is to compute $\frac{d f(x)_o}{dx_i}$ for all indices $o$ and $i$. If you get discouraged, just remember, you did this in high school (it just had way less indexing). 14 | -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_10_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_11_0.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_12_0.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_21_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_5_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_6_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_7_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_8_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /autodiff_puzzlers_files/autodiff_puzzlers_9_2.svg: -------------------------------------------------------------------------------- 1 | 2 | xf(x)xf(x) -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srush/Autodiff-Puzzles/ae4586b052f3328031ad64024edabbf6fbd2e4bb/image.png -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | import chalk 2 | from chalk import * 3 | import torch 4 | import random 5 | from colour import Color 6 | 7 | import pandas as pd 8 | import sys 9 | sys.setrecursionlimit(10000) 10 | set_svg_height(400) 11 | set_svg_draw_height(400) 12 | from IPython.display import display, HTML 13 | 14 | def show_dog(): 15 | print("Passed Tests!") 16 | pups = [ 17 | "2m78jPG", 18 | "pn1e9TO", 19 | "MQCIwzT", 20 | "udLK6FS", 21 | "ZNem5o3", 22 | "DS2IZ6K", 23 | "aydRUz8", 24 | "MVUdQYK", 25 | "kLvno0p", 26 | "wScLiVz", 27 | "Z0TII8i", 28 | "F1SChho", 29 | "9hRi2jN", 30 | "lvzRF3W", 31 | "fqHxOGI", 32 | "1xeUYme", 33 | "6tVqKyM", 34 | "CCxZ6Wr", 35 | "lMW0OPQ", 36 | "wHVpHVG", 37 | "Wj2PGRl", 38 | "HlaTE8H", 39 | "k5jALH0", 40 | "3V37Hqr", 41 | "Eq2uMTA", 42 | "Vy9JShx", 43 | "g9I2ZmK", 44 | "Nu4RH7f", 45 | "sWp0Dqd", 46 | "bRKfspn", 47 | "qawCMl5", 48 | "2F6j2B4", 49 | "fiJxCVA", 50 | "pCAIlxD", 51 | "zJx2skh", 52 | "2Gdl1u7", 53 | "aJJAY4c", 54 | "ros6RLC", 55 | "DKLBJh7", 56 | "eyxH0Wc", 57 | "rJEkEw4"] 58 | display(HTML(""" 59 | 62 | """%(random.sample(pups, 1)[0]))) 63 | 64 | 65 | 66 | def line(y, rev=False): 67 | return [(i + j, -2 * float(y[i])) 68 | for i in range(y.shape[0]) 69 | for j in range(2)] 70 | 71 | def graph(y, name, splits=[]): 72 | top = make_path([(0, -2), (0, 2)]) 73 | top += make_path([(0, 0), (y.shape[0],0)]).line_width(0.05).line_color(Color("grey")) 74 | top += make_path(line(y)).line_width(0.2) 75 | for s in splits: 76 | top += make_path([(s, -2), (s, 2)]) 77 | top = top.named(Name(name)) 78 | top = frame(name, top, Color("#EEEEEE"), name, h=1.5) + top 79 | top = top.scale_uniform_to_x((y.shape[0] + 15)/ 100).align_tl() 80 | return top 81 | 82 | def overgraph(d, name, y, y2, color): 83 | one = torch.tensor(1) 84 | y = torch.maximum(torch.minimum(one, y), -one) 85 | y2 = torch.maximum(torch.minimum(one, y2), -one) 86 | 87 | bot2 = make_path([(0, -2), (0, 2)]) 88 | 89 | pts = line(y2) + list(reversed(line(y))) 90 | bot2 += make_path(pts, True) 91 | bot2 = bot2.line_color(color).fill_opacity(0.5).fill_color(color).line_width(0.1) 92 | scale, trans = get_transform(name, d, y) 93 | old = make_path(line(y)) 94 | return (bot2 + old).scale_x(scale[0]).scale_y(scale[1]).translate_by(trans) 95 | 96 | def frame(name, d, c, l = "", w=15, h=1.5): 97 | s = d.get_subdiagram(Name(name)) 98 | env = s.get_envelope() 99 | r = rectangle(env.width + w, env.height * h).fill_color(c).line_width(0) 100 | if l: 101 | r = r.align_l() + (hstrut(w/4) | text(l, 2.5).fill_color(Color("black")).line_width(0)).align_l() 102 | return r.center_xy().translate_by(env.center) + d 103 | 104 | def get_transform(name, d, v): 105 | s = d.get_subdiagram(Name(name)) 106 | e = s.get_envelope() 107 | x = e.width / v.shape[0] 108 | y = e.height / (2 * 2) 109 | l = s.get_location() 110 | return (x, y), l 111 | 112 | def get_locations(name, d, x, y, v): 113 | (x_step, y_step), l = get_transform(name, d, v) 114 | return (l[0] + x_step * x, 115 | l[1] - y_step * 2 * y ) 116 | 117 | ORANGE = Color("orange") 118 | 119 | def connect(d, x, x_name, f_x, f_x_name, q, 120 | i_s=None, 121 | color=ORANGE, to_line=False, bad={}): 122 | offset = 0.1 123 | x1, y1 = get_locations(x_name, d, 124 | torch.arange(x.shape[0]+1), 125 | torch.zeros(x.shape[0]+1) - offset, x) 126 | 127 | if to_line: 128 | land = f_x 129 | else: 130 | land = f_x / f_x 131 | land = torch.cat([land, torch.tensor([0, 0])], 0) 132 | x2, y2 = get_locations(f_x_name, d, 133 | torch.arange(f_x.shape[0]+2), 134 | land + offset, f_x) 135 | c = color 136 | if i_s is None: 137 | i_s = range(x.shape[0]) 138 | m = q.abs().max() 139 | #q = torch.where(q.abs() / m > 0.1, q, 0) 140 | for i in i_s: 141 | for j in q[:, i].nonzero()[:, 0]: 142 | if q[j, i] != 0 and i < x1.shape[0] -1 and j < y2.shape[0]-2: 143 | p = make_path([((x1[i] + x1[i+1])/2, y1[i]), 144 | (x2[j], (y2[j] + y2[j-1]) / 2), 145 | (x2[j+1], (y2[j] + y2[j+1]) / 2)], True) 146 | p = p.line_width(0).fill_color(c).fill_color(c if j not in bad.get(i, []) else Color("red"))\ 147 | .fill_opacity(0.3 * abs(q[j, i] / m)) 148 | 149 | d += p 150 | return d 151 | 152 | def diff_graph(d, f_x, f_x_name, d_f_x, c, amount=5): 153 | updated = f_x + 5* d_f_x 154 | overgraph(d, f_x_name, "", f_x, updated, c) 155 | 156 | def double(y, f_y): 157 | d = vcat([graph(y, "x").center_xy(), 158 | graph(f_y, "f(x)").center_xy()], 0.1) 159 | return d 160 | 161 | def two_arg(x, y, f_xy, gaps=None): 162 | return vcat([hcat([graph(x, "x", gaps), graph(y, "y")], 0.03).center_xy(), 163 | graph(f_xy, "f(x)").center_xy()], 0.1) 164 | def outer_frame(d): 165 | return frame("full", d.named(Name("full")), Color("white"), w=0.1, h=1.2) + d 166 | 167 | 168 | def m(inp, out, f): 169 | ret = torch.zeros(inp.shape[0], out.shape[0]) 170 | for i in range(ret.shape[0]): 171 | for j in range(ret.shape[1]): 172 | ret[i, j] = f(i, j) 173 | return ret.T 174 | 175 | def m2(inp, out, f): 176 | ret = torch.zeros(inp.shape[0], out.shape[0], out.shape[1]) 177 | for i in range(ret.shape[0]): 178 | for j in range(ret.shape[1]): 179 | for k in range(ret.shape[2]): 180 | ret[i, j, k] = f(i, j, k) 181 | return ret.permute(1, 2, 0) 182 | 183 | def m3(inp, out, f): 184 | ret = torch.zeros(inp.shape[0], inp.shape[1], out.shape[0], out.shape[1]) 185 | for i in range(ret.shape[0]): 186 | for j in range(ret.shape[1]): 187 | for k in range(ret.shape[2]): 188 | for l in range(ret.shape[3]): 189 | ret[i, j, k, l] = f(i, j, k, l) 190 | return ret.permute(2, 3, 0, 1) 191 | 192 | def v(s, f): 193 | ret = torch.zeros(s) 194 | for i in range(s): 195 | ret[i] = f(i) 196 | return ret 197 | 198 | def v2(s, f): 199 | ret = torch.zeros(*s) 200 | for i in range(s[0]): 201 | for j in range(s[1]): 202 | ret[i, j] = f(i, j) 203 | return ret 204 | 205 | def numerical_deriv(fb, x, out): 206 | s = None 207 | if s is None: 208 | s = x.shape[0] 209 | dx2 = torch.zeros(out.shape[0], s) 210 | for i in range(s): 211 | up = x + 1e-5 * torch.eye(s).double()[i] 212 | f1 = fb(up) 213 | up = x - 1e-5 * torch.eye(s).double()[i] 214 | f2 = fb(up) 215 | for j in range(out.shape[0]): 216 | dx2[j, i] = (f1(j) - f2(j)) / (2 * 1e-5) 217 | return dx2 218 | 219 | def two_argf(fb, x, y, out_shape): 220 | f, dx, dy = fb(x, y) 221 | out = v(out_shape, f) 222 | dx = m(x, out, dx) 223 | dy = m(y, out, dy) 224 | dx2 = numerical_deriv(lambda a: fb(a, y)[0], x, out) 225 | dy2 = numerical_deriv(lambda b: fb(x, b)[0], y, out) 226 | return out, dx, dy, dx2, dy2 227 | 228 | def one_argf(fb, x, out_shape): 229 | f, dx = fb(x) 230 | out = v(out_shape, f) 231 | dx = m(x, out, dx) 232 | dx2 = numerical_deriv(lambda a: fb(a)[0], x, out) 233 | return out, dx, dx2 234 | 235 | # def two_mat_argf(fb, x, y, out_shape, in_shape): 236 | # f, dx, dy = fb(x, y2) 237 | # out = v2(out_shape, f) 238 | # dx = m3(x, out, dx) 239 | # dy = m2(y2, out, dy) 240 | # dx2 = [] 241 | # dx2 = numerical_deriv(lambda a: lambda v: fb(a, y)[0](v // out_shape[0], v % out_shape[0]), x.view(-1), out.view(-1), in_shape=in_shape) 242 | # dy2 = numerical_deriv(lambda b: fb(x, b)[0], y.view(-1), out.view(-1)) 243 | # return out, dx, dy, dx2, dy2 244 | 245 | def check(dx, dx2): 246 | bad = {} 247 | df = [] 248 | for j, i in (~torch.isclose(dx, dx2, atol=1e-4)).nonzero(): 249 | #print(i.item(), j.item(), dx[i,j].item(), dx2[i,j].item()) 250 | bad.setdefault(i.item(), []) 251 | bad[i.item()].append(j.item()) 252 | df.append({"In Index": i.item(), "Out Index": j.item()}) 253 | return bad, pd.DataFrame(df) 254 | 255 | gy = torch.tensor([math.sin(x/20) * 0.5 + (random.random() - 0.5) 256 | for x in range(50)]).double() 257 | 258 | def fb_demo(x): 259 | f = lambda o: x[o] 260 | dx = lambda i, o: (abs(o-i) < 4) * (abs(o-i) % 2) # Fill in this line 261 | return f, dx 262 | 263 | def in_out2(fb, fb2, pos=None, overlap=False, diff=1, out_shape=50, y=gy): 264 | "For functions with point samples" 265 | set_svg_height(500) 266 | f_y, dx, _ = one_argf(fb, y, out_shape) 267 | g_f_y, dxg, _ = one_argf(fb2, f_y, out_shape) 268 | 269 | if pos is None: 270 | pos = range(y.shape[0]) 271 | d = vcat([graph(y, "x").center_xy(), 272 | graph(f_y, "f(x)").center_xy(), 273 | graph(g_f_y, "g(f(x))").center_xy(), 274 | ], 0.1) 275 | 276 | d += overgraph(d, "f(x)", f_y, f_y + diff * dx[:, pos].sum(-1), Color("red")) 277 | d += overgraph(d, "g(f(x))", g_f_y, g_f_y + diff * (dxg @ dx)[:, pos].sum(-1), Color("green")) 278 | 279 | d = connect(d, y, "x", f_y, "f(x)", dx, pos, to_line=True) 280 | d = connect(d, f_y, "f(x)", g_f_y, "g(f(x))", dxg @ dx, 281 | [i.item() for i in dx[:, pos].sum(-1).nonzero()], to_line=True, color=Color("lightgreen")) 282 | 283 | return outer_frame(d) 284 | 285 | 286 | def in_out(fb, pos=None, overlap=False, diff=1, out_shape=50, y=gy): 287 | "For functions with point samples" 288 | out, dx, dx2 = one_argf(fb, y, out_shape) 289 | bad, df = check(dx, dx2) 290 | 291 | if pos is None: 292 | pos = range(y.shape[0]) 293 | set_svg_height(300) 294 | d = double(y, out) 295 | if overlap: 296 | for p in pos: 297 | d += overgraph(d, "f(x)", out, out + diff * dx[:, p], Color("red")) 298 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, p], Color("green")) 299 | 300 | else: 301 | d += overgraph(d, "f(x)", out, out + diff * dx[:, pos].sum(-1), Color("red")) 302 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, pos].sum(-1), Color("lightyellow")) 303 | 304 | d = connect(d, y, "x", out, "f(x)", dx, 305 | pos, to_line=True, bad=bad) 306 | set_svg_height(300) 307 | if bad: 308 | print("Errors") 309 | display(df[:10]) 310 | else: 311 | show_dog() 312 | return outer_frame(d) 313 | 314 | def zip(fb, split=25, pos1 = None, pos2=None, out_shape=25, diff=1, overlap=False, gaps=[0], y=gy): 315 | x, y2 = y[:split], y[split:] 316 | out, dx, dy, dx2, dy2= two_argf(fb, x, y2, out_shape) 317 | bad_x, df_x = check(dx, dx2) 318 | bad_y, df_y = check(dy, dy2) 319 | if pos1 is None: 320 | pos1 = range(x.shape[0]) 321 | if pos2 is None: 322 | pos2 = range(y2.shape[0]) 323 | d = two_arg(x, y2, out, gaps) 324 | gaps = gaps + [x.shape[0]] 325 | if len(gaps) == 2: 326 | colors = [ORANGE] 327 | else: 328 | colors = list(Color("yellow").range_to("darkorange", len(gaps)-1)) 329 | for k, c in enumerate(colors): 330 | d = connect(d, x, "x", out, "f(x)", dx, [p for p in pos1 if gaps[k] < p < gaps[k+1]], 331 | bad=bad_x, color=c) 332 | 333 | d = connect(d, y2, "y", out, "f(x)", dy, color=Color("lightblue"), i_s=pos2, bad =bad_y) 334 | 335 | if overlap: 336 | for p in pos1: 337 | d += overgraph(d, "f(x)", out, out + diff * dx2[:, p], Color("lightyellow")) 338 | d += overgraph(d, "f(x)", out, out + diff * dx[:, p], Color("darkorange")) 339 | for p in pos2: 340 | d += overgraph(d, "f(x)", out, out + diff * dy2[:, p], Color("lightblue")) 341 | d += overgraph(d, "f(x)", out, out + diff * dy[:, p], Color("blue")) 342 | 343 | else: 344 | d += overgraph(d, "f(x)", out, out + diff * dx2.sum(-1), Color("lightyellow")) 345 | d += overgraph(d, "f(x)", out, out + diff * dy2.sum(-1), Color("lightblue")) 346 | d += overgraph(d, "f(x)", out, out + diff * dx.sum(-1), Color("darkorange")) 347 | d += overgraph(d, "f(x)", out, out + diff * dy.sum(-1), Color("blue")) 348 | set_svg_height(300) 349 | if bad_x: 350 | print("x Errors") 351 | display(df_x[:10]) 352 | if bad_y: 353 | print("y Errors") 354 | display(df_y[:10]) 355 | if not bad_x and not bad_y: 356 | show_dog() 357 | return outer_frame(d) 358 | # def fb_index(x): 359 | # f = lambda o: x[o+5] 360 | # dx = lambda i, o: (o + 25) == i 361 | # return f, dx 362 | # in_out(fb_index, overlap=False, out_shape=25) 363 | # def mat(fb, split, in_shape, out_shape): 364 | # x, y2 = y[:split], y[split:] 365 | # x = x.view(*in_shape) 366 | # f, dx, dy, dx2, dy2 = two_mat_argf(fb, x, y2, out_shape) 367 | # bad_y, df_y = check(dy, dy2) 368 | 369 | # d = vcat([graph(x[i], f"x{i}") for i in range(x.shape[0])], 0.0) 370 | # d = hcat([d.center_xy(), graph(y2, "y")], 0.2) 371 | # d = vcat([d.center_xy(), 372 | # vcat([graph(out[i], f"f(x){i}").center_xy() for i in range(out.shape[0])])], 0.15) 373 | # s = d 374 | # for j in range(out.shape[0]): 375 | # for i in range(x.shape[0]): 376 | # d = connect(d, x[i], f"x{i}", out[j], f"f(x){j}", dx[j, :, i], 377 | # range(x.shape[1]), 378 | # list(Color("red").range_to("orange", x.shape[0]))[i]) 379 | # d = connect(d, y2, "y", out[j], f"f(x){j}", dy[j], 380 | # range(y2.shape[0]), 381 | # Color("lightblue"), bad=bad_y) 382 | # if bad_y: 383 | # print("y Errors") 384 | # display(df_y[:10]) 385 | # set_svg_height(800) 386 | # return outer_frame(d) 387 | def make_mat(fb, in_shape, out_shape): 388 | def nf(x, y): 389 | f, d_x, d_y = fb(x.view(in_shape), y) 390 | def f2(o): 391 | return f(o // out_shape[1], o % out_shape[1]) 392 | def d_x2(i, o): 393 | return d_x(i // in_shape[1], i % in_shape[1], o // out_shape[1], o % out_shape[1]) 394 | def d_y2(j, o): 395 | return d_y(j, o // out_shape[1], o % out_shape[1]) 396 | return f2, d_x2, d_y2 397 | return nf 398 | 399 | def make_mat2(fb, in_shape, in_shape2, out_shape): 400 | def nf(x, y): 401 | f, d_x, d_y = fb(x.view(in_shape), y.view(in_shape2)) 402 | def f2(o): 403 | return f(o // out_shape[1], o % out_shape[1]) 404 | def d_x2(i, o): 405 | return d_x(i // in_shape[1], i % in_shape[1], o // out_shape[1], o % out_shape[1]) 406 | def d_y2(j, o): 407 | return d_y(j // in_shape2[1], j % in_shape2[1], o // out_shape[1], o % out_shape[1]) 408 | return f2, d_x2, d_y2 409 | return nf 410 | --------------------------------------------------------------------------------