├── .gitignore
├── README.md
├── datasets.py
├── models.py
├── pretrain.py
├── splits
├── caltech101.pth
├── cub200.pth
├── dog.pth
├── imagenet100.txt
└── sun397.pth
├── trainers.py
├── transfer_few_shot.py
├── transfer_linear_eval.py
├── transforms.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | logs/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Transferability of Representations via Augmentation-Aware Self-Supervision
2 |
3 | Accepted to NeurIPS 2021
4 |
5 |
6 |
7 |
8 |
9 | **TL;DR:** Learning augmentation-aware information by predicting the difference between two augmented samples improves the transferability of representations.
10 |
11 | ## Dependencies
12 |
13 | ```bash
14 | conda create -n AugSelf python=3.8 pytorch=1.7.1 torchvision=0.8.2 cudatoolkit=10.1 ignite -c pytorch
15 | conda activate AugSelf
16 | pip install scipy tensorboard kornia==0.4.1 sklearn
17 | ```
18 |
19 | ## Checkpoints
20 |
21 | We provide ImageNet100-pretrained models in [this Dropbox link](https://www.dropbox.com/sh/0hjts19ysxebmaa/AABB6bF3QQWdIOCh9vocwTGGa?dl=0).
22 |
23 | ## Pretraining
24 |
25 | We here provide SimSiam+AugSelf pretraining scripts. For training the baseline (i.e., no AugSelf), remove `--ss-crop` and `--ss-color` options. For using other frameworks like SimCLR, use the `--framework` option.
26 |
27 | ### STL-10
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python pretrain.py \
30 | --logdir ./logs/stl10/simsiam/aug_self \
31 | --framework simsiam \
32 | --dataset stl10 \
33 | --datadir DATADIR \
34 | --model resnet18 \
35 | --batch-size 256 \
36 | --max-epochs 200 \
37 | --ss-color 1.0 --ss-crop 1.0
38 | ```
39 |
40 | ### ImageNet100
41 |
42 | ```bash
43 | python pretrain.py \
44 | --logdir ./logs/imagenet100/simsiam/aug_self \
45 | --framework simsiam \
46 | --dataset imagenet100 \
47 | --datadir DATADIR \
48 | --batch-size 256 \
49 | --max-epochs 500 \
50 | --model resnet50 \
51 | --base-lr 0.05 --wd 1e-4 \
52 | --ckpt-freq 50 --eval-freq 50 \
53 | --ss-crop 0.5 --ss-color 0.5 \
54 | --num-workers 16 --distributed
55 | ```
56 |
57 | ## Evaluation
58 |
59 | Our main evaluation setups are linear evaluation on fine-grained classification datasets (Table 1) and few-shot benchmarks (Table 2).
60 |
61 | ### linear evaluation
62 |
63 | ```bash
64 | CUDA_VISIBLE_DEVICES=0 python transfer_linear_eval.py \
65 | --pretrain-data imagenet100 \
66 | --ckpt CKPT \
67 | --model resnet50 \
68 | --dataset cifar10 \
69 | --datadir DATADIR \
70 | --metric top1
71 | ```
72 |
73 | ### few-shot
74 |
75 | ```bash
76 | CUDA_VISIBLE_DEVICES=0 python transfer_few_shot.py \
77 | --pretrain-data imagenet100 \
78 | --ckpt CKPT \
79 | --model resnet50 \
80 | --dataset cub200 \
81 | --datadir DATADIR
82 | ```
83 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | from scipy.io import loadmat
5 | from PIL import Image
6 | import xml.etree.ElementTree as ET
7 | from collections import defaultdict
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils.data import random_split, ConcatDataset, Subset
12 |
13 | from transforms import MultiView, RandomResizedCrop, ColorJitter, GaussianBlur, RandomRotation
14 | from torchvision import transforms as T
15 | from torchvision.datasets import STL10, CIFAR10, CIFAR100, ImageFolder, ImageNet, Caltech101, Caltech256
16 |
17 | import kornia.augmentation as K
18 |
19 | class ImageList(torch.utils.data.Dataset):
20 | def __init__(self, samples, transform=None):
21 | self.samples = samples
22 | self.transform = transform
23 |
24 | def __getitem__(self, idx):
25 | path, label = self.samples[idx]
26 | with open(path, 'rb') as f:
27 | img = Image.open(f)
28 | img = img.convert('RGB')
29 | if self.transform is not None:
30 | img = self.transform(img)
31 | return img, label
32 |
33 | def __len__(self):
34 | return len(self.samples)
35 |
36 | class ImageNet100(ImageFolder):
37 | def __init__(self, root, split, transform):
38 | with open('splits/imagenet100.txt') as f:
39 | classes = [line.strip() for line in f]
40 | class_to_idx = { cls: idx for idx, cls in enumerate(classes) }
41 |
42 | super().__init__(os.path.join(root, split), transform=transform)
43 | samples = []
44 | for path, label in self.samples:
45 | cls = self.classes[label]
46 | if cls not in class_to_idx:
47 | continue
48 | label = class_to_idx[cls]
49 | samples.append((path, label))
50 |
51 | self.samples = samples
52 | self.classes = classes
53 | self.class_to_idx = class_to_idx
54 | self.targets = [s[1] for s in samples]
55 |
56 | class Pets(ImageList):
57 | def __init__(self, root, split, transform=None):
58 | with open(os.path.join(root, 'annotations', f'{split}.txt')) as f:
59 | annotations = [line.split() for line in f]
60 |
61 | samples = []
62 | for sample in annotations:
63 | path = os.path.join(root, 'images', sample[0] + '.jpg')
64 | label = int(sample[1])-1
65 | samples.append((path, label))
66 |
67 | super().__init__(samples, transform)
68 |
69 | class Food101(ImageList):
70 | def __init__(self, root, split, transform=None):
71 | with open(os.path.join(root, 'meta', 'classes.txt')) as f:
72 | classes = [line.strip() for line in f]
73 | with open(os.path.join(root, 'meta', f'{split}.json')) as f:
74 | annotations = json.load(f)
75 |
76 | samples = []
77 | for i, cls in enumerate(classes):
78 | for path in annotations[cls]:
79 | samples.append((os.path.join(root, 'images', f'{path}.jpg'), i))
80 |
81 | super().__init__(samples, transform)
82 |
83 | class DTD(ImageList):
84 | def __init__(self, root, split, transform=None):
85 | with open(os.path.join(root, 'labels', f'{split}1.txt')) as f:
86 | paths = [line.strip() for line in f]
87 |
88 | classes = sorted(os.listdir(os.path.join(root, 'images')))
89 | samples = [(os.path.join(root, 'images', path), classes.index(path.split('/')[0])) for path in paths]
90 | super().__init__(samples, transform)
91 |
92 | class SUN397(ImageList):
93 | def __init__(self, root, split, transform=None):
94 | with open(os.path.join(root, 'ClassName.txt')) as f:
95 | classes = [line.strip() for line in f]
96 |
97 | with open(os.path.join(root, f'{split}_01.txt')) as f:
98 | samples = []
99 | for line in f:
100 | path = line.strip()
101 | for y, cls in enumerate(classes):
102 | if path.startswith(cls+'/'):
103 | samples.append((os.path.join(root, 'SUN397', path[1:]), y))
104 | break
105 | super().__init__(samples, transform)
106 |
107 | def load_pretrain_datasets(dataset='cifar10',
108 | datadir='/data',
109 | color_aug='default'):
110 |
111 | if dataset == 'imagenet100':
112 | mean = torch.tensor([0.485, 0.456, 0.406])
113 | std = torch.tensor([0.229, 0.224, 0.225])
114 | train_transform = MultiView(RandomResizedCrop(224, scale=(0.2, 1.0)))
115 | test_transform = T.Compose([T.Resize(224),
116 | T.CenterCrop(224),
117 | T.ToTensor(),
118 | T.Normalize(mean, std)])
119 | t1 = nn.Sequential(K.RandomHorizontalFlip(),
120 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
121 | K.RandomGrayscale(p=0.2),
122 | GaussianBlur(23, (0.1, 2.0)),
123 | K.Normalize(mean, std))
124 | t2 = nn.Sequential(K.RandomHorizontalFlip(),
125 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
126 | K.RandomGrayscale(p=0.2),
127 | GaussianBlur(23, (0.1, 2.0)),
128 | K.Normalize(mean, std))
129 |
130 | trainset = ImageNet100(datadir, split='train', transform=train_transform)
131 | valset = ImageNet100(datadir, split='train', transform=test_transform)
132 | testset = ImageNet100(datadir, split='val', transform=test_transform)
133 |
134 | elif dataset == 'stl10':
135 | mean = torch.tensor([0.43, 0.42, 0.39])
136 | std = torch.tensor([0.27, 0.26, 0.27])
137 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0)))
138 |
139 | if color_aug == 'default':
140 | s = 1
141 | elif color_aug == 'strong':
142 | s = 2.
143 | elif color_aug == 'weak':
144 | s = 0.5
145 | test_transform = T.Compose([T.Resize(96),
146 | T.CenterCrop(96),
147 | T.ToTensor(),
148 | T.Normalize(mean, std)])
149 | t1 = nn.Sequential(K.RandomHorizontalFlip(),
150 | ColorJitter(0.4*s, 0.4*s, 0.4*s, 0.1*s, p=0.8),
151 | K.RandomGrayscale(p=0.2*s),
152 | GaussianBlur(9, (0.1, 2.0)),
153 | K.Normalize(mean, std))
154 | t2 = nn.Sequential(K.RandomHorizontalFlip(),
155 | ColorJitter(0.4*s, 0.4*s, 0.4*s, 0.1*s, p=0.8),
156 | K.RandomGrayscale(p=0.2*s),
157 | GaussianBlur(9, (0.1, 2.0)),
158 | K.Normalize(mean, std))
159 |
160 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform)
161 | valset = STL10(datadir, split='train', transform=test_transform)
162 | testset = STL10(datadir, split='test', transform=test_transform)
163 |
164 | elif dataset == 'stl10_rot':
165 | mean = torch.tensor([0.43, 0.42, 0.39])
166 | std = torch.tensor([0.27, 0.26, 0.27])
167 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0)))
168 | test_transform = T.Compose([T.Resize(96),
169 | T.CenterCrop(96),
170 | T.ToTensor(),
171 | T.Normalize(mean, std)])
172 | t1 = nn.Sequential(K.RandomHorizontalFlip(),
173 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
174 | K.RandomGrayscale(p=0.2),
175 | GaussianBlur(9, (0.1, 2.0)),
176 | RandomRotation(p=0.5),
177 | K.Normalize(mean, std))
178 | t2 = nn.Sequential(K.RandomHorizontalFlip(),
179 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
180 | K.RandomGrayscale(p=0.2),
181 | GaussianBlur(9, (0.1, 2.0)),
182 | RandomRotation(p=0.5),
183 | K.Normalize(mean, std))
184 |
185 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform)
186 | valset = STL10(datadir, split='train', transform=test_transform)
187 | testset = STL10(datadir, split='test', transform=test_transform)
188 |
189 | elif dataset == 'stl10_sol':
190 | mean = torch.tensor([0.43, 0.42, 0.39])
191 | std = torch.tensor([0.27, 0.26, 0.27])
192 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0)))
193 |
194 | test_transform = T.Compose([T.Resize(96),
195 | T.CenterCrop(96),
196 | T.ToTensor(),
197 | T.Normalize(mean, std)])
198 | t1 = nn.Sequential(K.RandomHorizontalFlip(),
199 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
200 | K.RandomSolarize(0.5, 0.0, p=0.5),
201 | K.RandomGrayscale(p=0.2),
202 | GaussianBlur(9, (0.1, 2.0)),
203 | K.Normalize(mean, std))
204 | t2 = nn.Sequential(K.RandomHorizontalFlip(),
205 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8),
206 | K.RandomSolarize(0.5, 0.0, p=0.5),
207 | K.RandomGrayscale(p=0.2),
208 | GaussianBlur(9, (0.1, 2.0)),
209 | K.Normalize(mean, std))
210 |
211 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform)
212 | valset = STL10(datadir, split='train', transform=test_transform)
213 | testset = STL10(datadir, split='test', transform=test_transform)
214 |
215 | else:
216 | raise Exception(f'Unknown dataset {dataset}')
217 |
218 | return dict(train=trainset,
219 | val=valset,
220 | test=testset,
221 | t1=t1, t2=t2)
222 |
223 | def load_datasets(dataset='cifar10',
224 | datadir='/data',
225 | pretrain_data='stl10'):
226 |
227 | if pretrain_data == 'imagenet100':
228 | mean = torch.tensor([0.485, 0.456, 0.406])
229 | std = torch.tensor([0.229, 0.224, 0.225])
230 | transform = T.Compose([T.Resize(224, interpolation=Image.BICUBIC),
231 | T.CenterCrop(224),
232 | T.ToTensor(),
233 | T.Normalize(mean, std)])
234 |
235 | elif pretrain_data == 'stl10':
236 | mean = torch.tensor([0.43, 0.42, 0.39])
237 | std = torch.tensor([0.27, 0.26, 0.27])
238 | transform = T.Compose([T.Resize(96, interpolation=Image.BICUBIC),
239 | T.CenterCrop(96),
240 | T.ToTensor(),
241 | T.Normalize(mean, std)])
242 |
243 | generator = lambda seed: torch.Generator().manual_seed(seed)
244 | if dataset == 'imagenet100':
245 | trainval = ImageNet100(datadir, split='train', transform=transform)
246 | train, val = None, None
247 | test = ImageNet100(datadir, split='val', transform=transform)
248 | num_classes = 100
249 |
250 | elif dataset == 'food101':
251 | trainval = Food101(root=datadir, split='train', transform=transform)
252 | train, val = random_split(trainval, [68175, 7575], generator=generator(42))
253 | test = Food101(root=datadir, split='test', transform=transform)
254 | num_classes = 101
255 |
256 | elif dataset == 'cifar10':
257 | trainval = CIFAR10(root=datadir, train=True, transform=transform)
258 | train, val = random_split(trainval, [45000, 5000], generator=generator(43))
259 | test = CIFAR10(root=datadir, train=False, transform=transform)
260 | num_classes = 10
261 |
262 | elif dataset == 'cifar100':
263 | trainval = CIFAR100(root=datadir, train=True, transform=transform)
264 | train, val = random_split(trainval, [45000, 5000], generator=generator(44))
265 | test = CIFAR100(root=datadir, train=False, transform=transform)
266 | num_classes = 100
267 |
268 | elif dataset == 'sun397':
269 | trn_indices, val_indices = torch.load('splits/sun397.pth')
270 | trainval = SUN397(root=datadir, split='Training', transform=transform)
271 | train = Subset(trainval, trn_indices)
272 | val = Subset(trainval, val_indices)
273 | test = SUN397(root=datadir, split='Testing', transform=transform)
274 | num_classes = 397
275 |
276 | elif dataset == 'dtd':
277 | train = DTD(root=datadir, split='train', transform=transform)
278 | val = DTD(root=datadir, split='val', transform=transform)
279 | trainval = ConcatDataset([train, val])
280 | test = DTD(root=datadir, split='test', transform=transform)
281 | num_classes = 47
282 |
283 | elif dataset == 'pets':
284 | trainval = Pets(root=datadir, split='trainval', transform=transform)
285 | train, val = random_split(trainval, [2940, 740], generator=generator(49))
286 | test = Pets(root=datadir, split='test', transform=transform)
287 | num_classes = 37
288 |
289 | elif dataset == 'caltech101':
290 | transform.transforms.insert(0, T.Lambda(lambda img: img.convert('RGB')))
291 | D = Caltech101(datadir, transform=transform)
292 | trn_indices, val_indices, tst_indices = torch.load('splits/caltech101.pth')
293 | train = Subset(D, trn_indices)
294 | val = Subset(D, val_indices)
295 | trainval = ConcatDataset([train, val])
296 | test = Subset(D, tst_indices)
297 | num_classes = 101
298 |
299 | elif dataset == 'flowers':
300 | train = ImageFolder(os.path.join(datadir, 'trn'), transform=transform)
301 | val = ImageFolder(os.path.join(datadir, 'val'), transform=transform)
302 | trainval = ConcatDataset([train, val])
303 | test = ImageFolder(os.path.join(datadir, 'tst'), transform=transform)
304 | num_classes = 102
305 |
306 | elif dataset in ['flowers-5shot', 'flowers-10shot']:
307 | if dataset == 'flowers-5shot':
308 | n = 5
309 | else:
310 | n = 10
311 | train = ImageFolder(os.path.join(datadir, 'trn'), transform=transform)
312 | val = ImageFolder(os.path.join(datadir, 'val'), transform=transform)
313 | trainval = ImageFolder(os.path.join(datadir, 'trn'), transform=transform)
314 | trainval.samples += val.samples
315 | trainval.targets += val.targets
316 | indices = defaultdict(list)
317 | for i, y in enumerate(trainval.targets):
318 | indices[y].append(i)
319 | indices = sum([random.sample(indices[y], n) for y in indices.keys()], [])
320 | trainval = Subset(trainval, indices)
321 | test = ImageFolder(os.path.join(datadir, 'tst'), transform=transform)
322 | num_classes = 102
323 |
324 | elif dataset == 'stl10':
325 | trainval = STL10(root=datadir, split='train', transform=transform)
326 | test = STL10(root=datadir, split='test', transform=transform)
327 | train, val = random_split(trainval, [4500, 500], generator=generator(50))
328 | num_classes = 10
329 |
330 | elif dataset == 'mit67':
331 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
332 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
333 | train, val = random_split(trainval, [4690, 670], generator=generator(51))
334 | num_classes = 67
335 |
336 | elif dataset == 'cub200':
337 | trn_indices, val_indices = torch.load('splits/cub200.pth')
338 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
339 | train = Subset(trainval, trn_indices)
340 | val = Subset(trainval, val_indices)
341 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
342 | num_classes = 200
343 |
344 | elif dataset == 'dog':
345 | trn_indices, val_indices = torch.load('splits/dog.pth')
346 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
347 | train = Subset(trainval, trn_indices)
348 | val = Subset(trainval, val_indices)
349 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
350 | num_classes = 120
351 |
352 | return dict(trainval=trainval,
353 | train=train,
354 | val=val,
355 | test=test,
356 | num_classes=num_classes)
357 |
358 |
359 | def load_fewshot_datasets(dataset='cifar10',
360 | datadir='/data',
361 | pretrain_data='stl10'):
362 |
363 | if pretrain_data == 'imagenet100':
364 | mean = torch.tensor([0.485, 0.456, 0.406])
365 | std = torch.tensor([0.229, 0.224, 0.225])
366 | transform = T.Compose([T.Resize(224, interpolation=Image.BICUBIC),
367 | T.CenterCrop(224),
368 | T.ToTensor(),
369 | T.Normalize(mean, std)])
370 |
371 | elif pretrain_data == 'stl10':
372 | mean = torch.tensor([0.43, 0.42, 0.39])
373 | std = torch.tensor([0.27, 0.26, 0.27])
374 | transform = T.Compose([T.Resize(96, interpolation=Image.BICUBIC),
375 | T.CenterCrop(96),
376 | T.ToTensor(),
377 | T.Normalize(mean, std)])
378 |
379 | if dataset == 'cub200':
380 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
381 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
382 | test.samples = train.samples + test.samples
383 | test.targets = train.targets + test.targets
384 |
385 | elif dataset == 'fc100':
386 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
387 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
388 |
389 | elif dataset == 'plant_disease':
390 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform)
391 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform)
392 | test.samples = train.samples + test.samples
393 | test.targets = train.targets + test.targets
394 |
395 | return dict(test=test)
396 |
397 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 |
8 | from torchvision import models
9 |
10 | def reset_parameters(model):
11 | for m in model.modules():
12 | if isinstance(m, nn.Conv2d):
13 | m.reset_parameters()
14 |
15 | if isinstance(m, nn.Linear):
16 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
17 | bound = 1 / math.sqrt(fan_in)
18 | nn.init.uniform_(m.weight, -bound, bound)
19 | if m.bias is not None:
20 | nn.init.uniform_(m.bias, -bound, bound)
21 |
22 | def load_backbone(args):
23 | name = args.model
24 | backbone = models.__dict__[name.split('_')[-1]](zero_init_residual=True)
25 | if name.startswith('cifar_'):
26 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
27 | backbone.maxpool = nn.Identity()
28 | args.num_backbone_features = backbone.fc.weight.shape[1]
29 | backbone.fc = nn.Identity()
30 | reset_parameters(backbone)
31 | return backbone
32 |
33 |
34 | def load_mlp(n_in, n_hidden, n_out, num_layers=3, last_bn=True):
35 | layers = []
36 | for i in range(num_layers-1):
37 | layers.append(nn.Linear(n_in, n_hidden, bias=False))
38 | layers.append(nn.BatchNorm1d(n_hidden))
39 | layers.append(nn.ReLU())
40 | n_in = n_hidden
41 | layers.append(nn.Linear(n_hidden, n_out, bias=not last_bn))
42 | if last_bn:
43 | layers.append(nn.BatchNorm1d(n_out))
44 | mlp = nn.Sequential(*layers)
45 | reset_parameters(mlp)
46 | return mlp
47 |
48 |
49 | def load_ss_predictor(n_in, ss_objective, n_hidden=512):
50 | ss_predictor = {}
51 | for name, weight, n_out, _ in ss_objective.params:
52 | if weight > 0:
53 | ss_predictor[name] = load_mlp(n_in*2, n_hidden, n_out, num_layers=3, last_bn=False)
54 |
55 | return ss_predictor
56 |
57 |
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | import torch.backends.cudnn as cudnn
10 |
11 | import ignite
12 | from ignite.engine import Events
13 | import ignite.distributed as idist
14 |
15 | from datasets import load_pretrain_datasets
16 | from models import load_backbone, load_mlp, load_ss_predictor
17 | import trainers
18 | from trainers import SSObjective
19 | from utils import Logger
20 |
21 | def simsiam(args, t1, t2):
22 | out_dim = 2048
23 | device = idist.device()
24 |
25 | ss_objective = SSObjective(
26 | crop = args.ss_crop,
27 | color = args.ss_color,
28 | flip = args.ss_flip,
29 | blur = args.ss_blur,
30 | rot = args.ss_rot,
31 | sol = args.ss_sol,
32 | only = args.ss_only,
33 | )
34 |
35 | build_model = partial(idist.auto_model, sync_bn=True)
36 | backbone = build_model(load_backbone(args))
37 | projector = build_model(load_mlp(args.num_backbone_features,
38 | out_dim,
39 | out_dim,
40 | num_layers=2+int(args.dataset.startswith('imagenet')),
41 | last_bn=True))
42 | predictor = build_model(load_mlp(out_dim,
43 | out_dim // 4,
44 | out_dim,
45 | num_layers=2,
46 | last_bn=False))
47 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective)
48 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() }
49 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], [])
50 |
51 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
52 | build_optim = lambda x: idist.auto_optim(SGD(x))
53 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params),
54 | build_optim(list(predictor.parameters()))]
55 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)]
56 |
57 | trainer = trainers.simsiam(backbone=backbone,
58 | projector=projector,
59 | predictor=predictor,
60 | ss_predictor=ss_predictor,
61 | t1=t1, t2=t2,
62 | optimizers=optimizers,
63 | device=device,
64 | ss_objective=ss_objective)
65 |
66 | return dict(backbone=backbone,
67 | projector=projector,
68 | predictor=predictor,
69 | ss_predictor=ss_predictor,
70 | optimizers=optimizers,
71 | schedulers=schedulers,
72 | trainer=trainer)
73 |
74 |
75 | def moco(args, t1, t2):
76 | out_dim = 128
77 | device = idist.device()
78 |
79 | ss_objective = SSObjective(
80 | crop = args.ss_crop,
81 | color = args.ss_color,
82 | flip = args.ss_flip,
83 | blur = args.ss_blur,
84 | only = args.ss_only,
85 | )
86 |
87 | build_model = partial(idist.auto_model, sync_bn=True)
88 | backbone = build_model(load_backbone(args))
89 | projector = build_model(load_mlp(args.num_backbone_features,
90 | args.num_backbone_features,
91 | out_dim,
92 | num_layers=2,
93 | last_bn=False))
94 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective)
95 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() }
96 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], [])
97 |
98 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
99 | build_optim = lambda x: idist.auto_optim(SGD(x))
100 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params)]
101 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)]
102 |
103 | trainer = trainers.moco(
104 | backbone=backbone,
105 | projector=projector,
106 | ss_predictor=ss_predictor,
107 | t1=t1, t2=t2,
108 | optimizers=optimizers,
109 | device=device,
110 | ss_objective=ss_objective)
111 |
112 | return dict(backbone=backbone,
113 | projector=projector,
114 | ss_predictor=ss_predictor,
115 | optimizers=optimizers,
116 | schedulers=schedulers,
117 | trainer=trainer)
118 |
119 | def simclr(args, t1, t2):
120 | out_dim = 128
121 | device = idist.device()
122 |
123 | ss_objective = SSObjective(
124 | crop = args.ss_crop,
125 | color = args.ss_color,
126 | flip = args.ss_flip,
127 | blur = args.ss_blur,
128 | only = args.ss_only,
129 | )
130 |
131 | build_model = partial(idist.auto_model, sync_bn=True)
132 | backbone = build_model(load_backbone(args))
133 | projector = build_model(load_mlp(args.num_backbone_features,
134 | args.num_backbone_features,
135 | out_dim,
136 | num_layers=2,
137 | last_bn=False))
138 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective)
139 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() }
140 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], [])
141 |
142 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
143 | build_optim = lambda x: idist.auto_optim(SGD(x))
144 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params)]
145 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)]
146 |
147 | trainer = trainers.simclr(backbone=backbone,
148 | projector=projector,
149 | ss_predictor=ss_predictor,
150 | t1=t1, t2=t2,
151 | optimizers=optimizers,
152 | device=device,
153 | ss_objective=ss_objective)
154 |
155 | return dict(backbone=backbone,
156 | projector=projector,
157 | ss_predictor=ss_predictor,
158 | optimizers=optimizers,
159 | schedulers=schedulers,
160 | trainer=trainer)
161 |
162 |
163 | def byol(args, t1, t2):
164 | out_dim = 256
165 | h_dim = 4096
166 | device = idist.device()
167 |
168 | ss_objective = SSObjective(
169 | crop = args.ss_crop,
170 | color = args.ss_color,
171 | flip = args.ss_flip,
172 | blur = args.ss_blur,
173 | rot = args.ss_rot,
174 | only = args.ss_only,
175 | )
176 |
177 | build_model = partial(idist.auto_model, sync_bn=True)
178 | backbone = build_model(load_backbone(args))
179 | projector = build_model(load_mlp(args.num_backbone_features,
180 | h_dim,
181 | out_dim,
182 | num_layers=2,
183 | last_bn=False))
184 | predictor = build_model(load_mlp(out_dim,
185 | h_dim,
186 | out_dim,
187 | num_layers=2,
188 | last_bn=False))
189 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective)
190 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() }
191 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], [])
192 |
193 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
194 | build_optim = lambda x: idist.auto_optim(SGD(x))
195 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params+list(predictor.parameters()))]
196 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)]
197 |
198 | trainer = trainers.byol(backbone=backbone,
199 | projector=projector,
200 | predictor=predictor,
201 | ss_predictor=ss_predictor,
202 | t1=t1, t2=t2,
203 | optimizers=optimizers,
204 | device=device,
205 | ss_objective=ss_objective)
206 |
207 | return dict(backbone=backbone,
208 | projector=projector,
209 | predictor=predictor,
210 | ss_predictor=ss_predictor,
211 | optimizers=optimizers,
212 | schedulers=schedulers,
213 | trainer=trainer)
214 |
215 |
216 | def swav(args, t1, t2):
217 | out_dim = 128
218 | h_dim = 2048
219 | device = idist.device()
220 |
221 | ss_objective = SSObjective(
222 | crop = args.ss_crop,
223 | color = args.ss_color,
224 | flip = args.ss_flip,
225 | blur = args.ss_blur,
226 | rot = args.ss_rot,
227 | only = args.ss_only,
228 | )
229 |
230 | build_model = partial(idist.auto_model, sync_bn=True)
231 | backbone = build_model(load_backbone(args))
232 | projector = build_model(load_mlp(args.num_backbone_features,
233 | h_dim,
234 | out_dim,
235 | num_layers=2,
236 | last_bn=False))
237 | prototypes = build_model(nn.Linear(out_dim, 100, bias=False))
238 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective)
239 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() }
240 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], [])
241 |
242 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum)
243 | build_optim = lambda x: idist.auto_optim(SGD(x))
244 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params+list(prototypes.parameters()))]
245 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)]
246 |
247 | trainer = trainers.swav(backbone=backbone,
248 | projector=projector,
249 | prototypes=prototypes,
250 | ss_predictor=ss_predictor,
251 | t1=t1, t2=t2,
252 | optimizers=optimizers,
253 | device=device,
254 | ss_objective=ss_objective)
255 |
256 | return dict(backbone=backbone,
257 | projector=projector,
258 | prototypes=prototypes,
259 | ss_predictor=ss_predictor,
260 | optimizers=optimizers,
261 | schedulers=schedulers,
262 | trainer=trainer)
263 |
264 |
265 | def main(local_rank, args):
266 | cudnn.benchmark = True
267 | device = idist.device()
268 | logger = Logger(args.logdir, args.resume)
269 |
270 | # DATASETS
271 | datasets = load_pretrain_datasets(dataset=args.dataset,
272 | datadir=args.datadir,
273 | color_aug=args.color_aug)
274 | build_dataloader = partial(idist.auto_dataloader,
275 | batch_size=args.batch_size,
276 | num_workers=args.num_workers,
277 | shuffle=True,
278 | pin_memory=True)
279 | trainloader = build_dataloader(datasets['train'], drop_last=True)
280 | valloader = build_dataloader(datasets['val'] , drop_last=False)
281 | testloader = build_dataloader(datasets['test'], drop_last=False)
282 |
283 | t1, t2 = datasets['t1'], datasets['t2']
284 |
285 | # MODELS
286 | if args.framework == 'simsiam':
287 | models = simsiam(args, t1, t2)
288 | elif args.framework == 'moco':
289 | models = moco(args, t1, t2)
290 | elif args.framework == 'simclr':
291 | models = simclr(args, t1, t2)
292 | elif args.framework == 'byol':
293 | models = byol(args, t1, t2)
294 | elif args.framework == 'swav':
295 | models = swav(args, t1, t2)
296 |
297 | trainer = models['trainer']
298 | evaluator = trainers.nn_evaluator(backbone=models['backbone'],
299 | trainloader=valloader,
300 | testloader=testloader,
301 | device=device)
302 |
303 | if args.distributed:
304 | @trainer.on(Events.EPOCH_STARTED)
305 | def set_epoch(engine):
306 | for loader in [trainloader, valloader, testloader]:
307 | loader.sampler.set_epoch(engine.state.epoch)
308 |
309 | @trainer.on(Events.ITERATION_STARTED)
310 | def log_lr(engine):
311 | lrs = {}
312 | for i, optimizer in enumerate(models['optimizers']):
313 | for j, pg in enumerate(optimizer.param_groups):
314 | lrs[f'lr/{i}-{j}'] = pg['lr']
315 | logger.log(engine, engine.state.iteration, print_msg=False, **lrs)
316 |
317 | @trainer.on(Events.ITERATION_COMPLETED)
318 | def log(engine):
319 | loss = engine.state.output.pop('loss')
320 | ss_loss = engine.state.output.pop('ss/total')
321 | logger.log(engine, engine.state.iteration,
322 | print_msg=engine.state.iteration % args.print_freq == 0,
323 | loss=loss, ss_loss=ss_loss)
324 |
325 | if 'z1' in engine.state.output:
326 | with torch.no_grad():
327 | z1 = engine.state.output.pop('z1')
328 | z2 = engine.state.output.pop('z2')
329 | z1 = F.normalize(z1, dim=-1)
330 | z2 = F.normalize(z2, dim=-1)
331 | dist = torch.einsum('ik, jk -> ij', z1, z2)
332 | diag_masks = torch.diag(torch.ones(z1.shape[0])).bool()
333 | engine.state.output['dist/intra'] = dist[diag_masks].mean().item()
334 | engine.state.output['dist/inter'] = dist[~diag_masks].mean().item()
335 |
336 | logger.log(engine, engine.state.iteration,
337 | print_msg=False,
338 | **engine.state.output)
339 |
340 | @trainer.on(Events.EPOCH_COMPLETED(every=args.eval_freq))
341 | def evaluate(engine):
342 | acc = evaluator()
343 | logger.log(engine, engine.state.epoch, acc=acc)
344 |
345 | @trainer.on(Events.EPOCH_COMPLETED)
346 | def update_lr(engine):
347 | for scheduler in models['schedulers']:
348 | scheduler.step()
349 |
350 | @trainer.on(Events.EPOCH_COMPLETED(every=args.ckpt_freq))
351 | def save_ckpt(engine):
352 | logger.save(engine, **models)
353 |
354 | if args.resume is not None:
355 | @trainer.on(Events.STARTED)
356 | def load_state(engine):
357 | ckpt = torch.load(os.path.join(args.logdir, f'ckpt-{args.resume}.pth'), map_location='cpu')
358 | for k, v in models.items():
359 | if isinstance(v, nn.parallel.DistributedDataParallel):
360 | v = v.module
361 |
362 | if hasattr(v, 'state_dict'):
363 | v.load_state_dict(ckpt[k])
364 |
365 | if type(v) is list and hasattr(v[0], 'state_dict'):
366 | for i, x in enumerate(v):
367 | x.load_state_dict(ckpt[k][i])
368 |
369 | if type(v) is dict and k == 'ss_predictor':
370 | for y, x in v.items():
371 | x.load_state_dict(ckpt[k][y])
372 |
373 | trainer.run(trainloader, max_epochs=args.max_epochs)
374 |
375 | if __name__ == '__main__':
376 | parser = ArgumentParser()
377 | parser.add_argument('--logdir', type=str, required=True)
378 | parser.add_argument('--resume', type=int, default=None)
379 | parser.add_argument('--dataset', type=str, default='stl10')
380 | parser.add_argument('--datadir', type=str, default='/data')
381 | parser.add_argument('--batch-size', type=int, default=256)
382 | parser.add_argument('--max-epochs', type=int, default=200)
383 | parser.add_argument('--num-workers', type=int, default=4)
384 | parser.add_argument('--model', type=str, default='resnet18')
385 | parser.add_argument('--distributed', action='store_true')
386 |
387 | parser.add_argument('--framework', type=str, default='simsiam')
388 |
389 | parser.add_argument('--base-lr', type=float, default=0.03)
390 | parser.add_argument('--wd', type=float, default=5e-4)
391 | parser.add_argument('--momentum', type=float, default=0.9)
392 |
393 | parser.add_argument('--print-freq', type=int, default=10)
394 | parser.add_argument('--ckpt-freq', type=int, default=10)
395 | parser.add_argument('--eval-freq', type=int, default=1)
396 |
397 | parser.add_argument('--color-aug', type=str, default='default')
398 |
399 | parser.add_argument('--ss-crop', type=float, default=-1)
400 | parser.add_argument('--ss-color', type=float, default=-1)
401 | parser.add_argument('--ss-flip', type=float, default=-1)
402 | parser.add_argument('--ss-blur', type=float, default=-1)
403 | parser.add_argument('--ss-rot', type=float, default=-1)
404 | parser.add_argument('--ss-sol', type=float, default=-1)
405 | parser.add_argument('--ss-only', action='store_true')
406 |
407 | args = parser.parse_args()
408 | args.lr = args.base_lr * args.batch_size / 256
409 | if not args.distributed:
410 | with idist.Parallel() as parallel:
411 | parallel.run(main, args)
412 | else:
413 | with idist.Parallel('nccl', nproc_per_node=torch.cuda.device_count()) as parallel:
414 | parallel.run(main, args)
415 |
416 |
--------------------------------------------------------------------------------
/splits/caltech101.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/caltech101.pth
--------------------------------------------------------------------------------
/splits/cub200.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/cub200.pth
--------------------------------------------------------------------------------
/splits/dog.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/dog.pth
--------------------------------------------------------------------------------
/splits/imagenet100.txt:
--------------------------------------------------------------------------------
1 | n02869837
2 | n01749939
3 | n02488291
4 | n02107142
5 | n13037406
6 | n02091831
7 | n04517823
8 | n04589890
9 | n03062245
10 | n01773797
11 | n01735189
12 | n07831146
13 | n07753275
14 | n03085013
15 | n04485082
16 | n02105505
17 | n01983481
18 | n02788148
19 | n03530642
20 | n04435653
21 | n02086910
22 | n02859443
23 | n13040303
24 | n03594734
25 | n02085620
26 | n02099849
27 | n01558993
28 | n04493381
29 | n02109047
30 | n04111531
31 | n02877765
32 | n04429376
33 | n02009229
34 | n01978455
35 | n02106550
36 | n01820546
37 | n01692333
38 | n07714571
39 | n02974003
40 | n02114855
41 | n03785016
42 | n03764736
43 | n03775546
44 | n02087046
45 | n07836838
46 | n04099969
47 | n04592741
48 | n03891251
49 | n02701002
50 | n03379051
51 | n02259212
52 | n07715103
53 | n03947888
54 | n04026417
55 | n02326432
56 | n03637318
57 | n01980166
58 | n02113799
59 | n02086240
60 | n03903868
61 | n02483362
62 | n04127249
63 | n02089973
64 | n03017168
65 | n02093428
66 | n02804414
67 | n02396427
68 | n04418357
69 | n02172182
70 | n01729322
71 | n02113978
72 | n03787032
73 | n02089867
74 | n02119022
75 | n03777754
76 | n04238763
77 | n02231487
78 | n03032252
79 | n02138441
80 | n02104029
81 | n03837869
82 | n03494278
83 | n04136333
84 | n03794056
85 | n03492542
86 | n02018207
87 | n04067472
88 | n03930630
89 | n03584829
90 | n02123045
91 | n04229816
92 | n02100583
93 | n03642806
94 | n04336792
95 | n03259280
96 | n02116738
97 | n02108089
98 | n03424325
99 | n01855672
100 | n02090622
101 |
--------------------------------------------------------------------------------
/splits/sun397.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/sun397.pth
--------------------------------------------------------------------------------
/trainers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from copy import deepcopy
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from ignite.engine import Engine
9 | import ignite.distributed as idist
10 |
11 | from transforms import extract_diff
12 |
13 |
14 | class SSObjective:
15 | def __init__(self, crop=-1, color=-1, flip=-1, blur=-1, rot=-1, sol=-1, only=False):
16 | self.only = only
17 | self.params = [
18 | ('crop', crop, 4, 'regression'),
19 | ('color', color, 4, 'regression'),
20 | ('flip', flip, 1, 'binary_classification'),
21 | ('blur', blur, 1, 'regression'),
22 | ('rot', rot, 4, 'classification'),
23 | ('sol', sol, 1, 'regression'),
24 | ]
25 |
26 | def __call__(self, ss_predictor, z1, z2, d1, d2, symmetric=True):
27 | if symmetric:
28 | z = torch.cat([torch.cat([z1, z2], 1),
29 | torch.cat([z2, z1], 1)], 0)
30 | d = { k: torch.cat([d1[k], d2[k]], 0) for k in d1.keys() }
31 | else:
32 | z = torch.cat([z1, z2], 1)
33 | d = d1
34 |
35 | losses = { 'total': 0 }
36 | for name, weight, n_out, loss_type in self.params:
37 | if weight <= 0:
38 | continue
39 |
40 | p = ss_predictor[name](z)
41 | if loss_type == 'regression':
42 | losses[name] = F.mse_loss(torch.tanh(p), d[name])
43 | elif loss_type == 'binary_classification':
44 | losses[name] = F.binary_cross_entropy_with_logits(p, d[name])
45 | elif loss_type == 'classification':
46 | losses[name] = F.cross_entropy(p, d[name])
47 | losses['total'] += losses[name] * weight
48 |
49 | return losses
50 |
51 |
52 | def prepare_training_batch(batch, t1, t2, device):
53 | ((x1, w1), (x2, w2)), _ = batch
54 | with torch.no_grad():
55 | x1 = t1(x1.to(device)).detach()
56 | x2 = t2(x2.to(device)).detach()
57 | diff1 = { k: v.to(device) for k, v in extract_diff(t1, t2, w1, w2).items() }
58 | diff2 = { k: v.to(device) for k, v in extract_diff(t2, t1, w2, w1).items() }
59 |
60 | return x1, x2, diff1, diff2
61 |
62 |
63 | def simsiam(backbone,
64 | projector,
65 | predictor,
66 | ss_predictor,
67 | t1,
68 | t2,
69 | optimizers,
70 | device,
71 | ss_objective
72 | ):
73 |
74 | def training_step(engine, batch):
75 | backbone.train()
76 | projector.train()
77 | predictor.train()
78 |
79 | for o in optimizers:
80 | o.zero_grad()
81 |
82 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device)
83 | y1, y2 = backbone(x1), backbone(x2)
84 |
85 | if not ss_objective.only:
86 | z1 = projector(y1)
87 | z2 = projector(y2)
88 | p1 = predictor(z1)
89 | p2 = predictor(z2)
90 | loss1 = F.cosine_similarity(p1, z2.detach(), dim=-1).mean().mul(-1)
91 | loss2 = F.cosine_similarity(p2, z1.detach(), dim=-1).mean().mul(-1)
92 | loss = (loss1+loss2).mul(0.5)
93 | else:
94 | loss = 0.
95 |
96 | outputs = dict(loss=loss)
97 | if not ss_objective.only:
98 | outputs['z1'] = z1
99 | outputs['z2'] = z2
100 |
101 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2)
102 | (loss+ss_losses['total']).backward()
103 | for k, v in ss_losses.items():
104 | outputs[f'ss/{k}'] = v
105 |
106 | for o in optimizers:
107 | o.step()
108 |
109 | return outputs
110 |
111 | return Engine(training_step)
112 |
113 |
114 | def moco(backbone,
115 | projector,
116 | ss_predictor,
117 | t1,
118 | t2,
119 | optimizers,
120 | device,
121 | ss_objective,
122 | momentum=0.999,
123 | K=65536,
124 | T=0.2,
125 | ):
126 |
127 | target_backbone = deepcopy(backbone)
128 | target_projector = deepcopy(projector)
129 | for p in list(target_backbone.parameters())+list(target_projector.parameters()):
130 | p.requires_grad = False
131 |
132 | queue = F.normalize(torch.randn(K, 128).to(device)).detach()
133 | queue.requires_grad = False
134 | queue.ptr = 0
135 |
136 | def training_step(engine, batch):
137 | backbone.train()
138 | projector.train()
139 | target_backbone.train()
140 | target_projector.train()
141 |
142 | for o in optimizers:
143 | o.zero_grad()
144 |
145 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device)
146 | y1 = backbone(x1)
147 | z1 = F.normalize(projector(y1))
148 | with torch.no_grad():
149 | y2 = target_backbone(x2)
150 | z2 = F.normalize(target_projector(y2))
151 |
152 | l_pos = torch.einsum('nc,nc->n', [z1, z2]).unsqueeze(-1)
153 | l_neg = torch.einsum('nc,kc->nk', [z1, queue.clone().detach()])
154 | logits = torch.cat([l_pos, l_neg], dim=1).div(T)
155 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
156 | loss = F.cross_entropy(logits, labels)
157 | outputs = dict(loss=loss, z1=z1, z2=z2)
158 |
159 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2)
160 | (loss+ss_losses['total']).backward()
161 | for k, v in ss_losses.items():
162 | outputs[f'ss/{k}'] = v
163 |
164 | for o in optimizers:
165 | o.step()
166 |
167 | # momentum network update
168 | for online, target in [(backbone, target_backbone), (projector, target_projector)]:
169 | for p1, p2 in zip(online.parameters(), target.parameters()):
170 | p2.data.mul_(momentum).add_(p1.data, alpha=1-momentum)
171 |
172 | # queue update
173 | keys = idist.utils.all_gather(z1)
174 | queue[queue.ptr:queue.ptr+keys.shape[0]] = keys
175 | queue.ptr = (queue.ptr+keys.shape[0]) % K
176 |
177 | return outputs
178 |
179 | engine = Engine(training_step)
180 | return engine
181 |
182 |
183 | def simclr(backbone,
184 | projector,
185 | ss_predictor,
186 | t1,
187 | t2,
188 | optimizers,
189 | device,
190 | ss_objective,
191 | T=0.2,
192 | ):
193 |
194 | def training_step(engine, batch):
195 | backbone.train()
196 | projector.train()
197 |
198 | for o in optimizers:
199 | o.zero_grad()
200 |
201 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device)
202 | y1 = backbone(x1)
203 | y2 = backbone(x2)
204 | z1 = F.normalize(projector(y1))
205 | z2 = F.normalize(projector(y2))
206 |
207 | z = torch.cat([z1, z2], 0)
208 | scores = torch.einsum('ik, jk -> ij', z, z).div(T)
209 | n = z1.shape[0]
210 | labels = torch.tensor(list(range(n, 2*n)) + list(range(0, n)), device=scores.device)
211 | masks = torch.zeros_like(scores, dtype=torch.bool)
212 | for i in range(2*n):
213 | masks[i, i] = True
214 | scores = scores.masked_fill(masks, float('-inf'))
215 | loss = F.cross_entropy(scores, labels)
216 | outputs = dict(loss=loss, z1=z1, z2=z2)
217 |
218 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2)
219 | (loss+ss_losses['total']).backward()
220 | for k, v in ss_losses.items():
221 | outputs[f'ss/{k}'] = v
222 |
223 | for o in optimizers:
224 | o.step()
225 |
226 | return outputs
227 |
228 | engine = Engine(training_step)
229 | return engine
230 |
231 |
232 | def byol(backbone,
233 | projector,
234 | predictor,
235 | ss_predictor,
236 | t1,
237 | t2,
238 | optimizers,
239 | device,
240 | ss_objective,
241 | momentum=0.996,
242 | ):
243 |
244 | target_backbone = deepcopy(backbone)
245 | target_projector = deepcopy(projector)
246 | for p in list(target_backbone.parameters())+list(target_projector.parameters()):
247 | p.requires_grad = False
248 |
249 | def training_step(engine, batch):
250 | backbone.train()
251 | projector.train()
252 | predictor.train()
253 |
254 | for o in optimizers:
255 | o.zero_grad()
256 |
257 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device)
258 | y1, y2 = backbone(x1), backbone(x2)
259 | z1, z2 = projector(y1), projector(y2)
260 | p1, p2 = predictor(z1), predictor(z2)
261 | with torch.no_grad():
262 | tgt1 = target_projector(target_backbone(x1))
263 | tgt2 = target_projector(target_backbone(x2))
264 |
265 | loss1 = F.cosine_similarity(p1, tgt2.detach(), dim=-1).mean().mul(-1)
266 | loss2 = F.cosine_similarity(p2, tgt1.detach(), dim=-1).mean().mul(-1)
267 | loss = (loss1+loss2).mul(2)
268 |
269 | outputs = dict(loss=loss)
270 | outputs['z1'] = z1
271 | outputs['z2'] = z2
272 |
273 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2)
274 | (loss+ss_losses['total']).backward()
275 | for k, v in ss_losses.items():
276 | outputs[f'ss/{k}'] = v
277 |
278 | for o in optimizers:
279 | o.step()
280 |
281 | # momentum network update
282 | m = 1 - (1-momentum)*(math.cos(math.pi*(engine.state.epoch-1)/engine.state.max_epochs)+1)/2
283 | for online, target in [(backbone, target_backbone), (projector, target_projector)]:
284 | for p1, p2 in zip(online.parameters(), target.parameters()):
285 | p2.data.mul_(m).add_(p1.data, alpha=1-m)
286 |
287 | return outputs
288 |
289 | return Engine(training_step)
290 |
291 |
292 | def distributed_sinkhorn(Q, nmb_iters):
293 | with torch.no_grad():
294 | Q = shoot_infs(Q)
295 | sum_Q = torch.sum(Q)
296 | # idist.utils.all_reduce(sum_Q)
297 | Q /= sum_Q
298 | r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
299 | # c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (args.world_size * Q.shape[1])
300 | c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (1 * Q.shape[1])
301 | for it in range(nmb_iters):
302 | u = torch.sum(Q, dim=1)
303 | # idist.utils.all_reduce(u)
304 | u = r / u
305 | u = shoot_infs(u)
306 | Q *= u.unsqueeze(1)
307 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
308 | return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
309 |
310 |
311 | def shoot_infs(inp_tensor):
312 | """Replaces inf by maximum of tensor"""
313 | mask_inf = torch.isinf(inp_tensor)
314 | ind_inf = torch.nonzero(mask_inf)
315 | if len(ind_inf) > 0:
316 | for ind in ind_inf:
317 | if len(ind) == 2:
318 | inp_tensor[ind[0], ind[1]] = 0
319 | elif len(ind) == 1:
320 | inp_tensor[ind[0]] = 0
321 | m = torch.max(inp_tensor)
322 | for ind in ind_inf:
323 | if len(ind) == 2:
324 | inp_tensor[ind[0], ind[1]] = m
325 | elif len(ind) == 1:
326 | inp_tensor[ind[0]] = m
327 | return inp_tensor
328 |
329 |
330 | def swav(backbone,
331 | projector,
332 | prototypes,
333 | ss_predictor,
334 | t1,
335 | t2,
336 | optimizers,
337 | device,
338 | ss_objective,
339 | epsilon=0.05,
340 | n_iters=3,
341 | temperature=0.1,
342 | freeze_n_iters=410,
343 | ):
344 |
345 | def training_step(engine, batch):
346 | backbone.train()
347 | projector.train()
348 | prototypes.train()
349 |
350 | for o in optimizers:
351 | o.zero_grad()
352 |
353 | with torch.no_grad():
354 | w = prototypes.weight.data.clone()
355 | w = F.normalize(w, dim=1, p=2)
356 | prototypes.weight.copy_(w)
357 |
358 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device)
359 | y1, y2 = backbone(x1), backbone(x2)
360 | z1, z2 = projector(y1), projector(y2)
361 | z1 = F.normalize(z1, dim=1, p=2)
362 | z2 = F.normalize(z2, dim=1, p=2)
363 | p1, p2 = prototypes(z1), prototypes(z2)
364 |
365 | q1 = distributed_sinkhorn(torch.exp(p1 / epsilon).t(), n_iters)
366 | q2 = distributed_sinkhorn(torch.exp(p2 / epsilon).t(), n_iters)
367 |
368 | p1 = F.softmax(p1 / temperature, dim=1)
369 | p2 = F.softmax(p2 / temperature, dim=1)
370 |
371 | loss1 = -torch.mean(torch.sum(q1 * torch.log(p2), dim=1))
372 | loss2 = -torch.mean(torch.sum(q2 * torch.log(p1), dim=1))
373 | loss = loss1+loss2
374 |
375 | outputs = dict(loss=loss)
376 | outputs['z1'] = z1
377 | outputs['z2'] = z2
378 |
379 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2)
380 | (loss+ss_losses['total']).backward()
381 | for k, v in ss_losses.items():
382 | outputs[f'ss/{k}'] = v
383 |
384 | if engine.state.iteration < freeze_n_iters:
385 | for p in prototypes.parameters():
386 | p.grad = None
387 |
388 | for o in optimizers:
389 | o.step()
390 |
391 | return outputs
392 |
393 | return Engine(training_step)
394 |
395 |
396 | def collect_features(backbone,
397 | dataloader,
398 | device,
399 | normalize=True,
400 | dst=None,
401 | verbose=False):
402 |
403 | if dst is None:
404 | dst = device
405 |
406 | backbone.eval()
407 | with torch.no_grad():
408 | features = []
409 | labels = []
410 | for i, (x, y) in enumerate(dataloader):
411 | if x.ndim == 5:
412 | _, n, c, h, w = x.shape
413 | x = x.view(-1, c, h, w)
414 | y = y.view(-1, 1).repeat(1, n).view(-1)
415 | z = backbone(x.to(device))
416 | if normalize:
417 | z = F.normalize(z, dim=-1)
418 | features.append(z.to(dst).detach())
419 | labels.append(y.to(dst).detach())
420 | if verbose and (i+1) % 10 == 0:
421 | print(i+1)
422 | features = idist.utils.all_gather(torch.cat(features, 0).detach())
423 | labels = idist.utils.all_gather(torch.cat(labels, 0).detach())
424 |
425 | return features, labels
426 |
427 |
428 | def nn_evaluator(backbone,
429 | trainloader,
430 | testloader,
431 | device):
432 |
433 | def evaluator():
434 | backbone.eval()
435 | with torch.no_grad():
436 | features, labels = collect_features(backbone, trainloader, device)
437 | corrects, total = 0, 0
438 | for x, y in testloader:
439 | z = F.normalize(backbone(x.to(device)), dim=-1)
440 | scores = torch.einsum('ik, jk -> ij', z, features)
441 | preds = labels[scores.argmax(1)]
442 |
443 | corrects += (preds.cpu() == y).long().sum().item()
444 | total += y.shape[0]
445 | corrects = idist.utils.all_reduce(corrects)
446 | total = idist.utils.all_reduce(total)
447 |
448 | return corrects / total
449 |
450 | return evaluator
451 |
452 |
--------------------------------------------------------------------------------
/transfer_few_shot.py:
--------------------------------------------------------------------------------
1 | import random
2 | from argparse import ArgumentParser
3 | from functools import partial
4 | from copy import deepcopy
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torch.backends.cudnn as cudnn
14 |
15 | import ignite.distributed as idist
16 |
17 | from datasets import load_fewshot_datasets
18 | from models import load_backbone, load_mlp
19 | from trainers import collect_features, SSObjective
20 | from utils import Logger
21 | from transforms import extract_diff
22 |
23 | from sklearn.linear_model import LogisticRegression
24 |
25 |
26 | class FewShotBatchSampler(torch.utils.data.Sampler):
27 | def __init__(self, dataset, N, K, Q, num_iterations):
28 | self.N = N
29 | self.K = K
30 | self.Q = Q
31 | self.num_iterations = num_iterations
32 |
33 | labels = [label for _, label in dataset.samples]
34 | self.label2idx = defaultdict(list)
35 | for i, y in enumerate(labels):
36 | self.label2idx[y].append(i)
37 |
38 | few_labels = [y for y, indices in self.label2idx.items() if len(indices) <= self.K]
39 | for y in few_labels:
40 | del self.label2idx[y]
41 |
42 | def __len__(self):
43 | return self.num_iterations
44 |
45 | def __iter__(self):
46 | label_set = set(list(self.label2idx.keys()))
47 | for _ in range(self.num_iterations):
48 | labels = random.sample(label_set, self.N)
49 | indices = []
50 | for y in labels:
51 | if len(self.label2idx[y]) >= self.K+self.Q:
52 | indices.extend(list(random.sample(self.label2idx[y], self.K+self.Q)))
53 | else:
54 | tmp_indices = [i for i in self.label2idx[y]]
55 | random.shuffle(tmp_indices)
56 | indices.extend(tmp_indices[:self.K] + np.random.choice(tmp_indices[self.K:], size=self.Q).tolist())
57 | yield indices
58 |
59 |
60 | def main(local_rank, args):
61 | cudnn.benchmark = True
62 | device = idist.device()
63 | logger = Logger(None)
64 |
65 | # DATASETS
66 | datasets = load_fewshot_datasets(dataset=args.dataset,
67 | datadir=args.datadir,
68 | pretrain_data=args.pretrain_data)
69 | build_sampler = partial(FewShotBatchSampler,
70 | N=args.N, K=args.K, Q=args.Q, num_iterations=args.num_tasks)
71 | build_dataloader = partial(torch.utils.data.DataLoader,
72 | num_workers=args.num_workers)
73 | testloader = build_dataloader(datasets['test'], batch_sampler=build_sampler(datasets['test']))
74 |
75 | # MODELS
76 | ckpt = torch.load(args.ckpt, map_location=device)
77 | backbone = load_backbone(args).to(device)
78 | backbone.load_state_dict(ckpt['backbone'])
79 | backbone.eval()
80 |
81 | all_accuracies = []
82 | for i, (batch, _) in enumerate(testloader):
83 | with torch.no_grad():
84 | batch = batch.to(device)
85 | B, C, H, W = batch.shape
86 | batch = batch.view(args.N, args.K+args.Q, C, H, W)
87 |
88 | train_batch = batch[:, :args.K].reshape(args.N*args.K, C, H, W)
89 | test_batch = batch[:, args.K:].reshape(args.N*args.Q, C, H, W)
90 | train_labels = torch.arange(args.N).unsqueeze(1).repeat(1, args.K).to(device).view(-1)
91 | test_labels = torch.arange(args.N).unsqueeze(1).repeat(1, args.Q).to(device).view(-1)
92 |
93 | with torch.no_grad():
94 | X_train = backbone(train_batch)
95 | Y_train = train_labels
96 |
97 | X_test = backbone(test_batch)
98 | Y_test = test_labels
99 |
100 | classifier = LogisticRegression(solver='liblinear').fit(X_train.cpu().numpy(),
101 | Y_train.cpu().numpy())
102 | preds = classifier.predict(X_test.cpu().numpy())
103 | acc = np.mean((Y_test.cpu().numpy() == preds).astype(float))
104 | all_accuracies.append(acc)
105 | if (i+1) % 10 == 0:
106 | logger.log_msg(f'{i+1:3d} | {acc:.4f} (mean: {np.mean(all_accuracies):.4f})')
107 |
108 | avg = np.mean(all_accuracies)
109 | std = np.std(all_accuracies) * 1.96 / np.sqrt(len(all_accuracies))
110 | logger.log_msg(f'mean: {avg:.4f}±{std:.4f}')
111 |
112 |
113 | if __name__ == '__main__':
114 | parser = ArgumentParser()
115 | parser.add_argument('--ckpt', type=str, required=True)
116 | parser.add_argument('--pretrain-data', type=str, default='stl10')
117 | parser.add_argument('--dataset', type=str, default='cub200')
118 | parser.add_argument('--datadir', type=str, default='/data')
119 | parser.add_argument('--N', type=int, default=5)
120 | parser.add_argument('--K', type=int, default=1)
121 | parser.add_argument('--Q', type=int, default=16)
122 | parser.add_argument('--num-workers', type=int, default=8)
123 | parser.add_argument('--model', type=str, default='resnet18')
124 | parser.add_argument('--num-tasks', type=int, default=2000)
125 | args = parser.parse_args()
126 | args.num_backbone_features = 512 if args.model.endswith('resnet18') else 2048
127 | with idist.Parallel(None) as parallel:
128 | parallel.run(main, args)
129 |
130 |
--------------------------------------------------------------------------------
/transfer_linear_eval.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from functools import partial
3 | from copy import deepcopy
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | import torch.backends.cudnn as cudnn
10 |
11 | import ignite.distributed as idist
12 |
13 | from datasets import load_datasets
14 | from models import load_backbone
15 | from trainers import collect_features
16 | from utils import Logger
17 |
18 |
19 | def build_step(X, Y, classifier, optimizer, w):
20 | def step():
21 | optimizer.zero_grad()
22 | loss = F.cross_entropy(classifier(X), Y, reduction='sum')
23 | for p in classifier.parameters():
24 | loss = loss + p.pow(2).sum().mul(w)
25 | loss.backward()
26 | return loss
27 | return step
28 |
29 |
30 | def compute_accuracy(X, Y, classifier, metric):
31 | with torch.no_grad():
32 | preds = classifier(X).argmax(1)
33 | if metric == 'top1':
34 | acc = (preds == Y).float().mean().item()
35 | elif metric == 'class-avg':
36 | total, count = 0., 0.
37 | for y in range(0, Y.max().item()+1):
38 | masks = Y == y
39 | if masks.sum() > 0:
40 | total += (preds[masks] == y).float().mean().item()
41 | count += 1
42 | acc = total / count
43 | else:
44 | raise Exception(f'Unknown metric: {metric}')
45 | return acc
46 |
47 |
48 | def main(local_rank, args):
49 | cudnn.benchmark = True
50 | device = idist.device()
51 | logger = Logger(None)
52 |
53 | # DATASETS
54 | datasets = load_datasets(dataset=args.dataset,
55 | datadir=args.datadir,
56 | pretrain_data=args.pretrain_data)
57 | build_dataloader = partial(idist.auto_dataloader,
58 | batch_size=args.batch_size,
59 | num_workers=args.num_workers,
60 | shuffle=True,
61 | pin_memory=True)
62 | trainloader = build_dataloader(datasets['train'], drop_last=False)
63 | valloader = build_dataloader(datasets['val'], drop_last=False)
64 | testloader = build_dataloader(datasets['test'], drop_last=False)
65 | num_classes = datasets['num_classes']
66 |
67 | # MODELS
68 | ckpt = torch.load(args.ckpt, map_location=device)
69 | backbone = load_backbone(args)
70 | backbone.load_state_dict(ckpt['backbone'])
71 |
72 | build_model = partial(idist.auto_model, sync_bn=True)
73 | backbone = build_model(backbone)
74 |
75 | # EXTRACT FROZEN FEATURES
76 | logger.log_msg('collecting features ...')
77 | X_train, Y_train = collect_features(backbone, trainloader, device, normalize=False)
78 | X_val, Y_val = collect_features(backbone, valloader, device, normalize=False)
79 | X_test, Y_test = collect_features(backbone, testloader, device, normalize=False)
80 | classifier = nn.Linear(args.num_backbone_features, num_classes).to(device)
81 | optim_kwargs = {
82 | 'line_search_fn': 'strong_wolfe',
83 | 'max_iter': 5000,
84 | 'lr': 1.,
85 | 'tolerance_grad': 1e-10,
86 | 'tolerance_change': 0,
87 | }
88 | logger.log_msg('collecting features ... done')
89 |
90 | best_acc = 0.
91 | best_w = 0.
92 | best_classifier = None
93 | for w in torch.logspace(-6, 5, steps=45).tolist():
94 | optimizer = optim.LBFGS(classifier.parameters(), **optim_kwargs)
95 | optimizer.step(build_step(X_train, Y_train, classifier, optimizer, w))
96 | acc = compute_accuracy(X_val, Y_val, classifier, args.metric)
97 |
98 | if best_acc < acc:
99 | best_acc = acc
100 | best_w = w
101 | best_classifier = deepcopy(classifier)
102 |
103 | logger.log_msg(f'w={w:.4e}, acc={acc:.4f}')
104 |
105 | logger.log_msg(f'BEST: w={best_w:.4e}, acc={best_acc:.4f}')
106 |
107 | X = torch.cat([X_train, X_val], 0)
108 | Y = torch.cat([Y_train, Y_val], 0)
109 | optimizer = optim.LBFGS(best_classifier.parameters(), **optim_kwargs)
110 | optimizer.step(build_step(X, Y, best_classifier, optimizer, best_w))
111 | acc = compute_accuracy(X_test, Y_test, best_classifier, args.metric)
112 | logger.log_msg(f'test acc={acc:.4f}')
113 |
114 | if __name__ == '__main__':
115 | parser = ArgumentParser()
116 | parser.add_argument('--ckpt', type=str, required=True)
117 | parser.add_argument('--pretrain-data', type=str, default='stl10')
118 | parser.add_argument('--dataset', type=str, default='cifar10')
119 | parser.add_argument('--datadir', type=str, default='/data')
120 | parser.add_argument('--batch-size', type=int, default=256)
121 | parser.add_argument('--num-workers', type=int, default=4)
122 | parser.add_argument('--model', type=str, default='resnet18')
123 | parser.add_argument('--print-freq', type=int, default=10)
124 | parser.add_argument('--distributed', action='store_true')
125 | parser.add_argument('--metric', type=str, default='top1')
126 | args = parser.parse_args()
127 | args.backend = 'nccl' if args.distributed else None
128 | args.num_backbone_features = 512 if args.model.endswith('resnet18') else 2048
129 | with idist.Parallel(args.backend) as parallel:
130 | parallel.run(main, args)
131 |
132 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as NF
5 | import torchvision.transforms as T
6 | import torchvision.transforms.functional as F
7 | import kornia
8 | import kornia.augmentation as K
9 | import kornia.augmentation.functional as KF
10 |
11 |
12 | class MultiView:
13 | def __init__(self, transform, num_views=2):
14 | self.transform = transform
15 | self.num_views = num_views
16 |
17 | def __call__(self, x):
18 | return [self.transform(x) for _ in range(self.num_views)]
19 |
20 |
21 | class RandomResizedCrop(T.RandomResizedCrop):
22 | def forward(self, img):
23 | W, H = F._get_image_size(img)
24 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
25 | img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
26 | tensor = F.to_tensor(img)
27 | return tensor, torch.tensor([i/H, j/W, h/H, w/W], dtype=torch.float)
28 |
29 |
30 | def apply_adjust_brightness(img1, params):
31 | ratio = params['brightness_factor'][:, None, None, None].to(img1.device)
32 | img2 = torch.zeros_like(img1)
33 | return (ratio * img1 + (1.0-ratio) * img2).clamp(0, 1)
34 |
35 |
36 | def apply_adjust_contrast(img1, params):
37 | ratio = params['contrast_factor'][:, None, None, None].to(img1.device)
38 | img2 = 0.2989 * img1[:, 0:1] + 0.587 * img1[:, 1:2] + 0.114 * img1[:, 2:3]
39 | img2 = torch.mean(img2, dim=(-2, -1), keepdim=True)
40 | return (ratio * img1 + (1.0-ratio) * img2).clamp(0, 1)
41 |
42 |
43 | class ColorJitter(K.ColorJitter):
44 | def apply_transform(self, x, params):
45 | transforms = [
46 | lambda img: apply_adjust_brightness(img, params),
47 | lambda img: apply_adjust_contrast(img, params),
48 | lambda img: KF.apply_adjust_saturation(img, params),
49 | lambda img: KF.apply_adjust_hue(img, params)
50 | ]
51 |
52 | for idx in params['order'].tolist():
53 | t = transforms[idx]
54 | x = t(x)
55 |
56 | return x
57 |
58 |
59 | class GaussianBlur(K.AugmentationBase2D):
60 | def __init__(self, kernel_size, sigma, border_type='reflect',
61 | return_transform=False, same_on_batch=False, p=0.5):
62 | super().__init__(
63 | p=p, return_transform=return_transform, same_on_batch=same_on_batch, p_batch=1.)
64 | assert kernel_size % 2 == 1
65 | self.kernel_size = kernel_size
66 | self.sigma = sigma
67 | self.border_type = border_type
68 |
69 | def __repr__(self):
70 | return self.__class__.__name__ + f"({super().__repr__()})"
71 |
72 | def generate_parameters(self, batch_shape):
73 | return dict(sigma=torch.zeros(batch_shape[0]).uniform_(self.sigma[0], self.sigma[1]))
74 |
75 | def apply_transform(self, input, params):
76 | sigma = params['sigma'].to(input.device)
77 | k_half = self.kernel_size // 2
78 | x = torch.linspace(-k_half, k_half, steps=self.kernel_size, dtype=input.dtype, device=input.device)
79 | pdf = torch.exp(-0.5*(x[None, :] / sigma[:, None]).pow(2))
80 | kernel1d = pdf / pdf.sum(1, keepdim=True)
81 | kernel2d = torch.bmm(kernel1d[:, :, None], kernel1d[:, None, :])
82 | input = NF.pad(input, (k_half, k_half, k_half, k_half), mode=self.border_type)
83 | input = NF.conv2d(input.transpose(0, 1), kernel2d[:, None], groups=input.shape[0]).transpose(0, 1)
84 | return input
85 |
86 |
87 | class RandomRotation(K.AugmentationBase2D):
88 | def __init__(self, return_transform=False, same_on_batch=False, p=0.5):
89 | super().__init__(
90 | p=p, return_transform=return_transform, same_on_batch=same_on_batch, p_batch=1.)
91 |
92 | def __repr__(self):
93 | return self.__class__.__name__ + f"({super().__repr__()})"
94 |
95 | def generate_parameters(self, batch_shape):
96 | degrees = torch.randint(0, 4, (batch_shape[0], ))
97 | return dict(degrees=degrees)
98 |
99 | def apply_transform(self, input, params):
100 | degrees = params['degrees']
101 | input = torch.stack([torch.rot90(x, k, (1, 2)) for x, k in zip(input, degrees.tolist())], 0)
102 | return input
103 |
104 |
105 | def _extract_w(t):
106 | if isinstance(t, GaussianBlur):
107 | m = t._params['batch_prob']
108 | w = torch.zeros(m.shape[0], 1)
109 | w[m] = t._params['sigma'].unsqueeze(-1)
110 | return w
111 |
112 | elif isinstance(t, ColorJitter):
113 | to_apply = t._params['batch_prob']
114 | w = torch.zeros(to_apply.shape[0], 4)
115 | w[to_apply, 0] = (t._params['brightness_factor'] - 1) / (t.brightness[1]-t.brightness[0])
116 | w[to_apply, 1] = (t._params['contrast_factor'] - 1) / (t.contrast[1]-t.contrast[0])
117 | w[to_apply, 2] = (t._params['saturation_factor'] - 1) / (t.saturation[1]-t.saturation[0])
118 | w[to_apply, 3] = t._params['hue_factor'] / (t.hue[1]-t.hue[0])
119 | return w
120 |
121 | elif isinstance(t, RandomRotation):
122 | to_apply = t._params['batch_prob']
123 | w = torch.zeros(to_apply.shape[0], dtype=torch.long)
124 | w[to_apply] = t._params['degrees']
125 | return w
126 |
127 | elif isinstance(t, K.RandomSolarize):
128 | to_apply = t._params['batch_prob']
129 | w = torch.ones(to_apply.shape[0])
130 | w[to_apply] = t._params['thresholds_factor']
131 | return w
132 |
133 |
134 | def extract_diff(transforms1, transforms2, crop1, crop2):
135 | diff = {}
136 | for t1, t2 in zip(transforms1, transforms2):
137 | if isinstance(t1, K.RandomHorizontalFlip):
138 | f1 = t1._params['batch_prob']
139 | f2 = t2._params['batch_prob']
140 | break
141 |
142 | center1 = crop1[:, :2]+crop1[:, 2:]/2
143 | center2 = crop2[:, :2]+crop2[:, 2:]/2
144 | center1[f1, 1] = 1-center1[f1, 1]
145 | center2[f1, 1] = 1-center2[f1, 1]
146 | diff['crop'] = torch.cat([center1-center2, crop1[:, 2:]-crop2[:, 2:]], 1)
147 | diff['flip'] = (f1==f2).float().unsqueeze(-1)
148 | for t1, t2 in zip(transforms1, transforms2):
149 | if isinstance(t1, K.RandomHorizontalFlip):
150 | pass
151 |
152 | elif isinstance(t1, K.RandomGrayscale):
153 | pass
154 |
155 | elif isinstance(t1, GaussianBlur):
156 | w1 = _extract_w(t1)
157 | w2 = _extract_w(t2)
158 | diff['blur'] = w1-w2
159 |
160 | elif isinstance(t1, K.Normalize):
161 | pass
162 |
163 | elif isinstance(t1, K.ColorJitter):
164 | w1 = _extract_w(t1)
165 | w2 = _extract_w(t2)
166 | diff['color'] = w1-w2
167 |
168 | elif isinstance(t1, (nn.Identity, nn.Sequential)):
169 | pass
170 |
171 | elif isinstance(t1, RandomRotation):
172 | w1 = _extract_w(t1)
173 | w2 = _extract_w(t2)
174 | diff['rot'] = (w1-w2+4) % 4
175 |
176 | elif isinstance(t1, K.RandomSolarize):
177 | w1 = _extract_w(t1)
178 | w2 = _extract_w(t2)
179 | diff['sol'] = w1-w2
180 |
181 | else:
182 | raise Exception(f'Unknown transform: {str(t1.__class__)}')
183 |
184 | return diff
185 |
186 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | import torch
5 | from torch.optim import Optimizer
6 | from torch.utils.tensorboard import SummaryWriter
7 | import ignite.distributed as idist
8 |
9 | class Logger(object):
10 |
11 | def __init__(self, logdir, resume=None):
12 | self.logdir = logdir
13 | self.rank = idist.get_rank()
14 |
15 | handlers = [logging.StreamHandler(os.sys.stdout)]
16 | if logdir is not None and self.rank == 0:
17 | if resume is None:
18 | os.makedirs(logdir)
19 | handlers.append(logging.FileHandler(os.path.join(logdir, 'log.txt')))
20 | self.writer = SummaryWriter(log_dir=logdir)
21 | else:
22 | self.writer = None
23 |
24 | logging.basicConfig(format=f"[%(asctime)s ({self.rank})] %(message)s",
25 | level=logging.INFO,
26 | handlers=handlers)
27 | logging.info(' '.join(os.sys.argv))
28 |
29 | def log_msg(self, msg):
30 | if idist.get_rank() > 0:
31 | return
32 | logging.info(msg)
33 |
34 | def log(self, engine, global_step, print_msg=True, **kwargs):
35 | msg = f'[epoch {engine.state.epoch}] [iter {engine.state.iteration}]'
36 | for k, v in kwargs.items():
37 | if isinstance(v, torch.Tensor):
38 | v = v.item()
39 |
40 | if type(v) is float:
41 | msg += f' [{k} {v:.4f}]'
42 | else:
43 | msg += f' [{k} {v}]'
44 |
45 | if self.writer is not None:
46 | self.writer.add_scalar(k, v, global_step)
47 |
48 | if print_msg:
49 | logging.info(msg)
50 |
51 | def save(self, engine, **kwargs):
52 | if idist.get_rank() > 0:
53 | return
54 |
55 | state = {}
56 | for k, v in kwargs.items():
57 | if isinstance(v, torch.nn.parallel.DistributedDataParallel):
58 | v = v.module
59 |
60 | if hasattr(v, 'state_dict'):
61 | state[k] = v.state_dict()
62 |
63 | if type(v) is list and hasattr(v[0], 'state_dict'):
64 | state[k] = [x.state_dict() for x in v]
65 |
66 | if type(v) is dict and k == 'ss_predictor':
67 | state[k] = { y: x.state_dict() for y, x in v.items() }
68 |
69 | torch.save(state, os.path.join(self.logdir, f'ckpt-{engine.state.epoch}.pth'))
70 |
71 |
--------------------------------------------------------------------------------