├── .gitignore ├── LICENSE ├── README.md ├── bootstrapping_loss ├── __init__.py └── loss.py ├── examples ├── analysis.ipynb └── mnist │ ├── experiments.png │ ├── main.py │ └── run_experiments.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python, and modified... 2 | 3 | examples/processed 4 | examples/raw 5 | examples/tensorboard_logs 6 | 7 | ## Linux backup files 8 | *~ 9 | 10 | ## Arhives 11 | *.zip 12 | 13 | ### Mac OS X stuff ### 14 | Desktop DF 15 | Desktop DB 16 | .Spotlight-V100 17 | .DS_Store 18 | .Trashes 19 | .com.apple.timemachine.supported 20 | .fseventsd 21 | .syncinfo 22 | .TemporaryItems 23 | report/ 24 | 25 | ### Python ### 26 | # Byte-compiled / optimized / DLL files 27 | __pycache__/ 28 | *.py[cod] 29 | *$py.class 30 | .pytest_cache 31 | 32 | # C extensions 33 | *.so 34 | 35 | # Distribution / packaging 36 | .Python 37 | env/ 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | MANIFEST 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *,cover 73 | .hypothesis/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # IPython Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # Idea 92 | .idea 93 | *.iml 94 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 vfdev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bootstrapping loss function implementation 2 | based on "Training Deep Neural Networks on Noisy Labels with Bootstrapping" 3 | [https://arxiv.org/abs/1412.6596](https://arxiv.org/abs/1412.6596) 4 | 5 | 6 | ## Experiments on MNIST 7 | 8 | Experiments on MNIST: 9 | ```bash 10 | cd examples/mnist && python main.py run --mode hard_bootstrap --noise_fraction=0.45 11 | cd examples/mnist && python main.py run --mode soft_bootstrap --noise_fraction=0.45 12 | cd examples/mnist && python main.py run --mode xentropy --noise_fraction=0.45 13 | ``` 14 | 15 | ``` 16 | cd examples/mnist && sh run_experiments.sh >> out 2> log 17 | ``` 18 | 19 | - [Experiments on TRAINS](https://app.ignite.trains.allegro.ai/projects/276a39e824794d1093ecddd8b2afb8d0) 20 | - `WITH_TRAINS=True sh run_experiments.sh >> out 2> log` 21 | 22 | ### Requirements: 23 | 24 | - pytorch>=1.3 25 | - torchvision>=0.4.1 26 | - [pytorch-ignite](https://github.com/pytorch/ignite)>=0.4.2 27 | - google fire>=0.3.1 28 | 29 | ``` 30 | pip install -r requirements.txt 31 | ``` 32 | -------------------------------------------------------------------------------- /bootstrapping_loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from bootstrapping_loss.loss import SoftBootstrappingLoss, HardBootstrappingLoss -------------------------------------------------------------------------------- /bootstrapping_loss/loss.py: -------------------------------------------------------------------------------- 1 | # 2 | # Training Deep Neural Networks on Noisy Labels with Bootstrapping 3 | # http://www-personal.umich.edu/~reedscot/bootstrap.pdf 4 | # 5 | 6 | import torch 7 | from torch.nn import Module 8 | import torch.nn.functional as F 9 | 10 | 11 | class SoftBootstrappingLoss(Module): 12 | """ 13 | ``Loss(t, p) = - (beta * t + (1 - beta) * p) * log(p)`` 14 | 15 | Args: 16 | beta (float): bootstrap parameter. Default, 0.95 17 | reduce (bool): computes mean of the loss. Default, True. 18 | as_pseudo_label (bool): Stop gradient propagation for the term ``(1 - beta) * p``. 19 | Can be interpreted as pseudo-label. 20 | """ 21 | def __init__(self, beta=0.95, reduce=True, as_pseudo_label=True): 22 | super(SoftBootstrappingLoss, self).__init__() 23 | self.beta = beta 24 | self.reduce = reduce 25 | self.as_pseudo_label = as_pseudo_label 26 | 27 | def forward(self, y_pred, y): 28 | # cross_entropy = - t * log(p) 29 | beta_xentropy = self.beta * F.cross_entropy(y_pred, y, reduction='none') 30 | 31 | y_pred_a = y_pred.detach() if self.as_pseudo_label else y_pred 32 | # second term = - (1 - beta) * p * log(p) 33 | bootstrap = - (1.0 - self.beta) * torch.sum(F.softmax(y_pred_a, dim=1) * F.log_softmax(y_pred, dim=1), dim=1) 34 | 35 | if self.reduce: 36 | return torch.mean(beta_xentropy + bootstrap) 37 | return beta_xentropy + bootstrap 38 | 39 | 40 | class HardBootstrappingLoss(Module): 41 | """ 42 | ``Loss(t, p) = - (beta * t + (1 - beta) * z) * log(p)`` 43 | where ``z = argmax(p)`` 44 | 45 | Args: 46 | beta (float): bootstrap parameter. Default, 0.95 47 | reduce (bool): computes mean of the loss. Default, True. 48 | 49 | """ 50 | def __init__(self, beta=0.8, reduce=True): 51 | super(HardBootstrappingLoss, self).__init__() 52 | self.beta = beta 53 | self.reduce = reduce 54 | 55 | def forward(self, y_pred, y): 56 | # cross_entropy = - t * log(p) 57 | beta_xentropy = self.beta * F.cross_entropy(y_pred, y, reduction='none') 58 | 59 | # z = argmax(p) 60 | z = F.softmax(y_pred.detach(), dim=1).argmax(dim=1) 61 | z = z.view(-1, 1) 62 | bootstrap = F.log_softmax(y_pred, dim=1).gather(1, z).view(-1) 63 | # second term = (1 - beta) * z * log(p) 64 | bootstrap = - (1.0 - self.beta) * bootstrap 65 | 66 | if self.reduce: 67 | return torch.mean(beta_xentropy + bootstrap) 68 | return beta_xentropy + bootstrap 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /examples/analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Analysis of the loss function with bootstrapping\n", 8 | "https://arxiv.org/pdf/1412.6596\n", 9 | "\n", 10 | "Multinomial classification task with noisy labels. Loss function modifications:\n", 11 | "\n", 12 | "a) Soft version:\n", 13 | "\n", 14 | "$$\n", 15 | "L_{soft}(\\mathbf{q}, \\mathbf{t}) = - \\sum_{k=1}^{L} (\\beta t_{k} + (1 - \\beta) q_{k}) \\log(q_{k}),\n", 16 | "$$\n", 17 | "where $\\mathbf{q}$ is a single image class probabilities, $\\mathbf{t}$ is the ground truth, $L$ is the number of classes. Parameter $\\beta$ is chosen between $0$ and $1$. \n", 18 | "\n", 19 | "a) Hard version:\n", 20 | "\n", 21 | "$$\n", 22 | "L_{hard}(\\mathbf{q}, \\mathbf{t}) = - \\sum_{k=1}^{L} (\\beta t_{k} + (1 - \\beta) z_{k}) \\log(q_{k}),\n", 23 | "$$\n", 24 | "where $z_{k}$ is argmax of $\\mathbf{q}$ (similar form as $\\mathbf{t}$)." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "1.7.0\n" 37 | ] 38 | } 39 | ], 40 | "source": [ 41 | "import torch\n", 42 | "from torch.nn import functional as F\n", 43 | "\n", 44 | "print(torch.__version__)\n", 45 | "\n", 46 | "L = 3\n", 47 | "\n", 48 | "# Some ground truth samples:\n", 49 | "t_1 = torch.tensor([0])\n", 50 | "t_2 = torch.tensor([1])\n", 51 | "logit_q_1a = torch.tensor([[0.9, 0.05, 0.05], ])\n", 52 | "logit_q_1b = torch.tensor([[0.2, 0.5, 0.3], ])\n", 53 | "logit_q_2a = torch.tensor([[0.33, 0.33, 0.33], ])\n", 54 | "logit_q_2b = torch.tensor([[0.15, 0.7, 0.15], ])" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "#### Soft version\n", 62 | "\n", 63 | "Let's first compute cross entropy term: $-\\sum_{k} t_{k} \\log(q_{k})$" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "(tensor(0.6178), tensor(1.2398), tensor(1.0986), tensor(0.7673))" 75 | ] 76 | }, 77 | "execution_count": 2, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | } 81 | ], 82 | "source": [ 83 | "F.cross_entropy(logit_q_1a, t_1), F.cross_entropy(logit_q_1b, t_1), F.cross_entropy(logit_q_2a, t_2), F.cross_entropy(logit_q_2b, t_2)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "Now let's compute the second term (soft bootstrapping) : $-\\sum_{k} q_{k} \\log(q_{k})$" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "def soft_boostrapping(logit_q):\n", 100 | " return - torch.sum(F.softmax(logit_q, dim=1) * F.log_softmax(logit_q, dim=1), dim=1)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "(tensor([1.0095]), tensor([1.0906]), tensor([1.0986]), tensor([1.0619]))" 112 | ] 113 | }, 114 | "execution_count": 4, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "soft_boostrapping(logit_q_1a), soft_boostrapping(logit_q_1b), soft_boostrapping(logit_q_2a), soft_boostrapping(logit_q_2b)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "def soft_bootstrapping_loss(logit_q, t, beta):\n", 130 | " return F.cross_entropy(logit_q, t) * beta + (1.0 - beta) * soft_boostrapping(logit_q)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "(tensor([0.6374]), tensor([1.2324]))" 142 | ] 143 | }, 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "output_type": "execute_result" 147 | } 148 | ], 149 | "source": [ 150 | "soft_bootstrapping_loss(logit_q_1a, t_1, beta=0.95), soft_bootstrapping_loss(logit_q_1b, t_1, beta=0.95)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 7, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "text/plain": [ 161 | "(tensor([1.0986]), tensor([0.7820]))" 162 | ] 163 | }, 164 | "execution_count": 7, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "soft_bootstrapping_loss(logit_q_2a, t_2, beta=0.95), soft_bootstrapping_loss(logit_q_2b, t_2, beta=0.95)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "#### Hard version\n", 185 | "\n", 186 | "Let's first compute cross entropy term: $-\\sum_{k} t_{k} \\log(q_{k})$" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 8, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "(tensor(0.6178), tensor(1.2398), tensor(1.0986), tensor(0.7673))" 198 | ] 199 | }, 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "F.cross_entropy(logit_q_1a, t_1), F.cross_entropy(logit_q_1b, t_1), F.cross_entropy(logit_q_2a, t_2), F.cross_entropy(logit_q_2b, t_2)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "Now let's compute the second term (hard bootstrapping) : $-\\sum_{k} z_{k} \\log(q_{k})$" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 9, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "def hard_boostrapping(logit_q):\n", 223 | " _, z = torch.max(F.softmax(logit_q, dim=1), dim=1)\n", 224 | " z = z.view(-1, 1)\n", 225 | " return - F.log_softmax(logit_q, dim=1).gather(1, z).view(-1) " 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 10, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "(tensor([0.6178]), tensor([0.9398]), tensor([1.0986]), tensor([0.7673]))" 237 | ] 238 | }, 239 | "execution_count": 10, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "hard_boostrapping(logit_q_1a), hard_boostrapping(logit_q_1b), hard_boostrapping(logit_q_2a), hard_boostrapping(logit_q_2b)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 11, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "def hard_bootstrapping_loss(logit_q, t, beta):\n", 255 | " return F.cross_entropy(logit_q, t) * beta + (1.0 - beta) * hard_boostrapping(logit_q)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 12, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "(tensor([0.6178]), tensor([1.1798]))" 267 | ] 268 | }, 269 | "execution_count": 12, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "hard_bootstrapping_loss(logit_q_1a, t_1, beta=0.8), hard_bootstrapping_loss(logit_q_1b, t_1, beta=0.8)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 13, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "(tensor([1.0986]), tensor([0.7673]))" 287 | ] 288 | }, 289 | "execution_count": 13, 290 | "metadata": {}, 291 | "output_type": "execute_result" 292 | } 293 | ], 294 | "source": [ 295 | "hard_bootstrapping_loss(logit_q_2a, t_2, beta=0.8), hard_bootstrapping_loss(logit_q_2b, t_2, beta=0.8)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 14, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "y_pred = torch.rand(4, 10, requires_grad=True)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 15, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "tensor([3, 2, 5, 2])" 323 | ] 324 | }, 325 | "execution_count": 15, 326 | "metadata": {}, 327 | "output_type": "execute_result" 328 | } 329 | ], 330 | "source": [ 331 | "z = F.softmax(y_pred, dim=1).argmax(dim=1)\n", 332 | "z" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 17, 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "text/plain": [ 343 | "tensor([3, 2, 5, 2])" 344 | ] 345 | }, 346 | "execution_count": 17, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "_, z2 = F.softmax(y_pred, dim=1).max(dim=1)\n", 353 | "z2" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [] 362 | } 363 | ], 364 | "metadata": { 365 | "kernelspec": { 366 | "display_name": "Python 3", 367 | "language": "python", 368 | "name": "python3" 369 | }, 370 | "language_info": { 371 | "codemirror_mode": { 372 | "name": "ipython", 373 | "version": 3 374 | }, 375 | "file_extension": ".py", 376 | "mimetype": "text/x-python", 377 | "name": "python", 378 | "nbconvert_exporter": "python", 379 | "pygments_lexer": "ipython3", 380 | "version": "3.8.2" 381 | } 382 | }, 383 | "nbformat": 4, 384 | "nbformat_minor": 4 385 | } 386 | -------------------------------------------------------------------------------- /examples/mnist/experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vfdev-5/BootstrappingLoss/a072e4a20cbc4a19e73a69788869fbbd3505eeef/examples/mnist/experiments.png -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | MNIST example with training and validation monitoring using Tensorboard 3 | 4 | Usage: 5 | Start tensorboard: 6 | ```bash 7 | tensorboard --logdir=/tmp/tensorboard_logs/ 8 | ``` 9 | Run the example: 10 | ```bash 11 | python main.py --log_dir=/tmp/tensorboard_logs 12 | ``` 13 | """ 14 | import sys 15 | import random 16 | 17 | from pathlib import Path 18 | from datetime import datetime 19 | from functools import partial 20 | 21 | import fire 22 | 23 | import torch 24 | from torch.utils.data import DataLoader 25 | from torch import nn 26 | import torch.nn.functional as F 27 | from torch.optim import SGD 28 | from torchvision.datasets import MNIST 29 | from torchvision.transforms import Compose, ToTensor, Normalize 30 | 31 | import ignite 32 | from ignite.contrib.engines import common 33 | from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator 34 | from ignite.metrics import Accuracy, Loss 35 | from ignite.utils import manual_seed, setup_logger 36 | 37 | 38 | # Add local code 39 | sys.path.insert(0, "../..") 40 | 41 | from bootstrapping_loss import SoftBootstrappingLoss, HardBootstrappingLoss 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 48 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 49 | self.conv2_drop = nn.Dropout2d() 50 | self.fc1 = nn.Linear(320, 50) 51 | self.fc2 = nn.Linear(50, 10) 52 | 53 | def forward(self, x): 54 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 55 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 56 | x = x.view(-1, 320) 57 | x = F.relu(self.fc1(x)) 58 | x = F.dropout(x, training=self.training) 59 | x = self.fc2(x) 60 | return x 61 | 62 | 63 | label_noise_pattern = { 64 | # True label -> noise label 65 | 0: 7, 66 | 1: 9, 67 | 2: 0, 68 | 3: 4, 69 | 4: 2, 70 | 5: 1, 71 | 6: 3, 72 | 7: 5, 73 | 8: 6, 74 | 9: 8 75 | } 76 | 77 | 78 | def noisy_labels(y, a=0.5): 79 | 80 | if random.random() > a: 81 | return y 82 | return label_noise_pattern[y] 83 | 84 | 85 | def get_data_loaders(data_path, noise_fraction, train_batch_size, val_batch_size): 86 | data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) 87 | 88 | train_dataset = MNIST(download=True, root=data_path, transform=data_transform, 89 | target_transform=partial(noisy_labels, a=noise_fraction), 90 | train=True) 91 | 92 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, num_workers=4, shuffle=True) 93 | 94 | val_dataset = MNIST(download=False, root=data_path, transform=data_transform, train=False) 95 | val_loader = DataLoader(val_dataset, batch_size=val_batch_size, num_workers=4, shuffle=False) 96 | return train_loader, val_loader 97 | 98 | 99 | def run( 100 | data_path="/tmp/MNIST", 101 | seed=3321, 102 | mode="xentropy", 103 | noise_fraction=0.35, 104 | batch_size=64, 105 | val_batch_size=1000, 106 | num_epochs=50, 107 | lr=0.01, 108 | momentum=0.5, 109 | as_pseudo_label=None, 110 | log_dir="/tmp/output-bootstraping-loss/mnist/", 111 | with_trains=False, 112 | ): 113 | """Training on noisy labels with bootstrapping 114 | 115 | Args: 116 | data_path (str): Path to MNIST dataset. Default, "/tmp/MNIST" 117 | seed (int): Random seed to setup. Default, 3321 118 | mode (str): Loss function mode: cross-entropy or bootstrapping (soft, hard). 119 | Choices 'xentropy', 'soft_bootstrap', 'hard_bootstrap'. 120 | noise_fraction (float): Label noise fraction. Default, 0.35. 121 | batch_size (int): Input batch size for training. Default, 64. 122 | val_batch_size (int): input batch size for validation. Default, 1000. 123 | num_epochs (int): Number of epochs to train. Default, 50. 124 | lr (float): Learning rate. Default, 0.01. 125 | momentum (float): SGD momentum. Default, 0.5. 126 | log_dir (str): Log directory for Tensorboard log output. Default="/tmp/output-bootstraping-loss/mnist/". 127 | with_trains (bool): if True, experiment Trains logger is setup. Default, False. 128 | 129 | """ 130 | assert torch.cuda.is_available(), "Training should running on GPU" 131 | device = "cuda" 132 | 133 | manual_seed(seed) 134 | logger = setup_logger(name="MNIST-Training") 135 | 136 | now = datetime.now().strftime("%Y%m%d-%H%M%S") 137 | 138 | # Setup output path 139 | suffix = "" 140 | if mode == "soft_bootstrap" and (as_pseudo_label is not None and not as_pseudo_label): 141 | suffix = "as_xreg" 142 | output_path = Path(log_dir) / "train_{}_{}_{}_{}__{}".format(mode, noise_fraction, suffix, now, num_epochs) 143 | 144 | if not output_path.exists(): 145 | output_path.mkdir(parents=True) 146 | 147 | parameters = { 148 | "seed": seed, 149 | "mode": mode, 150 | "noise_fraction": noise_fraction, 151 | "batch_size": batch_size, 152 | "num_epochs": num_epochs, 153 | "lr": lr, 154 | "momentum": momentum, 155 | "as_pseudo_label": as_pseudo_label, 156 | } 157 | log_basic_info(logger, parameters) 158 | 159 | if with_trains: 160 | from trains import Task 161 | 162 | task = Task.init("BootstrappingLoss - Experiments on MNIST", task_name=output_path.name) 163 | # Log hyper parameters 164 | task.connect(parameters) 165 | 166 | train_loader, test_loader = get_data_loaders(data_path, noise_fraction, batch_size, val_batch_size) 167 | model = Net().to(device) 168 | optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) 169 | 170 | if mode == 'xentropy': 171 | criterion = nn.CrossEntropyLoss() 172 | elif mode == 'soft_bootstrap': 173 | if as_pseudo_label is None: 174 | as_pseudo_label = True 175 | criterion = SoftBootstrappingLoss(beta=0.95, as_pseudo_label=as_pseudo_label) 176 | elif mode == 'hard_bootstrap': 177 | criterion = HardBootstrappingLoss(beta=0.8) 178 | else: 179 | raise ValueError("Wrong mode {}, expected: xentropy, soft_bootstrap or hard_bootstrap".format(mode)) 180 | 181 | trainer = create_supervised_trainer(model, optimizer, criterion, device=device, non_blocking=True) 182 | 183 | metrics={ 184 | "Accuracy": Accuracy(), 185 | "{} loss".format(mode): Loss(criterion), 186 | } 187 | if mode is not "xentropy": 188 | metrics["xentropy loss"] = Loss(nn.CrossEntropyLoss()) 189 | 190 | evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) 191 | train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) 192 | 193 | def run_validation(engine): 194 | epoch = trainer.state.epoch 195 | state = train_evaluator.run(train_loader) 196 | log_metrics(logger, epoch, "Train", state.metrics) 197 | state = evaluator.run(test_loader) 198 | log_metrics(logger, epoch, "Test", state.metrics) 199 | 200 | trainer.add_event_handler(Events.EPOCH_COMPLETED | Events.COMPLETED, run_validation) 201 | 202 | evaluators = {"training": train_evaluator, "test": evaluator} 203 | tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) 204 | 205 | trainer.run(train_loader, max_epochs=num_epochs) 206 | 207 | test_acc = evaluator.state.metrics["Accuracy"] 208 | tb_logger.writer.add_hparams(parameters, {"hparam/test_accuracy": test_acc}) 209 | 210 | tb_logger.close() 211 | 212 | return (mode, noise_fraction, as_pseudo_label, test_acc) 213 | 214 | 215 | def log_metrics(logger, epoch, tag, metrics): 216 | logger.info( 217 | "Epoch {} - {} metrics: {}".format( 218 | epoch, tag, " - ".join(["{}: {:.4f}".format(k, v) for k, v in metrics.items()]) 219 | ) 220 | ) 221 | 222 | 223 | def log_basic_info(logger, config): 224 | logger.info("Train on MNIST") 225 | logger.info("- PyTorch version: {}".format(torch.__version__)) 226 | logger.info("- Ignite version: {}".format(ignite.__version__)) 227 | 228 | logger.info("\n") 229 | logger.info("Configuration:") 230 | for key, value in config.items(): 231 | logger.info("\t{}: {}".format(key, value)) 232 | logger.info("\n") 233 | 234 | 235 | if __name__ == "__main__": 236 | fire.Fire({"run": run}) 237 | -------------------------------------------------------------------------------- /examples/mnist/run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | for mode in "xentropy" "soft_bootstrap" "hard_bootstrap" 6 | do 7 | for noise_fraction in 0.3 0.34 0.38 0.42 0.46 0.5 8 | do 9 | python main.py run --mode=$mode --noise_fraction=$noise_fraction --num_epochs 50 --with_trains=${WITH_TRAINS:-False} 10 | done 11 | done 12 | 13 | # Run with as_pseudo_label=False 14 | for noise_fraction in 0.3 0.34 0.38 0.42 0.46 0.5 15 | do 16 | python main.py run --mode=soft_bootstrap --noise_fraction=$noise_fraction --num_epochs 50 --as_pseudo_label=False --with_trains=${WITH_TRAINS:-False} 17 | done 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch<2.0,>=1.3 2 | torchvision<1.0,>=0.4.1 3 | tensorboard<3.0,>=2.3.0 4 | pytorch-ignite<1.0,>=0.4.2 5 | fire<1.0,>=0.3.1 6 | trains<1.0,>=0.16.3 --------------------------------------------------------------------------------