├── howtocaption ├── __init__.py ├── model │ ├── utils │ │ ├── __init__.py │ │ ├── vit_per_frame.py │ │ ├── utils.py │ │ └── vit.py │ └── __init__.py ├── utils │ ├── __init__.py │ ├── meter.py │ ├── retrieval_metrics_from_cap4video.py │ ├── util.py │ └── dist_utils.py ├── trainer │ ├── __init__.py │ ├── utils.py │ ├── retrieval_eval.py │ └── coco_eval.py ├── data_loader │ ├── __init__.py │ ├── video_datasets │ │ ├── __init__.py │ │ ├── webvid.py │ │ ├── videocc.py │ │ ├── utils.py │ │ ├── youcook2.py │ │ ├── msrvtt.py │ │ ├── lsmdc.py │ │ ├── msvd.py │ │ └── howto.py │ ├── video_dataloader.py │ ├── transforms.py │ └── utils.py ├── base │ ├── __init__.py │ ├── base_model.py │ ├── base_data_loader.py │ └── base_trainer.py ├── lr_scheduler │ ├── __init__.py │ └── warmup.py ├── llm_prompting │ ├── scripts │ │ ├── 1_filter_available.py │ │ ├── 3_collect_predictions.py │ │ └── 2_create_word_blocks.py │ └── prompt_vicuna.py ├── save_frame_embeddings.py ├── eval.py ├── parse_config.py ├── save_text_embeddings.py ├── train.py └── align_and_filter.py ├── method.jpeg ├── .gitignore ├── setup.py ├── requirements.txt ├── configs ├── med_config.json ├── vicuna │ └── final_prompt.yaml ├── align_and_filter │ ├── blip.yaml │ ├── blip_ft_1round.yaml │ └── finetune_1round.yaml └── VL_training │ ├── captioning_youcook2.yaml │ ├── captioning_msvd.yaml │ ├── captioning_msrvtt.yaml │ ├── baselines │ ├── dual_encoder_retrieval_webvid.yaml │ └── dual_encoder_retrieval_HowTo100M.yaml │ ├── dual_encoder_retrieval.yaml │ ├── full_encoder_decoder_ViT_L.yaml │ └── full_encoder_decoder.yaml ├── dataset └── readme.md └── README.md /howtocaption/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /howtocaption/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /howtocaption/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * -------------------------------------------------------------------------------- /howtocaption/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .vl_trainer import VL_Trainer 2 | -------------------------------------------------------------------------------- /method.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ninatu/howtocaption/HEAD/method.jpeg -------------------------------------------------------------------------------- /howtocaption/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .video_dataloader import VideoDataLoader -------------------------------------------------------------------------------- /howtocaption/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .blip_tv_decoder import BlipVTDecoderModel 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | .ipynb_checkpoints 5 | .neptune 6 | data 7 | output 8 | 9 | -------------------------------------------------------------------------------- /howtocaption/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /howtocaption/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import * 2 | from .warmup import CustomCosineSchedulerWithWarmup, SchedulerWithWarmup, SchedulerWithWarmupAndDecay 3 | -------------------------------------------------------------------------------- /howtocaption/trainer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _move_to_device(data, device): 4 | if torch.is_tensor(data): 5 | return data.to(device, non_blocking=True) 6 | else: 7 | return data -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .howto import HowTo100M 2 | from .lsmdc import LSMDC 3 | from .msvd import MSVD 4 | from .msrvtt import MSRVTT 5 | from .youcook2 import YouCook2 6 | from .webvid import WebVid2M 7 | from .videocc import VideoCC3M 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='howtocaption', 5 | version='0.1', 6 | description='howtocaption', 7 | url='', 8 | author='', 9 | author_email='', 10 | license='', 11 | packages=find_packages(), 12 | dependency_links=[], 13 | zip_safe=False) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.3 2 | tqdm==4.64.0 3 | pyyaml==5.4.1 4 | tensorboardx==2.5.1 5 | neptune-client==0.16.7 6 | neptune-sacred==0.10.0 7 | sacred==0.8.2 8 | webdataset==0.1.103 9 | scikit-image==0.19.3 10 | timm==0.4.12 11 | fairscale==0.4.4 12 | einops==0.3.0 13 | ffmpeg-python==0.2.0 14 | matplotlib==3.6.0 15 | simplejson==3.18 16 | pycocoevalcap 17 | pycocotools 18 | 19 | 20 | accelerate==0.18.0 21 | sentencepiece==0.1.97 22 | chardet==5.1.0 23 | git+https://github.com/huggingface/transformers@v4.29.0 24 | fschat==0.1.10 -------------------------------------------------------------------------------- /configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/vicuna/final_prompt.yaml: -------------------------------------------------------------------------------- 1 | device: 'cuda' 2 | 3 | word_blocks: '200w' 4 | 5 | prompt: "I will give you an automatically recognized speech with timestamps from a video segment that is cut from a long video. Write a summary for this video segment. Write only short sentences. Describe only one action per sentence. Keep only actions that happen in the present time. Begin each sentence with an estimated timestamp. Here is this automatically recognized speech: \"{}\"." 6 | examples: [] 7 | 8 | model_generate_args: 9 | temperature: 0.7 10 | do_sample: true 11 | max_length: 2048 12 | eos_token_id: 2277 13 | 14 | batch_processing: true 15 | batch_size: 6 16 | 17 | save_dir: 'output/vicuna' -------------------------------------------------------------------------------- /howtocaption/base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return 'Trainable parameters: {}\n'.format(params) + super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /howtocaption/model/utils/vit_per_frame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from howtocaption.model.utils.vit import VisionTransformer 3 | 4 | 5 | class VisionTransformerPerFrame(VisionTransformer): 6 | def forward(self, x, register_blk=-1): 7 | b, f, _, _, _ = x.shape 8 | x = x.view(b * f, *x.size()[2:]) 9 | B = x.shape[0] 10 | x = self.patch_embed(x) 11 | 12 | cls_tokens = self.cls_token.expand(B, -1, -1) 13 | x = torch.cat((cls_tokens, x), dim=1) 14 | 15 | x = x + self.pos_embed[:, :x.size(1), :] 16 | x = self.pos_drop(x) 17 | 18 | for i, blk in enumerate(self.blocks): 19 | x = blk(x, register_blk == i) 20 | x = self.norm(x) 21 | 22 | patches_per_frame = self.patch_embed.num_patches + 1 23 | x = x.view(b, f * patches_per_frame, *x.size()[2:]) 24 | 25 | return x 26 | -------------------------------------------------------------------------------- /configs/align_and_filter/blip.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | train_contrastive: true 9 | train_captioning: false 10 | train_max_text_length: 32 11 | 12 | load_weights: null 13 | 14 | data_loader: 15 | type: VideoDataLoader 16 | args: 17 | dataset_type: HowTo100M 18 | dataset_args: 19 | csv: data/howto100m/video_path_filtered.csv 20 | video_root: data/howto100m/videos 21 | caption_path: data/howto100m/asr_filtered.pickle 22 | return_all_frames_1fps: true 23 | num_workers: 8 24 | batch_size: 1 25 | transform: 'test_resize' 26 | 27 | save_dir: output/embeddings 28 | 29 | -------------------------------------------------------------------------------- /configs/align_and_filter/blip_ft_1round.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | train_contrastive: true 9 | train_captioning: false 10 | alpha: 0.6 11 | train_max_text_length: 32 12 | queue_size: 2048 13 | continual_learning_weight: 0.1 14 | 15 | load_weights: pretrained/finetune_1round.pth # FIXME 16 | 17 | data_loader: 18 | type: VideoDataLoader 19 | args: 20 | dataset_type: HowTo100M 21 | dataset_args: 22 | csv: data/howto100m/video_path_filtered.csv 23 | video_root: data/howto100m/videos 24 | caption_path: data/howto100m/asr_filtered.pickle 25 | return_all_frames_1fps: true 26 | num_workers: 8 27 | batch_size: 1 28 | transform: 'test_resize' 29 | 30 | save_dir: output/embeddings 31 | 32 | -------------------------------------------------------------------------------- /howtocaption/base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from howtocaption.utils.dist_utils import is_dist_avail_and_initialized, get_world_size, get_rank 4 | 5 | 6 | class BaseDataLoader(DataLoader): 7 | """ 8 | Base class for all data loaders 9 | """ 10 | def __init__(self, dataset, batch_size, shuffle, num_workers, drop_last, collate_fn=None): 11 | if is_dist_avail_and_initialized(): 12 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, 13 | num_replicas=get_world_size(), 14 | rank=get_rank(), 15 | shuffle=shuffle, 16 | drop_last=drop_last) 17 | shuffle = False 18 | else: 19 | sampler = None 20 | 21 | super().__init__(dataset, batch_size, shuffle, sampler=sampler, num_workers=num_workers, collate_fn=collate_fn, 22 | pin_memory=True, drop_last=drop_last) 23 | -------------------------------------------------------------------------------- /howtocaption/utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | 20 | def __str__(self): 21 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 22 | return fmtstr.format(**self.__dict__) 23 | 24 | 25 | class ProgressMeter(object): 26 | def __init__(self, num_batches, meters, prefix=""): 27 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 28 | self.meters = meters 29 | self.prefix = prefix 30 | 31 | def display(self, batch): 32 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 33 | entries += [str(meter) for meter in self.meters] 34 | return '\t'.join(entries) 35 | 36 | def _get_batch_fmtstr(self, num_batches): 37 | num_digits = len(str(num_batches // 1)) 38 | fmt = '{:' + str(num_digits) + 'd}' 39 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' -------------------------------------------------------------------------------- /howtocaption/llm_prompting/scripts/1_filter_available.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import pandas as pd 4 | import os 5 | import argparse 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--asr", type=str, default='data/howto100m/sentencified_htm_1200k.json') 11 | parser.add_argument("--csv", type=str, default='data/howto100m/video_path_downloaded.csv') 12 | parser.add_argument("--output_folder", type=str, default='data/howto100m') 13 | 14 | args = parser.parse_args() 15 | csv = pd.read_csv(args.csv) 16 | 17 | with open(args.asr, 'r') as fin: 18 | asrs = json.load(fin) 19 | 20 | video_ids = set(csv['video_id']).intersection(asrs.keys()) 21 | 22 | csv_available = csv[csv['video_id'].isin(video_ids)] 23 | asrs = {key: asrs[key] for key in video_ids} 24 | 25 | csv_debug = csv_available.iloc[:50] 26 | asrs_debug = {key: asrs[key] for key in csv_debug['video_id']} 27 | 28 | csv_available.to_csv(os.path.join(args.output_folder, 'video_path_filtered.csv')) 29 | csv_debug.to_csv(os.path.join(args.output_folder, 'video_path_filtered_s50.csv')) 30 | 31 | with open(os.path.join(args.output_folder, 'asr_filtered.pickle'), 'wb') as fout: 32 | pickle.dump(asrs, fout) 33 | 34 | with open(os.path.join(args.output_folder, 'asr_filtered_s50.pickle'), 'wb') as fout: 35 | pickle.dump(asrs_debug, fout) 36 | -------------------------------------------------------------------------------- /howtocaption/trainer/retrieval_eval.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import scipy.stats 5 | 6 | 7 | def retrieval_metrics(sims, break_ties='averaging', complete_dataset_size=None): 8 | num_queries, num_vids = sims.shape 9 | if complete_dataset_size is not None: 10 | num_queries = complete_dataset_size 11 | 12 | sx = np.sort(-sims, axis=1) 13 | d = np.diag(-sims) 14 | d = d[:, np.newaxis] 15 | diff = sx - d 16 | if break_ties == 'optimistically': 17 | ind = np.argmax(diff == 0, axis=1) 18 | elif break_ties == 'averaging': 19 | locs = np.argwhere(diff == 0) 20 | grouped_locs = [list(values) for n_row, values in itertools.groupby(locs, key=lambda x: x[0])] 21 | ind = [np.mean(list(map(lambda x: x[1], locs))) for locs in grouped_locs] 22 | ind = np.array(ind) 23 | else: 24 | raise NotImplementedError 25 | return cols2metrics(ind, num_queries) 26 | 27 | 28 | def cols2metrics(cols, num_queries): 29 | metrics = {} 30 | metrics["R1"] = 100 * float(np.sum(cols == 0)) / num_queries 31 | metrics["R5"] = 100 * float(np.sum(cols < 5)) / num_queries 32 | metrics["R10"] = 100 * float(np.sum(cols < 10)) / num_queries 33 | metrics["R50"] = 100 * float(np.sum(cols < 50)) / num_queries 34 | metrics["MedR"] = np.median(cols) + 1 35 | metrics["MeanR"] = np.mean(cols) + 1 36 | stats = [metrics[x] for x in ("R1", "R5", "R10")] 37 | metrics["geometric_mean_R1-R5-R10"] = scipy.stats.mstats.gmean(stats) 38 | return metrics -------------------------------------------------------------------------------- /howtocaption/data_loader/video_dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from howtocaption.data_loader import video_datasets 7 | from howtocaption.data_loader.transforms import init_transform_dict 8 | from howtocaption.utils.dist_utils import get_rank, get_world_size 9 | 10 | 11 | class VideoDataLoader(DataLoader): 12 | def __init__(self, dataset_type, dataset_args, batch_size, num_workers, 13 | split='train', transform=None, shuffle=None, transform_params={}, 14 | prefetch_factor=2, 15 | pin_memory=True, 16 | **kwargs): 17 | if shuffle is None: 18 | shuffle = (split == 'train') 19 | 20 | if transform is None: 21 | transform = split 22 | transforms = init_transform_dict(**transform_params)[transform] 23 | dataset = getattr(video_datasets, dataset_type)(transforms=transforms, split=split, **dataset_args) 24 | 25 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=get_world_size(), rank=get_rank(), 26 | shuffle=shuffle) 27 | 28 | super(VideoDataLoader, self).__init__(dataset, 29 | batch_size=batch_size, 30 | num_workers=num_workers, 31 | pin_memory=pin_memory, 32 | sampler=sampler, 33 | shuffle=False, 34 | collate_fn=None, 35 | drop_last=(split == 'train'), 36 | prefetch_factor=prefetch_factor) 37 | -------------------------------------------------------------------------------- /howtocaption/llm_prompting/scripts/3_collect_predictions.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | import tqdm 4 | import argparse 5 | import yaml 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--config", type=str) 10 | parser.add_argument("--asr-path", type=str) 11 | parser.add_argument("--output-path", type=str, default=None) 12 | 13 | args = parser.parse_args() 14 | 15 | with open(args.config) as fin: 16 | config = yaml.safe_load(fin) 17 | 18 | word_blocks = config['word_blocks'] 19 | save_dir = config['save_dir'] 20 | 21 | exp_name = os.path.splitext(os.path.basename(args.config))[0] 22 | result_dir = os.path.join(save_dir, exp_name) 23 | 24 | with open(args.asr_path, 'rb') as fin: 25 | data = pickle.load(fin) 26 | 27 | output = {} 28 | for id_, (key, val) in enumerate(tqdm.tqdm(data.items())): 29 | path = f'{result_dir}/{key[0]}/{key[1]}/{key}.pickle' 30 | with open(path, 'rb') as fin: 31 | pred = pickle.load(fin) 32 | if len(pred) > 0 and isinstance(pred[0], list): 33 | pred = [x[0] for x in pred] 34 | output[key] = { 35 | 'start': val[word_blocks]['start'], 36 | 'end': val[word_blocks]['end'], 37 | 'prediction': pred 38 | } 39 | 40 | output_path = args.output_path 41 | if output_path is None: 42 | output_path = os.path.join(save_dir, f'{exp_name}.pickle') 43 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 44 | print(output) 45 | with open(output_path, 'wb') as fout: 46 | pickle.dump(output, fout) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/webvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | 5 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 6 | 7 | 8 | class WebVid2M(Dataset): 9 | def __init__(self, data_root, split, num_frames=4, transforms=None, 10 | dataset_name='WebVid2M', 11 | sample_beginning=False, 12 | central_frames=False, 13 | ): 14 | super(WebVid2M, self).__init__() 15 | 16 | self.data_root = data_root 17 | self._load_metadata(split) 18 | self.split = split 19 | 20 | self.transforms = transforms 21 | self.num_frames = num_frames 22 | 23 | self.dataset_name = dataset_name 24 | self.sample_beginning = sample_beginning 25 | self.central_frames = central_frames 26 | 27 | def _load_metadata(self, split): 28 | assert split in ["train"] 29 | self.csv = pd.read_csv(os.path.join(self.data_root, f'results_2M_{split}_downloaded.csv')) 30 | 31 | def __len__(self): 32 | return len(self.csv) 33 | 34 | def __getitem__(self, idx): 35 | data = self.csv.iloc[idx] 36 | 37 | video_id = data["videoid"] 38 | rel_fp = f'videos/{data["page_dir"]}/{data["videoid"]}.mp4' 39 | video_fp = os.path.join(self.data_root, rel_fp) 40 | caption = data['name'] 41 | 42 | video = get_video_frames(video_fp, start=0, end=None, num_frames=self.num_frames, 43 | sample_beginning=self.sample_beginning, central_frames=self.central_frames) 44 | 45 | if self.transforms is not None: 46 | video = self.transforms(video) 47 | 48 | output = {'video': video, 'text': caption, 'time': 0, 49 | 'dataset': self.dataset_name} 50 | output['start_time'] = 0 51 | output['end_time'] = 0 52 | 53 | return output 54 | -------------------------------------------------------------------------------- /howtocaption/llm_prompting/scripts/2_create_word_blocks.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tqdm 3 | import argparse 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--asr", type=str, default='data/howto100m/asr_filtered.pickle') 9 | parser.add_argument("--n_words_max", type=int, default=200) 10 | parser.add_argument("--output_key", type=str, default='200w') 11 | 12 | args = parser.parse_args() 13 | n_words_max = args.n_words_max 14 | 15 | with open(args.asr, 'rb') as fin: 16 | data = pickle.load(fin) 17 | 18 | output = {} 19 | for name, x in tqdm.tqdm(data.items()): 20 | sentences = [sent for sent in x['text']] 21 | 22 | blocks = [] 23 | starts = [] 24 | ends = [] 25 | all_words = 0 26 | 27 | cur_block = '' 28 | cur_words = 0 29 | cur_start = 1000000000 30 | cur_end = None 31 | 32 | for start, end, sent in zip(x['start'], x['end'], sentences): 33 | if len(sent.split(' ')) + cur_words <= n_words_max: 34 | cur_block += f'\n{int(start)}s: ' + sent 35 | 36 | cur_words += len(sent.split(' ')) 37 | cur_start = min(cur_start, start) 38 | cur_end = end 39 | else: 40 | if cur_block != '': 41 | blocks.append(cur_block) 42 | starts.append(cur_start) 43 | ends.append(cur_end) 44 | 45 | cur_block = sent 46 | cur_words = len(sent.split(' ')) 47 | cur_start = start 48 | cur_end = end 49 | 50 | if cur_block != '': 51 | blocks.append(cur_block) 52 | starts.append(cur_start) 53 | ends.append(cur_end) 54 | 55 | x[args.output_key] = { 56 | 'text': blocks, 57 | 'start': starts, 58 | 'end': ends, 59 | } 60 | 61 | with open(args.asr, 'wb') as fout: 62 | pickle.dump(data, fout) 63 | -------------------------------------------------------------------------------- /configs/VL_training/captioning_youcook2.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: false 8 | 9 | train_contrastive: false 10 | train_captioning: true 11 | train_itm: false 12 | 13 | train_max_text_length: 32 14 | queue_size: 2048 15 | 16 | train_data_loader: 17 | - type: VideoDataLoader 18 | args: 19 | dataset_type: YouCook2 20 | dataset_args: 21 | data_root: data/youcook 22 | num_frames: 16 23 | max_text_length: 32 24 | num_workers: 16 25 | batch_size: 32 26 | split: train 27 | transform: 'train' 28 | 29 | valid_data_loader: 30 | - type: VideoDataLoader 31 | args: 32 | dataset_type: YouCook2 33 | dataset_args: 34 | data_root: data/youcook 35 | num_frames: 16 36 | max_text_length: 32 37 | num_workers: 16 38 | batch_size: 32 39 | split: val 40 | transform: 'test' 41 | 42 | 43 | optimizer: 44 | type: AdamW 45 | args: 46 | lr: 1.0e-05 47 | weight_decay: 0.05 48 | 49 | lr_scheduler: 50 | type: CosineAnnealingLR 51 | args: 52 | T_max: 10 53 | eta_min: 0 54 | 55 | save_dir: output 56 | 57 | trainer: 58 | type: VL_Trainer 59 | args: 60 | resume_only_model: True 61 | load_strict: False 62 | 63 | lr_scheduler_update: 'iter' 64 | init_retrieval: false 65 | init_nlp: false 66 | 67 | epochs: 10 68 | save_latest: True 69 | save_period: 1000000 70 | monitor: 'off' 71 | mixed_precision: true 72 | 73 | log_visual_input_at_start: True 74 | freq_visual_input: 100000 75 | nlp_freq_eval: 1 76 | freq_eval: 100000 77 | retrieval_freq_eval: 100000 78 | 79 | eval_args: 80 | num_beams: 1 81 | min_length: 0 82 | max_length: 20 83 | top_p: 1.0 84 | repetition_penalty: 1.0 85 | 86 | clip_grad: 20 87 | -------------------------------------------------------------------------------- /configs/VL_training/captioning_msvd.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: false 8 | train_contrastive: false 9 | train_captioning: true 10 | train_itm: false 11 | 12 | train_max_text_length: 32 13 | queue_size: 2048 14 | 15 | train_data_loader: 16 | - type: VideoDataLoader 17 | args: 18 | dataset_type: MSVD 19 | dataset_args: 20 | data_root: data/msvd 21 | num_frames: 16 22 | max_text_length: 32 23 | sample_per_video: true 24 | num_workers: 16 25 | batch_size: 32 26 | split: train 27 | transform: 'train' 28 | 29 | valid_data_loader: 30 | - type: VideoDataLoader 31 | args: 32 | dataset_type: MSVD 33 | dataset_args: 34 | data_root: data/msvd 35 | num_frames: 16 36 | max_text_length: 32 37 | multi_sentence_per_video: true 38 | num_workers: 16 39 | batch_size: 32 40 | split: test 41 | transform: 'test' 42 | 43 | 44 | optimizer: 45 | type: AdamW 46 | args: 47 | lr: 1.0e-05 48 | weight_decay: 0.05 49 | 50 | lr_scheduler: 51 | type: CosineAnnealingLR 52 | args: 53 | T_max: 40 54 | eta_min: 0 55 | 56 | save_dir: output 57 | 58 | trainer: 59 | type: VL_Trainer 60 | args: 61 | resume_only_model: True 62 | load_strict: False 63 | 64 | lr_scheduler_update: 'iter' 65 | init_retrieval: false 66 | init_nlp: false 67 | 68 | epochs: 40 69 | save_latest: true 70 | save_period: 1000000 71 | monitor: 'off' 72 | mixed_precision: true 73 | 74 | log_visual_input_at_start: True 75 | freq_visual_input: 100000 76 | nlp_freq_eval: 1 77 | freq_eval: 100000 78 | retrieval_freq_eval: 100000 79 | 80 | eval_args: 81 | num_beams: 1 82 | min_length: 0 83 | max_length: 20 84 | top_p: 1.0 85 | repetition_penalty: 1.0 86 | 87 | clip_grad: 20 88 | -------------------------------------------------------------------------------- /configs/VL_training/captioning_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: false 8 | train_contrastive: false 9 | train_captioning: true 10 | train_itm: false 11 | 12 | train_max_text_length: 32 13 | queue_size: 2048 14 | 15 | train_data_loader: 16 | - type: VideoDataLoader 17 | args: 18 | dataset_type: MSRVTT 19 | dataset_args: 20 | dataset_name: MSRVTT_Cap 21 | data_root: data/msrvtt 22 | num_frames: 16 23 | max_text_length: 32 24 | cut: 'full-val' 25 | num_workers: 16 26 | batch_size: 32 27 | split: train 28 | transform: 'train' 29 | 30 | valid_data_loader: 31 | - type: VideoDataLoader 32 | args: 33 | dataset_type: MSRVTT 34 | dataset_args: 35 | dataset_name: MSRVTT_Cap 36 | data_root: data/msrvtt 37 | num_frames: 16 38 | max_text_length: 32 39 | cut: 'full-val' 40 | num_workers: 16 41 | batch_size: 32 42 | split: test 43 | transform: 'test' 44 | 45 | 46 | optimizer: 47 | type: AdamW 48 | args: 49 | lr: 1.0e-05 50 | weight_decay: 0.05 51 | 52 | lr_scheduler: 53 | type: CosineAnnealingLR 54 | args: 55 | T_max: 30 56 | eta_min: 0 57 | 58 | save_dir: output 59 | 60 | trainer: 61 | type: VL_Trainer 62 | args: 63 | resume_only_model: True 64 | load_strict: False 65 | 66 | lr_scheduler_update: 'iter' 67 | init_retrieval: false 68 | init_nlp: false 69 | 70 | epochs: 30 71 | save_latest: true 72 | save_period: 1000000 73 | monitor: 'off' 74 | mixed_precision: true 75 | 76 | log_visual_input_at_start: True 77 | freq_visual_input: 100000 78 | nlp_freq_eval: 1 79 | freq_eval: 100000 80 | retrieval_freq_eval: 100000 81 | 82 | eval_args: 83 | num_beams: 1 84 | min_length: 0 85 | max_length: 20 86 | top_p: 1.0 87 | repetition_penalty: 1.0 88 | 89 | clip_grad: 20 90 | -------------------------------------------------------------------------------- /howtocaption/data_loader/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms.functional import InterpolationMode 3 | from howtocaption.data_loader.utils import MyRandomResizedCrop 4 | 5 | 6 | def init_transform_dict(input_res=224, 7 | center_crop=256, 8 | randcrop_scale=(0.5, 1.0), 9 | color_jitter=(0, 0, 0), 10 | grayscale_p=0, 11 | norm_mean=(0.48145466, 0.4578275, 0.40821073), 12 | norm_std=(0.26862954, 0.26130258, 0.27577711), 13 | antialias=True, 14 | use_old_wrong_version=False): 15 | if use_old_wrong_version: 16 | norm_mean = (0.485, 0.456, 0.406) 17 | norm_std = (0.229, 0.224, 0.225) 18 | antialias = False 19 | 20 | normalize = transforms.Normalize(mean=norm_mean, std=norm_std) 21 | tsfm_dict = { 22 | 'train': transforms.Compose([ 23 | MyRandomResizedCrop(input_res, scale=randcrop_scale, interpolation=InterpolationMode.BICUBIC, antialias=antialias), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]), 26 | transforms.RandomGrayscale(p=grayscale_p), 27 | normalize, 28 | ]), 29 | 'test': transforms.Compose([ 30 | transforms.Resize(input_res, interpolation=InterpolationMode.BICUBIC, antialias=antialias), 31 | transforms.CenterCrop(input_res), 32 | normalize, 33 | ]), 34 | 'test_resize': transforms.Compose([ # TODO: this might be just test 35 | transforms.Resize(center_crop, interpolation=InterpolationMode.BICUBIC, antialias=antialias), 36 | transforms.CenterCrop(center_crop), 37 | transforms.Resize(input_res, interpolation=InterpolationMode.BICUBIC, antialias=antialias), 38 | normalize, 39 | ]), 40 | 'visualization': transforms.Compose([ 41 | transforms.Resize(240, interpolation=InterpolationMode.BICUBIC, antialias=antialias), 42 | transforms.CenterCrop((240, 320)), 43 | # normalize, 44 | ]), 45 | 'none': None, 46 | } 47 | return tsfm_dict -------------------------------------------------------------------------------- /howtocaption/lr_scheduler/warmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class CustomCosineSchedulerWithWarmup: 5 | def __init__(self, optimizer, T_max, lr, warmup_epochs=0, eta_min=0): 6 | self.optimizer = optimizer 7 | self.warmup_epochs = warmup_epochs 8 | self.T_max = T_max 9 | self.eta_min = eta_min 10 | 11 | if isinstance(lr, list): 12 | self.base_lr = lr 13 | else: 14 | self.base_lr = [lr] * len(self.optimizer.param_groups) 15 | assert len(self.optimizer.param_groups) == len(self.base_lr) 16 | 17 | self.step(0) 18 | 19 | def step(self, epoch): 20 | for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lr): 21 | epoch = epoch % self.T_max 22 | if epoch < self.warmup_epochs: 23 | cur_lr = epoch / self.warmup_epochs * base_lr 24 | else: 25 | cur_lr = self.eta_min + (base_lr - self.eta_min) * 0.5 * ( 26 | 1. + math.cos(math.pi * (epoch - self.warmup_epochs) / (self.T_max - self.warmup_epochs))) 27 | 28 | param_group['lr'] = cur_lr 29 | 30 | 31 | class SchedulerWithWarmup: 32 | def __init__(self, optimizer, lr, warmup_epochs=0, eta_min=0): 33 | self.optimizer = optimizer 34 | self.warmup_epochs = warmup_epochs 35 | self.eta_min = eta_min 36 | 37 | self.base_lr = lr 38 | self.step(0) 39 | 40 | def step(self, epoch): 41 | if epoch < self.warmup_epochs: 42 | cur_lr = epoch / self.warmup_epochs * self.base_lr 43 | else: 44 | cur_lr = self.base_lr 45 | 46 | for param_group in self.optimizer.param_groups: 47 | param_group['lr'] = cur_lr 48 | 49 | 50 | class SchedulerWithWarmupAndDecay: 51 | def __init__(self, optimizer, lr, warmup_epochs=0, min_lr=0, decay_rate=1): 52 | self.optimizer = optimizer 53 | self.warmup_epochs = warmup_epochs 54 | self.min_lr = min_lr 55 | self.decay_rate = decay_rate 56 | 57 | self.base_lr = lr 58 | self.step(0) 59 | 60 | def step(self, epoch): 61 | if epoch < self.warmup_epochs: 62 | cur_lr = epoch / self.warmup_epochs * self.base_lr 63 | else: 64 | cur_lr = max(self.min_lr, self.base_lr * (self.decay_rate ** (epoch - self.warmup_epochs))) 65 | 66 | for param_group in self.optimizer.param_groups: 67 | param_group['lr'] = cur_lr -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/videocc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | import json 5 | 6 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 7 | 8 | 9 | class VideoCC3M(Dataset): 10 | def __init__(self, data_root, split, num_frames=4, transforms=None, 11 | dataset_name='VideoCC3M', 12 | sample_beginning=False, 13 | central_frames=False, 14 | ): 15 | super(VideoCC3M, self).__init__() 16 | 17 | self.data_root = data_root 18 | self._load_metadata(split) 19 | self.split = split 20 | 21 | self.transforms = transforms 22 | self.num_frames = num_frames 23 | self.dataset_name = dataset_name 24 | 25 | self.sample_beginning = sample_beginning 26 | self.central_frames = central_frames 27 | 28 | def _load_metadata(self, split): 29 | assert split == "train" 30 | self.csv = pd.read_csv(os.path.join(self.data_root, f'video_cc_public_downloaded.csv')) 31 | 32 | def __len__(self): 33 | return len(self.csv) 34 | 35 | def __getitem__(self, idx): 36 | data = self.csv.iloc[idx] 37 | 38 | video_id = data["video_id"] 39 | rel_fp = f'dataset/{data["page_dir"]:05d}/{data["video_id"]}.mp4' 40 | video_fp = os.path.join(self.data_root, rel_fp) 41 | caption = data['caption'] 42 | 43 | try: 44 | json_path = f'dataset/{data["page_dir"]:05d}/{data["video_id"]}.json' 45 | with open(f'data/videocc3m/{json_path}') as fin: 46 | metadata = json.load(fin) 47 | width = metadata['video_metadata']['streams'][0]['width'] 48 | height = metadata['video_metadata']['streams'][0]['height'] 49 | clips = metadata['clips'] 50 | assert len(clips) == 1 51 | end = clips[0][1] - clips[0][0] 52 | except Exception as excep: 53 | print("Warning: video path: {} error. Error: {}".format(video_fp, excep), flush=True) 54 | width, height, end = None, None, None 55 | 56 | video = get_video_frames(video_fp, start=0, end=end, num_frames=self.num_frames, 57 | sample_beginning=self.sample_beginning, central_frames=self.central_frames, 58 | width=width, height=height) 59 | 60 | if self.transforms is not None: 61 | video = self.transforms(video) 62 | 63 | output = {'video': video, 'text': caption, 'time': 0, 'dataset': self.dataset_name, 'path': rel_fp, 'idx': video_id} 64 | 65 | return output 66 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import ffmpeg 4 | 5 | 6 | def get_video_frames(video_path, start, end, num_frames, fps=None, width=None, height=None, 7 | sample_beginning=False, central_frames=False): 8 | try: 9 | if (width is None) or (height is None) or (end is None): 10 | probe = ffmpeg.probe(video_path) 11 | 12 | if end is None: 13 | end = float(probe['format']['duration']) 14 | if end == 0: 15 | end = num_frames 16 | except Exception as excep: 17 | print("Warning: ffmpeg error. video path: {} error. Error: {}".format(video_path, excep), flush=True) 18 | 19 | num_sec = end - start 20 | if fps is None: 21 | fps = num_frames / num_sec 22 | 23 | assert (sample_beginning == False) or (central_frames == False) 24 | if sample_beginning: 25 | start = start + np.random.random() * (num_sec / num_frames) 26 | if central_frames: 27 | start = start + (num_sec / num_frames) / 2 28 | 29 | cmd = ( 30 | ffmpeg 31 | .input(video_path, ss=start, t=num_sec + 0.1) 32 | .filter('fps', fps=fps) 33 | ) 34 | for i in range(1): 35 | try: 36 | if width is None: 37 | width = int(probe['streams'][0]['width']) 38 | 39 | if height is None: 40 | height = int(probe['streams'][0]['height']) 41 | 42 | out, _ = ( 43 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') 44 | .run(capture_stdout=True, quiet=True) 45 | ) 46 | 47 | video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) 48 | video = th.tensor(video) 49 | video = video.permute(0, 3, 1, 2) 50 | if video.shape[0] < num_frames: 51 | # print(f'Warning: sampling less frames than necessary: {video.shape[0]}') 52 | zeros = th.zeros((num_frames - video.shape[0], 3, height, width), dtype=th.uint8) 53 | video = th.cat((video, zeros), axis=0) 54 | elif video.shape[0] > num_frames: 55 | # print(f'Warning: sampling more frames than necessary: {video.shape[0]}') 56 | video = video[:num_frames] 57 | break 58 | except Exception as excep: 59 | print("Warning: ffmpeg error. video path: {} error. Error: {}".format(video_path, excep), flush=True) 60 | else: 61 | # print(f'Warning: ffmpeg error. {video_path}', flush=True) 62 | video = th.zeros((num_frames, 3, 224, 224), dtype=th.uint8) 63 | 64 | video = video.float() / 255. 65 | 66 | return video -------------------------------------------------------------------------------- /configs/VL_training/baselines/dual_encoder_retrieval_webvid.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: false 11 | alpha: 0.6 12 | train_max_text_length: 32 13 | queue_size: 2048 14 | 15 | train_data_loader: 16 | - type: VideoDataLoader 17 | args: 18 | dataset_type: WebVid2M 19 | dataset_args: 20 | data_root: data/webvid2m 21 | num_frames: 4 22 | num_workers: 16 23 | batch_size: 128 24 | split: 'train' 25 | transform: train 26 | 27 | valid_data_loader: 28 | - type: VideoDataLoader 29 | args: 30 | dataset_type: MSRVTT 31 | dataset_args: 32 | data_root: data/msrvtt 33 | num_frames: 12 34 | max_text_length: 32 35 | cut: jsfusion 36 | num_workers: 16 37 | batch_size: 32 38 | split: test 39 | transform: 'test' 40 | 41 | 42 | - type: VideoDataLoader 43 | args: 44 | dataset_type: YouCook2 45 | dataset_args: 46 | data_root: data/youcook 47 | num_frames: 12 48 | max_text_length: 32 49 | num_workers: 16 50 | batch_size: 32 51 | split: val 52 | transform: 'test' 53 | 54 | - type: VideoDataLoader 55 | args: 56 | dataset_type: MSVD 57 | dataset_args: 58 | data_root: data/msvd 59 | num_frames: 12 60 | max_text_length: 32 61 | multi_sentence_per_video: true 62 | num_workers: 16 63 | batch_size: 32 64 | split: test 65 | transform: 'test' 66 | 67 | 68 | - type: VideoDataLoader 69 | args: 70 | dataset_type: LSMDC 71 | dataset_args: 72 | data_root: data/lsmdc 73 | num_frames: 12 74 | max_text_length: 32 75 | num_workers: 16 76 | batch_size: 32 77 | split: test 78 | transform: 'test' 79 | 80 | 81 | optimizer: 82 | type: AdamW 83 | args: 84 | lr: 1.0e-06 85 | weight_decay: 0.05 86 | 87 | lr_scheduler: 88 | type: SchedulerWithWarmup 89 | args: 90 | warmup_epochs: 10 91 | lr: 1.0e-06 92 | 93 | save_dir: output 94 | 95 | trainer: 96 | type: VL_Trainer 97 | args: 98 | inf_dataloaders: true 99 | len_epoch: 500 100 | lr_scheduler_update: 'iter' 101 | init_retrieval: true 102 | save_epochs: [50, 100] 103 | epochs: 100 104 | save_latest: True 105 | save_period: 1000000 106 | monitor: 'off' 107 | mixed_precision: true 108 | 109 | log_visual_input_at_start: True 110 | freq_visual_input: 100000 111 | nlp_freq_eval: 100000 112 | freq_eval: 100000 113 | retrieval_freq_eval: 10 114 | 115 | clip_grad: 20 116 | -------------------------------------------------------------------------------- /configs/VL_training/dual_encoder_retrieval.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: false 11 | 12 | alpha: 0.6 13 | train_max_text_length: 32 14 | queue_size: 2048 15 | 16 | train_data_loader: 17 | - type: VideoDataLoader 18 | args: 19 | dataset_type: HowTo100M 20 | dataset_args: 21 | # meta_info_path: data/howto100m/meta_width_height.pickle # optional: put path to dict with width/height of videos to speed up data reading 22 | # captions_are_zipped: True 23 | 24 | csv: data/howto100m/video_path_filtered.csv 25 | video_root: data/howto100m/videos 26 | caption_path: data/howtocaption/HowToCaption.pickle 27 | num_frames: 4 28 | prefetch_factor: 1 29 | num_workers: 16 30 | batch_size: 128 31 | split: 'train' 32 | transform: train 33 | 34 | valid_data_loader: 35 | - type: VideoDataLoader 36 | args: 37 | dataset_type: MSRVTT 38 | dataset_args: 39 | data_root: data/msrvtt 40 | num_frames: 12 41 | max_text_length: 32 42 | cut: jsfusion 43 | num_workers: 16 44 | batch_size: 32 45 | split: test 46 | transform: 'test' 47 | 48 | 49 | - type: VideoDataLoader 50 | args: 51 | dataset_type: YouCook2 52 | dataset_args: 53 | data_root: data/youcook 54 | num_frames: 12 55 | max_text_length: 32 56 | num_workers: 16 57 | batch_size: 32 58 | split: val 59 | transform: 'test' 60 | 61 | - type: VideoDataLoader 62 | args: 63 | dataset_type: MSVD 64 | dataset_args: 65 | data_root: data/msvd 66 | num_frames: 12 67 | max_text_length: 32 68 | multi_sentence_per_video: true 69 | num_workers: 16 70 | batch_size: 32 71 | split: test 72 | transform: 'test' 73 | 74 | 75 | - type: VideoDataLoader 76 | args: 77 | dataset_type: LSMDC 78 | dataset_args: 79 | data_root: data/lsmdc 80 | num_frames: 12 81 | max_text_length: 32 82 | num_workers: 16 83 | batch_size: 32 84 | split: test 85 | transform: 'test' 86 | 87 | 88 | optimizer: 89 | type: AdamW 90 | args: 91 | lr: 1.0e-06 92 | weight_decay: 0.05 93 | 94 | save_dir: output 95 | 96 | trainer: 97 | type: VL_Trainer 98 | args: 99 | inf_dataloaders: true 100 | len_epoch: 500 101 | lr_scheduler_update: 'iter' 102 | init_retrieval: true 103 | save_epochs: [40, 300, 600] 104 | epochs: 600 105 | save_latest: True 106 | save_period: 1000000 107 | monitor: 'off' 108 | mixed_precision: true 109 | 110 | log_visual_input_at_start: True 111 | freq_visual_input: 100000 112 | nlp_freq_eval: 100000 113 | freq_eval: 100000 114 | retrieval_freq_eval: 10 115 | 116 | clip_grad: 20 117 | -------------------------------------------------------------------------------- /configs/align_and_filter/finetune_1round.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: false 11 | alpha: 0.6 12 | train_max_text_length: 32 13 | queue_size: 2048 14 | continual_learning_weight: 0.1 15 | 16 | train_data_loader: 17 | - type: VideoDataLoader 18 | args: 19 | dataset_type: HowTo100M 20 | dataset_args: 21 | # meta_info_path: data/howto100m/meta_width_height.pickle # optional: put path to dict with width/height of videos to speed up data reading 22 | # captions_are_zipped: True 23 | 24 | csv: data/howto100m/video_path_filtered.csv 25 | video_root: data/howto100m/videos 26 | caption_path: TODO # FIXME: put path to initial alignment 27 | num_frames: 4 28 | prefetch_factor: 1 29 | num_workers: 16 30 | batch_size: 128 31 | split: 'train' 32 | transform: train 33 | 34 | valid_data_loader: 35 | - type: VideoDataLoader 36 | args: 37 | dataset_type: MSRVTT 38 | dataset_args: 39 | data_root: data/msrvtt 40 | num_frames: 12 41 | max_text_length: 32 42 | cut: jsfusion 43 | num_workers: 16 44 | batch_size: 32 45 | split: test 46 | transform: 'test' 47 | 48 | 49 | - type: VideoDataLoader 50 | args: 51 | dataset_type: YouCook2 52 | dataset_args: 53 | data_root: data/youcook 54 | num_frames: 12 55 | max_text_length: 32 56 | num_workers: 16 57 | batch_size: 32 58 | split: val 59 | transform: 'test' 60 | 61 | - type: VideoDataLoader 62 | args: 63 | dataset_type: MSVD 64 | dataset_args: 65 | data_root: data/msvd 66 | num_frames: 12 67 | max_text_length: 32 68 | multi_sentence_per_video: true 69 | num_workers: 16 70 | batch_size: 32 71 | split: test 72 | transform: 'test' 73 | 74 | 75 | - type: VideoDataLoader 76 | args: 77 | dataset_type: LSMDC 78 | dataset_args: 79 | data_root: data/lsmdc 80 | num_frames: 12 81 | max_text_length: 32 82 | num_workers: 16 83 | batch_size: 32 84 | split: test 85 | transform: 'test' 86 | 87 | 88 | optimizer: 89 | type: AdamW 90 | args: 91 | lr: 1.0e-06 92 | weight_decay: 0.05 93 | 94 | save_dir: output 95 | 96 | trainer: 97 | type: VL_Trainer 98 | args: 99 | inf_dataloaders: true 100 | len_epoch: 500 101 | lr_scheduler_update: 'iter' 102 | init_retrieval: false 103 | save_epochs: [40] 104 | epochs: 40 105 | save_latest: True 106 | save_period: 1000000 107 | monitor: 'off' 108 | mixed_precision: true 109 | 110 | log_visual_input_at_start: True 111 | freq_visual_input: 100000 112 | nlp_freq_eval: 100000 113 | freq_eval: 100000 114 | retrieval_freq_eval: 10 115 | 116 | clip_grad: 20 117 | -------------------------------------------------------------------------------- /howtocaption/trainer/coco_eval.py: -------------------------------------------------------------------------------- 1 | from pycocoevalcap.bleu.bleu import Bleu 2 | from pycocoevalcap.cider.cider import Cider 3 | from pycocoevalcap.meteor.meteor import Meteor 4 | from pycocoevalcap.rouge.rouge import Rouge 5 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 6 | # from pycocotools.spice.spice import Spice 7 | 8 | 9 | # Modified version of https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py 10 | # without Spice 11 | class COCOEvalCap: 12 | def __init__(self, coco, cocoRes): 13 | self.evalImgs = [] 14 | self.eval = {} 15 | self.imgToEval = {} 16 | self.coco = coco 17 | self.cocoRes = cocoRes 18 | self.params = {'image_id': coco.getImgIds()} 19 | 20 | def evaluate(self): 21 | imgIds = self.params['image_id'] 22 | # imgIds = self.coco.getImgIds() 23 | gts = {} 24 | res = {} 25 | for imgId in imgIds: 26 | gts[imgId] = self.coco.imgToAnns[imgId] 27 | res[imgId] = self.cocoRes.imgToAnns[imgId] 28 | 29 | # ================================================= 30 | # Set up scorers 31 | # ================================================= 32 | print('tokenization...') 33 | tokenizer = PTBTokenizer() 34 | gts = tokenizer.tokenize(gts) 35 | res = tokenizer.tokenize(res) 36 | 37 | # ================================================= 38 | # Set up scorers 39 | # ================================================= 40 | print('setting up scorers...') 41 | scorers = [ 42 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 43 | (Meteor(),"METEOR"), 44 | (Rouge(), "ROUGE_L"), 45 | (Cider(), "CIDEr"), 46 | # (Spice(), "SPICE") 47 | ] 48 | 49 | # ================================================= 50 | # Compute scores 51 | # ================================================= 52 | for scorer, method in scorers: 53 | print('computing %s score...'%(scorer.method())) 54 | score, scores = scorer.compute_score(gts, res) 55 | if type(method) == list: 56 | for sc, scs, m in zip(score, scores, method): 57 | self.setEval(sc, m) 58 | self.setImgToEvalImgs(scs, gts.keys(), m) 59 | print("%s: %0.3f"%(m, sc)) 60 | else: 61 | self.setEval(score, method) 62 | self.setImgToEvalImgs(scores, gts.keys(), method) 63 | print("%s: %0.3f"%(method, score)) 64 | self.setEvalImgs() 65 | 66 | def setEval(self, score, method): 67 | self.eval[method] = score 68 | 69 | def setImgToEvalImgs(self, scores, imgIds, method): 70 | for imgId, score in zip(imgIds, scores): 71 | if not imgId in self.imgToEval: 72 | self.imgToEval[imgId] = {} 73 | self.imgToEval[imgId]["image_id"] = imgId 74 | self.imgToEval[imgId][method] = score 75 | 76 | def setEvalImgs(self): 77 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()] -------------------------------------------------------------------------------- /configs/VL_training/baselines/dual_encoder_retrieval_HowTo100M.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: false 11 | 12 | alpha: 0.6 13 | train_max_text_length: 32 14 | queue_size: 2048 15 | 16 | train_data_loader: 17 | - type: VideoDataLoader 18 | args: 19 | dataset_type: HowTo100M 20 | dataset_args: 21 | # meta_info_path: data/howto100m/meta_width_height.pickle # optional: put path to dict with width/height of videos to speed up data reading 22 | # captions_are_zipped: True 23 | 24 | csv: data/howto100m/video_path_filtered.csv 25 | video_root: data/howto100m/videos 26 | caption_path: data/howto100m/asr_filtered.pickle 27 | num_frames: 4 28 | aggregate_to_min_time: True 29 | min_time: 8 30 | prefetch_factor: 1 31 | num_workers: 16 32 | batch_size: 128 33 | split: 'train' 34 | transform: train 35 | 36 | valid_data_loader: 37 | - type: VideoDataLoader 38 | args: 39 | dataset_type: MSRVTT 40 | dataset_args: 41 | data_root: data/msrvtt 42 | num_frames: 12 43 | max_text_length: 32 44 | cut: jsfusion 45 | num_workers: 16 46 | batch_size: 32 47 | split: test 48 | transform: 'test' 49 | 50 | 51 | - type: VideoDataLoader 52 | args: 53 | dataset_type: YouCook2 54 | dataset_args: 55 | data_root: data/youcook 56 | num_frames: 12 57 | max_text_length: 32 58 | num_workers: 16 59 | batch_size: 32 60 | split: val 61 | transform: 'test' 62 | 63 | - type: VideoDataLoader 64 | args: 65 | dataset_type: MSVD 66 | dataset_args: 67 | data_root: data/msvd 68 | num_frames: 12 69 | max_text_length: 32 70 | multi_sentence_per_video: true 71 | num_workers: 16 72 | batch_size: 32 73 | split: test 74 | transform: 'test' 75 | 76 | 77 | - type: VideoDataLoader 78 | args: 79 | dataset_type: LSMDC 80 | dataset_args: 81 | data_root: data/lsmdc 82 | num_frames: 12 83 | max_text_length: 32 84 | num_workers: 16 85 | batch_size: 32 86 | split: test 87 | transform: 'test' 88 | 89 | 90 | optimizer: 91 | type: AdamW 92 | args: 93 | lr: 1.0e-06 94 | weight_decay: 0.05 95 | 96 | save_dir: output 97 | 98 | trainer: 99 | type: VL_Trainer 100 | args: 101 | inf_dataloaders: true 102 | len_epoch: 500 103 | lr_scheduler_update: 'iter' 104 | init_retrieval: true 105 | save_epochs: [40, 300, 600] 106 | epochs: 600 107 | save_latest: True 108 | save_period: 1000000 109 | monitor: 'off' 110 | mixed_precision: true 111 | 112 | log_visual_input_at_start: True 113 | freq_visual_input: 100000 114 | nlp_freq_eval: 100000 115 | freq_eval: 100000 116 | retrieval_freq_eval: 10 117 | 118 | clip_grad: 20 119 | -------------------------------------------------------------------------------- /howtocaption/utils/retrieval_metrics_from_cap4video.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import unicode_literals 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | 9 | def compute_metrics(x): 10 | sx = np.sort(-x, axis=1) 11 | d = np.diag(-x) 12 | d = d[:, np.newaxis] 13 | ind = sx - d 14 | ind = np.where(ind == 0) 15 | ind = ind[1] 16 | metrics = {} 17 | metrics['R1'] = float(np.sum(ind == 0)) * 100 / len(ind) 18 | metrics['R5'] = float(np.sum(ind < 5)) * 100 / len(ind) 19 | metrics['R10'] = float(np.sum(ind < 10)) * 100 / len(ind) 20 | metrics['R50'] = float(np.sum(ind < 50)) * 100 / len(ind) 21 | metrics['MR'] = np.median(ind) + 1 22 | metrics["MedianR"] = metrics['MR'] 23 | metrics["MeanR"] = np.mean(ind) + 1 24 | # metrics["cols"] = [int(i) for i in list(ind)] 25 | return metrics 26 | 27 | def print_computed_metrics(metrics): 28 | r1 = metrics['R1'] 29 | r5 = metrics['R5'] 30 | r10 = metrics['R10'] 31 | mr = metrics['MR'] 32 | print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr)) 33 | 34 | # below two functions directly come from: https://github.com/Deferf/Experiments 35 | def tensor_text_to_video_metrics(sim_tensor, top_k = [1,5,10,50]): 36 | if not torch.is_tensor(sim_tensor): 37 | sim_tensor = torch.tensor(sim_tensor) 38 | 39 | # Permute sim_tensor so it represents a sequence of text-video similarity matrices. 40 | # Then obtain the double argsort to position the rank on the diagonal 41 | stacked_sim_matrices = sim_tensor.permute(1, 0, 2) 42 | first_argsort = torch.argsort(stacked_sim_matrices, dim = -1, descending= True) 43 | second_argsort = torch.argsort(first_argsort, dim = -1, descending= False) 44 | 45 | # Extracts ranks i.e diagonals 46 | ranks = torch.flatten(torch.diagonal(second_argsort, dim1 = 1, dim2 = 2)) 47 | 48 | # Now we need to extract valid ranks, as some belong to inf padding values 49 | permuted_original_data = torch.flatten(torch.diagonal(sim_tensor, dim1 = 0, dim2 = 2)) 50 | mask = ~ torch.logical_or(torch.isinf(permuted_original_data), torch.isnan(permuted_original_data)) 51 | valid_ranks = ranks[mask] 52 | # A quick dimension check validates our results, there may be other correctness tests pending 53 | # Such as dot product localization, but that is for other time. 54 | #assert int(valid_ranks.shape[0]) == sum([len(text_dict[k]) for k in text_dict]) 55 | if not torch.is_tensor(valid_ranks): 56 | valid_ranks = torch.tensor(valid_ranks) 57 | results = {f"R{k}": float(torch.sum(valid_ranks < k) * 100 / len(valid_ranks)) for k in top_k} 58 | results["MedianR"] = float(torch.median(valid_ranks + 1)) 59 | results["MeanR"] = float(np.mean(valid_ranks.numpy() + 1)) 60 | results["Std_Rank"] = float(np.std(valid_ranks.numpy() + 1)) 61 | results['MR'] = results["MedianR"] 62 | return results 63 | 64 | def tensor_video_to_text_sim(sim_tensor): 65 | if not torch.is_tensor(sim_tensor): 66 | sim_tensor = torch.tensor(sim_tensor) 67 | # Code to avoid nans 68 | sim_tensor[sim_tensor != sim_tensor] = float('-inf') 69 | # Forms a similarity matrix for use with rank at k 70 | values, _ = torch.max(sim_tensor, dim=1, keepdim=True) 71 | return torch.squeeze(values).T 72 | -------------------------------------------------------------------------------- /configs/VL_training/full_encoder_decoder_ViT_L.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'large' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: true 11 | train_itm: true 12 | 13 | alpha: 0.6 14 | train_max_text_length: 32 15 | queue_size: 2048 16 | 17 | train_data_loader: 18 | - type: VideoDataLoader 19 | args: 20 | dataset_type: HowTo100M 21 | dataset_args: 22 | # meta_info_path: data/howto100m/meta_width_height.pickle # optional: put path to dict with width/height of videos to speed up data reading 23 | # captions_are_zipped: True 24 | 25 | csv: data/howto100m/video_path_filtered.csv 26 | video_root: data/howto100m/videos 27 | caption_path: data/howtocaption/HowToCaption.pickle 28 | num_frames: 4 29 | prefetch_factor: 1 30 | num_workers: 16 31 | batch_size: 112 32 | split: 'train' 33 | transform: train 34 | 35 | valid_data_loader: 36 | - type: VideoDataLoader 37 | args: 38 | dataset_type: MSRVTT 39 | dataset_args: 40 | data_root: data/msrvtt 41 | num_frames: 12 42 | max_text_length: 32 43 | cut: jsfusion 44 | num_workers: 16 45 | batch_size: 32 46 | split: test 47 | transform: 'test' 48 | 49 | - type: VideoDataLoader 50 | args: 51 | dataset_type: MSRVTT 52 | dataset_args: 53 | dataset_name: MSRVTT_Cap 54 | data_root: data/msrvtt 55 | num_frames: 12 56 | max_text_length: 32 57 | cut: 'full-val' 58 | num_workers: 16 59 | batch_size: 32 60 | split: test 61 | transform: 'test' 62 | 63 | - type: VideoDataLoader 64 | args: 65 | dataset_type: YouCook2 66 | dataset_args: 67 | data_root: data/youcook 68 | num_frames: 12 69 | max_text_length: 32 70 | num_workers: 16 71 | batch_size: 32 72 | split: val 73 | transform: 'test' 74 | 75 | - type: VideoDataLoader 76 | args: 77 | dataset_type: MSVD 78 | dataset_args: 79 | data_root: data/msvd 80 | num_frames: 12 81 | max_text_length: 32 82 | multi_sentence_per_video: true 83 | num_workers: 16 84 | batch_size: 32 85 | split: test 86 | transform: 'test' 87 | 88 | 89 | - type: VideoDataLoader 90 | args: 91 | dataset_type: LSMDC 92 | dataset_args: 93 | data_root: data/lsmdc 94 | num_frames: 12 95 | max_text_length: 32 96 | num_workers: 16 97 | batch_size: 32 98 | split: test 99 | transform: 'test' 100 | 101 | 102 | optimizer: 103 | type: AdamW 104 | args: 105 | lr: 1.0e-06 106 | weight_decay: 0.05 107 | 108 | save_dir: output 109 | 110 | trainer: 111 | type: VL_Trainer 112 | args: 113 | inf_dataloaders: true 114 | len_epoch: 500 115 | lr_scheduler_update: 'iter' 116 | init_retrieval: false 117 | init_nlp: false 118 | 119 | save_epochs: [40, 300] 120 | epochs: 300 121 | save_latest: True 122 | save_period: 1000000 123 | monitor: 'off' 124 | mixed_precision: true 125 | 126 | log_visual_input_at_start: True 127 | freq_visual_input: 100000 128 | nlp_freq_eval: 10 129 | freq_eval: 100000 130 | retrieval_freq_eval: 10 131 | 132 | eval_args: 133 | num_beams: 1 134 | min_length: 0 135 | max_length: 20 136 | top_p: 1.0 137 | repetition_penalty: 1.0 138 | 139 | clip_grad: 20 140 | -------------------------------------------------------------------------------- /configs/VL_training/full_encoder_decoder.yaml: -------------------------------------------------------------------------------- 1 | arch: 2 | type: BlipVTDecoderModel 3 | args: 4 | vit: 'base' 5 | init_from_pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' 6 | med_config: configs/med_config.json 7 | tie_encoder_decoder_weights: true 8 | 9 | train_contrastive: true 10 | train_captioning: true 11 | train_itm: true 12 | 13 | alpha: 0.6 14 | train_max_text_length: 32 15 | queue_size: 2048 16 | 17 | train_data_loader: 18 | - type: VideoDataLoader 19 | args: 20 | dataset_type: HowTo100M 21 | dataset_args: 22 | # meta_info_path: data/howto100m/meta_width_height.pickle # optional: put path to dict with width/height of videos to speed up data reading 23 | # captions_are_zipped: True 24 | 25 | csv: data/howto100m/video_path_filtered.csv 26 | video_root: data/howto100m/videos 27 | caption_path: data/howtocaption/HowToCaption.pickle 28 | num_frames: 4 29 | prefetch_factor: 1 30 | num_workers: 16 31 | batch_size: 128 32 | split: 'train' 33 | transform: train 34 | 35 | valid_data_loader: 36 | - type: VideoDataLoader 37 | args: 38 | dataset_type: MSRVTT 39 | dataset_args: 40 | data_root: data/msrvtt 41 | num_frames: 12 42 | max_text_length: 32 43 | cut: jsfusion 44 | num_workers: 16 45 | batch_size: 32 46 | split: test 47 | transform: 'test' 48 | 49 | - type: VideoDataLoader 50 | args: 51 | dataset_type: MSRVTT 52 | dataset_args: 53 | dataset_name: MSRVTT_Cap 54 | data_root: data/msrvtt 55 | num_frames: 12 56 | max_text_length: 32 57 | cut: 'full-val' 58 | num_workers: 16 59 | batch_size: 32 60 | split: test 61 | transform: 'test' 62 | 63 | - type: VideoDataLoader 64 | args: 65 | dataset_type: YouCook2 66 | dataset_args: 67 | data_root: data/youcook 68 | num_frames: 12 69 | max_text_length: 32 70 | num_workers: 16 71 | batch_size: 32 72 | split: val 73 | transform: 'test' 74 | 75 | - type: VideoDataLoader 76 | args: 77 | dataset_type: MSVD 78 | dataset_args: 79 | data_root: data/msvd 80 | num_frames: 12 81 | max_text_length: 32 82 | multi_sentence_per_video: true 83 | num_workers: 16 84 | batch_size: 32 85 | split: test 86 | transform: 'test' 87 | 88 | 89 | - type: VideoDataLoader 90 | args: 91 | dataset_type: LSMDC 92 | dataset_args: 93 | data_root: data/lsmdc 94 | num_frames: 12 95 | max_text_length: 32 96 | num_workers: 16 97 | batch_size: 32 98 | split: test 99 | transform: 'test' 100 | 101 | 102 | optimizer: 103 | type: AdamW 104 | args: 105 | lr: 1.0e-06 106 | weight_decay: 0.05 107 | 108 | save_dir: output 109 | 110 | trainer: 111 | type: VL_Trainer 112 | args: 113 | inf_dataloaders: true 114 | len_epoch: 500 115 | lr_scheduler_update: 'iter' 116 | init_retrieval: false 117 | init_nlp: false 118 | 119 | save_epochs: [40, 300, 400] 120 | epochs: 400 121 | save_latest: True 122 | save_period: 1000000 123 | monitor: 'off' 124 | mixed_precision: true 125 | 126 | log_visual_input_at_start: True 127 | freq_visual_input: 100000 128 | nlp_freq_eval: 10 129 | freq_eval: 100000 130 | retrieval_freq_eval: 10 131 | 132 | eval_args: 133 | num_beams: 1 134 | min_length: 0 135 | max_length: 20 136 | top_p: 1.0 137 | repetition_penalty: 1.0 138 | 139 | clip_grad: 20 140 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/youcook2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 8 | from howtocaption.data_loader.transforms import init_transform_dict 9 | 10 | 11 | class YouCook2(Dataset): 12 | def __init__(self, data_root, split, num_frames=4, transforms=None, 13 | dataset_name='YouCook2', 14 | output_for_visualization=False, 15 | max_text_length=-1 16 | ): 17 | super(YouCook2, self).__init__() 18 | 19 | with open(os.path.join(data_root, 'youcookii_annotations_trainval.json'), 'rb') as fin: 20 | all_data = json.load(fin)['database'] 21 | 22 | with open(os.path.join(data_root, 'splits', f'{split}_list.txt'), 'r') as fin: 23 | video_ids = fin.readlines() 24 | video_ids = [os.path.basename(name.strip()) for name in video_ids if name.strip() != ''] 25 | 26 | self.data = [] 27 | for video_id in video_ids: 28 | for clip_data in all_data[video_id]['annotations']: 29 | clip_id = f'{video_id}_{clip_data["id"]}' 30 | clip_data['video_id'] = video_id 31 | clip_data['clip_id'] = clip_id 32 | clip_data['recipe_type'] = all_data[video_id]['recipe_type'] 33 | self.data.append(clip_data) 34 | 35 | self.data_root = data_root 36 | self.split = split 37 | self.transforms = transforms 38 | self.num_frames = num_frames 39 | self.dataset_name = dataset_name 40 | self.output_for_visualization = output_for_visualization 41 | self.max_text_length = max_text_length 42 | if output_for_visualization: 43 | self.vis_transform = init_transform_dict()['visualization'] 44 | 45 | def _get_video_path(self, sample): 46 | if self.split == 'train': 47 | folder = 'training' 48 | elif self.split == 'val': 49 | folder = 'validation' 50 | elif self.split == 'test': 51 | folder = 'testing' 52 | else: 53 | raise NotImplementedError 54 | rel_path = os.path.join(folder, sample['recipe_type'], sample['video_id'], sample['video_id']) 55 | if os.path.exists(os.path.join(self.data_root, rel_path + '.mp4')): 56 | rel_path = rel_path + '.mp4' 57 | else: 58 | rel_path = rel_path + '.mkv' 59 | 60 | return os.path.join(self.data_root, rel_path), rel_path 61 | 62 | def _get_caption(self, sample): 63 | return sample['sentence'] 64 | 65 | def __len__(self): 66 | return len(self.data) 67 | 68 | def __getitem__(self, idx): 69 | idx = idx % len(self.data) 70 | 71 | sample = self.data[idx] 72 | video_fp, rel_fp = self._get_video_path(sample) 73 | caption = self._get_caption(sample) 74 | 75 | start_clip, end_clip = sample['segment'] 76 | num_sec = end_clip - start_clip - 0.1 # just in case 77 | if self.split == 'train': 78 | # sample start 79 | fps = (self.num_frames + 1) / num_sec 80 | start_clip = start_clip + random.random() * num_sec / self.num_frames 81 | else: 82 | fps = self.num_frames / num_sec 83 | 84 | video = get_video_frames(video_fp, start_clip, end_clip, self.num_frames, fps=fps) 85 | 86 | if self.output_for_visualization: 87 | vis_video = self.vis_transform(video) 88 | 89 | if self.transforms is not None: 90 | video = self.transforms(video) 91 | 92 | return {'video': video, 'text': caption, 'time': num_sec, 'dataset': self.dataset_name, 'path': rel_fp, 93 | 'idx': str(idx), 94 | 'vis_video': (vis_video if self.output_for_visualization else 0), 95 | 'max_text_length': self.max_text_length, 96 | } -------------------------------------------------------------------------------- /howtocaption/utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import pandas as pd 4 | from pathlib import Path 5 | from itertools import repeat 6 | from collections import OrderedDict 7 | 8 | import yaml 9 | 10 | def read_yaml(fname): 11 | with fname.open('rt') as handle: 12 | class OrderedLoader(yaml.SafeLoader): 13 | pass 14 | def construct_mapping(loader, node): 15 | loader.flatten_mapping(node) 16 | return OrderedDict(loader.construct_pairs(node)) 17 | OrderedLoader.add_constructor( 18 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 19 | construct_mapping) 20 | return yaml.load(handle, OrderedLoader) 21 | 22 | 23 | def write_yaml(content, fname): 24 | with fname.open('wt') as handle: 25 | class OrderedDumper(yaml.SafeDumper): 26 | pass 27 | def _dict_representer(dumper, data): 28 | return dumper.represent_mapping( 29 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 30 | data.items()) 31 | OrderedDumper.add_representer(OrderedDict, _dict_representer) 32 | return yaml.dump(content, handle, OrderedDumper) 33 | 34 | 35 | def read_json(fname): 36 | fname = Path(fname) 37 | with fname.open('rt') as handle: 38 | return json.load(handle, object_hook=OrderedDict) 39 | 40 | 41 | def write_json(content, fname): 42 | fname = Path(fname) 43 | with fname.open('wt') as handle: 44 | json.dump(content, handle, indent=4, sort_keys=False) 45 | 46 | def read_config(fname): 47 | if fname.suffix == '.json': 48 | return read_json(fname) 49 | elif fname.suffix == '.yaml': 50 | return read_yaml(fname) 51 | else: 52 | raise NotImplementedError 53 | 54 | def ensure_dir(dirname): 55 | dirname = Path(dirname) 56 | if not dirname.is_dir(): 57 | dirname.mkdir(parents=True, exist_ok=False) 58 | 59 | 60 | def inf_loop(data_loader): 61 | ''' wrapper function for endless data loader. ''' 62 | for loader in repeat(data_loader): 63 | yield from loader 64 | 65 | 66 | def prepare_device(n_gpu_use): 67 | """ 68 | setup GPU device if available. get gpu device indices which are used for DataParallel 69 | """ 70 | n_gpu = torch.cuda.device_count() 71 | if n_gpu_use > 0 and n_gpu == 0: 72 | print("Warning: There\'s no GPU available on this machine," 73 | "training will be performed on CPU.") 74 | n_gpu_use = 0 75 | if n_gpu_use > n_gpu: 76 | print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are " 77 | "available on this machine.") 78 | n_gpu_use = n_gpu 79 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 80 | list_ids = list(range(n_gpu_use)) 81 | return device, list_ids 82 | 83 | 84 | class MetricTracker: 85 | def __init__(self, *keys, neptune_run=None): 86 | self.neptune_run = neptune_run 87 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 88 | self.reset() 89 | 90 | def reset(self): 91 | for col in self._data.columns: 92 | self._data[col].values[:] = 0 93 | 94 | def update(self, key, value, n=1, step=None): 95 | if key not in self._data.index: 96 | new_row = pd.DataFrame(index=[key], columns=['total', 'counts', 'average']) 97 | for col in new_row.columns: 98 | new_row[col].values[:] = 0 99 | self._data = pd.concat((self._data, new_row)) 100 | 101 | if self.neptune_run is not None: 102 | self.neptune_run[key].log(value, step=step) 103 | self._data.total[key] += value * n 104 | self._data.counts[key] += n 105 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 106 | 107 | def avg(self, key): 108 | return self._data.average[key] 109 | 110 | def result(self): 111 | return dict(self._data.average) 112 | 113 | 114 | -------------------------------------------------------------------------------- /howtocaption/save_frame_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | from datetime import datetime 5 | import tqdm 6 | import os 7 | import pickle 8 | from sacred import Experiment 9 | 10 | import howtocaption.data_loader as module_data 11 | import howtocaption.model as module_arch 12 | 13 | from howtocaption.trainer.vl_trainer import _move_to_device 14 | from howtocaption.parse_config import ConfigParser 15 | from howtocaption.train import init_dataloaders 16 | from howtocaption.base.base_trainer import fix_module_in_state_dict 17 | 18 | ex = Experiment('predict', save_git_info=False) 19 | 20 | 21 | @ex.main 22 | def run(): 23 | print(f"Config: {config['name']}") 24 | 25 | print("Creating dataset") 26 | print("Setting batch_size to 1") 27 | config['data_loader']['args']['batch_size'] = 1 28 | 29 | data_loader = init_dataloaders(config, 'data_loader', module_data, 30 | process_only_part_i=args.process_only_part_i, 31 | number_of_parts=args.number_of_parts) 32 | 33 | print("Creating model") 34 | model = config.initialize('arch', module_arch) 35 | 36 | if config['load_weights'] is not None: 37 | checkpoint = torch.load(config['load_weights']) 38 | print("Loading model weights: {} ...".format(config['load_weights']), flush=True) 39 | print("from epoch: {} ...".format(checkpoint['epoch']), flush=True) 40 | state_dict = checkpoint['state_dict'] 41 | state_dict = fix_module_in_state_dict(state_dict, model) 42 | model.load_state_dict(state_dict) 43 | model.eval() 44 | device = torch.device(args.device) 45 | model = model.to(device) 46 | 47 | save_dir = config['save_dir'] 48 | config_name = config["name"] 49 | 50 | with torch.no_grad(): 51 | for dl_idx, dl in enumerate(data_loader): 52 | output = {} 53 | for data_idx, data in tqdm.tqdm(enumerate(dl)): 54 | assert data['video'].shape[0] == 1 # batch size = 1 55 | assert len(data['video_id']) == 1 56 | video_id = data['video_id'][0] 57 | video = data['video'] 58 | 59 | cur_size = video.size() 60 | frames = video.view(cur_size[0] * cur_size[1], *cur_size[2:]) 61 | 62 | output_embed = [] 63 | max_batch_size = args.batch_size 64 | for offset_i in range(int(np.ceil(frames.size(0) / max_batch_size))): 65 | cur_frames = frames[offset_i * max_batch_size:(offset_i + 1) * max_batch_size] 66 | cur_frames = _move_to_device(cur_frames, device) 67 | 68 | cur_embed = model.encode_image(cur_frames[None]) 69 | cur_embed /= cur_embed.norm(dim=-1, keepdim=True) 70 | output_embed.append(cur_embed) 71 | output_embed = torch.cat(output_embed, dim=0) 72 | 73 | output[video_id] = {'start': data['start_time'][0].item(), 'end': data['end_time'][0].item(), 74 | 'frames': output_embed.detach().cpu().numpy().astype('float16')} 75 | 76 | if data_idx % 100 == 0: 77 | now = datetime.now() 78 | current_time = now.strftime("%H:%M:%S") 79 | print(f'{current_time}: batch', data_idx, flush=True) 80 | 81 | if args.process_only_part_i is not None: 82 | path = os.path.join(save_dir, f'video_{config_name}_part{args.process_only_part_i}.pickle') 83 | else: 84 | path = os.path.join(save_dir, f'video_{config_name}.pickle') 85 | print(f"Saving results into {path}") 86 | save_results(output, path) 87 | 88 | 89 | def save_results(results, path): 90 | os.makedirs(os.path.dirname(path), exist_ok=True) 91 | with open(path, 'wb') as fout: 92 | pickle.dump(results, fout) 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description='PyTorch Template') 97 | parser.add_argument('-c', '--config', default=None, type=str, 98 | help='config file path (default: None)') 99 | parser.add_argument('--device', default='cuda') 100 | parser.add_argument('--batch_size', default=128, type=int) 101 | parser.add_argument('--process_only_part_i', default=None, type=int) 102 | parser.add_argument('--number_of_parts', default=None, type=int) 103 | 104 | config = ConfigParser(parser, test=True) 105 | args = config.args 106 | 107 | ex.add_config(config.config) 108 | 109 | ex.run() 110 | -------------------------------------------------------------------------------- /howtocaption/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import howtocaption.data_loader as module_data 5 | import howtocaption.model as module_arch 6 | import howtocaption.trainer as module_trainer 7 | from howtocaption.parse_config import ConfigParser 8 | from sacred import Experiment 9 | from howtocaption.train import init_dataloaders 10 | 11 | import neptune.new as neptune 12 | from neptune.new.integrations.sacred import NeptuneObserver 13 | 14 | from howtocaption.utils.dist_utils import init_distributed_mode, is_main_process, get_rank, get_world_size 15 | 16 | ex = Experiment('eval', save_git_info=False) 17 | 18 | 19 | @ex.main 20 | def run(): 21 | print(f"Config: {config['name']}") 22 | torch.backends.cudnn.benchmark = True 23 | 24 | # setup data_loader instances 25 | torch.multiprocessing.set_start_method('spawn', force=True) 26 | 27 | print("Creating dataset") 28 | train_data_loader = init_dataloaders(config, 'train_data_loader', module_data) 29 | 30 | if config.config.get('valid_data_loader', None) is not None: 31 | valid_data_loader = init_dataloaders(config, 'valid_data_loader', module_data) 32 | else: 33 | valid_data_loader = None 34 | 35 | # build model architecture, then print to console 36 | print("Creating model") 37 | model = config.initialize('arch', module_arch) 38 | print(model) 39 | 40 | # prepare for (multi-device) GPU training 41 | device = torch.device(args.device) 42 | 43 | if args.distributed: 44 | # Apply SyncBN 45 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 46 | 47 | torch.cuda.set_device(args.gpu) 48 | model = model.cuda(args.gpu) 49 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])#, find_unused_parameters=True) 50 | model_without_ddp = model.module 51 | else: 52 | model = model.to(device) 53 | model_without_ddp = model 54 | 55 | config.config['trainer']['args']['resume_only_model'] = True 56 | trainer = config.initialize('trainer', module_trainer, 57 | model=model, 58 | loss=None, 59 | metrics=None, 60 | optimizer=None, 61 | neptune_run=neptune_run, 62 | config=config, 63 | device=device, 64 | data_loader=train_data_loader, 65 | valid_data_loader=valid_data_loader, 66 | lr_scheduler=None, 67 | model_without_ddp=model_without_ddp) 68 | if args.eval_retrieval: 69 | trainer._eval_retrieval() 70 | if args.eval_captioning: 71 | trainer._eval_nlp() 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser(description='PyTorch Template') 76 | parser.add_argument('-c', '--config', default=None, type=str, 77 | help='config file path (default: None)') 78 | parser.add_argument('-r', '--resume', default=None, type=str, 79 | help='path to latest checkpoint (default: None)') 80 | parser.add_argument('-n', '--neptune', action='store_true', 81 | help='Whether to observe (neptune)') 82 | parser.add_argument('--eval_retrieval', action='store_true', default=False) 83 | parser.add_argument('--eval_captioning', action='store_true', default=False) 84 | parser.add_argument('--seed', default=None, type=int) 85 | parser.add_argument('--device', default='cuda') 86 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 87 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 88 | parser.add_argument('--distributed', default=0, type=int) 89 | 90 | # custom cli options to modify configuration from default values given in json file. 91 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 92 | options = [ 93 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 94 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 95 | ] 96 | config = ConfigParser(parser, options) 97 | args = config.args 98 | 99 | ex.add_config(config.config) 100 | 101 | init_distributed_mode(args=args) 102 | 103 | neptune_run = None 104 | if args.neptune and is_main_process(): 105 | # delete this error if you have added your own neptune credentials neptune.ai 106 | raise ValueError 107 | api_token = '' 108 | project = '' 109 | 110 | neptune_run = neptune.init( 111 | project=project, 112 | api_token=api_token, 113 | source_files=['imp_videocap/**/*.py', '*.py'] 114 | ) 115 | 116 | ex.observers.append(NeptuneObserver(run=neptune_run)) 117 | 118 | ex.run() 119 | -------------------------------------------------------------------------------- /howtocaption/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from functools import reduce 4 | from operator import getitem 5 | from datetime import datetime 6 | import time 7 | import inspect 8 | from collections import OrderedDict 9 | import json 10 | 11 | from howtocaption.utils import read_yaml, write_yaml 12 | 13 | 14 | def read_json(fname): 15 | with fname.open('rt') as handle: 16 | return json.load(handle, object_hook=OrderedDict) 17 | 18 | 19 | def write_json(content, fname): 20 | with fname.open('wt') as handle: 21 | json.dump(content, handle, indent=4, sort_keys=False) 22 | 23 | 24 | class ConfigParser: 25 | def __init__(self, args_parser, options='', timestamp=True, test=False, parse_from_string=None): 26 | # parse default and custom cli options 27 | for opt in options: 28 | args_parser.add_argument(*opt.flags, default=None, type=opt.type) 29 | 30 | if parse_from_string is not None: 31 | import shlex 32 | args = args_parser.parse_args(shlex.split(parse_from_string)) 33 | else: 34 | args = args_parser.parse_args() 35 | self.args = args 36 | 37 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 38 | assert args.config is not None, msg_no_cfg 39 | config = read_yaml(Path(args.config)) 40 | 41 | if hasattr(args, 'resume') and args.resume is not None: 42 | self.resume = Path(args.resume) 43 | else: 44 | self.resume = None 45 | 46 | # load config file and apply custom cli options 47 | self.config = _update_config(config, options, args) 48 | 49 | # set seed 50 | self.config['seed'] = self.config.get('seed', None) # set None if it's not given 51 | if hasattr(args, 'seed') and args.seed is not None: 52 | self.config['seed'] = args.seed 53 | 54 | config['name'] = os.path.splitext(os.path.basename(args.config))[0] 55 | 56 | # set save_dir where trained model and log will be saved. 57 | save_dir = Path(self.config['save_dir']) 58 | timestamp = datetime.now().strftime(r'%y%m%d_%H%M%S') if timestamp else '' 59 | 60 | exper_name = self.config['name'] 61 | self._save_dir = save_dir / 'models' / exper_name / timestamp 62 | 63 | if not test: 64 | make_dir=True 65 | counter= 0 66 | while make_dir: 67 | try: 68 | self.save_dir.mkdir(parents=True, exist_ok=True) 69 | make_dir = False 70 | except PermissionError: 71 | exper_name = f'{self.config["name"]}_{counter}' 72 | self._save_dir = save_dir / 'models' / exper_name / timestamp 73 | self.config['name'] = exper_name 74 | counter += 1 75 | 76 | # save updated config file to the checkpoint dir 77 | if not test: 78 | write_yaml(self.config, self.save_dir / 'config.yaml') 79 | 80 | def initialize(self, name, module, *args, index=None, **kwargs): 81 | """ 82 | finds a function handle with the name given as 'type' in config, and returns the 83 | instance initialized with corresponding keyword args given as 'args'. 84 | """ 85 | if index is None: 86 | module_name = self[name]['type'] 87 | module_args = dict(self[name]['args']) 88 | else: 89 | module_name = self[name][index]['type'] 90 | module_args = dict(self[name][index]['args']) 91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 92 | module_args.update(kwargs) 93 | # if parameter not in config subdict, then check if it's in global config. 94 | signature = inspect.signature(getattr(module, module_name).__init__) # (self, arg1, arg2, ...) 95 | for param in list(signature.parameters.keys())[1 + len(args):]: # self + first n that takes from args 96 | if param not in module_args and param in self.config: 97 | module_args[param] = self[param] 98 | print(f'Warning: Using param {param} from global config in {name}') 99 | 100 | return getattr(module, module_name)(*args, **module_args) 101 | 102 | def __getitem__(self, name): 103 | return self.config[name] 104 | 105 | @property 106 | def save_dir(self): 107 | return self._save_dir 108 | 109 | 110 | # helper functions used to update config dict with custom cli options 111 | def _update_config(config, options, args): 112 | for opt in options: 113 | value = getattr(args, _get_opt_name(opt.flags)) 114 | if value is not None: 115 | _set_by_path(config, opt.target, value) 116 | return config 117 | 118 | 119 | def _get_opt_name(flags): 120 | for flg in flags: 121 | if flg.startswith('--'): 122 | return flg.replace('--', '') 123 | return flags[0].replace('--', '') 124 | 125 | 126 | def _set_by_path(tree, keys, value): 127 | """Set a value in a nested object in tree by sequence of keys.""" 128 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 129 | 130 | 131 | def _get_by_path(tree, keys): 132 | """Access a nested object in tree by sequence of keys.""" 133 | return reduce(getitem, keys, tree) 134 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/msrvtt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from torch.utils.data import Dataset 7 | import ffmpeg 8 | import json 9 | 10 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 11 | from howtocaption.data_loader.transforms import init_transform_dict 12 | 13 | 14 | class MSRVTT(Dataset): 15 | def __init__(self, data_root, cut, split, num_frames=4, transforms=None, 16 | dataset_name='MSRVTT', 17 | output_for_visualization=False, 18 | max_text_length=-1 19 | ): 20 | super(MSRVTT, self).__init__() 21 | 22 | self.metadata = self._load_metadata(data_root, cut, split) 23 | self.data_root = data_root 24 | self.split = split 25 | 26 | self.transforms = transforms 27 | self.num_frames = num_frames 28 | self.dataset_name = dataset_name 29 | self.output_for_visualization = output_for_visualization 30 | self.max_text_length = max_text_length 31 | if output_for_visualization: 32 | self.vis_transform = init_transform_dict()['visualization'] 33 | 34 | def _load_metadata(self, data_root, cut, split): 35 | json_fp = os.path.join(data_root, 'annotation', 'MSR_VTT.json') 36 | with open(json_fp, 'r') as fid: 37 | data = json.load(fid) 38 | df = pd.DataFrame(data['annotations']) 39 | 40 | split_dir = os.path.join(data_root, 'high-quality', 'structured-symlinks') 41 | js_test_cap_idx_path = None 42 | challenge_splits = {"val", "public_server_val", "public_server_test"} 43 | if cut == "miech": 44 | train_list_path = "train_list_miech.txt" 45 | test_list_path = "test_list_miech.txt" 46 | elif cut == "jsfusion": 47 | train_list_path = "train_list_jsfusion.txt" 48 | test_list_path = "val_list_jsfusion.txt" 49 | js_test_cap_idx_path = "jsfusion_val_caption_idx.pkl" 50 | elif cut in {"full-val", "full-test"}: 51 | train_list_path = "train_list_full.txt" 52 | if cut == "full-val": 53 | test_list_path = "val_list_full.txt" 54 | else: 55 | test_list_path = "test_list_full.txt" 56 | elif cut in challenge_splits: 57 | train_list_path = "train_list.txt" 58 | if cut == "val": 59 | test_list_path = f"{cut}_list.txt" 60 | else: 61 | test_list_path = f"{cut}.txt" 62 | else: 63 | msg = "unrecognised MSRVTT split: {}" 64 | raise ValueError(msg.format(cut)) 65 | 66 | train_df = pd.read_csv(os.path.join(split_dir, train_list_path), names=['videoid']) 67 | test_df = pd.read_csv(os.path.join(split_dir, test_list_path), names=['videoid']) 68 | 69 | if split == 'train': 70 | df = df[df['image_id'].isin(train_df['videoid'])] 71 | else: 72 | df = df[df['image_id'].isin(test_df['videoid'])] 73 | 74 | metadata = df.groupby(['image_id'])['caption'].apply(list) 75 | 76 | # use specific caption idx's in jsfusion 77 | if js_test_cap_idx_path is not None and split != 'train': 78 | caps = pd.Series(np.load(os.path.join(split_dir, js_test_cap_idx_path), allow_pickle=True)) 79 | new_res = pd.DataFrame({'caps': metadata, 'cap_idx': caps}) 80 | new_res['test_caps'] = new_res.apply(lambda x: [x['caps'][x['cap_idx']]], axis=1) 81 | metadata = new_res['test_caps'] 82 | 83 | metadata = pd.DataFrame({'captions': metadata}) 84 | return metadata 85 | 86 | def _get_video_path(self, sample): 87 | return os.path.join(self.data_root, 'videos', 'all', sample.name + '.mp4'), sample.name + '.mp4' 88 | 89 | def _get_caption(self, sample): 90 | if self.split == 'train': 91 | caption = random.choice(sample['captions']) 92 | else: 93 | caption = sample['captions'][0] 94 | return caption 95 | 96 | def get_full_caption(self, idx): 97 | sample = self.metadata.iloc[idx] 98 | return sample['captions'] 99 | 100 | def __len__(self): 101 | return len(self.metadata) 102 | 103 | def __getitem__(self, idx): 104 | idx = idx % len(self.metadata) 105 | 106 | sample = self.metadata.iloc[idx] 107 | video_fp, rel_fp = self._get_video_path(sample) 108 | caption = self._get_caption(sample) 109 | 110 | if isinstance(caption, str): 111 | probe = ffmpeg.probe(video_fp) 112 | 113 | start_clip = 0 114 | end_clip = np.floor(float(probe['format']['duration'])) 115 | else: 116 | caption, start_clip, end_clip = caption 117 | 118 | num_sec = end_clip - start_clip - 0.1 # just in case 119 | if self.split == 'train': 120 | # sample start 121 | fps = (self.num_frames + 1) / num_sec 122 | start_clip = random.random() * num_sec / self.num_frames 123 | else: 124 | fps = self.num_frames / num_sec 125 | 126 | video = get_video_frames(video_fp, start_clip, end_clip, self.num_frames, fps=fps) 127 | 128 | if self.output_for_visualization: 129 | vis_video = self.vis_transform(video) 130 | 131 | if self.transforms is not None: 132 | video = self.transforms(video) 133 | 134 | output = {'video': video, 'text': caption, 'time': num_sec, 'dataset': self.dataset_name, 'path': rel_fp, 'idx': sample.name, 135 | 'vis_video': (vis_video if self.output_for_visualization else 0), 136 | 'max_text_length': self.max_text_length, 137 | } 138 | return output 139 | -------------------------------------------------------------------------------- /dataset/readme.md: -------------------------------------------------------------------------------- 1 | # HowToCaption Dataset 2 | 3 | [**[arxiv]**](https://arxiv.org/abs/2301.02009) 4 | 5 | **The HowToCaption dataset** comprises 6 | 1.2M long-term instructional videos from [the HowTo100M dataset](https://www.di.ens.fr/willow/research/howto100m/), 7 | where ASR subtitles have been transformed into proper captions 8 | via our HowToCaption method using [the Vicuna-13B LLM](https://lmsys.org/blog/2023-03-30-vicuna/) ([v0](https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md)). 9 | The captions are automatically generated 10 | and their high-quality alignment to the video are further 11 | ensured through subsequent alignment and filtering post-processing, 12 | all achieved without any human involvement. As a result, the HowToCaption dataset contains 25M aligned video-text pairs. 13 | 14 | Get the dataset: 15 | * **HowToCaption dataset** (video_id+captions+timestamps): [Link](https://drive.google.com/file/d/1GU6G29RcVO8Og9D5CsJDS24kRs3bpJ2a/view?usp=drive_link) (~1.5GB) 16 | * **Unfiltered version** with 17 | the corresponding similarity scores of caption to video clip (video_id+captions+timestamps+scores): [Link](https://drive.google.com/file/d/1Do_anJj-FB8lGINKbUgbj7vd8AQaByzY/view?usp=drive_link) (~4.6GB) 18 | 19 | Additionally, we provide the HowToCaption-grounded dataset, featuring captions obtained via [the MiniGPT4 model](https://minigpt-4.github.io/): 20 | * **HowToCaption-grounded dataset** (video_id+captions+timestamps): [Link](https://drive.google.com/file/d/1zBXyCHgO8zrytd1m3ohq3eF3nFEwAMki/view?usp=drive_link) (~1.5GB) 21 | * **Unfiltered version** with 22 | the corresponding similarity scores of caption to video clip (video_id+captions+timestamps+scores): [Link](https://drive.google.com/file/d/1uNqxfEgviOt-Fmr9qMb0D3OhZ3LlqeIv/view?usp=drive_link) (~4.5GB) 23 | 24 | ### How To Use 25 | 26 | #### How to use filtered HowToCaption or HowToCaption-grounded datasets: 27 | * Each file is a dictionary with video-ids as keys 28 | * For each video we provide ‘start’, ‘end’, and ‘text’ lists of the same lengths 29 | * ’start’ and ‘end’ correspond to starts and ends seconds of the clips in the video 30 | 31 | 32 | To note: 33 | - ‘text’ is list of lists of strings as to the same position in the video can correspond several captions 34 | - Starting seconds in ‘start’ list are not ordered; however, ‘end’ seconds always correspond to ’start’ positions ordering 35 | 36 | 37 | **Example**: 38 | 39 | ``` 40 | <<< HowToCaption[‘---39MFGZ-k’] 41 | 42 | { 43 | 'start': [12, 19, 29, 25, 55, 81, 82], 44 | 'end': [20, 27, 37, 33, 63, 89, 90], 45 | 'text': [ 46 | [‘Show how to unload a 12-gauge shotgun’], 47 | [‘Loading a 12-gauge shotgun’], 48 | [‘Demonstrating how to unload a 12-gauge shotgun', 'A better way to unload a gun’], 49 | [‘Putting another round into the gun', 'The danger of loading a gun the usual way’], 50 | [‘Loading the gun safely', 'Short stroke to load the gun', 'Loading the gun today’], 51 | [‘Lifting up the bar to extract rounds’], 52 | [‘Going forward and lifting up the bar to extract rounds'] 53 | } 54 | ``` 55 | 56 | #### How to use unfiltered HowToCaption or HowToCaption-grounded datasets: 57 | 58 | The difference to standard HowToCaption dataset is that ‘text’ is list of lists of tuples of (string, score). 59 | 60 | **Example**: 61 | ``` 62 | <<< HowToCaption[‘---39MFGZ-k’] 63 | 64 | { 65 | 'start': [12, 19, 25, 29, 55, 54, 65, 81, 82, 105, 103], 66 | 'end': [20, 27, 33, 37, 63, 62, 73, 89, 90, 113, 111], 67 | 'text': [ 68 | [('Show how to unload a 12-gauge shotgun', 0.5699871778488159)], 69 | [('Loading a 12-gauge shotgun', 0.5876383185386658)], 70 | [('Unloading and removing a round from the chamber', 0.31276029348373413), ('Putting another round into the gun', 0.4805337190628052), ('The danger of loading a gun the usual way', 0.4611629843711853)], 71 | [('Demonstrating how to unload a 12-gauge shotgun', 0.617999255657196), ('A better way to unload a gun', 0.5126216411590576)], 72 | [('Loading the gun safely', 0.539146363735199), ('Short stroke to load the gun', 0.5076732635498047), ('Loading the gun today', 0.4759426712989807)], 73 | [('Being nervous on camera', 0.3465729355812073), ('Nervousness on camera', 0.27738460898399353)], 74 | [('Extracting rounds by lifting up the bar', 0.41076189279556274)], 75 | [('Lifting up the bar to extract rounds', 0.4220432639122009)], 76 | [('Going forward and lifting up the bar to extract rounds', 0.42620745301246643)], 77 | [('A person is speaking and pointing out that there are no ramps present', 0.30187565088272095)], 78 | [('The speaker mentions that they can be found online', 0.30197498202323914), ('The speaker concludes the video by saying "WWE" and ending the video', 0.36031144857406616)]] 79 | } 80 | ``` 81 | 82 | ### Acknowledgement 83 | * [BLIP](https://github.com/salesforce/BLIP) is the model for text-video encoder and score function 84 | * [Vicuna](https://github.com/lm-sys/FastChat/tree/main) is open source instructional LLM to generate HowToCaption dataset 85 | * [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) is open-source LLM with image conditioning to generate HowToCaption-grounded dataset 86 | 87 | 88 | If you're using HowToCaption or HowToCaption-grounded dataset in your research or applications, please cite using this BibTeX: 89 | 90 | ``` 91 | @article{shvetsova2023howtocaption, 92 | title={HowToCaption: Prompting LLMs to Transform Video Annotations at Scale}, 93 | author={Shvetsova, Nina and Kukleva, Anna and Hong, Xudong and Rupprecht, Christian and Schiele, Bernt and Kuehne, Hilde}, 94 | journal={ECCV}, 95 | year={2024} 96 | } 97 | ``` 98 | 99 | 100 | ### Licence: 101 | 102 | HowToCaption and HowToCaption-grounded are based on Vicuna and MiniGpt-4 that are fine-tuned LLaMA and should be used under [LLaMA's model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). 103 | 104 | 105 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/lsmdc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import ffmpeg 6 | 7 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 8 | from howtocaption.data_loader.transforms import init_transform_dict 9 | 10 | 11 | class LSMDC(Dataset): 12 | def __init__(self, data_root, split, num_frames=4, transforms=None, 13 | dataset_name='LSDMC', 14 | output_for_visualization=False, 15 | max_text_length=-1 16 | ): 17 | super(LSMDC, self).__init__() 18 | 19 | self.data_root = data_root 20 | self._load_metadata(split) 21 | self.split = split 22 | 23 | self.transforms = transforms 24 | self.num_frames = num_frames 25 | self.dataset_name = dataset_name 26 | self.output_for_visualization = output_for_visualization 27 | self.max_text_length = max_text_length 28 | if output_for_visualization: 29 | self.vis_transform = init_transform_dict()['visualization'] 30 | 31 | def _load_metadata(self, split): 32 | assert split in ["train", "val", "test"] 33 | video_json_path_dict = {} 34 | if split == 'train': 35 | video_json_path_dict["train"] = os.path.join(self.data_root, 'lsmdc2016', "LSMDC16_annos_training.csv") 36 | elif split == 'val': 37 | video_json_path_dict["val"] = os.path.join(self.data_root, 'lsmdc2016', "LSMDC16_annos_val.csv") 38 | else: 39 | video_json_path_dict["test"] = os.path.join(self.data_root, 'lsmdc2016', "LSMDC16_challenge_1000_publictect.csv") 40 | 41 | # \t\t\t\t\t 42 | # is not a unique identifier, i.e. the same can be associated with multiple sentences. 43 | # However, LSMDC16_challenge_1000_publictect.csv has no repeat instances 44 | video_id_list = [] 45 | caption_dict = {} 46 | with open(video_json_path_dict[split], 'r') as fp: 47 | for line in fp: 48 | line = line.strip() 49 | line_split = line.split("\t") 50 | assert len(line_split) == 6 51 | clip_id, start_aligned, end_aligned, start_extracted, end_extracted, sentence = line_split 52 | caption_dict[clip_id] = { 53 | 'start': start_aligned, 54 | 'end': end_aligned, 55 | 'text': sentence, 56 | 'clip_id': clip_id 57 | } 58 | if clip_id not in video_id_list: video_id_list.append(clip_id) 59 | 60 | self.caption_dict = caption_dict 61 | 62 | features_path = os.path.join(self.data_root, 'avi') 63 | features_path2 = os.path.join(self.data_root, 'avi-m-vad-aligned') 64 | video_dict = {} 65 | for root, dub_dir, video_files in os.walk(features_path): 66 | for video_file in video_files: 67 | video_id_ = ".".join(video_file.split(".")[:-1]) 68 | if video_id_ not in video_id_list: 69 | continue 70 | file_path_ = os.path.join(root, video_file) 71 | video_dict[video_id_] = file_path_ 72 | 73 | for root, dub_dir, video_files in os.walk(features_path2): 74 | for video_file in video_files: 75 | video_id_ = ".".join(video_file.split(".")[:-1]) 76 | if video_id_ not in video_id_list: 77 | continue 78 | file_path_ = os.path.join(root, video_file) 79 | video_dict[video_id_] = file_path_ 80 | 81 | self.video_dict = video_dict 82 | 83 | # Get all captions 84 | self.iter2video_pairs_dict = {} 85 | for v in caption_dict.values(): 86 | clip_id = v['clip_id'] 87 | if clip_id not in self.video_dict: 88 | continue 89 | self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = clip_id 90 | 91 | if split == 'test': 92 | assert len(self.iter2video_pairs_dict) == 1000 93 | 94 | def get_full_caption(self, idx): 95 | video_id = self.iter2video_pairs_dict[idx] 96 | return self.caption_dict[video_id]['text'] 97 | 98 | def __len__(self): 99 | return len(self.iter2video_pairs_dict) 100 | 101 | def __getitem__(self, idx): 102 | video_id = self.iter2video_pairs_dict[idx] 103 | video_fp = self.video_dict[video_id] 104 | rel_fp = video_id 105 | asr = '' 106 | caption = self.caption_dict[video_id]['text'] 107 | 108 | start_clip = 0 109 | probe = ffmpeg.probe(video_fp) 110 | end_clip = np.floor(float(probe['format']['duration'])) 111 | 112 | num_sec = end_clip - start_clip - 0.1 # just in case 113 | if self.split == 'train': 114 | # sample start 115 | fps = (self.num_frames + 1) / num_sec 116 | start_clip = random.random() * num_sec / self.num_frames 117 | else: 118 | fps = self.num_frames / num_sec 119 | 120 | video = get_video_frames(video_fp, start_clip, end_clip, self.num_frames, fps=fps) 121 | 122 | if self.output_for_visualization: 123 | vis_video = self.vis_transform(video) 124 | 125 | if self.transforms is not None: 126 | video = self.transforms(video) 127 | 128 | output = {'video': video, 'asr': asr, 'text': caption, 'time': num_sec, 'dataset': self.dataset_name, 'path': rel_fp, 'idx': video_id, 129 | 'vis_video': (vis_video if self.output_for_visualization else 0), 130 | 'max_text_length': self.max_text_length, 131 | } 132 | 133 | return output 134 | -------------------------------------------------------------------------------- /howtocaption/save_text_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | from datetime import datetime 5 | import tqdm 6 | import pandas as pd 7 | from sacred import Experiment 8 | 9 | import os 10 | import pickle 11 | import re 12 | import sys 13 | 14 | import howtocaption.model as module_arch 15 | from howtocaption.parse_config import ConfigParser 16 | from howtocaption.base.base_trainer import fix_module_in_state_dict 17 | 18 | 19 | ex = Experiment('save_text_embeddings', save_git_info=False) 20 | 21 | 22 | @ex.main 23 | def run(): 24 | print(f"Config: {config['name']}") 25 | 26 | print("Creating model") 27 | model = config.initialize('arch', module_arch) 28 | 29 | if config['load_weights'] is not None: 30 | checkpoint = torch.load(config['load_weights']) 31 | print("Loading model weights: {} ...".format(config['load_weights']), flush=True) 32 | print("from epoch: {} ...".format(checkpoint['epoch']), flush=True) 33 | state_dict = checkpoint['state_dict'] 34 | state_dict = fix_module_in_state_dict(state_dict, model) 35 | model.load_state_dict(state_dict) 36 | model.eval() 37 | device = torch.device(args.device) 38 | model = model.to(device) 39 | 40 | save_dir = config['save_dir'] 41 | config_name = config["name"] 42 | llm_prediction_name = os.path.splitext(os.path.basename(args.llm_predictions))[0] 43 | 44 | 45 | with open(args.llm_predictions, 'rb') as fin: 46 | llm_predictions = pickle.load(fin) 47 | 48 | if args.process_only_part_i is not None: 49 | assert args.number_of_parts is not None 50 | assert args.csv is not None 51 | 52 | csv = pd.read_csv(args.csv) 53 | size = int(np.ceil(len(csv) / args.number_of_parts)) 54 | csv = csv[args.process_only_part_i * size: (args.process_only_part_i + 1) * size] 55 | allowed_video_ids = csv['video_path'].map(lambda x: os.path.splitext(os.path.basename(x))[0]).tolist() 56 | llm_predictions = {key: val for key, val in llm_predictions.items() if key in allowed_video_ids} 57 | 58 | def preprocess(text): 59 | replace_none = ['*', '1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.'] 60 | for pat in replace_none: 61 | text = text.replace(pat, '') 62 | text = text.replace('\n', '.') 63 | return [sent.strip() for sent in text.split('.') if sent.strip() != ''] 64 | 65 | add_info_dict = {} 66 | anno_dict = {} 67 | print('Preprocessing: parsing sentences and timestamps...', flush=True) 68 | for video_id, data in tqdm.tqdm(llm_predictions.items()): 69 | texts = data['prediction'] 70 | global_starts = [np.floor(x) for x in data['start']] 71 | global_ends = [np.ceil(x) for x in data['end']] 72 | 73 | pat = re.compile('^(\d+)s:(.*)') 74 | anno_dict[video_id] = [] 75 | num_sents = [] 76 | start = [] 77 | end = [] 78 | for text, global_start, global_end in zip(texts, global_starts, global_ends): 79 | text = preprocess(text) 80 | 81 | for sent in text: 82 | match = pat.match(sent) 83 | 84 | if match is not None: 85 | s = int(match.group(1)) 86 | e = s + 8 87 | sent = match.group(2).strip() 88 | 89 | # exclude captions that were predicted outside original start end time: 90 | if not((global_start <= s) and (s <= global_end)): 91 | continue 92 | 93 | anno_dict[video_id].append(sent) 94 | start.append(s) 95 | end.append(e) 96 | num_sents.append(1) 97 | 98 | add_info_dict[video_id] = { 99 | 'num_sents': num_sents, 100 | 'start': start, 101 | 'end': end, 102 | } 103 | 104 | output = {} 105 | missed_videos = 0 106 | counter = 0 107 | 108 | anno_dict = list(anno_dict.items()) 109 | print("Found unique videos: ", len(anno_dict), flush=True) 110 | print("Found number of captions: ", sum(len(x[1]) for x in anno_dict), flush=True) 111 | 112 | with torch.no_grad(): 113 | for video_id, captions in tqdm.tqdm(anno_dict): 114 | try: 115 | text_features = model.encode_text(captions) 116 | counter += 1 117 | except Exception as e: 118 | print(e, file=sys.stderr) 119 | missed_videos += 1 120 | continue 121 | text_features /= text_features.norm(dim=-1, keepdim=True) 122 | text_features = text_features.detach().cpu().numpy() 123 | output[video_id] = {'text': captions, 'features': text_features} 124 | 125 | if len(add_info_dict) != 0: 126 | output[video_id].update(add_info_dict[video_id]) 127 | 128 | if counter % 1000 == 0: 129 | now = datetime.now() 130 | current_time = now.strftime("%H:%M:%S") 131 | print(current_time, video_id, flush=True) 132 | 133 | if args.process_only_part_i is not None: 134 | path = os.path.join(save_dir, f'text_{config_name}_{llm_prediction_name}_part{args.process_only_part_i}.pickle') 135 | else: 136 | path = os.path.join(save_dir, f'text_{config_name}_{llm_prediction_name}.pickle') 137 | 138 | print(f"Saving results into {path}") 139 | save_results(output, path) 140 | print("Missed videos", missed_videos) 141 | 142 | 143 | def save_results(results, path): 144 | os.makedirs(os.path.dirname(path), exist_ok=True) 145 | with open(path, 'wb') as fout: 146 | pickle.dump(results, fout) 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser(description='PyTorch Template') 151 | parser.add_argument('-c', '--config', default=None, type=str, 152 | help='config file path (default: None)') 153 | parser.add_argument('--device', default='cuda') 154 | 155 | parser.add_argument('--llm_predictions', default=None, type=str) 156 | 157 | # for processing only part of original data 158 | parser.add_argument('--csv', default=None, type=str) 159 | parser.add_argument('--process_only_part_i', default=None, type=int) 160 | parser.add_argument('--number_of_parts', default=None, type=int) 161 | 162 | config = ConfigParser(parser, test=True) 163 | args = config.args 164 | ex.add_config(config.config) 165 | 166 | ex.run() 167 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/msvd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import ffmpeg 6 | import pickle 7 | 8 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 9 | from howtocaption.data_loader.transforms import init_transform_dict 10 | 11 | 12 | class MSVD(Dataset): 13 | def __init__(self, data_root, split, num_frames=4, transforms=None, 14 | dataset_name='MSVD', 15 | multi_sentence_per_video=True, 16 | output_for_visualization=False, 17 | max_text_length=-1, 18 | sample_per_video=False 19 | ): 20 | super(MSVD, self).__init__() 21 | 22 | self.data_root = data_root 23 | self.multi_sentence_per_video = multi_sentence_per_video 24 | self._load_metadata(split) 25 | self.split = split 26 | 27 | self.transforms = transforms 28 | self.num_frames = num_frames 29 | 30 | self.dataset_name = dataset_name 31 | self.output_for_visualization = output_for_visualization 32 | self.max_text_length = max_text_length 33 | if output_for_visualization: 34 | self.vis_transform = init_transform_dict()['visualization'] 35 | self.sample_per_video = sample_per_video 36 | 37 | def _load_metadata(self, split): 38 | assert split in ["train", "val", "test"] 39 | video_id_path_dict = {} 40 | if split == 'train': 41 | video_id_path_dict["train"] = os.path.join(self.data_root, 'msvd_data', "train_list.txt") 42 | elif split == 'val': 43 | video_id_path_dict["val"] = os.path.join(self.data_root, 'msvd_data', "val_list.txt") 44 | else: 45 | video_id_path_dict["test"] = os.path.join(self.data_root, 'msvd_data', "test_list.txt") 46 | 47 | caption_file = os.path.join(self.data_root, 'msvd_data', "raw-captions.pkl") 48 | 49 | with open(video_id_path_dict[split], 'r') as fp: 50 | video_ids = [itm.strip() for itm in fp.readlines()] 51 | 52 | with open(caption_file, 'rb') as f: 53 | captions = pickle.load(f) 54 | self.captions = captions 55 | 56 | video_dict = {} 57 | features_path = os.path.join(self.data_root, 'YouTubeClips') 58 | for root, dub_dir, video_files in os.walk(features_path): 59 | for video_file in video_files: 60 | video_id_ = ".".join(video_file.split(".")[:-1]) 61 | if video_id_ not in video_ids: 62 | continue 63 | file_path_ = os.path.join(root, video_file) 64 | video_dict[video_id_] = file_path_ 65 | self.video_dict = video_dict 66 | self.video_ids = video_ids 67 | 68 | self.sentences_dict = {} 69 | self.cut_off_points = [] 70 | self.video_ids2video_idx = {} 71 | for idx, video_id in enumerate(video_ids): 72 | self.video_ids2video_idx[video_id] = idx 73 | assert video_id in captions 74 | for cap in captions[video_id]: 75 | cap_txt = " ".join(cap) 76 | self.sentences_dict[len(self.sentences_dict)] = (video_id, cap_txt) 77 | self.cut_off_points.append(len(self.sentences_dict)) 78 | 79 | self.multi_sentence_per_video = self.multi_sentence_per_video # !!! important tag for eval 80 | if split == "val" or split == "test": 81 | self.sentence_num = len(self.sentences_dict) 82 | self.video_num = len(video_ids) 83 | assert len(self.cut_off_points) == self.video_num 84 | print("For {}, sentence number: {}".format(split, self.sentence_num)) 85 | print("For {}, video number: {}".format(split, self.video_num)) 86 | 87 | print("Video number: {}".format(len(self.video_dict))) 88 | print("Total Paire: {}".format(len(self.sentences_dict))) 89 | self.sample_len = len(self.sentences_dict) 90 | 91 | def get_full_caption(self, idx): 92 | if isinstance(idx, int): 93 | video_id = self.video_ids[idx] 94 | else: 95 | video_id = idx 96 | return [' '.join(itm) for itm in self.captions[video_id]] 97 | 98 | def __len__(self): 99 | if self.sample_per_video: 100 | if self.split == 'train': 101 | return len( self.video_ids) 102 | else: 103 | raise NotImplementedError 104 | else: 105 | if self.split == 'train': 106 | return self.sample_len 107 | return len(self.video_dict) 108 | 109 | def __getitem__(self, idx): 110 | if self.sample_per_video: 111 | if self.split == 'train': 112 | video_id = self.video_ids[idx] 113 | video_fp = self.video_dict[video_id] 114 | caption = " ".join(np.random.choice(self.captions[video_id])) 115 | else: 116 | raise NotImplementedError 117 | else: 118 | if self.split == 'train': 119 | video_id, caption = self.sentences_dict[idx] 120 | video_fp = self.video_dict[video_id] 121 | else: 122 | video_id = self.video_ids[idx] 123 | video_fp = self.video_dict[video_id] 124 | caption = " ".join(self.captions[video_id][0]) 125 | 126 | rel_fp = f'{video_id}.avi' 127 | asr = '' 128 | 129 | assert isinstance(caption, str) 130 | probe = ffmpeg.probe(video_fp) 131 | 132 | start_clip = 0 133 | end_clip = np.floor(float(probe['format']['duration'])) 134 | 135 | num_sec = end_clip - start_clip - 0.1 # just in case 136 | if self.split == 'train': 137 | # sample start 138 | fps = (self.num_frames + 1) / num_sec 139 | start_clip = random.random() * num_sec / self.num_frames 140 | else: 141 | fps = self.num_frames / num_sec 142 | 143 | video = get_video_frames(video_fp, start_clip, end_clip, self.num_frames, fps=fps) 144 | 145 | if self.output_for_visualization: 146 | vis_video = self.vis_transform(video) 147 | 148 | if self.transforms is not None: 149 | video = self.transforms(video) 150 | 151 | output = {'video': video, 'asr': asr, 'text': caption, 'time': num_sec, 'dataset': self.dataset_name, 'path': rel_fp, 'idx': video_id, 152 | 'video_numerical_idx': self.video_ids2video_idx[video_id], 153 | 'vis_video': (vis_video if self.output_for_visualization else 0), 154 | 'max_text_length': self.max_text_length, 155 | } 156 | return output 157 | -------------------------------------------------------------------------------- /howtocaption/model/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib.parse import urlparse 3 | import torch.nn as nn 4 | 5 | import torch 6 | from timm.models.hub import download_cached_file 7 | 8 | from transformers import BertTokenizer 9 | 10 | from howtocaption.model.utils.vit import interpolate_pos_embed, VisionTransformer 11 | from howtocaption.model.utils.vit_per_frame import VisionTransformerPerFrame 12 | 13 | from typing import List 14 | 15 | 16 | def init_tokenizer(): 17 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 18 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) # there was no bos_token. After this line tokenizer.bos_token_id will appear 19 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 20 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 21 | return tokenizer 22 | 23 | 24 | def is_url(url_or_filename): 25 | parsed = urlparse(url_or_filename) 26 | return parsed.scheme in ("http", "https") 27 | 28 | 29 | def load_checkpoint(model, url_or_filename): 30 | if is_url(url_or_filename): 31 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 32 | checkpoint = torch.load(cached_file, map_location='cpu') 33 | elif os.path.isfile(url_or_filename): 34 | checkpoint = torch.load(url_or_filename, map_location='cpu') 35 | else: 36 | raise RuntimeError('checkpoint url or path is invalid') 37 | 38 | state_dict = checkpoint['model'] 39 | 40 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'], 41 | model.visual_encoder) 42 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 43 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 44 | model.visual_encoder_m) 45 | for key in model.state_dict().keys(): 46 | if key in state_dict.keys(): 47 | if state_dict[key].shape != model.state_dict()[key].shape: 48 | del state_dict[key] 49 | 50 | msg = model.load_state_dict(state_dict, strict=False) 51 | 52 | print('load checkpoint from %s' % url_or_filename) 53 | return model, msg 54 | 55 | 56 | def create_vit_per_frame_model(image_size=384, vit='base'): 57 | if vit == 'base': 58 | vision_width = 768 59 | visual_encoder = VisionTransformerPerFrame(img_size=image_size, patch_size=16, embed_dim=vision_width) 60 | elif vit == 'large': 61 | vision_width = 1024 62 | visual_encoder = VisionTransformerPerFrame(img_size=image_size, patch_size=16, embed_dim=vision_width, 63 | depth=24, num_heads=16) 64 | else: 65 | raise NotImplementedError 66 | return visual_encoder, vision_width 67 | 68 | 69 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_keys:List[str]): 70 | uninitialized_encoder_weights: List[str] = [] 71 | if decoder.__class__ != encoder.__class__: 72 | print(f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized.") 73 | 74 | def tie_encoder_to_decoder_recursively( 75 | decoder_pointer: nn.Module, 76 | encoder_pointer: nn.Module, 77 | module_name: str, 78 | uninitialized_encoder_weights: List[str], 79 | skip_keys: List[str], 80 | depth=0, 81 | ): 82 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 83 | encoder_pointer, nn.Module 84 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 85 | if hasattr(decoder_pointer, "weight") and all(skip_key not in module_name for skip_key in skip_keys): 86 | assert hasattr(encoder_pointer, "weight") 87 | encoder_pointer.weight = decoder_pointer.weight 88 | if hasattr(decoder_pointer, "bias"): 89 | assert hasattr(encoder_pointer, "bias") 90 | encoder_pointer.bias = decoder_pointer.bias 91 | if hasattr(decoder_pointer, "lora_A"): 92 | assert hasattr(encoder_pointer, "lora_A") 93 | encoder_pointer.lora_A = decoder_pointer.lora_A 94 | if hasattr(decoder_pointer, "lora_B"): 95 | assert hasattr(encoder_pointer, "lora_B") 96 | encoder_pointer.lora_B = decoder_pointer.lora_B 97 | # print(module_name+' is tied') 98 | return 99 | 100 | encoder_modules = encoder_pointer._modules 101 | decoder_modules = decoder_pointer._modules 102 | if len(decoder_modules) > 0: 103 | assert ( 104 | len(encoder_modules) > 0 105 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 106 | 107 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) 108 | encoder_layer_pos = 0 109 | for name, module in decoder_modules.items(): 110 | if name.isdigit(): 111 | encoder_name = str(int(name) + encoder_layer_pos) 112 | decoder_name = name 113 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( 114 | encoder_modules 115 | ) != len(decoder_modules): 116 | # this can happen if the name corresponds to the position in a list module list of layers 117 | # in this case the decoder has added a cross-attention that the encoder does not have 118 | # thus skip this step and subtract one layer pos from encoder 119 | encoder_layer_pos -= 1 120 | continue 121 | elif name not in encoder_modules: 122 | continue 123 | elif depth > 500: 124 | raise ValueError( 125 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 126 | ) 127 | else: 128 | decoder_name = encoder_name = name 129 | tie_encoder_to_decoder_recursively( 130 | decoder_modules[decoder_name], 131 | encoder_modules[encoder_name], 132 | module_name + "/" + name, 133 | uninitialized_encoder_weights, 134 | skip_keys, 135 | depth=depth + 1, 136 | ) 137 | all_encoder_weights.remove(module_name + "/" + encoder_name) 138 | 139 | uninitialized_encoder_weights += list(all_encoder_weights) 140 | 141 | # tie weights recursively 142 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_keys) 143 | -------------------------------------------------------------------------------- /howtocaption/data_loader/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections.abc import Sequence 4 | from typing import Tuple, List, Optional 5 | from torch import Tensor 6 | import torch 7 | import torchvision.transforms.functional 8 | 9 | from torchvision.transforms.transforms import _setup_size, InterpolationMode, _interpolation_modes_from_int 10 | 11 | 12 | class MyRandomResizedCrop(torch.nn.Module): 13 | """Crop a random portion of image and resize it to a given size. 14 | 15 | If the image is torch Tensor, it is expected 16 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 17 | 18 | A crop of the original image is made: the crop has a random area (H * W) 19 | and a random aspect ratio. This crop is finally resized to the given 20 | size. This is popularly used to train the Inception networks. 21 | 22 | Args: 23 | size (int or sequence): expected output size of the crop, for each edge. If size is an 24 | int instead of sequence like (h, w), a square output size ``(size, size)`` is 25 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 26 | 27 | .. note:: 28 | In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. 29 | scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, 30 | before resizing. The scale is defined with respect to the area of the original image. 31 | ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before 32 | resizing. 33 | interpolation (InterpolationMode): Desired interpolation enum defined by 34 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. 35 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and 36 | ``InterpolationMode.BICUBIC`` are supported. 37 | For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. 38 | 39 | """ 40 | 41 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR, 42 | antialias=False): 43 | super().__init__() 44 | self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") 45 | 46 | if not isinstance(scale, Sequence): 47 | raise TypeError("Scale should be a sequence") 48 | if not isinstance(ratio, Sequence): 49 | raise TypeError("Ratio should be a sequence") 50 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 51 | warnings.warn("Scale and ratio should be of kind (min, max)") 52 | 53 | # Backward compatibility with integer value 54 | if isinstance(interpolation, int): 55 | warnings.warn( 56 | "Argument interpolation should be of type InterpolationMode instead of int. " 57 | "Please, use InterpolationMode enum." 58 | ) 59 | interpolation = _interpolation_modes_from_int(interpolation) 60 | 61 | self.interpolation = interpolation 62 | self.scale = scale 63 | self.ratio = ratio 64 | self.antialias = antialias 65 | 66 | @staticmethod 67 | def get_params( 68 | img: Tensor, scale: List[float], ratio: List[float] 69 | ) -> Tuple[int, int, int, int]: 70 | """Get parameters for ``crop`` for a random sized crop. 71 | 72 | Args: 73 | img (PIL Image or Tensor): Input image. 74 | scale (list): range of scale of the origin size cropped 75 | ratio (list): range of aspect ratio of the origin aspect ratio cropped 76 | 77 | Returns: 78 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 79 | sized crop. 80 | """ 81 | width, height = torchvision.transforms.functional.get_image_size(img) 82 | area = height * width 83 | 84 | log_ratio = torch.log(torch.tensor(ratio)) 85 | for _ in range(10): 86 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 87 | aspect_ratio = torch.exp( 88 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 89 | ).item() 90 | 91 | w = int(round(math.sqrt(target_area * aspect_ratio))) 92 | h = int(round(math.sqrt(target_area / aspect_ratio))) 93 | 94 | if 0 < w <= width and 0 < h <= height: 95 | i = torch.randint(0, height - h + 1, size=(1,)).item() 96 | j = torch.randint(0, width - w + 1, size=(1,)).item() 97 | return i, j, h, w 98 | 99 | # Fallback to central crop 100 | in_ratio = float(width) / float(height) 101 | if in_ratio < min(ratio): 102 | w = width 103 | h = int(round(w / min(ratio))) 104 | elif in_ratio > max(ratio): 105 | h = height 106 | w = int(round(h * max(ratio))) 107 | else: # whole image 108 | w = width 109 | h = height 110 | i = (height - h) // 2 111 | j = (width - w) // 2 112 | return i, j, h, w 113 | 114 | def forward(self, img): 115 | """ 116 | Args: 117 | img (PIL Image or Tensor): Image to be cropped and resized. 118 | 119 | Returns: 120 | PIL Image or Tensor: Randomly cropped and resized image. 121 | """ 122 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 123 | return resized_crop(img, i, j, h, w, self.size, self.interpolation) 124 | 125 | def __repr__(self): 126 | interpolate_str = self.interpolation.value 127 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 128 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 129 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 130 | format_string += ', interpolation={0})'.format(interpolate_str) 131 | return format_string 132 | 133 | 134 | def resized_crop( 135 | img: Tensor, top: int, left: int, height: int, width: int, size: List[int], 136 | interpolation: InterpolationMode = InterpolationMode.BILINEAR, 137 | antialias: Optional[bool] = None, 138 | ) -> Tensor: 139 | """Crop the given image and resize it to desired size. 140 | If the image is torch Tensor, it is expected 141 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 142 | 143 | Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. 144 | 145 | Args: 146 | img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. 147 | top (int): Vertical component of the top left corner of the crop box. 148 | left (int): Horizontal component of the top left corner of the crop box. 149 | height (int): Height of the crop box. 150 | width (int): Width of the crop box. 151 | size (sequence or int): Desired output size. Same semantics as ``resize``. 152 | interpolation (InterpolationMode): Desired interpolation enum defined by 153 | :class:`torchvision.transforms.InterpolationMode`. 154 | Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, 155 | ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. 156 | For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. 157 | 158 | Returns: 159 | PIL Image or Tensor: Cropped image. 160 | """ 161 | img = torchvision.transforms.functional.crop(img, top, left, height, width) 162 | img = torchvision.transforms.functional.resize(img, size, interpolation, antialias=antialias) 163 | return img 164 | -------------------------------------------------------------------------------- /howtocaption/llm_prompting/prompt_vicuna.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fastchat.serve.inference import SeparatorStyle 3 | from fastchat.train.train import smart_tokenizer_and_embedding_resize, DEFAULT_PAD_TOKEN 4 | from fastchat.conversation import Conversation 5 | import pickle 6 | import os 7 | import tqdm 8 | import random 9 | import yaml 10 | from collections import OrderedDict 11 | from pathlib import Path 12 | 13 | import torch 14 | try: 15 | from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, AutoModel, LlamaForCausalLM 16 | except ImportError: 17 | from transformers import AutoTokenizer, AutoModelForCausalLM, LLaMATokenizer, AutoModel 18 | 19 | from fastchat.serve.monkey_patch_non_inplace import replace_llama_attn_with_non_inplace_operations 20 | from fastchat.serve.compression import compress_module 21 | 22 | 23 | def load_model(model_path, device, num_gpus, load_8bit=False, debug=False): 24 | if device == "cpu": 25 | kwargs = {} 26 | elif device == "cuda": 27 | kwargs = {"torch_dtype": torch.float16} 28 | if num_gpus == "auto": 29 | kwargs["device_map"] = "auto" 30 | else: 31 | num_gpus = int(num_gpus) 32 | if num_gpus != 1: 33 | kwargs.update({ 34 | "device_map": "auto", 35 | "max_memory": {i: "13GiB" for i in range(num_gpus)}, 36 | }) 37 | elif device == "mps": 38 | kwargs = {"torch_dtype": torch.float16} 39 | # Avoid bugs in mps backend by not using in-place operations. 40 | replace_llama_attn_with_non_inplace_operations() 41 | else: 42 | raise ValueError(f"Invalid device: {device}") 43 | 44 | if "chatglm" in model_path: 45 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left") 46 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda() 47 | else: 48 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left") 49 | model = AutoModelForCausalLM.from_pretrained(model_path, 50 | low_cpu_mem_usage=True, **kwargs) 51 | 52 | if load_8bit: 53 | compress_module(model, device) 54 | 55 | if (device == "cuda" and num_gpus == 1) or device == "mps": 56 | model.to(device) 57 | 58 | if debug: 59 | print(model) 60 | 61 | return model, tokenizer 62 | 63 | def read_yaml(fname): 64 | fname = Path(fname) 65 | with fname.open('rt') as handle: 66 | class OrderedLoader(yaml.SafeLoader): 67 | pass 68 | def construct_mapping(loader, node): 69 | loader.flatten_mapping(node) 70 | return OrderedDict(loader.construct_pairs(node)) 71 | OrderedLoader.add_constructor( 72 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 73 | construct_mapping) 74 | return yaml.load(handle, OrderedLoader) 75 | 76 | 77 | @torch.inference_mode() 78 | def get_answer(model, tokenizer, asrs, config, device, debug=False): 79 | request_template = config['prompt'] 80 | examples = config['examples'] 81 | batch_processing = config.get('batch_processing', False) 82 | batch_size = config.get('batch_size', 6) 83 | 84 | messages = [] 85 | for asr_example, answer_example in examples: 86 | messages.append(["Human", request_template.format(asr_example.strip())]) 87 | messages.append(["Assistant", answer_example]) 88 | 89 | conv_template = Conversation( 90 | system="A chat between a curious human and an artificial intelligence assistant. " 91 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 92 | roles=["Human", "Assistant"], 93 | messages=messages, 94 | offset=2, 95 | sep_style=SeparatorStyle.SINGLE, 96 | sep="\n### ", 97 | ) 98 | 99 | if batch_processing: 100 | outputs = [] 101 | for offset in range(0, len(asrs), batch_size): 102 | 103 | prompts = [] 104 | for asr in asrs[offset:offset + batch_size]: 105 | conv = conv_template.copy() 106 | conv.append_message(conv.roles[0], request_template.format(asr.strip())) 107 | conv.append_message(conv.roles[1], None) 108 | 109 | prompt = conv.get_prompt() 110 | prompts.append(prompt) 111 | 112 | if debug: 113 | print("Prompt:") 114 | print(prompt) 115 | 116 | input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(device) 117 | output = model.generate(**input_ids, **config['model_generate_args']) 118 | output = tokenizer.batch_decode(output[:, input_ids['input_ids'].shape[1]:], skip_special_tokens=True) 119 | output = [x[:-2] for x in output] 120 | outputs.extend(output) 121 | 122 | if debug: 123 | print("Output:") 124 | for x in output: 125 | print(x) 126 | else: 127 | outputs = [] 128 | for asr in asrs: 129 | conv = conv_template.copy() 130 | conv.append_message(conv.roles[0], request_template.format(asr)) 131 | conv.append_message(conv.roles[1], None) 132 | 133 | prompt = conv.get_prompt() 134 | 135 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) 136 | output = model.generate(input_ids, **config['model_generate_args']) 137 | output = tokenizer.decode(output[0][len(input_ids[0]):-1]) 138 | outputs.append(output) 139 | 140 | return outputs 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--config", type=str) 146 | parser.add_argument("--asr-path", type=str) 147 | parser.add_argument("--model-path", type=str) 148 | 149 | parser.add_argument("--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda") 150 | parser.add_argument("--num-gpus", type=str, default="1") 151 | parser.add_argument("--load-8bit", action="store_true", 152 | help="Use 8-bit quantization.") 153 | parser.add_argument("--debug", action="store_true") 154 | 155 | args = parser.parse_args() 156 | 157 | config = read_yaml(args.config) 158 | device = args.device 159 | num_gpus = args.num_gpus 160 | load_8bit = args.load_8bit 161 | debug = args.debug 162 | 163 | if args.debug: 164 | config['batch_size'] = 1 165 | 166 | model_path = args.model_path 167 | asr_path = args.asr_path 168 | word_blocks = config['word_blocks'] 169 | save_dir = config['save_dir'] 170 | 171 | exp_name = os.path.splitext(os.path.basename(args.config))[0] 172 | 173 | model, tokenizer = load_model(model_path, device, num_gpus, load_8bit, debug) 174 | smart_tokenizer_and_embedding_resize( 175 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 176 | tokenizer=tokenizer, 177 | model=model, 178 | ) 179 | 180 | result_dir = os.path.join(save_dir, exp_name) 181 | os.makedirs(result_dir, exist_ok=True) 182 | 183 | with open(asr_path, 'rb') as fin: 184 | data = pickle.load(fin) 185 | 186 | data = list(data.items()) 187 | if not args.debug: 188 | random.shuffle(data) 189 | 190 | for id_, (key, val) in enumerate(tqdm.tqdm(data)): 191 | texts = val[word_blocks]['text'] 192 | 193 | output_path = f'{result_dir}/{key[0]}/{key[1]}/{key}.pickle' 194 | 195 | if os.path.exists(output_path): 196 | continue 197 | 198 | if args.debug: 199 | print(key, flush=True) 200 | 201 | results = get_answer(model, tokenizer, texts, config, device, debug=debug) 202 | 203 | if args.debug: 204 | continue 205 | 206 | if id_ % 100 == 0: 207 | print(f'Vicuna output: {results[0]}') 208 | 209 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 210 | 211 | with open(output_path, 'wb') as fout: 212 | pickle.dump(results, fout) 213 | -------------------------------------------------------------------------------- /howtocaption/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import numpy as np 5 | import random 6 | import howtocaption.data_loader as module_data 7 | import howtocaption.lr_scheduler as module_lr_scheduler 8 | import howtocaption.model as module_arch 9 | import howtocaption.trainer as module_trainer 10 | from howtocaption.parse_config import ConfigParser 11 | from sacred import Experiment 12 | import warnings 13 | import logging 14 | 15 | import neptune.new as neptune 16 | from neptune.new.integrations.sacred import NeptuneObserver 17 | 18 | from howtocaption.utils.dist_utils import init_distributed_mode, is_main_process, get_rank, get_world_size 19 | 20 | ex = Experiment('train', save_git_info=False) 21 | 22 | 23 | @ex.main 24 | def run(): 25 | print(f"Config: {config['name']}") 26 | 27 | # fix random seeds for reproducibility 28 | if config['seed'] is not None: 29 | warnings.warn('You have chosen to seed training. ' 30 | 'This will turn on the CUDNN deterministic setting, ' 31 | 'which can slow down your training considerably! ' 32 | 'You may see unexpected behavior when restarting ' 33 | 'from checkpoints.') 34 | 35 | seed = config['seed'] + get_rank() 36 | torch.manual_seed(seed) 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | else: 42 | torch.backends.cudnn.benchmark = True 43 | 44 | # setup data_loader instances 45 | torch.multiprocessing.set_start_method('spawn', force=True) 46 | 47 | print("Creating dataset") 48 | train_data_loader = init_dataloaders(config, 'train_data_loader', module_data) 49 | 50 | if config.config.get('valid_data_loader', None) is not None: 51 | valid_data_loader = init_dataloaders(config, 'valid_data_loader', module_data) 52 | else: 53 | valid_data_loader = None 54 | 55 | # build model architecture, then print to console 56 | print("Creating model") 57 | model = config.initialize('arch', module_arch) 58 | # print(model) 59 | 60 | # prepare for (multi-device) GPU training 61 | device = torch.device(args.device) 62 | 63 | if args.distributed: 64 | # Apply SyncBN 65 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 66 | 67 | torch.cuda.set_device(args.gpu) 68 | model = model.cuda(args.gpu) 69 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])#, find_unused_parameters=True) 70 | model_without_ddp = model.module 71 | 72 | else: 73 | model = model.to(device) 74 | model_without_ddp = model 75 | 76 | # build optimizer, learning rate scheduler 77 | trainable_params = [param for name, param in model.named_parameters() if param.requires_grad and 'lora' not in name] 78 | optimizer = config.initialize('optimizer', torch.optim, params=trainable_params) 79 | 80 | if config.config.get('lr_scheduler') is not None: 81 | if config['trainer']['args'].get('lr_scheduler_update') == 'iter': 82 | len_epoch = config['trainer']['args'].get('len_epoch') 83 | if len_epoch is None: 84 | len_epoch = min([len(dl) for dl in train_data_loader]) 85 | if config['lr_scheduler']['type'] == 'CosineAnnealingLR': 86 | config['lr_scheduler']['args']['T_max'] *= len_epoch 87 | else: 88 | raise NotImplementedError() 89 | lr_scheduler = config.initialize('lr_scheduler', module_lr_scheduler, optimizer=optimizer) 90 | else: 91 | lr_scheduler = None 92 | 93 | 94 | trainer = config.initialize('trainer', module_trainer, 95 | model=model, 96 | loss=None, 97 | metrics=None, 98 | optimizer=optimizer, 99 | neptune_run=neptune_run, 100 | config=config, 101 | device=device, 102 | data_loader=train_data_loader, 103 | valid_data_loader=valid_data_loader, 104 | lr_scheduler=lr_scheduler, 105 | model_without_ddp=model_without_ddp) 106 | 107 | trainer.train() 108 | 109 | 110 | def init_dataloaders(config, data_loader_name, module_data, **kwargs): 111 | def fix_args_wrt_world_size(args_to_fix): 112 | for param in ['dataset_size', 'batch_size']: 113 | if param in args_to_fix: 114 | args_to_fix[param] = int(args_to_fix[param] / get_world_size()) 115 | 116 | if "type" in config[data_loader_name] and "args" in config[data_loader_name]: 117 | fix_args_wrt_world_size(config[data_loader_name]['args']) 118 | return [config.initialize(data_loader_name, module_data, **kwargs)] 119 | elif isinstance(config[data_loader_name], list): 120 | data_loaders = [] 121 | for idx in range(len(config[data_loader_name])): 122 | fix_args_wrt_world_size(config[data_loader_name][idx]['args']) 123 | data_loaders.append(config.initialize(data_loader_name, module_data, index=idx, **kwargs)) 124 | return data_loaders 125 | else: 126 | raise ValueError("Check data_loader config, not correct format.") 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser(description='PyTorch Template') 131 | parser.add_argument('-c', '--config', default=None, type=str, 132 | help='config file path (default: None)') 133 | parser.add_argument('-r', '--resume', default=None, type=str, 134 | help='path to latest checkpoint (default: None)') 135 | parser.add_argument('-n', '--neptune', action='store_true', 136 | help='Whether to observe (neptune)') 137 | parser.add_argument('--seed', default=None, type=int) 138 | parser.add_argument('--device', default='cuda') 139 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 140 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 141 | parser.add_argument('--distributed', default=0, type=int) 142 | parser.add_argument('--neptune_mode', default='async', type=str) 143 | 144 | # custom cli options to modify configuration from default values given in json file. 145 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 146 | options = [ 147 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 148 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size'), 149 | ] 150 | config = ConfigParser(parser, options) 151 | args = config.args 152 | 153 | ex.add_config(config.config) 154 | 155 | init_distributed_mode(args=args) 156 | 157 | neptune_run = None 158 | if args.neptune and is_main_process(): 159 | 160 | # for resume 161 | neptune_id_path = config.save_dir.parent / 'neptune_id.txt' 162 | if neptune_id_path.exists(): 163 | with open(neptune_id_path, 'r') as fin: 164 | with_neptune_id = fin.readline().strip() 165 | else: 166 | with_neptune_id = None 167 | 168 | # delete this error if you have added your own neptune credentials neptune.ai 169 | raise ValueError 170 | api_token = '' 171 | project = '' 172 | 173 | neptune_run = neptune.init( 174 | project=project, 175 | api_token=api_token, 176 | source_files=['imp_videocap/**/*.py', '*.py'], 177 | with_id=with_neptune_id, 178 | mode=args.neptune_mode, 179 | ) 180 | # save neptune id to be able to resume logging to the same id 181 | if not neptune_id_path.exists(): 182 | neptune_id = neptune_run["sys/id"].fetch() 183 | with open(neptune_id_path, 'w') as fout: 184 | fout.write(neptune_id) 185 | 186 | logging.getLogger("neptune.new.internal.operation_processors.async_operation_processor").setLevel( 187 | logging.CRITICAL) 188 | ex.observers.append(NeptuneObserver(run=neptune_run)) 189 | 190 | ex.run() 191 | -------------------------------------------------------------------------------- /howtocaption/data_loader/video_datasets/howto.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | import pandas as pd 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import ffmpeg 9 | import zlib 10 | import json 11 | 12 | from howtocaption.data_loader.video_datasets.utils import get_video_frames 13 | from howtocaption.data_loader.transforms import init_transform_dict 14 | 15 | 16 | class HowTo100M(Dataset): 17 | # adopted from https://github.com/antoine77340/MIL-NCE_HowTo100M/blob/master/video_loader.py 18 | """HowTo100M Video-Text loader.""" 19 | 20 | def __init__( 21 | self, 22 | csv, 23 | video_root, 24 | caption_path, 25 | split='train', 26 | num_frames=4, 27 | transforms=None, 28 | 29 | return_all_frames_1fps=False, 30 | lm_loss_weight=1, 31 | ita_loss_weight=1, 32 | dataset_name='HowTo100M', 33 | output_for_visualization=False, 34 | meta_info_path=None, # optional path to dict {video_id: {'width': x, 'height': y}, ...} 35 | process_only_part_i=None, 36 | number_of_parts=None, 37 | captions_are_zipped=False, 38 | aggregate_to_min_time=False, 39 | min_time=8.0, 40 | ): 41 | """ 42 | Args: 43 | """ 44 | self.split = split 45 | self.video_root = video_root 46 | self.min_time = min_time 47 | self.num_frames = num_frames 48 | self.transforms = transforms 49 | self.return_all_frames_1fps = return_all_frames_1fps 50 | 51 | self.dataset_name = dataset_name 52 | 53 | self.lm_loss_weight = lm_loss_weight 54 | self.ita_loss_weight = ita_loss_weight 55 | self.aggregate_to_min_time = aggregate_to_min_time 56 | self.captions_are_zipped = captions_are_zipped 57 | 58 | assert isinstance(csv, str) 59 | assert isinstance(caption_path, str) 60 | self.csv = pd.read_csv(csv) 61 | 62 | if process_only_part_i is not None: 63 | assert number_of_parts is not None 64 | size = int(np.ceil(len(self.csv) / number_of_parts)) 65 | self.csv = self.csv[process_only_part_i * size: (process_only_part_i + 1) * size] 66 | 67 | with open(caption_path, 'rb') as fin: 68 | self.captions = pickle.load(fin) 69 | 70 | video_ids = set(self.csv['video_id']).intersection(self.captions.keys()) 71 | self.csv = self.csv[self.csv['video_id'].isin(video_ids)] 72 | 73 | self.output_for_visualization = output_for_visualization 74 | self.vis_transform = init_transform_dict()['visualization'] 75 | 76 | if meta_info_path is not None: 77 | with open(meta_info_path, 'rb') as fin: 78 | self.meta_info = pickle.load(fin) 79 | else: 80 | self.meta_info = None 81 | 82 | def __len__(self): 83 | return len(self.csv) 84 | 85 | def _find_nearest_candidates(self, captions, ind): 86 | start, end = ind, ind 87 | diff = captions['end'][end] - captions['start'][start] 88 | # Extend the video clip if shorter than the minimum desired clip duration 89 | offset_from_end = 1 90 | while diff < self.min_time: 91 | if start > 0 and end < len(captions['end']) - 1 - offset_from_end: 92 | d1 = captions['end'][end + 1] - captions['start'][start] 93 | d2 = captions['end'][end] - captions['start'][start - 1] 94 | # Use the closest neighboring video clip 95 | if d2 <= d1: 96 | start -= 1 97 | else: 98 | end += 1 99 | # If no video clips after it, use the clip before it 100 | elif start > 0: 101 | start -= 1 102 | # If no video clips before it, use the clip after it. 103 | elif end < len(captions['end']) - 1 - offset_from_end: 104 | end += 1 105 | # If there's no clips before or after 106 | else: 107 | break 108 | diff = captions['end'][end] - captions['start'][start] 109 | 110 | return start, end 111 | 112 | def _find_future_candidates(self, captions, ind): 113 | start, end = ind, ind 114 | diff = captions['end'][end] - captions['start'][start] 115 | # Extend the video clip if shorter than the minimum desired clip duration 116 | offset_from_end = 1 117 | while diff < self.min_time: 118 | if end < len(captions['end']) - 1 - offset_from_end: 119 | end += 1 120 | else: 121 | break 122 | diff = captions['end'][end] - captions['start'][start] 123 | return start, end 124 | 125 | def _sample_clip_ids(self, captions): 126 | if self.aggregate_to_min_time: 127 | offset_from_end = 1 128 | ind = random.randint(0, len(captions['text']) - 1 - offset_from_end) if len(captions['text']) - offset_from_end > 0 else 0 129 | 130 | start_id, end_id = self._find_nearest_candidates(captions, ind) 131 | return start_id, end_id 132 | else: 133 | ind = random.randint(0, len(captions['text']) - 1) # not the same as np.random.randint 134 | return ind, ind 135 | 136 | def _get_text(self, captions, start_id, end_id): 137 | if start_id == end_id: 138 | texts = captions['text'][start_id] 139 | if isinstance(texts, list): 140 | cur_text = np.random.choice(texts) 141 | else: 142 | cur_text = texts 143 | else: 144 | cur_text = '. '.join(captions['text'][start_id:end_id + 1]) 145 | return cur_text 146 | 147 | def __getitem__(self, idx): 148 | video_id = self.csv['video_id'].iloc[idx] 149 | video_path = self.csv['video_path'].iloc[idx] 150 | video_path = os.path.join(self.video_root, video_path) 151 | 152 | captions = self.captions[video_id] 153 | 154 | if self.captions_are_zipped: 155 | captions = zlib.decompress(captions) 156 | captions = json.loads(captions) 157 | 158 | # read width and height from meta to speed up video reading (avoid ffmpeg.probe) 159 | if self.meta_info is not None: 160 | if video_id in self.meta_info: 161 | width = self.meta_info[video_id]['width'] 162 | height = self.meta_info[video_id]['height'] 163 | else: 164 | print("Meta info is not found for id", video_id, flush=True) 165 | width = None 166 | height = None 167 | else: 168 | width = None 169 | height = None 170 | 171 | if self.return_all_frames_1fps: 172 | try: 173 | probe = ffmpeg.probe(video_path) 174 | secs = int(np.floor(float(probe['format']['duration']))) 175 | except Exception as excep: 176 | secs = 1 177 | print("Warning: ffmpeg error. video path: {} error. Error: {}".format(video_path, excep), flush=True) 178 | 179 | video = get_video_frames(video_path, 0, secs, num_frames=max(1, secs), fps=1, width=width, height=height) 180 | 181 | if self.output_for_visualization: 182 | vis_video = self.vis_transform(video) 183 | 184 | if self.transforms is not None: 185 | video = self.transforms(video) 186 | 187 | return {'video': video, 'text': '', 188 | 'start_time': 0, 'end_time': secs, 189 | 'time': secs, 'dataset': self.dataset_name, 'video_id': video_id, 190 | 'vis_video': (vis_video if self.output_for_visualization else 0)} 191 | else: 192 | start_id, end_id = self._sample_clip_ids(captions) 193 | 194 | start_time, end_time = captions['start'][start_id], captions['end'][end_id] 195 | cur_text = self._get_text(captions, start_id, end_id) 196 | video = get_video_frames(video_path, start_time, end_time, self.num_frames, 197 | width=width, height=height, central_frames=True) 198 | 199 | output = {} 200 | if self.output_for_visualization: 201 | vis_video = self.vis_transform(video) 202 | output['vis_video'] = vis_video 203 | output['video_id'] = video_id 204 | output['start_id'] = start_id 205 | 206 | if self.transforms is not None: 207 | video = self.transforms(video) 208 | 209 | output.update({ 210 | 'video': video, 211 | 'start_time': start_time, 'end_time': end_time, 212 | 'time': end_time - start_time, 213 | 'dataset': self.dataset_name, 214 | 'lm_loss_weight': self.lm_loss_weight, 215 | 'ita_loss_weight': self.ita_loss_weight, 216 | }) 217 | 218 | update_output = {'text': cur_text} 219 | output.update(update_output) 220 | return output 221 | 222 | -------------------------------------------------------------------------------- /howtocaption/align_and_filter.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | from collections import defaultdict, Counter 4 | import tqdm 5 | import os 6 | import numpy as np 7 | import argparse 8 | import copy 9 | import json 10 | import zlib 11 | 12 | 13 | def run(args): 14 | THRESHOLD = args.threshold 15 | TOP_QUANTILE_THRESHOLD = args.top_quantile_threshold 16 | TOP_PAIRS_THRESHOLD = args.top_pairs_threshold 17 | OFFSET_SECS = args.offset_secs 18 | SECS = args.secs 19 | 20 | print("Loading text embeddings ...") 21 | PRELOADED_VIDEO_FEATURES = [] 22 | for frame_embeddings in args.frame_embeddings: 23 | with open(frame_embeddings, 'rb') as fin: 24 | PRELOADED_VIDEO_FEATURES.append(pickle.load(fin)) 25 | 26 | print("Loading text embeddings ...") 27 | PRELOADED_TEXT_FEATURES = [] 28 | for text_embeddings in args.text_embeddings: 29 | with open(text_embeddings, 'rb') as fin: 30 | PRELOADED_TEXT_FEATURES.append(pickle.load(fin)) 31 | assert len(PRELOADED_VIDEO_FEATURES) == len(PRELOADED_TEXT_FEATURES) 32 | 33 | def find_alignment(video_ids, threshold): 34 | output = [] 35 | found = 0 36 | selected_how_to_indices = [] 37 | 38 | for video_id in tqdm.tqdm(video_ids): 39 | try: 40 | if any((video_id not in features) for features in PRELOADED_VIDEO_FEATURES): 41 | print(f"Video features {video_id} are missing", flush=True) 42 | continue 43 | 44 | N_features = len(PRELOADED_VIDEO_FEATURES) 45 | RIGHT_PAD = int(np.ceil(SECS / 2.0)) 46 | 47 | cur_video_embeds_list = [] 48 | cur_text_embeds_list = [] 49 | for i in range(N_features): 50 | cur_video_embeds = PRELOADED_VIDEO_FEATURES[i][video_id]['frames'] 51 | cur_text_embeds = PRELOADED_TEXT_FEATURES[i][video_id]['features'] 52 | if isinstance(cur_text_embeds, np.ndarray): 53 | cur_text_embeds = torch.from_numpy(cur_text_embeds.astype(np.float32)) 54 | cur_video_embeds = torch.from_numpy(cur_video_embeds.astype(np.float32)) 55 | cur_text_embeds = cur_text_embeds.cuda() 56 | cur_video_embeds = cur_video_embeds.cuda() 57 | 58 | # get clip embeddings (average over frame embeddings) 59 | cur_video_embeds = cur_video_embeds[None] 60 | cur_video_embeds = torch.nn.functional.avg_pool2d(torch.nn.functional.pad(cur_video_embeds, (0, 0, RIGHT_PAD, SECS - RIGHT_PAD), mode='reflect'), (SECS, 1), stride=1)[0] 61 | cur_video_embeds_list.append(cur_video_embeds) 62 | cur_text_embeds_list.append(cur_text_embeds) 63 | 64 | cur_texts = PRELOADED_TEXT_FEATURES[0][video_id]['text'] 65 | starts = PRELOADED_TEXT_FEATURES[0][video_id]['start'] 66 | ends = PRELOADED_TEXT_FEATURES[0][video_id]['end'] 67 | num_sents = PRELOADED_TEXT_FEATURES[0][video_id]['num_sents'] 68 | cur_offset = 0 69 | for cur_i in range(len(starts)): 70 | start = starts[cur_i] 71 | end = ends[cur_i] 72 | 73 | # find center of the segment 74 | start = int(np.floor((start + end) / 2)) 75 | end = start + 1 76 | 77 | # apply offset 78 | start = max(0, int(np.floor(start)) - OFFSET_SECS) 79 | end = int(np.ceil(end)) + OFFSET_SECS 80 | is_empty = False 81 | sims = 0 82 | for i in range(N_features): 83 | cur_video_embeds = cur_video_embeds_list[i] 84 | cur_text_embeds = cur_text_embeds_list[i] 85 | video_embeds = cur_video_embeds[start:end] 86 | if len(video_embeds) == 0: 87 | is_empty = True 88 | break 89 | text_embeds = cur_text_embeds[cur_offset:cur_offset + num_sents[cur_i]] 90 | 91 | sims = text_embeds @ video_embeds.t() + sims 92 | if is_empty: 93 | continue 94 | 95 | texts = cur_texts[cur_offset:cur_offset + num_sents[cur_i]] 96 | cur_offset += num_sents[cur_i] 97 | values, cur_indices = torch.topk(sims, k=1, dim=1) 98 | values, cur_indices = values[:, -1], cur_indices[:, -1] 99 | if threshold is not None: 100 | cur_indices[values < threshold] = -1 101 | cur_indices = cur_indices.cpu().tolist() 102 | 103 | for ind, text, similarity in zip(cur_indices, texts, values): 104 | if ind != -1: 105 | start_time = max(0, start + ind - RIGHT_PAD) 106 | end_time = start + ind + (SECS - RIGHT_PAD) 107 | output.append((video_id, start_time, end_time, text, similarity)) 108 | found += 1 109 | selected_how_to_indices.append(f'{video_id}_{start_time}') 110 | except Exception as excep: 111 | print("Error: {}".format(excep), flush=True) 112 | 113 | return found, output, selected_how_to_indices 114 | 115 | all_video_ids = list(PRELOADED_TEXT_FEATURES[0].keys()) 116 | print(f"Number of video_ids in text dict {len(all_video_ids)}") 117 | print(f"Number of video_ids in video dict: {len( list(PRELOADED_VIDEO_FEATURES[0].keys()))}") 118 | 119 | print("Creating alignment ...", flush=True) 120 | 121 | if TOP_PAIRS_THRESHOLD is not None: 122 | print("top_pairs_threshold is defined --> estimating top_quantile_threshold ...", flush=True) 123 | video_ids = copy.deepcopy(all_video_ids) 124 | np.random.shuffle(video_ids) 125 | n_clips = sum(len(PRELOADED_TEXT_FEATURES[0][video_id]['text']) for video_id in video_ids) 126 | if args.number_of_parts is not None: 127 | TOP_QUANTILE_THRESHOLD = (TOP_PAIRS_THRESHOLD / args.number_of_parts) / n_clips 128 | else: 129 | TOP_QUANTILE_THRESHOLD = TOP_PAIRS_THRESHOLD / n_clips 130 | print(f"Estimated top_quantile_threshold is {TOP_QUANTILE_THRESHOLD}", flush=True) 131 | 132 | if TOP_QUANTILE_THRESHOLD is not None: 133 | print("top_quantile_threshold is defined --> estimating threshold ...", flush=True) 134 | video_ids = copy.deepcopy(all_video_ids) 135 | np.random.shuffle(video_ids) 136 | n_random_videos = min(5000, len(video_ids)) 137 | print(f"Estimate threshold on {n_random_videos} random videos", flush=True) 138 | found, output, selected_how_to_indices = find_alignment(video_ids[:n_random_videos], None) 139 | similarities = [similarity.cpu().item() for video_id, start_time, end_time, text, similarity in output] 140 | THRESHOLD = np.quantile(similarities, (1 - TOP_QUANTILE_THRESHOLD)) 141 | print(f"Estimated threshold is {THRESHOLD} ", flush=True) 142 | 143 | print(f"Alignment and filtering using threshold {THRESHOLD}", flush=True) 144 | found, output, selected_how_to_indices = find_alignment(all_video_ids, THRESHOLD) 145 | print(f"Alignment and filtering is finished!", flush=True) 146 | 147 | print(f"Only {found} text-video clip pairs are left") 148 | print("Number of unique video clips: ", len(Counter(selected_how_to_indices))) 149 | 150 | ######## ----- SAVING -------- 151 | 152 | new_data = defaultdict(lambda: {'start': [], 'end': [], 'text': []}) 153 | 154 | for video_id, start_time, end_time, text, similatiry in tqdm.tqdm(output): 155 | new_data[video_id]["start"].append(start_time) 156 | new_data[video_id]["end"].append(end_time) 157 | if args.with_scores: 158 | new_data[video_id]["text"].append([(text, similatiry.item())]) 159 | else: 160 | new_data[video_id]["text"].append([text]) 161 | 162 | print(f"Saving alignments to {args.output}") 163 | os.makedirs(os.path.dirname(args.output), exist_ok=True) 164 | with open(args.output, 'wb') as fout: 165 | pickle.dump(dict(new_data), fout) 166 | 167 | name, ext = os.path.splitext(args.output) 168 | zipped_path = f'{name}_zipped{ext}' 169 | print(f"Zipping alignments to {zipped_path}") 170 | zipped_data = {} 171 | for key, val in new_data.items(): 172 | val = json.dumps(val) 173 | val = zlib.compress(val.encode(), level=9) 174 | zipped_data[key] = val 175 | 176 | with open(zipped_path, 'wb') as fout: 177 | pickle.dump(zipped_data, fout) 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser(description='PyTorch Template') 182 | parser.add_argument('--frame_embeddings', type=str, nargs='+') 183 | parser.add_argument('--text_embeddings', type=str, nargs='+') 184 | parser.add_argument('--output', type=str) 185 | 186 | parser.add_argument('--threshold', default=None, type=float) 187 | parser.add_argument('--top_quantile_threshold', default=None, type=float) 188 | parser.add_argument('--top_pairs_threshold', default=None, type=int) 189 | 190 | parser.add_argument('--offset_secs', default=10, type=int) 191 | parser.add_argument('--secs', default=8, type=int) 192 | 193 | parser.add_argument('--process_only_part_i', default=None, type=str) 194 | parser.add_argument('--number_of_parts', default=None, type=int) 195 | parser.add_argument('--with_scores', default=0, type=int) 196 | 197 | args = parser.parse_args() 198 | assert sum([int(args.threshold is not None), 199 | int(args.top_quantile_threshold is not None), 200 | int(args.top_pairs_threshold is not None)]) <= 1 201 | run(args) -------------------------------------------------------------------------------- /howtocaption/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import defaultdict, deque 4 | import datetime 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | # adopted from https://github.com/salesforce/BLIP/blob/main/utils.py 11 | def setup_for_distributed(is_master): 12 | """ 13 | This function disables printing when not in master process 14 | """ 15 | import builtins as __builtin__ 16 | builtin_print = __builtin__.print 17 | 18 | def print(*args, **kwargs): 19 | force = kwargs.pop('force', False) 20 | if is_master or force: 21 | builtin_print(*args, **kwargs) 22 | 23 | __builtin__.print = print 24 | 25 | 26 | def init_distributed_mode(args): 27 | if args.distributed: 28 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 29 | print('Using distributed mode (RANK)') 30 | args.rank = int(os.environ["RANK"]) 31 | args.world_size = int(os.environ['WORLD_SIZE']) 32 | args.gpu = int(os.environ['LOCAL_RANK']) 33 | elif 'SLURM_PROCID' in os.environ: 34 | print('Using distributed mode (SLURM_PROCID)') 35 | args.rank = int(os.environ['SLURM_PROCID']) 36 | args.gpu = args.rank % torch.cuda.device_count() 37 | else: 38 | print('Not using distributed mode') 39 | args.distributed = False 40 | return 41 | 42 | # torch.cuda.set_device(args.gpu) 43 | args.dist_backend = 'nccl' 44 | print('Distributed init (rank {}, world {}, gpu {}), url:{}'.format( 45 | args.rank, args.world_size, args.gpu, args.dist_url), flush=True) 46 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 47 | world_size=args.world_size, rank=args.rank) 48 | torch.distributed.barrier() 49 | setup_for_distributed(args.rank == 0) 50 | else: 51 | print('Not using distributed mode') 52 | 53 | 54 | def is_dist_avail_and_initialized(): 55 | if not dist.is_available(): 56 | return False 57 | if not dist.is_initialized(): 58 | return False 59 | return True 60 | 61 | 62 | def get_world_size(): 63 | if not is_dist_avail_and_initialized(): 64 | return 1 65 | return dist.get_world_size() 66 | 67 | 68 | def get_rank(): 69 | if not is_dist_avail_and_initialized(): 70 | return 0 71 | return dist.get_rank() 72 | 73 | 74 | def is_main_process(): 75 | return get_rank() == 0 76 | 77 | 78 | def save_on_master(*args, **kwargs): 79 | if is_main_process(): 80 | torch.save(*args, **kwargs) 81 | 82 | 83 | class SmoothedValue(object): 84 | """Track a series of values and provide access to smoothed values over a 85 | window or the global series average. 86 | """ 87 | 88 | def __init__(self, window_size=20, fmt=None): 89 | if fmt is None: 90 | fmt = "{median:.4f} ({global_avg:.4f})" 91 | self.deque = deque(maxlen=window_size) 92 | self.total = 0.0 93 | self.count = 0 94 | self.fmt = fmt 95 | 96 | def update(self, value, n=1): 97 | self.deque.append(value) 98 | self.count += n 99 | self.total += value * n 100 | 101 | def synchronize_between_processes(self): 102 | """ 103 | Warning: does not synchronize the deque! 104 | """ 105 | if not is_dist_avail_and_initialized(): 106 | return 107 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 108 | dist.barrier() 109 | dist.all_reduce(t) 110 | t = t.tolist() 111 | self.count = int(t[0]) 112 | self.total = t[1] 113 | 114 | @property 115 | def median(self): 116 | d = torch.tensor(list(self.deque)) 117 | return d.median().item() 118 | 119 | @property 120 | def avg(self): 121 | d = torch.tensor(list(self.deque), dtype=torch.float32) 122 | return d.mean().item() 123 | 124 | @property 125 | def global_avg(self): 126 | return self.total / self.count 127 | 128 | @property 129 | def max(self): 130 | return max(self.deque) 131 | 132 | @property 133 | def value(self): 134 | return self.deque[-1] 135 | 136 | def __str__(self): 137 | return self.fmt.format( 138 | median=self.median, 139 | avg=self.avg, 140 | global_avg=self.global_avg, 141 | max=self.max, 142 | value=self.value) 143 | 144 | 145 | class MetricLogger(object): 146 | def __init__(self, delimiter="\t", neptune_run=None): 147 | self.meters = defaultdict(SmoothedValue) 148 | self.delimiter = delimiter 149 | self.neptune_run = neptune_run 150 | 151 | def update(self, k, v, step=None, log=True): 152 | if isinstance(v, torch.Tensor): 153 | v = v.item() 154 | assert isinstance(v, (float, int)) 155 | self.meters[k].update(v) 156 | if log and self.neptune_run is not None: 157 | self.neptune_run[k].log(v, step=step) 158 | 159 | def __getitem__(self, name): 160 | if name in self.meters: 161 | return self.meters[name] 162 | if name in self.__dict__: 163 | return self.__dict__[name] 164 | raise AttributeError("'{}' object has no attribute '{}'".format( 165 | type(self).__name__, name)) 166 | 167 | def __str__(self): 168 | loss_str = [] 169 | for name, meter in self.meters.items(): 170 | loss_str.append( 171 | "{}: {}".format(name, str(meter)) 172 | ) 173 | return self.delimiter.join(loss_str) 174 | 175 | def get_global_avg(self): 176 | return {k: meter.global_avg for k, meter in self.meters.items()} 177 | 178 | def log_global_avg(self, epoch=None): 179 | if self.neptune_run is not None: 180 | for key, value in self.get_global_avg().items(): 181 | self.neptune_run[key + '_avg_by_epoch'].log(value, step=epoch) 182 | 183 | def str_global_avg(self): 184 | loss_str = [] 185 | for name, meter in self.get_global_avg().items(): 186 | loss_str.append( 187 | "{}: {:.4f}".format(name, meter.global_avg) 188 | ) 189 | return self.delimiter.join(loss_str) 190 | 191 | def synchronize_between_processes(self): 192 | for meter in self.meters.values(): 193 | meter.synchronize_between_processes() 194 | 195 | def add_meter(self, name, meter): 196 | self.meters[name] = meter 197 | 198 | def log_every(self, iterable, print_freq, header=None): 199 | i = 0 200 | if not header: 201 | header = '' 202 | start_time = time.time() 203 | end = time.time() 204 | iter_time = SmoothedValue(fmt='{avg:.4f}') 205 | data_time = SmoothedValue(fmt='{avg:.4f}') 206 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 207 | log_msg = [ 208 | header, 209 | '[{0' + space_fmt + '}/{1}]', 210 | 'eta: {eta}', 211 | '{meters}', 212 | 'time: {time}', 213 | 'data: {data}' 214 | ] 215 | if torch.cuda.is_available(): 216 | log_msg.append('max mem: {memory:.0f}') 217 | log_msg = self.delimiter.join(log_msg) 218 | MB = 1024.0 * 1024.0 219 | for obj in iterable: 220 | data_time.update(time.time() - end) 221 | yield obj 222 | iter_time.update(time.time() - end) 223 | if i % print_freq == 0 or i == len(iterable) - 1: 224 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 225 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 226 | if torch.cuda.is_available(): 227 | print(log_msg.format( 228 | i, len(iterable), eta=eta_string, 229 | meters=str(self), 230 | time=str(iter_time), data=str(data_time), 231 | memory=torch.cuda.max_memory_allocated() / MB)) 232 | else: 233 | print(log_msg.format( 234 | i, len(iterable), eta=eta_string, 235 | meters=str(self), 236 | time=str(iter_time), data=str(data_time))) 237 | i += 1 238 | end = time.time() 239 | total_time = time.time() - start_time 240 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 241 | print('{} Total time: {} ({:.4f} s / it)'.format( 242 | header, total_time_str, total_time / len(iterable))) 243 | 244 | 245 | @torch.no_grad() 246 | def concat_all_gather(tensor): 247 | """ 248 | Performs all_gather operation on the provided tensors. 249 | *** Warning ***: torch.distributed.all_gather has no gradient. 250 | """ 251 | tensors_gather = [torch.ones_like(tensor) 252 | for _ in range(torch.distributed.get_world_size())] 253 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 254 | 255 | output = torch.cat(tensors_gather, dim=0) 256 | return output 257 | 258 | def evenly_divisible_concat_all_gather(tensor): 259 | """ 260 | Utility function for distributed data parallel to pad tensor to make it evenly divisible for all_gather. 261 | Args: 262 | data: source tensor to pad and execute all_gather in distributed data parallel. 263 | 264 | """ 265 | # if torch.distributed.get_world_size() <= 1: 266 | # return data 267 | # make sure the data is evenly-divisible on multi-GPUs 268 | length = torch.tensor([tensor.shape[0], 0], dtype=torch.int, device=tensor.device) 269 | tensors_gather_length = [torch.zeros_like(length) 270 | for _ in range(torch.distributed.get_world_size())] 271 | torch.distributed.all_gather(tensors_gather_length, length) 272 | all_len = [itm[0].item() for itm in tensors_gather_length] 273 | max_len = max(all_len) 274 | tensor_len = tensor.shape[0] 275 | if tensor_len < max_len: 276 | size = [max_len - tensor_len] + list(tensor.shape[1:]) 277 | tensor = torch.cat([tensor, tensor.new_full(size, float("NaN"))], dim=0) 278 | # all gather across all processes 279 | tensors_gather = [torch.ones_like(tensor) 280 | for _ in range(torch.distributed.get_world_size())] 281 | torch.distributed.all_gather(tensors_gather, tensor) 282 | # delete the padding NaN items 283 | return torch.cat([tensors_gather[i][:l, ...] for i, l in enumerate(all_len)], dim=0) -------------------------------------------------------------------------------- /howtocaption/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | import os 5 | import sys 6 | from collections import OrderedDict 7 | 8 | from howtocaption.utils.dist_utils import is_main_process 9 | 10 | 11 | class BaseTrainer: 12 | """ 13 | Base class for all trainers 14 | """ 15 | def __init__(self, model, model_without_ddp, loss, metrics, optimizer, device, config, 16 | epochs, save_period, monitor='off', init_val=False, early_stop=inf, neptune_run=None, 17 | resume_only_model=False, resume_only_model_and_opt=False, 18 | freq_eval=1, nlp_freq_eval=10, retrieval_freq_eval=10000000, 19 | freq_visual_input=50, log_visual_input_at_start=True, 20 | init_nlp=False, init_retrieval=False, init_val_loss=False, save_epochs=None, remove_resume=False, load_strict=True): 21 | 22 | self.model = model 23 | self.model_without_ddp = model_without_ddp 24 | self.loss = loss 25 | self.metrics = metrics 26 | self.optimizer = optimizer 27 | self.device = device 28 | 29 | self.config = config 30 | 31 | self.epochs = epochs 32 | self.save_period = save_period 33 | self.save_epochs = save_epochs 34 | if self.save_epochs is not None: 35 | assert isinstance(self.save_epochs, list) 36 | self.monitor = monitor 37 | self.init_val = init_val 38 | self.init_nlp = init_nlp 39 | self.init_retrieval = init_retrieval 40 | self.init_val_loss = init_val_loss 41 | 42 | self.resume_only_model = resume_only_model 43 | self.resume_only_model_and_opt = resume_only_model_and_opt 44 | self.remove_resume = remove_resume 45 | self.load_strict = load_strict 46 | 47 | # configuration to monitor model performance and save best 48 | if self.monitor == 'off': 49 | self.mnt_mode = 'off' 50 | self.mnt_best = 0 51 | else: 52 | self.mnt_mode, self.mnt_metric = self.monitor.split() 53 | assert self.mnt_mode in ['min', 'max'] 54 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 55 | self.early_stop = early_stop 56 | if self.early_stop <= 0: 57 | self.early_stop = inf 58 | 59 | self.start_epoch = 1 60 | self.step = 0 61 | 62 | self.checkpoint_dir = config.save_dir 63 | 64 | # # setup visualization writer instance 65 | # self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 66 | self.neptune_run = neptune_run 67 | 68 | # how often to evaluate 69 | self.freq_eval = freq_eval 70 | 71 | # how often to evaluate all these caption metrics 72 | self.nlp_freq_eval = nlp_freq_eval 73 | 74 | # how often to retrieval 75 | self.retrieval_freq_eval = retrieval_freq_eval 76 | 77 | # how often to log visual input 78 | self.freq_visual_input = freq_visual_input 79 | self.log_visual_input_at_start = log_visual_input_at_start 80 | 81 | if config.resume is not None: 82 | self._resume_checkpoint(config.resume) 83 | 84 | @abstractmethod 85 | def _train_epoch(self, epoch): 86 | """ 87 | Training logic for an epoch 88 | 89 | :param epoch: Current epoch number 90 | """ 91 | raise NotImplementedError 92 | 93 | @abstractmethod 94 | def _valid_epoch(self, epoch): 95 | """ 96 | Training logic for an epoch 97 | 98 | :param epoch: Current epoch number 99 | """ 100 | raise NotImplementedError 101 | 102 | @abstractmethod 103 | def _eval_nlp(self, epoch): 104 | """ 105 | Training logic for an epoch 106 | 107 | :param epoch: Current epoch number 108 | """ 109 | raise NotImplementedError 110 | 111 | 112 | @abstractmethod 113 | def _eval_retrieval(self, epoch): 114 | """ 115 | Training logic for an epoch 116 | 117 | :param epoch: Current epoch number 118 | """ 119 | raise NotImplementedError 120 | 121 | def train(self): 122 | """ 123 | Full training logic 124 | """ 125 | not_improved_count = 0 126 | 127 | # self._eval_retrieval(self.start_epoch - 1) 128 | 129 | 130 | if self.init_val: 131 | log = self._valid_epoch(self.start_epoch - 1) 132 | # print logged informations to the screen 133 | for key, value in log.items(): 134 | print(' {:15s}: {}'.format(str(key), value)) 135 | 136 | if self.init_retrieval: 137 | self._eval_retrieval(self.start_epoch - 1) 138 | 139 | if self.init_nlp: 140 | self._eval_nlp(self.start_epoch - 1) 141 | 142 | for epoch in range(self.start_epoch, self.epochs + 1): 143 | # with torch.autograd.set_detect_anomaly(True): 144 | result = self._train_epoch(epoch) 145 | 146 | # save logged informations into log dict 147 | log = {'epoch': epoch} 148 | for key, value in result.items(): 149 | if key == 'metrics': 150 | log.update({mtr.__name__: value[i] 151 | for i, mtr in enumerate(self.metrics)}) 152 | elif key == 'val_metrics': 153 | log.update({'val_' + mtr.__name__: value[i] 154 | for i, mtr in enumerate(self.metrics)}) 155 | elif key == 'nested_val_metrics': 156 | # NOTE: currently only supports two layers of nesting 157 | for subkey, subval in value.items(): 158 | for subsubkey, subsubval in subval.items(): 159 | for subsubsubkey, subsubsubval in subsubval.items(): 160 | log[f"val_{subkey}_{subsubkey}_{subsubsubkey}"] = subsubsubval 161 | else: 162 | log[key] = value 163 | 164 | # print logged informations to the screen 165 | for key, value in log.items(): 166 | print(' {:15s}: {}'.format(str(key), value)) 167 | 168 | # evaluate model performance according to configured metric, save best checkpoint as model_best 169 | best = False 170 | if self.mnt_mode != 'off': 171 | try: 172 | # check whether model performance improved or not, according to specified metric(mnt_metric) 173 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 174 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 175 | except KeyError: 176 | print("Warning: Metric '{}' is not found. " 177 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 178 | self.mnt_mode = 'off' 179 | improved = False 180 | 181 | if improved: 182 | self.mnt_best = log[self.mnt_metric] 183 | not_improved_count = 0 184 | best = True 185 | else: 186 | not_improved_count += 1 187 | 188 | if not_improved_count > self.early_stop: 189 | print("Validation performance didn\'t improve for {} epochs. " 190 | "Training stops.".format(self.early_stop)) 191 | break 192 | 193 | if best: 194 | self._save_checkpoint(epoch, save_best=best) 195 | 196 | if epoch % self.save_period == 0 : 197 | self._save_checkpoint(epoch) 198 | 199 | if (self.save_epochs is not None) and (epoch in self.save_epochs): 200 | self._save_checkpoint(epoch) 201 | 202 | def _save_checkpoint(self, epoch, save_best=False, save_latest=False): 203 | """ 204 | Saving checkpoints 205 | 206 | :param epoch: current epoch number 207 | :param log: logging information of the epoch 208 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 209 | """ 210 | 211 | if not is_main_process(): # in case of distributed training 212 | return 213 | 214 | arch = type(self.model).__name__ 215 | state = { 216 | 'arch': arch, 217 | 'epoch': epoch, 218 | 'step': self.step, 219 | 'state_dict': self.model.state_dict(), 220 | 'optimizer': self.optimizer.state_dict(), 221 | 'loss': self.loss.state_dict() if self.loss is not None else None, 222 | 'monitor_best': self.mnt_best, 223 | 'config': self.config.config 224 | } 225 | 226 | if save_latest: 227 | # for safety 228 | tmp_best_path = str(self.checkpoint_dir / 'tmp.pth') 229 | torch.save(state, tmp_best_path, _use_new_zipfile_serialization=False) 230 | best_path = str(self.checkpoint_dir / 'latest_model.pth') 231 | os.rename(tmp_best_path, best_path) 232 | print("Saving current model: latest_model.pth ...") # safe in terms of "No space left on device" 233 | 234 | if save_best: 235 | tmp_best_path = str(self.checkpoint_dir / 'tmp.pth') # safe in terms of "No space left on device" 236 | torch.save(state, tmp_best_path, _use_new_zipfile_serialization=False) 237 | best_path = str(self.checkpoint_dir / 'model_best.pth') 238 | os.rename(tmp_best_path, best_path) 239 | print("Saving current best: model_best.pth ...") 240 | 241 | if not(save_best or save_latest): 242 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 243 | torch.save(state, filename, _use_new_zipfile_serialization=False) 244 | print("Saving checkpoint: {} ...".format(filename)) 245 | 246 | def _resume_checkpoint(self, resume_path): 247 | """ 248 | Resume from saved checkpoints 249 | 250 | :param resume_path: Checkpoint path to be resumed 251 | """ 252 | resume_path = str(resume_path) 253 | print("Loading checkpoint: {} ...".format(resume_path)) 254 | checkpoint = torch.load(resume_path, map_location='cpu') 255 | print("Loading from epoch: {} ...".format(checkpoint['epoch']), flush=True) 256 | 257 | if not (self.resume_only_model or self.resume_only_model_and_opt): 258 | self.start_epoch = checkpoint['epoch'] + 1 259 | self.step = checkpoint['step'] + 1 260 | self.mnt_best = checkpoint['monitor_best'] 261 | 262 | # load architecture params from checkpoint. 263 | state_dict = checkpoint['state_dict'] 264 | state_dict = fix_module_in_state_dict(state_dict, self.model) 265 | self.model.load_state_dict(state_dict, strict=self.load_strict) 266 | 267 | if not self.resume_only_model: 268 | # load optimizer state from checkpoint only when optimizer type is not changed. 269 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 270 | print("Warning: Optimizer type given in config file is different from that of checkpoint. " 271 | "Optimizer parameters not being resumed.") 272 | else: 273 | self.optimizer.load_state_dict(checkpoint['optimizer']) 274 | 275 | if self.remove_resume: 276 | assert self.resume_only_model == False 277 | assert self.resume_only_model_and_opt == False 278 | 279 | os.unlink(resume_path) 280 | self._save_checkpoint(checkpoint['epoch'], save_latest=True) 281 | 282 | print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 283 | 284 | 285 | def fix_module_in_state_dict(state_dict, model): 286 | load_state_dict_keys = list(state_dict.keys()) 287 | curr_state_dict_keys = list(model.state_dict().keys()) 288 | redo_dp = False 289 | if not curr_state_dict_keys[0].startswith('module.') and load_state_dict_keys[0].startswith('module.'): 290 | undo_dp = True 291 | elif curr_state_dict_keys[0].startswith('module.') and not load_state_dict_keys[0].startswith('module.'): 292 | redo_dp = True 293 | undo_dp = False 294 | else: 295 | undo_dp = False 296 | 297 | if undo_dp: 298 | new_state_dict = OrderedDict() 299 | for k, v in state_dict.items(): 300 | name = k[7:] # remove `module.` 301 | new_state_dict[name] = v 302 | # load params 303 | elif redo_dp: 304 | new_state_dict = OrderedDict() 305 | for k, v in state_dict.items(): 306 | name = 'module.' + k # remove `module.` 307 | new_state_dict[name] = v 308 | else: 309 | new_state_dict = state_dict 310 | return new_state_dict 311 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HowToCaption: Prompting LLMs to Transform Video Annotations at Scale 2 | 3 |

