├── data ├── __init__.py ├── datamodule.py └── emqa_dataset.py ├── eval ├── __init__.py ├── util.py └── eval.py ├── model ├── __init__.py ├── external │ ├── __init__.py │ ├── stm.py │ └── compressive_transformer.py ├── blind.py ├── simple_vqa.py ├── lt_ct.py ├── sparse_transformer.py ├── importance_teacher.py ├── lightning.py ├── base.py ├── mann.py ├── moment_loc.py └── rehearsal.py ├── tools ├── __init__.py ├── aggregate_features_to_hdf5.py ├── split_existing_data_aligned.py ├── create_pure_videoqa_json.py └── extract_ego4d_clip_features.py ├── .gitignore ├── config ├── model │ ├── blind.yaml │ ├── importance_teacher.yaml │ ├── moment_localization_loss │ │ ├── attn_summed.yaml │ │ └── attn_sample.yaml │ ├── simple_vqa.yaml │ ├── bigbird.yaml │ ├── longformer.yaml │ ├── stm.yaml │ ├── lt_ct.yaml │ ├── dnc.yaml │ └── rehearsal_mem.yaml ├── dataset │ └── ego4d.yaml └── base.yaml ├── .gitattributes ├── experiment └── run.sh ├── run.py ├── hydra_compat.py ├── requirements.txt ├── README.md └── lightning_util.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | venv 3 | .idea 4 | **.pyc -------------------------------------------------------------------------------- /config/model/blind.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.blind.BlindVqaModel 2 | pretrained_model: t5-base -------------------------------------------------------------------------------- /config/model/importance_teacher.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.importance_teacher.ImportanceTeacherVqaModel 2 | pretrained_enc_dec: t5-small 3 | input_size: ${dataset.feature_dim} 4 | fragment_length: 128 5 | -------------------------------------------------------------------------------- /config/model/moment_localization_loss/attn_summed.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.moment_loc.SummedAttentionTransformerMomentLocLoss 2 | softmax_temperature: 0.1 3 | att_loss_type: 'lse' 4 | lse_alpha: 20 5 | -------------------------------------------------------------------------------- /config/model/simple_vqa.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - moment_localization_loss: attn_sample 3 | 4 | _target_: model.simple_vqa.SimpleVqaModel 5 | pretrained_enc_dec: t5-base 6 | input_size: ${dataset.feature_dim} 7 | #moment_localization_loss: null -------------------------------------------------------------------------------- /config/model/bigbird.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.sparse_transformer.BigBirdVqaModel 2 | input_size: ${dataset.feature_dim} 3 | pretrained_bigbird: google/bigbird-pegasus-large-arxiv 4 | use_moment_localization_loss: False 5 | gradient_checkpointing: True -------------------------------------------------------------------------------- /config/model/moment_localization_loss/attn_sample.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.moment_loc.SamplingAttentionTransformerMomentLocLoss 2 | softmax_temperature: 0.1 3 | att_loss_type: 'lse' 4 | lse_alpha: 20 5 | num_negatives: 2 6 | use_hard_negatives: False 7 | drop_topk: 0 8 | negative_pool_size: 0 9 | num_hard: 2 -------------------------------------------------------------------------------- /config/model/longformer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - moment_localization_loss: attn_sample 3 | 4 | _target_: model.sparse_transformer.LongformerVqaModel 5 | pretrained_enc_dec: t5-base 6 | input_size: ${dataset.feature_dim} 7 | pretrained_longformer: allenai/longformer-base-4096 8 | #moment_localization_loss: null 9 | -------------------------------------------------------------------------------- /config/dataset/ego4d.yaml: -------------------------------------------------------------------------------- 1 | data_dir: datasets/ego4d 2 | use_final_test: False 3 | feature_type: slowfast8x8_r101_k400 4 | feature_dim: 2304 5 | 6 | drop_val_last: True 7 | 8 | tokenizer_name: t5-base 9 | # separate_question_tok_name: allenai/longformer-base-4096 10 | 11 | workers: 16 12 | train_bsz: 8 13 | test_bsz: 16 14 | -------------------------------------------------------------------------------- /config/model/stm.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.mann.StmEmqaModel 2 | pretrained_enc_dec: t5-base 3 | segmentation_method: avg 4 | input_size: ${dataset.feature_dim} 5 | segment_length: 32 6 | stm_input_size: 1024 7 | mem_hidden_size: 768 8 | stm_step: 2 9 | stm_num_slot: 8 10 | stm_mlp_size: 256 11 | stm_slot_size: 128 12 | stm_rel_size: 128 13 | stm_out_att_size: 128 -------------------------------------------------------------------------------- /config/model/lt_ct.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.lt_ct.CompressiveTransformerEmqaModel 2 | pretrained_enc_dec: t5-base 3 | input_dim: ${dataset.feature_dim} 4 | hidden_dim: 768 5 | num_layers: 6 6 | heads: 8 7 | block_length: 64 8 | mem_length: 64 9 | cmem_lengths: [ 16, 8, 4 ] 10 | compression_factors: [ 4, 2, 2 ] 11 | use_ltmem: True 12 | memory_layers: [ 1, 4, 5, 6 ] 13 | dropout: 0.1 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | #Set the default behavior, in case people don't have core.autocrlf set. 2 | * text=auto 3 | 4 | # Declare files that will always have LF line endings on checkout. 5 | *.py text eol=lf 6 | *.txt text eol=lf 7 | *.xml text eol=lf 8 | *.gitattributes text eol=lf 9 | *.gitignore text eol=lf 10 | *.json text eol=lf 11 | *.jgram text eol=lf 12 | 13 | #Forced Binary files 14 | #For example *.png binary 15 | -------------------------------------------------------------------------------- /config/model/dnc.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.mann.DncEmqaModel 2 | pretrained_enc_dec: t5-base 3 | segmentation_method: avg 4 | input_size: ${dataset.feature_dim} 5 | segment_length: 32 6 | dnc_input_size: 1024 7 | rnn_hidden_size: 512 8 | num_dnc_layers: 2 9 | num_rnn_hidden_layers: 2 10 | num_mem_cells: 16 11 | mem_hidden_size: 768 12 | #moment_loc: 13 | # _target_: model.moment_loc.SeqMomentLocalizationLossModule 14 | # seq_hidden_dim: 1024 15 | # question_hidden_dim: 512 -------------------------------------------------------------------------------- /eval/util.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | -------------------------------------------------------------------------------- /experiment/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Run this within the project main directory (where the directories experiment, model, ... reside) 5 | 6 | if [[ -z $TRANSFORMERS_CACHE ]]; then 7 | echo "TRANSFORMERS_CACHE env var has to be set!" 8 | exit 1 9 | fi 10 | if [[ ! -d $TRANSFORMERS_CACHE ]]; then 11 | echo "TRANSFORMERS_CACHE directory should exist (to avoid unintended downloads). Current: '$TRANSFORMERS_CACHE'" 12 | exit 1 13 | fi 14 | 15 | python run.py "${@:1}" 16 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ego4d 3 | - model: rehearsal_mem 4 | - _self_ 5 | - override hydra/job_logging: none 6 | - override hydra/hydra_logging: none 7 | 8 | trainer: 9 | detect_anomaly: True 10 | #val_check_interval: 500 11 | max_epochs: 100 12 | accumulate_grad_batches: 4 13 | auto_resume: False 14 | gpus: 1 15 | log_every_n_steps: 4 16 | auto_lr_find: True 17 | enable_progress_bar: False 18 | monitor_variable: val_lm_loss 19 | monitor_mode: min 20 | 21 | optim: 22 | optimizer: 23 | _target_: torch.optim.Adam 24 | lr: 0.00001 25 | weight_decay: 0 26 | 27 | freeze: [ ] 28 | 29 | loss_weights: 30 | # lm_loss is reference 31 | recollection_loss: 1 32 | familiarity_loss: 0.5 33 | 34 | 35 | hydra: 36 | run: 37 | dir: . 38 | output_subdir: null 39 | 40 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | from data.datamodule import EmqaDataModule 7 | from hydra_compat import apply_argparse_defaults_to_hydra_config 8 | from lightning_util import tune_fit_test, add_common_trainer_util_args 9 | from model.lightning import EmqaLightningModule 10 | 11 | 12 | @hydra.main(config_path='config', config_name='base') 13 | def main(config: DictConfig): 14 | fake_parser = ArgumentParser() 15 | add_common_trainer_util_args(fake_parser, default_monitor_variable='val_total_loss') 16 | apply_argparse_defaults_to_hydra_config(config.trainer, fake_parser) 17 | 18 | data = EmqaDataModule(config.dataset) 19 | model = EmqaLightningModule(config.model, config.optim, data.tokenizer) 20 | 21 | tune_fit_test(config.trainer, model, data) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /tools/aggregate_features_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import h5py 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | def main(): 10 | features_dir = Path(sys.argv[1]) 11 | all_feature_files = list(features_dir.glob('*.pt')) 12 | 13 | with h5py.File('slowfast8x8_r101_k400.hdf5', 'a') as h5_file, tqdm(all_feature_files) as iterator: 14 | features_file: Path 15 | for features_file in iterator: 16 | try: 17 | clip_uid = features_file.name.replace('.pt', '') 18 | if clip_uid in h5_file: 19 | continue 20 | features = torch.load(str(features_file)) 21 | h5_file.create_dataset(clip_uid, data=features, compression="gzip") 22 | except BaseException as e: 23 | raise RuntimeError(f'Error during {features_file}: {e}') 24 | 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /hydra_compat.py: -------------------------------------------------------------------------------- 1 | from absl.flags.argparse_flags import ArgumentParser 2 | from omegaconf import DictConfig, open_dict 3 | 4 | 5 | def apply_argparse_defaults_to_hydra_config(config: DictConfig, parser: ArgumentParser, verbose=False): 6 | args = parser.parse_args([]) # Parser is not allowed to have required args, otherwise this will fail! 7 | defaults = vars(args) 8 | 9 | def _apply_defaults(dest: DictConfig, source: dict, indentation=''): 10 | for k, v in source.items(): 11 | if k in dest and isinstance(v, dict): 12 | current_value = dest[k] 13 | if current_value is not None: 14 | assert isinstance(current_value, DictConfig) 15 | _apply_defaults(current_value, v, indentation + ' ') 16 | elif k not in dest: 17 | dest[k] = v 18 | if verbose: 19 | print(indentation, 'set default value for', k) 20 | 21 | with open_dict(config): 22 | _apply_defaults(config, defaults) 23 | -------------------------------------------------------------------------------- /config/model/rehearsal_mem.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.rehearsal.RehearsalMemoryEmqaModel 2 | pretrained_enc_dec: t5-base 3 | 4 | rehearsal_machine: 5 | _target_: model.rehearsal.RehearsalMemoryMachine 6 | pretrained_encoder: t5-base 7 | input_dim: ${dataset.feature_dim} 8 | mem_hidden_size: 768 9 | num_memory_slots: 16 10 | segment_length: 128 11 | slot_to_item_num_heads: 1 12 | use_independent_gru_per_mem_slot: False 13 | 14 | rehearsal_trainer: 15 | _target_: model.rehearsal.RehearsalTrainingModule 16 | input_size: ${dataset.feature_dim} 17 | mem_hidden_size: ${..rehearsal_machine.mem_hidden_size} 18 | num_samples: 4 19 | sample_length: 128 20 | positive_mask_ratio: 0.5 21 | negative_replacement_ratio: 0.5 22 | invert_teacher_sequence: False 23 | pretrained_decoder: null 24 | decoder_params: 25 | hidden_size: 256 26 | num_hidden_layers: 3 27 | num_attention_heads: 4 28 | intermediate_size: 512 29 | max_position_embeddings: 132 30 | vocab_size: 16 # Vocab never used (only inputs_embeds) 31 | sampling_teacher_weights_file: ../22_04_b5aa4428_0/importance_teacher_weights.pt 32 | -------------------------------------------------------------------------------- /model/blind.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict 2 | 3 | import torch 4 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM 5 | 6 | from model.base import EmqaBaseModel 7 | 8 | 9 | class BlindVqaModel(EmqaBaseModel): 10 | 11 | def __init__(self, 12 | pretrained_model: str 13 | ) -> None: 14 | super().__init__() 15 | self.transformer: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model) 16 | assert self.transformer.config.is_encoder_decoder 17 | 18 | def teacher_forcing_forward(self, question_tokens, question_mask, video_features, video_mask, answer_tokens, 19 | answer_mask, batch_sample_ids, 20 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 21 | output = self.transformer(input_ids=question_tokens, attention_mask=question_mask, 22 | labels=answer_tokens, decoder_attention_mask=answer_mask) 23 | return {'lm_loss': output.loss}, output.logits 24 | 25 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor: 26 | return self.transformer.generate(inputs=question_tokens, attention_mask=question_mask) 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | antlr4-python3-runtime==4.8 5 | async-timeout==4.0.2 6 | attrs==21.4.0 7 | cachetools==5.0.0 8 | certifi==2021.10.8 9 | charset-normalizer==2.0.12 10 | click==8.0.4 11 | colorama==0.4.4 12 | dnc==1.1.0 13 | filelock==3.6.0 14 | flann==1.6.13 15 | frozenlist==1.3.0 16 | fsspec==2022.2.0 17 | future==0.18.2 18 | google-auth==2.6.0 19 | google-auth-oauthlib==0.4.6 20 | grpcio==1.44.0 21 | h5py==3.6.0 22 | huggingface-hub==0.4.0 23 | hydra-core==1.1.1 24 | idna==3.3 25 | importlib-metadata==4.11.2 26 | joblib==1.1.0 27 | Markdown==3.3.6 28 | multidict==6.0.2 29 | nltk==3.7 30 | numpy==1.22.2 31 | oauthlib==3.2.0 32 | omegaconf==2.1.1 33 | packaging==21.3 34 | portalocker==2.4.0 35 | protobuf==3.19.4 36 | pyasn1==0.4.8 37 | pyasn1-modules==0.2.8 38 | pyDeprecate==0.3.1 39 | pyparsing==3.0.7 40 | pytorch-lightning==1.5.10 41 | PyYAML==6.0 42 | regex==2022.3.2 43 | requests==2.27.1 44 | requests-oauthlib==1.3.1 45 | rouge-score==0.0.4 46 | rsa==4.8 47 | sacrebleu==2.0.0 48 | sacremoses==0.0.47 49 | six==1.16.0 50 | tabulate==0.8.9 51 | tensorboard==2.8.0 52 | tensorboard-data-server==0.6.1 53 | tensorboard-plugin-wit==1.8.1 54 | tokenizers==0.11.6 55 | torch==1.10.2 56 | torchmetrics==0.7.2 57 | tqdm==4.63.0 58 | transformers==4.16.2 59 | typing_extensions==4.1.1 60 | urllib3==1.26.8 61 | Werkzeug==2.0.3 62 | yarl==1.7.2 63 | zipp==3.7.0 64 | -------------------------------------------------------------------------------- /tools/split_existing_data_aligned.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import sys 4 | from pathlib import Path 5 | 6 | input_data_file = Path(sys.argv[1]) 7 | clip_to_video_file = Path(sys.argv[2]) 8 | input_samples = json.loads(input_data_file.read_text()) 9 | clip_to_video = json.loads(clip_to_video_file.read_text()) 10 | 11 | all_clips = {s['video_id'] for s in input_samples} 12 | all_videos = {clip_to_video[clip] for clip in all_clips} 13 | 14 | frac_val = 0.5 15 | assert frac_val < 1 16 | 17 | num_val = round(len(all_videos) * frac_val) 18 | num_test = len(all_videos) - num_val 19 | assert num_test >= 1 20 | 21 | val_videos = set(random.sample(sorted(all_videos), num_val)) 22 | test_videos = all_videos - val_videos 23 | 24 | val_clips = {clip for clip in all_clips if clip_to_video[clip] in val_videos} 25 | test_clips = {clip for clip in all_clips if clip_to_video[clip] in test_videos} 26 | 27 | val_samples = [s for s in input_samples if s['video_id'] in val_clips] 28 | test_samples = [s for s in input_samples if s['video_id'] in test_clips] 29 | 30 | print(f'Splits: videos/clips/samples ') 31 | print(f'Val : {len(val_videos)}/{len(val_clips)}/{len(val_samples)}') 32 | print(f'Test : {len(test_videos)}/{len(test_clips)}/{len(test_samples)}') 33 | 34 | Path('pure_emqa_val.json').write_text(json.dumps(val_samples)) 35 | Path('pure_emqa_test.json').write_text(json.dumps(test_samples)) 36 | 37 | Path('split_videos.json').write_text(json.dumps({'val': sorted(val_videos), 'test': sorted(test_videos)})) 38 | Path('split_clips.json').write_text(json.dumps({'val': sorted(val_clips), 'test': sorted(test_clips)})) 39 | -------------------------------------------------------------------------------- /model/simple_vqa.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.modeling_outputs import Seq2SeqLMOutput 6 | 7 | from model.base import MemoryAugmentedTransformerEmqaModel 8 | from model.moment_loc import TransformerMomentLocalizationLossModule 9 | 10 | 11 | # noinspection PyAbstractClass 12 | class SimpleVqaModel(MemoryAugmentedTransformerEmqaModel): 13 | # Actually this is not really a MemoryAugmentedTransformerEmqaModel since it uses full attention over the input 14 | 15 | def __init__(self, input_size: int, 16 | pretrained_enc_dec: str, 17 | moment_localization_loss: TransformerMomentLocalizationLossModule = None) -> None: 18 | super().__init__(pretrained_enc_dec) 19 | self.moment_localization_loss = moment_localization_loss 20 | self.transformer.get_decoder().config.output_attentions = True 21 | 22 | hidden = self.transformer.get_input_embeddings().embedding_dim 23 | if input_size != hidden: 24 | self.transform_visual = nn.Linear(input_size, hidden, bias=False) 25 | else: 26 | self.transform_visual = nn.Identity() 27 | 28 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels): 29 | visual_seq = self.transform_visual(video_features) 30 | context, context_mask = self._prepare_context(visual_seq, video_mask, question_tokens, question_mask) 31 | return context, context_mask, visual_seq, video_mask, {} 32 | 33 | def calc_additional_loss(self, question_tokens, question_mask, video_features, video_mask, answer_tokens, 34 | answer_mask, batch_sample_ids, context, context_mask, final_memory, mem_mask, 35 | transformer_output: Seq2SeqLMOutput, 36 | moment_localization_labels) -> Dict[str, torch.Tensor]: 37 | if self.moment_localization_loss: 38 | return { 39 | 'moment_localization': self.moment_localization_loss( 40 | question_tokens, transformer_output, moment_localization_labels, video_mask) 41 | } 42 | else: 43 | return {} 44 | -------------------------------------------------------------------------------- /tools/create_pure_videoqa_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import Dict 5 | 6 | 7 | def convert(nlq_file: Path, answers: Dict[str, str]): 8 | """ 9 | Reads NLQ JSON file and answer map, and creates a single dict with a list of 10 | {"video_id", "sample_id", "question", "answer", "moment_start_frame", "moment_end_frame"} objects. 11 | 12 | :param nlq_file: NLQ JSON file 13 | :param answers: Answer map 14 | """ 15 | result = [] 16 | 17 | annotations = json.loads(nlq_file.read_text()) 18 | for video in annotations['videos']: 19 | for clip in video['clips']: 20 | for annotation in clip['annotations']: 21 | for i, query in enumerate(annotation['language_queries']): 22 | if 'query' not in query or not query['query']: 23 | continue 24 | question = query['query'].replace('\n', '').replace(',', '').strip() 25 | video_id = clip['clip_uid'] 26 | sample_id = f'{annotation["annotation_uid"]}_{i}' 27 | if sample_id not in answers: 28 | continue 29 | answer = answers[sample_id].replace('\n', '').replace(',', '').strip() 30 | fps = 30 # fps = 30 is known for canonical Ego4D clips 31 | start_frame = query['clip_start_sec'] * fps 32 | end_frame = query['clip_end_sec'] * fps 33 | 34 | result.append({ 35 | 'video_id': video_id, 36 | 'sample_id': sample_id, 37 | 'answer': answer, 38 | 'question': question, 39 | 'moment_start_frame': start_frame, 40 | 'moment_end_frame': end_frame 41 | }) 42 | 43 | return result 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--ego4d', required=True, type=str, 49 | help='Directory where you placed the Ego4D download.') 50 | parser.add_argument('--qaego4d', required=True, type=str, 51 | help='Path to QaEgo4D answers.json file') 52 | args = parser.parse_args() 53 | 54 | ego4d_dir = Path(args.ego4d) 55 | ego4d_annotations_dir = ego4d_dir / 'v1' / 'annotations' 56 | qaego4d_file = Path(args.qaego4d) 57 | assert ego4d_dir.is_dir() 58 | assert ego4d_annotations_dir.is_dir() 59 | assert qaego4d_file.is_file() 60 | 61 | qaego4d_data = json.loads(qaego4d_file.read_text()) 62 | nlq_train, nlq_val = [ego4d_annotations_dir / f'nlq_{split}.json' for split in ('train', 'val')] 63 | 64 | train = convert(nlq_train, qaego4d_data['train']) 65 | val = convert(nlq_val, qaego4d_data['val']) 66 | test = convert(nlq_val, qaego4d_data['test']) 67 | 68 | output_dir = qaego4d_file.parent 69 | (output_dir / 'annotations.train.json').write_text(json.dumps(train)) 70 | (output_dir / 'annotations.val.json').write_text(json.dumps(val)) 71 | (output_dir / 'annotations.test.json').write_text(json.dumps(test)) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /data/datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import PreTrainedTokenizerBase as Tokenizer, AutoTokenizer 6 | 7 | from data.emqa_dataset import EmqaDataset 8 | 9 | 10 | # noinspection PyAbstractClass 11 | class EmqaDataModule(LightningDataModule): 12 | tokenizer: Tokenizer 13 | 14 | def __init__(self, config, drop_last=True): 15 | """ 16 | 17 | :param config: Needs {tokenizer_name: str, 18 | separate_question_tok_name: Optional[str], 19 | drop_val_last: Optional[bool] = False 20 | workers: int, 21 | use_final_test: bool, 22 | train_bsz: int, 23 | test_bsz: int 24 | } + requirements from EmqaDataset.create_from_cfg 25 | """ 26 | super().__init__() 27 | self.config = config 28 | self.drop_last = drop_last 29 | self.drop_val_last = getattr(config, 'drop_val_last', False) 30 | self.train_dataset, self.val_dataset, self.test_dataset = None, None, None 31 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) 32 | self.tokenizer.pad_token = self.tokenizer.eos_token # Set this per convenience for GPT-2 33 | if getattr(config, 'separate_question_tok_name', None): 34 | self.question_tok = AutoTokenizer.from_pretrained(config.separate_question_tok_name) 35 | else: 36 | self.question_tok = None 37 | 38 | def setup(self, stage: Optional[str] = None) -> None: 39 | super().setup(stage) 40 | self.train_dataset = self._create_data('train') 41 | self.val_dataset = self._create_data('val') 42 | if self.config.use_final_test: 43 | self.test_dataset = self._create_data('test') 44 | else: 45 | self.test_dataset = self.val_dataset or self._create_data('val') 46 | 47 | def _create_data(self, split): 48 | return EmqaDataset.create_from_cfg(self.config, split, self.tokenizer, self.question_tok) 49 | 50 | def common_loader_args(self, dataset): 51 | return dict(num_workers=self.config.workers, 52 | collate_fn=dataset.collate_emv_samples, 53 | pin_memory=True) 54 | 55 | def eval_loader_args(self, dataset): 56 | return dict(**self.common_loader_args(dataset), 57 | shuffle=False, 58 | batch_size=self.config.test_bsz) 59 | 60 | def train_dataloader(self): 61 | assert self.train_dataset 62 | return DataLoader(self.train_dataset, 63 | batch_size=self.config.train_bsz, 64 | shuffle=True, 65 | drop_last=self.drop_last, 66 | **self.common_loader_args(self.train_dataset)) 67 | 68 | def val_dataloader(self): 69 | assert self.val_dataset 70 | return DataLoader(self.val_dataset, drop_last=self.drop_val_last, 71 | **self.eval_loader_args(self.val_dataset)) 72 | 73 | def test_dataloader(self): 74 | assert self.test_dataset 75 | return DataLoader(self.test_dataset, **self.eval_loader_args(self.test_dataset)) 76 | -------------------------------------------------------------------------------- /model/lt_ct.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Tuple, Dict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from .base import MemoryAugmentedTransformerEmqaModel 8 | from .external.compressive_transformer import CompressiveTransformer 9 | 10 | 11 | class CompressiveTransformerEmqaModel(MemoryAugmentedTransformerEmqaModel): 12 | 13 | def __init__( 14 | self, 15 | pretrained_enc_dec: str, 16 | input_dim: int, 17 | hidden_dim: int = 512, 18 | num_layers: int = 10, 19 | heads: int = 8, 20 | block_length: int = 16, 21 | mem_length: int = 32, 22 | cmem_lengths=None, 23 | compression_factors=4, 24 | use_ltmem: bool = True, 25 | memory_layers=None, 26 | dropout: float = 0.1 27 | ): 28 | super().__init__(pretrained_enc_dec) 29 | self.mem_transformer = CompressiveTransformer( 30 | num_tokens=1, # Embedding is skipped (video features input) 31 | # However, we set emb_dim to automatically use CompressiveTransformer.to_model_dim 32 | emb_dim=input_dim, 33 | dim=hidden_dim, depth=num_layers, heads=heads, 34 | seq_len=block_length, mem_len=mem_length, 35 | cmem_lengths=cmem_lengths, 36 | cmem_ratios=compression_factors, 37 | use_ltmem=use_ltmem, memory_layers=memory_layers, 38 | attn_layer_dropout=dropout, ff_dropout=dropout, attn_dropout=dropout, 39 | reconstruction_attn_dropout=dropout, 40 | gru_gated_residual=False, mogrify_gru=False, ff_glu=False, one_kv_head=False 41 | ) 42 | self.mem_transformer.token_emb = nn.Identity() 43 | self.mem_transformer.to_logits = nn.Identity() 44 | 45 | def forward_memory(self, video_features, video_mask, 46 | moment_localization_labels, 47 | question_encoding) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 48 | bsz, seq_len = video_features.shape[0], video_features.shape[1] 49 | 50 | block_length = self.mem_transformer.seq_len 51 | num_blocks = seq_len // block_length + (0 if seq_len % block_length == 0 else 1) 52 | 53 | x = video_features 54 | memories = (None, None, None) 55 | aux_loss_sum = torch.zeros(1, device=x.device) 56 | for i in range(num_blocks): 57 | current_slice = slice(i * block_length, (i + 1) * block_length) 58 | input_block = x[:, current_slice, :] 59 | out, memories, aux_loss = self.mem_transformer(input_block, 60 | memories=memories, 61 | mask=video_mask[:, current_slice]) 62 | aux_loss_sum += aux_loss 63 | 64 | if num_blocks > 0: 65 | mem, cmems, ltmem = memories 66 | # memories is a tuple with three items. mem and ltmem are tensors, cmems a list of tensors, all of size 67 | # (num_memory_layers x batch x memory_seq_length x hidden) 68 | # ltmem always has memory_seq_length of either 0 or 1 69 | # out is (batch x sequence x hidden) from transformers last layer. 70 | # concatenate at sequence level (dim=1), but treat each mem layer as it's own vector 71 | # Also, treat all compression layers the same and simply concatenate 72 | memories = mem, *cmems, ltmem 73 | # noinspection PyUnboundLocalVariable 74 | complete_em = torch.cat([out] + [layer_mem for layer_mem in chain(*memories)], dim=1) # B x S x H 75 | else: 76 | complete_em = torch.empty(bsz, 0, self.output_size, 77 | device=x.device, dtype=x.dtype) 78 | 79 | return complete_em, {'ct_aux_loss': aux_loss_sum} 80 | -------------------------------------------------------------------------------- /tools/extract_ego4d_clip_features.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from argparse import ArgumentParser 5 | from dataclasses import dataclass 6 | from functools import partial 7 | from multiprocessing import Pool, Manager, JoinableQueue 8 | from pathlib import Path 9 | from typing import List 10 | 11 | import torch 12 | from tqdm import tqdm 13 | 14 | 15 | def existing_path(p: str): 16 | p = Path(p) 17 | assert p.exists(), p 18 | return p 19 | 20 | 21 | @dataclass 22 | class WorkItem: 23 | video_uid: str 24 | clip_uid: str 25 | clip_start_frame: int 26 | clip_end_frame: int 27 | 28 | def output_file(self, output_dir: Path): 29 | return output_dir / f'{self.clip_uid}.pt' 30 | 31 | def do_work(self, video_features_dir: Path, output_dir: Path, feature_window_stride: int, 32 | progress_queue: JoinableQueue): 33 | video_features_file = video_features_dir / f'{self.video_uid}.pt' 34 | output_file = self.output_file(output_dir) 35 | start_feature = self.clip_start_frame // feature_window_stride 36 | end_feature = math.ceil(self.clip_end_frame / feature_window_stride) 37 | video_features = torch.load(str(video_features_file)) 38 | clip_features = video_features[start_feature:end_feature] 39 | torch.save(clip_features, str(output_file)) 40 | progress_queue.put_nowait(1) 41 | 42 | 43 | def _extract_work_items(annotations) -> List[WorkItem]: 44 | work_items = [] 45 | for video in annotations['videos']: 46 | for clip in video['clips']: 47 | clip_uid = clip['clip_uid'] 48 | # Wanted to use video_start/end_frame, but there seems to be a bug with metadata in Ego4D data so 49 | # that length of clips would be zero. Thus, calc frames from seconds. 50 | # 30 fps is safe to assume for canonical videos 51 | start = int(clip['video_start_sec'] * 30) 52 | end = int(clip['video_end_sec'] * 30) 53 | assert start != end, f'{start}, {end}, {clip_uid}' 54 | work_items.append(WorkItem( 55 | video['video_uid'], clip_uid, 56 | start, end 57 | )) 58 | return work_items 59 | 60 | 61 | def main(nlq_file: Path, video_features_dir: Path, output_dir: Path, 62 | feature_window_stride=16, num_workers=16): 63 | annotations = json.loads(nlq_file.read_text()) 64 | all_work_items = _extract_work_items(annotations) 65 | all_work_items = [w for w in all_work_items if not w.output_file(output_dir).is_file()] 66 | print(f'Will extract {len(all_work_items)} clip features...') 67 | 68 | with Pool(num_workers) as pool, Manager() as manager: 69 | queue = manager.Queue() 70 | pool.map_async(partial(WorkItem.do_work, 71 | video_features_dir=video_features_dir, 72 | output_dir=output_dir, 73 | feature_window_stride=feature_window_stride, 74 | progress_queue=queue), 75 | all_work_items) 76 | with tqdm(total=len(all_work_items)) as pbar: 77 | while pbar.n < len(all_work_items): 78 | pbar.update(queue.get(block=True)) 79 | 80 | 81 | def cli_main(): 82 | parser = ArgumentParser() 83 | parser.add_argument('--annotation_file', type=existing_path, required=True, 84 | help='Ego4D Annotation JSON file containing annotations for which to extract clip features. ' 85 | 'Should contain "videos" array, where each item has a "clips" array ' 86 | '(e.g. NLQ annotations).') 87 | parser.add_argument('--video_features_dir', type=existing_path, required=True, 88 | help='Directory where to find pre-extracted Ego4D video features.') 89 | parser.add_argument('--output_dir', type=existing_path, required=True, 90 | help='Directory where to place output files. They are named after the clip_uid.') 91 | parser.add_argument('--feature_window_stride', type=int, default=16, 92 | help='Stride of window used to produce the features in video_features_dir') 93 | parser.add_argument('--num_workers', type=int, default=os.cpu_count(), 94 | help='Number of parallel workers') 95 | 96 | args = parser.parse_args() 97 | main(args.annotation_file, args.video_features_dir, args.output_dir, 98 | args.feature_window_stride, args.num_workers) 99 | 100 | 101 | if __name__ == '__main__': 102 | cli_main() 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QaEgo4D — Episodic-Memory-Based Question Answering on Egocentric Videos 2 | 3 | This repository contains the code to reproduce the results of the paper "Where did I leave my keys? — 4 | Episodic-Memory-Based Question Answering on Egocentric Videos". See our 5 | [paper](https://openaccess.thecvf.com/content/CVPR2022W/Ego4D-EPIC/papers/Barmann_Where_Did_I_Leave_My_Keys_-_Episodic-Memory-Based_Question_Answering_CVPRW_2022_paper.pdf) 6 | for more details. 7 | 8 | ## Abstract 9 | 10 | Humans have a remarkable ability to organize, compress and retrieve episodic memories throughout their daily life. 11 | Current AI systems, however, lack comparable capabilities as they are mostly constrained to an analysis with access to 12 | the raw input sequence, assuming an unlimited amount of data storage which is not feasible in realistic deployment 13 | scenarios. For instance, existing Video Question Answering (VideoQA) models typically reason over the video while 14 | already being aware of the question, thus requiring to store the complete video in case the question is not known in 15 | advance. 16 | 17 | In this paper, we address this challenge with three main contributions: 18 | First, we propose the Episodic Memory Question Answering (EMQA) task as a specialization of VideoQA. Specifically, EMQA 19 | models are constrained to keep only a constant-sized representation of the video input, thus automatically limiting the 20 | computation requirements at query time. Second, we introduce a new egocentric VideoQA dataset 21 | called QaEgo4D. It is the by far largest egocentric VideoQA dataset and 22 | video length is unprecedented in VideoQA datasets in general. Third, we present extensive experiments on the new 23 | dataset, comparing various baselines models in both the VideoQA as well as the EMQA setting. To facilitate future 24 | research on egocentric VideoQA as well as episodic memory representation and retrieval, we publish our code and dataset. 25 | 26 | ## Using the dataset 27 | 28 | To use the QaEgo4D dataset introduced in our paper, please follow these 29 | steps: 30 | 31 | 1. QaEgo4D builds on the Ego4D v1 videos and annotations. If you do not have 32 | access to Ego4D already, you should follow the steps at the [Ego4D website](https://ego4d-data.org/docs/start-here/) 33 | 2. To get access to QaEgo4D, please fill out 34 | this [Google form](https://forms.gle/Gxs93wwC5YYJtjqh8). You will need to sign a license agreement, but there are no 35 | fees if you use the data for non-commercial research purposes. 36 | 3. Download the Ego4D annotations and NLQ clips if you have not done so already. See 37 | the [Ego4D website](https://ego4d-data.org/docs/start-here/) 38 | 4. After you have access to both Ego4D and QaEgo4D, you can generate 39 | self-contained VideoQA annotation files 40 | using `python3 tools/create_pure_videoqa_json.py --ego4d /path/to/ego4d --qaego4d /path/to/qaego4d/answers.json`. 41 | `/path/to/ego4d` is the directory where you placed the Ego4D download, containing 42 | the `v1/annotations/nlq_{train,val}.json` files. This produces `/path/to/qaego4d/annotations.{train,val,test}.json`. 43 | 44 | The `annotations.*.json` files are JSON arrays, where each object has the following structure: 45 | ``` 46 | { 47 | "video_id": "abcdef00-0000-0000-0000-123456789abc", 48 | "sample_id": "12345678-1234-1234-1234-123456789abc_3", 49 | "question": "Where did I leave my keys?", 50 | "answer": "on the table", 51 | "moment_start_frame": 42, 52 | "moment_end_frame": 53 53 | } 54 | ``` 55 | 56 | ## Code 57 | 58 | In order to reproduce the experiments, prepare your workspace: 59 | 60 | 1. Follow the instructions above to get the dataset and features. 61 | 2. Create a conda / python virtual environment (Python 3.9.7) 62 | 3. Install the requirements in `requirements.txt` 63 | 4. Prepare the features: 64 | 1. Download the pre-extracted [Ego4D features](https://ego4d-data.org/docs/data/features/) if you have not done so 65 | already. 66 | 2. Ego4D features are provided for each canonical video, while the NLQ task and thus also VideoQA works on the 67 | canonical clips. To extract features for each clip, 68 | use `python tools/extract_ego4d_clip_features.py --annotation_file /path/to/ego4d/v1/annotations/nlq_train.json --video_features_dir /path/to/ego4d/v1/slowfast8x8_r101_k400 --output_dir /choose/your/clip_feature_dir` 69 | and do the same again with `nlq_val.json` 70 | 3. Aggregate the features into a single file 71 | using `python tools/aggregate_features_to_hdf5.py /choose/your/clip_feature_dir`. This 72 | produces `slowfast8x8_r101_k400.hdf5` in the current working directory. 73 | 5. Place or link the QaEgo4D data (`annotations.*.json` 74 | and `slowfast8x8_r101_k400.hdf5`) into `datasets/ego4d`. 75 | 76 | To run an experiment, use `bash experiment/run.sh`. All configuration files can be found in the `config` dir. 77 | 78 | 79 | ## Cite 80 | ``` 81 | @InProceedings{Baermann_2022_CVPR, 82 | author = {B\"armann, Leonard and Waibel, Alex}, 83 | title = {Where Did I Leave My Keys? - Episodic-Memory-Based Question Answering on Egocentric Videos}, 84 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 85 | month = {June}, 86 | year = {2022}, 87 | pages = {1560-1568} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | from pathlib import Path 5 | from pprint import pprint 6 | from typing import List, Dict, Any 7 | 8 | from nltk.translate.meteor_score import meteor_score 9 | from rouge_score.rouge_scorer import RougeScorer 10 | from rouge_score.tokenize import tokenize 11 | from sacrebleu.metrics import BLEU, BLEUScore 12 | 13 | from .util import AverageMeter 14 | 15 | 16 | # Check whether to use 17 | # - https://github.com/Maluuba/nlg-eval 18 | # - https://github.com/hwanheelee1993/KPQA 19 | def calc_metrics(predictions: List[str], gold_annotations: List[List[str]]) -> Dict[str, Any]: 20 | """ 21 | Calculate metrics. 22 | 23 | Parameters 24 | ---------- 25 | predictions : list[str] 26 | The list of predictions 27 | gold_annotations : list[list[str]] 28 | A list with the same length as predictions. 29 | Each element is a list of possible target candidates for the corresponding prediction. 30 | All elements should have the same length. 31 | """ 32 | if len(predictions) != len(gold_annotations): 33 | raise ValueError(f'{len(predictions)} != {len(gold_annotations)}') 34 | ref_count = len(gold_annotations[0]) 35 | if any(len(refs) != ref_count for refs in gold_annotations): 36 | raise ValueError(f'All refs should have the same length {ref_count}!') 37 | 38 | acc = _calc_accuracy(predictions, gold_annotations) 39 | bleu = _calc_bleu(predictions, gold_annotations) 40 | rouge = _calc_rouge(predictions, gold_annotations) 41 | meteor = _calc_meteor(predictions, gold_annotations) 42 | 43 | return { 44 | 'plain_acc': acc, 45 | **bleu, 46 | 'ROUGE': rouge['rougeL']['f'], 47 | **_flatten_dict(rouge, prefix='ROUGE.'), 48 | 'METEOR': meteor 49 | } 50 | 51 | 52 | def _calc_accuracy(predictions, gold_annotations): 53 | correct = 0 54 | for pred, possible_refs in zip(predictions, gold_annotations): 55 | if any(ref == pred for ref in possible_refs): 56 | correct += 1 57 | total = len(predictions) 58 | return correct / total 59 | 60 | 61 | def _calc_meteor(predictions, gold_annotations): 62 | score = AverageMeter() 63 | for pred, possible_refs in zip(predictions, gold_annotations): 64 | pred = tokenize(pred, None) 65 | # https://github.com/cmu-mtlab/meteor/blob/master/src/edu/cmu/meteor/util/Normalizer.java 66 | possible_refs = [tokenize(x, None) for x in possible_refs] 67 | score.update(meteor_score(possible_refs, pred)) 68 | return score.avg 69 | 70 | 71 | def _calc_rouge(predictions, gold_annotations) -> Dict[str, Dict[str, float]]: 72 | rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False) 73 | rouge = defaultdict(lambda: defaultdict(AverageMeter)) 74 | for pred, possible_refs in zip(predictions, gold_annotations): 75 | sample_result = {} 76 | for ref in possible_refs: 77 | single_ref_result = rouge_scorer.score(ref, pred) 78 | for k, scores in single_ref_result.items(): 79 | existing_result_dict = sample_result.setdefault(k, {}) 80 | if existing_result_dict.get('f', -1) < scores.fmeasure: 81 | existing_result_dict.update(f=scores.fmeasure, p=scores.precision, r=scores.recall) 82 | for k, best_scores in sample_result.items(): 83 | rouge[k]['p'].update(best_scores['p']) 84 | rouge[k]['r'].update(best_scores['r']) 85 | rouge[k]['f'].update(best_scores['f']) 86 | return { 87 | rouge_type: { 88 | measure: score.avg 89 | for measure, score in results.items() 90 | } for rouge_type, results in rouge.items() 91 | } 92 | 93 | 94 | def _calc_bleu(predictions, gold_annotations) -> Dict[str, float]: 95 | refs_transposed = [ 96 | [refs[i] for refs in gold_annotations] 97 | for i in range(len(gold_annotations[0])) 98 | ] 99 | bleu: BLEUScore = BLEU().corpus_score(predictions, refs_transposed) 100 | return { 101 | 'BLEU': bleu.score, 102 | 'BLEU.bp': bleu.bp, 103 | 'BLEU.ratio': bleu.ratio, 104 | 'BLEU.hyp_len': float(bleu.sys_len), 105 | 'BLEU.ref_len': float(bleu.ref_len), 106 | } 107 | 108 | 109 | def _flatten_dict(d, prefix=''): 110 | result = {} 111 | for k, v in d.items(): 112 | my_key = prefix + k 113 | if isinstance(v, dict): 114 | result.update(_flatten_dict(v, prefix=my_key + '.')) 115 | else: 116 | result[my_key] = v 117 | return result 118 | 119 | 120 | def main(): 121 | parser = ArgumentParser('Eval output file') 122 | parser.add_argument('--gold_answers', type=str, required=True, 123 | help='Path to answers.json, containing mapping from sample_id to answer') 124 | parser.add_argument('eval_file', type=str, 125 | help='JSON File to evaluate. Should contain mapping from sample_id ' 126 | 'to hypothesis or array of hypotheses') 127 | args = parser.parse_args() 128 | 129 | gold_answers = json.loads(Path(args.gold_answers).read_text()) 130 | hypotheses = json.loads(Path(args.eval_file).read_text()) 131 | if isinstance(next(iter(hypotheses.values())), list): 132 | hypotheses = {k: v[0] for k, v in hypotheses.items()} 133 | assert len(hypotheses.keys() - gold_answers.keys()) == 0, 'No gold answer for some hypotheses' 134 | 135 | gold_and_hypo = [(gold_answers[k], hypotheses[k]) for k in hypotheses.keys()] 136 | hypo_list = [h for g, h in gold_and_hypo] 137 | gold_list = [[g] for g, h in gold_and_hypo] 138 | metrics = calc_metrics(hypo_list, gold_list) 139 | 140 | pprint(metrics) 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /data/emqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from abc import ABC 4 | from pathlib import Path 5 | from typing import Dict, Union, List 6 | from typing import Literal 7 | 8 | import h5py 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn.utils.rnn import pad_sequence 12 | from torch.utils.data import Dataset 13 | from transformers import PreTrainedTokenizerBase as Tokenizer 14 | 15 | EmqaAnnotationFileDictKey = Literal[ 16 | 'sample_id', 17 | 'video_id', 18 | 'question', 19 | 'answer', 20 | 'moment_start_frame', # optional 21 | 'moment_end_frame', # optional 22 | ] 23 | EmqaSampleDictKey = Literal[ 24 | 'video_id', 25 | 'sample_id', 26 | 'question_text', 27 | 'answer_text', 28 | 'video_features', 29 | 'moment_label' 30 | ] 31 | EmqaBatchDictKey = Literal[ 32 | 'batch_video_ids', # List[int] 33 | 'batch_sample_ids', # List[int] 34 | 'batch_question_tokens', # int, B x L 35 | 'batch_question_mask', # bool, B x L 36 | 'batch_video_features', # float, B x N x H 37 | 'batch_video_mask', # bool, B x N 38 | 'batch_answer_texts', # List[str] 39 | 'batch_answer_tokens', # int, B x T 40 | 'batch_answer_mask', # bool, B x T 41 | 'batch_moment_localization_labels' # float, B x N, optional 42 | ] 43 | EmqaBatch = Dict[EmqaBatchDictKey, Union[torch.Tensor, List[int], List[str]]] 44 | EmqaSample = Dict[EmqaSampleDictKey, Union[torch.Tensor, int, str]] 45 | DatasetSplitName = Literal["train", "val", "test"] 46 | 47 | 48 | class EmqaDataset(Dataset, ABC): 49 | tokenizer: Tokenizer 50 | split: DatasetSplitName 51 | 52 | def __init__(self, video_features_file: Path, 53 | tokenizer: Tokenizer, 54 | split: DatasetSplitName, 55 | annotations: List[Dict[EmqaAnnotationFileDictKey, Union[int, str]]], 56 | normalize_video_features=False, 57 | frames_per_feature=16, 58 | separate_question_tokenizer: Tokenizer = None): 59 | super().__init__() 60 | self.video_features_file = video_features_file 61 | self.tokenizer = tokenizer 62 | self.separate_question_tokenizer = separate_question_tokenizer or tokenizer 63 | self.split = split 64 | self.annotations = annotations 65 | self.normalize_video_features = normalize_video_features 66 | self.frames_per_feature = frames_per_feature 67 | 68 | def __len__(self) -> int: 69 | return len(self.annotations) 70 | 71 | def __getitem__(self, index) -> EmqaSample: 72 | video_id = self.annotations[index]['video_id'] 73 | sample_id = self.annotations[index]['sample_id'] 74 | question = self.annotations[index]['question'] 75 | answer = self.annotations[index]['answer'] 76 | gt_start_frame = self.annotations[index].get('moment_start_frame') 77 | gt_end_frame = self.annotations[index].get('moment_end_frame') 78 | video_features = self._get_video_features(video_id) 79 | 80 | sample: EmqaSample = { 81 | 'video_id': video_id, 82 | 'sample_id': sample_id, 83 | 'video_features': video_features, 84 | 'question_text': question, 85 | 'answer_text': answer 86 | } 87 | if gt_start_frame is not None and gt_end_frame is not None: 88 | start = gt_start_frame // self.frames_per_feature 89 | # ensure at least one target frame even if gt_start == gt_end 90 | end = math.ceil(gt_end_frame / self.frames_per_feature) 91 | if start == end: 92 | end += 1 93 | sample['moment_label'] = torch.tensor([start, end], dtype=torch.int) 94 | return sample 95 | 96 | def _get_video_features(self, video_id): 97 | with h5py.File(self.video_features_file, 'r') as hdf5_file: 98 | features = torch.from_numpy(hdf5_file[video_id][:]).float() 99 | if self.normalize_video_features: 100 | features = F.normalize(features, dim=1) 101 | return features 102 | 103 | def collate_emv_samples(self, batch: List[EmqaSample]) -> EmqaBatch: 104 | video_ids = [b['video_id'] for b in batch] 105 | sample_ids = [b['sample_id'] for b in batch] 106 | video_features = [b['video_features'] for b in batch] 107 | questions = [b['question_text'] for b in batch] 108 | answers = [b['answer_text'] for b in batch] 109 | 110 | answers_with_eos = [a + self.tokenizer.eos_token for a in answers] 111 | tok_args = dict(padding=True, return_tensors='pt', add_special_tokens=False) 112 | question_tok = self.separate_question_tokenizer(questions, **tok_args) 113 | answers_tok = self.tokenizer(answers_with_eos, **tok_args) 114 | 115 | video_features_padded = pad_sequence(video_features, batch_first=True) 116 | video_mask = pad_sequence([torch.ones(len(v)) for v in video_features], batch_first=True).bool() 117 | 118 | result: EmqaBatch = { 119 | 'batch_video_ids': video_ids, 120 | 'batch_sample_ids': sample_ids, 121 | 'batch_question_tokens': question_tok['input_ids'], 122 | 'batch_question_mask': question_tok['attention_mask'], 123 | 'batch_answer_texts': answers, 124 | 'batch_answer_tokens': answers_tok['input_ids'], 125 | 'batch_answer_mask': answers_tok['attention_mask'], 126 | 'batch_video_features': video_features_padded, 127 | 'batch_video_mask': video_mask 128 | } 129 | if 'moment_label' in batch[0]: 130 | moment_labels = torch.zeros(len(batch), video_features_padded.shape[1]) 131 | for i, b in enumerate(batch): 132 | gt_start, gt_end = b['moment_label'] 133 | moment_labels[i, gt_start:gt_end] = 1 134 | # add smoothing before/after gt_start & end ? 135 | result['batch_moment_localization_labels'] = moment_labels 136 | 137 | return result 138 | 139 | @classmethod 140 | def create_from_cfg(cls, cfg, split: DatasetSplitName, tokenizer: Tokenizer, 141 | separate_question_tok: Tokenizer = None): 142 | """ 143 | Create EmqaDataset from cfg. 144 | 145 | :param cfg: Needs data_dir, feature_type 146 | :param split: train / val / test 147 | :param tokenizer: hugginface tokenizer 148 | :param separate_question_tok: separate tok for questions 149 | :return: EmqaDataset 150 | """ 151 | ds_dir = Path(cfg.data_dir) 152 | video_features_file = ds_dir / f'{split}-{cfg.feature_type}.hdf5' 153 | if not video_features_file.is_file(): 154 | video_features_file = ds_dir / f'{cfg.feature_type}.hdf5' 155 | if not video_features_file.is_file(): 156 | raise ValueError(str(video_features_file)) 157 | annotation_file = ds_dir / f'annotations.{split}.json' 158 | annotations = json.loads(annotation_file.read_text()) 159 | return cls(video_features_file, tokenizer, split, annotations, 160 | separate_question_tokenizer=separate_question_tok) 161 | -------------------------------------------------------------------------------- /model/sparse_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import AutoModelForSeq2SeqLM 6 | from transformers.modeling_outputs import BaseModelOutput 7 | from transformers.models.longformer.modeling_longformer import LongformerModel 8 | from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForConditionalGeneration 9 | 10 | from model.base import EmqaBaseModel 11 | from model.moment_loc import TransformerMomentLocalizationLossModule 12 | 13 | 14 | # noinspection PyAbstractClass 15 | class LongformerVqaModel(EmqaBaseModel): 16 | # Actually this is not really a EMQA model since it uses full attention over the input 17 | 18 | def __init__(self, input_size: int, 19 | pretrained_enc_dec: str, 20 | pretrained_longformer: str, 21 | moment_localization_loss: TransformerMomentLocalizationLossModule = None) -> None: 22 | super().__init__() 23 | # Actually, only using the decoder of the enc_dec model. Just want to use LM head + cross attention 24 | self.moment_localization_loss = moment_localization_loss 25 | self.enc_dec = AutoModelForSeq2SeqLM.from_pretrained(pretrained_enc_dec) 26 | self.enc_dec.encoder.block = None 27 | self.enc_dec.encoder.final_layer_norm = None 28 | self.enc_dec.decoder.config.output_attentions = True 29 | 30 | self.longformer = LongformerModel.from_pretrained(pretrained_longformer, add_pooling_layer=False) 31 | 32 | longformer_h = self.longformer.get_input_embeddings().embedding_dim 33 | if input_size != longformer_h: 34 | self.transform_visual = nn.Linear(input_size, longformer_h, bias=False) 35 | else: 36 | self.transform_visual = nn.Identity() 37 | decoder_h = self.enc_dec.get_input_embeddings().embedding_dim 38 | if longformer_h != decoder_h: 39 | self.transform_context = nn.Linear(longformer_h, decoder_h, bias=False) 40 | else: 41 | self.transform_context = nn.Identity() 42 | 43 | def teacher_forcing_forward(self, question_tokens, question_mask, 44 | video_features, video_mask, 45 | answer_tokens, answer_mask, 46 | batch_sample_ids, 47 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 48 | context, context_mask = self.forward_encoders(question_tokens, question_mask, video_features, video_mask) 49 | output = self.enc_dec(labels=answer_tokens, decoder_attention_mask=answer_mask, 50 | encoder_outputs=(context,), attention_mask=context_mask) 51 | loss_dict = {'lm_loss': output.loss} 52 | if self.moment_localization_loss: 53 | loss_dict['moment_localization'] = self.moment_localization_loss( 54 | question_tokens, output, moment_localization_labels, video_mask) 55 | return loss_dict, output.logits 56 | 57 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor: 58 | context, context_mask = self.forward_encoders(question_tokens, question_mask, video_features, video_mask) 59 | # noinspection PyTypeChecker 60 | enc_out = BaseModelOutput(last_hidden_state=context) 61 | return self.enc_dec.generate(encoder_outputs=enc_out, attention_mask=context_mask) 62 | 63 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask): 64 | longformer_in = torch.cat([ 65 | self.transform_visual(video_features), 66 | self.longformer.get_input_embeddings()(question_tokens) 67 | ], dim=1) 68 | longformer_mask = torch.cat([ 69 | video_mask, 70 | question_mask 71 | ], dim=1) 72 | # initialize to global attention to be deactivated for all tokens 73 | global_attention_mask = torch.zeros_like(longformer_mask) 74 | # Set global attention to question tokens 75 | global_attention_mask[:, -question_tokens.shape[1]:] = 1 76 | 77 | context = self.longformer( 78 | inputs_embeds=longformer_in, attention_mask=longformer_mask, 79 | global_attention_mask=global_attention_mask, 80 | ).last_hidden_state 81 | context = self.transform_context(context) 82 | return context, longformer_mask 83 | 84 | 85 | class BigBirdVqaModel(EmqaBaseModel): 86 | # Actually this is not really a EMQA model since it uses full attention over the input 87 | 88 | def __init__(self, input_size: int, 89 | pretrained_bigbird: str, 90 | moment_localization_loss: TransformerMomentLocalizationLossModule = None, 91 | gradient_checkpointing=False) -> None: 92 | super().__init__() 93 | self.moment_localization_loss = moment_localization_loss 94 | self.bigbird = BigBirdPegasusForConditionalGeneration.from_pretrained(pretrained_bigbird) 95 | 96 | bigbird_h = self.bigbird.get_input_embeddings().embedding_dim 97 | if input_size != bigbird_h: 98 | self.transform_visual = nn.Linear(input_size, bigbird_h, bias=False) 99 | else: 100 | self.transform_visual = nn.Identity() 101 | 102 | self.gradient_checkpointing = gradient_checkpointing 103 | if gradient_checkpointing: 104 | self.bigbird.gradient_checkpointing_enable() 105 | 106 | def teacher_forcing_forward(self, question_tokens, question_mask, 107 | video_features, video_mask, 108 | answer_tokens, answer_mask, 109 | batch_sample_ids, 110 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 111 | enc_in, enc_mask = self.prepare_encoder_input(question_tokens, question_mask, video_features, video_mask) 112 | output = self.bigbird(labels=answer_tokens, decoder_attention_mask=answer_mask, 113 | inputs_embeds=enc_in, attention_mask=enc_mask, 114 | use_cache=not self.gradient_checkpointing) 115 | loss_dict = {'lm_loss': output.loss} 116 | if self.moment_localization_loss: 117 | loss_dict['moment_localization'] = self.moment_localization_loss( 118 | question_tokens, output, moment_localization_labels, video_mask) 119 | return loss_dict, output.logits 120 | 121 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor: 122 | enc_in, enc_mask = self.prepare_encoder_input(question_tokens, question_mask, video_features, video_mask) 123 | return self.bigbird.generate(inputs_embeds=enc_in, attention_mask=enc_mask) 124 | 125 | def prepare_encoder_input(self, question_tokens, question_mask, video_features, video_mask): 126 | encoder_in = torch.cat([ 127 | self.transform_visual(video_features), 128 | self.bigbird.get_input_embeddings()(question_tokens) 129 | ], dim=1) 130 | encoder_mask = torch.cat([ 131 | video_mask, 132 | question_mask 133 | ], dim=1) 134 | return encoder_in, encoder_mask 135 | -------------------------------------------------------------------------------- /model/importance_teacher.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from hydra import initialize, compose 5 | from omegaconf import DictConfig 6 | from torch import nn 7 | from tqdm import tqdm 8 | 9 | from data.datamodule import EmqaDataModule 10 | from data.emqa_dataset import EmqaBatch 11 | from model.base import MemoryAugmentedTransformerEmqaModel 12 | 13 | 14 | # Mean Pooling - Take attention mask into account for correct averaging 15 | def mean_pooling(model_output, attention_mask): 16 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings 17 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 18 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 19 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 20 | return sum_embeddings / sum_mask 21 | 22 | 23 | # noinspection PyAbstractClass 24 | class ImportanceTeacherVqaModel(MemoryAugmentedTransformerEmqaModel): 25 | # Actually this is not really a MemoryAugmentedTransformerEmqaModel since it uses full attention over the input 26 | 27 | def __init__(self, fragment_length: int, 28 | input_size: int, 29 | pretrained_enc_dec: str) -> None: 30 | super().__init__(pretrained_enc_dec) 31 | self.clip_avg = nn.AvgPool1d(fragment_length, stride=1) 32 | self.question_embedding = self.transformer.get_encoder() 33 | 34 | hidden = self.transformer.get_input_embeddings().embedding_dim 35 | self.query_attn = nn.MultiheadAttention(embed_dim=hidden, 36 | num_heads=1, batch_first=True) 37 | 38 | sentence_emb = self.question_embedding.get_input_embeddings().embedding_dim 39 | if sentence_emb != hidden: 40 | self.transform_query = nn.Linear(sentence_emb, hidden, bias=False) 41 | else: 42 | self.transform_query = nn.Identity() 43 | if input_size != hidden: 44 | self.transform_visual = nn.Linear(input_size, hidden, bias=False) 45 | else: 46 | self.transform_visual = nn.Identity() 47 | 48 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels): 49 | attn_context = self.attend_to_fragments(question_tokens, question_mask, video_features, video_mask) 50 | attn_mask = torch.ones(attn_context.shape[:-1], device=attn_context.device, dtype=torch.bool) 51 | 52 | context, context_mask = self._prepare_context(attn_context, attn_mask, question_tokens, question_mask) 53 | return context, context_mask, attn_context, attn_mask, {} 54 | 55 | def attend_to_fragments(self, question_tokens, question_mask, 56 | visual_input, visual_mask, 57 | return_attn_scores=False): 58 | # visual input: B x N x H. Want to avg pool over N => permute 59 | fragments = self.clip_avg(visual_input.permute(0, 2, 1)) 60 | fragments = fragments.permute(0, 2, 1) # permute back to B x N' x H 61 | # bsz x num_fragments : True <==> valid entry 62 | fragments_mask = (self.clip_avg(visual_mask.to(dtype=torch.float)) == 1).squeeze() 63 | 64 | # modify samples where fragment_length > input_length with correct averaging over non-padded values only 65 | # so that even short samples are represented by at least one fragment 66 | input_lengths = visual_mask.sum(dim=1) 67 | num_fragments = fragments_mask.sum(dim=1) 68 | too_short_videos = num_fragments == 0 69 | for b in too_short_videos.nonzero(): 70 | b = b.item() 71 | fragments[b, 0] = visual_input[b, visual_mask[b]].sum(dim=0) / input_lengths[b] 72 | fragments_mask[b, 0] = 1 73 | num_fragments[b] = 1 74 | 75 | with torch.no_grad(): 76 | model_output = self.question_embedding(question_tokens, question_mask) 77 | query = mean_pooling(model_output, question_mask) 78 | 79 | fragments = self.transform_visual(fragments) 80 | query = self.transform_query(query).unsqueeze(dim=1) # add fake seq dimension 81 | 82 | # attn_mask: batch x querys x keys = bsz x 1 x num_fragments. 83 | # True <==> corresponding position is _not_ allowed to attend 84 | attn_mask = ~fragments_mask[:, None, :] 85 | context_token, attn_weights = self.query_attn(query, fragments, fragments, attn_mask=attn_mask, 86 | need_weights=return_attn_scores) # B x 1 x H 87 | if return_attn_scores: 88 | # trim padding indices away 89 | attn_weights = attn_weights.squeeze() 90 | return context_token, [w[:length] for w, length in zip(attn_weights, num_fragments)] 91 | else: 92 | return context_token 93 | 94 | 95 | def _extract_attn_scores(model: ImportanceTeacherVqaModel, train_iterator): 96 | all_scores = {} 97 | batch: EmqaBatch 98 | for batch in train_iterator: 99 | _, attn_scores = model.attend_to_fragments( 100 | batch['batch_question_tokens'].cuda(), 101 | batch['batch_question_mask'].cuda(), 102 | batch['batch_video_features'].cuda(), 103 | batch['batch_video_mask'].cuda(), 104 | return_attn_scores=True 105 | ) 106 | for idx, scores in zip(batch['batch_sample_ids'], attn_scores): 107 | all_scores[idx] = scores.cpu() 108 | 109 | return all_scores 110 | 111 | 112 | @torch.no_grad() 113 | def main(output_file: str, checkpoint_path: str, config: DictConfig): 114 | checkpoint = torch.load(checkpoint_path) 115 | state_dict = {k[len('model.'):]: v for k, v in checkpoint['state_dict'].items()} 116 | model_cfg = dict(config.model) 117 | model_cfg.pop('_target_') 118 | model = ImportanceTeacherVqaModel(**model_cfg) 119 | # noinspection PyTypeChecker 120 | model.load_state_dict(state_dict) 121 | 122 | data = EmqaDataModule(config.dataset, drop_last=False) 123 | data.prepare_data() 124 | data.setup() 125 | 126 | model.cuda() 127 | result = {} 128 | with tqdm(data.train_dataloader(), 'Train') as train_iterator: 129 | result.update(_extract_attn_scores(model, train_iterator)) 130 | with tqdm(data.val_dataloader(), 'Val') as val_iterator: 131 | result.update(_extract_attn_scores(model, val_iterator)) 132 | 133 | torch.save(result, output_file) 134 | 135 | 136 | def cli_main(): 137 | parser = ArgumentParser(description='Save attention scores from ImportanceTeacherVqaModel') 138 | 139 | parser.add_argument('--bsz', type=int, default=32) 140 | parser.add_argument('--output_file', type=str, required=True, help='Where to save attention scores.') 141 | parser.add_argument('--checkpoint_path', type=str, required=True, 142 | help='Where to load trained model from. PyTorch Lighting checkpoint from SimpleVqa training') 143 | args = parser.parse_args() 144 | 145 | initialize(config_path='../config', job_name="save_vqa_attn_scores") 146 | config = compose(config_name='base', overrides=[f'dataset.train_bsz={args.bsz}', f'dataset.test_bsz={args.bsz}']) 147 | main(args.output_file, args.checkpoint_path, config) 148 | 149 | 150 | if __name__ == '__main__': 151 | cli_main() 152 | -------------------------------------------------------------------------------- /model/lightning.py: -------------------------------------------------------------------------------- 1 | import random 2 | from itertools import chain 3 | from typing import List, Dict, Union 4 | 5 | import torch 6 | from hydra.utils import instantiate 7 | from pytorch_lightning import LightningModule 8 | from transformers import PreTrainedTokenizerBase as Tokenizer 9 | 10 | from data.emqa_dataset import EmqaBatch 11 | from eval.eval import calc_metrics 12 | from lightning_util import freeze_params 13 | from model.base import EmqaBaseModel 14 | 15 | 16 | # noinspection PyAbstractClass 17 | class EmqaLightningModule(LightningModule): 18 | 19 | def __init__(self, model_config, optim_config, tokenizer: Tokenizer) -> None: 20 | super().__init__() 21 | self.save_hyperparameters(dict(model=model_config, optim=optim_config)) 22 | self.model: EmqaBaseModel = instantiate(model_config) 23 | self.optimizer_config = optim_config.optimizer 24 | if 'loss_weights' in optim_config: 25 | self.loss_weights: Dict[str, float] = optim_config.loss_weights 26 | self._loss_calc_cache = None 27 | else: 28 | self.loss_weights = None 29 | self.tokenizer = tokenizer 30 | self.lr = self.optimizer_config.lr 31 | freeze_params(self, optim_config.freeze) 32 | self._log_indices = {} 33 | 34 | def training_step(self, batch: EmqaBatch, batch_idx): 35 | loss_dict, logits = self.model(**self._get_model_inputs(batch)) 36 | total_loss = self._modify_loss_dict(loss_dict) 37 | for k, v in loss_dict.items(): 38 | self.log(k, v) 39 | return total_loss 40 | 41 | def _modify_loss_dict(self, loss_dict: Dict[str, torch.Tensor]): 42 | if 'total_loss' in loss_dict: 43 | return loss_dict['total_loss'] 44 | if len(loss_dict) == 1: 45 | # No matter how it's called, the single value is the total loss 46 | # However, no need to add it to loss_dict again (would only produce log twice) 47 | return next(iter(loss_dict.values())) 48 | assert self.loss_weights is not None 49 | if self._loss_calc_cache: 50 | master_key, remaining_keys = self._loss_calc_cache 51 | else: 52 | specified_keys = self.loss_weights.keys() 53 | all_keys = loss_dict.keys() 54 | master_keys = all_keys - specified_keys 55 | assert len(master_keys) == 1, f'There must be exactly one loss weight not specified (master weight), ' \ 56 | f'got {all_keys} - {specified_keys} = {master_keys}' 57 | remaining_keys = all_keys - master_keys 58 | assert specified_keys == remaining_keys, f'{specified_keys} != {remaining_keys}' 59 | master_key = next(iter(master_keys)) 60 | self._loss_calc_cache = (master_key, remaining_keys) 61 | master_loss = loss_dict[master_key] 62 | total_loss = master_loss + sum(self.loss_weights[k] * loss_dict[k] for k in remaining_keys) 63 | loss_dict['total_loss'] = total_loss # Add it to dict for logging purposes 64 | return total_loss 65 | 66 | @staticmethod 67 | def _get_model_inputs(batch: EmqaBatch): 68 | return dict( 69 | question_tokens=batch['batch_question_tokens'], 70 | question_mask=batch['batch_question_mask'], 71 | video_features=batch['batch_video_features'], 72 | video_mask=batch['batch_video_mask'], 73 | answer_tokens=batch['batch_answer_tokens'], 74 | answer_mask=batch['batch_answer_mask'], 75 | batch_sample_ids=batch['batch_sample_ids'], 76 | moment_localization_labels=batch.get('batch_moment_localization_labels') 77 | ) 78 | 79 | def validation_step(self, batch: EmqaBatch, batch_idx): 80 | loss_dict, lm_logits = self.model(**self._get_model_inputs(batch)) 81 | self._modify_loss_dict(loss_dict) 82 | hypo_answers = self._extract_answers(lm_logits) 83 | return {'hypos': hypo_answers, 84 | 'targets': batch['batch_answer_texts'], 85 | **loss_dict} 86 | 87 | def _extract_answers(self, lm_logits): 88 | sequences = lm_logits.argmax(dim=-1) 89 | hypo_answers = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) 90 | return hypo_answers 91 | 92 | def validation_epoch_end(self, outputs: List[Dict[str, Union[torch.Tensor, List[str]]]]) -> None: 93 | def _mean(key): 94 | return torch.stack([data[key] for data in outputs]).mean() 95 | 96 | self._log_some_outputs(outputs, 'val') 97 | metrics = self.aggregate_metrics(outputs, prefix='val') 98 | metrics.update({ 99 | f'val_{name}': _mean(name) for name in outputs[0].keys() if 'loss' in name 100 | }) 101 | self.log_dict(metrics) 102 | 103 | def _log_some_outputs(self, outputs, name): 104 | num_val_steps_to_log, num_samples_per_batch_to_log = 5, 3 # Could be configurable via cfg 105 | if name in self._log_indices: 106 | steps_to_log_indices = self._log_indices[name]['steps'] 107 | else: 108 | steps_to_log_indices = random.sample(range(len(outputs)), k=min(len(outputs), num_val_steps_to_log)) 109 | self._log_indices[name] = {'steps': steps_to_log_indices, 'samples': [ 110 | random.sample(range(len(outputs[step]['targets'])), 111 | k=min(len(outputs[step]['targets']), num_samples_per_batch_to_log)) 112 | for step in steps_to_log_indices 113 | ]} 114 | for i, step in enumerate(steps_to_log_indices): 115 | output, target = outputs[step]['hypos'], outputs[step]['targets'] 116 | indices = self._log_indices[name]['samples'][i] 117 | for b in indices: 118 | sample = ( 119 | f'Target: "{target[b]}". \n' 120 | f'Output: "{output[b]}"' 121 | ) 122 | self.logger.experiment.add_text(f'{name} {str(i * len(indices) + b)}', sample, 123 | global_step=self.global_step) 124 | 125 | @staticmethod 126 | def aggregate_metrics(outputs, prefix): 127 | all_hypos = list(chain(*(data['hypos'] for data in outputs))) 128 | all_targets = list(chain(*(data['targets'] for data in outputs))) 129 | metrics = calc_metrics(all_hypos, [[x] for x in all_targets]) 130 | metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} 131 | return metrics 132 | 133 | def test_step(self, batch: EmqaBatch, batch_idx): 134 | sequences = self.model(**self._get_model_inputs(batch), teacher_forcing=False) 135 | hypo_answers = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) 136 | return {'hypos': hypo_answers, 137 | 'targets': batch['batch_answer_texts']} 138 | 139 | def test_epoch_end(self, outputs: List[Dict[str, Union[List[str], Dict]]]) -> None: 140 | self._log_some_outputs(outputs, 'test') 141 | metrics = self.aggregate_metrics(outputs, prefix='test') 142 | self.log_dict(metrics) 143 | 144 | def configure_optimizers(self): 145 | params = filter(lambda p: p.requires_grad, self.parameters()) 146 | return instantiate(self.optimizer_config, 147 | params, 148 | # lr might be overridden by auto lr tuning 149 | lr=self.lr) 150 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM 6 | from transformers.modeling_outputs import BaseModelOutput 7 | 8 | 9 | class EmqaBaseModel(nn.Module): 10 | 11 | def forward(self, question_tokens, question_mask, 12 | video_features, video_mask, 13 | answer_tokens, answer_mask, 14 | batch_sample_ids, 15 | moment_localization_labels, 16 | teacher_forcing=True) -> Union[Tuple[Dict[str, torch.Tensor], torch.Tensor], 17 | torch.Tensor]: 18 | if teacher_forcing: 19 | return self.teacher_forcing_forward(question_tokens, question_mask, 20 | video_features, video_mask, 21 | answer_tokens, answer_mask, 22 | batch_sample_ids, moment_localization_labels) 23 | else: 24 | return self.autoregressive_forward(question_tokens, question_mask, 25 | video_features, video_mask) 26 | 27 | def teacher_forcing_forward(self, question_tokens, question_mask, 28 | video_features, video_mask, 29 | answer_tokens, answer_mask, 30 | batch_sample_ids, 31 | moment_localization_labels 32 | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 33 | """ 34 | Forward in teacher forcing mode. 35 | 36 | :param question_tokens: B x Q 37 | :param question_mask: B x Q 38 | :param video_features: B x N x H 39 | :param video_mask: B x N x H 40 | :param answer_tokens: B x A 41 | :param answer_mask: B x A 42 | :param batch_sample_ids: List of sample ids for this batch. 43 | :param moment_localization_labels: B x N, with entries between 0.0 and 1.0 44 | (smoothed label saying if each frame is part of ground truth moment). Might be None. 45 | :return: tuple with loss_dict and anwer logits. 46 | loss_dict should at least contain "total_loss" Tensor. 47 | """ 48 | raise NotImplementedError 49 | 50 | def autoregressive_forward(self, question_tokens, question_mask, 51 | video_features, video_mask 52 | ) -> torch.Tensor: 53 | """ 54 | Forward in autoregressive mode. 55 | 56 | :param question_tokens: B x Q 57 | :param question_mask: B x Q 58 | :param video_features: B x N x H 59 | :param video_mask: B x N x H 60 | :return: tensor with answer tokens. B x A 61 | """ 62 | raise NotImplementedError 63 | 64 | 65 | class MemoryAugmentedTransformerEmqaModel(EmqaBaseModel): 66 | def __init__(self, 67 | pretrained_enc_dec: str 68 | ) -> None: 69 | super().__init__() 70 | self.transformer: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(pretrained_enc_dec) 71 | 72 | def teacher_forcing_forward(self, question_tokens, question_mask, 73 | video_features, video_mask, 74 | answer_tokens, answer_mask, 75 | batch_sample_ids, 76 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: 77 | context, context_mask, final_memory, mem_mask, enc_add_loss = self.forward_encoders( 78 | question_tokens, question_mask, video_features, video_mask, moment_localization_labels) 79 | output = self.transformer(labels=answer_tokens, decoder_attention_mask=answer_mask, 80 | encoder_outputs=(context,), attention_mask=context_mask) 81 | return { 82 | 'lm_loss': output.loss, 83 | **enc_add_loss, 84 | **self.calc_additional_loss(question_tokens, question_mask, 85 | video_features, video_mask, 86 | answer_tokens, answer_mask, batch_sample_ids, 87 | context, context_mask, final_memory, mem_mask, 88 | output, moment_localization_labels) 89 | }, output.logits 90 | 91 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels): 92 | q_encoder_out = self.transformer.encoder(input_ids=question_tokens, attention_mask=question_mask) 93 | q_avg_pooled = q_encoder_out.last_hidden_state.mean(dim=1) 94 | 95 | mem_out = self.forward_memory(video_features, video_mask, moment_localization_labels, q_avg_pooled) 96 | if isinstance(mem_out, tuple): 97 | final_memory, additional_loss = mem_out 98 | else: 99 | assert torch.is_tensor(mem_out) 100 | final_memory = mem_out 101 | additional_loss = {} 102 | mem_mask = torch.ones(final_memory.shape[:-1], device=final_memory.device, dtype=torch.bool) 103 | 104 | context, context_mask = self._prepare_context(final_memory, mem_mask, question_mask=question_mask, 105 | q_encoder_out=q_encoder_out) 106 | return context, context_mask, final_memory, mem_mask, additional_loss 107 | 108 | def _prepare_context(self, final_memory, mem_mask, question_tokens=None, question_mask=None, q_encoder_out=None): 109 | if q_encoder_out: 110 | encoder_out = q_encoder_out 111 | else: 112 | encoder_out = self.transformer.encoder(input_ids=question_tokens, attention_mask=question_mask) 113 | context = torch.cat([ 114 | final_memory, 115 | encoder_out.last_hidden_state 116 | ], dim=1) 117 | context_mask = torch.cat([ 118 | mem_mask, 119 | question_mask 120 | ], dim=1) 121 | return context, context_mask 122 | 123 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor: 124 | context, context_mask, _, _, _ = self.forward_encoders(question_tokens, question_mask, video_features, 125 | video_mask, None) 126 | # noinspection PyTypeChecker 127 | enc_out = BaseModelOutput(last_hidden_state=context) 128 | return self.transformer.generate(encoder_outputs=enc_out, attention_mask=context_mask) 129 | 130 | def forward_memory(self, video_features, video_mask, 131 | moment_localization_labels, 132 | question_encoding) -> Union[torch.Tensor, 133 | Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: 134 | """ 135 | Forward video input to memory. 136 | 137 | :param video_features: Tensor of shape Batch x Sequence x Features 138 | :param video_mask: Tensor of shape Batch x Sequence 139 | :param moment_localization_labels: Tensor of shape Batch x Sequence, with entries between 0.0 and 1.0 140 | (smoothed label saying if each frame is part of ground truth moment). Might be None. 141 | :param question_encoding: Tensor of shape Batch x Hidden, Mean-pooled representation of the question. 142 | Should only be used for additional loss, not memory construction! 143 | :return: memory, Tensor of shape Batch x MemoryLength x Hidden 144 | or tuple of (memory, additional_loss_dict) 145 | """ 146 | raise NotImplementedError 147 | 148 | # noinspection PyMethodMayBeStatic 149 | def calc_additional_loss(self, question_tokens, question_mask, 150 | video_features, video_mask, 151 | answer_tokens, answer_mask, batch_sample_ids, 152 | context, context_mask, final_memory, mem_mask, 153 | transformer_output, moment_localization_labels) -> Dict[str, torch.Tensor]: 154 | return {} 155 | -------------------------------------------------------------------------------- /model/mann.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Tuple, Dict 2 | 3 | import torch 4 | from dnc import DNC 5 | from torch import nn 6 | 7 | from model.base import MemoryAugmentedTransformerEmqaModel 8 | from model.external.stm import STM 9 | from model.moment_loc import SeqMomentLocalizationLossModule 10 | 11 | 12 | class SegmentationMemoryAugmentedTransformerEmqaModel(MemoryAugmentedTransformerEmqaModel): 13 | def __init__(self, pretrained_enc_dec: str, 14 | segmentation_method: Literal['flat', 'avg'], 15 | segment_length: int, 16 | input_size: int, # this is downscaled to dnc_input_size for each time step separately 17 | mem_input_size: int, # Actual input to Memory depends on segmentation_method 18 | moment_loc: SeqMomentLocalizationLossModule = None 19 | ) -> None: 20 | super().__init__(pretrained_enc_dec) 21 | self.segment_length = segment_length 22 | self.segmentation_method = segmentation_method 23 | self.input_downscale = (nn.Identity() if input_size == mem_input_size 24 | else nn.Linear(input_size, mem_input_size, bias=False)) 25 | self.actual_mem_input_size = (segment_length * mem_input_size 26 | if segmentation_method == 'flat' else mem_input_size) 27 | self.moment_loc = moment_loc 28 | 29 | def forward_memory(self, video_features, video_mask, 30 | moment_localization_labels, 31 | question_encoding) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 32 | # B: batch, H: downscaled hidden dim, L: segment_length, N: num_segments, S: input seq length 33 | seq_downscaled: torch.Tensor = self.input_downscale(video_features) # B x S x H 34 | if self.segmentation_method == 'flat': 35 | mem_in, mem_in_mask = self._segment_flat(seq_downscaled, video_mask) 36 | elif self.segmentation_method == 'avg': 37 | mem_in, mem_in_mask = self._segment_avg(seq_downscaled, video_mask) 38 | else: 39 | raise ValueError(self.segmentation_method) 40 | memory, item_output = self.forward_segmented_memory(mem_in) 41 | 42 | aux_loss = self._calc_aux_loss(item_output, mem_in_mask, moment_localization_labels, question_encoding) 43 | return memory, aux_loss 44 | 45 | def _calc_aux_loss(self, item_output, mem_in_mask, moment_localization_labels, question_encoding): 46 | has_moment_loc = (item_output is not None 47 | and self.moment_loc is not None 48 | and moment_localization_labels is not None) 49 | if has_moment_loc: 50 | # have: B x input_seq 51 | # need: B x N 52 | segments = list(moment_localization_labels.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) 53 | segments = [segment_labels.sum(dim=1).clip(0, 1) for segment_labels in segments] # N tensors of B 54 | moment_localization_labels = torch.stack(segments, dim=1) # B x N 55 | aux_loss = {} if not has_moment_loc else { 56 | 'moment_localization': self.moment_loc(item_output, mem_in_mask, 57 | moment_localization_labels, question_encoding) 58 | } 59 | return aux_loss 60 | 61 | def _segment_flat(self, seq_downscaled, mask): 62 | bsz, _, h = seq_downscaled.shape 63 | segments = list(seq_downscaled.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) x H 64 | mask_segments = list(mask.split(self.segment_length, dim=1)) 65 | last_segment_length = segments[-1].shape[1] 66 | if last_segment_length != self.segment_length: 67 | # Zero-pad to segment length so that dnc_in can be constructed correctly 68 | segments[-1] = torch.cat([ 69 | segments[-1], 70 | torch.zeros(bsz, self.segment_length - last_segment_length, h, 71 | device=seq_downscaled.device, dtype=seq_downscaled.dtype) 72 | ], dim=1) 73 | mask_segments[-1] = torch.cat([ 74 | mask_segments[-1], 75 | torch.zeros(bsz, self.segment_length - last_segment_length, 76 | device=mask.device, dtype=mask.dtype) 77 | ]) 78 | segments = [s.view(bsz, -1) for s in segments] # N tensors of B x LH 79 | mem_in = torch.stack(segments, dim=1) # B x N x LH 80 | mem_in_mask = torch.stack(mask_segments, dim=1) # B x N x L 81 | mem_in_mask = mem_in_mask.any(dim=-1) # B x N 82 | return mem_in, mem_in_mask 83 | 84 | def _segment_avg(self, seq_downscaled, mask): 85 | bsz, _, h = seq_downscaled.shape 86 | segments = list(seq_downscaled.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) x H 87 | avg_segments = [x.mean(dim=1) for x in segments] # N tensors of B x H 88 | mem_in = torch.stack(avg_segments, dim=1) # B x N x H 89 | mask_segments = list(mask.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) 90 | mask_segments = [x.any(dim=1) for x in mask_segments] # N tensors of B 91 | segment_mask = torch.stack(mask_segments, dim=1) # B x N 92 | return mem_in, segment_mask 93 | 94 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # mem_in: B x N x LH 96 | raise NotImplementedError 97 | 98 | 99 | class DncEmqaModel(SegmentationMemoryAugmentedTransformerEmqaModel): 100 | 101 | def __init__(self, pretrained_enc_dec: str, 102 | segmentation_method: Literal['flat', 'avg'], 103 | segment_length: int, 104 | input_size: int, # this is downscaled to dnc_input_size for each time step separately 105 | dnc_input_size: int, # Actual input to DNC is segment_length * dnc_input_size 106 | rnn_hidden_size: int, 107 | num_dnc_layers=1, 108 | num_rnn_hidden_layers=2, 109 | num_mem_cells=5, 110 | mem_hidden_size=10, 111 | moment_loc: SeqMomentLocalizationLossModule = None 112 | ) -> None: 113 | super().__init__(pretrained_enc_dec, segmentation_method, segment_length, 114 | input_size, dnc_input_size, moment_loc) 115 | self.dnc = DNC(self.actual_mem_input_size, 116 | rnn_hidden_size, num_layers=num_dnc_layers, 117 | num_hidden_layers=num_rnn_hidden_layers, 118 | nr_cells=num_mem_cells, cell_size=mem_hidden_size) 119 | 120 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 121 | self._hack_dnc_gpu_ids(mem_in.device) 122 | output, (controller_hidden, memory, read_vectors) = self.dnc(mem_in, reset_experience=True) 123 | # memory['memory']: Batch x NumMemCells x MemHiddenSize 124 | return memory['memory'], output 125 | 126 | def _hack_dnc_gpu_ids(self, device): 127 | gpu_id = -1 if device.type == 'cpu' else device.index 128 | if self.dnc.gpu_id == gpu_id: 129 | return 130 | # This is a hack because DNC is programmed not in a pytorch idiomatic way... 131 | self.dnc.gpu_id = gpu_id 132 | for m in self.dnc.memories: 133 | m.gpu_id = gpu_id 134 | m.I = m.I.to(device=device) 135 | 136 | 137 | class StmEmqaModel(SegmentationMemoryAugmentedTransformerEmqaModel): 138 | 139 | def __init__(self, pretrained_enc_dec: str, 140 | segmentation_method: Literal['flat', 'avg'], 141 | segment_length: int, 142 | input_size: int, # this is downscaled to dnc_input_size for each time step separately 143 | stm_input_size: int, # Actual input to DNC is segment_length * dnc_input_size 144 | mem_hidden_size=10, 145 | stm_step=1, 146 | stm_num_slot=8, 147 | stm_mlp_size=128, 148 | stm_slot_size=96, 149 | stm_rel_size=96, 150 | stm_out_att_size=64 151 | ) -> None: 152 | super().__init__(pretrained_enc_dec, segmentation_method, segment_length, input_size, stm_input_size) 153 | self.stm = STM(self.actual_mem_input_size, mem_hidden_size, 154 | stm_step, stm_num_slot, stm_mlp_size, stm_slot_size, 155 | stm_rel_size, stm_out_att_size) 156 | 157 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, None]: 158 | mem_in = mem_in.transpose(0, 1) # switch batch and sequence dim for STM 159 | output, (read_heads, item_memory_state, rel_memory_state) = self.stm(mem_in) 160 | # should return Batch x NumMemCells x MemHiddenSize 161 | # output is Batch x MemHiddenSize 162 | return output[:, None, :], None 163 | -------------------------------------------------------------------------------- /lightning_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | from argparse import Namespace 5 | from copy import deepcopy 6 | from fnmatch import fnmatchcase 7 | from pathlib import Path 8 | from typing import List 9 | 10 | import torch.nn 11 | from pytorch_lightning import seed_everything, Trainer, Callback 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 13 | from pytorch_lightning.plugins import DDPPlugin 14 | 15 | 16 | def dict_parser(s: str): 17 | return eval('{' + re.sub(r'(\w+)=(["\']?\w+["\']?)', r'"\1":\2', s) + '}') 18 | 19 | 20 | def add_common_trainer_util_args(parser, default_monitor_variable='val_loss', default_monitor_mode='min'): 21 | if default_monitor_mode not in ['min', 'max']: 22 | raise ValueError(default_monitor_mode) 23 | parser.add_argument('--lr_find_kwargs', default=dict(min_lr=5e-6, max_lr=1e-2), type=dict_parser, 24 | help='Arguments for LR find (--auto_lr_find). Default "min_lr=5e-6,max_lr=1e-2"') 25 | parser.add_argument('--random_seed', default=42, type=lambda s: None if s == 'None' else int(s), 26 | help='Seed everything. Set to "None" to disable global seeding') 27 | parser.add_argument('--auto_resume', default=False, action='store_true', 28 | help='Automatically resume last saved checkpoint, if available.') 29 | parser.add_argument('--test_only', default=False, action='store_true', 30 | help='Skip fit and call only test. This implies automatically detecting newest checkpoint, ' 31 | 'if --checkpoint_path is not given.') 32 | parser.add_argument('--checkpoint_path', default=None, type=str, 33 | help='Load this checkpoint to resume training or run testing. ' 34 | 'Pass in the special value "best" to use the best checkpoint according to ' 35 | 'args.monitor_variable and args.monitor_mode. ' 36 | 'Using "best" only works with test_only mode.') 37 | parser.add_argument('--ignore_existing_checkpoints', default=False, action='store_true', 38 | help='Proceed even with training a new model, even if previous checkpoints exists.') 39 | parser.add_argument('--monitor_variable', default=default_monitor_variable, type=str, 40 | help='Variable to monitor for early stopping and for checkpoint selection. ' 41 | f'Default: {default_monitor_variable}') 42 | parser.add_argument('--monitor_mode', default=default_monitor_mode, type=str, choices=['min', 'max'], 43 | help='Mode for monitoring the monitor_variable (for early stopping and checkpoint selection). ' 44 | f'Default: {default_monitor_mode}') 45 | parser.add_argument('--reset_early_stopping_criterion', default=False, action='store_true', 46 | help='Reset the early stopping criterion when loading from checkpoint. ' 47 | 'Prevents immediate exit after switching to more complex dataset in curriculum strategy') 48 | 49 | 50 | def _auto_resume_from_checkpoint(args): 51 | if getattr(args, 'resume_from_checkpoint', None) is not None: 52 | raise DeprecationWarning('Trainer.resume_from_checkpoint is deprecated. Switch to checkpoint_path argument.') 53 | best_mode = args.checkpoint_path == 'best' 54 | if best_mode: 55 | if not args.test_only: 56 | raise RuntimeError('checkpoint_path="best" only works in test_only mode!') 57 | # More "best" logic is handled below 58 | elif args.checkpoint_path is not None: 59 | return 60 | 61 | log_dir = Path(getattr(args, 'default_root_dir', None) or 'lightning_logs') 62 | existing_checkpoints = list(log_dir.glob('version_*/checkpoints/*.ckpt')) 63 | if len(existing_checkpoints) == 0: 64 | return # This is the first run 65 | if not args.test_only and not args.auto_resume: 66 | if args.ignore_existing_checkpoints: 67 | return # Explicitly requested 68 | raise RuntimeWarning(f"There already exist checkpoints, but checkpoint_path/auto_resume not set! " 69 | f"{existing_checkpoints}") 70 | if best_mode: 71 | chosen = _auto_choose_best_checkpoint(args, existing_checkpoints) 72 | else: 73 | chosen = _auto_choose_newest_checkpoint(existing_checkpoints) 74 | args.checkpoint_path = str(chosen) 75 | print(f'Auto-detected {"best" if best_mode else "newest"} checkpoint {chosen}, resuming it... ' 76 | f'If this is not intended, use --checkpoint_path !') 77 | 78 | 79 | def _auto_choose_newest_checkpoint(existing_checkpoints): 80 | chosen = None 81 | for c in existing_checkpoints: 82 | if chosen is None or c.stat().st_mtime > chosen.stat().st_mtime: 83 | chosen = c 84 | return chosen 85 | 86 | 87 | def _auto_choose_best_checkpoint(args, existing_checkpoints): 88 | chosen = None 89 | for c in existing_checkpoints: 90 | if chosen is None or 'last.ckpt' == chosen.name: 91 | chosen = c 92 | continue 93 | if 'last.ckpt' == c.name: 94 | continue 95 | chosen_match = re.search(fr'{re.escape(args.monitor_variable)}=(\d+(?:\.\d+)?)', chosen.name) 96 | current_match = re.search(fr'{re.escape(args.monitor_variable)}=(\d+(?:\.\d+)?)', c.name) 97 | if chosen_match is None: 98 | raise ValueError(chosen) 99 | if current_match is None: 100 | raise ValueError(c) 101 | op = {'min': lambda old, new: new < old, 'max': lambda old, new: new > old}[args.monitor_mode] 102 | if op(float(chosen_match.group(1)), float(current_match.group(1))): 103 | chosen = c 104 | return chosen 105 | 106 | 107 | def apply_common_train_util_args(args) -> List[Callback]: 108 | _auto_resume_from_checkpoint(args) 109 | if args.random_seed is not None: 110 | seed_everything(args.random_seed, workers=True) 111 | 112 | early_stopping = EarlyStopping(monitor=args.monitor_variable, mode=args.monitor_mode, 113 | min_delta=0.001, patience=10, 114 | check_on_train_epoch_end=False) 115 | if args.reset_early_stopping_criterion: 116 | # Prevent loading the early stopping criterion when restoring from checkpoint 117 | early_stopping.on_load_checkpoint = lambda *args, **kwargs: None 118 | return [ 119 | early_stopping, 120 | ModelCheckpoint(save_last=True, monitor=args.monitor_variable, mode=args.monitor_mode, 121 | save_top_k=1, filename='{step}-{' + args.monitor_variable + ':.3f}') 122 | ] 123 | 124 | 125 | def _ddp_save_tune(args, model, datamodule): 126 | """ 127 | Runs LR tuning on main process only, _before_ DDP is initialized. Sets env var for communication to child processes. 128 | """ 129 | lr_env_var = os.getenv('_pl_auto_lr_find') 130 | if lr_env_var is None: 131 | # Main process running this code will not have a value for the env var, and thus perform single-process LR tune 132 | args_copy = deepcopy(args) 133 | args_copy.strategy = None 134 | args_copy.gpus = 1 135 | print('Running single GPU tune...') 136 | single_process_trainer = Trainer.from_argparse_args(args_copy) 137 | single_process_trainer.tune(model, datamodule, lr_find_kwargs=dict(args.lr_find_kwargs)) 138 | # "Broadcast" result to other ranks. Can not use actual broadcast mechanism, 139 | # since DDP env is not yet running. 140 | os.environ['_pl_auto_lr_find'] = str(model.lr) 141 | else: 142 | # Later workers executing this code will just load the best LR from the env var 143 | model.lr = float(os.environ['_pl_auto_lr_find']) 144 | 145 | 146 | def _adjust_ddp_config(trainer_cfg): 147 | trainer_cfg = dict(trainer_cfg) 148 | strategy = trainer_cfg.get('strategy', None) 149 | if trainer_cfg['gpus'] > 1 and strategy is None: 150 | strategy = 'ddp' # Select ddp by default 151 | if strategy == 'ddp': 152 | trainer_cfg['strategy'] = DDPPlugin(find_unused_parameters=False, gradient_as_bucket_view=True) 153 | return trainer_cfg 154 | 155 | 156 | def tune_fit_test(trainer_cfg, model, datamodule): 157 | callbacks = apply_common_train_util_args(trainer_cfg) 158 | 159 | trainer_cfg = Namespace(**_adjust_ddp_config(trainer_cfg)) 160 | ddp_mode = isinstance(getattr(trainer_cfg, 'strategy', None), DDPPlugin) and trainer_cfg.gpus > 1 161 | if (ddp_mode 162 | and not trainer_cfg.test_only 163 | and trainer_cfg.checkpoint_path is None 164 | and trainer_cfg.auto_lr_find): 165 | # Do tuning with a "fake" trainer on single GPU 166 | _ddp_save_tune(trainer_cfg, model, datamodule) 167 | 168 | trainer = Trainer.from_argparse_args(trainer_cfg, callbacks=[ 169 | LearningRateMonitor(logging_interval='step'), 170 | *callbacks 171 | ]) 172 | 173 | if not trainer_cfg.test_only: 174 | if trainer_cfg.checkpoint_path is None and not ddp_mode: 175 | # Do tune with the trainer directly 176 | trainer.tune(model, datamodule, lr_find_kwargs=dict(trainer_cfg.lr_find_kwargs)) 177 | 178 | trainer.fit(model, datamodule, ckpt_path=trainer_cfg.checkpoint_path) 179 | 180 | trainer.test( 181 | model, datamodule, 182 | ckpt_path= 183 | trainer_cfg.checkpoint_path if trainer_cfg.test_only 184 | else None # If fit was called before, we should _not_ restore initial train checkpoint 185 | ) 186 | 187 | if not trainer_cfg.test_only and not trainer.fast_dev_run: 188 | # test call above runs test on last after training, 189 | # also want it on best checkpoint by default 190 | # before, flush stdout and err to avoid strange order of outputs 191 | print('', file=sys.stdout, flush=True) 192 | print('', file=sys.stderr, flush=True) 193 | trainer.test(model, datamodule, ckpt_path='best') 194 | 195 | 196 | def freeze_params(model: torch.nn.Module, freeze_spec: List[str]): 197 | """ 198 | Freeze parameters that begin with any of the freeze_spec items or match any of them according to fnmatchcase. 199 | 200 | Parameters 201 | ---------- 202 | model the model 203 | freeze_spec specifies which parameters to freeze (e.g. 'generator' or 'generator.h.*.weight') 204 | """ 205 | for name, p in model.named_parameters(): 206 | freeze = any( 207 | name.startswith(pattern) or fnmatchcase(name, pattern) 208 | for pattern in freeze_spec 209 | ) 210 | if freeze: 211 | p.requires_grad = False 212 | -------------------------------------------------------------------------------- /model/moment_loc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SeqMomentLocalizationLossModule(nn.Module): 6 | 7 | def __init__(self, 8 | seq_hidden_dim: int, 9 | question_hidden_dim: int) -> None: 10 | super().__init__() 11 | total = seq_hidden_dim + question_hidden_dim 12 | self.projection_layer = nn.Sequential( 13 | nn.Linear(total, total // 2, bias=True), 14 | nn.LeakyReLU(), 15 | nn.Linear(total // 2, 1, bias=False), 16 | nn.Tanh() 17 | ) 18 | self.loss = nn.MSELoss(reduction='none') 19 | 20 | def forward(self, output_sequence, sequence_mask, 21 | moment_localization_labels, question_encoding) -> torch.Tensor: 22 | """ 23 | Calculate moment localization loss based on output sequence. 24 | 25 | :param output_sequence: Tensor of shape Batch x Sequence x Hidden 26 | :param sequence_mask: Tensor of shape Batch x Sequence 27 | :param moment_localization_labels: Tensor of shape Batch x Sequence, with entries between 0.0 and 1.0 28 | (smoothed label saying if each frame is part of ground truth moment). Might be None. 29 | :param question_encoding: Tensor of shape Batch x Hidden, Mean-pooled representation of the question. 30 | :return: scalar Tensor containing moment localization loss 31 | """ 32 | bsz, seq = output_sequence.shape[:2] 33 | question = question_encoding[:, None, :].expand(bsz, seq, -1) 34 | seq_with_question = torch.cat([output_sequence, question], 35 | dim=-1) # Batch x Sequence x 2*Hidden 36 | output_scores = self.projection_layer(seq_with_question) # Batch x Sequence x 1 37 | loss_per_item = self.loss(output_scores.squeeze(), moment_localization_labels) 38 | loss_per_item = loss_per_item * sequence_mask 39 | return loss_per_item.mean() 40 | 41 | 42 | class TransformerMomentLocalizationLossModule(nn.Module): 43 | 44 | def __init__(self, softmax_temperature, att_loss_type, hinge_margin, lse_alpha): 45 | # softmax_temperature: softmax softmax_temperature when rescaling video-only attention scores 46 | # sum_attentions: whether to use LSE loss on sum of attention weights (target vs non-target interval) 47 | super().__init__() 48 | self.att_loss_type = att_loss_type 49 | self.softmax_temperature = softmax_temperature 50 | self.hinge_margin = hinge_margin 51 | self.lse_alpha = lse_alpha 52 | 53 | def forward(self, question_tokens, transformer_output, 54 | moment_localization_labels, video_mask): 55 | """ 56 | 57 | :param question_tokens: Batch x QuestionSequence 58 | :param transformer_output: transformer.ModelOutputWithCrossAttentions. Input to cross attention is assumed as 59 | concatenated [video, question], so that transformer_output.cross_attentions has shape 60 | Batch x NumHeads x DecoderSequence x (VideoSeq + QuestionSeq) 61 | :param moment_localization_labels: Batch x VideoSequence. 1 for target frames, 0 elsewhere 62 | :param video_mask: Batch x VideoSequence. 0 for masked indices, 1 for valid 63 | :return: loss 64 | """ 65 | # Tuple of tensors (one for each layer) of shape (batch, num_heads, dec_seq_len, enc_seq_len) 66 | cross_attentions = transformer_output.cross_attentions 67 | # (batch, num_layers, num_heads, dec_seq_len, enc_seq_len) 68 | cross_attentions = torch.stack(cross_attentions, dim=1) 69 | question_len = question_tokens.shape[1] 70 | 71 | # Need to take care of masked indices before re-applying softmax 72 | cross_attn_on_video = cross_attentions[:, :, :, :, :-question_len] / self.softmax_temperature 73 | mask_value = torch.scalar_tensor(-1e6, dtype=cross_attentions.dtype, device=cross_attentions.device) 74 | cross_attn_on_video = torch.where(video_mask[:, None, None, None, :], cross_attn_on_video, mask_value) 75 | cross_attn_on_video = cross_attn_on_video.softmax(dim=-1) 76 | 77 | # loss per batch 78 | loss = self.calc_attention_loss(cross_attn_on_video, moment_localization_labels, video_mask) 79 | return loss.mean() 80 | 81 | def calc_attention_loss(self, cross_attn_on_video, moment_localization_labels, video_mask): 82 | raise NotImplementedError 83 | 84 | def ranking_loss(self, pos_scores, neg_scores): 85 | if self.att_loss_type == "hinge": 86 | # max(0, m + S_pos - S_neg) 87 | loss = torch.clamp(self.hinge_margin + neg_scores - pos_scores, min=0) 88 | elif self.att_loss_type == "lse": 89 | # log[1 + exp(scale * (S_pos - S_neg))] 90 | loss = torch.log1p(torch.exp(self.lse_alpha * (neg_scores - pos_scores))) 91 | else: 92 | raise NotImplementedError("Only support hinge and lse") 93 | return loss 94 | 95 | 96 | class SummedAttentionTransformerMomentLocLoss(TransformerMomentLocalizationLossModule): 97 | 98 | def __init__(self, softmax_temperature=0.1, att_loss_type='lse', hinge_margin=0.4, lse_alpha=20): 99 | super().__init__(softmax_temperature, att_loss_type, hinge_margin, lse_alpha) 100 | 101 | def calc_attention_loss(self, cross_attn_on_video, moment_localization_labels, video_mask): 102 | bsz = len(video_mask) 103 | pos_mask = moment_localization_labels == 1 104 | neg_mask = (moment_localization_labels == 0) * video_mask 105 | mean_attn = cross_attn_on_video.mean(dim=(1, 2, 3)) # mean over heads, layers and decoder positions 106 | pos_scores = torch.stack([ 107 | mean_attn[b, pos_mask[b]].sum() 108 | for b in range(bsz) 109 | ]) 110 | neg_scores = torch.stack([ 111 | mean_attn[b, neg_mask[b]].sum() 112 | for b in range(bsz) 113 | ]) 114 | return self.ranking_loss(pos_scores, neg_scores) 115 | 116 | 117 | # This is copied & modified from https://github.com/jayleicn/TVQAplus 118 | class SamplingAttentionTransformerMomentLocLoss(TransformerMomentLocalizationLossModule): 119 | 120 | def __init__(self, num_negatives=2, use_hard_negatives=False, drop_topk=0, 121 | negative_pool_size=0, num_hard=2, 122 | softmax_temperature=0.1, att_loss_type='lse', hinge_margin=0.4, lse_alpha=20) -> None: 123 | super().__init__(softmax_temperature, att_loss_type, hinge_margin, lse_alpha) 124 | self.num_hard = num_hard 125 | self.negative_pool_size = negative_pool_size 126 | self.drop_topk = drop_topk 127 | self.use_hard_negatives = use_hard_negatives 128 | self.num_negatives = num_negatives 129 | 130 | def calc_attention_loss(self, cross_attn_on_video, att_labels, video_mask): 131 | # take max head, mean over layers and decoder positions 132 | scores = cross_attn_on_video.max(dim=2).values.mean(dim=(1, 2)) 133 | 134 | # att_labels : Batch x VideoSequence. 0 for non-target, 1 for target 135 | # scores : Batch x VideoSequence. Between 0 and 1 as given by softmax 136 | # rescale to [-1, 1] 137 | scores = scores * 2 - 1 138 | 139 | pos_container = [] # contains tuples of 2 elements, which are (batch_i, img_i) 140 | neg_container = [] 141 | bsz = len(att_labels) 142 | for b in range(bsz): 143 | pos_indices = att_labels[b].nonzero() # N_pos x 1 144 | neg_indices = ((1 - att_labels[b]) * video_mask[b]).nonzero() # N_neg x 1 145 | 146 | sampled_pos_indices, sampled_neg_indices = self._sample_negatives(scores[b], pos_indices, neg_indices) 147 | 148 | base_indices = torch.full((sampled_pos_indices.shape[0], 1), b, dtype=torch.long, device=pos_indices.device) 149 | pos_container.append(torch.cat([base_indices, sampled_pos_indices], dim=1)) 150 | neg_container.append(torch.cat([base_indices, sampled_neg_indices], dim=1)) 151 | 152 | pos_container = torch.cat(pos_container, dim=0) 153 | neg_container = torch.cat(neg_container, dim=0) 154 | 155 | pos_scores = scores[pos_container[:, 0], pos_container[:, 1]] 156 | neg_scores = scores[neg_container[:, 0], neg_container[:, 1]] 157 | 158 | att_loss = self.ranking_loss(pos_scores, neg_scores).mean(dim=-1) 159 | return att_loss 160 | 161 | def _sample_negatives(self, pred_score, pos_indices, neg_indices): 162 | """ Sample negatives from a set of indices. Several sampling strategies are supported: 163 | 1, random; 2, hard negatives; 3, drop_topk hard negatives; 4, mix easy and hard negatives 164 | 5, sampling within a pool of hard negatives; 6, sample across images of the same video. 165 | Args: 166 | pred_score: (num_img) 167 | pos_indices: (N_pos, 1) 168 | neg_indices: (N_neg, 1) 169 | Returns: 170 | 171 | """ 172 | num_unique_pos = len(pos_indices) 173 | sampled_pos_indices = torch.cat([pos_indices] * self.num_negatives, dim=0) 174 | if self.use_hard_negatives: 175 | # print("using use_hard_negatives") 176 | neg_scores = pred_score[neg_indices[:, 0]] 177 | max_indices = torch.sort(neg_scores, descending=True)[1].tolist() 178 | if self.negative_pool_size > self.num_negatives: # sample from a pool of hard negatives 179 | hard_pool = max_indices[self.drop_topk:self.drop_topk + self.negative_pool_size] 180 | hard_pool_indices = neg_indices[hard_pool] 181 | num_hard_negs = self.num_negatives 182 | sampled_easy_neg_indices = [] 183 | if self.num_hard < self.num_negatives: 184 | easy_pool = max_indices[self.drop_topk + self.negative_pool_size:] 185 | easy_pool_indices = neg_indices[easy_pool] 186 | num_hard_negs = self.num_hard 187 | num_easy_negs = self.num_negatives - num_hard_negs 188 | sampled_easy_neg_indices = easy_pool_indices[ 189 | torch.randint(low=0, high=len(easy_pool_indices), 190 | size=(num_easy_negs * num_unique_pos,), dtype=torch.long) 191 | ] 192 | sampled_hard_neg_indices = hard_pool_indices[ 193 | torch.randint(low=0, high=len(hard_pool_indices), 194 | size=(num_hard_negs * num_unique_pos,), dtype=torch.long) 195 | ] 196 | 197 | if len(sampled_easy_neg_indices) != 0: 198 | sampled_neg_indices = torch.cat([sampled_hard_neg_indices, sampled_easy_neg_indices], dim=0) 199 | else: 200 | sampled_neg_indices = sampled_hard_neg_indices 201 | 202 | else: # directly take the top negatives 203 | sampled_neg_indices = neg_indices[max_indices[self.drop_topk:self.drop_topk + len(sampled_pos_indices)]] 204 | else: 205 | sampled_neg_indices = neg_indices[ 206 | torch.randint(low=0, high=len(neg_indices), size=(len(sampled_pos_indices),), dtype=torch.long) 207 | ] 208 | return sampled_pos_indices, sampled_neg_indices 209 | -------------------------------------------------------------------------------- /model/external/stm.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/thaihungle/SAM/blob/master/baselines/sam/stm_basic.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def op_att(q, k, v): 10 | qq = q.unsqueeze(2).repeat(1, 1, k.shape[1], 1) 11 | kk = k.unsqueeze(1).repeat(1, q.shape[1], 1, 1) 12 | output = torch.matmul(F.tanh(qq * kk).unsqueeze(4), v.unsqueeze(1).repeat(1, q.shape[1], 1, 1).unsqueeze( 13 | 3)) # BxNXNxd_kq BxNxNxd_v --> BxNXNxd_kqxd_v 14 | # print(output.shape) 15 | output = torch.sum(output, dim=2) # BxNxd_kqxd_v 16 | # print(output.shape) 17 | return output 18 | 19 | 20 | def sdp_att(q, k, v): 21 | dot_product = torch.matmul(q, k.permute(0, 2, 1)) 22 | weights = F.softmax(dot_product, dim=-1) 23 | 24 | # output is [B, H, N, V] 25 | output = torch.matmul(weights, v) 26 | return output 27 | 28 | 29 | class MLP(nn.Module): 30 | def __init__(self, in_dim=28 * 28, out_dim=10, hid_dim=-1, layers=1): 31 | super(MLP, self).__init__() 32 | self.layers = layers 33 | if hid_dim <= 0: 34 | self.layers = -1 35 | if self.layers < 0: 36 | hid_dim = out_dim 37 | self.fc1 = nn.Linear(in_dim, hid_dim) 38 | # linear layer (n_hidden -> hidden_2) 39 | if self.layers > 0: 40 | self.fc2h = nn.ModuleList([nn.Linear(hid_dim, hid_dim)] * self.layers) 41 | # linear layer (n_hidden -> 10) 42 | if self.layers >= 0: 43 | self.fc3 = nn.Linear(hid_dim, out_dim) 44 | 45 | def forward(self, x): 46 | o = self.fc1(x) 47 | if self.layers > 0: 48 | for l in range(self.layers): 49 | o = self.fc2h[l](o) 50 | if self.layers >= 0: 51 | o = self.fc3(o) 52 | return o 53 | 54 | 55 | class STM(nn.Module): 56 | def __init__(self, input_size, output_size, step=1, num_slot=8, 57 | mlp_size=128, slot_size=96, rel_size=96, 58 | out_att_size=64, rd=True, 59 | init_alphas=[None, None, None], 60 | learn_init_mem=True, mlp_hid=-1): 61 | super(STM, self).__init__() 62 | self.mlp_size = mlp_size 63 | self.slot_size = slot_size 64 | self.rel_size = rel_size 65 | self.rnn_hid = slot_size 66 | self.num_slot = num_slot 67 | self.step = step 68 | self.rd = rd 69 | self.learn_init_mem = learn_init_mem 70 | 71 | self.out_att_size = out_att_size 72 | 73 | self.qkv_projector = nn.ModuleList([nn.Linear(slot_size, num_slot * 3)] * step) 74 | self.qkv_layernorm = nn.ModuleList([nn.LayerNorm([slot_size, num_slot * 3])] * step) 75 | 76 | if init_alphas[0] is None: 77 | self.alpha1 = [nn.Parameter(torch.zeros(1))] * step 78 | for ia, a in enumerate(self.alpha1): 79 | setattr(self, 'alpha1' + str(ia), self.alpha1[ia]) 80 | else: 81 | self.alpha1 = [init_alphas[0]] * step 82 | 83 | if init_alphas[1] is None: 84 | self.alpha2 = [nn.Parameter(torch.zeros(1))] * step 85 | for ia, a in enumerate(self.alpha2): 86 | setattr(self, 'alpha2' + str(ia), self.alpha2[ia]) 87 | else: 88 | self.alpha2 = [init_alphas[1]] * step 89 | 90 | if init_alphas[2] is None: 91 | self.alpha3 = [nn.Parameter(torch.zeros(1))] * step 92 | for ia, a in enumerate(self.alpha3): 93 | setattr(self, 'alpha3' + str(ia), self.alpha3[ia]) 94 | else: 95 | self.alpha3 = [init_alphas[2]] * step 96 | 97 | self.input_projector = MLP(input_size, slot_size, hid_dim=mlp_hid) 98 | self.input_projector2 = MLP(input_size, slot_size, hid_dim=mlp_hid) 99 | self.input_projector3 = MLP(input_size, num_slot, hid_dim=mlp_hid) 100 | 101 | self.input_gate_projector = nn.Linear(self.slot_size, self.slot_size * 2) 102 | self.memory_gate_projector = nn.Linear(self.slot_size, self.slot_size * 2) 103 | # trainable scalar gate bias tensors 104 | self.forget_bias = nn.Parameter(torch.tensor(1., dtype=torch.float32)) 105 | self.input_bias = nn.Parameter(torch.tensor(0., dtype=torch.float32)) 106 | 107 | self.rel_projector = nn.Linear(slot_size * slot_size, rel_size) 108 | self.rel_projector2 = nn.Linear(num_slot * slot_size, slot_size) 109 | self.rel_projector3 = nn.Linear(num_slot * rel_size, out_att_size) 110 | 111 | self.mlp = nn.Sequential( 112 | nn.Linear(out_att_size, self.mlp_size), 113 | nn.ReLU(), 114 | nn.Linear(self.mlp_size, self.mlp_size), 115 | nn.ReLU(), 116 | ) 117 | 118 | self.out = nn.Linear(self.mlp_size, output_size) 119 | 120 | if self.learn_init_mem: 121 | if torch.cuda.is_available(): 122 | self.register_parameter('item_memory_state_bias', 123 | torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size).cuda())) 124 | self.register_parameter('rel_memory_state_bias', torch.nn.Parameter( 125 | torch.Tensor(self.num_slot, self.slot_size, self.slot_size).cuda())) 126 | 127 | else: 128 | self.register_parameter('item_memory_state_bias', 129 | torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size))) 130 | self.register_parameter('rel_memory_state_bias', 131 | torch.nn.Parameter(torch.Tensor(self.num_slot, self.slot_size, self.slot_size))) 132 | 133 | stdev = 1 / (np.sqrt(self.slot_size + self.slot_size)) 134 | nn.init.uniform_(self.item_memory_state_bias, -stdev, stdev) 135 | stdev = 1 / (np.sqrt(self.slot_size + self.slot_size + self.num_slot)) 136 | nn.init.uniform_(self.rel_memory_state_bias, -stdev, stdev) 137 | 138 | def create_new_state(self, batch_size): 139 | if self.learn_init_mem: 140 | read_heads = torch.zeros(batch_size, self.out_att_size) 141 | item_memory_state = self.item_memory_state_bias.clone().repeat(batch_size, 1, 1) 142 | rel_memory_state = self.rel_memory_state_bias.clone().repeat(batch_size, 1, 1, 1) 143 | if torch.cuda.is_available(): 144 | read_heads = read_heads.cuda() 145 | else: 146 | item_memory_state = torch.stack([torch.zeros(self.slot_size, self.slot_size) for _ in range(batch_size)]) 147 | read_heads = torch.zeros(batch_size, self.out_att_size) 148 | rel_memory_state = torch.stack( 149 | [torch.zeros(self.num_slot, self.slot_size, self.slot_size) for _ in range(batch_size)]) 150 | if torch.cuda.is_available(): 151 | item_memory_state = item_memory_state.cuda() 152 | read_heads = read_heads.cuda() 153 | rel_memory_state = rel_memory_state.cuda() 154 | 155 | return read_heads, item_memory_state, rel_memory_state 156 | 157 | def compute_gates(self, inputs, memory): 158 | 159 | memory = torch.tanh(memory) 160 | if len(inputs.shape) == 3: 161 | if inputs.shape[1] > 1: 162 | raise ValueError( 163 | "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1") 164 | inputs = inputs.view(inputs.shape[0], -1) 165 | 166 | gate_inputs = self.input_gate_projector(inputs) 167 | gate_inputs = gate_inputs.unsqueeze(dim=1) 168 | gate_memory = self.memory_gate_projector(memory) 169 | else: 170 | raise ValueError("input shape of create_gate function is 2, expects 3") 171 | 172 | gates = gate_memory + gate_inputs 173 | gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2) 174 | input_gate, forget_gate = gates 175 | assert input_gate.shape[2] == forget_gate.shape[2] 176 | 177 | input_gate = torch.sigmoid(input_gate + self.input_bias) 178 | forget_gate = torch.sigmoid(forget_gate + self.forget_bias) 179 | 180 | return input_gate, forget_gate 181 | 182 | def compute(self, input_step, prev_state): 183 | 184 | hid = prev_state[0] 185 | item_memory_state = prev_state[1] 186 | rel_memory_state = prev_state[2] 187 | 188 | # transform input 189 | controller_outp = self.input_projector(input_step) 190 | controller_outp2 = self.input_projector2(input_step) 191 | controller_outp3 = self.input_projector3(input_step) 192 | 193 | # Mi write 194 | X = torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1)) # Bxdxd 195 | input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), item_memory_state) 196 | 197 | # Mr read 198 | controller_outp3 = F.softmax(controller_outp3, dim=-1) 199 | controller_outp4 = torch.einsum('bn,bd,bndf->bf', controller_outp3, controller_outp2, rel_memory_state) 200 | X2 = torch.einsum('bd,bf->bdf', controller_outp4, controller_outp2) 201 | 202 | if self.rd: 203 | # Mi write gating 204 | R = input_gate * F.tanh(X) 205 | R += forget_gate * item_memory_state 206 | else: 207 | # Mi write 208 | R = item_memory_state + torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1)) # Bxdxd 209 | 210 | for i in range(self.step): 211 | # SAM 212 | qkv = self.qkv_projector[i](R + self.alpha2[i] * X2) 213 | qkv = self.qkv_layernorm[i](qkv) 214 | qkv = qkv.permute(0, 2, 1) # Bx3Nxd 215 | 216 | q, k, v = torch.split(qkv, [self.num_slot] * 3, 1) # BxNxd 217 | 218 | R0 = op_att(q, k, v) # BxNxdxd 219 | 220 | # Mr transfer to Mi 221 | R2 = self.rel_projector2(R0.view(R0.shape[0], -1, R0.shape[3]).permute(0, 2, 1)) 222 | R = R + self.alpha3[i] * R2 223 | 224 | # Mr write 225 | rel_memory_state = self.alpha1[i] * rel_memory_state + R0 226 | 227 | # Mr transfer to output 228 | r_vec = self.rel_projector(rel_memory_state.view(rel_memory_state.shape[0], 229 | rel_memory_state.shape[1], 230 | -1)).view(input_step.shape[0], -1) 231 | out = self.rel_projector3(r_vec) 232 | 233 | # if self.gating_after: 234 | # #Mi write gating 235 | # input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), R) 236 | # if self.rd: 237 | # R = input_gate * torch.tanh(R) 238 | # R += forget_gate * item_memory_state 239 | 240 | return out, (out, R, rel_memory_state) 241 | 242 | def forward(self, input_step, hidden=None): 243 | 244 | if len(input_step.shape) == 3: 245 | self.init_sequence(input_step.shape[1]) 246 | for i in range(input_step.shape[0]): 247 | logit, self.previous_state = self.compute(input_step[i], self.previous_state) 248 | 249 | else: 250 | if hidden is not None: 251 | logit, hidden = self.compute(input_step, hidden) 252 | else: 253 | logit, self.previous_state = self.compute(input_step, self.previous_state) 254 | mlp = self.mlp(logit) 255 | out = self.out(mlp) 256 | return out, self.previous_state 257 | 258 | def init_sequence(self, batch_size): 259 | """Initializing the state.""" 260 | self.previous_state = self.create_new_state(batch_size) 261 | 262 | def calculate_num_params(self): 263 | """Returns the total number of parameters.""" 264 | num_params = 0 265 | for p in self.parameters(): 266 | num_params += p.data.view(-1).size(0) 267 | return num_params 268 | -------------------------------------------------------------------------------- /model/rehearsal.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Optional, Any 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from transformers import AutoModel, PreTrainedModel 9 | from transformers.models.roberta import RobertaModel, RobertaConfig 10 | 11 | from model.base import MemoryAugmentedTransformerEmqaModel 12 | 13 | 14 | class RehearsalTrainingModule(nn.Module): 15 | 16 | def __init__(self, 17 | input_size: int, 18 | mem_hidden_size: int, 19 | num_samples: int, 20 | sample_length: int, 21 | positive_mask_ratio: float = 0.5, 22 | negative_replacement_ratio: float = 0.5, 23 | invert_teacher_sequence: bool = False, 24 | pretrained_decoder: Optional[str] = None, 25 | decoder_params: Optional[Dict[str, Any]] = None, 26 | sampling_teacher_weights_file: Optional[str] = None 27 | ) -> None: 28 | super().__init__() 29 | if pretrained_decoder is not None: 30 | self.decoder = RobertaModel.from_pretrained(pretrained_decoder, config=RobertaConfig.from_pretrained( 31 | pretrained_decoder, is_decoder=True, add_cross_attention=True 32 | )) 33 | else: 34 | model_cfg = RobertaConfig(**decoder_params, is_decoder=True, add_cross_attention=True) 35 | self.decoder = RobertaModel(model_cfg) 36 | hidden_size = self.decoder.config.hidden_size 37 | self.input_dimension_adjustment_layer = nn.Linear(input_size, hidden_size, bias=False) 38 | self.em_dimension_adjustment_layer = nn.Linear(mem_hidden_size, hidden_size, bias=False) 39 | self.num_samples = num_samples 40 | self.sample_length = sample_length 41 | self.positive_mask_ratio = positive_mask_ratio 42 | self.negative_replacement_ratio = negative_replacement_ratio 43 | self.invert_teacher_sequence = invert_teacher_sequence 44 | 45 | empty = torch.empty(2, 1, hidden_size) 46 | nn.init.kaiming_uniform_(empty) 47 | self.class_token_emb = nn.Parameter(empty[0]) 48 | self.mask_token_emb = nn.Parameter(empty[1]) 49 | self.pos_neg_projection = nn.Sequential( 50 | nn.Linear(hidden_size, 1, bias=False), 51 | nn.Sigmoid() 52 | ) 53 | 54 | if sampling_teacher_weights_file: 55 | self.teacher_sampling_weights = torch.load(sampling_teacher_weights_file) 56 | else: 57 | self.teacher_sampling_weights = None 58 | 59 | def forward(self, memory, mem_mask, original_input, input_mask, batch_sample_ids): 60 | bsz, l_x = original_input.shape[0:2] 61 | h = self.decoder.config.hidden_size 62 | sample_length = min(self.sample_length, l_x) 63 | num_tokens_to_mask = int(sample_length * self.positive_mask_ratio) 64 | num_tokens_to_replace = int((sample_length - num_tokens_to_mask) * self.negative_replacement_ratio) 65 | dims = bsz, l_x, h, sample_length, num_tokens_to_mask, num_tokens_to_replace 66 | assert bsz > 1 # Negative sampling from other batches requires at least bsz=2 67 | 68 | samples, negative_samples, masked_items_map, padding_mask, start_indices = self._construct_samples( 69 | original_input, input_mask, dims, batch_sample_ids) 70 | hiddens_neg, hiddens_pos = self._forward_transformer(memory, mem_mask, samples, negative_samples, 71 | padding_mask, dims) 72 | 73 | recollection_loss = self._calc_recollection_loss(original_input, hiddens_pos, 74 | masked_items_map, start_indices, dims) 75 | familiarity_loss = self._calc_familiarity_loss(hiddens_neg, hiddens_pos, dims) 76 | 77 | return recollection_loss, familiarity_loss 78 | 79 | def _construct_samples(self, original_input, input_padding_mask, dims, batch_sample_ids): 80 | original_input = self.input_dimension_adjustment_layer(original_input) 81 | # original_input: bsz x l_x x h 82 | # input_padding_mask: bsz x l_x 83 | bsz, l_x, h, sample_length, num_tokens_to_mask, num_tokens_to_replace = dims 84 | start_indices = self._choose_start_indices(dims, batch_sample_ids) 85 | samples = torch.stack([ 86 | torch.stack([ 87 | torch.cat((self.class_token_emb, original_input[i, start:start + sample_length]), dim=0) 88 | for start in start_indices[i] 89 | ]) 90 | for i in range(bsz) 91 | ]) # bsz x num_samples x sample_length x hidden 92 | negative_samples = samples.clone() 93 | # masked_items_map is 1 for masked input items (!) i.e. 1 for items that are replaced with mask token. 94 | masked_items_map = torch.zeros(bsz, self.num_samples, 1 + sample_length, # + 1 for CLS token! 95 | device=samples.device, dtype=torch.bool) 96 | padding_mask = torch.ones_like(masked_items_map) # one for items that should be used, 0 for padding 97 | all_indices = range(1, 1 + sample_length) 98 | for i in range(bsz): 99 | for s in range(self.num_samples): 100 | start = start_indices[i, s].item() 101 | padding_mask[i, s, 1:] = input_padding_mask[i, start:start + sample_length] 102 | padding_indices = (~padding_mask).nonzero() 103 | mask_indices = np.random.choice(all_indices, num_tokens_to_mask, replace=False) 104 | # Set masked_items_map only if these items are not padded. padding_mask is zero for non-padded entries 105 | masked_items_map[i, s, mask_indices] = padding_mask[i, s, mask_indices] 106 | unmasked_indices = list(set(all_indices) - set(mask_indices) - set(padding_indices)) 107 | replacement_indices = np.random.choice(unmasked_indices, num_tokens_to_replace, replace=False) 108 | neg_sample_original_indices = replacement_indices + start - 1 # -1 because of CLS token 109 | negative_samples[i, s, replacement_indices] = original_input[(i + 1) % bsz, neg_sample_original_indices] 110 | samples[masked_items_map] = self.mask_token_emb 111 | negative_samples[masked_items_map] = self.mask_token_emb 112 | return samples, negative_samples, masked_items_map, padding_mask, start_indices 113 | 114 | def _choose_start_indices(self, dims, batch_sample_ids): 115 | bsz, l_x, _, sample_length, _, _ = dims 116 | num_fragments = l_x - sample_length + 1 117 | if self.teacher_sampling_weights is None: 118 | # uniform random sampling 119 | start_indices = np.stack([np.random.choice(max(1, num_fragments), self.num_samples, replace=False) 120 | for _ in range(bsz)]) 121 | else: 122 | # biased random sampling guided by unconstrained teacher model (see section "What to rehearse?" in RM paper) 123 | teacher_attn_weights = [ 124 | self.teacher_sampling_weights[sample_id] 125 | for sample_id in batch_sample_ids 126 | ] 127 | sample_distribution = torch.nn.utils.rnn.pad_sequence(teacher_attn_weights, batch_first=True) 128 | assert sample_distribution.shape[1] == num_fragments, f'{sample_distribution.shape}, {num_fragments}' 129 | if self.invert_teacher_sequence: 130 | sample_distribution = sample_distribution.flip(dims=(1,)) 131 | cum_distribution = sample_distribution.cumsum(dim=-1) 132 | rand_sources = torch.rand(bsz, self.num_samples) 133 | # noinspection PyTypeChecker 134 | start_indices = torch.sum(cum_distribution[:, None, :] < rand_sources[:, :, None], dim=-1) 135 | return start_indices 136 | 137 | def _forward_transformer(self, memory, mem_mask, samples, negative_samples, padding_mask, dims): 138 | bsz, _, h, sample_length, _, _ = dims 139 | sample_length = 1 + sample_length # CLS token at the beginning 140 | 141 | model_in = torch.cat((samples.view(-1, sample_length, h), negative_samples.view(-1, sample_length, h)), dim=0) 142 | model_padding_mask = padding_mask.view(-1, sample_length).repeat(2, 1) 143 | extended_attn_mask = model_padding_mask[:, None, :] # Extend here already so that no causal mask is added. 144 | memory = self.em_dimension_adjustment_layer(memory) 145 | # bsz must be repeated for each sample and again twice (positive vs. negative sample) 146 | memory = memory.repeat(2 * self.num_samples, 1, 1) 147 | mem_mask = mem_mask.repeat(2 * self.num_samples, 1) 148 | output = self.decoder(inputs_embeds=model_in, attention_mask=extended_attn_mask, 149 | encoder_hidden_states=memory, encoder_attention_mask=mem_mask) 150 | hiddens_pos, hiddens_neg = output.last_hidden_state.split(model_in.shape[0] // 2) 151 | hiddens_pos = hiddens_pos.reshape(bsz, self.num_samples, sample_length, h) 152 | hiddens_neg = hiddens_neg.reshape(bsz, self.num_samples, sample_length, h) 153 | return hiddens_neg, hiddens_pos 154 | 155 | def _calc_recollection_loss(self, original_input, hiddens_pos, masked_items_map, start_indices, dims): 156 | bsz, _, _, _, num_tokens_to_mask, _ = dims 157 | 158 | recollection_loss = torch.scalar_tensor(0, device=hiddens_pos.device, dtype=hiddens_pos.dtype) 159 | for i in range(bsz): 160 | for s in range(self.num_samples): 161 | mask_indices = masked_items_map[i, s].nonzero().squeeze(dim=-1) 162 | # -1 because masked_items respects CLS tokens at position 0, which is not part of original_input 163 | original_masked_item_indices = mask_indices + start_indices[i, s].item() - 1 164 | reconstructed_items = F.linear(hiddens_pos[i, s, mask_indices], 165 | self.input_dimension_adjustment_layer.weight.T) 166 | # other neg. sampling strategy? 167 | # (here, the contrastive loss is calculated between all the masked items from the original input. 168 | # This might be bad, since neighboring video clips might have actually similar features? 169 | # Other option would be to sample randomly from another batch) 170 | # Also introduce another hyperparameter: number of sampled items. RM uses 30 for ActivityNetQA 171 | target_items = original_input[i, original_masked_item_indices] 172 | assert reconstructed_items.shape == target_items.shape, \ 173 | f'{reconstructed_items.shape} != {target_items.shape}' 174 | products = torch.inner(reconstructed_items, target_items) # x,y = sum(reconstructed[x] * original[y]) 175 | recollection_loss += torch.log_softmax(products, dim=1).trace() 176 | recollection_loss = - recollection_loss / (num_tokens_to_mask * self.num_samples * bsz) 177 | return recollection_loss 178 | 179 | def _calc_familiarity_loss(self, hiddens_neg, hiddens_pos, dims): 180 | bsz = dims[0] 181 | pos_cls_output = hiddens_pos[..., 0, :] 182 | neg_cls_output = hiddens_neg[..., 0, :] 183 | pos_scores = self.pos_neg_projection(pos_cls_output) 184 | neg_scores = self.pos_neg_projection(neg_cls_output) 185 | # This has wrong sign in the RM Paper (Eq. 4), leading to NaN 186 | # neg_score has goal 0 => 1-neg_score has goal 1 => want to maximize its log 187 | # (thus needs negative sign due to overall loss minimization) 188 | # since log(x) in (-inf, 0) for x in (0, 1) 189 | # analogous for pos_score, which has goal 1 directly 190 | familiarity_loss = pos_scores.log() + (1 - neg_scores).log() 191 | familiarity_loss = -torch.sum(familiarity_loss) / (self.num_samples * bsz) 192 | return familiarity_loss 193 | 194 | 195 | class RehearsalMemoryMachine(nn.Module): 196 | 197 | def __init__(self, 198 | pretrained_encoder: str, 199 | input_dim: int, 200 | mem_hidden_size: int, 201 | num_memory_slots: int, 202 | segment_length: int, 203 | slot_to_item_num_heads: int = 1, 204 | use_independent_gru_per_mem_slot=False 205 | ) -> None: 206 | super().__init__() 207 | self.segment_length = segment_length 208 | self.num_memory_slots = num_memory_slots 209 | self.mem_hidden_size = mem_hidden_size 210 | self.encoder: PreTrainedModel = AutoModel.from_pretrained(pretrained_encoder) 211 | if hasattr(self.encoder, 'get_encoder'): # In case pretrained_encoder is actually encoder-decoder model 212 | self.encoder = self.encoder.get_encoder() 213 | feature_size = self.encoder.get_input_embeddings().embedding_dim 214 | if input_dim == feature_size: 215 | self.input_transform = nn.Identity() 216 | else: 217 | self.input_transform = nn.Linear(input_dim, feature_size, bias=False) 218 | self.slot_to_item_attn = nn.MultiheadAttention(embed_dim=mem_hidden_size, 219 | kdim=feature_size, vdim=feature_size, 220 | num_heads=slot_to_item_num_heads, 221 | batch_first=True) 222 | num_recurrent_units = num_memory_slots if use_independent_gru_per_mem_slot else 1 223 | self.recurrent_units = nn.ModuleList([ 224 | nn.GRUCell(input_size=mem_hidden_size, 225 | hidden_size=mem_hidden_size, 226 | bias=False) 227 | for _ in range(num_recurrent_units) 228 | ]) 229 | 230 | def forward(self, input_items, input_mask) -> torch.Tensor: 231 | """ 232 | Process the input items, and return memory state at the end. 233 | :param input_items: input of shape Batch x Sequence x InputHidden 234 | :param input_mask: mask of shape Batch x Sequence, 1 = valid token, 0 = masked token 235 | :return: memory state of shape Batch x NumMemorySlots x MemHidden 236 | """ 237 | bsz, seq, h = input_items.shape 238 | assert input_mask.shape == (bsz, seq) 239 | 240 | x = self.input_transform(input_items) 241 | 242 | num_segments = math.ceil(seq / self.segment_length) 243 | memory = torch.zeros(bsz, self.num_memory_slots, self.mem_hidden_size, 244 | device=x.device, dtype=x.dtype) 245 | 246 | for t in range(num_segments): 247 | current_slice = slice(t * self.segment_length, (t + 1) * self.segment_length) 248 | x_t = x[:, current_slice] 249 | mask_t = input_mask[:, current_slice] 250 | active_batches = mask_t.sum(dim=1) > 0 251 | if not active_batches.any(): 252 | raise RuntimeError('Bad batch - padding only?') 253 | 254 | f_t = self.encoder(inputs_embeds=x_t[active_batches], 255 | attention_mask=mask_t[active_batches]).last_hidden_state 256 | 257 | l_t, attn_weights = self.slot_to_item_attn(query=memory[active_batches], 258 | key=f_t, value=f_t, 259 | key_padding_mask=~mask_t[active_batches]) 260 | # l_t : bsz x num_memory_slots x mem_hidden 261 | if len(self.recurrent_units) == 1: 262 | flattened_l_t = l_t.reshape(-1, self.mem_hidden_size) 263 | flattened_mem = memory[active_batches].reshape(-1, self.mem_hidden_size) 264 | new_mem = self.recurrent_units[0](input=flattened_l_t, hx=flattened_mem) 265 | active_bsz = active_batches.sum() 266 | memory[active_batches] = new_mem.view(active_bsz, self.num_memory_slots, self.mem_hidden_size) 267 | else: 268 | for i in range(self.num_memory_slots): 269 | memory[active_batches, i, :] = self.recurrent_units[i](input=l_t[:, i, :], 270 | hx=memory[active_batches, i, :]) 271 | 272 | return memory 273 | 274 | 275 | class RehearsalMemoryEmqaModel(MemoryAugmentedTransformerEmqaModel): 276 | 277 | def __init__(self, 278 | rehearsal_machine: RehearsalMemoryMachine, 279 | rehearsal_trainer: RehearsalTrainingModule, 280 | pretrained_enc_dec: str 281 | ) -> None: 282 | super().__init__(pretrained_enc_dec) 283 | self.rehearsal_machine = rehearsal_machine 284 | self.rehearsal_trainer = rehearsal_trainer 285 | 286 | def forward_memory(self, video_features, video_mask, 287 | moment_localization_labels, question_encoding # unused 288 | ): 289 | return self.rehearsal_machine(video_features, video_mask) 290 | 291 | def calc_additional_loss(self, question_tokens, question_mask, video_features, video_mask, answer_tokens, 292 | answer_mask, batch_sample_ids, context, context_mask, final_memory, mem_mask, 293 | transformer_output, moment_localization_labels): 294 | if self.rehearsal_trainer: 295 | loss_rec, loss_fam = self.rehearsal_trainer(final_memory, mem_mask, 296 | video_features, video_mask, 297 | batch_sample_ids) 298 | else: 299 | loss_rec, loss_fam = torch.zeros(2, dtype=context.dtype, device=context.device) 300 | return { 301 | 'recollection_loss': loss_rec, 302 | 'familiarity_loss': loss_fam 303 | } 304 | -------------------------------------------------------------------------------- /model/external/compressive_transformer.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/lucidrains/compressive-transformer-pytorch 2 | 3 | import math 4 | import sys 5 | from collections import namedtuple 6 | from functools import partial 7 | from inspect import isfunction 8 | from typing import Type, Tuple, List, Union 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn, Tensor 13 | 14 | # structs 15 | 16 | Memory: Type[Tuple[Tensor, List[Tensor], Tensor]] = namedtuple('Memory', ['mem', 'compressed_mem', 'lt_mem']) 17 | 18 | 19 | # helper functions 20 | 21 | def to(t): 22 | return {'dtype': t.dtype, 'device': t.device} 23 | 24 | 25 | def cast_tuple(el): 26 | return el if isinstance(el, tuple) else (el,) 27 | 28 | 29 | def default(x, val): 30 | if x is not None: 31 | return x 32 | return val if not isfunction(val) else val() 33 | 34 | 35 | def max_neg_value(tensor): 36 | return -torch.finfo(tensor.dtype).max 37 | 38 | 39 | def reshape_dim(t, dim, split_dims): 40 | shape = list(t.shape) 41 | num_dims = len(shape) 42 | dim = (dim + num_dims) % num_dims 43 | shape[dim:dim + 1] = split_dims 44 | return t.reshape(shape) 45 | 46 | 47 | def split_at_index(dim, index, t): 48 | pre_slices = (slice(None),) * dim 49 | l = (*pre_slices, slice(None, index)) 50 | r = (*pre_slices, slice(index, None)) 51 | return t[l], t[r] 52 | 53 | 54 | def queue_fifo(*args, length, dim=-2): 55 | queue = torch.cat(args, dim=dim) 56 | if length > 0: 57 | return split_at_index(dim, -length, queue) 58 | 59 | device = queue.device 60 | shape = list(queue.shape) 61 | shape[dim] = 0 62 | return queue, torch.empty(shape, device=device) 63 | 64 | 65 | def shift(x): 66 | *_, i, j = x.shape 67 | zero_pad = torch.zeros((*_, i, i), **to(x)) 68 | x = torch.cat([x, zero_pad], -1) 69 | l = i + j - 1 70 | x = x.view(*_, -1) 71 | zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x)) 72 | shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l) 73 | return shifted[..., :i, i - 1:] 74 | 75 | 76 | def iterate_tensor(t): 77 | length = t.shape[0] 78 | for ind in range(length): 79 | yield t[ind] 80 | 81 | 82 | # full attention for calculating auxiliary reconstruction loss 83 | 84 | def full_attn(q, k, v, dropout_fn=None): 85 | *_, dim = q.shape 86 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * (dim ** -0.5) 87 | attn = dots.softmax(dim=-1) 88 | if dropout_fn is not None: 89 | attn = dropout_fn(attn) 90 | return torch.einsum('bhij,bhjd->bhid', attn, v) 91 | 92 | 93 | # helper classes 94 | 95 | class Residual(nn.Module): 96 | def __init__(self, fn): 97 | super().__init__() 98 | self.fn = fn 99 | 100 | def forward(self, x, **kwargs): 101 | out = self.fn(x, **kwargs) 102 | out = cast_tuple(out) 103 | ret = (out[0] + x), *out[1:] 104 | return ret 105 | 106 | 107 | class GRUGating(nn.Module): 108 | def __init__(self, dim, fn, mogrify=False): 109 | super().__init__() 110 | self.dim = dim 111 | self.fn = fn 112 | self.gru = nn.GRUCell(dim, dim) 113 | if mogrify: 114 | try: 115 | # noinspection PyPackageRequirements 116 | from mogrifier import Mogrifier 117 | self.mogrify = Mogrifier(dim, factorize_k=dim // 4) if mogrify else None 118 | except ImportError: 119 | print('!! mogrify is set, but mogrifier library not available!' 120 | ' Run "pip install mogrifier" to fix.', file=sys.stderr) 121 | 122 | def forward(self, x, **kwargs): 123 | batch, dim = x.shape[0], self.dim 124 | out = self.fn(x, **kwargs) 125 | (y, *rest) = cast_tuple(out) 126 | 127 | if self.mogrify is not None: 128 | y, x = self.mogrify(y, x) 129 | 130 | gated_output = self.gru( 131 | y.reshape(-1, dim), 132 | x.reshape(-1, dim) 133 | ) 134 | 135 | gated_output = gated_output.reshape(batch, -1, dim) 136 | ret = gated_output, *rest 137 | return ret 138 | 139 | 140 | class PreNorm(nn.Module): 141 | def __init__(self, dim, fn): 142 | super().__init__() 143 | self.norm = nn.LayerNorm(dim) 144 | self.fn = fn 145 | 146 | def forward(self, x, **kwargs): 147 | x = self.norm(x) 148 | return self.fn(x, **kwargs) 149 | 150 | 151 | class ConvCompress(nn.Module): 152 | def __init__(self, dim, ratio=4): 153 | super().__init__() 154 | self.conv = nn.Conv1d(dim, dim, ratio, stride=ratio) 155 | 156 | def forward(self, mem): 157 | mem = mem.transpose(1, 2) 158 | compressed_mem = self.conv(mem) 159 | return compressed_mem.transpose(1, 2) 160 | 161 | 162 | class DetachedConvCompress(nn.Module): 163 | def __init__(self, reference: ConvCompress): 164 | super().__init__() 165 | self.reference = reference 166 | 167 | def forward(self, mem): 168 | weight = self.reference.conv.weight.detach() 169 | bias = self.reference.conv.bias.detach() 170 | 171 | mem = mem.transpose(1, 2) 172 | compressed_mem = F.conv1d(mem, weight, bias, self.reference.conv.stride, 173 | self.reference.conv.padding, self.reference.conv.dilation, 174 | self.reference.conv.groups) 175 | return compressed_mem.transpose(1, 2) 176 | 177 | 178 | # feedforward 179 | 180 | class GELU_(nn.Module): 181 | def forward(self, x): 182 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 183 | 184 | 185 | GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_ 186 | 187 | 188 | class FeedForward(nn.Module): 189 | def __init__(self, dim, ff_dim, dropout=0., activation=None, glu=False): 190 | super().__init__() 191 | activation = default(activation, GELU) 192 | 193 | self.glu = glu 194 | self.w1 = nn.Linear(dim, ff_dim * (2 if glu else 1)) 195 | self.act = activation() 196 | self.dropout = nn.Dropout(dropout) 197 | self.w2 = nn.Linear(ff_dim, dim) 198 | 199 | def forward(self, x, **kwargs): 200 | if not self.glu: 201 | x = self.w1(x) 202 | x = self.act(x) 203 | else: 204 | x, v = self.w1(x).chunk(2, dim=-1) 205 | x = self.act(x) * v 206 | 207 | x = self.dropout(x) 208 | x = self.w2(x) 209 | return x 210 | 211 | 212 | class CompressionStage(nn.Module): 213 | 214 | def __init__(self, dim, cmem_ratio, cmem_len, attn_heads, attn_dim_heads, reconstruction_attn_dropout, 215 | prev_lvl_mem_start_index, prev_lvl_mem_len) -> None: 216 | super().__init__() 217 | self.attn_heads = attn_heads # of the containing SelfAttention object 218 | self.attn_dim_heads = attn_dim_heads 219 | self.mem_len_this_lvl = cmem_len 220 | self.prev_lvl_mem_start_index = prev_lvl_mem_start_index 221 | self.prev_lvl_mem_len = prev_lvl_mem_len 222 | 223 | assert prev_lvl_mem_len % cmem_ratio == 0, \ 224 | f'mem length of previous level ({prev_lvl_mem_len}) must be divisble by compression ratio ({cmem_ratio})' 225 | 226 | self.reconstruction_attn_dropout = nn.Dropout(reconstruction_attn_dropout) 227 | self.compress_mem_fn = ConvCompress(dim, cmem_ratio) 228 | self.compress_mem_fn_without_grad = DetachedConvCompress(self.compress_mem_fn) 229 | 230 | def forward(self, prev_cmem_this_lvl, old_mem_prev_lvl, prev_lvl_mem_len, q, k, v, to_kv_weight): 231 | compressed_mem = self.compress_mem_fn_without_grad(old_mem_prev_lvl) 232 | old_cmem, new_cmem = split_at_index(1, -self.mem_len_this_lvl, 233 | torch.cat((prev_cmem_this_lvl, compressed_mem), dim=1)) 234 | aux_loss = torch.zeros(1, requires_grad=True, **to(prev_cmem_this_lvl)) 235 | 236 | if not self.training: 237 | return old_cmem, new_cmem, aux_loss 238 | 239 | # calculate compressed memory auxiliary loss if training 240 | merge_heads = lambda x: reshape_dim(x, -1, (-1, self.attn_dim_heads)).transpose(1, 2) 241 | 242 | compressed_mem = self.compress_mem_fn(old_mem_prev_lvl.detach()) 243 | cmem_k, cmem_v = F.linear(compressed_mem, to_kv_weight.detach()).chunk(2, dim=-1) 244 | cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v)) 245 | cmem_k, cmem_v = map(lambda x: x.expand(-1, self.attn_heads, -1, -1), (cmem_k, cmem_v)) 246 | 247 | old_mem_range = slice(- min(prev_lvl_mem_len, self.prev_lvl_mem_len) - self.prev_lvl_mem_start_index, 248 | -self.prev_lvl_mem_start_index) 249 | old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v)) 250 | 251 | q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v)) 252 | 253 | attn_fn = partial(full_attn, dropout_fn=self.reconstruction_attn_dropout) 254 | 255 | aux_loss = F.mse_loss( 256 | attn_fn(q, old_mem_k, old_mem_v), 257 | attn_fn(q, cmem_k, cmem_v) 258 | ) 259 | 260 | return old_cmem, new_cmem, aux_loss 261 | 262 | 263 | # attention. 264 | 265 | class SelfAttention(nn.Module): 266 | 267 | @staticmethod 268 | def validate_cmem_parameters(seq_len: int, mem_len: int, 269 | cmem_lengths: List[int], cmem_ratios: Union[List[int], int]): 270 | assert len(cmem_lengths) == len(cmem_ratios), f'{cmem_lengths}, {cmem_ratios} should have same length!' 271 | compression_levels = len(cmem_lengths) 272 | # compression stage 0 is mem -> cmem 273 | one_input_block_size = seq_len 274 | for i in range(compression_levels): 275 | assert one_input_block_size >= cmem_ratios[i], \ 276 | f'At compression level {i}, one input block of {seq_len} tokens is already reduced to ' \ 277 | f'{one_input_block_size} compressed tokens, cannot be compressed again with ratio {cmem_ratios[i]}' 278 | assert cmem_lengths[i] >= (one_input_block_size // cmem_ratios[i]), \ 279 | f'length of compressed memory at level {i + 1} should be at least the compressed input block length ' \ 280 | f'at level {i} ({one_input_block_size}) divided by the compression ratio {cmem_ratios[i]}, ' \ 281 | f'i.e. at least {int(one_input_block_size // cmem_ratios[i])}' 282 | one_input_block_size //= cmem_ratios[i] 283 | 284 | # simulate information flow 285 | log = '' 286 | mem = 0 287 | cmems = [0] * compression_levels 288 | while True: # simulate until lt mem would be filled. then, sizes do not change anymore (everything full) 289 | mem += seq_len 290 | log += f'i={seq_len} -> ' 291 | if mem <= mem_len: 292 | log += f'm={mem}\n' 293 | continue 294 | old_mem = mem - mem_len 295 | mem = mem_len 296 | log += f'm={mem} -> {old_mem}' 297 | for lvl in range(compression_levels): 298 | log += f' --/{cmem_ratios[lvl]}--> c{lvl}=' 299 | assert old_mem % cmem_ratios[lvl] == 0, \ 300 | f'mem length {old_mem} from previous layer not divisible by compression ratio {cmem_ratios[lvl]} ' \ 301 | f'at compression level {lvl}. Log:\n{log}' 302 | cmems[lvl] += old_mem // cmem_ratios[lvl] 303 | if cmems[lvl] <= cmem_lengths[lvl]: 304 | log += f'{cmems[lvl]}' 305 | old_mem = 0 306 | break 307 | old_mem = cmems[lvl] - cmem_lengths[lvl] 308 | cmems[lvl] = cmem_lengths[lvl] 309 | log += f'{cmems[lvl]} -> {old_mem}' 310 | log += '\n' 311 | if old_mem > 0: 312 | break 313 | 314 | def __init__(self, dim, seq_len, mem_len: int, 315 | cmem_lengths: List[int], cmem_ratios: Union[List[int], int], 316 | use_ltmem=True, 317 | heads=8, attn_dropout=0., dropout=0., 318 | reconstruction_attn_dropout=0., one_kv_head=False): 319 | super().__init__() 320 | assert (dim % heads) == 0, 'dimension must be divisible by the number of heads' 321 | if isinstance(cmem_ratios, int): 322 | cmem_ratios = [cmem_ratios] * len(cmem_lengths) 323 | SelfAttention.validate_cmem_parameters(seq_len, mem_len, cmem_lengths, cmem_ratios) 324 | 325 | self.heads = heads 326 | self.dim_head = dim // heads 327 | self.seq_len = seq_len 328 | self.mem_len = mem_len 329 | self.num_cmem_stages = len(cmem_lengths) 330 | self.cmem_lengths = cmem_lengths 331 | self.cmem_ratios = cmem_ratios 332 | self.use_ltmem = use_ltmem 333 | self.scale = self.dim_head ** (-0.5) 334 | 335 | self.compression_stages = nn.ModuleList() 336 | running_start_index = self.seq_len 337 | prev_length = mem_len 338 | for i in range(self.num_cmem_stages): 339 | self.compression_stages.append(CompressionStage( 340 | dim, cmem_ratios[i], cmem_lengths[i], heads, self.dim_head, 341 | reconstruction_attn_dropout, 342 | prev_lvl_mem_start_index=running_start_index, 343 | prev_lvl_mem_len=prev_length)) 344 | prev_length = cmem_lengths[i] 345 | running_start_index += prev_length 346 | 347 | if self.use_ltmem: 348 | self.cmem_to_ltmem_query = nn.Parameter(torch.zeros(dim), requires_grad=True) 349 | self.ltmem_tokv = nn.Linear(dim, dim * 2, bias=False) 350 | self.recurrence = nn.GRUCell(dim, dim, bias=False) 351 | 352 | self.to_q = nn.Linear(dim, dim, bias=False) 353 | 354 | kv_dim = self.dim_head if one_kv_head else dim 355 | self.to_kv = nn.Linear(dim, kv_dim * 2, bias=False) 356 | self.to_out = nn.Linear(dim, dim) 357 | 358 | self.attn_dropout = nn.Dropout(attn_dropout) 359 | self.dropout = nn.Dropout(dropout) 360 | 361 | def forward(self, x, memories=None, pos_emb=None, input_mask=None, calc_memory=True, **kwargs): 362 | b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head 363 | 364 | memories: Memory = default(memories, (None, None, None)) 365 | mem, cmems, ltmem = memories 366 | 367 | init_empty_mem = lambda: torch.empty(b, 0, e, **to(x)) 368 | mem = default(mem, init_empty_mem) 369 | cmems = default(cmems, lambda: [init_empty_mem() for i in range(self.num_cmem_stages)]) 370 | ltmem = default(ltmem, init_empty_mem) 371 | 372 | mem_len = mem.shape[1] 373 | cmem_len_sum = sum(cmem.shape[1] for cmem in cmems) 374 | ltmem_len = ltmem.shape[1] 375 | assert 0 <= ltmem_len <= 1, str(ltmem) 376 | 377 | q = self.to_q(x) 378 | 379 | if self.num_cmem_stages == 0: 380 | kv_input = torch.cat((ltmem, mem, x), dim=1) 381 | else: 382 | kv_input = torch.cat((ltmem, *cmems, mem, x), dim=1) 383 | kv_len = kv_input.shape[1] 384 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 385 | 386 | merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2) 387 | q, k, v = map(merge_heads, (q, k, v)) 388 | 389 | k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v)) 390 | 391 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 392 | mask_value = max_neg_value(dots) 393 | 394 | if pos_emb is not None: 395 | pos_emb = pos_emb[:, -kv_len:].type(q.dtype) 396 | pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale 397 | pos_dots = shift(pos_dots) 398 | dots = dots + pos_dots 399 | 400 | if input_mask is not None: 401 | mask = input_mask[:, None, :, None] * input_mask[:, None, None, :] 402 | mask = F.pad(mask, [mem_len + cmem_len_sum + ltmem_len, 0], value=True) 403 | dots.masked_fill_(~mask, mask_value) 404 | 405 | total_mem_len = mem_len + cmem_len_sum + ltmem_len 406 | mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal=1 + total_mem_len).bool() 407 | dots.masked_fill_(mask[None, None, ...], mask_value) 408 | 409 | attn = dots.softmax(dim=-1) 410 | attn = self.attn_dropout(attn) 411 | 412 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 413 | out = out.transpose(1, 2).reshape(b, t, -1) 414 | logits = self.to_out(out) 415 | logits = self.dropout(logits) 416 | 417 | new_mem = mem 418 | new_cmems = cmems 419 | new_ltmem = ltmem 420 | aux_loss = torch.zeros(1, requires_grad=False, **to(q)) 421 | 422 | if self.seq_len > t or not calc_memory: 423 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss 424 | 425 | # calculate memory and compressed memory 426 | 427 | old_mem, new_mem = queue_fifo(mem, x, length=self.mem_len, dim=1) 428 | old_mem_padding = old_mem.shape[1] % self.cmem_ratios[0] 429 | 430 | if old_mem_padding != 0: 431 | old_mem = F.pad(old_mem, [0, 0, old_mem_padding, 0], value=0.) 432 | 433 | if old_mem.shape[1] == 0 or self.num_cmem_stages <= 0: 434 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss 435 | 436 | prev_mem_len = mem_len 437 | old_mem_prev_lvl = old_mem 438 | for i in range(self.num_cmem_stages): 439 | if old_mem_prev_lvl.size(1) == 0: 440 | break 441 | old_mem_prev_lvl, new_cmems[i], lvl_aux_loss = self.compression_stages[i]( 442 | prev_cmem_this_lvl=cmems[i], 443 | old_mem_prev_lvl=old_mem_prev_lvl, 444 | prev_lvl_mem_len=prev_mem_len, 445 | q=q, k=k, v=v, 446 | to_kv_weight=self.to_kv.weight 447 | ) 448 | aux_loss += lvl_aux_loss 449 | prev_mem_len = cmems[i].size(1) 450 | 451 | if old_mem_prev_lvl.size(1) > 0 and self.use_ltmem: 452 | old_cmem_k, old_cmem_v = (self.ltmem_tokv(old_mem_prev_lvl) 453 | .unsqueeze(dim=1) # Insert fake head dimension 454 | .chunk(2, dim=-1)) 455 | to_ltmem_query = self.cmem_to_ltmem_query.expand(b, 1, 1, e) # b x 1(=h) x 1(=seq) x e 456 | ltmem_update = full_attn(to_ltmem_query, old_cmem_k, old_cmem_v) 457 | if ltmem_len > 0: 458 | new_ltmem = self.recurrence(ltmem_update.view(b, e), ltmem.squeeze(dim=1)).unsqueeze(dim=1) 459 | else: 460 | new_ltmem = ltmem_update.squeeze(dim=1) # Remove heads dimension 461 | 462 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss 463 | 464 | 465 | # transformer 466 | 467 | class CompressiveTransformer(nn.Module): 468 | def __init__(self, num_tokens, dim, seq_len, depth, emb_dim=None, 469 | memory_layers=None, mem_len=None, 470 | cmem_lengths: List[int] = None, cmem_ratios: Union[int, List[int]] = 4, 471 | use_ltmem=True, 472 | heads=8, gru_gated_residual=True, mogrify_gru=False, attn_dropout=0., 473 | ff_glu=False, ff_dim=None, ff_dropout=0., 474 | attn_layer_dropout=0., reconstruction_attn_dropout=0., reconstruction_loss_weight=1., 475 | one_kv_head=False): 476 | super().__init__() 477 | if isinstance(cmem_ratios, int): 478 | if cmem_lengths is None: 479 | cmem_ratios = [cmem_ratios] 480 | else: 481 | cmem_ratios = [cmem_ratios] * len(cmem_lengths) 482 | else: 483 | assert cmem_lengths is not None 484 | assert len(cmem_lengths) == len(cmem_ratios) 485 | 486 | ff_dim = default(ff_dim, dim * 4) 487 | emb_dim = default(emb_dim, dim) 488 | mem_len = default(mem_len, seq_len) 489 | cmem_lengths = default(cmem_lengths, [mem_len // cmem_ratios[0]]) 490 | memory_layers = default(memory_layers, list(range(1, depth + 1))) 491 | 492 | assert mem_len >= seq_len, 'length of memory should be at least the sequence length' 493 | assert all( 494 | [0 < layer <= depth for layer in memory_layers]), 'one of the indicated memory layers is invalid' 495 | 496 | self.seq_len = seq_len 497 | 498 | self.depth = depth 499 | self.memory_layers = list(memory_layers) 500 | self.num_cmem_stages = len(cmem_lengths) 501 | 502 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 503 | self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim) 504 | 505 | seq_and_mem_len = seq_len + mem_len + sum(cmem_lengths) + (1 if use_ltmem else 0) # + 1 for LT Memory 506 | self.pos_emb = nn.Parameter(torch.zeros(heads, seq_and_mem_len, dim // heads), requires_grad=True) 507 | 508 | self.to_logits = nn.Sequential( 509 | nn.Identity() if emb_dim == dim else nn.Linear(dim, emb_dim), 510 | nn.Linear(emb_dim, num_tokens) 511 | ) 512 | 513 | wrapper = partial(GRUGating, dim, mogrify=mogrify_gru) if gru_gated_residual else Residual 514 | 515 | self.attn_layers = nn.ModuleList([ 516 | wrapper(PreNorm(dim, SelfAttention( 517 | dim, seq_len, mem_len, 518 | cmem_lengths if (i + 1) in memory_layers else [], 519 | cmem_ratios if (i + 1) in memory_layers else [], 520 | use_ltmem and (i + 1) in memory_layers, 521 | heads, dropout=attn_layer_dropout, 522 | attn_dropout=attn_dropout, 523 | reconstruction_attn_dropout=reconstruction_attn_dropout, 524 | one_kv_head=one_kv_head 525 | ))) for i in range(depth)]) 526 | self.ff_layers = nn.ModuleList( 527 | [wrapper(PreNorm(dim, FeedForward(dim, ff_dim, dropout=ff_dropout, glu=ff_glu))) for _ in range(depth)]) 528 | 529 | self.reconstruction_loss_weight = reconstruction_loss_weight 530 | 531 | def forward(self, x, memories=None, mask=None): 532 | input_device = x.device 533 | x = self.token_emb(x) 534 | x = self.to_model_dim(x) 535 | b, t, d = x.shape 536 | 537 | assert t <= self.seq_len, f'input contains a sequence length {t} that is greater than the designated maximum ' \ 538 | f'sequence length {self.seq_len} ' 539 | 540 | memories = default(memories, (None, None, None)) 541 | mem, cmems, ltmem = memories 542 | 543 | num_memory_layers = len(self.memory_layers) 544 | init_empty_mem = lambda: torch.empty(num_memory_layers, b, 0, d, **to(x)) 545 | mem = default(mem, init_empty_mem) 546 | cmems = default(cmems, lambda: [init_empty_mem() for i in range(self.num_cmem_stages)]) 547 | ltmem = default(ltmem, init_empty_mem) 548 | 549 | total_len = mem.shape[2] + sum(cmem.shape[2] for cmem in cmems) + ltmem.shape[2] + self.seq_len 550 | pos_emb = self.pos_emb[:, (self.seq_len - t):total_len] 551 | 552 | # Lists of {c,lt,}mem per transformer layer 553 | next_mem = [] 554 | next_cmems = [] 555 | next_ltmem = [] 556 | aux_loss = torch.tensor(0., requires_grad=True, **to(x)) 557 | 558 | mem_iter, ltmem_iter = map(iterate_tensor, (mem, ltmem)) 559 | cmems_iter = ([cmem[i] for cmem in cmems] for i in range(num_memory_layers)) 560 | 561 | for ind in range(self.depth): 562 | x, mem_out, cmems_out, ltmem_out, layer_aux_loss \ 563 | = self._pass_through_layer(ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x) 564 | aux_loss = aux_loss + layer_aux_loss 565 | 566 | if (ind + 1) not in self.memory_layers: 567 | continue 568 | 569 | next_mem.append(mem_out) 570 | next_cmems.append(cmems_out) 571 | next_ltmem.append(ltmem_out) 572 | 573 | out = self.to_logits(x) 574 | 575 | next_mem, next_ltmem = map(torch.stack, (next_mem, next_ltmem)) 576 | next_cmems = [torch.stack([next_cmems[layer][cstage] for layer in range(num_memory_layers)]) 577 | for cstage in range(self.num_cmem_stages)] 578 | 579 | aux_loss = aux_loss * self.reconstruction_loss_weight / num_memory_layers 580 | out = out.to(device=input_device) 581 | return out, Memory(mem=next_mem, compressed_mem=next_cmems, lt_mem=next_ltmem), aux_loss 582 | 583 | def _pass_through_layer(self, ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x): 584 | attn = self.attn_layers[ind] 585 | ff = self.ff_layers[ind] 586 | 587 | layer_num = ind + 1 588 | use_memory = layer_num in self.memory_layers 589 | memories = (next(mem_iter), next(cmems_iter), next(ltmem_iter)) if use_memory else None 590 | _dev = lambda t: t.to(device=x.device) 591 | memories = (_dev(memories[0]), [_dev(m) for m in memories[1]], _dev(memories[2])) if memories else None 592 | 593 | x, (mem_out, cmems_out, ltmem_out), layer_aux_loss = attn(x, memories=memories, calc_memory=use_memory, 594 | input_mask=mask, pos_emb=pos_emb) 595 | x, = ff(x) 596 | 597 | return x, mem_out, cmems_out, ltmem_out, layer_aux_loss 598 | 599 | 600 | class MultiDeviceCompressiveTransformer(CompressiveTransformer): 601 | """ 602 | CompressiveTransformer with model parallelism. 603 | Note: Start fairseq-train with 604 | --distributed-no-spawn 605 | --distributed-world-size 1 606 | to prevent data parallelism 607 | """ 608 | 609 | def __init__(self, num_tokens, dim, seq_len, depth, emb_dim=None, memory_layers=None, mem_len=None, 610 | cmem_lengths: List[int] = None, cmem_ratios: Union[int, List[int]] = 4, use_ltmem=True, heads=8, 611 | gru_gated_residual=True, mogrify_gru=False, attn_dropout=0., ff_glu=False, ff_dim=None, ff_dropout=0., 612 | attn_layer_dropout=0., reconstruction_attn_dropout=0., reconstruction_loss_weight=1., 613 | one_kv_head=False, 614 | layers_to_gpus=None): 615 | super().__init__(num_tokens, dim, seq_len, depth, emb_dim, memory_layers, mem_len, cmem_lengths, cmem_ratios, 616 | use_ltmem, heads, gru_gated_residual, mogrify_gru, attn_dropout, ff_glu, ff_dim, ff_dropout, 617 | attn_layer_dropout, reconstruction_attn_dropout, reconstruction_loss_weight, one_kv_head) 618 | 619 | gpus = torch.cuda.device_count() 620 | layers_to_gpus = default(layers_to_gpus, [int(i / self.depth * gpus) for i in range(self.depth)]) 621 | assert len(layers_to_gpus) == self.depth 622 | assert all(0 <= x < gpus for x in layers_to_gpus) 623 | self.layers_to_gpus = layers_to_gpus 624 | 625 | def cuda(self, device=None): 626 | # pos_emb, token_emb, to_model_dim and to_logits always stays on device 0 627 | self.pos_emb = nn.Parameter(self.pos_emb.cuda(), requires_grad=True) 628 | self.token_emb.to(device=0) 629 | self.to_model_dim.to(device=0) 630 | self.to_logits.to(device=torch.cuda.device_count() - 1) 631 | for i in range(self.depth): 632 | self.attn_layers[i].to(device=self.layers_to_gpus[i]) 633 | self.ff_layers[i].to(device=self.layers_to_gpus[i]) 634 | return self 635 | 636 | def _apply(self, fn): 637 | fake = torch.empty(0) 638 | if fn(fake).device.type == 'cuda' and fn(fake).device != fake.device: 639 | return self.cuda() 640 | else: 641 | # noinspection PyProtectedMember 642 | return super()._apply(fn) 643 | 644 | def _pass_through_layer(self, ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x): 645 | gpu = self.layers_to_gpus[ind] 646 | 647 | x = x.to(device=gpu) 648 | pos_emb = pos_emb.to(device=gpu) 649 | mask = mask.to(device=gpu) if mask else None 650 | x, mem_out, cmems_out, ltmem_out, layer_aux_loss = super()._pass_through_layer( 651 | ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x) 652 | 653 | mem_out = mem_out.to(device=0) if mem_out is not None else None 654 | cmems_out = [m.to(device=0) for m in cmems_out] if cmems_out is not None else None 655 | ltmem_out = ltmem_out.to(device=0) if ltmem_out is not None else None 656 | layer_aux_loss = layer_aux_loss.to(device=0) if layer_aux_loss is not None else None 657 | 658 | return x, mem_out, cmems_out, ltmem_out, layer_aux_loss 659 | --------------------------------------------------------------------------------