├── model ├── __init__.py ├── loss.py ├── model.py └── metric.py ├── utils ├── __init__.py ├── expert_dims.py ├── visualisation.py ├── custom_transforms.py ├── util.py ├── html.py └── visualizer.py ├── .gitignore ├── exps └── README.md ├── trainer ├── __init__.py └── trainer.py ├── data └── README.md ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── base ├── __init__.py ├── base_model.py ├── base_data_loader.py └── base_trainer.py ├── data_loader ├── data_loader.py └── CondensedMovies_dataset.py ├── configs ├── baseline.json └── baseline_mini.json ├── requirements └── environment.yaml ├── train.py ├── README.md ├── parse_config.py └── test.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | data_proc/* -------------------------------------------------------------------------------- /exps/README.md: -------------------------------------------------------------------------------- 1 | Directory to contain experiments -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This directory will contain the dataset 2 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * -------------------------------------------------------------------------------- /utils/expert_dims.py: -------------------------------------------------------------------------------- 1 | expert_dims = { 2 | "resnext101": 2048, 3 | "i3d": 1024, 4 | "senet154": 2048, 5 | "r2p1d": 512, 6 | "vggish": 128, 7 | "scene": 2208, 8 | "subtitles": 768 9 | } 10 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataLoader, BaseDataLoaderExplicitSplit 2 | from torchvision import transforms 3 | from data_loader.CondensedMovies_dataset import CondensedMovies 4 | 5 | class CondensedMoviesDataLoader(BaseDataLoaderExplicitSplit): 6 | """ 7 | CondensedeMovies DataLoader. 8 | """ 9 | 10 | def __init__(self, data_dir, experts, batch_size, split='train', shuffle=True, num_workers=4): 11 | self.data_dir = data_dir 12 | self.dataset = CondensedMovies(data_dir, experts, split) 13 | self.dataset_name = 'CondensedMovies' 14 | # batch size of entire val test set. change this for intra-movie 15 | if split in ['train', 'val']: 16 | drop_last = True 17 | else: 18 | drop_last = False 19 | # batch_size = len(self.dataset.data'clips') 20 | super().__init__(self.dataset, batch_size, shuffle, num_workers, drop_last=drop_last) -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /utils/visualisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib 4 | 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import ipdb 9 | 10 | def visualise_path(pred, target, window): 11 | """ 12 | :param pred: (P, 2) Tensor where P is the number of predictions, and 2 is the (i,j) coordinate 13 | :param target: (T, 2) Tensor where T is the number of targets, and 2 is the (i,j) coordinate 14 | :param dims: (H, W) tup/list the desired height and width of matrix (should be >= to max(i), max(j)) 15 | :param assignment_method: Method of assignment (dtw, minimum etc.) 16 | :return: image, visualisation of path prediction and target. 17 | """ 18 | tp = torch.Tensor((64, 191, 64)) 19 | fp = torch.Tensor((191, 64, 64)) 20 | gt = torch.Tensor((102, 153, 255)) 21 | 22 | grid = torch.ones_like(window).unsqueeze(0).repeat(3, 1, 1) * 255 23 | inf = 130 * torch.ones_like(grid) 24 | grid = torch.where(torch.isnan(window), inf, grid) 25 | 26 | clip_idxs = [t[0] for t in target] 27 | local_idxs = np.unique(np.array(clip_idxs)).tolist() 28 | 29 | for t in target: 30 | local_idx = local_idxs.index(t[0]) 31 | grid[:, local_idx,t[1]] = gt 32 | 33 | for p in pred: 34 | local_idx = local_idxs.index(p[0]) 35 | if (grid[:, local_idx,p[1]] == gt).all(): 36 | grid[:, local_idx, p[1]] = tp 37 | else: 38 | grid[:, local_idx, p[1]] = fp 39 | 40 | return grid / 255 41 | 42 | 43 | def batch_path_vis(pred_dict, target, window): 44 | 45 | grids = [] 46 | 47 | window = window.cpu() 48 | for key, pred in pred_dict.items(): 49 | tmp_window = window 50 | if key == 'min_dist': 51 | tmp_window = torch.zeros_like(window) 52 | grids.append(visualise_path(pred, target, tmp_window)) 53 | 54 | return torch.stack(grids) 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | pred = [[1,1], [2,4]] 60 | gt = [[1,1], [3,4]] 61 | window = torch.zeros((5,6)) 62 | visualise_path(pred, gt, window) 63 | -------------------------------------------------------------------------------- /configs/baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "baseline_allexperts", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "BaselineModel", 6 | "args": { 7 | "text_params": { 8 | "model": "distilbert-base-uncased", 9 | "pretrained": true, 10 | "input": "text" 11 | }, 12 | "projection_dim": 256 13 | } 14 | }, 15 | "data_loader": { 16 | "type": "CondensedMoviesDataLoader", 17 | "args":{ 18 | "data_dir": "data/CondensedMovies", 19 | "shuffle": true, 20 | "num_workers": 4, 21 | "batch_size": 64, 22 | "split": "train", 23 | "experts": { 24 | "resnext101": { 25 | "src": "pred_imagenet_25fps_256px_stride25_offset0/resnext101_32x48d", 26 | "max_tokens": 128, 27 | "use": true 28 | }, 29 | "senet154": { 30 | "src": "pred_imagenet_25fps_256px_stride25_offset0/senet154", 31 | "max_tokens": 128, 32 | "use": true 33 | }, 34 | "i3d": { 35 | "src": "pred_i3d_25fps_256px_stride8_offset0_inner_stride1/i3d", 36 | "max_tokens": 128, 37 | "use": true 38 | }, 39 | "vggish": { 40 | "src": "pred_audio/vggish", 41 | "max_tokens": 128, 42 | "use": true 43 | }, 44 | "scene": { 45 | "src": "pred_scene_25fps_256px_stride25_offset0/densenet161", 46 | "max_tokens": 128, 47 | "use": true 48 | }, 49 | "r2p1d": { 50 | "src": "pred_r2p1d_30fps_256px_stride16_offset0_inner_stride1/r2p1d-ig65m", 51 | "max_tokens": 128, 52 | "use": true 53 | }, 54 | "subtitles": { 55 | "src": "pred_subs/bert-base-uncased_line", 56 | "max_tokens": 128, 57 | "use": true 58 | } 59 | } 60 | } 61 | }, 62 | "optimizer": { 63 | "type": "AdamW", 64 | "args":{ 65 | "lr": 3e-4 66 | } 67 | }, 68 | "loss": { 69 | "type": "NormSoftmaxLoss", 70 | "args": { 71 | } 72 | }, 73 | "metrics": [ 74 | "t2v_metrics", 75 | "v2t_metrics" 76 | ], 77 | "trainer": { 78 | "epochs": 100, 79 | "max_samples_per_epoch": 30000, 80 | "save_dir": "exps", 81 | "save_period": 5, 82 | "verbosity": 2, 83 | "monitor": "min val_loss", 84 | "early_stop": 10, 85 | "neptune": false, 86 | "init_val": true 87 | }, 88 | "visualizer": { 89 | "type": "" 90 | } 91 | } -------------------------------------------------------------------------------- /configs/baseline_mini.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "baseline_subexp", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "BaselineModel", 6 | "args": { 7 | "text_params": { 8 | "model": "distilbert-base-uncased", 9 | "pretrained": true, 10 | "input": "text" 11 | }, 12 | "projection_dim": 256 13 | } 14 | }, 15 | "data_loader": { 16 | "type": "CondensedMoviesDataLoader", 17 | "args": { 18 | "data_dir": "data/CondensedMovies", 19 | "shuffle": true, 20 | "num_workers": 4, 21 | "batch_size": 64, 22 | "split": "train", 23 | "experts": { 24 | "resnext101": { 25 | "src": "pred_imagenet_25fps_256px_stride25_offset0/resnext101_32x48d", 26 | "max_tokens": 128, 27 | "use": false 28 | }, 29 | "senet154": { 30 | "src": "pred_imagenet_25fps_256px_stride25_offset0/senet154", 31 | "max_tokens": 128, 32 | "use": false 33 | }, 34 | "i3d": { 35 | "src": "pred_i3d_25fps_256px_stride8_offset0_inner_stride1/i3d", 36 | "max_tokens": 128, 37 | "use": true 38 | }, 39 | "vggish": { 40 | "src": "pred_audio/vggish", 41 | "max_tokens": 128, 42 | "use": true 43 | }, 44 | "scene": { 45 | "src": "pred_scene_25fps_256px_stride25_offset0/densenet161", 46 | "max_tokens": 128, 47 | "use": false 48 | }, 49 | "r2p1d": { 50 | "src": "pred_r2p1d_30fps_256px_stride16_offset0_inner_stride1/r2p1d-ig65m", 51 | "max_tokens": 128, 52 | "use": true 53 | }, 54 | "subtitles": { 55 | "src": "pred_subs/bert-base-uncased_line", 56 | "max_tokens": 128, 57 | "use": true 58 | } 59 | } 60 | } 61 | }, 62 | "optimizer": { 63 | "type": "AdamW", 64 | "args": { 65 | "lr": 3e-4 66 | } 67 | }, 68 | "loss": { 69 | "type": "NormSoftmaxLoss", 70 | "args": { 71 | } 72 | }, 73 | "metrics": [ 74 | "t2v_metrics", 75 | "v2t_metrics" 76 | ], 77 | "trainer": { 78 | "epochs": 100, 79 | "max_samples_per_epoch": 30000, 80 | "save_dir": "exps", 81 | "save_period": 5, 82 | "verbosity": 2, 83 | "monitor": "min val_loss", 84 | "early_stop": 10, 85 | "neptune": false, 86 | "init_val": true 87 | }, 88 | "visualizer": { 89 | "type": "" 90 | } 91 | } -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | # turn off shuffle option which is mutually exclusive with sampler 51 | self.shuffle = False 52 | self.n_samples = len(train_idx) 53 | 54 | return train_sampler, valid_sampler 55 | 56 | def split_validation(self, diff_kwargs=None): 57 | init_kwargs = self.init_kwargs 58 | if diff_kwargs is not None: 59 | init_kwargs.update(diff_kwargs) 60 | if self.valid_sampler is None: 61 | return None 62 | else: 63 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 64 | 65 | 66 | class BaseDataLoaderExplicitSplit(DataLoader): 67 | """ 68 | Base class for all data loaders 69 | """ 70 | def __init__(self, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate, drop_last=False): 71 | self.shuffle = shuffle 72 | 73 | self.batch_idx = 0 74 | self.n_samples = len(dataset) 75 | 76 | self.init_kwargs = { 77 | 'dataset': dataset, 78 | 'batch_size': batch_size, 79 | 'shuffle': self.shuffle, 80 | 'collate_fn': collate_fn, 81 | 'num_workers': num_workers, 82 | 'drop_last': drop_last 83 | } 84 | super().__init__(**self.init_kwargs) 85 | 86 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class NormSoftmaxLoss(nn.Module): 7 | def __init__(self, temperature=0.05): 8 | super().__init__() 9 | 10 | self.temperature = temperature 11 | 12 | def forward(self, x): 13 | "Assumes input x is similarity matrix of N x M \in [-1, 1], computed using the cosine similarity between normalised vectors" 14 | i_logsm = F.log_softmax(x/self.temperature, dim=1) 15 | j_logsm = F.log_softmax(x.t()/self.temperature, dim=1) 16 | 17 | # sum over positives 18 | idiag = torch.diag(i_logsm) 19 | loss_i = idiag.sum() / len(idiag) 20 | 21 | jdiag = torch.diag(j_logsm) 22 | loss_j = jdiag.sum() / len(jdiag) 23 | 24 | return - loss_i - loss_j 25 | 26 | 27 | class MaxMarginRankingLoss(nn.Module): 28 | 29 | def __init__(self, margin=1, fix_norm=True): 30 | super().__init__() 31 | self.fix_norm = fix_norm 32 | self.loss = th.nn.MarginRankingLoss(margin) 33 | self.margin = margin 34 | 35 | def forward(self, x): 36 | n = x.size()[0] 37 | 38 | x1 = th.diag(x) 39 | x1 = x1.unsqueeze(1) 40 | x1 = x1.expand(n, n) 41 | x1 = x1.contiguous().view(-1, 1) 42 | x1 = th.cat((x1, x1), 0) 43 | 44 | x2 = x.view(-1, 1) 45 | x3 = x.transpose(0, 1).contiguous().view(-1, 1) 46 | 47 | x2 = th.cat((x2, x3), 0) 48 | max_margin = F.relu(self.margin - (x1 - x2)) 49 | 50 | if self.fix_norm: 51 | # remove the elements from the diagonal 52 | keep = th.ones(x.shape) - th.eye(x.shape[0]) # 128 x 128 53 | keep1 = keep.view(-1, 1) 54 | keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) 55 | keep_idx = th.nonzero(th.cat((keep1, keep2), 0).flatten()).flatten() 56 | if x1.is_cuda: 57 | keep_idx = keep_idx.cuda() 58 | x1_ = th.index_select(x1, dim=0, index=keep_idx) 59 | x2_ = th.index_select(x2, dim=0, index=keep_idx) 60 | max_margin = F.relu(self.margin - (x1_ - x2_)) 61 | 62 | return max_margin.mean() 63 | 64 | 65 | class CrossEntropy(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self.loss = nn.CrossEntropyLoss() 69 | 70 | def forward(self, output, target): 71 | return self.loss(output, target) 72 | 73 | 74 | def cosine_sim(im, s): 75 | """Cosine similarity between all the image and sentence pairs 76 | """ 77 | return im.mm(s.t()) 78 | 79 | 80 | def order_sim(im, s): 81 | """Order embeddings similarity measure $max(0, s-im)$ 82 | """ 83 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1)) 84 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1))) 85 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t() 86 | return score 87 | 88 | 89 | def nll_loss(output, target): 90 | return F.nll_loss(output, target) 91 | 92 | 93 | if __name__ == "__main__": 94 | import torch 95 | 96 | random_sims = (torch.rand([10, 8]) * 2) - 1 97 | loss = NormSoftmaxLoss() 98 | loss(random_sims) 99 | -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from utils import Timer 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ 27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ 28 | "the 'config.json' file." 29 | logger.warning(message) 30 | 31 | self.step = 0 32 | self.mode = '' 33 | 34 | self.tb_writer_ftns = { 35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 37 | } 38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 39 | 40 | self.timer = Timer() 41 | 42 | def set_step(self, step, mode='train'): 43 | self.mode = mode 44 | self.step = step 45 | if step == 0: 46 | self.timer.reset() 47 | else: 48 | duration = self.timer.check() 49 | self.add_scalar('steps_per_sec', 1 / duration) 50 | 51 | def __getattr__(self, name): 52 | """ 53 | If visualization is configured to use: 54 | return add_data() methods of tensorboard with additional information (step, tag) added. 55 | Otherwise: 56 | return a blank function handle that does nothing 57 | """ 58 | if name in self.tb_writer_ftns: 59 | add_data = getattr(self.writer, name, None) 60 | 61 | def wrapper(tag, data, *args, **kwargs): 62 | if add_data is not None: 63 | # add mode(train/valid) tag 64 | if name not in self.tag_mode_exceptions: 65 | tag = '{}/{}'.format(tag, self.mode) 66 | add_data(tag, data, self.step, *args, **kwargs) 67 | return wrapper 68 | else: 69 | # default action for returning methods defined in this class, set_step() for instance. 70 | try: 71 | attr = object.__getattr__(name) 72 | except AttributeError: 73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 74 | return attr 75 | 76 | 77 | class SacredNeptuneWriter(): 78 | def __init__(self): 79 | raise NotImplementedError -------------------------------------------------------------------------------- /requirements/environment.yaml: -------------------------------------------------------------------------------- 1 | name: cmd-chall 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - _pytorch_select=0.1=cpu_0 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2021.7.5=h06a4308_1 13 | - certifi=2021.5.30=py37h06a4308_0 14 | - cffi=1.14.6=py37h400218f_0 15 | - cudatoolkit=11.1.74=h6bb024c_0 16 | - ffmpeg=4.3=hf484d3e_0 17 | - freetype=2.10.4=h5ab3b9f_0 18 | - gmp=6.2.1=h2531618_2 19 | - gnutls=3.6.15=he1e5248_0 20 | - intel-openmp=2019.4=243 21 | - jpeg=9b=h024ee3a_2 22 | - lame=3.100=h7b6447c_0 23 | - lcms2=2.12=h3be6417_0 24 | - ld_impl_linux-64=2.35.1=h7274673_9 25 | - libffi=3.3=he6710b0_2 26 | - libgcc-ng=9.3.0=h5101ec6_17 27 | - libgomp=9.3.0=h5101ec6_17 28 | - libiconv=1.15=h63c8f33_5 29 | - libidn2=2.3.2=h7f8727e_0 30 | - libmklml=2019.0.5=0 31 | - libpng=1.6.37=hbc83047_0 32 | - libstdcxx-ng=9.3.0=hd4cf53a_17 33 | - libtasn1=4.16.0=h27cfd23_0 34 | - libtiff=4.2.0=h85742a9_0 35 | - libunistring=0.9.10=h27cfd23_0 36 | - libuv=1.40.0=h7b6447c_0 37 | - libwebp-base=1.2.0=h27cfd23_0 38 | - lz4-c=1.9.3=h2531618_0 39 | - mkl=2020.2=256 40 | - mkl-service=2.3.0=py37he8ac12f_0 41 | - mkl_fft=1.3.0=py37h54f3939_0 42 | - mkl_random=1.1.1=py37h0573a6f_0 43 | - ncurses=6.2=he6710b0_1 44 | - nettle=3.7.3=hbbd107a_1 45 | - ninja=1.10.2=hff7bd54_1 46 | - numpy=1.19.2=py37h54aff64_0 47 | - numpy-base=1.19.2=py37hfa32c7d_0 48 | - olefile=0.46=py37_0 49 | - openh264=2.1.0=hd408876_0 50 | - openjpeg=2.3.0=h05c96fa_1 51 | - openssl=1.1.1k=h27cfd23_0 52 | - pillow=8.3.1=py37h2c7a002_0 53 | - pip=21.1.3=py37h06a4308_0 54 | - pycparser=2.20=py_2 55 | - python=3.7.10=h12debd9_4 56 | - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0 57 | - readline=8.1=h27cfd23_0 58 | - setuptools=52.0.0=py37h06a4308_0 59 | - six=1.16.0=pyhd3eb1b0_0 60 | - sqlite=3.36.0=hc218d9a_0 61 | - tk=8.6.10=hbc83047_0 62 | - torchaudio=0.9.0=py37 63 | - torchvision=0.10.0=py37_cu111 64 | - typing-extensions=3.10.0.0=hd3eb1b0_0 65 | - typing_extensions=3.10.0.0=pyh06a4308_0 66 | - wheel=0.36.2=pyhd3eb1b0_0 67 | - xz=5.2.5=h7b6447c_0 68 | - zlib=1.2.11=h7b6447c_3 69 | - zstd=1.4.9=haebb681_0 70 | - pip: 71 | - backcall==0.2.0 72 | - charset-normalizer==2.0.3 73 | - click==8.0.1 74 | - colorama==0.4.4 75 | - decorator==5.0.9 76 | - dominate==2.6.0 77 | - filelock==3.0.12 78 | - gitdb==4.0.7 79 | - gitpython==3.1.18 80 | - huggingface-hub==0.0.12 81 | - humanize==3.10.0 82 | - idna==3.2 83 | - importlib-metadata==4.6.1 84 | - ipdb==0.13.9 85 | - ipython==7.25.0 86 | - ipython-genutils==0.2.0 87 | - jedi==0.18.0 88 | - joblib==1.0.1 89 | - jsonpickle==1.5.2 90 | - matplotlib-inline==0.1.2 91 | - msgpack==1.0.2 92 | - munch==2.5.0 93 | - packaging==21.0 94 | - pandas==1.3.1 95 | - parso==0.8.2 96 | - pexpect==4.8.0 97 | - pickleshare==0.7.5 98 | - prompt-toolkit==3.0.19 99 | - psutil==5.8.0 100 | - ptyprocess==0.7.0 101 | - py-cpuinfo==8.0.0 102 | - pygments==2.9.0 103 | - pyparsing==2.4.7 104 | - python-dateutil==2.8.2 105 | - pytz==2021.1 106 | - pyyaml==5.4.1 107 | - regex==2021.7.6 108 | - requests==2.26.0 109 | - sacred==0.8.2 110 | - sacremoses==0.0.45 111 | - scikit-learn==0.24.2 112 | - scipy==1.7.0 113 | - smmap==4.0.0 114 | - threadpoolctl==2.2.0 115 | - tokenizers==0.10.3 116 | - toml==0.10.2 117 | - tqdm==4.61.2 118 | - traitlets==5.0.5 119 | - transformers==4.9.1 120 | - urllib3==1.26.6 121 | - wcwidth==0.2.5 122 | - wrapt==1.12.1 123 | - zipp==3.5.0 124 | prefix: /users/maxbain/miniconda3/envs/cmd-chall 125 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import data_loader.data_loader as module_data 3 | import model.loss as module_loss 4 | import model.metric as module_metric 5 | import model.model as module_arch 6 | import utils.visualizer as module_vis 7 | from parse_config import ConfigParser 8 | from trainer import Trainer 9 | from sacred import Experiment 10 | import transformers 11 | 12 | ex = Experiment('train') 13 | @ex.main 14 | def run(): 15 | logger = config.get_logger('train') 16 | 17 | if config['visualizer']['type'] != "": 18 | visualizer = config.initialize( 19 | name='visualizer', 20 | module=module_vis, 21 | exp_name=config['name'], 22 | web_dir=config._web_log_dir 23 | ) 24 | else: 25 | visualizer = None 26 | 27 | # build tokenizer 28 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'], TOKENIZERS_PARALLELISM=False) 29 | 30 | # setup data_loader instances 31 | data_loader = config.initialize('data_loader', module_data) 32 | config['data_loader']['args']['split'] = 'val' 33 | valid_data_loader = config.initialize('data_loader', module_data) 34 | print('Train dataset: ', len(data_loader.sampler), ' samples') 35 | print('Val dataset: ', len(valid_data_loader.sampler), ' samples') 36 | # build model architecture, then print to console 37 | config['arch']['args']['experts_used'] = data_loader.dataset.experts_used 38 | 39 | model = config.initialize('arch', module_arch) 40 | logger.info(model) 41 | 42 | # get function handles of loss and metrics 43 | loss = config.initialize(name="loss", module=module_loss) 44 | metrics = [getattr(module_metric, met) for met in config['metrics']] 45 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 46 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 47 | optimizer = config.initialize('optimizer', transformers, trainable_params) 48 | lr_scheduler = None 49 | if 'lr_scheduler' in config._config: 50 | if hasattr(transformers, config._config['lr_scheduler']['type']): 51 | lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer) 52 | else: 53 | print('lr scheduler not found') 54 | if config['trainer']['neptune']: 55 | writer = ex 56 | else: 57 | writer = None 58 | trainer = Trainer(model, loss, metrics, optimizer, 59 | config=config, 60 | data_loader=data_loader, 61 | valid_data_loader=valid_data_loader, 62 | lr_scheduler=lr_scheduler, 63 | visualizer=visualizer, 64 | writer=writer, 65 | tokenizer=tokenizer, 66 | max_samples_per_epoch=config['trainer']['max_samples_per_epoch'], 67 | init_val=config['trainer']['init_val']) 68 | 69 | trainer.train() 70 | 71 | 72 | if __name__ == '__main__': 73 | args = argparse.ArgumentParser(description='PyTorch Template') 74 | args.add_argument('-c', '--config', default=None, type=str, 75 | help='config file path (default: None)') 76 | args.add_argument('-r', '--resume', default=None, type=str, 77 | help='path to latest checkpoint (default: None)') 78 | args.add_argument('-d', '--device', default=None, type=str, 79 | help='indices of GPUs to enable (default: all)') 80 | args.add_argument('-o', '--observe', action='store_true', 81 | help='Whether to observe (neptune)') 82 | 83 | config = ConfigParser(args) 84 | ex.add_config(config._config) 85 | 86 | if config['trainer']['neptune']: 87 | from neptunecontrib.monitoring.sacred import NeptuneObserver 88 | raise ValueError("Neptune credentials not yet added") 89 | ex.observers.append(NeptuneObserver( 90 | api_token='', 91 | project_name='')) 92 | ex.run() 93 | else: 94 | run() 95 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from base import BaseModel 4 | from utils.expert_dims import expert_dims 5 | from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification 6 | import torch 7 | 8 | 9 | class BaselineModel(BaseModel): 10 | def __init__(self, 11 | experts_used, 12 | projection_dim, 13 | text_params, 14 | token_aggregation='mean'): 15 | super().__init__() 16 | 17 | self.experts_used = experts_used 18 | self.video_GU = nn.ModuleDict({ 19 | expert: Gated_Embedding_Unit(expert_dims[expert], projection_dim) 20 | for expert in experts_used 21 | }) 22 | 23 | self.text_model = AutoModel.from_pretrained(text_params['model']) 24 | txt_dim = self.text_model.config.hidden_size 25 | self.text_GU = nn.ModuleDict({ 26 | expert: Gated_Embedding_Unit(txt_dim, projection_dim, channels=0) 27 | for expert in experts_used 28 | 29 | }) 30 | 31 | self.token_agg = token_aggregation 32 | self.text_params = text_params 33 | 34 | def forward(self, x, eval=False): 35 | 36 | video_ftrs = {} 37 | for expert in self.experts_used: 38 | ftr = x['video'][expert]['ftr'] 39 | if self.token_agg == 'mean': 40 | ftr = ftr.sum(dim=1) / x['video'][expert]['n_tokens'].unsqueeze(1) 41 | else: 42 | raise NotImplementedError 43 | video_ftrs[expert] = ftr.float() 44 | 45 | video_mod = [] 46 | for exp, ftr in video_ftrs.items(): 47 | video_mod.append(self.video_GU[exp](ftr)) 48 | 49 | video_mod = torch.stack(video_mod, dim=1) 50 | 51 | if self.text_params['model'].startswith('bert'): 52 | txt_ftr = self.text_model(x['text']['input_ids'], attention_mask=x['text']['attention_mask'])[ 53 | 'pooler_output'] 54 | elif self.text_params['model'].startswith('distilbert'): 55 | txt_ftr = self.text_model(**x['text']).last_hidden_state[:, 0, :] 56 | else: 57 | raise NotImplementedError 58 | 59 | txt_mod = [self.text_GU[exp](txt_ftr) for exp in self.experts_used] 60 | txt_mod = torch.stack(txt_mod, dim=1) 61 | 62 | video_mod = F.normalize(video_mod, dim=-1) 63 | txt_mod = F.normalize(txt_mod, dim=-1) 64 | embed_stack = torch.einsum('ted,ved->tve', txt_mod, video_mod) 65 | conf_mat = embed_stack.sum(dim=2) / len(self.experts_used) 66 | 67 | if eval: 68 | return conf_mat, txt_mod, video_mod 69 | return conf_mat 70 | 71 | 72 | class Gated_Embedding_Unit(nn.Module): 73 | def __init__(self, input_dimension, output_dimension, gating=True, channels=0): 74 | super(Gated_Embedding_Unit, self).__init__() 75 | 76 | self.fc = nn.Linear(input_dimension, output_dimension) 77 | self.cg = Context_Gating(output_dimension, channels) 78 | self.gating = gating 79 | 80 | def forward(self, x): 81 | x = self.fc(x) 82 | if self.gating: 83 | x = self.cg(x) 84 | x = F.normalize(x, dim=-1) 85 | 86 | return x 87 | 88 | 89 | class Context_Gating(nn.Module): 90 | def __init__(self, dimension, channels, add_batch_norm=True): 91 | super(Context_Gating, self).__init__() 92 | self.fc = nn.Linear(dimension, dimension) 93 | self.add_batch_norm = add_batch_norm 94 | self.channels = channels 95 | if channels > 0: 96 | bn_dim = channels 97 | else: 98 | bn_dim = dimension 99 | self.batch_norm = nn.BatchNorm1d(bn_dim) 100 | 101 | def forward(self, x): 102 | x1 = self.fc(x) 103 | 104 | if self.add_batch_norm: 105 | x1 = self.batch_norm(x1) 106 | 107 | x = torch.cat((x, x1), -1) 108 | 109 | return F.glu(x, -1) 110 | 111 | 112 | def sim_matrix(a, b, eps=1e-8): 113 | """ 114 | added eps for numerical stability 115 | """ 116 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 117 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 118 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 119 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 120 | return sim_mt 121 | 122 | 123 | if __name__ == "__main__": 124 | pass 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Condensed Movies Challenge 2 | 3 | The official code repository for the 2021 Condensed Movies Challenge, held at the 4th Workshop on Closing the Loop Between Vision and Language, held in conjunction with ICCV 2021. This repository contains the code and details for the data download, baseline model, and evaluation. 4 | 5 | ## Dataset Download 6 | 7 | To participate in the challenge, you must first download the new **Challenge version of the Condensed Movies dataset.** Use the following instructions to download the challenge version of the Condensed Movies dataset: 8 | 9 | 1) Clone this repository. 10 | 11 | 2) **Download and unpack the dataset** (xGB). First, you can optionally choose where to download the features to (otherwise its downloaded to `./data`). To do this: 12 | - [Optional] Set environment variable `export DATA_DIR="PATH_WHERE_YOU_WANT_TO_STORE_DATA"` 13 | - Download the dataset (see [here](https://www.robots.ox.ac.uk/~vgg/research/condensed-movies/challenge.html "here")), unpack the tar.gz 14 | - [Optional] If you set custom DATA_DIR, set up a symlink so it maps to ./data: `cd data; ln -s $DATA_DIR/CondensedMovies .; cd ..` 15 | 16 | 3) **OPTIONAL:** For additional dataset queries, please contact maxbain@robots.ox.ac.uk 17 | 18 | ## Dataset Overview 19 | 20 | Here, we provide an overview and detail for the downloaded dataset. For more information on the following features in the dataset, including details architectures and datsets, see [here](https://www.robots.ox.ac.uk/~vgg/research/condensed-movies/features.html "here"). Below is an overview of the dataset tree structure with details: 21 | 22 | ``` 23 | 24 | ├── features 25 | │ ├── pred_audio 26 | │ │ └── vggish (audio features) 27 | │ ├── pred_i3d_25fps_256px_stride8_offset0_inner_stride1 28 | │ │ └── i3d (action features) 29 | │ ├── pred_imagenet_25fps_256px_stride25_offset0 30 | │ │ └── resnext101_32x48d (Instagram Hashtags, fine-tuned on ImageNet features) 31 | │ │ └── senet154 (ImageNet features) 32 | │ ├── pred_r2p1d_30fps_256px_stride16_offset0_inner_stride1 33 | │ │ └── r2p1d-ig65m (Instagram 65m, fine-tuned on Kinetics) 34 | │ └── pred_scene_25fps_256px_stride25_offset0 35 | │ │ └── densenet161 (scene features) 36 | │ └── pred_subs 37 | │ │ └── bert-base-uncased_line (BERT subtitle features) 38 | ├── metadata 39 | │ ├── subs_test.json (raw text subtitle files for the test set) 40 | │ ├── subs_train_val.json (raw text subtitle files for the train/val set) 41 | │ ├── test_challf0.csv (raw text descriptions for the test set) 42 | │ └── train_val_challf0.csv (raw text descriptions for the train/val set) 43 | 44 | ``` 45 | 46 | Below is an overview for the dataset tree structure within a specific feature directory. The features from the train/val videos are further subaranged by year (i.e. 2011 -> 2019 directories). The features from the test videos are found in the 'test' directory. 47 | 48 | ``` 49 | 50 | └── vggish 51 | ├── 2011 52 |    ├── sBjNpZ5t9S4.npy 53 |    ├── SBLMTDMdTIU.npy 54 |    ├── SbTWLdT_tgk.npy 55 |    ├── ... 56 | ├── 2012 57 |    ├── ... 58 | ├── 2013 59 |    ├── ... 60 | ├── 2014 61 |    ├── ... 62 | ├── 2015 63 |    ├── ... 64 | ├── 2016 65 |    ├── ... 66 | ├── 2017 67 |    ├── ... 68 | ├── 2018 69 |    ├── ... 70 | ├── 2019 71 |    ├── ... 72 | └── test (the features for the test videos) 73 | ├── 1061.npy 74 | ├── 1062.npy 75 | ├── 1063.npy 76 | ├── ... 77 | ``` 78 | 79 | 80 | **Train/Val/Test Splits:** the splits are contained in and read from the text description csv files (e.g. metadata/train_val_challf0.csv & metadata/test_challf0.csv) 81 | 82 | 83 | ## Training and Evaluation 84 | 85 | ### 📝 Preparation 86 | 87 | Create conda env `conda env create -f requirements/environment.yaml` (assumes CUDA 11.1, adjust if needed). 88 | 89 | Experiment checkpoints / dataset saves to `exps` by default, can become large in size, set up symlink if you want to store elsewhere. 90 | 91 | ### 🏋️‍️ Baseline Training 92 | 93 | `python train.py --config configs/baseline.json` 94 | 95 | Adjust batch_size, n_gpu, exp_name in the config file accordingly. 96 | 97 | ### 🏁 Evaluation 98 | 99 | #### Validation Check 100 | Evaluate on val set `python test.py --resume exps/models/{EXP_NAME}/{TIMESTAMP}/model_best.pth --split val` 101 | 102 | #### Test Submission 103 | Evaluate on test set `python test.py --resume exps/models/{EXP_NAME}/{TIMESTAMP}/model_best.pth --split test` 104 | 105 | Similarity matrix should be zipped at `exps/models/{EXP_NAME}/{TIMESTAMP}/submission.zip`. 106 | Please upload this to codalab for your submission. 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /utils/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Tuple, Any, Optional 5 | from torchvision.transforms import functional_pil as F_pil 6 | from torchvision.transforms import functional_tensor as F_t 7 | from torchvision.transforms.functional import center_crop, crop 8 | 9 | def _get_image_size(img: Tensor) -> List[int]: 10 | """Returns image size as [w, h] 11 | """ 12 | if isinstance(img, torch.Tensor): 13 | return F_t._get_image_size(img) 14 | 15 | return F_pil._get_image_size(img) 16 | 17 | def center_plus_four_crops(img: Tensor, size: List[int], 18 | margin_h: int, margin_w: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 19 | """Crop the given image into four tiled borders and the central crop. 20 | """ 21 | 22 | if isinstance(size, numbers.Number): 23 | size = (int(size), int(size)) 24 | elif isinstance(size, (tuple, list)) and len(size) == 1: 25 | size = (size[0], size[0]) 26 | 27 | if len(size) != 2: 28 | raise ValueError("Please provide only two dimensions (h, w) for size.") 29 | 30 | image_width, image_height = _get_image_size(img) 31 | 32 | crop_height, crop_width = size 33 | 34 | if crop_width > image_width or crop_height > image_height: 35 | msg = "Requested crop size {} is bigger than input size {}" 36 | raise ValueError(msg.format(size, (image_height, image_width))) 37 | 38 | if crop_width + margin_w > image_width: 39 | msg = "Requested margin size {} + input {} is bigger than input size {}" 40 | raise ValueError(msg.format((margin_h, margin_w), size, (image_height, image_width))) 41 | 42 | #vertical_border_height = image_height - crop_height 43 | #horizontal_border_height = image_width - crop_width 44 | 45 | #x1 = horizontal_border_height // 2 46 | x11 = (image_width - crop_width - 2 * margin_w) // 2 47 | x12 = x11 + margin_w 48 | x21 = x12 + crop_width 49 | x22 = x21 + margin_w 50 | 51 | y11 = (image_height - crop_height - 2 * margin_h) // 2 52 | y12 = y11 + margin_h 53 | y21 = y12 + crop_height 54 | y22 = y21 + margin_h 55 | 56 | tl = crop(img, y11, x11, margin_h, margin_w + crop_width) 57 | tr = crop(img, y11, x21, margin_h + crop_height, margin_w) 58 | bl = crop(img, y12, x11, margin_h + crop_height, margin_w) 59 | br = crop(img, y21, x12, margin_h, margin_w + crop_width) 60 | center = center_crop(img, [crop_height, crop_width]) 61 | 62 | return tl, tr, bl, br, center 63 | 64 | 65 | 66 | def center_plus_twohori_crops(img: Tensor, size: List[int], 67 | margin_w: int) -> Tuple[Tensor, Tensor, Tensor]: 68 | """Crop the given image into four tiled borders and the central crop. 69 | """ 70 | 71 | if isinstance(size, numbers.Number): 72 | size = (int(size), int(size)) 73 | elif isinstance(size, (tuple, list)) and len(size) == 1: 74 | size = (size[0], size[0]) 75 | 76 | if len(size) != 2: 77 | raise ValueError("Please provide only two dimensions (h, w) for size.") 78 | 79 | image_width, image_height = _get_image_size(img) 80 | 81 | crop_height, crop_width = size 82 | 83 | if crop_width > image_width or crop_height > image_height: 84 | msg = "Requested crop size {} is bigger than input size {}" 85 | raise ValueError(msg.format(size, (image_height, image_width))) 86 | 87 | if crop_width + margin_w > image_width : 88 | msg = "Requested margin size {} + input {} is bigger than input size {}" 89 | raise ValueError(msg.format((0, margin_w), size, (image_height, image_width))) 90 | 91 | # vertical_border_height = image_height - crop_height 92 | # horizontal_border_height = image_width - crop_width 93 | 94 | # x1 = horizontal_border_height // 2 95 | x11 = (image_width - crop_width - 2 * margin_w) // 2 96 | x12 = x11 + margin_w 97 | x21 = x12 + crop_width 98 | 99 | y11 = (image_height - crop_height) // 2 100 | 101 | left = crop(img, y11, x11, crop_height, margin_w) 102 | right = crop(img, y11, x21, crop_height, margin_w) 103 | center = center_crop(img, [crop_height, crop_width]) 104 | 105 | return left, right, center 106 | 107 | from torch import nn 108 | class TwoHoriCrop(nn.Module): 109 | def __init__(self, size, margin_w): 110 | super().__init__() 111 | self.size = size 112 | self.margin_w = margin_w 113 | 114 | def forward(self, x): 115 | return center_plus_twohori_crops(x, self.size, self.margin_w) 116 | 117 | if __name__ == "__main__": 118 | from PIL import Image 119 | 120 | img = Image.open('visualisations/guitar.png') 121 | crops = center_plus_four_crops(img, [336, 336], 112, 112) 122 | order = ['tl', 'tr', 'bl', 'br', 'center'] 123 | 124 | for idx, subimg in zip(order, crops): 125 | subimg.save(f'visualisations/guitar_{idx}.png') 126 | 127 | crops = center_plus_twohori_crops(img, [448, 448], 112) 128 | order = ['left', 'right', 'center2'] 129 | 130 | for idx, subimg in zip(order, crops): 131 | subimg.save(f'visualisations/guitar_{idx}.png') 132 | -------------------------------------------------------------------------------- /data_loader/CondensedMovies_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import pandas as pd 4 | from os.path import join as osj 5 | import numpy as np 6 | import random 7 | from utils.expert_dims import expert_dims 8 | 9 | 10 | class CondensedMovies(Dataset): 11 | def __init__(self, 12 | data_dir, 13 | experts, 14 | split='train'): 15 | 16 | self.data_dir = data_dir 17 | self.metadata_dir = osj(self.data_dir, 'challenge', 'metadata') 18 | self.experts = experts 19 | self.experts_used = [exp for exp, params in self.experts.items() if params['use']] 20 | self.split = split 21 | self.load_metadata() 22 | 23 | def load_metadata(self): 24 | 25 | if self.split in ['train', 'val']: 26 | df = pd.read_csv(osj(self.metadata_dir, 'train_val_challf0.csv')) 27 | elif self.split == 'test': 28 | df = pd.read_csv(osj(self.metadata_dir, 'test_challf0.csv')).sort_values('videoid') 29 | df.sort_values('videoid', inplace=True) 30 | else: 31 | raise ValueError("Split should be either train, val or test") 32 | 33 | df = df[df['split'] == self.split] 34 | df['videoid'] = df['videoid'].astype(str) 35 | self.data = df 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, item): 41 | sample = self.data.iloc[item] 42 | 43 | datum = {} 44 | video_data = {} 45 | for expert in self.experts_used: 46 | ftr_fp = osj(self.data_dir, 'challenge', 'features', self.experts[expert]['src'], 47 | str(sample['upload_year']), 48 | sample['videoid'] + self.experts[expert].get('ext', '.npy')) 49 | if not os.path.isfile(ftr_fp): 50 | if expert != "subtitles": 51 | raise ValueError( 52 | "All features should be available for every video, except for subtitles sometimes.\n" 53 | f"{expert} ftr not found for {sample['videoid']}, {ftr_fp}") 54 | # just fill features with zeros 55 | ftr = np.zeros([1, expert_dims[expert]]) 56 | else: 57 | ftr = np.load(ftr_fp, allow_pickle=True).item()['raw_feats'] 58 | ftr, toks = self._pad_to_max_tokens(ftr, self.experts[expert]['max_tokens']) 59 | video_data[expert] = {'ftr': ftr, 'n_tokens': toks, } 60 | 61 | datum['text'] = sample['caption'] 62 | datum['video'] = video_data 63 | return datum 64 | 65 | def _pad_to_max_tokens(self, ftr, max_tokens): 66 | """ 67 | Pads or truncates features to max tokens. 68 | For truncation, at test time use center, at training use random. 69 | """ 70 | output_shape = list(ftr.shape) 71 | output_shape[0] = max_tokens 72 | output_arr = np.zeros(output_shape) 73 | 74 | if ftr.shape[0] <= max_tokens: 75 | output_arr[:ftr.shape[0]] = ftr 76 | n_tokens = ftr.shape[0] 77 | else: 78 | if self.split == 'train': 79 | start_idx = random.randint(0, ftr.shape[0] - max_tokens) 80 | else: 81 | start_idx = int((ftr.shape[0] / 2) - (max_tokens / 2)) 82 | n_tokens = max_tokens 83 | output_arr = ftr[start_idx:start_idx + max_tokens] 84 | 85 | return output_arr, n_tokens 86 | 87 | 88 | if __name__ == "__main__": 89 | # make_new_splits() 90 | ds = CondensedMovies('/scratch/local/ssd/maxbain/CondensedMovies/', 91 | { 92 | "resnext101": { 93 | "src": "pred_imagenet_25fps_256px_stride25_offset0/resnext101_32x48d", 94 | "max_tokens": 128, 95 | "use": True 96 | }, 97 | "senet154": { 98 | "src": "pred_imagenet_25fps_256px_stride25_offset0/senet154", 99 | "max_tokens": 128, 100 | "use": True 101 | }, 102 | "i3d": { 103 | "src": "pred_i3d_25fps_256px_stride8_offset0_inner_stride1/i3d", 104 | "max_tokens": 128, 105 | "use": True 106 | }, 107 | "vggish": { 108 | "src": "pred_audio/vggish", 109 | "max_tokens": 128, 110 | "use": True, 111 | }, 112 | "densenet161": { 113 | "src": "pred_scene_25fps_256px_stride25_offset0/densenet161", 114 | "max_tokens": 128, 115 | "use": True 116 | }, 117 | "r2p1d-ig65m": { 118 | "src": "pred_r2p1d_30fps_256px_stride16_offset0_inner_stride1/r2p1d-ig65m", 119 | "max_tokens": 128, 120 | "use": True 121 | }, 122 | "subtitles": { 123 | "src": "pred_subs/bert-base-uncased_line", 124 | "max_tokens": 128, 125 | "use": True 126 | } 127 | }, 128 | split='test' 129 | ) 130 | for x in range(len(ds)): 131 | ds.__getitem__(x) -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | import time 10 | import inspect 11 | 12 | 13 | class ConfigParser: 14 | def __init__(self, args, options='', timestamp=True, test=False): 15 | # parse default and custom cli options 16 | for opt in options: 17 | args.add_argument(*opt.flags, default=None, type=opt.type) 18 | args = args.parse_args() 19 | 20 | if args.device: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 22 | if args.resume is None: 23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 24 | assert args.config is not None, msg_no_cfg 25 | self.cfg_fname = Path(args.config) 26 | config = read_json(self.cfg_fname) 27 | self.resume = None 28 | else: 29 | self.resume = Path(args.resume) 30 | resume_cfg_fname = self.resume.parent / 'config.json' 31 | config = read_json(resume_cfg_fname) 32 | if args.config is not None: 33 | config.update(read_json(Path(args.config))) 34 | 35 | # load config file and apply custom cli options 36 | self._config = _update_config(config, options, args) 37 | 38 | # set save_dir where trained model and log will be saved. 39 | save_dir = Path(self.config['trainer']['save_dir']) 40 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 41 | 42 | exper_name = self.config['name'] 43 | self._save_dir = save_dir / 'models' / exper_name / timestamp 44 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp 45 | self._log_dir = save_dir / 'log' / exper_name / timestamp 46 | 47 | if not test: 48 | self.save_dir.mkdir(parents=True, exist_ok=True) 49 | self.log_dir.mkdir(parents=True, exist_ok=True) 50 | 51 | # if set, remove all previous experiments with the current config 52 | if vars(args).get("purge_exp_dir", False): 53 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 54 | config_dir = dirpath.parent 55 | existing = list(config_dir.glob("*")) 56 | print(f"purging {len(existing)} directories from config_dir...") 57 | tic = time.time() 58 | os.system(f"rm -rf {config_dir}") 59 | print(f"Finished purge in {time.time() - tic:.3f}s") 60 | 61 | # save updated config file to the checkpoint dir 62 | if not test: 63 | write_json(self.config, self.save_dir / 'config.json') 64 | 65 | # configure logging module 66 | setup_logging(self.log_dir) 67 | self.log_levels = { 68 | 0: logging.WARNING, 69 | 1: logging.INFO, 70 | 2: logging.DEBUG 71 | } 72 | 73 | def initialize(self, name, module, *args, index=None, **kwargs): 74 | """ 75 | finds a function handle with the name given as 'type' in config, and returns the 76 | instance initialized with corresponding keyword args given as 'args'. 77 | """ 78 | if index is None: 79 | module_name = self[name]['type'] 80 | module_args = dict(self[name]['args']) 81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 82 | module_args.update(kwargs) 83 | else: 84 | module_name = self[name][index]['type'] 85 | module_args = dict(self[name][index]['args']) 86 | 87 | # if parameter not in config subdict, then check if it's in global config. 88 | signature = inspect.signature(getattr(module, module_name).__init__) 89 | print(module_name) 90 | for param in signature.parameters.keys(): 91 | if param not in module_args and param in self.config: 92 | module_args[param] = self[param] 93 | 94 | return getattr(module, module_name)(*args, **module_args) 95 | 96 | def __getitem__(self, name): 97 | return self.config[name] 98 | 99 | def get_logger(self, name, verbosity=2): 100 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 101 | self.log_levels.keys()) 102 | assert verbosity in self.log_levels, msg_verbosity 103 | logger = logging.getLogger(name) 104 | logger.setLevel(self.log_levels[verbosity]) 105 | return logger 106 | 107 | # setting read-only attributes 108 | @property 109 | def config(self): 110 | return self._config 111 | 112 | @property 113 | def save_dir(self): 114 | return self._save_dir 115 | 116 | @property 117 | def log_dir(self): 118 | return self._log_dir 119 | 120 | 121 | # helper functions used to update config dict with custom cli options 122 | def _update_config(config, options, args): 123 | for opt in options: 124 | value = getattr(args, _get_opt_name(opt.flags)) 125 | if value is not None: 126 | _set_by_path(config, opt.target, value) 127 | else: 128 | _set_by_path(config, opt.target, opt.default) 129 | return config 130 | 131 | 132 | def _get_opt_name(flags): 133 | for flg in flags: 134 | if flg.startswith('--'): 135 | return flg.replace('--', '') 136 | return flags[0].replace('--', '') 137 | 138 | 139 | def _set_by_path(tree, keys, value): 140 | """Set a value in a nested object in tree by sequence of keys.""" 141 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 142 | 143 | 144 | def _get_by_path(tree, keys): 145 | """Access a nested object in tree by sequence of keys.""" 146 | return reduce(getitem, keys, tree) 147 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from datetime import datetime 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | import functools 7 | import time 8 | import socket 9 | import numpy as np 10 | import psutil 11 | import msgpack 12 | import humanize 13 | import os 14 | 15 | import sys 16 | 17 | def query_yes_no(question, default="yes"): 18 | """Ask a yes/no question via raw_input() and return their answer. 19 | 20 | "question" is a string that is presented to the user. 21 | "default" is the presumed answer if the user just hits . 22 | It must be "yes" (the default), "no" or None (meaning 23 | an answer is required of the user). 24 | 25 | The "answer" return value is True for "yes" or False for "no". 26 | """ 27 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 28 | if default is None: 29 | prompt = " [y/n] " 30 | elif default == "yes": 31 | prompt = " [Y/n] " 32 | elif default == "no": 33 | prompt = " [y/N] " 34 | else: 35 | raise ValueError("invalid default answer: '%s'" % default) 36 | 37 | while True: 38 | sys.stdout.write(question + prompt) 39 | choice = input().lower() 40 | if default is not None and choice == "": 41 | return valid[default] 42 | elif choice in valid: 43 | return valid[choice] 44 | else: 45 | sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") 46 | 47 | def state_dict_data_parallel_fix(load_state_dict, curr_state_dict): 48 | load_keys = list(load_state_dict.keys()) 49 | curr_keys = list(curr_state_dict.keys()) 50 | 51 | redo_dp = False 52 | undo_dp = False 53 | if not curr_keys[0].startswith('module.') and load_keys[0].startswith('module.'): 54 | undo_dp = True 55 | elif curr_keys[0].startswith('module.') and not load_keys[0].startswith('module.'): 56 | redo_dp = True 57 | 58 | if undo_dp: 59 | from collections import OrderedDict 60 | new_state_dict = OrderedDict() 61 | for k, v in load_state_dict.items(): 62 | name = k[7:] # remove `module.` 63 | new_state_dict[name] = v 64 | # load params 65 | elif redo_dp: 66 | from collections import OrderedDict 67 | new_state_dict = OrderedDict() 68 | for k, v in load_state_dict.items(): 69 | name = 'module.' + k # remove `module.` 70 | new_state_dict[name] = v 71 | else: 72 | new_state_dict = load_state_dict 73 | return new_state_dict 74 | 75 | def print_numpy(x, val=True, shp=False): 76 | """Print the mean, min, max, median, std, and size of a numpy array 77 | Parameters: 78 | val (bool) -- if print the values of the numpy array 79 | shp (bool) -- if print the shape of the numpy array 80 | """ 81 | x = x.astype(np.float64) 82 | if shp: 83 | print('shape,', x.shape) 84 | if val: 85 | x = x.flatten() 86 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 87 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 88 | 89 | 90 | def mkdirs(paths): 91 | """create empty directories if they don't exist 92 | Parameters: 93 | paths (str list) -- a list of directory paths 94 | """ 95 | if isinstance(paths, list) and not isinstance(paths, str): 96 | for path in paths: 97 | mkdir(path) 98 | else: 99 | mkdir(paths) 100 | 101 | 102 | def mkdir(path): 103 | """create a single empty directory if it didn't exist 104 | Parameters: 105 | path (str) -- a single directory path 106 | """ 107 | if not os.path.exists(path): 108 | os.makedirs(path) 109 | 110 | def read_json(fname): 111 | with fname.open('rt') as handle: 112 | return json.load(handle, object_hook=OrderedDict) 113 | 114 | def write_json(content, fname): 115 | with fname.open('wt') as handle: 116 | json.dump(content, handle, indent=4, sort_keys=False) 117 | 118 | def inf_loop(data_loader): 119 | ''' wrapper function for endless data loader. ''' 120 | for loader in repeat(data_loader): 121 | yield from loader 122 | 123 | def memory_summary(): 124 | vmem = psutil.virtual_memory() 125 | msg = ( 126 | f">>> Currently using {vmem.percent}% of system memory " 127 | f"{humanize.naturalsize(vmem.used)}/{humanize.naturalsize(vmem.available)}" 128 | ) 129 | print(msg) 130 | 131 | @functools.lru_cache(maxsize=64, typed=False) 132 | def memcache(path): 133 | suffix = Path(path).suffix 134 | print(f"loading features >>>", end=" ") 135 | tic = time.time() 136 | if suffix == ".npy": 137 | res = np_loader(path) 138 | else: 139 | raise ValueError(f"unknown suffix: {suffix} for path {path}") 140 | print(f"[Total: {time.time() - tic:.1f}s] ({socket.gethostname() + ':' + str(path)})") 141 | return res 142 | 143 | def np_loader(np_path, l2norm=False): 144 | with open(np_path, "rb") as f: 145 | data = np.load(f, encoding="latin1", allow_pickle=True) 146 | if isinstance(data, np.ndarray) and data.size == 1: 147 | data = data[()] # handle numpy dict storage convnetion 148 | if l2norm: 149 | print("L2 normalizing features") 150 | if isinstance(data, dict): 151 | for key in data: 152 | feats_ = data[key] 153 | feats_ = feats_ / max(np.linalg.norm(feats_), 1E-6) 154 | data[key] = feats_ 155 | elif data.ndim == 2: 156 | data_norm = np.linalg.norm(data, axis=1) 157 | data = data / np.maximum(data_norm.reshape(-1, 1), 1E-6) 158 | else: 159 | raise ValueError("unexpected data format {}".format(type(data))) 160 | return data 161 | 162 | 163 | class Timer: 164 | def __init__(self): 165 | self.cache = datetime.now() 166 | 167 | def check(self): 168 | now = datetime.now() 169 | duration = now - self.cache 170 | self.cache = now 171 | return duration.total_seconds() 172 | 173 | def reset(self): 174 | self.cache = datetime.now() 175 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import data_loader.data_loader as module_data 5 | import collections 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | from sacred import Experiment 10 | import transformers 11 | from trainer.trainer import verbose 12 | import numpy as np 13 | from utils.util import state_dict_data_parallel_fix 14 | import zipfile 15 | 16 | ex = Experiment('test') 17 | 18 | @ex.main 19 | def run(): 20 | 21 | # setup data_loader instances 22 | config['data_loader']['args']['shuffle'] = False 23 | data_loader = config.initialize('data_loader', module_data) 24 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model']) 25 | 26 | # build model architecture 27 | config['arch']['args']['experts_used'] = data_loader.dataset.experts_used 28 | model = config.initialize('arch', module_arch) 29 | 30 | # get function handles of loss and metrics 31 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 32 | 33 | if config.resume is not None: 34 | checkpoint = torch.load(config.resume) 35 | state_dict = checkpoint['state_dict'] 36 | new_state_dict = state_dict_data_parallel_fix(state_dict, model.state_dict()) 37 | model.load_state_dict(new_state_dict, strict=True) 38 | else: 39 | print('Using random weights') 40 | 41 | if config['n_gpu'] > 1: 42 | model = torch.nn.DataParallel(model) 43 | 44 | # prepare model for testing 45 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 46 | model = model.to(device) 47 | model.eval() 48 | 49 | meta_arr = [] 50 | text_embed_arr = [] 51 | vid_embed_arr = [] 52 | print(len(data_loader)) 53 | with torch.no_grad(): 54 | for i, data in tqdm(tqdm(enumerate(data_loader))): 55 | # leave this for now since not doing anything on the gpu 56 | if tokenizer is not None: 57 | data['text'] = tokenizer(data['text'], return_tensors='pt', padding=True, truncation=True) 58 | data['text'] = {key: val.cuda() for key, val in data['text'].items()} 59 | 60 | 61 | _, text_embed, vid_embed = model(data, eval=True) 62 | text_embed_arr.append(text_embed) 63 | vid_embed_arr.append(vid_embed) 64 | 65 | text_embeds = torch.cat(text_embed_arr) 66 | vid_embeds = torch.cat(vid_embed_arr) 67 | 68 | embed_stack = torch.einsum('ted,ved->tve', text_embeds, vid_embeds) 69 | sims = embed_stack.sum(dim=2) / embed_stack.shape[2] 70 | #sims = sim_matrix(text_embeds, vid_embeds) 71 | sims = sims.detach().cpu().numpy() 72 | 73 | # similarity matrix checks 74 | if sims.min() < 0 or sims.max() > 1: 75 | ValueError(f"Similarity matrix should be \in [0,1], found {sims.min(), sims.max()}") 76 | 77 | if len(sims.shape) != 2: 78 | ValueError(f"Similarity matrix should be 2-D, not {sims.shape}") 79 | 80 | if sims.shape[0] != sims.shape[1]: 81 | ValueError(f"Expects similarity matrix to be square, since num_captions == num_videos, recieved {sims.shape}") 82 | 83 | 84 | # save similarity matrix 85 | if config.resume is not None: 86 | sim_save_dir = config.resume.parent 87 | else: 88 | sim_save_dir = config._save_dir 89 | if not sim_save_dir.exists(): 90 | sim_save_dir.mkdir() 91 | 92 | sim_save_fp = sim_save_dir / f"sim_matrix_{data_loader.dataset.split}.npy" 93 | np.save(sim_save_fp, sims) 94 | 95 | txt_save_fp = sim_save_dir / f"txt_embeds__{data_loader.dataset.split}.npy" 96 | np.save(txt_save_fp, text_embeds.cpu().numpy()) 97 | 98 | vid_save_fp = sim_save_dir / f"vid_embeds__{data_loader.dataset.split}.npy" 99 | np.save(vid_save_fp, vid_embeds.cpu().numpy()) 100 | 101 | if data_loader.dataset.split == 'val': 102 | #if True: 103 | # load from numpy file 104 | # sims = np.load(...) 105 | # DO what happens during evaluation code 106 | 107 | 108 | nested_metrics = {} 109 | 110 | for metric in metric_fns: 111 | metric_name = metric.__name__ 112 | res = metric(sims) 113 | verbose(epoch=0, metrics=res, name="", mode=metric_name) 114 | nested_metrics[metric_name] = res 115 | elif data_loader.dataset.split == 'test': 116 | # create zip file for submission 117 | submission_zip = sim_save_fp.parent / 'submission.zip' 118 | zipfile.ZipFile(submission_zip, mode='w').write(sim_save_fp, sim_save_fp.name) 119 | 120 | print(f"--For test submission, please upload {submission_zip} to the Codalab site.--\n" 121 | f"https://competitions.codalab.org/competitions/34124#participate-submit_results") 122 | 123 | # if config.config['visualizer']: 124 | # meta_arr_cat = {key: [] for key in meta_arr[0]} 125 | # for meta in meta_arr: 126 | # for key, val in meta.items(): 127 | # meta_arr_cat[key] += val 128 | 129 | if __name__ == '__main__': 130 | args = argparse.ArgumentParser(description='PyTorch Template') 131 | #args.add_argument('-t', '--test_submission', action='store_true', 132 | # help='whether to evaluate on test data for test submission, else val.') 133 | args.add_argument('-r', '--resume', default=None, type=str, 134 | help='path to latest checkpoint (default: None)') 135 | args.add_argument('-d', '--device', default=None, type=str, 136 | help='indices of GPUs to enable (default: all)') 137 | args.add_argument('-c', '--config', default=None, type=str, 138 | help='config file path (default: None)') 139 | 140 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type default target') 141 | options = [ 142 | CustomArgs(['--split'], type=str, default='val', target=('data_loader', 'args', 'split')), 143 | CustomArgs(['--bs', '--batch_size'], type=int, default=16, target=('data_loader', 'args', 'batch_size')), 144 | ] 145 | config = ConfigParser(args, options, test=True) 146 | 147 | if config._config['data_loader']['args']['split'] not in ['val', 'test']: 148 | raise ValueError("Split should be one of either val or test (the latter for submission), not ") 149 | 150 | ex.add_config(config.config) 151 | 152 | ex.run() 153 | -------------------------------------------------------------------------------- /utils/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source, attr 3 | from dominate.tags import span 4 | import os 5 | 6 | 7 | class HTML: 8 | """This HTML class allows us to save images and write texts into a single HTML file. 9 | 10 | It consists of functions such as (add a text header to the HTML file), 11 | (add a row of images to the HTML file), and (save the HTML to the disk). 12 | It is based on Python library 'dominate', a Python library for creating and 13 | manipulating HTML documents using a DOM API. 14 | """ 15 | 16 | def __init__(self, web_dir, title, refresh=0): 17 | """Initialize the HTML classes 18 | 19 | Parameters: 20 | web_dir (str) -- a directory that stores the webpage. HTML file will be 21 | created at /index.html; images will be saved at 0: 35 | with self.doc.head: 36 | meta(http_equiv="refresh", content=str(refresh)) 37 | 38 | def get_image_dir(self): 39 | """Return the directory that stores images""" 40 | return self.img_dir 41 | 42 | def add_header(self, text): 43 | """Insert a header to the HTML file 44 | 45 | Parameters: 46 | text (str) -- the header text 47 | """ 48 | with self.doc: 49 | h3(text) 50 | 51 | def add_videos(self, vids, txts, links, width=400, hidden_tag="hidden"): 52 | """add images to the HTML file 53 | 54 | Parameters: 55 | vids (str list) -- a list of image paths 56 | txts (str list) -- a list of image names shown on the website 57 | links (str list) -- a list of hyperref links; when you click an image, 58 | it will redirect you to a new page 59 | """ 60 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 61 | self.doc.add(self.t) 62 | colors = ["red", "blue", "gold", "salman"] 63 | with self.t: 64 | with tr(): 65 | for vid, txt, link in zip(vids, txts, links): 66 | td_style = "word-wrap: break-word; width:{}px".format(width) 67 | with td(style=td_style, halign="center", valign="top"): 68 | with p(): 69 | vid_path = str(vid) 70 | if vid_path == hidden_tag: 71 | p_style = "font-weight: bold; width:{}px;" 72 | p_style = p_style.format(width * 3) 73 | p("hidden video", style=p_style) 74 | else: 75 | with a(href=str(link)): 76 | with video(): 77 | attr(controls="controls") 78 | source(src=vid_path, type="video/mp4") 79 | br() 80 | rows = txt.split("
") 81 | for idx, row in enumerate(rows): 82 | color = colors[idx % len(colors)] 83 | bold_tag = "" 84 | if not row.startswith(bold_tag): 85 | s_style = "color:{};".format(color) 86 | else: 87 | s_style = "color:black; font-weight: bold;" 88 | row = row[len(bold_tag):] 89 | span(row, style=s_style) 90 | 91 | def add_images(self, ims, txts, links, width=400): 92 | """add images to the HTML file 93 | 94 | Parameters: 95 | ims (str list) -- a list of image paths 96 | txts (str list) -- a list of image names shown on the website 97 | links (str list) -- a list of hyperref links; when you click an image, 98 | it will redirect you to a new page 99 | """ 100 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 101 | self.doc.add(self.t) 102 | colors = ["red", "blue", "gold", "salman"] 103 | with self.t: 104 | with tr(): 105 | for im, txt, link in zip(ims, txts, links): 106 | td_style = "word-wrap: break-word;" 107 | with td(style=td_style, halign="center", valign="top"): 108 | with p(): 109 | with a(href=link): 110 | img( 111 | style="width:%dpx" % width, 112 | src=im, 113 | ) 114 | br() 115 | rows = txt.split("
") 116 | for idx, row in enumerate(rows): 117 | color = colors[idx % len(colors)] 118 | bold_tag = "" 119 | if not row.startswith(bold_tag): 120 | s_style = "color:{};".format(color) 121 | else: 122 | s_style = "color:black; font-weight: bold;" 123 | row = row[len(bold_tag):] 124 | span(row, style=s_style) 125 | 126 | def save(self): 127 | """save the current content to the HMTL file""" 128 | html_file = "%s/index.html" % self.web_dir 129 | f = open(html_file, "wt") 130 | f.write(self.doc.render()) 131 | f.close() 132 | 133 | 134 | if __name__ == "__main__": # we show an example usage here. 135 | html = HTML("web/", "test_html") 136 | html.add_header("hello world") 137 | 138 | ims, txts, links = [], [], [] 139 | for n in range(4): 140 | ims.append("image_%d.png" % n) 141 | txts.append("text_%d" % n) 142 | links.append("image_%d.png" % n) 143 | html.add_images(ims, txts, links) 144 | html.save() -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | """A simple HTML visualizer. 2 | 3 | It is based on the Cycle-GAN codebase: 4 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 5 | """ 6 | import os 7 | import numpy as np 8 | from pathlib import Path 9 | from . import util, html 10 | import pdb 11 | 12 | class RetrievalVis: 13 | """This class includes several functions that can display/save images. 14 | 15 | It uses a Python library 'visdom' for display, and a Python library 'dominate' 16 | (wrapped in 'HTML') for creating HTML files with images. 17 | """ 18 | 19 | def __init__(self, exp_name, web_dir, src_video_dir, vis_vid_freq, num_samples=50): 20 | """Initialize the Visualizer class 21 | Create an HTML object for saveing HTML filters 22 | """ 23 | self.name = exp_name 24 | self.web_dir = web_dir 25 | self.vis_vid_freq = vis_vid_freq 26 | self.img_dir = os.path.join(self.web_dir, "images") 27 | self.num_samples = num_samples 28 | 29 | self.data_type = 'images' # 'images' or 'videos' 30 | assert self.data_type in ('images', 'videos') 31 | 32 | print(f"create web directory {self.web_dir}...") 33 | mkdirs([self.web_dir, self.img_dir]) 34 | 35 | # cluster specific 36 | if "$TMPDIR" in src_video_dir: 37 | src_video_dir = src_video_dir.replace("$TMPDIR", os.environ['TMPDIR']) 38 | 39 | src_dir = Path(src_video_dir).absolute() 40 | print(f"symlinking videos from {src_dir}...") 41 | sym_dir = (Path(self.web_dir) / "videos").absolute() 42 | if sym_dir.is_symlink(): 43 | os.remove(sym_dir) 44 | sym_dir.symlink_to(src_dir) 45 | 46 | def visualize_ranking(self, sims, epoch, meta, nested_metrics): 47 | if not (self.vis_vid_freq and epoch % self.vis_vid_freq == 0): 48 | return 49 | 50 | dists = -sims 51 | np.random.seed(0) 52 | sorted_ranks = np.argsort(dists, axis=1) 53 | gt_dists = np.diag(dists) 54 | rankings = [] 55 | vis_top_k = 5 56 | hide_gt = False 57 | # num_indep_samples = 1 58 | # random_seeds = np.arange(num_indep_samples) 59 | sample = np.random.choice(np.arange(dists.shape[0]), size=self.num_samples, 60 | replace=False) 61 | for ii in sample: 62 | ranked_idx = sorted_ranks[ii][:vis_top_k] 63 | gt_captions = meta["raw_captions"][ii] 64 | # if args.sample_single_gt_caption: 65 | # gt_captions = np.random.choice(gt_captions, 1).tolist() 66 | datum = { 67 | "gt-sim": -gt_dists[ii], 68 | "gt-captions": gt_captions, 69 | "gt-rank": np.where(sorted_ranks[ii] == ii)[0][0], 70 | "gt-path": meta["paths"][ii], 71 | "top-k-sims": -dists[ii][ranked_idx], 72 | "top-k-paths": np.array(meta["paths"])[ranked_idx], 73 | "hide-gt": hide_gt, 74 | } 75 | rankings.append(datum) 76 | self.display_current_results( 77 | rankings, 78 | epoch=epoch, 79 | metrics=nested_metrics["t2v_metrics"], 80 | ) 81 | 82 | def display_current_results(self, rankings, epoch, metrics): 83 | """Display current results on visdom; save current results to an HTML file. 84 | 85 | Parameters: 86 | visuals (OrderedDict) - - dictionary of images to display or save 87 | epoch (int) - - the current epoch 88 | save_result (bool) - - if save the current results to an HTML file 89 | """ 90 | if not Path(self.web_dir).exists(): 91 | Path(self.web_dir).mkdir(exist_ok=True, parents=True) 92 | print(f"updating webpage at {self.web_dir}") 93 | title = f"Experiment name = {self.name}" 94 | refresh = True 95 | if not refresh: 96 | print("DISABLING WEB PAGE REFRESH") 97 | webpage = html.HTML(web_dir=self.web_dir, title=title, refresh=refresh) 98 | 99 | msg = f"epoch [{epoch}] - {self.name}" 100 | webpage.add_header(msg) 101 | msg = (f"R1: {metrics['R1']:.1f}, " 102 | f"R5: {metrics['R5']:.1f}, " 103 | f"R10: {metrics['R10']:.1f}, " 104 | f"MedR: {metrics['MedR']}") 105 | webpage.add_header(msg) 106 | print(f"Top {len(rankings[0])} retreived videos at epoch: {epoch}") 107 | 108 | for ranking in rankings: 109 | vids, txts, links = [], [], [] 110 | gt_vid_path = os.path.join('videos', ranking["gt-path"]) 111 | #gt_captions = [" ".join(x) for x in ranking["gt-captions"]] 112 | gt_captions = ranking['gt-captions'] 113 | gt_captions = "
" + (gt_captions) + "
" 114 | if ranking["hide-gt"]: 115 | txts.append(gt_captions) 116 | links.append("hidden") 117 | vids.append("hidden") 118 | else: 119 | txt = (f"{gt_captions}
Rank: {ranking['gt-rank']}, " 120 | f"Sim: {ranking['gt-sim']:.3f} [{Path(ranking['gt-path']).stem}]") 121 | txts.append(txt) 122 | links.append(gt_vid_path) 123 | vids.append(gt_vid_path) 124 | 125 | for idx, (vid_path, sim) in enumerate(zip(ranking["top-k-paths"], 126 | ranking["top-k-sims"])): 127 | vid_path = Path(os.path.join('videos', vid_path)) 128 | if ranking["hide-gt"]: 129 | txt = f"choice: {idx}" 130 | else: 131 | txt = f"Rank: {idx}, Sim: {sim:.3f}, [{Path(vid_path).stem}]" 132 | txts.append(txt) 133 | vids.append(vid_path) 134 | links.append(vid_path) 135 | if self.data_type == 'videos': 136 | webpage.add_videos(vids, txts, links, width=200) 137 | elif self.data_type == 'images': 138 | webpage.add_images(vids, txts, links, width=200) 139 | print(f"added {len(vids)} videos") 140 | webpage.save() 141 | 142 | def mkdirs(paths): 143 | """create empty directories if they don't exist 144 | 145 | Parameters: 146 | paths (str list) -- a list of directory paths 147 | """ 148 | if isinstance(paths, list) and not isinstance(paths, str): 149 | for path in paths: 150 | mkdir(path) 151 | else: 152 | mkdir(paths) 153 | 154 | 155 | def mkdir(path): 156 | """create a single empty directory if it didn't exist 157 | 158 | Parameters: 159 | path (str) -- a single directory path 160 | """ 161 | if not os.path.exists(path): 162 | os.makedirs(path) 163 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from utils import inf_loop 6 | from model.model import sim_matrix 7 | from itertools import cycle 8 | 9 | 10 | class Trainer(BaseTrainer): 11 | """ 12 | Trainer class 13 | 14 | Note: 15 | Inherited from BaseTrainer. 16 | """ 17 | 18 | def __init__(self, model, loss, metrics, optimizer, config, data_loader, 19 | valid_data_loader=None, lr_scheduler=None, len_epoch=None, writer=None, 20 | visualizer=None, tokenizer=None, max_samples_per_epoch=50000, init_val=True): 21 | super().__init__(model, loss, metrics, optimizer, config, writer, init_val=init_val) 22 | self.init_val = init_val 23 | self.config = config 24 | self.data_loader = data_loader 25 | if len_epoch is None: 26 | # epoch-based training 27 | self.len_epoch = len(self.data_loader) 28 | else: 29 | # iteration-based training 30 | self.data_loader = inf_loop(data_loader) 31 | self.len_epoch = len_epoch 32 | 33 | self.valid_data_loader = valid_data_loader 34 | self.do_validation = self.valid_data_loader is not None 35 | self.lr_scheduler = lr_scheduler 36 | self.log_step = int(np.sqrt(data_loader.batch_size)) 37 | self.visualizer = visualizer 38 | self.val_chunking = True 39 | self.batch_size = self.data_loader.batch_size 40 | self.tokenizer = tokenizer 41 | self.max_samples_per_epoch = max_samples_per_epoch 42 | 43 | def _eval_metrics(self, output): 44 | acc_metrics = np.zeros(len(self.metrics)) 45 | for i, metric in enumerate(self.metrics): 46 | acc_metrics[i] += metric(output) 47 | if self.writer is not None: 48 | self.writer.log_scalar('{}'.format(metric.__name__), acc_metrics[i]) 49 | return acc_metrics 50 | 51 | def _train_epoch(self, epoch): 52 | """ 53 | Training logic for an epoch 54 | 55 | :param epoch: Current training epoch. 56 | :return: A log that contains all information you want to save. 57 | 58 | Note: 59 | If you have additional information to record, for example: 60 | > additional_log = {"x": x, "y": y} 61 | merge it with log before return. i.e. 62 | > log = {**log, **additional_log} 63 | > return log 64 | 65 | The metrics in log must have the key 'metrics'. 66 | """ 67 | self.model.train() 68 | total_loss = 0 69 | total_metrics = np.zeros(len(self.metrics)) 70 | for batch_idx, data in enumerate(self.data_loader): 71 | if (batch_idx + 1) * self.batch_size > self.max_samples_per_epoch: 72 | break 73 | # then assume we must tokenize the input, e.g. its a string 74 | if self.tokenizer is not None: 75 | data['text'] = self.tokenizer(data['text'], return_tensors='pt', padding=True, 76 | truncation=True) 77 | data['text'] = {key: val.cuda() for key, val in data['text'].items()} 78 | for key, val in data['video'].items(): 79 | data['video'][key]['ftr'] = data['video'][key]['ftr'].cuda() 80 | data['video'][key]['n_tokens'] = data['video'][key]['n_tokens'].cuda() 81 | 82 | self.optimizer.zero_grad() 83 | output = self.model(data) 84 | #output = sim_matrix(text_embeds, video_embeds) 85 | loss = self.loss(output) 86 | loss.backward() 87 | self.optimizer.step() 88 | if self.writer is not None: 89 | self.writer.log_scalar(f'loss_train', loss.detach().item()) 90 | 91 | total_loss += loss.detach().item() 92 | 93 | if batch_idx % self.log_step == 0: 94 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( 95 | epoch, 96 | self._progress(batch_idx), 97 | loss.detach().item())) 98 | 99 | self.optimizer.zero_grad() 100 | 101 | if batch_idx == self.len_epoch: 102 | break 103 | 104 | log = { 105 | 'loss': total_loss / self.len_epoch, 106 | 'metrics': (total_metrics / self.len_epoch).tolist() 107 | } 108 | 109 | if self.do_validation: 110 | val_log = self._valid_epoch(epoch) 111 | log.update(val_log) 112 | 113 | if self.lr_scheduler is not None: 114 | self.lr_scheduler.step() 115 | 116 | return log 117 | 118 | def _valid_epoch(self, epoch): 119 | """ 120 | Validate after training an epoch 121 | 122 | :return: A log that contains information about validation 123 | 124 | Note: 125 | The validation metrics in log must have the key 'val_metrics'. 126 | """ 127 | self.model.eval() 128 | total_val_loss = 0 129 | total_val_metrics = np.zeros(len(self.metrics)) 130 | # self.valid_data_loader.dataset.__getitem__(0) 131 | meta_arr = [] 132 | text_embed_arr = [] 133 | vid_embed_arr = [] 134 | with torch.no_grad(): 135 | for batch_idx, data in enumerate(self.valid_data_loader): 136 | #meta_arr.append(data['meta']) 137 | if self.tokenizer is not None: 138 | data['text'] = self.tokenizer(data['text'], return_tensors='pt', padding=True, truncation=True) 139 | data['text'] = {key: val.cuda() for key, val in data['text'].items()} 140 | for key, val in data['video'].items(): 141 | data['video'][key]['ftr'] = data['video'][key]['ftr'].cuda() 142 | data['video'][key]['n_tokens'] = data['video'][key]['n_tokens'].cuda() 143 | sim_mat, text_embed, vid_embed = self.model(data, eval=True) 144 | text_embed_arr.append(text_embed.cpu()) 145 | vid_embed_arr.append(vid_embed.cpu()) 146 | loss = self.loss(sim_mat) 147 | total_val_loss += loss.item() 148 | 149 | text_embeds = torch.cat(text_embed_arr) 150 | vid_embeds = torch.cat(vid_embed_arr) 151 | #sims = sim_matrix(text_embeds, vid_embeds).detach().cpu().numpy() 152 | embed_stack = torch.einsum('ted,ved->tve', text_embeds, vid_embeds) 153 | sims = embed_stack.sum(dim=2) / embed_stack.shape[2] 154 | sims = sims.detach().cpu().numpy() 155 | 156 | # TODO: this needs a clean 157 | if self.writer is not None: 158 | self.writer.log_scalar(f'loss_val', total_val_loss / len(self.valid_data_loader)) 159 | nested_metrics = {} 160 | for metric in self.metrics: 161 | metric_name = metric.__name__ 162 | res = metric(sims) 163 | verbose(epoch=epoch, metrics=res, name=self.valid_data_loader.dataset_name, 164 | mode=metric_name) 165 | nested_metrics[metric_name] = res 166 | 167 | if self.writer is not None: 168 | to_write = format_nested_metrics_for_writer(res, mode=metric_name, 169 | name=self.valid_data_loader.dataset_name) 170 | for key, val in to_write.items(): 171 | self.writer.log_scalar(key, val) 172 | 173 | if self.visualizer is not None: 174 | meta_arr_cat = {key: [] for key in meta_arr[0]} 175 | for meta in meta_arr: 176 | for key, val in meta.items(): 177 | meta_arr_cat[key] += val 178 | self.visualizer.visualize_ranking(sims, epoch, meta_arr_cat, nested_metrics) 179 | 180 | res_dict = { 181 | 'val_loss': total_val_loss / len(self.valid_data_loader), 182 | 'nested_val_metrics': nested_metrics 183 | } 184 | 185 | return res_dict 186 | 187 | def _progress(self, batch_idx): 188 | base = '[{}/{} ({:.0f}%)]' 189 | if hasattr(self.data_loader, 'n_samples'): 190 | current = batch_idx * self.data_loader.batch_size 191 | total = self.data_loader.n_samples 192 | else: 193 | current = batch_idx 194 | total = self.len_epoch 195 | return base.format(current, total, 100.0 * current / total) 196 | 197 | 198 | def verbose(epoch, metrics, mode, name="TEST"): 199 | r1, r5, r10, r50 = metrics["R1"], metrics["R5"], metrics["R10"], metrics["R50"] 200 | msg = f"[{mode}]{name:s} epoch {epoch}, R@1: {r1:.1f}" 201 | msg += f", R@5: {r5:.1f}, R@10 {r10:.1f}, R@50 {r50:.1f}" 202 | msg += f"MedR: {metrics['MedR']:g}, MeanR: {metrics['MeanR']:.1f}" 203 | print(msg) 204 | 205 | 206 | def format_nested_metrics_for_writer(metrics, mode, name="TEST"): 207 | res = {} 208 | for key, val in metrics.items(): 209 | log_name = f"[{mode}]{name}_{key}" 210 | res[log_name] = val 211 | return res 212 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | def __init__(self, model, loss, metrics, optimizer, config, writer=None, init_val=True): 12 | self.config = config 13 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 14 | self.init_val = init_val 15 | # setup GPU device if available, move model into configured device 16 | self.device, device_ids = self._prepare_device(config['n_gpu']) 17 | self.model = model.to(self.device) 18 | self.model.device = self.device 19 | if len(device_ids) > 1: 20 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 21 | 22 | self.loss = loss 23 | self.metrics = metrics 24 | self.optimizer = optimizer 25 | 26 | cfg_trainer = config['trainer'] 27 | self.epochs = cfg_trainer['epochs'] 28 | self.save_period = cfg_trainer['save_period'] 29 | self.monitor = cfg_trainer.get('monitor', 'off') 30 | 31 | # configuration to monitor model performance and save best 32 | if self.monitor == 'off': 33 | self.mnt_mode = 'off' 34 | self.mnt_best = 0 35 | else: 36 | self.mnt_mode, self.mnt_metric = self.monitor.split() 37 | assert self.mnt_mode in ['min', 'max'] 38 | 39 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 40 | self.early_stop = cfg_trainer.get('early_stop', inf) 41 | 42 | self.start_epoch = 1 43 | 44 | self.checkpoint_dir = config.save_dir 45 | 46 | # setup visualization writer instance 47 | #self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 48 | self.writer = writer 49 | 50 | if config.resume is not None: 51 | self._resume_checkpoint(config.resume) 52 | 53 | @abstractmethod 54 | def _train_epoch(self, epoch): 55 | """ 56 | Training logic for an epoch 57 | 58 | :param epoch: Current epoch number 59 | """ 60 | raise NotImplementedError 61 | 62 | @abstractmethod 63 | def _valid_epoch(self, epoch): 64 | """ 65 | Training logic for an epoch 66 | 67 | :param epoch: Current epoch number 68 | """ 69 | raise NotImplementedError 70 | 71 | 72 | def train(self): 73 | """ 74 | Full training logic 75 | """ 76 | not_improved_count = 0 77 | if self.init_val: 78 | _ = self._valid_epoch(-1) 79 | 80 | for epoch in range(self.start_epoch, self.epochs + 1): 81 | result = self._train_epoch(epoch) 82 | 83 | # save logged informations into log dict 84 | 85 | # save logged informations into log dict 86 | log = {'epoch': epoch} 87 | for key, value in result.items(): 88 | if key == 'metrics': 89 | log.update({mtr.__name__: value[i] 90 | for i, mtr in enumerate(self.metrics)}) 91 | elif key == 'val_metrics': 92 | log.update({'val_' + mtr.__name__: value[i] 93 | for i, mtr in enumerate(self.metrics)}) 94 | elif key == 'nested_val_metrics': 95 | # NOTE: currently only supports two layers of nesting 96 | for subkey, subval in value.items(): 97 | for subsubkey, subsubval in subval.items(): 98 | log[f"val_{subkey}_{subsubkey}"] = subsubval 99 | else: 100 | log[key] = value 101 | 102 | # print logged informations to the screen 103 | for key, value in log.items(): 104 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 105 | 106 | # evaluate model performance according to configured metric, save best checkpoint as model_best 107 | best = False 108 | if self.mnt_mode != 'off': 109 | try: 110 | # check whether model performance improved or not, according to specified metric(mnt_metric) 111 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 112 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 113 | except KeyError: 114 | self.logger.warning("Warning: Metric '{}' is not found. " 115 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 116 | self.mnt_mode = 'off' 117 | improved = False 118 | 119 | if improved: 120 | self.mnt_best = log[self.mnt_metric] 121 | not_improved_count = 0 122 | best = True 123 | else: 124 | not_improved_count += 1 125 | 126 | if not_improved_count > self.early_stop: 127 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 128 | "Training stops.".format(self.early_stop)) 129 | break 130 | 131 | #if epoch % self.save_period == 0 or best: 132 | if best: 133 | self._save_checkpoint(epoch, save_best=best) 134 | 135 | def _prepare_device(self, n_gpu_use): 136 | """ 137 | setup GPU device if available, move model into configured device 138 | """ 139 | n_gpu = torch.cuda.device_count() 140 | if n_gpu_use > 0 and n_gpu == 0: 141 | self.logger.warning("Warning: There\'s no GPU available on this machine," 142 | "training will be performed on CPU.") 143 | n_gpu_use = 0 144 | if n_gpu_use > n_gpu: 145 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 146 | "on this machine.".format(n_gpu_use, n_gpu)) 147 | n_gpu_use = n_gpu 148 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 149 | list_ids = list(range(n_gpu_use)) 150 | return device, list_ids 151 | 152 | def _save_checkpoint(self, epoch, save_best=False): 153 | """ 154 | Saving checkpoints 155 | 156 | :param epoch: current epoch number 157 | :param log: logging information of the epoch 158 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 159 | """ 160 | arch = type(self.model).__name__ 161 | state = { 162 | 'arch': arch, 163 | 'epoch': epoch, 164 | 'state_dict': self.model.state_dict(), 165 | 'optimizer': self.optimizer.state_dict(), 166 | 'monitor_best': self.mnt_best, 167 | 'config': self.config 168 | } 169 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 170 | torch.save(state, filename) 171 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 172 | if save_best: 173 | best_path = str(self.checkpoint_dir / 'model_best.pth') 174 | torch.save(state, best_path) 175 | self.logger.info("Saving current best: model_best.pth ...") 176 | 177 | def _resume_checkpoint(self, resume_path): 178 | """ 179 | Resume from saved checkpoints 180 | 181 | :param resume_path: Checkpoint path to be resumed 182 | """ 183 | resume_path = str(resume_path) 184 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 185 | checkpoint = torch.load(resume_path) 186 | self.start_epoch = checkpoint['epoch'] + 1 187 | self.mnt_best = checkpoint['monitor_best'] 188 | 189 | # load architecture params from checkpoint. 190 | if checkpoint['config']['arch'] != self.config['arch']: 191 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 192 | "checkpoint. This may yield an exception while state_dict is being loaded.") 193 | 194 | state_dict = checkpoint['state_dict'] 195 | 196 | load_state_dict_keys = list(state_dict.keys()) 197 | curr_state_dict_keys = list(self.model.state_dict().keys()) 198 | redo_dp = False 199 | if not curr_state_dict_keys[0].startswith('module.') and load_state_dict_keys[0].startswith('module.'): 200 | undo_dp = True 201 | elif curr_state_dict_keys[0].startswith('module.') and not load_state_dict_keys[0].startswith('module.'): 202 | redo_dp = True 203 | undo_dp = False 204 | else: 205 | undo_dp = False 206 | 207 | if undo_dp: 208 | from collections import OrderedDict 209 | new_state_dict = OrderedDict() 210 | for k, v in state_dict.items(): 211 | name = k[7:] # remove `module.` 212 | new_state_dict[name] = v 213 | # load params 214 | elif redo_dp: 215 | from collections import OrderedDict 216 | new_state_dict = OrderedDict() 217 | for k, v in state_dict.items(): 218 | name = 'module.' + k # remove `module.` 219 | new_state_dict[name] = v 220 | else: 221 | new_state_dict = state_dict 222 | 223 | self.model.load_state_dict(new_state_dict) 224 | 225 | # load optimizer state from checkpoint only when optimizer type is not changed. 226 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 227 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 228 | "Optimizer parameters not being resumed.") 229 | else: 230 | self.optimizer.load_state_dict(checkpoint['optimizer']) 231 | 232 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 233 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | 2 | """Module for computing performance metrics 3 | 4 | """ 5 | import math 6 | import numbers 7 | from pathlib import Path 8 | import ipdb 9 | import numpy as np 10 | import torch 11 | import scipy.stats 12 | from sklearn.metrics import average_precision_score 13 | import ipdb 14 | import pdb 15 | 16 | def t2v_metrics(sims, query_masks=None): 17 | """Compute retrieval metrics from a similiarity matrix. 18 | 19 | Args: 20 | sims (th.Tensor): N x M matrix of similarities between embeddings, where 21 | x_{i,j} = 22 | query_masks (th.Tensor): mask any missing queries from the dataset (two videos 23 | in MSRVTT only have 19, rather than 20 captions) 24 | 25 | Returns: 26 | (dict[str:float]): retrieval metrics 27 | """ 28 | assert sims.ndim == 2, "expected a matrix" 29 | num_queries, num_vids = sims.shape 30 | dists = -sims 31 | sorted_dists = np.sort(dists, axis=1) 32 | 33 | # The indices are computed such that they slice out the ground truth distances 34 | # from the psuedo-rectangular dist matrix 35 | queries_per_video = num_queries // num_vids 36 | gt_idx = [[np.ravel_multi_index([ii, jj], (num_queries, num_vids)) 37 | for ii in range(jj * queries_per_video, (jj + 1) * queries_per_video)] 38 | for jj in range(num_vids)] 39 | gt_idx = np.array(gt_idx) 40 | gt_dists = dists.reshape(-1)[gt_idx.reshape(-1)] 41 | gt_dists = gt_dists[:, np.newaxis] 42 | rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT 43 | 44 | # -------------------------------- 45 | # NOTE: Breaking ties 46 | # -------------------------------- 47 | # We sometimes need to break ties (in general, these should occur extremely rarely, 48 | # but there are pathological cases when they can distort the scores, such as when 49 | # the similarity matrix is all zeros). Previous implementations (e.g. the t2i 50 | # evaluation function used 51 | # here: https://github.com/niluthpol/multimodal_vtt/blob/master/evaluation.py and 52 | # here: https://github.com/linxd5/VSE_Pytorch/blob/master/evaluation.py#L87) generally 53 | # break ties "optimistically". However, if the similarity matrix is constant this 54 | # can evaluate to a perfect ranking. A principled option is to average over all 55 | # possible partial orderings implied by the ties. See # this paper for a discussion: 56 | # McSherry, Frank, and Marc Najork, 57 | # "Computing information retrieval performance measures efficiently in the presence 58 | # of tied scores." European conference on information retrieval. Springer, Berlin, 59 | # Heidelberg, 2008. 60 | # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.145.8892&rep=rep1&type=pdf 61 | 62 | break_ties = "optimistically" 63 | #break_ties = "averaging" 64 | 65 | if rows.size > num_queries: 66 | assert np.unique(rows).size == num_queries, "issue in metric evaluation" 67 | if break_ties == "optimistically": 68 | _, idx = np.unique(rows, return_index=True) 69 | cols = cols[idx] 70 | elif break_ties == "averaging": 71 | # fast implementation, based on this code: 72 | # https://stackoverflow.com/a/49239335 73 | locs = np.argwhere((sorted_dists - gt_dists) == 0) 74 | 75 | # Find the split indices 76 | steps = np.diff(locs[:, 0]) 77 | splits = np.nonzero(steps)[0] + 1 78 | splits = np.insert(splits, 0, 0) 79 | 80 | # Compute the result columns 81 | summed_cols = np.add.reduceat(locs[:, 1], splits) 82 | counts = np.diff(np.append(splits, locs.shape[0])) 83 | avg_cols = summed_cols / counts 84 | if False: 85 | print("Running slower code to verify rank averaging across ties") 86 | # slow, but more interpretable version, used for testing 87 | avg_cols_slow = [np.mean(cols[rows == idx]) for idx in range(num_queries)] 88 | assert np.array_equal(avg_cols, avg_cols_slow), "slow vs fast difference" 89 | print("passed num check") 90 | cols = avg_cols 91 | 92 | msg = "expected ranks to match queries ({} vs {}) " 93 | if cols.size != num_queries: 94 | import ipdb; 95 | ipdb.set_trace() 96 | assert cols.size == num_queries, msg 97 | 98 | if False: 99 | # overload mask to check that we can recover the scores for single-query 100 | # retrieval 101 | print("DEBUGGING MODE") 102 | query_masks = np.zeros_like(query_masks) 103 | query_masks[:, 0] = 1 # recover single query score 104 | 105 | if query_masks is not None: 106 | # remove invalid queries 107 | assert query_masks.size == num_queries, "invalid query mask shape" 108 | cols = cols[query_masks.reshape(-1).astype(np.bool)] 109 | assert cols.size == query_masks.sum(), "masking was not applied correctly" 110 | # update number of queries to account for those that were missing 111 | num_queries = query_masks.sum() 112 | 113 | if False: 114 | # sanity check against old logic for square matrices 115 | gt_dists_old = np.diag(dists) 116 | gt_dists_old = gt_dists_old[:, np.newaxis] 117 | _, cols_old = np.where((sorted_dists - gt_dists_old) == 0) 118 | assert np.array_equal(cols_old, cols), "new metric doesn't match" 119 | 120 | return cols2metrics(cols, num_queries) 121 | 122 | 123 | def v2t_metrics(sims, query_masks=None): 124 | """Compute retrieval metrics from a similiarity matrix. 125 | 126 | Args: 127 | sims (th.Tensor): N x M matrix of similarities between embeddings, where 128 | x_{i,j} = 129 | query_masks (th.Tensor): mask any missing captions from the dataset 130 | 131 | Returns: 132 | (dict[str:float]): retrieval metrics 133 | 134 | NOTES: We find the closest "GT caption" in the style of VSE, which corresponds 135 | to finding the rank of the closest relevant caption in embedding space: 136 | github.com/ryankiros/visual-semantic-embedding/blob/master/evaluation.py#L52-L56 137 | """ 138 | # switch axes of text and video 139 | sims = sims.T 140 | 141 | if False: 142 | # experiment with toy example 143 | sims = np.ones((3, 3)) 144 | sims[0, 0] = 2 145 | sims[1, 1:2] = 2 146 | sims[2, :] = 2 147 | query_masks = None 148 | 149 | assert sims.ndim == 2, "expected a matrix" 150 | num_queries, num_caps = sims.shape 151 | dists = -sims 152 | caps_per_video = num_caps // num_queries 153 | break_ties = "averaging" 154 | 155 | MISSING_VAL = 1E8 156 | query_ranks = [] 157 | for ii in range(num_queries): 158 | row_dists = dists[ii, :] 159 | if query_masks is not None: 160 | # Set missing queries to have a distance of infinity. A missing query 161 | # refers to a query position `n` for a video that had less than `n` 162 | # captions (for example, a few MSRVTT videos only have 19 queries) 163 | row_dists[np.logical_not(query_masks.reshape(-1))] = MISSING_VAL 164 | 165 | # NOTE: Using distance subtraction to perform the ranking is easier to make 166 | # deterministic than using argsort, which suffers from the issue of defining 167 | # "stability" for equal distances. Example of distance subtraction code: 168 | # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py 169 | sorted_dists = np.sort(row_dists) 170 | 171 | min_rank = np.inf 172 | for jj in range(ii * caps_per_video, (ii + 1) * caps_per_video): 173 | if row_dists[jj] == MISSING_VAL: 174 | # skip rankings of missing captions 175 | continue 176 | ranks = np.where((sorted_dists - row_dists[jj]) == 0)[0] 177 | if break_ties == "optimistically": 178 | rank = ranks[0] 179 | elif break_ties == "averaging": 180 | # NOTE: If there is more than one caption per video, its possible for the 181 | # method to do "worse than chance" in the degenerate case when all 182 | # similarities are tied. TODO(Samuel): Address this case. 183 | rank = ranks.mean() 184 | if rank < min_rank: 185 | min_rank = rank 186 | query_ranks.append(min_rank) 187 | query_ranks = np.array(query_ranks) 188 | 189 | # sanity check against old version of code 190 | if False: 191 | sorted_dists = np.sort(dists, axis=1) 192 | gt_dists_old = np.diag(dists) 193 | gt_dists_old = gt_dists_old[:, np.newaxis] 194 | rows_old, cols_old = np.where((sorted_dists - gt_dists_old) == 0) 195 | if rows_old.size > num_queries: 196 | _, idx = np.unique(rows_old, return_index=True) 197 | cols_old = cols_old[idx] 198 | num_diffs = (1 - (cols_old == query_ranks)).sum() 199 | msg = f"new metric doesn't match in {num_diffs} places" 200 | assert np.array_equal(cols_old, query_ranks), msg 201 | 202 | # visualise the distance matrix 203 | import sys 204 | import matplotlib 205 | matplotlib.use("Agg") 206 | import matplotlib.pyplot as plt 207 | sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python")) 208 | from zsvision.zs_iterm import zs_dispFig # NOQA 209 | plt.matshow(dists) 210 | zs_dispFig() 211 | 212 | return cols2metrics(query_ranks, num_queries) 213 | 214 | 215 | def retrieval_as_classification(sims, query_masks=None): 216 | """Compute classification metrics from a similiarity matrix. 217 | """ 218 | assert sims.ndim == 2, "expected a matrix" 219 | 220 | # switch axes of query-labels and video 221 | sims = sims.T 222 | query_masks = query_masks.T 223 | dists = -sims 224 | num_queries, num_labels = sims.shape 225 | break_ties = "averaging" 226 | 227 | query_ranks = [] 228 | for ii in range(num_queries): 229 | row_dists = dists[ii, :] 230 | 231 | # NOTE: Using distance subtraction to perform the ranking is easier to make 232 | # deterministic than using argsort, which suffers from the issue of defining 233 | # "stability" for equal distances. Example of distance subtraction code: 234 | # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py 235 | sorted_dists = np.sort(row_dists) 236 | 237 | # min_rank = np.inf 238 | label_ranks = [] 239 | for gt_label in np.where(query_masks[ii, :])[0]: 240 | ranks = np.where((sorted_dists - row_dists[gt_label]) == 0)[0] 241 | if break_ties == "optimistically": 242 | rank = ranks[0] 243 | elif break_ties == "averaging": 244 | # NOTE: If there is more than one caption per video, its possible for the 245 | # method to do "worse than chance" in the degenerate case when all 246 | # similarities are tied. TODO(Samuel): Address this case. 247 | rank = ranks.mean() 248 | else: 249 | raise ValueError(f"unknown tie-breaking method: {break_ties}") 250 | label_ranks.append(rank) 251 | # Avoid penalising for assigning higher similarity to other gt labels. This is 252 | # done by subtracting out the better ranked query labels. Note that this step 253 | # introduces a slight skew in favour of videos with lots of labels. We can 254 | # address this later with a normalisation step if needed. 255 | label_ranks = [x - idx for idx, x in enumerate(label_ranks)] 256 | 257 | # Include all labels in the final calculation 258 | query_ranks.extend(label_ranks) 259 | query_ranks = np.array(query_ranks) 260 | 261 | # sanity check against old version of code 262 | if False: 263 | # visualise the distance matrix 264 | import sys 265 | import matplotlib 266 | matplotlib.use("Agg") 267 | import matplotlib.pyplot as plt 268 | sys.path.insert(0, str(Path.home() / "coding/src/zsvision/python")) 269 | from zsvision.zs_iterm import zs_dispFig # NOQA 270 | # plt.matshow(dists) 271 | # zs_dispFig() 272 | plt.hist(query_ranks, bins=313, alpha=0.5) 273 | plt.grid() 274 | zs_dispFig() 275 | import ipdb; 276 | ipdb.set_trace() 277 | 278 | return cols2metrics(query_ranks, num_queries=len(query_ranks)) 279 | 280 | 281 | def cols2metrics(cols, num_queries): 282 | metrics = {} 283 | metrics["R1"] = 100 * float(np.sum(cols == 0)) / num_queries 284 | metrics["R5"] = 100 * float(np.sum(cols < 5)) / num_queries 285 | metrics["R10"] = 100 * float(np.sum(cols < 10)) / num_queries 286 | metrics["R50"] = 100 * float(np.sum(cols < 50)) / num_queries 287 | metrics["MedR"] = np.median(cols) + 1 288 | metrics["MeanR"] = np.mean(cols) + 1 289 | stats = [metrics[x] for x in ("R1", "R5", "R10")] 290 | metrics["geometric_mean_R1-R5-R10"] = scipy.stats.mstats.gmean(stats) 291 | return metrics 292 | 293 | 294 | def mean_average_precision(sims, query_masks=None): 295 | ap_meter = APMeter() 296 | ap_meter.add(output=sims.T, target=query_masks.T) 297 | return {"mAP": ap_meter.value().mean()} 298 | 299 | def acc(output, target): 300 | with torch.no_grad(): 301 | pred = torch.argmax(output, dim=1) 302 | assert pred.shape[0] == len(target) 303 | correct = 0 304 | correct += torch.sum(pred == target).item() 305 | return correct / len(target) 306 | 307 | 308 | def my_metric2(output, target, k=3): 309 | with torch.no_grad(): 310 | pred = torch.topk(output, k, dim=1)[1] 311 | assert pred.shape[0] == len(target) 312 | correct = 0 313 | for i in range(k): 314 | correct += torch.sum(pred[:, i] == target).item() 315 | return correct / len(target) 316 | 317 | 318 | def video_precision(output, target): 319 | """ percentage of videos which have been aligned to a matching text pair""" 320 | assert output.shape[0] == target.shape[0] 321 | assert output.shape[2] == target.shape[2] == 2 322 | 323 | correct = 0 324 | for bout, btarg in zip(output, target): 325 | for pair in bout: 326 | eq = torch.eq(pair, btarg) 327 | if torch.logical_and(eq[:, 0], eq[:, 1]).any(): 328 | correct += 1 329 | return correct / (target.shape[0] * target.shape[1]) 330 | 331 | def video_precision_adj(output, target): 332 | """ adjusts the video precision metric by ignoring videos which have no aligning text.""" 333 | assert output.shape[0] == target.shape[0] 334 | assert output.shape[2] == target.shape[2] == 2 335 | 336 | assert output.shape[0] == target.shape[0] 337 | assert output.shape[2] == target.shape[2] == 2 338 | 339 | correct = 0 340 | for bout, btarg in zip(output, target): 341 | for pair in bout: 342 | eq = torch.eq(pair, btarg) 343 | if torch.logical_and(eq[:, 0], eq[:, 1]).any(): 344 | correct += 1 345 | denom = len(target[:, :, 0].unique()) 346 | 347 | return correct / denom 348 | 349 | def video_precision_adj(output, target): 350 | """ adjusts the video precision metric by ignoring videos which have no aligning text.""" 351 | assert output.shape[0] == target.shape[0] 352 | assert output.shape[2] == target.shape[2] == 2 353 | 354 | assert output.shape[0] == target.shape[0] 355 | assert output.shape[2] == target.shape[2] == 2 356 | 357 | correct = 0 358 | for bout, btarg in zip(output, target): 359 | for pair in bout: 360 | eq = torch.eq(pair, btarg) 361 | if torch.logical_and(eq[:, 0], eq[:, 1]).any(): 362 | correct += 1 363 | denom = len(target[:, :, 0].unique()) 364 | 365 | return correct / denom --------------------------------------------------------------------------------