├── .DS_Store
├── .idea
├── .gitignore
├── alex_frozen_dist.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── OATrans
├── .DS_Store
├── __init__.py
├── args.py
├── base
│ ├── __init__.py
│ ├── base_augmentation.py
│ ├── base_data_loader.py
│ ├── base_dataset.py
│ ├── base_dataset_global_local.py
│ ├── base_dataset_region_mem.py
│ ├── base_model.py
│ └── base_trainer.py
├── configs
│ ├── ft
│ │ └── msrvtt
│ │ │ ├── fine_tune
│ │ │ └── normal_1_cl.json
│ │ │ └── zsl
│ │ │ └── normal.json
│ └── pt
│ │ └── cc3m_webvid
│ │ ├── local-region-loss.json
│ │ └── norm.json
├── data_loader
│ ├── ConceptualCaptions_dataset.py
│ ├── DiDeMo_dataset.py
│ ├── LSMDC_choice_dataset.py
│ ├── LSMDC_dataset.py
│ ├── MSRVTT_dataset.py
│ ├── MSVD_dataset.py
│ ├── WebVid_dataset.py
│ ├── __init__.py
│ ├── data_loader.py
│ ├── data_loader_v2.py
│ └── transforms.py
├── logger
│ ├── __init__.py
│ ├── logger.py
│ ├── logger_config.json
│ └── visualization.py
├── model
│ ├── __init__.py
│ ├── loss.py
│ ├── metric.py
│ ├── model.py
│ ├── model_dist.py
│ ├── oa_loss.py
│ ├── oa_model.py
│ ├── oa_model_global_local.py
│ ├── oa_model_region_mem.py
│ ├── oa_video_transformer_global_local.py
│ ├── oa_video_transformer_region.py
│ ├── prompt_learner.py
│ └── video_transformer.py
├── options.py
├── parse_config.py
├── parse_config_dist_multi.py
├── test.py
├── test_region_mem.py
├── train.py
├── train_dist_multi.py
├── train_dist_multi_global_local.py
├── train_dist_region_mem.py
├── trainer
│ ├── __init__.py
│ ├── trainer.py
│ ├── trainer_dist.py
│ ├── trainer_global_local.py
│ └── trainer_region_mem.py
└── utils
│ ├── .DS_Store
│ ├── __init__.py
│ ├── binary_classification_accuracy.py
│ ├── custom_transforms.py
│ ├── html.py
│ ├── objects_vocab.txt
│ ├── objects_vocab_fine_grained.txt
│ ├── objects_vocab_token_len
│ ├── objects_vocab_token_len.txt
│ ├── param_forzen.py
│ ├── unit_test
│ ├── __init__.py
│ ├── distill_bert.py
│ ├── load_msvd_video.py
│ └── region_roi_example.py
│ ├── util.py
│ ├── video.py
│ ├── visualization
│ ├── .DS_Store
│ ├── 3f_vto_visualize.py
│ ├── __init__.py
│ ├── learned_embedding_visualization.py
│ ├── msrvtt_3f_vto_visualize.py
│ ├── msrvtt_vto_visualization.py
│ ├── predict_visualization
│ │ ├── 0_predict.png
│ │ ├── 10_predict.png
│ │ ├── 11_predict.png
│ │ ├── 12_predict.png
│ │ ├── 13_predict.png
│ │ ├── 14_predict.png
│ │ ├── 1_predict.png
│ │ ├── 2_predict.png
│ │ ├── 3_predict.png
│ │ ├── 4_predict.png
│ │ ├── 5_predict.png
│ │ ├── 6_predict.png
│ │ ├── 7_predict.png
│ │ ├── 8_predict.png
│ │ └── 9_predict.png
│ ├── print_tags.py
│ ├── transfer_predict_visualization
│ │ ├── 0_predict.png
│ │ ├── 10_predict.png
│ │ ├── 11_predict.png
│ │ ├── 12_predict.png
│ │ ├── 13_predict.png
│ │ ├── 14_predict.png
│ │ ├── 15_predict.png
│ │ ├── 1_predict.png
│ │ ├── 2_predict.png
│ │ ├── 3_predict.png
│ │ ├── 4_predict.png
│ │ ├── 5_predict.png
│ │ ├── 6_predict.png
│ │ ├── 7_predict.png
│ │ ├── 8_predict.png
│ │ └── 9_predict.png
│ └── webvid_vto_visualization.py
│ └── visualizer.py
├── ObjectExtractor
├── multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py
└── multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py
├── README.md
├── Visualization
├── .DS_Store
└── Cross_Modality_Transformer_Visualization
│ ├── .DS_Store
│ ├── data
│ └── webvid_validation_success_full.tsv
│ ├── data_preprocess.py
│ ├── main_img.py
│ ├── main_video.py
│ ├── main_video_patches_visualization.py
│ ├── model
│ ├── __init__.py
│ ├── text_model.py
│ ├── text_models
│ │ └── distill_bert.py
│ ├── vision_model.py
│ └── vision_models
│ │ ├── clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── clip.py
│ │ ├── model.py
│ │ └── simple_tokenizer.py
│ │ └── frozen.py
│ ├── parse_config.py
│ ├── patch_mask.py
│ ├── utils
│ ├── nltk_test.py
│ └── read_bboxs.py
│ └── visualize.py
├── environment.yml
├── figures
├── oa_main_ppl.jpg
├── oa_visualize_1.jpg
├── oa_visualize_2.jpg
├── objects.jpg
└── objects_2.png
├── object_extraction.md
├── train.md
└── visualization.md
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/alex_frozen_dist.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/OATrans/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/.DS_Store
--------------------------------------------------------------------------------
/OATrans/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/__init__.py
--------------------------------------------------------------------------------
/OATrans/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args(description='MILNCE'):
4 | parser = argparse.ArgumentParser(description=description)
5 | parser.add_argument(
6 | '--train_csv',
7 | type=str,
8 | default='csv/hmdb51.csv',
9 | help='train csv')
10 | parser.add_argument(
11 | '--video_path',
12 | type=str,
13 | default='',
14 | help='video_path')
15 | parser.add_argument(
16 | '--caption_root',
17 | type=str,
18 | default='',
19 | help='video_path')
20 | parser.add_argument(
21 | '--checkpoint_root',
22 | type=str,
23 | default='checkpoint',
24 | help='checkpoint dir root')
25 | parser.add_argument(
26 | '--log_root',
27 | type=str,
28 | default='log',
29 | help='log dir root')
30 | parser.add_argument(
31 | '--eval_video_root',
32 | type=str,
33 | default='',
34 | help='root folder for the video at for evaluation')
35 | parser.add_argument(
36 | '--checkpoint_dir',
37 | type=str,
38 | default='',
39 | help='checkpoint model folder')
40 | parser.add_argument(
41 | '--optimizer', type=str, default='adam', help='opt algorithm')
42 | parser.add_argument('--weight_init', type=str, default='uniform',
43 | help='CNN weights inits')
44 | parser.add_argument('--num_thread_reader', type=int, default=20,
45 | help='')
46 | parser.add_argument('--num_class', type=int, default=512,
47 | help='upper epoch limit')
48 | parser.add_argument('--num_candidates', type=int, default=1,
49 | help='num candidates for MILNCE loss')
50 | parser.add_argument('--batch_size', type=int, default=256,
51 | help='batch size')
52 | parser.add_argument('--num_windows_test', type=int, default=4,
53 | help='number of testing windows')
54 | parser.add_argument('--batch_size_val', type=int, default=32,
55 | help='batch size eval')
56 | parser.add_argument('--momemtum', type=float, default=0.9,
57 | help='SGD momemtum')
58 | parser.add_argument('--n_display', type=int, default=10,
59 | help='Information display frequence')
60 | parser.add_argument('--num_frames', type=int, default=16,
61 | help='random seed')
62 | parser.add_argument('--video_size', type=int, default=224,
63 | help='random seed')
64 | parser.add_argument('--crop_only', type=int, default=1,
65 | help='random seed')
66 | parser.add_argument('--centercrop', type=int, default=0,
67 | help='random seed')
68 | parser.add_argument('--random_flip', type=int, default=1,
69 | help='random seed')
70 | parser.add_argument('--verbose', type=int, default=1,
71 | help='')
72 | parser.add_argument('--warmup_steps', type=int, default=5000,
73 | help='')
74 | parser.add_argument('--min_time', type=float, default=5.0,
75 | help='')
76 | parser.add_argument(
77 | '--pretrain_cnn_path',
78 | type=str,
79 | default='',
80 | help='')
81 | parser.add_argument(
82 | '--word2vec_path', type=str, default='data/word2vec.pth', help='')
83 | parser.add_argument('--fps', type=int, default=5, help='')
84 | parser.add_argument('--cudnn_benchmark', type=int, default=0,
85 | help='')
86 | parser.add_argument('--epochs', default=150, type=int, metavar='N',
87 | help='number of total epochs to run')
88 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
89 | help='manual epoch number (useful on restarts)')
90 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
91 | metavar='LR', help='initial learning rate', dest='lr')
92 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
93 | help='momentum')
94 | parser.add_argument('--resume', dest='resume', action='store_true',
95 | help='resume training from last checkpoint')
96 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
97 | help='evaluate model on validation set')
98 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
99 | help='use pre-trained model')
100 | parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
101 | help='use pin_memory')
102 | parser.add_argument('--world-size', default=-1, type=int,
103 | help='number of nodes for distributed training')
104 | parser.add_argument('--rank', default=-1, type=int,
105 | help='node rank for distributed training')
106 | parser.add_argument('--dist-file', default='dist-file', type=str,
107 | help='url used to set up distributed training')
108 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
109 | help='url used to set up distributed training')
110 | parser.add_argument('--dist-backend', default='nccl', type=str,
111 | help='distributed backend')
112 | parser.add_argument('--seed', default=1, type=int,
113 | help='seed for initializing training. ')
114 | parser.add_argument('--gpu', default=None, type=int,
115 | help='GPU id to use.')
116 | parser.add_argument('--multiprocessing-distributed', action='store_true',
117 | help='Use multi-processing distributed training to launch '
118 | 'N processes per node, which has N GPUs. This is the '
119 | 'fastest way to use PyTorch for either single node or '
120 | 'multi node data parallel training')
121 | args = parser.parse_args()
122 | return args
123 |
--------------------------------------------------------------------------------
/OATrans/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_data_loader import *
2 | # from .base_dataset_v3 import *
3 | from .base_dataset import *
4 | from .base_model import *
5 | from .base_trainer import *
--------------------------------------------------------------------------------
/OATrans/base/base_augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | from PIL import ImageFilter
3 | # import nltk
4 | # nltk.data.path.append("pretrained/nltk_data")
5 | # from textaugment import EDA
6 |
7 |
8 | def textaug_eda(caption):
9 | aug_caption = caption
10 | t = EDA()
11 | if random.random() < 0.5:
12 | if random.random() < 0.3:
13 | aug_caption = t.synonym_replacement(aug_caption)
14 | aug_caption = t.random_deletion(aug_caption, p=random.random()*0.3)
15 | if random.random() < 0.3:
16 | aug_caption = t.random_swap(aug_caption)
17 | if random.random() < 0.3:
18 | aug_caption = t.random_insertion(aug_caption)
19 | return aug_caption
20 |
21 |
22 | def textaug_advanced(caption, aug_model):
23 | return aug_model.augment(caption)
24 |
25 |
26 |
27 | def mask_aug(sentence):
28 | words = sentence.split(' ')
29 | word_index = random.randint(0, len(words))
30 | words[word_index] = "[MASK]"
31 | new_cpation = ' '.join(words)
32 | new_sentence = ""
33 | # shuffle object localization
34 | # random drop some objects
35 | return new_sentence
36 |
37 |
38 | class GaussianBlur(object):
39 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
40 |
41 | def __init__(self, sigma=[.1, 2.]):
42 | self.sigma = sigma
43 |
44 | def __call__(self, x):
45 | sigma = random.uniform(self.sigma[0], self.sigma[1])
46 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
47 | return x
--------------------------------------------------------------------------------
/OATrans/base/base_data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataloader import default_collate
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 | from torch.utils.data.distributed import DistributedSampler
6 |
7 | class BaseDataLoader(DataLoader):
8 | """
9 | Base class for all data loaders
10 | """
11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
12 | self.validation_split = validation_split
13 | self.shuffle = shuffle
14 |
15 | self.batch_idx = 0
16 | self.n_samples = len(dataset)
17 |
18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
19 |
20 | self.init_kwargs = {
21 | 'dataset': dataset,
22 | 'batch_size': batch_size,
23 | 'shuffle': self.shuffle,
24 | 'collate_fn': collate_fn,
25 | 'num_workers': num_workers
26 | }
27 | super().__init__(sampler=self.sampler, **self.init_kwargs)
28 |
29 | def _split_sampler(self, split):
30 | if split == 0.0:
31 | return None, None
32 |
33 | idx_full = np.arange(self.n_samples)
34 |
35 | np.random.seed(0)
36 | np.random.shuffle(idx_full)
37 |
38 | if isinstance(split, int):
39 | assert split > 0
40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
41 | len_valid = split
42 | else:
43 | len_valid = int(self.n_samples * split)
44 |
45 | valid_idx = idx_full[0:len_valid]
46 | train_idx = np.delete(idx_full, np.arange(0, len_valid))
47 |
48 | train_sampler = SubsetRandomSampler(train_idx)
49 | valid_sampler = SubsetRandomSampler(valid_idx)
50 | # turn off shuffle option which is mutually exclusive with sampler
51 | self.shuffle = False
52 | self.n_samples = len(train_idx)
53 |
54 | return train_sampler, valid_sampler
55 |
56 | def split_validation(self, diff_kwargs=None):
57 | init_kwargs = self.init_kwargs
58 | if diff_kwargs is not None:
59 | init_kwargs.update(diff_kwargs)
60 | if self.valid_sampler is None:
61 | return None
62 | else:
63 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
64 |
65 | def num_samples(self):
66 | return len(self.sampler)
67 |
68 |
69 | class BaseDataLoaderExplicitSplit(DataLoader):
70 | """
71 | Base class for all data loaders
72 | """
73 | def __init__(self, args, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate):
74 | self.shuffle = shuffle
75 | self.args = args
76 | self.batch_idx = 0
77 | self.n_samples = len(dataset)
78 |
79 | self.init_kwargs = {
80 | 'dataset': dataset,
81 | 'batch_size': batch_size,
82 | 'shuffle': self.shuffle,
83 | 'collate_fn': collate_fn,
84 | 'num_workers': num_workers,
85 | 'pin_memory': True
86 | }
87 | super().__init__(**self.init_kwargs)
88 |
89 | class DistBaseDataLoaderExplicitSplit(DataLoader):
90 | """
91 | Base class for all data loaders
92 | """
93 | def __init__(self, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate):
94 | self.shuffle = shuffle
95 |
96 | self.batch_idx = 0
97 | self.n_samples = len(dataset)
98 | self.train_sampler = DistributedSampler(dataset)
99 | self.init_kwargs = {
100 | 'dataset': dataset,
101 | 'batch_size': batch_size,
102 | 'shuffle': False,
103 | 'collate_fn': collate_fn,
104 | 'num_workers': num_workers,
105 | 'pin_memory': True,
106 | 'sampler': self.train_sampler
107 | }
108 | super().__init__(**self.init_kwargs)
109 |
110 | class MultiDistBaseDataLoaderExplicitSplit(DataLoader):
111 | """
112 | Base class for all data loaders
113 | """
114 | def __init__(self, args, dataset, batch_size, shuffle, num_workers, collate_fn=default_collate):
115 | self.shuffle = shuffle
116 |
117 | self.batch_idx = 0
118 | self.n_samples = len(dataset)
119 | self.args = args
120 | self.train_sampler = DistributedSampler(dataset, num_replicas=self.args.world_size, rank=self.args.rank, drop_last=True)
121 | self.init_kwargs = {
122 | 'dataset': dataset,
123 | 'batch_size': batch_size,
124 | 'shuffle': False,
125 | 'collate_fn': collate_fn,
126 | 'num_workers': num_workers,
127 | 'pin_memory': True,
128 | 'sampler': self.train_sampler
129 | }
130 | super().__init__(**self.init_kwargs)
131 |
132 | class BaseMultiDataLoader:
133 | """
134 | Currently implemented as undersample the bigger dataloaders...
135 | """
136 | def __init__(self, dataloaders):
137 | self.dataloaders = dataloaders
138 | self.batch_size = self.dataloaders[0].batch_size
139 | def __getitem__(self, item):
140 | dl_idx = item % len(self.dataloaders)
141 | return next(iter(self.dataloaders[dl_idx]))
142 |
143 | def __len__(self):
144 | return min([len(x) for x in self.dataloaders]) * len(self.dataloaders)
145 |
146 | def num_samples(self):
147 | return sum([len(x.sampler) for x in self.dataloaders])
148 |
--------------------------------------------------------------------------------
/OATrans/base/base_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from abc import abstractmethod
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | @abstractmethod
11 | def forward(self, *inputs):
12 | """
13 | Forward pass logic
14 |
15 | :return: Model output
16 | """
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | """
21 | Model prints with number of trainable parameters
22 | """
23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
24 | params = sum([np.prod(p.size()) for p in model_parameters])
25 | return super().__str__() + '\nTrainable parameters: {}'.format(params)
26 |
--------------------------------------------------------------------------------
/OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "MSRVTTjsfusion_4f_stformer_pt-im21k",
3 | "n_gpu": 8,
4 | "linear_evaluation": false,
5 | "arch": {
6 | "type": "FrozenInTime",
7 | "stream": 2,
8 | "object": false,
9 | "args": {
10 | "video_params": {
11 | "model": "SpaceTimeTransformer",
12 | "arch_config": "base_patch16_224",
13 | "num_frames": 4,
14 | "pretrained": true,
15 | "time_init": "zeros",
16 | "two_outputs": false,
17 | "object_pseudo_label": false
18 | },
19 | "object_params": {
20 | "model": "",
21 | "input_objects": false
22 | },
23 | "text_params": {
24 | "model": "pretrained/distilbert-base-uncased",
25 | "pretrained": true,
26 | "input": "text",
27 | "two_outputs": false
28 | },
29 | "projection": "minimal",
30 | "load_checkpoint": "exps/2stream_wtags/models/full-WebVid2M-1f-pti2k/0106_180724/checkpoint-epoch3.pth"
31 | }
32 | },
33 | "data_loader":
34 | [
35 | {
36 | "type": "TextObjectVideoDataLoader",
37 | "args":{
38 | "dataset_name": "MSRVTT",
39 | "data_dir": "MSRVTT/",
40 | "object_dir": "MSRVTT/region_features_full/",
41 | "shuffle": true,
42 | "num_workers": 8,
43 | "batch_size": 64,
44 | "split": "train",
45 | "cut": "jsfusion",
46 | "subsample": 1,
47 | "text_params": {
48 | "object_tags": false,
49 | "drop_raw_caption": false,
50 | "text_aug": false,
51 | "object_aug": false
52 | },
53 | "object_params": {
54 | "input_objects": false,
55 | "pseudo_labels": false,
56 | "input_object_bboxs": false
57 | },
58 | "video_params": {
59 | "extraction_fps": 25,
60 | "extraction_res": 256,
61 | "input_res": 224,
62 | "num_frames": 4,
63 | "stride": 1
64 | }
65 | }
66 | }
67 | ],
68 | "optimizer": {
69 | "type": "AdamW",
70 | "args":{
71 | "lr": 3e-5
72 | }
73 | },
74 | "loss": {
75 | "type": "NormSoftmaxLoss",
76 | "args": {
77 | }
78 | },
79 | "metrics": [
80 | "t2v_metrics",
81 | "v2t_metrics"
82 | ],
83 | "trainer": {
84 | "epochs": 100,
85 | "max_samples_per_epoch": 9000,
86 | "save_dir": "exps",
87 | "save_period": 5,
88 | "verbosity": 2,
89 | "monitor": "min val_loss_0",
90 | "early_stop": 10,
91 | "neptune": false
92 | },
93 | "visualizer": {
94 | "type": "",
95 | "args": {
96 | }
97 | }
98 |
99 | }
--------------------------------------------------------------------------------
/OATrans/configs/ft/msrvtt/zsl/normal.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "MSRVTTjsfusion_4f_stformer_pt-im21k",
3 | "n_gpu": 2,
4 | "arch": {
5 | "type": "FrozenInTime",
6 | "stream": 2,
7 | "object": true,
8 | "args": {
9 | "video_params": {
10 | "model": "SpaceTimeTransformer",
11 | "arch_config": "base_patch16_224",
12 | "num_frames": 4,
13 | "pretrained": true,
14 | "time_init": "zeros",
15 | "two_outputs": false,
16 | "object_pseudo_label": false
17 | },
18 | "object_params": {
19 | "model": "",
20 | "input_objects": false
21 | },
22 | "text_params": {
23 | "model": "pretrained/distilbert-base-uncased",
24 | "pretrained": true,
25 | "input": "text",
26 | "two_outputs": true
27 | },
28 | "projection": "minimal",
29 | "load_checkpoint": "exps/2stream_wtags/models/full-WebVid2M-1f-pti2k/0106_180724/checkpoint-epoch3.pth"
30 | }
31 | },
32 | "data_loader": {
33 | "type": "MultiDistTextObjectVideoDataLoader",
34 | "args":{
35 | "dataset_name": "MSRVTT",
36 | "data_dir": "MSRVTT/",
37 | "object_dir": "MSRVTT/region_features_full/",
38 | "shuffle": true,
39 | "num_workers": 8,
40 | "batch_size": 16,
41 | "split": "train",
42 | "cut": "jsfusion",
43 | "subsample": 1,
44 | "text_params": {
45 | "object_tags": true,
46 | "drop_raw_caption": false,
47 | "text_aug": false,
48 | "object_aug": false
49 | },
50 | "object_params": {
51 | "input_objects": false,
52 | "pseudo_labels": false,
53 | "input_object_bboxs":false
54 | },
55 | "video_params": {
56 | "extraction_fps": 25,
57 | "extraction_res": 256,
58 | "input_res": 224,
59 | "num_frames": 4,
60 | "stride": 1
61 | }
62 | }
63 | },
64 | "optimizer": {
65 | "type": "AdamW",
66 | "args":{
67 | "lr": 3e-5
68 | }
69 | },
70 | "loss": {
71 | "type": "NormSoftmaxLoss",
72 | "args": {
73 | }
74 | },
75 | "metrics": [
76 | "t2v_metrics",
77 | "v2t_metrics"
78 | ],
79 | "trainer": {
80 | "epochs": 100,
81 | "max_samples_per_epoch": 9000,
82 | "save_dir": "exps",
83 | "save_period": 5,
84 | "verbosity": 2,
85 | "monitor": "min val_loss",
86 | "early_stop": 10,
87 | "neptune": true
88 | },
89 | "visualizer": {
90 | "type": "",
91 | "args": {
92 | }
93 | }
94 |
95 | }
--------------------------------------------------------------------------------
/OATrans/configs/pt/cc3m_webvid/local-region-loss.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "full-cc-WebVid2M-1f-pti2k",
3 | "n_gpu": 4,
4 | "arch": {
5 | "type": "FrozenInTime",
6 | "object": true,
7 | "stream": 2,
8 | "args": {
9 | "video_params": {
10 | "model": "SpaceTimeTransformer",
11 | "arch_config": "base_patch16_224",
12 | "num_frames": 4,
13 | "pretrained": true,
14 | "time_init": "zeros",
15 | "two_outputs": false,
16 | "object_pseudo_label": false
17 | },
18 | "object_params": {
19 | "model": "",
20 | "input_objects": false
21 | },
22 | "text_params": {
23 | "model": "pretrained/distilbert-base-uncased",
24 | "pretrained": true,
25 | "input": "text",
26 | "two_outputs": true
27 | },
28 | "projection": "minimal",
29 | "load_checkpoint" : ""
30 | }
31 | },
32 | "data_loader":
33 | [
34 | {
35 | "type": "MultiDistTextObjectVideoDataLoader",
36 | "args":{
37 | "dataset_name": "ConceptualCaptions3M",
38 | "data_dir": "CC3M/",
39 | "object_dir": "CC3M/1_frame_object",
40 | "reader": "cv2",
41 | "shuffle": true,
42 | "num_workers": 8,
43 | "batch_size": 16,
44 | "split": "train",
45 | "subsample": 1,
46 | "text_params": {
47 | },
48 | "object_params": {
49 | },
50 | "video_params": {
51 | "input_res": 224,
52 | "num_frames": 1,
53 | "loading": "lax"
54 | }
55 | }
56 | },
57 | {
58 | "type": "MultiDistTextObjectVideoDataLoader",
59 | "args":{
60 | "dataset_name": "WebVid",
61 | "data_dir": "WebVid",
62 | "object_dir": "WebVid/8_frame_object",
63 | "reader": "cv2",
64 | "shuffle": true,
65 | "num_workers": 8,
66 | "batch_size": 16,
67 | "split": "train",
68 | "cut": "2M",
69 | "subsample": 1,
70 | "text_params": {
71 | },
72 | "object_params": {
73 | },
74 | "video_params": {
75 | "input_res": 224,
76 | "num_frames": 4,
77 | "loading": "lax"
78 | }
79 | }
80 | }
81 | ],
82 | "optimizer": {
83 | "type": "AdamW",
84 | "args":{
85 | "lr": 2e-4
86 | }
87 | },
88 | "loss": {
89 | "type": "NormSoftmaxLoss",
90 | "args": {
91 | }
92 | },
93 | "metrics": [
94 | "t2v_metrics",
95 | "v2t_metrics"
96 | ],
97 | "trainer": {
98 | "epochs": 100,
99 | "max_samples_per_epoch": 1000000,
100 | "save_dir": "exps/2stream_wtags",
101 | "save_period": 5,
102 | "verbosity": 2,
103 | "monitor": "min val_loss_0",
104 | "early_stop": 10,
105 | "init_val": true,
106 | "neptune": false
107 | },
108 | "visualizer": {
109 | "type": ""
110 | }
111 |
112 | }
113 |
--------------------------------------------------------------------------------
/OATrans/configs/pt/cc3m_webvid/norm.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "full-cc-WebVid2M-1f-pti2k-normal",
3 | "n_gpu": 8,
4 | "arch": {
5 | "type": "FrozenInTime",
6 | "object": false,
7 | "stream": 2,
8 | "args": {
9 | "video_params": {
10 | "model": "SpaceTimeTransformer",
11 | "arch_config": "base_patch16_224",
12 | "num_frames": 4,
13 | "pretrained": true,
14 | "time_init": "zeros",
15 | "two_outputs": false,
16 | "object_pseudo_label": false
17 | },
18 | "object_params": {
19 | "model": "",
20 | "input_objects": false
21 | },
22 | "text_params": {
23 | "model": "pretrained/distilbert-base-uncased",
24 | "pretrained": true,
25 | "input": "text",
26 | "two_outputs": false
27 | },
28 | "projection": "minimal",
29 | "load_checkpoint" : ""
30 | }
31 | },
32 | "data_loader":
33 | [
34 | {
35 | "type": "MultiDistTextObjectVideoDataLoader",
36 | "args":{
37 | "dataset_name": "ConceptualCaptions3M",
38 | "data_dir": "CC3M/",
39 | "object_dir": "CC3M/1_frame_object",
40 | "reader": "cv2",
41 | "shuffle": true,
42 | "num_workers": 8,
43 | "batch_size": 16,
44 | "split": "train",
45 | "subsample": 1,
46 | "text_params": {
47 | },
48 | "object_params": {
49 | },
50 | "video_params": {
51 | "input_res": 224,
52 | "num_frames": 1,
53 | "loading": "lax"
54 | }
55 | }
56 | },
57 | {
58 | "type": "MultiDistTextObjectVideoDataLoader",
59 | "args":{
60 | "dataset_name": "WebVid",
61 | "data_dir": "WebVid",
62 | "object_dir": "WebVid/8_frame_object",
63 | "reader": "cv2",
64 | "shuffle": true,
65 | "num_workers": 8,
66 | "batch_size": 16,
67 | "split": "train",
68 | "cut": "2M",
69 | "subsample": 1,
70 | "text_params": {
71 | },
72 | "object_params": {
73 | },
74 | "video_params": {
75 | "input_res": 224,
76 | "num_frames": 4,
77 | "loading": "lax"
78 | }
79 | }
80 | }
81 | ],
82 | "optimizer": {
83 | "type": "AdamW",
84 | "args":{
85 | "lr": 2e-4
86 | }
87 | },
88 | "loss": {
89 | "type": "NormSoftmaxLoss",
90 | "args": {
91 | }
92 | },
93 | "metrics": [
94 | "t2v_metrics",
95 | "v2t_metrics"
96 | ],
97 | "trainer": {
98 | "epochs": 100,
99 | "max_samples_per_epoch": 1000000,
100 | "save_dir": "exps/2stream_wtags",
101 | "save_period": 5,
102 | "verbosity": 2,
103 | "monitor": "min val_loss_0",
104 | "early_stop": 10,
105 | "init_val": true,
106 | "neptune": false
107 | },
108 | "visualizer": {
109 | "type": ""
110 | }
111 |
112 | }
113 |
--------------------------------------------------------------------------------
/OATrans/data_loader/ConceptualCaptions_dataset.py:
--------------------------------------------------------------------------------
1 | # from base.base_dataset import TextObjectImageDataset
2 | from OATrans.base.base_dataset_region_mem import TextObjectImageDataset
3 | import pandas as pd
4 | import os
5 |
6 |
7 | class ConceptualCaptions3M(TextObjectImageDataset):
8 | """
9 | Conceptual Captions dataset. Split files are specific to my download regime.
10 | """
11 |
12 | def _load_metadata(self):
13 | # download specific
14 | metadata_dir = './meta_data'
15 | split_files = {
16 | 'train': 'cc3m_training_success_full.tsv',
17 | 'val': 'cc3m_validation_success_full.tsv', # there is no test
18 | }
19 | target_split_fp = split_files[self.split]
20 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
21 |
22 | if self.subsample < 1:
23 | metadata = metadata.sample(frac=self.subsample)
24 | # elif self.split == 'val':
25 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample.
26 |
27 | self.metadata = metadata
28 |
29 | def _get_video_path(self, sample):
30 | # conceptual captions uses this hashing to create the filename
31 | rel_dir = 'training'
32 | if self.split != 'train':
33 | rel_dir = 'validation'
34 | rel_fp = os.path.join(rel_dir, sample[1])
35 | #rel_fp = os.path.join(rel_dir, str(zlib.crc32(sample['thumbnailUrl'].encode('utf-8')) & 0xffffffff))
36 | return os.path.join(self.data_dir, rel_fp), rel_fp
37 |
38 | def _get_caption(self, sample):
39 | return sample[0]
40 | #return sample['caption']
41 |
42 | def _get_object_path(self, sample):
43 | """
44 | get the object npy path
45 | Args:
46 | sample (dict):
47 | Returns:
48 | abs path
49 | """
50 | # pre = sample[1].split('_')[0]
51 | # pre = pre.zfill(7)
52 | # rel_object_fp = os.path.join(pre[:4], sample[1])
53 | # rel_object_fp = os.path.join(pre[:4], sample[1] + '_1.npz')
54 | rel_object_fp = sample[1]
55 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
56 | return os.path.join(self.split, rel_object_fp), full_object_fp
--------------------------------------------------------------------------------
/OATrans/data_loader/DiDeMo_dataset.py:
--------------------------------------------------------------------------------
1 | from OATrans.base.base_dataset import TextObjectVideoDataset
2 | import pandas as pd
3 | import os
4 |
5 |
6 | class DiDeMo(TextObjectVideoDataset):
7 | def _load_metadata(self):
8 | metadata_dir = './meta_data'
9 | split_files = {
10 | 'train': 'DiDeMo_train.tsv',
11 | 'val': 'DiDeMo_val.tsv', # there is no test
12 | 'test': 'DiDeMo_test.tsv'
13 | }
14 | target_split_fp = split_files[self.split]
15 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
16 | if self.subsample < 1:
17 | metadata = metadata.sample(frac=self.subsample)
18 | self.metadata = metadata
19 | print("load split {}, {} samples".format(self.split, len(metadata)))
20 |
21 | def _get_video_path(self, sample):
22 | rel_video_fp = sample[1]
23 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
24 | full_video_fp = os.path.join(self.data_dir, rel_video_fp)
25 | # print(full_video_fp)
26 | return full_video_fp, rel_video_fp
27 |
28 | def _get_caption(self, sample):
29 | # print(sample[0].split(',')[0])
30 | # return sample[0].split(',')[0]
31 | return sample[0] # .split(',')[0]
32 |
33 | def _get_object_path(self, sample, index=0):
34 | """
35 | get the object npy path
36 | Args:
37 | sample (dict):
38 | Returns:
39 | abs path
40 | """
41 | rel_object_fp = os.path.join(sample[1], '1.npz')
42 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
43 | return os.path.join(self.split, rel_object_fp), full_object_fp
--------------------------------------------------------------------------------
/OATrans/data_loader/LSMDC_choice_dataset.py:
--------------------------------------------------------------------------------
1 | from OATrans.base.base_dataset import TextVideoDataset
2 | import pandas as pd
3 | import os
4 | import numpy as np
5 |
6 |
7 | class LSMDC(TextVideoDataset):
8 | def _load_metadata(self):
9 | split_paths = {key: os.path.join(self.metadata_dir, 'structured-symlinks', f'{key}_list.txt') for key in
10 | ['train', 'val', 'test']}
11 | df_dict = {key: pd.read_csv(val, names=['videoid']) for key, val in split_paths.items()}
12 | #### subsample_val
13 |
14 | self.split_sizes = {key: len(val) for key, val in df_dict.items()}
15 | target_vids = df_dict[self.split]
16 | # target_vids = target_vids['videoid'].str.split('.').str[0]
17 | if self.subsample < 1:
18 | target_vids = target_vids.sample(frac=self.subsample)
19 | captions = np.load(os.path.join(self.metadata_dir, 'structured-symlinks', 'raw-captions.pkl'),
20 | allow_pickle=True)
21 | captions = pd.DataFrame.from_dict(captions, orient='index')
22 | captions['captions'] = captions.values.tolist()
23 | target_vids.set_index('videoid', inplace=True)
24 | target_vids['captions'] = captions['captions']
25 | # import pdb; -.set_trace()
26 | # captions = captions[captions.index.isin(target_vids.str['videoid'].split('.').str[0])]
27 | self.metadata = target_vids
28 | frame_tar_list = pd.read_csv(os.path.join(self.metadata_dir, 'frame_tar_list.txt'), names=['fp'])
29 |
30 | frame_tar_list['fn'] = frame_tar_list['fp'].str.split('/').str[-2:].str.join('/')
31 | frame_tar_list['fn'] = frame_tar_list['fn'].str.replace('.tar', '')
32 | frame_tar_list['vid_stem'] = frame_tar_list['fn'].str.split('/').str[-1]
33 |
34 | frame_tar_list = frame_tar_list[frame_tar_list['vid_stem'].isin(self.metadata.index)]
35 |
36 | frame_tar_list.set_index('vid_stem', inplace=True)
37 | self.metadata['fn'] = frame_tar_list['fn']
38 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [ii for ii in x if ii is not None])
39 | self.metadata['num_captions'] = self.metadata['captions'].str.len()
40 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [' '.join(ii) for ii in x])
41 |
42 | if 'videoid' not in self.metadata.columns:
43 | self.metadata['videoid'] = self.metadata.index
44 |
45 | def _get_video_path(self, sample):
46 | return os.path.join(self.data_dir, 'videos', sample['fn'] + '.avi'), sample.name + '.avi'
47 |
48 | def _get_caption(self, sample):
49 | if len(sample['captions']) != 1:
50 | raise NotImplementedError
51 | return sample['captions'][0]
--------------------------------------------------------------------------------
/OATrans/data_loader/LSMDC_dataset.py:
--------------------------------------------------------------------------------
1 | from OATrans.base.base_dataset import TextVideoDataset
2 | import pandas as pd
3 | import os
4 | import numpy as np
5 |
6 |
7 | class LSMDC(TextVideoDataset):
8 | def _load_metadata(self):
9 | split_paths = {key: os.path.join(self.metadata_dir, 'structured-symlinks', f'{key}_list.txt') for key in
10 | ['train', 'val', 'test']}
11 | df_dict = {key: pd.read_csv(val, names=['videoid']) for key, val in split_paths.items()}
12 | #### subsample_val
13 |
14 | self.split_sizes = {key: len(val) for key, val in df_dict.items()}
15 | target_vids = df_dict[self.split]
16 | # target_vids = target_vids['videoid'].str.split('.').str[0]
17 | if self.subsample < 1:
18 | target_vids = target_vids.sample(frac=self.subsample)
19 | captions = np.load(os.path.join(self.metadata_dir, 'structured-symlinks', 'raw-captions.pkl'),
20 | allow_pickle=True)
21 | captions = pd.DataFrame.from_dict(captions, orient='index')
22 | captions['captions'] = captions.values.tolist()
23 | target_vids.set_index('videoid', inplace=True)
24 | target_vids['captions'] = captions['captions']
25 | # import pdb; -.set_trace()
26 | # captions = captions[captions.index.isin(target_vids.str['videoid'].split('.').str[0])]
27 | self.metadata = target_vids
28 | frame_tar_list = pd.read_csv(os.path.join(self.metadata_dir, 'frame_tar_list.txt'), names=['fp'])
29 |
30 | frame_tar_list['fn'] = frame_tar_list['fp'].str.split('/').str[-2:].str.join('/')
31 | frame_tar_list['fn'] = frame_tar_list['fn'].str.replace('.tar', '')
32 | frame_tar_list['vid_stem'] = frame_tar_list['fn'].str.split('/').str[-1]
33 |
34 | frame_tar_list = frame_tar_list[frame_tar_list['vid_stem'].isin(self.metadata.index)]
35 |
36 | frame_tar_list.set_index('vid_stem', inplace=True)
37 | self.metadata['fn'] = frame_tar_list['fn']
38 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [ii for ii in x if ii is not None])
39 | self.metadata['num_captions'] = self.metadata['captions'].str.len()
40 | self.metadata['captions'] = self.metadata['captions'].apply(lambda x: [' '.join(ii) for ii in x])
41 |
42 | if 'videoid' not in self.metadata.columns:
43 | self.metadata['videoid'] = self.metadata.index
44 |
45 | def _get_video_path(self, sample):
46 | return os.path.join(self.data_dir, 'videos', sample['fn'] + '.avi'), sample.name + '.avi'
47 |
48 | def _get_caption(self, sample):
49 | if len(sample['captions']) != 1:
50 | raise NotImplementedError
51 | return sample['captions'][0]
--------------------------------------------------------------------------------
/OATrans/data_loader/MSRVTT_dataset.py:
--------------------------------------------------------------------------------
1 | from OATrans.base.base_dataset import TextObjectVideoDataset
2 | # from base.base_dataset_global_local import TextObjectVideoDataset
3 | # from base.base_dataset_region_mem import TextObjectVideoDataset
4 | import pandas as pd
5 | import os
6 | import json
7 | import numpy as np
8 | import random
9 |
10 |
11 | class MSRVTT(TextObjectVideoDataset):
12 | def _load_metadata(self):
13 | json_fp = os.path.join(self.metadata_dir, 'annotation', 'MSR_VTT.json')
14 | with open(json_fp, 'r') as fid:
15 | data = json.load(fid)
16 | df = pd.DataFrame(data['annotations'])
17 |
18 | split_dir = os.path.join(self.metadata_dir, 'high-quality', 'structured-symlinks')
19 | js_test_cap_idx_path = None
20 | challenge_splits = {"val", "public_server_val", "public_server_test"}
21 | if self.cut == "miech":
22 | train_list_path = "train_list_miech.txt"
23 | test_list_path = "test_list_miech.txt"
24 | elif self.cut == "jsfusion":
25 | train_list_path = "train_list_jsfusion.txt"
26 | test_list_path = "val_list_jsfusion.txt"
27 | js_test_cap_idx_path = "jsfusion_val_caption_idx.pkl"
28 | elif self.cut in {"full-val", "full-test"}:
29 | train_list_path = "train_list_full.txt"
30 | if self.cut == "full-val":
31 | test_list_path = "val_list_full.txt"
32 | else:
33 | test_list_path = "test_list_full.txt"
34 | elif self.cut in challenge_splits:
35 | train_list_path = "train_list.txt"
36 | if self.cut == "val":
37 | test_list_path = f"{self.cut}_list.txt"
38 | else:
39 | test_list_path = f"{self.cut}.txt"
40 | else:
41 | msg = "unrecognised MSRVTT split: {}"
42 | raise ValueError(msg.format(self.cut))
43 |
44 | train_df = pd.read_csv(os.path.join(split_dir, train_list_path), names=['videoid'])
45 | test_df = pd.read_csv(os.path.join(split_dir, test_list_path), names=['videoid'])
46 | self.split_sizes = {'train': len(train_df), 'val': len(test_df), 'test': len(test_df)}
47 |
48 | if self.split == 'train':
49 | df = df[df['image_id'].isin(train_df['videoid'])]
50 | else:
51 | df = df[df['image_id'].isin(test_df['videoid'])]
52 |
53 | self.metadata = df.groupby(['image_id'])['caption'].apply(list)
54 | if self.subsample < 1:
55 | self.metadata = self.metadata.sample(frac=self.subsample)
56 |
57 | # use specific caption idx's in jsfusion
58 | if js_test_cap_idx_path is not None and self.split != 'train':
59 | caps = pd.Series(np.load(os.path.join(split_dir, js_test_cap_idx_path), allow_pickle=True))
60 | new_res = pd.DataFrame({'caps': self.metadata, 'cap_idx': caps})
61 | new_res['test_caps'] = new_res.apply(lambda x: [x['caps'][x['cap_idx']]], axis=1)
62 | self.metadata = new_res['test_caps']
63 |
64 | self.metadata = pd.DataFrame({'captions': self.metadata})
65 | print("load split {}, {} samples".format(self.split, len(self.metadata)))
66 |
67 | def _get_video_path(self, sample):
68 | return os.path.join(self.data_dir, 'videos', 'all', sample.name + '.mp4'), sample.name + '.mp4'
69 |
70 | def _get_caption(self, sample):
71 | caption_sample = self.text_params.get('caption_sample', "rand")
72 | if self.split in ['train', 'val'] and caption_sample == "rand":
73 | caption = random.choice(sample['captions'])
74 | else:
75 | caption = sample['captions'][0]
76 | return caption
77 |
78 | def _get_object_path(self, sample):
79 | """
80 | get the object npy path
81 | Args:
82 | sample (dict):
83 | Returns:
84 | abs path
85 | """
86 | # real_path = os.path.join(sample.name, '{}.npz'.format(index))
87 | real_path = sample.name
88 | full_object_fp = os.path.join(self.object_dir, sample.name)
89 | return real_path, full_object_fp
--------------------------------------------------------------------------------
/OATrans/data_loader/MSVD_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from OATrans.base.base_dataset import TextObjectVideoDataset
4 | import pandas as pd
5 | import os
6 |
7 |
8 | class MSVD(TextObjectVideoDataset):
9 | def _load_metadata(self):
10 | metadata_dir = './meta_data'
11 | split_files = {
12 | 'train': 'MSVD_train.tsv',
13 | # 'val': 'MSVD_val.tsv', # there is no test
14 | 'val': 'MSVD_test.tsv', # direct output test result
15 | # 'val': 'MSVD_split_test.tsv',
16 | # 'test': 'MSVD_split_test.tsv'
17 | 'test': 'MSVD_test.tsv'
18 | }
19 | target_split_fp = split_files[self.split]
20 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
21 | if self.subsample < 1:
22 | metadata = metadata.sample(frac=self.subsample)
23 | self.metadata = metadata
24 | print("load split {}, {} samples".format(self.split, len(metadata)))
25 |
26 | def _get_video_path(self, sample):
27 | rel_video_fp = sample[1] + '.avi'
28 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
29 | full_video_fp = os.path.join(self.data_dir, rel_video_fp)
30 | # print(full_video_fp)
31 | return full_video_fp, rel_video_fp
32 |
33 | # multiple sentence
34 | def _get_caption(self, sample):
35 | # print(sample[0].split(',')[0])
36 | if self.split == 'train':
37 | words = sample[0].split(',')
38 | num_word = len(words)
39 | index = random.randint(0, num_word-1)
40 | caption = words[index]
41 | else:
42 | # caption = sample[0]
43 | words = sample[0].split(',')
44 | num_word = len(words)
45 | index = random.randint(0, num_word-1)
46 | caption = words[index]
47 | # caption = None
48 | # if self.split == 'train':
49 | # indexs = sorted(random.sample(range(0, num_word-1), 5))
50 | # caption = ' '.join(words[item] for item in indexs)
51 | # else:
52 | # caption = ' '.join(words[item] for item in range(0, 5))
53 | return caption
54 |
55 | def _get_object_path(self, sample, index=1):
56 | """
57 | get the object npy path
58 | Args:
59 | sample (dict):
60 | Returns:
61 | abs path
62 | """
63 | rel_object_fp = os.path.join(sample[1], '1.npz')
64 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
65 | return os.path.join(self.split, rel_object_fp), full_object_fp
--------------------------------------------------------------------------------
/OATrans/data_loader/WebVid_dataset.py:
--------------------------------------------------------------------------------
1 | # from base.base_dataset import TextObjectVideoDataset
2 | from OATrans.base.base_dataset_region_mem import TextObjectVideoDataset
3 | # from base.base_dataset_region_single import TextObjectVideoDataset
4 | # from base.base_dataset_region_mem_bk import TextObjectVideoDataset
5 | import pandas as pd
6 | import os
7 |
8 |
9 | class WebVidObject(TextObjectVideoDataset):
10 | """
11 | WebVid Dataset.
12 | Assumes webvid data is structured as follows.
13 | Webvid/
14 | videos/
15 | 000001_000050/ ($page_dir)
16 | 1.mp4 (videoid.mp4)
17 | ...
18 | 5000.mp4
19 | ...
20 | """
21 | def _load_metadata(self):
22 | #metadata_dir = os.path.join(self.metadata_dir, 'meta_data')
23 | metadata_dir = './meta_data'
24 | split_files = {
25 | 'train': 'webvid_training_success_full.tsv',
26 | # 'train': 'webvid_1_of_10_training_success_full.tsv',
27 | # 'train': 'webvid_validation_success_full.tsv',
28 | 'val': 'webvid_validation_success_full.tsv', # there is no test
29 | }
30 |
31 | target_split_fp = split_files[self.split]
32 | metadata = pd.read_csv(os.path.join(metadata_dir, target_split_fp), sep='\t')
33 | if self.subsample < 1:
34 | metadata = metadata.sample(frac=self.subsample)
35 | # elif self.split == 'val':
36 | # metadata = metadata.sample(1000, random_state=0) # 15k val is unnecessarily large, downsample.
37 |
38 | #metadata['caption'] = metadata['name']
39 | #del metadata['name']
40 | self.metadata = metadata
41 | # TODO: clean final csv so this isn't necessary
42 | #self.metadata.dropna(inplace=True)
43 | #self.metadata['caption'] = self.metadata['caption'].str[:350]
44 |
45 | def _get_video_path(self, sample):
46 | rel_video_fp = sample[1] + '.mp4'
47 | #rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
48 | full_video_fp = os.path.join(self.data_dir, self.split, rel_video_fp)
49 | return full_video_fp, rel_video_fp
50 |
51 | def _get_caption(self, sample):
52 | return sample[0]
53 |
54 | def _get_object_path(self, sample):
55 | """
56 | get the object npy path
57 | Args:
58 | sample (dict):
59 | Returns:
60 | abs path
61 | """
62 | # rel_object_fp = sample[1] + '.pickle'
63 | rel_object_fp = sample[1] # + '.pickle'
64 | full_object_fp = os.path.join(self.object_dir, self.split, rel_object_fp)
65 | return rel_object_fp, full_object_fp
--------------------------------------------------------------------------------
/OATrans/data_loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/data_loader/__init__.py
--------------------------------------------------------------------------------
/OATrans/data_loader/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 |
4 | def init_transform_dict(input_res=224,
5 | center_crop=256,
6 | randcrop_scale=(0.5, 1.0),
7 | color_jitter=(0, 0, 0),
8 | norm_mean=(0.485, 0.456, 0.406),
9 | norm_std=(0.229, 0.224, 0.225)):
10 | normalize = transforms.Normalize(mean=norm_mean, std=norm_std)
11 | tsfm_dict = {
12 | 'train': transforms.Compose([
13 | transforms.RandomResizedCrop(input_res, scale=randcrop_scale),
14 | transforms.RandomHorizontalFlip(),
15 | transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]),
16 | normalize,
17 | ]),
18 | 'val': transforms.Compose([
19 | transforms.Resize(center_crop),
20 | transforms.CenterCrop(center_crop),
21 | transforms.Resize(input_res),
22 | normalize,
23 | ]),
24 | 'test': transforms.Compose([
25 | transforms.Resize(center_crop),
26 | transforms.CenterCrop(center_crop),
27 | transforms.Resize(input_res),
28 | normalize,
29 | ])
30 | }
31 | return tsfm_dict
32 |
--------------------------------------------------------------------------------
/OATrans/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .visualization import *
--------------------------------------------------------------------------------
/OATrans/logger/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config
3 | from pathlib import Path
4 | from OATrans.utils import read_json
5 |
6 |
7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
8 | """
9 | Setup logging configuration
10 | """
11 | log_config = Path(log_config)
12 | if log_config.is_file():
13 | config = read_json(log_config)
14 | # modify logging paths based on run config
15 | for _, handler in config['handlers'].items():
16 | if 'filename' in handler:
17 | handler['filename'] = str(save_dir / handler['filename'])
18 |
19 | logging.config.dictConfig(config)
20 | else:
21 | print("Warning: logging configuration file is not found in {}.".format(log_config))
22 | logging.basicConfig(level=default_level)
23 |
--------------------------------------------------------------------------------
/OATrans/logger/logger_config.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "version": 1,
4 | "disable_existing_loggers": false,
5 | "formatters": {
6 | "simple": {"format": "%(message)s"},
7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
8 | },
9 | "handlers": {
10 | "console": {
11 | "class": "logging.StreamHandler",
12 | "level": "DEBUG",
13 | "formatter": "simple",
14 | "stream": "ext://sys.stdout"
15 | },
16 | "info_file_handler": {
17 | "class": "logging.handlers.RotatingFileHandler",
18 | "level": "INFO",
19 | "formatter": "datetime",
20 | "filename": "info.log",
21 | "maxBytes": 10485760,
22 | "backupCount": 20, "encoding": "utf8"
23 | }
24 | },
25 | "root": {
26 | "level": "INFO",
27 | "handlers": [
28 | "console",
29 | "info_file_handler"
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/OATrans/logger/visualization.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from OATrans.utils import Timer
3 |
4 |
5 | class TensorboardWriter():
6 | def __init__(self, log_dir, logger, enabled):
7 | self.writer = None
8 | self.selected_module = ""
9 |
10 | if enabled:
11 | log_dir = str(log_dir)
12 |
13 | # Retrieve vizualization writer.
14 | succeeded = False
15 | for module in ["torch.utils.tensorboard", "tensorboardX"]:
16 | try:
17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18 | succeeded = True
19 | break
20 | except ImportError:
21 | succeeded = False
22 | self.selected_module = module
23 |
24 | if not succeeded:
25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26 | "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
27 | "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
28 | "the 'config.json' file."
29 | logger.warning(message)
30 |
31 | self.step = 0
32 | self.mode = ''
33 |
34 | self.tb_writer_ftns = {
35 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
36 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
37 | }
38 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
39 |
40 | self.timer = Timer()
41 |
42 | def set_step(self, step, mode='train'):
43 | self.mode = mode
44 | self.step = step
45 | if step == 0:
46 | self.timer.reset()
47 | else:
48 | duration = self.timer.check()
49 | self.add_scalar('steps_per_sec', 1 / duration)
50 |
51 | def __getattr__(self, name):
52 | """
53 | If visualization is configured to use:
54 | return add_data() methods of tensorboard with additional information (step, tag) added.
55 | Otherwise:
56 | return a blank function handle that does nothing
57 | """
58 | if name in self.tb_writer_ftns:
59 | add_data = getattr(self.writer, name, None)
60 |
61 | def wrapper(tag, data, *args, **kwargs):
62 | if add_data is not None:
63 | # add mode(train/valid) tag
64 | if name not in self.tag_mode_exceptions:
65 | tag = '{}/{}'.format(tag, self.mode)
66 | add_data(tag, data, self.step, *args, **kwargs)
67 | return wrapper
68 | else:
69 | # default action for returning methods defined in this class, set_step() for instance.
70 | try:
71 | attr = object.__getattr__(name)
72 | except AttributeError:
73 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
74 | return attr
75 |
76 |
77 | class SacredNeptuneWriter():
78 | def __init__(self):
79 | raise NotImplementedError
--------------------------------------------------------------------------------
/OATrans/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/model/__init__.py
--------------------------------------------------------------------------------
/OATrans/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch as th
3 | import torch.nn.functional as F
4 | import torch
5 |
6 |
7 | class NormSoftmaxLoss(nn.Module):
8 | def __init__(self, temperature=0.05):
9 | super().__init__()
10 |
11 | self.temperature = temperature
12 |
13 | def forward(self, x):
14 | "Assumes input x is similarity matrix of N x M \in [-1, 1], computed using the cosine similarity between normalised vectors"
15 | i_logsm = F.log_softmax(x/self.temperature, dim=1)
16 | j_logsm = F.log_softmax(x.t()/self.temperature, dim=1)
17 |
18 | # sum over positives
19 | idiag = torch.diag(i_logsm)
20 | loss_i = idiag.sum() / len(idiag)
21 |
22 | jdiag = torch.diag(j_logsm)
23 | loss_j = jdiag.sum() / len(jdiag)
24 |
25 | return - loss_i - loss_j
26 |
27 |
28 | class MaxMarginRankingLoss(nn.Module):
29 |
30 | def __init__(self, margin=1, fix_norm=True):
31 | super().__init__()
32 | self.fix_norm = fix_norm
33 | self.loss = th.nn.MarginRankingLoss(margin)
34 | self.margin = margin
35 |
36 | def forward(self, x):
37 | n = x.size()[0]
38 |
39 | x1 = th.diag(x)
40 | x1 = x1.unsqueeze(1)
41 | x1 = x1.expand(n, n)
42 | x1 = x1.contiguous().view(-1, 1)
43 | x1 = th.cat((x1, x1), 0)
44 |
45 | x2 = x.view(-1, 1)
46 | x3 = x.transpose(0, 1).contiguous().view(-1, 1)
47 |
48 | x2 = th.cat((x2, x3), 0)
49 | max_margin = F.relu(self.margin - (x1 - x2))
50 |
51 | if self.fix_norm:
52 | # remove the elements from the diagonal
53 | keep = th.ones(x.shape) - th.eye(x.shape[0]) # 128 x 128
54 | keep1 = keep.view(-1, 1)
55 | keep2 = keep.transpose(0, 1).contiguous().view(-1, 1)
56 | keep_idx = th.nonzero(th.cat((keep1, keep2), 0).flatten()).flatten()
57 | if x1.is_cuda:
58 | keep_idx = keep_idx.cuda()
59 | x1_ = th.index_select(x1, dim=0, index=keep_idx)
60 | x2_ = th.index_select(x2, dim=0, index=keep_idx)
61 | max_margin = F.relu(self.margin - (x1_ - x2_))
62 |
63 | return max_margin.mean()
64 |
65 |
66 | class CrossEntropy(nn.Module):
67 | def __init__(self):
68 | super().__init__()
69 | self.loss = nn.CrossEntropyLoss()
70 |
71 | def forward(self, output, target):
72 | return self.loss(output, target)
73 |
74 |
75 | def cosine_sim(im, s):
76 | """Cosine similarity between all the image and sentence pairs
77 | """
78 | return im.mm(s.t())
79 |
80 |
81 | def order_sim(im, s):
82 | """Order embeddings similarity measure $max(0, s-im)$
83 | """
84 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1))
85 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1)))
86 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t()
87 | return score
88 |
89 |
90 | def nll_loss(output, target):
91 | return F.nll_loss(output, target)
92 |
93 |
94 | if __name__ == "__main__":
95 | import torch
96 |
97 | random_sims = (torch.rand([10, 8]) * 2) - 1
98 | loss = NormSoftmaxLoss()
99 | loss(random_sims)
100 |
--------------------------------------------------------------------------------
/OATrans/model/oa_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch as th
3 | import torch.nn.functional as F
4 | import torch
5 | import math
6 | import numpy as np
7 | from model.model import sim_matrix
8 | from model.loss import NormSoftmaxLoss
9 | from torch.autograd import Variable
10 |
11 |
12 |
13 | # simsiam loss
14 |
15 |
16 | def softmax_kl_loss(input_logits, target_logits):
17 | """Takes softmax on both sides and returns KL divergence
18 | Note:
19 | - Returns the sum over all examples. Divide by the batch size afterwards
20 | if you want the mean.
21 | - Sends gradients to inputs but not the targets.
22 | """
23 | assert input_logits.size() == target_logits.size()
24 | input_log_softmax = F.log_softmax(input_logits, dim=1)
25 | target_softmax = F.softmax(target_logits, dim=1)
26 | return F.kl_div(input_log_softmax, target_softmax, size_average=False)
27 |
28 |
29 | def softmax_mse_loss(input_logits, target_logits):
30 | """Takes softmax on both sides and returns MSE loss
31 | Note:
32 | - Returns the sum over all examples. Divide by the batch size afterwards
33 | if you want the mean.
34 | - Sends gradients to inputs but not the targets.
35 | """
36 | assert input_logits.size() == target_logits.size()
37 | return F.mse_loss(input_logits, target_logits, size_average=False) # / num_classes
38 | # input_softmax = F.softmax(input_logits, dim=1)
39 | # target_softmax = F.softmax(target_logits, dim=1)
40 | # num_classes = input_logits.size()[1]
41 | # return F.mse_loss(input_softmax, target_softmax, size_average=False) # / num_classes
42 |
43 |
44 | # n_data = len(dataset)
45 | # contrast = MemoryMoCo(128, n_data, 8092*4, 0.07, use_softmax=True).cuda()
46 | # criterion = NCESoftmaxLoss()
47 | # criterion = criterion.cuda()
48 | #
49 | # out = contrast(feat_q, feat_k, feat_n, index)
50 | # contrast_loss = criterion(out)
51 |
52 |
53 | class NCESoftmaxLoss(nn.Module):
54 | """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)"""
55 | def __init__(self):
56 | super(NCESoftmaxLoss, self).__init__()
57 | self.criterion = nn.CrossEntropyLoss()
58 |
59 | def forward(self, x):
60 | bsz = x.shape[0]
61 | x = x.squeeze()
62 | label = torch.zeros([bsz]).cuda().long()
63 | loss = self.criterion(x, label)
64 | return loss
65 |
66 | class MemoryMoCo(nn.Module):
67 | """Fixed-size queue with momentum encoder"""
68 | # T = 0.2 achieve best result?
69 | def __init__(self, inputSize, outputSize, K, T=0.07, use_softmax=False):
70 | super(MemoryMoCo, self).__init__()
71 | self.outputSize = outputSize
72 | self.inputSize = inputSize
73 | self.queueSize = K
74 | self.T = T
75 | self.index = 0
76 | self.use_softmax = use_softmax
77 |
78 | self.register_buffer('params', torch.tensor([-1]))
79 | stdv = 1. / math.sqrt(inputSize / 3)
80 | self.register_buffer('memory', torch.rand(self.queueSize, inputSize).mul_(2 * stdv).add_(-stdv))
81 | # self.register_buffer('spatial_memory', torch.rand(self.queueSize, inputSize).mul_(2 * stdv).add_(-stdv))
82 | print('using queue shape: ({},{})'.format(self.queueSize, inputSize))
83 |
84 | def forward(self, q, k, n):
85 | # n, sn,
86 | batchSize = q.shape[0]
87 | k = k.detach()
88 |
89 | Z = self.params[0].item()
90 |
91 | # pos logit
92 | l_pos = torch.bmm(q.view(batchSize, 1, -1), k.view(batchSize, -1, 1))
93 | l_pos = l_pos.view(batchSize, 1)
94 |
95 | # # neg logit
96 | # # queue = self.memory_bank.get_queue(self.queueSize, indexs)
97 | queue = self.memory.clone()
98 | l_neg = torch.mm(queue.detach(), q.transpose(1, 0))
99 | l_neg = l_neg.transpose(0, 1)
100 | #out = torch.cat((l_pos, l_neg), dim=1)
101 |
102 | # other negative
103 | l_neg_2 = torch.bmm(q.view(batchSize, 1, -1), n.view(batchSize, -1, 1))
104 | l_neg_2 = l_neg_2.view(batchSize, 1)
105 | #
106 | # strong negative
107 | # l_s_neg = torch.bmm(q.view(batchSize, 1, -1), sn.view(batchSize, -1, 1))
108 | # l_s_neg = l_s_neg.view(batchSize, 1)
109 |
110 | out = torch.cat((l_pos, l_neg, l_neg_2), dim=1)
111 | # out = torch.cat((l_pos, l_neg, l_neg_2, l_s_neg), dim=1)
112 |
113 | if self.use_softmax:
114 | out = torch.div(out, self.T)
115 | out = out.squeeze().contiguous()
116 | else:
117 | out = torch.exp(torch.div(out, self.T))
118 | if Z < 0:
119 | self.params[0] = out.mean() * self.outputSize
120 | Z = self.params[0].clone().detach().item()
121 | print("normalization constant Z is set to {:.1f}".format(Z))
122 | # compute the out
123 | out = torch.div(out, Z).squeeze().contiguous()
124 |
125 | # label = torch.zeros([batchSize]).cuda().long()
126 | # loss = []
127 | # for i in range(batchSize):
128 | # loss.append(self.criterion(out[i].unsqueeze(0), label[i].unsqueeze(0)))
129 | # print(loss)
130 | # self.memory_bank.batch_set(indexs, k, loss)
131 | # self.memory = self.memory_bank.update_queue(self.memory)
132 | # print(self.memory_bank.link)
133 | # update memory
134 | with torch.no_grad():
135 | out_ids = torch.arange(batchSize).cuda()
136 | out_ids += self.index
137 | out_ids = torch.fmod(out_ids, self.queueSize) # 1 fmod 1.5 = 1 2 fmod 1.5 = 0.5
138 | out_ids = out_ids.long()
139 | self.memory.index_copy_(0, out_ids, k)
140 | self.index = (self.index + batchSize) % self.queueSize
141 | # add for spatial memory
142 |
143 | return out
144 |
145 |
146 | class FineGrainedLoss(nn.Module):
147 | def __init__(self, temperature=0.05):
148 | super().__init__()
149 | self.criterion = NormSoftmaxLoss(temperature)
150 |
151 | def forward(self, vid_feats, text_feats, bboxs, object_token_len, real_len):
152 | # find the patch that contain in bboxes
153 | loss = None
154 | bboxs[:, :4] = bboxs[:, :4] * 16
155 | bboxs[:, :2] = torch.round(bboxs[:, :2])
156 | bboxs[:, 2:4] = torch.ceil(bboxs[:, 2:4])
157 | # for each sample
158 | # print(vid_feats.size(), text_feats.size()) # 128 x 196 x 256, 128 x 14 x 256
159 |
160 | # step1: for each bbox, get corresponding features in tensor [B, 10, 256]
161 | for index, bbox in enumerate(bboxs):
162 | patch_indexs = np.zeros(16*16)
163 | for i in range(16):
164 | for j in range(16):
165 | if i > bbox[:, 0] and i < bbox[:, 2] and j > bbox[:, 1] and j < bbox[:, 3]:
166 | patch_indexs[:, i*16+j] = 1
167 | # select patch features according to indexs
168 | vid_feats_related = vid_feats[:, patch_indexs]
169 | vid_feat = torch.mean(vid_feats_related, dim=1)
170 | # shared proj head ?
171 |
172 | # step2: for text, compute the corresponding text features in tensor [B, 10, 256]
173 | # select text_feat of given bbox/ object_tokens
174 | text_feat = text_feats[:, index]
175 | # step3: compute intra_sample_loss and inter_sample_loss
176 | if loss is None:
177 | loss = self.criterion(sim_matrix(text_feat, vid_feat))
178 | else:
179 | loss += self.criterion(sim_matrix(text_feat, vid_feat))
180 | return loss
--------------------------------------------------------------------------------
/OATrans/model/prompt_learner.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from clip import clip
4 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
5 |
6 | _tokenizer = _Tokenizer()
7 |
8 |
9 | class TextEncoder(nn.Module):
10 | def __init__(self, clip_model):
11 | super().__init__()
12 | self.transformer = clip_model.transformer
13 | self.positional_embedding = clip_model.positional_embedding
14 | self.ln_final = clip_model.ln_final
15 | self.text_projection = clip_model.text_projection
16 | self.dtype = clip_model.dtype
17 |
18 | def forward(self, prompts, tokenized_prompts):
19 | x = prompts + self.positional_embedding.type(self.dtype)
20 | x = x.permute(1, 0, 2) # NLD -> LND
21 | x = self.transformer(x)
22 | x = x.permute(1, 0, 2) # LND -> NLD
23 | x = self.ln_final(x).type(self.dtype)
24 |
25 | # x.shape = [batch_size, n_ctx, transformer.width]
26 | # take features from the eot embedding (eot_token is the highest number in each sequence)
27 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
28 |
29 | return x
30 |
31 |
32 | class PromptLearner(nn.Module):
33 | def __init__(self, clip_model):
34 | super().__init__()
35 | n_ctx = 8
36 | ctx_init = False
37 | CSC = False # if class specific prompt
38 | dtype = clip_model.dtype
39 | ctx_dim = clip_model.ln_final.weight.shape[0]
40 |
41 | if ctx_init:
42 | # use given words to initialize context vectors
43 | ctx_init = ctx_init.replace("_", " ")
44 | n_ctx = len(ctx_init.split(" "))
45 | prompt = clip.tokenize(ctx_init)
46 | with torch.no_grad():
47 | embedding = clip_model.token_embedding(prompt).type(dtype)
48 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
49 | self.prompt_prefix = ctx_init
50 |
51 | else:
52 | # random initialization
53 | if CSC:
54 | print("Initializing class-specific contexts")
55 | ctx_vectors = torch.empty(1, n_ctx, ctx_dim, dtype=dtype)
56 | else:
57 | print("Initializing a generic context")
58 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
59 | nn.init.normal_(ctx_vectors, std=0.02)
60 | self.prompt_prefix = " ".join(["X"] * n_ctx)
61 |
62 | print(f'Initial context: "{prompt_prefix}"')
63 | print(f"Number of context words (tokens): {n_ctx}")
64 |
65 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
66 | self.n_cls = 1
67 | self.n_ctx = n_ctx
68 | self.tokenized_prompts = None # torch.Tensor
69 | self.class_token_position = "end"
70 | self.clip_model = clip_model
71 |
72 | def forward(self, cls_name):
73 | ctx = self.ctx
74 | if ctx.dim() == 2:
75 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
76 | prompts = [self.prompt_prefix + " " + cls_name]
77 | self.tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
78 | with torch.no_grad():
79 | embedding = self.clip_model.token_embedding(tokenized_prompts).type(dtype)
80 |
81 | prefix = embedding[:, :1, :]
82 | suffix = embedding[:, 1 + n_ctx :, :]
83 |
84 | if self.class_token_position == "end":
85 | prompts = torch.cat(
86 | [
87 | prefix, # (n_cls, 1, dim)
88 | ctx, # (n_cls, n_ctx, dim)
89 | suffix, # (n_cls, *, dim)
90 | ],
91 | dim=1,
92 | )
93 | else:
94 | raise ValueError
95 |
96 | return prompts
--------------------------------------------------------------------------------
/OATrans/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--data_path', default='./data/',
5 | help='path to datasets')
6 | parser.add_argument('--data_name', default='precomp',
7 | help='{coco,f30k}_precomp')
8 | parser.add_argument('--vocab_path', default='./vocab/',
9 | help='Path to saved vocabulary json files.')
10 | parser.add_argument('--margin', default=0.2, type=float,
11 | help='Rank loss margin.')
12 | parser.add_argument('--num_epochs', default=30, type=int,
13 | help='Number of training epochs.')
14 | parser.add_argument('--batch_size', default=128, type=int,
15 | help='Size of a training mini-batch.')
16 | parser.add_argument('--word_dim', default=300, type=int,
17 | help='Dimensionality of the word embedding.')
18 | parser.add_argument('--embed_size', default=1024, type=int,
19 | help='Dimensionality of the joint embedding.')
20 | parser.add_argument('--grad_clip', default=2., type=float,
21 | help='Gradient clipping threshold.')
22 | parser.add_argument('--num_layers', default=1, type=int,
23 | help='Number of GRU layers.')
24 | parser.add_argument('--learning_rate', default=.0002, type=float,
25 | help='Initial learning rate.')
26 | parser.add_argument('--lr_update', default=15, type=int,
27 | help='Number of epochs to update the learning rate.')
28 | parser.add_argument('--workers', default=10, type=int,
29 | help='Number of data loader workers.')
30 | parser.add_argument('--log_step', default=10, type=int,
31 | help='Number of steps to print and record the log.')
32 | parser.add_argument('--val_step', default=500, type=int,
33 | help='Number of steps to run validation.')
34 | parser.add_argument('--logger_name', default='./runs/runX/log',
35 | help='Path to save Tensorboard log.')
36 | parser.add_argument('--model_name', default='./runs/runX/checkpoint',
37 | help='Path to save the model.')
38 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
39 | help='path to latest checkpoint (default: none)')
40 | parser.add_argument('--max_violation', action='store_true',
41 | help='Use max instead of sum in the rank loss.')
42 | parser.add_argument('--img_dim', default=2048, type=int,
43 | help='Dimensionality of the image embedding.')
44 | parser.add_argument('--no_imgnorm', action='store_true',
45 | help='Do not normalize the image embeddings.')
46 | parser.add_argument('--no_txtnorm', action='store_true',
47 | help='Do not normalize the text embeddings.')
48 | parser.add_argument('--raw_feature_norm', default="clipped_l2norm",
49 | help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax')
50 | parser.add_argument('--agg_func', default="LogSumExp",
51 | help='LogSumExp|Mean|Max|Sum')
52 | parser.add_argument('--cross_attn', default="t2i",
53 | help='t2i|i2t')
54 | parser.add_argument('--precomp_enc_type', default="basic",
55 | help='basic|weight_norm')
56 | parser.add_argument('--bi_gru', action='store_true',
57 | help='Use bidirectional GRU.')
58 | parser.add_argument('--lambda_lse', default=6., type=float,
59 | help='LogSumExp temp.')
60 | parser.add_argument('--lambda_softmax', default=9., type=float,
61 | help='Attention softmax temperature.')
62 | opt = parser.parse_args()
63 | print(opt)
64 |
--------------------------------------------------------------------------------
/OATrans/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce
5 | from operator import getitem
6 | from datetime import datetime
7 | from OATrans.logger import setup_logging
8 | from utils import read_json, write_json
9 | import time
10 | import inspect
11 |
12 |
13 | class ConfigParser:
14 | def __init__(self, args, options='', timestamp=True, test=False):
15 | # parse default and custom cli options
16 | for opt in options:
17 | args.add_argument(*opt.flags, default=None, type=opt.type)
18 | args = args.parse_args()
19 | self.args = args
20 | if args.device:
21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
22 | if args.resume is None:
23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
24 | assert args.config is not None, msg_no_cfg
25 | self.cfg_fname = Path(args.config)
26 | config = read_json(self.cfg_fname)
27 | self.resume = None
28 | else:
29 | self.resume = Path(args.resume)
30 | resume_cfg_fname = self.resume.parent / 'config.json'
31 | config = read_json(resume_cfg_fname)
32 | if args.config is not None:
33 | config.update(read_json(Path(args.config)))
34 |
35 | # load config file and apply custom cli options
36 | self._config = _update_config(config, options, args)
37 |
38 | # set save_dir where trained model and log will be saved.
39 | save_dir = Path(self.config['trainer']['save_dir'])
40 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
41 |
42 | exper_name = self.config['name']
43 | self._save_dir = save_dir / 'models' / exper_name / timestamp
44 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp
45 | self._log_dir = save_dir / 'log' / exper_name / timestamp
46 |
47 | if not test:
48 | self.save_dir.mkdir(parents=True, exist_ok=True)
49 | self.log_dir.mkdir(parents=True, exist_ok=True)
50 |
51 | # if set, remove all previous experiments with the current config
52 | if vars(args).get("purge_exp_dir", False):
53 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir):
54 | config_dir = dirpath.parent
55 | existing = list(config_dir.glob("*"))
56 | print(f"purging {len(existing)} directories from config_dir...")
57 | tic = time.time()
58 | os.system(f"rm -rf {config_dir}")
59 | print(f"Finished purge in {time.time() - tic:.3f}s")
60 |
61 | # save updated config file to the checkpoint dir
62 | if not test:
63 | write_json(self.config, self.save_dir / 'config.json')
64 |
65 | # configure logging module
66 | setup_logging(self.log_dir)
67 | self.log_levels = {
68 | 0: logging.WARNING,
69 | 1: logging.INFO,
70 | 2: logging.DEBUG
71 | }
72 |
73 | def initialize(self, name, module, *args, index=None, **kwargs):
74 | """
75 | finds a function handle with the name given as 'type' in config, and returns the
76 | instance initialized with corresponding keyword args given as 'args'.
77 | """
78 | if index is None:
79 | module_name = self[name]['type']
80 | module_args = dict(self[name]['args'])
81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
82 | module_args.update(kwargs)
83 | else:
84 | module_name = self[name][index]['type']
85 | module_args = dict(self[name][index]['args'])
86 |
87 | # if parameter not in config subdict, then check if it's in global config.
88 | signature = inspect.signature(getattr(module, module_name).__init__)
89 | print(module_name)
90 | for param in signature.parameters.keys():
91 | if param not in module_args and param in self.config:
92 | module_args[param] = self[param]
93 | if module_name == 'FrozenInTime' and param == 'args':
94 | module_args[param] = self.args
95 | if module_name == 'MultiDistTextObjectVideoDataLoader' and param == 'args':
96 | module_args[param] = self.args
97 |
98 | return getattr(module, module_name)(*args, **module_args)
99 |
100 | def __getitem__(self, name):
101 | return self.config[name]
102 |
103 | def get_logger(self, name, verbosity=2):
104 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
105 | self.log_levels.keys())
106 | assert verbosity in self.log_levels, msg_verbosity
107 | logger = logging.getLogger(name)
108 | logger.setLevel(self.log_levels[verbosity])
109 | return logger
110 |
111 | # setting read-only attributes
112 | @property
113 | def config(self):
114 | return self._config
115 |
116 | @property
117 | def save_dir(self):
118 | return self._save_dir
119 |
120 | @property
121 | def log_dir(self):
122 | return self._log_dir
123 |
124 |
125 | # helper functions used to update config dict with custom cli options
126 | def _update_config(config, options, args):
127 | for opt in options:
128 | value = getattr(args, _get_opt_name(opt.flags))
129 | if value is not None:
130 | _set_by_path(config, opt.target, value)
131 | return config
132 |
133 |
134 | def _get_opt_name(flags):
135 | for flg in flags:
136 | if flg.startswith('--'):
137 | return flg.replace('--', '')
138 | return flags[0].replace('--', '')
139 |
140 |
141 | def _set_by_path(tree, keys, value):
142 | """Set a value in a nested object in tree by sequence of keys."""
143 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
144 |
145 |
146 | def _get_by_path(tree, keys):
147 | """Access a nested object in tree by sequence of keys."""
148 | return reduce(getitem, keys, tree)
149 |
--------------------------------------------------------------------------------
/OATrans/parse_config_dist_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce
5 | from operator import getitem
6 | from datetime import datetime
7 | from OATrans.logger import setup_logging
8 | from utils import read_json, write_json
9 | import time
10 | import inspect
11 |
12 |
13 | class ConfigParser:
14 | def __init__(self, args, options='', timestamp=True, test=False):
15 | # parse default and custom cli options
16 | for opt in options:
17 | args.add_argument(*opt.flags, default=None, type=opt.type)
18 | args = args.parse_args()
19 | self.args = args
20 | if args.device:
21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
22 | if args.resume is None:
23 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
24 | assert args.config is not None, msg_no_cfg
25 | self.cfg_fname = Path(args.config)
26 | config = read_json(self.cfg_fname)
27 | self.resume = None
28 | else:
29 | self.resume = Path(args.resume)
30 | resume_cfg_fname = self.resume.parent / 'config.json'
31 | config = read_json(resume_cfg_fname)
32 | if args.config is not None:
33 | config.update(read_json(Path(args.config)))
34 |
35 | # load config file and apply custom cli options
36 | self._config = _update_config(config, options, args)
37 |
38 | # set save_dir where trained model and log will be saved.
39 | save_dir = Path(self.config['trainer']['save_dir'])
40 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
41 |
42 | exper_name = self.config['name']
43 | self._save_dir = save_dir / 'models' / exper_name / timestamp
44 | self._web_log_dir = save_dir / 'web' / exper_name / timestamp
45 | self._log_dir = save_dir / 'log' / exper_name / timestamp
46 |
47 | if not test:
48 | self.save_dir.mkdir(parents=True, exist_ok=True)
49 | self.log_dir.mkdir(parents=True, exist_ok=True)
50 |
51 | # if set, remove all previous experiments with the current config
52 | if vars(args).get("purge_exp_dir", False):
53 | for dirpath in (self._save_dir, self._log_dir, self._web_log_dir):
54 | config_dir = dirpath.parent
55 | existing = list(config_dir.glob("*"))
56 | print(f"purging {len(existing)} directories from config_dir...")
57 | tic = time.time()
58 | os.system(f"rm -rf {config_dir}")
59 | print(f"Finished purge in {time.time() - tic:.3f}s")
60 |
61 | # save updated config file to the checkpoint dir
62 | if not test:
63 | write_json(self.config, self.save_dir / 'config.json')
64 |
65 | # configure logging module
66 | setup_logging(self.log_dir)
67 | self.log_levels = {
68 | 0: logging.WARNING,
69 | 1: logging.INFO,
70 | 2: logging.DEBUG
71 | }
72 |
73 | def initialize(self, name, module, *args, index=None, **kwargs):
74 | """
75 | finds a function handle with the name given as 'type' in config, and returns the
76 | instance initialized with corresponding keyword args given as 'args'.
77 | """
78 | if index is None:
79 | module_name = self[name]['type']
80 | module_args = dict(self[name]['args'])
81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
82 | module_args.update(kwargs)
83 | else:
84 | module_name = self[name][index]['type']
85 | module_args = dict(self[name][index]['args'])
86 |
87 | # if parameter not in config subdict, then check if it's in global config.
88 | signature = inspect.signature(getattr(module, module_name).__init__)
89 | print(module_name)
90 | for param in signature.parameters.keys():
91 | if param not in module_args and param in self.config:
92 | module_args[param] = self[param]
93 | if module_name == 'FrozenInTime' and param == 'args':
94 | module_args[param] = self.args
95 | if module_name == 'MultiDistTextObjectVideoDataLoader' and param == 'args':
96 | module_args[param] = self.args
97 | if module_name == 'TextObjectVideoDataLoader' and param == 'args':
98 | module_args[param] = self.args
99 |
100 | return getattr(module, module_name)(*args, **module_args)
101 |
102 | def __getitem__(self, name):
103 | return self.config[name]
104 |
105 | def get_logger(self, name, verbosity=2):
106 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
107 | self.log_levels.keys())
108 | assert verbosity in self.log_levels, msg_verbosity
109 | logger = logging.getLogger(name)
110 | logger.setLevel(self.log_levels[verbosity])
111 | return logger
112 |
113 | # setting read-only attributes
114 | @property
115 | def config(self):
116 | return self._config
117 |
118 | @property
119 | def save_dir(self):
120 | return self._save_dir
121 |
122 | @property
123 | def log_dir(self):
124 | return self._log_dir
125 |
126 |
127 | # helper functions used to update config dict with custom cli options
128 | def _update_config(config, options, args):
129 | for opt in options:
130 | value = getattr(args, _get_opt_name(opt.flags))
131 | if value is not None:
132 | _set_by_path(config, opt.target, value)
133 | return config
134 |
135 |
136 | def _get_opt_name(flags):
137 | for flg in flags:
138 | if flg.startswith('--'):
139 | return flg.replace('--', '')
140 | return flags[0].replace('--', '')
141 |
142 |
143 | def _set_by_path(tree, keys, value):
144 | """Set a value in a nested object in tree by sequence of keys."""
145 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
146 |
147 |
148 | def _get_by_path(tree, keys):
149 | """Access a nested object in tree by sequence of keys."""
150 | return reduce(getitem, keys, tree)
151 |
--------------------------------------------------------------------------------
/OATrans/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | from OATrans.data_loader import data_loader as module_data
4 | from OATrans import model as module_loss, model as module_metric, model as module_arch
5 | import utils.visualizer as module_vis
6 | from utils.util import replace_nested_dict_item
7 | from parse_config_dist_multi import ConfigParser
8 | from trainer.trainer import Trainer
9 | from sacred import Experiment
10 | from neptunecontrib.monitoring.sacred import NeptuneObserver
11 | import transformers
12 | import os
13 |
14 | ex = Experiment('train')
15 |
16 |
17 | @ex.main
18 | def run():
19 | logger = config.get_logger('train')
20 | os.environ['TOKENIZERS_PARALLELISM'] = "false"
21 | # TODO: improve Create identity (do nothing) visualiser?
22 | if config['visualizer']['type'] != "":
23 | visualizer = config.initialize(
24 | name='visualizer',
25 | module=module_vis,
26 | exp_name=config['name'],
27 | web_dir=config._web_log_dir
28 | )
29 | else:
30 | visualizer = None
31 | # pdb.set_trace()
32 | # build tokenizer
33 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'],
34 | TOKENIZERS_PARALLELISM=False)
35 |
36 | # setup data_loader instances
37 | data_loader, valid_data_loader = init_dataloaders(config, module_data)
38 | print('Train dataset: ', [x.n_samples for x in data_loader], ' samples')
39 | print('Val dataset: ', [x.n_samples for x in valid_data_loader], ' samples')
40 | # build model architecture, then print to console
41 | model = config.initialize('arch', module_arch)
42 | logger.info(model)
43 |
44 | # get function handles of loss and metrics
45 | loss = config.initialize(name="loss", module=module_loss)
46 | metrics = [getattr(module_metric, met) for met in config['metrics']]
47 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
48 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
49 | optimizer = config.initialize('optimizer', transformers, trainable_params)
50 | lr_scheduler = None
51 | if 'lr_scheduler' in config._config:
52 | if hasattr(transformers, config._config['lr_scheduler']['type']):
53 | lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer)
54 | else:
55 | print('lr scheduler not found')
56 | if config['trainer']['neptune']:
57 | writer = ex
58 | else:
59 | writer = None
60 | trainer = Trainer(model, loss, metrics, optimizer,
61 | config=config,
62 | data_loader=data_loader,
63 | valid_data_loader=valid_data_loader,
64 | lr_scheduler=lr_scheduler,
65 | visualizer=visualizer,
66 | writer=writer,
67 | tokenizer=tokenizer,
68 | max_samples_per_epoch=config['trainer']['max_samples_per_epoch'])
69 | trainer.train()
70 |
71 |
72 | def init_dataloaders(config, module_data):
73 | """
74 | We need a way to change split from 'train' to 'val'.
75 | """
76 | if "type" in config["data_loader"] and "args" in config["data_loader"]:
77 | # then its a single dataloader
78 | data_loader = [config.initialize("data_loader", module_data)]
79 | config['data_loader']['args'] = replace_nested_dict_item(config['data_loader']['args'], 'split', 'val')
80 | valid_data_loader = [config.initialize("data_loader", module_data)]
81 | elif isinstance(config["data_loader"], list):
82 | data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
83 | range(len(config['data_loader']))]
84 | new_cfg_li = []
85 | for dl_cfg in config['data_loader']:
86 | dl_cfg['args'] = replace_nested_dict_item(dl_cfg['args'], 'split', 'val')
87 | new_cfg_li.append(dl_cfg)
88 | config._config['data_loader'] = new_cfg_li
89 | valid_data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
90 | range(len(config['data_loader']))]
91 | else:
92 | raise ValueError("Check data_loader config, not correct format.")
93 |
94 | return data_loader, valid_data_loader
95 |
96 |
97 | if __name__ == '__main__':
98 | args = argparse.ArgumentParser(description='PyTorch Template')
99 | args.add_argument('-c', '--config', default=None, type=str,
100 | help='config file path (default: None)')
101 | args.add_argument('-r', '--resume', default=None, type=str,
102 | help='path to latest checkpoint (default: None)')
103 | args.add_argument('-d', '--device', default=None, type=str,
104 | help='indices of GPUs to enable (default: all)')
105 | args.add_argument('-o', '--observe', action='store_true',
106 | help='Whether to observe (neptune)')
107 | # custom cli options to modify configuration from default values given in json file.
108 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
109 | options = [
110 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
111 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
112 | ]
113 | config = ConfigParser(args, options)
114 | ex.add_config(config._config)
115 |
116 | if config['trainer']['neptune']:
117 | # delete this error if you have added your own neptune credentials neptune.ai
118 | # raise ValueError('Neptune credentials not set up yet.')
119 | ex.observers.append(NeptuneObserver(
120 | api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkZTg4NGQ4YS01NmRlLTQwMzEtYjc2NS1mYjY3MzRiMWNjZTYifQ==',
121 | project_name='awinyimgprocess/Frozen'))
122 | ex.run()
123 | else:
124 | run()
125 |
--------------------------------------------------------------------------------
/OATrans/train_dist_region_mem.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import torch
4 | import data_loader.data_loader as module_data
5 | import model.loss as module_loss
6 | import model.metric as module_metric
7 | import model.oa_model_region_mem as module_arch
8 | import utils.visualizer as module_vis
9 | from utils.util import replace_nested_dict_item
10 | from parse_config_dist_multi import ConfigParser
11 | from trainer.trainer_region_mem import Multi_Trainer_dist
12 | from sacred import Experiment
13 | from neptunecontrib.monitoring.sacred import NeptuneObserver
14 | import transformers
15 | import os
16 |
17 | ex = Experiment('train')
18 |
19 | @ex.main
20 | def run():
21 | logger = config.get_logger('train')
22 | os.environ['TOKENIZERS_PARALLELISM'] = "false"
23 | os.environ['TRANSFORMERS_OFFLINE'] = "1"
24 | # TODO: improve Create identity (do nothing) visualiser?
25 | if config['visualizer']['type'] != "":
26 | visualizer = config.initialize(
27 | name='visualizer',
28 | module=module_vis,
29 | exp_name=config['name'],
30 | web_dir=config._web_log_dir
31 | )
32 | else:
33 | visualizer = None
34 | torch.cuda.set_device(args.local_rank)
35 | torch.distributed.init_process_group(backend='nccl',
36 | init_method='tcp://{}:{}'.format(
37 | args.master_address, args.master_port),
38 | rank=args.rank, world_size=args.world_size)
39 | device = torch.device(f'cuda:{args.local_rank}')
40 | print('world_size', args.world_size, flush=True)
41 | print('local_rank: ', args.local_rank, flush=True)
42 | # build tokenizer
43 | tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'],
44 | TOKENIZERS_PARALLELISM=False)
45 |
46 | # setup data_loader instances
47 | data_loader, valid_data_loader = init_dataloaders(config, module_data)
48 | print('Train dataset: ', [x.n_samples for x in data_loader], ' samples')
49 | print('Val dataset: ', [x.n_samples for x in valid_data_loader], ' samples')
50 | # build model architecture, then print to console
51 | model = config.initialize('arch', module_arch)
52 | if args.local_rank == 0:
53 | logger.info(model)
54 |
55 | # get function handles of loss and metrics
56 | loss = config.initialize(name="loss", module=module_loss)
57 | metrics = [getattr(module_metric, met) for met in config['metrics']]
58 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
59 | trainable_params = filter(lambda p: p.requires_grad, model.parameters())
60 | optimizer = config.initialize('optimizer', transformers, trainable_params)
61 | lr_scheduler = None
62 | if 'lr_scheduler' in config._config:
63 | if hasattr(transformers, config._config['lr_scheduler']['type']):
64 | lr_scheduler = config.initialize('lr_scheduler', transformers, optimizer)
65 | else:
66 | print('lr scheduler not found')
67 | if config['trainer']['neptune']:
68 | writer = ex
69 | else:
70 | writer = None
71 | trainer = Multi_Trainer_dist(args, model, loss, metrics, optimizer,
72 | config=config,
73 | data_loader=data_loader,
74 | valid_data_loader=valid_data_loader,
75 | lr_scheduler=lr_scheduler,
76 | visualizer=visualizer,
77 | writer=writer,
78 | tokenizer=tokenizer,
79 | max_samples_per_epoch=config['trainer']['max_samples_per_epoch'])
80 | trainer.train()
81 |
82 |
83 | def init_dataloaders(config, module_data):
84 | """
85 | We need a way to change split from 'train' to 'val'.
86 | """
87 | if "type" in config["data_loader"] and "args" in config["data_loader"]:
88 | # then its a single dataloader
89 | data_loader = [config.initialize("data_loader", module_data)]
90 | config['data_loader']['args'] = replace_nested_dict_item(config['data_loader']['args'], 'split', 'val')
91 | valid_data_loader = [config.initialize("data_loader", module_data)]
92 | elif isinstance(config["data_loader"], list):
93 | data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
94 | range(len(config['data_loader']))]
95 | new_cfg_li = []
96 | for dl_cfg in config['data_loader']:
97 | dl_cfg['args'] = replace_nested_dict_item(dl_cfg['args'], 'split', 'val')
98 | new_cfg_li.append(dl_cfg)
99 | config._config['data_loader'] = new_cfg_li
100 | valid_data_loader = [config.initialize('data_loader', module_data, index=idx) for idx in
101 | range(len(config['data_loader']))]
102 | else:
103 | raise ValueError("Check data_loader config, not correct format.")
104 |
105 | return data_loader, valid_data_loader
106 |
107 |
108 | if __name__ == '__main__':
109 | args = argparse.ArgumentParser(description='PyTorch Template')
110 | args.add_argument('-c', '--config', default=None, type=str,
111 | help='config file path (default: None)')
112 | args.add_argument('-r', '--resume', default=None, type=str,
113 | help='path to latest checkpoint (default: None)')
114 | args.add_argument('-d', '--device', default=None, type=str,
115 | help='indices of GPUs to enable (default: all)')
116 | args.add_argument('-o', '--observe', action='store_true',
117 | help='Whether to observe (neptune)')
118 | args.add_argument('-l', '--launcher', choices=['none', 'pytorch'], default='none',help='job launcher')
119 | args.add_argument('-k', '--local_rank', type=int, default=0)
120 |
121 | master_address = os.environ['MASTER_ADDR']
122 | master_port = int(os.environ['MASTER_PORT'])
123 | world_size = int(os.environ['WORLD_SIZE'])
124 | # world_size = int(torch.cuda.device_count())
125 | rank = int(os.environ['RANK'])
126 | args.local_rank = int(os.environ['LOCAL_RANK'])
127 |
128 | if torch.cuda.device_count() > 1:
129 | print("Let's use", torch.cuda.device_count(), "GPUs!")
130 |
131 | args.add_argument('-ma', '--master_address', default=master_address)
132 | args.add_argument('-mp', '--master_port', type=int, default=master_port)
133 | args.add_argument('-ws', '--world_size', type=int, default=world_size)
134 | args.add_argument('-rk', '--rank', type=int, default=rank)
135 | args.add_argument('-lr1', '--learning_rate1', type=float, default=2e-4)
136 | args.add_argument('-sc', '--schedule', default=[60, 80])
137 |
138 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
139 | options = [
140 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
141 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
142 | ]
143 | config = ConfigParser(args, options)
144 | args = args.parse_args()
145 | ex.add_config(config._config)
146 |
147 | if config['trainer']['neptune']:
148 | # delete this error if you have added your own neptune credentials neptune.ai
149 | # raise ValueError('Neptune credentials not set up yet.')
150 | ex.observers.append(NeptuneObserver(
151 | api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJkZTg4NGQ4YS01NmRlLTQwMzEtYjc2NS1mYjY3MzRiMWNjZTYifQ==',
152 | project_name='awinyimgprocess/Frozen'))
153 | ex.run()
154 | else:
155 | run()
156 |
--------------------------------------------------------------------------------
/OATrans/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer_dist import *
2 |
--------------------------------------------------------------------------------
/OATrans/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/.DS_Store
--------------------------------------------------------------------------------
/OATrans/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 |
--------------------------------------------------------------------------------
/OATrans/utils/binary_classification_accuracy.py:
--------------------------------------------------------------------------------
1 | def get_accuracy(y_true, y_prob):
2 | assert y_true.ndim == 1 and y_true.size() == y_prob.size()
3 | y_prob = y_prob > 0.5
4 | return (y_true == y_prob).sum().item() / y_true.size(0)
--------------------------------------------------------------------------------
/OATrans/utils/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import torch
3 | from torch import Tensor
4 | from typing import List, Tuple, Any, Optional
5 | from torchvision.transforms import functional_pil as F_pil
6 | from torchvision.transforms import functional_tensor as F_t
7 | from torchvision.transforms.functional import center_crop, crop
8 |
9 | def _get_image_size(img: Tensor) -> List[int]:
10 | """Returns image size as [w, h]
11 | """
12 | if isinstance(img, torch.Tensor):
13 | return F_t._get_image_size(img)
14 |
15 | return F_pil._get_image_size(img)
16 |
17 | def center_plus_four_crops(img: Tensor, size: List[int],
18 | margin_h: int, margin_w: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
19 | """Crop the given image into four tiled borders and the central crop.
20 | """
21 |
22 | if isinstance(size, numbers.Number):
23 | size = (int(size), int(size))
24 | elif isinstance(size, (tuple, list)) and len(size) == 1:
25 | size = (size[0], size[0])
26 |
27 | if len(size) != 2:
28 | raise ValueError("Please provide only two dimensions (h, w) for size.")
29 |
30 | image_width, image_height = _get_image_size(img)
31 |
32 | crop_height, crop_width = size
33 |
34 | if crop_width > image_width or crop_height > image_height:
35 | msg = "Requested crop size {} is bigger than input size {}"
36 | raise ValueError(msg.format(size, (image_height, image_width)))
37 |
38 | if crop_width + margin_w > image_width:
39 | msg = "Requested margin size {} + input {} is bigger than input size {}"
40 | raise ValueError(msg.format((margin_h, margin_w), size, (image_height, image_width)))
41 |
42 | #vertical_border_height = image_height - crop_height
43 | #horizontal_border_height = image_width - crop_width
44 |
45 | #x1 = horizontal_border_height // 2
46 | x11 = (image_width - crop_width - 2 * margin_w) // 2
47 | x12 = x11 + margin_w
48 | x21 = x12 + crop_width
49 | x22 = x21 + margin_w
50 |
51 | y11 = (image_height - crop_height - 2 * margin_h) // 2
52 | y12 = y11 + margin_h
53 | y21 = y12 + crop_height
54 | y22 = y21 + margin_h
55 |
56 | tl = crop(img, y11, x11, margin_h, margin_w + crop_width)
57 | tr = crop(img, y11, x21, margin_h + crop_height, margin_w)
58 | bl = crop(img, y12, x11, margin_h + crop_height, margin_w)
59 | br = crop(img, y21, x12, margin_h, margin_w + crop_width)
60 | center = center_crop(img, [crop_height, crop_width])
61 |
62 | return tl, tr, bl, br, center
63 |
64 |
65 |
66 | def center_plus_twohori_crops(img: Tensor, size: List[int],
67 | margin_w: int) -> Tuple[Tensor, Tensor, Tensor]:
68 | """Crop the given image into four tiled borders and the central crop.
69 | """
70 |
71 | if isinstance(size, numbers.Number):
72 | size = (int(size), int(size))
73 | elif isinstance(size, (tuple, list)) and len(size) == 1:
74 | size = (size[0], size[0])
75 |
76 | if len(size) != 2:
77 | raise ValueError("Please provide only two dimensions (h, w) for size.")
78 |
79 | image_width, image_height = _get_image_size(img)
80 |
81 | crop_height, crop_width = size
82 |
83 | if crop_width > image_width or crop_height > image_height:
84 | msg = "Requested crop size {} is bigger than input size {}"
85 | raise ValueError(msg.format(size, (image_height, image_width)))
86 |
87 | if crop_width + margin_w > image_width :
88 | msg = "Requested margin size {} + input {} is bigger than input size {}"
89 | raise ValueError(msg.format((0, margin_w), size, (image_height, image_width)))
90 |
91 | # vertical_border_height = image_height - crop_height
92 | # horizontal_border_height = image_width - crop_width
93 |
94 | # x1 = horizontal_border_height // 2
95 | x11 = (image_width - crop_width - 2 * margin_w) // 2
96 | x12 = x11 + margin_w
97 | x21 = x12 + crop_width
98 |
99 | y11 = (image_height - crop_height) // 2
100 |
101 | left = crop(img, y11, x11, crop_height, margin_w)
102 | right = crop(img, y11, x21, crop_height, margin_w)
103 | center = center_crop(img, [crop_height, crop_width])
104 |
105 | return left, right, center
106 |
107 | from torch import nn
108 | class TwoHoriCrop(nn.Module):
109 | def __init__(self, size, margin_w):
110 | super().__init__()
111 | self.size = size
112 | self.margin_w = margin_w
113 |
114 | def forward(self, x):
115 | return center_plus_twohori_crops(x, self.size, self.margin_w)
116 |
117 | if __name__ == "__main__":
118 | from PIL import Image
119 |
120 | img = Image.open('visualisations/guitar.png')
121 | crops = center_plus_four_crops(img, [336, 336], 112, 112)
122 | order = ['tl', 'tr', 'bl', 'br', 'center']
123 |
124 | for idx, subimg in zip(order, crops):
125 | subimg.save(f'visualisations/guitar_{idx}.png')
126 |
127 | crops = center_plus_twohori_crops(img, [448, 448], 112)
128 | order = ['left', 'right', 'center2']
129 |
130 | for idx, subimg in zip(order, crops):
131 | subimg.save(f'visualisations/guitar_{idx}.png')
132 |
--------------------------------------------------------------------------------
/OATrans/utils/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source, attr
3 | from dominate.tags import span
4 | import os
5 |
6 |
7 | class HTML:
8 | """This HTML class allows us to save images and write texts into a single HTML file.
9 |
10 | It consists of functions such as (add a text header to the HTML file),
11 | (add a row of images to the HTML file), and (save the HTML to the disk).
12 | It is based on Python library 'dominate', a Python library for creating and
13 | manipulating HTML documents using a DOM API.
14 | """
15 |
16 | def __init__(self, web_dir, title, refresh=0):
17 | """Initialize the HTML classes
18 |
19 | Parameters:
20 | web_dir (str) -- a directory that stores the webpage. HTML file will be
21 | created at /index.html; images will be saved at 0:
35 | with self.doc.head:
36 | meta(http_equiv="refresh", content=str(refresh))
37 |
38 | def get_image_dir(self):
39 | """Return the directory that stores images"""
40 | return self.img_dir
41 |
42 | def add_header(self, text):
43 | """Insert a header to the HTML file
44 |
45 | Parameters:
46 | text (str) -- the header text
47 | """
48 | with self.doc:
49 | h3(text)
50 |
51 | def add_videos(self, vids, txts, links, width=400, hidden_tag="hidden"):
52 | """add images to the HTML file
53 |
54 | Parameters:
55 | vids (str list) -- a list of image paths
56 | txts (str list) -- a list of image names shown on the website
57 | links (str list) -- a list of hyperref links; when you click an image,
58 | it will redirect you to a new page
59 | """
60 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
61 | self.doc.add(self.t)
62 | colors = ["red", "blue", "gold", "salman"]
63 | with self.t:
64 | with tr():
65 | for vid, txt, link in zip(vids, txts, links):
66 | td_style = "word-wrap: break-word; width:{}px".format(width)
67 | with td(style=td_style, halign="center", valign="top"):
68 | with p():
69 | vid_path = str(vid)
70 | if vid_path == hidden_tag:
71 | p_style = "font-weight: bold; width:{}px;"
72 | p_style = p_style.format(width * 3)
73 | p("hidden video", style=p_style)
74 | else:
75 | with a(href=str(link)):
76 | with video():
77 | attr(controls="controls")
78 | source(src=vid_path, type="video/mp4")
79 | br()
80 | rows = txt.split("
")
81 | for idx, row in enumerate(rows):
82 | color = colors[idx % len(colors)]
83 | bold_tag = ""
84 | if not row.startswith(bold_tag):
85 | s_style = "color:{};".format(color)
86 | else:
87 | s_style = "color:black; font-weight: bold;"
88 | row = row[len(bold_tag):]
89 | span(row, style=s_style)
90 |
91 | def add_images(self, ims, txts, links, width=400):
92 | """add images to the HTML file
93 |
94 | Parameters:
95 | ims (str list) -- a list of image paths
96 | txts (str list) -- a list of image names shown on the website
97 | links (str list) -- a list of hyperref links; when you click an image,
98 | it will redirect you to a new page
99 | """
100 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
101 | self.doc.add(self.t)
102 | colors = ["red", "blue", "gold", "salman"]
103 | with self.t:
104 | with tr():
105 | for im, txt, link in zip(ims, txts, links):
106 | td_style = "word-wrap: break-word;"
107 | with td(style=td_style, halign="center", valign="top"):
108 | with p():
109 | with a(href=link):
110 | img(
111 | style="width:%dpx" % width,
112 | src=im,
113 | )
114 | br()
115 | rows = txt.split("
")
116 | for idx, row in enumerate(rows):
117 | color = colors[idx % len(colors)]
118 | bold_tag = ""
119 | if not row.startswith(bold_tag):
120 | s_style = "color:{};".format(color)
121 | else:
122 | s_style = "color:black; font-weight: bold;"
123 | row = row[len(bold_tag):]
124 | span(row, style=s_style)
125 |
126 | def save(self):
127 | """save the current content to the HMTL file"""
128 | html_file = "%s/index.html" % self.web_dir
129 | f = open(html_file, "wt")
130 | f.write(self.doc.render())
131 | f.close()
132 |
133 |
134 | if __name__ == "__main__": # we show an example usage here.
135 | html = HTML("web/", "test_html")
136 | html.add_header("hello world")
137 |
138 | ims, txts, links = [], [], []
139 | for n in range(4):
140 | ims.append("image_%d.png" % n)
141 | txts.append("text_%d" % n)
142 | links.append("image_%d.png" % n)
143 | html.add_images(ims, txts, links)
144 | html.save()
--------------------------------------------------------------------------------
/OATrans/utils/objects_vocab_token_len:
--------------------------------------------------------------------------------
1 | [2, 1, 1, 3, 1, 3, 2, 2, 1, 1, 2, 2, 3, 1, 1, 1, 2, 1, 1, 3, 1, 2, 1, 2, 1, 1, 1, 3, 2, 1, 1, 2, 2, 1, 2, 3, 1, 4, 2, 1, 1, 2, 1, 1, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 2, 3, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 3, 2, 3, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 2, 1, 2, 1, 5, 1, 1, 2, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 1, 3, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 3, 3, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 1, 3, 2, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 3, 2, 3, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 3, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 3, 2, 2, 1, 1, 1, 1, 1, 3, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 2, 1, 1, 1, 2, 1, 4, 2, 1, 2, 1, 2, 1, 1, 1, 3, 2, 1, 1, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 1, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 4, 3, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 3, 1, 1, 1, 2, 1, 3, 2, 1, 1, 1, 1, 2, 2, 3, 2, 1, 2, 1, 3, 1, 2, 1, 2, 2, 2, 3, 1, 1, 1, 1, 1, 2, 2, 3, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 3, 1, 1, 4, 2, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 3, 1, 1, 1, 1, 3, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 2, 2, 3, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 3, 2, 1, 3, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 3, 2, 3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 2, 1, 1, 2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 5, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 4, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 3, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 1, 1, 2, 3, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 2, 2, 1, 1, 1, 1, 3, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2, 2, 1, 4, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 3, 2, 1, 2, 1, 2, 2, 1, 1, 3, 1, 3, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 3, 2, 1, 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 3, 1, 3, 3, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 2, 3, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 3, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 3, 1, 1, 1, 2, 1, 1, 3, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 3, 1, 1, 1, 3, 2, 1, 1, 3, 1, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 3, 1, 1, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 3, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2]
--------------------------------------------------------------------------------
/OATrans/utils/param_forzen.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def forzen_param(model):
5 | for name, param in model.named_parameters():
6 | if 'vid_proj' in name or 'txt_proj' in name:
7 | param.requires_grad = True
8 | else:
9 | param.requires_grad = False
10 | return True
--------------------------------------------------------------------------------
/OATrans/utils/unit_test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/unit_test/__init__.py
--------------------------------------------------------------------------------
/OATrans/utils/unit_test/distill_bert.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel
2 | import transformers
3 |
4 | text = "tree"
5 |
6 | tokenizer = transformers.AutoTokenizer.from_pretrained("pretrained/distilbert-base-uncased",
7 | TOKENIZERS_PARALLELISM=False)
8 |
9 | text_data = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
10 | text_data = {key: val.cuda() for key, val in text_data.items()}
11 |
12 |
13 | text_model = AutoModel.from_pretrained("pretrained/distilbert-base-uncased").cuda()
14 |
15 | print(text_model)
16 |
17 | text_embeddings_all = text_model(**text_data).last_hidden_state
18 | print(text_embeddings_all.size())
19 | text_embeddings = text_embeddings_all[:, 0, :]
20 | print(text_embeddings)
21 |
22 |
23 | text_embeddings_2 = text_model.embeddings(text_data['input_ids'])
24 |
25 | text_embeddings_2 = text_model.transformer(text_embeddings_2,
26 | attn_mask=attention_mask,
27 | head_mask=head_mask,
28 | output_attentions=output_attentions,
29 | output_hidden_states=output_hidden_states,
30 | return_dict=return_dict,
31 | )
32 |
33 | print(text_embeddings - text_embeddings_2)
--------------------------------------------------------------------------------
/OATrans/utils/unit_test/load_msvd_video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | video_path = "MSVD/YouTubeClips/fVWUaH2mCt4_1_7.avi"
4 | cap = cv2.VideoCapture(video_path)
5 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
6 | print(vlen)
--------------------------------------------------------------------------------
/OATrans/utils/unit_test/region_roi_example.py:
--------------------------------------------------------------------------------
1 | """
2 | usage
3 | # add for roi pooling
4 | import torch
5 | import torchvision.ops.roi_align as roi_align
6 | self.roi_align = roi_align
7 | """
8 |
9 | def region_embed(self, x, bbox):
10 | """
11 | Args:
12 | x (): the input video
13 | bbox (): bounding boxes with 4 loc + height/width; stacked for num_frame times
14 |
15 | Returns:
16 | the raw pixel region of bbox
17 | """
18 | b, t, c, h, w = x.size()
19 | x = x.view(-1, c, h, w)
20 | B, L, N = bbox.size()
21 | coordinates = torch.zeros((B * L, 5)).cuda()
22 | for i in range(B * L):
23 | coordinates[i][0] = i // L
24 | coordinates[i][1:] = bbox[i // L, i % L, :4]
25 | regions = self.roi_align(x, coordinates, output_size=[self.patch_size, self.patch_size])
26 | region_features = self.region_embedding_layer(regions)
27 | region_features = region_features.view(-1, L // t, self.embed_dim)
28 | return region_features
--------------------------------------------------------------------------------
/OATrans/utils/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from datetime import datetime
4 | from itertools import repeat
5 | from collections import OrderedDict
6 | import functools
7 | import time
8 | import socket
9 | import numpy as np
10 | import psutil
11 | import msgpack
12 | import humanize
13 | import os
14 |
15 | def replace_nested_dict_item(obj, key, replace_value):
16 | for k, v in obj.items():
17 | if isinstance(v, dict):
18 | obj[k] = replace_nested_dict_item(v, key, replace_value)
19 | if key in obj:
20 | obj[key] = replace_value
21 | return obj
22 |
23 |
24 | def state_dict_data_parallel_fix(load_state_dict, curr_state_dict):
25 | load_keys = list(load_state_dict.keys())
26 | curr_keys = list(curr_state_dict.keys())
27 |
28 | redo_dp = False
29 | undo_dp = False
30 | if not curr_keys[0].startswith('module.') and load_keys[0].startswith('module.'):
31 | undo_dp = True
32 | elif curr_keys[0].startswith('module.') and not load_keys[0].startswith('module.'):
33 | redo_dp = True
34 |
35 | if undo_dp:
36 | from collections import OrderedDict
37 | new_state_dict = OrderedDict()
38 | for k, v in load_state_dict.items():
39 | name = k[7:] # remove `module.`
40 | new_state_dict[name] = v
41 | # load params
42 | elif redo_dp:
43 | from collections import OrderedDict
44 | new_state_dict = OrderedDict()
45 | for k, v in load_state_dict.items():
46 | name = 'module.' + k # remove `module.`
47 | new_state_dict[name] = v
48 | else:
49 | new_state_dict = load_state_dict
50 | return new_state_dict
51 |
52 | def print_numpy(x, val=True, shp=False):
53 | """Print the mean, min, max, median, std, and size of a numpy array
54 | Parameters:
55 | val (bool) -- if print the values of the numpy array
56 | shp (bool) -- if print the shape of the numpy array
57 | """
58 | x = x.astype(np.float64)
59 | if shp:
60 | print('shape,', x.shape)
61 | if val:
62 | x = x.flatten()
63 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
64 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
65 |
66 |
67 | def mkdirs(paths):
68 | """create empty directories if they don't exist
69 | Parameters:
70 | paths (str list) -- a list of directory paths
71 | """
72 | if isinstance(paths, list) and not isinstance(paths, str):
73 | for path in paths:
74 | mkdir(path)
75 | else:
76 | mkdir(paths)
77 |
78 |
79 | def mkdir(path):
80 | """create a single empty directory if it didn't exist
81 | Parameters:
82 | path (str) -- a single directory path
83 | """
84 | if not os.path.exists(path):
85 | os.makedirs(path)
86 |
87 | def read_json(fname):
88 | with fname.open('rt') as handle:
89 | return json.load(handle, object_hook=OrderedDict)
90 |
91 | def write_json(content, fname):
92 | with fname.open('wt') as handle:
93 | json.dump(content, handle, indent=4, sort_keys=False)
94 |
95 | def inf_loop(data_loader):
96 | ''' wrapper function for endless data loader. '''
97 | for loader in repeat(data_loader):
98 | yield from loader
99 |
100 | def memory_summary():
101 | vmem = psutil.virtual_memory()
102 | msg = (
103 | f">>> Currently using {vmem.percent}% of system memory "
104 | f"{humanize.naturalsize(vmem.used)}/{humanize.naturalsize(vmem.available)}"
105 | )
106 | print(msg)
107 |
108 | @functools.lru_cache(maxsize=64, typed=False)
109 | def memcache(path):
110 | suffix = Path(path).suffix
111 | print(f"loading features >>>", end=" ")
112 | tic = time.time()
113 | if suffix == ".npy":
114 | res = np_loader(path)
115 | else:
116 | raise ValueError(f"unknown suffix: {suffix} for path {path}")
117 | print(f"[Total: {time.time() - tic:.1f}s] ({socket.gethostname() + ':' + str(path)})")
118 | return res
119 |
120 | def np_loader(np_path, l2norm=False):
121 | with open(np_path, "rb") as f:
122 | data = np.load(f, encoding="latin1", allow_pickle=True)
123 | if isinstance(data, np.ndarray) and data.size == 1:
124 | data = data[()] # handle numpy dict storage convnetion
125 | if l2norm:
126 | print("L2 normalizing features")
127 | if isinstance(data, dict):
128 | for key in data:
129 | feats_ = data[key]
130 | feats_ = feats_ / max(np.linalg.norm(feats_), 1E-6)
131 | data[key] = feats_
132 | elif data.ndim == 2:
133 | data_norm = np.linalg.norm(data, axis=1)
134 | data = data / np.maximum(data_norm.reshape(-1, 1), 1E-6)
135 | else:
136 | raise ValueError("unexpected data format {}".format(type(data)))
137 | return data
138 |
139 |
140 | class Timer:
141 | def __init__(self):
142 | self.cache = datetime.now()
143 |
144 | def check(self):
145 | now = datetime.now()
146 | duration = now - self.cache
147 | self.cache = now
148 | return duration.total_seconds()
149 |
150 | def reset(self):
151 | self.cache = datetime.now()
152 |
--------------------------------------------------------------------------------
/OATrans/utils/video.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | import PIL
5 | import collections
6 | import random
7 | import cv2
8 | import os
9 | import numpy as np
10 |
11 | def load_frames_from_video_path(path, num_frames, sample='rand'):
12 | cap = cv2.VideoCapture(path)
13 | assert (cap.isOpened())
14 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
15 | acc_samples = min(num_frames, vlen)
16 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
17 | ranges = []
18 | for idx, interv in enumerate(intervals[:-1]):
19 | ranges.append((interv, intervals[idx + 1] - 1))
20 | if sample == 'rand':
21 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges]
22 | elif sample == 'uniform':
23 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
24 | else:
25 | raise NotImplementedError
26 |
27 | frames = []
28 | for index in frame_idxs:
29 | cap.set(cv2.CAP_PROP_POS_FRAMES, index)
30 | ret, frame = cap.read()
31 | if ret:
32 | cv2.imwrite(f'images/{index}.jpg', frame)
33 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
34 | frame = torch.from_numpy(frame)
35 | # (H x W x C) to (C x H x W)
36 | frame = frame.permute(2, 0, 1)
37 | frames.append(frame)
38 | else:
39 | raise ValueError
40 |
41 | frames = torch.stack(frames).float() / 255
42 | cap.release()
43 | return frames, frame_idxs
--------------------------------------------------------------------------------
/OATrans/utils/visualization/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/.DS_Store
--------------------------------------------------------------------------------
/OATrans/utils/visualization/3f_vto_visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | visualize both image + object + text
3 | """
4 | import numpy as np
5 | import cv2
6 | from csv import reader
7 | import os
8 | import random
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import pdb
12 | import textwrap
13 | import pandas as pd
14 |
15 | full_csv = "meta_data/webvid_training_success_full.tsv"
16 | data_source = "WebVid/train"
17 | feat_source = "WebVid/8_frame_object/train/"
18 | output = "WebVid2M_visualization/train"
19 |
20 |
21 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None):
22 | acc_samples = min(num_frames, vlen)
23 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
24 | ranges = []
25 | for idx, interv in enumerate(intervals[:-1]):
26 | ranges.append((interv, intervals[idx + 1] - 1))
27 | if sample == 'rand':
28 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges]
29 | elif fix_start is not None:
30 | frame_idxs = [x[0] + fix_start for x in ranges]
31 | elif sample == 'uniform':
32 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
33 | else:
34 | raise NotImplementedError
35 | return frame_idxs
36 |
37 |
38 | def read_frames_cv2(video_path, num_frames, sample='uniform', fix_start=None):
39 | cap = cv2.VideoCapture(video_path)
40 | assert (cap.isOpened())
41 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
42 | # get indexes of sampled frames
43 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start)
44 | frames = []
45 | success_idxs = []
46 | for index in frame_idxs:
47 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
48 | ret, frame = cap.read()
49 | if ret:
50 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
51 | frames.append(frame)
52 | success_idxs.append(index)
53 | else:
54 | pass
55 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})')
56 | cap.release()
57 | return frames
58 |
59 | # step1: open video
60 |
61 | # step2: video seq
62 |
63 | # step3: video feature
64 |
65 |
66 | def tri_region_visualize(imgs, feat_paths, caption, outpath="visualization/1.png"):
67 | concat_imgs = None
68 | for i in [0, 3, 7]:
69 | frame1 = np.load(feat_paths[i], allow_pickle=True)
70 | boxes = frame1['bbox']
71 | features = frame1['x'] # 20 x 2048
72 | confident = frame1['info'].item()['objects_conf']
73 | # step 1: re-ranking the region with confidence
74 | object_ids = frame1['info'].item()['objects_id']
75 | condident_indices = np.argsort(confident)[::-1]
76 | boxes = boxes[condident_indices]
77 | features = features[condident_indices]
78 | object_ids = object_ids[condident_indices]
79 | confident = confident[condident_indices]
80 | new_object, unique_indices = np.unique(object_ids, return_index=True)
81 | # step 2: remove region with same object class
82 | boxes = boxes[unique_indices]
83 | features = features[unique_indices]
84 | object_ids = object_ids[unique_indices]
85 | # object_ids = object_ids[unique_indices]
86 | # confident = confident[unique_indices]
87 | # # print(boxes, features)
88 | # image_width = frame1['info'].item()['image_w']
89 | # image_height = frame1['info'].item()['image_h']
90 | # box_width = boxes[:, 2] - boxes[:, 0]
91 | # box_height = boxes[:, 3] - boxes[:, 1]
92 | # scaled_width = box_width / image_width
93 | # scaled_height = box_height / image_height
94 | # scaled_x = boxes[:, 0] / image_width
95 | # scaled_y = boxes[:, 1] / image_height
96 | # scaled_width = scaled_width[..., np.newaxis]
97 | # scaled_height = scaled_height[..., np.newaxis]
98 | # scaled_x = scaled_x[..., np.newaxis]
99 | # scaled_y = scaled_y[..., np.newaxis]
100 | # spatial_features = np.concatenate(
101 | # (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1)
102 | # feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1)
103 | classes = ['__background__']
104 | with open('utils/objects_vocab.txt', 'r') as f:
105 | for object in f.readlines():
106 | classes.append(object.split(',')[0].lower().strip())
107 | # print(features.shape)
108 | # plot top 5 objects
109 | img = imgs[i]
110 | if len(boxes) < 5:
111 | return False
112 | # print(img.shape)
113 | # print(img)
114 | colormap = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (155, 100, 100), (100, 155, 100)]
115 | for j in range(5):
116 | # print(boxes[j])
117 | cv2.putText(img, '%s: %s' % (classes[object_ids[j] + 1], confident[j]), (int(boxes[j][0]), int(boxes[j][1] + 15)),
118 | cv2.FONT_HERSHEY_TRIPLEX,
119 | 0.5,
120 | colormap[j],
121 | 1)
122 | cv2.rectangle(img, (int(boxes[j][0]), int(boxes[j][1])), (int(boxes[j][2]), int(boxes[j][3])),
123 | colormap[j],
124 | 1)
125 | if concat_imgs is None:
126 | concat_imgs = img
127 | else:
128 | concat_imgs = np.concatenate((concat_imgs, img), axis=1)
129 | caption_img = np.ones((50, imgs[0].shape[1] * 3, 3)) * 255
130 | cv2.putText(caption_img, caption, (10, 10), cv2.FONT_HERSHEY_TRIPLEX,
131 | 0.5,
132 | colormap[j],
133 | 1)
134 | concat_imgs = np.concatenate((concat_imgs, caption_img), axis=0)
135 | cv2.imwrite(outpath, concat_imgs)
136 | return outpath
137 |
138 |
139 | if __name__ == '__main__':
140 | metadata = pd.read_csv(full_csv, sep='\t')
141 | count = 0
142 | for i in range(len(metadata)):
143 | sample = metadata.iloc[i]
144 | count += 1
145 | if count > 200:
146 | break
147 | video_path = os.path.join(data_source, sample[1] + '.mp4')
148 | imgs = read_frames_cv2(video_path, 8, 'uniform')
149 | feat_paths = [feat_source + sample[1] + '/' + str(k) + '.npz' for k in range(8)]
150 | outpath = 'visualization/3f/{}_{}.jpg'.format(i, sample[1].split('/')[1])
151 | tri_region_visualize(imgs, feat_paths, sample[0], outpath)
152 |
153 |
--------------------------------------------------------------------------------
/OATrans/utils/visualization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/__init__.py
--------------------------------------------------------------------------------
/OATrans/utils/visualization/learned_embedding_visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | the visualization of learned embedding
3 | """
4 |
5 | # That's an impressive list of imports.
6 | import numpy as np
7 | import torch
8 | from numpy import linalg
9 | from numpy.linalg import norm
10 | from scipy.spatial.distance import squareform, pdist
11 |
12 | # We import sklearn.
13 | import sklearn
14 | # from sklearn.manifold import TSNE
15 | from sklearn.manifold._t_sne import TSNE
16 | from sklearn.datasets import load_digits
17 | from sklearn.preprocessing import scale
18 |
19 | # We'll hack a bit with the t-SNE code in sklearn 0.15.2.
20 | from sklearn.metrics.pairwise import pairwise_distances
21 | from sklearn.manifold._t_sne import (_joint_probabilities,
22 | _kl_divergence)
23 | # Random state.
24 | RS = 20150101
25 |
26 | # We'll use matplotlib for graphics.
27 | import matplotlib.pyplot as plt
28 | import matplotlib.patheffects as PathEffects
29 | import matplotlib
30 | # %matplotlib inline
31 |
32 | import random
33 | # We import seaborn to make nice plots.
34 | import seaborn as sns
35 | import os
36 | sns.set_style('darkgrid')
37 | sns.set_palette('muted')
38 | sns.set_context("notebook", font_scale=1.5,
39 | rc={"lines.linewidth": 2.5})
40 |
41 | # We'll generate an animation with matplotlib and moviepy.
42 | # from moviepy.video.io.bindings import mplfig_to_npimage
43 | # import moviepy.editor as mpy
44 |
45 |
46 | def load_data(file_name):
47 | # digits = load_digits()
48 | # digits.data.shape # 1797 x 64
49 | # print(digits.data.shape)
50 | # print(digits['DESCR'])
51 | # return digits
52 | features = np.load(file_name, allow_pickle='TRUE').tolist()
53 | return features
54 |
55 |
56 | def scatter(x, colors, num_class=10):
57 | # We choose a color palette with seaborn.
58 | palette = np.array(sns.color_palette("hls", num_class))
59 | # sns.palplot(sns.color_palette("hls", 10))
60 | # We create a scatter plot.
61 | labels=['brush_hair', 'cartwheel', 'catch', 'chew',
62 | 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword',
63 | 'dribble']
64 | f = plt.figure(figsize=(8, 8))
65 | # print(colors.astype(np.int))
66 | ax = plt.subplot(aspect='equal')
67 | # for i in range(10):
68 | # sc = ax.scatter(x[:, 0][30*i:30*(i+1)], x[:, 1][30*i:30*(i+1)], c=palette[colors.astype(np.int)][30*i:30*(i+1)],
69 | # s=40,
70 | # label=labels[i],
71 | # )
72 | sc = ax.scatter(x[:,0], x[:,1], c=palette[colors.astype(np.int)],
73 | s=150,
74 | #label=colors.astype(np.int)[30],
75 | )
76 | # ax.legend(loc="best", title="Classes", bbox_to_anchor=(0.2, 0.4))
77 | plt.xlim(-25, 25)
78 | plt.ylim(-25, 25)
79 | ax.axis('off')
80 | ax.axis('tight')
81 |
82 | # We add the labels for each digit.
83 | txts = []
84 | for i in range(num_class):
85 | # Position of each label.
86 | xtext, ytext = np.median(x[colors == i, :], axis=0)
87 | txt = ax.text(xtext, ytext, str(i), fontsize=24)
88 | # ax.legend(ytext, "a")
89 | txt.set_path_effects([
90 | PathEffects.Stroke(linewidth=5, foreground="w"),
91 | PathEffects.Normal()])
92 | txts.append(txt)
93 | # ax.legend(('a','b','c','d','e'))
94 | return f, ax, sc, txts
95 |
96 |
97 | def tsne_visualize(data, file_name, num_class=101):
98 | # nrows, ncols = 2, 5
99 | # plt.figure(figsize=(6,3))
100 | # plt.gray()
101 | # for i in range(ncols * nrows):
102 | # ax = plt.subplot(nrows, ncols, i + 1)
103 | # ax.matshow(digits.images[i,...])
104 | # plt.xticks([]); plt.yticks([])
105 | # plt.title(digits.target[i])
106 | # plt.savefig('../../../experiments/visualization/digits-generated.png', dpi=150)
107 |
108 | # We first reorder the data points according to the handwritten numbers.
109 | datas = []
110 | labels = []
111 | nums = len(data)
112 | print(nums)
113 | for j in range(nums):
114 | datas.append(data[j])
115 | X = np.vstack(datas)
116 | for j in range(nums):
117 | # labels.append(min(j+1, nums-1))
118 | labels.append(1)
119 | y = np.hstack(labels)
120 | # X = np.vstack([data['data'][data['target']==i].cpu()
121 | # for i in range(10)])
122 | # y = np.hstack([data['target'][data['target']==i].cpu()
123 | # for i in range(10)])
124 | # print(y)
125 | digits_proj = TSNE(random_state=RS).fit_transform(X)
126 | scatter(digits_proj, y, nums)
127 | plt.savefig(file_name, dpi=120)
128 |
129 |
130 | # features_file = "utils/visualization/vid_embeds.npy"
131 | # file_name = "utils/visualization/figures/vid_embeds.png"
132 | # features_file = "utils/visualization/text_embeds.npy"
133 | # file_name = "utils/visualization/figures/text_embeds.png"
134 | features_file = "utils/visualization/sims_embeds.npy"
135 | file_name = "utils/visualization/figures/sims_embeds.png"
136 | data = load_data(features_file)
137 | tsne_visualize(data, file_name, '0')
--------------------------------------------------------------------------------
/OATrans/utils/visualization/msrvtt_3f_vto_visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | visualize both image + object + text
3 | """
4 | import numpy as np
5 | import cv2
6 | from csv import reader
7 | import os
8 | import random
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import pdb
12 | import textwrap
13 | import pandas as pd
14 | import json
15 |
16 | full_csv = "MSRVTT/annotation/MSR_VTT.json"
17 | data_source = "MSRVTT/videos/all"
18 | feat_source = "MSRVTT/region_features_full"
19 | output = "MSRVTT/region_visualization"
20 |
21 |
22 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None):
23 | acc_samples = min(num_frames, vlen)
24 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
25 | ranges = []
26 | for idx, interv in enumerate(intervals[:-1]):
27 | ranges.append((interv, intervals[idx + 1] - 1))
28 | if sample == 'rand':
29 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges]
30 | elif fix_start is not None:
31 | frame_idxs = [x[0] + fix_start for x in ranges]
32 | elif sample == 'uniform':
33 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
34 | else:
35 | raise NotImplementedError
36 | return frame_idxs
37 |
38 |
39 | def read_frames_cv2(video_path, num_frames, sample='uniform', fix_start=None):
40 | cap = cv2.VideoCapture(video_path)
41 | assert (cap.isOpened())
42 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
43 | # get indexes of sampled frames
44 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start)
45 | frames = []
46 | success_idxs = []
47 | for index in frame_idxs:
48 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
49 | ret, frame = cap.read()
50 | if ret:
51 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52 | frames.append(frame)
53 | success_idxs.append(index)
54 | else:
55 | pass
56 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})')
57 | cap.release()
58 | return frames, success_idxs
59 |
60 | # step1: open video
61 |
62 | # step2: video seq
63 |
64 | # step3: video feature
65 |
66 |
67 | def tri_region_visualize(imgs, feat_paths, caption, outpath="visualization/1.png"):
68 | concat_imgs = None
69 | for i in [0, 3, 7]:
70 | frame1 = np.load(feat_paths[i], allow_pickle=True)
71 | boxes = frame1['bbox']
72 | features = frame1['x'] # 20 x 2048
73 | confident = frame1['info'].item()['objects_conf']
74 | # step 1: re-ranking the region with confidence
75 | object_ids = frame1['info'].item()['objects_id']
76 | condident_indices = np.argsort(confident)[::-1]
77 | boxes = boxes[condident_indices]
78 | features = features[condident_indices]
79 | object_ids = object_ids[condident_indices]
80 | confident = confident[condident_indices]
81 | new_object, unique_indices = np.unique(object_ids, return_index=True)
82 | # step 2: remove region with same object class
83 | boxes = boxes[unique_indices]
84 | features = features[unique_indices]
85 | object_ids = object_ids[unique_indices]
86 | # object_ids = object_ids[unique_indices]
87 | # confident = confident[unique_indices]
88 | # # print(boxes, features)
89 | # image_width = frame1['info'].item()['image_w']
90 | # image_height = frame1['info'].item()['image_h']
91 | # box_width = boxes[:, 2] - boxes[:, 0]
92 | # box_height = boxes[:, 3] - boxes[:, 1]
93 | # scaled_width = box_width / image_width
94 | # scaled_height = box_height / image_height
95 | # scaled_x = boxes[:, 0] / image_width
96 | # scaled_y = boxes[:, 1] / image_height
97 | # scaled_width = scaled_width[..., np.newaxis]
98 | # scaled_height = scaled_height[..., np.newaxis]
99 | # scaled_x = scaled_x[..., np.newaxis]
100 | # scaled_y = scaled_y[..., np.newaxis]
101 | # spatial_features = np.concatenate(
102 | # (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1)
103 | # feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1)
104 | classes = ['__background__']
105 | with open('utils/objects_vocab.txt', 'r') as f:
106 | for object in f.readlines():
107 | classes.append(object.split(',')[0].lower().strip())
108 | # print(features.shape)
109 | # plot top 5 objects
110 | img = imgs[i]
111 | if len(boxes) < 5:
112 | return False
113 | # print(img.shape)
114 | # print(img)
115 | colormap = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (155, 100, 100), (100, 155, 100)]
116 | for j in range(5):
117 | # print(boxes[j])
118 | cv2.putText(img, '%s: %s' % (classes[object_ids[j] + 1], confident[j]), (int(boxes[j][0]), int(boxes[j][1] + 15)),
119 | cv2.FONT_HERSHEY_TRIPLEX,
120 | 0.5,
121 | colormap[j],
122 | 1)
123 | cv2.rectangle(img, (int(boxes[j][0]), int(boxes[j][1])), (int(boxes[j][2]), int(boxes[j][3])),
124 | colormap[j],
125 | 1)
126 | if concat_imgs is None:
127 | concat_imgs = img
128 | else:
129 | concat_imgs = np.concatenate((concat_imgs, img), axis=1)
130 | caption_img = np.ones((50, imgs[0].shape[1] * 3, 3)) * 255
131 | cv2.putText(caption_img, caption, (10, 10), cv2.FONT_HERSHEY_TRIPLEX,
132 | 0.5,
133 | colormap[j],
134 | 1)
135 | concat_imgs = np.concatenate((concat_imgs, caption_img), axis=0)
136 | cv2.imwrite(outpath, concat_imgs)
137 | return outpath
138 |
139 |
140 | if __name__ == '__main__':
141 |
142 | f = open(full_csv)
143 | data = json.load(f)
144 | count = 0
145 |
146 | for row in data['annotations']:
147 | count += 1
148 | if count % 10 != 0:
149 | continue
150 | # if row['id'] % 200 != 1:
151 | # continue
152 | if count > 2000:
153 | break
154 | video_path = os.path.join(data_source, row['image_id'] + '.mp4')
155 | imgs, success_idxs = read_frames_cv2(video_path, 8, 'uniform')
156 | feat_paths = [feat_source + '/' + row['image_id'] + '/' + str(success_idxs[k]) + '.npz' for k in range(8)]
157 | outpath = 'visualization/msrvtt_3f/{}_{}.jpg'.format(count, row['image_id'])
158 | tri_region_visualize(imgs, feat_paths, row['caption'], outpath)
159 |
160 |
--------------------------------------------------------------------------------
/OATrans/utils/visualization/msrvtt_vto_visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | visualize both image + object + text
3 | """
4 | import numpy as np
5 | import cv2
6 | from csv import reader
7 | import os
8 | import random
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import pdb
12 | import textwrap
13 | import json
14 |
15 | full_json = "MSRVTT/annotation/MSR_VTT.json"
16 | data_source = "MSRVTT/videos/all"
17 | feat_source = "MSRVTT/region_features"
18 | output = "MSRVTT/region_visualization"
19 |
20 |
21 | def feature_visualize(img1, feat_path):
22 | frame1 = np.load(feat_path, allow_pickle=True)
23 | boxes = frame1['bbox']
24 | features = frame1['x'] # 20 x 2048
25 | confident = frame1['info'].item()['objects_conf']
26 | # step 1: re-ranking the region with confidence
27 | object_ids = frame1['info'].item()['objects_id']
28 | condident_indices = np.argsort(confident)[::-1]
29 | boxes = boxes[condident_indices]
30 | features = features[condident_indices]
31 | object_ids = object_ids[condident_indices]
32 | confident = confident[condident_indices]
33 |
34 | new_object, unique_indices = np.unique(object_ids, return_index=True)
35 | # step 2: remove region with same object class
36 |
37 | boxes = boxes[unique_indices]
38 | features = features[unique_indices]
39 | object_ids = object_ids[unique_indices]
40 | confident = confident[unique_indices]
41 |
42 | # print(boxes, features)
43 | image_width = frame1['info'].item()['image_w']
44 | image_height = frame1['info'].item()['image_h']
45 |
46 | box_width = boxes[:, 2] - boxes[:, 0]
47 | box_height = boxes[:, 3] - boxes[:, 1]
48 | scaled_width = box_width / image_width
49 | scaled_height = box_height / image_height
50 | scaled_x = boxes[:, 0] / image_width
51 | scaled_y = boxes[:, 1] / image_height
52 | scaled_width = scaled_width[..., np.newaxis]
53 | scaled_height = scaled_height[..., np.newaxis]
54 | scaled_x = scaled_x[..., np.newaxis]
55 | scaled_y = scaled_y[..., np.newaxis]
56 | spatial_features = np.concatenate(
57 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1)
58 | # print(spatial_features)
59 | feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1)
60 | classes = ['__background__']
61 | with open('../objects_vocab.txt', 'r') as f:
62 | for object in f.readlines():
63 | classes.append(object.split(',')[0].lower().strip())
64 | # print(features.shape)
65 | im = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
66 | plt.axis('off')
67 | plt.imshow(im)
68 | new_boxes = boxes
69 | for i in range(len(new_boxes)):
70 | bbox = new_boxes[i]
71 | if i < 10:
72 | plt.gca().add_patch(
73 | plt.Rectangle((bbox[0], bbox[1]),
74 | bbox[2] - bbox[0],
75 | bbox[3] - bbox[1], fill=False,
76 | edgecolor='red', linewidth=2, alpha=0.5)
77 | )
78 | plt.gca().text(bbox[0], bbox[1] - 2,
79 | '%s: %s' % (classes[object_ids[i] + 1], confident[i]),
80 | bbox=dict(facecolor='blue', alpha=0.5),
81 | fontsize=10, color='white')
82 | outpath = "test_roi.png"
83 | plt.savefig(outpath, dpi=150)
84 | plt.close()
85 | return outpath
86 |
87 | f = open(full_json)
88 | data = json.load(f)
89 | # for key in data:
90 | # print(key)
91 | # pdb.set_trace()
92 | # Iterate over each row in the csv using reader object
93 | count = 0
94 |
95 | for row in data['annotations']:
96 | count += 1
97 | if row['id'] % 200 != 1:
98 | continue
99 | video_path = os.path.join(data_source, row['image_id'] + '.mp4')
100 | cap = cv2.VideoCapture(video_path)
101 | ret, img = cap.read()
102 | feat_path = os.path.join(feat_source, row['image_id'], '1.npz')
103 | feat_img = cv2.imread(feature_visualize(img, feat_path))
104 | print(feat_img.shape)
105 | caption_img = np.ones([feat_img.shape[0]//4, feat_img.shape[1], 3]) * 255
106 |
107 | wrapped_text = textwrap.wrap(row['caption'], width=35)
108 | x, y = 10, 40
109 | font_size = 1
110 | font_thickness = 2
111 | font = cv2.FONT_HERSHEY_TRIPLEX
112 |
113 | for i, line in enumerate(wrapped_text):
114 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0]
115 |
116 | gap = textsize[1] + 10
117 |
118 | y = 25 + i * gap
119 | x = int((caption_img.shape[1] - textsize[0]) / 2) + 20
120 |
121 | cv2.putText(caption_img, line, (x, y), font,
122 | font_size,
123 | (122, 21, 91),
124 | font_thickness,
125 | lineType=cv2.LINE_AA)
126 |
127 | # cv2.putText(caption_img, row[1], (30, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
128 | # 1)
129 | concat = np.concatenate([feat_img[50:-50, :, :], caption_img], axis=0)
130 | # cv2.imshow(concat)
131 | out_file = os.path.join(output, row['image_id'] + '_' + str(row['id']) + '.png')
132 | print("hello world")
133 | cv2.imwrite(out_file, concat)
134 | # if cv2.waitKey(33) == 27:
135 | # continue
136 |
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/0_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/0_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/10_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/10_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/11_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/11_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/12_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/12_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/13_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/13_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/14_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/14_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/1_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/1_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/2_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/2_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/3_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/3_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/4_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/4_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/5_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/5_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/6_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/6_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/7_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/7_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/8_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/8_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/predict_visualization/9_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/predict_visualization/9_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/print_tags.py:
--------------------------------------------------------------------------------
1 | def predict2caption(predict, vocab='utils/objects_vocab.txt'):
2 | caption = ""
3 | classes = ['__background__']
4 | with open(vocab, 'r') as f:
5 | for object in f.readlines():
6 | classes.append(object.split(',')[0].lower().strip())
7 | for n in range(len(predict)):
8 | caption += ' ' + (classes[predict[n]+1])
9 | return caption
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/0_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/0_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/10_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/10_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/11_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/11_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/12_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/12_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/13_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/13_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/14_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/14_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/15_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/15_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/1_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/1_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/2_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/2_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/3_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/3_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/4_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/4_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/5_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/5_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/6_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/6_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/7_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/7_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/8_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/8_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/transfer_predict_visualization/9_predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/OATrans/utils/visualization/transfer_predict_visualization/9_predict.png
--------------------------------------------------------------------------------
/OATrans/utils/visualization/webvid_vto_visualization.py:
--------------------------------------------------------------------------------
1 | """
2 | visualize both image + object + text
3 | """
4 | import numpy as np
5 | import cv2
6 | from csv import reader
7 | import os
8 | import random
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import pdb
12 | import textwrap
13 |
14 | full_csv = "WebVid2M_videos/metadata/results_subset_train.csv"
15 | data_source = "WebVid2M_videos/train_videos"
16 | feat_source = "WebVid2M_frames_region_features/train"
17 | output = "WebVid2M_visualization/train"
18 |
19 |
20 | def feature_visualize(img1, feat_path):
21 | frame1 = np.load(feat_path, allow_pickle=True)
22 | boxes = frame1['bbox']
23 | features = frame1['x'] # 20 x 2048
24 | confident = frame1['info'].item()['objects_conf']
25 | # step 1: re-ranking the region with confidence
26 | object_ids = frame1['info'].item()['objects_id']
27 | condident_indices = np.argsort(confident)[::-1]
28 | boxes = boxes[condident_indices]
29 | features = features[condident_indices]
30 | object_ids = object_ids[condident_indices]
31 | confident = confident[condident_indices]
32 |
33 | new_object, unique_indices = np.unique(object_ids, return_index=True)
34 | # step 2: remove region with same object class
35 |
36 | boxes = boxes[unique_indices]
37 | features = features[unique_indices]
38 | object_ids = object_ids[unique_indices]
39 | confident = confident[unique_indices]
40 |
41 | # print(boxes, features)
42 | image_width = frame1['info'].item()['image_w']
43 | image_height = frame1['info'].item()['image_h']
44 |
45 | box_width = boxes[:, 2] - boxes[:, 0]
46 | box_height = boxes[:, 3] - boxes[:, 1]
47 | scaled_width = box_width / image_width
48 | scaled_height = box_height / image_height
49 | scaled_x = boxes[:, 0] / image_width
50 | scaled_y = boxes[:, 1] / image_height
51 | scaled_width = scaled_width[..., np.newaxis]
52 | scaled_height = scaled_height[..., np.newaxis]
53 | scaled_x = scaled_x[..., np.newaxis]
54 | scaled_y = scaled_y[..., np.newaxis]
55 | spatial_features = np.concatenate(
56 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1)
57 | # print(spatial_features)
58 | feat = torch.cat([torch.from_numpy(features), torch.from_numpy(spatial_features)], dim=1)
59 | classes = ['__background__']
60 | with open('../objects_vocab.txt', 'r') as f:
61 | for object in f.readlines():
62 | classes.append(object.split(',')[0].lower().strip())
63 | # print(features.shape)
64 | im = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
65 | plt.axis('off')
66 | plt.imshow(im)
67 | new_boxes = boxes
68 | for i in range(len(new_boxes)):
69 | bbox = new_boxes[i]
70 | if i < 10:
71 | plt.gca().add_patch(
72 | plt.Rectangle((bbox[0], bbox[1]),
73 | bbox[2] - bbox[0],
74 | bbox[3] - bbox[1], fill=False,
75 | edgecolor='red', linewidth=2, alpha=0.5)
76 | )
77 | plt.gca().text(bbox[0], bbox[1] - 2,
78 | '%s: %s' % (classes[object_ids[i] + 1], confident[i]),
79 | bbox=dict(facecolor='blue', alpha=0.5),
80 | fontsize=10, color='white')
81 | outpath = "test_roi.png"
82 | plt.savefig(outpath, dpi=150)
83 | plt.close()
84 | return outpath
85 |
86 |
87 | with open(full_csv, 'r') as read_obj:
88 | # pass the file object to reader() to get the reader object
89 | csv_reader = reader(read_obj)
90 | # Iterate over each row in the csv using reader object
91 | count = 0
92 | for row in csv_reader:
93 | count += 1
94 | if count == 1:
95 | continue
96 | # if count > 3:
97 | # break
98 | # cv2.destroyAllWindows()
99 | if len(row[3]) < 3:
100 | continue
101 | video_path = os.path.join(data_source, row[3], row[0] + '.mp4')
102 | cap = cv2.VideoCapture(video_path)
103 | ret, img = cap.read()
104 | feat_path = os.path.join(feat_source, row[3], row[0], '1.npz')
105 | feat_img = cv2.imread(feature_visualize(img, feat_path))
106 | print(feat_img.shape)
107 | caption_img = np.ones([feat_img.shape[0]//4, feat_img.shape[1], 3]) * 255
108 |
109 | wrapped_text = textwrap.wrap(row[1], width=35)
110 | x, y = 10, 40
111 | font_size = 1
112 | font_thickness = 2
113 | font = cv2.FONT_HERSHEY_TRIPLEX
114 |
115 | for i, line in enumerate(wrapped_text):
116 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0]
117 |
118 | gap = textsize[1] + 10
119 |
120 | y = 25 + i * gap
121 | x = int((caption_img.shape[1] - textsize[0]) / 2) + 20
122 |
123 | cv2.putText(caption_img, line, (x, y), font,
124 | font_size,
125 | (122, 21, 91),
126 | font_thickness,
127 | lineType=cv2.LINE_AA)
128 |
129 | # cv2.putText(caption_img, row[1], (30, 30), cv2.FONT_HERSHEY_TRIPLEX, 1, (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
130 | # 1)
131 | concat = np.concatenate([feat_img[50:-50, :, :], caption_img], axis=0)
132 | # cv2.imshow(concat)
133 | out_file = os.path.join(output, row[3] + row[0] + '.png')
134 | print("hello world")
135 | cv2.imwrite(out_file, concat)
136 | # if cv2.waitKey(33) == 27:
137 | # continue
138 |
--------------------------------------------------------------------------------
/OATrans/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | """A simple HTML visualizer.
2 |
3 | It is based on the Cycle-GAN codebase:
4 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
5 | """
6 | import os
7 | import numpy as np
8 | from pathlib import Path
9 | from . import util, html
10 | import pdb
11 |
12 | class RetrievalVis:
13 | """This class includes several functions that can display/save images.
14 |
15 | It uses a Python library 'visdom' for display, and a Python library 'dominate'
16 | (wrapped in 'HTML') for creating HTML files with images.
17 | """
18 |
19 | def __init__(self, exp_name, web_dir, src_video_dir, vis_vid_freq, num_samples=50):
20 | """Initialize the Visualizer class
21 | Create an HTML object for saveing HTML filters
22 | """
23 | self.name = exp_name
24 | self.web_dir = web_dir
25 | self.vis_vid_freq = vis_vid_freq
26 | self.img_dir = os.path.join(self.web_dir, "images")
27 | self.num_samples = num_samples
28 |
29 | self.data_type = 'images' # 'images' or 'videos'
30 | assert self.data_type in ('images', 'videos')
31 |
32 | print(f"create web directory {self.web_dir}...")
33 | mkdirs([self.web_dir, self.img_dir])
34 |
35 | # cluster specific
36 | if "$TMPDIR" in src_video_dir:
37 | src_video_dir = src_video_dir.replace("$TMPDIR", os.environ['TMPDIR'])
38 |
39 | src_dir = Path(src_video_dir).absolute()
40 | print(f"symlinking videos from {src_dir}...")
41 | sym_dir = (Path(self.web_dir) / "videos").absolute()
42 | if sym_dir.is_symlink():
43 | os.remove(sym_dir)
44 | sym_dir.symlink_to(src_dir)
45 |
46 | def visualize_ranking(self, sims, epoch, meta, nested_metrics):
47 | if not (self.vis_vid_freq and epoch % self.vis_vid_freq == 0):
48 | return
49 |
50 | dists = -sims
51 | np.random.seed(0)
52 | sorted_ranks = np.argsort(dists, axis=1)
53 | gt_dists = np.diag(dists)
54 | rankings = []
55 | vis_top_k = 5
56 | hide_gt = False
57 | # num_indep_samples = 1
58 | # random_seeds = np.arange(num_indep_samples)
59 | sample = np.random.choice(np.arange(dists.shape[0]), size=self.num_samples,
60 | replace=False)
61 | for ii in sample:
62 | ranked_idx = sorted_ranks[ii][:vis_top_k]
63 | gt_captions = meta["raw_captions"][ii]
64 | # if args.sample_single_gt_caption:
65 | # gt_captions = np.random.choice(gt_captions, 1).tolist()
66 | datum = {
67 | "gt-sim": -gt_dists[ii],
68 | "gt-captions": gt_captions,
69 | "gt-rank": np.where(sorted_ranks[ii] == ii)[0][0],
70 | "gt-path": meta["paths"][ii],
71 | "top-k-sims": -dists[ii][ranked_idx],
72 | "top-k-paths": np.array(meta["paths"])[ranked_idx],
73 | "hide-gt": hide_gt,
74 | }
75 | rankings.append(datum)
76 | self.display_current_results(
77 | rankings,
78 | epoch=epoch,
79 | metrics=nested_metrics["t2v_metrics"],
80 | )
81 |
82 | def display_current_results(self, rankings, epoch, metrics):
83 | """Display current results on visdom; save current results to an HTML file.
84 |
85 | Parameters:
86 | visuals (OrderedDict) - - dictionary of images to display or save
87 | epoch (int) - - the current epoch
88 | save_result (bool) - - if save the current results to an HTML file
89 | """
90 | if not Path(self.web_dir).exists():
91 | Path(self.web_dir).mkdir(exist_ok=True, parents=True)
92 | print(f"updating webpage at {self.web_dir}")
93 | title = f"Experiment name = {self.name}"
94 | refresh = True
95 | if not refresh:
96 | print("DISABLING WEB PAGE REFRESH")
97 | webpage = html.HTML(web_dir=self.web_dir, title=title, refresh=refresh)
98 |
99 | msg = f"epoch [{epoch}] - {self.name}"
100 | webpage.add_header(msg)
101 | msg = (f"R1: {metrics['R1']:.1f}, "
102 | f"R5: {metrics['R5']:.1f}, "
103 | f"R10: {metrics['R10']:.1f}, "
104 | f"MedR: {metrics['MedR']}")
105 | webpage.add_header(msg)
106 | print(f"Top {len(rankings[0])} retreived videos at epoch: {epoch}")
107 |
108 | for ranking in rankings:
109 | vids, txts, links = [], [], []
110 | gt_vid_path = os.path.join('videos', ranking["gt-path"])
111 | #gt_captions = [" ".join(x) for x in ranking["gt-captions"]]
112 | gt_captions = ranking['gt-captions']
113 | gt_captions = "
" + (gt_captions) + "
"
114 | if ranking["hide-gt"]:
115 | txts.append(gt_captions)
116 | links.append("hidden")
117 | vids.append("hidden")
118 | else:
119 | txt = (f"{gt_captions}
Rank: {ranking['gt-rank']}, "
120 | f"Sim: {ranking['gt-sim']:.3f} [{Path(ranking['gt-path']).stem}]")
121 | txts.append(txt)
122 | links.append(gt_vid_path)
123 | vids.append(gt_vid_path)
124 |
125 | for idx, (vid_path, sim) in enumerate(zip(ranking["top-k-paths"],
126 | ranking["top-k-sims"])):
127 | vid_path = Path(os.path.join('videos', vid_path))
128 | if ranking["hide-gt"]:
129 | txt = f"choice: {idx}"
130 | else:
131 | txt = f"Rank: {idx}, Sim: {sim:.3f}, [{Path(vid_path).stem}]"
132 | txts.append(txt)
133 | vids.append(vid_path)
134 | links.append(vid_path)
135 | if self.data_type == 'videos':
136 | webpage.add_videos(vids, txts, links, width=200)
137 | elif self.data_type == 'images':
138 | webpage.add_images(vids, txts, links, width=200)
139 | print(f"added {len(vids)} videos")
140 | webpage.save()
141 |
142 | def mkdirs(paths):
143 | """create empty directories if they don't exist
144 |
145 | Parameters:
146 | paths (str list) -- a list of directory paths
147 | """
148 | if isinstance(paths, list) and not isinstance(paths, str):
149 | for path in paths:
150 | mkdir(path)
151 | else:
152 | mkdir(paths)
153 |
154 |
155 | def mkdir(path):
156 | """create a single empty directory if it didn't exist
157 |
158 | Parameters:
159 | path (str) -- a single directory path
160 | """
161 | if not os.path.exists(path):
162 | os.makedirs(path)
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 22] "Object-aware Video-language Pre-training for Retrieval" [arxiv](https://arxiv.org/abs/2112.00656)
2 |
3 |
4 | 
5 |
6 |
7 | ## 1. Object Feature Extractor
8 |
9 | We provide a faster version to extract object from WebVid 2.5M and CC 3M.
10 | We extract objects of 5.5M * 8 = 44M frames in total and it takes 28 days on 16 V100 GPUs.
11 |
12 | Refer to [Object Extractor.md](object_extraction.md) for more details.
13 |
14 |
15 | ## 2. OA Trans
16 |
17 | Refer to [train.md](train.md) for more details.
18 |
19 | ## 3. Visualizations
20 |
21 | In this code, we provide two ways to visualize cross-modality attention.
22 |
23 | ### Heatmap Visualization
24 | 
25 |
26 |
27 | ### Binary Map Visualization
28 | 
29 |
30 | Please refer to [visualization.md](visualization.md) for details.
31 |
32 |
33 | ## News:
34 | - 2021.12.5 Arxiv Version Published.
35 | - 2022.3.15 First version Code Released.
36 |
37 | ## 5. Citation
38 |
39 | If you find our work helpful, please cite our paper
40 | ```bash
41 | @article{wang2022oatrans,
42 | title={Object-aware Video-language Pre-training for Retrieval},
43 | author={Wang, Alex Jinpeng and Ge, Yixiao and Cai, Guanyu and Yan, Rui and Lin, Xudong and Shan, Ying and Qie, Xiaohu and Shou, Mike Zheng},
44 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision},
45 | year={2022}
46 | }
47 | ```
48 |
49 | ## Acknowledgement
50 |
51 | This work is mainly based on [Frozen](https://github.com/m-bain/frozen-in-time).
--------------------------------------------------------------------------------
/Visualization/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/.DS_Store
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/.DS_Store
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | from torchvision import transforms
5 | import torch
6 | from PIL import Image, ImageFile
7 |
8 |
9 | def sample_frames(num_frames, vlen, sample='rand', fix_start=None):
10 | acc_samples = min(num_frames, vlen)
11 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
12 | ranges = []
13 | for idx, interv in enumerate(intervals[:-1]):
14 | ranges.append((interv, intervals[idx + 1] - 1))
15 | if sample == 'rand':
16 | frame_idxs = [random.choice(range(x[0], x[1])) for x in ranges]
17 | elif fix_start is not None:
18 | frame_idxs = [x[0] + fix_start for x in ranges]
19 | elif sample == 'uniform':
20 | frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
21 | else:
22 | raise NotImplementedError
23 | return frame_idxs
24 |
25 |
26 | def read_frames_cv2(video_path, num_frames, sample='rand', fix_start=None, numpy=False):
27 | cap = cv2.VideoCapture(video_path)
28 | assert (cap.isOpened())
29 | vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
30 | # get indexes of sampled frames
31 | frame_idxs = sample_frames(num_frames, vlen, sample=sample, fix_start=fix_start)
32 | frames = []
33 | success_idxs = []
34 | for index in frame_idxs:
35 | cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
36 | ret, frame = cap.read()
37 | # print(frame.shape)
38 | if ret:
39 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
40 | frame = torch.from_numpy(frame)
41 | # (H x W x C) to (C x H x W)
42 | frame = frame.permute(2, 0, 1)
43 | frames.append(frame)
44 | success_idxs.append(index)
45 | else:
46 | pass
47 | # print(frame_idxs, ' fail ', index, f' (vlen {vlen})')
48 | if not numpy:
49 | frames = torch.stack(frames).float() / 255
50 | cap.release()
51 | return frames, success_idxs, vlen
52 |
53 |
54 | def vision_preprocess(vid_src):
55 | video, _, _ = read_frames_cv2(vid_src, 1)
56 | transform = transforms.Compose(
57 | [
58 | transforms.Resize(size=(224, 224)),
59 | # transforms.RandomResizedCrop(size=(224, 224)),
60 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61 | ]
62 | )
63 | video = transform(video)
64 | # expand one dim as batch
65 | video = video.unsqueeze(0)
66 | return video.cuda()
67 |
68 |
69 | def mask_vision_preprocess(vid_src):
70 | video, _, _ = read_frames_cv2(vid_src, 8)
71 | transform = transforms.Compose(
72 | [
73 | transforms.Resize(size=(224, 224)),
74 | # transforms.RandomResizedCrop(size=(224, 224)),
75 | # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
76 | ]
77 | )
78 | video = transform(video[:3])
79 | return video
80 |
81 |
82 | def vision_img_preprocess(img_src):
83 | img = Image.open(img_src).convert("RGB")
84 | img = transforms.ToTensor()(img).unsqueeze(0)
85 | transform = transforms.Compose(
86 | [
87 | transforms.Resize(size=(224, 224)),
88 | # transforms.RandomResizedCrop(size=(224, 224)),
89 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
90 | ]
91 | )
92 | img = transform(img)
93 | # expand one dim as batch
94 | img = img.unsqueeze(0)
95 | return img.cuda()
96 |
97 |
98 | def clip_img_preprocess(img_src, preprocess):
99 | img = Image.open(img_src).convert("RGB")
100 | img = preprocess(img).unsqueeze(0)
101 | # print(img.size())
102 | return img.cuda().half()
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/main_img.py:
--------------------------------------------------------------------------------
1 | from model.text_model import text_encode_init
2 | from model.vision_model import vision_encode_init
3 | from data_preprocess import vision_img_preprocess, clip_img_preprocess
4 | import pandas as pd
5 | from visualize import cross_attention_visualize
6 | import parse_config as parse_config
7 | import model.vision_models.clip as clip
8 |
9 |
10 | csv_file = "data/meta_data/cc3m_training_success_full.tsv"
11 | # out_dir = 'output/featmap/'
12 | model_se = 'frozen' # 'frozen' or 'clip'
13 | # out_dir = 'output/featmap/{}/'.format(model_se)
14 | out_dir = 'output/cross_featmap/cc3m/{}_attn/'.format(model_se)
15 | video_root = 'CC3M/training/'
16 | metadata = pd.read_csv(csv_file, sep='\t')
17 | text_model = text_encode_init(model_name=model_se)
18 | img_model, preprocess = vision_encode_init(model_name=model_se)
19 |
20 | count = 0
21 | for item in range(len(metadata)):
22 | sample = metadata.iloc[item]
23 | video_src = video_root + sample[1]
24 | caption = sample[0]
25 | if model_se == 'clip':
26 | img = clip_img_preprocess(video_src, preprocess)
27 | else:
28 | img = vision_img_preprocess(video_src)
29 | print(img.size())
30 | img_patch_embedding = img_model(img)
31 | if model_se == 'clip':
32 | img = img.unsqueeze(0)
33 | if model_se == 'clip':
34 | text_token = text_model(clip.tokenize(caption).cuda())
35 | else:
36 | text_token = text_model(caption)
37 | # print(img_patch_embedding.size())
38 | if model_se == 'clip':
39 | img = img.float()
40 | cross_attention_visualize(img_patch_embedding, img[0], caption, text_token, text_model, model_name=model_se,
41 | name=out_dir + str(item), v=1)
42 | count += 1
43 | if count > 500:
44 | break
45 |
46 |
47 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/main_video.py:
--------------------------------------------------------------------------------
1 | from model.text_model import text_encode_init
2 | from model.vision_model import vision_encode_init
3 | from data_preprocess import vision_preprocess
4 | import pandas as pd
5 | from visualize import cross_attention_visualize
6 | import parse_config as parse_config
7 |
8 |
9 | csv_file = "data/webvid_validation_success_full.tsv"
10 | # out_dir = 'output/featmap/'
11 | out_dir = 'output/cross_featmap/'
12 | video_root = 'WebVid/val/'
13 | metadata = pd.read_csv(csv_file, sep='\t')
14 | text_model = text_encode_init()
15 | video_model = vision_encode_init()
16 |
17 | count = 0
18 | for item in range(len(metadata)):
19 | sample = metadata.iloc[item]
20 | video_src = video_root + sample[1] + '.mp4'
21 | caption = sample[0]
22 | # print(video_src)
23 | video = vision_preprocess(video_src)
24 | # print(video.size())
25 | video_patch_embedding = video_model(video)
26 | print(caption)
27 | text_token = text_model(caption)
28 | # print(video_patch_embedding.size())
29 | # print(text_token.size())
30 | sim = cross_attention_visualize(video_patch_embedding, video[0], caption, text_token, text_model, name=out_dir + str(item))
31 |
32 | count += 1
33 | if count > 100:
34 | break
35 |
36 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/main_video_patches_visualization.py:
--------------------------------------------------------------------------------
1 | from data_preprocess import mask_vision_preprocess
2 | import pandas as pd
3 | import os
4 | from patch_mask import visualize_mask
5 | from utils.read_bboxs import read_bbox_from_pickle
6 |
7 |
8 | csv_file = "data/webvid_validation_success_full.tsv"
9 | # out_dir = 'output/featmap/'
10 | out_dir = 'output/mask_object_visualization/'
11 | video_root = 'WebVid/val/'
12 | metadata = pd.read_csv(csv_file, sep='\t')
13 | features_root = 'WebVid/8_frame_object'
14 |
15 |
16 |
17 | count = 0
18 | for item in range(len(metadata)):
19 | sample = metadata.iloc[item]
20 | video_src = video_root + sample[1] + '.mp4'
21 | video = mask_vision_preprocess(video_src)
22 | object_bboxs = []
23 | for i in range(3):
24 | rel_object_fp = os.path.join(sample[1], '{}.npz'.format(i))
25 | full_object_fp = os.path.join(features_root, 'val', rel_object_fp)
26 | object_bboxs.append(read_bbox_from_pickle(full_object_fp))
27 | # print(video_src)
28 | out_name = out_dir + str(item)
29 | visualize_mask(video, object_bboxs, out_name)
30 | count += 1
31 | if count > 100:
32 | break
33 |
34 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/model/__init__.py
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/text_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import AutoModel
3 | import transformers
4 | import torch
5 | import model.vision_models.clip as clip
6 |
7 | class TextEncoder(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 | self.text_model = AutoModel.from_pretrained('pretrained/distilbert-base-uncased')
11 | self.tokenizer = transformers.AutoTokenizer.from_pretrained('pretrained/distilbert-base-uncased',
12 | TOKENIZERS_PARALLELISM=False)
13 | self.device = "cuda:0"
14 | self.txt_proj = nn.Sequential(nn.ReLU(),
15 | nn.Linear(768, 256),
16 | )
17 | def token_of_word(self, word):
18 | token = self.tokenizer(word, return_tensors='pt', padding=True,
19 | truncation=True)
20 | return token
21 |
22 | def forward(self, x):
23 | if self.tokenizer is not None:
24 | x = self.tokenizer(x, return_tensors='pt', padding=True,
25 | truncation=True)
26 | x = {key: val.to(self.device) for key, val in x.items()}
27 | text_embeddings_all = self.text_model(**x).last_hidden_state
28 | # print(text_embeddings_all.size()) # batch_size, sequence_length, hidden_size
29 | # text_embeddings = text_embeddings_all[:, 0, :]
30 | text_embeddings = text_embeddings_all
31 | # print(text_embeddings.size())
32 | return self.txt_proj(text_embeddings)
33 | # return text_embeddings
34 |
35 | def weight_transform(model_dict, pretrain_dict):
36 | '''
37 | :return:
38 | '''
39 | weight_dict = {k[7:]:v for k, v in pretrain_dict.items() if k[7:] in model_dict and k[:7] == 'module.'}
40 | # for k, v in pretrain_dict.items():
41 | # print(k[7:])
42 | # # pdb.set_trace()
43 | for k, v in pretrain_dict.items():
44 | if k[:14] == 'module.txt_proj':
45 | weight_dict[k[7:]] = v
46 | for k, v in weight_dict.items():
47 | print("load: {}".format(k))
48 | # print(weight_dict)
49 | model_dict.update(weight_dict)
50 | return model_dict
51 |
52 | def load_pt_weight(model):
53 | checkpoint = torch.load("pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar", map_location="cpu")
54 | pretrained_state = checkpoint['state_dict']
55 | model_state = model.state_dict()
56 | # for k , v in model_state.items():
57 | # print(k)
58 | model.load_state_dict(weight_transform(model_state, pretrained_state))
59 | return model
60 |
61 |
62 | def text_encode_init(model_name='frozen'):
63 | if model_name == 'clip':
64 | full_model, preprocess = clip.load("pretrained/ViT-B-16.pt")
65 | model = full_model.encode_text
66 | else:
67 | model = TextEncoder()
68 | load_pt_weight(model)
69 | model = model.cuda()
70 | return model
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/text_models/distill_bert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import AutoModel
3 | import transformers
4 |
5 |
6 | class DistillBert(nn.module):
7 | def __init__(self)
8 | super().__init__()
9 | self.text_model = AutoModel.from_pretrained(text_params['model'])
10 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(config['arch']['args']['text_params']['model'],
11 | TOKENIZERS_PARALLELISM=False)
12 |
13 | def forward(self, x):
14 | if self.tokenizer is not None:
15 | x = self.tokenizer(x, return_tensors='pt', padding=True,
16 | truncation=True)
17 | x = {key: val.to(self.device) for key, val in x.items()}
18 | text_embeddings_all = self.text_model(**x).last_hidden_state
19 | # print(text_embeddings_all.size()) # batch_size, sequence_length, hidden_size
20 | text_embeddings = text_embeddings_all[:, 0, :]
21 | return text_embeddings
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/vision_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model.vision_models.frozen import SpaceTimeTransformer
3 | import model.vision_models.clip as clip
4 |
5 |
6 | def weight_transform(model_dict, pretrain_dict):
7 | '''
8 | :return:
9 | '''
10 | weight_dict = {k[19:]:v for k, v in pretrain_dict.items() if k[19:] in model_dict and k[:19] == 'module.video_model.'}
11 | for k, v in pretrain_dict.items():
12 | print(k[19:])
13 | # pdb.set_trace()
14 | for k, v in pretrain_dict.items():
15 | if k[:15] == 'module.vid_proj':
16 | weight_dict[k[7:]] = v
17 | for k, v in weight_dict.items():
18 | print("load: {}".format(k))
19 | # print(weight_dict)
20 | model_dict.update(weight_dict)
21 | return model_dict
22 |
23 |
24 | def load_pt_weight(model):
25 | """
26 | load the object transformer weight from clip vision transformer
27 | notice some of have failed
28 | Args:
29 | model ():
30 |
31 | Returns:
32 |
33 | """
34 | checkpoint = torch.load("pretrained/cc-webvid2m-4f_stformer_b_16_224.pth.tar", map_location="cpu")
35 | pretrained_state = checkpoint['state_dict']
36 | # model.load_state_dict(vit_checkpoint, strict=False)
37 | # pretrain_model = torch.jit.load('pretrained/ViT-B-16.pt')
38 | # pretrained_state = pretrain_model.state_dict()
39 | model_state = model.state_dict()
40 | # for k, v in model_state.items():
41 | # print(k)
42 | model.load_state_dict(weight_transform(model_state, pretrained_state))
43 | return model
44 |
45 |
46 | def vision_encode_init(model_name="frozen"):
47 | # frozen
48 | preprocess = None
49 | if model_name == 'clip':
50 | full_model, preprocess = clip.load("pretrained/ViT-B-16.pt")
51 | model = full_model.visual
52 | elif model_name == 'frozen':
53 | model = SpaceTimeTransformer()
54 | load_pt_weight(model)
55 | else:
56 | print("not support")
57 | model = model.cuda()
58 | return model, preprocess
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/model/vision_models/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/parse_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from pathlib import Path
4 | from functools import reduce
5 | from operator import getitem
6 | from datetime import datetime
7 | # from logger import setup_logging
8 | # from utils import read_json, write_json
9 | import time
10 | import inspect
11 | import pdb
12 |
13 | class ConfigParser:
14 | def __init__(self, args, options='', timestamp=True, test=False):
15 | # parse default and custom cli options
16 | for opt in options:
17 | args.add_argument(*opt.flags, default=None, type=opt.type)
18 | args = args.parse_args()
19 |
20 | # if args.device:
21 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.device
22 | # if args.resume is None:
23 | # msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
24 | # assert args.config is not None, msg_no_cfg
25 | # self.cfg_fname = Path(args.config)
26 | # config = read_json(self.cfg_fname)
27 | # self.resume = None
28 | # else:
29 | # self.resume = Path(args.resume)
30 | # resume_cfg_fname = self.resume.parent / 'config.json'
31 | # config = read_json(resume_cfg_fname)
32 | # if args.config is not None:
33 | # config.update(read_json(Path(args.config)))
34 | #
35 | # # load config file and apply custom cli options
36 | # self._config = _update_config(config, options, args)
37 | #
38 | # # set save_dir where trained model and log will be saved.
39 | # save_dir = Path(self.config['trainer']['save_dir'])
40 | # timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
41 | #
42 | # exper_name = self.config['name']
43 | # self._save_dir = save_dir / 'models' / exper_name / timestamp
44 | # self._web_log_dir = save_dir / 'web' / exper_name / timestamp
45 | # self._log_dir = save_dir / 'log' / exper_name / timestamp
46 | #
47 | # if not test:
48 | # self.save_dir.mkdir(parents=True, exist_ok=True)
49 | # self.log_dir.mkdir(parents=True, exist_ok=True)
50 | #
51 | # # if set, remove all previous experiments with the current config
52 | # if vars(args).get("purge_exp_dir", False):
53 | # for dirpath in (self._save_dir, self._log_dir, self._web_log_dir):
54 | # config_dir = dirpath.parent
55 | # existing = list(config_dir.glob("*"))
56 | # print(f"purging {len(existing)} directories from config_dir...")
57 | # tic = time.time()
58 | # os.system(f"rm -rf {config_dir}")
59 | # print(f"Finished purge in {time.time() - tic:.3f}s")
60 | #
61 | # # save updated config file to the checkpoint dir
62 | # if not test:
63 | # write_json(self.config, self.save_dir / 'config.json')
64 | #
65 | # # configure logging module
66 | # setup_logging(self.log_dir)
67 | # self.log_levels = {
68 | # 0: logging.WARNING,
69 | # 1: logging.INFO,
70 | # 2: logging.DEBUG
71 | # }
72 |
73 | def initialize(self, name, module, *args, index=None, **kwargs):
74 | """
75 | finds a function handle with the name given as 'type' in config, and returns the
76 | instance initialized with corresponding keyword args given as 'args'.
77 | """
78 | if index is None:
79 | module_name = self[name]['type']
80 | module_args = dict(self[name]['args'])
81 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
82 | module_args.update(kwargs)
83 | else:
84 | module_name = self[name][index]['type']
85 | module_args = dict(self[name][index]['args'])
86 | # pdb.set_trace()
87 | # if parameter not in config subdict, then check if it's in global config.
88 | signature = inspect.signature(getattr(module, module_name).__init__)
89 | print(module_name)
90 | for param in signature.parameters.keys():
91 | if param not in module_args and param in self.config:
92 | module_args[param] = self[param]
93 | # pdb.set_trace()
94 |
95 | return getattr(module, module_name)(*args, **module_args)
96 |
97 | def __getitem__(self, name):
98 | return self.config[name]
99 |
100 | def get_logger(self, name, verbosity=2):
101 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
102 | self.log_levels.keys())
103 | assert verbosity in self.log_levels, msg_verbosity
104 | logger = logging.getLogger(name)
105 | logger.setLevel(self.log_levels[verbosity])
106 | return logger
107 |
108 | # setting read-only attributes
109 | @property
110 | def config(self):
111 | return self._config
112 |
113 | @property
114 | def save_dir(self):
115 | return self._save_dir
116 |
117 | @property
118 | def log_dir(self):
119 | return self._log_dir
120 |
121 |
122 | # helper functions used to update config dict with custom cli options
123 | def _update_config(config, options, args):
124 | for opt in options:
125 | value = getattr(args, _get_opt_name(opt.flags))
126 | if value is not None:
127 | _set_by_path(config, opt.target, value)
128 | return config
129 |
130 |
131 | def _get_opt_name(flags):
132 | for flg in flags:
133 | if flg.startswith('--'):
134 | return flg.replace('--', '')
135 | return flags[0].replace('--', '')
136 |
137 |
138 | def _set_by_path(tree, keys, value):
139 | """Set a value in a nested object in tree by sequence of keys."""
140 | _get_by_path(tree, keys[:-1])[keys[-1]] = value
141 |
142 |
143 | def _get_by_path(tree, keys):
144 | """Access a nested object in tree by sequence of keys."""
145 | return reduce(getitem, keys, tree)
146 |
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/patch_mask.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | def patch_all_masks_from_bbox(bboxs, patch_rows=14):
7 | # generate patch masks from all bboxs
8 | # notice here bbox region is [1:3][0:2]
9 | patch_masks = np.zeros((patch_rows, patch_rows))
10 | bboxs[:, :4] = bboxs[:, :4] * patch_rows
11 | for index, bbox in enumerate(bboxs):
12 | if bbox[4] > 7 and bbox[5] > 7:
13 | bbox[0] += 1 / 3 * (bbox[2] - bbox[0])
14 | bbox[1] += 1 / 3 * (bbox[3] - bbox[1])
15 | bbox[2] -= 1 / 3 * (bbox[2] - bbox[0])
16 | bbox[3] -= 1 / 3 * (bbox[3] - bbox[1])
17 | patch_masks[int(bbox[1]):math.ceil(bbox[3]), int(bbox[0]):math.ceil(bbox[2])] = 1
18 | return patch_masks
19 |
20 |
21 | def image_mask_from_bbox(bboxs, img_shape):
22 | # print(img_shape)
23 | print(bboxs)
24 | w, h = img_shape[1:]
25 | mask = np.zeros((w, h))
26 | for index, bbox in enumerate(bboxs):
27 | print(int(bbox[0].item()), int(bbox[2].item()), int(bbox[1].item()), int(bbox[3].item()))
28 | # # print(bbox)
29 | if bbox[4] > 0.5 and bbox[5] > 0.5:
30 | bbox[0] += 1 / 3 * (bbox[2] - bbox[0])
31 | bbox[1] += 1 / 3 * (bbox[3] - bbox[1])
32 | bbox[2] -= 1 / 3 * (bbox[2] - bbox[0])
33 | bbox[3] -= 1 / 3 * (bbox[3] - bbox[1])
34 | # print(bbox[0])
35 | bbox[0] = bbox[0] * w
36 | bbox[1] = bbox[1] * h
37 | bbox[2] = bbox[2] * w
38 | bbox[3] = bbox[3] * h
39 | mask[int(bbox[0].item()): int(bbox[2].item()), int(bbox[1].item()):int(bbox[3].item())] = 1
40 | print(mask)
41 | return mask
42 |
43 |
44 | def visualize_mask(video, bboxs, out_path):
45 | """
46 | visualize three samples frames and show the masked videos
47 | Args:
48 | video:
49 | bboxs:
50 | out_path:
51 |
52 | Returns:
53 |
54 | """
55 | num_frames = len(video)
56 | out_imgs = None
57 | for index in range(num_frames):
58 | img = video[index] * 255.
59 | bbox_10 = bboxs[index]
60 | masks = image_mask_from_bbox(bbox_10, img.shape)
61 | mask_img = img * masks
62 | if out_imgs is None:
63 | out_imgs = np.concatenate((img, mask_img), axis=2)
64 | else:
65 | out_imgs = np.concatenate((out_imgs, mask_img), axis=2)
66 | # print(out_imgs)
67 | # print(out_imgs.shape)
68 | out_imgs = np.moveaxis(out_imgs, 0, 2)
69 | # print(out_imgs)
70 | print(out_imgs.shape)
71 | cv2.imwrite('{}.png'.format(out_path), out_imgs)
72 |
73 | # def visualize_mask(video, bboxs, out_path):
74 | # """
75 | # visualize three samples frames and show the masked videos
76 | # Args:
77 | # video:
78 | # bboxs:
79 | # out_path:
80 | #
81 | # Returns:
82 | #
83 | # """
84 | # num_frames = len(video)
85 | # out_imgs = None
86 | # for index in range(num_frames):
87 | # img = video[index] * 255.
88 | # img = img.permute(1, 2, 0)
89 | # print(img.shape)
90 | # img = cv2.resize(np.float32(img), (14, 14))
91 | # bbox_10 = bboxs[index]
92 | # # masks = image_mask_from_bbox(bbox_10, img.shape)
93 | # masks = patch_all_masks_from_bbox(bbox_10)
94 | # # print(masks)
95 | # mask_img = img * np.expand_dims(masks, axis=2)
96 | # if out_imgs is None:
97 | # out_imgs = np.concatenate((img, mask_img), axis=1)
98 | # else:
99 | # out_imgs = np.concatenate((out_imgs, mask_img), axis=1)
100 | # # print(out_imgs)
101 | # # print(out_imgs.shape)
102 | # # out_imgs = np.moveaxis(out_imgs, 0, 2)
103 | # # print(out_imgs)
104 | # print(out_imgs.shape)
105 | # cv2.imwrite('{}.png'.format(out_path), out_imgs)
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/utils/nltk_test.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | nltk.data.path.append("pretrained/nltk_data")
3 |
4 |
5 | def check_nouns(words):
6 | is_noun = lambda pos: pos[:2] == 'NN'
7 | # do the nlp stuff
8 | tokenized = nltk.word_tokenize(words)
9 | nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)]
10 | if len(nouns) > 0:
11 | return True
12 | else:
13 | return False
14 |
15 | lines = 'lines is some string of words'
16 | # function to test if something is a noun
17 | is_noun = lambda pos: pos[:2] == 'NN'
18 | # do the nlp stuff
19 | tokenized = nltk.word_tokenize(lines)
20 | nouns = [word for (word, pos) in nltk.pos_tag(tokenized) if is_noun(pos)]
21 |
22 | print(nouns)
23 | word = 'woman'
24 | print(check_nouns(word))
--------------------------------------------------------------------------------
/Visualization/Cross_Modality_Transformer_Visualization/utils/read_bboxs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import torch
4 |
5 |
6 | def read_bbox_from_pickle(object_path, top_k=5, v=1):
7 | frame1 = np.load(object_path, allow_pickle=True)
8 | boxes = frame1['bbox']
9 | # rank features and boxes according to confidence
10 | confident = frame1['info'].item()['objects_conf']
11 | condident_indices = np.argsort(confident)[::-1]
12 | boxes = boxes[condident_indices]
13 | object_ids = frame1['info'].item()['objects_id']
14 | if v == 2:
15 | new_object, unique_indices = np.unique(object_ids, return_index=True)
16 | # step 2: remove region with same object class
17 | boxes = boxes[unique_indices]
18 | # padding with same elements if not enough
19 | if boxes.shape[0] < top_k:
20 | res = top_k - boxes.shape[0]
21 | boxes = np.pad(boxes, (0, res), 'edge')
22 | boxes = boxes[:top_k, :]
23 | image_width = frame1['info'].item()['image_w']
24 | image_height = frame1['info'].item()['image_h']
25 | box_width = boxes[:, 2] - boxes[:, 0]
26 | box_height = boxes[:, 3] - boxes[:, 1]
27 | scaled_width = box_width / image_width
28 | scaled_height = box_height / image_height
29 | scaled_x = boxes[:, 0] / image_width
30 | scaled_y = boxes[:, 1] / image_height
31 | scaled_width = scaled_width[..., np.newaxis]
32 | scaled_height = scaled_height[..., np.newaxis]
33 | scaled_x = scaled_x[..., np.newaxis]
34 | scaled_y = scaled_y[..., np.newaxis]
35 | spatial_features = np.concatenate(
36 | (scaled_x, scaled_y, scaled_x + scaled_width, scaled_y + scaled_height, scaled_width, scaled_height), axis=1)
37 | return torch.from_numpy(spatial_features)
38 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: frozen
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=1_gnu
10 | - av=6.1.0=py37he005a31_1000
11 | - backcall=0.2.0=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - brotlipy=0.7.0=py37h5e8e339_1001
14 | - bzip2=1.0.8=h7b6447c_0
15 | - ca-certificates=2020.12.5=ha878542_0
16 | - cairo=1.16.0=hf32fb01_1
17 | - certifi=2020.12.5=py37h89c1867_1
18 | - cffi=1.14.5=py37hc58025e_0
19 | - chardet=4.0.0=py37h89c1867_1
20 | - click=8.0.0=py37h89c1867_0
21 | - cryptography=3.4.7=py37h5d9358c_0
22 | - cudatoolkit=11.1.74=h6bb024c_0
23 | - dataclasses=0.8=pyhc8e2a94_1
24 | - decorator=5.0.7=pyhd3eb1b0_0
25 | - ffmpeg=4.0=hcdf2ecd_0
26 | - filelock=3.0.12=pyh9f0ad1d_0
27 | - fontconfig=2.13.1=h6c09931_0
28 | - freeglut=3.0.0=hf484d3e_5
29 | - freetype=2.10.4=h5ab3b9f_0
30 | - glib=2.68.1=h36276a3_0
31 | - graphite2=1.3.14=h23475e2_0
32 | - harfbuzz=1.8.8=hffaf4a1_0
33 | - hdf5=1.10.2=hba1933b_1
34 | - huggingface_hub=0.0.8=pyhd8ed1ab_0
35 | - humanize=3.5.0=pyhd8ed1ab_0
36 | - icu=58.2=he6710b0_3
37 | - idna=2.10=pyh9f0ad1d_0
38 | - importlib-metadata=4.0.1=py37h89c1867_0
39 | - importlib_metadata=4.0.1=hd8ed1ab_0
40 | - intel-openmp=2021.2.0=h06a4308_610
41 | - ipdb=0.13.7=pyhd8ed1ab_0
42 | - ipython=7.22.0=py37hb070fc8_0
43 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
44 | - jasper=2.0.14=h07fcdf6_1
45 | - jedi=0.17.0=py37_0
46 | - joblib=1.0.1=pyhd3eb1b0_0
47 | - jpeg=9b=h024ee3a_2
48 | - lcms2=2.12=h3be6417_0
49 | - ld_impl_linux-64=2.33.1=h53a641e_7
50 | - libffi=3.3=he6710b0_2
51 | - libgcc-ng=9.3.0=h2828fa1_19
52 | - libgfortran-ng=7.3.0=hdf63c60_0
53 | - libglu=9.0.0=hf484d3e_1
54 | - libgomp=9.3.0=h2828fa1_19
55 | - libopencv=3.4.2=hb342d67_1
56 | - libopus=1.3.1=h7b6447c_0
57 | - libpng=1.6.37=hbc83047_0
58 | - libstdcxx-ng=9.3.0=h6de172a_19
59 | - libtiff=4.1.0=h2733197_1
60 | - libuuid=1.0.3=h1bed415_2
61 | - libuv=1.40.0=h7b6447c_0
62 | - libvpx=1.7.0=h439df22_0
63 | - libxcb=1.14=h7b6447c_0
64 | - libxml2=2.9.10=hb55368b_3
65 | - lz4-c=1.9.3=h2531618_0
66 | - mkl=2021.2.0=h06a4308_296
67 | - mkl-service=2.3.0=py37h27cfd23_1
68 | - mkl_fft=1.3.0=py37h42c9631_2
69 | - mkl_random=1.2.1=py37ha9443f7_2
70 | - msgpack-python=1.0.2=py37hff7bd54_1
71 | - ncurses=6.2=he6710b0_1
72 | - ninja=1.10.2=hff7bd54_1
73 | - numpy=1.20.1=py37h93e21f0_0
74 | - numpy-base=1.20.1=py37h7d8b39e_0
75 | - olefile=0.46=py37_0
76 | - opencv=3.4.2=py37h6fd60c2_1
77 | - openssl=1.1.1k=h7f98852_0
78 | - packaging=20.9=pyh44b312d_0
79 | - pandas=1.1.4=py37h10a2094_0
80 | - parso=0.8.2=pyhd3eb1b0_0
81 | - pcre=8.44=he6710b0_0
82 | - pexpect=4.8.0=pyhd3eb1b0_3
83 | - pickleshare=0.7.5=pyhd3eb1b0_1003
84 | - pip=21.0.1=py37h06a4308_0
85 | - pixman=0.40.0=h7b6447c_0
86 | - prompt-toolkit=3.0.17=pyh06a4308_0
87 | - psutil=5.8.0=py37h27cfd23_1
88 | - ptyprocess=0.7.0=pyhd3eb1b0_2
89 | - py-opencv=3.4.2=py37hb342d67_1
90 | - pycparser=2.20=pyh9f0ad1d_2
91 | - pygments=2.8.1=pyhd3eb1b0_0
92 | - pyopenssl=20.0.1=pyhd8ed1ab_0
93 | - pyparsing=2.4.7=pyh9f0ad1d_0
94 | - pysocks=1.7.1=py37h89c1867_3
95 | - python=3.7.10=hdb3f193_0
96 | - python-dateutil=2.8.1=pyhd3eb1b0_0
97 | - python_abi=3.7=1_cp37m
98 | - pytorch=1.8.1=py3.7_cuda11.1_cudnn8.0.5_0
99 | - pytz=2021.1=pyhd3eb1b0_0
100 | - readline=8.1=h27cfd23_0
101 | - regex=2021.4.4=py37h5e8e339_0
102 | - requests=2.25.1=pyhd3deb0d_0
103 | - sacremoses=0.0.43=pyh9f0ad1d_0
104 | - scikit-learn=0.24.1=py37ha9443f7_0
105 | - scipy=1.6.2=py37had2a1c9_1
106 | - setuptools=52.0.0=py37h06a4308_0
107 | - six=1.15.0=py37h06a4308_0
108 | - sqlite=3.35.4=hdfb4753_0
109 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
110 | - tk=8.6.10=hbc83047_0
111 | - tokenizers=0.10.1=py37hcb7a40c_0
112 | - torchaudio=0.8.1=py37
113 | - tqdm=4.60.0=pyhd8ed1ab_0
114 | - traitlets=5.0.5=pyhd3eb1b0_0
115 | - transformers=4.6.0=pyhd8ed1ab_0
116 | - typing_extensions=3.7.4.3=pyha847dfd_0
117 | - urllib3=1.26.4=pyhd8ed1ab_0
118 | - wcwidth=0.2.5=py_0
119 | - wheel=0.36.2=pyhd3eb1b0_0
120 | - xz=5.2.5=h7b6447c_0
121 | - zipp=3.4.1=pyhd8ed1ab_0
122 | - zlib=1.2.11=h7b6447c_3
123 | - zstd=1.4.9=haebb681_0
124 | - pip:
125 | - --find-links https://download.pytorch.org/whl/torch_stable.html
126 | - attrdict==2.0.1
127 | - attrs==21.2.0
128 | - bravado==11.0.3
129 | - bravado-core==5.17.0
130 | - colorama==0.4.4
131 | - cycler==0.10.0
132 | - decord==0.6.0
133 | - dominate==2.6.0
134 | - einops==0.3.0
135 | - future==0.18.2
136 | - gitdb==4.0.7
137 | - gitpython==3.1.17
138 | - jsonpickle==1.5.2
139 | - jsonpointer==2.1
140 | - jsonref==0.2
141 | - jsonschema==3.2.0
142 | - kiwisolver==1.3.1
143 | - matplotlib==3.4.2
144 | - monotonic==1.6
145 | - munch==2.5.0
146 | - neptune-client==0.9.9
147 | - neptune-contrib==0.27.1
148 | - oauthlib==3.1.0
149 | - pillow==8.3.1
150 | - py-cpuinfo==8.0.0
151 | - pyjwt==2.1.0
152 | - pyrsistent==0.17.3
153 | - pyyaml==5.4.1
154 | - requests-oauthlib==1.3.0
155 | - rfc3987==1.3.8
156 | - sacred==0.8.2
157 | - simplejson==3.17.2
158 | - smmap==4.0.0
159 | - strict-rfc3339==0.7
160 | - swagger-spec-validator==2.7.3
161 | - timm==0.4.5
162 | - torchvision==0.9.1+cu111
163 | - webcolors==1.11.1
164 | - websocket-client==0.59.0
165 | - wrapt==1.12.1
166 | prefix: /users/maxbain/miniconda3/envs/frozen
--------------------------------------------------------------------------------
/figures/oa_main_ppl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_main_ppl.jpg
--------------------------------------------------------------------------------
/figures/oa_visualize_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_visualize_1.jpg
--------------------------------------------------------------------------------
/figures/oa_visualize_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/oa_visualize_2.jpg
--------------------------------------------------------------------------------
/figures/objects.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/objects.jpg
--------------------------------------------------------------------------------
/figures/objects_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FingerRec/OA-Transformer/e38bc99d03ceec5f7ae4d5a072555f82b80d4d39/figures/objects_2.png
--------------------------------------------------------------------------------
/object_extraction.md:
--------------------------------------------------------------------------------
1 | ## Install
2 |
3 |
4 | Please follow [BUTD](https://github.com/MILVLG/bottom-up-attention.pytorch) to install detectron2.
5 |
6 | Then download pretrained model from [Google Driver](https://drive.google.com/file/d/1zFqaeNMDa6HL4tBWJd5BKu_AhCkqtacs/view?usp=sharing) and place it into pretrained.
7 |
8 | ```bash
9 | cd bottom-up-attention.pytorch
10 | mkdir pretrained
11 | mv [path_to_downloaded_pth] pretrained/
12 | ```
13 |
14 | ### Replace feature_extract
15 | Original BUTD provide script [feature_extract.py](https://github.com/MILVLG/bottom-up-attention.pytorch/blob/master/extract_features.py) to extract object with distributed framework ray.
16 | However, we find this tool is not stable and slowly.
17 | So we implement a 3times faster multiprocess version.
18 |
19 | Simple replace feature_extract.py with [extract_cc3m](ObjectExtractor/multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py) and [extract_wevic](ObjectExtractor/multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py).
20 |
21 | ### Webvid 2.5M
22 | ```bash
23 | python3 multiprocess_full_webvid_multiframe_complementary_modify_tsv_gen_from_video.py --mode caffe \
24 | --num-cpus 32 --gpus '0,1,2,3,4,5,6,7' \
25 | --workers_per_gpu 2 \
26 | --sampling_frames 8 \
27 | --split "train" \
28 | --dataset_dir "WebVid" \
29 | --extract-mode roi_feats \
30 | --min-max-boxes '10,100' \
31 | --config-file configs/bua-caffe/extract-bua-caffe-r101.yaml
32 | ```
33 |
34 |
35 | ### CC3M
36 | ```bash
37 | python3 multiprocess_full_cc3m_complementary_modify_tsv_gen_from_video.py \
38 | --mode caffe --num-cpus 0 --gpus '0,1,2,3,4,5,6,7' \
39 | --extract-mode roi_feats --min-max-boxes '10,100' \
40 | --config-file configs/bua-caffe/extract-bua-caffe-r101.yaml
41 | ```
42 |
43 | ### Visualization
44 | We visualize some extracted bounding boxes as below:
45 | 
--------------------------------------------------------------------------------
/train.md:
--------------------------------------------------------------------------------
1 | ## Install
2 |
3 | ```
4 | conda env create
5 | pip install decord
6 | pip install ftfy
7 | cd OATrans
8 | mkdir data;
9 | mkdir exps;
10 | ```
11 |
12 |
13 | ## Pre-training
14 |
15 | ### Normal OA-Transformer for retrieval
16 | ```bash
17 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 29132 \
18 | train_dist_multi.py \
19 | --config configs/pt/cc3m_webvid/local-region-loss.json # --launcher pytorch
20 | ```
21 |
22 | ### Region-sensitive OA-Transformer for Grounding
23 |
24 | ```bash
25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 --master_port 29132 \
26 | train_dist_multi_region_mem.py \
27 | --config configs/pt/cc3m_webvid/local-region-loss.json # --launcher pytorch
28 | ```
29 |
30 |
31 | ## Downstream
32 |
33 | ### zero-shot
34 | ```bash
35 | CUDA_VISIBLE_DEVICES=2,3 python -m torch.distributed.launch --nproc_per_node 2 \
36 | --master_port 29142 test_region_mem.py --config configs/ft/msrvtt/zsl/normal.json
37 | ```
38 |
39 | ### fine-tuning
40 | ```bash
41 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py --config configs/ft/msrvtt/fine_tune/normal_1_cl.json
42 | ```
--------------------------------------------------------------------------------
/visualization.md:
--------------------------------------------------------------------------------
1 | # Cross-modality Visualization Tools
2 |
3 | ## Heatmap Visualization
4 |
5 | ```bash
6 | cd Visualizaztion/Cross_Modality_Transformer_Visualization
7 | mkdir pretrained
8 | cd pretrained && mkdir distillbert-base-uncased
9 | ```
10 |
11 | Then download all files in [/distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased/tree/main) and place these file in the directory distillbert-base-uncased.
12 |
13 |
14 |
15 | ### Image
16 |
17 | ```bash
18 | python main_img.py
19 | ```
20 | We provide both _feature map visualization_ and _cross-modality attention visualize._
21 |
22 |
23 | ### Video
24 |
25 | ```bash
26 | python main_video.py
27 | ```
28 |
29 | ## Binary Map Visualization
30 |
31 | If we ask the model to learn fine-grained align we can generate binary map as below:
32 |
33 | Refer to file [test_region_mem.py](OATrans/test_region_mem.py) for details.
34 |
35 | 
--------------------------------------------------------------------------------