├── .DS_Store ├── .idea ├── .gitignore ├── alex_frozen_dist.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── OATrans ├── .DS_Store ├── __init__.py ├── args.py ├── base │ ├── __init__.py │ ├── base_augmentation.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── base_dataset_global_local.py │ ├── base_dataset_region_mem.py │ ├── base_model.py │ └── base_trainer.py ├── configs │ ├── ft │ │ └── msrvtt │ │ │ ├── fine_tune │ │ │ └── normal_1_cl.json │ │ │ └── zsl │ │ │ └── normal.json │ └── pt │ │ └── cc3m_webvid │ │ ├── local-region-loss.json │ │ └── norm.json ├── data_loader │ ├── ConceptualCaptions_dataset.py │ ├── DiDeMo_dataset.py │ ├── LSMDC_choice_dataset.py │ ├── LSMDC_dataset.py │ ├── MSRVTT_dataset.py │ ├── MSVD_dataset.py │ ├── WebVid_dataset.py │ ├── __init__.py │ ├── data_loader.py │ ├── data_loader_v2.py │ └── transforms.py ├── logger │ ├── __init__.py │ ├── logger.py │ ├── logger_config.json │ └── visualization.py ├── model │ ├── __init__.py │ ├── loss.py │ ├── metric.py │ ├── model.py │ ├── model_dist.py │ ├── oa_loss.py │ ├── oa_model.py │ ├── oa_model_global_local.py │ ├── oa_model_region_mem.py │ ├── oa_video_transformer_global_local.py │ ├── oa_video_transformer_region.py │ ├── prompt_learner.py │ └── video_transformer.py ├── options.py ├── parse_config.py ├── parse_config_dist_multi.py ├── test.py ├── test_region_mem.py ├── train.py ├── train_dist_multi.py ├── train_dist_multi_global_local.py ├── train_dist_region_mem.py ├── trainer │ ├── __init__.py │ ├── trainer.py │ ├── trainer_dist.py │ ├── trainer_global_local.py │ └── trainer_region_mem.py └── utils │ ├── .DS_Store │ ├── __init__.py │ ├── binary_classification_accuracy.py │ ├── custom_transforms.py │ ├── html.py │ ├── objects_vocab.txt │ ├── objects_vocab_fine_grained.txt │ ├── objects_vocab_token_len │ ├── objects_vocab_token_len.txt │ ├── param_forzen.py │ ├── unit_test │ ├── __init__.py │ ├── distill_bert.py │ ├── load_msvd_video.py │ └── region_roi_example.py │ ├── util.py │ ├── video.py │ ├── visualization │ ├── .DS_Store │ ├── 3f_vto_visualize.py │ ├── __init__.py │ ├── learned_embedding_visualization.py │ ├── msrvtt_3f_vto_visualize.py │ ├── msrvtt_vto_visualization.py │ ├── predict_visualization │ │ ├── 0_predict.png │ │ ├── 10_predict.png │ │ ├── 11_predict.png │ │ ├── 12_predict.png │ │ ├── 13_predict.png │ │ ├── 14_predict.png │ │ ├── 1_predict.png │ │ ├── 2_predict.png │ │ ├── 3_predict.png │ │ ├── 4_predict.png │ │ ├── 5_predict.png │ │ ├── 6_predict.png │ │ ├── 7_predict.png │ │ ├── 8_predict.png │ │ └── 9_predict.png │ ├── print_tags.py │ ├── transfer_predict_visualization │ │ ├── 0_predict.png │ │ ├── 10_predict.png │ │ ├── 11_predict.png │ │ ├── 12_predict.png │ │ ├── 13_predict.png │ │ ├── 14_predict.png │ │ ├── 15_predict.png │ │ ├── 1_predict.png │ │ ├── 2_predict.png │ │ ├── 3_predict.png │ │ ├── 4_predict.png │ │ ├── 5_predict.png │ │ ├── 6_predict.png │ │ ├── 7_predict.png │ │ ├── 8_predict.png │ │ └── 9_predict.png │ └── webvid_vto_visualization.py │ └── visualizer.py ├── ObjectExtractor ├── multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py └── multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py ├── README.md ├── Visualization ├── .DS_Store └── Cross_Modality_Transformer_Visualization │ ├── .DS_Store │ ├── data │ └── webvid_validation_success_full.tsv │ ├── data_preprocess.py │ ├── main_img.py │ ├── main_video.py │ ├── main_video_patches_visualization.py │ ├── model │ ├── __init__.py │ ├── text_model.py │ ├── text_models │ │ └── distill_bert.py │ ├── vision_model.py │ └── vision_models │ │ ├── clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ │ └── frozen.py │ ├── parse_config.py │ ├── patch_mask.py │ ├── utils │ ├── nltk_test.py │ └── read_bboxs.py │ └── visualize.py ├── environment.yml ├── figures ├── oa_main_ppl.jpg ├── oa_visualize_1.jpg ├── oa_visualize_2.jpg ├── objects.jpg └── objects_2.png ├── object_extraction.md ├── train.md └── visualization.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/alex_frozen_dist.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /OATrans/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/.DS_Store -------------------------------------------------------------------------------- /OATrans/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/__init__.py -------------------------------------------------------------------------------- /OATrans/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(description='MILNCE'): 4 | parser = argparse.ArgumentParser(description=description) 5 | parser.add_argument( 6 | '--train_csv', 7 | type=str, 8 | default='csv/hmdb51.csv', 9 | help='train csv') 10 | parser.add_argument( 11 | '--video_path', 12 | type=str, 13 | default='', 14 | help='video_path') 15 | parser.add_argument( 16 | '--caption_root', 17 | type=str, 18 | default='', 19 | help='video_path') 20 | parser.add_argument( 21 | '--checkpoint_root', 22 | type=str, 23 | default='checkpoint', 24 | help='checkpoint dir root') 25 | parser.add_argument( 26 | '--log_root', 27 | type=str, 28 | default='log', 29 | help='log dir root') 30 | parser.add_argument( 31 | '--eval_video_root', 32 | type=str, 33 | default='', 34 | help='root folder for the video at for evaluation') 35 | parser.add_argument( 36 | '--checkpoint_dir', 37 | type=str, 38 | default='', 39 | help='checkpoint model folder') 40 | parser.add_argument( 41 | '--optimizer', type=str, default='adam', help='opt algorithm') 42 | parser.add_argument('--weight_init', type=str, default='uniform', 43 | help='CNN weights inits') 44 | parser.add_argument('--num_thread_reader', type=int, default=20, 45 | help='') 46 | parser.add_argument('--num_class', type=int, default=512, 47 | help='upper epoch limit') 48 | parser.add_argument('--num_candidates', type=int, default=1, 49 | help='num candidates for MILNCE loss') 50 | parser.add_argument('--batch_size', type=int, default=256, 51 | help='batch size') 52 | parser.add_argument('--num_windows_test', type=int, default=4, 53 | help='number of testing windows') 54 | parser.add_argument('--batch_size_val', type=int, default=32, 55 | help='batch size eval') 56 | parser.add_argument('--momemtum', type=float, default=0.9, 57 | help='SGD momemtum') 58 | parser.add_argument('--n_display', type=int, default=10, 59 | help='Information display frequence') 60 | parser.add_argument('--num_frames', type=int, default=16, 61 | help='random seed') 62 | parser.add_argument('--video_size', type=int, default=224, 63 | help='random seed') 64 | parser.add_argument('--crop_only', type=int, default=1, 65 | help='random seed') 66 | parser.add_argument('--centercrop', type=int, default=0, 67 | help='random seed') 68 | parser.add_argument('--random_flip', type=int, default=1, 69 | help='random seed') 70 | parser.add_argument('--verbose', type=int, default=1, 71 | help='') 72 | parser.add_argument('--warmup_steps', type=int, default=5000, 73 | help='') 74 | parser.add_argument('--min_time', type=float, default=5.0, 75 | help='') 76 | parser.add_argument( 77 | '--pretrain_cnn_path', 78 | type=str, 79 | default='', 80 | help='') 81 | parser.add_argument( 82 | '--word2vec_path', type=str, default='data/word2vec.pth', help='') 83 | parser.add_argument('--fps', type=int, default=5, help='') 84 | parser.add_argument('--cudnn_benchmark', type=int, default=0, 85 | help='') 86 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 87 | help='number of total epochs to run') 88 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 89 | help='manual epoch number (useful on restarts)') 90 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 91 | metavar='LR', help='initial learning rate', dest='lr') 92 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 93 | help='momentum') 94 | parser.add_argument('--resume', dest='resume', action='store_true', 95 | help='resume training from last checkpoint') 96 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 97 | help='evaluate model on validation set') 98 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 99 | help='use pre-trained model') 100 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true', 101 | help='use pin_memory') 102 | parser.add_argument('--world-size', default=-1, type=int, 103 | help='number of nodes for distributed training') 104 | parser.add_argument('--rank', default=-1, type=int, 105 | help='node rank for distributed training') 106 | parser.add_argument('--dist-file', default='dist-file', type=str, 107 | help='url used to set up distributed training') 108 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 109 | help='url used to set up distributed training') 110 | parser.add_argument('--dist-backend', default='nccl', type=str, 111 | help='distributed backend') 112 | parser.add_argument('--seed', default=1, type=int, 113 | help='seed for initializing training. ') 114 | parser.add_argument('--gpu', default=None, type=int, 115 | help='GPU id to use.') 116 | parser.add_argument('--multiprocessing-distributed', action='store_true', 117 | help='Use multi-processing distributed training to launch ' 118 | 'N processes per node, which has N GPUs. This is the ' 119 | 'fastest way to use PyTorch for either single node or ' 120 | 'multi node data parallel training') 121 | args = parser.parse_args() 122 | return args 123 | -------------------------------------------------------------------------------- /OATrans/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | # from .base_dataset_v3 import * 3 | from .base_dataset import * 4 | from .base_model import * 5 | from .base_trainer import * -------------------------------------------------------------------------------- /OATrans/base/base_augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import ImageFilter 3 | # import nltk 4 | # nltk.data.path.append("pretrained/nltk_data") 5 | # from textaugment import EDA 6 | 7 | 8 | def textaug_eda(caption): 9 | aug_caption = caption 10 | t = EDA() 11 | if random.random() < 0.5: 12 | if random.random() < 0.3: 13 | aug_caption = t.synonym_replacement(aug_caption) 14 | aug_caption = t.random_deletion(aug_caption, p=random.random()*0.3) 15 | if random.random() < 0.3: 16 | aug_caption = t.random_swap(aug_caption) 17 | if random.random() < 0.3: 18 | aug_caption = t.random_insertion(aug_caption) 19 | return aug_caption 20 | 21 | 22 | def textaug_advanced(caption, aug_model): 23 | return aug_model.augment(caption) 24 | 25 | 26 | 27 | def mask_aug(sentence): 28 | words = sentence.split(' ') 29 | word_index = random.randint(0, len(words)) 30 | words[word_index] = "[MASK]" 31 | new_cpation = ' '.join(words) 32 | new_sentence = "" 33 | # shuffle object localization 34 | # random drop some objects 35 | return new_sentence 36 | 37 | 38 | class GaussianBlur(object): 39 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 40 | 41 | def __init__(self, sigma=[.1, 2.]): 42 | self.sigma = sigma 43 | 44 | def __call__(self, x): 45 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 46 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 47 | return x -------------------------------------------------------------------------------- /OATrans/base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | # turn off shuffle option which is mutually exclusive with sampler 51 | self.shuffle = False 52 | self.n_samples = len(train_idx) 53 | 54 | return train_sampler, valid_sampler 55 | 56 | def split_validation(self, diff_kwargs=None): 57 | init_kwargs = self.init_kwargs 58 | if diff_kwargs is not None: 59 | init_kwargs.update(diff_kwargs) 60 | if self.valid_sampler is None: 61 | return None 62 | else: 63 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 64 | 65 | def num_samples(self): 66 | return len(self.sampler) 67 | 68 | 69 | class BaseDataLoaderExplicitSplit(DataLoader): 70 | """ 71 | Base class for all data loaders 72 | """ 73 | def __init__(self, args, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 74 | self.shuffle = shuffle 75 | self.args = args 76 | self.batch_idx = 0 77 | self.n_samples = len(dataset) 78 | 79 | self.init_kwargs = { 80 | 'dataset': dataset, 81 | 'batch_size': batch_size, 82 | 'shuffle': self.shuffle, 83 | 'collate_fn': collate_fn, 84 | 'num_workers': num_workers, 85 | 'pin_memory': True 86 | } 87 | super().__init__(**self.init_kwargs) 88 | 89 | class DistBaseDataLoaderExplicitSplit(DataLoader): 90 | """ 91 | Base class for all data loaders 92 | """ 93 | def __init__(self, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 94 | self.shuffle = shuffle 95 | 96 | self.batch_idx = 0 97 | self.n_samples = len(dataset) 98 | self.train_sampler = DistributedSampler(dataset) 99 | self.init_kwargs = { 100 | 'dataset': dataset, 101 | 'batch_size': batch_size, 102 | 'shuffle': False, 103 | 'collate_fn': collate_fn, 104 | 'num_workers': num_workers, 105 | 'pin_memory': True, 106 | 'sampler': self.train_sampler 107 | } 108 | super().__init__(**self.init_kwargs) 109 | 110 | class MultiDistBaseDataLoaderExplicitSplit(DataLoader): 111 | """ 112 | Base class for all data loaders 113 | """ 114 | def __init__(self, args, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate): 115 | self.shuffle = shuffle 116 | 117 | self.batch_idx = 0 118 | self.n_samples = len(dataset) 119 | self.args = args 120 | self.train_sampler = DistributedSampler(dataset, num_replicas=self.args.world_size, rank=self.args.rank, drop_last=True) 121 | self.init_kwargs = { 122 | 'dataset': dataset, 123 | 'batch_size': batch_size, 124 | 'shuffle': False, 125 | 'collate_fn': collate_fn, 126 | 'num_workers': num_workers, 127 | 'pin_memory': True, 128 | 'sampler': self.train_sampler 129 | } 130 | super().__init__(**self.init_kwargs) 131 | 132 | class BaseMultiDataLoader: 133 | """ 134 | Currently implemented as undersample the bigger dataloaders... 135 | """ 136 | def __init__(self, dataloaders): 137 | self.dataloaders = dataloaders 138 | self.batch_size = self.dataloaders[0].batch_size 139 | def __getitem__(self, item): 140 | dl_idx = item % len(self.dataloaders) 141 | return next(iter(self.dataloaders[dl_idx])) 142 | 143 | def __len__(self): 144 | return min([len(x) for x in self.dataloaders]) * len(self.dataloaders) 145 | 146 | def num_samples(self): 147 | return sum([len(x.sampler) for x in self.dataloaders]) 148 | -------------------------------------------------------------------------------- /OATrans/base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTTjsfusion_4f_stformer_pt-im21k", 3 | "n_gpu": 8, 4 | "linear_evaluation": false, 5 | "arch": { 6 | "type": "FrozenInTime", 7 | "stream": 2, 8 | "object": false, 9 | "args": { 10 | "video_params": { 11 | "model": "SpaceTimeTransformer", 12 | "arch_config": "base_patch16_224", 13 | "num_frames": 4, 14 | "pretrained": true, 15 | "time_init": "zeros", 16 | "two_outputs": false, 17 | "object_pseudo_label": false 18 | }, 19 | "object_params": { 20 | "model": "", 21 | "input_objects": false 22 | }, 23 | "text_params": { 24 | "model": "pretrained/distilbert-base-uncased", 25 | "pretrained": true, 26 | "input": "text", 27 | "two_outputs": false 28 | }, 29 | "projection": "minimal", 30 | "load_checkpoint": "exps/2stream_wtags/models/full-WebVid2M-1f-pti2k/0106_180724/checkpoint-epoch3.pth" 31 | } 32 | }, 33 | "data_loader": 34 | [ 35 | { 36 | "type": "TextObjectVideoDataLoader", 37 | "args":{ 38 | "dataset_name": "MSRVTT", 39 | "data_dir": "MSRVTT/", 40 | "object_dir": "MSRVTT/region_features_full/", 41 | "shuffle": true, 42 | "num_workers": 8, 43 | "batch_size": 64, 44 | "split": "train", 45 | "cut": "jsfusion", 46 | "subsample": 1, 47 | "text_params": { 48 | "object_tags": false, 49 | "drop_raw_caption": false, 50 | "text_aug": false, 51 | "object_aug": false 52 | }, 53 | "object_params": { 54 | "input_objects": false, 55 | "pseudo_labels": false, 56 | "input_object_bboxs": false 57 | }, 58 | "video_params": { 59 | "extraction_fps": 25, 60 | "extraction_res": 256, 61 | "input_res": 224, 62 | "num_frames": 4, 63 | "stride": 1 64 | } 65 | } 66 | } 67 | ], 68 | "optimizer": { 69 | "type": "AdamW", 70 | "args":{ 71 | "lr": 3e-5 72 | } 73 | }, 74 | "loss": { 75 | "type": "NormSoftmaxLoss", 76 | "args": { 77 | } 78 | }, 79 | "metrics": [ 80 | "t2v_metrics", 81 | "v2t_metrics" 82 | ], 83 | "trainer": { 84 | "epochs": 100, 85 | "max_samples_per_epoch": 9000, 86 | "save_dir": "exps", 87 | "save_period": 5, 88 | "verbosity": 2, 89 | "monitor": "min val_loss_0", 90 | "early_stop": 10, 91 | "neptune": false 92 | }, 93 | "visualizer": { 94 | "type": "", 95 | "args": { 96 | } 97 | } 98 | 99 | } -------------------------------------------------------------------------------- /OATrans/configs/ft/msrvtt/zsl/normal.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSRVTTjsfusion_4f_stformer_pt-im21k", 3 | "n_gpu": 2, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "stream": 2, 7 | "object": true, 8 | "args": { 9 | "video_params": { 10 | "model": "SpaceTimeTransformer", 11 | "arch_config": "base_patch16_224", 12 | "num_frames": 4, 13 | "pretrained": true, 14 | "time_init": "zeros", 15 | "two_outputs": false, 16 | "object_pseudo_label": false 17 | }, 18 | "object_params": { 19 | "model": "", 20 | "input_objects": false 21 | }, 22 | "text_params": { 23 | "model": "pretrained/distilbert-base-uncased", 24 | "pretrained": true, 25 | "input": "text", 26 | "two_outputs": true 27 | }, 28 | "projection": "minimal", 29 | "load_checkpoint": "exps/2stream_wtags/models/full-WebVid2M-1f-pti2k/0106_180724/checkpoint-epoch3.pth" 30 | } 31 | }, 32 | "data_loader": { 33 | "type": "MultiDistTextObjectVideoDataLoader", 34 | "args":{ 35 | "dataset_name": "MSRVTT", 36 | "data_dir": "MSRVTT/", 37 | "object_dir": "MSRVTT/region_features_full/", 38 | "shuffle": true, 39 | "num_workers": 8, 40 | "batch_size": 16, 41 | "split": "train", 42 | "cut": "jsfusion", 43 | "subsample": 1, 44 | "text_params": { 45 | "object_tags": true, 46 | "drop_raw_caption": false, 47 | "text_aug": false, 48 | "object_aug": false 49 | }, 50 | "object_params": { 51 | "input_objects": false, 52 | "pseudo_labels": false, 53 | "input_object_bboxs":false 54 | }, 55 | "video_params": { 56 | "extraction_fps": 25, 57 | "extraction_res": 256, 58 | "input_res": 224, 59 | "num_frames": 4, 60 | "stride": 1 61 | } 62 | } 63 | }, 64 | "optimizer": { 65 | "type": "AdamW", 66 | "args":{ 67 | "lr": 3e-5 68 | } 69 | }, 70 | "loss": { 71 | "type": "NormSoftmaxLoss", 72 | "args": { 73 | } 74 | }, 75 | "metrics": [ 76 | "t2v_metrics", 77 | "v2t_metrics" 78 | ], 79 | "trainer": { 80 | "epochs": 100, 81 | "max_samples_per_epoch": 9000, 82 | "save_dir": "exps", 83 | "save_period": 5, 84 | "verbosity": 2, 85 | "monitor": "min val_loss", 86 | "early_stop": 10, 87 | "neptune": true 88 | }, 89 | "visualizer": { 90 | "type": "", 91 | "args": { 92 | } 93 | } 94 | 95 | } -------------------------------------------------------------------------------- /OATrans/configs/pt/cc3m_webvid/local-region-loss.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "full-cc-WebVid2M-1f-pti2k", 3 | "n_gpu": 4, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "object": true, 7 | "stream": 2, 8 | "args": { 9 | "video_params": { 10 | "model": "SpaceTimeTransformer", 11 | "arch_config": "base_patch16_224", 12 | "num_frames": 4, 13 | "pretrained": true, 14 | "time_init": "zeros", 15 | "two_outputs": false, 16 | "object_pseudo_label": false 17 | }, 18 | "object_params": { 19 | "model": "", 20 | "input_objects": false 21 | }, 22 | "text_params": { 23 | "model": "pretrained/distilbert-base-uncased", 24 | "pretrained": true, 25 | "input": "text", 26 | "two_outputs": true 27 | }, 28 | "projection": "minimal", 29 | "load_checkpoint" : "" 30 | } 31 | }, 32 | "data_loader": 33 | [ 34 | { 35 | "type": "MultiDistTextObjectVideoDataLoader", 36 | "args":{ 37 | "dataset_name": "ConceptualCaptions3M", 38 | "data_dir": "CC3M/", 39 | "object_dir": "CC3M/1_frame_object", 40 | "reader": "cv2", 41 | "shuffle": true, 42 | "num_workers": 8, 43 | "batch_size": 16, 44 | "split": "train", 45 | "subsample": 1, 46 | "text_params": { 47 | }, 48 | "object_params": { 49 | }, 50 | "video_params": { 51 | "input_res": 224, 52 | "num_frames": 1, 53 | "loading": "lax" 54 | } 55 | } 56 | }, 57 | { 58 | "type": "MultiDistTextObjectVideoDataLoader", 59 | "args":{ 60 | "dataset_name": "WebVid", 61 | "data_dir": "WebVid", 62 | "object_dir": "WebVid/8_frame_object", 63 | "reader": "cv2", 64 | "shuffle": true, 65 | "num_workers": 8, 66 | "batch_size": 16, 67 | "split": "train", 68 | "cut": "2M", 69 | "subsample": 1, 70 | "text_params": { 71 | }, 72 | "object_params": { 73 | }, 74 | "video_params": { 75 | "input_res": 224, 76 | "num_frames": 4, 77 | "loading": "lax" 78 | } 79 | } 80 | } 81 | ], 82 | "optimizer": { 83 | "type": "AdamW", 84 | "args":{ 85 | "lr": 2e-4 86 | } 87 | }, 88 | "loss": { 89 | "type": "NormSoftmaxLoss", 90 | "args": { 91 | } 92 | }, 93 | "metrics": [ 94 | "t2v_metrics", 95 | "v2t_metrics" 96 | ], 97 | "trainer": { 98 | "epochs": 100, 99 | "max_samples_per_epoch": 1000000, 100 | "save_dir": "exps/2stream_wtags", 101 | "save_period": 5, 102 | "verbosity": 2, 103 | "monitor": "min val_loss_0", 104 | "early_stop": 10, 105 | "init_val": true, 106 | "neptune": false 107 | }, 108 | "visualizer": { 109 | "type": "" 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /OATrans/configs/pt/cc3m_webvid/norm.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "full-cc-WebVid2M-1f-pti2k-normal", 3 | "n_gpu": 8, 4 | "arch": { 5 | "type": "FrozenInTime", 6 | "object": false, 7 | "stream": 2, 8 | "args": { 9 | "video_params": { 10 | "model": "SpaceTimeTransformer", 11 | "arch_config": "base_patch16_224", 12 | "num_frames": 4, 13 | "pretrained": true, 14 | "time_init": "zeros", 15 | "two_outputs": false, 16 | "object_pseudo_label": false 17 | }, 18 | "object_params": { 19 | "model": "", 20 | "input_objects": false 21 | }, 22 | "text_params": { 23 | "model": "pretrained/distilbert-base-uncased", 24 | "pretrained": true, 25 | "input": "text", 26 | "two_outputs": false 27 | }, 28 | "projection": "minimal", 29 | "load_checkpoint" : "" 30 | } 31 | }, 32 | "data_loader": 33 | [ 34 | { 35 | "type": "MultiDistTextObjectVideoDataLoader", 36 | "args":{ 37 | "dataset_name": "ConceptualCaptions3M", 38 | "data_dir": "CC3M/", 39 | "object_dir": "CC3M/1_frame_object", 40 | "reader": "cv2", 41 | "shuffle": true, 42 | "num_workers": 8, 43 | "batch_size": 16, 44 | "split": "train", 45 | "subsample": 1, 46 | "text_params": { 47 | }, 48 | "object_params": { 49 | }, 50 | "video_params": { 51 | "input_res": 224, 52 | "num_frames": 1, 53 | "loading": "lax" 54 | } 55 | } 56 | }, 57 | { 58 | "type": "MultiDistTextObjectVideoDataLoader", 59 | "args":{ 60 | "dataset_name": "WebVid", 61 | "data_dir": "WebVid", 62 | "object_dir": "WebVid/8_frame_object", 63 | "reader": "cv2", 64 | "shuffle": true, 65 | "num_workers": 8, 66 | "batch_size": 16, 67 | "split": "train", 68 | "cut": "2M", 69 | "subsample": 1, 70 | "text_params": { 71 | }, 72 | "object_params": { 73 | }, 74 | "video_params": { 75 | "input_res": 224, 76 | "num_frames": 4, 77 | "loading": "lax" 78 | } 79 | } 80 | } 81 | ], 82 | "optimizer": { 83 | "type": "AdamW", 84 | "args":{ 85 | "lr": 2e-4 86 | } 87 | }, 88 | "loss": { 89 | "type": "NormSoftmaxLoss", 90 | "args": { 91 | } 92 | }, 93 | "metrics": [ 94 | "t2v_metrics", 95 | "v2t_metrics" 96 | ], 97 | "trainer": { 98 | "epochs": 100, 99 | "max_samples_per_epoch": 1000000, 100 | "save_dir": "exps/2stream_wtags", 101 | "save_period": 5, 102 | "verbosity": 2, 103 | "monitor": "min val_loss_0", 104 | "early_stop": 10, 105 | "init_val": true, 106 | "neptune": false 107 | }, 108 | "visualizer": { 109 | "type": "" 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /OATrans/data_loader/ConceptualCaptions_dataset.py: -------------------------------------------------------------------------------- 1 | # from base.base_dataset import TextObjectImageDataset 2 | from OATrans.base.base_dataset_region_mem import TextObjectImageDataset 3 | import pandas as pd 4 | import os 5 | 6 | 7 | class ConceptualCaptions3M(TextObjectImageDataset): 8 | """ 9 | Conceptual Captions dataset. Split files are specific to my download regime. 10 | """ 11 | 12 | def _load_metadata(self): 13 | # download specific 14 | metadata_dir = './meta_data' 15 | split_files = { 16 | 'train': 'cc3m_training_success_full.tsv', 17 | 'val': 'cc3m_validation_success_full.tsv', # there is no test 18 | } 19 | target_split_fp = split_files[self.split] 20 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 21 | 22 | if self.subsample < 1: 23 | metadata = metadata.sample(frac=self.subsample) 24 | # elif self.split == 'val': 25 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 26 | 27 | self.metadata = metadata 28 | 29 | def _get_video_path(self, sample): 30 | # conceptual captions uses this hashing to create the filename 31 | rel_dir = 'training' 32 | if self.split != 'train': 33 | rel_dir = 'validation' 34 | rel_fp = os.path.join(rel_dir, sample[1]) 35 | #rel_fp = os.path.join(rel_dir, str(zlib.crc32(sample['thumbnailUrl'].encode('utf-8')) & 0xffffffff)) 36 | return os.path.join(self.data_dir, rel_fp), rel_fp 37 | 38 | def _get_caption(self, sample): 39 | return sample[0] 40 | #return sample['caption'] 41 | 42 | def _get_object_path(self, sample): 43 | """ 44 | get the object npy path 45 | Args: 46 | sample (dict): 47 | Returns: 48 | abs path 49 | """ 50 | # pre = sample[1].split('_')[0] 51 | # pre = pre.zfill(7) 52 | # rel_object_fp = os.path.join(pre[:4], sample[1]) 53 | # rel_object_fp = os.path.join(pre[:4], sample[1] + '_1.npz') 54 | rel_object_fp = sample[1] 55 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp) 56 | return os.path.join(self.split, rel_object_fp), full_object_fp -------------------------------------------------------------------------------- /OATrans/data_loader/DiDeMo_dataset.py: -------------------------------------------------------------------------------- 1 | from OATrans.base.base_dataset import TextObjectVideoDataset 2 | import pandas as pd 3 | import os 4 | 5 | 6 | class DiDeMo(TextObjectVideoDataset): 7 | def _load_metadata(self): 8 | metadata_dir = './meta_data' 9 | split_files = { 10 | 'train': 'DiDeMo_train.tsv', 11 | 'val': 'DiDeMo_val.tsv', # there is no test 12 | 'test': 'DiDeMo_test.tsv' 13 | } 14 | target_split_fp = split_files[self.split] 15 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 16 | if self.subsample < 1: 17 | metadata = metadata.sample(frac=self.subsample) 18 | self.metadata = metadata 19 | print("load split {}, {} samples".format(self.split, len(metadata))) 20 | 21 | def _get_video_path(self, sample): 22 | rel_video_fp = sample[1] 23 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 24 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 25 | # print(full_video_fp) 26 | return full_video_fp, rel_video_fp 27 | 28 | def _get_caption(self, sample): 29 | # print(sample[0].split(',')[0]) 30 | # return sample[0].split(',')[0] 31 | return sample[0] # .split(',')[0] 32 | 33 | def _get_object_path(self, sample, index=0): 34 | """ 35 | get the object npy path 36 | Args: 37 | sample (dict): 38 | Returns: 39 | abs path 40 | """ 41 | rel_object_fp = os.path.join(sample[1], '1.npz') 42 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp) 43 | return os.path.join(self.split, rel_object_fp), full_object_fp -------------------------------------------------------------------------------- /OATrans/data_loader/LSMDC_choice_dataset.py: -------------------------------------------------------------------------------- 1 | from OATrans.base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | 6 | 7 | class LSMDC(TextVideoDataset): 8 | def _load_metadata(self): 9 | split_paths = {key: os.path.join(self.metadata_dir, 'structured-symlinks', f'{key}_list.txt') for key in 10 | ['train', 'val', 'test']} 11 | df_dict = {key: pd.read_csv(val, names=['videoid']) for key, val in split_paths.items()} 12 | #### subsample_val 13 | 14 | self.split_sizes = {key: len(val) for key, val in df_dict.items()} 15 | target_vids = df_dict[self.split] 16 | # target_vids = target_vids['videoid'].str.split('.').str[0] 17 | if self.subsample < 1: 18 | target_vids = target_vids.sample(frac=self.subsample) 19 | captions = np.load(os.path.join(self.metadata_dir, 'structured-symlinks', 'raw-captions.pkl'), 20 | allow_pickle=True) 21 | captions = pd.DataFrame.from_dict(captions, orient='index') 22 | captions['captions'] = captions.values.tolist() 23 | target_vids.set_index('videoid', inplace=True) 24 | target_vids['captions'] = captions['captions'] 25 | # import pdb; -.set_trace() 26 | # captions = captions[captions.index.isin(target_vids.str['videoid'].split('.').str[0])] 27 | self.metadata = target_vids 28 | frame_tar_list = pd.read_csv(os.path.join(self.metadata_dir, 'frame_tar_list.txt'), names=['fp']) 29 | 30 | frame_tar_list['fn'] = frame_tar_list['fp'].str.split('/').str[-2:].str.join('/') 31 | frame_tar_list['fn'] = frame_tar_list['fn'].str.replace('.tar', '') 32 | frame_tar_list['vid_stem'] = frame_tar_list['fn'].str.split('/').str[-1] 33 | 34 | frame_tar_list = frame_tar_list[frame_tar_list['vid_stem'].isin(self.metadata.index)] 35 | 36 | frame_tar_list.set_index('vid_stem', inplace=True) 37 | self.metadata['fn'] = frame_tar_list['fn'] 38 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [ii for ii in x if ii is not None]) 39 | self.metadata['num_captions'] = self.metadata['captions'].str.len() 40 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [' '.join(ii) for ii in x]) 41 | 42 | if 'videoid' not in self.metadata.columns: 43 | self.metadata['videoid'] = self.metadata.index 44 | 45 | def _get_video_path(self, sample): 46 | return os.path.join(self.data_dir, 'videos', sample['fn'] + '.avi'), sample.name + '.avi' 47 | 48 | def _get_caption(self, sample): 49 | if len(sample['captions']) != 1: 50 | raise NotImplementedError 51 | return sample['captions'][0] -------------------------------------------------------------------------------- /OATrans/data_loader/LSMDC_dataset.py: -------------------------------------------------------------------------------- 1 | from OATrans.base.base_dataset import TextVideoDataset 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | 6 | 7 | class LSMDC(TextVideoDataset): 8 | def _load_metadata(self): 9 | split_paths = {key: os.path.join(self.metadata_dir, 'structured-symlinks', f'{key}_list.txt') for key in 10 | ['train', 'val', 'test']} 11 | df_dict = {key: pd.read_csv(val, names=['videoid']) for key, val in split_paths.items()} 12 | #### subsample_val 13 | 14 | self.split_sizes = {key: len(val) for key, val in df_dict.items()} 15 | target_vids = df_dict[self.split] 16 | # target_vids = target_vids['videoid'].str.split('.').str[0] 17 | if self.subsample < 1: 18 | target_vids = target_vids.sample(frac=self.subsample) 19 | captions = np.load(os.path.join(self.metadata_dir, 'structured-symlinks', 'raw-captions.pkl'), 20 | allow_pickle=True) 21 | captions = pd.DataFrame.from_dict(captions, orient='index') 22 | captions['captions'] = captions.values.tolist() 23 | target_vids.set_index('videoid', inplace=True) 24 | target_vids['captions'] = captions['captions'] 25 | # import pdb; -.set_trace() 26 | # captions = captions[captions.index.isin(target_vids.str['videoid'].split('.').str[0])] 27 | self.metadata = target_vids 28 | frame_tar_list = pd.read_csv(os.path.join(self.metadata_dir, 'frame_tar_list.txt'), names=['fp']) 29 | 30 | frame_tar_list['fn'] = frame_tar_list['fp'].str.split('/').str[-2:].str.join('/') 31 | frame_tar_list['fn'] = frame_tar_list['fn'].str.replace('.tar', '') 32 | frame_tar_list['vid_stem'] = frame_tar_list['fn'].str.split('/').str[-1] 33 | 34 | frame_tar_list = frame_tar_list[frame_tar_list['vid_stem'].isin(self.metadata.index)] 35 | 36 | frame_tar_list.set_index('vid_stem', inplace=True) 37 | self.metadata['fn'] = frame_tar_list['fn'] 38 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [ii for ii in x if ii is not None]) 39 | self.metadata['num_captions'] = self.metadata['captions'].str.len() 40 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [' '.join(ii) for ii in x]) 41 | 42 | if 'videoid' not in self.metadata.columns: 43 | self.metadata['videoid'] = self.metadata.index 44 | 45 | def _get_video_path(self, sample): 46 | return os.path.join(self.data_dir, 'videos', sample['fn'] + '.avi'), sample.name + '.avi' 47 | 48 | def _get_caption(self, sample): 49 | if len(sample['captions']) != 1: 50 | raise NotImplementedError 51 | return sample['captions'][0] -------------------------------------------------------------------------------- /OATrans/data_loader/MSRVTT_dataset.py: -------------------------------------------------------------------------------- 1 | from OATrans.base.base_dataset import TextObjectVideoDataset 2 | # from base.base_dataset_global_local import TextObjectVideoDataset 3 | # from base.base_dataset_region_mem import TextObjectVideoDataset 4 | import pandas as pd 5 | import os 6 | import json 7 | import numpy as np 8 | import random 9 | 10 | 11 | class MSRVTT(TextObjectVideoDataset): 12 | def _load_metadata(self): 13 | json_fp = os.path.join(self.metadata_dir, 'annotation', 'MSR_VTT.json') 14 | with open(json_fp, 'r') as fid: 15 | data = json.load(fid) 16 | df = pd.DataFrame(data['annotations']) 17 | 18 | split_dir = os.path.join(self.metadata_dir, 'high-quality', 'structured-symlinks') 19 | js_test_cap_idx_path = None 20 | challenge_splits = {"val", "public_server_val", "public_server_test"} 21 | if self.cut == "miech": 22 | train_list_path = "train_list_miech.txt" 23 | test_list_path = "test_list_miech.txt" 24 | elif self.cut == "jsfusion": 25 | train_list_path = "train_list_jsfusion.txt" 26 | test_list_path = "val_list_jsfusion.txt" 27 | js_test_cap_idx_path = "jsfusion_val_caption_idx.pkl" 28 | elif self.cut in {"full-val", "full-test"}: 29 | train_list_path = "train_list_full.txt" 30 | if self.cut == "full-val": 31 | test_list_path = "val_list_full.txt" 32 | else: 33 | test_list_path = "test_list_full.txt" 34 | elif self.cut in challenge_splits: 35 | train_list_path = "train_list.txt" 36 | if self.cut == "val": 37 | test_list_path = f"{self.cut}_list.txt" 38 | else: 39 | test_list_path = f"{self.cut}.txt" 40 | else: 41 | msg = "unrecognised MSRVTT split: {}" 42 | raise ValueError(msg.format(self.cut)) 43 | 44 | train_df = pd.read_csv(os.path.join(split_dir, train_list_path), names=['videoid']) 45 | test_df = pd.read_csv(os.path.join(split_dir, test_list_path), names=['videoid']) 46 | self.split_sizes = {'train': len(train_df), 'val': len(test_df), 'test': len(test_df)} 47 | 48 | if self.split == 'train': 49 | df = df[df['image_id'].isin(train_df['videoid'])] 50 | else: 51 | df = df[df['image_id'].isin(test_df['videoid'])] 52 | 53 | self.metadata = df.groupby(['image_id'])['caption'].apply(list) 54 | if self.subsample < 1: 55 | self.metadata = self.metadata.sample(frac=self.subsample) 56 | 57 | # use specific caption idx's in jsfusion 58 | if js_test_cap_idx_path is not None and self.split != 'train': 59 | caps = pd.Series(np.load(os.path.join(split_dir, js_test_cap_idx_path), allow_pickle=True)) 60 | new_res = pd.DataFrame({'caps': self.metadata, 'cap_idx': caps}) 61 | new_res['test_caps'] = new_res.apply(lambda x: [x['caps'][x['cap_idx']]], axis=1) 62 | self.metadata = new_res['test_caps'] 63 | 64 | self.metadata = pd.DataFrame({'captions': self.metadata}) 65 | print("load split {}, {} samples".format(self.split, len(self.metadata))) 66 | 67 | def _get_video_path(self, sample): 68 | return os.path.join(self.data_dir, 'videos', 'all', sample.name + '.mp4'), sample.name + '.mp4' 69 | 70 | def _get_caption(self, sample): 71 | caption_sample = self.text_params.get('caption_sample', "rand") 72 | if self.split in ['train', 'val'] and caption_sample == "rand": 73 | caption = random.choice(sample['captions']) 74 | else: 75 | caption = sample['captions'][0] 76 | return caption 77 | 78 | def _get_object_path(self, sample): 79 | """ 80 | get the object npy path 81 | Args: 82 | sample (dict): 83 | Returns: 84 | abs path 85 | """ 86 | # real_path = os.path.join(sample.name, '{}.npz'.format(index)) 87 | real_path = sample.name 88 | full_object_fp = os.path.join(self.object_dir, sample.name) 89 | return real_path, full_object_fp -------------------------------------------------------------------------------- /OATrans/data_loader/MSVD_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from OATrans.base.base_dataset import TextObjectVideoDataset 4 | import pandas as pd 5 | import os 6 | 7 | 8 | class MSVD(TextObjectVideoDataset): 9 | def _load_metadata(self): 10 | metadata_dir = './meta_data' 11 | split_files = { 12 | 'train': 'MSVD_train.tsv', 13 | # 'val': 'MSVD_val.tsv', # there is no test 14 | 'val': 'MSVD_test.tsv', # direct output test result 15 | # 'val': 'MSVD_split_test.tsv', 16 | # 'test': 'MSVD_split_test.tsv' 17 | 'test': 'MSVD_test.tsv' 18 | } 19 | target_split_fp = split_files[self.split] 20 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 21 | if self.subsample < 1: 22 | metadata = metadata.sample(frac=self.subsample) 23 | self.metadata = metadata 24 | print("load split {}, {} samples".format(self.split, len(metadata))) 25 | 26 | def _get_video_path(self, sample): 27 | rel_video_fp = sample[1] + '.avi' 28 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 29 | full_video_fp = os.path.join(self.data_dir, rel_video_fp) 30 | # print(full_video_fp) 31 | return full_video_fp, rel_video_fp 32 | 33 | # multiple sentence 34 | def _get_caption(self, sample): 35 | # print(sample[0].split(',')[0]) 36 | if self.split == 'train': 37 | words = sample[0].split(',') 38 | num_word = len(words) 39 | index = random.randint(0, num_word-1) 40 | caption = words[index] 41 | else: 42 | # caption = sample[0] 43 | words = sample[0].split(',') 44 | num_word = len(words) 45 | index = random.randint(0, num_word-1) 46 | caption = words[index] 47 | # caption = None 48 | # if self.split == 'train': 49 | # indexs = sorted(random.sample(range(0, num_word-1), 5)) 50 | # caption = ' '.join(words[item] for item in indexs) 51 | # else: 52 | # caption = ' '.join(words[item] for item in range(0, 5)) 53 | return caption 54 | 55 | def _get_object_path(self, sample, index=1): 56 | """ 57 | get the object npy path 58 | Args: 59 | sample (dict): 60 | Returns: 61 | abs path 62 | """ 63 | rel_object_fp = os.path.join(sample[1], '1.npz') 64 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp) 65 | return os.path.join(self.split, rel_object_fp), full_object_fp -------------------------------------------------------------------------------- /OATrans/data_loader/WebVid_dataset.py: -------------------------------------------------------------------------------- 1 | # from base.base_dataset import TextObjectVideoDataset 2 | from OATrans.base.base_dataset_region_mem import TextObjectVideoDataset 3 | # from base.base_dataset_region_single import TextObjectVideoDataset 4 | # from base.base_dataset_region_mem_bk import TextObjectVideoDataset 5 | import pandas as pd 6 | import os 7 | 8 | 9 | class WebVidObject(TextObjectVideoDataset): 10 | """ 11 | WebVid Dataset. 12 | Assumes webvid data is structured as follows. 13 | Webvid/ 14 | videos/ 15 | 000001_000050/ ($page_dir) 16 | 1.mp4 (videoid.mp4) 17 | ... 18 | 5000.mp4 19 | ... 20 | """ 21 | def _load_metadata(self): 22 | #metadata_dir = os.path.join(self.metadata_dir, 'meta_data') 23 | metadata_dir = './meta_data' 24 | split_files = { 25 | 'train': 'webvid_training_success_full.tsv', 26 | # 'train': 'webvid_1_of_10_training_success_full.tsv', 27 | # 'train': 'webvid_validation_success_full.tsv', 28 | 'val': 'webvid_validation_success_full.tsv', # there is no test 29 | } 30 | 31 | target_split_fp = split_files[self.split] 32 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t') 33 | if self.subsample < 1: 34 | metadata = metadata.sample(frac=self.subsample) 35 | # elif self.split == 'val': 36 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample. 37 | 38 | #metadata['caption'] = metadata['name'] 39 | #del metadata['name'] 40 | self.metadata = metadata 41 | # TODO: clean final csv so this isn't necessary 42 | #self.metadata.dropna(inplace=True) 43 | #self.metadata['caption'] = self.metadata['caption'].str[:350] 44 | 45 | def _get_video_path(self, sample): 46 | rel_video_fp = sample[1] + '.mp4' 47 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 48 | full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp) 49 | return full_video_fp, rel_video_fp 50 | 51 | def _get_caption(self, sample): 52 | return sample[0] 53 | 54 | def _get_object_path(self, sample): 55 | """ 56 | get the object npy path 57 | Args: 58 | sample (dict): 59 | Returns: 60 | abs path 61 | """ 62 | # rel_object_fp = sample[1] + '.pickle' 63 | rel_object_fp = sample[1] # + '.pickle' 64 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp) 65 | return rel_object_fp, full_object_fp -------------------------------------------------------------------------------- /OATrans/data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/data_loader/__init__.py -------------------------------------------------------------------------------- /OATrans/data_loader/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | 4 | def init_transform_dict(input_res=224, 5 | center_crop=256, 6 | randcrop_scale=(0.5, 1.0), 7 | color_jitter=(0, 0, 0), 8 | norm_mean=(0.485, 0.456, 0.406), 9 | norm_std=(0.229, 0.224, 0.225)): 10 | normalize = transforms.Normalize(mean=norm_mean, std=norm_std) 11 | tsfm_dict = { 12 | 'train': transforms.Compose([ 13 | transforms.RandomResizedCrop(input_res, scale=randcrop_scale), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]), 16 | normalize, 17 | ]), 18 | 'val': transforms.Compose([ 19 | transforms.Resize(center_crop), 20 | transforms.CenterCrop(center_crop), 21 | transforms.Resize(input_res), 22 | normalize, 23 | ]), 24 | 'test': transforms.Compose([ 25 | transforms.Resize(center_crop), 26 | transforms.CenterCrop(center_crop), 27 | transforms.Resize(input_res), 28 | normalize, 29 | ]) 30 | } 31 | return tsfm_dict 32 | -------------------------------------------------------------------------------- /OATrans/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /OATrans/logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from OATrans.utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /OATrans/logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /OATrans/logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from OATrans.utils import Timer 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ 27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ 28 | "the 'config.json' file." 29 | logger.warning(message) 30 | 31 | self.step = 0 32 | self.mode = '' 33 | 34 | self.tb_writer_ftns = { 35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 37 | } 38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 39 | 40 | self.timer = Timer() 41 | 42 | def set_step(self, step, mode='train'): 43 | self.mode = mode 44 | self.step = step 45 | if step == 0: 46 | self.timer.reset() 47 | else: 48 | duration = self.timer.check() 49 | self.add_scalar('steps_per_sec', 1 / duration) 50 | 51 | def __getattr__(self, name): 52 | """ 53 | If visualization is configured to use: 54 | return add_data() methods of tensorboard with additional information (step, tag) added. 55 | Otherwise: 56 | return a blank function handle that does nothing 57 | """ 58 | if name in self.tb_writer_ftns: 59 | add_data = getattr(self.writer, name, None) 60 | 61 | def wrapper(tag, data, *args, **kwargs): 62 | if add_data is not None: 63 | # add mode(train/valid) tag 64 | if name not in self.tag_mode_exceptions: 65 | tag = '{}/{}'.format(tag, self.mode) 66 | add_data(tag, data, self.step, *args, **kwargs) 67 | return wrapper 68 | else: 69 | # default action for returning methods defined in this class, set_step() for instance. 70 | try: 71 | attr = object.__getattr__(name) 72 | except AttributeError: 73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 74 | return attr 75 | 76 | 77 | class SacredNeptuneWriter(): 78 | def __init__(self): 79 | raise NotImplementedError -------------------------------------------------------------------------------- /OATrans/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/model/__init__.py -------------------------------------------------------------------------------- /OATrans/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | 7 | class NormSoftmaxLoss(nn.Module): 8 | def __init__(self, temperature=0.05): 9 | super().__init__() 10 | 11 | self.temperature = temperature 12 | 13 | def forward(self, x): 14 | "Assumes input x is similarity matrix of N x M \in [-1, 1], computed using the cosine similarity between normalised vectors" 15 | i_logsm = F.log_softmax(x/self.temperature, dim=1) 16 | j_logsm = F.log_softmax(x.t()/self.temperature, dim=1) 17 | 18 | # sum over positives 19 | idiag = torch.diag(i_logsm) 20 | loss_i = idiag.sum() / len(idiag) 21 | 22 | jdiag = torch.diag(j_logsm) 23 | loss_j = jdiag.sum() / len(jdiag) 24 | 25 | return - loss_i - loss_j 26 | 27 | 28 | class MaxMarginRankingLoss(nn.Module): 29 | 30 | def __init__(self, margin=1, fix_norm=True): 31 | super().__init__() 32 | self.fix_norm = fix_norm 33 | self.loss = th.nn.MarginRankingLoss(margin) 34 | self.margin = margin 35 | 36 | def forward(self, x): 37 | n = x.size()[0] 38 | 39 | x1 = th.diag(x) 40 | x1 = x1.unsqueeze(1) 41 | x1 = x1.expand(n, n) 42 | x1 = x1.contiguous().view(-1, 1) 43 | x1 = th.cat((x1, x1), 0) 44 | 45 | x2 = x.view(-1, 1) 46 | x3 = x.transpose(0, 1).contiguous().view(-1, 1) 47 | 48 | x2 = th.cat((x2, x3), 0) 49 | max_margin = F.relu(self.margin - (x1 - x2)) 50 | 51 | if self.fix_norm: 52 | # remove the elements from the diagonal 53 | keep = th.ones(x.shape) - th.eye(x.shape[0]) # 128 x 128 54 | keep1 = keep.view(-1, 1) 55 | keep2 = keep.transpose(0, 1).contiguous().view(-1, 1) 56 | keep_idx = th.nonzero(th.cat((keep1, keep2), 0).flatten()).flatten() 57 | if x1.is_cuda: 58 | keep_idx = keep_idx.cuda() 59 | x1_ = th.index_select(x1, dim=0, index=keep_idx) 60 | x2_ = th.index_select(x2, dim=0, index=keep_idx) 61 | max_margin = F.relu(self.margin - (x1_ - x2_)) 62 | 63 | return max_margin.mean() 64 | 65 | 66 | class CrossEntropy(nn.Module): 67 | def __init__(self): 68 | super().__init__() 69 | self.loss = nn.CrossEntropyLoss() 70 | 71 | def forward(self, output, target): 72 | return self.loss(output, target) 73 | 74 | 75 | def cosine_sim(im, s): 76 | """Cosine similarity between all the image and sentence pairs 77 | """ 78 | return im.mm(s.t()) 79 | 80 | 81 | def order_sim(im, s): 82 | """Order embeddings similarity measure $max(0, s-im)$ 83 | """ 84 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1)) 85 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1))) 86 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t() 87 | return score 88 | 89 | 90 | def nll_loss(output, target): 91 | return F.nll_loss(output, target) 92 | 93 | 94 | if __name__ == "__main__": 95 | import torch 96 | 97 | random_sims = (torch.rand([10, 8]) * 2) - 1 98 | loss = NormSoftmaxLoss() 99 | loss(random_sims) 100 | -------------------------------------------------------------------------------- /OATrans/model/oa_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | import torch.nn.functional as F 4 | import torch 5 | import math 6 | import numpy as np 7 | from model.model import sim_matrix 8 | from model.loss import NormSoftmaxLoss 9 | from torch.autograd import Variable 10 | 11 | 12 | 13 | # simsiam loss 14 | 15 | 16 | def softmax_kl_loss(input_logits, target_logits): 17 | """Takes softmax on both sides and returns KL divergence 18 | Note: 19 | - Returns the sum over all examples. Divide by the batch size afterwards 20 | if you want the mean. 21 | - Sends gradients to inputs but not the targets. 22 | """ 23 | assert input_logits.size() == target_logits.size() 24 | input_log_softmax = F.log_softmax(input_logits, dim=1) 25 | target_softmax = F.softmax(target_logits, dim=1) 26 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 27 | 28 | 29 | def softmax_mse_loss(input_logits, target_logits): 30 | """Takes softmax on both sides and returns MSE loss 31 | Note: 32 | - Returns the sum over all examples. Divide by the batch size afterwards 33 | if you want the mean. 34 | - Sends gradients to inputs but not the targets. 35 | """ 36 | assert input_logits.size() == target_logits.size() 37 | return F.mse_loss(input_logits, target_logits, size_average=False) # / num_classes 38 | # input_softmax = F.softmax(input_logits, dim=1) 39 | # target_softmax = F.softmax(target_logits, dim=1) 40 | # num_classes = input_logits.size()[1] 41 | # return F.mse_loss(input_softmax, target_softmax, size_average=False) # / num_classes 42 | 43 | 44 | # n_data = len(dataset) 45 | # contrast = MemoryMoCo(128, n_data, 8092*4, 0.07, use_softmax=True).cuda() 46 | # criterion = NCESoftmaxLoss() 47 | # criterion = criterion.cuda() 48 | # 49 | # out = contrast(feat_q, feat_k, feat_n, index) 50 | # contrast_loss = criterion(out) 51 | 52 | 53 | class NCESoftmaxLoss(nn.Module): 54 | """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)""" 55 | def __init__(self): 56 | super(NCESoftmaxLoss, self).__init__() 57 | self.criterion = nn.CrossEntropyLoss() 58 | 59 | def forward(self, x): 60 | bsz = x.shape[0] 61 | x = x.squeeze() 62 | label = torch.zeros([bsz]).cuda().long() 63 | loss = self.criterion(x, label) 64 | return loss 65 | 66 | class MemoryMoCo(nn.Module): 67 | """Fixed-size queue with momentum encoder""" 68 | # T = 0.2 achieve best result? 69 | def __init__(self, inputSize, outputSize, K, T=0.07, use_softmax=False): 70 | super(MemoryMoCo, self).__init__() 71 | self.outputSize = outputSize 72 | self.inputSize = inputSize 73 | self.queueSize = K 74 | self.T = T 75 | self.index = 0 76 | self.use_softmax = use_softmax 77 | 78 | self.register_buffer('params', torch.tensor([-1])) 79 | stdv = 1. / math.sqrt(inputSize / 3) 80 | self.register_buffer('memory', torch.rand(self.queueSize, inputSize).mul_(2 * stdv).add_(-stdv)) 81 | # self.register_buffer('spatial_memory', torch.rand(self.queueSize, inputSize).mul_(2 * stdv).add_(-stdv)) 82 | print('using queue shape: ({},{})'.format(self.queueSize, inputSize)) 83 | 84 | def forward(self, q, k, n): 85 | # n, sn, 86 | batchSize = q.shape[0] 87 | k = k.detach() 88 | 89 | Z = self.params[0].item() 90 | 91 | # pos logit 92 | l_pos = torch.bmm(q.view(batchSize, 1, -1), k.view(batchSize, -1, 1)) 93 | l_pos = l_pos.view(batchSize, 1) 94 | 95 | # # neg logit 96 | # # queue = self.memory_bank.get_queue(self.queueSize, indexs) 97 | queue = self.memory.clone() 98 | l_neg = torch.mm(queue.detach(), q.transpose(1, 0)) 99 | l_neg = l_neg.transpose(0, 1) 100 | #out = torch.cat((l_pos, l_neg), dim=1) 101 | 102 | # other negative 103 | l_neg_2 = torch.bmm(q.view(batchSize, 1, -1), n.view(batchSize, -1, 1)) 104 | l_neg_2 = l_neg_2.view(batchSize, 1) 105 | # 106 | # strong negative 107 | # l_s_neg = torch.bmm(q.view(batchSize, 1, -1), sn.view(batchSize, -1, 1)) 108 | # l_s_neg = l_s_neg.view(batchSize, 1) 109 | 110 | out = torch.cat((l_pos, l_neg, l_neg_2), dim=1) 111 | # out = torch.cat((l_pos, l_neg, l_neg_2, l_s_neg), dim=1) 112 | 113 | if self.use_softmax: 114 | out = torch.div(out, self.T) 115 | out = out.squeeze().contiguous() 116 | else: 117 | out = torch.exp(torch.div(out, self.T)) 118 | if Z < 0: 119 | self.params[0] = out.mean() * self.outputSize 120 | Z = self.params[0].clone().detach().item() 121 | print("normalization constant Z is set to {:.1f}".format(Z)) 122 | # compute the out 123 | out = torch.div(out, Z).squeeze().contiguous() 124 | 125 | # label = torch.zeros([batchSize]).cuda().long() 126 | # loss = [] 127 | # for i in range(batchSize): 128 | # loss.append(self.criterion(out[i].unsqueeze(0), label[i].unsqueeze(0))) 129 | # print(loss) 130 | # self.memory_bank.batch_set(indexs, k, loss) 131 | # self.memory = self.memory_bank.update_queue(self.memory) 132 | # print(self.memory_bank.link) 133 | # update memory 134 | with torch.no_grad(): 135 | out_ids = torch.arange(batchSize).cuda() 136 | out_ids += self.index 137 | out_ids = torch.fmod(out_ids, self.queueSize) # 1 fmod 1.5 = 1 2 fmod 1.5 = 0.5 138 | out_ids = out_ids.long() 139 | self.memory.index_copy_(0, out_ids, k) 140 | self.index = (self.index + batchSize) % self.queueSize 141 | # add for spatial memory 142 | 143 | return out 144 | 145 | 146 | class FineGrainedLoss(nn.Module): 147 | def __init__(self, temperature=0.05): 148 | super().__init__() 149 | self.criterion = NormSoftmaxLoss(temperature) 150 | 151 | def forward(self, vid_feats, text_feats, bboxs, object_token_len, real_len): 152 | # find the patch that contain in bboxes 153 | loss = None 154 | bboxs[:, :4] = bboxs[:, :4] * 16 155 | bboxs[:, :2] = torch.round(bboxs[:, :2]) 156 | bboxs[:, 2:4] = torch.ceil(bboxs[:, 2:4]) 157 | # for each sample 158 | # print(vid_feats.size(), text_feats.size()) # 128 x 196 x 256, 128 x 14 x 256 159 | 160 | # step1: for each bbox, get corresponding features in tensor [B, 10, 256] 161 | for index, bbox in enumerate(bboxs): 162 | patch_indexs = np.zeros(16*16) 163 | for i in range(16): 164 | for j in range(16): 165 | if i > bbox[:, 0] and i < bbox[:, 2] and j > bbox[:, 1] and j < bbox[:, 3]: 166 | patch_indexs[:, i*16+j] = 1 167 | # select patch features according to indexs 168 | vid_feats_related = vid_feats[:, patch_indexs] 169 | vid_feat = torch.mean(vid_feats_related, dim=1) 170 | # shared proj head ? 171 | 172 | # step2: for text, compute the corresponding text features in tensor [B, 10, 256] 173 | # select text_feat of given bbox/ object_tokens 174 | text_feat = text_feats[:, index] 175 | # step3: compute intra_sample_loss and inter_sample_loss 176 | if loss is None: 177 | loss = self.criterion(sim_matrix(text_feat, vid_feat)) 178 | else: 179 | loss += self.criterion(sim_matrix(text_feat, vid_feat)) 180 | return loss -------------------------------------------------------------------------------- /OATrans/model/prompt_learner.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from clip import clip 4 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 5 | 6 | _tokenizer = _Tokenizer() 7 | 8 | 9 | class TextEncoder(nn.Module): 10 | def __init__(self, clip_model): 11 | super().__init__() 12 | self.transformer = clip_model.transformer 13 | self.positional_embedding = clip_model.positional_embedding 14 | self.ln_final = clip_model.ln_final 15 | self.text_projection = clip_model.text_projection 16 | self.dtype = clip_model.dtype 17 | 18 | def forward(self, prompts, tokenized_prompts): 19 | x = prompts + self.positional_embedding.type(self.dtype) 20 | x = x.permute(1, 0, 2) # NLD -> LND 21 | x = self.transformer(x) 22 | x = x.permute(1, 0, 2) # LND -> NLD 23 | x = self.ln_final(x).type(self.dtype) 24 | 25 | # x.shape = [batch_size, n_ctx, transformer.width] 26 | # take features from the eot embedding (eot_token is the highest number in each sequence) 27 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 28 | 29 | return x 30 | 31 | 32 | class PromptLearner(nn.Module): 33 | def __init__(self, clip_model): 34 | super().__init__() 35 | n_ctx = 8 36 | ctx_init = False 37 | CSC = False # if class specific prompt 38 | dtype = clip_model.dtype 39 | ctx_dim = clip_model.ln_final.weight.shape[0] 40 | 41 | if ctx_init: 42 | # use given words to initialize context vectors 43 | ctx_init = ctx_init.replace("_", " ") 44 | n_ctx = len(ctx_init.split(" ")) 45 | prompt = clip.tokenize(ctx_init) 46 | with torch.no_grad(): 47 | embedding = clip_model.token_embedding(prompt).type(dtype) 48 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 49 | self.prompt_prefix = ctx_init 50 | 51 | else: 52 | # random initialization 53 | if CSC: 54 | print("Initializing class-specific contexts") 55 | ctx_vectors = torch.empty(1, n_ctx, ctx_dim, dtype=dtype) 56 | else: 57 | print("Initializing a generic context") 58 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 59 | nn.init.normal_(ctx_vectors, std=0.02) 60 | self.prompt_prefix = " ".join(["X"] * n_ctx) 61 | 62 | print(f'Initial context: "{prompt_prefix}"') 63 | print(f"Number of context words (tokens): {n_ctx}") 64 | 65 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 66 | self.n_cls = 1 67 | self.n_ctx = n_ctx 68 | self.tokenized_prompts = None # torch.Tensor 69 | self.class_token_position = "end" 70 | self.clip_model = clip_model 71 | 72 | def forward(self, cls_name): 73 | ctx = self.ctx 74 | if ctx.dim() == 2: 75 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 76 | prompts = [self.prompt_prefix + " " + cls_name] 77 | self.tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 78 | with torch.no_grad(): 79 | embedding = self.clip_model.token_embedding(tokenized_prompts).type(dtype) 80 | 81 | prefix = embedding[:, :1, :] 82 | suffix = embedding[:, 1 + n_ctx :, :] 83 | 84 | if self.class_token_position == "end": 85 | prompts = torch.cat( 86 | [ 87 | prefix, # (n_cls, 1, dim) 88 | ctx, # (n_cls, n_ctx, dim) 89 | suffix, # (n_cls, *, dim) 90 | ], 91 | dim=1, 92 | ) 93 | else: 94 | raise ValueError 95 | 96 | return prompts -------------------------------------------------------------------------------- /OATrans/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--data_path', default='./data/', 5 | help='path to datasets') 6 | parser.add_argument('--data_name', default='precomp', 7 | help='{coco,f30k}_precomp') 8 | parser.add_argument('--vocab_path', default='./vocab/', 9 | help='Path to saved vocabulary json files.') 10 | parser.add_argument('--margin', default=0.2, type=float, 11 | help='Rank loss margin.') 12 | parser.add_argument('--num_epochs', default=30, type=int, 13 | help='Number of training epochs.') 14 | parser.add_argument('--batch_size', default=128, type=int, 15 | help='Size of a training mini-batch.') 16 | parser.add_argument('--word_dim', default=300, type=int, 17 | help='Dimensionality of the word embedding.') 18 | parser.add_argument('--embed_size', default=1024, type=int, 19 | help='Dimensionality of the joint embedding.') 20 | parser.add_argument('--grad_clip', default=2., type=float, 21 | help='Gradient clipping threshold.') 22 | parser.add_argument('--num_layers', default=1, type=int, 23 | help='Number of GRU layers.') 24 | parser.add_argument('--learning_rate', default=.0002, type=float, 25 | help='Initial learning rate.') 26 | parser.add_argument('--lr_update', default=15, type=int, 27 | help='Number of epochs to update the learning rate.') 28 | parser.add_argument('--workers', default=10, type=int, 29 | help='Number of data loader workers.') 30 | parser.add_argument('--log_step', default=10, type=int, 31 | help='Number of steps to print and record the log.') 32 | parser.add_argument('--val_step', default=500, type=int, 33 | help='Number of steps to run validation.') 34 | parser.add_argument('--logger_name', default='./runs/runX/log', 35 | help='Path to save Tensorboard log.') 36 | parser.add_argument('--model_name', default='./runs/runX/checkpoint', 37 | help='Path to save the model.') 38 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--max_violation', action='store_true', 41 | help='Use max instead of sum in the rank loss.') 42 | parser.add_argument('--img_dim', default=2048, type=int, 43 | help='Dimensionality of the image embedding.') 44 | parser.add_argument('--no_imgnorm', action='store_true', 45 | help='Do not normalize the image embeddings.') 46 | parser.add_argument('--no_txtnorm', action='store_true', 47 | help='Do not normalize the text embeddings.') 48 | parser.add_argument('--raw_feature_norm', default="clipped_l2norm", 49 | help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax') 50 | parser.add_argument('--agg_func', default="LogSumExp", 51 | help='LogSumExp|Mean|Max|Sum') 52 | parser.add_argument('--cross_attn', default="t2i", 53 | help='t2i|i2t') 54 | parser.add_argument('--precomp_enc_type', default="basic", 55 | help='basic|weight_norm') 56 | parser.add_argument('--bi_gru', action='store_true', 57 | help='Use bidirectional GRU.') 58 | parser.add_argument('--lambda_lse', default=6., type=float, 59 | help='LogSumExp temp.') 60 | parser.add_argument('--lambda_softmax', default=9., type=float, 61 | help='Attention softmax temperature.') 62 | opt = parser.parse_args() 63 | print(opt) 64 | -------------------------------------------------------------------------------- /OATrans/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from OATrans.logger import setup_logging 8 | from utils import read_json, write_json 9 | import time 10 | import inspect 11 | 12 | 13 | class ConfigParser: 14 | def __init__(self, args, options='', timestamp=True, test=False): 15 | # parse default and custom cli options 16 | for opt in options: 17 | args.add_argument(*opt.flags, default=None, type=opt.type) 18 | args = args.parse_args() 19 | self.args = args 20 | if args.device: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 22 | if args.resume is None: 23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 24 | assert args.config is not None, msg_no_cfg 25 | self.cfg_fname = Path(args.config) 26 | config = read_json(self.cfg_fname) 27 | self.resume = None 28 | else: 29 | self.resume = Path(args.resume) 30 | resume_cfg_fname = self.resume.parent / 'config.json' 31 | config = read_json(resume_cfg_fname) 32 | if args.config is not None: 33 | config.update(read_json(Path(args.config))) 34 | 35 | # load config file and apply custom cli options 36 | self._config = _update_config(config, options, args) 37 | 38 | # set save_dir where trained model and log will be saved. 39 | save_dir = Path(self.config['trainer']['save_dir']) 40 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 41 | 42 | exper_name = self.config['name'] 43 | self._save_dir = save_dir / 'models' / exper_name / timestamp 44 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp 45 | self._log_dir = save_dir / 'log' / exper_name / timestamp 46 | 47 | if not test: 48 | self.save_dir.mkdir(parents=True, exist_ok=True) 49 | self.log_dir.mkdir(parents=True, exist_ok=True) 50 | 51 | # if set, remove all previous experiments with the current config 52 | if vars(args).get("purge_exp_dir", False): 53 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 54 | config_dir = dirpath.parent 55 | existing = list(config_dir.glob("*")) 56 | print(f"purging {len(existing)} directories from config_dir...") 57 | tic = time.time() 58 | os.system(f"rm -rf {config_dir}") 59 | print(f"Finished purge in {time.time() - tic:.3f}s") 60 | 61 | # save updated config file to the checkpoint dir 62 | if not test: 63 | write_json(self.config, self.save_dir / 'config.json') 64 | 65 | # configure logging module 66 | setup_logging(self.log_dir) 67 | self.log_levels = { 68 | 0: logging.WARNING, 69 | 1: logging.INFO, 70 | 2: logging.DEBUG 71 | } 72 | 73 | def initialize(self, name, module, *args, index=None, **kwargs): 74 | """ 75 | finds a function handle with the name given as 'type' in config, and returns the 76 | instance initialized with corresponding keyword args given as 'args'. 77 | """ 78 | if index is None: 79 | module_name = self[name]['type'] 80 | module_args = dict(self[name]['args']) 81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 82 | module_args.update(kwargs) 83 | else: 84 | module_name = self[name][index]['type'] 85 | module_args = dict(self[name][index]['args']) 86 | 87 | # if parameter not in config subdict, then check if it's in global config. 88 | signature = inspect.signature(getattr(module, module_name).__init__) 89 | print(module_name) 90 | for param in signature.parameters.keys(): 91 | if param not in module_args and param in self.config: 92 | module_args[param] = self[param] 93 | if module_name == 'FrozenInTime' and param == 'args': 94 | module_args[param] = self.args 95 | if module_name == 'MultiDistTextObjectVideoDataLoader' and param == 'args': 96 | module_args[param] = self.args 97 | 98 | return getattr(module, module_name)(*args, **module_args) 99 | 100 | def __getitem__(self, name): 101 | return self.config[name] 102 | 103 | def get_logger(self, name, verbosity=2): 104 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 105 | self.log_levels.keys()) 106 | assert verbosity in self.log_levels, msg_verbosity 107 | logger = logging.getLogger(name) 108 | logger.setLevel(self.log_levels[verbosity]) 109 | return logger 110 | 111 | # setting read-only attributes 112 | @property 113 | def config(self): 114 | return self._config 115 | 116 | @property 117 | def save_dir(self): 118 | return self._save_dir 119 | 120 | @property 121 | def log_dir(self): 122 | return self._log_dir 123 | 124 | 125 | # helper functions used to update config dict with custom cli options 126 | def _update_config(config, options, args): 127 | for opt in options: 128 | value = getattr(args, _get_opt_name(opt.flags)) 129 | if value is not None: 130 | _set_by_path(config, opt.target, value) 131 | return config 132 | 133 | 134 | def _get_opt_name(flags): 135 | for flg in flags: 136 | if flg.startswith('--'): 137 | return flg.replace('--', '') 138 | return flags[0].replace('--', '') 139 | 140 | 141 | def _set_by_path(tree, keys, value): 142 | """Set a value in a nested object in tree by sequence of keys.""" 143 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 144 | 145 | 146 | def _get_by_path(tree, keys): 147 | """Access a nested object in tree by sequence of keys.""" 148 | return reduce(getitem, keys, tree) 149 | -------------------------------------------------------------------------------- /OATrans/parse_config_dist_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from OATrans.logger import setup_logging 8 | from utils import read_json, write_json 9 | import time 10 | import inspect 11 | 12 | 13 | class ConfigParser: 14 | def __init__(self, args, options='', timestamp=True, test=False): 15 | # parse default and custom cli options 16 | for opt in options: 17 | args.add_argument(*opt.flags, default=None, type=opt.type) 18 | args = args.parse_args() 19 | self.args = args 20 | if args.device: 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 22 | if args.resume is None: 23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 24 | assert args.config is not None, msg_no_cfg 25 | self.cfg_fname = Path(args.config) 26 | config = read_json(self.cfg_fname) 27 | self.resume = None 28 | else: 29 | self.resume = Path(args.resume) 30 | resume_cfg_fname = self.resume.parent / 'config.json' 31 | config = read_json(resume_cfg_fname) 32 | if args.config is not None: 33 | config.update(read_json(Path(args.config))) 34 | 35 | # load config file and apply custom cli options 36 | self._config = _update_config(config, options, args) 37 | 38 | # set save_dir where trained model and log will be saved. 39 | save_dir = Path(self.config['trainer']['save_dir']) 40 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 41 | 42 | exper_name = self.config['name'] 43 | self._save_dir = save_dir / 'models' / exper_name / timestamp 44 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp 45 | self._log_dir = save_dir / 'log' / exper_name / timestamp 46 | 47 | if not test: 48 | self.save_dir.mkdir(parents=True, exist_ok=True) 49 | self.log_dir.mkdir(parents=True, exist_ok=True) 50 | 51 | # if set, remove all previous experiments with the current config 52 | if vars(args).get("purge_exp_dir", False): 53 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 54 | config_dir = dirpath.parent 55 | existing = list(config_dir.glob("*")) 56 | print(f"purging {len(existing)} directories from config_dir...") 57 | tic = time.time() 58 | os.system(f"rm -rf {config_dir}") 59 | print(f"Finished purge in {time.time() - tic:.3f}s") 60 | 61 | # save updated config file to the checkpoint dir 62 | if not test: 63 | write_json(self.config, self.save_dir / 'config.json') 64 | 65 | # configure logging module 66 | setup_logging(self.log_dir) 67 | self.log_levels = { 68 | 0: logging.WARNING, 69 | 1: logging.INFO, 70 | 2: logging.DEBUG 71 | } 72 | 73 | def initialize(self, name, module, *args, index=None, **kwargs): 74 | """ 75 | finds a function handle with the name given as 'type' in config, and returns the 76 | instance initialized with corresponding keyword args given as 'args'. 77 | """ 78 | if index is None: 79 | module_name = self[name]['type'] 80 | module_args = dict(self[name]['args']) 81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 82 | module_args.update(kwargs) 83 | else: 84 | module_name = self[name][index]['type'] 85 | module_args = dict(self[name][index]['args']) 86 | 87 | # if parameter not in config subdict, then check if it's in global config. 88 | signature = inspect.signature(getattr(module, module_name).__init__) 89 | print(module_name) 90 | for param in signature.parameters.keys(): 91 | if param not in module_args and param in self.config: 92 | module_args[param] = self[param] 93 | if module_name == 'FrozenInTime' and param == 'args': 94 | module_args[param] = self.args 95 | if module_name == 'MultiDistTextObjectVideoDataLoader' and param == 'args': 96 | module_args[param] = self.args 97 | if module_name == 'TextObjectVideoDataLoader' and param == 'args': 98 | module_args[param] = self.args 99 | 100 | return getattr(module, module_name)(*args, **module_args) 101 | 102 | def __getitem__(self, name): 103 | return self.config[name] 104 | 105 | def get_logger(self, name, verbosity=2): 106 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 107 | self.log_levels.keys()) 108 | assert verbosity in self.log_levels, msg_verbosity 109 | logger = logging.getLogger(name) 110 | logger.setLevel(self.log_levels[verbosity]) 111 | return logger 112 | 113 | # setting read-only attributes 114 | @property 115 | def config(self): 116 | return self._config 117 | 118 | @property 119 | def save_dir(self): 120 | return self._save_dir 121 | 122 | @property 123 | def log_dir(self): 124 | return self._log_dir 125 | 126 | 127 | # helper functions used to update config dict with custom cli options 128 | def _update_config(config, options, args): 129 | for opt in options: 130 | value = getattr(args, _get_opt_name(opt.flags)) 131 | if value is not None: 132 | _set_by_path(config, opt.target, value) 133 | return config 134 | 135 | 136 | def _get_opt_name(flags): 137 | for flg in flags: 138 | if flg.startswith('--'): 139 | return flg.replace('--', '') 140 | return flags[0].replace('--', '') 141 | 142 | 143 | def _set_by_path(tree, keys, value): 144 | """Set a value in a nested object in tree by sequence of keys.""" 145 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 146 | 147 | 148 | def _get_by_path(tree, keys): 149 | """Access a nested object in tree by sequence of keys.""" 150 | return reduce(getitem, keys, tree) 151 | -------------------------------------------------------------------------------- /OATrans/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | from OATrans.data_loader import data_loader as module_data 4 | from OATrans import model as module_loss, model as module_metric, model as module_arch 5 | import utils.visualizer as module_vis 6 | from utils.util import replace_nested_dict_item 7 | from parse_config_dist_multi import ConfigParser 8 | from trainer.trainer import Trainer 9 | from sacred import Experiment 10 | from neptunecontrib.monitoring.sacred import NeptuneObserver 11 | import transformers 12 | import os 13 | 14 | ex = Experiment('train') 15 | 16 | 17 | @ex.main 18 | def run(): 19 | logger = config.get_logger('train') 20 | os.environ['TOKENIZERS_PARALLELISM'] = "false" 21 | # TODO: improve Create identity (do nothing) visualiser? 22 | if config['visualizer']['type'] != "": 23 | visualizer = config.initialize( 24 | name='visualizer', 25 | module=module_vis, 26 | exp_name=config['name'], 27 | web_dir=config._web_log_dir 28 | ) 29 | else: 30 | visualizer = None 31 | # pdb.set_trace() 32 | # build tokenizer 33 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'], 34 | TOKENIZERS_PARALLELISM=False) 35 | 36 | # setup data_loader instances 37 | data_loader, valid_data_loader = init_dataloaders(config, module_data) 38 | print('Train dataset: ', [x.n_samples for x in data_loader], ' samples') 39 | print('Val dataset: ', [x.n_samples for x in valid_data_loader], ' samples') 40 | # build model architecture, then print to console 41 | model = config.initialize('arch', module_arch) 42 | logger.info(model) 43 | 44 | # get function handles of loss and metrics 45 | loss = config.initialize(name="loss", module=module_loss) 46 | metrics = [getattr(module_metric, met) for met in config['metrics']] 47 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 48 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 49 | optimizer = config.initialize('optimizer', transformers, trainable_params) 50 | lr_scheduler = None 51 | if 'lr_scheduler' in config._config: 52 | if hasattr(transformers, config._config['lr_scheduler']['type']): 53 | lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer) 54 | else: 55 | print('lr scheduler not found') 56 | if config['trainer']['neptune']: 57 | writer = ex 58 | else: 59 | writer = None 60 | trainer = Trainer(model, loss, metrics, optimizer, 61 | config=config, 62 | data_loader=data_loader, 63 | valid_data_loader=valid_data_loader, 64 | lr_scheduler=lr_scheduler, 65 | visualizer=visualizer, 66 | writer=writer, 67 | tokenizer=tokenizer, 68 | max_samples_per_epoch=config['trainer']['max_samples_per_epoch']) 69 | trainer.train() 70 | 71 | 72 | def init_dataloaders(config, module_data): 73 | """ 74 | We need a way to change split from 'train' to 'val'. 75 | """ 76 | if "type" in config["data_loader"] and "args" in config["data_loader"]: 77 | # then its a single dataloader 78 | data_loader = [config.initialize("data_loader", module_data)] 79 | config['data_loader']['args'] = replace_nested_dict_item(config['data_loader']['args'], 'split', 'val') 80 | valid_data_loader = [config.initialize("data_loader", module_data)] 81 | elif isinstance(config["data_loader"], list): 82 | data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in 83 | range(len(config['data_loader']))] 84 | new_cfg_li = [] 85 | for dl_cfg in config['data_loader']: 86 | dl_cfg['args'] = replace_nested_dict_item(dl_cfg['args'], 'split', 'val') 87 | new_cfg_li.append(dl_cfg) 88 | config._config['data_loader'] = new_cfg_li 89 | valid_data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in 90 | range(len(config['data_loader']))] 91 | else: 92 | raise ValueError("Check data_loader config, not correct format.") 93 | 94 | return data_loader, valid_data_loader 95 | 96 | 97 | if __name__ == '__main__': 98 | args = argparse.ArgumentParser(description='PyTorch Template') 99 | args.add_argument('-c', '--config', default=None, type=str, 100 | help='config file path (default: None)') 101 | args.add_argument('-r', '--resume', default=None, type=str, 102 | help='path to latest checkpoint (default: None)') 103 | args.add_argument('-d', '--device', default=None, type=str, 104 | help='indices of GPUs to enable (default: all)') 105 | args.add_argument('-o', '--observe', action='store_true', 106 | help='Whether to observe (neptune)') 107 | # custom cli options to modify configuration from default values given in json file. 108 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 109 | options = [ 110 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')), 111 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')), 112 | ] 113 | config = ConfigParser(args, options) 114 | ex.add_config(config._config) 115 | 116 | if config['trainer']['neptune']: 117 | # delete this error if you have added your own neptune credentials neptune.ai 118 | # raise ValueError('Neptune credentials not set up yet.') 119 | ex.observers.append(NeptuneObserver( 120 | api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkZTg4NGQ4YS01NmRlLTQwMzEtYjc2NS1mYjY3MzRiMWNjZTYifQ==', 121 | project_name='awinyimgprocess/Frozen')) 122 | ex.run() 123 | else: 124 | run() 125 | -------------------------------------------------------------------------------- /OATrans/train_dist_region_mem.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import data_loader.data_loader as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.oa_model_region_mem as module_arch 8 | import utils.visualizer as module_vis 9 | from utils.util import replace_nested_dict_item 10 | from parse_config_dist_multi import ConfigParser 11 | from trainer.trainer_region_mem import Multi_Trainer_dist 12 | from sacred import Experiment 13 | from neptunecontrib.monitoring.sacred import NeptuneObserver 14 | import transformers 15 | import os 16 | 17 | ex = Experiment('train') 18 | 19 | @ex.main 20 | def run(): 21 | logger = config.get_logger('train') 22 | os.environ['TOKENIZERS_PARALLELISM'] = "false" 23 | os.environ['TRANSFORMERS_OFFLINE'] = "1" 24 | # TODO: improve Create identity (do nothing) visualiser? 25 | if config['visualizer']['type'] != "": 26 | visualizer = config.initialize( 27 | name='visualizer', 28 | module=module_vis, 29 | exp_name=config['name'], 30 | web_dir=config._web_log_dir 31 | ) 32 | else: 33 | visualizer = None 34 | torch.cuda.set_device(args.local_rank) 35 | torch.distributed.init_process_group(backend='nccl', 36 | init_method='tcp://{}:{}'.format( 37 | args.master_address, args.master_port), 38 | rank=args.rank, world_size=args.world_size) 39 | device = torch.device(f'cuda:{args.local_rank}') 40 | print('world_size', args.world_size, flush=True) 41 | print('local_rank: ', args.local_rank, flush=True) 42 | # build tokenizer 43 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'], 44 | TOKENIZERS_PARALLELISM=False) 45 | 46 | # setup data_loader instances 47 | data_loader, valid_data_loader = init_dataloaders(config, module_data) 48 | print('Train dataset: ', [x.n_samples for x in data_loader], ' samples') 49 | print('Val dataset: ', [x.n_samples for x in valid_data_loader], ' samples') 50 | # build model architecture, then print to console 51 | model = config.initialize('arch', module_arch) 52 | if args.local_rank == 0: 53 | logger.info(model) 54 | 55 | # get function handles of loss and metrics 56 | loss = config.initialize(name="loss", module=module_loss) 57 | metrics = [getattr(module_metric, met) for met in config['metrics']] 58 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 59 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 60 | optimizer = config.initialize('optimizer', transformers, trainable_params) 61 | lr_scheduler = None 62 | if 'lr_scheduler' in config._config: 63 | if hasattr(transformers, config._config['lr_scheduler']['type']): 64 | lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer) 65 | else: 66 | print('lr scheduler not found') 67 | if config['trainer']['neptune']: 68 | writer = ex 69 | else: 70 | writer = None 71 | trainer = Multi_Trainer_dist(args, model, loss, metrics, optimizer, 72 | config=config, 73 | data_loader=data_loader, 74 | valid_data_loader=valid_data_loader, 75 | lr_scheduler=lr_scheduler, 76 | visualizer=visualizer, 77 | writer=writer, 78 | tokenizer=tokenizer, 79 | max_samples_per_epoch=config['trainer']['max_samples_per_epoch']) 80 | trainer.train() 81 | 82 | 83 | def init_dataloaders(config, module_data): 84 | """ 85 | We need a way to change split from 'train' to 'val'. 86 | """ 87 | if "type" in config["data_loader"] and "args" in config["data_loader"]: 88 | # then its a single dataloader 89 | data_loader = [config.initialize("data_loader", module_data)] 90 | config['data_loader']['args'] = replace_nested_dict_item(config['data_loader']['args'], 'split', 'val') 91 | valid_data_loader = [config.initialize("data_loader", module_data)] 92 | elif isinstance(config["data_loader"], list): 93 | data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in 94 | range(len(config['data_loader']))] 95 | new_cfg_li = [] 96 | for dl_cfg in config['data_loader']: 97 | dl_cfg['args'] = replace_nested_dict_item(dl_cfg['args'], 'split', 'val') 98 | new_cfg_li.append(dl_cfg) 99 | config._config['data_loader'] = new_cfg_li 100 | valid_data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in 101 | range(len(config['data_loader']))] 102 | else: 103 | raise ValueError("Check data_loader config, not correct format.") 104 | 105 | return data_loader, valid_data_loader 106 | 107 | 108 | if __name__ == '__main__': 109 | args = argparse.ArgumentParser(description='PyTorch Template') 110 | args.add_argument('-c', '--config', default=None, type=str, 111 | help='config file path (default: None)') 112 | args.add_argument('-r', '--resume', default=None, type=str, 113 | help='path to latest checkpoint (default: None)') 114 | args.add_argument('-d', '--device', default=None, type=str, 115 | help='indices of GPUs to enable (default: all)') 116 | args.add_argument('-o', '--observe', action='store_true', 117 | help='Whether to observe (neptune)') 118 | args.add_argument('-l', '--launcher', choices=['none', 'pytorch'], default='none',help='job launcher') 119 | args.add_argument('-k', '--local_rank', type=int, default=0) 120 | 121 | master_address = os.environ['MASTER_ADDR'] 122 | master_port = int(os.environ['MASTER_PORT']) 123 | world_size = int(os.environ['WORLD_SIZE']) 124 | # world_size = int(torch.cuda.device_count()) 125 | rank = int(os.environ['RANK']) 126 | args.local_rank = int(os.environ['LOCAL_RANK']) 127 | 128 | if torch.cuda.device_count() > 1: 129 | print("Let's use", torch.cuda.device_count(), "GPUs!") 130 | 131 | args.add_argument('-ma', '--master_address', default=master_address) 132 | args.add_argument('-mp', '--master_port', type=int, default=master_port) 133 | args.add_argument('-ws', '--world_size', type=int, default=world_size) 134 | args.add_argument('-rk', '--rank', type=int, default=rank) 135 | args.add_argument('-lr1', '--learning_rate1', type=float, default=2e-4) 136 | args.add_argument('-sc', '--schedule', default=[60, 80]) 137 | 138 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 139 | options = [ 140 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')), 141 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')), 142 | ] 143 | config = ConfigParser(args, options) 144 | args = args.parse_args() 145 | ex.add_config(config._config) 146 | 147 | if config['trainer']['neptune']: 148 | # delete this error if you have added your own neptune credentials neptune.ai 149 | # raise ValueError('Neptune credentials not set up yet.') 150 | ex.observers.append(NeptuneObserver( 151 | api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkZTg4NGQ4YS01NmRlLTQwMzEtYjc2NS1mYjY3MzRiMWNjZTYifQ==', 152 | project_name='awinyimgprocess/Frozen')) 153 | ex.run() 154 | else: 155 | run() 156 | -------------------------------------------------------------------------------- /OATrans/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer_dist import * 2 | -------------------------------------------------------------------------------- /OATrans/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/.DS_Store -------------------------------------------------------------------------------- /OATrans/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /OATrans/utils/binary_classification_accuracy.py: -------------------------------------------------------------------------------- 1 | def get_accuracy(y_true, y_prob): 2 | assert y_true.ndim == 1 and y_true.size() == y_prob.size() 3 | y_prob = y_prob > 0.5 4 | return (y_true == y_prob).sum().item() / y_true.size(0) -------------------------------------------------------------------------------- /OATrans/utils/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Tuple, Any, Optional 5 | from torchvision.transforms import functional_pil as F_pil 6 | from torchvision.transforms import functional_tensor as F_t 7 | from torchvision.transforms.functional import center_crop, crop 8 | 9 | def _get_image_size(img: Tensor) -> List[int]: 10 | """Returns image size as [w, h] 11 | """ 12 | if isinstance(img, torch.Tensor): 13 | return F_t._get_image_size(img) 14 | 15 | return F_pil._get_image_size(img) 16 | 17 | def center_plus_four_crops(img: Tensor, size: List[int], 18 | margin_h: int, margin_w: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 19 | """Crop the given image into four tiled borders and the central crop. 20 | """ 21 | 22 | if isinstance(size, numbers.Number): 23 | size = (int(size), int(size)) 24 | elif isinstance(size, (tuple, list)) and len(size) == 1: 25 | size = (size[0], size[0]) 26 | 27 | if len(size) != 2: 28 | raise ValueError("Please provide only two dimensions (h, w) for size.") 29 | 30 | image_width, image_height = _get_image_size(img) 31 | 32 | crop_height, crop_width = size 33 | 34 | if crop_width > image_width or crop_height > image_height: 35 | msg = "Requested crop size {} is bigger than input size {}" 36 | raise ValueError(msg.format(size, (image_height, image_width))) 37 | 38 | if crop_width + margin_w > image_width: 39 | msg = "Requested margin size {} + input {} is bigger than input size {}" 40 | raise ValueError(msg.format((margin_h, margin_w), size, (image_height, image_width))) 41 | 42 | #vertical_border_height = image_height - crop_height 43 | #horizontal_border_height = image_width - crop_width 44 | 45 | #x1 = horizontal_border_height // 2 46 | x11 = (image_width - crop_width - 2 * margin_w) // 2 47 | x12 = x11 + margin_w 48 | x21 = x12 + crop_width 49 | x22 = x21 + margin_w 50 | 51 | y11 = (image_height - crop_height - 2 * margin_h) // 2 52 | y12 = y11 + margin_h 53 | y21 = y12 + crop_height 54 | y22 = y21 + margin_h 55 | 56 | tl = crop(img, y11, x11, margin_h, margin_w + crop_width) 57 | tr = crop(img, y11, x21, margin_h + crop_height, margin_w) 58 | bl = crop(img, y12, x11, margin_h + crop_height, margin_w) 59 | br = crop(img, y21, x12, margin_h, margin_w + crop_width) 60 | center = center_crop(img, [crop_height, crop_width]) 61 | 62 | return tl, tr, bl, br, center 63 | 64 | 65 | 66 | def center_plus_twohori_crops(img: Tensor, size: List[int], 67 | margin_w: int) -> Tuple[Tensor, Tensor, Tensor]: 68 | """Crop the given image into four tiled borders and the central crop. 69 | """ 70 | 71 | if isinstance(size, numbers.Number): 72 | size = (int(size), int(size)) 73 | elif isinstance(size, (tuple, list)) and len(size) == 1: 74 | size = (size[0], size[0]) 75 | 76 | if len(size) != 2: 77 | raise ValueError("Please provide only two dimensions (h, w) for size.") 78 | 79 | image_width, image_height = _get_image_size(img) 80 | 81 | crop_height, crop_width = size 82 | 83 | if crop_width > image_width or crop_height > image_height: 84 | msg = "Requested crop size {} is bigger than input size {}" 85 | raise ValueError(msg.format(size, (image_height, image_width))) 86 | 87 | if crop_width + margin_w > image_width : 88 | msg = "Requested margin size {} + input {} is bigger than input size {}" 89 | raise ValueError(msg.format((0, margin_w), size, (image_height, image_width))) 90 | 91 | # vertical_border_height = image_height - crop_height 92 | # horizontal_border_height = image_width - crop_width 93 | 94 | # x1 = horizontal_border_height // 2 95 | x11 = (image_width - crop_width - 2 * margin_w) // 2 96 | x12 = x11 + margin_w 97 | x21 = x12 + crop_width 98 | 99 | y11 = (image_height - crop_height) // 2 100 | 101 | left = crop(img, y11, x11, crop_height, margin_w) 102 | right = crop(img, y11, x21, crop_height, margin_w) 103 | center = center_crop(img, [crop_height, crop_width]) 104 | 105 | return left, right, center 106 | 107 | from torch import nn 108 | class TwoHoriCrop(nn.Module): 109 | def __init__(self, size, margin_w): 110 | super().__init__() 111 | self.size = size 112 | self.margin_w = margin_w 113 | 114 | def forward(self, x): 115 | return center_plus_twohori_crops(x, self.size, self.margin_w) 116 | 117 | if __name__ == "__main__": 118 | from PIL import Image 119 | 120 | img = Image.open('visualisations/guitar.png') 121 | crops = center_plus_four_crops(img, [336, 336], 112, 112) 122 | order = ['tl', 'tr', 'bl', 'br', 'center'] 123 | 124 | for idx, subimg in zip(order, crops): 125 | subimg.save(f'visualisations/guitar_{idx}.png') 126 | 127 | crops = center_plus_twohori_crops(img, [448, 448], 112) 128 | order = ['left', 'right', 'center2'] 129 | 130 | for idx, subimg in zip(order, crops): 131 | subimg.save(f'visualisations/guitar_{idx}.png') 132 | -------------------------------------------------------------------------------- /OATrans/utils/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source, attr 3 | from dominate.tags import span 4 | import os 5 | 6 | 7 | class HTML: 8 | """This HTML class allows us to save images and write texts into a single HTML file. 9 | 10 | It consists of functions such as (add a text header to the HTML file), 11 | (add a row of images to the HTML file), and (save the HTML to the disk). 12 | It is based on Python library 'dominate', a Python library for creating and 13 | manipulating HTML documents using a DOM API. 14 | """ 15 | 16 | def __init__(self, web_dir, title, refresh=0): 17 | """Initialize the HTML classes 18 | 19 | Parameters: 20 | web_dir (str) -- a directory that stores the webpage. HTML file will be 21 | created at /index.html; images will be saved at 0: 35 | with self.doc.head: 36 | meta(http_equiv="refresh", content=str(refresh)) 37 | 38 | def get_image_dir(self): 39 | """Return the directory that stores images""" 40 | return self.img_dir 41 | 42 | def add_header(self, text): 43 | """Insert a header to the HTML file 44 | 45 | Parameters: 46 | text (str) -- the header text 47 | """ 48 | with self.doc: 49 | h3(text) 50 | 51 | def add_videos(self, vids, txts, links, width=400, hidden_tag="hidden"): 52 | """add images to the HTML file 53 | 54 | Parameters: 55 | vids (str list) -- a list of image paths 56 | txts (str list) -- a list of image names shown on the website 57 | links (str list) -- a list of hyperref links; when you click an image, 58 | it will redirect you to a new page 59 | """ 60 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 61 | self.doc.add(self.t) 62 | colors = ["red", "blue", "gold", "salman"] 63 | with self.t: 64 | with tr(): 65 | for vid, txt, link in zip(vids, txts, links): 66 | td_style = "word-wrap: break-word; width:{}px".format(width) 67 | with td(style=td_style, halign="center", valign="top"): 68 | with p(): 69 | vid_path = str(vid) 70 | if vid_path == hidden_tag: 71 | p_style = "font-weight: bold; width:{}px;" 72 | p_style = p_style.format(width * 3) 73 | p("hidden video", style=p_style) 74 | else: 75 | with a(href=str(link)): 76 | with video(): 77 | attr(controls="controls") 78 | source(src=vid_path, type="video/mp4") 79 | br() 80 | rows = txt.split("
") 81 | for idx, row in enumerate(rows): 82 | color = colors[idx % len(colors)] 83 | bold_tag = "" 84 | if not row.startswith(bold_tag): 85 | s_style = "color:{};".format(color) 86 | else: 87 | s_style = "color:black; font-weight: bold;" 88 | row = row[len(bold_tag):] 89 | span(row, style=s_style) 90 | 91 | def add_images(self, ims, txts, links, width=400): 92 | """add images to the HTML file 93 | 94 | Parameters: 95 | ims (str list) -- a list of image paths 96 | txts (str list) -- a list of image names shown on the website 97 | links (str list) -- a list of hyperref links; when you click an image, 98 | it will redirect you to a new page 99 | """ 100 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 101 | self.doc.add(self.t) 102 | colors = ["red", "blue", "gold", "salman"] 103 | with self.t: 104 | with tr(): 105 | for im, txt, link in zip(ims, txts, links): 106 | td_style = "word-wrap: break-word;" 107 | with td(style=td_style, halign="center", valign="top"): 108 | with p(): 109 | with a(href=link): 110 | img( 111 | style="width:%dpx" % width, 112 | src=im, 113 | ) 114 | br() 115 | rows = txt.split("
") 116 | for idx, row in enumerate(rows): 117 | color = colors[idx % len(colors)] 118 | bold_tag = "" 119 | if not row.startswith(bold_tag): 120 | s_style = "color:{};".format(color) 121 | else: 122 | s_style = "color:black; font-weight: bold;" 123 | row = row[len(bold_tag):] 124 | span(row, style=s_style) 125 | 126 | def save(self): 127 | """save the current content to the HMTL file""" 128 | html_file = "%s/index.html" % self.web_dir 129 | f = open(html_file, "wt") 130 | f.write(self.doc.render()) 131 | f.close() 132 | 133 | 134 | if __name__ == "__main__": # we show an example usage here. 135 | html = HTML("web/", "test_html") 136 | html.add_header("hello world") 137 | 138 | ims, txts, links = [], [], [] 139 | for n in range(4): 140 | ims.append("image_%d.png" % n) 141 | txts.append("text_%d" % n) 142 | links.append("image_%d.png" % n) 143 | html.add_images(ims, txts, links) 144 | html.save() -------------------------------------------------------------------------------- /OATrans/utils/objects_vocab_token_len: -------------------------------------------------------------------------------- 1 | [2, 1, 1, 3, 1, 3, 2, 2, 1, 1, 2, 2, 3, 1, 1, 1, 2, 1, 1, 3, 1, 2, 1, 2, 1, 1, 1, 3, 2, 1, 1, 2, 2, 1, 2, 3, 1, 4, 2, 1, 1, 2, 1, 1, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 2, 3, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 3, 2, 3, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 2, 1, 2, 1, 5, 1, 1, 2, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 3, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 3, 3, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 1, 3, 2, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 3, 2, 3, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 3, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 3, 2, 2, 1, 1, 1, 1, 1, 3, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 2, 1, 1, 1, 2, 1, 4, 2, 1, 2, 1, 2, 1, 1, 1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 1, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 4, 3, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 3, 1, 1, 1, 2, 1, 3, 2, 1, 1, 1, 1, 2, 2, 3, 2, 1, 2, 1, 3, 1, 2, 1, 2, 2, 2, 3, 1, 1, 1, 1, 1, 2, 2, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 3, 1, 1, 4, 2, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 3, 1, 1, 1, 1, 3, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 2, 3, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 3, 2, 3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 5, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 4, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 3, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 1, 1, 2, 3, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 2, 2, 1, 1, 1, 1, 3, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 1, 4, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 3, 2, 1, 2, 1, 2, 2, 1, 1, 3, 1, 3, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 3, 2, 1, 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 3, 1, 3, 3, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 3, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 3, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 3, 1, 1, 1, 2, 1, 1, 3, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 3, 1, 1, 1, 3, 2, 1, 1, 3, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 3, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 3, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2] -------------------------------------------------------------------------------- /OATrans/utils/param_forzen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def forzen_param(model): 5 | for name, param in model.named_parameters(): 6 | if 'vid_proj' in name or 'txt_proj' in name: 7 | param.requires_grad = True 8 | else: 9 | param.requires_grad = False 10 | return True -------------------------------------------------------------------------------- /OATrans/utils/unit_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/unit_test/__init__.py -------------------------------------------------------------------------------- /OATrans/utils/unit_test/distill_bert.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | import transformers 3 | 4 | text = "tree" 5 | 6 | tokenizer = transformers.AutoTokenizer.from_pretrained("pretrained/distilbert-base-uncased", 7 | TOKENIZERS_PARALLELISM=False) 8 | 9 | text_data = tokenizer(text, return_tensors='pt', padding=True, truncation=True) 10 | text_data = {key: val.cuda() for key, val in text_data.items()} 11 | 12 | 13 | text_model = AutoModel.from_pretrained("pretrained/distilbert-base-uncased").cuda() 14 | 15 | print(text_model) 16 | 17 | text_embeddings_all = text_model(**text_data).last_hidden_state 18 | print(text_embeddings_all.size()) 19 | text_embeddings = text_embeddings_all[:, 0, :] 20 | print(text_embeddings) 21 | 22 | 23 | text_embeddings_2 = text_model.embeddings(text_data['input_ids']) 24 | 25 | text_embeddings_2 = text_model.transformer(text_embeddings_2, 26 | attn_mask=attention_mask, 27 | head_mask=head_mask, 28 | output_attentions=output_attentions, 29 | output_hidden_states=output_hidden_states, 30 | return_dict=return_dict, 31 | ) 32 | 33 | print(text_embeddings - text_embeddings_2) -------------------------------------------------------------------------------- /OATrans/utils/unit_test/load_msvd_video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | video_path = "MSVD/YouTubeClips/fVWUaH2mCt4_1_7.avi" 4 | cap = cv2.VideoCapture(video_path) 5 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 6 | print(vlen) -------------------------------------------------------------------------------- /OATrans/utils/unit_test/region_roi_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | usage 3 | # add for roi pooling 4 | import torch 5 | import torchvision.ops.roi_align as roi_align 6 | self.roi_align = roi_align 7 | """ 8 | 9 | def region_embed(self, x, bbox): 10 | """ 11 | Args: 12 | x (): the input video 13 | bbox (): bounding boxes with 4 loc + height/width; stacked for num_frame times 14 | 15 | Returns: 16 | the raw pixel region of bbox 17 | """ 18 | b, t, c, h, w = x.size() 19 | x = x.view(-1, c, h, w) 20 | B, L, N = bbox.size() 21 | coordinates = torch.zeros((B * L, 5)).cuda() 22 | for i in range(B * L): 23 | coordinates[i][0] = i // L 24 | coordinates[i][1:] = bbox[i // L, i % L, :4] 25 | regions = self.roi_align(x, coordinates, output_size=[self.patch_size, self.patch_size]) 26 | region_features = self.region_embedding_layer(regions) 27 | region_features = region_features.view(-1, L // t, self.embed_dim) 28 | return region_features -------------------------------------------------------------------------------- /OATrans/utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from datetime import datetime 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | import functools 7 | import time 8 | import socket 9 | import numpy as np 10 | import psutil 11 | import msgpack 12 | import humanize 13 | import os 14 | 15 | def replace_nested_dict_item(obj, key, replace_value): 16 | for k, v in obj.items(): 17 | if isinstance(v, dict): 18 | obj[k] = replace_nested_dict_item(v, key, replace_value) 19 | if key in obj: 20 | obj[key] = replace_value 21 | return obj 22 | 23 | 24 | def state_dict_data_parallel_fix(load_state_dict, curr_state_dict): 25 | load_keys = list(load_state_dict.keys()) 26 | curr_keys = list(curr_state_dict.keys()) 27 | 28 | redo_dp = False 29 | undo_dp = False 30 | if not curr_keys[0].startswith('module.') and load_keys[0].startswith('module.'): 31 | undo_dp = True 32 | elif curr_keys[0].startswith('module.') and not load_keys[0].startswith('module.'): 33 | redo_dp = True 34 | 35 | if undo_dp: 36 | from collections import OrderedDict 37 | new_state_dict = OrderedDict() 38 | for k, v in load_state_dict.items(): 39 | name = k[7:] # remove `module.` 40 | new_state_dict[name] = v 41 | # load params 42 | elif redo_dp: 43 | from collections import OrderedDict 44 | new_state_dict = OrderedDict() 45 | for k, v in load_state_dict.items(): 46 | name = 'module.' + k # remove `module.` 47 | new_state_dict[name] = v 48 | else: 49 | new_state_dict = load_state_dict 50 | return new_state_dict 51 | 52 | def print_numpy(x, val=True, shp=False): 53 | """Print the mean, min, max, median, std, and size of a numpy array 54 | Parameters: 55 | val (bool) -- if print the values of the numpy array 56 | shp (bool) -- if print the shape of the numpy array 57 | """ 58 | x = x.astype(np.float64) 59 | if shp: 60 | print('shape,', x.shape) 61 | if val: 62 | x = x.flatten() 63 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 64 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 65 | 66 | 67 | def mkdirs(paths): 68 | """create empty directories if they don't exist 69 | Parameters: 70 | paths (str list) -- a list of directory paths 71 | """ 72 | if isinstance(paths, list) and not isinstance(paths, str): 73 | for path in paths: 74 | mkdir(path) 75 | else: 76 | mkdir(paths) 77 | 78 | 79 | def mkdir(path): 80 | """create a single empty directory if it didn't exist 81 | Parameters: 82 | path (str) -- a single directory path 83 | """ 84 | if not os.path.exists(path): 85 | os.makedirs(path) 86 | 87 | def read_json(fname): 88 | with fname.open('rt') as handle: 89 | return json.load(handle, object_hook=OrderedDict) 90 | 91 | def write_json(content, fname): 92 | with fname.open('wt') as handle: 93 | json.dump(content, handle, indent=4, sort_keys=False) 94 | 95 | def inf_loop(data_loader): 96 | ''' wrapper function for endless data loader. ''' 97 | for loader in repeat(data_loader): 98 | yield from loader 99 | 100 | def memory_summary(): 101 | vmem = psutil.virtual_memory() 102 | msg = ( 103 | f">>> Currently using {vmem.percent}% of system memory " 104 | f"{humanize.naturalsize(vmem.used)}/{humanize.naturalsize(vmem.available)}" 105 | ) 106 | print(msg) 107 | 108 | @functools.lru_cache(maxsize=64, typed=False) 109 | def memcache(path): 110 | suffix = Path(path).suffix 111 | print(f"loading features >>>", end=" ") 112 | tic = time.time() 113 | if suffix == ".npy": 114 | res = np_loader(path) 115 | else: 116 | raise ValueError(f"unknown suffix: {suffix} for path {path}") 117 | print(f"[Total: {time.time() - tic:.1f}s] ({socket.gethostname() + ':' + str(path)})") 118 | return res 119 | 120 | def np_loader(np_path, l2norm=False): 121 | with open(np_path, "rb") as f: 122 | data = np.load(f, encoding="latin1", allow_pickle=True) 123 | if isinstance(data, np.ndarray) and data.size == 1: 124 | data = data[()] # handle numpy dict storage convnetion 125 | if l2norm: 126 | print("L2 normalizing features") 127 | if isinstance(data, dict): 128 | for key in data: 129 | feats_ = data[key] 130 | feats_ = feats_ / max(np.linalg.norm(feats_), 1E-6) 131 | data[key] = feats_ 132 | elif data.ndim == 2: 133 | data_norm = np.linalg.norm(data, axis=1) 134 | data = data / np.maximum(data_norm.reshape(-1, 1), 1E-6) 135 | else: 136 | raise ValueError("unexpected data format {}".format(type(data))) 137 | return data 138 | 139 | 140 | class Timer: 141 | def __init__(self): 142 | self.cache = datetime.now() 143 | 144 | def check(self): 145 | now = datetime.now() 146 | duration = now - self.cache 147 | self.cache = now 148 | return duration.total_seconds() 149 | 150 | def reset(self): 151 | self.cache = datetime.now() 152 | -------------------------------------------------------------------------------- /OATrans/utils/video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import PIL 5 | import collections 6 | import random 7 | import cv2 8 | import os 9 | import numpy as np 10 | 11 | def load_frames_from_video_path(path, num_frames, sample='rand'): 12 | cap = cv2.VideoCapture(path) 13 | assert (cap.isOpened()) 14 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 15 | acc_samples = min(num_frames, vlen) 16 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 17 | ranges = [] 18 | for idx, interv in enumerate(intervals[:-1]): 19 | ranges.append((interv, intervals[idx + 1] - 1)) 20 | if sample == 'rand': 21 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 22 | elif sample == 'uniform': 23 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 24 | else: 25 | raise NotImplementedError 26 | 27 | frames = [] 28 | for index in frame_idxs: 29 | cap.set(cv2.CAP_PROP_POS_FRAMES, index) 30 | ret, frame = cap.read() 31 | if ret: 32 | cv2.imwrite(f'images/{index}.jpg', frame) 33 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 34 | frame = torch.from_numpy(frame) 35 | # (H x W x C) to (C x H x W) 36 | frame = frame.permute(2, 0, 1) 37 | frames.append(frame) 38 | else: 39 | raise ValueError 40 | 41 | frames = torch.stack(frames).float() / 255 42 | cap.release() 43 | return frames, frame_idxs -------------------------------------------------------------------------------- /OATrans/utils/visualization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/.DS_Store -------------------------------------------------------------------------------- /OATrans/utils/visualization/3f_vto_visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | visualize both image + object + text 3 | """ 4 | import numpy as np 5 | import cv2 6 | from csv import reader 7 | import os 8 | import random 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import pdb 12 | import textwrap 13 | import pandas as pd 14 | 15 | full_csv = "meta_data/webvid_training_success_full.tsv" 16 | data_source = "WebVid/train" 17 | feat_source = "WebVid/8_frame_object/train/" 18 | output = "WebVid2M_visualization/train" 19 | 20 | 21 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None): 22 | acc_samples = min(num_frames, vlen) 23 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 24 | ranges = [] 25 | for idx, interv in enumerate(intervals[:-1]): 26 | ranges.append((interv, intervals[idx + 1] - 1)) 27 | if sample == 'rand': 28 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 29 | elif fix_start is not None: 30 | frame_idxs = [x[0] + fix_start for x in ranges] 31 | elif sample == 'uniform': 32 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 33 | else: 34 | raise NotImplementedError 35 | return frame_idxs 36 | 37 | 38 | def read_frames_cv2(video_path, num_frames, sample='uniform', fix_start=None): 39 | cap = cv2.VideoCapture(video_path) 40 | assert (cap.isOpened()) 41 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 42 | # get indexes of sampled frames 43 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start) 44 | frames = [] 45 | success_idxs = [] 46 | for index in frame_idxs: 47 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) 48 | ret, frame = cap.read() 49 | if ret: 50 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 51 | frames.append(frame) 52 | success_idxs.append(index) 53 | else: 54 | pass 55 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})') 56 | cap.release() 57 | return frames 58 | 59 | # step1: open video 60 | 61 | # step2: video seq 62 | 63 | # step3: video feature 64 | 65 | 66 | def tri_region_visualize(imgs, feat_paths, caption, outpath="visualization/1.png"): 67 | concat_imgs = None 68 | for i in [0, 3, 7]: 69 | frame1 = np.load(feat_paths[i], allow_pickle=True) 70 | boxes = frame1['bbox'] 71 | features = frame1['x'] # 20 x 2048 72 | confident = frame1['info'].item()['objects_conf'] 73 | # step 1: re-ranking the region with confidence 74 | object_ids = frame1['info'].item()['objects_id'] 75 | condident_indices = np.argsort(confident)[::-1] 76 | boxes = boxes[condident_indices] 77 | features = features[condident_indices] 78 | object_ids = object_ids[condident_indices] 79 | confident = confident[condident_indices] 80 | new_object, unique_indices = np.unique(object_ids, return_index=True) 81 | # step 2: remove region with same object class 82 | boxes = boxes[unique_indices] 83 | features = features[unique_indices] 84 | object_ids = object_ids[unique_indices] 85 | # object_ids = object_ids[unique_indices] 86 | # confident = confident[unique_indices] 87 | # # print(boxes, features) 88 | # image_width = frame1['info'].item()['image_w'] 89 | # image_height = frame1['info'].item()['image_h'] 90 | # box_width = boxes[:, 2] - boxes[:, 0] 91 | # box_height = boxes[:, 3] - boxes[:, 1] 92 | # scaled_width = box_width / image_width 93 | # scaled_height = box_height / image_height 94 | # scaled_x = boxes[:, 0] / image_width 95 | # scaled_y = boxes[:, 1] / image_height 96 | # scaled_width = scaled_width[..., np.newaxis] 97 | # scaled_height = scaled_height[..., np.newaxis] 98 | # scaled_x = scaled_x[..., np.newaxis] 99 | # scaled_y = scaled_y[..., np.newaxis] 100 | # spatial_features = np.concatenate( 101 | # (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1) 102 | # feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1) 103 | classes = ['__background__'] 104 | with open('utils/objects_vocab.txt', 'r') as f: 105 | for object in f.readlines(): 106 | classes.append(object.split(',')[0].lower().strip()) 107 | # print(features.shape) 108 | # plot top 5 objects 109 | img = imgs[i] 110 | if len(boxes) < 5: 111 | return False 112 | # print(img.shape) 113 | # print(img) 114 | colormap = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (155, 100, 100), (100, 155, 100)] 115 | for j in range(5): 116 | # print(boxes[j]) 117 | cv2.putText(img, '%s: %s' % (classes[object_ids[j] + 1], confident[j]), (int(boxes[j][0]), int(boxes[j][1] + 15)), 118 | cv2.FONT_HERSHEY_TRIPLEX, 119 | 0.5, 120 | colormap[j], 121 | 1) 122 | cv2.rectangle(img, (int(boxes[j][0]), int(boxes[j][1])), (int(boxes[j][2]), int(boxes[j][3])), 123 | colormap[j], 124 | 1) 125 | if concat_imgs is None: 126 | concat_imgs = img 127 | else: 128 | concat_imgs = np.concatenate((concat_imgs, img), axis=1) 129 | caption_img = np.ones((50, imgs[0].shape[1] * 3, 3)) * 255 130 | cv2.putText(caption_img, caption, (10, 10), cv2.FONT_HERSHEY_TRIPLEX, 131 | 0.5, 132 | colormap[j], 133 | 1) 134 | concat_imgs = np.concatenate((concat_imgs, caption_img), axis=0) 135 | cv2.imwrite(outpath, concat_imgs) 136 | return outpath 137 | 138 | 139 | if __name__ == '__main__': 140 | metadata = pd.read_csv(full_csv, sep='\t') 141 | count = 0 142 | for i in range(len(metadata)): 143 | sample = metadata.iloc[i] 144 | count += 1 145 | if count > 200: 146 | break 147 | video_path = os.path.join(data_source, sample[1] + '.mp4') 148 | imgs = read_frames_cv2(video_path, 8, 'uniform') 149 | feat_paths = [feat_source + sample[1] + '/' + str(k) + '.npz' for k in range(8)] 150 | outpath = 'visualization/3f/{}_{}.jpg'.format(i, sample[1].split('/')[1]) 151 | tri_region_visualize(imgs, feat_paths, sample[0], outpath) 152 | 153 | -------------------------------------------------------------------------------- /OATrans/utils/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/__init__.py -------------------------------------------------------------------------------- /OATrans/utils/visualization/learned_embedding_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | the visualization of learned embedding 3 | """ 4 | 5 | # That's an impressive list of imports. 6 | import numpy as np 7 | import torch 8 | from numpy import linalg 9 | from numpy.linalg import norm 10 | from scipy.spatial.distance import squareform, pdist 11 | 12 | # We import sklearn. 13 | import sklearn 14 | # from sklearn.manifold import TSNE 15 | from sklearn.manifold._t_sne import TSNE 16 | from sklearn.datasets import load_digits 17 | from sklearn.preprocessing import scale 18 | 19 | # We'll hack a bit with the t-SNE code in sklearn 0.15.2. 20 | from sklearn.metrics.pairwise import pairwise_distances 21 | from sklearn.manifold._t_sne import (_joint_probabilities, 22 | _kl_divergence) 23 | # Random state. 24 | RS = 20150101 25 | 26 | # We'll use matplotlib for graphics. 27 | import matplotlib.pyplot as plt 28 | import matplotlib.patheffects as PathEffects 29 | import matplotlib 30 | # %matplotlib inline 31 | 32 | import random 33 | # We import seaborn to make nice plots. 34 | import seaborn as sns 35 | import os 36 | sns.set_style('darkgrid') 37 | sns.set_palette('muted') 38 | sns.set_context("notebook", font_scale=1.5, 39 | rc={"lines.linewidth": 2.5}) 40 | 41 | # We'll generate an animation with matplotlib and moviepy. 42 | # from moviepy.video.io.bindings import mplfig_to_npimage 43 | # import moviepy.editor as mpy 44 | 45 | 46 | def load_data(file_name): 47 | # digits = load_digits() 48 | # digits.data.shape # 1797 x 64 49 | # print(digits.data.shape) 50 | # print(digits['DESCR']) 51 | # return digits 52 | features = np.load(file_name, allow_pickle='TRUE').tolist() 53 | return features 54 | 55 | 56 | def scatter(x, colors, num_class=10): 57 | # We choose a color palette with seaborn. 58 | palette = np.array(sns.color_palette("hls", num_class)) 59 | # sns.palplot(sns.color_palette("hls", 10)) 60 | # We create a scatter plot. 61 | labels=['brush_hair', 'cartwheel', 'catch', 'chew', 62 | 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 63 | 'dribble'] 64 | f = plt.figure(figsize=(8, 8)) 65 | # print(colors.astype(np.int)) 66 | ax = plt.subplot(aspect='equal') 67 | # for i in range(10): 68 | # sc = ax.scatter(x[:, 0][30*i:30*(i+1)], x[:, 1][30*i:30*(i+1)], c=palette[colors.astype(np.int)][30*i:30*(i+1)], 69 | # s=40, 70 | # label=labels[i], 71 | # ) 72 | sc = ax.scatter(x[:,0], x[:,1], c=palette[colors.astype(np.int)], 73 | s=150, 74 | #label=colors.astype(np.int)[30], 75 | ) 76 | # ax.legend(loc="best", title="Classes", bbox_to_anchor=(0.2, 0.4)) 77 | plt.xlim(-25, 25) 78 | plt.ylim(-25, 25) 79 | ax.axis('off') 80 | ax.axis('tight') 81 | 82 | # We add the labels for each digit. 83 | txts = [] 84 | for i in range(num_class): 85 | # Position of each label. 86 | xtext, ytext = np.median(x[colors == i, :], axis=0) 87 | txt = ax.text(xtext, ytext, str(i), fontsize=24) 88 | # ax.legend(ytext, "a") 89 | txt.set_path_effects([ 90 | PathEffects.Stroke(linewidth=5, foreground="w"), 91 | PathEffects.Normal()]) 92 | txts.append(txt) 93 | # ax.legend(('a','b','c','d','e')) 94 | return f, ax, sc, txts 95 | 96 | 97 | def tsne_visualize(data, file_name, num_class=101): 98 | # nrows, ncols = 2, 5 99 | # plt.figure(figsize=(6,3)) 100 | # plt.gray() 101 | # for i in range(ncols * nrows): 102 | # ax = plt.subplot(nrows, ncols, i + 1) 103 | # ax.matshow(digits.images[i,...]) 104 | # plt.xticks([]); plt.yticks([]) 105 | # plt.title(digits.target[i]) 106 | # plt.savefig('../../../experiments/visualization/digits-generated.png', dpi=150) 107 | 108 | # We first reorder the data points according to the handwritten numbers. 109 | datas = [] 110 | labels = [] 111 | nums = len(data) 112 | print(nums) 113 | for j in range(nums): 114 | datas.append(data[j]) 115 | X = np.vstack(datas) 116 | for j in range(nums): 117 | # labels.append(min(j+1, nums-1)) 118 | labels.append(1) 119 | y = np.hstack(labels) 120 | # X = np.vstack([data['data'][data['target']==i].cpu() 121 | # for i in range(10)]) 122 | # y = np.hstack([data['target'][data['target']==i].cpu() 123 | # for i in range(10)]) 124 | # print(y) 125 | digits_proj = TSNE(random_state=RS).fit_transform(X) 126 | scatter(digits_proj, y, nums) 127 | plt.savefig(file_name, dpi=120) 128 | 129 | 130 | # features_file = "utils/visualization/vid_embeds.npy" 131 | # file_name = "utils/visualization/figures/vid_embeds.png" 132 | # features_file = "utils/visualization/text_embeds.npy" 133 | # file_name = "utils/visualization/figures/text_embeds.png" 134 | features_file = "utils/visualization/sims_embeds.npy" 135 | file_name = "utils/visualization/figures/sims_embeds.png" 136 | data = load_data(features_file) 137 | tsne_visualize(data, file_name, '0') -------------------------------------------------------------------------------- /OATrans/utils/visualization/msrvtt_3f_vto_visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | visualize both image + object + text 3 | """ 4 | import numpy as np 5 | import cv2 6 | from csv import reader 7 | import os 8 | import random 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import pdb 12 | import textwrap 13 | import pandas as pd 14 | import json 15 | 16 | full_csv = "MSRVTT/annotation/MSR_VTT.json" 17 | data_source = "MSRVTT/videos/all" 18 | feat_source = "MSRVTT/region_features_full" 19 | output = "MSRVTT/region_visualization" 20 | 21 | 22 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None): 23 | acc_samples = min(num_frames, vlen) 24 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 25 | ranges = [] 26 | for idx, interv in enumerate(intervals[:-1]): 27 | ranges.append((interv, intervals[idx + 1] - 1)) 28 | if sample == 'rand': 29 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 30 | elif fix_start is not None: 31 | frame_idxs = [x[0] + fix_start for x in ranges] 32 | elif sample == 'uniform': 33 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 34 | else: 35 | raise NotImplementedError 36 | return frame_idxs 37 | 38 | 39 | def read_frames_cv2(video_path, num_frames, sample='uniform', fix_start=None): 40 | cap = cv2.VideoCapture(video_path) 41 | assert (cap.isOpened()) 42 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 43 | # get indexes of sampled frames 44 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start) 45 | frames = [] 46 | success_idxs = [] 47 | for index in frame_idxs: 48 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) 49 | ret, frame = cap.read() 50 | if ret: 51 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 52 | frames.append(frame) 53 | success_idxs.append(index) 54 | else: 55 | pass 56 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})') 57 | cap.release() 58 | return frames, success_idxs 59 | 60 | # step1: open video 61 | 62 | # step2: video seq 63 | 64 | # step3: video feature 65 | 66 | 67 | def tri_region_visualize(imgs, feat_paths, caption, outpath="visualization/1.png"): 68 | concat_imgs = None 69 | for i in [0, 3, 7]: 70 | frame1 = np.load(feat_paths[i], allow_pickle=True) 71 | boxes = frame1['bbox'] 72 | features = frame1['x'] # 20 x 2048 73 | confident = frame1['info'].item()['objects_conf'] 74 | # step 1: re-ranking the region with confidence 75 | object_ids = frame1['info'].item()['objects_id'] 76 | condident_indices = np.argsort(confident)[::-1] 77 | boxes = boxes[condident_indices] 78 | features = features[condident_indices] 79 | object_ids = object_ids[condident_indices] 80 | confident = confident[condident_indices] 81 | new_object, unique_indices = np.unique(object_ids, return_index=True) 82 | # step 2: remove region with same object class 83 | boxes = boxes[unique_indices] 84 | features = features[unique_indices] 85 | object_ids = object_ids[unique_indices] 86 | # object_ids = object_ids[unique_indices] 87 | # confident = confident[unique_indices] 88 | # # print(boxes, features) 89 | # image_width = frame1['info'].item()['image_w'] 90 | # image_height = frame1['info'].item()['image_h'] 91 | # box_width = boxes[:, 2] - boxes[:, 0] 92 | # box_height = boxes[:, 3] - boxes[:, 1] 93 | # scaled_width = box_width / image_width 94 | # scaled_height = box_height / image_height 95 | # scaled_x = boxes[:, 0] / image_width 96 | # scaled_y = boxes[:, 1] / image_height 97 | # scaled_width = scaled_width[..., np.newaxis] 98 | # scaled_height = scaled_height[..., np.newaxis] 99 | # scaled_x = scaled_x[..., np.newaxis] 100 | # scaled_y = scaled_y[..., np.newaxis] 101 | # spatial_features = np.concatenate( 102 | # (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1) 103 | # feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1) 104 | classes = ['__background__'] 105 | with open('utils/objects_vocab.txt', 'r') as f: 106 | for object in f.readlines(): 107 | classes.append(object.split(',')[0].lower().strip()) 108 | # print(features.shape) 109 | # plot top 5 objects 110 | img = imgs[i] 111 | if len(boxes) < 5: 112 | return False 113 | # print(img.shape) 114 | # print(img) 115 | colormap = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (155, 100, 100), (100, 155, 100)] 116 | for j in range(5): 117 | # print(boxes[j]) 118 | cv2.putText(img, '%s: %s' % (classes[object_ids[j] + 1], confident[j]), (int(boxes[j][0]), int(boxes[j][1] + 15)), 119 | cv2.FONT_HERSHEY_TRIPLEX, 120 | 0.5, 121 | colormap[j], 122 | 1) 123 | cv2.rectangle(img, (int(boxes[j][0]), int(boxes[j][1])), (int(boxes[j][2]), int(boxes[j][3])), 124 | colormap[j], 125 | 1) 126 | if concat_imgs is None: 127 | concat_imgs = img 128 | else: 129 | concat_imgs = np.concatenate((concat_imgs, img), axis=1) 130 | caption_img = np.ones((50, imgs[0].shape[1] * 3, 3)) * 255 131 | cv2.putText(caption_img, caption, (10, 10), cv2.FONT_HERSHEY_TRIPLEX, 132 | 0.5, 133 | colormap[j], 134 | 1) 135 | concat_imgs = np.concatenate((concat_imgs, caption_img), axis=0) 136 | cv2.imwrite(outpath, concat_imgs) 137 | return outpath 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | f = open(full_csv) 143 | data = json.load(f) 144 | count = 0 145 | 146 | for row in data['annotations']: 147 | count += 1 148 | if count % 10 != 0: 149 | continue 150 | # if row['id'] % 200 != 1: 151 | # continue 152 | if count > 2000: 153 | break 154 | video_path = os.path.join(data_source, row['image_id'] + '.mp4') 155 | imgs, success_idxs = read_frames_cv2(video_path, 8, 'uniform') 156 | feat_paths = [feat_source + '/' + row['image_id'] + '/' + str(success_idxs[k]) + '.npz' for k in range(8)] 157 | outpath = 'visualization/msrvtt_3f/{}_{}.jpg'.format(count, row['image_id']) 158 | tri_region_visualize(imgs, feat_paths, row['caption'], outpath) 159 | 160 | -------------------------------------------------------------------------------- /OATrans/utils/visualization/msrvtt_vto_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | visualize both image + object + text 3 | """ 4 | import numpy as np 5 | import cv2 6 | from csv import reader 7 | import os 8 | import random 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import pdb 12 | import textwrap 13 | import json 14 | 15 | full_json = "MSRVTT/annotation/MSR_VTT.json" 16 | data_source = "MSRVTT/videos/all" 17 | feat_source = "MSRVTT/region_features" 18 | output = "MSRVTT/region_visualization" 19 | 20 | 21 | def feature_visualize(img1, feat_path): 22 | frame1 = np.load(feat_path, allow_pickle=True) 23 | boxes = frame1['bbox'] 24 | features = frame1['x'] # 20 x 2048 25 | confident = frame1['info'].item()['objects_conf'] 26 | # step 1: re-ranking the region with confidence 27 | object_ids = frame1['info'].item()['objects_id'] 28 | condident_indices = np.argsort(confident)[::-1] 29 | boxes = boxes[condident_indices] 30 | features = features[condident_indices] 31 | object_ids = object_ids[condident_indices] 32 | confident = confident[condident_indices] 33 | 34 | new_object, unique_indices = np.unique(object_ids, return_index=True) 35 | # step 2: remove region with same object class 36 | 37 | boxes = boxes[unique_indices] 38 | features = features[unique_indices] 39 | object_ids = object_ids[unique_indices] 40 | confident = confident[unique_indices] 41 | 42 | # print(boxes, features) 43 | image_width = frame1['info'].item()['image_w'] 44 | image_height = frame1['info'].item()['image_h'] 45 | 46 | box_width = boxes[:, 2] - boxes[:, 0] 47 | box_height = boxes[:, 3] - boxes[:, 1] 48 | scaled_width = box_width / image_width 49 | scaled_height = box_height / image_height 50 | scaled_x = boxes[:, 0] / image_width 51 | scaled_y = boxes[:, 1] / image_height 52 | scaled_width = scaled_width[..., np.newaxis] 53 | scaled_height = scaled_height[..., np.newaxis] 54 | scaled_x = scaled_x[..., np.newaxis] 55 | scaled_y = scaled_y[..., np.newaxis] 56 | spatial_features = np.concatenate( 57 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1) 58 | # print(spatial_features) 59 | feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1) 60 | classes = ['__background__'] 61 | with open('../objects_vocab.txt', 'r') as f: 62 | for object in f.readlines(): 63 | classes.append(object.split(',')[0].lower().strip()) 64 | # print(features.shape) 65 | im = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 66 | plt.axis('off') 67 | plt.imshow(im) 68 | new_boxes = boxes 69 | for i in range(len(new_boxes)): 70 | bbox = new_boxes[i] 71 | if i < 10: 72 | plt.gca().add_patch( 73 | plt.Rectangle((bbox[0], bbox[1]), 74 | bbox[2] - bbox[0], 75 | bbox[3] - bbox[1], fill=False, 76 | edgecolor='red', linewidth=2, alpha=0.5) 77 | ) 78 | plt.gca().text(bbox[0], bbox[1] - 2, 79 | '%s: %s' % (classes[object_ids[i] + 1], confident[i]), 80 | bbox=dict(facecolor='blue', alpha=0.5), 81 | fontsize=10, color='white') 82 | outpath = "test_roi.png" 83 | plt.savefig(outpath, dpi=150) 84 | plt.close() 85 | return outpath 86 | 87 | f = open(full_json) 88 | data = json.load(f) 89 | # for key in data: 90 | # print(key) 91 | # pdb.set_trace() 92 | # Iterate over each row in the csv using reader object 93 | count = 0 94 | 95 | for row in data['annotations']: 96 | count += 1 97 | if row['id'] % 200 != 1: 98 | continue 99 | video_path = os.path.join(data_source, row['image_id'] + '.mp4') 100 | cap = cv2.VideoCapture(video_path) 101 | ret, img = cap.read() 102 | feat_path = os.path.join(feat_source, row['image_id'], '1.npz') 103 | feat_img = cv2.imread(feature_visualize(img, feat_path)) 104 | print(feat_img.shape) 105 | caption_img = np.ones([feat_img.shape[0]//4, feat_img.shape[1], 3]) * 255 106 | 107 | wrapped_text = textwrap.wrap(row['caption'], width=35) 108 | x, y = 10, 40 109 | font_size = 1 110 | font_thickness = 2 111 | font = cv2.FONT_HERSHEY_TRIPLEX 112 | 113 | for i, line in enumerate(wrapped_text): 114 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] 115 | 116 | gap = textsize[1] + 10 117 | 118 | y = 25 + i * gap 119 | x = int((caption_img.shape[1] - textsize[0]) / 2) + 20 120 | 121 | cv2.putText(caption_img, line, (x, y), font, 122 | font_size, 123 | (122, 21, 91), 124 | font_thickness, 125 | lineType=cv2.LINE_AA) 126 | 127 | # cv2.putText(caption_img, row[1], (30, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), 128 | # 1) 129 | concat = np.concatenate([feat_img[50:-50, :, :], caption_img], axis=0) 130 | # cv2.imshow(concat) 131 | out_file = os.path.join(output, row['image_id'] + '_' + str(row['id']) + '.png') 132 | print("hello world") 133 | cv2.imwrite(out_file, concat) 134 | # if cv2.waitKey(33) == 27: 135 | # continue 136 | -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/0_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/0_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/10_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/10_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/11_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/11_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/12_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/12_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/13_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/13_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/14_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/14_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/1_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/1_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/2_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/2_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/3_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/3_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/4_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/4_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/5_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/5_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/6_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/6_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/7_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/7_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/8_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/8_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/predict_visualization/9_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/9_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/print_tags.py: -------------------------------------------------------------------------------- 1 | def predict2caption(predict, vocab='utils/objects_vocab.txt'): 2 | caption = "" 3 | classes = ['__background__'] 4 | with open(vocab, 'r') as f: 5 | for object in f.readlines(): 6 | classes.append(object.split(',')[0].lower().strip()) 7 | for n in range(len(predict)): 8 | caption += ' ' + (classes[predict[n]+1]) 9 | return caption -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/0_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/0_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/10_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/10_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/11_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/11_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/12_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/12_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/13_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/13_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/14_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/14_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/15_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/15_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/1_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/1_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/2_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/2_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/3_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/3_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/4_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/4_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/5_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/5_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/6_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/6_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/7_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/7_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/8_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/8_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/transfer_predict_visualization/9_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/9_predict.png -------------------------------------------------------------------------------- /OATrans/utils/visualization/webvid_vto_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | visualize both image + object + text 3 | """ 4 | import numpy as np 5 | import cv2 6 | from csv import reader 7 | import os 8 | import random 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import pdb 12 | import textwrap 13 | 14 | full_csv = "WebVid2M_videos/metadata/results_subset_train.csv" 15 | data_source = "WebVid2M_videos/train_videos" 16 | feat_source = "WebVid2M_frames_region_features/train" 17 | output = "WebVid2M_visualization/train" 18 | 19 | 20 | def feature_visualize(img1, feat_path): 21 | frame1 = np.load(feat_path, allow_pickle=True) 22 | boxes = frame1['bbox'] 23 | features = frame1['x'] # 20 x 2048 24 | confident = frame1['info'].item()['objects_conf'] 25 | # step 1: re-ranking the region with confidence 26 | object_ids = frame1['info'].item()['objects_id'] 27 | condident_indices = np.argsort(confident)[::-1] 28 | boxes = boxes[condident_indices] 29 | features = features[condident_indices] 30 | object_ids = object_ids[condident_indices] 31 | confident = confident[condident_indices] 32 | 33 | new_object, unique_indices = np.unique(object_ids, return_index=True) 34 | # step 2: remove region with same object class 35 | 36 | boxes = boxes[unique_indices] 37 | features = features[unique_indices] 38 | object_ids = object_ids[unique_indices] 39 | confident = confident[unique_indices] 40 | 41 | # print(boxes, features) 42 | image_width = frame1['info'].item()['image_w'] 43 | image_height = frame1['info'].item()['image_h'] 44 | 45 | box_width = boxes[:, 2] - boxes[:, 0] 46 | box_height = boxes[:, 3] - boxes[:, 1] 47 | scaled_width = box_width / image_width 48 | scaled_height = box_height / image_height 49 | scaled_x = boxes[:, 0] / image_width 50 | scaled_y = boxes[:, 1] / image_height 51 | scaled_width = scaled_width[..., np.newaxis] 52 | scaled_height = scaled_height[..., np.newaxis] 53 | scaled_x = scaled_x[..., np.newaxis] 54 | scaled_y = scaled_y[..., np.newaxis] 55 | spatial_features = np.concatenate( 56 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1) 57 | # print(spatial_features) 58 | feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1) 59 | classes = ['__background__'] 60 | with open('../objects_vocab.txt', 'r') as f: 61 | for object in f.readlines(): 62 | classes.append(object.split(',')[0].lower().strip()) 63 | # print(features.shape) 64 | im = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 65 | plt.axis('off') 66 | plt.imshow(im) 67 | new_boxes = boxes 68 | for i in range(len(new_boxes)): 69 | bbox = new_boxes[i] 70 | if i < 10: 71 | plt.gca().add_patch( 72 | plt.Rectangle((bbox[0], bbox[1]), 73 | bbox[2] - bbox[0], 74 | bbox[3] - bbox[1], fill=False, 75 | edgecolor='red', linewidth=2, alpha=0.5) 76 | ) 77 | plt.gca().text(bbox[0], bbox[1] - 2, 78 | '%s: %s' % (classes[object_ids[i] + 1], confident[i]), 79 | bbox=dict(facecolor='blue', alpha=0.5), 80 | fontsize=10, color='white') 81 | outpath = "test_roi.png" 82 | plt.savefig(outpath, dpi=150) 83 | plt.close() 84 | return outpath 85 | 86 | 87 | with open(full_csv, 'r') as read_obj: 88 | # pass the file object to reader() to get the reader object 89 | csv_reader = reader(read_obj) 90 | # Iterate over each row in the csv using reader object 91 | count = 0 92 | for row in csv_reader: 93 | count += 1 94 | if count == 1: 95 | continue 96 | # if count > 3: 97 | # break 98 | # cv2.destroyAllWindows() 99 | if len(row[3]) < 3: 100 | continue 101 | video_path = os.path.join(data_source, row[3], row[0] + '.mp4') 102 | cap = cv2.VideoCapture(video_path) 103 | ret, img = cap.read() 104 | feat_path = os.path.join(feat_source, row[3], row[0], '1.npz') 105 | feat_img = cv2.imread(feature_visualize(img, feat_path)) 106 | print(feat_img.shape) 107 | caption_img = np.ones([feat_img.shape[0]//4, feat_img.shape[1], 3]) * 255 108 | 109 | wrapped_text = textwrap.wrap(row[1], width=35) 110 | x, y = 10, 40 111 | font_size = 1 112 | font_thickness = 2 113 | font = cv2.FONT_HERSHEY_TRIPLEX 114 | 115 | for i, line in enumerate(wrapped_text): 116 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] 117 | 118 | gap = textsize[1] + 10 119 | 120 | y = 25 + i * gap 121 | x = int((caption_img.shape[1] - textsize[0]) / 2) + 20 122 | 123 | cv2.putText(caption_img, line, (x, y), font, 124 | font_size, 125 | (122, 21, 91), 126 | font_thickness, 127 | lineType=cv2.LINE_AA) 128 | 129 | # cv2.putText(caption_img, row[1], (30, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), 130 | # 1) 131 | concat = np.concatenate([feat_img[50:-50, :, :], caption_img], axis=0) 132 | # cv2.imshow(concat) 133 | out_file = os.path.join(output, row[3] + row[0] + '.png') 134 | print("hello world") 135 | cv2.imwrite(out_file, concat) 136 | # if cv2.waitKey(33) == 27: 137 | # continue 138 | -------------------------------------------------------------------------------- /OATrans/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | """A simple HTML visualizer. 2 | 3 | It is based on the Cycle-GAN codebase: 4 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 5 | """ 6 | import os 7 | import numpy as np 8 | from pathlib import Path 9 | from . import util, html 10 | import pdb 11 | 12 | class RetrievalVis: 13 | """This class includes several functions that can display/save images. 14 | 15 | It uses a Python library 'visdom' for display, and a Python library 'dominate' 16 | (wrapped in 'HTML') for creating HTML files with images. 17 | """ 18 | 19 | def __init__(self, exp_name, web_dir, src_video_dir, vis_vid_freq, num_samples=50): 20 | """Initialize the Visualizer class 21 | Create an HTML object for saveing HTML filters 22 | """ 23 | self.name = exp_name 24 | self.web_dir = web_dir 25 | self.vis_vid_freq = vis_vid_freq 26 | self.img_dir = os.path.join(self.web_dir, "images") 27 | self.num_samples = num_samples 28 | 29 | self.data_type = 'images' # 'images' or 'videos' 30 | assert self.data_type in ('images', 'videos') 31 | 32 | print(f"create web directory {self.web_dir}...") 33 | mkdirs([self.web_dir, self.img_dir]) 34 | 35 | # cluster specific 36 | if "$TMPDIR" in src_video_dir: 37 | src_video_dir = src_video_dir.replace("$TMPDIR", os.environ['TMPDIR']) 38 | 39 | src_dir = Path(src_video_dir).absolute() 40 | print(f"symlinking videos from {src_dir}...") 41 | sym_dir = (Path(self.web_dir) / "videos").absolute() 42 | if sym_dir.is_symlink(): 43 | os.remove(sym_dir) 44 | sym_dir.symlink_to(src_dir) 45 | 46 | def visualize_ranking(self, sims, epoch, meta, nested_metrics): 47 | if not (self.vis_vid_freq and epoch % self.vis_vid_freq == 0): 48 | return 49 | 50 | dists = -sims 51 | np.random.seed(0) 52 | sorted_ranks = np.argsort(dists, axis=1) 53 | gt_dists = np.diag(dists) 54 | rankings = [] 55 | vis_top_k = 5 56 | hide_gt = False 57 | # num_indep_samples = 1 58 | # random_seeds = np.arange(num_indep_samples) 59 | sample = np.random.choice(np.arange(dists.shape[0]), size=self.num_samples, 60 | replace=False) 61 | for ii in sample: 62 | ranked_idx = sorted_ranks[ii][:vis_top_k] 63 | gt_captions = meta["raw_captions"][ii] 64 | # if args.sample_single_gt_caption: 65 | # gt_captions = np.random.choice(gt_captions, 1).tolist() 66 | datum = { 67 | "gt-sim": -gt_dists[ii], 68 | "gt-captions": gt_captions, 69 | "gt-rank": np.where(sorted_ranks[ii] == ii)[0][0], 70 | "gt-path": meta["paths"][ii], 71 | "top-k-sims": -dists[ii][ranked_idx], 72 | "top-k-paths": np.array(meta["paths"])[ranked_idx], 73 | "hide-gt": hide_gt, 74 | } 75 | rankings.append(datum) 76 | self.display_current_results( 77 | rankings, 78 | epoch=epoch, 79 | metrics=nested_metrics["t2v_metrics"], 80 | ) 81 | 82 | def display_current_results(self, rankings, epoch, metrics): 83 | """Display current results on visdom; save current results to an HTML file. 84 | 85 | Parameters: 86 | visuals (OrderedDict) - - dictionary of images to display or save 87 | epoch (int) - - the current epoch 88 | save_result (bool) - - if save the current results to an HTML file 89 | """ 90 | if not Path(self.web_dir).exists(): 91 | Path(self.web_dir).mkdir(exist_ok=True, parents=True) 92 | print(f"updating webpage at {self.web_dir}") 93 | title = f"Experiment name = {self.name}" 94 | refresh = True 95 | if not refresh: 96 | print("DISABLING WEB PAGE REFRESH") 97 | webpage = html.HTML(web_dir=self.web_dir, title=title, refresh=refresh) 98 | 99 | msg = f"epoch [{epoch}] - {self.name}" 100 | webpage.add_header(msg) 101 | msg = (f"R1: {metrics['R1']:.1f}, " 102 | f"R5: {metrics['R5']:.1f}, " 103 | f"R10: {metrics['R10']:.1f}, " 104 | f"MedR: {metrics['MedR']}") 105 | webpage.add_header(msg) 106 | print(f"Top {len(rankings[0])} retreived videos at epoch: {epoch}") 107 | 108 | for ranking in rankings: 109 | vids, txts, links = [], [], [] 110 | gt_vid_path = os.path.join('videos', ranking["gt-path"]) 111 | #gt_captions = [" ".join(x) for x in ranking["gt-captions"]] 112 | gt_captions = ranking['gt-captions'] 113 | gt_captions = "
" + (gt_captions) + "
" 114 | if ranking["hide-gt"]: 115 | txts.append(gt_captions) 116 | links.append("hidden") 117 | vids.append("hidden") 118 | else: 119 | txt = (f"{gt_captions}
Rank: {ranking['gt-rank']}, " 120 | f"Sim: {ranking['gt-sim']:.3f} [{Path(ranking['gt-path']).stem}]") 121 | txts.append(txt) 122 | links.append(gt_vid_path) 123 | vids.append(gt_vid_path) 124 | 125 | for idx, (vid_path, sim) in enumerate(zip(ranking["top-k-paths"], 126 | ranking["top-k-sims"])): 127 | vid_path = Path(os.path.join('videos', vid_path)) 128 | if ranking["hide-gt"]: 129 | txt = f"choice: {idx}" 130 | else: 131 | txt = f"Rank: {idx}, Sim: {sim:.3f}, [{Path(vid_path).stem}]" 132 | txts.append(txt) 133 | vids.append(vid_path) 134 | links.append(vid_path) 135 | if self.data_type == 'videos': 136 | webpage.add_videos(vids, txts, links, width=200) 137 | elif self.data_type == 'images': 138 | webpage.add_images(vids, txts, links, width=200) 139 | print(f"added {len(vids)} videos") 140 | webpage.save() 141 | 142 | def mkdirs(paths): 143 | """create empty directories if they don't exist 144 | 145 | Parameters: 146 | paths (str list) -- a list of directory paths 147 | """ 148 | if isinstance(paths, list) and not isinstance(paths, str): 149 | for path in paths: 150 | mkdir(path) 151 | else: 152 | mkdir(paths) 153 | 154 | 155 | def mkdir(path): 156 | """create a single empty directory if it didn't exist 157 | 158 | Parameters: 159 | path (str) -- a single directory path 160 | """ 161 | if not os.path.exists(path): 162 | os.makedirs(path) 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 22] "Object-aware Video-language Pre-training for Retrieval" [arxiv](https://arxiv.org/abs/2112.00656) 2 | 3 | 4 | ![](figures/oa_main_ppl.jpg) 5 | 6 | 7 | ## 1. Object Feature Extractor 8 | 9 | We provide a faster version to extract object from WebVid 2.5M and CC 3M. 10 | We extract objects of 5.5M * 8 = 44M frames in total and it takes 28 days on 16 V100 GPUs. 11 | 12 | Refer to [Object Extractor.md](object_extraction.md) for more details. 13 | 14 | 15 | ## 2. OA Trans 16 | 17 | Refer to [train.md](train.md) for more details. 18 | 19 | ## 3. Visualizations 20 | 21 | In this code, we provide two ways to visualize cross-modality attention. 22 | 23 | ### Heatmap Visualization 24 | ![](figures/oa_visualize_1.jpg) 25 | 26 | 27 | ### Binary Map Visualization 28 | ![](figures/oa_visualize_2.jpg) 29 | 30 | Please refer to [visualization.md](visualization.md) for details. 31 | 32 | 33 | ## News: 34 | - 2021.12.5 Arxiv Version Published. 35 | - 2022.3.15 First version Code Released. 36 | 37 | ## 5. Citation 38 | 39 | If you find our work helpful, please cite our paper 40 | ```bash 41 | @article{wang2022oatrans, 42 | title={Object-aware Video-language Pre-training for Retrieval}, 43 | author={Wang, Alex Jinpeng and Ge, Yixiao and Cai, Guanyu and Yan, Rui and Lin, Xudong and Shan, Ying and Qie, Xiaohu and Shou, Mike Zheng}, 44 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 45 | year={2022} 46 | } 47 | ``` 48 | 49 | ## Acknowledgement 50 | 51 | This work is mainly based on [Frozen](https://github.com/m-bain/frozen-in-time). -------------------------------------------------------------------------------- /Visualization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/.DS_Store -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/.DS_Store -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from torchvision import transforms 5 | import torch 6 | from PIL import Image, ImageFile 7 | 8 | 9 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None): 10 | acc_samples = min(num_frames, vlen) 11 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 12 | ranges = [] 13 | for idx, interv in enumerate(intervals[:-1]): 14 | ranges.append((interv, intervals[idx + 1] - 1)) 15 | if sample == 'rand': 16 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges] 17 | elif fix_start is not None: 18 | frame_idxs = [x[0] + fix_start for x in ranges] 19 | elif sample == 'uniform': 20 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges] 21 | else: 22 | raise NotImplementedError 23 | return frame_idxs 24 | 25 | 26 | def read_frames_cv2(video_path, num_frames, sample='rand', fix_start=None, numpy=False): 27 | cap = cv2.VideoCapture(video_path) 28 | assert (cap.isOpened()) 29 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 30 | # get indexes of sampled frames 31 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start) 32 | frames = [] 33 | success_idxs = [] 34 | for index in frame_idxs: 35 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1) 36 | ret, frame = cap.read() 37 | # print(frame.shape) 38 | if ret: 39 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 40 | frame = torch.from_numpy(frame) 41 | # (H x W x C) to (C x H x W) 42 | frame = frame.permute(2, 0, 1) 43 | frames.append(frame) 44 | success_idxs.append(index) 45 | else: 46 | pass 47 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})') 48 | if not numpy: 49 | frames = torch.stack(frames).float() / 255 50 | cap.release() 51 | return frames, success_idxs, vlen 52 | 53 | 54 | def vision_preprocess(vid_src): 55 | video, _, _ = read_frames_cv2(vid_src, 1) 56 | transform = transforms.Compose( 57 | [ 58 | transforms.Resize(size=(224, 224)), 59 | # transforms.RandomResizedCrop(size=(224, 224)), 60 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 61 | ] 62 | ) 63 | video = transform(video) 64 | # expand one dim as batch 65 | video = video.unsqueeze(0) 66 | return video.cuda() 67 | 68 | 69 | def mask_vision_preprocess(vid_src): 70 | video, _, _ = read_frames_cv2(vid_src, 8) 71 | transform = transforms.Compose( 72 | [ 73 | transforms.Resize(size=(224, 224)), 74 | # transforms.RandomResizedCrop(size=(224, 224)), 75 | # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 76 | ] 77 | ) 78 | video = transform(video[:3]) 79 | return video 80 | 81 | 82 | def vision_img_preprocess(img_src): 83 | img = Image.open(img_src).convert("RGB") 84 | img = transforms.ToTensor()(img).unsqueeze(0) 85 | transform = transforms.Compose( 86 | [ 87 | transforms.Resize(size=(224, 224)), 88 | # transforms.RandomResizedCrop(size=(224, 224)), 89 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 90 | ] 91 | ) 92 | img = transform(img) 93 | # expand one dim as batch 94 | img = img.unsqueeze(0) 95 | return img.cuda() 96 | 97 | 98 | def clip_img_preprocess(img_src, preprocess): 99 | img = Image.open(img_src).convert("RGB") 100 | img = preprocess(img).unsqueeze(0) 101 | # print(img.size()) 102 | return img.cuda().half() -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/main_img.py: -------------------------------------------------------------------------------- 1 | from model.text_model import text_encode_init 2 | from model.vision_model import vision_encode_init 3 | from data_preprocess import vision_img_preprocess, clip_img_preprocess 4 | import pandas as pd 5 | from visualize import cross_attention_visualize 6 | import parse_config as parse_config 7 | import model.vision_models.clip as clip 8 | 9 | 10 | csv_file = "data/meta_data/cc3m_training_success_full.tsv" 11 | # out_dir = 'output/featmap/' 12 | model_se = 'frozen' # 'frozen' or 'clip' 13 | # out_dir = 'output/featmap/{}/'.format(model_se) 14 | out_dir = 'output/cross_featmap/cc3m/{}_attn/'.format(model_se) 15 | video_root = 'CC3M/training/' 16 | metadata = pd.read_csv(csv_file, sep='\t') 17 | text_model = text_encode_init(model_name=model_se) 18 | img_model, preprocess = vision_encode_init(model_name=model_se) 19 | 20 | count = 0 21 | for item in range(len(metadata)): 22 | sample = metadata.iloc[item] 23 | video_src = video_root + sample[1] 24 | caption = sample[0] 25 | if model_se == 'clip': 26 | img = clip_img_preprocess(video_src, preprocess) 27 | else: 28 | img = vision_img_preprocess(video_src) 29 | print(img.size()) 30 | img_patch_embedding = img_model(img) 31 | if model_se == 'clip': 32 | img = img.unsqueeze(0) 33 | if model_se == 'clip': 34 | text_token = text_model(clip.tokenize(caption).cuda()) 35 | else: 36 | text_token = text_model(caption) 37 | # print(img_patch_embedding.size()) 38 | if model_se == 'clip': 39 | img = img.float() 40 | cross_attention_visualize(img_patch_embedding, img[0], caption, text_token, text_model, model_name=model_se, 41 | name=out_dir + str(item), v=1) 42 | count += 1 43 | if count > 500: 44 | break 45 | 46 | 47 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/main_video.py: -------------------------------------------------------------------------------- 1 | from model.text_model import text_encode_init 2 | from model.vision_model import vision_encode_init 3 | from data_preprocess import vision_preprocess 4 | import pandas as pd 5 | from visualize import cross_attention_visualize 6 | import parse_config as parse_config 7 | 8 | 9 | csv_file = "data/webvid_validation_success_full.tsv" 10 | # out_dir = 'output/featmap/' 11 | out_dir = 'output/cross_featmap/' 12 | video_root = 'WebVid/val/' 13 | metadata = pd.read_csv(csv_file, sep='\t') 14 | text_model = text_encode_init() 15 | video_model = vision_encode_init() 16 | 17 | count = 0 18 | for item in range(len(metadata)): 19 | sample = metadata.iloc[item] 20 | video_src = video_root + sample[1] + '.mp4' 21 | caption = sample[0] 22 | # print(video_src) 23 | video = vision_preprocess(video_src) 24 | # print(video.size()) 25 | video_patch_embedding = video_model(video) 26 | print(caption) 27 | text_token = text_model(caption) 28 | # print(video_patch_embedding.size()) 29 | # print(text_token.size()) 30 | sim = cross_attention_visualize(video_patch_embedding, video[0], caption, text_token, text_model, name=out_dir + str(item)) 31 | 32 | count += 1 33 | if count > 100: 34 | break 35 | 36 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/main_video_patches_visualization.py: -------------------------------------------------------------------------------- 1 | from data_preprocess import mask_vision_preprocess 2 | import pandas as pd 3 | import os 4 | from patch_mask import visualize_mask 5 | from utils.read_bboxs import read_bbox_from_pickle 6 | 7 | 8 | csv_file = "data/webvid_validation_success_full.tsv" 9 | # out_dir = 'output/featmap/' 10 | out_dir = 'output/mask_object_visualization/' 11 | video_root = 'WebVid/val/' 12 | metadata = pd.read_csv(csv_file, sep='\t') 13 | features_root = 'WebVid/8_frame_object' 14 | 15 | 16 | 17 | count = 0 18 | for item in range(len(metadata)): 19 | sample = metadata.iloc[item] 20 | video_src = video_root + sample[1] + '.mp4' 21 | video = mask_vision_preprocess(video_src) 22 | object_bboxs = [] 23 | for i in range(3): 24 | rel_object_fp = os.path.join(sample[1], '{}.npz'.format(i)) 25 | full_object_fp = os.path.join(features_root, 'val', rel_object_fp) 26 | object_bboxs.append(read_bbox_from_pickle(full_object_fp)) 27 | # print(video_src) 28 | out_name = out_dir + str(item) 29 | visualize_mask(video, object_bboxs, out_name) 30 | count += 1 31 | if count > 100: 32 | break 33 | 34 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/model/__init__.py -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/text_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoModel 3 | import transformers 4 | import torch 5 | import model.vision_models.clip as clip 6 | 7 | class TextEncoder(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.text_model = AutoModel.from_pretrained('pretrained/distilbert-base-uncased') 11 | self.tokenizer = transformers.AutoTokenizer.from_pretrained('pretrained/distilbert-base-uncased', 12 | TOKENIZERS_PARALLELISM=False) 13 | self.device = "cuda:0" 14 | self.txt_proj = nn.Sequential(nn.ReLU(), 15 | nn.Linear(768, 256), 16 | ) 17 | def token_of_word(self, word): 18 | token = self.tokenizer(word, return_tensors='pt', padding=True, 19 | truncation=True) 20 | return token 21 | 22 | def forward(self, x): 23 | if self.tokenizer is not None: 24 | x = self.tokenizer(x, return_tensors='pt', padding=True, 25 | truncation=True) 26 | x = {key: val.to(self.device) for key, val in x.items()} 27 | text_embeddings_all = self.text_model(**x).last_hidden_state 28 | # print(text_embeddings_all.size()) # batch_size, sequence_length, hidden_size 29 | # text_embeddings = text_embeddings_all[:, 0, :] 30 | text_embeddings = text_embeddings_all 31 | # print(text_embeddings.size()) 32 | return self.txt_proj(text_embeddings) 33 | # return text_embeddings 34 | 35 | def weight_transform(model_dict, pretrain_dict): 36 | ''' 37 | :return: 38 | ''' 39 | weight_dict = {k[7:]:v for k, v in pretrain_dict.items() if k[7:] in model_dict and k[:7] == 'module.'} 40 | # for k, v in pretrain_dict.items(): 41 | # print(k[7:]) 42 | # # pdb.set_trace() 43 | for k, v in pretrain_dict.items(): 44 | if k[:14] == 'module.txt_proj': 45 | weight_dict[k[7:]] = v 46 | for k, v in weight_dict.items(): 47 | print("load: {}".format(k)) 48 | # print(weight_dict) 49 | model_dict.update(weight_dict) 50 | return model_dict 51 | 52 | def load_pt_weight(model): 53 | checkpoint = torch.load("pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar", map_location="cpu") 54 | pretrained_state = checkpoint['state_dict'] 55 | model_state = model.state_dict() 56 | # for k , v in model_state.items(): 57 | # print(k) 58 | model.load_state_dict(weight_transform(model_state, pretrained_state)) 59 | return model 60 | 61 | 62 | def text_encode_init(model_name='frozen'): 63 | if model_name == 'clip': 64 | full_model, preprocess = clip.load("pretrained/ViT-B-16.pt") 65 | model = full_model.encode_text 66 | else: 67 | model = TextEncoder() 68 | load_pt_weight(model) 69 | model = model.cuda() 70 | return model -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/text_models/distill_bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoModel 3 | import transformers 4 | 5 | 6 | class DistillBert(nn.module): 7 | def __init__(self) 8 | super().__init__() 9 | self.text_model = AutoModel.from_pretrained(text_params['model']) 10 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'], 11 | TOKENIZERS_PARALLELISM=False) 12 | 13 | def forward(self, x): 14 | if self.tokenizer is not None: 15 | x = self.tokenizer(x, return_tensors='pt', padding=True, 16 | truncation=True) 17 | x = {key: val.to(self.device) for key, val in x.items()} 18 | text_embeddings_all = self.text_model(**x).last_hidden_state 19 | # print(text_embeddings_all.size()) # batch_size, sequence_length, hidden_size 20 | text_embeddings = text_embeddings_all[:, 0, :] 21 | return text_embeddings -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/vision_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.vision_models.frozen import SpaceTimeTransformer 3 | import model.vision_models.clip as clip 4 | 5 | 6 | def weight_transform(model_dict, pretrain_dict): 7 | ''' 8 | :return: 9 | ''' 10 | weight_dict = {k[19:]:v for k, v in pretrain_dict.items() if k[19:] in model_dict and k[:19] == 'module.video_model.'} 11 | for k, v in pretrain_dict.items(): 12 | print(k[19:]) 13 | # pdb.set_trace() 14 | for k, v in pretrain_dict.items(): 15 | if k[:15] == 'module.vid_proj': 16 | weight_dict[k[7:]] = v 17 | for k, v in weight_dict.items(): 18 | print("load: {}".format(k)) 19 | # print(weight_dict) 20 | model_dict.update(weight_dict) 21 | return model_dict 22 | 23 | 24 | def load_pt_weight(model): 25 | """ 26 | load the object transformer weight from clip vision transformer 27 | notice some of have failed 28 | Args: 29 | model (): 30 | 31 | Returns: 32 | 33 | """ 34 | checkpoint = torch.load("pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar", map_location="cpu") 35 | pretrained_state = checkpoint['state_dict'] 36 | # model.load_state_dict(vit_checkpoint, strict=False) 37 | # pretrain_model = torch.jit.load('pretrained/ViT-B-16.pt') 38 | # pretrained_state = pretrain_model.state_dict() 39 | model_state = model.state_dict() 40 | # for k, v in model_state.items(): 41 | # print(k) 42 | model.load_state_dict(weight_transform(model_state, pretrained_state)) 43 | return model 44 | 45 | 46 | def vision_encode_init(model_name="frozen"): 47 | # frozen 48 | preprocess = None 49 | if model_name == 'clip': 50 | full_model, preprocess = clip.load("pretrained/ViT-B-16.pt") 51 | model = full_model.visual 52 | elif model_name == 'frozen': 53 | model = SpaceTimeTransformer() 54 | load_pt_weight(model) 55 | else: 56 | print("not support") 57 | model = model.cuda() 58 | return model, preprocess -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | # from logger import setup_logging 8 | # from utils import read_json, write_json 9 | import time 10 | import inspect 11 | import pdb 12 | 13 | class ConfigParser: 14 | def __init__(self, args, options='', timestamp=True, test=False): 15 | # parse default and custom cli options 16 | for opt in options: 17 | args.add_argument(*opt.flags, default=None, type=opt.type) 18 | args = args.parse_args() 19 | 20 | # if args.device: 21 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.device 22 | # if args.resume is None: 23 | # msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 24 | # assert args.config is not None, msg_no_cfg 25 | # self.cfg_fname = Path(args.config) 26 | # config = read_json(self.cfg_fname) 27 | # self.resume = None 28 | # else: 29 | # self.resume = Path(args.resume) 30 | # resume_cfg_fname = self.resume.parent / 'config.json' 31 | # config = read_json(resume_cfg_fname) 32 | # if args.config is not None: 33 | # config.update(read_json(Path(args.config))) 34 | # 35 | # # load config file and apply custom cli options 36 | # self._config = _update_config(config, options, args) 37 | # 38 | # # set save_dir where trained model and log will be saved. 39 | # save_dir = Path(self.config['trainer']['save_dir']) 40 | # timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 41 | # 42 | # exper_name = self.config['name'] 43 | # self._save_dir = save_dir / 'models' / exper_name / timestamp 44 | # self._web_log_dir = save_dir / 'web' / exper_name / timestamp 45 | # self._log_dir = save_dir / 'log' / exper_name / timestamp 46 | # 47 | # if not test: 48 | # self.save_dir.mkdir(parents=True, exist_ok=True) 49 | # self.log_dir.mkdir(parents=True, exist_ok=True) 50 | # 51 | # # if set, remove all previous experiments with the current config 52 | # if vars(args).get("purge_exp_dir", False): 53 | # for dirpath in (self._save_dir, self._log_dir, self._web_log_dir): 54 | # config_dir = dirpath.parent 55 | # existing = list(config_dir.glob("*")) 56 | # print(f"purging {len(existing)} directories from config_dir...") 57 | # tic = time.time() 58 | # os.system(f"rm -rf {config_dir}") 59 | # print(f"Finished purge in {time.time() - tic:.3f}s") 60 | # 61 | # # save updated config file to the checkpoint dir 62 | # if not test: 63 | # write_json(self.config, self.save_dir / 'config.json') 64 | # 65 | # # configure logging module 66 | # setup_logging(self.log_dir) 67 | # self.log_levels = { 68 | # 0: logging.WARNING, 69 | # 1: logging.INFO, 70 | # 2: logging.DEBUG 71 | # } 72 | 73 | def initialize(self, name, module, *args, index=None, **kwargs): 74 | """ 75 | finds a function handle with the name given as 'type' in config, and returns the 76 | instance initialized with corresponding keyword args given as 'args'. 77 | """ 78 | if index is None: 79 | module_name = self[name]['type'] 80 | module_args = dict(self[name]['args']) 81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 82 | module_args.update(kwargs) 83 | else: 84 | module_name = self[name][index]['type'] 85 | module_args = dict(self[name][index]['args']) 86 | # pdb.set_trace() 87 | # if parameter not in config subdict, then check if it's in global config. 88 | signature = inspect.signature(getattr(module, module_name).__init__) 89 | print(module_name) 90 | for param in signature.parameters.keys(): 91 | if param not in module_args and param in self.config: 92 | module_args[param] = self[param] 93 | # pdb.set_trace() 94 | 95 | return getattr(module, module_name)(*args, **module_args) 96 | 97 | def __getitem__(self, name): 98 | return self.config[name] 99 | 100 | def get_logger(self, name, verbosity=2): 101 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 102 | self.log_levels.keys()) 103 | assert verbosity in self.log_levels, msg_verbosity 104 | logger = logging.getLogger(name) 105 | logger.setLevel(self.log_levels[verbosity]) 106 | return logger 107 | 108 | # setting read-only attributes 109 | @property 110 | def config(self): 111 | return self._config 112 | 113 | @property 114 | def save_dir(self): 115 | return self._save_dir 116 | 117 | @property 118 | def log_dir(self): 119 | return self._log_dir 120 | 121 | 122 | # helper functions used to update config dict with custom cli options 123 | def _update_config(config, options, args): 124 | for opt in options: 125 | value = getattr(args, _get_opt_name(opt.flags)) 126 | if value is not None: 127 | _set_by_path(config, opt.target, value) 128 | return config 129 | 130 | 131 | def _get_opt_name(flags): 132 | for flg in flags: 133 | if flg.startswith('--'): 134 | return flg.replace('--', '') 135 | return flags[0].replace('--', '') 136 | 137 | 138 | def _set_by_path(tree, keys, value): 139 | """Set a value in a nested object in tree by sequence of keys.""" 140 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 141 | 142 | 143 | def _get_by_path(tree, keys): 144 | """Access a nested object in tree by sequence of keys.""" 145 | return reduce(getitem, keys, tree) 146 | -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/patch_mask.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def patch_all_masks_from_bbox(bboxs, patch_rows=14): 7 | # generate patch masks from all bboxs 8 | # notice here bbox region is [1:3][0:2] 9 | patch_masks = np.zeros((patch_rows, patch_rows)) 10 | bboxs[:, :4] = bboxs[:, :4] * patch_rows 11 | for index, bbox in enumerate(bboxs): 12 | if bbox[4] > 7 and bbox[5] > 7: 13 | bbox[0] += 1 / 3 * (bbox[2] - bbox[0]) 14 | bbox[1] += 1 / 3 * (bbox[3] - bbox[1]) 15 | bbox[2] -= 1 / 3 * (bbox[2] - bbox[0]) 16 | bbox[3] -= 1 / 3 * (bbox[3] - bbox[1]) 17 | patch_masks[int(bbox[1]):math.ceil(bbox[3]), int(bbox[0]):math.ceil(bbox[2])] = 1 18 | return patch_masks 19 | 20 | 21 | def image_mask_from_bbox(bboxs, img_shape): 22 | # print(img_shape) 23 | print(bboxs) 24 | w, h = img_shape[1:] 25 | mask = np.zeros((w, h)) 26 | for index, bbox in enumerate(bboxs): 27 | print(int(bbox[0].item()), int(bbox[2].item()), int(bbox[1].item()), int(bbox[3].item())) 28 | # # print(bbox) 29 | if bbox[4] > 0.5 and bbox[5] > 0.5: 30 | bbox[0] += 1 / 3 * (bbox[2] - bbox[0]) 31 | bbox[1] += 1 / 3 * (bbox[3] - bbox[1]) 32 | bbox[2] -= 1 / 3 * (bbox[2] - bbox[0]) 33 | bbox[3] -= 1 / 3 * (bbox[3] - bbox[1]) 34 | # print(bbox[0]) 35 | bbox[0] = bbox[0] * w 36 | bbox[1] = bbox[1] * h 37 | bbox[2] = bbox[2] * w 38 | bbox[3] = bbox[3] * h 39 | mask[int(bbox[0].item()): int(bbox[2].item()), int(bbox[1].item()):int(bbox[3].item())] = 1 40 | print(mask) 41 | return mask 42 | 43 | 44 | def visualize_mask(video, bboxs, out_path): 45 | """ 46 | visualize three samples frames and show the masked videos 47 | Args: 48 | video: 49 | bboxs: 50 | out_path: 51 | 52 | Returns: 53 | 54 | """ 55 | num_frames = len(video) 56 | out_imgs = None 57 | for index in range(num_frames): 58 | img = video[index] * 255. 59 | bbox_10 = bboxs[index] 60 | masks = image_mask_from_bbox(bbox_10, img.shape) 61 | mask_img = img * masks 62 | if out_imgs is None: 63 | out_imgs = np.concatenate((img, mask_img), axis=2) 64 | else: 65 | out_imgs = np.concatenate((out_imgs, mask_img), axis=2) 66 | # print(out_imgs) 67 | # print(out_imgs.shape) 68 | out_imgs = np.moveaxis(out_imgs, 0, 2) 69 | # print(out_imgs) 70 | print(out_imgs.shape) 71 | cv2.imwrite('{}.png'.format(out_path), out_imgs) 72 | 73 | # def visualize_mask(video, bboxs, out_path): 74 | # """ 75 | # visualize three samples frames and show the masked videos 76 | # Args: 77 | # video: 78 | # bboxs: 79 | # out_path: 80 | # 81 | # Returns: 82 | # 83 | # """ 84 | # num_frames = len(video) 85 | # out_imgs = None 86 | # for index in range(num_frames): 87 | # img = video[index] * 255. 88 | # img = img.permute(1, 2, 0) 89 | # print(img.shape) 90 | # img = cv2.resize(np.float32(img), (14, 14)) 91 | # bbox_10 = bboxs[index] 92 | # # masks = image_mask_from_bbox(bbox_10, img.shape) 93 | # masks = patch_all_masks_from_bbox(bbox_10) 94 | # # print(masks) 95 | # mask_img = img * np.expand_dims(masks, axis=2) 96 | # if out_imgs is None: 97 | # out_imgs = np.concatenate((img, mask_img), axis=1) 98 | # else: 99 | # out_imgs = np.concatenate((out_imgs, mask_img), axis=1) 100 | # # print(out_imgs) 101 | # # print(out_imgs.shape) 102 | # # out_imgs = np.moveaxis(out_imgs, 0, 2) 103 | # # print(out_imgs) 104 | # print(out_imgs.shape) 105 | # cv2.imwrite('{}.png'.format(out_path), out_imgs) -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/utils/nltk_test.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | nltk.data.path.append("pretrained/nltk_data") 3 | 4 | 5 | def check_nouns(words): 6 | is_noun = lambda pos: pos[:2] == 'NN' 7 | # do the nlp stuff 8 | tokenized = nltk.word_tokenize(words) 9 | nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)] 10 | if len(nouns) > 0: 11 | return True 12 | else: 13 | return False 14 | 15 | lines = 'lines is some string of words' 16 | # function to test if something is a noun 17 | is_noun = lambda pos: pos[:2] == 'NN' 18 | # do the nlp stuff 19 | tokenized = nltk.word_tokenize(lines) 20 | nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)] 21 | 22 | print(nouns) 23 | word = 'woman' 24 | print(check_nouns(word)) -------------------------------------------------------------------------------- /Visualization/Cross_Modality_Transformer_Visualization/utils/read_bboxs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import torch 4 | 5 | 6 | def read_bbox_from_pickle(object_path, top_k=5, v=1): 7 | frame1 = np.load(object_path, allow_pickle=True) 8 | boxes = frame1['bbox'] 9 | # rank features and boxes according to confidence 10 | confident = frame1['info'].item()['objects_conf'] 11 | condident_indices = np.argsort(confident)[::-1] 12 | boxes = boxes[condident_indices] 13 | object_ids = frame1['info'].item()['objects_id'] 14 | if v == 2: 15 | new_object, unique_indices = np.unique(object_ids, return_index=True) 16 | # step 2: remove region with same object class 17 | boxes = boxes[unique_indices] 18 | # padding with same elements if not enough 19 | if boxes.shape[0] < top_k: 20 | res = top_k - boxes.shape[0] 21 | boxes = np.pad(boxes, (0, res), 'edge') 22 | boxes = boxes[:top_k, :] 23 | image_width = frame1['info'].item()['image_w'] 24 | image_height = frame1['info'].item()['image_h'] 25 | box_width = boxes[:, 2] - boxes[:, 0] 26 | box_height = boxes[:, 3] - boxes[:, 1] 27 | scaled_width = box_width / image_width 28 | scaled_height = box_height / image_height 29 | scaled_x = boxes[:, 0] / image_width 30 | scaled_y = boxes[:, 1] / image_height 31 | scaled_width = scaled_width[..., np.newaxis] 32 | scaled_height = scaled_height[..., np.newaxis] 33 | scaled_x = scaled_x[..., np.newaxis] 34 | scaled_y = scaled_y[..., np.newaxis] 35 | spatial_features = np.concatenate( 36 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1) 37 | return torch.from_numpy(spatial_features) 38 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: frozen 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - av=6.1.0=py37he005a31_1000 11 | - backcall=0.2.0=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py37h5e8e339_1001 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2020.12.5=ha878542_0 16 | - cairo=1.16.0=hf32fb01_1 17 | - certifi=2020.12.5=py37h89c1867_1 18 | - cffi=1.14.5=py37hc58025e_0 19 | - chardet=4.0.0=py37h89c1867_1 20 | - click=8.0.0=py37h89c1867_0 21 | - cryptography=3.4.7=py37h5d9358c_0 22 | - cudatoolkit=11.1.74=h6bb024c_0 23 | - dataclasses=0.8=pyhc8e2a94_1 24 | - decorator=5.0.7=pyhd3eb1b0_0 25 | - ffmpeg=4.0=hcdf2ecd_0 26 | - filelock=3.0.12=pyh9f0ad1d_0 27 | - fontconfig=2.13.1=h6c09931_0 28 | - freeglut=3.0.0=hf484d3e_5 29 | - freetype=2.10.4=h5ab3b9f_0 30 | - glib=2.68.1=h36276a3_0 31 | - graphite2=1.3.14=h23475e2_0 32 | - harfbuzz=1.8.8=hffaf4a1_0 33 | - hdf5=1.10.2=hba1933b_1 34 | - huggingface_hub=0.0.8=pyhd8ed1ab_0 35 | - humanize=3.5.0=pyhd8ed1ab_0 36 | - icu=58.2=he6710b0_3 37 | - idna=2.10=pyh9f0ad1d_0 38 | - importlib-metadata=4.0.1=py37h89c1867_0 39 | - importlib_metadata=4.0.1=hd8ed1ab_0 40 | - intel-openmp=2021.2.0=h06a4308_610 41 | - ipdb=0.13.7=pyhd8ed1ab_0 42 | - ipython=7.22.0=py37hb070fc8_0 43 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 44 | - jasper=2.0.14=h07fcdf6_1 45 | - jedi=0.17.0=py37_0 46 | - joblib=1.0.1=pyhd3eb1b0_0 47 | - jpeg=9b=h024ee3a_2 48 | - lcms2=2.12=h3be6417_0 49 | - ld_impl_linux-64=2.33.1=h53a641e_7 50 | - libffi=3.3=he6710b0_2 51 | - libgcc-ng=9.3.0=h2828fa1_19 52 | - libgfortran-ng=7.3.0=hdf63c60_0 53 | - libglu=9.0.0=hf484d3e_1 54 | - libgomp=9.3.0=h2828fa1_19 55 | - libopencv=3.4.2=hb342d67_1 56 | - libopus=1.3.1=h7b6447c_0 57 | - libpng=1.6.37=hbc83047_0 58 | - libstdcxx-ng=9.3.0=h6de172a_19 59 | - libtiff=4.1.0=h2733197_1 60 | - libuuid=1.0.3=h1bed415_2 61 | - libuv=1.40.0=h7b6447c_0 62 | - libvpx=1.7.0=h439df22_0 63 | - libxcb=1.14=h7b6447c_0 64 | - libxml2=2.9.10=hb55368b_3 65 | - lz4-c=1.9.3=h2531618_0 66 | - mkl=2021.2.0=h06a4308_296 67 | - mkl-service=2.3.0=py37h27cfd23_1 68 | - mkl_fft=1.3.0=py37h42c9631_2 69 | - mkl_random=1.2.1=py37ha9443f7_2 70 | - msgpack-python=1.0.2=py37hff7bd54_1 71 | - ncurses=6.2=he6710b0_1 72 | - ninja=1.10.2=hff7bd54_1 73 | - numpy=1.20.1=py37h93e21f0_0 74 | - numpy-base=1.20.1=py37h7d8b39e_0 75 | - olefile=0.46=py37_0 76 | - opencv=3.4.2=py37h6fd60c2_1 77 | - openssl=1.1.1k=h7f98852_0 78 | - packaging=20.9=pyh44b312d_0 79 | - pandas=1.1.4=py37h10a2094_0 80 | - parso=0.8.2=pyhd3eb1b0_0 81 | - pcre=8.44=he6710b0_0 82 | - pexpect=4.8.0=pyhd3eb1b0_3 83 | - pickleshare=0.7.5=pyhd3eb1b0_1003 84 | - pip=21.0.1=py37h06a4308_0 85 | - pixman=0.40.0=h7b6447c_0 86 | - prompt-toolkit=3.0.17=pyh06a4308_0 87 | - psutil=5.8.0=py37h27cfd23_1 88 | - ptyprocess=0.7.0=pyhd3eb1b0_2 89 | - py-opencv=3.4.2=py37hb342d67_1 90 | - pycparser=2.20=pyh9f0ad1d_2 91 | - pygments=2.8.1=pyhd3eb1b0_0 92 | - pyopenssl=20.0.1=pyhd8ed1ab_0 93 | - pyparsing=2.4.7=pyh9f0ad1d_0 94 | - pysocks=1.7.1=py37h89c1867_3 95 | - python=3.7.10=hdb3f193_0 96 | - python-dateutil=2.8.1=pyhd3eb1b0_0 97 | - python_abi=3.7=1_cp37m 98 | - pytorch=1.8.1=py3.7_cuda11.1_cudnn8.0.5_0 99 | - pytz=2021.1=pyhd3eb1b0_0 100 | - readline=8.1=h27cfd23_0 101 | - regex=2021.4.4=py37h5e8e339_0 102 | - requests=2.25.1=pyhd3deb0d_0 103 | - sacremoses=0.0.43=pyh9f0ad1d_0 104 | - scikit-learn=0.24.1=py37ha9443f7_0 105 | - scipy=1.6.2=py37had2a1c9_1 106 | - setuptools=52.0.0=py37h06a4308_0 107 | - six=1.15.0=py37h06a4308_0 108 | - sqlite=3.35.4=hdfb4753_0 109 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 110 | - tk=8.6.10=hbc83047_0 111 | - tokenizers=0.10.1=py37hcb7a40c_0 112 | - torchaudio=0.8.1=py37 113 | - tqdm=4.60.0=pyhd8ed1ab_0 114 | - traitlets=5.0.5=pyhd3eb1b0_0 115 | - transformers=4.6.0=pyhd8ed1ab_0 116 | - typing_extensions=3.7.4.3=pyha847dfd_0 117 | - urllib3=1.26.4=pyhd8ed1ab_0 118 | - wcwidth=0.2.5=py_0 119 | - wheel=0.36.2=pyhd3eb1b0_0 120 | - xz=5.2.5=h7b6447c_0 121 | - zipp=3.4.1=pyhd8ed1ab_0 122 | - zlib=1.2.11=h7b6447c_3 123 | - zstd=1.4.9=haebb681_0 124 | - pip: 125 | - --find-links https://download.pytorch.org/whl/torch_stable.html 126 | - attrdict==2.0.1 127 | - attrs==21.2.0 128 | - bravado==11.0.3 129 | - bravado-core==5.17.0 130 | - colorama==0.4.4 131 | - cycler==0.10.0 132 | - decord==0.6.0 133 | - dominate==2.6.0 134 | - einops==0.3.0 135 | - future==0.18.2 136 | - gitdb==4.0.7 137 | - gitpython==3.1.17 138 | - jsonpickle==1.5.2 139 | - jsonpointer==2.1 140 | - jsonref==0.2 141 | - jsonschema==3.2.0 142 | - kiwisolver==1.3.1 143 | - matplotlib==3.4.2 144 | - monotonic==1.6 145 | - munch==2.5.0 146 | - neptune-client==0.9.9 147 | - neptune-contrib==0.27.1 148 | - oauthlib==3.1.0 149 | - pillow==8.3.1 150 | - py-cpuinfo==8.0.0 151 | - pyjwt==2.1.0 152 | - pyrsistent==0.17.3 153 | - pyyaml==5.4.1 154 | - requests-oauthlib==1.3.0 155 | - rfc3987==1.3.8 156 | - sacred==0.8.2 157 | - simplejson==3.17.2 158 | - smmap==4.0.0 159 | - strict-rfc3339==0.7 160 | - swagger-spec-validator==2.7.3 161 | - timm==0.4.5 162 | - torchvision==0.9.1+cu111 163 | - webcolors==1.11.1 164 | - websocket-client==0.59.0 165 | - wrapt==1.12.1 166 | prefix: /users/maxbain/miniconda3/envs/frozen -------------------------------------------------------------------------------- /figures/oa_main_ppl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_main_ppl.jpg -------------------------------------------------------------------------------- /figures/oa_visualize_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_visualize_1.jpg -------------------------------------------------------------------------------- /figures/oa_visualize_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_visualize_2.jpg -------------------------------------------------------------------------------- /figures/objects.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/objects.jpg -------------------------------------------------------------------------------- /figures/objects_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/objects_2.png -------------------------------------------------------------------------------- /object_extraction.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | 3 | 4 | Please follow [BUTD](https://github.com/MILVLG/bottom-up-attention.pytorch) to install detectron2. 5 | 6 | Then download pretrained model from [Google Driver](https://drive.google.com/file/d/1zFqaeNMDa6HL4tBWJd5BKu_AhCkqtacs/view?usp=sharing) and place it into pretrained. 7 | 8 | ```bash 9 | cd bottom-up-attention.pytorch 10 | mkdir pretrained 11 | mv [path_to_downloaded_pth] pretrained/ 12 | ``` 13 | 14 | ### Replace feature_extract 15 | Original BUTD provide script [feature_extract.py](https://github.com/MILVLG/bottom-up-attention.pytorch/blob/master/extract_features.py) to extract object with distributed framework ray. 16 | However, we find this tool is not stable and slowly. 17 | So we implement a 3times faster multiprocess version. 18 | 19 | Simple replace feature_extract.py with [extract_cc3m](ObjectExtractor/multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py) and [extract_wevic](ObjectExtractor/multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py). 20 | 21 | ### Webvid 2.5M 22 | ```bash 23 | python3 multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py --mode caffe \ 24 | --num-cpus 32 --gpus '0,1,2,3,4,5,6,7' \ 25 | --workers_per_gpu 2 \ 26 | --sampling_frames 8 \ 27 | --split "train" \ 28 | --dataset_dir "WebVid" \ 29 | --extract-mode roi_feats \ 30 | --min-max-boxes '10,100' \ 31 | --config-file configs/bua-caffe/extract-bua-caffe-r101.yaml 32 | ``` 33 | 34 | 35 | ### CC3M 36 | ```bash 37 | python3 multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py \ 38 | --mode caffe --num-cpus 0 --gpus '0,1,2,3,4,5,6,7' \ 39 | --extract-mode roi_feats --min-max-boxes '10,100' \ 40 | --config-file configs/bua-caffe/extract-bua-caffe-r101.yaml 41 | ``` 42 | 43 | ### Visualization 44 | We visualize some extracted bounding boxes as below: 45 | ![](figures/objects.jpg) -------------------------------------------------------------------------------- /train.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | 3 | ``` 4 | conda env create 5 | pip install decord 6 | pip install ftfy 7 | cd OATrans 8 | mkdir data; 9 | mkdir exps; 10 | ``` 11 | 12 | 13 | ## Pre-training 14 | 15 | ### Normal OA-Transformer for retrieval 16 | ```bash 17 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 29132 \ 18 | train_dist_multi.py \ 19 | --config configs/pt/cc3m_webvid/local-region-loss.json # --launcher pytorch 20 | ``` 21 | 22 | ### Region-sensitive OA-Transformer for Grounding 23 | 24 | ```bash 25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 29132 \ 26 | train_dist_multi_region_mem.py \ 27 | --config configs/pt/cc3m_webvid/local-region-loss.json # --launcher pytorch 28 | ``` 29 | 30 | 31 | ## Downstream 32 | 33 | ### zero-shot 34 | ```bash 35 | CUDA_VISIBLE_DEVICES=2,3 python -m torch.distributed.launch --nproc_per_node 2 \ 36 | --master_port 29142 test_region_mem.py --config configs/ft/msrvtt/zsl/normal.json 37 | ``` 38 | 39 | ### fine-tuning 40 | ```bash 41 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --config configs/ft/msrvtt/fine_tune/normal_1_cl.json 42 | ``` -------------------------------------------------------------------------------- /visualization.md: -------------------------------------------------------------------------------- 1 | # Cross-modality Visualization Tools 2 | 3 | ## Heatmap Visualization 4 | 5 | ```bash 6 | cd Visualizaztion/Cross_Modality_Transformer_Visualization 7 | mkdir pretrained 8 | cd pretrained && mkdir distillbert-base-uncased 9 | ``` 10 | 11 | Then download all files in [/distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased/tree/main) and place these file in the directory distillbert-base-uncased. 12 | 13 | 14 | 15 | ### Image 16 | 17 | ```bash 18 | python main_img.py 19 | ``` 20 | We provide both _feature map visualization_ and _cross-modality attention visualize._ 21 | 22 | 23 | ### Video 24 | 25 | ```bash 26 | python main_video.py 27 | ``` 28 | 29 | ## Binary Map Visualization 30 | 31 | If we ask the model to learn fine-grained align we can generate binary map as below: 32 | 33 | Refer to file [test_region_mem.py](OATrans/test_region_mem.py) for details. 34 | 35 | ![](figures/objects_2.png) --------------------------------------------------------------------------------