├── README.md ├── datasets.py ├── glc.py └── example.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # GLC 2 | Unofficial implementation of [Using Trusted Data to Train Deep Networks on 3 | Labels Corrupted by Severe Noise](https://arxiv.org/pdf/1802.05300.pdf) (NIPS 18) in PyTorch. 4 | ## Usage 5 | (See example.ipynb for a walkthrough on MNIST) 6 | ```python 7 | from datasets import GoldCorrectionDataset 8 | from glc import CorrectionGenerator, GoldCorrectionLossFunction 9 | 10 | c_gen = CorrectionGenerator(simulate=True, dataset=trn_ds, randomization_strength=1.0) 11 | 12 | # Fetch both corrupted and clean datasets if in simuate mode 13 | trusted_dataset, untrusted_dataset = c_gen.fetch_datasets() 14 | 15 | """ 16 | Train the model on untrusted_dataset 17 | """ 18 | # Generate correction matrix 19 | label_correction_matrix = c_gen.generate_correction_matrix(trainer.model, 32) 20 | 21 | # Wrap trusted and untrusted dataset together using GoldCorrectionDataset class 22 | gold_ds = GoldCorrectionDataset(trusted_dataset, untrusted_dataset) 23 | gold_dl = DataLoader(gold_ds, batch_size=32, shuffle=True) 24 | 25 | # Modified loss function 26 | gold_loss = GoldCorrectionLossFunction(label_correction_matrix) 27 | 28 | """ 29 | Train using gold_ds and gold_loss the model, until convergence 30 | """ 31 | ``` 32 | 33 | ## Results 34 | ### MNIST 35 | #### Regular training on trusted data (~5% of entire data) -> 61.12 accuracy 36 | #### Gold Loss Correction with 5% trusted -> 95.45 accuracy (All samples in untrusted data (95% of total data) is corrupted by randomly assigning labels) 37 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision.transforms import ToTensor 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | from torch.utils.data import Subset 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | class IndexEnabledDataset(Dataset): 12 | def __init__(self, dataset): 13 | self.dataset = dataset 14 | def __getitem__(self, index): 15 | data, target = self.dataset[index] 16 | return data, target, index 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | class RandomizedDataset(Dataset): 21 | def __init__(self, dataset, num_classes, p=0.5, mode='random_uniform'): 22 | self.dataset = dataset 23 | self.random_offsets = [] 24 | for i in range(len(self.dataset)): 25 | if np.random.uniform() < p: 26 | self.random_offsets.append(np.random.randint(0, num_classes)) 27 | else: 28 | self.random_offsets.append(0) 29 | self.mode = mode 30 | self.num_classes = num_classes 31 | self.p = p 32 | def __getitem__(self, index): 33 | data, target = self.dataset[index] 34 | if self.mode == 'random_uniform': 35 | target = (target + self.random_offsets[index]) % self.num_classes 36 | return data, target 37 | def __len__(self): 38 | return len(self.dataset) 39 | 40 | class GoldCorrectionDataset(Dataset): 41 | def __init__(self, true_dataset, noisy_dataset, valid=False): 42 | self.true_dataset = true_dataset 43 | self.noisy_dataset = noisy_dataset 44 | self.valid = valid 45 | def __getitem__(self, index): 46 | if index < len(self.true_dataset): 47 | x, y = self.true_dataset[index] 48 | return [x, 1], y 49 | x, y = self.noisy_dataset[index - len(self.true_dataset)] 50 | if self.valid: 51 | return [x, 1], y 52 | return [x, 0], y 53 | def __len__(self): 54 | return len(self.true_dataset) + len(self.noisy_dataset) -------------------------------------------------------------------------------- /glc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision.transforms import ToTensor 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | from torch.utils.data import Subset 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from datasets import * 11 | 12 | class CorrectionGenerator: 13 | 14 | def __init__(self, simulate=True, trusted_dataset=None, 15 | untrusted_dataset=None, dataset=None, randomization_strength=None): 16 | if not simulate and ((trusted_dataset is None) or (untrusted_dataset is None)): 17 | raise ValueError('Provide trusted and untrusted datasets') 18 | if simulate and (dataset is None) or (randomization_strength is None): 19 | raise ValueError('Cannot simulate without dataset and randomization strength') 20 | if not simulate: 21 | self.trusted_dataset = trusted_dataset 22 | self.untrusted_dataset = untrusted_dataset 23 | else: 24 | self.prepare_datasets(dataset) 25 | self.dataset_dicts = self.prepare_class_generators(self.trusted_dataset) 26 | self.untrusted_dataset = RandomizedDataset(self.untrusted_dataset, len(self.dataset_dicts.keys()), 27 | randomization_strength) 28 | 29 | def fetch_datasets(self): 30 | return self.trusted_dataset, self.untrusted_dataset 31 | 32 | def generate_correction_matrix(self, noisy_model, batch_size): 33 | return self.build_label_correction_matrix(noisy_model, self.dataset_dicts, batch_size) 34 | 35 | def random_true_noisy_split(self, dataset, true_rat=0.1): 36 | true_idx = int(true_rat * len(dataset)) 37 | idxs = np.arange(0, len(dataset)) 38 | np.random.shuffle(idxs) 39 | return Subset(dataset, idxs[:true_idx]), Subset(dataset, idxs[true_idx:]) 40 | 41 | def prepare_datasets(self, dataset, trusted_rat=0.1): 42 | self.trusted_dataset, self.untrusted_dataset = self.random_true_noisy_split(dataset, trusted_rat) 43 | 44 | def prepare_indices(self, dataset): 45 | index_enabled_dataset = IndexEnabledDataset(dataset) 46 | dl = DataLoader(index_enabled_dataset, batch_size=1) 47 | indices = defaultdict(list) 48 | for (x, y, index) in tqdm(dl): 49 | indices[int(y[0].data)].append(int(index[0].data)) 50 | return indices 51 | 52 | def prepare_class_generators(self, dataset): 53 | indices = self.prepare_indices(dataset) 54 | return {k:Subset(dataset, indices[k]) for k in indices.keys()} 55 | 56 | def build_label_correction_matrix(self, noisy_model, clean_ds_dicts, batch_size=32): 57 | num_labels = len(clean_ds_dicts.keys()) 58 | label_correction_matrix = torch.zeros((num_labels, num_labels)) 59 | for lab, idx in enumerate(clean_ds_dicts): 60 | clean_dl = DataLoader(clean_ds_dicts[idx], batch_size=batch_size) 61 | pbar = tqdm(clean_dl) 62 | pbar.set_description(f'Processing label {lab}') 63 | for data, target in pbar: 64 | predicted_proba = F.softmax(noisy_model(data), dim=1).mean(dim=0) 65 | label_correction_matrix[idx, :] += predicted_proba 66 | label_correction_matrix[idx, :] = label_correction_matrix[idx, :] / len(clean_dl) 67 | print('Done') 68 | return label_correction_matrix 69 | 70 | class GoldCorrectionLossFunction(nn.Module): 71 | def __init__(self, label_correction_matrix): 72 | super(GoldCorrectionLossFunction, self).__init__() 73 | self.label_correction_matrix = label_correction_matrix.data 74 | def forward(self, x, y): 75 | c_loss = nn.CrossEntropyLoss(reduction='none')(x[0], y) * x[1].data.float() 76 | n_loss = nn.NLLLoss(reduction='none')(torch.log(torch.matmul(F.softmax(x[0], dim=1), self.label_correction_matrix.data)), 77 | y) * (1 - x[1]).data.float() 78 | return c_loss.mean() + n_loss.mean() -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.utils.data import Dataset, DataLoader\n", 12 | "from torchvision.transforms import ToTensor\n", 13 | "from collections import defaultdict\n", 14 | "from tqdm import tqdm\n", 15 | "from torch.utils.data import Subset\n", 16 | "import numpy as np\n", 17 | "import torch.nn.functional as F\n", 18 | "from datasets import GoldCorrectionDataset\n", 19 | "from glc import CorrectionGenerator, GoldCorrectionLossFunction" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 6, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from torchvision.datasets import MNIST\n", 29 | "trn_ds = MNIST('data/', train=True, transform=ToTensor())\n", 30 | "val_ds = MNIST('data/', train=False, transform=ToTensor())" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 7, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stderr", 40 | "output_type": "stream", 41 | "text": [ 42 | "100%|██████████| 6000/6000 [00:00<00:00, 6324.63it/s]\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "c_gen = CorrectionGenerator(dataset=trn_ds, randomization_strength=1.0)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 8, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "trusted_dataset, untrusted_dataset = c_gen.fetch_datasets()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 12, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from torched.customs.layers import LinearLayer, Flatten\n", 66 | "from torched.trainer_utils import Train\n", 67 | "class Net(nn.Module):\n", 68 | " def __init__(self, in_dims, hid_dims, out_dims):\n", 69 | " super(Net, self).__init__()\n", 70 | " self.net = nn.Sequential(Flatten(),\n", 71 | " LinearLayer(in_dims, hid_dims, use_bn=True),\n", 72 | " LinearLayer(hid_dims, hid_dims, use_bn=True),\n", 73 | " nn.Linear(hid_dims, out_dims))\n", 74 | " def forward(self, x):\n", 75 | " if isinstance(x, list):\n", 76 | " inp, c = x[0], x[1]\n", 77 | " else:\n", 78 | " inp = x\n", 79 | " out = self.net(inp)\n", 80 | " if isinstance(x, list):\n", 81 | " return [out, c]\n", 82 | " return out" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 13, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "model = Net(784, 300, 10)\n", 92 | "u_dl = DataLoader(untrusted_dataset, batch_size=32)\n", 93 | "v_dl = DataLoader(val_ds, batch_size=32)\n", 94 | "trainer = Train(model, [u_dl, v_dl], cuda=False)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 14, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "application/vnd.jupyter.widget-view+json": { 105 | "model_id": "08d7d258a9db48ff9cd1ad67a4f569c4", 106 | "version_major": 2, 107 | "version_minor": 0 108 | }, 109 | "text/plain": [ 110 | "HBox(children=(IntProgress(value=0, max=7), HTML(value='')))" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "output_type": "display_data" 115 | }, 116 | { 117 | "name": "stderr", 118 | "output_type": "stream", 119 | "text": [ 120 | "Train Loss 2.328309: 100%|██████████| 1688/1688 [00:40<00:00, 35.15it/s]\n", 121 | "Valid Loss 2.332816: 100%|██████████| 313/313 [00:02<00:00, 109.70it/s]\n", 122 | "Train Loss 2.299558: 100%|██████████| 1688/1688 [00:42<00:00, 39.31it/s]\n", 123 | "Valid Loss 2.330564: 100%|██████████| 313/313 [00:02<00:00, 117.04it/s]\n", 124 | "Train Loss 2.262732: 100%|██████████| 1688/1688 [00:42<00:00, 39.47it/s]\n", 125 | "Valid Loss 2.329315: 100%|██████████| 313/313 [00:02<00:00, 113.94it/s]\n", 126 | "Train Loss 2.271581: 100%|██████████| 1688/1688 [00:43<00:00, 38.89it/s]\n", 127 | "Valid Loss 2.338990: 100%|██████████| 313/313 [00:02<00:00, 119.94it/s]\n", 128 | "Train Loss 2.245886: 100%|██████████| 1688/1688 [00:43<00:00, 35.17it/s]\n", 129 | "Valid Loss 2.339327: 100%|██████████| 313/313 [00:02<00:00, 107.19it/s]\n", 130 | "Train Loss 2.205070: 100%|██████████| 1688/1688 [00:43<00:00, 38.45it/s]\n", 131 | "Valid Loss 2.340389: 100%|██████████| 313/313 [00:02<00:00, 109.98it/s]\n", 132 | "Train Loss 2.172739: 100%|██████████| 1688/1688 [00:45<00:00, 37.32it/s]\n", 133 | "Valid Loss 2.348561: 100%|██████████| 313/313 [00:02<00:00, 112.15it/s]" 134 | ] 135 | }, 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "\n" 141 | ] 142 | }, 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "\n" 148 | ] 149 | } 150 | ], 151 | "source": [ 152 | "trainer.train(1e-4, 3, 2, crit=nn.CrossEntropyLoss(), opt='adamW')" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 15, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "Processing label 0: 100%|██████████| 19/19 [00:00<00:00, 114.33it/s]\n", 165 | "Processing label 1: 100%|██████████| 20/20 [00:00<00:00, 122.72it/s]\n", 166 | "Processing label 2: 100%|██████████| 20/20 [00:00<00:00, 189.87it/s]\n", 167 | "Processing label 3: 100%|██████████| 18/18 [00:00<00:00, 177.34it/s]\n", 168 | "Processing label 4: 100%|██████████| 19/19 [00:00<00:00, 175.65it/s]\n", 169 | "Processing label 5: 100%|██████████| 18/18 [00:00<00:00, 182.80it/s]\n", 170 | "Processing label 6: 100%|██████████| 19/19 [00:00<00:00, 161.49it/s]\n", 171 | "Processing label 7: 100%|██████████| 20/20 [00:00<00:00, 155.87it/s]\n", 172 | "Processing label 8: 100%|██████████| 18/18 [00:00<00:00, 168.99it/s]\n", 173 | "Processing label 9: 100%|██████████| 22/22 [00:00<00:00, 125.26it/s]" 174 | ] 175 | }, 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "Done\n" 181 | ] 182 | }, 183 | { 184 | "name": "stderr", 185 | "output_type": "stream", 186 | "text": [ 187 | "\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "label_correction_matrix = c_gen.generate_correction_matrix(trainer.model, 32)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 16, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "gold_ds = GoldCorrectionDataset(trusted_dataset, untrusted_dataset)\n", 202 | "gold_dl = DataLoader(gold_ds, batch_size=32, shuffle=True)\n", 203 | "g_val_ds = GoldCorrectionDataset(val_ds, val_ds, valid=True)\n", 204 | "g_val_dl = DataLoader(g_val_ds, batch_size=32)\n", 205 | "gold_loss = GoldCorrectionLossFunction(label_correction_matrix)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 17, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "4e958337334642b384135eb30568636b", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | "HBox(children=(IntProgress(value=0, max=7), HTML(value='')))" 222 | ] 223 | }, 224 | "metadata": {}, 225 | "output_type": "display_data" 226 | }, 227 | { 228 | "name": "stderr", 229 | "output_type": "stream", 230 | "text": [ 231 | "Train Loss 2.148798: 100%|██████████| 1875/1875 [00:48<00:00, 38.78it/s]\n", 232 | "Valid Loss 0.439217: 100%|██████████| 625/625 [00:06<00:00, 95.55it/s] \n", 233 | "Train Loss 2.106702: 100%|██████████| 1875/1875 [00:51<00:00, 36.76it/s]\n", 234 | "Valid Loss 0.236635: 100%|██████████| 625/625 [00:06<00:00, 98.85it/s] \n", 235 | "Train Loss 2.091937: 100%|██████████| 1875/1875 [00:52<00:00, 35.83it/s]\n", 236 | "Valid Loss 0.220722: 100%|██████████| 625/625 [00:06<00:00, 99.50it/s] \n", 237 | "Train Loss 2.090299: 100%|██████████| 1875/1875 [00:51<00:00, 34.03it/s]\n", 238 | "Valid Loss 0.183325: 100%|██████████| 625/625 [00:06<00:00, 103.53it/s]\n", 239 | "Train Loss 2.082261: 100%|██████████| 1875/1875 [00:52<00:00, 35.38it/s]\n", 240 | "Valid Loss 0.163479: 100%|██████████| 625/625 [00:06<00:00, 92.16it/s] \n", 241 | "Train Loss 2.077849: 100%|██████████| 1875/1875 [00:53<00:00, 34.91it/s]\n", 242 | "Valid Loss 0.157068: 100%|██████████| 625/625 [00:05<00:00, 104.54it/s]\n", 243 | "Train Loss 2.076205: 100%|██████████| 1875/1875 [00:54<00:00, 34.29it/s]\n", 244 | "Valid Loss 0.155766: 100%|██████████| 625/625 [00:06<00:00, 90.00it/s] \n" 245 | ] 246 | }, 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "trainer.dataloader = gold_dl\n", 257 | "trainer.val_dataloader = g_val_dl\n", 258 | "trainer.train(1e-4, 3, 2, crit=gold_loss, opt='adamW')" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 18, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "def epoch(loader, model, opt=None):\n", 268 | " \"\"\"Standard training/evaluation epoch over the dataset\"\"\"\n", 269 | " total_loss, total_err = 0.,0.\n", 270 | " for X,y in loader:\n", 271 | " X,y = X, y\n", 272 | " yp = model(X)\n", 273 | " loss = nn.CrossEntropyLoss()(yp,y)\n", 274 | " if opt:\n", 275 | " opt.zero_grad()\n", 276 | " loss.backward()\n", 277 | " opt.step()\n", 278 | " \n", 279 | " total_err += (yp.max(dim=1)[1] != y).sum().item()\n", 280 | " total_loss += loss.item() * X.shape[0]\n", 281 | " return total_err / len(loader.dataset), total_loss / len(loader.dataset)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 19, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "err, loss = epoch(v_dl, trainer.model)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 20, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "data": { 300 | "text/plain": [ 301 | "0.9545" 302 | ] 303 | }, 304 | "execution_count": 20, 305 | "metadata": {}, 306 | "output_type": "execute_result" 307 | } 308 | ], 309 | "source": [ 310 | "1 - err" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [] 319 | } 320 | ], 321 | "metadata": { 322 | "kernelspec": { 323 | "display_name": "Python 3", 324 | "language": "python", 325 | "name": "python3" 326 | }, 327 | "language_info": { 328 | "codemirror_mode": { 329 | "name": "ipython", 330 | "version": 3 331 | }, 332 | "file_extension": ".py", 333 | "mimetype": "text/x-python", 334 | "name": "python", 335 | "nbconvert_exporter": "python", 336 | "pygments_lexer": "ipython3", 337 | "version": "3.6.7" 338 | } 339 | }, 340 | "nbformat": 4, 341 | "nbformat_minor": 2 342 | } 343 | --------------------------------------------------------------------------------