├── figures └── diagram.png ├── requirements.txt ├── metrics.py ├── costs.py ├── checkpoint.py ├── utils.py ├── test.py ├── task_test.py ├── stats.py ├── task_train.py ├── train.py ├── README.md ├── graphs.py ├── model.py └── datasets.py /figures/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foiv0s/imc-swav-pub/HEAD/figures/diagram.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==7.2.0 2 | numpy==1.19.2 3 | sklearn==0.0 4 | scikit-learn==0.23.2 5 | scipy==1.5.2 6 | tensorboard==1.13.1 7 | tensorboardX==2.1 8 | torch==1.6.0+cu101 9 | torchvision==0.7.0+cu101 10 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import normalized_mutual_info_score 2 | from sklearn.metrics import adjusted_rand_score 3 | import numpy as np 4 | import warnings 5 | from scipy.optimize import linear_sum_assignment as linear_assignment 6 | 7 | ari = adjusted_rand_score 8 | nmi = normalized_mutual_info_score 9 | 10 | 11 | def acc(y_true, y_pred, detailed=False): 12 | def warn(*args, **kwargs): 13 | pass 14 | 15 | warnings.warn = warn 16 | y_true = y_true.astype(np.int64) 17 | assert y_pred.size == y_true.size 18 | D = max(y_pred.max(), y_true.max()) + 1 19 | w = np.zeros((D, D), dtype=np.int64) 20 | for i in range(y_pred.size): 21 | w[y_pred[i], y_true[i]] += 1 22 | ind = linear_assignment(w.max() - w) 23 | if detailed: 24 | return sum([w[i, j] for i, j in zip(ind[0], ind[1])]) * 1.0 / y_pred.size, w, ind 25 | else: 26 | return sum([w[i, j] for i, j in zip(ind[0], ind[1])]) * 1.0 / y_pred.size 27 | -------------------------------------------------------------------------------- /costs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils import shoot_infs_v2 4 | 5 | eps = 1e-17 6 | 7 | 8 | def entropy(x1, x2=None): 9 | x2 = x1 if x2 is None else x2 10 | return -(x1 * torch.log(x2.clamp(eps, 1.))) 11 | 12 | 13 | def mi(y, u, b=4, a=1e-2): 14 | lgt_reg = a * torch.relu(torch.abs(y) - 5.).sum(-1).mean() 15 | py, pu = torch.softmax(y, -1), torch.softmax(u, -1) 16 | p_yu = torch.matmul(py.T, pu) # k x k’ 17 | p_yu /= p_yu.sum() # normalize to sum 1 18 | p_u = p_yu.sum(0).view(1, -1) # marginal p_u 19 | p_y = p_yu.sum(1).view(-1, 1) # marginal p_y 20 | h_uy = (p_yu * (torch.log(p_u) - torch.log(p_yu))).sum() # conditional entropy 21 | hy = b * (p_yu * torch.log(p_y)).sum() # weighted marginal entropy 22 | return h_uy + hy, lgt_reg 23 | 24 | 25 | def sinkhorn(Q, nmb_iters): 26 | with torch.no_grad(): 27 | Q = shoot_infs_v2(Q) 28 | sum_Q = torch.sum(Q) 29 | Q /= sum_Q 30 | r = torch.ones(Q.shape[0]).cuda() / Q.shape[0] 31 | c = torch.ones(Q.shape[1]).cuda() / (-1 * Q.shape[1]) 32 | for it in range(nmb_iters): 33 | u = torch.sum(Q, dim=1) 34 | u = r / u 35 | u = shoot_infs_v2(u) 36 | Q *= u.unsqueeze(1) 37 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) 38 | return (Q / torch.sum(Q)).float() 39 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from model import Model 4 | 5 | 6 | class Checkpointer: 7 | def __init__(self, output_dir=None, filename='imc_swav.cpt'): 8 | self.output_dir = output_dir 9 | self.epoch = 0 10 | self.model = None 11 | self.filename = filename 12 | 13 | def track_new_model(self, model): 14 | self.model = model 15 | 16 | def restore_model_from_checkpoint(self, cpt_path): 17 | ckp = torch.load(cpt_path) 18 | hp = ckp['hyperparams'] 19 | params = ckp['model'] 20 | self.epoch = ckp['epoch'] 21 | 22 | self.model = Model(n_classes=hp['n_classes'], encoder_size=hp['encoder_size'], prototypes=hp['prototypes'], 23 | project_dim=hp['project_dim'], tau=hp['tau'], eps=hp['eps']) 24 | 25 | model_dict = self.model.state_dict() 26 | model_dict.update(params) 27 | params = model_dict 28 | self.model.load_state_dict(params) 29 | 30 | print("***** CHECKPOINTING *****\n" 31 | "Model restored from checkpoint.\n" 32 | "Training epoch {}\n" 33 | "*************************" 34 | .format(self.epoch)) 35 | return self.model 36 | 37 | def _get_state(self): 38 | return { 39 | 'model': self.model.state_dict(), 40 | 'hyperparams': self.model.hyperparams, 41 | 'epoch': self.epoch 42 | } 43 | 44 | def update(self, epoch): 45 | self.epoch = epoch 46 | cpt_path = os.path.join(self.output_dir, self.filename) 47 | torch.save(self._get_state(), cpt_path) 48 | 49 | def get_current_position(self): 50 | return self.epoch 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from stats import update_train_accuracies 4 | 5 | 6 | # Evaluate accuracy on test set 7 | def test_model(model, test_loader, device, stats): 8 | # warm up norm layers 9 | _warmup_batchnorm(model, test_loader, device, batches=50, train_loader=False) 10 | 11 | model.eval() 12 | targets, predictions = [], [] 13 | for _, (images, targets_, idxs) in enumerate(test_loader): 14 | images = images.to(device) 15 | val_idxs = targets_ >= 0 16 | with torch.no_grad(): 17 | res_dict = model(x=images, eval_only=True) 18 | predictions.append(res_dict['y'][val_idxs].cpu().numpy()), targets.append(targets_[val_idxs]) 19 | targets, predictions = np.concatenate(targets).ravel(), np.concatenate(predictions).ravel() 20 | model.train() 21 | update_train_accuracies(stats, targets, predictions, 'Test Clustering ') 22 | 23 | 24 | def _warmup_batchnorm(model, data_loader, device, batches=50, train_loader=False): 25 | model.train() 26 | for i, (images, _, idxs) in enumerate(data_loader): 27 | if i == batches: 28 | break 29 | if train_loader: 30 | images = images[0] 31 | _ = model(x=images.to(device), eval_only=True) 32 | 33 | 34 | ''' 35 | Following two methods (distributed_sinkhorn, shoot_infs) are based on SwAV implementation 36 | credits to https://github.com/facebookresearch/swav 37 | ''' 38 | 39 | 40 | def sinkhorn(Q, nmb_iters): 41 | with torch.no_grad(): 42 | Q = shoot_infs_v2(Q) 43 | sum_Q = torch.sum(Q) 44 | Q /= sum_Q 45 | r = torch.ones(Q.shape[0]).cuda() / Q.shape[0] 46 | c = torch.ones(Q.shape[1]).cuda() / (-1 * Q.shape[1]) 47 | for it in range(nmb_iters): 48 | u = torch.sum(Q, dim=1) 49 | u = r / u 50 | u = shoot_infs_v2(u) 51 | Q *= u.unsqueeze(1) 52 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) 53 | return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() 54 | 55 | 56 | def shoot_infs_v2(inp_tensor): 57 | """Replaces inf by maximum of tensor""" 58 | mask_inf = torch.isinf(inp_tensor) 59 | if mask_inf.sum() > 0.: 60 | inp_tensor[mask_inf] = 0 61 | m = torch.max(inp_tensor) 62 | inp_tensor[mask_inf] = m 63 | return inp_tensor 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from stats import AverageMeterSet 4 | from datasets import build_dataset, get_dataset 5 | from checkpoint import Checkpointer 6 | from utils import test_model 7 | import os 8 | 9 | parser = argparse.ArgumentParser(description='IMC-SwAV - Testing') 10 | 11 | parser.add_argument('--cpt_load_path', type=str, default=None, help='path from which to load checkpoint (if available)') 12 | parser.add_argument('--dataset', type=str, default='c10') 13 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process') 14 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+", 15 | help="list of number of crops (i.e.: [2, 4])") 16 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+", 17 | help="crops resolutions (i.e.: [28, 18])") 18 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+", 19 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])") 20 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+", 21 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])") 22 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)') 23 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)') 24 | 25 | args = parser.parse_args() 26 | 27 | if args.dev is not None: 28 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev 29 | 30 | 31 | def test(model, test_loader, device, stats): 32 | test_model(model, test_loader, device, stats) 33 | 34 | 35 | def main(): 36 | # get the dataset 37 | dataset = get_dataset(args.dataset) 38 | 39 | _, test_loader, num_classes = \ 40 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers, 41 | nmb_crops=args.nmb_crops, size_crops=args.size_crops, 42 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops) 43 | checkpointer = Checkpointer() 44 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu') 45 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path) 46 | model = model.to(torch_device) 47 | 48 | test_stats = AverageMeterSet() 49 | test(model, test_loader, torch_device, test_stats) 50 | stat_str = test_stats.pretty_string() 51 | print(stat_str) 52 | 53 | 54 | if __name__ == "__main__": 55 | print(args) 56 | main() 57 | -------------------------------------------------------------------------------- /task_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from stats import AverageMeterSet 4 | from datasets import build_dataset, get_dataset 5 | from checkpoint import Checkpointer 6 | from utils import test_model 7 | import os 8 | 9 | parser = argparse.ArgumentParser(description='IMC-SwAV - Testing') 10 | 11 | parser.add_argument('--cpt_load_path', type=str, default=None, help='path from which to load checkpoint (if available)') 12 | parser.add_argument('--dataset', type=str, default='c10') 13 | parser.add_argument('--path', type=str, default=None, help='Root directory for the dataset') 14 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process') 15 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+", 16 | help="list of number of crops (i.e.: [2, 4])") 17 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+", 18 | help="crops resolutions (i.e.: [28, 18])") 19 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+", 20 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])") 21 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+", 22 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])") 23 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)') 24 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)') 25 | 26 | args = parser.parse_args() 27 | 28 | if args.dev is not None: 29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev 30 | 31 | 32 | def main(): 33 | # get the dataset 34 | dataset = get_dataset(args.dataset) 35 | 36 | _, test_loader, num_classes = \ 37 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers, 38 | nmb_crops=args.nmb_crops, size_crops=args.size_crops, 39 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, path=args.path) 40 | checkpointer = Checkpointer() 41 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu') 42 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path) 43 | model = model.to(torch_device) 44 | 45 | test_stats = AverageMeterSet() 46 | test_model(model, test_loader, torch_device, test_stats) 47 | stat_str = test_stats.pretty_string() 48 | print(stat_str) 49 | 50 | 51 | if __name__ == "__main__": 52 | print(args) 53 | main() 54 | -------------------------------------------------------------------------------- /stats.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from metrics import ari, nmi, acc 3 | 4 | ''' 5 | Implementation of classes of AverageMeterSet and StatTracker are based on below repository 6 | https://github.com/Philip-Bachman/amdim-public 7 | ''' 8 | 9 | 10 | class AverageMeterSet: 11 | def __init__(self): 12 | self.sums = {} 13 | self.counts = {} 14 | self.avgs = {} 15 | 16 | def _compute_avgs(self): 17 | for name in self.sums: 18 | self.avgs[name] = float(self.sums[name]) / float(self.counts[name]) 19 | 20 | def update_dict(self, name_val_dict, n=1): 21 | for name, val in name_val_dict.items(): 22 | self.update(name, val, n) 23 | 24 | def update(self, name, value, n=1): 25 | if name not in self.sums: 26 | self.sums[name] = value 27 | self.counts[name] = n 28 | else: 29 | self.sums[name] = self.sums[name] + value 30 | self.counts[name] = self.counts[name] + n 31 | 32 | def pretty_string(self, ignore=('zzz')): 33 | self._compute_avgs() 34 | s = [] 35 | for name, avg in self.avgs.items(): 36 | keep = True 37 | for ign in ignore: 38 | if ign in name: 39 | keep = False 40 | if keep: 41 | s.append('{0:s}: {1:.3f}'.format(name, avg)) 42 | s = ', '.join(s) 43 | return s 44 | 45 | def averages(self, idx, prefix=''): 46 | self._compute_avgs() 47 | return {prefix + name: (avg, idx) for name, avg in self.avgs.items()} 48 | 49 | 50 | class StatTracker: 51 | 52 | def __init__(self, log_name=None, log_dir=None): 53 | assert ((log_name is None) or (log_dir is None)) 54 | if log_dir is None: 55 | self.writer = SummaryWriter(comment=log_name) 56 | else: 57 | print('log_dir: {}'.format(str(log_dir))) 58 | try: 59 | self.writer = SummaryWriter(logdir=log_dir) 60 | except: 61 | self.writer = SummaryWriter(log_dir=log_dir) 62 | 63 | def close(self): 64 | self.writer.close() 65 | 66 | def record_stats(self, stat_dict): 67 | for stat_name, stat_vals in stat_dict.items(): 68 | self.writer.add_scalar(stat_name, stat_vals[0], stat_vals[1]) 69 | 70 | 71 | # Helper function for tracking accuracy on training set 72 | def update_train_accuracies(epoch_stats, targets, predictions, name='train'): 73 | val_idxs = targets >= 0 74 | epoch_stats.update(name + ' ACC', acc(targets[val_idxs], predictions[val_idxs]), n=1) 75 | epoch_stats.update(name + ' NMI', nmi(targets[val_idxs], predictions[val_idxs]), n=1) 76 | epoch_stats.update(name + ' ARI', ari(targets[val_idxs], predictions[val_idxs]), n=1) 77 | 78 | -------------------------------------------------------------------------------- /task_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | import torch.optim as optim 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | import numpy as np 7 | from utils import test_model 8 | from stats import AverageMeterSet, update_train_accuracies 9 | 10 | 11 | def _train(model, optimizer, scheduler_inf, train_loader, test_loader, nmb_crops, stat_tracker, 12 | checkpointer, device, warmup, epochs): 13 | ''' 14 | Training loop for optimizing overall framework 15 | ''' 16 | lr_real = optimizer.param_groups[0]['lr'] 17 | torch.cuda.empty_cache() 18 | 19 | # If mixed precision is on, will add the necessary hooks into the model 20 | # and optimizer for half() conversions 21 | next_epoch = checkpointer.get_current_position() 22 | total_updates = next_epoch * len(train_loader) 23 | # run main training loop 24 | for epoch in range(next_epoch, epochs): 25 | epoch_stats = AverageMeterSet() 26 | time_epoch = time.time() 27 | targets, predictions = [], [] 28 | model.reset_membank_list() 29 | for _, ((aug_imgs, raw_imgs), targets_, idx) in enumerate(train_loader): 30 | 31 | # Perform clustering only on label idxs 32 | val_idxs = targets_ >= 0 33 | targets.append(targets_[val_idxs].numpy()) 34 | aug_imgs = [aug_img.to(device) for aug_img in aug_imgs] 35 | 36 | res_dict = model(x=aug_imgs, eval_only=False, nmb_crops=nmb_crops, eval_idxs=val_idxs) 37 | 38 | # Warmup 39 | if total_updates < warmup: 40 | lr_scale = min(1., float(total_updates + 1) / float(warmup)) 41 | for i, pg in enumerate(optimizer.param_groups): 42 | pg['lr'] = lr_scale * lr_real 43 | 44 | loss_opt = res_dict['swav_loss'] + res_dict['mi_loss'] + res_dict['lgt_reg'] 45 | optimizer.zero_grad() 46 | loss_opt.backward() 47 | 48 | # Stop gradient for prototypes till warmup is over 49 | if total_updates < warmup: 50 | model.prototypes.prototypes.weight.grad = None 51 | optimizer.step() 52 | 53 | epoch_stats.update_dict({'swav_loss': res_dict['swav_loss'].item(), }, n=1) 54 | 55 | # None can be only on STL10, if not enough labelled training instances to evaluate 56 | if res_dict['y'] is not None: 57 | predictions.append(res_dict['y'].cpu().numpy()) 58 | epoch_stats.update_dict({ 59 | 'mi_loss': res_dict['mi_loss'].item(), 60 | 'lgt_reg': res_dict['lgt_reg'].item(), 61 | }, n=1) 62 | total_updates += 1 63 | time_stop = time.time() 64 | spu = (time_stop - time_epoch) 65 | print('Epoch {0:d}, {1:.4f} sec/epoch'.format(epoch, spu)) 66 | # update learning rate 67 | scheduler_inf.step() 68 | targets, predictions = np.concatenate(targets).ravel(), np.concatenate(predictions).ravel() 69 | test_model(model, test_loader, device, epoch_stats) 70 | # Evaluation only for the labelled set (in case of STL10) 71 | update_train_accuracies(epoch_stats, targets[:predictions.shape[0]], predictions, 'Train Clustering ') 72 | epoch_str = epoch_stats.pretty_string() 73 | diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) 74 | print(diag_str) 75 | sys.stdout.flush() 76 | stat_tracker.record_stats(epoch_stats.averages(epoch, prefix='costs/')) 77 | checkpointer.update(epoch + 1) 78 | 79 | 80 | def train_model(model, learning_rate, train_loader, test_loader, nmb_crops, stat_tracker, 81 | checkpointer, device, warmup, epochs, l2_w): 82 | mods = [m for m in model.modules_] 83 | optimizer = optim.Adam([{'params': mod.parameters(), 'lr': learning_rate} for i, mod in enumerate(mods)], 84 | betas=(0.8, 0.999), weight_decay=l2_w) 85 | 86 | scheduler = MultiStepLR(optimizer, milestones=[150, 300, 400], gamma=0.4) 87 | _train(model, optimizer, scheduler, train_loader, test_loader, nmb_crops, stat_tracker, 88 | checkpointer, device, warmup, epochs) 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | from stats import StatTracker 7 | from datasets import build_dataset, get_dataset, get_encoder_size 8 | from model import Model 9 | from checkpoint import Checkpointer 10 | from task_train import train_model 11 | 12 | parser = argparse.ArgumentParser(description='IMC-SwAV - Training') 13 | 14 | parser.add_argument('--dataset', type=str, default='C10') 15 | parser.add_argument('--path', type=str, default=None, help='Root directory for the dataset') 16 | 17 | # Transformation parameters/number of crops sizes. First index reports the high resolution, low resolutions follows. 18 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process') 19 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+", 20 | help="list of number of crops (i.e.: [2, 4])") 21 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+", 22 | help="crops resolutions (i.e.: [28, 18])") 23 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+", 24 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])") 25 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+", 26 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])") 27 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)') 28 | 29 | # Model and training parameters 30 | parser.add_argument('--tau', type=float, default=0.1, help='Temperature parameter on Softmax (Eq. 2)') 31 | parser.add_argument('--eps', type=float, default=0.05, help='Epsilon scalar of Sinkhorn-Knopp (Eq. 3)') 32 | parser.add_argument('--warmup', type=int, default=500, help='Epoch of warmup schedule') 33 | parser.add_argument('--epochs', type=int, default=500, help='Training epoch') 34 | parser.add_argument('--learning_rate', type=float, default=0.0005, help='Learning rate') 35 | parser.add_argument("--project_dim", type=int, default=128, help="Project embedding dimension") 36 | parser.add_argument("--prototypes", type=int, default=1000, help="Number of prototypes") 37 | parser.add_argument("--model_type", type=str, default='resnet18', help="Type of ResNet") 38 | 39 | # parameters for output, logging, checkpointing, etc 40 | parser.add_argument('--output_dir', type=str, default='./default_run', 41 | help='Storing path for Tensorboard events and checkpoints') 42 | parser.add_argument('--cpt_load_path', type=str, default=None, help='Load checkpoint path+name(if available)') 43 | parser.add_argument('--cpt_name', type=str, default='imc_swav.cpt', help='Checkpoint name during training') 44 | parser.add_argument('--run_name', type=str, default='default_run', help='Tensorboard summary name') 45 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)') 46 | parser.add_argument("--l2_w", type=float, default=1e-5, help='l_2 weights') 47 | 48 | args = parser.parse_args() 49 | 50 | if args.dev is not None: 51 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev 52 | 53 | 54 | def main(): 55 | # create output dir (only if it doesn't exist) 56 | if not os.path.isdir(args.output_dir): 57 | os.mkdir(args.output_dir) 58 | 59 | # get the dataset 60 | dataset = get_dataset(args.dataset) 61 | encoder_size = get_encoder_size(dataset) 62 | 63 | # get a helper object for tensorboard logging 64 | log_dir = os.path.join(args.output_dir, args.run_name) 65 | stat_tracker = StatTracker(log_dir=log_dir) 66 | 67 | # get training and testing loaders 68 | train_loader, test_loader, num_classes = \ 69 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers, 70 | nmb_crops=args.nmb_crops, size_crops=args.size_crops, 71 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, path=args.path) 72 | 73 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu') 74 | checkpointer = Checkpointer(args.output_dir, args.cpt_name) 75 | if args.cpt_load_path: 76 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path) 77 | else: 78 | # create new model with random parameters 79 | model = Model(n_classes=num_classes, encoder_size=encoder_size, prototypes=args.prototypes, 80 | project_dim=args.project_dim, tau=args.tau, eps=args.eps, model_type=args.model_type) 81 | checkpointer.track_new_model(model) 82 | 83 | model = model.to(torch_device) 84 | 85 | train_model(model, args.learning_rate, train_loader, test_loader, args.nmb_crops, stat_tracker, 86 | checkpointer, torch_device, args.warmup, args.epochs, args.l2_w) 87 | 88 | 89 | if __name__ == "__main__": 90 | print(args) 91 | main() 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Information Maximization Clustering via Multi-View Self-Labelling 2 | 3 | ## Introduction 4 | **This is an implementation code written in Python (version 3.6.9) of IMC-SwAV based on manuscript [paper](https://arxiv.org/abs/2103.07368) 5 | ** 6 | 7 | 8 |

