├── data
├── __init__.py
├── datamodule.py
└── emqa_dataset.py
├── eval
├── __init__.py
├── util.py
└── eval.py
├── model
├── __init__.py
├── external
│ ├── __init__.py
│ ├── stm.py
│ └── compressive_transformer.py
├── blind.py
├── simple_vqa.py
├── lt_ct.py
├── sparse_transformer.py
├── importance_teacher.py
├── lightning.py
├── base.py
├── mann.py
├── moment_loc.py
└── rehearsal.py
├── tools
├── __init__.py
├── aggregate_features_to_hdf5.py
├── split_existing_data_aligned.py
├── create_pure_videoqa_json.py
└── extract_ego4d_clip_features.py
├── .gitignore
├── config
├── model
│ ├── blind.yaml
│ ├── importance_teacher.yaml
│ ├── moment_localization_loss
│ │ ├── attn_summed.yaml
│ │ └── attn_sample.yaml
│ ├── simple_vqa.yaml
│ ├── bigbird.yaml
│ ├── longformer.yaml
│ ├── stm.yaml
│ ├── lt_ct.yaml
│ ├── dnc.yaml
│ └── rehearsal_mem.yaml
├── dataset
│ └── ego4d.yaml
└── base.yaml
├── .gitattributes
├── experiment
└── run.sh
├── run.py
├── hydra_compat.py
├── requirements.txt
├── README.md
└── lightning_util.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/external/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | venv
3 | .idea
4 | **.pyc
--------------------------------------------------------------------------------
/config/model/blind.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.blind.BlindVqaModel
2 | pretrained_model: t5-base
--------------------------------------------------------------------------------
/config/model/importance_teacher.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.importance_teacher.ImportanceTeacherVqaModel
2 | pretrained_enc_dec: t5-small
3 | input_size: ${dataset.feature_dim}
4 | fragment_length: 128
5 |
--------------------------------------------------------------------------------
/config/model/moment_localization_loss/attn_summed.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.moment_loc.SummedAttentionTransformerMomentLocLoss
2 | softmax_temperature: 0.1
3 | att_loss_type: 'lse'
4 | lse_alpha: 20
5 |
--------------------------------------------------------------------------------
/config/model/simple_vqa.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - moment_localization_loss: attn_sample
3 |
4 | _target_: model.simple_vqa.SimpleVqaModel
5 | pretrained_enc_dec: t5-base
6 | input_size: ${dataset.feature_dim}
7 | #moment_localization_loss: null
--------------------------------------------------------------------------------
/config/model/bigbird.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.sparse_transformer.BigBirdVqaModel
2 | input_size: ${dataset.feature_dim}
3 | pretrained_bigbird: google/bigbird-pegasus-large-arxiv
4 | use_moment_localization_loss: False
5 | gradient_checkpointing: True
--------------------------------------------------------------------------------
/config/model/moment_localization_loss/attn_sample.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.moment_loc.SamplingAttentionTransformerMomentLocLoss
2 | softmax_temperature: 0.1
3 | att_loss_type: 'lse'
4 | lse_alpha: 20
5 | num_negatives: 2
6 | use_hard_negatives: False
7 | drop_topk: 0
8 | negative_pool_size: 0
9 | num_hard: 2
--------------------------------------------------------------------------------
/config/model/longformer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - moment_localization_loss: attn_sample
3 |
4 | _target_: model.sparse_transformer.LongformerVqaModel
5 | pretrained_enc_dec: t5-base
6 | input_size: ${dataset.feature_dim}
7 | pretrained_longformer: allenai/longformer-base-4096
8 | #moment_localization_loss: null
9 |
--------------------------------------------------------------------------------
/config/dataset/ego4d.yaml:
--------------------------------------------------------------------------------
1 | data_dir: datasets/ego4d
2 | use_final_test: False
3 | feature_type: slowfast8x8_r101_k400
4 | feature_dim: 2304
5 |
6 | drop_val_last: True
7 |
8 | tokenizer_name: t5-base
9 | # separate_question_tok_name: allenai/longformer-base-4096
10 |
11 | workers: 16
12 | train_bsz: 8
13 | test_bsz: 16
14 |
--------------------------------------------------------------------------------
/config/model/stm.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.mann.StmEmqaModel
2 | pretrained_enc_dec: t5-base
3 | segmentation_method: avg
4 | input_size: ${dataset.feature_dim}
5 | segment_length: 32
6 | stm_input_size: 1024
7 | mem_hidden_size: 768
8 | stm_step: 2
9 | stm_num_slot: 8
10 | stm_mlp_size: 256
11 | stm_slot_size: 128
12 | stm_rel_size: 128
13 | stm_out_att_size: 128
--------------------------------------------------------------------------------
/config/model/lt_ct.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.lt_ct.CompressiveTransformerEmqaModel
2 | pretrained_enc_dec: t5-base
3 | input_dim: ${dataset.feature_dim}
4 | hidden_dim: 768
5 | num_layers: 6
6 | heads: 8
7 | block_length: 64
8 | mem_length: 64
9 | cmem_lengths: [ 16, 8, 4 ]
10 | compression_factors: [ 4, 2, 2 ]
11 | use_ltmem: True
12 | memory_layers: [ 1, 4, 5, 6 ]
13 | dropout: 0.1
14 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | #Set the default behavior, in case people don't have core.autocrlf set.
2 | * text=auto
3 |
4 | # Declare files that will always have LF line endings on checkout.
5 | *.py text eol=lf
6 | *.txt text eol=lf
7 | *.xml text eol=lf
8 | *.gitattributes text eol=lf
9 | *.gitignore text eol=lf
10 | *.json text eol=lf
11 | *.jgram text eol=lf
12 |
13 | #Forced Binary files
14 | #For example *.png binary
15 |
--------------------------------------------------------------------------------
/config/model/dnc.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.mann.DncEmqaModel
2 | pretrained_enc_dec: t5-base
3 | segmentation_method: avg
4 | input_size: ${dataset.feature_dim}
5 | segment_length: 32
6 | dnc_input_size: 1024
7 | rnn_hidden_size: 512
8 | num_dnc_layers: 2
9 | num_rnn_hidden_layers: 2
10 | num_mem_cells: 16
11 | mem_hidden_size: 768
12 | #moment_loc:
13 | # _target_: model.moment_loc.SeqMomentLocalizationLossModule
14 | # seq_hidden_dim: 1024
15 | # question_hidden_dim: 512
--------------------------------------------------------------------------------
/eval/util.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 |
4 | def __init__(self):
5 | self.reset()
6 |
7 | def reset(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def update(self, val, n=1):
14 | self.val = val
15 | self.sum += val * n
16 | self.count += n
17 | self.avg = self.sum / self.count
18 |
--------------------------------------------------------------------------------
/experiment/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Run this within the project main directory (where the directories experiment, model, ... reside)
5 |
6 | if [[ -z $TRANSFORMERS_CACHE ]]; then
7 | echo "TRANSFORMERS_CACHE env var has to be set!"
8 | exit 1
9 | fi
10 | if [[ ! -d $TRANSFORMERS_CACHE ]]; then
11 | echo "TRANSFORMERS_CACHE directory should exist (to avoid unintended downloads). Current: '$TRANSFORMERS_CACHE'"
12 | exit 1
13 | fi
14 |
15 | python run.py "${@:1}"
16 |
--------------------------------------------------------------------------------
/config/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: ego4d
3 | - model: rehearsal_mem
4 | - _self_
5 | - override hydra/job_logging: none
6 | - override hydra/hydra_logging: none
7 |
8 | trainer:
9 | detect_anomaly: True
10 | #val_check_interval: 500
11 | max_epochs: 100
12 | accumulate_grad_batches: 4
13 | auto_resume: False
14 | gpus: 1
15 | log_every_n_steps: 4
16 | auto_lr_find: True
17 | enable_progress_bar: False
18 | monitor_variable: val_lm_loss
19 | monitor_mode: min
20 |
21 | optim:
22 | optimizer:
23 | _target_: torch.optim.Adam
24 | lr: 0.00001
25 | weight_decay: 0
26 |
27 | freeze: [ ]
28 |
29 | loss_weights:
30 | # lm_loss is reference
31 | recollection_loss: 1
32 | familiarity_loss: 0.5
33 |
34 |
35 | hydra:
36 | run:
37 | dir: .
38 | output_subdir: null
39 |
40 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import hydra
4 | from omegaconf import DictConfig
5 |
6 | from data.datamodule import EmqaDataModule
7 | from hydra_compat import apply_argparse_defaults_to_hydra_config
8 | from lightning_util import tune_fit_test, add_common_trainer_util_args
9 | from model.lightning import EmqaLightningModule
10 |
11 |
12 | @hydra.main(config_path='config', config_name='base')
13 | def main(config: DictConfig):
14 | fake_parser = ArgumentParser()
15 | add_common_trainer_util_args(fake_parser, default_monitor_variable='val_total_loss')
16 | apply_argparse_defaults_to_hydra_config(config.trainer, fake_parser)
17 |
18 | data = EmqaDataModule(config.dataset)
19 | model = EmqaLightningModule(config.model, config.optim, data.tokenizer)
20 |
21 | tune_fit_test(config.trainer, model, data)
22 |
23 |
24 | if __name__ == '__main__':
25 | main()
26 |
--------------------------------------------------------------------------------
/tools/aggregate_features_to_hdf5.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | import h5py
5 | import torch
6 | from tqdm import tqdm
7 |
8 |
9 | def main():
10 | features_dir = Path(sys.argv[1])
11 | all_feature_files = list(features_dir.glob('*.pt'))
12 |
13 | with h5py.File('slowfast8x8_r101_k400.hdf5', 'a') as h5_file, tqdm(all_feature_files) as iterator:
14 | features_file: Path
15 | for features_file in iterator:
16 | try:
17 | clip_uid = features_file.name.replace('.pt', '')
18 | if clip_uid in h5_file:
19 | continue
20 | features = torch.load(str(features_file))
21 | h5_file.create_dataset(clip_uid, data=features, compression="gzip")
22 | except BaseException as e:
23 | raise RuntimeError(f'Error during {features_file}: {e}')
24 |
25 |
26 | if __name__ == '__main__':
27 | main()
28 |
--------------------------------------------------------------------------------
/hydra_compat.py:
--------------------------------------------------------------------------------
1 | from absl.flags.argparse_flags import ArgumentParser
2 | from omegaconf import DictConfig, open_dict
3 |
4 |
5 | def apply_argparse_defaults_to_hydra_config(config: DictConfig, parser: ArgumentParser, verbose=False):
6 | args = parser.parse_args([]) # Parser is not allowed to have required args, otherwise this will fail!
7 | defaults = vars(args)
8 |
9 | def _apply_defaults(dest: DictConfig, source: dict, indentation=''):
10 | for k, v in source.items():
11 | if k in dest and isinstance(v, dict):
12 | current_value = dest[k]
13 | if current_value is not None:
14 | assert isinstance(current_value, DictConfig)
15 | _apply_defaults(current_value, v, indentation + ' ')
16 | elif k not in dest:
17 | dest[k] = v
18 | if verbose:
19 | print(indentation, 'set default value for', k)
20 |
21 | with open_dict(config):
22 | _apply_defaults(config, defaults)
23 |
--------------------------------------------------------------------------------
/config/model/rehearsal_mem.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.rehearsal.RehearsalMemoryEmqaModel
2 | pretrained_enc_dec: t5-base
3 |
4 | rehearsal_machine:
5 | _target_: model.rehearsal.RehearsalMemoryMachine
6 | pretrained_encoder: t5-base
7 | input_dim: ${dataset.feature_dim}
8 | mem_hidden_size: 768
9 | num_memory_slots: 16
10 | segment_length: 128
11 | slot_to_item_num_heads: 1
12 | use_independent_gru_per_mem_slot: False
13 |
14 | rehearsal_trainer:
15 | _target_: model.rehearsal.RehearsalTrainingModule
16 | input_size: ${dataset.feature_dim}
17 | mem_hidden_size: ${..rehearsal_machine.mem_hidden_size}
18 | num_samples: 4
19 | sample_length: 128
20 | positive_mask_ratio: 0.5
21 | negative_replacement_ratio: 0.5
22 | invert_teacher_sequence: False
23 | pretrained_decoder: null
24 | decoder_params:
25 | hidden_size: 256
26 | num_hidden_layers: 3
27 | num_attention_heads: 4
28 | intermediate_size: 512
29 | max_position_embeddings: 132
30 | vocab_size: 16 # Vocab never used (only inputs_embeds)
31 | sampling_teacher_weights_file: ../22_04_b5aa4428_0/importance_teacher_weights.pt
32 |
--------------------------------------------------------------------------------
/model/blind.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict
2 |
3 | import torch
4 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM
5 |
6 | from model.base import EmqaBaseModel
7 |
8 |
9 | class BlindVqaModel(EmqaBaseModel):
10 |
11 | def __init__(self,
12 | pretrained_model: str
13 | ) -> None:
14 | super().__init__()
15 | self.transformer: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model)
16 | assert self.transformer.config.is_encoder_decoder
17 |
18 | def teacher_forcing_forward(self, question_tokens, question_mask, video_features, video_mask, answer_tokens,
19 | answer_mask, batch_sample_ids,
20 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
21 | output = self.transformer(input_ids=question_tokens, attention_mask=question_mask,
22 | labels=answer_tokens, decoder_attention_mask=answer_mask)
23 | return {'lm_loss': output.loss}, output.logits
24 |
25 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor:
26 | return self.transformer.generate(inputs=question_tokens, attention_mask=question_mask)
27 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | aiohttp==3.8.1
3 | aiosignal==1.2.0
4 | antlr4-python3-runtime==4.8
5 | async-timeout==4.0.2
6 | attrs==21.4.0
7 | cachetools==5.0.0
8 | certifi==2021.10.8
9 | charset-normalizer==2.0.12
10 | click==8.0.4
11 | colorama==0.4.4
12 | dnc==1.1.0
13 | filelock==3.6.0
14 | flann==1.6.13
15 | frozenlist==1.3.0
16 | fsspec==2022.2.0
17 | future==0.18.2
18 | google-auth==2.6.0
19 | google-auth-oauthlib==0.4.6
20 | grpcio==1.44.0
21 | h5py==3.6.0
22 | huggingface-hub==0.4.0
23 | hydra-core==1.1.1
24 | idna==3.3
25 | importlib-metadata==4.11.2
26 | joblib==1.1.0
27 | Markdown==3.3.6
28 | multidict==6.0.2
29 | nltk==3.7
30 | numpy==1.22.2
31 | oauthlib==3.2.0
32 | omegaconf==2.1.1
33 | packaging==21.3
34 | portalocker==2.4.0
35 | protobuf==3.19.4
36 | pyasn1==0.4.8
37 | pyasn1-modules==0.2.8
38 | pyDeprecate==0.3.1
39 | pyparsing==3.0.7
40 | pytorch-lightning==1.5.10
41 | PyYAML==6.0
42 | regex==2022.3.2
43 | requests==2.27.1
44 | requests-oauthlib==1.3.1
45 | rouge-score==0.0.4
46 | rsa==4.8
47 | sacrebleu==2.0.0
48 | sacremoses==0.0.47
49 | six==1.16.0
50 | tabulate==0.8.9
51 | tensorboard==2.8.0
52 | tensorboard-data-server==0.6.1
53 | tensorboard-plugin-wit==1.8.1
54 | tokenizers==0.11.6
55 | torch==1.10.2
56 | torchmetrics==0.7.2
57 | tqdm==4.63.0
58 | transformers==4.16.2
59 | typing_extensions==4.1.1
60 | urllib3==1.26.8
61 | Werkzeug==2.0.3
62 | yarl==1.7.2
63 | zipp==3.7.0
64 |
--------------------------------------------------------------------------------
/tools/split_existing_data_aligned.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import sys
4 | from pathlib import Path
5 |
6 | input_data_file = Path(sys.argv[1])
7 | clip_to_video_file = Path(sys.argv[2])
8 | input_samples = json.loads(input_data_file.read_text())
9 | clip_to_video = json.loads(clip_to_video_file.read_text())
10 |
11 | all_clips = {s['video_id'] for s in input_samples}
12 | all_videos = {clip_to_video[clip] for clip in all_clips}
13 |
14 | frac_val = 0.5
15 | assert frac_val < 1
16 |
17 | num_val = round(len(all_videos) * frac_val)
18 | num_test = len(all_videos) - num_val
19 | assert num_test >= 1
20 |
21 | val_videos = set(random.sample(sorted(all_videos), num_val))
22 | test_videos = all_videos - val_videos
23 |
24 | val_clips = {clip for clip in all_clips if clip_to_video[clip] in val_videos}
25 | test_clips = {clip for clip in all_clips if clip_to_video[clip] in test_videos}
26 |
27 | val_samples = [s for s in input_samples if s['video_id'] in val_clips]
28 | test_samples = [s for s in input_samples if s['video_id'] in test_clips]
29 |
30 | print(f'Splits: videos/clips/samples ')
31 | print(f'Val : {len(val_videos)}/{len(val_clips)}/{len(val_samples)}')
32 | print(f'Test : {len(test_videos)}/{len(test_clips)}/{len(test_samples)}')
33 |
34 | Path('pure_emqa_val.json').write_text(json.dumps(val_samples))
35 | Path('pure_emqa_test.json').write_text(json.dumps(test_samples))
36 |
37 | Path('split_videos.json').write_text(json.dumps({'val': sorted(val_videos), 'test': sorted(test_videos)}))
38 | Path('split_clips.json').write_text(json.dumps({'val': sorted(val_clips), 'test': sorted(test_clips)}))
39 |
--------------------------------------------------------------------------------
/model/simple_vqa.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | from torch import nn
5 | from transformers.modeling_outputs import Seq2SeqLMOutput
6 |
7 | from model.base import MemoryAugmentedTransformerEmqaModel
8 | from model.moment_loc import TransformerMomentLocalizationLossModule
9 |
10 |
11 | # noinspection PyAbstractClass
12 | class SimpleVqaModel(MemoryAugmentedTransformerEmqaModel):
13 | # Actually this is not really a MemoryAugmentedTransformerEmqaModel since it uses full attention over the input
14 |
15 | def __init__(self, input_size: int,
16 | pretrained_enc_dec: str,
17 | moment_localization_loss: TransformerMomentLocalizationLossModule = None) -> None:
18 | super().__init__(pretrained_enc_dec)
19 | self.moment_localization_loss = moment_localization_loss
20 | self.transformer.get_decoder().config.output_attentions = True
21 |
22 | hidden = self.transformer.get_input_embeddings().embedding_dim
23 | if input_size != hidden:
24 | self.transform_visual = nn.Linear(input_size, hidden, bias=False)
25 | else:
26 | self.transform_visual = nn.Identity()
27 |
28 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels):
29 | visual_seq = self.transform_visual(video_features)
30 | context, context_mask = self._prepare_context(visual_seq, video_mask, question_tokens, question_mask)
31 | return context, context_mask, visual_seq, video_mask, {}
32 |
33 | def calc_additional_loss(self, question_tokens, question_mask, video_features, video_mask, answer_tokens,
34 | answer_mask, batch_sample_ids, context, context_mask, final_memory, mem_mask,
35 | transformer_output: Seq2SeqLMOutput,
36 | moment_localization_labels) -> Dict[str, torch.Tensor]:
37 | if self.moment_localization_loss:
38 | return {
39 | 'moment_localization': self.moment_localization_loss(
40 | question_tokens, transformer_output, moment_localization_labels, video_mask)
41 | }
42 | else:
43 | return {}
44 |
--------------------------------------------------------------------------------
/tools/create_pure_videoqa_json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import Dict
5 |
6 |
7 | def convert(nlq_file: Path, answers: Dict[str, str]):
8 | """
9 | Reads NLQ JSON file and answer map, and creates a single dict with a list of
10 | {"video_id", "sample_id", "question", "answer", "moment_start_frame", "moment_end_frame"} objects.
11 |
12 | :param nlq_file: NLQ JSON file
13 | :param answers: Answer map
14 | """
15 | result = []
16 |
17 | annotations = json.loads(nlq_file.read_text())
18 | for video in annotations['videos']:
19 | for clip in video['clips']:
20 | for annotation in clip['annotations']:
21 | for i, query in enumerate(annotation['language_queries']):
22 | if 'query' not in query or not query['query']:
23 | continue
24 | question = query['query'].replace('\n', '').replace(',', '').strip()
25 | video_id = clip['clip_uid']
26 | sample_id = f'{annotation["annotation_uid"]}_{i}'
27 | if sample_id not in answers:
28 | continue
29 | answer = answers[sample_id].replace('\n', '').replace(',', '').strip()
30 | fps = 30 # fps = 30 is known for canonical Ego4D clips
31 | start_frame = query['clip_start_sec'] * fps
32 | end_frame = query['clip_end_sec'] * fps
33 |
34 | result.append({
35 | 'video_id': video_id,
36 | 'sample_id': sample_id,
37 | 'answer': answer,
38 | 'question': question,
39 | 'moment_start_frame': start_frame,
40 | 'moment_end_frame': end_frame
41 | })
42 |
43 | return result
44 |
45 |
46 | def main():
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--ego4d', required=True, type=str,
49 | help='Directory where you placed the Ego4D download.')
50 | parser.add_argument('--qaego4d', required=True, type=str,
51 | help='Path to QaEgo4D answers.json file')
52 | args = parser.parse_args()
53 |
54 | ego4d_dir = Path(args.ego4d)
55 | ego4d_annotations_dir = ego4d_dir / 'v1' / 'annotations'
56 | qaego4d_file = Path(args.qaego4d)
57 | assert ego4d_dir.is_dir()
58 | assert ego4d_annotations_dir.is_dir()
59 | assert qaego4d_file.is_file()
60 |
61 | qaego4d_data = json.loads(qaego4d_file.read_text())
62 | nlq_train, nlq_val = [ego4d_annotations_dir / f'nlq_{split}.json' for split in ('train', 'val')]
63 |
64 | train = convert(nlq_train, qaego4d_data['train'])
65 | val = convert(nlq_val, qaego4d_data['val'])
66 | test = convert(nlq_val, qaego4d_data['test'])
67 |
68 | output_dir = qaego4d_file.parent
69 | (output_dir / 'annotations.train.json').write_text(json.dumps(train))
70 | (output_dir / 'annotations.val.json').write_text(json.dumps(val))
71 | (output_dir / 'annotations.test.json').write_text(json.dumps(test))
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
--------------------------------------------------------------------------------
/data/datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader
5 | from transformers import PreTrainedTokenizerBase as Tokenizer, AutoTokenizer
6 |
7 | from data.emqa_dataset import EmqaDataset
8 |
9 |
10 | # noinspection PyAbstractClass
11 | class EmqaDataModule(LightningDataModule):
12 | tokenizer: Tokenizer
13 |
14 | def __init__(self, config, drop_last=True):
15 | """
16 |
17 | :param config: Needs {tokenizer_name: str,
18 | separate_question_tok_name: Optional[str],
19 | drop_val_last: Optional[bool] = False
20 | workers: int,
21 | use_final_test: bool,
22 | train_bsz: int,
23 | test_bsz: int
24 | } + requirements from EmqaDataset.create_from_cfg
25 | """
26 | super().__init__()
27 | self.config = config
28 | self.drop_last = drop_last
29 | self.drop_val_last = getattr(config, 'drop_val_last', False)
30 | self.train_dataset, self.val_dataset, self.test_dataset = None, None, None
31 | self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
32 | self.tokenizer.pad_token = self.tokenizer.eos_token # Set this per convenience for GPT-2
33 | if getattr(config, 'separate_question_tok_name', None):
34 | self.question_tok = AutoTokenizer.from_pretrained(config.separate_question_tok_name)
35 | else:
36 | self.question_tok = None
37 |
38 | def setup(self, stage: Optional[str] = None) -> None:
39 | super().setup(stage)
40 | self.train_dataset = self._create_data('train')
41 | self.val_dataset = self._create_data('val')
42 | if self.config.use_final_test:
43 | self.test_dataset = self._create_data('test')
44 | else:
45 | self.test_dataset = self.val_dataset or self._create_data('val')
46 |
47 | def _create_data(self, split):
48 | return EmqaDataset.create_from_cfg(self.config, split, self.tokenizer, self.question_tok)
49 |
50 | def common_loader_args(self, dataset):
51 | return dict(num_workers=self.config.workers,
52 | collate_fn=dataset.collate_emv_samples,
53 | pin_memory=True)
54 |
55 | def eval_loader_args(self, dataset):
56 | return dict(**self.common_loader_args(dataset),
57 | shuffle=False,
58 | batch_size=self.config.test_bsz)
59 |
60 | def train_dataloader(self):
61 | assert self.train_dataset
62 | return DataLoader(self.train_dataset,
63 | batch_size=self.config.train_bsz,
64 | shuffle=True,
65 | drop_last=self.drop_last,
66 | **self.common_loader_args(self.train_dataset))
67 |
68 | def val_dataloader(self):
69 | assert self.val_dataset
70 | return DataLoader(self.val_dataset, drop_last=self.drop_val_last,
71 | **self.eval_loader_args(self.val_dataset))
72 |
73 | def test_dataloader(self):
74 | assert self.test_dataset
75 | return DataLoader(self.test_dataset, **self.eval_loader_args(self.test_dataset))
76 |
--------------------------------------------------------------------------------
/model/lt_ct.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import Tuple, Dict
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from .base import MemoryAugmentedTransformerEmqaModel
8 | from .external.compressive_transformer import CompressiveTransformer
9 |
10 |
11 | class CompressiveTransformerEmqaModel(MemoryAugmentedTransformerEmqaModel):
12 |
13 | def __init__(
14 | self,
15 | pretrained_enc_dec: str,
16 | input_dim: int,
17 | hidden_dim: int = 512,
18 | num_layers: int = 10,
19 | heads: int = 8,
20 | block_length: int = 16,
21 | mem_length: int = 32,
22 | cmem_lengths=None,
23 | compression_factors=4,
24 | use_ltmem: bool = True,
25 | memory_layers=None,
26 | dropout: float = 0.1
27 | ):
28 | super().__init__(pretrained_enc_dec)
29 | self.mem_transformer = CompressiveTransformer(
30 | num_tokens=1, # Embedding is skipped (video features input)
31 | # However, we set emb_dim to automatically use CompressiveTransformer.to_model_dim
32 | emb_dim=input_dim,
33 | dim=hidden_dim, depth=num_layers, heads=heads,
34 | seq_len=block_length, mem_len=mem_length,
35 | cmem_lengths=cmem_lengths,
36 | cmem_ratios=compression_factors,
37 | use_ltmem=use_ltmem, memory_layers=memory_layers,
38 | attn_layer_dropout=dropout, ff_dropout=dropout, attn_dropout=dropout,
39 | reconstruction_attn_dropout=dropout,
40 | gru_gated_residual=False, mogrify_gru=False, ff_glu=False, one_kv_head=False
41 | )
42 | self.mem_transformer.token_emb = nn.Identity()
43 | self.mem_transformer.to_logits = nn.Identity()
44 |
45 | def forward_memory(self, video_features, video_mask,
46 | moment_localization_labels,
47 | question_encoding) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
48 | bsz, seq_len = video_features.shape[0], video_features.shape[1]
49 |
50 | block_length = self.mem_transformer.seq_len
51 | num_blocks = seq_len // block_length + (0 if seq_len % block_length == 0 else 1)
52 |
53 | x = video_features
54 | memories = (None, None, None)
55 | aux_loss_sum = torch.zeros(1, device=x.device)
56 | for i in range(num_blocks):
57 | current_slice = slice(i * block_length, (i + 1) * block_length)
58 | input_block = x[:, current_slice, :]
59 | out, memories, aux_loss = self.mem_transformer(input_block,
60 | memories=memories,
61 | mask=video_mask[:, current_slice])
62 | aux_loss_sum += aux_loss
63 |
64 | if num_blocks > 0:
65 | mem, cmems, ltmem = memories
66 | # memories is a tuple with three items. mem and ltmem are tensors, cmems a list of tensors, all of size
67 | # (num_memory_layers x batch x memory_seq_length x hidden)
68 | # ltmem always has memory_seq_length of either 0 or 1
69 | # out is (batch x sequence x hidden) from transformers last layer.
70 | # concatenate at sequence level (dim=1), but treat each mem layer as it's own vector
71 | # Also, treat all compression layers the same and simply concatenate
72 | memories = mem, *cmems, ltmem
73 | # noinspection PyUnboundLocalVariable
74 | complete_em = torch.cat([out] + [layer_mem for layer_mem in chain(*memories)], dim=1) # B x S x H
75 | else:
76 | complete_em = torch.empty(bsz, 0, self.output_size,
77 | device=x.device, dtype=x.dtype)
78 |
79 | return complete_em, {'ct_aux_loss': aux_loss_sum}
80 |
--------------------------------------------------------------------------------
/tools/extract_ego4d_clip_features.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | from argparse import ArgumentParser
5 | from dataclasses import dataclass
6 | from functools import partial
7 | from multiprocessing import Pool, Manager, JoinableQueue
8 | from pathlib import Path
9 | from typing import List
10 |
11 | import torch
12 | from tqdm import tqdm
13 |
14 |
15 | def existing_path(p: str):
16 | p = Path(p)
17 | assert p.exists(), p
18 | return p
19 |
20 |
21 | @dataclass
22 | class WorkItem:
23 | video_uid: str
24 | clip_uid: str
25 | clip_start_frame: int
26 | clip_end_frame: int
27 |
28 | def output_file(self, output_dir: Path):
29 | return output_dir / f'{self.clip_uid}.pt'
30 |
31 | def do_work(self, video_features_dir: Path, output_dir: Path, feature_window_stride: int,
32 | progress_queue: JoinableQueue):
33 | video_features_file = video_features_dir / f'{self.video_uid}.pt'
34 | output_file = self.output_file(output_dir)
35 | start_feature = self.clip_start_frame // feature_window_stride
36 | end_feature = math.ceil(self.clip_end_frame / feature_window_stride)
37 | video_features = torch.load(str(video_features_file))
38 | clip_features = video_features[start_feature:end_feature]
39 | torch.save(clip_features, str(output_file))
40 | progress_queue.put_nowait(1)
41 |
42 |
43 | def _extract_work_items(annotations) -> List[WorkItem]:
44 | work_items = []
45 | for video in annotations['videos']:
46 | for clip in video['clips']:
47 | clip_uid = clip['clip_uid']
48 | # Wanted to use video_start/end_frame, but there seems to be a bug with metadata in Ego4D data so
49 | # that length of clips would be zero. Thus, calc frames from seconds.
50 | # 30 fps is safe to assume for canonical videos
51 | start = int(clip['video_start_sec'] * 30)
52 | end = int(clip['video_end_sec'] * 30)
53 | assert start != end, f'{start}, {end}, {clip_uid}'
54 | work_items.append(WorkItem(
55 | video['video_uid'], clip_uid,
56 | start, end
57 | ))
58 | return work_items
59 |
60 |
61 | def main(nlq_file: Path, video_features_dir: Path, output_dir: Path,
62 | feature_window_stride=16, num_workers=16):
63 | annotations = json.loads(nlq_file.read_text())
64 | all_work_items = _extract_work_items(annotations)
65 | all_work_items = [w for w in all_work_items if not w.output_file(output_dir).is_file()]
66 | print(f'Will extract {len(all_work_items)} clip features...')
67 |
68 | with Pool(num_workers) as pool, Manager() as manager:
69 | queue = manager.Queue()
70 | pool.map_async(partial(WorkItem.do_work,
71 | video_features_dir=video_features_dir,
72 | output_dir=output_dir,
73 | feature_window_stride=feature_window_stride,
74 | progress_queue=queue),
75 | all_work_items)
76 | with tqdm(total=len(all_work_items)) as pbar:
77 | while pbar.n < len(all_work_items):
78 | pbar.update(queue.get(block=True))
79 |
80 |
81 | def cli_main():
82 | parser = ArgumentParser()
83 | parser.add_argument('--annotation_file', type=existing_path, required=True,
84 | help='Ego4D Annotation JSON file containing annotations for which to extract clip features. '
85 | 'Should contain "videos" array, where each item has a "clips" array '
86 | '(e.g. NLQ annotations).')
87 | parser.add_argument('--video_features_dir', type=existing_path, required=True,
88 | help='Directory where to find pre-extracted Ego4D video features.')
89 | parser.add_argument('--output_dir', type=existing_path, required=True,
90 | help='Directory where to place output files. They are named after the clip_uid.')
91 | parser.add_argument('--feature_window_stride', type=int, default=16,
92 | help='Stride of window used to produce the features in video_features_dir')
93 | parser.add_argument('--num_workers', type=int, default=os.cpu_count(),
94 | help='Number of parallel workers')
95 |
96 | args = parser.parse_args()
97 | main(args.annotation_file, args.video_features_dir, args.output_dir,
98 | args.feature_window_stride, args.num_workers)
99 |
100 |
101 | if __name__ == '__main__':
102 | cli_main()
103 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QaEgo4D — Episodic-Memory-Based Question Answering on Egocentric Videos
2 |
3 | This repository contains the code to reproduce the results of the paper "Where did I leave my keys? —
4 | Episodic-Memory-Based Question Answering on Egocentric Videos". See our
5 | [paper](https://openaccess.thecvf.com/content/CVPR2022W/Ego4D-EPIC/papers/Barmann_Where_Did_I_Leave_My_Keys_-_Episodic-Memory-Based_Question_Answering_CVPRW_2022_paper.pdf)
6 | for more details.
7 |
8 | ## Abstract
9 |
10 | Humans have a remarkable ability to organize, compress and retrieve episodic memories throughout their daily life.
11 | Current AI systems, however, lack comparable capabilities as they are mostly constrained to an analysis with access to
12 | the raw input sequence, assuming an unlimited amount of data storage which is not feasible in realistic deployment
13 | scenarios. For instance, existing Video Question Answering (VideoQA) models typically reason over the video while
14 | already being aware of the question, thus requiring to store the complete video in case the question is not known in
15 | advance.
16 |
17 | In this paper, we address this challenge with three main contributions:
18 | First, we propose the Episodic Memory Question Answering (EMQA) task as a specialization of VideoQA. Specifically, EMQA
19 | models are constrained to keep only a constant-sized representation of the video input, thus automatically limiting the
20 | computation requirements at query time. Second, we introduce a new egocentric VideoQA dataset
21 | called QaEgo4D. It is the by far largest egocentric VideoQA dataset and
22 | video length is unprecedented in VideoQA datasets in general. Third, we present extensive experiments on the new
23 | dataset, comparing various baselines models in both the VideoQA as well as the EMQA setting. To facilitate future
24 | research on egocentric VideoQA as well as episodic memory representation and retrieval, we publish our code and dataset.
25 |
26 | ## Using the dataset
27 |
28 | To use the QaEgo4D dataset introduced in our paper, please follow these
29 | steps:
30 |
31 | 1. QaEgo4D builds on the Ego4D v1 videos and annotations. If you do not have
32 | access to Ego4D already, you should follow the steps at the [Ego4D website](https://ego4d-data.org/docs/start-here/)
33 | 2. To get access to QaEgo4D, please fill out
34 | this [Google form](https://forms.gle/Gxs93wwC5YYJtjqh8). You will need to sign a license agreement, but there are no
35 | fees if you use the data for non-commercial research purposes.
36 | 3. Download the Ego4D annotations and NLQ clips if you have not done so already. See
37 | the [Ego4D website](https://ego4d-data.org/docs/start-here/)
38 | 4. After you have access to both Ego4D and QaEgo4D, you can generate
39 | self-contained VideoQA annotation files
40 | using `python3 tools/create_pure_videoqa_json.py --ego4d /path/to/ego4d --qaego4d /path/to/qaego4d/answers.json`.
41 | `/path/to/ego4d` is the directory where you placed the Ego4D download, containing
42 | the `v1/annotations/nlq_{train,val}.json` files. This produces `/path/to/qaego4d/annotations.{train,val,test}.json`.
43 |
44 | The `annotations.*.json` files are JSON arrays, where each object has the following structure:
45 | ```
46 | {
47 | "video_id": "abcdef00-0000-0000-0000-123456789abc",
48 | "sample_id": "12345678-1234-1234-1234-123456789abc_3",
49 | "question": "Where did I leave my keys?",
50 | "answer": "on the table",
51 | "moment_start_frame": 42,
52 | "moment_end_frame": 53
53 | }
54 | ```
55 |
56 | ## Code
57 |
58 | In order to reproduce the experiments, prepare your workspace:
59 |
60 | 1. Follow the instructions above to get the dataset and features.
61 | 2. Create a conda / python virtual environment (Python 3.9.7)
62 | 3. Install the requirements in `requirements.txt`
63 | 4. Prepare the features:
64 | 1. Download the pre-extracted [Ego4D features](https://ego4d-data.org/docs/data/features/) if you have not done so
65 | already.
66 | 2. Ego4D features are provided for each canonical video, while the NLQ task and thus also VideoQA works on the
67 | canonical clips. To extract features for each clip,
68 | use `python tools/extract_ego4d_clip_features.py --annotation_file /path/to/ego4d/v1/annotations/nlq_train.json --video_features_dir /path/to/ego4d/v1/slowfast8x8_r101_k400 --output_dir /choose/your/clip_feature_dir`
69 | and do the same again with `nlq_val.json`
70 | 3. Aggregate the features into a single file
71 | using `python tools/aggregate_features_to_hdf5.py /choose/your/clip_feature_dir`. This
72 | produces `slowfast8x8_r101_k400.hdf5` in the current working directory.
73 | 5. Place or link the QaEgo4D data (`annotations.*.json`
74 | and `slowfast8x8_r101_k400.hdf5`) into `datasets/ego4d`.
75 |
76 | To run an experiment, use `bash experiment/run.sh`. All configuration files can be found in the `config` dir.
77 |
78 |
79 | ## Cite
80 | ```
81 | @InProceedings{Baermann_2022_CVPR,
82 | author = {B\"armann, Leonard and Waibel, Alex},
83 | title = {Where Did I Leave My Keys? - Episodic-Memory-Based Question Answering on Egocentric Videos},
84 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
85 | month = {June},
86 | year = {2022},
87 | pages = {1560-1568}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/eval/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from collections import defaultdict
4 | from pathlib import Path
5 | from pprint import pprint
6 | from typing import List, Dict, Any
7 |
8 | from nltk.translate.meteor_score import meteor_score
9 | from rouge_score.rouge_scorer import RougeScorer
10 | from rouge_score.tokenize import tokenize
11 | from sacrebleu.metrics import BLEU, BLEUScore
12 |
13 | from .util import AverageMeter
14 |
15 |
16 | # Check whether to use
17 | # - https://github.com/Maluuba/nlg-eval
18 | # - https://github.com/hwanheelee1993/KPQA
19 | def calc_metrics(predictions: List[str], gold_annotations: List[List[str]]) -> Dict[str, Any]:
20 | """
21 | Calculate metrics.
22 |
23 | Parameters
24 | ----------
25 | predictions : list[str]
26 | The list of predictions
27 | gold_annotations : list[list[str]]
28 | A list with the same length as predictions.
29 | Each element is a list of possible target candidates for the corresponding prediction.
30 | All elements should have the same length.
31 | """
32 | if len(predictions) != len(gold_annotations):
33 | raise ValueError(f'{len(predictions)} != {len(gold_annotations)}')
34 | ref_count = len(gold_annotations[0])
35 | if any(len(refs) != ref_count for refs in gold_annotations):
36 | raise ValueError(f'All refs should have the same length {ref_count}!')
37 |
38 | acc = _calc_accuracy(predictions, gold_annotations)
39 | bleu = _calc_bleu(predictions, gold_annotations)
40 | rouge = _calc_rouge(predictions, gold_annotations)
41 | meteor = _calc_meteor(predictions, gold_annotations)
42 |
43 | return {
44 | 'plain_acc': acc,
45 | **bleu,
46 | 'ROUGE': rouge['rougeL']['f'],
47 | **_flatten_dict(rouge, prefix='ROUGE.'),
48 | 'METEOR': meteor
49 | }
50 |
51 |
52 | def _calc_accuracy(predictions, gold_annotations):
53 | correct = 0
54 | for pred, possible_refs in zip(predictions, gold_annotations):
55 | if any(ref == pred for ref in possible_refs):
56 | correct += 1
57 | total = len(predictions)
58 | return correct / total
59 |
60 |
61 | def _calc_meteor(predictions, gold_annotations):
62 | score = AverageMeter()
63 | for pred, possible_refs in zip(predictions, gold_annotations):
64 | pred = tokenize(pred, None)
65 | # https://github.com/cmu-mtlab/meteor/blob/master/src/edu/cmu/meteor/util/Normalizer.java
66 | possible_refs = [tokenize(x, None) for x in possible_refs]
67 | score.update(meteor_score(possible_refs, pred))
68 | return score.avg
69 |
70 |
71 | def _calc_rouge(predictions, gold_annotations) -> Dict[str, Dict[str, float]]:
72 | rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
73 | rouge = defaultdict(lambda: defaultdict(AverageMeter))
74 | for pred, possible_refs in zip(predictions, gold_annotations):
75 | sample_result = {}
76 | for ref in possible_refs:
77 | single_ref_result = rouge_scorer.score(ref, pred)
78 | for k, scores in single_ref_result.items():
79 | existing_result_dict = sample_result.setdefault(k, {})
80 | if existing_result_dict.get('f', -1) < scores.fmeasure:
81 | existing_result_dict.update(f=scores.fmeasure, p=scores.precision, r=scores.recall)
82 | for k, best_scores in sample_result.items():
83 | rouge[k]['p'].update(best_scores['p'])
84 | rouge[k]['r'].update(best_scores['r'])
85 | rouge[k]['f'].update(best_scores['f'])
86 | return {
87 | rouge_type: {
88 | measure: score.avg
89 | for measure, score in results.items()
90 | } for rouge_type, results in rouge.items()
91 | }
92 |
93 |
94 | def _calc_bleu(predictions, gold_annotations) -> Dict[str, float]:
95 | refs_transposed = [
96 | [refs[i] for refs in gold_annotations]
97 | for i in range(len(gold_annotations[0]))
98 | ]
99 | bleu: BLEUScore = BLEU().corpus_score(predictions, refs_transposed)
100 | return {
101 | 'BLEU': bleu.score,
102 | 'BLEU.bp': bleu.bp,
103 | 'BLEU.ratio': bleu.ratio,
104 | 'BLEU.hyp_len': float(bleu.sys_len),
105 | 'BLEU.ref_len': float(bleu.ref_len),
106 | }
107 |
108 |
109 | def _flatten_dict(d, prefix=''):
110 | result = {}
111 | for k, v in d.items():
112 | my_key = prefix + k
113 | if isinstance(v, dict):
114 | result.update(_flatten_dict(v, prefix=my_key + '.'))
115 | else:
116 | result[my_key] = v
117 | return result
118 |
119 |
120 | def main():
121 | parser = ArgumentParser('Eval output file')
122 | parser.add_argument('--gold_answers', type=str, required=True,
123 | help='Path to answers.json, containing mapping from sample_id to answer')
124 | parser.add_argument('eval_file', type=str,
125 | help='JSON File to evaluate. Should contain mapping from sample_id '
126 | 'to hypothesis or array of hypotheses')
127 | args = parser.parse_args()
128 |
129 | gold_answers = json.loads(Path(args.gold_answers).read_text())
130 | hypotheses = json.loads(Path(args.eval_file).read_text())
131 | if isinstance(next(iter(hypotheses.values())), list):
132 | hypotheses = {k: v[0] for k, v in hypotheses.items()}
133 | assert len(hypotheses.keys() - gold_answers.keys()) == 0, 'No gold answer for some hypotheses'
134 |
135 | gold_and_hypo = [(gold_answers[k], hypotheses[k]) for k in hypotheses.keys()]
136 | hypo_list = [h for g, h in gold_and_hypo]
137 | gold_list = [[g] for g, h in gold_and_hypo]
138 | metrics = calc_metrics(hypo_list, gold_list)
139 |
140 | pprint(metrics)
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/data/emqa_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | from abc import ABC
4 | from pathlib import Path
5 | from typing import Dict, Union, List
6 | from typing import Literal
7 |
8 | import h5py
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.nn.utils.rnn import pad_sequence
12 | from torch.utils.data import Dataset
13 | from transformers import PreTrainedTokenizerBase as Tokenizer
14 |
15 | EmqaAnnotationFileDictKey = Literal[
16 | 'sample_id',
17 | 'video_id',
18 | 'question',
19 | 'answer',
20 | 'moment_start_frame', # optional
21 | 'moment_end_frame', # optional
22 | ]
23 | EmqaSampleDictKey = Literal[
24 | 'video_id',
25 | 'sample_id',
26 | 'question_text',
27 | 'answer_text',
28 | 'video_features',
29 | 'moment_label'
30 | ]
31 | EmqaBatchDictKey = Literal[
32 | 'batch_video_ids', # List[int]
33 | 'batch_sample_ids', # List[int]
34 | 'batch_question_tokens', # int, B x L
35 | 'batch_question_mask', # bool, B x L
36 | 'batch_video_features', # float, B x N x H
37 | 'batch_video_mask', # bool, B x N
38 | 'batch_answer_texts', # List[str]
39 | 'batch_answer_tokens', # int, B x T
40 | 'batch_answer_mask', # bool, B x T
41 | 'batch_moment_localization_labels' # float, B x N, optional
42 | ]
43 | EmqaBatch = Dict[EmqaBatchDictKey, Union[torch.Tensor, List[int], List[str]]]
44 | EmqaSample = Dict[EmqaSampleDictKey, Union[torch.Tensor, int, str]]
45 | DatasetSplitName = Literal["train", "val", "test"]
46 |
47 |
48 | class EmqaDataset(Dataset, ABC):
49 | tokenizer: Tokenizer
50 | split: DatasetSplitName
51 |
52 | def __init__(self, video_features_file: Path,
53 | tokenizer: Tokenizer,
54 | split: DatasetSplitName,
55 | annotations: List[Dict[EmqaAnnotationFileDictKey, Union[int, str]]],
56 | normalize_video_features=False,
57 | frames_per_feature=16,
58 | separate_question_tokenizer: Tokenizer = None):
59 | super().__init__()
60 | self.video_features_file = video_features_file
61 | self.tokenizer = tokenizer
62 | self.separate_question_tokenizer = separate_question_tokenizer or tokenizer
63 | self.split = split
64 | self.annotations = annotations
65 | self.normalize_video_features = normalize_video_features
66 | self.frames_per_feature = frames_per_feature
67 |
68 | def __len__(self) -> int:
69 | return len(self.annotations)
70 |
71 | def __getitem__(self, index) -> EmqaSample:
72 | video_id = self.annotations[index]['video_id']
73 | sample_id = self.annotations[index]['sample_id']
74 | question = self.annotations[index]['question']
75 | answer = self.annotations[index]['answer']
76 | gt_start_frame = self.annotations[index].get('moment_start_frame')
77 | gt_end_frame = self.annotations[index].get('moment_end_frame')
78 | video_features = self._get_video_features(video_id)
79 |
80 | sample: EmqaSample = {
81 | 'video_id': video_id,
82 | 'sample_id': sample_id,
83 | 'video_features': video_features,
84 | 'question_text': question,
85 | 'answer_text': answer
86 | }
87 | if gt_start_frame is not None and gt_end_frame is not None:
88 | start = gt_start_frame // self.frames_per_feature
89 | # ensure at least one target frame even if gt_start == gt_end
90 | end = math.ceil(gt_end_frame / self.frames_per_feature)
91 | if start == end:
92 | end += 1
93 | sample['moment_label'] = torch.tensor([start, end], dtype=torch.int)
94 | return sample
95 |
96 | def _get_video_features(self, video_id):
97 | with h5py.File(self.video_features_file, 'r') as hdf5_file:
98 | features = torch.from_numpy(hdf5_file[video_id][:]).float()
99 | if self.normalize_video_features:
100 | features = F.normalize(features, dim=1)
101 | return features
102 |
103 | def collate_emv_samples(self, batch: List[EmqaSample]) -> EmqaBatch:
104 | video_ids = [b['video_id'] for b in batch]
105 | sample_ids = [b['sample_id'] for b in batch]
106 | video_features = [b['video_features'] for b in batch]
107 | questions = [b['question_text'] for b in batch]
108 | answers = [b['answer_text'] for b in batch]
109 |
110 | answers_with_eos = [a + self.tokenizer.eos_token for a in answers]
111 | tok_args = dict(padding=True, return_tensors='pt', add_special_tokens=False)
112 | question_tok = self.separate_question_tokenizer(questions, **tok_args)
113 | answers_tok = self.tokenizer(answers_with_eos, **tok_args)
114 |
115 | video_features_padded = pad_sequence(video_features, batch_first=True)
116 | video_mask = pad_sequence([torch.ones(len(v)) for v in video_features], batch_first=True).bool()
117 |
118 | result: EmqaBatch = {
119 | 'batch_video_ids': video_ids,
120 | 'batch_sample_ids': sample_ids,
121 | 'batch_question_tokens': question_tok['input_ids'],
122 | 'batch_question_mask': question_tok['attention_mask'],
123 | 'batch_answer_texts': answers,
124 | 'batch_answer_tokens': answers_tok['input_ids'],
125 | 'batch_answer_mask': answers_tok['attention_mask'],
126 | 'batch_video_features': video_features_padded,
127 | 'batch_video_mask': video_mask
128 | }
129 | if 'moment_label' in batch[0]:
130 | moment_labels = torch.zeros(len(batch), video_features_padded.shape[1])
131 | for i, b in enumerate(batch):
132 | gt_start, gt_end = b['moment_label']
133 | moment_labels[i, gt_start:gt_end] = 1
134 | # add smoothing before/after gt_start & end ?
135 | result['batch_moment_localization_labels'] = moment_labels
136 |
137 | return result
138 |
139 | @classmethod
140 | def create_from_cfg(cls, cfg, split: DatasetSplitName, tokenizer: Tokenizer,
141 | separate_question_tok: Tokenizer = None):
142 | """
143 | Create EmqaDataset from cfg.
144 |
145 | :param cfg: Needs data_dir, feature_type
146 | :param split: train / val / test
147 | :param tokenizer: hugginface tokenizer
148 | :param separate_question_tok: separate tok for questions
149 | :return: EmqaDataset
150 | """
151 | ds_dir = Path(cfg.data_dir)
152 | video_features_file = ds_dir / f'{split}-{cfg.feature_type}.hdf5'
153 | if not video_features_file.is_file():
154 | video_features_file = ds_dir / f'{cfg.feature_type}.hdf5'
155 | if not video_features_file.is_file():
156 | raise ValueError(str(video_features_file))
157 | annotation_file = ds_dir / f'annotations.{split}.json'
158 | annotations = json.loads(annotation_file.read_text())
159 | return cls(video_features_file, tokenizer, split, annotations,
160 | separate_question_tokenizer=separate_question_tok)
161 |
--------------------------------------------------------------------------------
/model/sparse_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import torch
4 | from torch import nn
5 | from transformers import AutoModelForSeq2SeqLM
6 | from transformers.modeling_outputs import BaseModelOutput
7 | from transformers.models.longformer.modeling_longformer import LongformerModel
8 | from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForConditionalGeneration
9 |
10 | from model.base import EmqaBaseModel
11 | from model.moment_loc import TransformerMomentLocalizationLossModule
12 |
13 |
14 | # noinspection PyAbstractClass
15 | class LongformerVqaModel(EmqaBaseModel):
16 | # Actually this is not really a EMQA model since it uses full attention over the input
17 |
18 | def __init__(self, input_size: int,
19 | pretrained_enc_dec: str,
20 | pretrained_longformer: str,
21 | moment_localization_loss: TransformerMomentLocalizationLossModule = None) -> None:
22 | super().__init__()
23 | # Actually, only using the decoder of the enc_dec model. Just want to use LM head + cross attention
24 | self.moment_localization_loss = moment_localization_loss
25 | self.enc_dec = AutoModelForSeq2SeqLM.from_pretrained(pretrained_enc_dec)
26 | self.enc_dec.encoder.block = None
27 | self.enc_dec.encoder.final_layer_norm = None
28 | self.enc_dec.decoder.config.output_attentions = True
29 |
30 | self.longformer = LongformerModel.from_pretrained(pretrained_longformer, add_pooling_layer=False)
31 |
32 | longformer_h = self.longformer.get_input_embeddings().embedding_dim
33 | if input_size != longformer_h:
34 | self.transform_visual = nn.Linear(input_size, longformer_h, bias=False)
35 | else:
36 | self.transform_visual = nn.Identity()
37 | decoder_h = self.enc_dec.get_input_embeddings().embedding_dim
38 | if longformer_h != decoder_h:
39 | self.transform_context = nn.Linear(longformer_h, decoder_h, bias=False)
40 | else:
41 | self.transform_context = nn.Identity()
42 |
43 | def teacher_forcing_forward(self, question_tokens, question_mask,
44 | video_features, video_mask,
45 | answer_tokens, answer_mask,
46 | batch_sample_ids,
47 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
48 | context, context_mask = self.forward_encoders(question_tokens, question_mask, video_features, video_mask)
49 | output = self.enc_dec(labels=answer_tokens, decoder_attention_mask=answer_mask,
50 | encoder_outputs=(context,), attention_mask=context_mask)
51 | loss_dict = {'lm_loss': output.loss}
52 | if self.moment_localization_loss:
53 | loss_dict['moment_localization'] = self.moment_localization_loss(
54 | question_tokens, output, moment_localization_labels, video_mask)
55 | return loss_dict, output.logits
56 |
57 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor:
58 | context, context_mask = self.forward_encoders(question_tokens, question_mask, video_features, video_mask)
59 | # noinspection PyTypeChecker
60 | enc_out = BaseModelOutput(last_hidden_state=context)
61 | return self.enc_dec.generate(encoder_outputs=enc_out, attention_mask=context_mask)
62 |
63 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask):
64 | longformer_in = torch.cat([
65 | self.transform_visual(video_features),
66 | self.longformer.get_input_embeddings()(question_tokens)
67 | ], dim=1)
68 | longformer_mask = torch.cat([
69 | video_mask,
70 | question_mask
71 | ], dim=1)
72 | # initialize to global attention to be deactivated for all tokens
73 | global_attention_mask = torch.zeros_like(longformer_mask)
74 | # Set global attention to question tokens
75 | global_attention_mask[:, -question_tokens.shape[1]:] = 1
76 |
77 | context = self.longformer(
78 | inputs_embeds=longformer_in, attention_mask=longformer_mask,
79 | global_attention_mask=global_attention_mask,
80 | ).last_hidden_state
81 | context = self.transform_context(context)
82 | return context, longformer_mask
83 |
84 |
85 | class BigBirdVqaModel(EmqaBaseModel):
86 | # Actually this is not really a EMQA model since it uses full attention over the input
87 |
88 | def __init__(self, input_size: int,
89 | pretrained_bigbird: str,
90 | moment_localization_loss: TransformerMomentLocalizationLossModule = None,
91 | gradient_checkpointing=False) -> None:
92 | super().__init__()
93 | self.moment_localization_loss = moment_localization_loss
94 | self.bigbird = BigBirdPegasusForConditionalGeneration.from_pretrained(pretrained_bigbird)
95 |
96 | bigbird_h = self.bigbird.get_input_embeddings().embedding_dim
97 | if input_size != bigbird_h:
98 | self.transform_visual = nn.Linear(input_size, bigbird_h, bias=False)
99 | else:
100 | self.transform_visual = nn.Identity()
101 |
102 | self.gradient_checkpointing = gradient_checkpointing
103 | if gradient_checkpointing:
104 | self.bigbird.gradient_checkpointing_enable()
105 |
106 | def teacher_forcing_forward(self, question_tokens, question_mask,
107 | video_features, video_mask,
108 | answer_tokens, answer_mask,
109 | batch_sample_ids,
110 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
111 | enc_in, enc_mask = self.prepare_encoder_input(question_tokens, question_mask, video_features, video_mask)
112 | output = self.bigbird(labels=answer_tokens, decoder_attention_mask=answer_mask,
113 | inputs_embeds=enc_in, attention_mask=enc_mask,
114 | use_cache=not self.gradient_checkpointing)
115 | loss_dict = {'lm_loss': output.loss}
116 | if self.moment_localization_loss:
117 | loss_dict['moment_localization'] = self.moment_localization_loss(
118 | question_tokens, output, moment_localization_labels, video_mask)
119 | return loss_dict, output.logits
120 |
121 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor:
122 | enc_in, enc_mask = self.prepare_encoder_input(question_tokens, question_mask, video_features, video_mask)
123 | return self.bigbird.generate(inputs_embeds=enc_in, attention_mask=enc_mask)
124 |
125 | def prepare_encoder_input(self, question_tokens, question_mask, video_features, video_mask):
126 | encoder_in = torch.cat([
127 | self.transform_visual(video_features),
128 | self.bigbird.get_input_embeddings()(question_tokens)
129 | ], dim=1)
130 | encoder_mask = torch.cat([
131 | video_mask,
132 | question_mask
133 | ], dim=1)
134 | return encoder_in, encoder_mask
135 |
--------------------------------------------------------------------------------
/model/importance_teacher.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | from hydra import initialize, compose
5 | from omegaconf import DictConfig
6 | from torch import nn
7 | from tqdm import tqdm
8 |
9 | from data.datamodule import EmqaDataModule
10 | from data.emqa_dataset import EmqaBatch
11 | from model.base import MemoryAugmentedTransformerEmqaModel
12 |
13 |
14 | # Mean Pooling - Take attention mask into account for correct averaging
15 | def mean_pooling(model_output, attention_mask):
16 | token_embeddings = model_output[0] # First element of model_output contains all token embeddings
17 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
18 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
19 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
20 | return sum_embeddings / sum_mask
21 |
22 |
23 | # noinspection PyAbstractClass
24 | class ImportanceTeacherVqaModel(MemoryAugmentedTransformerEmqaModel):
25 | # Actually this is not really a MemoryAugmentedTransformerEmqaModel since it uses full attention over the input
26 |
27 | def __init__(self, fragment_length: int,
28 | input_size: int,
29 | pretrained_enc_dec: str) -> None:
30 | super().__init__(pretrained_enc_dec)
31 | self.clip_avg = nn.AvgPool1d(fragment_length, stride=1)
32 | self.question_embedding = self.transformer.get_encoder()
33 |
34 | hidden = self.transformer.get_input_embeddings().embedding_dim
35 | self.query_attn = nn.MultiheadAttention(embed_dim=hidden,
36 | num_heads=1, batch_first=True)
37 |
38 | sentence_emb = self.question_embedding.get_input_embeddings().embedding_dim
39 | if sentence_emb != hidden:
40 | self.transform_query = nn.Linear(sentence_emb, hidden, bias=False)
41 | else:
42 | self.transform_query = nn.Identity()
43 | if input_size != hidden:
44 | self.transform_visual = nn.Linear(input_size, hidden, bias=False)
45 | else:
46 | self.transform_visual = nn.Identity()
47 |
48 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels):
49 | attn_context = self.attend_to_fragments(question_tokens, question_mask, video_features, video_mask)
50 | attn_mask = torch.ones(attn_context.shape[:-1], device=attn_context.device, dtype=torch.bool)
51 |
52 | context, context_mask = self._prepare_context(attn_context, attn_mask, question_tokens, question_mask)
53 | return context, context_mask, attn_context, attn_mask, {}
54 |
55 | def attend_to_fragments(self, question_tokens, question_mask,
56 | visual_input, visual_mask,
57 | return_attn_scores=False):
58 | # visual input: B x N x H. Want to avg pool over N => permute
59 | fragments = self.clip_avg(visual_input.permute(0, 2, 1))
60 | fragments = fragments.permute(0, 2, 1) # permute back to B x N' x H
61 | # bsz x num_fragments : True <==> valid entry
62 | fragments_mask = (self.clip_avg(visual_mask.to(dtype=torch.float)) == 1).squeeze()
63 |
64 | # modify samples where fragment_length > input_length with correct averaging over non-padded values only
65 | # so that even short samples are represented by at least one fragment
66 | input_lengths = visual_mask.sum(dim=1)
67 | num_fragments = fragments_mask.sum(dim=1)
68 | too_short_videos = num_fragments == 0
69 | for b in too_short_videos.nonzero():
70 | b = b.item()
71 | fragments[b, 0] = visual_input[b, visual_mask[b]].sum(dim=0) / input_lengths[b]
72 | fragments_mask[b, 0] = 1
73 | num_fragments[b] = 1
74 |
75 | with torch.no_grad():
76 | model_output = self.question_embedding(question_tokens, question_mask)
77 | query = mean_pooling(model_output, question_mask)
78 |
79 | fragments = self.transform_visual(fragments)
80 | query = self.transform_query(query).unsqueeze(dim=1) # add fake seq dimension
81 |
82 | # attn_mask: batch x querys x keys = bsz x 1 x num_fragments.
83 | # True <==> corresponding position is _not_ allowed to attend
84 | attn_mask = ~fragments_mask[:, None, :]
85 | context_token, attn_weights = self.query_attn(query, fragments, fragments, attn_mask=attn_mask,
86 | need_weights=return_attn_scores) # B x 1 x H
87 | if return_attn_scores:
88 | # trim padding indices away
89 | attn_weights = attn_weights.squeeze()
90 | return context_token, [w[:length] for w, length in zip(attn_weights, num_fragments)]
91 | else:
92 | return context_token
93 |
94 |
95 | def _extract_attn_scores(model: ImportanceTeacherVqaModel, train_iterator):
96 | all_scores = {}
97 | batch: EmqaBatch
98 | for batch in train_iterator:
99 | _, attn_scores = model.attend_to_fragments(
100 | batch['batch_question_tokens'].cuda(),
101 | batch['batch_question_mask'].cuda(),
102 | batch['batch_video_features'].cuda(),
103 | batch['batch_video_mask'].cuda(),
104 | return_attn_scores=True
105 | )
106 | for idx, scores in zip(batch['batch_sample_ids'], attn_scores):
107 | all_scores[idx] = scores.cpu()
108 |
109 | return all_scores
110 |
111 |
112 | @torch.no_grad()
113 | def main(output_file: str, checkpoint_path: str, config: DictConfig):
114 | checkpoint = torch.load(checkpoint_path)
115 | state_dict = {k[len('model.'):]: v for k, v in checkpoint['state_dict'].items()}
116 | model_cfg = dict(config.model)
117 | model_cfg.pop('_target_')
118 | model = ImportanceTeacherVqaModel(**model_cfg)
119 | # noinspection PyTypeChecker
120 | model.load_state_dict(state_dict)
121 |
122 | data = EmqaDataModule(config.dataset, drop_last=False)
123 | data.prepare_data()
124 | data.setup()
125 |
126 | model.cuda()
127 | result = {}
128 | with tqdm(data.train_dataloader(), 'Train') as train_iterator:
129 | result.update(_extract_attn_scores(model, train_iterator))
130 | with tqdm(data.val_dataloader(), 'Val') as val_iterator:
131 | result.update(_extract_attn_scores(model, val_iterator))
132 |
133 | torch.save(result, output_file)
134 |
135 |
136 | def cli_main():
137 | parser = ArgumentParser(description='Save attention scores from ImportanceTeacherVqaModel')
138 |
139 | parser.add_argument('--bsz', type=int, default=32)
140 | parser.add_argument('--output_file', type=str, required=True, help='Where to save attention scores.')
141 | parser.add_argument('--checkpoint_path', type=str, required=True,
142 | help='Where to load trained model from. PyTorch Lighting checkpoint from SimpleVqa training')
143 | args = parser.parse_args()
144 |
145 | initialize(config_path='../config', job_name="save_vqa_attn_scores")
146 | config = compose(config_name='base', overrides=[f'dataset.train_bsz={args.bsz}', f'dataset.test_bsz={args.bsz}'])
147 | main(args.output_file, args.checkpoint_path, config)
148 |
149 |
150 | if __name__ == '__main__':
151 | cli_main()
152 |
--------------------------------------------------------------------------------
/model/lightning.py:
--------------------------------------------------------------------------------
1 | import random
2 | from itertools import chain
3 | from typing import List, Dict, Union
4 |
5 | import torch
6 | from hydra.utils import instantiate
7 | from pytorch_lightning import LightningModule
8 | from transformers import PreTrainedTokenizerBase as Tokenizer
9 |
10 | from data.emqa_dataset import EmqaBatch
11 | from eval.eval import calc_metrics
12 | from lightning_util import freeze_params
13 | from model.base import EmqaBaseModel
14 |
15 |
16 | # noinspection PyAbstractClass
17 | class EmqaLightningModule(LightningModule):
18 |
19 | def __init__(self, model_config, optim_config, tokenizer: Tokenizer) -> None:
20 | super().__init__()
21 | self.save_hyperparameters(dict(model=model_config, optim=optim_config))
22 | self.model: EmqaBaseModel = instantiate(model_config)
23 | self.optimizer_config = optim_config.optimizer
24 | if 'loss_weights' in optim_config:
25 | self.loss_weights: Dict[str, float] = optim_config.loss_weights
26 | self._loss_calc_cache = None
27 | else:
28 | self.loss_weights = None
29 | self.tokenizer = tokenizer
30 | self.lr = self.optimizer_config.lr
31 | freeze_params(self, optim_config.freeze)
32 | self._log_indices = {}
33 |
34 | def training_step(self, batch: EmqaBatch, batch_idx):
35 | loss_dict, logits = self.model(**self._get_model_inputs(batch))
36 | total_loss = self._modify_loss_dict(loss_dict)
37 | for k, v in loss_dict.items():
38 | self.log(k, v)
39 | return total_loss
40 |
41 | def _modify_loss_dict(self, loss_dict: Dict[str, torch.Tensor]):
42 | if 'total_loss' in loss_dict:
43 | return loss_dict['total_loss']
44 | if len(loss_dict) == 1:
45 | # No matter how it's called, the single value is the total loss
46 | # However, no need to add it to loss_dict again (would only produce log twice)
47 | return next(iter(loss_dict.values()))
48 | assert self.loss_weights is not None
49 | if self._loss_calc_cache:
50 | master_key, remaining_keys = self._loss_calc_cache
51 | else:
52 | specified_keys = self.loss_weights.keys()
53 | all_keys = loss_dict.keys()
54 | master_keys = all_keys - specified_keys
55 | assert len(master_keys) == 1, f'There must be exactly one loss weight not specified (master weight), ' \
56 | f'got {all_keys} - {specified_keys} = {master_keys}'
57 | remaining_keys = all_keys - master_keys
58 | assert specified_keys == remaining_keys, f'{specified_keys} != {remaining_keys}'
59 | master_key = next(iter(master_keys))
60 | self._loss_calc_cache = (master_key, remaining_keys)
61 | master_loss = loss_dict[master_key]
62 | total_loss = master_loss + sum(self.loss_weights[k] * loss_dict[k] for k in remaining_keys)
63 | loss_dict['total_loss'] = total_loss # Add it to dict for logging purposes
64 | return total_loss
65 |
66 | @staticmethod
67 | def _get_model_inputs(batch: EmqaBatch):
68 | return dict(
69 | question_tokens=batch['batch_question_tokens'],
70 | question_mask=batch['batch_question_mask'],
71 | video_features=batch['batch_video_features'],
72 | video_mask=batch['batch_video_mask'],
73 | answer_tokens=batch['batch_answer_tokens'],
74 | answer_mask=batch['batch_answer_mask'],
75 | batch_sample_ids=batch['batch_sample_ids'],
76 | moment_localization_labels=batch.get('batch_moment_localization_labels')
77 | )
78 |
79 | def validation_step(self, batch: EmqaBatch, batch_idx):
80 | loss_dict, lm_logits = self.model(**self._get_model_inputs(batch))
81 | self._modify_loss_dict(loss_dict)
82 | hypo_answers = self._extract_answers(lm_logits)
83 | return {'hypos': hypo_answers,
84 | 'targets': batch['batch_answer_texts'],
85 | **loss_dict}
86 |
87 | def _extract_answers(self, lm_logits):
88 | sequences = lm_logits.argmax(dim=-1)
89 | hypo_answers = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
90 | return hypo_answers
91 |
92 | def validation_epoch_end(self, outputs: List[Dict[str, Union[torch.Tensor, List[str]]]]) -> None:
93 | def _mean(key):
94 | return torch.stack([data[key] for data in outputs]).mean()
95 |
96 | self._log_some_outputs(outputs, 'val')
97 | metrics = self.aggregate_metrics(outputs, prefix='val')
98 | metrics.update({
99 | f'val_{name}': _mean(name) for name in outputs[0].keys() if 'loss' in name
100 | })
101 | self.log_dict(metrics)
102 |
103 | def _log_some_outputs(self, outputs, name):
104 | num_val_steps_to_log, num_samples_per_batch_to_log = 5, 3 # Could be configurable via cfg
105 | if name in self._log_indices:
106 | steps_to_log_indices = self._log_indices[name]['steps']
107 | else:
108 | steps_to_log_indices = random.sample(range(len(outputs)), k=min(len(outputs), num_val_steps_to_log))
109 | self._log_indices[name] = {'steps': steps_to_log_indices, 'samples': [
110 | random.sample(range(len(outputs[step]['targets'])),
111 | k=min(len(outputs[step]['targets']), num_samples_per_batch_to_log))
112 | for step in steps_to_log_indices
113 | ]}
114 | for i, step in enumerate(steps_to_log_indices):
115 | output, target = outputs[step]['hypos'], outputs[step]['targets']
116 | indices = self._log_indices[name]['samples'][i]
117 | for b in indices:
118 | sample = (
119 | f'Target: "{target[b]}". \n'
120 | f'Output: "{output[b]}"'
121 | )
122 | self.logger.experiment.add_text(f'{name} {str(i * len(indices) + b)}', sample,
123 | global_step=self.global_step)
124 |
125 | @staticmethod
126 | def aggregate_metrics(outputs, prefix):
127 | all_hypos = list(chain(*(data['hypos'] for data in outputs)))
128 | all_targets = list(chain(*(data['targets'] for data in outputs)))
129 | metrics = calc_metrics(all_hypos, [[x] for x in all_targets])
130 | metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
131 | return metrics
132 |
133 | def test_step(self, batch: EmqaBatch, batch_idx):
134 | sequences = self.model(**self._get_model_inputs(batch), teacher_forcing=False)
135 | hypo_answers = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
136 | return {'hypos': hypo_answers,
137 | 'targets': batch['batch_answer_texts']}
138 |
139 | def test_epoch_end(self, outputs: List[Dict[str, Union[List[str], Dict]]]) -> None:
140 | self._log_some_outputs(outputs, 'test')
141 | metrics = self.aggregate_metrics(outputs, prefix='test')
142 | self.log_dict(metrics)
143 |
144 | def configure_optimizers(self):
145 | params = filter(lambda p: p.requires_grad, self.parameters())
146 | return instantiate(self.optimizer_config,
147 | params,
148 | # lr might be overridden by auto lr tuning
149 | lr=self.lr)
150 |
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Union
2 |
3 | import torch
4 | from torch import nn
5 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM
6 | from transformers.modeling_outputs import BaseModelOutput
7 |
8 |
9 | class EmqaBaseModel(nn.Module):
10 |
11 | def forward(self, question_tokens, question_mask,
12 | video_features, video_mask,
13 | answer_tokens, answer_mask,
14 | batch_sample_ids,
15 | moment_localization_labels,
16 | teacher_forcing=True) -> Union[Tuple[Dict[str, torch.Tensor], torch.Tensor],
17 | torch.Tensor]:
18 | if teacher_forcing:
19 | return self.teacher_forcing_forward(question_tokens, question_mask,
20 | video_features, video_mask,
21 | answer_tokens, answer_mask,
22 | batch_sample_ids, moment_localization_labels)
23 | else:
24 | return self.autoregressive_forward(question_tokens, question_mask,
25 | video_features, video_mask)
26 |
27 | def teacher_forcing_forward(self, question_tokens, question_mask,
28 | video_features, video_mask,
29 | answer_tokens, answer_mask,
30 | batch_sample_ids,
31 | moment_localization_labels
32 | ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
33 | """
34 | Forward in teacher forcing mode.
35 |
36 | :param question_tokens: B x Q
37 | :param question_mask: B x Q
38 | :param video_features: B x N x H
39 | :param video_mask: B x N x H
40 | :param answer_tokens: B x A
41 | :param answer_mask: B x A
42 | :param batch_sample_ids: List of sample ids for this batch.
43 | :param moment_localization_labels: B x N, with entries between 0.0 and 1.0
44 | (smoothed label saying if each frame is part of ground truth moment). Might be None.
45 | :return: tuple with loss_dict and anwer logits.
46 | loss_dict should at least contain "total_loss" Tensor.
47 | """
48 | raise NotImplementedError
49 |
50 | def autoregressive_forward(self, question_tokens, question_mask,
51 | video_features, video_mask
52 | ) -> torch.Tensor:
53 | """
54 | Forward in autoregressive mode.
55 |
56 | :param question_tokens: B x Q
57 | :param question_mask: B x Q
58 | :param video_features: B x N x H
59 | :param video_mask: B x N x H
60 | :return: tensor with answer tokens. B x A
61 | """
62 | raise NotImplementedError
63 |
64 |
65 | class MemoryAugmentedTransformerEmqaModel(EmqaBaseModel):
66 | def __init__(self,
67 | pretrained_enc_dec: str
68 | ) -> None:
69 | super().__init__()
70 | self.transformer: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(pretrained_enc_dec)
71 |
72 | def teacher_forcing_forward(self, question_tokens, question_mask,
73 | video_features, video_mask,
74 | answer_tokens, answer_mask,
75 | batch_sample_ids,
76 | moment_localization_labels) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
77 | context, context_mask, final_memory, mem_mask, enc_add_loss = self.forward_encoders(
78 | question_tokens, question_mask, video_features, video_mask, moment_localization_labels)
79 | output = self.transformer(labels=answer_tokens, decoder_attention_mask=answer_mask,
80 | encoder_outputs=(context,), attention_mask=context_mask)
81 | return {
82 | 'lm_loss': output.loss,
83 | **enc_add_loss,
84 | **self.calc_additional_loss(question_tokens, question_mask,
85 | video_features, video_mask,
86 | answer_tokens, answer_mask, batch_sample_ids,
87 | context, context_mask, final_memory, mem_mask,
88 | output, moment_localization_labels)
89 | }, output.logits
90 |
91 | def forward_encoders(self, question_tokens, question_mask, video_features, video_mask, moment_localization_labels):
92 | q_encoder_out = self.transformer.encoder(input_ids=question_tokens, attention_mask=question_mask)
93 | q_avg_pooled = q_encoder_out.last_hidden_state.mean(dim=1)
94 |
95 | mem_out = self.forward_memory(video_features, video_mask, moment_localization_labels, q_avg_pooled)
96 | if isinstance(mem_out, tuple):
97 | final_memory, additional_loss = mem_out
98 | else:
99 | assert torch.is_tensor(mem_out)
100 | final_memory = mem_out
101 | additional_loss = {}
102 | mem_mask = torch.ones(final_memory.shape[:-1], device=final_memory.device, dtype=torch.bool)
103 |
104 | context, context_mask = self._prepare_context(final_memory, mem_mask, question_mask=question_mask,
105 | q_encoder_out=q_encoder_out)
106 | return context, context_mask, final_memory, mem_mask, additional_loss
107 |
108 | def _prepare_context(self, final_memory, mem_mask, question_tokens=None, question_mask=None, q_encoder_out=None):
109 | if q_encoder_out:
110 | encoder_out = q_encoder_out
111 | else:
112 | encoder_out = self.transformer.encoder(input_ids=question_tokens, attention_mask=question_mask)
113 | context = torch.cat([
114 | final_memory,
115 | encoder_out.last_hidden_state
116 | ], dim=1)
117 | context_mask = torch.cat([
118 | mem_mask,
119 | question_mask
120 | ], dim=1)
121 | return context, context_mask
122 |
123 | def autoregressive_forward(self, question_tokens, question_mask, video_features, video_mask) -> torch.Tensor:
124 | context, context_mask, _, _, _ = self.forward_encoders(question_tokens, question_mask, video_features,
125 | video_mask, None)
126 | # noinspection PyTypeChecker
127 | enc_out = BaseModelOutput(last_hidden_state=context)
128 | return self.transformer.generate(encoder_outputs=enc_out, attention_mask=context_mask)
129 |
130 | def forward_memory(self, video_features, video_mask,
131 | moment_localization_labels,
132 | question_encoding) -> Union[torch.Tensor,
133 | Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
134 | """
135 | Forward video input to memory.
136 |
137 | :param video_features: Tensor of shape Batch x Sequence x Features
138 | :param video_mask: Tensor of shape Batch x Sequence
139 | :param moment_localization_labels: Tensor of shape Batch x Sequence, with entries between 0.0 and 1.0
140 | (smoothed label saying if each frame is part of ground truth moment). Might be None.
141 | :param question_encoding: Tensor of shape Batch x Hidden, Mean-pooled representation of the question.
142 | Should only be used for additional loss, not memory construction!
143 | :return: memory, Tensor of shape Batch x MemoryLength x Hidden
144 | or tuple of (memory, additional_loss_dict)
145 | """
146 | raise NotImplementedError
147 |
148 | # noinspection PyMethodMayBeStatic
149 | def calc_additional_loss(self, question_tokens, question_mask,
150 | video_features, video_mask,
151 | answer_tokens, answer_mask, batch_sample_ids,
152 | context, context_mask, final_memory, mem_mask,
153 | transformer_output, moment_localization_labels) -> Dict[str, torch.Tensor]:
154 | return {}
155 |
--------------------------------------------------------------------------------
/model/mann.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Tuple, Dict
2 |
3 | import torch
4 | from dnc import DNC
5 | from torch import nn
6 |
7 | from model.base import MemoryAugmentedTransformerEmqaModel
8 | from model.external.stm import STM
9 | from model.moment_loc import SeqMomentLocalizationLossModule
10 |
11 |
12 | class SegmentationMemoryAugmentedTransformerEmqaModel(MemoryAugmentedTransformerEmqaModel):
13 | def __init__(self, pretrained_enc_dec: str,
14 | segmentation_method: Literal['flat', 'avg'],
15 | segment_length: int,
16 | input_size: int, # this is downscaled to dnc_input_size for each time step separately
17 | mem_input_size: int, # Actual input to Memory depends on segmentation_method
18 | moment_loc: SeqMomentLocalizationLossModule = None
19 | ) -> None:
20 | super().__init__(pretrained_enc_dec)
21 | self.segment_length = segment_length
22 | self.segmentation_method = segmentation_method
23 | self.input_downscale = (nn.Identity() if input_size == mem_input_size
24 | else nn.Linear(input_size, mem_input_size, bias=False))
25 | self.actual_mem_input_size = (segment_length * mem_input_size
26 | if segmentation_method == 'flat' else mem_input_size)
27 | self.moment_loc = moment_loc
28 |
29 | def forward_memory(self, video_features, video_mask,
30 | moment_localization_labels,
31 | question_encoding) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
32 | # B: batch, H: downscaled hidden dim, L: segment_length, N: num_segments, S: input seq length
33 | seq_downscaled: torch.Tensor = self.input_downscale(video_features) # B x S x H
34 | if self.segmentation_method == 'flat':
35 | mem_in, mem_in_mask = self._segment_flat(seq_downscaled, video_mask)
36 | elif self.segmentation_method == 'avg':
37 | mem_in, mem_in_mask = self._segment_avg(seq_downscaled, video_mask)
38 | else:
39 | raise ValueError(self.segmentation_method)
40 | memory, item_output = self.forward_segmented_memory(mem_in)
41 |
42 | aux_loss = self._calc_aux_loss(item_output, mem_in_mask, moment_localization_labels, question_encoding)
43 | return memory, aux_loss
44 |
45 | def _calc_aux_loss(self, item_output, mem_in_mask, moment_localization_labels, question_encoding):
46 | has_moment_loc = (item_output is not None
47 | and self.moment_loc is not None
48 | and moment_localization_labels is not None)
49 | if has_moment_loc:
50 | # have: B x input_seq
51 | # need: B x N
52 | segments = list(moment_localization_labels.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L )
53 | segments = [segment_labels.sum(dim=1).clip(0, 1) for segment_labels in segments] # N tensors of B
54 | moment_localization_labels = torch.stack(segments, dim=1) # B x N
55 | aux_loss = {} if not has_moment_loc else {
56 | 'moment_localization': self.moment_loc(item_output, mem_in_mask,
57 | moment_localization_labels, question_encoding)
58 | }
59 | return aux_loss
60 |
61 | def _segment_flat(self, seq_downscaled, mask):
62 | bsz, _, h = seq_downscaled.shape
63 | segments = list(seq_downscaled.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) x H
64 | mask_segments = list(mask.split(self.segment_length, dim=1))
65 | last_segment_length = segments[-1].shape[1]
66 | if last_segment_length != self.segment_length:
67 | # Zero-pad to segment length so that dnc_in can be constructed correctly
68 | segments[-1] = torch.cat([
69 | segments[-1],
70 | torch.zeros(bsz, self.segment_length - last_segment_length, h,
71 | device=seq_downscaled.device, dtype=seq_downscaled.dtype)
72 | ], dim=1)
73 | mask_segments[-1] = torch.cat([
74 | mask_segments[-1],
75 | torch.zeros(bsz, self.segment_length - last_segment_length,
76 | device=mask.device, dtype=mask.dtype)
77 | ])
78 | segments = [s.view(bsz, -1) for s in segments] # N tensors of B x LH
79 | mem_in = torch.stack(segments, dim=1) # B x N x LH
80 | mem_in_mask = torch.stack(mask_segments, dim=1) # B x N x L
81 | mem_in_mask = mem_in_mask.any(dim=-1) # B x N
82 | return mem_in, mem_in_mask
83 |
84 | def _segment_avg(self, seq_downscaled, mask):
85 | bsz, _, h = seq_downscaled.shape
86 | segments = list(seq_downscaled.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L ) x H
87 | avg_segments = [x.mean(dim=1) for x in segments] # N tensors of B x H
88 | mem_in = torch.stack(avg_segments, dim=1) # B x N x H
89 | mask_segments = list(mask.split(self.segment_length, dim=1)) # N tensors of B x ( ≤L )
90 | mask_segments = [x.any(dim=1) for x in mask_segments] # N tensors of B
91 | segment_mask = torch.stack(mask_segments, dim=1) # B x N
92 | return mem_in, segment_mask
93 |
94 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # mem_in: B x N x LH
96 | raise NotImplementedError
97 |
98 |
99 | class DncEmqaModel(SegmentationMemoryAugmentedTransformerEmqaModel):
100 |
101 | def __init__(self, pretrained_enc_dec: str,
102 | segmentation_method: Literal['flat', 'avg'],
103 | segment_length: int,
104 | input_size: int, # this is downscaled to dnc_input_size for each time step separately
105 | dnc_input_size: int, # Actual input to DNC is segment_length * dnc_input_size
106 | rnn_hidden_size: int,
107 | num_dnc_layers=1,
108 | num_rnn_hidden_layers=2,
109 | num_mem_cells=5,
110 | mem_hidden_size=10,
111 | moment_loc: SeqMomentLocalizationLossModule = None
112 | ) -> None:
113 | super().__init__(pretrained_enc_dec, segmentation_method, segment_length,
114 | input_size, dnc_input_size, moment_loc)
115 | self.dnc = DNC(self.actual_mem_input_size,
116 | rnn_hidden_size, num_layers=num_dnc_layers,
117 | num_hidden_layers=num_rnn_hidden_layers,
118 | nr_cells=num_mem_cells, cell_size=mem_hidden_size)
119 |
120 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121 | self._hack_dnc_gpu_ids(mem_in.device)
122 | output, (controller_hidden, memory, read_vectors) = self.dnc(mem_in, reset_experience=True)
123 | # memory['memory']: Batch x NumMemCells x MemHiddenSize
124 | return memory['memory'], output
125 |
126 | def _hack_dnc_gpu_ids(self, device):
127 | gpu_id = -1 if device.type == 'cpu' else device.index
128 | if self.dnc.gpu_id == gpu_id:
129 | return
130 | # This is a hack because DNC is programmed not in a pytorch idiomatic way...
131 | self.dnc.gpu_id = gpu_id
132 | for m in self.dnc.memories:
133 | m.gpu_id = gpu_id
134 | m.I = m.I.to(device=device)
135 |
136 |
137 | class StmEmqaModel(SegmentationMemoryAugmentedTransformerEmqaModel):
138 |
139 | def __init__(self, pretrained_enc_dec: str,
140 | segmentation_method: Literal['flat', 'avg'],
141 | segment_length: int,
142 | input_size: int, # this is downscaled to dnc_input_size for each time step separately
143 | stm_input_size: int, # Actual input to DNC is segment_length * dnc_input_size
144 | mem_hidden_size=10,
145 | stm_step=1,
146 | stm_num_slot=8,
147 | stm_mlp_size=128,
148 | stm_slot_size=96,
149 | stm_rel_size=96,
150 | stm_out_att_size=64
151 | ) -> None:
152 | super().__init__(pretrained_enc_dec, segmentation_method, segment_length, input_size, stm_input_size)
153 | self.stm = STM(self.actual_mem_input_size, mem_hidden_size,
154 | stm_step, stm_num_slot, stm_mlp_size, stm_slot_size,
155 | stm_rel_size, stm_out_att_size)
156 |
157 | def forward_segmented_memory(self, mem_in: torch.Tensor) -> Tuple[torch.Tensor, None]:
158 | mem_in = mem_in.transpose(0, 1) # switch batch and sequence dim for STM
159 | output, (read_heads, item_memory_state, rel_memory_state) = self.stm(mem_in)
160 | # should return Batch x NumMemCells x MemHiddenSize
161 | # output is Batch x MemHiddenSize
162 | return output[:, None, :], None
163 |
--------------------------------------------------------------------------------
/lightning_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 | from argparse import Namespace
5 | from copy import deepcopy
6 | from fnmatch import fnmatchcase
7 | from pathlib import Path
8 | from typing import List
9 |
10 | import torch.nn
11 | from pytorch_lightning import seed_everything, Trainer, Callback
12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
13 | from pytorch_lightning.plugins import DDPPlugin
14 |
15 |
16 | def dict_parser(s: str):
17 | return eval('{' + re.sub(r'(\w+)=(["\']?\w+["\']?)', r'"\1":\2', s) + '}')
18 |
19 |
20 | def add_common_trainer_util_args(parser, default_monitor_variable='val_loss', default_monitor_mode='min'):
21 | if default_monitor_mode not in ['min', 'max']:
22 | raise ValueError(default_monitor_mode)
23 | parser.add_argument('--lr_find_kwargs', default=dict(min_lr=5e-6, max_lr=1e-2), type=dict_parser,
24 | help='Arguments for LR find (--auto_lr_find). Default "min_lr=5e-6,max_lr=1e-2"')
25 | parser.add_argument('--random_seed', default=42, type=lambda s: None if s == 'None' else int(s),
26 | help='Seed everything. Set to "None" to disable global seeding')
27 | parser.add_argument('--auto_resume', default=False, action='store_true',
28 | help='Automatically resume last saved checkpoint, if available.')
29 | parser.add_argument('--test_only', default=False, action='store_true',
30 | help='Skip fit and call only test. This implies automatically detecting newest checkpoint, '
31 | 'if --checkpoint_path is not given.')
32 | parser.add_argument('--checkpoint_path', default=None, type=str,
33 | help='Load this checkpoint to resume training or run testing. '
34 | 'Pass in the special value "best" to use the best checkpoint according to '
35 | 'args.monitor_variable and args.monitor_mode. '
36 | 'Using "best" only works with test_only mode.')
37 | parser.add_argument('--ignore_existing_checkpoints', default=False, action='store_true',
38 | help='Proceed even with training a new model, even if previous checkpoints exists.')
39 | parser.add_argument('--monitor_variable', default=default_monitor_variable, type=str,
40 | help='Variable to monitor for early stopping and for checkpoint selection. '
41 | f'Default: {default_monitor_variable}')
42 | parser.add_argument('--monitor_mode', default=default_monitor_mode, type=str, choices=['min', 'max'],
43 | help='Mode for monitoring the monitor_variable (for early stopping and checkpoint selection). '
44 | f'Default: {default_monitor_mode}')
45 | parser.add_argument('--reset_early_stopping_criterion', default=False, action='store_true',
46 | help='Reset the early stopping criterion when loading from checkpoint. '
47 | 'Prevents immediate exit after switching to more complex dataset in curriculum strategy')
48 |
49 |
50 | def _auto_resume_from_checkpoint(args):
51 | if getattr(args, 'resume_from_checkpoint', None) is not None:
52 | raise DeprecationWarning('Trainer.resume_from_checkpoint is deprecated. Switch to checkpoint_path argument.')
53 | best_mode = args.checkpoint_path == 'best'
54 | if best_mode:
55 | if not args.test_only:
56 | raise RuntimeError('checkpoint_path="best" only works in test_only mode!')
57 | # More "best" logic is handled below
58 | elif args.checkpoint_path is not None:
59 | return
60 |
61 | log_dir = Path(getattr(args, 'default_root_dir', None) or 'lightning_logs')
62 | existing_checkpoints = list(log_dir.glob('version_*/checkpoints/*.ckpt'))
63 | if len(existing_checkpoints) == 0:
64 | return # This is the first run
65 | if not args.test_only and not args.auto_resume:
66 | if args.ignore_existing_checkpoints:
67 | return # Explicitly requested
68 | raise RuntimeWarning(f"There already exist checkpoints, but checkpoint_path/auto_resume not set! "
69 | f"{existing_checkpoints}")
70 | if best_mode:
71 | chosen = _auto_choose_best_checkpoint(args, existing_checkpoints)
72 | else:
73 | chosen = _auto_choose_newest_checkpoint(existing_checkpoints)
74 | args.checkpoint_path = str(chosen)
75 | print(f'Auto-detected {"best" if best_mode else "newest"} checkpoint {chosen}, resuming it... '
76 | f'If this is not intended, use --checkpoint_path !')
77 |
78 |
79 | def _auto_choose_newest_checkpoint(existing_checkpoints):
80 | chosen = None
81 | for c in existing_checkpoints:
82 | if chosen is None or c.stat().st_mtime > chosen.stat().st_mtime:
83 | chosen = c
84 | return chosen
85 |
86 |
87 | def _auto_choose_best_checkpoint(args, existing_checkpoints):
88 | chosen = None
89 | for c in existing_checkpoints:
90 | if chosen is None or 'last.ckpt' == chosen.name:
91 | chosen = c
92 | continue
93 | if 'last.ckpt' == c.name:
94 | continue
95 | chosen_match = re.search(fr'{re.escape(args.monitor_variable)}=(\d+(?:\.\d+)?)', chosen.name)
96 | current_match = re.search(fr'{re.escape(args.monitor_variable)}=(\d+(?:\.\d+)?)', c.name)
97 | if chosen_match is None:
98 | raise ValueError(chosen)
99 | if current_match is None:
100 | raise ValueError(c)
101 | op = {'min': lambda old, new: new < old, 'max': lambda old, new: new > old}[args.monitor_mode]
102 | if op(float(chosen_match.group(1)), float(current_match.group(1))):
103 | chosen = c
104 | return chosen
105 |
106 |
107 | def apply_common_train_util_args(args) -> List[Callback]:
108 | _auto_resume_from_checkpoint(args)
109 | if args.random_seed is not None:
110 | seed_everything(args.random_seed, workers=True)
111 |
112 | early_stopping = EarlyStopping(monitor=args.monitor_variable, mode=args.monitor_mode,
113 | min_delta=0.001, patience=10,
114 | check_on_train_epoch_end=False)
115 | if args.reset_early_stopping_criterion:
116 | # Prevent loading the early stopping criterion when restoring from checkpoint
117 | early_stopping.on_load_checkpoint = lambda *args, **kwargs: None
118 | return [
119 | early_stopping,
120 | ModelCheckpoint(save_last=True, monitor=args.monitor_variable, mode=args.monitor_mode,
121 | save_top_k=1, filename='{step}-{' + args.monitor_variable + ':.3f}')
122 | ]
123 |
124 |
125 | def _ddp_save_tune(args, model, datamodule):
126 | """
127 | Runs LR tuning on main process only, _before_ DDP is initialized. Sets env var for communication to child processes.
128 | """
129 | lr_env_var = os.getenv('_pl_auto_lr_find')
130 | if lr_env_var is None:
131 | # Main process running this code will not have a value for the env var, and thus perform single-process LR tune
132 | args_copy = deepcopy(args)
133 | args_copy.strategy = None
134 | args_copy.gpus = 1
135 | print('Running single GPU tune...')
136 | single_process_trainer = Trainer.from_argparse_args(args_copy)
137 | single_process_trainer.tune(model, datamodule, lr_find_kwargs=dict(args.lr_find_kwargs))
138 | # "Broadcast" result to other ranks. Can not use actual broadcast mechanism,
139 | # since DDP env is not yet running.
140 | os.environ['_pl_auto_lr_find'] = str(model.lr)
141 | else:
142 | # Later workers executing this code will just load the best LR from the env var
143 | model.lr = float(os.environ['_pl_auto_lr_find'])
144 |
145 |
146 | def _adjust_ddp_config(trainer_cfg):
147 | trainer_cfg = dict(trainer_cfg)
148 | strategy = trainer_cfg.get('strategy', None)
149 | if trainer_cfg['gpus'] > 1 and strategy is None:
150 | strategy = 'ddp' # Select ddp by default
151 | if strategy == 'ddp':
152 | trainer_cfg['strategy'] = DDPPlugin(find_unused_parameters=False, gradient_as_bucket_view=True)
153 | return trainer_cfg
154 |
155 |
156 | def tune_fit_test(trainer_cfg, model, datamodule):
157 | callbacks = apply_common_train_util_args(trainer_cfg)
158 |
159 | trainer_cfg = Namespace(**_adjust_ddp_config(trainer_cfg))
160 | ddp_mode = isinstance(getattr(trainer_cfg, 'strategy', None), DDPPlugin) and trainer_cfg.gpus > 1
161 | if (ddp_mode
162 | and not trainer_cfg.test_only
163 | and trainer_cfg.checkpoint_path is None
164 | and trainer_cfg.auto_lr_find):
165 | # Do tuning with a "fake" trainer on single GPU
166 | _ddp_save_tune(trainer_cfg, model, datamodule)
167 |
168 | trainer = Trainer.from_argparse_args(trainer_cfg, callbacks=[
169 | LearningRateMonitor(logging_interval='step'),
170 | *callbacks
171 | ])
172 |
173 | if not trainer_cfg.test_only:
174 | if trainer_cfg.checkpoint_path is None and not ddp_mode:
175 | # Do tune with the trainer directly
176 | trainer.tune(model, datamodule, lr_find_kwargs=dict(trainer_cfg.lr_find_kwargs))
177 |
178 | trainer.fit(model, datamodule, ckpt_path=trainer_cfg.checkpoint_path)
179 |
180 | trainer.test(
181 | model, datamodule,
182 | ckpt_path=
183 | trainer_cfg.checkpoint_path if trainer_cfg.test_only
184 | else None # If fit was called before, we should _not_ restore initial train checkpoint
185 | )
186 |
187 | if not trainer_cfg.test_only and not trainer.fast_dev_run:
188 | # test call above runs test on last after training,
189 | # also want it on best checkpoint by default
190 | # before, flush stdout and err to avoid strange order of outputs
191 | print('', file=sys.stdout, flush=True)
192 | print('', file=sys.stderr, flush=True)
193 | trainer.test(model, datamodule, ckpt_path='best')
194 |
195 |
196 | def freeze_params(model: torch.nn.Module, freeze_spec: List[str]):
197 | """
198 | Freeze parameters that begin with any of the freeze_spec items or match any of them according to fnmatchcase.
199 |
200 | Parameters
201 | ----------
202 | model the model
203 | freeze_spec specifies which parameters to freeze (e.g. 'generator' or 'generator.h.*.weight')
204 | """
205 | for name, p in model.named_parameters():
206 | freeze = any(
207 | name.startswith(pattern) or fnmatchcase(name, pattern)
208 | for pattern in freeze_spec
209 | )
210 | if freeze:
211 | p.requires_grad = False
212 |
--------------------------------------------------------------------------------
/model/moment_loc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class SeqMomentLocalizationLossModule(nn.Module):
6 |
7 | def __init__(self,
8 | seq_hidden_dim: int,
9 | question_hidden_dim: int) -> None:
10 | super().__init__()
11 | total = seq_hidden_dim + question_hidden_dim
12 | self.projection_layer = nn.Sequential(
13 | nn.Linear(total, total // 2, bias=True),
14 | nn.LeakyReLU(),
15 | nn.Linear(total // 2, 1, bias=False),
16 | nn.Tanh()
17 | )
18 | self.loss = nn.MSELoss(reduction='none')
19 |
20 | def forward(self, output_sequence, sequence_mask,
21 | moment_localization_labels, question_encoding) -> torch.Tensor:
22 | """
23 | Calculate moment localization loss based on output sequence.
24 |
25 | :param output_sequence: Tensor of shape Batch x Sequence x Hidden
26 | :param sequence_mask: Tensor of shape Batch x Sequence
27 | :param moment_localization_labels: Tensor of shape Batch x Sequence, with entries between 0.0 and 1.0
28 | (smoothed label saying if each frame is part of ground truth moment). Might be None.
29 | :param question_encoding: Tensor of shape Batch x Hidden, Mean-pooled representation of the question.
30 | :return: scalar Tensor containing moment localization loss
31 | """
32 | bsz, seq = output_sequence.shape[:2]
33 | question = question_encoding[:, None, :].expand(bsz, seq, -1)
34 | seq_with_question = torch.cat([output_sequence, question],
35 | dim=-1) # Batch x Sequence x 2*Hidden
36 | output_scores = self.projection_layer(seq_with_question) # Batch x Sequence x 1
37 | loss_per_item = self.loss(output_scores.squeeze(), moment_localization_labels)
38 | loss_per_item = loss_per_item * sequence_mask
39 | return loss_per_item.mean()
40 |
41 |
42 | class TransformerMomentLocalizationLossModule(nn.Module):
43 |
44 | def __init__(self, softmax_temperature, att_loss_type, hinge_margin, lse_alpha):
45 | # softmax_temperature: softmax softmax_temperature when rescaling video-only attention scores
46 | # sum_attentions: whether to use LSE loss on sum of attention weights (target vs non-target interval)
47 | super().__init__()
48 | self.att_loss_type = att_loss_type
49 | self.softmax_temperature = softmax_temperature
50 | self.hinge_margin = hinge_margin
51 | self.lse_alpha = lse_alpha
52 |
53 | def forward(self, question_tokens, transformer_output,
54 | moment_localization_labels, video_mask):
55 | """
56 |
57 | :param question_tokens: Batch x QuestionSequence
58 | :param transformer_output: transformer.ModelOutputWithCrossAttentions. Input to cross attention is assumed as
59 | concatenated [video, question], so that transformer_output.cross_attentions has shape
60 | Batch x NumHeads x DecoderSequence x (VideoSeq + QuestionSeq)
61 | :param moment_localization_labels: Batch x VideoSequence. 1 for target frames, 0 elsewhere
62 | :param video_mask: Batch x VideoSequence. 0 for masked indices, 1 for valid
63 | :return: loss
64 | """
65 | # Tuple of tensors (one for each layer) of shape (batch, num_heads, dec_seq_len, enc_seq_len)
66 | cross_attentions = transformer_output.cross_attentions
67 | # (batch, num_layers, num_heads, dec_seq_len, enc_seq_len)
68 | cross_attentions = torch.stack(cross_attentions, dim=1)
69 | question_len = question_tokens.shape[1]
70 |
71 | # Need to take care of masked indices before re-applying softmax
72 | cross_attn_on_video = cross_attentions[:, :, :, :, :-question_len] / self.softmax_temperature
73 | mask_value = torch.scalar_tensor(-1e6, dtype=cross_attentions.dtype, device=cross_attentions.device)
74 | cross_attn_on_video = torch.where(video_mask[:, None, None, None, :], cross_attn_on_video, mask_value)
75 | cross_attn_on_video = cross_attn_on_video.softmax(dim=-1)
76 |
77 | # loss per batch
78 | loss = self.calc_attention_loss(cross_attn_on_video, moment_localization_labels, video_mask)
79 | return loss.mean()
80 |
81 | def calc_attention_loss(self, cross_attn_on_video, moment_localization_labels, video_mask):
82 | raise NotImplementedError
83 |
84 | def ranking_loss(self, pos_scores, neg_scores):
85 | if self.att_loss_type == "hinge":
86 | # max(0, m + S_pos - S_neg)
87 | loss = torch.clamp(self.hinge_margin + neg_scores - pos_scores, min=0)
88 | elif self.att_loss_type == "lse":
89 | # log[1 + exp(scale * (S_pos - S_neg))]
90 | loss = torch.log1p(torch.exp(self.lse_alpha * (neg_scores - pos_scores)))
91 | else:
92 | raise NotImplementedError("Only support hinge and lse")
93 | return loss
94 |
95 |
96 | class SummedAttentionTransformerMomentLocLoss(TransformerMomentLocalizationLossModule):
97 |
98 | def __init__(self, softmax_temperature=0.1, att_loss_type='lse', hinge_margin=0.4, lse_alpha=20):
99 | super().__init__(softmax_temperature, att_loss_type, hinge_margin, lse_alpha)
100 |
101 | def calc_attention_loss(self, cross_attn_on_video, moment_localization_labels, video_mask):
102 | bsz = len(video_mask)
103 | pos_mask = moment_localization_labels == 1
104 | neg_mask = (moment_localization_labels == 0) * video_mask
105 | mean_attn = cross_attn_on_video.mean(dim=(1, 2, 3)) # mean over heads, layers and decoder positions
106 | pos_scores = torch.stack([
107 | mean_attn[b, pos_mask[b]].sum()
108 | for b in range(bsz)
109 | ])
110 | neg_scores = torch.stack([
111 | mean_attn[b, neg_mask[b]].sum()
112 | for b in range(bsz)
113 | ])
114 | return self.ranking_loss(pos_scores, neg_scores)
115 |
116 |
117 | # This is copied & modified from https://github.com/jayleicn/TVQAplus
118 | class SamplingAttentionTransformerMomentLocLoss(TransformerMomentLocalizationLossModule):
119 |
120 | def __init__(self, num_negatives=2, use_hard_negatives=False, drop_topk=0,
121 | negative_pool_size=0, num_hard=2,
122 | softmax_temperature=0.1, att_loss_type='lse', hinge_margin=0.4, lse_alpha=20) -> None:
123 | super().__init__(softmax_temperature, att_loss_type, hinge_margin, lse_alpha)
124 | self.num_hard = num_hard
125 | self.negative_pool_size = negative_pool_size
126 | self.drop_topk = drop_topk
127 | self.use_hard_negatives = use_hard_negatives
128 | self.num_negatives = num_negatives
129 |
130 | def calc_attention_loss(self, cross_attn_on_video, att_labels, video_mask):
131 | # take max head, mean over layers and decoder positions
132 | scores = cross_attn_on_video.max(dim=2).values.mean(dim=(1, 2))
133 |
134 | # att_labels : Batch x VideoSequence. 0 for non-target, 1 for target
135 | # scores : Batch x VideoSequence. Between 0 and 1 as given by softmax
136 | # rescale to [-1, 1]
137 | scores = scores * 2 - 1
138 |
139 | pos_container = [] # contains tuples of 2 elements, which are (batch_i, img_i)
140 | neg_container = []
141 | bsz = len(att_labels)
142 | for b in range(bsz):
143 | pos_indices = att_labels[b].nonzero() # N_pos x 1
144 | neg_indices = ((1 - att_labels[b]) * video_mask[b]).nonzero() # N_neg x 1
145 |
146 | sampled_pos_indices, sampled_neg_indices = self._sample_negatives(scores[b], pos_indices, neg_indices)
147 |
148 | base_indices = torch.full((sampled_pos_indices.shape[0], 1), b, dtype=torch.long, device=pos_indices.device)
149 | pos_container.append(torch.cat([base_indices, sampled_pos_indices], dim=1))
150 | neg_container.append(torch.cat([base_indices, sampled_neg_indices], dim=1))
151 |
152 | pos_container = torch.cat(pos_container, dim=0)
153 | neg_container = torch.cat(neg_container, dim=0)
154 |
155 | pos_scores = scores[pos_container[:, 0], pos_container[:, 1]]
156 | neg_scores = scores[neg_container[:, 0], neg_container[:, 1]]
157 |
158 | att_loss = self.ranking_loss(pos_scores, neg_scores).mean(dim=-1)
159 | return att_loss
160 |
161 | def _sample_negatives(self, pred_score, pos_indices, neg_indices):
162 | """ Sample negatives from a set of indices. Several sampling strategies are supported:
163 | 1, random; 2, hard negatives; 3, drop_topk hard negatives; 4, mix easy and hard negatives
164 | 5, sampling within a pool of hard negatives; 6, sample across images of the same video.
165 | Args:
166 | pred_score: (num_img)
167 | pos_indices: (N_pos, 1)
168 | neg_indices: (N_neg, 1)
169 | Returns:
170 |
171 | """
172 | num_unique_pos = len(pos_indices)
173 | sampled_pos_indices = torch.cat([pos_indices] * self.num_negatives, dim=0)
174 | if self.use_hard_negatives:
175 | # print("using use_hard_negatives")
176 | neg_scores = pred_score[neg_indices[:, 0]]
177 | max_indices = torch.sort(neg_scores, descending=True)[1].tolist()
178 | if self.negative_pool_size > self.num_negatives: # sample from a pool of hard negatives
179 | hard_pool = max_indices[self.drop_topk:self.drop_topk + self.negative_pool_size]
180 | hard_pool_indices = neg_indices[hard_pool]
181 | num_hard_negs = self.num_negatives
182 | sampled_easy_neg_indices = []
183 | if self.num_hard < self.num_negatives:
184 | easy_pool = max_indices[self.drop_topk + self.negative_pool_size:]
185 | easy_pool_indices = neg_indices[easy_pool]
186 | num_hard_negs = self.num_hard
187 | num_easy_negs = self.num_negatives - num_hard_negs
188 | sampled_easy_neg_indices = easy_pool_indices[
189 | torch.randint(low=0, high=len(easy_pool_indices),
190 | size=(num_easy_negs * num_unique_pos,), dtype=torch.long)
191 | ]
192 | sampled_hard_neg_indices = hard_pool_indices[
193 | torch.randint(low=0, high=len(hard_pool_indices),
194 | size=(num_hard_negs * num_unique_pos,), dtype=torch.long)
195 | ]
196 |
197 | if len(sampled_easy_neg_indices) != 0:
198 | sampled_neg_indices = torch.cat([sampled_hard_neg_indices, sampled_easy_neg_indices], dim=0)
199 | else:
200 | sampled_neg_indices = sampled_hard_neg_indices
201 |
202 | else: # directly take the top negatives
203 | sampled_neg_indices = neg_indices[max_indices[self.drop_topk:self.drop_topk + len(sampled_pos_indices)]]
204 | else:
205 | sampled_neg_indices = neg_indices[
206 | torch.randint(low=0, high=len(neg_indices), size=(len(sampled_pos_indices),), dtype=torch.long)
207 | ]
208 | return sampled_pos_indices, sampled_neg_indices
209 |
--------------------------------------------------------------------------------
/model/external/stm.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/thaihungle/SAM/blob/master/baselines/sam/stm_basic.py
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def op_att(q, k, v):
10 | qq = q.unsqueeze(2).repeat(1, 1, k.shape[1], 1)
11 | kk = k.unsqueeze(1).repeat(1, q.shape[1], 1, 1)
12 | output = torch.matmul(F.tanh(qq * kk).unsqueeze(4), v.unsqueeze(1).repeat(1, q.shape[1], 1, 1).unsqueeze(
13 | 3)) # BxNXNxd_kq BxNxNxd_v --> BxNXNxd_kqxd_v
14 | # print(output.shape)
15 | output = torch.sum(output, dim=2) # BxNxd_kqxd_v
16 | # print(output.shape)
17 | return output
18 |
19 |
20 | def sdp_att(q, k, v):
21 | dot_product = torch.matmul(q, k.permute(0, 2, 1))
22 | weights = F.softmax(dot_product, dim=-1)
23 |
24 | # output is [B, H, N, V]
25 | output = torch.matmul(weights, v)
26 | return output
27 |
28 |
29 | class MLP(nn.Module):
30 | def __init__(self, in_dim=28 * 28, out_dim=10, hid_dim=-1, layers=1):
31 | super(MLP, self).__init__()
32 | self.layers = layers
33 | if hid_dim <= 0:
34 | self.layers = -1
35 | if self.layers < 0:
36 | hid_dim = out_dim
37 | self.fc1 = nn.Linear(in_dim, hid_dim)
38 | # linear layer (n_hidden -> hidden_2)
39 | if self.layers > 0:
40 | self.fc2h = nn.ModuleList([nn.Linear(hid_dim, hid_dim)] * self.layers)
41 | # linear layer (n_hidden -> 10)
42 | if self.layers >= 0:
43 | self.fc3 = nn.Linear(hid_dim, out_dim)
44 |
45 | def forward(self, x):
46 | o = self.fc1(x)
47 | if self.layers > 0:
48 | for l in range(self.layers):
49 | o = self.fc2h[l](o)
50 | if self.layers >= 0:
51 | o = self.fc3(o)
52 | return o
53 |
54 |
55 | class STM(nn.Module):
56 | def __init__(self, input_size, output_size, step=1, num_slot=8,
57 | mlp_size=128, slot_size=96, rel_size=96,
58 | out_att_size=64, rd=True,
59 | init_alphas=[None, None, None],
60 | learn_init_mem=True, mlp_hid=-1):
61 | super(STM, self).__init__()
62 | self.mlp_size = mlp_size
63 | self.slot_size = slot_size
64 | self.rel_size = rel_size
65 | self.rnn_hid = slot_size
66 | self.num_slot = num_slot
67 | self.step = step
68 | self.rd = rd
69 | self.learn_init_mem = learn_init_mem
70 |
71 | self.out_att_size = out_att_size
72 |
73 | self.qkv_projector = nn.ModuleList([nn.Linear(slot_size, num_slot * 3)] * step)
74 | self.qkv_layernorm = nn.ModuleList([nn.LayerNorm([slot_size, num_slot * 3])] * step)
75 |
76 | if init_alphas[0] is None:
77 | self.alpha1 = [nn.Parameter(torch.zeros(1))] * step
78 | for ia, a in enumerate(self.alpha1):
79 | setattr(self, 'alpha1' + str(ia), self.alpha1[ia])
80 | else:
81 | self.alpha1 = [init_alphas[0]] * step
82 |
83 | if init_alphas[1] is None:
84 | self.alpha2 = [nn.Parameter(torch.zeros(1))] * step
85 | for ia, a in enumerate(self.alpha2):
86 | setattr(self, 'alpha2' + str(ia), self.alpha2[ia])
87 | else:
88 | self.alpha2 = [init_alphas[1]] * step
89 |
90 | if init_alphas[2] is None:
91 | self.alpha3 = [nn.Parameter(torch.zeros(1))] * step
92 | for ia, a in enumerate(self.alpha3):
93 | setattr(self, 'alpha3' + str(ia), self.alpha3[ia])
94 | else:
95 | self.alpha3 = [init_alphas[2]] * step
96 |
97 | self.input_projector = MLP(input_size, slot_size, hid_dim=mlp_hid)
98 | self.input_projector2 = MLP(input_size, slot_size, hid_dim=mlp_hid)
99 | self.input_projector3 = MLP(input_size, num_slot, hid_dim=mlp_hid)
100 |
101 | self.input_gate_projector = nn.Linear(self.slot_size, self.slot_size * 2)
102 | self.memory_gate_projector = nn.Linear(self.slot_size, self.slot_size * 2)
103 | # trainable scalar gate bias tensors
104 | self.forget_bias = nn.Parameter(torch.tensor(1., dtype=torch.float32))
105 | self.input_bias = nn.Parameter(torch.tensor(0., dtype=torch.float32))
106 |
107 | self.rel_projector = nn.Linear(slot_size * slot_size, rel_size)
108 | self.rel_projector2 = nn.Linear(num_slot * slot_size, slot_size)
109 | self.rel_projector3 = nn.Linear(num_slot * rel_size, out_att_size)
110 |
111 | self.mlp = nn.Sequential(
112 | nn.Linear(out_att_size, self.mlp_size),
113 | nn.ReLU(),
114 | nn.Linear(self.mlp_size, self.mlp_size),
115 | nn.ReLU(),
116 | )
117 |
118 | self.out = nn.Linear(self.mlp_size, output_size)
119 |
120 | if self.learn_init_mem:
121 | if torch.cuda.is_available():
122 | self.register_parameter('item_memory_state_bias',
123 | torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size).cuda()))
124 | self.register_parameter('rel_memory_state_bias', torch.nn.Parameter(
125 | torch.Tensor(self.num_slot, self.slot_size, self.slot_size).cuda()))
126 |
127 | else:
128 | self.register_parameter('item_memory_state_bias',
129 | torch.nn.Parameter(torch.Tensor(self.slot_size, self.slot_size)))
130 | self.register_parameter('rel_memory_state_bias',
131 | torch.nn.Parameter(torch.Tensor(self.num_slot, self.slot_size, self.slot_size)))
132 |
133 | stdev = 1 / (np.sqrt(self.slot_size + self.slot_size))
134 | nn.init.uniform_(self.item_memory_state_bias, -stdev, stdev)
135 | stdev = 1 / (np.sqrt(self.slot_size + self.slot_size + self.num_slot))
136 | nn.init.uniform_(self.rel_memory_state_bias, -stdev, stdev)
137 |
138 | def create_new_state(self, batch_size):
139 | if self.learn_init_mem:
140 | read_heads = torch.zeros(batch_size, self.out_att_size)
141 | item_memory_state = self.item_memory_state_bias.clone().repeat(batch_size, 1, 1)
142 | rel_memory_state = self.rel_memory_state_bias.clone().repeat(batch_size, 1, 1, 1)
143 | if torch.cuda.is_available():
144 | read_heads = read_heads.cuda()
145 | else:
146 | item_memory_state = torch.stack([torch.zeros(self.slot_size, self.slot_size) for _ in range(batch_size)])
147 | read_heads = torch.zeros(batch_size, self.out_att_size)
148 | rel_memory_state = torch.stack(
149 | [torch.zeros(self.num_slot, self.slot_size, self.slot_size) for _ in range(batch_size)])
150 | if torch.cuda.is_available():
151 | item_memory_state = item_memory_state.cuda()
152 | read_heads = read_heads.cuda()
153 | rel_memory_state = rel_memory_state.cuda()
154 |
155 | return read_heads, item_memory_state, rel_memory_state
156 |
157 | def compute_gates(self, inputs, memory):
158 |
159 | memory = torch.tanh(memory)
160 | if len(inputs.shape) == 3:
161 | if inputs.shape[1] > 1:
162 | raise ValueError(
163 | "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
164 | inputs = inputs.view(inputs.shape[0], -1)
165 |
166 | gate_inputs = self.input_gate_projector(inputs)
167 | gate_inputs = gate_inputs.unsqueeze(dim=1)
168 | gate_memory = self.memory_gate_projector(memory)
169 | else:
170 | raise ValueError("input shape of create_gate function is 2, expects 3")
171 |
172 | gates = gate_memory + gate_inputs
173 | gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
174 | input_gate, forget_gate = gates
175 | assert input_gate.shape[2] == forget_gate.shape[2]
176 |
177 | input_gate = torch.sigmoid(input_gate + self.input_bias)
178 | forget_gate = torch.sigmoid(forget_gate + self.forget_bias)
179 |
180 | return input_gate, forget_gate
181 |
182 | def compute(self, input_step, prev_state):
183 |
184 | hid = prev_state[0]
185 | item_memory_state = prev_state[1]
186 | rel_memory_state = prev_state[2]
187 |
188 | # transform input
189 | controller_outp = self.input_projector(input_step)
190 | controller_outp2 = self.input_projector2(input_step)
191 | controller_outp3 = self.input_projector3(input_step)
192 |
193 | # Mi write
194 | X = torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1)) # Bxdxd
195 | input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), item_memory_state)
196 |
197 | # Mr read
198 | controller_outp3 = F.softmax(controller_outp3, dim=-1)
199 | controller_outp4 = torch.einsum('bn,bd,bndf->bf', controller_outp3, controller_outp2, rel_memory_state)
200 | X2 = torch.einsum('bd,bf->bdf', controller_outp4, controller_outp2)
201 |
202 | if self.rd:
203 | # Mi write gating
204 | R = input_gate * F.tanh(X)
205 | R += forget_gate * item_memory_state
206 | else:
207 | # Mi write
208 | R = item_memory_state + torch.matmul(controller_outp.unsqueeze(2), controller_outp.unsqueeze(1)) # Bxdxd
209 |
210 | for i in range(self.step):
211 | # SAM
212 | qkv = self.qkv_projector[i](R + self.alpha2[i] * X2)
213 | qkv = self.qkv_layernorm[i](qkv)
214 | qkv = qkv.permute(0, 2, 1) # Bx3Nxd
215 |
216 | q, k, v = torch.split(qkv, [self.num_slot] * 3, 1) # BxNxd
217 |
218 | R0 = op_att(q, k, v) # BxNxdxd
219 |
220 | # Mr transfer to Mi
221 | R2 = self.rel_projector2(R0.view(R0.shape[0], -1, R0.shape[3]).permute(0, 2, 1))
222 | R = R + self.alpha3[i] * R2
223 |
224 | # Mr write
225 | rel_memory_state = self.alpha1[i] * rel_memory_state + R0
226 |
227 | # Mr transfer to output
228 | r_vec = self.rel_projector(rel_memory_state.view(rel_memory_state.shape[0],
229 | rel_memory_state.shape[1],
230 | -1)).view(input_step.shape[0], -1)
231 | out = self.rel_projector3(r_vec)
232 |
233 | # if self.gating_after:
234 | # #Mi write gating
235 | # input_gate, forget_gate = self.compute_gates(controller_outp.unsqueeze(1), R)
236 | # if self.rd:
237 | # R = input_gate * torch.tanh(R)
238 | # R += forget_gate * item_memory_state
239 |
240 | return out, (out, R, rel_memory_state)
241 |
242 | def forward(self, input_step, hidden=None):
243 |
244 | if len(input_step.shape) == 3:
245 | self.init_sequence(input_step.shape[1])
246 | for i in range(input_step.shape[0]):
247 | logit, self.previous_state = self.compute(input_step[i], self.previous_state)
248 |
249 | else:
250 | if hidden is not None:
251 | logit, hidden = self.compute(input_step, hidden)
252 | else:
253 | logit, self.previous_state = self.compute(input_step, self.previous_state)
254 | mlp = self.mlp(logit)
255 | out = self.out(mlp)
256 | return out, self.previous_state
257 |
258 | def init_sequence(self, batch_size):
259 | """Initializing the state."""
260 | self.previous_state = self.create_new_state(batch_size)
261 |
262 | def calculate_num_params(self):
263 | """Returns the total number of parameters."""
264 | num_params = 0
265 | for p in self.parameters():
266 | num_params += p.data.view(-1).size(0)
267 | return num_params
268 |
--------------------------------------------------------------------------------
/model/rehearsal.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Optional, Any
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from transformers import AutoModel, PreTrainedModel
9 | from transformers.models.roberta import RobertaModel, RobertaConfig
10 |
11 | from model.base import MemoryAugmentedTransformerEmqaModel
12 |
13 |
14 | class RehearsalTrainingModule(nn.Module):
15 |
16 | def __init__(self,
17 | input_size: int,
18 | mem_hidden_size: int,
19 | num_samples: int,
20 | sample_length: int,
21 | positive_mask_ratio: float = 0.5,
22 | negative_replacement_ratio: float = 0.5,
23 | invert_teacher_sequence: bool = False,
24 | pretrained_decoder: Optional[str] = None,
25 | decoder_params: Optional[Dict[str, Any]] = None,
26 | sampling_teacher_weights_file: Optional[str] = None
27 | ) -> None:
28 | super().__init__()
29 | if pretrained_decoder is not None:
30 | self.decoder = RobertaModel.from_pretrained(pretrained_decoder, config=RobertaConfig.from_pretrained(
31 | pretrained_decoder, is_decoder=True, add_cross_attention=True
32 | ))
33 | else:
34 | model_cfg = RobertaConfig(**decoder_params, is_decoder=True, add_cross_attention=True)
35 | self.decoder = RobertaModel(model_cfg)
36 | hidden_size = self.decoder.config.hidden_size
37 | self.input_dimension_adjustment_layer = nn.Linear(input_size, hidden_size, bias=False)
38 | self.em_dimension_adjustment_layer = nn.Linear(mem_hidden_size, hidden_size, bias=False)
39 | self.num_samples = num_samples
40 | self.sample_length = sample_length
41 | self.positive_mask_ratio = positive_mask_ratio
42 | self.negative_replacement_ratio = negative_replacement_ratio
43 | self.invert_teacher_sequence = invert_teacher_sequence
44 |
45 | empty = torch.empty(2, 1, hidden_size)
46 | nn.init.kaiming_uniform_(empty)
47 | self.class_token_emb = nn.Parameter(empty[0])
48 | self.mask_token_emb = nn.Parameter(empty[1])
49 | self.pos_neg_projection = nn.Sequential(
50 | nn.Linear(hidden_size, 1, bias=False),
51 | nn.Sigmoid()
52 | )
53 |
54 | if sampling_teacher_weights_file:
55 | self.teacher_sampling_weights = torch.load(sampling_teacher_weights_file)
56 | else:
57 | self.teacher_sampling_weights = None
58 |
59 | def forward(self, memory, mem_mask, original_input, input_mask, batch_sample_ids):
60 | bsz, l_x = original_input.shape[0:2]
61 | h = self.decoder.config.hidden_size
62 | sample_length = min(self.sample_length, l_x)
63 | num_tokens_to_mask = int(sample_length * self.positive_mask_ratio)
64 | num_tokens_to_replace = int((sample_length - num_tokens_to_mask) * self.negative_replacement_ratio)
65 | dims = bsz, l_x, h, sample_length, num_tokens_to_mask, num_tokens_to_replace
66 | assert bsz > 1 # Negative sampling from other batches requires at least bsz=2
67 |
68 | samples, negative_samples, masked_items_map, padding_mask, start_indices = self._construct_samples(
69 | original_input, input_mask, dims, batch_sample_ids)
70 | hiddens_neg, hiddens_pos = self._forward_transformer(memory, mem_mask, samples, negative_samples,
71 | padding_mask, dims)
72 |
73 | recollection_loss = self._calc_recollection_loss(original_input, hiddens_pos,
74 | masked_items_map, start_indices, dims)
75 | familiarity_loss = self._calc_familiarity_loss(hiddens_neg, hiddens_pos, dims)
76 |
77 | return recollection_loss, familiarity_loss
78 |
79 | def _construct_samples(self, original_input, input_padding_mask, dims, batch_sample_ids):
80 | original_input = self.input_dimension_adjustment_layer(original_input)
81 | # original_input: bsz x l_x x h
82 | # input_padding_mask: bsz x l_x
83 | bsz, l_x, h, sample_length, num_tokens_to_mask, num_tokens_to_replace = dims
84 | start_indices = self._choose_start_indices(dims, batch_sample_ids)
85 | samples = torch.stack([
86 | torch.stack([
87 | torch.cat((self.class_token_emb, original_input[i, start:start + sample_length]), dim=0)
88 | for start in start_indices[i]
89 | ])
90 | for i in range(bsz)
91 | ]) # bsz x num_samples x sample_length x hidden
92 | negative_samples = samples.clone()
93 | # masked_items_map is 1 for masked input items (!) i.e. 1 for items that are replaced with mask token.
94 | masked_items_map = torch.zeros(bsz, self.num_samples, 1 + sample_length, # + 1 for CLS token!
95 | device=samples.device, dtype=torch.bool)
96 | padding_mask = torch.ones_like(masked_items_map) # one for items that should be used, 0 for padding
97 | all_indices = range(1, 1 + sample_length)
98 | for i in range(bsz):
99 | for s in range(self.num_samples):
100 | start = start_indices[i, s].item()
101 | padding_mask[i, s, 1:] = input_padding_mask[i, start:start + sample_length]
102 | padding_indices = (~padding_mask).nonzero()
103 | mask_indices = np.random.choice(all_indices, num_tokens_to_mask, replace=False)
104 | # Set masked_items_map only if these items are not padded. padding_mask is zero for non-padded entries
105 | masked_items_map[i, s, mask_indices] = padding_mask[i, s, mask_indices]
106 | unmasked_indices = list(set(all_indices) - set(mask_indices) - set(padding_indices))
107 | replacement_indices = np.random.choice(unmasked_indices, num_tokens_to_replace, replace=False)
108 | neg_sample_original_indices = replacement_indices + start - 1 # -1 because of CLS token
109 | negative_samples[i, s, replacement_indices] = original_input[(i + 1) % bsz, neg_sample_original_indices]
110 | samples[masked_items_map] = self.mask_token_emb
111 | negative_samples[masked_items_map] = self.mask_token_emb
112 | return samples, negative_samples, masked_items_map, padding_mask, start_indices
113 |
114 | def _choose_start_indices(self, dims, batch_sample_ids):
115 | bsz, l_x, _, sample_length, _, _ = dims
116 | num_fragments = l_x - sample_length + 1
117 | if self.teacher_sampling_weights is None:
118 | # uniform random sampling
119 | start_indices = np.stack([np.random.choice(max(1, num_fragments), self.num_samples, replace=False)
120 | for _ in range(bsz)])
121 | else:
122 | # biased random sampling guided by unconstrained teacher model (see section "What to rehearse?" in RM paper)
123 | teacher_attn_weights = [
124 | self.teacher_sampling_weights[sample_id]
125 | for sample_id in batch_sample_ids
126 | ]
127 | sample_distribution = torch.nn.utils.rnn.pad_sequence(teacher_attn_weights, batch_first=True)
128 | assert sample_distribution.shape[1] == num_fragments, f'{sample_distribution.shape}, {num_fragments}'
129 | if self.invert_teacher_sequence:
130 | sample_distribution = sample_distribution.flip(dims=(1,))
131 | cum_distribution = sample_distribution.cumsum(dim=-1)
132 | rand_sources = torch.rand(bsz, self.num_samples)
133 | # noinspection PyTypeChecker
134 | start_indices = torch.sum(cum_distribution[:, None, :] < rand_sources[:, :, None], dim=-1)
135 | return start_indices
136 |
137 | def _forward_transformer(self, memory, mem_mask, samples, negative_samples, padding_mask, dims):
138 | bsz, _, h, sample_length, _, _ = dims
139 | sample_length = 1 + sample_length # CLS token at the beginning
140 |
141 | model_in = torch.cat((samples.view(-1, sample_length, h), negative_samples.view(-1, sample_length, h)), dim=0)
142 | model_padding_mask = padding_mask.view(-1, sample_length).repeat(2, 1)
143 | extended_attn_mask = model_padding_mask[:, None, :] # Extend here already so that no causal mask is added.
144 | memory = self.em_dimension_adjustment_layer(memory)
145 | # bsz must be repeated for each sample and again twice (positive vs. negative sample)
146 | memory = memory.repeat(2 * self.num_samples, 1, 1)
147 | mem_mask = mem_mask.repeat(2 * self.num_samples, 1)
148 | output = self.decoder(inputs_embeds=model_in, attention_mask=extended_attn_mask,
149 | encoder_hidden_states=memory, encoder_attention_mask=mem_mask)
150 | hiddens_pos, hiddens_neg = output.last_hidden_state.split(model_in.shape[0] // 2)
151 | hiddens_pos = hiddens_pos.reshape(bsz, self.num_samples, sample_length, h)
152 | hiddens_neg = hiddens_neg.reshape(bsz, self.num_samples, sample_length, h)
153 | return hiddens_neg, hiddens_pos
154 |
155 | def _calc_recollection_loss(self, original_input, hiddens_pos, masked_items_map, start_indices, dims):
156 | bsz, _, _, _, num_tokens_to_mask, _ = dims
157 |
158 | recollection_loss = torch.scalar_tensor(0, device=hiddens_pos.device, dtype=hiddens_pos.dtype)
159 | for i in range(bsz):
160 | for s in range(self.num_samples):
161 | mask_indices = masked_items_map[i, s].nonzero().squeeze(dim=-1)
162 | # -1 because masked_items respects CLS tokens at position 0, which is not part of original_input
163 | original_masked_item_indices = mask_indices + start_indices[i, s].item() - 1
164 | reconstructed_items = F.linear(hiddens_pos[i, s, mask_indices],
165 | self.input_dimension_adjustment_layer.weight.T)
166 | # other neg. sampling strategy?
167 | # (here, the contrastive loss is calculated between all the masked items from the original input.
168 | # This might be bad, since neighboring video clips might have actually similar features?
169 | # Other option would be to sample randomly from another batch)
170 | # Also introduce another hyperparameter: number of sampled items. RM uses 30 for ActivityNetQA
171 | target_items = original_input[i, original_masked_item_indices]
172 | assert reconstructed_items.shape == target_items.shape, \
173 | f'{reconstructed_items.shape} != {target_items.shape}'
174 | products = torch.inner(reconstructed_items, target_items) # x,y = sum(reconstructed[x] * original[y])
175 | recollection_loss += torch.log_softmax(products, dim=1).trace()
176 | recollection_loss = - recollection_loss / (num_tokens_to_mask * self.num_samples * bsz)
177 | return recollection_loss
178 |
179 | def _calc_familiarity_loss(self, hiddens_neg, hiddens_pos, dims):
180 | bsz = dims[0]
181 | pos_cls_output = hiddens_pos[..., 0, :]
182 | neg_cls_output = hiddens_neg[..., 0, :]
183 | pos_scores = self.pos_neg_projection(pos_cls_output)
184 | neg_scores = self.pos_neg_projection(neg_cls_output)
185 | # This has wrong sign in the RM Paper (Eq. 4), leading to NaN
186 | # neg_score has goal 0 => 1-neg_score has goal 1 => want to maximize its log
187 | # (thus needs negative sign due to overall loss minimization)
188 | # since log(x) in (-inf, 0) for x in (0, 1)
189 | # analogous for pos_score, which has goal 1 directly
190 | familiarity_loss = pos_scores.log() + (1 - neg_scores).log()
191 | familiarity_loss = -torch.sum(familiarity_loss) / (self.num_samples * bsz)
192 | return familiarity_loss
193 |
194 |
195 | class RehearsalMemoryMachine(nn.Module):
196 |
197 | def __init__(self,
198 | pretrained_encoder: str,
199 | input_dim: int,
200 | mem_hidden_size: int,
201 | num_memory_slots: int,
202 | segment_length: int,
203 | slot_to_item_num_heads: int = 1,
204 | use_independent_gru_per_mem_slot=False
205 | ) -> None:
206 | super().__init__()
207 | self.segment_length = segment_length
208 | self.num_memory_slots = num_memory_slots
209 | self.mem_hidden_size = mem_hidden_size
210 | self.encoder: PreTrainedModel = AutoModel.from_pretrained(pretrained_encoder)
211 | if hasattr(self.encoder, 'get_encoder'): # In case pretrained_encoder is actually encoder-decoder model
212 | self.encoder = self.encoder.get_encoder()
213 | feature_size = self.encoder.get_input_embeddings().embedding_dim
214 | if input_dim == feature_size:
215 | self.input_transform = nn.Identity()
216 | else:
217 | self.input_transform = nn.Linear(input_dim, feature_size, bias=False)
218 | self.slot_to_item_attn = nn.MultiheadAttention(embed_dim=mem_hidden_size,
219 | kdim=feature_size, vdim=feature_size,
220 | num_heads=slot_to_item_num_heads,
221 | batch_first=True)
222 | num_recurrent_units = num_memory_slots if use_independent_gru_per_mem_slot else 1
223 | self.recurrent_units = nn.ModuleList([
224 | nn.GRUCell(input_size=mem_hidden_size,
225 | hidden_size=mem_hidden_size,
226 | bias=False)
227 | for _ in range(num_recurrent_units)
228 | ])
229 |
230 | def forward(self, input_items, input_mask) -> torch.Tensor:
231 | """
232 | Process the input items, and return memory state at the end.
233 | :param input_items: input of shape Batch x Sequence x InputHidden
234 | :param input_mask: mask of shape Batch x Sequence, 1 = valid token, 0 = masked token
235 | :return: memory state of shape Batch x NumMemorySlots x MemHidden
236 | """
237 | bsz, seq, h = input_items.shape
238 | assert input_mask.shape == (bsz, seq)
239 |
240 | x = self.input_transform(input_items)
241 |
242 | num_segments = math.ceil(seq / self.segment_length)
243 | memory = torch.zeros(bsz, self.num_memory_slots, self.mem_hidden_size,
244 | device=x.device, dtype=x.dtype)
245 |
246 | for t in range(num_segments):
247 | current_slice = slice(t * self.segment_length, (t + 1) * self.segment_length)
248 | x_t = x[:, current_slice]
249 | mask_t = input_mask[:, current_slice]
250 | active_batches = mask_t.sum(dim=1) > 0
251 | if not active_batches.any():
252 | raise RuntimeError('Bad batch - padding only?')
253 |
254 | f_t = self.encoder(inputs_embeds=x_t[active_batches],
255 | attention_mask=mask_t[active_batches]).last_hidden_state
256 |
257 | l_t, attn_weights = self.slot_to_item_attn(query=memory[active_batches],
258 | key=f_t, value=f_t,
259 | key_padding_mask=~mask_t[active_batches])
260 | # l_t : bsz x num_memory_slots x mem_hidden
261 | if len(self.recurrent_units) == 1:
262 | flattened_l_t = l_t.reshape(-1, self.mem_hidden_size)
263 | flattened_mem = memory[active_batches].reshape(-1, self.mem_hidden_size)
264 | new_mem = self.recurrent_units[0](input=flattened_l_t, hx=flattened_mem)
265 | active_bsz = active_batches.sum()
266 | memory[active_batches] = new_mem.view(active_bsz, self.num_memory_slots, self.mem_hidden_size)
267 | else:
268 | for i in range(self.num_memory_slots):
269 | memory[active_batches, i, :] = self.recurrent_units[i](input=l_t[:, i, :],
270 | hx=memory[active_batches, i, :])
271 |
272 | return memory
273 |
274 |
275 | class RehearsalMemoryEmqaModel(MemoryAugmentedTransformerEmqaModel):
276 |
277 | def __init__(self,
278 | rehearsal_machine: RehearsalMemoryMachine,
279 | rehearsal_trainer: RehearsalTrainingModule,
280 | pretrained_enc_dec: str
281 | ) -> None:
282 | super().__init__(pretrained_enc_dec)
283 | self.rehearsal_machine = rehearsal_machine
284 | self.rehearsal_trainer = rehearsal_trainer
285 |
286 | def forward_memory(self, video_features, video_mask,
287 | moment_localization_labels, question_encoding # unused
288 | ):
289 | return self.rehearsal_machine(video_features, video_mask)
290 |
291 | def calc_additional_loss(self, question_tokens, question_mask, video_features, video_mask, answer_tokens,
292 | answer_mask, batch_sample_ids, context, context_mask, final_memory, mem_mask,
293 | transformer_output, moment_localization_labels):
294 | if self.rehearsal_trainer:
295 | loss_rec, loss_fam = self.rehearsal_trainer(final_memory, mem_mask,
296 | video_features, video_mask,
297 | batch_sample_ids)
298 | else:
299 | loss_rec, loss_fam = torch.zeros(2, dtype=context.dtype, device=context.device)
300 | return {
301 | 'recollection_loss': loss_rec,
302 | 'familiarity_loss': loss_fam
303 | }
304 |
--------------------------------------------------------------------------------
/model/external/compressive_transformer.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/lucidrains/compressive-transformer-pytorch
2 |
3 | import math
4 | import sys
5 | from collections import namedtuple
6 | from functools import partial
7 | from inspect import isfunction
8 | from typing import Type, Tuple, List, Union
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import nn, Tensor
13 |
14 | # structs
15 |
16 | Memory: Type[Tuple[Tensor, List[Tensor], Tensor]] = namedtuple('Memory', ['mem', 'compressed_mem', 'lt_mem'])
17 |
18 |
19 | # helper functions
20 |
21 | def to(t):
22 | return {'dtype': t.dtype, 'device': t.device}
23 |
24 |
25 | def cast_tuple(el):
26 | return el if isinstance(el, tuple) else (el,)
27 |
28 |
29 | def default(x, val):
30 | if x is not None:
31 | return x
32 | return val if not isfunction(val) else val()
33 |
34 |
35 | def max_neg_value(tensor):
36 | return -torch.finfo(tensor.dtype).max
37 |
38 |
39 | def reshape_dim(t, dim, split_dims):
40 | shape = list(t.shape)
41 | num_dims = len(shape)
42 | dim = (dim + num_dims) % num_dims
43 | shape[dim:dim + 1] = split_dims
44 | return t.reshape(shape)
45 |
46 |
47 | def split_at_index(dim, index, t):
48 | pre_slices = (slice(None),) * dim
49 | l = (*pre_slices, slice(None, index))
50 | r = (*pre_slices, slice(index, None))
51 | return t[l], t[r]
52 |
53 |
54 | def queue_fifo(*args, length, dim=-2):
55 | queue = torch.cat(args, dim=dim)
56 | if length > 0:
57 | return split_at_index(dim, -length, queue)
58 |
59 | device = queue.device
60 | shape = list(queue.shape)
61 | shape[dim] = 0
62 | return queue, torch.empty(shape, device=device)
63 |
64 |
65 | def shift(x):
66 | *_, i, j = x.shape
67 | zero_pad = torch.zeros((*_, i, i), **to(x))
68 | x = torch.cat([x, zero_pad], -1)
69 | l = i + j - 1
70 | x = x.view(*_, -1)
71 | zero_pad = torch.zeros(*_, -x.size(-1) % l, **to(x))
72 | shifted = torch.cat([x, zero_pad], -1).view(*_, -1, l)
73 | return shifted[..., :i, i - 1:]
74 |
75 |
76 | def iterate_tensor(t):
77 | length = t.shape[0]
78 | for ind in range(length):
79 | yield t[ind]
80 |
81 |
82 | # full attention for calculating auxiliary reconstruction loss
83 |
84 | def full_attn(q, k, v, dropout_fn=None):
85 | *_, dim = q.shape
86 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * (dim ** -0.5)
87 | attn = dots.softmax(dim=-1)
88 | if dropout_fn is not None:
89 | attn = dropout_fn(attn)
90 | return torch.einsum('bhij,bhjd->bhid', attn, v)
91 |
92 |
93 | # helper classes
94 |
95 | class Residual(nn.Module):
96 | def __init__(self, fn):
97 | super().__init__()
98 | self.fn = fn
99 |
100 | def forward(self, x, **kwargs):
101 | out = self.fn(x, **kwargs)
102 | out = cast_tuple(out)
103 | ret = (out[0] + x), *out[1:]
104 | return ret
105 |
106 |
107 | class GRUGating(nn.Module):
108 | def __init__(self, dim, fn, mogrify=False):
109 | super().__init__()
110 | self.dim = dim
111 | self.fn = fn
112 | self.gru = nn.GRUCell(dim, dim)
113 | if mogrify:
114 | try:
115 | # noinspection PyPackageRequirements
116 | from mogrifier import Mogrifier
117 | self.mogrify = Mogrifier(dim, factorize_k=dim // 4) if mogrify else None
118 | except ImportError:
119 | print('!! mogrify is set, but mogrifier library not available!'
120 | ' Run "pip install mogrifier" to fix.', file=sys.stderr)
121 |
122 | def forward(self, x, **kwargs):
123 | batch, dim = x.shape[0], self.dim
124 | out = self.fn(x, **kwargs)
125 | (y, *rest) = cast_tuple(out)
126 |
127 | if self.mogrify is not None:
128 | y, x = self.mogrify(y, x)
129 |
130 | gated_output = self.gru(
131 | y.reshape(-1, dim),
132 | x.reshape(-1, dim)
133 | )
134 |
135 | gated_output = gated_output.reshape(batch, -1, dim)
136 | ret = gated_output, *rest
137 | return ret
138 |
139 |
140 | class PreNorm(nn.Module):
141 | def __init__(self, dim, fn):
142 | super().__init__()
143 | self.norm = nn.LayerNorm(dim)
144 | self.fn = fn
145 |
146 | def forward(self, x, **kwargs):
147 | x = self.norm(x)
148 | return self.fn(x, **kwargs)
149 |
150 |
151 | class ConvCompress(nn.Module):
152 | def __init__(self, dim, ratio=4):
153 | super().__init__()
154 | self.conv = nn.Conv1d(dim, dim, ratio, stride=ratio)
155 |
156 | def forward(self, mem):
157 | mem = mem.transpose(1, 2)
158 | compressed_mem = self.conv(mem)
159 | return compressed_mem.transpose(1, 2)
160 |
161 |
162 | class DetachedConvCompress(nn.Module):
163 | def __init__(self, reference: ConvCompress):
164 | super().__init__()
165 | self.reference = reference
166 |
167 | def forward(self, mem):
168 | weight = self.reference.conv.weight.detach()
169 | bias = self.reference.conv.bias.detach()
170 |
171 | mem = mem.transpose(1, 2)
172 | compressed_mem = F.conv1d(mem, weight, bias, self.reference.conv.stride,
173 | self.reference.conv.padding, self.reference.conv.dilation,
174 | self.reference.conv.groups)
175 | return compressed_mem.transpose(1, 2)
176 |
177 |
178 | # feedforward
179 |
180 | class GELU_(nn.Module):
181 | def forward(self, x):
182 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
183 |
184 |
185 | GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
186 |
187 |
188 | class FeedForward(nn.Module):
189 | def __init__(self, dim, ff_dim, dropout=0., activation=None, glu=False):
190 | super().__init__()
191 | activation = default(activation, GELU)
192 |
193 | self.glu = glu
194 | self.w1 = nn.Linear(dim, ff_dim * (2 if glu else 1))
195 | self.act = activation()
196 | self.dropout = nn.Dropout(dropout)
197 | self.w2 = nn.Linear(ff_dim, dim)
198 |
199 | def forward(self, x, **kwargs):
200 | if not self.glu:
201 | x = self.w1(x)
202 | x = self.act(x)
203 | else:
204 | x, v = self.w1(x).chunk(2, dim=-1)
205 | x = self.act(x) * v
206 |
207 | x = self.dropout(x)
208 | x = self.w2(x)
209 | return x
210 |
211 |
212 | class CompressionStage(nn.Module):
213 |
214 | def __init__(self, dim, cmem_ratio, cmem_len, attn_heads, attn_dim_heads, reconstruction_attn_dropout,
215 | prev_lvl_mem_start_index, prev_lvl_mem_len) -> None:
216 | super().__init__()
217 | self.attn_heads = attn_heads # of the containing SelfAttention object
218 | self.attn_dim_heads = attn_dim_heads
219 | self.mem_len_this_lvl = cmem_len
220 | self.prev_lvl_mem_start_index = prev_lvl_mem_start_index
221 | self.prev_lvl_mem_len = prev_lvl_mem_len
222 |
223 | assert prev_lvl_mem_len % cmem_ratio == 0, \
224 | f'mem length of previous level ({prev_lvl_mem_len}) must be divisble by compression ratio ({cmem_ratio})'
225 |
226 | self.reconstruction_attn_dropout = nn.Dropout(reconstruction_attn_dropout)
227 | self.compress_mem_fn = ConvCompress(dim, cmem_ratio)
228 | self.compress_mem_fn_without_grad = DetachedConvCompress(self.compress_mem_fn)
229 |
230 | def forward(self, prev_cmem_this_lvl, old_mem_prev_lvl, prev_lvl_mem_len, q, k, v, to_kv_weight):
231 | compressed_mem = self.compress_mem_fn_without_grad(old_mem_prev_lvl)
232 | old_cmem, new_cmem = split_at_index(1, -self.mem_len_this_lvl,
233 | torch.cat((prev_cmem_this_lvl, compressed_mem), dim=1))
234 | aux_loss = torch.zeros(1, requires_grad=True, **to(prev_cmem_this_lvl))
235 |
236 | if not self.training:
237 | return old_cmem, new_cmem, aux_loss
238 |
239 | # calculate compressed memory auxiliary loss if training
240 | merge_heads = lambda x: reshape_dim(x, -1, (-1, self.attn_dim_heads)).transpose(1, 2)
241 |
242 | compressed_mem = self.compress_mem_fn(old_mem_prev_lvl.detach())
243 | cmem_k, cmem_v = F.linear(compressed_mem, to_kv_weight.detach()).chunk(2, dim=-1)
244 | cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v))
245 | cmem_k, cmem_v = map(lambda x: x.expand(-1, self.attn_heads, -1, -1), (cmem_k, cmem_v))
246 |
247 | old_mem_range = slice(- min(prev_lvl_mem_len, self.prev_lvl_mem_len) - self.prev_lvl_mem_start_index,
248 | -self.prev_lvl_mem_start_index)
249 | old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v))
250 |
251 | q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v))
252 |
253 | attn_fn = partial(full_attn, dropout_fn=self.reconstruction_attn_dropout)
254 |
255 | aux_loss = F.mse_loss(
256 | attn_fn(q, old_mem_k, old_mem_v),
257 | attn_fn(q, cmem_k, cmem_v)
258 | )
259 |
260 | return old_cmem, new_cmem, aux_loss
261 |
262 |
263 | # attention.
264 |
265 | class SelfAttention(nn.Module):
266 |
267 | @staticmethod
268 | def validate_cmem_parameters(seq_len: int, mem_len: int,
269 | cmem_lengths: List[int], cmem_ratios: Union[List[int], int]):
270 | assert len(cmem_lengths) == len(cmem_ratios), f'{cmem_lengths}, {cmem_ratios} should have same length!'
271 | compression_levels = len(cmem_lengths)
272 | # compression stage 0 is mem -> cmem
273 | one_input_block_size = seq_len
274 | for i in range(compression_levels):
275 | assert one_input_block_size >= cmem_ratios[i], \
276 | f'At compression level {i}, one input block of {seq_len} tokens is already reduced to ' \
277 | f'{one_input_block_size} compressed tokens, cannot be compressed again with ratio {cmem_ratios[i]}'
278 | assert cmem_lengths[i] >= (one_input_block_size // cmem_ratios[i]), \
279 | f'length of compressed memory at level {i + 1} should be at least the compressed input block length ' \
280 | f'at level {i} ({one_input_block_size}) divided by the compression ratio {cmem_ratios[i]}, ' \
281 | f'i.e. at least {int(one_input_block_size // cmem_ratios[i])}'
282 | one_input_block_size //= cmem_ratios[i]
283 |
284 | # simulate information flow
285 | log = ''
286 | mem = 0
287 | cmems = [0] * compression_levels
288 | while True: # simulate until lt mem would be filled. then, sizes do not change anymore (everything full)
289 | mem += seq_len
290 | log += f'i={seq_len} -> '
291 | if mem <= mem_len:
292 | log += f'm={mem}\n'
293 | continue
294 | old_mem = mem - mem_len
295 | mem = mem_len
296 | log += f'm={mem} -> {old_mem}'
297 | for lvl in range(compression_levels):
298 | log += f' --/{cmem_ratios[lvl]}--> c{lvl}='
299 | assert old_mem % cmem_ratios[lvl] == 0, \
300 | f'mem length {old_mem} from previous layer not divisible by compression ratio {cmem_ratios[lvl]} ' \
301 | f'at compression level {lvl}. Log:\n{log}'
302 | cmems[lvl] += old_mem // cmem_ratios[lvl]
303 | if cmems[lvl] <= cmem_lengths[lvl]:
304 | log += f'{cmems[lvl]}'
305 | old_mem = 0
306 | break
307 | old_mem = cmems[lvl] - cmem_lengths[lvl]
308 | cmems[lvl] = cmem_lengths[lvl]
309 | log += f'{cmems[lvl]} -> {old_mem}'
310 | log += '\n'
311 | if old_mem > 0:
312 | break
313 |
314 | def __init__(self, dim, seq_len, mem_len: int,
315 | cmem_lengths: List[int], cmem_ratios: Union[List[int], int],
316 | use_ltmem=True,
317 | heads=8, attn_dropout=0., dropout=0.,
318 | reconstruction_attn_dropout=0., one_kv_head=False):
319 | super().__init__()
320 | assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'
321 | if isinstance(cmem_ratios, int):
322 | cmem_ratios = [cmem_ratios] * len(cmem_lengths)
323 | SelfAttention.validate_cmem_parameters(seq_len, mem_len, cmem_lengths, cmem_ratios)
324 |
325 | self.heads = heads
326 | self.dim_head = dim // heads
327 | self.seq_len = seq_len
328 | self.mem_len = mem_len
329 | self.num_cmem_stages = len(cmem_lengths)
330 | self.cmem_lengths = cmem_lengths
331 | self.cmem_ratios = cmem_ratios
332 | self.use_ltmem = use_ltmem
333 | self.scale = self.dim_head ** (-0.5)
334 |
335 | self.compression_stages = nn.ModuleList()
336 | running_start_index = self.seq_len
337 | prev_length = mem_len
338 | for i in range(self.num_cmem_stages):
339 | self.compression_stages.append(CompressionStage(
340 | dim, cmem_ratios[i], cmem_lengths[i], heads, self.dim_head,
341 | reconstruction_attn_dropout,
342 | prev_lvl_mem_start_index=running_start_index,
343 | prev_lvl_mem_len=prev_length))
344 | prev_length = cmem_lengths[i]
345 | running_start_index += prev_length
346 |
347 | if self.use_ltmem:
348 | self.cmem_to_ltmem_query = nn.Parameter(torch.zeros(dim), requires_grad=True)
349 | self.ltmem_tokv = nn.Linear(dim, dim * 2, bias=False)
350 | self.recurrence = nn.GRUCell(dim, dim, bias=False)
351 |
352 | self.to_q = nn.Linear(dim, dim, bias=False)
353 |
354 | kv_dim = self.dim_head if one_kv_head else dim
355 | self.to_kv = nn.Linear(dim, kv_dim * 2, bias=False)
356 | self.to_out = nn.Linear(dim, dim)
357 |
358 | self.attn_dropout = nn.Dropout(attn_dropout)
359 | self.dropout = nn.Dropout(dropout)
360 |
361 | def forward(self, x, memories=None, pos_emb=None, input_mask=None, calc_memory=True, **kwargs):
362 | b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head
363 |
364 | memories: Memory = default(memories, (None, None, None))
365 | mem, cmems, ltmem = memories
366 |
367 | init_empty_mem = lambda: torch.empty(b, 0, e, **to(x))
368 | mem = default(mem, init_empty_mem)
369 | cmems = default(cmems, lambda: [init_empty_mem() for i in range(self.num_cmem_stages)])
370 | ltmem = default(ltmem, init_empty_mem)
371 |
372 | mem_len = mem.shape[1]
373 | cmem_len_sum = sum(cmem.shape[1] for cmem in cmems)
374 | ltmem_len = ltmem.shape[1]
375 | assert 0 <= ltmem_len <= 1, str(ltmem)
376 |
377 | q = self.to_q(x)
378 |
379 | if self.num_cmem_stages == 0:
380 | kv_input = torch.cat((ltmem, mem, x), dim=1)
381 | else:
382 | kv_input = torch.cat((ltmem, *cmems, mem, x), dim=1)
383 | kv_len = kv_input.shape[1]
384 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
385 |
386 | merge_heads = lambda x: reshape_dim(x, -1, (-1, dim_h)).transpose(1, 2)
387 | q, k, v = map(merge_heads, (q, k, v))
388 |
389 | k, v = map(lambda x: x.expand(-1, h, -1, -1), (k, v))
390 |
391 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
392 | mask_value = max_neg_value(dots)
393 |
394 | if pos_emb is not None:
395 | pos_emb = pos_emb[:, -kv_len:].type(q.dtype)
396 | pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale
397 | pos_dots = shift(pos_dots)
398 | dots = dots + pos_dots
399 |
400 | if input_mask is not None:
401 | mask = input_mask[:, None, :, None] * input_mask[:, None, None, :]
402 | mask = F.pad(mask, [mem_len + cmem_len_sum + ltmem_len, 0], value=True)
403 | dots.masked_fill_(~mask, mask_value)
404 |
405 | total_mem_len = mem_len + cmem_len_sum + ltmem_len
406 | mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal=1 + total_mem_len).bool()
407 | dots.masked_fill_(mask[None, None, ...], mask_value)
408 |
409 | attn = dots.softmax(dim=-1)
410 | attn = self.attn_dropout(attn)
411 |
412 | out = torch.einsum('bhij,bhjd->bhid', attn, v)
413 | out = out.transpose(1, 2).reshape(b, t, -1)
414 | logits = self.to_out(out)
415 | logits = self.dropout(logits)
416 |
417 | new_mem = mem
418 | new_cmems = cmems
419 | new_ltmem = ltmem
420 | aux_loss = torch.zeros(1, requires_grad=False, **to(q))
421 |
422 | if self.seq_len > t or not calc_memory:
423 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss
424 |
425 | # calculate memory and compressed memory
426 |
427 | old_mem, new_mem = queue_fifo(mem, x, length=self.mem_len, dim=1)
428 | old_mem_padding = old_mem.shape[1] % self.cmem_ratios[0]
429 |
430 | if old_mem_padding != 0:
431 | old_mem = F.pad(old_mem, [0, 0, old_mem_padding, 0], value=0.)
432 |
433 | if old_mem.shape[1] == 0 or self.num_cmem_stages <= 0:
434 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss
435 |
436 | prev_mem_len = mem_len
437 | old_mem_prev_lvl = old_mem
438 | for i in range(self.num_cmem_stages):
439 | if old_mem_prev_lvl.size(1) == 0:
440 | break
441 | old_mem_prev_lvl, new_cmems[i], lvl_aux_loss = self.compression_stages[i](
442 | prev_cmem_this_lvl=cmems[i],
443 | old_mem_prev_lvl=old_mem_prev_lvl,
444 | prev_lvl_mem_len=prev_mem_len,
445 | q=q, k=k, v=v,
446 | to_kv_weight=self.to_kv.weight
447 | )
448 | aux_loss += lvl_aux_loss
449 | prev_mem_len = cmems[i].size(1)
450 |
451 | if old_mem_prev_lvl.size(1) > 0 and self.use_ltmem:
452 | old_cmem_k, old_cmem_v = (self.ltmem_tokv(old_mem_prev_lvl)
453 | .unsqueeze(dim=1) # Insert fake head dimension
454 | .chunk(2, dim=-1))
455 | to_ltmem_query = self.cmem_to_ltmem_query.expand(b, 1, 1, e) # b x 1(=h) x 1(=seq) x e
456 | ltmem_update = full_attn(to_ltmem_query, old_cmem_k, old_cmem_v)
457 | if ltmem_len > 0:
458 | new_ltmem = self.recurrence(ltmem_update.view(b, e), ltmem.squeeze(dim=1)).unsqueeze(dim=1)
459 | else:
460 | new_ltmem = ltmem_update.squeeze(dim=1) # Remove heads dimension
461 |
462 | return logits, Memory(new_mem, new_cmems, new_ltmem), aux_loss
463 |
464 |
465 | # transformer
466 |
467 | class CompressiveTransformer(nn.Module):
468 | def __init__(self, num_tokens, dim, seq_len, depth, emb_dim=None,
469 | memory_layers=None, mem_len=None,
470 | cmem_lengths: List[int] = None, cmem_ratios: Union[int, List[int]] = 4,
471 | use_ltmem=True,
472 | heads=8, gru_gated_residual=True, mogrify_gru=False, attn_dropout=0.,
473 | ff_glu=False, ff_dim=None, ff_dropout=0.,
474 | attn_layer_dropout=0., reconstruction_attn_dropout=0., reconstruction_loss_weight=1.,
475 | one_kv_head=False):
476 | super().__init__()
477 | if isinstance(cmem_ratios, int):
478 | if cmem_lengths is None:
479 | cmem_ratios = [cmem_ratios]
480 | else:
481 | cmem_ratios = [cmem_ratios] * len(cmem_lengths)
482 | else:
483 | assert cmem_lengths is not None
484 | assert len(cmem_lengths) == len(cmem_ratios)
485 |
486 | ff_dim = default(ff_dim, dim * 4)
487 | emb_dim = default(emb_dim, dim)
488 | mem_len = default(mem_len, seq_len)
489 | cmem_lengths = default(cmem_lengths, [mem_len // cmem_ratios[0]])
490 | memory_layers = default(memory_layers, list(range(1, depth + 1)))
491 |
492 | assert mem_len >= seq_len, 'length of memory should be at least the sequence length'
493 | assert all(
494 | [0 < layer <= depth for layer in memory_layers]), 'one of the indicated memory layers is invalid'
495 |
496 | self.seq_len = seq_len
497 |
498 | self.depth = depth
499 | self.memory_layers = list(memory_layers)
500 | self.num_cmem_stages = len(cmem_lengths)
501 |
502 | self.token_emb = nn.Embedding(num_tokens, emb_dim)
503 | self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)
504 |
505 | seq_and_mem_len = seq_len + mem_len + sum(cmem_lengths) + (1 if use_ltmem else 0) # + 1 for LT Memory
506 | self.pos_emb = nn.Parameter(torch.zeros(heads, seq_and_mem_len, dim // heads), requires_grad=True)
507 |
508 | self.to_logits = nn.Sequential(
509 | nn.Identity() if emb_dim == dim else nn.Linear(dim, emb_dim),
510 | nn.Linear(emb_dim, num_tokens)
511 | )
512 |
513 | wrapper = partial(GRUGating, dim, mogrify=mogrify_gru) if gru_gated_residual else Residual
514 |
515 | self.attn_layers = nn.ModuleList([
516 | wrapper(PreNorm(dim, SelfAttention(
517 | dim, seq_len, mem_len,
518 | cmem_lengths if (i + 1) in memory_layers else [],
519 | cmem_ratios if (i + 1) in memory_layers else [],
520 | use_ltmem and (i + 1) in memory_layers,
521 | heads, dropout=attn_layer_dropout,
522 | attn_dropout=attn_dropout,
523 | reconstruction_attn_dropout=reconstruction_attn_dropout,
524 | one_kv_head=one_kv_head
525 | ))) for i in range(depth)])
526 | self.ff_layers = nn.ModuleList(
527 | [wrapper(PreNorm(dim, FeedForward(dim, ff_dim, dropout=ff_dropout, glu=ff_glu))) for _ in range(depth)])
528 |
529 | self.reconstruction_loss_weight = reconstruction_loss_weight
530 |
531 | def forward(self, x, memories=None, mask=None):
532 | input_device = x.device
533 | x = self.token_emb(x)
534 | x = self.to_model_dim(x)
535 | b, t, d = x.shape
536 |
537 | assert t <= self.seq_len, f'input contains a sequence length {t} that is greater than the designated maximum ' \
538 | f'sequence length {self.seq_len} '
539 |
540 | memories = default(memories, (None, None, None))
541 | mem, cmems, ltmem = memories
542 |
543 | num_memory_layers = len(self.memory_layers)
544 | init_empty_mem = lambda: torch.empty(num_memory_layers, b, 0, d, **to(x))
545 | mem = default(mem, init_empty_mem)
546 | cmems = default(cmems, lambda: [init_empty_mem() for i in range(self.num_cmem_stages)])
547 | ltmem = default(ltmem, init_empty_mem)
548 |
549 | total_len = mem.shape[2] + sum(cmem.shape[2] for cmem in cmems) + ltmem.shape[2] + self.seq_len
550 | pos_emb = self.pos_emb[:, (self.seq_len - t):total_len]
551 |
552 | # Lists of {c,lt,}mem per transformer layer
553 | next_mem = []
554 | next_cmems = []
555 | next_ltmem = []
556 | aux_loss = torch.tensor(0., requires_grad=True, **to(x))
557 |
558 | mem_iter, ltmem_iter = map(iterate_tensor, (mem, ltmem))
559 | cmems_iter = ([cmem[i] for cmem in cmems] for i in range(num_memory_layers))
560 |
561 | for ind in range(self.depth):
562 | x, mem_out, cmems_out, ltmem_out, layer_aux_loss \
563 | = self._pass_through_layer(ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x)
564 | aux_loss = aux_loss + layer_aux_loss
565 |
566 | if (ind + 1) not in self.memory_layers:
567 | continue
568 |
569 | next_mem.append(mem_out)
570 | next_cmems.append(cmems_out)
571 | next_ltmem.append(ltmem_out)
572 |
573 | out = self.to_logits(x)
574 |
575 | next_mem, next_ltmem = map(torch.stack, (next_mem, next_ltmem))
576 | next_cmems = [torch.stack([next_cmems[layer][cstage] for layer in range(num_memory_layers)])
577 | for cstage in range(self.num_cmem_stages)]
578 |
579 | aux_loss = aux_loss * self.reconstruction_loss_weight / num_memory_layers
580 | out = out.to(device=input_device)
581 | return out, Memory(mem=next_mem, compressed_mem=next_cmems, lt_mem=next_ltmem), aux_loss
582 |
583 | def _pass_through_layer(self, ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x):
584 | attn = self.attn_layers[ind]
585 | ff = self.ff_layers[ind]
586 |
587 | layer_num = ind + 1
588 | use_memory = layer_num in self.memory_layers
589 | memories = (next(mem_iter), next(cmems_iter), next(ltmem_iter)) if use_memory else None
590 | _dev = lambda t: t.to(device=x.device)
591 | memories = (_dev(memories[0]), [_dev(m) for m in memories[1]], _dev(memories[2])) if memories else None
592 |
593 | x, (mem_out, cmems_out, ltmem_out), layer_aux_loss = attn(x, memories=memories, calc_memory=use_memory,
594 | input_mask=mask, pos_emb=pos_emb)
595 | x, = ff(x)
596 |
597 | return x, mem_out, cmems_out, ltmem_out, layer_aux_loss
598 |
599 |
600 | class MultiDeviceCompressiveTransformer(CompressiveTransformer):
601 | """
602 | CompressiveTransformer with model parallelism.
603 | Note: Start fairseq-train with
604 | --distributed-no-spawn
605 | --distributed-world-size 1
606 | to prevent data parallelism
607 | """
608 |
609 | def __init__(self, num_tokens, dim, seq_len, depth, emb_dim=None, memory_layers=None, mem_len=None,
610 | cmem_lengths: List[int] = None, cmem_ratios: Union[int, List[int]] = 4, use_ltmem=True, heads=8,
611 | gru_gated_residual=True, mogrify_gru=False, attn_dropout=0., ff_glu=False, ff_dim=None, ff_dropout=0.,
612 | attn_layer_dropout=0., reconstruction_attn_dropout=0., reconstruction_loss_weight=1.,
613 | one_kv_head=False,
614 | layers_to_gpus=None):
615 | super().__init__(num_tokens, dim, seq_len, depth, emb_dim, memory_layers, mem_len, cmem_lengths, cmem_ratios,
616 | use_ltmem, heads, gru_gated_residual, mogrify_gru, attn_dropout, ff_glu, ff_dim, ff_dropout,
617 | attn_layer_dropout, reconstruction_attn_dropout, reconstruction_loss_weight, one_kv_head)
618 |
619 | gpus = torch.cuda.device_count()
620 | layers_to_gpus = default(layers_to_gpus, [int(i / self.depth * gpus) for i in range(self.depth)])
621 | assert len(layers_to_gpus) == self.depth
622 | assert all(0 <= x < gpus for x in layers_to_gpus)
623 | self.layers_to_gpus = layers_to_gpus
624 |
625 | def cuda(self, device=None):
626 | # pos_emb, token_emb, to_model_dim and to_logits always stays on device 0
627 | self.pos_emb = nn.Parameter(self.pos_emb.cuda(), requires_grad=True)
628 | self.token_emb.to(device=0)
629 | self.to_model_dim.to(device=0)
630 | self.to_logits.to(device=torch.cuda.device_count() - 1)
631 | for i in range(self.depth):
632 | self.attn_layers[i].to(device=self.layers_to_gpus[i])
633 | self.ff_layers[i].to(device=self.layers_to_gpus[i])
634 | return self
635 |
636 | def _apply(self, fn):
637 | fake = torch.empty(0)
638 | if fn(fake).device.type == 'cuda' and fn(fake).device != fake.device:
639 | return self.cuda()
640 | else:
641 | # noinspection PyProtectedMember
642 | return super()._apply(fn)
643 |
644 | def _pass_through_layer(self, ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x):
645 | gpu = self.layers_to_gpus[ind]
646 |
647 | x = x.to(device=gpu)
648 | pos_emb = pos_emb.to(device=gpu)
649 | mask = mask.to(device=gpu) if mask else None
650 | x, mem_out, cmems_out, ltmem_out, layer_aux_loss = super()._pass_through_layer(
651 | ind, mem_iter, cmems_iter, ltmem_iter, mask, pos_emb, x)
652 |
653 | mem_out = mem_out.to(device=0) if mem_out is not None else None
654 | cmems_out = [m.to(device=0) for m in cmems_out] if cmems_out is not None else None
655 | ltmem_out = ltmem_out.to(device=0) if ltmem_out is not None else None
656 | layer_aux_loss = layer_aux_loss.to(device=0) if layer_aux_loss is not None else None
657 |
658 | return x, mem_out, cmems_out, ltmem_out, layer_aux_loss
659 |
--------------------------------------------------------------------------------