├── readme_pics
├── jax_logo.png
├── going_deeper.jpg
├── tpu_jit_speed.PNG
└── api_onion_structure.jpg
├── .gitignore
├── LICENCE
├── README.md
├── Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb
└── Tutorial_4_Flax_Zero2Hero_Colab.ipynb
/readme_pics/jax_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/jax_logo.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # PyCharm IDE
2 | .idea
3 | __pycache__
4 |
5 | # Jupyter notebook checkpoints
6 | .ipynb_checkpoints
7 |
8 |
--------------------------------------------------------------------------------
/readme_pics/going_deeper.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/going_deeper.jpg
--------------------------------------------------------------------------------
/readme_pics/tpu_jit_speed.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/tpu_jit_speed.PNG
--------------------------------------------------------------------------------
/readme_pics/api_onion_structure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/get-started-with-JAX/HEAD/readme_pics/api_onion_structure.jpg
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Aleksa Gordić
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Get started with JAX! :computer: :zap:
2 |
3 | The goal of this repo is to make it easier to get started with [JAX](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Haiku](https://github.com/deepmind/dm-haiku)!
4 |
5 | `JAX` ecosystem is becoming an increasingly popular alternative to `PyTorch` and `TensorFlow`. :sunglasses:
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | *Note: I'm only going to recommend content that I've personally analyzed and found useful here.
18 | If you want a comprehensive list check out the [awesome-jax repo](https://github.com/n2cholas/awesome-jax).*
19 |
20 | ## Table of Contents
21 | * [Machine Learning with JAX](#my-machine-learning-with-jax-tutorials)
22 | + [Tutorial #1: From Zero to Hero](#tutorial-1-from-zero-to-hero)
23 | + [Tutorial #2: From Hero to Hero Pro+](#tutorial-2-from-hero-to-heropro)
24 | + [Tutorial #3: Coding a Neural Network from Scratch in Pure JAX](#tutorial-3-building-a-neural-network-from-scratch)
25 | + [Tutorial #4: Flax From Zero to Hero](#tutorial-4-machine-learning-with-flax---from-zero-to-hero)
26 | + [Tutorial #5: Haiku From Zero to Hero (coming soon)](#tutorial-5-coming-up-machine-learning-with-haiku---from-zero-to-hero)
27 | * [Other useful JAX resources](#other-useful-content)
28 |
29 | ## My Machine Learning with JAX Tutorials
30 |
31 | *Tip on how to use notebooks: just open the notebook directly in Google Colab
32 | (you'll see a button on top of the Jupyter file which will direct you to Colab).
33 | This way you can avoid having to setup the Python env! (This was especially convenient for me since I'm on Windows which is still not supported)*
34 |
35 | ### Tutorial #1: From Zero to Hero
36 |
37 | In this video, we start from the basics and then gradually dig into the nitty-gritty details
38 | of `jit`, `grad`, `vmap`, and various other idiosyncrasies of JAX.
39 |
40 | [YouTube Video (Tutorial #1)](https://youtu.be/SstuvS-tVc0)
41 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_1_JAX_Zero2Hero_Colab.ipynb)
42 |
43 |
44 |
46 |
47 |
48 | ### Tutorial #2: From Hero to HeroPro+
49 |
50 | In this video, we learn all additional components needed to train ML models (such as NNs) on multiple machines!
51 | We'll train a simple MLP model and we'll even train an ML model on 8 TPU cores!
52 |
53 | [YouTube Video (Tutorial #2)](https://www.youtube.com/watch?v=CQQaifxuFcs)
54 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb)
55 |
56 |
57 |
59 |
60 |
61 | ### Tutorial #3: Building a Neural Network from Scratch
62 |
63 | Watch me code a Neural Network from scratch! :partying_face: In this 3rd video of the JAX tutorials series.
64 |
65 | In this video, I build an [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) and train it as a classifier on MNIST
66 | using PyTorch's data loader (although it's trivial to use a more complex dataset) - all this in "pure" JAX (no Flax/Haiku/Optax).
67 |
68 | I then do an additional analysis:
69 | * Visualize MLP's learned weights
70 | * Visualize embeddings of a batch of images using t-SNE
71 | * Finally, I analyze whether we have too many dead ReLU neurons in our network
72 |
73 | [YouTube Video (Tutorial #3)](https://www.youtube.com/watch?v=6_PqUPxRmjY)
74 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb) (Note: I'll soon refactor it but I'll link the original)
75 |
76 |
77 |
79 |
80 |
81 | ---
82 |
83 | ### Tutorial #4: Machine Learning with Flax - From Zero to Hero
84 |
85 | In this video, I cover everything you need to know to get started with [Flax](https://github.com/google/flax)!
86 |
87 | We cover `init`, `apply`, `TrainState`, etc. and other idiosyncrasies like the usage of `mutable` and `rngs` keywords.
88 |
89 | [YouTube Video (Tutorial #4)](https://www.youtube.com/watch?v=5eUSmJvK8WA)
90 | [Accompanying Jupyter Notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_4_Flax_Zero2Hero_Colab.ipynb)
91 |
92 |
93 |
95 |
96 |
97 | ---
98 |
99 | ### Tutorial #5 (coming up): Machine Learning with Haiku - From Zero to Hero
100 |
101 | todo
102 |
103 | ## Other useful content
104 |
105 | Aside from the [official docs](https://jax.readthedocs.io/) here are some resources that helped me.
106 |
107 | ### Videos
108 |
109 | * [Introduction to JAX](https://www.youtube.com/watch?v=0mVmRHMaOJ4&ab_channel=GoogleCloudTech) (gives a very high-level overview)
110 | * [JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas](https://www.youtube.com/watch?v=z-WSrQDXkuM&ab_channel=Enthought) (many more details)
111 | * [NeurIPS 2020: JAX Ecosystem Meetup](https://www.youtube.com/watch?v=iDxJxIyzSiM&t=1s&ab_channel=DeepMind) (DeepMind team about the ecosystem of libs around JAX)
112 | * [Introduction to JAX for Machine Learning and More](https://www.youtube.com/watch?v=QkmKfzxbCLQ&ab_channel=UWaterlooDataScience) (nice, hands-on workshop)
113 | * [Day 1 Talks: JAX, Flax & Transformers | HuggingFace](https://www.youtube.com/watch?v=fuAyUQcVzTY&ab_channel=HuggingFace) (all 4 talks are good)
114 | * [Day 2 Talks: JAX, Flax & Transformers | HuggingFace](https://www.youtube.com/watch?v=__eG63ZP_5g&ab_channel=HuggingFace) (only the first 2 talks are relevant)
115 |
116 | ### Blogs
117 |
118 | * [Using JAX to accelerate our research | DeepMind](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) (similar info as the NeuroIPS 2020 video)
119 | * [You don't know JAX | Colin Raffel](https://colinraffel.com/blog/you-don-t-know-jax.html)
120 |
121 | ## Acknowledgements
122 |
123 | * The notebooks were heavily inspired by the official [JAX](https://jax.readthedocs.io/), [Flax](https://flax.readthedocs.io/en/latest/), and [Haiku](https://dm-haiku.readthedocs.io/en/latest/) docs.
124 |
125 | ## Citation
126 |
127 | If you find this content useful, please cite the following:
128 |
129 | ```
130 | @misc{Gordic2021GetStartedWithJAX,
131 | author = {Gordić, Aleksa},
132 | title = {Get started with JAX},
133 | year = {2021},
134 | publisher = {GitHub},
135 | journal = {GitHub repository},
136 | howpublished = {\url{https://github.com/gordicaleksa/get-started-with-JAX}},
137 | }
138 | ```
139 |
140 | ## Connect With Me
141 |
142 | If you'd love to have some more AI-related content in your life :nerd_face:, consider:
143 | * Subscribing to my YouTube channel [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany) :bell:
144 | * Follow me on [LinkedIn](https://www.linkedin.com/in/aleksagordic/) and [Twitter](https://twitter.com/gordic_aleksa) :bulb:
145 | * Follow me on [Medium](https://gordicaleksa.medium.com/) :books: :heart:
146 | * Join the [Discord](https://discord.gg/peBrCpheKE) community! :family:
147 |
148 | ## Licence
149 |
150 | [](https://github.com/gordicaleksa/get-started-with-JAX/blob/master/LICENCE)
--------------------------------------------------------------------------------
/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Tutorial 3: JAX - Building a Neural Network from Scratch.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyMpTL6XC+tcxSqZ2FePUhlZ",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | }
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | ""
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "XZuyP-M3KPUR"
35 | },
36 | "source": [
37 | "# MLP training on MNIST"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "id": "8-SzJ0NTKRP1"
44 | },
45 | "source": [
46 | "import numpy as np\n",
47 | "import jax.numpy as jnp\n",
48 | "from jax.scipy.special import logsumexp\n",
49 | "import jax\n",
50 | "from jax import jit, vmap, pmap, grad, value_and_grad\n",
51 | "\n",
52 | "from torchvision.datasets import MNIST\n",
53 | "from torch.utils.data import DataLoader"
54 | ],
55 | "execution_count": 1,
56 | "outputs": []
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "colab": {
62 | "base_uri": "https://localhost:8080/"
63 | },
64 | "id": "G4NrxSVjKt8f",
65 | "outputId": "6bb8bef6-3098-4fd5-8ffe-62f4b0b1aa79"
66 | },
67 | "source": [
68 | "seed = 0\n",
69 | "mnist_img_size = (28, 28)\n",
70 | "\n",
71 | "def init_MLP(layer_widths, parent_key, scale=0.01):\n",
72 | "\n",
73 | " params = []\n",
74 | " keys = jax.random.split(parent_key, num=len(layer_widths)-1)\n",
75 | "\n",
76 | " for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):\n",
77 | " weight_key, bias_key = jax.random.split(key)\n",
78 | " params.append([\n",
79 | " scale*jax.random.normal(weight_key, shape=(out_width, in_width)),\n",
80 | " scale*jax.random.normal(bias_key, shape=(out_width,))\n",
81 | " ]\n",
82 | " )\n",
83 | "\n",
84 | " return params\n",
85 | "\n",
86 | "# test\n",
87 | "key = jax.random.PRNGKey(seed)\n",
88 | "MLP_params = init_MLP([784, 512, 256, 10], key)\n",
89 | "print(jax.tree_map(lambda x: x.shape, MLP_params))"
90 | ],
91 | "execution_count": 4,
92 | "outputs": [
93 | {
94 | "output_type": "stream",
95 | "name": "stdout",
96 | "text": [
97 | "[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]\n"
98 | ]
99 | }
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "metadata": {
105 | "colab": {
106 | "base_uri": "https://localhost:8080/"
107 | },
108 | "id": "U_z7eLxINv9x",
109 | "outputId": "e9909f9f-6778-4977-91f1-f5b14dd9ecd4"
110 | },
111 | "source": [
112 | "def MLP_predict(params, x):\n",
113 | " hidden_layers = params[:-1]\n",
114 | "\n",
115 | " activation = x\n",
116 | " for w, b in hidden_layers:\n",
117 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n",
118 | "\n",
119 | " w_last, b_last = params[-1]\n",
120 | " logits = jnp.dot(w_last, activation) + b_last\n",
121 | "\n",
122 | " # log(exp(o1)) - log(sum(exp(o1), exp(o2), ..., exp(o10)))\n",
123 | " # log( exp(o1) / sum(...) )\n",
124 | " return logits - logsumexp(logits)\n",
125 | "\n",
126 | "# tests\n",
127 | "\n",
128 | "# test single example\n",
129 | "\n",
130 | "dummy_img_flat = np.random.randn(np.prod(mnist_img_size))\n",
131 | "print(dummy_img_flat.shape)\n",
132 | "\n",
133 | "prediction = MLP_predict(MLP_params, dummy_img_flat)\n",
134 | "print(prediction.shape)\n",
135 | "\n",
136 | "# test batched function\n",
137 | "batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))\n",
138 | "\n",
139 | "dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))\n",
140 | "print(dummy_imgs_flat.shape)\n",
141 | "predictions = batched_MLP_predict(MLP_params, dummy_imgs_flat)\n",
142 | "print(predictions.shape)"
143 | ],
144 | "execution_count": 5,
145 | "outputs": [
146 | {
147 | "output_type": "stream",
148 | "name": "stdout",
149 | "text": [
150 | "(784,)\n",
151 | "(10,)\n",
152 | "(16, 784)\n",
153 | "(16, 10)\n"
154 | ]
155 | }
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/"
163 | },
164 | "id": "5pPM1dZ4QyYe",
165 | "outputId": "3317666b-e167-46b7-8cf4-b8592adc065a"
166 | },
167 | "source": [
168 | "def custom_transform(x):\n",
169 | " return np.ravel(np.array(x, dtype=np.float32))\n",
170 | "\n",
171 | "def custom_collate_fn(batch):\n",
172 | " transposed_data = list(zip(*batch))\n",
173 | "\n",
174 | " labels = np.array(transposed_data[1])\n",
175 | " imgs = np.stack(transposed_data[0])\n",
176 | "\n",
177 | " return imgs, labels\n",
178 | "\n",
179 | "batch_size = 128\n",
180 | "train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)\n",
181 | "test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)\n",
182 | "\n",
183 | "train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)\n",
184 | "test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)\n",
185 | "\n",
186 | "# test\n",
187 | "batch_data = next(iter(train_loader))\n",
188 | "imgs = batch_data[0]\n",
189 | "lbls = batch_data[1]\n",
190 | "print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)\n",
191 | "\n",
192 | "# optimization - loading the whole dataset into memory\n",
193 | "train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)\n",
194 | "train_lbls = jnp.array(train_dataset.targets)\n",
195 | "\n",
196 | "test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)\n",
197 | "test_lbls = jnp.array(test_dataset.targets)"
198 | ],
199 | "execution_count": null,
200 | "outputs": [
201 | {
202 | "output_type": "stream",
203 | "name": "stdout",
204 | "text": [
205 | "(128, 784) float32 (128,) int64\n"
206 | ]
207 | }
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "metadata": {
213 | "id": "YQEYcSNzVeim"
214 | },
215 | "source": [
216 | "num_epochs = 5\n",
217 | "\n",
218 | "def loss_fn(params, imgs, gt_lbls):\n",
219 | " predictions = batched_MLP_predict(params, imgs)\n",
220 | "\n",
221 | " return -jnp.mean(predictions * gt_lbls)\n",
222 | "\n",
223 | "def accuracy(params, dataset_imgs, dataset_lbls):\n",
224 | " pred_classes = jnp.argmax(batched_MLP_predict(params, dataset_imgs), axis=1)\n",
225 | " return jnp.mean(dataset_lbls == pred_classes)\n",
226 | "\n",
227 | "@jit\n",
228 | "def update(params, imgs, gt_lbls, lr=0.01):\n",
229 | " loss, grads = value_and_grad(loss_fn)(params, imgs, gt_lbls)\n",
230 | "\n",
231 | " return loss, jax.tree_multimap(lambda p, g: p - lr*g, params, grads)\n",
232 | "\n",
233 | "# Create a MLP\n",
234 | "MLP_params = init_MLP([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)\n",
235 | "\n",
236 | "for epoch in range(num_epochs):\n",
237 | "\n",
238 | " for cnt, (imgs, lbls) in enumerate(train_loader):\n",
239 | "\n",
240 | " gt_labels = jax.nn.one_hot(lbls, len(MNIST.classes))\n",
241 | " \n",
242 | " loss, MLP_params = update(MLP_params, imgs, gt_labels)\n",
243 | " \n",
244 | " if cnt % 50 == 0:\n",
245 | " print(loss)\n",
246 | "\n",
247 | " print(f'Epoch {epoch}, train acc = {accuracy(MLP_params, train_images, train_lbls)} test acc = {accuracy(MLP_params, test_images, test_lbls)}')\n"
248 | ],
249 | "execution_count": null,
250 | "outputs": []
251 | },
252 | {
253 | "cell_type": "code",
254 | "metadata": {
255 | "colab": {
256 | "base_uri": "https://localhost:8080/",
257 | "height": 316
258 | },
259 | "id": "YmdBRBvU1wuA",
260 | "outputId": "efcfa75e-d0bb-4f16-9fb2-e85e82a53bcf"
261 | },
262 | "source": [
263 | "imgs, lbls = next(iter(test_loader))\n",
264 | "img = imgs[0].reshape(mnist_img_size)\n",
265 | "gt_lbl = lbls[0]\n",
266 | "print(img.shape)\n",
267 | "\n",
268 | "import matplotlib.pyplot as plt\n",
269 | "\n",
270 | "pred = jnp.argmax(MLP_predict(MLP_params, np.ravel(img)))\n",
271 | "print('pred', pred)\n",
272 | "print('gt', gt_lbl)\n",
273 | "\n",
274 | "plt.imshow(img); plt.show()"
275 | ],
276 | "execution_count": null,
277 | "outputs": [
278 | {
279 | "output_type": "stream",
280 | "name": "stdout",
281 | "text": [
282 | "(28, 28)\n",
283 | "pred 7\n",
284 | "gt 7\n"
285 | ]
286 | },
287 | {
288 | "output_type": "display_data",
289 | "data": {
290 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n",
291 | "text/plain": [
292 | ""
293 | ]
294 | },
295 | "metadata": {
296 | "needs_background": "light"
297 | }
298 | }
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {
304 | "id": "TwgI3fZbKRqM"
305 | },
306 | "source": [
307 | "# Visualizations"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "colab": {
314 | "base_uri": "https://localhost:8080/",
315 | "height": 299
316 | },
317 | "id": "jddJj8zo4D1e",
318 | "outputId": "fb157d1c-4fbe-45a5-c84d-6abe38355a5e"
319 | },
320 | "source": [
321 | "w = MLP_params[0][0]\n",
322 | "print(w.shape)\n",
323 | "\n",
324 | "w_single = w[500, :].reshape(mnist_img_size)\n",
325 | "print(w_single.shape)\n",
326 | "plt.imshow(w_single); plt.show()"
327 | ],
328 | "execution_count": null,
329 | "outputs": [
330 | {
331 | "output_type": "stream",
332 | "name": "stdout",
333 | "text": [
334 | "(512, 784)\n",
335 | "(28, 28)\n"
336 | ]
337 | },
338 | {
339 | "output_type": "display_data",
340 | "data": {
341 | "image/png": "\n",
342 | "text/plain": [
343 | ""
344 | ]
345 | },
346 | "metadata": {
347 | "needs_background": "light"
348 | }
349 | }
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "metadata": {
355 | "colab": {
356 | "base_uri": "https://localhost:8080/",
357 | "height": 484
358 | },
359 | "id": "AZxm7G3j4iOS",
360 | "outputId": "521c3ad2-147d-4076-eea0-6537f32dafa0"
361 | },
362 | "source": [
363 | "# todo: visualize embeddings using t-SNE\n",
364 | "\n",
365 | "from sklearn.manifold import TSNE\n",
366 | "\n",
367 | "def fetch_activations(params, x):\n",
368 | " hidden_layers = params[:-1]\n",
369 | "\n",
370 | " activation = x\n",
371 | " for w, b in hidden_layers:\n",
372 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n",
373 | "\n",
374 | " return activation\n",
375 | "\n",
376 | "batched_fetch_activations = vmap(fetch_activations, in_axes=(None, 0))\n",
377 | "imgs, lbls = next(iter(test_loader))\n",
378 | "\n",
379 | "batch_activations = batched_fetch_activations(MLP_params, imgs)\n",
380 | "print(batch_activations.shape) # (128, 2)\n",
381 | "\n",
382 | "t_sne_embeddings = TSNE(n_components=2, perplexity=30,).fit_transform(batch_activations)\n",
383 | "cora_label_to_color_map = {0: \"red\", 1: \"blue\", 2: \"green\", 3: \"orange\", 4: \"yellow\", 5: \"pink\", 6: \"gray\"}\n",
384 | "\n",
385 | "for class_id in range(10):\n",
386 | " plt.scatter(t_sne_embeddings[lbls == class_id, 0], t_sne_embeddings[lbls == class_id, 1], s=20, color=cora_label_to_color_map[class_id])\n",
387 | "plt.show()"
388 | ],
389 | "execution_count": null,
390 | "outputs": [
391 | {
392 | "output_type": "stream",
393 | "name": "stdout",
394 | "text": [
395 | "(128, 256)\n"
396 | ]
397 | },
398 | {
399 | "output_type": "error",
400 | "ename": "KeyError",
401 | "evalue": "ignored",
402 | "traceback": [
403 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
404 | "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
405 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclass_id\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_sne_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlbls\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mclass_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_sne_embeddings\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlbls\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mclass_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcora_label_to_color_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mclass_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
406 | "\u001b[0;31mKeyError\u001b[0m: 7"
407 | ]
408 | },
409 | {
410 | "output_type": "display_data",
411 | "data": {
412 | "image/png": "\n",
413 | "text/plain": [
414 | ""
415 | ]
416 | },
417 | "metadata": {
418 | "needs_background": "light"
419 | }
420 | }
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "metadata": {
426 | "colab": {
427 | "base_uri": "https://localhost:8080/"
428 | },
429 | "id": "MHL27HumNgwf",
430 | "outputId": "d44b1e9c-33d6-4dd3-cf05-b3a1e19f0194"
431 | },
432 | "source": [
433 | "# todo: dead neurons\n",
434 | "\n",
435 | "def fetch_activations2(params, x):\n",
436 | " hidden_layers = params[:-1]\n",
437 | " collector = []\n",
438 | "\n",
439 | " activation = x\n",
440 | " for w, b in hidden_layers:\n",
441 | " activation = jax.nn.relu(jnp.dot(w, activation) + b)\n",
442 | " collector.append(activation)\n",
443 | "\n",
444 | " return collector\n",
445 | "\n",
446 | "batched_fetch_activations2 = vmap(fetch_activations2, in_axes=(None, 0))\n",
447 | "\n",
448 | "imgs, lbls = next(iter(test_loader))\n",
449 | "\n",
450 | "MLP_params2 = init_MLP([np.prod(mnist_img_size), 512, 256, len(MNIST.classes)], key)\n",
451 | "\n",
452 | "batch_activations = batched_fetch_activations2(MLP_params2, imgs)\n",
453 | "print(batch_activations[1].shape) # (128, 512/256)\n",
454 | "\n",
455 | "dead_neurons = [np.ones(act.shape[1:]) for act in batch_activations]\n",
456 | "\n",
457 | "for layer_id, activations in enumerate(batch_activations):\n",
458 | " dead_neurons[layer_id] = np.logical_and(dead_neurons[layer_id], (activations == 0).all(axis=0))\n",
459 | "\n",
460 | "for layers in dead_neurons:\n",
461 | " print(np.sum(layers))"
462 | ],
463 | "execution_count": null,
464 | "outputs": [
465 | {
466 | "output_type": "stream",
467 | "name": "stdout",
468 | "text": [
469 | "(128, 256)\n",
470 | "0\n",
471 | "7\n"
472 | ]
473 | }
474 | ]
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "metadata": {
479 | "id": "jMmOX-VSKTjQ"
480 | },
481 | "source": [
482 | "# Parallelization"
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "metadata": {
488 | "id": "1aCkdHuhKUqV"
489 | },
490 | "source": [
491 | ""
492 | ],
493 | "execution_count": null,
494 | "outputs": []
495 | }
496 | ]
497 | }
--------------------------------------------------------------------------------
/Tutorial_4_Flax_Zero2Hero_Colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "[](https://deepnote.com/launch?url=https%3A%2F%2Fgithub.com%2Fgordicaleksa%2Fget-started-with-JAX%2Fblob%2Fmain%2FTutorial_4_Flax_Zero2Hero_Colab.ipynb)"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | ""
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "TbMr3-5oun69"
22 | },
23 | "source": [
24 | "# Flax: From Zero to Hero!\n",
25 | "\n",
26 | "This notebook heavily relies on the [official Flax docs](https://flax.readthedocs.io/en/latest/) and [examples](https://github.com/google/flax/blob/main/examples/) + some additional code/modifications, comments/notes, etc."
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "C1qve53yeof5"
33 | },
34 | "source": [
35 | "### Enter Flax - the basics ❤️\n",
36 | "\n",
37 | "Before you jump into the Flax world I strongly recommend you check out my JAX tutorials, as I won't be covering the details of JAX here.\n",
38 | "\n",
39 | "* (Tutorial 1) ML with JAX: From Zero to Hero ([video](https://www.youtube.com/watch?v=SstuvS-tVc0), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_1_JAX_Zero2Hero_Colab.ipynb))\n",
40 | "* (Tutorial 2) ML with JAX: from Hero to Hero Pro+ ([video](https://www.youtube.com/watch?v=CQQaifxuFcs), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_2_JAX_HeroPro%2B_Colab.ipynb))\n",
41 | "* (Tutorial 3) ML with JAX: Coding a Neural Network from Scratch in Pure JAX ([video](https://www.youtube.com/watch?v=6_PqUPxRmjY), [notebook](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb))\n",
42 | "\n",
43 | "That out of the way - let's start with the basics!"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 1,
49 | "metadata": {
50 | "id": "GHcasJkggdZN"
51 | },
52 | "outputs": [],
53 | "source": [
54 | "# Install Flax and JAX\n",
55 | "!pip install --upgrade -q \"jax[cuda11_cudnn805]\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
56 | "!pip install --upgrade -q git+https://github.com/google/flax.git\n",
57 | "!pip install --upgrade -q git+https://github.com/deepmind/dm-haiku # Haiku is here just for comparison purposes"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 2,
63 | "metadata": {
64 | "id": "HmVx7EjigrEZ"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "import jax\n",
69 | "from jax import lax, random, numpy as jnp\n",
70 | "\n",
71 | "# NN lib built on top of JAX developed by Google Research (Brain team)\n",
72 | "# Flax was \"designed for flexibility\" hence the name (Flexibility + JAX -> Flax)\n",
73 | "import flax\n",
74 | "from flax.core import freeze, unfreeze\n",
75 | "from flax import linen as nn # nn notation also used in PyTorch and in Flax's older API\n",
76 | "from flax.training import train_state # a useful dataclass to keep train state\n",
77 | "\n",
78 | "# DeepMind's NN JAX lib - just for comparison purposes, we're not learning Haiku here\n",
79 | "import haiku as hk \n",
80 | "\n",
81 | "# JAX optimizers - a separate lib developed by DeepMind\n",
82 | "import optax\n",
83 | "\n",
84 | "# Flax doesn't have its own data loading functions - we'll be using PyTorch dataloaders\n",
85 | "from torchvision.datasets import MNIST\n",
86 | "from torch.utils.data import DataLoader\n",
87 | "\n",
88 | "# Python libs\n",
89 | "import functools # useful utilities for functional programs\n",
90 | "from typing import Any, Callable, Sequence, Optional\n",
91 | "\n",
92 | "# Other important 3rd party libs\n",
93 | "import numpy as np\n",
94 | "import matplotlib.pyplot as plt"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {
100 | "id": "aSDyQLgOesZp"
101 | },
102 | "source": [
103 | "The goal of this notebook is to get you started with Flax!\n",
104 | "\n",
105 | "I'll only cover the most essential parts of Flax (and Optax) - just as much as needed to get you started with training NNs!"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {
112 | "id": "y1kdq0P_g7LU"
113 | },
114 | "outputs": [],
115 | "source": [
116 | "# Let's start with the simplest model possible: a single feed-forward layer (linear regression)\n",
117 | "model = nn.Dense(features=5)\n",
118 | "\n",
119 | "# All of the Flax NN layers inherit from the Module class (similarly to PyTorch)\n",
120 | "print(nn.Dense.__bases__)"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {
126 | "id": "ux9Okie5PWpw"
127 | },
128 | "source": [
129 | "So how can we do inference with this simple model? 2 steps: init and apply!"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 4,
135 | "metadata": {
136 | "id": "QViTvJhFite2"
137 | },
138 | "outputs": [],
139 | "source": [
140 | "# Step 1: init\n",
141 | "seed = 23\n",
142 | "key1, key2 = random.split(random.PRNGKey(seed))\n",
143 | "x = random.normal(key1, (10,)) # create a dummy input, a 10-dimensional random vector\n",
144 | "\n",
145 | "# Initialization call - this gives us the actual model weights \n",
146 | "# (remember JAX handles state externally!)\n",
147 | "y, params = model.init_with_output(key2, x) \n",
148 | "print(y)\n",
149 | "print(jax.tree_map(lambda x: x.shape, params))\n",
150 | "\n",
151 | "# Note1: automatic shape inference\n",
152 | "# Note2: immutable structure (hence FrozenDict)\n",
153 | "# Note3: init_with_output if you care, for whatever reason, about the output here"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "metadata": {
160 | "id": "b3yFAqeTjdLj"
161 | },
162 | "outputs": [],
163 | "source": [
164 | "# Step 2: apply\n",
165 | "y = model.apply(params, x) # this is how you run prediction in Flax, state is external!\n",
166 | "print(y)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {
173 | "id": "31O_mx-Smalq"
174 | },
175 | "outputs": [],
176 | "source": [
177 | "try:\n",
178 | " y = model(x) # this doesn't work anymore (bye bye PyTorch syntax)\n",
179 | "except Exception as e:\n",
180 | " print(e)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {
187 | "id": "fQYyv76sCJ25"
188 | },
189 | "outputs": [],
190 | "source": [
191 | "# todo: a small coding exercise - let's contrast Flax with Haiku"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {
198 | "cellView": "form",
199 | "id": "UWr3hdpmFBng"
200 | },
201 | "outputs": [],
202 | "source": [
203 | "#@title Haiku vs Flax solution\n",
204 | "model = hk.transform(lambda x: hk.Linear(output_size=5)(x))\n",
205 | "\n",
206 | "seed = 23\n",
207 | "key1, key2 = random.split(random.PRNGKey(seed))\n",
208 | "x = random.normal(key1, (10,)) # create a dummy input, a 10-dimensional random vector\n",
209 | "\n",
210 | "params = model.init(key2, x)\n",
211 | "out = model.apply(params, None, x)\n",
212 | "print(out)\n",
213 | "\n",
214 | "print(hk.Linear.__bases__)"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {
220 | "id": "wWBxTShUiLzW"
221 | },
222 | "source": [
223 | "All of this might (initially!) be overwhelming if you're used to stateful, object-oriented paradigm.\n",
224 | "\n",
225 | "What Flax offers is high performance and flexibility (similarly to JAX).\n",
226 | "\n",
227 | "Here are some [benchmark numbers](https://github.com/huggingface/transformers/tree/master/examples/flax/text-classification) from the HuggingFace team.\n",
228 | "\n",
229 | ""
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "metadata": {
235 | "id": "eUBYtd40krx1"
236 | },
237 | "source": [
238 | "Now that we have a an answer to \"why should I learn Flax?\" - let's start our descent into Flaxlandia!\n",
239 | "\n",
240 | "### A toy example 🚚 - training a linear regression model\n",
241 | "\n",
242 | "We'll first implement a pure-JAX appoach and then we'll do it the Flax-way."
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": null,
248 | "metadata": {
249 | "id": "53-TXcbYkt9D"
250 | },
251 | "outputs": [],
252 | "source": [
253 | "# Defining a toy dataset\n",
254 | "n_samples = 150\n",
255 | "x_dim = 2 # putting small numbers here so that we can visualize the data easily\n",
256 | "y_dim = 1\n",
257 | "noise_amplitude = 0.1\n",
258 | "\n",
259 | "# Generate (random) ground truth W and b\n",
260 | "# Note: we could get W, b from a randomely initialized nn.Dense here, being closer to JAX for now \n",
261 | "key, w_key, b_key = random.split(random.PRNGKey(seed), num=3)\n",
262 | "W = random.normal(w_key, (x_dim, y_dim)) # weight\n",
263 | "b = random.normal(b_key, (y_dim,)) # bias\n",
264 | "\n",
265 | "# This is the structure that Flax expects (recall from the previous section!)\n",
266 | "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n",
267 | "\n",
268 | "# Generate samples with additional noise\n",
269 | "key, x_key, noise_key = random.split(key, num=3)\n",
270 | "xs = random.normal(x_key, (n_samples, x_dim))\n",
271 | "ys = jnp.dot(xs, W) + b\n",
272 | "ys += noise_amplitude * random.normal(noise_key, (n_samples, y_dim))\n",
273 | "print(f'xs shape = {xs.shape} ; ys shape = {ys.shape}')"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 9,
279 | "metadata": {
280 | "colab": {
281 | "base_uri": "https://localhost:8080/",
282 | "height": 266
283 | },
284 | "id": "lc4-xoIapKCs",
285 | "outputId": "52656571-0aa5-4c6f-f522-83b0158c1b97"
286 | },
287 | "outputs": [
288 | {
289 | "data": {
290 | "text/plain": [
291 | ""
292 | ]
293 | },
294 | "execution_count": 9,
295 | "metadata": {},
296 | "output_type": "execute_result"
297 | },
298 | {
299 | "data": {
300 | "image/png": "\n",
301 | "text/plain": [
302 | ""
303 | ]
304 | },
305 | "metadata": {
306 | "needs_background": "light"
307 | },
308 | "output_type": "display_data"
309 | }
310 | ],
311 | "source": [
312 | "# Let's visualize our data (becoming one with the data paradigm <3)\n",
313 | "fig = plt.figure()\n",
314 | "ax = fig.add_subplot(111, projection='3d')\n",
315 | "assert xs.shape[-1] == 2 and ys.shape[-1] == 1 # low dimensional data so that we can plot it\n",
316 | "ax.scatter(xs[:, 0], xs[:, 1], zs=ys)\n",
317 | "\n",
318 | "# todo: exercise - let's show that our data lies on the 2D plane embedded in 3D\n",
319 | "# option 1: analytic approach\n",
320 | "# option 2: data-driven approach"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": 22,
326 | "metadata": {
327 | "id": "mKiCOyoikxcM"
328 | },
329 | "outputs": [],
330 | "source": [
331 | "def make_mse_loss(xs, ys):\n",
332 | " \n",
333 | " def mse_loss(params):\n",
334 | " \"\"\"Gives the value of the loss on the (xs, ys) dataset for the given model (params).\"\"\"\n",
335 | " \n",
336 | " # Define the squared loss for a single pair (x,y)\n",
337 | " def squared_error(x, y):\n",
338 | " pred = model.apply(params, x)\n",
339 | " # Inner because 'y' could have in general more than 1 dims\n",
340 | " return jnp.inner(y-pred, y-pred) / 2.0\n",
341 | "\n",
342 | " # Batched version via vmap\n",
343 | " return jnp.mean(jax.vmap(squared_error)(xs, ys), axis=0)\n",
344 | "\n",
345 | " return jax.jit(mse_loss) # and finally we jit the result (mse_loss is a pure function)\n",
346 | "\n",
347 | "mse_loss = make_mse_loss(xs, ys)\n",
348 | "value_and_grad_fn = jax.value_and_grad(mse_loss)"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "metadata": {
355 | "id": "phLYjH5ZkzLn"
356 | },
357 | "outputs": [],
358 | "source": [
359 | "# Let's reuse the simple feed-forward layer since it trivially implements linear regression\n",
360 | "model = nn.Dense(features=y_dim)\n",
361 | "params = model.init(key, xs)\n",
362 | "print(f'Initial params = {params}')\n",
363 | "\n",
364 | "# Let's set some reasonable hyperparams\n",
365 | "lr = 0.3\n",
366 | "epochs = 20\n",
367 | "log_period_epoch = 5\n",
368 | "\n",
369 | "print('-' * 50)\n",
370 | "for epoch in range(epochs):\n",
371 | " loss, grads = value_and_grad_fn(params)\n",
372 | " # SGD (closer to JAX again, but we'll progressively go towards how stuff is done in Flax)\n",
373 | " params = jax.tree_multimap(lambda p, g: p - lr * g, params, grads)\n",
374 | "\n",
375 | " if epoch % log_period_epoch == 0:\n",
376 | " print(f'epoch {epoch}, loss = {loss}')\n",
377 | "\n",
378 | "print('-' * 50)\n",
379 | "print(f'Learned params = {params}')\n",
380 | "print(f'Gt params = {true_params}')"
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {
386 | "id": "rvy6Oow2lLHu"
387 | },
388 | "source": [
389 | "Now let's do the same thing but this time with dedicated optimizers!\n",
390 | "\n",
391 | "Enter DeepMind's optax! ❤️🔥"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {
398 | "id": "5hhcFZ7UlCov"
399 | },
400 | "outputs": [],
401 | "source": [
402 | "opt_sgd = optax.sgd(learning_rate=lr)\n",
403 | "opt_state = opt_sgd.init(params) # always the same pattern - handling state externally\n",
404 | "print(opt_state)\n",
405 | "# todo: exercise - compare Adam's and SGD's states"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": null,
411 | "metadata": {
412 | "id": "t_EHHjy_lFGN"
413 | },
414 | "outputs": [],
415 | "source": [
416 | "params = model.init(key, xs) # let's start with fresh params again\n",
417 | "\n",
418 | "for epoch in range(epochs):\n",
419 | " loss, grads = value_and_grad_fn(params)\n",
420 | " updates, opt_state = opt_sgd.update(grads, opt_state) # arbitrary optim logic!\n",
421 | " params = optax.apply_updates(params, updates)\n",
422 | "\n",
423 | " if epoch % log_period_epoch == 0:\n",
424 | " print(f'epoch {epoch}, loss = {loss}')\n",
425 | "\n",
426 | "# Note 1: as expected we get the same loss values\n",
427 | "# Note 2: we'll later see more concise ways to handle all of these state components (hint: TrainState)"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "metadata": {
433 | "id": "QF1gAYSzxQ1R"
434 | },
435 | "source": [
436 | "In this toy SGD example Optax may not seem that useful but it's very powerful.\n",
437 | "\n",
438 | "You can build arbitrary optimizers with arbitrary hyperparam schedules, chaining, param freezing, etc. You can check the [official docs here](https://optax.readthedocs.io/en/latest/)."
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "execution_count": 8,
444 | "metadata": {
445 | "cellView": "form",
446 | "id": "rKbis5O0KQYH"
447 | },
448 | "outputs": [],
449 | "source": [
450 | "#@title Optax Advanced Examples\n",
451 | "# This cell won't \"compile\" (no ml_collections package) and serves just as an example\n",
452 | "\n",
453 | "# Example from Flax (ImageNet example)\n",
454 | "# https://github.com/google/flax/blob/main/examples/imagenet/train.py#L88\n",
455 | "def create_learning_rate_fn(\n",
456 | " config: ml_collections.ConfigDict,\n",
457 | " base_learning_rate: float,\n",
458 | " steps_per_epoch: int):\n",
459 | " \"\"\"Create learning rate schedule.\"\"\"\n",
460 | " warmup_fn = optax.linear_schedule(\n",
461 | " init_value=0., end_value=base_learning_rate,\n",
462 | " transition_steps=config.warmup_epochs * steps_per_epoch)\n",
463 | " cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)\n",
464 | " cosine_fn = optax.cosine_decay_schedule(\n",
465 | " init_value=base_learning_rate,\n",
466 | " decay_steps=cosine_epochs * steps_per_epoch)\n",
467 | " schedule_fn = optax.join_schedules(\n",
468 | " schedules=[warmup_fn, cosine_fn],\n",
469 | " boundaries=[config.warmup_epochs * steps_per_epoch])\n",
470 | " return schedule_fn\n",
471 | "\n",
472 | "tx = optax.sgd(\n",
473 | " learning_rate=learning_rate_fn,\n",
474 | " momentum=config.momentum,\n",
475 | " nesterov=True,\n",
476 | ")\n",
477 | "\n",
478 | "# Example from Haiku (ImageNet example)\n",
479 | "# https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py#L116\n",
480 | "def make_optimizer() -> optax.GradientTransformation:\n",
481 | " \"\"\"SGD with nesterov momentum and a custom lr schedule.\"\"\"\n",
482 | " return optax.chain(\n",
483 | " optax.trace(\n",
484 | " decay=FLAGS.optimizer_momentum,\n",
485 | " nesterov=FLAGS.optimizer_use_nesterov),\n",
486 | " optax.scale_by_schedule(lr_schedule), optax.scale(-1))"
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "metadata": {
492 | "id": "WFAeHIEwL0ZH"
493 | },
494 | "source": [
495 | "Now let's go beyond these extremely simple models!"
496 | ]
497 | },
498 | {
499 | "cell_type": "markdown",
500 | "metadata": {
501 | "id": "7_33y-bTl6bd"
502 | },
503 | "source": [
504 | "### Creating custom models ⭐"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": null,
510 | "metadata": {
511 | "id": "JOrJHqTSl75M"
512 | },
513 | "outputs": [],
514 | "source": [
515 | "class MLP(nn.Module):\n",
516 | " num_neurons_per_layer: Sequence[int] # data field (nn.Module is Python's dataclass)\n",
517 | "\n",
518 | " def setup(self): # because dataclass is implicitly using the __init__ function... :')\n",
519 | " self.layers = [nn.Dense(n) for n in self.num_neurons_per_layer]\n",
520 | "\n",
521 | " def __call__(self, x):\n",
522 | " activation = x\n",
523 | " for i, layer in enumerate(self.layers):\n",
524 | " activation = layer(activation)\n",
525 | " if i != len(self.layers) - 1:\n",
526 | " activation = nn.relu(activation)\n",
527 | " return activation\n",
528 | "\n",
529 | "x_key, init_key = random.split(random.PRNGKey(seed))\n",
530 | "\n",
531 | "model = MLP(num_neurons_per_layer=[16, 8, 1]) # define an MLP model\n",
532 | "x = random.uniform(x_key, (4,4)) # dummy input\n",
533 | "params = model.init(init_key, x) # initialize via init\n",
534 | "y = model.apply(params, x) # do a forward pass via apply\n",
535 | "\n",
536 | "print(jax.tree_map(jnp.shape, params))\n",
537 | "print(f'Output: {y}')\n",
538 | "\n",
539 | "# todo: exercise - use @nn.compact pattern instead\n",
540 | "# todo: check out https://realpython.com/python-data-classes/"
541 | ]
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "metadata": {
546 | "id": "TEhC-WdPnAYp"
547 | },
548 | "source": [
549 | "Great! \n",
550 | "\n",
551 | "Now that we know how to build more complex models let's dive deeper and understand how the 'nn.Dense' module is designed itself.\n",
552 | "\n",
553 | "#### Introducing \"param\""
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": null,
559 | "metadata": {
560 | "id": "Z9YhSgxjnBQg"
561 | },
562 | "outputs": [],
563 | "source": [
564 | "class MyDenseImp(nn.Module):\n",
565 | " num_neurons: int\n",
566 | " weight_init: Callable = nn.initializers.lecun_normal()\n",
567 | " bias_init: Callable = nn.initializers.zeros\n",
568 | "\n",
569 | " @nn.compact\n",
570 | " def __call__(self, x):\n",
571 | " weight = self.param('weight', # parametar name (as it will appear in the FrozenDict)\n",
572 | " self.weight_init, # initialization function, RNG passed implicitly through init fn\n",
573 | " (x.shape[-1], self.num_neurons)) # shape info\n",
574 | " bias = self.param('bias', self.bias_init, (self.num_neurons,))\n",
575 | "\n",
576 | " return jnp.dot(x, weight) + bias\n",
577 | "\n",
578 | "x_key, init_key = random.split(random.PRNGKey(seed))\n",
579 | "\n",
580 | "model = MyDenseImp(num_neurons=3) # initialize the model\n",
581 | "x = random.uniform(x_key, (4,4)) # dummy input\n",
582 | "params = model.init(init_key, x) # initialize via init\n",
583 | "y = model.apply(params, x) # do a forward pass via apply\n",
584 | "\n",
585 | "print(jax.tree_map(jnp.shape, params))\n",
586 | "print(f'Output: {y}')\n",
587 | "\n",
588 | "# todo: exercise - check out the source code:\n",
589 | "# https://github.com/google/flax/blob/main/flax/linen/linear.py\n",
590 | "# https://github.com/google/jax/blob/main/jax/_src/nn/initializers.py#L150 <- to see why lecun_normal() vs zeros (no brackets)"
591 | ]
592 | },
593 | {
594 | "cell_type": "code",
595 | "execution_count": null,
596 | "metadata": {
597 | "id": "AqCPhl9fBI_Z"
598 | },
599 | "outputs": [],
600 | "source": [
601 | "from inspect import signature\n",
602 | "\n",
603 | "# You can see it expects a PRNG key and it is passed implicitly through the init fn (same for zeros)\n",
604 | "print(signature(nn.initializers.lecun_normal()))"
605 | ]
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "metadata": {
610 | "id": "MWB8HvLHn6g0"
611 | },
612 | "source": [
613 | "So far we've only seen **trainable** params. \n",
614 | "\n",
615 | "ML models often times have variables which are part of the state but are not optimized via gradient descent.\n",
616 | "\n",
617 | "Let's see how we can handle them using a simple (and contrived) example!\n",
618 | "\n",
619 | "#### Introducing \"variable\"\n",
620 | "\n",
621 | "*Note on terminology: variable is a broader term and it includes both params (trainable variables) as well as non-trainable vars.*"
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": null,
627 | "metadata": {
628 | "id": "oGE6qTHHngYh"
629 | },
630 | "outputs": [],
631 | "source": [
632 | "class BiasAdderWithRunningMean(nn.Module):\n",
633 | " decay: float = 0.99\n",
634 | "\n",
635 | " @nn.compact\n",
636 | " def __call__(self, x):\n",
637 | " is_initialized = self.has_variable('batch_stats', 'ema')\n",
638 | "\n",
639 | " # 'batch_stats' is not an arbitrary name!\n",
640 | " # Flax uses that name in its implementation of BatchNorm (hard-coded, probably not the best of designs?)\n",
641 | " ema = self.variable('batch_stats', 'ema', lambda shape: jnp.zeros(shape), x.shape[1:])\n",
642 | "\n",
643 | " # self.param will by default add this variable to 'params' collection (vs 'batch_stats' above)\n",
644 | " # Again some idiosyncrasies here we need to pass a key even though we don't actually use it...\n",
645 | " bias = self.param('bias', lambda key, shape: jnp.zeros(shape), x.shape[1:])\n",
646 | "\n",
647 | " if is_initialized:\n",
648 | " # self.variable returns a reference hence .value\n",
649 | " ema.value = self.decay * ema.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n",
650 | "\n",
651 | " return x - ema.value + bias\n",
652 | "\n",
653 | "x_key, init_key = random.split(random.PRNGKey(seed))\n",
654 | "\n",
655 | "model = BiasAdderWithRunningMean()\n",
656 | "x = random.uniform(x_key, (10,4)) # dummy input\n",
657 | "variables = model.init(init_key, x)\n",
658 | "print(f'Multiple collections = {variables}') # we can now see a new collection 'batch_stats'\n",
659 | "\n",
660 | "# We have to use mutable since regular params are not modified during the forward\n",
661 | "# pass, but these variables are. We can't keep state internally (because JAX) so we have to return it.\n",
662 | "y, updated_non_trainable_params = model.apply(variables, x, mutable=['batch_stats'])\n",
663 | "print(updated_non_trainable_params)"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": null,
669 | "metadata": {
670 | "id": "PuzwVt8RoHvY"
671 | },
672 | "outputs": [],
673 | "source": [
674 | "# Let's see how we could train such model!\n",
675 | "def update_step(opt, apply_fn, x, opt_state, params, non_trainable_params):\n",
676 | "\n",
677 | " def loss_fn(params):\n",
678 | " y, updated_non_trainable_params = apply_fn(\n",
679 | " {'params': params, **non_trainable_params}, \n",
680 | " x, mutable=list(non_trainable_params.keys()))\n",
681 | " \n",
682 | " loss = ((x - y) ** 2).sum() # not doing anything really, just for the demo purpose\n",
683 | "\n",
684 | " return loss, updated_non_trainable_params\n",
685 | "\n",
686 | " (loss, non_trainable_params), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)\n",
687 | " updates, opt_state = opt.update(grads, opt_state)\n",
688 | " params = optax.apply_updates(params, updates)\n",
689 | " \n",
690 | " return opt_state, params, non_trainable_params # all of these represent the state - ugly, for now\n",
691 | "\n",
692 | "model = BiasAdderWithRunningMean()\n",
693 | "x = jnp.ones((10,4)) # dummy input, using ones because it's easier to see what's going on\n",
694 | "\n",
695 | "variables = model.init(random.PRNGKey(seed), x)\n",
696 | "non_trainable_params, params = variables.pop('params')\n",
697 | "del variables # delete variables to avoid wasting resources (this pattern is used in the official code)\n",
698 | "\n",
699 | "sgd_opt = optax.sgd(learning_rate=0.1) # originally you'll see them use the 'tx' naming (from opTaX)\n",
700 | "opt_state = sgd_opt.init(params)\n",
701 | "\n",
702 | "for _ in range(3):\n",
703 | " # We'll later see how TrainState abstraction will make this step much more elegant!\n",
704 | " opt_state, params, non_trainable_params = update_step(sgd_opt, model.apply, x, opt_state, params, non_trainable_params)\n",
705 | " print(non_trainable_params)"
706 | ]
707 | },
708 | {
709 | "cell_type": "markdown",
710 | "metadata": {
711 | "id": "gzWUq5vBrWMe"
712 | },
713 | "source": [
714 | "Let's go a level up in abstraction again now that we understand params and variables!\n",
715 | "\n",
716 | "Certain layers like BatchNorm will use variables in the background.\n",
717 | "\n",
718 | "Let's see a last example that is conceptually as complicated as it gets when it comes to Flax's idiosyncrasies, and high-level at the same time."
719 | ]
720 | },
721 | {
722 | "cell_type": "code",
723 | "execution_count": null,
724 | "metadata": {
725 | "id": "rDw2986orY0a"
726 | },
727 | "outputs": [],
728 | "source": [
729 | "class DDNBlock(nn.Module):\n",
730 | " \"\"\"Dense, dropout + batchnorm combo.\n",
731 | "\n",
732 | " Contains trainable variables (params), non-trainable variables (batch stats),\n",
733 | " and stochasticity in the forward pass (because of dropout).\n",
734 | " \"\"\"\n",
735 | " num_neurons: int\n",
736 | " training: bool\n",
737 | "\n",
738 | " @nn.compact\n",
739 | " def __call__(self, x):\n",
740 | " x = nn.Dense(self.num_neurons)(x)\n",
741 | " x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)\n",
742 | " x = nn.BatchNorm(use_running_average=not self.training)(x)\n",
743 | " return x\n",
744 | "\n",
745 | "key1, key2, key3, key4 = random.split(random.PRNGKey(seed), 4)\n",
746 | "\n",
747 | "model = DDNBlock(num_neurons=3, training=True)\n",
748 | "x = random.uniform(key1, (3,4,4))\n",
749 | "\n",
750 | "# New: because of Dropout we now have to include its unique key - kinda weird, but you get used to it\n",
751 | "variables = model.init({'params': key2, 'dropout': key3}, x)\n",
752 | "print(variables)\n",
753 | "\n",
754 | "# And same here, everything else remains the same as the previous example\n",
755 | "y, non_trainable_params = model.apply(variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])\n",
756 | "\n",
757 | "# Let's run these model variables during \"evaluation\":\n",
758 | "eval_model = DDNBlock(num_neurons=3, training=False)\n",
759 | "# Because training=False we don't have stochasticity in the forward pass neither do we update the stats\n",
760 | "y = eval_model.apply(variables, x)"
761 | ]
762 | },
763 | {
764 | "cell_type": "markdown",
765 | "metadata": {
766 | "id": "Ys1y-yM8vzT8"
767 | },
768 | "source": [
769 | "### A fully-fledged CNN on MNIST example in Flax! 💥\n",
770 | "\n",
771 | "Modified the official MNIST example here: https://github.com/google/flax/tree/main/examples/mnist\n",
772 | "\n",
773 | "We'll be using PyTorch dataloading instead of TFDS.\n",
774 | "\n",
775 | "Let's start by defining a model:"
776 | ]
777 | },
778 | {
779 | "cell_type": "code",
780 | "execution_count": 3,
781 | "metadata": {
782 | "id": "MD8t9K2Nv0yC"
783 | },
784 | "outputs": [],
785 | "source": [
786 | "class CNN(nn.Module): # lots of hardcoding, but it serves a purpose for a simple demo\n",
787 | " @nn.compact\n",
788 | " def __call__(self, x):\n",
789 | " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
790 | " x = nn.relu(x)\n",
791 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
792 | " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
793 | " x = nn.relu(x)\n",
794 | " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
795 | " x = x.reshape((x.shape[0], -1)) # flatten\n",
796 | " x = nn.Dense(features=256)(x)\n",
797 | " x = nn.relu(x)\n",
798 | " x = nn.Dense(features=10)(x)\n",
799 | " x = nn.log_softmax(x)\n",
800 | " return x"
801 | ]
802 | },
803 | {
804 | "cell_type": "markdown",
805 | "metadata": {
806 | "id": "rVgWLMhiSAYv"
807 | },
808 | "source": [
809 | "Let's add the data loading support in PyTorch!\n",
810 | "\n",
811 | "I'll be reusing code from [tutorial #3](https://github.com/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_3_JAX_Neural_Network_from_Scratch_Colab.ipynb):"
812 | ]
813 | },
814 | {
815 | "cell_type": "code",
816 | "execution_count": 4,
817 | "metadata": {
818 | "id": "UZ-og2UOUUWD"
819 | },
820 | "outputs": [],
821 | "source": [
822 | "def custom_transform(x):\n",
823 | " # A couple of modifications here compared to tutorial #3 since we're using a CNN\n",
824 | " # Input: (28, 28) uint8 [0, 255] torch.Tensor, Output: (28, 28, 1) float32 [0, 1] np array\n",
825 | " return np.expand_dims(np.array(x, dtype=np.float32), axis=2) / 255.\n",
826 | "\n",
827 | "def custom_collate_fn(batch):\n",
828 | " \"\"\"Provides us with batches of numpy arrays and not PyTorch's tensors.\"\"\"\n",
829 | " transposed_data = list(zip(*batch))\n",
830 | "\n",
831 | " labels = np.array(transposed_data[1])\n",
832 | " imgs = np.stack(transposed_data[0])\n",
833 | "\n",
834 | " return imgs, labels\n",
835 | "\n",
836 | "mnist_img_size = (28, 28, 1)\n",
837 | "batch_size = 128\n",
838 | "\n",
839 | "train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)\n",
840 | "test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)\n",
841 | "\n",
842 | "train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)\n",
843 | "test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)\n",
844 | "\n",
845 | "# optimization - loading the whole dataset into memory\n",
846 | "train_images = jnp.array(train_dataset.data)\n",
847 | "train_lbls = jnp.array(train_dataset.targets)\n",
848 | "\n",
849 | "# np.expand_dims is to convert shape from (10000, 28, 28) -> (10000, 28, 28, 1)\n",
850 | "# We don't have to do this for training images because custom_transform does it for us.\n",
851 | "test_images = np.expand_dims(jnp.array(test_dataset.data), axis=3)\n",
852 | "test_lbls = jnp.array(test_dataset.targets)"
853 | ]
854 | },
855 | {
856 | "cell_type": "code",
857 | "execution_count": 5,
858 | "metadata": {
859 | "colab": {
860 | "base_uri": "https://localhost:8080/",
861 | "height": 282
862 | },
863 | "id": "2HeXX51NU0k6",
864 | "outputId": "43dad5bf-20c2-4c5a-9705-12b2e422f915"
865 | },
866 | "outputs": [
867 | {
868 | "name": "stdout",
869 | "output_type": "stream",
870 | "text": [
871 | "7\n"
872 | ]
873 | },
874 | {
875 | "data": {
876 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n",
877 | "text/plain": [
878 | ""
879 | ]
880 | },
881 | "metadata": {
882 | "needs_background": "light"
883 | },
884 | "output_type": "display_data"
885 | }
886 | ],
887 | "source": [
888 | "# Visualize a single image\n",
889 | "imgs, lbls = next(iter(test_loader))\n",
890 | "img = imgs[0].reshape(mnist_img_size)[:, :, 0]\n",
891 | "gt_lbl = lbls[0]\n",
892 | "\n",
893 | "print(gt_lbl)\n",
894 | "plt.imshow(img); plt.show()"
895 | ]
896 | },
897 | {
898 | "cell_type": "markdown",
899 | "metadata": {
900 | "id": "TsGPQKx0SPL-"
901 | },
902 | "source": [
903 | "Great - we have our data pipeline ready and the model architecture defined.\n",
904 | "\n",
905 | "Now let's define core training functions:"
906 | ]
907 | },
908 | {
909 | "cell_type": "code",
910 | "execution_count": 6,
911 | "metadata": {
912 | "id": "qD8ztbEsVM43"
913 | },
914 | "outputs": [],
915 | "source": [
916 | "@jax.jit\n",
917 | "def train_step(state, imgs, gt_labels):\n",
918 | " def loss_fn(params):\n",
919 | " logits = CNN().apply({'params': params}, imgs)\n",
920 | " one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)\n",
921 | " loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))\n",
922 | " return loss, logits\n",
923 | " \n",
924 | " (_, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)\n",
925 | " state = state.apply_gradients(grads=grads) # this is the whole update now! concise!\n",
926 | " metrics = compute_metrics(logits=logits, gt_labels=gt_labels) # duplicating loss calculation but it's a bit cleaner\n",
927 | " return state, metrics\n",
928 | "\n",
929 | "@jax.jit\n",
930 | "def eval_step(state, imgs, gt_labels):\n",
931 | " logits = CNN().apply({'params': state.params}, imgs)\n",
932 | " return compute_metrics(logits=logits, gt_labels=gt_labels)"
933 | ]
934 | },
935 | {
936 | "cell_type": "code",
937 | "execution_count": 7,
938 | "metadata": {
939 | "id": "v5VblVs2VWxo"
940 | },
941 | "outputs": [],
942 | "source": [
943 | "def train_one_epoch(state, dataloader, epoch):\n",
944 | " \"\"\"Train for 1 epoch on the training set.\"\"\"\n",
945 | " batch_metrics = []\n",
946 | " for cnt, (imgs, labels) in enumerate(dataloader):\n",
947 | " state, metrics = train_step(state, imgs, labels)\n",
948 | " batch_metrics.append(metrics)\n",
949 | "\n",
950 | " # Aggregate the metrics\n",
951 | " batch_metrics_np = jax.device_get(batch_metrics) # pull from the accelerator onto host (CPU)\n",
952 | " epoch_metrics_np = {\n",
953 | " k: np.mean([metrics[k] for metrics in batch_metrics_np])\n",
954 | " for k in batch_metrics_np[0]\n",
955 | " }\n",
956 | "\n",
957 | " return state, epoch_metrics_np\n",
958 | "\n",
959 | "def evaluate_model(state, test_imgs, test_lbls):\n",
960 | " \"\"\"Evaluate on the validation set.\"\"\"\n",
961 | " metrics = eval_step(state, test_imgs, test_lbls)\n",
962 | " metrics = jax.device_get(metrics) # pull from the accelerator onto host (CPU)\n",
963 | " metrics = jax.tree_map(lambda x: x.item(), metrics) # np.ndarray -> scalar\n",
964 | " return metrics"
965 | ]
966 | },
967 | {
968 | "cell_type": "code",
969 | "execution_count": 8,
970 | "metadata": {
971 | "id": "xiV5yiA4BKEk"
972 | },
973 | "outputs": [],
974 | "source": [
975 | "# This one will keep things nice and tidy compared to our previous examples\n",
976 | "def create_train_state(key, learning_rate, momentum):\n",
977 | " cnn = CNN()\n",
978 | " params = cnn.init(key, jnp.ones([1, *mnist_img_size]))['params']\n",
979 | " sgd_opt = optax.sgd(learning_rate, momentum)\n",
980 | " # TrainState is a simple built-in wrapper class that makes things a bit cleaner\n",
981 | " return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=sgd_opt)\n",
982 | "\n",
983 | "def compute_metrics(*, logits, gt_labels):\n",
984 | " one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)\n",
985 | "\n",
986 | " loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))\n",
987 | " accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)\n",
988 | "\n",
989 | " metrics = {\n",
990 | " 'loss': loss,\n",
991 | " 'accuracy': accuracy,\n",
992 | " }\n",
993 | " return metrics"
994 | ]
995 | },
996 | {
997 | "cell_type": "code",
998 | "execution_count": 9,
999 | "metadata": {
1000 | "colab": {
1001 | "base_uri": "https://localhost:8080/"
1002 | },
1003 | "id": "s8EFriHnVcJO",
1004 | "outputId": "cb40714f-6150-44d6-e1e0-290b72a23eda"
1005 | },
1006 | "outputs": [
1007 | {
1008 | "name": "stdout",
1009 | "output_type": "stream",
1010 | "text": [
1011 | "Train epoch: 1, loss: 0.2903152406215668, accuracy: 91.86198115348816\n",
1012 | "Test epoch: 1, loss: 44.35035705566406, accuracy: 94.77999806404114\n",
1013 | "Train epoch: 2, loss: 0.058339256793260574, accuracy: 98.23551177978516\n",
1014 | "Test epoch: 2, loss: 17.13631820678711, accuracy: 97.33999967575073\n"
1015 | ]
1016 | }
1017 | ],
1018 | "source": [
1019 | "# Finally let's define the high-level training/val loops\n",
1020 | "seed = 0 # needless to say these should be in a config or defined like flags\n",
1021 | "learning_rate = 0.1\n",
1022 | "momentum = 0.9\n",
1023 | "num_epochs = 2\n",
1024 | "batch_size = 32\n",
1025 | "\n",
1026 | "train_state = create_train_state(jax.random.PRNGKey(seed), learning_rate, momentum)\n",
1027 | "\n",
1028 | "for epoch in range(1, num_epochs + 1):\n",
1029 | " train_state, train_metrics = train_one_epoch(train_state, train_loader, epoch)\n",
1030 | " print(f\"Train epoch: {epoch}, loss: {train_metrics['loss']}, accuracy: {train_metrics['accuracy'] * 100}\")\n",
1031 | "\n",
1032 | " test_metrics = evaluate_model(train_state, test_images, test_lbls)\n",
1033 | " print(f\"Test epoch: {epoch}, loss: {test_metrics['loss']}, accuracy: {test_metrics['accuracy'] * 100}\")\n",
1034 | "\n",
1035 | "# todo: exercise - how would we go about adding dropout? What about BatchNorm? What would have to change?"
1036 | ]
1037 | },
1038 | {
1039 | "cell_type": "markdown",
1040 | "metadata": {
1041 | "id": "6U-BIjQ1v4ff"
1042 | },
1043 | "source": [
1044 | "Bonus point: a walk-through the \"non-toy\", distributed ImageNet CNN training example.\n",
1045 | "\n",
1046 | "Head over to https://github.com/google/flax/tree/main/examples/imagenet\n",
1047 | "\n",
1048 | "You'll keep seeing the same pattern/structure in all official Flax examples."
1049 | ]
1050 | },
1051 | {
1052 | "cell_type": "markdown",
1053 | "metadata": {
1054 | "id": "6Q4C2M2tv_0J"
1055 | },
1056 | "source": [
1057 | "### Further learning resources 📚\n",
1058 | "\n",
1059 | "Aside from the [official docs](https://flax.readthedocs.io/en/latest/) and [examples](https://github.com/google/flax/tree/main/examples) I found [HuggingFace's Flax examples](https://github.com/huggingface/transformers/tree/master/examples/flax) and the resources from their [\"community week\"](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects) useful as well.\n",
1060 | "\n",
1061 | "Finally, [source code](https://github.com/google/flax) is also your friend, as the library is still evolving."
1062 | ]
1063 | },
1064 | {
1065 | "cell_type": "markdown",
1066 | "metadata": {
1067 | "id": "T5DqxlZ-SD3e"
1068 | },
1069 | "source": [
1070 | "### Connect with me ❤️\n",
1071 | "\n",
1072 | "Last but not least I regularly post AI-related stuff (paper summaries, AI news, etc.) on my Twitter/LinkedIn. We also have an ever increasing Discord community (1600+ members at the time of writing this). If you care about any of these I encourage you to connect! \n",
1073 | "\n",
1074 | "Social: \n",
1075 | "💼 LinkedIn - https://www.linkedin.com/in/aleksagordic/ \n",
1076 | "🐦 Twitter - https://twitter.com/gordic_aleksa \n",
1077 | "👨👩👧👦 Discord - https://discord.gg/peBrCpheKE \n",
1078 | "🙏 Patreon - https://www.patreon.com/theaiepiphany \n",
1079 | "\n",
1080 | "Content: \n",
1081 | "📺 YouTube - https://www.youtube.com/c/TheAIEpiphany/ \n",
1082 | "📚 Medium - https://gordicaleksa.medium.com/ \n",
1083 | "💻 GitHub - https://github.com/gordicaleksa \n",
1084 | "📢 AI Newsletter - https://aiepiphany.substack.com/ "
1085 | ]
1086 | }
1087 | ],
1088 | "metadata": {
1089 | "accelerator": "GPU",
1090 | "colab": {
1091 | "collapsed_sections": [],
1092 | "name": "Tutorial 4: Flax Zero2Hero.ipynb",
1093 | "provenance": []
1094 | },
1095 | "kernelspec": {
1096 | "display_name": "Python 3 (ipykernel)",
1097 | "language": "python",
1098 | "name": "python3"
1099 | },
1100 | "language_info": {
1101 | "codemirror_mode": {
1102 | "name": "ipython",
1103 | "version": 3
1104 | },
1105 | "file_extension": ".py",
1106 | "mimetype": "text/x-python",
1107 | "name": "python",
1108 | "nbconvert_exporter": "python",
1109 | "pygments_lexer": "ipython3",
1110 | "version": "3.9.0"
1111 | }
1112 | },
1113 | "nbformat": 4,
1114 | "nbformat_minor": 1
1115 | }
1116 |
--------------------------------------------------------------------------------