├── .gitignore
├── LICENSE
├── README.md
├── analysis
├── README.md
├── adversarial_linf.py
├── mine.py
├── train_vaes.py
├── vae.py
├── vae_mi.py
└── vgg.py
├── datasets
├── __init__.py
├── bengali.py
├── datasets.py
├── fashion.py
├── google_commands
│ ├── README.md
│ ├── google_commands.py
│ ├── sft_transforms.py
│ └── transforms.py
├── imagenet_a.py
├── imagenet_hdf5.py
├── tiny_imagenet.py
├── toxic.py
└── toxic_bert.py
├── experiments
├── bengali_experiment.sh
├── cifar_experiment.sh
├── fashion_experiments.sh
├── google_commands_experiment.sh
├── imagenet_experiment.sh
├── modelnet_experiment.sh
├── tiny_imagenet_experiment.sh
└── toxic_experiment.sh
├── fmix.py
├── fmix_3d.gif
├── fmix_example.png
├── hubconf.py
├── implementations
├── lightning.py
├── tensorflow_implementation.py
├── test_lightning.py
├── test_tensorflow.py
├── test_torchbearer.py
└── torchbearer_implementation.py
├── models
├── __init__.py
├── bert.py
├── densenet3.py
├── models.py
├── pyramid.py
├── resnet.py
├── senet.py
├── toxic_cnn.py
├── toxic_lstm.py
└── wide_resnet.py
├── notebooks
├── example_masks.ipynb
└── grad_cam.ipynb
├── requirements.txt
├── trainer.py
└── utils
├── __init__.py
├── auto_augment
├── __init__.py
├── auto_augment.py
└── auto_augment_aug_list.py
├── bengali_evaluate.py
├── convert_imagenet_model.py
├── cross_val.py
├── imagenet_to_hdf5.py
├── lr_warmup.py
├── macro_recall.py
├── msda_alternator.py
├── process.py
├── reduced_dataset_splitter.py
└── reformulated_mixup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Ours
132 | implementations/data
133 | implementations/lightning_logs
134 | data/
135 | .idea/
136 | saved_models/
137 | logs/
138 | notebooks/*
139 | !notebooks/*.ipynb
140 | .vector_cache
141 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Vision, Learning and Control Research Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # FMix
4 |
5 | This repository contains the __official__ implementation of the paper ['FMix: Enhancing Mixed Sampled Data Augmentation'](https://arxiv.org/abs/2002.12047)
6 |
7 | [](https://paperswithcode.com/sota/image-classification-on-cifar-10?p=understanding-and-enhancing-mixed-sample-data)
8 | [](https://paperswithcode.com/sota/image-classification-on-fashion-mnist?p=understanding-and-enhancing-mixed-sample-data)
9 |
10 |
11 | ArXiv •
12 | Papers With Code •
13 | About •
14 | Experiments •
15 | Implementations •
16 | Pre-trained Models
17 |
18 |
19 |
20 |
21 |
22 | Dive in with our example notebook in Colab!
23 |
24 |
25 |
26 |
27 | ## About
28 |
29 | FMix is a variant of MixUp, CutMix, etc. introduced in our paper ['FMix: Enhancing Mixed Sampled Data Augmentation'](https://arxiv.org/abs/2002.12047). It uses masks sampled from Fourier space to mix training examples. Take a look at our [example notebook in colab](https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/example_masks.ipynb) which shows how you can generate masks in two dimensions
30 |
31 |
32 |

33 |
34 |
35 | and in three!
36 |
37 |
38 |

39 |
40 |
41 | ## Experiments
42 |
43 | ### Core Experiments
44 |
45 | Shell scripts for our core experiments can be found in the [experiments folder](./experiments). For example,
46 |
47 | ```bash
48 | bash cifar_experiment cifar10 resnet fmix ./data
49 | ```
50 |
51 | will train a PreAct-ResNet18 on CIFAR-10 with FMix. More information can be found at the start of each of the shell files.
52 |
53 | ### Additional Experiments
54 |
55 | All additional classification experiments can be run via [`trainer.py`](./trainer.py)
56 |
57 | ### Analyses
58 |
59 | For Grad-CAM, take a look at the [Grad-CAM notebook in colab](https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/grad_cam.ipynb).
60 |
61 | For the other analyses, have a look in the [analysis folder](./analysis).
62 |
63 | ## Implementations
64 |
65 | The core implementation of `FMix` uses `numpy` and can be found in [`fmix.py`](./fmix.py). We provide bindings for this in [PyTorch](https://pytorch.org/) (with [Torchbearer](https://github.com/pytorchbearer/torchbearer) or [PyTorch-Lightning](https://github.com/PyTorchLightning/pytorch-lightning)) and [Tensorflow](https://www.tensorflow.org/).
66 |
67 | ### Torchbearer
68 |
69 | The `FMix` callback in [`torchbearer_implementation.py`](./implementations/torchbearer_implementation.py) can be added directly to your torchbearer code:
70 |
71 | ```python
72 | from implementations.torchbearer_implementation import FMix
73 |
74 | fmix = FMix()
75 | trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])
76 | ```
77 |
78 | See an example in [`test_torchbearer.py`](./implementations/test_torchbearer.py).
79 |
80 | ### PyTorch-Lightning
81 |
82 | For PyTorch-Lightning, we provide a class, `FMix` in [`lightning.py`](./implementations/lightning.py) that can be used in your `LightningModule`:
83 |
84 | ```python
85 | from implementations.lightning import FMix
86 |
87 | class CoolSystem(pl.LightningModule):
88 | def __init__(self):
89 | ...
90 |
91 | self.fmix = FMix()
92 |
93 | def training_step(self, batch, batch_nb):
94 | x, y = batch
95 | x = self.fmix(x)
96 |
97 | x = self.forward(x)
98 |
99 | loss = self.fmix.loss(x, y)
100 | return {'loss': loss}
101 | ```
102 |
103 | See an example in [`test_lightning.py`](./implementations/test_lightning.py).
104 |
105 | ### Tensorflow
106 |
107 | For Tensorflow, we provide a class, `FMix` in [`tensorflow_implementation.py`](./implementations/tensorflow_implementation.py) that can be used in your tensorflow code:
108 |
109 | ```python
110 | from implementations.tensorflow_implementation import FMix
111 |
112 | fmix = FMix()
113 |
114 | def loss(model, x, y, training=True):
115 | x = fmix(x)
116 | y_ = model(x, training=training)
117 | return tf.reduce_mean(fmix.loss(y_, y))
118 | ```
119 |
120 | See an example in [`test_tensorflow.py`](./implementations/test_tensorflow.py).
121 |
122 | ## Pre-trained Models
123 |
124 | We provide pre-trained models via `torch.hub` (more coming soon). To use them, run
125 |
126 | ```python
127 | import torch
128 | model = torch.hub.load('ecs-vlc/FMix:master', ARCHITECTURE, pretrained=True)
129 | ```
130 |
131 | where `ARCHITECTURE` is one of the following:
132 |
133 | ### CIFAR-10
134 |
135 | #### PreAct-ResNet-18
136 |
137 | | Configuration | `ARCHITECTURE` | Accuracy |
138 | | ---------------- | ----------------------------------------- | -------- |
139 | | Baseline | `'preact_resnet18_cifar10_baseline'` | -------- |
140 | | + MixUp | `'preact_resnet18_cifar10_mixup'` | -------- |
141 | | + FMix | `'preact_resnet18_cifar10_fmix'` | -------- |
142 | | + Mixup + FMix | `'preact_resnet18_cifar10_fmixplusmixup'` | -------- |
143 |
144 | #### PyramidNet-200
145 |
146 | | Configuration | `ARCHITECTURE` | Accuracy |
147 | | ---------------- | ------------------------------------ | --------- |
148 | | Baseline | `'pyramidnet_cifar10_baseline'` | 98.31 |
149 | | + MixUp | `'pyramidnet_cifar10_mixup'` | 97.92 |
150 | | + FMix | `'pyramidnet_cifar10_fmix'` | __98.64__ |
151 |
152 | ### ImageNet
153 |
154 | #### ResNet-101
155 |
156 | | Configuration | `ARCHITECTURE` | Accuracy (Top-1) |
157 | | ---------------- | ------------------------------------ | ---------------- |
158 | | Baseline | `'renset101_imagenet_baseline'` | 76.51 |
159 | | + MixUp | `'renset101_imagenet_mixup'` | 76.27 |
160 | | + FMix | `'renset101_imagenet_fmix'` | __76.72__ |
161 |
--------------------------------------------------------------------------------
/analysis/README.md:
--------------------------------------------------------------------------------
1 | # Analysis
2 |
3 | This directory contains the code used for the VAE and mutual information experiments from the paper.
4 |
--------------------------------------------------------------------------------
/analysis/adversarial_linf.py:
--------------------------------------------------------------------------------
1 | # !pip install foolbox
2 | import torch
3 | import eagerpy as ep
4 | from foolbox import PyTorchModel, accuracy, samples
5 | import foolbox.attacks as fa
6 | import numpy as np
7 | import json
8 |
9 | from torchvision.datasets import CIFAR10
10 | import torchvision.transforms as transforms
11 | from torch.utils.data import DataLoader
12 | from torch import nn
13 |
14 | from sklearn.model_selection import ParameterGrid
15 |
16 | import argparse
17 |
18 | from tqdm import tqdm
19 |
20 | parser = argparse.ArgumentParser(description='Imagenet Training')
21 | parser.add_argument('--arr', default=0, type=int, help='point in job array')
22 | args = parser.parse_args()
23 |
24 | param_grid = ParameterGrid({
25 | 'mode': ('baseline', 'cutmix', 'mixup', 'fmix'),
26 | 'repeat': list(range(5))
27 | })
28 |
29 | params = param_grid[args.arr]
30 | mode = params['mode']
31 | repeat = params['repeat']
32 |
33 | test_transform = transforms.Compose([
34 | transforms.ToTensor() # convert to tensor
35 | ])
36 |
37 | # load data
38 | testset = CIFAR10(".", train=False, download=True, transform=test_transform)
39 | testloader = DataLoader(testset, batch_size=750, shuffle=False, num_workers=5)
40 |
41 | attacks = [
42 | fa.FGSM(),
43 | fa.LinfPGD(),
44 | fa.LinfBasicIterativeAttack(),
45 | fa.LinfAdditiveUniformNoiseAttack(),
46 | fa.LinfDeepFoolAttack(),
47 | ]
48 |
49 | epsilons = [
50 | 0.0,
51 | 0.0005,
52 | 0.001,
53 | 0.0015,
54 | 0.002,
55 | 0.003,
56 | 0.005,
57 | 0.01,
58 | 0.02,
59 | 0.03,
60 | 0.1,
61 | 0.3,
62 | 0.5,
63 | 1.0,
64 | ]
65 |
66 |
67 | def normalize_with(mean, std):
68 | mean = torch.tensor(mean)
69 | std = torch.tensor(std)
70 | return lambda x: (x - mean.to(x.device).unsqueeze(0).unsqueeze(2).unsqueeze(3)) / std.to(x.device).unsqueeze(0).unsqueeze(2).unsqueeze(3)
71 |
72 | # (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
73 |
74 |
75 | class Normalized(nn.Module):
76 | def __init__(self, model, mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]):
77 | super().__init__()
78 | self.model = model
79 | self.normalize = normalize_with(mean=mean, std=std)
80 |
81 | def forward(self, x):
82 | x = self.normalize(x)
83 | return self.model(x)
84 |
85 |
86 | results = dict()
87 |
88 | # for mode in ('baseline', 'mixup', 'fmix', 'cutmix', 'fmixplusmixup'):
89 | print(mode)
90 | results[mode] = dict()
91 |
92 | model = Normalized(torch.hub.load('ecs-vlc/FMix:master', f'preact_resnet18_cifar10_{mode}', pretrained=True, repeat=repeat))
93 | model.eval()
94 | fmodel = PyTorchModel(model, bounds=(0, 1))
95 |
96 | attack_success = np.zeros((len(attacks), len(epsilons), len(testset)), dtype=np.bool)
97 | for i, attack in enumerate(attacks):
98 | # print(attack)
99 | idx = 0
100 | for images, labels in tqdm(testloader):
101 | # print('.', end='')
102 | images = images.to(fmodel.device)
103 | labels = labels.to(fmodel.device)
104 |
105 | _, _, success = attack(fmodel, images, labels, epsilons=epsilons)
106 | success_ = success.cpu().numpy()
107 | attack_success[i][:, idx:idx + len(labels)] = success_
108 | idx = idx + len(labels)
109 | # print("")
110 |
111 | import pickle
112 | with open(f'adversarial_linf_{mode}_{repeat}.p', 'wb') as f:
113 | # Pickle the 'data' dictionary using the highest protocol available.
114 | pickle.dump(attack_success, f)
115 | # for i, attack in enumerate(attacks):
116 | # results[mode][str(attack)] = (1.0 - attack_success[i].mean(axis=-1)).tolist()
117 | #
118 | # robust_accuracy = 1.0 - attack_success.max(axis=0).mean(axis=-1)
119 | # results[mode]['robust_accuracy'] = robust_accuracy.tolist()
120 | #
121 | # with open('adv-results-cifar-linf.json', 'w') as fp:
122 | # json.dump(results, fp)
123 |
--------------------------------------------------------------------------------
/analysis/mine.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from torchbearer import state_key
9 | from torchbearer import callbacks
10 |
11 | T = state_key('t')
12 | T_SHUFFLED = state_key('t_shuffled')
13 | MI = state_key('mi')
14 |
15 |
16 | class Flatten(nn.Module):
17 | def forward(self, x):
18 | return x.view(x.size(0), -1)
19 |
20 |
21 | class DoNothing(nn.Module):
22 | def forward(self, x):
23 | return x
24 |
25 |
26 | def resample(x):
27 | return F.fold(F.unfold(x, kernel_size=2, stride=2), (int(x.size(2) / 2), int(x.size(3) / 2)), 1)
28 |
29 |
30 | class Estimator(nn.Module):
31 | def __init__(self, conv, in_size, pool_input=False, halves=0):
32 | super().__init__()
33 | self.pool = DoNothing()
34 | self.halves = halves
35 |
36 | if conv:
37 | in_size = in_size + 3 * 4 ** halves
38 | if pool_input:
39 | self.pool = nn.AdaptiveAvgPool2d((8, 8))
40 |
41 | self.est = nn.Sequential(
42 | nn.Conv2d(in_size, 256, kernel_size=3, padding=1),
43 | nn.ReLU(),
44 | nn.AdaptiveAvgPool2d((2, 2)),
45 | Flatten(),
46 | nn.Linear(2 * 2 * 256, 512),
47 | nn.ReLU(),
48 | nn.Linear(512, 1)
49 | )
50 | else:
51 | in_size = in_size + 32 * 32 * 3
52 | self.est = nn.Sequential(
53 | nn.Linear(in_size, 512),
54 | nn.ReLU(),
55 | nn.Linear(512, 512),
56 | nn.ReLU(),
57 | nn.Linear(512, 1)
58 | )
59 |
60 | def forward(self, x, f):
61 | if self.halves < 5:
62 | for i in range(self.halves):
63 | x = resample(x)
64 | else:
65 | x = x.view(x.size(0), -1)
66 | f = f.view(f.size(0), -1)
67 | if f.dim() == 2:
68 | x = x.view(x.size(0), -1)
69 | x = torch.cat((x, f), dim=1)
70 | x = self.pool(x)
71 |
72 | return self.est(x)
73 |
74 |
75 | cfgs = {
76 | 'A': {
77 | 'f1': lambda: Estimator(True, 64, halves=1), 'f2': lambda: Estimator(True, 128, halves=2), 'f3': lambda: Estimator(True, 256, halves=2),
78 | 'f4': lambda: Estimator(True, 256, halves=3), 'f5': lambda: Estimator(True, 512, halves=3), 'f6': lambda: Estimator(True, 512, halves=4),
79 | 'f7': lambda: Estimator(True, 512, halves=4),
80 | 'f8': lambda: Estimator(False, 512, False, halves=5), 'c1': lambda: Estimator(False, 2048), 'c2': lambda: Estimator(False, 2048)},
81 | 'B': {},
82 | 'D': {},
83 | 'E': {},
84 | }
85 |
86 |
87 | def mi(tanh):
88 | def mi_loss(state):
89 | m_t, m_t_shuffled = state[torchbearer.Y_PRED]
90 | mi = {}
91 | sum = 0.0
92 | for layer in m_t.keys():
93 | t = m_t[layer]
94 | t_shuffled = m_t_shuffled[layer]
95 | if tanh:
96 | t = t.tanh()
97 | t_shuffled = t_shuffled.tanh()
98 | tmp = t.mean() - (torch.logsumexp(t_shuffled, 0) - math.log(t_shuffled.size(0)))
99 | mi[layer] = tmp.item()
100 | sum += tmp
101 | if len(mi.keys()) == 1:
102 | state[MI] = mi[next(iter(mi.keys()))]
103 | else:
104 | state[MI] = mi
105 | return -sum
106 | return mi_loss
107 |
108 |
109 | def process(x, cache, cfg):
110 | t = {}
111 | t_shuffled = {}
112 |
113 | for layer in cfg.keys():
114 | out = cache[layer].detach()
115 | t[layer] = cfg[layer](x, out)
116 | t_shuffled[layer] = cfg[layer](x, out[torch.randperm(out.size(0))])
117 | return t, t_shuffled
118 |
119 |
120 | class MimeVGG(nn.Module):
121 | def __init__(self, vgg, cfg):
122 | super().__init__()
123 |
124 | self.vgg = vgg
125 | self.cfg = nn.ModuleDict(cfg)
126 |
127 | def forward(self, x):
128 | pred, cache = self.vgg(x)
129 |
130 | t, t_shuffled = process(x, cache, self.cfg)
131 |
132 | return t, t_shuffled
133 |
134 |
135 | if __name__ == '__main__':
136 | from torch import optim
137 | from torchvision import transforms
138 | from torchvision.datasets import CIFAR10
139 |
140 | import torchbearer
141 | from torchbearer import Trial
142 |
143 | OTHER_MI = state_key('other_mi')
144 |
145 | cfg = cfgs['A']
146 |
147 | import argparse
148 |
149 | parser = argparse.ArgumentParser(description='VGG MI')
150 | parser.add_argument('--model', default='mix_3', type=str, help='model')
151 | args = parser.parse_args()
152 |
153 | from .vgg import vgg11_bn
154 |
155 | vgg = vgg11_bn(return_cache=True)
156 | vgg.load_state_dict(torch.load(args.model + '.pt')[torchbearer.MODEL])
157 | for param in vgg.parameters():
158 | param.requires_grad = False
159 |
160 | for layer in cfg:
161 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
162 | transform_base = [transforms.ToTensor(), normalize]
163 |
164 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base
165 |
166 | transform_train = transforms.Compose(transform)
167 | transform_test = transforms.Compose(transform_base)
168 |
169 | trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
170 | valset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
171 |
172 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=1000, shuffle=True, num_workers=8)
173 | valloader = torch.utils.data.DataLoader(valset, batch_size=5000, shuffle=True, num_workers=8)
174 |
175 | model = MimeVGG(vgg, {k: cfgs['A'][k]() for k in [layer]})
176 |
177 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=5e-4)
178 |
179 | mi_false = mi(False)
180 |
181 | @callbacks.add_to_loss
182 | def mi_no_tanh(state):
183 | state[OTHER_MI] = mi_false(state)
184 | return 0
185 |
186 | trial = Trial(model, optimizer, mi(True), metrics=['loss', torchbearer.metrics.mean(OTHER_MI)], callbacks=[mi_no_tanh, callbacks.TensorBoard(write_graph=False, comment='mi_' + args.model, log_dir='mi_data')])
187 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda')
188 | trial.run(20, verbose=1)
189 |
--------------------------------------------------------------------------------
/analysis/train_vaes.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | from torchvision import transforms, datasets
5 | import torchbearer
6 | from torchbearer import Trial, callbacks, metrics
7 | from torchbearer.callbacks import init
8 | import torch
9 | from torch import distributions
10 |
11 | from .vae import VAE, LATENT
12 | from implementations.torchbearer_implementation import FMix
13 |
14 | import argparse
15 |
16 | parser = argparse.ArgumentParser(description='VAE Training')
17 | parser.add_argument('--mode', default='base', type=str, help='name of run')
18 | parser.add_argument('--i', default=1, type=int, help='iteration')
19 | parser.add_argument('--var', default=1, type=float, help='iteration')
20 | parser.add_argument('--epochs', default=100, type=int, help='epochs')
21 | parser.add_argument('--dir', default='vaes', type=str, help='directory')
22 | args = parser.parse_args()
23 |
24 | KL = torchbearer.state_key('KL')
25 | NLL = torchbearer.state_key('NLL')
26 | SAMPLE = torchbearer.state_key('SAMPLE')
27 |
28 | # Data
29 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
30 | inv_normalize = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))
31 | transform_base = [transforms.ToTensor(), normalize]
32 |
33 | transform = [transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.RandomHorizontalFlip()] + transform_base
34 |
35 | transform_train = transforms.Compose(transform)
36 | transform_test = transforms.Compose(transform_base)
37 |
38 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
39 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
40 |
41 | train_loader = DataLoader(train_set, 128, shuffle=True, num_workers=5)
42 | test_loader = DataLoader(test_set, 100, shuffle=False, num_workers=5)
43 |
44 | # KL Divergence
45 |
46 | def kld(prior):
47 | @callbacks.add_to_loss
48 | def loss(state):
49 | res = distributions.kl_divergence(state[LATENT], prior).sum().div(state[LATENT].loc.size(0))
50 | state[KL] = res.detach()
51 | return res
52 | return loss
53 |
54 | # Negative Log Likelihood
55 |
56 | def nll(state):
57 | y_pred, y_true = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]
58 | res = - y_pred.log_prob(y_true).sum().div(y_true.size(0))
59 | state[NLL] = res.detach()
60 | return res
61 |
62 | # Generate Some Images
63 |
64 | @torchbearer.callbacks.on_forward
65 | @torchbearer.callbacks.on_forward_validation
66 | def sample(state):
67 | state[SAMPLE] = state[torchbearer.Y_PRED].loc
68 |
69 | # Train VAEs
70 |
71 | aug = []
72 | mode = args.mode
73 | if mode == 'mix':
74 | aug = [callbacks.Mixup()]
75 | if mode == 'cutmix':
76 | aug = [callbacks.CutMix(1, classes=10)]
77 | if mode == 'fmix':
78 | aug = [FMix(alpha=1, decay_power=3)]
79 |
80 | model = VAE(64, var=args.var)
81 | trial = Trial(model, optim.Adam(model.parameters(), lr=5e-2), nll,
82 | metrics=[
83 | metrics.MeanSquaredError(pred_key=SAMPLE),
84 | metrics.mean(NLL),
85 | metrics.mean(KL),
86 | 'loss'
87 | ],
88 | callbacks=[
89 | sample,
90 | kld(distributions.Normal(0, 1)),
91 | init.XavierNormal(targets=['Conv']),
92 | callbacks.MostRecent(args.dir + '/' + mode + '_' + str(args.i) + '.pt'),
93 | callbacks.MultiStepLR([40, 80]),
94 | callbacks.TensorBoard(write_graph=False, comment=mode + '_' + str(args.i), log_dir='vae_logs'),
95 | *aug
96 | ])
97 |
98 | if mode in ['base', 'mix', 'cutmix']:
99 | trial = trial.load_state_dict(torch.load('vaes/' + '/' + mode + '_' + str(args.i) + '.pt'))
100 |
101 | trial.with_generators(train_loader, test_loader).to('cuda').run(args.epochs, verbose=1)
102 |
--------------------------------------------------------------------------------
/analysis/vae.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import distributions
6 | from torch.distributions import constraints, register_kl
7 |
8 | import torchbearer
9 | from torchbearer import state_key
10 |
11 | LATENT = state_key('latent')
12 |
13 |
14 | class LogitNormal(distributions.Normal):
15 | arg_constraints = {'loc': constraints.real, 'log_scale': constraints.real}
16 | support = constraints.real
17 | has_rsample = True
18 |
19 | def __init__(self, loc, log_scale, validate_args=None):
20 | self.log_scale = log_scale
21 | scale = distributions.transform_to(distributions.Normal.arg_constraints['scale'])(log_scale)
22 | super().__init__(loc, scale, validate_args=validate_args)
23 |
24 | def log_prob(self, value):
25 | if self._validate_args:
26 | self._validate_sample(value)
27 | # compute the variance
28 | var = (self.scale ** 2)
29 | log_scale = self.log_scale
30 | return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
31 |
32 |
33 | @register_kl(LogitNormal, distributions.Normal)
34 | def kl_logitnormal_normal(p, q):
35 | log_var_ratio = 2 * (p.log_scale - q.scale.log())
36 | t1 = ((p.loc - q.loc) / q.scale).pow(2)
37 | return 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
38 |
39 |
40 | @register_kl(LogitNormal, LogitNormal)
41 | def kl_logitnormal_logitnormal(p, q):
42 | log_var_ratio = 2 * (p.log_scale - q.log_scale)
43 | t1 = ((p.loc - q.loc) / q.scale).pow(2)
44 | return 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
45 |
46 |
47 | class Flatten(nn.Module):
48 | def forward(self, x):
49 | return x.view(x.size(0), -1)
50 |
51 |
52 | class View(nn.Module):
53 | def __init__(self, *args):
54 | super().__init__()
55 | self.args = args
56 |
57 | def forward(self, x):
58 | return x.view(x.size(0), *self.args)
59 |
60 |
61 | class SimpleEncoder(nn.Sequential):
62 | def __init__(self):
63 | super().__init__(
64 | nn.Conv2d(3, 32, (4, 4), stride=2, padding=1),
65 | nn.ReLU(True),
66 | nn.BatchNorm2d(32),
67 | nn.Conv2d(32, 32, (4, 4), stride=2, padding=1),
68 | nn.ReLU(True),
69 | nn.BatchNorm2d(32),
70 | nn.Conv2d(32, 32, (4, 4), stride=2, padding=1),
71 | nn.ReLU(True),
72 | nn.BatchNorm2d(32),
73 | Flatten()
74 | )
75 |
76 | self.output_size = 32 * 4 * 4
77 |
78 |
79 | class SimpleDecoder(nn.Sequential):
80 | def __init__(self, z_dims):
81 | super().__init__(
82 | nn.Linear(z_dims, 32 * 4 * 4),
83 | View(32, 4, 4),
84 | nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1),
85 | nn.ReLU(True),
86 | nn.BatchNorm2d(32),
87 | nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1),
88 | nn.ReLU(True),
89 | nn.BatchNorm2d(32),
90 | nn.ConvTranspose2d(32, 3, (4, 4), stride=2, padding=1)
91 | )
92 |
93 |
94 | class DCGANDecoder(nn.Sequential):
95 | def __init__(self, z_dims, dim=128):
96 | super().__init__(
97 | nn.Linear(z_dims, dim * 4 * 4 * 4),
98 | nn.ReLU(True),
99 | nn.BatchNorm1d(dim * 4 * 4 * 4),
100 | View(dim * 4, 4, 4),
101 | nn.ConvTranspose2d(dim * 4, dim * 2, 5, stride=2, padding=2, output_padding=1),
102 | nn.ReLU(True),
103 | nn.BatchNorm2d(dim * 2),
104 | nn.ConvTranspose2d(dim * 2, dim, 5, stride=2, padding=2, output_padding=1),
105 | nn.ReLU(True),
106 | nn.BatchNorm2d(dim),
107 | nn.ConvTranspose2d(dim, 3, 5, stride=2, padding=2, output_padding=1)
108 | )
109 |
110 |
111 | class DCGANEncoder(nn.Sequential):
112 | def __init__(self, dim=128):
113 | super().__init__(
114 | nn.Conv2d(3, dim, 5, stride=2, padding=2),
115 | nn.ReLU(True),
116 | nn.BatchNorm2d(dim),
117 | nn.Conv2d(dim, dim * 2, 5, stride=2, padding=2),
118 | nn.ReLU(True),
119 | nn.BatchNorm2d(dim * 2),
120 | nn.Conv2d(dim * 2, dim * 4, 5, stride=2, padding=2),
121 | nn.ReLU(True),
122 | nn.BatchNorm2d(dim * 4),
123 | Flatten()
124 | )
125 |
126 | self.output_size = dim * 4 * 4 * 4
127 |
128 |
129 | class BetaVAEDecoder(nn.Sequential):
130 | def __init__(self, z_dims):
131 | super().__init__(
132 | nn.Linear(z_dims, 256),
133 | nn.ReLU(True),
134 | nn.BatchNorm1d(256),
135 | nn.Linear(256, 64 * 2 * 2),
136 | nn.ReLU(True),
137 | nn.BatchNorm1d(64 * 2 * 2),
138 | View(64, 2, 2),
139 | nn.ConvTranspose2d(64, 64, 4, stride=2, padding=1),
140 | nn.ReLU(True),
141 | nn.BatchNorm2d(64),
142 | nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
143 | nn.ReLU(True),
144 | nn.BatchNorm2d(32),
145 | nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1),
146 | nn.ReLU(True),
147 | nn.BatchNorm2d(32),
148 | nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1)
149 | )
150 |
151 |
152 | class BetaVAEDecoder2(nn.Sequential):
153 | def __init__(self, z_dims):
154 | super().__init__(
155 | nn.Linear(z_dims, 256),
156 | nn.ReLU(True),
157 | # nn.BatchNorm1d(256),
158 | nn.Linear(256, 64 * 2 * 2),
159 | nn.ReLU(True),
160 | # nn.BatchNorm1d(64 * 2 * 2),
161 | View(64, 2, 2),
162 | nn.Upsample(size=5, mode='bilinear'),
163 | nn.Conv2d(64, 64, 4, padding=1),
164 | nn.ReLU(True),
165 | nn.BatchNorm2d(64),
166 | nn.Upsample(size=9, mode='bilinear'),
167 | nn.Conv2d(64, 32, 4, padding=1),
168 | nn.ReLU(True),
169 | nn.BatchNorm2d(32),
170 | nn.Upsample(size=17, mode='bilinear'),
171 | nn.Conv2d(32, 32, 4, padding=1),
172 | nn.ReLU(True),
173 | nn.BatchNorm2d(32),
174 | nn.Upsample(size=33, mode='bilinear'),
175 | nn.Conv2d(32, 3, 4, padding=1)
176 | )
177 |
178 |
179 | class BetaVAEEncoder(nn.Sequential):
180 | def __init__(self):
181 | super().__init__(
182 | nn.Conv2d( 3, 32, 4, stride=2, padding=1),
183 | nn.ReLU(True),
184 | nn.BatchNorm2d(32),
185 | nn.Conv2d(32, 32, 4, stride=2, padding=1),
186 | nn.ReLU(True),
187 | nn.BatchNorm2d(32),
188 | nn.Conv2d(32, 64, 4, stride=2, padding=1),
189 | nn.ReLU(True),
190 | nn.BatchNorm2d(64),
191 | nn.Conv2d(64, 64, 4, stride=2, padding=1),
192 | nn.ReLU(True),
193 | nn.BatchNorm2d(64),
194 | Flatten(),
195 | nn.Linear(64 * 2 * 2, 256),
196 | nn.ReLU(True),
197 | # nn.BatchNorm1d(256)
198 | )
199 |
200 | self.output_size = 256
201 |
202 |
203 | class VAE(nn.Module):
204 | def __init__(self, z_dims=64, encoder=BetaVAEEncoder, decoder=BetaVAEDecoder2, var=0.1):
205 | super(VAE, self).__init__()
206 |
207 | self.var = var
208 |
209 | self.encoder = encoder()
210 | self.decoder = decoder(z_dims)
211 |
212 | self.loc = nn.Linear(self.encoder.output_size, z_dims)
213 | self.scale = nn.Linear(self.encoder.output_size, z_dims)
214 | self.loc.weight.data.zero_()
215 | self.loc.bias.data.zero_()
216 | self.scale.weight.data.zero_()
217 | self.scale.bias.data.zero_()
218 |
219 | def encode(self, x):
220 | x = self.encoder(x)
221 | loc = self.loc(x)
222 | scale = self.scale(x)
223 | return LogitNormal(loc, scale)
224 |
225 | def forward(self, x, state=None):
226 | if state is not None:
227 | state[torchbearer.TARGET] = x.detach()
228 |
229 | latent = self.encode(x)
230 |
231 | if state is not None:
232 | state[LATENT] = latent
233 |
234 | x = self.decoder(latent.rsample())
235 | return LogitNormal(x, (torch.ones_like(x) * self.var).log())
236 |
237 |
238 | class PredictionNetwork(nn.Module):
239 | def __init__(self, encoder_a, encoder_b, z_dims=32):
240 | super().__init__()
241 | self.z_dims = z_dims
242 |
243 | self.encoder_a = encoder_a
244 | for param in self.encoder_a.parameters():
245 | param.requires_grad = False
246 |
247 | self.encoder_b = encoder_b
248 | for param in self.encoder_b.parameters():
249 | param.requires_grad = False
250 |
251 | self.net = nn.Sequential(
252 | nn.Linear(z_dims, 32),
253 | nn.ReLU(True),
254 | nn.Linear(32, 32),
255 | nn.ReLU(True),
256 | nn.Linear(32, 32),
257 | nn.ReLU(True),
258 | nn.Linear(32, z_dims * 2)
259 | )
260 |
261 | self.net[-1].weight.data.zero_()
262 | self.net[-1].bias.data.zero_()
263 |
264 | def forward(self, x, state):
265 | self.encoder_a.eval()
266 | self.encoder_b.eval()
267 |
268 | a = self.encoder_a.encode(x).rsample().detach()
269 | b = self.encoder_b.encode(x)
270 | b.loc = b.loc.detach()
271 | b.scale = b.scale.detach()
272 |
273 | state[torchbearer.TARGET] = b
274 |
275 | x = self.net(a)
276 | loc = x[:, :self.z_dims]
277 | scale = x[:, self.z_dims:]
278 | return LogitNormal(loc, scale)
279 |
280 |
281 | class MINetwork(nn.Module):
282 | def __init__(self, encoder_a, encoder_b, upper=False):
283 | super().__init__()
284 | self.upper = upper
285 |
286 | self.encoder_a = encoder_a
287 | for param in self.encoder_a.parameters():
288 | param.requires_grad = False
289 |
290 | self.encoder_b = encoder_b
291 | for param in self.encoder_b.parameters():
292 | param.requires_grad = False
293 |
294 | def forward(self, x, state):
295 | self.encoder_a.eval()
296 | self.encoder_b.eval()
297 |
298 | if not self.upper:
299 | x = self.encoder_a(x).rsample()
300 | return self.encoder_b.encode(x)
301 |
--------------------------------------------------------------------------------
/analysis/vae_mi.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | from torch.utils.data import DataLoader
3 | from torchvision import transforms, datasets
4 | import torchbearer
5 | from torchbearer import Trial
6 | import torch
7 | from torch import distributions
8 |
9 | from .vae import VAE, MINetwork
10 |
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description='VAE MI')
14 | parser.add_argument('--vae1', default='base_5', type=str, help='VAE 1')
15 | parser.add_argument('--vae2', default='cutmix_5', type=str, help='VAE 2')
16 | parser.add_argument('--upper', default=False, type=bool, help='if True, use upper bound, else lower')
17 | args = parser.parse_args()
18 |
19 | # Data
20 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
21 | inv_normalize = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))
22 | transform_base = [transforms.ToTensor(), normalize]
23 |
24 | transform_test = transforms.Compose(transform_base)
25 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
26 |
27 | test_loader = DataLoader(test_set, 100, shuffle=True, num_workers=5)
28 |
29 | # MI Loss
30 |
31 | def kld(state):
32 | y_pred = state[torchbearer.Y_PRED]
33 | marginal = distributions.Normal(0, 1)
34 | return distributions.kl_divergence(y_pred, marginal).sum().div(y_pred.loc.size(0))
35 |
36 | # Train VAEs
37 |
38 | vae1 = VAE(64)
39 | vae1.load_state_dict(torch.load('vaes3/' + args.vae1 + '.pt')[torchbearer.MODEL])
40 |
41 | for param in vae1.parameters():
42 | param.requires_grad = False
43 |
44 | vae2 = VAE(64)
45 | vae2.load_state_dict(torch.load('vaes3/' + args.vae2 + '.pt')[torchbearer.MODEL])
46 |
47 | for param in vae2.parameters():
48 | param.requires_grad = False
49 |
50 | model = MINetwork(vae1, vae2, upper=args.upper)
51 | trial = Trial(model, criterion=kld,
52 | metrics=[
53 | 'loss'
54 | ])
55 |
56 | trial.with_generators(test_generator=test_loader).to('cuda').evaluate(data_key=torchbearer.TEST_DATA)
57 |
--------------------------------------------------------------------------------
/analysis/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://pytorch.org/docs/stable/_modules/torchvision/models/vgg.html
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | __all__ = [
9 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
10 | 'vgg19_bn', 'vgg19',
11 | ]
12 |
13 |
14 | class Cache:
15 | def __init__(self):
16 | super(Cache, self).__init__()
17 | self.outputs = {}
18 |
19 | def for_name(outer, layer_name):
20 | class Inner(nn.Module):
21 | def forward(inner, x):
22 | outer.outputs[layer_name] = x
23 | return x
24 | return Inner()
25 |
26 | def get_outputs(self):
27 | tmp = self.outputs
28 | self.outputs = {}
29 | return tmp
30 |
31 |
32 | class VGG(nn.Module):
33 | def __init__(self, features, cache, return_cache=False, num_classes=10, init_weights=True):
34 | super(VGG, self).__init__()
35 | self.features = features
36 | self.return_cache = return_cache
37 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
38 | self.cache = cache
39 | self.classifier = nn.Sequential(
40 | nn.Linear(512 * 7 * 7, 2048), # Normally 4096
41 | nn.ReLU(True),
42 | self.cache.for_name('c1'),
43 | nn.Dropout(),
44 | nn.Linear(2048, 2048), # Normally 4096
45 | nn.ReLU(True),
46 | self.cache.for_name('c2'),
47 | nn.Dropout(),
48 | nn.Linear(2048, num_classes), # Normally 4096
49 | )
50 | if init_weights:
51 | self._initialize_weights()
52 |
53 | def forward(self, x):
54 | x = self.features(x)
55 | x = self.avgpool(x)
56 | x = torch.flatten(x, 1)
57 | x = self.classifier(x)
58 | outs = self.cache.get_outputs()
59 | if self.return_cache:
60 | return x, outs
61 | else:
62 | return x
63 |
64 | def _initialize_weights(self):
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv2d):
67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
68 | if m.bias is not None:
69 | nn.init.constant_(m.bias, 0)
70 | elif isinstance(m, nn.BatchNorm2d):
71 | nn.init.constant_(m.weight, 1)
72 | nn.init.constant_(m.bias, 0)
73 | elif isinstance(m, nn.Linear):
74 | nn.init.normal_(m.weight, 0, 0.01)
75 | nn.init.constant_(m.bias, 0)
76 |
77 |
78 | def make_layers(cfg, batch_norm=False):
79 | cache = Cache()
80 | layers = []
81 | in_channels = 3
82 | for v in cfg:
83 | if v == 'M':
84 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
85 | elif str(v) is v:
86 | layers += [cache.for_name(v)]
87 | else:
88 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
89 | if batch_norm:
90 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
91 | else:
92 | layers += [conv2d, nn.ReLU(inplace=True)]
93 | in_channels = v
94 | return nn.Sequential(*layers), cache
95 |
96 |
97 | cfgs = {
98 | 'A': [64, 'M', 'f1', 128, 'M', 'f2', 256, 'f3', 256, 'M', 'f4', 512, 'f5', 512, 'M', 'f6', 512, 'f7', 512, 'M', 'f8'],
99 | 'B': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'M', 'f6', 512, 'f7', 512, 'M', 'f8', 512, 'f9', 512, 'M', 'f10'],
100 | 'D': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'f6', 256, 'M', 'f7', 512, 'f8', 512, 'f9', 512, 'M', 'f10', 512, 'f11', 512, 'f12', 512, 'M', 'f13'],
101 | 'E': [64, 'f1', 64, 'M', 'f2', 128, 'f3', 128, 'M', 'f4', 256, 'f5', 256, 'f6', 256, 'f7', 256, 'M', 'f8', 512, 'f9', 512, 'f10', 512, 'f11', 512, 'M', 'f12', 512, 'f13', 512, 'f14', 512, 'f15', 512, 'M', 'f16'],
102 | }
103 |
104 |
105 | def _vgg(cfg, batch_norm, **kwargs):
106 | model = VGG(*make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
107 | return model
108 |
109 |
110 | def vgg11(**kwargs):
111 | r"""VGG 11-layer model (configuration "A") from
112 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
113 | """
114 | return _vgg('A', False, **kwargs)
115 |
116 |
117 | def vgg11_bn(**kwargs):
118 | r"""VGG 11-layer model (configuration "A") with batch normalization
119 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
120 | """
121 | return _vgg('A', True, **kwargs)
122 |
123 |
124 | def vgg13(**kwargs):
125 | r"""VGG 13-layer model (configuration "B")
126 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
127 | """
128 | return _vgg('B', False, **kwargs)
129 |
130 |
131 | def vgg13_bn(**kwargs):
132 | r"""VGG 13-layer model (configuration "B") with batch normalization
133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
134 | """
135 | return _vgg('B', True, **kwargs)
136 |
137 |
138 | def vgg16(**kwargs):
139 | r"""VGG 16-layer model (configuration "D")
140 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
141 | """
142 | return _vgg('D', False, **kwargs)
143 |
144 |
145 | def vgg16_bn(**kwargs):
146 | r"""VGG 16-layer model (configuration "D") with batch normalization
147 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
148 | """
149 | return _vgg('D', True, **kwargs)
150 |
151 |
152 | def vgg19(**kwargs):
153 | r"""VGG 19-layer model (configuration "E")
154 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
155 | """
156 | return _vgg('E', False, **kwargs)
157 |
158 |
159 | def vgg19_bn(**kwargs):
160 | r"""VGG 19-layer model (configuration 'E') with batch normalization
161 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_
162 | """
163 | return _vgg('E', True, **kwargs)
164 |
165 |
166 | if __name__ == '__main__':
167 | from torch import optim
168 | from torchvision import transforms
169 | from torchvision.datasets import CIFAR10
170 |
171 | from torchbearer import Trial
172 | from torchbearer.callbacks import MultiStepLR, MostRecent, Mixup, CutMix
173 | from implementations.torchbearer_implementation import FMix
174 |
175 | for mode in ['baseline', 'mix', 'fmix', 'cutmix']:
176 | for i in range(0, 3):
177 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
178 | transform_base = [transforms.ToTensor(), normalize]
179 |
180 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base
181 |
182 | transform_train = transforms.Compose(transform)
183 | transform_test = transforms.Compose(transform_base)
184 |
185 | trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
186 | valset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
187 |
188 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
189 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8)
190 |
191 | vgg = vgg11_bn(return_cache=False)
192 | optimizer = optim.SGD(vgg.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
193 |
194 | app = []
195 | loss = nn.CrossEntropyLoss()
196 |
197 | if mode == 'mix':
198 | app = [Mixup()]
199 | loss = Mixup.mixup_loss
200 | if mode == 'fmix':
201 | app = [FMix(alpha=1)]
202 | loss = Mixup.mixup_loss
203 | if mode == 'cutmix':
204 | app = [CutMix(1.0, classes=10, mixup_loss=True)]
205 | loss = Mixup.mixup_loss
206 |
207 | trial = Trial(vgg, optimizer, loss, metrics=['acc', 'loss'], callbacks=app + [MostRecent(mode + '_' + str(i + 1) + '.pt'), MultiStepLR([100, 150])])
208 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda')
209 | trial.run(200, verbose=1)
210 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/bengali.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://www.kaggle.com/corochann/bengali-seresnext-training-with-pytorch
3 | """
4 | from torch.utils.data import Dataset
5 | # import os
6 | # from torchvision.datasets.folder import default_loader
7 | import numpy as np
8 | import pandas as pd
9 | import gc
10 |
11 |
12 | def prepare_image(root, indices=[0, 1, 2, 3]):
13 | # assert data_type in ['train', 'test']
14 | # if submission:
15 | # image_df_list = [pd.read_parquet(datadir / f'{data_type}_image_data_{i}.parquet')
16 | # for i in indices]
17 | # else:
18 | image_df_list = [pd.read_feather(f'{root}/train_image_data_{i}.feather') for i in indices]
19 |
20 | HEIGHT = 137
21 | WIDTH = 236
22 | images = [df.iloc[:, 1:].values.reshape(-1, HEIGHT, WIDTH) for df in image_df_list]
23 | del image_df_list
24 | gc.collect()
25 | images = np.concatenate(images, axis=0)
26 | return images
27 |
28 |
29 | class Bengali(Dataset):
30 | def __init__(self, root, targets, transform=None):
31 | self.transform = transform
32 |
33 | if isinstance(targets, list):
34 | self.labels = list(pd.read_csv(f'{root}/train.csv')[targets].itertuples(index=False, name=None))
35 | else:
36 | self.labels = pd.read_csv(f'{root}/train.csv')[targets]
37 | self.images = prepare_image(root)
38 |
39 | def __getitem__(self, index):
40 | image, label = self.images[index], self.labels[index]
41 | image = (255 - image).astype(np.float32) / 255.
42 |
43 | if self.transform is not None:
44 | image = self.transform(image)
45 |
46 | return image, label
47 |
48 | def __len__(self) -> int:
49 | return len(self.labels)
50 |
51 |
52 | class BengaliGraphemeWhole(Bengali):
53 | def __init__(self, root, transform=None):
54 | super().__init__(root, ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'], transform=transform)
55 |
56 |
57 | class BengaliGraphemeRoot(Bengali):
58 | def __init__(self, root, transform=None):
59 | super().__init__(root, 'grapheme_root', transform=transform)
60 |
61 |
62 | class BengaliVowelDiacritic(Bengali):
63 | def __init__(self, root, transform=None):
64 | super().__init__(root, 'vowel_diacritic', transform=transform)
65 |
66 |
67 | class BengaliConsonantDiacritic(Bengali):
68 | def __init__(self, root, transform=None):
69 | super().__init__(root, 'consonant_diacritic', transform=transform)
70 |
--------------------------------------------------------------------------------
/datasets/fashion.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import MNIST
2 |
3 |
4 | class OldFashionMNIST(MNIST):
5 | """The original version of the `Fashion-MNIST `_ Dataset. This had
6 | some train / test duplicates and so was replaced. Do not use in practice, may be helpful for reproducing results
7 | from FMix, RandomErase and other papers.
8 |
9 | Args:
10 | root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt``
11 | and ``Fashion-MNIST/processed/test.pt`` exist.
12 | train (bool, optional): If True, creates dataset from ``training.pt``,
13 | otherwise from ``test.pt``.
14 | download (bool, optional): If true, downloads the dataset from the internet and
15 | puts it in root directory. If dataset is already downloaded, it is not
16 | downloaded again.
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | """
22 | urls = [
23 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/train-images-idx3-ubyte.gz',
24 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/train-labels-idx1-ubyte.gz',
25 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/t10k-images-idx3-ubyte.gz',
26 | 'https://github.com/zalandoresearch/fashion-mnist/raw/949006594cc2f804a93b9155849734564c3545ec/data/fashion/t10k-labels-idx1-ubyte.gz',
27 | ]
28 | classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
29 | 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
30 |
--------------------------------------------------------------------------------
/datasets/google_commands/README.md:
--------------------------------------------------------------------------------
1 | Code adapted from https://github.com/tugstugi/pytorch-speech-commands
--------------------------------------------------------------------------------
/datasets/google_commands/google_commands.py:
--------------------------------------------------------------------------------
1 | """Google speech commands dataset."""
2 | __author__ = 'Yuan Xu'
3 |
4 | import os
5 | import numpy as np
6 |
7 | from torch.utils.data import Dataset
8 |
9 | __all__ = [ 'CLASSES', 'SpeechCommandsDataset', 'BackgroundNoiseDataset' ]
10 |
11 | CLASSES = 'unknown, silence, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
12 |
13 |
14 | class SpeechCommandsDataset(Dataset):
15 | """Google speech commands dataset. Only 'yes', 'no', 'up', 'down', 'left',
16 | 'right', 'on', 'off', 'stop' and 'go' are treated as known classes.
17 | All other classes are used as 'unknown' samples.
18 | See for more information: https://www.kaggle.com/c/tensorflow-speech-recognition-challenge
19 | """
20 |
21 | def __init__(self, folder, transform=None, classes=CLASSES, silence_percentage=0.1):
22 | try:
23 | import librosa
24 | except:
25 | raise ModuleNotFoundError('Librosa package is reuiqred for google commands experiments. Try pip install')
26 |
27 | all_classes = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith('_')]
28 | #for c in classes[2:]:
29 | # assert c in all_classes
30 |
31 | class_to_idx = {classes[i]: i for i in range(len(classes))}
32 | for c in all_classes:
33 | if c not in class_to_idx:
34 | class_to_idx[c] = 0
35 |
36 | data = []
37 | for c in all_classes:
38 | d = os.path.join(folder, c)
39 | target = class_to_idx[c]
40 | for f in os.listdir(d):
41 | path = os.path.join(d, f)
42 | data.append((path, target))
43 |
44 | # add silence
45 | target = class_to_idx['silence']
46 | data += [('', target)] * int(len(data) * silence_percentage)
47 |
48 | self.classes = classes
49 | self.data = data
50 | self.transform = transform
51 |
52 | def __len__(self):
53 | return len(self.data)
54 |
55 | def __getitem__(self, index):
56 | path, target = self.data[index]
57 | data = {'path': path, 'target': target}
58 |
59 | if self.transform is not None:
60 | data = self.transform(data)
61 |
62 | return data, target
63 |
64 | def make_weights_for_balanced_classes(self):
65 | """adopted from https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3"""
66 |
67 | nclasses = len(self.classes)
68 | count = np.zeros(nclasses)
69 | for item in self.data:
70 | count[item[1]] += 1
71 |
72 | N = float(sum(count))
73 | weight_per_class = N / count
74 | weight = np.zeros(len(self))
75 | for idx, item in enumerate(self.data):
76 | weight[idx] = weight_per_class[item[1]]
77 | return weight
78 |
79 |
80 | class BackgroundNoiseDataset(Dataset):
81 | """Dataset for silence / background noise."""
82 |
83 | def __init__(self, folder, transform=None, sample_rate=16000, sample_length=1):
84 | try:
85 | import librosa
86 | except:
87 | raise ModuleNotFoundError('Librosa package is reuiqred for google commands experiments. Try pip install')
88 | audio_files = [d for d in os.listdir(folder) if os.path.isfile(os.path.join(folder, d)) and d.endswith('.wav')]
89 | samples = []
90 | for f in audio_files:
91 | path = os.path.join(folder, f)
92 | s, sr = librosa.load(path, sample_rate)
93 | samples.append(s)
94 |
95 | samples = np.hstack(samples)
96 | c = int(sample_rate * sample_length)
97 | r = len(samples) // c
98 | self.samples = samples[:r*c].reshape(-1, c)
99 | self.sample_rate = sample_rate
100 | self.classes = CLASSES
101 | self.transform = transform
102 | self.path = folder
103 |
104 | def __len__(self):
105 | return len(self.samples)
106 |
107 | def __getitem__(self, index):
108 | data = {'samples': self.samples[index], 'sample_rate': self.sample_rate, 'target': 1, 'path': self.path}
109 |
110 | if self.transform is not None:
111 | data = self.transform(data)
112 |
113 | return data
114 |
--------------------------------------------------------------------------------
/datasets/google_commands/sft_transforms.py:
--------------------------------------------------------------------------------
1 | """Transforms on the short time fourier transforms of wav samples."""
2 |
3 | __author__ = 'Erdene-Ochir Tuguldur'
4 |
5 | import random
6 |
7 | import numpy as np
8 |
9 | from torch.utils.data import Dataset
10 |
11 | from datasets.google_commands.transforms import should_apply_transform
12 |
13 |
14 | class ToSTFT(object):
15 | """Applies on an audio the short time fourier transform."""
16 |
17 | def __init__(self, n_fft=2048, hop_length=512):
18 | self.n_fft = n_fft
19 | self.hop_length = hop_length
20 |
21 | def __call__(self, data):
22 | import librosa
23 | samples = data['samples']
24 | sample_rate = data['sample_rate']
25 | data['n_fft'] = self.n_fft
26 | data['hop_length'] = self.hop_length
27 | data['stft'] = librosa.stft(samples, n_fft=self.n_fft, hop_length=self.hop_length)
28 | data['stft_shape'] = data['stft'].shape
29 | return data
30 |
31 |
32 | class StretchAudioOnSTFT(object):
33 | """Stretches an audio on the frequency domain."""
34 |
35 | def __init__(self, max_scale=0.2):
36 | self.max_scale = max_scale
37 |
38 | def __call__(self, data):
39 | import librosa
40 |
41 | if not should_apply_transform():
42 | return data
43 |
44 | stft = data['stft']
45 | sample_rate = data['sample_rate']
46 | hop_length = data['hop_length']
47 | scale = random.uniform(-self.max_scale, self.max_scale)
48 | stft_stretch = librosa.core.phase_vocoder(stft, 1 + scale, hop_length=hop_length)
49 | data['stft'] = stft_stretch
50 | return data
51 |
52 |
53 | class TimeshiftAudioOnSTFT(object):
54 | """A simple timeshift on the frequency domain without multiplying with exp."""
55 |
56 | def __init__(self, max_shift=8):
57 | self.max_shift = max_shift
58 |
59 | def __call__(self, data):
60 | if not should_apply_transform():
61 | return data
62 |
63 | stft = data['stft']
64 | shift = random.randint(-self.max_shift, self.max_shift)
65 | a = -min(0, shift)
66 | b = max(0, shift)
67 | stft = np.pad(stft, ((0, 0), (a, b)), "constant")
68 | if a == 0:
69 | stft = stft[:, b:]
70 | else:
71 | stft = stft[:, 0:-a]
72 | data['stft'] = stft
73 | return data
74 |
75 |
76 | class AddBackgroundNoiseOnSTFT(Dataset):
77 | """Adds a random background noise on the frequency domain."""
78 |
79 | def __init__(self, bg_dataset, max_percentage=0.45):
80 | self.bg_dataset = bg_dataset
81 | self.max_percentage = max_percentage
82 |
83 | def __call__(self, data):
84 | if not should_apply_transform():
85 | return data
86 |
87 | noise = random.choice(self.bg_dataset)['stft']
88 | percentage = random.uniform(0, self.max_percentage)
89 | data['stft'] = data['stft'] * (1 - percentage) + noise * percentage
90 | return data
91 |
92 |
93 | class FixSTFTDimension(object):
94 | """Either pads or truncates in the time axis on the frequency domain, applied after stretching, time shifting etc."""
95 |
96 | def __call__(self, data):
97 | stft = data['stft']
98 | t_len = stft.shape[1]
99 | orig_t_len = data['stft_shape'][1]
100 | if t_len > orig_t_len:
101 | stft = stft[:, 0:orig_t_len]
102 | elif t_len < orig_t_len:
103 | stft = np.pad(stft, ((0, 0), (0, orig_t_len - t_len)), "constant")
104 |
105 | data['stft'] = stft
106 | return data
107 |
108 |
109 | class ToMelSpectrogramFromSTFT(object):
110 | """Creates the mel spectrogram from the short time fourier transform of a file. The result is a 32x32 matrix."""
111 |
112 | def __init__(self, n_mels=32):
113 | self.n_mels = n_mels
114 |
115 | def __call__(self, data):
116 | import librosa
117 |
118 | stft = data['stft']
119 | sample_rate = data['sample_rate']
120 | n_fft = data['n_fft']
121 | mel_basis = librosa.filters.mel(sample_rate, n_fft, self.n_mels)
122 | s = np.dot(mel_basis, np.abs(stft) ** 2.0)
123 | data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
124 | return data
125 |
126 |
127 | class DeleteSTFT(object):
128 | """Pytorch doesn't like complex numbers, use this transform to remove STFT after computing the mel spectrogram."""
129 |
130 | def __call__(self, data):
131 | del data['stft']
132 | return data
133 |
134 |
135 | class AudioFromSTFT(object):
136 | """Inverse short time fourier transform."""
137 |
138 | def __call__(self, data):
139 | import librosa
140 |
141 | stft = data['stft']
142 | data['istft_samples'] = librosa.core.istft(stft, dtype=data['samples'].dtype)
143 | return data
144 |
--------------------------------------------------------------------------------
/datasets/google_commands/transforms.py:
--------------------------------------------------------------------------------
1 | """Transforms on raw wav samples."""
2 |
3 | __author__ = 'Yuan Xu'
4 |
5 | import random
6 | import numpy as np
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 |
12 | def should_apply_transform(prob=0.5):
13 | """Transforms are only randomly applied with the given probability."""
14 | return random.random() < prob
15 |
16 |
17 | class LoadAudio(object):
18 | """Loads an audio into a numpy array."""
19 |
20 | def __init__(self, sample_rate=16000):
21 | self.sample_rate = sample_rate
22 |
23 | def __call__(self, data):
24 | import librosa
25 |
26 | path = data['path']
27 | if path:
28 | samples, sample_rate = librosa.load(path, self.sample_rate)
29 | else:
30 | # silence
31 | sample_rate = self.sample_rate
32 | samples = np.zeros(sample_rate, dtype=np.float32)
33 | data['samples'] = samples
34 | data['sample_rate'] = sample_rate
35 | return data
36 |
37 |
38 | class FixAudioLength(object):
39 | """Either pads or truncates an audio into a fixed length."""
40 |
41 | def __init__(self, time=1):
42 | self.time = time
43 |
44 | def __call__(self, data):
45 | samples = data['samples']
46 | sample_rate = data['sample_rate']
47 | length = int(self.time * sample_rate)
48 | if length < len(samples):
49 | data['samples'] = samples[:length]
50 | elif length > len(samples):
51 | data['samples'] = np.pad(samples, (0, length - len(samples)), "constant")
52 | return data
53 |
54 |
55 | class ChangeAmplitude(object):
56 | """Changes amplitude of an audio randomly."""
57 |
58 | def __init__(self, amplitude_range=(0.7, 1.1)):
59 | self.amplitude_range = amplitude_range
60 |
61 | def __call__(self, data):
62 | if not should_apply_transform():
63 | return data
64 |
65 | data['samples'] = data['samples'] * random.uniform(*self.amplitude_range)
66 | return data
67 |
68 |
69 | class ChangeSpeedAndPitchAudio(object):
70 | """Change the speed of an audio. This transform also changes the pitch of the audio."""
71 |
72 | def __init__(self, max_scale=0.2):
73 | self.max_scale = max_scale
74 |
75 | def __call__(self, data):
76 | if not should_apply_transform():
77 | return data
78 |
79 | samples = data['samples']
80 | sample_rate = data['sample_rate']
81 | scale = random.uniform(-self.max_scale, self.max_scale)
82 | speed_fac = 1.0 / (1 + scale)
83 | data['samples'] = np.interp(np.arange(0, len(samples), speed_fac), np.arange(0,len(samples)), samples).astype(np.float32)
84 | return data
85 |
86 |
87 | class StretchAudio(object):
88 | """Stretches an audio randomly."""
89 |
90 | def __init__(self, max_scale=0.2):
91 | self.max_scale = max_scale
92 |
93 | def __call__(self, data):
94 | import librosa
95 |
96 | if not should_apply_transform():
97 | return data
98 |
99 | scale = random.uniform(-self.max_scale, self.max_scale)
100 | data['samples'] = librosa.effects.time_stretch(data['samples'], 1+scale)
101 | return data
102 |
103 |
104 | class TimeshiftAudio(object):
105 | """Shifts an audio randomly."""
106 |
107 | def __init__(self, max_shift_seconds=0.2):
108 | self.max_shift_seconds = max_shift_seconds
109 |
110 | def __call__(self, data):
111 | if not should_apply_transform():
112 | return data
113 |
114 | samples = data['samples']
115 | sample_rate = data['sample_rate']
116 | max_shift = (sample_rate * self.max_shift_seconds)
117 | shift = random.randint(-max_shift, max_shift)
118 | a = -min(0, shift)
119 | b = max(0, shift)
120 | samples = np.pad(samples, (a, b), "constant")
121 | data['samples'] = samples[:len(samples) - a] if a else samples[b:]
122 | return data
123 |
124 |
125 | class AddBackgroundNoise(Dataset):
126 | """Adds a random background noise."""
127 |
128 | def __init__(self, bg_dataset, max_percentage=0.45):
129 | self.bg_dataset = bg_dataset
130 | self.max_percentage = max_percentage
131 |
132 | def __call__(self, data):
133 | if not should_apply_transform():
134 | return data
135 |
136 | samples = data['samples']
137 | noise = random.choice(self.bg_dataset)['samples']
138 | percentage = random.uniform(0, self.max_percentage)
139 | data['samples'] = samples * (1 - percentage) + noise * percentage
140 | return data
141 |
142 |
143 | class ToMelSpectrogram(object):
144 | """Creates the mel spectrogram from an audio. The result is a 32x32 matrix."""
145 |
146 | def __init__(self, n_mels=32):
147 | self.n_mels = n_mels
148 |
149 | def __call__(self, data):
150 | import librosa
151 |
152 | samples = data['samples']
153 | sample_rate = data['sample_rate']
154 | s = librosa.feature.melspectrogram(samples, sr=sample_rate, n_mels=self.n_mels)
155 | data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
156 | return data
157 |
158 |
159 | class ToTensor(object):
160 | """Converts into a tensor."""
161 |
162 | def __init__(self, np_name, tensor_name, normalize=None):
163 | self.np_name = np_name
164 | self.tensor_name = tensor_name
165 | self.normalize = normalize
166 |
167 | def __call__(self, data):
168 | tensor = torch.FloatTensor(data[self.np_name])
169 | if self.normalize is not None:
170 | mean, std = self.normalize
171 | tensor -= mean
172 | tensor /= std
173 | data[self.tensor_name] = tensor
174 | return (data['input'].unsqueeze(0)+40)/80
--------------------------------------------------------------------------------
/datasets/imagenet_a.py:
--------------------------------------------------------------------------------
1 | thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1,
2 | 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1,
3 | 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1,
4 | 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1,
5 | 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1,
6 | 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1,
7 | 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1,
8 | 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22,
9 | 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26,
10 | 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1,
11 | 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1,
12 | 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1,
13 | 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1,
14 | 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1,
15 | 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1,
16 | 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1,
17 | 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1,
18 | 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1,
19 | 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38,
20 | 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1,
21 | 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1,
22 | 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1,
23 | 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1,
24 | 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1,
25 | 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1,
26 | 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42,
27 | 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44,
28 | 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1,
29 | 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50,
30 | 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58,
31 | 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63,
32 | 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1,
33 | 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68,
34 | 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1,
35 | 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1,
36 | 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1,
37 | 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1,
38 | 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74,
39 | 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79,
40 | 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82,
41 | 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1,
42 | 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87,
43 | 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1,
44 | 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91,
45 | 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1,
46 | 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1,
47 | 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1,
48 | 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1,
49 | 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1,
50 | 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1,
51 | 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1,
52 | 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1,
53 | 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1,
54 | 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110,
55 | 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1,
56 | 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1,
57 | 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1,
58 | 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1,
59 | 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120,
60 | 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1,
61 | 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124,
62 | 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1,
63 | 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1,
64 | 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1,
65 | 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1,
66 | 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131,
67 | 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134,
68 | 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1,
69 | 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1,
70 | 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1,
71 | 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1,
72 | 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1,
73 | 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1,
74 | 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1,
75 | 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1,
76 | 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1,
77 | 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1,
78 | 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153,
79 | 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1,
80 | 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1,
81 | 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1,
82 | 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1,
83 | 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166,
84 | 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1,
85 | 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1,
86 | 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1,
87 | 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1,
88 | 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175,
89 | 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177,
90 | 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1,
91 | 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1,
92 | 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183,
93 | 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186,
94 | 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190,
95 | 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1,
96 | 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1,
97 | 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198,
98 | 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1,
99 | 998: -1, 999: -1}
100 |
101 | indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1]
102 |
--------------------------------------------------------------------------------
/datasets/imagenet_hdf5.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import pickle
4 |
5 | import h5py
6 | from PIL import Image
7 | from torchvision.datasets import VisionDataset
8 |
9 |
10 | class ImageNetHDF5(VisionDataset):
11 | def __init__(self, root, cache_size=500, transform=None):
12 | super(ImageNetHDF5, self).__init__(root, transform=transform, target_transform=None)
13 |
14 | self.dest = pickle.load(open(os.path.join(root, 'dest.p'), 'rb'))
15 | self.cache = {}
16 | self.cache_size = cache_size
17 |
18 | targets = sorted(list(filter(lambda f: '.hdf5' in f, os.listdir(root))))
19 | self.targets = {f[:-5]: i for i, f in enumerate(targets)}
20 | self.fill_cache()
21 |
22 | def load(self, file, i):
23 | with h5py.File(os.path.join(self.root, file + '.hdf5'), 'r') as f:
24 | return f['data'][i]
25 |
26 | def fill_cache(self):
27 | print('Filling cache')
28 | files = (f[:-5] for f in list(filter(lambda f: '.hdf5' in f, os.listdir(self.root)))[:self.cache_size])
29 | for file in files:
30 | with h5py.File(os.path.join(self.root, file + '.hdf5'), 'r') as f:
31 | self.cache[file] = list(f['data'])
32 | print('Done')
33 |
34 | def load_from_cache(self, file, i):
35 | if file in self.cache:
36 | return self.cache[file][i]
37 | return self.load(file, i)
38 |
39 | def __getitem__(self, index):
40 | dest, i = self.dest[index]
41 |
42 | sample = self.load_from_cache(dest, i)
43 |
44 | sample = Image.open(io.BytesIO(sample))
45 | sample = sample.convert('RGB')
46 |
47 | if self.transform is not None:
48 | sample = self.transform(sample)
49 |
50 | return sample, self.targets[dest]
51 |
52 | def __len__(self):
53 | return len(self.dest)
54 |
--------------------------------------------------------------------------------
/datasets/tiny_imagenet.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import os
3 | from torchvision.datasets.folder import default_loader
4 |
5 |
6 | class TinyImageNet(Dataset):
7 | def __init__(self, root, train=True, transform=None):
8 | super().__init__()
9 | self.root = root
10 | self.transform = transform
11 | self.words = self.parse_classes()
12 |
13 | if train:
14 | self.class_path = os.path.join(root, 'train')
15 | self.img_labels = self.parse_train()
16 | else:
17 | self.class_path = os.path.join(root, 'val')
18 | self.img_labels = self.parse_val_labels()
19 |
20 | def parse_classes(self):
21 | words_path = os.path.join(self.root, 'wnids.txt')
22 | words = {}
23 | i = 0
24 | with open(words_path, 'r') as f:
25 | for w in f:
26 | w = w.strip('\n')
27 | word_label = w.split('\t')[0]
28 | words[word_label] = i
29 | i += 1
30 | return words
31 |
32 | def parse_val_labels(self):
33 | val_annot = os.path.join(self.root, 'val', 'val_annotations.txt')
34 | img_label = []
35 | with open(val_annot, 'r') as f:
36 | for line in f:
37 | line.strip('\n')
38 | img, word, *_ = line.split('\t')
39 | img = os.path.join(self.root, 'val', 'images', img)
40 | img_label.append((img, self.words[word]))
41 | return img_label
42 |
43 | def parse_train(self):
44 | img_labels = []
45 | for c in os.listdir(self.class_path):
46 | label = self.words[c]
47 | images_path = os.path.join(self.root, 'train', c, 'images')
48 | for im in os.listdir(images_path):
49 | im_path = os.path.join(images_path, im)
50 | img_labels.append((im_path, label))
51 | return img_labels
52 |
53 | def __getitem__(self, index):
54 | img, label = self.img_labels[index]
55 | pil_img = default_loader(img)
56 |
57 | if self.transform is not None:
58 | pil_img = self.transform(pil_img)
59 |
60 | return pil_img, label
61 |
62 | def __len__(self) -> int:
63 | return len(self.img_labels)
64 |
--------------------------------------------------------------------------------
/datasets/toxic.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import re
3 |
4 | import pandas as pd
5 | import os
6 | import torch
7 |
8 | from torchbearer import Callback
9 | import torchbearer
10 |
11 |
12 | def toxic_ds(args):
13 | from torchtext import data
14 | import torchtext
15 | tok = spacy.load('en')
16 | stopwords = spacy.lang.en.stop_words.STOP_WORDS
17 | stopwords.update(['wikipedia', 'article', 'articles', 'im', 'page'])
18 |
19 | def spacy_tok(x):
20 | x = re.sub(r'[^a-zA-Z\s]', '', x)
21 | x = re.sub(r'[\n]', ' ', x)
22 | return [t.text for t in tok.tokenizer(x)]
23 |
24 | TEXT = data.Field(lower=True, tokenize=spacy_tok, eos_token='EOS', stop_words=stopwords, include_lengths=True)
25 | LABEL = data.Field(sequential=False, use_vocab=False, pad_token=None, unk_token=None)
26 | dataFields = [("id", None), ("comment_text", TEXT), ("toxic", LABEL), ("severe_toxic", LABEL), ("threat", LABEL),
27 | ("obscene", LABEL), ("insult", LABEL), ("identity_hate", LABEL)]
28 | trainset_path = os.path.join(args.dataset_path, 'train.csv')
29 | train = data.TabularDataset(path=trainset_path, format='csv', fields=dataFields, skip_header=True)
30 |
31 | TEXT.build_vocab(train, vectors='fasttext.simple.300d')
32 | traindl = torchtext.data.BucketIterator(dataset=train, batch_size=args.batch_size,
33 | sort_key=lambda x: len(x.comment_text), device=torch.device(args.device),
34 | sort_within_batch=True)
35 |
36 | test_csv_path = os.path.join(args.dataset_path, 'test.csv')
37 | test_labels_path = os.path.join(args.dataset_path, 'test_labels.csv')
38 | test_set_path = os.path.join(args.dataset_path, 'test_set.csv')
39 | if not os.path.isfile(test_set_path):
40 | a = pd.read_csv(test_csv_path)
41 | b = pd.read_csv(test_labels_path)
42 | b = b.dropna(axis=1)
43 | merged = a.merge(b, on='id')
44 | merged = merged[merged['toxic'] >= 0]
45 | merged.to_csv(test_set_path, index=False)
46 |
47 | testset = data.TabularDataset(path=test_set_path, format='csv', fields=dataFields, skip_header=True)
48 | testdl = torchtext.data.BucketIterator(dataset=testset, batch_size=64, sort_key=lambda x: len(x.comment_text),
49 | device=torch.device(args.device), sort_within_batch=True)
50 | vectors = train.fields['comment_text'].vocab.vectors.to(args.device)
51 |
52 | traindl, testdl = BatchGenerator(traindl), BatchGenerator(testdl)
53 | traindl.vectors = vectors
54 | traindl.ntokens = len(TEXT.vocab)
55 |
56 | return traindl, None, testdl
57 |
58 |
59 | class BatchGenerator:
60 | def __init__(self, dl):
61 | self.dl = dl
62 | self.yFields = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
63 | self.x = 'comment_text'
64 |
65 | def __len__(self):
66 | return len(self.dl)
67 |
68 | def __iter__(self):
69 | for batch in self.dl:
70 | X = list(getattr(batch, self.x))
71 | X = X[0].permute(1, 0)
72 | y = torch.transpose(torch.stack([getattr(batch, y) for y in self.yFields]), 0, 1)
73 | yield (X, y)
74 |
75 |
76 | class ToxicHelper(Callback):
77 | def __init__(self, to_float=True):
78 | self.convert = (lambda x: x.float()) if to_float else (lambda x: x)
79 |
80 | def on_start(self, state):
81 | super().on_start(state)
82 | vectors = state[torchbearer.TRAIN_GENERATOR].vectors
83 | ntokens = state[torchbearer.TRAIN_GENERATOR].ntokens
84 | state[torchbearer.MODEL].init_embedding(vectors, ntokens, state[torchbearer.DEVICE])
85 |
86 | def on_sample(self, state):
87 | state[torchbearer.Y_TRUE] = self.convert(state[torchbearer.Y_TRUE])
88 |
89 | state[torchbearer.X] = state[torchbearer.MODEL].embed(state[torchbearer.X])
90 |
91 | def on_sample_validation(self, state):
92 | state[torchbearer.Y_TRUE] = self.convert(state[torchbearer.Y_TRUE])
93 |
94 | state[torchbearer.X] = state[torchbearer.MODEL].embed(state[torchbearer.X])
95 |
--------------------------------------------------------------------------------
/datasets/toxic_bert.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from torch.utils.data import Dataset
7 | from torchtext.data import Iterator, batch, pool
8 | from tqdm import tqdm
9 |
10 |
11 | class ToxicDataset(Dataset):
12 | def __init__(self, dataframe, bert_model='bert-base-cased'):
13 | from transformers import BertTokenizer
14 | self.tokenizer = BertTokenizer.from_pretrained(bert_model)
15 | self.pad_idx = self.tokenizer.pad_token_id
16 |
17 | self.X = []
18 | self.Y = []
19 | for i, (row) in tqdm(dataframe.iterrows()):
20 | x, y = self.row_to_tensor(self.tokenizer, row)
21 | self.X.append(x)
22 | self.Y.append(y)
23 |
24 | @staticmethod
25 | def row_to_tensor(tokenizer, row):
26 | tokens = tokenizer.encode(row["comment_text"], add_special_tokens=True, max_length=128)
27 |
28 | x = torch.LongTensor(tokens)
29 | y = torch.FloatTensor(row[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]])
30 | return x, y
31 |
32 | def __len__(self):
33 | return len(self.X)
34 |
35 | def __getitem__(self, index):
36 | return self.X[index], self.Y[index]
37 |
38 |
39 | class NoBatchBucketIterator(Iterator):
40 | """Defines an iterator that batches examples of similar lengths together.
41 |
42 | Minimizes amount of padding needed while producing freshly shuffled
43 | batches for each new epoch. See pool for the bucketing procedure used.
44 | """
45 |
46 | def create_batches(self):
47 | if self.sort:
48 | self.batches = batch(self.data(), self.batch_size,
49 | self.batch_size_fn)
50 | else:
51 | self.batches = pool(self.data(), self.batch_size,
52 | self.sort_key, self.batch_size_fn,
53 | random_shuffler=self.random_shuffler,
54 | shuffle=self.shuffle,
55 | sort_within_batch=self.sort_within_batch)
56 |
57 | def __iter__(self):
58 | while True:
59 | self.init_epoch()
60 | for idx, minibatch in enumerate(self.batches):
61 | # fast-forward if loaded from state
62 | if self._iterations_this_epoch > idx:
63 | continue
64 | self.iterations += 1
65 | self._iterations_this_epoch += 1
66 | if self.sort_within_batch:
67 | # NOTE: `rnn.pack_padded_sequence` requires that a minibatch
68 | # be sorted by decreasing order, which requires reversing
69 | # relative to typical sort keys
70 | if self.sort:
71 | minibatch.reverse()
72 | else:
73 | minibatch.sort(key=self.sort_key, reverse=True)
74 |
75 | x, y = list(zip(*minibatch))
76 | x = pad_sequence(x, batch_first=True, padding_value=0)
77 | y = torch.stack(y)
78 | yield x, y
79 | if not self.repeat:
80 | return
81 |
82 |
83 | def toxic_bert(args):
84 | trainset_path = os.path.join(args.dataset_path, 'train.csv')
85 | test_csv_path = os.path.join(args.dataset_path, 'test.csv')
86 | test_labels_path = os.path.join(args.dataset_path, 'test_labels.csv')
87 | testset_path = os.path.join(args.dataset_path, 'test_set.csv')
88 | if not os.path.isfile(testset_path):
89 | a = pd.read_csv(test_csv_path)
90 | b = pd.read_csv(test_labels_path)
91 | b = b.dropna(axis=1)
92 | merged = a.merge(b, on='id')
93 | merged = merged[merged['toxic'] >= 0]
94 | merged.to_csv(testset_path, index=False)
95 |
96 | train_df = pd.read_csv(trainset_path)
97 | test_df = pd.read_csv(testset_path)
98 |
99 | train_dataset = ToxicDataset(train_df)
100 | test_dataset = ToxicDataset(test_df)
101 |
102 | train_loader = NoBatchBucketIterator(dataset=train_dataset, batch_size=args.batch_size,
103 | sort_key=lambda x: x[0].size(0),
104 | device=torch.device(args.device), sort_within_batch=True)
105 | test_loader = NoBatchBucketIterator(dataset=test_dataset, batch_size=args.batch_size,
106 | sort_key=lambda x: x[0].size(0),
107 | device=torch.device(args.device), sort_within_batch=True)
108 |
109 | return train_loader, None, test_loader
110 |
--------------------------------------------------------------------------------
/experiments/bengali_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (msda_mode) (type) (dataset path)`
3 | # where msda_mode is one of [fmix, mixup, None]
4 | # type is one of [r, c, v]
5 | # For multiple GPU, add --parallel=True
6 |
7 | python ../trainer.py --dataset bengali_${2} --fold 0 --model se_resnext50_32x4d --epoch 100 --schedule 50 75 --batch-size 512 --lr=0.1 --dataset-path=$3 --msda-mode=$1
8 |
--------------------------------------------------------------------------------
/experiments/cifar_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (cifar) (model) (msda_mode) (dataset path)`
3 | # where:
4 | # cifar is one of [cifar10, cifar100]
5 | # model is one of [resnet, wrn, densenet, pyramidnet]
6 | # msda_mode is one of [fmix, mixup, None]
7 | # For multiple GPU, add --parallel=True
8 |
9 | if [ "$1" == "cifar10" ]
10 | then
11 | ds=cifar10
12 | fi
13 | if [ "$1" == "cifar100" ]
14 | then
15 | ds=cifar10
16 | fi
17 |
18 | if [ "$2" == "resnet" ]
19 | then
20 | model=ResNet18
21 | epoch=200
22 | schedule=(100 150)
23 | bs=128
24 | cosine=False
25 | fi
26 |
27 | if [ "$2" == "wrn" ]
28 | then
29 | model=wrn
30 | epoch=200
31 | schedule=(100 150)
32 | bs=128
33 | cosine=False
34 | fi
35 |
36 | if [ "$2" == "densenet" ]
37 | then
38 | model=DenseNet190
39 | epoch=300
40 | schedule=(100 150 225)
41 | bs=32
42 | cosine=False
43 | fi
44 |
45 | if [ "$2" == "pyramidnet" ]
46 | then
47 | model=aa_PyramidNet
48 | epoch=1800
49 | schedule=2000
50 | bs=64
51 | cosine=True
52 | fi
53 |
54 | python ../trainer.py --dataset=$ds --model=$model --epoch=$epoch --schedule ${schedule[@]} --lr=0.1 --dataset-path=$4 --msda-mode=$3 --batch-size=$bs --cosine-scheduler=$cosine
55 |
--------------------------------------------------------------------------------
/experiments/fashion_experiments.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (model) (msda_mode) (dataset path)`
3 | # where:
4 | # model is one of [resnet, wrn, densenet]
5 | # msda_mode is one of [fmix, mixup, None]
6 | # For multiple GPU, add --parallel=True
7 |
8 | if [ "$1" == "resnet" ]
9 | then
10 | model=ResNet18
11 | epoch=200
12 | schedule=(100 150)
13 | bs=128
14 | fi
15 |
16 | if [ "$1" == "wrn" ]
17 | then
18 | model=wrn
19 | epoch=300
20 | schedule=(100 150 225)
21 | bs=32
22 | fi
23 |
24 | if [ "$1" == "densenet" ]
25 | then
26 | model=DenseNet190
27 | epoch=300
28 | schedule=(100 150 225)
29 | bs=32
30 | fi
31 |
32 | python ../trainer.py --dataset=fashion --model=$model --epoch=$epoch --schedule ${schedule[@]} --lr=0.1 --dataset-path=$3 --msda-mode=$2 --batch-size=$bs
33 |
--------------------------------------------------------------------------------
/experiments/google_commands_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path) (alpha)`
3 | # where msda_mode is one of [fmix, mixup, None]
4 | # For multiple GPU, add --parallel=True
5 |
6 | python ../trainer.py --dataset=commands --epoch=90 --schedule 30 60 80 --lr=0.01 --dataset-path=$2 --msda-mode=$1 --alpha=$3
7 |
--------------------------------------------------------------------------------
/experiments/imagenet_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)`
4 | # where msda_mode is one of [fmix, mixup, None]
5 | # For multiple GPU, add --parallel=True
6 |
7 | python ../trainer.py --dataset=imagenet --epoch=90 --model=torch_resnet101 --schedule 30 60 80 --batch-size=256 --lr=0.4 --lr-warmup=True --dataset-path=$2 --msda-mode=$1
8 |
--------------------------------------------------------------------------------
/experiments/modelnet_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)`
3 | # where msda_mode is one of [fmix, mixup, None]
4 | # For multiple GPU, add --parallel=True
5 |
6 | python ../trainer.py --dataset=modelnet --epoch=50 --schedule 10 20 30 40 --lr=0.001 --dataset-path=$2 --msda-mode=$1 --batch-size=16
7 |
--------------------------------------------------------------------------------
/experiments/tiny_imagenet_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)`
3 | # where msda_mode is one of [fmix, mixup, None]
4 | # For multiple GPU, add --parallel=True
5 |
6 | python ../trainer.py --dataset=tinyimagenet --epoch=200 --schedule 150 180 --batch-size=128 --lr=0.1 --dataset-path=$2 --msda-mode=$1
7 |
--------------------------------------------------------------------------------
/experiments/toxic_experiment.sh:
--------------------------------------------------------------------------------
1 |
2 | # Usage: `bash imagenet_experiment (msda_mode) (dataset path)`
3 | # where msda_mode is one of [fmix, mixup, None]
4 | # For multiple GPU, add --parallel=True
5 |
6 | python ../trainer.py --dataset=toxic --epoch=10 --batch-size=64 --lr=1e-4 --dataset-path=$2 --msda-mode=$1
7 |
--------------------------------------------------------------------------------
/fmix.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | from scipy.stats import beta
6 |
7 |
8 | def fftfreqnd(h, w=None, z=None):
9 | """ Get bin values for discrete fourier transform of size (h, w, z)
10 |
11 | :param h: Required, first dimension size
12 | :param w: Optional, second dimension size
13 | :param z: Optional, third dimension size
14 | """
15 | fz = fx = 0
16 | fy = np.fft.fftfreq(h)
17 |
18 | if w is not None:
19 | fy = np.expand_dims(fy, -1)
20 |
21 | if w % 2 == 1:
22 | fx = np.fft.fftfreq(w)[: w // 2 + 2]
23 | else:
24 | fx = np.fft.fftfreq(w)[: w // 2 + 1]
25 |
26 | if z is not None:
27 | fy = np.expand_dims(fy, -1)
28 | if z % 2 == 1:
29 | fz = np.fft.fftfreq(z)[:, None]
30 | else:
31 | fz = np.fft.fftfreq(z)[:, None]
32 |
33 | return np.sqrt(fx * fx + fy * fy + fz * fz)
34 |
35 |
36 | def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
37 | """ Samples a fourier image with given size and frequencies decayed by decay power
38 |
39 | :param freqs: Bin values for the discrete fourier transform
40 | :param decay_power: Decay power for frequency decay prop 1/f**d
41 | :param ch: Number of channels for the resulting mask
42 | :param h: Required, first dimension size
43 | :param w: Optional, second dimension size
44 | :param z: Optional, third dimension size
45 | """
46 | scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
47 |
48 | param_size = [ch] + list(freqs.shape) + [2]
49 | param = np.random.randn(*param_size)
50 |
51 | scale = np.expand_dims(scale, -1)[None, :]
52 |
53 | return scale * param
54 |
55 |
56 | def make_low_freq_image(decay, shape, ch=1):
57 | """ Sample a low frequency image from fourier space
58 |
59 | :param decay_power: Decay power for frequency decay prop 1/f**d
60 | :param shape: Shape of desired mask, list up to 3 dims
61 | :param ch: Number of channels for desired mask
62 | """
63 | freqs = fftfreqnd(*shape)
64 | spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
65 | spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
66 | mask = np.real(np.fft.irfftn(spectrum, shape))
67 |
68 | if len(shape) == 1:
69 | mask = mask[:1, :shape[0]]
70 | if len(shape) == 2:
71 | mask = mask[:1, :shape[0], :shape[1]]
72 | if len(shape) == 3:
73 | mask = mask[:1, :shape[0], :shape[1], :shape[2]]
74 |
75 | mask = mask
76 | mask = (mask - mask.min())
77 | mask = mask / mask.max()
78 | return mask
79 |
80 |
81 | def sample_lam(alpha, reformulate=False):
82 | """ Sample a lambda from symmetric beta distribution with given alpha
83 |
84 | :param alpha: Alpha value for beta distribution
85 | :param reformulate: If True, uses the reformulation of [1].
86 | """
87 | if reformulate:
88 | lam = beta.rvs(alpha+1, alpha)
89 | else:
90 | lam = beta.rvs(alpha, alpha)
91 |
92 | return lam
93 |
94 |
95 | def binarise_mask(mask, lam, in_shape, max_soft=0.0):
96 | """ Binarises a given low frequency image such that it has mean lambda.
97 |
98 | :param mask: Low frequency image, usually the result of `make_low_freq_image`
99 | :param lam: Mean value of final mask
100 | :param in_shape: Shape of inputs
101 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
102 | :return:
103 | """
104 | idx = mask.reshape(-1).argsort()[::-1]
105 | mask = mask.reshape(-1)
106 | num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
107 |
108 | eff_soft = max_soft
109 | if max_soft > lam or max_soft > (1-lam):
110 | eff_soft = min(lam, 1-lam)
111 |
112 | soft = int(mask.size * eff_soft)
113 | num_low = num - soft
114 | num_high = num + soft
115 |
116 | mask[idx[:num_high]] = 1
117 | mask[idx[num_low:]] = 0
118 | mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
119 |
120 | mask = mask.reshape((1, *in_shape))
121 | return mask
122 |
123 |
124 | def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
125 | """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
126 | it based on this lambda
127 |
128 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
129 | :param decay_power: Decay power for frequency decay prop 1/f**d
130 | :param shape: Shape of desired mask, list up to 3 dims
131 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
132 | :param reformulate: If True, uses the reformulation of [1].
133 | """
134 | if isinstance(shape, int):
135 | shape = (shape,)
136 |
137 | # Choose lambda
138 | lam = sample_lam(alpha, reformulate)
139 |
140 | # Make mask, get mean / std
141 | mask = make_low_freq_image(decay_power, shape)
142 | mask = binarise_mask(mask, lam, shape, max_soft)
143 |
144 | return lam, mask
145 |
146 |
147 | def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
148 | """
149 |
150 | :param x: Image batch on which to apply fmix of shape [b, c, shape*]
151 | :param alpha: Alpha value for beta distribution from which to sample mean of mask
152 | :param decay_power: Decay power for frequency decay prop 1/f**d
153 | :param shape: Shape of desired mask, list up to 3 dims
154 | :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
155 | :param reformulate: If True, uses the reformulation of [1].
156 | :return: mixed input, permutation indices, lambda value of mix,
157 | """
158 | lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
159 | index = np.random.permutation(x.shape[0])
160 |
161 | x1, x2 = x * mask, x[index] * (1-mask)
162 | return x1+x2, index, lam
163 |
164 |
165 | class FMixBase:
166 | r""" FMix augmentation
167 |
168 | Args:
169 | decay_power (float): Decay power for frequency decay prop 1/f**d
170 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
171 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
172 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
173 | reformulate (bool): If True, uses the reformulation of [1].
174 | """
175 |
176 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
177 | super().__init__()
178 | self.decay_power = decay_power
179 | self.reformulate = reformulate
180 | self.size = size
181 | self.alpha = alpha
182 | self.max_soft = max_soft
183 | self.index = None
184 | self.lam = None
185 |
186 | def __call__(self, x):
187 | raise NotImplementedError
188 |
189 | def loss(self, *args, **kwargs):
190 | raise NotImplementedError
191 |
--------------------------------------------------------------------------------
/fmix_3d.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/fmix_3d.gif
--------------------------------------------------------------------------------
/fmix_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/fmix_example.png
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | dependencies = ['torch', 'torchvision']
2 |
3 | from torch.hub import load_state_dict_from_url
4 |
5 |
6 | def _preact_resnet18(msda='fmix', pretrained=False, repeat=0, *args, **kwargs):
7 | from models import ResNet18
8 | model = ResNet18(*args, **kwargs)
9 |
10 | if pretrained:
11 | state = load_state_dict_from_url(
12 | f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/preact-resnet18/cifar10_preact_resnet18_{msda}_{repeat}.pt', progress=True)
13 | model.load_state_dict(state)
14 |
15 | return model
16 |
17 |
18 | def _resnet101(msda='fmix', pretrained=False, *args, **kwargs):
19 | from torchvision.models.resnet import resnet101
20 | model = resnet101(*args, **kwargs)
21 |
22 | if pretrained:
23 | state = load_state_dict_from_url(
24 | f'http://marc.ecs.soton.ac.uk/pytorch-models/imagenet/resnet101/imagenet_resnet101_{msda}.pt', progress=True)
25 | model.load_state_dict(state)
26 |
27 | return model
28 |
29 |
30 | def _pyramidnet(msda='fmix', pretrained=False, *args, **kwargs):
31 | from models import aa_PyramidNet
32 | model = aa_PyramidNet(*args, **kwargs)
33 |
34 | if pretrained:
35 | state = load_state_dict_from_url(
36 | f'http://marc.ecs.soton.ac.uk/pytorch-models/cifar10/pyramidnet/cifar10_pyramidnet_{msda}.pt', progress=True)
37 | model.load_state_dict(state)
38 |
39 | return model
40 |
41 |
42 | def preact_resnet18_cifar10_baseline(pretrained=False, repeat=0, *args, **kwargs):
43 | return _preact_resnet18('baseline', pretrained, repeat, *args, **kwargs)
44 |
45 |
46 | def preact_resnet18_cifar10_fmix(pretrained=False, repeat=0, *args, **kwargs):
47 | return _preact_resnet18('fmix', pretrained, repeat, *args, **kwargs)
48 |
49 |
50 | def preact_resnet18_cifar10_mixup(pretrained=False, repeat=0, *args, **kwargs):
51 | return _preact_resnet18('mixup', pretrained, repeat, *args, **kwargs)
52 |
53 |
54 | def preact_resnet18_cifar10_cutmix(pretrained=False, repeat=0, *args, **kwargs):
55 | return _preact_resnet18('cutmix', pretrained, repeat, *args, **kwargs)
56 |
57 |
58 | def preact_resnet18_cifar10_fmixplusmixup(pretrained=False, repeat=0, *args, **kwargs):
59 | return _preact_resnet18('fmixplusmixup', pretrained, repeat, *args, **kwargs)
60 |
61 |
62 | def pyramidnet_cifar10_baseline(pretrained=False, *args, **kwargs):
63 | return _pyramidnet('baseline', pretrained, *args, **kwargs)
64 |
65 |
66 | def pyramidnet_cifar10_fmix(pretrained=False, *args, **kwargs):
67 | return _pyramidnet('fmix', pretrained, *args, **kwargs)
68 |
69 |
70 | def pyramidnet_cifar10_mixup(pretrained=False, *args, **kwargs):
71 | return _pyramidnet('mixup', pretrained, *args, **kwargs)
72 |
73 |
74 | def pyramidnet_cifar10_cutmix(pretrained=False, *args, **kwargs):
75 | return _pyramidnet('cutmix', pretrained, *args, **kwargs)
76 |
77 |
78 | def renset101_imagenet_baseline(pretrained=False, *args, **kwargs):
79 | return _resnet101('baseline', pretrained, *args, **kwargs)
80 |
81 |
82 | def renset101_imagenet_fmix(pretrained=False, *args, **kwargs):
83 | return _resnet101('fmix', pretrained, *args, **kwargs)
84 |
85 |
86 | def renset101_imagenet_mixup(pretrained=False, *args, **kwargs):
87 | return _resnet101('mixup', pretrained, *args, **kwargs)
88 |
89 |
90 | def renset101_imagenet_cutmix(pretrained=False, *args, **kwargs):
91 | return _resnet101('cutmix', pretrained, *args, **kwargs)
92 |
--------------------------------------------------------------------------------
/implementations/lightning.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from fmix import sample_mask, FMixBase
3 | import torch
4 |
5 |
6 | def fmix_loss(input, y1, index, lam, train=True, reformulate=False):
7 | r"""Criterion for fmix
8 |
9 | Args:
10 | input: If train, mixed input. If not train, standard input
11 | y1: Targets for first image
12 | index: Permutation for mixing
13 | lam: Lambda value of mixing
14 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1
15 | """
16 |
17 | if train and not reformulate:
18 | y2 = y1[index]
19 | return F.cross_entropy(input, y1) * lam + F.cross_entropy(input, y2) * (1 - lam)
20 | else:
21 | return F.cross_entropy(input, y1)
22 |
23 |
24 | class FMix(FMixBase):
25 | r""" FMix augmentation
26 |
27 | Args:
28 | decay_power (float): Decay power for frequency decay prop 1/f**d
29 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
30 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
31 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
32 | reformulate (bool): If True, uses the reformulation of [1].
33 |
34 | Example
35 | -------
36 |
37 | .. code-block:: python
38 |
39 | class FMixExp(pl.LightningModule):
40 | def __init__(*args, **kwargs):
41 | self.fmix = Fmix(...)
42 | # ...
43 |
44 | def training_step(self, batch, batch_idx):
45 | x, y = batch
46 | x = self.fmix(x)
47 |
48 | feature_maps = self.forward(x)
49 | logits = self.classifier(feature_maps)
50 | loss = self.fmix.loss(logits, y)
51 |
52 | # ...
53 | return loss
54 | """
55 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
56 | super().__init__(decay_power, alpha, size, max_soft, reformulate)
57 |
58 | def __call__(self, x):
59 | # Sample mask and generate random permutation
60 | lam, mask = sample_mask(self.alpha, self.decay_power, self.size, self.max_soft, self.reformulate)
61 | index = torch.randperm(x.size(0)).to(x.device)
62 | mask = torch.from_numpy(mask).float().to(x.device)
63 |
64 | # Mix the images
65 | x1 = mask * x
66 | x2 = (1 - mask) * x[index]
67 | self.index = index
68 | self.lam = lam
69 | return x1+x2
70 |
71 | def loss(self, y_pred, y, train=True):
72 | return fmix_loss(y_pred, y, self.index, self.lam, train, self.reformulate)
73 |
--------------------------------------------------------------------------------
/implementations/tensorflow_implementation.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from fmix import sample_mask, FMixBase
4 | softmax_cross_entropy_with_logits = tf.nn.softmax_cross_entropy_with_logits
5 |
6 |
7 | def fmix_loss(input, y1, index, lam, train=True, reformulate=False):
8 | r"""Criterion for fmix
9 |
10 | Args:
11 | input: If train, mixed input. If not train, standard input
12 | y1: Targets for first image
13 | y2: Targets for image mixed with first image
14 | lam: Lambda value of mixing
15 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1
16 | """
17 |
18 | if train and not reformulate:
19 | y2 = tf.gather(y1, index)
20 | y1, y2 = tf.transpose(tf.one_hot(y1, 10, axis=0)), tf.transpose(tf.one_hot(y2, 10, axis=0))
21 | return softmax_cross_entropy_with_logits(logits=input, labels=y1) * lam + softmax_cross_entropy_with_logits(logits=input, labels=y2) * (1-lam)
22 | else:
23 | y1 = tf.transpose(tf.one_hot(y1, 10, axis=0))
24 | return softmax_cross_entropy_with_logits(logits=input, labels=y1)
25 |
26 |
27 | class FMix(FMixBase):
28 | r""" FMix augmentation
29 |
30 | Args:
31 | decay_power (float): Decay power for frequency decay prop 1/f**d
32 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
33 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
34 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
35 | reformulate (bool): If True, uses the reformulation of [1].
36 |
37 | Example
38 | ----------
39 |
40 | fmix = FMix(...)
41 |
42 | def loss(model, x, y, training=True):
43 | x = fmix(x)
44 | y_ = model(x, training=training)
45 | return tf.reduce_mean(fmix.loss(y_, y))
46 | """
47 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
48 | super().__init__(decay_power, alpha, size, max_soft, reformulate)
49 |
50 | def __call__(self, x):
51 | shape = [int(s) for s in x.shape][1:-1]
52 | lam, mask = sample_mask(self.alpha, self.decay_power, shape, self.max_soft, self.reformulate)
53 | index = np.random.permutation(int(x.shape[0]))
54 | index = tf.constant(index)
55 | mask = np.expand_dims(mask, -1)
56 |
57 | x1 = x * mask
58 | x2 = tf.gather(x, index) * (1 - mask)
59 | self.index = index
60 | self.lam = lam
61 |
62 | return x1 + x2
63 |
64 | def loss(self, y_pred, y, train=True):
65 | return fmix_loss(y_pred, y, self.index, self.lam, train, self.reformulate)
66 |
--------------------------------------------------------------------------------
/implementations/test_lightning.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms, models
2 | import torch
3 | from torch import optim
4 | from implementations.lightning import FMix
5 | from pytorch_lightning import LightningModule, Trainer, data_loader
6 |
7 |
8 | # ######### Data
9 | print('==> Preparing data..')
10 | classes, cifar = 10, datasets.CIFAR10
11 |
12 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
13 | transform_base = [transforms.ToTensor(), normalize]
14 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base
15 |
16 | transform_train = transforms.Compose(transform)
17 | transform_test = transforms.Compose(transform_base)
18 | trainset = cifar(root='./data', train=True, download=True, transform=transform_train)
19 | valset = cifar(root='./data', train=False, download=True, transform=transform_test)
20 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
21 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8)
22 |
23 |
24 | ######### Model
25 | print('==> Building model..')
26 |
27 |
28 | class FMixExp(LightningModule):
29 | def __init__(self):
30 | super().__init__()
31 | self.net = models.resnet18(False)
32 | self.fmix = FMix()
33 |
34 | def forward(self, x):
35 | return self.net(x)
36 |
37 | def training_step(self, batch, batch_nb):
38 | x, y = batch
39 | x = self.fmix(x)
40 |
41 | x = self.forward(x)
42 |
43 | loss = self.fmix.loss(x, y)
44 | return {'loss': loss}
45 |
46 | def validation_step(self, batch, batch_nb):
47 | x, y = batch
48 |
49 | x = self.forward(x)
50 |
51 | labels_hat = torch.argmax(x, dim=1)
52 | val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
53 | val_acc = torch.tensor(val_acc)
54 |
55 | loss = self.fmix.loss(x, y, train=False)
56 | output = {
57 | 'val_loss': loss,
58 | 'val_acc': val_acc,
59 | }
60 |
61 | # can also return just a scalar instead of a dict (return loss_val)
62 | return output
63 |
64 | def validation_end(self, outputs):
65 | """
66 | Called at the end of validation to aggregate outputs
67 | :param outputs: list of individual outputs of each validation step
68 | :return:
69 | """
70 | # if returned a scalar from validation_step, outputs is a list of tensor scalars
71 | # we return just the average in this case (if we want)
72 | # return torch.stack(outputs).mean()
73 |
74 | val_loss_mean = 0
75 | val_acc_mean = 0
76 | for output in outputs:
77 | val_loss = output['val_loss']
78 |
79 | # reduce manually when using dp
80 | if self.trainer.use_dp or self.trainer.use_ddp2:
81 | val_loss = torch.mean(val_loss)
82 | val_loss_mean += val_loss
83 |
84 | # reduce manually when using dp
85 | val_acc = output['val_acc']
86 | if self.trainer.use_dp or self.trainer.use_ddp2:
87 | val_acc = torch.mean(val_acc)
88 |
89 | val_acc_mean += val_acc
90 |
91 | val_loss_mean /= len(outputs)
92 | val_acc_mean /= len(outputs)
93 | tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
94 | result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': val_loss_mean}
95 | return result
96 |
97 | def configure_optimizers(self):
98 | return torch.optim.SGD(self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
99 |
100 | @data_loader
101 | def train_dataloader(self):
102 | return trainloader
103 |
104 | @data_loader
105 | def val_dataloader(self):
106 | return valloader
107 |
108 |
109 | ######### Train
110 | print('==> Starting training..')
111 | trainer = Trainer(gpus=1, early_stop_callback=False, max_epochs=20, checkpoint_callback=False)
112 | mod = FMixExp()
113 | trainer.fit(mod)
114 |
--------------------------------------------------------------------------------
/implementations/test_tensorflow.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 | from tensorflow.compat.v1 import enable_eager_execution
5 | from tensorflow.keras import datasets, layers, models
6 |
7 | from implementations.tensorflow_implementation import FMix
8 |
9 | enable_eager_execution()
10 | print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
11 |
12 |
13 | (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
14 | train_images, test_images = train_images / 255.0, test_images / 255.0
15 | train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
16 |
17 | fmix = FMix()
18 |
19 |
20 | def loss(model, x, y, training=True):
21 | x = fmix(x)
22 | y_ = model(x, training=training)
23 | return tf.reduce_mean(fmix.loss(y_, y))
24 |
25 |
26 | model = models.Sequential()
27 | model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
28 | model.add(layers.MaxPooling2D((2, 2)))
29 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
30 | model.add(layers.MaxPooling2D((2, 2)))
31 | model.add(layers.Conv2D(64, (3, 3), activation='relu'))
32 | model.add(layers.Flatten())
33 | model.add(layers.Dense(64, activation='relu'))
34 | model.add(layers.Dense(10))
35 |
36 | optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
37 |
38 |
39 | def train(model, train_images, train_labels):
40 | with tf.GradientTape() as t:
41 | current_loss = loss(model, train_images, train_labels)
42 | return current_loss, t.gradient(current_loss, model.trainable_variables)
43 |
44 |
45 | epochs = range(100)
46 | import tqdm
47 | epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
48 |
49 |
50 | for epoch in epochs:
51 | t = tqdm.tqdm()
52 | for batch in train_ds.shuffle(256).batch(128):
53 | x, y = batch
54 | x, y = tf.cast(x, 'float32'), tf.cast(y, 'int32')[:,0]
55 | current_loss, grads = train(model, x, y)
56 | optimizer.apply_gradients(zip(grads, model.trainable_variables))
57 | epoch_accuracy(y, model(x, training=True))
58 | t.update(1)
59 | t.set_postfix_str('Epoch: {}. Loss: {}. Acc: {}'.format(epoch, current_loss, epoch_accuracy.result()))
60 | t.close()
--------------------------------------------------------------------------------
/implementations/test_torchbearer.py:
--------------------------------------------------------------------------------
1 | from torchvision import datasets, transforms, models
2 | import torch
3 | from torch import optim
4 | from implementations.torchbearer_implementation import FMix
5 | from torchbearer import Trial
6 |
7 |
8 | # ######### Data
9 | print('==> Preparing data..')
10 | classes, cifar = 10, datasets.CIFAR10
11 |
12 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
13 | transform_base = [transforms.ToTensor(), normalize]
14 | transform = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()] + transform_base
15 |
16 | transform_train = transforms.Compose(transform)
17 | transform_test = transforms.Compose(transform_base)
18 | trainset = cifar(root='./data', train=True, download=True, transform=transform_train)
19 | valset = cifar(root='./data', train=False, download=True, transform=transform_test)
20 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
21 | valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=8)
22 |
23 |
24 | ######### Model
25 | print('==> Building model..')
26 | net = models.resnet18(False)
27 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
28 | fmix = FMix()
29 | criterion = fmix.loss()
30 |
31 |
32 | ######### Trial
33 | print('==> Starting training..')
34 | trial = Trial(net, optimizer, criterion, metrics=['acc', 'loss'], callbacks=[fmix])
35 | trial.with_generators(train_generator=trainloader, val_generator=valloader).to('cuda')
36 | trial.run(100, verbose=2)
37 |
--------------------------------------------------------------------------------
/implementations/torchbearer_implementation.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torchbearer
3 | from torchbearer.callbacks import Callback
4 | from fmix import sample_mask, FMixBase
5 | import torch
6 |
7 |
8 | def fmix_loss(input, y, index, lam, train=True, reformulate=False, bce_loss=False):
9 | r"""Criterion for fmix
10 |
11 | Args:
12 | input: If train, mixed input. If not train, standard input
13 | y: Targets for first image
14 | index: Permutation for mixing
15 | lam: Lambda value of mixing
16 | train: If true, sum cross entropy of input with y1 and y2, weighted by lam/(1-lam). If false, cross entropy loss with y1
17 | """
18 | loss_fn = F.cross_entropy if not bce_loss else F.binary_cross_entropy_with_logits
19 |
20 | if train and not reformulate:
21 | y2 = y[index]
22 | return loss_fn(input, y) * lam + loss_fn(input, y2) * (1 - lam)
23 | else:
24 | return loss_fn(input, y)
25 |
26 |
27 | class FMix(FMixBase, Callback):
28 | r""" FMix augmentation
29 |
30 | Args:
31 | decay_power (float): Decay power for frequency decay prop 1/f**d
32 | alpha (float): Alpha value for beta distribution from which to sample mean of mask
33 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims. -1 computes on the fly
34 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
35 | reformulate (bool): If True, uses the reformulation of [1].
36 |
37 | Example
38 | -------
39 |
40 | .. code-block:: python
41 |
42 | fmix = FMix(...)
43 | trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])
44 | # ...
45 | """
46 | def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
47 | super().__init__(decay_power, alpha, size, max_soft, reformulate)
48 |
49 | def on_sample(self, state):
50 | super().on_sample(state)
51 | x, y = state[torchbearer.X], state[torchbearer.Y_TRUE]
52 | device = state[torchbearer.DEVICE]
53 |
54 | x = self(x)
55 |
56 | # Store the results
57 | state[torchbearer.X] = x
58 | state[torchbearer.Y_TRUE] = y
59 |
60 | # Set mixup flags
61 | state[torchbearer.MIXUP_LAMBDA] = torch.tensor([self.lam], device=device) if not self.reformulate else torch.tensor([1], device=device)
62 | state[torchbearer.MIXUP_PERMUTATION] = self.index
63 |
64 | def __call__(self, x):
65 | size = []
66 | for i, s in enumerate(self.size):
67 | if s != -1:
68 | size.append(s)
69 | else:
70 | size.append(x.shape[i+1])
71 |
72 | lam, mask = sample_mask(self.alpha, self.decay_power, size, self.max_soft, self.reformulate)
73 | index = torch.randperm(x.size(0)).to(x.device)
74 | mask = torch.from_numpy(mask).float().to(x.device)
75 |
76 | if len(self.size) == 1 and x.ndim == 3:
77 | mask = mask.unsqueeze(2)
78 |
79 | # Mix the images
80 | x1 = mask * x
81 | x2 = (1 - mask) * x[index]
82 | self.index = index
83 | self.lam = lam
84 | return x1 + x2
85 |
86 | def loss(self, use_bce=False):
87 | def _fmix_loss(state):
88 | y_pred = state[torchbearer.Y_PRED]
89 | y = state[torchbearer.Y_TRUE]
90 | index = state[torchbearer.MIXUP_PERMUTATION] if torchbearer.MIXUP_PERMUTATION in state else None
91 | lam = state[torchbearer.MIXUP_LAMBDA] if torchbearer.MIXUP_LAMBDA in state else None
92 | train = state[torchbearer.MODEL].training
93 | return fmix_loss(y_pred, y, index, lam, train, self.reformulate, use_bce)
94 |
95 | return _fmix_loss
96 |
97 |
98 | class PointNetFMix(FMix):
99 | def __init__(self, resolution, decay_power=3, alpha=1, max_soft=0.0, reformulate=False):
100 | super().__init__(decay_power, alpha, [resolution, resolution, resolution], max_soft, reformulate)
101 | self.res = resolution
102 |
103 | def __call__(self, x):
104 | import kaolin.conversions as cvt
105 | x = super().__call__(x)
106 | t = []
107 | for i in range(x.shape[0]):
108 | t.append(cvt.voxelgrid_to_pointcloud(x[i], self.res, normalize=True))
109 | return torch.stack(t)
110 |
111 |
112 | from torchbearer.metrics import default as d
113 | from utils.reformulated_mixup import MixupAcc
114 | d.__loss_map__[FMix().loss().__name__] = MixupAcc
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .wide_resnet import wrn
2 | from .densenet3 import DenseNet190
3 | from .pyramid import aa_PyramidNet
4 | from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
5 | from .senet import se_resnext50_32x4d
6 | from .toxic_lstm import LSTM
7 | from .toxic_cnn import CNN
8 | from .bert import Bert
9 |
--------------------------------------------------------------------------------
/models/bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Bert(nn.Module):
6 | def __init__(self, num_classes, nc=None, bert_model='bert-base-cased'):
7 | super().__init__()
8 | from transformers import BertModel
9 | self.bert = BertModel.from_pretrained(bert_model, output_hidden_states=True)
10 | self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
11 |
12 | def forward(self, x, token_type_ids=None, position_ids=None, head_mask=None):
13 | outputs = self.bert(x.long(),
14 | attention_mask=(x != 0).float(),
15 | token_type_ids=token_type_ids,
16 | position_ids=position_ids,
17 | head_mask=head_mask)
18 | cls_output = torch.cat(outputs[2], dim=1).mean(dim=1)
19 | cls_output = self.classifier(cls_output) # batch, 6
20 | return cls_output
21 |
--------------------------------------------------------------------------------
/models/densenet3.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/andreasveit/densenet-pytorch
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | def __init__(self, in_planes, out_planes, dropRate=0.0):
11 | super(BasicBlock, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.relu = nn.ReLU(inplace=True)
14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
15 | padding=1, bias=False)
16 | self.droprate = dropRate
17 |
18 | def forward(self, x):
19 | out = self.conv1(self.relu(self.bn1(x)))
20 | if self.droprate > 0:
21 | out = F.dropout(out, p=self.droprate, training=self.training)
22 | return torch.cat([x, out], 1)
23 |
24 |
25 | class BottleneckBlock(nn.Module):
26 | def __init__(self, in_planes, out_planes, dropRate=0.0):
27 | super(BottleneckBlock, self).__init__()
28 | inter_planes = out_planes * 4
29 | self.bn1 = nn.BatchNorm2d(in_planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
32 | padding=0, bias=False)
33 | self.bn2 = nn.BatchNorm2d(inter_planes)
34 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
35 | padding=1, bias=False)
36 | self.droprate = dropRate
37 |
38 | def forward(self, x):
39 | out = self.conv1(self.relu(self.bn1(x)))
40 | if self.droprate > 0:
41 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
42 | out = self.conv2(self.relu(self.bn2(out)))
43 | if self.droprate > 0:
44 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
45 | return torch.cat([x, out], 1)
46 |
47 |
48 | class TransitionBlock(nn.Module):
49 | def __init__(self, in_planes, out_planes, dropRate=0.0):
50 | super(TransitionBlock, self).__init__()
51 | self.bn1 = nn.BatchNorm2d(in_planes)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
54 | padding=0, bias=False)
55 | self.droprate = dropRate
56 |
57 | def forward(self, x):
58 | out = self.conv1(self.relu(self.bn1(x)))
59 | if self.droprate > 0:
60 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
61 | return F.avg_pool2d(out, 2)
62 |
63 |
64 | class DenseBlock(nn.Module):
65 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
66 | super(DenseBlock, self).__init__()
67 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
68 |
69 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
70 | layers = []
71 | for i in range(nb_layers):
72 | layers.append(block(in_planes + i * growth_rate, growth_rate, dropRate))
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | return self.layer(x)
77 |
78 |
79 | class DenseNet3(nn.Module):
80 | def __init__(self, depth, num_classes, growth_rate=12,
81 | reduction=0.5, bottleneck=True, dropRate=0.0, efficient=False, nc=3):
82 | super(DenseNet3, self).__init__()
83 | in_planes = 2 * growth_rate
84 | n = (depth - 4) // 3
85 | if bottleneck == True:
86 | n = n // 2
87 | block = BottleneckBlock
88 | else:
89 | block = BasicBlock
90 | # 1st conv before any dense block
91 | self.conv1 = nn.Conv2d(nc, in_planes, kernel_size=3, stride=1,
92 | padding=1, bias=False)
93 | # 1st block
94 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
95 | in_planes = int(in_planes + n * growth_rate)
96 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes * reduction)), dropRate=dropRate)
97 | in_planes = int(math.floor(in_planes * reduction))
98 | # 2nd block
99 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
100 | in_planes = int(in_planes + n * growth_rate)
101 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes * reduction)), dropRate=dropRate)
102 | in_planes = int(math.floor(in_planes * reduction))
103 | # 3rd block
104 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
105 | in_planes = int(in_planes + n * growth_rate)
106 | # global average pooling and classifier
107 | self.bn1 = nn.BatchNorm2d(in_planes)
108 | self.relu = nn.ReLU(inplace=True)
109 | self.fc = nn.Linear(in_planes, num_classes)
110 | self.in_planes = in_planes
111 |
112 | for m in self.modules():
113 | if isinstance(m, nn.Conv2d):
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
115 | m.weight.data.normal_(0, math.sqrt(2. / n))
116 | elif isinstance(m, nn.BatchNorm2d):
117 | m.weight.data.fill_(1)
118 | m.bias.data.zero_()
119 | elif isinstance(m, nn.Linear):
120 | m.bias.data.zero_()
121 |
122 | def forward(self, x):
123 | out = self.conv1(x)
124 | out = self.trans1(self.block1(out))
125 | out = self.trans2(self.block2(out))
126 | out = self.block3(out)
127 | out = self.relu(self.bn1(out))
128 | # out = F.avg_pool2d(out, 8)
129 | out = F.adaptive_avg_pool2d(out, (1, 1))
130 | out = out.view(-1, self.in_planes)
131 | return self.fc(out)
132 |
133 |
134 | def DenseNet190(num_classes=10, nc=3):
135 | return DenseNet3(190, num_classes, growth_rate=40, nc=nc)
136 |
137 |
138 | def EDenseNet190():
139 | return DenseNet3(190, 10, growth_rate=40, efficient=True)
140 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import models
2 | from models import wrn
3 | import torchvision.models as m
4 | # from models.toxic_lstm import LSTM
5 |
6 |
7 | def get_model(args, classes, nc):
8 | # Load torchvision models with "torch_" prefix
9 | if 'torch' in args.model:
10 | return m.__dict__[args.model[6:]](num_classes=classes, pretrained=False)
11 |
12 | # Load the pyramidnet used for autoaugment experiments on cifar
13 | if args.model == 'aa_PyramidNet':
14 | return models.__dict__[args.model](dataset='cifar10', depth=272, alpha=200, num_classes=classes)
15 |
16 | # Load the WideResNet-28-10
17 | if args.model == 'wrn':
18 | return wrn(num_classes=classes, depth=28, widen_factor=10, nc=nc)
19 |
20 | if args.model == 'PointNet' or args.dataset == 'modelnet':
21 | from kaolin.models.PointNet import PointNetClassifier
22 | return PointNetClassifier(num_classes=classes)
23 |
24 | # if args.dataset == 'toxic':
25 | # return LSTM()
26 |
27 | # Otherwise return models from other files
28 | return models.__dict__[args.model](num_classes=classes, nc=nc)
29 |
--------------------------------------------------------------------------------
/models/pyramid.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/networks/pyramidnet.py
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 |
9 |
10 | class ShakeDropFunction(torch.autograd.Function):
11 |
12 | @staticmethod
13 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]):
14 | if training:
15 | gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop)
16 | ctx.save_for_backward(gate)
17 | if gate.item() == 0:
18 | alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range)
19 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x)
20 | return alpha * x
21 | else:
22 | return x
23 | else:
24 | return (1 - p_drop) * x
25 |
26 | @staticmethod
27 | def backward(ctx, grad_output):
28 | gate = ctx.saved_tensors[0]
29 | if gate.item() == 0:
30 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1)
31 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
32 | beta = Variable(beta)
33 | return beta * grad_output, None, None, None
34 | else:
35 | return grad_output, None, None, None
36 |
37 |
38 | class ShakeDrop(nn.Module):
39 |
40 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]):
41 | super(ShakeDrop, self).__init__()
42 | self.p_drop = p_drop
43 | self.alpha_range = alpha_range
44 |
45 | def forward(self, x):
46 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range)
47 |
48 |
49 | def conv3x3(in_planes, out_planes, stride=1):
50 | """
51 | 3x3 convolution with padding
52 | """
53 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
54 |
55 |
56 | class BasicBlock(nn.Module):
57 | outchannel_ratio = 1
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0):
60 | super(BasicBlock, self).__init__()
61 | self.bn1 = nn.BatchNorm2d(inplanes)
62 | self.conv1 = conv3x3(inplanes, planes, stride)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv2 = conv3x3(planes, planes)
65 | self.bn3 = nn.BatchNorm2d(planes)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 | self.shake_drop = ShakeDrop(p_shakedrop)
70 |
71 | def forward(self, x):
72 |
73 | out = self.bn1(x)
74 | out = self.conv1(out)
75 | out = self.bn2(out)
76 | out = self.relu(out)
77 | out = self.conv2(out)
78 | out = self.bn3(out)
79 |
80 | out = self.shake_drop(out)
81 |
82 | if self.downsample is not None:
83 | shortcut = self.downsample(x)
84 | featuremap_size = shortcut.size()[2:4]
85 | else:
86 | shortcut = x
87 | featuremap_size = out.size()[2:4]
88 |
89 | batch_size = out.size()[0]
90 | residual_channel = out.size()[1]
91 | shortcut_channel = shortcut.size()[1]
92 |
93 | if residual_channel != shortcut_channel:
94 | padding = torch.autograd.Variable(
95 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0],
96 | featuremap_size[1]).fill_(0))
97 | out += torch.cat((shortcut, padding), 1)
98 | else:
99 | out += shortcut
100 |
101 | return out
102 |
103 |
104 | class Bottleneck(nn.Module):
105 | outchannel_ratio = 4
106 |
107 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0):
108 | super(Bottleneck, self).__init__()
109 | self.bn1 = nn.BatchNorm2d(inplanes)
110 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
111 | self.bn2 = nn.BatchNorm2d(planes)
112 | self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride,
113 | padding=1, bias=False)
114 | self.bn3 = nn.BatchNorm2d((planes * 1))
115 | self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False)
116 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.downsample = downsample
119 | self.stride = stride
120 | self.shake_drop = ShakeDrop(p_shakedrop)
121 |
122 | def forward(self, x):
123 |
124 | out = self.bn1(x)
125 | out = self.conv1(out)
126 |
127 | out = self.bn2(out)
128 | out = self.relu(out)
129 | out = self.conv2(out)
130 |
131 | out = self.bn3(out)
132 | out = self.relu(out)
133 | out = self.conv3(out)
134 |
135 | out = self.bn4(out)
136 |
137 | out = self.shake_drop(out)
138 |
139 | if self.downsample is not None:
140 | shortcut = self.downsample(x)
141 | featuremap_size = shortcut.size()[2:4]
142 | else:
143 | shortcut = x
144 | featuremap_size = out.size()[2:4]
145 |
146 | batch_size = out.size()[0]
147 | residual_channel = out.size()[1]
148 | shortcut_channel = shortcut.size()[1]
149 |
150 | if residual_channel != shortcut_channel:
151 | padding = torch.autograd.Variable(
152 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0],
153 | featuremap_size[1]).fill_(0))
154 | out += torch.cat((shortcut, padding), 1)
155 | else:
156 | out += shortcut
157 |
158 | return out
159 |
160 |
161 | class aa_PyramidNet(nn.Module):
162 | def __init__(self, dataset='cifar10', depth=272, alpha=200, num_classes=10, bottleneck=True):
163 | super(aa_PyramidNet, self).__init__()
164 | self.dataset = dataset
165 | if self.dataset.startswith('cifar'):
166 | self.inplanes = 16
167 | if bottleneck:
168 | n = int((depth - 2) / 9)
169 | block = Bottleneck
170 | else:
171 | n = int((depth - 2) / 6)
172 | block = BasicBlock
173 |
174 | self.addrate = alpha / (3 * n * 1.0)
175 | self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)]
176 |
177 | self.input_featuremap_dim = self.inplanes
178 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False)
179 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
180 |
181 | self.featuremap_dim = self.input_featuremap_dim
182 | self.layer1 = self.pyramidal_make_layer(block, n)
183 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2)
184 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2)
185 |
186 | self.final_featuremap_dim = self.input_featuremap_dim
187 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
188 | self.relu_final = nn.ReLU(inplace=True)
189 | self.avgpool = nn.AvgPool2d(8)
190 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
191 |
192 | elif dataset == 'imagenet':
193 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
194 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3],
195 | 200: [3, 24, 36, 3]}
196 |
197 | if layers.get(depth) is None:
198 | if bottleneck == True:
199 | blocks[depth] = Bottleneck
200 | temp_cfg = int((depth - 2) / 12)
201 | else:
202 | blocks[depth] = BasicBlock
203 | temp_cfg = int((depth - 2) / 8)
204 |
205 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg]
206 | print('=> the layer configuration for each stage is set to', layers[depth])
207 |
208 | self.inplanes = 64
209 | self.addrate = alpha / (sum(layers[depth]) * 1.0)
210 |
211 | self.input_featuremap_dim = self.inplanes
212 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False)
213 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
214 | self.relu = nn.ReLU(inplace=True)
215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
216 |
217 | self.featuremap_dim = self.input_featuremap_dim
218 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0])
219 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2)
220 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2)
221 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2)
222 |
223 | self.final_featuremap_dim = self.input_featuremap_dim
224 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
225 | self.relu_final = nn.ReLU(inplace=True)
226 | self.avgpool = nn.AvgPool2d(7)
227 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
228 |
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
232 | m.weight.data.normal_(0, math.sqrt(2. / n))
233 | elif isinstance(m, nn.BatchNorm2d):
234 | m.weight.data.fill_(1)
235 | m.bias.data.zero_()
236 |
237 | assert len(self.ps_shakedrop) == 0, self.ps_shakedrop
238 |
239 | def pyramidal_make_layer(self, block, block_depth, stride=1):
240 | downsample = None
241 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio:
242 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True)
243 |
244 | layers = []
245 | self.featuremap_dim = self.featuremap_dim + self.addrate
246 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample,
247 | p_shakedrop=self.ps_shakedrop.pop(0)))
248 | for i in range(1, block_depth):
249 | temp_featuremap_dim = self.featuremap_dim + self.addrate
250 | layers.append(
251 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1,
252 | p_shakedrop=self.ps_shakedrop.pop(0)))
253 | self.featuremap_dim = temp_featuremap_dim
254 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio
255 |
256 | return nn.Sequential(*layers)
257 |
258 | def forward(self, x):
259 | if self.dataset == 'cifar10' or self.dataset == 'cifar100':
260 | x = self.conv1(x)
261 | x = self.bn1(x)
262 |
263 | x = self.layer1(x)
264 | x = self.layer2(x)
265 | x = self.layer3(x)
266 |
267 | x = self.bn_final(x)
268 | x = self.relu_final(x)
269 | x = self.avgpool(x)
270 | x = x.view(x.size(0), -1)
271 | x = self.fc(x)
272 |
273 | elif self.dataset == 'imagenet':
274 | x = self.conv1(x)
275 | x = self.bn1(x)
276 | x = self.relu(x)
277 | x = self.maxpool(x)
278 |
279 | x = self.layer1(x)
280 | x = self.layer2(x)
281 | x = self.layer3(x)
282 | x = self.layer4(x)
283 |
284 | x = self.bn_final(x)
285 | x = self.relu_final(x)
286 | x = self.avgpool(x)
287 | x = x.view(x.size(0), -1)
288 | x = self.fc(x)
289 |
290 | return x
291 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/kuangliu/pytorch-cifar
2 |
3 | '''ResNet in PyTorch.
4 |
5 | BasicBlock and Bottleneck module is from the original ResNet paper:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 |
9 | PreActBlock and PreActBottleneck module is from the later paper:
10 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
11 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
12 | '''
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | from torch.autograd import Variable
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, in_planes, planes, stride=1):
28 | super(BasicBlock, self).__init__()
29 | self.conv1 = conv3x3(in_planes, planes, stride)
30 | self.bn1 = nn.BatchNorm2d(planes)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 |
34 | self.shortcut = nn.Sequential()
35 | if stride != 1 or in_planes != self.expansion * planes:
36 | self.shortcut = nn.Sequential(
37 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
38 | nn.BatchNorm2d(self.expansion * planes)
39 | )
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.bn2(self.conv2(out))
44 | out += self.shortcut(x)
45 | out = F.relu(out)
46 | return out
47 |
48 |
49 | class PreActBlock(nn.Module):
50 | '''Pre-activation version of the BasicBlock.'''
51 | expansion = 1
52 |
53 | def __init__(self, in_planes, planes, stride=1):
54 | super(PreActBlock, self).__init__()
55 | self.bn1 = nn.BatchNorm2d(in_planes)
56 | self.conv1 = conv3x3(in_planes, planes, stride)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv2 = conv3x3(planes, planes)
59 |
60 | self.shortcut = nn.Sequential()
61 | if stride != 1 or in_planes != self.expansion * planes:
62 | self.shortcut = nn.Sequential(
63 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
64 | )
65 |
66 | def forward(self, x):
67 | out = F.relu(self.bn1(x))
68 | shortcut = self.shortcut(out)
69 | out = self.conv1(out)
70 | out = self.conv2(F.relu(self.bn2(out)))
71 | out += shortcut
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 |
78 | def __init__(self, in_planes, planes, stride=1):
79 | super(Bottleneck, self).__init__()
80 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(planes)
82 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
85 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
86 |
87 | self.shortcut = nn.Sequential()
88 | if stride != 1 or in_planes != self.expansion * planes:
89 | self.shortcut = nn.Sequential(
90 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
91 | nn.BatchNorm2d(self.expansion * planes)
92 | )
93 |
94 | def forward(self, x):
95 | out = F.relu(self.bn1(self.conv1(x)))
96 | out = F.relu(self.bn2(self.conv2(out)))
97 | out = self.bn3(self.conv3(out))
98 | out += self.shortcut(x)
99 | out = F.relu(out)
100 | return out
101 |
102 |
103 | class PreActBottleneck(nn.Module):
104 | '''Pre-activation version of the original Bottleneck module.'''
105 | expansion = 4
106 |
107 | def __init__(self, in_planes, planes, stride=1):
108 | super(PreActBottleneck, self).__init__()
109 | self.bn1 = nn.BatchNorm2d(in_planes)
110 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
111 | self.bn2 = nn.BatchNorm2d(planes)
112 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
113 | self.bn3 = nn.BatchNorm2d(planes)
114 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
115 |
116 | self.shortcut = nn.Sequential()
117 | if stride != 1 or in_planes != self.expansion * planes:
118 | self.shortcut = nn.Sequential(
119 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
120 | )
121 |
122 | def forward(self, x):
123 | out = F.relu(self.bn1(x))
124 | shortcut = self.shortcut(out)
125 | out = self.conv1(out)
126 | out = self.conv2(F.relu(self.bn2(out)))
127 | out = self.conv3(F.relu(self.bn3(out)))
128 | out += shortcut
129 | return out
130 |
131 |
132 | class ResNet(nn.Module):
133 | def __init__(self, block, num_blocks, num_classes=10, nc=3):
134 | super(ResNet, self).__init__()
135 | self.in_planes = 64
136 |
137 | self.conv1 = conv3x3(nc, 64)
138 | self.bn1 = nn.BatchNorm2d(64)
139 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
140 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
141 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
142 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
143 | self.linear = nn.Linear(512 * block.expansion, num_classes)
144 |
145 | def _make_layer(self, block, planes, num_blocks, stride):
146 | strides = [stride] + [1] * (num_blocks - 1)
147 | layers = []
148 | for stride in strides:
149 | layers.append(block(self.in_planes, planes, stride))
150 | self.in_planes = planes * block.expansion
151 | return nn.Sequential(*layers)
152 |
153 | def forward(self, x, lin=0, lout=5):
154 | out = x
155 | if lin < 1 and lout > -1:
156 | out = self.conv1(out)
157 | out = self.bn1(out)
158 | out = F.relu(out)
159 | if lin < 2 and lout > 0:
160 | out = self.layer1(out)
161 | if lin < 3 and lout > 1:
162 | out = self.layer2(out)
163 | if lin < 4 and lout > 2:
164 | out = self.layer3(out)
165 | if lin < 5 and lout > 3:
166 | out = self.layer4(out)
167 | if lout > 4:
168 | # out = F.avg_pool2d(out, 4)
169 | out = F.adaptive_avg_pool2d(out, (1, 1))
170 |
171 | out = out.view(out.size(0), -1)
172 | out = self.linear(out)
173 | return out
174 |
175 |
176 | def ResNet18(num_classes=10, nc=3):
177 | return ResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes, nc=nc)
178 |
179 |
180 | def ResNet34(num_classes=10):
181 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
182 |
183 |
184 | def ResNet50(num_classes=10):
185 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
186 |
187 |
188 | def ResNet101(num_classes=10):
189 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
190 |
191 |
192 | def ResNet152(num_classes=10):
193 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)
194 |
195 |
196 | def test():
197 | net = ResNet18()
198 | y = net(Variable(torch.randn(1, 3, 32, 32)))
199 | print(y.size())
200 |
201 |
202 | def resnet():
203 | return ResNet18()
204 |
--------------------------------------------------------------------------------
/models/senet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from
3 | https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
4 |
5 | ResNet code gently borrowed from
6 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 | """
8 | from __future__ import print_function, division, absolute_import
9 | from collections import OrderedDict
10 | import math
11 |
12 | import torch.nn as nn
13 |
14 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
15 | 'se_resnext50_32x4d', 'se_resnext101_32x4d']
16 |
17 |
18 | class SEModule(nn.Module):
19 |
20 | def __init__(self, channels, reduction):
21 | super(SEModule, self).__init__()
22 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
23 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
24 | padding=0)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
27 | padding=0)
28 | self.sigmoid = nn.Sigmoid()
29 |
30 | def forward(self, x):
31 | module_input = x
32 | x = self.avg_pool(x)
33 | x = self.fc1(x)
34 | x = self.relu(x)
35 | x = self.fc2(x)
36 | x = self.sigmoid(x)
37 | return module_input * x
38 |
39 |
40 | class Bottleneck(nn.Module):
41 | """
42 | Base class for bottlenecks that implements `forward()` method.
43 | """
44 | def forward(self, x):
45 | residual = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv3(out)
56 | out = self.bn3(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out = self.se_module(out) + residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class SEBottleneck(Bottleneck):
68 | """
69 | Bottleneck for SENet154.
70 | """
71 | expansion = 4
72 |
73 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
74 | downsample=None):
75 | super(SEBottleneck, self).__init__()
76 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
77 | self.bn1 = nn.BatchNorm2d(planes * 2)
78 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
79 | stride=stride, padding=1, groups=groups,
80 | bias=False)
81 | self.bn2 = nn.BatchNorm2d(planes * 4)
82 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
83 | bias=False)
84 | self.bn3 = nn.BatchNorm2d(planes * 4)
85 | self.relu = nn.ReLU(inplace=True)
86 | self.se_module = SEModule(planes * 4, reduction=reduction)
87 | self.downsample = downsample
88 | self.stride = stride
89 |
90 |
91 | class SEResNetBottleneck(Bottleneck):
92 | """
93 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
94 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
95 | (the latter is used in the torchvision implementation of ResNet).
96 | """
97 | expansion = 4
98 |
99 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
100 | downsample=None):
101 | super(SEResNetBottleneck, self).__init__()
102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
103 | stride=stride)
104 | self.bn1 = nn.BatchNorm2d(planes)
105 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
106 | groups=groups, bias=False)
107 | self.bn2 = nn.BatchNorm2d(planes)
108 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
109 | self.bn3 = nn.BatchNorm2d(planes * 4)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.se_module = SEModule(planes * 4, reduction=reduction)
112 | self.downsample = downsample
113 | self.stride = stride
114 |
115 |
116 | class SEResNeXtBottleneck(Bottleneck):
117 | """
118 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
119 | """
120 | expansion = 4
121 |
122 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
123 | downsample=None, base_width=4):
124 | super(SEResNeXtBottleneck, self).__init__()
125 | width = math.floor(planes * (base_width / 64)) * groups
126 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
127 | stride=1)
128 | self.bn1 = nn.BatchNorm2d(width)
129 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
130 | padding=1, groups=groups, bias=False)
131 | self.bn2 = nn.BatchNorm2d(width)
132 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
133 | self.bn3 = nn.BatchNorm2d(planes * 4)
134 | self.relu = nn.ReLU(inplace=True)
135 | self.se_module = SEModule(planes * 4, reduction=reduction)
136 | self.downsample = downsample
137 | self.stride = stride
138 |
139 |
140 | class SENet(nn.Module):
141 |
142 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
143 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
144 | downsample_padding=1, num_classes=1000, nc=3):
145 | """
146 | Parameters
147 | ----------
148 | block (nn.Module): Bottleneck class.
149 | - For SENet154: SEBottleneck
150 | - For SE-ResNet models: SEResNetBottleneck
151 | - For SE-ResNeXt models: SEResNeXtBottleneck
152 | layers (list of ints): Number of residual blocks for 4 layers of the
153 | network (layer1...layer4).
154 | groups (int): Number of groups for the 3x3 convolution in each
155 | bottleneck block.
156 | - For SENet154: 64
157 | - For SE-ResNet models: 1
158 | - For SE-ResNeXt models: 32
159 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
160 | - For all models: 16
161 | dropout_p (float or None): Drop probability for the Dropout layer.
162 | If `None` the Dropout layer is not used.
163 | - For SENet154: 0.2
164 | - For SE-ResNet models: None
165 | - For SE-ResNeXt models: None
166 | inplanes (int): Number of input channels for layer1.
167 | - For SENet154: 128
168 | - For SE-ResNet models: 64
169 | - For SE-ResNeXt models: 64
170 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
171 | a single 7x7 convolution in layer0.
172 | - For SENet154: True
173 | - For SE-ResNet models: False
174 | - For SE-ResNeXt models: False
175 | downsample_kernel_size (int): Kernel size for downsampling convolutions
176 | in layer2, layer3 and layer4.
177 | - For SENet154: 3
178 | - For SE-ResNet models: 1
179 | - For SE-ResNeXt models: 1
180 | downsample_padding (int): Padding for downsampling convolutions in
181 | layer2, layer3 and layer4.
182 | - For SENet154: 1
183 | - For SE-ResNet models: 0
184 | - For SE-ResNeXt models: 0
185 | num_classes (int): Number of outputs in `last_linear` layer.
186 | - For all models: 1000
187 | """
188 | super(SENet, self).__init__()
189 | self.inplanes = inplanes
190 | if input_3x3:
191 | layer0_modules = [
192 | ('conv1', nn.Conv2d(nc, 64, 3, stride=2, padding=1,
193 | bias=False)),
194 | ('bn1', nn.BatchNorm2d(64)),
195 | ('relu1', nn.ReLU(inplace=True)),
196 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
197 | bias=False)),
198 | ('bn2', nn.BatchNorm2d(64)),
199 | ('relu2', nn.ReLU(inplace=True)),
200 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
201 | bias=False)),
202 | ('bn3', nn.BatchNorm2d(inplanes)),
203 | ('relu3', nn.ReLU(inplace=True)),
204 | ]
205 | else:
206 | layer0_modules = [
207 | ('conv1', nn.Conv2d(nc, inplanes, kernel_size=7, stride=2,
208 | padding=3, bias=False)),
209 | ('bn1', nn.BatchNorm2d(inplanes)),
210 | ('relu1', nn.ReLU(inplace=True)),
211 | ]
212 | # To preserve compatibility with Caffe weights `ceil_mode=True`
213 | # is used instead of `padding=1`.
214 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
215 | ceil_mode=True)))
216 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
217 | self.layer1 = self._make_layer(
218 | block,
219 | planes=64,
220 | blocks=layers[0],
221 | groups=groups,
222 | reduction=reduction,
223 | downsample_kernel_size=1,
224 | downsample_padding=0
225 | )
226 | self.layer2 = self._make_layer(
227 | block,
228 | planes=128,
229 | blocks=layers[1],
230 | stride=2,
231 | groups=groups,
232 | reduction=reduction,
233 | downsample_kernel_size=downsample_kernel_size,
234 | downsample_padding=downsample_padding
235 | )
236 | self.layer3 = self._make_layer(
237 | block,
238 | planes=256,
239 | blocks=layers[2],
240 | stride=2,
241 | groups=groups,
242 | reduction=reduction,
243 | downsample_kernel_size=downsample_kernel_size,
244 | downsample_padding=downsample_padding
245 | )
246 | self.layer4 = self._make_layer(
247 | block,
248 | planes=512,
249 | blocks=layers[3],
250 | stride=2,
251 | groups=groups,
252 | reduction=reduction,
253 | downsample_kernel_size=downsample_kernel_size,
254 | downsample_padding=downsample_padding
255 | )
256 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
257 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
258 | self.last_linear = nn.Linear(512 * block.expansion, num_classes)
259 |
260 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
261 | downsample_kernel_size=1, downsample_padding=0):
262 | downsample = None
263 | if stride != 1 or self.inplanes != planes * block.expansion:
264 | downsample = nn.Sequential(
265 | nn.Conv2d(self.inplanes, planes * block.expansion,
266 | kernel_size=downsample_kernel_size, stride=stride,
267 | padding=downsample_padding, bias=False),
268 | nn.BatchNorm2d(planes * block.expansion),
269 | )
270 |
271 | layers = []
272 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
273 | downsample))
274 | self.inplanes = planes * block.expansion
275 | for i in range(1, blocks):
276 | layers.append(block(self.inplanes, planes, groups, reduction))
277 |
278 | return nn.Sequential(*layers)
279 |
280 | def features(self, x):
281 | x = self.layer0(x)
282 | x = self.layer1(x)
283 | x = self.layer2(x)
284 | x = self.layer3(x)
285 | x = self.layer4(x)
286 | return x
287 |
288 | def logits(self, x):
289 | x = self.avg_pool(x)
290 | if self.dropout is not None:
291 | x = self.dropout(x)
292 | x = x.view(x.size(0), -1)
293 | x = self.last_linear(x)
294 | return x
295 |
296 | def forward(self, x):
297 | x = self.features(x)
298 | x = self.logits(x)
299 | return x
300 |
301 |
302 | def senet154(num_classes=10):
303 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
304 | dropout_p=0.2, num_classes=num_classes)
305 | return model
306 |
307 |
308 | def se_resnet50(num_classes=10):
309 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
310 | dropout_p=None, inplanes=64, input_3x3=False,
311 | downsample_kernel_size=1, downsample_padding=0,
312 | num_classes=num_classes)
313 | return model
314 |
315 |
316 | def se_resnet101(num_classes=10):
317 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
318 | dropout_p=None, inplanes=64, input_3x3=False,
319 | downsample_kernel_size=1, downsample_padding=0,
320 | num_classes=num_classes)
321 | return model
322 |
323 |
324 | def se_resnet152(num_classes=10):
325 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
326 | dropout_p=None, inplanes=64, input_3x3=False,
327 | downsample_kernel_size=1, downsample_padding=0,
328 | num_classes=num_classes)
329 | return model
330 |
331 |
332 | def se_resnext50_32x4d(num_classes=10, nc=3):
333 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
334 | dropout_p=None, inplanes=64, input_3x3=False,
335 | downsample_kernel_size=1, downsample_padding=0,
336 | num_classes=num_classes, nc=nc)
337 | return model
338 |
339 |
340 | def se_resnext101_32x4d(num_classes=10, nc=3):
341 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
342 | dropout_p=None, inplanes=64, input_3x3=False,
343 | downsample_kernel_size=1, downsample_padding=0,
344 | num_classes=num_classes, nc=nc)
345 | return model
346 |
--------------------------------------------------------------------------------
/models/toxic_cnn.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class CNN(nn.Module):
7 | def __init__(self, num_classes=6, nl=2, nc=300, hidden_sz=128):
8 | super(CNN, self).__init__()
9 | self.hidden_sz = hidden_sz
10 | self.emb_sz = nc
11 | self.embeddings = None
12 |
13 | self.conv = nn.Sequential(
14 | nn.Conv1d(nc, hidden_sz, 3, padding=1),
15 | nn.ReLU(True),
16 | nn.Conv1d(hidden_sz, hidden_sz, 3, padding=1),
17 | nn.ReLU(True),
18 | nn.Conv1d(hidden_sz, hidden_sz, 3, padding=1),
19 | nn.ReLU()
20 | )
21 |
22 | layers = []
23 | for i in range(nl):
24 | if i == 0:
25 | layers.append(nn.Linear(hidden_sz * 3, hidden_sz))
26 | else:
27 | layers.append(nn.Linear(hidden_sz, hidden_sz))
28 | layers.append(nn.ReLU())
29 | self.layers = nn.Sequential(*layers)
30 | self.output = nn.Linear(hidden_sz, num_classes)
31 |
32 | def init_embedding(self, vectors, n_tokens, device):
33 | self.embeddings = nn.Embedding(n_tokens, self.emb_sz).to(device)
34 | self.embeddings.weight.data.copy_(vectors.to(device))
35 |
36 | def embed(self, data):
37 | embedded = self.embeddings(data)
38 | return embedded
39 |
40 | def forward(self, embedded):
41 | x = self.conv(embedded.permute(0, 2, 1))
42 |
43 | avg_pool = F.adaptive_avg_pool1d(x, 1).view(embedded.size(0), -1)
44 | max_pool = F.adaptive_max_pool1d(x, 1).view(embedded.size(0), -1)
45 | x = torch.cat([avg_pool, max_pool, x.permute(0, 2, 1)[:, -1]], dim=1)
46 | x = self.layers(x)
47 | res = self.output(x)
48 | if res.size(1) == 1:
49 | res = res.squeeze(1)
50 | return res
51 |
--------------------------------------------------------------------------------
/models/toxic_lstm.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class LSTM(nn.Module):
7 | def __init__(self, num_classes=6, nl=2, bidirectional=True, nc=300, hidden_sz=128):
8 | super(LSTM, self).__init__()
9 | self.hidden_sz = hidden_sz
10 | self.emb_sz = nc
11 | self.embeddings = None
12 |
13 | self.rnn = nn.LSTM(nc, hidden_sz, num_layers=2, bidirectional=bidirectional, dropout=0, batch_first=True)
14 | if bidirectional:
15 | hidden_sz = 2 * hidden_sz
16 |
17 | layers = []
18 | for i in range(nl):
19 | if i == 0:
20 | layers.append(nn.Linear(hidden_sz * 3, hidden_sz))
21 | else:
22 | layers.append(nn.Linear(hidden_sz, hidden_sz))
23 | layers.append(nn.ReLU())
24 | self.layers = nn.Sequential(*layers)
25 | self.output = nn.Linear(hidden_sz, num_classes)
26 |
27 | def init_embedding(self, vectors, n_tokens, device):
28 | self.embeddings = nn.Embedding(n_tokens, self.emb_sz).to(device)
29 | self.embeddings.weight.data.copy_(vectors.to(device))
30 |
31 | def embed(self, data):
32 | self.h = self.init_hidden(data.size(0))
33 | embedded = self.embeddings(data)
34 | return embedded
35 |
36 | def forward(self, embedded):
37 | rnn_out, self.h = self.rnn(embedded, (self.h, self.h))
38 |
39 | avg_pool = F.adaptive_avg_pool1d(rnn_out.permute(0, 2, 1), 1).view(embedded.size(0), -1)
40 | max_pool = F.adaptive_max_pool1d(rnn_out.permute(0, 2, 1), 1).view(embedded.size(0), -1)
41 | x = torch.cat([avg_pool, max_pool, rnn_out[:, -1]], dim=1)
42 | x = self.layers(x)
43 | res = self.output(x)
44 | if res.size(1) == 1:
45 | res = res.squeeze(1)
46 | return res
47 |
48 | def init_hidden(self, batch_size):
49 | return torch.zeros((4, batch_size, self.hidden_sz), device="cuda")
50 |
--------------------------------------------------------------------------------
/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/xternalz/WideResNet-pytorch
2 |
3 | from __future__ import division
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
12 | super(BasicBlock, self).__init__()
13 | self.bn1 = nn.BatchNorm2d(in_planes)
14 | self.relu1 = nn.ReLU(inplace=True)
15 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(out_planes)
18 | self.relu2 = nn.ReLU(inplace=True)
19 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
20 | padding=1, bias=False)
21 | self.droprate = dropRate
22 | self.equalInOut = (in_planes == out_planes)
23 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
24 | padding=0, bias=False) or None
25 |
26 | def forward(self, x):
27 | if not self.equalInOut:
28 | x = self.relu1(self.bn1(x))
29 | else:
30 | out = self.relu1(self.bn1(x))
31 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
32 | if self.droprate > 0:
33 | out = F.dropout(out, p=self.droprate, training=self.training)
34 | out = self.conv2(out)
35 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
36 |
37 |
38 | class NetworkBlock(nn.Module):
39 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
40 | super(NetworkBlock, self).__init__()
41 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
42 |
43 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
44 | layers = []
45 | for i in range(nb_layers):
46 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
47 | return nn.Sequential(*layers)
48 |
49 | def forward(self, x):
50 | return self.layer(x)
51 |
52 |
53 | class WideResNet(nn.Module):
54 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, nc=1):
55 | super(WideResNet, self).__init__()
56 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
57 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
58 | n = (depth - 4) // 6
59 | block = BasicBlock
60 | # 1st conv before any network block
61 | self.conv1 = nn.Conv2d(nc, nChannels[0], kernel_size=3, stride=1,
62 | padding=1, bias=False)
63 | # 1st block
64 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
65 | # 2nd block
66 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
67 | # 3rd block
68 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
69 | # global average pooling and classifier
70 | self.bn1 = nn.BatchNorm2d(nChannels[3])
71 | self.relu = nn.ReLU(inplace=True)
72 | self.fc = nn.Linear(nChannels[3], num_classes)
73 | self.nChannels = nChannels[3]
74 |
75 | for m in self.modules():
76 | if isinstance(m, nn.Conv2d):
77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
78 | m.weight.data.normal_(0, math.sqrt(2. / n))
79 | elif isinstance(m, nn.BatchNorm2d):
80 | m.weight.data.fill_(1)
81 | m.bias.data.zero_()
82 | elif isinstance(m, nn.Linear):
83 | m.bias.data.zero_()
84 |
85 | def forward(self, x):
86 | out = self.conv1(x)
87 | out = self.block1(out)
88 | out = self.block2(out)
89 | out = self.block3(out)
90 | out = self.relu(self.bn1(out))
91 | out = F.avg_pool2d(out, 7)
92 | out = out.view(-1, self.nChannels)
93 | return self.fc(out)
94 |
95 |
96 | def wrn(**kwargs):
97 | """
98 | Constructs a Wide Residual Networks.
99 | """
100 | model = WideResNet(**kwargs)
101 | return model
102 |
--------------------------------------------------------------------------------
/notebooks/grad_cam.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Example Masks\n",
8 | "\n",
9 | "In this notebook, we plot the [Grad-CAM](https://arxiv.org/abs/1610.02391) figures from the paper. For more info and other examples, have a look at [our README](https://github.com/ecs-vlc/fmix).\n",
10 | "\n",
11 | "**Note**: The easiest way to use this is as a colab notebook, which allows you to dive in with no setup.\n",
12 | "\n",
13 | "First, we load dependencies and some data from CIFAR-10:"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "metadata": {},
20 | "outputs": [
21 | {
22 | "name": "stdout",
23 | "output_type": "stream",
24 | "text": [
25 | "Files already downloaded and verified\n"
26 | ]
27 | }
28 | ],
29 | "source": [
30 | "import numpy as np\n",
31 | "import torch\n",
32 | "import torch.nn as nn\n",
33 | "import torchvision\n",
34 | "from torchvision import transforms\n",
35 | "import models\n",
36 | "from torchbearer import Trial\n",
37 | "import cv2\n",
38 | "import torch.nn.functional as F\n",
39 | "\n",
40 | "inv_norm = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))\n",
41 | "valset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,\n",
42 | " transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
43 | " (0.2023, 0.1994, 0.2010)),]))\n",
44 | "\n",
45 | "valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "## Model Wrapper\n",
53 | "\n",
54 | "Next, we define a `ResNet` wrapper that will iterate up to some given block `k`:"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 2,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "class ResNet_CAM(nn.Module):\n",
64 | " def __init__(self, net, layer_k):\n",
65 | " super(ResNet_CAM, self).__init__()\n",
66 | " self.resnet = net\n",
67 | " convs = nn.Sequential(*list(net.children())[:-1])\n",
68 | " self.first_part_conv = convs[:layer_k]\n",
69 | " self.second_part_conv = convs[layer_k:]\n",
70 | " self.linear = nn.Sequential(*list(net.children())[-1:])\n",
71 | " \n",
72 | " def forward(self, x):\n",
73 | " x = self.first_part_conv(x)\n",
74 | " x.register_hook(self.activations_hook)\n",
75 | " x = self.second_part_conv(x)\n",
76 | " x = F.adaptive_avg_pool2d(x, (1,1))\n",
77 | " x = x.view((1, -1))\n",
78 | " x = self.linear(x)\n",
79 | " return x\n",
80 | " \n",
81 | " def activations_hook(self, grad):\n",
82 | " self.gradients = grad\n",
83 | " \n",
84 | " def get_activations_gradient(self):\n",
85 | " return self.gradients\n",
86 | " \n",
87 | " def get_activations(self, x):\n",
88 | " return self.first_part_conv(x)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "## Grad-CAM\n",
96 | "\n",
97 | "Now for the Grad-CAM code, adapted from [implementing-grad-cam-in-pytorch](https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82):"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 3,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "def superimpose_heatmap(heatmap, img):\n",
107 | " resized_heatmap = cv2.resize(heatmap.numpy(), (img.shape[2], img.shape[3]))\n",
108 | " resized_heatmap = np.uint8(255 * resized_heatmap)\n",
109 | " resized_heatmap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)\n",
110 | " superimposed_img = torch.Tensor(cv2.cvtColor(resized_heatmap, cv2.COLOR_BGR2RGB)) * 0.006 + inv_norm(img[0]).permute(1,2,0)\n",
111 | " \n",
112 | " return superimposed_img\n",
113 | "\n",
114 | "def get_grad_cam(net, img):\n",
115 | " net.eval()\n",
116 | " pred = net(img)\n",
117 | " pred[:,pred.argmax(dim=1)].backward()\n",
118 | " gradients = net.get_activations_gradient()\n",
119 | " pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])\n",
120 | " activations = net.get_activations(img).detach()\n",
121 | " for i in range(activations.size(1)):\n",
122 | " activations[:, i, :, :] *= pooled_gradients[i]\n",
123 | " heatmap = torch.mean(activations, dim=1).squeeze()\n",
124 | " heatmap = np.maximum(heatmap, 0)\n",
125 | " heatmap /= torch.max(heatmap)\n",
126 | " \n",
127 | " return torch.Tensor(superimpose_heatmap(heatmap, img).permute(2,0,1))"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "## Models From `torch.hub`\n",
135 | "\n",
136 | "Next, we load in the models from `torch.hub`"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 4,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "name": "stderr",
146 | "output_type": "stream",
147 | "text": [
148 | "Downloading: \"https://github.com/ecs-vlc/FMix/archive/master.zip\" to /home/ethan/.cache/torch/hub/master.zip\n",
149 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n",
150 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n",
151 | "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "baseline_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_baseline', pretrained=True)\n",
157 | "\n",
158 | "fmix_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmix', pretrained=True)\n",
159 | "\n",
160 | "mixup_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_mixup', pretrained=True)\n",
161 | "\n",
162 | "fmix_plus_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmixplusmixup', pretrained=True)"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {},
168 | "source": [
169 | "## Plots\n",
170 | "\n",
171 | "Finally, generate and save the Grad-CAM plots:"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 5,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "layer_k = 4\n",
181 | "n_imgs = 10\n",
182 | "\n",
183 | "baseline_cam_net = ResNet_CAM(baseline_net, layer_k)\n",
184 | "fmix_cam_net = ResNet_CAM(fmix_net, layer_k)\n",
185 | "mixup_cam_net = ResNet_CAM(mixup_net, layer_k)\n",
186 | "fmix_plus_cam_net = ResNet_CAM(fmix_plus_net, layer_k)\n",
187 | "\n",
188 | "imgs = torch.Tensor(5, n_imgs, 3, 32, 32)\n",
189 | "it = iter(valloader)\n",
190 | "for i in range(0,n_imgs):\n",
191 | " img, _ = next(it)\n",
192 | " imgs[0][i] = inv_norm(img[0])\n",
193 | " imgs[1][i] = get_grad_cam(baseline_cam_net, img)\n",
194 | " imgs[2][i] = get_grad_cam(mixup_cam_net, img)\n",
195 | " imgs[3][i] = get_grad_cam(fmix_cam_net, img)\n",
196 | " imgs[4][i] = get_grad_cam(fmix_plus_cam_net, img)\n",
197 | "\n",
198 | "torchvision.utils.save_image(imgs.view(-1, 3, 32, 32), \"gradcam_at_layer\" + str(layer_k) + \".png\",nrow=n_imgs, pad_value=1)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "metadata": {},
205 | "outputs": [],
206 | "source": []
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "Python 3",
212 | "language": "python",
213 | "name": "python3"
214 | },
215 | "language_info": {
216 | "codemirror_mode": {
217 | "name": "ipython",
218 | "version": 3
219 | },
220 | "file_extension": ".py",
221 | "mimetype": "text/x-python",
222 | "name": "python",
223 | "nbconvert_exporter": "python",
224 | "pygments_lexer": "ipython3",
225 | "version": "3.7.3"
226 | }
227 | },
228 | "nbformat": 4,
229 | "nbformat_minor": 2
230 | }
231 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | Pillow
5 | torchbearer
6 | pandas
7 | pytorch-lightning
8 | tensorflow
9 | pyarrow
10 | spacy
11 | torchtext
12 | scipy
13 | scikit-learn
14 | tensorboardX
15 | transformers
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import ast
3 | from datetime import datetime
4 |
5 | import pandas as pd
6 | import torch
7 | import torchbearer
8 | from torch import nn, optim
9 | from torchbearer import Trial
10 | from torchbearer.callbacks import MultiStepLR, CosineAnnealingLR
11 | from torchbearer.callbacks import TensorBoard, TensorBoardText, Cutout, CutMix, RandomErase, on_forward_validation
12 |
13 | from datasets.datasets import ds, dsmeta, nlp_data
14 | from implementations.torchbearer_implementation import FMix, PointNetFMix
15 | from models.models import get_model
16 | from utils import RMixup, MSDAAlternator, WarmupLR
17 | from datasets.toxic import ToxicHelper
18 |
19 | # Setup
20 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
21 | parser.add_argument('--dataset', type=str, default='cifar10',
22 | choices=['cifar10', 'cifar100', 'reduced_cifar', 'fashion', 'imagenet', 'imagenet_hdf5', 'imagenet_a', 'tinyimagenet',
23 | 'commands', 'modelnet', 'toxic', 'toxic_bert', 'bengali_r', 'bengali_c', 'bengali_v', 'imdb', 'yelp_2', 'yelp_5'])
24 | parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path')
25 | parser.add_argument('--split-fraction', type=float, default=1., help='Fraction of total data to train on for reduced_cifar dataset')
26 | parser.add_argument('--pointcloud-resolution', default=128, type=int, help='Resolution of pointclouds in modelnet dataset')
27 | parser.add_argument('--model', default="ResNet18", type=str, help='model type')
28 | parser.add_argument('--epoch', default=200, type=int, help='total epochs to run')
29 | parser.add_argument('--train-steps', type=int, default=None, help='Number of training steps to run per "epoch"')
30 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
31 | parser.add_argument('--lr-warmup', type=ast.literal_eval, default=False, help='Use lr warmup')
32 | parser.add_argument('--batch-size', default=128, type=int, help='batch size')
33 | parser.add_argument('--device', default='cuda', type=str, help='Device on which to run')
34 | parser.add_argument('--num-workers', default=7, type=int, help='Number of dataloader workers')
35 |
36 | parser.add_argument('--auto-augment', type=ast.literal_eval, default=False, help='Use auto augment with cifar10/100')
37 | parser.add_argument('--augment', type=ast.literal_eval, default=True, help='use standard augmentation (default: True)')
38 | parser.add_argument('--parallel', type=ast.literal_eval, default=False, help='Use DataParallel')
39 | parser.add_argument('--reload', type=ast.literal_eval, default=False, help='Set to resume training from model path')
40 | parser.add_argument('--verbose', type=int, default=2, choices=[0, 1, 2])
41 | parser.add_argument('--seed', default=0, type=int, help='random seed')
42 |
43 | # Augs
44 | parser.add_argument('--random-erase', default=False, type=ast.literal_eval, help='Apply random erase')
45 | parser.add_argument('--cutout', default=False, type=ast.literal_eval, help='Apply Cutout')
46 | parser.add_argument('--msda-mode', default=None, type=str, choices=['fmix', 'cutmix', 'mixup', 'alt_mixup_fmix',
47 | 'alt_mixup_cutmix', 'alt_fmix_cutmix', 'None'])
48 |
49 | # Aug Params
50 | parser.add_argument('--alpha', default=1., type=float, help='mixup/fmix interpolation coefficient')
51 | parser.add_argument('--f-decay', default=3.0, type=float, help='decay power')
52 | parser.add_argument('--cutout_l', default=16, type=int, help='cutout/erase length')
53 | parser.add_argument('--reformulate', default=False, type=ast.literal_eval, help='Use reformulated fmix/mixup')
54 |
55 | # Scheduling
56 | parser.add_argument('--schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.')
57 | parser.add_argument('--cosine-scheduler', type=ast.literal_eval, default=False, help='Set to use a cosine scheduler instead of step scheduler')
58 |
59 | # Cross validation
60 | parser.add_argument('--fold-path', type=str, default='./data/folds.npz', help='Path to object storing fold ids. Run-id 0 will regen this if not existing')
61 | parser.add_argument('--n-folds', type=int, default=6, help='Number of cross val folds')
62 | parser.add_argument('--fold', type=str, default='test', help='One of [1, ..., n] or "test"')
63 |
64 | # Logs
65 | parser.add_argument('--run-id', type=int, default=0, help='Run id')
66 | parser.add_argument('--log-dir', default='./logs/testing', help='Tensorboard log dir')
67 | parser.add_argument('--model-file', default='./saved_models/model.pt', help='Path under which to save model. eg ./model.py')
68 | args = parser.parse_args()
69 |
70 |
71 | if args.seed != 0:
72 | torch.manual_seed(args.seed)
73 |
74 |
75 | print('==> Preparing data..')
76 | data = ds[args.dataset]
77 | meta = dsmeta[args.dataset]
78 | classes, nc, size = meta['classes'], meta['nc'], meta['size']
79 |
80 | trainset, valset, testset = data(args)
81 |
82 | # Toxic comments uses its own data loaders
83 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (trainset is not None) and (args.dataset not in nlp_data) else trainset
84 | valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (valset is not None) and (args.dataset not in nlp_data) else valset
85 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) if (args.dataset not in nlp_data) else testset
86 |
87 | print('==> Building model..')
88 | net = get_model(args, classes, nc)
89 | net = nn.DataParallel(net) if args.parallel else net
90 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
91 |
92 | if (args.dataset in nlp_data) or ('modelnet' in args.dataset):
93 | optimizer = optim.Adam(net.parameters(), lr=args.lr)
94 |
95 |
96 | print('==> Setting up callbacks..')
97 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') + "-run-" + str(args.run_id)
98 | tboard = TensorBoard(write_graph=False, comment=current_time, log_dir=args.log_dir)
99 | tboardtext = TensorBoardText(write_epoch_metrics=False, comment=current_time, log_dir=args.log_dir)
100 |
101 |
102 | @torchbearer.callbacks.on_start
103 | def write_params(_):
104 | params = vars(args)
105 | params['schedule'] = str(params['schedule'])
106 | df = pd.DataFrame(params, index=[0]).transpose()
107 | tboardtext.get_writer(tboardtext.log_dir).add_text('params', df.to_html(), 1)
108 |
109 |
110 | modes = {
111 | 'fmix': FMix(decay_power=args.f_decay, alpha=args.alpha, size=size, max_soft=0, reformulate=args.reformulate),
112 | 'mixup': RMixup(args.alpha, reformulate=args.reformulate),
113 | 'cutmix': CutMix(args.alpha, classes, True),
114 | 'pointcloud_fmix': PointNetFMix(args.pointcloud_resolution, decay_power=args.f_decay, alpha=args.alpha, max_soft=0,
115 | reformulate=args.reformulate)
116 | }
117 | modes.update({
118 | 'alt_mixup_fmix': MSDAAlternator(modes['fmix'], modes['mixup']),
119 | 'alt_mixup_cutmix': MSDAAlternator(modes['mixup'], modes['cutmix']),
120 | 'alt_fmix_cutmix': MSDAAlternator(modes['fmix'], modes['cutmix']),
121 | })
122 |
123 | # Pointcloud fmix converts voxel grids back into point clouds after mixing
124 | mode = 'pointcloud_fmix' if (args.msda_mode == 'fmix' and args.dataset == 'modelnet') else args.msda_mode
125 |
126 | # CutMix callback returns mixed and original targets. We mix in the loss function instead
127 | @torchbearer.callbacks.on_sample
128 | def cutmix_reformat(state):
129 | state[torchbearer.Y_TRUE] = state[torchbearer.Y_TRUE][0]
130 |
131 | cb = [tboard, tboardtext, write_params, torchbearer.callbacks.MostRecent(args.model_file)]
132 | # Toxic helper needs to go before the msda to reshape the input
133 | cb.append(ToxicHelper(to_float=args.dataset != 'yelp_5')) if (args.dataset in ['toxic', 'imdb', 'yelp_2', 'yelp_5']) else []
134 | cb.append(modes[mode]) if args.msda_mode not in [None, 'None'] else []
135 | cb.append(Cutout(1, args.cutout_l)) if args.cutout else []
136 | cb.append(RandomErase(1, args.cutout_l)) if args.random_erase else []
137 | # WARNING: Schedulers appear to be broken (wrong lr output) in some versions of PyTorch, including 1.4. We used 1.3.1
138 | cb.append(MultiStepLR(args.schedule)) if not args.cosine_scheduler else cb.append(CosineAnnealingLR(args.epoch, eta_min=0.))
139 | cb.append(WarmupLR(0.1, args.lr)) if args.lr_warmup else []
140 | cb.append(cutmix_reformat) if args.msda_mode == 'cutmix' else []
141 |
142 | # FMix loss is equivalent to mixup loss and works for all msda in torchbearer
143 | if args.msda_mode not in [None, 'None']:
144 | bce = True if (args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']) else False
145 | criterion = modes['fmix'].loss(bce)
146 | elif args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']:
147 | criterion = nn.BCEWithLogitsLoss()
148 | else:
149 | criterion = nn.CrossEntropyLoss()
150 |
151 | metrics_append = []
152 | if 'bengali' in args.dataset:
153 | from utils.macro_recall import MacroRecall
154 | metrics_append = [MacroRecall()]
155 | elif 'imagenet' in args.dataset:
156 | metrics_append = ['top_5_acc']
157 | elif 'toxic' in args.dataset:
158 | from torchbearer.metrics import to_dict, EpochLambda
159 |
160 | @to_dict
161 | class RocAucScore(EpochLambda):
162 | def __init__(self):
163 | import sklearn.metrics
164 |
165 | super().__init__('roc_auc_score',
166 | lambda y_pred, y_true: sklearn.metrics.roc_auc_score(y_true.cpu().numpy(), y_pred.detach().sigmoid().cpu().numpy()),
167 | running=False)
168 | metrics_append = [RocAucScore()]
169 |
170 | if args.dataset == 'imagenet_a':
171 | from datasets.imagenet_a import indices_in_1k
172 |
173 | @on_forward_validation
174 | def map_preds(state):
175 | state[torchbearer.PREDICTION] = state[torchbearer.PREDICTION][:, indices_in_1k]
176 | cb.append(map_preds)
177 |
178 | print('==> Training model..')
179 |
180 | acc = 'acc'
181 | if args.dataset in ['toxic', 'toxic_bert', 'imdb', 'yelp_2']:
182 | acc = 'binary_acc'
183 |
184 | trial = Trial(net, optimizer, criterion, metrics=[acc, 'loss', 'lr'] + metrics_append, callbacks=cb)
185 | trial.with_generators(train_generator=trainloader, val_generator=valloader, train_steps=args.train_steps, test_generator=testloader).to(args.device)
186 |
187 | if args.reload:
188 | state = torch.load(args.model_file)
189 | trial.load_state_dict(state, resume=args.dataset != 'imagenet_a')
190 | trial.replay()
191 |
192 | if trainloader is not None:
193 | trial.run(args.epoch, verbose=args.verbose)
194 | trial.evaluate(data_key=torchbearer.TEST_DATA)
195 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_warmup import WarmupLR
2 | from .msda_alternator import MSDAAlternator
3 | from .reformulated_mixup import RMixup
4 | from .cross_val import split, gen_folds
5 | from .reduced_dataset_splitter import EqualSplitter
6 | from .auto_augment.auto_augment import auto_augment, _fa_reduced_cifar10
--------------------------------------------------------------------------------
/utils/auto_augment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ecs-vlc/FMix/e5991dca018882734c8ea63599f10dfbe67fa0ae/utils/auto_augment/__init__.py
--------------------------------------------------------------------------------
/utils/auto_augment/auto_augment_aug_list.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | import torch
8 |
9 | random_mirror = True
10 |
11 |
12 | def ShearX(img, v): # [-0.3, 0.3]
13 | assert -0.3 <= v <= 0.3
14 | if random_mirror and random.random() > 0.5:
15 | v = -v
16 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
17 |
18 |
19 | def ShearY(img, v): # [-0.3, 0.3]
20 | assert -0.3 <= v <= 0.3
21 | if random_mirror and random.random() > 0.5:
22 | v = -v
23 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
24 |
25 |
26 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
27 | assert -0.45 <= v <= 0.45
28 | if random_mirror and random.random() > 0.5:
29 | v = -v
30 | v = v * img.size[0]
31 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
32 |
33 |
34 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
35 | assert -0.45 <= v <= 0.45
36 | if random_mirror and random.random() > 0.5:
37 | v = -v
38 | v = v * img.size[1]
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
40 |
41 |
42 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
43 | assert 0 <= v <= 10
44 | if random.random() > 0.5:
45 | v = -v
46 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
47 |
48 |
49 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
50 | assert 0 <= v <= 10
51 | if random.random() > 0.5:
52 | v = -v
53 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
54 |
55 |
56 | def Rotate(img, v): # [-30, 30]
57 | assert -30 <= v <= 30
58 | if random_mirror and random.random() > 0.5:
59 | v = -v
60 | return img.rotate(v)
61 |
62 |
63 | def AutoContrast(img, _):
64 | return PIL.ImageOps.autocontrast(img)
65 |
66 |
67 | def Invert(img, _):
68 | return PIL.ImageOps.invert(img)
69 |
70 |
71 | def Equalize(img, _):
72 | return PIL.ImageOps.equalize(img)
73 |
74 |
75 | def Flip(img, _): # not from the paper
76 | return PIL.ImageOps.mirror(img)
77 |
78 |
79 | def Solarize(img, v): # [0, 256]
80 | assert 0 <= v <= 256
81 | return PIL.ImageOps.solarize(img, v)
82 |
83 |
84 | def Posterize(img, v): # [4, 8]
85 | assert 4 <= v <= 8
86 | v = int(v)
87 | return PIL.ImageOps.posterize(img, v)
88 |
89 |
90 | def Posterize2(img, v): # [0, 4]
91 | assert 0 <= v <= 4
92 | v = int(v)
93 | return PIL.ImageOps.posterize(img, v)
94 |
95 |
96 | def Contrast(img, v): # [0.1,1.9]
97 | assert 0.1 <= v <= 1.9
98 | return PIL.ImageEnhance.Contrast(img).enhance(v)
99 |
100 |
101 | def Color(img, v): # [0.1,1.9]
102 | assert 0.1 <= v <= 1.9
103 | return PIL.ImageEnhance.Color(img).enhance(v)
104 |
105 |
106 | def Brightness(img, v): # [0.1,1.9]
107 | assert 0.1 <= v <= 1.9
108 | return PIL.ImageEnhance.Brightness(img).enhance(v)
109 |
110 |
111 | def Sharpness(img, v): # [0.1,1.9]
112 | assert 0.1 <= v <= 1.9
113 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
114 |
115 |
116 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
117 | assert 0.0 <= v <= 0.2
118 | if v <= 0.:
119 | return img
120 |
121 | v = v * img.size[0]
122 | return CutoutAbs(img, v)
123 |
124 |
125 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
126 | # assert 0 <= v <= 20
127 | if v < 0:
128 | return img
129 | w, h = img.size
130 | x0 = np.random.uniform(w)
131 | y0 = np.random.uniform(h)
132 |
133 | x0 = int(max(0, x0 - v / 2.))
134 | y0 = int(max(0, y0 - v / 2.))
135 | x1 = min(w, x0 + v)
136 | y1 = min(h, y0 + v)
137 |
138 | xy = (x0, y0, x1, y1)
139 | color = (125, 123, 114)
140 | # color = (0, 0, 0)
141 | img = img.copy()
142 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
143 | return img
144 |
145 |
146 | def SamplePairing(imgs): # [0, 0.4]
147 | def f(img1, v):
148 | i = np.random.choice(len(imgs))
149 | img2 = PIL.Image.fromarray(imgs[i])
150 | return PIL.Image.blend(img1, img2, v)
151 |
152 | return f
153 |
154 |
155 | def augment_list(for_autoaug=True): # 16 oeprations and their ranges
156 | l = [
157 | (ShearX, -0.3, 0.3), # 0
158 | (ShearY, -0.3, 0.3), # 1
159 | (TranslateX, -0.45, 0.45), # 2
160 | (TranslateY, -0.45, 0.45), # 3
161 | (Rotate, -30, 30), # 4
162 | (AutoContrast, 0, 1), # 5
163 | (Invert, 0, 1), # 6
164 | (Equalize, 0, 1), # 7
165 | (Solarize, 0, 256), # 8
166 | (Posterize, 4, 8), # 9
167 | (Contrast, 0.1, 1.9), # 10
168 | (Color, 0.1, 1.9), # 11
169 | (Brightness, 0.1, 1.9), # 12
170 | (Sharpness, 0.1, 1.9), # 13
171 | (Cutout, 0, 0.2), # 14
172 | # (SamplePairing(imgs), 0, 0.4), # 15
173 | ]
174 | if for_autoaug:
175 | l += [
176 | (CutoutAbs, 0, 20), # compatible with auto-augment
177 | (Posterize2, 0, 4), # 9
178 | (TranslateXAbs, 0, 10), # 9
179 | (TranslateYAbs, 0, 10), # 9
180 | ]
181 | return l
182 |
183 |
184 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
185 |
186 |
187 | def get_augment(name):
188 | return augment_dict[name]
189 |
190 |
191 | def apply_augment(img, name, level):
192 | augment_fn, low, high = get_augment(name)
193 | return augment_fn(img.copy(), level * (high - low) + low)
194 |
195 |
196 | class Lighting(object):
197 | """Lighting noise(AlexNet - style PCA - based noise)"""
198 |
199 | def __init__(self, alphastd, eigval, eigvec):
200 | self.alphastd = alphastd
201 | self.eigval = torch.Tensor(eigval)
202 | self.eigvec = torch.Tensor(eigvec)
203 |
204 | def __call__(self, img):
205 | if self.alphastd == 0:
206 | return img
207 |
208 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
209 | rgb = self.eigvec.type_as(img).clone() \
210 | .mul(alpha.view(1, 3).expand(3, 3)) \
211 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
212 | .sum(1).squeeze()
213 |
214 | return img.add(rgb.view(3, 1, 1).expand_as(img))
--------------------------------------------------------------------------------
/utils/bengali_evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchbearer
4 | from torchbearer import metrics, Trial
5 | from datasets.datasets import bengali
6 | import argparse
7 | from models import se_resnext50_32x4d
8 |
9 | parser = argparse.ArgumentParser(description='Bengali Evaluate')
10 |
11 | parser.add_argument('--dataset', type=str, default='bengali', help='Optional dataset path')
12 | parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path')
13 | parser.add_argument('--fold-path', type=str, default='./data/folds.npz', help='Path to object storing fold ids. Run-id 0 will regen this if not existing')
14 | parser.add_argument('--fold', type=str, default='test', help='One of [1, ..., n] or "test"')
15 | parser.add_argument('--run-id', type=int, default=0, help='Run id')
16 |
17 | parser.add_argument('--model-r', type=str, default=None, help='Root model')
18 | parser.add_argument('--model-v', type=str, default=None, help='Vowel model')
19 | parser.add_argument('--model-c', type=str, default=None, help='Consonant model')
20 |
21 | args = parser.parse_args()
22 |
23 | _, __, testset = bengali(args)
24 |
25 | testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=6)
26 |
27 |
28 | class BengaliModelWrapper(nn.Module):
29 | def __init__(self, model_r, model_v, model_c):
30 | super().__init__()
31 |
32 | self.model_r = model_r
33 | self.model_v = model_v
34 | self.model_c = model_c
35 |
36 | def forward(self, x):
37 | return self.model_r(x), self.model_v(x), self.model_c(x)
38 |
39 |
40 | @metrics.default_for_key('grapheme')
41 | @metrics.mean
42 | class GraphemeAccuracy(metrics.Metric):
43 | def __init__(self):
44 | super().__init__('grapheme_acc')
45 |
46 | def process(self, *args):
47 | state = args[0]
48 | r_pred, v_pred, c_pred = state[torchbearer.PREDICTION]
49 | r_true, v_true, c_true = state[torchbearer.TARGET]
50 |
51 | _, r_pred = torch.max(r_pred, 1)
52 | _, v_pred = torch.max(v_pred, 1)
53 | _, c_pred = torch.max(c_pred, 1)
54 |
55 | r = (r_pred == r_true)
56 | v = (v_pred == v_true)
57 | c = (c_pred == c_true)
58 | return torch.stack((r, v, c), dim=1).all(1).float()
59 |
60 |
61 | model_r = se_resnext50_32x4d(168, 1)
62 | model_v = se_resnext50_32x4d(11, 1)
63 | model_c = se_resnext50_32x4d(7, 1)
64 |
65 | model_r.load_state_dict(torch.load(args.model_r, map_location='cpu')[torchbearer.MODEL])
66 | model_v.load_state_dict(torch.load(args.model_v, map_location='cpu')[torchbearer.MODEL])
67 | model_c.load_state_dict(torch.load(args.model_c, map_location='cpu')[torchbearer.MODEL])
68 |
69 | model = BengaliModelWrapper(model_r, model_v, model_c)
70 |
71 | trial = Trial(model, criterion=lambda state: None, metrics=['grapheme']).with_test_generator(testloader).to('cuda')
72 | trial.evaluate(data_key=torchbearer.TEST_DATA)
73 |
--------------------------------------------------------------------------------
/utils/convert_imagenet_model.py:
--------------------------------------------------------------------------------
1 | """
2 | ImageNet models trained with the original imagenet_hdf5 data set have their outputs in the wrong order. This loads a
3 | model and re-orders the output weights to be consistent with other ImageNet models.
4 | """
5 | import argparse
6 | import os
7 | from torchvision.models import resnet101
8 | import torch
9 | import torchbearer
10 | from torch import nn
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--dataset-path', type=str, default=None, help='ImageNet path')
14 | parser.add_argument('--model-file')
15 | args = parser.parse_args()
16 |
17 | root = os.path.join(args.dataset_path, 'train')
18 |
19 | old_classes = list(filter(lambda f: '.hdf5' in f, os.listdir(root)))
20 | new_classes = sorted(old_classes)
21 |
22 | model = nn.DataParallel(resnet101(False))
23 | sd = torch.load(args.model_file, map_location='cpu')[torchbearer.MODEL]
24 | model.load_state_dict(sd)
25 | model = model.module
26 |
27 | new_weights = torch.zeros_like(model.fc.weight.data)
28 | new_bias = torch.zeros_like(model.fc.bias.data)
29 |
30 | for layer in range(1000):
31 | new_layer = new_classes.index(old_classes[layer])
32 |
33 | new_weights[layer, :] = model.fc.weight[layer, :]
34 | new_bias[layer] = model.fc.bias[layer]
35 |
36 | model.fc.weight.data = new_weights.data
37 | model.fc.bias.data = new_bias.data
38 |
39 | torch.save(model.state_dict(), args.model_file + '_converted.pt')
40 |
--------------------------------------------------------------------------------
/utils/cross_val.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import os
4 | from torchbearer.cv_utils import DatasetValidationSplitter
5 |
6 |
7 | def gen_folds(args, dataset, test_size):
8 | from sklearn.utils.validation import check_random_state
9 | from sklearn.model_selection._split import _validate_shuffle_split
10 |
11 | n_samples = len(dataset)
12 | n_train, n_test = _validate_shuffle_split(n_samples, test_size, None, default_test_size=0.1)
13 | rng = check_random_state(args.seed)
14 |
15 | train_folds = []
16 | test_folds = []
17 |
18 | for i in range(args.n_folds):
19 | # random partition
20 | permutation = rng.permutation(n_samples)
21 | ind_test = permutation[:n_test]
22 | ind_train = permutation[n_test:(n_test + n_train)]
23 | train_folds.append(ind_train)
24 | test_folds.append(ind_test)
25 |
26 | train_folds, test_folds = np.stack(train_folds), np.stack(test_folds)
27 | np.savez(args.fold_path, train=train_folds, test=test_folds)
28 |
29 |
30 | def split(func):
31 | def splitting(args):
32 | try:
33 | trainset, testset = func(args)
34 |
35 | if args.fold == 'test':
36 | return trainset, testset, testset
37 | except:
38 | trainset = func(args)
39 | testset = None
40 |
41 | if args.run_id == 0 and not os.path.exists(args.fold_path):
42 | gen_folds(args, trainset, len(trainset) // args.n_folds)
43 | else:
44 | time.sleep(3)
45 |
46 | folds = np.load(args.fold_path)
47 | train_ids, val_ids = folds['train'][int(args.fold)], folds['test'][int(args.fold)]
48 |
49 | splitter = DatasetValidationSplitter(len(trainset), 0.1)
50 | splitter.train_ids, splitter.valid_ids = train_ids, val_ids
51 |
52 | trainset, valset = splitter.get_train_dataset(trainset), splitter.get_val_dataset(trainset)
53 | return trainset, valset, (testset if testset is not None else valset)
54 |
55 | return splitting
56 |
--------------------------------------------------------------------------------
/utils/imagenet_to_hdf5.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import multiprocessing
4 | import os
5 | # from PIL import Image
6 | import pickle
7 | import process
8 |
9 | SIZE = 256
10 |
11 | imagenet_dir = '/ssd/ILSVRC2012/train'
12 | target_dir = '/data/imagenet_hdf5/train'
13 |
14 | subdirs = [d for d in os.listdir(imagenet_dir) if os.path.isdir(os.path.join(imagenet_dir, d))]
15 |
16 | num_cpus = multiprocessing.cpu_count()
17 | pool = multiprocessing.Pool(num_cpus)
18 |
19 | dest_list = []
20 |
21 | for subdir in subdirs:
22 | from_dir = os.path.join(imagenet_dir, subdir)
23 | to_path = os.path.join(target_dir, subdir + '.hdf5')
24 |
25 | paths = [os.path.join(from_dir, d) for d in os.listdir(from_dir)]
26 |
27 | # arr = np.array(pool.map(process.process, paths), dtype='uint8')
28 | images = pool.map(process.read_bytes, paths)
29 | with h5py.File(to_path, 'w') as f:
30 | dt = h5py.special_dtype(vlen=np.dtype('uint8'))
31 | dset = f.create_dataset('data', (len(paths), ), dtype=dt)
32 | for i, image in enumerate(images):
33 | dset[i] = np.fromstring(image, dtype='uint8')
34 |
35 | for i, path in enumerate(paths):
36 | dest_list.append((subdir, i))
37 |
38 | print(subdir)
39 |
40 | pickle.dump(dest_list, open(os.path.join(target_dir, 'dest.p'), 'wb'))
41 |
--------------------------------------------------------------------------------
/utils/lr_warmup.py:
--------------------------------------------------------------------------------
1 | from torchbearer.callbacks import Callback
2 | import torchbearer
3 |
4 |
5 | class WarmupLR(Callback):
6 | def __init__(self, min_lr, max_lr, warmup_period=5):
7 | super().__init__()
8 | self.min_lr = min_lr
9 | self.max_lr = max_lr
10 | self.t = warmup_period
11 |
12 | def on_start_training(self, state):
13 | super().on_start_training(state)
14 | if state[torchbearer.EPOCH] < self.t:
15 | delta = (self.t - state[torchbearer.EPOCH])/self.t
16 | opt = state[torchbearer.OPTIMIZER]
17 |
18 | for pg in opt.param_groups:
19 | pg['lr'] = self.min_lr * delta + self.max_lr * (1-delta)
20 |
--------------------------------------------------------------------------------
/utils/macro_recall.py:
--------------------------------------------------------------------------------
1 | from torchbearer.metrics import default_for_key, to_dict, EpochLambda
2 |
3 |
4 | @default_for_key('macro_recall')
5 | @to_dict
6 | class MacroRecall(EpochLambda):
7 | def __init__(self):
8 | from sklearn import metrics
9 |
10 | def process(y_pred):
11 | return y_pred.max(1)[1]
12 |
13 | super().__init__('macro_recall', lambda y_pred, y_true: metrics.recall_score(y_true.cpu().numpy(), process(y_pred).detach().cpu().numpy(), average='macro'))
14 |
--------------------------------------------------------------------------------
/utils/msda_alternator.py:
--------------------------------------------------------------------------------
1 | from torchbearer import Callback
2 |
3 |
4 | class MSDAAlternator(Callback):
5 | def __init__(self, msda_a, msda_b, n_a=1, n_b=1):
6 | super().__init__()
7 | self.augs = ((msda_a, n_a), (msda_b, n_b))
8 | self.current_aug = 0
9 | self.current_steps = 0
10 |
11 | def on_sample(self, state):
12 | super().on_sample(state)
13 |
14 | aug, steps = self.augs[self.current_aug]
15 | aug.on_sample(state)
16 |
17 | self.current_steps = self.current_steps + 1
18 | if self.current_steps >= steps:
19 | self.current_aug = (self.current_aug + 1) % 2
20 | self.current_steps = 0
--------------------------------------------------------------------------------
/utils/process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | SIZE = 256
5 |
6 |
7 | def process(path):
8 | with open(path, 'rb') as f:
9 | img = Image.open(f)
10 | img = img.convert('RGB')
11 | return np.array(img.resize((SIZE, SIZE), Image.BILINEAR))
12 |
13 |
14 | def read_bytes(path):
15 | f = open(path, 'rb')
16 | return f.read()
17 |
--------------------------------------------------------------------------------
/utils/reduced_dataset_splitter.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 |
4 | from torchbearer.cv_utils import SubsetDataset
5 |
6 |
7 | class EqualSplitter:
8 | """ Splits a dataset into two parts, taking equal number of samples from each class
9 |
10 | :param dataset:
11 | :param split_fraction:
12 | """
13 | def __init__(self, dataset, split_fraction):
14 | self.ds = dataset
15 | self.split_fraction = split_fraction
16 | self.train_ids, self.valid_ids = self._gen_split_ids()
17 |
18 | def _gen_split_ids(self):
19 | classes = defaultdict(list)
20 | for i in range(len(self.ds)):
21 | _, label = self.ds[i]
22 | classes[label].append(i)
23 |
24 | nc = len(classes.keys())
25 | cut_per_class = int(len(self.ds) * self.split_fraction / nc)
26 |
27 | cut_indexes = []
28 | retained_indexes = []
29 | for c in classes.keys():
30 | random.shuffle(classes[c])
31 | cut_indexes += classes[c][:cut_per_class]
32 | retained_indexes += classes[c][cut_per_class:]
33 | return cut_indexes, retained_indexes
34 |
35 | def get_train_dataset(self):
36 | """ Creates a training dataset from existing dataset
37 |
38 | Args:
39 | dataset (torch.utils.data.Dataset): Dataset to be split into a training dataset
40 |
41 | Returns:
42 | torch.utils.data.Dataset: Training dataset split from whole dataset
43 | """
44 | return SubsetDataset(self.ds, self.train_ids)
45 |
46 | def get_val_dataset(self):
47 | """ Creates a validation dataset from existing dataset
48 |
49 | Args:
50 | dataset (torch.utils.data.Dataset): Dataset to be split into a validation dataset
51 |
52 | Returns:
53 | torch.utils.data.Dataset: Validation dataset split from whole dataset
54 | """
55 | return SubsetDataset(self.ds, self.valid_ids)
56 |
57 |
--------------------------------------------------------------------------------
/utils/reformulated_mixup.py:
--------------------------------------------------------------------------------
1 | # Code modified from https://github.com/pytorchbearer/torchbearer
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.distributions.beta import Beta
6 |
7 | import torchbearer
8 | from torchbearer.callbacks import Callback
9 | from torchbearer import metrics as m
10 |
11 |
12 | @m.running_mean
13 | @m.mean
14 | class MixupAcc(m.AdvancedMetric):
15 | def __init__(self):
16 | m.super(MixupAcc, self).__init__('mixup_acc')
17 | self.cat_acc = m.CategoricalAccuracy().root
18 |
19 | def process_train(self, *args):
20 | m.super(MixupAcc, self).process_train(*args)
21 | state = args[0]
22 |
23 | target1 = state[torchbearer.Y_TRUE]
24 | target2 = target1[state[torchbearer.MIXUP_PERMUTATION]]
25 | _state = args[0].copy()
26 | _state[torchbearer.Y_TRUE] = target1
27 | acc1 = self.cat_acc.process(_state)
28 |
29 | _state = args[0].copy()
30 | _state[torchbearer.Y_TRUE] = target2
31 | acc2 = self.cat_acc.process(_state)
32 |
33 | return acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (1 - state[torchbearer.MIXUP_LAMBDA])
34 |
35 | def process_validate(self, *args):
36 | m.super(MixupAcc, self).process_validate(*args)
37 |
38 | return self.cat_acc.process(*args)
39 |
40 | def reset(self, state):
41 | self.cat_acc.reset(state)
42 |
43 |
44 | class RMixup(Callback):
45 | """Perform mixup on the model inputs. Requires use of :meth:`MixupInputs.loss`, otherwise lambdas can be found in
46 | state under :attr:`.MIXUP_LAMBDA`. Model targets will be a tuple containing the original target and permuted target.
47 |
48 | .. note::
49 |
50 | The accuracy metric for mixup is different on training to deal with the different targets,
51 | but for validation it is exactly the categorical accuracy, despite being called "val_mixup_acc"
52 |
53 | Example: ::
54 |
55 | >>> from torchbearer import Trial
56 | >>> from torchbearer.callbacks import Mixup
57 |
58 | # Example Trial which does Mixup regularisation
59 | >>> mixup = Mixup(0.9)
60 | >>> trial = Trial(None, criterion=Mixup.mixup_loss, callbacks=[mixup], metrics=['acc'])
61 |
62 | Args:
63 | lam (float): Mixup inputs by fraction lam. If RANDOM, choose lambda from Beta(alpha, alpha). Else, lambda=lam
64 | alpha (float): The alpha value to use in the beta distribution.
65 | """
66 | RANDOM = -10.0
67 |
68 | def __init__(self, alpha=1.0, lam=RANDOM, reformulate=False):
69 | super(RMixup, self).__init__()
70 | self.alpha = alpha
71 | self.lam = lam
72 | self.reformulate = reformulate
73 | self.distrib = Beta(self.alpha, self.alpha) if not reformulate else Beta(self.alpha + 1, self.alpha)
74 |
75 | @staticmethod
76 | def mixup_loss(state):
77 | """The standard cross entropy loss formulated for mixup (weighted combination of `F.cross_entropy`).
78 |
79 | Args:
80 | state: The current :class:`Trial` state.
81 | """
82 | input, target = state[torchbearer.Y_PRED], state[torchbearer.Y_TRUE]
83 |
84 | if state[torchbearer.DATA] is torchbearer.TRAIN_DATA:
85 | y1, y2 = target
86 | return F.cross_entropy(input, y1) * state[torchbearer.MIXUP_LAMBDA] + F.cross_entropy(input, y2) * (1-state[torchbearer.MIXUP_LAMBDA])
87 | else:
88 | return F.cross_entropy(input, target)
89 |
90 | def on_sample(self, state):
91 | if self.lam is RMixup.RANDOM:
92 | if self.alpha > 0:
93 | lam = self.distrib.sample()
94 | else:
95 | lam = 1.0
96 | else:
97 | lam = self.lam
98 |
99 | state[torchbearer.MIXUP_LAMBDA] = lam
100 |
101 | state[torchbearer.MIXUP_PERMUTATION] = torch.randperm(state[torchbearer.X].size(0))
102 | state[torchbearer.X] = state[torchbearer.X] * state[torchbearer.MIXUP_LAMBDA] + \
103 | state[torchbearer.X][state[torchbearer.MIXUP_PERMUTATION],:] \
104 | * (1 - state[torchbearer.MIXUP_LAMBDA])
105 |
106 | if self.reformulate:
107 | state[torchbearer.MIXUP_LAMBDA] = 1
108 |
109 |
110 | from torchbearer.metrics import default as d
111 | d.__loss_map__[RMixup.mixup_loss.__name__] = MixupAcc
112 |
--------------------------------------------------------------------------------