├── .DS_Store ├── README.md ├── args.py ├── base ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── base_data_loader.cpython-37.pyc │ ├── base_dataset.cpython-37.pyc │ ├── base_model.cpython-37.pyc │ └── base_trainer.cpython-37.pyc ├── base_data_loader.py ├── base_dataset.py ├── base_model.py └── base_trainer.py ├── configs ├── ft │ ├── DiDeMo_8f.json │ ├── HMDB_16f.json │ ├── HMDB_1f.json │ ├── HMDB_3f.json │ ├── HMDB_4f.json │ ├── HMDB_8f.json │ ├── LSMDC_8f.json │ ├── LSMDC_MC_8f.json │ ├── MSRVTT_8f.json │ ├── MSRVTT_8f_clip.json │ ├── MSRVTT_MC_8f.json │ ├── MSRVTT_QA_8f.json │ ├── MSVD_8f.json │ ├── MSVD_QA_8f.json │ └── UCF_8f.json └── pt │ ├── CC3M-WebVid2M.json │ ├── WebVid2M.json │ ├── WebVid2M_4f_ME.json │ ├── WebVid2M_clip.json │ ├── WebVid2M_clip_RL.json │ └── WebVid2M_raw.json ├── data_loader ├── ConceptualCaptions_dataset.py ├── DiDeMo_dataset.py ├── HMDB_dataset.py ├── LSMDC_dataset.py ├── LSMDC_dataset_old.py ├── MSRVTT_dataset.py ├── MSVD_dataset.py ├── UCF_dataset.py ├── WebVid_dataset.py ├── __pycache__ │ ├── ConceptualCaptions_dataset.cpython-37.pyc │ ├── DiDeMo_dataset.cpython-37.pyc │ ├── HMDB_dataset.cpython-37.pyc │ ├── LSMDC_dataset.cpython-37.pyc │ ├── MSRVTT_dataset.cpython-37.pyc │ ├── MSVD_dataset.cpython-37.pyc │ ├── UCF_dataset.cpython-37.pyc │ ├── WebVid_dataset.cpython-37.pyc │ ├── data_loader.cpython-37.pyc │ └── transforms.cpython-37.pyc ├── data_loader.py └── transforms.py ├── fine-tuning.sh ├── logger ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── logger.cpython-37.pyc │ └── visualization.cpython-37.pyc ├── logger.py ├── logger_config.json └── visualization.py ├── model ├── RegionLearner │ ├── Quantizer.py │ ├── RegionLearner.py │ └── __pycache__ │ │ ├── Quantizer.cpython-37.pyc │ │ └── RegionLearner.cpython-37.pyc ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── helper.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── metric.cpython-37.pyc │ ├── model.cpython-37.pyc │ ├── qa_model.cpython-37.pyc │ └── video_transformer.cpython-37.pyc ├── clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── clip.cpython-37.pyc │ │ ├── model.cpython-37.pyc │ │ ├── simple_tokenizer.cpython-37.pyc │ │ └── tokenization_clip.cpython-37.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ ├── simple_tokenizer.py │ └── tokenization_clip.py ├── helper.py ├── loss.py ├── metric.py ├── model.py ├── qa_model.py └── video_transformer.py ├── parse_config.py ├── pre-training.sh ├── setup_myEnv.sh ├── train.py ├── trainer ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── trainer.cpython-37.pyc └── trainer.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── html.cpython-37.pyc ├── util.cpython-37.pyc └── visualizer.cpython-37.pyc ├── custom_transforms.py ├── html.py ├── util.py ├── video.py ├── visualisation.py └── visualizer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Region_Learner 2 | The Pytorch implementation for "Video-Text Pre-training with Learned Regions" 3 | ([arxiv](https://arxiv.org/pdf/2112.01194.pdf)) 4 | 5 | ***We are still cleaning up the code further and preparing for pre-training weights.*** 6 | 7 | ## Preparation 8 | Overall, this code is built on PyTorch with DistributedDataParallel (DDP). 9 | - Create conda env and install required packages via `sh setup_myEnv.sh` 10 | - Create some important folders 11 | 1. `mkdir data` (you can symlink huge datasets to this folder) 12 | 2. `mkdir meta_data` (put meta data of each dataset here) 13 | 3. `mkdir results` 14 | - Download Pre-training data 15 | 1. Download WebVid-2M (see https://github.com/m-bain/webvid) 16 | 2. Download CC3M (see https://ai.google.com/research/ConceptualCaptions/download) 17 | 18 | PS: Not all videos are avaible so that you need to modify the metadata depend on your case. We also provide our metadata in [here](https://drive.google.com/drive/folders/1y9Byj2IFWSyeGiyJJwc2VPIESzakGHAh?usp=sharing). 19 | 20 | 21 | ## Pre-training 22 | - Run `sh pre-training.sh` (Commands with different settings are listed in this script.) 23 | 24 | ## Finetuning (on MSR-VTT) 25 | - Download data (see https://github.com/m-bain/frozen-in-time#-finetuning-benchmarks-msr-vtt) 26 | - Run `sh fine-tune.sh`. 27 | 28 | ## Pre-trained Weights 29 | [WebVid2M + CC3M](https://drive.google.com/file/d/1ql5PDgaTqA9pQcBb1cYRGkH3IfbbUkSv/view?usp=sharing) 30 | 31 | ## Acknowledgements 32 | This code is based off [Frozen in Time](https://github.com/m-bain/frozen-in-time "Frozen in Time") 33 | 34 | 35 | 36 | 37 | 38 | ## Citation 39 | ``` 40 | @article{yan2021video, 41 | title={Video-Text Pre-training with Learned Regions}, 42 | author={Yan, Rui and Shou, Mike Zheng and Ge, Yixiao and Wang, Alex Jinpeng and Lin, Xudong and Cai, Guanyu and Tang, Jinhui}, 43 | journal={arXiv preprint arXiv:2112.01194}, 44 | year={2021} 45 | } 46 | ``` -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(description='MILNCE'): 4 | parser = argparse.ArgumentParser(description=description) 5 | parser.add_argument( 6 | '--train_csv', 7 | type=str, 8 | default='csv/hmdb51.csv', 9 | help='train csv') 10 | parser.add_argument( 11 | '--video_path', 12 | type=str, 13 | default='', 14 | help='video_path') 15 | parser.add_argument( 16 | '--caption_root', 17 | type=str, 18 | default='', 19 | help='video_path') 20 | parser.add_argument( 21 | '--checkpoint_root', 22 | type=str, 23 | default='checkpoint', 24 | help='checkpoint dir root') 25 | parser.add_argument( 26 | '--log_root', 27 | type=str, 28 | default='log', 29 | help='log dir root') 30 | parser.add_argument( 31 | '--eval_video_root', 32 | type=str, 33 | default='', 34 | help='root folder for the video at for evaluation') 35 | parser.add_argument( 36 | '--checkpoint_dir', 37 | type=str, 38 | default='', 39 | help='checkpoint model folder') 40 | parser.add_argument( 41 | '--optimizer', type=str, default='adam', help='opt algorithm') 42 | parser.add_argument('--weight_init', type=str, default='uniform', 43 | help='CNN weights inits') 44 | parser.add_argument('--num_thread_reader', type=int, default=20, 45 | help='') 46 | parser.add_argument('--num_class', type=int, default=512, 47 | help='upper epoch limit') 48 | parser.add_argument('--num_candidates', type=int, default=1, 49 | help='num candidates for MILNCE loss') 50 | parser.add_argument('--batch_size', type=int, default=256, 51 | help='batch size') 52 | parser.add_argument('--num_windows_test', type=int, default=4, 53 | help='number of testing windows') 54 | parser.add_argument('--batch_size_val', type=int, default=32, 55 | help='batch size eval') 56 | parser.add_argument('--momemtum', type=float, default=0.9, 57 | help='SGD momemtum') 58 | parser.add_argument('--n_display', type=int, default=10, 59 | help='Information display frequence') 60 | parser.add_argument('--num_frames', type=int, default=16, 61 | help='random seed') 62 | parser.add_argument('--video_size', type=int, default=224, 63 | help='random seed') 64 | parser.add_argument('--crop_only', type=int, default=1, 65 | help='random seed') 66 | parser.add_argument('--centercrop', type=int, default=0, 67 | help='random seed') 68 | parser.add_argument('--random_flip', type=int, default=1, 69 | help='random seed') 70 | parser.add_argument('--verbose', type=int, default=1, 71 | help='') 72 | parser.add_argument('--warmup_steps', type=int, default=5000, 73 | help='') 74 | parser.add_argument('--min_time', type=float, default=5.0, 75 | help='') 76 | parser.add_argument( 77 | '--pretrain_cnn_path', 78 | type=str, 79 | default='', 80 | help='') 81 | parser.add_argument( 82 | '--word2vec_path', type=str, default='data/word2vec.pth', help='') 83 | parser.add_argument('--fps', type=int, default=5, help='') 84 | parser.add_argument('--cudnn_benchmark', type=int, default=0, 85 | help='') 86 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 87 | help='number of total epochs to run') 88 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 89 | help='manual epoch number (useful on restarts)') 90 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 91 | metavar='LR', help='initial learning rate', dest='lr') 92 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 93 | help='momentum') 94 | parser.add_argument('--resume', dest='resume', action='store_true', 95 | help='resume training from last checkpoint') 96 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 97 | help='evaluate model on validation set') 98 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 99 | help='use pre-trained model') 100 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true', 101 | help='use pin_memory') 102 | parser.add_argument('--world-size', default=-1, type=int, 103 | help='number of nodes for distributed training') 104 | parser.add_argument('--rank', default=-1, type=int, 105 | help='node rank for distributed training') 106 | parser.add_argument('--dist-file', default='dist-file', type=str, 107 | help='url used to set up distributed training') 108 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 109 | help='url used to set up distributed training') 110 | parser.add_argument('--dist-backend', default='nccl', type=str, 111 | help='distributed backend') 112 | parser.add_argument('--seed', default=1, type=int, 113 | help='seed for initializing training. ') 114 | parser.add_argument('--gpu', default=None, type=int, 115 | help='GPU id to use.') 116 | parser.add_argument('--multiprocessing-distributed', action='store_true', 117 | help='Use multi-processing distributed training to launch ' 118 | 'N processes per node, which has N GPUs. This is the ' 119 | 'fastest way to use PyTorch for either single node or ' 120 | 'multi node data parallel training') 121 | args = parser.parse_args() 122 | return args 123 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_dataset import * 3 | from .base_model import * 4 | from .base_trainer import * -------------------------------------------------------------------------------- /base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/base/__pycache__/base_data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/base/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/base/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/base_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/base/__pycache__/base_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | # turn off shuffle option which is mutually exclusive with sampler 51 | self.shuffle = False 52 | self.n_samples = len(train_idx) 53 | 54 | return train_sampler, valid_sampler 55 | 56 | def split_validation(self, diff_kwargs=None): 57 | init_kwargs = self.init_kwargs 58 | if diff_kwargs is not None: 59 | init_kwargs.update(diff_kwargs) 60 | if self.valid_sampler is None: 61 | return None 62 | else: 63 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 64 | 65 | def num_samples(self): 66 | return len(self.sampler) 67 | 68 | 69 | class BaseDataLoaderExplicitSplit(DataLoader): 70 | """ 71 | Base class for all data loaders 72 | """ 73 | def __init__(self, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 74 | self.shuffle = shuffle 75 | 76 | self.batch_idx = 0 77 | self.n_samples = len(dataset) 78 | 79 | self.init_kwargs = { 80 | 'dataset': dataset, 81 | 'batch_size': batch_size, 82 | 'shuffle': self.shuffle, 83 | 'collate_fn': collate_fn, 84 | 'num_workers': num_workers, 85 | 'pin_memory': True 86 | } 87 | super().__init__(**self.init_kwargs) 88 | 89 | class DistBaseDataLoaderExplicitSplit(DataLoader): 90 | """ 91 | Base class for all data loaders 92 | """ 93 | def __init__(self, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 94 | self.shuffle = shuffle 95 | 96 | self.batch_idx = 0 97 | self.n_samples = len(dataset) 98 | self.train_sampler = DistributedSampler(dataset) 99 | self.init_kwargs = { 100 | 'dataset': dataset, 101 | 'batch_size': batch_size, 102 | 'shuffle': False, 103 | 'collate_fn': collate_fn, 104 | 'num_workers': num_workers, 105 | 'pin_memory': True, 106 | 'sampler': self.train_sampler 107 | } 108 | super().__init__(**self.init_kwargs) 109 | 110 | class MultiDistBaseDataLoaderExplicitSplit(DataLoader): 111 | """ 112 | Base class for all data loaders 113 | """ 114 | def __init__(self, args, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 115 | self.shuffle = shuffle 116 | 117 | self.batch_idx = 0 118 | self.n_samples = len(dataset) 119 | self.args = args 120 | self.train_sampler = DistributedSampler(dataset, num_replicas=self.args.world_size, rank=self.args.rank, drop_last=True) 121 | self.init_kwargs = { 122 | 'dataset': dataset, 123 | 'batch_size': batch_size, 124 | 'shuffle': False, 125 | 'collate_fn': collate_fn, 126 | 'num_workers': num_workers, 127 | 'pin_memory': True, 128 | 'sampler': self.train_sampler 129 | } 130 | super().__init__(**self.init_kwargs) 131 | 132 | class BaseMultiDataLoader: 133 | """ 134 | Currently implemented as undersample the bigger dataloaders... 135 | """ 136 | def __init__(self, dataloaders): 137 | self.dataloaders = dataloaders 138 | self.batch_size = self.dataloaders[0].batch_size 139 | def __getitem__(self, item): 140 | dl_idx = item % len(self.dataloaders) 141 | return next(iter(self.dataloaders[dl_idx])) 142 | 143 | def __len__(self): 144 | return min([len(x) for x in self.dataloaders]) * len(self.dataloaders) 145 | 146 | def num_samples(self): 147 | return sum([len(x.sampler) for x in self.dataloaders]) 148 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /configs/ft/DiDeMo_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DIDEMO_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint": "" 24 | } 25 | }, 26 | "data_loader": [{ 27 | "type": "MultiDistTextVideoDataLoader", 28 | "args": { 29 | "dataset_name": "DIDEMO", 30 | "metadata_dir": "metadata/DiDeMo", 31 | "reader":"decord", 32 | "shuffle": false, 33 | "num_workers": 16, 34 | "batch_size": 16, 35 | "split": "train", 36 | "cut": "xx", 37 | "subsample": 1, 38 | "text_params": { 39 | "input": "text" 40 | }, 41 | "video_params": { 42 | "extraction_fps": 25, 43 | "extraction_res": 256, 44 | "input_res": 224, 45 | "num_frames": 8, 46 | "loading": "lax", 47 | "stride": 1 48 | } 49 | } 50 | }], 51 | "optimizer": { 52 | "type": "AdamW", 53 | "args": { 54 | "lr": 3e-05 55 | } 56 | }, 57 | "loss": { 58 | "type": "NormSoftmaxLoss", 59 | "args": {} 60 | }, 61 | "metrics": [ 62 | "t2v_metrics", 63 | "v2t_metrics" 64 | ], 65 | "trainer": { 66 | "epochs": 100, 67 | "max_samples_per_epoch": 9000, 68 | "save_dir": "./results/ft/DiDeMo/", 69 | "save_period": 5, 70 | "verbosity": 2, 71 | "monitor": "min val_loss_0", 72 | "early_stop": 10, 73 | "neptune": false, 74 | "use_amp":false 75 | }, 76 | "visualizer": { 77 | "type": "", 78 | "args": {} 79 | } 80 | } -------------------------------------------------------------------------------- /configs/ft/HMDB_16f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HMDB_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 16, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "HMDB", 31 | "data_dir": "data/HMDB51/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 4, 36 | "split": "train", 37 | "cut": "1", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 16, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/HMDB/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/HMDB_1f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HMDB_1f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 1, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "HMDB", 31 | "data_dir": "data/HMDB51/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "1", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 1, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/HMDB/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/HMDB_3f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HMDB_3f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 3, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "HMDB", 31 | "data_dir": "data/HMDB51/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "1", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 3, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/HMDB/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/HMDB_4f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HMDB_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "HMDB", 31 | "data_dir": "data/HMDB51/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 8, 36 | "split": "train", 37 | "cut": "3", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 4, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/HMDB/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/HMDB_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HMDB_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "HMDB", 31 | "data_dir": "data/HMDB51/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 8, 36 | "split": "train", 37 | "cut": "1", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/HMDB/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/LSMDC_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "LSMDC_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint": "" 24 | } 25 | }, 26 | "data_loader": [{ 27 | "type": "MultiDistTextVideoDataLoader", 28 | "args":{ 29 | "dataset_name": "LSMDC", 30 | "data_dir": "data/LSMDC", 31 | "metadata_dir": "metadata/LSMDC", 32 | "reader":"decord", 33 | "shuffle": false, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "xx", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "loading": "lax", 48 | "stride": 1 49 | } 50 | } 51 | }], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "t2v_metrics", 65 | "v2t_metrics" 66 | ], 67 | "trainer": { 68 | "epochs": 100, 69 | "max_samples_per_epoch": 9000, 70 | "save_dir": "./results/ft/LSMDC/", 71 | "save_period": 5, 72 | "verbosity": 2, 73 | "monitor": "min val_loss_0", 74 | "early_stop": 10, 75 | "neptune": false, 76 | "use_amp":false 77 | }, 78 | "visualizer": { 79 | "type": "", 80 | "args": { 81 | } 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /configs/ft/LSMDC_MC_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "LSMDC_MC_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint": "" 24 | } 25 | }, 26 | "data_loader": [{ 27 | "type": "MultiDistTextVideoDataLoader", 28 | "args":{ 29 | "dataset_name": "LSMDC_MC", 30 | "data_dir": "data/LSMDC", 31 | "metadata_dir": "data/LSMDC/metadata/MC/", 32 | "reader":"decord", 33 | "shuffle": false, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "xx", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "loading": "lax", 48 | "stride": 1 49 | } 50 | } 51 | }], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "mc_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/LSMDC/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/ft/MSRVTT_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTTjsfusion_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "MSRVTT", 31 | "data_dir": "data/MSRVTT/MSRVTT_source/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "jsfusion", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "t2v_metrics", 65 | "v2t_metrics" 66 | ], 67 | "trainer": { 68 | "epochs": 100, 69 | "max_samples_per_epoch": 9000, 70 | "save_dir": "./results/ft/MSRVTT/", 71 | "save_period": 5, 72 | "verbosity": 2, 73 | "monitor": "min val_loss_0", 74 | "early_stop": 10, 75 | "neptune": false, 76 | "use_amp":false 77 | }, 78 | "visualizer": { 79 | "type": "", 80 | "args": { 81 | } 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /configs/ft/MSRVTT_8f_clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTTjsfusion_8f_clip", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "CLIP", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "CLIP", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "MSRVTT", 31 | "data_dir": "data/MSRVTT/MSRVTT_source/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 16, 36 | "split": "train", 37 | "cut": "jsfusion", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "t2v_metrics", 65 | "v2t_metrics" 66 | ], 67 | "trainer": { 68 | "epochs": 100, 69 | "max_samples_per_epoch": 9000, 70 | "save_dir": "./results/ft/MSRVTT/", 71 | "save_period": 5, 72 | "verbosity": 2, 73 | "monitor": "min val_loss_0", 74 | "early_stop": 10, 75 | "neptune": false, 76 | "use_amp":false 77 | }, 78 | "visualizer": { 79 | "type": "", 80 | "args": { 81 | } 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /configs/ft/MSRVTT_MC_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTT_MC_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "MSRVTT_MC", 31 | "data_dir": "data/MSRVTT/MSRVTT_source/", 32 | "metadata_dir":"data/MSRVTT/metadata/MC/", 33 | "reader":"decord", 34 | "shuffle": true, 35 | "num_workers": 16, 36 | "batch_size": 16, 37 | "split": "train", 38 | "cut": "jsfusion", 39 | "subsample": 1, 40 | "text_params": { 41 | "input": "text" 42 | }, 43 | "video_params": { 44 | "extraction_fps": 25, 45 | "extraction_res": 256, 46 | "input_res": 224, 47 | "num_frames": 8, 48 | "stride": 1 49 | } 50 | } 51 | } 52 | ], 53 | "optimizer": { 54 | "type": "AdamW", 55 | "args":{ 56 | "lr": 3e-5 57 | } 58 | }, 59 | "loss": { 60 | "type": "NormSoftmaxLoss", 61 | "args": { 62 | } 63 | }, 64 | "metrics": [ 65 | "mc_as_retrieval" 66 | ], 67 | "trainer": { 68 | "epochs": 100, 69 | "max_samples_per_epoch": 9000, 70 | "save_dir": "./results/ft/MSRVTT/", 71 | "save_period": 5, 72 | "verbosity": 2, 73 | "monitor": "min val_loss_0", 74 | "early_stop": 10, 75 | "neptune": false, 76 | "use_amp":false 77 | }, 78 | "visualizer": { 79 | "type": "", 80 | "args": { 81 | } 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /configs/ft/MSRVTT_QA_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTT_QA_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "num_ans": 1500, 12 | "pretrained": true, 13 | "time_init": "zeros", 14 | "random_sampling": false, 15 | "temporal_type":"att" 16 | }, 17 | "text_params": { 18 | "model": "distilbert-base-uncased", 19 | "pretrained": true, 20 | "pretrained_path": "pretrained", 21 | "input": "text" 22 | }, 23 | "projection": "minimal", 24 | "load_checkpoint" : "" 25 | } 26 | }, 27 | "data_loader": [ 28 | { 29 | "type": "MultiDistTextVideoDataLoader", 30 | "args":{ 31 | "dataset_name": "MSRVTT_QA", 32 | "data_dir": "data/MSRVTT/MSRVTT_source/", 33 | "metadata_dir":"data/MSRVTT/metadata/QA/", 34 | "reader":"decord", 35 | "shuffle": true, 36 | "num_workers": 16, 37 | "batch_size": 16, 38 | "split": "train", 39 | "cut": "jsfusion", 40 | "subsample": 1, 41 | "text_params": { 42 | "input": "text" 43 | }, 44 | "video_params": { 45 | "extraction_fps": 25, 46 | "extraction_res": 256, 47 | "input_res": 224, 48 | "num_frames": 8, 49 | "stride": 1 50 | } 51 | } 52 | } 53 | ], 54 | "optimizer": { 55 | "type": "AdamW", 56 | "args":{ 57 | "lr": 3e-5 58 | } 59 | }, 60 | "loss": { 61 | "type": "CrossEntropy", 62 | "args": { 63 | } 64 | }, 65 | "metrics": [ 66 | "acc" 67 | ], 68 | "trainer": { 69 | "epochs": 100, 70 | "max_samples_per_epoch": 9000, 71 | "save_dir": "./results/ft/MSRVTT/", 72 | "save_period": 5, 73 | "verbosity": 2, 74 | "monitor": "min val_loss_0", 75 | "early_stop": 10, 76 | "neptune": false, 77 | "use_amp":false 78 | }, 79 | "visualizer": { 80 | "type": "", 81 | "args": { 82 | } 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /configs/ft/MSVD_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSVD_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint": "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args": { 30 | "dataset_name": "MSVD", 31 | "data_dir": "data/MSVD/YouTubeClips", 32 | "metadata_dir": "metadata/MSVD", 33 | "reader":"decord", 34 | "shuffle": false, 35 | "num_workers": 16, 36 | "batch_size": 16, 37 | "split": "train", 38 | "cut": "xx", 39 | "subsample": 1, 40 | "text_params": { 41 | "input": "text" 42 | }, 43 | "video_params": { 44 | "extraction_fps": 25, 45 | "extraction_res": 256, 46 | "input_res": 224, 47 | "num_frames": 8, 48 | "stride": 1, 49 | "loading": "lax" 50 | } 51 | } 52 | } 53 | ], 54 | "optimizer": { 55 | "type": "AdamW", 56 | "args": { 57 | "lr": 3e-05 58 | } 59 | }, 60 | "loss": { 61 | "type": "NormSoftmaxLoss", 62 | "args": {} 63 | }, 64 | "metrics": [ 65 | "t2v_metrics", 66 | "v2t_metrics" 67 | ], 68 | "trainer": { 69 | "epochs": 100, 70 | "max_samples_per_epoch": 9000, 71 | "save_dir": "./results/ft/MSVD/", 72 | "save_period": 5, 73 | "verbosity": 2, 74 | "monitor": "min val_loss_0", 75 | "early_stop": 10, 76 | "neptune": false, 77 | "use_amp":false 78 | }, 79 | "visualizer": { 80 | "type": "", 81 | "args": {} 82 | } 83 | } -------------------------------------------------------------------------------- /configs/ft/MSVD_QA_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSVD_QA_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "num_ans": 1000, 12 | "pretrained": true, 13 | "time_init": "zeros", 14 | "random_sampling": false, 15 | "temporal_type":"att" 16 | }, 17 | "text_params": { 18 | "model": "distilbert-base-uncased", 19 | "pretrained": true, 20 | "pretrained_path": "pretrained", 21 | "input": "text" 22 | }, 23 | "projection": "minimal", 24 | "load_checkpoint": "" 25 | } 26 | }, 27 | "data_loader": [ 28 | { 29 | "type": "MultiDistTextVideoDataLoader", 30 | "args": { 31 | "dataset_name": "MSVD_QA", 32 | "data_dir": "data/MSVD/YouTubeClips/", 33 | "metadata_dir": "data/MSVD/metadata/", 34 | "reader":"decord", 35 | "shuffle": true, 36 | "num_workers": 16, 37 | "batch_size": 16, 38 | "split": "train", 39 | "cut": "xx", 40 | "subsample": 1, 41 | "text_params": { 42 | "input": "text" 43 | }, 44 | "video_params": { 45 | "extraction_fps": 25, 46 | "extraction_res": 256, 47 | "input_res": 224, 48 | "num_frames": 8, 49 | "stride": 1, 50 | "loading": "lax" 51 | } 52 | } 53 | } 54 | ], 55 | "optimizer": { 56 | "type": "AdamW", 57 | "args": { 58 | "lr": 3e-05 59 | } 60 | }, 61 | "loss": { 62 | "type": "CrossEntropy", 63 | "args": {} 64 | }, 65 | "metrics": [ 66 | "acc" 67 | ], 68 | "trainer": { 69 | "epochs": 100, 70 | "max_samples_per_epoch": 9000, 71 | "save_dir": "./results/ft/MSVD/", 72 | "save_period": 5, 73 | "verbosity": 2, 74 | "monitor": "min val_loss_0", 75 | "early_stop": 10, 76 | "neptune": false, 77 | "use_amp":false 78 | }, 79 | "visualizer": { 80 | "type": "", 81 | "args": {} 82 | } 83 | } -------------------------------------------------------------------------------- /configs/ft/UCF_8f.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "UCF_8f", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 8, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "temporal_type":"att" 15 | }, 16 | "text_params": { 17 | "model": "distilbert-base-uncased", 18 | "pretrained": true, 19 | "pretrained_path": "pretrained", 20 | "input": "text" 21 | }, 22 | "projection": "minimal", 23 | "load_checkpoint" : "" 24 | } 25 | }, 26 | "data_loader": [ 27 | { 28 | "type": "MultiDistTextVideoDataLoader", 29 | "args":{ 30 | "dataset_name": "UCF", 31 | "data_dir": "data/UCF101/", 32 | "reader":"decord", 33 | "shuffle": true, 34 | "num_workers": 16, 35 | "batch_size": 8, 36 | "split": "train", 37 | "cut": "3", 38 | "subsample": 1, 39 | "text_params": { 40 | "input": "text" 41 | }, 42 | "video_params": { 43 | "extraction_fps": 25, 44 | "extraction_res": 256, 45 | "input_res": 224, 46 | "num_frames": 8, 47 | "stride": 1 48 | } 49 | } 50 | } 51 | ], 52 | "optimizer": { 53 | "type": "AdamW", 54 | "args":{ 55 | "lr": 3e-5 56 | } 57 | }, 58 | "loss": { 59 | "type": "NormSoftmaxLoss", 60 | "args": { 61 | } 62 | }, 63 | "metrics": [ 64 | "cls_as_retrieval" 65 | ], 66 | "trainer": { 67 | "epochs": 100, 68 | "max_samples_per_epoch": 9000, 69 | "save_dir": "./results/ft/UCF/", 70 | "save_period": 5, 71 | "verbosity": 2, 72 | "monitor": "min val_loss_0", 73 | "early_stop": 10, 74 | "neptune": false, 75 | "use_amp":false 76 | }, 77 | "visualizer": { 78 | "type": "", 79 | "args": { 80 | } 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /configs/pt/CC3M-WebVid2M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "CC3M-WebVid2M", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": false, 15 | "VQ_num_tokens":2048, 16 | "AGG_region_num":8, 17 | "Interaction_depth":1, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "distilbert-base-uncased", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "minimal", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "ConceptualCaptions3M", 36 | "data_dir": "data/CC3M/", 37 | "metadata_dir": "meta_data/CC3M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 64, 42 | "split": "train", 43 | "subsample": 1, 44 | "text_params": { 45 | "input": "text" 46 | }, 47 | "video_params": { 48 | "input_res": 224, 49 | "num_frames": 1, 50 | "loading": "lax" 51 | } 52 | } 53 | }, 54 | { 55 | "type": "MultiDistTextVideoDataLoader", 56 | "args":{ 57 | "dataset_name": "WebVid", 58 | "data_dir": "data/WebVid/", 59 | "metadata_dir": "meta_data/WebVid/", 60 | "reader": "decord", 61 | "shuffle": true, 62 | "num_workers": 16, 63 | "batch_size": 64, 64 | "split": "train", 65 | "cut": "2M", 66 | "subsample": 1, 67 | "text_params": { 68 | "input": "text" 69 | }, 70 | "video_params": { 71 | "input_res": 224, 72 | "num_frames": 1, 73 | "loading": "lax" 74 | } 75 | } 76 | } 77 | ], 78 | "optimizer": { 79 | "type": "AdamW", 80 | "args":{ 81 | "lr": 2e-4 82 | } 83 | }, 84 | "loss": { 85 | "type": "NormSoftmaxLoss", 86 | "args": { 87 | } 88 | }, 89 | "metrics": [ 90 | "t2v_metrics", 91 | "v2t_metrics" 92 | ], 93 | "trainer": { 94 | "epochs": 100, 95 | "max_samples_per_epoch": 1000000, 96 | "save_dir": "./results/pt/CC3M_WebVid/", 97 | "save_period": 5, 98 | "verbosity": 2, 99 | "monitor": "min val_loss_0", 100 | "early_stop": 10, 101 | "init_val": true, 102 | "neptune": false, 103 | "use_amp":false, 104 | "accum_iter":1 105 | }, 106 | "visualizer": { 107 | "type": "" 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /configs/pt/WebVid2M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WebVid2M", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": false, 15 | "VQ_num_tokens":2048, 16 | "AGG_region_num":8, 17 | "Interaction_depth":1, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "distilbert-base-uncased", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "minimal", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "WebVid", 36 | "data_dir": "data/WebVid2M/", 37 | "metadata_dir": "data/WebVid2M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 64, 42 | "split": "train", 43 | "cut": "2M", 44 | "subsample": 1, 45 | "text_params": { 46 | "input": "text" 47 | }, 48 | "video_params": { 49 | "input_res": 224, 50 | "num_frames": 1, 51 | "loading": "lax" 52 | } 53 | } 54 | } 55 | ], 56 | "optimizer": { 57 | "type": "AdamW", 58 | "args":{ 59 | "lr": 2e-4 60 | } 61 | }, 62 | "loss": { 63 | "type": "NormSoftmaxLoss", 64 | "args": { 65 | } 66 | }, 67 | "metrics": [ 68 | "t2v_metrics", 69 | "v2t_metrics" 70 | ], 71 | "trainer": { 72 | "epochs": 100, 73 | "max_samples_per_epoch": 1000000, 74 | "save_dir": "./results/pt/WebVid/", 75 | "save_period": 5, 76 | "verbosity": 2, 77 | "monitor": "min val_loss_0", 78 | "early_stop": 10, 79 | "init_val": true, 80 | "neptune": false, 81 | "use_amp":false, 82 | "accum_iter":1 83 | }, 84 | "visualizer": { 85 | "type": "" 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /configs/pt/WebVid2M_4f_ME.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WebVid2M_4f_ME", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": true, 15 | "VQ_num_tokens":2048, 16 | "AGG_region_num":8, 17 | "Interaction_depth":1, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "distilbert-base-uncased", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "minimal", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "WebVid", 36 | "data_dir": "data/WebVid2M/", 37 | "metadata_dir": "data/WebVid2M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 16, 42 | "split": "train", 43 | "cut": "2M", 44 | "subsample": 1, 45 | "text_params": { 46 | "input": "text" 47 | }, 48 | "video_params": { 49 | "input_res": 224, 50 | "num_frames": 4, 51 | "loading": "lax" 52 | } 53 | } 54 | } 55 | ], 56 | "optimizer": { 57 | "type": "AdamW", 58 | "args":{ 59 | "lr": 2e-4 60 | } 61 | }, 62 | "loss": { 63 | "type": "NormSoftmaxLoss", 64 | "args": { 65 | } 66 | }, 67 | "metrics": [ 68 | "t2v_metrics", 69 | "v2t_metrics" 70 | ], 71 | "trainer": { 72 | "epochs": 100, 73 | "max_samples_per_epoch": 1000000, 74 | "save_dir": "./results/pt/WebVid/", 75 | "save_period": 5, 76 | "verbosity": 2, 77 | "monitor": "min val_loss_0", 78 | "early_stop": 10, 79 | "init_val": true, 80 | "neptune": false, 81 | "use_amp":false, 82 | "accum_iter":1 83 | }, 84 | "visualizer": { 85 | "type": "" 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /configs/pt/WebVid2M_clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WebVid2M_clip", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "CLIP", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": false, 15 | "VQ_num_tokens":2048, 16 | "AGG_region_num":8, 17 | "Interaction_depth":1, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "CLIP", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "WebVid", 36 | "data_dir": "data/WebVid2M/", 37 | "metadata_dir": "data/WebVid2M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 128, 42 | "split": "train", 43 | "cut": "2M", 44 | "subsample": 1, 45 | "text_params": { 46 | "input": "text" 47 | }, 48 | "video_params": { 49 | "input_res": 224, 50 | "num_frames": 1, 51 | "loading": "lax" 52 | } 53 | } 54 | } 55 | ], 56 | "optimizer": { 57 | "type": "AdamW", 58 | "args":{ 59 | "lr": 2e-4 60 | } 61 | }, 62 | "loss": { 63 | "type": "NormSoftmaxLoss", 64 | "args": { 65 | } 66 | }, 67 | "metrics": [ 68 | "t2v_metrics", 69 | "v2t_metrics" 70 | ], 71 | "trainer": { 72 | "epochs": 100, 73 | "max_samples_per_epoch": 1000000, 74 | "save_dir": "./results/pt/WebVid/", 75 | "save_period": 5, 76 | "verbosity": 2, 77 | "monitor": "min val_loss_0", 78 | "early_stop": 10, 79 | "init_val": true, 80 | "neptune": false, 81 | "use_amp":false, 82 | "accum_iter":1 83 | }, 84 | "visualizer": { 85 | "type": "" 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /configs/pt/WebVid2M_clip_RL.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WebVid2M_clip_RL", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "CLIP", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": false, 15 | "VQ_num_tokens":2048, 16 | "AGG_region_num":8, 17 | "Interaction_depth":1, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "distilbert-base-uncased", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "minimal", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "WebVid", 36 | "data_dir": "data/WebVid2M/", 37 | "metadata_dir": "data/WebVid2M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 128, 42 | "split": "train", 43 | "cut": "2M", 44 | "subsample": 1, 45 | "text_params": { 46 | "input": "text" 47 | }, 48 | "video_params": { 49 | "input_res": 224, 50 | "num_frames": 1, 51 | "loading": "lax" 52 | } 53 | } 54 | } 55 | ], 56 | "optimizer": { 57 | "type": "AdamW", 58 | "args":{ 59 | "lr": 2e-4 60 | } 61 | }, 62 | "loss": { 63 | "type": "NormSoftmaxLoss", 64 | "args": { 65 | } 66 | }, 67 | "metrics": [ 68 | "t2v_metrics", 69 | "v2t_metrics" 70 | ], 71 | "trainer": { 72 | "epochs": 100, 73 | "max_samples_per_epoch": 1000000, 74 | "save_dir": "./results/pt/WebVid/", 75 | "save_period": 5, 76 | "verbosity": 2, 77 | "monitor": "min val_loss_0", 78 | "early_stop": 10, 79 | "init_val": true, 80 | "neptune": false, 81 | "use_amp":false, 82 | "accum_iter":1 83 | }, 84 | "visualizer": { 85 | "type": "" 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /configs/pt/WebVid2M_raw.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WebVid2M_raw", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "args": { 7 | "video_params": { 8 | "model": "SpaceTimeTransformer", 9 | "arch_config": "base_patch16_224", 10 | "num_frames": 4, 11 | "pretrained": true, 12 | "time_init": "zeros", 13 | "random_sampling": false, 14 | "Motion_Excitation": false, 15 | "VQ_num_tokens":0, 16 | "AGG_region_num":0, 17 | "Interaction_depth":0, 18 | "temporal_type":"att" 19 | }, 20 | "text_params": { 21 | "model": "distilbert-base-uncased", 22 | "pretrained": true, 23 | "pretrained_path": "pretrained", 24 | "input": "text" 25 | }, 26 | "projection": "minimal", 27 | "load_checkpoint" : "" 28 | } 29 | }, 30 | "data_loader": 31 | [ 32 | { 33 | "type": "MultiDistTextVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "WebVid", 36 | "data_dir": "data/WebVid2M/", 37 | "metadata_dir": "data/WebVid2M/", 38 | "reader": "decord", 39 | "shuffle": true, 40 | "num_workers": 16, 41 | "batch_size": 64, 42 | "split": "train", 43 | "cut": "2M", 44 | "subsample": 1, 45 | "text_params": { 46 | "input": "text" 47 | }, 48 | "video_params": { 49 | "input_res": 224, 50 | "num_frames": 1, 51 | "loading": "lax" 52 | } 53 | } 54 | } 55 | ], 56 | "optimizer": { 57 | "type": "AdamW", 58 | "args":{ 59 | "lr": 2e-4 60 | } 61 | }, 62 | "loss": { 63 | "type": "NormSoftmaxLoss", 64 | "args": { 65 | } 66 | }, 67 | "metrics": [ 68 | "t2v_metrics", 69 | "v2t_metrics" 70 | ], 71 | "trainer": { 72 | "epochs": 100, 73 | "max_samples_per_epoch": 1000000, 74 | "save_dir": "./results/pt/WebVid/", 75 | "save_period": 5, 76 | "verbosity": 2, 77 | "monitor": "min val_loss_0", 78 | "early_stop": 10, 79 | "init_val": true, 80 | "neptune": false, 81 | "use_amp":false, 82 | "accum_iter":1 83 | }, 84 | "visualizer": { 85 | "type": "" 86 | } 87 | 88 | } 89 | -------------------------------------------------------------------------------- /data_loader/ConceptualCaptions_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextImageDataset 2 | import pandas as pd 3 | import os 4 | import json 5 | import numpy as np 6 | import random 7 | import zlib 8 | 9 | 10 | class ConceptualCaptions3M(TextImageDataset): 11 | """ 12 | Conceptual Captions dataset. Split files are specific to my download regime. 13 | """ 14 | 15 | def _load_metadata(self): 16 | # download specific 17 | # metadata_dir = './meta_data/CC3M/' 18 | split_files = { 19 | 'train': 'cc3m_training_success_full.tsv', 20 | 'val': 'cc3m_validation_success_full.tsv', # there is no test 21 | } 22 | target_split_fp = split_files[self.split] 23 | metadata = pd.read_csv(os.path.join(self.metadata_dir, target_split_fp), sep='\t') 24 | 25 | if self.subsample < 1: 26 | metadata = metadata.sample(frac=self.subsample) 27 | # elif self.split == 'val': 28 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 29 | 30 | self.metadata = metadata 31 | 32 | def _get_video_path(self, sample): 33 | # conceptual captions uses this hashing to create the filename 34 | rel_dir = 'training' 35 | if self.split != 'train': 36 | rel_dir = 'validation' 37 | rel_fp = os.path.join(rel_dir, sample[1]) 38 | #rel_fp = os.path.join(rel_dir, str(zlib.crc32(sample['thumbnailUrl'].encode('utf-8')) & 0xffffffff)) 39 | return os.path.join(self.data_dir, rel_fp), rel_fp 40 | 41 | def _get_caption(self, sample): 42 | return sample[0] 43 | #return sample['caption'] 44 | -------------------------------------------------------------------------------- /data_loader/DiDeMo_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | 5 | 6 | class DiDeMo(TextVideoDataset): 7 | def _load_metadata(self): 8 | # metadata_dir = './meta_data/DIDEMO' 9 | split_files = { 10 | 'train': 'DiDeMo_train.tsv', 11 | 'val': 'DiDeMo_test.tsv', # there is no test 12 | 'test': 'DiDeMo_test.tsv' 13 | } 14 | target_split_fp = split_files[self.split] 15 | metadata = pd.read_csv(os.path.join(self.metadata_dir, target_split_fp), sep='\t') 16 | if self.subsample < 1: 17 | metadata = metadata.sample(frac=self.subsample) 18 | self.metadata = metadata 19 | print("load split {}, {} samples".format(self.split, len(metadata))) 20 | # data/MSVD/YouTubeClips 21 | 22 | def _get_video_path(self, sample): 23 | rel_video_fp = sample[1] 24 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 25 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 26 | # print(full_video_fp) 27 | return full_video_fp, rel_video_fp 28 | 29 | def _get_caption(self, sample): 30 | # print(sample[0].split(',')[0]) 31 | return sample[0] # .split(',')[0] 32 | 33 | def _get_object_path(self, sample): 34 | """ 35 | get the object npy path 36 | Args: 37 | sample (dict): 38 | Returns: 39 | abs path 40 | """ 41 | rel_object_fp = os.path.join(sample[1], '1.npz') 42 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp) 43 | return os.path.join(self.split, rel_object_fp), full_object_fp 44 | -------------------------------------------------------------------------------- /data_loader/HMDB_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | 3 | import pandas as pd 4 | import os 5 | import json 6 | import numpy as np 7 | import random 8 | 9 | 10 | class HMDB(TextVideoDataset): 11 | """ 12 | WebVid Dataset. 13 | Assumes HMDB51 data is structured as follows. 14 | HMDB51/ 15 | videos/ 16 | action_name/ ($action_name) 17 | 1.avi (videoid.mp4) 18 | ... 19 | 5000.avi 20 | ... 21 | """ 22 | def _load_metadata(self): 23 | metadata_dir = os.path.join(self.metadata_dir, 'metadata') 24 | metadata_fp = os.path.join(metadata_dir, f'{self.split}_{self.cut}.csv') 25 | metadata = pd.read_csv(metadata_fp) 26 | 27 | 28 | if self.subsample < 1: 29 | metadata = metadata.sample(frac=self.subsample) 30 | # elif self.split == 'val': 31 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 32 | 33 | 34 | self.metadata = metadata 35 | # TODO: clean final csv so this isn't necessary 36 | self.metadata.dropna(inplace=True) 37 | self.metadata['action_str'] = self.metadata['action_str'].str[:350] 38 | 39 | def _get_video_path(self, sample): 40 | # rel_video_fp = os.path.join(sample['action_str'], str(sample['video_name'])) 41 | rel_video_fp = str(sample['video_name']) 42 | full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) 43 | return full_video_fp, rel_video_fp 44 | 45 | def _get_caption(self, sample): 46 | action = sample['action_str'].replace('_',' ') 47 | if self.split=='train': 48 | cap_list = [ 49 | f"A video of {action}", 50 | f"A person is {action}" 51 | ] 52 | return random.choice(cap_list) 53 | else: 54 | return f"A person is {action}" 55 | # return f"{action}, a video of action" 56 | # return f"this is {action}, a video of action" 57 | # return f"human action of {action}" 58 | # return f"A video of {action}" 59 | # return f"A person is {action}" 60 | # return f"He is {action}" 61 | 62 | -------------------------------------------------------------------------------- /data_loader/LSMDC_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import torch 5 | 6 | class LSMDC(TextVideoDataset): 7 | def _load_metadata(self): 8 | # metadata_dir = './meta_data/LSMDC' 9 | split_files = { 10 | # 'train': 'LSMDC16_annos_training.csv', 11 | 'train': 'LSMDC16_annos_training.csv', 12 | 'val': 'LSMDC16_challenge_1000_publictect.csv', 13 | # 'val': 'LSMDC16_annos_val.csv', # there is no test 14 | 'test': 'LSMDC16_challenge_1000_publictect.csv' 15 | } 16 | target_split_fp = split_files[self.split] 17 | metadata = pd.read_csv(os.path.join(self.metadata_dir, target_split_fp), sep='\t') 18 | if self.subsample < 1: 19 | metadata = metadata.sample(frac=self.subsample) 20 | self.metadata = metadata 21 | print("load split {}, {} samples".format(self.split, len(metadata))) 22 | self.miss_vid_cnt = 0 23 | 24 | def _get_video_path(self, sample): 25 | video_fp = sample[0] 26 | sub_path = video_fp.split('.')[0] 27 | remove = sub_path.split('_')[-1] 28 | sub_path = sub_path.replace('_'+remove,'/') 29 | rel_video_fp = sub_path + video_fp + '.avi' 30 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 31 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 32 | return full_video_fp, rel_video_fp 33 | 34 | def _get_caption(self, sample): 35 | # print(sample[0].split(',')[0]) 36 | return sample[-1] # .split(',')[0] 37 | 38 | 39 | 40 | class LSMDC_MC(TextVideoDataset): 41 | def __getitem__(self, idx): 42 | ######################## Just copy from base ############################# 43 | idx = idx % len(self.metadata) 44 | sample = self.metadata[idx] # train/test for list 45 | video_fp, rel_fp = self._get_video_path(sample) 46 | ######################## Just copy from base ############################# 47 | 48 | if self.split == 'train': 49 | # just do RET 50 | caption = self._get_caption(sample) 51 | answer = -1 52 | else: 53 | # for MC choice Test, caption is the concated text of multiple choices 54 | # caption = {} 55 | # for i, opt in enumerate(sample["options"]): 56 | # caption[i] = opt 57 | caption = sample["options"] 58 | # print('caption:', len(caption), caption) 59 | answer = int(sample["answer"]) 60 | # print(self.split, caption) 61 | 62 | meta_arr = {'raw_captions': caption, 'paths': rel_fp, 'dataset': self.dataset_name} 63 | # data = {'video': final, 'text': caption, 'meta': meta_arr, 'frame_idxs': idxs} 64 | ######################## Just copy from base ############################# 65 | 66 | 67 | data = {'video': self.get_video(sample, video_fp), 'text': caption, 'meta': meta_arr, 'answer': answer} 68 | # print('base_dataset:\t', data.keys()) 69 | return data 70 | 71 | def _load_metadata(self): 72 | # metadata_dir = os.path.join(self.metadata_dir, "meta_data") 73 | metadata_dir = self.metadata_dir 74 | split_files = { 75 | 'train': 'LSMDC16_multiple_choice_train.csv', 76 | 'val': 'LSMDC16_multiple_choice_test_randomized.csv', # there is no test 77 | 'test': 'LSMDC16_multiple_choice_test_randomized.csv' 78 | } 79 | target_split_fp = split_files[self.split] 80 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 81 | if self.subsample < 1: 82 | metadata = metadata.sample(frac=self.subsample) 83 | 84 | datalist = [] 85 | for raw_id in range(len(metadata)): 86 | raw_d = metadata.iloc[raw_id] 87 | video_fp = raw_d[0] 88 | options = [raw_d[idx] for idx in range(5, 10)] 89 | d = dict( 90 | id=video_fp, 91 | # vid_id=rel_video_fp, 92 | answer=raw_d[-1] - 1 if self.split in ['val', 'test'] else 0, 93 | options=options, 94 | ) 95 | datalist.append(d) 96 | self.metadata = datalist 97 | # self.id2answer = {d["id"]: int(d["answer"]) for d in self.metadata} 98 | # self.id2data = {d["id"]: d for d in self.metadata} 99 | print("load split {}, {} samples".format(self.split, len(metadata))) 100 | 101 | 102 | def _get_video_path(self, sample): 103 | video_fp = sample['id'] 104 | sub_path = video_fp.split('.')[0] 105 | remove = sub_path.split('_')[-1] 106 | sub_path = sub_path.replace('_'+remove,'/') 107 | rel_video_fp = sub_path + video_fp + '.avi' 108 | full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) 109 | # print('_get_video_path', full_video_fp, rel_video_fp) 110 | return full_video_fp, rel_video_fp 111 | 112 | def _get_caption(self, sample): 113 | return sample['options'][0] # for train only, select the first option as cap for RET. -------------------------------------------------------------------------------- /data_loader/LSMDC_dataset_old.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | 9 | class LSMDC(TextVideoDataset): 10 | def _load_metadata(self): 11 | split_paths = {key: os.path.join(self.metadata_dir, 'structured-symlinks', f'{key}_list.txt') for key in 12 | ['train', 'val', 'test']} 13 | df_dict = {key: pd.read_csv(val, names=['videoid']) for key, val in split_paths.items()} 14 | #### subsample_val 15 | 16 | self.split_sizes = {key: len(val) for key, val in df_dict.items()} 17 | target_vids = df_dict[self.split] 18 | # target_vids = target_vids['videoid'].str.split('.').str[0] 19 | if self.subsample < 1: 20 | target_vids = target_vids.sample(frac=self.subsample) 21 | captions = np.load(os.path.join(self.metadata_dir, 'structured-symlinks', 'raw-captions.pkl'), 22 | allow_pickle=True) 23 | captions = pd.DataFrame.from_dict(captions, orient='index') 24 | captions['captions'] = captions.values.tolist() 25 | target_vids.set_index('videoid', inplace=True) 26 | target_vids['captions'] = captions['captions'] 27 | # import pdb; -.set_trace() 28 | # captions = captions[captions.index.isin(target_vids.str['videoid'].split('.').str[0])] 29 | self.metadata = target_vids 30 | frame_tar_list = pd.read_csv(os.path.join(self.metadata_dir, 'frame_tar_list.txt'), names=['fp']) 31 | 32 | frame_tar_list['fn'] = frame_tar_list['fp'].str.split('/').str[-2:].str.join('/') 33 | frame_tar_list['fn'] = frame_tar_list['fn'].str.replace('.tar', '') 34 | frame_tar_list['vid_stem'] = frame_tar_list['fn'].str.split('/').str[-1] 35 | 36 | frame_tar_list = frame_tar_list[frame_tar_list['vid_stem'].isin(self.metadata.index)] 37 | 38 | frame_tar_list.set_index('vid_stem', inplace=True) 39 | self.metadata['fn'] = frame_tar_list['fn'] 40 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [ii for ii in x if ii is not None]) 41 | self.metadata['num_captions'] = self.metadata['captions'].str.len() 42 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [' '.join(ii) for ii in x]) 43 | 44 | if 'videoid' not in self.metadata.columns: 45 | self.metadata['videoid'] = self.metadata.index 46 | 47 | 48 | def _get_video_path(self, sample): 49 | return os.path.join(self.data_dir, 'videos', sample['fn'] + '.avi'), sample.name + '.avi' 50 | 51 | def _get_caption(self, sample): 52 | if len(sample['captions']) != 1: 53 | raise NotImplementedError 54 | return sample['captions'][0] -------------------------------------------------------------------------------- /data_loader/MSRVTT_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | import json 9 | import torch 10 | 11 | from utils.util import load_json, load_jsonl 12 | 13 | 14 | class MSRVTT(TextVideoDataset): 15 | def _load_metadata(self): 16 | json_fp = os.path.join(self.metadata_dir, 'annotation', 'MSR_VTT.json') 17 | with open(json_fp, 'r') as fid: 18 | data = json.load(fid) 19 | df = pd.DataFrame(data['annotations']) 20 | 21 | split_dir = os.path.join(self.metadata_dir, 'high-quality', 'structured-symlinks') 22 | js_test_cap_idx_path = None 23 | challenge_splits = {"val", "public_server_val", "public_server_test"} 24 | if self.cut == "miech": 25 | train_list_path = "train_list_miech.txt" 26 | test_list_path = "test_list_miech.txt" 27 | elif self.cut == "jsfusion": 28 | train_list_path = "train_list_jsfusion.txt" 29 | test_list_path = "val_list_jsfusion.txt" 30 | js_test_cap_idx_path = "jsfusion_val_caption_idx.pkl" 31 | elif self.cut in {"full-val", "full-test"}: 32 | train_list_path = "train_list_full.txt" 33 | if self.cut == "full-val": 34 | test_list_path = "val_list_full.txt" 35 | else: 36 | test_list_path = "test_list_full.txt" 37 | elif self.cut in challenge_splits: 38 | train_list_path = "train_list.txt" 39 | if self.cut == "val": 40 | test_list_path = f"{self.cut}_list.txt" 41 | else: 42 | test_list_path = f"{self.cut}.txt" 43 | else: 44 | msg = "unrecognised MSRVTT split: {}" 45 | raise ValueError(msg.format(self.cut)) 46 | 47 | train_df = pd.read_csv(os.path.join(split_dir, train_list_path), names=['videoid']) 48 | test_df = pd.read_csv(os.path.join(split_dir, test_list_path), names=['videoid']) 49 | self.split_sizes = {'train': len(train_df), 'val': len(test_df), 'test': len(test_df)} 50 | 51 | if self.split == 'train': 52 | df = df[df['image_id'].isin(train_df['videoid'])] 53 | else: 54 | df = df[df['image_id'].isin(test_df['videoid'])] 55 | 56 | self.metadata = df.groupby(['image_id'])['caption'].apply(list) 57 | if self.subsample < 1: 58 | self.metadata = self.metadata.sample(frac=self.subsample) 59 | 60 | # use specific caption idx's in jsfusion 61 | if js_test_cap_idx_path is not None and self.split != 'train': 62 | caps = pd.Series(np.load(os.path.join(split_dir, js_test_cap_idx_path), allow_pickle=True)) 63 | new_res = pd.DataFrame({'caps': self.metadata, 'cap_idx': caps}) 64 | new_res['test_caps'] = new_res.apply(lambda x: [x['caps'][x['cap_idx']]], axis=1) 65 | self.metadata = new_res['test_caps'] 66 | 67 | self.metadata = pd.DataFrame({'captions': self.metadata}) 68 | 69 | def _get_video_path(self, sample): 70 | return os.path.join(self.data_dir, 'videos', 'all', sample.name + '.mp4'), sample.name + '.mp4' 71 | 72 | def _get_caption(self, sample): 73 | caption_sample = self.text_params.get('caption_sample', "rand") 74 | if self.split in ['train', 'val'] and caption_sample == "rand": 75 | caption = random.choice(sample['captions']) 76 | else: 77 | caption = sample['captions'][0] 78 | return caption 79 | 80 | 81 | class MSRVTT_MC(TextVideoDataset): 82 | def __getitem__(self, idx): 83 | ######################## Just copy from base ############################# 84 | idx = idx % len(self.metadata) 85 | sample = self.metadata.iloc[idx] if self.split=='train' else self.metadata[idx] # test for jsonl 86 | video_fp, rel_fp = self._get_video_path(sample) 87 | ######################## Just copy from base ############################# 88 | 89 | if self.split == 'train': 90 | # just do RET 91 | caption = self._get_caption(sample) 92 | answer = -1 93 | else: 94 | # for MC choice Test, caption is the concated text of multiple choices 95 | # caption = {} 96 | # for i, opt in enumerate(sample["options"]): 97 | # caption[i] = opt 98 | caption = sample["options"] 99 | # print('caption:', len(caption), caption) 100 | answer = int(sample["answer"]) 101 | 102 | 103 | meta_arr = {'raw_captions': caption, 'paths': rel_fp, 'dataset': self.dataset_name} 104 | data = {'video': self.get_video(sample, video_fp), 'text': caption, 'meta': meta_arr, 'answer': answer} 105 | # print('base_dataset:\t', data.keys()) 106 | return data 107 | 108 | # added by Mr. YAN 109 | def _load_metadata(self): 110 | if self.split=='train': 111 | print('here is train', self.data_dir) 112 | self.metadata = self._load_ret_train_full_metadata(self.data_dir) 113 | elif self.split=='val': 114 | meta_file = os.path.join(self.metadata_dir, "mc_test.jsonl") 115 | metadata = load_jsonl(meta_file) 116 | data_size = len(metadata) 117 | if self.subsample < 1: 118 | data_size = int(data_size * self.subsample) 119 | self.metadata = metadata[:data_size] 120 | 121 | print("load split {}, {} samples".format(self.split, len(self.metadata))) 122 | 123 | def _load_ret_train_full_metadata(self, metadata_dir): 124 | # copy from training part of ret metadata loading 125 | json_fp = os.path.join(metadata_dir, 'annotation', 'MSR_VTT.json') 126 | with open(json_fp, 'r') as fid: 127 | data = json.load(fid) 128 | df = pd.DataFrame(data['annotations']) 129 | 130 | split_dir = os.path.join(metadata_dir, 'high-quality', 'structured-symlinks') 131 | train_list_path = "train_list_full.txt" 132 | train_df = pd.read_csv(os.path.join(split_dir, train_list_path), names=['videoid']) 133 | df = df[df['image_id'].isin(train_df['videoid'])] 134 | metadata = df.groupby(['image_id'])['caption'].apply(list) 135 | if self.subsample < 1: 136 | metadata = metadata.sample(frac=self.subsample) 137 | 138 | return pd.DataFrame({'captions': metadata}) 139 | 140 | def _get_video_path(self, sample): 141 | if self.split=='train': 142 | return os.path.join(self.data_dir, 'videos', 'all', sample.name + '.mp4'), sample.name + '.mp4' 143 | else: 144 | return os.path.join(self.data_dir, 'videos', 'all', sample["clip_name"] + '.mp4'), sample["clip_name"] + '.mp4' 145 | 146 | 147 | 148 | def _get_caption(self, sample): 149 | caption_sample = self.text_params.get('caption_sample', "rand") 150 | if self.split in ['train', 'val'] and caption_sample == "rand": 151 | caption = random.choice(sample['captions']) 152 | else: 153 | caption = sample['captions'][0] 154 | return caption 155 | 156 | class MSRVTT_QA(TextVideoDataset): 157 | def __getitem__(self, idx): 158 | idx = idx % len(self.metadata) 159 | sample = self.metadata[idx] 160 | video_fp, rel_fp = self._get_video_path(sample) 161 | question = sample['question'] 162 | answer_id = self._get_answer_id(sample) 163 | 164 | meta_arr = {'raw_captions': question, 'paths': rel_fp, 'dataset': self.dataset_name} 165 | # data = {'video': final, 'text': caption, 'meta': meta_arr, 'frame_idxs': idxs} 166 | data = {'video': self.get_video(sample, video_fp), 'text': question, 'meta': meta_arr, 'answer_id': answer_id} 167 | # print('base_dataset:\t', data.keys()) 168 | return data 169 | 170 | # added by Mr. YAN 171 | def _load_metadata(self): 172 | ans2label_file = os.path.join(self.metadata_dir, "msrvtt_train_ans2label.json") 173 | self.ans2label = load_json(ans2label_file) 174 | split_files = { 175 | 'train': "msrvtt_qa_train.jsonl", 176 | 'test': "msrvtt_qa_test.jsonl", 177 | 'val': "msrvtt_qa_val.jsonl" 178 | } 179 | target_split_fp = split_files[self.split] 180 | meta_file = os.path.join(self.metadata_dir, target_split_fp) 181 | metadata = load_jsonl(meta_file) 182 | data_size = len(metadata) 183 | if self.subsample < 1: 184 | data_size = int(data_size * self.subsample) 185 | 186 | self.metadata = metadata[:data_size] 187 | self.num_labels = len(self.ans2label) 188 | self.label2ans = {v: k for k, v in self.ans2label.items()} 189 | 190 | print("load split {}, {} samples".format(self.split, data_size)) 191 | 192 | def _get_video_path(self, sample): 193 | return os.path.join(self.data_dir, 'videos', 'all', sample["video_id"] + '.mp4'), sample["video_id"] + '.mp4' 194 | 195 | def _get_answer_id(self, sample): 196 | if sample["answer"] in self.ans2label.keys(): 197 | return self.ans2label[sample["answer"]] 198 | else: 199 | return -1 # answers of some test samples may not in vocabulary -------------------------------------------------------------------------------- /data_loader/MSVD_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | 5 | 6 | from utils.util import load_json, load_jsonl 7 | 8 | class MSVD(TextVideoDataset): 9 | def _load_metadata(self): 10 | # metadata_dir = 'meta_data/MSVD' 11 | split_files = { 12 | 'train': 'MSVD_train.tsv', 13 | 'val': 'MSVD_test.tsv', # there is no test 14 | 'test': 'MSVD_test.tsv' 15 | } 16 | target_split_fp = split_files[self.split] 17 | metadata = pd.read_csv(os.path.join(self.metadata_dir, target_split_fp), sep='\t') 18 | if self.subsample < 1: 19 | metadata = metadata.sample(frac=self.subsample) 20 | self.metadata = metadata 21 | print("load split {}, {} samples".format(self.split, len(metadata))) 22 | # data/MSVD/YouTubeClips 23 | 24 | def _get_video_path(self, sample): 25 | rel_video_fp = sample[1] + '.avi' 26 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 27 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 28 | # print(full_video_fp) 29 | return full_video_fp, rel_video_fp 30 | 31 | def _get_caption(self, sample): 32 | # print(sample[0].split(',')[0]) 33 | return sample[0].split(',')[0] 34 | 35 | class MSVD_QA(TextVideoDataset): 36 | def __getitem__(self, idx): 37 | idx = idx % len(self.metadata) 38 | sample = self.metadata[idx] 39 | video_fp, rel_fp = self._get_video_path(sample) 40 | question = sample['question'] 41 | answer_id = self._get_answer_id(sample) 42 | 43 | meta_arr = {'raw_captions': question, 'paths': rel_fp, 'dataset': self.dataset_name} 44 | # data = {'video': final, 'text': caption, 'meta': meta_arr, 'frame_idxs': idxs} 45 | data = {'video': self.get_video(sample, video_fp), 'text': question, 'meta': meta_arr, 'answer_id': answer_id} 46 | # print('base_dataset:\t', data.keys()) 47 | return data 48 | 49 | # added by Mr. YAN 50 | def _load_metadata(self): 51 | ans2label_file = os.path.join(self.metadata_dir, "msvd_answer_set.txt") 52 | yt_mapping_file = os.path.join(self.metadata_dir, "msvd_youtube_mapping.txt") 53 | self.ans2label = self.load_ans_set(ans2label_file) 54 | self.video_mapper = self._get_video_mapper(yt_mapping_file) 55 | 56 | split_files = { 57 | 'train': "msvd_train_qa_encode.json", 58 | 'test': "msvd_test_qa_encode.json", 59 | 'val': "msvd_val_qa_encode.json" 60 | }# Only top 1000 answers are used 61 | 62 | target_split_fp = split_files[self.split] 63 | meta_file = os.path.join(self.metadata_dir, target_split_fp) 64 | metadata = load_json(meta_file) 65 | data_size = len(metadata) 66 | if self.subsample < 1: 67 | data_size = int(data_size * self.subsample) 68 | 69 | self.metadata = metadata[:data_size] 70 | self.num_labels = len(self.ans2label) 71 | self.label2ans = {v: k for k, v in self.ans2label.items()} 72 | 73 | print("load split {}, {} samples".format(self.split, data_size)) 74 | 75 | def _get_video_path(self, sample): 76 | video_name = self.video_mapper[str(sample["video_id"])] 77 | return os.path.join(self.data_dir, video_name + '.avi'), video_name + '.avi' 78 | 79 | def load_ans_set(self, ans_set_file): 80 | """ 81 | input: A list of answers from a txt file. 82 | """ 83 | ans2label = {} 84 | with open(ans_set_file, 'r') as f: 85 | lines = f.readlines() 86 | for i, line in enumerate(lines): 87 | ans2label[line.strip('\n')] = i 88 | 89 | return ans2label 90 | 91 | def _get_video_mapper(self, mapping_file): 92 | """ 93 | input: A list of from a txt file. 94 | """ 95 | video_mapper = {} 96 | with open(mapping_file, 'r') as f: 97 | lines = f.readlines() 98 | for line in lines: 99 | line = line.strip('\n') 100 | yt_id, vid = line.split(' ') 101 | video_mapper[vid.strip('vid')] = yt_id 102 | return video_mapper 103 | 104 | 105 | 106 | def _get_answer_id(self, sample): 107 | if sample["answer"] in self.ans2label.keys(): 108 | return self.ans2label[sample["answer"]] 109 | else: 110 | return -1 # answers of some test samples may not in vocabulary -------------------------------------------------------------------------------- /data_loader/UCF_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | 3 | import pandas as pd 4 | import os 5 | import json 6 | import numpy as np 7 | import random 8 | 9 | 10 | class UCF(TextVideoDataset): 11 | """ 12 | WebVid Dataset. 13 | Assumes UCF101 data is structured as follows. 14 | UCF101/ 15 | videos/ 16 | 1.avi (videoid.mp4) 17 | ... 18 | 5000.avi 19 | ... 20 | """ 21 | def _load_metadata(self): 22 | metadata_dir = os.path.join(self.metadata_dir, 'metadata') 23 | metadata_fp = os.path.join(metadata_dir, f'{self.split}_{self.cut}.csv') 24 | metadata = pd.read_csv(metadata_fp) 25 | 26 | 27 | if self.subsample < 1: 28 | metadata = metadata.sample(frac=self.subsample) 29 | # elif self.split == 'val': 30 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 31 | 32 | 33 | self.metadata = metadata 34 | # TODO: clean final csv so this isn't necessary 35 | self.metadata.dropna(inplace=True) 36 | self.metadata['action_str'] = self.metadata['action_str'].str[:350] 37 | 38 | def _get_video_path(self, sample): 39 | # rel_video_fp = os.path.join(sample['action_str'], str(sample['video_name'])) 40 | rel_video_fp = str(sample['video_name']) 41 | full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) 42 | return full_video_fp, rel_video_fp 43 | 44 | def _get_caption(self, sample): 45 | action = sample['action_str'].replace('_',' ') 46 | if self.split=='train': 47 | cap_list = [ 48 | f"A video of {action}", 49 | f"A person is {action}" 50 | ] 51 | return random.choice(cap_list) 52 | else: 53 | # return f"A person is {action}" 54 | return action 55 | 56 | -------------------------------------------------------------------------------- /data_loader/WebVid_dataset.py: -------------------------------------------------------------------------------- 1 | from base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import json 5 | import numpy as np 6 | import random 7 | 8 | 9 | # class WebVid(TextVideoDataset): 10 | # """ 11 | # WebVid Dataset. 12 | # Assumes webvid data is structured as follows. 13 | # Webvid/ 14 | # videos/ 15 | # 000001_000050/ ($page_dir) 16 | # 1.mp4 (videoid.mp4) 17 | # ... 18 | # 5000.mp4 19 | # ... 20 | # """ 21 | # def _load_metadata(self): 22 | # # metadata_dir = './meta_data/WEBVID/' 23 | 24 | # # NOTE: Not all videos can be download. Maybe you need adjust the meta_data for each dataset 25 | # split_files = { 26 | # 'train': 'results_2M_train.csv', 27 | # 'val': 'results_2M_val.csv', # there is no test 28 | # } 29 | # target_split_fp = split_files[self.split] 30 | # metadata = pd.read_csv(os.path.join(self.metadata_dir, target_split_fp), sep='\t') 31 | # if self.subsample < 1: 32 | # metadata = metadata.sample(frac=self.subsample) 33 | 34 | # # modified by Mr. Yan. Use full val set. 35 | # # elif self.split == 'val': 36 | # # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 37 | 38 | # #metadata['caption'] = metadata['name'] 39 | # #del metadata['name'] 40 | # self.metadata = metadata 41 | # # TODO: clean final csv so this isn't necessary 42 | # #self.metadata.dropna(inplace=True) 43 | # #self.metadata['caption'] = self.metadata['caption'].str[:350] 44 | 45 | # def _get_video_path(self, sample): 46 | 47 | # # rel_video_fp = sample[1] + '.mp4' 48 | # rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 49 | # full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp) 50 | 51 | # return full_video_fp, rel_video_fp 52 | 53 | # def _get_caption(self, sample): 54 | # return sample[0] 55 | 56 | 57 | class WebVid(TextVideoDataset): 58 | """ 59 | WebVid Dataset. 60 | Assumes webvid data is structured as follows. 61 | Webvid/ 62 | videos/ 63 | 000001_000050/ ($page_dir) 64 | 1.mp4 (videoid.mp4) 65 | ... 66 | 5000.mp4 67 | ... 68 | """ 69 | def _load_metadata(self): 70 | metadata_dir = os.path.join(self.metadata_dir, 'metadata') 71 | metadata_fp = os.path.join(metadata_dir, f'results_{self.cut}_{self.split}.csv') 72 | metadata = pd.read_csv(metadata_fp) 73 | 74 | if self.subsample < 1: 75 | metadata = metadata.sample(frac=self.subsample) 76 | 77 | # need all val samples 78 | # elif self.split == 'val': 79 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 80 | 81 | metadata['caption'] = metadata['name'] 82 | del metadata['name'] 83 | self.metadata = metadata 84 | # TODO: clean final csv so this isn't necessary 85 | self.metadata.dropna(inplace=True) 86 | # self.metadata['caption'] = self.metadata['caption'].str[:350] 87 | 88 | def _get_video_path(self, sample): 89 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 90 | full_video_fp = os.path.join(self.data_dir, 'videos', self.split, rel_video_fp) 91 | return full_video_fp, rel_video_fp 92 | 93 | def _get_caption(self, sample): 94 | return sample['caption'] -------------------------------------------------------------------------------- /data_loader/__pycache__/ConceptualCaptions_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/ConceptualCaptions_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/DiDeMo_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/DiDeMo_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/HMDB_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/HMDB_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/LSMDC_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/LSMDC_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/MSRVTT_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/MSRVTT_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/MSVD_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/MSVD_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/UCF_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/UCF_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/WebVid_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/WebVid_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/data_loader/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataLoader, BaseDataLoaderExplicitSplit, DistBaseDataLoaderExplicitSplit, MultiDistBaseDataLoaderExplicitSplit, BaseMultiDataLoader 2 | from data_loader.transforms import init_transform_dict, init_video_transform_dict 3 | from data_loader.ConceptualCaptions_dataset import ConceptualCaptions3M 4 | from data_loader.MSRVTT_dataset import MSRVTT, MSRVTT_MC, MSRVTT_QA 5 | from data_loader.MSVD_dataset import MSVD, MSVD_QA 6 | from data_loader.DiDeMo_dataset import DiDeMo 7 | from data_loader.LSMDC_dataset import LSMDC, LSMDC_MC 8 | from data_loader.WebVid_dataset import WebVid 9 | from data_loader.HMDB_dataset import HMDB 10 | from data_loader.UCF_dataset import UCF 11 | 12 | 13 | def dataset_loader(dataset_name, 14 | text_params, 15 | video_params, 16 | data_dir, 17 | metadata_dir=None, 18 | split='train', 19 | tsfms=None, 20 | cut=None, 21 | subsample=1, 22 | sliding_window_stride=-1, 23 | reader='cv2'): 24 | kwargs = dict( 25 | dataset_name=dataset_name, 26 | text_params=text_params, 27 | video_params=video_params, 28 | data_dir=data_dir, 29 | metadata_dir=metadata_dir, 30 | split=split, 31 | tsfms=tsfms, 32 | cut=cut, 33 | subsample=subsample, 34 | sliding_window_stride=sliding_window_stride, 35 | reader=reader 36 | ) 37 | 38 | # TODO: change to... 39 | dataset = globals()[dataset_name](**kwargs) 40 | # ...is this safe / or just lazy? 41 | # if dataset_name == "HMDB": 42 | # dataset = HMDB(**kwargs) 43 | # elif dataset_name == "UCF": 44 | # dataset = UCF(**kwargs) 45 | # elif dataset_name == "MSRVTT": 46 | # dataset = MSRVTT(**kwargs) 47 | # elif dataset_name == "MSRVTT_MC": 48 | # dataset = MSRVTT_MC(**kwargs) 49 | # elif dataset_name == "MSRVTT_QA": 50 | # dataset = MSRVTT_QA(**kwargs) 51 | # elif dataset_name == "LSMDC": 52 | # dataset = LSMDC(**kwargs) 53 | # elif dataset_name == "LSMDC_MC": 54 | # dataset = LSMDC_MC(**kwargs) 55 | # elif dataset_name == "MSVD": 56 | # dataset = MSVD(**kwargs) 57 | # elif dataset_name == "MSVD_QA": 58 | # dataset = MSVD_QA(**kwargs) 59 | # elif dataset_name == "DIDEMO": 60 | # dataset = DiDeMo(**kwargs) 61 | # elif dataset_name == "SomethingSomethingV2": 62 | # dataset = SomethingSomethingV2(**kwargs) 63 | # elif dataset_name == "WebVid": 64 | # dataset = WebVid(**kwargs) 65 | # elif dataset_name == "ConceptualCaptions3M": 66 | # dataset = ConceptualCaptions3M(**kwargs) 67 | # elif dataset_name == "ConceptualCaptions12M": 68 | # dataset = ConceptualCaptions12M(**kwargs) 69 | # elif dataset_name == "COCOCaptions": 70 | # dataset = COCOCaptions(**kwargs) 71 | # else: 72 | # raise NotImplementedError(f"Dataset: {dataset_name} not found.") 73 | 74 | return dataset 75 | 76 | 77 | class TextVideoDataLoader(BaseDataLoaderExplicitSplit): 78 | def __init__(self, 79 | dataset_name, 80 | text_params, 81 | video_params, 82 | data_dir, 83 | metadata_dir=None, 84 | split='train', 85 | tsfm_params=None, 86 | cut=None, 87 | subsample=1, 88 | sliding_window_stride=-1, 89 | reader='cv2', 90 | batch_size=1, 91 | num_workers=1, 92 | shuffle=True): 93 | if tsfm_params is None: 94 | tsfm_params = {} 95 | tsfm_dict = init_transform_dict(**tsfm_params) 96 | tsfm = tsfm_dict[split] 97 | dataset = dataset_loader(dataset_name, text_params, video_params, data_dir, metadata_dir, split, tsfm, cut, 98 | subsample, sliding_window_stride, reader) 99 | # if split != 'train': 100 | # shuffle = False 101 | 102 | 103 | 104 | 105 | super().__init__(dataset, batch_size, shuffle, num_workers) 106 | self.dataset_name = dataset_name 107 | 108 | class DistTextVideoDataLoader(DistBaseDataLoaderExplicitSplit): 109 | def __init__(self, 110 | dataset_name, 111 | text_params, 112 | video_params, 113 | data_dir, 114 | metadata_dir=None, 115 | split='train', 116 | tsfm_params=None, 117 | cut=None, 118 | subsample=1, 119 | sliding_window_stride=-1, 120 | reader='cv2', 121 | batch_size=1, 122 | num_workers=1, 123 | shuffle=True): 124 | if tsfm_params is None: 125 | tsfm_params = {} 126 | 127 | # BUG repaired by Mr. YAN 128 | if video_params['num_frames'] > 1: 129 | # video data can not do flip, crop aug 130 | tsfm_dict = init_video_transform_dict(**tsfm_params) 131 | else: 132 | tsfm_dict = init_transform_dict(**tsfm_params) 133 | 134 | tsfm = tsfm_dict[split] 135 | dataset = dataset_loader(dataset_name, text_params, video_params, data_dir, metadata_dir, split, tsfm, cut, 136 | subsample, sliding_window_stride, reader) 137 | # if split != 'train': 138 | # shuffle = False 139 | super().__init__(dataset, batch_size, shuffle, num_workers) 140 | self.dataset_name = dataset_name 141 | 142 | class MultiDistTextVideoDataLoader(MultiDistBaseDataLoaderExplicitSplit): 143 | def __init__(self, 144 | args, 145 | dataset_name, 146 | text_params, 147 | video_params, 148 | data_dir, 149 | metadata_dir=None, 150 | split='train', 151 | tsfm_params=None, 152 | cut=None, 153 | subsample=1, 154 | sliding_window_stride=-1, 155 | reader='cv2', 156 | batch_size=1, 157 | num_workers=1, 158 | shuffle=True): 159 | if tsfm_params is None: 160 | tsfm_params = {} 161 | # tsfm_dict = init_transform_dict(**tsfm_params) 162 | 163 | # BUG repaired by Mr. YAN 164 | if video_params['num_frames'] > 1: 165 | # video data can not do flip, crop aug 166 | tsfm_dict = init_video_transform_dict(**tsfm_params) 167 | else: 168 | tsfm_dict = init_transform_dict(**tsfm_params) 169 | 170 | tsfm = tsfm_dict[split] 171 | dataset = dataset_loader(dataset_name, text_params, video_params, data_dir, metadata_dir, split, tsfm, cut, 172 | subsample, sliding_window_stride, reader) 173 | # if split != 'train': 174 | # shuffle = False 175 | super().__init__(args, dataset, batch_size, shuffle, num_workers) 176 | self.dataset_name = dataset_name 177 | 178 | class TextVideoMultiDataLoader(BaseMultiDataLoader): 179 | # TODO: figure out neat way to have N data_loaders 180 | # TODO: also add N weighted sampler 181 | def __init__(self, data_loader1, data_loader2): 182 | # get class from "type" in dict 183 | dls_cfg = [data_loader1, data_loader2] 184 | dls = [] 185 | for dcfg in dls_cfg: 186 | dl = globals()[dcfg['type']](**dcfg['args']) 187 | dls.append(dl) 188 | super().__init__(dls) 189 | 190 | -------------------------------------------------------------------------------- /data_loader/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms._transforms_video import RandomCropVideo, RandomResizedCropVideo,CenterCropVideo, NormalizeVideo,ToTensorVideo,RandomHorizontalFlipVideo 3 | 4 | 5 | def init_transform_dict(input_res=224, 6 | center_crop=256, 7 | randcrop_scale=(0.5, 1.0), 8 | color_jitter=(0, 0, 0), 9 | norm_mean=(0.485, 0.456, 0.406), 10 | norm_std=(0.229, 0.224, 0.225)): 11 | print('Image Transform is used!') 12 | normalize = transforms.Normalize(mean=norm_mean, std=norm_std) 13 | tsfm_dict = { 14 | 'train': transforms.Compose([ 15 | transforms.RandomResizedCrop(input_res, scale=randcrop_scale), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]), 18 | normalize, 19 | ]), 20 | 'val': transforms.Compose([ 21 | transforms.Resize(center_crop), 22 | transforms.CenterCrop(center_crop), 23 | transforms.Resize(input_res), 24 | normalize, 25 | ]), 26 | 'test': transforms.Compose([ 27 | transforms.Resize(center_crop), 28 | transforms.CenterCrop(center_crop), 29 | transforms.Resize(input_res), 30 | normalize, 31 | ]) 32 | } 33 | return tsfm_dict 34 | 35 | 36 | 37 | # BUG fixed by Mr. YAN 38 | # A video-based transform is applied when the model takes more than one image. 39 | def init_video_transform_dict(input_res=224, 40 | center_crop=256, 41 | randcrop_scale=(0.5, 1.0), 42 | color_jitter=(0, 0, 0), 43 | norm_mean=(0.485, 0.456, 0.406), 44 | norm_std=(0.229, 0.224, 0.225)): 45 | print('Video Transform is used!') 46 | normalize = NormalizeVideo(mean=norm_mean, std=norm_std) 47 | tsfm_dict = { 48 | 'train': transforms.Compose([ 49 | RandomResizedCropVideo(input_res), 50 | RandomHorizontalFlipVideo(), 51 | transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]), 52 | normalize, 53 | ]), 54 | 'val': transforms.Compose([ 55 | transforms.Resize(center_crop), 56 | transforms.CenterCrop(center_crop), 57 | transforms.Resize(input_res), 58 | normalize, 59 | ]), 60 | 'test': transforms.Compose([ 61 | transforms.Resize(center_crop), 62 | transforms.CenterCrop(center_crop), 63 | transforms.Resize(input_res), 64 | normalize, 65 | ]) 66 | } 67 | return tsfm_dict -------------------------------------------------------------------------------- /fine-tuning.sh: -------------------------------------------------------------------------------- 1 | 2 | # install environments 3 | sh setup_myEnv.sh 4 | 5 | 6 | # goto workdir 7 | # echo "load path ....." 8 | cd your_path/Region_Learner 9 | 10 | 11 | nproc_per_node=4 # determined by your resource 12 | 13 | 14 | # NOTE: Not all videos can be download. Maybe you need adjust the meta_data for each dataset class defined in 'data_loader'. 15 | 16 | ##################################### MSRVTT ################################### 17 | # set your path 18 | MSRVTT_root="data/MSRVTT/" 19 | MSRVTT_save_dir="./results/ft/MSRVTT/" 20 | CCWV_mw="your_path_to/model_best.pth" 21 | 22 | # fine-tuning on MSRVTT 23 | python -m torch.distributed.launch --nproc_per_node $nproc_per_node $@ train.py \ 24 | --config configs/ft/MSRVTT_8f.json --launcher pytorch \ 25 | --save_dir $MSRVTT_save_dir --load_checkpoint $CCWV_mw \ 26 | --data_dir_0 $MSRVTT_root --learning_rate1 3e-5 --schedule 101 27 | ################################################################################# 28 | 29 | 30 | 31 | ######################### TODO:Other benchmarks ################################## 32 | 33 | 34 | 35 | if [ $? != 0 ]; then 36 | echo "Fail! Exit with 1" 37 | exit 1 38 | else 39 | echo "Success! Exit with 0" 40 | exit 0 41 | fi 42 | 43 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/logger/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /logger/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/logger/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /logger/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/logger/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='src/logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from utils import Timer 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ 27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ 28 | "the 'config.json' file." 29 | logger.warning(message) 30 | 31 | self.step = 0 32 | self.mode = '' 33 | 34 | self.tb_writer_ftns = { 35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 37 | } 38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 39 | 40 | self.timer = Timer() 41 | 42 | def set_step(self, step, mode='train'): 43 | self.mode = mode 44 | self.step = step 45 | if step == 0: 46 | self.timer.reset() 47 | else: 48 | duration = self.timer.check() 49 | self.add_scalar('steps_per_sec', 1 / duration) 50 | 51 | def __getattr__(self, name): 52 | """ 53 | If visualization is configured to use: 54 | return add_data() methods of tensorboard with additional information (step, tag) added. 55 | Otherwise: 56 | return a blank function handle that does nothing 57 | """ 58 | if name in self.tb_writer_ftns: 59 | add_data = getattr(self.writer, name, None) 60 | 61 | def wrapper(tag, data, *args, **kwargs): 62 | if add_data is not None: 63 | # add mode(train/valid) tag 64 | if name not in self.tag_mode_exceptions: 65 | tag = '{}/{}'.format(tag, self.mode) 66 | add_data(tag, data, self.step, *args, **kwargs) 67 | return wrapper 68 | else: 69 | # default action for returning methods defined in this class, set_step() for instance. 70 | try: 71 | attr = object.__getattr__(name) 72 | except AttributeError: 73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 74 | return attr 75 | 76 | 77 | class SacredNeptuneWriter(): 78 | def __init__(self): 79 | raise NotImplementedError -------------------------------------------------------------------------------- /model/RegionLearner/Quantizer.py: -------------------------------------------------------------------------------- 1 | # reference https://github.com/researchmm/soho/blob/d98b2ba52ffda2ba857aa4bc0d4e9239efcfd806/SOHO/models/necks/utils.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | 7 | 8 | class GatherLayer(torch.autograd.Function): 9 | '''Gather tensors from all process, supporting backward propagation. 10 | ''' 11 | @staticmethod 12 | def forward(ctx, input): 13 | ctx.save_for_backward(input) 14 | output = [torch.zeros_like(input) \ 15 | for _ in range(dist.get_world_size())] 16 | dist.all_gather(output, input) 17 | return tuple(output) 18 | 19 | @staticmethod 20 | def backward(ctx, *grads): 21 | input, = ctx.saved_tensors 22 | grad_out = torch.zeros_like(input) 23 | grad_out[:] = grads[dist.get_rank()] 24 | return grad_out 25 | 26 | 27 | @torch.no_grad() 28 | def concat_all_gather(tensor): 29 | """ 30 | Performs all_gather operation on the provided tensors. 31 | *** Warning ***: torch.distributed.all_gather has no gradient. 32 | """ 33 | tensors_gather = [ 34 | torch.ones_like(tensor) 35 | for _ in range(torch.distributed.get_world_size()) 36 | ] 37 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 38 | 39 | output = torch.cat(tensors_gather, dim=0) 40 | return output 41 | 42 | 43 | def ema_inplace(moving_avg, new, decay): 44 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 45 | 46 | 47 | def ema_tensor_inplace(moving_avg, new, decay): 48 | new_out = torch.mul(new, 1.0 - decay) 49 | moving_avg.data.mul_(decay).add_(new_out.detach()) 50 | 51 | 52 | def sum_inplace(sum_data, new): 53 | sum_data.data.add_(new) 54 | 55 | 56 | def laplace_smoothing(x, n_categories, eps=1e-5): 57 | return (x + eps) / (x.sum() + n_categories * eps) 58 | 59 | 60 | def laplace_smoothing_dim(x, n_categories, dim=1, eps=1e-5): 61 | return (x + eps) / (x.sum(dim=dim, keepdim=True) + n_categories * eps) 62 | 63 | 64 | class VectorQuantizer(nn.Module): 65 | """ 66 | Inputs: 67 | - num_tokens : number of embeddings 68 | - token_dim : dimension of embedding 69 | - dist : distribution training or not 70 | """ 71 | def __init__(self, 72 | num_tokens, 73 | token_dim, 74 | decay=0.1, 75 | max_decay=0.99, 76 | eps=1e-5, 77 | dist=False): 78 | super(VectorQuantizer, self).__init__() 79 | self.dist = dist 80 | self.token_dim = token_dim 81 | self.num_tokens = num_tokens 82 | embed = torch.randn(num_tokens, token_dim) 83 | self.register_buffer('embed', embed) 84 | nn.init.normal_(self.embed) 85 | self.register_buffer('cluster_size', torch.zeros(num_tokens)) 86 | self.register_buffer('cluster_sum', torch.zeros(num_tokens)) 87 | self.register_buffer('embed_avg', torch.zeros(num_tokens, token_dim)) 88 | 89 | self.decay = decay 90 | self.eps = eps 91 | self.curr_decay = self.decay 92 | self.max_decay = max_decay 93 | 94 | print("Performing VectorQuantizer to cluster %d tokens ...." %(num_tokens)) 95 | 96 | def set_decay_updates(self, num_update): 97 | self.curr_decay = min(self.decay * num_update, self.max_decay) 98 | 99 | def forward(self, inputs_flatten): 100 | 101 | distances = (torch.sum(inputs_flatten**2, dim=1, keepdim=True) + 102 | torch.sum(self.embed.data**2, dim=1) - 103 | 2 * torch.matmul(inputs_flatten, self.embed.data.t())) 104 | """ 105 | encoding_indices: Tensor containing the discrete encoding indices, ie 106 | which element of the quantized space each input element was mapped to. 107 | """ 108 | 109 | # print('distances:\t', distances.size()) 110 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 111 | # print('encoding_indices:\t', encoding_indices.size()) 112 | 113 | encodings = torch.zeros(encoding_indices.shape[0], 114 | self.num_tokens, 115 | dtype=torch.float, 116 | device=inputs_flatten.device) 117 | # print('encodings:\t', encodings.size()) 118 | 119 | encodings.scatter_(1, encoding_indices, 1) 120 | # print('encodings:\t', encodings.size()) 121 | 122 | if self.training: 123 | 124 | tmp_sum = torch.sum(encodings, dim=0, keepdim=True) 125 | if self.dist: 126 | encoding_sum = torch.sum(concat_all_gather(tmp_sum), dim=0) 127 | else: 128 | encoding_sum = torch.sum(tmp_sum, dim=0) 129 | 130 | sum_inplace(self.cluster_sum, encoding_sum) 131 | ema_tensor_inplace(self.cluster_size, encoding_sum, 132 | self.curr_decay) 133 | embed_sum_tmp = torch.matmul(encodings.t(), inputs_flatten) 134 | if self.dist: 135 | embed_sum = torch.sum(concat_all_gather( 136 | embed_sum_tmp.unsqueeze(dim=0)), 137 | dim=0) 138 | else: 139 | embed_sum = torch.sum(embed_sum_tmp.unsqueeze(dim=0), dim=0) 140 | 141 | ema_tensor_inplace(self.embed_avg, embed_sum, self.curr_decay) 142 | 143 | cluster_size = laplace_smoothing( 144 | self.cluster_size, self.num_tokens, 145 | self.eps) * self.cluster_size.sum() 146 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 147 | 148 | if self.dist: 149 | world_size = dist.get_world_size() 150 | dist.all_reduce(embed_normalized.div_(world_size)) 151 | self.embed.data.copy_(embed_normalized) 152 | 153 | quantize = torch.matmul(encodings, self.embed) 154 | # print('encodings:\t', encodings) 155 | #quantize = inputs_flatten 156 | quantize = (quantize - inputs_flatten).detach() + inputs_flatten 157 | 158 | return quantize, encoding_indices 159 | -------------------------------------------------------------------------------- /model/RegionLearner/RegionLearner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | from model.RegionLearner.Quantizer import VectorQuantizer 7 | from model.helper import Attention 8 | 9 | class Aggregation(nn.Module): 10 | """ 11 | The second step of RegionLearner. 12 | It is designed to aggregate quantized tokens into several semantic regions 13 | """ 14 | def __init__(self, token_dim, num_region=8): 15 | super(Aggregation, self).__init__() 16 | self.num_region = num_region 17 | self.token_dim= token_dim 18 | self.spatial_pooling = nn.AdaptiveAvgPool2d((1,1)) 19 | self.spatial_att = nn.Conv2d(in_channels=token_dim, 20 | out_channels=self.num_region, # each channel used as att map for capturing one region 21 | kernel_size=3, 22 | padding=1 23 | ) 24 | 25 | print("Performing Aggregation to learn %d regions ...." %(num_region)) 26 | 27 | def forward(self, x): 28 | # x [B, C, H, W] 29 | B = x.size(0) 30 | region_mask = self.spatial_att(x) # [B, S, H, W] 31 | learned_region_list = [] 32 | for s in range(self.num_region): 33 | # print(x.size(), att_mask[:,s,...].unsqueeze(1).size()) 34 | learned_region = x * region_mask[:,s,...].unsqueeze(1) # [B, C, H, W] * [B, 1, H, W] --> [B, C, H, W] 35 | learned_region_list.append(self.spatial_pooling(learned_region).reshape(B, self.token_dim)) # [B, C, H, W] --> [B, C, 1, 1] 36 | 37 | # learned_region_list [B, C, 1] 38 | # print(learned_region_list[0].size()) 39 | learned_regions = torch.stack(learned_region_list, dim=-1) # [B, C, S] 40 | return learned_regions, region_mask # [B, C, S] 41 | 42 | 43 | class Motion_Excitation(nn.Module): 44 | """ 45 | It is designed to xxx 46 | """ 47 | def __init__(self, token_dim): 48 | super(Motion_Excitation, self).__init__() 49 | self.ex_proj = nn.Conv2d(in_channels=token_dim, 50 | out_channels=1, # each channel used as att map for capturing one region 51 | kernel_size=3, 52 | padding=1 53 | ) 54 | print("Performing Motion Excitation ...") 55 | 56 | def ex(self, x, h, w): 57 | # X: [B, T, L, C] 58 | _, T, L, C = x.size() 59 | x = x.reshape(-1, L, C) #[B*T, L, C] 60 | x = x.transpose(1, 2) #[B*T, C, L] 61 | x = x.reshape(-1, C, h, w) 62 | motion_map = self.ex_proj(x) # [B*T, 1, H, W] 63 | motion_map = motion_map.reshape(-1, T, L) 64 | return motion_map 65 | 66 | def forward(self, x, h, w): 67 | # X: [B, T, L, C] 68 | # shift X: 69 | T = x.size(1) 70 | x_ = x.clone() 71 | x_[:,1:,...] = x[:,:T-1,...] 72 | motion_map = self.ex(x - x_, h, w) # [B, T, L] 73 | motion_map = F.softmax(motion_map, dim=-1) # [B, T, L] 74 | # motion_map = F.softmax(fc(fc(x_) - x), dim=-1) 75 | # print('motion_map:', x.size(), motion_map.size()) 76 | motaion_feas = x * motion_map.unsqueeze(-1) # [B, T, L, C] 77 | return motaion_feas 78 | 79 | 80 | class RegionLearner(nn.Module): 81 | """ 82 | Learning implicit regions without supervision from video feature map. 83 | """ 84 | def __init__(self, VQ_num_tokens=None, VQ_token_dim=768, AGG_region_num=None, Interaction_depth=None, ME=None, dist=False): 85 | super(RegionLearner, self).__init__() 86 | self.Quantization = None 87 | self.Aggregation = None 88 | self.Interaction = None 89 | self.Motion_Excitation = None 90 | 91 | 92 | if VQ_num_tokens and VQ_token_dim: 93 | self.Quantization = VectorQuantizer(VQ_num_tokens, VQ_token_dim, dist=dist) 94 | if ME: 95 | self.Motion_Excitation = Motion_Excitation(VQ_token_dim) 96 | 97 | 98 | if AGG_region_num: 99 | self.Aggregation = Aggregation(token_dim=VQ_token_dim, num_region=AGG_region_num) 100 | if Interaction_depth: 101 | print("Performing Interaction among regions with %d layers ...." %(Interaction_depth)) 102 | if Interaction_depth>1: 103 | self.Interaction = nn.ModuleList([Attention(VQ_token_dim) 104 | for i in range(Interaction_depth)]) 105 | else: 106 | # TODO Delete it, now is kept for old models. 107 | self.Interaction = Attention(VQ_token_dim) 108 | 109 | def forward(self, in_feas, cur_f=1, epoch=0): 110 | if self.Quantization: 111 | B, L, C = in_feas.size() 112 | vd_inputs = in_feas.reshape(-1, C) # [BL, C] 113 | vd_outputs, encoding_indices = self.Quantization(vd_inputs) # [BL, C], [BL, 1] 114 | h = int(math.sqrt(L)) 115 | w = int(L//h) 116 | encoding_indices = encoding_indices.reshape(-1, h, w) 117 | 118 | # Perform Motion Excitation 119 | if cur_f>1 and self.Motion_Excitation: 120 | in_feas = in_feas.view(-1, cur_f, L, C) 121 | motion_feas = self.Motion_Excitation(in_feas, h, w) # [B, T, L, C] 122 | motion_feas = motion_feas.reshape(-1, C) 123 | # print('size:', motion_feas.size(), vd_outputs.size()) 124 | vd_outputs = vd_outputs + motion_feas 125 | # vd_outputs = torch.cat([vd_outputs,motion_feas], dim=-1) 126 | 127 | 128 | 129 | 130 | if self.Aggregation: 131 | # [B*L, C] 132 | vd_outputs = vd_outputs.reshape(B, L, C) 133 | vd_outputs = vd_outputs.transpose(1, 2) #[B, C, L] 134 | vd_outputs = vd_outputs.reshape(B, C, h, w) 135 | # [B, C, H, W] 136 | # print('learn regions, x size:\t', x.size()) 137 | vd_outputs, region_mask = self.Aggregation(vd_outputs) # [B, C, S] 138 | vd_outputs = vd_outputs.transpose(1, 2) # [B, S, C] 139 | if self.Interaction: 140 | # print('do joint att') 141 | _, S, C = vd_outputs.size() 142 | # print('vd_outputs:\t', vd_outputs.size()) 143 | T = cur_f 144 | # TODO we can do spatial-temporal on regions 145 | vd_outputs = vd_outputs.reshape(-1, T, S, C) 146 | vd_outputs = vd_outputs.reshape(-1, T*S, C) 147 | # [b, T*S, C] 148 | # print('att inputs:\t', vd_outputs.size()) 149 | vd_outputs = self.Interaction(vd_outputs) 150 | # print('att outputs:\t', vd_outputs.size()) 151 | vd_outputs = vd_outputs.reshape(-1, T, S, C) 152 | vd_outputs = vd_outputs.reshape(-1, S, C) # [B, S, C] 153 | 154 | return vd_outputs, encoding_indices, region_mask 155 | else: 156 | vd_outputs = vd_outputs.reshape(B, L, C) 157 | return vd_outputs, encoding_indices, None 158 | 159 | 160 | 161 | 162 | 163 | if __name__ == "__main__": 164 | B, T, L = 2, 2, 196 165 | token_dim = 768 166 | num_tokens = 2048 167 | RL = RegionLearner(VQ_num_tokens=num_tokens, VQ_token_dim=token_dim, AGG_region_num=8, Interaction_depth=1, dist=False, ME=True) 168 | inputs = torch.randn(B*T, L, token_dim) 169 | print('Input of RegionLearner:\t', inputs.size()) 170 | outputs, encoding_indices, region_mask = RL(inputs, T) 171 | print('Output of RegionLearner:\t', outputs.size()) 172 | print('Encoding Indices of Quantization:\t', encoding_indices.size()) 173 | if region_mask is not None: 174 | print('Region Mask of Aggregation:\t', region_mask.size()) 175 | 176 | -------------------------------------------------------------------------------- /model/RegionLearner/__pycache__/Quantizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/RegionLearner/__pycache__/Quantizer.cpython-37.pyc -------------------------------------------------------------------------------- /model/RegionLearner/__pycache__/RegionLearner.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/RegionLearner/__pycache__/RegionLearner.cpython-37.pyc -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__init__.py -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/helper.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/qa_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/qa_model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/video_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/__pycache__/video_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /model/clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/__pycache__/tokenization_clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/__pycache__/tokenization_clip.cpython-37.pyc -------------------------------------------------------------------------------- /model/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/model/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /model/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 94 | """Load a CLIP model 95 | 96 | Parameters 97 | ---------- 98 | name : str 99 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 100 | 101 | device : Union[str, torch.device] 102 | The device to put the loaded model 103 | 104 | jit : bool 105 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 106 | 107 | download_root: str 108 | path to download the model files; by default, it uses "~/.cache/clip" 109 | 110 | Returns 111 | ------- 112 | model : torch.nn.Module 113 | The CLIP model 114 | 115 | preprocess : Callable[[PIL.Image], torch.Tensor] 116 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 117 | """ 118 | if name in _MODELS: 119 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 120 | elif os.path.isfile(name): 121 | model_path = name 122 | else: 123 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 124 | 125 | try: 126 | # loading JIT archive 127 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 128 | state_dict = None 129 | except RuntimeError: 130 | # loading saved state dict 131 | if jit: 132 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 133 | jit = False 134 | state_dict = torch.load(model_path, map_location="cpu") 135 | 136 | print('state_dict', state_dict) 137 | if not jit: 138 | model = build_model(state_dict or model.state_dict()).to(device) 139 | if str(device) == "cpu": 140 | model.float() 141 | return model, _transform(model.visual.input_resolution) 142 | 143 | # patch the device names 144 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 145 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 146 | 147 | def patch_device(module): 148 | try: 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | except RuntimeError: 151 | graphs = [] 152 | 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("prim::Constant"): 158 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 159 | node.copyAttributes(device_node) 160 | 161 | model.apply(patch_device) 162 | patch_device(model.encode_image) 163 | patch_device(model.encode_text) 164 | 165 | # patch dtype to float32 on CPU 166 | if str(device) == "cpu": 167 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 168 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 169 | float_node = float_input.node() 170 | 171 | def patch_float(module): 172 | try: 173 | graphs = [module.graph] if hasattr(module, "graph") else [] 174 | except RuntimeError: 175 | graphs = [] 176 | 177 | if hasattr(module, "forward1"): 178 | graphs.append(module.forward1.graph) 179 | 180 | for graph in graphs: 181 | for node in graph.findAllNodes("aten::to"): 182 | inputs = list(node.inputs()) 183 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 184 | if inputs[i].node()["value"] == 5: 185 | inputs[i].node().copyAttributes(float_node) 186 | 187 | model.apply(patch_float) 188 | patch_float(model.encode_image) 189 | patch_float(model.encode_text) 190 | 191 | model.float() 192 | 193 | return model, _transform(model.input_resolution.item()) 194 | 195 | 196 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 197 | """ 198 | Returns the tokenized representation of given input string(s) 199 | 200 | Parameters 201 | ---------- 202 | texts : Union[str, List[str]] 203 | An input string or a list of input strings to tokenize 204 | 205 | context_length : int 206 | The context length to use; all CLIP models use 77 as the context length 207 | 208 | truncate: bool 209 | Whether to truncate the text in case its encoding is longer than the context length 210 | 211 | Returns 212 | ------- 213 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 214 | """ 215 | if isinstance(texts, str): 216 | texts = [texts] 217 | 218 | sot_token = _tokenizer.encoder["<|startoftext|>"] 219 | eot_token = _tokenizer.encoder["<|endoftext|>"] 220 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 221 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 222 | 223 | for i, tokens in enumerate(all_tokens): 224 | if len(tokens) > context_length: 225 | if truncate: 226 | tokens = tokens[:context_length] 227 | tokens[-1] = eot_token 228 | else: 229 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 230 | result[i, :len(tokens)] = torch.tensor(tokens) 231 | 232 | return result 233 | -------------------------------------------------------------------------------- /model/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /model/clip/tokenization_clip.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | self.vocab = self.encoder 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | def tokenize(self, text): 137 | tokens = [] 138 | text = whitespace_clean(basic_clean(text)).lower() 139 | for token in re.findall(self.pat, text): 140 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 141 | tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 142 | return tokens 143 | 144 | def convert_tokens_to_ids(self, tokens): 145 | return [self.encoder[bpe_token] for bpe_token in tokens] -------------------------------------------------------------------------------- /model/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines some useful function for model building. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from einops import rearrange, repeat 9 | import os 10 | import json 11 | from datetime import datetime 12 | 13 | def get_random_sample_indices( 14 | seq_len, sample_ratio=0.9, num_samples=100, device=torch.device("cpu")): 15 | """ 16 | Args: 17 | seq_len: int, the sampled indices will be in the range [0, seq_len-1] 18 | num_samples: sample size 19 | device: torch.device 20 | Returns: 21 | 1D torch.LongTensor consisting of sorted sample indices 22 | (sort should not affect the results as we use transformers) 23 | """ 24 | if sample_ratio: 25 | num_samples = int(seq_len*sample_ratio) 26 | 27 | if num_samples >= seq_len: 28 | # return all indices 29 | sample_indices = np.arange(seq_len) 30 | else: 31 | sample_indices = np.random.choice( 32 | seq_len, size=num_samples, replace=False) 33 | sample_indices = np.sort(sample_indices) 34 | return torch.from_numpy(sample_indices).long().to(device) 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True): 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = qk_scale or head_dim ** -0.5 43 | self.with_qkv = with_qkv 44 | if self.with_qkv: 45 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 46 | self.proj = nn.Linear(dim, dim) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | self.attn_drop = nn.Dropout(attn_drop) 49 | 50 | def forward(self, x): 51 | B, N, C = x.shape 52 | if self.with_qkv: 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | q, k, v = qkv[0], qkv[1], qkv[2] 55 | else: 56 | qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 57 | q, k, v = qkv, qkv, qkv 58 | 59 | attn = (q @ k.transpose(-2, -1)) * self.scale 60 | attn = attn.softmax(dim=-1) 61 | attn = self.attn_drop(attn) 62 | 63 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 64 | if self.with_qkv: 65 | x = self.proj(x) 66 | x = self.proj_drop(x) 67 | return x 68 | 69 | 70 | 71 | 72 | ############ Following functions are used to visualize results from RegionLearner ##################### 73 | 74 | def file_check(file_name): 75 | path = os.path.dirname(file_name) 76 | if not os.path.exists(path): 77 | os.makedirs(path) 78 | 79 | def json_write(data, file_name): 80 | try: 81 | file_check(file_name) 82 | with open(file_name, 'w+') as outfile: 83 | json.dump(data, outfile) 84 | print('json file saved at %s'%(file_name)) 85 | except: 86 | import traceback 87 | traceback.print_exc() 88 | print('cannot write %s'%(file_name)) 89 | 90 | def save_vis_re(data, vis_re, save_pth=None, timestamp=True): 91 | # meta_arr = {'raw_captions': caption, 'paths': rel_fp, 'dataset': self.dataset_name} 92 | # data = {'video': final, 'text': caption, 'meta': meta_arr, 'frame_idxs': idxs} 93 | # print('save_vis_re:\t', data.keys()) 94 | indices, region_mask = vis_re 95 | vids = data['meta']['paths'] 96 | raw_caps= data['meta']['raw_captions'] 97 | # TODO support multiple frames 98 | frame_idxs= data['frame_idxs'][0].tolist() 99 | 100 | re = {} 101 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 102 | print(len(vids), indices.size(), region_mask.size()) 103 | for i in range(len(vids)): 104 | k = str(vids[i]) 105 | v = indices[i].cpu().detach().tolist() 106 | v1 = region_mask[i].cpu().detach().tolist() 107 | # print(k,v) 108 | re[k] = {'cluster_id':v, 'region_mask':v1, 'raw_caption':raw_caps[i], 'frame_idxs':frame_idxs[i]} 109 | # re[k] = {'cluster_id':v, 'region_mask':v1, 'raw_caption':raw_caps[i]} 110 | 111 | # print(re) 112 | save_pth = os.path.join(save_pth, timestamp) 113 | try: 114 | json_write(re, '%s/vis.json'%(save_pth)) 115 | except: 116 | print("failed to save results!!!") 117 | import traceback 118 | traceback.print_exc() -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class NormSoftmaxLoss(nn.Module): 7 | def __init__(self, temperature=0.05): 8 | super().__init__() 9 | 10 | self.temperature = temperature 11 | 12 | def forward(self, x): 13 | "Assumes input x is similarity matrix of N x M \in [-1, 1], computed using the cosine similarity between normalised vectors" 14 | i_logsm = F.log_softmax(x/self.temperature, dim=1) 15 | j_logsm = F.log_softmax(x.t()/self.temperature, dim=1) 16 | 17 | # sum over positives 18 | idiag = torch.diag(i_logsm) 19 | loss_i = idiag.sum() / len(idiag) 20 | 21 | jdiag = torch.diag(j_logsm) 22 | loss_j = jdiag.sum() / len(jdiag) 23 | 24 | return - loss_i - loss_j 25 | 26 | 27 | class MaxMarginRankingLoss(nn.Module): 28 | 29 | def __init__(self, margin=1, fix_norm=True): 30 | super().__init__() 31 | self.fix_norm = fix_norm 32 | self.loss = th.nn.MarginRankingLoss(margin) 33 | self.margin = margin 34 | 35 | def forward(self, x): 36 | n = x.size()[0] 37 | 38 | x1 = th.diag(x) 39 | x1 = x1.unsqueeze(1) 40 | x1 = x1.expand(n, n) 41 | x1 = x1.contiguous().view(-1, 1) 42 | x1 = th.cat((x1, x1), 0) 43 | 44 | x2 = x.view(-1, 1) 45 | x3 = x.transpose(0, 1).contiguous().view(-1, 1) 46 | 47 | x2 = th.cat((x2, x3), 0) 48 | max_margin = F.relu(self.margin - (x1 - x2)) 49 | 50 | if self.fix_norm: 51 | # remove the elements from the diagonal 52 | keep = th.ones(x.shape) - th.eye(x.shape[0]) # 128 x 128 53 | keep1 = keep.view(-1, 1) 54 | keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) 55 | keep_idx = th.nonzero(th.cat((keep1, keep2), 0).flatten()).flatten() 56 | if x1.is_cuda: 57 | keep_idx = keep_idx.cuda() 58 | x1_ = th.index_select(x1, dim=0, index=keep_idx) 59 | x2_ = th.index_select(x2, dim=0, index=keep_idx) 60 | max_margin = F.relu(self.margin - (x1_ - x2_)) 61 | 62 | return max_margin.mean() 63 | 64 | 65 | class CrossEntropy(nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self.loss = nn.CrossEntropyLoss(ignore_index=-1) 69 | 70 | def forward(self, output, target): 71 | return self.loss(output, target) 72 | 73 | 74 | def cosine_sim(im, s): 75 | """Cosine similarity between all the image and sentence pairs 76 | """ 77 | return im.mm(s.t()) 78 | 79 | 80 | def order_sim(im, s): 81 | """Order embeddings similarity measure $max(0, s-im)$ 82 | """ 83 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1)) 84 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1))) 85 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t() 86 | return score 87 | 88 | 89 | def nll_loss(output, target): 90 | return F.nll_loss(output, target) 91 | 92 | 93 | if __name__ == "__main__": 94 | import torch 95 | 96 | random_sims = (torch.rand([10, 8]) * 2) - 1 97 | loss = NormSoftmaxLoss() 98 | loss(random_sims) 99 | -------------------------------------------------------------------------------- /model/qa_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.weight_norm import weight_norm 3 | 4 | 5 | 6 | 7 | class FCNet(nn.Module): 8 | """ 9 | Simple class for multi-layer non-linear fully connect network 10 | Activate function: ReLU() 11 | """ 12 | def __init__(self, dims, dropout=0.0, norm=True): 13 | super(FCNet, self).__init__() 14 | self.num_layers = len(dims) -1 15 | self.drop = dropout 16 | self.norm = norm 17 | self.main = nn.Sequential(*self._init_layers(dims)) 18 | 19 | def _init_layers(self, dims): 20 | layers = [] 21 | for i in range(self.num_layers): 22 | in_dim = dims[i] 23 | out_dim = dims[i + 1] 24 | # layers.append(nn.Dropout(self.drop)) 25 | if self.norm: 26 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 27 | else: 28 | layers.append(nn.Linear(in_dim, out_dim)) 29 | layers.append(nn.ReLU()) 30 | return layers 31 | 32 | def forward(self, x): 33 | return self.main(x) 34 | 35 | 36 | class SimpleClassifier(nn.Module): 37 | def __init__(self, in_dim, hid_dim, out_dim, dropout=0.0): 38 | super(SimpleClassifier, self).__init__() 39 | self.q_net = FCNet([in_dim[0], hid_dim[0]], dropout) 40 | self.v_net = FCNet([in_dim[1], hid_dim[0]], dropout) 41 | self.main = nn.Sequential( 42 | nn.Linear(hid_dim[0], hid_dim[1]), 43 | nn.ReLU(), 44 | nn.Dropout(dropout, inplace=True), 45 | nn.Linear(hid_dim[1], out_dim) 46 | ) 47 | 48 | def forward(self, q_emb, v_emb): 49 | joint_repr = self.q_net(q_emb) * self.v_net(v_emb) 50 | logits = self.main(joint_repr) 51 | return logits 52 | 53 | class BUTDQAHead(nn.Module): 54 | def __init__(self, v_dim, q_dim, hid_dim, num_ans): 55 | super(BUTDQAHead, self).__init__() 56 | self.classifier = SimpleClassifier([q_dim, v_dim], [hid_dim, hid_dim*2], num_ans) 57 | 58 | def forward(self, video_embed, question_embed): 59 | logits = self.classifier(question_embed, video_embed) 60 | return logits 61 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | import time 10 | import inspect 11 | 12 | 13 | class ConfigParser: 14 | def __init__(self, args, options='', timestamp=True, test=False): 15 | # parse default and custom cli options 16 | for opt in options: 17 | args.add_argument(*opt.flags, default=None, type=opt.type) 18 | args = args.parse_args() 19 | self.args = args 20 | if args.device: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 22 | if args.resume is None: 23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 24 | assert args.config is not None, msg_no_cfg 25 | self.cfg_fname = Path(args.config) 26 | config = read_json(self.cfg_fname) 27 | self.resume = None 28 | else: 29 | self.resume = Path(args.resume) 30 | resume_cfg_fname = self.resume.parent / 'config.json' 31 | config = read_json(resume_cfg_fname) 32 | if args.config is not None: 33 | config.update(read_json(Path(args.config))) 34 | 35 | 36 | 37 | # load config file and apply custom cli options 38 | self._config = _update_config(config, options, args) 39 | 40 | # set save_dir where trained model and log will be saved. 41 | save_dir = Path(self.config['trainer']['save_dir']) 42 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 43 | 44 | exper_name = self.config['name'] 45 | # self._save_dir = save_dir / 'models' / exper_name / timestamp 46 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp 47 | # self._log_dir = save_dir / 'log' / exper_name / timestamp 48 | 49 | # modified by Mr. Yan. Associate log with model in the same folder 50 | if args.debug: 51 | self._save_dir = save_dir / "debug" 52 | self._log_dir = self._save_dir 53 | else: 54 | self._save_dir = save_dir / exper_name # / timestamp 55 | self._log_dir = self._save_dir 56 | 57 | if args.auto_resume: 58 | resume_cfg_fname = self._save_dir / 'config.json' 59 | if os.path.exists(self._save_dir): 60 | self.resume = self._save_dir / 'model_latest.pth' 61 | if not os.path.exists(self.resume): 62 | self.resume = None 63 | 64 | 65 | # if os.path.exists(resume_cfg_fname): 66 | # print(resume_cfg_fname) 67 | # config = read_json(resume_cfg_fname) 68 | # # if args.config is not None: 69 | # # config.update(read_json(Path(args.config))) 70 | # self.resume = self._save_dir / 'model_latest.pth' 71 | # if not os.path.exists(self.resume): 72 | # self.resume = None 73 | 74 | 75 | if not test: 76 | self.save_dir.mkdir(parents=True, exist_ok=True) 77 | self.log_dir.mkdir(parents=True, exist_ok=True) 78 | 79 | 80 | 81 | # if set, remove all previous experiments with the current config 82 | if vars(args).get("purge_exp_dir", False): 83 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 84 | config_dir = dirpath.parent 85 | existing = list(config_dir.glob("*")) 86 | print(f"purging {len(existing)} directories from config_dir...") 87 | tic = time.time() 88 | os.system(f"rm -rf {config_dir}") 89 | print(f"Finished purge in {time.time() - tic:.3f}s") 90 | 91 | # save updated config file to the checkpoint dir 92 | if not test: 93 | write_json(self.config, self.save_dir / 'config.json') 94 | 95 | # configure logging module 96 | setup_logging(self.log_dir) 97 | self.log_levels = { 98 | 0: logging.WARNING, 99 | 1: logging.INFO, 100 | 2: logging.DEBUG 101 | } 102 | 103 | def initialize(self, name, module, *args, index=None, **kwargs): 104 | """ 105 | finds a function handle with the name given as 'type' in config, and returns the 106 | instance initialized with corresponding keyword args given as 'args'. 107 | """ 108 | if index is None: 109 | module_name = self[name]['type'] 110 | module_args = dict(self[name]['args']) 111 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 112 | module_args.update(kwargs) 113 | else: 114 | module_name = self[name][index]['type'] 115 | module_args = dict(self[name][index]['args']) 116 | 117 | # if parameter not in config subdict, then check if it's in global config. 118 | signature = inspect.signature(getattr(module, module_name).__init__) 119 | print(module_name) 120 | for param in signature.parameters.keys(): 121 | if param not in module_args and param in self.config: 122 | module_args[param] = self[param] 123 | if module_name == 'FrozenInTime' and param == 'args': 124 | module_args[param] = self.args 125 | if module_name == 'MultiDistTextVideoDataLoader' and param == 'args': 126 | module_args[param] = self.args 127 | 128 | return getattr(module, module_name)(*args, **module_args) 129 | 130 | def __getitem__(self, name): 131 | return self.config[name] 132 | 133 | def get_logger(self, name, verbosity=2): 134 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 135 | self.log_levels.keys()) 136 | assert verbosity in self.log_levels, msg_verbosity 137 | logger = logging.getLogger(name) 138 | logger.setLevel(self.log_levels[verbosity]) 139 | return logger 140 | 141 | # setting read-only attributes 142 | @property 143 | def config(self): 144 | return self._config 145 | 146 | @property 147 | def save_dir(self): 148 | return self._save_dir 149 | 150 | @property 151 | def log_dir(self): 152 | return self._log_dir 153 | 154 | 155 | # helper functions used to update config dict with custom cli options 156 | def _update_config(config, options, args): 157 | for opt in options: 158 | value = getattr(args, _get_opt_name(opt.flags)) 159 | if value is not None: 160 | _set_by_path(config, opt.target, value) 161 | return config 162 | 163 | 164 | def _get_opt_name(flags): 165 | for flg in flags: 166 | if flg.startswith('--'): 167 | return flg.replace('--', '') 168 | return flags[0].replace('--', '') 169 | 170 | 171 | def _set_by_path(tree, keys, value): 172 | """Set a value in a nested object in tree by sequence of keys.""" 173 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 174 | 175 | 176 | def _get_by_path(tree, keys): 177 | """Access a nested object in tree by sequence of keys.""" 178 | return reduce(getitem, keys, tree) 179 | -------------------------------------------------------------------------------- /pre-training.sh: -------------------------------------------------------------------------------- 1 | 2 | # install environments 3 | sh setup_myEnv.sh 4 | 5 | 6 | 7 | # goto workdir 8 | # echo "load path ....." 9 | cd your_path/Region_Learner 10 | 11 | 12 | 13 | 14 | 15 | nproc_per_node=1 # determined by your resource 16 | data_root="data" 17 | 18 | 19 | 20 | 21 | 22 | # NOTE: Not all videos can be download. Maybe you need adjust the meta_data for each dataset class defined in 'data_loader'. 23 | WebVid_root=$data_root"/WebVid" 24 | CC3M_root=$data_root"/CC3M/" 25 | save_dir="./results/pt/" 26 | 27 | 28 | # new script 29 | # Pre-training on WebVid-2M and CC3M 30 | # NOTES: It will takes several minutes before the first epoch. Training speed depends on status of your IO system. 31 | python -m torch.distributed.launch --nproc_per_node $nproc_per_node $@ train.py \ 32 | --config configs/pt/CC3M-WebVid2M.json --launcher pytorch \ 33 | --save_dir $save_dir --data_dir_0 $CC3M_root --data_dir_1 $WebVid_root \ 34 | --epochs 50 --schedule 30 40 35 | 36 | 37 | 38 | 39 | 40 | if [ $? != 0 ]; then 41 | echo "Fail! Exit with 1" 42 | exit 1 43 | else 44 | echo "Success! Exit with 0" 45 | exit 0 46 | fi 47 | -------------------------------------------------------------------------------- /setup_myEnv.sh: -------------------------------------------------------------------------------- 1 | # Our code is run with CUDA11 on 16 A100 2 | 3 | # ----------------------------------------------------------------------------- 4 | # activate conda env 5 | # ----------------------------------------------------------------------------- 6 | conda deactivate 7 | conda activate env-3.6.8 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Install some important packages. (Your can also directly package them into your docker!) 11 | # ----------------------------------------------------------------------------- 12 | pip install torch==1.8.0+cu111 torchvision==0.9.0 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 13 | pip install decord dominate sacred numpy nltk gensim textblob googletrans textaugment gensim==3.4.0 14 | pip install addict future lmdb Pillow pyyaml requests scikit-image scipy tb-nightly tqdm yapf 15 | pip install av psutil msgpack humanize ipdb scipy sklearn transformers timm==0.4.5 einops 16 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 black==19.3b0 flake8 isort parameterized setuptools simplejson 17 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /trainer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/trainer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainer/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/trainer/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/html.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/utils/__pycache__/html.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/utils/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/Region_Learner/edd3d1c2b95d4cef5d0cef9bda27001c6282eb3b/utils/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Tuple, Any, Optional 5 | from torchvision.transforms import functional_pil as F_pil 6 | from torchvision.transforms import functional_tensor as F_t 7 | from torchvision.transforms.functional import center_crop, crop 8 | 9 | def _get_image_size(img: Tensor) -> List[int]: 10 | """Returns image size as [w, h] 11 | """ 12 | if isinstance(img, torch.Tensor): 13 | return F_t._get_image_size(img) 14 | 15 | return F_pil._get_image_size(img) 16 | 17 | def center_plus_four_crops(img: Tensor, size: List[int], 18 | margin_h: int, margin_w: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 19 | """Crop the given image into four tiled borders and the central crop. 20 | """ 21 | 22 | if isinstance(size, numbers.Number): 23 | size = (int(size), int(size)) 24 | elif isinstance(size, (tuple, list)) and len(size) == 1: 25 | size = (size[0], size[0]) 26 | 27 | if len(size) != 2: 28 | raise ValueError("Please provide only two dimensions (h, w) for size.") 29 | 30 | image_width, image_height = _get_image_size(img) 31 | 32 | crop_height, crop_width = size 33 | 34 | if crop_width > image_width or crop_height > image_height: 35 | msg = "Requested crop size {} is bigger than input size {}" 36 | raise ValueError(msg.format(size, (image_height, image_width))) 37 | 38 | if crop_width + margin_w > image_width: 39 | msg = "Requested margin size {} + input {} is bigger than input size {}" 40 | raise ValueError(msg.format((margin_h, margin_w), size, (image_height, image_width))) 41 | 42 | #vertical_border_height = image_height - crop_height 43 | #horizontal_border_height = image_width - crop_width 44 | 45 | #x1 = horizontal_border_height // 2 46 | x11 = (image_width - crop_width - 2 * margin_w) // 2 47 | x12 = x11 + margin_w 48 | x21 = x12 + crop_width 49 | x22 = x21 + margin_w 50 | 51 | y11 = (image_height - crop_height - 2 * margin_h) // 2 52 | y12 = y11 + margin_h 53 | y21 = y12 + crop_height 54 | y22 = y21 + margin_h 55 | 56 | tl = crop(img, y11, x11, margin_h, margin_w + crop_width) 57 | tr = crop(img, y11, x21, margin_h + crop_height, margin_w) 58 | bl = crop(img, y12, x11, margin_h + crop_height, margin_w) 59 | br = crop(img, y21, x12, margin_h, margin_w + crop_width) 60 | center = center_crop(img, [crop_height, crop_width]) 61 | 62 | return tl, tr, bl, br, center 63 | 64 | 65 | 66 | def center_plus_twohori_crops(img: Tensor, size: List[int], 67 | margin_w: int) -> Tuple[Tensor, Tensor, Tensor]: 68 | """Crop the given image into four tiled borders and the central crop. 69 | """ 70 | 71 | if isinstance(size, numbers.Number): 72 | size = (int(size), int(size)) 73 | elif isinstance(size, (tuple, list)) and len(size) == 1: 74 | size = (size[0], size[0]) 75 | 76 | if len(size) != 2: 77 | raise ValueError("Please provide only two dimensions (h, w) for size.") 78 | 79 | image_width, image_height = _get_image_size(img) 80 | 81 | crop_height, crop_width = size 82 | 83 | if crop_width > image_width or crop_height > image_height: 84 | msg = "Requested crop size {} is bigger than input size {}" 85 | raise ValueError(msg.format(size, (image_height, image_width))) 86 | 87 | if crop_width + margin_w > image_width : 88 | msg = "Requested margin size {} + input {} is bigger than input size {}" 89 | raise ValueError(msg.format((0, margin_w), size, (image_height, image_width))) 90 | 91 | # vertical_border_height = image_height - crop_height 92 | # horizontal_border_height = image_width - crop_width 93 | 94 | # x1 = horizontal_border_height // 2 95 | x11 = (image_width - crop_width - 2 * margin_w) // 2 96 | x12 = x11 + margin_w 97 | x21 = x12 + crop_width 98 | 99 | y11 = (image_height - crop_height) // 2 100 | 101 | left = crop(img, y11, x11, crop_height, margin_w) 102 | right = crop(img, y11, x21, crop_height, margin_w) 103 | center = center_crop(img, [crop_height, crop_width]) 104 | 105 | return left, right, center 106 | 107 | from torch import nn 108 | class TwoHoriCrop(nn.Module): 109 | def __init__(self, size, margin_w): 110 | super().__init__() 111 | self.size = size 112 | self.margin_w = margin_w 113 | 114 | def forward(self, x): 115 | return center_plus_twohori_crops(x, self.size, self.margin_w) 116 | 117 | if __name__ == "__main__": 118 | from PIL import Image 119 | 120 | img = Image.open('visualisations/guitar.png') 121 | crops = center_plus_four_crops(img, [336, 336], 112, 112) 122 | order = ['tl', 'tr', 'bl', 'br', 'center'] 123 | 124 | for idx, subimg in zip(order, crops): 125 | subimg.save(f'visualisations/guitar_{idx}.png') 126 | 127 | crops = center_plus_twohori_crops(img, [448, 448], 112) 128 | order = ['left', 'right', 'center2'] 129 | 130 | for idx, subimg in zip(order, crops): 131 | subimg.save(f'visualisations/guitar_{idx}.png') 132 | -------------------------------------------------------------------------------- /utils/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source, attr 3 | from dominate.tags import span 4 | import os 5 | 6 | 7 | class HTML: 8 | """This HTML class allows us to save images and write texts into a single HTML file. 9 | 10 | It consists of functions such as (add a text header to the HTML file), 11 | (add a row of images to the HTML file), and (save the HTML to the disk). 12 | It is based on Python library 'dominate', a Python library for creating and 13 | manipulating HTML documents using a DOM API. 14 | """ 15 | 16 | def __init__(self, web_dir, title, refresh=0): 17 | """Initialize the HTML classes 18 | 19 | Parameters: 20 | web_dir (str) -- a directory that stores the webpage. HTML file will be 21 | created at /index.html; images will be saved at 0: 35 | with self.doc.head: 36 | meta(http_equiv="refresh", content=str(refresh)) 37 | 38 | def get_image_dir(self): 39 | """Return the directory that stores images""" 40 | return self.img_dir 41 | 42 | def add_header(self, text): 43 | """Insert a header to the HTML file 44 | 45 | Parameters: 46 | text (str) -- the header text 47 | """ 48 | with self.doc: 49 | h3(text) 50 | 51 | def add_videos(self, vids, txts, links, width=400, hidden_tag="hidden"): 52 | """add images to the HTML file 53 | 54 | Parameters: 55 | vids (str list) -- a list of image paths 56 | txts (str list) -- a list of image names shown on the website 57 | links (str list) -- a list of hyperref links; when you click an image, 58 | it will redirect you to a new page 59 | """ 60 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 61 | self.doc.add(self.t) 62 | colors = ["red", "blue", "gold", "salman"] 63 | with self.t: 64 | with tr(): 65 | for vid, txt, link in zip(vids, txts, links): 66 | td_style = "word-wrap: break-word; width:{}px".format(width) 67 | with td(style=td_style, halign="center", valign="top"): 68 | with p(): 69 | vid_path = str(vid) 70 | if vid_path == hidden_tag: 71 | p_style = "font-weight: bold; width:{}px;" 72 | p_style = p_style.format(width * 3) 73 | p("hidden video", style=p_style) 74 | else: 75 | with a(href=str(link)): 76 | with video(): 77 | attr(controls="controls") 78 | source(src=vid_path, type="video/mp4") 79 | br() 80 | rows = txt.split("
") 81 | for idx, row in enumerate(rows): 82 | color = colors[idx % len(colors)] 83 | bold_tag = "" 84 | if not row.startswith(bold_tag): 85 | s_style = "color:{};".format(color) 86 | else: 87 | s_style = "color:black; font-weight: bold;" 88 | row = row[len(bold_tag):] 89 | span(row, style=s_style) 90 | 91 | def add_images(self, ims, txts, links, width=400): 92 | """add images to the HTML file 93 | 94 | Parameters: 95 | ims (str list) -- a list of image paths 96 | txts (str list) -- a list of image names shown on the website 97 | links (str list) -- a list of hyperref links; when you click an image, 98 | it will redirect you to a new page 99 | """ 100 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 101 | self.doc.add(self.t) 102 | colors = ["red", "blue", "gold", "salman"] 103 | with self.t: 104 | with tr(): 105 | for im, txt, link in zip(ims, txts, links): 106 | td_style = "word-wrap: break-word;" 107 | with td(style=td_style, halign="center", valign="top"): 108 | with p(): 109 | with a(href=link): 110 | img( 111 | style="width:%dpx" % width, 112 | src=im, 113 | ) 114 | br() 115 | rows = txt.split("
") 116 | for idx, row in enumerate(rows): 117 | color = colors[idx % len(colors)] 118 | bold_tag = "" 119 | if not row.startswith(bold_tag): 120 | s_style = "color:{};".format(color) 121 | else: 122 | s_style = "color:black; font-weight: bold;" 123 | row = row[len(bold_tag):] 124 | span(row, style=s_style) 125 | 126 | def save(self): 127 | """save the current content to the HMTL file""" 128 | html_file = "%s/index.html" % self.web_dir 129 | f = open(html_file, "wt") 130 | f.write(self.doc.render()) 131 | f.close() 132 | 133 | 134 | if __name__ == "__main__": # we show an example usage here. 135 | html = HTML("web/", "test_html") 136 | html.add_header("hello world") 137 | 138 | ims, txts, links = [], [], [] 139 | for n in range(4): 140 | ims.append("image_%d.png" % n) 141 | txts.append("text_%d" % n) 142 | links.append("image_%d.png" % n) 143 | html.add_images(ims, txts, links) 144 | html.save() -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from datetime import datetime 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | import functools 7 | import time 8 | import socket 9 | import numpy as np 10 | import psutil 11 | import msgpack 12 | import humanize 13 | import os 14 | 15 | 16 | def load_json(filename): 17 | with open(filename, 'r') as f: 18 | return json.load(f) 19 | 20 | def load_jsonl(filename): 21 | with open(filename, 'r') as f: 22 | return [json.loads(l.strip("\n")) for l in f.readlines()] 23 | 24 | def list_flatten(t): 25 | return [item for sublist in t for item in sublist] 26 | 27 | def replace_nested_dict_item(obj, key, replace_value): 28 | for k, v in obj.items(): 29 | if isinstance(v, dict): 30 | obj[k] = replace_nested_dict_item(v, key, replace_value) 31 | if key in obj: 32 | obj[key] = replace_value 33 | return obj 34 | 35 | 36 | def state_dict_data_parallel_fix(load_state_dict, curr_state_dict): 37 | load_keys = list(load_state_dict.keys()) 38 | curr_keys = list(curr_state_dict.keys()) 39 | 40 | redo_dp = False 41 | undo_dp = False 42 | if not curr_keys[0].startswith('module.') and load_keys[0].startswith('module.'): 43 | undo_dp = True 44 | elif curr_keys[0].startswith('module.') and not load_keys[0].startswith('module.'): 45 | redo_dp = True 46 | 47 | if undo_dp: 48 | from collections import OrderedDict 49 | new_state_dict = OrderedDict() 50 | for k, v in load_state_dict.items(): 51 | name = k[7:] # remove `module.` 52 | new_state_dict[name] = v 53 | # load params 54 | elif redo_dp: 55 | from collections import OrderedDict 56 | new_state_dict = OrderedDict() 57 | for k, v in load_state_dict.items(): 58 | name = 'module.' + k # remove `module.` 59 | new_state_dict[name] = v 60 | else: 61 | new_state_dict = load_state_dict 62 | return new_state_dict 63 | 64 | def print_numpy(x, val=True, shp=False): 65 | """Print the mean, min, max, median, std, and size of a numpy array 66 | Parameters: 67 | val (bool) -- if print the values of the numpy array 68 | shp (bool) -- if print the shape of the numpy array 69 | """ 70 | x = x.astype(np.float64) 71 | if shp: 72 | print('shape,', x.shape) 73 | if val: 74 | x = x.flatten() 75 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 76 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 77 | 78 | 79 | def mkdirs(paths): 80 | """create empty directories if they don't exist 81 | Parameters: 82 | paths (str list) -- a list of directory paths 83 | """ 84 | if isinstance(paths, list) and not isinstance(paths, str): 85 | for path in paths: 86 | mkdir(path) 87 | else: 88 | mkdir(paths) 89 | 90 | 91 | def mkdir(path): 92 | """create a single empty directory if it didn't exist 93 | Parameters: 94 | path (str) -- a single directory path 95 | """ 96 | if not os.path.exists(path): 97 | os.makedirs(path) 98 | 99 | def read_json(fname): 100 | with fname.open('rt') as handle: 101 | return json.load(handle, object_hook=OrderedDict) 102 | 103 | def write_json(content, fname): 104 | with fname.open('wt') as handle: 105 | json.dump(content, handle, indent=4, sort_keys=False) 106 | 107 | def inf_loop(data_loader): 108 | ''' wrapper function for endless data loader. ''' 109 | for loader in repeat(data_loader): 110 | yield from loader 111 | 112 | def memory_summary(): 113 | vmem = psutil.virtual_memory() 114 | msg = ( 115 | f">>> Currently using {vmem.percent}% of system memory " 116 | f"{humanize.naturalsize(vmem.used)}/{humanize.naturalsize(vmem.available)}" 117 | ) 118 | print(msg) 119 | 120 | @functools.lru_cache(maxsize=64, typed=False) 121 | def memcache(path): 122 | suffix = Path(path).suffix 123 | print(f"loading features >>>", end=" ") 124 | tic = time.time() 125 | if suffix == ".npy": 126 | res = np_loader(path) 127 | else: 128 | raise ValueError(f"unknown suffix: {suffix} for path {path}") 129 | print(f"[Total: {time.time() - tic:.1f}s] ({socket.gethostname() + ':' + str(path)})") 130 | return res 131 | 132 | def np_loader(np_path, l2norm=False): 133 | with open(np_path, "rb") as f: 134 | data = np.load(f, encoding="latin1", allow_pickle=True) 135 | if isinstance(data, np.ndarray) and data.size == 1: 136 | data = data[()] # handle numpy dict storage convnetion 137 | if l2norm: 138 | print("L2 normalizing features") 139 | if isinstance(data, dict): 140 | for key in data: 141 | feats_ = data[key] 142 | feats_ = feats_ / max(np.linalg.norm(feats_), 1E-6) 143 | data[key] = feats_ 144 | elif data.ndim == 2: 145 | data_norm = np.linalg.norm(data, axis=1) 146 | data = data / np.maximum(data_norm.reshape(-1, 1), 1E-6) 147 | else: 148 | raise ValueError("unexpected data format {}".format(type(data))) 149 | return data 150 | 151 | 152 | class Timer: 153 | def __init__(self): 154 | self.cache = datetime.now() 155 | 156 | def check(self): 157 | now = datetime.now() 158 | duration = now - self.cache 159 | self.cache = now 160 | return duration.total_seconds() 161 | 162 | def reset(self): 163 | self.cache = datetime.now() 164 | -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import PIL 5 | import collections 6 | import random 7 | import cv2 8 | import os 9 | import numpy as np 10 | 11 | def load_frames_from_video_path(path, num_frames, sample='rand'): 12 | cap = cv2.VideoCapture(path) 13 | assert (cap.isOpened()) 14 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 15 | acc_samples = min(num_frames, vlen) 16 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 17 | ranges = [] 18 | for idx, interv in enumerate(intervals[:-1]): 19 | ranges.append((interv, intervals[idx + 1] - 1)) 20 | if sample == 'rand': 21 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 22 | elif sample == 'uniform': 23 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 24 | else: 25 | raise NotImplementedError 26 | 27 | frames = [] 28 | for index in frame_idxs: 29 | cap.set(cv2.CAP_PROP_POS_FRAMES, index) 30 | ret, frame = cap.read() 31 | if ret: 32 | cv2.imwrite(f'images/{index}.jpg', frame) 33 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 34 | frame = torch.from_numpy(frame) 35 | # (H x W x C) to (C x H x W) 36 | frame = frame.permute(2, 0, 1) 37 | frames.append(frame) 38 | else: 39 | raise ValueError 40 | 41 | frames = torch.stack(frames).float() / 255 42 | cap.release() 43 | return frames, frame_idxs -------------------------------------------------------------------------------- /utils/visualisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib 4 | 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import ipdb 9 | 10 | def visualise_path(pred, target, window): 11 | """ 12 | :param pred: (P, 2) Tensor where P is the number of predictions, and 2 is the (i,j) coordinate 13 | :param target: (T, 2) Tensor where T is the number of targets, and 2 is the (i,j) coordinate 14 | :param dims: (H, W) tup/list the desired height and width of matrix (should be >= to max(i), max(j)) 15 | :param assignment_method: Method of assignment (dtw, minimum etc.) 16 | :return: image, visualisation of path prediction and target. 17 | """ 18 | tp = torch.Tensor((64, 191, 64)) 19 | fp = torch.Tensor((191, 64, 64)) 20 | gt = torch.Tensor((102, 153, 255)) 21 | 22 | grid = torch.ones_like(window).unsqueeze(0).repeat(3, 1, 1) * 255 23 | inf = 130 * torch.ones_like(grid) 24 | grid = torch.where(torch.isnan(window), inf, grid) 25 | 26 | clip_idxs = [t[0] for t in target] 27 | local_idxs = np.unique(np.array(clip_idxs)).tolist() 28 | 29 | for t in target: 30 | local_idx = local_idxs.index(t[0]) 31 | grid[:, local_idx,t[1]] = gt 32 | 33 | for p in pred: 34 | local_idx = local_idxs.index(p[0]) 35 | if (grid[:, local_idx,p[1]] == gt).all(): 36 | grid[:, local_idx, p[1]] = tp 37 | else: 38 | grid[:, local_idx, p[1]] = fp 39 | 40 | return grid / 255 41 | 42 | 43 | def batch_path_vis(pred_dict, target, window): 44 | 45 | grids = [] 46 | 47 | window = window.cpu() 48 | for key, pred in pred_dict.items(): 49 | tmp_window = window 50 | if key == 'min_dist': 51 | tmp_window = torch.zeros_like(window) 52 | grids.append(visualise_path(pred, target, tmp_window)) 53 | 54 | return torch.stack(grids) 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | pred = [[1,1], [2,4]] 60 | gt = [[1,1], [3,4]] 61 | window = torch.zeros((5,6)) 62 | visualise_path(pred, gt, window) 63 | -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | """A simple HTML visualizer. 2 | 3 | It is based on the Cycle-GAN codebase: 4 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 5 | """ 6 | import os 7 | import numpy as np 8 | from pathlib import Path 9 | from . import util, html 10 | import pdb 11 | 12 | class RetrievalVis: 13 | """This class includes several functions that can display/save images. 14 | 15 | It uses a Python library 'visdom' for display, and a Python library 'dominate' 16 | (wrapped in 'HTML') for creating HTML files with images. 17 | """ 18 | 19 | def __init__(self, exp_name, web_dir, src_video_dir, vis_vid_freq, num_samples=50): 20 | """Initialize the Visualizer class 21 | Create an HTML object for saveing HTML filters 22 | """ 23 | self.name = exp_name 24 | self.web_dir = web_dir 25 | self.vis_vid_freq = vis_vid_freq 26 | self.img_dir = os.path.join(self.web_dir, "images") 27 | self.num_samples = num_samples 28 | 29 | self.data_type = 'images' # 'images' or 'videos' 30 | assert self.data_type in ('images', 'videos') 31 | 32 | print(f"create web directory {self.web_dir}...") 33 | mkdirs([self.web_dir, self.img_dir]) 34 | 35 | # cluster specific 36 | if "$TMPDIR" in src_video_dir: 37 | src_video_dir = src_video_dir.replace("$TMPDIR", os.environ['TMPDIR']) 38 | 39 | src_dir = Path(src_video_dir).absolute() 40 | print(f"symlinking videos from {src_dir}...") 41 | sym_dir = (Path(self.web_dir) / "videos").absolute() 42 | if sym_dir.is_symlink(): 43 | os.remove(sym_dir) 44 | sym_dir.symlink_to(src_dir) 45 | 46 | def visualize_ranking(self, sims, epoch, meta, nested_metrics): 47 | if not (self.vis_vid_freq and epoch % self.vis_vid_freq == 0): 48 | return 49 | 50 | dists = -sims 51 | np.random.seed(0) 52 | sorted_ranks = np.argsort(dists, axis=1) 53 | gt_dists = np.diag(dists) 54 | rankings = [] 55 | vis_top_k = 5 56 | hide_gt = False 57 | # num_indep_samples = 1 58 | # random_seeds = np.arange(num_indep_samples) 59 | sample = np.random.choice(np.arange(dists.shape[0]), size=self.num_samples, 60 | replace=False) 61 | for ii in sample: 62 | ranked_idx = sorted_ranks[ii][:vis_top_k] 63 | gt_captions = meta["raw_captions"][ii] 64 | # if args.sample_single_gt_caption: 65 | # gt_captions = np.random.choice(gt_captions, 1).tolist() 66 | datum = { 67 | "gt-sim": -gt_dists[ii], 68 | "gt-captions": gt_captions, 69 | "gt-rank": np.where(sorted_ranks[ii] == ii)[0][0], 70 | "gt-path": meta["paths"][ii], 71 | "top-k-sims": -dists[ii][ranked_idx], 72 | "top-k-paths": np.array(meta["paths"])[ranked_idx], 73 | "hide-gt": hide_gt, 74 | } 75 | rankings.append(datum) 76 | self.display_current_results( 77 | rankings, 78 | epoch=epoch, 79 | metrics=nested_metrics["t2v_metrics"], 80 | ) 81 | 82 | def display_current_results(self, rankings, epoch, metrics): 83 | """Display current results on visdom; save current results to an HTML file. 84 | 85 | Parameters: 86 | visuals (OrderedDict) - - dictionary of images to display or save 87 | epoch (int) - - the current epoch 88 | save_result (bool) - - if save the current results to an HTML file 89 | """ 90 | if not Path(self.web_dir).exists(): 91 | Path(self.web_dir).mkdir(exist_ok=True, parents=True) 92 | print(f"updating webpage at {self.web_dir}") 93 | title = f"Experiment name = {self.name}" 94 | refresh = True 95 | if not refresh: 96 | print("DISABLING WEB PAGE REFRESH") 97 | webpage = html.HTML(web_dir=self.web_dir, title=title, refresh=refresh) 98 | 99 | msg = f"epoch [{epoch}] - {self.name}" 100 | webpage.add_header(msg) 101 | msg = (f"R1: {metrics['R1']:.1f}, " 102 | f"R5: {metrics['R5']:.1f}, " 103 | f"R10: {metrics['R10']:.1f}, " 104 | f"MedR: {metrics['MedR']}") 105 | webpage.add_header(msg) 106 | print(f"Top {len(rankings[0])} retreived videos at epoch: {epoch}") 107 | 108 | for ranking in rankings: 109 | vids, txts, links = [], [], [] 110 | gt_vid_path = os.path.join('videos', ranking["gt-path"]) 111 | #gt_captions = [" ".join(x) for x in ranking["gt-captions"]] 112 | gt_captions = ranking['gt-captions'] 113 | gt_captions = "
" + (gt_captions) + "
" 114 | if ranking["hide-gt"]: 115 | txts.append(gt_captions) 116 | links.append("hidden") 117 | vids.append("hidden") 118 | else: 119 | txt = (f"{gt_captions}
Rank: {ranking['gt-rank']}, " 120 | f"Sim: {ranking['gt-sim']:.3f} [{Path(ranking['gt-path']).stem}]") 121 | txts.append(txt) 122 | links.append(gt_vid_path) 123 | vids.append(gt_vid_path) 124 | 125 | for idx, (vid_path, sim) in enumerate(zip(ranking["top-k-paths"], 126 | ranking["top-k-sims"])): 127 | vid_path = Path(os.path.join('videos', vid_path)) 128 | if ranking["hide-gt"]: 129 | txt = f"choice: {idx}" 130 | else: 131 | txt = f"Rank: {idx}, Sim: {sim:.3f}, [{Path(vid_path).stem}]" 132 | txts.append(txt) 133 | vids.append(vid_path) 134 | links.append(vid_path) 135 | if self.data_type == 'videos': 136 | webpage.add_videos(vids, txts, links, width=200) 137 | elif self.data_type == 'images': 138 | webpage.add_images(vids, txts, links, width=200) 139 | print(f"added {len(vids)} videos") 140 | webpage.save() 141 | 142 | def mkdirs(paths): 143 | """create empty directories if they don't exist 144 | 145 | Parameters: 146 | paths (str list) -- a list of directory paths 147 | """ 148 | if isinstance(paths, list) and not isinstance(paths, str): 149 | for path in paths: 150 | mkdir(path) 151 | else: 152 | mkdir(paths) 153 | 154 | 155 | def mkdir(path): 156 | """create a single empty directory if it didn't exist 157 | 158 | Parameters: 159 | path (str) -- a single directory path 160 | """ 161 | if not os.path.exists(path): 162 | os.makedirs(path) 163 | --------------------------------------------------------------------------------