├── figures
└── diagram.png
├── requirements.txt
├── metrics.py
├── costs.py
├── checkpoint.py
├── utils.py
├── test.py
├── task_test.py
├── stats.py
├── task_train.py
├── train.py
├── README.md
├── graphs.py
├── model.py
└── datasets.py
/figures/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foiv0s/imc-swav-pub/HEAD/figures/diagram.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow==7.2.0
2 | numpy==1.19.2
3 | sklearn==0.0
4 | scikit-learn==0.23.2
5 | scipy==1.5.2
6 | tensorboard==1.13.1
7 | tensorboardX==2.1
8 | torch==1.6.0+cu101
9 | torchvision==0.7.0+cu101
10 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import normalized_mutual_info_score
2 | from sklearn.metrics import adjusted_rand_score
3 | import numpy as np
4 | import warnings
5 | from scipy.optimize import linear_sum_assignment as linear_assignment
6 |
7 | ari = adjusted_rand_score
8 | nmi = normalized_mutual_info_score
9 |
10 |
11 | def acc(y_true, y_pred, detailed=False):
12 | def warn(*args, **kwargs):
13 | pass
14 |
15 | warnings.warn = warn
16 | y_true = y_true.astype(np.int64)
17 | assert y_pred.size == y_true.size
18 | D = max(y_pred.max(), y_true.max()) + 1
19 | w = np.zeros((D, D), dtype=np.int64)
20 | for i in range(y_pred.size):
21 | w[y_pred[i], y_true[i]] += 1
22 | ind = linear_assignment(w.max() - w)
23 | if detailed:
24 | return sum([w[i, j] for i, j in zip(ind[0], ind[1])]) * 1.0 / y_pred.size, w, ind
25 | else:
26 | return sum([w[i, j] for i, j in zip(ind[0], ind[1])]) * 1.0 / y_pred.size
27 |
--------------------------------------------------------------------------------
/costs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from utils import shoot_infs_v2
4 |
5 | eps = 1e-17
6 |
7 |
8 | def entropy(x1, x2=None):
9 | x2 = x1 if x2 is None else x2
10 | return -(x1 * torch.log(x2.clamp(eps, 1.)))
11 |
12 |
13 | def mi(y, u, b=4, a=1e-2):
14 | lgt_reg = a * torch.relu(torch.abs(y) - 5.).sum(-1).mean()
15 | py, pu = torch.softmax(y, -1), torch.softmax(u, -1)
16 | p_yu = torch.matmul(py.T, pu) # k x k’
17 | p_yu /= p_yu.sum() # normalize to sum 1
18 | p_u = p_yu.sum(0).view(1, -1) # marginal p_u
19 | p_y = p_yu.sum(1).view(-1, 1) # marginal p_y
20 | h_uy = (p_yu * (torch.log(p_u) - torch.log(p_yu))).sum() # conditional entropy
21 | hy = b * (p_yu * torch.log(p_y)).sum() # weighted marginal entropy
22 | return h_uy + hy, lgt_reg
23 |
24 |
25 | def sinkhorn(Q, nmb_iters):
26 | with torch.no_grad():
27 | Q = shoot_infs_v2(Q)
28 | sum_Q = torch.sum(Q)
29 | Q /= sum_Q
30 | r = torch.ones(Q.shape[0]).cuda() / Q.shape[0]
31 | c = torch.ones(Q.shape[1]).cuda() / (-1 * Q.shape[1])
32 | for it in range(nmb_iters):
33 | u = torch.sum(Q, dim=1)
34 | u = r / u
35 | u = shoot_infs_v2(u)
36 | Q *= u.unsqueeze(1)
37 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
38 | return (Q / torch.sum(Q)).float()
39 |
--------------------------------------------------------------------------------
/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from model import Model
4 |
5 |
6 | class Checkpointer:
7 | def __init__(self, output_dir=None, filename='imc_swav.cpt'):
8 | self.output_dir = output_dir
9 | self.epoch = 0
10 | self.model = None
11 | self.filename = filename
12 |
13 | def track_new_model(self, model):
14 | self.model = model
15 |
16 | def restore_model_from_checkpoint(self, cpt_path):
17 | ckp = torch.load(cpt_path)
18 | hp = ckp['hyperparams']
19 | params = ckp['model']
20 | self.epoch = ckp['epoch']
21 |
22 | self.model = Model(n_classes=hp['n_classes'], encoder_size=hp['encoder_size'], prototypes=hp['prototypes'],
23 | project_dim=hp['project_dim'], tau=hp['tau'], eps=hp['eps'])
24 |
25 | model_dict = self.model.state_dict()
26 | model_dict.update(params)
27 | params = model_dict
28 | self.model.load_state_dict(params)
29 |
30 | print("***** CHECKPOINTING *****\n"
31 | "Model restored from checkpoint.\n"
32 | "Training epoch {}\n"
33 | "*************************"
34 | .format(self.epoch))
35 | return self.model
36 |
37 | def _get_state(self):
38 | return {
39 | 'model': self.model.state_dict(),
40 | 'hyperparams': self.model.hyperparams,
41 | 'epoch': self.epoch
42 | }
43 |
44 | def update(self, epoch):
45 | self.epoch = epoch
46 | cpt_path = os.path.join(self.output_dir, self.filename)
47 | torch.save(self._get_state(), cpt_path)
48 |
49 | def get_current_position(self):
50 | return self.epoch
51 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from stats import update_train_accuracies
4 |
5 |
6 | # Evaluate accuracy on test set
7 | def test_model(model, test_loader, device, stats):
8 | # warm up norm layers
9 | _warmup_batchnorm(model, test_loader, device, batches=50, train_loader=False)
10 |
11 | model.eval()
12 | targets, predictions = [], []
13 | for _, (images, targets_, idxs) in enumerate(test_loader):
14 | images = images.to(device)
15 | val_idxs = targets_ >= 0
16 | with torch.no_grad():
17 | res_dict = model(x=images, eval_only=True)
18 | predictions.append(res_dict['y'][val_idxs].cpu().numpy()), targets.append(targets_[val_idxs])
19 | targets, predictions = np.concatenate(targets).ravel(), np.concatenate(predictions).ravel()
20 | model.train()
21 | update_train_accuracies(stats, targets, predictions, 'Test Clustering ')
22 |
23 |
24 | def _warmup_batchnorm(model, data_loader, device, batches=50, train_loader=False):
25 | model.train()
26 | for i, (images, _, idxs) in enumerate(data_loader):
27 | if i == batches:
28 | break
29 | if train_loader:
30 | images = images[0]
31 | _ = model(x=images.to(device), eval_only=True)
32 |
33 |
34 | '''
35 | Following two methods (distributed_sinkhorn, shoot_infs) are based on SwAV implementation
36 | credits to https://github.com/facebookresearch/swav
37 | '''
38 |
39 |
40 | def sinkhorn(Q, nmb_iters):
41 | with torch.no_grad():
42 | Q = shoot_infs_v2(Q)
43 | sum_Q = torch.sum(Q)
44 | Q /= sum_Q
45 | r = torch.ones(Q.shape[0]).cuda() / Q.shape[0]
46 | c = torch.ones(Q.shape[1]).cuda() / (-1 * Q.shape[1])
47 | for it in range(nmb_iters):
48 | u = torch.sum(Q, dim=1)
49 | u = r / u
50 | u = shoot_infs_v2(u)
51 | Q *= u.unsqueeze(1)
52 | Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
53 | return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
54 |
55 |
56 | def shoot_infs_v2(inp_tensor):
57 | """Replaces inf by maximum of tensor"""
58 | mask_inf = torch.isinf(inp_tensor)
59 | if mask_inf.sum() > 0.:
60 | inp_tensor[mask_inf] = 0
61 | m = torch.max(inp_tensor)
62 | inp_tensor[mask_inf] = m
63 | return inp_tensor
64 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from stats import AverageMeterSet
4 | from datasets import build_dataset, get_dataset
5 | from checkpoint import Checkpointer
6 | from utils import test_model
7 | import os
8 |
9 | parser = argparse.ArgumentParser(description='IMC-SwAV - Testing')
10 |
11 | parser.add_argument('--cpt_load_path', type=str, default=None, help='path from which to load checkpoint (if available)')
12 | parser.add_argument('--dataset', type=str, default='c10')
13 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process')
14 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
15 | help="list of number of crops (i.e.: [2, 4])")
16 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+",
17 | help="crops resolutions (i.e.: [28, 18])")
18 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+",
19 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])")
20 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+",
21 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])")
22 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)')
23 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)')
24 |
25 | args = parser.parse_args()
26 |
27 | if args.dev is not None:
28 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev
29 |
30 |
31 | def test(model, test_loader, device, stats):
32 | test_model(model, test_loader, device, stats)
33 |
34 |
35 | def main():
36 | # get the dataset
37 | dataset = get_dataset(args.dataset)
38 |
39 | _, test_loader, num_classes = \
40 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers,
41 | nmb_crops=args.nmb_crops, size_crops=args.size_crops,
42 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops)
43 | checkpointer = Checkpointer()
44 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu')
45 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path)
46 | model = model.to(torch_device)
47 |
48 | test_stats = AverageMeterSet()
49 | test(model, test_loader, torch_device, test_stats)
50 | stat_str = test_stats.pretty_string()
51 | print(stat_str)
52 |
53 |
54 | if __name__ == "__main__":
55 | print(args)
56 | main()
57 |
--------------------------------------------------------------------------------
/task_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from stats import AverageMeterSet
4 | from datasets import build_dataset, get_dataset
5 | from checkpoint import Checkpointer
6 | from utils import test_model
7 | import os
8 |
9 | parser = argparse.ArgumentParser(description='IMC-SwAV - Testing')
10 |
11 | parser.add_argument('--cpt_load_path', type=str, default=None, help='path from which to load checkpoint (if available)')
12 | parser.add_argument('--dataset', type=str, default='c10')
13 | parser.add_argument('--path', type=str, default=None, help='Root directory for the dataset')
14 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process')
15 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
16 | help="list of number of crops (i.e.: [2, 4])")
17 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+",
18 | help="crops resolutions (i.e.: [28, 18])")
19 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+",
20 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])")
21 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+",
22 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])")
23 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)')
24 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)')
25 |
26 | args = parser.parse_args()
27 |
28 | if args.dev is not None:
29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev
30 |
31 |
32 | def main():
33 | # get the dataset
34 | dataset = get_dataset(args.dataset)
35 |
36 | _, test_loader, num_classes = \
37 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers,
38 | nmb_crops=args.nmb_crops, size_crops=args.size_crops,
39 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, path=args.path)
40 | checkpointer = Checkpointer()
41 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu')
42 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path)
43 | model = model.to(torch_device)
44 |
45 | test_stats = AverageMeterSet()
46 | test_model(model, test_loader, torch_device, test_stats)
47 | stat_str = test_stats.pretty_string()
48 | print(stat_str)
49 |
50 |
51 | if __name__ == "__main__":
52 | print(args)
53 | main()
54 |
--------------------------------------------------------------------------------
/stats.py:
--------------------------------------------------------------------------------
1 | from tensorboardX import SummaryWriter
2 | from metrics import ari, nmi, acc
3 |
4 | '''
5 | Implementation of classes of AverageMeterSet and StatTracker are based on below repository
6 | https://github.com/Philip-Bachman/amdim-public
7 | '''
8 |
9 |
10 | class AverageMeterSet:
11 | def __init__(self):
12 | self.sums = {}
13 | self.counts = {}
14 | self.avgs = {}
15 |
16 | def _compute_avgs(self):
17 | for name in self.sums:
18 | self.avgs[name] = float(self.sums[name]) / float(self.counts[name])
19 |
20 | def update_dict(self, name_val_dict, n=1):
21 | for name, val in name_val_dict.items():
22 | self.update(name, val, n)
23 |
24 | def update(self, name, value, n=1):
25 | if name not in self.sums:
26 | self.sums[name] = value
27 | self.counts[name] = n
28 | else:
29 | self.sums[name] = self.sums[name] + value
30 | self.counts[name] = self.counts[name] + n
31 |
32 | def pretty_string(self, ignore=('zzz')):
33 | self._compute_avgs()
34 | s = []
35 | for name, avg in self.avgs.items():
36 | keep = True
37 | for ign in ignore:
38 | if ign in name:
39 | keep = False
40 | if keep:
41 | s.append('{0:s}: {1:.3f}'.format(name, avg))
42 | s = ', '.join(s)
43 | return s
44 |
45 | def averages(self, idx, prefix=''):
46 | self._compute_avgs()
47 | return {prefix + name: (avg, idx) for name, avg in self.avgs.items()}
48 |
49 |
50 | class StatTracker:
51 |
52 | def __init__(self, log_name=None, log_dir=None):
53 | assert ((log_name is None) or (log_dir is None))
54 | if log_dir is None:
55 | self.writer = SummaryWriter(comment=log_name)
56 | else:
57 | print('log_dir: {}'.format(str(log_dir)))
58 | try:
59 | self.writer = SummaryWriter(logdir=log_dir)
60 | except:
61 | self.writer = SummaryWriter(log_dir=log_dir)
62 |
63 | def close(self):
64 | self.writer.close()
65 |
66 | def record_stats(self, stat_dict):
67 | for stat_name, stat_vals in stat_dict.items():
68 | self.writer.add_scalar(stat_name, stat_vals[0], stat_vals[1])
69 |
70 |
71 | # Helper function for tracking accuracy on training set
72 | def update_train_accuracies(epoch_stats, targets, predictions, name='train'):
73 | val_idxs = targets >= 0
74 | epoch_stats.update(name + ' ACC', acc(targets[val_idxs], predictions[val_idxs]), n=1)
75 | epoch_stats.update(name + ' NMI', nmi(targets[val_idxs], predictions[val_idxs]), n=1)
76 | epoch_stats.update(name + ' ARI', ari(targets[val_idxs], predictions[val_idxs]), n=1)
77 |
78 |
--------------------------------------------------------------------------------
/task_train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import torch
4 | import torch.optim as optim
5 | from torch.optim.lr_scheduler import MultiStepLR
6 | import numpy as np
7 | from utils import test_model
8 | from stats import AverageMeterSet, update_train_accuracies
9 |
10 |
11 | def _train(model, optimizer, scheduler_inf, train_loader, test_loader, nmb_crops, stat_tracker,
12 | checkpointer, device, warmup, epochs):
13 | '''
14 | Training loop for optimizing overall framework
15 | '''
16 | lr_real = optimizer.param_groups[0]['lr']
17 | torch.cuda.empty_cache()
18 |
19 | # If mixed precision is on, will add the necessary hooks into the model
20 | # and optimizer for half() conversions
21 | next_epoch = checkpointer.get_current_position()
22 | total_updates = next_epoch * len(train_loader)
23 | # run main training loop
24 | for epoch in range(next_epoch, epochs):
25 | epoch_stats = AverageMeterSet()
26 | time_epoch = time.time()
27 | targets, predictions = [], []
28 | model.reset_membank_list()
29 | for _, ((aug_imgs, raw_imgs), targets_, idx) in enumerate(train_loader):
30 |
31 | # Perform clustering only on label idxs
32 | val_idxs = targets_ >= 0
33 | targets.append(targets_[val_idxs].numpy())
34 | aug_imgs = [aug_img.to(device) for aug_img in aug_imgs]
35 |
36 | res_dict = model(x=aug_imgs, eval_only=False, nmb_crops=nmb_crops, eval_idxs=val_idxs)
37 |
38 | # Warmup
39 | if total_updates < warmup:
40 | lr_scale = min(1., float(total_updates + 1) / float(warmup))
41 | for i, pg in enumerate(optimizer.param_groups):
42 | pg['lr'] = lr_scale * lr_real
43 |
44 | loss_opt = res_dict['swav_loss'] + res_dict['mi_loss'] + res_dict['lgt_reg']
45 | optimizer.zero_grad()
46 | loss_opt.backward()
47 |
48 | # Stop gradient for prototypes till warmup is over
49 | if total_updates < warmup:
50 | model.prototypes.prototypes.weight.grad = None
51 | optimizer.step()
52 |
53 | epoch_stats.update_dict({'swav_loss': res_dict['swav_loss'].item(), }, n=1)
54 |
55 | # None can be only on STL10, if not enough labelled training instances to evaluate
56 | if res_dict['y'] is not None:
57 | predictions.append(res_dict['y'].cpu().numpy())
58 | epoch_stats.update_dict({
59 | 'mi_loss': res_dict['mi_loss'].item(),
60 | 'lgt_reg': res_dict['lgt_reg'].item(),
61 | }, n=1)
62 | total_updates += 1
63 | time_stop = time.time()
64 | spu = (time_stop - time_epoch)
65 | print('Epoch {0:d}, {1:.4f} sec/epoch'.format(epoch, spu))
66 | # update learning rate
67 | scheduler_inf.step()
68 | targets, predictions = np.concatenate(targets).ravel(), np.concatenate(predictions).ravel()
69 | test_model(model, test_loader, device, epoch_stats)
70 | # Evaluation only for the labelled set (in case of STL10)
71 | update_train_accuracies(epoch_stats, targets[:predictions.shape[0]], predictions, 'Train Clustering ')
72 | epoch_str = epoch_stats.pretty_string()
73 | diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str)
74 | print(diag_str)
75 | sys.stdout.flush()
76 | stat_tracker.record_stats(epoch_stats.averages(epoch, prefix='costs/'))
77 | checkpointer.update(epoch + 1)
78 |
79 |
80 | def train_model(model, learning_rate, train_loader, test_loader, nmb_crops, stat_tracker,
81 | checkpointer, device, warmup, epochs, l2_w):
82 | mods = [m for m in model.modules_]
83 | optimizer = optim.Adam([{'params': mod.parameters(), 'lr': learning_rate} for i, mod in enumerate(mods)],
84 | betas=(0.8, 0.999), weight_decay=l2_w)
85 |
86 | scheduler = MultiStepLR(optimizer, milestones=[150, 300, 400], gamma=0.4)
87 | _train(model, optimizer, scheduler, train_loader, test_loader, nmb_crops, stat_tracker,
88 | checkpointer, device, warmup, epochs)
89 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | from stats import StatTracker
7 | from datasets import build_dataset, get_dataset, get_encoder_size
8 | from model import Model
9 | from checkpoint import Checkpointer
10 | from task_train import train_model
11 |
12 | parser = argparse.ArgumentParser(description='IMC-SwAV - Training')
13 |
14 | parser.add_argument('--dataset', type=str, default='C10')
15 | parser.add_argument('--path', type=str, default=None, help='Root directory for the dataset')
16 |
17 | # Transformation parameters/number of crops sizes. First index reports the high resolution, low resolutions follows.
18 | parser.add_argument('--nmb_workers', type=int, default=8, help='Number of workers on Transformation process')
19 | parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+",
20 | help="list of number of crops (i.e.: [2, 4])")
21 | parser.add_argument("--size_crops", type=int, default=[28, 18], nargs="+",
22 | help="crops resolutions (i.e.: [28, 18])")
23 | parser.add_argument("--max_scale_crops", type=float, default=[1., 0.4], nargs="+",
24 | help="argument in RandomResizedCrop (i.e.: [1., 0.5])")
25 | parser.add_argument("--min_scale_crops", type=float, default=[0.2, 0.08], nargs="+",
26 | help="argument in RandomResizedCrop (i.e.: [0.2, 0.08])")
27 | parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)')
28 |
29 | # Model and training parameters
30 | parser.add_argument('--tau', type=float, default=0.1, help='Temperature parameter on Softmax (Eq. 2)')
31 | parser.add_argument('--eps', type=float, default=0.05, help='Epsilon scalar of Sinkhorn-Knopp (Eq. 3)')
32 | parser.add_argument('--warmup', type=int, default=500, help='Epoch of warmup schedule')
33 | parser.add_argument('--epochs', type=int, default=500, help='Training epoch')
34 | parser.add_argument('--learning_rate', type=float, default=0.0005, help='Learning rate')
35 | parser.add_argument("--project_dim", type=int, default=128, help="Project embedding dimension")
36 | parser.add_argument("--prototypes", type=int, default=1000, help="Number of prototypes")
37 | parser.add_argument("--model_type", type=str, default='resnet18', help="Type of ResNet")
38 |
39 | # parameters for output, logging, checkpointing, etc
40 | parser.add_argument('--output_dir', type=str, default='./default_run',
41 | help='Storing path for Tensorboard events and checkpoints')
42 | parser.add_argument('--cpt_load_path', type=str, default=None, help='Load checkpoint path+name(if available)')
43 | parser.add_argument('--cpt_name', type=str, default='imc_swav.cpt', help='Checkpoint name during training')
44 | parser.add_argument('--run_name', type=str, default='default_run', help='Tensorboard summary name')
45 | parser.add_argument("--dev", type=str, help='GPU device number (if applying)')
46 | parser.add_argument("--l2_w", type=float, default=1e-5, help='l_2 weights')
47 |
48 | args = parser.parse_args()
49 |
50 | if args.dev is not None:
51 | os.environ["CUDA_VISIBLE_DEVICES"] = args.dev
52 |
53 |
54 | def main():
55 | # create output dir (only if it doesn't exist)
56 | if not os.path.isdir(args.output_dir):
57 | os.mkdir(args.output_dir)
58 |
59 | # get the dataset
60 | dataset = get_dataset(args.dataset)
61 | encoder_size = get_encoder_size(dataset)
62 |
63 | # get a helper object for tensorboard logging
64 | log_dir = os.path.join(args.output_dir, args.run_name)
65 | stat_tracker = StatTracker(log_dir=log_dir)
66 |
67 | # get training and testing loaders
68 | train_loader, test_loader, num_classes = \
69 | build_dataset(dataset=dataset, batch_size=args.batch_size, nmb_workers=args.nmb_workers,
70 | nmb_crops=args.nmb_crops, size_crops=args.size_crops,
71 | min_scale_crops=args.min_scale_crops, max_scale_crops=args.max_scale_crops, path=args.path)
72 |
73 | torch_device = torch.device('cuda') if torch.cuda.device_count() > 0 else torch.device('cpu')
74 | checkpointer = Checkpointer(args.output_dir, args.cpt_name)
75 | if args.cpt_load_path:
76 | model = checkpointer.restore_model_from_checkpoint(args.cpt_load_path)
77 | else:
78 | # create new model with random parameters
79 | model = Model(n_classes=num_classes, encoder_size=encoder_size, prototypes=args.prototypes,
80 | project_dim=args.project_dim, tau=args.tau, eps=args.eps, model_type=args.model_type)
81 | checkpointer.track_new_model(model)
82 |
83 | model = model.to(torch_device)
84 |
85 | train_model(model, args.learning_rate, train_loader, test_loader, args.nmb_crops, stat_tracker,
86 | checkpointer, torch_device, args.warmup, args.epochs, args.l2_w)
87 |
88 |
89 | if __name__ == "__main__":
90 | print(args)
91 | main()
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Information Maximization Clustering via Multi-View Self-Labelling
2 |
3 | ## Introduction
4 | **This is an implementation code written in Python (version 3.6.9) of IMC-SwAV based on manuscript [paper](https://arxiv.org/abs/2103.07368)
5 | **
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Performance
14 |
15 | The reported performance of our proposed model is based on ResNet18 encoder architecture.
16 | We train our IMC-SwAV for 15 independent runs on training set and we report the result of testing set only.
17 |
18 | ### Average Performance
19 |
20 | Table of average performance and the corresponding standard deviation.
21 |
22 | Dataset | Acc | NMI | ARI
23 | --- | --- | --- | ---
24 | CIFAR-10|89.1 (±0.5) | 81.1 (±0.7)| 79.0 (±1.0)
25 | CIFAR-100-20 | 49.0 (±1.8)| 50.3 (±1.2) | 33.7 (±1.3)
26 | STL10| 83.1 (±1.0) | 72.9 (±0.9) | 68.5 (±1.4)
27 | Tiny-Imagenet| 27.9 (±0.3) | 48.5 (±2.0) | 14.3 (±2.1)
28 |
29 | ### Best Performance
30 |
31 | Below table reports the best recorded performance from our model.
32 |
33 | Dataset | Acc | NMI | ARI
34 | --- | --- | --- | ---
35 | CIFAR-10|89.7 | 81.8| 80.0
36 | CIFAR-100-20 | 51.9| 52.7 | 36.1
37 | STL10| 85.3 | 74.7 | 71.6
38 | Tiny-Imagenet| 28.2 | 52.6 | 14.6
39 |
40 | Below, we report separate the result of our proposed IMC-SwAV in CIFAR-100 experiment (100 class)
41 |
42 | Dataset | Top-1 ACC | Top-5 ACC | NMI | ARI
43 | --- | --- | --- | --- | ---
44 | CIFAR-100| 45.1 | 67.5 | 60.8 | 30.7
45 |
46 | ## Usage
47 |
48 | All hyper parameters apply across all datasets (default setup/experiment) in the submission document as following:
49 |
50 | Settings related with the multi-crop \
51 | --nmb_crops 2 4 \
52 | --max_scale_crops 1. 0.4 \
53 | --min_scale_crops 0.2 0.08
54 |
55 | Settings related with SwAV \
56 | --tau 0.1 \
57 | --eps 0.05 \
58 | --project_dim 128 \
59 | --prototypes 1000
60 |
61 | Settings related to the training \
62 | --learning_rate 0.0005 \
63 | --warmup 500
64 | --l2_w 1e-5
65 |
66 | Settings related to the dataset
67 | --path ROOT_DIRECTORY_OF_THE_DATASET (the path folder of the dataset)
68 |
69 | To run any of the code, the directory path of the dataset is required
70 | otherwise it will automatically download to './dataset'
71 | #### CIFAR-10
72 |
73 | To run the training code.
74 |
75 | ```
76 | python train.py --dataset C10 --path ./dataset --size_crops 28 18 \
77 | --output_dir ./c10 --cpt_name c10.cpt
78 | ```
79 |
80 | #### CIFAR-100/20
81 |
82 | To run the training code.
83 |
84 | ```
85 | python train.py --dataset C20 --path ./dataset --size_crops 28 18 \
86 | --output_dir ./c20 --cpt_name c20.cpt
87 | ```
88 |
89 | #### STL10
90 |
91 | To run the training code.
92 |
93 | ```
94 | python train.py --dataset STL10 --path ./dataset --size_crops 76 52 \
95 | --output_dir ./stl10 --cpt_name stl10.cpt --path ./dataset
96 | ```
97 |
98 | #### CIFAR100
99 |
100 | To run the training code.
101 |
102 | ```
103 | python train.py --dataset C100 --path ./dataset --size_crops 28 18 --batch_size 512 \
104 | --output_dir ./c100 --cpt_name c100.cpt
105 | ```
106 |
107 | #### Tiny-Imagenet
108 |
109 | To run the training code.
110 |
111 | ```
112 | python train.py --dataset tiny --path ./dataset --size_crops 56 36 --batch_size 512 \
113 | --output_dir ./tiny --cpt_name tiny.cpt
114 | ```
115 | ## The evaluation of the model.
116 |
117 | ##### Example evaluation on CIFAR10/100-20/100:
118 |
119 | Through the argument '--cpt_load_path', the full path of the stored model is parsed.
120 |
121 | ```
122 | python test.py --dataset c10 --path ./dataset --size_crops 28 18 --cpt_load_path ./c10/imc_swav.cpt
123 | ```
124 |
125 | ```
126 | python test.py --dataset c20 --path ./dataset --size_crops 28 18 --cpt_load_path ./c20/imc_swav.cpt
127 | ```
128 |
129 | ```
130 | python test.py --dataset c100 --path ./dataset --size_crops 28 18 --cpt_load_path ./c100/imc_swav.cpt
131 | ```
132 |
133 | ##### Example evaluation on STL10:
134 |
135 | ```
136 | python test.py --dataset STL10 --path ./dataset --size_crops 76 52 --cpt_load_path ./stl10/imc_swav.cpt
137 | ```
138 |
139 | ##### Example evaluation on Tiny-Imagenet:
140 |
141 | ```
142 | python test.py --dataset tiny --path ./dataset --size_crops 56 36 --cpt_load_path ./tiny/tiny.cpt
143 | ```
144 |
145 | ## Notes
146 |
147 | - During the training, each epoch reports the model's performance on test (validation set)
148 | and the training set (performance on training set is based on transformed instances).
149 |
150 | - The classifier head is trained and evaluated only for labelled set on STL10 dataset. The unlabelled part of STL10 is
151 | used only to train the encoder and prototypes.
152 |
153 | - All tests have been performed in Cuda version 10.1.
154 |
155 | ## Acknowledgement for reference repos
156 | - [AMDIM](https://github.com/Philip-Bachman/amdim-public)
157 | - [SwAV](https://github.com/facebookresearch/swav)
158 |
159 | ## Citation
160 |
161 | ```shell
162 | @misc{ntelemis2021information,
163 | title={Information Maximization Clustering via Multi-View Self-Labelling},
164 | author={Foivos Ntelemis and Yaochu Jin and Spencer A. Thomas},
165 | year={2021},
166 | eprint={2103.07368},
167 | archivePrefix={arXiv},
168 | primaryClass={cs.CV}
169 | }
170 | ```
171 |
172 |
173 |
--------------------------------------------------------------------------------
/graphs.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torchvision.models.resnet import Bottleneck, BasicBlock, conv1x1
4 |
5 | '''Modify version of Pytorch's ResNet'''
6 |
7 |
8 | class ResNet(nn.Module):
9 | def __init__(self, block, layers, conv, zero_init_residual=False, groups=1, width_per_group=64,
10 | replace_stride_with_dilation=None, norm_layer=None):
11 | super(ResNet, self).__init__()
12 | if norm_layer is None:
13 | norm_layer = nn.BatchNorm2d
14 | self._norm_layer = norm_layer
15 | self.covn1 = conv
16 | self.inplanes = 64
17 | self.dilation = 1
18 | if replace_stride_with_dilation is None:
19 | # each element in the tuple indicates if we should replace
20 | # the 2x2 stride with a dilated convolution instead
21 | replace_stride_with_dilation = [False, False, False]
22 | if len(replace_stride_with_dilation) != 3:
23 | raise ValueError("replace_stride_with_dilation should be None "
24 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
25 | self.groups = groups
26 | self.base_width = width_per_group
27 |
28 | self.layer1 = self._make_layer(block, 64, layers[0])
29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
30 | dilate=replace_stride_with_dilation[0]) # 16x16
31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
32 | dilate=replace_stride_with_dilation[1])
33 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
34 | dilate=replace_stride_with_dilation[2])
35 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
36 | # self.fc.init_weights(1.)
37 | for m in self.modules():
38 | if isinstance(m, nn.Conv2d):
39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
40 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
41 | nn.init.constant_(m.weight, 1)
42 | nn.init.constant_(m.bias, 0)
43 |
44 | # Zero-initialize the last BN in each residual branch,
45 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
46 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
47 | self.layer_list = nn.ModuleList([self.covn1, self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool])
48 | if zero_init_residual:
49 | for m in self.modules():
50 | if isinstance(m, Bottleneck):
51 | nn.init.constant_(m.bn3.weight, 0)
52 | elif isinstance(m, BasicBlock):
53 | nn.init.constant_(m.bn2.weight, 0)
54 |
55 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
56 | norm_layer = self._norm_layer
57 | downsample = None
58 | previous_dilation = self.dilation
59 | if dilate:
60 | self.dilation *= stride
61 | stride = 1
62 | if stride != 1 or self.inplanes != planes * block.expansion:
63 | downsample = nn.Sequential(
64 | conv1x1(self.inplanes, planes * block.expansion, stride),
65 | norm_layer(planes * block.expansion),
66 | )
67 |
68 | layers = []
69 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
70 | self.base_width, previous_dilation, norm_layer))
71 | self.inplanes = planes * block.expansion
72 | for _ in range(1, blocks):
73 | layers.append(block(self.inplanes, planes, groups=self.groups,
74 | base_width=self.base_width, dilation=self.dilation,
75 | norm_layer=norm_layer))
76 |
77 | return nn.Sequential(*layers)
78 |
79 | def _forward_impl(self, x):
80 |
81 | layer_acts = [x]
82 | for i, layer in enumerate(self.layer_list):
83 | layer_in = layer_acts[-1]
84 | layer_out = layer(layer_in)
85 | layer_acts.append(layer_out)
86 | return layer_acts[-1]
87 |
88 | def forward(self, x):
89 | return self._forward_impl(x)
90 |
91 |
92 | class Projection(nn.Module):
93 | def __init__(self, n_input, n_out=128):
94 | super(Projection, self).__init__()
95 | # self.project = nn.Sequential(nn.Conv2d(n_input, n_input, 1, 1, 0),
96 | # # nn.BatchNorm2d(n_input, affine=True),
97 | # # nn.ReLU(inplace=True),
98 | # nn.Conv2d(n_input, n_out, 1, 1, 0, bias=True))
99 | self.project = nn.Sequential(nn.Linear(n_input, n_input, bias=True),
100 | # nn.BatchNorm1d(n_input, affine=True),
101 | # nn.ReLU(inplace=True),
102 | nn.Linear(n_input, n_out, bias=True))
103 | return
104 |
105 | def forward(self, r1_x):
106 | # out = self.project(r1_x)
107 | # out = nn.functional.normalize(out, dim=1, p=2)
108 | return self.project(r1_x)
109 |
110 |
111 | class Prototypes(nn.Module):
112 | def __init__(self, n_input, n_out=1000):
113 | super(Prototypes, self).__init__()
114 | # self.prototypes = nn.Conv2d(n_input, n_out, 1, 1, 0, bias=False)
115 | self.prototypes = nn.Linear(n_input, n_out, bias=False)
116 | return
117 |
118 | def forward(self, r1_x):
119 | r1_x = nn.functional.normalize(r1_x, dim=1, p=2)
120 | # if len(r1_x.shape) != 4:
121 | # r1_x = r1_x.unsqueeze(-1).unsqueeze(-1)
122 | return self.prototypes(r1_x) # .squeeze(-1).squeeze(-1)
123 |
124 |
125 | class Classifier(nn.Module):
126 | def __init__(self, n_input, n_classes, n_hidden=1024):
127 | super(Classifier, self).__init__()
128 | self.n_input = n_input
129 | self.n_classes = n_classes
130 | self.n_hidden = n_hidden
131 |
132 | self.aux_head = nn.Sequential(
133 | nn.Linear(n_input, n_hidden, bias=True),
134 | nn.BatchNorm1d(n_hidden, affine=True),
135 | nn.ReLU(inplace=True),
136 | nn.Linear(n_hidden, n_hidden, bias=True),
137 | nn.BatchNorm1d(n_hidden, affine=True),
138 | nn.ReLU(inplace=True),
139 | nn.Linear(n_hidden, n_classes, bias=True)
140 | # nn.Conv2d(n_input, n_hidden, 1, 1, 0),
141 | # nn.BatchNorm2d(n_hidden, affine=True),
142 | # nn.ReLU(inplace=True),
143 | # nn.Conv2d(n_hidden, n_hidden, 1, 1, 0),
144 | # nn.BatchNorm2d(n_hidden, affine=True),
145 | # nn.ReLU(inplace=True),
146 | # nn.Conv2d(n_hidden, n_classes, 1, 1, 0)
147 | )
148 | return
149 |
150 | def forward(self, r1_x):
151 | # Always detach so to train only omega parameters
152 | r1_x = r1_x.detach()
153 | px = self.aux_head(r1_x)
154 | return px.squeeze(-1).squeeze(-1)
155 |
156 |
157 | class MLPClassifier(nn.Module):
158 | def __init__(self, n_classes, n_input, p=0.1):
159 | super(MLPClassifier, self).__init__()
160 | self.n_input = n_input
161 | self.n_classes = n_classes
162 |
163 | self.block_forward = nn.Sequential(
164 | nn.Dropout(p=p),
165 | nn.Linear(self.n_input, n_classes, bias=True)
166 | )
167 |
168 | def forward(self, x):
169 | x = x.detach()
170 | x = torch.flatten(x, 1)
171 | logits = self.block_forward(x)
172 | return logits
173 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from graphs import ResNet, BasicBlock, Bottleneck, Projection, Prototypes, Classifier
5 | from utils import sinkhorn
6 | from costs import entropy, mi
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, n_classes, encoder_size=32, prototypes=1000, project_dim=128,
11 | tau=0.1, eps=0.05, model_type='resnet18'):
12 | super(Model, self).__init__()
13 |
14 | self.hyperparams = {
15 | 'n_classes': n_classes,
16 | 'encoder_size': encoder_size,
17 | 'prototypes': prototypes, # k' number of prototypes
18 | 'project_dim': project_dim, # projection head's dimension
19 | 'tau': tau, # Tau parameter of Softmax smoothness (Eq.2)
20 | 'eps': eps, # epsilon (Eq.3)
21 | 'model_type': model_type
22 | }
23 |
24 | dummy_batch = torch.zeros((2, 3, encoder_size, encoder_size))
25 |
26 | # encoder that provides multiscale features
27 | self.encoder = Encoder(encoder_size=encoder_size, model_type=model_type)
28 | rkhs_1 = self.encoder(dummy_batch)
29 | self.encoder = nn.DataParallel(self.encoder)
30 | self.project = Projection(rkhs_1.size(1), project_dim)
31 | self.prototypes = Prototypes(project_dim, prototypes)
32 | self.auxhead = Classifier(rkhs_1.size(1), n_classes) # Classifier
33 | self.modules_ = [self.encoder.module, self.prototypes, self.project, self.auxhead]
34 | self._t, self._e = tau, eps
35 | self._z_bank = []
36 | self._u_bank = []
37 | self.counter = 0
38 |
39 | def _encode(self, res_dict, aug_imgs, num_crops):
40 | res_dict['z'], res_dict['u'] = [], []
41 | b = aug_imgs[0].size()[0]
42 |
43 | with torch.no_grad():
44 | # l2 normalization for prototypes
45 | w = self.prototypes.prototypes.weight.data.clone()
46 | w = nn.functional.normalize(w, dim=1, p=2)
47 | self.prototypes.prototypes.weight.copy_(w)
48 |
49 | for aug_imgs_ in aug_imgs:
50 | emb = self.encoder(aug_imgs_)
51 | emb_projected = self.project(emb)
52 | ux = self.prototypes(emb_projected)
53 | res_dict['z'].append(emb)
54 | res_dict['u'].append(ux)
55 | # '''
56 | swav_loss = []
57 | for i in range(num_crops[0]):
58 | with torch.no_grad():
59 | # Sinkhorn knopp algorithm (Formulation of eq. 3 & 4, based on SWAV implementation)
60 | # (SWAV URL: https://arxiv.org/abs/2006.09882)
61 | q = res_dict['u'][i].detach().clone()
62 | q = torch.exp(q / self._e)
63 | q = sinkhorn(q.T, 3) # [-batch_size:]
64 | for p, px in enumerate(res_dict['u']):
65 | if p != i:
66 | # Equation 1
67 | swav_loss.append(entropy(q, torch.softmax(px / self._t, -1)).sum(-1).mean())
68 | res_dict['swav_loss'] = torch.stack(swav_loss).sum() / len(swav_loss)
69 |
70 | return res_dict
71 |
72 | def encode(self, imgs, res_dict):
73 |
74 | with torch.no_grad():
75 | z = self.encoder(imgs)
76 | res_dict['y'] = torch.flatten(self.auxhead(z), 1).argmax(-1)
77 | return res_dict
78 |
79 | def forward(self, x, nmb_crops=[2], eval_idxs=None, eval_only=False):
80 |
81 | # dict for returning various values
82 | res_dict = {}
83 | if eval_only:
84 | return self.encode(x, res_dict)
85 |
86 | res_dict = self._encode(res_dict, x, nmb_crops)
87 |
88 | '''
89 | SLT10 contains instances of unlabelled and labelled set..
90 | Because our method is trained in online mode together with encoder part.
91 | Due to the ratio of instances between unlabelled (100.000) and labelled set (5000 on training set),
92 | the training batch size contains a large number of unlabelled instances and a very small number
93 | of training instances.
94 | Hence, we use a short temporary bank to store representations (z) till we can reach the same number
95 | as the batch size of labelled instances.
96 | Afterwards, the membanks are used to train the classifier.
97 | '''
98 |
99 | # The classifier is trained only on labelled training set
100 | # This actually applies only on STL10 where there is an unlabelled and labelled set
101 | # We train and evaluate the classifier only on labelled set, hence eval_idxs is a boolean array
102 | # for performing training only on labelled idxs
103 | if eval_idxs is None:
104 | eval_idxs = torch.ones(res_dict['z'][0].size(0), dtype=torch.bool)
105 |
106 | # Below lists are membanks to store embeddings and u probability distributions.
107 | # These are only used for training on STL10.
108 | # As it is mentioned above classifier is trained only on labelled set of STL10
109 | # We just collect the labelled embedding representations till the collection reaches equal number to
110 | # the training batch size
111 | # This happens because a batch size on STL10 contains instances of label and unlabelled set
112 | # Hence each batch size contains a very small number of training instances and we use this collection
113 | # In order to maintain the batch size of classifier on average equals to the encoder batch size
114 | # '''
115 |
116 | with torch.no_grad():
117 | self._z_bank.append(torch.cat([z[eval_idxs].unsqueeze(1).detach() for z in res_dict['z']], 1))
118 | self._u_bank.append(torch.cat([ux[eval_idxs].unsqueeze(1).detach() for ux in res_dict['u']], 1))
119 | b = res_dict['z'][0].size(0)
120 | self.counter += eval_idxs.sum().item()
121 | mi_loss, lgt_reg, res_dict['y'] = [], [], None
122 | if self.counter >= b:
123 | Z = torch.cat(self._z_bank, 0).unbind(1)
124 | Y = [self.auxhead(z) for z in Z]
125 | U = torch.unbind(torch.cat(self._u_bank), 1)
126 | for j, py_j in enumerate(Y):
127 | for u, pu_i in enumerate(U):
128 | mi_loss_, lgt_reg_ = mi(py_j, pu_i / self._t) # Equation 6
129 | mi_loss.append(mi_loss_), lgt_reg.append(lgt_reg_)
130 | res_dict['y'] = torch.flatten(Y[0], 1).argmax(-1)
131 | self.reset_membank_list()
132 | zero = torch.tensor([0], device=x[0].device.type)
133 | res_dict['mi_loss'] = torch.stack(mi_loss).mean() if len(mi_loss) > 0 else zero
134 | res_dict['lgt_reg'] = torch.stack(lgt_reg).mean() if len(lgt_reg) > 0 else zero
135 |
136 | return res_dict
137 |
138 | # Reset the membanks, this actually applies only on STL10 training because of the unlabelled set
139 | def reset_membank_list(self):
140 | self._z_bank, self._u_bank = [], []
141 | self.counter = 0
142 |
143 |
144 | class Encoder(nn.Module):
145 | def __init__(self, encoder_size=32, model_type='resnet18'):
146 | super(Encoder, self).__init__()
147 |
148 | # encoding block for local features
149 | print('Using a {}x{} encoder'.format(encoder_size, encoder_size))
150 | inplanes = 64
151 | if encoder_size == 32:
152 | conv1 = nn.Sequential(nn.Conv2d(3, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
153 | nn.BatchNorm2d(inplanes),
154 | nn.ReLU(inplace=True))
155 | elif encoder_size == 96 or encoder_size == 64:
156 | conv1 = nn.Sequential(nn.Conv2d(3, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
157 | nn.BatchNorm2d(inplanes),
158 | nn.ReLU(inplace=True),
159 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
160 |
161 | else:
162 | raise RuntimeError("Could not build encoder."
163 | "Encoder size {} is not supported".format(encoder_size))
164 |
165 | if model_type == 'resnet18':
166 | # ResNet18 block
167 | self.model = ResNet(BasicBlock, [2, 2, 2, 2], conv1)
168 | elif model_type == 'resnet34':
169 | self.model = ResNet(BasicBlock, [3, 4, 6, 3], conv1)
170 | elif model_type == 'resnet50':
171 | self.model = ResNet(Bottleneck, [3, 4, 6, 3], conv1)
172 | else:
173 | raise RuntimeError("Wrong model type")
174 |
175 | print(self.get_param_n())
176 |
177 | def get_param_n(self):
178 | w = 0
179 | for p in self.model.parameters():
180 | w += np.product(p.shape)
181 | return w
182 |
183 | def forward(self, x):
184 | return torch.flatten(self.model(x), 1)
185 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from sys import platform
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from torchvision import datasets, transforms
7 | import socket
8 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive
9 | from torchvision.datasets.vision import VisionDataset
10 | import pickle
11 | from collections import defaultdict
12 | from torch.utils.data import Dataset
13 | from tqdm.autonotebook import tqdm
14 |
15 | training_datasets = ['C10', 'C20', 'C100', 'STL10', 'TINY']
16 |
17 |
18 | def get_encoder_size(dataset_name):
19 | if dataset_name in training_datasets[:3]:
20 | return 32
21 | if dataset_name == training_datasets[-2]:
22 | return 96
23 | if dataset_name == training_datasets[-1]:
24 | return 64
25 | raise RuntimeError("Error get encoder size, unknown setup size: {}".format(dataset_name))
26 |
27 |
28 | def get_dataset(dataset_name):
29 | dataset_name = dataset_name.upper()
30 | if dataset_name in training_datasets:
31 | return dataset_name
32 | raise KeyError("Unknown dataset '" + dataset_name + "'. Must be one of "
33 | + ', '.join([name for name in training_datasets]))
34 |
35 |
36 | class Transforms:
37 |
38 | def __init__(self, nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std):
39 | assert len(size_crops) == len(nmb_crops)
40 | assert len(min_scale_crops) == len(nmb_crops)
41 | assert len(max_scale_crops) == len(nmb_crops)
42 |
43 | flip = transforms.RandomHorizontalFlip(p=0.5)
44 | normalize = transforms.Normalize(mean=mu, std=std)
45 |
46 | col_jitter = transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8)
47 | rnd_gray = transforms.RandomGrayscale(p=0.25)
48 |
49 | trans = []
50 | for i in range(len(size_crops)):
51 | randomresizedcrop = transforms.RandomResizedCrop(
52 | size_crops[i], scale=(min_scale_crops[i], max_scale_crops[i]))
53 | trans.extend([transforms.Compose([
54 | flip, randomresizedcrop, col_jitter, rnd_gray, transforms.ToTensor(), normalize])] * nmb_crops[i])
55 | self.train_transform = trans
56 |
57 | self.test_transform = transforms.Compose([transforms.ToTensor(), normalize])
58 |
59 | def __call__(self, inp):
60 | multi_crops = list(map(lambda trans: trans(inp), self.train_transform))
61 | return multi_crops, self.test_transform(inp)
62 |
63 |
64 | def build_dataset(dataset, batch_size, nmb_workers, nmb_crops, size_crops, min_scale_crops, max_scale_crops, path):
65 | if dataset == training_datasets[0]:
66 | num_classes = 10
67 | mu = [0.4914, 0.4822, 0.4465]
68 | std = [0.2023, 0.1994, 0.2010]
69 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std)
70 | test_transform = train_transform.test_transform
71 | train_dataset = CIFAR10(root=path, train=True, transform=train_transform, download=True)
72 | test_dataset = CIFAR10(root=path, train=False, transform=test_transform, download=True)
73 |
74 | elif dataset in training_datasets[1:3]:
75 | num_classes = 20 if dataset == training_datasets[1] else 100
76 | coarse = True if dataset == training_datasets[1] else False
77 | mu = [0.5071, 0.4867, 0.4408]
78 | std = [0.2675, 0.2565, 0.2761]
79 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std)
80 | test_transform = train_transform.test_transform
81 | train_dataset = CIFAR100(root=path, train=True, transform=train_transform, download=True, c100_coarse=coarse)
82 | test_dataset = CIFAR100(root=path, train=False, transform=test_transform, download=True, c100_coarse=coarse)
83 |
84 | elif dataset == training_datasets[-2]:
85 | num_classes = 10
86 | mu = [0.43, 0.42, 0.39]
87 | std = [0.27, 0.26, 0.27]
88 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std)
89 | test_transform = train_transform.test_transform
90 | train_dataset = STL10(root=path, split='train+unlabeled', transform=train_transform, download=True)
91 | test_dataset = STL10(root=path, split='test', transform=test_transform, download=True)
92 | elif dataset == training_datasets[-1]:
93 | num_classes = 200
94 | mu = [0.485, 0.456, 0.406]
95 | std = [0.229, 0.224, 0.225]
96 | train_transform = Transforms(nmb_crops, size_crops, min_scale_crops, max_scale_crops, mu, std)
97 | test_transform = train_transform.test_transform
98 | train_dataset = TinyImageNetDataset(path, transform=train_transform, download=False, preload=True)
99 | test_dataset = TinyImageNetDataset(path, mode='val', transform=test_transform, download=False,
100 | preload=False)
101 | else:
102 | raise RuntimeError("Error not supported dataset {}".format(dataset))
103 |
104 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
105 | pin_memory=True, drop_last=True, num_workers=nmb_workers)
106 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True,
107 | pin_memory=True, drop_last=False, num_workers=nmb_workers)
108 |
109 | return train_loader, test_loader, num_classes
110 |
111 |
112 | '''
113 | Overwritting Pytorch methods of CIFAR-10 and CIFAR-100 for being able to provide the coarse labels (CIFAR-20)
114 | Additionally, all below vision datasets return the index of the iterating instance for reference only
115 | '''
116 |
117 |
118 | class CIFAR10(VisionDataset):
119 | """`CIFAR10 `_ Dataset.
120 |
121 | Args:
122 | root (string): Root directory of dataset where directory
123 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
124 | train (bool, optional): If True, creates dataset from training set, otherwise
125 | creates from test set.
126 | transform (callable, optional): A function/transform that takes in an PIL image
127 | and returns a transformed version. E.g, ``transforms.RandomCrop``
128 | target_transform (callable, optional): A function/transform that takes in the
129 | target and transforms it.
130 | download (bool, optional): If true, downloads the dataset from the internet and
131 | puts it in root directory. If dataset is already downloaded, it is not
132 | downloaded again.
133 |
134 | """
135 | base_folder = 'cifar-10-batches-py'
136 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
137 | filename = "cifar-10-python.tar.gz"
138 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
139 | train_list = [
140 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
141 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
142 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
143 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
144 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
145 | ]
146 |
147 | test_list = [
148 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
149 | ]
150 | meta = {
151 | 'filename': 'batches.meta',
152 | 'key': 'label_names',
153 | 'md5': '5ff9c542aee3614f3951f8cda6e48888',
154 | }
155 |
156 | def __init__(self, root, train=True, transform=None, target_transform=None,
157 | download=False, c100_coarse=True):
158 |
159 | super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform)
160 |
161 | self.train = train # training set or test set
162 |
163 | if download:
164 | self.download()
165 |
166 | if not self._check_integrity():
167 | raise RuntimeError('Dataset not found or corrupted.' +
168 | ' You can use download=True to download it')
169 |
170 | if self.train:
171 | downloaded_list = self.train_list
172 | else:
173 | downloaded_list = self.test_list
174 |
175 | self.data = []
176 | self.targets = []
177 |
178 | # now load the picked numpy arrays
179 | for file_name, checksum in downloaded_list:
180 | file_path = os.path.join(self.root, self.base_folder, file_name)
181 | with open(file_path, 'rb') as f:
182 | entry = pickle.load(f, encoding='latin1')
183 | self.data.append(entry['data'])
184 | if 'labels' in entry:
185 | self.targets.extend(entry['labels'])
186 | else:
187 | if c100_coarse is True:
188 | self.targets.extend(entry['coarse_labels'])
189 | self.meta['key'] = self.meta['key2']
190 | else:
191 | self.targets.extend(entry['fine_labels'])
192 | self.meta['key'] = self.meta['key1']
193 |
194 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
195 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
196 | self._load_meta()
197 |
198 | def _load_meta(self):
199 | path = os.path.join(self.root, self.base_folder, self.meta['filename'])
200 | if not check_integrity(path, self.meta['md5']):
201 | raise RuntimeError('Dataset metadata file not found or corrupted.' +
202 | ' You can use download=True to download it')
203 | with open(path, 'rb') as infile:
204 | data = pickle.load(infile, encoding='latin1')
205 | self.classes = data[self.meta['key']]
206 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
207 |
208 | def __getitem__(self, index):
209 | """
210 | Args:
211 | index (int): Index
212 |
213 | Returns:
214 | tuple: (image, target) where target is index of the target class.
215 | """
216 | img, target = self.data[index], self.targets[index]
217 |
218 | # doing this so that it is consistent with all other datasets
219 | # to return a PIL Image
220 | img = Image.fromarray(img)
221 |
222 | if self.transform is not None:
223 | img = self.transform(img)
224 |
225 | if self.target_transform is not None:
226 | target = self.target_transform(target)
227 |
228 | return img, target, index
229 |
230 | def __len__(self):
231 | return len(self.data)
232 |
233 | def _check_integrity(self):
234 | root = self.root
235 | for fentry in (self.train_list + self.test_list):
236 | filename, md5 = fentry[0], fentry[1]
237 | fpath = os.path.join(root, self.base_folder, filename)
238 | if not check_integrity(fpath, md5):
239 | return False
240 | return True
241 |
242 | def download(self):
243 | if self._check_integrity():
244 | print('Files already downloaded and verified')
245 | return
246 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
247 |
248 | def extra_repr(self):
249 | return "Split: {}".format("Train" if self.train is True else "Test")
250 |
251 |
252 | class CIFAR100(CIFAR10):
253 | """`CIFAR100 `_ Dataset.
254 |
255 | This is a subclass of the `CIFAR10` Dataset.
256 | """
257 | base_folder = 'cifar-100-python'
258 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
259 | filename = "cifar-100-python.tar.gz"
260 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
261 | train_list = [
262 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
263 | ]
264 |
265 | test_list = [
266 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
267 | ]
268 | meta = {
269 | 'filename': 'meta',
270 | 'key1': 'fine_label_names',
271 | 'key2': 'coarse_label_names',
272 | 'md5': '7973b15100ade9c7d40fb424638fde48',
273 | }
274 |
275 |
276 | class STL10(datasets.STL10):
277 |
278 | def __getitem__(self, index):
279 | """
280 | Args:
281 | index (int): Index
282 |
283 | Returns:
284 | tuple: (image, target) where target is index of the target class.
285 | """
286 | if self.labels is not None:
287 | img, target = self.data[index], int(self.labels[index])
288 | else:
289 | img, target = self.data[index], None
290 |
291 | # doing this so that it is consistent with all other datasets
292 | # to return a PIL Image
293 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
294 | # img = img.crop((8, 8, 88, 88))
295 | if self.transform is not None:
296 | img = self.transform(img)
297 |
298 | if self.target_transform is not None:
299 | target = self.target_transform(target)
300 |
301 | return img, target, index
302 |
303 |
304 | def download_and_unzip(URL, root_dir):
305 | error_message = "Download is not yet implemented. Please, go to {URL} urself."
306 | raise NotImplementedError(error_message.format(URL))
307 |
308 |
309 | def _add_channels(img):
310 | if len(img.getbands()) == 1: # third axis is the channels
311 | img = np.expand_dims(np.array(img), -1)
312 | img = np.tile(img, (1, 1, 3))
313 | img = Image.fromarray(img)
314 | return img
315 |
316 |
317 | """Creates a paths datastructure for the tiny imagenet.
318 | Args:
319 | root_dir: Where the data is located
320 | download: Download if the data is not there
321 | Members:
322 | label_id:
323 | ids:
324 | nit_to_words:
325 | data_dict:
326 | """
327 |
328 |
329 | class TinyImageNetPaths:
330 | def __init__(self, root_dir, download=False):
331 | if download:
332 | download_and_unzip('http://cs231n.stanford.edu/tiny-imagenet-200.zip',
333 | root_dir)
334 | train_path = os.path.join(root_dir, 'train')
335 | val_path = os.path.join(root_dir, 'val')
336 | test_path = os.path.join(root_dir, 'test')
337 |
338 | wnids_path = os.path.join(root_dir, 'wnids.txt')
339 | words_path = os.path.join(root_dir, 'words.txt')
340 |
341 | self._make_paths(train_path, val_path, test_path,
342 | wnids_path, words_path)
343 |
344 | def _make_paths(self, train_path, val_path, test_path,
345 | wnids_path, words_path):
346 | self.ids = []
347 | with open(wnids_path, 'r') as idf:
348 | for nid in idf:
349 | nid = nid.strip()
350 | self.ids.append(nid)
351 | self.nid_to_words = defaultdict(list)
352 | with open(words_path, 'r') as wf:
353 | for line in wf:
354 | nid, labels = line.split('\t')
355 | labels = list(map(lambda x: x.strip(), labels.split(',')))
356 | self.nid_to_words[nid].extend(labels)
357 |
358 | self.paths = {
359 | 'train': [], # [img_path, id, nid, box]
360 | 'val': [], # [img_path, id, nid, box]
361 | 'test': [] # img_path
362 | }
363 |
364 | # Get the test paths
365 | self.paths['test'] = list(map(lambda x: os.path.join(test_path, x),
366 | os.listdir(test_path)))
367 | # Get the validation paths and labels
368 | with open(os.path.join(val_path, 'val_annotations.txt')) as valf:
369 | for line in valf:
370 | fname, nid, x0, y0, x1, y1 = line.split()
371 | fname = os.path.join(val_path, 'images', fname)
372 | bbox = int(x0), int(y0), int(x1), int(y1)
373 | label_id = self.ids.index(nid)
374 | self.paths['val'].append((fname, label_id, nid, bbox))
375 |
376 | # Get the training paths
377 | train_nids = os.listdir(train_path)
378 | for nid in train_nids:
379 | anno_path = os.path.join(train_path, nid, nid + '_boxes.txt')
380 | imgs_path = os.path.join(train_path, nid, 'images')
381 | label_id = self.ids.index(nid)
382 | with open(anno_path, 'r') as annof:
383 | for line in annof:
384 | fname, x0, y0, x1, y1 = line.split()
385 | fname = os.path.join(imgs_path, fname)
386 | bbox = int(x0), int(y0), int(x1), int(y1)
387 | self.paths['train'].append((fname, label_id, nid, bbox))
388 |
389 |
390 | """Datastructure for the tiny image dataset.
391 | Args:
392 | root_dir: Root directory for the data
393 | mode: One of "train", "test", or "val"
394 | preload: Preload into memory
395 | load_transform: Transformation to use at the preload time
396 | transform: Transformation to use at the retrieval time
397 | download: Download the dataset
398 | Members:
399 | tinp: Instance of the TinyImageNetPaths
400 | img_data: Image data
401 | label_data: Label data
402 | """
403 |
404 |
405 | class TinyImageNetDataset(Dataset):
406 | def __init__(self, root_dir, mode='train', preload=True, load_transform=None,
407 | transform=None, download=False, max_samples=None):
408 | tinp = TinyImageNetPaths(root_dir, download)
409 | self.mode = mode
410 | self.label_idx = 1 # from [image, id, nid, box]
411 | self.preload = preload
412 | self.transform = transform
413 | self.transform_results = dict()
414 |
415 | self.IMAGE_SHAPE = (64, 64, 3)
416 |
417 | self.img_data = []
418 | self.label_data = []
419 |
420 | self.max_samples = max_samples
421 | self.samples = tinp.paths[mode]
422 | self.samples_num = len(self.samples)
423 |
424 | if self.max_samples is not None:
425 | self.samples_num = min(self.max_samples, self.samples_num)
426 | self.samples = np.random.permutation(self.samples)[:self.samples_num]
427 |
428 | if self.preload:
429 | load_desc = "Preloading {} data...".format(mode)
430 | self.img_data = {} # np.zeros((self.samples_num,) + self.IMAGE_SHAPE, dtype=np.float32)
431 | self.label_data = np.zeros((self.samples_num,), dtype=np.int)
432 | for idx in tqdm(range(self.samples_num), desc=load_desc):
433 | s = self.samples[idx]
434 | # img = imageio.imread(s[0])
435 | img_ = Image.open(s[0])
436 | img = img_.copy()
437 | img = _add_channels(img)
438 | img_.close()
439 | self.img_data[idx] = img
440 | if mode != 'test':
441 | self.label_data[idx] = s[self.label_idx]
442 |
443 | if load_transform:
444 | for lt in load_transform:
445 | result = lt(self.img_data, self.label_data)
446 | self.img_data, self.label_data = result[:2]
447 | if len(result) > 2:
448 | self.transform_results.update(result[2])
449 |
450 | def __len__(self):
451 | return self.samples_num
452 |
453 | def __getitem__(self, idx):
454 | if self.preload:
455 | img = self.img_data[idx]
456 | target = None if self.mode == 'test' else self.label_data[idx]
457 | else:
458 | s = self.samples[idx]
459 | # img = imageio.imread(s[0])
460 | img = Image.open(s[0])
461 | img = _add_channels(img)
462 | target = None if self.mode == 'test' else s[self.label_idx]
463 | # img = img.crop((4, 4, 60, 60))
464 |
465 | # to return a PIL Image
466 | # img = Image.fromarray(np.transpose(img, (1, 2, 0)))
467 | if self.transform:
468 | img = self.transform(img)
469 | return img, target, idx
470 |
471 |
472 | dir_structure_help = r"""
473 | TinyImageNetPath
474 | ├── test
475 | │ └── images
476 | │ ├── test_0.JPEG
477 | │ ├── t...
478 | │ └── ...
479 | ├── train
480 | │ ├── n01443537
481 | │ │ ├── images
482 | │ │ │ ├── n01443537_0.JPEG
483 | │ │ │ ├── n...
484 | │ │ │ └── ...
485 | │ │ └── n01443537_boxes.txt
486 | │ ├── n01629819
487 | │ │ ├── images
488 | │ │ │ ├── n01629819_0.JPEG
489 | │ │ │ ├── n...
490 | │ │ │ └── ...
491 | │ │ └── n01629819_boxes.txt
492 | │ ├── n...
493 | │ │ ├── images
494 | │ │ │ ├── ...
495 | │ │ │ └── ...
496 | ├── val
497 | │ ├── images
498 | │ │ ├── val_0.JPEG
499 | │ │ ├── v...
500 | │ │ └── ...
501 | │ └── val_annotations.txt
502 | ├── wnids.txt
503 | └── words.txt
504 | """
505 |
--------------------------------------------------------------------------------