├── imgs ├── ad_example.png ├── backward_ad.png ├── forward_ad.png ├── model_optim.png ├── learning_process.png ├── autoencoder_bowtie.png └── automatic_diff_methods.png ├── requirements.txt ├── environment.yml ├── README.md ├── .gitignore ├── addendum_tf_static_graph.ipynb ├── tip_model_persistence.ipynb └── addendum_autograd.ipynb /imgs/ad_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/ad_example.png -------------------------------------------------------------------------------- /imgs/backward_ad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/backward_ad.png -------------------------------------------------------------------------------- /imgs/forward_ad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/forward_ad.png -------------------------------------------------------------------------------- /imgs/model_optim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/model_optim.png -------------------------------------------------------------------------------- /imgs/learning_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/learning_process.png -------------------------------------------------------------------------------- /imgs/autoencoder_bowtie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/autoencoder_bowtie.png -------------------------------------------------------------------------------- /imgs/automatic_diff_methods.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leriomaggio/euroscipy22-pytorch/main/imgs/automatic_diff_methods.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.4.0 2 | jupyter==1.0.0 3 | matplotlib==3.5.2 4 | notebook==6.4.12 5 | numpy==1.21.5 6 | pandas==1.4.3 7 | python==3.9.12 8 | scikit-learn==1.1.1 9 | scipy==1.7.3 10 | setuptools==63.4.1 11 | tensorboard==2.6.0 12 | tensorboard-data-server==0.6.0 13 | tensorboard-plugin-wit==1.6.0 14 | notexbook-theme==2.0.1 15 | pillow==9.2.0 16 | torch==1.12.1 17 | torchvision==0.13.1 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch-euroscipy 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ipykernel=6.9.1 6 | - ipython=8.4.0 7 | - jupyter=1.0.0 8 | - matplotlib=3.5.2 9 | - nomkl=3.0 10 | - notebook=6.4.12 11 | - numpy=1.21.5 12 | - pandas=1.4.3 13 | - python=3.9.12 14 | - scikit-learn=1.1.1 15 | - scipy=1.7.3 16 | - setuptools=63.4.1 17 | - tensorboard=2.6.0 18 | - tensorboard-data-server=0.6.0 19 | - tensorboard-plugin-wit=1.6.0 20 | - pip: 21 | - charset-normalizer==2.1.1 22 | - notexbook-theme==2.0.1 23 | - pillow==9.2.0 24 | - torch==1.12.1 25 | - torchvision==0.13.1 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # introduction-to-pytorch 2 | Lecture materials for the Introduction to Pytorch tutorial presented at EuroScipy 2022 in Basel, Switzerland 3 | 4 | ### Set up the environment 5 | 6 | Using `conda`: 7 | 8 | ``` 9 | conda env create -f environment.yml 10 | conda activate pytorch-euroscipy 11 | ``` 12 | 13 | Using `pip` (and virtualenv of your choice): 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Run Notebook 20 | 21 | 1. Jupyter notebook: 22 | 23 | ``` 24 | jupyter notebook 25 | ``` 26 | 27 | 2. MyBinder (no install) 28 | 29 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/leriomaggio/euroscipy22-pytorch/HEAD) 30 | 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /addendum_tf_static_graph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow (`1.x`): Static Graph\n", 8 | "\n", 9 | "**Adapted from**: [TensorFlow Static Graph](https://pytorch.org/tutorials/beginner/examples_autograd/tf_two_layer_net.html#sphx-glr-beginner-examples-autograd-tf-two-layer-net-py) \n", 10 | "\n", 11 | "A fully-connected ReLU network with one hidden layer and no biases, trained to predict y from x by minimizing squared Euclidean distance.\n", 12 | "\n", 13 | "This implementation uses basic TensorFlow operations to set up a computational graph, then executes the graph many times to actually train the network.\n", 14 | "\n", 15 | "One of the main differences between TensorFlow and PyTorch is that TensorFlow uses static computational graphs while PyTorch uses dynamic computational graphs.\n", 16 | "\n", 17 | "In TensorFlow we first set up the computational graph, then execute the same graph many times.\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 4, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import tensorflow\n", 27 | "import numpy as np\n", 28 | "\n", 29 | "# First we set up the computational graph, and disable Eager execution\n", 30 | "tensorflow.compat.v1.disable_eager_execution()\n", 31 | "\n", 32 | "# Monkye patching V1 Compatibility\n", 33 | "from tensorflow.compat import v1 as tf\n", 34 | "\n", 35 | "# N is batch size; D_in is input dimension;\n", 36 | "# H is hidden dimension; D_out is output dimension.\n", 37 | "N, D_in, H, D_out = 64, 1000, 100, 10\n", 38 | "\n", 39 | "# Create placeholders for the input and target data; these will be filled\n", 40 | "# with real data when we execute the graph.\n", 41 | "x = tf.placeholder(tf.float32, shape=(None, D_in))\n", 42 | "y = tf.placeholder(tf.float32, shape=(None, D_out))\n", 43 | "\n", 44 | "# Create Variables for the weights and initialize them with random data.\n", 45 | "# A TensorFlow Variable persists its value across executions of the graph.\n", 46 | "w1 = tf.compat.v1.Variable(tf.random_normal((D_in, H)))\n", 47 | "w2 = tf.compat.v1.Variable(tf.random_normal((H, D_out)))" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 5, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "data": { 57 | "text/plain": [ 58 | "" 59 | ] 60 | }, 61 | "execution_count": 5, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "x" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 7, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "" 79 | ] 80 | }, 81 | "execution_count": 7, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "w1" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 8, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Forward pass: Compute the predicted y using operations on TensorFlow Tensors.\n", 97 | "# Note that this code does not actually perform any numeric operations; it\n", 98 | "# merely sets up the computational graph that we will later execute.\n", 99 | "h = tf.matmul(x, w1)\n", 100 | "h_relu = tf.maximum(h, tf.zeros(1))\n", 101 | "y_pred = tf.matmul(h_relu, w2)\n", 102 | "\n", 103 | "# Compute loss using operations on TensorFlow Tensors\n", 104 | "loss = tf.reduce_sum((y - y_pred) ** 2.0)\n", 105 | "\n", 106 | "# Compute gradient of the loss with respect to w1 and w2.\n", 107 | "grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 9, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Update the weights using gradient descent. To actually update the weights\n", 117 | "# we need to evaluate new_w1 and new_w2 when executing the graph. Note that\n", 118 | "# in TensorFlow the the act of updating the value of the weights is part of\n", 119 | "# the computational graph; in PyTorch this happens outside the computational\n", 120 | "# graph.\n", 121 | "learning_rate = 1e-6\n", 122 | "new_w1 = w1.assign(w1 - learning_rate * grad_w1)\n", 123 | "new_w2 = w2.assign(w2 - learning_rate * grad_w2)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 10, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "0 23785982.0\n", 136 | "50 14178.837\n", 137 | "100 555.4907\n", 138 | "150 41.04864\n", 139 | "200 4.0571003\n", 140 | "250 0.454022\n", 141 | "300 0.053948544\n", 142 | "350 0.0068461723\n", 143 | "400 0.0011231959\n", 144 | "450 0.00029193383\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "# Now we have built our computational graph, so we enter a TensorFlow session to\n", 150 | "# actually execute the graph.\n", 151 | "with tf.Session() as sess:\n", 152 | " # Run the graph once to initialize the Variables w1 and w2.\n", 153 | " sess.run(tf.global_variables_initializer())\n", 154 | "\n", 155 | " # Create numpy arrays holding the actual data for the inputs x and targets\n", 156 | " # y\n", 157 | " x_value = np.random.randn(N, D_in)\n", 158 | " y_value = np.random.randn(N, D_out)\n", 159 | " for t in range(500):\n", 160 | " # Execute the graph many times. Each time it executes we want to bind\n", 161 | " # x_value to x and y_value to y, specified with the feed_dict argument.\n", 162 | " # Each time we execute the graph we want to compute the values for loss,\n", 163 | " # new_w1, and new_w2; the values of these Tensors are returned as numpy\n", 164 | " # arrays.\n", 165 | " loss_value, _, _ = sess.run([loss, new_w1, new_w2],\n", 166 | " feed_dict={x: x_value, y: y_value})\n", 167 | " if t % 50 == 0:\n", 168 | " print(t, loss_value)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3 (ipykernel)", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.10.4" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 4 200 | } 201 | -------------------------------------------------------------------------------- /tip_model_persistence.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# PyTorch Model Persistence" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "pycharm": { 18 | "name": "#%% md\n" 19 | } 20 | }, 21 | "source": [ 22 | "One paramount important requirement in DL model training and learning is the ability to store and **save** the internal state of a model for the future!" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "pycharm": { 29 | "name": "#%% md\n" 30 | } 31 | }, 32 | "source": [ 33 | "We might want to **load and save** a model for:\n", 34 | "\n", 35 | "1. **inference**; \n", 36 | "2. **re-start** the training where we left (i.e. _checkpoint_ )\n", 37 | "3. **save** the best hyper-parameter configuration in a randomised _grid search_ optimisation\n", 38 | "4. $\\ldots$" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "pycharm": { 45 | "name": "#%% md\n" 46 | } 47 | }, 48 | "source": [ 49 | "There are **two** approaches for saving and loading models for inference in PyTorch. \n", 50 | "\n", 51 | "The **first** is saving and loading the `state_dict`, and the second is saving and loading the **entire model**." 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "pycharm": { 58 | "name": "#%% md\n" 59 | } 60 | }, 61 | "source": [ 62 | "##### Let's define our (usual) model and optimiser first" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 1, 68 | "metadata": { 69 | "pycharm": { 70 | "name": "#%%\n" 71 | } 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import torch" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": { 82 | "pycharm": { 83 | "name": "#%%\n" 84 | } 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "class TwoLayerNet(torch.nn.Module):\n", 89 | " def __init__(self, D_in, H, D_out):\n", 90 | " \"\"\"\n", 91 | " In the constructor we instantiate two nn.Linear modules and assign them as\n", 92 | " member variables.\n", 93 | " \"\"\"\n", 94 | " super(TwoLayerNet, self).__init__()\n", 95 | " self.linear1 = torch.nn.Linear(D_in, H)\n", 96 | " self.linear2 = torch.nn.Linear(H, D_out)\n", 97 | "\n", 98 | " def forward(self, x):\n", 99 | " \"\"\"\n", 100 | " In the forward function we accept a Tensor of input data and we must return\n", 101 | " a Tensor of output data. We can use Modules defined in the constructor as\n", 102 | " well as arbitrary operators on Tensors.\n", 103 | " \"\"\"\n", 104 | " h_relu = torch.relu(self.linear1(x))\n", 105 | " y_pred = self.linear2(h_relu)\n", 106 | " return y_pred" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 3, 112 | "metadata": { 113 | "pycharm": { 114 | "name": "#%%\n" 115 | } 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "# N is batch size; D_in is input dimension;\n", 120 | "# H is hidden dimension; D_out is output dimension.\n", 121 | "N, D_in, H, D_out = 64, 1000, 100, 10\n", 122 | "\n", 123 | "# Create random Tensors to hold inputs and outputs\n", 124 | "x = torch.randn(N, D_in)\n", 125 | "y = torch.randn(N, D_out)\n", 126 | "\n", 127 | "# Construct our model by instantiating the class defined above\n", 128 | "model = TwoLayerNet(D_in, H, D_out)\n", 129 | "\n", 130 | "# Construct our loss function and an Optimizer. The call to model.parameters()\n", 131 | "# in the SGD constructor will contain the learnable parameters of the two\n", 132 | "# nn.Linear modules which are members of the model.\n", 133 | "criterion = torch.nn.MSELoss(reduction='sum')\n", 134 | "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "jp-MarkdownHeadingCollapsed": true, 141 | "pycharm": { 142 | "name": "#%% md\n" 143 | }, 144 | "tags": [] 145 | }, 146 | "source": [ 147 | "#### Saving the entire model using `pickle`\n", 148 | "\n", 149 | "We could use the Python `pickle` module to save and load an entire model.\n", 150 | "\n", 151 | "Using this approach yields the most intuitive syntax and involves the least amount of code." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 4, 157 | "metadata": { 158 | "pycharm": { 159 | "name": "#%%\n" 160 | } 161 | }, 162 | "outputs": [ 163 | { 164 | "name": "stderr", 165 | "output_type": "stream", 166 | "text": [ 167 | "/Users/gu19087/opt/anaconda3/envs/dl-torch/lib/python3.7/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", 168 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "import pickle\n", 174 | "\n", 175 | "with open('model_serialisation.pkl', 'wb') as pkf: \n", 176 | " pickle.dump(model, pkf)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "metadata": { 183 | "pycharm": { 184 | "name": "#%%\n" 185 | }, 186 | "scrolled": true 187 | }, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | "linear1.weight torch.Size([100, 1000]) 100000\n", 194 | "linear1.bias torch.Size([100]) 100\n", 195 | "linear2.weight torch.Size([10, 100]) 1000\n", 196 | "linear2.bias torch.Size([10]) 10\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "with open('model_serialisation.pkl', 'rb') as pkf:\n", 202 | " model_pkl = pickle.load(pkf)\n", 203 | " for name_str, param in model_pkl.named_parameters():\n", 204 | " print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel()))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "pycharm": { 211 | "name": "#%% md\n" 212 | } 213 | }, 214 | "source": [ 215 | "**However**, this method is far from being flexible: the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. \n", 216 | "\n", 217 | "The reason for this is because pickle does not save the model class itself. \n", 218 | "Rather, it saves a path to the file containing the class, which is used during load time. \n", 219 | "\n", 220 | "**For this reason**, your code can break in various ways when used in other projects or after refactors. " 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": { 226 | "pycharm": { 227 | "name": "#%% md\n" 228 | } 229 | }, 230 | "source": [ 231 | "## Introducing `model|optim.state_dict`" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": { 237 | "pycharm": { 238 | "name": "#%% md\n" 239 | } 240 | }, 241 | "source": [ 242 | "In PyTorch, the learnable parameters (i.e. weights and biases) of a `torch.nn.Module` model are contained in the model’s parameters (accessed with `model.parameters()`). \n", 243 | "\n", 244 | "A `state_dict` is simply a Python dictionary object that maps each layer to its parameter tensor." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 6, 250 | "metadata": { 251 | "pycharm": { 252 | "name": "#%%\n" 253 | } 254 | }, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "linear1.weight torch.Size([100, 1000]) 100000\n", 261 | "linear1.bias torch.Size([100]) 100\n", 262 | "linear2.weight torch.Size([10, 100]) 1000\n", 263 | "linear2.bias torch.Size([10]) 10\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "# model (named) parameters\n", 269 | "for name_str, param in model.named_parameters():\n", 270 | " print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel()))" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 7, 276 | "metadata": { 277 | "pycharm": { 278 | "name": "#%%\n" 279 | }, 280 | "scrolled": true 281 | }, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "[{'params': [Parameter containing:\n", 287 | " tensor([[ 0.0129, -0.0277, 0.0117, ..., -0.0011, 0.0143, -0.0153],\n", 288 | " [ 0.0233, 0.0099, 0.0282, ..., -0.0152, -0.0058, 0.0137],\n", 289 | " [-0.0149, 0.0187, 0.0192, ..., 0.0030, -0.0027, -0.0279],\n", 290 | " ...,\n", 291 | " [-0.0135, 0.0245, 0.0204, ..., -0.0174, 0.0021, 0.0034],\n", 292 | " [ 0.0231, 0.0282, 0.0300, ..., 0.0051, -0.0141, -0.0308],\n", 293 | " [-0.0289, 0.0309, -0.0010, ..., 0.0190, -0.0300, -0.0228]],\n", 294 | " requires_grad=True),\n", 295 | " Parameter containing:\n", 296 | " tensor([ 0.0276, 0.0058, 0.0107, -0.0192, -0.0024, 0.0151, -0.0274, 0.0171,\n", 297 | " 0.0170, 0.0036, -0.0219, 0.0019, 0.0056, -0.0313, 0.0247, -0.0212,\n", 298 | " -0.0061, 0.0081, -0.0244, -0.0102, 0.0062, 0.0146, -0.0283, -0.0042,\n", 299 | " 0.0272, 0.0168, -0.0242, 0.0041, 0.0200, 0.0264, -0.0223, 0.0292,\n", 300 | " 0.0232, 0.0261, 0.0186, -0.0127, 0.0209, -0.0078, -0.0054, -0.0058,\n", 301 | " 0.0274, 0.0140, -0.0233, -0.0233, -0.0227, 0.0311, -0.0284, -0.0142,\n", 302 | " 0.0028, -0.0015, -0.0302, 0.0288, 0.0050, 0.0204, -0.0251, -0.0151,\n", 303 | " 0.0116, 0.0073, 0.0199, -0.0227, -0.0121, 0.0102, -0.0036, -0.0301,\n", 304 | " 0.0017, -0.0155, -0.0290, 0.0002, 0.0296, -0.0193, 0.0074, 0.0119,\n", 305 | " -0.0166, 0.0154, -0.0187, 0.0072, 0.0048, -0.0152, -0.0279, -0.0140,\n", 306 | " 0.0066, -0.0037, -0.0155, 0.0060, 0.0244, -0.0207, -0.0030, 0.0056,\n", 307 | " 0.0060, -0.0171, -0.0028, 0.0223, -0.0062, -0.0147, 0.0306, 0.0030,\n", 308 | " 0.0156, -0.0204, -0.0092, -0.0282], requires_grad=True),\n", 309 | " Parameter containing:\n", 310 | " tensor([[ 6.0397e-03, -8.7106e-02, -5.7451e-02, 1.8538e-03, -6.7711e-02,\n", 311 | " 4.9320e-02, 1.4171e-02, 5.4378e-02, 1.4146e-02, -8.6058e-02,\n", 312 | " 5.0229e-02, 8.7175e-02, -1.7491e-02, -6.8559e-02, 2.4183e-02,\n", 313 | " -4.8714e-02, -3.4075e-02, -7.0591e-02, -2.2242e-02, -8.8331e-02,\n", 314 | " 5.1855e-02, -3.5466e-03, -9.7914e-02, -7.5831e-02, 2.8753e-02,\n", 315 | " 6.3270e-02, -6.4052e-02, 5.3039e-02, 6.5276e-02, 4.8598e-02,\n", 316 | " 5.6762e-02, 5.6528e-02, -5.6083e-02, -4.3567e-02, -4.1739e-02,\n", 317 | " -1.9754e-02, -9.7998e-02, 4.5380e-02, -8.1140e-02, -5.4604e-02,\n", 318 | " -4.0464e-02, -5.8959e-02, 1.6699e-02, -3.1998e-02, 7.6722e-03,\n", 319 | " 4.3287e-02, 6.3781e-02, 1.3740e-03, 4.1390e-02, 6.3780e-02,\n", 320 | " 8.3487e-02, -1.0379e-03, 3.4332e-02, 1.0504e-02, -7.5072e-02,\n", 321 | " 2.0289e-02, 9.3477e-02, 7.4263e-02, -3.5551e-02, 1.4834e-02,\n", 322 | " -4.3010e-02, 7.1876e-02, 3.7470e-02, 4.5498e-02, 1.2225e-02,\n", 323 | " -6.7099e-02, -9.2879e-02, -9.2626e-02, 8.3658e-02, 5.8638e-02,\n", 324 | " 2.6543e-02, 1.2648e-02, -5.2592e-02, -8.4065e-03, -6.0760e-02,\n", 325 | " -9.0605e-02, -7.7751e-03, -8.6171e-02, -2.3379e-02, 7.8544e-02,\n", 326 | " 8.0324e-02, -7.7527e-02, 2.4833e-02, -9.1391e-02, 3.4234e-02,\n", 327 | " -8.8944e-02, -1.0545e-02, 1.9541e-02, -2.2006e-02, 7.3074e-02,\n", 328 | " 8.8172e-02, -4.1541e-02, -8.9734e-02, -4.5536e-02, -1.8579e-03,\n", 329 | " 8.5630e-02, -4.3191e-02, 1.8475e-02, -8.6686e-02, 1.8877e-02],\n", 330 | " [-3.3583e-03, -3.8957e-02, -5.4169e-02, 8.7197e-02, 1.8776e-02,\n", 331 | " 5.4089e-02, 1.1461e-02, 7.9725e-02, 8.9141e-02, -8.7513e-02,\n", 332 | " -8.6997e-03, -7.8238e-02, -7.5282e-02, 2.9287e-02, 7.6636e-02,\n", 333 | " 4.1366e-02, -8.9879e-02, -1.6514e-02, -5.7966e-02, -5.9631e-02,\n", 334 | " 7.8763e-02, -8.7289e-02, 5.5093e-02, -2.6882e-02, -5.0371e-02,\n", 335 | " -9.5760e-02, -3.5820e-03, -5.2561e-02, 2.4092e-02, 8.9301e-02,\n", 336 | " -3.6397e-02, -7.7650e-02, 8.3698e-02, -9.9086e-02, 4.5744e-02,\n", 337 | " 1.8525e-02, 1.3282e-02, 5.5116e-02, 7.2584e-02, -9.9738e-03,\n", 338 | " 8.7878e-02, 4.1210e-02, 5.3266e-02, 1.7971e-02, -5.1427e-03,\n", 339 | " -6.5150e-02, 7.1435e-02, 7.8157e-02, 5.6057e-02, 7.6694e-03,\n", 340 | " -8.7516e-03, 4.3877e-02, 6.4443e-02, 9.1313e-02, 2.7966e-02,\n", 341 | " -4.9264e-02, -9.9418e-02, 1.7548e-02, 5.5363e-02, 7.8699e-02,\n", 342 | " 7.8334e-02, 3.1311e-02, -4.7621e-02, 2.2168e-02, 7.1217e-02,\n", 343 | " 8.9676e-02, -4.5493e-02, -8.3063e-02, -7.9194e-03, 4.2968e-02,\n", 344 | " -1.4330e-02, 5.8753e-02, 2.5522e-02, -5.0606e-02, -4.8589e-03,\n", 345 | " -6.3464e-02, 2.1676e-02, 8.7757e-02, 9.0938e-02, -9.2778e-02,\n", 346 | " 8.0187e-02, -5.4807e-02, 3.9872e-02, 9.3799e-02, -3.7557e-02,\n", 347 | " 2.0098e-02, -1.4448e-02, -3.3486e-02, 6.9724e-02, 4.5098e-02,\n", 348 | " 4.7195e-02, 8.3146e-02, -1.5775e-02, 9.5246e-02, 7.9584e-02,\n", 349 | " -2.0932e-02, -9.7079e-02, -3.7770e-02, -1.4361e-02, -6.5851e-03],\n", 350 | " [-7.3068e-02, 9.6765e-02, 1.3639e-02, -5.5968e-02, -1.2070e-02,\n", 351 | " -7.2104e-02, 9.0902e-02, 7.9207e-02, 9.5885e-02, -9.7221e-02,\n", 352 | " 4.1327e-02, -8.8859e-02, -1.6648e-03, 5.6193e-03, -3.1264e-02,\n", 353 | " 4.3246e-02, 9.2361e-02, 1.0870e-02, -9.3888e-02, 2.1386e-02,\n", 354 | " 5.4324e-02, -7.1312e-02, 6.5123e-02, -9.0817e-02, 9.3773e-02,\n", 355 | " 9.9825e-02, 1.8338e-02, -8.5020e-02, 2.7074e-02, 7.3050e-02,\n", 356 | " -1.4532e-02, 8.7660e-03, -9.9246e-02, -6.9973e-02, 2.1765e-02,\n", 357 | " -2.4508e-02, 4.5531e-02, -7.9878e-03, -9.4078e-02, -6.5869e-02,\n", 358 | " 9.2926e-02, -7.0529e-02, -6.6713e-02, 2.0497e-03, 6.7184e-03,\n", 359 | " 4.3883e-02, 1.0719e-02, -2.5432e-02, 3.4595e-03, 6.4158e-02,\n", 360 | " -1.6255e-02, -7.8653e-02, 2.6729e-02, -4.4632e-02, -2.5579e-02,\n", 361 | " -3.4998e-02, -6.6887e-02, 1.8426e-02, -1.3198e-02, 2.0053e-02,\n", 362 | " -7.5295e-02, 9.5105e-02, -6.2293e-02, -3.6116e-02, -3.6820e-02,\n", 363 | " 8.5946e-02, 1.0105e-02, 1.3194e-02, 1.9207e-02, -8.4800e-02,\n", 364 | " 9.5676e-02, 9.1723e-02, 3.5921e-02, -2.6434e-02, 1.3105e-02,\n", 365 | " 9.9024e-02, -1.9139e-02, -2.7808e-02, 9.3036e-02, 8.4069e-02,\n", 366 | " -6.6241e-02, -3.2865e-02, 8.0090e-02, -2.1852e-02, 1.9757e-02,\n", 367 | " 5.5647e-02, 8.2475e-02, -1.9520e-02, -8.9646e-02, -1.6741e-02,\n", 368 | " 9.9963e-02, -6.0581e-02, 5.2808e-02, -6.7505e-02, -4.6904e-02,\n", 369 | " -8.1296e-02, 5.3573e-02, 9.1083e-02, 1.9106e-02, -2.9080e-02],\n", 370 | " [ 7.1465e-02, -4.3288e-05, -6.9280e-02, -6.7076e-02, 6.3996e-02,\n", 371 | " 3.0872e-02, -7.9766e-02, 9.2255e-02, 7.9854e-02, 5.6929e-02,\n", 372 | " 2.1776e-02, 7.3513e-02, 2.7571e-02, 2.2885e-02, -5.8250e-02,\n", 373 | " 8.9086e-02, 3.8966e-02, -3.2023e-02, -5.8905e-02, 8.6405e-02,\n", 374 | " -7.4906e-02, -7.2720e-02, 5.7642e-03, 3.1756e-02, 1.5929e-02,\n", 375 | " 6.1962e-02, -6.3816e-02, 1.0647e-02, 9.8364e-02, -4.1501e-02,\n", 376 | " 7.0460e-03, 4.4556e-02, -9.5004e-02, -8.3058e-02, -5.3030e-02,\n", 377 | " -8.4196e-02, 3.9178e-02, 8.1794e-04, -3.9290e-02, 7.3309e-02,\n", 378 | " 5.9412e-02, 7.7620e-02, 8.3524e-02, -5.4468e-02, -4.4280e-02,\n", 379 | " -6.7692e-02, 1.7777e-02, 9.7212e-02, -9.0789e-02, 7.2118e-02,\n", 380 | " -3.2341e-02, -5.5326e-02, 4.0588e-02, 9.9750e-02, 3.8025e-02,\n", 381 | " -7.6471e-03, -6.9118e-02, -2.7383e-02, 6.1322e-03, 1.9185e-02,\n", 382 | " -2.4591e-03, 5.6286e-03, -7.7470e-02, -8.9915e-02, 7.9303e-02,\n", 383 | " 4.6957e-02, -4.9349e-03, -2.8890e-02, 7.2679e-02, -4.5378e-02,\n", 384 | " -5.2123e-02, -8.9835e-02, 1.1287e-02, -8.9529e-03, 9.0742e-04,\n", 385 | " 5.5114e-02, 9.3845e-02, 4.8538e-02, 9.1878e-02, -4.0535e-02,\n", 386 | " 1.0247e-02, -2.7165e-02, 4.6268e-02, 7.9237e-02, 4.3156e-02,\n", 387 | " 5.9693e-02, -3.3915e-02, -4.4982e-02, -8.7630e-02, -5.0843e-02,\n", 388 | " -8.2647e-02, -8.2477e-02, 8.0764e-02, 5.0915e-03, 4.6106e-02,\n", 389 | " -1.8766e-02, -3.3182e-02, 8.5344e-02, -1.5128e-03, -8.4288e-02],\n", 390 | " [ 3.4387e-03, 9.9405e-02, 8.2109e-04, -3.2104e-02, -1.1983e-02,\n", 391 | " 8.2339e-02, 7.6370e-02, -9.8138e-02, -2.4796e-02, -7.0797e-02,\n", 392 | " -5.9370e-02, 9.2313e-02, -1.6375e-02, 3.3266e-02, -9.6662e-02,\n", 393 | " -3.2316e-02, -9.0488e-02, -6.1023e-03, -8.1759e-02, -7.0254e-02,\n", 394 | " 9.6130e-03, 9.3115e-02, 3.1517e-02, 8.8886e-02, -9.9441e-02,\n", 395 | " -3.8234e-02, -9.9071e-02, -5.4025e-02, -5.7463e-02, 8.2620e-02,\n", 396 | " 3.6974e-02, 7.3258e-02, 2.6846e-02, 6.2566e-02, 2.5136e-02,\n", 397 | " 8.1305e-02, -9.2818e-02, -5.0314e-02, -6.4061e-02, -5.5626e-02,\n", 398 | " -5.0043e-02, -7.7857e-02, -3.9889e-02, 3.4864e-03, -7.4790e-02,\n", 399 | " 9.3081e-02, 9.8013e-02, 4.1346e-02, 5.9070e-02, 8.6867e-02,\n", 400 | " 5.3637e-02, 7.2297e-02, 9.8437e-02, 5.3412e-02, 5.6891e-02,\n", 401 | " 2.6427e-02, 4.3960e-02, -3.8807e-02, 1.9828e-02, -2.2710e-02,\n", 402 | " -2.6245e-02, 2.1453e-02, 7.5224e-02, 4.1028e-02, -5.1455e-02,\n", 403 | " 2.2243e-02, 7.9249e-02, -4.6603e-02, -6.7208e-02, -6.0396e-02,\n", 404 | " 5.0192e-02, -3.1604e-02, 3.6818e-02, 7.0960e-02, 7.7530e-02,\n", 405 | " 6.5197e-02, 6.9150e-02, -4.6160e-02, 5.2696e-02, -9.1601e-03,\n", 406 | " -9.4035e-02, -8.0848e-02, 1.2252e-02, 7.2659e-02, -2.4505e-02,\n", 407 | " 7.4934e-02, -7.6167e-02, 2.1669e-02, 9.3008e-02, 4.9432e-02,\n", 408 | " 9.4449e-02, 6.6678e-02, 4.2753e-02, -1.1910e-02, 9.6679e-02,\n", 409 | " -6.8529e-02, -1.0157e-03, -9.2545e-02, -9.4833e-02, -6.0401e-02],\n", 410 | " [-2.0554e-02, -4.1283e-02, 7.5202e-02, -5.0581e-02, 7.6867e-02,\n", 411 | " 1.0000e-01, -9.3872e-02, -6.9777e-02, -5.9833e-02, -6.9427e-02,\n", 412 | " 7.6663e-02, -8.5208e-02, 2.4230e-02, -1.5796e-02, -1.5415e-02,\n", 413 | " 6.4671e-02, 9.9584e-02, 3.2597e-02, -2.5461e-02, -5.5413e-03,\n", 414 | " -1.3413e-02, -3.1326e-02, -2.9192e-02, -4.6567e-02, 9.8243e-02,\n", 415 | " -1.1212e-02, -2.3979e-02, 7.7617e-02, -3.7327e-02, -4.7715e-02,\n", 416 | " -4.4871e-02, -3.9263e-02, 8.7015e-02, 5.9443e-03, -1.4883e-02,\n", 417 | " 1.6572e-02, -2.3925e-02, -3.1680e-02, -5.4073e-02, -9.9868e-02,\n", 418 | " 3.5647e-02, 9.3804e-02, -5.1556e-02, -6.6315e-02, 1.7316e-02,\n", 419 | " -5.9026e-02, 1.2335e-02, -8.9687e-02, -3.8119e-03, -8.3790e-02,\n", 420 | " -4.2525e-02, -1.3249e-02, 6.1686e-02, -6.5085e-02, -2.7949e-02,\n", 421 | " -8.7100e-02, -2.4767e-02, -1.7463e-02, 9.0935e-02, -9.5649e-02,\n", 422 | " -8.2664e-02, 5.0175e-03, 2.1225e-02, 7.2999e-02, 8.7847e-02,\n", 423 | " 6.7847e-02, -5.7107e-04, 5.6071e-02, -8.4907e-02, 1.7269e-02,\n", 424 | " -4.0415e-02, -7.4515e-02, 1.3839e-02, 9.1045e-03, -7.7219e-02,\n", 425 | " 2.3996e-02, 8.7458e-02, 5.0316e-02, -2.2645e-02, -8.2988e-02,\n", 426 | " -9.7689e-03, -9.6182e-02, 2.5362e-02, 9.9611e-02, 4.6788e-02,\n", 427 | " 1.1238e-02, -8.6436e-02, 7.3232e-02, -6.6566e-02, 8.3938e-02,\n", 428 | " -2.4745e-02, -8.5759e-02, 5.8885e-03, 2.4950e-02, -6.8162e-02,\n", 429 | " -3.2551e-02, 2.0507e-02, 7.8403e-02, -2.4033e-02, 6.6680e-02],\n", 430 | " [ 5.7149e-02, -3.7554e-02, -8.5112e-02, 9.3722e-02, 6.8870e-02,\n", 431 | " 8.3002e-02, -4.0203e-02, 2.8195e-02, 1.5695e-02, 1.4885e-02,\n", 432 | " -5.2610e-02, -5.1160e-02, -1.4406e-03, 8.5899e-02, -9.0598e-02,\n", 433 | " -6.8469e-02, -6.5987e-02, 6.2049e-02, -4.0249e-02, -7.2779e-02,\n", 434 | " 4.2711e-02, -3.8364e-03, 3.8309e-02, 9.1223e-02, 6.4874e-02,\n", 435 | " -7.0102e-02, 5.6570e-02, -7.5938e-02, -7.7024e-02, 6.4373e-02,\n", 436 | " -3.6254e-02, -7.8834e-02, 5.0244e-02, -6.0321e-02, -7.0094e-02,\n", 437 | " 5.5626e-02, -6.7325e-02, 3.0025e-02, -7.4662e-02, -6.7068e-03,\n", 438 | " 3.2788e-02, -9.8470e-02, -9.1810e-02, 3.3483e-02, 9.4110e-02,\n", 439 | " 2.5885e-02, -3.8974e-02, 8.6909e-02, 8.3467e-02, -7.8696e-02,\n", 440 | " -7.7835e-02, -8.6607e-02, -2.8971e-02, -2.6093e-02, -9.8127e-02,\n", 441 | " 3.8763e-03, 4.6259e-02, -7.1629e-02, 4.8900e-02, 6.9722e-02,\n", 442 | " -3.2478e-02, 9.2257e-02, 6.7863e-02, -3.6889e-02, -8.8230e-02,\n", 443 | " -7.6977e-03, -1.0429e-03, -2.3290e-04, -7.8065e-02, 5.5609e-02,\n", 444 | " -9.5768e-02, 5.1733e-02, -7.7599e-02, -8.3914e-02, -2.3168e-02,\n", 445 | " 1.3477e-02, 8.0447e-02, 5.4271e-03, -8.6737e-02, 3.6465e-02,\n", 446 | " -2.1326e-02, 6.7685e-03, 5.9429e-02, 7.4680e-02, 8.4911e-02,\n", 447 | " 2.6829e-02, 8.1955e-02, 1.3126e-02, 2.6068e-03, -9.2530e-02,\n", 448 | " 3.2421e-02, 4.9750e-02, -6.1601e-02, -6.4948e-02, 3.6481e-02,\n", 449 | " -5.8526e-02, 3.7691e-02, -1.7092e-02, -3.5618e-02, -6.8862e-02],\n", 450 | " [-6.6715e-02, -7.3132e-02, -3.5132e-03, 7.7850e-02, 6.3476e-02,\n", 451 | " 7.1232e-02, -9.6672e-02, -8.2047e-02, 2.2797e-02, 3.1244e-02,\n", 452 | " -5.2365e-02, 3.2612e-02, 3.4499e-02, -4.0493e-02, 5.1946e-02,\n", 453 | " 9.1978e-02, -8.4371e-02, 8.8180e-02, 9.4906e-02, 8.1407e-02,\n", 454 | " -8.8640e-02, 2.8114e-03, 5.3087e-02, -2.2679e-02, 7.3953e-02,\n", 455 | " -8.8017e-02, -3.9196e-02, -9.2420e-02, -9.9014e-02, -4.2489e-02,\n", 456 | " 6.2517e-02, -3.1612e-02, 7.9146e-02, -7.6977e-03, 4.8802e-02,\n", 457 | " -8.3909e-02, -6.6282e-02, -1.8300e-02, -2.9578e-02, 5.8294e-02,\n", 458 | " 6.5964e-02, 3.2812e-03, 3.7626e-02, 9.6420e-02, -4.2926e-02,\n", 459 | " 5.4093e-02, -7.8043e-02, -9.1384e-02, -2.2752e-02, -2.3332e-02,\n", 460 | " -2.6411e-02, 7.2564e-02, -8.4488e-03, -5.8523e-02, -5.0776e-02,\n", 461 | " 3.2122e-02, -9.7434e-02, 3.3793e-02, 2.0783e-02, -4.5941e-02,\n", 462 | " -2.6204e-02, -9.7895e-02, -3.2598e-02, 5.6724e-02, -5.1443e-02,\n", 463 | " 2.9846e-02, 8.4554e-03, -4.5411e-02, -3.3259e-02, -2.0535e-02,\n", 464 | " -4.3371e-02, 5.0148e-02, 4.6049e-02, -1.5824e-02, -7.2475e-02,\n", 465 | " 4.3322e-02, -3.9835e-02, 8.5319e-02, -2.0438e-02, -2.2051e-02,\n", 466 | " 1.0703e-02, -6.2848e-02, -3.4325e-02, -9.3368e-02, 9.2366e-02,\n", 467 | " -8.3445e-02, 9.6797e-02, -9.0583e-02, 6.4380e-03, 2.7341e-02,\n", 468 | " 1.3602e-02, -7.7465e-03, -3.7057e-03, -6.5682e-02, 6.6202e-02,\n", 469 | " 6.0016e-02, -9.2970e-02, 8.4253e-02, 3.0414e-02, -8.5505e-02],\n", 470 | " [-9.5348e-02, -7.6113e-02, -8.3606e-02, 1.5938e-02, -2.0131e-02,\n", 471 | " -2.0699e-02, 3.4306e-04, -7.2682e-02, -6.7024e-03, -8.0483e-03,\n", 472 | " 9.7141e-03, -6.4989e-02, -6.9469e-02, 3.8024e-02, 5.0706e-02,\n", 473 | " 2.0047e-02, -2.3145e-02, -6.8218e-02, -5.4835e-02, -5.6266e-02,\n", 474 | " -1.5731e-03, 5.5056e-03, -5.9776e-02, -8.9783e-02, 4.0810e-03,\n", 475 | " 1.0393e-02, -4.2363e-02, 4.7316e-02, -7.1609e-02, -3.1711e-02,\n", 476 | " 9.2264e-02, -8.2079e-02, 6.6912e-02, 6.4212e-02, 2.1305e-02,\n", 477 | " -5.0689e-02, -1.9879e-02, -6.8521e-02, 5.3721e-02, -6.3786e-02,\n", 478 | " 7.4574e-02, 3.9781e-02, -1.5271e-02, 1.2718e-02, -3.4332e-02,\n", 479 | " 2.2917e-02, 9.0192e-02, 5.0204e-03, 9.2888e-02, -2.6175e-02,\n", 480 | " -5.0326e-02, 9.9106e-02, -1.3483e-02, -7.1479e-02, -1.1978e-02,\n", 481 | " 4.7299e-02, 6.4482e-02, -5.9913e-02, -7.9387e-02, 9.2318e-02,\n", 482 | " 2.1284e-02, 9.9712e-02, -7.3051e-02, 9.0756e-02, 7.7301e-02,\n", 483 | " 9.8834e-02, -4.6585e-02, -6.4026e-02, -1.8898e-02, 6.0776e-02,\n", 484 | " -2.7641e-02, 2.5494e-02, 7.2721e-02, -1.5766e-02, 4.7878e-02,\n", 485 | " -4.2163e-02, 8.5962e-02, 7.3850e-02, -2.3346e-02, -2.6012e-02,\n", 486 | " 7.4358e-03, 9.6486e-02, 9.3332e-02, -6.5746e-02, 2.5491e-02,\n", 487 | " -4.3182e-02, -5.5005e-02, 2.0531e-02, -1.5092e-02, 2.5460e-02,\n", 488 | " -3.0333e-02, -8.8788e-02, 8.4583e-02, -1.8511e-03, -2.2102e-02,\n", 489 | " -9.5434e-02, -1.9180e-02, 8.7092e-03, 5.7534e-02, 3.9332e-02],\n", 490 | " [-4.6316e-02, -8.3404e-02, -7.9888e-02, 9.1124e-02, 6.2954e-02,\n", 491 | " 6.5575e-02, -9.3121e-02, 4.2876e-02, 8.9381e-02, -3.9886e-02,\n", 492 | " 4.9845e-02, -4.5368e-03, 7.6700e-02, -3.5977e-02, 5.4258e-02,\n", 493 | " 1.9079e-02, 3.3177e-02, 7.7526e-02, 8.3591e-02, -4.2260e-02,\n", 494 | " 6.7222e-02, -9.2025e-02, 8.3811e-02, -2.6344e-02, 3.6421e-02,\n", 495 | " 2.7206e-02, -1.9360e-02, 3.7738e-02, -5.1932e-02, 2.7489e-02,\n", 496 | " -5.4084e-02, -5.5596e-02, -3.9697e-02, -3.2667e-02, -9.1111e-02,\n", 497 | " -7.7667e-02, 5.7508e-02, -2.9396e-02, 7.9757e-02, -3.3773e-03,\n", 498 | " -8.4354e-02, -9.1824e-02, -1.5001e-02, -3.5265e-03, 3.3553e-02,\n", 499 | " -7.8544e-02, 7.0951e-02, 8.9415e-02, -6.8650e-02, -6.4837e-03,\n", 500 | " 4.1214e-02, 5.4215e-02, 1.6607e-02, -5.5795e-02, -2.3485e-02,\n", 501 | " -5.2561e-02, -3.0251e-02, 8.7573e-02, -8.5913e-02, -5.2281e-02,\n", 502 | " -3.9188e-02, -1.4096e-02, 6.3002e-02, -3.8720e-02, 6.5381e-02,\n", 503 | " -1.7108e-04, -6.0910e-02, 8.9747e-02, -3.4480e-03, 3.5674e-02,\n", 504 | " -9.3960e-02, 4.7458e-02, -9.9309e-02, -7.5367e-02, 6.8660e-02,\n", 505 | " 1.0661e-02, -1.1495e-02, 5.4564e-02, 2.3786e-03, 3.5486e-02,\n", 506 | " 7.3573e-02, -8.1635e-02, -6.5729e-02, -7.0058e-02, -2.7226e-02,\n", 507 | " -6.3011e-02, 4.7394e-02, -5.4706e-02, 8.1106e-02, 2.3748e-03,\n", 508 | " -3.7955e-02, -5.7144e-02, -6.7127e-02, -9.8912e-02, 6.3665e-02,\n", 509 | " 6.8999e-02, 6.7499e-02, -8.6455e-02, 1.0371e-02, 2.6472e-02]],\n", 510 | " requires_grad=True),\n", 511 | " Parameter containing:\n", 512 | " tensor([-0.0698, -0.0899, 0.0866, -0.0496, 0.0276, -0.0060, -0.0541, 0.0843,\n", 513 | " 0.0287, -0.0502], requires_grad=True)],\n", 514 | " 'lr': 0.0001,\n", 515 | " 'momentum': 0,\n", 516 | " 'dampening': 0,\n", 517 | " 'weight_decay': 0,\n", 518 | " 'nesterov': False}]" 519 | ] 520 | }, 521 | "execution_count": 7, 522 | "metadata": {}, 523 | "output_type": "execute_result" 524 | } 525 | ], 526 | "source": [ 527 | "optimizer.param_groups" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 8, 533 | "metadata": { 534 | "pycharm": { 535 | "name": "#%%\n" 536 | } 537 | }, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "text/plain": [ 542 | "dict" 543 | ] 544 | }, 545 | "execution_count": 8, 546 | "metadata": {}, 547 | "output_type": "execute_result" 548 | } 549 | ], 550 | "source": [ 551 | "p = optimizer.param_groups[0]\n", 552 | "type(p)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 9, 558 | "metadata": { 559 | "pycharm": { 560 | "name": "#%%\n" 561 | } 562 | }, 563 | "outputs": [ 564 | { 565 | "data": { 566 | "text/plain": [ 567 | "['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov']" 568 | ] 569 | }, 570 | "execution_count": 9, 571 | "metadata": {}, 572 | "output_type": "execute_result" 573 | } 574 | ], 575 | "source": [ 576 | "list(p.keys())" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": { 582 | "pycharm": { 583 | "name": "#%% md\n" 584 | } 585 | }, 586 | "source": [ 587 | "When we have to save a DL model, we definitely **need** to save model parameters (e.g. _inference_ ), but for other cases (i.e. _model checkpoint_ ) we **also need** to save **optimiser** `parameters` and `hyper-parameters`" 588 | ] 589 | }, 590 | { 591 | "cell_type": "markdown", 592 | "metadata": { 593 | "pycharm": { 594 | "name": "#%% md\n" 595 | } 596 | }, 597 | "source": [ 598 | "##### `state_dict`\n", 599 | "\n", 600 | "A `state_dict` is an integral entity if you are interested in saving or loading models from PyTorch. \n", 601 | "\n", 602 | "Because `state_dict` objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers. \n", 603 | "\n", 604 | "Note that **only** layers with learnable parameters and registered buffers (e.g. batchnorm’s running_mean) have entries in the model’s `state_dict`. Optimizer objects (`torch.optim`) also have a `state_dict`, which contains information about the optimizer’s state, as well as the **hyperparameters** used. " 605 | ] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 13, 610 | "metadata": { 611 | "pycharm": { 612 | "name": "#%%\n" 613 | } 614 | }, 615 | "outputs": [ 616 | { 617 | "name": "stdout", 618 | "output_type": "stream", 619 | "text": [ 620 | "Model's state_dict:\n", 621 | "linear1.weight \t torch.Size([100, 1000])\n", 622 | "linear1.bias \t torch.Size([100])\n", 623 | "linear2.weight \t torch.Size([10, 100])\n", 624 | "linear2.bias \t torch.Size([10])\n", 625 | "\n", 626 | "Optimizer's state_dict:\n", 627 | "state \t {}\n", 628 | "param_groups \t [{'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4875991952, 4875992032, 4875992112, 4875992192]}]\n" 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "# Print model's state_dict\n", 634 | "print(\"Model's state_dict:\")\n", 635 | "for param_tensor in model.state_dict():\n", 636 | " print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n", 637 | "\n", 638 | "print()\n", 639 | "\n", 640 | "# Print optimizer's state_dict\n", 641 | "print(\"Optimizer's state_dict:\")\n", 642 | "for var_name in optimizer.state_dict():\n", 643 | " print(var_name, \"\\t\", optimizer.state_dict()[var_name])" 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "metadata": { 649 | "pycharm": { 650 | "name": "#%% md\n" 651 | } 652 | }, 653 | "source": [ 654 | "### Saving and Loading models for Inference in PyTorch" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": { 660 | "pycharm": { 661 | "name": "#%% md\n" 662 | } 663 | }, 664 | "source": [ 665 | "Each instance of a `torch.nn.Module` can be saved using the `torch.save()` function.\n", 666 | "\n", 667 | "Saving the model’s `state_dict` with the `torch.save()` function will give you the most flexibility for restoring the model later. \n", 668 | "\n", 669 | "This is the **recommended method** for saving models, because it is only really necessary to save the trained model’s learned parameters. \n", 670 | "\n" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 15, 676 | "metadata": { 677 | "pycharm": { 678 | "name": "#%%\n" 679 | } 680 | }, 681 | "outputs": [], 682 | "source": [ 683 | "# Save\n", 684 | "torch.save(model.state_dict(), \"model_state_dict.pt\")" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 17, 690 | "metadata": { 691 | "pycharm": { 692 | "name": "#%%\n" 693 | } 694 | }, 695 | "outputs": [ 696 | { 697 | "data": { 698 | "text/plain": [ 699 | "" 700 | ] 701 | }, 702 | "execution_count": 17, 703 | "metadata": {}, 704 | "output_type": "execute_result" 705 | } 706 | ], 707 | "source": [ 708 | "# Load\n", 709 | "model = TwoLayerNet(D_in, H, D_out)\n", 710 | "model.load_state_dict(torch.load(\"model_state_dict.pt\"))" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "metadata": { 716 | "pycharm": { 717 | "name": "#%% md\n" 718 | } 719 | }, 720 | "source": [ 721 | "A common PyTorch **convention** is to save models using either a `.pt` or `.pth` file extension.\n", 722 | "\n", 723 | "Notice that the `load_state_dict()` function takes a dictionary object, NOT a `path` to a saved object. \n", 724 | "\n", 725 | "This means that you **must** deserialize the saved `state_dict` before you pass it to the `load_state_dict()` function. \n", 726 | "\n", 727 | "For example, you **CANNOT** just load using `model.load_state_dict(\"path_to_file.pt\")`." 728 | ] 729 | }, 730 | { 731 | "cell_type": "markdown", 732 | "metadata": { 733 | "pycharm": { 734 | "name": "#%% md\n" 735 | } 736 | }, 737 | "source": [ 738 | "###### Saving and Loading Entire Model" 739 | ] 740 | }, 741 | { 742 | "cell_type": "markdown", 743 | "metadata": { 744 | "pycharm": { 745 | "name": "#%% md\n" 746 | } 747 | }, 748 | "source": [ 749 | "Let’s try the same thing with the entire model." 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 18, 755 | "metadata": { 756 | "pycharm": { 757 | "name": "#%%\n" 758 | } 759 | }, 760 | "outputs": [ 761 | { 762 | "name": "stderr", 763 | "output_type": "stream", 764 | "text": [ 765 | "/Users/gu19087/opt/anaconda3/envs/dl-torch/lib/python3.7/site-packages/torch/serialization.py:402: UserWarning: Couldn't retrieve source code for container of type TwoLayerNet. It won't be checked for correctness upon loading.\n", 766 | " \"type \" + obj.__name__ + \". It won't be checked \"\n" 767 | ] 768 | } 769 | ], 770 | "source": [ 771 | "# Save\n", 772 | "torch.save(model, \"model.pth\")\n", 773 | "\n", 774 | "# Load\n", 775 | "model = torch.load(\"model.pth\")" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": 20, 781 | "metadata": { 782 | "pycharm": { 783 | "name": "#%%\n" 784 | } 785 | }, 786 | "outputs": [ 787 | { 788 | "name": "stdout", 789 | "output_type": "stream", 790 | "text": [ 791 | "linear1.weight \t torch.Size([100, 1000])\n", 792 | "linear1.bias \t torch.Size([100])\n", 793 | "linear2.weight \t torch.Size([10, 100])\n", 794 | "linear2.bias \t torch.Size([10])\n" 795 | ] 796 | } 797 | ], 798 | "source": [ 799 | "for param_tensor in model.state_dict():\n", 800 | " print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())" 801 | ] 802 | }, 803 | { 804 | "cell_type": "markdown", 805 | "metadata": { 806 | "pycharm": { 807 | "name": "#%% md\n" 808 | } 809 | }, 810 | "source": [ 811 | "---" 812 | ] 813 | }, 814 | { 815 | "cell_type": "markdown", 816 | "metadata": { 817 | "pycharm": { 818 | "name": "#%% md\n" 819 | } 820 | }, 821 | "source": [ 822 | "### Saving and loading model checkpoint" 823 | ] 824 | }, 825 | { 826 | "cell_type": "markdown", 827 | "metadata": { 828 | "pycharm": { 829 | "name": "#%% md\n" 830 | } 831 | }, 832 | "source": [ 833 | "Saving and loading a general `checkpoint model` for inference or resuming training can be helpful for picking up where you last left off. \n", 834 | "\n", 835 | "When saving a general checkpoint, you must save more than just the model’s `state_dict`. \n", 836 | "\n", 837 | "It is **also important** to save the **optimizer**’s `state_dict`, as this contains buffers and parameters that are updated as the model trains. \n", 838 | "\n", 839 | "**Moreover**, you might also want to save the `epoch` you left off on, the latest recorded `training loss`, external layers, and more, based on your own algorithm." 840 | ] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "execution_count": 21, 845 | "metadata": { 846 | "pycharm": { 847 | "name": "#%%\n" 848 | } 849 | }, 850 | "outputs": [], 851 | "source": [ 852 | "# Additional information\n", 853 | "EPOCH = 5\n", 854 | "LOSS = 0.4\n", 855 | "CHKPOINT = \"model_checpoint.pth\"\n", 856 | "\n", 857 | "torch.save({'epoch': EPOCH,\n", 858 | " 'model_state_dict': model.state_dict(),\n", 859 | " 'optimizer_state_dict': optimizer.state_dict(),\n", 860 | " 'loss': LOSS,\n", 861 | " }, CHKPOINT)" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": 22, 867 | "metadata": { 868 | "pycharm": { 869 | "name": "#%%\n" 870 | } 871 | }, 872 | "outputs": [], 873 | "source": [ 874 | "# Load\n", 875 | "model = TwoLayerNet(D_in, H, D_out)\n", 876 | "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)\n", 877 | "\n", 878 | "checkpoint = torch.load(CHKPOINT)\n", 879 | "model.load_state_dict(checkpoint['model_state_dict'])\n", 880 | "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", 881 | "epoch = checkpoint['epoch']\n", 882 | "loss = checkpoint['loss']" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": 23, 888 | "metadata": { 889 | "pycharm": { 890 | "name": "#%%\n" 891 | } 892 | }, 893 | "outputs": [ 894 | { 895 | "name": "stdout", 896 | "output_type": "stream", 897 | "text": [ 898 | "Model's state_dict:\n", 899 | "linear1.weight \t torch.Size([100, 1000])\n", 900 | "linear1.bias \t torch.Size([100])\n", 901 | "linear2.weight \t torch.Size([10, 100])\n", 902 | "linear2.bias \t torch.Size([10])\n", 903 | "\n", 904 | "Optimizer's state_dict:\n", 905 | "state \t {}\n", 906 | "param_groups \t [{'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4898874624, 4898873904, 4896499024, 4898875584]}]\n" 907 | ] 908 | } 909 | ], 910 | "source": [ 911 | "print(\"Model's state_dict:\")\n", 912 | "for param_tensor in model.state_dict():\n", 913 | " print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n", 914 | "\n", 915 | "print()\n", 916 | "\n", 917 | "# Print optimizer's state_dict\n", 918 | "print(\"Optimizer's state_dict:\")\n", 919 | "for var_name in optimizer.state_dict():\n", 920 | " print(var_name, \"\\t\", optimizer.state_dict()[var_name])" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 24, 926 | "metadata": { 927 | "pycharm": { 928 | "name": "#%%\n" 929 | } 930 | }, 931 | "outputs": [ 932 | { 933 | "name": "stdout", 934 | "output_type": "stream", 935 | "text": [ 936 | "Loss from Checkpoint: 0.4\n", 937 | "Epoch: 5\n" 938 | ] 939 | } 940 | ], 941 | "source": [ 942 | "print('Loss from Checkpoint: ', loss)\n", 943 | "print('Epoch: ', epoch)" 944 | ] 945 | } 946 | ], 947 | "metadata": { 948 | "kernelspec": { 949 | "display_name": "Python 3 (ipykernel)", 950 | "language": "python", 951 | "name": "python3" 952 | }, 953 | "language_info": { 954 | "codemirror_mode": { 955 | "name": "ipython", 956 | "version": 3 957 | }, 958 | "file_extension": ".py", 959 | "mimetype": "text/x-python", 960 | "name": "python", 961 | "nbconvert_exporter": "python", 962 | "pygments_lexer": "ipython3", 963 | "version": "3.10.4" 964 | } 965 | }, 966 | "nbformat": 4, 967 | "nbformat_minor": 4 968 | } 969 | -------------------------------------------------------------------------------- /addendum_autograd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext notexbook" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 5, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/html": [ 20 | "\n", 2995 | "\n", 2996 | "\n", 3076 | "\n", 3077 | "
\n", 3078 | " The notebook is using\n", 3079 | " \n", 3080 | " no$\\TeX$book Jupyter Theme (release 2.0.1).\n", 3081 | "\n", 3082 | "
" 3083 | ], 3084 | "text/plain": [ 3085 | "" 3086 | ] 3087 | }, 3088 | "execution_count": 5, 3089 | "metadata": {}, 3090 | "output_type": "execute_result" 3091 | } 3092 | ], 3093 | "source": [ 3094 | "%texify" 3095 | ] 3096 | }, 3097 | { 3098 | "cell_type": "markdown", 3099 | "metadata": {}, 3100 | "source": [ 3101 | "# Addendum: Automatic Differentiation" 3102 | ] 3103 | }, 3104 | { 3105 | "cell_type": "markdown", 3106 | "metadata": {}, 3107 | "source": [ 3108 | "Fundamental to Deep Learning and Neural network foundation is the computation of the **gradients**. And so now the question is: _How can we *really* calculate those gradients ?_\n", 3109 | "\n", 3110 | "In addition, two very compulsory requirements about this calculation: \n", 3111 | "- **computationally exact**; \n", 3112 | "- **computationally efficient**.\n" 3113 | ] 3114 | }, 3115 | { 3116 | "cell_type": "markdown", 3117 | "metadata": {}, 3118 | "source": [ 3119 | "## Introducing Automatic Differentiation" 3120 | ] 3121 | }, 3122 | { 3123 | "cell_type": "markdown", 3124 | "metadata": {}, 3125 | "source": [ 3126 | "> Methods for the computation of derivatives in computer programs can be classified into four categories: \n", 3127 | ">\n", 3128 | "> (1) **manually** working out derivatives and coding them; \n", 3129 | ">\n", 3130 | "> (2) **numerical** differentiation using finite difference approximations; \n", 3131 | ">\n", 3132 | "> (3) **symbolic** differentiation using expression manipulation in computer algebra > systems; \n", 3133 | ">\n", 3134 | "> (4) **automatic** differentiation, also called *algorithmic differentiation*" 3135 | ] 3136 | }, 3137 | { 3138 | "cell_type": "markdown", 3139 | "metadata": {}, 3140 | "source": [ 3141 | "\"Automatic\n", 3142 | "\n", 3143 | "---\n", 3144 | "Source: [4]" 3145 | ] 3146 | }, 3147 | { 3148 | "cell_type": "markdown", 3149 | "metadata": {}, 3150 | "source": [ 3151 | "**Automatic Differentiation** is the subject matter of this notebook (and section!)" 3152 | ] 3153 | }, 3154 | { 3155 | "cell_type": "markdown", 3156 | "metadata": {}, 3157 | "source": [ 3158 | "##### Automatic Differentiation in a Nutshell\n", 3159 | "\n", 3160 | "**TL; DR**: Automatic Differentiation lets you compute **exact** derivatives in **constant time**" 3161 | ] 3162 | }, 3163 | { 3164 | "cell_type": "markdown", 3165 | "metadata": {}, 3166 | "source": [ 3167 | "###### Teaser\n", 3168 | "\n", 3169 | "Automatic Differentiation is the secret sauce that powers most of the existing Deep Learning frameworks (e.g. Pytorch or TensorFlow). \n", 3170 | "\n", 3171 | "In a nutshell, Deep learnin frameworks provide the (technical) infrastructure in which computing the derivative of a function takes as much time as evaluating the function. In particular, the design idea is: \"you define a network with a loss function, and you get a gradient *for free*\".\n", 3172 | "\n", 3173 | "\n", 3174 | "**Differentiation** in general is becoming a **first class citizen** in programming languages with early work started by Chris Lattner of LLVM famework — see the [Differentiable Programming Manifesto](https://github.com/apple/swift/blob/master/docs/DifferentiableProgramming.md) for more detail.\n", 3175 | "\n", 3176 | "However, if you're still wondering whether any of this is just boring math and how this relates to computer programming, I definitely suggest to have a look at **this** vide0 on YouTube:\n", 3177 | "\n", 3178 | "[GOTO 2018 - Machine Learning: Alchemy for the Modern Computer Scientist by Erik Meijer](https://www.youtube.com/watch?v=Rs0uRQJdIcg)" 3179 | ] 3180 | }, 3181 | { 3182 | "cell_type": "markdown", 3183 | "metadata": {}, 3184 | "source": [ 3185 | "##### Automatic Differentiation is NOT $\\ldots$" 3186 | ] 3187 | }, 3188 | { 3189 | "cell_type": "markdown", 3190 | "metadata": {}, 3191 | "source": [ 3192 | "In the above list, we included automatic differentiation as the fourth item [$^3$](#fn3 \"Manual really?\"), so separated from other differentiation methods such as **symbolic** or **numerical** methods.\n", 3193 | "\n", 3194 | "---\n", 3195 | "3: The first was indeed manual, hard coded implementation - not really meant though!" 3196 | ] 3197 | }, 3198 | { 3199 | "cell_type": "markdown", 3200 | "metadata": {}, 3201 | "source": [ 3202 | "###### Automatic Differentiation $\\neq$ Symbolic Differentiation\n", 3203 | "\n", 3204 | "This is the **most** obvious, I presume. \n", 3205 | "\n", 3206 | "Symbolic differentiation works by breaking apart a complex expression into a bunch of simpler expressions by using various rules — very similar to a compiler.\n", 3207 | "\n", 3208 | "(Mathematically and) Technically speaking this is all possible thanks to the [**compositionality**](https://en.wikipedia.org/wiki/Principle_of_compositionality) of derivatives calculation. \n", 3209 | "\n", 3210 | "Examples of some rules:\n", 3211 | "\n", 3212 | "**Sum rule**:\n", 3213 | "\n", 3214 | "$$\n", 3215 | "\\begin{equation}\n", 3216 | "\\frac{d}{dx} (f(x) + g(x)) = \\frac{d}{dx} f(x) + \\frac{d}{dx} g(x)\n", 3217 | "\\end{equation}\n", 3218 | "$$\n", 3219 | "\n", 3220 | "**Derivatives of powers rule**:\n", 3221 | "\n", 3222 | "$$\n", 3223 | "\\begin{equation}\n", 3224 | "\\frac{d}{dx} x^r = rx^{r-1}\n", 3225 | "\\end{equation}\n", 3226 | "$$\n", 3227 | "\n", 3228 | "\n", 3229 | "**Multiplication rule**:\n", 3230 | "\n", 3231 | "$$\n", 3232 | "\\begin{equation}\n", 3233 | "\\frac{d(fg)}{dx} = \\frac{df}{dx} g + f\\frac{dg}{dx})\n", 3234 | "\\end{equation}\n", 3235 | "$$\n", 3236 | "\n", 3237 | "**Chain rule (!!)**:\n", 3238 | "\n", 3239 | "$$\n", 3240 | "\\begin{equation}\n", 3241 | "\\frac{d}{dx}[f(g(x))] = \\frac{df}{dg} \\cdot \\frac{dg}{dx}\n", 3242 | "\\end{equation}\n", 3243 | "$$\n" 3244 | ] 3245 | }, 3246 | { 3247 | "cell_type": "markdown", 3248 | "metadata": {}, 3249 | "source": [ 3250 | "**Issue**:\n", 3251 | "\n", 3252 | "- For complicated functions, the resultant expression can be exponentially large - $O(2^n)$;\n", 3253 | "- Could be wasteful to keep around intermediate symbolic expressions if we only need a **numeric value** of the gradient in the end;\n", 3254 | "\n", 3255 | "**Real world examples**: `theano` (see [doc](http://deeplearning.net/software/theano/tutorial/gradients.html))\n" 3256 | ] 3257 | }, 3258 | { 3259 | "cell_type": "markdown", 3260 | "metadata": {}, 3261 | "source": [ 3262 | "###### Automatic Differentiation $\\neq$ Numerical Differentiation\n", 3263 | "\n", 3264 | "Numerical Differentiation is far more intuitive and effectively the (most) natural way to compute derivatives, building on the very definition of derivatives as the **limit of the different quotient**:\n", 3265 | "\n", 3266 | "$$\n", 3267 | "\\begin{equation}\n", 3268 | "f'(x) = \\lim_{\\epsilon \\leftarrow 0} \\frac{f(x + \\epsilon) - f(x)}{\\epsilon}\n", 3269 | "\\end{equation}\n", 3270 | "$$\n", 3271 | "\n", 3272 | "Recap on (other) notations of derivatives:\n", 3273 | "\n", 3274 | "- f'(x): **Lagrange's Notation**\n", 3275 | "- $\\frac{df}{dx}$: **Leibniz's Notation**\n", 3276 | "\n", 3277 | "For **multivariate** functions $f: \\mathbb{R} \\mapsto \\mathbb{R}^n $, we can approximate the gradient \n", 3278 | "$\\nabla f = ( \\frac{\\partial f}{\\partial x_1}, \\ldots, \\frac{\\partial f}{\\partial x_n} )$ using:\n", 3279 | "\n", 3280 | "$$\n", 3281 | "\\frac{\\partial f({\\bf x})}{\\partial x_i} \\approx \\frac{f({\\bf x} + h{\\bf e_i}) - f({\\bf x})}{h}\n", 3282 | "$$\n", 3283 | "\n", 3284 | "where $\\bf{e_i}$ is the $i-th$ unit vector and $ h > 0$ is a small step size.\n", 3285 | "\n", 3286 | "This has the **advantage** of being uncomplicated to implement, but the **disadvantages** of performin $O(n)$\n", 3287 | "evaluations of $f$ for a gradient in $n$ dimensions, and requiring **careful considerations** in selecting the step size $h$.\n", 3288 | "\n", 3289 | "Indeed, Numerical approximations of derivatives are inherently ill-conditioned and unstable[$^4$](#fn4).\n", 3290 | "\n", 3291 | "\n", 3292 | "4: Using the limit definition of the derivative for finite difference approximation commits both cardinal sins of numerical analysis: **“thou shalt not add small numbers to big numbers”, and “thou shalt not subtract numbers which are approximately equal”**." 3293 | ] 3294 | }, 3295 | { 3296 | "cell_type": "markdown", 3297 | "metadata": {}, 3298 | "source": [ 3299 | "###### Automatic Differentiation $=$ Automatic Differentiation" 3300 | ] 3301 | }, 3302 | { 3303 | "cell_type": "markdown", 3304 | "metadata": {}, 3305 | "source": [ 3306 | "Automatic Differentiation (AD) gives exact answers in constant time.\n", 3307 | "\n", 3308 | "AD can be thought of as performing a non-standard interpretation of a computer program where this interpretation involves **augmenting** the standard computation with the calculation of various derivatives. \n", 3309 | "\n", 3310 | "All numerical computations are ultimately compositions of a finite set of elementary operations for which derivatives are known, and combining the derivatives of the constituent operations through the **chain rule** gives the derivative of the overall composition.\n", 3311 | "\n", 3312 | "The chain rule above told us what we **wanted**. \n", 3313 | "\n", 3314 | "What it does not say is how to compute it efficiently. \n", 3315 | "\n", 3316 | "AD is **a way** to get these gradients efficiently without having to do anything but write the objective function in computer code. \n", 3317 | "\n", 3318 | "AD is a very broadly applicable technique and it has been studied for decades. Curiously however, it has traditionally only been used in a limited way in machine learning despite the ubiquity of gradient-based optimization problems in ML. \n", 3319 | "\n", 3320 | "It has only been recently that serious automatic differentiation (versus naïve hand-coded backprop rules) have started to make their way into mainstream deep learning toolchains. \n", 3321 | "\n", 3322 | "TensorFlow, for example, has the ability to compute gradients of the computational graphs it supports with its limited domain-specific language, but it is not close to a full AD system. \n", 3323 | "\n", 3324 | "Automatic differentiation can be implemented in a variety of ways, via **run-time abstractions** and also via **source code transformation**." 3325 | ] 3326 | }, 3327 | { 3328 | "cell_type": "markdown", 3329 | "metadata": {}, 3330 | "source": [ 3331 | "###### Example\n", 3332 | "\n", 3333 | "Rather than talking about large neural networks, we will seek to understand automatic differentiation via a small problem borrowed from the book of *Griewank and Walther (2008)*.\n", 3334 | "\n", 3335 | "In the following we will adopt their very same **three-part** notation (also used in [4]).\n", 3336 | "\n", 3337 | "A function $f: \\mathbb{R^n} \\mapsto \\mathbb{R^m}$ is constructed using intermediate variables $v_i$ such that:\n", 3338 | "\n", 3339 | "- variables $v_{i-n} = x_i$, $i = 1,\\ldots,n$ are the input variables;\n", 3340 | "- variables $v_i$, $i = 1,\\ldots,l$ are the working **intermediate** variables;\n", 3341 | "- variables $y_{m-i} = v_{l-i}$, $i = m-1,\\ldots,0$ are the output variables.\n", 3342 | "\n", 3343 | "" 3344 | ] 3345 | }, 3346 | { 3347 | "cell_type": "markdown", 3348 | "metadata": {}, 3349 | "source": [ 3350 | "The **traversal** of the graph and the **direction** in which gradients are actually computed defines the two modalities of AD:\n", 3351 | "\n", 3352 | "* **forward mode** AD;\n", 3353 | "* **backward mode** AD." 3354 | ] 3355 | }, 3356 | { 3357 | "cell_type": "code", 3358 | "execution_count": null, 3359 | "metadata": {}, 3360 | "outputs": [], 3361 | "source": [] 3362 | }, 3363 | { 3364 | "cell_type": "markdown", 3365 | "metadata": {}, 3366 | "source": [ 3367 | "###### Forward (Tangent) Mode" 3368 | ] 3369 | }, 3370 | { 3371 | "cell_type": "markdown", 3372 | "metadata": {}, 3373 | "source": [ 3374 | "The idea of forward mode automatic differentiation is that we can compute derivatives as we go and\n", 3375 | "that the chain rule says the overall derivative that we want is a composition of these incremental\n", 3376 | "computations.\n", 3377 | "\n", 3378 | "Let’s imagine that our overall goal is to compute $\\frac{\\partial y}{\\partial x_1}$ for the example above.\n", 3379 | "\n", 3380 | "We denote all the intermediate *partial derivatives* (**with respect to $x_1$**) as:\n", 3381 | "$$\n", 3382 | "\\dot{v}_{i} = \\frac{\\partial v_i}{\\partial x_1}\n", 3383 | "$$\n", 3384 | "\n", 3385 | "The **very same** applies when we want to compute $\\frac{\\partial y}{\\partial x_2}$" 3386 | ] 3387 | }, 3388 | { 3389 | "cell_type": "markdown", 3390 | "metadata": {}, 3391 | "source": [ 3392 | "" 3393 | ] 3394 | }, 3395 | { 3396 | "cell_type": "markdown", 3397 | "metadata": {}, 3398 | "source": [ 3399 | "**Dual Numbers**\n", 3400 | "\n", 3401 | "Mathematically, forward mode AD can be viewed as evaluating a function using **dual numbers** which can be defined as truncated Taylor series of the form:\n", 3402 | "$$\n", 3403 | "v + \\dot{v}\\epsilon\n", 3404 | "$$\n", 3405 | "\n", 3406 | "where $v$ and $\\dot{v} \\in \\mathbb{R}$, and $\\epsilon$ is a *nilpotent* number such as $\\epsilon^2 = 0$, and\n", 3407 | "$\\epsilon \\neq 0$.\n", 3408 | "\n", 3409 | "Therefore:\n", 3410 | "$$\n", 3411 | "f(v + \\dot{v}\\epsilon) = f(v) + f'(v)\\dot{v}\\epsilon\n", 3412 | "$$\n", 3413 | "\n", 3414 | "and using dual numbers as data structures for carrying the **tangent** value together with the primal.\n", 3415 | "In fact, **forward mode AD** is also known as **tangent mode**." 3416 | ] 3417 | }, 3418 | { 3419 | "cell_type": "markdown", 3420 | "metadata": {}, 3421 | "source": [ 3422 | "---" 3423 | ] 3424 | }, 3425 | { 3426 | "cell_type": "markdown", 3427 | "metadata": {}, 3428 | "source": [ 3429 | "The interesting bit is that we can implement this bookkeeping during the *execution trace* just via abstraction.\n", 3430 | "\n", 3431 | "We can replace our floating point numbers with `tuples`, and replace primitive functions with the following Python implementation (just using `numpy`)" 3432 | ] 3433 | }, 3434 | { 3435 | "cell_type": "code", 3436 | "execution_count": 2, 3437 | "metadata": {}, 3438 | "outputs": [], 3439 | "source": [ 3440 | "import numpy as np\n", 3441 | "\n", 3442 | "def add(atuple, btuple):\n", 3443 | " (a, adot) = atuple\n", 3444 | " (b, bdot) = btuple\n", 3445 | " return ( a + b, adot + bdot)\n", 3446 | "\n", 3447 | "def subtract(atuple, btuple): \n", 3448 | " (a, adot) = atuple\n", 3449 | " (b, bdot) = btuple\n", 3450 | " return (a - b, adot - bdot)\n", 3451 | "\n", 3452 | "def multiply(atuple, btuple):\n", 3453 | " (a, adot) = atuple\n", 3454 | " (b, bdot) = btuple\n", 3455 | " return (a * b, adot * b + bdot * a)\n", 3456 | "\n", 3457 | "def divide(atuple, btuple):\n", 3458 | " (a, adot) = atuple\n", 3459 | " (b, bdot) = btuple\n", 3460 | " return (a / b, (adot * b - bdot * a) / (b*b))\n", 3461 | "\n", 3462 | "def ln(atuple):\n", 3463 | " (a, adot) = atuple\n", 3464 | " return (np.log(a), (1/a)*adot)\n", 3465 | "\n", 3466 | "def sin(atuple):\n", 3467 | " (a, adot) = atuple\n", 3468 | " return (np.sin(a), np.cos(a)*adot)" 3469 | ] 3470 | }, 3471 | { 3472 | "cell_type": "code", 3473 | "execution_count": 8, 3474 | "metadata": {}, 3475 | "outputs": [], 3476 | "source": [ 3477 | "def f(x1: tuple, x2: tuple):\n", 3478 | " # ln(x1) + x1x2 - sin(x2)\n", 3479 | " v1 = ln(x1)\n", 3480 | " v2 = multiply(x1, x2)\n", 3481 | " v3 = add(v1, v2)\n", 3482 | " v4 = sin(x2)\n", 3483 | " v5 = subtract(v3, v4)\n", 3484 | " return v5" 3485 | ] 3486 | }, 3487 | { 3488 | "cell_type": "code", 3489 | "execution_count": 10, 3490 | "metadata": {}, 3491 | "outputs": [ 3492 | { 3493 | "data": { 3494 | "text/plain": [ 3495 | "(11.652071455223084, 5.5)" 3496 | ] 3497 | }, 3498 | "execution_count": 10, 3499 | "metadata": {}, 3500 | "output_type": "execute_result" 3501 | } 3502 | ], 3503 | "source": [ 3504 | "f(x1=(2, 1), x2=(5, 0))" 3505 | ] 3506 | }, 3507 | { 3508 | "cell_type": "code", 3509 | "execution_count": 11, 3510 | "metadata": {}, 3511 | "outputs": [ 3512 | { 3513 | "data": { 3514 | "text/plain": [ 3515 | "(11.652071455223084, 1.7163378145367738)" 3516 | ] 3517 | }, 3518 | "execution_count": 11, 3519 | "metadata": {}, 3520 | "output_type": "execute_result" 3521 | } 3522 | ], 3523 | "source": [ 3524 | "f(x1=(2, 0), x2=(5, 1))" 3525 | ] 3526 | }, 3527 | { 3528 | "cell_type": "markdown", 3529 | "metadata": {}, 3530 | "source": [ 3531 | "###### Reverse (Co-Tangent) Mode\n", 3532 | "\n", 3533 | "AD in the reverse accumulation mode corresponds to a generalized backpropagation algorithm, in that it propagates derivatives backward from a given output. This is done by complementing each intermediate variable $v_i$ with an **adjoint**:\n", 3534 | "$$\n", 3535 | "\\bar{v}_{i} = \\frac{\\partial y_i}{\\partial v_i} = \\displaystyle{\\sum_{j:\\text{child of i}} \\bar{vj} \\frac{\\partial v_j}{\\partial v_i}}\n", 3536 | "$$" 3537 | ] 3538 | }, 3539 | { 3540 | "cell_type": "markdown", 3541 | "metadata": {}, 3542 | "source": [ 3543 | "" 3544 | ] 3545 | }, 3546 | { 3547 | "cell_type": "markdown", 3548 | "metadata": {}, 3549 | "source": [ 3550 | "There are various ways to implement this abstraction in its full generality, but an implementation requires more code than can easily appear here. The three major approaches are:\n", 3551 | "\n", 3552 | "**source code transformation**: The adjoint backward pass code is generated a priori from the forward computation. A clean Python example of such a system is [**Tangent**](https://colab.research.google.com/drive/1cjoX9GteBymbnqcikNMZP1uenMcwAGDe).\n", 3553 | "\n", 3554 | "**graph-based**: This approach uses an embedded mini-language to specify a graph of computations that can then be manipulated for function evaluations and gradients. \n", 3555 | "\n", 3556 | "$\\rightarrow$ The advantage of this approach is that it is amenable to intelligent graph optimizations and use of compilers. The embedded mini-language also makes it possible to build specialized hardware that targets the differentiable primitives. \n", 3557 | "\n", 3558 | "$\\rightarrow$ The downside of this approach is that you are not coding in the host language (e.g., Python) and so you can’t take advantage of its imperative design and control flow. Generally the mini-language is less expressive than the host language. Also, the lazy execution of the function represented by the graph can make it difficult to debug. TensorFlow 1.x is an example of this kind of automatic differentiation.\n", 3559 | "\n", 3560 | "**tape-based**: This approach tracks the actual composed functions as they are called during execution of the forward pass. One name for this data structure is the *Wengert list*. \n", 3561 | "\n", 3562 | "$\\rightarrow$ With the ordered sequence of computations in hand, it is then possible to walk backward through the list to compute the gradient. \n", 3563 | "\n", 3564 | "$\\rightarrow$ The advantage of this is that it can more easily use all the features of the host language and the imperative execution is easier to understand. \n", 3565 | "\n", 3566 | "$\\rightarrow$ The downside is that it can be more difficult to optimize the code and reuse computations across executions. \n", 3567 | "\n", 3568 | "[Autograd](https://github.com/HIPS/autograd) is an example of this. \n", 3569 | "The automatic differentiation in [PyTorch](https://pytorch.org/) also roughly follows this model." 3570 | ] 3571 | }, 3572 | { 3573 | "cell_type": "markdown", 3574 | "metadata": {}, 3575 | "source": [ 3576 | "" 3577 | ] 3578 | }, 3579 | { 3580 | "cell_type": "markdown", 3581 | "metadata": {}, 3582 | "source": [ 3583 | "### Re-inventing the Wheel: Introducing `micrograd`\n", 3584 | "\n", 3585 | "[`micrograd`](https://github.com/karpathy/micrograd) is:\n", 3586 | "\n", 3587 | "> a tiny Autograd engine (with a bite! :)). \n", 3588 | "> Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural networks library on top of it with a PyTorch-like API. \n", 3589 | ">\n", 3590 | "> Both are tiny, with about 100 and 50 lines of code respectively. \n", 3591 | ">\n", 3592 | "> The DAG only operates over scalar values, so e.g. we chop up each neuron into all of its individual tiny adds and multiplies. However, this is enough to build up entire deep neural nets doing binary classification. Potentially useful for educational purposes." 3593 | ] 3594 | }, 3595 | { 3596 | "cell_type": "code", 3597 | "execution_count": 1, 3598 | "metadata": {}, 3599 | "outputs": [], 3600 | "source": [ 3601 | "def example(this):\n", 3602 | " for i in range(190):\n", 3603 | " print(map(lambda x: x+2, \n", 3604 | " iter(range(32))))" 3605 | ] 3606 | }, 3607 | { 3608 | "cell_type": "markdown", 3609 | "metadata": {}, 3610 | "source": [ 3611 | "(*from the doc*)\n", 3612 | "\n", 3613 | "```python\n", 3614 | "from micrograd.engine import Value\n", 3615 | "\n", 3616 | "a = Value(-4.0)\n", 3617 | "b = Value(2.0)\n", 3618 | "c = a + b\n", 3619 | "d = a * b + b**3\n", 3620 | "c += c + 1\n", 3621 | "c += 1 + c + (-a)\n", 3622 | "d += d * 2 + (b + a).relu()\n", 3623 | "d += 3 * d + (b - a).relu()\n", 3624 | "e = c - d\n", 3625 | "f = e**2\n", 3626 | "g = f / 2.0\n", 3627 | "g += 10.0 / f\n", 3628 | "print(f'{g.data:.4f}') # prints 24.7041, the outcome of this forward pass\n", 3629 | "g.backward()\n", 3630 | "print(f'{a.grad:.4f}') # prints 138.8338, i.e. the numerical value of dg/da\n", 3631 | "print(f'{b.grad:.4f}') # prints 645.5773, i.e. the numerical value of dg/db\n", 3632 | "```" 3633 | ] 3634 | }, 3635 | { 3636 | "cell_type": "markdown", 3637 | "metadata": {}, 3638 | "source": [ 3639 | "```python\n", 3640 | "from micrograd import nn\n", 3641 | "n = nn.Neuron(2)\n", 3642 | "x = [Value(1.0), Value(-2.0)]\n", 3643 | "y = n(x)\n", 3644 | "dot = draw_dot(y)\n", 3645 | "```\n", 3646 | "\n", 3647 | "" 3648 | ] 3649 | }, 3650 | { 3651 | "cell_type": "markdown", 3652 | "metadata": {}, 3653 | "source": [ 3654 | "### References and Futher Reading:\n", 3655 | "\n", 3656 | "1. [Automatic Differentiation Step by Step](https://medium.com/@marksaroufim/automatic-differentiation-step-by-step-24240f97a6e6)\n", 3657 | "\n", 3658 | "2. [Deep Learning with PyTorch (**free sample**) - Luca Antiga et. al.](https://pytorch.org/deep-learning-with-pytorch)\n", 3659 | "\n", 3660 | "3. [Python Machine Learning, 3rd ed. - Sebastian Raschka](https://sebastianraschka.com/books.html)\n", 3661 | "\n", 3662 | "4. [(*Paper*) Automatic Differentiation in Machine Learning: a Survey](https://arxiv.org/abs/1502.05767)" 3663 | ] 3664 | }, 3665 | { 3666 | "cell_type": "code", 3667 | "execution_count": null, 3668 | "metadata": {}, 3669 | "outputs": [], 3670 | "source": [] 3671 | } 3672 | ], 3673 | "metadata": { 3674 | "kernelspec": { 3675 | "display_name": "Python 3 (ipykernel)", 3676 | "language": "python", 3677 | "name": "python3" 3678 | }, 3679 | "language_info": { 3680 | "codemirror_mode": { 3681 | "name": "ipython", 3682 | "version": 3 3683 | }, 3684 | "file_extension": ".py", 3685 | "mimetype": "text/x-python", 3686 | "name": "python", 3687 | "nbconvert_exporter": "python", 3688 | "pygments_lexer": "ipython3", 3689 | "version": "3.10.4" 3690 | } 3691 | }, 3692 | "nbformat": 4, 3693 | "nbformat_minor": 4 3694 | } 3695 | --------------------------------------------------------------------------------