├── .gitignore ├── README.md ├── datasets.py ├── models.py ├── pretrain.py ├── splits ├── caltech101.pth ├── cub200.pth ├── dog.pth ├── imagenet100.txt └── sun397.pth ├── trainers.py ├── transfer_few_shot.py ├── transfer_linear_eval.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Transferability of Representations via Augmentation-Aware Self-Supervision 2 | 3 | Accepted to NeurIPS 2021 4 | 5 |

6 | thumbnail 7 |

8 | 9 | **TL;DR:** Learning augmentation-aware information by predicting the difference between two augmented samples improves the transferability of representations. 10 | 11 | ## Dependencies 12 | 13 | ```bash 14 | conda create -n AugSelf python=3.8 pytorch=1.7.1 torchvision=0.8.2 cudatoolkit=10.1 ignite -c pytorch 15 | conda activate AugSelf 16 | pip install scipy tensorboard kornia==0.4.1 sklearn 17 | ``` 18 | 19 | ## Checkpoints 20 | 21 | We provide ImageNet100-pretrained models in [this Dropbox link](https://www.dropbox.com/sh/0hjts19ysxebmaa/AABB6bF3QQWdIOCh9vocwTGGa?dl=0). 22 | 23 | ## Pretraining 24 | 25 | We here provide SimSiam+AugSelf pretraining scripts. For training the baseline (i.e., no AugSelf), remove `--ss-crop` and `--ss-color` options. For using other frameworks like SimCLR, use the `--framework` option. 26 | 27 | ### STL-10 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python pretrain.py \ 30 | --logdir ./logs/stl10/simsiam/aug_self \ 31 | --framework simsiam \ 32 | --dataset stl10 \ 33 | --datadir DATADIR \ 34 | --model resnet18 \ 35 | --batch-size 256 \ 36 | --max-epochs 200 \ 37 | --ss-color 1.0 --ss-crop 1.0 38 | ``` 39 | 40 | ### ImageNet100 41 | 42 | ```bash 43 | python pretrain.py \ 44 | --logdir ./logs/imagenet100/simsiam/aug_self \ 45 | --framework simsiam \ 46 | --dataset imagenet100 \ 47 | --datadir DATADIR \ 48 | --batch-size 256 \ 49 | --max-epochs 500 \ 50 | --model resnet50 \ 51 | --base-lr 0.05 --wd 1e-4 \ 52 | --ckpt-freq 50 --eval-freq 50 \ 53 | --ss-crop 0.5 --ss-color 0.5 \ 54 | --num-workers 16 --distributed 55 | ``` 56 | 57 | ## Evaluation 58 | 59 | Our main evaluation setups are linear evaluation on fine-grained classification datasets (Table 1) and few-shot benchmarks (Table 2). 60 | 61 | ### linear evaluation 62 | 63 | ```bash 64 | CUDA_VISIBLE_DEVICES=0 python transfer_linear_eval.py \ 65 | --pretrain-data imagenet100 \ 66 | --ckpt CKPT \ 67 | --model resnet50 \ 68 | --dataset cifar10 \ 69 | --datadir DATADIR \ 70 | --metric top1 71 | ``` 72 | 73 | ### few-shot 74 | 75 | ```bash 76 | CUDA_VISIBLE_DEVICES=0 python transfer_few_shot.py \ 77 | --pretrain-data imagenet100 \ 78 | --ckpt CKPT \ 79 | --model resnet50 \ 80 | --dataset cub200 \ 81 | --datadir DATADIR 82 | ``` 83 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | from scipy.io import loadmat 5 | from PIL import Image 6 | import xml.etree.ElementTree as ET 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import random_split, ConcatDataset, Subset 12 | 13 | from transforms import MultiView, RandomResizedCrop, ColorJitter, GaussianBlur, RandomRotation 14 | from torchvision import transforms as T 15 | from torchvision.datasets import STL10, CIFAR10, CIFAR100, ImageFolder, ImageNet, Caltech101, Caltech256 16 | 17 | import kornia.augmentation as K 18 | 19 | class ImageList(torch.utils.data.Dataset): 20 | def __init__(self, samples, transform=None): 21 | self.samples = samples 22 | self.transform = transform 23 | 24 | def __getitem__(self, idx): 25 | path, label = self.samples[idx] 26 | with open(path, 'rb') as f: 27 | img = Image.open(f) 28 | img = img.convert('RGB') 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | return img, label 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | 36 | class ImageNet100(ImageFolder): 37 | def __init__(self, root, split, transform): 38 | with open('splits/imagenet100.txt') as f: 39 | classes = [line.strip() for line in f] 40 | class_to_idx = { cls: idx for idx, cls in enumerate(classes) } 41 | 42 | super().__init__(os.path.join(root, split), transform=transform) 43 | samples = [] 44 | for path, label in self.samples: 45 | cls = self.classes[label] 46 | if cls not in class_to_idx: 47 | continue 48 | label = class_to_idx[cls] 49 | samples.append((path, label)) 50 | 51 | self.samples = samples 52 | self.classes = classes 53 | self.class_to_idx = class_to_idx 54 | self.targets = [s[1] for s in samples] 55 | 56 | class Pets(ImageList): 57 | def __init__(self, root, split, transform=None): 58 | with open(os.path.join(root, 'annotations', f'{split}.txt')) as f: 59 | annotations = [line.split() for line in f] 60 | 61 | samples = [] 62 | for sample in annotations: 63 | path = os.path.join(root, 'images', sample[0] + '.jpg') 64 | label = int(sample[1])-1 65 | samples.append((path, label)) 66 | 67 | super().__init__(samples, transform) 68 | 69 | class Food101(ImageList): 70 | def __init__(self, root, split, transform=None): 71 | with open(os.path.join(root, 'meta', 'classes.txt')) as f: 72 | classes = [line.strip() for line in f] 73 | with open(os.path.join(root, 'meta', f'{split}.json')) as f: 74 | annotations = json.load(f) 75 | 76 | samples = [] 77 | for i, cls in enumerate(classes): 78 | for path in annotations[cls]: 79 | samples.append((os.path.join(root, 'images', f'{path}.jpg'), i)) 80 | 81 | super().__init__(samples, transform) 82 | 83 | class DTD(ImageList): 84 | def __init__(self, root, split, transform=None): 85 | with open(os.path.join(root, 'labels', f'{split}1.txt')) as f: 86 | paths = [line.strip() for line in f] 87 | 88 | classes = sorted(os.listdir(os.path.join(root, 'images'))) 89 | samples = [(os.path.join(root, 'images', path), classes.index(path.split('/')[0])) for path in paths] 90 | super().__init__(samples, transform) 91 | 92 | class SUN397(ImageList): 93 | def __init__(self, root, split, transform=None): 94 | with open(os.path.join(root, 'ClassName.txt')) as f: 95 | classes = [line.strip() for line in f] 96 | 97 | with open(os.path.join(root, f'{split}_01.txt')) as f: 98 | samples = [] 99 | for line in f: 100 | path = line.strip() 101 | for y, cls in enumerate(classes): 102 | if path.startswith(cls+'/'): 103 | samples.append((os.path.join(root, 'SUN397', path[1:]), y)) 104 | break 105 | super().__init__(samples, transform) 106 | 107 | def load_pretrain_datasets(dataset='cifar10', 108 | datadir='/data', 109 | color_aug='default'): 110 | 111 | if dataset == 'imagenet100': 112 | mean = torch.tensor([0.485, 0.456, 0.406]) 113 | std = torch.tensor([0.229, 0.224, 0.225]) 114 | train_transform = MultiView(RandomResizedCrop(224, scale=(0.2, 1.0))) 115 | test_transform = T.Compose([T.Resize(224), 116 | T.CenterCrop(224), 117 | T.ToTensor(), 118 | T.Normalize(mean, std)]) 119 | t1 = nn.Sequential(K.RandomHorizontalFlip(), 120 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 121 | K.RandomGrayscale(p=0.2), 122 | GaussianBlur(23, (0.1, 2.0)), 123 | K.Normalize(mean, std)) 124 | t2 = nn.Sequential(K.RandomHorizontalFlip(), 125 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 126 | K.RandomGrayscale(p=0.2), 127 | GaussianBlur(23, (0.1, 2.0)), 128 | K.Normalize(mean, std)) 129 | 130 | trainset = ImageNet100(datadir, split='train', transform=train_transform) 131 | valset = ImageNet100(datadir, split='train', transform=test_transform) 132 | testset = ImageNet100(datadir, split='val', transform=test_transform) 133 | 134 | elif dataset == 'stl10': 135 | mean = torch.tensor([0.43, 0.42, 0.39]) 136 | std = torch.tensor([0.27, 0.26, 0.27]) 137 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0))) 138 | 139 | if color_aug == 'default': 140 | s = 1 141 | elif color_aug == 'strong': 142 | s = 2. 143 | elif color_aug == 'weak': 144 | s = 0.5 145 | test_transform = T.Compose([T.Resize(96), 146 | T.CenterCrop(96), 147 | T.ToTensor(), 148 | T.Normalize(mean, std)]) 149 | t1 = nn.Sequential(K.RandomHorizontalFlip(), 150 | ColorJitter(0.4*s, 0.4*s, 0.4*s, 0.1*s, p=0.8), 151 | K.RandomGrayscale(p=0.2*s), 152 | GaussianBlur(9, (0.1, 2.0)), 153 | K.Normalize(mean, std)) 154 | t2 = nn.Sequential(K.RandomHorizontalFlip(), 155 | ColorJitter(0.4*s, 0.4*s, 0.4*s, 0.1*s, p=0.8), 156 | K.RandomGrayscale(p=0.2*s), 157 | GaussianBlur(9, (0.1, 2.0)), 158 | K.Normalize(mean, std)) 159 | 160 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform) 161 | valset = STL10(datadir, split='train', transform=test_transform) 162 | testset = STL10(datadir, split='test', transform=test_transform) 163 | 164 | elif dataset == 'stl10_rot': 165 | mean = torch.tensor([0.43, 0.42, 0.39]) 166 | std = torch.tensor([0.27, 0.26, 0.27]) 167 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0))) 168 | test_transform = T.Compose([T.Resize(96), 169 | T.CenterCrop(96), 170 | T.ToTensor(), 171 | T.Normalize(mean, std)]) 172 | t1 = nn.Sequential(K.RandomHorizontalFlip(), 173 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 174 | K.RandomGrayscale(p=0.2), 175 | GaussianBlur(9, (0.1, 2.0)), 176 | RandomRotation(p=0.5), 177 | K.Normalize(mean, std)) 178 | t2 = nn.Sequential(K.RandomHorizontalFlip(), 179 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 180 | K.RandomGrayscale(p=0.2), 181 | GaussianBlur(9, (0.1, 2.0)), 182 | RandomRotation(p=0.5), 183 | K.Normalize(mean, std)) 184 | 185 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform) 186 | valset = STL10(datadir, split='train', transform=test_transform) 187 | testset = STL10(datadir, split='test', transform=test_transform) 188 | 189 | elif dataset == 'stl10_sol': 190 | mean = torch.tensor([0.43, 0.42, 0.39]) 191 | std = torch.tensor([0.27, 0.26, 0.27]) 192 | train_transform = MultiView(RandomResizedCrop(96, scale=(0.2, 1.0))) 193 | 194 | test_transform = T.Compose([T.Resize(96), 195 | T.CenterCrop(96), 196 | T.ToTensor(), 197 | T.Normalize(mean, std)]) 198 | t1 = nn.Sequential(K.RandomHorizontalFlip(), 199 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 200 | K.RandomSolarize(0.5, 0.0, p=0.5), 201 | K.RandomGrayscale(p=0.2), 202 | GaussianBlur(9, (0.1, 2.0)), 203 | K.Normalize(mean, std)) 204 | t2 = nn.Sequential(K.RandomHorizontalFlip(), 205 | ColorJitter(0.4, 0.4, 0.4, 0.1, p=0.8), 206 | K.RandomSolarize(0.5, 0.0, p=0.5), 207 | K.RandomGrayscale(p=0.2), 208 | GaussianBlur(9, (0.1, 2.0)), 209 | K.Normalize(mean, std)) 210 | 211 | trainset = STL10(datadir, split='train+unlabeled', transform=train_transform) 212 | valset = STL10(datadir, split='train', transform=test_transform) 213 | testset = STL10(datadir, split='test', transform=test_transform) 214 | 215 | else: 216 | raise Exception(f'Unknown dataset {dataset}') 217 | 218 | return dict(train=trainset, 219 | val=valset, 220 | test=testset, 221 | t1=t1, t2=t2) 222 | 223 | def load_datasets(dataset='cifar10', 224 | datadir='/data', 225 | pretrain_data='stl10'): 226 | 227 | if pretrain_data == 'imagenet100': 228 | mean = torch.tensor([0.485, 0.456, 0.406]) 229 | std = torch.tensor([0.229, 0.224, 0.225]) 230 | transform = T.Compose([T.Resize(224, interpolation=Image.BICUBIC), 231 | T.CenterCrop(224), 232 | T.ToTensor(), 233 | T.Normalize(mean, std)]) 234 | 235 | elif pretrain_data == 'stl10': 236 | mean = torch.tensor([0.43, 0.42, 0.39]) 237 | std = torch.tensor([0.27, 0.26, 0.27]) 238 | transform = T.Compose([T.Resize(96, interpolation=Image.BICUBIC), 239 | T.CenterCrop(96), 240 | T.ToTensor(), 241 | T.Normalize(mean, std)]) 242 | 243 | generator = lambda seed: torch.Generator().manual_seed(seed) 244 | if dataset == 'imagenet100': 245 | trainval = ImageNet100(datadir, split='train', transform=transform) 246 | train, val = None, None 247 | test = ImageNet100(datadir, split='val', transform=transform) 248 | num_classes = 100 249 | 250 | elif dataset == 'food101': 251 | trainval = Food101(root=datadir, split='train', transform=transform) 252 | train, val = random_split(trainval, [68175, 7575], generator=generator(42)) 253 | test = Food101(root=datadir, split='test', transform=transform) 254 | num_classes = 101 255 | 256 | elif dataset == 'cifar10': 257 | trainval = CIFAR10(root=datadir, train=True, transform=transform) 258 | train, val = random_split(trainval, [45000, 5000], generator=generator(43)) 259 | test = CIFAR10(root=datadir, train=False, transform=transform) 260 | num_classes = 10 261 | 262 | elif dataset == 'cifar100': 263 | trainval = CIFAR100(root=datadir, train=True, transform=transform) 264 | train, val = random_split(trainval, [45000, 5000], generator=generator(44)) 265 | test = CIFAR100(root=datadir, train=False, transform=transform) 266 | num_classes = 100 267 | 268 | elif dataset == 'sun397': 269 | trn_indices, val_indices = torch.load('splits/sun397.pth') 270 | trainval = SUN397(root=datadir, split='Training', transform=transform) 271 | train = Subset(trainval, trn_indices) 272 | val = Subset(trainval, val_indices) 273 | test = SUN397(root=datadir, split='Testing', transform=transform) 274 | num_classes = 397 275 | 276 | elif dataset == 'dtd': 277 | train = DTD(root=datadir, split='train', transform=transform) 278 | val = DTD(root=datadir, split='val', transform=transform) 279 | trainval = ConcatDataset([train, val]) 280 | test = DTD(root=datadir, split='test', transform=transform) 281 | num_classes = 47 282 | 283 | elif dataset == 'pets': 284 | trainval = Pets(root=datadir, split='trainval', transform=transform) 285 | train, val = random_split(trainval, [2940, 740], generator=generator(49)) 286 | test = Pets(root=datadir, split='test', transform=transform) 287 | num_classes = 37 288 | 289 | elif dataset == 'caltech101': 290 | transform.transforms.insert(0, T.Lambda(lambda img: img.convert('RGB'))) 291 | D = Caltech101(datadir, transform=transform) 292 | trn_indices, val_indices, tst_indices = torch.load('splits/caltech101.pth') 293 | train = Subset(D, trn_indices) 294 | val = Subset(D, val_indices) 295 | trainval = ConcatDataset([train, val]) 296 | test = Subset(D, tst_indices) 297 | num_classes = 101 298 | 299 | elif dataset == 'flowers': 300 | train = ImageFolder(os.path.join(datadir, 'trn'), transform=transform) 301 | val = ImageFolder(os.path.join(datadir, 'val'), transform=transform) 302 | trainval = ConcatDataset([train, val]) 303 | test = ImageFolder(os.path.join(datadir, 'tst'), transform=transform) 304 | num_classes = 102 305 | 306 | elif dataset in ['flowers-5shot', 'flowers-10shot']: 307 | if dataset == 'flowers-5shot': 308 | n = 5 309 | else: 310 | n = 10 311 | train = ImageFolder(os.path.join(datadir, 'trn'), transform=transform) 312 | val = ImageFolder(os.path.join(datadir, 'val'), transform=transform) 313 | trainval = ImageFolder(os.path.join(datadir, 'trn'), transform=transform) 314 | trainval.samples += val.samples 315 | trainval.targets += val.targets 316 | indices = defaultdict(list) 317 | for i, y in enumerate(trainval.targets): 318 | indices[y].append(i) 319 | indices = sum([random.sample(indices[y], n) for y in indices.keys()], []) 320 | trainval = Subset(trainval, indices) 321 | test = ImageFolder(os.path.join(datadir, 'tst'), transform=transform) 322 | num_classes = 102 323 | 324 | elif dataset == 'stl10': 325 | trainval = STL10(root=datadir, split='train', transform=transform) 326 | test = STL10(root=datadir, split='test', transform=transform) 327 | train, val = random_split(trainval, [4500, 500], generator=generator(50)) 328 | num_classes = 10 329 | 330 | elif dataset == 'mit67': 331 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 332 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 333 | train, val = random_split(trainval, [4690, 670], generator=generator(51)) 334 | num_classes = 67 335 | 336 | elif dataset == 'cub200': 337 | trn_indices, val_indices = torch.load('splits/cub200.pth') 338 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 339 | train = Subset(trainval, trn_indices) 340 | val = Subset(trainval, val_indices) 341 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 342 | num_classes = 200 343 | 344 | elif dataset == 'dog': 345 | trn_indices, val_indices = torch.load('splits/dog.pth') 346 | trainval = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 347 | train = Subset(trainval, trn_indices) 348 | val = Subset(trainval, val_indices) 349 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 350 | num_classes = 120 351 | 352 | return dict(trainval=trainval, 353 | train=train, 354 | val=val, 355 | test=test, 356 | num_classes=num_classes) 357 | 358 | 359 | def load_fewshot_datasets(dataset='cifar10', 360 | datadir='/data', 361 | pretrain_data='stl10'): 362 | 363 | if pretrain_data == 'imagenet100': 364 | mean = torch.tensor([0.485, 0.456, 0.406]) 365 | std = torch.tensor([0.229, 0.224, 0.225]) 366 | transform = T.Compose([T.Resize(224, interpolation=Image.BICUBIC), 367 | T.CenterCrop(224), 368 | T.ToTensor(), 369 | T.Normalize(mean, std)]) 370 | 371 | elif pretrain_data == 'stl10': 372 | mean = torch.tensor([0.43, 0.42, 0.39]) 373 | std = torch.tensor([0.27, 0.26, 0.27]) 374 | transform = T.Compose([T.Resize(96, interpolation=Image.BICUBIC), 375 | T.CenterCrop(96), 376 | T.ToTensor(), 377 | T.Normalize(mean, std)]) 378 | 379 | if dataset == 'cub200': 380 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 381 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 382 | test.samples = train.samples + test.samples 383 | test.targets = train.targets + test.targets 384 | 385 | elif dataset == 'fc100': 386 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 387 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 388 | 389 | elif dataset == 'plant_disease': 390 | train = ImageFolder(os.path.join(datadir, 'train'), transform=transform) 391 | test = ImageFolder(os.path.join(datadir, 'test'), transform=transform) 392 | test.samples = train.samples + test.samples 393 | test.targets = train.targets + test.targets 394 | 395 | return dict(test=test) 396 | 397 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from torchvision import models 9 | 10 | def reset_parameters(model): 11 | for m in model.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | m.reset_parameters() 14 | 15 | if isinstance(m, nn.Linear): 16 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 17 | bound = 1 / math.sqrt(fan_in) 18 | nn.init.uniform_(m.weight, -bound, bound) 19 | if m.bias is not None: 20 | nn.init.uniform_(m.bias, -bound, bound) 21 | 22 | def load_backbone(args): 23 | name = args.model 24 | backbone = models.__dict__[name.split('_')[-1]](zero_init_residual=True) 25 | if name.startswith('cifar_'): 26 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 27 | backbone.maxpool = nn.Identity() 28 | args.num_backbone_features = backbone.fc.weight.shape[1] 29 | backbone.fc = nn.Identity() 30 | reset_parameters(backbone) 31 | return backbone 32 | 33 | 34 | def load_mlp(n_in, n_hidden, n_out, num_layers=3, last_bn=True): 35 | layers = [] 36 | for i in range(num_layers-1): 37 | layers.append(nn.Linear(n_in, n_hidden, bias=False)) 38 | layers.append(nn.BatchNorm1d(n_hidden)) 39 | layers.append(nn.ReLU()) 40 | n_in = n_hidden 41 | layers.append(nn.Linear(n_hidden, n_out, bias=not last_bn)) 42 | if last_bn: 43 | layers.append(nn.BatchNorm1d(n_out)) 44 | mlp = nn.Sequential(*layers) 45 | reset_parameters(mlp) 46 | return mlp 47 | 48 | 49 | def load_ss_predictor(n_in, ss_objective, n_hidden=512): 50 | ss_predictor = {} 51 | for name, weight, n_out, _ in ss_objective.params: 52 | if weight > 0: 53 | ss_predictor[name] = load_mlp(n_in*2, n_hidden, n_out, num_layers=3, last_bn=False) 54 | 55 | return ss_predictor 56 | 57 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | 11 | import ignite 12 | from ignite.engine import Events 13 | import ignite.distributed as idist 14 | 15 | from datasets import load_pretrain_datasets 16 | from models import load_backbone, load_mlp, load_ss_predictor 17 | import trainers 18 | from trainers import SSObjective 19 | from utils import Logger 20 | 21 | def simsiam(args, t1, t2): 22 | out_dim = 2048 23 | device = idist.device() 24 | 25 | ss_objective = SSObjective( 26 | crop = args.ss_crop, 27 | color = args.ss_color, 28 | flip = args.ss_flip, 29 | blur = args.ss_blur, 30 | rot = args.ss_rot, 31 | sol = args.ss_sol, 32 | only = args.ss_only, 33 | ) 34 | 35 | build_model = partial(idist.auto_model, sync_bn=True) 36 | backbone = build_model(load_backbone(args)) 37 | projector = build_model(load_mlp(args.num_backbone_features, 38 | out_dim, 39 | out_dim, 40 | num_layers=2+int(args.dataset.startswith('imagenet')), 41 | last_bn=True)) 42 | predictor = build_model(load_mlp(out_dim, 43 | out_dim // 4, 44 | out_dim, 45 | num_layers=2, 46 | last_bn=False)) 47 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective) 48 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() } 49 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], []) 50 | 51 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 52 | build_optim = lambda x: idist.auto_optim(SGD(x)) 53 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params), 54 | build_optim(list(predictor.parameters()))] 55 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)] 56 | 57 | trainer = trainers.simsiam(backbone=backbone, 58 | projector=projector, 59 | predictor=predictor, 60 | ss_predictor=ss_predictor, 61 | t1=t1, t2=t2, 62 | optimizers=optimizers, 63 | device=device, 64 | ss_objective=ss_objective) 65 | 66 | return dict(backbone=backbone, 67 | projector=projector, 68 | predictor=predictor, 69 | ss_predictor=ss_predictor, 70 | optimizers=optimizers, 71 | schedulers=schedulers, 72 | trainer=trainer) 73 | 74 | 75 | def moco(args, t1, t2): 76 | out_dim = 128 77 | device = idist.device() 78 | 79 | ss_objective = SSObjective( 80 | crop = args.ss_crop, 81 | color = args.ss_color, 82 | flip = args.ss_flip, 83 | blur = args.ss_blur, 84 | only = args.ss_only, 85 | ) 86 | 87 | build_model = partial(idist.auto_model, sync_bn=True) 88 | backbone = build_model(load_backbone(args)) 89 | projector = build_model(load_mlp(args.num_backbone_features, 90 | args.num_backbone_features, 91 | out_dim, 92 | num_layers=2, 93 | last_bn=False)) 94 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective) 95 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() } 96 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], []) 97 | 98 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 99 | build_optim = lambda x: idist.auto_optim(SGD(x)) 100 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params)] 101 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)] 102 | 103 | trainer = trainers.moco( 104 | backbone=backbone, 105 | projector=projector, 106 | ss_predictor=ss_predictor, 107 | t1=t1, t2=t2, 108 | optimizers=optimizers, 109 | device=device, 110 | ss_objective=ss_objective) 111 | 112 | return dict(backbone=backbone, 113 | projector=projector, 114 | ss_predictor=ss_predictor, 115 | optimizers=optimizers, 116 | schedulers=schedulers, 117 | trainer=trainer) 118 | 119 | def simclr(args, t1, t2): 120 | out_dim = 128 121 | device = idist.device() 122 | 123 | ss_objective = SSObjective( 124 | crop = args.ss_crop, 125 | color = args.ss_color, 126 | flip = args.ss_flip, 127 | blur = args.ss_blur, 128 | only = args.ss_only, 129 | ) 130 | 131 | build_model = partial(idist.auto_model, sync_bn=True) 132 | backbone = build_model(load_backbone(args)) 133 | projector = build_model(load_mlp(args.num_backbone_features, 134 | args.num_backbone_features, 135 | out_dim, 136 | num_layers=2, 137 | last_bn=False)) 138 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective) 139 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() } 140 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], []) 141 | 142 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 143 | build_optim = lambda x: idist.auto_optim(SGD(x)) 144 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params)] 145 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)] 146 | 147 | trainer = trainers.simclr(backbone=backbone, 148 | projector=projector, 149 | ss_predictor=ss_predictor, 150 | t1=t1, t2=t2, 151 | optimizers=optimizers, 152 | device=device, 153 | ss_objective=ss_objective) 154 | 155 | return dict(backbone=backbone, 156 | projector=projector, 157 | ss_predictor=ss_predictor, 158 | optimizers=optimizers, 159 | schedulers=schedulers, 160 | trainer=trainer) 161 | 162 | 163 | def byol(args, t1, t2): 164 | out_dim = 256 165 | h_dim = 4096 166 | device = idist.device() 167 | 168 | ss_objective = SSObjective( 169 | crop = args.ss_crop, 170 | color = args.ss_color, 171 | flip = args.ss_flip, 172 | blur = args.ss_blur, 173 | rot = args.ss_rot, 174 | only = args.ss_only, 175 | ) 176 | 177 | build_model = partial(idist.auto_model, sync_bn=True) 178 | backbone = build_model(load_backbone(args)) 179 | projector = build_model(load_mlp(args.num_backbone_features, 180 | h_dim, 181 | out_dim, 182 | num_layers=2, 183 | last_bn=False)) 184 | predictor = build_model(load_mlp(out_dim, 185 | h_dim, 186 | out_dim, 187 | num_layers=2, 188 | last_bn=False)) 189 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective) 190 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() } 191 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], []) 192 | 193 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 194 | build_optim = lambda x: idist.auto_optim(SGD(x)) 195 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params+list(predictor.parameters()))] 196 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)] 197 | 198 | trainer = trainers.byol(backbone=backbone, 199 | projector=projector, 200 | predictor=predictor, 201 | ss_predictor=ss_predictor, 202 | t1=t1, t2=t2, 203 | optimizers=optimizers, 204 | device=device, 205 | ss_objective=ss_objective) 206 | 207 | return dict(backbone=backbone, 208 | projector=projector, 209 | predictor=predictor, 210 | ss_predictor=ss_predictor, 211 | optimizers=optimizers, 212 | schedulers=schedulers, 213 | trainer=trainer) 214 | 215 | 216 | def swav(args, t1, t2): 217 | out_dim = 128 218 | h_dim = 2048 219 | device = idist.device() 220 | 221 | ss_objective = SSObjective( 222 | crop = args.ss_crop, 223 | color = args.ss_color, 224 | flip = args.ss_flip, 225 | blur = args.ss_blur, 226 | rot = args.ss_rot, 227 | only = args.ss_only, 228 | ) 229 | 230 | build_model = partial(idist.auto_model, sync_bn=True) 231 | backbone = build_model(load_backbone(args)) 232 | projector = build_model(load_mlp(args.num_backbone_features, 233 | h_dim, 234 | out_dim, 235 | num_layers=2, 236 | last_bn=False)) 237 | prototypes = build_model(nn.Linear(out_dim, 100, bias=False)) 238 | ss_predictor = load_ss_predictor(args.num_backbone_features, ss_objective) 239 | ss_predictor = { k: build_model(v) for k, v in ss_predictor.items() } 240 | ss_params = sum([list(v.parameters()) for v in ss_predictor.values()], []) 241 | 242 | SGD = partial(optim.SGD, lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 243 | build_optim = lambda x: idist.auto_optim(SGD(x)) 244 | optimizers = [build_optim(list(backbone.parameters())+list(projector.parameters())+ss_params+list(prototypes.parameters()))] 245 | schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizers[0], args.max_epochs)] 246 | 247 | trainer = trainers.swav(backbone=backbone, 248 | projector=projector, 249 | prototypes=prototypes, 250 | ss_predictor=ss_predictor, 251 | t1=t1, t2=t2, 252 | optimizers=optimizers, 253 | device=device, 254 | ss_objective=ss_objective) 255 | 256 | return dict(backbone=backbone, 257 | projector=projector, 258 | prototypes=prototypes, 259 | ss_predictor=ss_predictor, 260 | optimizers=optimizers, 261 | schedulers=schedulers, 262 | trainer=trainer) 263 | 264 | 265 | def main(local_rank, args): 266 | cudnn.benchmark = True 267 | device = idist.device() 268 | logger = Logger(args.logdir, args.resume) 269 | 270 | # DATASETS 271 | datasets = load_pretrain_datasets(dataset=args.dataset, 272 | datadir=args.datadir, 273 | color_aug=args.color_aug) 274 | build_dataloader = partial(idist.auto_dataloader, 275 | batch_size=args.batch_size, 276 | num_workers=args.num_workers, 277 | shuffle=True, 278 | pin_memory=True) 279 | trainloader = build_dataloader(datasets['train'], drop_last=True) 280 | valloader = build_dataloader(datasets['val'] , drop_last=False) 281 | testloader = build_dataloader(datasets['test'], drop_last=False) 282 | 283 | t1, t2 = datasets['t1'], datasets['t2'] 284 | 285 | # MODELS 286 | if args.framework == 'simsiam': 287 | models = simsiam(args, t1, t2) 288 | elif args.framework == 'moco': 289 | models = moco(args, t1, t2) 290 | elif args.framework == 'simclr': 291 | models = simclr(args, t1, t2) 292 | elif args.framework == 'byol': 293 | models = byol(args, t1, t2) 294 | elif args.framework == 'swav': 295 | models = swav(args, t1, t2) 296 | 297 | trainer = models['trainer'] 298 | evaluator = trainers.nn_evaluator(backbone=models['backbone'], 299 | trainloader=valloader, 300 | testloader=testloader, 301 | device=device) 302 | 303 | if args.distributed: 304 | @trainer.on(Events.EPOCH_STARTED) 305 | def set_epoch(engine): 306 | for loader in [trainloader, valloader, testloader]: 307 | loader.sampler.set_epoch(engine.state.epoch) 308 | 309 | @trainer.on(Events.ITERATION_STARTED) 310 | def log_lr(engine): 311 | lrs = {} 312 | for i, optimizer in enumerate(models['optimizers']): 313 | for j, pg in enumerate(optimizer.param_groups): 314 | lrs[f'lr/{i}-{j}'] = pg['lr'] 315 | logger.log(engine, engine.state.iteration, print_msg=False, **lrs) 316 | 317 | @trainer.on(Events.ITERATION_COMPLETED) 318 | def log(engine): 319 | loss = engine.state.output.pop('loss') 320 | ss_loss = engine.state.output.pop('ss/total') 321 | logger.log(engine, engine.state.iteration, 322 | print_msg=engine.state.iteration % args.print_freq == 0, 323 | loss=loss, ss_loss=ss_loss) 324 | 325 | if 'z1' in engine.state.output: 326 | with torch.no_grad(): 327 | z1 = engine.state.output.pop('z1') 328 | z2 = engine.state.output.pop('z2') 329 | z1 = F.normalize(z1, dim=-1) 330 | z2 = F.normalize(z2, dim=-1) 331 | dist = torch.einsum('ik, jk -> ij', z1, z2) 332 | diag_masks = torch.diag(torch.ones(z1.shape[0])).bool() 333 | engine.state.output['dist/intra'] = dist[diag_masks].mean().item() 334 | engine.state.output['dist/inter'] = dist[~diag_masks].mean().item() 335 | 336 | logger.log(engine, engine.state.iteration, 337 | print_msg=False, 338 | **engine.state.output) 339 | 340 | @trainer.on(Events.EPOCH_COMPLETED(every=args.eval_freq)) 341 | def evaluate(engine): 342 | acc = evaluator() 343 | logger.log(engine, engine.state.epoch, acc=acc) 344 | 345 | @trainer.on(Events.EPOCH_COMPLETED) 346 | def update_lr(engine): 347 | for scheduler in models['schedulers']: 348 | scheduler.step() 349 | 350 | @trainer.on(Events.EPOCH_COMPLETED(every=args.ckpt_freq)) 351 | def save_ckpt(engine): 352 | logger.save(engine, **models) 353 | 354 | if args.resume is not None: 355 | @trainer.on(Events.STARTED) 356 | def load_state(engine): 357 | ckpt = torch.load(os.path.join(args.logdir, f'ckpt-{args.resume}.pth'), map_location='cpu') 358 | for k, v in models.items(): 359 | if isinstance(v, nn.parallel.DistributedDataParallel): 360 | v = v.module 361 | 362 | if hasattr(v, 'state_dict'): 363 | v.load_state_dict(ckpt[k]) 364 | 365 | if type(v) is list and hasattr(v[0], 'state_dict'): 366 | for i, x in enumerate(v): 367 | x.load_state_dict(ckpt[k][i]) 368 | 369 | if type(v) is dict and k == 'ss_predictor': 370 | for y, x in v.items(): 371 | x.load_state_dict(ckpt[k][y]) 372 | 373 | trainer.run(trainloader, max_epochs=args.max_epochs) 374 | 375 | if __name__ == '__main__': 376 | parser = ArgumentParser() 377 | parser.add_argument('--logdir', type=str, required=True) 378 | parser.add_argument('--resume', type=int, default=None) 379 | parser.add_argument('--dataset', type=str, default='stl10') 380 | parser.add_argument('--datadir', type=str, default='/data') 381 | parser.add_argument('--batch-size', type=int, default=256) 382 | parser.add_argument('--max-epochs', type=int, default=200) 383 | parser.add_argument('--num-workers', type=int, default=4) 384 | parser.add_argument('--model', type=str, default='resnet18') 385 | parser.add_argument('--distributed', action='store_true') 386 | 387 | parser.add_argument('--framework', type=str, default='simsiam') 388 | 389 | parser.add_argument('--base-lr', type=float, default=0.03) 390 | parser.add_argument('--wd', type=float, default=5e-4) 391 | parser.add_argument('--momentum', type=float, default=0.9) 392 | 393 | parser.add_argument('--print-freq', type=int, default=10) 394 | parser.add_argument('--ckpt-freq', type=int, default=10) 395 | parser.add_argument('--eval-freq', type=int, default=1) 396 | 397 | parser.add_argument('--color-aug', type=str, default='default') 398 | 399 | parser.add_argument('--ss-crop', type=float, default=-1) 400 | parser.add_argument('--ss-color', type=float, default=-1) 401 | parser.add_argument('--ss-flip', type=float, default=-1) 402 | parser.add_argument('--ss-blur', type=float, default=-1) 403 | parser.add_argument('--ss-rot', type=float, default=-1) 404 | parser.add_argument('--ss-sol', type=float, default=-1) 405 | parser.add_argument('--ss-only', action='store_true') 406 | 407 | args = parser.parse_args() 408 | args.lr = args.base_lr * args.batch_size / 256 409 | if not args.distributed: 410 | with idist.Parallel() as parallel: 411 | parallel.run(main, args) 412 | else: 413 | with idist.Parallel('nccl', nproc_per_node=torch.cuda.device_count()) as parallel: 414 | parallel.run(main, args) 415 | 416 | -------------------------------------------------------------------------------- /splits/caltech101.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/caltech101.pth -------------------------------------------------------------------------------- /splits/cub200.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/cub200.pth -------------------------------------------------------------------------------- /splits/dog.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/dog.pth -------------------------------------------------------------------------------- /splits/imagenet100.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 101 | -------------------------------------------------------------------------------- /splits/sun397.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hankook/AugSelf/c131db66b5ade96af86774bc43a2cb797390bba5/splits/sun397.pth -------------------------------------------------------------------------------- /trainers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ignite.engine import Engine 9 | import ignite.distributed as idist 10 | 11 | from transforms import extract_diff 12 | 13 | 14 | class SSObjective: 15 | def __init__(self, crop=-1, color=-1, flip=-1, blur=-1, rot=-1, sol=-1, only=False): 16 | self.only = only 17 | self.params = [ 18 | ('crop', crop, 4, 'regression'), 19 | ('color', color, 4, 'regression'), 20 | ('flip', flip, 1, 'binary_classification'), 21 | ('blur', blur, 1, 'regression'), 22 | ('rot', rot, 4, 'classification'), 23 | ('sol', sol, 1, 'regression'), 24 | ] 25 | 26 | def __call__(self, ss_predictor, z1, z2, d1, d2, symmetric=True): 27 | if symmetric: 28 | z = torch.cat([torch.cat([z1, z2], 1), 29 | torch.cat([z2, z1], 1)], 0) 30 | d = { k: torch.cat([d1[k], d2[k]], 0) for k in d1.keys() } 31 | else: 32 | z = torch.cat([z1, z2], 1) 33 | d = d1 34 | 35 | losses = { 'total': 0 } 36 | for name, weight, n_out, loss_type in self.params: 37 | if weight <= 0: 38 | continue 39 | 40 | p = ss_predictor[name](z) 41 | if loss_type == 'regression': 42 | losses[name] = F.mse_loss(torch.tanh(p), d[name]) 43 | elif loss_type == 'binary_classification': 44 | losses[name] = F.binary_cross_entropy_with_logits(p, d[name]) 45 | elif loss_type == 'classification': 46 | losses[name] = F.cross_entropy(p, d[name]) 47 | losses['total'] += losses[name] * weight 48 | 49 | return losses 50 | 51 | 52 | def prepare_training_batch(batch, t1, t2, device): 53 | ((x1, w1), (x2, w2)), _ = batch 54 | with torch.no_grad(): 55 | x1 = t1(x1.to(device)).detach() 56 | x2 = t2(x2.to(device)).detach() 57 | diff1 = { k: v.to(device) for k, v in extract_diff(t1, t2, w1, w2).items() } 58 | diff2 = { k: v.to(device) for k, v in extract_diff(t2, t1, w2, w1).items() } 59 | 60 | return x1, x2, diff1, diff2 61 | 62 | 63 | def simsiam(backbone, 64 | projector, 65 | predictor, 66 | ss_predictor, 67 | t1, 68 | t2, 69 | optimizers, 70 | device, 71 | ss_objective 72 | ): 73 | 74 | def training_step(engine, batch): 75 | backbone.train() 76 | projector.train() 77 | predictor.train() 78 | 79 | for o in optimizers: 80 | o.zero_grad() 81 | 82 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device) 83 | y1, y2 = backbone(x1), backbone(x2) 84 | 85 | if not ss_objective.only: 86 | z1 = projector(y1) 87 | z2 = projector(y2) 88 | p1 = predictor(z1) 89 | p2 = predictor(z2) 90 | loss1 = F.cosine_similarity(p1, z2.detach(), dim=-1).mean().mul(-1) 91 | loss2 = F.cosine_similarity(p2, z1.detach(), dim=-1).mean().mul(-1) 92 | loss = (loss1+loss2).mul(0.5) 93 | else: 94 | loss = 0. 95 | 96 | outputs = dict(loss=loss) 97 | if not ss_objective.only: 98 | outputs['z1'] = z1 99 | outputs['z2'] = z2 100 | 101 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2) 102 | (loss+ss_losses['total']).backward() 103 | for k, v in ss_losses.items(): 104 | outputs[f'ss/{k}'] = v 105 | 106 | for o in optimizers: 107 | o.step() 108 | 109 | return outputs 110 | 111 | return Engine(training_step) 112 | 113 | 114 | def moco(backbone, 115 | projector, 116 | ss_predictor, 117 | t1, 118 | t2, 119 | optimizers, 120 | device, 121 | ss_objective, 122 | momentum=0.999, 123 | K=65536, 124 | T=0.2, 125 | ): 126 | 127 | target_backbone = deepcopy(backbone) 128 | target_projector = deepcopy(projector) 129 | for p in list(target_backbone.parameters())+list(target_projector.parameters()): 130 | p.requires_grad = False 131 | 132 | queue = F.normalize(torch.randn(K, 128).to(device)).detach() 133 | queue.requires_grad = False 134 | queue.ptr = 0 135 | 136 | def training_step(engine, batch): 137 | backbone.train() 138 | projector.train() 139 | target_backbone.train() 140 | target_projector.train() 141 | 142 | for o in optimizers: 143 | o.zero_grad() 144 | 145 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device) 146 | y1 = backbone(x1) 147 | z1 = F.normalize(projector(y1)) 148 | with torch.no_grad(): 149 | y2 = target_backbone(x2) 150 | z2 = F.normalize(target_projector(y2)) 151 | 152 | l_pos = torch.einsum('nc,nc->n', [z1, z2]).unsqueeze(-1) 153 | l_neg = torch.einsum('nc,kc->nk', [z1, queue.clone().detach()]) 154 | logits = torch.cat([l_pos, l_neg], dim=1).div(T) 155 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) 156 | loss = F.cross_entropy(logits, labels) 157 | outputs = dict(loss=loss, z1=z1, z2=z2) 158 | 159 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2) 160 | (loss+ss_losses['total']).backward() 161 | for k, v in ss_losses.items(): 162 | outputs[f'ss/{k}'] = v 163 | 164 | for o in optimizers: 165 | o.step() 166 | 167 | # momentum network update 168 | for online, target in [(backbone, target_backbone), (projector, target_projector)]: 169 | for p1, p2 in zip(online.parameters(), target.parameters()): 170 | p2.data.mul_(momentum).add_(p1.data, alpha=1-momentum) 171 | 172 | # queue update 173 | keys = idist.utils.all_gather(z1) 174 | queue[queue.ptr:queue.ptr+keys.shape[0]] = keys 175 | queue.ptr = (queue.ptr+keys.shape[0]) % K 176 | 177 | return outputs 178 | 179 | engine = Engine(training_step) 180 | return engine 181 | 182 | 183 | def simclr(backbone, 184 | projector, 185 | ss_predictor, 186 | t1, 187 | t2, 188 | optimizers, 189 | device, 190 | ss_objective, 191 | T=0.2, 192 | ): 193 | 194 | def training_step(engine, batch): 195 | backbone.train() 196 | projector.train() 197 | 198 | for o in optimizers: 199 | o.zero_grad() 200 | 201 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device) 202 | y1 = backbone(x1) 203 | y2 = backbone(x2) 204 | z1 = F.normalize(projector(y1)) 205 | z2 = F.normalize(projector(y2)) 206 | 207 | z = torch.cat([z1, z2], 0) 208 | scores = torch.einsum('ik, jk -> ij', z, z).div(T) 209 | n = z1.shape[0] 210 | labels = torch.tensor(list(range(n, 2*n)) + list(range(0, n)), device=scores.device) 211 | masks = torch.zeros_like(scores, dtype=torch.bool) 212 | for i in range(2*n): 213 | masks[i, i] = True 214 | scores = scores.masked_fill(masks, float('-inf')) 215 | loss = F.cross_entropy(scores, labels) 216 | outputs = dict(loss=loss, z1=z1, z2=z2) 217 | 218 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2) 219 | (loss+ss_losses['total']).backward() 220 | for k, v in ss_losses.items(): 221 | outputs[f'ss/{k}'] = v 222 | 223 | for o in optimizers: 224 | o.step() 225 | 226 | return outputs 227 | 228 | engine = Engine(training_step) 229 | return engine 230 | 231 | 232 | def byol(backbone, 233 | projector, 234 | predictor, 235 | ss_predictor, 236 | t1, 237 | t2, 238 | optimizers, 239 | device, 240 | ss_objective, 241 | momentum=0.996, 242 | ): 243 | 244 | target_backbone = deepcopy(backbone) 245 | target_projector = deepcopy(projector) 246 | for p in list(target_backbone.parameters())+list(target_projector.parameters()): 247 | p.requires_grad = False 248 | 249 | def training_step(engine, batch): 250 | backbone.train() 251 | projector.train() 252 | predictor.train() 253 | 254 | for o in optimizers: 255 | o.zero_grad() 256 | 257 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device) 258 | y1, y2 = backbone(x1), backbone(x2) 259 | z1, z2 = projector(y1), projector(y2) 260 | p1, p2 = predictor(z1), predictor(z2) 261 | with torch.no_grad(): 262 | tgt1 = target_projector(target_backbone(x1)) 263 | tgt2 = target_projector(target_backbone(x2)) 264 | 265 | loss1 = F.cosine_similarity(p1, tgt2.detach(), dim=-1).mean().mul(-1) 266 | loss2 = F.cosine_similarity(p2, tgt1.detach(), dim=-1).mean().mul(-1) 267 | loss = (loss1+loss2).mul(2) 268 | 269 | outputs = dict(loss=loss) 270 | outputs['z1'] = z1 271 | outputs['z2'] = z2 272 | 273 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2) 274 | (loss+ss_losses['total']).backward() 275 | for k, v in ss_losses.items(): 276 | outputs[f'ss/{k}'] = v 277 | 278 | for o in optimizers: 279 | o.step() 280 | 281 | # momentum network update 282 | m = 1 - (1-momentum)*(math.cos(math.pi*(engine.state.epoch-1)/engine.state.max_epochs)+1)/2 283 | for online, target in [(backbone, target_backbone), (projector, target_projector)]: 284 | for p1, p2 in zip(online.parameters(), target.parameters()): 285 | p2.data.mul_(m).add_(p1.data, alpha=1-m) 286 | 287 | return outputs 288 | 289 | return Engine(training_step) 290 | 291 | 292 | def distributed_sinkhorn(Q, nmb_iters): 293 | with torch.no_grad(): 294 | Q = shoot_infs(Q) 295 | sum_Q = torch.sum(Q) 296 | # idist.utils.all_reduce(sum_Q) 297 | Q /= sum_Q 298 | r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0] 299 | # c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (args.world_size * Q.shape[1]) 300 | c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (1 * Q.shape[1]) 301 | for it in range(nmb_iters): 302 | u = torch.sum(Q, dim=1) 303 | # idist.utils.all_reduce(u) 304 | u = r / u 305 | u = shoot_infs(u) 306 | Q *= u.unsqueeze(1) 307 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) 308 | return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() 309 | 310 | 311 | def shoot_infs(inp_tensor): 312 | """Replaces inf by maximum of tensor""" 313 | mask_inf = torch.isinf(inp_tensor) 314 | ind_inf = torch.nonzero(mask_inf) 315 | if len(ind_inf) > 0: 316 | for ind in ind_inf: 317 | if len(ind) == 2: 318 | inp_tensor[ind[0], ind[1]] = 0 319 | elif len(ind) == 1: 320 | inp_tensor[ind[0]] = 0 321 | m = torch.max(inp_tensor) 322 | for ind in ind_inf: 323 | if len(ind) == 2: 324 | inp_tensor[ind[0], ind[1]] = m 325 | elif len(ind) == 1: 326 | inp_tensor[ind[0]] = m 327 | return inp_tensor 328 | 329 | 330 | def swav(backbone, 331 | projector, 332 | prototypes, 333 | ss_predictor, 334 | t1, 335 | t2, 336 | optimizers, 337 | device, 338 | ss_objective, 339 | epsilon=0.05, 340 | n_iters=3, 341 | temperature=0.1, 342 | freeze_n_iters=410, 343 | ): 344 | 345 | def training_step(engine, batch): 346 | backbone.train() 347 | projector.train() 348 | prototypes.train() 349 | 350 | for o in optimizers: 351 | o.zero_grad() 352 | 353 | with torch.no_grad(): 354 | w = prototypes.weight.data.clone() 355 | w = F.normalize(w, dim=1, p=2) 356 | prototypes.weight.copy_(w) 357 | 358 | x1, x2, d1, d2 = prepare_training_batch(batch, t1, t2, device) 359 | y1, y2 = backbone(x1), backbone(x2) 360 | z1, z2 = projector(y1), projector(y2) 361 | z1 = F.normalize(z1, dim=1, p=2) 362 | z2 = F.normalize(z2, dim=1, p=2) 363 | p1, p2 = prototypes(z1), prototypes(z2) 364 | 365 | q1 = distributed_sinkhorn(torch.exp(p1 / epsilon).t(), n_iters) 366 | q2 = distributed_sinkhorn(torch.exp(p2 / epsilon).t(), n_iters) 367 | 368 | p1 = F.softmax(p1 / temperature, dim=1) 369 | p2 = F.softmax(p2 / temperature, dim=1) 370 | 371 | loss1 = -torch.mean(torch.sum(q1 * torch.log(p2), dim=1)) 372 | loss2 = -torch.mean(torch.sum(q2 * torch.log(p1), dim=1)) 373 | loss = loss1+loss2 374 | 375 | outputs = dict(loss=loss) 376 | outputs['z1'] = z1 377 | outputs['z2'] = z2 378 | 379 | ss_losses = ss_objective(ss_predictor, y1, y2, d1, d2) 380 | (loss+ss_losses['total']).backward() 381 | for k, v in ss_losses.items(): 382 | outputs[f'ss/{k}'] = v 383 | 384 | if engine.state.iteration < freeze_n_iters: 385 | for p in prototypes.parameters(): 386 | p.grad = None 387 | 388 | for o in optimizers: 389 | o.step() 390 | 391 | return outputs 392 | 393 | return Engine(training_step) 394 | 395 | 396 | def collect_features(backbone, 397 | dataloader, 398 | device, 399 | normalize=True, 400 | dst=None, 401 | verbose=False): 402 | 403 | if dst is None: 404 | dst = device 405 | 406 | backbone.eval() 407 | with torch.no_grad(): 408 | features = [] 409 | labels = [] 410 | for i, (x, y) in enumerate(dataloader): 411 | if x.ndim == 5: 412 | _, n, c, h, w = x.shape 413 | x = x.view(-1, c, h, w) 414 | y = y.view(-1, 1).repeat(1, n).view(-1) 415 | z = backbone(x.to(device)) 416 | if normalize: 417 | z = F.normalize(z, dim=-1) 418 | features.append(z.to(dst).detach()) 419 | labels.append(y.to(dst).detach()) 420 | if verbose and (i+1) % 10 == 0: 421 | print(i+1) 422 | features = idist.utils.all_gather(torch.cat(features, 0).detach()) 423 | labels = idist.utils.all_gather(torch.cat(labels, 0).detach()) 424 | 425 | return features, labels 426 | 427 | 428 | def nn_evaluator(backbone, 429 | trainloader, 430 | testloader, 431 | device): 432 | 433 | def evaluator(): 434 | backbone.eval() 435 | with torch.no_grad(): 436 | features, labels = collect_features(backbone, trainloader, device) 437 | corrects, total = 0, 0 438 | for x, y in testloader: 439 | z = F.normalize(backbone(x.to(device)), dim=-1) 440 | scores = torch.einsum('ik, jk -> ij', z, features) 441 | preds = labels[scores.argmax(1)] 442 | 443 | corrects += (preds.cpu() == y).long().sum().item() 444 | total += y.shape[0] 445 | corrects = idist.utils.all_reduce(corrects) 446 | total = idist.utils.all_reduce(total) 447 | 448 | return corrects / total 449 | 450 | return evaluator 451 | 452 | -------------------------------------------------------------------------------- /transfer_few_shot.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | from copy import deepcopy 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.backends.cudnn as cudnn 14 | 15 | import ignite.distributed as idist 16 | 17 | from datasets import load_fewshot_datasets 18 | from models import load_backbone, load_mlp 19 | from trainers import collect_features, SSObjective 20 | from utils import Logger 21 | from transforms import extract_diff 22 | 23 | from sklearn.linear_model import LogisticRegression 24 | 25 | 26 | class FewShotBatchSampler(torch.utils.data.Sampler): 27 | def __init__(self, dataset, N, K, Q, num_iterations): 28 | self.N = N 29 | self.K = K 30 | self.Q = Q 31 | self.num_iterations = num_iterations 32 | 33 | labels = [label for _, label in dataset.samples] 34 | self.label2idx = defaultdict(list) 35 | for i, y in enumerate(labels): 36 | self.label2idx[y].append(i) 37 | 38 | few_labels = [y for y, indices in self.label2idx.items() if len(indices) <= self.K] 39 | for y in few_labels: 40 | del self.label2idx[y] 41 | 42 | def __len__(self): 43 | return self.num_iterations 44 | 45 | def __iter__(self): 46 | label_set = set(list(self.label2idx.keys())) 47 | for _ in range(self.num_iterations): 48 | labels = random.sample(label_set, self.N) 49 | indices = [] 50 | for y in labels: 51 | if len(self.label2idx[y]) >= self.K+self.Q: 52 | indices.extend(list(random.sample(self.label2idx[y], self.K+self.Q))) 53 | else: 54 | tmp_indices = [i for i in self.label2idx[y]] 55 | random.shuffle(tmp_indices) 56 | indices.extend(tmp_indices[:self.K] + np.random.choice(tmp_indices[self.K:], size=self.Q).tolist()) 57 | yield indices 58 | 59 | 60 | def main(local_rank, args): 61 | cudnn.benchmark = True 62 | device = idist.device() 63 | logger = Logger(None) 64 | 65 | # DATASETS 66 | datasets = load_fewshot_datasets(dataset=args.dataset, 67 | datadir=args.datadir, 68 | pretrain_data=args.pretrain_data) 69 | build_sampler = partial(FewShotBatchSampler, 70 | N=args.N, K=args.K, Q=args.Q, num_iterations=args.num_tasks) 71 | build_dataloader = partial(torch.utils.data.DataLoader, 72 | num_workers=args.num_workers) 73 | testloader = build_dataloader(datasets['test'], batch_sampler=build_sampler(datasets['test'])) 74 | 75 | # MODELS 76 | ckpt = torch.load(args.ckpt, map_location=device) 77 | backbone = load_backbone(args).to(device) 78 | backbone.load_state_dict(ckpt['backbone']) 79 | backbone.eval() 80 | 81 | all_accuracies = [] 82 | for i, (batch, _) in enumerate(testloader): 83 | with torch.no_grad(): 84 | batch = batch.to(device) 85 | B, C, H, W = batch.shape 86 | batch = batch.view(args.N, args.K+args.Q, C, H, W) 87 | 88 | train_batch = batch[:, :args.K].reshape(args.N*args.K, C, H, W) 89 | test_batch = batch[:, args.K:].reshape(args.N*args.Q, C, H, W) 90 | train_labels = torch.arange(args.N).unsqueeze(1).repeat(1, args.K).to(device).view(-1) 91 | test_labels = torch.arange(args.N).unsqueeze(1).repeat(1, args.Q).to(device).view(-1) 92 | 93 | with torch.no_grad(): 94 | X_train = backbone(train_batch) 95 | Y_train = train_labels 96 | 97 | X_test = backbone(test_batch) 98 | Y_test = test_labels 99 | 100 | classifier = LogisticRegression(solver='liblinear').fit(X_train.cpu().numpy(), 101 | Y_train.cpu().numpy()) 102 | preds = classifier.predict(X_test.cpu().numpy()) 103 | acc = np.mean((Y_test.cpu().numpy() == preds).astype(float)) 104 | all_accuracies.append(acc) 105 | if (i+1) % 10 == 0: 106 | logger.log_msg(f'{i+1:3d} | {acc:.4f} (mean: {np.mean(all_accuracies):.4f})') 107 | 108 | avg = np.mean(all_accuracies) 109 | std = np.std(all_accuracies) * 1.96 / np.sqrt(len(all_accuracies)) 110 | logger.log_msg(f'mean: {avg:.4f}±{std:.4f}') 111 | 112 | 113 | if __name__ == '__main__': 114 | parser = ArgumentParser() 115 | parser.add_argument('--ckpt', type=str, required=True) 116 | parser.add_argument('--pretrain-data', type=str, default='stl10') 117 | parser.add_argument('--dataset', type=str, default='cub200') 118 | parser.add_argument('--datadir', type=str, default='/data') 119 | parser.add_argument('--N', type=int, default=5) 120 | parser.add_argument('--K', type=int, default=1) 121 | parser.add_argument('--Q', type=int, default=16) 122 | parser.add_argument('--num-workers', type=int, default=8) 123 | parser.add_argument('--model', type=str, default='resnet18') 124 | parser.add_argument('--num-tasks', type=int, default=2000) 125 | args = parser.parse_args() 126 | args.num_backbone_features = 512 if args.model.endswith('resnet18') else 2048 127 | with idist.Parallel(None) as parallel: 128 | parallel.run(main, args) 129 | 130 | -------------------------------------------------------------------------------- /transfer_linear_eval.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from functools import partial 3 | from copy import deepcopy 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | 11 | import ignite.distributed as idist 12 | 13 | from datasets import load_datasets 14 | from models import load_backbone 15 | from trainers import collect_features 16 | from utils import Logger 17 | 18 | 19 | def build_step(X, Y, classifier, optimizer, w): 20 | def step(): 21 | optimizer.zero_grad() 22 | loss = F.cross_entropy(classifier(X), Y, reduction='sum') 23 | for p in classifier.parameters(): 24 | loss = loss + p.pow(2).sum().mul(w) 25 | loss.backward() 26 | return loss 27 | return step 28 | 29 | 30 | def compute_accuracy(X, Y, classifier, metric): 31 | with torch.no_grad(): 32 | preds = classifier(X).argmax(1) 33 | if metric == 'top1': 34 | acc = (preds == Y).float().mean().item() 35 | elif metric == 'class-avg': 36 | total, count = 0., 0. 37 | for y in range(0, Y.max().item()+1): 38 | masks = Y == y 39 | if masks.sum() > 0: 40 | total += (preds[masks] == y).float().mean().item() 41 | count += 1 42 | acc = total / count 43 | else: 44 | raise Exception(f'Unknown metric: {metric}') 45 | return acc 46 | 47 | 48 | def main(local_rank, args): 49 | cudnn.benchmark = True 50 | device = idist.device() 51 | logger = Logger(None) 52 | 53 | # DATASETS 54 | datasets = load_datasets(dataset=args.dataset, 55 | datadir=args.datadir, 56 | pretrain_data=args.pretrain_data) 57 | build_dataloader = partial(idist.auto_dataloader, 58 | batch_size=args.batch_size, 59 | num_workers=args.num_workers, 60 | shuffle=True, 61 | pin_memory=True) 62 | trainloader = build_dataloader(datasets['train'], drop_last=False) 63 | valloader = build_dataloader(datasets['val'], drop_last=False) 64 | testloader = build_dataloader(datasets['test'], drop_last=False) 65 | num_classes = datasets['num_classes'] 66 | 67 | # MODELS 68 | ckpt = torch.load(args.ckpt, map_location=device) 69 | backbone = load_backbone(args) 70 | backbone.load_state_dict(ckpt['backbone']) 71 | 72 | build_model = partial(idist.auto_model, sync_bn=True) 73 | backbone = build_model(backbone) 74 | 75 | # EXTRACT FROZEN FEATURES 76 | logger.log_msg('collecting features ...') 77 | X_train, Y_train = collect_features(backbone, trainloader, device, normalize=False) 78 | X_val, Y_val = collect_features(backbone, valloader, device, normalize=False) 79 | X_test, Y_test = collect_features(backbone, testloader, device, normalize=False) 80 | classifier = nn.Linear(args.num_backbone_features, num_classes).to(device) 81 | optim_kwargs = { 82 | 'line_search_fn': 'strong_wolfe', 83 | 'max_iter': 5000, 84 | 'lr': 1., 85 | 'tolerance_grad': 1e-10, 86 | 'tolerance_change': 0, 87 | } 88 | logger.log_msg('collecting features ... done') 89 | 90 | best_acc = 0. 91 | best_w = 0. 92 | best_classifier = None 93 | for w in torch.logspace(-6, 5, steps=45).tolist(): 94 | optimizer = optim.LBFGS(classifier.parameters(), **optim_kwargs) 95 | optimizer.step(build_step(X_train, Y_train, classifier, optimizer, w)) 96 | acc = compute_accuracy(X_val, Y_val, classifier, args.metric) 97 | 98 | if best_acc < acc: 99 | best_acc = acc 100 | best_w = w 101 | best_classifier = deepcopy(classifier) 102 | 103 | logger.log_msg(f'w={w:.4e}, acc={acc:.4f}') 104 | 105 | logger.log_msg(f'BEST: w={best_w:.4e}, acc={best_acc:.4f}') 106 | 107 | X = torch.cat([X_train, X_val], 0) 108 | Y = torch.cat([Y_train, Y_val], 0) 109 | optimizer = optim.LBFGS(best_classifier.parameters(), **optim_kwargs) 110 | optimizer.step(build_step(X, Y, best_classifier, optimizer, best_w)) 111 | acc = compute_accuracy(X_test, Y_test, best_classifier, args.metric) 112 | logger.log_msg(f'test acc={acc:.4f}') 113 | 114 | if __name__ == '__main__': 115 | parser = ArgumentParser() 116 | parser.add_argument('--ckpt', type=str, required=True) 117 | parser.add_argument('--pretrain-data', type=str, default='stl10') 118 | parser.add_argument('--dataset', type=str, default='cifar10') 119 | parser.add_argument('--datadir', type=str, default='/data') 120 | parser.add_argument('--batch-size', type=int, default=256) 121 | parser.add_argument('--num-workers', type=int, default=4) 122 | parser.add_argument('--model', type=str, default='resnet18') 123 | parser.add_argument('--print-freq', type=int, default=10) 124 | parser.add_argument('--distributed', action='store_true') 125 | parser.add_argument('--metric', type=str, default='top1') 126 | args = parser.parse_args() 127 | args.backend = 'nccl' if args.distributed else None 128 | args.num_backbone_features = 512 if args.model.endswith('resnet18') else 2048 129 | with idist.Parallel(args.backend) as parallel: 130 | parallel.run(main, args) 131 | 132 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as NF 5 | import torchvision.transforms as T 6 | import torchvision.transforms.functional as F 7 | import kornia 8 | import kornia.augmentation as K 9 | import kornia.augmentation.functional as KF 10 | 11 | 12 | class MultiView: 13 | def __init__(self, transform, num_views=2): 14 | self.transform = transform 15 | self.num_views = num_views 16 | 17 | def __call__(self, x): 18 | return [self.transform(x) for _ in range(self.num_views)] 19 | 20 | 21 | class RandomResizedCrop(T.RandomResizedCrop): 22 | def forward(self, img): 23 | W, H = F._get_image_size(img) 24 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 25 | img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 26 | tensor = F.to_tensor(img) 27 | return tensor, torch.tensor([i/H, j/W, h/H, w/W], dtype=torch.float) 28 | 29 | 30 | def apply_adjust_brightness(img1, params): 31 | ratio = params['brightness_factor'][:, None, None, None].to(img1.device) 32 | img2 = torch.zeros_like(img1) 33 | return (ratio * img1 + (1.0-ratio) * img2).clamp(0, 1) 34 | 35 | 36 | def apply_adjust_contrast(img1, params): 37 | ratio = params['contrast_factor'][:, None, None, None].to(img1.device) 38 | img2 = 0.2989 * img1[:, 0:1] + 0.587 * img1[:, 1:2] + 0.114 * img1[:, 2:3] 39 | img2 = torch.mean(img2, dim=(-2, -1), keepdim=True) 40 | return (ratio * img1 + (1.0-ratio) * img2).clamp(0, 1) 41 | 42 | 43 | class ColorJitter(K.ColorJitter): 44 | def apply_transform(self, x, params): 45 | transforms = [ 46 | lambda img: apply_adjust_brightness(img, params), 47 | lambda img: apply_adjust_contrast(img, params), 48 | lambda img: KF.apply_adjust_saturation(img, params), 49 | lambda img: KF.apply_adjust_hue(img, params) 50 | ] 51 | 52 | for idx in params['order'].tolist(): 53 | t = transforms[idx] 54 | x = t(x) 55 | 56 | return x 57 | 58 | 59 | class GaussianBlur(K.AugmentationBase2D): 60 | def __init__(self, kernel_size, sigma, border_type='reflect', 61 | return_transform=False, same_on_batch=False, p=0.5): 62 | super().__init__( 63 | p=p, return_transform=return_transform, same_on_batch=same_on_batch, p_batch=1.) 64 | assert kernel_size % 2 == 1 65 | self.kernel_size = kernel_size 66 | self.sigma = sigma 67 | self.border_type = border_type 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + f"({super().__repr__()})" 71 | 72 | def generate_parameters(self, batch_shape): 73 | return dict(sigma=torch.zeros(batch_shape[0]).uniform_(self.sigma[0], self.sigma[1])) 74 | 75 | def apply_transform(self, input, params): 76 | sigma = params['sigma'].to(input.device) 77 | k_half = self.kernel_size // 2 78 | x = torch.linspace(-k_half, k_half, steps=self.kernel_size, dtype=input.dtype, device=input.device) 79 | pdf = torch.exp(-0.5*(x[None, :] / sigma[:, None]).pow(2)) 80 | kernel1d = pdf / pdf.sum(1, keepdim=True) 81 | kernel2d = torch.bmm(kernel1d[:, :, None], kernel1d[:, None, :]) 82 | input = NF.pad(input, (k_half, k_half, k_half, k_half), mode=self.border_type) 83 | input = NF.conv2d(input.transpose(0, 1), kernel2d[:, None], groups=input.shape[0]).transpose(0, 1) 84 | return input 85 | 86 | 87 | class RandomRotation(K.AugmentationBase2D): 88 | def __init__(self, return_transform=False, same_on_batch=False, p=0.5): 89 | super().__init__( 90 | p=p, return_transform=return_transform, same_on_batch=same_on_batch, p_batch=1.) 91 | 92 | def __repr__(self): 93 | return self.__class__.__name__ + f"({super().__repr__()})" 94 | 95 | def generate_parameters(self, batch_shape): 96 | degrees = torch.randint(0, 4, (batch_shape[0], )) 97 | return dict(degrees=degrees) 98 | 99 | def apply_transform(self, input, params): 100 | degrees = params['degrees'] 101 | input = torch.stack([torch.rot90(x, k, (1, 2)) for x, k in zip(input, degrees.tolist())], 0) 102 | return input 103 | 104 | 105 | def _extract_w(t): 106 | if isinstance(t, GaussianBlur): 107 | m = t._params['batch_prob'] 108 | w = torch.zeros(m.shape[0], 1) 109 | w[m] = t._params['sigma'].unsqueeze(-1) 110 | return w 111 | 112 | elif isinstance(t, ColorJitter): 113 | to_apply = t._params['batch_prob'] 114 | w = torch.zeros(to_apply.shape[0], 4) 115 | w[to_apply, 0] = (t._params['brightness_factor'] - 1) / (t.brightness[1]-t.brightness[0]) 116 | w[to_apply, 1] = (t._params['contrast_factor'] - 1) / (t.contrast[1]-t.contrast[0]) 117 | w[to_apply, 2] = (t._params['saturation_factor'] - 1) / (t.saturation[1]-t.saturation[0]) 118 | w[to_apply, 3] = t._params['hue_factor'] / (t.hue[1]-t.hue[0]) 119 | return w 120 | 121 | elif isinstance(t, RandomRotation): 122 | to_apply = t._params['batch_prob'] 123 | w = torch.zeros(to_apply.shape[0], dtype=torch.long) 124 | w[to_apply] = t._params['degrees'] 125 | return w 126 | 127 | elif isinstance(t, K.RandomSolarize): 128 | to_apply = t._params['batch_prob'] 129 | w = torch.ones(to_apply.shape[0]) 130 | w[to_apply] = t._params['thresholds_factor'] 131 | return w 132 | 133 | 134 | def extract_diff(transforms1, transforms2, crop1, crop2): 135 | diff = {} 136 | for t1, t2 in zip(transforms1, transforms2): 137 | if isinstance(t1, K.RandomHorizontalFlip): 138 | f1 = t1._params['batch_prob'] 139 | f2 = t2._params['batch_prob'] 140 | break 141 | 142 | center1 = crop1[:, :2]+crop1[:, 2:]/2 143 | center2 = crop2[:, :2]+crop2[:, 2:]/2 144 | center1[f1, 1] = 1-center1[f1, 1] 145 | center2[f1, 1] = 1-center2[f1, 1] 146 | diff['crop'] = torch.cat([center1-center2, crop1[:, 2:]-crop2[:, 2:]], 1) 147 | diff['flip'] = (f1==f2).float().unsqueeze(-1) 148 | for t1, t2 in zip(transforms1, transforms2): 149 | if isinstance(t1, K.RandomHorizontalFlip): 150 | pass 151 | 152 | elif isinstance(t1, K.RandomGrayscale): 153 | pass 154 | 155 | elif isinstance(t1, GaussianBlur): 156 | w1 = _extract_w(t1) 157 | w2 = _extract_w(t2) 158 | diff['blur'] = w1-w2 159 | 160 | elif isinstance(t1, K.Normalize): 161 | pass 162 | 163 | elif isinstance(t1, K.ColorJitter): 164 | w1 = _extract_w(t1) 165 | w2 = _extract_w(t2) 166 | diff['color'] = w1-w2 167 | 168 | elif isinstance(t1, (nn.Identity, nn.Sequential)): 169 | pass 170 | 171 | elif isinstance(t1, RandomRotation): 172 | w1 = _extract_w(t1) 173 | w2 = _extract_w(t2) 174 | diff['rot'] = (w1-w2+4) % 4 175 | 176 | elif isinstance(t1, K.RandomSolarize): 177 | w1 = _extract_w(t1) 178 | w2 = _extract_w(t2) 179 | diff['sol'] = w1-w2 180 | 181 | else: 182 | raise Exception(f'Unknown transform: {str(t1.__class__)}') 183 | 184 | return diff 185 | 186 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import torch 5 | from torch.optim import Optimizer 6 | from torch.utils.tensorboard import SummaryWriter 7 | import ignite.distributed as idist 8 | 9 | class Logger(object): 10 | 11 | def __init__(self, logdir, resume=None): 12 | self.logdir = logdir 13 | self.rank = idist.get_rank() 14 | 15 | handlers = [logging.StreamHandler(os.sys.stdout)] 16 | if logdir is not None and self.rank == 0: 17 | if resume is None: 18 | os.makedirs(logdir) 19 | handlers.append(logging.FileHandler(os.path.join(logdir, 'log.txt'))) 20 | self.writer = SummaryWriter(log_dir=logdir) 21 | else: 22 | self.writer = None 23 | 24 | logging.basicConfig(format=f"[%(asctime)s ({self.rank})] %(message)s", 25 | level=logging.INFO, 26 | handlers=handlers) 27 | logging.info(' '.join(os.sys.argv)) 28 | 29 | def log_msg(self, msg): 30 | if idist.get_rank() > 0: 31 | return 32 | logging.info(msg) 33 | 34 | def log(self, engine, global_step, print_msg=True, **kwargs): 35 | msg = f'[epoch {engine.state.epoch}] [iter {engine.state.iteration}]' 36 | for k, v in kwargs.items(): 37 | if isinstance(v, torch.Tensor): 38 | v = v.item() 39 | 40 | if type(v) is float: 41 | msg += f' [{k} {v:.4f}]' 42 | else: 43 | msg += f' [{k} {v}]' 44 | 45 | if self.writer is not None: 46 | self.writer.add_scalar(k, v, global_step) 47 | 48 | if print_msg: 49 | logging.info(msg) 50 | 51 | def save(self, engine, **kwargs): 52 | if idist.get_rank() > 0: 53 | return 54 | 55 | state = {} 56 | for k, v in kwargs.items(): 57 | if isinstance(v, torch.nn.parallel.DistributedDataParallel): 58 | v = v.module 59 | 60 | if hasattr(v, 'state_dict'): 61 | state[k] = v.state_dict() 62 | 63 | if type(v) is list and hasattr(v[0], 'state_dict'): 64 | state[k] = [x.state_dict() for x in v] 65 | 66 | if type(v) is dict and k == 'ss_predictor': 67 | state[k] = { y: x.state_dict() for y, x in v.items() } 68 | 69 | torch.save(state, os.path.join(self.logdir, f'ckpt-{engine.state.epoch}.pth')) 70 | 71 | --------------------------------------------------------------------------------