4 | 5 |

6 | 7 | Official PyTorch implementation of the paper ["HowToCaption: Prompting LLMs to Transform Video Annotations at Scale"](https://arxiv.org/abs/2310.04900), ECCV 2024. 8 | 9 | by [Nina Shvetsova*](https://ninatu.github.io/), 10 | [Anna Kukleva*](https://annusha.github.io/), 11 | [Xudong Hong](https://xudonghong.me/), 12 | [Christian Rupprecht](https://chrirupp.github.io/), 13 | [Bernt Schiele](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/people/bernt-schiele), 14 | [Hilde Kuehne](https://hildekuehne.github.io/). 15 | 16 | [**[arXiv]**](https://arxiv.org/abs/2301.02009) 17 | 18 | 19 | ## HowToCaption Dataset 20 | 21 | We release the **HowToCaption dataset**. Check the [**dataset readme**](dataset/) to download it. 22 | 23 | The HowToCaption dataset comprises 24 | 1.2M long-term instructional videos from [the HowTo100M dataset](https://www.di.ens.fr/willow/research/howto100m/), 25 | where ASR subtitles have been transformed into proper captions 26 | via our HowToCaption method using [the Vicuna-13B LLM](https://lmsys.org/blog/2023-03-30-vicuna/) ([v0](https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md)). 27 | The captions are automatically generated 28 | and their high-quality alignment to the video are further 29 | ensured through subsequent alignment and filtering post-processing, 30 | all achieved without any human involvement. As a result, the HowToCaption dataset contains 25M aligned video-text pairs. 31 | 32 | ## Pretrained Models 33 | 34 | Using the proposed HowToCaption dataset, we pretrained video-language models (initialized from the image-text [BLIP model](https://arxiv.org/abs/2201.12086)): 35 | All checkpoints are available [here](https://drive.google.com/drive/folders/1tZICCBaRW_wwBIfjNfg7qmfpD1LopQfF?usp=sharing). 36 | 37 | | Method | Model size | Dataset | YouCook2 | | MSRVTT | | 38 | |---------|------------|------------|-----|-----|-----|-----| 39 | | | | | R1 | R10 | R1 | R10 | 40 | | [Dual Encoder Model](https://drive.google.com/file/d/1gZhQ8ZdiD-j10NXAxVn-CGluxwLm3Ecj/view?usp=sharing) | ViT-B | HowTo100M | 12.2 | 39.3 | 30.8 | 61.7 | 41 | | [Dual Encoder Model](https://drive.google.com/file/d/1vsyTKKPmaxDYUo9c9zdFGTxsT-0-tNbC/view?usp=sharing) | ViT-B | WebVid2M | 7.3 | 29.0 | 38.5 | 71.9 | 42 | | [Dual Encoder Model](https://drive.google.com/file/d/1Lr3FI9T8Yt7hJo4MzW8DB1MyfIwZHM98/view?usp=sharing) | ViT-B | **HowToCaption** | 13.4 | 44.1 | 37.6 | 73.3 | 43 | | [Full Model (with re-ranking)](https://drive.google.com/file/d/11EEoVPpz-iqVQMXariE2xQZstYDYo2vS/view?usp=sharing) | ViT-B | **HowToCaption** | 18.15 | 50.4 | 44.3 | 76.6 | 44 | | [Full Model (with re-ranking)](https://drive.google.com/file/d/1D_IN0dUbagLvGA4iw81uG101FwLsnILk/view?usp=sharing) | ViT-L | **HowToCaption** | 19.9 | 53.2 | 45.2 | 77.8 | 45 | 46 | Full Model (ViT-B) fine-tuned for video captioning: 47 | 48 | | Dataset | BLEU@4 | METEOR | ROUGE | CIDEr| 49 | |---------|------------|------------|-----|-----| 50 | | [YouCook2](https://drive.google.com/file/d/1Hl5nQx0lvuPwtOxmhIU1V8dtYMLypyGc/view?usp=sharing) | 8.8 | 15.9 | 37.3 | 116.4 | 51 | | [MSRVTT](https://drive.google.com/file/d/1KzRDBee-JspNYd4N1a-0ZcxmW0dSFHyW/view?usp=sharing) | 49.8 | 32.2 | 66.3 | 65.3 | 52 | | [MSVD](https://drive.google.com/file/d/1D_3BjPxUeHuHRCQXqk5YEOL97Q4TA_Pp/view?usp=sharing) | 70.4 | 46.4 | 83.2 | 154.2 | 53 | 54 | We also release weights for the fine-tuned [VAST](https://github.com/TXH-mercury/VAST) ViT-L model: [weights](https://drive.google.com/file/d/1biT_wj8SMsPB6i9h59StP3X3tz5l2L9L/view?usp=sharing). 55 | 56 | ## Get Started 57 | 58 | ### Set Up an Environment 59 | 60 | ```shell 61 | conda create python=3.8 -y -n howtocaption 62 | conda activate howtocaption 63 | conda install -y pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 64 | pip install -r requirements.txt 65 | pip install -e . 66 | ``` 67 | 68 | ### Data Preparation 69 | 70 | #### HowTo100M Dataset 71 | 72 | Preprocess data into the `data/howto100m` folder: 73 | 1. Link the folder to the videos of [HowTo100M](https://www.di.ens.fr/willow/research/howto100m/) in `data/howto100m/videos` 74 | 2. Create a CSV file `data/howto100m/video_path_downloaded.csv` with video_id and video_path correspondences (path should be relative to the folder `data/howto100m/videos`). For example: 75 | 76 | ``` 77 | video_id, video_path 78 | RoupYOneCIo,food_and_entertaining/video_R/RoupYOneCIo.webm 79 | Zx3_yGY_ERs, food_and_entertaining/video_Z/Zx3_yGY_ERs.mp4 80 | ``` 81 | 82 | #### HowToCaption Dataset 83 | 84 | Follow the [**dataset readme**](dataset/) to download the HowToCaption dataset and store it in `data/howtocaption`. 85 | 86 | #### MSRVTT, YouCook2, MSVD, LSMDC 87 | 88 | Follow [CLIP4CLIP guidelines](https://github.com/ArrowLuo/CLIP4Clip) to download the MSRVTT, MSVD, and LSMDC datasets. 89 | Follow the [YouCook2 guidelines](http://youcook2.eecs.umich.edu/) to download YouCook2. Put datasets in corresponding folders: 90 | `data/msrvtt`, `data/msvd`, `data/lsmdc`, `data/youcook2`. 91 | 92 | ## Video-Language Models 93 | 94 | ### Configs 95 | 96 | This repository uses YAML files to keep all hyperparameters. The `configs` folder contains configs for LLM prompting, vision-language model 97 | training, and evaluation. 98 | 99 | ### Experiment Logging 100 | 101 | This repository uses Sacred with [neptune.ai](https://neptune.ai/) for logging and tracking experiments. If you want to activate this: 102 | 103 | 1. Create a [neptune.ai](https://neptune.ai/) account (you may ask for an academic account if applicable). 104 | 2. Create a project, and copy your credentials (api_token, project name) in `train.py`. 105 | 3. Add the --neptune key to `train.py`. 106 | 107 | ### Run Evaluation 108 | 109 | Evaluate video retrieval (without re-ranking): 110 | 111 | ```shell 112 | export MASTER_PORT=12345 113 | export WORLD_SIZE=4 114 | export MASTER_ADDR=localhost 115 | 116 | torchrun --standalone --nnodes=1 --nproc_per_node=${WORLD_SIZE} \ 117 | howtocaption/eval.py \ 118 | --resume pretrained/dual_encoder_retrieval.pth\ 119 | -c configs/VL_training/dual_encoder_retrieval.yaml \ 120 | --distributed 1 \ 121 | --world_size ${WORLD_SIZE} \ 122 | --eval_retrieval 123 | ``` 124 | 125 | Evaluate video captioning: 126 | 127 | ```shell 128 | export MASTER_PORT=12345 129 | export WORLD_SIZE=4 130 | export MASTER_ADDR=localhost 131 | 132 | torchrun --standalone --nnodes=1 --nproc_per_node=${WORLD_SIZE} \ 133 | howtocaption/eval.py \ 134 | --resume pretrained/captioning_msrvtt.pth\ 135 | -c configs/VL_training/captioning_msrvtt.yaml \ 136 | --distributed 1 \ 137 | --world_size ${WORLD_SIZE} \ 138 | --eval_captioning 139 | ``` 140 | 141 | See more configs in `configs/` and models [here](https://drive.google.com/drive/folders/1tZICCBaRW_wwBIfjNfg7qmfpD1LopQfF?usp=sharing). 142 | 143 | For retrieval evaluation with re-ranking, we followed the [VAST](https://github.com/TXH-mercury/VAST) implementation. 144 | 145 | ### Run Training 146 | 147 | Train Dual-Encoder Model (initialized from BLIP) on HowToCaption Dataset: 148 | 149 | ```shell 150 | export MASTER_PORT=12345 151 | export WORLD_SIZE=4 152 | export MASTER_ADDR=localhost 153 | 154 | torchrun --standalone --nnodes=1 --nproc_per_node=${WORLD_SIZE} \ 155 | howtocaption/train.py \ 156 | -c configs/VL_training/dual_encoder_retrieval.yaml \ 157 | --distributed 1 \ 158 | --world_size ${WORLD_SIZE} 159 | ``` 160 | 161 | Train Full Encoder-Decoder Model (initialized from BLIP) on HowToCaption Dataset: 162 | 163 | ```shell 164 | export MASTER_PORT=12345 165 | export WORLD_SIZE=4 166 | export MASTER_ADDR=localhost 167 | 168 | torchrun --standalone --nnodes=1 --nproc_per_node=${WORLD_SIZE} \ 169 | howtocaption/train.py \ 170 | -c configs/VL_training/dual_encoder_retrieval.yaml \ 171 | --distributed 1 \ 172 | --world_size ${WORLD_SIZE} 173 | ``` 174 | 175 | See more configs in `configs/`. 176 | 177 | ### HowToCaption Framework 178 | 179 | #### LLM Prompting 180 | 181 | We share all steps of the HowToCaption framework with an example of applying it to the HowTo100M dataset. 182 | 1. Make sure `videos` and `video_path_downloaded.csv` are in `data/howto100m` (as described in **Data Preparation**). 183 | 184 | 2. Prepare ASR annotation. Download them, filter by downloaded videos, and divide them into 200-word blocks. We used the [Sentencified HTM](https://www.robots.ox.ac.uk/~vgg/research/tan/index.html#dataset-summary) 185 | version of ASR annotation, where ASR was preprocessed into full sentences. 186 | 187 | ```shell 188 | wget http://www.robots.ox.ac.uk/~htd/tan/sentencified_htm_1200k.json -P data/howto100m 189 | python howtocaption/llm_prompting/scripts/1_filter_available.py --asr data/howto100m/sentencified_htm_1200k.json \ 190 | --csv data/howto100m/video_path_downloaded.csv --output_folder data/howto100m/ 191 | python howtocaption/llm_prompting/scripts/2_create_word_blocks.py --asr data/howto100m/asr_filtered.pickle \ 192 | --n_words_max 200 --output_key '200w' 193 | ``` 194 | 195 | You can use `video_path_filtered_s50.pickle` with only 50 videos for a quick start with later prompting Vicuna: 196 | ```shell 197 | python howtocaption/llm_prompting/scripts/2_create_word_blocks.py --asr data/howto100m/asr_filtered_s50.pickle \ 198 | --n_words_max 200 --output_key '200w' 199 | ``` 200 | 201 | 3. Download Vicuna weights. We used [Vicuna-13B (v0)](https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md). 202 | To download LLaMA weights and Vicuna-13B delta and apply delta weights, follow the [official instruction "How to Apply Delta Weights"](https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md). 203 | 204 | 4. Prompt the Vicuna model to transform ASRs into captions. Results will be saved in a separate file for each video_id. 205 | **Tip:** Run the same script on multiple GPUs to speed up processing. You may use `asr_filtered_s50.pickle` for a quick start. 206 | 207 | ```shell 208 | python howtocaption/llm_prompting/prompt_vicuna.py --config configs/vicuna/final_prompt.yaml \ 209 | --asr-path data/howto100m/asr_filtered.pickle \ 210 | --model-path '/BS/nshvetso/work/cache/huggingface/transformers/models--vicuna-13b' 211 | ``` 212 | 213 | 5. Collect all Vicuna predictions into a single pickle with input timestamps. 214 | ```shell 215 | python howtocaption/llm_prompting/scripts/3_collect_predictions.py --config configs/vicuna/final_prompt.yaml \ 216 | --asr-path data/howto100m/asr_filtered.pickle \ 217 | --output-path output/vicuna/final_prompt.pickle 218 | ``` 219 | 220 | #### Alignment & Filtering 221 | 222 | 1. Extract embeddings for all frames: 223 | 224 | ```shell 225 | config=blip 226 | python howtocaption/save_frame_embeddings.py \ 227 | -c configs/align_and_filter/${config}.yaml 228 | ``` 229 | 230 | Tip: Use `--process_only_part_i` and `--number_of_parts` to process only part of the input data in the current process. For example: 231 | 232 | ```shell 233 | config=blip 234 | process_only_part_i=0 235 | python howtocaption/save_frame_embeddings.py \ 236 | -c configs/align_and_filter/${config}.yaml \ 237 | --process_only_part_i ${process_only_part_i} \ 238 | --number_of_parts 64 239 | ``` 240 | 241 | 2. Extract embeddings for all generated captions: 242 | 243 | ```shell 244 | config=blip 245 | llm_predictions=final_prompt 246 | python howtocaption/save_text_embeddings.py \ 247 | --llm_predictions output/vicuna/${llm_predictions}.pickle \ 248 | -c configs/align_and_filter/${config}.yaml 249 | ``` 250 | 251 | 3. Create alignment and filtering: 252 | 253 | ```shell 254 | config=blip 255 | llm_predictions=final_prompt 256 | python howtocaption/align_and_filter.py \ 257 | --frame_embeddings output/embeddings/video_${config}.pickle \ 258 | --text_embeddings output/embeddings/text_${config}_${llm_predictions}.pickle \ 259 | --top_pairs_threshold 25000000 \ 260 | --output output/generated_dataset/${config}_${llm_predictions}.pickle 261 | ``` 262 | 263 | 4. Fine-tune the video-language model on initial alignments (or use our [fine-tuned model](https://drive.google.com/file/d/1qRn6FQyywv5-l6b7UzPMKcLVNCXS0YIk/view?usp=sharing)). 264 | Add the path to aligned captions in the config 265 | `configs/align_and_filter/finetune_1round.yaml` and fine-tune the model: 266 | 267 | ```shell 268 | export MASTER_PORT=12345 269 | export WORLD_SIZE=4 270 | export MASTER_ADDR=localhost 271 | 272 | torchrun --standalone --nnodes=1 --nproc_per_node=${WORLD_SIZE} \ 273 | howtocaption/train.py \ 274 | -c configs/align_and_filter/finetune_1round.yaml\ 275 | --distributed 1 \ 276 | --world_size ${WORLD_SIZE} 277 | ``` 278 | 279 | 5. Extract text and video features with the new model and repeat re-alignment using both the original and new features. 280 | 281 | ```shell 282 | config1=blip 283 | config2=blip_ft_1round 284 | llm_predictions=final_prompt 285 | python howtocaption/align_and_filter.py \ 286 | --frame_embeddings output/embeddings/video_${config1}.pickle output/embeddings/video_${config2}.pickle \ 287 | --text_embeddings output/embeddings/text_${config1}_${llm_predictions}.pickle output/embeddings/text_${config2}_${llm_predictions}.pickle \ 288 | --top_pairs_threshold 25000000 \ 289 | --output output/generated_dataset/average_${llm_predictions}.pickle 290 | ``` 291 | 292 | ## Acknowledgments and Licenses 293 | 294 | The main structure of the code is based on 295 | https://github.com/victoresque/pytorch-template, which is licensed under MIT. 296 | 297 | The code is partly derived from 298 | https://github.com/salesforce/BLIP, 299 | https://github.com/ArrowLuo/CLIP4Clip, 300 | https://github.com/whwu95/Cap4Video, 301 | https://github.com/lm-sys/FastChat, 302 | https://github.com/tylin/coco-caption 303 | is licensed under an Apache License 2.0 or MIT or BSD-3. 304 | 305 | All other code is licensed under MIT. All license clauses are in the LICENSE file. 306 | 307 | ## Citation 308 | 309 | If you use this code in your research, please cite: 310 | 311 | ``` 312 | @inproceedings{shvetsova2024howtocaption, 313 | title={Howtocaption: Prompting llms to transform video annotations at scale}, 314 | author={Shvetsova, Nina and Kukleva, Anna and Hong, Xudong and Rupprecht, Christian and Schiele, Bernt and Kuehne, Hilde}, 315 | booktitle={European Conference on Computer Vision}, 316 | year={2024}, 317 | organization={Springer} 318 | } 319 | ``` 320 | -------------------------------------------------------------------------------- /howtocaption/model/utils/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | 24 | class Mlp(nn.Module): 25 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 26 | """ 27 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 28 | super().__init__() 29 | out_features = out_features or in_features 30 | hidden_features = hidden_features or in_features 31 | self.fc1 = nn.Linear(in_features, hidden_features) 32 | self.act = act_layer() 33 | self.fc2 = nn.Linear(hidden_features, out_features) 34 | self.drop = nn.Dropout(drop) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = self.act(x) 39 | x = self.drop(x) 40 | x = self.fc2(x) 41 | x = self.drop(x) 42 | return x 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 47 | super().__init__() 48 | self.num_heads = num_heads 49 | head_dim = dim // num_heads 50 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 51 | self.scale = qk_scale or head_dim ** -0.5 52 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 53 | self.attn_drop = nn.Dropout(attn_drop) 54 | self.proj = nn.Linear(dim, dim) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | self.attn_gradients = None 57 | self.attention_map = None 58 | 59 | def save_attn_gradients(self, attn_gradients): 60 | self.attn_gradients = attn_gradients 61 | 62 | def get_attn_gradients(self): 63 | return self.attn_gradients 64 | 65 | def save_attention_map(self, attention_map): 66 | self.attention_map = attention_map 67 | 68 | def get_attention_map(self): 69 | return self.attention_map 70 | 71 | def forward(self, x, register_hook=False): 72 | B, N, C = x.shape 73 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 75 | 76 | attn = (q @ k.transpose(-2, -1)) * self.scale 77 | attn = attn.softmax(dim=-1) 78 | attn = self.attn_drop(attn) 79 | 80 | if register_hook: 81 | self.save_attention_map(attn) 82 | attn.register_hook(self.save_attn_gradients) 83 | 84 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 85 | x = self.proj(x) 86 | x = self.proj_drop(x) 87 | return x 88 | 89 | 90 | class Block(nn.Module): 91 | 92 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 93 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 94 | super().__init__() 95 | self.norm1 = norm_layer(dim) 96 | self.attn = Attention( 97 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 98 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 99 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 100 | self.norm2 = norm_layer(dim) 101 | mlp_hidden_dim = int(dim * mlp_ratio) 102 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 103 | 104 | if use_grad_checkpointing: 105 | self.attn = checkpoint_wrapper(self.attn) 106 | self.mlp = checkpoint_wrapper(self.mlp) 107 | 108 | def forward(self, x, register_hook=False): 109 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 110 | x = x + self.drop_path(self.mlp(self.norm2(x))) 111 | return x 112 | 113 | 114 | class VisionTransformer(nn.Module): 115 | """ Vision Transformer 116 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 117 | https://arxiv.org/abs/2010.11929 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 120 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 121 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 122 | use_grad_checkpointing=False, ckpt_layer=0): 123 | """ 124 | Args: 125 | img_size (int, tuple): input image size 126 | patch_size (int, tuple): patch size 127 | in_chans (int): number of input channels 128 | num_classes (int): number of classes for classification head 129 | embed_dim (int): embedding dimension 130 | depth (int): depth of transformer 131 | num_heads (int): number of attention heads 132 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 133 | qkv_bias (bool): enable bias for qkv if True 134 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 135 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 136 | drop_rate (float): dropout rate 137 | attn_drop_rate (float): attention dropout rate 138 | drop_path_rate (float): stochastic depth rate 139 | norm_layer: (nn.Module): normalization layer 140 | """ 141 | super().__init__() 142 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 143 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 144 | 145 | self.patch_embed = PatchEmbed( 146 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | 148 | num_patches = self.patch_embed.num_patches 149 | 150 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 151 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 152 | self.pos_drop = nn.Dropout(p=drop_rate) 153 | 154 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 155 | self.blocks = nn.ModuleList([ 156 | Block( 157 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 158 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 159 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 160 | ) 161 | for i in range(depth)]) 162 | self.norm = norm_layer(embed_dim) 163 | 164 | trunc_normal_(self.pos_embed, std=.02) 165 | trunc_normal_(self.cls_token, std=.02) 166 | self.apply(self._init_weights) 167 | 168 | def _init_weights(self, m): 169 | if isinstance(m, nn.Linear): 170 | trunc_normal_(m.weight, std=.02) 171 | if isinstance(m, nn.Linear) and m.bias is not None: 172 | nn.init.constant_(m.bias, 0) 173 | elif isinstance(m, nn.LayerNorm): 174 | nn.init.constant_(m.bias, 0) 175 | nn.init.constant_(m.weight, 1.0) 176 | 177 | @torch.jit.ignore 178 | def no_weight_decay(self): 179 | return {'pos_embed', 'cls_token'} 180 | 181 | def forward(self, x, register_blk=-1): 182 | B = x.shape[0] 183 | x = self.patch_embed(x) 184 | 185 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 186 | x = torch.cat((cls_tokens, x), dim=1) 187 | 188 | x = x + self.pos_embed[:,:x.size(1),:] 189 | x = self.pos_drop(x) 190 | 191 | for i,blk in enumerate(self.blocks): 192 | x = blk(x, register_blk==i) 193 | x = self.norm(x) 194 | 195 | return x 196 | 197 | @torch.jit.ignore() 198 | def load_pretrained(self, checkpoint_path, prefix=''): 199 | _load_weights(self, checkpoint_path, prefix) 200 | 201 | 202 | @torch.no_grad() 203 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 204 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 205 | """ 206 | import numpy as np 207 | 208 | def _n2p(w, t=True): 209 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 210 | w = w.flatten() 211 | if t: 212 | if w.ndim == 4: 213 | w = w.transpose([3, 2, 0, 1]) 214 | elif w.ndim == 3: 215 | w = w.transpose([2, 0, 1]) 216 | elif w.ndim == 2: 217 | w = w.transpose([1, 0]) 218 | return torch.from_numpy(w) 219 | 220 | w = np.load(checkpoint_path) 221 | if not prefix and 'opt/target/embedding/kernel' in w: 222 | prefix = 'opt/target/' 223 | 224 | if hasattr(model.patch_embed, 'backbone'): 225 | # hybrid 226 | backbone = model.patch_embed.backbone 227 | stem_only = not hasattr(backbone, 'stem') 228 | stem = backbone if stem_only else backbone.stem 229 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 230 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 231 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 232 | if not stem_only: 233 | for i, stage in enumerate(backbone.stages): 234 | for j, block in enumerate(stage.blocks): 235 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 236 | for r in range(3): 237 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 238 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 239 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 240 | if block.downsample is not None: 241 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 242 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 243 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 244 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 245 | else: 246 | embed_conv_w = adapt_input_conv( 247 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 248 | model.patch_embed.proj.weight.copy_(embed_conv_w) 249 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 250 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 251 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 252 | if pos_embed_w.shape != model.pos_embed.shape: 253 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 254 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 255 | model.pos_embed.copy_(pos_embed_w) 256 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 257 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 258 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 259 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 260 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 261 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 262 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 263 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 264 | for i, block in enumerate(model.blocks.children()): 265 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 266 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 267 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 268 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 269 | block.attn.qkv.weight.copy_(torch.cat([ 270 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 271 | block.attn.qkv.bias.copy_(torch.cat([ 272 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 273 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 274 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 275 | for r in range(2): 276 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 277 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 278 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 279 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 280 | 281 | 282 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 283 | # interpolate position embedding 284 | embedding_size = pos_embed_checkpoint.shape[-1] 285 | num_patches = visual_encoder.patch_embed.num_patches 286 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 287 | # height (== width) for the checkpoint position embedding 288 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 289 | # height (== width) for the new position embedding 290 | new_size = int(num_patches ** 0.5) 291 | 292 | if orig_size!=new_size: 293 | # class_token and dist_token are kept unchanged 294 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 295 | # only the position tokens are interpolated 296 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 297 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 298 | pos_tokens = torch.nn.functional.interpolate( 299 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 300 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 301 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 302 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 303 | 304 | return new_pos_embed 305 | else: 306 | return pos_embed_checkpoint --------------------------------------------------------------------------------