9 | IMC-SwAV diagram 10 |

11 | 12 | 13 | ## Performance 14 | 15 | The reported performance of our proposed model is based on ResNet18 encoder architecture.
16 | We train our IMC-SwAV for 15 independent runs on training set and we report the result of testing set only. 17 | 18 | ### Average Performance 19 | 20 | Table of average performance and the corresponding standard deviation. 21 | 22 | Dataset | Acc | NMI | ARI 23 | --- | --- | --- | --- 24 | CIFAR-10|89.1 (±0.5) | 81.1 (±0.7)| 79.0 (±1.0) 25 | CIFAR-100-20 | 49.0 (±1.8)| 50.3 (±1.2) | 33.7 (±1.3) 26 | STL10| 83.1 (±1.0) | 72.9 (±0.9) | 68.5 (±1.4) 27 | Tiny-Imagenet| 27.9 (±0.3) | 48.5 (±2.0) | 14.3 (±2.1) 28 | 29 | ### Best Performance 30 | 31 | Below table reports the best recorded performance from our model. 32 | 33 | Dataset | Acc | NMI | ARI 34 | --- | --- | --- | --- 35 | CIFAR-10|89.7 | 81.8| 80.0 36 | CIFAR-100-20 | 51.9| 52.7 | 36.1 37 | STL10| 85.3 | 74.7 | 71.6 38 | Tiny-Imagenet| 28.2 | 52.6 | 14.6 39 | 40 | Below, we report separate the result of our proposed IMC-SwAV in CIFAR-100 experiment (100 class) 41 | 42 | Dataset | Top-1 ACC | Top-5 ACC | NMI | ARI 43 | --- | --- | --- | --- | --- 44 | CIFAR-100| 45.1 | 67.5 | 60.8 | 30.7 45 | 46 | ## Usage 47 | 48 | All hyper parameters apply across all datasets (default setup/experiment) in the submission document as following: 49 | 50 | Settings related with the multi-crop \ 51 | --nmb_crops 2 4 \ 52 | --max_scale_crops 1. 0.4 \ 53 | --min_scale_crops 0.2 0.08 54 | 55 | Settings related with SwAV \ 56 | --tau 0.1 \ 57 | --eps 0.05 \ 58 | --project_dim 128 \ 59 | --prototypes 1000 60 | 61 | Settings related to the training \ 62 | --learning_rate 0.0005 \ 63 | --warmup 500
64 | --l2_w 1e-5 65 | 66 | Settings related to the dataset
67 | --path ROOT_DIRECTORY_OF_THE_DATASET (the path folder of the dataset) 68 | 69 | To run any of the code, the directory path of the dataset is required
70 | otherwise it will automatically download to './dataset' 71 | #### CIFAR-10 72 | 73 | To run the training code. 74 | 75 | ``` 76 | python train.py --dataset C10 --path ./dataset --size_crops 28 18 \ 77 | --output_dir ./c10 --cpt_name c10.cpt 78 | ``` 79 | 80 | #### CIFAR-100/20 81 | 82 | To run the training code. 83 | 84 | ``` 85 | python train.py --dataset C20 --path ./dataset --size_crops 28 18 \ 86 | --output_dir ./c20 --cpt_name c20.cpt 87 | ``` 88 | 89 | #### STL10 90 | 91 | To run the training code. 92 | 93 | ``` 94 | python train.py --dataset STL10 --path ./dataset --size_crops 76 52 \ 95 | --output_dir ./stl10 --cpt_name stl10.cpt --path ./dataset 96 | ``` 97 | 98 | #### CIFAR100 99 | 100 | To run the training code. 101 | 102 | ``` 103 | python train.py --dataset C100 --path ./dataset --size_crops 28 18 --batch_size 512 \ 104 | --output_dir ./c100 --cpt_name c100.cpt 105 | ``` 106 | 107 | #### Tiny-Imagenet 108 | 109 | To run the training code. 110 | 111 | ``` 112 | python train.py --dataset tiny --path ./dataset --size_crops 56 36 --batch_size 512 \ 113 | --output_dir ./tiny --cpt_name tiny.cpt 114 | ``` 115 | ## The evaluation of the model. 116 | 117 | ##### Example evaluation on CIFAR10/100-20/100: 118 | 119 | Through the argument '--cpt_load_path', the full path of the stored model is parsed. 120 | 121 | ``` 122 | python test.py --dataset c10 --path ./dataset --size_crops 28 18 --cpt_load_path ./c10/imc_swav.cpt 123 | ``` 124 | 125 | ``` 126 | python test.py --dataset c20 --path ./dataset --size_crops 28 18 --cpt_load_path ./c20/imc_swav.cpt 127 | ``` 128 | 129 | ``` 130 | python test.py --dataset c100 --path ./dataset --size_crops 28 18 --cpt_load_path ./c100/imc_swav.cpt 131 | ``` 132 | 133 | ##### Example evaluation on STL10: 134 | 135 | ``` 136 | python test.py --dataset STL10 --path ./dataset --size_crops 76 52 --cpt_load_path ./stl10/imc_swav.cpt 137 | ``` 138 | 139 | ##### Example evaluation on Tiny-Imagenet: 140 | 141 | ``` 142 | python test.py --dataset tiny --path ./dataset --size_crops 56 36 --cpt_load_path ./tiny/tiny.cpt 143 | ``` 144 | 145 | ## Notes 146 | 147 | - During the training, each epoch reports the model's performance on test (validation set) 148 | and the training set (performance on training set is based on transformed instances). 149 | 150 | - The classifier head is trained and evaluated only for labelled set on STL10 dataset. The unlabelled part of STL10 is 151 | used only to train the encoder and prototypes. 152 | 153 | - All tests have been performed in Cuda version 10.1. 154 | 155 | ## Acknowledgement for reference repos 156 | - [AMDIM](https://github.com/Philip-Bachman/amdim-public) 157 | - [SwAV](https://github.com/facebookresearch/swav) 158 | 159 | ## Citation 160 | 161 | ```shell 162 | @misc{ntelemis2021information, 163 | title={Information Maximization Clustering via Multi-View Self-Labelling}, 164 | author={Foivos Ntelemis and Yaochu Jin and Spencer A. Thomas}, 165 | year={2021}, 166 | eprint={2103.07368}, 167 | archivePrefix={arXiv}, 168 | primaryClass={cs.CV} 169 | } 170 | ``` 171 | 172 | 173 | -------------------------------------------------------------------------------- /graphs.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchvision.models.resnet import Bottleneck, BasicBlock, conv1x1 4 | 5 | '''Modify version of Pytorch's ResNet''' 6 | 7 | 8 | class ResNet(nn.Module): 9 | def __init__(self, block, layers, conv, zero_init_residual=False, groups=1, width_per_group=64, 10 | replace_stride_with_dilation=None, norm_layer=None): 11 | super(ResNet, self).__init__() 12 | if norm_layer is None: 13 | norm_layer = nn.BatchNorm2d 14 | self._norm_layer = norm_layer 15 | self.covn1 = conv 16 | self.inplanes = 64 17 | self.dilation = 1 18 | if replace_stride_with_dilation is None: 19 | # each element in the tuple indicates if we should replace 20 | # the 2x2 stride with a dilated convolution instead 21 | replace_stride_with_dilation = [False, False, False] 22 | if len(replace_stride_with_dilation) != 3: 23 | raise ValueError("replace_stride_with_dilation should be None " 24 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 25 | self.groups = groups 26 | self.base_width = width_per_group 27 | 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 30 | dilate=replace_stride_with_dilation[0]) # 16x16 31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 32 | dilate=replace_stride_with_dilation[1]) 33 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 34 | dilate=replace_stride_with_dilation[2]) 35 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 36 | # self.fc.init_weights(1.) 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 40 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 41 | nn.init.constant_(m.weight, 1) 42 | nn.init.constant_(m.bias, 0) 43 | 44 | # Zero-initialize the last BN in each residual branch, 45 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 46 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 47 | self.layer_list = nn.ModuleList([self.covn1, self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool]) 48 | if zero_init_residual: 49 | for m in self.modules(): 50 | if isinstance(m, Bottleneck): 51 | nn.init.constant_(m.bn3.weight, 0) 52 | elif isinstance(m, BasicBlock): 53 | nn.init.constant_(m.bn2.weight, 0) 54 | 55 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 56 | norm_layer = self._norm_layer 57 | downsample = None 58 | previous_dilation = self.dilation 59 | if dilate: 60 | self.dilation *= stride 61 | stride = 1 62 | if stride != 1 or self.inplanes != planes * block.expansion: 63 | downsample = nn.Sequential( 64 | conv1x1(self.inplanes, planes * block.expansion, stride), 65 | norm_layer(planes * block.expansion), 66 | ) 67 | 68 | layers = [] 69 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 70 | self.base_width, previous_dilation, norm_layer)) 71 | self.inplanes = planes * block.expansion 72 | for _ in range(1, blocks): 73 | layers.append(block(self.inplanes, planes, groups=self.groups, 74 | base_width=self.base_width, dilation=self.dilation, 75 | norm_layer=norm_layer)) 76 | 77 | return nn.Sequential(*layers) 78 | 79 | def _forward_impl(self, x): 80 | 81 | layer_acts = [x] 82 | for i, layer in enumerate(self.layer_list): 83 | layer_in = layer_acts[-1] 84 | layer_out = layer(layer_in) 85 | layer_acts.append(layer_out) 86 | return layer_acts[-1] 87 | 88 | def forward(self, x): 89 | return self._forward_impl(x) 90 | 91 | 92 | class Projection(nn.Module): 93 | def __init__(self, n_input, n_out=128): 94 | super(Projection, self).__init__() 95 | # self.project = nn.Sequential(nn.Conv2d(n_input, n_input, 1, 1, 0), 96 | # # nn.BatchNorm2d(n_input, affine=True), 97 | # # nn.ReLU(inplace=True), 98 | # nn.Conv2d(n_input, n_out, 1, 1, 0, bias=True)) 99 | self.project = nn.Sequential(nn.Linear(n_input, n_input, bias=True), 100 | # nn.BatchNorm1d(n_input, affine=True), 101 | # nn.ReLU(inplace=True), 102 | nn.Linear(n_input, n_out, bias=True)) 103 | return 104 | 105 | def forward(self, r1_x): 106 | # out = self.project(r1_x) 107 | # out = nn.functional.normalize(out, dim=1, p=2) 108 | return self.project(r1_x) 109 | 110 | 111 | class Prototypes(nn.Module): 112 | def __init__(self, n_input, n_out=1000): 113 | super(Prototypes, self).__init__() 114 | # self.prototypes = nn.Conv2d(n_input, n_out, 1, 1, 0, bias=False) 115 | self.prototypes = nn.Linear(n_input, n_out, bias=False) 116 | return 117 | 118 | def forward(self, r1_x): 119 | r1_x = nn.functional.normalize(r1_x, dim=1, p=2) 120 | # if len(r1_x.shape) != 4: 121 | # r1_x = r1_x.unsqueeze(-1).unsqueeze(-1) 122 | return self.prototypes(r1_x) # .squeeze(-1).squeeze(-1) 123 | 124 | 125 | class Classifier(nn.Module): 126 | def __init__(self, n_input, n_classes, n_hidden=1024): 127 | super(Classifier, self).__init__() 128 | self.n_input = n_input 129 | self.n_classes = n_classes 130 | self.n_hidden = n_hidden 131 | 132 | self.aux_head = nn.Sequential( 133 | nn.Linear(n_input, n_hidden, bias=True), 134 | nn.BatchNorm1d(n_hidden, affine=True), 135 | nn.ReLU(inplace=True), 136 | nn.Linear(n_hidden, n_hidden, bias=True), 137 | nn.BatchNorm1d(n_hidden, affine=True), 138 | nn.ReLU(inplace=True), 139 | nn.Linear(n_hidden, n_classes, bias=True) 140 | # nn.Conv2d(n_input, n_hidden, 1, 1, 0), 141 | # nn.BatchNorm2d(n_hidden, affine=True), 142 | # nn.ReLU(inplace=True), 143 | # nn.Conv2d(n_hidden, n_hidden, 1, 1, 0), 144 | # nn.BatchNorm2d(n_hidden, affine=True), 145 | # nn.ReLU(inplace=True), 146 | # nn.Conv2d(n_hidden, n_classes, 1, 1, 0) 147 | ) 148 | return 149 | 150 | def forward(self, r1_x): 151 | # Always detach so to train only omega parameters 152 | r1_x = r1_x.detach() 153 | px = self.aux_head(r1_x) 154 | return px.squeeze(-1).squeeze(-1) 155 | 156 | 157 | class MLPClassifier(nn.Module): 158 | def __init__(self, n_classes, n_input, p=0.1): 159 | super(MLPClassifier, self).__init__() 160 | self.n_input = n_input 161 | self.n_classes = n_classes 162 | 163 | self.block_forward = nn.Sequential( 164 | nn.Dropout(p=p), 165 | nn.Linear(self.n_input, n_classes, bias=True) 166 | ) 167 | 168 | def forward(self, x): 169 | x = x.detach() 170 | x = torch.flatten(x, 1) 171 | logits = self.block_forward(x) 172 | return logits 173 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from graphs import ResNet, BasicBlock, Bottleneck, Projection, Prototypes, Classifier 5 | from utils import sinkhorn 6 | from costs import entropy, mi 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, n_classes, encoder_size=32, prototypes=1000, project_dim=128, 11 | tau=0.1, eps=0.05, model_type='resnet18'): 12 | super(Model, self).__init__() 13 | 14 | self.hyperparams = { 15 | 'n_classes': n_classes, 16 | 'encoder_size': encoder_size, 17 | 'prototypes': prototypes, # k' number of prototypes 18 | 'project_dim': project_dim, # projection head's dimension 19 | 'tau': tau, # Tau parameter of Softmax smoothness (Eq.2) 20 | 'eps': eps, # epsilon (Eq.3) 21 | 'model_type': model_type 22 | } 23 | 24 | dummy_batch = torch.zeros((2, 3, encoder_size, encoder_size)) 25 | 26 | # encoder that provides multiscale features 27 | self.encoder = Encoder(encoder_size=encoder_size, model_type=model_type) 28 | rkhs_1 = self.encoder(dummy_batch) 29 | self.encoder = nn.DataParallel(self.encoder) 30 | self.project = Projection(rkhs_1.size(1), project_dim) 31 | self.prototypes = Prototypes(project_dim, prototypes) 32 | self.auxhead = Classifier(rkhs_1.size(1), n_classes) # Classifier 33 | self.modules_ = [self.encoder.module, self.prototypes, self.project, self.auxhead] 34 | self._t, self._e = tau, eps 35 | self._z_bank = [] 36 | self._u_bank = [] 37 | self.counter = 0 38 | 39 | def _encode(self, res_dict, aug_imgs, num_crops): 40 | res_dict['z'], res_dict['u'] = [], [] 41 | b = aug_imgs[0].size()[0] 42 | 43 | with torch.no_grad(): 44 | # l2 normalization for prototypes 45 | w = self.prototypes.prototypes.weight.data.clone() 46 | w = nn.functional.normalize(w, dim=1, p=2) 47 | self.prototypes.prototypes.weight.copy_(w) 48 | 49 | for aug_imgs_ in aug_imgs: 50 | emb = self.encoder(aug_imgs_) 51 | emb_projected = self.project(emb) 52 | ux = self.prototypes(emb_projected) 53 | res_dict['z'].append(emb) 54 | res_dict['u'].append(ux) 55 | # ''' 56 | swav_loss = [] 57 | for i in range(num_crops[0]): 58 | with torch.no_grad(): 59 | # Sinkhorn knopp algorithm (Formulation of eq. 3 & 4, based on SWAV implementation) 60 | # (SWAV URL: https://arxiv.org/abs/2006.09882) 61 | q = res_dict['u'][i].detach().clone() 62 | q = torch.exp(q / self._e) 63 | q = sinkhorn(q.T, 3) # [-batch_size:] 64 | for p, px in enumerate(res_dict['u']): 65 | if p != i: 66 | # Equation 1 67 | swav_loss.append(entropy(q, torch.softmax(px / self._t, -1)).sum(-1).mean()) 68 | res_dict['swav_loss'] = torch.stack(swav_loss).sum() / len(swav_loss) 69 | 70 | return res_dict 71 | 72 | def encode(self, imgs, res_dict): 73 | 74 | with torch.no_grad(): 75 | z = self.encoder(imgs) 76 | res_dict['y'] = torch.flatten(self.auxhead(z), 1).argmax(-1) 77 | return res_dict 78 | 79 | def forward(self, x, nmb_crops=[2], eval_idxs=None, eval_only=False): 80 | 81 | # dict for returning various values 82 | res_dict = {} 83 | if eval_only: 84 | return self.encode(x, res_dict) 85 | 86 | res_dict = self._encode(res_dict, x, nmb_crops) 87 | 88 | ''' 89 | SLT10 contains instances of unlabelled and labelled set.. 90 | Because our method is trained in online mode together with encoder part. 91 | Due to the ratio of instances between unlabelled (100.000) and labelled set (5000 on training set), 92 | the training batch size contains a large number of unlabelled instances and a very small number 93 | of training instances. 94 | Hence, we use a short temporary bank to store representations (z) till we can reach the same number 95 | as the batch size of labelled instances. 96 | Afterwards, the membanks are used to train the classifier. 97 | ''' 98 | 99 | # The classifier is trained only on labelled training set 100 | # This actually applies only on STL10 where there is an unlabelled and labelled set 101 | # We train and evaluate the classifier only on labelled set, hence eval_idxs is a boolean array 102 | # for performing training only on labelled idxs 103 | if eval_idxs is None: 104 | eval_idxs = torch.ones(res_dict['z'][0].size(0), dtype=torch.bool) 105 | 106 | # Below lists are membanks to store embeddings and u probability distributions. 107 | # These are only used for training on STL10. 108 | # As it is mentioned above classifier is trained only on labelled set of STL10 109 | # We just collect the labelled embedding representations till the collection reaches equal number to 110 | # the training batch size 111 | # This happens because a batch size on STL10 contains instances of label and unlabelled set 112 | # Hence each batch size contains a very small number of training instances and we use this collection 113 | # In order to maintain the batch size of classifier on average equals to the encoder batch size 114 | # ''' 115 | 116 | with torch.no_grad(): 117 | self._z_bank.append(torch.cat([z[eval_idxs].unsqueeze(1).detach() for z in res_dict['z']], 1)) 118 | self._u_bank.append(torch.cat([ux[eval_idxs].unsqueeze(1).detach() for ux in res_dict['u']], 1)) 119 | b = res_dict['z'][0].size(0) 120 | self.counter += eval_idxs.sum().item() 121 | mi_loss, lgt_reg, res_dict['y'] = [], [], None 122 | if self.counter >= b: 123 | Z = torch.cat(self._z_bank, 0).unbind(1) 124 | Y = [self.auxhead(z) for z in Z] 125 | U = torch.unbind(torch.cat(self._u_bank), 1) 126 | for j, py_j in enumerate(Y): 127 | for u, pu_i in enumerate(U): 128 | mi_loss_, lgt_reg_ = mi(py_j, pu_i / self._t) # Equation 6 129 | mi_loss.append(mi_loss_), lgt_reg.append(lgt_reg_) 130 | res_dict['y'] = torch.flatten(Y[0], 1).argmax(-1) 131 | self.reset_membank_list() 132 | zero = torch.tensor([0], device=x[0].device.type) 133 | res_dict['mi_loss'] = torch.stack(mi_loss).mean() if len(mi_loss) > 0 else zero 134 | res_dict['lgt_reg'] = torch.stack(lgt_reg).mean() if len(lgt_reg) > 0 else zero 135 | 136 | return res_dict 137 | 138 | # Reset the membanks, this actually applies only on STL10 training because of the unlabelled set 139 | def reset_membank_list(self): 140 | self._z_bank, self._u_bank = [], [] 141 | self.counter = 0 142 | 143 | 144 | class Encoder(nn.Module): 145 | def __init__(self, encoder_size=32, model_type='resnet18'): 146 | super(Encoder, self).__init__() 147 | 148 | # encoding block for local features 149 | print('Using a {}x{} encoder'.format(encoder_size, encoder_size)) 150 | inplanes = 64 151 | if encoder_size == 32: 152 | conv1 = nn.Sequential(nn.Conv2d(3, inplanes, kernel_size=3, stride=1, padding=1, bias=False), 153 | nn.BatchNorm2d(inplanes), 154 | nn.ReLU(inplace=True)) 155 | elif encoder_size == 96 or encoder_size == 64: 156 | conv1 = nn.Sequential(nn.Conv2d(3, inplanes, kernel_size=3, stride=1, padding=1, bias=False), 157 | nn.BatchNorm2d(inplanes), 158 | nn.ReLU(inplace=True), 159 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) 160 | 161 | else: 162 | raise RuntimeError("Could not build encoder." 163 | "Encoder size {} is not supported".format(encoder_size)) 164 | 165 | if model_type == 'resnet18': 166 | # ResNet18 block 167 | self.model = ResNet(BasicBlock, [2, 2, 2, 2], conv1) 168 | elif model_type == 'resnet34': 169 | self.model = ResNet(BasicBlock, [3, 4, 6, 3], conv1) 170 | elif model_type == 'resnet50': 171 | self.model = ResNet(Bottleneck, [3, 4, 6, 3], conv1) 172 | else: 173 | raise RuntimeError("Wrong model type") 174 | 175 | print(self.get_param_n()) 176 | 177 | def get_param_n(self): 178 | w = 0 179 | for p in self.model.parameters(): 180 | w += np.product(p.shape) 181 | return w 182 | 183 | def forward(self, x): 184 | return torch.flatten(self.model(x), 1) 185 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sys import platform 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision import datasets, transforms 7 | import socket 8 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 9 | from torchvision.datasets.vision import VisionDataset 10 | import pickle 11 | from collections import defaultdict 12 | from torch.utils.data import Dataset 13 | from tqdm.autonotebook import tqdm 14 | 15 | training_datasets = ['C10', 'C20', 'C100', 'STL10', 'TINY'] 16 | 17 | 18 | def get_encoder_size(dataset_name): 19 | if dataset_name in training_datasets[:3]: 20 | return 32 21 | if dataset_name == training_datasets[-2]: 22 | return 96 23 | if dataset_name == training_datasets[-1]: 24 | return 64 25 | raise RuntimeError("Error get encoder size, unknown setup size: {}".format(dataset_name)) 26 | 27 | 28 | def get_dataset(dataset_name): 29 | dataset_name = dataset_name.upper() 30 | if dataset_name in training_datasets: 31 | return dataset_name 32 | raise KeyError("Unknown dataset '" + dataset_name + "'. Must be one of " 33 | + ', '.join([name for name in training_datasets])) 34 | 35 | 36 | class Transforms: 37 | 38 | def __init__(self, nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std): 39 | assert len(size_crops) == len(nmb_crops) 40 | assert len(min_scale_crops) == len(nmb_crops) 41 | assert len(max_scale_crops) == len(nmb_crops) 42 | 43 | flip = transforms.RandomHorizontalFlip(p=0.5) 44 | normalize = transforms.Normalize(mean=mu, std=std) 45 | 46 | col_jitter = transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8) 47 | rnd_gray = transforms.RandomGrayscale(p=0.25) 48 | 49 | trans = [] 50 | for i in range(len(size_crops)): 51 | randomresizedcrop = transforms.RandomResizedCrop( 52 | size_crops[i], scale=(min_scale_crops[i], max_scale_crops[i])) 53 | trans.extend([transforms.Compose([ 54 | flip, randomresizedcrop, col_jitter, rnd_gray, transforms.ToTensor(), normalize])] * nmb_crops[i]) 55 | self.train_transform = trans 56 | 57 | self.test_transform = transforms.Compose([transforms.ToTensor(), normalize]) 58 | 59 | def __call__(self, inp): 60 | multi_crops = list(map(lambda trans: trans(inp), self.train_transform)) 61 | return multi_crops, self.test_transform(inp) 62 | 63 | 64 | def build_dataset(dataset, batch_size, nmb_workers, nmb_crops, size_crops, min_scale_crops, max_scale_crops, path): 65 | if dataset == training_datasets[0]: 66 | num_classes = 10 67 | mu = [0.4914, 0.4822, 0.4465] 68 | std = [0.2023, 0.1994, 0.2010] 69 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std) 70 | test_transform = train_transform.test_transform 71 | train_dataset = CIFAR10(root=path, train=True, transform=train_transform, download=True) 72 | test_dataset = CIFAR10(root=path, train=False, transform=test_transform, download=True) 73 | 74 | elif dataset in training_datasets[1:3]: 75 | num_classes = 20 if dataset == training_datasets[1] else 100 76 | coarse = True if dataset == training_datasets[1] else False 77 | mu = [0.5071, 0.4867, 0.4408] 78 | std = [0.2675, 0.2565, 0.2761] 79 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std) 80 | test_transform = train_transform.test_transform 81 | train_dataset = CIFAR100(root=path, train=True, transform=train_transform, download=True, c100_coarse=coarse) 82 | test_dataset = CIFAR100(root=path, train=False, transform=test_transform, download=True, c100_coarse=coarse) 83 | 84 | elif dataset == training_datasets[-2]: 85 | num_classes = 10 86 | mu = [0.43, 0.42, 0.39] 87 | std = [0.27, 0.26, 0.27] 88 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std) 89 | test_transform = train_transform.test_transform 90 | train_dataset = STL10(root=path, split='train+unlabeled', transform=train_transform, download=True) 91 | test_dataset = STL10(root=path, split='test', transform=test_transform, download=True) 92 | elif dataset == training_datasets[-1]: 93 | num_classes = 200 94 | mu = [0.485, 0.456, 0.406] 95 | std = [0.229, 0.224, 0.225] 96 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std) 97 | test_transform = train_transform.test_transform 98 | train_dataset = TinyImageNetDataset(path, transform=train_transform, download=False, preload=True) 99 | test_dataset = TinyImageNetDataset(path, mode='val', transform=test_transform, download=False, 100 | preload=False) 101 | else: 102 | raise RuntimeError("Error not supported dataset {}".format(dataset)) 103 | 104 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, 105 | pin_memory=True, drop_last=True, num_workers=nmb_workers) 106 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, 107 | pin_memory=True, drop_last=False, num_workers=nmb_workers) 108 | 109 | return train_loader, test_loader, num_classes 110 | 111 | 112 | ''' 113 | Overwritting Pytorch methods of CIFAR-10 and CIFAR-100 for being able to provide the coarse labels (CIFAR-20) 114 | Additionally, all below vision datasets return the index of the iterating instance for reference only 115 | ''' 116 | 117 | 118 | class CIFAR10(VisionDataset): 119 | """`CIFAR10 `_ Dataset. 120 | 121 | Args: 122 | root (string): Root directory of dataset where directory 123 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 124 | train (bool, optional): If True, creates dataset from training set, otherwise 125 | creates from test set. 126 | transform (callable, optional): A function/transform that takes in an PIL image 127 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 128 | target_transform (callable, optional): A function/transform that takes in the 129 | target and transforms it. 130 | download (bool, optional): If true, downloads the dataset from the internet and 131 | puts it in root directory. If dataset is already downloaded, it is not 132 | downloaded again. 133 | 134 | """ 135 | base_folder = 'cifar-10-batches-py' 136 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 137 | filename = "cifar-10-python.tar.gz" 138 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 139 | train_list = [ 140 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 141 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 142 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 143 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 144 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 145 | ] 146 | 147 | test_list = [ 148 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 149 | ] 150 | meta = { 151 | 'filename': 'batches.meta', 152 | 'key': 'label_names', 153 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 154 | } 155 | 156 | def __init__(self, root, train=True, transform=None, target_transform=None, 157 | download=False, c100_coarse=True): 158 | 159 | super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform) 160 | 161 | self.train = train # training set or test set 162 | 163 | if download: 164 | self.download() 165 | 166 | if not self._check_integrity(): 167 | raise RuntimeError('Dataset not found or corrupted.' + 168 | ' You can use download=True to download it') 169 | 170 | if self.train: 171 | downloaded_list = self.train_list 172 | else: 173 | downloaded_list = self.test_list 174 | 175 | self.data = [] 176 | self.targets = [] 177 | 178 | # now load the picked numpy arrays 179 | for file_name, checksum in downloaded_list: 180 | file_path = os.path.join(self.root, self.base_folder, file_name) 181 | with open(file_path, 'rb') as f: 182 | entry = pickle.load(f, encoding='latin1') 183 | self.data.append(entry['data']) 184 | if 'labels' in entry: 185 | self.targets.extend(entry['labels']) 186 | else: 187 | if c100_coarse is True: 188 | self.targets.extend(entry['coarse_labels']) 189 | self.meta['key'] = self.meta['key2'] 190 | else: 191 | self.targets.extend(entry['fine_labels']) 192 | self.meta['key'] = self.meta['key1'] 193 | 194 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 195 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 196 | self._load_meta() 197 | 198 | def _load_meta(self): 199 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 200 | if not check_integrity(path, self.meta['md5']): 201 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 202 | ' You can use download=True to download it') 203 | with open(path, 'rb') as infile: 204 | data = pickle.load(infile, encoding='latin1') 205 | self.classes = data[self.meta['key']] 206 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 207 | 208 | def __getitem__(self, index): 209 | """ 210 | Args: 211 | index (int): Index 212 | 213 | Returns: 214 | tuple: (image, target) where target is index of the target class. 215 | """ 216 | img, target = self.data[index], self.targets[index] 217 | 218 | # doing this so that it is consistent with all other datasets 219 | # to return a PIL Image 220 | img = Image.fromarray(img) 221 | 222 | if self.transform is not None: 223 | img = self.transform(img) 224 | 225 | if self.target_transform is not None: 226 | target = self.target_transform(target) 227 | 228 | return img, target, index 229 | 230 | def __len__(self): 231 | return len(self.data) 232 | 233 | def _check_integrity(self): 234 | root = self.root 235 | for fentry in (self.train_list + self.test_list): 236 | filename, md5 = fentry[0], fentry[1] 237 | fpath = os.path.join(root, self.base_folder, filename) 238 | if not check_integrity(fpath, md5): 239 | return False 240 | return True 241 | 242 | def download(self): 243 | if self._check_integrity(): 244 | print('Files already downloaded and verified') 245 | return 246 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 247 | 248 | def extra_repr(self): 249 | return "Split: {}".format("Train" if self.train is True else "Test") 250 | 251 | 252 | class CIFAR100(CIFAR10): 253 | """`CIFAR100 `_ Dataset. 254 | 255 | This is a subclass of the `CIFAR10` Dataset. 256 | """ 257 | base_folder = 'cifar-100-python' 258 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 259 | filename = "cifar-100-python.tar.gz" 260 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 261 | train_list = [ 262 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 263 | ] 264 | 265 | test_list = [ 266 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 267 | ] 268 | meta = { 269 | 'filename': 'meta', 270 | 'key1': 'fine_label_names', 271 | 'key2': 'coarse_label_names', 272 | 'md5': '7973b15100ade9c7d40fb424638fde48', 273 | } 274 | 275 | 276 | class STL10(datasets.STL10): 277 | 278 | def __getitem__(self, index): 279 | """ 280 | Args: 281 | index (int): Index 282 | 283 | Returns: 284 | tuple: (image, target) where target is index of the target class. 285 | """ 286 | if self.labels is not None: 287 | img, target = self.data[index], int(self.labels[index]) 288 | else: 289 | img, target = self.data[index], None 290 | 291 | # doing this so that it is consistent with all other datasets 292 | # to return a PIL Image 293 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 294 | # img = img.crop((8, 8, 88, 88)) 295 | if self.transform is not None: 296 | img = self.transform(img) 297 | 298 | if self.target_transform is not None: 299 | target = self.target_transform(target) 300 | 301 | return img, target, index 302 | 303 | 304 | def download_and_unzip(URL, root_dir): 305 | error_message = "Download is not yet implemented. Please, go to {URL} urself." 306 | raise NotImplementedError(error_message.format(URL)) 307 | 308 | 309 | def _add_channels(img): 310 | if len(img.getbands()) == 1: # third axis is the channels 311 | img = np.expand_dims(np.array(img), -1) 312 | img = np.tile(img, (1, 1, 3)) 313 | img = Image.fromarray(img) 314 | return img 315 | 316 | 317 | """Creates a paths datastructure for the tiny imagenet. 318 | Args: 319 | root_dir: Where the data is located 320 | download: Download if the data is not there 321 | Members: 322 | label_id: 323 | ids: 324 | nit_to_words: 325 | data_dict: 326 | """ 327 | 328 | 329 | class TinyImageNetPaths: 330 | def __init__(self, root_dir, download=False): 331 | if download: 332 | download_and_unzip('http://cs231n.stanford.edu/tiny-imagenet-200.zip', 333 | root_dir) 334 | train_path = os.path.join(root_dir, 'train') 335 | val_path = os.path.join(root_dir, 'val') 336 | test_path = os.path.join(root_dir, 'test') 337 | 338 | wnids_path = os.path.join(root_dir, 'wnids.txt') 339 | words_path = os.path.join(root_dir, 'words.txt') 340 | 341 | self._make_paths(train_path, val_path, test_path, 342 | wnids_path, words_path) 343 | 344 | def _make_paths(self, train_path, val_path, test_path, 345 | wnids_path, words_path): 346 | self.ids = [] 347 | with open(wnids_path, 'r') as idf: 348 | for nid in idf: 349 | nid = nid.strip() 350 | self.ids.append(nid) 351 | self.nid_to_words = defaultdict(list) 352 | with open(words_path, 'r') as wf: 353 | for line in wf: 354 | nid, labels = line.split('\t') 355 | labels = list(map(lambda x: x.strip(), labels.split(','))) 356 | self.nid_to_words[nid].extend(labels) 357 | 358 | self.paths = { 359 | 'train': [], # [img_path, id, nid, box] 360 | 'val': [], # [img_path, id, nid, box] 361 | 'test': [] # img_path 362 | } 363 | 364 | # Get the test paths 365 | self.paths['test'] = list(map(lambda x: os.path.join(test_path, x), 366 | os.listdir(test_path))) 367 | # Get the validation paths and labels 368 | with open(os.path.join(val_path, 'val_annotations.txt')) as valf: 369 | for line in valf: 370 | fname, nid, x0, y0, x1, y1 = line.split() 371 | fname = os.path.join(val_path, 'images', fname) 372 | bbox = int(x0), int(y0), int(x1), int(y1) 373 | label_id = self.ids.index(nid) 374 | self.paths['val'].append((fname, label_id, nid, bbox)) 375 | 376 | # Get the training paths 377 | train_nids = os.listdir(train_path) 378 | for nid in train_nids: 379 | anno_path = os.path.join(train_path, nid, nid + '_boxes.txt') 380 | imgs_path = os.path.join(train_path, nid, 'images') 381 | label_id = self.ids.index(nid) 382 | with open(anno_path, 'r') as annof: 383 | for line in annof: 384 | fname, x0, y0, x1, y1 = line.split() 385 | fname = os.path.join(imgs_path, fname) 386 | bbox = int(x0), int(y0), int(x1), int(y1) 387 | self.paths['train'].append((fname, label_id, nid, bbox)) 388 | 389 | 390 | """Datastructure for the tiny image dataset. 391 | Args: 392 | root_dir: Root directory for the data 393 | mode: One of "train", "test", or "val" 394 | preload: Preload into memory 395 | load_transform: Transformation to use at the preload time 396 | transform: Transformation to use at the retrieval time 397 | download: Download the dataset 398 | Members: 399 | tinp: Instance of the TinyImageNetPaths 400 | img_data: Image data 401 | label_data: Label data 402 | """ 403 | 404 | 405 | class TinyImageNetDataset(Dataset): 406 | def __init__(self, root_dir, mode='train', preload=True, load_transform=None, 407 | transform=None, download=False, max_samples=None): 408 | tinp = TinyImageNetPaths(root_dir, download) 409 | self.mode = mode 410 | self.label_idx = 1 # from [image, id, nid, box] 411 | self.preload = preload 412 | self.transform = transform 413 | self.transform_results = dict() 414 | 415 | self.IMAGE_SHAPE = (64, 64, 3) 416 | 417 | self.img_data = [] 418 | self.label_data = [] 419 | 420 | self.max_samples = max_samples 421 | self.samples = tinp.paths[mode] 422 | self.samples_num = len(self.samples) 423 | 424 | if self.max_samples is not None: 425 | self.samples_num = min(self.max_samples, self.samples_num) 426 | self.samples = np.random.permutation(self.samples)[:self.samples_num] 427 | 428 | if self.preload: 429 | load_desc = "Preloading {} data...".format(mode) 430 | self.img_data = {} # np.zeros((self.samples_num,) + self.IMAGE_SHAPE, dtype=np.float32) 431 | self.label_data = np.zeros((self.samples_num,), dtype=np.int) 432 | for idx in tqdm(range(self.samples_num), desc=load_desc): 433 | s = self.samples[idx] 434 | # img = imageio.imread(s[0]) 435 | img_ = Image.open(s[0]) 436 | img = img_.copy() 437 | img = _add_channels(img) 438 | img_.close() 439 | self.img_data[idx] = img 440 | if mode != 'test': 441 | self.label_data[idx] = s[self.label_idx] 442 | 443 | if load_transform: 444 | for lt in load_transform: 445 | result = lt(self.img_data, self.label_data) 446 | self.img_data, self.label_data = result[:2] 447 | if len(result) > 2: 448 | self.transform_results.update(result[2]) 449 | 450 | def __len__(self): 451 | return self.samples_num 452 | 453 | def __getitem__(self, idx): 454 | if self.preload: 455 | img = self.img_data[idx] 456 | target = None if self.mode == 'test' else self.label_data[idx] 457 | else: 458 | s = self.samples[idx] 459 | # img = imageio.imread(s[0]) 460 | img = Image.open(s[0]) 461 | img = _add_channels(img) 462 | target = None if self.mode == 'test' else s[self.label_idx] 463 | # img = img.crop((4, 4, 60, 60)) 464 | 465 | # to return a PIL Image 466 | # img = Image.fromarray(np.transpose(img, (1, 2, 0))) 467 | if self.transform: 468 | img = self.transform(img) 469 | return img, target, idx 470 | 471 | 472 | dir_structure_help = r""" 473 | TinyImageNetPath 474 | ├── test 475 | │ └── images 476 | │ ├── test_0.JPEG 477 | │ ├── t... 478 | │ └── ... 479 | ├── train 480 | │ ├── n01443537 481 | │ │ ├── images 482 | │ │ │ ├── n01443537_0.JPEG 483 | │ │ │ ├── n... 484 | │ │ │ └── ... 485 | │ │ └── n01443537_boxes.txt 486 | │ ├── n01629819 487 | │ │ ├── images 488 | │ │ │ ├── n01629819_0.JPEG 489 | │ │ │ ├── n... 490 | │ │ │ └── ... 491 | │ │ └── n01629819_boxes.txt 492 | │ ├── n... 493 | │ │ ├── images 494 | │ │ │ ├── ... 495 | │ │ │ └── ... 496 | ├── val 497 | │ ├── images 498 | │ │ ├── val_0.JPEG 499 | │ │ ├── v... 500 | │ │ └── ... 501 | │ └── val_annotations.txt 502 | ├── wnids.txt 503 | └── words.txt 504 | """ 505 | --------------------------------------------------------------------------------