├── .gitignore
├── LICENSE
├── README.md
├── START_HERE.ipynb
├── contlearn
├── __init__.py
├── getdata.py
├── getmodels.py
└── gettrainer.py
├── files
├── F1.large.jpg
├── baseline_mnist.png
├── basemodel.pth
├── catastrophic_forgetting.png
├── ewc_training.png
├── ewcmodel.pth
└── notebook1.png
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Misc
2 | .DS_Store
3 | docker-compose.yml
4 | docker-compose.override.yml
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Carson Lam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # intro_continual_learning
2 |
3 | This is a tutorial to connect the mathematics and machine learning theory to practical implementations addressing the continual learning problem of artificial intelligence. We will learn this in python by examining and deconstructing a method called [elastic weight consolidation](https://www.pnas.org/content/114/13/3521) (EWC).
4 |
5 | I wish there were more learning tools in this style that directly try to help the learner connect the math to the code, and do it using a simple but completely end to end project. While it is true that the average programmer can load a "out of the box" library in 5 minutes and be running the latest model solving a common task in 15 minutes, I often hear from engineers that although they are engineers, they feel under-developed in the math that underlies recent academic research in machine learning. I have received criticism from some that believe tutorials like this provide a shortcut for "average" engineers to "think" they understand the math behind a new flashy artificial intelligence concept, who think the joy of reading these papers should be reserved for the traditionally trained academics that have gone through the years of formal coursework. I think there is nothing wrong with motivating learners using a cool AI concept to learn more of the fundamental math on their own.
6 |
7 | "anyone can cook" - ratatouille
8 |
9 |
10 |
11 |
12 |
13 | ### What does elastic weight consolidation do?
14 |
15 | The ability to learn tasks in a sequential fashion is crucial to the development of artificial intelligence. When an artificial neural network is trained on a new training set, unless that new training set includes all the old tasks combined with the new task, it generally is subject to catastrophic forgetting, whereby learning to solve new task B accompanies degradation of performance at old task A. In contrast, human neural networks can maintain expertise on tasks that they have not experienced for a long time. EWC addresses this problem by selectively slowing down learning on the weights (ie parameters, synaptic strengths) important for those old tasks.
16 |
17 | ## Setup
18 |
19 | - Ubuntu 18.04.3 LTS (bionic)
20 | - Python 3.8
21 | - Cuda 10.1
22 | - cudnn7.6.4
23 | - PyTorch 1.10.0
24 |
25 | ### These same steps should work on MacOS to
26 |
27 | ```console
28 | you@you:/path/to/folder$ pip3 install virtualenv
29 |
30 | you@you:/path/to/folder$ virtualenv venv --python=python3.8
31 |
32 | you@you:/path/to/folder$ source venv/bin/activate
33 |
34 | (venv) you@you:/path/to/folder$ pip3 install -r requirements.txt
35 |
36 | (venv) you@you:/path/to/folder$ jupyter notebook
37 | ```
38 |
39 | ### Credit/References:
40 |
41 | 1. [James Kirkpatrick et al. Overcoming catastrophic forgetting in neural networks 2016(10.1073/pnas.1611835114)](https://www.pnas.org/content/114/13/3521)
42 |
43 | 2. [shivamsaboo17](https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks)
44 |
45 | 3. [moskomule](https://github.com/moskomule/ewc.pytorch)
46 |
--------------------------------------------------------------------------------
/START_HERE.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Part 1 intuition behind elastic weight colsolidation\n",
8 | "\n",
9 | "In the figure below, $\\theta^{*}$ are the weights (ie parameters, synaptic strengths) learned by the neural network (NN) to solve old task A, shown as a vector in vector space. ie if this neural network as 100 total weights then this is a 2D representation of $\\mathbb{R}^{100}$ space. The blue horizontal arrow shows an example of catastrophic forgetting, whereby $\\theta^{*}$ moves out of the region that allows the NN to perform well at task A (grey), and into the center of a region that allows the NN to perform well at task B (cream). The downward green arrow is the update to $\\theta^{*}$ regularized by L2 penalty $\\alpha (\\theta_{i} - \\theta_{A , i}^{*})^{2}$ that causes it to move toward the cream region irrespective of the shape of the grey region. The desired update vector is the red arrow that moves the NN weights into a region capable of performing well at both tasks A and B. \n",
10 | "\n",
11 | "How Elastic Weight Cosolidation Changes Learning New Weights $\\theta^{*}$\n",
12 | "
\n",
13 | "\n",
14 | "
\n",
15 | "\n",
16 | "EWC encourages movement of weights along the red path by modifying the loss function when re-training a NN that has already been trained to convergence using the loss function for task A, $L_{A}$, which has settled on weights $\\theta_{A}$. When re-training the NN on task B using $L_{B}$, we add a term which penalizes changes to weights that are both far from $\\theta_{A}$, ie $(\\theta_{i} - \\theta_{A , i}^{*})^{2}$, and also high in $F_{i}$ which encodes the shape of the grey region.\n",
17 | "\n",
18 | "$$L \\left(\\right. \\theta \\left.\\right) = L_{B} \\left(\\right. \\theta \\left.\\right) + \\underset{i}{\\sum} \\frac{\\lambda}{2} F_{i} \\left(\\theta_{i} - \\theta_{A , i}^{*}\\right)^{2} $$\n",
19 | "\n",
20 | "### But what is F_i? \n",
21 | "\n",
22 | "F_i is the diagonal of the Fisher information matrix. We want to use the diagonal components in Fisher Information Matrix to identify which parameters are more important to task A and apply higher weights to them (the direction of the short axis of grey oval). To learn B we should instead change those weights where F_i is small (long axis of grey oval). \n",
23 | "\n",
24 | "In the EWC paper:\n",
25 | "\n",
26 | "\"we approximate the posterior as a Gaussian distribution with mean given by the parameters θ∗A and a diagonal precision given by the diagonal of the Fisher information matrix F. F has three key properties (20): (i) It is equivalent to the second derivative of the loss near a minimum, (ii) it can be computed from first-order derivatives alone and is thus easy to calculate even for large models, and (iii) it is guaranteed to be positive semidefinite. Note that this approach is similar to expectation propagation where each subtask is seen as a factor of the posterior (21). where LB(θ) is the loss for task B only, λ sets how important the old task is compared with the new one, and i labels each parameter.\n",
27 | "\n",
28 | "When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty.\"\n",
29 | "\n",
30 | "### Lets learn what F is in the example\n",
31 | "\n",
32 | "This article gives a very good explaination of F in the context of EWC: [Fisher Information Matrix by Yuan-Hong Liao](https://andrewliao11.github.io/blog/fisher-info-matrix/)\n",
33 | "\n",
34 | "To compute F_i, we sample the data from task A once and calculate the empirical Fisher Information Matrix. \n",
35 | "\n",
36 | "$$\n",
37 | "I_{\\theta_\\mathcal{A}^*} = \\frac{1}{N} \\sum_{i=1}^{N} \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*) \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)^T\n",
38 | "$$\n",
39 | "\n",
40 | "This is just to say that the above equation is how you calculate the below equation from the data. For each pair of parameters in $\\theta$ ($\\theta_i$ and $\\theta_j$), the Fisher Information matrix at position ij is\n",
41 | "\n",
42 | "$$\n",
43 | "I(\\theta)_{ij} = E\\left[ \\left( \\frac{\\partial}{\\partial\\theta_i}\\log f(X;\\theta) \\right)\\left( \\frac{\\partial}{\\partial\\theta_j}\\log f(X;\\theta) \\right) \\mid \\theta\\right]\n",
44 | "$$\n",
45 | "\n",
46 | "If this equation is hard to understand, then the code should make it clearer, dont worry, we will match parts of the code to the equation above so it becomes more tangible.\n",
47 | "\n",
48 | "# Part 2 A look at the data and the task\n",
49 | "\n",
50 | "### MNIST\n",
51 | "\n",
52 | "The MNIST data set contains 70,000 images of handwritten digits and their corresponding labels. The images are 28x28 with pixel values from 0 to 255. The labels are the digits from 0 to 9. By default 60,000 of these images belong to a training set and 10,000 of these images belong to a test set.\n",
53 | "\n",
54 | "### Fashion-MNIST\n",
55 | "\n",
56 | "Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. Zalando intends Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.\n",
57 | "\n",
58 | "Each training and test example is assigned to one of the following labels:\n",
59 | "\n",
60 | "- 0 T-shirt/top\n",
61 | "- 1 Trouser\n",
62 | "- 2 Pullover\n",
63 | "- 3 Dress\n",
64 | "- 4 Coat\n",
65 | "- 5 Sandal\n",
66 | "- 6 Shirt\n",
67 | "- 7 Sneaker\n",
68 | "- 8 Bag\n",
69 | "- 9 Ankle boot\n",
70 | "\n",
71 | "## Task\n",
72 | "\n",
73 | "as you might guess, our goal is to train an NN that retains it's ability to perform well on MNIST after being retrained on only Fashion-MNIST"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 1,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "import numpy as np\n",
83 | "from PIL import Image\n",
84 | "from matplotlib import pyplot as plt\n",
85 | "from matplotlib.pyplot import imshow\n",
86 | "\n",
87 | "from contlearn.getdata import getMNIST, getFashionMNIST\n",
88 | "\n",
89 | "%load_ext autoreload\n",
90 | "%autoreload 2\n",
91 | "%matplotlib inline"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 2,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "# task A training and test set\n",
101 | "train_loader_mnist, test_loader_mnist = getMNIST(batch_size=32)\n",
102 | "\n",
103 | "# task B training and test set\n",
104 | "train_loader_fashion, test_loader_fashion = getFashionMNIST(batch_size=32)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 3,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "tensor(1)\n",
117 | "torch.Size([28, 28])\n"
118 | ]
119 | },
120 | {
121 | "data": {
122 | "text/plain": [
123 | ""
124 | ]
125 | },
126 | "execution_count": 3,
127 | "metadata": {},
128 | "output_type": "execute_result"
129 | },
130 | {
131 | "data": {
132 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAALp0lEQVR4nO3df6jV9R3H8ddrrkGUkC52uZisLBFEmI2LDCaj1TL1H+ufoX8Mx4LbH0UF+2PS/lgwBjGWg4iSG0lutGJgkcRYOgndYITXcHrV1XVhpN6U0Mj+auZ7f5yv42b3fM/1nO/3fE++nw84nHO+73PP980XX36+P845H0eEAFz9vtZ0AwD6g7ADSRB2IAnCDiRB2IEkvt7Pldnm1D9Qs4jwTMt7Gtltr7b9ju1jtjf18l4A6uVur7PbniPpXUl3SzohaZ+kDRFxpORvGNmBmtUxsq+QdCwi3ouIzyS9LGldD+8HoEa9hH2BpA+mPT9RLPsC26O2x22P97AuAD2q/QRdRIxJGpPYjQea1MvIflLSwmnPbyqWARhAvYR9n6TFtm+x/Q1J6yXtqKYtAFXrejc+Ii7YfkjSG5LmSNoaEYcr6wxApbq+9NbVyjhmB2pXy4dqAHx1EHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRF+nbEb/zZ07t7S+f//+0vrmzZtL61u2bLnintAMRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7Fe5a6+9trS+aNGi0no/Z/lFvXoKu+3jks5L+lzShYgYqaIpANWrYmT/YUR8VMH7AKgRx+xAEr2GPSTttL3f9uhML7A9anvc9niP6wLQg15341dGxEnb35K0y/a/I2Lv9BdExJikMUmyzdkeoCE9jewRcbK4PyPpVUkrqmgKQPW6Drvt62zPvfRY0ipJE1U1BqBavezGD0l61fal9/lTRPy1kq4wMJYtW9Z0C6hI12GPiPckfafCXgDUiEtvQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwU9JX+XOnj1bWt+5c2dpna+4Xj0Y2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCa6zX+UuXLhQWj9//nxp/Z577imtz58/v7Te6To/+oeRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Do7SkVEaX39+vWl9WeeeabKdtCDjiO77a22z9iemLZsvu1dtieL+3n1tgmgV7PZjX9B0urLlm2StDsiFkvaXTwHMMA6hj0i9kq6/DOP6yRtKx5vk3RvtW0BqFq3x+xDETFVPP5Q0lC7F9oelTTa5XoAVKTnE3QREbbbnsWJiDFJY5JU9joA9er20ttp28OSVNyfqa4lAHXoNuw7JG0sHm+U9Fo17QCoy2wuvb0k6Z+Sltg+Yft+SU9Iutv2pKQfFc8BDLCOx+wRsaFN6a6KewFQIz4uCyRB2IEkCDuQBGEHkiDsQBJ8xTW5I0eOlNbXrFlTWr/ttttK6zfccEPb2scff1z6t6gWIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJOFOPxVc6cr4pZqvnEOHDpXWly5dWlq/884729b27NnTVU8oFxGeaTkjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwffZUWrLli2l9aeffrpPnaBXjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATX2VGq0+8dXLx4sbS+bNmytjW+z95fs5mffavtM7Ynpi173PZJ2weK29p62wTQq9nsxr8gafUMy38fEcuL21+qbQtA1TqGPSL2Sjrbh14A1KiXE3QP2T5Y7ObPa/ci26O2x22P97AuAD3qNuzPSrpV0nJJU5KebPfCiBiLiJGIGOlyXQAq0FXYI+J0RHweERclPSdpRbVtAahaV2G3PTzt6X2SJtq9FsBg6Hid3fZLku6QdKPtE5J+JekO28slhaTjkh6or0V8lU1MMA4Mio5hj4gNMyx+voZeANSIj8sCSRB2IAnCDiRB2IEkCDuQBF9xRalVq1Y13QIqwsgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0lwnT25BQsWlNZHRsp/YOjcuXOl9VOnTl1xT6gHIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF19uQWL15cWh8eHi6tHz16tLQ+OTl5xT2hHozsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE19lRKiJ6qmNwdBzZbS+0/abtI7YP236kWD7f9i7bk8X9vPrbBdCt2ezGX5D084hYKul7kh60vVTSJkm7I2KxpN3FcwADqmPYI2IqIt4uHp+XdFTSAknrJG0rXrZN0r019QigAld0zG77Zkm3S3pL0lBETBWlDyUNtfmbUUmjPfQIoAKzPhtv+3pJ2yU9GhGfTK9F6yzNjGdqImIsIkYiovyXCwHUalZht32NWkF/MSJeKRaftj1c1IclnamnRQBV6Lgbb9uSnpd0NCI2TyvtkLRR0hPF/Wu1dIiBtmTJktL6U0891bb28MMPV90OSszmmP37kn4i6ZDtA8Wyx9QK+Z9t3y/pfUk/rqVDAJXoGPaI+IcktynfVW07AOrCx2WBJAg7kARhB5Ig7EAShB1Igq+4oidz5swprW/fvr1PnaATRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7OjJuXPnSuunTp3qUyfohJEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwP6fctc38vkDNImLGX4NmZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDqG3fZC22/aPmL7sO1HiuWP2z5p+0BxW1t/uwC61fFDNbaHJQ1HxNu250raL+leteZj/zQifjfrlfGhGqB27T5UM5v52ackTRWPz9s+KmlBte0BqNsVHbPbvlnS7ZLeKhY9ZPug7a2257X5m1Hb47bHe2sVQC9m/dl429dL2iPpNxHxiu0hSR9JCkm/VmtX/2cd3oPdeKBm7XbjZxV229dIel3SGxGxeYb6zZJej4hlHd6HsAM16/qLMLYt6XlJR6cHvThxd8l9kiZ6bRJAfWZzNn6lpL9LOiTpYrH4MUkbJC1Xazf+uKQHipN5Ze/FyA7UrKfd+KoQdqB+fJ8dSI6wA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRMcfnKzYR5Len/b8xmLZIBrU3ga1L4neulVlb99uV+jr99m/tHJ7PCJGGmugxKD2Nqh9SfTWrX71xm48kARhB5JoOuxjDa+/zKD2Nqh9SfTWrb701ugxO4D+aXpkB9AnhB1IopGw215t+x3bx2xvaqKHdmwft32omIa60fnpijn0ztiemLZsvu1dtieL+xnn2Guot4GYxrtkmvFGt13T05/3/Zjd9hxJ70q6W9IJSfskbYiII31tpA3bxyWNRETjH8Cw/QNJn0r6w6WptWz/VtLZiHii+I9yXkT8YkB6e1xXOI13Tb21m2b8p2pw21U5/Xk3mhjZV0g6FhHvRcRnkl6WtK6BPgZeROyVdPayxeskbSseb1PrH0vfteltIETEVES8XTw+L+nSNOONbruSvvqiibAvkPTBtOcnNFjzvYeknbb32x5tupkZDE2bZutDSUNNNjODjtN499Nl04wPzLbrZvrzXnGC7stWRsR3Ja2R9GCxuzqQonUMNkjXTp+VdKtacwBOSXqyyWaKaca3S3o0Ij6ZXmty283QV1+2WxNhPylp4bTnNxXLBkJEnCzuz0h6Va3DjkFy+tIMusX9mYb7+b+IOB0Rn0fERUnPqcFtV0wzvl3SixHxSrG48W03U1/92m5NhH2fpMW2b7H9DUnrJe1ooI8vsX1dceJEtq+TtEqDNxX1Dkkbi8cbJb3WYC9fMCjTeLebZlwNb7vGpz+PiL7fJK1V64z8fyT9soke2vS1SNK/itvhpnuT9JJau3X/Vevcxv2Svilpt6RJSX+TNH+AevujWlN7H1QrWMMN9bZSrV30g5IOFLe1TW+7kr76st34uCyQBCfogCQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJ/wG7A6WdK+psKAAAAABJRU5ErkJggg==\n",
133 | "text/plain": [
134 | ""
135 | ]
136 | },
137 | "metadata": {
138 | "needs_background": "light"
139 | },
140 | "output_type": "display_data"
141 | }
142 | ],
143 | "source": [
144 | "input_image, target_label = next(iter(train_loader_mnist))\n",
145 | "\n",
146 | "print(target_label[0])\n",
147 | "print(input_image[0][0].shape)\n",
148 | "\n",
149 | "img = Image.fromarray(input_image[0][0].detach().cpu().numpy()*255)\n",
150 | "\n",
151 | "plt.imshow(img)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 4,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "name": "stdout",
161 | "output_type": "stream",
162 | "text": [
163 | "Coat\n",
164 | "torch.Size([28, 28])\n"
165 | ]
166 | },
167 | {
168 | "data": {
169 | "text/plain": [
170 | ""
171 | ]
172 | },
173 | "execution_count": 4,
174 | "metadata": {},
175 | "output_type": "execute_result"
176 | },
177 | {
178 | "data": {
179 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASZUlEQVR4nO3dXWxV15UH8P8KYAg23wTHYAIYkkAYGDpCaKKiURI0iJJEpC+hSFREiuo+tFIr8TBR5qHkoVI0Gtrpw6iSO4lKJ0yqSrQKD4lSSiqRIoXgRIxNQqZOwHxYBlPzZSB8GK958CFyiM9azj333nPM+v8ky9dn3X3P5tiLc+9ZZ+8tqgoiuvvdk3cHiKg6mOxEQTDZiYJgshMFwWQnCmJsNXcmIrz0X4KJEyea8ZqamtTY+PHjzbYiYsbvucc+H4wbN86MX7lyJTV2/fp1s21fX58Zp+Gp6rC/1EzJLiLrAPwCwBgA/6WqL2d5vUry/qi9EqTV3nvtgYEBM+5ZsmSJGX/ggQdSY01NTWZbL1lra2vN+OzZs834+++/nxrr6Ogw277zzjtm3GP9XrL8vkfSvohKfhsvImMA/CeAbwF4BMAmEXmkXB0jovLK8pl9FYBPVfWoqt4A8FsAG8rTLSIqtyzJPgfAySE/n0q2fYmINItIq4i0ZtgXEWVU8Qt0qtoCoAXgBTqiPGU5s3cBmDvk58ZkGxEVUJZkPwjgQRFZICI1AL4DYHd5ukVE5SZZSggish7Af2Cw9Paqqv7Uef6ofRufpYzj2b9/vxm3atUAcOHChdTY2rVrzbZTpkwx452dnWb86tWrZnzq1KmpMa/st2bNGjPe3t5uxq3X7+/vN9uOxtLabRWps6vqmwDezPIaRFQdvF2WKAgmO1EQTHaiIJjsREEw2YmCYLITBVHV8ex5KvKQRWsYKAAsWrTIjDc2NqbGXnvtNbPtgQMHzLhVwweAJ5980ozPmzcvNdbb22u2vXbtmhn3WL/T0VxHLxXP7ERBMNmJgmCyEwXBZCcKgslOFASTnSiIMKW3rKUWq33Wsp433NIbAmvNPuuVxp544gkz7g0jvXXrlhk/efJkamzhwoVmW++4eqzjOnas/afvzQicdcbgPPDMThQEk50oCCY7URBMdqIgmOxEQTDZiYJgshMFkWkq6a+9sxynkq7kEFdvpdTVq1eb8W3btpnxEydOmPEbN26kxo4cOWK29aapXr58uRnv7u424+vXr0+NXb582Wz79ttvm/F3333XjO/Zsyc15g3dLfKQaE/aVNI8sxMFwWQnCoLJThQEk50oCCY7URBMdqIgmOxEQYSps2e1bt261Nj27dvNtm1tbWb83LlzZnz69Olm3PL555+b8c8++8yMjx8/3ow3NDSY8bq6utSYd//A7NmzzfikSZNKjm/dutVse+jQITNe5Dp8RZZsFpFOAH0AbgHoV9WVWV6PiCqnHDPVPK6qfyvD6xBRBfEzO1EQWZNdAfxRRD4QkebhniAizSLSKiKtGfdFRBlkfRu/WlW7RGQWgD0i8omq7hv6BFVtAdACjO4LdESjXaYzu6p2Jd97APwBwKpydIqIyq/kZBeRWhGZdPsxgLUADperY0RUXiXX2UWkCYNnc2Dw48D/qOpPnTaj9m38W2+9lRo7evSo2ba2ttaMX7p0yYxb49U99957rxkfM2aMGffGfXt1eGv/fX19ZltvTnqv71OnTk2NjRs3zmz73HPPmfEiK3udXVWPAvj7kntERFXF0htREEx2oiCY7ERBMNmJgmCyEwURZslmj1dCskpQXhnn7NmzZtwrj3ms8qk3XfOUKVPMuFW+AvzyV0dHR2rM+3fPmDHDjHulu+vXr6fG5s2bZ7a9G/HMThQEk50oCCY7URBMdqIgmOxEQTDZiYJgshMFwTp7wqu73nNP+v+L06ZNM9t60zl7Q1gnTJhgxm/evFnyvq1aNODX2b0hsNZU016d3bs/YebMmWZ81qxZqbGenh6zrfc7PX/+vBkvIp7ZiYJgshMFwWQnCoLJThQEk50oCCY7URBMdqIgWGdPNDU1mfFr166lxrwllZcuXWrG29vbzfiVK1dKjlv9Bvw6u9feu0fAGmvvjUf37n1YvHixGbeWhL548aLZdtmyZWZ83759ZryIeGYnCoLJThQEk50oCCY7URBMdqIgmOxEQTDZiYJgnT3x8MMPm3GrHi0y7Aq5Xxg71j7MXj25ra3NjHd2dqbGvDHjXp3cq8N79wBcvXo1NdbY2Gi2XbBggRm35hgA7HsEvH/3Qw89ZMbvyjq7iLwqIj0icnjItukiskdEOpLv9kh/IsrdSN7G/xrAuju2vQBgr6o+CGBv8jMRFZib7Kq6D8C5OzZvALAjebwDwDPl7RYRlVupn9nrVbU7eXwaQH3aE0WkGUBzifshojLJfIFOVVVEUkc7qGoLgBYAsJ5HRJVVauntjIg0AEDy3Z6qk4hyV2qy7wawJXm8BcAb5ekOEVWK+zZeRF4H8BiAmSJyCsBPALwM4Hci8jyA4wCerWQnq2HJkiVm/MyZM6kxb41zb7y7Ny67t7fXjJ88eTI1lnVOe2/9dW9teqvOf99995ltvbXlvfsXBgYGUmPefPd34/rtbrKr6qaU0Joy94WIKoi3yxIFwWQnCoLJThQEk50oCCY7URAc4pqwlvcFgOPHj6fGurq6zLbecElvGKr3+jU1Nakxbyrouro6M37r1i0z7pW/rCGw3mt7ZT+vpGm9vldy9IbXjkY8sxMFwWQnCoLJThQEk50oCCY7URBMdqIgmOxEQbDOnqitrTXj1nDLiRMnmm0nTZpkxru7u824N4zUqpV7w0StGv1I4l69Okut26vDe/dGWEOPz527c1rFL1uxYoUZH414ZicKgslOFASTnSgIJjtREEx2oiCY7ERBMNmJgghTZ/fq6N6yy/39/amxCRMmmG29cdnWVNCAX+u2+m5NpwwAqvYiPV7fvVq4dY+AV2f37hHwlmy25gnwXtu7d8K79+HmzZtmPA88sxMFwWQnCoLJThQEk50oCCY7URBMdqIgmOxEQYSps0+ePNmMe/Vkqy67fPlys623bLJX0/Xmlbfq/Nb9AYBfJ/fae8fNugfgxo0bZlvv/oKLFy+a8aVLl6bGvDq5d39CfX29GT916pQZz4N7ZheRV0WkR0QOD9m2TUS6RORQ8rW+st0koqxG8jb+1wDWDbP956q6Ivl6s7zdIqJyc5NdVfcBsOfwIaLCy3KB7oci0pa8zU/9UCoizSLSKiKtGfZFRBmVmuy/BLAQwAoA3QC2pz1RVVtUdaWqrixxX0RUBiUlu6qeUdVbqjoA4FcAVpW3W0RUbiUlu4g0DPnx2wAOpz2XiIrBrbOLyOsAHgMwU0ROAfgJgMdEZAUABdAJ4PuV62J5zJ8/34x79WRrfPKcOXPMtp988okZ99Y49+rRVt+81846nt07blat3Pt3efPt79q1y4xv3LgxNTZ16lSzbW9vrxlvbGw040Wss7vJrqqbhtn8SgX6QkQVxNtliYJgshMFwWQnCoLJThQEk50oiDBDXBctWmTGvWmNrWGk3jTUO3fuNOObN2824960x+PHj0+NeVMae0M9vdJbliGuWYfXeiVNq7T36KOPmm2PHTtmxq1lsouKZ3aiIJjsREEw2YmCYLITBcFkJwqCyU4UBJOdKIgwdfb777/fjHtDGq3hlg0NDakxwJ8K2uMNBbWWLvaGuHrLHnu1bu8eA6sO7+37woULZtwbRnrlypXU2OzZs822Xg3fm/67iHhmJwqCyU4UBJOdKAgmO1EQTHaiIJjsREEw2YmCCFNnnzFjhhn3xrNbtfKDBw+W1KfbvFq2V9O16sneeHZvPLo33t27h8Aas+79u71x/B7ruHhj6b37E7zjVkQ8sxMFwWQnCoLJThQEk50oCCY7URBMdqIgmOxEQYSps3tL9HpLF1vj4U+cOGG29Wqyp0+fNuPWnPUAMHny5NTY+fPnzbZWLRqw56QfCavO741n9/btHdeenp7UmHffhVdn7+vrM+NF5J7ZRWSuiPxZRD4WkY9E5EfJ9ukiskdEOpLv0yrfXSIq1UjexvcD2KqqjwD4RwA/EJFHALwAYK+qPghgb/IzERWUm+yq2q2qHyaP+wAcATAHwAYAO5Kn7QDwTIX6SERl8LU+s4vIfADfAHAAQL2qdieh0wDqU9o0A2jO0EciKoMRX40XkToAuwD8WFUvDY3p4NWtYa9wqWqLqq5U1ZWZekpEmYwo2UVkHAYTfaeq/j7ZfEZEGpJ4A4D0S59ElDv3bbwMzhX8CoAjqvqzIaHdALYAeDn5/kZFelgmXmnNK0HV1NSkxt577z2z7dNPP23GL168aMa9qaStYaa1tbVmW698de3aNTPuDaG1eENYrem7AWDaNLsAdODAgdTY448/brb1jrlXmiuikfT4mwC+C6BdRA4l217EYJL/TkSeB3AcwLMV6SERlYWb7Kr6FwBpKwGsKW93iKhSeLssURBMdqIgmOxEQTDZiYJgshMFMfqKhSXy6snecMu6urrUmFXPBYCXXnrJjLe2tprxq1evmnGrFu5NmZx1mKk3/Nbqu9fWm8b6qaeeMuPWks7ecfH2bf09FBXP7ERBMNmJgmCyEwXBZCcKgslOFASTnSgIJjtREGHq7N50z96yyFOmTEmN9fb2mm33799vxmfNmmXGvXHf8+fPT41dunQpNQb44/i9+xO8uDUu3JuO2Zti24tbv5fGxkazrff3wDo7ERUWk50oCCY7URBMdqIgmOxEQTDZiYJgshMFEabO7s3Nbs29Dth1VW9s9MaNG8344NT86bw5763lqAcGBsy2Hm/+dK9vVnuvbVarVq1KjS1btizTa2ddyjoPPLMTBcFkJwqCyU4UBJOdKAgmO1EQTHaiIJjsREGMZH32uQB+A6AegAJoUdVfiMg2AN8DcDZ56ouq+malOpqVt5a3N37ZGju9ePFis21PT48Zz1pvvnDhQqb2d6vr16+nxrz7D6y2gL/ufRGN5KaafgBbVfVDEZkE4AMR2ZPEfq6q/1657hFRuYxkffZuAN3J4z4ROQJgTqU7RkTl9bU+s4vIfADfAHB7vaMfikibiLwqIsO+TxaRZhFpFRF7jSMiqqgRJ7uI1AHYBeDHqnoJwC8BLASwAoNn/u3DtVPVFlVdqaors3eXiEo1omQXkXEYTPSdqvp7AFDVM6p6S1UHAPwKQPqoAyLKnZvsMjgk6xUAR1T1Z0O2Nwx52rcBHC5/94ioXEZyNf6bAL4LoF1EDiXbXgSwSURWYLAc1wng+xXoX9k0NDSY8blz55pxa9rj+vr6kvp0m7c8cH9/f6bXH628ob9e+ezYsWMl77upqanktkU1kqvxfwEw3FEvbE2diL6Kd9ARBcFkJwqCyU4UBJOdKAgmO1EQTHaiIKTS0/l+aWci1dvZHazplgFg8+bNZtxa8nnv3r1mW29Z5LFj7Qpo1Dp7JXl1dG/Is1fD937nlaSqw96gwDM7URBMdqIgmOxEQTDZiYJgshMFwWQnCoLJThREtevsZwEcH7JpJoC/Va0DX09R+1bUfgHsW6nK2bd5qnrfcIGqJvtXdi7SWtS56Yrat6L2C2DfSlWtvvFtPFEQTHaiIPJO9pac928pat+K2i+AfStVVfqW62d2IqqevM/sRFQlTHaiIHJJdhFZJyL/JyKfisgLefQhjYh0iki7iBzKe326ZA29HhE5PGTbdBHZIyIdyXd7Lerq9m2biHQlx+6QiKzPqW9zReTPIvKxiHwkIj9Ktud67Ix+VeW4Vf0zu4iMAfBXAP8M4BSAgwA2qerHVe1IChHpBLBSVXO/AUNE/gnAZQC/UdW/S7b9G4Bzqvpy8h/lNFX9l4L0bRuAy3kv452sVtQwdJlxAM8AeA45HjujX8+iCsctjzP7KgCfqupRVb0B4LcANuTQj8JT1X0Azt2xeQOAHcnjHRj8Y6m6lL4Vgqp2q+qHyeM+ALeXGc/12Bn9qoo8kn0OgJNDfj6FYq33rgD+KCIfiEhz3p0ZRr2qdiePTwPItvZU+bnLeFfTHcuMF+bYlbL8eVa8QPdVq1X1HwB8C8APkrerhaSDn8GKVDsd0TLe1TLMMuNfyPPYlbr8eVZ5JHsXgKGrKDYm2wpBVbuS7z0A/oDiLUV95vYKusn3npz784UiLeM93DLjKMCxy3P58zyS/SCAB0VkgYjUAPgOgN059OMrRKQ2uXACEakFsBbFW4p6N4AtyeMtAN7IsS9fUpRlvNOWGUfOxy735c9VtepfANZj8Ir8ZwD+NY8+pPSrCcD/Jl8f5d03AK9j8G3dTQxe23gewAwAewF0APgTgOkF6tt/A2gH0IbBxGrIqW+rMfgWvQ3AoeRrfd7HzuhXVY4bb5clCoIX6IiCYLITBcFkJwqCyU4UBJOdKAgmO1EQTHaiIP4fnZAvWiCTxoQAAAAASUVORK5CYII=\n",
180 | "text/plain": [
181 | ""
182 | ]
183 | },
184 | "metadata": {
185 | "needs_background": "light"
186 | },
187 | "output_type": "display_data"
188 | }
189 | ],
190 | "source": [
191 | "input_image, target_label = next(iter(train_loader_fashion))\n",
192 | "\n",
193 | "fashion_key = {\n",
194 | " 0: \"T-shirt/top\",\n",
195 | " 1: \"Trouser\",\n",
196 | " 2: \"Pullover\",\n",
197 | " 3: \"Dress\",\n",
198 | " 4: \"Coat\",\n",
199 | " 5: \"Sandal\",\n",
200 | " 6: \"Shirt\",\n",
201 | " 7: \"Sneaker\",\n",
202 | " 8: \"Bag\",\n",
203 | " 9: \"Ankle boot\",\n",
204 | "}\n",
205 | "\n",
206 | "print(fashion_key[int(target_label[0].detach().cpu().numpy())])\n",
207 | "print(input_image[0][0].shape)\n",
208 | "\n",
209 | "img = Image.fromarray(input_image[0][0].detach().cpu().numpy()*255)\n",
210 | "\n",
211 | "plt.imshow(img)"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "metadata": {},
217 | "source": [
218 | "# Part 3 Baseline results\n",
219 | "\n",
220 | "first we train on MNIST and the we will observe the drop in performance once we retrain on Fashion-MNIST, WITHOUT Elastic Weight Consolidation"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 3,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "name": "stdout",
230 | "output_type": "stream",
231 | "text": [
232 | "The autoreload extension is already loaded. To reload it, use:\n",
233 | " %reload_ext autoreload\n",
234 | "you are using PyTorch version 1.10.0+cu102\n",
235 | "you have 2 GPUs\n",
236 | "cuda:0\n"
237 | ]
238 | }
239 | ],
240 | "source": [
241 | "import math\n",
242 | "import random\n",
243 | "\n",
244 | "import numpy as np\n",
245 | "import torch\n",
246 | "from matplotlib import pyplot as plt\n",
247 | "from matplotlib.pyplot import imshow\n",
248 | "from tqdm.notebook import tqdm\n",
249 | "\n",
250 | "from contlearn.getmodels import MLP\n",
251 | "from contlearn.gettrainer import one_epoch_baseline, test, var2device\n",
252 | "\n",
253 | "%load_ext autoreload\n",
254 | "%autoreload 2\n",
255 | "%matplotlib inline\n",
256 | "\n",
257 | "print('you are using PyTorch version ',torch.__version__)\n",
258 | "\n",
259 | "if torch.cuda.is_available():\n",
260 | " use_cuda = True\n",
261 | " print(\"you have\", torch.cuda.device_count(), \"GPUs\")\n",
262 | " device = torch.device(\"cuda:0\")\n",
263 | " print(device)\n",
264 | "else:\n",
265 | " use_cuda = False\n",
266 | " print('no GPUs detected')\n",
267 | " device = torch.device(\"cpu\")"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 6,
273 | "metadata": {},
274 | "outputs": [
275 | {
276 | "name": "stdout",
277 | "output_type": "stream",
278 | "text": [
279 | "torch.Size([32, 28, 28])\n",
280 | "torch.Size([32, 10])\n"
281 | ]
282 | }
283 | ],
284 | "source": [
285 | "# initialize a new model\n",
286 | "\n",
287 | "model = MLP(hidden_size=256)\n",
288 | "\n",
289 | "if torch.cuda.is_available() and use_cuda:\n",
290 | " model.cuda()\n",
291 | " \n",
292 | "# push an image through it\n",
293 | "\n",
294 | "input_image, target_label = next(iter(train_loader_fashion))\n",
295 | "\n",
296 | "input_image = var2device(input_image).squeeze(1)\n",
297 | "\n",
298 | "print(input_image.shape)\n",
299 | "\n",
300 | "output = model(input_image)\n",
301 | "\n",
302 | "print(output.shape)"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 7,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "def baseline_training(\n",
312 | " model, \n",
313 | " epochs, \n",
314 | " train_loader,\n",
315 | " test_loader,\n",
316 | " test2_loader = None,\n",
317 | " use_cuda=True, \n",
318 | "):\n",
319 | " \n",
320 | " \"\"\"\n",
321 | " This function saves the training curve data consisting\n",
322 | " training set loss and validation set accuracy over the\n",
323 | " course of the epochs of training.\n",
324 | " \n",
325 | " I set this up such that if you provide 2 test sets,you\n",
326 | " can watch the test accuracy change together during training\n",
327 | " on train_loder\n",
328 | " \"\"\"\n",
329 | " \n",
330 | " if torch.cuda.is_available() and use_cuda:\n",
331 | " model.cuda()\n",
332 | " \n",
333 | " train_loss, val_acc, val2_acc = [], [], []\n",
334 | " \n",
335 | " for epoch in tqdm(range(epochs)):\n",
336 | "\n",
337 | " epoch_loss = one_epoch_baseline(model,train_loader)\n",
338 | " train_loss.append(epoch_loss)\n",
339 | " \n",
340 | " acc = test(model,test_loader)\n",
341 | " val_acc.append(acc.detach().cpu().numpy())\n",
342 | " \n",
343 | " if test2_loader is not None:\n",
344 | " acc2 = test(model,test2_loader)\n",
345 | " val2_acc.append(acc2.detach().cpu().numpy())\n",
346 | " \n",
347 | " return train_loss, val_acc, val2_acc, model "
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": 10,
353 | "metadata": {},
354 | "outputs": [
355 | {
356 | "data": {
357 | "application/vnd.jupyter.widget-view+json": {
358 | "model_id": "0620996aa0b143caa339803ee51b4aec",
359 | "version_major": 2,
360 | "version_minor": 0
361 | },
362 | "text/plain": [
363 | " 0%| | 0/25 [00:00, ?it/s]"
364 | ]
365 | },
366 | "metadata": {},
367 | "output_type": "display_data"
368 | }
369 | ],
370 | "source": [
371 | "# set seeds for reproducibility and train the model using the training loop called\n",
372 | "# baseline_training\n",
373 | "\n",
374 | "torch.manual_seed(0)\n",
375 | "np.random.seed(0)\n",
376 | "random.seed(0)\n",
377 | "\n",
378 | "train_loss, val_acc, val2_acc, model = baseline_training(\n",
379 | " model,\n",
380 | " epochs = 25,\n",
381 | " train_loader = train_loader_mnist,\n",
382 | " test_loader = test_loader_mnist,\n",
383 | ")\n",
384 | "\n",
385 | "# save the trained model\n",
386 | "model = model.cpu()\n",
387 | "torch.save(model, \"files/basemodel.pth\")"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 11,
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "data": {
397 | "image/png": "\n",
398 | "text/plain": [
399 | ""
400 | ]
401 | },
402 | "metadata": {
403 | "needs_background": "light"
404 | },
405 | "output_type": "display_data"
406 | }
407 | ],
408 | "source": [
409 | "plt.figure()\n",
410 | "plt.xlabel('epochs', fontsize=25)\n",
411 | "plt.ylabel('validation accuracy', fontsize=25)\n",
412 | "plt.plot(val_acc, label='mnist')\n",
413 | "plt.legend()\n",
414 | "plt.show()"
415 | ]
416 | },
417 | {
418 | "cell_type": "markdown",
419 | "metadata": {},
420 | "source": [
421 | "### learning curve\n",
422 | "you should get something like this where the accuracy starts to plateau at around 75% for this simple feed forward model\n",
423 | "\n",
424 | ""
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 26,
430 | "metadata": {},
431 | "outputs": [
432 | {
433 | "name": "stdout",
434 | "output_type": "stream",
435 | "text": [
436 | "mnist accuracy tensor(0.8238)\n",
437 | "fashion accuracy tensor(0.0704)\n"
438 | ]
439 | }
440 | ],
441 | "source": [
442 | "# you can use this to load the model instead of training it from scratch like above \n",
443 | "\n",
444 | "model = torch.load(\"files/basemodel.pth\")\n",
445 | "\n",
446 | "if torch.cuda.is_available() and use_cuda:\n",
447 | " model.cuda()\n",
448 | " \n",
449 | "# tensor(0.8238, device='cuda:0') means that the test set accuracy was 82.4%\n",
450 | "# compared to a baseline accuracy of 10% if the model is choosing randomly\n",
451 | "print('mnist accuracy', test(model,test_loader_mnist))\n",
452 | "print('fashion accuracy', test(model,test_loader_fashion))"
453 | ]
454 | },
455 | {
456 | "cell_type": "markdown",
457 | "metadata": {},
458 | "source": [
459 | "### Catastrophic Forgetting\n",
460 | "\n",
461 | "This is one of the main problems we are trying to solve in the continual learning, aka lifelong learning, aka sequential learning, research field. As you can see, learning Fashion MNIST also degrades performance on original MNIST. In part this is because some of the useful parameters, aka weights, used for original MNIST, are overwritten or updated past what is useful for original MNIST in order to become useful for Fashion MNIST\n",
462 | "\n",
463 | "We expect something like this in which the fashion accuracy increases from it's random performance at around 10% to almost 40% and the mnist test accuracy drops to below 30% from a previous performance of around 80%\n",
464 | "\n",
465 | "\n"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 16,
471 | "metadata": {},
472 | "outputs": [
473 | {
474 | "data": {
475 | "application/vnd.jupyter.widget-view+json": {
476 | "model_id": "04ee7766e75c4b2ba21f154dac08758b",
477 | "version_major": 2,
478 | "version_minor": 0
479 | },
480 | "text/plain": [
481 | " 0%| | 0/25 [00:00, ?it/s]"
482 | ]
483 | },
484 | "metadata": {},
485 | "output_type": "display_data"
486 | }
487 | ],
488 | "source": [
489 | "torch.manual_seed(0)\n",
490 | "np.random.seed(0)\n",
491 | "random.seed(0)\n",
492 | "\n",
493 | "train_loss, val_acc, val2_acc, model = baseline_training(\n",
494 | " model,\n",
495 | " epochs = 25,\n",
496 | " train_loader = train_loader_fashion,\n",
497 | " test_loader = test_loader_fashion,\n",
498 | " test2_loader = test_loader_mnist,\n",
499 | ")"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": 17,
505 | "metadata": {},
506 | "outputs": [
507 | {
508 | "data": {
509 | "image/png": "\n",
510 | "text/plain": [
511 | ""
512 | ]
513 | },
514 | "metadata": {
515 | "needs_background": "light"
516 | },
517 | "output_type": "display_data"
518 | }
519 | ],
520 | "source": [
521 | "plt.figure()\n",
522 | "plt.xlabel('epochs', fontsize=25)\n",
523 | "plt.ylabel('validation accuracy', fontsize=25)\n",
524 | "plt.plot(val_acc, label='fashion')\n",
525 | "plt.plot(val2_acc, label='mnist')\n",
526 | "plt.legend()\n",
527 | "plt.show()"
528 | ]
529 | },
530 | {
531 | "cell_type": "markdown",
532 | "metadata": {},
533 | "source": [
534 | "# Part 4 Elastic Weight Consolidation\n",
535 | "\n",
536 | "now that we have implemented the control group, lets implement the experimental group\n",
537 | "\n",
538 | "instead of revisiting every old MNIST example to build our Fisher Information Matrix, we will visit num_samples of them to form a approximate of the matrix\n",
539 | "\n",
540 | "please read the comments in the EWC class, which explains each step of the math needed\n",
541 | "to calculate the EWC loss function, dont worry if it doesnt make sense, we will then go through it step by step\n",
542 | "\n",
543 | "After gathering your examples of task A in the cell below, mnist, run the below cells from the second cell all the way to \n",
544 | "\n",
545 | "#### sum the squares of the gradients\n",
546 | "\n",
547 | "comparing the norm of the precision _matrices when you use a pretrained model and using a randomly initialized model "
548 | ]
549 | },
550 | {
551 | "cell_type": "code",
552 | "execution_count": 4,
553 | "metadata": {},
554 | "outputs": [
555 | {
556 | "name": "stdout",
557 | "output_type": "stream",
558 | "text": [
559 | "num_samples 512\n"
560 | ]
561 | }
562 | ],
563 | "source": [
564 | "# instead of revisiting every old MNIST example to build our Fisher Information Matrix\n",
565 | "# use num_samples of them to calculate a approximate of the matrix\n",
566 | "\n",
567 | "torch.manual_seed(0)\n",
568 | "np.random.seed(0)\n",
569 | "random.seed(0)\n",
570 | "\n",
571 | "num_batches = 16\n",
572 | "\n",
573 | "old_tasks = []\n",
574 | "\n",
575 | "for sample in range(num_batches):\n",
576 | " input_batch, target_batch = next(iter(train_loader_mnist))\n",
577 | " for image in input_batch:\n",
578 | " old_tasks.append(image)\n",
579 | " \n",
580 | "print(\"num_samples\", len(old_tasks))"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": 5,
586 | "metadata": {},
587 | "outputs": [
588 | {
589 | "name": "stdout",
590 | "output_type": "stream",
591 | "text": [
592 | "The autoreload extension is already loaded. To reload it, use:\n",
593 | " %reload_ext autoreload\n",
594 | "accuracy on mnist test set 0.8237999677658081\n"
595 | ]
596 | }
597 | ],
598 | "source": [
599 | "from copy import deepcopy\n",
600 | "\n",
601 | "import torch\n",
602 | "from torch import nn\n",
603 | "from torch.nn import functional as F\n",
604 | "from torch.autograd import Variable\n",
605 | "from torch import optim\n",
606 | "import torch.utils.data\n",
607 | "\n",
608 | "from contlearn.getmodels import MLP\n",
609 | "from contlearn.gettrainer import var2device\n",
610 | "\n",
611 | "%load_ext autoreload\n",
612 | "%autoreload 2\n",
613 | "%matplotlib inline\n",
614 | "\n",
615 | "\n",
616 | "# Uncomment one of the two below lines of code that instantiates a model\n",
617 | "### birth a new randomly initiated model ###\n",
618 | "\n",
619 | "# model = MLP(hidden_size=256)\n",
620 | "\n",
621 | "### load a model previously trained on one task, task A ###\n",
622 | "\n",
623 | "model = torch.load(\"files/basemodel.pth\")\n",
624 | "\n",
625 | "#####\n",
626 | "\n",
627 | "\n",
628 | "if torch.cuda.is_available() and use_cuda:\n",
629 | " model.cuda() \n",
630 | "# tensor(0.8238, device='cuda:0') means that the test set accuracy was 82.4%\n",
631 | "# compared to a baseline accuracy of 10% if the model is choosing randomly\n",
632 | "acc = test(model,test_loader_mnist)\n",
633 | "print(\"accuracy on mnist test set\", acc.item())"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": 6,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "class EWC(object):\n",
643 | " \n",
644 | " \"\"\"\n",
645 | " Class to calculate the Fisher Information Matrix\n",
646 | " used in the Elastic Weight Consolidation portion\n",
647 | " of the loss function\n",
648 | " \"\"\"\n",
649 | " \n",
650 | " def __init__(self, model: nn.Module, dataset: list):\n",
651 | "\n",
652 | " self.model = model #pretrained model\n",
653 | " self.dataset = dataset #samples from the old task or tasks\n",
654 | " \n",
655 | " # n is the string name of the parameter matrix p, aka theta, aka weights\n",
656 | " # in self.params we reference all of those weights that are open to\n",
657 | " # being updated by the gradient\n",
658 | " self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}\n",
659 | " \n",
660 | " # make a copy of the old weights, ie theta_A,star, ie 𝜃∗A, in the loss equation\n",
661 | " # we need this to calculate (𝜃 - 𝜃∗A)^2 because self.params will be changing \n",
662 | " # upon every backward pass and parameter update by the optimizer\n",
663 | " self._means = {}\n",
664 | " for n, p in deepcopy(self.params).items():\n",
665 | " self._means[n] = var2device(p.data)\n",
666 | " \n",
667 | " # calculate the fisher information matrix \n",
668 | " self._precision_matrices = self._diag_fisher()\n",
669 | "\n",
670 | " def _diag_fisher(self):\n",
671 | " \n",
672 | " # save a copy of the zero'd out version of\n",
673 | " # each layer's parameters of the same shape\n",
674 | " # to precision_matrices[n]\n",
675 | " precision_matrices = {}\n",
676 | " for n, p in deepcopy(self.params).items():\n",
677 | " p.data.zero_()\n",
678 | " precision_matrices[n] = var2device(p.data)\n",
679 | "\n",
680 | " # we need the model to calculate the gradient but\n",
681 | " # we have no intention in this step to actually update the model\n",
682 | " # that will have to wait for the combining of this EWC loss term\n",
683 | " # with the new task's loss term\n",
684 | " self.model.eval()\n",
685 | " for input in self.dataset:\n",
686 | " self.model.zero_grad()\n",
687 | " # remove channel dim, these are greyscale, not color rgb images\n",
688 | " # bs,1,h,w -> bs,h,w\n",
689 | " input = input.squeeze(1)\n",
690 | " input = var2device(input)\n",
691 | " output = self.model(input).view(1, -1)\n",
692 | " label = output.max(1)[1].view(-1)\n",
693 | " # calculate loss and backprop\n",
694 | " loss = F.nll_loss(F.log_softmax(output, dim=1), label)\n",
695 | " loss.backward()\n",
696 | "\n",
697 | " for n, p in self.model.named_parameters():\n",
698 | " precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)\n",
699 | "\n",
700 | " precision_matrices = {n: p for n, p in precision_matrices.items()}\n",
701 | " return precision_matrices\n",
702 | "\n",
703 | " def penalty(self, model: nn.Module):\n",
704 | " loss = 0\n",
705 | " for n, p in model.named_parameters():\n",
706 | " _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2\n",
707 | " loss += _loss.sum()\n",
708 | " return loss\n",
709 | " \n",
710 | "# place the model pretrained on mnist, but not fashion-mnist, along with some mnist examples\n",
711 | "# into the Elastic Weight Consolidation object to perform EWC related tasks like calculating the\n",
712 | "# Fisher Matrix\n",
713 | "ewc = EWC(model, old_tasks)"
714 | ]
715 | },
716 | {
717 | "cell_type": "markdown",
718 | "metadata": {},
719 | "source": [
720 | "### Empirical Estimate of the Fisher Information Matrix\n",
721 | "\n",
722 | "Lets work through the _diag_fisher() method that calculates the fisher matrix, \n",
723 | "\n",
724 | "\n",
725 | "together, step by step\n",
726 | "\n",
727 | "as a reminder here is the equation for the fisher matrix one more time\n",
728 | "\n",
729 | "$$\n",
730 | "I_{\\theta_\\mathcal{A}^*} = \\frac{1}{N} \\sum_{i=1}^{N} \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*) \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)^T\n",
731 | "$$\n",
732 | "\n",
733 | "To my understanding, the equation is saying, sum N of these matrices on top of each other such that the resultant matrix is the same shape of the individual matrices. The gradient term $\\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)$ is a vector and the expression $ \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*) \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)^T$ represents the [outer product](https://en.wikipedia.org/wiki/Outer_product) of these two vectors to produce a matrix where the elements of that matrix we are interested in is the diagonal $F_{ii}$. ie we are interested in $I_{ii}$, not $I_{ij}$. \n",
734 | "\n",
735 | "note: does this mean that in addition to penalizing moving $\\theta_i$ when $F_{ii}$ is high, we could also penalize moving both $\\theta_i$ and $\\theta_j$ together when $F_{ij}$ is high?\n",
736 | "\n",
737 | "### The Fisher Information Matrix is w.r.t the log probability of the NN prediction p(y|x,theta), rather than w.r.t the log likelihood of the data p(x|theta)\n",
738 | "\n",
739 | "notice that the term grad log p(D|theta)\n",
740 | "\n",
741 | "$$\\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)$$\n",
742 | "\n",
743 | "looks alot like the formula for the gradient of the loss function with respect to the parameters.\n",
744 | "\n",
745 | "Thats because for the purposes of EWC, it is!\n",
746 | "\n",
747 | "from the [Overcoming catastrophic forgetting in neural networks paper](https://www.pnas.org/content/114/13/3521):\n",
748 | "\n",
749 | "\"\"\"\n",
750 | "Note that the log probability of the data given the parameters logp(D|θ) is simply the negative of the loss function for the problem at hand −L(θ). \n",
751 | "\"\"\"\n",
752 | "\n",
753 | "The key mental leap we need to make is that in our machine learning model, p(D|θ) is the probability of the target label given the parameters rather than the p(X|θ) we are used to seeing that represents the likelihood of observing X given θ. \n",
754 | "\n",
755 | "Had the neural netword ended in a softmax layer the term $p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)$ would be the softmax output, but in the code this is combined with the log using the F.log_softmax function. \n",
756 | "\n",
757 | "$ log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)$ is the log likelihood. In the code below this is:\n",
758 | "\n",
759 | "```\n",
760 | "F.log_softmax(output, dim=1)\n",
761 | "```\n",
762 | "\n",
763 | "from [the pytorch docs](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss)\n",
764 | "\n",
765 | "\"\"\"\n",
766 | "Obtaining log-probabilities in a neural network is easily achieved by adding a LogSoftmax layer in the last layer of your network. You may use CrossEntropyLoss instead, if you prefer not to add an extra layer.\n",
767 | "\"\"\"\n",
768 | "\n",
769 | "Meaning that the full line of code:\n",
770 | "\n",
771 | "```\n",
772 | "F.nll_loss(F.log_softmax(output, dim=1), label)\n",
773 | "```\n",
774 | "\n",
775 | "is essentially the cross entropy loss"
776 | ]
777 | },
778 | {
779 | "cell_type": "code",
780 | "execution_count": 8,
781 | "metadata": {},
782 | "outputs": [
783 | {
784 | "name": "stdout",
785 | "output_type": "stream",
786 | "text": [
787 | "name of this layer's weights: fc1.weight\n",
788 | "shape of this matrix W torch.Size([256, 784])\n",
789 | "norm of the matrix 0.0\n"
790 | ]
791 | }
792 | ],
793 | "source": [
794 | "precision_matrices = {}\n",
795 | "for n, p in deepcopy(ewc.params).items():\n",
796 | " \n",
797 | " print(\"name of this layer's weights:\", n)\n",
798 | " p.data.zero_()\n",
799 | " precision_matrices[n] = var2device(p.data)\n",
800 | " print(\"shape of this matrix W\", precision_matrices[n].shape)\n",
801 | " #print(precision_matrices[n])\n",
802 | " print(\"norm of the matrix\",torch.norm(precision_matrices[n]).item())\n",
803 | " break"
804 | ]
805 | },
806 | {
807 | "cell_type": "markdown",
808 | "metadata": {},
809 | "source": [
810 | "#### the code above\n",
811 | "As a reminder, your input shape x is (batch_size, 784) becasue 28*28 = 784. \n",
812 | "\n",
813 | "The matrix W of this first linear, aka affine, aka fully connect layer, is\n",
814 | "(256, 784) because xW^T = a of shape (batch_size, 256) which is the shape of the first layer's activations\n",
815 | "\n",
816 | "#### the code below\n",
817 | "we pass just one image through our model, and examine the tensors being created\n",
818 | "as they are applied to just the first layer of parameters ofcourse the\n",
819 | "real thing accumulates this update through all the samples and all the layers"
820 | ]
821 | },
822 | {
823 | "cell_type": "code",
824 | "execution_count": 9,
825 | "metadata": {},
826 | "outputs": [
827 | {
828 | "name": "stdout",
829 | "output_type": "stream",
830 | "text": [
831 | "input.shape torch.Size([1, 28, 28])\n",
832 | "[0.0, 0.0, 4.682543, 0.0, 4.328751, 0.0, 7.281813, 0.0, 0.0, 0.0]\n",
833 | "predicted number tensor([6], device='cuda:0')\n",
834 | "loss 0.12338782846927643\n"
835 | ]
836 | },
837 | {
838 | "data": {
839 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOGElEQVR4nO3dbahdZXrG8euaOELMaNSEHo9OfIkkwlitlvgChmIdZ7D6IRmEQaON0tGMkMCohVYsZoK1IK2jH4XjC6YyGgaMVYbSGZVJbVAHo8YYTY0xJGhykqARjMkHm3j3w1kZjnr2s497rf0S7/8PDmfvde+11s02l2vt9ex1HkeEAHz7faffDQDoDcIOJEHYgSQIO5AEYQeSOKqXO7PNpX+gyyLCEy2vdWS3fYXtd21vsX1HnW0B6C53Os5ue4qkzZJ+JOlDSa9KujYi3imsw5Ed6LJuHNkvlLQlIrZGxOeSVklaUGN7ALqoTthPkfTBuOcfVsu+xPYS2+tsr6uxLwA1df0CXUSMSBqROI0H+qnOkX2HpFnjnn+/WgZgANUJ+6uS5tg+w/bRkq6R9GwzbQFoWsen8RFx0PYySb+TNEXSoxHxdmOdAWhUx0NvHe2Mz+xA13XlSzUAjhyEHUiCsANJEHYgCcIOJEHYgSR6ej97VlOmTCnWV69eXazPnTu3WL/gggta1j777LPiusiDIzuQBGEHkiDsQBKEHUiCsANJEHYgCYbeeuCoo8pv83HHHVesz5kzp1ifOnVqyxpDbziMIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4ew9cdNFFxfq5557bo06QGUd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYeWLFiRbE+ffr0Yn3Dhg3F+v79+79pS0ioVthtb5O0T9IhSQcjYl4TTQFoXhNH9r+OiI8a2A6ALuIzO5BE3bCHpN/bfs32koleYHuJ7XW219XcF4Aa6p7Gz4+IHbb/TNJztv83Il4c/4KIGJE0Ikm2o+b+AHSo1pE9InZUv/dIelrShU00BaB5HYfd9jTbxx5+LOnHkjY21RiAZtU5jR+S9LTtw9t5IiL+q5GujjCXXXZZsX7JJZfU2v4DDzxQrB84cKDW9pFDx2GPiK2S/qLBXgB0EUNvQBKEHUiCsANJEHYgCcIOJMEtrg2YPXt2sT5lypRa29+8eXOt9QGJIzuQBmEHkiDsQBKEHUiCsANJEHYgCcIOJME4+wDYuXNnrTowGRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtkbMGPGjFrrf/DBB7Xqg2z+/Pkta4sXL6617V27dhXrDz30UMvakfyedoojO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7AxYtWlRr/a1btzbUSe8tXbq0WF++fHnLWt3vJ7Rz0003taw98sgjxXXvuuuuptvpu7ZHdtuP2t5je+O4ZSfafs72e9XvE7rbJoC6JnMa/5ikK76y7A5JL0TEHEkvVM8BDLC2YY+IFyXt/criBZJWVo9XSlrYbFsAmtbpZ/ahiBitHu+SNNTqhbaXSFrS4X4ANKT2BbqICNtRqI9IGpGk0usAdFenQ2+7bQ9LUvV7T3MtAeiGTsP+rKQbqsc3SHqmmXYAdIsjymfWtp+UdKmkmZJ2S/qlpP+Q9BtJp0raLumnEfHVi3gTbetbeRr/5ptvFutnn312sX799dcX66tWrfrGPfXKJ598Uqwfe+yxHW/7pZdeKtZnzpxZrM+dO7dlbf/+/cV1p0+fXqwPsojwRMvbfmaPiGtblH5YqyMAPcXXZYEkCDuQBGEHkiDsQBKEHUiCW1wnqTSMc/LJJxfXbTfM88Ybb3TUUxOmTZtWrD/xxBPF+vHHH1+sv/vuuy1ry5YtK677/PPPF+vDw8PFemno7tRTTy2ue/vttxfr999/f7E+iDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLNP0qxZs1rWTjih/Md19+wp/22P0lh0t914443F+lVXXVWsHzhwoFi/++67W9bajaO3Mzo62nG99N9Taj+dNOPsAAYWYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7cgsXLqy1/n333Vest7sfvo52U2Wfc845HW+7m333C0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfZJWrNmTcvapk2biuuedtppxfrFF19crL/yyivFeslJJ51UrM+ePbvjbUvS3r1tZ+ru2He+Uz4WXX311cX61KlTW9ba3Ut/JN6v3k7bI7vtR23vsb1x3LIVtnfYXl/9XNndNgHUNZnT+MckXTHB8gci4rzq5z+bbQtA09qGPSJelNS9czUAPVHnAt0y2xuq0/yWf4TN9hLb62yvq7EvADV1GvYHJZ0p6TxJo5J+1eqFETESEfMiYl6H+wLQgI7CHhG7I+JQRHwh6SFJFzbbFoCmdRR22+Pnyv2JpI2tXgtgMLQdZ7f9pKRLJc20/aGkX0q61PZ5kkLSNkk/716Lg+HQoUMta+vXry+ue9ZZZxXrzzzzTLHebjx57dq1LWvt5l9v9zfv2xkaGqq1fkm7e+UXLFhQrG/evLll7c477yyue/DgwWL9SNQ27BFx7QSLH+lCLwC6iK/LAkkQdiAJwg4kQdiBJAg7kIQjonc7s3u3swHy6aefFuvHHHNMsf7yyy8X60uXLm1Z27BhQ3Hdm2++uVh/8MEHi/XPP/+8WL/nnnta1kq3oErtp5MeHh4u1rdv396ydv755xfXbfffbJBFhCdazpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0HrrnmmmL9scceK9aPOqp8c+LHH3/csnbGGWcU1y3duitJjz/+eLHe7vbbbv772rlzZ7F++eWXt6yVbn890jHODiRH2IEkCDuQBGEHkiDsQBKEHUiCsANJMGVzD6xatapYtyccFv2TkZGRYn3GjBkta6V7uiVp8eLFxfr7779frHfTtm3bivXbbrutWP82j6V3giM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB/exHgEWLFhXrDz/8cMva0Ucf3XQ7X9LuOwKlf19r1qwprnvLLbcU61u2bCnWs+r4fnbbs2z/wfY7tt+2/Ytq+Ym2n7P9XvW73kTfALpqMqfxByX9fUT8QNLFkpba/oGkOyS9EBFzJL1QPQcwoNqGPSJGI+L16vE+SZsknSJpgaSV1ctWSlrYpR4BNOAbfTfe9umSzpf0R0lDETFalXZJGmqxzhJJS2r0CKABk74ab/t7kp6SdGtEfGnWuxi7CjPhlZiIGImIeRExr1anAGqZVNhtf1djQf91RKyuFu+2PVzVhyXt6U6LAJrQdujNY2MrKyXtjYhbxy3/N0kfR8S9tu+QdGJE/EObbTH01gVnnnlmy9ry5cuL61533XW19r1v375ivfSnpteuXVtct9100JhYq6G3yXxmv0TS30p6y/b6atmdku6V9BvbP5O0XdJPG+gTQJe0DXtErJXU6psTP2y2HQDdwtdlgSQIO5AEYQeSIOxAEoQdSIJbXIFvGaZsBpIj7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJNqG3fYs23+w/Y7tt23/olq+wvYO2+urnyu73y6ATrWdJML2sKThiHjd9rGSXpO0UGPzsX8WEfdNemdMEgF0XatJIiYzP/uopNHq8T7bmySd0mx7ALrtG31mt326pPMl/bFatMz2BtuP2j6hxTpLbK+zva5eqwDqmPRcb7a/J+m/Jf1LRKy2PSTpI0kh6Z81dqr/d222wWk80GWtTuMnFXbb35X0W0m/i4j7J6ifLum3EfHnbbZD2IEu63hiR9uW9IikTeODXl24O+wnkjbWbRJA90zmavx8Sf8j6S1JX1SL75R0raTzNHYav03Sz6uLeaVtcWQHuqzWaXxTCDvQfczPDiRH2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtH5xs2EeSto97PrNaNogGtbdB7Uuit0412dtprQo9vZ/9azu310XEvL41UDCovQ1qXxK9dapXvXEaDyRB2IEk+h32kT7vv2RQexvUviR661RPeuvrZ3YAvdPvIzuAHiHsQBJ9CbvtK2y/a3uL7Tv60UMrtrfZfquahrqv89NVc+jtsb1x3LITbT9n+73q94Rz7PWpt4GYxrswzXhf37t+T3/e88/stqdI2izpR5I+lPSqpGsj4p2eNtKC7W2S5kVE37+AYfuvJH0m6d8PT61l+18l7Y2Ie6v/UZ4QEf84IL2t0DecxrtLvbWaZvxG9fG9a3L6807048h+oaQtEbE1Ij6XtErSgj70MfAi4kVJe7+yeIGkldXjlRr7x9JzLXobCBExGhGvV4/3STo8zXhf37tCXz3Rj7CfIumDcc8/1GDN9x6Sfm/7NdtL+t3MBIbGTbO1S9JQP5uZQNtpvHvpK9OMD8x718n053Vxge7r5kfEX0r6G0lLq9PVgRRjn8EGaez0QUlnamwOwFFJv+pnM9U0409JujUiPh1f6+d7N0FfPXnf+hH2HZJmjXv+/WrZQIiIHdXvPZKe1tjHjkGy+/AMutXvPX3u508iYndEHIqILyQ9pD6+d9U0409J+nVErK4W9/29m6ivXr1v/Qj7q5Lm2D7D9tGSrpH0bB/6+Brb06oLJ7I9TdKPNXhTUT8r6Ybq8Q2SnuljL18yKNN4t5pmXH1+7/o+/XlE9PxH0pUauyL/vqR/6kcPLfqaLenN6uftfvcm6UmNndb9n8aubfxM0gxJL0h6T9Lzkk4coN4e19jU3hs0FqzhPvU2X2On6Bskra9+ruz3e1foqyfvG1+XBZLgAh2QBGEHkiDsQBKEHUiCsANJEHYgCcIOJPH/hbRq7dS3XKkAAAAASUVORK5CYII=\n",
840 | "text/plain": [
841 | ""
842 | ]
843 | },
844 | "metadata": {
845 | "needs_background": "light"
846 | },
847 | "output_type": "display_data"
848 | }
849 | ],
850 | "source": [
851 | "ewc.model.eval()\n",
852 | "for input in ewc.dataset:\n",
853 | " ewc.model.zero_grad()\n",
854 | " # no need for the channel dim, these are greyscale, not color rgb images\n",
855 | " # bs,1,h,w -> bs,h,w\n",
856 | " input = input.squeeze(1)\n",
857 | " input = var2device(input)\n",
858 | " output = ewc.model(input).view(1, -1)\n",
859 | " label = output.max(1)[1].view(-1)\n",
860 | " loss = F.nll_loss(F.log_softmax(output, dim=1), label)\n",
861 | " loss.backward() \n",
862 | " break\n",
863 | "\n",
864 | "print(\"input.shape\", input.shape)\n",
865 | "img = Image.fromarray(input[0].detach().cpu().numpy()*255)\n",
866 | "plt.imshow(img)\n",
867 | "\n",
868 | "print(list(output.detach().cpu().numpy()[0]))\n",
869 | "print(\"predicted number\",label)\n",
870 | "print(\"loss\",loss.item())"
871 | ]
872 | },
873 | {
874 | "cell_type": "markdown",
875 | "metadata": {},
876 | "source": [
877 | "### sum the squares of the gradients\n",
878 | "\n",
879 | "the cell below is the last one you should run when comparing the [norm](https://pytorch.org/docs/stable/generated/torch.norm.html) of the matrix between when the model is pretrained or not.\n",
880 | "\n",
881 | "noticed that when you ran the above steps with a pretrained model, the norm is larger than when run with a randomly initialized model, why is this? "
882 | ]
883 | },
884 | {
885 | "cell_type": "code",
886 | "execution_count": 11,
887 | "metadata": {},
888 | "outputs": [
889 | {
890 | "name": "stdout",
891 | "output_type": "stream",
892 | "text": [
893 | "name of this layer's weights: fc1.weight\n",
894 | "shape of p.grad.data torch.Size([256, 784])\n",
895 | "shape of p.grad.data ** 2 torch.Size([256, 784])\n",
896 | "norm of the matrix 0.0003205974353477359\n"
897 | ]
898 | }
899 | ],
900 | "source": [
901 | "for n, p in ewc.params.items():\n",
902 | " print(\"name of this layer's weights:\", n)\n",
903 | " print(\"shape of p.grad.data\", p.grad.data.shape)\n",
904 | " print(\"shape of p.grad.data ** 2\", (p.grad.data ** 2).shape)\n",
905 | " precision_matrices[n].data += p.grad.data ** 2 / len(ewc.dataset)\n",
906 | " #print(precision_matrices[n])\n",
907 | " print(\"norm of the matrix\",torch.norm(precision_matrices[n]).item())\n",
908 | " break"
909 | ]
910 | },
911 | {
912 | "cell_type": "markdown",
913 | "metadata": {},
914 | "source": [
915 | "### The Math in the Code\n",
916 | "\n",
917 | "One clarification that I think needs to be made that was brought to my attention on [Reddit](https://www.reddit.com/r/MachineLearning/comments/t2riby/p_i_made_the_kind_of_tutorial_i_wish_someone_had/?utm_source=share&utm_medium=web2x&context=3). The equation for the empirical fisher below suggests we are calculating the full Fisher Matrix:\n",
918 | "\n",
919 | "$$\n",
920 | "I_{\\theta_\\mathcal{A}^*} = \\frac{1}{N} \\sum_{i=1}^{N} \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*) \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)^T\n",
921 | "$$\n",
922 | "\n",
923 | "But in the code we are not, this line is where the diagonal terms are calculated\n",
924 | "\n",
925 | "```\n",
926 | "p.grad.data ** 2\n",
927 | "```\n",
928 | "\n",
929 | "the above line of code corresponds only to the diagonal component of the matrix $I_{\\theta_\\mathcal{A}^*}$ above.\n",
930 | "\n",
931 | "Notice `p.grad.data ** 2` is the same shape as the W for each layer, thats becasue `** 2` is an element-wise operation where we square every element. which is essentially the gradient of the loss function w.r.t. just one parameter, aka weight, multiplied by itself. If we took all the layer's `p.grad.data ** 2`'s and flattened them out into a very long vector, then that would be the just diagonal of the fisher matrix we see in the above equation. So we have only calculated the F_ii terms of the matrix and none of the F_ij terms. \n",
932 | "\n",
933 | "note: the diagonal therefore must all be positive valued\n",
934 | "\n",
935 | "The fisher diagonals are summed over all the N examples and divided by N, aka we average them to get the mean, aka the expectation, which is why in the actual EWS class, the code is written\n",
936 | "\n",
937 | "```\n",
938 | "+= p.grad.data ** 2 / len(self.dataset)\n",
939 | "```\n",
940 | "\n",
941 | "and placed in the inner loop of:\n",
942 | "\n",
943 | "```\n",
944 | "for input in self.dataset:\n",
945 | "```\n",
946 | "\n",
947 | "After the code accumulates over all N examples, it completes the rest of the equation:\n",
948 | "\n",
949 | "$$\n",
950 | "\\frac{1}{N} \\sum_{i=1}^{N} \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*) \\nabla_\\theta log \\ p(x_{\\mathcal{A}, i}|\\theta_\\mathcal{A}^*)^T\n",
951 | "$$\n",
952 | "\n",
953 | "\n",
954 | "### using the F_i in the overall loss function\n",
955 | "\n",
956 | "this line of code in the `penalty()` method of the EWS class\n",
957 | "\n",
958 | "```\n",
959 | "_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2\n",
960 | "```\n",
961 | "\n",
962 | "corresponds to this formula\n",
963 | "\n",
964 | "$$ F_{i} \\left(\\theta_{i} - \\theta_{A , i}^{*}\\right)^{2} $$\n",
965 | "\n",
966 | "where `self._precision_matrices[n]` = $ F_{i}$ and \n",
967 | "`(p - self._means[n]) ** 2` = $\\left(\\theta_{i} - \\theta_{A , i}^{*}\\right)^{2} $\n",
968 | "\n",
969 | "this is then summed over all parameters `loss += _loss.sum()`\n",
970 | "and done for each layer, which is why the `penaly()` method places this line of code in the inner loop of `for n, p in model.named_parameters():`\n",
971 | "\n",
972 | "later in the code for EWC training, this has to be added to the loss for task B `F.cross_entropy(output, target)` and scaled by \n",
973 | "\n",
974 | "$\\lambda$ = importance, for the full line of code:\n",
975 | "\n",
976 | "`loss = F.cross_entropy(output, target) + importance * ewc.penalty(model)`\n",
977 | "\n",
978 | "which is the equivalent of our EWC loss function\n",
979 | "\n",
980 | "$$\n",
981 | "L \\left(\\right. \\theta \\left.\\right) = L_{B} \\left(\\right. \\theta \\left.\\right) + \\underset{i}{\\sum} \\frac{\\lambda}{2} F_{i} \\left(\\theta_{i} - \\theta_{A , i}^{*}\\right)^{2} $$\n",
982 | "\n",
983 | "### Explaination\n",
984 | "\n",
985 | "For the full proof read [Fisher Information Matrix by Agustinus Kristiadi](https://agustinus.kristia.de/techblog/2018/03/11/fisher-information/) or [Fisher Information Matrix by Yuan-Hong Liao](https://andrewliao11.github.io/blog/fisher-info-matrix/). \n",
986 | "\n",
987 | "The end result of the proof is that \"The Fisher is the negative expectation of the Hessian of the log-likelihood\"\n",
988 | "\n",
989 | "The Hessian is a square matrix of second-order partial derivatives of a scalar-valued function, or scalar field. It describes the local curvature of a function of many variables. So the diagonal of the Hessian is the 2nd derivative of the loss (negative log likelihood) with respect to the parameters. Imagine if you only had 2 parameters, the Fisher Information over a patch of values for $\\theta_i$ adn $\\theta_j$ might look like this:\n",
990 | "\n",
991 | "Hypothetical F_ij Surface, ignore the vertical axis label for now\n",
992 | "
\n",
993 | "\n",
994 | "
\n",
995 | "\n",
996 | "The warmer areas (more red) are the areas of higher curvature, the areas of higher curvature are the areas of local minima and maxima. If you turn the above figure upside down, this would look just like the topological manifold, aka surface, of the loss function, And what is gradient descent trying to do? it is trying to find a local minima on this loss surface, aka a local maxima on the fisher information surface. \n",
997 | "\n",
998 | "In intuitive terms, both the Hessian and Fisher Matrix describe how sharp, ie curved, each point in $\\theta$ space is, with respect to the loss.\n",
999 | "\n",
1000 | "But wait a minute, cant we have a valley with a very flat and wide floor? yes we can, and it very well might be that some of our parameters have settled into such a region while training on task A. However, if that is the case, then while $\\theta$ is in the center of such a region, it can move in various directions and still contribute to good performance on task A, which means we should not penalize changes in this region until $\\theta$ reaches the edge of such a region, in which case, the Fisher information should start to increase. \n",
1001 | "\n",
1002 | "Basically the Fisher Matrix is the negative of this matrix below, where you replace f with log[p(x|theta)] and replace x with theta\n",
1003 | "
\n",
1004 | "\n",
1005 | "
\n",
1006 | "\n",
1007 | "*The $F_i$ is higher for those $\\theta_i$ that are already optimized into a narrow region and thus if changed, would cause an increase in the loss for task A, thereby reducing the performance on task A described by that loss. More of these \"important and just right\" parameters will exist for trained models and thus the norm on the fisher diagonals will be higher. For models not yet trained on task A, there are fewer parameters that have been optimized to help solve task A, therefore fewer parameters that when changed will have a significant effect on the loss for task A*"
1008 | ]
1009 | },
1010 | {
1011 | "cell_type": "code",
1012 | "execution_count": 12,
1013 | "metadata": {},
1014 | "outputs": [
1015 | {
1016 | "data": {
1017 | "text/plain": [
1018 | "tensor(0., grad_fn=)"
1019 | ]
1020 | },
1021 | "execution_count": 12,
1022 | "metadata": {},
1023 | "output_type": "execute_result"
1024 | }
1025 | ],
1026 | "source": [
1027 | "# since the parameters have not changed yet, we expect (𝜃 - 𝜃∗)2 to be zero throughout\n",
1028 | "ewc.penalty(model)"
1029 | ]
1030 | },
1031 | {
1032 | "cell_type": "markdown",
1033 | "metadata": {},
1034 | "source": [
1035 | "### EWC Training\n",
1036 | "\n",
1037 | "Load a model pretrained on MNIST, then train it on Fashion-MNIST while observing the new training's effect on MNIST performance"
1038 | ]
1039 | },
1040 | {
1041 | "cell_type": "code",
1042 | "execution_count": 20,
1043 | "metadata": {},
1044 | "outputs": [
1045 | {
1046 | "name": "stdout",
1047 | "output_type": "stream",
1048 | "text": [
1049 | "accuracy on mnist test set 0.8238000273704529\n"
1050 | ]
1051 | }
1052 | ],
1053 | "source": [
1054 | "# verify that the model you are about to retrain is indeed pretrained\n",
1055 | "acc = test(model,test_loader_mnist)\n",
1056 | "print(\"accuracy on mnist test set\", acc.item())"
1057 | ]
1058 | },
1059 | {
1060 | "cell_type": "code",
1061 | "execution_count": 21,
1062 | "metadata": {},
1063 | "outputs": [],
1064 | "source": [
1065 | "def one_epoch_ewc(\n",
1066 | " ewc: EWC, \n",
1067 | " importance: float,\n",
1068 | " model: nn.Module, \n",
1069 | " data_loader: torch.utils.data.DataLoader, \n",
1070 | " lr = 1e-3,\n",
1071 | "):\n",
1072 | " model.train()\n",
1073 | " epoch_loss = 0\n",
1074 | " optimizer = optim.SGD(params=model.parameters(), lr=lr)\n",
1075 | " for input, target in data_loader:\n",
1076 | " # no need for the channel dim\n",
1077 | " # bs,1,h,w -> bs,h,w\n",
1078 | " input = input.squeeze(1) \n",
1079 | " input, target = var2device(input), var2device(target)\n",
1080 | " optimizer.zero_grad()\n",
1081 | " output = model(input)\n",
1082 | " loss = F.cross_entropy(output, target) + importance * ewc.penalty(model)\n",
1083 | " epoch_loss += loss.item()\n",
1084 | " loss.backward()\n",
1085 | " optimizer.step()\n",
1086 | " return epoch_loss / len(data_loader)\n",
1087 | "\n",
1088 | "def ewc_training(\n",
1089 | " ewc,\n",
1090 | " importance,\n",
1091 | " model, \n",
1092 | " epochs, \n",
1093 | " train_loader,\n",
1094 | " test_loader,\n",
1095 | " test2_loader = None,\n",
1096 | " use_cuda=True, \n",
1097 | "):\n",
1098 | " \n",
1099 | " \"\"\"\n",
1100 | " This function saves the training curve data consisting\n",
1101 | " training set loss and validation set accuracy over the\n",
1102 | " course of the epochs of training.\n",
1103 | " \n",
1104 | " I set this up such that if you provide 2 test sets,you\n",
1105 | " can watch the test accuracy change together during training\n",
1106 | " on train_loder\n",
1107 | " \"\"\"\n",
1108 | " \n",
1109 | " if torch.cuda.is_available() and use_cuda:\n",
1110 | " model.cuda()\n",
1111 | " \n",
1112 | " train_loss, val_acc, val2_acc = [], [], []\n",
1113 | " \n",
1114 | " for epoch in tqdm(range(epochs)):\n",
1115 | "\n",
1116 | " epoch_loss = one_epoch_ewc(ewc,importance, model,train_loader)\n",
1117 | " train_loss.append(epoch_loss)\n",
1118 | " \n",
1119 | " acc = test(model,test_loader)\n",
1120 | " val_acc.append(acc.detach().cpu().numpy())\n",
1121 | " \n",
1122 | " if test2_loader is not None:\n",
1123 | " acc2 = test(model,test2_loader)\n",
1124 | " val2_acc.append(acc2.detach().cpu().numpy())\n",
1125 | " \n",
1126 | " return train_loss, val_acc, val2_acc, model "
1127 | ]
1128 | },
1129 | {
1130 | "cell_type": "code",
1131 | "execution_count": 22,
1132 | "metadata": {},
1133 | "outputs": [
1134 | {
1135 | "data": {
1136 | "application/vnd.jupyter.widget-view+json": {
1137 | "model_id": "b8b24f5220c942fba0da4b1f6cc6e46c",
1138 | "version_major": 2,
1139 | "version_minor": 0
1140 | },
1141 | "text/plain": [
1142 | " 0%| | 0/25 [00:00, ?it/s]"
1143 | ]
1144 | },
1145 | "metadata": {},
1146 | "output_type": "display_data"
1147 | }
1148 | ],
1149 | "source": [
1150 | "torch.manual_seed(0)\n",
1151 | "np.random.seed(0)\n",
1152 | "random.seed(0)\n",
1153 | "\n",
1154 | "train_loss, val_acc, val2_acc, model = ewc_training(\n",
1155 | " ewc = ewc,\n",
1156 | " importance = 1000,\n",
1157 | " model = model,\n",
1158 | " epochs = 25,\n",
1159 | " train_loader = train_loader_fashion,\n",
1160 | " test_loader = test_loader_fashion,\n",
1161 | " test2_loader = test_loader_mnist,\n",
1162 | ")"
1163 | ]
1164 | },
1165 | {
1166 | "cell_type": "code",
1167 | "execution_count": 23,
1168 | "metadata": {},
1169 | "outputs": [
1170 | {
1171 | "data": {
1172 | "image/png": "\n",
1173 | "text/plain": [
1174 | ""
1175 | ]
1176 | },
1177 | "metadata": {
1178 | "needs_background": "light"
1179 | },
1180 | "output_type": "display_data"
1181 | }
1182 | ],
1183 | "source": [
1184 | "plt.figure()\n",
1185 | "plt.xlabel('epochs', fontsize=25)\n",
1186 | "plt.ylabel('validation accuracy', fontsize=25)\n",
1187 | "plt.plot(val_acc, label='fashion')\n",
1188 | "plt.plot(val2_acc, label='mnist')\n",
1189 | "plt.legend()\n",
1190 | "plt.show()"
1191 | ]
1192 | },
1193 | {
1194 | "cell_type": "code",
1195 | "execution_count": 24,
1196 | "metadata": {},
1197 | "outputs": [],
1198 | "source": [
1199 | "# save the trained model\n",
1200 | "model = model.cpu()\n",
1201 | "torch.save(model, \"files/ewcmodel.pth\")"
1202 | ]
1203 | },
1204 | {
1205 | "cell_type": "markdown",
1206 | "metadata": {},
1207 | "source": [
1208 | "### Conclusion\n",
1209 | "\n",
1210 | "We expect a result like this, where instead of the mnist accuracy dropping to below 30%, it degrades by a far lesser amount, it is still above 70%, from around 80% at the end of mnist training. meanwhile, learning of the fashion-mnist dataset still proceed forward, rising above 40% from a baseline below 10%. The performances here are modest, but that is not the point, im sure with architectures specialized for images like convolutional layers or vision transformers you can get the final accuracy much higher, the point here is that we were able to preserve previously learned capabilities of our NN\n",
1211 | "\n",
1212 | "