├── config ├── trainer │ └── defaults.yaml ├── optimizer │ └── adam.yaml ├── .DS_Store ├── dataloader │ └── defaults.yaml ├── loss │ └── triple_loss.yaml ├── augmentation │ └── defaults.yaml ├── paths │ └── public.yaml ├── dataset │ └── sentence.yaml ├── model │ └── cslr2.yaml ├── example_job_train.sh ├── cslr2.yaml ├── example_job_test.sh └── cslr2_eval.yaml ├── misc └── process_cslr_json │ ├── plots │ ├── cslr_test_anns_hist.png │ ├── cslr_test_per_sign_density.png │ └── cslr_test_per_duration_density.png │ ├── remove_star_annots_from_csv.py │ ├── run_pipeline.py │ ├── fix_boundaries.py │ └── fix_alignment.py ├── dataset ├── __init__.py ├── lmdb_loader.py └── subtitles.py ├── loops ├── __init__.py ├── retrieval_loop.py ├── train_loop.py ├── val_loop.py └── retrieval.py ├── models ├── __init__.py ├── sbert.py ├── t5.py ├── cslr2.py └── transformer_encoder.py ├── utils ├── __init__.py ├── seed.py ├── idr_torch.py ├── gather.py ├── ddp_settings.py ├── root_words.py ├── instantiate_augmentations.py ├── instantiate_model.py ├── wandb_utils.py ├── synonyms.py ├── instantiate_dataloaders.py ├── matplotlib_utils.py └── frame_level_evaluation_dict.py ├── environment.yaml ├── LICENSE ├── augmentations ├── video_augment.py └── text_augment.py ├── .gitignore ├── loss └── hn_nce.py ├── README.md ├── main.py └── extract_for_eval.py /config/trainer/defaults.yaml: -------------------------------------------------------------------------------- 1 | epochs: 20 2 | epoch_start: 0 3 | -------------------------------------------------------------------------------- /config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_ : torch.optim.Adam 2 | lr: 0.00005 3 | -------------------------------------------------------------------------------- /config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulvarol/cslr2/HEAD/config/.DS_Store -------------------------------------------------------------------------------- /misc/process_cslr_json/plots/cslr_test_anns_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulvarol/cslr2/HEAD/misc/process_cslr_json/plots/cslr_test_anns_hist.png -------------------------------------------------------------------------------- /misc/process_cslr_json/plots/cslr_test_per_sign_density.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulvarol/cslr2/HEAD/misc/process_cslr_json/plots/cslr_test_per_sign_density.png -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """For relative imports to work in Python >= 3.6""" 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | -------------------------------------------------------------------------------- /loops/__init__.py: -------------------------------------------------------------------------------- 1 | """For relative imports to work in Python >= 3.6""" 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """For relative imports to work in Python >= 3.6""" 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """For relative imports to work in Python >= 3.6""" 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 6 | -------------------------------------------------------------------------------- /misc/process_cslr_json/plots/cslr_test_per_duration_density.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulvarol/cslr2/HEAD/misc/process_cslr_json/plots/cslr_test_per_duration_density.png -------------------------------------------------------------------------------- /config/dataloader/defaults.yaml: -------------------------------------------------------------------------------- 1 | dataloader: 2 | _target_: torch.utils.data.DataLoader 3 | batch_size: 128 4 | shuffle: True 5 | num_workers: 10 6 | persistent_workers: True 7 | drop_last: True 8 | pin_memory: True 9 | 10 | # other parameters (used for training on subsets of the data, etc.) 11 | N: null 12 | train_data_fraction: 1.0 13 | val_data_fraction: 1.0 14 | -------------------------------------------------------------------------------- /utils/seed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Function to set the seed for random number generators. 3 | """ 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def setup_seed(seed: int) -> None: 9 | """Set seed for numpy and torch""" 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = True 15 | -------------------------------------------------------------------------------- /config/loss/triple_loss.yaml: -------------------------------------------------------------------------------- 1 | # weight parameters 2 | lda_sent_ret: 0.9 3 | lda_sign_ret: 0.0075 4 | lda_sign_cls: 0 5 | 6 | sent_ret: 7 | _target_: loss.hn_nce.HardNegativeNCE 8 | alpha: 1.0 9 | beta: 1.0 10 | temperature: 0.07 11 | 12 | sign_ret: 13 | _target_: loss.hn_nce.HardNegativeNCE 14 | alpha: 1.0 15 | beta: 0.5 16 | temperature: 0.07 17 | 18 | sign_cls: 19 | _target_: torch.nn.CrossEntropyLoss -------------------------------------------------------------------------------- /config/augmentation/defaults.yaml: -------------------------------------------------------------------------------- 1 | do_swap: False 2 | do_drop: True 3 | do_shuffle: True 4 | do_frame_drop: True 5 | 6 | swap_words: 7 | _target_: augmentations.text_augment.SwapWords 8 | nb_swaps : 1 9 | 10 | drop_words: 11 | _target_: augmentations.text_augment.DropWords 12 | p_sentence: 0.8 13 | p_word: 0.4 14 | 15 | shuffle_words: 16 | _target_: augmentations.text_augment.ShuffleWords 17 | p_shuffle: 0.5 18 | 19 | frame_drop: 20 | _target_: augmentations.video_augment.DropFrames 21 | p_sequence: 0.8 22 | p_frame: 0.5 23 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: cslr2 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9.16 6 | - pytorch==1.12.1 7 | - torchvision==0.13.1 8 | - cudatoolkit=11.6 9 | - pandas=1.5.3 10 | - einops=0.6.0 11 | - humanize=4.6.0 12 | - tqdm=4.65.0 13 | - pip 14 | - pip: 15 | - hydra-core==1.3.2 16 | - matplotlib==3.7.1 17 | - plotly==5.14.1 18 | - nltk==3.8.1 19 | - seaborn==0.12.2 20 | - sentence-transformers==2.2.2 21 | - wandb==0.14.0 22 | - lmdb 23 | - tabulate 24 | - opencv-python==4.7.0.72 25 | -------------------------------------------------------------------------------- /config/paths/public.yaml: -------------------------------------------------------------------------------- 1 | subset2episode: bobsl/splits/subset2episode.json 2 | vocab_pkl: bobsl/vocab/8697_vocab.pkl 3 | info_pkl: bobsl/unannotated-info/info-src_videos_2236.pkl 4 | annotations_pkl: bobsl/lmdbs/lmdb-pl_vswin_t-bs256_float16/ 5 | vid_features_lmdb: bobsl/lmdbs/lmdb-feats_vswin_t-bs256_float16/ 6 | word_embds_pkl: bobsl/vocab_embds/t5-large_8k_word_embeddings.pkl 7 | subtitles_path: bobsl/subtitles_pkl/best_delta_postpro_bobsl_best_traineval_0_pslab_ftune_.pkl 8 | aligned_subtitles_path: bobsl/subtitles_pkl/manually-aligned.pkl 9 | synonyms_pkl: bobsl/syns/synonym_pickle_english_and_signdict_and_signbank.pkl 10 | lm_root: t5_checkpoint/ 11 | rgb_frames: bobsl/lmdbs/lmdb-rgb_anon-public_1962/ 12 | log_dir: runs/ 13 | misaligned_csv_root: bobsl/cslr/manual_glosses_extended_boundaries 14 | heuristic_aligned_csv_root: bobsl/cslr/manual_glosses_extended_boundaries_fix_alignment 15 | -------------------------------------------------------------------------------- /utils/idr_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch: Parallélisme de données multi-GPU et multi-noeaud (IDRIS's website) 3 | Should also work on Triton (to be tested) 4 | /usr/bin/env python 5 | coding: utf-8 6 | """ 7 | import os 8 | import hostlist 9 | 10 | # get SLURM variables 11 | rank = int(os.environ['SLURM_PROCID']) 12 | local_rank = int(os.environ['SLURM_LOCALID']) 13 | size = int(os.environ['SLURM_NTASKS']) 14 | cpus_per_task = int(os.environ['SLURM_CPUS_PER_TASK']) 15 | 16 | # get node list from slurm 17 | hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST']) 18 | 19 | # get IDs of reserved GPU 20 | # gpu_ids = os.environ['SLURM_JOB_GPUS'].split(",") 21 | gpu_ids = os.environ['SLURM_STEP_GPUS'].split(",") 22 | 23 | # define MASTER_ADD & MASTER_PORT 24 | os.environ['MASTER_ADDR'] = hostnames[0] 25 | # to avoid port conflict on the same node 26 | os.environ['MASTER_PORT'] = str(12345 + int(min(gpu_ids))) 27 | -------------------------------------------------------------------------------- /utils/gather.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used in DDP mode with the HN-NCE loss. 3 | """ 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | class DiffAllGather(torch.autograd.Function): 9 | """Gathering all tensors from all processes""" 10 | @staticmethod 11 | def forward(ctx, tensor): 12 | """ 13 | Forward pass with gathering all tensors 14 | """ 15 | gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] 16 | dist.all_gather(gathered, tensor) 17 | return tuple(gathered) 18 | 19 | @staticmethod 20 | def backward(ctx, *grad_outs): 21 | """ 22 | Backward pass with all-reduce 23 | """ 24 | grad_outs = torch.stack(grad_outs) 25 | dist.all_reduce(grad_outs) 26 | return grad_outs[dist.get_rank()] 27 | 28 | 29 | def all_gather(tensor): 30 | """All gather tensors from all processes""" 31 | return DiffAllGather.apply(tensor) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Charles Raude 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/dataset/sentence.yaml: -------------------------------------------------------------------------------- 1 | _target_: dataset.sentence.Sentences 2 | 3 | # general parameters to load sentences associated to subtitles 4 | subset2episode: ${paths.subset2episode} 5 | subtitles_path: ${paths.subtitles_path} 6 | subtitles_temporal_shift: 0.0 7 | subtitles_max_duration: 20.0 8 | subtitles_min_duration: 1.0 9 | temporal_pad: 1.0 10 | info_pkl: ${paths.info_pkl} 11 | filter_stop_words: False 12 | subtitles_random_offset: 0.5 13 | fps: 25 14 | 15 | # parameters to load video features 16 | load_features: True 17 | feats_lmdb: ${paths.vid_features_lmdb} 18 | feats_load_stride: 1 19 | feats_load_float16: False 20 | feats_lmdb_window_size: 16 21 | feats_lmdb_stride: 2 22 | feats_dim: 768 23 | video_augmentations: null 24 | 25 | # parameters to load video pseudo-labels 26 | load_pl: True 27 | pl_lmdb: ${paths.annotations_pkl} 28 | pl_load_stride: 1 29 | pl_load_float16: False 30 | pl_lmdb_window_size: 16 31 | pl_lmdb_stride: 2 32 | pl_filter: 0.6 33 | pl_min_count: 6 34 | pl_synonym_grouping: True 35 | synonyms_pkl: ${paths.synonyms_pkl} 36 | vocab_pkl: ${paths.vocab_pkl} 37 | 38 | # parameters to load word embeddings 39 | load_word_embds: True 40 | word_embds_pkl: ${paths.word_embds_pkl} 41 | 42 | # other parameters 43 | verbose: False -------------------------------------------------------------------------------- /config/model/cslr2.yaml: -------------------------------------------------------------------------------- 1 | freeze_transformer: False 2 | 3 | cslr2: 4 | _target_: models.cslr2.CSLR2 5 | video_encoder: 6 | _target_: models.transformer_encoder.make_model 7 | vocab: 8697 8 | N: 6 9 | d_model: 768 10 | h: 8 11 | dropout: 0.1 12 | contrastive: True 13 | text_encoder: 14 | _target_: models.t5.make_sentence_model 15 | model_name: t5-large 16 | root_path: ${paths.lm_root} 17 | video_sequence_ll: 18 | _target_: torch.nn.Linear 19 | in_features: ${model.cslr2.video_encoder.d_model} 20 | out_features: 256 21 | video_token_ll: 22 | _target_: torch.nn.Linear 23 | in_features: ${model.cslr2.video_encoder.d_model} 24 | out_features: ${model.cslr2.video_sequence_ll.out_features} 25 | text_sentence_ll: 26 | _target_: torch.nn.Linear 27 | in_features: 1024 # size of text encoder embds 28 | out_features: ${model.cslr2.video_sequence_ll.out_features} 29 | text_word_ll: 30 | _target_: torch.nn.Linear 31 | in_features: ${model.cslr2.text_sentence_ll.in_features} 32 | out_features: ${model.cslr2.video_sequence_ll.out_features} 33 | pooling: max 34 | sign_ret: True 35 | no_video_encoder: False 36 | same_text_ll: False 37 | same_video_ll: False 38 | -------------------------------------------------------------------------------- /config/example_job_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=cslr2_train # job name 3 | #SBATCH --account=vvh@v100 # project code 4 | #SBATCH -C v100-32g # choose nodes with 32G GPU memory 5 | #SBATCH --ntasks=4 # number of tasks (GPUs) 6 | #SBATCH --ntasks-per-node=4 # number of tasks (GPUs) per node 7 | #SBATCH --gres=gpu:4 # number of GPUs per node 8 | #SBATCH --qos=qos_gpu-t3 # (20h) jobs 9 | #SBATCH --cpus-per-task=10 # number of cores per tasks 10 | #SBATCH --hint=nomultithread # we get physical cores not logical 11 | #SBATCH --time=20:00:00 # maximum execution time (HH:MM:SS) 12 | #SBATCH --output=logs/cslr2_train%j.out # output file name 13 | #SBATCH --error=logs/cslr2_train%j.err # error file name 14 | 15 | set -x # echo launched commands 16 | 17 | module purge 18 | 19 | . ${WORK}/miniconda3/etc/profile.d/conda.sh 20 | conda activate gpu_pytorch_1.12 21 | 22 | cd "${WORK}/code/cslr2" 23 | 24 | export HYDRA_FULL_ERROR=1 # to get better error messages if job crashes 25 | export WANDB_MODE=offline 26 | 27 | echo "Do not forget to set distributed: True in config/cslr2.yaml" 28 | 29 | srun python main.py run_name=cslr2 30 | 31 | -------------------------------------------------------------------------------- /config/cslr2.yaml: -------------------------------------------------------------------------------- 1 | # configuration for running the cslr2 pipeline 2 | hydra: 3 | run: 4 | dir: ./runs/${run_name} 5 | 6 | defaults: 7 | - _self_ 8 | - paths: public # configuration file for paths 9 | - optimizer: adam # configuration file for optimizer 10 | - model: cslr2 # configuration file for model 11 | - dataset: sentence # configuration file for dataset 12 | - augmentation: defaults # configuration file for augmentation 13 | - loss: triple_loss # configuration file for loss 14 | - dataloader: defaults # configuration file for dataloader 15 | - trainer: defaults # configuration file for trainer details 16 | 17 | run_name: ?? # name of the run 18 | 19 | checkpoint: null # path to a checkpoint to load 20 | 21 | test: False # if True, test the model for retrieval 22 | 23 | # visualisation settings 24 | vis: True # if True, save visualisations 25 | worst_retrieval: False # if True, save worst retrieved results 26 | nb_vis: 0 # number of visualisations to save per epoch 27 | 28 | # wandb settings 29 | wandb_offline: False # if True, wandb will not be synced 30 | 31 | # distributed settings 32 | distributed: False 33 | world_size: null 34 | rank: null 35 | local_rank: null 36 | fixed_lr: False 37 | 38 | # other 39 | seed: 0 # for reproducibility 40 | do_print: True # variable that controls printing (only prints in the main process) 41 | -------------------------------------------------------------------------------- /utils/ddp_settings.py: -------------------------------------------------------------------------------- 1 | """Initialize DDP settings""" 2 | import builtins 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from omegaconf import DictConfig 8 | 9 | 10 | def ddp_settings(cfg: DictConfig) -> DictConfig: 11 | """Initialize DDP settings and return the updated config""" 12 | if cfg.distributed: 13 | import utils.idr_torch as idr_torch 14 | if "SLURM_PROCID" in os.environ: 15 | cfg.world_size = idr_torch.size 16 | cfg.rank = idr_torch.rank 17 | cfg.local_rank = idr_torch.local_rank 18 | cfg.optimizer.lr = cfg.optimizer.lr * cfg.world_size if not cfg.fixed_lr \ 19 | else cfg.optimizer.lr 20 | elif "LOCAL_RANK" in os.environ and int(os.environ["LOCAL_RANK"]) != -1: 21 | # for torch.distributed.launch 22 | cfg.rank = int(os.environ["LOCAL_RANK"]) 23 | cfg.local_rank = cfg.rank 24 | cfg.world_size = int(os.environ["WORLD_SIZE"]) 25 | dist.init_process_group( 26 | backend="nccl", init_method="env://", 27 | world_size=cfg.world_size, rank=cfg.rank, 28 | ) 29 | torch.cuda.set_device(cfg.local_rank) 30 | else: 31 | cfg.world_size, cfg.rank, cfg.local_rank = 1, 0, 0 32 | if cfg.rank != 0: 33 | print(f"Rank {cfg.rank} is muted") 34 | def print_pass(*args): 35 | pass 36 | builtins.print = print_pass 37 | return cfg 38 | -------------------------------------------------------------------------------- /utils/root_words.py: -------------------------------------------------------------------------------- 1 | """Python file defining functions to extract root words from a given word.""" 2 | import re 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | from nltk.stem import WordNetLemmatizer 7 | 8 | wordnet_lemmatizer = WordNetLemmatizer() 9 | 10 | 11 | def get_root_words( 12 | words: Union[str, List[str], np.ndarray], 13 | segment_behavior: bool = False, 14 | ) -> List[List[str]]: 15 | """Normalise + lemmatise words""" 16 | if isinstance(words, List) or isinstance(words, np.ndarray): 17 | words = " ".join(words) 18 | if segment_behavior: 19 | words = words.strip().replace("-", " ").split() 20 | words = [word.lower() for word in words] 21 | words = [ 22 | wordnet_lemmatizer.lemmatize( 23 | wordnet_lemmatizer.lemmatize( 24 | " ".join(word.replace("/", " ").split()), 25 | pos="v", 26 | ), 27 | pos="n", 28 | ) for word in words 29 | ] 30 | else: 31 | words = words.strip().split() 32 | words = [word.lower() for word in words] 33 | for word_idx, word in enumerate(words): 34 | true_words = re.split("-|/| ", word) 35 | words[word_idx] = [ 36 | wordnet_lemmatizer.lemmatize( 37 | wordnet_lemmatizer.lemmatize(true_word, pos="v"), 38 | pos="n" 39 | ) for true_word in true_words if len(true_word) > 0 40 | ] 41 | return words 42 | -------------------------------------------------------------------------------- /config/example_job_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=cslr2_test # job name 3 | #SBATCH --account=vvh@v100 # project code 4 | #SBATCH -C v100-32g # choose nodes with 32G GPU memory 5 | #SBATCH --ntasks-per-node=1 # number of MPI tasks per node 6 | #SBATCH --gres=gpu:1 # number of GPUs per node 7 | #SBATCH --qos=qos_gpu-t3 # (20h) jobs 8 | #SBATCH --cpus-per-task=8 # number of cores per tasks 9 | #SBATCH --hint=nomultithread # we get physical cores not logical 10 | #SBATCH --time=00:30:00 # maximum execution time (HH:MM:SS) 11 | #SBATCH --output=logs/cslr2_test%j.out # output file name 12 | #SBATCH --error=logs/cslr2_test%j.err # error file namea 13 | 14 | set -x # echo launched commands 15 | 16 | module purge 17 | 18 | . ${WORK}/miniconda3/etc/profile.d/conda.sh 19 | conda activate gpu_pytorch_1.12 20 | 21 | cd "${WORK}/code/cslr2" 22 | 23 | export HYDRA_FULL_ERROR=1 # to get better error messages if job crashes 24 | export WANDB_MODE=offline 25 | 26 | # Set folder where to save outputs 27 | export RUN_NAME=runs/cslr2 28 | # Set checkpoint 29 | export PATH_TO_CHECKPOINT=${RUN_NAME}/models/model_best.pth 30 | 31 | # 1) Evaluate sentence retrieval 32 | python main.py run_name=cslr2_test checkpoint=${PATH_TO_CHECKPOINT} test=True 33 | 34 | # 2) Evaluate CSLR (in two steps) 35 | python extract_for_eval.py checkpoint=${PATH_TO_CHECKPOINT} 36 | 37 | python frame_level_evaluation.py prediction_pickle_files=${RUN_NAME}/cslr/eval/nn 38 | 39 | -------------------------------------------------------------------------------- /augmentations/video_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for augmenting videos. 3 | """ 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class DropFrames(object): 11 | """ 12 | Drop frames from a video or a sequence of video frames/features. 13 | 14 | Args: 15 | p_sequence (float): probability of augmenting a video sequence 16 | p_frame (float): probability of dropping a frame 17 | """ 18 | def __init__( 19 | self, 20 | p_sequence: float = 0.5, 21 | p_frame: float = 0.3, 22 | ) -> None: 23 | self.p_sequence = p_sequence 24 | self.p_frame = p_frame 25 | 26 | def _drop_frames( 27 | self, 28 | video: Union[np.ndarray, torch.Tensor], 29 | ) -> Union[np.ndarray, torch.Tensor]: 30 | # keep at least 25 frames 31 | if len(video) <= 25: 32 | return video, np.arange(len(video)) 33 | kept_frames = 25 34 | while kept_frames <= 25: 35 | # assign a random probability to each frame 36 | frames_probs = np.random.rand(len(video)) 37 | # get indices of frames to keep 38 | kept_indices = np.where(frames_probs >= self.p_frame)[0] 39 | kept_frames = len(kept_indices) 40 | return video[kept_indices], kept_indices 41 | 42 | def __call__( 43 | self, 44 | video: Union[np.ndarray, torch.Tensor], 45 | ) -> Union[np.ndarray, torch.Tensor]: 46 | if np.random.rand() < self.p_sequence: 47 | return self._drop_frames(video) 48 | return video, np.arange(len(video)) 49 | -------------------------------------------------------------------------------- /utils/instantiate_augmentations.py: -------------------------------------------------------------------------------- 1 | """Functions to instantiate video and text augmentations.""" 2 | import hydra 3 | from omegaconf import DictConfig 4 | from torchvision.transforms import Compose 5 | 6 | 7 | def text_augmentations(cfg: DictConfig): 8 | """Instantiate text augmentations""" 9 | augmentations = None 10 | if cfg.augmentation.do_swap: 11 | swap_words = hydra.utils.instantiate(cfg.augmentation.swap_words) 12 | if not cfg.augmentation.do_drop or not cfg.augmentation.do_shuffle: 13 | augmentations = swap_words 14 | if cfg.augmentation.do_drop: 15 | drop_words = hydra.utils.instantiate(cfg.augmentation.drop_words) 16 | if cfg.augmentation.do_swap and not cfg.augmentation.do_shuffle: 17 | augmentations = Compose([swap_words, drop_words]) 18 | else: 19 | augmentations = drop_words 20 | if cfg.augmentation.do_shuffle: 21 | shuffle_words = hydra.utils.instantiate(cfg.augmentation.shuffle_words) 22 | if cfg.augmentation.do_swap and not cfg.augmentation.do_drop: 23 | augmentations = Compose([swap_words, shuffle_words]) 24 | elif cfg.augmentation.do_drop and not cfg.augmentation.do_swap: 25 | augmentations = Compose([drop_words, shuffle_words]) 26 | elif cfg.augmentation.do_drop and cfg.augmentation.do_swap: 27 | augmentations = Compose([swap_words, drop_words, shuffle_words]) 28 | else: 29 | augmentations = shuffle_words 30 | return augmentations 31 | 32 | 33 | def vid_augmentations(cfg: DictConfig): 34 | """Instantiate vid augmentations""" 35 | augmentations = None 36 | if cfg.augmentation.do_frame_drop: 37 | augmentations = hydra.utils.instantiate(cfg.augmentation.frame_drop) 38 | return augmentations 39 | -------------------------------------------------------------------------------- /config/cslr2_eval.yaml: -------------------------------------------------------------------------------- 1 | # configuration for CSLR evaluation 2 | hydra: 3 | run: 4 | dir: .{runs}/${run_name} 5 | 6 | defaults: 7 | - _self_ 8 | - paths: public # configuration file for paths 9 | - model: cslr2 # configuration file for model 10 | - dataset: sentence # configuration file for dataset 11 | 12 | run_name: ?? # name of the run 13 | checkpoint: ?? # path to a checkpoint to load 14 | 15 | ## Options for feature extraction 16 | swin: False # if True, loaded model is a Swin model (used for Swin Transformer sliding window baseline) 17 | 18 | classification_only: False # if True, only save linear classification results 19 | nn_classification_only: False # if True, only save nearest neighbour classification results 20 | synonym_grouping: True # if True, merge predictions when saving features 21 | temp: 0.05 # temperature to scale the nn sim matrix 22 | 23 | ## Options for frame level evaluation 24 | prediction_pickle_files: null 25 | gt_csv_root: ${paths.heuristic_aligned_csv_root} # root of GT to evaluate against 26 | remove_synonyms_handling: False # if True, remove synonym handling from the evaluation procedure (merging predictions + if synonyms correct) 27 | remove_synonym_grouping: False # if True, remove synonym grouping from the evaluation procedure (i.e., counting synonyms correct) 28 | 29 | do_vis: False # if True, save CSLR visualisations (in png format) 30 | do_phrases_vis: False # if True, save CSLR visualisations for phrases (in png format) 31 | 32 | effect_of_post_processing: False # if True, save CSLR visualisations for the effect of post-processing (in png format) 33 | 34 | test_search: False # if True, search for the best threshold and min_count params for the test set 35 | 36 | optimal_tau: null # threshold used for evaluation on the test set 37 | optimal_mc: null # min_count used for evaluation on the test set 38 | 39 | no_save: False # if True, do not save results 40 | 41 | automatic_annotations: False # if True, predictions are from automatic spottings. 42 | 43 | fps: 25 # framerate of the dataset 44 | -------------------------------------------------------------------------------- /utils/instantiate_model.py: -------------------------------------------------------------------------------- 1 | """Functions to instantiate models with Hydra""" 2 | import hydra 3 | import torch 4 | import torch.nn as nn 5 | 6 | from omegaconf import DictConfig 7 | 8 | 9 | def instantiate_model(cfg: DictConfig) -> nn.Module: 10 | """Instantiate model from config.""" 11 | model = hydra.utils.instantiate(cfg.model.cslr2) 12 | return model 13 | 14 | 15 | def handle_model_freeze( 16 | model: nn.Module, 17 | cfg: DictConfig, 18 | ) -> nn.Module: 19 | """Freeze model parameters according to options in the config.""" 20 | # freeze generator if not training with SignCls (avoid unused parameters error) 21 | if cfg.loss.lda_sign_cls == 0: 22 | for name, param in model.named_parameters(): 23 | if "generator" in name and "text_encoder" not in name: 24 | param.requires_grad = False 25 | # freeze transformer if specified 26 | if cfg.model.freeze_transformer: 27 | for name, param in model.named_parameters(): 28 | if "generator" in name: 29 | param.requires_grad = False 30 | # handling same text and video ll 31 | if cfg.model.cslr2.same_text_ll: 32 | for name, param in model.named_parameters(): 33 | if "text_word_ll" in name: 34 | param.requires_grad = False 35 | if cfg.model.cslr2.same_video_ll: 36 | for name, param in model.named_parameters(): 37 | if "video_token_ll" in name: 38 | param.requires_grad = False 39 | # freeze text encoder 40 | for name, param in model.named_parameters(): 41 | if "text_encoder" in name: 42 | param.requires_grad = False 43 | return model 44 | 45 | 46 | def load_checkpoint( 47 | cfg: DictConfig, 48 | model: nn.Module, 49 | opt: torch.optim.Optimizer, 50 | device: torch.device, 51 | ): 52 | """Load checkpoint if specified in the config.""" 53 | if cfg.checkpoint is not None: 54 | # load checkpoint 55 | checkpoint = torch.load(cfg.checkpoint, map_location=device) 56 | # remove module. from checkpoint keys 57 | model_state_dict = checkpoint["model_state_dict"] 58 | model_state_dict = { 59 | k.replace("module.", ""): v for k, v in model_state_dict.items()} 60 | # to prevent errors when text encoder is not in chkpt 61 | model.load_state_dict(model_state_dict, strict=False) 62 | if "optimizer_state_dict" in checkpoint: 63 | opt.load_state_dict(checkpoint["optimizer_state_dict"]) 64 | if "epoch" in checkpoint: 65 | cfg.trainer.epoch_start = checkpoint["epoch"] 66 | print(f"Loaded checkpoint from {cfg.checkpoint}") 67 | return model, opt 68 | -------------------------------------------------------------------------------- /models/sbert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extraction of Sentence-Level Embeddings with SBERT models. 3 | """ 4 | from typing import List, Optional, Union 5 | 6 | import torch 7 | from sentence_transformers import SentenceTransformer 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | 11 | def make_sentence_model( 12 | model_name: str = "all-mpnet-base-v2", 13 | root_path: Optional[str] = None, 14 | ) -> SentenceTransformer: 15 | """ 16 | Setup SentenceTransformer model. 17 | 18 | Args: 19 | model_name (str): name of the model to use. 20 | root_path (Optional[str]): path to the root directory of the model. 21 | """ 22 | if root_path is not None: 23 | model = SentenceTransformer(root_path + model_name) 24 | else: 25 | model = SentenceTransformer(model_name) 26 | return model 27 | 28 | 29 | def extract_sentence_embeddings( 30 | model: SentenceTransformer, 31 | sentences: Union[str, List[str]], 32 | device: torch.device, 33 | ) -> torch.Tensor: 34 | """ 35 | Extract sentence embeddings. 36 | 37 | Args: 38 | model (SentenceTransformer): SentenceTransformer model. 39 | sentences (Union[str, List[str]]): list of sentences to encode. 40 | device (torch.device): device to use for the model. 41 | 42 | Returns: 43 | torch.Tensor: tensor of sentence-level embeddings. 44 | """ 45 | batch_size = 1 46 | if isinstance(sentences, List): 47 | batch_size = len(sentences) 48 | try: 49 | embeddings = model.encode( 50 | sentences, 51 | batch_size=batch_size, 52 | show_progress_bar=False, 53 | convert_to_tensor=True, 54 | device=device, 55 | ) 56 | except AttributeError: 57 | embeddings = model.module.encode( 58 | sentences, 59 | batch_size=batch_size, 60 | show_progress_bar=False, 61 | convert_to_tensor=True, 62 | device=device, 63 | ) 64 | return embeddings 65 | 66 | 67 | def extract_token_embeddings( 68 | model: SentenceTransformer, 69 | sentences: Union[str, List[str]], 70 | device: torch.device, 71 | ) -> torch.Tensor: 72 | """ 73 | Extract embeddings at token level. 74 | 75 | Args: 76 | model (SentenceTransformer): SentenceTransformer model. 77 | sentences (Union[str, List[str]]): list of sentences to encode. 78 | device (torch.device): device to use for the model. 79 | 80 | Returns: 81 | torch.Tensor: tensor of token-level embeddings (padded). 82 | """ 83 | batch_size = len(sentences) if isinstance(sentences, List) else 1 84 | embeddings = model.encode( 85 | sentences, 86 | batch_size=batch_size, 87 | output_value="token_embeddings", 88 | show_progress_bar=False, 89 | device=device, 90 | ) 91 | return pad_sequence(embeddings, batch_first=True, padding_value=0) 92 | -------------------------------------------------------------------------------- /augmentations/text_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for augmenting text data. 3 | Implemented: 4 | - dropping words 5 | - swap words 6 | - shuffle sentences 7 | """ 8 | from typing import List 9 | 10 | import numpy as np 11 | 12 | 13 | class DropWords(object): 14 | """ 15 | Drop words from a sentence. 16 | 17 | Args: 18 | p_sentence (float): probability of augmenting a sentence 19 | p_word (float): probability of dropping a word 20 | """ 21 | def __init__( 22 | self, 23 | p_sentence: float = 0.5, 24 | p_word: float = 0.3, 25 | ) -> None: 26 | self.p_sentence = p_sentence 27 | self.p_word = p_word 28 | 29 | def _drop_words(self, sentence: str) -> str: 30 | # split sentence into words 31 | words = np.array(sentence.split()) 32 | if len(words) == 1: 33 | return sentence 34 | kept_words = -1 35 | while kept_words < 1: 36 | # assign a random probability to each word 37 | words_probs = np.random.rand(len(words)) 38 | # get indices of words to keep 39 | kept_indices = np.where(words_probs >= self.p_word)[0] 40 | kept_words = len(kept_indices) 41 | # keep only the words with indices in kept_indices 42 | kept_words = words[kept_indices] 43 | return " ".join(kept_words) 44 | 45 | def __call__(self, sentence: str) -> str: 46 | if np.random.rand() < self.p_sentence: 47 | return self._drop_words(sentence) 48 | return sentence 49 | 50 | 51 | class SwapWords(object): 52 | """ 53 | Swap words in a sentence. 54 | 55 | Args: 56 | nb_swaps (int): number of swaps 57 | """ 58 | def __init__(self, nb_swaps: int = 2): 59 | self.nb_swaps = nb_swaps 60 | 61 | def __call__(self, sentence: str) -> str: 62 | # split sentence into words 63 | words = sentence.split() 64 | if len(words) == 1: 65 | return sentence 66 | for _ in range(self.nb_swaps): 67 | words = self._swap_words(words) 68 | return " ".join(words) 69 | 70 | def _swap_words(self, words: List[str]) -> List[str]: 71 | random_idx_1 = np.random.randint(len(words)) 72 | random_idx_2 = np.random.randint(len(words)) 73 | # could do identity swap 74 | words[random_idx_1], words[random_idx_2] = words[random_idx_2], words[random_idx_1] 75 | return words 76 | 77 | 78 | class ShuffleWords(object): 79 | """ 80 | Shuffle words in a sentence. 81 | 82 | Args: 83 | p_shuffle (float): probability of shuffling a sentence 84 | """ 85 | def __init__(self, p_shuffle: float = 0.5): 86 | self.p_shuffle = p_shuffle 87 | 88 | def _shuffle_words(self, sentence: str) -> str: 89 | # split sentence into words 90 | words = sentence.split() 91 | if len(words) == 1: 92 | return sentence 93 | np.random.shuffle(words) 94 | return " ".join(words) 95 | 96 | def __call__(self, sentence: str) -> str: 97 | if np.random.rand() < self.p_shuffle: 98 | return self._shuffle_words(sentence) 99 | return sentence 100 | -------------------------------------------------------------------------------- /loops/retrieval_loop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python file defining the retrieval loop of the model. 3 | """ 4 | import lmdb 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from omegaconf import DictConfig 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from loops.retrieval import t2v_metrics, v2t_metrics 13 | from utils.matplotlib_utils import save_retrieval_vis 14 | 15 | def retrieval_loop( 16 | model: nn.Module, 17 | vis_loader: DataLoader, 18 | rgb_lmdb_env: lmdb.Environment, 19 | cfg: DictConfig, 20 | setname: str, 21 | epoch: int, 22 | ): 23 | """ 24 | T2V and V2T Retrieval Loop + eventual video visualisation saving. 25 | 26 | Args: 27 | model (nn.Module): model to train 28 | vis_loader (DataLoader): visualisation DataLoader 29 | rgb_lmdb_env (lmdb.Environment): lmdb environment for video retrieval 30 | cfg (DictConfig): config file 31 | setname (str): name of the dataset 32 | epoch (int): current epoch 33 | """ 34 | device = torch.device( 35 | f"cuda:{cfg.local_rank}" if torch.cuda.is_available() else "cpu") 36 | model.eval() 37 | all_text, all_cls_tokens = [], [] 38 | all_sentences = [] 39 | video_names, sub_starts, sub_ends = [], [], [] 40 | with torch.no_grad(): 41 | for _, batch in enumerate(tqdm(iter(vis_loader))): 42 | # unpack the batch 43 | subs, feats, _, _, _, _, names, starts, ends = batch 44 | feats = feats.to(device) 45 | # retrieval forward pass 46 | cls_tokens, sentence_embds = model.forward_sentret(feats, subs) if not cfg.distributed \ 47 | else model.module.forward_sentret(feats, subs) 48 | all_text.append(sentence_embds) 49 | all_cls_tokens.append(cls_tokens) 50 | all_sentences.extend(subs) 51 | video_names.extend(names) 52 | sub_starts.extend(starts) 53 | sub_ends.extend(ends) 54 | all_text = torch.cat(all_text, dim=0) 55 | all_cls_tokens = torch.cat(all_cls_tokens, dim=0) 56 | # compute similarities st sims[i, j] = 57 | sims = all_text @ all_cls_tokens.T 58 | sims = sims.detach().cpu().numpy() 59 | v2t, ranks = v2t_metrics(sims) 60 | t2v, _ = t2v_metrics(sims) 61 | 62 | if cfg.worst_retrieval: 63 | # get the worst retrieval cases 64 | indices = np.argsort(-ranks) 65 | # reorder sims 66 | sims = sims[:, indices] # first order columns 67 | sims = sims[indices, :] # then order rows 68 | all_sentences = np.array(all_sentences)[indices] 69 | sub_starts = np.array(sub_starts)[indices] 70 | sub_ends = np.array(sub_ends)[indices] 71 | video_names = np.array(video_names)[indices] 72 | if cfg.nb_vis > 0: 73 | save_retrieval_vis( 74 | cfg=cfg, 75 | sim_matrix=sims, 76 | all_sentences=all_sentences, 77 | video_names=video_names, 78 | sub_starts=sub_starts, 79 | sub_ends=sub_ends, 80 | rgb_lmdb_env=rgb_lmdb_env, 81 | setname=setname, 82 | epoch=epoch, 83 | k=5 if not cfg.worst_retrieval else 20, 84 | text_only=False, 85 | ) 86 | return v2t, t2v 87 | -------------------------------------------------------------------------------- /misc/process_cslr_json/remove_star_annots_from_csv.py: -------------------------------------------------------------------------------- 1 | """Remove star annotations from CSLR csv files.""" 2 | import argparse 3 | import os 4 | 5 | import pandas as pd 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--csv_root", 11 | type=str, 12 | required=True, 13 | ) 14 | args = parser.parse_args() 15 | 16 | save_dir = os.path.join( 17 | args.csv_root.replace("with_stars_", "") 18 | ) 19 | if not os.path.exists(save_dir): 20 | os.makedirs(save_dir) 21 | print(f"Created directory {save_dir}") 22 | 23 | for split in ["train", "val", "test"]: 24 | csv_dir = os.path.join(args.csv_root, f"0/{split}") 25 | for csv_file in os.listdir(csv_dir): 26 | if csv_file.endswith(".csv"): 27 | # open file with pandas 28 | df = pd.read_csv(os.path.join(csv_dir, csv_file)) 29 | 30 | # loop over rows 31 | for idx, row in df.iterrows(): 32 | # remove star annotations 33 | if "*" in row["approx gloss sequence"]: 34 | new_row = row["approx gloss sequence"] 35 | new_row = new_row.replace("]", "[").split("[") 36 | annots = new_row[::2] 37 | times = new_row[1::2] 38 | filtered_annots = [] 39 | filtered_times = [] 40 | for annot, time in zip(annots, times): 41 | # check if there is a lexical annotation 42 | has_lexical = any( 43 | ["*" not in a for a in annot.split("/")]) 44 | if has_lexical: 45 | annot = "/".join([ 46 | a for a in annot.split("/") 47 | if "*" not in a 48 | ]) 49 | filtered_annots.append(annot) 50 | filtered_times.append(time) 51 | else: 52 | continue 53 | if len(filtered_annots) == 0: 54 | # can drop the row in question entirely 55 | df.drop(idx, inplace=True) 56 | else: 57 | # replace the row with the filtered annotations 58 | new_annots = "" 59 | for annot, time in zip(filtered_annots, filtered_times): 60 | new_annots += f"{annot}[{time}]" 61 | df.loc[idx, "approx gloss sequence"] = new_annots 62 | # save the dataframe 63 | if not os.path.exists(os.path.join(save_dir, f"0/{split}")): 64 | os.makedirs(os.path.join(save_dir, f"0/{split}")) 65 | print(f"Created directory {save_dir}0/{split}") 66 | df.to_csv( 67 | os.path.join( 68 | save_dir, 69 | f"0/{split}/{csv_file}" 70 | ), 71 | index=False 72 | ) 73 | -------------------------------------------------------------------------------- /misc/process_cslr_json/run_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running full pipeline for CSLR data pre-processing. 3 | """ 4 | import os 5 | 6 | from argparse import ArgumentParser 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = ArgumentParser() 11 | parser.add_argument( 12 | "--input_dir", 13 | type=str, 14 | help="Path to the directory containing CSLR json data.", 15 | required=True, 16 | ) 17 | parser.add_argument( 18 | "--output_dir", 19 | type=str, 20 | help="Path to the output directory.", 21 | required=True, 22 | ) 23 | parser.add_argument( 24 | "--subs_dir", 25 | type=str, 26 | help="Path to the directory containing subs files (manually aligned).", 27 | required=True, 28 | ) 29 | parser.add_argument( 30 | "--subset2episode", 31 | type=str, 32 | help="Path to json file containing subset to episode mapping.", 33 | required=True, 34 | ) 35 | args = parser.parse_args() 36 | args.output_dir = args.output_dir.rstrip("/") 37 | # the following assumes that the output_directory last 8 characters are date in format DD.MM.YY 38 | output_dir2 = args.output_dir[:-8] + "extended_boundaries_" + args.output_dir[-8:] 39 | output_dir3 = output_dir2[:-8] + "fix_alignment_" + output_dir2[-8:] 40 | 41 | assignment_command = "python misc/process_cslr_json/preprocess_raw_json_annotations.py " 42 | assignment_command += f"--output_dir {args.output_dir} --input_dir {args.input_dir} " 43 | assignment_command += f"--subs_dir {args.subs_dir} --subset2episode {args.subset2episode}" 44 | 45 | # run assignment command 46 | print(f"Running command: {assignment_command}") 47 | os.system(assignment_command) 48 | 49 | fix_boundaries_command = "python misc/process_cslr_json/fix_boundaries.py " 50 | fix_boundaries_command += f"--csv_file {args.output_dir}" 51 | # run fix boundaries command 52 | print(f"Running command: {fix_boundaries_command}") 53 | os.system(fix_boundaries_command) 54 | 55 | fix_alignment_command = "python misc/process_cslr_json/fix_alignment.py " 56 | fix_alignment_command += f"--csv_file {output_dir2}" 57 | fix_alignment_command2 = "python misc/process_cslr_json/preprocess_raw_json_annotations.py " 58 | fix_alignment_command2 += f"--output_dir {output_dir3} --input_dir {args.input_dir} " 59 | fix_alignment_command2 += f"--subs_dir {output_dir2} --misalignment_fix " 60 | fix_alignment_command2 += f"--subset2episode {args.subset2episode}" 61 | # run fix alignment command 62 | print(f"Running command: {fix_alignment_command}") 63 | os.system(fix_alignment_command) 64 | print(f"Running command: {fix_alignment_command2}") 65 | os.system(fix_alignment_command2) 66 | 67 | remove_stars_command = "python misc/process_cslr_json/remove_star_annots_from_csv.py " 68 | remove_stars_command += f"--csv_root {output_dir2}" 69 | remove_stars_command2 = "python misc/process_cslr_json/remove_star_annots_from_csv.py " 70 | remove_stars_command2 += f"--csv_root {output_dir3}" 71 | # run remove stars command 72 | print(f"Running command: {remove_stars_command}") 73 | os.system(remove_stars_command) 74 | print(f"Running command: {remove_stars_command2}") 75 | os.system(remove_stars_command2) 76 | -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for wandb 3 | """ 4 | from typing import Union 5 | 6 | import wandb 7 | from numpy import ndarray 8 | from omegaconf import DictConfig, OmegaConf 9 | from torch import Tensor 10 | 11 | 12 | def wandb_run_name(cfg: DictConfig) -> str: 13 | """ 14 | Setup wandb run name with some hyperparameters 15 | """ 16 | run_name = cfg.run_name 17 | return run_name 18 | 19 | 20 | def wandb_setup(cfg: DictConfig, setname: str="cslr2") -> None: 21 | """ 22 | Initialize wandb 23 | """ 24 | if cfg.do_print: 25 | wandb.config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 26 | wandb.init( 27 | project=setname, 28 | dir=cfg.paths.log_dir, 29 | name=wandb_run_name(cfg), 30 | ) 31 | 32 | 33 | def log_retrieval_performances( 34 | train_v2t: Union[Tensor, ndarray], 35 | train_t2v: Union[Tensor, ndarray], 36 | val_v2t: Union[Tensor, ndarray], 37 | val_t2v: Union[Tensor, ndarray], 38 | epoch: int, 39 | pl_as_subtitles: bool = False, 40 | ): 41 | """ 42 | Log retrieval performances 43 | """ 44 | prefix = "pl_as_subtitles_" if pl_as_subtitles else "" 45 | wandb.log({ 46 | f"{prefix}train_v2t_R1": train_v2t["R1"], 47 | f"{prefix}train_v2t_R5": train_v2t["R5"], 48 | f"{prefix}train_v2t_R10": train_v2t["R10"], 49 | f"{prefix}train_v2t_R50": train_v2t["R50"], 50 | f"{prefix}train_v2t_medr": train_v2t["MedR"], 51 | f"{prefix}train_v2t_meanr": train_v2t["MeanR"], 52 | f"{prefix}train_v2t_geometric_mean_R1-R5-R10": train_v2t["geometric_mean_R1-R5-R10"], 53 | f"{prefix}train_t2v_R1": train_t2v["R1"], 54 | f"{prefix}train_t2v_R5": train_t2v["R5"], 55 | f"{prefix}train_t2v_R10": train_t2v["R10"], 56 | f"{prefix}train_t2v_R50": train_t2v["R50"], 57 | f"{prefix}train_t2v_medr": train_t2v["MedR"], 58 | f"{prefix}train_t2v_meanr": train_t2v["MeanR"], 59 | f"{prefix}train_t2v_geometric_mean_R1-R5-R10": train_t2v["geometric_mean_R1-R5-R10"], 60 | f"{prefix}val_v2t_R1": val_v2t["R1"], 61 | f"{prefix}val_v2t_R5": val_v2t["R5"], 62 | f"{prefix}val_v2t_R10": val_v2t["R10"], 63 | f"{prefix}val_v2t_R50": val_v2t["R50"], 64 | f"{prefix}val_v2t_medr": val_v2t["MedR"], 65 | f"{prefix}val_v2t_meanr": val_v2t["MeanR"], 66 | f"{prefix}val_v2t_geometric_mean_R1-R5-R10": val_v2t["geometric_mean_R1-R5-R10"], 67 | f"{prefix}val_t2v_R1": val_t2v["R1"], 68 | f"{prefix}val_t2v_R5": val_t2v["R5"], 69 | f"{prefix}val_t2v_R10": val_t2v["R10"], 70 | f"{prefix}val_t2v_R50": val_t2v["R50"], 71 | f"{prefix}val_t2v_medr": val_t2v["MedR"], 72 | f"{prefix}val_t2v_meanr": val_t2v["MeanR"], 73 | f"{prefix}val_t2v_geometric_mean_R1-R5-R10": val_t2v["geometric_mean_R1-R5-R10"], 74 | "epoch": epoch, 75 | }) 76 | 77 | 78 | def log_test_retrieval_performances( 79 | test_v2t: Union[Tensor, ndarray], 80 | test_t2v: Union[Tensor, ndarray], 81 | epoch: int, 82 | ): 83 | """ 84 | Log Retrieval Performances (for the test set) 85 | """ 86 | wandb.log({ 87 | "test_v2t_R1": test_v2t["R1"], 88 | "test_v2t_R5": test_v2t["R5"], 89 | "test_v2t_R10": test_v2t["R10"], 90 | "test_v2t_R50": test_v2t["R50"], 91 | "test_v2t_medr": test_v2t["MedR"], 92 | "test_v2t_meanr": test_v2t["MeanR"], 93 | "test_v2t_geometric_mean_R1-R5-R10": test_v2t["geometric_mean_R1-R5-R10"], 94 | "test_t2v_R1": test_t2v["R1"], 95 | "test_t2v_R5": test_t2v["R5"], 96 | "test_t2v_R10": test_t2v["R10"], 97 | "test_t2v_R50": test_t2v["R50"], 98 | "test_t2v_medr": test_t2v["MedR"], 99 | "test_t2v_meanr": test_t2v["MeanR"], 100 | "test_t2v_geometric_mean_R1-R5-R10": test_t2v["geometric_mean_R1-R5-R10"], 101 | "epoch": epoch, 102 | }) 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | runs/ 163 | .runs/ 164 | .{runs}/ 165 | runs 166 | misc/plots/ 167 | .vscode/ 168 | 169 | bobsl 170 | t5_checkpoint 171 | logs 172 | -------------------------------------------------------------------------------- /loss/hn_nce.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hard-Negative NCE loss for contrastive learning. 3 | https://arxiv.org/pdf/2301.02280.pdf 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class HardNegativeNCE(nn.Module): 13 | """ 14 | Hard Negative NCE loss for contrastive learning. 15 | """ 16 | def __init__(self, temperature: float = 0.07, alpha: float = 1.0, beta: float = 0.0): 17 | """ 18 | Args: 19 | temperature: temperature for the softmax 20 | alpha: rescaling factor for positiver terms 21 | beta: concentration parameter 22 | 23 | Note: 24 | alpha = 1 and beta = 0 corresponds to the original Info-NCE loss 25 | """ 26 | super(HardNegativeNCE, self).__init__() 27 | self.temperature = temperature 28 | self.alpha = alpha 29 | self.beta = beta 30 | 31 | def forward( 32 | self, 33 | video_embds: torch.Tensor, 34 | text_embds: torch.Tensor, 35 | labels: Optional[torch.Tensor] = None, 36 | debug_test: bool = False, 37 | ) -> float: 38 | """ 39 | Args: 40 | video_embds: (batch_size, video_embd_dim) 41 | text_embds: (batch_size, text_embd_dim) 42 | debug_test: if True, then also compute Info-NCE loss 43 | """ 44 | batch_size = video_embds.size(0) 45 | # computation of the similarity matrix 46 | sim_matrix = video_embds @ text_embds.T # (batch_size, batch_size) 47 | # scale the similarity matrix with the temperature 48 | sim_matrix = sim_matrix / self.temperature 49 | sim_matrix = sim_matrix.float() 50 | if labels is not None: 51 | mask = (labels.unsqueeze(0) == labels.unsqueeze(1)) 52 | mask = mask & (~torch.eye( 53 | len(sim_matrix), 54 | device=sim_matrix.device, 55 | dtype=mask.dtype 56 | )).bool() 57 | sim_matrix = sim_matrix.masked_fill(mask, float('-inf')) 58 | 59 | # sim_matrix(i, j) = 60 | 61 | nominator = torch.diagonal(sim_matrix) 62 | if debug_test: 63 | # V2T 64 | denominator = torch.logsumexp(sim_matrix, dim=1) 65 | 66 | # T2V 67 | denominator2 = torch.logsumexp(sim_matrix, dim=0) 68 | beta_sim = self.beta * sim_matrix 69 | w_v2t = (batch_size - 1) * torch.exp(beta_sim) / \ 70 | (torch.exp(beta_sim).sum(dim=1) - torch.exp(torch.diagonal(beta_sim))) 71 | w_t2v = (batch_size - 1) * torch.exp(beta_sim) / \ 72 | (torch.exp(beta_sim).sum(dim=0) - torch.exp(torch.diagonal(beta_sim))) 73 | # replace the diagonal terms of w_v2t and w_t2v with alpha 74 | w_v2t[range(batch_size), range(batch_size)] = self.alpha 75 | w_t2v[range(batch_size), range(batch_size)] = self.alpha 76 | denominator_v2t = torch.log((torch.exp(sim_matrix) * w_v2t).sum(dim=1)) 77 | denominator_t2v = torch.log((torch.exp(sim_matrix) * w_t2v).sum(dim=0)) 78 | hn_nce_loss = (denominator_v2t - nominator).mean() + (denominator_t2v - nominator).mean() 79 | if debug_test: 80 | info_nce_loss = (denominator - nominator).mean() + (denominator2 - nominator).mean() 81 | print(f"hn_nce_loss: {hn_nce_loss}") 82 | print(f"info_nce_loss: {info_nce_loss}") 83 | return hn_nce_loss 84 | 85 | 86 | if __name__ == "__main__": 87 | # sanity check that the loss is working 88 | # looking whether the loss is equal to the Info-NCE loss 89 | # when alpha = 1 and beta = 0 90 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | for _ in range(10): 92 | in_video_embd = torch.randn(1024, 512).to(device) 93 | in_text_embd = torch.randn(1024, 512).to(device) 94 | 95 | # normalize 96 | in_video_embd = F.normalize(in_video_embd, dim=-1) 97 | in_text_embd = F.normalize(in_text_embd, dim=-1) 98 | loss_fn = HardNegativeNCE(beta=0.0, alpha=1.0) 99 | loss_fn(in_video_embd, in_text_embd, debug_test=True) 100 | -------------------------------------------------------------------------------- /utils/synonyms.py: -------------------------------------------------------------------------------- 1 | """Functions related to synonyms.""" 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def synonym_combine( 9 | labels: Union[torch.tensor, np.ndarray], 10 | probs: Union[torch.tensor, np.ndarray], 11 | synonyms: dict, 12 | verbose: bool = False, 13 | ): 14 | """ 15 | Function aggregating probabilities of synonyms. 16 | 17 | Args: 18 | labels (Union[torch.tensor, np.ndarray]): list of labels. 19 | probs (Union[torch.tensor, np.ndarray]): list of probabilities. 20 | synonyms (dict): dictionary of synonyms. 21 | verbose (bool): whether to print the output. 22 | 23 | Returns: 24 | Union[torch.tensor, np.ndarray]: aggregated probabilities. 25 | Union[torch.tensor, np.ndarray]: aggregated labels. 26 | """ 27 | change = False 28 | new_probs = [] 29 | for anchor_idx, anchor in enumerate(labels): 30 | try: 31 | anchor = anchor.replace("-", " ") 32 | syns = synonyms[anchor] 33 | if verbose: 34 | print(anchor, syns) 35 | anchor_new_prob = 0 36 | for checked_idx, checked_label in enumerate(labels): 37 | checked_label = checked_label.replace("-", " ") 38 | if checked_label in syns: 39 | if verbose: 40 | print(checked_label) 41 | anchor_new_prob += probs[checked_idx] 42 | if checked_idx != anchor_idx: 43 | change = True 44 | new_probs.append(anchor_new_prob) 45 | except KeyError: 46 | # prediction not in the synonym list 47 | new_probs.append(probs[anchor_idx]) 48 | if change: 49 | # need to sort 50 | sorted_indices = np.argsort(- 1 * np.array(new_probs)) 51 | if verbose: 52 | print(labels, new_probs) 53 | new_probs = np.array(new_probs)[sorted_indices] 54 | labels = np.array(labels)[sorted_indices] 55 | if verbose: 56 | print(labels, new_probs) 57 | else: 58 | new_probs = np.array(new_probs) 59 | labels = np.array(labels) 60 | return new_probs, labels 61 | 62 | 63 | def fix_synonyms_dict(synonyms: dict, verbose: bool = False) -> dict: 64 | """ 65 | Make sure that the synonyms dictionary satisfies the following: 66 | - if a is a synonym of b, then b is a synonym of a 67 | - a is a synonym of a 68 | 69 | Args: 70 | synonyms (dict): dictionary of synonyms. 71 | verbose (bool): whether to print the output. 72 | 73 | Returns: 74 | dict: updated dictionary of synonyms. 75 | """ 76 | change_count = 0 77 | syn_change_count = 0 78 | for word, syns in synonyms.items(): 79 | if word not in syns: 80 | change_count += 1 81 | syns.append(word) 82 | # need to check that for each synonym in the list, the word is in the list of synonyms for that synonym 83 | for syn in syns: 84 | if word not in synonyms[syn]: 85 | synonyms[syn].append(word) 86 | syn_change_count += 1 87 | if verbose: 88 | print(f"Added {change_count} words to their own synonyms list") 89 | print(f"Total number of words: {len(synonyms)}") 90 | print(f"Added {syn_change_count} so that a is a syn of b is equivalent to b is a syn of a") 91 | return synonyms 92 | 93 | 94 | def extend(labels: List[str], synonyms: dict) -> List[str]: 95 | """ 96 | Extend a list of labels using synonyms 97 | 98 | Args: 99 | labels (List[str]): list of labels 100 | synonyms (dict): synonym dictionary 101 | 102 | Returns: 103 | new_labels (List[str]): list of labels with synonyms 104 | """ 105 | new_labels = [] 106 | for lbl in labels: 107 | if synonyms is not None and lbl in synonyms.keys(): 108 | temp_synonyms = synonyms[lbl] 109 | if lbl not in temp_synonyms: 110 | temp_synonyms.append(lbl) 111 | new_labels.extend(temp_synonyms) 112 | else: 113 | # no synonyms or no synonyms for this word 114 | new_labels.extend([lbl]) 115 | return new_labels 116 | -------------------------------------------------------------------------------- /misc/process_cslr_json/fix_boundaries.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python script to fix the boundaries of subtitles in the CSLR dataset. 3 | 4 | The fix is meant to change boundaries of subtitles such that all annotations that are 5 | associated with it completely fall within (intersection == 1). 6 | """ 7 | import argparse 8 | import glob 9 | import os 10 | 11 | import pandas as pd 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--csv_file", 18 | type=str, 19 | required=True, 20 | help="Path to the CSV file containing the subtitles and annotations.", 21 | ) 22 | parser.add_argument( 23 | "--debug", 24 | action="store_true", 25 | ) 26 | args = parser.parse_args() 27 | # need to remove "/" from the end of the path 28 | args.csv_file = args.csv_file.rstrip("/") 29 | names = os.path.basename(args.csv_file) 30 | 31 | splits = ["train", "val", "test"] 32 | total_changes = { 33 | "all": {"start": 0, "end": 0}, 34 | "train": {"start": 0, "end": 0}, 35 | "val": {"start": 0, "end": 0}, 36 | "test": {"start": 0, "end": 0}, 37 | } 38 | for split in splits: 39 | print(f"Processing {split} split...") 40 | all_files = glob.glob(f"{args.csv_file}/0/{split}/*.csv") 41 | print(f"Found {len(all_files)} files.") 42 | for csv_file in all_files: 43 | # open csv file in question 44 | gt_df = pd.read_csv(csv_file, delimiter=",") 45 | # sort by start_sub 46 | gt_df = gt_df.sort_values(by=["start_sub"], ascending=True) 47 | gt_glosses = gt_df["approx gloss sequence"].tolist() 48 | left_boundary = [] 49 | right_boundary = [] 50 | for gt_gloss in gt_glosses: 51 | if isinstance(gt_gloss, float): 52 | # no gt 53 | left_boundary.append("") 54 | right_boundary.append("") 55 | else: 56 | gt_gloss = gt_gloss.replace("]", "[").replace("'", "") 57 | gt_annots = gt_gloss.split("[")[:-1] 58 | gt_timings = gt_gloss.replace(" ", "/").replace("--", "-") 59 | gt_timings = gt_timings.split("[")[:-1] 60 | gt_times = gt_timings[1::2] 61 | gt_annots = gt_annots[::2] 62 | gt_times, gt_annots = zip(*sorted(zip(gt_times, gt_annots))) 63 | left_boundary.append( 64 | f"{gt_annots[0].strip().replace(' ', '-')} {gt_times[0]}" 65 | ) 66 | right_boundary.append( 67 | f"{gt_annots[-1].strip().replace(' ', '-')} {gt_times[-1]}" 68 | ) 69 | gt_df["left_boundary"] = left_boundary 70 | gt_df["right_boundary"] = right_boundary 71 | 72 | # need to loop over all subtitles now 73 | starts = gt_df["start_sub"].tolist() 74 | ends = gt_df["end_sub"].tolist() 75 | subs = gt_df["english sentence"].tolist() 76 | assert ( 77 | len(starts) 78 | == len(ends) 79 | == len(subs) 80 | == len(left_boundary) 81 | == len(right_boundary) 82 | ) 83 | 84 | updated_starts, updated_ends = [], [] 85 | for start, end, sub, left_bound, right_bound in zip( 86 | starts, ends, subs, left_boundary, right_boundary 87 | ): 88 | if isinstance(left_bound, float) or isinstance(right_bound, float): 89 | continue 90 | left_bound = left_bound.split(" ") 91 | right_bound = right_bound.split(" ") 92 | left_bound = left_bound[1].split("-")[0] 93 | right_bound = right_bound[1].split("-")[-1] 94 | left_bound = float(left_bound) 95 | right_bound = float(right_bound) 96 | if left_bound < start: 97 | total_changes["all"]["start"] += 1 98 | total_changes[split]["start"] += 1 99 | if args.debug: 100 | print(f"Start: {left_bound} -> {start}") 101 | if right_bound > end: 102 | total_changes["all"]["end"] += 1 103 | total_changes[split]["end"] += 1 104 | if args.debug: 105 | print(f"End: {right_bound} -> {end}") 106 | updated_starts.append(min(start, left_bound)) 107 | updated_ends.append(max(end, right_bound)) 108 | 109 | gt_df["start_sub"] = updated_starts 110 | gt_df["end_sub"] = updated_ends 111 | # save the csv file to another location 112 | new_name = names[:-8] + "extended_boundaries_" + names[-8:] 113 | new_location = csv_file.replace(names, new_name) 114 | if not os.path.exists(os.path.dirname(new_location)): 115 | os.makedirs(os.path.dirname(new_location)) 116 | print(f"Created directory {os.path.dirname(new_location)}") 117 | gt_df.to_csv(new_location, index=False) 118 | print(total_changes) 119 | -------------------------------------------------------------------------------- /loops/train_loop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python file defining the training loop of the model. 3 | """ 4 | from typing import List, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import wandb 9 | from omegaconf import DictConfig 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from utils.gather import all_gather 14 | 15 | 16 | def train_loop( 17 | model: nn.Module, 18 | train_loader: DataLoader, 19 | opt: torch.optim.Optimizer, 20 | sent_ret_loss_fn: Optional[torch.nn.Module], 21 | sign_ret_loss_fn: Optional[torch.nn.Module], 22 | sign_cls_loss_fn: Optional[torch.nn.Module], 23 | epoch: int, 24 | cfg: DictConfig, 25 | ) -> List[float]: 26 | """ 27 | Training Loop. 28 | 29 | Args: 30 | model (nn.Module): model to train 31 | train_loader (DataLoader): training DataLoader 32 | opt (torch.optim.Optimizer): optimizer 33 | sent_ret_loss_fn (Optional[torch.nn.Module]): sentence retrieval loss function 34 | sign_ret_loss_fn (Optional[torch.nn.Module]): sign retrieval loss function 35 | sign_cls_loss_fn (Optional[torch.nn.Module]): sign classification loss function 36 | epoch (int): current epoch 37 | cfg (DictConfig): config file 38 | 39 | Returns: 40 | List[float]: train loss, sent_ret, sign_ret, sign_cls 41 | """ 42 | model.train() 43 | total_train_loss = 0 44 | total_sent_ret, total_sign_ret, total_sign_cls = 0, 0, 0 45 | device = torch.device( 46 | f"cuda:{cfg.local_rank}" if torch.cuda.is_available() else "cpu") 47 | pbar = tqdm(iter(train_loader)) if cfg.do_print else iter(train_loader) 48 | for batch_idx, batch in enumerate(pbar): 49 | model.zero_grad() 50 | opt.zero_grad() 51 | # unpack the batch 52 | subs, feats, target_indices, target_labels, target_word_embds, _, _, _, _ = batch 53 | feats = feats.to(device) 54 | word_embds = torch.cat(target_word_embds).to(device) \ 55 | if target_word_embds[0] is not None else None 56 | # forward pass on the model 57 | cls_tokens, video_tokens, sentence_embds, word_embds, output_tensor = model( 58 | video_features=feats, 59 | subtitles=subs, 60 | word_embds=word_embds, 61 | ) 62 | 63 | # computation of the different terms of the loss 64 | # computation of SentRet 65 | if sent_ret_loss_fn is not None: 66 | if cfg.distributed: 67 | cls_tokens = torch.cat(all_gather(cls_tokens), dim=0) 68 | sentence_embds = torch.cat(all_gather(sentence_embds), dim=0) 69 | sent_ret = sent_ret_loss_fn(cls_tokens, sentence_embds) 70 | else: 71 | sent_ret = torch.tensor([0]).to(device) 72 | 73 | # get the indices on where to compute the SignCls and SignRet losses 74 | if sign_cls_loss_fn is not None or sign_ret_loss_fn is not None: 75 | target_indices_batch_idx = torch.repeat_interleave( 76 | input=torch.arange(len(subs)), 77 | repeats=torch.tensor( 78 | [len(target_index) for target_index in target_indices] 79 | ), 80 | ) 81 | target_indices = torch.cat(target_indices) 82 | 83 | # computation of SignCls loss 84 | if sign_cls_loss_fn is not None: 85 | target_labels = torch.cat(target_labels).to(device, torch.long) 86 | predicted_logits = output_tensor[ 87 | target_indices_batch_idx, target_indices 88 | ] 89 | if cfg.loss.sign_cls._target_ == "torch.nn.BCEWithLogitsLoss": 90 | one_hot_target = torch.zeros_like(predicted_logits).to(device) 91 | one_hot_target[torch.arange( 92 | len(target_labels)), target_labels] = 1 93 | temp_target_labels, target_labels = target_labels, one_hot_target 94 | sign_cls = sign_cls_loss_fn(predicted_logits, target_labels) 95 | if cfg.loss.sign_cls._target_ == "torch.nn.BCEWithLogitsLoss": 96 | target_labels = temp_target_labels 97 | else: 98 | sign_cls = torch.tensor([0]).to(device) 99 | 100 | # computation of SignRet loss 101 | if sign_ret_loss_fn is not None: 102 | if sign_cls_loss_fn is None: 103 | target_labels = torch.cat(target_labels).to(device, torch.long) 104 | sign_ret = sign_ret_loss_fn( 105 | video_tokens[target_indices_batch_idx, target_indices], 106 | word_embds, 107 | labels=target_labels, 108 | ) 109 | else: 110 | sign_ret = torch.tensor([0]).to(device) 111 | 112 | # weighted sum of losses 113 | total_loss = cfg.loss.lda_sent_ret * sent_ret + \ 114 | cfg.loss.lda_sign_ret * sign_ret + \ 115 | cfg.loss.lda_sign_cls * sign_cls 116 | 117 | total_loss.backward() 118 | opt.step() 119 | torch.cuda.synchronize() 120 | 121 | # prepare for printing / logging 122 | current_loss = total_loss.detach().item() 123 | total_train_loss += current_loss 124 | current_sent_ret = sent_ret.detach().item() 125 | total_sent_ret += current_sent_ret 126 | current_sign_ret = sign_ret.detach().item() 127 | total_sign_ret += current_sign_ret 128 | current_sign_cls = sign_cls.detach().item() 129 | total_sign_cls += current_sign_cls 130 | if cfg.do_print: 131 | pbar.set_postfix( 132 | { 133 | "Loss": f"{current_loss:.2f}", 134 | "SentRet": f"{current_sent_ret:.2f}", 135 | "SignRet": f"{current_sign_ret:.2f}", 136 | "SignCls": f"{current_sign_cls:.2f}", 137 | } 138 | ) 139 | if cfg.do_print: 140 | wandb.log( 141 | { 142 | "train_loss_iter": current_loss, 143 | "train_sent_ret_iter": current_sent_ret, 144 | "train_sign_ret_iter": current_sign_ret, 145 | "train_sign_cls_iter": current_sign_cls, 146 | "train_iter": epoch * len(train_loader) + batch_idx, 147 | } 148 | ) 149 | return total_train_loss / len(train_loader), total_sent_ret / len(train_loader), \ 150 | total_sign_ret / len(train_loader), total_sign_cls / len(train_loader) 151 | -------------------------------------------------------------------------------- /loops/val_loop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python file defining the validation loop of the model. 3 | """ 4 | from typing import List, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import wandb 9 | from omegaconf import DictConfig 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from utils.gather import all_gather 14 | 15 | def val_loop( 16 | model: nn.Module, 17 | val_loader: DataLoader, 18 | sent_ret_loss_fn: Optional[nn.Module], 19 | sign_ret_loss_fn: Optional[nn.Module], 20 | sign_cls_loss_fn: Optional[nn.Module], 21 | epoch: int, 22 | cfg: DictConfig, 23 | ) -> List[float]: 24 | """ 25 | Validation Loop. 26 | 27 | Args: 28 | model (nn.Module): model to validate 29 | val_loader (DataLoader): validation DataLoader 30 | sent_ret_loss_fn (Optional[nn.Module]): sentence retrieval loss function 31 | sign_ret_loss_fn (Optional[nn.Module]): sign retrieval loss function 32 | sign_cls_loss_fn (Optional[nn.Module]): sign classification loss function 33 | epoch (int): current epoch 34 | cfg (DictConfig): config file 35 | 36 | Returns: 37 | List[float]: val loss, sent_ret, sign_ret, sign_cls 38 | """ 39 | model.eval() 40 | total_val_loss = 0 41 | total_sent_ret, total_sign_ret, total_sign_cls = 0, 0, 0 42 | device = torch.device( 43 | f"cuda:{cfg.local_rank}" if torch.cuda.is_available() else "cpu") 44 | pbar = tqdm(iter(val_loader)) if cfg.do_print else iter(val_loader) 45 | with torch.no_grad(): 46 | for batch_idx, batch in enumerate(pbar): 47 | # unpack the batch 48 | subs, feats, target_indices, target_labels, target_word_embds, _, _, _, _ = batch 49 | feats = feats.to(device) 50 | word_embds = torch.cat(target_word_embds).to(device) \ 51 | if target_word_embds[0] is not None else None 52 | # forward pass on the model 53 | cls_tokens, video_tokens, sentence_embds, word_embds, output_tensor = model( 54 | video_features=feats, 55 | subtitles=subs, 56 | word_embds=word_embds, 57 | ) 58 | 59 | # computation of the different terms of the loss 60 | # computation of SentRet 61 | if sent_ret_loss_fn is not None: 62 | if cfg.distributed: 63 | # need to all-gather the cls_tokens 64 | cls_tokens = torch.cat(all_gather(cls_tokens), dim=0) 65 | sentence_embds = torch.cat( 66 | all_gather(sentence_embds), dim=0) 67 | # computation of SentRet loss 68 | sent_ret = sent_ret_loss_fn(cls_tokens, sentence_embds) 69 | else: 70 | sent_ret = torch.tensor([0]).to(device) 71 | 72 | # get the indices of the target labels for SignCls and SignRet losses 73 | if sign_ret_loss_fn is not None or sign_cls_loss_fn is not None: 74 | target_labels = torch.cat(target_labels).to(device, torch.long) 75 | target_indices_batch_idx = torch.repeat_interleave( 76 | input=torch.arange(len(subs)), 77 | repeats=torch.tensor( 78 | [len(target_index) for target_index in target_indices] 79 | ), 80 | ) 81 | target_indices = torch.cat(target_indices) 82 | 83 | # computation of SignCls loss 84 | if sign_cls_loss_fn is not None: 85 | predicted_logits = output_tensor[ 86 | target_indices_batch_idx, target_indices 87 | ] 88 | if cfg.loss.sign_cls._target_ == "torch.nn.BCEWithLogitsLoss": 89 | one_hot_target = torch.zeros_like( 90 | predicted_logits).to(device) 91 | one_hot_target[torch.arange( 92 | len(target_labels)), target_labels] = 1 93 | temp_target_labels, target_labels = target_labels, one_hot_target 94 | sign_cls = sign_cls_loss_fn(predicted_logits, target_labels) 95 | if cfg.loss.sign_cls._target_ == "torch.nn.BCEWithLogitsLoss": 96 | target_labels = temp_target_labels 97 | else: 98 | sign_cls = torch.tensor([0]).to(device) 99 | 100 | # computation of SignRet loss 101 | if sign_ret_loss_fn is not None: 102 | sign_ret = sign_ret_loss_fn( 103 | video_tokens[target_indices_batch_idx, target_indices], 104 | word_embds, 105 | labels=target_labels, 106 | ) 107 | else: 108 | sign_ret = torch.tensor([0]).to(device) 109 | 110 | # weighted sum of losses 111 | total_loss = cfg.loss.lda_sent_ret * sent_ret + \ 112 | cfg.loss.lda_sign_ret * sign_ret + \ 113 | cfg.loss.lda_sign_cls * sign_cls 114 | 115 | torch.cuda.synchronize() 116 | 117 | # prepare for printing / logging 118 | current_loss = total_loss.detach().item() 119 | total_val_loss += current_loss 120 | current_sent_ret = sent_ret.detach().item() 121 | total_sent_ret += current_sent_ret 122 | current_sign_ret = sign_ret.detach().item() 123 | total_sign_ret += current_sign_ret 124 | current_sign_cls = sign_cls.detach().item() 125 | total_sign_cls += current_sign_cls 126 | if cfg.do_print: 127 | pbar.set_postfix( 128 | { 129 | "Loss": f"{current_loss:.2f}", 130 | "SentRet": f"{current_sent_ret:.2f}", 131 | "SignRet": f"{current_sign_ret:.2f}", 132 | "SignCls": f"{current_sign_cls:.2f}", 133 | } 134 | ) 135 | if cfg.do_print: 136 | wandb.log( 137 | { 138 | "val_loss_iter": current_loss, 139 | "val_sent_ret_iter": current_sent_ret, 140 | "val_sign_ret_iter": current_sign_ret, 141 | "val_sign_cls_iter": current_sign_cls, 142 | "val_iter": epoch * len(val_loader) + batch_idx, 143 | } 144 | ) 145 | return total_val_loss / len(val_loader), total_sent_ret / len(val_loader), \ 146 | total_sign_ret / len(val_loader), total_sign_cls / len(val_loader) 147 | -------------------------------------------------------------------------------- /utils/instantiate_dataloaders.py: -------------------------------------------------------------------------------- 1 | """Functions to instantiate dataloaders with Hydra""" 2 | from functools import partial 3 | from multiprocessing import Value 4 | 5 | import hydra 6 | import torch 7 | from omegaconf import DictConfig 8 | from torch.utils.data import DataLoader, Subset, get_worker_info 9 | from tqdm import tqdm 10 | 11 | from dataset.sentence import collate_fn_padd 12 | from utils.instantiate_augmentations import text_augmentations, vid_augmentations 13 | 14 | 15 | def worker_init_fn(skip_mode: bool, worker_id) -> None: 16 | """Worker init function.""" 17 | info = get_worker_info() 18 | try: 19 | info.dataset.dataset.skip_mode = skip_mode 20 | except AttributeError: 21 | info.dataset.skip_mode = skip_mode 22 | 23 | 24 | def instantiate_dataloaders(cfg: DictConfig): 25 | """DataLoader instantiation with Hydra config.""" 26 | train_dataset = hydra.utils.instantiate( 27 | cfg.dataset, 28 | setname="train", 29 | text_augmentations=text_augmentations(cfg), 30 | video_augmentations=vid_augmentations(cfg), 31 | ) 32 | val_dataset = hydra.utils.instantiate( 33 | cfg.dataset, 34 | setname="val", 35 | ) 36 | 37 | # shuffle data if train from checkpoint 38 | # to avoid getting the same batches (since seeded runs are used for reproducibility) 39 | if cfg.checkpoint is not None: 40 | train_dataset.subtitles.shuffle() 41 | val_dataset.subtitles.shuffle() 42 | 43 | # if we want to train on a fraction of the data 44 | sampler = None 45 | if cfg.dataloader.train_data_fraction < 1: 46 | train_dataset = Subset( 47 | train_dataset, 48 | torch.randperm( 49 | len(train_dataset) 50 | )[:int(len(train_dataset) * cfg.dataloader.train_data_fraction)], 51 | ) 52 | if cfg.distributed: 53 | assert cfg.world_size is not None and cfg.rank is not None 54 | sampler = torch.utils.data.distributed.DistributedSampler( 55 | train_dataset, num_replicas=cfg.world_size, rank=cfg.rank, 56 | ) 57 | cfg.dataloader.dataloader.shuffle = False 58 | train_skip_mode = Value("i", False) 59 | train_loader = hydra.utils.instantiate( 60 | cfg.dataloader.dataloader, 61 | dataset=train_dataset, 62 | collate_fn=collate_fn_padd, 63 | sampler=sampler, 64 | worker_init_fn=partial(worker_init_fn, train_skip_mode), 65 | ) 66 | 67 | # if we want to validate on a fraction of the data 68 | if cfg.dataloader.val_data_fraction < 1: 69 | val_dataset = Subset( 70 | val_dataset, 71 | torch.randperm( 72 | len(val_dataset) 73 | )[:int(len(val_dataset) * cfg.dataloader.val_data_fraction)], 74 | ) 75 | if cfg.distributed: 76 | sampler = torch.utils.data.distributed.DistributedSampler( 77 | val_dataset, num_replicas=cfg.world_size, rank=cfg.rank, 78 | ) 79 | val_skip_mode = Value("i", False) 80 | val_loader = hydra.utils.instantiate( 81 | cfg.dataloader.dataloader, 82 | dataset=val_dataset, 83 | collate_fn=collate_fn_padd, 84 | sampler=sampler, 85 | worker_init_fn=partial(worker_init_fn, val_skip_mode), 86 | ) 87 | cfg.dataloader.N = len(train_loader) 88 | return train_loader, val_loader, train_skip_mode, val_skip_mode 89 | 90 | 91 | def instantiate_vis_dataloaders(cfg: DictConfig): 92 | """DataLoader instantiation for visualization.""" 93 | # save visualisation on weakly aligned subtitles 94 | sub_paths = cfg.paths.subtitles_path 95 | skip_mode = Value("i", False) 96 | train_dataset = hydra.utils.instantiate( 97 | cfg.dataset, 98 | setname="train", 99 | subtitles_path=sub_paths, 100 | text_augmentations=text_augmentations(cfg), 101 | video_augmentations=vid_augmentations(cfg), 102 | load_pl=False, 103 | load_word_embds=False, 104 | ) 105 | val_dataset = hydra.utils.instantiate( 106 | cfg.dataset, 107 | setname="val", 108 | subtitles_path=sub_paths, 109 | load_pl=False, 110 | load_word_embds=False, 111 | ) 112 | gallery_size = min(len(val_dataset), 2000) 113 | train_dataset = Subset( 114 | train_dataset, 115 | torch.randperm(len(train_dataset))[:gallery_size], 116 | ) 117 | train_loader = hydra.utils.instantiate( 118 | cfg.dataloader.dataloader, 119 | dataset=train_dataset, 120 | collate_fn=collate_fn_padd, 121 | worker_init_fn=partial(worker_init_fn, skip_mode), 122 | ) 123 | if gallery_size < len(val_dataset): 124 | val_dataset = Subset( 125 | val_dataset, 126 | torch.randperm(len(val_dataset))[:gallery_size], 127 | ) 128 | val_loader = hydra.utils.instantiate( 129 | cfg.dataloader.dataloader, 130 | dataset=val_dataset, 131 | collate_fn=collate_fn_padd, 132 | worker_init_fn=partial(worker_init_fn, skip_mode), 133 | ) 134 | return train_loader, val_loader 135 | 136 | 137 | def instantiate_test_dataloader(cfg: DictConfig): 138 | """DataLoader instantiation for test.""" 139 | sub_paths = cfg.paths.aligned_subtitles_path 140 | test_dataset = hydra.utils.instantiate( 141 | cfg.dataset, 142 | setname="public_test", 143 | subtitles_path=sub_paths, 144 | ) 145 | skip_mode = Value("i", False) 146 | test_loader = hydra.utils.instantiate( 147 | cfg.dataloader.dataloader, 148 | dataset=test_dataset, 149 | collate_fn=collate_fn_padd, 150 | worker_init_fn=partial(worker_init_fn, skip_mode), 151 | ) 152 | return test_loader 153 | 154 | def skip_epochs( 155 | cfg: DictConfig, 156 | train_loader: DataLoader, 157 | val_loader: DataLoader, 158 | train_skip_mode: Value, 159 | val_skip_mode: Value, 160 | ): 161 | """Skip epochs to avoid getting the same batches.""" 162 | epoch_start = 0 if cfg.trainer.epoch_start is None else cfg.trainer.epoch_start 163 | if epoch_start != 0: 164 | # assert not cfg.dataloader.persistent_workers, \ 165 | # "When resuming training, persistent_workers must be set to False" 166 | # need to go through train and val datasets for RNG 167 | print(f"Skipping {epoch_start} epochs (for RNG)") 168 | train_skip_mode.value = True 169 | val_skip_mode.value = True 170 | for _ in range(0, epoch_start): 171 | pbar = tqdm(iter(train_loader)) if cfg.do_print else iter( 172 | train_loader) 173 | for _, _ in enumerate(pbar): 174 | pass 175 | pbar = tqdm(iter(val_loader)) if cfg.do_print else iter( 176 | val_loader) 177 | for _, _ in enumerate(pbar): 178 | pass 179 | train_skip_mode.value = False 180 | val_skip_mode.value = False 181 | return train_loader, val_loader 182 | -------------------------------------------------------------------------------- /models/t5.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creation of sentence_transformer model like for T5 architecture. 3 | https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py 4 | """ 5 | import logging 6 | from typing import List, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from tqdm.autonotebook import trange 12 | from transformers import T5EncoderModel, T5Tokenizer 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def make_sentence_model( 18 | model_name: str = "t5-base", 19 | root_path: bool = False 20 | ) -> nn.Module: 21 | """ 22 | Setup SentenceTransformer model. 23 | 24 | Args: 25 | model_name (str): name of the model to use. 26 | root_path (bool): whether to use the root path of the model. 27 | """ 28 | if root_path is not None: 29 | model = T5SentenceTransformer(root_path + model_name) 30 | else: 31 | model = T5SentenceTransformer(model_name) 32 | return model 33 | 34 | 35 | class T5SentenceTransformer(nn.Module): 36 | """Loads or create a T5 model, that can be used to map sentences / text to embeddings.""" 37 | def __init__(self, model_path_or_name: str): 38 | """Initializes a T5 model.""" 39 | super().__init__() 40 | logger.info("Load pretrained T5 model %s", model_path_or_name) 41 | self.tokenizer = T5Tokenizer.from_pretrained(model_path_or_name) 42 | self.model = T5EncoderModel.from_pretrained(model_path_or_name) 43 | device = "cuda" if torch.cuda.is_available() else "cpu" 44 | logger.info("Use pytorch device: %s", device) 45 | self._target_device = torch.device(device) 46 | 47 | def encode( 48 | self, 49 | sentences: Union[str, List[str]], 50 | batch_size: int = 32, 51 | show_progress_bar: bool = None, 52 | output_value: str = "sentence_embedding", 53 | convert_to_numpy: bool = True, 54 | convert_to_tensor: bool = False, 55 | device: str = None, 56 | normalize_embeddings: bool = False, 57 | ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: 58 | """ 59 | Computes sentence embeddings 60 | 61 | Args: 62 | sentences (Union[str, List[str]]): the sentences to embed 63 | batch_size (int): the batch size used for the computation 64 | show_progress_bar (bool): Output a progress bar when encode sentences 65 | output_value (str): Default, sentence_embeddings, to get sentence embeddings. 66 | Can be set to token_embeddings to get wordpiece token embeddings. 67 | convert_to_numpy (bool): If true, the output is a list of numpy vectors. 68 | Else, it is a list of pytorch tensors. 69 | convert_to_tensor (bool): If true, the output is a list of pytorch tensors. 70 | Else, it is a list of numpy vectors. 71 | device (str): Which torch.device to use for the computation 72 | normalize_embeddings (bool): If set to true, returned vectors will have length 1. 73 | 74 | Returns: 75 | By default, a list of tensors is returned. 76 | If convert_to_tensor, a stacked tensor is returned. 77 | If convert_to_numpy, a numpy matrix is returned. 78 | """ 79 | if show_progress_bar is None: 80 | show_progress_bar = ( 81 | logger.getEffectiveLevel() == logging.INFO or \ 82 | logger.getEffectiveLevel() == logging.DEBUG 83 | ) 84 | if convert_to_tensor: 85 | convert_to_numpy = False 86 | if output_value != "sentence_embedding": 87 | convert_to_tensor = False 88 | convert_to_numpy = False 89 | input_was_string = False 90 | if isinstance(sentences, str) or not hasattr(sentences, '__len__'): 91 | # cast an individual sentence to a list with length 1 92 | sentences = [sentences] 93 | input_was_string = True 94 | if device is None: 95 | device = self._target_device 96 | 97 | self.model.to(device) 98 | 99 | all_embeddings = [] 100 | length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) 101 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 102 | pbar = trange( 103 | 0, len(sentences_sorted), batch_size, 104 | desc="Batches", 105 | disable=not show_progress_bar, 106 | ) 107 | for start_index in pbar: 108 | sentences_batch = sentences_sorted[start_index:start_index + batch_size] 109 | features = self.tokenizer( 110 | sentences_batch, 111 | padding=True, 112 | return_tensors="pt", 113 | ) 114 | features = features.to(device) 115 | # forward (only the encoder) 116 | out_features = self.model(**features) 117 | if output_value == "token_embeddings": 118 | embeddings = [] 119 | zipped_features = zip(out_features["last_hidden_state"], features["attention_mask"]) 120 | for token_emb, attention in zipped_features: 121 | last_mask_id = len(attention) - 1 122 | while last_mask_id > 0 and attention[last_mask_id] == 0: 123 | last_mask_id -= 1 124 | embeddings.append(token_emb[0:last_mask_id + 1]) 125 | elif output_value is None: 126 | # return all outputs 127 | embeddings = [] 128 | for sent_idx in range(len(out_features["attentions"])): 129 | row = {name: out_features[name][sent_idx] for name in out_features} 130 | embeddings.append(row) 131 | else: 132 | # sentence embeddings 133 | # T5 has no pooling/cls token --> take the mean over all tokens 134 | token_embeddings = out_features["last_hidden_state"] 135 | input_mask_expanded = features["attention_mask"].unsqueeze(-1).expand( 136 | token_embeddings.size() 137 | ).float() 138 | embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / \ 139 | torch.clamp(input_mask_expanded.sum(1), min=1e-9) 140 | if normalize_embeddings: 141 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 142 | if convert_to_numpy: 143 | embeddings = embeddings.detach() 144 | embeddings = embeddings.cpu().numpy() 145 | all_embeddings.extend(embeddings) 146 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 147 | if convert_to_tensor: 148 | all_embeddings = torch.stack(all_embeddings) 149 | elif convert_to_numpy: 150 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 151 | if input_was_string: 152 | all_embeddings = all_embeddings[0] 153 | return all_embeddings 154 | 155 | def get_max_seq_length(self): 156 | """ 157 | Returns the maximal sequence length for input the model accepts. 158 | Longer inputs will be truncated. 159 | """ 160 | if hasattr(self.tokenizer, "max_seq_length"): 161 | return self.tokenizer.max_seq_length 162 | return None 163 | 164 | def _text_length(self, text: Union[List[int], List[List[int]]]): 165 | """ 166 | Help function to get the length for the input text. 167 | 168 | Args: 169 | text (Union[List[int], List[List[int]]]): 170 | list of ints (which means a signle text as input), 171 | or a tuple of list of ints (representing several text inputs to the model) 172 | """ 173 | if isinstance(text, dict): 174 | # {key: value} case 175 | return len(next(iter(text.values()))) 176 | elif not hasattr(text, "__len__"): 177 | # object has no len method 178 | return 1 179 | elif len(text) == 0 or isinstance(text[0], int): 180 | # empty string or list of ints 181 | return len(text) 182 | else: 183 | # sum of length of individual strings 184 | return sum([len(t) for t in text]) 185 | -------------------------------------------------------------------------------- /loops/retrieval.py: -------------------------------------------------------------------------------- 1 | """Module for computing retrieval metrics.""" 2 | import numpy as np 3 | import scipy.stats 4 | 5 | 6 | def t2v_metrics(sims, query_masks=None): 7 | """Compute retrieval metrics from a similiarity matrix. 8 | 9 | Args: 10 | sims (th.Tensor): N x M matrix of similarities between embeddings, where 11 | x_{i,j} = 12 | query_masks (th.Tensor): mask any missing queries from the dataset (two videos 13 | in MSRVTT only have 19, rather than 20 captions) 14 | 15 | Returns: 16 | (dict[str:float]): retrieval metrics 17 | """ 18 | assert sims.ndim == 2, "expected a matrix" 19 | num_queries, num_vids = sims.shape 20 | dists = -sims 21 | sorted_dists = np.sort(dists, axis=1) 22 | 23 | # The indices are computed such that they slice out the ground truth distances 24 | # from the psuedo-rectangular dist matrix 25 | queries_per_video = num_queries // num_vids 26 | gt_idx = [[np.ravel_multi_index([ii, jj], (num_queries, num_vids)) 27 | for ii in range(jj * queries_per_video, (jj + 1) * queries_per_video)] 28 | for jj in range(num_vids)] 29 | gt_idx = np.array(gt_idx) 30 | gt_dists = dists.reshape(-1)[gt_idx.reshape(-1)] 31 | gt_dists = gt_dists[:, np.newaxis] 32 | rows, cols = np.where((sorted_dists - gt_dists) == 0) # find column position of GT 33 | 34 | # -------------------------------- 35 | # NOTE: Breaking ties 36 | # -------------------------------- 37 | # We sometimes need to break ties (in general, these should occur extremely rarely, 38 | # but there are pathological cases when they can distort the scores, such as when 39 | # the similarity matrix is all zeros). Previous implementations (e.g. the t2i 40 | # evaluation function used 41 | # here: https://github.com/niluthpol/multimodal_vtt/blob/master/evaluation.py and 42 | # here: https://github.com/linxd5/VSE_Pytorch/blob/master/evaluation.py#L87) generally 43 | # break ties "optimistically". However, if the similarity matrix is constant this 44 | # can evaluate to a perfect ranking. A principled option is to average over all 45 | # possible partial orderings implied by the ties. See # this paper for a discussion: 46 | # McSherry, Frank, and Marc Najork, 47 | # "Computing information retrieval performance measures efficiently in the presence 48 | # of tied scores." European conference on information retrieval. Springer, Berlin, 49 | # Heidelberg, 2008. 50 | # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.145.8892&rep=rep1&type=pdf 51 | 52 | # break_ties = "optimistically" 53 | break_ties = "averaging" 54 | 55 | if rows.size > num_queries: 56 | assert np.unique(rows).size == num_queries, "issue in metric evaluation" 57 | if break_ties == "optimistically": 58 | _, idx = np.unique(rows, return_index=True) 59 | cols = cols[idx] 60 | elif break_ties == "averaging": 61 | # fast implementation, based on this code: 62 | # https://stackoverflow.com/a/49239335 63 | locs = np.argwhere((sorted_dists - gt_dists) == 0) 64 | 65 | # Find the split indices 66 | steps = np.diff(locs[:, 0]) 67 | splits = np.nonzero(steps)[0] + 1 68 | splits = np.insert(splits, 0, 0) 69 | 70 | # Compute the result columns 71 | summed_cols = np.add.reduceat(locs[:, 1], splits) 72 | counts = np.diff(np.append(splits, locs.shape[0])) 73 | avg_cols = summed_cols / counts 74 | cols = avg_cols 75 | 76 | msg = "expected ranks to match queries ({} vs {}) " 77 | if cols.size != num_queries: 78 | raise ValueError(msg.format(cols.size, num_queries)) 79 | assert cols.size == num_queries, msg 80 | 81 | if query_masks is not None: 82 | # remove invalid queries 83 | assert query_masks.size == num_queries, "invalid query mask shape" 84 | cols = cols[query_masks.reshape(-1).astype(np.bool)] 85 | assert cols.size == query_masks.sum(), "masking was not applied correctly" 86 | # update number of queries to account for those that were missing 87 | num_queries = query_masks.sum() 88 | 89 | return cols2metrics(cols, num_queries), cols 90 | 91 | 92 | def v2t_metrics(sims, query_masks=None): 93 | """Compute retrieval metrics from a similiarity matrix. 94 | 95 | Args: 96 | sims (th.Tensor): N x M matrix of similarities between embeddings, where 97 | x_{i,j} = 98 | query_masks (th.Tensor): mask any missing captions from the dataset 99 | 100 | Returns: 101 | (dict[str:float]): retrieval metrics 102 | 103 | NOTES: We find the closest "GT caption" in the style of VSE, which corresponds 104 | to finding the rank of the closest relevant caption in embedding space: 105 | github.com/ryankiros/visual-semantic-embedding/blob/master/evaluation.py#L52-L56 106 | """ 107 | # switch axes of text and video 108 | sims = sims.T 109 | 110 | assert sims.ndim == 2, "expected a matrix" 111 | num_queries, num_caps = sims.shape 112 | dists = -sims 113 | caps_per_video = num_caps // num_queries 114 | break_ties = "averaging" 115 | 116 | MISSING_VAL = 1E8 # pylint: disable=invalid-name 117 | query_ranks = [] 118 | for ii in range(num_queries): # pylint: disable=invalid-name 119 | row_dists = dists[ii, :] 120 | if query_masks is not None: 121 | # Set missing queries to have a distance of infinity. A missing query 122 | # refers to a query position `n` for a video that had less than `n` 123 | # captions (for example, a few MSRVTT videos only have 19 queries) 124 | row_dists[np.logical_not(query_masks.reshape(-1))] = MISSING_VAL 125 | 126 | # NOTE: Using distance subtraction to perform the ranking is easier to make 127 | # deterministic than using argsort, which suffers from the issue of defining 128 | # "stability" for equal distances. Example of distance subtraction code: 129 | # github.com/antoine77340/Mixture-of-Embedding-Experts/blob/master/train.py 130 | sorted_dists = np.sort(row_dists) 131 | 132 | min_rank = np.inf 133 | for jj in range(ii * caps_per_video, (ii + 1) * caps_per_video): # pylint: disable=invalid-name 134 | if row_dists[jj] == MISSING_VAL: 135 | # skip rankings of missing captions 136 | continue 137 | ranks = np.where((sorted_dists - row_dists[jj]) == 0)[0] 138 | if break_ties == "optimistically": 139 | rank = ranks[0] 140 | elif break_ties == "averaging": 141 | # NOTE: If there is more than one caption per video, its possible for the 142 | # method to do "worse than chance" in the degenerate case when all 143 | # similarities are tied. TODO(Samuel): Address this case. 144 | rank = ranks.mean() 145 | else: 146 | raise ValueError(f"unknown break_ties: {break_ties}") 147 | if rank < min_rank: 148 | min_rank = rank 149 | query_ranks.append(min_rank) 150 | query_ranks = np.array(query_ranks) 151 | 152 | return cols2metrics(query_ranks, num_queries), query_ranks 153 | 154 | 155 | def cols2metrics(cols, num_queries): 156 | """ 157 | Compute retrieval metrics from a column vector of ranks. 158 | """ 159 | metrics = {} 160 | metrics["R1"] = 100 * float(np.sum(cols == 0)) / num_queries 161 | metrics["R5"] = 100 * float(np.sum(cols < 5)) / num_queries 162 | metrics["R10"] = 100 * float(np.sum(cols < 10)) / num_queries 163 | metrics["R50"] = 100 * float(np.sum(cols < 50)) / num_queries 164 | metrics["MedR"] = np.median(cols) + 1 165 | metrics["MeanR"] = np.mean(cols) + 1 166 | stats = [metrics[x] for x in ("R1", "R5", "R10")] 167 | metrics["geometric_mean_R1-R5-R10"] = scipy.stats.mstats.gmean(stats) 168 | return metrics 169 | 170 | 171 | if __name__ == "__main__": 172 | # test out the implementation with a toy example 173 | test_sims = np.ones((3, 3)) 174 | test_sims[0, 0] = 2 175 | test_sims[1, 1:2] = 2 176 | test_sims[2, :] = 2 177 | test_query_masks = None # pylint: disable=invalid-name 178 | 179 | print("v2t_metrics") 180 | print(v2t_metrics(test_sims, test_query_masks)) 181 | 182 | print("t2v_metrics") 183 | print(t2v_metrics(test_sims, test_query_masks)) 184 | -------------------------------------------------------------------------------- /dataset/lmdb_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic dataset that loads data from LMDB database. 3 | """ 4 | import warnings 5 | from pathlib import Path 6 | from typing import List, Optional, Union 7 | 8 | import lmdb 9 | import torch 10 | from einops import rearrange 11 | from torchvision.io import decode_image, write_video 12 | 13 | 14 | class LMDBLoader(object): 15 | """ 16 | Generic dataset that loads data from LMDB database. 17 | """ 18 | def __init__( 19 | self, 20 | lmdb_path: str, 21 | load_stride: int = 1, 22 | load_float16: bool = False, 23 | load_type: str = "feats", 24 | verbose: bool = False, 25 | lmdb_window_size: int = 16, 26 | lmdb_stride: int = 2, 27 | feat_dim: Optional[int] = None, 28 | ): 29 | """ 30 | Args: 31 | lmdb_path (Union[str, Path]): Path to LMDB database. 32 | load_stride (int): Stride for loading frames from LMDB database. 33 | load_float16 (bool): Whether to load frames as float16. 34 | load_type (str): Type of data to load from LMDB database. 35 | Either "feats" or "frames" or "pseudo-labels". 36 | verbose (bool): Whether to print verbose messages. 37 | lmdb_window_size (int): Window size for sliding window approach. 38 | Only required if load_type == "feats" or "pseudo-labels". 39 | lmdb_stride (int): Stride for sliding window approach. 40 | Only required if load_type == "feats" or "pseudo-labels". 41 | feat_dim (Optional[int]): Feature dimensionality. 42 | Only required if load_type == "feats". 43 | """ 44 | assert load_type in ["feats", "frames", "pseudo-labels"], \ 45 | f"load_type must be either 'feats' or 'frames' or 'pseudo-labels', but got {load_type}." 46 | self.lmdb_path = lmdb_path 47 | self.load_stride = load_stride 48 | self.load_float16 = load_float16 49 | self.lmdb = self._init_lmdb() 50 | self.load_type = load_type 51 | self.verbose = verbose 52 | if self.load_type == "feats": 53 | assert feat_dim is not None, "feat_dim must be provided if load_type == 'feats'." 54 | self.vid_feat_dim = feat_dim 55 | if self.load_type in ["feats", "pseudo-labels"]: 56 | self.lmdb_window_size = lmdb_window_size 57 | self.lmdb_stride = lmdb_stride 58 | 59 | def _init_lmdb(self) -> lmdb.Environment: 60 | """Initialise LMDB database.""" 61 | return lmdb.open(self.lmdb_path, readonly=True, lock=False, max_readers=10000) 62 | 63 | @staticmethod 64 | def _get_feat_key(episode_name: str, frame_index: int, suffix: str = ".np") -> bytes: 65 | """Returns key for features in LMDB database.""" 66 | key_end = f"{frame_index + 1:07d}{suffix}" 67 | return f"{Path(episode_name.split('.')[0]).stem}/{key_end}".encode('ascii') 68 | 69 | @staticmethod 70 | def _get_pseudo_label_key( 71 | episode_name: str, frame_index: int, suffix: str = ".np" 72 | ) -> List[bytes]: 73 | """Returns key for pseudo-labels and corresponding probabilities in LMDB database.""" 74 | key_end = f"{frame_index + 1:07d}{suffix}" 75 | return f"{Path(episode_name.split('.')[0] + '_label').stem}/{key_end}".encode('ascii'), \ 76 | f"{Path(episode_name.split('.')[0] + '_prob').stem}/{key_end}".encode('ascii') 77 | 78 | @staticmethod 79 | def _get_rbg_key( 80 | episode_name: str, frame_index: int, suffix: str = ".jpg") -> bytes: 81 | """Returns key for RGB frames in LMDB database.""" 82 | key_end = f"{frame_index + 1:07d}{suffix}" 83 | return f"{Path(episode_name.split('.')[0]).stem}/{key_end}".encode('ascii') 84 | 85 | def feature_idx_to_frame_idx( 86 | self, 87 | feature_idx: int, 88 | ) -> int: 89 | """Convert feature index to frame index.""" 90 | begin_idx = self.lmdb_window_size // 2 - 1 91 | return begin_idx + feature_idx * self.lmdb_stride 92 | 93 | def frame_idx_to_feature_idx( 94 | self, 95 | frame_idx: int, 96 | ) -> int: 97 | """ 98 | Convert frame index to feature index. 99 | Formula: frame_idx = begin_idx + feature_idx * stride 100 | with begin_idx = self.lmdb_window_size // 2 - 1 101 | """ 102 | begin_idx = self.lmdb_window_size // 2 - 1 103 | return max(0, (frame_idx - begin_idx) // self.lmdb_stride) 104 | 105 | def load_sequence( 106 | self, episode_name: str, begin_frame: int, end_frame: int, 107 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 108 | """Loads a sequence of frames/features/pseudo-labels from LMDB database.""" 109 | if self.load_type == "feats" or self.load_type == "pseudo-labels": 110 | begin_frame = self.frame_idx_to_feature_idx(frame_idx=begin_frame) 111 | end_frame = self.frame_idx_to_feature_idx(frame_idx=end_frame) 112 | if self.load_type == "feats": 113 | all_feats = [] 114 | else: 115 | all_labels, all_probs = [], [] 116 | else: 117 | # load RGB frames 118 | frames = [] 119 | with warnings.catch_warnings(): 120 | warnings.simplefilter("ignore", category=UserWarning) 121 | for frame_idx in range(begin_frame, end_frame, self.load_stride): 122 | # get key for LMDB database 123 | if self.load_type == "feats": 124 | feats_key = self._get_feat_key(episode_name, frame_idx) 125 | with self.lmdb.begin() as txn: 126 | features = torch.zeros((self.vid_feat_dim,), dtype=torch.float16) 127 | try: 128 | features = torch.frombuffer(txn.get(feats_key), dtype=torch.float16) 129 | except (KeyError, TypeError, ValueError): 130 | if self.verbose: 131 | print(f"Key {feats_key} not found in LMDB database.") 132 | all_feats.append(features) 133 | elif self.load_type == "pseudo-labels": 134 | label_key, prob_key = self._get_pseudo_label_key(episode_name, frame_idx) 135 | with self.lmdb.begin() as txn: 136 | labels = torch.zeros((5,), dtype=torch.long) 137 | probs = torch.zeros((5,), dtype=torch.float16) 138 | try: 139 | labels = torch.frombuffer(txn.get(label_key), dtype=torch.long) 140 | probs = torch.frombuffer(txn.get(prob_key), dtype=torch.float16) 141 | except (KeyError, TypeError, ValueError): 142 | if self.verbose: 143 | print(f"Key {label_key} or {prob_key} not found in LMDB database.") 144 | all_labels.append(labels) 145 | all_probs.append(probs) 146 | else: 147 | # load rgb frames 148 | rgb_key = self._get_rbg_key(episode_name, frame_idx) 149 | with self.lmdb.begin() as txn: 150 | frame = torch.zeros((3, 256, 256), dtype=torch.uint8) 151 | try: 152 | frame = decode_image( 153 | torch.frombuffer(txn.get(rgb_key), dtype=torch.uint8) 154 | ) 155 | except (KeyError, TypeError, ValueError): 156 | if self.verbose: 157 | print(f"Key {rgb_key} not found in LMDB database.") 158 | frames.append(frame) 159 | 160 | if self.load_type == "feats": 161 | if self.load_float16: 162 | return torch.stack(all_feats).half() 163 | try: 164 | all_feats = torch.stack(all_feats).float() 165 | except: 166 | print(len(all_feats), episode_name, begin_frame, end_frame) 167 | return all_feats 168 | elif self.load_type == "pseudo-labels": 169 | if self.load_float16: 170 | return torch.stack(all_labels), torch.stack(all_probs).half() 171 | return torch.stack(all_labels), torch.stack(all_probs).float() 172 | else: 173 | # load rgb frames + rearrange from (T, C, H, W) to (T, H, W, C) 174 | return rearrange(torch.stack(frames), "t c h w -> t h w c") 175 | 176 | def save_rgb_video( 177 | self, 178 | episode_name: str, 179 | begin_frame: int, end_frame: int, 180 | save_dir: str, 181 | ) -> None: 182 | """Function to save RGB frames as video.""" 183 | # load frames 184 | frames = self.load_sequence(episode_name, begin_frame, end_frame) 185 | # save video 186 | write_video( 187 | filename=f"{save_dir}{episode_name}_{begin_frame}-{end_frame}.mp4", 188 | video_array=frames, 189 | fps=25, 190 | ) 191 | -------------------------------------------------------------------------------- /models/cslr2.py: -------------------------------------------------------------------------------- 1 | """Model combining video and text modalities together.""" 2 | from typing import List, Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class CSLR2(nn.Module): 10 | """Model combining video and text modalities together.""" 11 | def __init__( 12 | self, 13 | video_encoder: nn.Module, 14 | text_encoder: Union[nn.Module, object], 15 | video_sequence_ll: Optional[nn.Module] = None, 16 | video_token_ll: Optional[nn.Module] = None, 17 | text_sentence_ll: Optional[nn.Module] = None, 18 | text_word_ll: Optional[nn.Module] = None, 19 | pooling: str = "max", 20 | sign_ret: bool = False, 21 | no_video_encoder: bool = False, 22 | same_text_ll: bool = False, 23 | same_video_ll: bool = False, 24 | ) -> None: 25 | """ 26 | Args: 27 | video_encoder (nn.Module): video encoder model. 28 | text_encoder (Union[nn.Module, object]): text encoder model. 29 | video_sequence_ll (Optional[nn.Module]): linear layer for video sequence embeddings. 30 | video_token_ll (Optional[nn.Module]): linear layer for video token embeddings. 31 | text_sentence_ll (Optional[nn.Module]): linear layer for text sentence embeddings. 32 | text_word_ll (Optional[nn.Module]): linear layer for text word embeddings. 33 | pooling (str): pooling method for video embeddings. 34 | sign_ret (bool): whether sign retrieval loss is used. 35 | no_video_encoder (bool): whether to use video encoder. 36 | same_text_ll (bool): whether to use the same linear layer for text embeddings. 37 | same_video_ll (bool): whether to use the same linear layer for video embeddings. 38 | """ 39 | super(CSLR2, self).__init__() 40 | self.video_encoder = video_encoder 41 | self.video_sequence_ll = video_sequence_ll 42 | self.video_token_ll = video_token_ll if sign_ret else None 43 | self.text_encoder = text_encoder 44 | self.text_sentence_ll = text_sentence_ll 45 | self.text_word_ll = text_word_ll if sign_ret else None 46 | self.pooling = pooling 47 | self.sign_ret = sign_ret 48 | self.no_video_encoder = no_video_encoder 49 | if self.no_video_encoder: 50 | self.video_encoder = None 51 | self.same_text_ll = same_text_ll 52 | self.same_video_ll = same_video_ll 53 | 54 | def extract_sentence_embeddings( 55 | self, 56 | sentences: Union[str, List[str]], 57 | device: torch.device, 58 | ) -> torch.tensor: 59 | """ 60 | Extract sentence embeddings. 61 | 62 | Args: 63 | sentences (Union[str, List[str]]): List of sentences or a single sentence. 64 | device (torch.device): Device to use for the model. 65 | 66 | Returns: 67 | Sentence embeddings (torch.tensor). 68 | """ 69 | batch_size = len(sentences) if isinstance(sentences, List) else 1 70 | embeddings = self.text_encoder.encode( 71 | sentences, 72 | batch_size=batch_size, 73 | show_progress_bar=False, 74 | convert_to_tensor=True, 75 | device=device, 76 | ) 77 | return embeddings 78 | 79 | def project_sentence_embeddings( 80 | self, 81 | embeddings: torch.tensor, 82 | ) -> torch.tensor: 83 | """Project sentence embeddings (text).""" 84 | return F.normalize(self.text_sentence_ll(embeddings), dim=-1) 85 | 86 | def project_sequence_embeddings( 87 | self, 88 | embeddings: torch.tensor, 89 | ) -> torch.tensor: 90 | """Project sequence embeddings (video).""" 91 | return F.normalize(self.video_sequence_ll(embeddings), dim=-1) 92 | 93 | def project_word_embeddings( 94 | self, 95 | embeddings: torch.tensor, 96 | ) -> torch.tensor: 97 | """Project word embeddings (text).""" 98 | assert self.text_word_ll is not None 99 | if self.same_text_ll: 100 | return F.normalize(self.text_sentence_ll(embeddings), dim=-1) 101 | return F.normalize(self.text_word_ll(embeddings), dim=-1) 102 | 103 | def project_token_embeddings( 104 | self, 105 | embeddings: torch.tensor, 106 | ) -> torch.tensor: 107 | """Project token embeddings (video).""" 108 | assert self.video_token_ll is not None 109 | if self.same_video_ll: 110 | return F.normalize(self.video_sequence_ll(embeddings), dim=-1) 111 | return F.normalize(self.video_token_ll(embeddings), dim=-1) 112 | 113 | def video_pooling( 114 | self, 115 | embeddings: torch.tensor, 116 | input_features: torch.tensor, 117 | ) -> torch.tensor: 118 | """ 119 | Pooling of video embeddings to replace learnable CLS token. 120 | 121 | Args: 122 | embeddings (torch.tensor): embedded video features. 123 | input_features (torch.tensor): video features. 124 | 125 | Returns: 126 | Pooled video embeddings. 127 | """ 128 | video_mask = (input_features != 0).sum(-1) != 0 129 | video_mask = video_mask.to(input_features.device, non_blocking=True) 130 | pool_start_idx = 0 if self.no_video_encoder else 1 131 | if self.pooling == "mean": 132 | cls_tokens = (embeddings[:, pool_start_idx:, :] 133 | * video_mask[:, :, None]) 134 | cls_tokens = cls_tokens.sum(1) / video_mask.sum(1)[:, None] 135 | elif self.pooling == "max": 136 | cls_tokens = embeddings[:, pool_start_idx:, :].max(dim=1)[0] 137 | elif self.pooling == "median": 138 | cls_tokens = embeddings[:, pool_start_idx:, :].median(dim=1)[0] 139 | else: 140 | # learnable CLS token 141 | cls_tokens = embeddings[:, 0, :] 142 | return cls_tokens 143 | 144 | def forward( 145 | self, 146 | video_features: torch.tensor, 147 | subtitles: List[str], 148 | word_embds: Optional[torch.tensor] = None, 149 | ): 150 | """ 151 | Forward function of the model. 152 | 153 | Args: 154 | video_features (torch.tensor): video features. 155 | subtitles (List[str]): list of subtitles. 156 | word_embds (Optional[torch.tensor]): word embeddings. 157 | 158 | Returns: 159 | cls_tokens (torch.tensor): video embeddings (sequence level). 160 | video_tokens (torch.tensor): video token embeddings (token level). 161 | sentence_embds (torch.tensor): sentence embeddings (text). 162 | word_embds (torch.tensor): word embeddings (text). 163 | output_tensor (torch.tensor): embedded video features. 164 | """ 165 | # video side 166 | if not self.no_video_encoder: 167 | cls_tokens, output_tensor = self.video_encoder(video_features) 168 | else: 169 | cls_tokens = video_features 170 | output_tensor = None 171 | if self.sign_ret: 172 | # remove CLS token 173 | tokens = cls_tokens[:, 174 | 1:] if not self.no_video_encoder else cls_tokens 175 | if self.video_token_ll is not None: 176 | video_tokens = self.project_token_embeddings(tokens) 177 | else: 178 | # normalise 179 | video_tokens = F.normalize(tokens, dim=-1) 180 | else: 181 | video_tokens = None 182 | 183 | cls_tokens = self.video_pooling(cls_tokens, video_features) 184 | if self.video_sequence_ll is not None: 185 | cls_tokens = self.project_sequence_embeddings(cls_tokens) 186 | else: 187 | # normalise 188 | cls_tokens = F.normalize(cls_tokens, dim=-1) 189 | # text side 190 | sentence_embds = self.extract_sentence_embeddings( 191 | subtitles, video_features.device) 192 | if self.text_sentence_ll is not None: 193 | sentence_embds = self.project_sentence_embeddings(sentence_embds) 194 | else: 195 | # normalise 196 | sentence_embds = F.normalize(sentence_embds, dim=-1) 197 | if self.text_word_ll is not None and word_embds is not None: 198 | word_embds = self.project_word_embeddings(word_embds) 199 | elif word_embds is not None: 200 | # normalise 201 | word_embds = F.normalize(word_embds, dim=-1) 202 | return cls_tokens, video_tokens, sentence_embds, word_embds, output_tensor 203 | 204 | def forward_sentret( 205 | self, 206 | video_features: torch.Tensor, 207 | subtitles: List[str], 208 | ): 209 | """ 210 | Forward function of the model (in the case only sentence-level retrieval is needed). 211 | 212 | Args: 213 | video_features (torch.tensor): video features. 214 | subtitles (List[str]): list of subtitles. 215 | 216 | Returns: 217 | cls_tokens (torch.tensor): video embeddings (sequence level). 218 | sentence_embds (torch.tensor): sentence embeddings (text). 219 | """ 220 | device = video_features.device 221 | # video side 222 | if not self.no_video_encoder: 223 | cls_tokens, _ = self.video_encoder(video_features) 224 | else: 225 | cls_tokens = video_features 226 | 227 | cls_tokens = self.video_pooling(cls_tokens, video_features) 228 | if self.video_sequence_ll is not None: 229 | cls_tokens = self.project_sequence_embeddings(cls_tokens) 230 | else: 231 | # normalise 232 | cls_tokens = F.normalize(cls_tokens, dim=-1) 233 | 234 | # text side 235 | sentence_embds = self.extract_sentence_embeddings( 236 | subtitles, device) 237 | if self.text_sentence_ll is not None: 238 | sentence_embds = self.project_sentence_embeddings( 239 | sentence_embds) 240 | else: 241 | sentence_embds = F.normalize(sentence_embds, dim=-1) 242 | return cls_tokens, sentence_embds 243 | -------------------------------------------------------------------------------- /utils/matplotlib_utils.py: -------------------------------------------------------------------------------- 1 | """Functions to save predictions (retrieval) with matplotlib.""" 2 | import os 3 | import pickle 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | import cv2 8 | import lmdb 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import wandb 13 | from matplotlib.animation import ArtistAnimation 14 | from numpy import ndarray 15 | from omegaconf import DictConfig 16 | from torch import Tensor 17 | from tqdm import tqdm 18 | 19 | 20 | def lmdb_key_list(episode_name: str, begin_frame: int, end_frame: int) -> List: 21 | """ 22 | Returns list of keys for RGB videos 23 | 24 | Args: 25 | episode_name (str): Episode name. 26 | begin_frame (int): Begin frame. 27 | end_frame (int): End frame. 28 | 29 | Returns: 30 | List: List of keys mapping to RGB frames in lmdb environment. 31 | """ 32 | return [f"{Path(episode_name.split('.')[0])}/{frame_idx + 1:07d}.jpg".encode('ascii') \ 33 | for frame_idx in range(begin_frame, end_frame + 1)] 34 | 35 | 36 | def get_rgb_frames(lmdb_keys: List[str], lmdb_env: lmdb.Environment) -> List: 37 | """ 38 | Returns list of RGB frames 39 | 40 | Args: 41 | lmdb_keys (List[str]): List of keys mapping to RGB frames in lmdb environment. 42 | lmdb_env (lmdb.Environment): lmdb environment. 43 | 44 | Returns: 45 | frames (List): List of RGB frames. 46 | """ 47 | frames = [] 48 | for key in lmdb_keys: 49 | with lmdb_env.begin() as txn: 50 | frame = txn.get(key) 51 | frame = cv2.imdecode( 52 | np.frombuffer(frame, dtype=np.uint8), 53 | cv2.IMREAD_COLOR, 54 | ) 55 | rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 56 | frames.append(rgb_frame) 57 | return frames 58 | 59 | def save_retrieval_vis( 60 | cfg: DictConfig, 61 | sim_matrix: Union[ndarray, Tensor], 62 | all_sentences: List[str], 63 | video_names: List[str], 64 | sub_starts: List[float], 65 | sub_ends: List[float], 66 | rgb_lmdb_env: lmdb.Environment, 67 | setname: str, 68 | epoch: int, 69 | pl_as_subtitles: bool = False, 70 | k: int = 5, 71 | text_only: bool = False, 72 | ) -> None: 73 | """ 74 | Save retrieval visualization. 75 | 76 | Args: 77 | cfg (DictConfig): Config file. 78 | sim_matrix (Union[ndarray, Tensor]): Similarity matrix. sim[i, j] = . 79 | all_sentences (List[str]): List of all text sentences. 80 | video_names (List[str]): List of all video names. 81 | sub_starts (List[float]): List of all subtitle start times. 82 | sub_ends (List[float]): List of all subtitle end times. 83 | rgb_lmdb_env (lmdb.Environment): lmdb environment. 84 | setname (str): Name of the set (train, val, test). 85 | epoch (int): Current epoch. 86 | pl_as_subtitles (bool, optional): Whether to use the PL as subtitles. Defaults to False. 87 | k (int): Number of subtitles to retrieve for visualisation. Defaults to 5. 88 | text_only (bool, optional): Whether to only save the text. Defaults to False. 89 | """ 90 | if text_only: 91 | v2t_retrieval_results = { 92 | "gt": [], "retrieved_subs": [], "gt_start": [], "gt_end": [], 93 | "retrieved_starts": [], "retrieved_ends": [], "gt_name": [], "retrieved_names": [], 94 | "sims": [], 95 | } 96 | t2v_retrieval_results = { 97 | "gt": [], "retrieved_videos": [], "gt_start": [], "gt_end": [], 98 | "retrieved_starts": [], "retrieved_ends": [], "gt_name": [], "retrieved_names": [], 99 | "sims": [], 100 | } 101 | for vis_idx in tqdm(range(cfg.nb_vis)): 102 | # get the similarity scores for the current video 103 | sim_scores = sim_matrix[:, vis_idx] 104 | # get the indices of the top-k most similar videos 105 | topk = torch.topk(torch.tensor(sim_scores), k) 106 | topk_indices, topk_values = topk.indices, topk.values 107 | # get the subtitles corresponding to the indices in question 108 | topk_subtitles = [all_sentences[idx] for idx in topk_indices] 109 | topk_subtitles_str = "" 110 | video_name = video_names[vis_idx] 111 | start, end = sub_starts[vis_idx], sub_ends[vis_idx] 112 | subtitle = all_sentences[vis_idx] 113 | if text_only: 114 | v2t_retrieval_results["gt"].append(subtitle) 115 | v2t_retrieval_results["retrieved_subs"].append(topk_subtitles) 116 | v2t_retrieval_results["gt_start"].append(start) 117 | v2t_retrieval_results["gt_end"].append(end) 118 | v2t_retrieval_results["gt_name"].append(video_name) 119 | v2t_retrieval_results["retrieved_starts"].append( 120 | [sub_starts[idx] for idx in topk_indices] 121 | ) 122 | v2t_retrieval_results["retrieved_ends"].append( 123 | [sub_ends[idx] for idx in topk_indices] 124 | ) 125 | v2t_retrieval_results["retrieved_names"].append( 126 | [video_names[idx] for idx in topk_indices] 127 | ) 128 | v2t_retrieval_results["sims"].append(topk_values) 129 | else: 130 | for idx, topk_subtitle in enumerate(topk_subtitles): 131 | topk_subtitles_str += topk_subtitle 132 | topk_subtitles_str += f" ({topk_values[idx]:.2f}) \n" 133 | # get the corresponding video frames 134 | lmdb_keys = lmdb_key_list( 135 | episode_name=video_name, 136 | begin_frame=int(start * 25), 137 | end_frame=int(end * 25), 138 | ) 139 | frames = get_rgb_frames(lmdb_keys, rgb_lmdb_env) 140 | # assemble into matplotlib figure 141 | fig = plt.figure() 142 | ax1, ax2 = fig.add_subplot(2, 1, 1), fig.add_subplot(2, 1, 2) 143 | ax1.set_xlim(0, 255) 144 | ax1.set_ylim(0, 255) 145 | ax1.tick_params( 146 | axis="both", 147 | which="both", 148 | bottom=False, 149 | top=False, 150 | left=False, 151 | right=False, 152 | labelbottom=False, 153 | labelleft=False, 154 | ) 155 | ax1.set_title(subtitle) 156 | ax2.tick_params( 157 | axis="both", 158 | which="both", 159 | bottom=False, 160 | top=False, 161 | left=False, 162 | right=False, 163 | labelbottom=False, 164 | labelleft=False, 165 | ) 166 | ax2.text( 167 | .5, .5, 168 | topk_subtitles_str, 169 | fontsize=7, 170 | horizontalalignment='center', 171 | verticalalignment='center', 172 | transform=ax2.transAxes, 173 | wrap=True, 174 | ) 175 | animated_frames = [] 176 | for idx, frame in enumerate(frames): 177 | animated_frame = [] 178 | animated_frame.append( 179 | ax1.imshow( 180 | np.flipud(frame), 181 | animated=True, 182 | interpolation="nearest", 183 | ) 184 | ) 185 | animated_frames.append(animated_frame) 186 | fig.tight_layout() 187 | anim = ArtistAnimation( 188 | fig, 189 | animated_frames, 190 | interval=50, 191 | blit=True, 192 | repeat=False, 193 | ) 194 | # save the video locally 195 | video_prefix = f"pl_as_subtitles_{setname}" if pl_as_subtitles else setname 196 | video_path = cfg.paths.log_dir + \ 197 | f"/videos/{video_prefix}_E_{str(epoch + 1)}_I_{str(vis_idx + 1)}.mp4" 198 | if not os.path.exists(cfg.paths.log_dir + "/videos/"): 199 | os.makedirs(cfg.paths.log_dir + "/videos/") 200 | anim.save(video_path, writer="ffmpeg", fps=25) 201 | # upload to wandb 202 | wandb.log( 203 | {f"{video_prefix}--video": wandb.Video(video_path, fps=25, format="mp4")}, 204 | ) 205 | for vis_idx in tqdm(range(cfg.nb_vis)): 206 | # get the similarity scores for the current subtitle 207 | sim_scores = sim_matrix[vis_idx, :] 208 | # get the indices of the top-k most similar videos 209 | topk = torch.topk(torch.tensor(sim_scores), k) 210 | topk_indices, topk_values = topk.indices, topk.values 211 | # get the subtitles corresponding to the indices in question 212 | topk_subtitles = [all_sentences[idx] for idx in topk_indices] 213 | topk_subtitles_str = "" 214 | video_name = video_names[vis_idx] 215 | start, end = sub_starts[vis_idx], sub_ends[vis_idx] 216 | subtitle = all_sentences[vis_idx] 217 | if text_only: 218 | t2v_retrieval_results["gt"].append(subtitle) 219 | t2v_retrieval_results["retrieved_videos"].append(topk_subtitles) 220 | t2v_retrieval_results["gt_start"].append(start) 221 | t2v_retrieval_results["gt_end"].append(end) 222 | t2v_retrieval_results["gt_name"].append(video_name) 223 | t2v_retrieval_results["retrieved_starts"].append( 224 | [sub_starts[idx] for idx in topk_indices] 225 | ) 226 | t2v_retrieval_results["retrieved_ends"].append( 227 | [sub_ends[idx] for idx in topk_indices] 228 | ) 229 | t2v_retrieval_results["retrieved_names"].append( 230 | [video_names[idx] for idx in topk_indices] 231 | ) 232 | t2v_retrieval_results["sims"].append(topk_values) 233 | if text_only: 234 | # save the dictionary 235 | f_path = cfg.paths.log_dir + f"/retrieval_results_{setname}_E_{str(epoch + 1)}.pkl" 236 | saved_pickle = {"t2v": t2v_retrieval_results, "v2t": v2t_retrieval_results} 237 | pickle.dump( 238 | saved_pickle, 239 | open(f_path, "wb") 240 | ) 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # CSLR2 4 | ## Large-Vocabulary Continuous *Sign Language* Recognition
from *Spoken Language* Supervision 5 | 6 | Charles Raude · Prajwal KR · Liliane Momeni · 7 | Hannah Bull · Samuel Albanie · Andrew Zisserman · Gül Varol 8 | 9 | [![arXiv](https://img.shields.io/badge/arXiv-CSLR2-A10717.svg?logo=arXiv)](https://arxiv.org/abs/2405.10266) 10 | [![License](https://img.shields.io/badge/License-MIT-green.svg)]() 11 | 12 |
13 | 14 | ## Description 15 | Official PyTorch implementation of the paper: 16 |
17 | 18 | [**A Tale of Two Languages: Large-Vocabulary Continuous *Sign Language* Recognition from *Spoken Language* Supervision**](https://arxiv.org/abs/2405.10266). 19 | 20 |
21 | 22 | Please visit our [**webpage**](https://imagine.enpc.fr/~varolg/cslr2/) for more details. 23 | 24 | ### Bibtex 25 | If you find this code useful in your research, please cite: 26 | 27 | ```bibtex 28 | @article{raude2024, 29 | title={A Tale of Two Languages: Large-Vocabulary Continuous Sign Language Recognition from Spoken Language Supervision}, 30 | author={Raude, Charles and Prajwal, K R and Momeni, Liliane and Bull, Hannah and Albanie, Samuel and Zisserman, Andrew and Varol, G{\"u}l}, 31 | journal={arXiv}, 32 | year={2024} 33 | } 34 | ``` 35 | 36 | ## Installation :construction_worker: 37 | 38 |
Create environment 39 |   40 | 41 | Create a conda environment associated to this project by running the following lines: 42 | ```bash 43 | conda create -n cslr2 python=3.9.16 44 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge 45 | conda install anaconda::pandas=1.5.3 46 | conda install conda-forge::einops=0.6.0 47 | conda install conda-forge::humanize=4.6.0 48 | conda install conda-forge::tqdm=4.65.0 49 | pip install hydra-core==1.3.2 50 | pip install matplotlib==3.7.1 51 | pip install plotly==5.14.1 52 | pip install nltk==3.8.1 53 | pip install seaborn==0.12.2 54 | pip install sentence-transformers==2.2.2 55 | pip install wandb==0.14.0 56 | pip install lmdb 57 | pip install tabulate 58 | pip install opencv-python==4.7.0.72 59 | ``` 60 | You can also create the environment using the associated `.yaml` file using conda (this might not always work, depending on the machine and the version of conda installed, try to update the version of conda). 61 | 62 | ```bash 63 | conda env create --file=environment.yaml 64 | ``` 65 | 66 | After installing these packages, you will have to install a few `ntlk` packages manually in Python. 67 | 68 | ```python 69 | import nltk 70 | nltk.download("wordnet") 71 | ``` 72 |
73 | 74 |
Set up the BOBSL data 75 | 76 | * Make sure you have the permission to use the BOBSL dataset. You can request access following the instructions at [the official BOBSL webpage](https://www.robots.ox.ac.uk/~vgg/data/bobsl/). 77 | * With the username/password obtained, you can download the two required files via the following: 78 | ``` bash 79 | # Download the pre-extracted video features [262G] 80 | wget --user ${BOBSL_USERNAME} --password ${BOBSL_PASSWORD} \ 81 | https://thor.robots.ox.ac.uk/~vgg/data/bobsl/features/lmdb/feats_vswin_t-bs256_float16/data.mdb 82 | # Download the raw video frames [1.5T] (you can skip this if purely training/testing with features, and not visualizing) 83 | wget --user ${BOBSL_USERNAME} --password ${BOBSL_PASSWORD} \ 84 | https://thor.robots.ox.ac.uk/~vgg/data/bobsl/videos/lmdb/rgb_anon-public_1962/data.mdb 85 | ``` 86 | * Download [`bobsl.zip` 87 | (1.9G)](https://drive.google.com/file/d/13pp83GCoy1SVScvZRmNsoFxtNm8ogI7h/view?usp=sharing) 88 | for the rest of the files (including annotations and metadata). Note the folder becomes 89 | 15G when decompressed. Make sure they correspond to the paths defined here: 90 | `config/paths/public.yaml`. 91 | * Download [`t5_checkpoint.zip` (1.4G)](https://drive.google.com/file/d/1hxkb8KAC0sgSYKefOLue1wyO2fJmmqxT/view?usp=sharing) for the T5 pretrained model weights, also defined at `config/paths/public.yaml`. 92 | 93 |
94 | 95 | ## Training :rocket: 96 | 97 | ```python 98 | export HYDRA_FULL_ERROR=1 # to get better error messages if job crashes 99 | python main.py run_name=cslr2_train 100 | ``` 101 | permits to train the CSLR2 model with the best set of hyperparameters obtained in the paper. 102 | Using 4 x V100-32Gb, training for 20 epochs should take less than 20 hours. 103 | 104 | To change training parameters, you should be looking at changing parameters in the `config/` folder. 105 | 106 | To manually synchronise the offline jobs on wandb, one should run: `wandb sync --sync-all` in the folder of the experiment (do not forget to do `export WANDB_MODE=offline` first). 107 | 108 | Training should save one model per epoch as `$EXP_NAME/models/model_$EPOCH_NB.pth`. Also, the model that obtains the best T2V performance on validation set is saved as `$EXP_NAME/models/model_best.pth`. 109 | 110 | ## Test :bar_chart: 111 | 112 | You can download a pretrained model from [here](https://drive.google.com/file/d/1qyFHSFnxmy1rRGjlKEBfsjC8yt2kdalx/view?usp=sharing). 113 | 114 | ### 1. Retrieval on 25K manually aligned test set 115 | 116 | To test any model for the retrieval task on the 25K manually aligned test set, one should run the following command: 117 | 118 | ```python 119 | python main.py run_name=cslr2_retrieval_25k checkpoint=$PATH_TO_CHECKPOINT test=True 120 | ``` 121 | 122 | ### 2. CSLR evaluation 123 | 124 | CSLR evaluation is done in two steps. First, extract frame-level predictions and then evaluate. 125 | 126 | #### 2.1 Feature Extraction 127 | 128 | ```python 129 | python extract_for_eval.py checkpoint=$PATH_TO_CHECKPOINT 130 | ``` 131 | extracts predictions (linear layer classification, nearest neighbor classification) for both heuristic aligned subtitles and manually aligned subtitles. 132 | 133 | #### 2.2 Evaluation 134 | 135 | ```python 136 | python frame_level_evaluation.py prediction_pickle_files=$PRED_FILES gt_csv_root=$GT_CSV_ROOT 137 | ``` 138 | Note that by default, if gt_csv_root is not provided, it will use `${paths.heuristic_aligned_csv_root}`. 139 | 140 | 141 | ## Pre-processing of gloss annotations :computer: 142 | 143 | You do not need to run this pre-processing, but we release the scripts for how to convert raw 144 | gloss annotations (released from the official BOBSL webpage) into the format used for our 145 | evaluation. A total of 4 steps are required to fully pre-process gloss annotations that are 146 | stored in json files. 147 | 148 |
149 | 1. Assign each annotation to its closest subtitle 150 | 151 | ```python 152 | python misc/process_cslr_json/preprocess_raw_json_annotations.py --output_dir OUTPUT_DIR --input_dir INPUT_DIR --subs_dir SUBS_DIR --subset2episode SUBSET2EPISODE 153 | ``` 154 | where `INPUT_DIR` is the directory where json files are stored and `OUTPUT_DIR` is the directory where the assigned annotations are saved. 155 | `SUBS_DIR` is the directory where manually aligned subtitles are saved. This corresponds to the `subtitles/manually-aligned` files from the public release. 156 | `SUBSET2EPISODE` is the path to the json file containing information about splits and episodes. This corresponds to the `subset2episode.json` file from the public release. 157 |
158 | 159 | 160 | 161 |
162 | 2. Fix boundaries of subtitles. 163 | 164 | During assignment, it could happen that certain annotations overlap with the boundaries of subtitles. It could even happen that certain annotations are not within the boundaries of its associated subtitle. 165 | Since at evaluation time, we load all features corresponding to subtitles timestamps, we need to extend boundaries of certain subtitles. 166 | 167 | ```python 168 | python misc/process_cslr_json/fix_boundaries.py --csv_file OUTPUT_DIR 169 | ``` 170 |
171 | 172 | 173 |
174 | 3. Fix alignment of subtitles. 175 | 176 | Subtitles have been manually aligned. However, since gloss annotations are much more treated more precisely, it could happen that certain gloss annotations better match surrounding subtitles. 177 | In order to fix this, we propose an automatic re-alignment algorithm. 178 | 179 | ```python 180 | python misc/process_cslr_json/fix_alignment.py --csv_file OUTPUT_DIR2 181 | python misc/process_cslr_json/preprocess_raw_json_annotations.py --output_dir OUTPUT_DIR3 --input_dir INPUT_DIR --subs_dir OUTPUT_DIR2 --misalignment_fix 182 | ``` 183 | 184 | where `OUTPUT_DIR2 = OUTPUT_DIR[:-8] + "extended_boundaries_" + OUTPUT_DIR[-8:]` and `OUTPUT_DIR3 = OUTPUT_DIR2[:-8] + "fix_alignment_" + OUTPUT_DIR2[-8:]`. 185 | Here we assume that `OUTPUT_DIR` ends with a date in the format DD.MM.YY 186 |
187 | 188 |
189 | 4. Only keep lexical annotations. 190 | 191 | We only evaluate against lexical annotations: i.e., annotations that are associated with a word. 192 | 193 | ```python 194 | python misc/process_cslr_json/remove_star_annots_from_csvs.py --csv_root OUTPUT_DIR2 # only boundary extension fix 195 | python misc/process_cslr_json/remove_star_annots_from_csvs.py --csv_root OUTPUT_DIR3 # with total alignment fix 196 | ``` 197 |
198 | 199 |
200 | Do all the steps with one command. 201 | 202 | **Instead, you can also use `python misc/process_cslr_json/run_pipeline.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR --subs_dir SUBS_DIR --subset2episode SUBSET2EPISODE`** 203 |
204 | 205 | ## License :books: 206 | This code is developed by [Charles Raude](https://github.com/charles-raude), may not be 207 | maintained, and is distributed under an [MIT LICENSE](LICENSE). 208 | 209 | Note that the code depends on other libraries, including PyTorch, T5, Hydra, and use the BOBSL dataset which each have their own respective licenses that must also be followed. 210 | 211 | The license for the BOBSL-CSLR data can be found at [https://imagine.enpc.fr/~varolg/cslr2/license.txt](https://imagine.enpc.fr/~varolg/cslr2/license.txt). 212 | -------------------------------------------------------------------------------- /dataset/subtitles.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic dataset to load subtitles from data files 3 | """ 4 | import json 5 | import pickle 6 | from typing import Optional 7 | 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class Subtitles(Dataset): 13 | """Generic dataset to load subtitles from data files""" 14 | def __init__( 15 | self, 16 | subset2episode: str, 17 | setname: str, 18 | subtitles_path: str, 19 | subtitles_temporal_shift: float, 20 | subtitles_max_duration: float, 21 | subtitles_min_duration: float, 22 | temporal_pad: float, 23 | info_pkl: str, 24 | filter_stop_words: bool = False, 25 | subtitles_random_offset: Optional[float] = None, 26 | text_augmentations: Optional[object] = None, 27 | fps: int = 25, 28 | verbose: bool = False, 29 | ): 30 | """ 31 | Args: 32 | subset2episode (str): path to the json file containing the mapping 33 | between subset and episode. 34 | setname (str): name of the subset to load. 35 | subtitles_path (str): path to the subtitles pickle file. 36 | subtitles_temporal_shift (float): temporal shift to apply to the 37 | subtitles. 38 | subtitles_max_duration (float): maximum duration of the subtitles. 39 | subtitles_min_duration (float): minimum duration of the subtitles. 40 | temporal_pad (float): temporal padding to apply to the subtitles. 41 | info_pkl (str): path to the info pickle file. 42 | filter_stop_words (bool, optional): whether to filter stop words 43 | subtitles_random_offset (float, optional): randomly add an offset to 44 | the subtitles. 45 | text_augmentations (object, optional): text augmentations to apply 46 | fps (int): fps of the videos associated to the subtitles. 47 | verbose: (bool, optional): verbosity. 48 | """ 49 | self.verbose = verbose 50 | with open(subset2episode, "rb") as json_f: 51 | subset2episode = json.load(json_f) 52 | self.setname = setname 53 | self.setname_episode = subset2episode[self.setname] 54 | del subset2episode 55 | self.text_augmentations = text_augmentations 56 | if self.verbose: 57 | print(f"Loading {self.setname} subtitles.") 58 | with open(subtitles_path, "rb") as pickle_f: 59 | self.subtitles = pickle.load(pickle_f) 60 | if self.verbose: 61 | print(f"Loaded {len(self.subtitles['episode_name'])} subtitles.") 62 | self.subtitles_temporal_shift = subtitles_temporal_shift 63 | self.subtitles_random_offset = subtitles_random_offset 64 | self.subtitles_temporal_pad = temporal_pad 65 | self.fps = fps 66 | for key, val in self.subtitles.items(): 67 | if key in ["start", "end"]: 68 | self.subtitles[key] = np.array( 69 | [ 70 | self.convert_strtime_to_seconds( 71 | time=x, 72 | temporal_shift=self.subtitles_temporal_shift, 73 | ) for x in val 74 | ] 75 | ) 76 | else: 77 | self.subtitles[key] = np.array(val) 78 | # filter by episodes 79 | if self.verbose: 80 | print( 81 | f"Filtering to {self.setname} subtitles.", 82 | ) 83 | filtered_indices = np.where( 84 | np.isin(self.subtitles["episode_name"], self.setname_episode), 85 | )[0] 86 | self.filter_subtitles(filtered_indices) 87 | # filter by duration 88 | if self.verbose: 89 | print( 90 | "Filtering to subtitles with duration in" + 91 | f" [{subtitles_min_duration}, {subtitles_max_duration}].", 92 | ) 93 | filtered_indices = np.where( 94 | self.subtitles["duration"] <= subtitles_max_duration, 95 | )[0] 96 | filtered_indices = np.intersect1d( 97 | filtered_indices, 98 | np.where( 99 | self.subtitles["duration"] >= subtitles_min_duration, 100 | )[0], 101 | ) 102 | self.filter_subtitles(filtered_indices) 103 | # info file 104 | self.info_file_idx = {} 105 | with open(info_pkl, "rb") as pickle_f: 106 | info_file = pickle.load(pickle_f)["videos"] 107 | self.length = info_file["videos"]["T"] 108 | for vid_idx, vid_name in enumerate(info_file["name"]): 109 | if vid_name.split(".")[0] in self.setname_episode: 110 | self.info_file_idx[vid_name] = vid_idx 111 | del info_file 112 | 113 | self.nltk_stop_words = None 114 | if filter_stop_words: 115 | self.nltk_stop_words = { 116 | 'ourselves', 'hers', 'between', 'yourself', 'but', 'again', 117 | 'there', 'about', 'once', 'during', 'out', 'very', 'having', 118 | 'with', 'they', 'own', 'an', 'be', 'some', 'for', 'do', 'its', 119 | 'yours', 'such', 'into', 'of', 'most', 'itself', 'other', 'off', 120 | 'is', 's', 'am', 'or', 'who', 'as', 'from', 'him', 'each', 'the', 121 | 'themselves', 'until', 'below', 'are', 'we', 'these', 'your', 'his', 122 | 'through', 'don', 'nor', 'me', 'were', 'her', 'more', 'himself', 'this', 123 | 'down', 'should', 'our', 'their', 'while', 'above', 'both', 'up', 'to', 'ours', 124 | 'had', 'she', 'all', 'no', 'when', 'at', 'any', 'before', 'them', 'same', 'and', 125 | 'been', 'have', 'in', 'will', 'on', 'does', 'yourselves', 'then', 'that', 126 | 'over', 'so', 'can', 'did', 'not', 'now', 'under', 'he', 'you', 127 | 'herself', 'has', 'just', 'where', 'too', 'only', 'myself', 'which', 'those', 128 | 'i', 'after', 'few', 'whom', 't', 'being', 'if', 'theirs', 'my', 'against', 'a', 129 | 'by', 'doing', 'it', 'how', 'further', 'was', 'here', 'than', 130 | } # removed 'what', 'why' and 'because' 131 | 132 | @staticmethod 133 | def convert_strtime_to_seconds( 134 | time: str, temporal_shift: float, 135 | ) -> float: 136 | """ 137 | Convert a string time in the format HH:MM:SS.SSS to seconds 138 | with a potential additional temporal shift. 139 | 140 | Args: 141 | time (str): time in the format HH:MM:SS.SSS 142 | temporal_shift (float): additional temporal shift in seconds 143 | """ 144 | if isinstance(time, float): 145 | return time + temporal_shift 146 | else: 147 | assert isinstance(time, str) 148 | time = time.split(":") 149 | time = [float(x) for x in time] 150 | time = sum([x * y for x, y in zip(time, [3600, 60, 1, 1e-3])]) 151 | time += temporal_shift 152 | return time 153 | 154 | def filter_subtitles( 155 | self, filtered_indices: np.ndarray, 156 | ) -> None: 157 | """ 158 | Filter self.subtitles wrt. filtered_indices. 159 | Args: 160 | filtered_indices (np.ndarray): indices to keep 161 | """ 162 | previous_length = len(self.subtitles["episode_name"]) 163 | # filtering each key 164 | for key, val in self.subtitles.items(): 165 | self.subtitles[key] = val[filtered_indices] 166 | if self.verbose: 167 | print( 168 | f"\tFrom {previous_length} subtitles," + 169 | f" {len(filtered_indices)} are kept.", 170 | ) 171 | 172 | def shuffle(self) -> None: 173 | """Shuffle all subtitles.""" 174 | shuffled_indices = np.arange(len(self.subtitles["episode_name"])) 175 | for key, val in self.subtitles.items(): 176 | self.subtitles[key] = val[shuffled_indices] 177 | 178 | def __len__(self) -> int: 179 | return len(self.subtitles["episode_name"]) 180 | 181 | def __getitem__(self, idx: int) -> dict: 182 | """Loads subtitles[idx]""" 183 | video_name = self.subtitles["episode_name"][idx] + ".mp4" 184 | if self.subtitles_random_offset is not None and self.subtitles_random_offset > 0: 185 | # adds a random offset to the subtitles 186 | # in (- self.subtitles_random_offset, self.subtitles_random_offset) 187 | sub_start = self.subtitles["start"][idx] - \ 188 | self.subtitles_temporal_pad 189 | sub_end = self.subtitles["end"][idx] + self.subtitles_temporal_pad 190 | random_start = np.random.uniform( 191 | sub_start - self.subtitles_random_offset, 192 | min(sub_start + self.subtitles_random_offset, sub_end - 1.0), 193 | ) # ensure that the random start is at least 1 second before the end 194 | random_end = np.random.uniform( 195 | max(random_start + 1.0, sub_end - self.subtitles_random_offset), 196 | sub_end + self.subtitles_random_offset, 197 | ) 198 | sub_start = max(0, random_start) 199 | # 0.32 is 8 frames at 25 fps (16f windows) 200 | sub_end = min( 201 | random_end, self.length[self.info_file_idx[video_name]] / self.fps - 0.32) 202 | else: 203 | sub_start = max( 204 | 0, self.subtitles["start"][idx] - self.subtitles_temporal_pad 205 | ) 206 | sub_end = min( 207 | self.subtitles["end"][idx] + self.subtitles_temporal_pad, 208 | self.length[self.info_file_idx[video_name]] / self.fps - 0.32, 209 | ) 210 | subtitle = self.subtitles["subtitle"][idx] 211 | 212 | if self.nltk_stop_words is not None: 213 | subtitle = " ".join( 214 | [word for word in subtitle.split() if word not in self.nltk_stop_words] 215 | ) 216 | if self.text_augmentations is not None: 217 | subtitle = self.text_augmentations(subtitle) 218 | 219 | if sub_end - sub_start <= 0.5: 220 | # change the start of the subtitle so that it is at least 1 second long 221 | sub_start = max(0, sub_end - 1.0) 222 | subtitles, sub_starts, sub_ends, video_names = subtitle, sub_start, sub_end, video_name 223 | return { 224 | "subtitle": subtitles, 225 | "sub_start": sub_starts, 226 | "sub_end": sub_ends, 227 | "video_name": video_names, 228 | } 229 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file to train the CSLR2 model 3 | """ 4 | import os 5 | import shutil 6 | from typing import Optional 7 | 8 | import humanize 9 | import hydra 10 | import lmdb 11 | import pandas as pd 12 | import torch 13 | import torch.nn as nn 14 | import wandb 15 | from omegaconf import DictConfig 16 | from tabulate import tabulate 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | 19 | from loops.train_loop import train_loop 20 | from loops.val_loop import val_loop 21 | from loops.retrieval_loop import retrieval_loop 22 | from utils.ddp_settings import ddp_settings 23 | from utils.instantiate_dataloaders import ( 24 | instantiate_dataloaders, instantiate_test_dataloader, 25 | instantiate_vis_dataloaders, skip_epochs 26 | ) 27 | from utils.instantiate_model import ( 28 | handle_model_freeze, instantiate_model, load_checkpoint 29 | ) 30 | from utils.seed import setup_seed 31 | from utils.wandb_utils import log_retrieval_performances, wandb_setup 32 | 33 | 34 | def log_and_save( 35 | model: nn.Module, 36 | opt: torch.optim.Optimizer, 37 | train_loss: float, 38 | val_loss: float, 39 | train_sent_ret: float, 40 | val_sent_ret: float, 41 | train_sign_ret: float, 42 | val_sign_ret: float, 43 | train_sign_cls: float, 44 | val_sign_cls: float, 45 | vis_train_loader: torch.utils.data.DataLoader, 46 | vis_val_loader: torch.utils.data.DataLoader, 47 | rgb_lmdb_env: lmdb.Environment, 48 | cfg: DictConfig, 49 | epoch: int, 50 | best_t2v: int, 51 | ): 52 | """Log and save checkpoint.""" 53 | if cfg.do_print: 54 | epoch_log = { 55 | "Split": ["Train", "Val"], 56 | "Loss": [train_loss, val_loss], 57 | "SentRet": [train_sent_ret, val_sent_ret], 58 | "SignRet": [train_sign_ret, val_sign_ret], 59 | "SignCls": [train_sign_cls, val_sign_cls], 60 | } 61 | train_v2t, train_t2v = retrieval_loop( 62 | model, vis_train_loader, rgb_lmdb_env, cfg, "train", epoch, 63 | ) 64 | val_v2t, val_t2v = retrieval_loop( 65 | model, vis_val_loader, rgb_lmdb_env, cfg, "val", epoch 66 | ) 67 | epoch_log["T2V R@1"] = [train_t2v["R1"], val_t2v["R1"]] 68 | epoch_log["T2V R@5"] = [train_t2v["R5"], val_t2v["R5"]] 69 | epoch_log["T2V MedR"] = [train_t2v["MedR"], val_t2v["MedR"]] 70 | epoch_log["V2T R@1"] = [train_v2t["R1"], val_v2t["R1"]] 71 | epoch_log["V2T R@5"] = [train_v2t["R5"], val_v2t["R5"]] 72 | epoch_log["V2T MedR"] = [train_v2t["MedR"], val_v2t["MedR"]] 73 | log_df = pd.DataFrame(epoch_log) 74 | # display as table 75 | print("") 76 | print( 77 | tabulate( 78 | log_df, 79 | headers="keys", 80 | tablefmt="presto", 81 | showindex="never", 82 | floatfmt=".2f", 83 | ) 84 | ) 85 | print("") 86 | 87 | model_path = cfg.paths.log_dir + \ 88 | "/models/model_" + str(epoch + 1) + ".pth" 89 | if not os.path.exists(cfg.paths.log_dir + "/models/"): 90 | os.makedirs(cfg.paths.log_dir + "/models/") 91 | print(f"Saving model to {model_path}") 92 | model_state_dict = model.state_dict() 93 | torch.save( 94 | { 95 | "model_state_dict": model_state_dict, 96 | "optimizer_state_dict": opt.state_dict(), 97 | "epoch": epoch + 1, 98 | "loss": train_loss, 99 | }, 100 | model_path 101 | ) 102 | if best_t2v <= val_t2v["R1"]: 103 | best_t2v = val_t2v["R1"] 104 | model_path = cfg.paths.log_dir + "/models/model_best.pth" 105 | print(f"Saving new best model to {model_path}") 106 | torch.save( 107 | { 108 | "model_state_dict": model_state_dict, 109 | "optimizer_state_dict": opt.state_dict(), 110 | "epoch": epoch + 1, 111 | "loss": train_loss, 112 | }, 113 | model_path 114 | ) 115 | # log in wandb 116 | log_retrieval_performances( 117 | train_v2t, train_t2v, val_v2t, val_t2v, epoch 118 | ) 119 | wandb.log( 120 | { 121 | "train_loss_epoch": train_loss, 122 | "train_sent_ret_epoch": train_sent_ret, 123 | "train_sign_ret_epoch": train_sign_ret, 124 | "train_sign_cls_epoch": train_sign_cls, 125 | "val_loss_epoch": val_loss, 126 | "val_sent_ret_epoch": val_sent_ret, 127 | "val_sign_ret_epoch": val_sign_ret, 128 | "val_sign_cls_epoch": val_sign_cls, 129 | "epoch": epoch + 1, 130 | } 131 | ) 132 | 133 | 134 | def log_test_retrieval( 135 | test_t2v: dict, 136 | test_v2t: dict, 137 | cfg: DictConfig, 138 | ): 139 | """Log test retrieval performances.""" 140 | if cfg.do_print: 141 | log_dict = { 142 | "T2V R@1": test_t2v["R1"], 143 | "T2V R@5": test_t2v["R5"], 144 | "T2V R@10": test_t2v["R10"], 145 | "T2V R@50": test_t2v["R50"], 146 | "T2V MedR": test_t2v["MedR"], 147 | "T2V MeanR": test_t2v["MeanR"], 148 | "T2V geometric_mean_R1-R5-R10": test_t2v["geometric_mean_R1-R5-R10"], 149 | "V2T R@1": test_v2t["R1"], 150 | "V2T R@5": test_v2t["R5"], 151 | "V2T R@10": test_v2t["R10"], 152 | "V2T R@50": test_v2t["R50"], 153 | "V2T MedR": test_v2t["MedR"], 154 | "V2T MeanR": test_v2t["MeanR"], 155 | "V2T geometric_mean_R1-R5-R10": test_v2t["geometric_mean_R1-R5-R10"], 156 | } 157 | log_df = pd.DataFrame(log_dict, index=[0]) 158 | # display as table 159 | print("") 160 | print( 161 | tabulate( 162 | log_df, 163 | headers="keys", 164 | tablefmt="presto", 165 | showindex="never", 166 | floatfmt=".2f", 167 | ) 168 | ) 169 | print("") 170 | wandb.log(log_dict) 171 | 172 | 173 | @hydra.main(version_base=None, config_path="config", config_name="cslr2") 174 | def main(cfg: Optional[DictConfig] = None) -> None: 175 | """Main function""" 176 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 177 | # setup configuration 178 | cfg = ddp_settings(cfg) 179 | if cfg.wandb_offline: 180 | os.environ["WANDB_MODE"] = "offline" 181 | cfg.do_print = not cfg.distributed or cfg.rank == 0 182 | if not os.path.isdir(cfg.paths.log_dir): 183 | os.makedirs(cfg.paths.log_dir) 184 | print(f"Created {cfg.paths.log_dir}") 185 | cfg.paths.log_dir += f"{cfg.run_name}" 186 | if not os.path.isdir(cfg.paths.log_dir): 187 | os.mkdir(cfg.paths.log_dir) 188 | print(f"Created {cfg.paths.log_dir}") 189 | print(f"Logging to {cfg.paths.log_dir}") 190 | 191 | columns = shutil.get_terminal_size().columns 192 | # avoid deadlock in dataloader + too many open files errors 193 | torch.multiprocessing.set_sharing_strategy("file_system") 194 | 195 | # setup seed 196 | setup_seed(seed=cfg.seed) 197 | 198 | # setup logging 199 | wandb_setup(cfg, setname="CSLR2_train") 200 | 201 | # create model 202 | model = instantiate_model(cfg) 203 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 204 | model.to(device) 205 | model = handle_model_freeze(model, cfg) 206 | 207 | # optimiser 208 | opt = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) 209 | 210 | # load checkpoint 211 | model, opt = load_checkpoint(cfg, model, opt, device) 212 | 213 | # parameter count 214 | param_count = humanize.intword(sum(p.numel() for p in model.parameters())) 215 | trainable_param_count = humanize.intword( 216 | sum( 217 | p.numel() for p in model.parameters() if p.requires_grad 218 | ) 219 | ) 220 | print( 221 | f"Model has {param_count} parameters ({trainable_param_count} trainable)") 222 | 223 | if cfg.distributed: 224 | model = DDP(model, device_ids=[cfg.local_rank]) 225 | 226 | if cfg.vis and cfg.do_print: 227 | rgb_frames = cfg.paths.rgb_frames 228 | rgb_lmdb_env = lmdb.open( 229 | rgb_frames, readonly=True, lock=False, max_readers=512 230 | ) 231 | 232 | if cfg.test: 233 | # perform retrieval on the manually aligned test set with the loaded checkpoint 234 | test_loader = instantiate_test_dataloader(cfg) 235 | test_v2t, test_t2v = retrieval_loop( 236 | model=model, 237 | vis_loader=test_loader, 238 | rgb_lmdb_env=rgb_lmdb_env, 239 | setname="test", 240 | epoch=0, 241 | cfg=cfg, 242 | ) 243 | log_test_retrieval(test_t2v, test_v2t, cfg) 244 | 245 | else: 246 | best_t2v = -1 247 | # create dataset + dataloader 248 | train_loader, val_loader, train_skip_mode, val_skip_mode = instantiate_dataloaders( 249 | cfg) 250 | print(f"Train dataloader size: {len(train_loader)}") 251 | print(f"Val dataloader size: {len(val_loader)}") 252 | vis_train_loader, vis_val_loader = instantiate_vis_dataloaders(cfg) 253 | print(f"Train vis dataloader size: {len(vis_train_loader)}") 254 | print(f"Val vis dataloader size: {len(vis_val_loader)}") 255 | 256 | # eventually skip epochs 257 | train_loader, val_loader = skip_epochs( 258 | cfg, train_loader, val_loader, 259 | train_skip_mode, val_skip_mode, 260 | ) 261 | 262 | # loss function 263 | sent_ret_loss_fn = hydra.utils.instantiate(cfg.loss.sent_ret) 264 | # small hack to avoid getting unused_parameters 265 | # in ddp mode when not using SignRet and SignCls 266 | remove_all = ( 267 | cfg.loss.lda_sign_ret == 0 and cfg.loss.lda_sign_cls == 0 268 | ) 269 | sign_ret_loss_fn = hydra.utils.instantiate(cfg.loss.sign_ret) \ 270 | if (cfg.loss.lda_sign_ret > 0 or remove_all) else None 271 | sign_cls_loss_fn = hydra.utils.instantiate(cfg.loss.sign_cls) \ 272 | if (cfg.loss.lda_sign_cls > 0 or remove_all) else None 273 | 274 | for epoch in range(cfg.trainer.epoch_start, cfg.trainer.epochs): 275 | print("") 276 | print("-" * columns) 277 | print( 278 | f"Epoch {epoch + 1}/{cfg.trainer.epochs}".center(columns) 279 | ) 280 | train_loss, train_sent_ret, train_sign_ret, train_sign_cls = train_loop( 281 | model=model, 282 | opt=opt, 283 | sent_ret_loss_fn=sent_ret_loss_fn, 284 | sign_ret_loss_fn=sign_ret_loss_fn, 285 | sign_cls_loss_fn=sign_cls_loss_fn, 286 | train_loader=train_loader, 287 | epoch=epoch, 288 | cfg=cfg, 289 | ) 290 | val_loss, val_sent_ret, val_sign_ret, val_sign_cls = val_loop( 291 | model=model, 292 | sent_ret_loss_fn=sent_ret_loss_fn, 293 | sign_ret_loss_fn=sign_ret_loss_fn, 294 | sign_cls_loss_fn=sign_cls_loss_fn, 295 | val_loader=val_loader, 296 | epoch=epoch, 297 | cfg=cfg, 298 | ) 299 | log_and_save( 300 | model, opt, 301 | train_loss, val_loss, 302 | train_sent_ret, val_sent_ret, 303 | train_sign_ret, val_sign_ret, 304 | train_sign_cls, val_sign_cls, 305 | vis_train_loader, vis_val_loader, 306 | rgb_lmdb_env, cfg, epoch, best_t2v, 307 | ) 308 | print("Training complete!") 309 | 310 | 311 | if __name__ == "__main__": 312 | main() 313 | -------------------------------------------------------------------------------- /extract_for_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python file to extract features from the trained model for evaluation. 3 | """ 4 | import os 5 | import pickle 6 | from glob import glob 7 | from operator import itemgetter 8 | from typing import Optional, Union 9 | 10 | import hydra 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from einops import rearrange 17 | from omegaconf import DictConfig 18 | from tqdm import tqdm 19 | 20 | from utils.instantiate_model import instantiate_model 21 | from utils.synonyms import fix_synonyms_dict, synonym_combine 22 | 23 | 24 | def load_model(cfg: DictConfig): 25 | """ 26 | Load model from checkpoint. 27 | """ 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | if cfg.swin: 30 | # need to load swin model 31 | chkpt = torch.load( 32 | cfg.checkpoint, 33 | map_location=device, 34 | )["state_dict"] 35 | W = chkpt["module.generator.weight"] 36 | b = chkpt["module.generator.bias"] 37 | model = {"W": W, "b": b} 38 | else: 39 | model = instantiate_model(cfg) 40 | model.to(device) 41 | model.eval() 42 | # load checkpoint 43 | checkpoint = torch.load(cfg.checkpoint, map_location=device) 44 | # remove module. prefix 45 | model_state_dict = checkpoint["model_state_dict"] 46 | model.state_dict = { 47 | k.replace( 48 | "module.", "" 49 | ).replace( 50 | "self_attn", "self_attention" 51 | ): v for k, v in model_state_dict.items() 52 | } 53 | # !!! Be careful with strict=False 54 | # some weights might not be loaded 55 | # and the model might not work as expected 56 | # this will happen silently 57 | model.load_state_dict(model.state_dict, strict=False) 58 | print(f"Loaded checkpoint {cfg.checkpoint}") 59 | return model, device 60 | 61 | 62 | def load_text_files( 63 | cfg: DictConfig, 64 | model: Optional[Union[nn.Module, dict]] = None, 65 | device: Optional[torch.device] = None, 66 | ): 67 | """ 68 | Load word embeddings, vocabulary and synonyms. 69 | """ 70 | word_embds = None 71 | if model is not None: 72 | # load word embeddings 73 | word_embds = pickle.load(open(cfg.paths.word_embds_pkl, "rb")) 74 | if isinstance(word_embds, dict): 75 | # need to convert to list 76 | word_embds = [val for _, val in word_embds.items()] 77 | word_embds = torch.stack(word_embds).to(device) 78 | if not cfg.swin: 79 | word_embds = model.project_word_embeddings(word_embds) 80 | 81 | # load vocab 82 | vocab = pickle.load(open(cfg.paths.vocab_pkl, "rb")) 83 | if "words_to_id" in vocab.keys(): 84 | vocab = vocab["words_to_id"] 85 | id2word = {v: k for k, v in vocab.items()} 86 | # load synonyms 87 | synonyms = pickle.load(open(cfg.paths.synonyms_pkl, "rb")) 88 | synonyms = fix_synonyms_dict(synonyms) 89 | return word_embds, vocab, id2word, synonyms 90 | 91 | 92 | def create_dirs( 93 | out_dir: str, 94 | ): 95 | """ 96 | Create directories for saving features. 97 | """ 98 | if not os.path.isdir(out_dir): 99 | os.makedirs(out_dir) 100 | feats_save_dir_path = os.path.join(out_dir, "features") 101 | if not os.path.isdir(feats_save_dir_path): 102 | os.makedirs(feats_save_dir_path) 103 | classif_save_dir_path = os.path.join(out_dir, "classification") 104 | if not os.path.isdir(classif_save_dir_path): 105 | os.makedirs(classif_save_dir_path) 106 | nn_save_dir_path = os.path.join(out_dir, "nn") 107 | if not os.path.isdir(nn_save_dir_path): 108 | os.makedirs(nn_save_dir_path) 109 | return feats_save_dir_path, classif_save_dir_path, nn_save_dir_path 110 | 111 | 112 | def update_syn_combine( 113 | logits: torch.Tensor, 114 | pred_dict: dict, 115 | dict_key: str, 116 | synonyms: dict, 117 | id2word: dict, 118 | vocab: dict, 119 | synonym_grouping: bool, 120 | ) -> dict: 121 | """ 122 | Update prediction dictionaries with synonym combine if needed. 123 | 124 | Args: 125 | logits (torch.Tensor): logits tensor 126 | pred_dict (dict): prediction dict to update 127 | dict_key (str): key to update 128 | synonyms (dict): synonym dict 129 | id2word (dict): id2word dict 130 | vocab (dict): vocab dict 131 | synonym_grouping (bool): whether to use synonym grouping or not 132 | 133 | Returns: 134 | pred_dict (dict): updated prediction dict 135 | """ 136 | top5_probs, top5_labels = torch.topk( 137 | logits, k=5, dim=-1 138 | ) 139 | top5_probs = top5_probs.cpu().numpy() 140 | top5_labels = top5_labels.cpu().numpy() 141 | if synonym_grouping: 142 | labels = rearrange(top5_labels, "t k -> (t k)") 143 | words = itemgetter(*labels)(id2word) 144 | words = rearrange( 145 | np.array(words), "(t k) -> t k", k=5 146 | ) 147 | new_words, new_probs = [], [] 148 | for word, prob in zip(words, top5_probs): 149 | new_prob, new_word = synonym_combine( 150 | word, prob, synonyms, 151 | ) 152 | new_words.append(new_word) 153 | new_probs.append(new_prob) 154 | new_words = np.array(new_words) 155 | new_words = rearrange(new_words, "t k -> (t k)") 156 | labels = itemgetter(*new_words)(vocab) 157 | labels = rearrange( 158 | np.array(labels), "(t k) -> t k", k=5 159 | ) 160 | top5_labels = labels 161 | top5_probs = np.array(new_probs) 162 | 163 | # update pred_dict 164 | pred_dict[dict_key] = { 165 | "labels": [top5_labels], 166 | "probs": [top5_probs], 167 | "logits": [], 168 | } 169 | return pred_dict 170 | 171 | 172 | def save_dicts( 173 | features: dict, 174 | feats_save_dir_path: str, 175 | classification: dict, 176 | classif_save_dir_path: str, 177 | nn_classification: dict, 178 | nn_save_dir_path: str, 179 | vid_name: str, 180 | ) -> None: 181 | """ 182 | Save predictions that are stored in dictionaries. 183 | 184 | Args: 185 | features (dict): features dict 186 | feats_save_dir_path (str): features save dir path 187 | classification (dict): classification dict 188 | classif_save_dir_path (str): classification save dir path 189 | nn_classification (dict): nn classification dict 190 | nn_save_dir_path (str): nn classification save dir path 191 | vid_name (str): video name 192 | """ 193 | vid_name = f"{vid_name.split('.')[0]}.pkl" 194 | with open(os.path.join(feats_save_dir_path, vid_name), "wb") as feats_f: 195 | pickle.dump(features, feats_f) 196 | with open(os.path.join(classif_save_dir_path, vid_name), "wb") as classif_f: 197 | pickle.dump(classification, classif_f) 198 | with open(os.path.join(nn_save_dir_path, vid_name), "wb") as nn_f: 199 | pickle.dump(nn_classification, nn_f) 200 | 201 | 202 | @hydra.main(version_base=None, config_path="config", config_name="cslr2_eval.yaml") 203 | def main(cfg: Optional[DictConfig] = None) -> None: 204 | """ 205 | Main Funcion to extract features from trained model for evaluation. 206 | """ 207 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 208 | if cfg.swin: 209 | out_dir = cfg.checkpoint 210 | else: 211 | out_dir = os.path.dirname( 212 | os.path.dirname(cfg.checkpoint) 213 | ) 214 | 215 | # load model 216 | model, device = load_model(cfg) 217 | # load other files 218 | word_embds, vocab, id2word, synonyms = load_text_files( 219 | cfg, model, device, 220 | ) 221 | 222 | # extract features (first loop on splits and then on csv roots) 223 | for split in ["train", "val", "test"]: 224 | # load dataset 225 | setname = split if split != "test" else "public_test" 226 | dataset = hydra.utils.instantiate( 227 | cfg.dataset, 228 | setname=setname, 229 | subtitles_max_duration=1000000.0, 230 | subtitles_min_duration=0.0, 231 | ) 232 | for csv_root in [cfg.paths.misaligned_csv_root, cfg.paths.heuristic_aligned_csv_root]: 233 | # 0 is the tolerance 234 | split_csv_root = os.path.join(csv_root, f"0/{split}") 235 | all_csvs = list(glob(os.path.join(split_csv_root, "*.csv"))) 236 | print(f"Found {len(all_csvs)} csvs in {split_csv_root}") 237 | # create the save directory 238 | save_dir_path = os.path.join( 239 | out_dir, 240 | os.path.join(csv_root.split("/")[-2], "eval") 241 | ) 242 | feats_save_dir_path, classif_save_dir_path, nn_save_dir_path = create_dirs( 243 | save_dir_path 244 | ) 245 | for csv in tqdm(all_csvs): 246 | gt_df = pd.read_csv(csv, delimiter=",") 247 | starts = gt_df["start_sub"].tolist() 248 | ends = gt_df["end_sub"].tolist() 249 | subs = gt_df["english sentence"].tolist() 250 | vid_name = os.path.basename(csv) 251 | features = {} 252 | classification = {} 253 | nn_classification = {} 254 | assert len(starts) == len(ends) and len(starts) == len(subs) 255 | for start, end, sub in zip(starts, ends, subs): 256 | start, end = float(start), float(end) 257 | start = max(0.0, start) 258 | end = min( 259 | end, 260 | dataset.subtitles.length[ 261 | dataset.subtitles.info_file_idx[ 262 | vid_name.replace(".csv", ".mp4") 263 | ] 264 | ] / 25 - 0.32, 265 | ) 266 | try: 267 | src = dataset.features.load_sequence( 268 | episode_name=vid_name, 269 | begin_frame=int(start * 25), 270 | end_frame=int(end * 25), 271 | ).to(device).unsqueeze(0) 272 | if cfg.swin: 273 | # swin model 274 | with torch.no_grad(): 275 | src = src.squeeze(0) 276 | logits = src @ model["W"].T + \ 277 | model["b"][None, :] 278 | logits = torch.nn.Softmax(dim=-1)(logits) 279 | dict_key = f"{round(start, 3)}--{round(end, 3)}" 280 | classification = update_syn_combine( 281 | logits, classification, dict_key, synonyms, 282 | id2word, vocab, cfg.synonym_grouping, 283 | ) 284 | 285 | else: 286 | with torch.no_grad(): 287 | cls_tokens, output_tensor = model.video_encoder( 288 | src) 289 | tokens = cls_tokens[:, 1:] if not model.no_video_encoder \ 290 | else cls_tokens 291 | feats = tokens.cpu().numpy() 292 | if model.video_token_ll is not None: 293 | video_tokens = model.project_token_embeddings( 294 | tokens 295 | ) 296 | else: 297 | # normalise 298 | video_tokens = F.normalize(tokens, dim=-1) 299 | video_tokens = video_tokens.squeeze(0) 300 | # compute sim matrix between video tokens and word embeddings 301 | sim_matrix = video_tokens @ word_embds.T 302 | if cfg.synonym_grouping: 303 | sim_matrix = torch.nn.Softmax( 304 | dim=-1 305 | )(sim_matrix / cfg.temp) 306 | dict_key = f"{round(start, 3)}--{round(end, 3)}" 307 | nn_classification = update_syn_combine( 308 | sim_matrix, nn_classification, dict_key, synonyms, 309 | id2word, vocab, cfg.synonym_grouping, 310 | ) 311 | features[dict_key] = feats 312 | 313 | # classification layer (from logits) 314 | output_tensor = output_tensor.squeeze(0) 315 | classification = update_syn_combine( 316 | output_tensor, classification, dict_key, synonyms, 317 | id2word, vocab, cfg.synonym_grouping, 318 | ) 319 | except AttributeError: 320 | print(f"Some error with {vid_name} at {start}--{end}") 321 | print(f"Sub: {sub}") 322 | save_dicts( 323 | features, feats_save_dir_path, 324 | classification, classif_save_dir_path, 325 | nn_classification, nn_save_dir_path, 326 | vid_name, 327 | ) 328 | 329 | 330 | if __name__ == "__main__": 331 | main() 332 | -------------------------------------------------------------------------------- /misc/process_cslr_json/fix_alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to fix alignment of subtitles. 3 | 4 | Heursitic implemented: 5 | - Assumption 1: glosses are well aligned. 6 | Only subtitles start and end times are to be modified. 7 | - For consecutive subtitles, look at boundary glosses: 8 | - If the boundary gloss better corresponds to the other subtitle, 9 | need to change start/end of the two subtitles. 10 | - If the boundary gloss better corresponds to the subtitle itself, 11 | no need to change. 12 | - If the boundary gloss is equally good for both subtitles, 13 | no need to change. 14 | 15 | In the naive implementation, glosses are considered to correspond to a subtitle if 16 | the gloss word is in the subtitle text. 17 | """ 18 | import argparse 19 | import glob 20 | import os 21 | import re 22 | from typing import List 23 | 24 | 25 | import pandas as pd 26 | from nltk.stem import WordNetLemmatizer 27 | 28 | 29 | def get_root_words( 30 | vocab_list: List[str], 31 | ) -> List[str]: 32 | """ 33 | Get the root words from a vocabulary 34 | 35 | Args: 36 | vocab_list (List[str]): vocabulary of words 37 | 38 | Returns: 39 | root_words (List[str]): root words. 40 | """ 41 | wordnet_lemmatizer = WordNetLemmatizer() 42 | root_words = [ 43 | re.sub( 44 | r'[.,?!()"\']', '', word.lower().strip() 45 | ).split() 46 | for word in vocab_list 47 | ] 48 | root_words = [ 49 | word for words in root_words for word in words if word != "" 50 | ] 51 | root_words = set(root_words) 52 | root_words = [ 53 | wordnet_lemmatizer.lemmatize( 54 | wordnet_lemmatizer.lemmatize(word, pos="v"), 55 | pos="n" 56 | ) for word in root_words 57 | ] 58 | root_words = set(root_words) 59 | return root_words 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument( 65 | "--csv_file", 66 | type=str, 67 | required=True, 68 | help="Path to the csv file containing the subtitles." 69 | ) 70 | parser.add_argument( 71 | "--do_star", 72 | action="store_true", 73 | help="If set, will save the star annotations as well." 74 | ) 75 | parser.add_argument( 76 | "--max_shift", 77 | type=float, 78 | default=1.0, 79 | help="Maximum shift allowed for a subtitle (left or right), meaning it could be shifted by 2 * max_shift." # pylint: disable=line-too-long 80 | ) 81 | args = parser.parse_args() 82 | 83 | splits = ["train", "val", "test"] 84 | total_changes = { 85 | "all": {"start": 0, "end": 0, "double": 0, "bd_error": 0}, 86 | "train": {"start": 0, "end": 0, "double": 0, "bd_error": 0}, 87 | "val": {"start": 0, "end": 0, "double": 0, "bd_error": 0}, 88 | "test": {"start": 0, "end": 0, "double": 0, "bd_error": 0}, 89 | } 90 | for split in splits: 91 | print(f"Processing {split} split...") 92 | all_files = glob.glob(os.path.join(args.csv_file, f"0/{split}/*.csv")) 93 | print(f"Found {len(all_files)} files.") 94 | for csv_file in all_files: 95 | # open csv file in question 96 | gt_df = pd.read_csv(csv_file, delimiter=",") 97 | # hard copy 98 | gt_df = gt_df.copy() 99 | # sort by start_sub 100 | gt_df = gt_df.sort_values(by=["start_sub"], ascending=True) 101 | gt_glosses = gt_df["approx gloss sequence"].tolist() 102 | left_boundary = [] 103 | right_boundary = [] 104 | for gt_gloss in gt_glosses: 105 | if isinstance(gt_gloss, float): 106 | # no gt 107 | left_boundary.append("") 108 | right_boundary.append("") 109 | else: 110 | gt_gloss = gt_gloss.replace("]", "[").replace("'", "") 111 | gt_annots = gt_gloss.split("[")[:-1] 112 | gt_timings = gt_gloss.replace(" ", "/").replace("--", "-") 113 | gt_timings = gt_timings.split("[")[:-1] 114 | gt_times = gt_timings[1::2] 115 | gt_annots = gt_annots[::2] 116 | gt_times, gt_annots = zip( 117 | *sorted(zip(gt_times, gt_annots)) 118 | ) 119 | left_boundary.append( 120 | f"{gt_annots[0].strip().replace(' ', '-')} {gt_times[0]}" 121 | ) 122 | right_boundary.append( 123 | f"{gt_annots[-1].strip().replace(' ', '-')} {gt_times[-1]}" 124 | ) 125 | gt_df["left_boundary"] = left_boundary 126 | gt_df["right_boundary"] = right_boundary 127 | # need to loop by pairs of consecutive subtitles 128 | starts = gt_df["start_sub"].tolist() 129 | ends = gt_df["end_sub"].tolist() 130 | subs = gt_df["english sentence"].tolist() 131 | assert len(starts) == len(ends) and \ 132 | len(ends) == len(subs) and \ 133 | len(subs) == len(left_boundary) and \ 134 | len(left_boundary) == len(right_boundary) 135 | 136 | # loop to change both start and times of subtitles 137 | # loop from left to right in timeline 138 | updated_starts = [] 139 | updated_ends = [] 140 | starts_changes = 0 141 | ends_changes = 0 142 | double_changes = 0 # i.e no change at all (heuristic) 143 | for i in range(len(starts) - 1): 144 | if i == 0: 145 | # first subtitle start will not change 146 | updated_starts.append(starts[i]) 147 | 148 | first_sub = subs[i] 149 | second_sub = subs[i + 1] 150 | first_sub_end = ends[i] 151 | second_sub_start = starts[i + 1] 152 | if abs(second_sub_start - first_sub_end) < args.max_shift: 153 | # maximum shift should be of args.max_shift seconds 154 | 155 | # first check that the boundary word does not correspond to second subtitle 156 | # meaning that the first subtitle is ending too late 157 | # and that the second subtitle is starting too late as well 158 | first_sub_right_boundary = right_boundary[i] 159 | first_sub_words = get_root_words(first_sub.split(" ")) 160 | second_sub_words = get_root_words(second_sub.split(" ")) 161 | if first_sub_right_boundary in ["", " "]: 162 | # no annotation 163 | in_first_sub, in_second_sub = True, False 164 | else: 165 | try: 166 | first_sub_right_boundary_words, first_sub_right_boundary_times = \ 167 | first_sub_right_boundary.split(" ") 168 | first_sub_right_boundary_words = get_root_words( 169 | first_sub_right_boundary_words.split("/") 170 | ) 171 | in_first_sub, in_second_sub = False, False 172 | for first_sub_right_boundary_word in first_sub_right_boundary_words: 173 | if first_sub_right_boundary_word in first_sub_words: 174 | in_first_sub = True 175 | if first_sub_right_boundary_word in second_sub_words: 176 | in_second_sub = True 177 | except ValueError: 178 | # no annotation ==> no change 179 | in_first_sub, in_second_sub = True, False 180 | first_change = False 181 | if in_second_sub and not in_first_sub: 182 | first_change = True 183 | # means that the second subtitle should start earlier 184 | # and first subtitle should end earlier 185 | updated_starts.append( 186 | float( 187 | # to avoid same start and end times 188 | first_sub_right_boundary_times.split("-")[0] 189 | ) 190 | ) 191 | updated_ends.append( 192 | float( 193 | first_sub_right_boundary_times.split("-")[0] 194 | ) - 1e-8 195 | ) 196 | starts_changes += 1 197 | 198 | second_sub_left_boundary = left_boundary[i + 1] 199 | # second check that the boundary word does not correspond to first subtitle 200 | # meaning that the second subtitle is starting too late 201 | # and that the first subtitle is ending too early 202 | if second_sub_left_boundary in ["", " "]: 203 | in_first_sub, in_second_sub = False, True 204 | else: 205 | try: 206 | second_sub_left_boundary_words, second_sub_left_boundary_times = \ 207 | second_sub_left_boundary.split(" ") 208 | second_sub_left_boundary_words = get_root_words( 209 | second_sub_left_boundary_words.split("/") 210 | ) 211 | in_first_sub, in_second_sub = False, False 212 | for second_sub_left_boundary_word in second_sub_left_boundary_words: 213 | if second_sub_left_boundary_word in first_sub_words: 214 | in_first_sub = True 215 | if second_sub_left_boundary_word in second_sub_words: 216 | in_second_sub = True 217 | except ValueError: 218 | # no annotation ==> no change 219 | in_first_sub, in_second_sub = False, True 220 | second_change = False 221 | if in_first_sub and not in_second_sub: 222 | second_change = True 223 | if first_change: 224 | # both changes are needed, so no change 225 | updated_ends[-1] = first_sub_end 226 | updated_starts[-1] = second_sub_start 227 | double_changes += 1 228 | starts_changes -= 1 229 | else: 230 | # means that the first subtitle should end later 231 | # and second subtitle should start later 232 | updated_ends.append( 233 | float( 234 | second_sub_left_boundary_times.split( 235 | "-")[-1] 236 | ) 237 | ) 238 | updated_starts.append( 239 | float( 240 | second_sub_left_boundary_times.split( 241 | "-")[-1] 242 | ) + 1e-8 243 | ) 244 | ends_changes += 1 245 | 246 | if not (first_change or second_change): 247 | # no change 248 | updated_starts.append(starts[i + 1]) 249 | updated_ends.append(ends[i]) 250 | else: 251 | # no change 252 | updated_starts.append(starts[i + 1]) 253 | updated_ends.append(ends[i]) 254 | 255 | if i == len(starts) - 2: 256 | # last subtitle end will not change 257 | updated_ends.append( 258 | max( 259 | ends[i + 1], 260 | float( 261 | second_sub_left_boundary_times.split("-")[-1] 262 | ) 263 | ) 264 | ) 265 | try: 266 | assert len(updated_starts) == len(updated_ends) and \ 267 | len(updated_ends) == len(starts) 268 | except AssertionError: 269 | pass 270 | 271 | # check that the new starts are smaller than the new ends 272 | for idx, (start, end) in enumerate(zip(updated_starts, updated_ends)): 273 | if start >= end: 274 | print( 275 | f"Error with {csv_file}: {subs[idx]} | {start} | {end}", 276 | "Reverting to original start and end times." 277 | ) 278 | updated_starts[idx] = starts[idx] 279 | updated_ends[idx] = ends[idx] 280 | for sp in ["all", split]: 281 | total_changes[sp]["bd_error"] += 1 282 | 283 | gt_df["start sub (after alignement heuristic 1)"] = updated_starts 284 | gt_df["end sub (after alignement heuristic 1)"] = updated_ends 285 | # save the csv file 286 | gt_df.to_csv(csv_file, index=False) 287 | for sp in ["all", split]: 288 | total_changes[sp]["start"] += starts_changes 289 | total_changes[sp]["end"] += ends_changes 290 | total_changes[sp]["double"] += double_changes 291 | 292 | # open the star csv file in question 293 | if args.do_star: 294 | star_gt_df = pd.read_csv( 295 | csv_file.replace( 296 | "with_timings", "with_stars_with_timings2"), 297 | delimiter="," 298 | ) 299 | star_gt_df = star_gt_df.sort_values( 300 | by=["start_sub"], ascending=True 301 | ) 302 | star_gt_df["start sub (after alignement heuristic 1)"] = updated_starts 303 | star_gt_df["end sub (after alignement heuristic 1)"] = updated_ends 304 | star_gt_df.to_csv( 305 | csv_file.replace( 306 | "with_timings", "with_stars_with_timings2" 307 | ), 308 | index=False 309 | ) 310 | print(total_changes) 311 | -------------------------------------------------------------------------------- /utils/frame_level_evaluation_dict.py: -------------------------------------------------------------------------------- 1 | """Python file with all the functions used to work with the frame level evaluation dictionary.""" 2 | import os 3 | import pickle 4 | from copy import deepcopy 5 | from operator import itemgetter 6 | from typing import List, Optional, Union 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from einops import rearrange 11 | from tqdm import tqdm 12 | 13 | from utils.cslr_metrics import get_labels_start_end_time 14 | from utils.root_words import get_root_words 15 | from utils.synonyms import synonym_combine 16 | 17 | def pred_pickles_to_frame_level_predictions( 18 | pred_pickles: Union[List[str], str], 19 | id2word_dict: dict, 20 | synonyms: Optional[dict] = None, 21 | automatic_annotations: bool = False, 22 | remove_synonym_grouping: bool = False, 23 | ) -> dict: 24 | """ 25 | Load predictions saved in pickle format. 26 | Convert to frame-level predictions. 27 | 28 | Args: 29 | pred_pickles (Union[List[str], str]): list of paths to pickle files 30 | or path to a single pickle file. 31 | id2word_dict (dict): dictionary mapping from index to word. 32 | synonyms (Optional[dict], optional): synonym dictionary. Defaults to None. 33 | If a synonym dictionary is provided, predictions will be combined between 34 | synonyms. 35 | automatic_annotations (bool): whether the predictions are automatic annotations. 36 | Defaults to False. 37 | remove_synonym_grouping (bool): whether to remove synonym grouping. 38 | 39 | Returns: 40 | dict: frame-level predictions with the following keys 41 | episode_name (List[str]): list of episode names 42 | sub_start (List[float]): list of start times 43 | sub_end (List[float]): list of end times 44 | labels (List[np.ndarray]): list of labels (frame-level) 45 | words (List[np.ndarray]): list of words (frame-level) 46 | probs (List[np.ndarray]): list of probabilities (frame-level) 47 | unique_key (List[str]): list of unique keys in format episode_name--start--end 48 | """ 49 | out_dict = { 50 | "episode_name": [], 51 | "sub_start": [], 52 | "sub_end": [], 53 | "labels": [], 54 | "words": [], 55 | "probs": [], 56 | "unique_key": [], 57 | } 58 | if synonyms is not None: 59 | word2id_dict = {v: k for k, v in id2word_dict.items()} 60 | if isinstance(pred_pickles, str): 61 | pred_pickles = [pred_pickles] 62 | for pred_pickle in tqdm(pred_pickles): 63 | # episode name 64 | episode_name = os.path.basename(pred_pickle).replace(".pkl", "") 65 | # load predictions 66 | predictions = pickle.load(open(pred_pickle, "rb")) 67 | for timings, preds in predictions.items(): 68 | # episode name 69 | out_dict["episode_name"].append(episode_name) 70 | # timings 71 | try: 72 | sub_start, sub_end = timings.split("--") 73 | except ValueError: 74 | sub_start, sub_end = timings.split("-") 75 | out_dict["sub_start"].append(float(sub_start)) 76 | out_dict["sub_end"].append(float(sub_end)) 77 | unique_key = f"{episode_name}--{float(sub_start):.2f}--{float(sub_end):.2f}" 78 | out_dict["unique_key"].append(unique_key) 79 | 80 | if not automatic_annotations: 81 | # labels and probds 82 | labels = np.array(preds["labels"][0]) 83 | probs = np.array(preds["probs"][0]) 84 | # check if batch size is not one 85 | if len(labels.shape) == 1: 86 | labels = np.expand_dims(labels, axis=0) 87 | probs = np.expand_dims(probs, axis=0) 88 | labels = rearrange(labels, "t k -> (t k)") 89 | words = itemgetter(*labels)(id2word_dict) 90 | words = rearrange(np.array(words), "(t k) -> t k", k=5) 91 | labels = rearrange(labels, "(t k) -> t k", k=5) 92 | if synonyms is not None and not remove_synonym_grouping: 93 | # synonym grouping 94 | new_words, new_probs = [], [] 95 | for word_top5, probs_top5 in zip(words, probs): 96 | new_probs_top5, new_word_top5 = synonym_combine( 97 | word_top5, probs_top5, synonyms 98 | ) 99 | new_words.append(new_word_top5) 100 | new_probs.append(new_probs_top5) 101 | new_words = np.array(new_words) 102 | probs = np.array(new_probs) 103 | new_words = rearrange(new_words, "t k -> (t k)") 104 | labels = itemgetter(*new_words)(word2id_dict) 105 | labels = rearrange(np.array(labels), "(t k) -> t k", k=5) 106 | words = rearrange(new_words, "(t k) -> t k", k=5) 107 | # only keep top 1 108 | labels, probs, words = labels[:, 0], probs[:, 0], words[:, 0] 109 | else: 110 | # labels and probs 111 | labels = np.array(preds["labels"]) 112 | probs = np.array(preds["probs"]) 113 | try: 114 | if len(labels) == 1: 115 | words = np.array([itemgetter(*labels)(id2word_dict)]) 116 | else: 117 | words = np.array(itemgetter(*labels)(id2word_dict)) 118 | except TypeError: 119 | words = [] 120 | # lemmatise words 121 | try: 122 | for word_idx, word in enumerate(words): 123 | if " " in word: 124 | words[word_idx] = word.replace(" ", "-") 125 | words = get_root_words(words) 126 | if len(words) != len(labels): 127 | raise ValueError("Length mismatch") 128 | assert len(words) == len(labels), print(words) 129 | except TypeError: 130 | pass 131 | out_dict["labels"].append(labels) 132 | out_dict["probs"].append(probs) 133 | out_dict["words"].append(words) 134 | return out_dict 135 | 136 | 137 | def gt_csvs_to_frame_level_gt( 138 | gt_csvs: Union[List[str], str], 139 | fps: int = 25, 140 | ) -> dict: 141 | """ 142 | Load ground truth saved in csv format. 143 | Convert to frame-level predictions. 144 | 145 | Args: 146 | gt_csvs (Union[List[str], str]): list of paths to csv files. 147 | 148 | Returns: 149 | dict: frame-level ground truth with the following keys 150 | episode_name (List[str]): list of episode names 151 | sub_start (List[float]): list of start times 152 | sub_end (List[float]): list of end times 153 | frame_ground_truth (List[np.ndarray]): list of ground truth (frame-level) 154 | segment_ground_truth (List[str]): list of ground truth (segment-level) 155 | raw_segment_ground_truth (List[List[str]]): list of ground truth 156 | (segment level, without collapsing) 157 | subtitles (List[np.ndarray]): list of subtitles 158 | unique_key (List[str]): list of unique keys in format episode_name--start--end 159 | """ 160 | out_dict = { 161 | "episode_name": [], 162 | "sub_start": [], 163 | "sub_end": [], 164 | "frame_ground_truth": [], 165 | "segment_ground_truth": [], 166 | "raw_segment_ground_truth": [], 167 | "subtitles": [], 168 | "unique_key": [], 169 | } 170 | 171 | if isinstance(gt_csvs, str): 172 | gt_csvs = [gt_csvs] 173 | for gt_csv in tqdm(gt_csvs): 174 | # episode_name 175 | episode_name = os.path.basename(gt_csv).replace(".csv", "") 176 | # load 177 | try: 178 | gt_df = pd.read_csv(gt_csv, delimiter=",") 179 | starts, ends = gt_df["start_sub"].tolist( 180 | ), gt_df["end_sub"].tolist() 181 | subs = gt_df["english sentence"].tolist() 182 | gt_glosses = gt_df["approx gloss sequence"].tolist() 183 | assert len(starts) == len(ends) and len(ends) == len( 184 | subs) and len(subs) == len(gt_glosses) 185 | for start, end, sub, gt_gloss in zip(starts, ends, subs, gt_glosses): 186 | if isinstance(gt_gloss, float) or start >= end: 187 | # no gt 188 | # the second condition should not happen with the new fix_alignement.py 189 | # temporary fix 190 | pass 191 | else: 192 | # episode name 193 | out_dict["episode_name"].append(episode_name) 194 | out_dict["sub_start"].append(float(start)) 195 | out_dict["sub_end"].append(float(end)) 196 | unique_key = f"{episode_name}--{start:.2f}--{end:.2f}" 197 | out_dict["unique_key"].append(unique_key) 198 | # frame-level ground truth 199 | gt_labels, gt_segment, gt_segment_raw = gloss_update( 200 | gt_gloss, 201 | start, 202 | end, 203 | fps=fps, 204 | ) 205 | out_dict["frame_ground_truth"].append(gt_labels) 206 | out_dict["segment_ground_truth"].append(gt_segment) 207 | out_dict["raw_segment_ground_truth"].append(gt_segment_raw) 208 | out_dict["subtitles"].append(sub) 209 | except Exception as gt_exception: 210 | print(f"Error with {gt_csv}: {gt_exception}") 211 | return out_dict 212 | 213 | 214 | def populate_combined_dict( 215 | combined_dictionary: dict, 216 | input_dictionary: dict, 217 | setname: str, 218 | ) -> dict: 219 | """ 220 | Populate combined dictionary with input dictionary. 221 | 222 | Args: 223 | combined_dictionary (dict): combined dictionary 224 | input_dictionary (dict): input dictionary 225 | setname (str): set name, either "gt" or "pred" 226 | 227 | Returns: 228 | combined_dictionary: populated combined dictionary 229 | """ 230 | for key, value in input_dictionary.items(): 231 | if key in ["episode_name", "sub_start", "sub_end", "unique_key"]: 232 | key += f"_{setname}" 233 | combined_dictionary[key] = value 234 | return combined_dictionary 235 | 236 | 237 | def combine_gt_pred_dict( 238 | gt_dictionary: dict, 239 | pred_dictionary: dict, 240 | ) -> dict: 241 | """ 242 | Combine ground truth and predictions into a single dictionary. 243 | 244 | Args: 245 | gt_dictionary (dict): ground truth dictionary 246 | pred_dictionary (dict): predictions dictionary 247 | 248 | Returns: 249 | combined_dictionary: combined dictionary 250 | """ 251 | combined_dictionary = {} 252 | copy_pred_dictionary = deepcopy(pred_dictionary) 253 | 254 | pred_unique_keys = np.array(copy_pred_dictionary["unique_key"]) 255 | gt_unique_keys = np.array(gt_dictionary["unique_key"]) 256 | filtering, mapping = np.where( 257 | pred_unique_keys[:, None] == gt_unique_keys[None, :] 258 | ) 259 | 260 | # first filter out the predictions that don't have ground truth 261 | for key, value in copy_pred_dictionary.items(): 262 | copy_pred_dictionary[key] = np.array( 263 | value, dtype=object, 264 | )[filtering].tolist() 265 | # next, order gt to match pred in terms of order 266 | for key, value in gt_dictionary.items(): 267 | gt_dictionary[key] = np.array(value, dtype=object)[ 268 | mapping].tolist() 269 | # then populate the combined dictionary 270 | combined_dictionary = populate_combined_dict( 271 | combined_dictionary, gt_dictionary, "gt", 272 | ) 273 | combined_dictionary = populate_combined_dict( 274 | combined_dictionary, copy_pred_dictionary, "pred", 275 | ) 276 | return combined_dictionary 277 | 278 | 279 | def save_all_annots( 280 | combined_dictionary: dict, 281 | ) -> dict: 282 | """ 283 | Goes through all annotations and saves them in a dictionary. 284 | 285 | Args: 286 | combined_dictionary (dict): combined dictionary 287 | 288 | Returns: 289 | all_annots (dict): dictionary of all annotations. Keys are words, values are counts. 290 | """ 291 | all_annots = {} 292 | frame_wise_gt_labels = combined_dictionary["frame_ground_truth"] 293 | for gt_labels in frame_wise_gt_labels: 294 | gt_segments, _, _ = get_labels_start_end_time( 295 | gt_labels, 296 | bg_class=["no annotation"], 297 | ) 298 | for segment in gt_segments: 299 | for word in segment: 300 | try: 301 | all_annots[word] += 1 302 | except KeyError: 303 | all_annots[word] = 1 304 | return all_annots 305 | 306 | def gloss_update( 307 | gloss: str, 308 | start: Union[str, float], 309 | end: Union[str, float], 310 | fps: int, 311 | stars: bool = False, 312 | ): 313 | """ 314 | Computes frame-level ground truth from gloss (with timings). 315 | 316 | Args: 317 | gloss (str): string with glosses along timings 318 | start (Union[str, float]): start time 319 | end (Union[str, float]): end time 320 | fps (int): fps of the video 321 | stars (bool): whether behaviour has star annotations loading. 322 | Defaults to False. 323 | 324 | Returns: 325 | labels (List[List[str]]): frame-level ground truth 326 | segment (str): segment-level ground truth in one string 327 | segment_raw (List[List[str]]): segment-level ground truth without collapsing 328 | """ 329 | labels = [["no annotation"]] * \ 330 | int((float(end) - float(start)) * fps) 331 | timings = gloss.replace("]", "[").replace("'", "") 332 | raw_annots = timings.split("[")[:-1][::2] 333 | segment_raw = [] 334 | for raw_annot in raw_annots: 335 | if (stars and "*" in raw_annot) or not stars: 336 | raw_annot = raw_annot.split("/") 337 | segment = [] 338 | for annot in raw_annot: 339 | if len(annot) > 0: 340 | if annot[0] == " ": 341 | annot = annot[1:] 342 | segment.append(annot) 343 | segment_raw.append(segment) 344 | timings = timings.replace(" ", "/").replace("--", "-") 345 | timings = timings.split('[')[:-1] 346 | annots = timings[::2] 347 | # lemmatise annotations 348 | annots = [ 349 | annot if annot[0] != "/" else annot[1:] 350 | for annot in annots 351 | ] 352 | annots_seg = get_root_words(annots, True) 353 | annots = get_root_words(annots) 354 | times = timings[1::2] 355 | assert len(annots) == len(times) 356 | for annot, time in zip(annots, times): 357 | has_star = True 358 | if stars: 359 | has_star = any(["*" in ann for ann in annot]) 360 | if has_star: 361 | start_time, end_time = time.split("-") 362 | start_time = float(start_time) 363 | end_time = float(end_time) 364 | start_idx = int( 365 | (start_time - float(start)) * fps 366 | ) 367 | end_idx = min( 368 | len(labels), 369 | int((end_time - float(start)) * fps) 370 | ) 371 | annot_len = end_idx - start_idx 372 | if annot_len < 1: 373 | # no annotation 374 | pass 375 | else: 376 | labels[start_idx:end_idx] = [annot] * annot_len 377 | return labels, " ".join(annots_seg), segment_raw 378 | -------------------------------------------------------------------------------- /models/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Transformer Encoder model. 3 | """ 4 | import copy 5 | import math 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | 14 | def _xavier_uniform(module: nn.Module): 15 | """ 16 | Xavier uniform initialization for the weights of a module. 17 | 18 | Args: 19 | module (nn.Module): module to initialize. 20 | """ 21 | for _, params in module.named_parameters(): 22 | if params.dim() > 1: 23 | nn.init.xavier_uniform_(params) 24 | 25 | 26 | def clones( 27 | module: nn.Module, 28 | nb_clones: int, 29 | ) -> nn.ModuleList: 30 | """ 31 | Produce nb_clones identical layers. 32 | 33 | Args: 34 | module (nn.Module): module to clone. 35 | nb_clones (int): number of clones. 36 | 37 | Returns: 38 | nn.ModuleList: list of cloned modules. 39 | """ 40 | return nn.ModuleList([copy.deepcopy(module) for _ in range(nb_clones)]) 41 | 42 | 43 | class Encoder(nn.Module): 44 | """Encoder defined as a stack of N layers.""" 45 | def __init__( 46 | self, 47 | layer: nn.Module, 48 | N: int, 49 | final_norm: bool = True, 50 | ) -> None: 51 | """ 52 | Args: 53 | layer (nn.Module): layer to use. 54 | N (int): number of layers. 55 | final_norm (bool): whether to apply layer normalization at the end of the encoder. 56 | """ 57 | super(Encoder, self).__init__() 58 | self.layers = clones(layer, N) 59 | if final_norm: 60 | self.norm = LayerNorm(layer.size) 61 | 62 | def forward( 63 | self, 64 | x: torch.tensor, 65 | mask: Optional[torch.tensor] = None, 66 | ) -> torch.tensor: 67 | """ 68 | Forward function of the encoder module. 69 | Pass the input (and mask) through each layer in turn. 70 | 71 | Args: 72 | x (torch.tensor): input tensor. 73 | mask (Optional[torch.tensor]): tensor of masks. 74 | 75 | Returns: 76 | torch.tensor: output tensor. 77 | """ 78 | for layer in self.layers: 79 | x = layer(x, mask) 80 | return (self.norm(x) if hasattr(self, 'norm') else x) 81 | 82 | 83 | class LayerNorm(nn.Module): 84 | """LayerNorm module.""" 85 | def __init__( 86 | self, 87 | size: List[int], 88 | eps: float = 1e-6, 89 | ) -> None: 90 | """ 91 | Args: 92 | size (List[int]): size of the layer. 93 | eps (float): epsilon value for numerical stability. 94 | """ 95 | super(LayerNorm, self).__init__() 96 | self.a_2 = nn.Parameter(torch.ones(size)) 97 | self.b_2 = nn.Parameter(torch.zeros(size)) 98 | self.eps = eps 99 | 100 | def forward( 101 | self, 102 | x: torch.tensor, 103 | ) -> torch.tensor: 104 | """Perform layer normalisation on the input tensor.""" 105 | mean = x.mean(-1, keepdim=True) 106 | std = x.std(-1, keepdim=True) 107 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 108 | 109 | 110 | class SublayerConnection(nn.Module): 111 | """Residual connection followed by a layer norm.""" 112 | def __init__( 113 | self, 114 | size: List[int], 115 | dropout: float, 116 | ) -> None: 117 | """ 118 | Args: 119 | size (List[int]): size of the layer. 120 | dropout (float): dropout rate. 121 | """ 122 | super(SublayerConnection, self).__init__() 123 | self.norm = LayerNorm(size) 124 | self.dropout = nn.Dropout(dropout) 125 | 126 | def forward( 127 | self, 128 | x: torch.tensor, 129 | sublayer: nn.Module, 130 | ): 131 | """ 132 | Apply residual connection to any sublayer with the same size. 133 | 134 | Args: 135 | x (torch.tensor): input tensor. 136 | sublayer (nn.Module): sublayer to apply. 137 | 138 | Returns: 139 | torch.tensor: output tensor. 140 | """ 141 | return x + self.dropout(sublayer(self.norm(x))) 142 | 143 | 144 | class EncoderLayer(nn.Module): 145 | """Encoder layer module""" 146 | def __init__( 147 | self, 148 | size: List[int], 149 | self_attention: nn.Module, 150 | feed_forward: nn.Module, 151 | dropout: float, 152 | ) -> None: 153 | """ 154 | Args: 155 | size (List[int]): size of the layer. 156 | self_attention (nn.Module): self-attention module. 157 | feed_forward (nn.Module): feed-forward module. 158 | dropout (float): dropout rate. 159 | """ 160 | super(EncoderLayer, self).__init__() 161 | self.self_attention = self_attention 162 | self.feed_forward = feed_forward 163 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 164 | self.size = size 165 | 166 | def forward( 167 | self, 168 | x: torch.tensor, 169 | mask: Optional[torch.tensor] = None, 170 | ) -> torch.tensor: 171 | """Perform forward pass on the input tensor.""" 172 | x = self.sublayer[0](x, lambda x: self.self_attention(x, x, x, mask)) 173 | return self.sublayer[1](x, self.feed_forward) 174 | 175 | 176 | class MultiHeadedAttention(nn.Module): 177 | """Multi-headed attention module.""" 178 | def __init__( 179 | self, 180 | h: int, 181 | d_model: int, 182 | dropout: float = 0.1, 183 | ) -> None: 184 | """ 185 | Args: 186 | h (int): number of heads. 187 | d_model (int): size of the model. 188 | dropout (float): dropout rate. 189 | """ 190 | super(MultiHeadedAttention, self).__init__() 191 | assert d_model % h == 0 192 | # we assume d_v always equals d_k 193 | self.d_k = d_model // h 194 | self.h = h 195 | self.linears = clones(nn.Linear(d_model, d_model), 4) 196 | self.attn = None 197 | self.dropout = nn.Dropout(p=dropout) 198 | 199 | @staticmethod 200 | def attention( 201 | query: torch.tensor, 202 | key: torch.tensor, 203 | value: torch.tensor, 204 | mask: Optional[torch.tensor] = None, 205 | dropout: Optional[nn.Dropout] = None, 206 | ): 207 | """ 208 | Compute scaled dot product attention. 209 | 210 | Args: 211 | query (torch.tensor): query tensor. 212 | key (torch.tensor): key tensor. 213 | value (torch.tensor): value tensor. 214 | mask (Optional[torch.tensor]): tensor of masks. 215 | dropout (Optional[nn.Dropout]): dropout module. 216 | 217 | Returns: 218 | output tensor and attention weights. 219 | """ 220 | d_k = query.size(-1) 221 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 222 | if mask is not None: 223 | scores = scores.float() 224 | scores = scores.masked_fill(mask == 0, -1e9) 225 | p_attn = F.softmax(scores, dim=-1) 226 | if dropout is not None: 227 | p_attn = dropout(p_attn) 228 | return torch.matmul(p_attn, value), p_attn 229 | 230 | def forward( 231 | self, 232 | query: torch.tensor, 233 | key: torch.tensor, 234 | value: torch.tensor, 235 | mask: Optional[torch.tensor] = None, 236 | ) -> torch.tensor: 237 | """ 238 | Forward function of the multi-headed attention module. 239 | 240 | Args: 241 | query (torch.tensor): query tensor. 242 | key (torch.tensor): key tensor. 243 | value (torch.tensor): value tensor. 244 | mask (Optional[torch.tensor]): tensor of masks. 245 | 246 | Returns: 247 | torch.tensor: output tensor. 248 | """ 249 | if mask is not None: 250 | mask = mask.unsqueeze(1) 251 | nbatches = query.size(0) 252 | 253 | # 1) do all the linear projections in batch from d_model => h x d_k 254 | query, key, value = \ 255 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 256 | for l, x in zip(self.linears, (query, key, value))] 257 | 258 | # 2) apply attention on all the projected vectors in batch. 259 | x, self.attn = self.attention(query, key, value, mask=mask, dropout=self.dropout) 260 | 261 | # 3) "concat" using a view and apply a final linear. 262 | x = x.transpose(1, 2).contiguous() \ 263 | .view(nbatches, -1, self.h * self.d_k) 264 | return self.linears[-1](x) 265 | 266 | 267 | class PositionWiseFeedForward(nn.Module): 268 | """Implements FFN equation.""" 269 | def __init__( 270 | self, 271 | d_model: int, 272 | d_ff: int, 273 | dropout: float = 0.1, 274 | ) -> None: 275 | """ 276 | Args: 277 | d_model (int): size of the model. 278 | d_ff (int): size of the feed-forward layer. 279 | dropout (float): dropout rate. 280 | """ 281 | super(PositionWiseFeedForward, self).__init__() 282 | self.w_1 = nn.Linear(d_model, d_ff) 283 | self.w_2 = nn.Linear(d_ff, d_model) 284 | self.dropout = nn.Dropout(dropout) 285 | 286 | def forward(self, x: torch.tensor) -> torch.tensor: 287 | """Perform forward pass on the input tensor.""" 288 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 289 | 290 | 291 | class PositionalEncoding(nn.Module): 292 | """Positional Encoding module.""" 293 | def __init__( 294 | self, 295 | d_model: int, 296 | dropout: float, 297 | max_len: int = 5000, 298 | ) -> None: 299 | """ 300 | Args: 301 | d_model (int): size of the model. 302 | dropout (float): dropout rate. 303 | max_len (int): maximum length of the sequence. 304 | """ 305 | super(PositionalEncoding, self).__init__() 306 | self.dropout = nn.Dropout(p=dropout) 307 | 308 | # compute the positional encodings once in log space 309 | pe = torch.zeros(max_len, d_model) 310 | position = torch.arange(0, max_len).unsqueeze(1) 311 | div_term = torch.exp(torch.arange(0, d_model, 2) * 312 | -(math.log(10000.0) / d_model)) 313 | pe[:, 0::2] = torch.sin(position * div_term) 314 | pe[:, 1::2] = torch.cos(position * div_term) 315 | pe = pe.unsqueeze(0) 316 | self.register_buffer('pe', pe) 317 | 318 | def forward(self, x: torch.tensor) -> torch.tensor: 319 | """Perform positional encoding on the input tensor.""" 320 | x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 321 | return self.dropout(x) 322 | 323 | 324 | class Embeddings(nn.Module): 325 | """Embedding module.""" 326 | def __init__( 327 | self, 328 | d_model: int, 329 | vocab: int, 330 | ) -> None: 331 | """ 332 | Args: 333 | d_model (int): size of the model. 334 | vocab (int): size of the vocabulary. 335 | """ 336 | super(Embeddings, self).__init__() 337 | self.lut = nn.Embedding(vocab, d_model) 338 | self.d_model = d_model 339 | 340 | def forward(self, x: torch.tensor) -> torch.tensor: 341 | """Perform forward pass on the input tensor using the embedding layer.""" 342 | return self.lut(x) * math.sqrt(self.d_model) 343 | 344 | 345 | class Generator(nn.Module): 346 | """Generator module (linear projection in vocab space).""" 347 | def __init__( 348 | self, 349 | d_model: int, 350 | vocab: int, 351 | ) -> None: 352 | """ 353 | Args: 354 | d_model (int): size of the model. 355 | vocab (int): size of the vocabulary. 356 | """ 357 | super(Generator, self).__init__() 358 | self.proj = nn.Linear(d_model, vocab) 359 | 360 | def forward(self, x: torch.tensor) -> torch.tensor: 361 | """Project the input tensor into the vocabulary space.""" 362 | return self.proj(x) 363 | 364 | 365 | class EncoderWithLinear(nn.Module): 366 | """Encoder with linear projection.""" 367 | def __init__( 368 | self, 369 | encoder: Encoder, 370 | generator: Generator, 371 | src_embed: Embeddings, 372 | d_model: int = 768, 373 | contrastive: bool = False, 374 | ) -> None: 375 | """ 376 | Args: 377 | encoder (Encoder): encoder module. 378 | generator (Generator): generator module. 379 | src_embed (Embeddings): positional encoding module. 380 | d_model (int): size of the model. 381 | contrastive (bool): whether to use contrastive learning. 382 | """ 383 | super(EncoderWithLinear, self).__init__() 384 | self.encoder = encoder 385 | self.generator = generator 386 | self.src_embed = src_embed 387 | self.contrastive = contrastive 388 | self.d_model = d_model 389 | self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) if self.contrastive else None 390 | 391 | def forward( 392 | self, 393 | src: Optional[torch.tensor] = None, 394 | src_mask: Optional[torch.tensor] = None, 395 | ): 396 | """ 397 | Forward function of the encoder module. 398 | 399 | Args: 400 | src (Optional[torch.tensor]): input tensor. 401 | src_mask (Optional[torch.tensor]): tensor of masks. 402 | 403 | Returns: 404 | If contrastive, return the encoder output 405 | and the linear projection of the encoder output. 406 | Else return the linear projection of the encoder output. 407 | """ 408 | encoder_out, src_mask = self.encode(src, src_mask) 409 | if self.contrastive: 410 | return encoder_out, self.generator(encoder_out[:, 1:, :]) 411 | return self.generator(encoder_out) 412 | 413 | def encode( 414 | self, 415 | src: torch.tensor, 416 | src_mask: torch.tensor, 417 | ): 418 | """ 419 | Encode the input tensor. 420 | 421 | Args: 422 | src (torch.tensor): input tensor. 423 | src_mask (torch.tensor): tensor of masks. 424 | 425 | Returns: 426 | encoder output and mask. 427 | """ 428 | if isinstance(src, tuple): 429 | src, src_mask = src 430 | src_embeddings = self.src_embed(src) 431 | if self.contrastive: 432 | src_embeddings = torch.cat( 433 | [ 434 | self.cls_token.expand(src_embeddings.size(0), -1, -1), 435 | src_embeddings 436 | ], 437 | dim=1, 438 | ) 439 | return self.encoder(src_embeddings, src_mask), src_mask 440 | 441 | 442 | def make_model( 443 | vocab: int, 444 | N: int = 6, 445 | d_model: int = 512, 446 | d_ff: int = 2048, 447 | h: int = 8, 448 | dropout: float = 0.1, 449 | contrastive: bool = False, 450 | ) -> EncoderWithLinear: 451 | """ 452 | Function to create an instance of the EncoderWithLinear model. 453 | 454 | Args: 455 | vocab (int): size of the vocabulary. 456 | N (int): number of layers. 457 | d_model (int): size of the model. 458 | d_ff (int): size of the feed-forward layer. 459 | h (int): number of heads. 460 | dropout (float): dropout rate. 461 | contrastive (bool): whether to use contrastive learning. 462 | 463 | Returns: 464 | EncoderWithLinear: instance of the EncoderWithLinear model. 465 | """ 466 | c = copy.deepcopy 467 | attn = MultiHeadedAttention(h, d_model, dropout=dropout) 468 | ff = PositionWiseFeedForward(d_model, d_ff, dropout=dropout) 469 | position = PositionalEncoding(d_model, dropout) 470 | model = EncoderWithLinear( 471 | encoder=Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 472 | generator=Generator(d_model, vocab), 473 | src_embed=position, 474 | d_model=d_model, 475 | contrastive=contrastive, 476 | ) 477 | # initialise parameters with Glorot / fan_avg. 478 | for p in model.parameters(): 479 | if p.dim() > 1: 480 | nn.init.xavier_uniform_(p) 481 | return model 482 | --------------------------------------------------------------------------------