├── config ├── model │ ├── groundvqa_b.yaml │ └── groundvqa_s.yaml ├── dataset │ └── egovlp_internvideo.yaml └── base.yaml ├── scripts ├── train_groundvqa_s-qaego4d.sh ├── train_groundvqa_b-qaego4d_egotimeqa.sh ├── train_groundvqa_s-qaego4d_egotimeqa.sh ├── train_groundvqa_b-nlq.sh ├── evaluate_groundvqa_b-nlq.sh ├── evaluate_groundvqa_b-nlq_naq.sh ├── evaluate_groundvqa_s-qaego4d.sh ├── evaluate_groundvqa_b-qaego4d_egotimeqa.sh └── evaluate_groundvqa_s-qaego4d_egotimeqa.sh ├── requirements.txt ├── utils ├── generate_open_qa │ ├── merge.py │ └── generate.py └── generate_close_qa │ └── generate.py ├── LICENSE ├── model └── ours │ ├── model.py │ ├── dataset.py │ ├── lightning_module.py │ └── nlq_head.py ├── .gitignore ├── README.md ├── run.py ├── eval.py └── eval_nlq.py /config/model/groundvqa_b.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.ours.model.GroundVQA 2 | lm_path: google/flan-t5-base 3 | input_dim: ${dataset.feature_dim} 4 | freeze_word: True 5 | -------------------------------------------------------------------------------- /config/model/groundvqa_s.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.ours.model.GroundVQA 2 | lm_path: google/flan-t5-small 3 | input_dim: ${dataset.feature_dim} 4 | freeze_word: True 5 | -------------------------------------------------------------------------------- /scripts/train_groundvqa_s-qaego4d.sh: -------------------------------------------------------------------------------- 1 | # train with QaEgo4D 2 | CUDA_VISIBLE_DEVICES=6,7 python run.py \ 3 | model=groundvqa_s \ 4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 5 | dataset.batch_size=64 \ 6 | trainer.gpus=2 -------------------------------------------------------------------------------- /scripts/train_groundvqa_b-qaego4d_egotimeqa.sh: -------------------------------------------------------------------------------- 1 | # train with QaEgo4D + EgoTimeQA 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \ 3 | model=groundvqa_b \ 4 | 'dataset.qa_train_splits=[QaEgo4D_train,EgoTimeQA]' \ 5 | dataset.batch_size=16 \ 6 | trainer.gpus=8 7 | -------------------------------------------------------------------------------- /scripts/train_groundvqa_s-qaego4d_egotimeqa.sh: -------------------------------------------------------------------------------- 1 | # train with QaEgo4D + EgoTimeQA 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \ 3 | model=groundvqa_s \ 4 | 'dataset.qa_train_splits=[QaEgo4D_train,EgoTimeQA]' \ 5 | dataset.batch_size=32 \ 6 | trainer.gpus=8 7 | -------------------------------------------------------------------------------- /scripts/train_groundvqa_b-nlq.sh: -------------------------------------------------------------------------------- 1 | # train with NLQv2 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \ 3 | model=groundvqa_b \ 4 | 'dataset.nlq_train_splits=[NLQ_train]' \ 5 | 'dataset.test_splits=[NLQ_val]' \ 6 | dataset.batch_size=16 \ 7 | trainer.find_unused_parameters=True \ 8 | trainer.gpus=8 9 | -------------------------------------------------------------------------------- /scripts/evaluate_groundvqa_b-nlq.sh: -------------------------------------------------------------------------------- 1 | # NLQv2 val set 2 | CUDA_VISIBLE_DEVICES=1 python run.py \ 3 | model=groundvqa_b \ 4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 5 | 'dataset.test_splits=[NLQ_val]' \ 6 | dataset.batch_size=32 \ 7 | +trainer.test_only=True \ 8 | '+trainer.checkpoint_path=""' \ 9 | trainer.load_nlq_head=True 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 2 | torchvision==0.16.1 3 | tqdm 4 | ipdb 5 | h5py==3.10.0 6 | terminaltables 7 | nltk 8 | rouge-score==0.1.2 9 | numpy==1.23.5 10 | pytorch_lightning==1.5.10 11 | hydra-core==1.3.2 12 | optimum==1.14.1 13 | sentence_transformers==2.2.2 14 | positional_encodings==6.0.1 15 | ffmpeg-python==0.2.0 16 | transformers==4.35.2 17 | accelerate==0.24.1 18 | bitsandbytes -------------------------------------------------------------------------------- /scripts/evaluate_groundvqa_b-nlq_naq.sh: -------------------------------------------------------------------------------- 1 | # NLQv2 val set 2 | python run.py \ 3 | model=groundvqa_b \ 4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 5 | 'dataset.test_splits=[NLQ_val]' \ 6 | dataset.batch_size=32 \ 7 | +trainer.test_only=True \ 8 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-NLQ_NaQ-finetune_NLQ-VLG-val_R1_03=29.7.ckpt"' \ 9 | trainer.load_nlq_head=True -------------------------------------------------------------------------------- /config/dataset/egovlp_internvideo.yaml: -------------------------------------------------------------------------------- 1 | data_dir: data/unified 2 | nlq_val_anno: data/nlq_v2/nlq_val.json 3 | feature_type: egovlp_internvideo 4 | feature_dim: 2304 5 | max_v_len: 1200 6 | 7 | qa_train_splits: [] 8 | nlq_train_splits: [] 9 | test_splits: ['QaEgo4D_test', 'QaEgo4D_test_close', 'NLQ_val'] 10 | closeqa_weight: 50 11 | 12 | tokenizer_path: google/flan-t5-small 13 | 14 | num_workers: 4 15 | batch_size: 16 16 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: egovlp_internvideo 3 | - model: groundvqa_s 4 | - _self_ 5 | - override hydra/job_logging: none 6 | - override hydra/hydra_logging: none 7 | 8 | trainer: 9 | detect_anomaly: True 10 | max_epochs: 100 11 | accumulate_grad_batches: 1 12 | auto_resume: False 13 | gpus: 1 14 | log_every_n_steps: 1 15 | auto_lr_find: False 16 | enable_progress_bar: True 17 | monitor_variable: val_ROUGE 18 | monitor_mode: max 19 | find_unused_parameters: False 20 | precision: bf16 21 | val: False # test on the val set 22 | gradient_clip_val: 1.0 23 | save_nlq_results: null 24 | deterministic: True 25 | load_decoder: True 26 | load_nlq_head: True 27 | ignore_existing_checkpoints: True 28 | 29 | optim: 30 | optimizer: 31 | _target_: torch.optim.AdamW 32 | lr: 0.0001 33 | weight_decay: 0.0 34 | freeze: [ ] 35 | lr_scheduler: False 36 | 37 | hydra: 38 | run: 39 | dir: . 40 | output_subdir: null 41 | -------------------------------------------------------------------------------- /utils/generate_open_qa/merge.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import numpy as np 4 | 5 | 6 | split = 'EgoTimeQA' 7 | src1 = f'tmp/annotations.{split}_*.json' 8 | tgt = f'annotations.{split}.json' 9 | paths = glob.glob(src1) 10 | paths = sorted(paths) 11 | print('Merging') 12 | 13 | merge = [] 14 | for p in paths: 15 | print(p) 16 | x = json.load(open(p)) 17 | merge += x 18 | 19 | all_duration_sec = [(x['moment_end_frame'] - x['moment_start_frame']) / 30 for x in merge] 20 | mean_duration_sec = np.asarray(all_duration_sec).mean() 21 | # normalize duration_sec 22 | for x in merge: 23 | start_sec = x['moment_start_frame'] / 30 24 | end_sec = x['moment_end_frame'] / 30 25 | center_sec = (start_sec + end_sec) / 2 26 | duration_sec = (end_sec - start_sec) / mean_duration_sec 27 | x['moment_start_frame'] = (center_sec - duration_sec / 2) * 30 28 | x['moment_end_frame'] = (center_sec + duration_sec / 2) * 30 29 | 30 | print(f'into {tgt}') 31 | with open(tgt, 'w') as f: 32 | json.dump(merge, f) 33 | 34 | print(len(merge)) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Bright 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/evaluate_groundvqa_s-qaego4d.sh: -------------------------------------------------------------------------------- 1 | # QAEgo4D-Close test set 2 | for SEED in 0 1111 2222 3333 4444 3 | do 4 | python run.py \ 5 | model=groundvqa_s \ 6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 7 | 'dataset.test_splits=[QaEgo4D_test_close]' \ 8 | dataset.batch_size=64 \ 9 | +trainer.test_only=True \ 10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-test_ROUGE=29.0.ckpt"' \ 11 | +trainer.random_seed=$SEED 12 | done 13 | 14 | # QAEgo4D-Open test set 15 | python run.py \ 16 | model=groundvqa_s \ 17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 18 | 'dataset.test_splits=[QaEgo4D_test]' \ 19 | dataset.batch_size=64 \ 20 | +trainer.test_only=True \ 21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-test_ROUGE=29.0.ckpt"' 22 | 23 | # NLQv2 val set 24 | python run.py \ 25 | model=groundvqa_s \ 26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 27 | 'dataset.test_splits=[NLQ_val]' \ 28 | dataset.batch_size=64 \ 29 | +trainer.test_only=True \ 30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-val_R1_03=11.0.ckpt"' \ 31 | trainer.load_nlq_head=True 32 | -------------------------------------------------------------------------------- /scripts/evaluate_groundvqa_b-qaego4d_egotimeqa.sh: -------------------------------------------------------------------------------- 1 | # QAEgo4D-Close test set 2 | for SEED in 0 1111 2222 3333 4444 3 | do 4 | python run.py \ 5 | model=groundvqa_b \ 6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 7 | 'dataset.test_splits=[QaEgo4D_test_close]' \ 8 | dataset.batch_size=32 \ 9 | +trainer.test_only=True \ 10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.4.ckpt"' \ 11 | +trainer.random_seed=$SEED 12 | done 13 | 14 | # QAEgo4D-Open test set 15 | python run.py \ 16 | model=groundvqa_b \ 17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 18 | 'dataset.test_splits=[QaEgo4D_test]' \ 19 | dataset.batch_size=32 \ 20 | +trainer.test_only=True \ 21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.4.ckpt"' 22 | 23 | # NLQv2 val set 24 | python run.py \ 25 | model=groundvqa_b \ 26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 27 | 'dataset.test_splits=[NLQ_val]' \ 28 | dataset.batch_size=32 \ 29 | +trainer.test_only=True \ 30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-val_R1_03=25.6.ckpt"' \ 31 | trainer.load_nlq_head=True -------------------------------------------------------------------------------- /scripts/evaluate_groundvqa_s-qaego4d_egotimeqa.sh: -------------------------------------------------------------------------------- 1 | # QAEgo4D-Close test set 2 | for SEED in 0 1111 2222 3333 4444 3 | do 4 | python run.py \ 5 | model=groundvqa_s \ 6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 7 | 'dataset.test_splits=[QaEgo4D_test_close]' \ 8 | dataset.batch_size=64 \ 9 | +trainer.test_only=True \ 10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.2.ckpt"' \ 11 | +trainer.random_seed=$SEED 12 | done 13 | 14 | # QAEgo4D-Open test set 15 | python run.py \ 16 | model=groundvqa_s \ 17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 18 | 'dataset.test_splits=[QaEgo4D_test]' \ 19 | dataset.batch_size=64 \ 20 | +trainer.test_only=True \ 21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.2.ckpt"' 22 | 23 | # NLQv2 val set 24 | python run.py \ 25 | model=groundvqa_s \ 26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \ 27 | 'dataset.test_splits=[NLQ_val]' \ 28 | dataset.batch_size=64 \ 29 | +trainer.test_only=True \ 30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-val_R1_03=23.3.ckpt"' \ 31 | trainer.load_nlq_head=True -------------------------------------------------------------------------------- /model/ours/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM 4 | from transformers.modeling_outputs import BaseModelOutput 5 | 6 | from model.ours.nlq_head import NLQHead 7 | 8 | 9 | class GroundVQA(nn.Module): 10 | def __init__(self, lm_path, input_dim, freeze_word=False, max_v_len=256): 11 | super().__init__() 12 | 13 | if not isinstance(input_dim, int): 14 | input_dim = input_dim.v_dim 15 | 16 | self.lm: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(lm_path, local_files_only=True) 17 | 18 | lm_dim = self.lm.get_input_embeddings().embedding_dim 19 | self.lm_proj = nn.Linear(input_dim, lm_dim) 20 | self.v_emb = nn.Parameter(torch.randn((1, 1, lm_dim))) 21 | if freeze_word: 22 | for name, param in self.lm.named_parameters(): 23 | if 'shared' in name: 24 | param.requires_grad = False 25 | 26 | self.nlq_head = NLQHead(in_dim=lm_dim, max_v_len=max_v_len) 27 | 28 | def forward(self, v_feat, v_mask, q_token, q_mask, gt_segments, gt_labels, 29 | labels=None, **remains): 30 | # encoder 31 | encoder_out, mask = self.forward_encoder(v_feat, v_mask, q_token, q_mask) 32 | 33 | # localizer 34 | encoder_out_v = encoder_out[:, -v_feat.shape[1]:] 35 | nlq_results = self.nlq_head( 36 | feat=encoder_out_v.permute(0, 2, 1), # (B, D, T) 37 | mask=v_mask.unsqueeze(1), # (B, 1, T) 38 | gt_segments=gt_segments, 39 | gt_labels=gt_labels 40 | ) 41 | time_loss = nlq_results['final_loss'] * 1.0 42 | 43 | # decoder 44 | outputs = self.lm( 45 | encoder_outputs=(encoder_out,), 46 | attention_mask=mask, 47 | labels=labels, 48 | ) 49 | lm_loss = outputs.loss 50 | 51 | total_loss = 0.5 * time_loss + 0.5 * lm_loss 52 | 53 | return total_loss, lm_loss, time_loss 54 | 55 | def generate(self, v_feat, v_mask, q_token, q_mask, v_len, **remains): 56 | encoder_out, mask = self.forward_encoder(v_feat, v_mask, q_token, q_mask) 57 | encoder_out_v = encoder_out[:, -v_feat.shape[1]:] 58 | 59 | nlq_results = self.nlq_head( 60 | feat=encoder_out_v.permute(0, 2, 1), # (B, D, T) 61 | mask=v_mask.unsqueeze(1), # (B, 1, T) 62 | training=False, 63 | v_lens=v_len 64 | ) 65 | answer_tokens = self.lm.generate( 66 | encoder_outputs=BaseModelOutput(last_hidden_state=encoder_out), 67 | attention_mask=mask, 68 | max_new_tokens=32 69 | ) 70 | 71 | return nlq_results, answer_tokens 72 | 73 | def forward_encoder(self, v_feat, v_mask, q_token, q_mask): 74 | B, L, D = v_feat.shape 75 | v_feat = self.lm_proj(v_feat) 76 | v_feat = v_feat + self.v_emb.expand((B, L, -1)) 77 | q_feat = self.lm.encoder.embed_tokens(q_token) 78 | lm_input = torch.cat([q_feat, v_feat], dim=1) 79 | lm_mask = torch.cat([q_mask, v_mask], dim=1) 80 | out = self.lm.encoder( 81 | inputs_embeds=lm_input, 82 | attention_mask=lm_mask 83 | ) 84 | return out.last_hidden_state, lm_mask 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data/ 3 | lightning_logs 4 | utils/generate_open_qa/tmp 5 | utils/generate_close_qa/tmp 6 | *.pkl 7 | *.json 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /utils/generate_close_qa/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import literal_eval 3 | 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | import transformers 6 | import torch 7 | from tqdm import tqdm 8 | import json 9 | 10 | 11 | token = '' # your access token to Llama2 12 | 13 | model_id = 'meta-llama/Llama-2-13b-chat-hf' 14 | batch_size = 4 15 | 16 | # fp16 17 | tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=token) 18 | pipeline = transformers.pipeline( 19 | "text-generation", 20 | model=model_id, 21 | torch_dtype=torch.float16, 22 | device_map="auto", 23 | tokenizer=tokenizer, 24 | token=token 25 | ) 26 | # fix a bug: https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020 27 | pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token 28 | pipeline.model.config.pad_token_id = pipeline.model.config.eos_token_id 29 | 30 | 31 | class QADataset(torch.utils.data.Dataset): 32 | def __init__(self, qa_annotations, start, end): 33 | self.qa_annotations = qa_annotations 34 | if start is not None and end is not None: 35 | self.qa_annotations = self.qa_annotations[start:end] 36 | 37 | print('preparing LM prompts...') 38 | self.lm_inputs = [] 39 | for data in tqdm(self.qa_annotations): 40 | self.lm_inputs.append(self._get_lm_input(data['question'], data['answer'])) 41 | 42 | def _get_lm_input(self, question, answer): 43 | return f"""[INST] <> 44 | I'll provide a question and its correct answer. Generate three plausible, but incorrect, answers that closely resemble the correct one Make it challenging to identify the right answer. 45 | <> 46 | 47 | No preamble, get right to the three wrong answers and present them in a list format. Question: How many frying pans can i see on the shelf? Correct Answer: two pieces. Wrong Answers: [/INST] [\"one piece\", \"three piece\", \"five pieces\"] 48 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: What colour bowl did i carry from the plate stand? Correct Answer: green. Wrong Answers: [/INST] [\"blue\", \"black\", \"white\"] 49 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: What did i pour in the bowl? Correct Answer: boiling water. Wrong Answers: [/INST] [\"hot oil\", \"steamed milk\", \"warm broth\"] 50 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: {question} Correct Answer: {answer}. Wrong Answers: [/INST] 51 | 52 | """ 53 | 54 | def __len__(self): 55 | return len(self.qa_annotations) 56 | 57 | def __getitem__(self, idx): 58 | return self.lm_inputs[idx] 59 | 60 | def get_data(self, idx): 61 | return self.qa_annotations[idx] 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--start', type=int, default=None) 67 | parser.add_argument('--end', type=int, default=None) 68 | args = parser.parse_args() 69 | 70 | with open('annotations.EgoTimeQA.json', 'r') as f: 71 | qa_annotations = json.load(f) 72 | 73 | dataset = QADataset(qa_annotations, args.start, args.end) 74 | 75 | errors = 0 76 | pbar = tqdm(total=len(dataset)) 77 | res = [] 78 | for idx, out in enumerate(pipeline( 79 | dataset, 80 | batch_size=64, 81 | do_sample=True, 82 | temperature=0.5, 83 | top_k=10, 84 | num_return_sequences=1, 85 | eos_token_id=tokenizer.eos_token_id, 86 | max_new_tokens=64, 87 | return_full_text=False, 88 | )): 89 | pbar.set_description('Errors: %d' % errors) 90 | pbar.update(1) 91 | gen_result = out[0]['generated_text'] 92 | 93 | try: # may not generate the desired format 94 | wrong_answers = literal_eval(gen_result) 95 | assert isinstance(wrong_answers, list) and len(wrong_answers) == 3 96 | data = dataset.get_data(idx) 97 | data['wrong_answers'] = wrong_answers 98 | res.append(data) 99 | success = True 100 | except: 101 | errors += 1 102 | print(gen_result) 103 | 104 | print(f'#{len(res)} / {len(dataset)} samples generated!') 105 | 106 | if args.start is not None and args.end is not None: 107 | with open(f'tmp/annotations.EgoTimeQA_{args.start}_{args.end}.json', 'w') as f: 108 | json.dump(res, f) 109 | else: 110 | with open('annotations.EgoTimeQA.json', 'w') as f: 111 | json.dump(res, f) 112 | -------------------------------------------------------------------------------- /utils/generate_open_qa/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from ast import literal_eval 4 | 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | import transformers 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | import json 11 | import pickle 12 | 13 | 14 | token = '' # your access token to Llama2 15 | 16 | model_id = 'meta-llama/Llama-2-13b-chat-hf' 17 | model_id = "/root/.cache/huggingface/models--meta-llama--Llama-2-13b-chat-hf" 18 | batch_size = 4 19 | 20 | # fp16 21 | tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=token) 22 | pipeline = transformers.pipeline( 23 | "text-generation", 24 | model=model_id, 25 | torch_dtype=torch.float16, 26 | device_map="auto", 27 | tokenizer=tokenizer, 28 | token=token 29 | ) 30 | # fix a bug: https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020 31 | pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token 32 | pipeline.model.config.pad_token_id = pipeline.model.config.eos_token_id 33 | 34 | 35 | class NarrationDataset(torch.utils.data.Dataset): 36 | def __init__(self, clip_narrations): 37 | self.prompt = """[INST] <> 38 | You are an AI Assistant and always write the output of your response in JSON. I will provide you with a series of narrations that depict my behavior. You should generate one QA pair based on the narrations in the format of {\"Q\": , \"A\": }. In the narrations, \"C\" represents me, and \"O\" represents someone else. Use as much information as possible from narrations to generate the question, and the question you generate should be able to be answered using the information provided in the narrations. The question should be in the past tense. The question should be within 10 words, and the answer should be within 5 words. 39 | <> 40 | 41 | C pours hot water from the frying pan in his left hand into the bowl in his right hand. [/INST] {\"Q\": \"What did I pour in the bowl?\", \"A\": \"boiling water\"} 42 | [INST] C searches through the cabinet. C closes the cabinet. C picks the tin from the cabinet. C places the tin on the counter. [/INST] {\"Q\": \"Where was the tin before I took it?\", \"A\": \"at the cabinet\"} 43 | [INST] C turns on sink knob. C washes the cucumber on the sink. C turns off sink knob. [/INST] {\"Q\": \"Did I wash the cucumber?\", \"A\": \"yes\"} 44 | """ 45 | self.narrations = self._prepare(clip_narrations) 46 | 47 | def _prepare(self, clip_narrations): 48 | sampled_narrations = [] 49 | for c in clip_narrations: 50 | clip_uid = c['clip_uid'] 51 | narration_pass = c['narration_pass'] 52 | narrations = c['narrations'] 53 | 54 | idx = 0 55 | while idx < len(narrations): 56 | sampled = self._sample_narrations(narrations, start=idx) 57 | start_sec = sampled[0]['timestamps'][0] 58 | end_sec = sampled[-1]['timestamps'][1] 59 | narration_texts = ' '.join([n['narration_text'] for n in sampled]) 60 | lm_input = self.prompt + "[INST] " + narration_texts + "[/INST]\n" 61 | sampled_narrations.append({ 62 | 'clip_uid': clip_uid, 63 | 'narration_pass': narration_pass, 64 | 'narrations': narration_texts, 65 | 'start_sec': start_sec, 66 | 'end_sec': end_sec, 67 | 'lm_input': lm_input, 68 | 'n_narration': len(sampled) 69 | }) 70 | idx += len(sampled) 71 | 72 | return sampled_narrations 73 | 74 | def _sample_narrations(self, narrations, start, max_n=5, max_timespan=30): 75 | end = min(len(narrations), start+max_n) 76 | while start < end - 1 and max_timespan <= narrations[end-1]['timestamps'][1] - narrations[start]['timestamps'][0]: 77 | end -= 1 78 | end = np.random.choice(np.arange(start, end), 1)[0] + 1 79 | return narrations[start:end] 80 | 81 | def __len__(self): 82 | return len(self.narrations) 83 | 84 | def __getitem__(self, idx): 85 | return self.narrations[idx]['lm_input'] 86 | 87 | def get_clip_uid(self, idx): 88 | return self.narrations[idx]['clip_uid'] 89 | 90 | def get_start_sec(self, idx): 91 | return self.narrations[idx]['start_sec'] 92 | 93 | def get_end_sec(self, idx): 94 | return self.narrations[idx]['end_sec'] 95 | 96 | def get_narrations(self, idx): 97 | return self.narrations[idx]['narrations'] 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('-start', type=int, default=None) 103 | parser.add_argument('-end', type=int, default=None) 104 | args = parser.parse_args() 105 | 106 | with open('em_train_narrations.pkl', 'rb') as f: 107 | clip_narrations = pickle.load(f) 108 | print(len(clip_narrations)) 109 | os.makedirs('tmp', exist_ok=True) 110 | 111 | if args.start is not None and args.end is not None: 112 | clip_narrations = clip_narrations[args.start:args.end] 113 | save_path = f'tmp/annotations.EgoTimeQA_{args.start}_{args.end}.json' 114 | else: 115 | save_path = 'annotations.EgoTimeQA.json' 116 | 117 | dataset = NarrationDataset(clip_narrations) 118 | 119 | errors = 0 120 | pbar = tqdm(total=len(dataset)) 121 | res = [] 122 | for idx, out in enumerate(pipeline( 123 | dataset, 124 | batch_size=32, 125 | do_sample=True, 126 | temperature=0.5, 127 | top_k=10, 128 | num_return_sequences=1, 129 | eos_token_id=tokenizer.eos_token_id, 130 | max_new_tokens=64, 131 | return_full_text=False, 132 | )): 133 | pbar.set_description(f'Errors: {errors}') 134 | pbar.update(1) 135 | gen_result = out[0]['generated_text'] 136 | try: # may not generate in JSON format 137 | qa = literal_eval(gen_result) 138 | question = qa['Q'] 139 | answer = qa['A'] 140 | start_sec = dataset.get_start_sec(idx) 141 | end_sec = dataset.get_end_sec(idx) 142 | start_frame = int(start_sec * 30) 143 | end_frame = int(end_sec * 30) 144 | res.append({ 145 | 'video_id': dataset.get_clip_uid(idx), 146 | 'sample_id': None, 147 | 'answer': answer, 148 | 'question': question, 149 | 'start_sec': start_sec, 150 | 'end_sec': end_sec, 151 | "moment_start_frame": start_frame, 152 | "moment_end_frame": end_frame, 153 | 'narrations': dataset.get_narrations(idx) 154 | }) 155 | except: 156 | errors += 1 157 | continue 158 | 159 | with open(save_path, 'w') as f: 160 | json.dump(res, f) 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GroundVQA 2 | 3 | Official PyTorch code of "Grounded Question-Answering in Long Egocentric Videos", *CVPR* 2024. 4 | 5 | [[Project page]](https://dszdsz.cn/GroundVQA/index.html) [[Paper]](https://arxiv.org/abs/2312.06505) 6 | 7 | ## News 8 | 9 | [Apr 2024] We have updated the CloseQA benchmarking results, now available on [arXiv [v4]](https://arxiv.org/pdf/2312.06505). 10 | 11 | [Feb 2024] We update the CloseQA test set with rigorous human verification. The benchmark results will be updated in our paper shortly. 12 | 13 | ## Abstract 14 | 15 | Existing approaches to video understanding, mainly designed for short videos from a third-person perspective, are limited in their applicability in certain fields, such as robotics. In this paper, we delve into **open-ended question-answering (QA) in long, egocentric videos**, which allows individuals or robots to inquire about their own past visual experiences. 16 | 17 | This task presents **unique challenges**, including the complexity of temporally grounding queries within extensive video content, the high resource demands for precise data annotation, and the inherent difficulty of evaluating open-ended answers due to their ambiguous nature. 18 | 19 | Our proposed approach tackles these challenges by 20 | 21 | - **GroundVQA**: integrating query grounding and answering within a unified model to reduce error propagation; 22 | - **EgoTimeQA**: employing large language models for efficient and scalable data synthesis; 23 | - **QaEgo4D**$`_\texttt{close}`$: introducing a close-ended QA task for evaluation, to manage answer ambiguity. 24 | 25 | Extensive experiments demonstrate the effectiveness of our method, which also achieves state-of-the-art performance on the QaEgo4D and Ego4D-NLQ benchmarks. 26 | 27 | ## Directory Structure 28 | 29 | ``` 30 | . 31 | |-- checkpoints provided model checkpoints 32 | |-- config configs of models and datasets 33 | |-- data processed dataset and video features 34 | |-- eval.py code for evaluating QaEgo4D performance 35 | |-- eval_nlq.py code for evaluating NLQ performance 36 | |-- model code for model, dataset, and training 37 | |-- requirements.txt list of packages for building the Python environment 38 | |-- run.py entry code 39 | |-- scripts scripts for training and evaluation 40 | `-- utils code for generating OpenQA and CloseQA data from Ego4D narrations 41 | ``` 42 | 43 | ## Preparation 44 | 45 | Our setup: Ubuntu 20.04, CUDA 12.2, 8x Nvidia A100 (80GB) 46 | 47 | - Clone this repo: `https://github.com/Becomebright/GroundVQA.git` 48 | - Create the conda environment: `conda create -n groundvqa python=3.9 -y && conda activate groundvqa` 49 | - Install packages: `pip install -r requirements.txt` 50 | - Compile `nms_1d_cpu` following [here](https://github.com/happyharrycn/actionformer_release/blob/main/INSTALL.md) 51 | - Download the data, video feature, and model checkpoints from [Huggingface](https://huggingface.co/Becomebright/GroundVQA) 52 | - **data:** unzip `data.zip` under the project's root directory. 53 | - **video feature:** merge the files `cat egovlp_internvideoa* > egovlp_internvideo.hdf5` and put it under `data/unified/` 54 | - **model checkpoints**: put them under `checkpoints/` 55 | 56 | | Model | Data | Task | NLQ$`_\texttt{v2}`$ | QaEgo4D | Cost$`^{*}`$ | 57 | | ------------------------------- | ------------------------------------------------------------ | ---------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------- | 58 | | $`\text{GroundVQA}_\texttt{S}`$ | QaEgo4D | CloseQA+OpenQA+VLG | [[val_R1_03=11.0]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D-COV-val_R1_03%3D11.0.ckpt) | [[test_ROUGE=29.0]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D-COV-test_ROUGE%3D29.0.ckpt) | 7 | 59 | | $`\text{GroundVQA}_\texttt{S}`$ | QaEgo4D+EgoTimeQA | CloseQA+OpenQA+VLG | [[val_R1_03=23.3]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-val_R1_03%3D23.3.ckpt) | [[test_ROUGE=30.2]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE%3D30.2.ckpt) | 150 | 60 | | $`\text{GroundVQA}_\texttt{B}`$ | QaEgo4D+EgoTimeQA | CloseQA+OpenQA+VLG | [[val_R1_03=25.6]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-val_R1_03%3D25.6.ckpt) | [[test_ROUGE=30.4]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE%3D30.4.ckpt) | 350 | 61 | | $`\text{GroundVQA}_\texttt{B}`$ | NLQ$`_\texttt{v2}`$+NaQ $\rightarrow$ NLQ$`_\texttt{v2}`$$`^{**}`$ | VLG | [[val_R1_03=29.7]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-NLQ_NaQ-finetune_NLQ-VLG-val_R1_03%3D29.7.ckpt) | - | 700 | 62 | 63 | \* The training costs counted by GPU hours. 64 | 65 | ** Pre-trained on NLQ$`_\texttt{v2}`$ and NaQ, and further fine-tuned on NLQ$`_\texttt{v2}`$​. 66 | 67 | ## Training 68 | 69 | ```bash 70 | # train GroundVQA_S on QaEgo4D 71 | bash scripts/train_groundvqa_small-qaego4d.sh 72 | 73 | # train GroundVQA_S on QaEgo4D and EgoTimeQA 74 | bash scripts/train_groundvqa_small-qaego4d_egotimeqa.sh 75 | 76 | # train GroundVQA_B on QaEgo4D and EgoTimeQA 77 | bash scripts/train_groundvqa_base-qaego4d_egotimeqa.sh 78 | ``` 79 | 80 | ## Evaluation 81 | 82 | ```bash 83 | # evaluate GroundVQA_S train on QaEgo4D 84 | bash scripts/evaluate_groundvqa_s-qaego4d.sh 85 | 86 | # evaluate GroundVQA_S train on QaEgo4D and EgoTimeQA 87 | bash scripts/evaluate_groundvqa_s-qaego4d_egotimeqa.sh 88 | 89 | # evaluate GroundVQA_B train on QaEgo4D and EgoTimeQA 90 | bash scripts/evaluate_groundvqa_b-qaego4d_egotimeqa.sh 91 | 92 | # evaluate GroundVQA_B train on NLQv2 and NaQ and further fine-tuned on NLQv2 93 | bash scripts/evaluate_groundvqa_b-nlq_naq.sh 94 | ``` 95 | 96 | ## Generate OpenQA data 97 | 98 | Download the processed Ego4D narrations [[em_train_narrations.pkl]](https://huggingface.co/Becomebright/GroundVQA/blob/main/em_train_narrations.pkl) 99 | 100 | Put it under `utils/generate_open_qa/` 101 | 102 | Generate QAs in parallel on multiple GPUs (*e.g.*, 2) 103 | 104 | ```bash 105 | cd utils/generate_open_qa 106 | 107 | # GPU-0 108 | CUDA_VISIBLE_DEVICES=0 python generate.py -start 0 -end 5000 109 | 110 | # GPU-1 111 | CUDA_VISIBLE_DEVICES=1 python generate.py -start 5000 -end 11000 # 10777 clips in total 112 | ``` 113 | 114 | Merge the results and normalize the duration of temporal windows 115 | 116 | ```bash 117 | python merge.py 118 | ``` 119 | 120 | ## Generate CloseQA data 121 | 122 | ```bash 123 | cd utils/generate_close_qa 124 | python generate.py 125 | ``` 126 | 127 | The above script produce wrong answers for EgoTimeQA using a single GPU. 128 | 129 | You can also conduct generation on multiple GPUs or generate wrong answers for QaEgo4D. 130 | 131 | ## Citation 132 | 133 | ```latex 134 | @inproceedings{di2023groundvqa, 135 | title={Grounded Question-Answering in Long Egocentric Videos}, 136 | author={Di, Shangzhe and Xie, Weidi}, 137 | booktitle={CVPR}, 138 | year={2024} 139 | } 140 | ``` 141 | 142 | ## Acknowledgements 143 | 144 | Our code is based on [QaEgo4D](https://github.com/lbaermann/qaego4d), [GroundNLQ](https://github.com/houzhijian/GroundNLQ), and [ActionFormer](https://github.com/happyharrycn/actionformer_release). 145 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | from argparse import ArgumentParser, Namespace 4 | 5 | import hydra 6 | import torch 7 | import pytorch_lightning as pl 8 | from omegaconf import DictConfig, open_dict 9 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging 10 | from pytorch_lightning.plugins import DDPPlugin 11 | 12 | from model.ours.dataset import JointDataModule 13 | from model.ours.lightning_module import LightningModule 14 | 15 | 16 | def dict_parser(s: str): 17 | return eval('{' + re.sub(r'(\w+)=(["\']?\w+["\']?)', r'"\1":\2', s) + '}') 18 | 19 | def add_common_trainer_util_args(parser, default_monitor_variable='val_loss', default_monitor_mode='min'): 20 | if default_monitor_mode not in ['min', 'max']: 21 | raise ValueError(default_monitor_mode) 22 | parser.add_argument('--lr_find_kwargs', default=dict(min_lr=5e-6, max_lr=1e-2), type=dict_parser, 23 | help='Arguments for LR find (--auto_lr_find). Default "min_lr=5e-6,max_lr=1e-2"') 24 | parser.add_argument('--random_seed', default=42, type=lambda s: None if s == 'None' else int(s), 25 | help='Seed everything. Set to "None" to disable global seeding') 26 | parser.add_argument('--auto_resume', default=False, action='store_true', 27 | help='Automatically resume last saved checkpoint, if available.') 28 | parser.add_argument('--test_only', default=False, action='store_true', 29 | help='Skip fit and call only test. This implies automatically detecting newest checkpoint, ' 30 | 'if --checkpoint_path is not given.') 31 | parser.add_argument('--checkpoint_path', default=None, type=str, 32 | help='Load this checkpoint to resume training or run testing. ' 33 | 'Pass in the special value "best" to use the best checkpoint according to ' 34 | 'args.monitor_variable and args.monitor_mode. ' 35 | 'Using "best" only works with test_only mode.') 36 | parser.add_argument('--ignore_existing_checkpoints', default=False, action='store_true', 37 | help='Proceed even with training a new model, even if previous checkpoints exists.') 38 | parser.add_argument('--monitor_variable', default=default_monitor_variable, type=str, 39 | help='Variable to monitor for early stopping and for checkpoint selection. ' 40 | f'Default: {default_monitor_variable}') 41 | parser.add_argument('--monitor_mode', default=default_monitor_mode, type=str, choices=['min', 'max'], 42 | help='Mode for monitoring the monitor_variable (for early stopping and checkpoint selection). ' 43 | f'Default: {default_monitor_mode}') 44 | parser.add_argument('--reset_early_stopping_criterion', default=False, action='store_true', 45 | help='Reset the early stopping criterion when loading from checkpoint. ' 46 | 'Prevents immediate exit after switching to more complex dataset in curriculum strategy') 47 | 48 | def apply_argparse_defaults_to_hydra_config(config: DictConfig, parser: ArgumentParser, verbose=False): 49 | args = parser.parse_args([]) # Parser is not allowed to have required args, otherwise this will fail! 50 | defaults = vars(args) 51 | 52 | def _apply_defaults(dest: DictConfig, source: dict, indentation=''): 53 | for k, v in source.items(): 54 | if k in dest and isinstance(v, dict): 55 | current_value = dest[k] 56 | if current_value is not None: 57 | assert isinstance(current_value, DictConfig) 58 | _apply_defaults(current_value, v, indentation + ' ') 59 | elif k not in dest: 60 | dest[k] = v 61 | if verbose: 62 | print(indentation, 'set default value for', k) 63 | 64 | with open_dict(config): 65 | _apply_defaults(config, defaults) 66 | 67 | 68 | def _adjust_ddp_config(trainer_cfg): 69 | trainer_cfg = dict(trainer_cfg) 70 | strategy = trainer_cfg.get('strategy', None) 71 | if trainer_cfg['gpus'] > 1 and strategy is None: 72 | strategy = 'ddp' # Select ddp by default 73 | if strategy == 'ddp': 74 | trainer_cfg['strategy'] = DDPPlugin( 75 | find_unused_parameters=trainer_cfg['find_unused_parameters'], 76 | gradient_as_bucket_view=True) 77 | return trainer_cfg 78 | 79 | 80 | @hydra.main(config_path='config', config_name='base') 81 | def train(config: DictConfig): 82 | fake_parser = ArgumentParser() 83 | add_common_trainer_util_args(fake_parser, default_monitor_variable='val_loss') 84 | apply_argparse_defaults_to_hydra_config(config.trainer, fake_parser) 85 | pl.seed_everything(config.trainer.random_seed, workers=True) 86 | trainer_cfg = Namespace(**_adjust_ddp_config(config.trainer)) 87 | 88 | data = JointDataModule(config.dataset) 89 | data.setup() 90 | 91 | total_steps = trainer_cfg.max_epochs * math.floor(len(data.train_dataset) / trainer_cfg.gpus / config.dataset.batch_size) 92 | model = LightningModule(config, total_steps) 93 | if trainer_cfg.checkpoint_path: 94 | state_dict = torch.load(trainer_cfg.checkpoint_path, map_location='cpu')['state_dict'] 95 | if not trainer_cfg.load_nlq_head: 96 | print('Train NLQ head from scratch') 97 | state_dict = {k: v for k, v in state_dict.items() if not "nlq_head" in k} 98 | if not trainer_cfg.load_decoder: 99 | print('Train LM decoder head from scratch') 100 | state_dict = {k: v for k, v in state_dict.items() if not ("decoder" in k or "lm_head" in k)} 101 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 102 | print(f'Load checkpoint: {trainer_cfg.checkpoint_path}') 103 | print(f'Missing Keys: {missing_keys}') 104 | print(f'Unexpected Keys: {unexpected_keys}') 105 | 106 | 107 | if trainer_cfg.test_only: # evaluation 108 | trainer = pl.Trainer.from_argparse_args( 109 | trainer_cfg, 110 | enable_checkpointing=False, 111 | logger=False 112 | ) 113 | if trainer_cfg.val: 114 | trainer.validate( 115 | model, data.val_dataloader(), 116 | ) 117 | else: 118 | trainer.test( 119 | model, data.test_dataloader(), 120 | ) 121 | else: # training 122 | model_checkpoint = [] 123 | if 'QaEgo4D_test' in config.dataset.test_splits: 124 | model_checkpoint.append( 125 | ModelCheckpoint( 126 | save_last=False, 127 | monitor='val_ROUGE', 128 | mode='max', 129 | save_top_k=1, 130 | filename='{step}-{' + 'val_ROUGE' + ':.3f}') 131 | ) 132 | if 'QaEgo4D_test_close' in config.dataset.test_splits: 133 | model_checkpoint.append( 134 | ModelCheckpoint( 135 | save_last=False, 136 | monitor='val_close_acc', 137 | mode='max', 138 | save_top_k=1, 139 | filename='{step}-{' + 'val_close_acc' + ':.3f}') 140 | ) 141 | if 'NLQ_val' in config.dataset.test_splits: 142 | model_checkpoint.append( 143 | ModelCheckpoint( 144 | save_last=False, 145 | monitor='val_R1_03', 146 | mode='max', 147 | save_top_k=1, 148 | filename='{step}-{' + 'val_R1_03' + ':.3f}') 149 | ) 150 | trainer = pl.Trainer.from_argparse_args(trainer_cfg, callbacks=[ 151 | LearningRateMonitor(logging_interval='step'), 152 | # StochasticWeightAveraging(swa_lrs=1e-2), 153 | *model_checkpoint 154 | ]) 155 | trainer.fit( 156 | model, data.train_dataloader(), data.val_dataloader(), 157 | ) 158 | 159 | 160 | if __name__ == '__main__': 161 | train() 162 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | from pathlib import Path 5 | from pprint import pprint 6 | from typing import List, Dict, Any 7 | 8 | # import bert_score 9 | import sentence_transformers 10 | from nltk.translate.meteor_score import meteor_score 11 | from rouge_score.rouge_scorer import RougeScorer 12 | from rouge_score.tokenize import tokenize 13 | # from sacrebleu.metrics import BLEU, BLEUScore 14 | from torchmetrics.functional import sacre_bleu_score 15 | from nltk.tokenize import word_tokenize 16 | from nltk.corpus import wordnet 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | # Check whether to use 39 | # - https://github.com/Maluuba/nlg-eval 40 | # - https://github.com/hwanheelee1993/KPQA 41 | def calc_metrics(predictions: List[str], gold_annotations: List[List[str]], test=False) -> Dict[str, Any]: 42 | """ 43 | Calculate metrics. 44 | 45 | Parameters 46 | ---------- 47 | predictions : list[str] 48 | The list of predictions 49 | gold_annotations : list[list[str]] 50 | A list with the same length as predictions. 51 | Each element is a list of possible target candidates for the corresponding prediction. 52 | All elements should have the same length. 53 | """ 54 | if len(predictions) != len(gold_annotations): 55 | raise ValueError(f'{len(predictions)} != {len(gold_annotations)}') 56 | ref_count = len(gold_annotations[0]) 57 | if any(len(refs) != ref_count for refs in gold_annotations): 58 | raise ValueError(f'All refs should have the same length {ref_count}!') 59 | 60 | acc = _calc_accuracy(predictions, gold_annotations) 61 | # bleu = _calc_bleu(predictions, gold_annotations) 62 | rouge = _calc_rouge(predictions, gold_annotations) 63 | meteor = _calc_meteor(predictions, gold_annotations) 64 | # bert_score = _calc_bertscore(predictions, gold_annotations) 65 | # wups = _calc_wups(predictions, gold_annotations) 66 | if test: 67 | sts = SentenceTransformerSimilarity() 68 | sts_score = sts.calc_st_similarity(predictions, gold_annotations) 69 | 70 | return { 71 | 'plain_acc': acc, 72 | # **bleu, 73 | 'ROUGE': rouge['rougeL']['f'], 74 | **_flatten_dict(rouge, prefix='ROUGE.'), 75 | 'METEOR': meteor, 76 | 'SentenceSimilarity': sts_score if test else 0. 77 | # 'BERTSCORE': bert_score, 78 | # 'WUPS': wups 79 | } 80 | 81 | 82 | """ Sentence Transformer """ 83 | class SentenceTransformerSimilarity: 84 | def __init__(self): 85 | self.model = sentence_transformers.SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') 86 | 87 | def _calc_similarity(self, pred, gts): 88 | pred_emb = self.model.encode(pred) 89 | gts_emb = self.model.encode(gts) 90 | score = sentence_transformers.util.dot_score(pred_emb, gts_emb)[0,0].cpu() 91 | return float(score) 92 | 93 | def calc_st_similarity(self, predictions, gold_annotations): 94 | total_score = 0. 95 | for pred, gts in zip(predictions, gold_annotations): 96 | score = self._calc_similarity(pred, gts) 97 | total_score += score 98 | return total_score / len(predictions) 99 | 100 | 101 | """ WUPS """ 102 | # ==================================================== 103 | # @Time : 13/9/20 4:19 PM 104 | # @Author : Xiao Junbin 105 | # @Email : junbin@comp.nus.edu.sg 106 | # @File : metrics.py 107 | # ==================================================== 108 | 109 | 110 | def wup(word1, word2, alpha): 111 | """ 112 | calculate the wup similarity 113 | :param word1: 114 | :param word2: 115 | :param alpha: 116 | :return: 117 | """ 118 | # print(word1, word2) 119 | if word1 == word2: 120 | return 1.0 121 | 122 | w1 = wordnet.synsets(word1) 123 | w1_len = len(w1) 124 | if w1_len == 0: return 0.0 125 | w2 = wordnet.synsets(word2) 126 | w2_len = len(w2) 127 | if w2_len == 0: return 0.0 128 | 129 | #match the first 130 | word_sim = w1[0].wup_similarity(w2[0]) 131 | if word_sim is None: 132 | word_sim = 0.0 133 | 134 | if word_sim < alpha: 135 | word_sim = 0.1*word_sim 136 | return word_sim 137 | 138 | def wups(words1, words2, alpha): 139 | """ 140 | 141 | :param pred: 142 | :param truth: 143 | :param alpha: 144 | :return: 145 | """ 146 | sim = 1.0 147 | flag = False 148 | for w1 in words1: 149 | max_sim = 0 150 | for w2 in words2: 151 | word_sim = wup(w1, w2, alpha) 152 | if word_sim > max_sim: 153 | max_sim = word_sim 154 | if max_sim == 0: continue 155 | sim *= max_sim 156 | flag = True 157 | if not flag: 158 | sim = 0.0 159 | return sim 160 | 161 | def get_wups(pred, truth, alpha=0): 162 | """ 163 | calculate the wups score 164 | :param pred: 165 | :param truth: 166 | :return: 167 | """ 168 | pred = word_tokenize(pred) 169 | truth = word_tokenize(truth) 170 | item1 = wups(pred, truth, alpha) 171 | item2 = wups(truth, pred, alpha) 172 | value = min(item1, item2) 173 | return value 174 | 175 | def _calc_wups(predictions, gold_annotations): 176 | wups = 0 177 | for pred, gt in zip(predictions, gold_annotations): 178 | wups += get_wups(pred, gt[0]) 179 | wups /= len(predictions) 180 | return wups 181 | """ WUPS """ 182 | 183 | 184 | # def _calc_bertscore(predictions, gold_annotations): 185 | # references = [x[0] for x in gold_annotations] 186 | # P, R, F1 = bert_score.score( 187 | # predictions, references, lang='en', 188 | # model_type='microsoft/deberta-xlarge-mnli', 189 | # ) 190 | # return float(F1.mean()) 191 | 192 | 193 | def _calc_accuracy(predictions, gold_annotations): 194 | correct = 0 195 | for pred, possible_refs in zip(predictions, gold_annotations): 196 | if any(ref == pred for ref in possible_refs): 197 | correct += 1 198 | total = len(predictions) 199 | return correct / total 200 | 201 | 202 | def _calc_meteor(predictions, gold_annotations): 203 | score = AverageMeter() 204 | for pred, possible_refs in zip(predictions, gold_annotations): 205 | pred = tokenize(pred, None) 206 | # https://github.com/cmu-mtlab/meteor/blob/master/src/edu/cmu/meteor/util/Normalizer.java 207 | possible_refs = [tokenize(x, None) for x in possible_refs] 208 | score.update(meteor_score(possible_refs, pred)) 209 | return score.avg 210 | 211 | 212 | def _calc_rouge(predictions, gold_annotations) -> Dict[str, Dict[str, float]]: 213 | rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False) 214 | rouge = defaultdict(lambda: defaultdict(AverageMeter)) 215 | for pred, possible_refs in zip(predictions, gold_annotations): 216 | sample_result = {} 217 | for ref in possible_refs: 218 | single_ref_result = rouge_scorer.score(ref, pred) 219 | for k, scores in single_ref_result.items(): 220 | existing_result_dict = sample_result.setdefault(k, {}) 221 | if existing_result_dict.get('f', -1) < scores.fmeasure: 222 | existing_result_dict.update(f=scores.fmeasure, p=scores.precision, r=scores.recall) 223 | for k, best_scores in sample_result.items(): 224 | rouge[k]['p'].update(best_scores['p']) 225 | rouge[k]['r'].update(best_scores['r']) 226 | rouge[k]['f'].update(best_scores['f']) 227 | return { 228 | rouge_type: { 229 | measure: score.avg 230 | for measure, score in results.items() 231 | } for rouge_type, results in rouge.items() 232 | } 233 | 234 | 235 | def _calc_bleu(predictions, gold_annotations) -> Dict[str, float]: 236 | return { 237 | 'BLEU': sacre_bleu_score(predictions, gold_annotations, n_gram=1) 238 | } 239 | # refs_transposed = [ 240 | # [refs[i] for refs in gold_annotations] 241 | # for i in range(len(gold_annotations[0])) 242 | # ] 243 | # bleu: BLEUScore = BLEU().corpus_score(predictions, refs_transposed) 244 | # return { 245 | # 'BLEU': bleu.score, 246 | # 'BLEU.bp': bleu.bp, 247 | # 'BLEU.ratio': bleu.ratio, 248 | # 'BLEU.hyp_len': float(bleu.sys_len), 249 | # 'BLEU.ref_len': float(bleu.ref_len), 250 | # } 251 | 252 | 253 | def _flatten_dict(d, prefix=''): 254 | result = {} 255 | for k, v in d.items(): 256 | my_key = prefix + k 257 | if isinstance(v, dict): 258 | result.update(_flatten_dict(v, prefix=my_key + '.')) 259 | else: 260 | result[my_key] = v 261 | return result 262 | 263 | 264 | def main(): 265 | parser = ArgumentParser('Eval output file') 266 | parser.add_argument('--gold_answers', type=str, required=True, 267 | help='Path to answers.json, containing mapping from sample_id to answer') 268 | parser.add_argument('eval_file', type=str, 269 | help='JSON File to evaluate. Should contain mapping from sample_id ' 270 | 'to hypothesis or array of hypotheses') 271 | args = parser.parse_args() 272 | 273 | gold_answers = json.loads(Path(args.gold_answers).read_text()) 274 | hypotheses = json.loads(Path(args.eval_file).read_text()) 275 | if isinstance(next(iter(hypotheses.values())), list): 276 | hypotheses = {k: v[0] for k, v in hypotheses.items()} 277 | assert len(hypotheses.keys() - gold_answers.keys()) == 0, 'No gold answer for some hypotheses' 278 | 279 | gold_and_hypo = [(gold_answers[k], hypotheses[k]) for k in hypotheses.keys()] 280 | hypo_list = [h for g, h in gold_and_hypo] 281 | gold_list = [[g] for g, h in gold_and_hypo] 282 | metrics = calc_metrics(hypo_list, gold_list) 283 | 284 | pprint(metrics) 285 | 286 | 287 | if __name__ == '__main__': 288 | # main() 289 | 290 | # debug 291 | st = SentenceTransformerSimilarity() 292 | score = st._calc_similarity('inside the drawer', ['inside the drawer']) 293 | print(score) # 1.0 294 | 295 | score = st._calc_similarity('inside the drawer', ['on the table']) 296 | print(score) # 0.49 297 | 298 | score = st._calc_similarity('inside the drawer', ['in the drawer']) 299 | print(score) # 0.93 300 | 301 | # mean_score = st.calc_st_similarity( 302 | # ['floor', '3'], 303 | # [['on the ground'], ['two']] 304 | # ) 305 | # print(mean_score) 306 | -------------------------------------------------------------------------------- /model/ours/dataset.py: -------------------------------------------------------------------------------- 1 | # Joint dataset of CloseQA, OpenQA, and NLQ 2 | 3 | import os 4 | import math 5 | import json 6 | import random 7 | from pathlib import Path 8 | from typing import Iterable 9 | 10 | import h5py 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import pytorch_lightning as pl 15 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 16 | from torch.nn.utils.rnn import pad_sequence 17 | from torch.utils.data.dataset import Dataset 18 | from transformers import AutoTokenizer 19 | 20 | 21 | class BaseDataset(Dataset): 22 | def __init__(self, data_dir, split, feature_type, max_v_len): 23 | super().__init__() 24 | self.split = split 25 | self.video_features = h5py.File(os.path.join(data_dir, feature_type + '.hdf5'), 'r') 26 | self.annotations = json.loads(Path(os.path.join(data_dir, f'annotations.{split}.json')).read_text()) 27 | self.max_v_len = max_v_len 28 | print(f'{split} set: {len(self.annotations)}') 29 | 30 | def __len__(self): 31 | return len(self.annotations) 32 | 33 | def _get_video_feature(self, video_id): 34 | video_feature = torch.from_numpy(self.video_features[video_id][:]) 35 | v_len = video_feature.shape[0] 36 | sample_ratio = 1.0 37 | if v_len > self.max_v_len: 38 | sample_idx = torch.linspace(0, v_len-1, self.max_v_len).long() 39 | video_feature = video_feature[sample_idx] 40 | sample_ratio = self.max_v_len / v_len 41 | v_len = self.max_v_len 42 | return video_feature, v_len, sample_ratio 43 | 44 | 45 | class NLQDataset(BaseDataset): 46 | def __init__(self, data_dir, split, feature_type, max_v_len): 47 | super().__init__(data_dir, split, feature_type, max_v_len) 48 | 49 | def __getitem__(self, index): 50 | video_id = self.annotations[index]['video_id'] 51 | query_id = self.annotations[index].get('sample_id') 52 | question = self.annotations[index]['question'] 53 | 54 | video_feature, v_len, sample_ratio = self._get_video_feature(video_id) 55 | 56 | if 'clip_start_sec' in self.annotations[index]: 57 | start_time = self.annotations[index].get('clip_start_sec') 58 | end_time = self.annotations[index].get('clip_end_sec') 59 | else: 60 | start_time = self.annotations[index].get('moment_start_frame') / 30 61 | end_time = self.annotations[index].get('moment_end_frame') / 30 62 | 63 | query_type = self.annotations[index].get('query_type') 64 | if query_type == 'narration': 65 | duration = end_time - start_time 66 | center = (end_time + start_time) / 2 67 | scale_ratio = random.randint(1, 10) 68 | shift_number = random.uniform(-1, 1) * (scale_ratio - 1) * duration / 2 69 | new_center = center - shift_number 70 | start_time = new_center - scale_ratio * duration / 2 71 | end_time = new_center + scale_ratio * duration / 2 72 | 73 | segments = torch.tensor([[start_time, end_time]]) * 30 / 16.043 * sample_ratio 74 | labels = torch.zeros(len(segments), dtype=torch.int64) 75 | one_hot_labels = F.one_hot(labels, 1) # (1, 1) 76 | 77 | return { 78 | 'video_id': video_id, 79 | 'question': f"question: {question} video: ", 80 | 'answer': 'None', 81 | 'v_feat': video_feature, 82 | 'v_len': v_len, 83 | 'segments': segments, 84 | 'one_hot_labels': one_hot_labels, 85 | 'query_id': query_id, 86 | 'sample_ratio': sample_ratio, 87 | 'task': 'NLQ' 88 | } 89 | 90 | 91 | class QADataset(BaseDataset): 92 | def __init__(self, data_dir, split, feature_type, max_v_len, qa_type, CloseQA_weight=50): 93 | super().__init__(data_dir, split, feature_type, max_v_len) 94 | self.qa_type = qa_type # CloseQA, OpenQA, Mixed 95 | self.choice_indices = ['A', 'B', 'C', 'D'] 96 | self.CloseQA_weight = CloseQA_weight 97 | self.openqa_weight = 100 - CloseQA_weight 98 | 99 | def __getitem__(self, index): 100 | video_id = self.annotations[index]['video_id'] 101 | query_id = self.annotations[index].get('sample_id') 102 | question = self.annotations[index]['question'] 103 | answer = self.annotations[index]['answer'].strip() 104 | 105 | qa_type = self.qa_type 106 | if qa_type == 'Mixed': # randomly choose a qa type 107 | qa_type = random.choices(['CloseQA', 'OpenQA'], weights=[self.CloseQA_weight, self.openqa_weight], k=1)[0] 108 | if qa_type == 'OpenQA': 109 | question_str = f"question: {question} video: " 110 | answer_str = answer 111 | elif qa_type == 'CloseQA': 112 | wrong_answers = self.annotations[index]['wrong_answers'] 113 | # shuffle choices 114 | choices = [answer] + wrong_answers 115 | random.shuffle(choices) 116 | answer_index = choices.index(answer) 117 | choices = [f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))] # ["(A) xx", "(B) xx", "(C) xx", "(D) xx"] 118 | choices_str = ' '.join(choices) # (A) xx (B) xx (C) xx (D) xx 119 | question_str = f"question: {question} choices: {choices_str}. video: " 120 | answer_str = choices[answer_index] # (A/B/C/D) xx 121 | else: 122 | raise NotImplementedError 123 | 124 | video_feature, v_len, sample_ratio = self._get_video_feature(video_id) 125 | 126 | start_frame = self.annotations[index].get('moment_start_frame') 127 | end_frame = self.annotations[index].get('moment_end_frame') 128 | start_time = start_frame / 30 129 | end_time = end_frame / 30 130 | 131 | if 'video_start_sec' not in self.annotations[index]: # LLM generated QA 132 | duration = end_time - start_time 133 | center = (end_time + start_time) / 2 134 | scale_ratio = random.randint(1, 10) 135 | shift_number = random.uniform(-1, 1) * (scale_ratio - 1) * duration / 2 136 | new_center = center - shift_number 137 | start_time = new_center - scale_ratio * duration / 2 138 | end_time = new_center + scale_ratio * duration / 2 139 | 140 | segments = torch.tensor([[start_time, end_time]]) * 30 / 16.043 * sample_ratio 141 | labels = torch.zeros(len(segments), dtype=torch.int64) 142 | one_hot_labels = F.one_hot(labels, 1) # (1, 1) 143 | 144 | return { 145 | 'video_id': video_id, 146 | 'question': question_str, 147 | 'answer': answer_str, 148 | 'v_feat': video_feature, 149 | 'v_len': v_len, 150 | 'segments': segments, 151 | 'one_hot_labels': one_hot_labels, 152 | 'query_id': query_id, 153 | 'sample_ratio': sample_ratio, 154 | 'task': qa_type 155 | } 156 | 157 | 158 | class JointDataset(ConcatDataset): 159 | def __init__(self, datasets: Iterable[Dataset], tokenizer_path) -> None: 160 | super().__init__(datasets) 161 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 162 | self.tokenizer.pad_token = self.tokenizer.eos_token # BUG: Set this per convenience for GPT-2 163 | 164 | def collate_fn(self, batch): 165 | question = [b['question'] for b in batch] 166 | question_tok = self.tokenizer(question, padding=True, return_tensors='pt', add_special_tokens=False) 167 | 168 | answer = [b['answer'] for b in batch] 169 | labels = self.tokenizer(answer, padding=True, return_tensors='pt').input_ids 170 | # NOTE: NLQ data does not have an answer 171 | for idx, a in enumerate(answer): 172 | if a == 'None': 173 | labels[idx] = torch.ones_like(labels[idx]) * -100 174 | 175 | video_feature = [b['v_feat'] for b in batch] 176 | video_feature_padded = pad_sequence(video_feature, batch_first=True) 177 | video_mask = pad_sequence([torch.ones(len(v)) for v in video_feature], batch_first=True).bool() 178 | 179 | result = { 180 | 'video_id': [b['video_id'] for b in batch], 181 | 'q_text': question, 182 | 'q_token': question_tok.input_ids, 183 | 'q_mask': question_tok.attention_mask.bool(), 184 | 'v_feat': video_feature_padded, 185 | 'v_mask': video_mask, 186 | 'v_len': np.asarray([b['v_len'] for b in batch], dtype=np.long), 187 | 'gt_segments': torch.stack([b['segments'] for b in batch]), 188 | 'gt_labels': torch.stack([b['one_hot_labels'] for b in batch]), 189 | 'query_id': [b['query_id'] for b in batch], 190 | 'sample_ratio': [b['sample_ratio'] for b in batch], 191 | 'a_text': answer, 192 | 'labels': labels, 193 | 'task': [b['task'] for b in batch] 194 | } 195 | 196 | return result 197 | 198 | 199 | class JointDataModule(pl.LightningDataModule): 200 | train_dataset = None 201 | val_dataset = None 202 | test_dataset = None 203 | 204 | def __init__(self, config): 205 | super().__init__() 206 | self.config = config 207 | 208 | def setup(self, stage=None): 209 | CloseQA_weight = self.config.get('closeqa_weight', 50) 210 | print(f'CloseQA percentage: {CloseQA_weight}%') 211 | self.train_dataset = JointDataset([ 212 | QADataset('data/unified', train_split, self.config.feature_type, self.config.max_v_len, 'Mixed', CloseQA_weight) 213 | for train_split in self.config.qa_train_splits 214 | ] + [ 215 | NLQDataset('data/unified', train_split, self.config.feature_type, self.config.max_v_len) 216 | for train_split in self.config.nlq_train_splits 217 | ], 218 | self.config.tokenizer_path 219 | ) 220 | 221 | test_datasets = [] 222 | for split in self.config.test_splits: 223 | if split == 'QaEgo4D_test': 224 | test_datasets.append(QADataset('data/unified', split, self.config.feature_type, self.config.max_v_len, 'OpenQA')) 225 | elif split == 'QaEgo4D_test_close': 226 | test_datasets.append(QADataset('data/unified', split, self.config.feature_type, self.config.max_v_len, 'CloseQA')) 227 | elif split in ['NLQ_val', 'NLQ_test_unannotated']: 228 | test_datasets.append(NLQDataset('data/unified', split, self.config.feature_type, self.config.max_v_len)) 229 | else: 230 | print(split) 231 | raise NotImplementedError 232 | self.val_dataset = self.test_dataset = JointDataset(test_datasets, self.config.tokenizer_path) 233 | 234 | print(f'#total train: {len(self.train_dataset)}') 235 | print(f'#total val: {len(self.val_dataset)}') 236 | print(f'#total test: {len(self.test_dataset)}') 237 | 238 | def train_dataloader(self): 239 | return DataLoader( 240 | self.train_dataset, 241 | batch_size=self.config.batch_size, 242 | shuffle=True, 243 | drop_last=True, 244 | num_workers=self.config.num_workers, 245 | collate_fn=self.train_dataset.collate_fn, 246 | pin_memory=True, 247 | persistent_workers=True 248 | ) 249 | 250 | def val_dataloader(self): 251 | return DataLoader( 252 | self.val_dataset, 253 | batch_size=self.config.batch_size, 254 | shuffle=False, 255 | drop_last=False, 256 | num_workers=self.config.num_workers, 257 | collate_fn=self.val_dataset.collate_fn, 258 | pin_memory=True 259 | ) 260 | 261 | def test_dataloader(self): 262 | return DataLoader( 263 | self.test_dataset, 264 | batch_size=self.config.batch_size, 265 | shuffle=False, 266 | drop_last=False, 267 | num_workers=self.config.num_workers, 268 | collate_fn=self.val_dataset.collate_fn, 269 | pin_memory=True 270 | ) 271 | -------------------------------------------------------------------------------- /model/ours/lightning_module.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import random 4 | 5 | import torch 6 | import pytorch_lightning as pl 7 | from hydra.utils import instantiate 8 | from transformers import AutoTokenizer 9 | from torch.optim.lr_scheduler import OneCycleLR 10 | 11 | from eval import calc_metrics 12 | from eval_nlq import ReferringRecall 13 | 14 | 15 | class TestLightningModule(pl.LightningModule): 16 | def __init__(self, config): 17 | super().__init__() 18 | self.config = config 19 | self.tokenizer = AutoTokenizer.from_pretrained(config.dataset.tokenizer_path) 20 | self.tokenizer.pad_token = self.tokenizer.eos_token 21 | self.model = instantiate(config.model, max_v_len=config.dataset.max_v_len) 22 | 23 | def test_step(self, batch, batch_idx): 24 | nlq_results, answer_tokens = self.model.generate(**batch) 25 | pred_answer = self.tokenizer.batch_decode(answer_tokens, skip_special_tokens=True) 26 | return { 27 | 'question': batch['q_text'], 28 | 'video_id': batch['video_id'], 29 | 'answer': batch['a_text'] if 'a_text' in batch else '', 30 | 'pred_answer': pred_answer, 31 | 'nlq_results': nlq_results, 32 | 'query_id': batch['query_id'], 33 | 'sample_ratio': batch['sample_ratio'], 34 | 'task': batch['task'], 35 | 'clip_uid': batch['video_id'] 36 | } 37 | 38 | def test_epoch_end(self, outputs): 39 | self.save_nlq_results(outputs) 40 | 41 | def save_nlq_results(self, preds): 42 | # aggregate preds 43 | pred_dict = { 44 | "version": "1.0", 45 | "challenge": "ego4d_nlq_challenge", 46 | "results": [] 47 | } 48 | for batch_pred in preds: 49 | for i in range(len(batch_pred['video_id'])): 50 | qid = batch_pred['query_id'][i] 51 | annotation_uid, query_idx = qid.split('_') 52 | query_idx = int(query_idx) 53 | clip_uid = batch_pred['clip_uid'][i] 54 | sample_ratio = batch_pred['sample_ratio'][i] 55 | predicted_times = [ 56 | [segment[0] / sample_ratio, segment[1] / sample_ratio] 57 | for segment in batch_pred['nlq_results'][i]['segments'].cpu().detach().tolist() 58 | ] 59 | 60 | pred_dict['results'].append({ 61 | 'clip_uid': clip_uid, 62 | 'annotation_uid': annotation_uid, 63 | 'query_idx': query_idx, 64 | 'predicted_times': predicted_times 65 | }) 66 | 67 | with open('nlq_eval_results/nlq_v2.json', 'w') as f: 68 | json.dump(pred_dict, f) 69 | 70 | 71 | class LightningModule(pl.LightningModule): 72 | def __init__(self, config, total_steps): 73 | super().__init__() 74 | self.config = config 75 | self.tokenizer = AutoTokenizer.from_pretrained(config.dataset.tokenizer_path) 76 | self.tokenizer.pad_token = self.tokenizer.eos_token 77 | self.model = instantiate(config.model, max_v_len=config.dataset.max_v_len) 78 | self.nlq_evaluator = ReferringRecall( 79 | dataset="ego4d", 80 | gt_file=config.dataset.nlq_val_anno 81 | ) 82 | self._log_indices = {} 83 | self.total_steps = total_steps 84 | 85 | def training_step(self, batch, batch_idx): 86 | total_loss, ce_loss, time_loss = self.model(**batch) 87 | self.log('total_loss', total_loss, rank_zero_only=True) 88 | self.log('ce_loss', ce_loss, rank_zero_only=True) 89 | self.log('time_loss', time_loss, rank_zero_only=True) 90 | return { 91 | 'loss': total_loss, 92 | } 93 | 94 | def validation_step(self, batch, batch_idx): 95 | nlq_results, answer_tokens = self.model.generate(**batch) 96 | pred_answer = self.tokenizer.batch_decode(answer_tokens, skip_special_tokens=True) 97 | return { 98 | 'question': batch['q_text'], 99 | 'video_id': batch['video_id'], 100 | 'answer': batch['a_text'] if 'a_text' in batch else '', 101 | 'pred_answer': pred_answer, 102 | 'nlq_results': nlq_results, 103 | 'query_id': batch['query_id'], 104 | 'sample_ratio': batch['sample_ratio'], 105 | 'task': batch['task'] 106 | } 107 | 108 | def test_step(self, batch, batch_idx): 109 | return self.validation_step(batch, batch_idx) 110 | 111 | def _log_some_outputs(self, outputs, name): 112 | num_val_steps_to_log, num_samples_per_batch_to_log = 5, 3 # Could be configurable via cfg 113 | steps_to_log_indices = random.sample(range(len(outputs)), k=min(len(outputs), num_val_steps_to_log)) 114 | self._log_indices[name] = { 115 | 'steps': steps_to_log_indices, 116 | 'samples': [ 117 | random.sample( 118 | range(len(outputs[step]['answer'])), 119 | k=min(len(outputs[step]['answer']), 120 | num_samples_per_batch_to_log)) 121 | for step in steps_to_log_indices 122 | ] 123 | } 124 | for i, step in enumerate(steps_to_log_indices): 125 | indices = self._log_indices[name]['samples'][i] 126 | for b in indices: 127 | sample = ( 128 | f'Video: "{outputs[step]["video_id"][b]}". \n' 129 | f'Question: "{outputs[step]["question"][b]}". \n' 130 | f'Target: "{outputs[step]["answer"][b]}". \n' 131 | f'Output: "{outputs[step]["pred_answer"][b]}"' 132 | ) 133 | self.logger.experiment.add_text(f'{name} {str(i * len(indices) + b)}', sample, 134 | global_step=self.global_step) 135 | 136 | def aggregate_metrics(self, outputs, prefix): 137 | # evaluate CloseQA 138 | all_hypos = [] 139 | all_targets = [] 140 | for output in outputs: 141 | for i in range(len(output['video_id'])): 142 | if output['task'][i] == 'CloseQA': 143 | all_hypos.append(output['pred_answer'][i]) 144 | all_targets.append(output['answer'][i]) 145 | if len(all_hypos) > 0: 146 | num_correct = 0 147 | for hypo, target in zip(all_hypos, all_targets): 148 | if hypo == target: 149 | num_correct += 1 150 | acc = num_correct / len(all_targets) * 100 151 | metrics = {f'{prefix}_close_acc': acc} 152 | else: 153 | metrics = {} 154 | 155 | # evaluate OpenQA 156 | all_hypos = [] 157 | all_targets = [] 158 | for output in outputs: 159 | for i in range(len(output['video_id'])): 160 | if output['task'][i] == 'OpenQA': 161 | all_hypos.append(output['pred_answer'][i]) 162 | all_targets.append(output['answer'][i]) 163 | if len(all_hypos) > 0: 164 | open_qa_metrics = calc_metrics(all_hypos, [[x] for x in all_targets], test=prefix=='test') 165 | for k, v in open_qa_metrics.items(): 166 | metrics[f'{prefix}_{k}'] = v 167 | 168 | # evalute NLQ 169 | nlq_preds = [] 170 | for output in outputs: 171 | for i in range(len(output['video_id'])): 172 | if output['task'][i] != 'NLQ': 173 | continue 174 | qid = output['query_id'][i] 175 | temp_list = qid.split("_") 176 | sample_ratio = output['sample_ratio'][i] 177 | new_prediction = [ 178 | [ segment[0] / sample_ratio, 179 | segment[1] / sample_ratio, 180 | score ] 181 | for segment, score in zip( 182 | output['nlq_results'][i]['segments'].cpu().detach().tolist(), 183 | output['nlq_results'][i]['scores'].cpu().detach().tolist(), 184 | )] 185 | nlq_preds.append({ 186 | 'query_idx': int(temp_list[1]), 187 | 'annotation_uid': temp_list[0], 188 | 'predicted_times': new_prediction, 189 | 'clip_uid': output['video_id'][i] 190 | }) 191 | if len(nlq_preds) > 0: 192 | performance, score_str = self.nlq_evaluator.evaluate(nlq_preds, verbose=False) 193 | metrics[f'{prefix}_R1_03'] = performance[0, 0] * 100 194 | metrics[f'{prefix}_R5_03'] = performance[0, 1] * 100 195 | metrics[f'{prefix}_R1_05'] = performance[1, 0] * 100 196 | metrics[f'{prefix}_R5_05'] = performance[1, 1] * 100 197 | metrics[f'{prefix}_Mean_R1'] = (performance[0, 0] + performance[1, 0]) * 100 / 2 198 | 199 | # # save predictions 200 | # results = [] 201 | # for output in outputs: 202 | # for i in range(len(output['video_id'])): 203 | # results.append({ 204 | # 'query_id': output['query_id'][i], 205 | # 'pred_answer': output['pred_answer'][i], 206 | # 'gt_answer': output['answer'][i], 207 | # 'pred_window': (output['nlq_results'][i]['segments'].cpu().detach() / output['sample_ratio'][i]).tolist(), 208 | # 'gt_window': self.nlq_evaluator.gt_dict[(output['video_id'][i], output['query_id'][i].split('_')[0])]["language_queries"][int(output['query_id'][i].split('_')[1])] 209 | # }) 210 | # with open('analysis/VLG_OpenQA.json', 'w') as f: 211 | # json.dump(results, f) 212 | 213 | return metrics 214 | 215 | # def training_epoch_end(self, outputs): 216 | # self._log_some_outputs(outputs, 'train') 217 | # metrics = self.aggregate_metrics(outputs, prefix='train') 218 | # self.log_dict(metrics, sync_dist=True) 219 | 220 | def validation_epoch_end(self, outputs): 221 | def _mean(key): 222 | return torch.stack([data[key] for data in outputs]).mean() 223 | 224 | # self._log_some_outputs(outputs, 'val') 225 | metrics = self.aggregate_metrics(outputs, prefix='val') 226 | metrics.update({ 227 | f'val_{name}': _mean(name) for name in outputs[0].keys() if 'loss' in name 228 | }) 229 | self.log_dict(metrics, sync_dist=True) 230 | 231 | def test_epoch_end(self, outputs): 232 | # self._log_some_outputs(outputs, 'test') 233 | metrics = self.aggregate_metrics(outputs, prefix='test') 234 | self.log_dict(metrics, sync_dist=True) 235 | if self.config.trainer.save_nlq_results is not None: 236 | src = 'data/joint/annotations.QaEgo4D_test_close.json' 237 | dst = self.config.trainer.save_nlq_results 238 | self.save_nlq_results(src, dst, outputs) 239 | 240 | def save_nlq_results(self, src, dst, preds): 241 | # aggregate preds 242 | pred_dict = {} 243 | for batch_pred in preds: 244 | for i in range(len(batch_pred['video_id'])): 245 | qid = batch_pred['query_id'][i] 246 | sample_ratio = batch_pred['sample_ratio'][i] 247 | pred_start = batch_pred['nlq_results'][i]['segments'][0].cpu().detach().tolist()[0] / sample_ratio 248 | pred_end = batch_pred['nlq_results'][i]['segments'][0].cpu().detach().tolist()[1] / sample_ratio 249 | assert qid not in pred_dict 250 | pred_dict[qid] = { 251 | 'pred_start_sec': pred_start, 252 | 'pred_end_sec': pred_end 253 | } 254 | 255 | save_results = [] 256 | for src_data in json.load(open(src)): 257 | pred_data = pred_dict[src_data['sample_id']] 258 | save_data = copy.deepcopy(src_data) 259 | save_data['moment_start_frame'] = pred_data['pred_start_sec'] * 30 260 | save_data['moment_end_frame'] = pred_data['pred_end_sec'] * 30 261 | save_results.append(save_data) 262 | with open(dst, 'w') as f: 263 | json.dump(save_results, f) 264 | 265 | def configure_optimizers(self): 266 | optimizer = instantiate( 267 | self.config.optim.optimizer, 268 | filter(lambda p: p.requires_grad, self.parameters()), 269 | lr=self.config.optim.optimizer.lr 270 | ) 271 | if self.config.optim.lr_scheduler: 272 | lr_scheduler = OneCycleLR( 273 | optimizer=optimizer, 274 | max_lr=self.config.optim.optimizer.lr, 275 | total_steps=self.total_steps, 276 | anneal_strategy='linear' 277 | ) 278 | return { 279 | 'optimizer': optimizer, 280 | 'lr_scheduler': { 281 | 'scheduler': lr_scheduler, 282 | 'interval': 'step' 283 | } 284 | } 285 | else: 286 | return optimizer -------------------------------------------------------------------------------- /eval_nlq.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Script to evaluate performance of any model for Ego4d Episodic Memory. 4 | 5 | Natural Language Queries (NLQ) 6 | """ 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import json 12 | 13 | import numpy as np 14 | import torch 15 | import terminaltables 16 | 17 | 18 | 19 | def compute_overlap(pred, gt): 20 | # check format 21 | assert isinstance(pred, list) and isinstance(gt, list) 22 | pred_is_list = isinstance(pred[0], list) 23 | gt_is_list = isinstance(gt[0], list) 24 | pred = pred if pred_is_list else [pred] 25 | gt = gt if gt_is_list else [gt] 26 | # compute overlap 27 | pred, gt = np.array(pred), np.array(gt) 28 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0]) 29 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1]) 30 | inter = np.maximum(0.0, inter_right - inter_left) 31 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0]) 32 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1]) 33 | union = np.maximum(1e-12, union_right - union_left) 34 | overlap = 1.0 * inter / union 35 | # reformat output 36 | overlap = overlap if gt_is_list else overlap[:, 0] 37 | overlap = overlap if pred_is_list else overlap[0] 38 | return overlap 39 | 40 | 41 | def time_to_index(start_time, end_time, num_units, duration): 42 | s_times = np.arange(0, num_units).astype(np.float32) / float(num_units) * duration 43 | e_times = ( 44 | np.arange(1, num_units + 1).astype(np.float32) / float(num_units) * duration 45 | ) 46 | candidates = np.stack( 47 | [ 48 | np.repeat(s_times[:, None], repeats=num_units, axis=1), 49 | np.repeat(e_times[None, :], repeats=num_units, axis=0), 50 | ], 51 | axis=2, 52 | ).reshape((-1, 2)) 53 | overlaps = compute_overlap(candidates.tolist(), [start_time, end_time]).reshape( 54 | num_units, num_units 55 | ) 56 | start_index = np.argmax(overlaps) // num_units 57 | end_index = np.argmax(overlaps) % num_units 58 | return start_index, end_index, overlaps 59 | 60 | 61 | def index_to_time(start_index, end_index, num_units, duration): 62 | s_times = np.arange(0, num_units).astype(np.float32) * duration / float(num_units) 63 | e_times = ( 64 | np.arange(1, num_units + 1).astype(np.float32) * duration / float(num_units) 65 | ) 66 | start_time = s_times[start_index] 67 | end_time = e_times[end_index] 68 | return start_time, end_time 69 | 70 | 71 | def display_results(results, mIoU, thresholds, topK, title=None): 72 | display_data = [ 73 | [f"Rank@{ii}\nmIoU@{jj}" for ii in topK for jj in thresholds] + ["mIoU"] 74 | ] 75 | results *= 100 76 | mIoU *= 100 77 | display_data.append( 78 | [ 79 | f"{results[jj][ii]:.02f}" 80 | for ii in range(len(topK)) 81 | for jj in range(len(thresholds)) 82 | ] 83 | + [f"{mIoU:.02f}"] 84 | ) 85 | table = terminaltables.AsciiTable(display_data, title) 86 | for ii in range(len(thresholds) * len(topK)): 87 | table.justify_columns[ii] = "center" 88 | return table.table 89 | 90 | 91 | def compute_IoU(pred, gt): 92 | """Compute the IoU given predicted and ground truth windows.""" 93 | assert isinstance(pred, list) and isinstance(gt, list) 94 | pred_is_list = isinstance(pred[0], list) 95 | gt_is_list = isinstance(gt[0], list) 96 | if not pred_is_list: 97 | pred = [pred] 98 | if not gt_is_list: 99 | gt = [gt] 100 | pred, gt = np.array(pred), np.array(gt) 101 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0]) 102 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1]) 103 | inter = np.maximum(0.0, inter_right - inter_left) 104 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0]) 105 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1]) 106 | union = np.maximum(0.0, union_right - union_left) 107 | overlap = 1.0 * inter / union 108 | if not gt_is_list: 109 | overlap = overlap[:, 0] 110 | if not pred_is_list: 111 | overlap = overlap[0] 112 | return overlap 113 | 114 | 115 | def evaluate_nlq_performance( 116 | predictions, ground_truth, thresholds, topK, per_instance=False 117 | ): 118 | results = [[[] for _ in topK] for _ in thresholds] 119 | average_IoU = [] 120 | for pred, gt in zip(predictions, ground_truth): 121 | # Compute overlap and recalls. 122 | overlap = compute_IoU(pred, gt) 123 | average_IoU.append(np.mean(np.sort(overlap[0])[-3:])) 124 | for tt, threshold in enumerate(thresholds): 125 | for rr, KK in enumerate(topK): 126 | results[tt][rr].append((overlap > threshold)[:KK].any()) 127 | 128 | mean_results = np.array(results).mean(axis=-1) 129 | mIoU = np.mean(average_IoU) 130 | if per_instance: 131 | per_instance_results = { 132 | "overlap": overlap, 133 | "average_IoU": average_IoU, 134 | "results": results, 135 | } 136 | return mean_results, mIoU, per_instance_results 137 | else: 138 | return mean_results, mIoU 139 | 140 | 141 | def load_jsonl(filename): 142 | with open(filename, "r") as f: 143 | return [json.loads(l.strip("\n")) for l in f.readlines()] 144 | 145 | 146 | class ReferringRecall(object): 147 | thresholds = np.array([0.3, 0.5]) 148 | topK = np.array([1, 5]) 149 | def __init__( 150 | self, 151 | dataset="ego4d", 152 | gt_file="/remote-home/share/Ego4D_dsz/v1/annotations/nlq_val.json" 153 | ): 154 | self.dataset = dataset 155 | self.gt_file = gt_file 156 | print(self.gt_file) 157 | if self.dataset == "ego4d": 158 | with open(self.gt_file) as file_id: 159 | self.gt_dict, self.num_gt_queries = self.load_gt_from_json(json.load(file_id)) 160 | else: 161 | self.gt_dict = {} 162 | for d in load_jsonl(self.gt_file): 163 | # print(d) 164 | self.gt_dict[d['query_id']] = d["timestamps"] 165 | self.num_gt_queries = len(self.gt_dict) 166 | 167 | def load_gt_from_json(self, ground_truth): 168 | gt_dict = {} 169 | num_gt_queries = 0 170 | 171 | for video_datum in ground_truth["videos"]: 172 | for clip_datum in video_datum["clips"]: 173 | clip_uid = clip_datum["clip_uid"] 174 | for ann_datum in clip_datum["annotations"]: 175 | key = (clip_uid, ann_datum["annotation_uid"]) 176 | gt_dict[key] = ann_datum 177 | num_gt_queries += len(ann_datum["language_queries"]) 178 | 179 | return gt_dict, num_gt_queries 180 | 181 | def compute_IoU(self, pred, gt): 182 | """Compute the IoU given predicted and ground truth windows.""" 183 | assert isinstance(pred, list) and isinstance(gt, list) 184 | if len(pred) == 0: # FIXME: I don't know why, maybe coincidence, that the PtTransformerRegHead produces all 0 offsets at the start of training 185 | return [0] 186 | pred_is_list = isinstance(pred[0], list) 187 | gt_is_list = isinstance(gt[0], list) 188 | if not pred_is_list: 189 | pred = [pred] 190 | if not gt_is_list: 191 | gt = [gt] 192 | pred, gt = np.array(pred), np.array(gt) 193 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0]) 194 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1]) 195 | inter = np.maximum(0.0, inter_right - inter_left) 196 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0]) 197 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1]) 198 | union = np.maximum(0.0, union_right - union_left) 199 | overlap = 1.0 * inter / union 200 | if not gt_is_list: 201 | overlap = overlap[:, 0] 202 | if not pred_is_list: 203 | overlap = overlap[0] 204 | return overlap 205 | 206 | def display_results_anet(self, results, title=None): 207 | display_data = [ 208 | [f"Rank@{ii}\nmIoU@{jj:.1f}" for ii in self.topK for jj in self.thresholds] 209 | ] 210 | results *= 100 211 | display_data.append( 212 | [ 213 | f"{results[ii][jj]:.02f}" 214 | for ii in range(len(self.topK)) 215 | for jj in range(len(self.thresholds)) 216 | ] 217 | ) 218 | table = terminaltables.AsciiTable(display_data, title) 219 | for ii in range(len(self.thresholds) * len(self.topK)): 220 | table.justify_columns[ii] = "center" 221 | return table.table 222 | 223 | def display_results(self, results, title=None): 224 | display_data = [ 225 | [f"Rank@{ii}\nmIoU@{jj}" for ii in self.topK for jj in self.thresholds] 226 | ] 227 | results *= 100 228 | 229 | display_data.append( 230 | [ 231 | f"{results[jj][ii]:.02f}" 232 | for ii in range(len(self.topK)) 233 | for jj in range(len(self.thresholds)) 234 | ] 235 | ) 236 | table = terminaltables.AsciiTable(display_data, title) 237 | for ii in range(len(self.thresholds) * len(self.topK)): 238 | table.justify_columns[ii] = "center" 239 | return table.table 240 | 241 | def evaluate(self, predictions, verbose=True): 242 | """Evalutes the performances.""" 243 | 244 | results = [[[] for _ in self.topK] for _ in self.thresholds] 245 | average_IoU = [] 246 | num_instances = 0 247 | 248 | for pred_datum in predictions: 249 | key = (pred_datum["clip_uid"], pred_datum["annotation_uid"]) 250 | assert key in self.gt_dict, f"{key} Instance not present!" 251 | query_id = pred_datum["query_idx"] 252 | gt_datum = self.gt_dict[key] 253 | gt_query_datum = gt_datum["language_queries"][query_id] 254 | 255 | # Compute overlap and recalls. 256 | overlap = self.compute_IoU( 257 | pred_datum["predicted_times"], 258 | [[gt_query_datum["clip_start_sec"], gt_query_datum["clip_end_sec"]]], 259 | ) 260 | average_IoU.append(overlap[0]) 261 | 262 | for tt, threshold in enumerate(self.thresholds): 263 | for rr, KK in enumerate(self.topK): 264 | results[tt][rr].append((overlap > threshold)[:KK].any()) 265 | num_instances += 1 266 | 267 | mean_results = np.array(results).mean(axis=-1) 268 | 269 | score_str = None 270 | if verbose: 271 | print(f"Evaluated: {num_instances} / {self.num_gt_queries} instances") 272 | score_str = self.display_results(mean_results) 273 | print(score_str, flush=True) 274 | 275 | return mean_results, score_str 276 | 277 | def _iou(self, candidates, gt): 278 | start, end = candidates[:, 0].float(), candidates[:, 1].float() 279 | s, e = gt[0].float(), gt[1].float() 280 | inter = end.min(e) - start.max(s) 281 | union = end.max(e) - start.min(s) 282 | return inter.clamp(min=0) / union 283 | 284 | def evaluate_anet( 285 | self, submission, verbose=True): 286 | 287 | iou_metrics = torch.tensor(self.thresholds) 288 | num_iou_metrics = len(iou_metrics) 289 | 290 | recall_metrics = torch.tensor(self.topK) 291 | max_recall = recall_metrics.max() 292 | num_recall_metrics = len(recall_metrics) 293 | recall_x_iou = torch.zeros((num_recall_metrics, len(iou_metrics))) 294 | 295 | for k in submission: 296 | # print(k) 297 | gt_grounding = torch.tensor(self.gt_dict[k['query_id']]) 298 | pred_moments = torch.tensor(k["predicted_times"][:max_recall]) 299 | mious = self._iou(pred_moments, gt_grounding) 300 | mious_len = len(mious) 301 | bools = mious[:, None].expand(mious_len, num_iou_metrics) > iou_metrics 302 | for i, r in enumerate(recall_metrics): 303 | recall_x_iou[i] += bools[:r].any(dim=0) 304 | 305 | recall_x_iou /= len(submission) 306 | 307 | if verbose: 308 | print(f"Evaluated: {len(submission)} / {self.num_gt_queries} instances") 309 | score_str = self.display_results_anet(recall_x_iou) 310 | print(score_str, flush=True) 311 | 312 | return recall_x_iou 313 | 314 | 315 | def segment_iou(target_segment, candidate_segments): 316 | """Compute the temporal intersection over union between a 317 | target segment and all the test segments. 318 | Parameters 319 | ---------- 320 | target_segment : 1d array 321 | Temporal target segment containing [starting, ending] times. 322 | candidate_segments : 2d array 323 | Temporal candidate segments containing N x [starting, ending] times. 324 | Outputs 325 | ------- 326 | tiou : 1d array 327 | Temporal intersection over union score of the N's candidate segments. 328 | """ 329 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0]) 330 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1]) 331 | # Intersection including Non-negative overlap score. 332 | segments_intersection = (tt2 - tt1).clip(0) 333 | # Segment union. 334 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \ 335 | + (target_segment[1] - target_segment[0]) - segments_intersection 336 | # Compute overlap as the ratio of the intersection 337 | # over union of two segments. 338 | tIoU = segments_intersection.astype(float) / segments_union 339 | return tIoU -------------------------------------------------------------------------------- /model/ours/nlq_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import nms_1d_cpu 7 | 8 | 9 | class LayerNorm(nn.Module): 10 | """ 11 | LayerNorm that supports inputs of size B, C, T 12 | """ 13 | 14 | def __init__( 15 | self, 16 | num_channels, 17 | eps=1e-5, 18 | affine=True, 19 | device=None, 20 | dtype=None, 21 | ): 22 | super().__init__() 23 | factory_kwargs = {'device': device, 'dtype': dtype} 24 | self.num_channels = num_channels 25 | self.eps = eps 26 | self.affine = affine 27 | 28 | if self.affine: 29 | self.weight = nn.Parameter( 30 | torch.ones([1, num_channels, 1], **factory_kwargs)) 31 | self.bias = nn.Parameter( 32 | torch.zeros([1, num_channels, 1], **factory_kwargs)) 33 | else: 34 | self.register_parameter('weight', None) 35 | self.register_parameter('bias', None) 36 | 37 | def forward(self, x): 38 | assert x.dim() == 3 39 | assert x.shape[1] == self.num_channels 40 | 41 | # normalization along C channels 42 | mu = torch.mean(x, dim=1, keepdim=True) 43 | res_x = x - mu 44 | sigma = torch.mean(res_x ** 2, dim=1, keepdim=True) 45 | out = res_x / torch.sqrt(sigma + self.eps) 46 | 47 | # apply weight and bias 48 | if self.affine: 49 | out *= self.weight 50 | out += self.bias 51 | 52 | return out 53 | 54 | 55 | @torch.jit.script 56 | def ctr_diou_loss_1d( 57 | input_offsets: torch.Tensor, 58 | target_offsets: torch.Tensor, 59 | reduction: str = 'none', 60 | eps: float = 1e-8, 61 | ) -> torch.Tensor: 62 | """ 63 | Distance-IoU Loss (Zheng et. al) 64 | https://arxiv.org/abs/1911.08287 65 | 66 | This is an implementation that assumes a 1D event is represented using 67 | the same center point with different offsets, e.g., 68 | (t1, t2) = (c - o_1, c + o_2) with o_i >= 0 69 | 70 | Reference code from 71 | https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py 72 | 73 | Args: 74 | input/target_offsets (Tensor): 1D offsets of size (N, 2) 75 | reduction: 'none' | 'mean' | 'sum' 76 | 'none': No reduction will be applied to the output. 77 | 'mean': The output will be averaged. 78 | 'sum': The output will be summed. 79 | eps (float): small number to prevent division by zero 80 | """ 81 | input_offsets = input_offsets.float() 82 | target_offsets = target_offsets.float() 83 | # check all 1D events are valid 84 | assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative" 85 | assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative" 86 | 87 | lp, rp = input_offsets[:, 0], input_offsets[:, 1] 88 | lg, rg = target_offsets[:, 0], target_offsets[:, 1] 89 | 90 | # intersection key points 91 | lkis = torch.min(lp, lg) 92 | rkis = torch.min(rp, rg) 93 | 94 | # iou 95 | intsctk = rkis + lkis 96 | unionk = (lp + rp) + (lg + rg) - intsctk 97 | iouk = intsctk / unionk.clamp(min=eps) 98 | 99 | # smallest enclosing box 100 | lc = torch.max(lp, lg) 101 | rc = torch.max(rp, rg) 102 | len_c = lc + rc 103 | 104 | # offset between centers 105 | rho = 0.5 * (rp - lp - rg + lg) 106 | 107 | # diou 108 | loss = 1.0 - iouk + torch.square(rho / len_c.clamp(min=eps)) 109 | 110 | if reduction == "mean": 111 | loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() 112 | elif reduction == "sum": 113 | loss = loss.sum() 114 | 115 | return loss 116 | 117 | 118 | @torch.jit.script 119 | def sigmoid_focal_loss( 120 | inputs: torch.Tensor, 121 | targets: torch.Tensor, 122 | alpha: float = 0.25, 123 | gamma: float = 2.0, 124 | reduction: str = "none", 125 | ) -> torch.Tensor: 126 | """ 127 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 128 | Taken from 129 | https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py 130 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 131 | 132 | Args: 133 | inputs: A float tensor of arbitrary shape. 134 | The predictions for each example. 135 | targets: A float tensor with the same shape as inputs. Stores the binary 136 | classification label for each element in inputs 137 | (0 for the negative class and 1 for the positive class). 138 | alpha: (optional) Weighting factor in range (0,1) to balance 139 | positive vs negative examples. Default = 0.25. 140 | gamma: Exponent of the modulating factor (1 - p_t) to 141 | balance easy vs hard examples. 142 | reduction: 'none' | 'mean' | 'sum' 143 | 'none': No reduction will be applied to the output. 144 | 'mean': The output will be averaged. 145 | 'sum': The output will be summed. 146 | Returns: 147 | Loss tensor with the reduction option applied. 148 | """ 149 | inputs = inputs.float() 150 | targets = targets.float() 151 | p = torch.sigmoid(inputs) 152 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 153 | p_t = p * targets + (1 - p) * (1 - targets) 154 | loss = ce_loss * ((1 - p_t) ** gamma) 155 | 156 | if alpha >= 0: 157 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 158 | loss = alpha_t * loss 159 | 160 | if reduction == "mean": 161 | loss = loss.mean() 162 | elif reduction == "sum": 163 | loss = loss.sum() 164 | 165 | return loss 166 | 167 | 168 | class BufferList(nn.Module): 169 | """ 170 | Similar to nn.ParameterList, but for buffers 171 | 172 | Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/anchor_generator.py 173 | """ 174 | 175 | def __init__(self, buffers): 176 | super().__init__() 177 | for i, buffer in enumerate(buffers): 178 | # Use non-persistent buffer so the values are not saved in checkpoint 179 | self.register_buffer(str(i), buffer, persistent=False) 180 | 181 | def __len__(self): 182 | return len(self._buffers) 183 | 184 | def __iter__(self): 185 | return iter(self._buffers.values()) 186 | 187 | 188 | class PointGenerator(nn.Module): 189 | """ 190 | A generator for temporal "points" 191 | 192 | max_seq_len can be much larger than the actual seq length 193 | """ 194 | def __init__( 195 | self, 196 | max_seq_len, # max sequence length that the generator will buffer 197 | fpn_strides, # strides of fpn levels 198 | regression_range, # regression range (on feature grids) 199 | use_offset=False # if to align the points at grid centers 200 | ): 201 | super().__init__() 202 | # sanity check, # fpn levels and length divisible 203 | fpn_levels = len(fpn_strides) 204 | assert len(regression_range) == fpn_levels 205 | 206 | # save params 207 | self.max_seq_len = max_seq_len 208 | self.fpn_levels = fpn_levels 209 | self.fpn_strides = fpn_strides 210 | self.regression_range = regression_range 211 | self.use_offset = use_offset 212 | 213 | # generate all points and buffer the list 214 | self.buffer_points = self._generate_points() 215 | 216 | def _generate_points(self): 217 | points_list = [] 218 | # loop over all points at each pyramid level 219 | for l, stride in enumerate(self.fpn_strides): 220 | reg_range = torch.as_tensor( 221 | self.regression_range[l], dtype=torch.float) 222 | fpn_stride = torch.as_tensor(stride, dtype=torch.float) 223 | points = torch.arange(0, self.max_seq_len, stride)[:, None] 224 | # add offset if necessary (not in our current model) 225 | if self.use_offset: 226 | points += 0.5 * stride 227 | # pad the time stamp with additional regression range / stride 228 | reg_range = reg_range[None].repeat(points.shape[0], 1) 229 | fpn_stride = fpn_stride[None].repeat(points.shape[0], 1) 230 | # size: T x 4 (ts, reg_range, stride) 231 | points_list.append(torch.cat((points, reg_range, fpn_stride), dim=1)) 232 | 233 | return BufferList(points_list) 234 | 235 | def forward(self, feats): 236 | # feats will be a list of torch tensors 237 | assert len(feats) == self.fpn_levels 238 | pts_list = [] 239 | feat_lens = [feat.shape[-1] for feat in feats] 240 | for feat_len, buffer_pts in zip(feat_lens, self.buffer_points): 241 | assert feat_len <= buffer_pts.shape[0], "Reached max buffer length for point generator" 242 | pts = buffer_pts[:feat_len, :] 243 | pts_list.append(pts) 244 | return pts_list 245 | 246 | 247 | # drop path: from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py 248 | class Scale(nn.Module): 249 | """ 250 | Multiply the output regression range by a learnable constant value 251 | """ 252 | 253 | def __init__(self, init_value=1.0): 254 | """ 255 | init_value : initial value for the scalar 256 | """ 257 | super().__init__() 258 | self.scale = nn.Parameter( 259 | torch.tensor(init_value, dtype=torch.float32), 260 | requires_grad=True 261 | ) 262 | 263 | def forward(self, x): 264 | """ 265 | input -> scale * input 266 | """ 267 | return x * self.scale 268 | 269 | 270 | class MaskedConv1D(nn.Module): 271 | """ 272 | Masked 1D convolution. Interface remains the same as Conv1d. 273 | Only support a sub set of 1d convs 274 | """ 275 | 276 | def __init__( 277 | self, 278 | in_channels, 279 | out_channels, 280 | kernel_size, 281 | stride=1, 282 | padding=0, 283 | dilation=1, 284 | groups=1, 285 | bias=True, 286 | padding_mode='zeros' 287 | ): 288 | super().__init__() 289 | # element must be aligned 290 | assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding) 291 | # stride 292 | self.stride = stride 293 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 294 | stride, padding, dilation, groups, bias, padding_mode) 295 | # zero out the bias term if it exists 296 | if bias: 297 | torch.nn.init.constant_(self.conv.bias, 0.) 298 | 299 | def forward(self, x, mask): 300 | # x: batch size, feature channel, sequence length, 301 | # mask: batch size, 1, sequence length (bool) 302 | B, C, T = x.size() 303 | # input length must be divisible by stride 304 | assert T % self.stride == 0 305 | 306 | # conv 307 | out_conv = self.conv(x) 308 | # compute the mask 309 | if self.stride > 1: 310 | # downsample the mask using nearest neighbor 311 | out_mask = F.interpolate( 312 | mask.to(x.dtype), 313 | size=T // self.stride, 314 | mode='nearest' 315 | ) 316 | else: 317 | # masking out the features 318 | out_mask = mask.to(x.dtype) 319 | 320 | # masking the output, stop grad to mask 321 | out_conv = out_conv * out_mask.detach() 322 | out_mask = out_mask.bool() 323 | return out_conv, out_mask 324 | 325 | 326 | class PtTransformerClsHead(nn.Module): 327 | """ 328 | 1D Conv heads for classification 329 | """ 330 | 331 | def __init__( 332 | self, 333 | input_dim, 334 | feat_dim, 335 | num_classes, 336 | prior_prob=0.01, 337 | num_layers=3, 338 | kernel_size=3, 339 | act_layer=nn.ReLU, 340 | with_ln=False, 341 | empty_cls=[] 342 | ): 343 | super().__init__() 344 | self.act = act_layer() 345 | 346 | # build the head 347 | self.head = nn.ModuleList() 348 | self.norm = nn.ModuleList() 349 | for idx in range(num_layers - 1): 350 | if idx == 0: 351 | in_dim = input_dim 352 | out_dim = feat_dim 353 | else: 354 | in_dim = feat_dim 355 | out_dim = feat_dim 356 | self.head.append( 357 | MaskedConv1D( 358 | in_dim, out_dim, kernel_size, 359 | stride=1, 360 | padding=kernel_size // 2, 361 | bias=(not with_ln) 362 | ) 363 | ) 364 | if with_ln: 365 | self.norm.append( 366 | LayerNorm(out_dim) 367 | ) 368 | else: 369 | self.norm.append(nn.Identity()) 370 | 371 | # classifier 372 | self.cls_head = MaskedConv1D( 373 | feat_dim, num_classes, kernel_size, 374 | stride=1, padding=kernel_size // 2 375 | ) 376 | 377 | # use prior in model initialization to improve stability 378 | # this will overwrite other weight init 379 | bias_value = -(math.log((1 - prior_prob) / prior_prob)) 380 | torch.nn.init.constant_(self.cls_head.conv.bias, bias_value) 381 | 382 | # a quick fix to empty categories: 383 | # the weights associated with these categories will remain unchanged 384 | # we set their bias to a large negative value to prevent their outputs 385 | if len(empty_cls) > 0: 386 | bias_value = -(math.log((1 - 1e-6) / 1e-6)) 387 | for idx in empty_cls: 388 | torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value) 389 | 390 | def forward(self, fpn_feats, fpn_masks): 391 | assert len(fpn_feats) == len(fpn_masks) 392 | 393 | # apply the classifier for each pyramid level 394 | out_logits = tuple() 395 | for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): 396 | cur_out = cur_feat 397 | for idx in range(len(self.head)): 398 | cur_out, _ = self.head[idx](cur_out, cur_mask) 399 | cur_out = self.act(self.norm[idx](cur_out)) 400 | cur_logits, _ = self.cls_head(cur_out, cur_mask) 401 | out_logits += (cur_logits,) 402 | 403 | # fpn_masks remains the same 404 | return out_logits 405 | 406 | 407 | class PtTransformerRegHead(nn.Module): 408 | """ 409 | Shared 1D Conv heads for regression 410 | Simlar logic as PtTransformerClsHead with separated implementation for clarity 411 | """ 412 | 413 | def __init__( 414 | self, 415 | input_dim, 416 | feat_dim, 417 | fpn_levels, 418 | num_layers=3, 419 | kernel_size=3, 420 | act_layer=nn.ReLU, 421 | with_ln=False 422 | ): 423 | super().__init__() 424 | self.fpn_levels = fpn_levels 425 | self.act = act_layer() 426 | 427 | # build the conv head 428 | self.head = nn.ModuleList() 429 | self.norm = nn.ModuleList() 430 | for idx in range(num_layers - 1): 431 | if idx == 0: 432 | in_dim = input_dim 433 | out_dim = feat_dim 434 | else: 435 | in_dim = feat_dim 436 | out_dim = feat_dim 437 | self.head.append( 438 | MaskedConv1D( 439 | in_dim, out_dim, kernel_size, 440 | stride=1, 441 | padding=kernel_size // 2, 442 | bias=(not with_ln) 443 | ) 444 | ) 445 | if with_ln: 446 | self.norm.append( 447 | LayerNorm(out_dim) 448 | ) 449 | else: 450 | self.norm.append(nn.Identity()) 451 | 452 | self.scale = nn.ModuleList() 453 | for idx in range(fpn_levels): 454 | self.scale.append(Scale()) 455 | 456 | # segment regression 457 | self.offset_head = MaskedConv1D( 458 | feat_dim, 2, kernel_size, 459 | stride=1, padding=kernel_size // 2 460 | ) 461 | 462 | def forward(self, fpn_feats, fpn_masks): 463 | assert len(fpn_feats) == len(fpn_masks) 464 | assert len(fpn_feats) == self.fpn_levels 465 | 466 | # apply the classifier for each pyramid level 467 | out_offsets = tuple() 468 | for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)): 469 | cur_out = cur_feat 470 | for idx in range(len(self.head)): 471 | cur_out, _ = self.head[idx](cur_out, cur_mask) 472 | cur_out = self.act(self.norm[idx](cur_out)) 473 | cur_offsets, _ = self.offset_head(cur_out, cur_mask) 474 | out_offsets += (F.relu(self.scale[l](cur_offsets)),) 475 | 476 | # fpn_masks remains the same 477 | return out_offsets 478 | 479 | 480 | class NLQHead(nn.Module): 481 | def __init__(self, in_dim, max_v_len): 482 | super().__init__() 483 | self.train_center_sample = 'radius' 484 | self.train_center_sample_radius = 1.5 485 | self.train_loss_weight = 1.0 486 | self.train_cls_prior_prob = 0.01 487 | self.train_label_smoothing = 0.1 488 | 489 | self.test_pre_nms_thresh = 0.001 490 | self.test_pre_nms_topk = 2000 491 | self.test_iou_threshold = 0.1 492 | self.test_min_score = 0.001 493 | self.test_max_seg_num = 5 494 | self.test_nms_method = 'soft' 495 | self.test_duration_thresh = 0.001 496 | self.test_multiclass_nms = True 497 | self.test_nms_sigma = 0.75 498 | self.test_voting_thresh = 0.9 499 | 500 | self.loss_normalizer = 200 501 | self.loss_normalizer_momentum = 0.9 502 | 503 | self.neck = FPNIdentity( 504 | in_channels=[in_dim], 505 | out_channel=in_dim, 506 | start_level=0, 507 | end_level=-1, 508 | with_ln=True 509 | ) 510 | self.point_generator = PointGenerator( 511 | max_seq_len=1.0 * max_v_len, 512 | fpn_strides=[1], 513 | regression_range=[[0,10000]] 514 | ) 515 | self.cls_head = PtTransformerClsHead( 516 | in_dim, 517 | feat_dim=384, 518 | num_classes=1, 519 | kernel_size=3, 520 | prior_prob=self.train_cls_prior_prob, 521 | with_ln=True, 522 | num_layers=3, 523 | empty_cls=[] 524 | ) 525 | self.reg_head = PtTransformerRegHead( 526 | in_dim, 527 | feat_dim=384, 528 | fpn_levels=1, 529 | kernel_size=3, 530 | num_layers=3, 531 | with_ln=True 532 | ) 533 | 534 | def forward(self, feat, mask, training=True, gt_segments=None, gt_labels=None, v_lens=None): 535 | """ 536 | feat: (B, D, T) 537 | mask: (B, 1, T) 538 | """ 539 | masks = [mask] 540 | feats = [feat] 541 | 542 | fpn_feats, fpn_masks = self.neck(feats, masks) 543 | 544 | points = self.point_generator(fpn_feats) 545 | out_cls_logits = [x.permute(0, 2, 1) for x in self.cls_head(fpn_feats, fpn_masks)] # (B, T, #cls+1) 546 | out_offsets = [x.permute(0, 2, 1) for x in self.reg_head(fpn_feats, fpn_masks)] # (B, T, #cls*2) 547 | fpn_masks = [x.squeeze(1) for x in fpn_masks] # (B, T) 548 | 549 | # return loss during training 550 | if training: 551 | gt_cls_labels, gt_offsets = self.label_points(points, gt_segments, gt_labels, 1) 552 | 553 | # compute the loss and return 554 | losses = self.losses( 555 | fpn_masks, 556 | out_cls_logits, out_offsets, 557 | gt_cls_labels, gt_offsets 558 | ) 559 | return losses 560 | else: 561 | # decode the actions (sigmoid / stride, etc) 562 | results = self.inference(points, fpn_masks, out_cls_logits, out_offsets, 1, v_lens) 563 | 564 | return results 565 | 566 | @torch.no_grad() 567 | def label_points(self, points, gt_segments, gt_labels, num_classes): 568 | # concat points on all fpn levels List[T x 4] -> F T x 4 569 | # This is shared for all samples in the mini-batch 570 | num_levels = len(points) 571 | concat_points = torch.cat(points, dim=0) 572 | 573 | gt_cls, gt_offset = [], [] 574 | # loop over each video sample 575 | for gt_segment, gt_label in zip(gt_segments, gt_labels): 576 | assert len(gt_segment) == len(gt_label), (gt_segment, gt_label) 577 | cls_targets, reg_targets = self.label_points_single_video( 578 | concat_points, gt_segment, gt_label, num_classes 579 | ) 580 | # "cls_targets: " #points, num_classes 581 | # "reg_targets: " #points, 2 582 | # append to list (len = # images, each of size FT x C) 583 | gt_cls.append(cls_targets) 584 | gt_offset.append(reg_targets) 585 | 586 | return gt_cls, gt_offset 587 | 588 | @torch.no_grad() 589 | def label_points_single_video(self, concat_points, gt_segment, gt_label, num_classes): 590 | # concat_points : F T x 4 (t, regression range, stride) 591 | # gt_segment : N (#Events) x 2 592 | # gt_label : N (#Events) x 1 593 | num_pts = concat_points.shape[0] 594 | num_gts = gt_segment.shape[0] 595 | 596 | # corner case where current sample does not have actions 597 | if num_gts == 0: 598 | cls_targets = gt_segment.new_full((num_pts, num_classes), 0) 599 | reg_targets = gt_segment.new_zeros((num_pts, 2)) 600 | return cls_targets, reg_targets 601 | 602 | # compute the lengths of all segments -> F T x N 603 | lens = gt_segment[:, 1] - gt_segment[:, 0] 604 | lens = lens[None, :].repeat(num_pts, 1) 605 | 606 | # compute the distance of every point to each segment boundary 607 | # auto broadcasting for all reg target-> F T x N x 2 608 | gt_segs = gt_segment[None].expand(num_pts, num_gts, 2) 609 | left = concat_points[:, 0, None] - gt_segs[:, :, 0] 610 | right = gt_segs[:, :, 1] - concat_points[:, 0, None] 611 | reg_targets = torch.stack((left, right), dim=-1) 612 | 613 | if self.train_center_sample == 'radius': 614 | # center of all segments F T x N 615 | center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1]) 616 | # center sampling based on stride radius 617 | # compute the new boundaries: 618 | # concat_points[:, 3] stores the stride 619 | t_mins = \ 620 | center_pts - concat_points[:, 3, None] * self.train_center_sample_radius 621 | t_maxs = \ 622 | center_pts + concat_points[:, 3, None] * self.train_center_sample_radius 623 | 624 | # prevent t_mins / maxs from over-running the action boundary 625 | # left: torch.maximum(t_mins, gt_segs[:, :, 0]) 626 | # right: torch.minimum(t_maxs, gt_segs[:, :, 1]) 627 | # F T x N (distance to the new boundary) 628 | cb_dist_left = concat_points[:, 0, None] \ 629 | - torch.maximum(t_mins, gt_segs[:, :, 0]) 630 | cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \ 631 | - concat_points[:, 0, None] 632 | # F T x N x 2 633 | center_seg = torch.stack( 634 | (cb_dist_left, cb_dist_right), -1) 635 | 636 | # F T x N 637 | inside_gt_seg_mask = center_seg.min(-1)[0] > 0 638 | else: 639 | # inside an gt action 640 | inside_gt_seg_mask = reg_targets.min(-1)[0] > 0 641 | 642 | # limit the regression range for each location 643 | max_regress_distance = reg_targets.max(-1)[0] 644 | 645 | # F T x N 646 | inside_regress_range = torch.logical_and( 647 | (max_regress_distance >= concat_points[:, 1, None]), 648 | (max_regress_distance <= concat_points[:, 2, None]) 649 | ) 650 | 651 | # limit the regression range for each location and inside the center radius 652 | lens.masked_fill_(inside_gt_seg_mask == 0, float('inf')) 653 | lens.masked_fill_(inside_regress_range == 0, float('inf')) 654 | 655 | # if there are still more than one ground-truths for one point 656 | # pick the ground-truth with the shortest duration for the point (easiest to regress) 657 | # corner case: multiple actions with very similar durations (e.g., THUMOS14) 658 | # make sure that each point can only map with at most one ground-truth 659 | # F T x N -> F T 660 | min_len, min_len_inds = lens.min(dim=1) 661 | min_len_mask = torch.logical_and( 662 | (lens <= (min_len[:, None] + 1e-3)), (lens < float('inf')) 663 | ).to(reg_targets.dtype) 664 | 665 | # cls_targets: F T x C; reg_targets F T x 2 666 | # gt_label_one_hot = F.one_hot(gt_label, num_classes).to(reg_targets.dtype) 667 | gt_label_one_hot = gt_label.to(reg_targets.dtype) 668 | cls_targets = min_len_mask @ gt_label_one_hot 669 | # to prevent multiple GT actions with the same label and boundaries 670 | cls_targets.clamp_(min=0.0, max=1.0) 671 | 672 | # OK to use min_len_inds 673 | reg_targets = reg_targets[range(num_pts), min_len_inds] 674 | # normalization based on stride 675 | reg_targets /= concat_points[:, 3, None] 676 | return cls_targets, reg_targets 677 | 678 | def losses( 679 | self, fpn_masks, 680 | out_cls_logits, out_offsets, 681 | gt_cls_labels, gt_offsets 682 | ): 683 | # fpn_masks, out_*: F (List) [B, T_i, C] 684 | # gt_* : B (list) [F T, C] 685 | # fpn_masks -> (B, FT) 686 | valid_mask = torch.cat(fpn_masks, dim=1) 687 | 688 | # 1. classification loss 689 | # stack the list -> (B, FT) -> (# Valid, ) 690 | gt_cls = torch.stack(gt_cls_labels) 691 | pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask) 692 | 693 | # update the loss normalizer 694 | num_pos = pos_mask.sum().item() 695 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 696 | 1 - self.loss_normalizer_momentum) * max(num_pos, 1) 697 | 698 | # gt_cls is already one hot encoded now, simply masking out 699 | gt_target = gt_cls[valid_mask] 700 | 701 | num_classes = gt_target.shape[-1] 702 | 703 | # optional label smoothing 704 | gt_target *= 1 - self.train_label_smoothing 705 | gt_target += self.train_label_smoothing / (num_classes + 1) 706 | 707 | # focal loss 708 | cls_loss = sigmoid_focal_loss( 709 | torch.cat(out_cls_logits, dim=1)[valid_mask], 710 | gt_target, 711 | reduction='sum' 712 | ) 713 | cls_loss /= self.loss_normalizer 714 | 715 | # 2. regression using IoU/GIoU loss (defined on positive samples) 716 | # cat the predicted offsets -> (B, FT, 2 (xC)) -> # (#Pos, 2 (xC)) 717 | pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask] 718 | gt_offsets = torch.stack(gt_offsets)[pos_mask] 719 | if num_pos == 0: 720 | reg_loss = 0 * pred_offsets.sum() 721 | else: 722 | # giou loss defined on positive samples 723 | reg_loss = ctr_diou_loss_1d( 724 | pred_offsets, 725 | gt_offsets, 726 | reduction='sum' 727 | ) 728 | reg_loss /= self.loss_normalizer 729 | 730 | if self.train_loss_weight > 0: 731 | loss_weight = self.train_loss_weight 732 | else: 733 | loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01) 734 | 735 | # return a dict of losses 736 | final_loss = cls_loss + reg_loss * loss_weight 737 | return {'cls_loss': cls_loss, 738 | 'reg_loss': reg_loss, 739 | 'final_loss': final_loss} 740 | 741 | @torch.no_grad() 742 | def inference( 743 | self, 744 | points, fpn_masks, 745 | out_cls_logits, out_offsets, num_classes, v_lens 746 | ): 747 | # video_list B (list) [dict] 748 | # points F (list) [T_i, 4] 749 | # fpn_masks, out_*: F (List) [B, T_i, C] 750 | results = [] 751 | 752 | # 2: inference on each single video and gather the results 753 | # upto this point, all results use timestamps defined on feature grids 754 | for idx, vlen in enumerate(v_lens): 755 | # gather per-video outputs 756 | cls_logits_per_vid = [x[idx] for x in out_cls_logits] 757 | offsets_per_vid = [x[idx] for x in out_offsets] 758 | fpn_masks_per_vid = [x[idx] for x in fpn_masks] 759 | # inference on a single video (should always be the case) 760 | results_per_vid = self.inference_single_video( 761 | points, fpn_masks_per_vid, 762 | cls_logits_per_vid, offsets_per_vid, num_classes, 763 | ) 764 | # pass through video meta info 765 | results_per_vid['duration'] = vlen 766 | results.append(results_per_vid) 767 | 768 | # step 3: postprocessing 769 | results = self.postprocessing(results) 770 | 771 | return results 772 | 773 | @torch.no_grad() 774 | def inference_single_video( 775 | self, 776 | points, 777 | fpn_masks, 778 | out_cls_logits, 779 | out_offsets, 780 | num_classes, 781 | ): 782 | # points F (list) [T_i, 4] 783 | # fpn_masks, out_*: F (List) [T_i, C] 784 | segs_all = [] 785 | scores_all = [] 786 | cls_idxs_all = [] 787 | 788 | # loop over fpn levels 789 | for cls_i, offsets_i, pts_i, mask_i in zip( 790 | out_cls_logits, out_offsets, points, fpn_masks 791 | ): 792 | # sigmoid normalization for output logits 793 | pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten() 794 | 795 | # Apply filtering to make NMS faster following detectron2 796 | # 1. Keep seg with confidence score > a threshold 797 | keep_idxs1 = (pred_prob > self.test_pre_nms_thresh) 798 | pred_prob = pred_prob[keep_idxs1] 799 | topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0] 800 | 801 | # 2. Keep top k top scoring boxes only 802 | num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0)) 803 | pred_prob, idxs = pred_prob.sort(descending=True) 804 | pred_prob = pred_prob[:num_topk].clone() 805 | topk_idxs = topk_idxs[idxs[:num_topk]].clone() 806 | 807 | # fix a warning in pytorch 1.9 808 | pt_idxs = torch.div( 809 | topk_idxs, num_classes, rounding_mode='floor' 810 | ) 811 | cls_idxs = torch.fmod(topk_idxs, num_classes) 812 | 813 | # 3. gather predicted offsets 814 | offsets = offsets_i[pt_idxs] 815 | pts = pts_i[pt_idxs] 816 | 817 | # 4. compute predicted segments (denorm by stride for output offsets) 818 | seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3] 819 | seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3] 820 | pred_segs = torch.stack((seg_left, seg_right), -1) 821 | 822 | # 5. Keep seg with duration > a threshold (relative to feature grids) 823 | seg_areas = seg_right - seg_left 824 | keep_idxs2 = seg_areas > self.test_duration_thresh 825 | 826 | # *_all : N (filtered # of segments) x 2 / 1 827 | segs_all.append(pred_segs[keep_idxs2]) 828 | scores_all.append(pred_prob[keep_idxs2]) 829 | cls_idxs_all.append(cls_idxs[keep_idxs2]) 830 | 831 | # cat along the FPN levels (F N_i, C) 832 | segs_all, scores_all, cls_idxs_all = [ 833 | torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all] 834 | ] 835 | results = {'segments': segs_all, 836 | 'scores': scores_all, 837 | 'labels': cls_idxs_all} 838 | 839 | return results 840 | 841 | @torch.no_grad() 842 | def postprocessing(self, results): 843 | # input : list of dictionary items 844 | # (1) push to CPU; (2) NMS; (3) convert to actual time stamps 845 | processed_results = [] 846 | fps = 30 847 | stride = nframes = 16.043 848 | for results_per_vid in results: 849 | # unpack the meta info 850 | vlen = results_per_vid['duration'] 851 | # 1: unpack the results and move to CPU 852 | segs = results_per_vid['segments'].detach().cpu() 853 | scores = results_per_vid['scores'].detach().cpu() 854 | labels = results_per_vid['labels'].detach().cpu() 855 | if self.test_nms_method != 'none': 856 | # 2: batched nms (only implemented on CPU) 857 | segs, scores, labels = batched_nms( 858 | segs, scores, labels, 859 | self.test_iou_threshold, 860 | self.test_min_score, 861 | self.test_max_seg_num, 862 | use_soft_nms=(self.test_nms_method == 'soft'), 863 | multiclass=self.test_multiclass_nms, 864 | sigma=self.test_nms_sigma, 865 | voting_thresh=self.test_voting_thresh 866 | ) 867 | # 3: convert from feature grids to seconds 868 | if segs.shape[0] > 0: 869 | # segs_sec = (segs * stride + 0.5 * nframes) / fps 870 | segs_sec = segs * stride / fps # use_offset=False 871 | # truncate all boundaries within [0, duration] 872 | segs_sec[segs_sec <= 0.0] *= 0.0 873 | segs_sec[segs_sec >= vlen] = segs_sec[segs_sec >= vlen] * 0.0 + vlen 874 | # else: # FIXME: don't know why but Flan-T5-L produces segs.shape[0] == 0 875 | # segs_sec = torch.zeros((1, 2)) 876 | # scores = torch.zeros(1) 877 | # labels = torch.zeros(1, dtype=torch.int64) 878 | # 4: repack the results 879 | processed_results.append({ 880 | 'segments': segs_sec, 881 | # 'segments_feat': segs, 882 | 'scores': scores, 883 | 'labels': labels 884 | }) 885 | 886 | return processed_results 887 | 888 | 889 | class NMSop(torch.autograd.Function): 890 | @staticmethod 891 | def forward( 892 | ctx, segs, scores, cls_idxs, 893 | iou_threshold, min_score, max_num 894 | ): 895 | # vanilla nms will not change the score, so we can filter segs first 896 | is_filtering_by_score = (min_score > 0) 897 | if is_filtering_by_score: 898 | valid_mask = scores > min_score 899 | segs, scores = segs[valid_mask], scores[valid_mask] 900 | cls_idxs = cls_idxs[valid_mask] 901 | valid_inds = torch.nonzero( 902 | valid_mask, as_tuple=False).squeeze(dim=1) 903 | 904 | # nms op; return inds that is sorted by descending order 905 | inds = nms_1d_cpu.nms( 906 | segs.contiguous().cpu(), 907 | scores.contiguous().cpu(), 908 | iou_threshold=float(iou_threshold)) 909 | # cap by max number 910 | if max_num > 0: 911 | inds = inds[:min(max_num, len(inds))] 912 | # return the sorted segs / scores 913 | sorted_segs = segs[inds] 914 | sorted_scores = scores[inds] 915 | sorted_cls_idxs = cls_idxs[inds] 916 | return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone() 917 | 918 | 919 | class SoftNMSop(torch.autograd.Function): 920 | @staticmethod 921 | def forward( 922 | ctx, segs, scores, cls_idxs, 923 | iou_threshold, sigma, min_score, method, max_num 924 | ): 925 | # pre allocate memory for sorted results 926 | dets = segs.new_empty((segs.size(0), 3), device='cpu') 927 | # softnms op, return dets that stores the sorted segs / scores 928 | inds = nms_1d_cpu.softnms( 929 | segs.cpu(), 930 | scores.cpu(), 931 | dets.cpu(), 932 | iou_threshold=float(iou_threshold), 933 | sigma=float(sigma), 934 | min_score=float(min_score), 935 | method=int(method)) 936 | # cap by max number 937 | if max_num > 0: 938 | n_segs = min(len(inds), max_num) 939 | else: 940 | n_segs = len(inds) 941 | sorted_segs = dets[:n_segs, :2] 942 | sorted_scores = dets[:n_segs, 2] 943 | sorted_cls_idxs = cls_idxs[inds] 944 | sorted_cls_idxs = sorted_cls_idxs[:n_segs] 945 | return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone() 946 | 947 | 948 | def seg_voting(nms_segs, all_segs, all_scores, iou_threshold, score_offset=1.5): 949 | """ 950 | blur localization results by incorporating side segs. 951 | this is known as bounding box voting in object detection literature. 952 | slightly boost the performance around iou_threshold 953 | """ 954 | 955 | # *_segs : N_i x 2, all_scores: N, 956 | # apply offset 957 | offset_scores = all_scores + score_offset 958 | 959 | # computer overlap between nms and all segs 960 | # construct the distance matrix of # N_nms x # N_all 961 | num_nms_segs, num_all_segs = nms_segs.shape[0], all_segs.shape[0] 962 | ex_nms_segs = nms_segs[:, None].expand(num_nms_segs, num_all_segs, 2) 963 | ex_all_segs = all_segs[None, :].expand(num_nms_segs, num_all_segs, 2) 964 | 965 | # compute intersection 966 | left = torch.maximum(ex_nms_segs[:, :, 0], ex_all_segs[:, :, 0]) 967 | right = torch.minimum(ex_nms_segs[:, :, 1], ex_all_segs[:, :, 1]) 968 | inter = (right-left).clamp(min=0) 969 | 970 | # lens of all segments 971 | nms_seg_lens = ex_nms_segs[:, :, 1] - ex_nms_segs[:, :, 0] 972 | all_seg_lens = ex_all_segs[:, :, 1] - ex_all_segs[:, :, 0] 973 | 974 | # iou 975 | iou = inter / (nms_seg_lens + all_seg_lens - inter) 976 | 977 | # get neighbors (# N_nms x # N_all) / weights 978 | seg_weights = (iou >= iou_threshold).to(all_scores.dtype) * all_scores[None, :] * iou 979 | seg_weights /= torch.sum(seg_weights, dim=1, keepdim=True) 980 | refined_segs = seg_weights @ all_segs 981 | 982 | return refined_segs 983 | 984 | 985 | def batched_nms( 986 | segs, 987 | scores, 988 | cls_idxs, 989 | iou_threshold, 990 | min_score, 991 | max_seg_num, 992 | use_soft_nms=True, 993 | multiclass=True, 994 | sigma=0.5, 995 | voting_thresh=0.75, 996 | ): 997 | # Based on Detectron2 implementation, 998 | num_segs = segs.shape[0] 999 | # corner case, no prediction outputs 1000 | if num_segs == 0: 1001 | return torch.zeros([0, 2]),\ 1002 | torch.zeros([0,]),\ 1003 | torch.zeros([0,], dtype=cls_idxs.dtype) 1004 | 1005 | if multiclass: 1006 | # multiclass nms: apply nms on each class independently 1007 | new_segs, new_scores, new_cls_idxs = [], [], [] 1008 | for class_id in torch.unique(cls_idxs): 1009 | curr_indices = torch.where(cls_idxs == class_id)[0] 1010 | # soft_nms vs nms 1011 | if use_soft_nms: 1012 | sorted_segs, sorted_scores, sorted_cls_idxs = SoftNMSop.apply( 1013 | segs[curr_indices], 1014 | scores[curr_indices], 1015 | cls_idxs[curr_indices], 1016 | iou_threshold, 1017 | sigma, 1018 | min_score, 1019 | 2, 1020 | max_seg_num 1021 | ) 1022 | else: 1023 | sorted_segs, sorted_scores, sorted_cls_idxs = NMSop.apply( 1024 | segs[curr_indices], 1025 | scores[curr_indices], 1026 | cls_idxs[curr_indices], 1027 | iou_threshold, 1028 | min_score, 1029 | max_seg_num 1030 | ) 1031 | # disable seg voting for multiclass nms, no sufficient segs 1032 | 1033 | # fill in the class index 1034 | new_segs.append(sorted_segs) 1035 | new_scores.append(sorted_scores) 1036 | new_cls_idxs.append(sorted_cls_idxs) 1037 | 1038 | # cat the results 1039 | new_segs = torch.cat(new_segs) 1040 | new_scores = torch.cat(new_scores) 1041 | new_cls_idxs = torch.cat(new_cls_idxs) 1042 | 1043 | else: 1044 | # class agnostic 1045 | if use_soft_nms: 1046 | new_segs, new_scores, new_cls_idxs = SoftNMSop.apply( 1047 | segs, scores, cls_idxs, iou_threshold, 1048 | sigma, min_score, 2, max_seg_num 1049 | ) 1050 | else: 1051 | new_segs, new_scores, new_cls_idxs = NMSop.apply( 1052 | segs, scores, cls_idxs, iou_threshold, 1053 | min_score, max_seg_num 1054 | ) 1055 | # seg voting 1056 | if voting_thresh > 0: 1057 | new_segs = seg_voting( 1058 | new_segs, 1059 | segs, 1060 | scores, 1061 | voting_thresh 1062 | ) 1063 | 1064 | # sort based on scores and return 1065 | # truncate the results based on max_seg_num 1066 | _, idxs = new_scores.sort(descending=True) 1067 | max_seg_num = min(max_seg_num, new_segs.shape[0]) 1068 | # needed for multiclass NMS 1069 | new_segs = new_segs[idxs[:max_seg_num]] 1070 | new_scores = new_scores[idxs[:max_seg_num]] 1071 | new_cls_idxs = new_cls_idxs[idxs[:max_seg_num]] 1072 | return new_segs, new_scores, new_cls_idxs 1073 | 1074 | 1075 | class FPNIdentity(nn.Module): 1076 | def __init__( 1077 | self, 1078 | in_channels, # input feature channels, len(in_channels) = # levels 1079 | out_channel, # output feature channel 1080 | start_level=0, # start fpn level 1081 | end_level=-1, # end fpn level 1082 | with_ln=True # if to apply layer norm at the end 1083 | ): 1084 | super().__init__() 1085 | 1086 | self.in_channels = in_channels 1087 | self.out_channel = out_channel 1088 | 1089 | self.start_level = start_level 1090 | if end_level == -1: 1091 | self.end_level = len(in_channels) 1092 | else: 1093 | self.end_level = end_level 1094 | assert self.end_level <= len(in_channels) 1095 | assert (self.start_level >= 0) and (self.start_level < self.end_level) 1096 | 1097 | self.fpn_norms = nn.ModuleList() 1098 | for i in range(self.start_level, self.end_level): 1099 | # check feat dims 1100 | assert self.in_channels[i] == self.out_channel 1101 | # layer norm for order (B C T) 1102 | if with_ln: 1103 | fpn_norm = LayerNorm(out_channel) 1104 | else: 1105 | fpn_norm = nn.Identity() 1106 | self.fpn_norms.append(fpn_norm) 1107 | 1108 | def forward(self, inputs, fpn_masks): 1109 | # inputs must be a list / tuple 1110 | assert len(inputs) == len(self.in_channels), (len(inputs) , len(self.in_channels)) 1111 | assert len(fpn_masks) == len(self.in_channels), (len(fpn_masks) , len(self.in_channels)) 1112 | 1113 | # apply norms, fpn_masks will remain the same with 1x1 convs 1114 | fpn_feats = tuple() 1115 | new_fpn_masks = tuple() 1116 | for i in range(len(self.fpn_norms)): 1117 | x = self.fpn_norms[i](inputs[i + self.start_level]) 1118 | fpn_feats += (x, ) 1119 | new_fpn_masks += (fpn_masks[i + self.start_level], ) 1120 | 1121 | return fpn_feats, new_fpn_masks 1122 | --------------------------------------------------------------------------------