├── .gitignore ├── LICENSE ├── README.md ├── analysis ├── README.md ├── adversarial_linf.py ├── mine.py ├── train_vaes.py ├── vae.py ├── vae_mi.py └── vgg.py ├── datasets ├── __init__.py ├── bengali.py ├── datasets.py ├── fashion.py ├── google_commands │ ├── README.md │ ├── google_commands.py │ ├── sft_transforms.py │ └── transforms.py ├── imagenet_a.py ├── imagenet_hdf5.py ├── tiny_imagenet.py ├── toxic.py └── toxic_bert.py ├── experiments ├── bengali_experiment.sh ├── cifar_experiment.sh ├── fashion_experiments.sh ├── google_commands_experiment.sh ├── imagenet_experiment.sh ├── modelnet_experiment.sh ├── tiny_imagenet_experiment.sh └── toxic_experiment.sh ├── fmix.py ├── fmix_3d.gif ├── fmix_example.png ├── hubconf.py ├── implementations ├── lightning.py ├── tensorflow_implementation.py ├── test_lightning.py ├── test_tensorflow.py ├── test_torchbearer.py └── torchbearer_implementation.py ├── models ├── __init__.py ├── bert.py ├── densenet3.py ├── models.py ├── pyramid.py ├── resnet.py ├── senet.py ├── toxic_cnn.py ├── toxic_lstm.py └── wide_resnet.py ├── notebooks ├── example_masks.ipynb └── grad_cam.ipynb ├── requirements.txt ├── trainer.py └── utils ├── __init__.py ├── auto_augment ├── __init__.py ├── auto_augment.py └── auto_augment_aug_list.py ├── bengali_evaluate.py ├── convert_imagenet_model.py ├── cross_val.py ├── imagenet_to_hdf5.py ├── lr_warmup.py ├── macro_recall.py ├── msda_alternator.py ├── process.py ├── reduced_dataset_splitter.py └── reformulated_mixup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ours 132 | implementations/data 133 | implementations/lightning_logs 134 | data/ 135 | .idea/ 136 | saved_models/ 137 | logs/ 138 | notebooks/* 139 | !notebooks/*.ipynb 140 | .vector_cache 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vision, Learning and Control Research Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # FMix 4 | 5 | This repository contains the __official__ implementation of the paper ['FMix: Enhancing Mixed Sampled Data Augmentation'](https://arxiv.org/abs/2002.12047) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/understanding-and-enhancing-mixed-sample-data/image-classification-on-cifar-10)](https://paperswithcode.com/sota/image-classification-on-cifar-10?p=understanding-and-enhancing-mixed-sample-data) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/understanding-and-enhancing-mixed-sample-data/image-classification-on-fashion-mnist)](https://paperswithcode.com/sota/image-classification-on-fashion-mnist?p=understanding-and-enhancing-mixed-sample-data) 9 | 10 |

11 | ArXiv • 12 | Papers With Code • 13 | About • 14 | Experiments • 15 | Implementations • 16 | Pre-trained Models 17 |

18 | 19 | 20 | 21 | 22 | Dive in with our example notebook in Colab! 23 | 24 | 25 |
26 | 27 | ## About 28 | 29 | FMix is a variant of MixUp, CutMix, etc. introduced in our paper ['FMix: Enhancing Mixed Sampled Data Augmentation'](https://arxiv.org/abs/2002.12047). It uses masks sampled from Fourier space to mix training examples. Take a look at our [example notebook in colab](https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/example_masks.ipynb) which shows how you can generate masks in two dimensions 30 | 31 |
32 | 33 |
34 | 35 | and in three! 36 | 37 |
38 | 39 |
40 | 41 | ## Experiments 42 | 43 | ### Core Experiments 44 | 45 | Shell scripts for our core experiments can be found in the [experiments folder](./experiments). For example, 46 | 47 | ```bash 48 | bash cifar_experiment cifar10 resnet fmix ./data 49 | ``` 50 | 51 | will train a PreAct-ResNet18 on CIFAR-10 with FMix. More information can be found at the start of each of the shell files. 52 | 53 | ### Additional Experiments 54 | 55 | All additional classification experiments can be run via [`trainer.py`](./trainer.py) 56 | 57 | ### Analyses 58 | 59 | For Grad-CAM, take a look at the [Grad-CAM notebook in colab](https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/grad_cam.ipynb). 60 | 61 | For the other analyses, have a look in the [analysis folder](./analysis). 62 | 63 | ## Implementations 64 | 65 | The core implementation of `FMix` uses `numpy` and can be found in [`fmix.py`](./fmix.py). We provide bindings for this in [PyTorch](https://pytorch.org/) (with [Torchbearer](https://github.com/pytorchbearer/torchbearer) or [PyTorch-Lightning](https://github.com/PyTorchLightning/pytorch-lightning)) and [Tensorflow](https://www.tensorflow.org/). 66 | 67 | ### Torchbearer 68 | 69 | The `FMix` callback in [`torchbearer_implementation.py`](./implementations/torchbearer_implementation.py) can be added directly to your torchbearer code: 70 | 71 | ```python 72 | from implementations.torchbearer_implementation import FMix 73 | 74 | fmix = FMix() 75 | trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix]) 76 | ``` 77 | 78 | See an example in [`test_torchbearer.py`](./implementations/test_torchbearer.py). 79 | 80 | ### PyTorch-Lightning 81 | 82 | For PyTorch-Lightning, we provide a class, `FMix` in [`lightning.py`](./implementations/lightning.py) that can be used in your `LightningModule`: 83 | 84 | ```python 85 | from implementations.lightning import FMix 86 | 87 | class CoolSystem(pl.LightningModule): 88 | def __init__(self): 89 | ... 90 | 91 | self.fmix = FMix() 92 | 93 | def training_step(self, batch, batch_nb): 94 | x, y = batch 95 | x = self.fmix(x) 96 | 97 | x = self.forward(x) 98 | 99 | loss = self.fmix.loss(x, y) 100 | return {'loss': loss} 101 | ``` 102 | 103 | See an example in [`test_lightning.py`](./implementations/test_lightning.py). 104 | 105 | ### Tensorflow 106 | 107 | For Tensorflow, we provide a class, `FMix` in [`tensorflow_implementation.py`](./implementations/tensorflow_implementation.py) that can be used in your tensorflow code: 108 | 109 | ```python 110 | from implementations.tensorflow_implementation import FMix 111 | 112 | fmix = FMix() 113 | 114 | def loss(model, x, y, training=True): 115 | x = fmix(x) 116 | y_ = model(x, training=training) 117 | return tf.reduce_mean(fmix.loss(y_, y)) 118 | ``` 119 | 120 | See an example in [`test_tensorflow.py`](./implementations/test_tensorflow.py). 121 | 122 | ## Pre-trained Models 123 | 124 | We provide pre-trained models via `torch.hub` (more coming soon). To use them, run 125 | 126 | ```python 127 | import torch 128 | model = torch.hub.load('ecs-vlc/FMix:master', ARCHITECTURE, pretrained=True) 129 | ``` 130 | 131 | where `ARCHITECTURE` is one of the following: 132 | 133 | ### CIFAR-10 134 | 135 | #### PreAct-ResNet-18 136 | 137 | | Configuration | `ARCHITECTURE` | Accuracy | 138 | | ---------------- | ----------------------------------------- | -------- | 139 | | Baseline | `'preact_resnet18_cifar10_baseline'` | -------- | 140 | | + MixUp | `'preact_resnet18_cifar10_mixup'` | -------- | 141 | | + FMix | `'preact_resnet18_cifar10_fmix'` | -------- | 142 | | + Mixup + FMix | `'preact_resnet18_cifar10_fmixplusmixup'` | -------- | 143 | 144 | #### PyramidNet-200 145 | 146 | | Configuration | `ARCHITECTURE` | Accuracy | 147 | | ---------------- | ------------------------------------ | --------- | 148 | | Baseline | `'pyramidnet_cifar10_baseline'` | 98.31 | 149 | | + MixUp | `'pyramidnet_cifar10_mixup'` | 97.92 | 150 | | + FMix | `'pyramidnet_cifar10_fmix'` | __98.64__ | 151 | 152 | ### ImageNet 153 | 154 | #### ResNet-101 155 | 156 | | Configuration | `ARCHITECTURE` | Accuracy (Top-1) | 157 | | ---------------- | ------------------------------------ | ---------------- | 158 | | Baseline | `'renset101_imagenet_baseline'` | 76.51 | 159 | | + MixUp | `'renset101_imagenet_mixup'` | 76.27 | 160 | | + FMix | `'renset101_imagenet_fmix'` | __76.72__ | 161 | -------------------------------------------------------------------------------- /analysis/README.md: -------------------------------------------------------------------------------- 1 | # Analysis 2 | 3 | This directory contains the code used for the VAE and mutual information experiments from the paper. 4 | -------------------------------------------------------------------------------- /analysis/adversarial_linf.py: -------------------------------------------------------------------------------- 1 | # !pip install foolbox 2 | import torch 3 | import eagerpy as ep 4 | from foolbox import PyTorchModel, accuracy, samples 5 | import foolbox.attacks as fa 6 | import numpy as np 7 | import json 8 | 9 | from torchvision.datasets import CIFAR10 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | from torch import nn 13 | 14 | from sklearn.model_selection import ParameterGrid 15 | 16 | import argparse 17 | 18 | from tqdm import tqdm 19 | 20 | parser = argparse.ArgumentParser(description='Imagenet Training') 21 | parser.add_argument('--arr', default=0, type=int, help='point in job array') 22 | args = parser.parse_args() 23 | 24 | param_grid = ParameterGrid({ 25 | 'mode': ('baseline', 'cutmix', 'mixup', 'fmix'), 26 | 'repeat': list(range(5)) 27 | }) 28 | 29 | params = param_grid[args.arr] 30 | mode = params['mode'] 31 | repeat = params['repeat'] 32 | 33 | test_transform = transforms.Compose([ 34 | transforms.ToTensor() # convert to tensor 35 | ]) 36 | 37 | # load data 38 | testset = CIFAR10(".", train=False, download=True, transform=test_transform) 39 | testloader = DataLoader(testset, batch_size=750, shuffle=False, num_workers=5) 40 | 41 | attacks = [ 42 | fa.FGSM(), 43 | fa.LinfPGD(), 44 | fa.LinfBasicIterativeAttack(), 45 | fa.LinfAdditiveUniformNoiseAttack(), 46 | fa.LinfDeepFoolAttack(), 47 | ] 48 | 49 | epsilons = [ 50 | 0.0, 51 | 0.0005, 52 | 0.001, 53 | 0.0015, 54 | 0.002, 55 | 0.003, 56 | 0.005, 57 | 0.01, 58 | 0.02, 59 | 0.03, 60 | 0.1, 61 | 0.3, 62 | 0.5, 63 | 1.0, 64 | ] 65 | 66 | 67 | def normalize_with(mean, std): 68 | mean = torch.tensor(mean) 69 | std = torch.tensor(std) 70 | return lambda x: (x - mean.to(x.device).unsqueeze(0).unsqueeze(2).unsqueeze(3)) / std.to(x.device).unsqueeze(0).unsqueeze(2).unsqueeze(3) 71 | 72 | # (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 73 | 74 | 75 | class Normalized(nn.Module): 76 | def __init__(self, model, mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]): 77 | super().__init__() 78 | self.model = model 79 | self.normalize = normalize_with(mean=mean, std=std) 80 | 81 | def forward(self, x): 82 | x = self.normalize(x) 83 | return self.model(x) 84 | 85 | 86 | results = dict() 87 | 88 | # for mode in ('baseline', 'mixup', 'fmix', 'cutmix', 'fmixplusmixup'): 89 | print(mode) 90 | results[mode] = dict() 91 | 92 | model = Normalized(torch.hub.load('ecs-vlc/FMix:master', f'preact_resnet18_cifar10_{mode}', pretrained=True, repeat=repeat)) 93 | model.eval() 94 | fmodel = PyTorchModel(model, bounds=(0, 1)) 95 | 96 | attack_success = np.zeros((len(attacks), len(epsilons), len(testset)), dtype=np.bool) 97 | for i, attack in enumerate(attacks): 98 | # print(attack) 99 | idx = 0 100 | for images, labels in tqdm(testloader): 101 | # print('.', end='') 102 | images = images.to(fmodel.device) 103 | labels = labels.to(fmodel.device) 104 | 105 | _, _, success = attack(fmodel, images, labels, epsilons=epsilons) 106 | success_ = success.cpu().numpy() 107 | attack_success[i][:, idx:idx + len(labels)] = success_ 108 | idx = idx + len(labels) 109 | # print("") 110 | 111 | import pickle 112 | with open(f'adversarial_linf_{mode}_{repeat}.p', 'wb') as f: 113 | # Pickle the 'data' dictionary using the highest protocol available. 114 | pickle.dump(attack_success, f) 115 | # for i, attack in enumerate(attacks): 116 | # results[mode][str(attack)] = (1.0 - attack_success[i].mean(axis=-1)).tolist() 117 | # 118 | # robust_accuracy = 1.0 - attack_success.max(axis=0).mean(axis=-1) 119 | # results[mode]['robust_accuracy'] = robust_accuracy.tolist() 120 | # 121 | # with open('adv-results-cifar-linf.json', 'w') as fp: 122 | # json.dump(results, fp) 123 | -------------------------------------------------------------------------------- /analysis/mine.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torchbearer import state_key 9 | from torchbearer import callbacks 10 | 11 | T = state_key('t') 12 | T_SHUFFLED = state_key('t_shuffled') 13 | MI = state_key('mi') 14 | 15 | 16 | class Flatten(nn.Module): 17 | def forward(self, x): 18 | return x.view(x.size(0), -1) 19 | 20 | 21 | class DoNothing(nn.Module): 22 | def forward(self, x): 23 | return x 24 | 25 | 26 | def resample(x): 27 | return F.fold(F.unfold(x, kernel_size=2, stride=2), (int(x.size(2) / 2), int(x.size(3) / 2)), 1) 28 | 29 | 30 | class Estimator(nn.Module): 31 | def __init__(self, conv, in_size, pool_input=False, halves=0): 32 | super().__init__() 33 | self.pool = DoNothing() 34 | self.halves = halves 35 | 36 | if conv: 37 | in_size = in_size + 3 * 4 ** halves 38 | if pool_input: 39 | self.pool = nn.AdaptiveAvgPool2d((8, 8)) 40 | 41 | self.est = nn.Sequential( 42 | nn.Conv2d(in_size, 256, kernel_size=3, padding=1), 43 | nn.ReLU(), 44 | nn.AdaptiveAvgPool2d((2, 2)), 45 | Flatten(), 46 | nn.Linear(2 * 2 * 256, 512), 47 | nn.ReLU(), 48 | nn.Linear(512, 1) 49 | ) 50 | else: 51 | in_size = in_size + 32 * 32 * 3 52 | self.est = nn.Sequential( 53 | nn.Linear(in_size, 512), 54 | nn.ReLU(), 55 | nn.Linear(512, 512), 56 | nn.ReLU(), 57 | nn.Linear(512, 1) 58 | ) 59 | 60 | def forward(self, x, f): 61 | if self.halves < 5: 62 | for i in range(self.halves): 63 | x = resample(x) 64 | else: 65 | x = x.view(x.size(0), -1) 66 | f = f.view(f.size(0), -1) 67 | if f.dim() == 2: 68 | x = x.view(x.size(0), -1) 69 | x = torch.cat((x, f), dim=1) 70 | x = self.pool(x) 71 | 72 | return self.est(x) 73 | 74 | 75 | cfgs = { 76 | 'A': { 77 | 'f1': lambda: Estimator(True, 64, halves=1), 'f2': lambda: Estimator(True, 128, halves=2), 'f3': lambda: Estimator(True, 256, halves=2), 78 | 'f4': lambda: Estimator(True, 256, halves=3), 'f5': lambda: Estimator(True, 512, halves=3), 'f6': lambda: Estimator(True, 512, halves=4), 79 | 'f7': lambda: Estimator(True, 512, halves=4), 80 | 'f8': lambda: Estimator(False, 512, False, halves=5), 'c1': lambda: Estimator(False, 2048), 'c2': lambda: Estimator(False, 2048)}, 81 | 'B': {}, 82 | 'D': {}, 83 | 'E': {}, 84 | } 85 | 86 | 87 | def mi(tanh): 88 | def mi_loss(state): 89 | m_t, m_t_shuffled = state[torchbearer.Y_PRED] 90 | mi = {} 91 | sum = 0.0 92 | for layer in m_t.keys(): 93 | t = m_t[layer] 94 | t_shuffled = m_t_shuffled[layer] 95 | if tanh: 96 | t = t.tanh() 97 | t_shuffled = t_shuffled.tanh() 98 | tmp = t.mean() - (torch.logsumexp(t_shuffled, 0) - math.log(t_shuffled.size(0))) 99 | mi[layer] = tmp.item() 100 | sum += tmp 101 | if len(mi.keys()) == 1: 102 | state[MI] = mi[next(iter(mi.keys()))] 103 | else: 104 | state[MI] = mi 105 | return -sum 106 | return mi_loss 107 | 108 | 109 | def process(x, cache, cfg): 110 | t = {} 111 | t_shuffled = {} 112 | 113 | for layer in cfg.keys(): 114 | out = cache[layer].detach() 115 | t[layer] = cfg[layer](x, out) 116 | t_shuffled[layer] = cfg[layer](x, out[torch.randperm(out.size(0))]) 117 | return t, t_shuffled 118 | 119 | 120 | class MimeVGG(nn.Module): 121 | def __init__(self, vgg, cfg): 122 | super().__init__() 123 | 124 | self.vgg = vgg 125 | self.cfg = nn.ModuleDict(cfg) 126 | 127 | def forward(self, x): 128 | pred, cache = self.vgg(x) 129 | 130 | t, t_shuffled = process(x, cache, self.cfg) 131 | 132 | return t, t_shuffled 133 | 134 | 135 | if __name__ == '__main__': 136 | from torch import optim 137 | from torchvision import transforms 138 | from torchvision.datasets import CIFAR10 139 | 140 | import torchbearer 141 | from torchbearer import Trial 142 | 143 | OTHER_MI = state_key('other_mi') 144 | 145 | cfg = cfgs['A'] 146 | 147 | import argparse 148 | 149 | parser = argparse.ArgumentParser(description='VGG MI') 150 | parser.add_argument('--model', default='mix_3', type=str, help='model') 151 | args = parser.parse_args() 152 | 153 | from .vgg import vgg11_bn 154 | 155 | vgg = vgg11_bn(return_cache=True) 156 | vgg.load_state_dict(torch.load(args.model + '.pt')[torchbearer.MODEL]) 157 | for param in vgg.parameters(): 158 | param.requires_grad = False 159 | 160 | for layer in cfg: 161 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 162 | transform_base = [transforms.ToTensor(), normalize] 163 | 164 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base 165 | 166 | transform_train = transforms.Compose(transform) 167 | transform_test = transforms.Compose(transform_base) 168 | 169 | trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) 170 | valset = CIFAR10(root='./data', train=False, download=True, transform=transform_test) 171 | 172 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=1000, shuffle=True, num_workers=8) 173 | valloader = torch.utils.data.DataLoader(valset, batch_size=5000, shuffle=True, num_workers=8) 174 | 175 | model = MimeVGG(vgg, {k: cfgs['A'][k]() for k in [layer]}) 176 | 177 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=5e-4) 178 | 179 | mi_false = mi(False) 180 | 181 | @callbacks.add_to_loss 182 | def mi_no_tanh(state): 183 | state[OTHER_MI] = mi_false(state) 184 | return 0 185 | 186 | trial = Trial(model, optimizer, mi(True), metrics=['loss', torchbearer.metrics.mean(OTHER_MI)], callbacks=[mi_no_tanh, callbacks.TensorBoard(write_graph=False, comment='mi_' + args.model, log_dir='mi_data')]) 187 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda') 188 | trial.run(20, verbose=1) 189 | -------------------------------------------------------------------------------- /analysis/train_vaes.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms, datasets 5 | import torchbearer 6 | from torchbearer import Trial, callbacks, metrics 7 | from torchbearer.callbacks import init 8 | import torch 9 | from torch import distributions 10 | 11 | from .vae import VAE, LATENT 12 | from implementations.torchbearer_implementation import FMix 13 | 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser(description='VAE Training') 17 | parser.add_argument('--mode', default='base', type=str, help='name of run') 18 | parser.add_argument('--i', default=1, type=int, help='iteration') 19 | parser.add_argument('--var', default=1, type=float, help='iteration') 20 | parser.add_argument('--epochs', default=100, type=int, help='epochs') 21 | parser.add_argument('--dir', default='vaes', type=str, help='directory') 22 | args = parser.parse_args() 23 | 24 | KL = torchbearer.state_key('KL') 25 | NLL = torchbearer.state_key('NLL') 26 | SAMPLE = torchbearer.state_key('SAMPLE') 27 | 28 | # Data 29 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 30 | inv_normalize = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010)) 31 | transform_base = [transforms.ToTensor(), normalize] 32 | 33 | transform = [transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.RandomHorizontalFlip()] + transform_base 34 | 35 | transform_train = transforms.Compose(transform) 36 | transform_test = transforms.Compose(transform_base) 37 | 38 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 39 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 40 | 41 | train_loader = DataLoader(train_set, 128, shuffle=True, num_workers=5) 42 | test_loader = DataLoader(test_set, 100, shuffle=False, num_workers=5) 43 | 44 | # KL Divergence 45 | 46 | def kld(prior): 47 | @callbacks.add_to_loss 48 | def loss(state): 49 | res = distributions.kl_divergence(state[LATENT], prior).sum().div(state[LATENT].loc.size(0)) 50 | state[KL] = res.detach() 51 | return res 52 | return loss 53 | 54 | # Negative Log Likelihood 55 | 56 | def nll(state): 57 | y_pred, y_true = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE] 58 | res = - y_pred.log_prob(y_true).sum().div(y_true.size(0)) 59 | state[NLL] = res.detach() 60 | return res 61 | 62 | # Generate Some Images 63 | 64 | @torchbearer.callbacks.on_forward 65 | @torchbearer.callbacks.on_forward_validation 66 | def sample(state): 67 | state[SAMPLE] = state[torchbearer.Y_PRED].loc 68 | 69 | # Train VAEs 70 | 71 | aug = [] 72 | mode = args.mode 73 | if mode == 'mix': 74 | aug = [callbacks.Mixup()] 75 | if mode == 'cutmix': 76 | aug = [callbacks.CutMix(1, classes=10)] 77 | if mode == 'fmix': 78 | aug = [FMix(alpha=1, decay_power=3)] 79 | 80 | model = VAE(64, var=args.var) 81 | trial = Trial(model, optim.Adam(model.parameters(), lr=5e-2), nll, 82 | metrics=[ 83 | metrics.MeanSquaredError(pred_key=SAMPLE), 84 | metrics.mean(NLL), 85 | metrics.mean(KL), 86 | 'loss' 87 | ], 88 | callbacks=[ 89 | sample, 90 | kld(distributions.Normal(0, 1)), 91 | init.XavierNormal(targets=['Conv']), 92 | callbacks.MostRecent(args.dir + '/' + mode + '_' + str(args.i) + '.pt'), 93 | callbacks.MultiStepLR([40, 80]), 94 | callbacks.TensorBoard(write_graph=False, comment=mode + '_' + str(args.i), log_dir='vae_logs'), 95 | *aug 96 | ]) 97 | 98 | if mode in ['base', 'mix', 'cutmix']: 99 | trial = trial.load_state_dict(torch.load('vaes/' + '/' + mode + '_' + str(args.i) + '.pt')) 100 | 101 | trial.with_generators(train_loader, test_loader).to('cuda').run(args.epochs, verbose=1) 102 | -------------------------------------------------------------------------------- /analysis/vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import distributions 6 | from torch.distributions import constraints, register_kl 7 | 8 | import torchbearer 9 | from torchbearer import state_key 10 | 11 | LATENT = state_key('latent') 12 | 13 | 14 | class LogitNormal(distributions.Normal): 15 | arg_constraints = {'loc': constraints.real, 'log_scale': constraints.real} 16 | support = constraints.real 17 | has_rsample = True 18 | 19 | def __init__(self, loc, log_scale, validate_args=None): 20 | self.log_scale = log_scale 21 | scale = distributions.transform_to(distributions.Normal.arg_constraints['scale'])(log_scale) 22 | super().__init__(loc, scale, validate_args=validate_args) 23 | 24 | def log_prob(self, value): 25 | if self._validate_args: 26 | self._validate_sample(value) 27 | # compute the variance 28 | var = (self.scale ** 2) 29 | log_scale = self.log_scale 30 | return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) 31 | 32 | 33 | @register_kl(LogitNormal, distributions.Normal) 34 | def kl_logitnormal_normal(p, q): 35 | log_var_ratio = 2 * (p.log_scale - q.scale.log()) 36 | t1 = ((p.loc - q.loc) / q.scale).pow(2) 37 | return 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 38 | 39 | 40 | @register_kl(LogitNormal, LogitNormal) 41 | def kl_logitnormal_logitnormal(p, q): 42 | log_var_ratio = 2 * (p.log_scale - q.log_scale) 43 | t1 = ((p.loc - q.loc) / q.scale).pow(2) 44 | return 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio) 45 | 46 | 47 | class Flatten(nn.Module): 48 | def forward(self, x): 49 | return x.view(x.size(0), -1) 50 | 51 | 52 | class View(nn.Module): 53 | def __init__(self, *args): 54 | super().__init__() 55 | self.args = args 56 | 57 | def forward(self, x): 58 | return x.view(x.size(0), *self.args) 59 | 60 | 61 | class SimpleEncoder(nn.Sequential): 62 | def __init__(self): 63 | super().__init__( 64 | nn.Conv2d(3, 32, (4, 4), stride=2, padding=1), 65 | nn.ReLU(True), 66 | nn.BatchNorm2d(32), 67 | nn.Conv2d(32, 32, (4, 4), stride=2, padding=1), 68 | nn.ReLU(True), 69 | nn.BatchNorm2d(32), 70 | nn.Conv2d(32, 32, (4, 4), stride=2, padding=1), 71 | nn.ReLU(True), 72 | nn.BatchNorm2d(32), 73 | Flatten() 74 | ) 75 | 76 | self.output_size = 32 * 4 * 4 77 | 78 | 79 | class SimpleDecoder(nn.Sequential): 80 | def __init__(self, z_dims): 81 | super().__init__( 82 | nn.Linear(z_dims, 32 * 4 * 4), 83 | View(32, 4, 4), 84 | nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1), 85 | nn.ReLU(True), 86 | nn.BatchNorm2d(32), 87 | nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1), 88 | nn.ReLU(True), 89 | nn.BatchNorm2d(32), 90 | nn.ConvTranspose2d(32, 3, (4, 4), stride=2, padding=1) 91 | ) 92 | 93 | 94 | class DCGANDecoder(nn.Sequential): 95 | def __init__(self, z_dims, dim=128): 96 | super().__init__( 97 | nn.Linear(z_dims, dim * 4 * 4 * 4), 98 | nn.ReLU(True), 99 | nn.BatchNorm1d(dim * 4 * 4 * 4), 100 | View(dim * 4, 4, 4), 101 | nn.ConvTranspose2d(dim * 4, dim * 2, 5, stride=2, padding=2, output_padding=1), 102 | nn.ReLU(True), 103 | nn.BatchNorm2d(dim * 2), 104 | nn.ConvTranspose2d(dim * 2, dim, 5, stride=2, padding=2, output_padding=1), 105 | nn.ReLU(True), 106 | nn.BatchNorm2d(dim), 107 | nn.ConvTranspose2d(dim, 3, 5, stride=2, padding=2, output_padding=1) 108 | ) 109 | 110 | 111 | class DCGANEncoder(nn.Sequential): 112 | def __init__(self, dim=128): 113 | super().__init__( 114 | nn.Conv2d(3, dim, 5, stride=2, padding=2), 115 | nn.ReLU(True), 116 | nn.BatchNorm2d(dim), 117 | nn.Conv2d(dim, dim * 2, 5, stride=2, padding=2), 118 | nn.ReLU(True), 119 | nn.BatchNorm2d(dim * 2), 120 | nn.Conv2d(dim * 2, dim * 4, 5, stride=2, padding=2), 121 | nn.ReLU(True), 122 | nn.BatchNorm2d(dim * 4), 123 | Flatten() 124 | ) 125 | 126 | self.output_size = dim * 4 * 4 * 4 127 | 128 | 129 | class BetaVAEDecoder(nn.Sequential): 130 | def __init__(self, z_dims): 131 | super().__init__( 132 | nn.Linear(z_dims, 256), 133 | nn.ReLU(True), 134 | nn.BatchNorm1d(256), 135 | nn.Linear(256, 64 * 2 * 2), 136 | nn.ReLU(True), 137 | nn.BatchNorm1d(64 * 2 * 2), 138 | View(64, 2, 2), 139 | nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1), 140 | nn.ReLU(True), 141 | nn.BatchNorm2d(64), 142 | nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), 143 | nn.ReLU(True), 144 | nn.BatchNorm2d(32), 145 | nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1), 146 | nn.ReLU(True), 147 | nn.BatchNorm2d(32), 148 | nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1) 149 | ) 150 | 151 | 152 | class BetaVAEDecoder2(nn.Sequential): 153 | def __init__(self, z_dims): 154 | super().__init__( 155 | nn.Linear(z_dims, 256), 156 | nn.ReLU(True), 157 | # nn.BatchNorm1d(256), 158 | nn.Linear(256, 64 * 2 * 2), 159 | nn.ReLU(True), 160 | # nn.BatchNorm1d(64 * 2 * 2), 161 | View(64, 2, 2), 162 | nn.Upsample(size=5, mode='bilinear'), 163 | nn.Conv2d(64, 64, 4, padding=1), 164 | nn.ReLU(True), 165 | nn.BatchNorm2d(64), 166 | nn.Upsample(size=9, mode='bilinear'), 167 | nn.Conv2d(64, 32, 4, padding=1), 168 | nn.ReLU(True), 169 | nn.BatchNorm2d(32), 170 | nn.Upsample(size=17, mode='bilinear'), 171 | nn.Conv2d(32, 32, 4, padding=1), 172 | nn.ReLU(True), 173 | nn.BatchNorm2d(32), 174 | nn.Upsample(size=33, mode='bilinear'), 175 | nn.Conv2d(32, 3, 4, padding=1) 176 | ) 177 | 178 | 179 | class BetaVAEEncoder(nn.Sequential): 180 | def __init__(self): 181 | super().__init__( 182 | nn.Conv2d( 3, 32, 4, stride=2, padding=1), 183 | nn.ReLU(True), 184 | nn.BatchNorm2d(32), 185 | nn.Conv2d(32, 32, 4, stride=2, padding=1), 186 | nn.ReLU(True), 187 | nn.BatchNorm2d(32), 188 | nn.Conv2d(32, 64, 4, stride=2, padding=1), 189 | nn.ReLU(True), 190 | nn.BatchNorm2d(64), 191 | nn.Conv2d(64, 64, 4, stride=2, padding=1), 192 | nn.ReLU(True), 193 | nn.BatchNorm2d(64), 194 | Flatten(), 195 | nn.Linear(64 * 2 * 2, 256), 196 | nn.ReLU(True), 197 | # nn.BatchNorm1d(256) 198 | ) 199 | 200 | self.output_size = 256 201 | 202 | 203 | class VAE(nn.Module): 204 | def __init__(self, z_dims=64, encoder=BetaVAEEncoder, decoder=BetaVAEDecoder2, var=0.1): 205 | super(VAE, self).__init__() 206 | 207 | self.var = var 208 | 209 | self.encoder = encoder() 210 | self.decoder = decoder(z_dims) 211 | 212 | self.loc = nn.Linear(self.encoder.output_size, z_dims) 213 | self.scale = nn.Linear(self.encoder.output_size, z_dims) 214 | self.loc.weight.data.zero_() 215 | self.loc.bias.data.zero_() 216 | self.scale.weight.data.zero_() 217 | self.scale.bias.data.zero_() 218 | 219 | def encode(self, x): 220 | x = self.encoder(x) 221 | loc = self.loc(x) 222 | scale = self.scale(x) 223 | return LogitNormal(loc, scale) 224 | 225 | def forward(self, x, state=None): 226 | if state is not None: 227 | state[torchbearer.TARGET] = x.detach() 228 | 229 | latent = self.encode(x) 230 | 231 | if state is not None: 232 | state[LATENT] = latent 233 | 234 | x = self.decoder(latent.rsample()) 235 | return LogitNormal(x, (torch.ones_like(x) * self.var).log()) 236 | 237 | 238 | class PredictionNetwork(nn.Module): 239 | def __init__(self, encoder_a, encoder_b, z_dims=32): 240 | super().__init__() 241 | self.z_dims = z_dims 242 | 243 | self.encoder_a = encoder_a 244 | for param in self.encoder_a.parameters(): 245 | param.requires_grad = False 246 | 247 | self.encoder_b = encoder_b 248 | for param in self.encoder_b.parameters(): 249 | param.requires_grad = False 250 | 251 | self.net = nn.Sequential( 252 | nn.Linear(z_dims, 32), 253 | nn.ReLU(True), 254 | nn.Linear(32, 32), 255 | nn.ReLU(True), 256 | nn.Linear(32, 32), 257 | nn.ReLU(True), 258 | nn.Linear(32, z_dims * 2) 259 | ) 260 | 261 | self.net[-1].weight.data.zero_() 262 | self.net[-1].bias.data.zero_() 263 | 264 | def forward(self, x, state): 265 | self.encoder_a.eval() 266 | self.encoder_b.eval() 267 | 268 | a = self.encoder_a.encode(x).rsample().detach() 269 | b = self.encoder_b.encode(x) 270 | b.loc = b.loc.detach() 271 | b.scale = b.scale.detach() 272 | 273 | state[torchbearer.TARGET] = b 274 | 275 | x = self.net(a) 276 | loc = x[:, :self.z_dims] 277 | scale = x[:, self.z_dims:] 278 | return LogitNormal(loc, scale) 279 | 280 | 281 | class MINetwork(nn.Module): 282 | def __init__(self, encoder_a, encoder_b, upper=False): 283 | super().__init__() 284 | self.upper = upper 285 | 286 | self.encoder_a = encoder_a 287 | for param in self.encoder_a.parameters(): 288 | param.requires_grad = False 289 | 290 | self.encoder_b = encoder_b 291 | for param in self.encoder_b.parameters(): 292 | param.requires_grad = False 293 | 294 | def forward(self, x, state): 295 | self.encoder_a.eval() 296 | self.encoder_b.eval() 297 | 298 | if not self.upper: 299 | x = self.encoder_a(x).rsample() 300 | return self.encoder_b.encode(x) 301 | -------------------------------------------------------------------------------- /analysis/vae_mi.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms, datasets 4 | import torchbearer 5 | from torchbearer import Trial 6 | import torch 7 | from torch import distributions 8 | 9 | from .vae import VAE, MINetwork 10 | 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description='VAE MI') 14 | parser.add_argument('--vae1', default='base_5', type=str, help='VAE 1') 15 | parser.add_argument('--vae2', default='cutmix_5', type=str, help='VAE 2') 16 | parser.add_argument('--upper', default=False, type=bool, help='if True, use upper bound, else lower') 17 | args = parser.parse_args() 18 | 19 | # Data 20 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 21 | inv_normalize = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010)) 22 | transform_base = [transforms.ToTensor(), normalize] 23 | 24 | transform_test = transforms.Compose(transform_base) 25 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 26 | 27 | test_loader = DataLoader(test_set, 100, shuffle=True, num_workers=5) 28 | 29 | # MI Loss 30 | 31 | def kld(state): 32 | y_pred = state[torchbearer.Y_PRED] 33 | marginal = distributions.Normal(0, 1) 34 | return distributions.kl_divergence(y_pred, marginal).sum().div(y_pred.loc.size(0)) 35 | 36 | # Train VAEs 37 | 38 | vae1 = VAE(64) 39 | vae1.load_state_dict(torch.load('vaes3/' + args.vae1 + '.pt')[torchbearer.MODEL]) 40 | 41 | for param in vae1.parameters(): 42 | param.requires_grad = False 43 | 44 | vae2 = VAE(64) 45 | vae2.load_state_dict(torch.load('vaes3/' + args.vae2 + '.pt')[torchbearer.MODEL]) 46 | 47 | for param in vae2.parameters(): 48 | param.requires_grad = False 49 | 50 | model = MINetwork(vae1, vae2, upper=args.upper) 51 | trial = Trial(model, criterion=kld, 52 | metrics=[ 53 | 'loss' 54 | ]) 55 | 56 | trial.with_generators(test_generator=test_loader).to('cuda').evaluate(data_key=torchbearer.TEST_DATA) 57 | -------------------------------------------------------------------------------- /analysis/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://pytorch.org/docs/stable/_modules/torchvision/models/vgg.html 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = [ 9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 10 | 'vgg19_bn', 'vgg19', 11 | ] 12 | 13 | 14 | class Cache: 15 | def __init__(self): 16 | super(Cache, self).__init__() 17 | self.outputs = {} 18 | 19 | def for_name(outer, layer_name): 20 | class Inner(nn.Module): 21 | def forward(inner, x): 22 | outer.outputs[layer_name] = x 23 | return x 24 | return Inner() 25 | 26 | def get_outputs(self): 27 | tmp = self.outputs 28 | self.outputs = {} 29 | return tmp 30 | 31 | 32 | class VGG(nn.Module): 33 | def __init__(self, features, cache, return_cache=False, num_classes=10, init_weights=True): 34 | super(VGG, self).__init__() 35 | self.features = features 36 | self.return_cache = return_cache 37 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 38 | self.cache = cache 39 | self.classifier = nn.Sequential( 40 | nn.Linear(512 * 7 * 7, 2048), # Normally 4096 41 | nn.ReLU(True), 42 | self.cache.for_name('c1'), 43 | nn.Dropout(), 44 | nn.Linear(2048, 2048), # Normally 4096 45 | nn.ReLU(True), 46 | self.cache.for_name('c2'), 47 | nn.Dropout(), 48 | nn.Linear(2048, num_classes), # Normally 4096 49 | ) 50 | if init_weights: 51 | self._initialize_weights() 52 | 53 | def forward(self, x): 54 | x = self.features(x) 55 | x = self.avgpool(x) 56 | x = torch.flatten(x, 1) 57 | x = self.classifier(x) 58 | outs = self.cache.get_outputs() 59 | if self.return_cache: 60 | return x, outs 61 | else: 62 | return x 63 | 64 | def _initialize_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 68 | if m.bias is not None: 69 | nn.init.constant_(m.bias, 0) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | nn.init.constant_(m.weight, 1) 72 | nn.init.constant_(m.bias, 0) 73 | elif isinstance(m, nn.Linear): 74 | nn.init.normal_(m.weight, 0, 0.01) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | 78 | def make_layers(cfg, batch_norm=False): 79 | cache = Cache() 80 | layers = [] 81 | in_channels = 3 82 | for v in cfg: 83 | if v == 'M': 84 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 85 | elif str(v) is v: 86 | layers += [cache.for_name(v)] 87 | else: 88 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 89 | if batch_norm: 90 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 91 | else: 92 | layers += [conv2d, nn.ReLU(inplace=True)] 93 | in_channels = v 94 | return nn.Sequential(*layers), cache 95 | 96 | 97 | cfgs = { 98 | 'A': [64, 'M', 'f1', 128, 'M', 'f2', 256, 'f3', 256, 'M', 'f4', 512, 'f5', 512, 'M', 'f6', 512, 'f7', 512, 'M', 'f8'], 99 | 'B': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'M', 'f6', 512, 'f7', 512, 'M', 'f8', 512, 'f9', 512, 'M', 'f10'], 100 | 'D': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'f6', 256, 'M', 'f7', 512, 'f8', 512, 'f9', 512, 'M', 'f10', 512, 'f11', 512, 'f12', 512, 'M', 'f13'], 101 | 'E': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'f6', 256, 'f7', 256, 'M', 'f8', 512, 'f9', 512, 'f10', 512, 'f11', 512, 'M', 'f12', 512, 'f13', 512, 'f14', 512, 'f15', 512, 'M', 'f16'], 102 | } 103 | 104 | 105 | def _vgg(cfg, batch_norm, **kwargs): 106 | model = VGG(*make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 107 | return model 108 | 109 | 110 | def vgg11(**kwargs): 111 | r"""VGG 11-layer model (configuration "A") from 112 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 113 | """ 114 | return _vgg('A', False, **kwargs) 115 | 116 | 117 | def vgg11_bn(**kwargs): 118 | r"""VGG 11-layer model (configuration "A") with batch normalization 119 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 120 | """ 121 | return _vgg('A', True, **kwargs) 122 | 123 | 124 | def vgg13(**kwargs): 125 | r"""VGG 13-layer model (configuration "B") 126 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 127 | """ 128 | return _vgg('B', False, **kwargs) 129 | 130 | 131 | def vgg13_bn(**kwargs): 132 | r"""VGG 13-layer model (configuration "B") with batch normalization 133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 134 | """ 135 | return _vgg('B', True, **kwargs) 136 | 137 | 138 | def vgg16(**kwargs): 139 | r"""VGG 16-layer model (configuration "D") 140 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 141 | """ 142 | return _vgg('D', False, **kwargs) 143 | 144 | 145 | def vgg16_bn(**kwargs): 146 | r"""VGG 16-layer model (configuration "D") with batch normalization 147 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 148 | """ 149 | return _vgg('D', True, **kwargs) 150 | 151 | 152 | def vgg19(**kwargs): 153 | r"""VGG 19-layer model (configuration "E") 154 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 155 | """ 156 | return _vgg('E', False, **kwargs) 157 | 158 | 159 | def vgg19_bn(**kwargs): 160 | r"""VGG 19-layer model (configuration 'E') with batch normalization 161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 162 | """ 163 | return _vgg('E', True, **kwargs) 164 | 165 | 166 | if __name__ == '__main__': 167 | from torch import optim 168 | from torchvision import transforms 169 | from torchvision.datasets import CIFAR10 170 | 171 | from torchbearer import Trial 172 | from torchbearer.callbacks import MultiStepLR, MostRecent, Mixup, CutMix 173 | from implementations.torchbearer_implementation import FMix 174 | 175 | for mode in ['baseline', 'mix', 'fmix', 'cutmix']: 176 | for i in range(0, 3): 177 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 178 | transform_base = [transforms.ToTensor(), normalize] 179 | 180 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base 181 | 182 | transform_train = transforms.Compose(transform) 183 | transform_test = transforms.Compose(transform_base) 184 | 185 | trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) 186 | valset = CIFAR10(root='./data', train=False, download=True, transform=transform_test) 187 | 188 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 189 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8) 190 | 191 | vgg = vgg11_bn(return_cache=False) 192 | optimizer = optim.SGD(vgg.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 193 | 194 | app = [] 195 | loss = nn.CrossEntropyLoss() 196 | 197 | if mode == 'mix': 198 | app = [Mixup()] 199 | loss = Mixup.mixup_loss 200 | if mode == 'fmix': 201 | app = [FMix(alpha=1)] 202 | loss = Mixup.mixup_loss 203 | if mode == 'cutmix': 204 | app = [CutMix(1.0, classes=10, mixup_loss=True)] 205 | loss = Mixup.mixup_loss 206 | 207 | trial = Trial(vgg, optimizer, loss, metrics=['acc', 'loss'], callbacks=app + [MostRecent(mode + '_' + str(i + 1) + '.pt'), MultiStepLR([100, 150])]) 208 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda') 209 | trial.run(200, verbose=1) 210 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/bengali.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://www.kaggle.com/corochann/bengali-seresnext-training-with-pytorch 3 | """ 4 | from torch.utils.data import Dataset 5 | # import os 6 | # from torchvision.datasets.folder import default_loader 7 | import numpy as np 8 | import pandas as pd 9 | import gc 10 | 11 | 12 | def prepare_image(root, indices=[0, 1, 2, 3]): 13 | # assert data_type in ['train', 'test'] 14 | # if submission: 15 | # image_df_list = [pd.read_parquet(datadir / f'{data_type}_image_data_{i}.parquet') 16 | # for i in indices] 17 | # else: 18 | image_df_list = [pd.read_feather(f'{root}/train_image_data_{i}.feather') for i in indices] 19 | 20 | HEIGHT = 137 21 | WIDTH = 236 22 | images = [df.iloc[:, 1:].values.reshape(-1, HEIGHT, WIDTH) for df in image_df_list] 23 | del image_df_list 24 | gc.collect() 25 | images = np.concatenate(images, axis=0) 26 | return images 27 | 28 | 29 | class Bengali(Dataset): 30 | def __init__(self, root, targets, transform=None): 31 | self.transform = transform 32 | 33 | if isinstance(targets, list): 34 | self.labels = list(pd.read_csv(f'{root}/train.csv')[targets].itertuples(index=False, name=None)) 35 | else: 36 | self.labels = pd.read_csv(f'{root}/train.csv')[targets] 37 | self.images = prepare_image(root) 38 | 39 | def __getitem__(self, index): 40 | image, label = self.images[index], self.labels[index] 41 | image = (255 - image).astype(np.float32) / 255. 42 | 43 | if self.transform is not None: 44 | image = self.transform(image) 45 | 46 | return image, label 47 | 48 | def __len__(self) -> int: 49 | return len(self.labels) 50 | 51 | 52 | class BengaliGraphemeWhole(Bengali): 53 | def __init__(self, root, transform=None): 54 | super().__init__(root, ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'], transform=transform) 55 | 56 | 57 | class BengaliGraphemeRoot(Bengali): 58 | def __init__(self, root, transform=None): 59 | super().__init__(root, 'grapheme_root', transform=transform) 60 | 61 | 62 | class BengaliVowelDiacritic(Bengali): 63 | def __init__(self, root, transform=None): 64 | super().__init__(root, 'vowel_diacritic', transform=transform) 65 | 66 | 67 | class BengaliConsonantDiacritic(Bengali): 68 | def __init__(self, root, transform=None): 69 | super().__init__(root, 'consonant_diacritic', transform=transform) 70 | -------------------------------------------------------------------------------- /datasets/fashion.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | 3 | 4 | class OldFashionMNIST(MNIST): 5 | """The original version of the `Fashion-MNIST `_ Dataset. This had 6 | some train / test duplicates and so was replaced. Do not use in practice, may be helpful for reproducing results 7 | from FMix, RandomErase and other papers. 8 | 9 | Args: 10 | root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt`` 11 | and ``Fashion-MNIST/processed/test.pt`` exist. 12 | train (bool, optional): If True, creates dataset from ``training.pt``, 13 | otherwise from ``test.pt``. 14 | download (bool, optional): If true, downloads the dataset from the internet and 15 | puts it in root directory. If dataset is already downloaded, it is not 16 | downloaded again. 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | """ 22 | urls = [ 23 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/train-images-idx3-ubyte.gz', 24 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/train-labels-idx1-ubyte.gz', 25 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/t10k-images-idx3-ubyte.gz', 26 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/t10k-labels-idx1-ubyte.gz', 27 | ] 28 | classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 29 | 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 30 | -------------------------------------------------------------------------------- /datasets/google_commands/README.md: -------------------------------------------------------------------------------- 1 | Code adapted from https://github.com/tugstugi/pytorch-speech-commands -------------------------------------------------------------------------------- /datasets/google_commands/google_commands.py: -------------------------------------------------------------------------------- 1 | """Google speech commands dataset.""" 2 | __author__ = 'Yuan Xu' 3 | 4 | import os 5 | import numpy as np 6 | 7 | from torch.utils.data import Dataset 8 | 9 | __all__ = [ 'CLASSES', 'SpeechCommandsDataset', 'BackgroundNoiseDataset' ] 10 | 11 | CLASSES = 'unknown, silence, yes, no, up, down, left, right, on, off, stop, go'.split(', ') 12 | 13 | 14 | class SpeechCommandsDataset(Dataset): 15 | """Google speech commands dataset. Only 'yes', 'no', 'up', 'down', 'left', 16 | 'right', 'on', 'off', 'stop' and 'go' are treated as known classes. 17 | All other classes are used as 'unknown' samples. 18 | See for more information: https://www.kaggle.com/c/tensorflow-speech-recognition-challenge 19 | """ 20 | 21 | def __init__(self, folder, transform=None, classes=CLASSES, silence_percentage=0.1): 22 | try: 23 | import librosa 24 | except: 25 | raise ModuleNotFoundError('Librosa package is reuiqred for google commands experiments. Try pip install') 26 | 27 | all_classes = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith('_')] 28 | #for c in classes[2:]: 29 | # assert c in all_classes 30 | 31 | class_to_idx = {classes[i]: i for i in range(len(classes))} 32 | for c in all_classes: 33 | if c not in class_to_idx: 34 | class_to_idx[c] = 0 35 | 36 | data = [] 37 | for c in all_classes: 38 | d = os.path.join(folder, c) 39 | target = class_to_idx[c] 40 | for f in os.listdir(d): 41 | path = os.path.join(d, f) 42 | data.append((path, target)) 43 | 44 | # add silence 45 | target = class_to_idx['silence'] 46 | data += [('', target)] * int(len(data) * silence_percentage) 47 | 48 | self.classes = classes 49 | self.data = data 50 | self.transform = transform 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | def __getitem__(self, index): 56 | path, target = self.data[index] 57 | data = {'path': path, 'target': target} 58 | 59 | if self.transform is not None: 60 | data = self.transform(data) 61 | 62 | return data, target 63 | 64 | def make_weights_for_balanced_classes(self): 65 | """adopted from https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3""" 66 | 67 | nclasses = len(self.classes) 68 | count = np.zeros(nclasses) 69 | for item in self.data: 70 | count[item[1]] += 1 71 | 72 | N = float(sum(count)) 73 | weight_per_class = N / count 74 | weight = np.zeros(len(self)) 75 | for idx, item in enumerate(self.data): 76 | weight[idx] = weight_per_class[item[1]] 77 | return weight 78 | 79 | 80 | class BackgroundNoiseDataset(Dataset): 81 | """Dataset for silence / background noise.""" 82 | 83 | def __init__(self, folder, transform=None, sample_rate=16000, sample_length=1): 84 | try: 85 | import librosa 86 | except: 87 | raise ModuleNotFoundError('Librosa package is reuiqred for google commands experiments. Try pip install') 88 | audio_files = [d for d in os.listdir(folder) if os.path.isfile(os.path.join(folder, d)) and d.endswith('.wav')] 89 | samples = [] 90 | for f in audio_files: 91 | path = os.path.join(folder, f) 92 | s, sr = librosa.load(path, sample_rate) 93 | samples.append(s) 94 | 95 | samples = np.hstack(samples) 96 | c = int(sample_rate * sample_length) 97 | r = len(samples) // c 98 | self.samples = samples[:r*c].reshape(-1, c) 99 | self.sample_rate = sample_rate 100 | self.classes = CLASSES 101 | self.transform = transform 102 | self.path = folder 103 | 104 | def __len__(self): 105 | return len(self.samples) 106 | 107 | def __getitem__(self, index): 108 | data = {'samples': self.samples[index], 'sample_rate': self.sample_rate, 'target': 1, 'path': self.path} 109 | 110 | if self.transform is not None: 111 | data = self.transform(data) 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /datasets/google_commands/sft_transforms.py: -------------------------------------------------------------------------------- 1 | """Transforms on the short time fourier transforms of wav samples.""" 2 | 3 | __author__ = 'Erdene-Ochir Tuguldur' 4 | 5 | import random 6 | 7 | import numpy as np 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from datasets.google_commands.transforms import should_apply_transform 12 | 13 | 14 | class ToSTFT(object): 15 | """Applies on an audio the short time fourier transform.""" 16 | 17 | def __init__(self, n_fft=2048, hop_length=512): 18 | self.n_fft = n_fft 19 | self.hop_length = hop_length 20 | 21 | def __call__(self, data): 22 | import librosa 23 | samples = data['samples'] 24 | sample_rate = data['sample_rate'] 25 | data['n_fft'] = self.n_fft 26 | data['hop_length'] = self.hop_length 27 | data['stft'] = librosa.stft(samples, n_fft=self.n_fft, hop_length=self.hop_length) 28 | data['stft_shape'] = data['stft'].shape 29 | return data 30 | 31 | 32 | class StretchAudioOnSTFT(object): 33 | """Stretches an audio on the frequency domain.""" 34 | 35 | def __init__(self, max_scale=0.2): 36 | self.max_scale = max_scale 37 | 38 | def __call__(self, data): 39 | import librosa 40 | 41 | if not should_apply_transform(): 42 | return data 43 | 44 | stft = data['stft'] 45 | sample_rate = data['sample_rate'] 46 | hop_length = data['hop_length'] 47 | scale = random.uniform(-self.max_scale, self.max_scale) 48 | stft_stretch = librosa.core.phase_vocoder(stft, 1 + scale, hop_length=hop_length) 49 | data['stft'] = stft_stretch 50 | return data 51 | 52 | 53 | class TimeshiftAudioOnSTFT(object): 54 | """A simple timeshift on the frequency domain without multiplying with exp.""" 55 | 56 | def __init__(self, max_shift=8): 57 | self.max_shift = max_shift 58 | 59 | def __call__(self, data): 60 | if not should_apply_transform(): 61 | return data 62 | 63 | stft = data['stft'] 64 | shift = random.randint(-self.max_shift, self.max_shift) 65 | a = -min(0, shift) 66 | b = max(0, shift) 67 | stft = np.pad(stft, ((0, 0), (a, b)), "constant") 68 | if a == 0: 69 | stft = stft[:, b:] 70 | else: 71 | stft = stft[:, 0:-a] 72 | data['stft'] = stft 73 | return data 74 | 75 | 76 | class AddBackgroundNoiseOnSTFT(Dataset): 77 | """Adds a random background noise on the frequency domain.""" 78 | 79 | def __init__(self, bg_dataset, max_percentage=0.45): 80 | self.bg_dataset = bg_dataset 81 | self.max_percentage = max_percentage 82 | 83 | def __call__(self, data): 84 | if not should_apply_transform(): 85 | return data 86 | 87 | noise = random.choice(self.bg_dataset)['stft'] 88 | percentage = random.uniform(0, self.max_percentage) 89 | data['stft'] = data['stft'] * (1 - percentage) + noise * percentage 90 | return data 91 | 92 | 93 | class FixSTFTDimension(object): 94 | """Either pads or truncates in the time axis on the frequency domain, applied after stretching, time shifting etc.""" 95 | 96 | def __call__(self, data): 97 | stft = data['stft'] 98 | t_len = stft.shape[1] 99 | orig_t_len = data['stft_shape'][1] 100 | if t_len > orig_t_len: 101 | stft = stft[:, 0:orig_t_len] 102 | elif t_len < orig_t_len: 103 | stft = np.pad(stft, ((0, 0), (0, orig_t_len - t_len)), "constant") 104 | 105 | data['stft'] = stft 106 | return data 107 | 108 | 109 | class ToMelSpectrogramFromSTFT(object): 110 | """Creates the mel spectrogram from the short time fourier transform of a file. The result is a 32x32 matrix.""" 111 | 112 | def __init__(self, n_mels=32): 113 | self.n_mels = n_mels 114 | 115 | def __call__(self, data): 116 | import librosa 117 | 118 | stft = data['stft'] 119 | sample_rate = data['sample_rate'] 120 | n_fft = data['n_fft'] 121 | mel_basis = librosa.filters.mel(sample_rate, n_fft, self.n_mels) 122 | s = np.dot(mel_basis, np.abs(stft) ** 2.0) 123 | data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max) 124 | return data 125 | 126 | 127 | class DeleteSTFT(object): 128 | """Pytorch doesn't like complex numbers, use this transform to remove STFT after computing the mel spectrogram.""" 129 | 130 | def __call__(self, data): 131 | del data['stft'] 132 | return data 133 | 134 | 135 | class AudioFromSTFT(object): 136 | """Inverse short time fourier transform.""" 137 | 138 | def __call__(self, data): 139 | import librosa 140 | 141 | stft = data['stft'] 142 | data['istft_samples'] = librosa.core.istft(stft, dtype=data['samples'].dtype) 143 | return data 144 | -------------------------------------------------------------------------------- /datasets/google_commands/transforms.py: -------------------------------------------------------------------------------- 1 | """Transforms on raw wav samples.""" 2 | 3 | __author__ = 'Yuan Xu' 4 | 5 | import random 6 | import numpy as np 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | 12 | def should_apply_transform(prob=0.5): 13 | """Transforms are only randomly applied with the given probability.""" 14 | return random.random() < prob 15 | 16 | 17 | class LoadAudio(object): 18 | """Loads an audio into a numpy array.""" 19 | 20 | def __init__(self, sample_rate=16000): 21 | self.sample_rate = sample_rate 22 | 23 | def __call__(self, data): 24 | import librosa 25 | 26 | path = data['path'] 27 | if path: 28 | samples, sample_rate = librosa.load(path, self.sample_rate) 29 | else: 30 | # silence 31 | sample_rate = self.sample_rate 32 | samples = np.zeros(sample_rate, dtype=np.float32) 33 | data['samples'] = samples 34 | data['sample_rate'] = sample_rate 35 | return data 36 | 37 | 38 | class FixAudioLength(object): 39 | """Either pads or truncates an audio into a fixed length.""" 40 | 41 | def __init__(self, time=1): 42 | self.time = time 43 | 44 | def __call__(self, data): 45 | samples = data['samples'] 46 | sample_rate = data['sample_rate'] 47 | length = int(self.time * sample_rate) 48 | if length < len(samples): 49 | data['samples'] = samples[:length] 50 | elif length > len(samples): 51 | data['samples'] = np.pad(samples, (0, length - len(samples)), "constant") 52 | return data 53 | 54 | 55 | class ChangeAmplitude(object): 56 | """Changes amplitude of an audio randomly.""" 57 | 58 | def __init__(self, amplitude_range=(0.7, 1.1)): 59 | self.amplitude_range = amplitude_range 60 | 61 | def __call__(self, data): 62 | if not should_apply_transform(): 63 | return data 64 | 65 | data['samples'] = data['samples'] * random.uniform(*self.amplitude_range) 66 | return data 67 | 68 | 69 | class ChangeSpeedAndPitchAudio(object): 70 | """Change the speed of an audio. This transform also changes the pitch of the audio.""" 71 | 72 | def __init__(self, max_scale=0.2): 73 | self.max_scale = max_scale 74 | 75 | def __call__(self, data): 76 | if not should_apply_transform(): 77 | return data 78 | 79 | samples = data['samples'] 80 | sample_rate = data['sample_rate'] 81 | scale = random.uniform(-self.max_scale, self.max_scale) 82 | speed_fac = 1.0 / (1 + scale) 83 | data['samples'] = np.interp(np.arange(0, len(samples), speed_fac), np.arange(0,len(samples)), samples).astype(np.float32) 84 | return data 85 | 86 | 87 | class StretchAudio(object): 88 | """Stretches an audio randomly.""" 89 | 90 | def __init__(self, max_scale=0.2): 91 | self.max_scale = max_scale 92 | 93 | def __call__(self, data): 94 | import librosa 95 | 96 | if not should_apply_transform(): 97 | return data 98 | 99 | scale = random.uniform(-self.max_scale, self.max_scale) 100 | data['samples'] = librosa.effects.time_stretch(data['samples'], 1+scale) 101 | return data 102 | 103 | 104 | class TimeshiftAudio(object): 105 | """Shifts an audio randomly.""" 106 | 107 | def __init__(self, max_shift_seconds=0.2): 108 | self.max_shift_seconds = max_shift_seconds 109 | 110 | def __call__(self, data): 111 | if not should_apply_transform(): 112 | return data 113 | 114 | samples = data['samples'] 115 | sample_rate = data['sample_rate'] 116 | max_shift = (sample_rate * self.max_shift_seconds) 117 | shift = random.randint(-max_shift, max_shift) 118 | a = -min(0, shift) 119 | b = max(0, shift) 120 | samples = np.pad(samples, (a, b), "constant") 121 | data['samples'] = samples[:len(samples) - a] if a else samples[b:] 122 | return data 123 | 124 | 125 | class AddBackgroundNoise(Dataset): 126 | """Adds a random background noise.""" 127 | 128 | def __init__(self, bg_dataset, max_percentage=0.45): 129 | self.bg_dataset = bg_dataset 130 | self.max_percentage = max_percentage 131 | 132 | def __call__(self, data): 133 | if not should_apply_transform(): 134 | return data 135 | 136 | samples = data['samples'] 137 | noise = random.choice(self.bg_dataset)['samples'] 138 | percentage = random.uniform(0, self.max_percentage) 139 | data['samples'] = samples * (1 - percentage) + noise * percentage 140 | return data 141 | 142 | 143 | class ToMelSpectrogram(object): 144 | """Creates the mel spectrogram from an audio. The result is a 32x32 matrix.""" 145 | 146 | def __init__(self, n_mels=32): 147 | self.n_mels = n_mels 148 | 149 | def __call__(self, data): 150 | import librosa 151 | 152 | samples = data['samples'] 153 | sample_rate = data['sample_rate'] 154 | s = librosa.feature.melspectrogram(samples, sr=sample_rate, n_mels=self.n_mels) 155 | data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max) 156 | return data 157 | 158 | 159 | class ToTensor(object): 160 | """Converts into a tensor.""" 161 | 162 | def __init__(self, np_name, tensor_name, normalize=None): 163 | self.np_name = np_name 164 | self.tensor_name = tensor_name 165 | self.normalize = normalize 166 | 167 | def __call__(self, data): 168 | tensor = torch.FloatTensor(data[self.np_name]) 169 | if self.normalize is not None: 170 | mean, std = self.normalize 171 | tensor -= mean 172 | tensor /= std 173 | data[self.tensor_name] = tensor 174 | return (data['input'].unsqueeze(0)+40)/80 -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 2 | 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 3 | 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 4 | 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 5 | 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 6 | 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 7 | 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 8 | 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 9 | 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 10 | 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 11 | 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 12 | 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 13 | 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 14 | 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 15 | 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 16 | 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 17 | 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 18 | 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 19 | 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 20 | 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 21 | 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 22 | 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 23 | 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 24 | 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 25 | 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 26 | 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 27 | 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 28 | 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 29 | 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 30 | 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 31 | 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 32 | 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 33 | 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 34 | 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 35 | 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 36 | 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 37 | 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 38 | 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 39 | 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 40 | 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 41 | 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 42 | 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 43 | 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 44 | 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 45 | 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 46 | 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 47 | 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 48 | 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 49 | 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 50 | 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 51 | 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 52 | 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 53 | 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 54 | 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 55 | 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 56 | 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 57 | 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 58 | 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 59 | 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 60 | 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 61 | 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 62 | 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 63 | 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 64 | 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 65 | 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 66 | 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 67 | 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 68 | 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 69 | 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 70 | 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 71 | 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 72 | 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 73 | 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 74 | 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 75 | 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 76 | 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 77 | 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 78 | 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 79 | 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 80 | 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 81 | 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 82 | 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 83 | 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 84 | 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 85 | 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 86 | 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 87 | 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 88 | 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 89 | 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 90 | 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 91 | 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 92 | 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 93 | 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 94 | 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 95 | 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 96 | 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 97 | 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 98 | 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 99 | 998: -1, 999: -1} 100 | 101 | indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1] 102 | -------------------------------------------------------------------------------- /datasets/imagenet_hdf5.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import pickle 4 | 5 | import h5py 6 | from PIL import Image 7 | from torchvision.datasets import VisionDataset 8 | 9 | 10 | class ImageNetHDF5(VisionDataset): 11 | def __init__(self, root, cache_size=500, transform=None): 12 | super(ImageNetHDF5, self).__init__(root, transform=transform, target_transform=None) 13 | 14 | self.dest = pickle.load(open(os.path.join(root, 'dest.p'), 'rb')) 15 | self.cache = {} 16 | self.cache_size = cache_size 17 | 18 | targets = sorted(list(filter(lambda f: '.hdf5' in f, os.listdir(root)))) 19 | self.targets = {f[:-5]: i for i, f in enumerate(targets)} 20 | self.fill_cache() 21 | 22 | def load(self, file, i): 23 | with h5py.File(os.path.join(self.root, file + '.hdf5'), 'r') as f: 24 | return f['data'][i] 25 | 26 | def fill_cache(self): 27 | print('Filling cache') 28 | files = (f[:-5] for f in list(filter(lambda f: '.hdf5' in f, os.listdir(self.root)))[:self.cache_size]) 29 | for file in files: 30 | with h5py.File(os.path.join(self.root, file + '.hdf5'), 'r') as f: 31 | self.cache[file] = list(f['data']) 32 | print('Done') 33 | 34 | def load_from_cache(self, file, i): 35 | if file in self.cache: 36 | return self.cache[file][i] 37 | return self.load(file, i) 38 | 39 | def __getitem__(self, index): 40 | dest, i = self.dest[index] 41 | 42 | sample = self.load_from_cache(dest, i) 43 | 44 | sample = Image.open(io.BytesIO(sample)) 45 | sample = sample.convert('RGB') 46 | 47 | if self.transform is not None: 48 | sample = self.transform(sample) 49 | 50 | return sample, self.targets[dest] 51 | 52 | def __len__(self): 53 | return len(self.dest) 54 | -------------------------------------------------------------------------------- /datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | from torchvision.datasets.folder import default_loader 4 | 5 | 6 | class TinyImageNet(Dataset): 7 | def __init__(self, root, train=True, transform=None): 8 | super().__init__() 9 | self.root = root 10 | self.transform = transform 11 | self.words = self.parse_classes() 12 | 13 | if train: 14 | self.class_path = os.path.join(root, 'train') 15 | self.img_labels = self.parse_train() 16 | else: 17 | self.class_path = os.path.join(root, 'val') 18 | self.img_labels = self.parse_val_labels() 19 | 20 | def parse_classes(self): 21 | words_path = os.path.join(self.root, 'wnids.txt') 22 | words = {} 23 | i = 0 24 | with open(words_path, 'r') as f: 25 | for w in f: 26 | w = w.strip('\n') 27 | word_label = w.split('\t')[0] 28 | words[word_label] = i 29 | i += 1 30 | return words 31 | 32 | def parse_val_labels(self): 33 | val_annot = os.path.join(self.root, 'val', 'val_annotations.txt') 34 | img_label = [] 35 | with open(val_annot, 'r') as f: 36 | for line in f: 37 | line.strip('\n') 38 | img, word, *_ = line.split('\t') 39 | img = os.path.join(self.root, 'val', 'images', img) 40 | img_label.append((img, self.words[word])) 41 | return img_label 42 | 43 | def parse_train(self): 44 | img_labels = [] 45 | for c in os.listdir(self.class_path): 46 | label = self.words[c] 47 | images_path = os.path.join(self.root, 'train', c, 'images') 48 | for im in os.listdir(images_path): 49 | im_path = os.path.join(images_path, im) 50 | img_labels.append((im_path, label)) 51 | return img_labels 52 | 53 | def __getitem__(self, index): 54 | img, label = self.img_labels[index] 55 | pil_img = default_loader(img) 56 | 57 | if self.transform is not None: 58 | pil_img = self.transform(pil_img) 59 | 60 | return pil_img, label 61 | 62 | def __len__(self) -> int: 63 | return len(self.img_labels) 64 | -------------------------------------------------------------------------------- /datasets/toxic.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import re 3 | 4 | import pandas as pd 5 | import os 6 | import torch 7 | 8 | from torchbearer import Callback 9 | import torchbearer 10 | 11 | 12 | def toxic_ds(args): 13 | from torchtext import data 14 | import torchtext 15 | tok = spacy.load('en') 16 | stopwords = spacy.lang.en.stop_words.STOP_WORDS 17 | stopwords.update(['wikipedia', 'article', 'articles', 'im', 'page']) 18 | 19 | def spacy_tok(x): 20 | x = re.sub(r'[^a-zA-Z\s]', '', x) 21 | x = re.sub(r'[\n]', ' ', x) 22 | return [t.text for t in tok.tokenizer(x)] 23 | 24 | TEXT = data.Field(lower=True, tokenize=spacy_tok, eos_token='EOS', stop_words=stopwords, include_lengths=True) 25 | LABEL = data.Field(sequential=False, use_vocab=False, pad_token=None, unk_token=None) 26 | dataFields = [("id", None), ("comment_text", TEXT), ("toxic", LABEL), ("severe_toxic", LABEL), ("threat", LABEL), 27 | ("obscene", LABEL), ("insult", LABEL), ("identity_hate", LABEL)] 28 | trainset_path = os.path.join(args.dataset_path, 'train.csv') 29 | train = data.TabularDataset(path=trainset_path, format='csv', fields=dataFields, skip_header=True) 30 | 31 | TEXT.build_vocab(train, vectors='fasttext.simple.300d') 32 | traindl = torchtext.data.BucketIterator(dataset=train, batch_size=args.batch_size, 33 | sort_key=lambda x: len(x.comment_text), device=torch.device(args.device), 34 | sort_within_batch=True) 35 | 36 | test_csv_path = os.path.join(args.dataset_path, 'test.csv') 37 | test_labels_path = os.path.join(args.dataset_path, 'test_labels.csv') 38 | test_set_path = os.path.join(args.dataset_path, 'test_set.csv') 39 | if not os.path.isfile(test_set_path): 40 | a = pd.read_csv(test_csv_path) 41 | b = pd.read_csv(test_labels_path) 42 | b = b.dropna(axis=1) 43 | merged = a.merge(b, on='id') 44 | merged = merged[merged['toxic'] >= 0] 45 | merged.to_csv(test_set_path, index=False) 46 | 47 | testset = data.TabularDataset(path=test_set_path, format='csv', fields=dataFields, skip_header=True) 48 | testdl = torchtext.data.BucketIterator(dataset=testset, batch_size=64, sort_key=lambda x: len(x.comment_text), 49 | device=torch.device(args.device), sort_within_batch=True) 50 | vectors = train.fields['comment_text'].vocab.vectors.to(args.device) 51 | 52 | traindl, testdl = BatchGenerator(traindl), BatchGenerator(testdl) 53 | traindl.vectors = vectors 54 | traindl.ntokens = len(TEXT.vocab) 55 | 56 | return traindl, None, testdl 57 | 58 | 59 | class BatchGenerator: 60 | def __init__(self, dl): 61 | self.dl = dl 62 | self.yFields = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] 63 | self.x = 'comment_text' 64 | 65 | def __len__(self): 66 | return len(self.dl) 67 | 68 | def __iter__(self): 69 | for batch in self.dl: 70 | X = list(getattr(batch, self.x)) 71 | X = X[0].permute(1, 0) 72 | y = torch.transpose(torch.stack([getattr(batch, y) for y in self.yFields]), 0, 1) 73 | yield (X, y) 74 | 75 | 76 | class ToxicHelper(Callback): 77 | def __init__(self, to_float=True): 78 | self.convert = (lambda x: x.float()) if to_float else (lambda x: x) 79 | 80 | def on_start(self, state): 81 | super().on_start(state) 82 | vectors = state[torchbearer.TRAIN_GENERATOR].vectors 83 | ntokens = state[torchbearer.TRAIN_GENERATOR].ntokens 84 | state[torchbearer.MODEL].init_embedding(vectors, ntokens, state[torchbearer.DEVICE]) 85 | 86 | def on_sample(self, state): 87 | state[torchbearer.Y_TRUE] = self.convert(state[torchbearer.Y_TRUE]) 88 | 89 | state[torchbearer.X] = state[torchbearer.MODEL].embed(state[torchbearer.X]) 90 | 91 | def on_sample_validation(self, state): 92 | state[torchbearer.Y_TRUE] = self.convert(state[torchbearer.Y_TRUE]) 93 | 94 | state[torchbearer.X] = state[torchbearer.MODEL].embed(state[torchbearer.X]) 95 | -------------------------------------------------------------------------------- /datasets/toxic_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import Dataset 7 | from torchtext.data import Iterator, batch, pool 8 | from tqdm import tqdm 9 | 10 | 11 | class ToxicDataset(Dataset): 12 | def __init__(self, dataframe, bert_model='bert-base-cased'): 13 | from transformers import BertTokenizer 14 | self.tokenizer = BertTokenizer.from_pretrained(bert_model) 15 | self.pad_idx = self.tokenizer.pad_token_id 16 | 17 | self.X = [] 18 | self.Y = [] 19 | for i, (row) in tqdm(dataframe.iterrows()): 20 | x, y = self.row_to_tensor(self.tokenizer, row) 21 | self.X.append(x) 22 | self.Y.append(y) 23 | 24 | @staticmethod 25 | def row_to_tensor(tokenizer, row): 26 | tokens = tokenizer.encode(row["comment_text"], add_special_tokens=True, max_length=128) 27 | 28 | x = torch.LongTensor(tokens) 29 | y = torch.FloatTensor(row[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]]) 30 | return x, y 31 | 32 | def __len__(self): 33 | return len(self.X) 34 | 35 | def __getitem__(self, index): 36 | return self.X[index], self.Y[index] 37 | 38 | 39 | class NoBatchBucketIterator(Iterator): 40 | """Defines an iterator that batches examples of similar lengths together. 41 | 42 | Minimizes amount of padding needed while producing freshly shuffled 43 | batches for each new epoch. See pool for the bucketing procedure used. 44 | """ 45 | 46 | def create_batches(self): 47 | if self.sort: 48 | self.batches = batch(self.data(), self.batch_size, 49 | self.batch_size_fn) 50 | else: 51 | self.batches = pool(self.data(), self.batch_size, 52 | self.sort_key, self.batch_size_fn, 53 | random_shuffler=self.random_shuffler, 54 | shuffle=self.shuffle, 55 | sort_within_batch=self.sort_within_batch) 56 | 57 | def __iter__(self): 58 | while True: 59 | self.init_epoch() 60 | for idx, minibatch in enumerate(self.batches): 61 | # fast-forward if loaded from state 62 | if self._iterations_this_epoch > idx: 63 | continue 64 | self.iterations += 1 65 | self._iterations_this_epoch += 1 66 | if self.sort_within_batch: 67 | # NOTE: `rnn.pack_padded_sequence` requires that a minibatch 68 | # be sorted by decreasing order, which requires reversing 69 | # relative to typical sort keys 70 | if self.sort: 71 | minibatch.reverse() 72 | else: 73 | minibatch.sort(key=self.sort_key, reverse=True) 74 | 75 | x, y = list(zip(*minibatch)) 76 | x = pad_sequence(x, batch_first=True, padding_value=0) 77 | y = torch.stack(y) 78 | yield x, y 79 | if not self.repeat: 80 | return 81 | 82 | 83 | def toxic_bert(args): 84 | trainset_path = os.path.join(args.dataset_path, 'train.csv') 85 | test_csv_path = os.path.join(args.dataset_path, 'test.csv') 86 | test_labels_path = os.path.join(args.dataset_path, 'test_labels.csv') 87 | testset_path = os.path.join(args.dataset_path, 'test_set.csv') 88 | if not os.path.isfile(testset_path): 89 | a = pd.read_csv(test_csv_path) 90 | b = pd.read_csv(test_labels_path) 91 | b = b.dropna(axis=1) 92 | merged = a.merge(b, on='id') 93 | merged = merged[merged['toxic'] >= 0] 94 | merged.to_csv(testset_path, index=False) 95 | 96 | train_df = pd.read_csv(trainset_path) 97 | test_df = pd.read_csv(testset_path) 98 | 99 | train_dataset = ToxicDataset(train_df) 100 | test_dataset = ToxicDataset(test_df) 101 | 102 | train_loader = NoBatchBucketIterator(dataset=train_dataset, batch_size=args.batch_size, 103 | sort_key=lambda x: x[0].size(0), 104 | device=torch.device(args.device), sort_within_batch=True) 105 | test_loader = NoBatchBucketIterator(dataset=test_dataset, batch_size=args.batch_size, 106 | sort_key=lambda x: x[0].size(0), 107 | device=torch.device(args.device), sort_within_batch=True) 108 | 109 | return train_loader, None, test_loader 110 | -------------------------------------------------------------------------------- /experiments/bengali_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (msda_mode) (type) (dataset path)` 3 | # where msda_mode is one of [fmix, mixup, None] 4 | # type is one of [r, c, v] 5 | # For multiple GPU, add --parallel=True 6 | 7 | python ../trainer.py --dataset bengali_${2} --fold 0 --model se_resnext50_32x4d --epoch 100 --schedule 50 75 --batch-size 512 --lr=0.1 --dataset-path=$3 --msda-mode=$1 8 | -------------------------------------------------------------------------------- /experiments/cifar_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (cifar) (model) (msda_mode) (dataset path)` 3 | # where: 4 | # cifar is one of [cifar10, cifar100] 5 | # model is one of [resnet, wrn, densenet, pyramidnet] 6 | # msda_mode is one of [fmix, mixup, None] 7 | # For multiple GPU, add --parallel=True 8 | 9 | if [ "$1" == "cifar10" ] 10 | then 11 | ds=cifar10 12 | fi 13 | if [ "$1" == "cifar100" ] 14 | then 15 | ds=cifar10 16 | fi 17 | 18 | if [ "$2" == "resnet" ] 19 | then 20 | model=ResNet18 21 | epoch=200 22 | schedule=(100 150) 23 | bs=128 24 | cosine=False 25 | fi 26 | 27 | if [ "$2" == "wrn" ] 28 | then 29 | model=wrn 30 | epoch=200 31 | schedule=(100 150) 32 | bs=128 33 | cosine=False 34 | fi 35 | 36 | if [ "$2" == "densenet" ] 37 | then 38 | model=DenseNet190 39 | epoch=300 40 | schedule=(100 150 225) 41 | bs=32 42 | cosine=False 43 | fi 44 | 45 | if [ "$2" == "pyramidnet" ] 46 | then 47 | model=aa_PyramidNet 48 | epoch=1800 49 | schedule=2000 50 | bs=64 51 | cosine=True 52 | fi 53 | 54 | python ../trainer.py --dataset=$ds --model=$model --epoch=$epoch --schedule ${schedule[@]} --lr=0.1 --dataset-path=$4 --msda-mode=$3 --batch-size=$bs --cosine-scheduler=$cosine 55 | -------------------------------------------------------------------------------- /experiments/fashion_experiments.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (model) (msda_mode) (dataset path)` 3 | # where: 4 | # model is one of [resnet, wrn, densenet] 5 | # msda_mode is one of [fmix, mixup, None] 6 | # For multiple GPU, add --parallel=True 7 | 8 | if [ "$1" == "resnet" ] 9 | then 10 | model=ResNet18 11 | epoch=200 12 | schedule=(100 150) 13 | bs=128 14 | fi 15 | 16 | if [ "$1" == "wrn" ] 17 | then 18 | model=wrn 19 | epoch=300 20 | schedule=(100 150 225) 21 | bs=32 22 | fi 23 | 24 | if [ "$1" == "densenet" ] 25 | then 26 | model=DenseNet190 27 | epoch=300 28 | schedule=(100 150 225) 29 | bs=32 30 | fi 31 | 32 | python ../trainer.py --dataset=fashion --model=$model --epoch=$epoch --schedule ${schedule[@]} --lr=0.1 --dataset-path=$3 --msda-mode=$2 --batch-size=$bs 33 | -------------------------------------------------------------------------------- /experiments/google_commands_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path) (alpha)` 3 | # where msda_mode is one of [fmix, mixup, None] 4 | # For multiple GPU, add --parallel=True 5 | 6 | python ../trainer.py --dataset=commands --epoch=90 --schedule 30 60 80 --lr=0.01 --dataset-path=$2 --msda-mode=$1 --alpha=$3 7 | -------------------------------------------------------------------------------- /experiments/imagenet_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)` 4 | # where msda_mode is one of [fmix, mixup, None] 5 | # For multiple GPU, add --parallel=True 6 | 7 | python ../trainer.py --dataset=imagenet --epoch=90 --model=torch_resnet101 --schedule 30 60 80 --batch-size=256 --lr=0.4 --lr-warmup=True --dataset-path=$2 --msda-mode=$1 8 | -------------------------------------------------------------------------------- /experiments/modelnet_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)` 3 | # where msda_mode is one of [fmix, mixup, None] 4 | # For multiple GPU, add --parallel=True 5 | 6 | python ../trainer.py --dataset=modelnet --epoch=50 --schedule 10 20 30 40 --lr=0.001 --dataset-path=$2 --msda-mode=$1 --batch-size=16 7 | -------------------------------------------------------------------------------- /experiments/tiny_imagenet_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)` 3 | # where msda_mode is one of [fmix, mixup, None] 4 | # For multiple GPU, add --parallel=True 5 | 6 | python ../trainer.py --dataset=tinyimagenet --epoch=200 --schedule 150 180 --batch-size=128 --lr=0.1 --dataset-path=$2 --msda-mode=$1 7 | -------------------------------------------------------------------------------- /experiments/toxic_experiment.sh: -------------------------------------------------------------------------------- 1 | 2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)` 3 | # where msda_mode is one of [fmix, mixup, None] 4 | # For multiple GPU, add --parallel=True 5 | 6 | python ../trainer.py --dataset=toxic --epoch=10 --batch-size=64 --lr=1e-4 --dataset-path=$2 --msda-mode=$1 7 | -------------------------------------------------------------------------------- /fmix.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | from scipy.stats import beta 6 | 7 | 8 | def fftfreqnd(h, w=None, z=None): 9 | """ Get bin values for discrete fourier transform of size (h, w, z) 10 | 11 | :param h: Required, first dimension size 12 | :param w: Optional, second dimension size 13 | :param z: Optional, third dimension size 14 | """ 15 | fz = fx = 0 16 | fy = np.fft.fftfreq(h) 17 | 18 | if w is not None: 19 | fy = np.expand_dims(fy, -1) 20 | 21 | if w % 2 == 1: 22 | fx = np.fft.fftfreq(w)[: w // 2 + 2] 23 | else: 24 | fx = np.fft.fftfreq(w)[: w // 2 + 1] 25 | 26 | if z is not None: 27 | fy = np.expand_dims(fy, -1) 28 | if z % 2 == 1: 29 | fz = np.fft.fftfreq(z)[:, None] 30 | else: 31 | fz = np.fft.fftfreq(z)[:, None] 32 | 33 | return np.sqrt(fx * fx + fy * fy + fz * fz) 34 | 35 | 36 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0): 37 | """ Samples a fourier image with given size and frequencies decayed by decay power 38 | 39 | :param freqs: Bin values for the discrete fourier transform 40 | :param decay_power: Decay power for frequency decay prop 1/f**d 41 | :param ch: Number of channels for the resulting mask 42 | :param h: Required, first dimension size 43 | :param w: Optional, second dimension size 44 | :param z: Optional, third dimension size 45 | """ 46 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power) 47 | 48 | param_size = [ch] + list(freqs.shape) + [2] 49 | param = np.random.randn(*param_size) 50 | 51 | scale = np.expand_dims(scale, -1)[None, :] 52 | 53 | return scale * param 54 | 55 | 56 | def make_low_freq_image(decay, shape, ch=1): 57 | """ Sample a low frequency image from fourier space 58 | 59 | :param decay_power: Decay power for frequency decay prop 1/f**d 60 | :param shape: Shape of desired mask, list up to 3 dims 61 | :param ch: Number of channels for desired mask 62 | """ 63 | freqs = fftfreqnd(*shape) 64 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1)) 65 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1] 66 | mask = np.real(np.fft.irfftn(spectrum, shape)) 67 | 68 | if len(shape) == 1: 69 | mask = mask[:1, :shape[0]] 70 | if len(shape) == 2: 71 | mask = mask[:1, :shape[0], :shape[1]] 72 | if len(shape) == 3: 73 | mask = mask[:1, :shape[0], :shape[1], :shape[2]] 74 | 75 | mask = mask 76 | mask = (mask - mask.min()) 77 | mask = mask / mask.max() 78 | return mask 79 | 80 | 81 | def sample_lam(alpha, reformulate=False): 82 | """ Sample a lambda from symmetric beta distribution with given alpha 83 | 84 | :param alpha: Alpha value for beta distribution 85 | :param reformulate: If True, uses the reformulation of [1]. 86 | """ 87 | if reformulate: 88 | lam = beta.rvs(alpha+1, alpha) 89 | else: 90 | lam = beta.rvs(alpha, alpha) 91 | 92 | return lam 93 | 94 | 95 | def binarise_mask(mask, lam, in_shape, max_soft=0.0): 96 | """ Binarises a given low frequency image such that it has mean lambda. 97 | 98 | :param mask: Low frequency image, usually the result of `make_low_freq_image` 99 | :param lam: Mean value of final mask 100 | :param in_shape: Shape of inputs 101 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 102 | :return: 103 | """ 104 | idx = mask.reshape(-1).argsort()[::-1] 105 | mask = mask.reshape(-1) 106 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size) 107 | 108 | eff_soft = max_soft 109 | if max_soft > lam or max_soft > (1-lam): 110 | eff_soft = min(lam, 1-lam) 111 | 112 | soft = int(mask.size * eff_soft) 113 | num_low = num - soft 114 | num_high = num + soft 115 | 116 | mask[idx[:num_high]] = 1 117 | mask[idx[num_low:]] = 0 118 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low)) 119 | 120 | mask = mask.reshape((1, *in_shape)) 121 | return mask 122 | 123 | 124 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False): 125 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises 126 | it based on this lambda 127 | 128 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 129 | :param decay_power: Decay power for frequency decay prop 1/f**d 130 | :param shape: Shape of desired mask, list up to 3 dims 131 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 132 | :param reformulate: If True, uses the reformulation of [1]. 133 | """ 134 | if isinstance(shape, int): 135 | shape = (shape,) 136 | 137 | # Choose lambda 138 | lam = sample_lam(alpha, reformulate) 139 | 140 | # Make mask, get mean / std 141 | mask = make_low_freq_image(decay_power, shape) 142 | mask = binarise_mask(mask, lam, shape, max_soft) 143 | 144 | return lam, mask 145 | 146 | 147 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False): 148 | """ 149 | 150 | :param x: Image batch on which to apply fmix of shape [b, c, shape*] 151 | :param alpha: Alpha value for beta distribution from which to sample mean of mask 152 | :param decay_power: Decay power for frequency decay prop 1/f**d 153 | :param shape: Shape of desired mask, list up to 3 dims 154 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask. 155 | :param reformulate: If True, uses the reformulation of [1]. 156 | :return: mixed input, permutation indices, lambda value of mix, 157 | """ 158 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate) 159 | index = np.random.permutation(x.shape[0]) 160 | 161 | x1, x2 = x * mask, x[index] * (1-mask) 162 | return x1+x2, index, lam 163 | 164 | 165 | class FMixBase: 166 | r""" FMix augmentation 167 | 168 | Args: 169 | decay_power (float): Decay power for frequency decay prop 1/f**d 170 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 171 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims 172 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 173 | reformulate (bool): If True, uses the reformulation of [1]. 174 | """ 175 | 176 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 177 | super().__init__() 178 | self.decay_power = decay_power 179 | self.reformulate = reformulate 180 | self.size = size 181 | self.alpha = alpha 182 | self.max_soft = max_soft 183 | self.index = None 184 | self.lam = None 185 | 186 | def __call__(self, x): 187 | raise NotImplementedError 188 | 189 | def loss(self, *args, **kwargs): 190 | raise NotImplementedError 191 | -------------------------------------------------------------------------------- /fmix_3d.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/fmix_3d.gif -------------------------------------------------------------------------------- /fmix_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/fmix_example.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'torchvision'] 2 | 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | def _preact_resnet18(msda='fmix', pretrained=False, repeat=0, *args, **kwargs): 7 | from models import ResNet18 8 | model = ResNet18(*args, **kwargs) 9 | 10 | if pretrained: 11 | state = load_state_dict_from_url( 12 | f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/preact-resnet18/cifar10_preact_resnet18_{msda}_{repeat}.pt', progress=True) 13 | model.load_state_dict(state) 14 | 15 | return model 16 | 17 | 18 | def _resnet101(msda='fmix', pretrained=False, *args, **kwargs): 19 | from torchvision.models.resnet import resnet101 20 | model = resnet101(*args, **kwargs) 21 | 22 | if pretrained: 23 | state = load_state_dict_from_url( 24 | f'http://marc.ecs.soton.ac.uk/pytorch-models/imagenet/resnet101/imagenet_resnet101_{msda}.pt', progress=True) 25 | model.load_state_dict(state) 26 | 27 | return model 28 | 29 | 30 | def _pyramidnet(msda='fmix', pretrained=False, *args, **kwargs): 31 | from models import aa_PyramidNet 32 | model = aa_PyramidNet(*args, **kwargs) 33 | 34 | if pretrained: 35 | state = load_state_dict_from_url( 36 | f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/pyramidnet/cifar10_pyramidnet_{msda}.pt', progress=True) 37 | model.load_state_dict(state) 38 | 39 | return model 40 | 41 | 42 | def preact_resnet18_cifar10_baseline(pretrained=False, repeat=0, *args, **kwargs): 43 | return _preact_resnet18('baseline', pretrained, repeat, *args, **kwargs) 44 | 45 | 46 | def preact_resnet18_cifar10_fmix(pretrained=False, repeat=0, *args, **kwargs): 47 | return _preact_resnet18('fmix', pretrained, repeat, *args, **kwargs) 48 | 49 | 50 | def preact_resnet18_cifar10_mixup(pretrained=False, repeat=0, *args, **kwargs): 51 | return _preact_resnet18('mixup', pretrained, repeat, *args, **kwargs) 52 | 53 | 54 | def preact_resnet18_cifar10_cutmix(pretrained=False, repeat=0, *args, **kwargs): 55 | return _preact_resnet18('cutmix', pretrained, repeat, *args, **kwargs) 56 | 57 | 58 | def preact_resnet18_cifar10_fmixplusmixup(pretrained=False, repeat=0, *args, **kwargs): 59 | return _preact_resnet18('fmixplusmixup', pretrained, repeat, *args, **kwargs) 60 | 61 | 62 | def pyramidnet_cifar10_baseline(pretrained=False, *args, **kwargs): 63 | return _pyramidnet('baseline', pretrained, *args, **kwargs) 64 | 65 | 66 | def pyramidnet_cifar10_fmix(pretrained=False, *args, **kwargs): 67 | return _pyramidnet('fmix', pretrained, *args, **kwargs) 68 | 69 | 70 | def pyramidnet_cifar10_mixup(pretrained=False, *args, **kwargs): 71 | return _pyramidnet('mixup', pretrained, *args, **kwargs) 72 | 73 | 74 | def pyramidnet_cifar10_cutmix(pretrained=False, *args, **kwargs): 75 | return _pyramidnet('cutmix', pretrained, *args, **kwargs) 76 | 77 | 78 | def renset101_imagenet_baseline(pretrained=False, *args, **kwargs): 79 | return _resnet101('baseline', pretrained, *args, **kwargs) 80 | 81 | 82 | def renset101_imagenet_fmix(pretrained=False, *args, **kwargs): 83 | return _resnet101('fmix', pretrained, *args, **kwargs) 84 | 85 | 86 | def renset101_imagenet_mixup(pretrained=False, *args, **kwargs): 87 | return _resnet101('mixup', pretrained, *args, **kwargs) 88 | 89 | 90 | def renset101_imagenet_cutmix(pretrained=False, *args, **kwargs): 91 | return _resnet101('cutmix', pretrained, *args, **kwargs) 92 | -------------------------------------------------------------------------------- /implementations/lightning.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from fmix import sample_mask, FMixBase 3 | import torch 4 | 5 | 6 | def fmix_loss(input, y1, index, lam, train=True, reformulate=False): 7 | r"""Criterion for fmix 8 | 9 | Args: 10 | input: If train, mixed input. If not train, standard input 11 | y1: Targets for first image 12 | index: Permutation for mixing 13 | lam: Lambda value of mixing 14 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1 15 | """ 16 | 17 | if train and not reformulate: 18 | y2 = y1[index] 19 | return F.cross_entropy(input, y1) * lam + F.cross_entropy(input, y2) * (1 - lam) 20 | else: 21 | return F.cross_entropy(input, y1) 22 | 23 | 24 | class FMix(FMixBase): 25 | r""" FMix augmentation 26 | 27 | Args: 28 | decay_power (float): Decay power for frequency decay prop 1/f**d 29 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 30 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims 31 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 32 | reformulate (bool): If True, uses the reformulation of [1]. 33 | 34 | Example 35 | ------- 36 | 37 | .. code-block:: python 38 | 39 | class FMixExp(pl.LightningModule): 40 | def __init__(*args, **kwargs): 41 | self.fmix = Fmix(...) 42 | # ... 43 | 44 | def training_step(self, batch, batch_idx): 45 | x, y = batch 46 | x = self.fmix(x) 47 | 48 | feature_maps = self.forward(x) 49 | logits = self.classifier(feature_maps) 50 | loss = self.fmix.loss(logits, y) 51 | 52 | # ... 53 | return loss 54 | """ 55 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 56 | super().__init__(decay_power, alpha, size, max_soft, reformulate) 57 | 58 | def __call__(self, x): 59 | # Sample mask and generate random permutation 60 | lam, mask = sample_mask(self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate) 61 | index = torch.randperm(x.size(0)).to(x.device) 62 | mask = torch.from_numpy(mask).float().to(x.device) 63 | 64 | # Mix the images 65 | x1 = mask * x 66 | x2 = (1 - mask) * x[index] 67 | self.index = index 68 | self.lam = lam 69 | return x1+x2 70 | 71 | def loss(self, y_pred, y, train=True): 72 | return fmix_loss(y_pred, y, self.index, self.lam, train, self.reformulate) 73 | -------------------------------------------------------------------------------- /implementations/tensorflow_implementation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from fmix import sample_mask, FMixBase 4 | softmax_cross_entropy_with_logits = tf.nn.softmax_cross_entropy_with_logits 5 | 6 | 7 | def fmix_loss(input, y1, index, lam, train=True, reformulate=False): 8 | r"""Criterion for fmix 9 | 10 | Args: 11 | input: If train, mixed input. If not train, standard input 12 | y1: Targets for first image 13 | y2: Targets for image mixed with first image 14 | lam: Lambda value of mixing 15 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1 16 | """ 17 | 18 | if train and not reformulate: 19 | y2 = tf.gather(y1, index) 20 | y1, y2 = tf.transpose(tf.one_hot(y1, 10, axis=0)), tf.transpose(tf.one_hot(y2, 10, axis=0)) 21 | return softmax_cross_entropy_with_logits(logits=input, labels=y1) * lam + softmax_cross_entropy_with_logits(logits=input, labels=y2) * (1-lam) 22 | else: 23 | y1 = tf.transpose(tf.one_hot(y1, 10, axis=0)) 24 | return softmax_cross_entropy_with_logits(logits=input, labels=y1) 25 | 26 | 27 | class FMix(FMixBase): 28 | r""" FMix augmentation 29 | 30 | Args: 31 | decay_power (float): Decay power for frequency decay prop 1/f**d 32 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 33 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims 34 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 35 | reformulate (bool): If True, uses the reformulation of [1]. 36 | 37 | Example 38 | ---------- 39 | 40 | fmix = FMix(...) 41 | 42 | def loss(model, x, y, training=True): 43 | x = fmix(x) 44 | y_ = model(x, training=training) 45 | return tf.reduce_mean(fmix.loss(y_, y)) 46 | """ 47 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 48 | super().__init__(decay_power, alpha, size, max_soft, reformulate) 49 | 50 | def __call__(self, x): 51 | shape = [int(s) for s in x.shape][1:-1] 52 | lam, mask = sample_mask(self.alpha, self.decay_power, shape, self.max_soft, self.reformulate) 53 | index = np.random.permutation(int(x.shape[0])) 54 | index = tf.constant(index) 55 | mask = np.expand_dims(mask, -1) 56 | 57 | x1 = x * mask 58 | x2 = tf.gather(x, index) * (1 - mask) 59 | self.index = index 60 | self.lam = lam 61 | 62 | return x1 + x2 63 | 64 | def loss(self, y_pred, y, train=True): 65 | return fmix_loss(y_pred, y, self.index, self.lam, train, self.reformulate) 66 | -------------------------------------------------------------------------------- /implementations/test_lightning.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms, models 2 | import torch 3 | from torch import optim 4 | from implementations.lightning import FMix 5 | from pytorch_lightning import LightningModule, Trainer, data_loader 6 | 7 | 8 | # ######### Data 9 | print('==> Preparing data..') 10 | classes, cifar = 10, datasets.CIFAR10 11 | 12 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 13 | transform_base = [transforms.ToTensor(), normalize] 14 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base 15 | 16 | transform_train = transforms.Compose(transform) 17 | transform_test = transforms.Compose(transform_base) 18 | trainset = cifar(root='./data', train=True, download=True, transform=transform_train) 19 | valset = cifar(root='./data', train=False, download=True, transform=transform_test) 20 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 21 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8) 22 | 23 | 24 | ######### Model 25 | print('==> Building model..') 26 | 27 | 28 | class FMixExp(LightningModule): 29 | def __init__(self): 30 | super().__init__() 31 | self.net = models.resnet18(False) 32 | self.fmix = FMix() 33 | 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | def training_step(self, batch, batch_nb): 38 | x, y = batch 39 | x = self.fmix(x) 40 | 41 | x = self.forward(x) 42 | 43 | loss = self.fmix.loss(x, y) 44 | return {'loss': loss} 45 | 46 | def validation_step(self, batch, batch_nb): 47 | x, y = batch 48 | 49 | x = self.forward(x) 50 | 51 | labels_hat = torch.argmax(x, dim=1) 52 | val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) 53 | val_acc = torch.tensor(val_acc) 54 | 55 | loss = self.fmix.loss(x, y, train=False) 56 | output = { 57 | 'val_loss': loss, 58 | 'val_acc': val_acc, 59 | } 60 | 61 | # can also return just a scalar instead of a dict (return loss_val) 62 | return output 63 | 64 | def validation_end(self, outputs): 65 | """ 66 | Called at the end of validation to aggregate outputs 67 | :param outputs: list of individual outputs of each validation step 68 | :return: 69 | """ 70 | # if returned a scalar from validation_step, outputs is a list of tensor scalars 71 | # we return just the average in this case (if we want) 72 | # return torch.stack(outputs).mean() 73 | 74 | val_loss_mean = 0 75 | val_acc_mean = 0 76 | for output in outputs: 77 | val_loss = output['val_loss'] 78 | 79 | # reduce manually when using dp 80 | if self.trainer.use_dp or self.trainer.use_ddp2: 81 | val_loss = torch.mean(val_loss) 82 | val_loss_mean += val_loss 83 | 84 | # reduce manually when using dp 85 | val_acc = output['val_acc'] 86 | if self.trainer.use_dp or self.trainer.use_ddp2: 87 | val_acc = torch.mean(val_acc) 88 | 89 | val_acc_mean += val_acc 90 | 91 | val_loss_mean /= len(outputs) 92 | val_acc_mean /= len(outputs) 93 | tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} 94 | result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': val_loss_mean} 95 | return result 96 | 97 | def configure_optimizers(self): 98 | return torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 99 | 100 | @data_loader 101 | def train_dataloader(self): 102 | return trainloader 103 | 104 | @data_loader 105 | def val_dataloader(self): 106 | return valloader 107 | 108 | 109 | ######### Train 110 | print('==> Starting training..') 111 | trainer = Trainer(gpus=1, early_stop_callback=False, max_epochs=20, checkpoint_callback=False) 112 | mod = FMixExp() 113 | trainer.fit(mod) 114 | -------------------------------------------------------------------------------- /implementations/test_tensorflow.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | from tensorflow.compat.v1 import enable_eager_execution 5 | from tensorflow.keras import datasets, layers, models 6 | 7 | from implementations.tensorflow_implementation import FMix 8 | 9 | enable_eager_execution() 10 | print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) 11 | 12 | 13 | (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() 14 | train_images, test_images = train_images / 255.0, test_images / 255.0 15 | train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) 16 | 17 | fmix = FMix() 18 | 19 | 20 | def loss(model, x, y, training=True): 21 | x = fmix(x) 22 | y_ = model(x, training=training) 23 | return tf.reduce_mean(fmix.loss(y_, y)) 24 | 25 | 26 | model = models.Sequential() 27 | model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3))) 28 | model.add(layers.MaxPooling2D((2, 2))) 29 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 30 | model.add(layers.MaxPooling2D((2, 2))) 31 | model.add(layers.Conv2D(64, (3, 3), activation='relu')) 32 | model.add(layers.Flatten()) 33 | model.add(layers.Dense(64, activation='relu')) 34 | model.add(layers.Dense(10)) 35 | 36 | optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) 37 | 38 | 39 | def train(model, train_images, train_labels): 40 | with tf.GradientTape() as t: 41 | current_loss = loss(model, train_images, train_labels) 42 | return current_loss, t.gradient(current_loss, model.trainable_variables) 43 | 44 | 45 | epochs = range(100) 46 | import tqdm 47 | epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 48 | 49 | 50 | for epoch in epochs: 51 | t = tqdm.tqdm() 52 | for batch in train_ds.shuffle(256).batch(128): 53 | x, y = batch 54 | x, y = tf.cast(x, 'float32'), tf.cast(y, 'int32')[:,0] 55 | current_loss, grads = train(model, x, y) 56 | optimizer.apply_gradients(zip(grads, model.trainable_variables)) 57 | epoch_accuracy(y, model(x, training=True)) 58 | t.update(1) 59 | t.set_postfix_str('Epoch: {}. Loss: {}. Acc: {}'.format(epoch, current_loss, epoch_accuracy.result())) 60 | t.close() -------------------------------------------------------------------------------- /implementations/test_torchbearer.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms, models 2 | import torch 3 | from torch import optim 4 | from implementations.torchbearer_implementation import FMix 5 | from torchbearer import Trial 6 | 7 | 8 | # ######### Data 9 | print('==> Preparing data..') 10 | classes, cifar = 10, datasets.CIFAR10 11 | 12 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 13 | transform_base = [transforms.ToTensor(), normalize] 14 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base 15 | 16 | transform_train = transforms.Compose(transform) 17 | transform_test = transforms.Compose(transform_base) 18 | trainset = cifar(root='./data', train=True, download=True, transform=transform_train) 19 | valset = cifar(root='./data', train=False, download=True, transform=transform_test) 20 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 21 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8) 22 | 23 | 24 | ######### Model 25 | print('==> Building model..') 26 | net = models.resnet18(False) 27 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 28 | fmix = FMix() 29 | criterion = fmix.loss() 30 | 31 | 32 | ######### Trial 33 | print('==> Starting training..') 34 | trial = Trial(net, optimizer, criterion, metrics=['acc', 'loss'], callbacks=[fmix]) 35 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda') 36 | trial.run(100, verbose=2) 37 | -------------------------------------------------------------------------------- /implementations/torchbearer_implementation.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torchbearer 3 | from torchbearer.callbacks import Callback 4 | from fmix import sample_mask, FMixBase 5 | import torch 6 | 7 | 8 | def fmix_loss(input, y, index, lam, train=True, reformulate=False, bce_loss=False): 9 | r"""Criterion for fmix 10 | 11 | Args: 12 | input: If train, mixed input. If not train, standard input 13 | y: Targets for first image 14 | index: Permutation for mixing 15 | lam: Lambda value of mixing 16 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1 17 | """ 18 | loss_fn = F.cross_entropy if not bce_loss else F.binary_cross_entropy_with_logits 19 | 20 | if train and not reformulate: 21 | y2 = y[index] 22 | return loss_fn(input, y) * lam + loss_fn(input, y2) * (1 - lam) 23 | else: 24 | return loss_fn(input, y) 25 | 26 | 27 | class FMix(FMixBase, Callback): 28 | r""" FMix augmentation 29 | 30 | Args: 31 | decay_power (float): Decay power for frequency decay prop 1/f**d 32 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 33 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims. -1 computes on the fly 34 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 35 | reformulate (bool): If True, uses the reformulation of [1]. 36 | 37 | Example 38 | ------- 39 | 40 | .. code-block:: python 41 | 42 | fmix = FMix(...) 43 | trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix]) 44 | # ... 45 | """ 46 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False): 47 | super().__init__(decay_power, alpha, size, max_soft, reformulate) 48 | 49 | def on_sample(self, state): 50 | super().on_sample(state) 51 | x, y = state[torchbearer.X], state[torchbearer.Y_TRUE] 52 | device = state[torchbearer.DEVICE] 53 | 54 | x = self(x) 55 | 56 | # Store the results 57 | state[torchbearer.X] = x 58 | state[torchbearer.Y_TRUE] = y 59 | 60 | # Set mixup flags 61 | state[torchbearer.MIXUP_LAMBDA] = torch.tensor([self.lam], device=device) if not self.reformulate else torch.tensor([1], device=device) 62 | state[torchbearer.MIXUP_PERMUTATION] = self.index 63 | 64 | def __call__(self, x): 65 | size = [] 66 | for i, s in enumerate(self.size): 67 | if s != -1: 68 | size.append(s) 69 | else: 70 | size.append(x.shape[i+1]) 71 | 72 | lam, mask = sample_mask(self.alpha, self.decay_power, size, self.max_soft, self.reformulate) 73 | index = torch.randperm(x.size(0)).to(x.device) 74 | mask = torch.from_numpy(mask).float().to(x.device) 75 | 76 | if len(self.size) == 1 and x.ndim == 3: 77 | mask = mask.unsqueeze(2) 78 | 79 | # Mix the images 80 | x1 = mask * x 81 | x2 = (1 - mask) * x[index] 82 | self.index = index 83 | self.lam = lam 84 | return x1 + x2 85 | 86 | def loss(self, use_bce=False): 87 | def _fmix_loss(state): 88 | y_pred = state[torchbearer.Y_PRED] 89 | y = state[torchbearer.Y_TRUE] 90 | index = state[torchbearer.MIXUP_PERMUTATION] if torchbearer.MIXUP_PERMUTATION in state else None 91 | lam = state[torchbearer.MIXUP_LAMBDA] if torchbearer.MIXUP_LAMBDA in state else None 92 | train = state[torchbearer.MODEL].training 93 | return fmix_loss(y_pred, y, index, lam, train, self.reformulate, use_bce) 94 | 95 | return _fmix_loss 96 | 97 | 98 | class PointNetFMix(FMix): 99 | def __init__(self, resolution, decay_power=3, alpha=1, max_soft=0.0, reformulate=False): 100 | super().__init__(decay_power, alpha, [resolution, resolution, resolution], max_soft, reformulate) 101 | self.res = resolution 102 | 103 | def __call__(self, x): 104 | import kaolin.conversions as cvt 105 | x = super().__call__(x) 106 | t = [] 107 | for i in range(x.shape[0]): 108 | t.append(cvt.voxelgrid_to_pointcloud(x[i], self.res, normalize=True)) 109 | return torch.stack(t) 110 | 111 | 112 | from torchbearer.metrics import default as d 113 | from utils.reformulated_mixup import MixupAcc 114 | d.__loss_map__[FMix().loss().__name__] = MixupAcc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wide_resnet import wrn 2 | from .densenet3 import DenseNet190 3 | from .pyramid import aa_PyramidNet 4 | from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 5 | from .senet import se_resnext50_32x4d 6 | from .toxic_lstm import LSTM 7 | from .toxic_cnn import CNN 8 | from .bert import Bert 9 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Bert(nn.Module): 6 | def __init__(self, num_classes, nc=None, bert_model='bert-base-cased'): 7 | super().__init__() 8 | from transformers import BertModel 9 | self.bert = BertModel.from_pretrained(bert_model, output_hidden_states=True) 10 | self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) 11 | 12 | def forward(self, x, token_type_ids=None, position_ids=None, head_mask=None): 13 | outputs = self.bert(x.long(), 14 | attention_mask=(x != 0).float(), 15 | token_type_ids=token_type_ids, 16 | position_ids=position_ids, 17 | head_mask=head_mask) 18 | cls_output = torch.cat(outputs[2], dim=1).mean(dim=1) 19 | cls_output = self.classifier(cls_output) # batch, 6 20 | return cls_output 21 | -------------------------------------------------------------------------------- /models/densenet3.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/andreasveit/densenet-pytorch 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 15 | padding=1, bias=False) 16 | self.droprate = dropRate 17 | 18 | def forward(self, x): 19 | out = self.conv1(self.relu(self.bn1(x))) 20 | if self.droprate > 0: 21 | out = F.dropout(out, p=self.droprate, training=self.training) 22 | return torch.cat([x, out], 1) 23 | 24 | 25 | class BottleneckBlock(nn.Module): 26 | def __init__(self, in_planes, out_planes, dropRate=0.0): 27 | super(BottleneckBlock, self).__init__() 28 | inter_planes = out_planes * 4 29 | self.bn1 = nn.BatchNorm2d(in_planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 32 | padding=0, bias=False) 33 | self.bn2 = nn.BatchNorm2d(inter_planes) 34 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 35 | padding=1, bias=False) 36 | self.droprate = dropRate 37 | 38 | def forward(self, x): 39 | out = self.conv1(self.relu(self.bn1(x))) 40 | if self.droprate > 0: 41 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 42 | out = self.conv2(self.relu(self.bn2(out))) 43 | if self.droprate > 0: 44 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 45 | return torch.cat([x, out], 1) 46 | 47 | 48 | class TransitionBlock(nn.Module): 49 | def __init__(self, in_planes, out_planes, dropRate=0.0): 50 | super(TransitionBlock, self).__init__() 51 | self.bn1 = nn.BatchNorm2d(in_planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 54 | padding=0, bias=False) 55 | self.droprate = dropRate 56 | 57 | def forward(self, x): 58 | out = self.conv1(self.relu(self.bn1(x))) 59 | if self.droprate > 0: 60 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 61 | return F.avg_pool2d(out, 2) 62 | 63 | 64 | class DenseBlock(nn.Module): 65 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 66 | super(DenseBlock, self).__init__() 67 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 68 | 69 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 70 | layers = [] 71 | for i in range(nb_layers): 72 | layers.append(block(in_planes + i * growth_rate, growth_rate, dropRate)) 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | return self.layer(x) 77 | 78 | 79 | class DenseNet3(nn.Module): 80 | def __init__(self, depth, num_classes, growth_rate=12, 81 | reduction=0.5, bottleneck=True, dropRate=0.0, efficient=False, nc=3): 82 | super(DenseNet3, self).__init__() 83 | in_planes = 2 * growth_rate 84 | n = (depth - 4) // 3 85 | if bottleneck == True: 86 | n = n // 2 87 | block = BottleneckBlock 88 | else: 89 | block = BasicBlock 90 | # 1st conv before any dense block 91 | self.conv1 = nn.Conv2d(nc, in_planes, kernel_size=3, stride=1, 92 | padding=1, bias=False) 93 | # 1st block 94 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 95 | in_planes = int(in_planes + n * growth_rate) 96 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes * reduction)), dropRate=dropRate) 97 | in_planes = int(math.floor(in_planes * reduction)) 98 | # 2nd block 99 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 100 | in_planes = int(in_planes + n * growth_rate) 101 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes * reduction)), dropRate=dropRate) 102 | in_planes = int(math.floor(in_planes * reduction)) 103 | # 3rd block 104 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 105 | in_planes = int(in_planes + n * growth_rate) 106 | # global average pooling and classifier 107 | self.bn1 = nn.BatchNorm2d(in_planes) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.fc = nn.Linear(in_planes, num_classes) 110 | self.in_planes = in_planes 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.Linear): 120 | m.bias.data.zero_() 121 | 122 | def forward(self, x): 123 | out = self.conv1(x) 124 | out = self.trans1(self.block1(out)) 125 | out = self.trans2(self.block2(out)) 126 | out = self.block3(out) 127 | out = self.relu(self.bn1(out)) 128 | # out = F.avg_pool2d(out, 8) 129 | out = F.adaptive_avg_pool2d(out, (1, 1)) 130 | out = out.view(-1, self.in_planes) 131 | return self.fc(out) 132 | 133 | 134 | def DenseNet190(num_classes=10, nc=3): 135 | return DenseNet3(190, num_classes, growth_rate=40, nc=nc) 136 | 137 | 138 | def EDenseNet190(): 139 | return DenseNet3(190, 10, growth_rate=40, efficient=True) 140 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import models 2 | from models import wrn 3 | import torchvision.models as m 4 | # from models.toxic_lstm import LSTM 5 | 6 | 7 | def get_model(args, classes, nc): 8 | # Load torchvision models with "torch_" prefix 9 | if 'torch' in args.model: 10 | return m.__dict__[args.model[6:]](num_classes=classes, pretrained=False) 11 | 12 | # Load the pyramidnet used for autoaugment experiments on cifar 13 | if args.model == 'aa_PyramidNet': 14 | return models.__dict__[args.model](dataset='cifar10', depth=272, alpha=200, num_classes=classes) 15 | 16 | # Load the WideResNet-28-10 17 | if args.model == 'wrn': 18 | return wrn(num_classes=classes, depth=28, widen_factor=10, nc=nc) 19 | 20 | if args.model == 'PointNet' or args.dataset == 'modelnet': 21 | from kaolin.models.PointNet import PointNetClassifier 22 | return PointNetClassifier(num_classes=classes) 23 | 24 | # if args.dataset == 'toxic': 25 | # return LSTM() 26 | 27 | # Otherwise return models from other files 28 | return models.__dict__[args.model](num_classes=classes, nc=nc) 29 | -------------------------------------------------------------------------------- /models/pyramid.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/networks/pyramidnet.py 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | class ShakeDropFunction(torch.autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 14 | if training: 15 | gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop) 16 | ctx.save_for_backward(gate) 17 | if gate.item() == 0: 18 | alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range) 19 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 20 | return alpha * x 21 | else: 22 | return x 23 | else: 24 | return (1 - p_drop) * x 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | gate = ctx.saved_tensors[0] 29 | if gate.item() == 0: 30 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1) 31 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 32 | beta = Variable(beta) 33 | return beta * grad_output, None, None, None 34 | else: 35 | return grad_output, None, None, None 36 | 37 | 38 | class ShakeDrop(nn.Module): 39 | 40 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 41 | super(ShakeDrop, self).__init__() 42 | self.p_drop = p_drop 43 | self.alpha_range = alpha_range 44 | 45 | def forward(self, x): 46 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 47 | 48 | 49 | def conv3x3(in_planes, out_planes, stride=1): 50 | """ 51 | 3x3 convolution with padding 52 | """ 53 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | outchannel_ratio = 1 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 60 | super(BasicBlock, self).__init__() 61 | self.bn1 = nn.BatchNorm2d(inplanes) 62 | self.conv1 = conv3x3(inplanes, planes, stride) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn3 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | self.shake_drop = ShakeDrop(p_shakedrop) 70 | 71 | def forward(self, x): 72 | 73 | out = self.bn1(x) 74 | out = self.conv1(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | out = self.conv2(out) 78 | out = self.bn3(out) 79 | 80 | out = self.shake_drop(out) 81 | 82 | if self.downsample is not None: 83 | shortcut = self.downsample(x) 84 | featuremap_size = shortcut.size()[2:4] 85 | else: 86 | shortcut = x 87 | featuremap_size = out.size()[2:4] 88 | 89 | batch_size = out.size()[0] 90 | residual_channel = out.size()[1] 91 | shortcut_channel = shortcut.size()[1] 92 | 93 | if residual_channel != shortcut_channel: 94 | padding = torch.autograd.Variable( 95 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 96 | featuremap_size[1]).fill_(0)) 97 | out += torch.cat((shortcut, padding), 1) 98 | else: 99 | out += shortcut 100 | 101 | return out 102 | 103 | 104 | class Bottleneck(nn.Module): 105 | outchannel_ratio = 4 106 | 107 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 108 | super(Bottleneck, self).__init__() 109 | self.bn1 = nn.BatchNorm2d(inplanes) 110 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(planes) 112 | self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, 113 | padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d((planes * 1)) 115 | self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 116 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | self.shake_drop = ShakeDrop(p_shakedrop) 121 | 122 | def forward(self, x): 123 | 124 | out = self.bn1(x) 125 | out = self.conv1(out) 126 | 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | out = self.conv2(out) 130 | 131 | out = self.bn3(out) 132 | out = self.relu(out) 133 | out = self.conv3(out) 134 | 135 | out = self.bn4(out) 136 | 137 | out = self.shake_drop(out) 138 | 139 | if self.downsample is not None: 140 | shortcut = self.downsample(x) 141 | featuremap_size = shortcut.size()[2:4] 142 | else: 143 | shortcut = x 144 | featuremap_size = out.size()[2:4] 145 | 146 | batch_size = out.size()[0] 147 | residual_channel = out.size()[1] 148 | shortcut_channel = shortcut.size()[1] 149 | 150 | if residual_channel != shortcut_channel: 151 | padding = torch.autograd.Variable( 152 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 153 | featuremap_size[1]).fill_(0)) 154 | out += torch.cat((shortcut, padding), 1) 155 | else: 156 | out += shortcut 157 | 158 | return out 159 | 160 | 161 | class aa_PyramidNet(nn.Module): 162 | def __init__(self, dataset='cifar10', depth=272, alpha=200, num_classes=10, bottleneck=True): 163 | super(aa_PyramidNet, self).__init__() 164 | self.dataset = dataset 165 | if self.dataset.startswith('cifar'): 166 | self.inplanes = 16 167 | if bottleneck: 168 | n = int((depth - 2) / 9) 169 | block = Bottleneck 170 | else: 171 | n = int((depth - 2) / 6) 172 | block = BasicBlock 173 | 174 | self.addrate = alpha / (3 * n * 1.0) 175 | self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)] 176 | 177 | self.input_featuremap_dim = self.inplanes 178 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 179 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 180 | 181 | self.featuremap_dim = self.input_featuremap_dim 182 | self.layer1 = self.pyramidal_make_layer(block, n) 183 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 184 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 185 | 186 | self.final_featuremap_dim = self.input_featuremap_dim 187 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 188 | self.relu_final = nn.ReLU(inplace=True) 189 | self.avgpool = nn.AvgPool2d(8) 190 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 191 | 192 | elif dataset == 'imagenet': 193 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 194 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 195 | 200: [3, 24, 36, 3]} 196 | 197 | if layers.get(depth) is None: 198 | if bottleneck == True: 199 | blocks[depth] = Bottleneck 200 | temp_cfg = int((depth - 2) / 12) 201 | else: 202 | blocks[depth] = BasicBlock 203 | temp_cfg = int((depth - 2) / 8) 204 | 205 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 206 | print('=> the layer configuration for each stage is set to', layers[depth]) 207 | 208 | self.inplanes = 64 209 | self.addrate = alpha / (sum(layers[depth]) * 1.0) 210 | 211 | self.input_featuremap_dim = self.inplanes 212 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 213 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 214 | self.relu = nn.ReLU(inplace=True) 215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 216 | 217 | self.featuremap_dim = self.input_featuremap_dim 218 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 219 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 220 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 221 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 222 | 223 | self.final_featuremap_dim = self.input_featuremap_dim 224 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 225 | self.relu_final = nn.ReLU(inplace=True) 226 | self.avgpool = nn.AvgPool2d(7) 227 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 232 | m.weight.data.normal_(0, math.sqrt(2. / n)) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | m.weight.data.fill_(1) 235 | m.bias.data.zero_() 236 | 237 | assert len(self.ps_shakedrop) == 0, self.ps_shakedrop 238 | 239 | def pyramidal_make_layer(self, block, block_depth, stride=1): 240 | downsample = None 241 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 242 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 243 | 244 | layers = [] 245 | self.featuremap_dim = self.featuremap_dim + self.addrate 246 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, 247 | p_shakedrop=self.ps_shakedrop.pop(0))) 248 | for i in range(1, block_depth): 249 | temp_featuremap_dim = self.featuremap_dim + self.addrate 250 | layers.append( 251 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, 252 | p_shakedrop=self.ps_shakedrop.pop(0))) 253 | self.featuremap_dim = temp_featuremap_dim 254 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 255 | 256 | return nn.Sequential(*layers) 257 | 258 | def forward(self, x): 259 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 260 | x = self.conv1(x) 261 | x = self.bn1(x) 262 | 263 | x = self.layer1(x) 264 | x = self.layer2(x) 265 | x = self.layer3(x) 266 | 267 | x = self.bn_final(x) 268 | x = self.relu_final(x) 269 | x = self.avgpool(x) 270 | x = x.view(x.size(0), -1) 271 | x = self.fc(x) 272 | 273 | elif self.dataset == 'imagenet': 274 | x = self.conv1(x) 275 | x = self.bn1(x) 276 | x = self.relu(x) 277 | x = self.maxpool(x) 278 | 279 | x = self.layer1(x) 280 | x = self.layer2(x) 281 | x = self.layer3(x) 282 | x = self.layer4(x) 283 | 284 | x = self.bn_final(x) 285 | x = self.relu_final(x) 286 | x = self.avgpool(x) 287 | x = x.view(x.size(0), -1) 288 | x = self.fc(x) 289 | 290 | return x 291 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/kuangliu/pytorch-cifar 2 | 3 | '''ResNet in PyTorch. 4 | 5 | BasicBlock and Bottleneck module is from the original ResNet paper: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | 9 | PreActBlock and PreActBottleneck module is from the later paper: 10 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 11 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 12 | ''' 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from torch.autograd import Variable 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, in_planes, planes, stride=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(in_planes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != self.expansion * planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 38 | nn.BatchNorm2d(self.expansion * planes) 39 | ) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.bn2(self.conv2(out)) 44 | out += self.shortcut(x) 45 | out = F.relu(out) 46 | return out 47 | 48 | 49 | class PreActBlock(nn.Module): 50 | '''Pre-activation version of the BasicBlock.''' 51 | expansion = 1 52 | 53 | def __init__(self, in_planes, planes, stride=1): 54 | super(PreActBlock, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(in_planes) 56 | self.conv1 = conv3x3(in_planes, planes, stride) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv2 = conv3x3(planes, planes) 59 | 60 | self.shortcut = nn.Sequential() 61 | if stride != 1 or in_planes != self.expansion * planes: 62 | self.shortcut = nn.Sequential( 63 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 64 | ) 65 | 66 | def forward(self, x): 67 | out = F.relu(self.bn1(x)) 68 | shortcut = self.shortcut(out) 69 | out = self.conv1(out) 70 | out = self.conv2(F.relu(self.bn2(out))) 71 | out += shortcut 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, in_planes, planes, stride=1): 79 | super(Bottleneck, self).__init__() 80 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 86 | 87 | self.shortcut = nn.Sequential() 88 | if stride != 1 or in_planes != self.expansion * planes: 89 | self.shortcut = nn.Sequential( 90 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 91 | nn.BatchNorm2d(self.expansion * planes) 92 | ) 93 | 94 | def forward(self, x): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = F.relu(self.bn2(self.conv2(out))) 97 | out = self.bn3(self.conv3(out)) 98 | out += self.shortcut(x) 99 | out = F.relu(out) 100 | return out 101 | 102 | 103 | class PreActBottleneck(nn.Module): 104 | '''Pre-activation version of the original Bottleneck module.''' 105 | expansion = 4 106 | 107 | def __init__(self, in_planes, planes, stride=1): 108 | super(PreActBottleneck, self).__init__() 109 | self.bn1 = nn.BatchNorm2d(in_planes) 110 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(planes) 112 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 113 | self.bn3 = nn.BatchNorm2d(planes) 114 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 115 | 116 | self.shortcut = nn.Sequential() 117 | if stride != 1 or in_planes != self.expansion * planes: 118 | self.shortcut = nn.Sequential( 119 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 120 | ) 121 | 122 | def forward(self, x): 123 | out = F.relu(self.bn1(x)) 124 | shortcut = self.shortcut(out) 125 | out = self.conv1(out) 126 | out = self.conv2(F.relu(self.bn2(out))) 127 | out = self.conv3(F.relu(self.bn3(out))) 128 | out += shortcut 129 | return out 130 | 131 | 132 | class ResNet(nn.Module): 133 | def __init__(self, block, num_blocks, num_classes=10, nc=3): 134 | super(ResNet, self).__init__() 135 | self.in_planes = 64 136 | 137 | self.conv1 = conv3x3(nc, 64) 138 | self.bn1 = nn.BatchNorm2d(64) 139 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 140 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 141 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 142 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 143 | self.linear = nn.Linear(512 * block.expansion, num_classes) 144 | 145 | def _make_layer(self, block, planes, num_blocks, stride): 146 | strides = [stride] + [1] * (num_blocks - 1) 147 | layers = [] 148 | for stride in strides: 149 | layers.append(block(self.in_planes, planes, stride)) 150 | self.in_planes = planes * block.expansion 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x, lin=0, lout=5): 154 | out = x 155 | if lin < 1 and lout > -1: 156 | out = self.conv1(out) 157 | out = self.bn1(out) 158 | out = F.relu(out) 159 | if lin < 2 and lout > 0: 160 | out = self.layer1(out) 161 | if lin < 3 and lout > 1: 162 | out = self.layer2(out) 163 | if lin < 4 and lout > 2: 164 | out = self.layer3(out) 165 | if lin < 5 and lout > 3: 166 | out = self.layer4(out) 167 | if lout > 4: 168 | # out = F.avg_pool2d(out, 4) 169 | out = F.adaptive_avg_pool2d(out, (1, 1)) 170 | 171 | out = out.view(out.size(0), -1) 172 | out = self.linear(out) 173 | return out 174 | 175 | 176 | def ResNet18(num_classes=10, nc=3): 177 | return ResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes, nc=nc) 178 | 179 | 180 | def ResNet34(num_classes=10): 181 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 182 | 183 | 184 | def ResNet50(num_classes=10): 185 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 186 | 187 | 188 | def ResNet101(num_classes=10): 189 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 190 | 191 | 192 | def ResNet152(num_classes=10): 193 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 194 | 195 | 196 | def test(): 197 | net = ResNet18() 198 | y = net(Variable(torch.randn(1, 3, 32, 32))) 199 | print(y.size()) 200 | 201 | 202 | def resnet(): 203 | return ResNet18() 204 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from 3 | https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py 4 | 5 | ResNet code gently borrowed from 6 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | """ 8 | from __future__ import print_function, division, absolute_import 9 | from collections import OrderedDict 10 | import math 11 | 12 | import torch.nn as nn 13 | 14 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 15 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 16 | 17 | 18 | class SEModule(nn.Module): 19 | 20 | def __init__(self, channels, reduction): 21 | super(SEModule, self).__init__() 22 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 23 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 24 | padding=0) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 27 | padding=0) 28 | self.sigmoid = nn.Sigmoid() 29 | 30 | def forward(self, x): 31 | module_input = x 32 | x = self.avg_pool(x) 33 | x = self.fc1(x) 34 | x = self.relu(x) 35 | x = self.fc2(x) 36 | x = self.sigmoid(x) 37 | return module_input * x 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | """ 42 | Base class for bottlenecks that implements `forward()` method. 43 | """ 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out = self.se_module(out) + residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class SEBottleneck(Bottleneck): 68 | """ 69 | Bottleneck for SENet154. 70 | """ 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 74 | downsample=None): 75 | super(SEBottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(planes * 2) 78 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 79 | stride=stride, padding=1, groups=groups, 80 | bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes * 4) 82 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 83 | bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes * 4) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.se_module = SEModule(planes * 4, reduction=reduction) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | 91 | class SEResNetBottleneck(Bottleneck): 92 | """ 93 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 94 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 95 | (the latter is used in the torchvision implementation of ResNet). 96 | """ 97 | expansion = 4 98 | 99 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 100 | downsample=None): 101 | super(SEResNetBottleneck, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 103 | stride=stride) 104 | self.bn1 = nn.BatchNorm2d(planes) 105 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 106 | groups=groups, bias=False) 107 | self.bn2 = nn.BatchNorm2d(planes) 108 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 109 | self.bn3 = nn.BatchNorm2d(planes * 4) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.se_module = SEModule(planes * 4, reduction=reduction) 112 | self.downsample = downsample 113 | self.stride = stride 114 | 115 | 116 | class SEResNeXtBottleneck(Bottleneck): 117 | """ 118 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 119 | """ 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 123 | downsample=None, base_width=4): 124 | super(SEResNeXtBottleneck, self).__init__() 125 | width = math.floor(planes * (base_width / 64)) * groups 126 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 127 | stride=1) 128 | self.bn1 = nn.BatchNorm2d(width) 129 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 130 | padding=1, groups=groups, bias=False) 131 | self.bn2 = nn.BatchNorm2d(width) 132 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 133 | self.bn3 = nn.BatchNorm2d(planes * 4) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.se_module = SEModule(planes * 4, reduction=reduction) 136 | self.downsample = downsample 137 | self.stride = stride 138 | 139 | 140 | class SENet(nn.Module): 141 | 142 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 143 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 144 | downsample_padding=1, num_classes=1000, nc=3): 145 | """ 146 | Parameters 147 | ---------- 148 | block (nn.Module): Bottleneck class. 149 | - For SENet154: SEBottleneck 150 | - For SE-ResNet models: SEResNetBottleneck 151 | - For SE-ResNeXt models: SEResNeXtBottleneck 152 | layers (list of ints): Number of residual blocks for 4 layers of the 153 | network (layer1...layer4). 154 | groups (int): Number of groups for the 3x3 convolution in each 155 | bottleneck block. 156 | - For SENet154: 64 157 | - For SE-ResNet models: 1 158 | - For SE-ResNeXt models: 32 159 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 160 | - For all models: 16 161 | dropout_p (float or None): Drop probability for the Dropout layer. 162 | If `None` the Dropout layer is not used. 163 | - For SENet154: 0.2 164 | - For SE-ResNet models: None 165 | - For SE-ResNeXt models: None 166 | inplanes (int): Number of input channels for layer1. 167 | - For SENet154: 128 168 | - For SE-ResNet models: 64 169 | - For SE-ResNeXt models: 64 170 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 171 | a single 7x7 convolution in layer0. 172 | - For SENet154: True 173 | - For SE-ResNet models: False 174 | - For SE-ResNeXt models: False 175 | downsample_kernel_size (int): Kernel size for downsampling convolutions 176 | in layer2, layer3 and layer4. 177 | - For SENet154: 3 178 | - For SE-ResNet models: 1 179 | - For SE-ResNeXt models: 1 180 | downsample_padding (int): Padding for downsampling convolutions in 181 | layer2, layer3 and layer4. 182 | - For SENet154: 1 183 | - For SE-ResNet models: 0 184 | - For SE-ResNeXt models: 0 185 | num_classes (int): Number of outputs in `last_linear` layer. 186 | - For all models: 1000 187 | """ 188 | super(SENet, self).__init__() 189 | self.inplanes = inplanes 190 | if input_3x3: 191 | layer0_modules = [ 192 | ('conv1', nn.Conv2d(nc, 64, 3, stride=2, padding=1, 193 | bias=False)), 194 | ('bn1', nn.BatchNorm2d(64)), 195 | ('relu1', nn.ReLU(inplace=True)), 196 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 197 | bias=False)), 198 | ('bn2', nn.BatchNorm2d(64)), 199 | ('relu2', nn.ReLU(inplace=True)), 200 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 201 | bias=False)), 202 | ('bn3', nn.BatchNorm2d(inplanes)), 203 | ('relu3', nn.ReLU(inplace=True)), 204 | ] 205 | else: 206 | layer0_modules = [ 207 | ('conv1', nn.Conv2d(nc, inplanes, kernel_size=7, stride=2, 208 | padding=3, bias=False)), 209 | ('bn1', nn.BatchNorm2d(inplanes)), 210 | ('relu1', nn.ReLU(inplace=True)), 211 | ] 212 | # To preserve compatibility with Caffe weights `ceil_mode=True` 213 | # is used instead of `padding=1`. 214 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 215 | ceil_mode=True))) 216 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 217 | self.layer1 = self._make_layer( 218 | block, 219 | planes=64, 220 | blocks=layers[0], 221 | groups=groups, 222 | reduction=reduction, 223 | downsample_kernel_size=1, 224 | downsample_padding=0 225 | ) 226 | self.layer2 = self._make_layer( 227 | block, 228 | planes=128, 229 | blocks=layers[1], 230 | stride=2, 231 | groups=groups, 232 | reduction=reduction, 233 | downsample_kernel_size=downsample_kernel_size, 234 | downsample_padding=downsample_padding 235 | ) 236 | self.layer3 = self._make_layer( 237 | block, 238 | planes=256, 239 | blocks=layers[2], 240 | stride=2, 241 | groups=groups, 242 | reduction=reduction, 243 | downsample_kernel_size=downsample_kernel_size, 244 | downsample_padding=downsample_padding 245 | ) 246 | self.layer4 = self._make_layer( 247 | block, 248 | planes=512, 249 | blocks=layers[3], 250 | stride=2, 251 | groups=groups, 252 | reduction=reduction, 253 | downsample_kernel_size=downsample_kernel_size, 254 | downsample_padding=downsample_padding 255 | ) 256 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 257 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 258 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 259 | 260 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 261 | downsample_kernel_size=1, downsample_padding=0): 262 | downsample = None 263 | if stride != 1 or self.inplanes != planes * block.expansion: 264 | downsample = nn.Sequential( 265 | nn.Conv2d(self.inplanes, planes * block.expansion, 266 | kernel_size=downsample_kernel_size, stride=stride, 267 | padding=downsample_padding, bias=False), 268 | nn.BatchNorm2d(planes * block.expansion), 269 | ) 270 | 271 | layers = [] 272 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 273 | downsample)) 274 | self.inplanes = planes * block.expansion 275 | for i in range(1, blocks): 276 | layers.append(block(self.inplanes, planes, groups, reduction)) 277 | 278 | return nn.Sequential(*layers) 279 | 280 | def features(self, x): 281 | x = self.layer0(x) 282 | x = self.layer1(x) 283 | x = self.layer2(x) 284 | x = self.layer3(x) 285 | x = self.layer4(x) 286 | return x 287 | 288 | def logits(self, x): 289 | x = self.avg_pool(x) 290 | if self.dropout is not None: 291 | x = self.dropout(x) 292 | x = x.view(x.size(0), -1) 293 | x = self.last_linear(x) 294 | return x 295 | 296 | def forward(self, x): 297 | x = self.features(x) 298 | x = self.logits(x) 299 | return x 300 | 301 | 302 | def senet154(num_classes=10): 303 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, 304 | dropout_p=0.2, num_classes=num_classes) 305 | return model 306 | 307 | 308 | def se_resnet50(num_classes=10): 309 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, 310 | dropout_p=None, inplanes=64, input_3x3=False, 311 | downsample_kernel_size=1, downsample_padding=0, 312 | num_classes=num_classes) 313 | return model 314 | 315 | 316 | def se_resnet101(num_classes=10): 317 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, 318 | dropout_p=None, inplanes=64, input_3x3=False, 319 | downsample_kernel_size=1, downsample_padding=0, 320 | num_classes=num_classes) 321 | return model 322 | 323 | 324 | def se_resnet152(num_classes=10): 325 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, 326 | dropout_p=None, inplanes=64, input_3x3=False, 327 | downsample_kernel_size=1, downsample_padding=0, 328 | num_classes=num_classes) 329 | return model 330 | 331 | 332 | def se_resnext50_32x4d(num_classes=10, nc=3): 333 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 334 | dropout_p=None, inplanes=64, input_3x3=False, 335 | downsample_kernel_size=1, downsample_padding=0, 336 | num_classes=num_classes, nc=nc) 337 | return model 338 | 339 | 340 | def se_resnext101_32x4d(num_classes=10, nc=3): 341 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 342 | dropout_p=None, inplanes=64, input_3x3=False, 343 | downsample_kernel_size=1, downsample_padding=0, 344 | num_classes=num_classes, nc=nc) 345 | return model 346 | -------------------------------------------------------------------------------- /models/toxic_cnn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class CNN(nn.Module): 7 | def __init__(self, num_classes=6, nl=2, nc=300, hidden_sz=128): 8 | super(CNN, self).__init__() 9 | self.hidden_sz = hidden_sz 10 | self.emb_sz = nc 11 | self.embeddings = None 12 | 13 | self.conv = nn.Sequential( 14 | nn.Conv1d(nc, hidden_sz, 3, padding=1), 15 | nn.ReLU(True), 16 | nn.Conv1d(hidden_sz, hidden_sz, 3, padding=1), 17 | nn.ReLU(True), 18 | nn.Conv1d(hidden_sz, hidden_sz, 3, padding=1), 19 | nn.ReLU() 20 | ) 21 | 22 | layers = [] 23 | for i in range(nl): 24 | if i == 0: 25 | layers.append(nn.Linear(hidden_sz * 3, hidden_sz)) 26 | else: 27 | layers.append(nn.Linear(hidden_sz, hidden_sz)) 28 | layers.append(nn.ReLU()) 29 | self.layers = nn.Sequential(*layers) 30 | self.output = nn.Linear(hidden_sz, num_classes) 31 | 32 | def init_embedding(self, vectors, n_tokens, device): 33 | self.embeddings = nn.Embedding(n_tokens, self.emb_sz).to(device) 34 | self.embeddings.weight.data.copy_(vectors.to(device)) 35 | 36 | def embed(self, data): 37 | embedded = self.embeddings(data) 38 | return embedded 39 | 40 | def forward(self, embedded): 41 | x = self.conv(embedded.permute(0, 2, 1)) 42 | 43 | avg_pool = F.adaptive_avg_pool1d(x, 1).view(embedded.size(0), -1) 44 | max_pool = F.adaptive_max_pool1d(x, 1).view(embedded.size(0), -1) 45 | x = torch.cat([avg_pool, max_pool, x.permute(0, 2, 1)[:, -1]], dim=1) 46 | x = self.layers(x) 47 | res = self.output(x) 48 | if res.size(1) == 1: 49 | res = res.squeeze(1) 50 | return res 51 | -------------------------------------------------------------------------------- /models/toxic_lstm.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class LSTM(nn.Module): 7 | def __init__(self, num_classes=6, nl=2, bidirectional=True, nc=300, hidden_sz=128): 8 | super(LSTM, self).__init__() 9 | self.hidden_sz = hidden_sz 10 | self.emb_sz = nc 11 | self.embeddings = None 12 | 13 | self.rnn = nn.LSTM(nc, hidden_sz, num_layers=2, bidirectional=bidirectional, dropout=0, batch_first=True) 14 | if bidirectional: 15 | hidden_sz = 2 * hidden_sz 16 | 17 | layers = [] 18 | for i in range(nl): 19 | if i == 0: 20 | layers.append(nn.Linear(hidden_sz * 3, hidden_sz)) 21 | else: 22 | layers.append(nn.Linear(hidden_sz, hidden_sz)) 23 | layers.append(nn.ReLU()) 24 | self.layers = nn.Sequential(*layers) 25 | self.output = nn.Linear(hidden_sz, num_classes) 26 | 27 | def init_embedding(self, vectors, n_tokens, device): 28 | self.embeddings = nn.Embedding(n_tokens, self.emb_sz).to(device) 29 | self.embeddings.weight.data.copy_(vectors.to(device)) 30 | 31 | def embed(self, data): 32 | self.h = self.init_hidden(data.size(0)) 33 | embedded = self.embeddings(data) 34 | return embedded 35 | 36 | def forward(self, embedded): 37 | rnn_out, self.h = self.rnn(embedded, (self.h, self.h)) 38 | 39 | avg_pool = F.adaptive_avg_pool1d(rnn_out.permute(0, 2, 1), 1).view(embedded.size(0), -1) 40 | max_pool = F.adaptive_max_pool1d(rnn_out.permute(0, 2, 1), 1).view(embedded.size(0), -1) 41 | x = torch.cat([avg_pool, max_pool, rnn_out[:, -1]], dim=1) 42 | x = self.layers(x) 43 | res = self.output(x) 44 | if res.size(1) == 1: 45 | res = res.squeeze(1) 46 | return res 47 | 48 | def init_hidden(self, batch_size): 49 | return torch.zeros((4, batch_size, self.hidden_sz), device="cuda") 50 | -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/xternalz/WideResNet-pytorch 2 | 3 | from __future__ import division 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 12 | super(BasicBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(out_planes) 18 | self.relu2 = nn.ReLU(inplace=True) 19 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 20 | padding=1, bias=False) 21 | self.droprate = dropRate 22 | self.equalInOut = (in_planes == out_planes) 23 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 24 | padding=0, bias=False) or None 25 | 26 | def forward(self, x): 27 | if not self.equalInOut: 28 | x = self.relu1(self.bn1(x)) 29 | else: 30 | out = self.relu1(self.bn1(x)) 31 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 32 | if self.droprate > 0: 33 | out = F.dropout(out, p=self.droprate, training=self.training) 34 | out = self.conv2(out) 35 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 36 | 37 | 38 | class NetworkBlock(nn.Module): 39 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 40 | super(NetworkBlock, self).__init__() 41 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 42 | 43 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 44 | layers = [] 45 | for i in range(nb_layers): 46 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 47 | return nn.Sequential(*layers) 48 | 49 | def forward(self, x): 50 | return self.layer(x) 51 | 52 | 53 | class WideResNet(nn.Module): 54 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, nc=1): 55 | super(WideResNet, self).__init__() 56 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 57 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 58 | n = (depth - 4) // 6 59 | block = BasicBlock 60 | # 1st conv before any network block 61 | self.conv1 = nn.Conv2d(nc, nChannels[0], kernel_size=3, stride=1, 62 | padding=1, bias=False) 63 | # 1st block 64 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 65 | # 2nd block 66 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 67 | # 3rd block 68 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 69 | # global average pooling and classifier 70 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.fc = nn.Linear(nChannels[3], num_classes) 73 | self.nChannels = nChannels[3] 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | elif isinstance(m, nn.Linear): 83 | m.bias.data.zero_() 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.block1(out) 88 | out = self.block2(out) 89 | out = self.block3(out) 90 | out = self.relu(self.bn1(out)) 91 | out = F.avg_pool2d(out, 7) 92 | out = out.view(-1, self.nChannels) 93 | return self.fc(out) 94 | 95 | 96 | def wrn(**kwargs): 97 | """ 98 | Constructs a Wide Residual Networks. 99 | """ 100 | model = WideResNet(**kwargs) 101 | return model 102 | -------------------------------------------------------------------------------- /notebooks/grad_cam.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example Masks\n", 8 | "\n", 9 | "In this notebook, we plot the [Grad-CAM](https://arxiv.org/abs/1610.02391) figures from the paper. For more info and other examples, have a look at [our README](https://github.com/ecs-vlc/fmix).\n", 10 | "\n", 11 | "**Note**: The easiest way to use this is as a colab notebook, which allows you to dive in with no setup.\n", 12 | "\n", 13 | "First, we load dependencies and some data from CIFAR-10:" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Files already downloaded and verified\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "import numpy as np\n", 31 | "import torch\n", 32 | "import torch.nn as nn\n", 33 | "import torchvision\n", 34 | "from torchvision import transforms\n", 35 | "import models\n", 36 | "from torchbearer import Trial\n", 37 | "import cv2\n", 38 | "import torch.nn.functional as F\n", 39 | "\n", 40 | "inv_norm = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))\n", 41 | "valset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,\n", 42 | " transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),\n", 43 | " (0.2023, 0.1994, 0.2010)),]))\n", 44 | "\n", 45 | "valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Model Wrapper\n", 53 | "\n", 54 | "Next, we define a `ResNet` wrapper that will iterate up to some given block `k`:" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "class ResNet_CAM(nn.Module):\n", 64 | " def __init__(self, net, layer_k):\n", 65 | " super(ResNet_CAM, self).__init__()\n", 66 | " self.resnet = net\n", 67 | " convs = nn.Sequential(*list(net.children())[:-1])\n", 68 | " self.first_part_conv = convs[:layer_k]\n", 69 | " self.second_part_conv = convs[layer_k:]\n", 70 | " self.linear = nn.Sequential(*list(net.children())[-1:])\n", 71 | " \n", 72 | " def forward(self, x):\n", 73 | " x = self.first_part_conv(x)\n", 74 | " x.register_hook(self.activations_hook)\n", 75 | " x = self.second_part_conv(x)\n", 76 | " x = F.adaptive_avg_pool2d(x, (1,1))\n", 77 | " x = x.view((1, -1))\n", 78 | " x = self.linear(x)\n", 79 | " return x\n", 80 | " \n", 81 | " def activations_hook(self, grad):\n", 82 | " self.gradients = grad\n", 83 | " \n", 84 | " def get_activations_gradient(self):\n", 85 | " return self.gradients\n", 86 | " \n", 87 | " def get_activations(self, x):\n", 88 | " return self.first_part_conv(x)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "## Grad-CAM\n", 96 | "\n", 97 | "Now for the Grad-CAM code, adapted from [implementing-grad-cam-in-pytorch](https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82):" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 3, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def superimpose_heatmap(heatmap, img):\n", 107 | " resized_heatmap = cv2.resize(heatmap.numpy(), (img.shape[2], img.shape[3]))\n", 108 | " resized_heatmap = np.uint8(255 * resized_heatmap)\n", 109 | " resized_heatmap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)\n", 110 | " superimposed_img = torch.Tensor(cv2.cvtColor(resized_heatmap, cv2.COLOR_BGR2RGB)) * 0.006 + inv_norm(img[0]).permute(1,2,0)\n", 111 | " \n", 112 | " return superimposed_img\n", 113 | "\n", 114 | "def get_grad_cam(net, img):\n", 115 | " net.eval()\n", 116 | " pred = net(img)\n", 117 | " pred[:,pred.argmax(dim=1)].backward()\n", 118 | " gradients = net.get_activations_gradient()\n", 119 | " pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])\n", 120 | " activations = net.get_activations(img).detach()\n", 121 | " for i in range(activations.size(1)):\n", 122 | " activations[:, i, :, :] *= pooled_gradients[i]\n", 123 | " heatmap = torch.mean(activations, dim=1).squeeze()\n", 124 | " heatmap = np.maximum(heatmap, 0)\n", 125 | " heatmap /= torch.max(heatmap)\n", 126 | " \n", 127 | " return torch.Tensor(superimpose_heatmap(heatmap, img).permute(2,0,1))" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Models From `torch.hub`\n", 135 | "\n", 136 | "Next, we load in the models from `torch.hub`" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 4, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stderr", 146 | "output_type": "stream", 147 | "text": [ 148 | "Downloading: \"https://github.com/ecs-vlc/FMix/archive/master.zip\" to /home/ethan/.cache/torch/hub/master.zip\n", 149 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n", 150 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n", 151 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n" 152 | ] 153 | } 154 | ], 155 | "source": [ 156 | "baseline_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_baseline', pretrained=True)\n", 157 | "\n", 158 | "fmix_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmix', pretrained=True)\n", 159 | "\n", 160 | "mixup_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_mixup', pretrained=True)\n", 161 | "\n", 162 | "fmix_plus_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmixplusmixup', pretrained=True)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Plots\n", 170 | "\n", 171 | "Finally, generate and save the Grad-CAM plots:" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 5, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "layer_k = 4\n", 181 | "n_imgs = 10\n", 182 | "\n", 183 | "baseline_cam_net = ResNet_CAM(baseline_net, layer_k)\n", 184 | "fmix_cam_net = ResNet_CAM(fmix_net, layer_k)\n", 185 | "mixup_cam_net = ResNet_CAM(mixup_net, layer_k)\n", 186 | "fmix_plus_cam_net = ResNet_CAM(fmix_plus_net, layer_k)\n", 187 | "\n", 188 | "imgs = torch.Tensor(5, n_imgs, 3, 32, 32)\n", 189 | "it = iter(valloader)\n", 190 | "for i in range(0,n_imgs):\n", 191 | " img, _ = next(it)\n", 192 | " imgs[0][i] = inv_norm(img[0])\n", 193 | " imgs[1][i] = get_grad_cam(baseline_cam_net, img)\n", 194 | " imgs[2][i] = get_grad_cam(mixup_cam_net, img)\n", 195 | " imgs[3][i] = get_grad_cam(fmix_cam_net, img)\n", 196 | " imgs[4][i] = get_grad_cam(fmix_plus_cam_net, img)\n", 197 | "\n", 198 | "torchvision.utils.save_image(imgs.view(-1, 3, 32, 32), \"gradcam_at_layer\" + str(layer_k) + \".png\",nrow=n_imgs, pad_value=1)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.7.3" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 2 230 | } 231 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | Pillow 5 | torchbearer 6 | pandas 7 | pytorch-lightning 8 | tensorflow 9 | pyarrow 10 | spacy 11 | torchtext 12 | scipy 13 | scikit-learn 14 | tensorboardX 15 | transformers -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ast 3 | from datetime import datetime 4 | 5 | import pandas as pd 6 | import torch 7 | import torchbearer 8 | from torch import nn, optim 9 | from torchbearer import Trial 10 | from torchbearer.callbacks import MultiStepLR, CosineAnnealingLR 11 | from torchbearer.callbacks import TensorBoard, TensorBoardText, Cutout, CutMix, RandomErase, on_forward_validation 12 | 13 | from datasets.datasets import ds, dsmeta, nlp_data 14 | from implementations.torchbearer_implementation import FMix, PointNetFMix 15 | from models.models import get_model 16 | from utils import RMixup, MSDAAlternator, WarmupLR 17 | from datasets.toxic import ToxicHelper 18 | 19 | # Setup 20 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 21 | parser.add_argument('--dataset', type=str, default='cifar10', 22 | choices=['cifar10', 'cifar100', 'reduced_cifar', 'fashion', 'imagenet', 'imagenet_hdf5', 'imagenet_a', 'tinyimagenet', 23 | 'commands', 'modelnet', 'toxic', 'toxic_bert', 'bengali_r', 'bengali_c', 'bengali_v', 'imdb', 'yelp_2', 'yelp_5']) 24 | parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path') 25 | parser.add_argument('--split-fraction', type=float, default=1., help='Fraction of total data to train on for reduced_cifar dataset') 26 | parser.add_argument('--pointcloud-resolution', default=128, type=int, help='Resolution of pointclouds in modelnet dataset') 27 | parser.add_argument('--model', default="ResNet18", type=str, help='model type') 28 | parser.add_argument('--epoch', default=200, type=int, help='total epochs to run') 29 | parser.add_argument('--train-steps', type=int, default=None, help='Number of training steps to run per "epoch"') 30 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 31 | parser.add_argument('--lr-warmup', type=ast.literal_eval, default=False, help='Use lr warmup') 32 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 33 | parser.add_argument('--device', default='cuda', type=str, help='Device on which to run') 34 | parser.add_argument('--num-workers', default=7, type=int, help='Number of dataloader workers') 35 | 36 | parser.add_argument('--auto-augment', type=ast.literal_eval, default=False, help='Use auto augment with cifar10/100') 37 | parser.add_argument('--augment', type=ast.literal_eval, default=True, help='use standard augmentation (default: True)') 38 | parser.add_argument('--parallel', type=ast.literal_eval, default=False, help='Use DataParallel') 39 | parser.add_argument('--reload', type=ast.literal_eval, default=False, help='Set to resume training from model path') 40 | parser.add_argument('--verbose', type=int, default=2, choices=[0, 1, 2]) 41 | parser.add_argument('--seed', default=0, type=int, help='random seed') 42 | 43 | # Augs 44 | parser.add_argument('--random-erase', default=False, type=ast.literal_eval, help='Apply random erase') 45 | parser.add_argument('--cutout', default=False, type=ast.literal_eval, help='Apply Cutout') 46 | parser.add_argument('--msda-mode', default=None, type=str, choices=['fmix', 'cutmix', 'mixup', 'alt_mixup_fmix', 47 | 'alt_mixup_cutmix', 'alt_fmix_cutmix', 'None']) 48 | 49 | # Aug Params 50 | parser.add_argument('--alpha', default=1., type=float, help='mixup/fmix interpolation coefficient') 51 | parser.add_argument('--f-decay', default=3.0, type=float, help='decay power') 52 | parser.add_argument('--cutout_l', default=16, type=int, help='cutout/erase length') 53 | parser.add_argument('--reformulate', default=False, type=ast.literal_eval, help='Use reformulated fmix/mixup') 54 | 55 | # Scheduling 56 | parser.add_argument('--schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.') 57 | parser.add_argument('--cosine-scheduler', type=ast.literal_eval, default=False, help='Set to use a cosine scheduler instead of step scheduler') 58 | 59 | # Cross validation 60 | parser.add_argument('--fold-path', type=str, default='./data/folds.npz', help='Path to object storing fold ids. Run-id 0 will regen this if not existing') 61 | parser.add_argument('--n-folds', type=int, default=6, help='Number of cross val folds') 62 | parser.add_argument('--fold', type=str, default='test', help='One of [1, ..., n] or "test"') 63 | 64 | # Logs 65 | parser.add_argument('--run-id', type=int, default=0, help='Run id') 66 | parser.add_argument('--log-dir', default='./logs/testing', help='Tensorboard log dir') 67 | parser.add_argument('--model-file', default='./saved_models/model.pt', help='Path under which to save model. eg ./model.py') 68 | args = parser.parse_args() 69 | 70 | 71 | if args.seed != 0: 72 | torch.manual_seed(args.seed) 73 | 74 | 75 | print('==> Preparing data..') 76 | data = ds[args.dataset] 77 | meta = dsmeta[args.dataset] 78 | classes, nc, size = meta['classes'], meta['nc'], meta['size'] 79 | 80 | trainset, valset, testset = data(args) 81 | 82 | # Toxic comments uses its own data loaders 83 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (trainset is not None) and (args.dataset not in nlp_data) else trainset 84 | valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (valset is not None) and (args.dataset not in nlp_data) else valset 85 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (args.dataset not in nlp_data) else testset 86 | 87 | print('==> Building model..') 88 | net = get_model(args, classes, nc) 89 | net = nn.DataParallel(net) if args.parallel else net 90 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) 91 | 92 | if (args.dataset in nlp_data) or ('modelnet' in args.dataset): 93 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 94 | 95 | 96 | print('==> Setting up callbacks..') 97 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') + "-run-" + str(args.run_id) 98 | tboard = TensorBoard(write_graph=False, comment=current_time, log_dir=args.log_dir) 99 | tboardtext = TensorBoardText(write_epoch_metrics=False, comment=current_time, log_dir=args.log_dir) 100 | 101 | 102 | @torchbearer.callbacks.on_start 103 | def write_params(_): 104 | params = vars(args) 105 | params['schedule'] = str(params['schedule']) 106 | df = pd.DataFrame(params, index=[0]).transpose() 107 | tboardtext.get_writer(tboardtext.log_dir).add_text('params', df.to_html(), 1) 108 | 109 | 110 | modes = { 111 | 'fmix': FMix(decay_power=args.f_decay, alpha=args.alpha, size=size, max_soft=0, reformulate=args.reformulate), 112 | 'mixup': RMixup(args.alpha, reformulate=args.reformulate), 113 | 'cutmix': CutMix(args.alpha, classes, True), 114 | 'pointcloud_fmix': PointNetFMix(args.pointcloud_resolution, decay_power=args.f_decay, alpha=args.alpha, max_soft=0, 115 | reformulate=args.reformulate) 116 | } 117 | modes.update({ 118 | 'alt_mixup_fmix': MSDAAlternator(modes['fmix'], modes['mixup']), 119 | 'alt_mixup_cutmix': MSDAAlternator(modes['mixup'], modes['cutmix']), 120 | 'alt_fmix_cutmix': MSDAAlternator(modes['fmix'], modes['cutmix']), 121 | }) 122 | 123 | # Pointcloud fmix converts voxel grids back into point clouds after mixing 124 | mode = 'pointcloud_fmix' if (args.msda_mode == 'fmix' and args.dataset == 'modelnet') else args.msda_mode 125 | 126 | # CutMix callback returns mixed and original targets. We mix in the loss function instead 127 | @torchbearer.callbacks.on_sample 128 | def cutmix_reformat(state): 129 | state[torchbearer.Y_TRUE] = state[torchbearer.Y_TRUE][0] 130 | 131 | cb = [tboard, tboardtext, write_params, torchbearer.callbacks.MostRecent(args.model_file)] 132 | # Toxic helper needs to go before the msda to reshape the input 133 | cb.append(ToxicHelper(to_float=args.dataset != 'yelp_5')) if (args.dataset in ['toxic', 'imdb', 'yelp_2', 'yelp_5']) else [] 134 | cb.append(modes[mode]) if args.msda_mode not in [None, 'None'] else [] 135 | cb.append(Cutout(1, args.cutout_l)) if args.cutout else [] 136 | cb.append(RandomErase(1, args.cutout_l)) if args.random_erase else [] 137 | # WARNING: Schedulers appear to be broken (wrong lr output) in some versions of PyTorch, including 1.4. We used 1.3.1 138 | cb.append(MultiStepLR(args.schedule)) if not args.cosine_scheduler else cb.append(CosineAnnealingLR(args.epoch, eta_min=0.)) 139 | cb.append(WarmupLR(0.1, args.lr)) if args.lr_warmup else [] 140 | cb.append(cutmix_reformat) if args.msda_mode == 'cutmix' else [] 141 | 142 | # FMix loss is equivalent to mixup loss and works for all msda in torchbearer 143 | if args.msda_mode not in [None, 'None']: 144 | bce = True if (args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']) else False 145 | criterion = modes['fmix'].loss(bce) 146 | elif args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']: 147 | criterion = nn.BCEWithLogitsLoss() 148 | else: 149 | criterion = nn.CrossEntropyLoss() 150 | 151 | metrics_append = [] 152 | if 'bengali' in args.dataset: 153 | from utils.macro_recall import MacroRecall 154 | metrics_append = [MacroRecall()] 155 | elif 'imagenet' in args.dataset: 156 | metrics_append = ['top_5_acc'] 157 | elif 'toxic' in args.dataset: 158 | from torchbearer.metrics import to_dict, EpochLambda 159 | 160 | @to_dict 161 | class RocAucScore(EpochLambda): 162 | def __init__(self): 163 | import sklearn.metrics 164 | 165 | super().__init__('roc_auc_score', 166 | lambda y_pred, y_true: sklearn.metrics.roc_auc_score(y_true.cpu().numpy(), y_pred.detach().sigmoid().cpu().numpy()), 167 | running=False) 168 | metrics_append = [RocAucScore()] 169 | 170 | if args.dataset == 'imagenet_a': 171 | from datasets.imagenet_a import indices_in_1k 172 | 173 | @on_forward_validation 174 | def map_preds(state): 175 | state[torchbearer.PREDICTION] = state[torchbearer.PREDICTION][:, indices_in_1k] 176 | cb.append(map_preds) 177 | 178 | print('==> Training model..') 179 | 180 | acc = 'acc' 181 | if args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']: 182 | acc = 'binary_acc' 183 | 184 | trial = Trial(net, optimizer, criterion, metrics=[acc, 'loss', 'lr'] + metrics_append, callbacks=cb) 185 | trial.with_generators(train_generator=trainloader, val_generator=valloader, train_steps=args.train_steps, test_generator=testloader).to(args.device) 186 | 187 | if args.reload: 188 | state = torch.load(args.model_file) 189 | trial.load_state_dict(state, resume=args.dataset != 'imagenet_a') 190 | trial.replay() 191 | 192 | if trainloader is not None: 193 | trial.run(args.epoch, verbose=args.verbose) 194 | trial.evaluate(data_key=torchbearer.TEST_DATA) 195 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_warmup import WarmupLR 2 | from .msda_alternator import MSDAAlternator 3 | from .reformulated_mixup import RMixup 4 | from .cross_val import split, gen_folds 5 | from .reduced_dataset_splitter import EqualSplitter 6 | from .auto_augment.auto_augment import auto_augment, _fa_reduced_cifar10 -------------------------------------------------------------------------------- /utils/auto_augment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/utils/auto_augment/__init__.py -------------------------------------------------------------------------------- /utils/auto_augment/auto_augment_aug_list.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | 9 | random_mirror = True 10 | 11 | 12 | def ShearX(img, v): # [-0.3, 0.3] 13 | assert -0.3 <= v <= 0.3 14 | if random_mirror and random.random() > 0.5: 15 | v = -v 16 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 17 | 18 | 19 | def ShearY(img, v): # [-0.3, 0.3] 20 | assert -0.3 <= v <= 0.3 21 | if random_mirror and random.random() > 0.5: 22 | v = -v 23 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 24 | 25 | 26 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 27 | assert -0.45 <= v <= 0.45 28 | if random_mirror and random.random() > 0.5: 29 | v = -v 30 | v = v * img.size[0] 31 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 32 | 33 | 34 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 35 | assert -0.45 <= v <= 0.45 36 | if random_mirror and random.random() > 0.5: 37 | v = -v 38 | v = v * img.size[1] 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 40 | 41 | 42 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 43 | assert 0 <= v <= 10 44 | if random.random() > 0.5: 45 | v = -v 46 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 47 | 48 | 49 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 50 | assert 0 <= v <= 10 51 | if random.random() > 0.5: 52 | v = -v 53 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 54 | 55 | 56 | def Rotate(img, v): # [-30, 30] 57 | assert -30 <= v <= 30 58 | if random_mirror and random.random() > 0.5: 59 | v = -v 60 | return img.rotate(v) 61 | 62 | 63 | def AutoContrast(img, _): 64 | return PIL.ImageOps.autocontrast(img) 65 | 66 | 67 | def Invert(img, _): 68 | return PIL.ImageOps.invert(img) 69 | 70 | 71 | def Equalize(img, _): 72 | return PIL.ImageOps.equalize(img) 73 | 74 | 75 | def Flip(img, _): # not from the paper 76 | return PIL.ImageOps.mirror(img) 77 | 78 | 79 | def Solarize(img, v): # [0, 256] 80 | assert 0 <= v <= 256 81 | return PIL.ImageOps.solarize(img, v) 82 | 83 | 84 | def Posterize(img, v): # [4, 8] 85 | assert 4 <= v <= 8 86 | v = int(v) 87 | return PIL.ImageOps.posterize(img, v) 88 | 89 | 90 | def Posterize2(img, v): # [0, 4] 91 | assert 0 <= v <= 4 92 | v = int(v) 93 | return PIL.ImageOps.posterize(img, v) 94 | 95 | 96 | def Contrast(img, v): # [0.1,1.9] 97 | assert 0.1 <= v <= 1.9 98 | return PIL.ImageEnhance.Contrast(img).enhance(v) 99 | 100 | 101 | def Color(img, v): # [0.1,1.9] 102 | assert 0.1 <= v <= 1.9 103 | return PIL.ImageEnhance.Color(img).enhance(v) 104 | 105 | 106 | def Brightness(img, v): # [0.1,1.9] 107 | assert 0.1 <= v <= 1.9 108 | return PIL.ImageEnhance.Brightness(img).enhance(v) 109 | 110 | 111 | def Sharpness(img, v): # [0.1,1.9] 112 | assert 0.1 <= v <= 1.9 113 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 114 | 115 | 116 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 117 | assert 0.0 <= v <= 0.2 118 | if v <= 0.: 119 | return img 120 | 121 | v = v * img.size[0] 122 | return CutoutAbs(img, v) 123 | 124 | 125 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 126 | # assert 0 <= v <= 20 127 | if v < 0: 128 | return img 129 | w, h = img.size 130 | x0 = np.random.uniform(w) 131 | y0 = np.random.uniform(h) 132 | 133 | x0 = int(max(0, x0 - v / 2.)) 134 | y0 = int(max(0, y0 - v / 2.)) 135 | x1 = min(w, x0 + v) 136 | y1 = min(h, y0 + v) 137 | 138 | xy = (x0, y0, x1, y1) 139 | color = (125, 123, 114) 140 | # color = (0, 0, 0) 141 | img = img.copy() 142 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 143 | return img 144 | 145 | 146 | def SamplePairing(imgs): # [0, 0.4] 147 | def f(img1, v): 148 | i = np.random.choice(len(imgs)) 149 | img2 = PIL.Image.fromarray(imgs[i]) 150 | return PIL.Image.blend(img1, img2, v) 151 | 152 | return f 153 | 154 | 155 | def augment_list(for_autoaug=True): # 16 oeprations and their ranges 156 | l = [ 157 | (ShearX, -0.3, 0.3), # 0 158 | (ShearY, -0.3, 0.3), # 1 159 | (TranslateX, -0.45, 0.45), # 2 160 | (TranslateY, -0.45, 0.45), # 3 161 | (Rotate, -30, 30), # 4 162 | (AutoContrast, 0, 1), # 5 163 | (Invert, 0, 1), # 6 164 | (Equalize, 0, 1), # 7 165 | (Solarize, 0, 256), # 8 166 | (Posterize, 4, 8), # 9 167 | (Contrast, 0.1, 1.9), # 10 168 | (Color, 0.1, 1.9), # 11 169 | (Brightness, 0.1, 1.9), # 12 170 | (Sharpness, 0.1, 1.9), # 13 171 | (Cutout, 0, 0.2), # 14 172 | # (SamplePairing(imgs), 0, 0.4), # 15 173 | ] 174 | if for_autoaug: 175 | l += [ 176 | (CutoutAbs, 0, 20), # compatible with auto-augment 177 | (Posterize2, 0, 4), # 9 178 | (TranslateXAbs, 0, 10), # 9 179 | (TranslateYAbs, 0, 10), # 9 180 | ] 181 | return l 182 | 183 | 184 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 185 | 186 | 187 | def get_augment(name): 188 | return augment_dict[name] 189 | 190 | 191 | def apply_augment(img, name, level): 192 | augment_fn, low, high = get_augment(name) 193 | return augment_fn(img.copy(), level * (high - low) + low) 194 | 195 | 196 | class Lighting(object): 197 | """Lighting noise(AlexNet - style PCA - based noise)""" 198 | 199 | def __init__(self, alphastd, eigval, eigvec): 200 | self.alphastd = alphastd 201 | self.eigval = torch.Tensor(eigval) 202 | self.eigvec = torch.Tensor(eigvec) 203 | 204 | def __call__(self, img): 205 | if self.alphastd == 0: 206 | return img 207 | 208 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 209 | rgb = self.eigvec.type_as(img).clone() \ 210 | .mul(alpha.view(1, 3).expand(3, 3)) \ 211 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 212 | .sum(1).squeeze() 213 | 214 | return img.add(rgb.view(3, 1, 1).expand_as(img)) -------------------------------------------------------------------------------- /utils/bengali_evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchbearer 4 | from torchbearer import metrics, Trial 5 | from datasets.datasets import bengali 6 | import argparse 7 | from models import se_resnext50_32x4d 8 | 9 | parser = argparse.ArgumentParser(description='Bengali Evaluate') 10 | 11 | parser.add_argument('--dataset', type=str, default='bengali', help='Optional dataset path') 12 | parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path') 13 | parser.add_argument('--fold-path', type=str, default='./data/folds.npz', help='Path to object storing fold ids. Run-id 0 will regen this if not existing') 14 | parser.add_argument('--fold', type=str, default='test', help='One of [1, ..., n] or "test"') 15 | parser.add_argument('--run-id', type=int, default=0, help='Run id') 16 | 17 | parser.add_argument('--model-r', type=str, default=None, help='Root model') 18 | parser.add_argument('--model-v', type=str, default=None, help='Vowel model') 19 | parser.add_argument('--model-c', type=str, default=None, help='Consonant model') 20 | 21 | args = parser.parse_args() 22 | 23 | _, __, testset = bengali(args) 24 | 25 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=6) 26 | 27 | 28 | class BengaliModelWrapper(nn.Module): 29 | def __init__(self, model_r, model_v, model_c): 30 | super().__init__() 31 | 32 | self.model_r = model_r 33 | self.model_v = model_v 34 | self.model_c = model_c 35 | 36 | def forward(self, x): 37 | return self.model_r(x), self.model_v(x), self.model_c(x) 38 | 39 | 40 | @metrics.default_for_key('grapheme') 41 | @metrics.mean 42 | class GraphemeAccuracy(metrics.Metric): 43 | def __init__(self): 44 | super().__init__('grapheme_acc') 45 | 46 | def process(self, *args): 47 | state = args[0] 48 | r_pred, v_pred, c_pred = state[torchbearer.PREDICTION] 49 | r_true, v_true, c_true = state[torchbearer.TARGET] 50 | 51 | _, r_pred = torch.max(r_pred, 1) 52 | _, v_pred = torch.max(v_pred, 1) 53 | _, c_pred = torch.max(c_pred, 1) 54 | 55 | r = (r_pred == r_true) 56 | v = (v_pred == v_true) 57 | c = (c_pred == c_true) 58 | return torch.stack((r, v, c), dim=1).all(1).float() 59 | 60 | 61 | model_r = se_resnext50_32x4d(168, 1) 62 | model_v = se_resnext50_32x4d(11, 1) 63 | model_c = se_resnext50_32x4d(7, 1) 64 | 65 | model_r.load_state_dict(torch.load(args.model_r, map_location='cpu')[torchbearer.MODEL]) 66 | model_v.load_state_dict(torch.load(args.model_v, map_location='cpu')[torchbearer.MODEL]) 67 | model_c.load_state_dict(torch.load(args.model_c, map_location='cpu')[torchbearer.MODEL]) 68 | 69 | model = BengaliModelWrapper(model_r, model_v, model_c) 70 | 71 | trial = Trial(model, criterion=lambda state: None, metrics=['grapheme']).with_test_generator(testloader).to('cuda') 72 | trial.evaluate(data_key=torchbearer.TEST_DATA) 73 | -------------------------------------------------------------------------------- /utils/convert_imagenet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageNet models trained with the original imagenet_hdf5 data set have their outputs in the wrong order. This loads a 3 | model and re-orders the output weights to be consistent with other ImageNet models. 4 | """ 5 | import argparse 6 | import os 7 | from torchvision.models import resnet101 8 | import torch 9 | import torchbearer 10 | from torch import nn 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset-path', type=str, default=None, help='ImageNet path') 14 | parser.add_argument('--model-file') 15 | args = parser.parse_args() 16 | 17 | root = os.path.join(args.dataset_path, 'train') 18 | 19 | old_classes = list(filter(lambda f: '.hdf5' in f, os.listdir(root))) 20 | new_classes = sorted(old_classes) 21 | 22 | model = nn.DataParallel(resnet101(False)) 23 | sd = torch.load(args.model_file, map_location='cpu')[torchbearer.MODEL] 24 | model.load_state_dict(sd) 25 | model = model.module 26 | 27 | new_weights = torch.zeros_like(model.fc.weight.data) 28 | new_bias = torch.zeros_like(model.fc.bias.data) 29 | 30 | for layer in range(1000): 31 | new_layer = new_classes.index(old_classes[layer]) 32 | 33 | new_weights[layer, :] = model.fc.weight[layer, :] 34 | new_bias[layer] = model.fc.bias[layer] 35 | 36 | model.fc.weight.data = new_weights.data 37 | model.fc.bias.data = new_bias.data 38 | 39 | torch.save(model.state_dict(), args.model_file + '_converted.pt') 40 | -------------------------------------------------------------------------------- /utils/cross_val.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import os 4 | from torchbearer.cv_utils import DatasetValidationSplitter 5 | 6 | 7 | def gen_folds(args, dataset, test_size): 8 | from sklearn.utils.validation import check_random_state 9 | from sklearn.model_selection._split import _validate_shuffle_split 10 | 11 | n_samples = len(dataset) 12 | n_train, n_test = _validate_shuffle_split(n_samples, test_size, None, default_test_size=0.1) 13 | rng = check_random_state(args.seed) 14 | 15 | train_folds = [] 16 | test_folds = [] 17 | 18 | for i in range(args.n_folds): 19 | # random partition 20 | permutation = rng.permutation(n_samples) 21 | ind_test = permutation[:n_test] 22 | ind_train = permutation[n_test:(n_test + n_train)] 23 | train_folds.append(ind_train) 24 | test_folds.append(ind_test) 25 | 26 | train_folds, test_folds = np.stack(train_folds), np.stack(test_folds) 27 | np.savez(args.fold_path, train=train_folds, test=test_folds) 28 | 29 | 30 | def split(func): 31 | def splitting(args): 32 | try: 33 | trainset, testset = func(args) 34 | 35 | if args.fold == 'test': 36 | return trainset, testset, testset 37 | except: 38 | trainset = func(args) 39 | testset = None 40 | 41 | if args.run_id == 0 and not os.path.exists(args.fold_path): 42 | gen_folds(args, trainset, len(trainset) // args.n_folds) 43 | else: 44 | time.sleep(3) 45 | 46 | folds = np.load(args.fold_path) 47 | train_ids, val_ids = folds['train'][int(args.fold)], folds['test'][int(args.fold)] 48 | 49 | splitter = DatasetValidationSplitter(len(trainset), 0.1) 50 | splitter.train_ids, splitter.valid_ids = train_ids, val_ids 51 | 52 | trainset, valset = splitter.get_train_dataset(trainset), splitter.get_val_dataset(trainset) 53 | return trainset, valset, (testset if testset is not None else valset) 54 | 55 | return splitting 56 | -------------------------------------------------------------------------------- /utils/imagenet_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import multiprocessing 4 | import os 5 | # from PIL import Image 6 | import pickle 7 | import process 8 | 9 | SIZE = 256 10 | 11 | imagenet_dir = '/ssd/ILSVRC2012/train' 12 | target_dir = '/data/imagenet_hdf5/train' 13 | 14 | subdirs = [d for d in os.listdir(imagenet_dir) if os.path.isdir(os.path.join(imagenet_dir, d))] 15 | 16 | num_cpus = multiprocessing.cpu_count() 17 | pool = multiprocessing.Pool(num_cpus) 18 | 19 | dest_list = [] 20 | 21 | for subdir in subdirs: 22 | from_dir = os.path.join(imagenet_dir, subdir) 23 | to_path = os.path.join(target_dir, subdir + '.hdf5') 24 | 25 | paths = [os.path.join(from_dir, d) for d in os.listdir(from_dir)] 26 | 27 | # arr = np.array(pool.map(process.process, paths), dtype='uint8') 28 | images = pool.map(process.read_bytes, paths) 29 | with h5py.File(to_path, 'w') as f: 30 | dt = h5py.special_dtype(vlen=np.dtype('uint8')) 31 | dset = f.create_dataset('data', (len(paths), ), dtype=dt) 32 | for i, image in enumerate(images): 33 | dset[i] = np.fromstring(image, dtype='uint8') 34 | 35 | for i, path in enumerate(paths): 36 | dest_list.append((subdir, i)) 37 | 38 | print(subdir) 39 | 40 | pickle.dump(dest_list, open(os.path.join(target_dir, 'dest.p'), 'wb')) 41 | -------------------------------------------------------------------------------- /utils/lr_warmup.py: -------------------------------------------------------------------------------- 1 | from torchbearer.callbacks import Callback 2 | import torchbearer 3 | 4 | 5 | class WarmupLR(Callback): 6 | def __init__(self, min_lr, max_lr, warmup_period=5): 7 | super().__init__() 8 | self.min_lr = min_lr 9 | self.max_lr = max_lr 10 | self.t = warmup_period 11 | 12 | def on_start_training(self, state): 13 | super().on_start_training(state) 14 | if state[torchbearer.EPOCH] < self.t: 15 | delta = (self.t - state[torchbearer.EPOCH])/self.t 16 | opt = state[torchbearer.OPTIMIZER] 17 | 18 | for pg in opt.param_groups: 19 | pg['lr'] = self.min_lr * delta + self.max_lr * (1-delta) 20 | -------------------------------------------------------------------------------- /utils/macro_recall.py: -------------------------------------------------------------------------------- 1 | from torchbearer.metrics import default_for_key, to_dict, EpochLambda 2 | 3 | 4 | @default_for_key('macro_recall') 5 | @to_dict 6 | class MacroRecall(EpochLambda): 7 | def __init__(self): 8 | from sklearn import metrics 9 | 10 | def process(y_pred): 11 | return y_pred.max(1)[1] 12 | 13 | super().__init__('macro_recall', lambda y_pred, y_true: metrics.recall_score(y_true.cpu().numpy(), process(y_pred).detach().cpu().numpy(), average='macro')) 14 | -------------------------------------------------------------------------------- /utils/msda_alternator.py: -------------------------------------------------------------------------------- 1 | from torchbearer import Callback 2 | 3 | 4 | class MSDAAlternator(Callback): 5 | def __init__(self, msda_a, msda_b, n_a=1, n_b=1): 6 | super().__init__() 7 | self.augs = ((msda_a, n_a), (msda_b, n_b)) 8 | self.current_aug = 0 9 | self.current_steps = 0 10 | 11 | def on_sample(self, state): 12 | super().on_sample(state) 13 | 14 | aug, steps = self.augs[self.current_aug] 15 | aug.on_sample(state) 16 | 17 | self.current_steps = self.current_steps + 1 18 | if self.current_steps >= steps: 19 | self.current_aug = (self.current_aug + 1) % 2 20 | self.current_steps = 0 -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | SIZE = 256 5 | 6 | 7 | def process(path): 8 | with open(path, 'rb') as f: 9 | img = Image.open(f) 10 | img = img.convert('RGB') 11 | return np.array(img.resize((SIZE, SIZE), Image.BILINEAR)) 12 | 13 | 14 | def read_bytes(path): 15 | f = open(path, 'rb') 16 | return f.read() 17 | -------------------------------------------------------------------------------- /utils/reduced_dataset_splitter.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | from torchbearer.cv_utils import SubsetDataset 5 | 6 | 7 | class EqualSplitter: 8 | """ Splits a dataset into two parts, taking equal number of samples from each class 9 | 10 | :param dataset: 11 | :param split_fraction: 12 | """ 13 | def __init__(self, dataset, split_fraction): 14 | self.ds = dataset 15 | self.split_fraction = split_fraction 16 | self.train_ids, self.valid_ids = self._gen_split_ids() 17 | 18 | def _gen_split_ids(self): 19 | classes = defaultdict(list) 20 | for i in range(len(self.ds)): 21 | _, label = self.ds[i] 22 | classes[label].append(i) 23 | 24 | nc = len(classes.keys()) 25 | cut_per_class = int(len(self.ds) * self.split_fraction / nc) 26 | 27 | cut_indexes = [] 28 | retained_indexes = [] 29 | for c in classes.keys(): 30 | random.shuffle(classes[c]) 31 | cut_indexes += classes[c][:cut_per_class] 32 | retained_indexes += classes[c][cut_per_class:] 33 | return cut_indexes, retained_indexes 34 | 35 | def get_train_dataset(self): 36 | """ Creates a training dataset from existing dataset 37 | 38 | Args: 39 | dataset (torch.utils.data.Dataset): Dataset to be split into a training dataset 40 | 41 | Returns: 42 | torch.utils.data.Dataset: Training dataset split from whole dataset 43 | """ 44 | return SubsetDataset(self.ds, self.train_ids) 45 | 46 | def get_val_dataset(self): 47 | """ Creates a validation dataset from existing dataset 48 | 49 | Args: 50 | dataset (torch.utils.data.Dataset): Dataset to be split into a validation dataset 51 | 52 | Returns: 53 | torch.utils.data.Dataset: Validation dataset split from whole dataset 54 | """ 55 | return SubsetDataset(self.ds, self.valid_ids) 56 | 57 | -------------------------------------------------------------------------------- /utils/reformulated_mixup.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/pytorchbearer/torchbearer 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.distributions.beta import Beta 6 | 7 | import torchbearer 8 | from torchbearer.callbacks import Callback 9 | from torchbearer import metrics as m 10 | 11 | 12 | @m.running_mean 13 | @m.mean 14 | class MixupAcc(m.AdvancedMetric): 15 | def __init__(self): 16 | m.super(MixupAcc, self).__init__('mixup_acc') 17 | self.cat_acc = m.CategoricalAccuracy().root 18 | 19 | def process_train(self, *args): 20 | m.super(MixupAcc, self).process_train(*args) 21 | state = args[0] 22 | 23 | target1 = state[torchbearer.Y_TRUE] 24 | target2 = target1[state[torchbearer.MIXUP_PERMUTATION]] 25 | _state = args[0].copy() 26 | _state[torchbearer.Y_TRUE] = target1 27 | acc1 = self.cat_acc.process(_state) 28 | 29 | _state = args[0].copy() 30 | _state[torchbearer.Y_TRUE] = target2 31 | acc2 = self.cat_acc.process(_state) 32 | 33 | return acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (1 - state[torchbearer.MIXUP_LAMBDA]) 34 | 35 | def process_validate(self, *args): 36 | m.super(MixupAcc, self).process_validate(*args) 37 | 38 | return self.cat_acc.process(*args) 39 | 40 | def reset(self, state): 41 | self.cat_acc.reset(state) 42 | 43 | 44 | class RMixup(Callback): 45 | """Perform mixup on the model inputs. Requires use of :meth:`MixupInputs.loss`, otherwise lambdas can be found in 46 | state under :attr:`.MIXUP_LAMBDA`. Model targets will be a tuple containing the original target and permuted target. 47 | 48 | .. note:: 49 | 50 | The accuracy metric for mixup is different on training to deal with the different targets, 51 | but for validation it is exactly the categorical accuracy, despite being called "val_mixup_acc" 52 | 53 | Example: :: 54 | 55 | >>> from torchbearer import Trial 56 | >>> from torchbearer.callbacks import Mixup 57 | 58 | # Example Trial which does Mixup regularisation 59 | >>> mixup = Mixup(0.9) 60 | >>> trial = Trial(None, criterion=Mixup.mixup_loss, callbacks=[mixup], metrics=['acc']) 61 | 62 | Args: 63 | lam (float): Mixup inputs by fraction lam. If RANDOM, choose lambda from Beta(alpha, alpha). Else, lambda=lam 64 | alpha (float): The alpha value to use in the beta distribution. 65 | """ 66 | RANDOM = -10.0 67 | 68 | def __init__(self, alpha=1.0, lam=RANDOM, reformulate=False): 69 | super(RMixup, self).__init__() 70 | self.alpha = alpha 71 | self.lam = lam 72 | self.reformulate = reformulate 73 | self.distrib = Beta(self.alpha, self.alpha) if not reformulate else Beta(self.alpha + 1, self.alpha) 74 | 75 | @staticmethod 76 | def mixup_loss(state): 77 | """The standard cross entropy loss formulated for mixup (weighted combination of `F.cross_entropy`). 78 | 79 | Args: 80 | state: The current :class:`Trial` state. 81 | """ 82 | input, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE] 83 | 84 | if state[torchbearer.DATA] is torchbearer.TRAIN_DATA: 85 | y1, y2 = target 86 | return F.cross_entropy(input, y1) * state[torchbearer.MIXUP_LAMBDA] + F.cross_entropy(input, y2) * (1-state[torchbearer.MIXUP_LAMBDA]) 87 | else: 88 | return F.cross_entropy(input, target) 89 | 90 | def on_sample(self, state): 91 | if self.lam is RMixup.RANDOM: 92 | if self.alpha > 0: 93 | lam = self.distrib.sample() 94 | else: 95 | lam = 1.0 96 | else: 97 | lam = self.lam 98 | 99 | state[torchbearer.MIXUP_LAMBDA] = lam 100 | 101 | state[torchbearer.MIXUP_PERMUTATION] = torch.randperm(state[torchbearer.X].size(0)) 102 | state[torchbearer.X] = state[torchbearer.X] * state[torchbearer.MIXUP_LAMBDA] + \ 103 | state[torchbearer.X][state[torchbearer.MIXUP_PERMUTATION],:] \ 104 | * (1 - state[torchbearer.MIXUP_LAMBDA]) 105 | 106 | if self.reformulate: 107 | state[torchbearer.MIXUP_LAMBDA] = 1 108 | 109 | 110 | from torchbearer.metrics import default as d 111 | d.__loss_map__[RMixup.mixup_loss.__name__] = MixupAcc 112 | --------------------------------------------------------------------------------