├── README.md ├── data.py ├── demo.ipynb └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EWC.pytorch 2 | 3 | An implementation of Elastic Weight Consolidation (EWC), proposed in James Kirkpatrick et al. *Overcoming catastrophic forgetting in neural networks* 2016(10.1073/pnas.1611835114). 4 | 5 | * [demo.ipynb](demo.ipynb) demonstrates EWC with supervised learning. -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torchvision import datasets 4 | 5 | 6 | class PermutedMNIST(datasets.MNIST): 7 | 8 | def __init__(self, root="~/.torch/data/mnist", train=True, permute_idx=None): 9 | super(PermutedMNIST, self).__init__(root, train, download=True) 10 | assert len(permute_idx) == 28 * 28 11 | if self.train: 12 | self.train_data = torch.stack([img.float().view(-1)[permute_idx] / 255 13 | for img in self.train_data]) 14 | else: 15 | self.test_data = torch.stack([img.float().view(-1)[permute_idx] / 255 16 | for img in self.test_data]) 17 | 18 | def __getitem__(self, index): 19 | 20 | if self.train: 21 | img, target = self.train_data[index], self.train_labels[index] 22 | else: 23 | img, target = self.test_data[index], self.test_labels[index] 24 | 25 | return img, target 26 | 27 | def get_sample(self, sample_size): 28 | sample_idx = random.sample(range(len(self)), sample_size) 29 | return [img for img in self.train_data[sample_idx]] 30 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "% matplotlib inline\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "\n", 12 | "plt.style.use(\"seaborn-white\")\n", 13 | "\n", 14 | "import random\n", 15 | "import torch\n", 16 | "from torch import nn\n", 17 | "from torch.nn import functional as F\n", 18 | "from torch import optim\n", 19 | "from tqdm import tqdm\n", 20 | "\n", 21 | "from data import PermutedMNIST\n", 22 | "from utils import EWC, ewc_train, normal_train, test\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "epochs = 50\n", 32 | "lr = 1e-3\n", 33 | "batch_size = 128\n", 34 | "sample_size = 200\n", 35 | "hidden_size = 200\n", 36 | "num_task = 3\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "class MLP(nn.Module):\n", 46 | " def __init__(self, hidden_size=400):\n", 47 | " super(MLP, self).__init__()\n", 48 | " self.fc1 = nn.Linear(28 * 28, hidden_size)\n", 49 | " self.fc2 = nn.Linear(hidden_size, hidden_size)\n", 50 | " self.fc3 = nn.Linear(hidden_size, hidden_size)\n", 51 | " self.fc4 = nn.Linear(hidden_size, 10)\n", 52 | "\n", 53 | " def forward(self, input):\n", 54 | " x = F.relu(self.fc1(input))\n", 55 | " x = F.relu(self.fc2(x))\n", 56 | " x = F.relu(self.fc3(x))\n", 57 | " x = F.relu(self.fc4(x))\n", 58 | " return x\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def get_permute_mnist():\n", 68 | " train_loader = {}\n", 69 | " test_loader = {}\n", 70 | " idx = list(range(28 * 28))\n", 71 | " for i in range(num_task):\n", 72 | " train_loader[i] = torch.utils.data.DataLoader(PermutedMNIST(train=True, permute_idx=idx),\n", 73 | " batch_size=batch_size,\n", 74 | " num_workers=4)\n", 75 | " test_loader[i] = torch.utils.data.DataLoader(PermutedMNIST(train=False, permute_idx=idx),\n", 76 | " batch_size=batch_size)\n", 77 | " random.shuffle(idx)\n", 78 | " return train_loader, test_loader\n", 79 | "\n", 80 | "\n", 81 | "train_loader, test_loader = get_permute_mnist()\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def standard_process(epochs, use_cuda=True, weight=True):\n", 91 | " model = MLP(hidden_size)\n", 92 | " if torch.cuda.is_available() and use_cuda:\n", 93 | " model.cuda()\n", 94 | " optimizer = optim.SGD(params=model.parameters(), lr=lr)\n", 95 | "\n", 96 | " loss, acc = {}, {}\n", 97 | " for task in range(num_task):\n", 98 | " loss[task] = []\n", 99 | " acc[task] = []\n", 100 | " for _ in tqdm(range(epochs)):\n", 101 | " loss[task].append(normal_train(model, optimizer, train_loader[task]))\n", 102 | " for sub_task in range(task + 1):\n", 103 | " acc[sub_task].append(test(model, test_loader[sub_task]))\n", 104 | " if task == 0 and weight:\n", 105 | " weight = model.state_dict()\n", 106 | " return loss, acc, weight\n", 107 | "\n", 108 | "\n", 109 | "def ewc_process(epochs, importance, use_cuda=True, weight=None):\n", 110 | " model = MLP(hidden_size)\n", 111 | " if torch.cuda.is_available() and use_cuda:\n", 112 | " model.cuda()\n", 113 | " optimizer = optim.SGD(params=model.parameters(), lr=lr)\n", 114 | "\n", 115 | " loss, acc, ewc = {}, {}, {}\n", 116 | " for task in range(num_task):\n", 117 | " loss[task] = []\n", 118 | " acc[task] = []\n", 119 | "\n", 120 | " if task == 0:\n", 121 | " if weight:\n", 122 | " model.load_state_dict(weight)\n", 123 | " else:\n", 124 | " for _ in tqdm(range(epochs)):\n", 125 | " loss[task].append(normal_train(model, optimizer, train_loader[task]))\n", 126 | " acc[task].append(test(model, test_loader[task]))\n", 127 | " else:\n", 128 | " old_tasks = []\n", 129 | " for sub_task in range(task):\n", 130 | " old_tasks = old_tasks + train_loader[sub_task].dataset.get_sample(sample_size)\n", 131 | " old_tasks = random.sample(old_tasks, k=sample_size)\n", 132 | " for _ in tqdm(range(epochs)):\n", 133 | " loss[task].append(ewc_train(model, optimizer, train_loader[task], EWC(model, old_tasks), importance))\n", 134 | " for sub_task in range(task + 1):\n", 135 | " acc[sub_task].append(test(model, test_loader[sub_task]))\n", 136 | "\n", 137 | " return loss, acc\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "def loss_plot(x):\n", 147 | " for t, v in x.items():\n", 148 | " plt.plot(list(range(t * epochs, (t + 1) * epochs)), v)\n", 149 | "\n", 150 | "def accuracy_plot(x):\n", 151 | " for t, v in x.items():\n", 152 | " plt.plot(list(range(t * epochs, num_task * epochs)), v)\n", 153 | " plt.ylim(0, 1)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 7, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stderr", 163 | "output_type": "stream", 164 | "text": [ 165 | "100%|██████████| 50/50 [01:01<00:00, 1.22s/it]\n", 166 | "100%|██████████| 50/50 [01:05<00:00, 1.31s/it]\n", 167 | "100%|██████████| 50/50 [01:08<00:00, 1.38s/it]\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "loss, acc, weight = standard_process(epochs)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "image/png": "\n", 183 | "text/plain": [ 184 | "" 185 | ] 186 | }, 187 | "metadata": {}, 188 | "output_type": "display_data" 189 | } 190 | ], 191 | "source": [ 192 | "loss_plot(loss)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 9, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "image/png": "\n", 203 | "text/plain": [ 204 | "" 205 | ] 206 | }, 207 | "metadata": {}, 208 | "output_type": "display_data" 209 | } 210 | ], 211 | "source": [ 212 | "accuracy_plot(acc)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 10, 218 | "metadata": { 219 | "scrolled": true 220 | }, 221 | "outputs": [ 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "100%|██████████| 50/50 [01:01<00:00, 1.23s/it]\n", 227 | "100%|██████████| 50/50 [01:52<00:00, 2.25s/it]\n", 228 | "100%|██████████| 50/50 [02:04<00:00, 2.49s/it]\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "loss_ewc, acc_ewc = ewc_process(epochs, importance=1000, \n", 234 | "# weight=weight\n", 235 | " )" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 11, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "image/png": "\n", 246 | "text/plain": [ 247 | "" 248 | ] 249 | }, 250 | "metadata": {}, 251 | "output_type": "display_data" 252 | } 253 | ], 254 | "source": [ 255 | "loss_plot(loss_ewc)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "image/png": "\n", 266 | "text/plain": [ 267 | "" 268 | ] 269 | }, 270 | "metadata": {}, 271 | "output_type": "display_data" 272 | } 273 | ], 274 | "source": [ 275 | "accuracy_plot(acc_ewc)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 13, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "" 287 | ] 288 | }, 289 | "execution_count": 13, 290 | "metadata": {}, 291 | "output_type": "execute_result" 292 | }, 293 | { 294 | "data": { 295 | "image/png": "\n", 296 | "text/plain": [ 297 | "" 298 | ] 299 | }, 300 | "metadata": {}, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "plt.plot(acc[0], label=\"sgd\")\n", 306 | "plt.plot(acc_ewc[0], label=\"ewc\")\n", 307 | "plt.legend()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [] 316 | } 317 | ], 318 | "metadata": { 319 | "kernelspec": { 320 | "display_name": "Python 3", 321 | "language": "python", 322 | "name": "python3" 323 | }, 324 | "language_info": { 325 | "codemirror_mode": { 326 | "name": "ipython", 327 | "version": 3 328 | }, 329 | "file_extension": ".py", 330 | "mimetype": "text/x-python", 331 | "name": "python", 332 | "nbconvert_exporter": "python", 333 | "pygments_lexer": "ipython3", 334 | "version": "3.6.3" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 2 339 | } 340 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | import torch.utils.data 8 | 9 | 10 | def variable(t: torch.Tensor, use_cuda=True, **kwargs): 11 | if torch.cuda.is_available() and use_cuda: 12 | t = t.cuda() 13 | return Variable(t, **kwargs) 14 | 15 | 16 | class EWC(object): 17 | def __init__(self, model: nn.Module, dataset: list): 18 | 19 | self.model = model 20 | self.dataset = dataset 21 | 22 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 23 | self._means = {} 24 | self._precision_matrices = self._diag_fisher() 25 | 26 | for n, p in deepcopy(self.params).items(): 27 | self._means[n] = variable(p.data) 28 | 29 | def _diag_fisher(self): 30 | precision_matrices = {} 31 | for n, p in deepcopy(self.params).items(): 32 | p.data.zero_() 33 | precision_matrices[n] = variable(p.data) 34 | 35 | self.model.eval() 36 | for input in self.dataset: 37 | self.model.zero_grad() 38 | input = variable(input) 39 | output = self.model(input).view(1, -1) 40 | label = output.max(1)[1].view(-1) 41 | loss = F.nll_loss(F.log_softmax(output, dim=1), label) 42 | loss.backward() 43 | 44 | for n, p in self.model.named_parameters(): 45 | precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset) 46 | 47 | precision_matrices = {n: p for n, p in precision_matrices.items()} 48 | return precision_matrices 49 | 50 | def penalty(self, model: nn.Module): 51 | loss = 0 52 | for n, p in model.named_parameters(): 53 | _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2 54 | loss += _loss.sum() 55 | return loss 56 | 57 | 58 | def normal_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader): 59 | model.train() 60 | epoch_loss = 0 61 | for input, target in data_loader: 62 | input, target = variable(input), variable(target) 63 | optimizer.zero_grad() 64 | output = model(input) 65 | loss = F.cross_entropy(output, target) 66 | epoch_loss += loss.data[0] 67 | loss.backward() 68 | optimizer.step() 69 | return epoch_loss / len(data_loader) 70 | 71 | 72 | def ewc_train(model: nn.Module, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, 73 | ewc: EWC, importance: float): 74 | model.train() 75 | epoch_loss = 0 76 | for input, target in data_loader: 77 | input, target = variable(input), variable(target) 78 | optimizer.zero_grad() 79 | output = model(input) 80 | loss = F.cross_entropy(output, target) + importance * ewc.penalty(model) 81 | epoch_loss += loss.data[0] 82 | loss.backward() 83 | optimizer.step() 84 | return epoch_loss / len(data_loader) 85 | 86 | 87 | def test(model: nn.Module, data_loader: torch.utils.data.DataLoader): 88 | model.eval() 89 | correct = 0 90 | for input, target in data_loader: 91 | input, target = variable(input), variable(target) 92 | output = model(input) 93 | correct += (F.softmax(output, dim=1).max(dim=1)[1] == target).data.sum() 94 | return correct / len(data_loader.dataset) 95 | --------------------------------------------------------------------------------