├── GenerativeImage2Text ├── __init__.py ├── alvs_inter.py ├── common.py ├── data_layer │ ├── __init__.py │ ├── builder.py │ └── transform.py ├── data_prepare.py ├── dataloader.py ├── inference.py ├── layers │ ├── CLIP │ │ ├── __init__.py │ │ ├── clip.py │ │ └── model.py │ ├── __init__.py │ ├── bert │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── file_utils.py │ │ ├── modeling_bert.py │ │ └── modeling_utils.py │ └── decoder.py ├── main_alvs.py ├── model.py ├── torch_common.py ├── train.py ├── trie_decoder.py └── tsv_io.py ├── LAST ├── grounding_gen │ ├── dataloader_grd_gen.py │ ├── dataloader_grd_gen_vis.py │ ├── main_grd_gen.py │ ├── main_grd_gen_vis.py │ ├── nets_grd_gen.py │ ├── nets_grd_gen_vis.py │ └── visual_net.py └── net_grd_avst │ ├── __init__.py │ ├── alvs_inter.py │ ├── audio_layers.py │ ├── audio_others.py │ ├── dataloader_alvs.py │ ├── dataloader_avst.py │ ├── main_alvs.py │ ├── main_avst.py │ ├── net_alvs.py │ ├── net_avst.py │ ├── net_encoder.py │ └── visual_net.py ├── LAVISH ├── grounding_gen │ ├── dataloader_grd_gen.py │ ├── dataloader_grd_gen_vis.py │ ├── main_grd_gen.py │ ├── main_grd_gen_vis.py │ ├── nets_grd_gen.py │ ├── nets_grd_gen_vis.py │ └── visual_net.py └── net_grd_avst │ ├── __init__.py │ ├── alvs_inter.py │ ├── audio_layers.py │ ├── audio_others.py │ ├── base_options.py │ ├── compute_mean.py │ ├── dataloader_alvs.py │ ├── dataloader_avst.py │ ├── main_alvs.py │ ├── main_avst.py │ ├── net_alvs.py │ ├── net_avst.py │ ├── net_encoder.py │ └── visual_net.py ├── MovieSeq ├── example.ipynb ├── movieseq.py └── utils.py ├── NarrativeBridge ├── __init__.py ├── alvs_inter.py ├── common.py ├── data_layer │ ├── __init__.py │ ├── builder.py │ └── transform.py ├── data_prepare.py ├── dataloader.py ├── inference.py ├── layers │ ├── CLIP │ │ ├── __init__.py │ │ ├── clip.py │ │ └── model.py │ ├── __init__.py │ ├── bert │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── file_utils.py │ │ ├── modeling_bert.py │ │ └── modeling_utils.py │ └── decoder.py ├── main_alvs.py ├── model.py └── train.py ├── README.md ├── STG-CMA ├── grounding_gen │ ├── dataloader_grd_gen.py │ ├── esc_config.py │ ├── htsat.py │ ├── layers.py │ ├── main_grd_gen.py │ ├── nets_grd_gen.py │ ├── utils.py │ └── visual_net.py └── net_grd_avst │ ├── __init__.py │ ├── alvs_inter.py │ ├── audio_layers.py │ ├── audio_others.py │ ├── base_options.py │ ├── compute_mean.py │ ├── dataloader_alvs.py │ ├── dataloader_avst.py │ ├── esc_config.py │ ├── htsat.py │ ├── layers.py │ ├── main_alvs.py │ ├── main_avst.py │ ├── net_alvs.py │ ├── net_avst.py │ ├── net_encoders.py │ ├── utils.py │ └── visual_net.py ├── VindLU ├── configs │ ├── beit-base-patch16-224-pt22k-ft22k.json │ ├── config_bert.json │ ├── config_bert_large.json │ ├── data.py │ ├── model.py │ ├── pretrain.py │ ├── qa.py │ ├── qa_anet.py │ ├── qa_msrvtt.py │ ├── ret_anet.py │ ├── ret_coco.py │ ├── ret_didemo.py │ ├── ret_flickr.py │ ├── ret_msrvtt.py │ ├── ret_msrvtt_9k.py │ ├── ret_msrvtt_mc.py │ ├── ret_ssv2_label.py │ ├── ret_ssv2_template.py │ └── tvqa.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ ├── caption_dataset.py │ ├── dataloader.py │ ├── qa_dataset.py │ ├── sqlite_dataset.py │ ├── utils.py │ └── video_utils.py ├── miscs │ └── test_flops.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── beit │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ └── st_beit.py │ │ ├── bert │ │ │ ├── __init__.py │ │ │ ├── builder.py │ │ │ ├── tokenization_bert.py │ │ │ └── xbert.py │ │ └── omnivore_swin │ │ │ ├── builder.py │ │ │ ├── distributed.py │ │ │ ├── omnivore_swin.py │ │ │ ├── utils.py │ │ │ └── video_swin.py │ ├── criterions.py │ ├── modules │ │ ├── __init__.py │ │ └── temporal_model.py │ ├── utils.py │ ├── vindlu.py │ ├── vindlu_qa.py │ └── vindlu_tvqa.py ├── preprocess │ ├── compress.py │ ├── create_sqlite_db.py │ ├── gen_webvid10m_label.py │ └── utils.py ├── tasks │ ├── pretrain.py │ ├── retrieval.py │ ├── retrieval_mc.py │ ├── retrieval_utils.py │ ├── shared_utils.py │ ├── trainer.py │ ├── tvqa.py │ ├── vqa.py │ └── vqa_utils.py ├── tests │ └── test_cfg.py ├── tools │ ├── run.py │ ├── submit.sh │ └── utils.py ├── utils │ ├── basic_utils.py │ ├── config.py │ ├── config_utils.py │ ├── distributed.py │ ├── easydict.py │ ├── logger.py │ ├── optimizer.py │ └── scheduler.py └── vl.yml ├── figs └── TMW_pipeline.png ├── testa ├── alvs_inter.py ├── configs │ ├── bert_config.json │ ├── config.py │ ├── med_config.json │ ├── pretrain_timesformer.yaml │ ├── pretrain_timesformer_coco.yaml │ ├── pretrain_timesformer_from_blip.yaml │ ├── pretrain_video.yaml │ ├── retrieval_activitynet_f32.yaml │ ├── retrieval_activitynet_f96.yaml │ ├── retrieval_condensedmovies_f32.yaml │ ├── retrieval_condensedmovies_f96.yaml │ ├── retrieval_didemo_f32.yaml │ ├── retrieval_didemo_f96.yaml │ ├── retrieval_queryd_f32.yaml │ ├── retrieval_queryd_f96.yaml │ └── vqa_activitynet_f32.yaml ├── data │ ├── __init__.py │ ├── compress.py │ ├── pretrain_dataset.py │ ├── randaugment.py │ ├── utils.py │ ├── video_dataset.py │ └── vqa_dataset.py ├── main_alvs.py ├── model_stats.py ├── models │ ├── __init__.py │ ├── blip.py │ ├── blip_pretrain.py │ ├── med.py │ ├── testa_retrieval.py │ ├── testa_vqa.py │ ├── timesformer │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── defaults.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── cv2_transform.py │ │ │ ├── decoder.py │ │ │ ├── kinetics.py │ │ │ ├── loader.py │ │ │ ├── multigrid_helper.py │ │ │ ├── ssv2.py │ │ │ ├── transform.py │ │ │ ├── utils.py │ │ │ └── video_container.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── batchnorm_helper.py │ │ │ ├── build.py │ │ │ ├── conv2d_same.py │ │ │ ├── custom_video_model_builder.py │ │ │ ├── features.py │ │ │ ├── head_helper.py │ │ │ ├── helpers.py │ │ │ ├── linear.py │ │ │ ├── losses.py │ │ │ ├── nonlocal_helper.py │ │ │ ├── operators.py │ │ │ ├── optimizer.py │ │ │ ├── resnet_helper.py │ │ │ ├── stem_helper.py │ │ │ ├── video_model_builder.py │ │ │ ├── vit.py │ │ │ └── vit_utils.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── ava_eval_helper.py │ │ │ ├── ava_evaluation │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt │ │ │ │ ├── label_map_util.py │ │ │ │ ├── metrics.py │ │ │ │ ├── np_box_list.py │ │ │ │ ├── np_box_list_ops.py │ │ │ │ ├── np_box_mask_list.py │ │ │ │ ├── np_box_mask_list_ops.py │ │ │ │ ├── np_box_ops.py │ │ │ │ ├── np_mask_ops.py │ │ │ │ ├── object_detection_evaluation.py │ │ │ │ ├── per_image_evaluation.py │ │ │ │ └── standard_fields.py │ │ │ ├── benchmark.py │ │ │ ├── bn_helper.py │ │ │ ├── c2_model_loading.py │ │ │ ├── checkpoint.py │ │ │ ├── distributed.py │ │ │ ├── env.py │ │ │ ├── logging.py │ │ │ ├── lr_policy.py │ │ │ ├── meters.py │ │ │ ├── metrics.py │ │ │ ├── misc.py │ │ │ ├── multigrid.py │ │ │ ├── multiprocessing.py │ │ │ ├── parser.py │ │ │ └── weight_init_helper.py │ │ └── visualization │ │ │ ├── __init__.py │ │ │ ├── tensorboard_vis.py │ │ │ └── utils.py │ └── vit.py ├── testa │ ├── __init__.py │ ├── merge.py │ ├── merge_original.py │ ├── patch │ │ ├── __init__.py │ │ ├── timesformer.py │ │ ├── timesformer_prune.py │ │ └── vit.py │ ├── utils.py │ └── vis.py ├── train_video_qa.py ├── train_video_retrieval.py ├── transform │ └── randaugment.py ├── utils.py └── video_pretrain.py └── v2tactiongraph ├── alvs_inter.py ├── dataloaders ├── dataloader_msrvtt.py ├── dataloader_msrvtt_caption.py ├── dataloader_msrvtt_patch.py ├── dataloader_msvd.py ├── dataloader_msvd_caption.py ├── dataloader_msvd_patch.py └── rawvideo_util.py ├── feature_extractor ├── action_spatio_temporal_graph_feature_extractor.ipynb ├── clip4clip_theta_2_feature_extraction.ipynb ├── grid_based_spatial_action_graph.ipynb ├── grid_node_theta_1_feature_extractor.ipynb ├── model │ └── i3d │ │ ├── InceptionI3d.py │ │ └── i3d.ipynb ├── object_based_spatial_action_graph.ipynb ├── object_node_theta_1_feature_extractor.ipynb ├── temporal_similarity_graph.ipynb ├── transform-graph-to-geometric.ipynb ├── util.py └── utility │ ├── dataset.py │ ├── util.py │ └── vocabulary.py ├── main_alvs.py ├── main_task_caption_GNN.py ├── modules ├── __init__.py ├── beam.py ├── decoder-base │ └── decoder_config.json ├── file_utils.py ├── gnn │ ├── GATConvolution.py │ ├── GATv2Convolution.py │ └── TransformerConvolution.py ├── graph_gat_modelling.py ├── graph_gatv2_modelling.py ├── graph_transformer_modelling.py ├── modeling.py ├── module_bert.py ├── module_decoder.py ├── module_visual.py ├── optimization.py ├── tokenization.py ├── until_config.py ├── until_module.py └── visual-base │ └── visual_config.json └── scripts ├── msrvtt_train_GNN.sh └── msvd_train_GNN.sh /GenerativeImage2Text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/GenerativeImage2Text/__init__.py -------------------------------------------------------------------------------- /GenerativeImage2Text/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /GenerativeImage2Text/data_layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/GenerativeImage2Text/data_layer/__init__.py -------------------------------------------------------------------------------- /GenerativeImage2Text/data_layer/builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | 4 | 5 | def collate_fn(batch): 6 | # this function is designed to support any customized type and to be compatible 7 | # with the default collate function 8 | ele = batch[0] 9 | if isinstance(ele, dict): 10 | return {key: collate_fn([d[key] for d in batch]) for key in ele} 11 | elif isinstance(ele, (tuple, list)): 12 | return [collate_fn(x) for x in zip(*batch)] 13 | else: 14 | if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0: 15 | if not all(b.shape == batch[0].shape for b in batch[1:]): 16 | assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:]) 17 | shape = torch.tensor([b.shape for b in batch]) 18 | max_shape = tuple(shape.max(dim=0)[0].tolist()) 19 | batch2 = [] 20 | for b in batch: 21 | if any(c < m for c, m in zip(b.shape, max_shape)): 22 | b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device) 23 | if b.dim() == 1: 24 | b2[:b.shape[0]] = b 25 | elif b.dim() == 2: 26 | b2[:b.shape[0], :b.shape[1]] = b 27 | elif b.dim() == 3: 28 | b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b 29 | else: 30 | raise NotImplementedError 31 | b = b2 32 | batch2.append(b) 33 | batch = batch2 34 | return default_collate(batch) 35 | 36 | 37 | -------------------------------------------------------------------------------- /GenerativeImage2Text/data_prepare.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os.path as op 3 | from pprint import pformat 4 | from .common import parse_general_args, json_dump 5 | from .common import qd_tqdm as tqdm 6 | import logging 7 | from .common import load_list_file, read_to_buffer 8 | import json 9 | from .common import init_logging, hash_sha1, write_to_file 10 | from .taxonomy import noffset_to_synset, get_nick_name 11 | from .tsv_io import tsv_writer 12 | 13 | 14 | def get_imagenet_unique_nick_names(): 15 | txt = './aux_data/imagenet/LOC_synset_mapping.txt' 16 | noffsets = load_list_file(txt) 17 | noffsets = [x.split(' ')[0] for x in noffsets] 18 | assert hash_sha1(noffsets) == 'fb9737bbca048296520bc35582947b3755aa948f' 19 | nick_name_overwrite = { 20 | 'n02012849': 'crane bird', 21 | 'n03126707': 'crane machine', 22 | 'n02113186': 'cardigan dog', 23 | 'n02963159': 'cardigan jacket', 24 | 'n03710637': 'maillot tights', 25 | 'n03710721': 'maillot bathing suit', 26 | } 27 | nick_names = [nick_name_overwrite[n] if n in nick_name_overwrite else 28 | get_nick_name(noffset_to_synset(n)) for n in noffsets] 29 | assert hash_sha1(nick_names) == '9c1dd12d7e8120820ffd44b75ebe8b78b659a4f4' 30 | assert len(set(nick_names)) == len(nick_names) 31 | assert len(set(map(lambda n: n.replace(' ', ''), nick_names))) == len(nick_names) 32 | return nick_names 33 | 34 | def generate_imagenet_unique_names(): 35 | nick_names = get_imagenet_unique_nick_names() 36 | write_to_file('\n'.join(nick_names), 37 | './aux_data/imagenet/imagenet_unique_readable_names.txt') 38 | 39 | 40 | def prepare_coco_test(): 41 | image_folder = 'aux_data/raw_data/val2014' 42 | json_file = 'aux_data/raw_data/dataset_coco.json' 43 | infos = json.loads(read_to_buffer(json_file))['images'] 44 | infos = [i for i in infos if i['split'] == 'test'] 45 | assert all(i['filepath'] == 'val2014' for i in infos) 46 | def gen_rows(): 47 | for i in tqdm(infos): 48 | payload = base64.b64encode(read_to_buffer(op.join(image_folder, 49 | i['filename']))) 50 | yield i['cocoid'], payload 51 | tsv_writer(gen_rows(), 'data/coco_caption/test.img.tsv') 52 | 53 | def gen_cap_rows(): 54 | for i in tqdm(infos): 55 | caps = [{'caption': j['raw']} for j in i['sentences']] 56 | yield i['cocoid'], json_dump(caps) 57 | tsv_writer(gen_cap_rows(), 'data/coco_caption/test.caption.tsv') 58 | 59 | if __name__ == '__main__': 60 | init_logging() 61 | kwargs = parse_general_args() 62 | logging.info('param:\n{}'.format(pformat(kwargs))) 63 | function_name = kwargs['type'] 64 | del kwargs['type'] 65 | locals()[function_name](**kwargs) 66 | 67 | -------------------------------------------------------------------------------- /GenerativeImage2Text/layers/CLIP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/GenerativeImage2Text/layers/CLIP/__init__.py -------------------------------------------------------------------------------- /GenerativeImage2Text/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/GenerativeImage2Text/layers/__init__.py -------------------------------------------------------------------------------- /GenerativeImage2Text/layers/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | from .modeling_bert import BertConfig 3 | 4 | -------------------------------------------------------------------------------- /GenerativeImage2Text/layers/bert/activations.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import logging 7 | 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | def swish(x): 13 | return x * torch.sigmoid(x) 14 | 15 | 16 | def _gelu_python(x): 17 | """ 18 | Original Implementation of the gelu activation function in Google Bert repo when initially created. For 19 | information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + 20 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in 21 | torch.nn.functional Also see https://arxiv.org/abs/1606.08415 22 | """ 23 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 24 | 25 | 26 | def gelu_new(x): 27 | """ 28 | Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). Also see 29 | https://arxiv.org/abs/1606.08415 30 | """ 31 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 32 | 33 | 34 | if torch.__version__ < "1.4.0": 35 | gelu = _gelu_python 36 | else: 37 | gelu = F.gelu 38 | 39 | 40 | def gelu_fast(x): 41 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 42 | 43 | 44 | def mish(x): 45 | return x * torch.tanh(torch.nn.functional.softplus(x)) 46 | 47 | 48 | def linear_act(x): 49 | return x 50 | 51 | 52 | ACT2FN = { 53 | "relu": F.relu, 54 | "swish": swish, 55 | #"gelu": gelu, 56 | 'gelu': _gelu_python, 57 | "tanh": torch.tanh, 58 | "gelu_new": gelu_new, 59 | "gelu_fast": gelu_fast, 60 | "mish": mish, 61 | "linear": linear_act, 62 | "sigmoid": torch.sigmoid, 63 | } 64 | 65 | 66 | def get_activation(activation_string): 67 | if activation_string in ACT2FN: 68 | return ACT2FN[activation_string] 69 | else: 70 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 71 | -------------------------------------------------------------------------------- /GenerativeImage2Text/model.py: -------------------------------------------------------------------------------- 1 | from .torch_common import resize_2d_pos_embed 2 | import torch 3 | from .layers.CLIP import clip 4 | from .layers.decoder import CaptioningModel 5 | from .layers.decoder import (TransformerDecoderTextualHead, 6 | AutoRegressiveBeamSearch, GeneratorWithBeamSearch) 7 | 8 | 9 | def get_git_model(tokenizer, param): 10 | image_encoder = get_image_encoder( 11 | param.get('image_encoder_type', 'CLIPViT_B_16'), 12 | input_resolution=param.get('test_crop_size', 224), 13 | ) 14 | text_decoder = TransformerDecoderTextualHead( 15 | visual_feature_size=param.get('visual_feature_size', 768), 16 | vocab_size=30522, 17 | hidden_size=768, 18 | num_layers=6, 19 | attention_heads=12, 20 | feedforward_size=768* 4, 21 | max_caption_length=1024, 22 | mask_future_positions=True, 23 | padding_idx=0, 24 | decoder_type='bert_en', 25 | visual_projection_type='linearLn', 26 | ) 27 | #decoder = AutoRegressiveBeamSearch( 28 | #eos_index=tokenizer.sep_token_id, 29 | #max_steps=40, 30 | #beam_size=1, 31 | #per_node_beam_size=1, 32 | #fix_missing_prefix=True, 33 | #) 34 | decoder = GeneratorWithBeamSearch( 35 | eos_index=tokenizer.sep_token_id, 36 | #max_steps=40, 37 | max_steps=1024, 38 | beam_size=4, 39 | length_penalty=0.6, 40 | ) 41 | 42 | #from .trie_decoder import TrieAutoRegressiveBeamSearch, get_trie 43 | #decoder = TrieAutoRegressiveBeamSearch( 44 | #eos_index=tokenizer.sep_token_id, 45 | #max_steps=1022, 46 | #beam_size=1, 47 | #trie=get_trie(tokenizer), 48 | #) 49 | 50 | model = CaptioningModel( 51 | image_encoder, 52 | text_decoder, 53 | decoder=decoder, 54 | sos_index=tokenizer.cls_token_id, 55 | eos_index=tokenizer.sep_token_id, 56 | tokenizer=tokenizer, 57 | use_history_for_infer=True, 58 | loss_type='smooth', 59 | num_image_with_embedding=param.get('num_image_with_embedding') 60 | ) 61 | return model 62 | 63 | def get_image_encoder(encoder_type, input_resolution=224): 64 | name_map = { 65 | 'CLIPViT_B_16': 'ViT-B/16', 66 | 'CLIPViT_L_14': 'ViT-L/14', 67 | } 68 | name_in_clip = name_map[encoder_type] 69 | model, _ = clip.load(name_in_clip, device='cpu', jit=False) 70 | model = model.train() 71 | ret = model.visual 72 | ret.to(torch.float32) 73 | ret.output_grid = True 74 | ret.grid_after_ln = True 75 | if ret.input_resolution != input_resolution: 76 | if encoder_type in ['CLIPViT_B_16', 'CLIPViT_L_14']: 77 | pos = ret.positional_embedding 78 | patch_size = ret.conv1.kernel_size[0] 79 | else: 80 | pos = ret.attnpool.positional_embedding 81 | patch_size = 32 82 | p2 = resize_2d_pos_embed(pos, 83 | ret.input_resolution, 84 | patch_size, 85 | input_resolution) 86 | ret.input_resolution = input_resolution 87 | if encoder_type in ['CLIPViT_B_16', 'CLIPViT_L_14']: 88 | ret.positional_embedding = torch.nn.Parameter(p2) 89 | else: 90 | ret.attnpool.positional_embedding = torch.nn.Parameter(p2) 91 | return ret 92 | 93 | -------------------------------------------------------------------------------- /LAST/grounding_gen/nets_grd_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from visual_net import resnet18 8 | 9 | 10 | class AVQA_AVatt_Grounding(nn.Module): 11 | 12 | def __init__(self): 13 | super(AVQA_AVatt_Grounding, self).__init__() 14 | 15 | # for features 16 | self.fc_a1 = nn.Linear(128, 512) 17 | self.fc_a2=nn.Linear(512,512) 18 | 19 | # visual 20 | self.visual_net = resnet18(pretrained=True) 21 | 22 | # combine 23 | self.fc1 = nn.Linear(1024, 512) 24 | self.relu1 = nn.ReLU() 25 | self.fc2 = nn.Linear(512, 256) 26 | self.relu2 = nn.ReLU() 27 | self.fc3 = nn.Linear(256, 128) 28 | self.relu3 = nn.ReLU() 29 | self.fc4 = nn.Linear(128, 2) 30 | self.relu4 = nn.ReLU() 31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 32 | 33 | self.fc_gl=nn.Linear(1024,512) 34 | self.tanh = nn.Tanh() 35 | 36 | 37 | def forward(self, video_id, audio, visual): 38 | 39 | ## audio features 40 | audio_feat = F.relu(self.fc_a1(audio)) 41 | audio_feat=self.fc_a2(audio_feat) # [16, 20, 512] 42 | (B, T, C) = audio_feat.size() 43 | audio_feat = audio_feat.view(B*T, C) # [320, 512] 44 | 45 | ## visual, input: [16, 20, 3, 224, 224] 46 | (B, T, C, H, W) = visual.size() 47 | visual = visual.view(B * T, C, H, W) # [320, 3, 224, 224] 48 | 49 | v_feat_out_res18 = self.visual_net(visual) # [320, 512, 14, 14] 50 | v_feat=self.avgpool(v_feat_out_res18) 51 | visual_feat_before_grounding=v_feat.squeeze() # 320 512 52 | 53 | (B, C, H, W) = v_feat_out_res18.size() 54 | v_feat = v_feat_out_res18.view(B, C, H * W) 55 | v_feat = v_feat.permute(0, 2, 1) # B, HxW, C 56 | visual = nn.functional.normalize(v_feat, dim=2) 57 | 58 | ## audio-visual grounding 59 | audio_feat_aa = audio_feat.unsqueeze(-1) # [320, 512, 1] 60 | audio_feat_aa = nn.functional.normalize(audio_feat_aa, dim=1) 61 | visual_feat = visual 62 | x2_va = torch.matmul(visual_feat, audio_feat_aa).squeeze() 63 | 64 | x2_p = F.softmax(x2_va, dim=-1).unsqueeze(-2) # [320, 1, 196] 65 | visual_feat_grd = torch.matmul(x2_p, visual_feat) 66 | visual_feat_grd = visual_feat_grd.squeeze() # [320, 512] 67 | 68 | visual_gl=torch.cat((visual_feat_before_grounding,visual_feat_grd),dim=-1) 69 | visual_feat_grd=self.tanh(visual_gl) 70 | visual_feat_grd=self.fc_gl(visual_feat_grd) 71 | 72 | # combine a and v 73 | feat = torch.cat((audio_feat, visual_feat_grd), dim=-1) # [320, 1024] 74 | 75 | feat = F.relu(self.fc1(feat)) # (1024, 512) 76 | feat = F.relu(self.fc2(feat)) # (512, 256) 77 | feat = F.relu(self.fc3(feat)) # (256, 128) 78 | feat = self.fc4(feat) # (128, 2) 79 | 80 | return feat 81 | -------------------------------------------------------------------------------- /LAST/grounding_gen/nets_grd_gen_vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from visual_net import resnet18 8 | 9 | 10 | class AVQA_AVatt_Grounding(nn.Module): 11 | 12 | def __init__(self): 13 | super(AVQA_AVatt_Grounding, self).__init__() 14 | 15 | # for features 16 | self.fc_a1 = nn.Linear(128, 512) 17 | self.fc_a2=nn.Linear(512,512) 18 | 19 | # visual 20 | self.visual_net = resnet18(pretrained=True) 21 | 22 | # combine 23 | self.fc1 = nn.Linear(1024, 512) 24 | self.relu1 = nn.ReLU() 25 | self.fc2 = nn.Linear(512, 256) 26 | self.relu2 = nn.ReLU() 27 | self.fc3 = nn.Linear(256, 128) 28 | self.relu3 = nn.ReLU() 29 | self.fc4 = nn.Linear(128, 2) 30 | self.relu4 = nn.ReLU() 31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 32 | 33 | self.fc_gl=nn.Linear(1024,512) 34 | self.tanh = nn.Tanh() 35 | 36 | 37 | def forward(self, video_id, audio, visual): 38 | 39 | ## audio features 40 | audio_feat = F.relu(self.fc_a1(audio)) 41 | audio_feat=self.fc_a2(audio_feat) # [16, 20, 512] 42 | (B, T, C) = audio_feat.size() 43 | audio_feat = audio_feat.view(B*T, C) # [320, 512] 44 | 45 | ## visual, input: [16, 20, 3, 224, 224] 46 | (B, T, C, H, W) = visual.size() 47 | visual = visual.view(B * T, C, H, W) # [320, 3, 224, 224] 48 | 49 | v_feat_out_res18 = self.visual_net(visual) # [320, 512, 14, 14] 50 | v_feat=self.avgpool(v_feat_out_res18) 51 | visual_feat_before_grounding=v_feat.squeeze() # 320 512 52 | 53 | (B, C, H, W) = v_feat_out_res18.size() 54 | v_feat = v_feat_out_res18.view(B, C, H * W) 55 | v_feat = v_feat.permute(0, 2, 1) # B, HxW, C 56 | visual = nn.functional.normalize(v_feat, dim=2) 57 | 58 | ## audio-visual grounding 59 | audio_feat_aa = audio_feat.unsqueeze(-1) # [320, 512, 1] 60 | audio_feat_aa = nn.functional.normalize(audio_feat_aa, dim=1) 61 | visual_feat = visual 62 | x2_va = torch.matmul(visual_feat, audio_feat_aa).squeeze() 63 | 64 | x2_p = F.softmax(x2_va, dim=-1).unsqueeze(-2) # [320, 1, 196] 65 | visual_feat_grd = torch.matmul(x2_p, visual_feat) 66 | visual_feat_grd = visual_feat_grd.squeeze() # [320, 512] 67 | 68 | visual_gl=torch.cat((visual_feat_before_grounding,visual_feat_grd),dim=-1) 69 | visual_feat_grd=self.tanh(visual_gl) 70 | visual_feat_grd=self.fc_gl(visual_feat_grd) 71 | 72 | # combine a and v 73 | feat = torch.cat((audio_feat, visual_feat_grd), dim=-1) # [320, 1024] 74 | 75 | feat = F.relu(self.fc1(feat)) # (1024, 512) 76 | feat = F.relu(self.fc2(feat)) # (512, 256) 77 | feat = F.relu(self.fc3(feat)) # (256, 128) 78 | feat = self.fc4(feat) # (128, 2) 79 | 80 | return x2_p, feat 81 | -------------------------------------------------------------------------------- /LAST/net_grd_avst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/LAST/net_grd_avst/__init__.py -------------------------------------------------------------------------------- /LAST/net_grd_avst/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /LAVISH/grounding_gen/nets_grd_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from visual_net import resnet18 8 | 9 | 10 | class AVQA_AVatt_Grounding(nn.Module): 11 | 12 | def __init__(self): 13 | super(AVQA_AVatt_Grounding, self).__init__() 14 | 15 | # for features 16 | self.fc_a1 = nn.Linear(128, 512) 17 | self.fc_a2=nn.Linear(512,512) 18 | 19 | # visual 20 | self.visual_net = resnet18(pretrained=True) 21 | 22 | # combine 23 | self.fc1 = nn.Linear(1024, 512) 24 | self.relu1 = nn.ReLU() 25 | self.fc2 = nn.Linear(512, 256) 26 | self.relu2 = nn.ReLU() 27 | self.fc3 = nn.Linear(256, 128) 28 | self.relu3 = nn.ReLU() 29 | self.fc4 = nn.Linear(128, 2) 30 | self.relu4 = nn.ReLU() 31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 32 | 33 | self.fc_gl=nn.Linear(1024,512) 34 | self.tanh = nn.Tanh() 35 | 36 | 37 | def forward(self, video_id, audio, visual): 38 | 39 | ## audio features 40 | audio_feat = F.relu(self.fc_a1(audio)) 41 | audio_feat=self.fc_a2(audio_feat) # [16, 20, 512] 42 | (B, T, C) = audio_feat.size() 43 | audio_feat = audio_feat.view(B*T, C) # [320, 512] 44 | 45 | ## visual, input: [16, 20, 3, 224, 224] 46 | (B, T, C, H, W) = visual.size() 47 | visual = visual.view(B * T, C, H, W) # [320, 3, 224, 224] 48 | 49 | v_feat_out_res18 = self.visual_net(visual) # [320, 512, 14, 14] 50 | v_feat=self.avgpool(v_feat_out_res18) 51 | visual_feat_before_grounding=v_feat.squeeze() # 320 512 52 | 53 | (B, C, H, W) = v_feat_out_res18.size() 54 | v_feat = v_feat_out_res18.view(B, C, H * W) 55 | v_feat = v_feat.permute(0, 2, 1) # B, HxW, C 56 | visual = nn.functional.normalize(v_feat, dim=2) 57 | 58 | ## audio-visual grounding 59 | audio_feat_aa = audio_feat.unsqueeze(-1) # [320, 512, 1] 60 | audio_feat_aa = nn.functional.normalize(audio_feat_aa, dim=1) 61 | visual_feat = visual 62 | x2_va = torch.matmul(visual_feat, audio_feat_aa).squeeze() 63 | 64 | x2_p = F.softmax(x2_va, dim=-1).unsqueeze(-2) # [320, 1, 196] 65 | visual_feat_grd = torch.matmul(x2_p, visual_feat) 66 | visual_feat_grd = visual_feat_grd.squeeze() # [320, 512] 67 | 68 | visual_gl=torch.cat((visual_feat_before_grounding,visual_feat_grd),dim=-1) 69 | visual_feat_grd=self.tanh(visual_gl) 70 | visual_feat_grd=self.fc_gl(visual_feat_grd) 71 | 72 | # combine a and v 73 | feat = torch.cat((audio_feat, visual_feat_grd), dim=-1) # [320, 1024] 74 | 75 | feat = F.relu(self.fc1(feat)) # (1024, 512) 76 | feat = F.relu(self.fc2(feat)) # (512, 256) 77 | feat = F.relu(self.fc3(feat)) # (256, 128) 78 | feat = self.fc4(feat) # (128, 2) 79 | 80 | return feat 81 | -------------------------------------------------------------------------------- /LAVISH/grounding_gen/nets_grd_gen_vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from visual_net import resnet18 8 | 9 | 10 | class AVQA_AVatt_Grounding(nn.Module): 11 | 12 | def __init__(self): 13 | super(AVQA_AVatt_Grounding, self).__init__() 14 | 15 | # for features 16 | self.fc_a1 = nn.Linear(128, 512) 17 | self.fc_a2=nn.Linear(512,512) 18 | 19 | # visual 20 | self.visual_net = resnet18(pretrained=True) 21 | 22 | # combine 23 | self.fc1 = nn.Linear(1024, 512) 24 | self.relu1 = nn.ReLU() 25 | self.fc2 = nn.Linear(512, 256) 26 | self.relu2 = nn.ReLU() 27 | self.fc3 = nn.Linear(256, 128) 28 | self.relu3 = nn.ReLU() 29 | self.fc4 = nn.Linear(128, 2) 30 | self.relu4 = nn.ReLU() 31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 32 | 33 | self.fc_gl=nn.Linear(1024,512) 34 | self.tanh = nn.Tanh() 35 | 36 | 37 | def forward(self, video_id, audio, visual): 38 | 39 | ## audio features 40 | audio_feat = F.relu(self.fc_a1(audio)) 41 | audio_feat=self.fc_a2(audio_feat) # [16, 20, 512] 42 | (B, T, C) = audio_feat.size() 43 | audio_feat = audio_feat.view(B*T, C) # [320, 512] 44 | 45 | ## visual, input: [16, 20, 3, 224, 224] 46 | (B, T, C, H, W) = visual.size() 47 | visual = visual.view(B * T, C, H, W) # [320, 3, 224, 224] 48 | 49 | v_feat_out_res18 = self.visual_net(visual) # [320, 512, 14, 14] 50 | v_feat=self.avgpool(v_feat_out_res18) 51 | visual_feat_before_grounding=v_feat.squeeze() # 320 512 52 | 53 | (B, C, H, W) = v_feat_out_res18.size() 54 | v_feat = v_feat_out_res18.view(B, C, H * W) 55 | v_feat = v_feat.permute(0, 2, 1) # B, HxW, C 56 | visual = nn.functional.normalize(v_feat, dim=2) 57 | 58 | ## audio-visual grounding 59 | audio_feat_aa = audio_feat.unsqueeze(-1) # [320, 512, 1] 60 | audio_feat_aa = nn.functional.normalize(audio_feat_aa, dim=1) 61 | visual_feat = visual 62 | x2_va = torch.matmul(visual_feat, audio_feat_aa).squeeze() 63 | 64 | x2_p = F.softmax(x2_va, dim=-1).unsqueeze(-2) # [320, 1, 196] 65 | visual_feat_grd = torch.matmul(x2_p, visual_feat) 66 | visual_feat_grd = visual_feat_grd.squeeze() # [320, 512] 67 | 68 | visual_gl=torch.cat((visual_feat_before_grounding,visual_feat_grd),dim=-1) 69 | visual_feat_grd=self.tanh(visual_gl) 70 | visual_feat_grd=self.fc_gl(visual_feat_grd) 71 | 72 | # combine a and v 73 | feat = torch.cat((audio_feat, visual_feat_grd), dim=-1) # [320, 1024] 74 | 75 | feat = F.relu(self.fc1(feat)) # (1024, 512) 76 | feat = F.relu(self.fc2(feat)) # (512, 256) 77 | feat = F.relu(self.fc3(feat)) # (256, 128) 78 | feat = self.fc4(feat) # (128, 2) 79 | 80 | return x2_p, feat 81 | -------------------------------------------------------------------------------- /LAVISH/net_grd_avst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/LAVISH/net_grd_avst/__init__.py -------------------------------------------------------------------------------- /LAVISH/net_grd_avst/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /MovieSeq/movieseq.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import json 4 | import imageio 5 | import requests 6 | import argparse 7 | from tqdm import tqdm 8 | from moviepy.editor import VideoFileClip 9 | from utils import encode_image 10 | 11 | import openai 12 | from openai import OpenAI 13 | 14 | class MovieSeq: 15 | def __init__(self, 16 | model="gpt-4o", api_key=None, image_detail="auto", 17 | system_text=None): 18 | self.api_key = api_key 19 | self.model = model 20 | self.image_detail = image_detail 21 | if system_text is None: 22 | self.system_text = """ 23 | You will be provided with the following inputs: 24 | 1. A sequence of photos of characters along with their names. 25 | 2. Keyframes from a video clip and the corresponding dialogues, each associated with a speaker ID. 26 | 27 | Your task is to analyze and associate these inputs, understand the context of the video, and respond to the user's needs accordingly. 28 | """ 29 | 30 | self.headers = { 31 | "Content-Type": "application/json", 32 | "Authorization": f"Bearer {self.api_key}" 33 | } 34 | self.url = "https://api.openai.com/v1/chat/completions" 35 | self.client = OpenAI() 36 | 37 | def get_response(self, char_bank, frame_list, diag_list, 38 | query, 39 | resize=None, temperature=0, detail="auto"): 40 | messages = [{ 41 | "role": "system", 42 | "content": [{"type": "text", "text": self.system_text,},] 43 | }] 44 | 45 | for char_name, char_url in char_bank.items(): 46 | char_image = encode_image(char_url) 47 | messages.append({ 48 | "role": "user", 49 | "content": [ 50 | f"This is the photo of {char_name}.", 51 | {'image': char_image}, 52 | ], 53 | }) 54 | 55 | assert len(diag_list) == len(frame_list) 56 | for frame_i, diag_i in zip(frame_list, diag_list): 57 | frame_image = encode_image(frame_i) 58 | messages.append({ 59 | "role": "user", 60 | "content": [ 61 | {'image': frame_image}, 62 | f"{diag_i}.", 63 | ], 64 | }) 65 | 66 | messages.append({ 67 | "role": "user", 68 | "content": [{"type": "text", "text": query,},] 69 | }) 70 | 71 | params = { 72 | "model": self.model, 73 | "messages": messages, 74 | "max_tokens": 2048, 75 | "temperature": temperature, 76 | } 77 | 78 | response = self.client.chat.completions.create(**params) 79 | json_string = response.json() 80 | json_object = json.loads(json_string) 81 | content = json_object['choices'][0]['message']['content'] 82 | return content -------------------------------------------------------------------------------- /MovieSeq/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import math 5 | import base64 6 | from PIL import Image 7 | 8 | # Function to encode the image 9 | def encode_image(image_path): 10 | with open(image_path, "rb") as image_file: 11 | return base64.b64encode(image_file.read()).decode('utf-8') 12 | 13 | def video2frame(video_path, save_dir=None, num_frames=4): 14 | if not os.path.exists(save_dir): 15 | os.makedirs(save_dir) 16 | 17 | cap = cv2.VideoCapture(video_path) 18 | if not cap.isOpened(): 19 | print("Error opening video file") 20 | return [] 21 | 22 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 23 | interval = total_frames // num_frames 24 | 25 | idx = 0 26 | saved = [] 27 | while True: 28 | ret, frame = cap.read() 29 | if not ret: 30 | print("Can't receive frame (stream end?). Exiting ...") 31 | break 32 | 33 | if idx % interval == 0 and len(saved) < num_frames: 34 | path = os.path.join(save_dir, f"{idx}.jpg") 35 | 36 | rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 37 | img_pil = Image.fromarray(rgb_frame) 38 | 39 | cv2.imwrite(path, frame) 40 | # print(f'Saved frame {idx} at {path}') 41 | saved.append(path) 42 | 43 | if len(saved) >= num_frames: 44 | break 45 | 46 | idx += 1 47 | 48 | cap.release() 49 | print('Done saving frames.') 50 | return saved 51 | 52 | def read_txt(file_path): 53 | try: 54 | with open(file_path, 'r') as file: 55 | text = file.read() 56 | return text 57 | except FileNotFoundError: 58 | print(f"Error: File '{file_path}' not found.") 59 | return None 60 | 61 | def read_json(json_url): 62 | with open(json_url, 'r') as file: 63 | data = json.load(file) 64 | return data 65 | 66 | def write_json(data, save_url, indent=4): 67 | with open(save_url, 'w') as file: 68 | json.dump(data, file, indent=indent) 69 | return save_url 70 | 71 | def read_task(task_json): 72 | task_list = [] 73 | for x in task_json: 74 | task_list.append(x['text']) 75 | task_str = "\n".join([f"{i+1}. {task}" for i, task in enumerate(task_list)]) 76 | return task_str 77 | 78 | def print_score(procedure_eval): 79 | num_seq = len(procedure_eval) 80 | score_dict = {'vis': 0, 'txt': 0, 'vis_txt': 0} 81 | for x in procedure_eval: 82 | for mode in ['vis', 'txt', 'vis_txt']: 83 | score_dict[mode] += x[mode]['score'] 84 | 85 | for mode in ['vis', 'txt', 'vis_txt']: 86 | score_dict[mode] /= num_seq 87 | return score_dict 88 | 89 | def point_to_rect(point, pred): 90 | # point: [x,y] 91 | # rect: [[x1,y1,x2,y2]] 92 | x, y = point 93 | x1, y1, x2, y2 = pred[0] 94 | center_x = (x1 + x2) / 2 95 | center_y = (y1 + y2) / 2 96 | distance = math.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) 97 | inside = x1 <= x <= x2 and y1 <= y <= y2 98 | return distance, inside 99 | 100 | def point_to_point(point, pred): 101 | # point: [x,y] 102 | x, y = point 103 | center_x, center_y = pred 104 | distance = math.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) 105 | return distance -------------------------------------------------------------------------------- /NarrativeBridge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/NarrativeBridge/__init__.py -------------------------------------------------------------------------------- /NarrativeBridge/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /NarrativeBridge/data_layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/NarrativeBridge/data_layer/__init__.py -------------------------------------------------------------------------------- /NarrativeBridge/data_layer/builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | 4 | 5 | def collate_fn(batch): 6 | # this function is designed to support any customized type and to be compatible 7 | # with the default collate function 8 | ele = batch[0] 9 | if isinstance(ele, dict): 10 | return {key: collate_fn([d[key] for d in batch]) for key in ele} 11 | elif isinstance(ele, (tuple, list)): 12 | return [collate_fn(x) for x in zip(*batch)] 13 | else: 14 | if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0: 15 | if not all(b.shape == batch[0].shape for b in batch[1:]): 16 | assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:]) 17 | shape = torch.tensor([b.shape for b in batch]) 18 | max_shape = tuple(shape.max(dim=0)[0].tolist()) 19 | batch2 = [] 20 | for b in batch: 21 | if any(c < m for c, m in zip(b.shape, max_shape)): 22 | b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device) 23 | if b.dim() == 1: 24 | b2[:b.shape[0]] = b 25 | elif b.dim() == 2: 26 | b2[:b.shape[0], :b.shape[1]] = b 27 | elif b.dim() == 3: 28 | b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b 29 | else: 30 | raise NotImplementedError 31 | b = b2 32 | batch2.append(b) 33 | batch = batch2 34 | return default_collate(batch) 35 | 36 | 37 | -------------------------------------------------------------------------------- /NarrativeBridge/data_prepare.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os.path as op 3 | from pprint import pformat 4 | from .common import parse_general_args, json_dump 5 | from .common import qd_tqdm as tqdm 6 | import logging 7 | from .common import load_list_file, read_to_buffer 8 | import json 9 | from .common import init_logging, hash_sha1, write_to_file 10 | from .taxonomy import noffset_to_synset, get_nick_name 11 | from .tsv_io import tsv_writer 12 | 13 | 14 | def get_imagenet_unique_nick_names(): 15 | txt = './aux_data/imagenet/LOC_synset_mapping.txt' 16 | noffsets = load_list_file(txt) 17 | noffsets = [x.split(' ')[0] for x in noffsets] 18 | assert hash_sha1(noffsets) == 'fb9737bbca048296520bc35582947b3755aa948f' 19 | nick_name_overwrite = { 20 | 'n02012849': 'crane bird', 21 | 'n03126707': 'crane machine', 22 | 'n02113186': 'cardigan dog', 23 | 'n02963159': 'cardigan jacket', 24 | 'n03710637': 'maillot tights', 25 | 'n03710721': 'maillot bathing suit', 26 | } 27 | nick_names = [nick_name_overwrite[n] if n in nick_name_overwrite else 28 | get_nick_name(noffset_to_synset(n)) for n in noffsets] 29 | assert hash_sha1(nick_names) == '9c1dd12d7e8120820ffd44b75ebe8b78b659a4f4' 30 | assert len(set(nick_names)) == len(nick_names) 31 | assert len(set(map(lambda n: n.replace(' ', ''), nick_names))) == len(nick_names) 32 | return nick_names 33 | 34 | def generate_imagenet_unique_names(): 35 | nick_names = get_imagenet_unique_nick_names() 36 | write_to_file('\n'.join(nick_names), 37 | './aux_data/imagenet/imagenet_unique_readable_names.txt') 38 | 39 | 40 | def prepare_coco_test(): 41 | image_folder = 'aux_data/raw_data/val2014' 42 | json_file = 'aux_data/raw_data/dataset_coco.json' 43 | infos = json.loads(read_to_buffer(json_file))['images'] 44 | infos = [i for i in infos if i['split'] == 'test'] 45 | assert all(i['filepath'] == 'val2014' for i in infos) 46 | def gen_rows(): 47 | for i in tqdm(infos): 48 | payload = base64.b64encode(read_to_buffer(op.join(image_folder, 49 | i['filename']))) 50 | yield i['cocoid'], payload 51 | tsv_writer(gen_rows(), 'data/coco_caption/test.img.tsv') 52 | 53 | def gen_cap_rows(): 54 | for i in tqdm(infos): 55 | caps = [{'caption': j['raw']} for j in i['sentences']] 56 | yield i['cocoid'], json_dump(caps) 57 | tsv_writer(gen_cap_rows(), 'data/coco_caption/test.caption.tsv') 58 | 59 | if __name__ == '__main__': 60 | init_logging() 61 | kwargs = parse_general_args() 62 | logging.info('param:\n{}'.format(pformat(kwargs))) 63 | function_name = kwargs['type'] 64 | del kwargs['type'] 65 | locals()[function_name](**kwargs) 66 | 67 | -------------------------------------------------------------------------------- /NarrativeBridge/layers/CLIP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/NarrativeBridge/layers/CLIP/__init__.py -------------------------------------------------------------------------------- /NarrativeBridge/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/NarrativeBridge/layers/__init__.py -------------------------------------------------------------------------------- /NarrativeBridge/layers/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | from .modeling_bert import BertConfig 3 | 4 | -------------------------------------------------------------------------------- /NarrativeBridge/layers/bert/activations.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import logging 7 | 8 | 9 | logger = logging.getLogger() 10 | 11 | 12 | def swish(x): 13 | return x * torch.sigmoid(x) 14 | 15 | 16 | def _gelu_python(x): 17 | """ 18 | Original Implementation of the gelu activation function in Google Bert repo when initially created. For 19 | information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + 20 | torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in 21 | torch.nn.functional Also see https://arxiv.org/abs/1606.08415 22 | """ 23 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 24 | 25 | 26 | def gelu_new(x): 27 | """ 28 | Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). Also see 29 | https://arxiv.org/abs/1606.08415 30 | """ 31 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 32 | 33 | 34 | if torch.__version__ < "1.4.0": 35 | gelu = _gelu_python 36 | else: 37 | gelu = F.gelu 38 | 39 | 40 | def gelu_fast(x): 41 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 42 | 43 | 44 | def mish(x): 45 | return x * torch.tanh(torch.nn.functional.softplus(x)) 46 | 47 | 48 | def linear_act(x): 49 | return x 50 | 51 | 52 | ACT2FN = { 53 | "relu": F.relu, 54 | "swish": swish, 55 | #"gelu": gelu, 56 | 'gelu': _gelu_python, 57 | "tanh": torch.tanh, 58 | "gelu_new": gelu_new, 59 | "gelu_fast": gelu_fast, 60 | "mish": mish, 61 | "linear": linear_act, 62 | "sigmoid": torch.sigmoid, 63 | } 64 | 65 | 66 | def get_activation(activation_string): 67 | if activation_string in ACT2FN: 68 | return ACT2FN[activation_string] 69 | else: 70 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 71 | -------------------------------------------------------------------------------- /NarrativeBridge/model.py: -------------------------------------------------------------------------------- 1 | from .torch_common import resize_2d_pos_embed 2 | import torch 3 | from .layers.CLIP import clip 4 | from .layers.decoder import CaptioningModel 5 | from .layers.decoder import (TransformerDecoderTextualHead, 6 | AutoRegressiveBeamSearch, GeneratorWithBeamSearch) 7 | 8 | 9 | def get_git_model(tokenizer, param): 10 | image_encoder = get_image_encoder( 11 | param.get('image_encoder_type', 'CLIPViT_B_16'), 12 | input_resolution=param.get('test_crop_size', 224), 13 | ) 14 | text_decoder = TransformerDecoderTextualHead( 15 | visual_feature_size=param.get('visual_feature_size', 768), 16 | vocab_size=30522, 17 | hidden_size=768, 18 | num_layers=6, 19 | attention_heads=12, 20 | feedforward_size=768* 4, 21 | max_caption_length=1024, 22 | mask_future_positions=True, 23 | padding_idx=0, 24 | decoder_type='bert_en', 25 | visual_projection_type='linearLn', 26 | ) 27 | #decoder = AutoRegressiveBeamSearch( 28 | #eos_index=tokenizer.sep_token_id, 29 | #max_steps=40, 30 | #beam_size=1, 31 | #per_node_beam_size=1, 32 | #fix_missing_prefix=True, 33 | #) 34 | decoder = GeneratorWithBeamSearch( 35 | eos_index=tokenizer.sep_token_id, 36 | #max_steps=40, 37 | max_steps=1024, 38 | beam_size=4, 39 | length_penalty=0.6, 40 | ) 41 | 42 | #from .trie_decoder import TrieAutoRegressiveBeamSearch, get_trie 43 | #decoder = TrieAutoRegressiveBeamSearch( 44 | #eos_index=tokenizer.sep_token_id, 45 | #max_steps=1022, 46 | #beam_size=1, 47 | #trie=get_trie(tokenizer), 48 | #) 49 | 50 | model = CaptioningModel( 51 | image_encoder, 52 | text_decoder, 53 | decoder=decoder, 54 | sos_index=tokenizer.cls_token_id, 55 | eos_index=tokenizer.sep_token_id, 56 | tokenizer=tokenizer, 57 | use_history_for_infer=True, 58 | loss_type='smooth', 59 | num_image_with_embedding=param.get('num_image_with_embedding') 60 | ) 61 | return model 62 | 63 | def get_image_encoder(encoder_type, input_resolution=224): 64 | name_map = { 65 | 'CLIPViT_B_16': 'ViT-B/16', 66 | 'CLIPViT_L_14': 'ViT-L/14', 67 | } 68 | name_in_clip = name_map[encoder_type] 69 | model, _ = clip.load(name_in_clip, device='cpu', jit=False) 70 | model = model.train() 71 | ret = model.visual 72 | ret.to(torch.float32) 73 | ret.output_grid = True 74 | ret.grid_after_ln = True 75 | if ret.input_resolution != input_resolution: 76 | if encoder_type in ['CLIPViT_B_16', 'CLIPViT_L_14']: 77 | pos = ret.positional_embedding 78 | patch_size = ret.conv1.kernel_size[0] 79 | else: 80 | pos = ret.attnpool.positional_embedding 81 | patch_size = 32 82 | p2 = resize_2d_pos_embed(pos, 83 | ret.input_resolution, 84 | patch_size, 85 | input_resolution) 86 | ret.input_resolution = input_resolution 87 | if encoder_type in ['CLIPViT_B_16', 'CLIPViT_L_14']: 88 | ret.positional_embedding = torch.nn.Parameter(p2) 89 | else: 90 | ret.attnpool.positional_embedding = torch.nn.Parameter(p2) 91 | return ret 92 | 93 | -------------------------------------------------------------------------------- /STG-CMA/net_grd_avst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/STG-CMA/net_grd_avst/__init__.py -------------------------------------------------------------------------------- /STG-CMA/net_grd_avst/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /VindLU/configs/beit-base-patch16-224-pt22k-ft22k.json: -------------------------------------------------------------------------------- 1 | { 2 | "note": "this file is a copy of the BEiT model config, not used directly", 3 | "architectures": [ 4 | "BeitForImageClassification" 5 | ], 6 | "url": "https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k/raw/main/config.json", 7 | "attention_probs_dropout_prob": 0.0, 8 | "drop_path_rate": 0.1, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.0, 11 | "hidden_size": 768, 12 | "image_size": 224, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-12, 16 | "layer_scale_init_value": 0.1, 17 | "model_type": "beit", 18 | "num_attention_heads": 12, 19 | "num_channels": 3, 20 | "num_hidden_layers": 12, 21 | "patch_size": 16, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.11.0.dev0", 24 | "use_absolute_position_embeddings": false, 25 | "use_mask_token": false, 26 | "use_mean_pooling": true, 27 | "use_relative_position_bias": true, 28 | "use_shared_relative_position_bias": false, 29 | "vocab_size": 8192 30 | } 31 | -------------------------------------------------------------------------------- /VindLU/configs/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 9, 20 | "encoder_width": 768, 21 | "cross_module": "ca" 22 | } 23 | -------------------------------------------------------------------------------- /VindLU/configs/config_bert_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "gradient_checkpointing": false, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 1024, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 4096, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 16, 16 | "num_hidden_layers": 24, 17 | "pad_token_id": 0, 18 | "position_embedding_type": "absolute", 19 | "type_vocab_size": 2, 20 | "use_cache": true, 21 | "vocab_size": 30522, 22 | "fusion_layer": 19, 23 | "encoder_width": 768, 24 | "cross_module": "ca" 25 | } 26 | -------------------------------------------------------------------------------- /VindLU/configs/data.py: -------------------------------------------------------------------------------- 1 | import os as __os # add "__" if not want to be exported 2 | from copy import deepcopy as __deepcopy 3 | 4 | data_dir = __os.environ.get("VL_DATA_DIR") 5 | if data_dir is None: 6 | raise ValueError("please set environment `VL_DATA_DIR` before continue") 7 | 8 | data_root = __os.path.join(data_dir, "videos_images") 9 | anno_root_pt = __os.path.join(data_dir, "anno_pretrain") 10 | anno_root_downstream = __os.path.join(data_dir, "anno_downstream") 11 | 12 | # ============== pretraining datasets================= 13 | available_corpus = dict( 14 | # pretraining datasets 15 | cc3m=[f"{anno_root_pt}/cc3m_train.sqlite.db", f"{data_root}/cc3m_224"], 16 | cc12m=[f"{anno_root_pt}/cc12m.sqlite.db", f"{data_root}/cc12m_224"], 17 | sbu=[f"{anno_root_pt}/sbu.sqlite.db", f"{data_root}/sbu_224"], 18 | vg=[f"{anno_root_pt}/vg.sqlite.db", f"{data_root}/vg"], 19 | coco=[f"{anno_root_pt}/coco.sqlite.db", f"{data_root}/coco"], 20 | webvid=[f"{anno_root_pt}/webvid_train.sqlite.db", f"{data_root}/webvid_2fps_224", "video"], 21 | webvid_10m=[ 22 | f"{anno_root_pt}/webvid_10m_train.sqlite.db", 23 | f"{data_root}/webvid_10m_2fps_224", 24 | "video", 25 | ], 26 | # downstream datasets. 27 | ) 28 | 29 | # composed datasets. 30 | available_corpus["coco_vg"] = [available_corpus["coco"], available_corpus["vg"]] 31 | available_corpus["webvid_cc3m"] = [available_corpus["webvid"], available_corpus["cc3m"]] 32 | available_corpus["webvid_14m"] = [ 33 | available_corpus["webvid"], 34 | available_corpus["cc3m"], 35 | available_corpus["coco"], 36 | available_corpus["vg"], 37 | available_corpus["sbu"], 38 | available_corpus["cc12m"], 39 | ] 40 | available_corpus["webvid12m_14m"] = [ 41 | available_corpus["webvid"], 42 | available_corpus["webvid_10m"], 43 | available_corpus["cc3m"], 44 | available_corpus["coco"], 45 | available_corpus["vg"], 46 | available_corpus["sbu"], 47 | available_corpus["cc12m"], 48 | ] 49 | available_corpus["webvid10m_14m"] = [ 50 | available_corpus["webvid_10m"], 51 | available_corpus["cc3m"], 52 | available_corpus["coco"], 53 | available_corpus["vg"], 54 | available_corpus["sbu"], 55 | available_corpus["cc12m"], 56 | ] 57 | 58 | # ============== for validation ================= 59 | available_corpus["msrvtt_1k_test"] = [ 60 | f"{anno_root_downstream}/msrvtt_test1k.json", 61 | f"{data_root}/msrvtt_2fps_224", 62 | "video", 63 | ] 64 | 65 | -------------------------------------------------------------------------------- /VindLU/configs/model.py: -------------------------------------------------------------------------------- 1 | VisionEncoders = dict() 2 | VisionEncoders["beit"] = dict( 3 | name="beit_base", 4 | pretrained="microsoft/beit-base-patch16-224-pt22k-ft22k", 5 | d_model=768, 6 | ) 7 | VisionEncoders["beit_large"] = dict( 8 | name="beit_large", 9 | pretrained="microsoft/beit-large-patch16-224-pt22k-ft22k", 10 | d_model=1024, 11 | ) 12 | 13 | TextEncoders = dict() 14 | TextEncoders["bert"] = dict( 15 | name="bert_base", 16 | pretrained="bert-base-uncased", 17 | config="configs/config_bert.json", 18 | d_model=768, 19 | fusion_layer=9, 20 | ) 21 | TextEncoders["bert_large"] = dict( 22 | name="bert_large", 23 | pretrained="bert-base-uncased", 24 | config="configs/config_bert_large.json", 25 | d_model=1024, 26 | fusion_layer=19, 27 | ) 28 | -------------------------------------------------------------------------------- /VindLU/configs/pretrain.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .model import * 3 | 4 | # ========================= data ========================== 5 | train_corpus = "webvid_cc3m" 6 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 7 | test_file = dict(msrvtt_1k_test=available_corpus["msrvtt_1k_test"]) 8 | test_types = ["msrvtt_1k_test"] 9 | num_workers = 6 10 | 11 | stop_key = None 12 | 13 | # ========================= input ========================== 14 | num_frames = 4 15 | num_frames_test = 4 16 | batch_size = 64 17 | max_txt_l = 32 18 | 19 | inputs = dict( 20 | image_res=224, 21 | video_input=dict( 22 | num_frames="${num_frames}", 23 | sample_type="rand", 24 | num_frames_test="${num_frames_test}", 25 | sample_type_test="middle", 26 | random_aug=False, 27 | ), 28 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 29 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 30 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 31 | ) 32 | 33 | # ========================= model ========================== 34 | vision_enc = "beit" 35 | text_enc = "bert" 36 | model = dict( 37 | vision_encoder="${VisionEncoders[${vision_enc}]}", 38 | text_encoder="${TextEncoders[${text_enc}]}", 39 | temporal_modeling=dict( 40 | num_frames="${num_frames}", 41 | temporal_model_block="timesformer", 42 | temporal_model_position="last", 43 | temporal_model_config=dict(input_dim="${model.vision_encoder.d_model}"), 44 | use_temporal_position_embedding=True, 45 | ), 46 | vit_add_ln=True, 47 | multimodal=dict(enable=True), 48 | embed_dim=256, 49 | temp=0.07, 50 | ) 51 | 52 | criterion = dict( 53 | loss_weight=dict(vtc=1.0, mlm=1.0, vtm=1.0, mvm=0.0), # 0: disabled. 54 | vtm_hard_neg=True, 55 | mlm_masking_prob=0.5, 56 | ) 57 | 58 | optimizer = dict( 59 | opt="adamW", 60 | lr=1e-4, 61 | opt_betas=[0.9, 0.999], # default 62 | weight_decay=0.02, 63 | max_grad_norm=-1, # requires a positive float, use -1 to disable 64 | # use a different lr for some modules, e.g., larger lr for new modules 65 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 66 | ) 67 | 68 | scheduler = dict(sched="cosine", epochs=10, min_lr_multi=0.01, warmup_epochs=1) 69 | 70 | evaluate = False 71 | deep_fusion = False 72 | evaluation = dict( 73 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 74 | eval_x_only=False, 75 | k_test=128, 76 | eval_offload=True, # offload gpu tensors to cpu to save memory. 77 | ) 78 | 79 | fp16 = True 80 | gradient_checkpointing = True 81 | 82 | # ========================= wandb ========================== 83 | wandb = dict( 84 | enable=False, 85 | entity="klauscc", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 86 | project="vindlu", # setup in your command line 87 | ) 88 | dist_url = "env://" 89 | device = "cuda" 90 | mode = "pt" 91 | 92 | # ========================= others ========================== 93 | output_dir = None # output dir 94 | resume = False # if True, load optimizer and scheduler states as well 95 | debug = False 96 | log_freq = 100 97 | seed = 42 98 | 99 | pretrained_path = "" # path to pretrained model weights, for resume only? 100 | -------------------------------------------------------------------------------- /VindLU/configs/qa.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | criterion["loss_weight"]["mlm"] = 0.0 6 | scheduler["warmup_epochs"] = 0.5 7 | 8 | max_txt_l = 32 9 | batch_size = 32 10 | num_frames = 12 11 | 12 | optimizer["lr"] = 1e-5 13 | log_freq = 100 14 | 15 | # =========additional args for VQA ============ 16 | eos = "[SEP]" 17 | max_q_len = 25 18 | max_a_len = 5 19 | # =========end ================================ 20 | 21 | -------------------------------------------------------------------------------- /VindLU/configs/qa_anet.py: -------------------------------------------------------------------------------- 1 | from .qa import * 2 | 3 | train_file = [ 4 | [ 5 | f"{anno_root_downstream}/anet_qa_train.json", 6 | f"{data_root}/activity_net_2fps_360", 7 | "video", 8 | ] 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/anet_qa_val.json", 13 | f"{data_root}/activity_net_2fps_360", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/anet_qa_test.json", 18 | f"{data_root}/activity_net_2fps_360", 19 | "video", 20 | ] 21 | ) 22 | dataset_name = "anet" 23 | 24 | answer_list = f"{anno_root_downstream}/anet_qa_answer_list.json" # list of answer words 25 | 26 | test_types = ["val"] 27 | stop_key = "val" # used to choose the best ckpt. If None, save the last. 28 | -------------------------------------------------------------------------------- /VindLU/configs/qa_msrvtt.py: -------------------------------------------------------------------------------- 1 | from .qa import * 2 | 3 | train_file = [ 4 | [ 5 | f"{anno_root_downstream}/msrvtt_qa_train.json", 6 | f"{data_root}/msrvtt_2fps_224", 7 | "video", 8 | ] 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/msrvtt_qa_val.json", 13 | f"{data_root}/msrvtt_2fps_224", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/msrvtt_qa_test.json", 18 | f"{data_root}/msrvtt_2fps_224", 19 | "video", 20 | ], 21 | ) 22 | dataset_name = "msrvtt" 23 | 24 | answer_list = f"{anno_root_downstream}/msrvtt_qa_answer_list.json" # list of answer words 25 | 26 | test_types = ["val"] 27 | stop_key = "val" # used to choose the best ckpt. If None, save the last. 28 | -------------------------------------------------------------------------------- /VindLU/configs/ret_anet.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/anet_ret_train.json", 7 | f"{data_root}/activity_net_2fps_360", 8 | "video", 9 | ] 10 | test_file = dict( 11 | test=[ 12 | f"{anno_root_downstream}/anet_ret_val_1.json", 13 | f"{data_root}/activity_net_2fps_360", 14 | "video", 15 | ], 16 | ) 17 | 18 | test_types = ["test"] 19 | stop_key = "test/" # used to choose the best ckpt. If None, save the last. 20 | is_paragraph_retrieval = True 21 | 22 | max_txt_l = 64 23 | batch_size = 32 24 | num_frames = 12 25 | 26 | optimizer["lr"] = 1e-5 27 | log_freq = 100 28 | -------------------------------------------------------------------------------- /VindLU/configs/ret_coco.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/coco_train.json", 7 | f"{data_root}/coco", 8 | "video", 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/coco_val.json", 13 | f"{data_root}/coco", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/coco_test.json", 18 | f"{data_root}/coco", 19 | "video", 20 | ], 21 | ) 22 | 23 | test_types = ["val"] 24 | stop_key = "val/" # used to choose the best ckpt. If None, save the last. 25 | is_paragraph_retrieval = False 26 | 27 | criterion["loss_weight"]["mlm"] = 0.0 28 | scheduler["warmup_epochs"] = 0 29 | optimizer["lr"] = 1e-5 30 | 31 | 32 | max_txt_l = 22 33 | batch_size = 128 34 | num_frames = 1 35 | num_frames_test = 1 36 | 37 | log_freq = 100 38 | -------------------------------------------------------------------------------- /VindLU/configs/ret_didemo.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/didemo_ret_train.json", 7 | f"{data_root}/didemo_2fps_360_trimed30", 8 | "video", 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/didemo_ret_val.json", 13 | f"{data_root}/didemo_2fps_360_trimed30", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/didemo_ret_test.json", 18 | f"{data_root}/didemo_2fps_360_trimed30", 19 | "video", 20 | ], 21 | ) 22 | 23 | test_types = ["val"] 24 | stop_key = "val/" # used to choose the best ckpt. If None, save the last. 25 | is_paragraph_retrieval = True 26 | 27 | criterion["loss_weight"]["mlm"] = 0.0 28 | scheduler["warmup_epochs"] = 0 29 | optimizer["lr"] = 1e-5 30 | 31 | 32 | max_txt_l = 64 33 | batch_size = 32 34 | num_frames = 12 35 | 36 | log_freq = 10 37 | -------------------------------------------------------------------------------- /VindLU/configs/ret_flickr.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/flickr30k_train.json", 7 | f"{data_root}/f30k", 8 | "video", 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/flickr30k_val.json", 13 | f"{data_root}/f30k", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/flickr30k_test.json", 18 | f"{data_root}/f30k", 19 | "video", 20 | ], 21 | ) 22 | 23 | test_types = ["val"] 24 | stop_key = "val/" # used to choose the best ckpt. If None, save the last. 25 | is_paragraph_retrieval = False 26 | 27 | criterion["loss_weight"]["mlm"] = 0.0 28 | scheduler["warmup_epochs"] = 0 29 | optimizer["lr"] = 1e-5 30 | 31 | 32 | max_txt_l = 32 33 | batch_size = 128 34 | num_frames = 1 35 | num_frames_test = 1 36 | 37 | log_freq = 100 38 | -------------------------------------------------------------------------------- /VindLU/configs/ret_msrvtt.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/msrvtt_ret_train7k.json", 7 | f"{data_root}/msrvtt_2fps_224", 8 | "video", 9 | ] 10 | test_file = dict( 11 | test=[ 12 | f"{anno_root_downstream}/msrvtt_ret_test1k.json", 13 | f"{data_root}/msrvtt_2fps_224", 14 | "video", 15 | ], 16 | ) 17 | 18 | test_types = ["test"] 19 | stop_key = None # used to choose the best ckpt. If None, save the last. 20 | is_paragraph_retrieval = False 21 | 22 | criterion["loss_weight"]["mlm"] = 0.0 23 | scheduler["warmup_epochs"] = 0 24 | scheduler["epochs"] = 5 25 | optimizer["lr"] = 1e-5 26 | 27 | max_txt_l = 32 28 | batch_size = 32 29 | num_frames = 12 30 | 31 | log_freq = 100 32 | -------------------------------------------------------------------------------- /VindLU/configs/ret_msrvtt_9k.py: -------------------------------------------------------------------------------- 1 | from .ret_msrvtt import * 2 | 3 | train_file = [ 4 | f"{anno_root_downstream}/msrvtt_ret_train9k.json", 5 | f"{data_root}/msrvtt_2fps_224", 6 | "video", 7 | ] 8 | -------------------------------------------------------------------------------- /VindLU/configs/ret_msrvtt_mc.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/msrvtt_ret_train7k.json", 7 | f"{data_root}/msrvtt_2fps_224", 8 | "video", 9 | ] 10 | test_file = dict( 11 | mc_test=[ 12 | f"{anno_root_downstream}/msrvtt_mc_test.json", 13 | f"{data_root}/msrvtt_2fps_224", 14 | "video", 15 | ] 16 | ) 17 | 18 | test_types = ["mc_test"] 19 | stop_key = None # used to choose the best ckpt. If None, save the last. 20 | is_paragraph_retrieval = False 21 | 22 | criterion["loss_weight"]["mlm"] = 0.0 23 | scheduler["warmup_epochs"] = 0 24 | optimizer["lr"] = 1e-5 25 | 26 | max_txt_l = 32 27 | batch_size = 32 28 | num_frames = 12 29 | 30 | log_freq = 100 31 | -------------------------------------------------------------------------------- /VindLU/configs/ret_ssv2_label.py: -------------------------------------------------------------------------------- 1 | from .ret_msrvtt import * 2 | 3 | train_file = [ 4 | f"{anno_root_downstream}/ssv2_ret_label_train.json", 5 | f"{data_root}/ssv2", 6 | "video", 7 | ] 8 | test_file = dict( 9 | val=[ 10 | f"{anno_root_downstream}/ssv2_ret_label_val_small.json", 11 | f"{data_root}/ssv2", 12 | "video", 13 | ], 14 | ) 15 | 16 | test_types = ["val"] 17 | stop_key = None # used to choose the best ckpt. If None, save the last. 18 | 19 | has_multi_vision_gt = True 20 | 21 | scheduler["epochs"] = 10 22 | optimizer["lr"] = 1e-4 23 | 24 | max_txt_l = 25 25 | -------------------------------------------------------------------------------- /VindLU/configs/ret_ssv2_template.py: -------------------------------------------------------------------------------- 1 | from .ret_msrvtt import * 2 | 3 | train_file = [ 4 | f"{anno_root_downstream}/ssv2_ret_template_train.json", 5 | f"{data_root}/ssv2", 6 | "video", 7 | ] 8 | test_file = dict( 9 | val=[ 10 | f"{anno_root_downstream}/ssv2_ret_template_val_small.json", 11 | f"{data_root}/ssv2", 12 | "video", 13 | ], 14 | ) 15 | 16 | test_types = ["val"] 17 | stop_key = None # used to choose the best ckpt. If None, save the last. 18 | 19 | has_multi_vision_gt = True 20 | 21 | scheduler["epochs"] = 10 22 | optimizer["lr"] = 1e-4 23 | 24 | max_txt_l = 22 25 | -------------------------------------------------------------------------------- /VindLU/configs/tvqa.py: -------------------------------------------------------------------------------- 1 | from .pretrain import * 2 | 3 | del available_corpus 4 | 5 | train_file = [ 6 | f"{anno_root_downstream}/tvqa_train_with_answer.json", 7 | f"{data_root}/tvqa_trimmed_3fps", 8 | "video", 9 | ] 10 | test_file = dict( 11 | val=[ 12 | f"{anno_root_downstream}/tvqa_val_with_answer.json", 13 | f"{data_root}/tvqa_trimmed_3fps", 14 | "video", 15 | ], 16 | test=[ 17 | f"{anno_root_downstream}/tvqa_test_public_with_answer.json", 18 | f"{data_root}/tvqa_trimmed_3fps", 19 | "video", 20 | ], 21 | ) 22 | 23 | test_types = ["val"] 24 | stop_key = "val" # used to choose the best ckpt. If None, save the last. 25 | is_paragraph_retrieval = False 26 | 27 | criterion["loss_weight"]["mlm"] = 0.0 28 | optimizer["lr"] = 1e-5 29 | scheduler["warmup_epochs"] = 0.5 30 | scheduler["epochs"] = 10 31 | 32 | max_txt_l = 150 33 | batch_size = 32 34 | num_frames = 12 35 | 36 | log_freq = 100 37 | -------------------------------------------------------------------------------- /VindLU/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | from torch.utils.data import Dataset 5 | 6 | from dataset.utils import load_image_from_path 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class ImageVideoBaseDataset(Dataset): 12 | """Base class that implements the image and video loading methods""" 13 | 14 | media_type = "video" 15 | 16 | def __init__(self): 17 | assert self.media_type in ["image", "video"] 18 | self.data_root = None 19 | self.anno_list = ( 20 | None # list(dict), each dict contains {"image": str, # image or video path} 21 | ) 22 | self.transform = None 23 | self.video_reader = None 24 | self.num_tries = None 25 | 26 | def __getitem__(self, index): 27 | raise NotImplementedError 28 | 29 | def __len__(self): 30 | raise NotImplementedError 31 | 32 | def get_anno(self, index): 33 | """obtain the annotation for one media (video or image) 34 | 35 | Args: 36 | index (int): The media index. 37 | 38 | Returns: dict. 39 | - "image" or "video": the filename. 40 | - "caption": The caption for this file. 41 | 42 | """ 43 | anno = self.anno_list[index] 44 | if self.data_root is not None: 45 | anno[self.media_type] = os.path.join(self.data_root, anno[self.media_type]) 46 | return anno 47 | 48 | def load_and_transform_media_data(self, index): 49 | if self.media_type == "image": 50 | return self.load_and_transform_media_data_image(index) 51 | else: 52 | return self.load_and_transform_media_data_video(index) 53 | 54 | def load_and_transform_media_data_image(self, index): 55 | ann = self.get_anno(index) 56 | data_path = ann["image"] 57 | image = load_image_from_path(data_path) 58 | image = self.transform(image) 59 | return image, index 60 | 61 | def load_and_transform_media_data_video(self, index): 62 | for i in range(self.num_tries): 63 | ann = self.get_anno(index) 64 | data_path = ann["image"] 65 | try: 66 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 67 | frames, frame_indices, video_duration = self.video_reader( 68 | data_path, self.num_frames, self.sample_type, max_num_frames=max_num_frames 69 | ) 70 | except Exception as e: 71 | index = random.randint(0, len(self) - 1) 72 | logger.warning( 73 | f"Caught exception {e} when loading video {data_path}, " 74 | f"randomly sample a new video as replacement" 75 | ) 76 | continue 77 | 78 | frames = self.transform(frames) 79 | return frames, index 80 | else: 81 | raise RuntimeError( 82 | f"Failed to fetch video after {self.num_tries} tries. " 83 | f"This might indicate that you have many corrupted videos." 84 | ) 85 | -------------------------------------------------------------------------------- /VindLU/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, name2loader): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.name2loader = name2loader 20 | self.name2iter = {name: iter(l) for name, l in name2loader.items()} 21 | name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} 22 | index2name = {v: k for k, v in name2index.items()} 23 | 24 | iter_order = [] 25 | for n, l in name2loader.items(): 26 | iter_order.extend([name2index[n]]*len(l)) 27 | 28 | random.shuffle(iter_order) 29 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 30 | 31 | # sync 32 | if is_dist_avail_and_initialized(): 33 | # make sure all processes have the same order so that 34 | # each step they will have data from the same loader 35 | dist.broadcast(iter_order, src=0) 36 | self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] 37 | 38 | logger.info(str(self)) 39 | 40 | def __str__(self): 41 | output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] 42 | for idx, (name, loader) in enumerate(self.name2loader.items()): 43 | output.append( 44 | f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " 45 | ) 46 | return "\n".join(output) 47 | 48 | def __len__(self): 49 | return len(self.iter_order) 50 | 51 | def __iter__(self): 52 | """ this iterator will run indefinitely """ 53 | for name in self.iter_order: 54 | _iter = self.name2iter[name] 55 | batch = next(_iter) 56 | yield name, batch 57 | -------------------------------------------------------------------------------- /VindLU/dataset/qa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataset.base_dataset import ImageVideoBaseDataset 3 | from dataset.utils import pre_text, load_anno 4 | from dataset.video_utils import VIDEO_READER_FUNCS 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class ImageQADataset(ImageVideoBaseDataset): 11 | media_type = "image" 12 | 13 | def __init__(self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None): 14 | super(ImageQADataset, self).__init__() 15 | assert mode in ["train", "eval"] 16 | self.mode = mode 17 | self.transform = transform 18 | self.eos = eos 19 | 20 | self.anno_list = load_anno(ann_file) 21 | 22 | if mode == "eval": 23 | self.answer_list = json.load(open(answer_list, "r")) 24 | 25 | def __len__(self): 26 | return len(self.anno_list) 27 | 28 | def get_answers_with_weights(self, raw_answers): 29 | if isinstance(raw_answers, str): 30 | raw_answers = [raw_answers] 31 | answer_weight = {} 32 | for answer in raw_answers: 33 | if answer in answer_weight.keys(): 34 | answer_weight[answer] += 1/len(raw_answers) 35 | else: 36 | answer_weight[answer] = 1/len(raw_answers) 37 | 38 | answers = list(answer_weight.keys()) 39 | weights = [answer_weight[a] for a in answers] 40 | answers = [answer + " " + self.eos for answer in answers] 41 | return answers, weights 42 | 43 | def __getitem__(self, index): 44 | ann = self.anno_list[index] 45 | image, index = self.load_and_transform_media_data(index) 46 | 47 | question = pre_text(ann["question"]) 48 | if self.mode == "train": 49 | answers, weights = self.get_answers_with_weights(ann["answer"]) 50 | return image, question, answers, weights 51 | else: # self.mode == "eval": 52 | question_id = ann["question_id"] 53 | return image, question, question_id 54 | 55 | 56 | class VideoQADataset(ImageQADataset): 57 | media_type = "video" 58 | 59 | def __init__( 60 | self, ann_file, transform, eos="[SEP]", mode="train", answer_list=None, 61 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=1 62 | ): 63 | super(VideoQADataset, self).__init__( 64 | ann_file, transform, eos, mode, answer_list) 65 | self.num_frames = num_frames 66 | self.video_reader_type = video_reader_type 67 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 68 | self.sample_type = sample_type 69 | self.num_tries = num_tries 70 | -------------------------------------------------------------------------------- /VindLU/dataset/sqlite_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sqlite3 4 | from os.path import basename 5 | 6 | import numpy as np 7 | import tqdm 8 | 9 | from dataset.base_dataset import ImageVideoBaseDataset 10 | from dataset.utils import load_anno, pre_text 11 | from dataset.video_utils import VIDEO_READER_FUNCS 12 | from utils.distributed import is_main_process 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_anno_by_id(cur: sqlite3.Cursor, id: int): 18 | """TODO: Docstring for get_anno_by_id. 19 | 20 | Args: 21 | cur (sqlite3.Cursor): The dataset cursor. 22 | id (int): The annotation id. 23 | 24 | Returns: 25 | 26 | """ 27 | pass 28 | 29 | 30 | class SQLiteImgTxtRetTrainDataset(ImageVideoBaseDataset): 31 | media_type = "image" 32 | 33 | def __init__(self, ann_file, transform, has_multi_vision_gt=False): 34 | super().__init__() 35 | 36 | self.media_type = "video" if len(ann_file) == 3 and ann_file[2] == "video" else "image" 37 | self.label_file, self.data_root = ann_file[:2] 38 | 39 | self.con = sqlite3.connect("file:" + self.label_file + "?mode=ro", uri=True) 40 | self.cur = self.con.cursor() 41 | 42 | # enable this will get stuck on NFS. 43 | # self.cur.execute("PRAGMA temp_store = MEMORY") 44 | # self.cur.execute("PRAGMA mmap_size = 30000000000") 45 | 46 | self.transform = transform 47 | # each caption has multiple image as ground_truth, e.g., ssv2 48 | self.has_multi_vision_gt = has_multi_vision_gt 49 | assert not self.has_multi_vision_gt 50 | 51 | self.num_examples = self.get_length() 52 | 53 | def get_anno(self, index): 54 | query = f"SELECT * FROM annos WHERE id = {index};" 55 | res = self.cur.execute(query) 56 | id, filename, caption = res.fetchone() 57 | anno = {"image": os.path.join(self.data_root, filename), "caption": caption} 58 | return anno 59 | 60 | def get_length(self): 61 | """get the number of examples in this dataset. 62 | Returns: 63 | 64 | """ 65 | num_rows = self.cur.execute("SELECT COUNT(*) FROM annos").fetchone()[0] 66 | return num_rows 67 | 68 | def __len__(self): 69 | return self.num_examples 70 | 71 | def __getitem__(self, index): 72 | 73 | try: 74 | ann = self.get_anno(index) 75 | image, index = self.load_and_transform_media_data(index) 76 | caption = pre_text(ann["caption"]) 77 | # key = ann["caption"] if self.has_multi_vision_gt else basename(ann["image"]) 78 | return image, caption, index 79 | except Exception as e: 80 | index = np.random.randint(0, len(self)) 81 | return self.__getitem__(index) 82 | 83 | 84 | class SQLiteVidTxtRetTrainDataset(SQLiteImgTxtRetTrainDataset): 85 | media_type = "video" 86 | 87 | def __init__( 88 | self, 89 | ann_file, 90 | transform, 91 | num_frames=4, 92 | video_reader_type="decord", 93 | sample_type="rand", 94 | num_tries=3, 95 | is_paragraph_retrieval=False, 96 | has_multi_vision_gt=False, 97 | ): 98 | super().__init__(ann_file, transform, has_multi_vision_gt) 99 | self.num_frames = num_frames 100 | self.video_reader_type = video_reader_type 101 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 102 | self.sample_type = sample_type 103 | self.num_tries = num_tries 104 | self.is_paragraph_retrieval = is_paragraph_retrieval 105 | 106 | if is_paragraph_retrieval: 107 | raise ValueError(f"not implemented") 108 | -------------------------------------------------------------------------------- /VindLU/miscs/test_flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fvcore.nn import FlopCountAnalysis, flop_count_table 3 | from torch.nn import MultiheadAttention 4 | 5 | from models.beit.st_beit import BeitConfig, BeitModel 6 | from models.temporal_model import (STAdapter, TemporalAttention, 7 | WindowTemporalAttention) 8 | 9 | 10 | def mem_stat(): 11 | mem = torch.cuda.max_memory_allocated() / 1024 / 1024 12 | print(f"max memory allocated: {mem}MB") 13 | 14 | 15 | def build_backbone(tm_block="timesformer"): 16 | """TODO: Docstring for build_backbone. 17 | Returns: TODO 18 | 19 | """ 20 | if tm_block == "timesformer": 21 | other_cfg = dict( 22 | num_frames=12, temporal_model_block="timesformer", temporal_model_config={} 23 | ) 24 | elif tm_block == "st_adapter": 25 | other_cfg = dict( 26 | num_frames=12, temporal_model_block="st_adapter", temporal_model_config={} 27 | ) 28 | elif tm_block == "xclip": 29 | other_cfg = dict( 30 | num_frames=12, temporal_model_block="xclip", temporal_model_config={} 31 | ) 32 | elif tm_block == "none": 33 | other_cfg = dict(num_frames=12, temporal_model_block="none", temporal_model_config={}) 34 | elif tm_block == "wa_2x2": 35 | other_cfg = dict( 36 | num_frames=12, 37 | temporal_model_block="window_attention", 38 | temporal_model_config=dict(window_size=(2, 2)), 39 | ) 40 | elif tm_block == "wa_7x7": 41 | other_cfg = dict( 42 | num_frames=12, 43 | temporal_model_block="window_attention", 44 | temporal_model_config=dict(window_size=(7, 7)), 45 | ) 46 | else: 47 | raise ValueError("not exist") 48 | 49 | model_card = "microsoft/beit-base-patch16-224-pt22k-ft22k" 50 | model_config = BeitConfig.from_pretrained(model_card, image_size=224, **other_cfg) 51 | model = BeitModel(model_config) 52 | return model 53 | 54 | 55 | # model = TemporalAttention() 56 | model = build_backbone("st_adapter") 57 | model.gradient_checkpointing_enable() 58 | model.cuda() 59 | for i in range(3): 60 | x = torch.rand(32, 12, 3, 224, 224, requires_grad=True) 61 | x = x.cuda() 62 | x = x.requires_grad_() 63 | y = model(x) 64 | loss = y[0].mean() 65 | loss.backward() 66 | mem_stat() 67 | 68 | # flops = FlopCountAnalysis(model, x) 69 | # print(flop_count_table(flops)) 70 | -------------------------------------------------------------------------------- /VindLU/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/VindLU/models/__init__.py -------------------------------------------------------------------------------- /VindLU/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/VindLU/models/backbones/__init__.py -------------------------------------------------------------------------------- /VindLU/models/backbones/beit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/VindLU/models/backbones/beit/__init__.py -------------------------------------------------------------------------------- /VindLU/models/backbones/beit/builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from models.utils import (interpolate_pos_relative_bias_beit, 4 | load_temp_embed_with_mismatch) 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def interpolate_pos_embed_beit(state_dict, new_model): 10 | """interpolate the positional embeddings. 11 | The spatial pe is relative and temporal pe is absolute. 12 | additional temporal pe is padded with 0. 13 | 14 | Args: 15 | state_dict (dict): The state_dict. 16 | new_model (nn.Module): The created model. 17 | 18 | Returns: dict. The state_dict with updated positional embeddings. 19 | 20 | """ 21 | state_dict = interpolate_pos_relative_bias_beit( 22 | state_dict_old=state_dict, 23 | state_dict_new=new_model.state_dict(), 24 | patch_shape_new=new_model.vision_encoder.embeddings.patch_embeddings.patch_shape, 25 | ) 26 | # absolute temporal pos bias 27 | temporal_pe_key = "vision_encoder.embeddings.temporal_position_embeddings" 28 | if temporal_pe_key in state_dict: 29 | logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}") 30 | state_dict[temporal_pe_key] = load_temp_embed_with_mismatch( 31 | temp_embed_old=state_dict[temporal_pe_key], 32 | temp_embed_new=new_model.state_dict()[temporal_pe_key], 33 | ) 34 | return state_dict 35 | 36 | 37 | def build_beit(model_config, image_res, checkpoint): 38 | """build beit with configuration. 39 | 40 | Args: 41 | config (dict): The configs for beit. 42 | image_res (int): The image resolution. 43 | checkpoint (bool): Whether to enable gradient checkpointing. 44 | 45 | Returns: nn.Module 46 | 47 | """ 48 | from .st_beit import BeitConfig as config_cls 49 | from .st_beit import BeitModel as model_cls 50 | 51 | logger.info( 52 | f"Loading vit pre-trained weights from huggingface {model_config.vision_encoder.pretrained}." 53 | ) 54 | # BEiT uses average pooled tokens instead of [CLS] used by other models 55 | aux_kwargs = {"add_pooling_layer": True} 56 | tmp_model = model_cls.from_pretrained(model_config.vision_encoder.pretrained, **aux_kwargs) 57 | state_dict = tmp_model.state_dict() 58 | del tmp_model 59 | 60 | logger.info(f"Init new model with new image size {image_res}, and load weights.") 61 | 62 | other_cfg = model_config.temporal_modeling 63 | vit_config = config_cls.from_pretrained( 64 | model_config.vision_encoder.pretrained, image_size=image_res, **other_cfg 65 | ) 66 | model = model_cls(config=vit_config, **aux_kwargs) 67 | 68 | if checkpoint: 69 | model.gradient_checkpointing_enable() 70 | 71 | # interpolate relative pos bias 72 | state_dict = interpolate_pos_relative_bias_beit( 73 | state_dict_old=state_dict, 74 | state_dict_new=model.state_dict(), 75 | patch_shape_new=model.embeddings.patch_embeddings.patch_shape, 76 | ) 77 | 78 | # del prompt_bias_table 79 | for k in list(state_dict.keys()): 80 | if "prompt_bias_table" in k: 81 | del state_dict[k] 82 | 83 | msg = model.load_state_dict(state_dict, strict=False) 84 | logger.info(msg) 85 | return model 86 | -------------------------------------------------------------------------------- /VindLU/models/backbones/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/VindLU/models/backbones/bert/__init__.py -------------------------------------------------------------------------------- /VindLU/models/backbones/bert/builder.py: -------------------------------------------------------------------------------- 1 | from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel 2 | 3 | 4 | def build_bert(model_config, pretrain, checkpoint): 5 | """build text encoder. 6 | 7 | Args: 8 | model_config (dict): model config. 9 | pretrain (bool): Whether to do pretrain or finetuning. 10 | checkpoint (bool): whether to do gradient_checkpointing. 11 | 12 | Returns: TODO 13 | 14 | """ 15 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 16 | bert_config.encoder_width = model_config.vision_encoder.d_model 17 | bert_config.gradient_checkpointing = checkpoint 18 | bert_config.fusion_layer = model_config.text_encoder.fusion_layer 19 | 20 | if not model_config.multimodal.enable: 21 | bert_config.fusion_layer = bert_config.num_hidden_layers 22 | 23 | if pretrain: 24 | text_encoder, loading_info = BertForMaskedLM.from_pretrained( 25 | model_config.text_encoder.pretrained, 26 | config=bert_config, 27 | output_loading_info=True, 28 | ) 29 | else: 30 | text_encoder, loading_info = BertModel.from_pretrained( 31 | model_config.text_encoder.pretrained, 32 | config=bert_config, 33 | add_pooling_layer=False, 34 | output_loading_info=True, 35 | ) 36 | 37 | return text_encoder 38 | 39 | 40 | def build_bert_decoder(model_config, checkpoint): 41 | """build text decoder the same as the multimodal encoder. 42 | 43 | Args: 44 | model_config (dict): model config. 45 | pretrain (bool): Whether to do pretrain or finetuning. 46 | checkpoint (bool): whether to do gradient_checkpointing. 47 | 48 | Returns: TODO 49 | 50 | """ 51 | bert_config = BertConfig.from_json_file(model_config.text_encoder.config) 52 | bert_config.encoder_width = model_config.vision_encoder.d_model 53 | bert_config.gradient_checkpointing = checkpoint 54 | 55 | bert_config.fusion_layer = 0 56 | bert_config.num_hidden_layers = ( 57 | bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer 58 | ) 59 | 60 | text_decoder, loading_info = BertLMHeadModel.from_pretrained( 61 | model_config.text_encoder.pretrained, 62 | config=bert_config, 63 | output_loading_info=True, 64 | ) 65 | 66 | return text_decoder 67 | -------------------------------------------------------------------------------- /VindLU/models/backbones/omnivore_swin/builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import einops 5 | from torch import nn 6 | from torch.hub import load_state_dict_from_url 7 | from torch.nn import functional as F 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | # WARNING: outdated. not usable. 12 | 13 | def build_omnivore_swinb(config): 14 | """build omnivore swinb 15 | 16 | Args: 17 | config (dict): The config 18 | 19 | Returns: nn.Module. 20 | 21 | """ 22 | from .omnivore_swin import CHECKPOINT_PATHS, SwinTransformer3D 23 | 24 | # new_wd = config.video_input.num_frames 25 | new_wd = config.swin_wd 26 | vision_encoder = SwinTransformer3D( 27 | pretrained2d=False, 28 | patch_size=(2, 4, 4), 29 | embed_dim=128, 30 | depths=[2, 2, 18, 2], 31 | num_heads=[4, 8, 16, 32], 32 | window_size=(new_wd, 7, 7), 33 | drop_path_rate=0.1, # TODO: set this based on the final models 34 | patch_norm=True, # Make this the default value? 35 | depth_mode="summed_rgb_d_tokens", # Make this the default value? 36 | ) 37 | path = CHECKPOINT_PATHS[config.vit_name_or_pretrained_path] 38 | checkpoint = load_state_dict_from_url(path, progress=True, map_location="cpu") 39 | wd, wh, ww = 16, 7, 7 40 | 41 | # interpolate the rel_pos enmbedding. 42 | trunk_ckpt = checkpoint["trunk"] 43 | new_state_dict = OrderedDict() 44 | if new_wd != wd: 45 | for k, v in trunk_ckpt.items(): 46 | if "relative_position_bias_table" in k: 47 | # do interpolation 48 | if config.pe_scale_method == "interpolation": 49 | v = einops.rearrange( 50 | v, 51 | "(d h w c) nh -> nh c d h w", 52 | d=2 * wd - 1, 53 | h=2 * wh - 1, 54 | w=2 * ww - 1, 55 | ) 56 | v = F.interpolate( 57 | v, size=(2 * new_wd - 1, 13, 13), mode="trilinear" 58 | ) # shape: [nh, c, d, h, w] 59 | v = einops.rearrange(v, "nh c d h w -> (d h w c) nh") 60 | elif config.pe_scale_method == "crop": 61 | v = einops.rearrange( 62 | v, 63 | "(d h w c) nh -> d (h w c) nh", 64 | d=2 * wd - 1, 65 | h=2 * wh - 1, 66 | w=2 * ww - 1, 67 | ) 68 | v = v[wd - (new_wd) : wd + new_wd - 1] 69 | v = einops.rearrange( 70 | v, 71 | "d (h w c) nh -> (d h w c) nh", 72 | d=2 * new_wd - 1, 73 | h=2 * wh - 1, 74 | w=2 * ww - 1, 75 | ) 76 | else: 77 | raise ValueError("not implemented") 78 | if "relative_position_index" not in k: 79 | new_state_dict[k] = v 80 | 81 | info = vision_encoder.load_state_dict(new_state_dict, strict=False) 82 | logger.info(f"SwinTransformer3D: loaded checkpoint {path}. info:{info}") 83 | return vision_encoder 84 | -------------------------------------------------------------------------------- /VindLU/models/backbones/omnivore_swin/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copied from https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/util.py 8 | 9 | import logging 10 | from typing import Dict, List, Optional 11 | 12 | import torch 13 | from iopath.common.file_io import g_pathmgr 14 | 15 | from .distributed import broadcast_object, is_primary 16 | 17 | 18 | # constants: 19 | CHECKPOINT_FILE = "checkpoint.torch" 20 | CPU_DEVICE = torch.device("cpu") 21 | GPU_DEVICE = torch.device("cuda") 22 | 23 | 24 | def load_and_broadcast_checkpoint_list( 25 | checkpoint_paths: List[str], device: torch.device = CPU_DEVICE 26 | ): 27 | if is_primary(): 28 | for path in checkpoint_paths: 29 | checkpoint = load_checkpoint(path, device) 30 | if checkpoint is not None: 31 | break 32 | else: 33 | checkpoint = None 34 | logging.info(f"Broadcasting checkpoint loaded from {checkpoint_paths}") 35 | return broadcast_object(checkpoint) 36 | 37 | 38 | def load_and_broadcast_checkpoint( 39 | checkpoint_path: str, device: torch.device = CPU_DEVICE 40 | ) -> Optional[Dict]: 41 | """Loads a checkpoint on primary and broadcasts it to all replicas. 42 | This is a collective operation which needs to be run in sync on all replicas. 43 | See :func:`load_checkpoint` for the arguments. 44 | """ 45 | if is_primary(): 46 | checkpoint = load_checkpoint(checkpoint_path, device) 47 | else: 48 | checkpoint = None 49 | logging.info(f"Broadcasting checkpoint loaded from {checkpoint_path}") 50 | return broadcast_object(checkpoint) 51 | 52 | 53 | def load_checkpoint( 54 | checkpoint_path: str, device: torch.device = CPU_DEVICE 55 | ) -> Optional[Dict]: 56 | """Loads a checkpoint from the specified checkpoint path. 57 | Args: 58 | checkpoint_path: The path to load the checkpoint from. Can be a file or a 59 | directory. If it is a directory, the checkpoint is loaded from 60 | :py:data:`CHECKPOINT_FILE` inside the directory. 61 | device: device to load the checkpoint to 62 | Returns: 63 | The checkpoint, if it exists, or None. 64 | """ 65 | if not checkpoint_path: 66 | return None 67 | 68 | assert device is not None, "Please specify what device to load checkpoint on" 69 | assert device.type in ["cpu", "cuda"], f"Unknown device: {device}" 70 | if device.type == "cuda": 71 | assert torch.cuda.is_available() 72 | 73 | if not g_pathmgr.exists(checkpoint_path): 74 | logging.warning(f"Checkpoint path {checkpoint_path} not found") 75 | return None 76 | if g_pathmgr.isdir(checkpoint_path): 77 | checkpoint_path = f"{checkpoint_path.rstrip('/')}/{CHECKPOINT_FILE}" 78 | 79 | if not g_pathmgr.exists(checkpoint_path): 80 | logging.warning(f"Checkpoint file {checkpoint_path} not found.") 81 | return None 82 | 83 | logging.info(f"Attempting to load checkpoint from {checkpoint_path}") 84 | # load model on specified device and not on saved device for model and return 85 | # the checkpoint 86 | with g_pathmgr.open(checkpoint_path, "rb") as f: 87 | checkpoint = torch.load(f, map_location=device) 88 | logging.info(f"Loaded checkpoint from {checkpoint_path}") 89 | return checkpoint 90 | -------------------------------------------------------------------------------- /VindLU/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/VindLU/models/modules/__init__.py -------------------------------------------------------------------------------- /VindLU/models/vindlu_tvqa.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from .utils import tile 8 | from .vindlu import VindLU 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class VindLU_TVQA(VindLU): 14 | """docstring for VindLU_TVQA""" 15 | 16 | def __init__(self, config, tokenizer, is_pretrain=False): 17 | super().__init__(config, tokenizer, is_pretrain) 18 | 19 | def forward(self, image, text, answer, train=True): 20 | """forward and calculate loss. 21 | 22 | Args: 23 | image (torch.Tensor): The input images. Shape: [B,T,C,H,W]. 24 | text (TODO): tokenized text. 5*B. Each image has 5 text. 25 | answer (torch.Tensor): The answers. Shape: [B,]. Each value 26 | is between 0 and 4. 27 | 28 | """ 29 | # ================= Dual Encoder ITC loss ================ # 30 | bsz = len(image) 31 | num_options_per_q = 5 32 | 33 | image_embeds, pooled_image_embeds = self.encode_vision(image) # (N, ) 34 | text_embeds, pooled_text_embeds = self.encode_text(text) # (5N, ) 35 | image_embeds = rearrange(image_embeds, "b t l c -> b (t l) c") 36 | 37 | # ================= Cross Encoder ITM loss ================ # 38 | image_embeds = tile(image_embeds, 0, num_options_per_q) 39 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( 40 | image_embeds.device, non_blocking=True 41 | ) 42 | 43 | output = self.get_text_encoder()( 44 | encoder_embeds=text_embeds, 45 | attention_mask=text.attention_mask, 46 | encoder_hidden_states=image_embeds, 47 | encoder_attention_mask=image_atts, 48 | return_dict=True, 49 | mode="fusion", 50 | ) 51 | itm_embeds = output.last_hidden_state[:, 0] # [CLS] (5N, ) 52 | 53 | score = self.itm_head(itm_embeds)[:, 1] # (5N, ) 54 | score = score.view(-1, num_options_per_q) # (N, 5) 55 | if train: 56 | loss_qa = F.cross_entropy(score, answer) 57 | 58 | return_dict = dict(loss_qa=loss_qa) 59 | 60 | return return_dict 61 | else: 62 | pred_ans = score.max(1)[1].cpu() 63 | return pred_ans 64 | -------------------------------------------------------------------------------- /VindLU/preprocess/create_sqlite_db.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sqlite3 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | def convert_to_sqlite_db(src_path: str, dst_path: str, media_type: str): 10 | """TODO: Docstring for convert_to_sqlite_db. 11 | 12 | Args: 13 | src_path (str): The src json annotation file path. 14 | dst_path (str): The saved sqlite db path. 15 | media_type (str): The media type. Either "image" or "video". 16 | 17 | """ 18 | 19 | # con = sqlite3.connect("file:"+dst_path+"?mode=ro",uri=True) 20 | con = sqlite3.connect(dst_path) 21 | cur = con.cursor() 22 | print(f"creating table") 23 | cur.execute("DROP TABLE IF EXISTS annos") 24 | table_sql = f""" CREATE TABLE IF NOT EXISTS `annos` ( 25 | `id` integer PRIMARY KEY, 26 | `{media_type}` text, 27 | `caption` text 28 | )""" 29 | print(table_sql) 30 | cur.execute(table_sql) 31 | 32 | with open(src_path, "r") as f: 33 | anno_list = json.load(f) 34 | filenames = [anno[media_type] for anno in anno_list] 35 | captions = [anno["caption"] for anno in anno_list] 36 | ids = list(range(len(filenames))) 37 | records = list(zip(ids, filenames, captions)) 38 | 39 | cur.executemany(f"INSERT INTO annos (id, {media_type}, caption) VALUES (?,?,?)", records) 40 | con.commit() 41 | con.close() 42 | 43 | 44 | def read_sqlite(db_path): 45 | """TODO: Docstring for read_sqlite. 46 | 47 | Args: 48 | db_path (TODO): TODO 49 | 50 | Returns: TODO 51 | 52 | """ 53 | con = sqlite3.connect("file:" + db_path + "?mode=ro", uri=True) 54 | cur = con.cursor() 55 | ids = cur.execute("SELECT id FROM annos").fetchall() 56 | ids = [id[0] for id in ids] 57 | print("number medias:", len(ids)) 58 | np.random.shuffle(ids) 59 | for id in ids[:100]: 60 | t1 = time.time() 61 | query = f"SELECT * FROM annos WHERE id = {id};" 62 | res = cur.execute(query) 63 | newid, filename, caption = res.fetchone() 64 | t2 = time.time() 65 | print(f"time: {t2-t1}s", id, newid, filename, caption) 66 | con.close() 67 | 68 | 69 | def convert(json_filename, media_type): 70 | """convert json annotations to sqlite. 71 | Returns: TODO 72 | 73 | """ 74 | print(f"\n--------converting {filename}----------------") 75 | src_path = os.path.join(os.environ["SL_DATA_DIR"], "anno_pretrain", json_filename) 76 | dst_path = src_path.replace(".json", ".sqlite.db") 77 | convert_to_sqlite_db(src_path, dst_path, media_type) 78 | read_sqlite(dst_path) 79 | 80 | 81 | if __name__ == "__main__": 82 | filenames = [ 83 | ["cc12m.json", "image"], 84 | ["cc3m_train.json", "image"], 85 | ["cc3m_val.json", "image"], 86 | ["coco.json", "image"], 87 | ["sbu.json", "image"], 88 | ["vg.json", "image"], 89 | ["webvid_10m_train.json", "video"], 90 | ["webvid_10m_val.json", "video"], 91 | ["webvid_train.json", "video"], 92 | ] 93 | for filename, media_type in filenames: 94 | convert(filename, media_type) 95 | -------------------------------------------------------------------------------- /VindLU/preprocess/gen_webvid10m_label.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from multiprocessing import Pool 4 | 5 | import pandas 6 | import tqdm 7 | 8 | from utils import get_video_duration 9 | 10 | data_dir = os.path.join(os.environ["SL_DATA_DIR"], "videos_images/webvid_10m_2fps_224") 11 | downloaded_vidlist = data_dir.replace("webvid_10m_2fps_224", "webvid_10m_vidlist.txt") 12 | 13 | def gen_valid_vidlist(): 14 | """generate the valid video list. 15 | Returns: set. The valid 16 | 17 | """ 18 | with open(downloaded_vidlist, 'r') as f: 19 | videos = f.read().splitlines() 20 | return set(videos) 21 | 22 | 23 | def gen_labels(src_file, dst_file): 24 | """TODO: Docstring for gen_labels. 25 | 26 | Args: 27 | src_file (str): The original csv file 28 | dst_file (str): the output json file 29 | data_dir (str): The data to store the videos. 30 | 31 | """ 32 | df = pandas.read_csv(src_file) 33 | vids = df["videoid"].values.tolist() 34 | captions = df["name"].values.tolist() 35 | 36 | valid_videos = gen_valid_vidlist() 37 | 38 | labels = [] 39 | num_invalid = 0 40 | for vid, caption in tqdm.tqdm(zip(vids, captions), total=len(vids)): 41 | vid_name = f"{vid}.mp4" 42 | if vid_name in valid_videos: 43 | example = {"video": vid_name, "caption": caption, "duration": 0} 44 | labels.append(example) 45 | else: 46 | num_invalid += 1 47 | 48 | # pool = Pool(128) 49 | # labels = [] 50 | # for example in tqdm.tqdm(pool.imap_unordered(gen_one_example, zip(vids,captions)), total=len(vids)): 51 | # labels.append(example) 52 | print(f"number of valid videos: {len(labels)}. invalid: {num_invalid}") 53 | with open(dst_file, "w") as f: 54 | json.dump(labels, f) 55 | 56 | 57 | def webvid10m(subset): 58 | print(f"generate labels for subset: {subset}") 59 | assert subset in ["train", "val"] 60 | src_file = f"/data/shared/datasets/webvid-10M/raw_data/results_10M_{subset}.csv" 61 | dst_file = os.path.join( 62 | os.environ["SL_DATA_DIR"], "anno_pretrain", f"webvid_10m_{subset}.json" 63 | ) 64 | gen_labels(src_file, dst_file) 65 | 66 | 67 | if __name__ == "__main__": 68 | webvid10m("val") 69 | webvid10m("train") 70 | -------------------------------------------------------------------------------- /VindLU/preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | 5 | def get_video_duration(filename): 6 | 7 | result = subprocess.check_output( 8 | f'ffprobe -v quiet -show_streams -select_streams v:0 -of json "{filename}"', shell=True 9 | ).decode() 10 | fields = json.loads(result)["streams"][0] 11 | 12 | duration = float(fields["duration"]) 13 | return duration 14 | 15 | if __name__ == "__main__": 16 | import os 17 | fp = os.path.join(os.environ["SL_DATA_DIR"], "videos_images/webvid_10m_2fps_224/22920757.mp4") 18 | print(get_video_duration(fp)) 19 | -------------------------------------------------------------------------------- /VindLU/tests/test_cfg.py: -------------------------------------------------------------------------------- 1 | from utils.config import Config 2 | 3 | cfg = Config.get_config() 4 | 5 | cfg_text = Config.pretty_text(cfg) 6 | print(cfg_text) 7 | -------------------------------------------------------------------------------- /VindLU/tools/submit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | mode=$1 # slurm or local 4 | nnodes=$2 5 | ngpus=$3 6 | cmd=${@:4} # the command to run. i.e. tasks/pretrain.py ... 7 | 8 | if [[ "$mode" == "slurm" ]]; then # slurm 9 | master_node=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 10 | all_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 11 | echo "All nodes used: ${all_nodes}" 12 | echo "Master node ${master_node}" 13 | 14 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$master_node" hostname --ip-address | awk '{print $1}') 15 | # head_node_ip=$master_node 16 | rdzv_endpoint="${head_node_ip}:${MASTER_PORT:-40000}" 17 | bin="srun" 18 | 19 | else # local 20 | rdzv_endpoint="${MASTER_ADDR:-localhost}:${MASTER_PORT:-40000}" 21 | bin="" 22 | fi 23 | 24 | echo "PYTHONPATH: ${PYTHONPATH}" 25 | which_python=$(which python) 26 | echo "which python: ${which_python}" 27 | export PYTHONPATH=${PYTHONPATH}:${which_python} 28 | export PYTHONPATH=${PYTHONPATH}:. 29 | echo "PYTHONPATH: ${PYTHONPATH}" 30 | 31 | #run command 32 | $bin torchrun --nnodes=$nnodes \ 33 | --nproc_per_node=$ngpus \ 34 | --rdzv_backend=c10d \ 35 | --rdzv_endpoint=${rdzv_endpoint} \ 36 | $cmd 37 | 38 | echo "Finish at dir: ${PWD}" 39 | ############### ======> Your training scripts [END] 40 | -------------------------------------------------------------------------------- /VindLU/tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import socket 4 | 5 | 6 | def has_slurm(): 7 | """determine the system has slurm or not 8 | Returns: True if has else False. 9 | 10 | """ 11 | return shutil.which("sbatch") is not None 12 | 13 | def random_port(): 14 | """random a unused port 15 | Returns: str 16 | 17 | """ 18 | with socket.socket() as s: 19 | s.bind(("", 0)) 20 | return s.getsockname()[1] 21 | 22 | def runcmd(cmd): 23 | """run command 24 | 25 | Args: 26 | cmd (str): The command to run 27 | 28 | """ 29 | os.system(cmd) 30 | -------------------------------------------------------------------------------- /VindLU/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from os.path import dirname, join 5 | 6 | from utils.config import Config 7 | from utils.distributed import init_distributed_mode, is_main_process 8 | from utils.logger import setup_logger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_config(): 14 | """Conbine yaml config and command line config with OmegaConf. 15 | Also converts types, e.g., `'None'` (str) --> `None` (None) 16 | """ 17 | config = Config.get_config() 18 | if config.debug: 19 | config.wandb.enable = False 20 | return config 21 | 22 | 23 | def setup_evaluate_config(config): 24 | """setup evaluation default settings, e.g., disable wandb""" 25 | assert config.evaluate 26 | config.wandb.enable = False 27 | if config.output_dir is None: 28 | config.output_dir = join(dirname(config.pretrained_path), "eval") 29 | return config 30 | 31 | 32 | def setup_output_dir(output_dir, excludes=["code"]): 33 | """ensure not overwritting an exisiting/non-empty output dir""" 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir, exist_ok=False) 36 | else: 37 | existing_dirs_files = os.listdir(output_dir) # list 38 | remaining = set(existing_dirs_files) - set(excludes) 39 | remaining = [e for e in remaining if "slurm" not in e] 40 | remaining = [e for e in remaining if ".out" not in e] 41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 42 | logger.warn(f"remaining dirs or files: {remaining}") 43 | 44 | 45 | def setup_main(): 46 | """ 47 | Setup config, logger, output_dir, etc. 48 | Shared for pretrain and all downstream tasks. 49 | """ 50 | config = setup_config() 51 | if hasattr(config, "evaluate") and config.evaluate: 52 | config = setup_evaluate_config(config) 53 | init_distributed_mode(config) 54 | 55 | if is_main_process(): 56 | setup_output_dir(config.output_dir, excludes=["code"]) 57 | setup_logger(output=config.output_dir, color=True, name="vindlu") 58 | logger.info(f"config: {Config.pretty_text(config)}") 59 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 60 | return config 61 | -------------------------------------------------------------------------------- /VindLU/utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /VindLU/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | -------------------------------------------------------------------------------- /figs/TMW_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/figs/TMW_pipeline.png -------------------------------------------------------------------------------- /testa/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /testa/configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /testa/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /testa/configs/pretrain_timesformer.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/home/ma-user/work/renshuhuai/data/webvid/pretrain.json', '/home/ma-user/work/renshuhuai/data/cc3m/pretrain.json'] 2 | 3 | # size of vit model; base or large 4 | vit: 'timesformer' 5 | patch_size: 16 6 | learnable_temporal_scaling: True 7 | attention_type: 'divided_space_time' 8 | 9 | vit_grad_ckpt: False 10 | vit_ckpt_layer: 0 11 | 12 | video_resize: 256 13 | image_size: 224 14 | batch_size: 12 15 | num_frames: 8 16 | num_frm_train: 8 17 | max_words: 30 18 | frm_sampling_strategy: 'headtail' 19 | 20 | queue_size: 57600 21 | alpha: 0.4 22 | 23 | # optimizer 24 | weight_decay: 0.05 25 | init_lr: 5e-06 26 | min_lr: 5e-07 27 | warmup_lr: 1e-6 28 | lr_decay_rate: 0.95 29 | max_epoch: 20 30 | warmup_steps: 5000 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /testa/configs/pretrain_timesformer_coco.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/home/chensishuo/data/coco/pretrain.json'] 2 | 3 | 4 | 5 | # size of vit model; base or large 6 | vit: 'timesformer' 7 | patch_size: 16 8 | learnable_temporal_scaling: True 9 | attention_type: 'divided_space_time' 10 | 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | 14 | 15 | image_size: 224 16 | batch_size: 24 17 | num_frm_train: 2 18 | max_words: 30 19 | frm_sampling_strategy: 'headtail' 20 | 21 | queue_size: 57600 22 | alpha: 0.4 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | init_lr: 1e-4 27 | min_lr: 1e-6 28 | warmup_lr: 1e-6 29 | lr_decay_rate: 0.9 30 | max_epoch: 20 31 | warmup_steps: 3000 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /testa/configs/pretrain_timesformer_from_blip.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/home/ma-user/work/renshuhuai/data/webvid/pretrain.json', '/home/ma-user/work/renshuhuai/data/cc3m/pretrain.json'] 2 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 3 | 4 | # size of vit model; base or large 5 | vit: 'timesformer' 6 | patch_size: 16 7 | learnable_temporal_scaling: True 8 | attention_type: 'divided_space_time' 9 | 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | 13 | video_resize: 256 14 | image_size: 224 15 | batch_size: 12 16 | num_frames: 8 17 | num_frm_train: 8 18 | max_words: 30 19 | frm_sampling_strategy: 'headtail' 20 | 21 | queue_size: 57600 22 | alpha: 0.4 23 | 24 | # optimizer 25 | weight_decay: 0.05 26 | init_lr: 5e-06 27 | min_lr: 5e-07 28 | warmup_lr: 1e-6 29 | lr_decay_rate: 0.95 30 | max_epoch: 20 31 | warmup_steps: 5000 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /testa/configs/pretrain_video.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/home/chensishuo/data/webvid/webvid_demo.json'] 2 | 3 | 4 | 5 | # size of vit model; base or large 6 | vit: 'base' 7 | vit_grad_ckpt: False 8 | vit_ckpt_layer: 0 9 | 10 | image_size: 224 11 | batch_size: 4 12 | num_frm_train: 8 13 | max_words: 30 14 | 15 | queue_size: 57600 16 | alpha: 0.4 17 | 18 | # optimizer 19 | weight_decay: 0.05 20 | init_lr: 3e-4 21 | min_lr: 1e-6 22 | warmup_lr: 1e-6 23 | lr_decay_rate: 0.9 24 | max_epoch: 20 25 | warmup_steps: 3000 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /testa/configs/retrieval_activitynet_f32.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/Activitynet_Captions/anet_6fps_224/' 2 | ann_root: '/home/renshuhuai/data/Activitynet_Captions/' 3 | dataset: 'retrieval_activitynet' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,] 11 | merging_type: frame&patch 12 | 13 | max_words: 64 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 4 31 | batch_size_test: 32 32 | init_lr: 1e-5 33 | 34 | num_frames: 32 35 | num_frm_train: 32 36 | num_frm_test: 32 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 20 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_activitynet_f96.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/Activitynet_Captions/anet_6fps_224/' 2 | ann_root: '/home/renshuhuai/data/Activitynet_Captions/' 3 | dataset: 'retrieval_activitynet' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,] 11 | merging_type: frame&patch 12 | 13 | max_words: 64 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 2 31 | batch_size_test: 8 32 | init_lr: 1e-5 33 | 34 | num_frames: 96 35 | num_frm_train: 96 36 | num_frm_test: 96 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 20 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_condensedmovies_f32.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/ubuntu/efs/data/CondensedMovies/videos/' 2 | ann_root: '/home/ubuntu/efs/data/CondensedMovies/metadata/' 3 | dataset: 'retrieval_condensedmovies' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,] 11 | merging_type: frame&patch 12 | 13 | max_words: 32 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 16 31 | batch_size_test: 1 32 | init_lr: 1e-5 33 | 34 | num_frames: 1200 35 | num_frm_train: 1200 36 | num_frm_test: 1200 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 0 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_condensedmovies_f96.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/ubuntu/efs/data/CondensedMovies/videos/' 2 | ann_root: '/home/ubuntu/efs/data/CondensedMovies/metadata/' 3 | dataset: 'retrieval_condensedmovies' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,] 11 | merging_type: frame&patch 12 | 13 | max_words: 32 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 4 31 | batch_size_test: 32 32 | init_lr: 1e-5 33 | 34 | num_frames: 96 35 | num_frm_train: 96 36 | num_frm_test: 96 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 10 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_didemo_f32.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/didemo/didemo_30fps_224_trimed30/' 2 | ann_root: '/home/renshuhuai/data/didemo/' 3 | dataset: 'retrieval_didemo' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,] 11 | merging_type: frame&patch 12 | 13 | max_words: 64 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 4 31 | batch_size_test: 32 32 | init_lr: 1e-5 33 | 34 | num_frames: 32 35 | num_frm_train: 32 36 | num_frm_test: 32 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 10 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_didemo_f96.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/didemo/didemo_30fps_224_trimed30/' 2 | ann_root: '/home/renshuhuai/data/didemo/' 3 | dataset: 'retrieval_didemo' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,] 11 | merging_type: frame&patch 12 | 13 | max_words: 64 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 2 31 | batch_size_test: 8 32 | init_lr: 1e-5 33 | 34 | num_frames: 96 35 | num_frm_train: 96 36 | num_frm_test: 96 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 10 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_queryd_f32.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/QuerYD/QuerYD_downloader/videos/' 2 | ann_root: '/home/renshuhuai/data/QuerYD/QuerYD-experts/data/QuerYD/structured-symlinks/' 3 | dataset: 'retrieval_queryd' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,] 11 | merging_type: frame&patch 12 | 13 | max_words: 128 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 4 31 | batch_size_test: 32 32 | init_lr: 1e-5 33 | 34 | num_frames: 32 35 | num_frm_train: 32 36 | num_frm_test: 32 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 10 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/retrieval_queryd_f96.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/QuerYD/QuerYD_downloader/videos/' 2 | ann_root: '/home/renshuhuai/data/QuerYD/QuerYD-experts/data/QuerYD/structured-symlinks/' 3 | dataset: 'retrieval_queryd' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,4,8,] 11 | merging_type: frame&patch 12 | 13 | max_words: 128 14 | 15 | # timesformer video encoder 16 | vit: 'timesformer' 17 | patch_size: 16 18 | learnable_temporal_scaling: False 19 | attention_type: 'divided_space_time' 20 | vit_grad_ckpt: False 21 | vit_ckpt_layer: 0 22 | vision_width: 768 # 1024 23 | num_heads: 12 # 16 24 | 25 | # vit video encoder 26 | #vit: 'base' 27 | #vit_grad_ckpt: False 28 | #vit_ckpt_layer: 0 29 | 30 | batch_size_train: 2 31 | batch_size_test: 8 32 | init_lr: 1e-5 33 | 34 | num_frames: 96 35 | num_frm_train: 96 36 | num_frm_test: 96 37 | 38 | image_size: 224 39 | queue_size: 32 40 | alpha: 0.4 41 | k_test: 128 42 | negative_all_rank: True 43 | 44 | # optimizer 45 | weight_decay: 0.05 46 | min_lr: 0 47 | max_epoch: 10 48 | 49 | -------------------------------------------------------------------------------- /testa/configs/vqa_activitynet_f32.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/home/renshuhuai/data/Activitynet-QA/anet_6fps_224' 2 | ann_root: '/home/renshuhuai/data/Activitynet-QA/annos' 3 | dataset: 'activitynet_qa' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: '/home/renshuhuai/TESTA/pretrained/testa_model_base_pretrain.pth' 7 | 8 | # token merging 9 | token_merging: True 10 | testa_r: [1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,] 11 | merging_type: frame&patch 12 | 13 | # timesformer video encoder 14 | vit: 'timesformer' 15 | patch_size: 16 16 | learnable_temporal_scaling: False 17 | attention_type: 'divided_space_time' 18 | vit_grad_ckpt: False 19 | vit_ckpt_layer: 0 20 | vision_width: 768 # 1024 21 | num_heads: 12 # 16 22 | 23 | # vit video encoder 24 | #vit: 'base' 25 | #vit_grad_ckpt: False 26 | #vit_ckpt_layer: 0 27 | 28 | batch_size_train: 4 29 | batch_size_test: 32 30 | init_lr: 2e-5 31 | 32 | num_frames: 32 33 | num_frm_train: 32 34 | num_frm_test: 32 35 | 36 | image_size: 224 37 | 38 | k_test: 128 39 | inference: 'rank' 40 | 41 | # optimizer 42 | weight_decay: 0.05 43 | min_lr: 0 44 | max_epoch: 10 -------------------------------------------------------------------------------- /testa/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/testa/models/__init__.py -------------------------------------------------------------------------------- /testa/models/timesformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .utils.env import setup_environment 4 | 5 | setup_environment() 6 | -------------------------------------------------------------------------------- /testa/models/timesformer/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /testa/models/timesformer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .build import DATASET_REGISTRY, build_dataset # noqa 4 | from .kinetics import Kinetics # noqa 5 | from .ssv2 import Ssv2 # noqa 6 | -------------------------------------------------------------------------------- /testa/models/timesformer/datasets/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from fvcore.common.registry import Registry 4 | 5 | DATASET_REGISTRY = Registry("DATASET") 6 | DATASET_REGISTRY.__doc__ = """ 7 | Registry for dataset. 8 | 9 | The registered object will be called with `obj(cfg, split)`. 10 | The call should return a `torch.utils.data.Dataset` object. 11 | """ 12 | 13 | 14 | def build_dataset(dataset_name, cfg, split): 15 | """ 16 | Build a dataset, defined by `dataset_name`. 17 | Args: 18 | dataset_name (str): the name of the dataset to be constructed. 19 | cfg (CfgNode): configs. Details can be found in 20 | slowfast/config/defaults.py 21 | split (str): the split of the data loader. Options include `train`, 22 | `val`, and `test`. 23 | Returns: 24 | Dataset: a constructed dataset specified by dataset_name. 25 | """ 26 | # Capitalize the the first letter of the dataset_name since the dataset_name 27 | # in configs may be in lowercase but the name of dataset class should always 28 | # start with an uppercase letter. 29 | name = dataset_name.capitalize() 30 | return DATASET_REGISTRY.get(name)(cfg, split) 31 | -------------------------------------------------------------------------------- /testa/models/timesformer/datasets/multigrid_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Helper functions for multigrid training.""" 4 | 5 | import numpy as np 6 | from torch._six import int_classes as _int_classes 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class ShortCycleBatchSampler(Sampler): 11 | """ 12 | Extend Sampler to support "short cycle" sampling. 13 | See paper "A Multigrid Method for Efficiently Training Video Models", 14 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details. 15 | """ 16 | 17 | def __init__(self, sampler, batch_size, drop_last, cfg): 18 | if not isinstance(sampler, Sampler): 19 | raise ValueError( 20 | "sampler should be an instance of " 21 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 22 | ) 23 | if ( 24 | not isinstance(batch_size, _int_classes) 25 | or isinstance(batch_size, bool) 26 | or batch_size <= 0 27 | ): 28 | raise ValueError( 29 | "batch_size should be a positive integer value, " 30 | "but got batch_size={}".format(batch_size) 31 | ) 32 | if not isinstance(drop_last, bool): 33 | raise ValueError( 34 | "drop_last should be a boolean value, but got " 35 | "drop_last={}".format(drop_last) 36 | ) 37 | self.sampler = sampler 38 | self.drop_last = drop_last 39 | 40 | bs_factor = [ 41 | int( 42 | round( 43 | ( 44 | float(cfg.DATA.TRAIN_CROP_SIZE) 45 | / (s * cfg.MULTIGRID.DEFAULT_S) 46 | ) 47 | ** 2 48 | ) 49 | ) 50 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS 51 | ] 52 | 53 | self.batch_sizes = [ 54 | batch_size * bs_factor[0], 55 | batch_size * bs_factor[1], 56 | batch_size, 57 | ] 58 | 59 | def __iter__(self): 60 | counter = 0 61 | batch_size = self.batch_sizes[0] 62 | batch = [] 63 | for idx in self.sampler: 64 | batch.append((idx, counter % 3)) 65 | if len(batch) == batch_size: 66 | yield batch 67 | counter += 1 68 | batch_size = self.batch_sizes[counter % 3] 69 | batch = [] 70 | if len(batch) > 0 and not self.drop_last: 71 | yield batch 72 | 73 | def __len__(self): 74 | avg_batch_size = sum(self.batch_sizes) / 3.0 75 | if self.drop_last: 76 | return int(np.floor(len(self.sampler) / avg_batch_size)) 77 | else: 78 | return int(np.ceil(len(self.sampler) / avg_batch_size)) 79 | -------------------------------------------------------------------------------- /testa/models/timesformer/datasets/video_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | import av 4 | 5 | 6 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 7 | """ 8 | Given the path to the video, return the pyav video container. 9 | Args: 10 | path_to_vid (str): path to the video. 11 | multi_thread_decode (bool): if True, perform multi-thread decoding. 12 | backend (str): decoder backend, options include `pyav` and 13 | `torchvision`, default is `pyav`. 14 | Returns: 15 | container (container): video container. 16 | """ 17 | if backend == "torchvision": 18 | with open(path_to_vid, "rb") as fp: 19 | container = fp.read() 20 | return container 21 | elif backend == "pyav": 22 | #try: 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | #except: 28 | # container = None 29 | return container 30 | else: 31 | raise NotImplementedError("Unknown backend {}".format(backend)) 32 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .build import MODEL_REGISTRY, build_model # noqa 4 | from .custom_video_model_builder import * # noqa 5 | from .video_model_builder import ResNet, SlowFast # noqa 6 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Model construction functions.""" 4 | 5 | import torch 6 | from fvcore.common.registry import Registry 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for video model. 11 | 12 | The registered object will be called with `obj(cfg)`. 13 | The call should return a `torch.nn.Module` object. 14 | """ 15 | 16 | 17 | def build_model(cfg, gpu_id=None): 18 | """ 19 | Builds the video model. 20 | Args: 21 | cfg (configs): configs that contains the hyper-parameters to build the 22 | backbone. Details can be seen in slowfast/config/defaults.py. 23 | gpu_id (Optional[int]): specify the gpu index to build model. 24 | """ 25 | if torch.cuda.is_available(): 26 | assert ( 27 | cfg.NUM_GPUS <= torch.cuda.device_count() 28 | ), "Cannot use more GPU devices than available" 29 | else: 30 | assert ( 31 | cfg.NUM_GPUS == 0 32 | ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." 33 | 34 | # Construct the model 35 | name = cfg.MODEL.MODEL_NAME 36 | model = MODEL_REGISTRY.get(name)(cfg) 37 | 38 | if cfg.NUM_GPUS: 39 | if gpu_id is None: 40 | # Determine the GPU used by the current process 41 | cur_device = torch.cuda.current_device() 42 | else: 43 | cur_device = gpu_id 44 | # Transfer the model to the current GPU device 45 | model = model.cuda(device=cur_device) 46 | 47 | 48 | # Use multi-process data parallel model in the multi-gpu setting 49 | if cfg.NUM_GPUS > 1: 50 | # Make model replica operate on the current device 51 | model = torch.nn.parallel.DistributedDataParallel( 52 | module=model, device_ids=[cur_device], output_device=cur_device 53 | ) 54 | return model 55 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/conv2d_same.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Conv2d w/ Same Padding 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Tuple, Optional 8 | 9 | import math 10 | from typing import List, Tuple 11 | #from .padding import pad_same, get_padding_value 12 | 13 | # Dynamically pad input x with 'SAME' padding for conv with specified args 14 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 15 | ih, iw = x.size()[-2:] 16 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 17 | if pad_h > 0 or pad_w > 0: 18 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 19 | return x 20 | 21 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 22 | def get_same_padding(x: int, k: int, s: int, d: int): 23 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 24 | 25 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 26 | dynamic = False 27 | if isinstance(padding, str): 28 | # for any string padding, the padding will be calculated for you, one of three ways 29 | padding = padding.lower() 30 | if padding == 'same': 31 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 32 | if is_static_pad(kernel_size, **kwargs): 33 | # static case, no extra overhead 34 | padding = get_padding(kernel_size, **kwargs) 35 | else: 36 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 37 | padding = 0 38 | dynamic = True 39 | elif padding == 'valid': 40 | # 'VALID' padding, same as padding=0 41 | padding = 0 42 | else: 43 | # Default to PyTorch style 'same'-ish symmetric padding 44 | padding = get_padding(kernel_size, **kwargs) 45 | return padding, dynamic 46 | 47 | def conv2d_same( 48 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 49 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 50 | x = pad_same(x, weight.shape[-2:], stride, dilation) 51 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 52 | 53 | 54 | class Conv2dSame(nn.Conv2d): 55 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 56 | """ 57 | 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 59 | padding=0, dilation=1, groups=1, bias=True): 60 | super(Conv2dSame, self).__init__( 61 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 62 | 63 | def forward(self, x): 64 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 65 | 66 | 67 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 68 | padding = kwargs.pop('padding', '') 69 | kwargs.setdefault('bias', False) 70 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 71 | if is_dynamic: 72 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 73 | else: 74 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 75 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/custom_video_model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | """A More Flexible Video models.""" 5 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | class Linear(nn.Linear): 8 | def forward(self, input: torch.Tensor) -> torch.Tensor: 9 | if torch.jit.is_scripting(): 10 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 11 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 12 | else: 13 | return F.linear(input, self.weight, self.bias) 14 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Loss functions.""" 4 | 5 | import torch.nn as nn 6 | 7 | _LOSSES = { 8 | "cross_entropy": nn.CrossEntropyLoss, 9 | "bce": nn.BCELoss, 10 | "bce_logit": nn.BCEWithLogitsLoss, 11 | } 12 | 13 | 14 | def get_loss_func(loss_name): 15 | """ 16 | Retrieve the loss given the loss name. 17 | Args (int): 18 | loss_name: the name of the loss to use. 19 | """ 20 | if loss_name not in _LOSSES.keys(): 21 | raise NotImplementedError("Loss {} is not supported".format(loss_name)) 22 | return _LOSSES[loss_name] 23 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/operators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Custom operators.""" 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class Swish(nn.Module): 10 | """Swish activation function: x * sigmoid(x).""" 11 | 12 | def __init__(self): 13 | super(Swish, self).__init__() 14 | 15 | def forward(self, x): 16 | return SwishEfficient.apply(x) 17 | 18 | 19 | class SwishEfficient(torch.autograd.Function): 20 | """Swish activation function: x * sigmoid(x).""" 21 | 22 | @staticmethod 23 | def forward(ctx, x): 24 | result = x * torch.sigmoid(x) 25 | ctx.save_for_backward(x) 26 | return result 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | x = ctx.saved_variables[0] 31 | sigmoid_x = torch.sigmoid(x) 32 | return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) 33 | 34 | 35 | class SE(nn.Module): 36 | """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" 37 | 38 | def _round_width(self, width, multiplier, min_width=8, divisor=8): 39 | """ 40 | Round width of filters based on width multiplier 41 | Args: 42 | width (int): the channel dimensions of the input. 43 | multiplier (float): the multiplication factor. 44 | min_width (int): the minimum width after multiplication. 45 | divisor (int): the new width should be dividable by divisor. 46 | """ 47 | if not multiplier: 48 | return width 49 | 50 | width *= multiplier 51 | min_width = min_width or divisor 52 | width_out = max( 53 | min_width, int(width + divisor / 2) // divisor * divisor 54 | ) 55 | if width_out < 0.9 * width: 56 | width_out += divisor 57 | return int(width_out) 58 | 59 | def __init__(self, dim_in, ratio, relu_act=True): 60 | """ 61 | Args: 62 | dim_in (int): the channel dimensions of the input. 63 | ratio (float): the channel reduction ratio for squeeze. 64 | relu_act (bool): whether to use ReLU activation instead 65 | of Swish (default). 66 | divisor (int): the new width should be dividable by divisor. 67 | """ 68 | super(SE, self).__init__() 69 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 70 | dim_fc = self._round_width(dim_in, ratio) 71 | self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) 72 | self.fc1_act = nn.ReLU() if relu_act else Swish() 73 | self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) 74 | 75 | self.fc2_sig = nn.Sigmoid() 76 | 77 | def forward(self, x): 78 | x_in = x 79 | for module in self.children(): 80 | x = module(x) 81 | return x_in * x 82 | -------------------------------------------------------------------------------- /testa/models/timesformer/models/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Optimizer.""" 4 | 5 | import torch 6 | 7 | import timesformer.utils.lr_policy as lr_policy 8 | 9 | 10 | def construct_optimizer(model, cfg): 11 | """ 12 | Construct a stochastic gradient descent or ADAM optimizer with momentum. 13 | Details can be found in: 14 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method." 15 | and 16 | Diederik P.Kingma, and Jimmy Ba. 17 | "Adam: A Method for Stochastic Optimization." 18 | 19 | Args: 20 | model (model): model to perform stochastic gradient descent 21 | optimization or ADAM optimization. 22 | cfg (config): configs of hyper-parameters of SGD or ADAM, includes base 23 | learning rate, momentum, weight_decay, dampening, and etc. 24 | """ 25 | # Batchnorm parameters. 26 | bn_params = [] 27 | # Non-batchnorm parameters. 28 | non_bn_parameters = [] 29 | for name, p in model.named_parameters(): 30 | if "bn" in name: 31 | bn_params.append(p) 32 | else: 33 | non_bn_parameters.append(p) 34 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 35 | # In Caffe2 classification codebase the weight decay for batchnorm is 0.0. 36 | # Having a different weight decay on batchnorm might cause a performance 37 | # drop. 38 | optim_params = [ 39 | {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, 40 | {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, 41 | ] 42 | # Check all parameters will be passed into optimizer. 43 | assert len(list(model.parameters())) == len(non_bn_parameters) + len( 44 | bn_params 45 | ), "parameter size does not match: {} + {} != {}".format( 46 | len(non_bn_parameters), len(bn_params), len(list(model.parameters())) 47 | ) 48 | 49 | if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": 50 | return torch.optim.SGD( 51 | optim_params, 52 | lr=cfg.SOLVER.BASE_LR, 53 | momentum=cfg.SOLVER.MOMENTUM, 54 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 55 | dampening=cfg.SOLVER.DAMPENING, 56 | nesterov=cfg.SOLVER.NESTEROV, 57 | ) 58 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": 59 | return torch.optim.Adam( 60 | optim_params, 61 | lr=cfg.SOLVER.BASE_LR, 62 | betas=(0.9, 0.999), 63 | eps=1e-08, 64 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 65 | ) 66 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adamw": 67 | return torch.optim.AdamW( 68 | optim_params, 69 | lr=cfg.SOLVER.BASE_LR, 70 | betas=(0.9, 0.999), 71 | eps=1e-08, 72 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 73 | ) 74 | else: 75 | raise NotImplementedError( 76 | "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) 77 | ) 78 | 79 | 80 | def get_epoch_lr(cur_epoch, cfg): 81 | """ 82 | Retrieves the lr for the given epoch (as specified by the lr policy). 83 | Args: 84 | cfg (config): configs of hyper-parameters of ADAM, includes base 85 | learning rate, betas, and weight decays. 86 | cur_epoch (float): the number of epoch of the current training stage. 87 | """ 88 | return lr_policy.get_lr_at_epoch(cfg, cur_epoch) 89 | 90 | 91 | def set_lr(optimizer, new_lr): 92 | """ 93 | Sets the optimizer lr to the specified value. 94 | Args: 95 | optimizer (optim): the optimizer using to optimize the current network. 96 | new_lr (float): the new learning rate to set. 97 | """ 98 | for param_group in optimizer.param_groups: 99 | param_group["lr"] = new_lr 100 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/ava_evaluation/README.md: -------------------------------------------------------------------------------- 1 | The code under this folder is from the official [ActivityNet repo](https://github.com/activitynet/ActivityNet). 2 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/ava_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/testa/models/timesformer/utils/ava_evaluation/__init__.py -------------------------------------------------------------------------------- /testa/models/timesformer/utils/ava_evaluation/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | from . import np_box_list 27 | 28 | 29 | class BoxMaskList(np_box_list.BoxList): 30 | """Convenience wrapper for BoxList with masks. 31 | 32 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 33 | In particular, its constructor receives both boxes and masks. Note that the 34 | masks correspond to the full image. 35 | """ 36 | 37 | def __init__(self, box_data, mask_data): 38 | """Constructs box collection. 39 | 40 | Args: 41 | box_data: a numpy array of shape [N, 4] representing box coordinates 42 | mask_data: a numpy array of shape [N, height, width] representing masks 43 | with values are in {0,1}. The masks correspond to the full 44 | image. The height and the width will be equal to image height and width. 45 | 46 | Raises: 47 | ValueError: if bbox data is not a numpy array 48 | ValueError: if invalid dimensions for bbox data 49 | ValueError: if mask data is not a numpy array 50 | ValueError: if invalid dimension for mask data 51 | """ 52 | super(BoxMaskList, self).__init__(box_data) 53 | if not isinstance(mask_data, np.ndarray): 54 | raise ValueError("Mask data must be a numpy array.") 55 | if len(mask_data.shape) != 3: 56 | raise ValueError("Invalid dimensions for mask data.") 57 | if mask_data.dtype != np.uint8: 58 | raise ValueError( 59 | "Invalid data type for mask data: uint8 is required." 60 | ) 61 | if mask_data.shape[0] != box_data.shape[0]: 62 | raise ValueError( 63 | "There should be the same number of boxes and masks." 64 | ) 65 | self.data["masks"] = mask_data 66 | 67 | def get_masks(self): 68 | """Convenience function for accessing masks. 69 | 70 | Returns: 71 | a numpy array of shape [N, height, width] representing masks 72 | """ 73 | return self.get_field("masks") 74 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Functions for benchmarks. 4 | """ 5 | 6 | import numpy as np 7 | import pprint 8 | import torch 9 | import tqdm 10 | from fvcore.common.timer import Timer 11 | 12 | import timesformer.utils.logging as logging 13 | import timesformer.utils.misc as misc 14 | from timesformer.datasets import loader 15 | from timesformer.utils.env import setup_environment 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | def benchmark_data_loading(cfg): 21 | """ 22 | Benchmark the speed of data loading in PySlowFast. 23 | Args: 24 | 25 | cfg (CfgNode): configs. Details can be found in 26 | lib/config/defaults.py 27 | """ 28 | # Set up environment. 29 | setup_environment() 30 | # Set random seed from configs. 31 | np.random.seed(cfg.RNG_SEED) 32 | torch.manual_seed(cfg.RNG_SEED) 33 | 34 | # Setup logging format. 35 | logging.setup_logging(cfg.OUTPUT_DIR) 36 | 37 | # Print config. 38 | logger.info("Benchmark data loading with config:") 39 | logger.info(pprint.pformat(cfg)) 40 | 41 | timer = Timer() 42 | dataloader = loader.construct_loader(cfg, "train") 43 | logger.info( 44 | "Initialize loader using {:.2f} seconds.".format(timer.seconds()) 45 | ) 46 | # Total batch size across different machines. 47 | batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS 48 | log_period = cfg.BENCHMARK.LOG_PERIOD 49 | epoch_times = [] 50 | # Test for a few epochs. 51 | for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS): 52 | timer = Timer() 53 | timer_epoch = Timer() 54 | iter_times = [] 55 | if cfg.BENCHMARK.SHUFFLE: 56 | loader.shuffle_dataset(dataloader, cur_epoch) 57 | for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)): 58 | if cur_iter > 0 and cur_iter % log_period == 0: 59 | iter_times.append(timer.seconds()) 60 | ram_usage, ram_total = misc.cpu_mem_usage() 61 | logger.info( 62 | "Epoch {}: {} iters ({} videos) in {:.2f} seconds. " 63 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 64 | cur_epoch, 65 | log_period, 66 | log_period * batch_size, 67 | iter_times[-1], 68 | ram_usage, 69 | ram_total, 70 | ) 71 | ) 72 | timer.reset() 73 | epoch_times.append(timer_epoch.seconds()) 74 | ram_usage, ram_total = misc.cpu_mem_usage() 75 | logger.info( 76 | "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. " 77 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 78 | cur_epoch, 79 | len(dataloader), 80 | len(dataloader) * batch_size, 81 | epoch_times[-1], 82 | ram_usage, 83 | ram_total, 84 | ) 85 | ) 86 | logger.info( 87 | "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} " 88 | "(avg/std) seconds.".format( 89 | cur_epoch, 90 | log_period, 91 | log_period * batch_size, 92 | np.mean(iter_times), 93 | np.std(iter_times), 94 | ) 95 | ) 96 | logger.info( 97 | "On average every epoch ({} videos) takes {:.2f}/{:.2f} " 98 | "(avg/std) seconds.".format( 99 | len(dataloader) * batch_size, 100 | np.mean(epoch_times), 101 | np.std(epoch_times), 102 | ) 103 | ) 104 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/bn_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """bn helper.""" 4 | 5 | import itertools 6 | import torch 7 | 8 | 9 | @torch.no_grad() 10 | def compute_and_update_bn_stats(model, data_loader, num_batches=200): 11 | """ 12 | Compute and update the batch norm stats to make it more precise. During 13 | training both bn stats and the weight are changing after every iteration, 14 | so the bn can not precisely reflect the latest stats of the current model. 15 | Here the bn stats is recomputed without change of weights, to make the 16 | running mean and running var more precise. 17 | Args: 18 | model (model): the model using to compute and update the bn stats. 19 | data_loader (dataloader): dataloader using to provide inputs. 20 | num_batches (int): running iterations using to compute the stats. 21 | """ 22 | 23 | # Prepares all the bn layers. 24 | bn_layers = [ 25 | m 26 | for m in model.modules() 27 | if any( 28 | ( 29 | isinstance(m, bn_type) 30 | for bn_type in ( 31 | torch.nn.BatchNorm1d, 32 | torch.nn.BatchNorm2d, 33 | torch.nn.BatchNorm3d, 34 | ) 35 | ) 36 | ) 37 | ] 38 | 39 | # In order to make the running stats only reflect the current batch, the 40 | # momentum is disabled. 41 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 42 | # Setting the momentum to 1.0 to compute the stats without momentum. 43 | momentum_actual = [bn.momentum for bn in bn_layers] 44 | for bn in bn_layers: 45 | bn.momentum = 1.0 46 | 47 | # Calculates the running iterations for precise stats computation. 48 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 49 | running_square_mean = [torch.zeros_like(bn.running_var) for bn in bn_layers] 50 | 51 | for ind, (inputs, _, _) in enumerate( 52 | itertools.islice(data_loader, num_batches) 53 | ): 54 | # Forwards the model to update the bn stats. 55 | if isinstance(inputs, (list,)): 56 | for i in range(len(inputs)): 57 | inputs[i] = inputs[i].float().cuda(non_blocking=True) 58 | else: 59 | inputs = inputs.cuda(non_blocking=True) 60 | model(inputs) 61 | 62 | for i, bn in enumerate(bn_layers): 63 | # Accumulates the bn stats. 64 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 65 | # $E(x^2) = Var(x) + E(x)^2$. 66 | cur_square_mean = bn.running_var + bn.running_mean ** 2 67 | running_square_mean[i] += ( 68 | cur_square_mean - running_square_mean[i] 69 | ) / (ind + 1) 70 | 71 | for i, bn in enumerate(bn_layers): 72 | bn.running_mean = running_mean[i] 73 | # Var(x) = $E(x^2) - E(x)^2$. 74 | bn.running_var = running_square_mean[i] - bn.running_mean ** 2 75 | # Sets the precise bn stats. 76 | bn.momentum = momentum_actual[i] 77 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Set up Environment.""" 4 | 5 | import logging as logging 6 | 7 | _ENV_SETUP_DONE = False 8 | 9 | 10 | def setup_environment(): 11 | global _ENV_SETUP_DONE 12 | if _ENV_SETUP_DONE: 13 | return 14 | _ENV_SETUP_DONE = True 15 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Logging.""" 4 | 5 | import atexit 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | from fvcore.common.file_io import PathManager 14 | 15 | import timesformer.utils.distributed as du 16 | 17 | 18 | def _suppress_print(): 19 | """ 20 | Suppresses printing from the current process. 21 | """ 22 | 23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 24 | pass 25 | 26 | builtins.print = print_pass 27 | 28 | 29 | @functools.lru_cache(maxsize=None) 30 | def _cached_log_stream(filename): 31 | io = PathManager.open(filename, "a", buffering=1024) 32 | atexit.register(io.close) 33 | return io 34 | 35 | 36 | def setup_logging(output_dir=None): 37 | """ 38 | Sets up the logging for multiple processes. Only enable the logging for the 39 | master process, and suppress logging for the non-master processes. 40 | """ 41 | # Set up logging format. 42 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 43 | 44 | if du.is_master_proc(): 45 | # Enable logging for the master process. 46 | logging.root.handlers = [] 47 | else: 48 | # Suppress logging for non-master processes. 49 | _suppress_print() 50 | 51 | logger = logging.getLogger() 52 | logger.setLevel(logging.DEBUG) 53 | logger.propagate = False 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 56 | datefmt="%m/%d %H:%M:%S", 57 | ) 58 | 59 | if du.is_master_proc(): 60 | ch = logging.StreamHandler(stream=sys.stdout) 61 | ch.setLevel(logging.DEBUG) 62 | ch.setFormatter(plain_formatter) 63 | logger.addHandler(ch) 64 | 65 | if output_dir is not None and du.is_master_proc(du.get_world_size()): 66 | filename = os.path.join(output_dir, "stdout.log") 67 | fh = logging.StreamHandler(_cached_log_stream(filename)) 68 | fh.setLevel(logging.DEBUG) 69 | fh.setFormatter(plain_formatter) 70 | logger.addHandler(fh) 71 | 72 | 73 | def get_logger(name): 74 | """ 75 | Retrieve the logger with the specified name or, if name is None, return a 76 | logger which is the root logger of the hierarchy. 77 | Args: 78 | name (string): name of the logger. 79 | """ 80 | return logging.getLogger(name) 81 | 82 | 83 | def log_json_stats(stats): 84 | """ 85 | Logs json stats. 86 | Args: 87 | stats (dict): a dictionary of statistical information to log. 88 | """ 89 | stats = { 90 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 91 | for k, v in stats.items() 92 | } 93 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 94 | logger = get_logger(__name__) 95 | logger.info("json_stats: {:s}".format(json_stats)) 96 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/lr_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Learning rate policy.""" 4 | 5 | import math 6 | 7 | 8 | def get_lr_at_epoch(cfg, cur_epoch): 9 | """ 10 | Retrieve the learning rate of the current epoch with the option to perform 11 | warm up in the beginning of the training stage. 12 | Args: 13 | cfg (CfgNode): configs. Details can be found in 14 | slowfast/config/defaults.py 15 | cur_epoch (float): the number of epoch of the current training stage. 16 | """ 17 | lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) 18 | # Perform warm up. 19 | if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: 20 | lr_start = cfg.SOLVER.WARMUP_START_LR 21 | lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( 22 | cfg, cfg.SOLVER.WARMUP_EPOCHS 23 | ) 24 | alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS 25 | lr = cur_epoch * alpha + lr_start 26 | return lr 27 | 28 | 29 | def lr_func_cosine(cfg, cur_epoch): 30 | """ 31 | Retrieve the learning rate to specified values at specified epoch with the 32 | cosine learning rate schedule. Details can be found in: 33 | Ilya Loshchilov, and Frank Hutter 34 | SGDR: Stochastic Gradient Descent With Warm Restarts. 35 | Args: 36 | cfg (CfgNode): configs. Details can be found in 37 | slowfast/config/defaults.py 38 | cur_epoch (float): the number of epoch of the current training stage. 39 | """ 40 | assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR 41 | return ( 42 | cfg.SOLVER.COSINE_END_LR 43 | + (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) 44 | * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) 45 | * 0.5 46 | ) 47 | 48 | 49 | def lr_func_steps_with_relative_lrs(cfg, cur_epoch): 50 | """ 51 | Retrieve the learning rate to specified values at specified epoch with the 52 | steps with relative learning rate schedule. 53 | Args: 54 | cfg (CfgNode): configs. Details can be found in 55 | slowfast/config/defaults.py 56 | cur_epoch (float): the number of epoch of the current training stage. 57 | """ 58 | ind = get_step_index(cfg, cur_epoch) 59 | return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR 60 | 61 | 62 | def get_step_index(cfg, cur_epoch): 63 | """ 64 | Retrieves the lr step index for the given epoch. 65 | Args: 66 | cfg (CfgNode): configs. Details can be found in 67 | slowfast/config/defaults.py 68 | cur_epoch (float): the number of epoch of the current training stage. 69 | """ 70 | steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] 71 | for ind, step in enumerate(steps): # NoQA 72 | if cur_epoch < step: 73 | break 74 | return ind - 1 75 | 76 | 77 | def get_lr_func(lr_policy): 78 | """ 79 | Given the configs, retrieve the specified lr policy function. 80 | Args: 81 | lr_policy (string): the learning rate policy to use for the job. 82 | """ 83 | policy = "lr_func_" + lr_policy 84 | if policy not in globals(): 85 | raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) 86 | else: 87 | return globals()[policy] 88 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Multiprocessing helpers.""" 4 | 5 | import torch 6 | 7 | 8 | def run( 9 | local_rank, 10 | num_proc, 11 | func, 12 | init_method, 13 | shard_id, 14 | num_shards, 15 | backend, 16 | cfg, 17 | output_queue=None, 18 | ): 19 | """ 20 | Runs a function from a child process. 21 | Args: 22 | local_rank (int): rank of the current process on the current machine. 23 | num_proc (int): number of processes per machine. 24 | func (function): function to execute on each of the process. 25 | init_method (string): method to initialize the distributed training. 26 | TCP initialization: equiring a network address reachable from all 27 | processes followed by the port. 28 | Shared file-system initialization: makes use of a file system that 29 | is shared and visible from all machines. The URL should start with 30 | file:// and contain a path to a non-existent file on a shared file 31 | system. 32 | shard_id (int): the rank of the current machine. 33 | num_shards (int): number of overall machines for the distributed 34 | training job. 35 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 36 | supports, each with different capabilities. Details can be found 37 | here: 38 | https://pytorch.org/docs/stable/distributed.html 39 | cfg (CfgNode): configs. Details can be found in 40 | slowfast/config/defaults.py 41 | output_queue (queue): can optionally be used to return values from the 42 | master process. 43 | """ 44 | # Initialize the process group. 45 | world_size = num_proc * num_shards 46 | rank = shard_id * num_proc + local_rank 47 | 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | except Exception as e: 56 | raise e 57 | 58 | torch.cuda.set_device(local_rank) 59 | ret = func(cfg) 60 | if output_queue is not None and local_rank == 0: 61 | output_queue.put(ret) 62 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Argument parser functions.""" 4 | 5 | import argparse 6 | import sys 7 | 8 | import timesformer.utils.checkpoint as cu 9 | from timesformer.config.defaults import get_cfg 10 | 11 | 12 | def parse_args(): 13 | """ 14 | Parse the following arguments for a default parser for PySlowFast users. 15 | Args: 16 | shard_id (int): shard id for the current machine. Starts from 0 to 17 | num_shards - 1. If single machine is used, then set shard id to 0. 18 | num_shards (int): number of shards using by the job. 19 | init_method (str): initialization method to launch the job with multiple 20 | devices. Options includes TCP or shared file-system for 21 | initialization. details can be find in 22 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 23 | cfg (str): path to the config file. 24 | opts (argument): provide addtional options from the command line, it 25 | overwrites the config loaded from file. 26 | """ 27 | parser = argparse.ArgumentParser( 28 | description="Provide SlowFast video training and testing pipeline." 29 | ) 30 | parser.add_argument( 31 | "--shard_id", 32 | help="The shard id of current node, Starts from 0 to num_shards - 1", 33 | default=0, 34 | type=int, 35 | ) 36 | parser.add_argument( 37 | "--num_shards", 38 | help="Number of shards using by the job", 39 | default=1, 40 | type=int, 41 | ) 42 | parser.add_argument( 43 | "--init_method", 44 | help="Initialization method, includes TCP or shared file-system", 45 | default="tcp://localhost:9999", 46 | type=str, 47 | ) 48 | parser.add_argument( 49 | "--cfg", 50 | dest="cfg_file", 51 | help="Path to the config file", 52 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 53 | type=str, 54 | ) 55 | parser.add_argument( 56 | "opts", 57 | help="See slowfast/config/defaults.py for all options", 58 | default=None, 59 | nargs=argparse.REMAINDER, 60 | ) 61 | if len(sys.argv) == 1: 62 | parser.print_help() 63 | return parser.parse_args() 64 | 65 | 66 | def load_config(args): 67 | """ 68 | Given the arguemnts, load and initialize the configs. 69 | Args: 70 | args (argument): arguments includes `shard_id`, `num_shards`, 71 | `init_method`, `cfg_file`, and `opts`. 72 | """ 73 | # Setup cfg. 74 | cfg = get_cfg() 75 | # Load config from cfg. 76 | if args.cfg_file is not None: 77 | cfg.merge_from_file(args.cfg_file) 78 | # Load config from command line, overwrite config from opts. 79 | if args.opts is not None: 80 | cfg.merge_from_list(args.opts) 81 | 82 | # Inherit parameters from args. 83 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 84 | cfg.NUM_SHARDS = args.num_shards 85 | cfg.SHARD_ID = args.shard_id 86 | if hasattr(args, "rng_seed"): 87 | cfg.RNG_SEED = args.rng_seed 88 | if hasattr(args, "output_dir"): 89 | cfg.OUTPUT_DIR = args.output_dir 90 | 91 | # Create the checkpoint dir. 92 | cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 93 | return cfg 94 | -------------------------------------------------------------------------------- /testa/models/timesformer/utils/weight_init_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Utility function for weight initialization""" 4 | 5 | import torch.nn as nn 6 | from fvcore.nn.weight_init import c2_msra_fill 7 | 8 | 9 | def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 10 | """ 11 | Performs ResNet style weight initialization. 12 | Args: 13 | fc_init_std (float): the expected standard deviation for fc layer. 14 | zero_init_final_bn (bool): if True, zero initialize the final bn for 15 | every bottleneck. 16 | """ 17 | for m in model.modules(): 18 | if isinstance(m, nn.Conv3d): 19 | """ 20 | Follow the initialization method proposed in: 21 | {He, Kaiming, et al. 22 | "Delving deep into rectifiers: Surpassing human-level 23 | performance on imagenet classification." 24 | arXiv preprint arXiv:1502.01852 (2015)} 25 | """ 26 | c2_msra_fill(m) 27 | elif isinstance(m, nn.BatchNorm3d): 28 | if ( 29 | hasattr(m, "transform_final_bn") 30 | and m.transform_final_bn 31 | and zero_init_final_bn 32 | ): 33 | batchnorm_weight = 0.0 34 | else: 35 | batchnorm_weight = 1.0 36 | if m.weight is not None: 37 | m.weight.data.fill_(batchnorm_weight) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | if isinstance(m, nn.Linear): 41 | m.weight.data.normal_(mean=0.0, std=fc_init_std) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | -------------------------------------------------------------------------------- /testa/models/timesformer/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /testa/testa/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/facebookresearch/ToMe 3 | ''' 4 | 5 | from . import merge, patch, utils 6 | from .vis import make_visualization 7 | 8 | __all__ = ["utils", "merge", "patch", "make_visualization"] 9 | -------------------------------------------------------------------------------- /testa/testa/patch/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/facebookresearch/ToMe 3 | ''' 4 | 5 | from .vit import apply_patch as vit 6 | from .timesformer import apply_patch as timesformer 7 | 8 | __all__ = ["vit", "timesformer"] 9 | -------------------------------------------------------------------------------- /testa/testa/vis.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/facebookresearch/ToMe 3 | ''' 4 | 5 | import copy 6 | import random 7 | from typing import List, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from PIL import Image 13 | 14 | try: 15 | from scipy.ndimage import binary_erosion 16 | except ImportError: 17 | pass # Don't fail if scipy is not installed. It's only necessary for this one file. 18 | 19 | 20 | def generate_colormap(N: int, seed: int = 0) -> List[Tuple[float, float, float]]: 21 | """Generates a equidistant colormap with N elements.""" 22 | random.seed(seed) 23 | 24 | def generate_color(): 25 | return (random.random(), random.random(), random.random()) 26 | 27 | return [generate_color() for _ in range(N)] 28 | 29 | 30 | def make_visualization( 31 | img: Image, source: torch.Tensor, patch_size: int = 16, class_token: bool = True 32 | ) -> Image: 33 | """ 34 | Create a visualization like in the paper. 35 | 36 | Args: 37 | - 38 | 39 | Returns: 40 | - A PIL image the same size as the input. 41 | """ 42 | 43 | img = np.array(img.convert("RGB")) / 255.0 44 | source = source.detach().cpu() 45 | 46 | h, w, _ = img.shape 47 | ph = h // patch_size 48 | pw = w // patch_size 49 | 50 | if class_token: 51 | source = source[:, :, 1:] 52 | 53 | vis = source.argmax(dim=1) 54 | num_groups = vis.max().item() + 1 55 | 56 | cmap = generate_colormap(num_groups) 57 | vis_img = 0 58 | colors = [] 59 | 60 | for i in range(num_groups): 61 | mask = (vis == i).float().view(1, 1, ph, pw) 62 | mask = F.interpolate(mask, size=(h, w), mode="nearest") 63 | mask = mask.view(h, w, 1).numpy() 64 | 65 | color = (mask * img).sum(axis=(0, 1)) / mask.sum() 66 | colors.append(color) 67 | mask_eroded = binary_erosion(mask[..., 0])[..., None] 68 | mask_edge = mask - mask_eroded 69 | 70 | if not np.isfinite(color).all(): 71 | color = np.zeros(3) 72 | 73 | vis_img = vis_img + mask_eroded * color.reshape(1, 1, 3) 74 | vis_img = vis_img + mask_edge * np.array(cmap[i]).reshape(1, 1, 3) 75 | img = img * (1 - mask_edge) + mask_edge * np.array(cmap[i]).reshape(1, 1, 3) 76 | 77 | # Convert back into a PIL image 78 | vis_img = Image.fromarray(np.uint8(vis_img * 255)) 79 | img = Image.fromarray(np.uint8(img * 255)) 80 | 81 | # final_img = Image.blend(img, vis_img, alpha=0.7) 82 | return vis_img, cmap, colors, img 83 | -------------------------------------------------------------------------------- /v2tactiongraph/alvs_inter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def sample_vectors(f_v, k): 5 | """ 6 | Uniformly sample k vectors from f_v. 7 | """ 8 | t = f_v.size(0) 9 | indices = torch.linspace(0, t - 1, steps=k).long() 10 | sampled_vectors = f_v[indices] 11 | return sampled_vectors, indices 12 | 13 | def cosine_similarity(a, b): 14 | """ 15 | Calculate the cosine similarity between two tensors a and b. 16 | """ 17 | return F.cosine_similarity(a, b, dim=-1) 18 | 19 | def calculate_attention(f_qst, f_v_sampled): 20 | """ 21 | Calculate attention values 22 | """ 23 | q = f_qst.unsqueeze(0) # Shape: [1, d] 24 | k = f_v_sampled # Shape: [k, d] 25 | v = f_v_sampled # Shape: [k, d] 26 | 27 | # Attention calculation: att_weights = softmax(q * k^T / sqrt(d)) 28 | d = f_qst.size(-1) 29 | att_weights = F.softmax(torch.matmul(q, k.T) / (d ** 0.5), dim=-1) 30 | att_f_v = torch.matmul(att_weights, v) # Shape: [1, d] 31 | return att_f_v.squeeze(0) # Shape: [d] 32 | 33 | def process_numbers(number_list): 34 | unique_numbers = sorted(set(int(num) for num in number_list)) 35 | return unique_numbers 36 | 37 | def iterative_sampling(f_v, f_text, k, m, a1, a2): 38 | """ 39 | Perform the iterative sampling process. 40 | """ 41 | t, d = f_v.shape 42 | indices_record = [] 43 | iter_samples = [0] 44 | 45 | 46 | for _ in range(m): 47 | # Uniformly sample k vectors from f_v 48 | f_v_sampled, sampled_indices = sample_vectors(f_v, k) 49 | sampled_indices += sum(iter_samples) 50 | 51 | # Compute cosine similarity between consecutive vectors 52 | sim1 = cosine_similarity(f_v_sampled[:-1], f_v_sampled[1:]) 53 | sim1 = torch.cat([sim1, sim1[-1].unsqueeze(0)]) # Ensure last and second-last are the same 54 | 55 | # Calculate att_f_v 56 | att_f_v = torch.stack([calculate_attention(f_text, f_v_sampled[i].unsqueeze(0)) for i in range(k)]) 57 | 58 | # Compute cosine similarity between consecutive att_f_v vectors 59 | sim2 = cosine_similarity(att_f_v[:-1], att_f_v[1:]) 60 | sim2 = torch.cat([sim2, sim2[-1].unsqueeze(0)]) # Ensure last and second-last are the same 61 | 62 | # Sum and find the max index 63 | sim = a1 * sim1 + a2 * sim2 64 | max_sim_index = torch.argmax(sim) 65 | 66 | # Use max_sim_index as center, select new vectors 67 | center_idx = sampled_indices[max_sim_index] 68 | start_idx = max(0, center_idx - t // k) 69 | end_idx = min(t, center_idx + t // k) 70 | iter_samples.append(start_idx) 71 | f_v = f_v[start_idx:end_idx] 72 | indices_record += list(sampled_indices) 73 | indices_record = process_numbers(indices_record) 74 | 75 | return indices_record -------------------------------------------------------------------------------- /v2tactiongraph/feature_extractor/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import threading 4 | from torch._utils import ExceptionWrapper 5 | import logging 6 | 7 | def get_a_var(obj): 8 | if isinstance(obj, torch.Tensor): 9 | return obj 10 | 11 | if isinstance(obj, list) or isinstance(obj, tuple): 12 | for result in map(get_a_var, obj): 13 | if isinstance(result, torch.Tensor): 14 | return result 15 | if isinstance(obj, dict): 16 | for result in map(get_a_var, obj.items()): 17 | if isinstance(result, torch.Tensor): 18 | return result 19 | return None 20 | 21 | def parallel_apply(fct, model, inputs, device_ids): 22 | modules = nn.parallel.replicate(model, device_ids) 23 | assert len(modules) == len(inputs) 24 | lock = threading.Lock() 25 | results = {} 26 | grad_enabled = torch.is_grad_enabled() 27 | 28 | def _worker(i, module, input): 29 | torch.set_grad_enabled(grad_enabled) 30 | device = get_a_var(input).get_device() 31 | try: 32 | with torch.cuda.device(device): 33 | # this also avoids accidental slicing of `input` if it is a Tensor 34 | if not isinstance(input, (list, tuple)): 35 | input = (input,) 36 | output = fct(module, *input) 37 | with lock: 38 | results[i] = output 39 | except Exception: 40 | with lock: 41 | results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) 42 | 43 | if len(modules) > 1: 44 | threads = [threading.Thread(target=_worker, args=(i, module, input)) 45 | for i, (module, input) in enumerate(zip(modules, inputs))] 46 | 47 | for thread in threads: 48 | thread.start() 49 | for thread in threads: 50 | thread.join() 51 | else: 52 | _worker(0, modules[0], inputs[0]) 53 | 54 | outputs = [] 55 | for i in range(len(inputs)): 56 | output = results[i] 57 | if isinstance(output, ExceptionWrapper): 58 | output.reraise() 59 | outputs.append(output) 60 | return outputs 61 | 62 | def get_logger(filename=None): 63 | logger = logging.getLogger('logger') 64 | logger.setLevel(logging.DEBUG) 65 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', 66 | datefmt='%m/%d/%Y %H:%M:%S', 67 | level=logging.INFO) 68 | if filename is not None: 69 | handler = logging.FileHandler(filename) 70 | handler.setLevel(logging.DEBUG) 71 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 72 | logging.getLogger().addHandler(handler) 73 | return logger -------------------------------------------------------------------------------- /v2tactiongraph/feature_extractor/utility/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | # we do not used this class 3 | class CustomDataset(Dataset): 4 | def __init__(self, fo_input, stgraph, target): 5 | 6 | self.fo_input = fo_input 7 | self.stgraph = stgraph 8 | self.target = target 9 | self.n_samples = len(fo_input) 10 | 11 | def __getitem__(self, index): 12 | return self.fo_input[index], self.stgraph[index], self.target[index] 13 | 14 | def __len__(self): 15 | return self.n_samples 16 | -------------------------------------------------------------------------------- /v2tactiongraph/feature_extractor/utility/vocabulary.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | # we do not used this vocabulary class 3 | class Vocabulary: 4 | PAD_token = 0 # Used for padding short sentences 5 | BOS_token = 1 # Beginning-of-sentence token 6 | EOS_token = 2 # End-of-sentence token 7 | UNK_token = 3 # Unknown word token 8 | 9 | def __init__(self): 10 | self.word2index = {} 11 | self.word2count = {} 12 | self.index2word = {self.PAD_token: "", self.BOS_token: "", self.EOS_token: "", self.UNK_token: ""} 13 | self.num_words = 4 14 | self.num_sentences = 0 15 | self.longest_sentence = 0 16 | self.tokenizer = spacy.load('en_core_web_sm') 17 | 18 | def add_word(self, word): 19 | if word not in self.word2index: 20 | # First entry of word into vocabulary 21 | self.word2index[word] = self.num_words 22 | self.word2count[word] = 1 23 | self.index2word[self.num_words] = word 24 | self.num_words += 1 25 | else: 26 | # Word exists; increase word count 27 | self.word2count[word] += 1 28 | 29 | def add_sentence(self, sentence): 30 | sentence_len = 0 31 | for word in self.tokenizer(sentence): 32 | sentence_len += 1 33 | self.add_word(str(word)) 34 | if sentence_len > self.longest_sentence: 35 | # This is the longest sentence 36 | self.longest_sentence = sentence_len 37 | # Count the number of sentences 38 | self.num_sentences += 1 39 | 40 | def generate_vector(self, sentence="Hello", longest_sentence=None): 41 | # Validation data/test data may have longer sentence, so a parameter longest sentence provided 42 | if longest_sentence is None: 43 | longest_sentence = self.longest_sentence 44 | 45 | vector = [self.BOS_token] 46 | sentence_len = 0 47 | for word in self.tokenizer(sentence): 48 | vector.append(self.to_index(str(word))) 49 | sentence_len += 1 50 | vector.append(self.EOS_token) 51 | 52 | # Add token if needed 53 | if sentence_len < longest_sentence: 54 | for i in range(sentence_len, longest_sentence): 55 | vector.append(self.PAD_token) 56 | 57 | return vector 58 | 59 | def to_word(self, index): 60 | return self.index2word[index] 61 | 62 | def to_index(self, word): 63 | if word not in self.word2index: 64 | return self.UNK_token 65 | 66 | return self.word2index[word] 67 | 68 | def filter_vocab(self, min_word_count=0): 69 | word2count = self.word2count 70 | self.num_words = 4 71 | self.word2index = {} 72 | self.word2count = {} 73 | self.index2word = {self.PAD_token: "", self.BOS_token: "", self.EOS_token: "", self.UNK_token: ""} 74 | for word, count in word2count.items(): 75 | if count>=min_word_count: 76 | self.word2index[word] = self.num_words 77 | self.word2count[word] = count 78 | self.index2word[self.num_words] = word 79 | self.num_words += 1 80 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xid32/NAACL_2025_TWM/f36facd4a1388962e814723883c0bfabafd8c0ac/v2tactiongraph/modules/__init__.py -------------------------------------------------------------------------------- /v2tactiongraph/modules/decoder-base/decoder_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "num_attention_heads": 12, 9 | "num_hidden_layers": 12, 10 | "type_vocab_size": 2, 11 | "vocab_size": 30522, 12 | "num_decoder_layers": 1, 13 | "max_target_embeddings": 512 14 | } 15 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/gnn/GATConvolution.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import GATConv 4 | 5 | class LinearProj(nn.Module): 6 | """Implementation of the linear projection layer for GATC. 7 | Params: 8 | node_feat_dim: input node feature dimension. 9 | d_model: output dimension. 10 | """ 11 | def __init__(self, node_feat_dim, d_model): 12 | super(LinearProj, self).__init__() 13 | 14 | self.proj = nn.Linear(node_feat_dim, d_model) 15 | 16 | def forward(self, x): 17 | return self.proj(x) 18 | 19 | class GATC(nn.Module): 20 | """Implementation of the transformer heads module from the `"Graph Attention Networks" 21 | ` paper. 22 | Params: 23 | node_feat_dim: input node feature dimension. 24 | d_model: output dimension of linear projection. 25 | edge_dim: edge feature dimension. 26 | heads: total head. Default: 4 27 | project_edge_dim: projection of edge dimension 28 | more_skip: whether to use skip connection. Default: True 29 | last_average: whether to average the multi-head attentions. Default: False 30 | """ 31 | def __init__(self, node_feat_dim, d_model, edge_dim, heads=4, project_edge_dim=None, more_skip=True, last_average=False): 32 | super().__init__() 33 | self.lp = LinearProj(node_feat_dim, d_model) 34 | self.more_skip = more_skip 35 | self.project_edge_dim = project_edge_dim 36 | if self.project_edge_dim is not None: 37 | self.lp_edge_attr = nn.Linear(edge_dim, project_edge_dim) 38 | edge_dim = project_edge_dim 39 | 40 | self.conv1 = GATConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 41 | 42 | self.conv2 = GATConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 43 | 44 | if last_average: 45 | self.conv3 = GATConv(d_model, d_model, heads, concat=False, edge_dim=edge_dim, aggr='mean') 46 | else: 47 | self.conv3 = GATConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 48 | 49 | def forward(self, data): 50 | x = self.lp(data.x) 51 | if self.project_edge_dim is not None: 52 | e = self.lp_edge_attr(data.edge_attr) 53 | else: 54 | e = data.edge_attr 55 | 56 | if self.more_skip: 57 | x = F.relu(x + self.conv1(x, data.edge_index, e)) 58 | x = F.relu(x + self.conv2(x, data.edge_index, e)) 59 | x = F.relu(x + self.conv3(x, data.edge_index, e)) 60 | else: 61 | x = F.relu(self.conv1(x, data.edge_index, e)) 62 | x = F.relu(self.conv2(x, data.edge_index, e)) 63 | x = F.relu(self.conv3(x, data.edge_index, e)) 64 | return x -------------------------------------------------------------------------------- /v2tactiongraph/modules/gnn/GATv2Convolution.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import GATv2Conv 4 | 5 | class LinearProj(nn.Module): 6 | """Implementation of the linear projection layer for GATv2C. 7 | Params: 8 | node_feat_dim: input node feature dimension. 9 | d_model: output dimension. 10 | """ 11 | def __init__(self, node_feat_dim, d_model): 12 | super(LinearProj, self).__init__() 13 | 14 | self.proj = nn.Linear(node_feat_dim, d_model) 15 | 16 | def forward(self, x): 17 | return self.proj(x) 18 | 19 | class GATv2C(nn.Module): 20 | """Implementation of the transformer heads module from the "How Attentive are Graph Attention Networks?" 21 | ` paper. 22 | Params: 23 | node_feat_dim: input node feature dimension. 24 | d_model: output dimension of linear projection. 25 | edge_dim: edge feature dimension. 26 | heads: total head. Default: 4 27 | project_edge_dim: projection of edge dimension 28 | more_skip: whether to use skip connection. Default: True 29 | last_average: whether to average the multi-head attentions. Default: False 30 | """ 31 | def __init__(self, node_feat_dim, d_model, edge_dim, heads=4, project_edge_dim=None, more_skip=True, last_average=False): 32 | super().__init__() 33 | self.lp = LinearProj(node_feat_dim, d_model) 34 | self.more_skip = more_skip 35 | 36 | self.project_edge_dim = project_edge_dim 37 | if self.project_edge_dim is not None: 38 | self.lp_edge_attr = nn.Linear(edge_dim, project_edge_dim) 39 | edge_dim = project_edge_dim 40 | 41 | self.conv1 = GATv2Conv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 42 | 43 | self.conv2 = GATv2Conv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 44 | 45 | if last_average: 46 | self.conv3 = GATv2Conv(d_model, d_model, heads, concat=False, edge_dim=edge_dim, aggr='mean') 47 | else: 48 | self.conv3 = GATv2Conv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean') 49 | 50 | def forward(self, data): 51 | x = self.lp(data.x) 52 | if self.project_edge_dim is not None: 53 | e = self.lp_edge_attr(data.edge_attr) 54 | else: 55 | e = data.edge_attr 56 | if self.more_skip: 57 | x = F.relu(x + self.conv1(x, data.edge_index, e)) 58 | x = F.relu(x + self.conv2(x, data.edge_index, e)) 59 | x = F.relu(x + self.conv3(x, data.edge_index, e)) 60 | else: 61 | x = F.relu(self.conv1(x, data.edge_index, e)) 62 | x = F.relu(self.conv2(x, data.edge_index, e)) 63 | x = F.relu(self.conv3(x, data.edge_index, e)) 64 | return x 65 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/gnn/TransformerConvolution.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import TransformerConv 4 | 5 | class LinearProj(nn.Module): 6 | """Implementation of the linear projection layer for TransC. 7 | Params: 8 | node_feat_dim: input node feature dimension. 9 | d_model: output dimension. 10 | """ 11 | def __init__(self, node_feat_dim, d_model): 12 | super(LinearProj, self).__init__() 13 | 14 | self.proj = nn.Linear(node_feat_dim, d_model) 15 | 16 | def forward(self, x): 17 | return self.proj(x) 18 | 19 | class TransC(nn.Module): 20 | """Implementation of the transformer heads module from the `"Masked Label Prediction: Unified Message 21 | Passing Model for Semi-Supervised Classification" ` paper. 22 | Params: 23 | node_feat_dim: input node feature dimension. 24 | d_model: output dimension of linear projection. 25 | edge_dim: edge feature dimension. 26 | heads: total head. Default: 4 27 | project_edge_dim: projection of edge dimension 28 | more_skip: whether to use skip connection. Default: True 29 | last_average: whether to average the multi-head attentions. Default: False 30 | beta: whether to enable feature combination using beta trade-off (see TransformerConv for 31 | more detailed formula). Default: True 32 | """ 33 | def __init__(self, node_feat_dim, d_model, edge_dim, heads=4, project_edge_dim=None, more_skip=True, last_average=False, beta=True): 34 | super().__init__() 35 | self.lp = LinearProj(node_feat_dim, d_model) 36 | self.more_skip = more_skip 37 | self.project_edge_dim = project_edge_dim 38 | if self.project_edge_dim is not None: 39 | self.lp_edge_attr = nn.Linear(edge_dim, project_edge_dim) 40 | edge_dim = project_edge_dim 41 | 42 | self.conv1 = TransformerConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean', beta=beta) 43 | 44 | self.conv2 = TransformerConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean', beta=beta) 45 | 46 | if last_average: 47 | self.conv3 = TransformerConv(d_model, d_model, heads, concat=False, edge_dim=edge_dim, aggr='mean', beta=beta) 48 | else: 49 | self.conv3 = TransformerConv(d_model, int(d_model/heads), heads, edge_dim=edge_dim, aggr='mean', beta=beta) 50 | 51 | def forward(self, data): 52 | x = self.lp(data.x) 53 | if self.project_edge_dim is not None: 54 | e = F.relu(self.lp_edge_attr(data.edge_attr)) 55 | else: 56 | e = data.edge_attr 57 | if self.more_skip: 58 | x = F.relu(x + self.conv1(x, data.edge_index, e)) 59 | x = F.relu(x + self.conv2(x, data.edge_index, e)) 60 | x = F.relu(x + self.conv3(x, data.edge_index, e)) 61 | else: 62 | x = F.relu(self.conv1(x, data.edge_index, e)) 63 | x = F.relu(self.conv2(x, data.edge_index, e)) 64 | x = F.relu(self.conv3(x, data.edge_index, e)) 65 | return x -------------------------------------------------------------------------------- /v2tactiongraph/modules/graph_gat_modelling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.modeling import CaptionGenerator 3 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 4 | from torch import nn 5 | from torch.nn import KLDivLoss 6 | from modules.gnn.GATConvolution import GATC 7 | 8 | 9 | class GraphGAT(nn.Module): 10 | """Implementation of the graph model using GAT head. 11 | Params: 12 | caption_generator_state_dict: model's saved state. 13 | caption_generator_cache_dir: cache directory location. 14 | args: args variable from main script. 15 | """ 16 | def __init__(self, caption_generator_state_dict=None, caption_generator_cache_dir=None, args=None): 17 | super().__init__() 18 | self.gatc = GATC(node_feat_dim=args.node_feat_dim, d_model=args.d_model, edge_dim=args.edge_dim, 19 | project_edge_dim=args.project_edge_dim, more_skip=args.no_skip==False, last_average=args.last_average) 20 | self.caption_generator_model = CaptionGenerator.from_pretrained(args.bert_model, args.visual_model, args.decoder_model, 21 | cache_dir=caption_generator_cache_dir, state_dict=caption_generator_state_dict, task_config=args) 22 | self.avgPool = nn.AvgPool2d((args.num_object,1)) # number of patches or objects 23 | self.kl_loss_fct = KLDivLoss(reduction='batchmean') 24 | self.lp = nn.Linear(1024, 512) 25 | 26 | def forward(self, geo_graph, video_mask=None, 27 | input_caption_ids=None, decoder_mask=None, batch_size=None, n_node=None): 28 | fo_convolved = self.gatc(geo_graph) 29 | 30 | fo_convolved = fo_convolved.unflatten(0, (batch_size,n_node)) 31 | fo_convolved = self.avgPool(fo_convolved) 32 | 33 | decoder_scores = self.caption_generator_model(fo_convolved, video_mask, input_caption_ids=input_caption_ids, decoder_mask=decoder_mask) 34 | return decoder_scores 35 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/graph_gatv2_modelling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.modeling import CaptionGenerator 3 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 4 | from torch import nn 5 | from torch.nn import KLDivLoss 6 | from modules.gnn.GATv2Convolution import GATv2C 7 | 8 | 9 | class GraphGATv2(nn.Module): 10 | """Implementation of the graph model using GATv2 head. 11 | Params: 12 | caption_generator_state_dict: model's saved state. 13 | caption_generator_cache_dir: cache directory location. 14 | args: args variable from main script. 15 | """ 16 | def __init__(self, caption_generator_state_dict=None, caption_generator_cache_dir=None, args=None): 17 | super().__init__() 18 | self.gatv2c = GATv2C(node_feat_dim=args.node_feat_dim, d_model=args.d_model, edge_dim=args.edge_dim, 19 | project_edge_dim=args.project_edge_dim, more_skip=args.no_skip==False, last_average=args.last_average) 20 | self.caption_generator_model = CaptionGenerator.from_pretrained(args.bert_model, args.visual_model, args.decoder_model, 21 | cache_dir=caption_generator_cache_dir, state_dict=caption_generator_state_dict, task_config=args) 22 | self.avgPool = nn.AvgPool2d((args.num_object,1)) # number of patches or objects 23 | self.kl_loss_fct = KLDivLoss(reduction='batchmean') 24 | self.lp = nn.Linear(1024, 512) 25 | 26 | def forward(self, geo_graph, video_mask=None, 27 | input_caption_ids=None, decoder_mask=None, batch_size=None, n_node=None): 28 | fo_convolved = self.gatv2c(geo_graph) 29 | 30 | fo_convolved = fo_convolved.unflatten(0, (batch_size,n_node)) 31 | fo_convolved = self.avgPool(fo_convolved) 32 | 33 | decoder_scores = self.caption_generator_model(fo_convolved, video_mask, input_caption_ids=input_caption_ids, decoder_mask=decoder_mask) 34 | return decoder_scores 35 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/graph_transformer_modelling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.modeling import CaptionGenerator 3 | from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 4 | from torch import nn 5 | from torch.nn import KLDivLoss 6 | from modules.gnn.TransformerConvolution import TransC 7 | 8 | 9 | class GraphTransformer(nn.Module): 10 | """Implementation of the graph model using Transformer head. 11 | Params: 12 | caption_generator_state_dict: model's saved state. 13 | caption_generator_cache_dir: cache directory location. 14 | args: args variable from main script. 15 | """ 16 | def __init__(self, caption_generator_state_dict=None, caption_generator_cache_dir=None, args=None): 17 | super().__init__() 18 | self.transc = TransC(node_feat_dim=args.node_feat_dim, d_model=args.d_model, edge_dim=args.edge_dim, 19 | project_edge_dim=args.project_edge_dim, more_skip=args.no_skip==False, last_average=args.last_average, beta=args.no_beta_transformer==False) 20 | self.caption_generator_model = CaptionGenerator.from_pretrained(args.bert_model, args.visual_model, args.decoder_model, 21 | cache_dir=caption_generator_cache_dir, state_dict=caption_generator_state_dict, task_config=args) 22 | self.avgPool = nn.AvgPool2d((args.num_object,1)) # number of patches or objects 23 | self.kl_loss_fct = KLDivLoss(reduction='batchmean') 24 | self.lp = nn.Linear(1024, 512) 25 | 26 | def forward(self, geo_graph, video_mask=None, 27 | input_caption_ids=None, decoder_mask=None, batch_size=None, n_node=None): 28 | fo_convolved = self.transc(geo_graph) 29 | 30 | fo_convolved = fo_convolved.unflatten(0, (batch_size,n_node)) 31 | fo_convolved = self.avgPool(fo_convolved) 32 | 33 | decoder_scores = self.caption_generator_model(fo_convolved, video_mask, input_caption_ids=input_caption_ids, decoder_mask=decoder_mask) 34 | return decoder_scores 35 | 36 | def get_visual_output(self, geo_graph, video_mask=None, batch_size=None, n_node=None, action=None): 37 | fo_convolved = self.transc(geo_graph) 38 | 39 | fo_convolved = fo_convolved.unflatten(0, (batch_size, n_node)) 40 | fo_convolved = self.avgPool(fo_convolved) 41 | 42 | visual_output = self.caption_generator_model.get_visual_output(fo_convolved, video_mask) 43 | 44 | return visual_output 45 | -------------------------------------------------------------------------------- /v2tactiongraph/modules/visual-base/visual_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 1, 11 | "vocab_size": 512 12 | } 13 | -------------------------------------------------------------------------------- /v2tactiongraph/scripts/msrvtt_train_GNN.sh: -------------------------------------------------------------------------------- 1 | # Setup 2 | DATATYPE=msrvtt 3 | N_GPU=1 4 | N_THREAD=8 5 | 6 | # PATH to files 7 | DATA_PATH=../dataset/MSRVTT/MSRVTT_data.json 8 | CKPT_ROOT=../ckpts 9 | INIT_MODEL_PATH=../weight/univl.pretrained.bin 10 | FEATURES_PATH= # Change into the path to the features you extracted from CLIP4Clip 11 | DATA_GEOMETRIC_PATH= # Change into the path to the graph-based features (can be grid or object-based features) 12 | NODE_FEATURES=geometric # please only use geometric for now 13 | # Params 14 | LEARNING_RATE=(3e-4) 15 | 16 | for lr in "${LEARNING_RATE[@]}" 17 | do 18 | python -m torch.distributed.launch --nproc_per_node=${N_GPU} \ 19 | ../main_task_caption_GNN.py --do_train --num_thread_reader=${N_THREAD} \ 20 | --epochs=50 --batch_size=1024 --n_display=50 --gradient_accumulation_steps 1 \ 21 | --data_path ${DATA_PATH} --features_path ${FEATURES_PATH} \ 22 | --output_dir ${CKPT_ROOT}/${DATATYPE}_lr${lr}_gnn \ 23 | --bert_model bert-base-uncased --do_lower_case \ 24 | --lr ${lr} --max_words 48 --max_frames 20 --batch_size_val 128 \ 25 | --visual_num_hidden_layers 2 --decoder_num_hidden_layers 2 \ 26 | --datatype ${DATATYPE} --init_model ${INIT_MODEL_PATH} \ 27 | --data_geometric_path ${DATA_GEOMETRIC_PATH} \ 28 | --node_features ${NODE_FEATURES} --node_feat_dim 512 --d_model 512 --video_dim 512 --edge_dim 1024 \ 29 | --tradeoff_theta_2 4 --tradeoff_distill 1 --gnn_model_type transformer \ 30 | done 31 | -------------------------------------------------------------------------------- /v2tactiongraph/scripts/msvd_train_GNN.sh: -------------------------------------------------------------------------------- 1 | # Setup 2 | DATATYPE=msvd 3 | N_GPU=1 4 | N_THREAD=8 5 | 6 | # PATH to files 7 | DATA_PATH=../dataset/MSVD 8 | CKPT_ROOT=../ckpts 9 | INIT_MODEL_PATH=../weight/univl.pretrained.bin 10 | FEATURES_PATH= # Change into the features you extracted from CLIP4Clip 11 | DATA_GEOMETRIC_PATH=# Change into the path to the graph-based features (can be grid or object-based features) 12 | NODE_FEATURES=geometric # please only use geometric for now 13 | # Params 14 | LEARNING_RATE=(1e-4) 15 | 16 | for lr in "${LEARNING_RATE[@]}" 17 | do 18 | python -m torch.distributed.launch --nproc_per_node=${N_GPU} \ 19 | ../main_task_caption_GNN.py --do_train --num_thread_reader=${N_THREAD} \ 20 | --epochs=50 --batch_size=128 --n_display=50 --gradient_accumulation_steps 2 \ 21 | --data_path ${DATA_PATH} --features_path ${FEATURES_PATH} \ 22 | --output_dir ${CKPT_ROOT}/${DATATYPE}_lr${lr}_gnn \ 23 | --bert_model bert-base-uncased --do_lower_case \ 24 | --lr ${lr} --max_words 48 --max_frames 20 --batch_size_val 16 \ 25 | --visual_num_hidden_layers 2 --decoder_num_hidden_layers 2 \ 26 | --datatype ${DATATYPE} --init_model ${INIT_MODEL_PATH} \ 27 | --data_geometric_path ${DATA_GEOMETRIC_PATH} \ 28 | --node_features ${NODE_FEATURES} --node_feat_dim 512 --d_model 512 --video_dim 512 --edge_dim 1024 \ 29 | --tradeoff_theta_2 4 --tradeoff_distill 1 --gnn_model_type transformer \ 30 | --custom_input_dim 1812 31 | done 32 | --------------------------------------------------------------------------------