├── .gitignore
├── 0_intro_material
├── Lab_01_Intro_Colab.ipynb
├── Lab_02_Intro_JAX.ipynb
├── Lab_03_Intro_Numpy.ipynb
├── Lab_04_Intro_Plotting.ipynb
└── README.md
├── 1_vision
├── ComputerVisionPart1.ipynb
├── ComputerVisionPart1_solution.ipynb
├── ComputerVisionPart2.ipynb
├── ComputerVisionPart2_solution.ipynb
├── ComputerVisionPart3.ipynb
├── ComputerVisionPart3_solution.ipynb
└── README.md
├── 2_nlp
├── NLP_tutorial.ipynb
├── NLP_tutorial_solutions.ipynb
└── README.md
├── 3_generative
├── README.md
├── VAE_Tutorial_Solutions.ipynb
└── VAE_Tutorial_Start.ipynb
├── 4_rl
├── README.md
├── RL_Tutorial.ipynb
└── RL_Tutorial_solutions.ipynb
├── LICENSE
├── README.md
└── assets
├── vision_part1_im_rotate.jpg
└── vision_part1_pixel_permutation.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/0_intro_material/Lab_01_Intro_Colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Lab_01_Intro_Colab.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.8.5"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "IBAupUTp-1ao"
34 | },
35 | "source": [
36 | "# What is Google Colab?\n",
37 | "\n",
38 | "[Colaboratory](https://colab.sandbox.google.com/notebooks/welcome.ipynb) is a [Jupyter](http://jupyter.org/) notebook environment that requires no setup to use. It allows you to create and share documents that contain:\n",
39 | "\n",
40 | "* Live, runnable code\n",
41 | "* Visualizations\n",
42 | "* Explanatory text\n",
43 | "\n",
44 | "It's also a great tool for prototyping and quick development. Let's give it a try. "
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {
50 | "id": "Cg9or9RcAG0-"
51 | },
52 | "source": [
53 | "Run the following so-called *(Code) Cell* by moving the cursor into it, and either\n",
54 | "\n",
55 | "* Pressing the \"play\" icon on the left of the cell, or\n",
56 | "* Hitting **`Shift + Enter`**."
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "metadata": {
62 | "id": "qyNrs_gC-2JU"
63 | },
64 | "source": [
65 | "print('Hello, M2L!')"
66 | ],
67 | "execution_count": null,
68 | "outputs": []
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {
73 | "id": "8RGeFACL-2Ps"
74 | },
75 | "source": [
76 | "You should see the `Hello, M2L!` printed under the code."
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {
82 | "id": "7UwtwzObZbKt"
83 | },
84 | "source": [
85 | "The code is executed on a virtual machine dedicated to your account, with the results sent back to your browser. This has some positive and negative consequences."
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {
91 | "id": "-iBdZmzvbPAh"
92 | },
93 | "source": [
94 | "## Using a GPU\n",
95 | "\n",
96 | "You can connect to a virtual machine with a GPU. To select the hardware you want to use, follow either\n",
97 | "\n",
98 | "* **Edit > Notebook settings**, or\n",
99 | "* **Runtime > Change runtime type**\n",
100 | "\n",
101 | "and choose an accelerator.\n",
102 | "\n",
103 | "Note: since January 2020, only **Python 3** is supported in the Google Colab environment."
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "metadata": {
109 | "id": "zxRZTbjtbPco"
110 | },
111 | "source": [
112 | "## Losing Connection\n",
113 | "\n",
114 | "You may lose connection to your virtual machine. The two most common causes are:\n",
115 | "\n",
116 | "* Virtual machines are recycled when idle for a while, and have a maximum lifetime enforced by the system.\n",
117 | "* Long-running background computations, particularly on GPUs, may be stopped.\n",
118 | "\n",
119 | "**If you lose connection**, the state of your notebook will also be lost. You will need to **rerun all cells** up to the one you are currently working on. To do so\n",
120 | "\n",
121 | "1. Select (place the cursor into) the cell you are working on. \n",
122 | "2. Follow **Runtime > Run before**."
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {
128 | "id": "kUhiagDUC28G"
129 | },
130 | "source": [
131 | "## Pretty Printing by colab\n",
132 | "1) If the **last operation** of a given cell returns a value, it will be pretty printed by colab.\n"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "metadata": {
138 | "id": "Hz9y44oYDpXI"
139 | },
140 | "source": [
141 | "6 * 7"
142 | ],
143 | "execution_count": null,
144 | "outputs": []
145 | },
146 | {
147 | "cell_type": "code",
148 | "metadata": {
149 | "id": "YqgK5BvLFaU3"
150 | },
151 | "source": [
152 | "my_dict = {'one': 1, 'some set': {4, 2, 2}, 'a regular list': range(5)}"
153 | ],
154 | "execution_count": null,
155 | "outputs": []
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {
160 | "id": "1Rqgt29CFb4J"
161 | },
162 | "source": [
163 | "There is no output from the second cell, as assignment does not return anything."
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "id": "ZgekhVgsC3Bu"
170 | },
171 | "source": [
172 | "2) You can explicitly **print** anything before the last operation, or **suppress** the output of the last operation by adding a semicolon."
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "metadata": {
178 | "id": "7VKvo7dvDxBy"
179 | },
180 | "source": [
181 | "print(my_dict)\n",
182 | "my_dict['one'] * 10 + 1"
183 | ],
184 | "execution_count": null,
185 | "outputs": []
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {
190 | "id": "z2cu-vPJJL2o"
191 | },
192 | "source": [
193 | "## Scoping and Execution Model"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {
199 | "id": "nKnGjUJvI5Gg"
200 | },
201 | "source": [
202 | "Notice that in the previous code cell we worked with `my_dict`, while it was defined in an even earlier cell.\n",
203 | "\n",
204 | "1) In colabs, variables defined at cell root have **global** scope.\n",
205 | "\n",
206 | "Modify `my_dict`:"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "metadata": {
212 | "id": "XETX0TevKW_h"
213 | },
214 | "source": [
215 | "my_dict['I\\'ve been changed!'] = True"
216 | ],
217 | "execution_count": null,
218 | "outputs": []
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {
223 | "id": "FMPTBE_HDxIH"
224 | },
225 | "source": [
226 | "2) Cells can be **run** in any **arbitrary order**, and global state is maintained between them.\n",
227 | "\n",
228 | "Try re-running the cell where we printed `my_dict`. You should see now see the additional item `\"I've been changed!\": True`.\n"
229 | ]
230 | },
231 | {
232 | "cell_type": "markdown",
233 | "metadata": {
234 | "id": "wiNYzBdyLimN"
235 | },
236 | "source": [
237 | "3) Unintentionally reusing a global variable can lead to bugs. If all else fails, you can uncomment and run the following line to **clear all global variables**."
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "metadata": {
243 | "id": "_iO10nIILkqs"
244 | },
245 | "source": [
246 | "# %reset -f"
247 | ],
248 | "execution_count": null,
249 | "outputs": []
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "metadata": {
254 | "id": "vIVCsaGYMj1k"
255 | },
256 | "source": [
257 | "You will have to re-run the setup cells after."
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "qXw6vbMpFtFF"
264 | },
265 | "source": [
266 | "## Autocomplete / Documentation\n",
267 | "\n",
268 | "Code completions are displayed automatically while you are writing your code.\n",
269 | "\n",
270 | "If you find this to be annoying, you can go to:
\n",
271 | "`Tools > Settings ... > Editor > Automatically trigger code completions` (UNCHECK).\n",
272 | "\n",
273 | "In this case, suggestions are manually invoked with the *``* key:\n",
274 | "* Pressing *``* after typing a prefix will show the available variables / commands.\n",
275 | "* Pressing *``* on a function parameter list will show the function documentation."
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "metadata": {
281 | "id": "hwwF9RtoGMl4"
282 | },
283 | "source": [
284 | "direction_a, direction_b = ['UP', 'DOWN']"
285 | ],
286 | "execution_count": null,
287 | "outputs": []
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "metadata": {
292 | "id": "_6WGj-j5UQ6k"
293 | },
294 | "source": [
295 | "Uncomment and hit *``* after '**dir**':"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "metadata": {
301 | "id": "CVXKeD7PkIUp"
302 | },
303 | "source": [
304 | "# dir"
305 | ],
306 | "execution_count": null,
307 | "outputs": []
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {
312 | "id": "Rqy-TemM4fki"
313 | },
314 | "source": [
315 | "Uncomment and hit *``* after **`print(`**:"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "metadata": {
321 | "id": "88NbgSBsVKeM"
322 | },
323 | "source": [
324 | "# print("
325 | ],
326 | "execution_count": null,
327 | "outputs": []
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {
332 | "id": "xUAFxf1uVkam"
333 | },
334 | "source": [
335 | "Alternatively, the question mark (**?**) works as a special character which gives us information about variables and functions. In this case you need to run the cell.\n",
336 | "\n"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "metadata": {
342 | "id": "h0nEx_NUVXbv"
343 | },
344 | "source": [
345 | "range?"
346 | ],
347 | "execution_count": null,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "metadata": {
353 | "id": "6B17wzOd8S3D"
354 | },
355 | "source": [
356 | "## Shortcuts\n",
357 | "\n",
358 | "The 4 most useful colab-specific shortcuts are:\n",
359 | "\n",
360 | "* `Ctrl+/` which toggles comments. Can be applied across multiple lines. **Try below.**\n",
361 | "* `Ctrl+M b` which creates a new code cell below the current one, placing the cursor in it.\n",
362 | "* `Ctrl+M -` which splits the current cell into 2 at the location of the cursor.\n",
363 | "* `Ctrl+M d` which deletes the current cell.\n",
364 | "\n",
365 | "There is of course a search and replace functionality as well.\n",
366 | "\n"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "metadata": {
372 | "id": "PJon8nGv8TM4"
373 | },
374 | "source": [
375 | "# print('comment')\n",
376 | "# print('me')\n",
377 | "# print('out')\n",
378 | "# print('in one go')"
379 | ],
380 | "execution_count": null,
381 | "outputs": []
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {
386 | "id": "dUTBihZIYC32"
387 | },
388 | "source": [
389 | "## Setup and Imports\n",
390 | "\n",
391 | "Python packages can and need to be imported into your colab notebook, the same way you would import them in a python script. For example, to use `numpy`, you would do"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "metadata": {
397 | "id": "Rhk-7l_dYExv"
398 | },
399 | "source": [
400 | "import numpy as np"
401 | ],
402 | "execution_count": null,
403 | "outputs": []
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "metadata": {
408 | "id": "fQE00MVwYfej"
409 | },
410 | "source": [
411 | "While many packages (all packages that you will need!) can just be imported, some (e.g. `sonnet`) may not immediately be available. With colab, you can install any python package from `pip` for the duration of your connection."
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "metadata": {
417 | "id": "yfP92hNwYfqv"
418 | },
419 | "source": [
420 | "!pip install -q dm-haiku"
421 | ],
422 | "execution_count": null,
423 | "outputs": []
424 | },
425 | {
426 | "cell_type": "markdown",
427 | "metadata": {
428 | "id": "YUQ8oIhHfgYp"
429 | },
430 | "source": [
431 | "You would, then, be able to call `import haiku as hk` as you normally do."
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {
437 | "id": "qCIXC0D6_VEm"
438 | },
439 | "source": [
440 | "Notice that we ran the shell command `pip` above. You can run any shell command by starting with `!`.
\n",
441 | "There is an example below, and you can read more [here](https://colab.research.google.com/github/jakevdp/PythonDataScienceHandbook/blob/master/notebooks/01.05-IPython-And-Shell-Commands.ipynb#scrollTo=Tts2ysMs-xIz)."
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "metadata": {
447 | "id": "8iCFyEUP_jpX"
448 | },
449 | "source": [
450 | "!ls\n",
451 | "!mkdir {\"test\"}\n",
452 | "contents = !ls\n",
453 | "print(contents)"
454 | ],
455 | "execution_count": null,
456 | "outputs": []
457 | },
458 | {
459 | "cell_type": "markdown",
460 | "metadata": {
461 | "id": "pWAjydYTuVUz"
462 | },
463 | "source": [
464 | "## Debugging\n",
465 | "\n",
466 | "You can debug code in cells with `(i)pdb` (IPython debugger).\n",
467 | "There are multiple options:\n",
468 | "\n",
469 | "* Create a cell containging only `%debug` *after* encountering an error.\n",
470 | "* Add `%%debug` as the first line of an existing cell to start debugging from the beginning.\n",
471 | "* Add `import pdb; pdb.set_trace()` on a line to pause execution there.\n",
472 | "\n",
473 | "Use 'n' to run the next line of code, and 'c' to resume execution."
474 | ]
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "metadata": {
479 | "id": "EFTr9-pQGNB8"
480 | },
481 | "source": [
482 | "### Use of %debug\n"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "metadata": {
488 | "id": "cRb-UKo7GWH_"
489 | },
490 | "source": [
491 | "a_list = list()\n",
492 | "a_list['some_key'] = 4"
493 | ],
494 | "execution_count": null,
495 | "outputs": []
496 | },
497 | {
498 | "cell_type": "code",
499 | "metadata": {
500 | "id": "-UvYB2Y-GV2M"
501 | },
502 | "source": [
503 | "%debug"
504 | ],
505 | "execution_count": null,
506 | "outputs": []
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "id": "fu9KuBnFGWzq"
512 | },
513 | "source": [
514 | "### Use of %%debug"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "metadata": {
520 | "id": "uyzsDJwGq_Dy"
521 | },
522 | "source": [
523 | "%%debug\n",
524 | "print('Let me get started.')\n",
525 | "message = 'We are almost done.'\n",
526 | "print('You can see the global variables, step through code, etc.')"
527 | ],
528 | "execution_count": null,
529 | "outputs": []
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {
534 | "id": "d1ZsQaxZGZ7W"
535 | },
536 | "source": [
537 | "### Use of pdb"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "metadata": {
543 | "id": "_tQM1XFNFooh"
544 | },
545 | "source": [
546 | "print('Let me get started.')\n",
547 | "message = 'We are almost done.'\n",
548 | "import pdb; pdb.set_trace()\n",
549 | "print('You can see the global variables, step through code, etc.')"
550 | ],
551 | "execution_count": null,
552 | "outputs": []
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "metadata": {
557 | "id": "tiSKbgPe9hqI"
558 | },
559 | "source": [
560 | "## Forms\n",
561 | "\n",
562 | "With colab it is easy to take input from the user in code cells through so called forms. A simplest example is shown below."
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "metadata": {
568 | "id": "Gu30itgd97l7"
569 | },
570 | "source": [
571 | "# @title This text shows up as a title.\n",
572 | "\n",
573 | "a = 2 # @param {type: 'integer'}\n",
574 | "b = 3 # @param\n",
575 | "\n",
576 | "print('a+b =', str(a+b))"
577 | ],
578 | "execution_count": null,
579 | "outputs": []
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "metadata": {
584 | "id": "_-w9fM9BcUBS"
585 | },
586 | "source": [
587 | "In order to expose a variable as parameter you just add `#@param` after it.\n",
588 | "\n",
589 | "You can use the GUI on othe right hand side to change parameters values and types.
**Try setting the value of a=5 and rerun the cell above.**\n",
590 | "\n",
591 | "You can read more about this on the official starting colab.\n"
592 | ]
593 | },
594 | {
595 | "cell_type": "markdown",
596 | "metadata": {
597 | "id": "XdQo7XSl-4wm"
598 | },
599 | "source": [
600 | "\n",
601 | "Cells with forms allow you to toggle whether either of these are visible:\n",
602 | "\n",
603 | "* the code,\n",
604 | "* the form,\n",
605 | "* or both\n",
606 | "\n",
607 | "**Try switching between these 3 options for the above cell.** This is how you do this:\n",
608 | "\n",
609 | "1. Click anywhere over the area of the cell with the form to highlight it.\n",
610 | "2. Click on the \"three vertically arranged dots\" icong in the top right of the cell.\n",
611 | "3. Go to \"Form >\", select your desired action."
612 | ]
613 | },
614 | {
615 | "cell_type": "markdown",
616 | "metadata": {
617 | "id": "SZKTwLixhlc5"
618 | },
619 | "source": [
620 | "## Guided exercise: write a decoder for a text encoder\n",
621 | "\n",
622 | "We defined a (very) simple text encoding function in ``encode()``.
\n",
623 | "Your job is to understand it, and write the corresponding decoder in ``decode()``, so that **`text == decoder(encoder(text))`**.\n"
624 | ]
625 | },
626 | {
627 | "cell_type": "code",
628 | "metadata": {
629 | "id": "dJrw_uFKhqCh",
630 | "cellView": "code"
631 | },
632 | "source": [
633 | "# Code\n",
634 | "laws = \"\"\"\n",
635 | "1. A robot may not injure a human being or,\n",
636 | " through inaction, allow a human being to come to harm.\n",
637 | "2. A robot must obey orders given it by human\n",
638 | " beings except where such orders would conflict with the First Law.\n",
639 | "3. A robot must protect its own existence as\n",
640 | " long as such protection does not conflict with the First or Second Law.\n",
641 | "\"\"\"\n",
642 | "\n",
643 | "\n",
644 | "def encode(plain_text):\n",
645 | " new_letters = [chr(ord(letter)+1) for letter in plain_text]\n",
646 | " return ''.join(new_letters)\n",
647 | "\n",
648 | "\n",
649 | "def decode(encoded_text):\n",
650 | " ### Your Code Here ###\n",
651 | " return decoded_text"
652 | ],
653 | "execution_count": null,
654 | "outputs": []
655 | },
656 | {
657 | "cell_type": "code",
658 | "metadata": {
659 | "id": "4X3bhmgLolXZ"
660 | },
661 | "source": [
662 | "# Basic Test\n",
663 | "\n",
664 | "encoded_text = encode(laws)\n",
665 | "print('The encoded text:')\n",
666 | "print(encoded_text)\n",
667 | "assert encoded_text != laws, (\n",
668 | " 'The encoded text should be different from the original')\n",
669 | "print()\n",
670 | "\n",
671 | "# decoded_text = decode(encoded_text)\n",
672 | "# print('The decoded text:')\n",
673 | "# print(decoded_text)\n",
674 | "# assert decoded_text == laws, (\n",
675 | "# 'The decoded text should be the same as the original')"
676 | ],
677 | "execution_count": null,
678 | "outputs": []
679 | },
680 | {
681 | "cell_type": "code",
682 | "metadata": {
683 | "id": "odNeTetkrS2x",
684 | "cellView": "form"
685 | },
686 | "source": [
687 | "# @title Solution\n",
688 | "\n",
689 | "# def encode(plain_text):\n",
690 | "# new_letters = [chr(ord(letter)+1) for letter in plain_text]\n",
691 | "# return ''.join(new_letters)\n",
692 | "\n",
693 | "\n",
694 | "# def decode(encoded_text):\n",
695 | "# new_letters = [chr(ord(letter)-1) for letter in encoded_text]\n",
696 | "# return ''.join(new_letters)"
697 | ],
698 | "execution_count": null,
699 | "outputs": []
700 | },
701 | {
702 | "cell_type": "markdown",
703 | "metadata": {
704 | "id": "B6g2I-twVvp5"
705 | },
706 | "source": [
707 | "## Some additional tips\n",
708 | "\n",
709 | "* You can access an outline of the colab by clicking the arrow on the right hand side.\n",
710 | "* The [official colab landing colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb) has some more examples and info as well."
711 | ]
712 | }
713 | ]
714 | }
--------------------------------------------------------------------------------
/0_intro_material/Lab_02_Intro_JAX.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "Lab_02_Intro_JAX.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.8.5"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "rjPF8rKiize7"
34 | },
35 | "source": [
36 | "# JAX\n",
37 | "[JAX](https://jax.readthedocs.io/en/latest/jax.html) is a library to conduct machine learning research directly on NumPy-like code and data.\n",
38 | "\n",
39 | "- It automatically differentiates python code and NumPy code (with [Autograd](https://github.com/hips/autograd)).\n",
40 | "- It compiles and runs NumPy code efficiently on accelerators like GPU and TPU (with [XLA](https://www.tensorflow.org/xla))."
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {
46 | "id": "NDmbQoL5mYmA"
47 | },
48 | "source": [
49 | "### JAX and random number generators\n",
50 | "To use pseudo-random generators in JAX, you need to explicitely generate a random key, and pass it to the operations that work with random numbers (e.g. model initialization, dropout etc).\n",
51 | "\n",
52 | "A call to a random function with the same key does not change the state of the generator. This has to be done explicitely with `split()` (or `next_rng_key()` in `haiku` transformed functions)."
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "metadata": {
58 | "id": "0j5kLgolmlCl"
59 | },
60 | "source": [
61 | "import jax.numpy as jnp\n",
62 | "from jax import random\n",
63 | "\n",
64 | "# Generating a single key will result in duplicate pseudo-random numbers\n",
65 | "key = random.PRNGKey(0)\n",
66 | "x1 = random.normal(key, (3,))\n",
67 | "print(x1)\n",
68 | "x2 = random.normal(key, (3,))\n",
69 | "print(x2)"
70 | ],
71 | "execution_count": null,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "code",
76 | "metadata": {
77 | "id": "IZNs5VAdfmpG"
78 | },
79 | "source": [
80 | "# Let's split the key to be able to generate different random values\n",
81 | "key, new_key = random.split(key)\n",
82 | "x1 = random.normal(key, (3,))\n",
83 | "print(x1)\n",
84 | "x2 = random.normal(new_key, (3,))\n",
85 | "print(x2)"
86 | ],
87 | "execution_count": null,
88 | "outputs": []
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {
93 | "id": "-hiqCH1Nkvdv"
94 | },
95 | "source": [
96 | "### JAX program transformations with examples \n",
97 | "* `jit` (just-in-time compilation) -- speeds up your code by running all the ops inside the jit-ed function as a *fused* op; it compiles the function when it is called the first time, and uses the compiled (optimized) version from the second call onwards.\n",
98 | "* `grad` (automatic differentiation) -- returns derivatives of a function with respect to the model weights passed as parameters.\n",
99 | "* `vmap` (automatic batching) -- returns a new function that can apply the original (per-sample) function to a batch.\n",
100 | "\n"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {
106 | "id": "PJ49KxvRklLG"
107 | },
108 | "source": [
109 | "**Just-in-time compilation**"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {
115 | "id": "kgUXa0jMo2gx"
116 | },
117 | "source": [
118 | "A function can be \"jit-ed\" in two ways:\n",
119 | "* by defining a new one as ``jit(original_function)`` (shown here)\n",
120 | "* by using the ``@jit`` decorator in the function definition (shown later).\n"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "qiWR4CPjlc6T"
127 | },
128 | "source": [
129 | "from jax import jit\n",
130 | "\n",
131 | "\n",
132 | "# Function and input definition\n",
133 | "def selu(x, alpha=1.67, lmbda=1.05):\n",
134 | " return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n",
135 | "\n",
136 | "\n",
137 | "x = random.normal(key, (1000000,))\n",
138 | "\n",
139 | "# Execute the function without jit\n",
140 | "%timeit selu(x).block_until_ready()\n",
141 | "\n",
142 | "# Execute the function with jit\n",
143 | "selu_jit = jit(selu)\n",
144 | "%timeit selu_jit(x).block_until_ready()"
145 | ],
146 | "execution_count": null,
147 | "outputs": []
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {
152 | "id": "N5r5pqGzk_jK"
153 | },
154 | "source": [
155 | "Note: we used ``block_until_ready`` to time the function calls because JAX, by default, runs operations asynchronously."
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {
161 | "id": "25H9JFFQkw2I"
162 | },
163 | "source": [
164 | "**Automatic differentiation**"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {
170 | "id": "IBHX97Uy5kBD"
171 | },
172 | "source": [
173 | "Also known as \"autograd\", automatic differentiation can be obtained in JAX by calling `grad(original_function)`."
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "metadata": {
179 | "id": "rP70-aR3oouX"
180 | },
181 | "source": [
182 | "from jax import grad, jit\n",
183 | "\n",
184 | "\n",
185 | "# Function definition\n",
186 | "def simple_fun(x):\n",
187 | " return jnp.sin(x) / x\n",
188 | "\n",
189 | "\n",
190 | "# Get the gradient of simple_fun with respect to its input x\n",
191 | "grad_simple_fun = grad(simple_fun)\n",
192 | "\n",
193 | "# Get higher order derivatives of simple_fun (e.g. Hessian)\n",
194 | "grad_grad_simple_fun = grad(grad(simple_fun))"
195 | ],
196 | "execution_count": null,
197 | "outputs": []
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "metadata": {
202 | "id": "BeegVpqw6e6k"
203 | },
204 | "source": [
205 | "Note: `grad_simple_fun()` accepts the same input type as `simple_fun()`.
\n",
206 | "So if we want to pass a vector in our example, we have to either use a list comprehension (shown here), or use a proper batching mechanism (shown later)."
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "metadata": {
212 | "id": "FH8uTx364T0m"
213 | },
214 | "source": [
215 | "# Let's plot the result\n",
216 | "import matplotlib.pyplot as plt\n",
217 | "x_range = jnp.arange(-8, 8, .1)\n",
218 | "plt.plot(x_range, simple_fun(x_range), 'b')\n",
219 | "plt.plot(x_range, [grad_simple_fun(xi) for xi in x_range], 'r')\n",
220 | "plt.plot(x_range, [grad_grad_simple_fun(xi) for xi in x_range], '--g')\n",
221 | "plt.legend(('simple_fun(x)', 'grad_simple_fun(x)', 'grad_grad_simple_fun(x)'))\n",
222 | "plt.show()"
223 | ],
224 | "execution_count": null,
225 | "outputs": []
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "metadata": {
230 | "id": "O7dRXkw8k1c0"
231 | },
232 | "source": [
233 | "**Automatic batching**"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "metadata": {
239 | "id": "W6mCivd-nTay"
240 | },
241 | "source": [
242 | "from jax import vmap\n",
243 | "\n",
244 | "# In the example above, we can use vmap instead of loop to compute gradients\n",
245 | "# (We stop at the first-order derivative)\n",
246 | "grad_vect_simple_fun = vmap(grad_simple_fun)(x_range)\n",
247 | "\n",
248 | "# Plot again and visually check that the gradients are identical\n",
249 | "plt.plot(x_range, simple_fun(x_range), 'b')\n",
250 | "plt.plot(x_range, grad_vect_simple_fun, 'c', linewidth=5)\n",
251 | "plt.plot(x_range, [grad_simple_fun(xi) for xi in x_range], 'r')\n",
252 | "plt.legend(('simple_fun(x)', 'grad_vect_simple_fun(x)', 'grad_simple_fun(x)'))\n",
253 | "plt.show()"
254 | ],
255 | "execution_count": null,
256 | "outputs": []
257 | },
258 | {
259 | "cell_type": "code",
260 | "metadata": {
261 | "id": "dibhhYsjphE3"
262 | },
263 | "source": [
264 | "# Let's time them!\n",
265 | "\n",
266 | "# naive batching\n",
267 | "def naively_batched(x):\n",
268 | " return jnp.stack([grad_simple_fun(xi) for xi in x])\n",
269 | "\n",
270 | "\n",
271 | "# manual batching with jit\n",
272 | "@jit\n",
273 | "def manual_batched(x):\n",
274 | " return jnp.stack([grad_simple_fun(xi) for xi in x])\n",
275 | "\n",
276 | "\n",
277 | "# Batching using vmap and jit\n",
278 | "@jit\n",
279 | "def vmap_batched(x):\n",
280 | " return vmap(grad_simple_fun)(x)\n",
281 | "\n",
282 | "\n",
283 | "print('Naively batched')\n",
284 | "%timeit naively_batched(x_range).block_until_ready()\n",
285 | "print('jit batched')\n",
286 | "%timeit manual_batched(x_range).block_until_ready()\n",
287 | "print('With jit vmap')\n",
288 | "%timeit vmap_batched(x_range).block_until_ready()"
289 | ],
290 | "execution_count": null,
291 | "outputs": []
292 | },
293 | {
294 | "cell_type": "markdown",
295 | "metadata": {
296 | "id": "d0lHGcAwdfQq"
297 | },
298 | "source": [
299 | "### Read the doc for [common gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in JAX!"
300 | ]
301 | },
302 | {
303 | "cell_type": "markdown",
304 | "metadata": {
305 | "id": "x7ViyN6Q8kOa"
306 | },
307 | "source": [
308 | "# Haiku\n",
309 | "[Haiku](https://github.com/deepmind/dm-haiku) is an object-oriented library to develop neural networks on top of JAX.\n",
310 | "\n",
311 | "Notable entities:\n",
312 | "* `hk.Module`: this is the base class for all layers and neural networks in Haiku.
You can implement your own as a subclass of it: `class MyModule(hk.Module): [...]`\n",
313 | "* `hk.transform`: this is used to convert modules (stateful elements) into pure functions (stateless elements).
All JAX transformations (e.g. `jax.grad`) require you to pass in a pure function for correct behaviour.\n",
314 | "\n"
315 | ]
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {
320 | "id": "3h4MVQzkck_R"
321 | },
322 | "source": [
323 | "**Example: Train a Multi-Layer Perceptron classifier (MLP) on the MNIST dataset**"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "metadata": {
329 | "id": "G6dDQBr71quQ"
330 | },
331 | "source": [
332 | "# We will use haiku on top of jax\n",
333 | "\n",
334 | "! pip install -q dm-haiku\n",
335 | "! pip install -q optax"
336 | ],
337 | "execution_count": null,
338 | "outputs": []
339 | },
340 | {
341 | "cell_type": "code",
342 | "metadata": {
343 | "id": "NG3krUq1WK_A"
344 | },
345 | "source": [
346 | "import contextlib\n",
347 | "from typing import Any, Mapping, Generator, Tuple\n",
348 | "\n",
349 | "import haiku as hk\n",
350 | "\n",
351 | "import jax\n",
352 | "import optax # package for optimizers\n",
353 | "import jax.numpy as jnp\n",
354 | "import numpy as np\n",
355 | "import enum\n",
356 | "\n",
357 | "# Dataset library\n",
358 | "import tensorflow.compat.v2 as tf\n",
359 | "import tensorflow_datasets as tfds\n",
360 | "\n",
361 | "# Plotting library\n",
362 | "from matplotlib import pyplot as plt\n",
363 | "import pylab as pl\n",
364 | "from IPython import display\n",
365 | "\n",
366 | "# Don't forget to select GPU runtime environment\n",
367 | "# in Runtime -> Change runtime type\n",
368 | "device_name = tf.test.gpu_device_name()\n",
369 | "if device_name != '/device:GPU:0':\n",
370 | " raise SystemError('GPU device not found')\n",
371 | "print('Found GPU at: {}'.format(device_name))\n",
372 | "\n",
373 | "# define some useful types\n",
374 | "OptState = Any\n",
375 | "Batch = Mapping[str, np.ndarray]"
376 | ],
377 | "execution_count": null,
378 | "outputs": []
379 | },
380 | {
381 | "cell_type": "markdown",
382 | "metadata": {
383 | "id": "6r0EfeAqIZkS"
384 | },
385 | "source": [
386 | "### Define the dataset: MNIST"
387 | ]
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "metadata": {
392 | "id": "gaqS6S2Y3oZh"
393 | },
394 | "source": [
395 | "MNIST dataset: [[Reference 1](http://yann.lecun.com/exdb/mnist/)] [[Reference 2](https://en.wikipedia.org/wiki/MNIST_database)]"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "metadata": {
401 | "id": "_lWzTqgO1K3R"
402 | },
403 | "source": [
404 | "# JAX does not support data loading or preprocessing,\n",
405 | "# so we use TensorFlow datasets (tfds).\n",
406 | "# We define a load_dataset() function that selectively imports,\n",
407 | "# shuffles and batches part of MNIST.\n",
408 | "# The function returns a Generator which produces batches of\n",
409 | "# MNIST data.\n",
410 | "\n",
411 | "NUM_CLASSES = 10\n",
412 | "\n",
413 | "def load_dataset(split:str, *, is_training:bool, batch_size:int) -> Generator[Batch, None, None]:\n",
414 | " \"\"\"Loads the dataset as a generator of batches.\"\"\"\n",
415 | " ds = tfds.load('mnist:3.*.*', split=split).cache().repeat()\n",
416 | " if is_training:\n",
417 | " ds = ds.shuffle(10 * batch_size, seed=0)\n",
418 | " ds = ds.batch(batch_size)\n",
419 | " return tfds.as_numpy(ds)"
420 | ],
421 | "execution_count": null,
422 | "outputs": []
423 | },
424 | {
425 | "cell_type": "code",
426 | "metadata": {
427 | "id": "GMXfrKnof2L3"
428 | },
429 | "source": [
430 | "# Make datasets for train and test\n",
431 | "train_dataset = iter(load_dataset('train',\n",
432 | " is_training=True,\n",
433 | " batch_size=1000))\n",
434 | "train_eval_dataset = iter(load_dataset('train',\n",
435 | " is_training=False,\n",
436 | " batch_size=10000))\n",
437 | "test_eval_dataset = iter(load_dataset('test',\n",
438 | " is_training=False,\n",
439 | " batch_size=10000))"
440 | ],
441 | "execution_count": null,
442 | "outputs": []
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {
447 | "id": "00m3RHu0BTir"
448 | },
449 | "source": [
450 | "### Define the classifier: a simple MLP"
451 | ]
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "metadata": {
456 | "id": "8zqQS22S-qcu"
457 | },
458 | "source": [
459 | "Architecture:\n",
460 | "* Flatten (unroll 28$\\times$28 image into 784-long vector)\n",
461 | "* Linear mapping to 300-long vector (fully-connected)\n",
462 | "* ReLU (non-linearity)\n",
463 | "* Linear mapping to 100-long vector\n",
464 | "* ReLU\n",
465 | "* Linear mapping to the final problem size (10 classes)"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "metadata": {
471 | "id": "7j58v6LhWjrQ"
472 | },
473 | "source": [
474 | "def net_fn(batch: Batch) -> jnp.ndarray:\n",
475 | " \"\"\"Standard LeNet-300-100 MLP network.\"\"\"\n",
476 | " # The images are in [0, 255], uint8; \n",
477 | " # we need to convert to float and normalize\n",
478 | " x = batch['image'].astype(jnp.float32) / 255.\n",
479 | " # We use hk.Sequential to chain the modules in the network\n",
480 | " mlp = hk.Sequential([\n",
481 | " # The input images are 28x28, so we first flatten them\n",
482 | " # to apply linear (fully-connected) layers\n",
483 | " hk.Flatten(),\n",
484 | " hk.Linear(300), jax.nn.relu,\n",
485 | " hk.Linear(100), jax.nn.relu,\n",
486 | " hk.Linear(10),\n",
487 | " ])\n",
488 | " return mlp(x)"
489 | ],
490 | "execution_count": null,
491 | "outputs": []
492 | },
493 | {
494 | "cell_type": "markdown",
495 | "metadata": {
496 | "id": "YeSPeUblANWV"
497 | },
498 | "source": [
499 | "``hk.transform`` turns functions that use object-oriented, functionally \"impure\" modules into pure functions that can be used with ``jax.jit``, ``jax.grad``, ``jax.pmap``, etc.\n",
500 | "\n",
501 | "Note: since we do not store additional state statistics (e.g. as needed in batch norm), we use `hk.transform`.
\n",
502 | "If we define a `batch_norm` layer, we will use `hk.transform_with_state`"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "metadata": {
508 | "id": "a9YzGJ_4WuE5"
509 | },
510 | "source": [
511 | "net = hk.transform(net_fn)"
512 | ],
513 | "execution_count": null,
514 | "outputs": []
515 | },
516 | {
517 | "cell_type": "code",
518 | "metadata": {
519 | "id": "jgEI2ACn_0C0"
520 | },
521 | "source": [
522 | "print(type(net_fn))\n",
523 | "print(type(net))"
524 | ],
525 | "execution_count": null,
526 | "outputs": []
527 | },
528 | {
529 | "cell_type": "markdown",
530 | "metadata": {
531 | "id": "qfm8NSdQOR1n"
532 | },
533 | "source": [
534 | "### Define the optimizer"
535 | ]
536 | },
537 | {
538 | "cell_type": "markdown",
539 | "metadata": {
540 | "id": "l4yK8gSlBFjR"
541 | },
542 | "source": [
543 | "https://jax.readthedocs.io/en/latest/jax.experimental.optimizers.html"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "metadata": {
549 | "id": "xxXMm5OzOPCO"
550 | },
551 | "source": [
552 | "# We use Adam optimizer here. Others are possible, e.g. SGD with momentum.\n",
553 | "lr = 1e-3\n",
554 | "opt = optax.adam(lr)"
555 | ],
556 | "execution_count": null,
557 | "outputs": []
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "metadata": {
562 | "id": "GA81esw-OoBK"
563 | },
564 | "source": [
565 | "### Define the optimization objective (loss function)"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "metadata": {
571 | "id": "OA1FdFSwWwtb"
572 | },
573 | "source": [
574 | "# Training loss: cross-entropy plus regularization weight decay\n",
575 | "def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:\n",
576 | " \"\"\"Compute the loss of the network, including L2 for regularization.\"\"\"\n",
577 | "\n",
578 | " # Get network predictions\n",
579 | " logits = net.apply(params, None, batch)\n",
580 | "\n",
581 | " # Generate one_hot labels from index classes\n",
582 | " labels = jax.nn.one_hot(batch['label'], NUM_CLASSES)\n",
583 | "\n",
584 | " # Compute mean softmax cross entropy over the batch\n",
585 | " softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))\n",
586 | " softmax_xent /= labels.shape[0]\n",
587 | "\n",
588 | " # Compute the weight decay loss by penalising the norm of parameters\n",
589 | " l2_loss = 0.5 * sum(jnp.sum(jnp.square(p))\n",
590 | " for p in jax.tree_leaves(params))\n",
591 | "\n",
592 | " return softmax_xent + 1e-4 * l2_loss"
593 | ],
594 | "execution_count": null,
595 | "outputs": []
596 | },
597 | {
598 | "cell_type": "markdown",
599 | "metadata": {
600 | "id": "KA_CZhFHQDDN"
601 | },
602 | "source": [
603 | "### Evaluation metric"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "metadata": {
609 | "id": "dlEMA3RCXBU_"
610 | },
611 | "source": [
612 | "# Classification accuracy\n",
613 | "@jax.jit\n",
614 | "def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:\n",
615 | " # Get network predictions\n",
616 | " predictions = net.apply(params, None, batch)\n",
617 | " # Return accuracy = how many predictions match the ground truth\n",
618 | " return jnp.mean(jnp.argmax(predictions, axis=-1) == batch['label'])"
619 | ],
620 | "execution_count": null,
621 | "outputs": []
622 | },
623 | {
624 | "cell_type": "markdown",
625 | "metadata": {
626 | "id": "0SGCtelDQfOv"
627 | },
628 | "source": [
629 | "### Define training step (parameters update)"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "metadata": {
635 | "id": "PsjUbkMWQlt-"
636 | },
637 | "source": [
638 | "@jax.jit\n",
639 | "def update(\n",
640 | " params: hk.Params,\n",
641 | " opt_state: OptState,\n",
642 | " batch: Batch,\n",
643 | ") -> Tuple[hk.Params, OptState]:\n",
644 | " \"\"\"Learning rule (stochastic gradient descent).\"\"\"\n",
645 | " # Use jax transformation `grad` to compute gradients;\n",
646 | " # it expects the prameters of the model and the input batch\n",
647 | " grads = jax.grad(loss)(params, batch)\n",
648 | "\n",
649 | " # Compute parameters updates based on gradients and optimizer state\n",
650 | " updates, opt_state = opt.update(grads, opt_state)\n",
651 | "\n",
652 | " # Apply updates to parameters\n",
653 | " new_params = optax.apply_updates(params, updates)\n",
654 | " return new_params, opt_state"
655 | ],
656 | "execution_count": null,
657 | "outputs": []
658 | },
659 | {
660 | "cell_type": "markdown",
661 | "metadata": {
662 | "id": "RubrTm2_WOiP"
663 | },
664 | "source": [
665 | "### Initialize the model and the optimizer"
666 | ]
667 | },
668 | {
669 | "cell_type": "code",
670 | "metadata": {
671 | "id": "2a8lZDaAmH3B"
672 | },
673 | "source": [
674 | "# Initialize model and optimizer; note that a sample input is needed to compute\n",
675 | "# shapes of parameters\n",
676 | "\n",
677 | "# Generate a data batch\n",
678 | "batch = next(train_dataset)\n",
679 | "# Initialize model\n",
680 | "params = net.init(jax.random.PRNGKey(42), batch)\n",
681 | "# Initialize optimizer\n",
682 | "opt_state = opt.init(params)"
683 | ],
684 | "execution_count": null,
685 | "outputs": []
686 | },
687 | {
688 | "cell_type": "markdown",
689 | "metadata": {
690 | "id": "h2yunmBjWyPQ"
691 | },
692 | "source": [
693 | "### Visualize data and parameter shapes"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "metadata": {
699 | "id": "A2OYWEKKgVEJ"
700 | },
701 | "source": [
702 | "# We define a gallery() function to display images\n",
703 | "\n",
704 | "MAX_IMAGES = 10\n",
705 | "\n",
706 | "def gallery(images, label, title='Input images'):\n",
707 | " class_dict = [u'zero', u'one', u'two', u'three',\n",
708 | " u'four', u'five', u'six', u'seven',\n",
709 | " u'eight', u'nine']\n",
710 | " num_frames, h, w, num_channels = images.shape\n",
711 | " num_frames = min(num_frames, MAX_IMAGES)\n",
712 | " ff, axes = plt.subplots(1, num_frames,\n",
713 | " figsize=(30, 30),\n",
714 | " subplot_kw={'xticks': [], 'yticks': []})\n",
715 | " if images.min() < 0:\n",
716 | " images = (images + 1.) / 2.\n",
717 | " for i in range(0, num_frames):\n",
718 | " if num_channels == 3:\n",
719 | " axes[i].imshow(np.squeeze(images[i]))\n",
720 | " else:\n",
721 | " axes[i].imshow(np.squeeze(images[i]), cmap='gray')\n",
722 | " axes[i].set_title(class_dict[label[i]], fontsize=28)\n",
723 | " plt.setp(axes[i].get_xticklabels(), visible=False)\n",
724 | " plt.setp(axes[i].get_yticklabels(), visible=False)\n",
725 | " ff.subplots_adjust(wspace=0.1)\n",
726 | " plt.show()"
727 | ],
728 | "execution_count": null,
729 | "outputs": []
730 | },
731 | {
732 | "cell_type": "code",
733 | "metadata": {
734 | "id": "SSmLtlnUgBeS"
735 | },
736 | "source": [
737 | "# Display shapes and images\n",
738 | "print(batch['image'].shape)\n",
739 | "print(batch['label'].shape)\n",
740 | "gallery(batch['image'], batch['label'])"
741 | ],
742 | "execution_count": null,
743 | "outputs": []
744 | },
745 | {
746 | "cell_type": "code",
747 | "metadata": {
748 | "id": "euhMKRF_W5_D"
749 | },
750 | "source": [
751 | "# Let's see how many parameters are in our network and their shapes\n",
752 | "def get_num_params(params: hk.Params):\n",
753 | " num_params = 0\n",
754 | " for p in jax.tree_leaves(params):\n",
755 | " print(p.shape)\n",
756 | " num_params = num_params + np.prod(p.shape)\n",
757 | " return num_params\n",
758 | "\n",
759 | "print('Total number of parameters %d' % get_num_params(params))"
760 | ],
761 | "execution_count": null,
762 | "outputs": []
763 | },
764 | {
765 | "cell_type": "markdown",
766 | "metadata": {
767 | "id": "Uq1OddpfX-6y"
768 | },
769 | "source": [
770 | "### Accuracy of the untrained model (should be ~10%)"
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "metadata": {
776 | "id": "dHzblZx6X9jH"
777 | },
778 | "source": [
779 | "# Compute accuracy on the test dataset\n",
780 | "test_accuracy = accuracy(params, next(test_eval_dataset))\n",
781 | "print('Test accuracy %f ' % test_accuracy)"
782 | ],
783 | "execution_count": null,
784 | "outputs": []
785 | },
786 | {
787 | "cell_type": "code",
788 | "metadata": {
789 | "id": "wIiw1dPDZEuN"
790 | },
791 | "source": [
792 | "# Let's visualize some network predictions\n",
793 | "# before training; if some are correct,\n",
794 | "# they are correct by chance.\n",
795 | "predictions = net.apply(params, None, batch)\n",
796 | "pred_labels = jnp.argmax(predictions, axis=-1)\n",
797 | "gallery(batch['image'], pred_labels)"
798 | ],
799 | "execution_count": null,
800 | "outputs": []
801 | },
802 | {
803 | "cell_type": "markdown",
804 | "metadata": {
805 | "id": "NjeMJxKYaeN-"
806 | },
807 | "source": [
808 | "### Run one training step"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "metadata": {
814 | "id": "ejfhiuoiaiZ6"
815 | },
816 | "source": [
817 | "# First, let's do one step and check if the updates lead to decrease in error\n",
818 | "loss_before_train = loss(params, batch)\n",
819 | "print('Loss before train %f' % loss_before_train)\n",
820 | "params, opt_state = update(params, opt_state, batch)\n",
821 | "new_loss = loss(params, next(train_dataset))\n",
822 | "new_loss_same_batch = loss(params, batch)\n",
823 | "print('Loss after one step of training, same batch %f, different batch %f'\n",
824 | " % (new_loss_same_batch, new_loss))"
825 | ],
826 | "execution_count": null,
827 | "outputs": []
828 | },
829 | {
830 | "cell_type": "markdown",
831 | "metadata": {
832 | "id": "n9QFlfNTW-GZ"
833 | },
834 | "source": [
835 | "### Run training steps in a loop. We also run evaluation periodically."
836 | ]
837 | },
838 | {
839 | "cell_type": "code",
840 | "metadata": {
841 | "id": "fAR5joBwV5cT"
842 | },
843 | "source": [
844 | "# Train/eval loop.\n",
845 | "for step in range(5001):\n",
846 | " if step % 1000 == 0:\n",
847 | " # Periodically evaluate classification accuracy on train & test sets.\n",
848 | " train_accuracy = accuracy(params, next(train_eval_dataset))\n",
849 | " test_accuracy = accuracy(params, next(test_eval_dataset))\n",
850 | " train_accuracy, test_accuracy = jax.device_get(\n",
851 | " (train_accuracy, test_accuracy))\n",
852 | " print('Step %d Train / Test accuracy: %f / %f'\n",
853 | " % (step, train_accuracy, test_accuracy))\n",
854 | "\n",
855 | " # Do SGD on a batch of training examples.\n",
856 | " params, opt_state = update(params, opt_state, next(train_dataset))"
857 | ],
858 | "execution_count": null,
859 | "outputs": []
860 | },
861 | {
862 | "cell_type": "markdown",
863 | "metadata": {
864 | "id": "JNYLWtvmao4i"
865 | },
866 | "source": [
867 | "### Visualize network predictions after training. Most of the predictions should be correct."
868 | ]
869 | },
870 | {
871 | "cell_type": "code",
872 | "metadata": {
873 | "id": "mGb8B6n6au9L"
874 | },
875 | "source": [
876 | "# Get predictions for the same batch\n",
877 | "predictions = net.apply(params, None, batch)\n",
878 | "pred_labels = jnp.argmax(predictions, axis=-1)\n",
879 | "gallery(batch['image'], pred_labels)"
880 | ],
881 | "execution_count": null,
882 | "outputs": []
883 | }
884 | ]
885 | }
886 |
--------------------------------------------------------------------------------
/0_intro_material/Lab_03_Intro_Numpy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Lab_03_Intro_Numpy.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "display_name": "Python 3",
12 | "language": "python",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "codemirror_mode": {
17 | "name": "ipython",
18 | "version": 3
19 | },
20 | "file_extension": ".py",
21 | "mimetype": "text/x-python",
22 | "name": "python",
23 | "nbconvert_exporter": "python",
24 | "pygments_lexer": "ipython3",
25 | "version": "3.8.5"
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "TK_s931vVvwi"
33 | },
34 | "source": [
35 | "# Numpy\n",
36 | "\n",
37 | "\" NumPy is the fundamental package for scientific computing with Python. It contains among other things:\n",
38 | "\n",
39 | "* a powerful N-dimensional array object\n",
40 | "* sophisticated (broadcasting) functions\n",
41 | "* useful linear algebra, Fourier transform, and random number capabilities \"\n",
42 | "\n",
43 | "\n",
44 | "-- From the [NumPy](http://www.numpy.org/) landing page.\n",
45 | "\n"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {
51 | "id": "kVwjoPsFGgbH"
52 | },
53 | "source": [
54 | "Before learning about numpy, we introduce..\n",
55 | "\n",
56 | "### The NXOR Function\n",
57 | "\n",
58 | "Many of the exercises involve working with the $\\mathrm{NXOR} \\colon \\; [-1, 1]^2 \\rightarrow \\{-1, +1\\}$ function defined as \n",
59 | "\n",
60 | "$$ (x_1, x_2) \\longmapsto \\mathrm{sgn}(x_1 \\cdot x_2) .$$\n",
61 | "\n",
62 | "where for $x_1 \\cdot x_2 = 0$ we let $\\mathrm{NXOR}(x_1, x_2) = -1$.\n",
63 | "\n",
64 | "We can visualize this function as\n",
65 | "\n",
66 | "![A set of points in \\[-1, +1\\]^2 with green and red markers denoting the value assigned to them by the NXOR function](https://github.com/tmlss2018/PracticalSessions/blob/master/assets/nxor_labels.png?raw=true)\n",
67 | "\n",
68 | "where each point in $ [-1, 1]^2$ is marked by green (+1) or red (-1) according to the value assigned to it by the NXOR function.\n",
69 | "\n",
70 | "\n"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "9HfKNF9JG3Sg"
77 | },
78 | "source": [
79 | "\n",
80 | "Over the course of the intro lab exercises we will\n",
81 | "\n",
82 | "1. Generate such data with numpy.\n",
83 | "2. Create the plot above with matplotlib.\n",
84 | "3. Train a model to learn this function.\n"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {
90 | "id": "psAhyeala4Qa"
91 | },
92 | "source": [
93 | "### Setup and imports. Run the following cell."
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "metadata": {
99 | "id": "jJmfT0IMa494"
100 | },
101 | "source": [
102 | "import numpy as np"
103 | ],
104 | "execution_count": 1,
105 | "outputs": []
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {
110 | "id": "209T8819ws6R"
111 | },
112 | "source": [
113 | "### Random numbers in numpy"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "metadata": {
119 | "colab": {
120 | "base_uri": "https://localhost:8080/"
121 | },
122 | "id": "HPoRLRxdwyHs",
123 | "outputId": "8ab46604-035f-4ade-93a9-9bd699bba33f"
124 | },
125 | "source": [
126 | "np.random.random((3, 2)) # Array of shape (3, 2), entries uniform in [0, 1)."
127 | ],
128 | "execution_count": 2,
129 | "outputs": [
130 | {
131 | "output_type": "execute_result",
132 | "data": {
133 | "text/plain": [
134 | "array([[0.24979127, 0.59265351],\n",
135 | " [0.33590299, 0.1185103 ],\n",
136 | " [0.58447674, 0.26072642]])"
137 | ]
138 | },
139 | "metadata": {
140 | "tags": []
141 | },
142 | "execution_count": 2
143 | }
144 | ]
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "metadata": {
149 | "id": "d1xeqjN_eGrM"
150 | },
151 | "source": [
152 | "Note that (as usual in computing) numpy produces pseudo-random numbers based on a seed, or more precisely a random state. In order to make random sequences and calculations based on reproducible, use\n",
153 | "\n",
154 | "* the [`np.random.seed()`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.seed.html) function to set the default global seed, or\n",
155 | "* the [`np.random.RandomState`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.html) class which is a container for a pseudo-random number generator and exposes methods for generating random numbers.\n"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/"
163 | },
164 | "id": "knUtcFWoFqK0",
165 | "outputId": "d454fca2-dbc5-4fc1-ca45-599baa0640f4"
166 | },
167 | "source": [
168 | "np.random.seed(0)\n",
169 | "print(np.random.random(2))\n",
170 | "# Reset the global random state to the same state.\n",
171 | "np.random.seed(0)\n",
172 | "print(np.random.random(2))"
173 | ],
174 | "execution_count": 3,
175 | "outputs": [
176 | {
177 | "output_type": "stream",
178 | "text": [
179 | "[0.5488135 0.71518937]\n",
180 | "[0.5488135 0.71518937]\n"
181 | ],
182 | "name": "stdout"
183 | }
184 | ]
185 | },
186 | {
187 | "cell_type": "markdown",
188 | "metadata": {
189 | "id": "Tz435THaxePN"
190 | },
191 | "source": [
192 | "### Numpy Array Operations 1\n",
193 | "\n",
194 | "There are a large number of operations you can run on any numpy array. Here we showcase some common ones."
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "metadata": {
200 | "colab": {
201 | "base_uri": "https://localhost:8080/"
202 | },
203 | "id": "YZbsqkyXxgUo",
204 | "outputId": "5d0a5c64-93df-4614-d693-d02cbf4a970f"
205 | },
206 | "source": [
207 | "# Create one from hard-coded data:\n",
208 | "ar = np.array([\n",
209 | " [0.0, 0.2],\n",
210 | " [0.9, 0.5],\n",
211 | " [0.3, 0.7],\n",
212 | "], dtype=np.float64) # float64 is the default.\n",
213 | "\n",
214 | "print('The array:\\n', ar)\n",
215 | "print()\n",
216 | "\n",
217 | "print('data type', ar.dtype)\n",
218 | "print('transpose\\n', ar.T)\n",
219 | "print('shape', ar.shape)\n",
220 | "print('reshaping an array', ar.reshape((6)))"
221 | ],
222 | "execution_count": 4,
223 | "outputs": [
224 | {
225 | "output_type": "stream",
226 | "text": [
227 | "The array:\n",
228 | " [[0. 0.2]\n",
229 | " [0.9 0.5]\n",
230 | " [0.3 0.7]]\n",
231 | "\n",
232 | "data type float64\n",
233 | "transpose\n",
234 | " [[0. 0.9 0.3]\n",
235 | " [0.2 0.5 0.7]]\n",
236 | "shape (3, 2)\n",
237 | "reshaping an array [0. 0.2 0.9 0.5 0.3 0.7]\n"
238 | ],
239 | "name": "stdout"
240 | }
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {
246 | "id": "s5wC8i8_4WBf"
247 | },
248 | "source": [
249 | "Many numpy operations are available both as np module functions as well as array methods. For example, we can also reshape as"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "metadata": {
255 | "colab": {
256 | "base_uri": "https://localhost:8080/"
257 | },
258 | "id": "AjGy5myl4jxo",
259 | "outputId": "ca4e1bbd-8835-47d8-bb61-3259ae702534"
260 | },
261 | "source": [
262 | "print('reshape v2', np.reshape(ar, (6, 1)))"
263 | ],
264 | "execution_count": 5,
265 | "outputs": [
266 | {
267 | "output_type": "stream",
268 | "text": [
269 | "reshape v2 [[0. ]\n",
270 | " [0.2]\n",
271 | " [0.9]\n",
272 | " [0.5]\n",
273 | " [0.3]\n",
274 | " [0.7]]\n"
275 | ],
276 | "name": "stdout"
277 | }
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {
283 | "id": "Tt4G37QAGWl4"
284 | },
285 | "source": [
286 | "### Numpy Indexing and selectors\n",
287 | "\n",
288 | "Here are some basic indexing examples from numpy."
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "metadata": {
294 | "colab": {
295 | "base_uri": "https://localhost:8080/"
296 | },
297 | "id": "4Bbdmb0BGXPc",
298 | "outputId": "d473b981-9993-4aae-fdb3-068c3d8961f8"
299 | },
300 | "source": [
301 | "ar"
302 | ],
303 | "execution_count": 6,
304 | "outputs": [
305 | {
306 | "output_type": "execute_result",
307 | "data": {
308 | "text/plain": [
309 | "array([[0. , 0.2],\n",
310 | " [0.9, 0.5],\n",
311 | " [0.3, 0.7]])"
312 | ]
313 | },
314 | "metadata": {
315 | "tags": []
316 | },
317 | "execution_count": 6
318 | }
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "metadata": {
324 | "colab": {
325 | "base_uri": "https://localhost:8080/"
326 | },
327 | "id": "6lk0NQGGGpRK",
328 | "outputId": "a437878d-6641-426f-a06b-7329537a2117"
329 | },
330 | "source": [
331 | "ar[0, 1] # row, column"
332 | ],
333 | "execution_count": 7,
334 | "outputs": [
335 | {
336 | "output_type": "execute_result",
337 | "data": {
338 | "text/plain": [
339 | "0.2"
340 | ]
341 | },
342 | "metadata": {
343 | "tags": []
344 | },
345 | "execution_count": 7
346 | }
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "metadata": {
352 | "colab": {
353 | "base_uri": "https://localhost:8080/"
354 | },
355 | "id": "Eh1zKgqMGpa-",
356 | "outputId": "d8c9ee03-0f4e-4f87-e536-d22606fcfee9"
357 | },
358 | "source": [
359 | "ar[:, 1] # slices: select all elements across the first (0th) axis."
360 | ],
361 | "execution_count": 8,
362 | "outputs": [
363 | {
364 | "output_type": "execute_result",
365 | "data": {
366 | "text/plain": [
367 | "array([0.2, 0.5, 0.7])"
368 | ]
369 | },
370 | "metadata": {
371 | "tags": []
372 | },
373 | "execution_count": 8
374 | }
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "metadata": {
380 | "colab": {
381 | "base_uri": "https://localhost:8080/"
382 | },
383 | "id": "QVo9SdiWSZqn",
384 | "outputId": "c1f9291b-7c54-4ac6-d08a-ada8db0e518b"
385 | },
386 | "source": [
387 | "ar[1:2, 1] # slices with syntax from:to, selecting [from, to)."
388 | ],
389 | "execution_count": 9,
390 | "outputs": [
391 | {
392 | "output_type": "execute_result",
393 | "data": {
394 | "text/plain": [
395 | "array([0.5])"
396 | ]
397 | },
398 | "metadata": {
399 | "tags": []
400 | },
401 | "execution_count": 9
402 | }
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "metadata": {
408 | "colab": {
409 | "base_uri": "https://localhost:8080/"
410 | },
411 | "id": "n64DQFr7Subs",
412 | "outputId": "c11d3ff2-52f1-4abb-800b-7c415a4fd6f0"
413 | },
414 | "source": [
415 | "ar[1:, 1] # Omit `to` to go all the way to the end"
416 | ],
417 | "execution_count": 10,
418 | "outputs": [
419 | {
420 | "output_type": "execute_result",
421 | "data": {
422 | "text/plain": [
423 | "array([0.5, 0.7])"
424 | ]
425 | },
426 | "metadata": {
427 | "tags": []
428 | },
429 | "execution_count": 10
430 | }
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "metadata": {
436 | "colab": {
437 | "base_uri": "https://localhost:8080/"
438 | },
439 | "id": "Uq7bmNPVTD2B",
440 | "outputId": "21d0fb40-031e-49aa-a9ed-27144ff80433"
441 | },
442 | "source": [
443 | "ar[:2, 1] # Omit `from` to start from the beginning"
444 | ],
445 | "execution_count": 11,
446 | "outputs": [
447 | {
448 | "output_type": "execute_result",
449 | "data": {
450 | "text/plain": [
451 | "array([0.2, 0.5])"
452 | ]
453 | },
454 | "metadata": {
455 | "tags": []
456 | },
457 | "execution_count": 11
458 | }
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "metadata": {
464 | "colab": {
465 | "base_uri": "https://localhost:8080/"
466 | },
467 | "id": "uJ7A1vVWTKQc",
468 | "outputId": "bb487b04-2175-478e-e17f-2e8213529eba"
469 | },
470 | "source": [
471 | "ar[0:-1, 1] # Use negative indexing to count elements from the back."
472 | ],
473 | "execution_count": 12,
474 | "outputs": [
475 | {
476 | "output_type": "execute_result",
477 | "data": {
478 | "text/plain": [
479 | "array([0.2, 0.5])"
480 | ]
481 | },
482 | "metadata": {
483 | "tags": []
484 | },
485 | "execution_count": 12
486 | }
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "metadata": {
492 | "id": "CVizP583IolT"
493 | },
494 | "source": [
495 | "We can also pass boolean arrays as indices. These will exactly define which elements to select."
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "metadata": {
501 | "colab": {
502 | "base_uri": "https://localhost:8080/"
503 | },
504 | "id": "-91q2Si7IBGN",
505 | "outputId": "14d3d425-6f71-4a37-e7e3-70e20040f0bf"
506 | },
507 | "source": [
508 | "ar[np.array([\n",
509 | " [True, False],\n",
510 | " [False, True],\n",
511 | " [True, False],\n",
512 | "])]"
513 | ],
514 | "execution_count": 13,
515 | "outputs": [
516 | {
517 | "output_type": "execute_result",
518 | "data": {
519 | "text/plain": [
520 | "array([0. , 0.5, 0.3])"
521 | ]
522 | },
523 | "metadata": {
524 | "tags": []
525 | },
526 | "execution_count": 13
527 | }
528 | ]
529 | },
530 | {
531 | "cell_type": "markdown",
532 | "metadata": {
533 | "id": "g9VttyKZI2aX"
534 | },
535 | "source": [
536 | "Boolean arrays can be created with logical operations, then used as selectors. Logical operators apply elementwise."
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "metadata": {
542 | "colab": {
543 | "base_uri": "https://localhost:8080/"
544 | },
545 | "id": "KyH_50JVIWOK",
546 | "outputId": "0e8d76b8-8ff5-41f7-bd57-b0616b74c9af"
547 | },
548 | "source": [
549 | "ar_2 = np.array([ # Nearly the same as ar\n",
550 | " [0.0, 0.1],\n",
551 | " [0.9, 0.5],\n",
552 | " [0.0, 0.7],\n",
553 | "])\n",
554 | "\n",
555 | "# Where ar_2 is smaller than ar, let ar_2 be -inf.\n",
556 | "ar_2[ar_2 < ar] = -np.inf\n",
557 | "ar_2"
558 | ],
559 | "execution_count": 14,
560 | "outputs": [
561 | {
562 | "output_type": "execute_result",
563 | "data": {
564 | "text/plain": [
565 | "array([[ 0. , -inf],\n",
566 | " [ 0.9, 0.5],\n",
567 | " [-inf, 0.7]])"
568 | ]
569 | },
570 | "metadata": {
571 | "tags": []
572 | },
573 | "execution_count": 14
574 | }
575 | ]
576 | },
577 | {
578 | "cell_type": "markdown",
579 | "metadata": {
580 | "id": "-GwBK4LJGpjK"
581 | },
582 | "source": [
583 | "### Numpy Operations 2"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "metadata": {
589 | "colab": {
590 | "base_uri": "https://localhost:8080/"
591 | },
592 | "id": "gHElhCDy4SrG",
593 | "outputId": "cc2f8875-9ef7-419b-c567-8d4528f9adbf"
594 | },
595 | "source": [
596 | "print('array:\\n', ar)\n",
597 | "print()\n",
598 | "\n",
599 | "print('sum across axis 0 (rows):', ar.sum(axis=0))\n",
600 | "print('mean', ar.mean())\n",
601 | "print('min', ar.min())\n",
602 | "print('row-wise min', ar.min(axis=1))"
603 | ],
604 | "execution_count": 15,
605 | "outputs": [
606 | {
607 | "output_type": "stream",
608 | "text": [
609 | "array:\n",
610 | " [[0. 0.2]\n",
611 | " [0.9 0.5]\n",
612 | " [0.3 0.7]]\n",
613 | "\n",
614 | "sum across axis 0 (rows): [1.2 1.4]\n",
615 | "mean 0.43333333333333335\n",
616 | "min 0.0\n",
617 | "row-wise min [0. 0.5 0.3]\n"
618 | ],
619 | "name": "stdout"
620 | }
621 | ]
622 | },
623 | {
624 | "cell_type": "markdown",
625 | "metadata": {
626 | "id": "bSNi0aLs03J1"
627 | },
628 | "source": [
629 | "We can also take element-wise minimums between two arrays.\n",
630 | "\n",
631 | "We may want to do this when \"clipping\" values in a matrix, that is, setting any values larger than, say, 0.6, to 0.6. We would do this in numpy with..\n",
632 | "\n",
633 | "### Broadcasting (and selectors)"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "metadata": {
639 | "colab": {
640 | "base_uri": "https://localhost:8080/"
641 | },
642 | "id": "GLCN0umLxuLv",
643 | "outputId": "1ed6704a-d63a-4230-9ce7-c27e7a2816ca"
644 | },
645 | "source": [
646 | "np.minimum(ar, 0.6)"
647 | ],
648 | "execution_count": 16,
649 | "outputs": [
650 | {
651 | "output_type": "execute_result",
652 | "data": {
653 | "text/plain": [
654 | "array([[0. , 0.2],\n",
655 | " [0.6, 0.5],\n",
656 | " [0.3, 0.6]])"
657 | ]
658 | },
659 | "metadata": {
660 | "tags": []
661 | },
662 | "execution_count": 16
663 | }
664 | ]
665 | },
666 | {
667 | "cell_type": "markdown",
668 | "metadata": {
669 | "id": "Kmq7Yf_73GgC"
670 | },
671 | "source": [
672 | "Numpy automatically turns the scalar 0.6 into an array the same size as `ar` in order to take element-wise minimum.\n",
673 | "\n"
674 | ]
675 | },
676 | {
677 | "cell_type": "markdown",
678 | "metadata": {
679 | "id": "vkM1cL-pBrEE"
680 | },
681 | "source": [
682 | "Broadcasting can save us a lot of typing, but in complicated cases it may require a good understanding of the exact rules followed.\n",
683 | "\n",
684 | "Some references:\n",
685 | "\n",
686 | "* [Numpy page that explains broadcasting](https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)\n",
687 | "* [Similar content with some visualizations](http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc)\n",
688 | "\n",
689 | "Here we follow with a selection of other useful broadcasting examples.\n"
690 | ]
691 | },
692 | {
693 | "cell_type": "code",
694 | "metadata": {
695 | "colab": {
696 | "base_uri": "https://localhost:8080/"
697 | },
698 | "id": "S52hBSnG5JRa",
699 | "outputId": "901e7125-95ee-41f4-e635-e965f7d5a17e"
700 | },
701 | "source": [
702 | "# Centering our array.\n",
703 | "print('centered array:\\n', ar - np.mean(ar))"
704 | ],
705 | "execution_count": 17,
706 | "outputs": [
707 | {
708 | "output_type": "stream",
709 | "text": [
710 | "centered array:\n",
711 | " [[-0.43333333 -0.23333333]\n",
712 | " [ 0.46666667 0.06666667]\n",
713 | " [-0.13333333 0.26666667]]\n"
714 | ],
715 | "name": "stdout"
716 | }
717 | ]
718 | },
719 | {
720 | "cell_type": "markdown",
721 | "metadata": {
722 | "id": "mS5IDFrA6ig3"
723 | },
724 | "source": [
725 | "Note that `np.mean()` was a scalar, but it is automatically subtracted from every element.\n"
726 | ]
727 | },
728 | {
729 | "cell_type": "markdown",
730 | "metadata": {
731 | "id": "KzlVmOY_FPhQ"
732 | },
733 | "source": [
734 | "We can write the minimum function ourselves, as well."
735 | ]
736 | },
737 | {
738 | "cell_type": "code",
739 | "metadata": {
740 | "colab": {
741 | "base_uri": "https://localhost:8080/"
742 | },
743 | "id": "zatGgBsOFNdF",
744 | "outputId": "b8f3f370-9a28-4f97-c7b1-018830ad88ca"
745 | },
746 | "source": [
747 | "clipped_ar = ar.copy() # So that ar is not modified.\n",
748 | "clipped_ar[clipped_ar > 0.6] = 0.6\n",
749 | "clipped_ar"
750 | ],
751 | "execution_count": 18,
752 | "outputs": [
753 | {
754 | "output_type": "execute_result",
755 | "data": {
756 | "text/plain": [
757 | "array([[0. , 0.2],\n",
758 | " [0.6, 0.5],\n",
759 | " [0.3, 0.6]])"
760 | ]
761 | },
762 | "metadata": {
763 | "tags": []
764 | },
765 | "execution_count": 18
766 | }
767 | ]
768 | },
769 | {
770 | "cell_type": "markdown",
771 | "metadata": {
772 | "id": "SnDqUDgtFyF5"
773 | },
774 | "source": [
775 | "A few things happened here:\n",
776 | "\n",
777 | "1. 0.6 was broadcast in for the greater than (>) operation\n",
778 | "2. The greater than operation defined a selector, selecting a subset of the elements of the array\n",
779 | "3. 0.6 was broadcast to the right number of elements for assignment."
780 | ]
781 | },
782 | {
783 | "cell_type": "markdown",
784 | "metadata": {
785 | "id": "QzLtiYXUFN5P"
786 | },
787 | "source": [
788 | "Vectors may also be broadcast into matrices."
789 | ]
790 | },
791 | {
792 | "cell_type": "code",
793 | "metadata": {
794 | "colab": {
795 | "base_uri": "https://localhost:8080/"
796 | },
797 | "id": "n2gv3YVf5JZL",
798 | "outputId": "19801a00-5678-4d96-a9b8-99ce9b4b9a9e"
799 | },
800 | "source": [
801 | "vec = np.array([1, 2])\n",
802 | "ar + vec"
803 | ],
804 | "execution_count": 19,
805 | "outputs": [
806 | {
807 | "output_type": "execute_result",
808 | "data": {
809 | "text/plain": [
810 | "array([[1. , 2.2],\n",
811 | " [1.9, 2.5],\n",
812 | " [1.3, 2.7]])"
813 | ]
814 | },
815 | "metadata": {
816 | "tags": []
817 | },
818 | "execution_count": 19
819 | }
820 | ]
821 | },
822 | {
823 | "cell_type": "markdown",
824 | "metadata": {
825 | "id": "7ehzlV8J75ar"
826 | },
827 | "source": [
828 | "Here the shapes of the involved arrays are:\n",
829 | "```\n",
830 | "ar (2d array): 2 x 2\n",
831 | "vec (1d array): 2\n",
832 | "Result (2d array): 2 x 2\n",
833 | "```\n",
834 | "\n",
835 | "When either of the dimensions compared is one (even implicitly, like in the case of `vec`), the other is used. In other words, dimensions with size 1 are stretched or “copied” to match the other.\n",
836 | "\n",
837 | "Here, this meant that the `[1, 2]` row was repeated to match the number of rows in `ar`, then added together.\n"
838 | ]
839 | },
840 | {
841 | "cell_type": "markdown",
842 | "metadata": {
843 | "id": "bBLjzt84A2ZL"
844 | },
845 | "source": [
846 | "If there is a shape mismatch, you will be informed. To try, uncomment the line below and run it."
847 | ]
848 | },
849 | {
850 | "cell_type": "code",
851 | "metadata": {
852 | "id": "sViqTRjXAWhN"
853 | },
854 | "source": [
855 | "# ar + np.array([[1, 2, 3]])"
856 | ],
857 | "execution_count": 20,
858 | "outputs": []
859 | },
860 | {
861 | "cell_type": "markdown",
862 | "metadata": {
863 | "id": "K4pvND75AWuG"
864 | },
865 | "source": [
866 | "#### Exercise\n",
867 | "\n",
868 | "Broadcast and add the vector `[10, 20, 30]` across the columns of `ar`. \n",
869 | "\n",
870 | "You should get \n",
871 | "```\n",
872 | "array([[10. , 10.2],\n",
873 | " [20.9, 20.5],\n",
874 | " [30.3, 30.7]])\n",
875 | " ```\n"
876 | ]
877 | },
878 | {
879 | "cell_type": "code",
880 | "metadata": {
881 | "id": "AgnJiYHm_ENT"
882 | },
883 | "source": [
884 | "# @title Code\n",
885 | "\n",
886 | "# Recall that you can use vec.shape to verify that your array has the\n",
887 | "# shape you expect.\n",
888 | "\n",
889 | "### Your code here ###"
890 | ],
891 | "execution_count": 21,
892 | "outputs": []
893 | },
894 | {
895 | "cell_type": "code",
896 | "metadata": {
897 | "cellView": "form",
898 | "colab": {
899 | "base_uri": "https://localhost:8080/"
900 | },
901 | "id": "gS1X9kt6_pLW",
902 | "outputId": "1e8e05f5-1a52-446d-a522-99880417c119"
903 | },
904 | "source": [
905 | "# @title Solution\n",
906 | "\n",
907 | "vec = np.array([[10], [20], [30]])\n",
908 | "ar + vec"
909 | ],
910 | "execution_count": 22,
911 | "outputs": [
912 | {
913 | "output_type": "execute_result",
914 | "data": {
915 | "text/plain": [
916 | "array([[10. , 10.2],\n",
917 | " [20.9, 20.5],\n",
918 | " [30.3, 30.7]])"
919 | ]
920 | },
921 | "metadata": {
922 | "tags": []
923 | },
924 | "execution_count": 22
925 | }
926 | ]
927 | },
928 | {
929 | "cell_type": "markdown",
930 | "metadata": {
931 | "id": "ZHGWztEX99vd"
932 | },
933 | "source": [
934 | "### `np.newaxis`\n",
935 | "\n",
936 | "We can use another numpy feature, `np.newaxis` to simply form the column vector that was required for the example above. It adds a singleton dimension to arrays at the desired location:"
937 | ]
938 | },
939 | {
940 | "cell_type": "code",
941 | "metadata": {
942 | "colab": {
943 | "base_uri": "https://localhost:8080/"
944 | },
945 | "id": "ac9tsc4e5JhA",
946 | "outputId": "1bfc04ba-69d1-4acd-e35c-fabfe7c0ec49"
947 | },
948 | "source": [
949 | "vec = np.array([1, 2])\n",
950 | "vec.shape"
951 | ],
952 | "execution_count": 23,
953 | "outputs": [
954 | {
955 | "output_type": "execute_result",
956 | "data": {
957 | "text/plain": [
958 | "(2,)"
959 | ]
960 | },
961 | "metadata": {
962 | "tags": []
963 | },
964 | "execution_count": 23
965 | }
966 | ]
967 | },
968 | {
969 | "cell_type": "code",
970 | "metadata": {
971 | "colab": {
972 | "base_uri": "https://localhost:8080/"
973 | },
974 | "id": "WpG-B92VCNAL",
975 | "outputId": "f817123e-6130-4cf2-9540-879c28f13e1d"
976 | },
977 | "source": [
978 | "vec[np.newaxis, :].shape"
979 | ],
980 | "execution_count": 24,
981 | "outputs": [
982 | {
983 | "output_type": "execute_result",
984 | "data": {
985 | "text/plain": [
986 | "(1, 2)"
987 | ]
988 | },
989 | "metadata": {
990 | "tags": []
991 | },
992 | "execution_count": 24
993 | }
994 | ]
995 | },
996 | {
997 | "cell_type": "code",
998 | "metadata": {
999 | "colab": {
1000 | "base_uri": "https://localhost:8080/"
1001 | },
1002 | "id": "YdbrkV1OCz8V",
1003 | "outputId": "474f5777-27dc-468d-9b3c-094c17020a81"
1004 | },
1005 | "source": [
1006 | "vec[:, np.newaxis].shape"
1007 | ],
1008 | "execution_count": 25,
1009 | "outputs": [
1010 | {
1011 | "output_type": "execute_result",
1012 | "data": {
1013 | "text/plain": [
1014 | "(2, 1)"
1015 | ]
1016 | },
1017 | "metadata": {
1018 | "tags": []
1019 | },
1020 | "execution_count": 25
1021 | }
1022 | ]
1023 | },
1024 | {
1025 | "cell_type": "markdown",
1026 | "metadata": {
1027 | "id": "KQZcrYaJxOOR"
1028 | },
1029 | "source": [
1030 | "Now you know more than enough to generate some example data for our `NXOR` function.\n",
1031 | "\n",
1032 | "\n",
1033 | "### Exercise: Generate Data for NXOR\n",
1034 | "\n",
1035 | "Write a function `get_data(num_examples)` that returns two numpy arrays\n",
1036 | "\n",
1037 | "* `inputs` of shape `num_examples x 2` with points selected uniformly from the $[-1, 1]^2$ domain.\n",
1038 | "* `labels` of shape `num_examples` with the associated output of `NXOR`."
1039 | ]
1040 | },
1041 | {
1042 | "cell_type": "code",
1043 | "metadata": {
1044 | "id": "SlZD-vcTVv-t"
1045 | },
1046 | "source": [
1047 | "# @title Code\n",
1048 | "\n",
1049 | "def get_data(num_examples):\n",
1050 | " # Replace with your code.\n",
1051 | " return np.zeros((num_examples, 2)), np.zeros((num_examples))"
1052 | ],
1053 | "execution_count": 26,
1054 | "outputs": []
1055 | },
1056 | {
1057 | "cell_type": "code",
1058 | "metadata": {
1059 | "cellView": "form",
1060 | "id": "JWHBltJ7fimG"
1061 | },
1062 | "source": [
1063 | "# @title Solution\n",
1064 | "\n",
1065 | "# Solution 1.\n",
1066 | "def get_data(num_examples):\n",
1067 | " inputs = 2*np.random.random((num_examples, 2)) - 1\n",
1068 | " labels = np.prod(inputs, axis=1)\n",
1069 | " labels[labels <= 0] = -1\n",
1070 | " labels[labels > 0] = 1\n",
1071 | " return inputs, labels\n",
1072 | "\n",
1073 | "# Solution 1.\n",
1074 | "# def get_data(num_examples):\n",
1075 | "# inputs = 2*np.random.random((num_examples, 2)) - 1\n",
1076 | "# labels = np.sign(np.prod(inputs, axis=1))\n",
1077 | "# labels[labels == 0] = -1\n",
1078 | "# return inputs, labels"
1079 | ],
1080 | "execution_count": 27,
1081 | "outputs": []
1082 | },
1083 | {
1084 | "cell_type": "code",
1085 | "metadata": {
1086 | "colab": {
1087 | "base_uri": "https://localhost:8080/"
1088 | },
1089 | "id": "8HhjadmmvZyc",
1090 | "outputId": "8ac3d30a-d8a4-4577-b76a-9fdd65cfe422"
1091 | },
1092 | "source": [
1093 | "get_data(4)"
1094 | ],
1095 | "execution_count": 28,
1096 | "outputs": [
1097 | {
1098 | "output_type": "execute_result",
1099 | "data": {
1100 | "text/plain": [
1101 | "(array([[ 0.20552675, 0.08976637],\n",
1102 | " [-0.1526904 , 0.29178823],\n",
1103 | " [-0.12482558, 0.783546 ],\n",
1104 | " [ 0.92732552, -0.23311696]]), array([ 1., -1., -1., -1.]))"
1105 | ]
1106 | },
1107 | "metadata": {
1108 | "tags": []
1109 | },
1110 | "execution_count": 28
1111 | }
1112 | ]
1113 | },
1114 | {
1115 | "cell_type": "markdown",
1116 | "metadata": {
1117 | "id": "IQ0Ip-SVb3Nc"
1118 | },
1119 | "source": [
1120 | "## That's all, folks!\n",
1121 | "\n",
1122 | "For now."
1123 | ]
1124 | }
1125 | ]
1126 | }
--------------------------------------------------------------------------------
/0_intro_material/README.md:
--------------------------------------------------------------------------------
1 | # Hello, M2L!
2 | by Marco Buzzelli, Luigi Celona, Flavio Piccoli, and Simone Zini
3 |
4 | _Designed for education purposes. Please do not distribute without permission_.
5 |
6 | This colaboratory (colab) will help you get prepared for the rest of the practical sessions at M2L 2021.
7 | Here you will get familiar with the environment and tools used in the rest of the practical sessions.
8 |
9 | **We strongly encourage you to cover the Colab and JAX tutorials.**
10 |
11 | After that, if you want to know a bit more about the basics, you can browse the other notebooks (Numpy and Plotting).
12 |
13 |
14 | ## Lab 01: What is Google Colab?
15 | This tutorial teaches you about colab and its main features. You will need to know this for the rest of the labs as you will write all code in colab.
16 |
17 | Open the file _Lab_01_Intro_Colab.ipynb_ to access the colab.
18 |
19 | [](https://colab.research.google.com/github/m2lschool/tutorials2021/blob/master/0_intro_material/Lab_01_Intro_Colab.ipynb)
20 |
21 | ## Lab 02: JAX
22 | This colab is designed as an all-encompassing guide to get started with the fundamentals of using Tensorflow/JAX/Haiku in a Google Colab environment.
23 |
24 | Open the file _Lab_02_Intro_JAX.ipynb_ to access the colab.
25 |
26 | [](https://colab.research.google.com/github/m2lschool/tutorials2021/blob/master/0_intro_material/Lab_02_Intro_JAX.ipynb)
27 |
28 | ## Lab 03: Numpy
29 | This colab introduces you to numpy, the python package we use for computing. Topics such as
30 |
31 | * array creation
32 | * operations on arrays
33 | * indexing and selection on arrays
34 | * broadcasting
35 |
36 | are covered.
37 |
38 | By the end of this colab you will have written a function to generate datasets for learning the NXOR function.
39 |
40 | Open the file _Lab_03_Intro_Numpy.ipynb_ to access the colab.
41 |
42 | [](https://colab.research.google.com/github/m2lschool/tutorials2021/blob/master/0_intro_material/Lab_03_Intro_Numpy.ipynb)
43 |
44 |
45 | ## Lab 04: Plotting with matplotlib, more numpy
46 | In this colab we generate the plot included with the definition of the NXOR function. In the process we use some more features of numpy.
47 |
48 | Other than the plot above, we also see how to use matplotlib to
49 |
50 | * draw line plots so we can visualize training curves later, and
51 | * display images, or galleries of images so we can visualize datasets and the output of learnt models.
52 |
53 | Open the file _Lab_04_Intro_Plotting.ipynb_ to access the colab.
54 |
55 | [](https://colab.research.google.com/github/m2lschool/tutorials2021/blob/master/0_intro_material/Lab_04_Intro_Plotting.ipynb)
56 |
57 |
--------------------------------------------------------------------------------
/1_vision/ComputerVisionPart3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "_xnMOsbqHz61"
7 | },
8 | "source": [
9 | "## M2L: Image to Image Translation Tutorial (PART III)"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "riM2huyZjKZ-"
16 | },
17 | "source": [
18 | "### Image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004)\n",
19 | "### by Marco Buzzelli, Luigi Celona, Flavio Piccoli, and Simone Zini\n",
20 | "\n",
21 | "* Excercise: Convert building facades to real buildings"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {
27 | "id": "k7drmZ2pj2F-"
28 | },
29 | "source": [
30 | "We will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n",
31 | "\n",
32 | "Each epoch takes around 15 seconds on a single V100 GPU.\n",
33 | "\n",
34 | "Below is the output generated after training the model for 200 epochs.\n",
35 | "\n",
36 | "\n",
37 | ""
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "id": "e1_Y75QXJS6h"
44 | },
45 | "source": [
46 | "## Import JAX, Haiku, and other libraries"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {
53 | "id": "TmG8rR8D1fZa"
54 | },
55 | "outputs": [],
56 | "source": [
57 | "!pip install ipdb &> /dev/null\n",
58 | "!pip install git+https://github.com/deepmind/dm-haiku &> /dev/null\n",
59 | "!pip install -U tensorboard &> /dev/null\n",
60 | "!pip install git+https://github.com/deepmind/optax.git &> /dev/null"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {
67 | "id": "YfIk2es3hJEd"
68 | },
69 | "outputs": [],
70 | "source": [
71 | "import os\n",
72 | "import time\n",
73 | "import pickle\n",
74 | "import functools\n",
75 | "import numpy as np\n",
76 | "\n",
77 | "# Dataset libraries.\n",
78 | "import tensorflow as tf\n",
79 | "\n",
80 | "\n",
81 | "import haiku as hk\n",
82 | "import jax\n",
83 | "import optax # Package for optimizer.\n",
84 | "import jax.numpy as jnp\n",
85 | "\n",
86 | "# Plotting libraries.\n",
87 | "from matplotlib import pyplot as plt\n",
88 | "from IPython import display\n",
89 | "\n",
90 | "from typing import Mapping, Optional, Tuple, NamedTuple, Any"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {
96 | "id": "iYn4MdZnKCey"
97 | },
98 | "source": [
99 | "## Download the dataset\n",
100 | "\n",
101 | "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n",
102 | "\n",
103 | "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n",
104 | "* In random mirroring, the image is randomly flipped horizontally i.e left to right."
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {
111 | "id": "Kn-k8kTXuAlv"
112 | },
113 | "outputs": [],
114 | "source": [
115 | "_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'\n",
116 | "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
117 | " origin=_URL,\n",
118 | " extract=True)\n",
119 | "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {
125 | "id": "ZjMDvmR_j2GA"
126 | },
127 | "source": [
128 | "## Hyper-parameters for data preprocessing and training"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {
135 | "id": "2CbTEt448b4R"
136 | },
137 | "outputs": [],
138 | "source": [
139 | "BUFFER_SIZE = 400 #@param\n",
140 | "BATCH_SIZE = 1 #@param\n",
141 | "IMG_WIDTH = 256 #@param\n",
142 | "IMG_HEIGHT = 256 #@param\n",
143 | "TRAIN_INIT_RANDOM_SEED = 1729 #@param\n",
144 | "LAMBDA = 100 #@param\n",
145 | "EPOCHS = 150\n",
146 | "\n",
147 | "# We need a random key for initialization.\n",
148 | "rng = jax.random.PRNGKey(TRAIN_INIT_RANDOM_SEED)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {
155 | "id": "aO9ZAGH5K3SY"
156 | },
157 | "outputs": [],
158 | "source": [
159 | "#@title Dataset loading and preprocessing\n",
160 | "# We use tensorflow readers; JAX does not have support for input data reading\n",
161 | "# and pre-processing.\n",
162 | "def load(image_file):\n",
163 | " image = tf.io.read_file(image_file)\n",
164 | " image = tf.image.decode_jpeg(image)\n",
165 | "\n",
166 | " w = tf.shape(image)[1]\n",
167 | "\n",
168 | " w = w // 2\n",
169 | " real_image = image[:, :w, :]\n",
170 | " input_image = image[:, w:, :]\n",
171 | "\n",
172 | " input_image = tf.cast(input_image, tf.float32)\n",
173 | " real_image = tf.cast(real_image, tf.float32)\n",
174 | "\n",
175 | " return input_image, real_image"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {
182 | "id": "4OLHMpsQ5aOv"
183 | },
184 | "outputs": [],
185 | "source": [
186 | "inp, re = load(PATH + 'train/100.jpg')\n",
187 | "# Casting to int for matplotlib to show the image.\n",
188 | "plt.figure()\n",
189 | "plt.imshow(inp/255.0)\n",
190 | "plt.figure()\n",
191 | "plt.imshow(re/255.0)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {
198 | "id": "rwwYQpu9FzDu"
199 | },
200 | "outputs": [],
201 | "source": [
202 | "def resize(input_image, real_image, height, width):\n",
203 | " input_image = tf.image.resize(input_image, [height, width],\n",
204 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
205 | " real_image = tf.image.resize(real_image, [height, width],\n",
206 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
207 | "\n",
208 | " return input_image, real_image"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {
215 | "id": "Yn3IwqhiIszt"
216 | },
217 | "outputs": [],
218 | "source": [
219 | "def random_crop(input_image, real_image):\n",
220 | " stacked_image = tf.stack([input_image, real_image], axis=0)\n",
221 | " cropped_image = tf.image.random_crop(\n",
222 | " stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
223 | "\n",
224 | " return cropped_image[0], cropped_image[1]"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "metadata": {
231 | "id": "muhR2cgbLKWW"
232 | },
233 | "outputs": [],
234 | "source": [
235 | "# Normalizes the input images to [-1, 1].\n",
236 | "def normalize(input_image, real_image):\n",
237 | " input_image = (input_image / 127.5) - 1\n",
238 | " real_image = (real_image / 127.5) - 1\n",
239 | "\n",
240 | " return input_image, real_image"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {
246 | "id": "G2RHtcYPj2GC"
247 | },
248 | "source": [
249 | "Random jittering as described in the paper is composed of the following steps:\n",
250 | "1. Resize an image to a bigger height and width\n",
251 | "2. Randomly crop to the target size\n",
252 | "3. Randomly flip the image horizontally"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {
259 | "id": "fVQOjcPVLrUc"
260 | },
261 | "outputs": [],
262 | "source": [
263 | "#@title Data augmentation { form-width: \"40%\"}\n",
264 | "@tf.function()\n",
265 | "def random_jitter(input_image, real_image):\n",
266 | " # Resizing to 286 x 286 x 3.\n",
267 | " input_image, real_image = resize(input_image, real_image, 286, 286)\n",
268 | "\n",
269 | " # Randomly cropping to 256 x 256 x 3.\n",
270 | " input_image, real_image = random_crop(input_image, real_image)\n",
271 | "\n",
272 | " if tf.random.uniform(()) > 0.5:\n",
273 | " # Random mirroring.\n",
274 | " input_image = tf.image.flip_left_right(input_image)\n",
275 | " real_image = tf.image.flip_left_right(real_image)\n",
276 | "\n",
277 | " return input_image, real_image"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {
284 | "id": "n0OGdi6D92kM"
285 | },
286 | "outputs": [],
287 | "source": [
288 | "plt.figure(figsize=(6, 6))\n",
289 | "for i in range(4):\n",
290 | " rj_inp, rj_re = random_jitter(inp, re)\n",
291 | " plt.subplot(2, 2, i + 1)\n",
292 | " plt.imshow(rj_inp / 255.0)\n",
293 | " plt.axis('off')\n",
294 | "plt.show()"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {
301 | "id": "tyaP4hLJ8b4W"
302 | },
303 | "outputs": [],
304 | "source": [
305 | "def load_image_train(image_file):\n",
306 | " input_image, real_image = load(image_file)\n",
307 | " input_image, real_image = random_jitter(input_image, real_image)\n",
308 | " input_image, real_image = normalize(input_image, real_image)\n",
309 | "\n",
310 | " return input_image, real_image"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {
317 | "id": "VB3Z6D_zKSru"
318 | },
319 | "outputs": [],
320 | "source": [
321 | "def load_image_test(image_file):\n",
322 | " input_image, real_image = load(image_file)\n",
323 | " input_image, real_image = resize(input_image, real_image,\n",
324 | " IMG_HEIGHT, IMG_WIDTH)\n",
325 | " input_image, real_image = normalize(input_image, real_image)\n",
326 | "\n",
327 | " return input_image, real_image"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "metadata": {
333 | "id": "PIGN6ouoQxt3"
334 | },
335 | "source": [
336 | "## Input Pipeline"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "metadata": {
343 | "id": "SQHmYSmk8b4b"
344 | },
345 | "outputs": [],
346 | "source": [
347 | "train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')\n",
348 | "train_dataset = train_dataset.map(load_image_train,\n",
349 | " num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
350 | "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
351 | "train_dataset = train_dataset.batch(BATCH_SIZE)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "metadata": {
358 | "id": "MS9J0yA58b4g"
359 | },
360 | "outputs": [],
361 | "source": [
362 | "test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')\n",
363 | "test_dataset = test_dataset.map(load_image_test)\n",
364 | "test_dataset = test_dataset.batch(BATCH_SIZE)"
365 | ]
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "metadata": {
370 | "id": "THY-sZMiQ4UV"
371 | },
372 | "source": [
373 | "## Build the Generator\n",
374 | "The architecture of the generator is a modified U-Net\n",
375 | " * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)\n",
376 | " * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout (applied to the first 3 blocks) -> ReLU)\n",
377 | " * There are skip connections between the encoder and decoder (as in U-Net)"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "metadata": {
384 | "id": "3R09ATE_SH9P"
385 | },
386 | "outputs": [],
387 | "source": [
388 | "#@title Encoder definition (Conv -> Batchnorm -> Leaky ReLU) { form-width: \"40%\" }\n",
389 | "class Encoder(hk.Module):\n",
390 | " def __init__(self,\n",
391 | " channels: int,\n",
392 | " size: int,\n",
393 | " apply_batchnorm=True):\n",
394 | " super().__init__()\n",
395 | " self.channels = channels\n",
396 | " self.size = size\n",
397 | " self.initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)\n",
398 | " self.apply_batchnorm = apply_batchnorm\n",
399 | "\n",
400 | " def __call__(self, inputs, is_training):\n",
401 | " ################\n",
402 | " # YOUR CODE HERE\n",
403 | " # Encoder steps:\n",
404 | " # 1. conv layer (channels, size, stride=2, init, pad='SAME', nobias)\n",
405 | " # 2. batch_norm\n",
406 | " # 3. leakyReLU (negative_slop=0.2)\n",
407 | "\n",
408 | " # YOUR CODE HERE out = hk.Conv2D(...\n",
409 | " \n",
410 | " if self.apply_batchnorm:\n",
411 | " # YOUR CODE HERE bn = hk.BatchNorm(... \n",
412 | " # YOUR CODE HERE out = ...\n",
413 | " \n",
414 | " # YOUR CODE HERE out = ...\n",
415 | " return out"
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "execution_count": null,
421 | "metadata": {
422 | "id": "nhgDsHClSQzP"
423 | },
424 | "outputs": [],
425 | "source": [
426 | "#@title Decoder definition (Transposed Conv -> Batchnorm -> Dropout (applied to the first 3 blocks) -> ReLU) { form-width: \"40%\" }\n",
427 | "class Decoder(hk.Module):\n",
428 | " def __init__(self,\n",
429 | " channels: int,\n",
430 | " size: int,\n",
431 | " apply_dropout=False):\n",
432 | " super().__init__()\n",
433 | " self.initializer = hk.initializers.RandomNormal(mean=0.0,\n",
434 | " stddev=0.02)\n",
435 | " self.channels = channels\n",
436 | " self.size = size\n",
437 | " self.apply_dropout = apply_dropout\n",
438 | "\n",
439 | " def __call__(self, inputs, is_training):\n",
440 | " ################\n",
441 | " # YOUR CODE HERE\n",
442 | " # Decoder steps:\n",
443 | " # 1. transpose conv layer (channels, size, stride=2, init, pad='SAME', nobias)\n",
444 | " # 2. batch_norm\n",
445 | " # 3. dropout\n",
446 | " # 4. ReLU\n",
447 | "\n",
448 | " # YOUR CODE HERE out = hk.Conv2DTranspose(...\n",
449 | " # YOUR CODE HERE out = hk.BatchNorm(...\n",
450 | " \n",
451 | " if self.apply_dropout and is_training:\n",
452 | " # Apply 0.5 probability dropout to out\n",
453 | " out = hk.dropout(rng, rate=0.5, x=out)\n",
454 | "\n",
455 | " # relu activation\n",
456 | " # YOUR CODE HERE out = ...\n",
457 | " \n",
458 | " return out"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": null,
464 | "metadata": {
465 | "id": "lFPI4Nu-8b4q"
466 | },
467 | "outputs": [],
468 | "source": [
469 | "class Generator(hk.Module):\n",
470 | " def __init__(self):\n",
471 | " super().__init__()\n",
472 | " # In comment the output size of each block. `bs` is the batch size.\n",
473 | " self.down_stack = [\n",
474 | " Encoder(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)\n",
475 | " Encoder(128, 4), # (bs, 64, 64, 128)\n",
476 | " Encoder(256, 4), # (bs, 32, 32, 256)\n",
477 | " Encoder(512, 4), # (bs, 16, 16, 512)\n",
478 | " Encoder(512, 4), # (bs, 8, 8, 512)\n",
479 | " Encoder(512, 4), # (bs, 4, 4, 512)\n",
480 | " Encoder(512, 4), # (bs, 2, 2, 512)\n",
481 | " Encoder(512, 4), # (bs, 1, 1, 512)\n",
482 | " ]\n",
483 | "\n",
484 | " self.up_stack = [\n",
485 | " Decoder(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)\n",
486 | " Decoder(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)\n",
487 | " Decoder(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)\n",
488 | " Decoder(512, 4), # (bs, 16, 16, 1024)\n",
489 | " Decoder(256, 4), # (bs, 32, 32, 512)\n",
490 | " Decoder(128, 4), # (bs, 64, 64, 256)\n",
491 | " Decoder(64, 4), # (bs, 128, 128, 128)\n",
492 | " ]\n",
493 | "\n",
494 | " initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)\n",
495 | " self.last = hk.Conv2DTranspose(3, 4,\n",
496 | " stride=2,\n",
497 | " padding='SAME',\n",
498 | " w_init=initializer) # (bs, 256, 256, 3)\n",
499 | "\n",
500 | " def __call__(self, x, is_training):\n",
501 | " ################\n",
502 | " # YOUR CODE HERE\n",
503 | "\n",
504 | " # Downsampling through the model\n",
505 | " # YOUR CODE HERE ...\n",
506 | "\n",
507 | " # Upsampling and establishing the skip connections\n",
508 | " # YOUR CODE HERE ...\n",
509 | "\n",
510 | " x = self.last(x)\n",
511 | " return x"
512 | ]
513 | },
514 | {
515 | "cell_type": "markdown",
516 | "metadata": {
517 | "id": "dpDPEQXIAiQO"
518 | },
519 | "source": [
520 | "### Generator loss\n",
521 | " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n",
522 | " * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n",
523 | " * This allows the generated image to become structurally similar to the target image.\n",
524 | " * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)."
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "metadata": {
531 | "id": "XuzZAKFjqzqA"
532 | },
533 | "outputs": [],
534 | "source": [
535 | "# Computes binary cross entropy for classification.\n",
536 | "def bce_w_logits(\n",
537 | " logits: jnp.ndarray,\n",
538 | " target: jnp.ndarray\n",
539 | ") -> jnp.ndarray:\n",
540 | " \"\"\"\n",
541 | " Binary Cross Entropy Loss\n",
542 | " :param logits: Input tensor\n",
543 | " :param target: Target tensor\n",
544 | " :return: Scalar value\n",
545 | " \"\"\"\n",
546 | " max_val = jnp.clip(logits, 0, None)\n",
547 | " loss = logits - logits * target + max_val + \\\n",
548 | " jnp.log(jnp.exp(-max_val) + jnp.exp((-logits - max_val)))\n",
549 | "\n",
550 | " return jnp.mean(loss)"
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": null,
556 | "metadata": {
557 | "id": "90BIcCKcDMxz"
558 | },
559 | "outputs": [],
560 | "source": [
561 | "def generator_loss(\n",
562 | " disc_generated_output: jnp.ndarray,\n",
563 | " gen_output: jnp.ndarray,\n",
564 | " target: jnp.ndarray\n",
565 | ") -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:\n",
566 | " \"\"\"Computes the generator loss for the given batch.\"\"\"\n",
567 | " gan_loss = bce_w_logits(disc_generated_output,\n",
568 | " jnp.ones_like(disc_generated_output))\n",
569 | "\n",
570 | " # Mean absolute error.\n",
571 | " l1_loss = jnp.mean(jnp.abs(target - gen_output))\n",
572 | " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
573 | "\n",
574 | " return total_gen_loss, gan_loss, l1_loss"
575 | ]
576 | },
577 | {
578 | "cell_type": "markdown",
579 | "metadata": {
580 | "id": "TlB-XMY5Awj9"
581 | },
582 | "source": [
583 | "\n"
584 | ]
585 | },
586 | {
587 | "cell_type": "markdown",
588 | "metadata": {
589 | "id": "ZTKZfoaoEF22"
590 | },
591 | "source": [
592 | "## Build the Discriminator\n",
593 | "The Discriminator is a PatchGAN\n",
594 | "\n",
595 | " * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)\n",
596 | " * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n",
597 | " * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n",
598 | " * Discriminator receives 2 inputs.\n",
599 | " * Input image and the target image, which it should classify as real.\n",
600 | " * Input image and the generated image (output of generator), which it should classify as fake.\n",
601 | " * We concatenate these 2 inputs together in the code (`jax.numpy.concatenate([inp, tar], axis=-1)`)"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": null,
607 | "metadata": {
608 | "id": "ll6aNeQx8b4v"
609 | },
610 | "outputs": [],
611 | "source": [
612 | "class Discriminator(hk.Module):\n",
613 | " def __init__(self):\n",
614 | " super().__init__()\n",
615 | " initializer = hk.initializers.RandomNormal(mean=0.0, stddev=0.02)\n",
616 | "\n",
617 | " self.down1 = Encoder(64, 4, apply_batchnorm=False)\n",
618 | " self.down2 = Encoder(128, 4)\n",
619 | " self.down3 = Encoder(256, 4)\n",
620 | "\n",
621 | " self.conv = hk.Conv2D(512, 4, stride=1, w_init=initializer,\n",
622 | " padding='VALID', with_bias=False)\n",
623 | " self.bn = hk.BatchNorm(create_scale=True, create_offset=True,\n",
624 | " decay_rate=0.999, eps=0.001)\n",
625 | " self.last = hk.Conv2D(1, 4, stride=1, padding='VALID',\n",
626 | " w_init=initializer)\n",
627 | "\n",
628 | " def __call__(self, x, is_training): # (bs, 256, 256, channels*2)\n",
629 | " x = self.down1(x, is_training) # (bs, 128, 128, 64)\n",
630 | " x = self.down2(x, is_training) # (bs, 64, 64, 128)\n",
631 | " x = self.down3(x, is_training) # (bs, 32, 32, 256)\n",
632 | " x = jnp.pad(x, ((0, 0), (1, 1), (1, 1), (0, 0))) # (bs, 34, 34, 256)\n",
633 | " x = self.conv(x) # (bs, 31, 31, 512)\n",
634 | " x = self.bn(x, is_training)\n",
635 | " x = jax.nn.leaky_relu(x, negative_slope=0.2)\n",
636 | " x = jnp.pad(x, ((0, 0), (1, 1), (1, 1), (0, 0))) # (bs, 33, 33, 256)\n",
637 | " x = self.last(x) # (bs, 30, 30, 1)\n",
638 | " return x"
639 | ]
640 | },
641 | {
642 | "cell_type": "markdown",
643 | "metadata": {
644 | "id": "AOqg1dhUAWoD"
645 | },
646 | "source": [
647 | "### Discriminator loss\n",
648 | " * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
649 | " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**\n",
650 | " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**\n",
651 | " * Then the total_loss is the sum of real_loss and the generated_loss\n"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": null,
657 | "metadata": {
658 | "id": "wkMNfBWlT-PV"
659 | },
660 | "outputs": [],
661 | "source": [
662 | "def discriminator_loss(disc_real_output, disc_generated_output):\n",
663 | " real_loss = bce_w_logits(disc_real_output,\n",
664 | " jnp.ones_like(disc_real_output))\n",
665 | " generated_loss = bce_w_logits(disc_generated_output,\n",
666 | " jnp.zeros_like(disc_generated_output))\n",
667 | " total_disc_loss = real_loss + generated_loss\n",
668 | " return total_disc_loss"
669 | ]
670 | },
671 | {
672 | "cell_type": "markdown",
673 | "metadata": {
674 | "id": "-ede4p2YELFa"
675 | },
676 | "source": [
677 | "The training procedure for the discriminator is shown below.\n",
678 | "\n",
679 | "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004)."
680 | ]
681 | },
682 | {
683 | "cell_type": "markdown",
684 | "metadata": {
685 | "id": "IS9sHa-1BoAF"
686 | },
687 | "source": [
688 | "\n"
689 | ]
690 | },
691 | {
692 | "cell_type": "markdown",
693 | "metadata": {
694 | "id": "NLKOG55MErD0"
695 | },
696 | "source": [
697 | "## Training\n",
698 | "\n",
699 | "* For each example input generates an output.\n",
700 | "* The discriminator receives the input image and the generated image as the first input. The second input is the input image and the target image.\n",
701 | "* Next, we calculate the generator and the discriminator loss.\n",
702 | "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.\n",
703 | "* Last, we log the losses to TensorBoard."
704 | ]
705 | },
706 | {
707 | "cell_type": "markdown",
708 | "metadata": {
709 | "id": "0FMYgY_mPfTi"
710 | },
711 | "source": [
712 | "### Define the Checkpoint-saver\n"
713 | ]
714 | },
715 | {
716 | "cell_type": "code",
717 | "execution_count": null,
718 | "metadata": {
719 | "id": "7micePl8XVtF"
720 | },
721 | "outputs": [],
722 | "source": [
723 | "checkpoint_dir = './training_checkpoints'\n",
724 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
725 | "if not os.path.exists(checkpoint_prefix):\n",
726 | " os.makedirs(checkpoint_prefix)"
727 | ]
728 | },
729 | {
730 | "cell_type": "code",
731 | "execution_count": null,
732 | "metadata": {
733 | "id": "hh0Q_XuscND5"
734 | },
735 | "outputs": [],
736 | "source": [
737 | "class P2PTuple(NamedTuple):\n",
738 | " gen: Any\n",
739 | " disc: Any\n",
740 | "\n",
741 | "\n",
742 | "class P2PState(NamedTuple):\n",
743 | " params: P2PTuple\n",
744 | " states: P2PTuple\n",
745 | " opt_state: P2PTuple\n",
746 | "\n",
747 | "\n",
748 | "class Pix2Pix:\n",
749 | " \"\"\"Pix2Pix model.\"\"\"\n",
750 | "\n",
751 | " def __init__(self):\n",
752 | " self.gen_transform = hk.transform_with_state(\n",
753 | " lambda *args: Generator()(*args)\n",
754 | " )\n",
755 | " self.disc_transform = hk.transform_with_state(\n",
756 | " lambda *args: Discriminator()(*args)\n",
757 | " )\n",
758 | "\n",
759 | " # Build the optimizers.\n",
760 | " self.gen_optimizer = optax.adam(2e-4, b1=0.5, b2=0.999)\n",
761 | " self.disc_optimizer = optax.adam(2e-4, b1=0.5, b2=0.999)\n",
762 | "\n",
763 | " @functools.partial(jax.jit, static_argnums=0)\n",
764 | " def initial_state(self,\n",
765 | " rng: jnp.ndarray,\n",
766 | " batch: Tuple[jnp.ndarray, jnp.ndarray]):\n",
767 | " \"\"\"Returns the initial parameters and optimize states of the generator.\n",
768 | " \"\"\"\n",
769 | " rng, gen_rng, disc_rng = jax.random.split(rng, 3)\n",
770 | " gen_params, gen_state = self.gen_transform.init(gen_rng, batch[0], True)\n",
771 | " disc_params, disc_state = \\\n",
772 | " self.disc_transform.init(disc_rng,\n",
773 | " jnp.concatenate(batch, axis=-1),\n",
774 | " True)\n",
775 | " params = P2PTuple(gen=gen_params, disc=disc_params)\n",
776 | " states = P2PTuple(gen=gen_state, disc=disc_state)\n",
777 | "\n",
778 | " # Initialize the optimizers.\n",
779 | " opt_state = P2PTuple(gen=self.gen_optimizer.init(params.gen),\n",
780 | " disc=self.disc_optimizer.init(params.disc)\n",
781 | " )\n",
782 | " return P2PState(params=params, states=states, opt_state=opt_state)\n",
783 | "\n",
784 | " def generate_images(self,\n",
785 | " params: P2PTuple,\n",
786 | " state: P2PTuple,\n",
787 | " test_input):\n",
788 | " # Note: The `training=True` is intentional here since\n",
789 | " # we want the batch statistics while running the model\n",
790 | " # on the test dataset. If we use training=False, we will get\n",
791 | " # the accumulated statistics learned from the training dataset\n",
792 | " # (which we don't want)\n",
793 | " prediction, _ = self.gen_transform.apply(\n",
794 | " params, state, None, test_input, True\n",
795 | " )\n",
796 | "\n",
797 | " return prediction\n",
798 | "\n",
799 | " def gen_loss(self,\n",
800 | " gen_params: P2PTuple,\n",
801 | " gen_state: P2PTuple,\n",
802 | " batch: Tuple[jnp.ndarray, jnp.ndarray],\n",
803 | " disc_params: P2PTuple,\n",
804 | " disc_state: P2PTuple,\n",
805 | " rng_gen, rng_disc):\n",
806 | " \"\"\"Computes a regularized loss for the given batch.\"\"\"\n",
807 | "\n",
808 | " input, target = batch\n",
809 | " \n",
810 | " output, gen_state = self.gen_transform.apply(\n",
811 | " gen_params, gen_state, rng_gen, input, True\n",
812 | " )\n",
813 | "\n",
814 | " # Evaluate using the discriminator.\n",
815 | " disc_generated_output, disc_state = self.disc_transform.apply(\n",
816 | " disc_params, disc_state, rng_disc,\n",
817 | " jnp.concatenate([input, output], axis=-1), True\n",
818 | " )\n",
819 | "\n",
820 | " states = P2PTuple(gen=gen_state, disc=disc_state)\n",
821 | "\n",
822 | " # Compute discriminator loss.\n",
823 | " total_loss, gan_loss, l1_loss = generator_loss(\n",
824 | " disc_generated_output, output, target\n",
825 | " )\n",
826 | "\n",
827 | " return total_loss, (output, states, gan_loss, l1_loss)\n",
828 | "\n",
829 | " def disc_loss(self,\n",
830 | " params: P2PTuple,\n",
831 | " state: P2PTuple,\n",
832 | " batch: Tuple[jnp.ndarray, jnp.ndarray],\n",
833 | " gen_output: jnp.ndarray, rng):\n",
834 | " \"\"\"Computes a regularized loss for the given batch.\"\"\"\n",
835 | " input, target = batch\n",
836 | " real_output, state = self.disc_transform.apply(\n",
837 | " params, state, rng, jnp.concatenate([input, target], axis=-1), True\n",
838 | " )\n",
839 | "\n",
840 | " generated_output, state = self.disc_transform.apply(\n",
841 | " params, state, rng,\n",
842 | " jnp.concatenate([input, gen_output], axis=-1), True\n",
843 | " )\n",
844 | "\n",
845 | " # Compute discriminator loss.\n",
846 | " loss = discriminator_loss(real_output, generated_output)\n",
847 | " return loss, state\n",
848 | "\n",
849 | " @functools.partial(jax.jit, static_argnums=0)\n",
850 | " def update(self, rng, p2p_state, batch):\n",
851 | " \"\"\" Performs a parameter update. \"\"\"\n",
852 | " rng, gen_rng, disc_rng = jax.random.split(rng, 3)\n",
853 | "\n",
854 | " # Update the generator.\n",
855 | " (gen_loss, gen_aux), gen_grads = \\\n",
856 | " jax.value_and_grad(self.gen_loss,\n",
857 | " has_aux=True)(\n",
858 | " p2p_state.params.gen,\n",
859 | " p2p_state.states.gen,\n",
860 | " batch,\n",
861 | " p2p_state.params.disc,\n",
862 | " p2p_state.states.disc,\n",
863 | " gen_rng, disc_rng)\n",
864 | "\n",
865 | " generated_output, states, gan_loss, l1_loss = gen_aux\n",
866 | " gen_update, gen_opt_state = self.gen_optimizer.update(\n",
867 | " gen_grads, p2p_state.opt_state.gen)\n",
868 | " gen_params = optax.apply_updates(p2p_state.params.gen, gen_update)\n",
869 | "\n",
870 | " # Update the discriminator.\n",
871 | " (disc_loss, disc_state), disc_grads = \\\n",
872 | " jax.value_and_grad(self.disc_loss,\n",
873 | " has_aux=True)(\n",
874 | " p2p_state.params.disc,\n",
875 | " states.disc,\n",
876 | " batch,\n",
877 | " generated_output,\n",
878 | " disc_rng)\n",
879 | "\n",
880 | " disc_update, disc_opt_state = self.disc_optimizer.update(\n",
881 | " disc_grads, p2p_state.opt_state.disc)\n",
882 | " disc_params = optax.apply_updates(p2p_state.params.disc, disc_update)\n",
883 | "\n",
884 | " params = P2PTuple(gen=gen_params, disc=disc_params)\n",
885 | " states = P2PTuple(gen=states.gen, disc=disc_state)\n",
886 | " opt_state = P2PTuple(gen=gen_opt_state, disc=disc_opt_state)\n",
887 | " p2p_state = P2PState(params=params, states=states, opt_state=opt_state)\n",
888 | "\n",
889 | " return p2p_state, gen_loss, disc_loss, gan_loss, l1_loss"
890 | ]
891 | },
892 | {
893 | "cell_type": "code",
894 | "execution_count": null,
895 | "metadata": {
896 | "id": "lYVSnyvwjCUb"
897 | },
898 | "outputs": [],
899 | "source": [
900 | "# The model.\n",
901 | "net = Pix2Pix()\n",
902 | "\n",
903 | "# Initialize the network and optimizer.\n",
904 | "for input, target in train_dataset.take(1):\n",
905 | " net_state = net.initial_state(rng, (jnp.asarray(input),\n",
906 | " jnp.asarray(target)))"
907 | ]
908 | },
909 | {
910 | "cell_type": "code",
911 | "execution_count": null,
912 | "metadata": {
913 | "id": "xNNMDBNH12q-"
914 | },
915 | "outputs": [],
916 | "source": [
917 | "import datetime\n",
918 | "log_dir = \"logs/\"\n",
919 | "\n",
920 | "summary_writer = tf.summary.create_file_writer(\n",
921 | " log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))"
922 | ]
923 | },
924 | {
925 | "cell_type": "markdown",
926 | "metadata": {
927 | "id": "hx7s-vBHFKdh"
928 | },
929 | "source": [
930 | "The actual training loop:\n",
931 | "\n",
932 | "* Iterates over the number of epochs.\n",
933 | "* On each epoch it clears the display, and runs `generate_images` to show it's progress.\n",
934 | "* On each epoch it iterates over the training dataset, printing a '.' for each example.\n",
935 | "* It saves a checkpoint every 20 epochs."
936 | ]
937 | },
938 | {
939 | "cell_type": "code",
940 | "execution_count": null,
941 | "metadata": {
942 | "id": "2M7LmLtGEMQJ"
943 | },
944 | "outputs": [],
945 | "source": [
946 | "def fit(train_ds, epochs, test_ds, net_state):\n",
947 | " for epoch in range(epochs):\n",
948 | " start = time.time()\n",
949 | "\n",
950 | " display.clear_output(wait=True)\n",
951 | "\n",
952 | " for example_input, example_target in test_ds.take(1):\n",
953 | " prediction = net.generate_images(net_state.params.gen,\n",
954 | " net_state.states.gen,\n",
955 | " jnp.asarray(example_input))\n",
956 | " plt.figure(figsize=(15, 15))\n",
957 | "\n",
958 | " display_list = [example_input[0], example_target[0], prediction[0]]\n",
959 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
960 | "\n",
961 | " for i in range(3):\n",
962 | " plt.subplot(1, 3, i+1)\n",
963 | " plt.title(title[i])\n",
964 | " # Getting the pixel values between [0, 1] to plot it.\n",
965 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
966 | " plt.axis('off')\n",
967 | " plt.show()\n",
968 | "\n",
969 | " print(\"Epoch: \", epoch)\n",
970 | "\n",
971 | " # Train loop.\n",
972 | " for n, (input_image, target) in train_ds.enumerate():\n",
973 | " # Take a training step.\n",
974 | " print('.', end='')\n",
975 | " if (n+1) % 100 == 0:\n",
976 | " print()\n",
977 | "\n",
978 | " net_state, gen_total_loss, disc_loss, \\\n",
979 | " gen_gan_loss, gen_l1_loss = \\\n",
980 | " net.update(rng, net_state,\n",
981 | " (jnp.asarray(input_image), jnp.asarray(target)))\n",
982 | "\n",
983 | " with summary_writer.as_default():\n",
984 | " tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)\n",
985 | " tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)\n",
986 | " tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)\n",
987 | " tf.summary.scalar('disc_loss', disc_loss, step=epoch)\n",
988 | " \n",
989 | " print()\n",
990 | "\n",
991 | " # Save (checkpoint) the model every 20 epochs.\n",
992 | " if (epoch + 1) % 20 == 0:\n",
993 | " with open(\n",
994 | " os.path.join(checkpoint_prefix, 'pix2pix_params.pkl'),\n",
995 | " 'wb') as handle:\n",
996 | " pickle.dump(net_state.params, handle,\n",
997 | " protocol=pickle.HIGHEST_PROTOCOL)\n",
998 | "\n",
999 | " with open(\n",
1000 | " os.path.join(checkpoint_prefix, 'pix2pix_states.pkl'),\n",
1001 | " 'wb') as handle:\n",
1002 | " pickle.dump(net_state.states, handle,\n",
1003 | " protocol=pickle.HIGHEST_PROTOCOL)\n",
1004 | "\n",
1005 | " print('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
1006 | " time.time()-start))\n",
1007 | "\n",
1008 | " # Save the last checkpoint.\n",
1009 | " with open(\n",
1010 | " os.path.join(checkpoint_prefix, 'pix2pix_params.pkl'),\n",
1011 | " 'wb') as handle:\n",
1012 | " pickle.dump(net_state.params, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
1013 | "\n",
1014 | " with open(\n",
1015 | " os.path.join(checkpoint_prefix, 'pix2pix_states.pkl'),\n",
1016 | " 'wb') as handle:\n",
1017 | " pickle.dump(net_state.states, handle, protocol=pickle.HIGHEST_PROTOCOL)"
1018 | ]
1019 | },
1020 | {
1021 | "cell_type": "markdown",
1022 | "metadata": {
1023 | "id": "wozqyTh2wmCu"
1024 | },
1025 | "source": [
1026 | "This training loop saves logs you can easily view in TensorBoard to monitor the training progress. Working locally you would launch a separate tensorboard process. In a notebook, if you want to monitor with TensorBoard it's easiest to launch the viewer before starting the training.\n",
1027 | "\n",
1028 | "To launch the viewer run the following cell:"
1029 | ]
1030 | },
1031 | {
1032 | "cell_type": "code",
1033 | "execution_count": null,
1034 | "metadata": {
1035 | "id": "Ot22ujrlLhOd"
1036 | },
1037 | "outputs": [],
1038 | "source": [
1039 | "#docs_infra: no_execute\n",
1040 | "%load_ext tensorboard\n",
1041 | "%tensorboard --logdir {log_dir}"
1042 | ]
1043 | },
1044 | {
1045 | "cell_type": "markdown",
1046 | "metadata": {
1047 | "id": "Pe0-8Bzg22ox"
1048 | },
1049 | "source": [
1050 | "Now run the training loop:"
1051 | ]
1052 | },
1053 | {
1054 | "cell_type": "code",
1055 | "execution_count": null,
1056 | "metadata": {
1057 | "id": "a1zZmKmvOH85"
1058 | },
1059 | "outputs": [],
1060 | "source": [
1061 | "fit(train_dataset, EPOCHS, test_dataset, net_state)"
1062 | ]
1063 | },
1064 | {
1065 | "cell_type": "markdown",
1066 | "metadata": {
1067 | "id": "oeq9sByu86-B"
1068 | },
1069 | "source": [
1070 | "If you want to share the TensorBoard results _publicly_ you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.\n",
1071 | "\n",
1072 | "Note: This requires a Google account.\n",
1073 | "\n",
1074 | "```\n",
1075 | "!tensorboard dev upload --logdir {log_dir}\n",
1076 | "```"
1077 | ]
1078 | },
1079 | {
1080 | "cell_type": "markdown",
1081 | "metadata": {
1082 | "id": "l-kT7WHRKz-E"
1083 | },
1084 | "source": [
1085 | "Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the \"interrupt execution\" option in your notebook tool."
1086 | ]
1087 | },
1088 | {
1089 | "cell_type": "markdown",
1090 | "metadata": {
1091 | "id": "-lGhS_LfwQoL"
1092 | },
1093 | "source": [
1094 | "You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).\n",
1095 | "\n",
1096 | "TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.\n",
1097 | "\n",
1098 | "It can also included inline using an `