├── AgeDB_AllCNN_0to30_forgetting.ipynb ├── LICENSE ├── README.md ├── STSB_LSTM_Unlearn_0_to_2.ipynb ├── configs.py ├── data_formatters ├── base.py ├── electricity.py ├── ts_dataset.py └── utils.py ├── datasets.py ├── metrics.py ├── models.py ├── script_download_data.py ├── unlearn.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ayush Kumar Tarun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Regression Unlearning 2 | Official repo of the paper Deep Regression Unlearning accepted in ICML 2023. 3 | 4 | ## Description 5 | We study the unexplored field of unlearning in deep regression models and introduce two unlearning methods, blindspot unlearning and gaussian amnesiac unlearning. We adapt metrics from earlier studies on unlearning in deep classification problems for the regression setting. We evaluate our method using a variety of unlearning metrics including different privacy attacks. We conduct regression experiments for computer vision, natural language processing and forecasting applications. 6 | 7 | ## Paper 8 | [Deep Regression Unlearning](https://arxiv.org/pdf/2210.08196) 9 | 10 | ## BibTex 11 | @article{tarun2022deep, 12 | title={Deep Regression Unlearning}, 13 | author={Tarun, Ayush K and Chundawat, Vikram S and Mandal, Murari and Kankanhalli, Mohan}, 14 | journal={arXiv preprint arXiv:2210.08196}, 15 | year={2022} 16 | } 17 | -------------------------------------------------------------------------------- /STSB_LSTM_Unlearn_0_to_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "3b72a11e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/home/users/visionintelligence/Vikram/anaconda3/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3\n", 14 | " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "import os\n", 21 | "import wget\n", 22 | "import zipfile\n", 23 | "import re\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "from torch.nn import functional as F\n", 28 | "from utils import clean_text\n", 29 | "from models import LSTMnetwork" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "924a107a", 35 | "metadata": {}, 36 | "source": [ 37 | "## Define some helper functions" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "a44827b1", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def training_step(model, batch, device):\n", 48 | " sent1, sent2, labels = batch \n", 49 | " sent1, sent2, labels = sent1.to(device), sent2.to(device), labels.to(device)\n", 50 | " out, *_ = model(sent1, sent2) # Generate predictions\n", 51 | " loss= F.mse_loss(out, labels) # Calculate loss\n", 52 | " return loss\n", 53 | "\n", 54 | "def validation_step(model, batch, device):\n", 55 | " sent1, sent2, labels = batch \n", 56 | " sent1, sent2, labels = sent1.to(device), sent2.to(device), labels.to(device)\n", 57 | " out, *_ = model(sent1, sent2) # Generate predictions\n", 58 | " loss= F.mse_loss(out, labels) # Calculate loss\n", 59 | " return {'Loss': loss.detach()}\n", 60 | "\n", 61 | "def validation_epoch_end(model, outputs):\n", 62 | " batch_losses = [x['Loss'] for x in outputs]\n", 63 | " epoch_loss = torch.stack(batch_losses).mean() # Combine losses\n", 64 | " return {'Loss': epoch_loss.item()}\n", 65 | "\n", 66 | "def epoch_end(model, epoch, result):\n", 67 | " print(\"Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}\".format(\n", 68 | " epoch, result['lrs'][-1], result['train_loss'], result['Loss']))\n", 69 | "\n", 70 | "\n", 71 | "\n", 72 | "@torch.no_grad()\n", 73 | "def evaluate(model, val_df, device, batch_size = 256):\n", 74 | " model.eval()\n", 75 | " outputs = []\n", 76 | " \n", 77 | " num_steps = len(val_df)//batch_size\n", 78 | " \n", 79 | " for i in range(num_steps):\n", 80 | " sent1 = torch.tensor(np.stack(val_df.iloc[i*batch_size:(i+1)*batch_size]['sentence1'])).float()\n", 81 | " sent2 = torch.tensor(np.stack(val_df.iloc[i*batch_size:(i+1)*batch_size]['sentence2'])).float()\n", 82 | " labels = torch.tensor(val_df.iloc[i*batch_size:(i+1)*batch_size]['score'].values)\n", 83 | " batch = (sent1, sent2, labels)\n", 84 | "\n", 85 | " outputs.append(validation_step(model, batch, device))\n", 86 | " \n", 87 | " if len(val_df)%batch_size != 0:\n", 88 | " sent1 = torch.tensor(np.stack(val_df.iloc[num_steps*batch_size:]['sentence1'])).float()\n", 89 | " sent2 = torch.tensor(np.stack(val_df.iloc[num_steps*batch_size:]['sentence2'])).float()\n", 90 | " labels = torch.tensor(val_df.iloc[num_steps*batch_size:]['score'].values)\n", 91 | " batch = (sent1, sent2, labels)\n", 92 | "\n", 93 | " outputs.append(validation_step(model, batch, device))\n", 94 | " \n", 95 | " return validation_epoch_end(model, outputs)\n", 96 | "\n", 97 | "def get_lr(optimizer):\n", 98 | " for param_group in optimizer.param_groups:\n", 99 | " return param_group['lr']\n", 100 | "\n", 101 | "def fit_one_cycle(epochs, model, train_df, val_df, device, save_path, batch_size = 256):\n", 102 | " best_loss = np.inf\n", 103 | " torch.cuda.empty_cache()\n", 104 | " history = []\n", 105 | " \n", 106 | " optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)\n", 107 | "\n", 108 | " sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)\n", 109 | " num_steps = len(train_df)//batch_size\n", 110 | " \n", 111 | " #for \n", 112 | " for epoch in range(epochs): \n", 113 | " model.train()\n", 114 | " train_losses = []\n", 115 | " lrs = []\n", 116 | " for i in range(num_steps):\n", 117 | " sent1 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence1'])).float()\n", 118 | " sent2 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence2'])).float()\n", 119 | " labels = torch.tensor(train_df.iloc[i*batch_size:(i+1)*batch_size]['score'].values).float()\n", 120 | " batch = (sent1, sent2, labels)\n", 121 | " loss = training_step(model, batch, device)\n", 122 | " train_losses.append(loss)\n", 123 | " loss.backward()\n", 124 | " \n", 125 | " optimizer.step()\n", 126 | " optimizer.zero_grad()\n", 127 | " \n", 128 | " lrs.append(get_lr(optimizer))\n", 129 | " if len(train_df)%batch_size != 0:\n", 130 | " sent1 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence1'])).float()\n", 131 | " sent2 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence2'])).float()\n", 132 | " labels = torch.tensor(train_df.iloc[num_steps*batch_size:]['score'].values).float()\n", 133 | " batch = (sent1, sent2, labels)\n", 134 | " loss = training_step(model, batch, device)\n", 135 | " train_losses.append(loss)\n", 136 | " loss.backward()\n", 137 | " \n", 138 | " optimizer.step()\n", 139 | " optimizer.zero_grad()\n", 140 | " \n", 141 | " lrs.append(get_lr(optimizer))\n", 142 | " \n", 143 | " \n", 144 | " # Validation phase\n", 145 | " result = evaluate(model, val_df, device, batch_size)\n", 146 | " result['train_loss'] = torch.stack(train_losses).mean().item()\n", 147 | " result['lrs'] = lrs\n", 148 | " epoch_end(model, epoch, result)\n", 149 | " history.append(result)\n", 150 | " sched.step(result['Loss'])\n", 151 | " if best_loss > result['Loss']:\n", 152 | " best_loss = result['Loss']\n", 153 | " torch.save(model.state_dict(), save_path)\n", 154 | " \n", 155 | " return history" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 3, 161 | "id": "0530992b", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def fit_one_finetune_cycle(epochs, model, train_df, val_df, lr, device, save_path, batch_size = 256):\n", 166 | " best_loss = np.inf\n", 167 | " torch.cuda.empty_cache()\n", 168 | " history = []\n", 169 | " \n", 170 | " optimizer = torch.optim.Adam(model.parameters(), lr = lr)\n", 171 | "\n", 172 | " sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)\n", 173 | " num_steps = len(train_df)//batch_size\n", 174 | " \n", 175 | " #for \n", 176 | " for epoch in range(epochs): \n", 177 | " model.train()\n", 178 | " train_losses = []\n", 179 | " lrs = []\n", 180 | " for i in range(num_steps):\n", 181 | " sent1 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence1'])).float()\n", 182 | " sent2 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence2'])).float()\n", 183 | " labels = torch.tensor(train_df.iloc[i*batch_size:(i+1)*batch_size]['score'].values).float()\n", 184 | " batch = (sent1, sent2, labels)\n", 185 | " loss = training_step(model, batch, device)\n", 186 | " train_losses.append(loss)\n", 187 | " loss.backward()\n", 188 | " \n", 189 | " optimizer.step()\n", 190 | " optimizer.zero_grad()\n", 191 | " \n", 192 | " lrs.append(get_lr(optimizer))\n", 193 | " if len(train_df)%batch_size != 0:\n", 194 | " sent1 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence1'])).float()\n", 195 | " sent2 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence2'])).float()\n", 196 | " labels = torch.tensor(train_df.iloc[num_steps*batch_size:]['score'].values).float()\n", 197 | " batch = (sent1, sent2, labels)\n", 198 | " loss = training_step(model, batch, device)\n", 199 | " train_losses.append(loss)\n", 200 | " loss.backward()\n", 201 | " \n", 202 | " optimizer.step()\n", 203 | " optimizer.zero_grad()\n", 204 | " \n", 205 | " lrs.append(get_lr(optimizer))\n", 206 | " \n", 207 | " \n", 208 | " # Validation phase\n", 209 | " result = evaluate(model, val_df, device, batch_size)\n", 210 | " result['train_loss'] = torch.stack(train_losses).mean().item()\n", 211 | " result['lrs'] = lrs\n", 212 | " epoch_end(model, epoch, result)\n", 213 | " history.append(result)\n", 214 | " sched.step(result['Loss'])\n", 215 | " if best_loss > result['Loss']:\n", 216 | " best_loss = result['Loss']\n", 217 | " torch.save(model.state_dict(), save_path)\n", 218 | " \n", 219 | " return history" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 4, 225 | "id": "41b683e4", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def attention(x):\n", 230 | " \"\"\"\n", 231 | " Taken from https://github.com/szagoruyko/attention-transfer\n", 232 | " :param x = activations\n", 233 | " \"\"\"\n", 234 | " return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))\n", 235 | "\n", 236 | "\n", 237 | "def attention_diff(x, y):\n", 238 | " \"\"\"\n", 239 | " Taken from https://github.com/szagoruyko/attention-transfer\n", 240 | " :param x = activations\n", 241 | " :param y = activations\n", 242 | " \"\"\"\n", 243 | " return (attention(x) - attention(y)).pow(2).mean()\n", 244 | "\n", 245 | "\n", 246 | "\n", 247 | "def forget_loss(model_output, model_activations, proxy_output, proxy_activations, mask):\n", 248 | "\n", 249 | " loss = F.mse_loss(model_output[mask], proxy_output[mask])\n", 250 | " if AT_beta > 0:\n", 251 | " at_loss = 0\n", 252 | " for i in range(len(proxy_activations)):\n", 253 | " at_loss = at_loss + AT_beta * attention_diff(model_activations[i][mask], proxy_activations[i][mask])\n", 254 | " else:\n", 255 | " at_loss = 0\n", 256 | "\n", 257 | " total_loss = loss + at_loss\n", 258 | "\n", 259 | " return total_loss\n", 260 | "\n", 261 | "\n", 262 | "\n", 263 | "def fit_one_forget_cycle(epochs, model, proxy_model, train_df, val_df, lr, device, save_path, batch_size = 256):\n", 264 | " best_loss = np.inf\n", 265 | " torch.cuda.empty_cache()\n", 266 | " history = []\n", 267 | " \n", 268 | " optimizer = torch.optim.Adam(model.parameters(), lr = lr)\n", 269 | "\n", 270 | " sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)\n", 271 | " num_steps = len(train_df)//batch_size\n", 272 | " for epoch in range(epochs): \n", 273 | " model.train()\n", 274 | " train_losses = []\n", 275 | " lrs = []\n", 276 | " #for batch in train_loader:\n", 277 | " for i in range(num_steps):\n", 278 | " sent1 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence1'])).float()\n", 279 | " sent2 = torch.tensor(np.stack(train_df.iloc[i*batch_size:(i+1)*batch_size]['sentence2'])).float()\n", 280 | " labels = torch.tensor(train_df.iloc[i*batch_size:(i+1)*batch_size]['score'].values).float()\n", 281 | " ulabels = torch.tensor(train_df.iloc[i*batch_size:(i+1)*batch_size]['forget'].values)\n", 282 | " \n", 283 | " sent1, sent2, labels, ulabels = sent1.to(device), sent2.to(device), labels.to(device), ulabels.to(device)\n", 284 | " \n", 285 | " model_out, *model_activations = model(sent1, sent2)\n", 286 | " with torch.no_grad():\n", 287 | " proxy_out, *proxy_activations = proxy_model(sent1, sent2)\n", 288 | " \n", 289 | " \n", 290 | " label_loss = 0\n", 291 | " if ulabels.sum() < len(ulabels):\n", 292 | " mask = (ulabels == 0)\n", 293 | " r_model_out = model_out[mask]\n", 294 | " r_labels = labels[mask]\n", 295 | " label_loss = F.mse_loss(r_model_out, r_labels)\n", 296 | " \n", 297 | " proxy_loss = 0\n", 298 | " if ulabels.sum() > 0:\n", 299 | " mask = (ulabels == 1)\n", 300 | " proxy_loss = forget_loss(model_out, model_activations, proxy_out, proxy_activations, mask)\n", 301 | " \n", 302 | " coeff = ulabels.sum()/len(ulabels)\n", 303 | " loss = coeff*proxy_loss + (1-coeff)*label_loss\n", 304 | " \n", 305 | " ######\n", 306 | " train_losses.append(loss)\n", 307 | " loss.backward()\n", 308 | " \n", 309 | " optimizer.step()\n", 310 | " optimizer.zero_grad()\n", 311 | " \n", 312 | " lrs.append(get_lr(optimizer))\n", 313 | " \n", 314 | " if len(train_df)%batch_size != 0:\n", 315 | " sent1 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence1'])).float()\n", 316 | " sent2 = torch.tensor(np.stack(train_df.iloc[num_steps*batch_size:]['sentence2'])).float()\n", 317 | " labels = torch.tensor(train_df.iloc[num_steps*batch_size:]['score'].values).float()\n", 318 | " ulabels = torch.tensor(train_df.iloc[num_steps*batch_size:]['forget'].values)\n", 319 | " \n", 320 | " sent1, sent2, labels, ulabels = sent1.to(device), sent2.to(device), labels.to(device), ulabels.to(device)\n", 321 | " \n", 322 | " model_out, *model_activations = model(sent1, sent2)\n", 323 | " with torch.no_grad():\n", 324 | " proxy_out, *proxy_activations = proxy_model(sent1, sent2)\n", 325 | " \n", 326 | " \n", 327 | " label_loss = 0\n", 328 | " if ulabels.sum() < len(ulabels):\n", 329 | " mask = (ulabels == 0)\n", 330 | " r_model_out = model_out[mask]\n", 331 | " r_labels = labels[mask]\n", 332 | " label_loss = F.mse_loss(r_model_out, r_labels)\n", 333 | " \n", 334 | " proxy_loss = 0\n", 335 | " if ulabels.sum() > 0:\n", 336 | " mask = (ulabels == 1)\n", 337 | " proxy_loss = forget_loss(model_out, model_activations, proxy_out, proxy_activations, mask)\n", 338 | " \n", 339 | " coeff = ulabels.sum()/len(ulabels)\n", 340 | " loss = coeff*proxy_loss + (1-coeff)*label_loss\n", 341 | " \n", 342 | " ######\n", 343 | " train_losses.append(loss)\n", 344 | " loss.backward()\n", 345 | " \n", 346 | " optimizer.step()\n", 347 | " optimizer.zero_grad()\n", 348 | " \n", 349 | " lrs.append(get_lr(optimizer))\n", 350 | " \n", 351 | " \n", 352 | " # Validation phase\n", 353 | " result = evaluate(model, val_df, device)\n", 354 | " result['train_loss'] = torch.stack(train_losses).mean().item()\n", 355 | " result['lrs'] = lrs\n", 356 | " epoch_end(model, epoch, result)\n", 357 | " history.append(result)\n", 358 | " #sched.step(result['Loss'])\n", 359 | " #if best_loss > result['Loss']:\n", 360 | " # best_loss = result['Loss']\n", 361 | " # torch.save(model.state_dict(), save_path)\n", 362 | " torch.save(model.state_dict(), save_path)\n", 363 | " \n", 364 | " return history" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 5, 370 | "id": "c619d91c", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "text_embedding_dimension = 300\n", 375 | "\n", 376 | "def text_embed(words):\n", 377 | " \n", 378 | " unknown_indices = []\n", 379 | " mean = np.zeros(text_embedding_dimension)\n", 380 | " \n", 381 | " for i in range(len(words)):\n", 382 | " if words[i] in embeddings_index_300 and embeddings_index_300[ words[i] ].shape == (300, ):\n", 383 | " words[i] = embeddings_index_300[ words[i] ]\n", 384 | " mean += words[i]\n", 385 | " else:\n", 386 | " unknown_indices.append(i)\n", 387 | " \n", 388 | " mean /= max(len(words)-len(unknown_indices), 1)\n", 389 | " \n", 390 | " # unknown words in the text are represented using the mean of the known words\n", 391 | " for i in unknown_indices:\n", 392 | " words[i] = mean\n", 393 | " return words\n", 394 | "\n", 395 | "def pad(x, max_len = 10):\n", 396 | " if len(x) >= max_len:\n", 397 | " return x[:10]\n", 398 | " zeros = [np.zeros(text_embedding_dimension)]*(max_len - len(x))\n", 399 | " return zeros + x\n", 400 | "\n" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "36c92dd4", 406 | "metadata": {}, 407 | "source": [ 408 | "## Get GLOVE Word embeddings" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 6, 414 | "id": "0fa7b4a9", 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "Downloading and extracting GloVe word embeddings...\n", 422 | "\n", 423 | "Completed!\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "print(\"Downloading and extracting GloVe word embeddings...\")\n", 429 | "data_file = \"./glove.840B.300d.zip\"\n", 430 | "wget.download(\"http://nlp.stanford.edu/data/glove.840B.300d.zip\", out=data_file)\n", 431 | "with zipfile.ZipFile(data_file) as zip_ref:\n", 432 | " zip_ref.extractall('./glove')\n", 433 | "os.remove(data_file)\n", 434 | "print(\"\\nCompleted!\")" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 7, 440 | "id": "24ba7212", 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "name": "stderr", 445 | "output_type": "stream", 446 | "text": [ 447 | "/tmp/ipykernel_3686965/1901847276.py:7: DeprecationWarning: string or file could not be read to its end due to unmatched data; this will raise a ValueError in the future.\n", 448 | " coefs = np.fromstring(coefs, \"f\", sep=\" \")\n" 449 | ] 450 | } 451 | ], 452 | "source": [ 453 | "path_to_glove_file = \"./glove/glove.840B.300d.txt\"\n", 454 | "\n", 455 | "embeddings_index_300 = {}\n", 456 | "with open(path_to_glove_file) as f:\n", 457 | " for line in f:\n", 458 | " word, coefs = line.split(maxsplit=1)\n", 459 | " coefs = np.fromstring(coefs, \"f\", sep=\" \")\n", 460 | " embeddings_index_300[word] = coefs" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 8, 466 | "id": "d96820e5", 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | "Found 2195884 word vectors.\n" 474 | ] 475 | } 476 | ], 477 | "source": [ 478 | "print(\"Found %s word vectors.\" % len(embeddings_index_300))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "id": "78b826fa", 484 | "metadata": {}, 485 | "source": [ 486 | "## Load Data" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 9, 492 | "id": "744d72db", 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "train_df = pd.read_csv(\"./stsb_data/train_new.tsv\", sep='\\t', on_bad_lines='skip')\n", 497 | "val_df = pd.read_csv(\"./stsb_data/dev_new.tsv\", sep='\\t', on_bad_lines='skip')\n", 498 | "test_df = pd.read_csv(\"./stsb_data/test_new.tsv\", sep='\\t', on_bad_lines='skip')" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 10, 504 | "id": "baa4f766", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "train_df.dropna(subset=['score'], inplace=True)\n", 509 | "val_df.dropna(subset=['score'], inplace=True)\n", 510 | "test_df.dropna(subset=['score'], inplace=True)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": 11, 516 | "id": "16b33a94", 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "train_df['sentence1'] = train_df['sentence1'].apply(lambda x: clean_text(x))\n", 521 | "train_df['sentence2'] = train_df['sentence2'].apply(lambda x: clean_text(x))\n", 522 | "\n", 523 | "val_df['sentence1'] = val_df['sentence1'].apply(lambda x: clean_text(x))\n", 524 | "val_df['sentence2'] = val_df['sentence2'].apply(lambda x: clean_text(x))\n", 525 | "\n", 526 | "test_df['sentence1'] = test_df['sentence1'].apply(lambda x: clean_text(x))\n", 527 | "test_df['sentence2'] = test_df['sentence2'].apply(lambda x: clean_text(x))" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 12, 533 | "id": "bcbbe74a", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "train_df['sentence1'] = train_df['sentence1'].apply(lambda words: text_embed(words))\n", 538 | "train_df['sentence2'] = train_df['sentence2'].apply(lambda words: text_embed(words))\n", 539 | "\n", 540 | "val_df['sentence1'] = val_df['sentence1'].apply(lambda words: text_embed(words))\n", 541 | "val_df['sentence2'] = val_df['sentence2'].apply(lambda words: text_embed(words))\n", 542 | "\n", 543 | "test_df['sentence1'] = test_df['sentence1'].apply(lambda words: text_embed(words))\n", 544 | "test_df['sentence2'] = test_df['sentence2'].apply(lambda words: text_embed(words))" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 13, 550 | "id": "ac57e5f6", 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "train_df['sentence1'] = train_df['sentence1'].apply(lambda words: pad(words))\n", 555 | "train_df['sentence2'] = train_df['sentence2'].apply(lambda words: pad(words))\n", 556 | "\n", 557 | "val_df['sentence1'] = val_df['sentence1'].apply(lambda words: pad(words))\n", 558 | "val_df['sentence2'] = val_df['sentence2'].apply(lambda words: pad(words))\n", 559 | "\n", 560 | "test_df['sentence1'] = test_df['sentence1'].apply(lambda words: pad(words))\n", 561 | "test_df['sentence2'] = test_df['sentence2'].apply(lambda words: pad(words))" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 14, 567 | "id": "ff0a35bf", 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "train_df = train_df.sample(frac = 1, random_state = 0)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "id": "cd8b997c", 577 | "metadata": {}, 578 | "source": [ 579 | "## Train the model" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 15, 585 | "id": "20581b40", 586 | "metadata": { 587 | "scrolled": true 588 | }, 589 | "outputs": [ 590 | { 591 | "name": "stdout", 592 | "output_type": "stream", 593 | "text": [ 594 | "Epoch [0], last_lr: 0.01000, train_loss: 3.7514, val_loss: 2.1620\n", 595 | "Epoch [1], last_lr: 0.01000, train_loss: 2.2661, val_loss: 2.0303\n", 596 | "Epoch [2], last_lr: 0.01000, train_loss: 2.2280, val_loss: 1.9903\n", 597 | "Epoch [3], last_lr: 0.01000, train_loss: 2.1673, val_loss: 1.9712\n", 598 | "Epoch [4], last_lr: 0.01000, train_loss: 2.0814, val_loss: 1.9476\n", 599 | "Epoch [5], last_lr: 0.01000, train_loss: 1.9387, val_loss: 1.9810\n", 600 | "Epoch [6], last_lr: 0.01000, train_loss: 1.7108, val_loss: 2.0152\n", 601 | "Epoch [7], last_lr: 0.01000, train_loss: 1.5121, val_loss: 2.0804\n", 602 | "Epoch [8], last_lr: 0.01000, train_loss: 1.3200, val_loss: 2.1425\n", 603 | "Epoch [9], last_lr: 0.01000, train_loss: 1.2409, val_loss: 2.2455\n", 604 | "Epoch [10], last_lr: 0.01000, train_loss: 1.3078, val_loss: 3.0867\n", 605 | "Epoch [11], last_lr: 0.01000, train_loss: 1.7610, val_loss: 2.3519\n", 606 | "Epoch [12], last_lr: 0.01000, train_loss: 1.3685, val_loss: 2.1034\n", 607 | "Epoch [13], last_lr: 0.01000, train_loss: 0.7959, val_loss: 2.3659\n", 608 | "Epoch [14], last_lr: 0.01000, train_loss: 0.6454, val_loss: 2.4677\n", 609 | "Epoch [15], last_lr: 0.01000, train_loss: 0.5746, val_loss: 2.4120\n", 610 | "Epoch 16: reducing learning rate of group 0 to 1.0000e-03.\n", 611 | "Epoch [16], last_lr: 0.00100, train_loss: 0.4698, val_loss: 2.3867\n", 612 | "Epoch [17], last_lr: 0.00100, train_loss: 0.4407, val_loss: 2.4148\n", 613 | "Epoch [18], last_lr: 0.00100, train_loss: 0.4346, val_loss: 2.3970\n", 614 | "Epoch [19], last_lr: 0.00100, train_loss: 0.4076, val_loss: 2.4035\n", 615 | "Epoch [20], last_lr: 0.00100, train_loss: 0.4085, val_loss: 2.4092\n", 616 | "Epoch [21], last_lr: 0.00100, train_loss: 0.3912, val_loss: 2.4234\n", 617 | "Epoch [22], last_lr: 0.00100, train_loss: 0.3870, val_loss: 2.4152\n", 618 | "Epoch [23], last_lr: 0.00100, train_loss: 0.3731, val_loss: 2.4115\n", 619 | "Epoch [24], last_lr: 0.00100, train_loss: 0.3774, val_loss: 2.4153\n", 620 | "Epoch [25], last_lr: 0.00100, train_loss: 0.3694, val_loss: 2.4129\n", 621 | "Epoch [26], last_lr: 0.00100, train_loss: 0.3600, val_loss: 2.4117\n", 622 | "Epoch 27: reducing learning rate of group 0 to 1.0000e-04.\n", 623 | "Epoch [27], last_lr: 0.00010, train_loss: 0.3481, val_loss: 2.4155\n", 624 | "Epoch [28], last_lr: 0.00010, train_loss: 0.3343, val_loss: 2.4162\n", 625 | "Epoch [29], last_lr: 0.00010, train_loss: 0.3467, val_loss: 2.4154\n", 626 | "Epoch [30], last_lr: 0.00010, train_loss: 0.3555, val_loss: 2.4142\n", 627 | "Epoch [31], last_lr: 0.00010, train_loss: 0.3464, val_loss: 2.4153\n", 628 | "Epoch [32], last_lr: 0.00010, train_loss: 0.3443, val_loss: 2.4130\n", 629 | "Epoch [33], last_lr: 0.00010, train_loss: 0.3417, val_loss: 2.4131\n", 630 | "Epoch [34], last_lr: 0.00010, train_loss: 0.3423, val_loss: 2.4147\n", 631 | "Epoch [35], last_lr: 0.00010, train_loss: 0.3409, val_loss: 2.4150\n", 632 | "Epoch [36], last_lr: 0.00010, train_loss: 0.3507, val_loss: 2.4144\n", 633 | "Epoch [37], last_lr: 0.00010, train_loss: 0.3405, val_loss: 2.4174\n", 634 | "Epoch 38: reducing learning rate of group 0 to 1.0000e-05.\n", 635 | "Epoch [38], last_lr: 0.00001, train_loss: 0.3391, val_loss: 2.4176\n", 636 | "Epoch [39], last_lr: 0.00001, train_loss: 0.3467, val_loss: 2.4175\n", 637 | "Epoch [40], last_lr: 0.00001, train_loss: 0.3300, val_loss: 2.4170\n", 638 | "Epoch [41], last_lr: 0.00001, train_loss: 0.3391, val_loss: 2.4170\n", 639 | "Epoch [42], last_lr: 0.00001, train_loss: 0.3409, val_loss: 2.4170\n", 640 | "Epoch [43], last_lr: 0.00001, train_loss: 0.3369, val_loss: 2.4170\n", 641 | "Epoch [44], last_lr: 0.00001, train_loss: 0.3377, val_loss: 2.4171\n", 642 | "Epoch [45], last_lr: 0.00001, train_loss: 0.3406, val_loss: 2.4169\n", 643 | "Epoch [46], last_lr: 0.00001, train_loss: 0.3391, val_loss: 2.4169\n", 644 | "Epoch [47], last_lr: 0.00001, train_loss: 0.3513, val_loss: 2.4167\n", 645 | "Epoch [48], last_lr: 0.00001, train_loss: 0.3420, val_loss: 2.4165\n", 646 | "Epoch 49: reducing learning rate of group 0 to 1.0000e-06.\n", 647 | "Epoch [49], last_lr: 0.00000, train_loss: 0.3418, val_loss: 2.4165\n", 648 | "Epoch [50], last_lr: 0.00000, train_loss: 0.3363, val_loss: 2.4165\n", 649 | "Epoch [51], last_lr: 0.00000, train_loss: 0.3415, val_loss: 2.4166\n", 650 | "Epoch [52], last_lr: 0.00000, train_loss: 0.3249, val_loss: 2.4166\n", 651 | "Epoch [53], last_lr: 0.00000, train_loss: 0.3291, val_loss: 2.4165\n", 652 | "Epoch [54], last_lr: 0.00000, train_loss: 0.3327, val_loss: 2.4165\n", 653 | "Epoch [55], last_lr: 0.00000, train_loss: 0.3423, val_loss: 2.4166\n", 654 | "Epoch [56], last_lr: 0.00000, train_loss: 0.3345, val_loss: 2.4165\n", 655 | "Epoch [57], last_lr: 0.00000, train_loss: 0.3398, val_loss: 2.4165\n", 656 | "Epoch [58], last_lr: 0.00000, train_loss: 0.3429, val_loss: 2.4165\n", 657 | "Epoch [59], last_lr: 0.00000, train_loss: 0.3336, val_loss: 2.4165\n", 658 | "Epoch 60: reducing learning rate of group 0 to 1.0000e-07.\n", 659 | "Epoch [60], last_lr: 0.00000, train_loss: 0.3449, val_loss: 2.4165\n", 660 | "Epoch [61], last_lr: 0.00000, train_loss: 0.3385, val_loss: 2.4165\n", 661 | "Epoch [62], last_lr: 0.00000, train_loss: 0.3467, val_loss: 2.4165\n", 662 | "Epoch [63], last_lr: 0.00000, train_loss: 0.3370, val_loss: 2.4165\n", 663 | "Epoch [64], last_lr: 0.00000, train_loss: 0.3454, val_loss: 2.4165\n", 664 | "Epoch [65], last_lr: 0.00000, train_loss: 0.3425, val_loss: 2.4165\n", 665 | "Epoch [66], last_lr: 0.00000, train_loss: 0.3245, val_loss: 2.4165\n", 666 | "Epoch [67], last_lr: 0.00000, train_loss: 0.3301, val_loss: 2.4165\n", 667 | "Epoch [68], last_lr: 0.00000, train_loss: 0.3389, val_loss: 2.4165\n", 668 | "Epoch [69], last_lr: 0.00000, train_loss: 0.3454, val_loss: 2.4165\n", 669 | "Epoch [70], last_lr: 0.00000, train_loss: 0.3395, val_loss: 2.4165\n", 670 | "Epoch 71: reducing learning rate of group 0 to 1.0000e-08.\n", 671 | "Epoch [71], last_lr: 0.00000, train_loss: 0.3348, val_loss: 2.4165\n", 672 | "Epoch [72], last_lr: 0.00000, train_loss: 0.3462, val_loss: 2.4165\n", 673 | "Epoch [73], last_lr: 0.00000, train_loss: 0.3448, val_loss: 2.4165\n", 674 | "Epoch [74], last_lr: 0.00000, train_loss: 0.3434, val_loss: 2.4165\n", 675 | "Epoch [75], last_lr: 0.00000, train_loss: 0.3374, val_loss: 2.4165\n", 676 | "Epoch [76], last_lr: 0.00000, train_loss: 0.3355, val_loss: 2.4165\n", 677 | "Epoch [77], last_lr: 0.00000, train_loss: 0.3372, val_loss: 2.4165\n", 678 | "Epoch [78], last_lr: 0.00000, train_loss: 0.3313, val_loss: 2.4165\n", 679 | "Epoch [79], last_lr: 0.00000, train_loss: 0.3378, val_loss: 2.4165\n", 680 | "Epoch [80], last_lr: 0.00000, train_loss: 0.3414, val_loss: 2.4165\n", 681 | "Epoch [81], last_lr: 0.00000, train_loss: 0.3318, val_loss: 2.4165\n", 682 | "Epoch [82], last_lr: 0.00000, train_loss: 0.3428, val_loss: 2.4165\n", 683 | "Epoch [83], last_lr: 0.00000, train_loss: 0.3395, val_loss: 2.4165\n", 684 | "Epoch [84], last_lr: 0.00000, train_loss: 0.3440, val_loss: 2.4165\n", 685 | "Epoch [85], last_lr: 0.00000, train_loss: 0.3342, val_loss: 2.4165\n", 686 | "Epoch [86], last_lr: 0.00000, train_loss: 0.3378, val_loss: 2.4165\n", 687 | "Epoch [87], last_lr: 0.00000, train_loss: 0.3456, val_loss: 2.4165\n", 688 | "Epoch [88], last_lr: 0.00000, train_loss: 0.3383, val_loss: 2.4165\n", 689 | "Epoch [89], last_lr: 0.00000, train_loss: 0.3314, val_loss: 2.4165\n", 690 | "Epoch [90], last_lr: 0.00000, train_loss: 0.3422, val_loss: 2.4165\n", 691 | "Epoch [91], last_lr: 0.00000, train_loss: 0.3396, val_loss: 2.4165\n", 692 | "Epoch [92], last_lr: 0.00000, train_loss: 0.3337, val_loss: 2.4165\n", 693 | "Epoch [93], last_lr: 0.00000, train_loss: 0.3437, val_loss: 2.4165\n", 694 | "Epoch [94], last_lr: 0.00000, train_loss: 0.3375, val_loss: 2.4165\n", 695 | "Epoch [95], last_lr: 0.00000, train_loss: 0.3427, val_loss: 2.4165\n", 696 | "Epoch [96], last_lr: 0.00000, train_loss: 0.3482, val_loss: 2.4165\n", 697 | "Epoch [97], last_lr: 0.00000, train_loss: 0.3346, val_loss: 2.4165\n", 698 | "Epoch [98], last_lr: 0.00000, train_loss: 0.3463, val_loss: 2.4165\n", 699 | "Epoch [99], last_lr: 0.00000, train_loss: 0.3376, val_loss: 2.4165\n" 700 | ] 701 | }, 702 | { 703 | "data": { 704 | "text/plain": [ 705 | "" 706 | ] 707 | }, 708 | "execution_count": 15, 709 | "metadata": {}, 710 | "output_type": "execute_result" 711 | } 712 | ], 713 | "source": [ 714 | "device = 'cuda'\n", 715 | "model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 716 | "epochs = 100\n", 717 | "save_path = \"saved_models/LSTM_STSB_100epochs.pt\"\n", 718 | "history = fit_one_cycle(epochs, model, train_df, val_df, device = device, save_path = save_path)\n", 719 | "model.load_state_dict(torch.load(save_path))" 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "id": "e95014d9", 725 | "metadata": {}, 726 | "source": [ 727 | "## Creating the forget and retain sets\n" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 16, 733 | "id": "b4fa3fb5", 734 | "metadata": {}, 735 | "outputs": [], 736 | "source": [ 737 | "train_df_retain = train_df[train_df['score'] >= 2]\n", 738 | "val_df_retain = val_df[val_df['score'] >= 2]\n", 739 | "test_df_retain = test_df[test_df['score'] >= 2]" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": 17, 745 | "id": "beb1c7f6", 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "train_df_forget = train_df[train_df['score'] < 2]\n", 750 | "val_df_forget = val_df[val_df['score'] < 2]\n", 751 | "test_df_forget = test_df[test_df['score'] < 2]" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "id": "e4b8fe69", 757 | "metadata": {}, 758 | "source": [ 759 | "## Retraining the model from scratch on Retain Data" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 19, 765 | "id": "6581b7f4", 766 | "metadata": {}, 767 | "outputs": [ 768 | { 769 | "name": "stdout", 770 | "output_type": "stream", 771 | "text": [ 772 | "Epoch [0], last_lr: 0.01000, train_loss: 4.4909, val_loss: 1.1749\n", 773 | "Epoch [1], last_lr: 0.01000, train_loss: 0.9404, val_loss: 1.3314\n", 774 | "Epoch [2], last_lr: 0.01000, train_loss: 0.7599, val_loss: 0.9330\n", 775 | "Epoch [3], last_lr: 0.01000, train_loss: 0.6883, val_loss: 1.0268\n", 776 | "Epoch [4], last_lr: 0.01000, train_loss: 0.6612, val_loss: 0.9592\n", 777 | "Epoch [5], last_lr: 0.01000, train_loss: 0.6337, val_loss: 1.0138\n", 778 | "Epoch [6], last_lr: 0.01000, train_loss: 0.6384, val_loss: 0.9921\n", 779 | "Epoch [7], last_lr: 0.01000, train_loss: 0.6282, val_loss: 0.9893\n", 780 | "Epoch [8], last_lr: 0.01000, train_loss: 0.6106, val_loss: 0.9939\n", 781 | "Epoch [9], last_lr: 0.01000, train_loss: 0.5835, val_loss: 1.0177\n", 782 | "Epoch [10], last_lr: 0.01000, train_loss: 0.5419, val_loss: 1.0217\n", 783 | "Epoch [11], last_lr: 0.01000, train_loss: 0.5061, val_loss: 1.0469\n", 784 | "Epoch [12], last_lr: 0.01000, train_loss: 0.4733, val_loss: 1.0764\n", 785 | "Epoch [13], last_lr: 0.01000, train_loss: 0.4352, val_loss: 1.1303\n", 786 | "Epoch 14: reducing learning rate of group 0 to 1.0000e-03.\n", 787 | "Epoch [14], last_lr: 0.00100, train_loss: 0.3902, val_loss: 1.0564\n", 788 | "Epoch [15], last_lr: 0.00100, train_loss: 0.3695, val_loss: 1.0537\n", 789 | "Epoch [16], last_lr: 0.00100, train_loss: 0.3652, val_loss: 1.0650\n", 790 | "Epoch [17], last_lr: 0.00100, train_loss: 0.3575, val_loss: 1.0767\n", 791 | "Epoch [18], last_lr: 0.00100, train_loss: 0.3496, val_loss: 1.0717\n", 792 | "Epoch [19], last_lr: 0.00100, train_loss: 0.3539, val_loss: 1.0816\n", 793 | "Epoch [20], last_lr: 0.00100, train_loss: 0.3503, val_loss: 1.0682\n", 794 | "Epoch [21], last_lr: 0.00100, train_loss: 0.3332, val_loss: 1.0863\n", 795 | "Epoch [22], last_lr: 0.00100, train_loss: 0.3259, val_loss: 1.0894\n", 796 | "Epoch [23], last_lr: 0.00100, train_loss: 0.3229, val_loss: 1.0952\n", 797 | "Epoch [24], last_lr: 0.00100, train_loss: 0.3161, val_loss: 1.0974\n", 798 | "Epoch 25: reducing learning rate of group 0 to 1.0000e-04.\n", 799 | "Epoch [25], last_lr: 0.00010, train_loss: 0.3126, val_loss: 1.0999\n", 800 | "Epoch [26], last_lr: 0.00010, train_loss: 0.3121, val_loss: 1.1028\n", 801 | "Epoch [27], last_lr: 0.00010, train_loss: 0.3107, val_loss: 1.1022\n", 802 | "Epoch [28], last_lr: 0.00010, train_loss: 0.3096, val_loss: 1.1047\n", 803 | "Epoch [29], last_lr: 0.00010, train_loss: 0.3067, val_loss: 1.1033\n", 804 | "Epoch [30], last_lr: 0.00010, train_loss: 0.3111, val_loss: 1.1007\n", 805 | "Epoch [31], last_lr: 0.00010, train_loss: 0.3119, val_loss: 1.1031\n", 806 | "Epoch [32], last_lr: 0.00010, train_loss: 0.3076, val_loss: 1.1046\n", 807 | "Epoch [33], last_lr: 0.00010, train_loss: 0.3026, val_loss: 1.1058\n", 808 | "Epoch [34], last_lr: 0.00010, train_loss: 0.3026, val_loss: 1.1040\n", 809 | "Epoch [35], last_lr: 0.00010, train_loss: 0.2972, val_loss: 1.1092\n", 810 | "Epoch 36: reducing learning rate of group 0 to 1.0000e-05.\n", 811 | "Epoch [36], last_lr: 0.00001, train_loss: 0.3002, val_loss: 1.1089\n", 812 | "Epoch [37], last_lr: 0.00001, train_loss: 0.3023, val_loss: 1.1078\n", 813 | "Epoch [38], last_lr: 0.00001, train_loss: 0.3008, val_loss: 1.1062\n", 814 | "Epoch [39], last_lr: 0.00001, train_loss: 0.3050, val_loss: 1.1055\n", 815 | "Epoch [40], last_lr: 0.00001, train_loss: 0.3099, val_loss: 1.1050\n", 816 | "Epoch [41], last_lr: 0.00001, train_loss: 0.3022, val_loss: 1.1052\n", 817 | "Epoch [42], last_lr: 0.00001, train_loss: 0.3069, val_loss: 1.1055\n", 818 | "Epoch [43], last_lr: 0.00001, train_loss: 0.3120, val_loss: 1.1049\n", 819 | "Epoch [44], last_lr: 0.00001, train_loss: 0.3075, val_loss: 1.1048\n", 820 | "Epoch [45], last_lr: 0.00001, train_loss: 0.2988, val_loss: 1.1044\n", 821 | "Epoch [46], last_lr: 0.00001, train_loss: 0.3010, val_loss: 1.1038\n", 822 | "Epoch 47: reducing learning rate of group 0 to 1.0000e-06.\n", 823 | "Epoch [47], last_lr: 0.00000, train_loss: 0.3095, val_loss: 1.1038\n", 824 | "Epoch [48], last_lr: 0.00000, train_loss: 0.3041, val_loss: 1.1038\n", 825 | "Epoch [49], last_lr: 0.00000, train_loss: 0.3057, val_loss: 1.1039\n", 826 | "Epoch [50], last_lr: 0.00000, train_loss: 0.2969, val_loss: 1.1039\n", 827 | "Epoch [51], last_lr: 0.00000, train_loss: 0.3009, val_loss: 1.1039\n", 828 | "Epoch [52], last_lr: 0.00000, train_loss: 0.3028, val_loss: 1.1039\n", 829 | "Epoch [53], last_lr: 0.00000, train_loss: 0.3057, val_loss: 1.1039\n", 830 | "Epoch [54], last_lr: 0.00000, train_loss: 0.3009, val_loss: 1.1038\n", 831 | "Epoch [55], last_lr: 0.00000, train_loss: 0.3039, val_loss: 1.1038\n", 832 | "Epoch [56], last_lr: 0.00000, train_loss: 0.3031, val_loss: 1.1038\n", 833 | "Epoch [57], last_lr: 0.00000, train_loss: 0.3043, val_loss: 1.1039\n", 834 | "Epoch 58: reducing learning rate of group 0 to 1.0000e-07.\n", 835 | "Epoch [58], last_lr: 0.00000, train_loss: 0.3009, val_loss: 1.1039\n", 836 | "Epoch [59], last_lr: 0.00000, train_loss: 0.3036, val_loss: 1.1039\n", 837 | "Epoch [60], last_lr: 0.00000, train_loss: 0.3057, val_loss: 1.1039\n", 838 | "Epoch [61], last_lr: 0.00000, train_loss: 0.3021, val_loss: 1.1039\n", 839 | "Epoch [62], last_lr: 0.00000, train_loss: 0.3059, val_loss: 1.1039\n", 840 | "Epoch [63], last_lr: 0.00000, train_loss: 0.3051, val_loss: 1.1039\n", 841 | "Epoch [64], last_lr: 0.00000, train_loss: 0.3073, val_loss: 1.1039\n", 842 | "Epoch [65], last_lr: 0.00000, train_loss: 0.3056, val_loss: 1.1039\n", 843 | "Epoch [66], last_lr: 0.00000, train_loss: 0.3067, val_loss: 1.1039\n", 844 | "Epoch [67], last_lr: 0.00000, train_loss: 0.3114, val_loss: 1.1039\n", 845 | "Epoch [68], last_lr: 0.00000, train_loss: 0.3039, val_loss: 1.1039\n", 846 | "Epoch 69: reducing learning rate of group 0 to 1.0000e-08.\n", 847 | "Epoch [69], last_lr: 0.00000, train_loss: 0.2969, val_loss: 1.1039\n", 848 | "Epoch [70], last_lr: 0.00000, train_loss: 0.3046, val_loss: 1.1039\n", 849 | "Epoch [71], last_lr: 0.00000, train_loss: 0.3028, val_loss: 1.1039\n", 850 | "Epoch [72], last_lr: 0.00000, train_loss: 0.3094, val_loss: 1.1039\n", 851 | "Epoch [73], last_lr: 0.00000, train_loss: 0.3110, val_loss: 1.1039\n", 852 | "Epoch [74], last_lr: 0.00000, train_loss: 0.3052, val_loss: 1.1039\n", 853 | "Epoch [75], last_lr: 0.00000, train_loss: 0.2993, val_loss: 1.1039\n", 854 | "Epoch [76], last_lr: 0.00000, train_loss: 0.3104, val_loss: 1.1039\n", 855 | "Epoch [77], last_lr: 0.00000, train_loss: 0.3045, val_loss: 1.1039\n", 856 | "Epoch [78], last_lr: 0.00000, train_loss: 0.3094, val_loss: 1.1039\n", 857 | "Epoch [79], last_lr: 0.00000, train_loss: 0.3047, val_loss: 1.1039\n", 858 | "Epoch [80], last_lr: 0.00000, train_loss: 0.3046, val_loss: 1.1039\n", 859 | "Epoch [81], last_lr: 0.00000, train_loss: 0.3047, val_loss: 1.1039\n", 860 | "Epoch [82], last_lr: 0.00000, train_loss: 0.3123, val_loss: 1.1039\n", 861 | "Epoch [83], last_lr: 0.00000, train_loss: 0.3108, val_loss: 1.1039\n", 862 | "Epoch [84], last_lr: 0.00000, train_loss: 0.3012, val_loss: 1.1039\n", 863 | "Epoch [85], last_lr: 0.00000, train_loss: 0.3051, val_loss: 1.1039\n", 864 | "Epoch [86], last_lr: 0.00000, train_loss: 0.3100, val_loss: 1.1039\n", 865 | "Epoch [87], last_lr: 0.00000, train_loss: 0.3029, val_loss: 1.1039\n", 866 | "Epoch [88], last_lr: 0.00000, train_loss: 0.3065, val_loss: 1.1039\n", 867 | "Epoch [89], last_lr: 0.00000, train_loss: 0.3035, val_loss: 1.1039\n", 868 | "Epoch [90], last_lr: 0.00000, train_loss: 0.3059, val_loss: 1.1039\n", 869 | "Epoch [91], last_lr: 0.00000, train_loss: 0.2980, val_loss: 1.1039\n", 870 | "Epoch [92], last_lr: 0.00000, train_loss: 0.3091, val_loss: 1.1039\n", 871 | "Epoch [93], last_lr: 0.00000, train_loss: 0.3034, val_loss: 1.1039\n", 872 | "Epoch [94], last_lr: 0.00000, train_loss: 0.3047, val_loss: 1.1039\n", 873 | "Epoch [95], last_lr: 0.00000, train_loss: 0.2994, val_loss: 1.1039\n", 874 | "Epoch [96], last_lr: 0.00000, train_loss: 0.2968, val_loss: 1.1039\n", 875 | "Epoch [97], last_lr: 0.00000, train_loss: 0.3181, val_loss: 1.1039\n", 876 | "Epoch [98], last_lr: 0.00000, train_loss: 0.2990, val_loss: 1.1039\n", 877 | "Epoch [99], last_lr: 0.00000, train_loss: 0.3063, val_loss: 1.1039\n" 878 | ] 879 | }, 880 | { 881 | "data": { 882 | "text/plain": [ 883 | "" 884 | ] 885 | }, 886 | "execution_count": 19, 887 | "metadata": {}, 888 | "output_type": "execute_result" 889 | } 890 | ], 891 | "source": [ 892 | "device = 'cuda'\n", 893 | "gold_model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 894 | "\n", 895 | "epochs = 100\n", 896 | "save_path = \"saved_models/LSTM_STSB_100epochs_0to2_retrained.pt\"\n", 897 | "history = fit_one_cycle(epochs, gold_model, train_df_retain, val_df_retain, device = device, save_path = save_path)\n", 898 | "gold_model.load_state_dict(torch.load(save_path))" 899 | ] 900 | }, 901 | { 902 | "cell_type": "markdown", 903 | "id": "706b753b", 904 | "metadata": {}, 905 | "source": [ 906 | "### Evaluate the retrained model on various cohorts" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": 20, 912 | "id": "56387ab7", 913 | "metadata": {}, 914 | "outputs": [ 915 | { 916 | "data": { 917 | "text/plain": [ 918 | "{'Loss': 2.841952850846533}" 919 | ] 920 | }, 921 | "execution_count": 20, 922 | "metadata": {}, 923 | "output_type": "execute_result" 924 | } 925 | ], 926 | "source": [ 927 | "evaluate(model, test_df_retain, 'cuda')" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": 21, 933 | "id": "16bb9939", 934 | "metadata": {}, 935 | "outputs": [ 936 | { 937 | "data": { 938 | "text/plain": [ 939 | "{'Loss': 1.7932759574140893}" 940 | ] 941 | }, 942 | "execution_count": 21, 943 | "metadata": {}, 944 | "output_type": "execute_result" 945 | } 946 | ], 947 | "source": [ 948 | "evaluate(model, test_df_forget, 'cuda')" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 22, 954 | "id": "f9dd8476", 955 | "metadata": {}, 956 | "outputs": [ 957 | { 958 | "data": { 959 | "text/plain": [ 960 | "{'Loss': 0.9258318761899855}" 961 | ] 962 | }, 963 | "execution_count": 22, 964 | "metadata": {}, 965 | "output_type": "execute_result" 966 | } 967 | ], 968 | "source": [ 969 | "evaluate(gold_model, test_df_retain, 'cuda')" 970 | ] 971 | }, 972 | { 973 | "cell_type": "code", 974 | "execution_count": 23, 975 | "id": "e51de613", 976 | "metadata": {}, 977 | "outputs": [ 978 | { 979 | "data": { 980 | "text/plain": [ 981 | "{'Loss': 7.221951234429545}" 982 | ] 983 | }, 984 | "execution_count": 23, 985 | "metadata": {}, 986 | "output_type": "execute_result" 987 | } 988 | ], 989 | "source": [ 990 | "evaluate(gold_model, test_df_forget, 'cuda')" 991 | ] 992 | }, 993 | { 994 | "cell_type": "markdown", 995 | "id": "72ce5e55", 996 | "metadata": {}, 997 | "source": [ 998 | "## Finetuning" 999 | ] 1000 | }, 1001 | { 1002 | "cell_type": "code", 1003 | "execution_count": 24, 1004 | "id": "c5b988cd", 1005 | "metadata": {}, 1006 | "outputs": [ 1007 | { 1008 | "name": "stdout", 1009 | "output_type": "stream", 1010 | "text": [ 1011 | "Epoch [0], last_lr: 0.00100, train_loss: 0.9928, val_loss: 1.1080\n", 1012 | "Epoch [1], last_lr: 0.00100, train_loss: 0.7369, val_loss: 1.0494\n", 1013 | "Epoch [2], last_lr: 0.00100, train_loss: 0.7062, val_loss: 1.1080\n", 1014 | "Epoch [3], last_lr: 0.00100, train_loss: 0.6652, val_loss: 1.0326\n", 1015 | "Epoch [4], last_lr: 0.00100, train_loss: 0.6550, val_loss: 1.0242\n", 1016 | "CPU times: user 2min 13s, sys: 45.9 ms, total: 2min 13s\n", 1017 | "Wall time: 1.46 s\n" 1018 | ] 1019 | }, 1020 | { 1021 | "data": { 1022 | "text/plain": [ 1023 | "" 1024 | ] 1025 | }, 1026 | "execution_count": 24, 1027 | "metadata": {}, 1028 | "output_type": "execute_result" 1029 | } 1030 | ], 1031 | "source": [ 1032 | "%%time\n", 1033 | "student_model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 1034 | "student_model.load_state_dict(torch.load(\"saved_models/LSTM_STSB_100epochs.pt\"))\n", 1035 | "epochs = 5\n", 1036 | "save_path = \"saved_models/LSTM_STSB_5epochs_4to5_Finetune_Forget.pt\"\n", 1037 | "history = fit_one_finetune_cycle(epochs, student_model, train_df_retain, val_df_retain, 0.001, device = device, save_path = save_path)\n", 1038 | "student_model.load_state_dict(torch.load(save_path))" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": 25, 1044 | "id": "2952073d", 1045 | "metadata": {}, 1046 | "outputs": [ 1047 | { 1048 | "data": { 1049 | "text/plain": [ 1050 | "{'Loss': 1.0722167293689073}" 1051 | ] 1052 | }, 1053 | "execution_count": 25, 1054 | "metadata": {}, 1055 | "output_type": "execute_result" 1056 | } 1057 | ], 1058 | "source": [ 1059 | "evaluate(student_model, test_df_retain, 'cuda')" 1060 | ] 1061 | }, 1062 | { 1063 | "cell_type": "code", 1064 | "execution_count": 26, 1065 | "id": "0c421e03", 1066 | "metadata": {}, 1067 | "outputs": [ 1068 | { 1069 | "data": { 1070 | "text/plain": [ 1071 | "{'Loss': 5.554914107911729}" 1072 | ] 1073 | }, 1074 | "execution_count": 26, 1075 | "metadata": {}, 1076 | "output_type": "execute_result" 1077 | } 1078 | ], 1079 | "source": [ 1080 | "evaluate(student_model, test_df_forget, 'cuda')" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "markdown", 1085 | "id": "20853f1f", 1086 | "metadata": {}, 1087 | "source": [ 1088 | "## Amnesiac Finetuning" 1089 | ] 1090 | }, 1091 | { 1092 | "cell_type": "code", 1093 | "execution_count": 27, 1094 | "id": "9b8a0585", 1095 | "metadata": {}, 1096 | "outputs": [], 1097 | "source": [ 1098 | "mean = train_df['score'].mean()\n", 1099 | "sd = train_df['score'].std()\n", 1100 | "\n", 1101 | "random_preds = np.random.normal(loc=mean, scale=sd, size=(len(train_df[train_df['score'] < 2]),))\n", 1102 | "\n", 1103 | "amnesiac_finetune_df = train_df.copy()\n", 1104 | "amnesiac_finetune_df.loc[amnesiac_finetune_df['score'] < 2, 'score'] = random_preds" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "execution_count": 28, 1110 | "id": "2c35e52c", 1111 | "metadata": {}, 1112 | "outputs": [ 1113 | { 1114 | "name": "stdout", 1115 | "output_type": "stream", 1116 | "text": [ 1117 | "Epoch [0], last_lr: 0.00100, train_loss: 1.4821, val_loss: 1.3775\n", 1118 | "Epoch [1], last_lr: 0.00100, train_loss: 1.3101, val_loss: 1.1982\n", 1119 | "Epoch [2], last_lr: 0.00100, train_loss: 1.2878, val_loss: 1.1332\n", 1120 | "Epoch [3], last_lr: 0.00100, train_loss: 1.2601, val_loss: 1.0947\n", 1121 | "Epoch [4], last_lr: 0.00100, train_loss: 1.2185, val_loss: 1.0959\n", 1122 | "CPU times: user 3min 2s, sys: 31.4 ms, total: 3min 2s\n", 1123 | "Wall time: 1.85 s\n" 1124 | ] 1125 | }, 1126 | { 1127 | "data": { 1128 | "text/plain": [ 1129 | "" 1130 | ] 1131 | }, 1132 | "execution_count": 28, 1133 | "metadata": {}, 1134 | "output_type": "execute_result" 1135 | } 1136 | ], 1137 | "source": [ 1138 | "%%time\n", 1139 | "student_model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 1140 | "student_model.load_state_dict(torch.load(\"saved_models/LSTM_STSB_100epochs.pt\"))\n", 1141 | "epochs = 5\n", 1142 | "save_path = \"saved_models/LSTM_STSB_2epochs_Amnesiac_Finetune_Forget.pt\"\n", 1143 | "history = fit_one_finetune_cycle(epochs, student_model, amnesiac_finetune_df, val_df_retain, 0.001, device = device, save_path = save_path)\n", 1144 | "student_model.load_state_dict(torch.load(save_path))" 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": 29, 1150 | "id": "683eca0b", 1151 | "metadata": {}, 1152 | "outputs": [ 1153 | { 1154 | "data": { 1155 | "text/plain": [ 1156 | "{'Loss': 1.2053828360775851}" 1157 | ] 1158 | }, 1159 | "execution_count": 29, 1160 | "metadata": {}, 1161 | "output_type": "execute_result" 1162 | } 1163 | ], 1164 | "source": [ 1165 | "evaluate(student_model, test_df_retain, 'cuda')" 1166 | ] 1167 | }, 1168 | { 1169 | "cell_type": "code", 1170 | "execution_count": 30, 1171 | "id": "79517fae", 1172 | "metadata": {}, 1173 | "outputs": [ 1174 | { 1175 | "data": { 1176 | "text/plain": [ 1177 | "{'Loss': 4.82110405608104}" 1178 | ] 1179 | }, 1180 | "execution_count": 30, 1181 | "metadata": {}, 1182 | "output_type": "execute_result" 1183 | } 1184 | ], 1185 | "source": [ 1186 | "evaluate(student_model, test_df_forget, 'cuda')" 1187 | ] 1188 | }, 1189 | { 1190 | "cell_type": "markdown", 1191 | "id": "215bb9e1", 1192 | "metadata": {}, 1193 | "source": [ 1194 | "## Blindspot Unlearning" 1195 | ] 1196 | }, 1197 | { 1198 | "cell_type": "code", 1199 | "execution_count": 31, 1200 | "id": "b7f20c7a", 1201 | "metadata": {}, 1202 | "outputs": [], 1203 | "source": [ 1204 | "u_train_df = train_df.copy()" 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": 32, 1210 | "id": "ca6e4e6f", 1211 | "metadata": {}, 1212 | "outputs": [], 1213 | "source": [ 1214 | "u_train_df['forget'] = 0\n", 1215 | "u_train_df.loc[u_train_df['score'] < 2, 'forget'] = 1" 1216 | ] 1217 | }, 1218 | { 1219 | "cell_type": "markdown", 1220 | "id": "e667bf9a", 1221 | "metadata": {}, 1222 | "source": [ 1223 | "### Training the Blindspot model" 1224 | ] 1225 | }, 1226 | { 1227 | "cell_type": "code", 1228 | "execution_count": 33, 1229 | "id": "84f6a8b9", 1230 | "metadata": {}, 1231 | "outputs": [ 1232 | { 1233 | "name": "stdout", 1234 | "output_type": "stream", 1235 | "text": [ 1236 | "Epoch [0], last_lr: 0.01000, train_loss: 5.3748, val_loss: 1.7657\n", 1237 | "Epoch [1], last_lr: 0.01000, train_loss: 1.2427, val_loss: 1.0975\n", 1238 | "Epoch [2], last_lr: 0.01000, train_loss: 0.8266, val_loss: 0.9669\n", 1239 | "Epoch [3], last_lr: 0.01000, train_loss: 0.7312, val_loss: 0.9212\n", 1240 | "Epoch [4], last_lr: 0.01000, train_loss: 0.7079, val_loss: 0.9279\n", 1241 | "Epoch [5], last_lr: 0.01000, train_loss: 0.6720, val_loss: 0.9683\n", 1242 | "Epoch [6], last_lr: 0.01000, train_loss: 0.6607, val_loss: 0.9701\n", 1243 | "Epoch [7], last_lr: 0.01000, train_loss: 0.6526, val_loss: 0.9699\n", 1244 | "Epoch [8], last_lr: 0.01000, train_loss: 0.6474, val_loss: 0.9719\n", 1245 | "Epoch [9], last_lr: 0.01000, train_loss: 0.6403, val_loss: 0.9790\n", 1246 | "CPU times: user 4min 24s, sys: 59.7 ms, total: 4min 24s\n", 1247 | "Wall time: 2.77 s\n" 1248 | ] 1249 | }, 1250 | { 1251 | "data": { 1252 | "text/plain": [ 1253 | "" 1254 | ] 1255 | }, 1256 | "execution_count": 33, 1257 | "metadata": {}, 1258 | "output_type": "execute_result" 1259 | } 1260 | ], 1261 | "source": [ 1262 | "%%time\n", 1263 | "device = 'cuda'\n", 1264 | "proxy_model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 1265 | "epochs = 10\n", 1266 | "save_path = \"saved_models/LSTM_STSB_blindspot.pt\"\n", 1267 | "history = fit_one_cycle(epochs, proxy_model, train_df_retain, val_df_retain, device = device, save_path = save_path)\n", 1268 | "proxy_model.load_state_dict(torch.load(save_path))" 1269 | ] 1270 | }, 1271 | { 1272 | "cell_type": "code", 1273 | "execution_count": 34, 1274 | "id": "f8b1d24f", 1275 | "metadata": {}, 1276 | "outputs": [ 1277 | { 1278 | "data": { 1279 | "text/plain": [ 1280 | "{'Loss': 0.9433637388558371}" 1281 | ] 1282 | }, 1283 | "execution_count": 34, 1284 | "metadata": {}, 1285 | "output_type": "execute_result" 1286 | } 1287 | ], 1288 | "source": [ 1289 | "evaluate(proxy_model, test_df_retain, 'cuda')" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": 35, 1295 | "id": "7606e738", 1296 | "metadata": {}, 1297 | "outputs": [ 1298 | { 1299 | "data": { 1300 | "text/plain": [ 1301 | "{'Loss': 6.870501907695236}" 1302 | ] 1303 | }, 1304 | "execution_count": 35, 1305 | "metadata": {}, 1306 | "output_type": "execute_result" 1307 | } 1308 | ], 1309 | "source": [ 1310 | "evaluate(proxy_model, test_df_forget, 'cuda')" 1311 | ] 1312 | }, 1313 | { 1314 | "cell_type": "code", 1315 | "execution_count": 36, 1316 | "id": "e2550637", 1317 | "metadata": {}, 1318 | "outputs": [ 1319 | { 1320 | "name": "stdout", 1321 | "output_type": "stream", 1322 | "text": [ 1323 | "Epoch [0], last_lr: 0.00100, train_loss: 1.3176, val_loss: 2.6206\n", 1324 | "Epoch [1], last_lr: 0.00100, train_loss: 0.9288, val_loss: 3.1632\n", 1325 | "Epoch [2], last_lr: 0.00100, train_loss: 0.8149, val_loss: 3.2369\n", 1326 | "Epoch [3], last_lr: 0.00100, train_loss: 0.7305, val_loss: 3.2416\n", 1327 | "Epoch [4], last_lr: 0.00100, train_loss: 0.6652, val_loss: 3.2574\n", 1328 | "CPU times: user 4min 51s, sys: 149 ms, total: 4min 51s\n", 1329 | "Wall time: 6.44 s\n" 1330 | ] 1331 | }, 1332 | { 1333 | "data": { 1334 | "text/plain": [ 1335 | "" 1336 | ] 1337 | }, 1338 | "execution_count": 36, 1339 | "metadata": {}, 1340 | "output_type": "execute_result" 1341 | } 1342 | ], 1343 | "source": [ 1344 | "%%time\n", 1345 | "AT_beta = 50\n", 1346 | "student_model = LSTMnetwork(text_embedding_dimension = text_embedding_dimension).to(device)\n", 1347 | "student_model.load_state_dict(torch.load(\"saved_models/LSTM_STSB_100epochs.pt\"))\n", 1348 | "epochs = 5\n", 1349 | "save_path = \"saved_models/LSTM_STSB_unlearn.pt\"\n", 1350 | "history = fit_one_forget_cycle(epochs, student_model, proxy_model, u_train_df, val_df, lr = 0.001, device = device, save_path = save_path)\n", 1351 | "student_model.load_state_dict(torch.load(save_path))" 1352 | ] 1353 | }, 1354 | { 1355 | "cell_type": "code", 1356 | "execution_count": 37, 1357 | "id": "4b5f314a", 1358 | "metadata": {}, 1359 | "outputs": [ 1360 | { 1361 | "data": { 1362 | "text/plain": [ 1363 | "{'Loss': 0.9808568328147851}" 1364 | ] 1365 | }, 1366 | "execution_count": 37, 1367 | "metadata": {}, 1368 | "output_type": "execute_result" 1369 | } 1370 | ], 1371 | "source": [ 1372 | "evaluate(student_model, test_df_retain, 'cuda')" 1373 | ] 1374 | }, 1375 | { 1376 | "cell_type": "code", 1377 | "execution_count": 38, 1378 | "id": "2a0a2633", 1379 | "metadata": {}, 1380 | "outputs": [ 1381 | { 1382 | "data": { 1383 | "text/plain": [ 1384 | "{'Loss': 6.394269830679027}" 1385 | ] 1386 | }, 1387 | "execution_count": 38, 1388 | "metadata": {}, 1389 | "output_type": "execute_result" 1390 | } 1391 | ], 1392 | "source": [ 1393 | "evaluate(student_model, test_df_forget, 'cuda')" 1394 | ] 1395 | }, 1396 | { 1397 | "cell_type": "code", 1398 | "execution_count": null, 1399 | "id": "cb568be7", 1400 | "metadata": {}, 1401 | "outputs": [], 1402 | "source": [] 1403 | }, 1404 | { 1405 | "cell_type": "code", 1406 | "execution_count": null, 1407 | "id": "43a3f94e", 1408 | "metadata": {}, 1409 | "outputs": [], 1410 | "source": [] 1411 | } 1412 | ], 1413 | "metadata": { 1414 | "kernelspec": { 1415 | "display_name": "Python 3", 1416 | "language": "python", 1417 | "name": "python3" 1418 | }, 1419 | "language_info": { 1420 | "codemirror_mode": { 1421 | "name": "ipython", 1422 | "version": 3 1423 | }, 1424 | "file_extension": ".py", 1425 | "mimetype": "text/x-python", 1426 | "name": "python", 1427 | "nbconvert_exporter": "python", 1428 | "pygments_lexer": "ipython3", 1429 | "version": "3.9.7" 1430 | } 1431 | }, 1432 | "nbformat": 4, 1433 | "nbformat_minor": 5 1434 | } 1435 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Default configs for TFT experiments. 18 | 19 | Contains the default output paths for data, serialised models and predictions 20 | for the main experiments used in the publication. 21 | """ 22 | 23 | import os 24 | 25 | import data_formatters.electricity 26 | 27 | 28 | class ExperimentConfig(object): 29 | """Defines experiment configs and paths to outputs. 30 | 31 | Attributes: 32 | root_folder: Root folder to contain all experimental outputs. 33 | experiment: Name of experiment to run. 34 | data_folder: Folder to store data for experiment. 35 | model_folder: Folder to store serialised models. 36 | results_folder: Folder to store results. 37 | data_csv_path: Path to primary data csv file used in experiment. 38 | hyperparam_iterations: Default number of random search iterations for 39 | experiment. 40 | """ 41 | 42 | default_experiments = ['volatility', 'electricity', 'traffic', 'favorita'] 43 | 44 | def __init__(self, experiment='volatility', root_folder=None): 45 | """Creates configs based on default experiment chosen. 46 | 47 | Args: 48 | experiment: Name of experiment. 49 | root_folder: Root folder to save all outputs of training. 50 | """ 51 | 52 | if experiment not in self.default_experiments: 53 | raise ValueError('Unrecognised experiment={}'.format(experiment)) 54 | 55 | # Defines all relevant paths 56 | if root_folder is None: 57 | root_folder = os.path.join( 58 | os.path.dirname(os.path.realpath(__file__)), '..', 'outputs') 59 | print('Using root folder {}'.format(root_folder)) 60 | 61 | self.root_folder = root_folder 62 | self.experiment = experiment 63 | self.data_folder = os.path.join(root_folder, 'data', experiment) 64 | self.model_folder = os.path.join(root_folder, 'saved_models', experiment) 65 | self.results_folder = os.path.join(root_folder, 'results', experiment) 66 | 67 | # Creates folders if they don't exist 68 | for relevant_directory in [ 69 | self.root_folder, self.data_folder, self.model_folder, 70 | self.results_folder 71 | ]: 72 | if not os.path.exists(relevant_directory): 73 | os.makedirs(relevant_directory) 74 | 75 | @property 76 | def data_csv_path(self): 77 | csv_map = { 78 | 'volatility': 'formatted_omi_vol.csv', 79 | 'electricity': 'hourly_electricity.csv', 80 | 'traffic': 'hourly_data.csv', 81 | 'favorita': 'favorita_consolidated.csv' 82 | } 83 | 84 | return os.path.join(self.data_folder, csv_map[self.experiment]) 85 | 86 | @property 87 | def hyperparam_iterations(self): 88 | 89 | return 240 if self.experiment == 'volatility' else 60 90 | 91 | def make_data_formatter(self): 92 | """Gets a data formatter object for experiment. 93 | 94 | Returns: 95 | Default DataFormatter per experiment. 96 | """ 97 | 98 | data_formatter_class = { 99 | 'electricity': data_formatters.electricity.ElectricityFormatter 100 | } 101 | 102 | return data_formatter_class[self.experiment]() 103 | -------------------------------------------------------------------------------- /data_formatters/base.py: -------------------------------------------------------------------------------- 1 | """Default data formatting functions for experiments. 2 | 3 | For new datasets, inherit form GenericDataFormatter and implement 4 | all abstract functions. 5 | 6 | These dataset-specific methods: 7 | 1) Define the column and input types for tabular dataframes used by model 8 | 2) Perform the necessary input feature engineering & normalisation steps 9 | 3) Reverts the normalisation for predictions 10 | 4) Are responsible for train, validation and test splits 11 | 12 | 13 | """ 14 | 15 | import abc 16 | import enum 17 | 18 | 19 | # Type defintions 20 | class DataTypes(enum.IntEnum): 21 | """Defines numerical types of each column.""" 22 | REAL_VALUED = 0 23 | CATEGORICAL = 1 24 | DATE = 2 25 | 26 | 27 | class InputTypes(enum.IntEnum): 28 | """Defines input types of each column.""" 29 | TARGET = 0 30 | OBSERVED_INPUT = 1 31 | KNOWN_INPUT = 2 32 | STATIC_INPUT = 3 33 | ID = 4 # Single column used as an entity identifier 34 | TIME = 5 # Single column exclusively used as a time index 35 | 36 | 37 | class GenericDataFormatter(abc.ABC): 38 | """Abstract base class for all data formatters. 39 | 40 | User can implement the abstract methods below to perform dataset-specific 41 | manipulations. 42 | 43 | """ 44 | 45 | @abc.abstractmethod 46 | def set_scalers(self, df): 47 | """Calibrates scalers using the data supplied.""" 48 | raise NotImplementedError() 49 | 50 | @abc.abstractmethod 51 | def transform_inputs(self, df): 52 | """Performs feature transformation.""" 53 | raise NotImplementedError() 54 | 55 | @abc.abstractmethod 56 | def format_predictions(self, df): 57 | """Reverts any normalisation to give predictions in original scale.""" 58 | raise NotImplementedError() 59 | 60 | @abc.abstractmethod 61 | def split_data(self, df): 62 | """Performs the default train, validation and test splits.""" 63 | raise NotImplementedError() 64 | 65 | @property 66 | @abc.abstractmethod 67 | def _column_definition(self): 68 | """Defines order, input type and data type of each column.""" 69 | raise NotImplementedError() 70 | 71 | @abc.abstractmethod 72 | def get_fixed_params(self): 73 | """Defines the fixed parameters used by the model for training. 74 | 75 | Requires the following keys: 76 | 'total_time_steps': Defines the total number of time steps used by TFT 77 | 'num_encoder_steps': Determines length of LSTM encoder (i.e. history) 78 | 'num_epochs': Maximum number of epochs for training 79 | 'early_stopping_patience': Early stopping param for keras 80 | 'multiprocessing_workers': # of cpus for data processing 81 | 82 | 83 | Returns: 84 | A dictionary of fixed parameters, e.g.: 85 | 86 | fixed_params = { 87 | 'total_time_steps': 252 + 5, 88 | 'num_encoder_steps': 252, 89 | 'num_epochs': 100, 90 | 'early_stopping_patience': 5, 91 | 'multiprocessing_workers': 5, 92 | } 93 | """ 94 | raise NotImplementedError 95 | 96 | # Shared functions across data-formatters 97 | @property 98 | def num_classes_per_cat_input(self): 99 | """Returns number of categories per relevant input. 100 | 101 | This is seqeuently required for keras embedding layers. 102 | """ 103 | return self._num_classes_per_cat_input 104 | 105 | def get_num_samples_for_calibration(self): 106 | """Gets the default number of training and validation samples. 107 | 108 | Use to sub-sample the data for network calibration and a value of -1 uses 109 | all available samples. 110 | 111 | Returns: 112 | Tuple of (training samples, validation samples) 113 | """ 114 | return -1, -1 115 | 116 | def get_column_definition(self): 117 | """"Returns formatted column definition in order expected by the TFT.""" 118 | 119 | column_definition = self._column_definition 120 | 121 | # Sanity checks first. 122 | # Ensure only one ID and time column exist 123 | def _check_single_column(input_type): 124 | 125 | length = len([tup for tup in column_definition if tup[2] == input_type]) 126 | 127 | if length != 1: 128 | raise ValueError('Illegal number of inputs ({}) of type {}'.format( 129 | length, input_type)) 130 | 131 | _check_single_column(InputTypes.ID) 132 | _check_single_column(InputTypes.TIME) 133 | 134 | identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID] 135 | time = [tup for tup in column_definition if tup[2] == InputTypes.TIME] 136 | real_inputs = [ 137 | tup for tup in column_definition if tup[1] == DataTypes.REAL_VALUED and 138 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 139 | ] 140 | categorical_inputs = [ 141 | tup for tup in column_definition if tup[1] == DataTypes.CATEGORICAL and 142 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 143 | ] 144 | 145 | return identifier + time + real_inputs + categorical_inputs 146 | 147 | def _get_input_columns(self): 148 | """Returns names of all input columns.""" 149 | return [ 150 | tup[0] 151 | for tup in self.get_column_definition() 152 | if tup[2] not in {InputTypes.ID, InputTypes.TIME} 153 | ] 154 | 155 | def _get_tft_input_indices(self): 156 | """Returns the relevant indexes and input sizes required by TFT.""" 157 | 158 | # Functions 159 | def _extract_tuples_from_data_type(data_type, defn): 160 | return [ 161 | tup for tup in defn if tup[1] == data_type and 162 | tup[2] not in {InputTypes.ID, InputTypes.TIME} 163 | ] 164 | 165 | def _get_locations(input_types, defn): 166 | return [i for i, tup in enumerate(defn) if tup[2] in input_types] 167 | 168 | # Start extraction 169 | column_definition = [ 170 | tup for tup in self.get_column_definition() 171 | if tup[2] not in {InputTypes.ID, InputTypes.TIME} 172 | ] 173 | 174 | categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, 175 | column_definition) 176 | real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, 177 | column_definition) 178 | 179 | locations = { 180 | 'input_size': 181 | len(self._get_input_columns()), 182 | 'output_size': 183 | len(_get_locations({InputTypes.TARGET}, column_definition)), 184 | 'category_counts': 185 | self.num_classes_per_cat_input, 186 | 'input_obs_loc': 187 | _get_locations({InputTypes.TARGET}, column_definition), 188 | 'static_input_loc': 189 | _get_locations({InputTypes.STATIC_INPUT}, column_definition), 190 | 'known_regular_inputs': 191 | _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, 192 | real_inputs), 193 | 'known_categorical_inputs': 194 | _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, 195 | categorical_inputs), 196 | } 197 | 198 | return locations 199 | 200 | def get_experiment_params(self): 201 | """Returns fixed model parameters for experiments.""" 202 | 203 | required_keys = [ 204 | 'total_time_steps', 'num_encoder_steps', 'num_epochs', 205 | 'early_stopping_patience', 'multiprocessing_workers' 206 | ] 207 | 208 | fixed_params = self.get_fixed_params() 209 | 210 | for k in required_keys: 211 | if k not in fixed_params: 212 | raise ValueError('Field {}'.format(k) + 213 | ' missing from fixed parameter definitions!') 214 | 215 | fixed_params['column_definition'] = self.get_column_definition() 216 | 217 | fixed_params.update(self._get_tft_input_indices()) 218 | 219 | return fixed_params 220 | -------------------------------------------------------------------------------- /data_formatters/electricity.py: -------------------------------------------------------------------------------- 1 | """Custom formatting functions for Electricity dataset. 2 | 3 | Defines dataset specific column definitions and data transformations. Uses 4 | entity specific z-score normalization. 5 | """ 6 | 7 | import data_formatters.base 8 | import data_formatters.utils as utils 9 | import pandas as pd 10 | import sklearn.preprocessing 11 | 12 | GenericDataFormatter = data_formatters.base.GenericDataFormatter 13 | DataTypes = data_formatters.base.DataTypes 14 | InputTypes = data_formatters.base.InputTypes 15 | 16 | 17 | class ElectricityFormatter(GenericDataFormatter): 18 | """Defines and formats data for the electricity dataset. 19 | 20 | Note that per-entity z-score normalization is used here, and is implemented 21 | across functions. 22 | 23 | Attributes: 24 | column_definition: Defines input and data type of column used in the 25 | experiment. 26 | identifiers: Entity identifiers used in experiments. 27 | """ 28 | 29 | _column_definition = [ 30 | ('id', DataTypes.REAL_VALUED, InputTypes.ID), 31 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME), 32 | ('power_usage', DataTypes.REAL_VALUED, InputTypes.TARGET), 33 | ('hour', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 34 | ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 35 | ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT), 36 | ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT), 37 | ] 38 | 39 | def __init__(self): 40 | """Initialises formatter.""" 41 | 42 | self.identifiers = None 43 | self._real_scalers = None 44 | self._cat_scalers = None 45 | self._target_scaler = None 46 | self._num_classes_per_cat_input = None 47 | self._time_steps = self.get_fixed_params()['total_time_steps'] 48 | 49 | def split_data(self, df, valid_boundary=1315, test_boundary=1339): 50 | """Splits data frame into training-validation-test data frames. 51 | 52 | This also calibrates scaling object, and transforms data for each split. 53 | 54 | Args: 55 | df: Source data frame to split. 56 | valid_boundary: Starting year for validation data 57 | test_boundary: Starting year for test data 58 | 59 | Returns: 60 | Tuple of transformed (train, valid, test) data. 61 | """ 62 | 63 | print('Formatting train-valid-test splits.') 64 | 65 | index = df['days_from_start'] 66 | train = df.loc[index < valid_boundary] 67 | valid = df.loc[(index >= valid_boundary - 7) & (index < test_boundary)] 68 | test = df.loc[index >= test_boundary - 7] 69 | 70 | self.set_scalers(train) 71 | 72 | return (self.transform_inputs(data) for data in [train, valid, test]) 73 | 74 | def set_scalers(self, df): 75 | """Calibrates scalers using the data supplied. 76 | 77 | Args: 78 | df: Data to use to calibrate scalers. 79 | """ 80 | print('Setting scalers with training data...') 81 | 82 | column_definitions = self.get_column_definition() 83 | id_column = utils.get_single_col_by_input_type(InputTypes.ID, 84 | column_definitions) 85 | target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, 86 | column_definitions) 87 | 88 | # Format real scalers 89 | real_inputs = utils.extract_cols_from_data_type( 90 | DataTypes.REAL_VALUED, column_definitions, 91 | {InputTypes.ID, InputTypes.TIME}) 92 | 93 | # Initialise scaler caches 94 | self._real_scalers = {} 95 | self._target_scaler = {} 96 | identifiers = [] 97 | for identifier, sliced in df.groupby(id_column): 98 | 99 | if len(sliced) >= self._time_steps: 100 | 101 | data = sliced[real_inputs].values 102 | targets = sliced[[target_column]].values 103 | self._real_scalers[identifier] \ 104 | = sklearn.preprocessing.StandardScaler().fit(data) 105 | 106 | self._target_scaler[identifier] \ 107 | = sklearn.preprocessing.StandardScaler().fit(targets) 108 | identifiers.append(identifier) 109 | 110 | # Format categorical scalers 111 | categorical_inputs = utils.extract_cols_from_data_type( 112 | DataTypes.CATEGORICAL, column_definitions, 113 | {InputTypes.ID, InputTypes.TIME}) 114 | 115 | categorical_scalers = {} 116 | num_classes = [] 117 | for col in categorical_inputs: 118 | # Set all to str so that we don't have mixed integer/string columns 119 | srs = df[col].apply(str) 120 | categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit( 121 | srs.values) 122 | num_classes.append(srs.nunique()) 123 | 124 | # Set categorical scaler outputs 125 | self._cat_scalers = categorical_scalers 126 | self._num_classes_per_cat_input = num_classes 127 | 128 | # Extract identifiers in case required 129 | self.identifiers = identifiers 130 | 131 | def transform_inputs(self, df): 132 | """Performs feature transformations. 133 | 134 | This includes both feature engineering, preprocessing and normalisation. 135 | 136 | Args: 137 | df: Data frame to transform. 138 | 139 | Returns: 140 | Transformed data frame. 141 | 142 | """ 143 | 144 | if self._real_scalers is None and self._cat_scalers is None: 145 | raise ValueError('Scalers have not been set!') 146 | 147 | # Extract relevant columns 148 | column_definitions = self.get_column_definition() 149 | id_col = utils.get_single_col_by_input_type(InputTypes.ID, 150 | column_definitions) 151 | real_inputs = utils.extract_cols_from_data_type( 152 | DataTypes.REAL_VALUED, column_definitions, 153 | {InputTypes.ID, InputTypes.TIME}) 154 | categorical_inputs = utils.extract_cols_from_data_type( 155 | DataTypes.CATEGORICAL, column_definitions, 156 | {InputTypes.ID, InputTypes.TIME}) 157 | 158 | # Transform real inputs per entity 159 | df_list = [] 160 | for identifier, sliced in df.groupby(id_col): 161 | 162 | # Filter out any trajectories that are too short 163 | if len(sliced) >= self._time_steps: 164 | sliced_copy = sliced.copy() 165 | sliced_copy[real_inputs] = self._real_scalers[identifier].transform( 166 | sliced_copy[real_inputs].values) 167 | df_list.append(sliced_copy) 168 | 169 | output = pd.concat(df_list, axis=0) 170 | 171 | # Format categorical inputs 172 | for col in categorical_inputs: 173 | string_df = df[col].apply(str) 174 | output[col] = self._cat_scalers[col].transform(string_df) 175 | 176 | return output 177 | 178 | def format_predictions(self, predictions): 179 | """Reverts any normalisation to give predictions in original scale. 180 | 181 | Args: 182 | predictions: Dataframe of model predictions. 183 | 184 | Returns: 185 | Data frame of unnormalised predictions. 186 | """ 187 | 188 | if self._target_scaler is None: 189 | raise ValueError('Scalers have not been set!') 190 | 191 | column_names = predictions.columns 192 | 193 | df_list = [] 194 | for identifier, sliced in predictions.groupby('identifier'): 195 | sliced_copy = sliced.copy() 196 | target_scaler = self._target_scaler[identifier] 197 | 198 | for col in column_names: 199 | if col not in {'forecast_time', 'identifier'}: 200 | sliced_copy[col] = target_scaler.inverse_transform(sliced_copy[col]) 201 | df_list.append(sliced_copy) 202 | 203 | output = pd.concat(df_list, axis=0) 204 | 205 | return output 206 | 207 | # Default params 208 | def get_fixed_params(self): 209 | """Returns fixed model parameters for experiments.""" 210 | 211 | fixed_params = { 212 | 'total_time_steps': 8 * 24, 213 | 'num_encoder_steps': 7 * 24, 214 | 'num_epochs': 100, 215 | 'early_stopping_patience': 5, 216 | 'multiprocessing_workers': 5 217 | } 218 | 219 | return fixed_params 220 | 221 | def get_default_model_params(self): 222 | """Returns default optimised model parameters.""" 223 | 224 | model_params = { 225 | 'dropout_rate': 0.1, 226 | 'hidden_layer_size': 160, 227 | 'learning_rate': 0.001, 228 | 'minibatch_size': 64, 229 | 'max_gradient_norm': 0.01, 230 | 'num_heads': 4, 231 | 'stack_size': 1 232 | } 233 | 234 | return model_params 235 | 236 | def get_num_samples_for_calibration(self): 237 | """Gets the default number of training and validation samples. 238 | 239 | Use to sub-sample the data for network calibration and a value of -1 uses 240 | all available samples. 241 | 242 | Returns: 243 | Tuple of (training samples, validation samples) 244 | """ 245 | return 450000, 50000 246 | -------------------------------------------------------------------------------- /data_formatters/ts_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | 5 | class TSDataset(Dataset): 6 | ## Mostly adapted from original TFT Github, data_formatters 7 | def __init__(self, id_col, static_cols, time_col, input_cols, 8 | target_col, time_steps, max_samples, 9 | input_size, num_encoder_steps,num_static, 10 | output_size, data): 11 | 12 | self.time_steps = time_steps 13 | self.input_size = input_size 14 | self.output_size = output_size 15 | self.num_encoder_steps = num_encoder_steps 16 | 17 | 18 | data.sort_values(by=[id_col, time_col], inplace=True) 19 | print('Getting valid sampling locations.') 20 | 21 | valid_sampling_locations = [] 22 | split_data_map = {} 23 | for identifier, df in data.groupby(id_col): 24 | num_entries = len(df) 25 | if num_entries >= self.time_steps: 26 | valid_sampling_locations += [ 27 | (identifier, self.time_steps + i) 28 | for i in range(num_entries - self.time_steps + 1) 29 | ] 30 | split_data_map[identifier] = df 31 | 32 | 33 | 34 | if max_samples > 0 and len(valid_sampling_locations) > max_samples: 35 | print('Extracting {} samples...'.format(max_samples)) 36 | ranges = [valid_sampling_locations[i] for i in np.random.choice( 37 | len(valid_sampling_locations), max_samples, replace=False)] 38 | else: 39 | print('Max samples={} exceeds # available segments={}'.format( 40 | max_samples, len(valid_sampling_locations))) 41 | ranges = valid_sampling_locations 42 | #print(len(ranges)) 43 | self.inputs = np.zeros((len(ranges), self.time_steps, self.input_size)) 44 | self.outputs = np.zeros((len(ranges), self.time_steps, self.output_size)) 45 | self.time = np.empty((len(ranges), self.time_steps, 1)) 46 | self.identifiers = np.empty((len(ranges),self.time_steps, num_static)) 47 | for i, tup in enumerate(ranges): 48 | if ((i + 1) % 10000) == 0: 49 | print(i + 1, 'of', len(ranges), 'samples done...') 50 | identifier, start_idx = tup 51 | sliced = split_data_map[identifier].iloc[start_idx - 52 | self.time_steps:start_idx] 53 | self.inputs[i, :, :] = sliced[input_cols] 54 | self.outputs[i, :, :] = sliced[[target_col]] 55 | self.time[i, :, 0] = sliced[time_col] 56 | self.identifiers[i,:, :] = sliced[static_cols] 57 | 58 | self.sampled_data = { 59 | 'inputs': self.inputs, 60 | 'outputs': self.outputs[:, self.num_encoder_steps:, :], 61 | 'active_entries': np.ones_like(self.outputs[:, self.num_encoder_steps:, :]), 62 | 'time': self.time, 63 | 'identifier': self.identifiers 64 | } 65 | 66 | def __getitem__(self, index): 67 | s = { 68 | 'inputs': self.inputs[index], 69 | 'outputs': self.outputs[index, self.num_encoder_steps:, :], 70 | 'active_entries': np.ones_like(self.outputs[index, self.num_encoder_steps:, :]), 71 | 'time': self.time[index], 72 | 'identifier': self.identifiers[index] 73 | } 74 | 75 | return s 76 | def __len__(self): 77 | return self.inputs.shape[0] 78 | -------------------------------------------------------------------------------- /data_formatters/utils.py: -------------------------------------------------------------------------------- 1 | """Generic helper functions used across codebase.""" 2 | 3 | import os 4 | import pathlib 5 | import torch 6 | import numpy as np 7 | #import tensorflow as tf 8 | #from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 9 | 10 | 11 | 12 | # Loss functions. 13 | def pytorch_quantile_loss(y, y_pred, quantile): 14 | """Computes quantile loss for tensorflow. 15 | 16 | Standard quantile loss as defined in the "Training Procedure" section of 17 | the main TFT paper 18 | 19 | Args: 20 | y: Targets 21 | y_pred: Predictions 22 | quantile: Quantile to use for loss calculations (between 0 & 1) 23 | 24 | Returns: 25 | Tensor for quantile loss. 26 | """ 27 | 28 | # Checks quantile 29 | if quantile < 0 or quantile > 1: 30 | raise ValueError( 31 | 'Illegal quantile value={}! Values should be between 0 and 1.'.format( 32 | quantile)) 33 | 34 | prediction_underflow = y - y_pred 35 | q_loss = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + ( 36 | 1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow)) 37 | 38 | return torch.sum(q_loss, axis=-1) 39 | 40 | 41 | 42 | # Generic. 43 | def get_single_col_by_input_type(input_type, column_definition): 44 | """Returns name of single column. 45 | 46 | Args: 47 | input_type: Input type of column to extract 48 | column_definition: Column definition list for experiment 49 | """ 50 | 51 | l = [tup[0] for tup in column_definition if tup[2] == input_type] 52 | 53 | if len(l) != 1: 54 | raise ValueError('Invalid number of columns for {}'.format(input_type)) 55 | 56 | return l[0] 57 | 58 | 59 | def extract_cols_from_data_type(data_type, column_definition, 60 | excluded_input_types): 61 | """Extracts the names of columns that correspond to a define data_type. 62 | 63 | Args: 64 | data_type: DataType of columns to extract. 65 | column_definition: Column definition to use. 66 | excluded_input_types: Set of input types to exclude 67 | 68 | Returns: 69 | List of names for columns with data type specified. 70 | """ 71 | return [ 72 | tup[0] 73 | for tup in column_definition 74 | if tup[1] == data_type and tup[2] not in excluded_input_types 75 | ] 76 | 77 | 78 | # Loss functions. 79 | def tensorflow_quantile_loss(y, y_pred, quantile): 80 | """Computes quantile loss for tensorflow. 81 | 82 | Standard quantile loss as defined in the "Training Procedure" section of 83 | the main TFT paper 84 | 85 | Args: 86 | y: Targets 87 | y_pred: Predictions 88 | quantile: Quantile to use for loss calculations (between 0 & 1) 89 | 90 | Returns: 91 | Tensor for quantile loss. 92 | """ 93 | 94 | # Checks quantile 95 | if quantile < 0 or quantile > 1: 96 | raise ValueError( 97 | 'Illegal quantile value={}! Values should be between 0 and 1.'.format( 98 | quantile)) 99 | 100 | prediction_underflow = y - y_pred 101 | q_loss = quantile * tf.maximum(prediction_underflow, 0.) + ( 102 | 1. - quantile) * tf.maximum(-prediction_underflow, 0.) 103 | 104 | return tf.reduce_sum(q_loss, axis=-1) 105 | 106 | 107 | def numpy_normalised_quantile_loss(y, y_pred, quantile): 108 | """Computes normalised quantile loss for numpy arrays. 109 | 110 | Uses the q-Risk metric as defined in the "Training Procedure" section of the 111 | main TFT paper. 112 | 113 | Args: 114 | y: Targets 115 | y_pred: Predictions 116 | quantile: Quantile to use for loss calculations (between 0 & 1) 117 | 118 | Returns: 119 | Float for normalised quantile loss. 120 | """ 121 | prediction_underflow = y - y_pred 122 | weighted_errors = quantile * np.maximum(prediction_underflow, 0.) \ 123 | + (1. - quantile) * np.maximum(-prediction_underflow, 0.) 124 | 125 | quantile_loss = weighted_errors.mean() 126 | normaliser = y.abs().mean() 127 | 128 | return 2 * quantile_loss / normaliser 129 | 130 | 131 | # OS related functions. 132 | def create_folder_if_not_exist(directory): 133 | """Creates folder if it doesn't exist. 134 | 135 | Args: 136 | directory: Folder path to create. 137 | """ 138 | # Also creates directories recursively 139 | pathlib.Path(directory).mkdir(parents=True, exist_ok=True) 140 | 141 | 142 | # Tensorflow related functions. 143 | def get_default_tensorflow_config(tf_device='gpu', gpu_id=0): 144 | """Creates tensorflow config for graphs to run on CPU or GPU. 145 | 146 | Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi 147 | GPU machines. 148 | 149 | Args: 150 | tf_device: 'cpu' or 'gpu' 151 | gpu_id: GPU ID to use if relevant 152 | 153 | Returns: 154 | Tensorflow config. 155 | """ 156 | 157 | if tf_device == 'cpu': 158 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # for training on cpu 159 | tf_config = tf.ConfigProto( 160 | log_device_placement=False, device_count={'GPU': 0}) 161 | 162 | else: 163 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 164 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) 165 | 166 | print('Selecting GPU ID={}'.format(gpu_id)) 167 | 168 | tf_config = tf.ConfigProto(log_device_placement=False) 169 | tf_config.gpu_options.allow_growth = True 170 | 171 | return tf_config 172 | 173 | 174 | def save(tf_session, model_folder, cp_name, scope=None): 175 | """Saves Tensorflow graph to checkpoint. 176 | 177 | Saves all trainiable variables under a given variable scope to checkpoint. 178 | 179 | Args: 180 | tf_session: Session containing graph 181 | model_folder: Folder to save models 182 | cp_name: Name of Tensorflow checkpoint 183 | scope: Variable scope containing variables to save 184 | """ 185 | # Save model 186 | if scope is None: 187 | saver = tf.train.Saver() 188 | else: 189 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 190 | saver = tf.train.Saver(var_list=var_list, max_to_keep=100000) 191 | 192 | save_path = saver.save(tf_session, 193 | os.path.join(model_folder, '{0}.ckpt'.format(cp_name))) 194 | print('Model saved to: {0}'.format(save_path)) 195 | 196 | 197 | def load(tf_session, model_folder, cp_name, scope=None, verbose=False): 198 | """Loads Tensorflow graph from checkpoint. 199 | 200 | Args: 201 | tf_session: Session to load graph into 202 | model_folder: Folder containing serialised model 203 | cp_name: Name of Tensorflow checkpoint 204 | scope: Variable scope to use. 205 | verbose: Whether to print additional debugging information. 206 | """ 207 | # Load model proper 208 | load_path = os.path.join(model_folder, '{0}.ckpt'.format(cp_name)) 209 | 210 | print('Loading model from {0}'.format(load_path)) 211 | 212 | print_weights_in_checkpoint(model_folder, cp_name) 213 | 214 | initial_vars = set( 215 | [v.name for v in tf.get_default_graph().as_graph_def().node]) 216 | 217 | # Saver 218 | if scope is None: 219 | saver = tf.train.Saver() 220 | else: 221 | var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) 222 | saver = tf.train.Saver(var_list=var_list, max_to_keep=100000) 223 | # Load 224 | saver.restore(tf_session, load_path) 225 | all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node]) 226 | 227 | if verbose: 228 | print('Restored {0}'.format(','.join(initial_vars.difference(all_vars)))) 229 | print('Existing {0}'.format(','.join(all_vars.difference(initial_vars)))) 230 | print('All {0}'.format(','.join(all_vars))) 231 | 232 | print('Done.') 233 | 234 | 235 | def print_weights_in_checkpoint(model_folder, cp_name): 236 | """Prints all weights in Tensorflow checkpoint. 237 | 238 | Args: 239 | model_folder: Folder containing checkpoint 240 | cp_name: Name of checkpoint 241 | 242 | Returns: 243 | 244 | """ 245 | load_path = os.path.join(model_folder, '{0}.ckpt'.format(cp_name)) 246 | 247 | print_tensors_in_checkpoint_file( 248 | file_name=load_path, 249 | tensor_name='', 250 | all_tensors=True, 251 | all_tensor_names=True) 252 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import pandas as pd 5 | from PIL import Image 6 | from torch.utils import data 7 | import torchvision.transforms as transforms 8 | 9 | class AgeDB(data.Dataset): 10 | def __init__(self, df, data_dir, img_size, split='train', reweight='none', 11 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 12 | self.df = df 13 | self.data_dir = data_dir 14 | self.img_size = img_size 15 | self.split = split 16 | 17 | def __len__(self): 18 | return len(self.df) 19 | 20 | def __getitem__(self, index): 21 | index = index % len(self.df) 22 | row = self.df.iloc[index] 23 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 24 | transform = self.get_transform() 25 | img = transform(img) 26 | label = np.asarray([row['age']]).astype('float32') 27 | return img, label 28 | 29 | def get_transform(self): 30 | if self.split == 'train': 31 | transform = transforms.Compose([ 32 | transforms.Resize((self.img_size, self.img_size)), 33 | transforms.RandomCrop(self.img_size, padding=16), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 37 | ]) 38 | else: 39 | transform = transforms.Compose([ 40 | transforms.Resize((self.img_size, self.img_size)), 41 | transforms.ToTensor(), 42 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 43 | ]) 44 | return transform 45 | 46 | class TSDataset(data.Dataset): 47 | ## Mostly adapted from original TFT Github, data_formatters 48 | def __init__(self, id_col, static_cols, time_col, input_cols, 49 | target_col, time_steps, max_samples, 50 | input_size, num_encoder_steps,num_static, 51 | output_size, data): 52 | 53 | self.time_steps = time_steps 54 | self.input_size = input_size 55 | self.output_size = output_size 56 | self.num_encoder_steps = num_encoder_steps 57 | 58 | 59 | data.sort_values(by=[id_col, time_col], inplace=True) 60 | print('Getting valid sampling locations.') 61 | 62 | valid_sampling_locations = [] 63 | split_data_map = {} 64 | for identifier, df in data.groupby(id_col): 65 | num_entries = len(df) 66 | if num_entries >= self.time_steps: 67 | valid_sampling_locations += [ 68 | (identifier, self.time_steps + i) 69 | for i in range(num_entries - self.time_steps + 1) 70 | ] 71 | split_data_map[identifier] = df 72 | 73 | 74 | 75 | if max_samples > 0 and len(valid_sampling_locations) > max_samples: 76 | print('Extracting {} samples...'.format(max_samples)) 77 | ranges = [valid_sampling_locations[i] for i in np.random.choice( 78 | len(valid_sampling_locations), max_samples, replace=False)] 79 | else: 80 | print('Max samples={} exceeds # available segments={}'.format( 81 | max_samples, len(valid_sampling_locations))) 82 | ranges = valid_sampling_locations 83 | #print(len(ranges)) 84 | self.inputs = np.zeros((len(ranges), self.time_steps, self.input_size)) 85 | self.outputs = np.zeros((len(ranges), self.time_steps, self.output_size)) 86 | self.time = np.empty((len(ranges), self.time_steps, 1)) 87 | self.identifiers = np.empty((len(ranges),self.time_steps, num_static)) 88 | for i, tup in enumerate(ranges): 89 | if ((i + 1) % 10000) == 0: 90 | print(i + 1, 'of', len(ranges), 'samples done...') 91 | identifier, start_idx = tup 92 | sliced = split_data_map[identifier].iloc[start_idx - 93 | self.time_steps:start_idx] 94 | self.inputs[i, :, :] = sliced[input_cols] 95 | self.outputs[i, :, :] = sliced[[target_col]] 96 | self.time[i, :, 0] = sliced[time_col] 97 | self.identifiers[i,:, :] = sliced[static_cols] 98 | 99 | self.sampled_data = { 100 | 'inputs': self.inputs, 101 | 'outputs': self.outputs[:, self.num_encoder_steps:, :], 102 | 'active_entries': np.ones_like(self.outputs[:, self.num_encoder_steps:, :]), 103 | 'time': self.time, 104 | 'identifier': self.identifiers 105 | } 106 | 107 | def __getitem__(self, index): 108 | s = { 109 | 'inputs': self.inputs[index], 110 | 'outputs': self.outputs[index, self.num_encoder_steps:, :], 111 | 'active_entries': np.ones_like(self.outputs[index, self.num_encoder_steps:, :]), 112 | 'time': self.time[index], 113 | 'identifier': self.identifiers[index] 114 | } 115 | 116 | return s 117 | def __len__(self): 118 | return self.inputs.shape[0] 119 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import pandas as pd 5 | from torch.utils.data import DataLoader 6 | from torch.nn import functional as F 7 | import tqdm 8 | import random 9 | from sklearn.svm import SVC 10 | import numpy as np 11 | from torch.utils import data 12 | from utils import evaluate, training_step 13 | 14 | def get_attack_features(data_loader, model, device='cuda'): 15 | data_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, shuffle=False)#, num_workers = 8, prefetch_factor = 4) 16 | prefinal_gradients = [] 17 | prefinal_activations = [] 18 | predictions = [] 19 | labels = [] 20 | losses = [] 21 | 22 | model.eval() 23 | for batch in data_loader: 24 | data, target = batch 25 | data, target = data.to(device), target.to(device) 26 | output, *_ = model(data) 27 | 28 | loss = F.l1_loss(output, target) 29 | 30 | 31 | predictions.append(output.detach().cpu()) 32 | labels.append(target.detach().cpu()) 33 | losses.append(loss.detach().cpu()) 34 | 35 | predictions = torch.squeeze(torch.cat(predictions, axis = 0)) 36 | labels = torch.squeeze(torch.cat(labels, axis = 0)) 37 | losses = torch.tensor(losses) 38 | 39 | return predictions, labels, losses 40 | 41 | def get_membership_attack_data(retain_loader, test_loader, model, prediction_loaders = None): 42 | retain_preds, retain_labels, retain_losses = get_attack_features(retain_loader, model) 43 | test_preds, test_labels, test_losses = get_attack_features(test_loader, model) 44 | 45 | prediction_outputs = {} 46 | if prediction_loaders is not None: 47 | for prediction_set, prediction_loader in prediction_loaders.items(): 48 | prediction_preds, prediction_labels, prediction_losses = get_attack_features(prediction_loader, model) 49 | 50 | prediction_outputs[prediction_set] = {'prediction_preds':prediction_preds, 51 | 'prediction_labels':prediction_labels, 52 | 'prediction_losses':prediction_losses} 53 | 54 | prediction_outputs['train'] = {'prediction_preds':retain_preds, 55 | 'prediction_labels':retain_labels, 56 | 'prediction_losses':retain_losses} 57 | 58 | prediction_outputs['test'] = {'prediction_preds':test_preds, 59 | 'prediction_labels':test_labels, 60 | 'prediction_losses':test_losses} 61 | 62 | X_train = torch.cat([retain_losses, test_losses], axis = 0).numpy() 63 | Y_train = np.concatenate([np.ones(len(retain_losses)), np.zeros(len(test_losses))]) 64 | 65 | index_shuf = list(range(len(X_train))) 66 | random.Random().shuffle(index_shuf) 67 | X_train = X_train[index_shuf] 68 | Y_train = Y_train[index_shuf] 69 | 70 | return X_train, Y_train, prediction_outputs 71 | 72 | def get_membership_attack_prob(retain_loader, test_loader, model, prediction_loaders = None): 73 | X_train, Y_train, prediction_outputs = get_membership_attack_data(retain_loader, test_loader, model, prediction_loaders) 74 | clf = SVC(C=3,gamma='auto',kernel='rbf') 75 | clf.fit(X_train[:, np.newaxis], Y_train[:, np.newaxis]) 76 | 77 | results = {} 78 | for prediction_set, features in prediction_outputs.items(): 79 | attack_result = clf.predict(features['prediction_losses'][:, np.newaxis]) 80 | results[prediction_set] = attack_result.mean() 81 | return results 82 | 83 | def relearn_time(model, train_loader, valid_loader, reqAcc, lr, device = 'cuda'): 84 | rltime = 0 85 | curr_Acc = 0 86 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 87 | 88 | for epoch in range(10): 89 | 90 | for batch in train_loader: 91 | model.train() 92 | loss = training_step(model, batch, device) 93 | loss.backward() 94 | 95 | optimizer.step() 96 | optimizer.zero_grad() 97 | history = [evaluate(model, valid_loader, device)] 98 | curr_Acc = history[0]["Loss"] 99 | 100 | rltime += 1 101 | if(curr_Acc <= reqAcc): 102 | break 103 | 104 | if(curr_Acc <= reqAcc): 105 | break 106 | return rltime 107 | 108 | def ain(full_model, model, gold_model, train_data, val_retain, val_forget, 109 | batch_size = 256, error_range = 0.05, lr = 0.001, device = 'cuda'): 110 | # measuring performance of fully trained model on forget class 111 | forget_valid_dl = DataLoader(val_forget, batch_size, num_workers = 64) 112 | history = [evaluate(full_model, forget_valid_dl, device)] 113 | LossForget = history[0]["Loss"] 114 | 115 | retain_valid_dl = DataLoader(val_retain, batch_size, num_workers = 64) 116 | history = [evaluate(full_model, retain_valid_dl, device)] 117 | LossRetain = history[0]["Loss"] 118 | 119 | history = [evaluate(model, forget_valid_dl, device)] 120 | LossForget_Fmodel = history[0]["Loss"] 121 | 122 | history = [evaluate(model, retain_valid_dl, device)] 123 | LossRetain_Fmodel = history[0]["Loss"] 124 | 125 | history = [evaluate(gold_model, forget_valid_dl, device)] 126 | LossForget_Gmodel = history[0]["Loss"] 127 | 128 | history = [evaluate(gold_model, retain_valid_dl, device)] 129 | LossRetain_Gmodel = history[0]["Loss"] 130 | 131 | reqLossF = (1+error_range)*LossForget 132 | 133 | train_loader = DataLoader(train_data, batch_size, shuffle = True, num_workers = 64) 134 | valid_loader = DataLoader(val_forget, batch_size, num_workers = 64) 135 | rltime_gold = relearn_time(model = gold_model, train_loader = train_loader, valid_loader = valid_loader, 136 | reqAcc = reqLossF, lr = lr, device = device) 137 | 138 | rltime_forget = relearn_time(model = model, train_loader = train_loader, valid_loader = valid_loader, 139 | reqAcc = reqLossF, lr = lr, device = device) 140 | 141 | rl_coeff = rltime_forget/rltime_gold 142 | return rl_coeff -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torchvision.models import mobilenet_v3_large 7 | from torchvision.models import resnet18 8 | 9 | class MobileNet(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | base = mobilenet_v3_large() 13 | base_list = [*list(base.children())[:-1]] 14 | self.conv_norm1 = nn.Sequential(*base_list[0][0]) 15 | for i in range(1, 16): 16 | exec(f"self.inverted_residual_{i} = base_list[0][{i}]") 17 | self.conv_norm2 = nn.Sequential(*base_list[0][16]) 18 | self.pool1 = base_list[1] 19 | self.drop = nn.Dropout() 20 | self.final = nn.Linear(960,1) 21 | 22 | def forward(self,x): 23 | actvn1 = self.conv_norm1(x) 24 | 25 | for i in range(1, 16): 26 | exec(f"actvn{i+1} = self.inverted_residual_{i}(actvn{i})", locals(), globals()) 27 | 28 | actvn17 = self.conv_norm2(actvn16) 29 | out = self.pool1(actvn17) 30 | 31 | out = self.drop(out.view(-1,self.final.in_features)) 32 | return self.final(out), actvn1, actvn2, actvn3, actvn4, actvn5, actvn6, actvn7,\ 33 | actvn8, actvn9, actvn10, actvn11, actvn12, actvn13, actvn14, actvn15,\ 34 | actvn16, actvn17 35 | 36 | class Identity(nn.Module): 37 | def __init__(self): 38 | super(Identity, self).__init__() 39 | 40 | def forward(self, x): 41 | return x 42 | 43 | class Flatten(nn.Module): 44 | def __init__(self): 45 | super(Flatten, self).__init__() 46 | def forward(self,x): 47 | return x.view(x.size(0), -1) 48 | 49 | class ConvStandard(nn.Conv2d): 50 | 51 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0, w_sig =\ 52 | np.sqrt(1.0)): 53 | super(ConvStandard, self).__init__(in_channels, out_channels,kernel_size) 54 | self.in_channels=in_channels 55 | self.out_channels=out_channels 56 | self.kernel_size=kernel_size 57 | self.stride=stride 58 | self.padding=padding 59 | self.w_sig = w_sig 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | torch.nn.init.normal_(self.weight, mean=0, std=self.w_sig/(self.in_channels*np.prod(self.kernel_size))) 64 | if self.bias is not None: 65 | torch.nn.init.normal_(self.bias, mean=0, std=0) 66 | 67 | def forward(self, input): 68 | return F.conv2d(input,self.weight,self.bias,self.stride,self.padding) 69 | 70 | class Conv(nn.Sequential): 71 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, output_padding=0, 72 | activation_fn=nn.ReLU, batch_norm=True, transpose=False): 73 | if padding is None: 74 | padding = (kernel_size - 1) // 2 75 | model = [] 76 | if not transpose: 77 | # model += [ConvStandard(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding 78 | # )] 79 | model += [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, 80 | bias=not batch_norm)] 81 | else: 82 | model += [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 83 | output_padding=output_padding, bias=not batch_norm)] 84 | if batch_norm: 85 | model += [nn.BatchNorm2d(out_channels, affine=True)] 86 | model += [activation_fn()] 87 | super(Conv, self).__init__(*model) 88 | 89 | class AllCNN(nn.Module): 90 | def __init__(self, filters_percentage=1., n_channels=3, num_classes=10, dropout=False, batch_norm=True): 91 | super(AllCNN, self).__init__() 92 | n_filter1 = int(96 * filters_percentage) 93 | n_filter2 = int(192 * filters_percentage) 94 | 95 | self.conv1 = Conv(n_channels, n_filter1, kernel_size=3, batch_norm=batch_norm) 96 | self.conv2 = Conv(n_filter1, n_filter1, kernel_size=3, batch_norm=batch_norm) 97 | self.conv3 = Conv(n_filter1, n_filter2, kernel_size=3, stride=2, padding=1, batch_norm=batch_norm) 98 | 99 | self.dropout1 = self.features = nn.Sequential(nn.Dropout(inplace=True) if dropout else Identity()) 100 | 101 | self.conv4 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm) 102 | self.conv5 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm) 103 | self.conv6 = Conv(n_filter2, n_filter2, kernel_size=3, stride=2, padding=1, batch_norm=batch_norm) 104 | 105 | self.dropout2 = self.features = nn.Sequential(nn.Dropout(inplace=True) if dropout else Identity()) 106 | 107 | self.conv7 = Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm) 108 | self.conv8 = Conv(n_filter2, n_filter2, kernel_size=1, stride=1, batch_norm=batch_norm) 109 | self.pool = nn.AvgPool2d(8) 110 | self.flatten = Flatten() 111 | 112 | self.classifier = nn.Sequential( 113 | nn.Linear(n_filter2, num_classes), 114 | ) 115 | 116 | def forward(self, x): 117 | out = self.conv1(x) 118 | actv1 = out 119 | 120 | out = self.conv2(out) 121 | actv2 = out 122 | 123 | out = self.conv3(out) 124 | actv3 = out 125 | 126 | out = self.dropout1(out) 127 | 128 | out = self.conv4(out) 129 | actv4 = out 130 | 131 | out = self.conv5(out) 132 | actv5 = out 133 | 134 | out = self.conv6(out) 135 | actv6 = out 136 | 137 | out = self.dropout2(out) 138 | 139 | out = self.conv7(out) 140 | actv7 = out 141 | 142 | out = self.conv8(out) 143 | actv8 = out 144 | 145 | out = self.pool(out) 146 | 147 | out = self.flatten(out) 148 | 149 | out = self.classifier(out) 150 | 151 | return out, actv1, actv2, actv3, actv4, actv5, actv6, actv7, actv8 152 | 153 | class ResNet18(nn.Module): 154 | def __init__(self): 155 | super().__init__() 156 | base = resnet18(pretrained=False) 157 | in_features = base.fc.in_features 158 | base_list = [*list(base.children())[:-1]] 159 | self.layer1 = nn.Sequential(*base_list[0:3]) 160 | self.pool1 = base_list[3] 161 | self.basic_block1 = base_list[4][0] 162 | self.basic_block2 = base_list[4][1] 163 | self.basic_block3 = base_list[5][0] 164 | self.basic_block4 = base_list[5][1] 165 | self.basic_block5 = base_list[6][0] 166 | self.basic_block6 = base_list[6][1] 167 | self.basic_block7 = base_list[7][0] 168 | self.basic_block8 = base_list[7][1] 169 | self.pool2 = base_list[8] 170 | self.drop = nn.Dropout() 171 | self.final = nn.Linear(512,1) 172 | 173 | 174 | def forward(self,x): 175 | out = self.layer1(x) 176 | actvn1 = out 177 | 178 | out = self.pool1(out) 179 | 180 | out = self.basic_block1(out) 181 | actvn2 = out 182 | 183 | out = self.basic_block2(out) 184 | actvn3 = out 185 | 186 | out = self.basic_block3(out) 187 | actvn4 = out 188 | 189 | out = self.basic_block4(out) 190 | actvn5 = out 191 | 192 | out = self.basic_block5(out) 193 | actvn6 = out 194 | 195 | out = self.basic_block6(out) 196 | actvn7 = out 197 | 198 | out = self.basic_block7(out) 199 | actvn8 = out 200 | 201 | out = self.basic_block8(out) 202 | actvn9 = out 203 | 204 | out = self.pool2(out) 205 | out = out.view(-1,self.final.in_features) 206 | 207 | out = self.final(out) 208 | 209 | return out, actvn1, actvn2, actvn3, actvn4, actvn5, actvn6, actvn7, actvn8, actvn9 210 | 211 | class TimeDistributed(nn.Module): 212 | ## Takes any module and stacks the time dimension with the batch dimenison of inputs before apply the module 213 | ## From: https://discuss.pytorch.org/t/any-pytorch-function-can-work-as-keras-timedistributed/1346/4 214 | def __init__(self, module, batch_first=False): 215 | super(TimeDistributed, self).__init__() 216 | self.module = module 217 | self.batch_first = batch_first 218 | 219 | def forward(self, x): 220 | 221 | if len(x.size()) <= 2: 222 | return self.module(x) 223 | 224 | # Squash samples and timesteps into a single axis 225 | x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) 226 | 227 | y = self.module(x_reshape) 228 | 229 | # We have to reshape Y 230 | if self.batch_first: 231 | y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) 232 | else: 233 | y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) 234 | 235 | return y 236 | 237 | class GLU(nn.Module): 238 | #Gated Linear Unit 239 | def __init__(self, input_size): 240 | super(GLU, self).__init__() 241 | 242 | self.fc1 = nn.Linear(input_size,input_size) 243 | self.fc2 = nn.Linear(input_size, input_size) 244 | self.sigmoid = nn.Sigmoid() 245 | 246 | def forward(self, x): 247 | 248 | sig = self.sigmoid(self.fc1(x)) 249 | x = self.fc2(x) 250 | return torch.mul(sig, x) 251 | 252 | class GatedResidualNetwork(nn.Module): 253 | def __init__(self, input_size,hidden_state_size, output_size, dropout, hidden_context_size=None, batch_first=False): 254 | super(GatedResidualNetwork, self).__init__() 255 | self.input_size = input_size 256 | self.output_size = output_size 257 | self.hidden_context_size = hidden_context_size 258 | self.hidden_state_size=hidden_state_size 259 | self.dropout = dropout 260 | 261 | if self.input_size!=self.output_size: 262 | self.skip_layer = TimeDistributed(nn.Linear(self.input_size, self.output_size)) 263 | 264 | self.fc1 = TimeDistributed(nn.Linear(self.input_size, self.hidden_state_size), batch_first=batch_first) 265 | self.elu1 = nn.ELU() 266 | 267 | if self.hidden_context_size is not None: 268 | self.context = TimeDistributed(nn.Linear(self.hidden_context_size, self.hidden_state_size),batch_first=batch_first) 269 | 270 | self.fc2 = TimeDistributed(nn.Linear(self.hidden_state_size, self.output_size), batch_first=batch_first) 271 | self.elu2 = nn.ELU() 272 | 273 | self.dropout = nn.Dropout(self.dropout) 274 | self.bn = TimeDistributed(nn.BatchNorm1d(self.output_size),batch_first=batch_first) 275 | self.gate = TimeDistributed(GLU(self.output_size), batch_first=batch_first) 276 | 277 | def forward(self, x, context=None): 278 | 279 | if self.input_size!=self.output_size: 280 | residual = self.skip_layer(x) 281 | else: 282 | residual = x 283 | 284 | x = self.fc1(x) 285 | if context is not None: 286 | context = self.context(context) 287 | x = x+context 288 | x = self.elu1(x) 289 | 290 | x = self.fc2(x) 291 | x = self.dropout(x) 292 | x = self.gate(x) 293 | x = x+residual 294 | x = self.bn(x) 295 | 296 | return x 297 | 298 | class PositionalEncoder(torch.nn.Module): 299 | def __init__(self, d_model, max_seq_len=160): 300 | super().__init__() 301 | self.d_model = d_model 302 | pe = torch.zeros(max_seq_len, d_model) 303 | for pos in range(max_seq_len): 304 | for i in range(0, d_model, 2): 305 | pe[pos, i] = \ 306 | math.sin(pos / (10000 ** ((2 * i) / d_model))) 307 | pe[pos, i + 1] = \ 308 | math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 309 | pe = pe.unsqueeze(0) 310 | self.register_buffer('pe', pe) 311 | 312 | def forward(self, x): 313 | with torch.no_grad(): 314 | x = x * math.sqrt(self.d_model) 315 | seq_len = x.size(0) 316 | pe = self.pe[:, :seq_len].view(seq_len,1,self.d_model) 317 | x = x + pe 318 | return x 319 | 320 | class VariableSelectionNetwork(nn.Module): 321 | def __init__(self, input_size, num_inputs, hidden_size, dropout, context=None): 322 | super(VariableSelectionNetwork, self).__init__() 323 | 324 | self.hidden_size = hidden_size 325 | self.input_size =input_size 326 | self.num_inputs = num_inputs 327 | self.dropout = dropout 328 | self.context=context 329 | 330 | if self.context is not None: 331 | self.flattened_grn = GatedResidualNetwork(self.num_inputs*self.input_size, self.hidden_size, self.num_inputs, self.dropout, self.context) 332 | else: 333 | self.flattened_grn = GatedResidualNetwork(self.num_inputs*self.input_size, self.hidden_size, self.num_inputs, self.dropout) 334 | 335 | 336 | self.single_variable_grns = nn.ModuleList() 337 | for i in range(self.num_inputs): 338 | self.single_variable_grns.append(GatedResidualNetwork(self.input_size, self.hidden_size, self.hidden_size, self.dropout)) 339 | 340 | self.softmax = nn.Softmax() 341 | 342 | def forward(self, embedding, context=None): 343 | if context is not None: 344 | sparse_weights = self.flattened_grn(embedding, context) 345 | else: 346 | sparse_weights = self.flattened_grn(embedding) 347 | 348 | sparse_weights = self.softmax(sparse_weights).unsqueeze(2) 349 | 350 | var_outputs = [] 351 | for i in range(self.num_inputs): 352 | ##select slice of embedding belonging to a single input 353 | var_outputs.append(self.single_variable_grns[i](embedding[:,:, (i*self.input_size) : (i+1)*self.input_size])) 354 | 355 | var_outputs = torch.stack(var_outputs, axis=-1) 356 | 357 | outputs = var_outputs*sparse_weights 358 | 359 | outputs = outputs.sum(axis=-1) 360 | 361 | return outputs, sparse_weights 362 | 363 | class TFT(nn.Module): 364 | def __init__(self, config): 365 | super(TFT, self).__init__() 366 | self.device = config['device'] 367 | self.batch_size=config['batch_size'] 368 | self.static_variables = config['static_variables'] 369 | self.encode_length = config['encode_length'] 370 | self.time_varying_categoical_variables = config['time_varying_categoical_variables'] 371 | self.time_varying_real_variables_encoder = config['time_varying_real_variables_encoder'] 372 | self.time_varying_real_variables_decoder = config['time_varying_real_variables_decoder'] 373 | self.num_input_series_to_mask = config['num_masked_series'] 374 | self.hidden_size = config['lstm_hidden_dimension'] 375 | self.lstm_layers = config['lstm_layers'] 376 | self.dropout = config['dropout'] 377 | self.embedding_dim = config['embedding_dim'] 378 | self.attn_heads = config['attn_heads'] 379 | self.num_quantiles = config['num_quantiles'] 380 | self.valid_quantiles = config['valid_quantiles'] 381 | self.seq_length = config['seq_length'] 382 | 383 | self.static_embedding_layers = nn.ModuleList() 384 | for i in range(self.static_variables): 385 | emb = nn.Embedding(config['static_embedding_vocab_sizes'][i], config['embedding_dim']).to(self.device) 386 | self.static_embedding_layers.append(emb) 387 | 388 | 389 | 390 | self.time_varying_embedding_layers = nn.ModuleList() 391 | for i in range(self.time_varying_categoical_variables): 392 | emb = TimeDistributed(nn.Embedding(config['time_varying_embedding_vocab_sizes'][i], config['embedding_dim']), batch_first=True).to(self.device) 393 | self.time_varying_embedding_layers.append(emb) 394 | 395 | self.time_varying_linear_layers = nn.ModuleList() 396 | for i in range(self.time_varying_real_variables_encoder): 397 | emb = TimeDistributed(nn.Linear(1, config['embedding_dim']), batch_first=True).to(self.device) 398 | self.time_varying_linear_layers.append(emb) 399 | 400 | self.encoder_variable_selection = VariableSelectionNetwork(config['embedding_dim'], 401 | (config['time_varying_real_variables_encoder'] + config['time_varying_categoical_variables']), 402 | self.hidden_size, 403 | self.dropout, 404 | config['embedding_dim']*config['static_variables']).to(self.device) 405 | 406 | self.decoder_variable_selection = VariableSelectionNetwork(config['embedding_dim'], 407 | (config['time_varying_real_variables_decoder'] + config['time_varying_categoical_variables']), 408 | self.hidden_size, 409 | self.dropout, 410 | config['embedding_dim']*config['static_variables']).to(self.device) 411 | 412 | 413 | self.lstm_encoder_input_size = config['embedding_dim']*(config['time_varying_real_variables_encoder'] + 414 | config['time_varying_categoical_variables'] + 415 | config['static_variables']) 416 | 417 | self.lstm_decoder_input_size = config['embedding_dim']*(config['time_varying_real_variables_decoder'] + 418 | config['time_varying_categoical_variables'] + 419 | config['static_variables']) 420 | 421 | 422 | self.lstm_encoder = nn.LSTM(input_size=self.hidden_size, 423 | hidden_size=self.hidden_size, 424 | num_layers=self.lstm_layers, 425 | dropout=config['dropout']).to(self.device) 426 | 427 | self.lstm_decoder = nn.LSTM(input_size=self.hidden_size, 428 | hidden_size=self.hidden_size, 429 | num_layers=self.lstm_layers, 430 | dropout=config['dropout']).to(self.device) 431 | 432 | self.post_lstm_gate = TimeDistributed(GLU(self.hidden_size)).to(self.device) 433 | self.post_lstm_norm = TimeDistributed(nn.BatchNorm1d(self.hidden_size)).to(self.device) 434 | 435 | self.static_enrichment = GatedResidualNetwork(self.hidden_size,self.hidden_size, self.hidden_size, self.dropout, config['embedding_dim']*self.static_variables).to(self.device) 436 | 437 | self.position_encoding = PositionalEncoder(self.hidden_size, self.seq_length).to(self.device) 438 | 439 | self.multihead_attn = nn.MultiheadAttention(self.hidden_size, self.attn_heads).to(self.device) 440 | self.post_attn_gate = TimeDistributed(GLU(self.hidden_size)).to(self.device) 441 | 442 | self.post_attn_norm = TimeDistributed(nn.BatchNorm1d(self.hidden_size, self.hidden_size)).to(self.device) 443 | self.pos_wise_ff = GatedResidualNetwork(self.hidden_size, self.hidden_size, self.hidden_size, self.dropout).to(self.device) 444 | 445 | self.pre_output_norm = TimeDistributed(nn.BatchNorm1d(self.hidden_size, self.hidden_size)).to(self.device) 446 | self.pre_output_gate = TimeDistributed(GLU(self.hidden_size)).to(self.device) 447 | 448 | self.output_layer = TimeDistributed(nn.Linear(self.hidden_size, self.num_quantiles), batch_first=True).to(self.device) 449 | 450 | def init_hidden(self): 451 | return torch.zeros(self.lstm_layers, self.batch_size, self.hidden_size, device=self.device) 452 | 453 | def apply_embedding(self, x, static_embedding, apply_masking): 454 | ###x should have dimensions (batch_size, timesteps, input_size) 455 | ## Apply masking is used to mask variables that should not be accessed after the encoding steps 456 | #Time-varying real embeddings 457 | if apply_masking: 458 | time_varying_real_vectors = [] 459 | for i in range(self.time_varying_real_variables_decoder): 460 | emb = self.time_varying_linear_layers[i+self.num_input_series_to_mask](x[:,:,i+self.num_input_series_to_mask].view(x.size(0), -1, 1)) 461 | time_varying_real_vectors.append(emb) 462 | time_varying_real_embedding = torch.cat(time_varying_real_vectors, dim=2) 463 | 464 | else: 465 | time_varying_real_vectors = [] 466 | for i in range(self.time_varying_real_variables_encoder): 467 | emb = self.time_varying_linear_layers[i](x[:,:,i].view(x.size(0), -1, 1)) 468 | time_varying_real_vectors.append(emb) 469 | time_varying_real_embedding = torch.cat(time_varying_real_vectors, dim=2) 470 | 471 | 472 | ##Time-varying categorical embeddings (ie hour) 473 | time_varying_categoical_vectors = [] 474 | for i in range(self.time_varying_categoical_variables): 475 | emb = self.time_varying_embedding_layers[i](x[:, :,self.time_varying_real_variables_encoder+i].view(x.size(0), -1, 1).long()) 476 | time_varying_categoical_vectors.append(emb) 477 | time_varying_categoical_embedding = torch.cat(time_varying_categoical_vectors, dim=2) 478 | 479 | ##repeat static_embedding for all timesteps 480 | static_embedding = torch.cat(time_varying_categoical_embedding.size(1)*[static_embedding]) 481 | static_embedding = static_embedding.view(time_varying_categoical_embedding.size(0),time_varying_categoical_embedding.size(1),-1 ) 482 | 483 | ##concatenate all embeddings 484 | embeddings = torch.cat([static_embedding,time_varying_categoical_embedding,time_varying_real_embedding], dim=2) 485 | 486 | return embeddings.view(-1,x.size(0),embeddings.size(2)) 487 | 488 | def encode(self, x, hidden=None): 489 | 490 | if hidden is None: 491 | hidden = self.init_hidden() 492 | 493 | output, (hidden, cell) = self.lstm_encoder(x, (hidden, hidden)) 494 | 495 | return output, hidden 496 | 497 | def decode(self, x, hidden=None): 498 | 499 | if hidden is None: 500 | hidden = self.init_hidden() 501 | 502 | output, (hidden, cell) = self.lstm_decoder(x, (hidden,hidden)) 503 | 504 | return output, hidden 505 | 506 | 507 | def forward(self, x): 508 | 509 | ##inputs should be in this order 510 | # static 511 | # time_varying_categorical 512 | # time_varying_real 513 | 514 | embedding_vectors = [] 515 | for i in range(self.static_variables): 516 | #only need static variable from the first timestep 517 | emb = self.static_embedding_layers[i](x['identifier'][:,0, i].long().to(self.device)) 518 | embedding_vectors.append(emb) 519 | 520 | ##Embedding and variable selection 521 | static_embedding = torch.cat(embedding_vectors, dim=1) 522 | embeddings_encoder = self.apply_embedding(x['inputs'][:,:self.encode_length,:].float().to(self.device), static_embedding, apply_masking=False) 523 | embeddings_decoder = self.apply_embedding(x['inputs'][:,self.encode_length:,:].float().to(self.device), static_embedding, apply_masking=True) 524 | embeddings_encoder, encoder_sparse_weights = self.encoder_variable_selection(embeddings_encoder[:,:,:-(self.embedding_dim*self.static_variables)],embeddings_encoder[:,:,-(self.embedding_dim*self.static_variables):]) 525 | embeddings_decoder, decoder_sparse_weights = self.decoder_variable_selection(embeddings_decoder[:,:,:-(self.embedding_dim*self.static_variables)],embeddings_decoder[:,:,-(self.embedding_dim*self.static_variables):]) 526 | 527 | 528 | pe = self.position_encoding(torch.zeros(self.seq_length, 1, embeddings_encoder.size(2)).to(self.device)).to(self.device) 529 | 530 | embeddings_encoder = embeddings_encoder+pe[:self.encode_length,:,:] 531 | embeddings_decoder = embeddings_decoder+pe[self.encode_length:,:,:] 532 | 533 | ##LSTM 534 | lstm_input = torch.cat([embeddings_encoder,embeddings_decoder], dim=0) 535 | encoder_output, hidden = self.encode(embeddings_encoder) 536 | decoder_output, _ = self.decode(embeddings_decoder, hidden) 537 | lstm_output = torch.cat([encoder_output, decoder_output], dim=0) 538 | 539 | ##skip connection over lstm 540 | lstm_output = self.post_lstm_gate(lstm_output+lstm_input) 541 | 542 | ##static enrichment 543 | static_embedding = torch.cat(lstm_output.size(0)*[static_embedding]).view(lstm_output.size(0), lstm_output.size(1), -1) 544 | #print(lstm_output.device) 545 | #print(static_embedding.device) 546 | #print(self.static_enrichment.device) 547 | attn_input = self.static_enrichment(lstm_output, static_embedding) 548 | 549 | ##skip connection over lstm 550 | attn_input = self.post_lstm_norm(lstm_output) 551 | 552 | #attn_input = self.position_encoding(attn_input) 553 | 554 | ##Attention 555 | attn_output, attn_output_weights = self.multihead_attn(attn_input[self.encode_length:,:,:], attn_input[:self.encode_length,:,:], attn_input[:self.encode_length,:,:]) 556 | 557 | ##skip connection over attention 558 | attn_output = self.post_attn_gate(attn_output) + attn_input[self.encode_length:,:,:] 559 | attn_output = self.post_attn_norm(attn_output) 560 | 561 | output = self.pos_wise_ff(attn_output) #[self.encode_length:,:,:]) 562 | 563 | ##skip connection over Decoder 564 | output = self.pre_output_gate(output) + lstm_output[self.encode_length:,:,:] 565 | 566 | #Final output layers 567 | output = self.pre_output_norm(output) 568 | output = self.output_layer(output.view(self.batch_size, -1, self.hidden_size)) 569 | 570 | 571 | return output, encoder_output, decoder_output, attn_output, attn_output_weights, encoder_sparse_weights, decoder_sparse_weights 572 | 573 | class LSTMnetwork(nn.Module): 574 | def __init__(self, text_embedding_dimension): 575 | super().__init__() 576 | self.hidden_size = 64 577 | self.input_size = text_embedding_dimension 578 | self.num_layers = 1 579 | self.bidirectional = False 580 | self.num_directions = 1 581 | self.dropout1 = nn.Dropout(p=0.3) 582 | 583 | if self.bidirectional: 584 | self.num_directions = 2 585 | 586 | self.lstm = nn.LSTM( self.input_size, self.hidden_size, self.num_layers, 587 | bidirectional=self.bidirectional, batch_first=True) 588 | 589 | self.linear1 = nn.Linear(self.hidden_size*self.num_directions*2, 64) 590 | self.linear2 = nn.Linear(64, 32) 591 | self.linear3 = nn.Linear(32, 16) 592 | self.linear4 = nn.Linear(16, 1) 593 | self.relu = nn.ReLU() 594 | 595 | def forward(self, sent1, sent2): 596 | 597 | lstm_out1, _ = self.lstm( sent1) 598 | 599 | x1 = self.dropout1( lstm_out1) 600 | 601 | actv1 = x1 602 | 603 | lstm_out2, _ = self.lstm( sent2) 604 | 605 | x2 = self.dropout1( lstm_out2) 606 | 607 | actv2 = x2 608 | 609 | output = self.linear1(torch.cat([x1[:, -1, :], x2[:, -1, :]], axis = 1)) 610 | actv3 = output 611 | output = self.relu(output) 612 | 613 | 614 | output = self.linear2(output) 615 | actv4 = output 616 | output = self.relu(output) 617 | 618 | 619 | output = self.linear3(output) 620 | output = self.relu(output) 621 | output = self.linear4(output) 622 | 623 | return torch.squeeze(output), actv1, actv2, actv3, actv4 -------------------------------------------------------------------------------- /script_download_data.py: -------------------------------------------------------------------------------- 1 | """Script to download data for a default experiment. 2 | 3 | Only downloads data if the csv files are present, unless the "force_download" 4 | argument is supplied. For new datasets, the download_and_unzip(.) can be reused 5 | to pull csv files from an online repository, but may require subsequent 6 | dataset-specific processing. 7 | 8 | Usage: 9 | python3 script_download_data {EXPT_NAME} {OUTPUT_FOLDER} {FORCE_DOWNLOAD} 10 | 11 | Command line args: 12 | EXPT_NAME: Name of experiment to download data for {e.g. volatility} 13 | OUTPUT_FOLDER: Path to folder in which 14 | FORCE_DOWNLOAD: Whether to force data download from scratch. 15 | 16 | 17 | 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import argparse 25 | 26 | import gc 27 | import glob 28 | import os 29 | import shutil 30 | import sys 31 | 32 | from expt_settings.configs import ExperimentConfig 33 | import numpy as np 34 | import pandas as pd 35 | import pyunpack 36 | import wget 37 | 38 | 39 | # General functions for data downloading & aggregation. 40 | def download_from_url(url, output_path): 41 | """Downloads a file froma url.""" 42 | 43 | print('Pulling data from {} to {}'.format(url, output_path)) 44 | wget.download(url, output_path) 45 | print('done') 46 | 47 | 48 | def recreate_folder(path): 49 | """Deletes and recreates folder.""" 50 | 51 | shutil.rmtree(path) 52 | os.makedirs(path) 53 | 54 | 55 | def unzip(zip_path, output_file, data_folder): 56 | """Unzips files and checks successful completion.""" 57 | 58 | print('Unzipping file: {}'.format(zip_path)) 59 | pyunpack.Archive(zip_path).extractall(data_folder) 60 | 61 | # Checks if unzip was successful 62 | if not os.path.exists(output_file): 63 | raise ValueError( 64 | 'Error in unzipping process! {} not found.'.format(output_file)) 65 | 66 | 67 | def download_and_unzip(url, zip_path, csv_path, data_folder): 68 | """Downloads and unzips an online csv file. 69 | 70 | Args: 71 | url: Web address 72 | zip_path: Path to download zip file 73 | csv_path: Expected path to csv file 74 | data_folder: Folder in which data is stored. 75 | """ 76 | 77 | download_from_url(url, zip_path) 78 | 79 | unzip(zip_path, csv_path, data_folder) 80 | 81 | print('Done.') 82 | 83 | 84 | # Dataset specific download routines. 85 | def download_volatility(config): 86 | """Downloads volatility data from OMI website.""" 87 | 88 | url = 'https://realized.oxford-man.ox.ac.uk/images/oxfordmanrealizedvolatilityindices.zip' 89 | 90 | data_folder = config.data_folder 91 | csv_path = os.path.join(data_folder, 'oxfordmanrealizedvolatilityindices.csv') 92 | zip_path = os.path.join(data_folder, 'oxfordmanrealizedvolatilityindices.zip') 93 | 94 | download_and_unzip(url, zip_path, csv_path, data_folder) 95 | 96 | print('Unzip complete. Adding extra inputs') 97 | 98 | df = pd.read_csv(csv_path, index_col=0) # no explicit index 99 | 100 | # Adds additional date/day fields 101 | idx = [str(s).split('+')[0] for s in df.index 102 | ] # ignore timezones, we don't need them 103 | dates = pd.to_datetime(idx) 104 | df['date'] = dates 105 | df['days_from_start'] = (dates - pd.datetime(2000, 1, 3)).days 106 | df['day_of_week'] = dates.dayofweek 107 | df['day_of_month'] = dates.day 108 | df['week_of_year'] = dates.weekofyear 109 | df['month'] = dates.month 110 | df['year'] = dates.year 111 | df['categorical_id'] = df['Symbol'].copy() 112 | 113 | # Processes log volatility 114 | vol = df['rv5_ss'].copy() 115 | vol.loc[vol == 0.] = np.nan 116 | df['log_vol'] = np.log(vol) 117 | 118 | # Adds static information 119 | symbol_region_mapping = { 120 | '.AEX': 'EMEA', 121 | '.AORD': 'APAC', 122 | '.BFX': 'EMEA', 123 | '.BSESN': 'APAC', 124 | '.BVLG': 'EMEA', 125 | '.BVSP': 'AMER', 126 | '.DJI': 'AMER', 127 | '.FCHI': 'EMEA', 128 | '.FTMIB': 'EMEA', 129 | '.FTSE': 'EMEA', 130 | '.GDAXI': 'EMEA', 131 | '.GSPTSE': 'AMER', 132 | '.HSI': 'APAC', 133 | '.IBEX': 'EMEA', 134 | '.IXIC': 'AMER', 135 | '.KS11': 'APAC', 136 | '.KSE': 'APAC', 137 | '.MXX': 'AMER', 138 | '.N225': 'APAC ', 139 | '.NSEI': 'APAC', 140 | '.OMXC20': 'EMEA', 141 | '.OMXHPI': 'EMEA', 142 | '.OMXSPI': 'EMEA', 143 | '.OSEAX': 'EMEA', 144 | '.RUT': 'EMEA', 145 | '.SMSI': 'EMEA', 146 | '.SPX': 'AMER', 147 | '.SSEC': 'APAC', 148 | '.SSMI': 'EMEA', 149 | '.STI': 'APAC', 150 | '.STOXX50E': 'EMEA' 151 | } 152 | 153 | df['Region'] = df['Symbol'].apply(lambda k: symbol_region_mapping[k]) 154 | 155 | # Performs final processing 156 | output_df_list = [] 157 | for grp in df.groupby('Symbol'): 158 | sliced = grp[1].copy() 159 | sliced.sort_values('days_from_start', inplace=True) 160 | # Impute log volatility values 161 | sliced['log_vol'].fillna(method='ffill', inplace=True) 162 | sliced.dropna() 163 | output_df_list.append(sliced) 164 | 165 | df = pd.concat(output_df_list, axis=0) 166 | 167 | output_file = config.data_csv_path 168 | print('Completed formatting, saving to {}'.format(output_file)) 169 | df.to_csv(output_file) 170 | 171 | print('Done.') 172 | 173 | 174 | def download_electricity(config): 175 | """Downloads electricity dataset from UCI repository.""" 176 | 177 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip' 178 | 179 | data_folder = config.data_folder 180 | csv_path = os.path.join(data_folder, 'LD2011_2014.txt') 181 | zip_path = csv_path + '.zip' 182 | 183 | download_and_unzip(url, zip_path, csv_path, data_folder) 184 | 185 | print('Aggregating to hourly data') 186 | 187 | df = pd.read_csv(csv_path, index_col=0, sep=';', decimal=',') 188 | df.index = pd.to_datetime(df.index) 189 | df.sort_index(inplace=True) 190 | 191 | # Used to determine the start and end dates of a series 192 | output = df.resample('1h').mean().replace(0., np.nan) 193 | 194 | earliest_time = output.index.min() 195 | 196 | df_list = [] 197 | for label in output: 198 | print('Processing {}'.format(label)) 199 | srs = output[label] 200 | 201 | start_date = min(srs.fillna(method='ffill').dropna().index) 202 | end_date = max(srs.fillna(method='bfill').dropna().index) 203 | 204 | active_range = (srs.index >= start_date) & (srs.index <= end_date) 205 | srs = srs[active_range].fillna(0.) 206 | 207 | tmp = pd.DataFrame({'power_usage': srs}) 208 | date = tmp.index 209 | tmp['t'] = (date - earliest_time).seconds / 60 / 60 + ( 210 | date - earliest_time).days * 24 211 | tmp['days_from_start'] = (date - earliest_time).days 212 | tmp['categorical_id'] = label 213 | tmp['date'] = date 214 | tmp['id'] = label 215 | tmp['hour'] = date.hour 216 | tmp['day'] = date.day 217 | tmp['day_of_week'] = date.dayofweek 218 | tmp['month'] = date.month 219 | 220 | df_list.append(tmp) 221 | 222 | output = pd.concat(df_list, axis=0, join='outer').reset_index(drop=True) 223 | 224 | output['categorical_id'] = output['id'].copy() 225 | output['hours_from_start'] = output['t'] 226 | output['categorical_day_of_week'] = output['day_of_week'].copy() 227 | output['categorical_hour'] = output['hour'].copy() 228 | 229 | # Filter to match range used by other academic papers 230 | output = output[(output['days_from_start'] >= 1096) 231 | & (output['days_from_start'] < 1346)].copy() 232 | 233 | output.to_csv(config.data_csv_path) 234 | 235 | print('Done.') 236 | 237 | 238 | def download_traffic(config): 239 | """Downloads traffic dataset from UCI repository.""" 240 | 241 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00204/PEMS-SF.zip' 242 | 243 | data_folder = config.data_folder 244 | csv_path = os.path.join(data_folder, 'PEMS_train') 245 | zip_path = os.path.join(data_folder, 'PEMS-SF.zip') 246 | 247 | download_and_unzip(url, zip_path, csv_path, data_folder) 248 | 249 | print('Aggregating to hourly data') 250 | 251 | def process_list(s, variable_type=int, delimiter=None): 252 | """Parses a line in the PEMS format to a list.""" 253 | if delimiter is None: 254 | l = [ 255 | variable_type(i) for i in s.replace('[', '').replace(']', '').split() 256 | ] 257 | else: 258 | l = [ 259 | variable_type(i) 260 | for i in s.replace('[', '').replace(']', '').split(delimiter) 261 | ] 262 | 263 | return l 264 | 265 | def read_single_list(filename): 266 | """Returns single list from a file in the PEMS-custom format.""" 267 | with open(os.path.join(data_folder, filename), 'r') as dat: 268 | l = process_list(dat.readlines()[0]) 269 | return l 270 | 271 | def read_matrix(filename): 272 | """Returns a matrix from a file in the PEMS-custom format.""" 273 | array_list = [] 274 | with open(os.path.join(data_folder, filename), 'r') as dat: 275 | 276 | lines = dat.readlines() 277 | for i, line in enumerate(lines): 278 | if (i + 1) % 50 == 0: 279 | print('Completed {} of {} rows for {}'.format(i + 1, len(lines), 280 | filename)) 281 | 282 | array = [ 283 | process_list(row_split, variable_type=float, delimiter=None) 284 | for row_split in process_list( 285 | line, variable_type=str, delimiter=';') 286 | ] 287 | array_list.append(array) 288 | 289 | return array_list 290 | 291 | shuffle_order = np.array(read_single_list('randperm')) - 1 # index from 0 292 | train_dayofweek = read_single_list('PEMS_trainlabels') 293 | train_tensor = read_matrix('PEMS_train') 294 | test_dayofweek = read_single_list('PEMS_testlabels') 295 | test_tensor = read_matrix('PEMS_test') 296 | 297 | # Inverse permutate shuffle order 298 | print('Shuffling') 299 | inverse_mapping = { 300 | new_location: previous_location 301 | for previous_location, new_location in enumerate(shuffle_order) 302 | } 303 | reverse_shuffle_order = np.array([ 304 | inverse_mapping[new_location] 305 | for new_location, _ in enumerate(shuffle_order) 306 | ]) 307 | 308 | # Group and reoder based on permuation matrix 309 | print('Reodering') 310 | day_of_week = np.array(train_dayofweek + test_dayofweek) 311 | combined_tensor = np.array(train_tensor + test_tensor) 312 | 313 | day_of_week = day_of_week[reverse_shuffle_order] 314 | combined_tensor = combined_tensor[reverse_shuffle_order] 315 | 316 | # Put everything back into a dataframe 317 | print('Parsing as dataframe') 318 | labels = ['traj_{}'.format(i) for i in read_single_list('stations_list')] 319 | 320 | hourly_list = [] 321 | for day, day_matrix in enumerate(combined_tensor): 322 | 323 | # Hourly data 324 | hourly = pd.DataFrame(day_matrix.T, columns=labels) 325 | hourly['hour_on_day'] = [int(i / 6) for i in hourly.index 326 | ] # sampled at 10 min intervals 327 | if hourly['hour_on_day'].max() > 23 or hourly['hour_on_day'].min() < 0: 328 | raise ValueError('Invalid hour! {}-{}'.format( 329 | hourly['hour_on_day'].min(), hourly['hour_on_day'].max())) 330 | 331 | hourly = hourly.groupby('hour_on_day', as_index=True).mean()[labels] 332 | hourly['sensor_day'] = day 333 | hourly['time_on_day'] = hourly.index 334 | hourly['day_of_week'] = day_of_week[day] 335 | 336 | hourly_list.append(hourly) 337 | 338 | hourly_frame = pd.concat(hourly_list, axis=0, ignore_index=True, sort=False) 339 | 340 | # Flatten such that each entitiy uses one row in dataframe 341 | store_columns = [c for c in hourly_frame.columns if 'traj' in c] 342 | other_columns = [c for c in hourly_frame.columns if 'traj' not in c] 343 | flat_df = pd.DataFrame(columns=['values', 'prev_values', 'next_values'] + 344 | other_columns + ['id']) 345 | 346 | def format_index_string(x): 347 | """Returns formatted string for key.""" 348 | 349 | if x < 10: 350 | return '00' + str(x) 351 | elif x < 100: 352 | return '0' + str(x) 353 | elif x < 1000: 354 | return str(x) 355 | 356 | raise ValueError('Invalid value of x {}'.format(x)) 357 | 358 | for store in store_columns: 359 | print('Processing {}'.format(store)) 360 | 361 | sliced = hourly_frame[[store] + other_columns].copy() 362 | sliced.columns = ['values'] + other_columns 363 | sliced['id'] = int(store.replace('traj_', '')) 364 | 365 | # Sort by Sensor-date-time 366 | key = sliced['id'].apply(str) \ 367 | + sliced['sensor_day'].apply(lambda x: '_' + format_index_string(x)) \ 368 | + sliced['time_on_day'].apply(lambda x: '_' + format_index_string(x)) 369 | sliced = sliced.set_index(key).sort_index() 370 | 371 | sliced['values'] = sliced['values'].fillna(method='ffill') 372 | sliced['prev_values'] = sliced['values'].shift(1) 373 | sliced['next_values'] = sliced['values'].shift(-1) 374 | 375 | flat_df = flat_df.append(sliced.dropna(), ignore_index=True, sort=False) 376 | 377 | # Filter to match range used by other academic papers 378 | index = flat_df['sensor_day'] 379 | flat_df = flat_df[index < 173].copy() 380 | 381 | # Creating columns fo categorical inputs 382 | flat_df['categorical_id'] = flat_df['id'].copy() 383 | flat_df['hours_from_start'] = flat_df['time_on_day'] \ 384 | + flat_df['sensor_day']*24. 385 | flat_df['categorical_day_of_week'] = flat_df['day_of_week'].copy() 386 | flat_df['categorical_time_on_day'] = flat_df['time_on_day'].copy() 387 | 388 | flat_df.to_csv(config.data_csv_path) 389 | print('Done.') 390 | 391 | 392 | def process_favorita(config): 393 | """Processes Favorita dataset. 394 | 395 | Makes use of the raw files should be manually downloaded from Kaggle @ 396 | https://www.kaggle.com/c/favorita-grocery-sales-forecasting/data 397 | 398 | Args: 399 | config: Default experiment config for Favorita 400 | """ 401 | 402 | url = 'https://www.kaggle.com/c/favorita-grocery-sales-forecasting/data' 403 | 404 | data_folder = config.data_folder 405 | 406 | # Save manual download to root folder to avoid deleting when re-processing. 407 | zip_file = os.path.join(data_folder, '..', 408 | 'favorita-grocery-sales-forecasting.zip') 409 | 410 | if not os.path.exists(zip_file): 411 | raise ValueError( 412 | 'Favorita zip file not found in {}!'.format(zip_file) + 413 | ' Please manually download data from Kaggle @ {}'.format(url)) 414 | 415 | # Unpack main zip file 416 | outputs_file = os.path.join(data_folder, 'train.csv.7z') 417 | unzip(zip_file, outputs_file, data_folder) 418 | 419 | # Unpack individually zipped files 420 | for file in glob.glob(os.path.join(data_folder, '*.7z')): 421 | 422 | csv_file = file.replace('.7z', '') 423 | 424 | unzip(file, csv_file, data_folder) 425 | 426 | print('Unzipping complete, commencing data processing...') 427 | 428 | # Extract only a subset of data to save/process for efficiency 429 | start_date = pd.datetime(2015, 1, 1) 430 | end_date = pd.datetime(2016, 6, 1) 431 | 432 | print('Regenerating data...') 433 | 434 | # load temporal data 435 | temporal = pd.read_csv(os.path.join(data_folder, 'train.csv'), index_col=0) 436 | 437 | store_info = pd.read_csv(os.path.join(data_folder, 'stores.csv'), index_col=0) 438 | oil = pd.read_csv( 439 | os.path.join(data_folder, 'oil.csv'), index_col=0).iloc[:, 0] 440 | holidays = pd.read_csv(os.path.join(data_folder, 'holidays_events.csv')) 441 | items = pd.read_csv(os.path.join(data_folder, 'items.csv'), index_col=0) 442 | transactions = pd.read_csv(os.path.join(data_folder, 'transactions.csv')) 443 | 444 | # Take first 6 months of data 445 | temporal['date'] = pd.to_datetime(temporal['date']) 446 | 447 | # Filter dates to reduce storage space requirements 448 | if start_date is not None: 449 | temporal = temporal[(temporal['date'] >= start_date)] 450 | if end_date is not None: 451 | temporal = temporal[(temporal['date'] < end_date)] 452 | 453 | dates = temporal['date'].unique() 454 | 455 | # Add trajectory identifier 456 | temporal['traj_id'] = temporal['store_nbr'].apply( 457 | str) + '_' + temporal['item_nbr'].apply(str) 458 | temporal['unique_id'] = temporal['traj_id'] + '_' + temporal['date'].apply( 459 | str) 460 | 461 | # Remove all IDs with negative returns 462 | print('Removing returns data') 463 | min_returns = temporal['unit_sales'].groupby(temporal['traj_id']).min() 464 | valid_ids = set(min_returns[min_returns >= 0].index) 465 | selector = temporal['traj_id'].apply(lambda traj_id: traj_id in valid_ids) 466 | new_temporal = temporal[selector].copy() 467 | del temporal 468 | gc.collect() 469 | temporal = new_temporal 470 | temporal['open'] = 1 471 | 472 | # Resampling 473 | print('Resampling to regular grid') 474 | resampled_dfs = [] 475 | for traj_id, raw_sub_df in temporal.groupby('traj_id'): 476 | print('Resampling', traj_id) 477 | sub_df = raw_sub_df.set_index('date', drop=True).copy() 478 | sub_df = sub_df.resample('1d').last() 479 | sub_df['date'] = sub_df.index 480 | sub_df[['store_nbr', 'item_nbr', 'onpromotion']] \ 481 | = sub_df[['store_nbr', 'item_nbr', 'onpromotion']].fillna(method='ffill') 482 | sub_df['open'] = sub_df['open'].fillna( 483 | 0) # flag where sales data is unknown 484 | sub_df['log_sales'] = np.log(sub_df['unit_sales']) 485 | 486 | resampled_dfs.append(sub_df.reset_index(drop=True)) 487 | 488 | new_temporal = pd.concat(resampled_dfs, axis=0) 489 | del temporal 490 | gc.collect() 491 | temporal = new_temporal 492 | 493 | print('Adding oil') 494 | oil.name = 'oil' 495 | oil.index = pd.to_datetime(oil.index) 496 | temporal = temporal.join( 497 | oil.loc[dates].fillna(method='ffill'), on='date', how='left') 498 | temporal['oil'] = temporal['oil'].fillna(-1) 499 | 500 | print('Adding store info') 501 | temporal = temporal.join(store_info, on='store_nbr', how='left') 502 | 503 | print('Adding item info') 504 | temporal = temporal.join(items, on='item_nbr', how='left') 505 | 506 | transactions['date'] = pd.to_datetime(transactions['date']) 507 | temporal = temporal.merge( 508 | transactions, 509 | left_on=['date', 'store_nbr'], 510 | right_on=['date', 'store_nbr'], 511 | how='left') 512 | temporal['transactions'] = temporal['transactions'].fillna(-1) 513 | 514 | # Additional date info 515 | temporal['day_of_week'] = pd.to_datetime(temporal['date'].values).dayofweek 516 | temporal['day_of_month'] = pd.to_datetime(temporal['date'].values).day 517 | temporal['month'] = pd.to_datetime(temporal['date'].values).month 518 | 519 | # Add holiday info 520 | print('Adding holidays') 521 | holiday_subset = holidays[holidays['transferred'].apply( 522 | lambda x: not x)].copy() 523 | holiday_subset.columns = [ 524 | s if s != 'type' else 'holiday_type' for s in holiday_subset.columns 525 | ] 526 | holiday_subset['date'] = pd.to_datetime(holiday_subset['date']) 527 | local_holidays = holiday_subset[holiday_subset['locale'] == 'Local'] 528 | regional_holidays = holiday_subset[holiday_subset['locale'] == 'Regional'] 529 | national_holidays = holiday_subset[holiday_subset['locale'] == 'National'] 530 | 531 | temporal['national_hol'] = temporal.merge( 532 | national_holidays, left_on=['date'], right_on=['date'], 533 | how='left')['description'].fillna('') 534 | temporal['regional_hol'] = temporal.merge( 535 | regional_holidays, 536 | left_on=['state', 'date'], 537 | right_on=['locale_name', 'date'], 538 | how='left')['description'].fillna('') 539 | temporal['local_hol'] = temporal.merge( 540 | local_holidays, 541 | left_on=['city', 'date'], 542 | right_on=['locale_name', 'date'], 543 | how='left')['description'].fillna('') 544 | 545 | temporal.sort_values('unique_id', inplace=True) 546 | 547 | print('Saving processed file to {}'.format(config.data_csv_path)) 548 | temporal.to_csv(config.data_csv_path) 549 | 550 | 551 | # Core routine. 552 | def main(expt_name, force_download, output_folder): 553 | """Runs main download routine. 554 | 555 | Args: 556 | expt_name: Name of experiment 557 | force_download: Whether to force data download from scratch 558 | output_folder: Folder path for storing data 559 | """ 560 | 561 | print('#### Running download script ###') 562 | 563 | expt_config = ExperimentConfig(expt_name, output_folder) 564 | 565 | if os.path.exists(expt_config.data_csv_path) and not force_download: 566 | print('Data has been processed for {}. Skipping download...'.format( 567 | expt_name)) 568 | sys.exit(0) 569 | else: 570 | print('Resetting data folder...') 571 | recreate_folder(expt_config.data_folder) 572 | 573 | # Default download functions 574 | download_functions = { 575 | 'volatility': download_volatility, 576 | 'electricity': download_electricity, 577 | 'traffic': download_traffic, 578 | 'favorita': process_favorita 579 | } 580 | 581 | if expt_name not in download_functions: 582 | raise ValueError('Unrecongised experiment! name={}'.format(expt_name)) 583 | 584 | download_function = download_functions[expt_name] 585 | 586 | # Run data download 587 | print('Getting {} data...'.format(expt_name)) 588 | download_function(expt_config) 589 | 590 | print('Download completed.') 591 | 592 | 593 | if __name__ == '__main__': 594 | 595 | def get_args(): 596 | """Returns settings from command line.""" 597 | 598 | experiment_names = ExperimentConfig.default_experiments 599 | 600 | parser = argparse.ArgumentParser(description='Data download configs') 601 | parser.add_argument( 602 | 'expt_name', 603 | metavar='e', 604 | type=str, 605 | nargs='?', 606 | choices=experiment_names, 607 | help='Experiment Name. Default={}'.format(','.join(experiment_names))) 608 | parser.add_argument( 609 | 'output_folder', 610 | metavar='f', 611 | type=str, 612 | nargs='?', 613 | default='.', 614 | help='Path to folder for data download') 615 | parser.add_argument( 616 | 'force_download', 617 | metavar='r', 618 | type=str, 619 | nargs='?', 620 | choices=['yes', 'no'], 621 | default='no', 622 | help='Whether to re-run data download') 623 | 624 | args = parser.parse_known_args()[0] 625 | 626 | root_folder = None if args.output_folder == '.' else args.output_folder 627 | 628 | return args.expt_name, args.force_download == 'yes', root_folder 629 | 630 | name, force, folder = get_args() 631 | main(expt_name=name, force_download=force, output_folder=folder) 632 | -------------------------------------------------------------------------------- /unlearn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import os 6 | import numpy as np 7 | from utils import get_lr, evaluate, epoch_end 8 | 9 | class UAgeDB(torch.utils.data.Dataset): 10 | def __init__(self, df, data_dir, img_size, split='train'): 11 | self.df = df 12 | self.data_dir = data_dir 13 | self.img_size = img_size 14 | self.split = split 15 | 16 | 17 | def __len__(self): 18 | return len(self.df) 19 | 20 | def __getitem__(self, index): 21 | index = index % len(self.df) 22 | row = self.df.iloc[index] 23 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 24 | transform = self.get_transform() 25 | img = transform(img) 26 | label = np.asarray([row['age']]).astype('float32') 27 | ulabel = np.asarray(row['unlearn']).astype('float32') 28 | 29 | return img, label, ulabel 30 | 31 | def get_transform(self): 32 | if self.split == 'train': 33 | transform = transforms.Compose([ 34 | transforms.Resize((self.img_size, self.img_size)), 35 | transforms.RandomCrop(self.img_size, padding=16), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 39 | ]) 40 | else: 41 | transform = transforms.Compose([ 42 | transforms.Resize((self.img_size, self.img_size)), 43 | transforms.ToTensor(), 44 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 45 | ]) 46 | return transform 47 | 48 | def attention(x): 49 | """ 50 | Taken from https://github.com/szagoruyko/attention-transfer 51 | :param x = activations 52 | """ 53 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 54 | 55 | 56 | def attention_diff(x, y): 57 | """ 58 | Taken from https://github.com/szagoruyko/attention-transfer 59 | :param x = activations 60 | :param y = activations 61 | """ 62 | return (attention(x) - attention(y)).pow(2).mean() 63 | 64 | 65 | 66 | def forget_loss(model_output, model_activations, proxy_output, proxy_activations, mask, AT_beta = 50): 67 | 68 | loss = F.l1_loss(model_output[mask], proxy_output[mask]) 69 | if AT_beta > 0: 70 | at_loss = 0 71 | for i in range(len(proxy_activations)): 72 | at_loss = at_loss + AT_beta * attention_diff(model_activations[i][mask], proxy_activations[i][mask]) 73 | else: 74 | at_loss = 0 75 | 76 | total_loss = loss + at_loss 77 | 78 | return total_loss 79 | 80 | 81 | 82 | def fit_one_forget_cycle(epochs, model, proxy_model, train_loader, val_loader, lr, device, save_path): 83 | best_loss = np.inf 84 | torch.cuda.empty_cache() 85 | history = [] 86 | 87 | optimizer = torch.optim.Adam(model.parameters(), lr = lr) 88 | 89 | sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) 90 | 91 | for epoch in range(epochs): 92 | model.train() 93 | train_losses = [] 94 | lrs = [] 95 | for batch in train_loader: 96 | 97 | images, labels, ulabels = batch 98 | images, labels, ulabels = images.to(device), labels.to(device), ulabels.to(device) 99 | model_out, *model_activations = model(images) 100 | with torch.no_grad(): 101 | proxy_out, *proxy_activations = proxy_model(images) 102 | 103 | 104 | label_loss = 0 105 | if ulabels.sum() < len(ulabels): 106 | mask = (ulabels == 0) 107 | r_model_out = model_out[mask] 108 | r_labels = labels[mask] 109 | label_loss = F.l1_loss(r_model_out, r_labels) 110 | 111 | proxy_loss = 0 112 | if ulabels.sum() > 0: 113 | mask = (ulabels == 1) 114 | proxy_loss = forget_loss(model_out, model_activations, proxy_out, proxy_activations, mask) 115 | 116 | coeff = ulabels.sum()/len(ulabels) 117 | loss = coeff*proxy_loss + (1-coeff)*label_loss 118 | 119 | ###### 120 | train_losses.append(loss) 121 | loss.backward() 122 | 123 | optimizer.step() 124 | optimizer.zero_grad() 125 | 126 | lrs.append(get_lr(optimizer)) 127 | 128 | 129 | # Validation phase 130 | result = evaluate(model, val_loader, device) 131 | result['train_loss'] = torch.stack(train_losses).mean().item() 132 | result['lrs'] = lrs 133 | epoch_end(model, epoch, result) 134 | history.append(result) 135 | torch.save(model.state_dict(), save_path) 136 | 137 | return history -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import numpy as np 4 | from torch import nn 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from torch.nn import functional as F 8 | from nltk.corpus import stopwords 9 | 10 | stop_words = set(stopwords.words('english')) 11 | 12 | class QuantileLoss(nn.Module): 13 | ## From: https://medium.com/the-artificial-impostor/quantile-regression-part-2-6fdbc26b2629 14 | 15 | def __init__(self, quantiles): 16 | ##takes a list of quantiles 17 | super().__init__() 18 | self.quantiles = quantiles 19 | 20 | def forward(self, preds, target): 21 | assert not target.requires_grad 22 | assert preds.size(0) == target.size(0) 23 | losses = [] 24 | for i, q in enumerate(self.quantiles): 25 | errors = target - preds[:, i] 26 | losses.append( 27 | torch.max( 28 | (q-1) * errors, 29 | q * errors 30 | ).unsqueeze(1)) 31 | loss = torch.mean( 32 | torch.sum(torch.cat(losses, dim=1), dim=1)) 33 | return loss 34 | 35 | def training_step(model, batch, device): 36 | images, labels = batch 37 | images, labels = images.to(device), labels.to(device) 38 | out, *_ = model(images) # Generate predictions 39 | loss = F.l1_loss(out, labels) # Calculate loss 40 | return loss 41 | 42 | def validation_step(model, batch, device): 43 | images, labels= batch 44 | images, labels = images.to(device), labels.to(device) 45 | out, *_ = model(images) # Generate predictions 46 | loss = F.l1_loss(out, labels) # Calculate loss 47 | return {'Loss': loss.detach()} 48 | 49 | def validation_epoch_end(model, outputs): 50 | batch_losses = [x['Loss'] for x in outputs] 51 | epoch_loss = torch.stack(batch_losses).mean() # Combine losses 52 | return {'Loss': epoch_loss.item()} 53 | 54 | def epoch_end(model, epoch, result): 55 | print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}".format( 56 | epoch, result['lrs'][-1], result['train_loss'], result['Loss'])) 57 | 58 | 59 | 60 | @torch.no_grad() 61 | def evaluate(model, val_loader, device): 62 | model.eval() 63 | outputs = [validation_step(model, batch, device) for batch in val_loader] 64 | return validation_epoch_end(model, outputs) 65 | 66 | def get_lr(optimizer): 67 | for param_group in optimizer.param_groups: 68 | return param_group['lr'] 69 | 70 | def fit_one_cycle(epochs, model, train_loader, val_loader, device, save_path, lr=0.01): 71 | best_loss = np.inf 72 | torch.cuda.empty_cache() 73 | history = [] 74 | 75 | optimizer = torch.optim.Adam(model.parameters(), lr = lr) 76 | 77 | sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True) 78 | 79 | for epoch in range(epochs): 80 | model.train() 81 | train_losses = [] 82 | lrs = [] 83 | for batch in train_loader: 84 | loss = training_step(model, batch, device) 85 | train_losses.append(loss) 86 | loss.backward() 87 | 88 | optimizer.step() 89 | optimizer.zero_grad() 90 | 91 | lrs.append(get_lr(optimizer)) 92 | 93 | 94 | # Validation phase 95 | result = evaluate(model, val_loader, device) 96 | result['train_loss'] = torch.stack(train_losses).mean().item() 97 | result['lrs'] = lrs 98 | epoch_end(model, epoch, result) 99 | history.append(result) 100 | sched.step(result['Loss']) 101 | if best_loss > result['Loss']: 102 | best_loss = result['Loss'] 103 | torch.save(model.state_dict(), save_path) 104 | 105 | return history 106 | 107 | def inference_step(model, batch, device): 108 | images, labels= batch 109 | images, labels = images.to(device), labels.to(device) 110 | out, *_ = model(images) # Generate predictions 111 | return out 112 | 113 | 114 | @torch.no_grad() 115 | def predict(model, val_loader, device): 116 | model.eval() 117 | outputs = [inference_step(model, batch, device) for batch in val_loader] 118 | return torch.cat(outputs, axis = 0) 119 | 120 | 121 | def clean_text(text): 122 | # lower case characters only 123 | text = text.lower() 124 | 125 | # remove urls 126 | text = re.sub('http\S+', ' ', text) 127 | 128 | # only alphabets, spaces and apostrophes 129 | text = re.sub("[^a-z' ]+", ' ', text) 130 | 131 | # remove all apostrophes which are not used in word contractions 132 | text = ' ' + text + ' ' 133 | text = re.sub("[^a-z]'|'[^a-z]", ' ', text) 134 | 135 | split_sentence = text.split() 136 | filtered_sentence = [w for w in split_sentence if not w.lower() in stop_words] 137 | return filtered_sentence 138 | 139 | --------------------------------------------------------------------------------