├── LICENSE ├── README.md ├── cifar10_100 ├── main.py ├── models │ ├── WideResnet.py │ ├── resnet.py │ ├── resnet18_classifier.py │ └── wrn16x8_classifier.py ├── submit_job.sh └── utils │ └── OT.py ├── imagenet ├── gen.slurm ├── main.py ├── models │ ├── resnet.py │ └── resnet50_classifier.py └── utils │ ├── OT.py │ └── util.py ├── logfiles ├── cifar10 │ ├── log_r18.txt │ └── log_wrn16x8.txt ├── cifar100 │ ├── log_r18.txt │ └── log_wrn16x8.txt └── tiny_imagenet │ └── log.txt └── tiny_imgnet ├── main.py ├── models ├── resnet.py └── resnet18_classifier.py ├── submit_job.sh └── utils ├── OT.py ├── dataLoader.py ├── wnids.txt └── words.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shashanka Venkataramanan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlignMixup (CVPR 2022) 2 | This repo consists of the official Pytorch code for our CVPR 2022 paper AlignMixup: Improving Representations By Interpolating Aligned Features (http://arxiv.org/abs/2103.15375) 3 | 4 | ### Requirements 5 | This code has been tested with 6 | python 3.8.11 7 | torch 1.10.1 8 | torchvision 0.11.2 9 | numpy==1.21.0 10 | 11 | ### Additional package versions 12 | cuda 11.3.1 13 | cudnn 8.2.0.53-11.3 14 | tar==1.34 15 | py-virtualenv==16.7.6 16 | 17 | 18 | ### Dataset Preparation 19 | 20 | 1. For CIFAR-10/100, the dataset will automatically be downloaded, if there does not exist any CIFAR-10/100 directory in the path specified while executing the code. 21 | 2. For Tiny-Imagenet-200, you can download the dataset from [here](http://cs231n.stanford.edu/tiny-imagenet-200.zip). Unzip it and specify its path in the code. 22 | 23 | Alternatively, you can run the following command in your terminal if you have ```wget``` installed to download it to your current directory: 24 | ``` 25 | wget http://cs231n.stanford.edu/tiny-imagenet-200.zip 26 | ``` 27 | 28 | 29 | ### How to run experiments for CIFAR-10 30 | 31 | #### AlignMixup PreActResnet18 32 | ``` 33 | cd cifar10_100 34 | 35 | python main.py --dataset cifar10 --data_dir path_to_cifar10_directory \ 36 | --save_dir path_to_save_checkpoints --network resnet --epochs 2000 \ 37 | --alpha 2.0 --num_classes 10 --manualSeed 8492 38 | ``` 39 | 40 | #### AlignMixup WRN 16x8 41 | ``` 42 | cd cifar10_100 43 | 44 | python main.py --dataset cifar10 --data_dir path_to_cifar10_directory \ 45 | --save_dir path_to_save_checkpoints --network wideresnet --epochs 2000 \ 46 | --alpha 2.0 --num_classes 10 --manualSeed 8492 47 | ``` 48 | 49 | 50 | 51 | ### How to run experiments for CIFAR-100 52 | 53 | #### AlignMixup PreActResnet18 54 | ``` 55 | cd cifar10_100 56 | 57 | python main.py --dataset cifar100 --data_dir path_to_cifar100_directory \ 58 | --save_dir path_to_save_checkpoints --network resnet --epochs 2000 \ 59 | --alpha 2.0 --num_classes 100 --manualSeed 8492 60 | ``` 61 | 62 | #### AlignMixup WRN 16x8 63 | ``` 64 | cd cifar10_100 65 | 66 | python main.py --dataset cifar100 --data_dir path_to_cifar100_directory \ 67 | --save_dir path_to_save_checkpoints --network wideresnet --epochs 2000 \ 68 | --alpha 2.0 --num_classes 100 --manualSeed 8492 69 | ``` 70 | 71 | 72 | ### How to run experiments for Tiny-Imagenet-200 73 | 74 | 75 | #### AlignMixup PreActResnet18 76 | ``` 77 | cd tiny_imgnet 78 | 79 | python main.py --train_dir path_to_train_directory \ 80 | --val_dir path_to_val_directory \ 81 | --save_dir path_to_save_checkpoints --epochs 1200 \ 82 | --alpha 2.0 --num_classes 200 --manualSeed 8492 83 | ``` 84 | 85 | 86 | ### How to run experiments for Imagenet 87 | 88 | #### To run on a subset of training set (i.e approx 20% images per class) 89 | 90 | ``` 91 | cd imagenet 92 | 93 | python main.py --data_dir path_to_imagenet_directory --save_dir path_to_save_checkpoints \ 94 | --mini_imagenet True --subset 260 --num_classes 1000 --epochs 300 --alpha 2.0 --batch_size 1024 95 | ``` 96 | 97 | 98 | #### To run on a full imagenet 99 | 100 | ``` 101 | cd imagenet 102 | 103 | python main.py --data_dir path_to_imagenet_directory --save_dir path_to_save_checkpoints \ 104 | --mini_imagenet False --num_classes 1000 --epochs 300 --alpha 2.0 --batch_size 1024 105 | ``` 106 | 107 | #### TODO 108 | Imagenet using Distributed data parallel (multiple nodes) - coming soon 109 | 110 | 111 | ## Results 112 | 113 | | Dataset | Network | AlignMixup | | 114 | |:--------------:|:---------:|:----------:|---| 115 | | CIFAR-10 | Resnet-18 | 97.05% | [log](logfiles/cifar10/log_r18.txt) | 116 | | CIFAR-10 | WRN 16x8 | 96.91% | [log](logfiles/cifar10/log_wrn16x8.txt) | 117 | | CIFAR-100 | Resnet-18 | 81.71% | [log](logfiles/cifar100/log_r18.txt) | 118 | | CIFAR-100 | WRN 16x8 | 81.23% | [log](logfiles/cifar100/log_wrn16x8.txt)| 119 | | Tiny-Imagenet | Resnet-18 | 66.87% | [log](logfiles/tiny_imagenet/log.txt) | 120 | | Imagenet | Resnet-50 | 79.32% | [log](logfiles/imnet/log.txt) | 121 | 122 | ## Acknowledgement 123 | The code for Sinkhorn-Knopp algorithm is adapted and modified based on this amazing repository by [Daniel Daza](https://github.com/dfdazac/wassdistance) 124 | 125 | 126 | 127 | 128 | ## Citation 129 | 130 | If you find this work useful and use it on your own research, please cite our paper 131 | 132 | ``` 133 | @inproceedings{venkataramanan2021alignmix, 134 | title={AlignMixup: Improving Representations By Interpolating Aligned Features}, 135 | author={Venkataramanan, Shashanka and Kijak, Ewa and Amsaleg, Laurent and Avrithis, Yannis}, 136 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 137 | year={2022} 138 | } 139 | 140 | ``` 141 | -------------------------------------------------------------------------------- /cifar10_100/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import argparse 5 | import itertools 6 | from models.resnet18_classifier import Resnet_classifier 7 | from models.wrn16x8_classifier import WideResNet_classifier 8 | 9 | import torch 10 | from torch import nn, optim 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | import torch.backends.cudnn as cudnn 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Trains ResNet on CIFAR-10/100', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--network', type=str, default='resnet', choices=['resnet', 'wideresnet'], help='Choose between resnet18/WRN16x8') 18 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100') 19 | parser.add_argument('--data_dir', type = str, default = '/srv/tempdd/svenkata/Datasets/data', 20 | help='path to train/test folder') 21 | parser.add_argument('--save_dir', type = str, default = '/nfs/pyrex/raid6/svenkata/weights/AlignMixup_CVPR22/cifar100/', 22 | help='folder where results are to be stored') 23 | 24 | # Optimization options 25 | parser.add_argument('--epochs', type=int, default=2000, help='Number of epochs to train.') 26 | parser.add_argument('--alpha', type=float, default=2.0, help='alpha parameter for mixup') 27 | parser.add_argument('--num_classes', type=int, default=10, help='number of classes, set 100 for CIFAR-100') 28 | 29 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 30 | parser.add_argument('--lr_', type=float, default=0.1, help='The Learning Rate.') 31 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 32 | 33 | parser.add_argument('--decay', type=float, default=1e-4, help='Weight decay (L2 penalty).') 34 | parser.add_argument('--schedule', type=int, nargs='+', default=[500, 1000, 1500], help='Decrease learning rate at these epochs.') 35 | parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1, 0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule') 36 | 37 | # Checkpoints 38 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 39 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 40 | 41 | # Acceleration 42 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') 43 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') 44 | 45 | # random seed 46 | parser.add_argument('--manualSeed', type=int, help='manual seed') 47 | 48 | args = parser.parse_args() 49 | 50 | out_str = str(args) 51 | print(out_str) 52 | 53 | 54 | device = torch.device("cuda" if args.ngpu>0 and torch.cuda.is_available() else "cpu") 55 | 56 | if args.manualSeed is None: 57 | args.manualSeed = random.randint(1, 10000) 58 | 59 | # args.manualSeed = 8492 60 | random.seed(args.manualSeed) 61 | np.random.seed(args.manualSeed) 62 | torch.manual_seed(args.manualSeed) 63 | torch.cuda.manual_seed_all(args.manualSeed) 64 | cudnn.benchmark = True 65 | 66 | 67 | if not os.path.exists(args.save_dir + args.network): 68 | os.makedirs(args.save_dir + args.network) 69 | 70 | 71 | transform_train = transforms.Compose([ 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.4914, 0.4822, 0.4465), 76 | (0.2023, 0.1994, 0.2010)), 77 | ]) 78 | 79 | 80 | transform_test = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.4914, 0.4822, 0.4465), 83 | (0.2023, 0.1994, 0.2010)), 84 | ]) 85 | 86 | 87 | if (args.dataset == 'cifar10'): 88 | train_data = datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train) 89 | test_data = datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 90 | 91 | elif(args.dataset == 'cifar100'): 92 | train_data = datasets.CIFAR100(root=args.data_dir, train=True, download=True, transform=transform_train) 93 | test_data = datasets.CIFAR100(root=args.data_dir, train=False, download=True, transform=transform_test) 94 | 95 | 96 | trainloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 97 | testloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 98 | 99 | 100 | 101 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 102 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 103 | 104 | 105 | 106 | def adjust_learning_rate(optimizer, epoch, gammas, schedule): 107 | """Sets the learning rate to the initial LR decayed by 10 every 500 epochs""" 108 | lr = args.lr_ 109 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 110 | for (gamma, step) in zip(gammas, schedule): 111 | if (epoch >= step): 112 | lr = lr * gamma 113 | else: 114 | break 115 | for param_group in optimizer.param_groups: 116 | param_group['lr'] = lr 117 | 118 | return lr 119 | 120 | 121 | if args.network == 'resnet': 122 | model = Resnet_classifier(args.num_classes) 123 | elif args.network == 'wideresnet': 124 | model = WideResNet_classifier(args.num_classes) 125 | 126 | model = torch.nn.DataParallel(model) 127 | model.to(device) 128 | print(model) 129 | 130 | criterion = nn.CrossEntropyLoss() 131 | optimizer = optim.SGD(model.parameters(), lr=args.lr_, momentum=args.momentum, weight_decay=args.decay) 132 | best_acc = 0 133 | 134 | 135 | if args.resume: 136 | if os.path.isfile(args.resume): 137 | print("=> loading checkpoint '{}'".format(args.resume)) 138 | checkpoint = torch.load(args.resume) 139 | args.start_epoch = checkpoint['epoch'] 140 | model.load_state_dict(checkpoint['model']) 141 | optimizer.load_state_dict(checkpoint['optimizer']) 142 | best_acc = checkpoint['acc'] 143 | print("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch'])) 144 | else: 145 | print("=> no checkpoint found at '{}'".format(args.resume)) 146 | 147 | 148 | 149 | 150 | def train(epoch): 151 | 152 | model.train() 153 | total_loss = 0 154 | correct = 0 155 | 156 | for i, (images, targets) in enumerate(trainloader): 157 | 158 | images = images.to(device) 159 | targets = targets.to(device) 160 | 161 | lam = np.random.beta(args.alpha, args.alpha) 162 | 163 | outputs,targets_a,targets_b = model(images, targets, lam, mode='train') 164 | 165 | loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) 166 | 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | total_loss += loss.item() 172 | 173 | _,pred = torch.max(outputs, dim=1) 174 | correct += (pred == targets).sum().item() 175 | 176 | print('epoch: {} --> Train loss = {:.4f} Train Accuracy = {:.4f} '.format(epoch, total_loss / len(trainloader.dataset), 100.*correct / len(trainloader.dataset))) 177 | 178 | 179 | 180 | def test(epoch): 181 | global best_acc 182 | model.eval() 183 | test_loss = 0 184 | correct = 0 185 | total = 0 186 | with torch.no_grad(): 187 | for batch_idx, (inputs, targets) in enumerate(testloader): 188 | 189 | inputs = inputs.to(device) 190 | targets = targets.to(device) 191 | 192 | outputs = model(inputs, None, None, mode='test') 193 | 194 | loss = criterion(outputs, targets) 195 | 196 | test_loss += loss.item() 197 | _, predicted = torch.max(outputs.data, 1) 198 | total += targets.size(0) 199 | correct += predicted.eq(targets.data).cpu().sum() 200 | 201 | print('------> epoch: {} --> Test loss = {:.4f} Test Accuracy = {:.4f} '.format(epoch,test_loss / len(testloader.dataset), 100.*correct / len(testloader.dataset))) 202 | 203 | acc = 100.*correct/total 204 | if acc > best_acc: 205 | checkpoint(acc, epoch) 206 | best_acc = acc 207 | 208 | return best_acc 209 | 210 | 211 | def checkpoint(acc, epoch): 212 | # Save checkpoint. 213 | print('Saving..') 214 | state = { 215 | 'model': model.state_dict(), 216 | 'optimizer' : optimizer.state_dict(), 217 | 'acc': acc, 218 | 'epoch': epoch, 219 | 'seed' : args.manualSeed 220 | } 221 | 222 | torch.save(state, args.save_dir + args.network + '/' + 'checkpoint.t7') 223 | 224 | 225 | 226 | if __name__ == '__main__': 227 | for epoch in range(args.start_epoch, args.epochs): 228 | adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) 229 | train(epoch) 230 | best_accuracy = test(epoch) 231 | 232 | print('Best Accuracy = ', best_accuracy) 233 | 234 | 235 | -------------------------------------------------------------------------------- /cifar10_100/models/WideResnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys, os 8 | import numpy as np 9 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | act = torch.nn.ReLU() 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | return nn.Conv2d(in_planes, 16 | out_planes, 17 | kernel_size=3, 18 | stride=stride, 19 | padding=1, 20 | bias=True) 21 | 22 | 23 | def conv_init(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('Conv') != -1: 26 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 27 | init.constant(m.bias, 0) 28 | elif classname.find('BatchNorm') != -1: 29 | init.constant(m.weight, 1) 30 | init.constant(m.bias, 0) 31 | 32 | 33 | class wide_basic(nn.Module): 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(wide_basic, self).__init__() 36 | self.bn1 = nn.BatchNorm2d(in_planes) 37 | self.conv1 = nn.Conv2d(in_planes, 38 | planes, 39 | kernel_size=3, 40 | padding=1, 41 | bias=True) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, 44 | planes, 45 | kernel_size=3, 46 | stride=stride, 47 | padding=1, 48 | bias=True) 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != planes: 52 | self.shortcut = nn.Sequential( 53 | nn.Conv2d(in_planes, 54 | planes, 55 | kernel_size=1, 56 | stride=stride, 57 | bias=True), ) 58 | 59 | def forward(self, x): 60 | out = self.conv1(act(self.bn1(x))) 61 | out = self.conv2(act(self.bn2(out))) 62 | out += self.shortcut(x) 63 | 64 | return out 65 | 66 | 67 | class Wide_ResNet(nn.Module): 68 | def __init__(self, 69 | depth, 70 | widen_factor, 71 | num_classes, 72 | stride=1, 73 | parallel=False): 74 | super(Wide_ResNet, self).__init__() 75 | self.num_classes = num_classes 76 | self.in_planes = 16 77 | 78 | assert ((depth - 4) % 6 == 0), 'Wide-resnet_v2 depth should be 6n+4' 79 | n = int((depth - 4) / 6) 80 | k = widen_factor 81 | 82 | print('| Wide-Resnet %dx%d' % (depth, k)) 83 | nStages = [16, 16 * k, 32 * k, 64 * k] 84 | 85 | self.conv1 = conv3x3(3, nStages[0], stride=stride) 86 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, stride=1) 87 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, stride=2) 88 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, stride=2) 89 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 90 | 91 | 92 | def _wide_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1] * (num_blocks - 1) 94 | layers = [] 95 | 96 | for stride in strides: 97 | layers.append(block(self.in_planes, planes, stride)) 98 | self.in_planes = planes 99 | 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = x 104 | out = self.conv1(out) 105 | out = self.layer1(out) 106 | out = self.layer2(out) 107 | out = self.layer3(out) 108 | out = act(self.bn1(out)) 109 | 110 | return out 111 | 112 | 113 | def wrn28_10(num_classes=10, dropout=False, stride=1): 114 | model = Wide_ResNet(depth=28, 115 | widen_factor=10, 116 | num_classes=num_classes, 117 | stride=stride) 118 | return model 119 | 120 | 121 | def wrn28_2(num_classes=10, dropout=False, stride=1): 122 | model = Wide_ResNet(depth=28, 123 | widen_factor=2, 124 | num_classes=num_classes, 125 | stride=stride) 126 | return model 127 | 128 | 129 | def wrn16_8(num_classes=10, dropout=False, stride=1): 130 | model = Wide_ResNet(depth=16, 131 | widen_factor=8, 132 | num_classes=num_classes, 133 | stride=stride) 134 | return model -------------------------------------------------------------------------------- /cifar10_100/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from torch.autograd import Variable 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(in_planes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion*planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 36 | nn.BatchNorm2d(self.expansion*planes) 37 | ) 38 | 39 | def forward(self, x): 40 | out = F.relu(self.bn1(self.conv1(x))) 41 | out = self.bn2(self.conv2(out)) 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | '''Pre-activation version of the BasicBlock.''' 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1): 52 | super(PreActBlock, self).__init__() 53 | self.bn1 = nn.BatchNorm2d(in_planes) 54 | self.conv1 = conv3x3(in_planes, planes, stride) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv2 = conv3x3(planes, planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out += shortcut 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion*planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion*planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class PreActBottleneck(nn.Module): 102 | '''Pre-activation version of the original Bottleneck module.''' 103 | expansion = 4 104 | 105 | def __init__(self, in_planes, planes, stride=1): 106 | super(PreActBottleneck, self).__init__() 107 | self.bn1 = nn.BatchNorm2d(in_planes) 108 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes) 112 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 113 | 114 | self.shortcut = nn.Sequential() 115 | if stride != 1 or in_planes != self.expansion*planes: 116 | self.shortcut = nn.Sequential( 117 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 118 | ) 119 | 120 | def forward(self, x): 121 | out = F.relu(self.bn1(x)) 122 | shortcut = self.shortcut(out) 123 | out = self.conv1(out) 124 | out = self.conv2(F.relu(self.bn2(out))) 125 | out = self.conv3(F.relu(self.bn3(out))) 126 | out += shortcut 127 | return out 128 | 129 | 130 | class ResNet(nn.Module): 131 | def __init__(self, block, num_blocks, num_classes=10): 132 | super(ResNet, self).__init__() 133 | self.in_planes = 64 134 | self.conv1 = conv3x3(3,64) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 137 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 140 | 141 | def _make_layer(self, block, planes, num_blocks, stride): 142 | strides = [stride] + [1]*(num_blocks-1) 143 | layers = [] 144 | for stride in strides: 145 | layers.append(block(self.in_planes, planes, stride)) 146 | self.in_planes = planes * block.expansion 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x, lin=0, lout=4): 150 | out = x 151 | if lin < 1 and lout > -1: 152 | out = self.conv1(out) 153 | out = self.bn1(out) 154 | out = F.relu(out) 155 | if lin < 2 and lout > 0: 156 | out = self.layer1(out) 157 | if lin < 3 and lout > 1: 158 | out = self.layer2(out) 159 | if lin < 4 and lout > 2: 160 | out = self.layer3(out) 161 | if lin < 5 and lout > 3: 162 | out = self.layer4(out) 163 | # if lout > 4: 164 | # out = F.avg_pool2d(out, 4) 165 | # out = out.view(out.size(0), -1) 166 | # out = self.linear(out) 167 | return out 168 | 169 | 170 | def ResNet18(): 171 | return ResNet(PreActBlock, [2,2,2,2]) 172 | 173 | def ResNet34(): 174 | return ResNet(BasicBlock, [3,4,6,3]) 175 | 176 | def ResNet50(): 177 | return ResNet(Bottleneck, [3,4,6,3]) 178 | 179 | def ResNet101(): 180 | return ResNet(Bottleneck, [3,4,23,3]) 181 | 182 | def ResNet152(): 183 | return ResNet(Bottleneck, [3,8,36,3]) 184 | 185 | 186 | def test(): 187 | net = ResNet18() 188 | y = net(Variable(torch.randn(1,3,32,32))) 189 | print(y.size()) 190 | 191 | # test() 192 | -------------------------------------------------------------------------------- /cifar10_100/models/resnet18_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from utils.OT import SinkhornDistance 7 | import numpy as np 8 | from models.resnet import ResNet18 9 | import random 10 | 11 | 12 | def mixup_process(out, y, lam): 13 | indices = np.random.permutation(out.size(0)) 14 | out = out*lam + out[indices]*(1-lam) 15 | y_a, y_b = y, y[indices] 16 | return out, y_a, y_b 17 | 18 | 19 | 20 | def mixup_aligned(out, y, lam): 21 | # out shape = batch_size x 512 x 4 x 4 (cifar10/100) 22 | 23 | indices = np.random.permutation(out.size(0)) 24 | feat1 = out.view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 25 | feat2 = out[indices].view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 26 | 27 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) 28 | P = sinkhorn(feat1.permute(0,2,1), feat2.permute(0,2,1)).detach() # optimal plan batch x 16 x 16 29 | 30 | P = P*(out.size(2)*out.size(3)) # assignment matrix 31 | 32 | align_mix = random.randint(0,1) # uniformly choose at random, which alignmix to perform 33 | 34 | if (align_mix == 0): 35 | # \tilde{A} = A'R^{T} 36 | f1 = torch.matmul(feat2, P.permute(0,2,1).cuda()).view(out.shape) 37 | final = feat1.view(out.shape)*lam + f1*(1-lam) 38 | 39 | elif (align_mix == 1): 40 | # \tilde{A}' = AR 41 | f2 = torch.matmul(feat1, P.cuda()).view(out.shape).cuda() 42 | final = f2*lam + feat2.view(out.shape)*(1-lam) 43 | 44 | y_a, y_b = y,y[indices] 45 | 46 | return final, y_a, y_b 47 | 48 | 49 | 50 | class Resnet_classifier(nn.Module): 51 | def __init__(self, num_classes, z_dim=512): 52 | super(Resnet_classifier, self).__init__() 53 | 54 | self.encoder = ResNet18() 55 | self.classifier = nn.Linear(z_dim, num_classes) 56 | 57 | 58 | def forward(self, x, targets, lam, mode): 59 | 60 | if (mode == 'train'): 61 | 62 | layer_mix = random.randint(0,1) 63 | 64 | if layer_mix == 0: 65 | x,t_a,t_b = mixup_process(x, targets, lam) 66 | 67 | out = self.encoder(x, lin=0, lout=0) 68 | out = self.encoder.layer1(out) 69 | out = self.encoder.layer2(out) 70 | out = self.encoder.layer3(out) 71 | out = self.encoder.layer4(out) 72 | 73 | if layer_mix == 1: 74 | out,t_a,t_b = mixup_aligned(out, targets, lam) 75 | 76 | out = F.avg_pool2d(out, 4) 77 | out = out.reshape(out.size(0), -1) 78 | cls_output = self.classifier(out) 79 | 80 | return cls_output, t_a, t_b 81 | 82 | 83 | elif (mode == 'test'): 84 | out = self.encoder(x) 85 | out = F.avg_pool2d(out, 4) 86 | out = out.reshape(out.size(0), -1) 87 | cls_output = self.classifier(out) 88 | 89 | return cls_output -------------------------------------------------------------------------------- /cifar10_100/models/wrn16x8_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from utils.OT import SinkhornDistance 7 | import numpy as np 8 | from models.WideResnet import Wide_ResNet 9 | import random 10 | 11 | 12 | def mixup_process(out, y, lam): 13 | indices = np.random.permutation(out.size(0)) 14 | out = out*lam + out[indices]*(1-lam) 15 | y_a, y_b = y, y[indices] 16 | return out, y_a, y_b 17 | 18 | 19 | 20 | 21 | def mixup_aligned(out, y, lam): 22 | # out shape = batch_size x 512 x 4 x 4 (cifar10/100) 23 | 24 | indices = np.random.permutation(out.size(0)) 25 | feat1 = out.view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 26 | feat2 = out[indices].view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 27 | 28 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) 29 | P = sinkhorn(feat1.permute(0,2,1), feat2.permute(0,2,1)).detach() # optimal plan batch x 16 x 16 30 | 31 | P = P*(out.size(2)*out.size(3)) # assignment matrix 32 | 33 | align_mix = random.randint(0,1) # uniformly choose at random, which alignmix to perform 34 | 35 | if (align_mix == 0): 36 | # \tilde{A} = A'R^{T} 37 | f1 = torch.matmul(feat2, P.permute(0,2,1).cuda()).view(out.shape) 38 | final = feat1.view(out.shape)*lam + f1*(1-lam) 39 | 40 | elif (align_mix == 1): 41 | # \tilde{A}' = AR 42 | f2 = torch.matmul(feat1, P.cuda()).view(out.shape).cuda() 43 | final = f2*lam + feat2.view(out.shape)*(1-lam) 44 | 45 | y_a, y_b = y,y[indices] 46 | 47 | return final, y_a, y_b 48 | 49 | 50 | 51 | 52 | class WideResNet_classifier(nn.Module): 53 | def __init__(self, num_classes, z_dim=512): 54 | super(WideResNet_classifier, self).__init__() 55 | 56 | self.encoder = Wide_ResNet(depth=16, 57 | widen_factor=8, 58 | num_classes=10, 59 | stride=1) 60 | 61 | self.classifier = nn.Linear(z_dim,num_classes) 62 | 63 | 64 | 65 | def forward(self, x, targets, lam, mode): 66 | 67 | if (mode == 'train'): 68 | 69 | layer_mix = random.randint(0,1) 70 | if layer_mix == 0: 71 | x,t_a,t_b = mixup_process(x, targets, lam) 72 | 73 | out = self.encoder(x) 74 | 75 | if layer_mix == 1: 76 | out,t_a,t_b = mixup_aligned(out, targets, lam) 77 | 78 | out = F.avg_pool2d(out, 8) 79 | out = out.reshape(out.size(0), -1) 80 | cls_output = self.classifier(out) 81 | 82 | 83 | return cls_output, t_a, t_b 84 | 85 | 86 | 87 | elif (mode == 'test'): 88 | out = self.encoder(x) 89 | out = F.avg_pool2d(out, 8) 90 | out = out.reshape(out.size(0), -1) 91 | cls_output = self.classifier(out) 92 | 93 | return cls_output 94 | 95 | -------------------------------------------------------------------------------- /cifar10_100/submit_job.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | #OAR -l /host=1/gpu_device=1,walltime=12:00:00 4 | 5 | #OAR -O /srv/tempdd/svenkata/logFiles/AlignMixup_CVPR22/cifar100/log_temp.txt 6 | #OAR -E /srv/tempdd/svenkata/logFiles/AlignMixup_CVPR22/cifar100/log_temp.error 7 | 8 | #patch to be aware of "module" inside a job 9 | . /etc/profile.d/modules.sh 10 | 11 | 12 | echo " got the python script" 13 | 14 | 15 | module load pytorch/1.10.1-py3.8 16 | 17 | 18 | EXECUTABLE="main.py --dataset cifar10 --data_dir /srv/tempdd/svenkata/Datasets/data \ 19 | --save_dir /nfs/pyrex/raid6/svenkata/weights/AlignMixup_CVPR22/ --epochs 2000 \ 20 | --alpha 2.0 --num_classes 10 --manualSeed 8492" 21 | 22 | 23 | echo 24 | echo "=============== RUN ${OAR_JOB_ID} ===============" 25 | echo "Running ..." 26 | python ${EXECUTABLE} $* 27 | echo "Done" 28 | -------------------------------------------------------------------------------- /cifar10_100/utils/OT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff 5 | class SinkhornDistance(nn.Module): 6 | r""" 7 | Given two empirical measures each with :math:`P_1` locations 8 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 9 | outputs an approximation of the regularized OT plan. 10 | Args: 11 | eps (float): regularization coefficient 12 | max_iter (int): maximum number of Sinkhorn iterations 13 | reduction (string, optional): Specifies the reduction to apply to the output: 14 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 15 | 'mean': the sum of the output will be divided by the number of 16 | elements in the output, 'sum': the output will be summed. Default: 'none' 17 | Shape: 18 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 19 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 20 | """ 21 | def __init__(self, eps, max_iter, reduction='none'): 22 | super(SinkhornDistance, self).__init__() 23 | self.eps = eps 24 | self.max_iter = max_iter 25 | self.reduction = reduction 26 | 27 | def forward(self, x, y): 28 | # The Sinkhorn algorithm takes as input three variables : 29 | C = self._cost_matrix(x, y) # Wasserstein cost function 30 | x_points = x.shape[-2] 31 | y_points = y.shape[-2] 32 | if x.dim() == 2: 33 | batch_size = 1 34 | else: 35 | batch_size = x.shape[0] 36 | 37 | # both marginals are fixed with equal weights 38 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 39 | requires_grad=False).fill_(1.0 / x_points).squeeze().cuda() 40 | nu = torch.empty(batch_size, y_points, dtype=torch.float, 41 | requires_grad=False).fill_(1.0 / y_points).squeeze().cuda() 42 | 43 | u = torch.zeros_like(mu).cuda() 44 | v = torch.zeros_like(nu).cuda() 45 | # To check if algorithm terminates because of threshold 46 | # or max iterations reached 47 | actual_nits = 0 48 | # Stopping criterion 49 | thresh = 1e-1 50 | 51 | # Sinkhorn iterations 52 | for i in range(self.max_iter): 53 | u1 = u # useful to check the update 54 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 55 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 56 | err = (u - u1).abs().sum(-1).mean() 57 | 58 | actual_nits += 1 59 | if err.item() < thresh: 60 | break 61 | 62 | U, V = u, v 63 | # Transport plan pi = diag(a)*K*diag(b) 64 | pi = torch.exp(self.M(C, U, V)) 65 | 66 | 67 | return pi 68 | 69 | def M(self, C, u, v): 70 | # "Modified cost for logarithmic updates" 71 | # "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 72 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 73 | 74 | @staticmethod 75 | def _cost_matrix(x, y, p=2): 76 | "Returns the matrix of $|x_i-y_j|^p$." 77 | x_col = x.unsqueeze(-2).cuda() 78 | y_lin = y.unsqueeze(-3).cuda() 79 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 80 | return C 81 | 82 | @staticmethod 83 | def ave(u, u1, tau): 84 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 85 | return tau * u + (1 - tau) * u1 -------------------------------------------------------------------------------- /imagenet/gen.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=imagenet # name of job 3 | #SBATCH --account=nug@gpu 4 | ##SBATCH -C v100-16g # reserving 16 GB GPUs only 5 | #SBATCH --partition=gpu_p2 # uncomment for gpu_p2 partition gpu_p2 6 | #SBATCH --ntasks=8 # total number of processes (= number of GPUs here) 7 | ##SBATCH --ntasks-per-node=4 8 | #SBATCH --nodes=1 # reserving 1 node 9 | #SBATCH --gres=gpu:8 # number of GPUs (1/4 of GPUs) 10 | ##SBATCH --cpus-per-task=10 # number of cores per task (1/4 of the 4-GPUs node) 11 | #SBATCH --cpus-per-task=3 # number of cores per task (with gpu_p2: 1/8 of the 8-GPUs node) 12 | # /!\ Caution, "multithread" in Slurm vocabulary refers to hyperthreading. 13 | #SBATCH --hint=nomultithread # hyperthreading is deactivated 14 | #SBATCH --time=100:00:00 # maximum execution time requested (HH:MM:SS) 15 | #SBATCH --output=logfiles/log.out # name of output file 16 | #SBATCH --error=logfiles/log.error # name of error file (here, in common with the output file) 17 | #SBATCH --qos=qos_gpu-t4 18 | 19 | # cleans out the modules loaded in interactive and inherited by default 20 | module purge 21 | 22 | # loading of modules 23 | module load pytorch-gpu/py3/1.10.1 24 | 25 | # echo of launched commands 26 | set -x 27 | 28 | # code execution 29 | python -u main.py --data_dir /gpfsdswork/dataset/imagenet/RawImages --save_dir /gpfsstore/rech/nug/udq92qm/imagenet/ \ 30 | --num_classes 1000 --alpha 2.0 --batch_size 1024 -------------------------------------------------------------------------------- /imagenet/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import argparse 5 | import itertools 6 | from models.resnet50_classifier import Resnet_classifier 7 | from utils import util 8 | 9 | import torch 10 | from torch import nn, optim 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | import torch.backends.cudnn as cudnn 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Trains ResNet-50 on ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--data_dir', type = str, default = '', 18 | help='file where results are to be written') 19 | parser.add_argument('--save_dir', type = str, default = '', 20 | help='folder where results are to be stored') 21 | parser.add_argument('--mini_imagenet', type = bool, default = False, 22 | help='Use subset of imagenet for training') 23 | parser.add_argument('--subset', type = int, default = 260, 24 | help='number of samples from each class. Since there are 1300 samples in each class, 260/1300 is 20/% /of training set') 25 | 26 | 27 | # Optimization options 28 | parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.') 29 | parser.add_argument('--alpha', type=float, default=2.0, help='alpha parameter for mixup') 30 | parser.add_argument('--num_classes', type=int, default=10, help='number of classes, set 100 for CIFAR-100') 31 | parser.add_argument('--decay', type=float, default=1e-4, help='Weight decay (L2 penalty).') 32 | 33 | parser.add_argument('--batch_size', type=int, default=512, help='Batch size.') 34 | parser.add_argument('--lr_', type=float, default=0.1, help='The Learning Rate.') 35 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 36 | 37 | 38 | # Checkpoints 39 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 41 | 42 | # Acceleration 43 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') 44 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') 45 | 46 | # random seed 47 | parser.add_argument('--manualSeed', type=int, help='manual seed') 48 | 49 | args = parser.parse_args() 50 | 51 | out_str = str(args) 52 | print(out_str) 53 | 54 | 55 | device = torch.device("cuda" if args.ngpu>0 and torch.cuda.is_available() else "cpu") 56 | 57 | if args.manualSeed is None: 58 | args.manualSeed = random.randint(1, 10000) 59 | 60 | random.seed(args.manualSeed) 61 | np.random.seed(args.manualSeed) 62 | torch.manual_seed(args.manualSeed) 63 | torch.cuda.manual_seed_all(args.manualSeed) 64 | cudnn.benchmark = True 65 | 66 | 67 | if not os.path.exists(args.save_dir): 68 | os.makedirs(args.save_dir) 69 | 70 | 71 | mean = [0.485, 0.456, 0.406] 72 | std = [0.229, 0.224, 0.225] 73 | normalize = transforms.Normalize(mean=mean, std=std) 74 | jittering = util.ColorJitter(brightness=0.4, contrast=0.4, 75 | saturation=0.4) 76 | lighting = util.Lighting(alphastd=0.1, 77 | eigval=[0.2175, 0.0188, 0.0045], 78 | eigvec=[[-0.5675, 0.7192, 0.4009], 79 | [-0.5808, -0.0045, -0.8140], 80 | [-0.5836, -0.6948, 0.4203]]) 81 | 82 | 83 | transform_train = transforms.Compose([ 84 | transforms.RandomResizedCrop(224), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | jittering, 88 | lighting, 89 | normalize, 90 | ]) 91 | 92 | 93 | transform_test = transforms.Compose([ 94 | transforms.Resize(256), 95 | transforms.CenterCrop(224), 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.485, 0.456, 0.406), 98 | (0.229, 0.224, 0.225)), 99 | ]) 100 | 101 | 102 | 103 | train_data = datasets.ImageFolder(root=os.path.join(args.data_dir, 'train'), transform=transform_train) 104 | if args.mini_imagenet: 105 | # use 20% of the training set. For research who lack resources 106 | train_data = util.subset_of_ImageNet_train_split(train_data, subset=args.subset) 107 | 108 | test_data = datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'), transform=transform_test) 109 | 110 | trainloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 111 | testloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 112 | 113 | 114 | 115 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 116 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 117 | 118 | 119 | 120 | def adjust_learning_rate(optimizer, epoch): 121 | """Sets the learning rate to the initial LR decayed by 10 every 500 epochs""" 122 | lr = args.lr_ 123 | if args.epochs == 300: 124 | lr = args.lr_ * (0.1**(epoch // 75)) 125 | else: 126 | lr = args.lr_ * (0.1**(epoch // 30)) 127 | 128 | for param_group in optimizer.param_groups: 129 | param_group['lr'] = lr 130 | 131 | return lr 132 | 133 | 134 | 135 | model = Resnet_classifier(args.num_classes) 136 | model = torch.nn.DataParallel(model) 137 | model.to(device) 138 | print(model) 139 | 140 | criterion = nn.CrossEntropyLoss() 141 | optimizer = optim.SGD(model.parameters(), lr=args.lr_, momentum=args.momentum, weight_decay=args.decay, nesterov=True) 142 | best_acc = 0 143 | 144 | 145 | if args.resume: 146 | if os.path.isfile(args.resume): 147 | print("=> loading checkpoint '{}'".format(args.resume)) 148 | checkpoint = torch.load(args.resume) 149 | args.start_epoch = checkpoint['epoch'] 150 | model.load_state_dict(checkpoint['model']) 151 | optimizer.load_state_dict(checkpoint['optimizer']) 152 | best_acc = checkpoint['acc'] 153 | print("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch'])) 154 | else: 155 | print("=> no checkpoint found at '{}'".format(args.resume)) 156 | 157 | 158 | 159 | 160 | def train(epoch): 161 | 162 | model.train() 163 | total_loss = 0 164 | correct = 0 165 | 166 | for i, (images, targets) in enumerate(trainloader): 167 | 168 | images = images.to(device) 169 | targets = targets.to(device) 170 | 171 | lam = np.random.beta(args.alpha, args.alpha) 172 | outputs,targets_a,targets_b = model(images, targets, lam, mode='train') 173 | loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) 174 | 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | 179 | total_loss += loss.item() 180 | 181 | _,pred = torch.max(outputs, dim=1) 182 | correct += (pred == targets).sum().item() 183 | 184 | print('epoch: {} --> Train loss = {:.4f} Train Accuracy = {:.4f} '.format(epoch, total_loss / len(trainloader.dataset), 100.*correct / len(trainloader.dataset))) 185 | 186 | 187 | 188 | def test(epoch): 189 | global best_acc 190 | model.eval() 191 | test_loss = 0 192 | correct = 0 193 | total = 0 194 | with torch.no_grad(): 195 | for batch_idx, (inputs, targets) in enumerate(testloader): 196 | 197 | inputs = inputs.to(device) 198 | targets = targets.to(device) 199 | 200 | outputs = model(inputs, None, None, mode='test') 201 | loss = criterion(outputs, targets) 202 | 203 | test_loss += loss.item() 204 | _, predicted = torch.max(outputs.data, 1) 205 | total += targets.size(0) 206 | correct += predicted.eq(targets.data).cpu().sum() 207 | 208 | print('------> epoch: {} --> Test loss = {:.4f} Test Accuracy = {:.4f} '.format(epoch,test_loss / len(testloader.dataset), 100.*correct / len(testloader.dataset))) 209 | 210 | acc = 100.*correct/total 211 | if acc > best_acc: 212 | checkpoint(acc, epoch) 213 | best_acc = acc 214 | 215 | return best_acc 216 | 217 | 218 | def checkpoint(acc, epoch): 219 | # Save checkpoint. 220 | print('Saving..') 221 | state = { 222 | 'model': model.state_dict(), 223 | 'optimizer' : optimizer.state_dict(), 224 | 'acc': acc, 225 | 'epoch': epoch, 226 | } 227 | 228 | torch.save(state, args.save_dir + 'checkpoint.t7') 229 | 230 | 231 | 232 | if __name__ == '__main__': 233 | for epoch in range(args.start_epoch, args.epochs): 234 | adjust_learning_rate(optimizer, epoch) 235 | train(epoch) 236 | best_accuracy = test(epoch) 237 | 238 | print('Best Accuracy = ', best_accuracy) 239 | 240 | 241 | -------------------------------------------------------------------------------- /imagenet/models/resnet.py: -------------------------------------------------------------------------------- 1 | 2 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | # This code is taken from https://github.com/clovaai/CutMix-PyTorch 4 | 5 | import torch.nn as nn 6 | import math 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, bottleneck=False): 87 | super(ResNet, self).__init__() 88 | 89 | depth = 50 90 | num_classes = 1000 91 | blocks ={50: Bottleneck} 92 | layers ={50: [3, 4, 6, 3]} 93 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 94 | 95 | self.inplanes = 64 96 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 97 | self.bn1 = nn.BatchNorm2d(64) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 101 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 102 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 103 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 104 | # self.avgpool = nn.AvgPool2d(7) 105 | # self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | # x = self.avgpool(x) 145 | # x = x.view(x.size(0), -1) 146 | # x = self.fc(x) 147 | 148 | return x 149 | -------------------------------------------------------------------------------- /imagenet/models/resnet50_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from utils.OT import SinkhornDistance 7 | import numpy as np 8 | import random 9 | from torchvision.models import resnet50 10 | import torch.utils.model_zoo as model_zoo 11 | from models.resnet import ResNet 12 | 13 | 14 | def mixup_process(out, y, lam): 15 | indices = np.random.permutation(out.size(0)) 16 | out = out*lam + out[indices]*(1-lam) 17 | y_a, y_b = y, y[indices] 18 | return out, y_a, y_b 19 | 20 | 21 | 22 | def mixup_aligned(out, y, lam): 23 | # out shape = batch_size x 512 x 4 x 4 (cifar10/100) 24 | 25 | indices = np.random.permutation(out.size(0)) 26 | feat1 = out.view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 27 | feat2 = out[indices].view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 28 | 29 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) 30 | P = sinkhorn(feat1.permute(0,2,1), feat2.permute(0,2,1)).detach() # optimal plan batch x 16 x 16 31 | 32 | P = P*(out.size(2)*out.size(3)) # assignment matrix 33 | 34 | align_mix = random.randint(0,1) # uniformly choose at random, which alignmix to perform 35 | 36 | if (align_mix == 0): 37 | # \tilde{A} = A'R^{T} 38 | f1 = torch.matmul(feat2, P.permute(0,2,1).cuda()).view(out.shape) 39 | final = feat1.view(out.shape)*lam + f1*(1-lam) 40 | 41 | elif (align_mix == 1): 42 | # \tilde{A}' = AR 43 | f2 = torch.matmul(feat1, P.cuda()).view(out.shape).cuda() 44 | final = f2*lam + feat2.view(out.shape)*(1-lam) 45 | 46 | y_a, y_b = y,y[indices] 47 | 48 | return final, y_a, y_b 49 | 50 | 51 | 52 | 53 | class Resnet_classifier(nn.Module): 54 | def __init__(self, num_classes, z_dim=2048): 55 | super(Resnet_classifier, self).__init__() 56 | 57 | self.model = ResNet() 58 | self.classifier = nn.Linear(z_dim, num_classes) 59 | 60 | 61 | def forward(self, x, targets, lam, mode): 62 | 63 | if (mode == 'train'): 64 | layer_mix = random.randint(0,1) 65 | 66 | if layer_mix == 0: 67 | x,t_a,t_b = mixup_process(x, targets, lam) 68 | 69 | x = self.model(x) 70 | 71 | if layer_mix == 1: 72 | x,t_a,t_b = mixup_aligned(x, targets, lam) 73 | 74 | # print(x.shape) 75 | # exit() 76 | x = F.avg_pool2d(x, 7) 77 | x = x.view(x.size(0), -1) 78 | cls_output = self.classifier(x) 79 | 80 | return cls_output, t_a,t_b 81 | 82 | 83 | elif (mode == 'test'): 84 | out = self.model(x) 85 | out = F.avg_pool2d(out, 7) 86 | out = out.reshape(out.size(0), -1) 87 | cls_output = self.classifier(out) 88 | 89 | return cls_output 90 | -------------------------------------------------------------------------------- /imagenet/utils/OT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff 5 | class SinkhornDistance(nn.Module): 6 | r""" 7 | Given two empirical measures each with :math:`P_1` locations 8 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 9 | outputs an approximation of the regularized OT plan. 10 | Args: 11 | eps (float): regularization coefficient 12 | max_iter (int): maximum number of Sinkhorn iterations 13 | reduction (string, optional): Specifies the reduction to apply to the output: 14 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 15 | 'mean': the sum of the output will be divided by the number of 16 | elements in the output, 'sum': the output will be summed. Default: 'none' 17 | Shape: 18 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 19 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 20 | """ 21 | def __init__(self, eps, max_iter, reduction='none'): 22 | super(SinkhornDistance, self).__init__() 23 | self.eps = eps 24 | self.max_iter = max_iter 25 | self.reduction = reduction 26 | 27 | def forward(self, x, y): 28 | # The Sinkhorn algorithm takes as input three variables : 29 | C = self._cost_matrix(x, y) # Wasserstein cost function 30 | x_points = x.shape[-2] 31 | y_points = y.shape[-2] 32 | if x.dim() == 2: 33 | batch_size = 1 34 | else: 35 | batch_size = x.shape[0] 36 | 37 | # both marginals are fixed with equal weights 38 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 39 | requires_grad=False).fill_(1.0 / x_points).squeeze().cuda() 40 | nu = torch.empty(batch_size, y_points, dtype=torch.float, 41 | requires_grad=False).fill_(1.0 / y_points).squeeze().cuda() 42 | 43 | u = torch.zeros_like(mu).cuda() 44 | v = torch.zeros_like(nu).cuda() 45 | # To check if algorithm terminates because of threshold 46 | # or max iterations reached 47 | actual_nits = 0 48 | # Stopping criterion 49 | thresh = 1e-1 50 | 51 | # Sinkhorn iterations 52 | for i in range(self.max_iter): 53 | u1 = u # useful to check the update 54 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 55 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 56 | err = (u - u1).abs().sum(-1).mean() 57 | 58 | actual_nits += 1 59 | if err.item() < thresh: 60 | break 61 | 62 | U, V = u, v 63 | # Transport plan pi = diag(a)*K*diag(b) 64 | pi = torch.exp(self.M(C, U, V)) 65 | 66 | 67 | return pi 68 | 69 | def M(self, C, u, v): 70 | # "Modified cost for logarithmic updates" 71 | # "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 72 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 73 | 74 | @staticmethod 75 | def _cost_matrix(x, y, p=2): 76 | "Returns the matrix of $|x_i-y_j|^p$." 77 | x_col = x.unsqueeze(-2).cuda() 78 | y_lin = y.unsqueeze(-3).cuda() 79 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 80 | return C 81 | 82 | @staticmethod 83 | def ave(u, u1, tau): 84 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 85 | return tau * u + (1 - tau) * u1 -------------------------------------------------------------------------------- /imagenet/utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import torchvision 5 | from torchvision import transforms 6 | import PIL.Image 7 | import torch 8 | import random 9 | 10 | 11 | # original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py 12 | 13 | __all__ = ["Compose", "Lighting", "ColorJitter"] 14 | 15 | 16 | class Compose(object): 17 | """Composes several transforms together. 18 | Args: 19 | transforms (list of ``Transform`` objects): list of transforms to compose. 20 | Example: 21 | >>> transforms.Compose([ 22 | >>> transforms.CenterCrop(10), 23 | >>> transforms.ToTensor(), 24 | >>> ]) 25 | """ 26 | 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | 30 | def __call__(self, img): 31 | for t in self.transforms: 32 | img = t(img) 33 | return img 34 | 35 | def __repr__(self): 36 | format_string = self.__class__.__name__ + '(' 37 | for t in self.transforms: 38 | format_string += '\n' 39 | format_string += ' {0}'.format(t) 40 | format_string += '\n)' 41 | return format_string 42 | 43 | 44 | class Lighting(object): 45 | """Lighting noise(AlexNet - style PCA - based noise)""" 46 | 47 | def __init__(self, alphastd, eigval, eigvec): 48 | self.alphastd = alphastd 49 | self.eigval = torch.Tensor(eigval) 50 | self.eigvec = torch.Tensor(eigvec) 51 | 52 | def __call__(self, img): 53 | if self.alphastd == 0: 54 | return img 55 | 56 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 57 | rgb = self.eigvec.type_as(img).clone() \ 58 | .mul(alpha.view(1, 3).expand(3, 3)) \ 59 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 60 | .sum(1).squeeze() 61 | 62 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 63 | 64 | 65 | class Grayscale(object): 66 | 67 | def __call__(self, img): 68 | gs = img.clone() 69 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 70 | gs[1].copy_(gs[0]) 71 | gs[2].copy_(gs[0]) 72 | return gs 73 | 74 | 75 | class Saturation(object): 76 | 77 | def __init__(self, var): 78 | self.var = var 79 | 80 | def __call__(self, img): 81 | gs = Grayscale()(img) 82 | alpha = random.uniform(-self.var, self.var) 83 | return img.lerp(gs, alpha) 84 | 85 | 86 | class Brightness(object): 87 | 88 | def __init__(self, var): 89 | self.var = var 90 | 91 | def __call__(self, img): 92 | gs = img.new().resize_as_(img).zero_() 93 | alpha = random.uniform(-self.var, self.var) 94 | return img.lerp(gs, alpha) 95 | 96 | 97 | class Contrast(object): 98 | 99 | def __init__(self, var): 100 | self.var = var 101 | 102 | def __call__(self, img): 103 | gs = Grayscale()(img) 104 | gs.fill_(gs.mean()) 105 | alpha = random.uniform(-self.var, self.var) 106 | return img.lerp(gs, alpha) 107 | 108 | 109 | class ColorJitter(object): 110 | 111 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 112 | self.brightness = brightness 113 | self.contrast = contrast 114 | self.saturation = saturation 115 | 116 | def __call__(self, img): 117 | self.transforms = [] 118 | if self.brightness != 0: 119 | self.transforms.append(Brightness(self.brightness)) 120 | if self.contrast != 0: 121 | self.transforms.append(Contrast(self.contrast)) 122 | if self.saturation != 0: 123 | self.transforms.append(Saturation(self.saturation)) 124 | 125 | random.shuffle(self.transforms) 126 | transform = Compose(self.transforms) 127 | # print(transform) 128 | return transform(img) 129 | 130 | 131 | 132 | # Taken from https://github.com/valeoai/obow/blob/3758504f5e058275725c35ca7faca3731572b911/obow/datasets.py#L147 133 | 134 | def subset_of_ImageNet_train_split(dataset_train, subset): 135 | assert isinstance(subset, int) 136 | assert subset > 0 137 | 138 | all_indices = [] 139 | for _, img_indices in buildLabelIndex(dataset_train.targets).items(): 140 | assert len(img_indices) >= subset 141 | all_indices += img_indices[:subset] 142 | 143 | dataset_train.imgs = [dataset_train.imgs[idx] for idx in all_indices] 144 | dataset_train.samples = [dataset_train.samples[idx] for idx in all_indices] 145 | dataset_train.targets = [dataset_train.targets[idx] for idx in all_indices] 146 | assert len(dataset_train) == (subset * 1000) 147 | 148 | return dataset_train 149 | 150 | 151 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 152 | warmup_schedule = np.array([]) 153 | warmup_iters = warmup_epochs * niter_per_ep 154 | if warmup_epochs > 0: 155 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 156 | 157 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 158 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 159 | 160 | schedule = np.concatenate((warmup_schedule, schedule)) 161 | assert len(schedule) == epochs * niter_per_ep 162 | return schedule -------------------------------------------------------------------------------- /tiny_imgnet/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import argparse 5 | import itertools 6 | from models.resnet18_classifier import Resnet_classifier 7 | 8 | import torch 9 | from torch import nn, optim 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | import torch.backends.cudnn as cudnn 13 | from utils.dataLoader import dataloader_train, dataloader_val 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Trains ResNet on tiny-imagenet-200', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--train_dir', type = str, default = '/nfs/pyrex/raid6/svenkata/Datasets/tiny-imagenet-200/train', 18 | help='path to train folder') 19 | parser.add_argument('--val_dir', type = str, default = '/nfs/pyrex/raid6/svenkata/Datasets/tiny-imagenet-200/val', 20 | help='path to val folder') 21 | parser.add_argument('--save_dir', type = str, default = '/nfs/pyrex/raid6/svenkata/weights/AlignMixup_CVPR22/tiny_imagenet/', 22 | help='folder where results are to be stored') 23 | 24 | # Optimization options 25 | 26 | parser.add_argument('--epochs', type=int, default=1200, help='Number of epochs to train.') 27 | parser.add_argument('--alpha', type=float, default=2.0, help='alpha parameter for mixup') 28 | parser.add_argument('--num_classes', type=int, default=200, help='number of classes, set 100 for CIFAR-100') 29 | 30 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size.') 31 | parser.add_argument('--lr_', type=float, default=0.1, help='The Learning Rate.') 32 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 33 | 34 | parser.add_argument('--decay', type=float, default=1e-4, help='Weight decay (L2 penalty).') 35 | parser.add_argument('--schedule', type=int, nargs='+', default=[600, 900], help='Decrease learning rate at these epochs.') 36 | parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule') 37 | 38 | # Checkpoints 39 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 41 | 42 | # Acceleration 43 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.') 44 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers (default: 8)') 45 | 46 | # random seed 47 | parser.add_argument('--manualSeed', type=int, help='manual seed') 48 | 49 | args = parser.parse_args() 50 | 51 | out_str = str(args) 52 | print(out_str) 53 | 54 | 55 | 56 | device = torch.device("cuda" if args.ngpu>0 and torch.cuda.is_available() else "cpu") 57 | 58 | if args.manualSeed is None: 59 | args.manualSeed = random.randint(1, 10000) 60 | 61 | random.seed(args.manualSeed) 62 | np.random.seed(args.manualSeed) 63 | torch.manual_seed(args.manualSeed) 64 | torch.cuda.manual_seed_all(args.manualSeed) 65 | cudnn.benchmark = True 66 | 67 | 68 | if not os.path.exists(args.save_dir): 69 | os.makedirs(args.save_dir) 70 | 71 | 72 | 73 | transform_train = transforms.Compose([ 74 | transforms.RandomCrop(64, padding=4), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.4914, 0.4822, 0.4465), 78 | (0.2023, 0.1994, 0.2010)), 79 | ]) 80 | 81 | 82 | transform_test = transforms.Compose([ 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.4914, 0.4822, 0.4465), 85 | (0.2023, 0.1994, 0.2010)), 86 | ]) 87 | 88 | 89 | 90 | train_data = dataloader_train(args.train_dir, transform_train) 91 | test_data = dataloader_val(os.path.join(args.val_dir, 'images'), os.path.join(args.val_dir, 'val_annotations.txt'), transform_test) 92 | 93 | trainloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 94 | testloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 95 | 96 | 97 | model = Resnet_classifier(num_classes=args.num_classes) 98 | model = torch.nn.DataParallel(model) 99 | model.to(device) 100 | print(model) 101 | 102 | optimizer = optim.SGD(model.parameters(), lr=args.lr_, momentum=args.momentum, weight_decay=args.decay) 103 | criterion = nn.CrossEntropyLoss() 104 | best_acc = 0 105 | 106 | if args.resume: 107 | if os.path.isfile(args.resume): 108 | print("=> loading checkpoint '{}'".format(args.resume)) 109 | checkpoint = torch.load(args.resume) 110 | args.start_epoch = checkpoint['epoch'] 111 | model.load_state_dict(checkpoint['model']) 112 | optimizer.load_state_dict(checkpoint['optimizer']) 113 | best_acc = checkpoint['acc'] 114 | print("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch'])) 115 | else: 116 | print("=> no checkpoint found at '{}'".format(args.resume)) 117 | 118 | 119 | 120 | def train(epoch): 121 | model.train() 122 | total_loss = 0 123 | correct = 0 124 | for i, (images, targets) in enumerate(trainloader): 125 | 126 | images = images.to(device) 127 | targets = targets.to(device) 128 | 129 | lam = np.random.beta(args.alpha, args.alpha) 130 | 131 | outputs,targets_a,targets_b = model(images, targets, lam, mode='train') 132 | 133 | loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) 134 | 135 | optimizer.zero_grad() 136 | loss.backward() 137 | optimizer.step() 138 | 139 | total_loss += loss.item() 140 | 141 | _,pred = torch.max(outputs, dim=1) 142 | correct += (pred == targets).sum().item() 143 | 144 | print('epoch: {} --> Train loss = {:.4f} Train Accuracy = {:.4f} '.format(epoch, total_loss / len(trainloader.dataset), 100.*correct / len(trainloader.dataset))) 145 | 146 | 147 | 148 | def test(epoch): 149 | #best_acc = 0 150 | global best_acc 151 | model.eval() 152 | test_loss = 0 153 | correct = 0 154 | total = 0 155 | 156 | with torch.no_grad(): 157 | for batch_idx, (inputs, targets) in enumerate(testloader): 158 | 159 | inputs = inputs.to(device) 160 | targets = targets.to(device) 161 | 162 | outputs = model(inputs, None, None, mode='test') 163 | 164 | loss = criterion(outputs, targets) 165 | 166 | test_loss += loss.item() 167 | _, predicted = torch.max(outputs.data, 1) 168 | total += targets.size(0) 169 | correct += predicted.eq(targets.data).cpu().sum() 170 | 171 | print('------> epoch: {} --> Test loss = {:.4f} Test Accuracy = {:.4f} '.format(epoch,test_loss / len(testloader.dataset), 100.*correct / len(testloader.dataset))) 172 | 173 | acc = 100.*correct/total 174 | if acc > best_acc: 175 | checkpoint(acc, epoch) 176 | best_acc = acc 177 | 178 | return best_acc 179 | 180 | 181 | def checkpoint(acc, epoch): 182 | # Save checkpoint. 183 | print('Saving..') 184 | state = { 185 | 'model': model.state_dict(), 186 | 'optimizer' : optimizer.state_dict(), 187 | 'acc': acc, 188 | 'epoch': epoch, 189 | 'seed' : args.manualSeed 190 | } 191 | 192 | torch.save(state, args.save_dir + 'checkpoint.t7') 193 | 194 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 195 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 196 | 197 | 198 | 199 | def adjust_learning_rate(optimizer, epoch, gammas, schedule): 200 | """Sets the learning rate to the initial LR decayed by 10 at 600 and 900 epochs""" 201 | lr = args.lr_ 202 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 203 | for (gamma, step) in zip(gammas, schedule): 204 | if (epoch >= step): 205 | lr = lr * gamma 206 | else: 207 | break 208 | for param_group in optimizer.param_groups: 209 | param_group['lr'] = lr 210 | 211 | return lr 212 | 213 | 214 | if __name__ == '__main__': 215 | for epoch in range(args.start_epoch, args.epochs): 216 | adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule) 217 | train(epoch) 218 | best_accuracy = test(epoch) 219 | 220 | print('Best Accuracy = ', best_accuracy) 221 | 222 | -------------------------------------------------------------------------------- /tiny_imgnet/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from torch.autograd import Variable 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(in_planes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion*planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 36 | nn.BatchNorm2d(self.expansion*planes) 37 | ) 38 | 39 | def forward(self, x): 40 | out = F.relu(self.bn1(self.conv1(x))) 41 | out = self.bn2(self.conv2(out)) 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | '''Pre-activation version of the BasicBlock.''' 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1): 52 | super(PreActBlock, self).__init__() 53 | self.bn1 = nn.BatchNorm2d(in_planes) 54 | self.conv1 = conv3x3(in_planes, planes, stride) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv2 = conv3x3(planes, planes) 57 | 58 | self.shortcut = nn.Sequential() 59 | if stride != 1 or in_planes != self.expansion*planes: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | out += shortcut 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 84 | 85 | self.shortcut = nn.Sequential() 86 | if stride != 1 or in_planes != self.expansion*planes: 87 | self.shortcut = nn.Sequential( 88 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(self.expansion*planes) 90 | ) 91 | 92 | def forward(self, x): 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | out = F.relu(self.bn2(self.conv2(out))) 95 | out = self.bn3(self.conv3(out)) 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | 101 | class PreActBottleneck(nn.Module): 102 | '''Pre-activation version of the original Bottleneck module.''' 103 | expansion = 4 104 | 105 | def __init__(self, in_planes, planes, stride=1): 106 | super(PreActBottleneck, self).__init__() 107 | self.bn1 = nn.BatchNorm2d(in_planes) 108 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes) 112 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 113 | 114 | self.shortcut = nn.Sequential() 115 | if stride != 1 or in_planes != self.expansion*planes: 116 | self.shortcut = nn.Sequential( 117 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 118 | ) 119 | 120 | def forward(self, x): 121 | out = F.relu(self.bn1(x)) 122 | shortcut = self.shortcut(out) 123 | out = self.conv1(out) 124 | out = self.conv2(F.relu(self.bn2(out))) 125 | out = self.conv3(F.relu(self.bn3(out))) 126 | out += shortcut 127 | return out 128 | 129 | 130 | class ResNet(nn.Module): 131 | def __init__(self, block, num_blocks, num_classes=10): 132 | super(ResNet, self).__init__() 133 | self.in_planes = 64 134 | self.conv1 = conv3x3(3,64) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 137 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 140 | # self.linear = nn.Linear(512*block.expansion, num_classes) 141 | 142 | def _make_layer(self, block, planes, num_blocks, stride): 143 | strides = [stride] + [1]*(num_blocks-1) 144 | layers = [] 145 | for stride in strides: 146 | layers.append(block(self.in_planes, planes, stride)) 147 | self.in_planes = planes * block.expansion 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x, lin=0, lout=4): 151 | out = x 152 | if lin < 1 and lout > -1: 153 | out = self.conv1(out) 154 | out = self.bn1(out) 155 | out = F.relu(out) 156 | if lin < 2 and lout > 0: 157 | out = self.layer1(out) 158 | if lin < 3 and lout > 1: 159 | out = self.layer2(out) 160 | if lin < 4 and lout > 2: 161 | out = self.layer3(out) 162 | if lin < 5 and lout > 3: 163 | out = self.layer4(out) 164 | # if lout > 4: 165 | # out = F.avg_pool2d(out, 4) 166 | # out = out.view(out.size(0), -1) 167 | # out = self.linear(out) 168 | return out 169 | 170 | 171 | def ResNet18(): 172 | return ResNet(PreActBlock, [2,2,2,2]) 173 | 174 | def ResNet34(): 175 | return ResNet(BasicBlock, [3,4,6,3]) 176 | 177 | def ResNet50(): 178 | return ResNet(Bottleneck, [3,4,6,3]) 179 | 180 | def ResNet101(): 181 | return ResNet(Bottleneck, [3,4,23,3]) 182 | 183 | def ResNet152(): 184 | return ResNet(Bottleneck, [3,8,36,3]) 185 | 186 | 187 | def test(): 188 | net = ResNet18() 189 | y = net(Variable(torch.randn(1,3,32,32))) 190 | print(y.size()) 191 | 192 | # test() 193 | -------------------------------------------------------------------------------- /tiny_imgnet/models/resnet18_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from utils.OT import SinkhornDistance 7 | import numpy as np 8 | from models.resnet import ResNet18 9 | import random 10 | 11 | 12 | def mixup_process(out, y, lam): 13 | indices = np.random.permutation(out.size(0)) 14 | out = out*lam + out[indices]*(1-lam) 15 | y_a, y_b = y, y[indices] 16 | return out, y_a, y_b 17 | 18 | 19 | 20 | def mixup_aligned(out, y, lam): 21 | # out shape = batch_size x 512 x 4 x 4 (cifar10/100) 22 | 23 | indices = np.random.permutation(out.size(0)) 24 | feat1 = out.view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 25 | feat2 = out[indices].view(out.shape[0], out.shape[1], -1) # batch_size x 512 x 16 26 | 27 | sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) 28 | P = sinkhorn(feat1.permute(0,2,1), feat2.permute(0,2,1)).detach() # optimal plan batch x 16 x 16 29 | 30 | P = P*(out.size(2)*out.size(3)) # assignment matrix 31 | 32 | align_mix = random.randint(0,1) # uniformly choose at random, which alignmix to perform 33 | 34 | if (align_mix == 0): 35 | # \tilde{A} = A'R^{T} 36 | f1 = torch.matmul(feat2, P.permute(0,2,1).cuda()).view(out.shape) 37 | final = feat1.view(out.shape)*lam + f1*(1-lam) 38 | 39 | elif (align_mix == 1): 40 | # \tilde{A}' = AR 41 | f2 = torch.matmul(feat1, P.cuda()).view(out.shape).cuda() 42 | final = f2*lam + feat2.view(out.shape)*(1-lam) 43 | 44 | y_a, y_b = y,y[indices] 45 | 46 | return final, y_a, y_b 47 | 48 | 49 | 50 | class Resnet_classifier(nn.Module): 51 | def __init__(self, num_classes, z_dim=512): 52 | super(Resnet_classifier, self).__init__() 53 | 54 | self.encoder = ResNet18() 55 | self.classifier = nn.Linear(z_dim, num_classes) 56 | 57 | 58 | def forward(self, x, targets, lam, mode): 59 | 60 | if (mode == 'train'): 61 | 62 | layer_mix = random.randint(0,1) 63 | 64 | if layer_mix == 0: 65 | x,t_a,t_b = mixup_process(x, targets, lam) 66 | 67 | out = self.encoder(x, lin=0, lout=0) 68 | out = self.encoder.layer1(out) 69 | out = self.encoder.layer2(out) 70 | out = self.encoder.layer3(out) 71 | out = self.encoder.layer4(out) 72 | 73 | if layer_mix == 1: 74 | out,t_a,t_b = mixup_aligned(out, targets, lam) 75 | 76 | out = F.avg_pool2d(out, 8) 77 | out = out.reshape(out.size(0), -1) 78 | cls_output = self.classifier(out) 79 | 80 | return cls_output, t_a, t_b 81 | 82 | 83 | elif (mode == 'test'): 84 | out = self.encoder(x) 85 | out = F.avg_pool2d(out, 8) 86 | out = out.reshape(out.size(0), -1) 87 | cls_output = self.classifier(out) 88 | 89 | return cls_output -------------------------------------------------------------------------------- /tiny_imgnet/submit_job.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | #OAR -l /host=1/gpu_device=2,walltime=48:00:00 4 | 5 | #OAR -O /srv/tempdd/svenkata/logFiles/AlignMixup_CVPR22/tiny_imagenet/log.txt 6 | #OAR -E /srv/tempdd/svenkata/logFiles/AlignMixup_CVPR22/tiny_imagenet/log.error 7 | 8 | #patch to be aware of "module" inside a job 9 | . /etc/profile.d/modules.sh 10 | 11 | 12 | echo " got the python script" 13 | 14 | 15 | module load pytorch/1.10.1-py3.8 16 | 17 | 18 | EXECUTABLE="main.py --train_dir /nfs/pyrex/raid6/svenkata/Datasets/tiny-imagenet-200/train \ 19 | --val_dir /nfs/pyrex/raid6/svenkata/Datasets/tiny-imagenet-200/val \ 20 | --save_dir /nfs/pyrex/raid6/svenkata/weights/AlignMixup_CVPR22/tiny_imagenet --epochs 1200 \ 21 | --alpha 2.0 --num_classes 200 --manualSeed 8492" 22 | 23 | 24 | 25 | echo 26 | echo "=============== RUN ${OAR_JOB_ID} ===============" 27 | echo "Running ..." 28 | python ${EXECUTABLE} $* 29 | echo "Done" 30 | -------------------------------------------------------------------------------- /tiny_imgnet/utils/OT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Adapted from https://github.com/gpeyre/SinkhornAutoDiff 5 | class SinkhornDistance(nn.Module): 6 | r""" 7 | Given two empirical measures each with :math:`P_1` locations 8 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 9 | outputs an approximation of the regularized OT plan. 10 | Args: 11 | eps (float): regularization coefficient 12 | max_iter (int): maximum number of Sinkhorn iterations 13 | reduction (string, optional): Specifies the reduction to apply to the output: 14 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 15 | 'mean': the sum of the output will be divided by the number of 16 | elements in the output, 'sum': the output will be summed. Default: 'none' 17 | Shape: 18 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 19 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 20 | """ 21 | def __init__(self, eps, max_iter, reduction='none'): 22 | super(SinkhornDistance, self).__init__() 23 | self.eps = eps 24 | self.max_iter = max_iter 25 | self.reduction = reduction 26 | 27 | def forward(self, x, y): 28 | # The Sinkhorn algorithm takes as input three variables : 29 | C = self._cost_matrix(x, y) # Wasserstein cost function 30 | x_points = x.shape[-2] 31 | y_points = y.shape[-2] 32 | if x.dim() == 2: 33 | batch_size = 1 34 | else: 35 | batch_size = x.shape[0] 36 | 37 | # both marginals are fixed with equal weights 38 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 39 | requires_grad=False).fill_(1.0 / x_points).squeeze().cuda() 40 | nu = torch.empty(batch_size, y_points, dtype=torch.float, 41 | requires_grad=False).fill_(1.0 / y_points).squeeze().cuda() 42 | 43 | u = torch.zeros_like(mu).cuda() 44 | v = torch.zeros_like(nu).cuda() 45 | # To check if algorithm terminates because of threshold 46 | # or max iterations reached 47 | actual_nits = 0 48 | # Stopping criterion 49 | thresh = 1e-1 50 | 51 | # Sinkhorn iterations 52 | for i in range(self.max_iter): 53 | u1 = u # useful to check the update 54 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 55 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 56 | err = (u - u1).abs().sum(-1).mean() 57 | 58 | actual_nits += 1 59 | if err.item() < thresh: 60 | break 61 | 62 | U, V = u, v 63 | # Transport plan pi = diag(a)*K*diag(b) 64 | pi = torch.exp(self.M(C, U, V)) 65 | 66 | 67 | return pi 68 | 69 | def M(self, C, u, v): 70 | # "Modified cost for logarithmic updates" 71 | # "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 72 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 73 | 74 | @staticmethod 75 | def _cost_matrix(x, y, p=2): 76 | "Returns the matrix of $|x_i-y_j|^p$." 77 | x_col = x.unsqueeze(-2).cuda() 78 | y_lin = y.unsqueeze(-3).cuda() 79 | C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 80 | return C 81 | 82 | @staticmethod 83 | def ave(u, u1, tau): 84 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 85 | return tau * u + (1 - tau) * u1 -------------------------------------------------------------------------------- /tiny_imgnet/utils/dataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import random 5 | import torch 6 | import torch.utils.data 7 | from torch import nn, optim 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.dataset import Dataset 11 | from torch.nn import functional as F 12 | from torchvision import datasets,transforms 13 | from torchvision.utils import save_image 14 | from PIL import Image 15 | import glob 16 | import numpy 17 | 18 | textFile = 'utils/wnids.txt' 19 | classes_idx = [] 20 | classes = {} 21 | cnt = 0 22 | with open(textFile) as f: 23 | for line in f: 24 | classes_idx.append(line.rstrip('\n')) 25 | classes[line.rstrip('\n')] = cnt 26 | cnt += 1 27 | 28 | 29 | 30 | class dataloader_val(Dataset): 31 | def __init__(self, ImagePth, valtxtfile,transform=None): 32 | self.ImagePth = ImagePth 33 | self.valtxtfile = valtxtfile 34 | self.transform = transform 35 | 36 | imagelist = [] 37 | labelList = [] 38 | imgname = [] 39 | classId = [] 40 | 41 | 42 | with open(self.valtxtfile) as f: 43 | for line in f: 44 | imagelist.append(self.ImagePth + '/' + line.split()[0]) 45 | labelList.append(classes[line.split()[1]]) 46 | 47 | self.files = imagelist 48 | self.labels = labelList 49 | 50 | 51 | def __getitem__(self,index): 52 | image = Image.open(self.files[index]) 53 | img_tmp = (numpy.array(image)) 54 | if(img_tmp.ndim == 2): 55 | img_tmp = img_tmp.reshape((img_tmp.shape[0], img_tmp.shape[1], 1)) 56 | img_tmp = numpy.concatenate([img_tmp, img_tmp, img_tmp], axis=2) 57 | 58 | image = Image.fromarray(img_tmp) 59 | if(self.transform): 60 | image = self.transform(image) 61 | 62 | label = self.labels[index] 63 | 64 | 65 | return image, label 66 | 67 | def __len__(self): 68 | return len(self.files) 69 | 70 | 71 | class dataloader_train(Dataset): 72 | def __init__(self, ImagePth, transform=None): 73 | self.ImagePth = ImagePth 74 | self.transform = transform 75 | 76 | imagelist = [] 77 | labelList = [] 78 | 79 | classes_idx = sorted(os.listdir(self.ImagePth)) 80 | 81 | for i in range(len(classes)): 82 | for filename in glob.glob(self.ImagePth + '/' + classes_idx[i] + '/images/' +'*.JPEG'): 83 | 84 | imagelist.append(filename) 85 | labelList.append(classes[classes_idx[i]]) 86 | 87 | 88 | self.files = imagelist 89 | self.labels = labelList 90 | 91 | 92 | def __getitem__(self,index): 93 | image = Image.open(self.files[index]) 94 | img_tmp = (numpy.array(image)) 95 | if(img_tmp.ndim == 2): 96 | img_tmp = img_tmp.reshape((img_tmp.shape[0], img_tmp.shape[1], 1)) 97 | img_tmp = numpy.concatenate([img_tmp, img_tmp, img_tmp], axis=2) 98 | 99 | image = Image.fromarray(img_tmp) 100 | if(self.transform): 101 | image = self.transform(image) 102 | 103 | label = self.labels[index] 104 | 105 | 106 | return image, label 107 | 108 | def __len__(self): 109 | return len(self.files) 110 | 111 | 112 | -------------------------------------------------------------------------------- /tiny_imgnet/utils/wnids.txt: -------------------------------------------------------------------------------- 1 | n02124075 2 | n04067472 3 | n04540053 4 | n04099969 5 | n07749582 6 | n01641577 7 | n02802426 8 | n09246464 9 | n07920052 10 | n03970156 11 | n03891332 12 | n02106662 13 | n03201208 14 | n02279972 15 | n02132136 16 | n04146614 17 | n07873807 18 | n02364673 19 | n04507155 20 | n03854065 21 | n03838899 22 | n03733131 23 | n01443537 24 | n07875152 25 | n03544143 26 | n09428293 27 | n03085013 28 | n02437312 29 | n07614500 30 | n03804744 31 | n04265275 32 | n02963159 33 | n02486410 34 | n01944390 35 | n09256479 36 | n02058221 37 | n04275548 38 | n02321529 39 | n02769748 40 | n02099712 41 | n07695742 42 | n02056570 43 | n02281406 44 | n01774750 45 | n02509815 46 | n03983396 47 | n07753592 48 | n04254777 49 | n02233338 50 | n04008634 51 | n02823428 52 | n02236044 53 | n03393912 54 | n07583066 55 | n04074963 56 | n01629819 57 | n09332890 58 | n02481823 59 | n03902125 60 | n03404251 61 | n09193705 62 | n03637318 63 | n04456115 64 | n02666196 65 | n03796401 66 | n02795169 67 | n02123045 68 | n01855672 69 | n01882714 70 | n02917067 71 | n02988304 72 | n04398044 73 | n02843684 74 | n02423022 75 | n02669723 76 | n04465501 77 | n02165456 78 | n03770439 79 | n02099601 80 | n04486054 81 | n02950826 82 | n03814639 83 | n04259630 84 | n03424325 85 | n02948072 86 | n03179701 87 | n03400231 88 | n02206856 89 | n03160309 90 | n01984695 91 | n03977966 92 | n03584254 93 | n04023962 94 | n02814860 95 | n01910747 96 | n04596742 97 | n03992509 98 | n04133789 99 | n03937543 100 | n02927161 101 | n01945685 102 | n02395406 103 | n02125311 104 | n03126707 105 | n04532106 106 | n02268443 107 | n02977058 108 | n07734744 109 | n03599486 110 | n04562935 111 | n03014705 112 | n04251144 113 | n04356056 114 | n02190166 115 | n03670208 116 | n02002724 117 | n02074367 118 | n04285008 119 | n04560804 120 | n04366367 121 | n02403003 122 | n07615774 123 | n04501370 124 | n03026506 125 | n02906734 126 | n01770393 127 | n04597913 128 | n03930313 129 | n04118538 130 | n04179913 131 | n04311004 132 | n02123394 133 | n04070727 134 | n02793495 135 | n02730930 136 | n02094433 137 | n04371430 138 | n04328186 139 | n03649909 140 | n04417672 141 | n03388043 142 | n01774384 143 | n02837789 144 | n07579787 145 | n04399382 146 | n02791270 147 | n03089624 148 | n02814533 149 | n04149813 150 | n07747607 151 | n03355925 152 | n01983481 153 | n04487081 154 | n03250847 155 | n03255030 156 | n02892201 157 | n02883205 158 | n03100240 159 | n02415577 160 | n02480495 161 | n01698640 162 | n01784675 163 | n04376876 164 | n03444034 165 | n01917289 166 | n01950731 167 | n03042490 168 | n07711569 169 | n04532670 170 | n03763968 171 | n07768694 172 | n02999410 173 | n03617480 174 | n06596364 175 | n01768244 176 | n02410509 177 | n03976657 178 | n01742172 179 | n03980874 180 | n02808440 181 | n02226429 182 | n02231487 183 | n02085620 184 | n01644900 185 | n02129165 186 | n02699494 187 | n03837869 188 | n02815834 189 | n07720875 190 | n02788148 191 | n02909870 192 | n03706229 193 | n07871810 194 | n03447447 195 | n02113799 196 | n12267677 197 | n03662601 198 | n02841315 199 | n07715103 200 | n02504458 201 | --------------------------------------------------------------------------------