├── utils ├── __init__.py └── utils.py ├── datasets ├── __init__.py └── hotels8k.py ├── loss ├── __init__.py └── loss.py ├── model ├── __init__.py └── multi_image_model.py ├── engine ├── __init__.py ├── engine.py ├── evaluator.py └── trainer.py ├── README.md └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .hotels8k import HotelsDataset 2 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import MutualDistillationLoss 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_image_model import MultiImageHybrid 2 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import EngineBase 2 | from .evaluator import Evaluator 3 | from .trainer import TrainerEngine -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def str2bool(v): 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def to_numpy(tensor, n_dims=2): 17 | """Convert a torch tensor to numpy array. 18 | Args: 19 | tensor (Tensor): a tensor object to convert. 20 | n_dims (int): size of numpy array shape 21 | """ 22 | try: 23 | nparray = tensor.detach().cpu().clone().numpy() 24 | except AttributeError: 25 | raise TypeError('tensor type should be torch.Tensor, not {}'.format(type(tensor))) 26 | 27 | while len(nparray.shape) < n_dims: 28 | nparray = np.expand_dims(nparray, axis=0) 29 | 30 | return nparray 31 | 32 | 33 | def l2_normalize(tensor, axis=-1): 34 | """L2-normalize columns of tensor""" 35 | return F.normalize(tensor, p=2, dim=axis) -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MutualDistillationLoss(nn.Module): 6 | 7 | 8 | def __init__(self, temp=4., lambda_hyperparam=.1): 9 | 10 | super(MutualDistillationLoss, self).__init__() 11 | self.temp = temp 12 | self.kl_div = nn.KLDivLoss(reduction='none') 13 | self.lambda_hyperparam = lambda_hyperparam 14 | 15 | 16 | def forward(self, multi_view_logits, single_view_logits, targets): 17 | 18 | averaged_single_logits = torch.mean(single_view_logits, dim=1) 19 | q = torch.softmax(averaged_single_logits / self.temp, dim=1) 20 | 21 | try: 22 | max_q, pred_q = torch.max(q, dim=1) 23 | q_correct = pred_q == targets 24 | q_correct = q_correct.float().mean().item() 25 | max_q = max_q.mean().item() 26 | except RuntimeError: 27 | q_correct = 0. 28 | max_q = 0. 29 | 30 | p = torch.softmax(multi_view_logits / self.temp, dim=1) 31 | max_p, _ = torch.max(p, dim=1) 32 | max_p = max_p.mean().item() 33 | 34 | log_q = torch.log_softmax(averaged_single_logits / self.temp, dim=1) 35 | log_p = torch.log_softmax(multi_view_logits / self.temp, dim=1) 36 | 37 | loss = (1/2) * (self.kl_div(log_p, q.detach()).sum(dim=1).mean() + self.kl_div(log_q, p.detach()).sum(dim=1).mean()) 38 | loss_weighted = loss * (self.temp ** 2) * self.lambda_hyperparam 39 | 40 | return loss_weighted 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Official repository for the WACV 2024 paper [Multi-view Classification with Hybrid Fusion and Mutual Distilation](https://openaccess.thecvf.com/content/WACV2024/papers/Black_Multi-View_Classification_Using_Hybrid_Fusion_and_Mutual_Distillation_WACV_2024_paper.pdf). Here, you'll find our code to train and evaluate our method, MV-HFMD. Currently, we provide code to run MV-HFMD on the Hotels-8k dataset. 2 | 3 | ## Instructions 4 | 5 | To train our method on Hotels-8k, first, download the dataset from this [link](https://tuprd-my.sharepoint.com/:u:/g/personal/tul03156_temple_edu/EdVGFFJyQKpGqxmk-WeApP8BLzHIaQ2XYGhhR6E1s0ntqQ?e=qR5rZf). Unzip the file into the desired directory. Then, run 6 | 7 | python3 main.py --data-directory {DATA_DIRECTORY} 8 | 9 | You can toggle the mutual distillation loss function with the argument 10 | 11 | --use_mutual_distillation_loss {True/False} 12 | 13 | And then the number of images per collection that you wish to train and evaluate on 14 | 15 | --num_images {2/3/4} 16 | 17 | By default, the model will generate classification predictions for each individual image and then the entire multi-view collection. These are given in the model output dictionary under the keys 'single' and 'mv_collection', respectivefully. 18 | 19 | ## Requirements: 20 | 21 | * Python 3 22 | * torch 23 | * numpy 24 | * timm 25 | * einops 26 | 27 | ## Citation: 28 | 29 | If you find our work helpful in your research, please consider citing: 30 | 31 | @inproceedings{black2024multi, 32 | title={Multi-View Classification Using Hybrid Fusion and Mutual Distillation}, 33 | author={Black, Samuel and Souvenir, Richard}, 34 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 35 | pages={270--280}, 36 | year={2024} 37 | } 38 | 39 | -------------------------------------------------------------------------------- /model/multi_image_model.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import einops 6 | 7 | class MultiImageHybrid(nn.Module): 8 | 9 | def __init__(self, arch, num_classes, n, pretrained_weights=True): 10 | 11 | super().__init__() 12 | 13 | self.n = n 14 | self.num_classes = num_classes 15 | self.pretrained_weights = pretrained_weights 16 | 17 | drop_rate = .0 if 'tiny' in arch else .1 18 | self.model = timm.create_model(arch, pretrained=self.pretrained_weights, num_classes=self.num_classes, drop_rate=drop_rate) 19 | for block in self.model.blocks: 20 | block.attn.fused_attn = False 21 | 22 | self.embed_dim = self.model.embed_dim 23 | 24 | for block in self.model.blocks: 25 | block.attn.proj_drop = nn.Dropout(p=0.0) 26 | 27 | self.img_embed_matrix = nn.Parameter(torch.zeros(1, n, self.embed_dim), requires_grad=True) 28 | nn.init.xavier_uniform_(self.img_embed_matrix) 29 | 30 | nn.init.zeros_(self.model.head.weight) 31 | nn.init.zeros_(self.model.head.bias) 32 | 33 | def format_multi_image_tokens(self, x, batch_size, tokens_per_image): 34 | 35 | x = einops.rearrange(x, '(b n) s c -> b (n s) c', b=batch_size, n=self.n) 36 | first_img_token_idx = 0 37 | if self.model.cls_token is not None: 38 | # Need to remove all excess CLS tokens 39 | for i in range(1, self.n): 40 | excess_cls_index = i * tokens_per_image + 1 41 | x = torch.cat((x[:, :excess_cls_index], x[:, excess_cls_index + 1:]), dim=1) 42 | first_img_token_idx = 1 43 | 44 | image_embeddings = F.normalize(self.img_embed_matrix, dim=-1) 45 | x[:, first_img_token_idx:] += torch.repeat_interleave(image_embeddings, tokens_per_image, dim=1) 46 | return x 47 | 48 | def forward(self, x): 49 | 50 | batch_size = len(x) 51 | output_dict = {'single': {}} 52 | if self.n > 1: 53 | output_dict['mv_collection'] = {} 54 | 55 | x = einops.rearrange(x, 'b n c h w -> (b n) c h w') 56 | x = self.model.patch_embed(x) 57 | 58 | tokens_per_image = x.shape[1] 59 | x = self.model._pos_embed(x) 60 | 61 | for view_type in output_dict: 62 | 63 | tokens = x.clone() 64 | if view_type == 'mv_collection': 65 | tokens = self.format_multi_image_tokens(tokens, batch_size, tokens_per_image) 66 | tokens = self.model.blocks(tokens) 67 | tokens = self.model.norm(tokens) 68 | output_dict[view_type]['logits'] = self.model.forward_head(tokens) 69 | 70 | return output_dict -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | 5 | import torch 6 | 7 | 8 | def torch_safe_load(module, state_dict, strict=True): 9 | module.load_state_dict({ 10 | k.replace('module.', ''): v for k, v in state_dict.items() 11 | }, strict=strict) 12 | 13 | 14 | class EngineBase(object): 15 | def __init__(self, model, optimizer_model, criterion, lr_scheduler_model, evaluator, 16 | save_dir=None, md_loss=None, grad_clip_norm=None, logger=None): 17 | 18 | 19 | self.device = 'cuda' 20 | self.model = model 21 | self.optimizer_model = optimizer_model 22 | self.criterion = criterion 23 | self.lr_scheduler_model = lr_scheduler_model 24 | self.save_dir = save_dir 25 | self.evaluator = evaluator 26 | self.md_loss = md_loss 27 | self.grad_clip = grad_clip_norm 28 | self.metadata = {} 29 | self.logger = logger 30 | 31 | def model_to_device(self): 32 | self.model.to(self.device) 33 | 34 | @torch.no_grad() 35 | def evaluate(self, val_loader): 36 | if self.evaluator is None: 37 | self.logger.info('[Evaluate] Warning, no evaluator is defined. Skip evaluation') 38 | return 39 | scores = self.evaluator.evaluate(val_loader) 40 | return scores 41 | 42 | def save_models(self, save_to, metadata=None): 43 | state_dict = { 44 | 'model': self.model.state_dict(), 45 | 'optimizer_model': self.optimizer_model.state_dict(), 46 | } 47 | if self.lr_scheduler_model is not None: 48 | state_dict['lr_scheduler_model'] = self.lr_scheduler_model.state_dict() 49 | print('Saving model to {}'.format(save_to)) 50 | torch.save(state_dict, save_to) 51 | self.logger.info('state dict is saved to {}'.format(save_to)) 52 | 53 | def load_models(self, state_dict_path, load_keys=None): 54 | with open(state_dict_path, 'rb') as fin: 55 | model_hash = hashlib.sha1(fin.read()).hexdigest() 56 | self.metadata['pretrain_hash'] = model_hash 57 | 58 | state_dict = torch.load(state_dict_path, map_location='cpu') 59 | 60 | if 'model' not in state_dict: 61 | torch_safe_load(self.model, state_dict, strict=False) 62 | return 63 | 64 | if not load_keys: 65 | load_keys = ['model', 'optimizer_model', 'lr_scheduler_model'] 66 | 67 | for key in load_keys: 68 | try: 69 | torch_safe_load(getattr(self, key), state_dict[key]) 70 | except RuntimeError as e: 71 | print('Unable to import state_dict, missing keys are found. {}'.format(e)) 72 | torch_safe_load(getattr(self, key), state_dict[key], strict=False) 73 | print('state dict is loaded from {} (hash: {}), load_key ({})'.format(state_dict_path, model_hash, load_keys)) 74 | -------------------------------------------------------------------------------- /datasets/hotels8k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | 9 | 10 | def extract_ids(im_path): 11 | hotel_id = im_path.split('/')[-2] 12 | img_id = im_path.split('/')[-1].split('.')[0] 13 | return img_id, hotel_id 14 | 15 | 16 | def generate_target2indices(targets): 17 | target2indices = {} 18 | for i in range(len(targets)): 19 | t = targets[i] 20 | if t not in target2indices: 21 | target2indices[t] = [i] 22 | else: 23 | target2indices[t].append(i) 24 | return target2indices 25 | 26 | 27 | class HotelsDataset(torch.utils.data.Dataset): 28 | 29 | def __init__(self, data_dir, split, n=2, train=False, classes=None): 30 | 31 | self.all_paths = np.load(os.path.join(data_dir, f'{split}.npy')).tolist() 32 | for i in range(len(self.all_paths)): 33 | self.all_paths[i] = f'{data_dir}/' + self.all_paths[i] 34 | self.all_paths = np.array(self.all_paths) 35 | 36 | mean = [0.485, 0.456, 0.406] 37 | std = [0.229, 0.224, 0.225] 38 | 39 | if split == 'train': 40 | self.transform = transforms.Compose([ 41 | transforms.Resize(256), 42 | transforms.RandomCrop(224), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=mean, std=std) 46 | ]) 47 | else: 48 | self.transform = transforms.Compose([ 49 | transforms.Resize(256), 50 | transforms.CenterCrop(224), 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=mean, std=std) 53 | ]) 54 | 55 | if not classes: 56 | self.classes, self.class_to_idx = self.find_classes() 57 | else: 58 | # Assure using same classes indcies as generated for training dataset 59 | self.classes = classes 60 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 61 | 62 | self.num_classes = len(self.class_to_idx) 63 | self.n = n 64 | self.train = train 65 | 66 | self.samples = self.make_dataset() 67 | self.image_paths = [s[0] for s in self.samples] 68 | self.targets = [int(s[1]) for s in self.samples] 69 | self.targets2indices = generate_target2indices(self.targets) 70 | 71 | if not self.train: 72 | self.samples = self.get_all_collection_combos() 73 | self.image_paths = [s[0] for s in self.samples] 74 | self.targets = [int(s[1]) for s in self.samples] 75 | 76 | def find_classes(self): 77 | classes = set() 78 | for path in self.all_paths: 79 | _, hotel_id = extract_ids(path) 80 | classes.add(hotel_id) 81 | classes = list(classes) 82 | classes.sort() 83 | class_to_idx = {classes[i]: i for i in range(len(classes))} 84 | return classes, class_to_idx 85 | 86 | def make_dataset(self): 87 | samples = [] 88 | for path in self.all_paths: 89 | _, hotel_id = extract_ids(path) 90 | if hotel_id in self.class_to_idx: 91 | item = (path, self.class_to_idx[hotel_id]) 92 | samples.append(item) 93 | return samples 94 | 95 | def get_all_collection_combos(self): 96 | samples = [] 97 | for t, indices in self.targets2indices.items(): 98 | paths = [self.image_paths[i] for i in indices] 99 | for subset in itertools.combinations(paths, self.n): 100 | samples.append([subset, t]) 101 | return samples 102 | 103 | def __getitem__(self, index): 104 | target = self.targets[index] 105 | if self.train: # select random images to go with it 106 | possible_choices = self.targets2indices[target] 107 | if len(possible_choices) <= self.n: 108 | paths = [self.image_paths[i] for i in possible_choices] 109 | unique_requirement = False 110 | else: 111 | paths = [self.image_paths[index]] 112 | unique_requirement = True 113 | 114 | while len(paths) < self.n: 115 | selection = np.random.choice(possible_choices) 116 | path = self.image_paths[selection] 117 | if selection not in paths or not unique_requirement: 118 | paths.append(path) 119 | else: 120 | paths = self.image_paths[index] 121 | target = torch.ones((self.n, )).long() * target 122 | images = torch.stack([self.transform(Image.open(p).convert('RGB')) for p in paths]) 123 | return images, target, paths 124 | 125 | def __len__(self): 126 | return len(self.samples) -------------------------------------------------------------------------------- /engine/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from utils import to_numpy 5 | 6 | def top_k_acc(knn_labels, gt_labels, k): 7 | accuracy_per_sample = torch.any(knn_labels[:, :k] == gt_labels, dim=1).float() 8 | return torch.mean(accuracy_per_sample) 9 | 10 | class Evaluator(object): 11 | 12 | 13 | def __init__(self, model, n=2): 14 | 15 | self.model = model 16 | 17 | self.extract_device = 'cuda' 18 | self.eval_device = 'cuda' 19 | 20 | self.num_classes = self.model.num_classes 21 | self.embed_dim = self.model.embed_dim 22 | self.n = n 23 | 24 | @torch.no_grad() 25 | def extract(self, dataloader): 26 | 27 | self.model.eval() 28 | self.model.to(self.extract_device) 29 | 30 | num_collections = len(dataloader.dataset) 31 | num_total_images = num_collections * self.n 32 | 33 | results_dict = {'single': {'logits': np.zeros((num_total_images, self.num_classes)), 34 | 'classes': np.zeros(num_total_images), 35 | 'paths': []}, 36 | 'mv_collection': {'logits': np.zeros((num_collections, self.num_classes)), 37 | 'classes': np.zeros(num_collections), 38 | 'paths': []}} 39 | 40 | if self.n == 1: 41 | del results_dict['mv_collection'] 42 | 43 | s = 0 44 | for i, data in enumerate(dataloader): 45 | images, targets, paths = data 46 | images = images.to(self.extract_device) 47 | batch_output = self.model(images) 48 | e = s + len(images) 49 | for view_type in batch_output: 50 | if view_type == 'single': 51 | multiplier = self.n 52 | t = targets.view(len(images) * self.n) 53 | p = np.array(paths).T.flatten().tolist() 54 | else: 55 | multiplier = 1 56 | t = targets[:, 0] 57 | p = [] 58 | for w in range(len(paths[0])): 59 | l = [] 60 | for j in range(self.n): 61 | l.append(paths[j][w]) 62 | p.append(l) 63 | 64 | results_dict[view_type]['logits'][s * multiplier: e * multiplier] = to_numpy(batch_output[view_type]['logits']) 65 | results_dict[view_type]['classes'][s * multiplier: e * multiplier] = to_numpy(t) 66 | results_dict[view_type]['paths'].extend(p) 67 | s = e 68 | 69 | results_dict['single']['paths'] = np.array(results_dict['single']['paths']) 70 | duplicates = set() 71 | kept_indices = [] 72 | for i in range(len(results_dict['single']['paths'])): 73 | p = results_dict['single']['paths'][i] 74 | if p not in duplicates: 75 | duplicates.add(p) 76 | kept_indices.append(i) 77 | 78 | for key in results_dict['single']: 79 | results_dict['single'][key] = results_dict['single'][key][kept_indices] 80 | 81 | for view_type in results_dict: 82 | results_dict[view_type]['logits'] = torch.from_numpy(results_dict[view_type]['logits']).squeeze() 83 | results_dict[view_type]['classes'] = results_dict[view_type]['classes'].squeeze() 84 | 85 | return results_dict 86 | 87 | def get_metrics(self, logits, gt_labels): 88 | 89 | try: 90 | logits = logits 91 | gt_labels = gt_labels 92 | predictions = torch.argsort(logits, dim=1, descending=True) 93 | except torch.cuda.OutOfMemoryError: 94 | logits = logits.cpu() 95 | gt_labels = gt_labels.cpu() 96 | predictions = torch.argsort(logits, dim=1, descending=True) 97 | 98 | class_1 = top_k_acc(predictions, gt_labels, 1) 99 | class_2 = top_k_acc(predictions, gt_labels, 2) 100 | class_5 = top_k_acc(predictions, gt_labels, 5) 101 | class_10 = top_k_acc(predictions, gt_labels, 10) 102 | class_100 = top_k_acc(predictions, gt_labels, 100) 103 | 104 | metrics = { 105 | 'top1_acc': class_1.item(), 106 | 'top2_acc': class_2.item(), 107 | 'top5_acc': class_5.item(), 108 | 'top10_acc': class_10.item(), 109 | 'top100_acc': class_100.item(), 110 | } 111 | return metrics 112 | 113 | def evaluate(self, dataloader): 114 | 115 | results_dict = self.extract(dataloader) 116 | metrics_dict = {} 117 | 118 | for view_type in results_dict: 119 | gt_classes = np.expand_dims(results_dict[view_type]['classes'], -1) 120 | metrics = self.get_metrics(results_dict[view_type]['logits'].to(self.eval_device), torch.from_numpy(gt_classes).to(self.eval_device)) 121 | metrics_dict[view_type] = metrics 122 | 123 | return metrics_dict 124 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import einops 4 | import torch 5 | 6 | from .engine import EngineBase 7 | 8 | 9 | def cur_step(cur_epoch, idx, N, fmt=None): 10 | _cur_step = cur_epoch + idx / N 11 | if fmt: 12 | return fmt.format(_cur_step) 13 | else: 14 | return _cur_step 15 | 16 | 17 | def get_lr(optimizer): 18 | for param_group in optimizer.param_groups: 19 | return param_group['lr'] 20 | 21 | 22 | class TrainerEngine(EngineBase): 23 | 24 | def _train_epoch(self, dataloader, cur_epoch): 25 | 26 | self.model.train() 27 | train_acc = {} 28 | 29 | for idx, (images, targets, _) in enumerate(dataloader): 30 | B, N, _, _, _ = images.shape 31 | images = images.to(self.device) 32 | targets = targets.to(self.device) 33 | 34 | with torch.cuda.amp.autocast(): 35 | output = self.model(images) 36 | total_loss = torch.tensor(0.).to(self.device, non_blocking=True) 37 | 38 | for view_type in output: 39 | 40 | if view_type == 'mv_collection': 41 | t = targets[:, 0].flatten() 42 | elif view_type == 'single': 43 | t = targets.flatten() 44 | logits = output[view_type]['logits'] 45 | base_loss = self.criterion(logits, t) 46 | total_loss += base_loss 47 | 48 | pred = torch.argmax(logits, dim=1) 49 | if view_type not in train_acc: 50 | train_acc[view_type] = {'total': 0, 'correct': 0} 51 | pred = pred[t != -1] 52 | t = t[t != -1] 53 | train_acc[view_type]['correct'] += (pred == t).long().sum().item() 54 | train_acc[view_type]['total'] += t.numel() 55 | 56 | if self.md_loss is not None: 57 | single_view_logits = einops.rearrange(output['single']['logits'], '(b n) k -> b n k', b=B, n=N) 58 | mutual_distillation_loss = self.md_loss(output['mv_collection']['logits'], single_view_logits, targets[:, 0]) 59 | total_loss += mutual_distillation_loss 60 | 61 | self.optimizer_model.zero_grad() 62 | total_loss.backward() 63 | if self.grad_clip is not None: 64 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) 65 | 66 | self.optimizer_model.step() 67 | if self.lr_scheduler_model is not None: 68 | try: 69 | self.lr_scheduler_model.step() 70 | except ValueError: 71 | pass 72 | 73 | for view_type in train_acc: 74 | if train_acc[view_type]['total'] > 0: 75 | acc = train_acc[view_type]['correct'] / train_acc[view_type]['total'] 76 | self.logger.info(f'Epoch {cur_epoch}, {view_type} Train top1_acc: {acc}') 77 | 78 | def _eval_epoch(self, dataloader, cur_epoch, model_save_to, best_model_save_to, best_acc): 79 | 80 | scores = self.evaluate(dataloader) 81 | self.metadata['scores'] = scores 82 | save_key = 'mv_collection' if 'mv_collection' in scores else 'single' 83 | 84 | save_metric = 'top1_acc' 85 | save_score = scores[save_key][save_metric] 86 | 87 | if best_acc < save_score: 88 | self.save_models(best_model_save_to, self.metadata) 89 | best_acc = save_score 90 | self.metadata['best_score'] = best_acc 91 | self.metadata['best_epoch'] = cur_epoch 92 | 93 | for view_type in scores: 94 | for metric in scores[view_type]: 95 | self.logger.info(f'Epoch {cur_epoch}, {view_type} Validation {metric}: {scores[view_type][metric]}') 96 | 97 | self.save_models(model_save_to, self.metadata) 98 | return best_acc 99 | 100 | def train(self, tr_loader, n_epochs, val_loader): 101 | 102 | model_save_to = f'{self.save_dir}/last.pth' 103 | best_model_save_to = f'{self.save_dir}/best.pth' 104 | 105 | dt = datetime.datetime.now() 106 | self.model_to_device() 107 | best_acc = 0. 108 | for cur_epoch in range(1, n_epochs + 1): 109 | if cur_epoch == 1: 110 | best_acc = self._eval_epoch(val_loader, 0, model_save_to, best_model_save_to, best_acc) 111 | 112 | self._train_epoch(tr_loader, cur_epoch) 113 | self.metadata['cur_epoch'] = cur_epoch 114 | self.metadata['lr'] = get_lr(self.optimizer_model) 115 | best_acc = self._eval_epoch(val_loader, cur_epoch, model_save_to, best_model_save_to, best_acc) 116 | 117 | elapsed = datetime.datetime.now() - dt 118 | expected_total = elapsed / cur_epoch * n_epochs 119 | expected_remain = expected_total - elapsed 120 | self.logger.info('expected remain {}'.format(expected_remain)) 121 | self.logger.info('finish engine, takes {}'.format(datetime.datetime.now() - dt)) 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from model import MultiImageHybrid 9 | 10 | from loss import MutualDistillationLoss 11 | from engine import TrainerEngine, Evaluator 12 | import numpy as np 13 | from datasets import HotelsDataset 14 | from utils import str2bool 15 | 16 | def main(logger): 17 | 18 | seed = args.seed 19 | torch.manual_seed(seed) 20 | np.random.seed(seed) 21 | 22 | dataset_train = HotelsDataset(args.data_dir, split='train', n=args.n, train=True) 23 | dataset_val = HotelsDataset(args.data_dir, split='val', n=args.n, classes=dataset_train.classes, train=False) 24 | 25 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False, pin_memory=True) 26 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=128, shuffle=False, num_workers=args.num_workers, drop_last=False, pin_memory=True) 27 | 28 | model = MultiImageHybrid(args.architecture, num_classes=dataset_train.num_classes, n=args.n) 29 | model.cuda() 30 | 31 | print('Number of classes: ', dataset_train.num_classes) 32 | 33 | optimizer_model = torch.optim.SGD(params=model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum) 34 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer_model, max_lr=args.lr, 35 | epochs=args.num_epochs, 36 | div_factor=10, 37 | steps_per_epoch=len(dataset_train) // args.batch_size, 38 | final_div_factor=1000, 39 | pct_start=5 / args.num_epochs, anneal_strategy='cos') 40 | 41 | criterion = nn.CrossEntropyLoss(ignore_index=-1) 42 | if args.use_mutual_distillation_loss: 43 | md_loss = MutualDistillationLoss(temp=args.md_temp, lambda_hyperparam=args.md_lambda) 44 | else: 45 | md_loss = None 46 | 47 | evaluator = Evaluator(model=model, n=args.num_images) 48 | trainer = TrainerEngine(model=model, lr_scheduler_model=scheduler, criterion=criterion, optimizer_model=optimizer_model, 49 | evaluator=evaluator, md_loss=md_loss, grad_clip_norm=args.grad_clip_norm, logger=logger, 50 | save_dir=args.save_dir) 51 | 52 | trainer.train(loader_train, args.num_epochs, loader_val) 53 | logger.info('Training done!\n') 54 | 55 | best_weights = torch.load(f'{args.save_dir}/best.pth') 56 | model.load_state_dict(best_weights['model']) 57 | 58 | logger.info('Evaluating on test set:\n') 59 | dataset_test = HotelsDataset(args.data_dir, split='test', n=args.n, classes=dataset_train.classes, train=False) 60 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=False, num_workers=args.num_workers, drop_last=False, pin_memory=True) 61 | 62 | score_dict = evaluator.evaluate(loader_test) 63 | for view_type in score_dict: 64 | for metric in score_dict[view_type]: 65 | logger.info(f'Test {view_type} {metric}: {score_dict[view_type][metric]}') 66 | 67 | if __name__ == '__main__': 68 | 69 | parser = argparse.ArgumentParser( 70 | description="Multiview Model Training", allow_abbrev=False 71 | ) 72 | parser.add_argument("--data_dir", default='/', type=str, help="location to images") 73 | parser.add_argument("--architecture", default='vit_small_r26_s32_224', type=str, help="model architecture") 74 | parser.add_argument('--pretrained_weights', default=True, type=str2bool, help='use pretrained weights') 75 | parser.add_argument('--save_dir', default='output', type=str, help='save location for model weights and log') 76 | parser.add_argument('--seed', default=0, type=int, help='seed') 77 | parser.add_argument("--batch_size", default=64, type=int, help="batch size") 78 | parser.add_argument("--lr", default=0.01, type=float, help="learning rate") 79 | parser.add_argument("--wd", default=5e-4, type=float, help="weight decay") 80 | parser.add_argument("--momentum", default=.9, type=float, help="momentum") 81 | parser.add_argument("--num_epochs", default=50, type=int, help="number of epochs for training") 82 | parser.add_argument("--num_images", default=2, type=int, help="number of images per input") 83 | parser.add_argument("--num_workers", default=8, type=int, help="num dataloading workers") 84 | parser.add_argument('--use_mutual_distillation_loss', default=True, type=str2bool, help='use mutual distillation loss') 85 | parser.add_argument("--md_temp", default=4., type=float, help='mutual distillation temperature') 86 | parser.add_argument("--md_lambda", default=.1, type=float, help='mutual distillation temperature lambda hyperparm') 87 | parser.add_argument("--grad_clip_norm", default=80., type=float, help='grad clip norm value') 88 | args = parser.parse_args() 89 | 90 | if not os.path.exists(args.save_dir): 91 | os.makedirs(args.save_dir) 92 | 93 | logfile = f'{args.save_dir}/log.txt' 94 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler(logfile), logging.StreamHandler()]) 95 | logger = logging.getLogger(__name__) 96 | for param in vars(args): 97 | logger.info(f'{param}: {getattr(args, param)}') 98 | 99 | PARAMS = {} 100 | for arg in vars(args): 101 | PARAMS[arg] = getattr(args, arg) 102 | 103 | args.n = args.num_images 104 | main(logger=logger) --------------------------------------------------------------------------------