├── config
├── model
│ ├── groundvqa_b.yaml
│ └── groundvqa_s.yaml
├── dataset
│ └── egovlp_internvideo.yaml
└── base.yaml
├── scripts
├── train_groundvqa_s-qaego4d.sh
├── train_groundvqa_b-qaego4d_egotimeqa.sh
├── train_groundvqa_s-qaego4d_egotimeqa.sh
├── train_groundvqa_b-nlq.sh
├── evaluate_groundvqa_b-nlq.sh
├── evaluate_groundvqa_b-nlq_naq.sh
├── evaluate_groundvqa_s-qaego4d.sh
├── evaluate_groundvqa_b-qaego4d_egotimeqa.sh
└── evaluate_groundvqa_s-qaego4d_egotimeqa.sh
├── requirements.txt
├── utils
├── generate_open_qa
│ ├── merge.py
│ └── generate.py
└── generate_close_qa
│ └── generate.py
├── LICENSE
├── model
└── ours
│ ├── model.py
│ ├── dataset.py
│ ├── lightning_module.py
│ └── nlq_head.py
├── .gitignore
├── README.md
├── run.py
├── eval.py
└── eval_nlq.py
/config/model/groundvqa_b.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.ours.model.GroundVQA
2 | lm_path: google/flan-t5-base
3 | input_dim: ${dataset.feature_dim}
4 | freeze_word: True
5 |
--------------------------------------------------------------------------------
/config/model/groundvqa_s.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.ours.model.GroundVQA
2 | lm_path: google/flan-t5-small
3 | input_dim: ${dataset.feature_dim}
4 | freeze_word: True
5 |
--------------------------------------------------------------------------------
/scripts/train_groundvqa_s-qaego4d.sh:
--------------------------------------------------------------------------------
1 | # train with QaEgo4D
2 | CUDA_VISIBLE_DEVICES=6,7 python run.py \
3 | model=groundvqa_s \
4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
5 | dataset.batch_size=64 \
6 | trainer.gpus=2
--------------------------------------------------------------------------------
/scripts/train_groundvqa_b-qaego4d_egotimeqa.sh:
--------------------------------------------------------------------------------
1 | # train with QaEgo4D + EgoTimeQA
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \
3 | model=groundvqa_b \
4 | 'dataset.qa_train_splits=[QaEgo4D_train,EgoTimeQA]' \
5 | dataset.batch_size=16 \
6 | trainer.gpus=8
7 |
--------------------------------------------------------------------------------
/scripts/train_groundvqa_s-qaego4d_egotimeqa.sh:
--------------------------------------------------------------------------------
1 | # train with QaEgo4D + EgoTimeQA
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \
3 | model=groundvqa_s \
4 | 'dataset.qa_train_splits=[QaEgo4D_train,EgoTimeQA]' \
5 | dataset.batch_size=32 \
6 | trainer.gpus=8
7 |
--------------------------------------------------------------------------------
/scripts/train_groundvqa_b-nlq.sh:
--------------------------------------------------------------------------------
1 | # train with NLQv2
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python run.py \
3 | model=groundvqa_b \
4 | 'dataset.nlq_train_splits=[NLQ_train]' \
5 | 'dataset.test_splits=[NLQ_val]' \
6 | dataset.batch_size=16 \
7 | trainer.find_unused_parameters=True \
8 | trainer.gpus=8
9 |
--------------------------------------------------------------------------------
/scripts/evaluate_groundvqa_b-nlq.sh:
--------------------------------------------------------------------------------
1 | # NLQv2 val set
2 | CUDA_VISIBLE_DEVICES=1 python run.py \
3 | model=groundvqa_b \
4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
5 | 'dataset.test_splits=[NLQ_val]' \
6 | dataset.batch_size=32 \
7 | +trainer.test_only=True \
8 | '+trainer.checkpoint_path=""' \
9 | trainer.load_nlq_head=True
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.1
2 | torchvision==0.16.1
3 | tqdm
4 | ipdb
5 | h5py==3.10.0
6 | terminaltables
7 | nltk
8 | rouge-score==0.1.2
9 | numpy==1.23.5
10 | pytorch_lightning==1.5.10
11 | hydra-core==1.3.2
12 | optimum==1.14.1
13 | sentence_transformers==2.2.2
14 | positional_encodings==6.0.1
15 | ffmpeg-python==0.2.0
16 | transformers==4.35.2
17 | accelerate==0.24.1
18 | bitsandbytes
--------------------------------------------------------------------------------
/scripts/evaluate_groundvqa_b-nlq_naq.sh:
--------------------------------------------------------------------------------
1 | # NLQv2 val set
2 | python run.py \
3 | model=groundvqa_b \
4 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
5 | 'dataset.test_splits=[NLQ_val]' \
6 | dataset.batch_size=32 \
7 | +trainer.test_only=True \
8 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-NLQ_NaQ-finetune_NLQ-VLG-val_R1_03=29.7.ckpt"' \
9 | trainer.load_nlq_head=True
--------------------------------------------------------------------------------
/config/dataset/egovlp_internvideo.yaml:
--------------------------------------------------------------------------------
1 | data_dir: data/unified
2 | nlq_val_anno: data/nlq_v2/nlq_val.json
3 | feature_type: egovlp_internvideo
4 | feature_dim: 2304
5 | max_v_len: 1200
6 |
7 | qa_train_splits: []
8 | nlq_train_splits: []
9 | test_splits: ['QaEgo4D_test', 'QaEgo4D_test_close', 'NLQ_val']
10 | closeqa_weight: 50
11 |
12 | tokenizer_path: google/flan-t5-small
13 |
14 | num_workers: 4
15 | batch_size: 16
16 |
--------------------------------------------------------------------------------
/config/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: egovlp_internvideo
3 | - model: groundvqa_s
4 | - _self_
5 | - override hydra/job_logging: none
6 | - override hydra/hydra_logging: none
7 |
8 | trainer:
9 | detect_anomaly: True
10 | max_epochs: 100
11 | accumulate_grad_batches: 1
12 | auto_resume: False
13 | gpus: 1
14 | log_every_n_steps: 1
15 | auto_lr_find: False
16 | enable_progress_bar: True
17 | monitor_variable: val_ROUGE
18 | monitor_mode: max
19 | find_unused_parameters: False
20 | precision: bf16
21 | val: False # test on the val set
22 | gradient_clip_val: 1.0
23 | save_nlq_results: null
24 | deterministic: True
25 | load_decoder: True
26 | load_nlq_head: True
27 | ignore_existing_checkpoints: True
28 |
29 | optim:
30 | optimizer:
31 | _target_: torch.optim.AdamW
32 | lr: 0.0001
33 | weight_decay: 0.0
34 | freeze: [ ]
35 | lr_scheduler: False
36 |
37 | hydra:
38 | run:
39 | dir: .
40 | output_subdir: null
41 |
--------------------------------------------------------------------------------
/utils/generate_open_qa/merge.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 | import numpy as np
4 |
5 |
6 | split = 'EgoTimeQA'
7 | src1 = f'tmp/annotations.{split}_*.json'
8 | tgt = f'annotations.{split}.json'
9 | paths = glob.glob(src1)
10 | paths = sorted(paths)
11 | print('Merging')
12 |
13 | merge = []
14 | for p in paths:
15 | print(p)
16 | x = json.load(open(p))
17 | merge += x
18 |
19 | all_duration_sec = [(x['moment_end_frame'] - x['moment_start_frame']) / 30 for x in merge]
20 | mean_duration_sec = np.asarray(all_duration_sec).mean()
21 | # normalize duration_sec
22 | for x in merge:
23 | start_sec = x['moment_start_frame'] / 30
24 | end_sec = x['moment_end_frame'] / 30
25 | center_sec = (start_sec + end_sec) / 2
26 | duration_sec = (end_sec - start_sec) / mean_duration_sec
27 | x['moment_start_frame'] = (center_sec - duration_sec / 2) * 30
28 | x['moment_end_frame'] = (center_sec + duration_sec / 2) * 30
29 |
30 | print(f'into {tgt}')
31 | with open(tgt, 'w') as f:
32 | json.dump(merge, f)
33 |
34 | print(len(merge))
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Bright
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/evaluate_groundvqa_s-qaego4d.sh:
--------------------------------------------------------------------------------
1 | # QAEgo4D-Close test set
2 | for SEED in 0 1111 2222 3333 4444
3 | do
4 | python run.py \
5 | model=groundvqa_s \
6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
7 | 'dataset.test_splits=[QaEgo4D_test_close]' \
8 | dataset.batch_size=64 \
9 | +trainer.test_only=True \
10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-test_ROUGE=29.0.ckpt"' \
11 | +trainer.random_seed=$SEED
12 | done
13 |
14 | # QAEgo4D-Open test set
15 | python run.py \
16 | model=groundvqa_s \
17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
18 | 'dataset.test_splits=[QaEgo4D_test]' \
19 | dataset.batch_size=64 \
20 | +trainer.test_only=True \
21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-test_ROUGE=29.0.ckpt"'
22 |
23 | # NLQv2 val set
24 | python run.py \
25 | model=groundvqa_s \
26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
27 | 'dataset.test_splits=[NLQ_val]' \
28 | dataset.batch_size=64 \
29 | +trainer.test_only=True \
30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D-COV-val_R1_03=11.0.ckpt"' \
31 | trainer.load_nlq_head=True
32 |
--------------------------------------------------------------------------------
/scripts/evaluate_groundvqa_b-qaego4d_egotimeqa.sh:
--------------------------------------------------------------------------------
1 | # QAEgo4D-Close test set
2 | for SEED in 0 1111 2222 3333 4444
3 | do
4 | python run.py \
5 | model=groundvqa_b \
6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
7 | 'dataset.test_splits=[QaEgo4D_test_close]' \
8 | dataset.batch_size=32 \
9 | +trainer.test_only=True \
10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.4.ckpt"' \
11 | +trainer.random_seed=$SEED
12 | done
13 |
14 | # QAEgo4D-Open test set
15 | python run.py \
16 | model=groundvqa_b \
17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
18 | 'dataset.test_splits=[QaEgo4D_test]' \
19 | dataset.batch_size=32 \
20 | +trainer.test_only=True \
21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.4.ckpt"'
22 |
23 | # NLQv2 val set
24 | python run.py \
25 | model=groundvqa_b \
26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
27 | 'dataset.test_splits=[NLQ_val]' \
28 | dataset.batch_size=32 \
29 | +trainer.test_only=True \
30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-val_R1_03=25.6.ckpt"' \
31 | trainer.load_nlq_head=True
--------------------------------------------------------------------------------
/scripts/evaluate_groundvqa_s-qaego4d_egotimeqa.sh:
--------------------------------------------------------------------------------
1 | # QAEgo4D-Close test set
2 | for SEED in 0 1111 2222 3333 4444
3 | do
4 | python run.py \
5 | model=groundvqa_s \
6 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
7 | 'dataset.test_splits=[QaEgo4D_test_close]' \
8 | dataset.batch_size=64 \
9 | +trainer.test_only=True \
10 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.2.ckpt"' \
11 | +trainer.random_seed=$SEED
12 | done
13 |
14 | # QAEgo4D-Open test set
15 | python run.py \
16 | model=groundvqa_s \
17 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
18 | 'dataset.test_splits=[QaEgo4D_test]' \
19 | dataset.batch_size=64 \
20 | +trainer.test_only=True \
21 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE=30.2.ckpt"'
22 |
23 | # NLQv2 val set
24 | python run.py \
25 | model=groundvqa_s \
26 | 'dataset.qa_train_splits=[QaEgo4D_train]' \
27 | 'dataset.test_splits=[NLQ_val]' \
28 | dataset.batch_size=64 \
29 | +trainer.test_only=True \
30 | '+trainer.checkpoint_path="checkpoints/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-val_R1_03=23.3.ckpt"' \
31 | trainer.load_nlq_head=True
--------------------------------------------------------------------------------
/model/ours/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import PreTrainedModel, AutoModelForSeq2SeqLM
4 | from transformers.modeling_outputs import BaseModelOutput
5 |
6 | from model.ours.nlq_head import NLQHead
7 |
8 |
9 | class GroundVQA(nn.Module):
10 | def __init__(self, lm_path, input_dim, freeze_word=False, max_v_len=256):
11 | super().__init__()
12 |
13 | if not isinstance(input_dim, int):
14 | input_dim = input_dim.v_dim
15 |
16 | self.lm: PreTrainedModel = AutoModelForSeq2SeqLM.from_pretrained(lm_path, local_files_only=True)
17 |
18 | lm_dim = self.lm.get_input_embeddings().embedding_dim
19 | self.lm_proj = nn.Linear(input_dim, lm_dim)
20 | self.v_emb = nn.Parameter(torch.randn((1, 1, lm_dim)))
21 | if freeze_word:
22 | for name, param in self.lm.named_parameters():
23 | if 'shared' in name:
24 | param.requires_grad = False
25 |
26 | self.nlq_head = NLQHead(in_dim=lm_dim, max_v_len=max_v_len)
27 |
28 | def forward(self, v_feat, v_mask, q_token, q_mask, gt_segments, gt_labels,
29 | labels=None, **remains):
30 | # encoder
31 | encoder_out, mask = self.forward_encoder(v_feat, v_mask, q_token, q_mask)
32 |
33 | # localizer
34 | encoder_out_v = encoder_out[:, -v_feat.shape[1]:]
35 | nlq_results = self.nlq_head(
36 | feat=encoder_out_v.permute(0, 2, 1), # (B, D, T)
37 | mask=v_mask.unsqueeze(1), # (B, 1, T)
38 | gt_segments=gt_segments,
39 | gt_labels=gt_labels
40 | )
41 | time_loss = nlq_results['final_loss'] * 1.0
42 |
43 | # decoder
44 | outputs = self.lm(
45 | encoder_outputs=(encoder_out,),
46 | attention_mask=mask,
47 | labels=labels,
48 | )
49 | lm_loss = outputs.loss
50 |
51 | total_loss = 0.5 * time_loss + 0.5 * lm_loss
52 |
53 | return total_loss, lm_loss, time_loss
54 |
55 | def generate(self, v_feat, v_mask, q_token, q_mask, v_len, **remains):
56 | encoder_out, mask = self.forward_encoder(v_feat, v_mask, q_token, q_mask)
57 | encoder_out_v = encoder_out[:, -v_feat.shape[1]:]
58 |
59 | nlq_results = self.nlq_head(
60 | feat=encoder_out_v.permute(0, 2, 1), # (B, D, T)
61 | mask=v_mask.unsqueeze(1), # (B, 1, T)
62 | training=False,
63 | v_lens=v_len
64 | )
65 | answer_tokens = self.lm.generate(
66 | encoder_outputs=BaseModelOutput(last_hidden_state=encoder_out),
67 | attention_mask=mask,
68 | max_new_tokens=32
69 | )
70 |
71 | return nlq_results, answer_tokens
72 |
73 | def forward_encoder(self, v_feat, v_mask, q_token, q_mask):
74 | B, L, D = v_feat.shape
75 | v_feat = self.lm_proj(v_feat)
76 | v_feat = v_feat + self.v_emb.expand((B, L, -1))
77 | q_feat = self.lm.encoder.embed_tokens(q_token)
78 | lm_input = torch.cat([q_feat, v_feat], dim=1)
79 | lm_mask = torch.cat([q_mask, v_mask], dim=1)
80 | out = self.lm.encoder(
81 | inputs_embeds=lm_input,
82 | attention_mask=lm_mask
83 | )
84 | return out.last_hidden_state, lm_mask
85 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | data/
3 | lightning_logs
4 | utils/generate_open_qa/tmp
5 | utils/generate_close_qa/tmp
6 | *.pkl
7 | *.json
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/#use-with-ide
118 | .pdm.toml
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
--------------------------------------------------------------------------------
/utils/generate_close_qa/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ast import literal_eval
3 |
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 | import transformers
6 | import torch
7 | from tqdm import tqdm
8 | import json
9 |
10 |
11 | token = '' # your access token to Llama2
12 |
13 | model_id = 'meta-llama/Llama-2-13b-chat-hf'
14 | batch_size = 4
15 |
16 | # fp16
17 | tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=token)
18 | pipeline = transformers.pipeline(
19 | "text-generation",
20 | model=model_id,
21 | torch_dtype=torch.float16,
22 | device_map="auto",
23 | tokenizer=tokenizer,
24 | token=token
25 | )
26 | # fix a bug: https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020
27 | pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
28 | pipeline.model.config.pad_token_id = pipeline.model.config.eos_token_id
29 |
30 |
31 | class QADataset(torch.utils.data.Dataset):
32 | def __init__(self, qa_annotations, start, end):
33 | self.qa_annotations = qa_annotations
34 | if start is not None and end is not None:
35 | self.qa_annotations = self.qa_annotations[start:end]
36 |
37 | print('preparing LM prompts...')
38 | self.lm_inputs = []
39 | for data in tqdm(self.qa_annotations):
40 | self.lm_inputs.append(self._get_lm_input(data['question'], data['answer']))
41 |
42 | def _get_lm_input(self, question, answer):
43 | return f"""[INST] <>
44 | I'll provide a question and its correct answer. Generate three plausible, but incorrect, answers that closely resemble the correct one Make it challenging to identify the right answer.
45 | <>
46 |
47 | No preamble, get right to the three wrong answers and present them in a list format. Question: How many frying pans can i see on the shelf? Correct Answer: two pieces. Wrong Answers: [/INST] [\"one piece\", \"three piece\", \"five pieces\"]
48 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: What colour bowl did i carry from the plate stand? Correct Answer: green. Wrong Answers: [/INST] [\"blue\", \"black\", \"white\"]
49 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: What did i pour in the bowl? Correct Answer: boiling water. Wrong Answers: [/INST] [\"hot oil\", \"steamed milk\", \"warm broth\"]
50 | [INST] No preamble, get right to the three wrong answers and present them in a list format. Question: {question} Correct Answer: {answer}. Wrong Answers: [/INST]
51 |
52 | """
53 |
54 | def __len__(self):
55 | return len(self.qa_annotations)
56 |
57 | def __getitem__(self, idx):
58 | return self.lm_inputs[idx]
59 |
60 | def get_data(self, idx):
61 | return self.qa_annotations[idx]
62 |
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--start', type=int, default=None)
67 | parser.add_argument('--end', type=int, default=None)
68 | args = parser.parse_args()
69 |
70 | with open('annotations.EgoTimeQA.json', 'r') as f:
71 | qa_annotations = json.load(f)
72 |
73 | dataset = QADataset(qa_annotations, args.start, args.end)
74 |
75 | errors = 0
76 | pbar = tqdm(total=len(dataset))
77 | res = []
78 | for idx, out in enumerate(pipeline(
79 | dataset,
80 | batch_size=64,
81 | do_sample=True,
82 | temperature=0.5,
83 | top_k=10,
84 | num_return_sequences=1,
85 | eos_token_id=tokenizer.eos_token_id,
86 | max_new_tokens=64,
87 | return_full_text=False,
88 | )):
89 | pbar.set_description('Errors: %d' % errors)
90 | pbar.update(1)
91 | gen_result = out[0]['generated_text']
92 |
93 | try: # may not generate the desired format
94 | wrong_answers = literal_eval(gen_result)
95 | assert isinstance(wrong_answers, list) and len(wrong_answers) == 3
96 | data = dataset.get_data(idx)
97 | data['wrong_answers'] = wrong_answers
98 | res.append(data)
99 | success = True
100 | except:
101 | errors += 1
102 | print(gen_result)
103 |
104 | print(f'#{len(res)} / {len(dataset)} samples generated!')
105 |
106 | if args.start is not None and args.end is not None:
107 | with open(f'tmp/annotations.EgoTimeQA_{args.start}_{args.end}.json', 'w') as f:
108 | json.dump(res, f)
109 | else:
110 | with open('annotations.EgoTimeQA.json', 'w') as f:
111 | json.dump(res, f)
112 |
--------------------------------------------------------------------------------
/utils/generate_open_qa/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from ast import literal_eval
4 |
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | import transformers
7 | import torch
8 | import numpy as np
9 | from tqdm import tqdm
10 | import json
11 | import pickle
12 |
13 |
14 | token = '' # your access token to Llama2
15 |
16 | model_id = 'meta-llama/Llama-2-13b-chat-hf'
17 | model_id = "/root/.cache/huggingface/models--meta-llama--Llama-2-13b-chat-hf"
18 | batch_size = 4
19 |
20 | # fp16
21 | tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=token)
22 | pipeline = transformers.pipeline(
23 | "text-generation",
24 | model=model_id,
25 | torch_dtype=torch.float16,
26 | device_map="auto",
27 | tokenizer=tokenizer,
28 | token=token
29 | )
30 | # fix a bug: https://discuss.huggingface.co/t/llama2-pad-token-for-batched-inference/48020
31 | pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
32 | pipeline.model.config.pad_token_id = pipeline.model.config.eos_token_id
33 |
34 |
35 | class NarrationDataset(torch.utils.data.Dataset):
36 | def __init__(self, clip_narrations):
37 | self.prompt = """[INST] <>
38 | You are an AI Assistant and always write the output of your response in JSON. I will provide you with a series of narrations that depict my behavior. You should generate one QA pair based on the narrations in the format of {\"Q\": , \"A\": }. In the narrations, \"C\" represents me, and \"O\" represents someone else. Use as much information as possible from narrations to generate the question, and the question you generate should be able to be answered using the information provided in the narrations. The question should be in the past tense. The question should be within 10 words, and the answer should be within 5 words.
39 | <>
40 |
41 | C pours hot water from the frying pan in his left hand into the bowl in his right hand. [/INST] {\"Q\": \"What did I pour in the bowl?\", \"A\": \"boiling water\"}
42 | [INST] C searches through the cabinet. C closes the cabinet. C picks the tin from the cabinet. C places the tin on the counter. [/INST] {\"Q\": \"Where was the tin before I took it?\", \"A\": \"at the cabinet\"}
43 | [INST] C turns on sink knob. C washes the cucumber on the sink. C turns off sink knob. [/INST] {\"Q\": \"Did I wash the cucumber?\", \"A\": \"yes\"}
44 | """
45 | self.narrations = self._prepare(clip_narrations)
46 |
47 | def _prepare(self, clip_narrations):
48 | sampled_narrations = []
49 | for c in clip_narrations:
50 | clip_uid = c['clip_uid']
51 | narration_pass = c['narration_pass']
52 | narrations = c['narrations']
53 |
54 | idx = 0
55 | while idx < len(narrations):
56 | sampled = self._sample_narrations(narrations, start=idx)
57 | start_sec = sampled[0]['timestamps'][0]
58 | end_sec = sampled[-1]['timestamps'][1]
59 | narration_texts = ' '.join([n['narration_text'] for n in sampled])
60 | lm_input = self.prompt + "[INST] " + narration_texts + "[/INST]\n"
61 | sampled_narrations.append({
62 | 'clip_uid': clip_uid,
63 | 'narration_pass': narration_pass,
64 | 'narrations': narration_texts,
65 | 'start_sec': start_sec,
66 | 'end_sec': end_sec,
67 | 'lm_input': lm_input,
68 | 'n_narration': len(sampled)
69 | })
70 | idx += len(sampled)
71 |
72 | return sampled_narrations
73 |
74 | def _sample_narrations(self, narrations, start, max_n=5, max_timespan=30):
75 | end = min(len(narrations), start+max_n)
76 | while start < end - 1 and max_timespan <= narrations[end-1]['timestamps'][1] - narrations[start]['timestamps'][0]:
77 | end -= 1
78 | end = np.random.choice(np.arange(start, end), 1)[0] + 1
79 | return narrations[start:end]
80 |
81 | def __len__(self):
82 | return len(self.narrations)
83 |
84 | def __getitem__(self, idx):
85 | return self.narrations[idx]['lm_input']
86 |
87 | def get_clip_uid(self, idx):
88 | return self.narrations[idx]['clip_uid']
89 |
90 | def get_start_sec(self, idx):
91 | return self.narrations[idx]['start_sec']
92 |
93 | def get_end_sec(self, idx):
94 | return self.narrations[idx]['end_sec']
95 |
96 | def get_narrations(self, idx):
97 | return self.narrations[idx]['narrations']
98 |
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument('-start', type=int, default=None)
103 | parser.add_argument('-end', type=int, default=None)
104 | args = parser.parse_args()
105 |
106 | with open('em_train_narrations.pkl', 'rb') as f:
107 | clip_narrations = pickle.load(f)
108 | print(len(clip_narrations))
109 | os.makedirs('tmp', exist_ok=True)
110 |
111 | if args.start is not None and args.end is not None:
112 | clip_narrations = clip_narrations[args.start:args.end]
113 | save_path = f'tmp/annotations.EgoTimeQA_{args.start}_{args.end}.json'
114 | else:
115 | save_path = 'annotations.EgoTimeQA.json'
116 |
117 | dataset = NarrationDataset(clip_narrations)
118 |
119 | errors = 0
120 | pbar = tqdm(total=len(dataset))
121 | res = []
122 | for idx, out in enumerate(pipeline(
123 | dataset,
124 | batch_size=32,
125 | do_sample=True,
126 | temperature=0.5,
127 | top_k=10,
128 | num_return_sequences=1,
129 | eos_token_id=tokenizer.eos_token_id,
130 | max_new_tokens=64,
131 | return_full_text=False,
132 | )):
133 | pbar.set_description(f'Errors: {errors}')
134 | pbar.update(1)
135 | gen_result = out[0]['generated_text']
136 | try: # may not generate in JSON format
137 | qa = literal_eval(gen_result)
138 | question = qa['Q']
139 | answer = qa['A']
140 | start_sec = dataset.get_start_sec(idx)
141 | end_sec = dataset.get_end_sec(idx)
142 | start_frame = int(start_sec * 30)
143 | end_frame = int(end_sec * 30)
144 | res.append({
145 | 'video_id': dataset.get_clip_uid(idx),
146 | 'sample_id': None,
147 | 'answer': answer,
148 | 'question': question,
149 | 'start_sec': start_sec,
150 | 'end_sec': end_sec,
151 | "moment_start_frame": start_frame,
152 | "moment_end_frame": end_frame,
153 | 'narrations': dataset.get_narrations(idx)
154 | })
155 | except:
156 | errors += 1
157 | continue
158 |
159 | with open(save_path, 'w') as f:
160 | json.dump(res, f)
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GroundVQA
2 |
3 | Official PyTorch code of "Grounded Question-Answering in Long Egocentric Videos", *CVPR* 2024.
4 |
5 | [[Project page]](https://dszdsz.cn/GroundVQA/index.html) [[Paper]](https://arxiv.org/abs/2312.06505)
6 |
7 | ## News
8 |
9 | [Apr 2024] We have updated the CloseQA benchmarking results, now available on [arXiv [v4]](https://arxiv.org/pdf/2312.06505).
10 |
11 | [Feb 2024] We update the CloseQA test set with rigorous human verification. The benchmark results will be updated in our paper shortly.
12 |
13 | ## Abstract
14 |
15 | Existing approaches to video understanding, mainly designed for short videos from a third-person perspective, are limited in their applicability in certain fields, such as robotics. In this paper, we delve into **open-ended question-answering (QA) in long, egocentric videos**, which allows individuals or robots to inquire about their own past visual experiences.
16 |
17 | This task presents **unique challenges**, including the complexity of temporally grounding queries within extensive video content, the high resource demands for precise data annotation, and the inherent difficulty of evaluating open-ended answers due to their ambiguous nature.
18 |
19 | Our proposed approach tackles these challenges by
20 |
21 | - **GroundVQA**: integrating query grounding and answering within a unified model to reduce error propagation;
22 | - **EgoTimeQA**: employing large language models for efficient and scalable data synthesis;
23 | - **QaEgo4D**$`_\texttt{close}`$: introducing a close-ended QA task for evaluation, to manage answer ambiguity.
24 |
25 | Extensive experiments demonstrate the effectiveness of our method, which also achieves state-of-the-art performance on the QaEgo4D and Ego4D-NLQ benchmarks.
26 |
27 | ## Directory Structure
28 |
29 | ```
30 | .
31 | |-- checkpoints provided model checkpoints
32 | |-- config configs of models and datasets
33 | |-- data processed dataset and video features
34 | |-- eval.py code for evaluating QaEgo4D performance
35 | |-- eval_nlq.py code for evaluating NLQ performance
36 | |-- model code for model, dataset, and training
37 | |-- requirements.txt list of packages for building the Python environment
38 | |-- run.py entry code
39 | |-- scripts scripts for training and evaluation
40 | `-- utils code for generating OpenQA and CloseQA data from Ego4D narrations
41 | ```
42 |
43 | ## Preparation
44 |
45 | Our setup: Ubuntu 20.04, CUDA 12.2, 8x Nvidia A100 (80GB)
46 |
47 | - Clone this repo: `https://github.com/Becomebright/GroundVQA.git`
48 | - Create the conda environment: `conda create -n groundvqa python=3.9 -y && conda activate groundvqa`
49 | - Install packages: `pip install -r requirements.txt`
50 | - Compile `nms_1d_cpu` following [here](https://github.com/happyharrycn/actionformer_release/blob/main/INSTALL.md)
51 | - Download the data, video feature, and model checkpoints from [Huggingface](https://huggingface.co/Becomebright/GroundVQA)
52 | - **data:** unzip `data.zip` under the project's root directory.
53 | - **video feature:** merge the files `cat egovlp_internvideoa* > egovlp_internvideo.hdf5` and put it under `data/unified/`
54 | - **model checkpoints**: put them under `checkpoints/`
55 |
56 | | Model | Data | Task | NLQ$`_\texttt{v2}`$ | QaEgo4D | Cost$`^{*}`$ |
57 | | ------------------------------- | ------------------------------------------------------------ | ---------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------- |
58 | | $`\text{GroundVQA}_\texttt{S}`$ | QaEgo4D | CloseQA+OpenQA+VLG | [[val_R1_03=11.0]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D-COV-val_R1_03%3D11.0.ckpt) | [[test_ROUGE=29.0]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D-COV-test_ROUGE%3D29.0.ckpt) | 7 |
59 | | $`\text{GroundVQA}_\texttt{S}`$ | QaEgo4D+EgoTimeQA | CloseQA+OpenQA+VLG | [[val_R1_03=23.3]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-val_R1_03%3D23.3.ckpt) | [[test_ROUGE=30.2]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_S-QaEgo4D_EgoTimeQA-COV-test_ROUGE%3D30.2.ckpt) | 150 |
60 | | $`\text{GroundVQA}_\texttt{B}`$ | QaEgo4D+EgoTimeQA | CloseQA+OpenQA+VLG | [[val_R1_03=25.6]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-val_R1_03%3D25.6.ckpt) | [[test_ROUGE=30.4]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-QaEgo4D_EgoTimeQA-COV-test_ROUGE%3D30.4.ckpt) | 350 |
61 | | $`\text{GroundVQA}_\texttt{B}`$ | NLQ$`_\texttt{v2}`$+NaQ $\rightarrow$ NLQ$`_\texttt{v2}`$$`^{**}`$ | VLG | [[val_R1_03=29.7]](https://huggingface.co/Becomebright/GroundVQA/blob/main/GroundVQA_B-NLQ_NaQ-finetune_NLQ-VLG-val_R1_03%3D29.7.ckpt) | - | 700 |
62 |
63 | \* The training costs counted by GPU hours.
64 |
65 | ** Pre-trained on NLQ$`_\texttt{v2}`$ and NaQ, and further fine-tuned on NLQ$`_\texttt{v2}`$.
66 |
67 | ## Training
68 |
69 | ```bash
70 | # train GroundVQA_S on QaEgo4D
71 | bash scripts/train_groundvqa_small-qaego4d.sh
72 |
73 | # train GroundVQA_S on QaEgo4D and EgoTimeQA
74 | bash scripts/train_groundvqa_small-qaego4d_egotimeqa.sh
75 |
76 | # train GroundVQA_B on QaEgo4D and EgoTimeQA
77 | bash scripts/train_groundvqa_base-qaego4d_egotimeqa.sh
78 | ```
79 |
80 | ## Evaluation
81 |
82 | ```bash
83 | # evaluate GroundVQA_S train on QaEgo4D
84 | bash scripts/evaluate_groundvqa_s-qaego4d.sh
85 |
86 | # evaluate GroundVQA_S train on QaEgo4D and EgoTimeQA
87 | bash scripts/evaluate_groundvqa_s-qaego4d_egotimeqa.sh
88 |
89 | # evaluate GroundVQA_B train on QaEgo4D and EgoTimeQA
90 | bash scripts/evaluate_groundvqa_b-qaego4d_egotimeqa.sh
91 |
92 | # evaluate GroundVQA_B train on NLQv2 and NaQ and further fine-tuned on NLQv2
93 | bash scripts/evaluate_groundvqa_b-nlq_naq.sh
94 | ```
95 |
96 | ## Generate OpenQA data
97 |
98 | Download the processed Ego4D narrations [[em_train_narrations.pkl]](https://huggingface.co/Becomebright/GroundVQA/blob/main/em_train_narrations.pkl)
99 |
100 | Put it under `utils/generate_open_qa/`
101 |
102 | Generate QAs in parallel on multiple GPUs (*e.g.*, 2)
103 |
104 | ```bash
105 | cd utils/generate_open_qa
106 |
107 | # GPU-0
108 | CUDA_VISIBLE_DEVICES=0 python generate.py -start 0 -end 5000
109 |
110 | # GPU-1
111 | CUDA_VISIBLE_DEVICES=1 python generate.py -start 5000 -end 11000 # 10777 clips in total
112 | ```
113 |
114 | Merge the results and normalize the duration of temporal windows
115 |
116 | ```bash
117 | python merge.py
118 | ```
119 |
120 | ## Generate CloseQA data
121 |
122 | ```bash
123 | cd utils/generate_close_qa
124 | python generate.py
125 | ```
126 |
127 | The above script produce wrong answers for EgoTimeQA using a single GPU.
128 |
129 | You can also conduct generation on multiple GPUs or generate wrong answers for QaEgo4D.
130 |
131 | ## Citation
132 |
133 | ```latex
134 | @inproceedings{di2023groundvqa,
135 | title={Grounded Question-Answering in Long Egocentric Videos},
136 | author={Di, Shangzhe and Xie, Weidi},
137 | booktitle={CVPR},
138 | year={2024}
139 | }
140 | ```
141 |
142 | ## Acknowledgements
143 |
144 | Our code is based on [QaEgo4D](https://github.com/lbaermann/qaego4d), [GroundNLQ](https://github.com/houzhijian/GroundNLQ), and [ActionFormer](https://github.com/happyharrycn/actionformer_release).
145 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 | from argparse import ArgumentParser, Namespace
4 |
5 | import hydra
6 | import torch
7 | import pytorch_lightning as pl
8 | from omegaconf import DictConfig, open_dict
9 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging
10 | from pytorch_lightning.plugins import DDPPlugin
11 |
12 | from model.ours.dataset import JointDataModule
13 | from model.ours.lightning_module import LightningModule
14 |
15 |
16 | def dict_parser(s: str):
17 | return eval('{' + re.sub(r'(\w+)=(["\']?\w+["\']?)', r'"\1":\2', s) + '}')
18 |
19 | def add_common_trainer_util_args(parser, default_monitor_variable='val_loss', default_monitor_mode='min'):
20 | if default_monitor_mode not in ['min', 'max']:
21 | raise ValueError(default_monitor_mode)
22 | parser.add_argument('--lr_find_kwargs', default=dict(min_lr=5e-6, max_lr=1e-2), type=dict_parser,
23 | help='Arguments for LR find (--auto_lr_find). Default "min_lr=5e-6,max_lr=1e-2"')
24 | parser.add_argument('--random_seed', default=42, type=lambda s: None if s == 'None' else int(s),
25 | help='Seed everything. Set to "None" to disable global seeding')
26 | parser.add_argument('--auto_resume', default=False, action='store_true',
27 | help='Automatically resume last saved checkpoint, if available.')
28 | parser.add_argument('--test_only', default=False, action='store_true',
29 | help='Skip fit and call only test. This implies automatically detecting newest checkpoint, '
30 | 'if --checkpoint_path is not given.')
31 | parser.add_argument('--checkpoint_path', default=None, type=str,
32 | help='Load this checkpoint to resume training or run testing. '
33 | 'Pass in the special value "best" to use the best checkpoint according to '
34 | 'args.monitor_variable and args.monitor_mode. '
35 | 'Using "best" only works with test_only mode.')
36 | parser.add_argument('--ignore_existing_checkpoints', default=False, action='store_true',
37 | help='Proceed even with training a new model, even if previous checkpoints exists.')
38 | parser.add_argument('--monitor_variable', default=default_monitor_variable, type=str,
39 | help='Variable to monitor for early stopping and for checkpoint selection. '
40 | f'Default: {default_monitor_variable}')
41 | parser.add_argument('--monitor_mode', default=default_monitor_mode, type=str, choices=['min', 'max'],
42 | help='Mode for monitoring the monitor_variable (for early stopping and checkpoint selection). '
43 | f'Default: {default_monitor_mode}')
44 | parser.add_argument('--reset_early_stopping_criterion', default=False, action='store_true',
45 | help='Reset the early stopping criterion when loading from checkpoint. '
46 | 'Prevents immediate exit after switching to more complex dataset in curriculum strategy')
47 |
48 | def apply_argparse_defaults_to_hydra_config(config: DictConfig, parser: ArgumentParser, verbose=False):
49 | args = parser.parse_args([]) # Parser is not allowed to have required args, otherwise this will fail!
50 | defaults = vars(args)
51 |
52 | def _apply_defaults(dest: DictConfig, source: dict, indentation=''):
53 | for k, v in source.items():
54 | if k in dest and isinstance(v, dict):
55 | current_value = dest[k]
56 | if current_value is not None:
57 | assert isinstance(current_value, DictConfig)
58 | _apply_defaults(current_value, v, indentation + ' ')
59 | elif k not in dest:
60 | dest[k] = v
61 | if verbose:
62 | print(indentation, 'set default value for', k)
63 |
64 | with open_dict(config):
65 | _apply_defaults(config, defaults)
66 |
67 |
68 | def _adjust_ddp_config(trainer_cfg):
69 | trainer_cfg = dict(trainer_cfg)
70 | strategy = trainer_cfg.get('strategy', None)
71 | if trainer_cfg['gpus'] > 1 and strategy is None:
72 | strategy = 'ddp' # Select ddp by default
73 | if strategy == 'ddp':
74 | trainer_cfg['strategy'] = DDPPlugin(
75 | find_unused_parameters=trainer_cfg['find_unused_parameters'],
76 | gradient_as_bucket_view=True)
77 | return trainer_cfg
78 |
79 |
80 | @hydra.main(config_path='config', config_name='base')
81 | def train(config: DictConfig):
82 | fake_parser = ArgumentParser()
83 | add_common_trainer_util_args(fake_parser, default_monitor_variable='val_loss')
84 | apply_argparse_defaults_to_hydra_config(config.trainer, fake_parser)
85 | pl.seed_everything(config.trainer.random_seed, workers=True)
86 | trainer_cfg = Namespace(**_adjust_ddp_config(config.trainer))
87 |
88 | data = JointDataModule(config.dataset)
89 | data.setup()
90 |
91 | total_steps = trainer_cfg.max_epochs * math.floor(len(data.train_dataset) / trainer_cfg.gpus / config.dataset.batch_size)
92 | model = LightningModule(config, total_steps)
93 | if trainer_cfg.checkpoint_path:
94 | state_dict = torch.load(trainer_cfg.checkpoint_path, map_location='cpu')['state_dict']
95 | if not trainer_cfg.load_nlq_head:
96 | print('Train NLQ head from scratch')
97 | state_dict = {k: v for k, v in state_dict.items() if not "nlq_head" in k}
98 | if not trainer_cfg.load_decoder:
99 | print('Train LM decoder head from scratch')
100 | state_dict = {k: v for k, v in state_dict.items() if not ("decoder" in k or "lm_head" in k)}
101 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
102 | print(f'Load checkpoint: {trainer_cfg.checkpoint_path}')
103 | print(f'Missing Keys: {missing_keys}')
104 | print(f'Unexpected Keys: {unexpected_keys}')
105 |
106 |
107 | if trainer_cfg.test_only: # evaluation
108 | trainer = pl.Trainer.from_argparse_args(
109 | trainer_cfg,
110 | enable_checkpointing=False,
111 | logger=False
112 | )
113 | if trainer_cfg.val:
114 | trainer.validate(
115 | model, data.val_dataloader(),
116 | )
117 | else:
118 | trainer.test(
119 | model, data.test_dataloader(),
120 | )
121 | else: # training
122 | model_checkpoint = []
123 | if 'QaEgo4D_test' in config.dataset.test_splits:
124 | model_checkpoint.append(
125 | ModelCheckpoint(
126 | save_last=False,
127 | monitor='val_ROUGE',
128 | mode='max',
129 | save_top_k=1,
130 | filename='{step}-{' + 'val_ROUGE' + ':.3f}')
131 | )
132 | if 'QaEgo4D_test_close' in config.dataset.test_splits:
133 | model_checkpoint.append(
134 | ModelCheckpoint(
135 | save_last=False,
136 | monitor='val_close_acc',
137 | mode='max',
138 | save_top_k=1,
139 | filename='{step}-{' + 'val_close_acc' + ':.3f}')
140 | )
141 | if 'NLQ_val' in config.dataset.test_splits:
142 | model_checkpoint.append(
143 | ModelCheckpoint(
144 | save_last=False,
145 | monitor='val_R1_03',
146 | mode='max',
147 | save_top_k=1,
148 | filename='{step}-{' + 'val_R1_03' + ':.3f}')
149 | )
150 | trainer = pl.Trainer.from_argparse_args(trainer_cfg, callbacks=[
151 | LearningRateMonitor(logging_interval='step'),
152 | # StochasticWeightAveraging(swa_lrs=1e-2),
153 | *model_checkpoint
154 | ])
155 | trainer.fit(
156 | model, data.train_dataloader(), data.val_dataloader(),
157 | )
158 |
159 |
160 | if __name__ == '__main__':
161 | train()
162 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from collections import defaultdict
4 | from pathlib import Path
5 | from pprint import pprint
6 | from typing import List, Dict, Any
7 |
8 | # import bert_score
9 | import sentence_transformers
10 | from nltk.translate.meteor_score import meteor_score
11 | from rouge_score.rouge_scorer import RougeScorer
12 | from rouge_score.tokenize import tokenize
13 | # from sacrebleu.metrics import BLEU, BLEUScore
14 | from torchmetrics.functional import sacre_bleu_score
15 | from nltk.tokenize import word_tokenize
16 | from nltk.corpus import wordnet
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 |
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=1):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 |
38 | # Check whether to use
39 | # - https://github.com/Maluuba/nlg-eval
40 | # - https://github.com/hwanheelee1993/KPQA
41 | def calc_metrics(predictions: List[str], gold_annotations: List[List[str]], test=False) -> Dict[str, Any]:
42 | """
43 | Calculate metrics.
44 |
45 | Parameters
46 | ----------
47 | predictions : list[str]
48 | The list of predictions
49 | gold_annotations : list[list[str]]
50 | A list with the same length as predictions.
51 | Each element is a list of possible target candidates for the corresponding prediction.
52 | All elements should have the same length.
53 | """
54 | if len(predictions) != len(gold_annotations):
55 | raise ValueError(f'{len(predictions)} != {len(gold_annotations)}')
56 | ref_count = len(gold_annotations[0])
57 | if any(len(refs) != ref_count for refs in gold_annotations):
58 | raise ValueError(f'All refs should have the same length {ref_count}!')
59 |
60 | acc = _calc_accuracy(predictions, gold_annotations)
61 | # bleu = _calc_bleu(predictions, gold_annotations)
62 | rouge = _calc_rouge(predictions, gold_annotations)
63 | meteor = _calc_meteor(predictions, gold_annotations)
64 | # bert_score = _calc_bertscore(predictions, gold_annotations)
65 | # wups = _calc_wups(predictions, gold_annotations)
66 | if test:
67 | sts = SentenceTransformerSimilarity()
68 | sts_score = sts.calc_st_similarity(predictions, gold_annotations)
69 |
70 | return {
71 | 'plain_acc': acc,
72 | # **bleu,
73 | 'ROUGE': rouge['rougeL']['f'],
74 | **_flatten_dict(rouge, prefix='ROUGE.'),
75 | 'METEOR': meteor,
76 | 'SentenceSimilarity': sts_score if test else 0.
77 | # 'BERTSCORE': bert_score,
78 | # 'WUPS': wups
79 | }
80 |
81 |
82 | """ Sentence Transformer """
83 | class SentenceTransformerSimilarity:
84 | def __init__(self):
85 | self.model = sentence_transformers.SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
86 |
87 | def _calc_similarity(self, pred, gts):
88 | pred_emb = self.model.encode(pred)
89 | gts_emb = self.model.encode(gts)
90 | score = sentence_transformers.util.dot_score(pred_emb, gts_emb)[0,0].cpu()
91 | return float(score)
92 |
93 | def calc_st_similarity(self, predictions, gold_annotations):
94 | total_score = 0.
95 | for pred, gts in zip(predictions, gold_annotations):
96 | score = self._calc_similarity(pred, gts)
97 | total_score += score
98 | return total_score / len(predictions)
99 |
100 |
101 | """ WUPS """
102 | # ====================================================
103 | # @Time : 13/9/20 4:19 PM
104 | # @Author : Xiao Junbin
105 | # @Email : junbin@comp.nus.edu.sg
106 | # @File : metrics.py
107 | # ====================================================
108 |
109 |
110 | def wup(word1, word2, alpha):
111 | """
112 | calculate the wup similarity
113 | :param word1:
114 | :param word2:
115 | :param alpha:
116 | :return:
117 | """
118 | # print(word1, word2)
119 | if word1 == word2:
120 | return 1.0
121 |
122 | w1 = wordnet.synsets(word1)
123 | w1_len = len(w1)
124 | if w1_len == 0: return 0.0
125 | w2 = wordnet.synsets(word2)
126 | w2_len = len(w2)
127 | if w2_len == 0: return 0.0
128 |
129 | #match the first
130 | word_sim = w1[0].wup_similarity(w2[0])
131 | if word_sim is None:
132 | word_sim = 0.0
133 |
134 | if word_sim < alpha:
135 | word_sim = 0.1*word_sim
136 | return word_sim
137 |
138 | def wups(words1, words2, alpha):
139 | """
140 |
141 | :param pred:
142 | :param truth:
143 | :param alpha:
144 | :return:
145 | """
146 | sim = 1.0
147 | flag = False
148 | for w1 in words1:
149 | max_sim = 0
150 | for w2 in words2:
151 | word_sim = wup(w1, w2, alpha)
152 | if word_sim > max_sim:
153 | max_sim = word_sim
154 | if max_sim == 0: continue
155 | sim *= max_sim
156 | flag = True
157 | if not flag:
158 | sim = 0.0
159 | return sim
160 |
161 | def get_wups(pred, truth, alpha=0):
162 | """
163 | calculate the wups score
164 | :param pred:
165 | :param truth:
166 | :return:
167 | """
168 | pred = word_tokenize(pred)
169 | truth = word_tokenize(truth)
170 | item1 = wups(pred, truth, alpha)
171 | item2 = wups(truth, pred, alpha)
172 | value = min(item1, item2)
173 | return value
174 |
175 | def _calc_wups(predictions, gold_annotations):
176 | wups = 0
177 | for pred, gt in zip(predictions, gold_annotations):
178 | wups += get_wups(pred, gt[0])
179 | wups /= len(predictions)
180 | return wups
181 | """ WUPS """
182 |
183 |
184 | # def _calc_bertscore(predictions, gold_annotations):
185 | # references = [x[0] for x in gold_annotations]
186 | # P, R, F1 = bert_score.score(
187 | # predictions, references, lang='en',
188 | # model_type='microsoft/deberta-xlarge-mnli',
189 | # )
190 | # return float(F1.mean())
191 |
192 |
193 | def _calc_accuracy(predictions, gold_annotations):
194 | correct = 0
195 | for pred, possible_refs in zip(predictions, gold_annotations):
196 | if any(ref == pred for ref in possible_refs):
197 | correct += 1
198 | total = len(predictions)
199 | return correct / total
200 |
201 |
202 | def _calc_meteor(predictions, gold_annotations):
203 | score = AverageMeter()
204 | for pred, possible_refs in zip(predictions, gold_annotations):
205 | pred = tokenize(pred, None)
206 | # https://github.com/cmu-mtlab/meteor/blob/master/src/edu/cmu/meteor/util/Normalizer.java
207 | possible_refs = [tokenize(x, None) for x in possible_refs]
208 | score.update(meteor_score(possible_refs, pred))
209 | return score.avg
210 |
211 |
212 | def _calc_rouge(predictions, gold_annotations) -> Dict[str, Dict[str, float]]:
213 | rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
214 | rouge = defaultdict(lambda: defaultdict(AverageMeter))
215 | for pred, possible_refs in zip(predictions, gold_annotations):
216 | sample_result = {}
217 | for ref in possible_refs:
218 | single_ref_result = rouge_scorer.score(ref, pred)
219 | for k, scores in single_ref_result.items():
220 | existing_result_dict = sample_result.setdefault(k, {})
221 | if existing_result_dict.get('f', -1) < scores.fmeasure:
222 | existing_result_dict.update(f=scores.fmeasure, p=scores.precision, r=scores.recall)
223 | for k, best_scores in sample_result.items():
224 | rouge[k]['p'].update(best_scores['p'])
225 | rouge[k]['r'].update(best_scores['r'])
226 | rouge[k]['f'].update(best_scores['f'])
227 | return {
228 | rouge_type: {
229 | measure: score.avg
230 | for measure, score in results.items()
231 | } for rouge_type, results in rouge.items()
232 | }
233 |
234 |
235 | def _calc_bleu(predictions, gold_annotations) -> Dict[str, float]:
236 | return {
237 | 'BLEU': sacre_bleu_score(predictions, gold_annotations, n_gram=1)
238 | }
239 | # refs_transposed = [
240 | # [refs[i] for refs in gold_annotations]
241 | # for i in range(len(gold_annotations[0]))
242 | # ]
243 | # bleu: BLEUScore = BLEU().corpus_score(predictions, refs_transposed)
244 | # return {
245 | # 'BLEU': bleu.score,
246 | # 'BLEU.bp': bleu.bp,
247 | # 'BLEU.ratio': bleu.ratio,
248 | # 'BLEU.hyp_len': float(bleu.sys_len),
249 | # 'BLEU.ref_len': float(bleu.ref_len),
250 | # }
251 |
252 |
253 | def _flatten_dict(d, prefix=''):
254 | result = {}
255 | for k, v in d.items():
256 | my_key = prefix + k
257 | if isinstance(v, dict):
258 | result.update(_flatten_dict(v, prefix=my_key + '.'))
259 | else:
260 | result[my_key] = v
261 | return result
262 |
263 |
264 | def main():
265 | parser = ArgumentParser('Eval output file')
266 | parser.add_argument('--gold_answers', type=str, required=True,
267 | help='Path to answers.json, containing mapping from sample_id to answer')
268 | parser.add_argument('eval_file', type=str,
269 | help='JSON File to evaluate. Should contain mapping from sample_id '
270 | 'to hypothesis or array of hypotheses')
271 | args = parser.parse_args()
272 |
273 | gold_answers = json.loads(Path(args.gold_answers).read_text())
274 | hypotheses = json.loads(Path(args.eval_file).read_text())
275 | if isinstance(next(iter(hypotheses.values())), list):
276 | hypotheses = {k: v[0] for k, v in hypotheses.items()}
277 | assert len(hypotheses.keys() - gold_answers.keys()) == 0, 'No gold answer for some hypotheses'
278 |
279 | gold_and_hypo = [(gold_answers[k], hypotheses[k]) for k in hypotheses.keys()]
280 | hypo_list = [h for g, h in gold_and_hypo]
281 | gold_list = [[g] for g, h in gold_and_hypo]
282 | metrics = calc_metrics(hypo_list, gold_list)
283 |
284 | pprint(metrics)
285 |
286 |
287 | if __name__ == '__main__':
288 | # main()
289 |
290 | # debug
291 | st = SentenceTransformerSimilarity()
292 | score = st._calc_similarity('inside the drawer', ['inside the drawer'])
293 | print(score) # 1.0
294 |
295 | score = st._calc_similarity('inside the drawer', ['on the table'])
296 | print(score) # 0.49
297 |
298 | score = st._calc_similarity('inside the drawer', ['in the drawer'])
299 | print(score) # 0.93
300 |
301 | # mean_score = st.calc_st_similarity(
302 | # ['floor', '3'],
303 | # [['on the ground'], ['two']]
304 | # )
305 | # print(mean_score)
306 |
--------------------------------------------------------------------------------
/model/ours/dataset.py:
--------------------------------------------------------------------------------
1 | # Joint dataset of CloseQA, OpenQA, and NLQ
2 |
3 | import os
4 | import math
5 | import json
6 | import random
7 | from pathlib import Path
8 | from typing import Iterable
9 |
10 | import h5py
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | import pytorch_lightning as pl
15 | from torch.utils.data import Dataset, DataLoader, ConcatDataset
16 | from torch.nn.utils.rnn import pad_sequence
17 | from torch.utils.data.dataset import Dataset
18 | from transformers import AutoTokenizer
19 |
20 |
21 | class BaseDataset(Dataset):
22 | def __init__(self, data_dir, split, feature_type, max_v_len):
23 | super().__init__()
24 | self.split = split
25 | self.video_features = h5py.File(os.path.join(data_dir, feature_type + '.hdf5'), 'r')
26 | self.annotations = json.loads(Path(os.path.join(data_dir, f'annotations.{split}.json')).read_text())
27 | self.max_v_len = max_v_len
28 | print(f'{split} set: {len(self.annotations)}')
29 |
30 | def __len__(self):
31 | return len(self.annotations)
32 |
33 | def _get_video_feature(self, video_id):
34 | video_feature = torch.from_numpy(self.video_features[video_id][:])
35 | v_len = video_feature.shape[0]
36 | sample_ratio = 1.0
37 | if v_len > self.max_v_len:
38 | sample_idx = torch.linspace(0, v_len-1, self.max_v_len).long()
39 | video_feature = video_feature[sample_idx]
40 | sample_ratio = self.max_v_len / v_len
41 | v_len = self.max_v_len
42 | return video_feature, v_len, sample_ratio
43 |
44 |
45 | class NLQDataset(BaseDataset):
46 | def __init__(self, data_dir, split, feature_type, max_v_len):
47 | super().__init__(data_dir, split, feature_type, max_v_len)
48 |
49 | def __getitem__(self, index):
50 | video_id = self.annotations[index]['video_id']
51 | query_id = self.annotations[index].get('sample_id')
52 | question = self.annotations[index]['question']
53 |
54 | video_feature, v_len, sample_ratio = self._get_video_feature(video_id)
55 |
56 | if 'clip_start_sec' in self.annotations[index]:
57 | start_time = self.annotations[index].get('clip_start_sec')
58 | end_time = self.annotations[index].get('clip_end_sec')
59 | else:
60 | start_time = self.annotations[index].get('moment_start_frame') / 30
61 | end_time = self.annotations[index].get('moment_end_frame') / 30
62 |
63 | query_type = self.annotations[index].get('query_type')
64 | if query_type == 'narration':
65 | duration = end_time - start_time
66 | center = (end_time + start_time) / 2
67 | scale_ratio = random.randint(1, 10)
68 | shift_number = random.uniform(-1, 1) * (scale_ratio - 1) * duration / 2
69 | new_center = center - shift_number
70 | start_time = new_center - scale_ratio * duration / 2
71 | end_time = new_center + scale_ratio * duration / 2
72 |
73 | segments = torch.tensor([[start_time, end_time]]) * 30 / 16.043 * sample_ratio
74 | labels = torch.zeros(len(segments), dtype=torch.int64)
75 | one_hot_labels = F.one_hot(labels, 1) # (1, 1)
76 |
77 | return {
78 | 'video_id': video_id,
79 | 'question': f"question: {question} video: ",
80 | 'answer': 'None',
81 | 'v_feat': video_feature,
82 | 'v_len': v_len,
83 | 'segments': segments,
84 | 'one_hot_labels': one_hot_labels,
85 | 'query_id': query_id,
86 | 'sample_ratio': sample_ratio,
87 | 'task': 'NLQ'
88 | }
89 |
90 |
91 | class QADataset(BaseDataset):
92 | def __init__(self, data_dir, split, feature_type, max_v_len, qa_type, CloseQA_weight=50):
93 | super().__init__(data_dir, split, feature_type, max_v_len)
94 | self.qa_type = qa_type # CloseQA, OpenQA, Mixed
95 | self.choice_indices = ['A', 'B', 'C', 'D']
96 | self.CloseQA_weight = CloseQA_weight
97 | self.openqa_weight = 100 - CloseQA_weight
98 |
99 | def __getitem__(self, index):
100 | video_id = self.annotations[index]['video_id']
101 | query_id = self.annotations[index].get('sample_id')
102 | question = self.annotations[index]['question']
103 | answer = self.annotations[index]['answer'].strip()
104 |
105 | qa_type = self.qa_type
106 | if qa_type == 'Mixed': # randomly choose a qa type
107 | qa_type = random.choices(['CloseQA', 'OpenQA'], weights=[self.CloseQA_weight, self.openqa_weight], k=1)[0]
108 | if qa_type == 'OpenQA':
109 | question_str = f"question: {question} video: "
110 | answer_str = answer
111 | elif qa_type == 'CloseQA':
112 | wrong_answers = self.annotations[index]['wrong_answers']
113 | # shuffle choices
114 | choices = [answer] + wrong_answers
115 | random.shuffle(choices)
116 | answer_index = choices.index(answer)
117 | choices = [f'({self.choice_indices[idx]}) {choices[idx]}' for idx in range(len(choices))] # ["(A) xx", "(B) xx", "(C) xx", "(D) xx"]
118 | choices_str = ' '.join(choices) # (A) xx (B) xx (C) xx (D) xx
119 | question_str = f"question: {question} choices: {choices_str}. video: "
120 | answer_str = choices[answer_index] # (A/B/C/D) xx
121 | else:
122 | raise NotImplementedError
123 |
124 | video_feature, v_len, sample_ratio = self._get_video_feature(video_id)
125 |
126 | start_frame = self.annotations[index].get('moment_start_frame')
127 | end_frame = self.annotations[index].get('moment_end_frame')
128 | start_time = start_frame / 30
129 | end_time = end_frame / 30
130 |
131 | if 'video_start_sec' not in self.annotations[index]: # LLM generated QA
132 | duration = end_time - start_time
133 | center = (end_time + start_time) / 2
134 | scale_ratio = random.randint(1, 10)
135 | shift_number = random.uniform(-1, 1) * (scale_ratio - 1) * duration / 2
136 | new_center = center - shift_number
137 | start_time = new_center - scale_ratio * duration / 2
138 | end_time = new_center + scale_ratio * duration / 2
139 |
140 | segments = torch.tensor([[start_time, end_time]]) * 30 / 16.043 * sample_ratio
141 | labels = torch.zeros(len(segments), dtype=torch.int64)
142 | one_hot_labels = F.one_hot(labels, 1) # (1, 1)
143 |
144 | return {
145 | 'video_id': video_id,
146 | 'question': question_str,
147 | 'answer': answer_str,
148 | 'v_feat': video_feature,
149 | 'v_len': v_len,
150 | 'segments': segments,
151 | 'one_hot_labels': one_hot_labels,
152 | 'query_id': query_id,
153 | 'sample_ratio': sample_ratio,
154 | 'task': qa_type
155 | }
156 |
157 |
158 | class JointDataset(ConcatDataset):
159 | def __init__(self, datasets: Iterable[Dataset], tokenizer_path) -> None:
160 | super().__init__(datasets)
161 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
162 | self.tokenizer.pad_token = self.tokenizer.eos_token # BUG: Set this per convenience for GPT-2
163 |
164 | def collate_fn(self, batch):
165 | question = [b['question'] for b in batch]
166 | question_tok = self.tokenizer(question, padding=True, return_tensors='pt', add_special_tokens=False)
167 |
168 | answer = [b['answer'] for b in batch]
169 | labels = self.tokenizer(answer, padding=True, return_tensors='pt').input_ids
170 | # NOTE: NLQ data does not have an answer
171 | for idx, a in enumerate(answer):
172 | if a == 'None':
173 | labels[idx] = torch.ones_like(labels[idx]) * -100
174 |
175 | video_feature = [b['v_feat'] for b in batch]
176 | video_feature_padded = pad_sequence(video_feature, batch_first=True)
177 | video_mask = pad_sequence([torch.ones(len(v)) for v in video_feature], batch_first=True).bool()
178 |
179 | result = {
180 | 'video_id': [b['video_id'] for b in batch],
181 | 'q_text': question,
182 | 'q_token': question_tok.input_ids,
183 | 'q_mask': question_tok.attention_mask.bool(),
184 | 'v_feat': video_feature_padded,
185 | 'v_mask': video_mask,
186 | 'v_len': np.asarray([b['v_len'] for b in batch], dtype=np.long),
187 | 'gt_segments': torch.stack([b['segments'] for b in batch]),
188 | 'gt_labels': torch.stack([b['one_hot_labels'] for b in batch]),
189 | 'query_id': [b['query_id'] for b in batch],
190 | 'sample_ratio': [b['sample_ratio'] for b in batch],
191 | 'a_text': answer,
192 | 'labels': labels,
193 | 'task': [b['task'] for b in batch]
194 | }
195 |
196 | return result
197 |
198 |
199 | class JointDataModule(pl.LightningDataModule):
200 | train_dataset = None
201 | val_dataset = None
202 | test_dataset = None
203 |
204 | def __init__(self, config):
205 | super().__init__()
206 | self.config = config
207 |
208 | def setup(self, stage=None):
209 | CloseQA_weight = self.config.get('closeqa_weight', 50)
210 | print(f'CloseQA percentage: {CloseQA_weight}%')
211 | self.train_dataset = JointDataset([
212 | QADataset('data/unified', train_split, self.config.feature_type, self.config.max_v_len, 'Mixed', CloseQA_weight)
213 | for train_split in self.config.qa_train_splits
214 | ] + [
215 | NLQDataset('data/unified', train_split, self.config.feature_type, self.config.max_v_len)
216 | for train_split in self.config.nlq_train_splits
217 | ],
218 | self.config.tokenizer_path
219 | )
220 |
221 | test_datasets = []
222 | for split in self.config.test_splits:
223 | if split == 'QaEgo4D_test':
224 | test_datasets.append(QADataset('data/unified', split, self.config.feature_type, self.config.max_v_len, 'OpenQA'))
225 | elif split == 'QaEgo4D_test_close':
226 | test_datasets.append(QADataset('data/unified', split, self.config.feature_type, self.config.max_v_len, 'CloseQA'))
227 | elif split in ['NLQ_val', 'NLQ_test_unannotated']:
228 | test_datasets.append(NLQDataset('data/unified', split, self.config.feature_type, self.config.max_v_len))
229 | else:
230 | print(split)
231 | raise NotImplementedError
232 | self.val_dataset = self.test_dataset = JointDataset(test_datasets, self.config.tokenizer_path)
233 |
234 | print(f'#total train: {len(self.train_dataset)}')
235 | print(f'#total val: {len(self.val_dataset)}')
236 | print(f'#total test: {len(self.test_dataset)}')
237 |
238 | def train_dataloader(self):
239 | return DataLoader(
240 | self.train_dataset,
241 | batch_size=self.config.batch_size,
242 | shuffle=True,
243 | drop_last=True,
244 | num_workers=self.config.num_workers,
245 | collate_fn=self.train_dataset.collate_fn,
246 | pin_memory=True,
247 | persistent_workers=True
248 | )
249 |
250 | def val_dataloader(self):
251 | return DataLoader(
252 | self.val_dataset,
253 | batch_size=self.config.batch_size,
254 | shuffle=False,
255 | drop_last=False,
256 | num_workers=self.config.num_workers,
257 | collate_fn=self.val_dataset.collate_fn,
258 | pin_memory=True
259 | )
260 |
261 | def test_dataloader(self):
262 | return DataLoader(
263 | self.test_dataset,
264 | batch_size=self.config.batch_size,
265 | shuffle=False,
266 | drop_last=False,
267 | num_workers=self.config.num_workers,
268 | collate_fn=self.val_dataset.collate_fn,
269 | pin_memory=True
270 | )
271 |
--------------------------------------------------------------------------------
/model/ours/lightning_module.py:
--------------------------------------------------------------------------------
1 | import json
2 | import copy
3 | import random
4 |
5 | import torch
6 | import pytorch_lightning as pl
7 | from hydra.utils import instantiate
8 | from transformers import AutoTokenizer
9 | from torch.optim.lr_scheduler import OneCycleLR
10 |
11 | from eval import calc_metrics
12 | from eval_nlq import ReferringRecall
13 |
14 |
15 | class TestLightningModule(pl.LightningModule):
16 | def __init__(self, config):
17 | super().__init__()
18 | self.config = config
19 | self.tokenizer = AutoTokenizer.from_pretrained(config.dataset.tokenizer_path)
20 | self.tokenizer.pad_token = self.tokenizer.eos_token
21 | self.model = instantiate(config.model, max_v_len=config.dataset.max_v_len)
22 |
23 | def test_step(self, batch, batch_idx):
24 | nlq_results, answer_tokens = self.model.generate(**batch)
25 | pred_answer = self.tokenizer.batch_decode(answer_tokens, skip_special_tokens=True)
26 | return {
27 | 'question': batch['q_text'],
28 | 'video_id': batch['video_id'],
29 | 'answer': batch['a_text'] if 'a_text' in batch else '',
30 | 'pred_answer': pred_answer,
31 | 'nlq_results': nlq_results,
32 | 'query_id': batch['query_id'],
33 | 'sample_ratio': batch['sample_ratio'],
34 | 'task': batch['task'],
35 | 'clip_uid': batch['video_id']
36 | }
37 |
38 | def test_epoch_end(self, outputs):
39 | self.save_nlq_results(outputs)
40 |
41 | def save_nlq_results(self, preds):
42 | # aggregate preds
43 | pred_dict = {
44 | "version": "1.0",
45 | "challenge": "ego4d_nlq_challenge",
46 | "results": []
47 | }
48 | for batch_pred in preds:
49 | for i in range(len(batch_pred['video_id'])):
50 | qid = batch_pred['query_id'][i]
51 | annotation_uid, query_idx = qid.split('_')
52 | query_idx = int(query_idx)
53 | clip_uid = batch_pred['clip_uid'][i]
54 | sample_ratio = batch_pred['sample_ratio'][i]
55 | predicted_times = [
56 | [segment[0] / sample_ratio, segment[1] / sample_ratio]
57 | for segment in batch_pred['nlq_results'][i]['segments'].cpu().detach().tolist()
58 | ]
59 |
60 | pred_dict['results'].append({
61 | 'clip_uid': clip_uid,
62 | 'annotation_uid': annotation_uid,
63 | 'query_idx': query_idx,
64 | 'predicted_times': predicted_times
65 | })
66 |
67 | with open('nlq_eval_results/nlq_v2.json', 'w') as f:
68 | json.dump(pred_dict, f)
69 |
70 |
71 | class LightningModule(pl.LightningModule):
72 | def __init__(self, config, total_steps):
73 | super().__init__()
74 | self.config = config
75 | self.tokenizer = AutoTokenizer.from_pretrained(config.dataset.tokenizer_path)
76 | self.tokenizer.pad_token = self.tokenizer.eos_token
77 | self.model = instantiate(config.model, max_v_len=config.dataset.max_v_len)
78 | self.nlq_evaluator = ReferringRecall(
79 | dataset="ego4d",
80 | gt_file=config.dataset.nlq_val_anno
81 | )
82 | self._log_indices = {}
83 | self.total_steps = total_steps
84 |
85 | def training_step(self, batch, batch_idx):
86 | total_loss, ce_loss, time_loss = self.model(**batch)
87 | self.log('total_loss', total_loss, rank_zero_only=True)
88 | self.log('ce_loss', ce_loss, rank_zero_only=True)
89 | self.log('time_loss', time_loss, rank_zero_only=True)
90 | return {
91 | 'loss': total_loss,
92 | }
93 |
94 | def validation_step(self, batch, batch_idx):
95 | nlq_results, answer_tokens = self.model.generate(**batch)
96 | pred_answer = self.tokenizer.batch_decode(answer_tokens, skip_special_tokens=True)
97 | return {
98 | 'question': batch['q_text'],
99 | 'video_id': batch['video_id'],
100 | 'answer': batch['a_text'] if 'a_text' in batch else '',
101 | 'pred_answer': pred_answer,
102 | 'nlq_results': nlq_results,
103 | 'query_id': batch['query_id'],
104 | 'sample_ratio': batch['sample_ratio'],
105 | 'task': batch['task']
106 | }
107 |
108 | def test_step(self, batch, batch_idx):
109 | return self.validation_step(batch, batch_idx)
110 |
111 | def _log_some_outputs(self, outputs, name):
112 | num_val_steps_to_log, num_samples_per_batch_to_log = 5, 3 # Could be configurable via cfg
113 | steps_to_log_indices = random.sample(range(len(outputs)), k=min(len(outputs), num_val_steps_to_log))
114 | self._log_indices[name] = {
115 | 'steps': steps_to_log_indices,
116 | 'samples': [
117 | random.sample(
118 | range(len(outputs[step]['answer'])),
119 | k=min(len(outputs[step]['answer']),
120 | num_samples_per_batch_to_log))
121 | for step in steps_to_log_indices
122 | ]
123 | }
124 | for i, step in enumerate(steps_to_log_indices):
125 | indices = self._log_indices[name]['samples'][i]
126 | for b in indices:
127 | sample = (
128 | f'Video: "{outputs[step]["video_id"][b]}". \n'
129 | f'Question: "{outputs[step]["question"][b]}". \n'
130 | f'Target: "{outputs[step]["answer"][b]}". \n'
131 | f'Output: "{outputs[step]["pred_answer"][b]}"'
132 | )
133 | self.logger.experiment.add_text(f'{name} {str(i * len(indices) + b)}', sample,
134 | global_step=self.global_step)
135 |
136 | def aggregate_metrics(self, outputs, prefix):
137 | # evaluate CloseQA
138 | all_hypos = []
139 | all_targets = []
140 | for output in outputs:
141 | for i in range(len(output['video_id'])):
142 | if output['task'][i] == 'CloseQA':
143 | all_hypos.append(output['pred_answer'][i])
144 | all_targets.append(output['answer'][i])
145 | if len(all_hypos) > 0:
146 | num_correct = 0
147 | for hypo, target in zip(all_hypos, all_targets):
148 | if hypo == target:
149 | num_correct += 1
150 | acc = num_correct / len(all_targets) * 100
151 | metrics = {f'{prefix}_close_acc': acc}
152 | else:
153 | metrics = {}
154 |
155 | # evaluate OpenQA
156 | all_hypos = []
157 | all_targets = []
158 | for output in outputs:
159 | for i in range(len(output['video_id'])):
160 | if output['task'][i] == 'OpenQA':
161 | all_hypos.append(output['pred_answer'][i])
162 | all_targets.append(output['answer'][i])
163 | if len(all_hypos) > 0:
164 | open_qa_metrics = calc_metrics(all_hypos, [[x] for x in all_targets], test=prefix=='test')
165 | for k, v in open_qa_metrics.items():
166 | metrics[f'{prefix}_{k}'] = v
167 |
168 | # evalute NLQ
169 | nlq_preds = []
170 | for output in outputs:
171 | for i in range(len(output['video_id'])):
172 | if output['task'][i] != 'NLQ':
173 | continue
174 | qid = output['query_id'][i]
175 | temp_list = qid.split("_")
176 | sample_ratio = output['sample_ratio'][i]
177 | new_prediction = [
178 | [ segment[0] / sample_ratio,
179 | segment[1] / sample_ratio,
180 | score ]
181 | for segment, score in zip(
182 | output['nlq_results'][i]['segments'].cpu().detach().tolist(),
183 | output['nlq_results'][i]['scores'].cpu().detach().tolist(),
184 | )]
185 | nlq_preds.append({
186 | 'query_idx': int(temp_list[1]),
187 | 'annotation_uid': temp_list[0],
188 | 'predicted_times': new_prediction,
189 | 'clip_uid': output['video_id'][i]
190 | })
191 | if len(nlq_preds) > 0:
192 | performance, score_str = self.nlq_evaluator.evaluate(nlq_preds, verbose=False)
193 | metrics[f'{prefix}_R1_03'] = performance[0, 0] * 100
194 | metrics[f'{prefix}_R5_03'] = performance[0, 1] * 100
195 | metrics[f'{prefix}_R1_05'] = performance[1, 0] * 100
196 | metrics[f'{prefix}_R5_05'] = performance[1, 1] * 100
197 | metrics[f'{prefix}_Mean_R1'] = (performance[0, 0] + performance[1, 0]) * 100 / 2
198 |
199 | # # save predictions
200 | # results = []
201 | # for output in outputs:
202 | # for i in range(len(output['video_id'])):
203 | # results.append({
204 | # 'query_id': output['query_id'][i],
205 | # 'pred_answer': output['pred_answer'][i],
206 | # 'gt_answer': output['answer'][i],
207 | # 'pred_window': (output['nlq_results'][i]['segments'].cpu().detach() / output['sample_ratio'][i]).tolist(),
208 | # 'gt_window': self.nlq_evaluator.gt_dict[(output['video_id'][i], output['query_id'][i].split('_')[0])]["language_queries"][int(output['query_id'][i].split('_')[1])]
209 | # })
210 | # with open('analysis/VLG_OpenQA.json', 'w') as f:
211 | # json.dump(results, f)
212 |
213 | return metrics
214 |
215 | # def training_epoch_end(self, outputs):
216 | # self._log_some_outputs(outputs, 'train')
217 | # metrics = self.aggregate_metrics(outputs, prefix='train')
218 | # self.log_dict(metrics, sync_dist=True)
219 |
220 | def validation_epoch_end(self, outputs):
221 | def _mean(key):
222 | return torch.stack([data[key] for data in outputs]).mean()
223 |
224 | # self._log_some_outputs(outputs, 'val')
225 | metrics = self.aggregate_metrics(outputs, prefix='val')
226 | metrics.update({
227 | f'val_{name}': _mean(name) for name in outputs[0].keys() if 'loss' in name
228 | })
229 | self.log_dict(metrics, sync_dist=True)
230 |
231 | def test_epoch_end(self, outputs):
232 | # self._log_some_outputs(outputs, 'test')
233 | metrics = self.aggregate_metrics(outputs, prefix='test')
234 | self.log_dict(metrics, sync_dist=True)
235 | if self.config.trainer.save_nlq_results is not None:
236 | src = 'data/joint/annotations.QaEgo4D_test_close.json'
237 | dst = self.config.trainer.save_nlq_results
238 | self.save_nlq_results(src, dst, outputs)
239 |
240 | def save_nlq_results(self, src, dst, preds):
241 | # aggregate preds
242 | pred_dict = {}
243 | for batch_pred in preds:
244 | for i in range(len(batch_pred['video_id'])):
245 | qid = batch_pred['query_id'][i]
246 | sample_ratio = batch_pred['sample_ratio'][i]
247 | pred_start = batch_pred['nlq_results'][i]['segments'][0].cpu().detach().tolist()[0] / sample_ratio
248 | pred_end = batch_pred['nlq_results'][i]['segments'][0].cpu().detach().tolist()[1] / sample_ratio
249 | assert qid not in pred_dict
250 | pred_dict[qid] = {
251 | 'pred_start_sec': pred_start,
252 | 'pred_end_sec': pred_end
253 | }
254 |
255 | save_results = []
256 | for src_data in json.load(open(src)):
257 | pred_data = pred_dict[src_data['sample_id']]
258 | save_data = copy.deepcopy(src_data)
259 | save_data['moment_start_frame'] = pred_data['pred_start_sec'] * 30
260 | save_data['moment_end_frame'] = pred_data['pred_end_sec'] * 30
261 | save_results.append(save_data)
262 | with open(dst, 'w') as f:
263 | json.dump(save_results, f)
264 |
265 | def configure_optimizers(self):
266 | optimizer = instantiate(
267 | self.config.optim.optimizer,
268 | filter(lambda p: p.requires_grad, self.parameters()),
269 | lr=self.config.optim.optimizer.lr
270 | )
271 | if self.config.optim.lr_scheduler:
272 | lr_scheduler = OneCycleLR(
273 | optimizer=optimizer,
274 | max_lr=self.config.optim.optimizer.lr,
275 | total_steps=self.total_steps,
276 | anneal_strategy='linear'
277 | )
278 | return {
279 | 'optimizer': optimizer,
280 | 'lr_scheduler': {
281 | 'scheduler': lr_scheduler,
282 | 'interval': 'step'
283 | }
284 | }
285 | else:
286 | return optimizer
--------------------------------------------------------------------------------
/eval_nlq.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | """
3 | Script to evaluate performance of any model for Ego4d Episodic Memory.
4 |
5 | Natural Language Queries (NLQ)
6 | """
7 |
8 | from __future__ import absolute_import, division, print_function, unicode_literals
9 |
10 | import argparse
11 | import json
12 |
13 | import numpy as np
14 | import torch
15 | import terminaltables
16 |
17 |
18 |
19 | def compute_overlap(pred, gt):
20 | # check format
21 | assert isinstance(pred, list) and isinstance(gt, list)
22 | pred_is_list = isinstance(pred[0], list)
23 | gt_is_list = isinstance(gt[0], list)
24 | pred = pred if pred_is_list else [pred]
25 | gt = gt if gt_is_list else [gt]
26 | # compute overlap
27 | pred, gt = np.array(pred), np.array(gt)
28 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
29 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1])
30 | inter = np.maximum(0.0, inter_right - inter_left)
31 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0])
32 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1])
33 | union = np.maximum(1e-12, union_right - union_left)
34 | overlap = 1.0 * inter / union
35 | # reformat output
36 | overlap = overlap if gt_is_list else overlap[:, 0]
37 | overlap = overlap if pred_is_list else overlap[0]
38 | return overlap
39 |
40 |
41 | def time_to_index(start_time, end_time, num_units, duration):
42 | s_times = np.arange(0, num_units).astype(np.float32) / float(num_units) * duration
43 | e_times = (
44 | np.arange(1, num_units + 1).astype(np.float32) / float(num_units) * duration
45 | )
46 | candidates = np.stack(
47 | [
48 | np.repeat(s_times[:, None], repeats=num_units, axis=1),
49 | np.repeat(e_times[None, :], repeats=num_units, axis=0),
50 | ],
51 | axis=2,
52 | ).reshape((-1, 2))
53 | overlaps = compute_overlap(candidates.tolist(), [start_time, end_time]).reshape(
54 | num_units, num_units
55 | )
56 | start_index = np.argmax(overlaps) // num_units
57 | end_index = np.argmax(overlaps) % num_units
58 | return start_index, end_index, overlaps
59 |
60 |
61 | def index_to_time(start_index, end_index, num_units, duration):
62 | s_times = np.arange(0, num_units).astype(np.float32) * duration / float(num_units)
63 | e_times = (
64 | np.arange(1, num_units + 1).astype(np.float32) * duration / float(num_units)
65 | )
66 | start_time = s_times[start_index]
67 | end_time = e_times[end_index]
68 | return start_time, end_time
69 |
70 |
71 | def display_results(results, mIoU, thresholds, topK, title=None):
72 | display_data = [
73 | [f"Rank@{ii}\nmIoU@{jj}" for ii in topK for jj in thresholds] + ["mIoU"]
74 | ]
75 | results *= 100
76 | mIoU *= 100
77 | display_data.append(
78 | [
79 | f"{results[jj][ii]:.02f}"
80 | for ii in range(len(topK))
81 | for jj in range(len(thresholds))
82 | ]
83 | + [f"{mIoU:.02f}"]
84 | )
85 | table = terminaltables.AsciiTable(display_data, title)
86 | for ii in range(len(thresholds) * len(topK)):
87 | table.justify_columns[ii] = "center"
88 | return table.table
89 |
90 |
91 | def compute_IoU(pred, gt):
92 | """Compute the IoU given predicted and ground truth windows."""
93 | assert isinstance(pred, list) and isinstance(gt, list)
94 | pred_is_list = isinstance(pred[0], list)
95 | gt_is_list = isinstance(gt[0], list)
96 | if not pred_is_list:
97 | pred = [pred]
98 | if not gt_is_list:
99 | gt = [gt]
100 | pred, gt = np.array(pred), np.array(gt)
101 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
102 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1])
103 | inter = np.maximum(0.0, inter_right - inter_left)
104 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0])
105 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1])
106 | union = np.maximum(0.0, union_right - union_left)
107 | overlap = 1.0 * inter / union
108 | if not gt_is_list:
109 | overlap = overlap[:, 0]
110 | if not pred_is_list:
111 | overlap = overlap[0]
112 | return overlap
113 |
114 |
115 | def evaluate_nlq_performance(
116 | predictions, ground_truth, thresholds, topK, per_instance=False
117 | ):
118 | results = [[[] for _ in topK] for _ in thresholds]
119 | average_IoU = []
120 | for pred, gt in zip(predictions, ground_truth):
121 | # Compute overlap and recalls.
122 | overlap = compute_IoU(pred, gt)
123 | average_IoU.append(np.mean(np.sort(overlap[0])[-3:]))
124 | for tt, threshold in enumerate(thresholds):
125 | for rr, KK in enumerate(topK):
126 | results[tt][rr].append((overlap > threshold)[:KK].any())
127 |
128 | mean_results = np.array(results).mean(axis=-1)
129 | mIoU = np.mean(average_IoU)
130 | if per_instance:
131 | per_instance_results = {
132 | "overlap": overlap,
133 | "average_IoU": average_IoU,
134 | "results": results,
135 | }
136 | return mean_results, mIoU, per_instance_results
137 | else:
138 | return mean_results, mIoU
139 |
140 |
141 | def load_jsonl(filename):
142 | with open(filename, "r") as f:
143 | return [json.loads(l.strip("\n")) for l in f.readlines()]
144 |
145 |
146 | class ReferringRecall(object):
147 | thresholds = np.array([0.3, 0.5])
148 | topK = np.array([1, 5])
149 | def __init__(
150 | self,
151 | dataset="ego4d",
152 | gt_file="/remote-home/share/Ego4D_dsz/v1/annotations/nlq_val.json"
153 | ):
154 | self.dataset = dataset
155 | self.gt_file = gt_file
156 | print(self.gt_file)
157 | if self.dataset == "ego4d":
158 | with open(self.gt_file) as file_id:
159 | self.gt_dict, self.num_gt_queries = self.load_gt_from_json(json.load(file_id))
160 | else:
161 | self.gt_dict = {}
162 | for d in load_jsonl(self.gt_file):
163 | # print(d)
164 | self.gt_dict[d['query_id']] = d["timestamps"]
165 | self.num_gt_queries = len(self.gt_dict)
166 |
167 | def load_gt_from_json(self, ground_truth):
168 | gt_dict = {}
169 | num_gt_queries = 0
170 |
171 | for video_datum in ground_truth["videos"]:
172 | for clip_datum in video_datum["clips"]:
173 | clip_uid = clip_datum["clip_uid"]
174 | for ann_datum in clip_datum["annotations"]:
175 | key = (clip_uid, ann_datum["annotation_uid"])
176 | gt_dict[key] = ann_datum
177 | num_gt_queries += len(ann_datum["language_queries"])
178 |
179 | return gt_dict, num_gt_queries
180 |
181 | def compute_IoU(self, pred, gt):
182 | """Compute the IoU given predicted and ground truth windows."""
183 | assert isinstance(pred, list) and isinstance(gt, list)
184 | if len(pred) == 0: # FIXME: I don't know why, maybe coincidence, that the PtTransformerRegHead produces all 0 offsets at the start of training
185 | return [0]
186 | pred_is_list = isinstance(pred[0], list)
187 | gt_is_list = isinstance(gt[0], list)
188 | if not pred_is_list:
189 | pred = [pred]
190 | if not gt_is_list:
191 | gt = [gt]
192 | pred, gt = np.array(pred), np.array(gt)
193 | inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
194 | inter_right = np.minimum(pred[:, 1, None], gt[None, :, 1])
195 | inter = np.maximum(0.0, inter_right - inter_left)
196 | union_left = np.minimum(pred[:, 0, None], gt[None, :, 0])
197 | union_right = np.maximum(pred[:, 1, None], gt[None, :, 1])
198 | union = np.maximum(0.0, union_right - union_left)
199 | overlap = 1.0 * inter / union
200 | if not gt_is_list:
201 | overlap = overlap[:, 0]
202 | if not pred_is_list:
203 | overlap = overlap[0]
204 | return overlap
205 |
206 | def display_results_anet(self, results, title=None):
207 | display_data = [
208 | [f"Rank@{ii}\nmIoU@{jj:.1f}" for ii in self.topK for jj in self.thresholds]
209 | ]
210 | results *= 100
211 | display_data.append(
212 | [
213 | f"{results[ii][jj]:.02f}"
214 | for ii in range(len(self.topK))
215 | for jj in range(len(self.thresholds))
216 | ]
217 | )
218 | table = terminaltables.AsciiTable(display_data, title)
219 | for ii in range(len(self.thresholds) * len(self.topK)):
220 | table.justify_columns[ii] = "center"
221 | return table.table
222 |
223 | def display_results(self, results, title=None):
224 | display_data = [
225 | [f"Rank@{ii}\nmIoU@{jj}" for ii in self.topK for jj in self.thresholds]
226 | ]
227 | results *= 100
228 |
229 | display_data.append(
230 | [
231 | f"{results[jj][ii]:.02f}"
232 | for ii in range(len(self.topK))
233 | for jj in range(len(self.thresholds))
234 | ]
235 | )
236 | table = terminaltables.AsciiTable(display_data, title)
237 | for ii in range(len(self.thresholds) * len(self.topK)):
238 | table.justify_columns[ii] = "center"
239 | return table.table
240 |
241 | def evaluate(self, predictions, verbose=True):
242 | """Evalutes the performances."""
243 |
244 | results = [[[] for _ in self.topK] for _ in self.thresholds]
245 | average_IoU = []
246 | num_instances = 0
247 |
248 | for pred_datum in predictions:
249 | key = (pred_datum["clip_uid"], pred_datum["annotation_uid"])
250 | assert key in self.gt_dict, f"{key} Instance not present!"
251 | query_id = pred_datum["query_idx"]
252 | gt_datum = self.gt_dict[key]
253 | gt_query_datum = gt_datum["language_queries"][query_id]
254 |
255 | # Compute overlap and recalls.
256 | overlap = self.compute_IoU(
257 | pred_datum["predicted_times"],
258 | [[gt_query_datum["clip_start_sec"], gt_query_datum["clip_end_sec"]]],
259 | )
260 | average_IoU.append(overlap[0])
261 |
262 | for tt, threshold in enumerate(self.thresholds):
263 | for rr, KK in enumerate(self.topK):
264 | results[tt][rr].append((overlap > threshold)[:KK].any())
265 | num_instances += 1
266 |
267 | mean_results = np.array(results).mean(axis=-1)
268 |
269 | score_str = None
270 | if verbose:
271 | print(f"Evaluated: {num_instances} / {self.num_gt_queries} instances")
272 | score_str = self.display_results(mean_results)
273 | print(score_str, flush=True)
274 |
275 | return mean_results, score_str
276 |
277 | def _iou(self, candidates, gt):
278 | start, end = candidates[:, 0].float(), candidates[:, 1].float()
279 | s, e = gt[0].float(), gt[1].float()
280 | inter = end.min(e) - start.max(s)
281 | union = end.max(e) - start.min(s)
282 | return inter.clamp(min=0) / union
283 |
284 | def evaluate_anet(
285 | self, submission, verbose=True):
286 |
287 | iou_metrics = torch.tensor(self.thresholds)
288 | num_iou_metrics = len(iou_metrics)
289 |
290 | recall_metrics = torch.tensor(self.topK)
291 | max_recall = recall_metrics.max()
292 | num_recall_metrics = len(recall_metrics)
293 | recall_x_iou = torch.zeros((num_recall_metrics, len(iou_metrics)))
294 |
295 | for k in submission:
296 | # print(k)
297 | gt_grounding = torch.tensor(self.gt_dict[k['query_id']])
298 | pred_moments = torch.tensor(k["predicted_times"][:max_recall])
299 | mious = self._iou(pred_moments, gt_grounding)
300 | mious_len = len(mious)
301 | bools = mious[:, None].expand(mious_len, num_iou_metrics) > iou_metrics
302 | for i, r in enumerate(recall_metrics):
303 | recall_x_iou[i] += bools[:r].any(dim=0)
304 |
305 | recall_x_iou /= len(submission)
306 |
307 | if verbose:
308 | print(f"Evaluated: {len(submission)} / {self.num_gt_queries} instances")
309 | score_str = self.display_results_anet(recall_x_iou)
310 | print(score_str, flush=True)
311 |
312 | return recall_x_iou
313 |
314 |
315 | def segment_iou(target_segment, candidate_segments):
316 | """Compute the temporal intersection over union between a
317 | target segment and all the test segments.
318 | Parameters
319 | ----------
320 | target_segment : 1d array
321 | Temporal target segment containing [starting, ending] times.
322 | candidate_segments : 2d array
323 | Temporal candidate segments containing N x [starting, ending] times.
324 | Outputs
325 | -------
326 | tiou : 1d array
327 | Temporal intersection over union score of the N's candidate segments.
328 | """
329 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0])
330 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1])
331 | # Intersection including Non-negative overlap score.
332 | segments_intersection = (tt2 - tt1).clip(0)
333 | # Segment union.
334 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \
335 | + (target_segment[1] - target_segment[0]) - segments_intersection
336 | # Compute overlap as the ratio of the intersection
337 | # over union of two segments.
338 | tIoU = segments_intersection.astype(float) / segments_union
339 | return tIoU
--------------------------------------------------------------------------------
/model/ours/nlq_head.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import nms_1d_cpu
7 |
8 |
9 | class LayerNorm(nn.Module):
10 | """
11 | LayerNorm that supports inputs of size B, C, T
12 | """
13 |
14 | def __init__(
15 | self,
16 | num_channels,
17 | eps=1e-5,
18 | affine=True,
19 | device=None,
20 | dtype=None,
21 | ):
22 | super().__init__()
23 | factory_kwargs = {'device': device, 'dtype': dtype}
24 | self.num_channels = num_channels
25 | self.eps = eps
26 | self.affine = affine
27 |
28 | if self.affine:
29 | self.weight = nn.Parameter(
30 | torch.ones([1, num_channels, 1], **factory_kwargs))
31 | self.bias = nn.Parameter(
32 | torch.zeros([1, num_channels, 1], **factory_kwargs))
33 | else:
34 | self.register_parameter('weight', None)
35 | self.register_parameter('bias', None)
36 |
37 | def forward(self, x):
38 | assert x.dim() == 3
39 | assert x.shape[1] == self.num_channels
40 |
41 | # normalization along C channels
42 | mu = torch.mean(x, dim=1, keepdim=True)
43 | res_x = x - mu
44 | sigma = torch.mean(res_x ** 2, dim=1, keepdim=True)
45 | out = res_x / torch.sqrt(sigma + self.eps)
46 |
47 | # apply weight and bias
48 | if self.affine:
49 | out *= self.weight
50 | out += self.bias
51 |
52 | return out
53 |
54 |
55 | @torch.jit.script
56 | def ctr_diou_loss_1d(
57 | input_offsets: torch.Tensor,
58 | target_offsets: torch.Tensor,
59 | reduction: str = 'none',
60 | eps: float = 1e-8,
61 | ) -> torch.Tensor:
62 | """
63 | Distance-IoU Loss (Zheng et. al)
64 | https://arxiv.org/abs/1911.08287
65 |
66 | This is an implementation that assumes a 1D event is represented using
67 | the same center point with different offsets, e.g.,
68 | (t1, t2) = (c - o_1, c + o_2) with o_i >= 0
69 |
70 | Reference code from
71 | https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py
72 |
73 | Args:
74 | input/target_offsets (Tensor): 1D offsets of size (N, 2)
75 | reduction: 'none' | 'mean' | 'sum'
76 | 'none': No reduction will be applied to the output.
77 | 'mean': The output will be averaged.
78 | 'sum': The output will be summed.
79 | eps (float): small number to prevent division by zero
80 | """
81 | input_offsets = input_offsets.float()
82 | target_offsets = target_offsets.float()
83 | # check all 1D events are valid
84 | assert (input_offsets >= 0.0).all(), "predicted offsets must be non-negative"
85 | assert (target_offsets >= 0.0).all(), "GT offsets must be non-negative"
86 |
87 | lp, rp = input_offsets[:, 0], input_offsets[:, 1]
88 | lg, rg = target_offsets[:, 0], target_offsets[:, 1]
89 |
90 | # intersection key points
91 | lkis = torch.min(lp, lg)
92 | rkis = torch.min(rp, rg)
93 |
94 | # iou
95 | intsctk = rkis + lkis
96 | unionk = (lp + rp) + (lg + rg) - intsctk
97 | iouk = intsctk / unionk.clamp(min=eps)
98 |
99 | # smallest enclosing box
100 | lc = torch.max(lp, lg)
101 | rc = torch.max(rp, rg)
102 | len_c = lc + rc
103 |
104 | # offset between centers
105 | rho = 0.5 * (rp - lp - rg + lg)
106 |
107 | # diou
108 | loss = 1.0 - iouk + torch.square(rho / len_c.clamp(min=eps))
109 |
110 | if reduction == "mean":
111 | loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
112 | elif reduction == "sum":
113 | loss = loss.sum()
114 |
115 | return loss
116 |
117 |
118 | @torch.jit.script
119 | def sigmoid_focal_loss(
120 | inputs: torch.Tensor,
121 | targets: torch.Tensor,
122 | alpha: float = 0.25,
123 | gamma: float = 2.0,
124 | reduction: str = "none",
125 | ) -> torch.Tensor:
126 | """
127 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
128 | Taken from
129 | https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
130 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
131 |
132 | Args:
133 | inputs: A float tensor of arbitrary shape.
134 | The predictions for each example.
135 | targets: A float tensor with the same shape as inputs. Stores the binary
136 | classification label for each element in inputs
137 | (0 for the negative class and 1 for the positive class).
138 | alpha: (optional) Weighting factor in range (0,1) to balance
139 | positive vs negative examples. Default = 0.25.
140 | gamma: Exponent of the modulating factor (1 - p_t) to
141 | balance easy vs hard examples.
142 | reduction: 'none' | 'mean' | 'sum'
143 | 'none': No reduction will be applied to the output.
144 | 'mean': The output will be averaged.
145 | 'sum': The output will be summed.
146 | Returns:
147 | Loss tensor with the reduction option applied.
148 | """
149 | inputs = inputs.float()
150 | targets = targets.float()
151 | p = torch.sigmoid(inputs)
152 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
153 | p_t = p * targets + (1 - p) * (1 - targets)
154 | loss = ce_loss * ((1 - p_t) ** gamma)
155 |
156 | if alpha >= 0:
157 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
158 | loss = alpha_t * loss
159 |
160 | if reduction == "mean":
161 | loss = loss.mean()
162 | elif reduction == "sum":
163 | loss = loss.sum()
164 |
165 | return loss
166 |
167 |
168 | class BufferList(nn.Module):
169 | """
170 | Similar to nn.ParameterList, but for buffers
171 |
172 | Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/anchor_generator.py
173 | """
174 |
175 | def __init__(self, buffers):
176 | super().__init__()
177 | for i, buffer in enumerate(buffers):
178 | # Use non-persistent buffer so the values are not saved in checkpoint
179 | self.register_buffer(str(i), buffer, persistent=False)
180 |
181 | def __len__(self):
182 | return len(self._buffers)
183 |
184 | def __iter__(self):
185 | return iter(self._buffers.values())
186 |
187 |
188 | class PointGenerator(nn.Module):
189 | """
190 | A generator for temporal "points"
191 |
192 | max_seq_len can be much larger than the actual seq length
193 | """
194 | def __init__(
195 | self,
196 | max_seq_len, # max sequence length that the generator will buffer
197 | fpn_strides, # strides of fpn levels
198 | regression_range, # regression range (on feature grids)
199 | use_offset=False # if to align the points at grid centers
200 | ):
201 | super().__init__()
202 | # sanity check, # fpn levels and length divisible
203 | fpn_levels = len(fpn_strides)
204 | assert len(regression_range) == fpn_levels
205 |
206 | # save params
207 | self.max_seq_len = max_seq_len
208 | self.fpn_levels = fpn_levels
209 | self.fpn_strides = fpn_strides
210 | self.regression_range = regression_range
211 | self.use_offset = use_offset
212 |
213 | # generate all points and buffer the list
214 | self.buffer_points = self._generate_points()
215 |
216 | def _generate_points(self):
217 | points_list = []
218 | # loop over all points at each pyramid level
219 | for l, stride in enumerate(self.fpn_strides):
220 | reg_range = torch.as_tensor(
221 | self.regression_range[l], dtype=torch.float)
222 | fpn_stride = torch.as_tensor(stride, dtype=torch.float)
223 | points = torch.arange(0, self.max_seq_len, stride)[:, None]
224 | # add offset if necessary (not in our current model)
225 | if self.use_offset:
226 | points += 0.5 * stride
227 | # pad the time stamp with additional regression range / stride
228 | reg_range = reg_range[None].repeat(points.shape[0], 1)
229 | fpn_stride = fpn_stride[None].repeat(points.shape[0], 1)
230 | # size: T x 4 (ts, reg_range, stride)
231 | points_list.append(torch.cat((points, reg_range, fpn_stride), dim=1))
232 |
233 | return BufferList(points_list)
234 |
235 | def forward(self, feats):
236 | # feats will be a list of torch tensors
237 | assert len(feats) == self.fpn_levels
238 | pts_list = []
239 | feat_lens = [feat.shape[-1] for feat in feats]
240 | for feat_len, buffer_pts in zip(feat_lens, self.buffer_points):
241 | assert feat_len <= buffer_pts.shape[0], "Reached max buffer length for point generator"
242 | pts = buffer_pts[:feat_len, :]
243 | pts_list.append(pts)
244 | return pts_list
245 |
246 |
247 | # drop path: from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py
248 | class Scale(nn.Module):
249 | """
250 | Multiply the output regression range by a learnable constant value
251 | """
252 |
253 | def __init__(self, init_value=1.0):
254 | """
255 | init_value : initial value for the scalar
256 | """
257 | super().__init__()
258 | self.scale = nn.Parameter(
259 | torch.tensor(init_value, dtype=torch.float32),
260 | requires_grad=True
261 | )
262 |
263 | def forward(self, x):
264 | """
265 | input -> scale * input
266 | """
267 | return x * self.scale
268 |
269 |
270 | class MaskedConv1D(nn.Module):
271 | """
272 | Masked 1D convolution. Interface remains the same as Conv1d.
273 | Only support a sub set of 1d convs
274 | """
275 |
276 | def __init__(
277 | self,
278 | in_channels,
279 | out_channels,
280 | kernel_size,
281 | stride=1,
282 | padding=0,
283 | dilation=1,
284 | groups=1,
285 | bias=True,
286 | padding_mode='zeros'
287 | ):
288 | super().__init__()
289 | # element must be aligned
290 | assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding)
291 | # stride
292 | self.stride = stride
293 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
294 | stride, padding, dilation, groups, bias, padding_mode)
295 | # zero out the bias term if it exists
296 | if bias:
297 | torch.nn.init.constant_(self.conv.bias, 0.)
298 |
299 | def forward(self, x, mask):
300 | # x: batch size, feature channel, sequence length,
301 | # mask: batch size, 1, sequence length (bool)
302 | B, C, T = x.size()
303 | # input length must be divisible by stride
304 | assert T % self.stride == 0
305 |
306 | # conv
307 | out_conv = self.conv(x)
308 | # compute the mask
309 | if self.stride > 1:
310 | # downsample the mask using nearest neighbor
311 | out_mask = F.interpolate(
312 | mask.to(x.dtype),
313 | size=T // self.stride,
314 | mode='nearest'
315 | )
316 | else:
317 | # masking out the features
318 | out_mask = mask.to(x.dtype)
319 |
320 | # masking the output, stop grad to mask
321 | out_conv = out_conv * out_mask.detach()
322 | out_mask = out_mask.bool()
323 | return out_conv, out_mask
324 |
325 |
326 | class PtTransformerClsHead(nn.Module):
327 | """
328 | 1D Conv heads for classification
329 | """
330 |
331 | def __init__(
332 | self,
333 | input_dim,
334 | feat_dim,
335 | num_classes,
336 | prior_prob=0.01,
337 | num_layers=3,
338 | kernel_size=3,
339 | act_layer=nn.ReLU,
340 | with_ln=False,
341 | empty_cls=[]
342 | ):
343 | super().__init__()
344 | self.act = act_layer()
345 |
346 | # build the head
347 | self.head = nn.ModuleList()
348 | self.norm = nn.ModuleList()
349 | for idx in range(num_layers - 1):
350 | if idx == 0:
351 | in_dim = input_dim
352 | out_dim = feat_dim
353 | else:
354 | in_dim = feat_dim
355 | out_dim = feat_dim
356 | self.head.append(
357 | MaskedConv1D(
358 | in_dim, out_dim, kernel_size,
359 | stride=1,
360 | padding=kernel_size // 2,
361 | bias=(not with_ln)
362 | )
363 | )
364 | if with_ln:
365 | self.norm.append(
366 | LayerNorm(out_dim)
367 | )
368 | else:
369 | self.norm.append(nn.Identity())
370 |
371 | # classifier
372 | self.cls_head = MaskedConv1D(
373 | feat_dim, num_classes, kernel_size,
374 | stride=1, padding=kernel_size // 2
375 | )
376 |
377 | # use prior in model initialization to improve stability
378 | # this will overwrite other weight init
379 | bias_value = -(math.log((1 - prior_prob) / prior_prob))
380 | torch.nn.init.constant_(self.cls_head.conv.bias, bias_value)
381 |
382 | # a quick fix to empty categories:
383 | # the weights associated with these categories will remain unchanged
384 | # we set their bias to a large negative value to prevent their outputs
385 | if len(empty_cls) > 0:
386 | bias_value = -(math.log((1 - 1e-6) / 1e-6))
387 | for idx in empty_cls:
388 | torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value)
389 |
390 | def forward(self, fpn_feats, fpn_masks):
391 | assert len(fpn_feats) == len(fpn_masks)
392 |
393 | # apply the classifier for each pyramid level
394 | out_logits = tuple()
395 | for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
396 | cur_out = cur_feat
397 | for idx in range(len(self.head)):
398 | cur_out, _ = self.head[idx](cur_out, cur_mask)
399 | cur_out = self.act(self.norm[idx](cur_out))
400 | cur_logits, _ = self.cls_head(cur_out, cur_mask)
401 | out_logits += (cur_logits,)
402 |
403 | # fpn_masks remains the same
404 | return out_logits
405 |
406 |
407 | class PtTransformerRegHead(nn.Module):
408 | """
409 | Shared 1D Conv heads for regression
410 | Simlar logic as PtTransformerClsHead with separated implementation for clarity
411 | """
412 |
413 | def __init__(
414 | self,
415 | input_dim,
416 | feat_dim,
417 | fpn_levels,
418 | num_layers=3,
419 | kernel_size=3,
420 | act_layer=nn.ReLU,
421 | with_ln=False
422 | ):
423 | super().__init__()
424 | self.fpn_levels = fpn_levels
425 | self.act = act_layer()
426 |
427 | # build the conv head
428 | self.head = nn.ModuleList()
429 | self.norm = nn.ModuleList()
430 | for idx in range(num_layers - 1):
431 | if idx == 0:
432 | in_dim = input_dim
433 | out_dim = feat_dim
434 | else:
435 | in_dim = feat_dim
436 | out_dim = feat_dim
437 | self.head.append(
438 | MaskedConv1D(
439 | in_dim, out_dim, kernel_size,
440 | stride=1,
441 | padding=kernel_size // 2,
442 | bias=(not with_ln)
443 | )
444 | )
445 | if with_ln:
446 | self.norm.append(
447 | LayerNorm(out_dim)
448 | )
449 | else:
450 | self.norm.append(nn.Identity())
451 |
452 | self.scale = nn.ModuleList()
453 | for idx in range(fpn_levels):
454 | self.scale.append(Scale())
455 |
456 | # segment regression
457 | self.offset_head = MaskedConv1D(
458 | feat_dim, 2, kernel_size,
459 | stride=1, padding=kernel_size // 2
460 | )
461 |
462 | def forward(self, fpn_feats, fpn_masks):
463 | assert len(fpn_feats) == len(fpn_masks)
464 | assert len(fpn_feats) == self.fpn_levels
465 |
466 | # apply the classifier for each pyramid level
467 | out_offsets = tuple()
468 | for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
469 | cur_out = cur_feat
470 | for idx in range(len(self.head)):
471 | cur_out, _ = self.head[idx](cur_out, cur_mask)
472 | cur_out = self.act(self.norm[idx](cur_out))
473 | cur_offsets, _ = self.offset_head(cur_out, cur_mask)
474 | out_offsets += (F.relu(self.scale[l](cur_offsets)),)
475 |
476 | # fpn_masks remains the same
477 | return out_offsets
478 |
479 |
480 | class NLQHead(nn.Module):
481 | def __init__(self, in_dim, max_v_len):
482 | super().__init__()
483 | self.train_center_sample = 'radius'
484 | self.train_center_sample_radius = 1.5
485 | self.train_loss_weight = 1.0
486 | self.train_cls_prior_prob = 0.01
487 | self.train_label_smoothing = 0.1
488 |
489 | self.test_pre_nms_thresh = 0.001
490 | self.test_pre_nms_topk = 2000
491 | self.test_iou_threshold = 0.1
492 | self.test_min_score = 0.001
493 | self.test_max_seg_num = 5
494 | self.test_nms_method = 'soft'
495 | self.test_duration_thresh = 0.001
496 | self.test_multiclass_nms = True
497 | self.test_nms_sigma = 0.75
498 | self.test_voting_thresh = 0.9
499 |
500 | self.loss_normalizer = 200
501 | self.loss_normalizer_momentum = 0.9
502 |
503 | self.neck = FPNIdentity(
504 | in_channels=[in_dim],
505 | out_channel=in_dim,
506 | start_level=0,
507 | end_level=-1,
508 | with_ln=True
509 | )
510 | self.point_generator = PointGenerator(
511 | max_seq_len=1.0 * max_v_len,
512 | fpn_strides=[1],
513 | regression_range=[[0,10000]]
514 | )
515 | self.cls_head = PtTransformerClsHead(
516 | in_dim,
517 | feat_dim=384,
518 | num_classes=1,
519 | kernel_size=3,
520 | prior_prob=self.train_cls_prior_prob,
521 | with_ln=True,
522 | num_layers=3,
523 | empty_cls=[]
524 | )
525 | self.reg_head = PtTransformerRegHead(
526 | in_dim,
527 | feat_dim=384,
528 | fpn_levels=1,
529 | kernel_size=3,
530 | num_layers=3,
531 | with_ln=True
532 | )
533 |
534 | def forward(self, feat, mask, training=True, gt_segments=None, gt_labels=None, v_lens=None):
535 | """
536 | feat: (B, D, T)
537 | mask: (B, 1, T)
538 | """
539 | masks = [mask]
540 | feats = [feat]
541 |
542 | fpn_feats, fpn_masks = self.neck(feats, masks)
543 |
544 | points = self.point_generator(fpn_feats)
545 | out_cls_logits = [x.permute(0, 2, 1) for x in self.cls_head(fpn_feats, fpn_masks)] # (B, T, #cls+1)
546 | out_offsets = [x.permute(0, 2, 1) for x in self.reg_head(fpn_feats, fpn_masks)] # (B, T, #cls*2)
547 | fpn_masks = [x.squeeze(1) for x in fpn_masks] # (B, T)
548 |
549 | # return loss during training
550 | if training:
551 | gt_cls_labels, gt_offsets = self.label_points(points, gt_segments, gt_labels, 1)
552 |
553 | # compute the loss and return
554 | losses = self.losses(
555 | fpn_masks,
556 | out_cls_logits, out_offsets,
557 | gt_cls_labels, gt_offsets
558 | )
559 | return losses
560 | else:
561 | # decode the actions (sigmoid / stride, etc)
562 | results = self.inference(points, fpn_masks, out_cls_logits, out_offsets, 1, v_lens)
563 |
564 | return results
565 |
566 | @torch.no_grad()
567 | def label_points(self, points, gt_segments, gt_labels, num_classes):
568 | # concat points on all fpn levels List[T x 4] -> F T x 4
569 | # This is shared for all samples in the mini-batch
570 | num_levels = len(points)
571 | concat_points = torch.cat(points, dim=0)
572 |
573 | gt_cls, gt_offset = [], []
574 | # loop over each video sample
575 | for gt_segment, gt_label in zip(gt_segments, gt_labels):
576 | assert len(gt_segment) == len(gt_label), (gt_segment, gt_label)
577 | cls_targets, reg_targets = self.label_points_single_video(
578 | concat_points, gt_segment, gt_label, num_classes
579 | )
580 | # "cls_targets: " #points, num_classes
581 | # "reg_targets: " #points, 2
582 | # append to list (len = # images, each of size FT x C)
583 | gt_cls.append(cls_targets)
584 | gt_offset.append(reg_targets)
585 |
586 | return gt_cls, gt_offset
587 |
588 | @torch.no_grad()
589 | def label_points_single_video(self, concat_points, gt_segment, gt_label, num_classes):
590 | # concat_points : F T x 4 (t, regression range, stride)
591 | # gt_segment : N (#Events) x 2
592 | # gt_label : N (#Events) x 1
593 | num_pts = concat_points.shape[0]
594 | num_gts = gt_segment.shape[0]
595 |
596 | # corner case where current sample does not have actions
597 | if num_gts == 0:
598 | cls_targets = gt_segment.new_full((num_pts, num_classes), 0)
599 | reg_targets = gt_segment.new_zeros((num_pts, 2))
600 | return cls_targets, reg_targets
601 |
602 | # compute the lengths of all segments -> F T x N
603 | lens = gt_segment[:, 1] - gt_segment[:, 0]
604 | lens = lens[None, :].repeat(num_pts, 1)
605 |
606 | # compute the distance of every point to each segment boundary
607 | # auto broadcasting for all reg target-> F T x N x 2
608 | gt_segs = gt_segment[None].expand(num_pts, num_gts, 2)
609 | left = concat_points[:, 0, None] - gt_segs[:, :, 0]
610 | right = gt_segs[:, :, 1] - concat_points[:, 0, None]
611 | reg_targets = torch.stack((left, right), dim=-1)
612 |
613 | if self.train_center_sample == 'radius':
614 | # center of all segments F T x N
615 | center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1])
616 | # center sampling based on stride radius
617 | # compute the new boundaries:
618 | # concat_points[:, 3] stores the stride
619 | t_mins = \
620 | center_pts - concat_points[:, 3, None] * self.train_center_sample_radius
621 | t_maxs = \
622 | center_pts + concat_points[:, 3, None] * self.train_center_sample_radius
623 |
624 | # prevent t_mins / maxs from over-running the action boundary
625 | # left: torch.maximum(t_mins, gt_segs[:, :, 0])
626 | # right: torch.minimum(t_maxs, gt_segs[:, :, 1])
627 | # F T x N (distance to the new boundary)
628 | cb_dist_left = concat_points[:, 0, None] \
629 | - torch.maximum(t_mins, gt_segs[:, :, 0])
630 | cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \
631 | - concat_points[:, 0, None]
632 | # F T x N x 2
633 | center_seg = torch.stack(
634 | (cb_dist_left, cb_dist_right), -1)
635 |
636 | # F T x N
637 | inside_gt_seg_mask = center_seg.min(-1)[0] > 0
638 | else:
639 | # inside an gt action
640 | inside_gt_seg_mask = reg_targets.min(-1)[0] > 0
641 |
642 | # limit the regression range for each location
643 | max_regress_distance = reg_targets.max(-1)[0]
644 |
645 | # F T x N
646 | inside_regress_range = torch.logical_and(
647 | (max_regress_distance >= concat_points[:, 1, None]),
648 | (max_regress_distance <= concat_points[:, 2, None])
649 | )
650 |
651 | # limit the regression range for each location and inside the center radius
652 | lens.masked_fill_(inside_gt_seg_mask == 0, float('inf'))
653 | lens.masked_fill_(inside_regress_range == 0, float('inf'))
654 |
655 | # if there are still more than one ground-truths for one point
656 | # pick the ground-truth with the shortest duration for the point (easiest to regress)
657 | # corner case: multiple actions with very similar durations (e.g., THUMOS14)
658 | # make sure that each point can only map with at most one ground-truth
659 | # F T x N -> F T
660 | min_len, min_len_inds = lens.min(dim=1)
661 | min_len_mask = torch.logical_and(
662 | (lens <= (min_len[:, None] + 1e-3)), (lens < float('inf'))
663 | ).to(reg_targets.dtype)
664 |
665 | # cls_targets: F T x C; reg_targets F T x 2
666 | # gt_label_one_hot = F.one_hot(gt_label, num_classes).to(reg_targets.dtype)
667 | gt_label_one_hot = gt_label.to(reg_targets.dtype)
668 | cls_targets = min_len_mask @ gt_label_one_hot
669 | # to prevent multiple GT actions with the same label and boundaries
670 | cls_targets.clamp_(min=0.0, max=1.0)
671 |
672 | # OK to use min_len_inds
673 | reg_targets = reg_targets[range(num_pts), min_len_inds]
674 | # normalization based on stride
675 | reg_targets /= concat_points[:, 3, None]
676 | return cls_targets, reg_targets
677 |
678 | def losses(
679 | self, fpn_masks,
680 | out_cls_logits, out_offsets,
681 | gt_cls_labels, gt_offsets
682 | ):
683 | # fpn_masks, out_*: F (List) [B, T_i, C]
684 | # gt_* : B (list) [F T, C]
685 | # fpn_masks -> (B, FT)
686 | valid_mask = torch.cat(fpn_masks, dim=1)
687 |
688 | # 1. classification loss
689 | # stack the list -> (B, FT) -> (# Valid, )
690 | gt_cls = torch.stack(gt_cls_labels)
691 | pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask)
692 |
693 | # update the loss normalizer
694 | num_pos = pos_mask.sum().item()
695 | self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
696 | 1 - self.loss_normalizer_momentum) * max(num_pos, 1)
697 |
698 | # gt_cls is already one hot encoded now, simply masking out
699 | gt_target = gt_cls[valid_mask]
700 |
701 | num_classes = gt_target.shape[-1]
702 |
703 | # optional label smoothing
704 | gt_target *= 1 - self.train_label_smoothing
705 | gt_target += self.train_label_smoothing / (num_classes + 1)
706 |
707 | # focal loss
708 | cls_loss = sigmoid_focal_loss(
709 | torch.cat(out_cls_logits, dim=1)[valid_mask],
710 | gt_target,
711 | reduction='sum'
712 | )
713 | cls_loss /= self.loss_normalizer
714 |
715 | # 2. regression using IoU/GIoU loss (defined on positive samples)
716 | # cat the predicted offsets -> (B, FT, 2 (xC)) -> # (#Pos, 2 (xC))
717 | pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask]
718 | gt_offsets = torch.stack(gt_offsets)[pos_mask]
719 | if num_pos == 0:
720 | reg_loss = 0 * pred_offsets.sum()
721 | else:
722 | # giou loss defined on positive samples
723 | reg_loss = ctr_diou_loss_1d(
724 | pred_offsets,
725 | gt_offsets,
726 | reduction='sum'
727 | )
728 | reg_loss /= self.loss_normalizer
729 |
730 | if self.train_loss_weight > 0:
731 | loss_weight = self.train_loss_weight
732 | else:
733 | loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01)
734 |
735 | # return a dict of losses
736 | final_loss = cls_loss + reg_loss * loss_weight
737 | return {'cls_loss': cls_loss,
738 | 'reg_loss': reg_loss,
739 | 'final_loss': final_loss}
740 |
741 | @torch.no_grad()
742 | def inference(
743 | self,
744 | points, fpn_masks,
745 | out_cls_logits, out_offsets, num_classes, v_lens
746 | ):
747 | # video_list B (list) [dict]
748 | # points F (list) [T_i, 4]
749 | # fpn_masks, out_*: F (List) [B, T_i, C]
750 | results = []
751 |
752 | # 2: inference on each single video and gather the results
753 | # upto this point, all results use timestamps defined on feature grids
754 | for idx, vlen in enumerate(v_lens):
755 | # gather per-video outputs
756 | cls_logits_per_vid = [x[idx] for x in out_cls_logits]
757 | offsets_per_vid = [x[idx] for x in out_offsets]
758 | fpn_masks_per_vid = [x[idx] for x in fpn_masks]
759 | # inference on a single video (should always be the case)
760 | results_per_vid = self.inference_single_video(
761 | points, fpn_masks_per_vid,
762 | cls_logits_per_vid, offsets_per_vid, num_classes,
763 | )
764 | # pass through video meta info
765 | results_per_vid['duration'] = vlen
766 | results.append(results_per_vid)
767 |
768 | # step 3: postprocessing
769 | results = self.postprocessing(results)
770 |
771 | return results
772 |
773 | @torch.no_grad()
774 | def inference_single_video(
775 | self,
776 | points,
777 | fpn_masks,
778 | out_cls_logits,
779 | out_offsets,
780 | num_classes,
781 | ):
782 | # points F (list) [T_i, 4]
783 | # fpn_masks, out_*: F (List) [T_i, C]
784 | segs_all = []
785 | scores_all = []
786 | cls_idxs_all = []
787 |
788 | # loop over fpn levels
789 | for cls_i, offsets_i, pts_i, mask_i in zip(
790 | out_cls_logits, out_offsets, points, fpn_masks
791 | ):
792 | # sigmoid normalization for output logits
793 | pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten()
794 |
795 | # Apply filtering to make NMS faster following detectron2
796 | # 1. Keep seg with confidence score > a threshold
797 | keep_idxs1 = (pred_prob > self.test_pre_nms_thresh)
798 | pred_prob = pred_prob[keep_idxs1]
799 | topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0]
800 |
801 | # 2. Keep top k top scoring boxes only
802 | num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0))
803 | pred_prob, idxs = pred_prob.sort(descending=True)
804 | pred_prob = pred_prob[:num_topk].clone()
805 | topk_idxs = topk_idxs[idxs[:num_topk]].clone()
806 |
807 | # fix a warning in pytorch 1.9
808 | pt_idxs = torch.div(
809 | topk_idxs, num_classes, rounding_mode='floor'
810 | )
811 | cls_idxs = torch.fmod(topk_idxs, num_classes)
812 |
813 | # 3. gather predicted offsets
814 | offsets = offsets_i[pt_idxs]
815 | pts = pts_i[pt_idxs]
816 |
817 | # 4. compute predicted segments (denorm by stride for output offsets)
818 | seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3]
819 | seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3]
820 | pred_segs = torch.stack((seg_left, seg_right), -1)
821 |
822 | # 5. Keep seg with duration > a threshold (relative to feature grids)
823 | seg_areas = seg_right - seg_left
824 | keep_idxs2 = seg_areas > self.test_duration_thresh
825 |
826 | # *_all : N (filtered # of segments) x 2 / 1
827 | segs_all.append(pred_segs[keep_idxs2])
828 | scores_all.append(pred_prob[keep_idxs2])
829 | cls_idxs_all.append(cls_idxs[keep_idxs2])
830 |
831 | # cat along the FPN levels (F N_i, C)
832 | segs_all, scores_all, cls_idxs_all = [
833 | torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all]
834 | ]
835 | results = {'segments': segs_all,
836 | 'scores': scores_all,
837 | 'labels': cls_idxs_all}
838 |
839 | return results
840 |
841 | @torch.no_grad()
842 | def postprocessing(self, results):
843 | # input : list of dictionary items
844 | # (1) push to CPU; (2) NMS; (3) convert to actual time stamps
845 | processed_results = []
846 | fps = 30
847 | stride = nframes = 16.043
848 | for results_per_vid in results:
849 | # unpack the meta info
850 | vlen = results_per_vid['duration']
851 | # 1: unpack the results and move to CPU
852 | segs = results_per_vid['segments'].detach().cpu()
853 | scores = results_per_vid['scores'].detach().cpu()
854 | labels = results_per_vid['labels'].detach().cpu()
855 | if self.test_nms_method != 'none':
856 | # 2: batched nms (only implemented on CPU)
857 | segs, scores, labels = batched_nms(
858 | segs, scores, labels,
859 | self.test_iou_threshold,
860 | self.test_min_score,
861 | self.test_max_seg_num,
862 | use_soft_nms=(self.test_nms_method == 'soft'),
863 | multiclass=self.test_multiclass_nms,
864 | sigma=self.test_nms_sigma,
865 | voting_thresh=self.test_voting_thresh
866 | )
867 | # 3: convert from feature grids to seconds
868 | if segs.shape[0] > 0:
869 | # segs_sec = (segs * stride + 0.5 * nframes) / fps
870 | segs_sec = segs * stride / fps # use_offset=False
871 | # truncate all boundaries within [0, duration]
872 | segs_sec[segs_sec <= 0.0] *= 0.0
873 | segs_sec[segs_sec >= vlen] = segs_sec[segs_sec >= vlen] * 0.0 + vlen
874 | # else: # FIXME: don't know why but Flan-T5-L produces segs.shape[0] == 0
875 | # segs_sec = torch.zeros((1, 2))
876 | # scores = torch.zeros(1)
877 | # labels = torch.zeros(1, dtype=torch.int64)
878 | # 4: repack the results
879 | processed_results.append({
880 | 'segments': segs_sec,
881 | # 'segments_feat': segs,
882 | 'scores': scores,
883 | 'labels': labels
884 | })
885 |
886 | return processed_results
887 |
888 |
889 | class NMSop(torch.autograd.Function):
890 | @staticmethod
891 | def forward(
892 | ctx, segs, scores, cls_idxs,
893 | iou_threshold, min_score, max_num
894 | ):
895 | # vanilla nms will not change the score, so we can filter segs first
896 | is_filtering_by_score = (min_score > 0)
897 | if is_filtering_by_score:
898 | valid_mask = scores > min_score
899 | segs, scores = segs[valid_mask], scores[valid_mask]
900 | cls_idxs = cls_idxs[valid_mask]
901 | valid_inds = torch.nonzero(
902 | valid_mask, as_tuple=False).squeeze(dim=1)
903 |
904 | # nms op; return inds that is sorted by descending order
905 | inds = nms_1d_cpu.nms(
906 | segs.contiguous().cpu(),
907 | scores.contiguous().cpu(),
908 | iou_threshold=float(iou_threshold))
909 | # cap by max number
910 | if max_num > 0:
911 | inds = inds[:min(max_num, len(inds))]
912 | # return the sorted segs / scores
913 | sorted_segs = segs[inds]
914 | sorted_scores = scores[inds]
915 | sorted_cls_idxs = cls_idxs[inds]
916 | return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone()
917 |
918 |
919 | class SoftNMSop(torch.autograd.Function):
920 | @staticmethod
921 | def forward(
922 | ctx, segs, scores, cls_idxs,
923 | iou_threshold, sigma, min_score, method, max_num
924 | ):
925 | # pre allocate memory for sorted results
926 | dets = segs.new_empty((segs.size(0), 3), device='cpu')
927 | # softnms op, return dets that stores the sorted segs / scores
928 | inds = nms_1d_cpu.softnms(
929 | segs.cpu(),
930 | scores.cpu(),
931 | dets.cpu(),
932 | iou_threshold=float(iou_threshold),
933 | sigma=float(sigma),
934 | min_score=float(min_score),
935 | method=int(method))
936 | # cap by max number
937 | if max_num > 0:
938 | n_segs = min(len(inds), max_num)
939 | else:
940 | n_segs = len(inds)
941 | sorted_segs = dets[:n_segs, :2]
942 | sorted_scores = dets[:n_segs, 2]
943 | sorted_cls_idxs = cls_idxs[inds]
944 | sorted_cls_idxs = sorted_cls_idxs[:n_segs]
945 | return sorted_segs.clone(), sorted_scores.clone(), sorted_cls_idxs.clone()
946 |
947 |
948 | def seg_voting(nms_segs, all_segs, all_scores, iou_threshold, score_offset=1.5):
949 | """
950 | blur localization results by incorporating side segs.
951 | this is known as bounding box voting in object detection literature.
952 | slightly boost the performance around iou_threshold
953 | """
954 |
955 | # *_segs : N_i x 2, all_scores: N,
956 | # apply offset
957 | offset_scores = all_scores + score_offset
958 |
959 | # computer overlap between nms and all segs
960 | # construct the distance matrix of # N_nms x # N_all
961 | num_nms_segs, num_all_segs = nms_segs.shape[0], all_segs.shape[0]
962 | ex_nms_segs = nms_segs[:, None].expand(num_nms_segs, num_all_segs, 2)
963 | ex_all_segs = all_segs[None, :].expand(num_nms_segs, num_all_segs, 2)
964 |
965 | # compute intersection
966 | left = torch.maximum(ex_nms_segs[:, :, 0], ex_all_segs[:, :, 0])
967 | right = torch.minimum(ex_nms_segs[:, :, 1], ex_all_segs[:, :, 1])
968 | inter = (right-left).clamp(min=0)
969 |
970 | # lens of all segments
971 | nms_seg_lens = ex_nms_segs[:, :, 1] - ex_nms_segs[:, :, 0]
972 | all_seg_lens = ex_all_segs[:, :, 1] - ex_all_segs[:, :, 0]
973 |
974 | # iou
975 | iou = inter / (nms_seg_lens + all_seg_lens - inter)
976 |
977 | # get neighbors (# N_nms x # N_all) / weights
978 | seg_weights = (iou >= iou_threshold).to(all_scores.dtype) * all_scores[None, :] * iou
979 | seg_weights /= torch.sum(seg_weights, dim=1, keepdim=True)
980 | refined_segs = seg_weights @ all_segs
981 |
982 | return refined_segs
983 |
984 |
985 | def batched_nms(
986 | segs,
987 | scores,
988 | cls_idxs,
989 | iou_threshold,
990 | min_score,
991 | max_seg_num,
992 | use_soft_nms=True,
993 | multiclass=True,
994 | sigma=0.5,
995 | voting_thresh=0.75,
996 | ):
997 | # Based on Detectron2 implementation,
998 | num_segs = segs.shape[0]
999 | # corner case, no prediction outputs
1000 | if num_segs == 0:
1001 | return torch.zeros([0, 2]),\
1002 | torch.zeros([0,]),\
1003 | torch.zeros([0,], dtype=cls_idxs.dtype)
1004 |
1005 | if multiclass:
1006 | # multiclass nms: apply nms on each class independently
1007 | new_segs, new_scores, new_cls_idxs = [], [], []
1008 | for class_id in torch.unique(cls_idxs):
1009 | curr_indices = torch.where(cls_idxs == class_id)[0]
1010 | # soft_nms vs nms
1011 | if use_soft_nms:
1012 | sorted_segs, sorted_scores, sorted_cls_idxs = SoftNMSop.apply(
1013 | segs[curr_indices],
1014 | scores[curr_indices],
1015 | cls_idxs[curr_indices],
1016 | iou_threshold,
1017 | sigma,
1018 | min_score,
1019 | 2,
1020 | max_seg_num
1021 | )
1022 | else:
1023 | sorted_segs, sorted_scores, sorted_cls_idxs = NMSop.apply(
1024 | segs[curr_indices],
1025 | scores[curr_indices],
1026 | cls_idxs[curr_indices],
1027 | iou_threshold,
1028 | min_score,
1029 | max_seg_num
1030 | )
1031 | # disable seg voting for multiclass nms, no sufficient segs
1032 |
1033 | # fill in the class index
1034 | new_segs.append(sorted_segs)
1035 | new_scores.append(sorted_scores)
1036 | new_cls_idxs.append(sorted_cls_idxs)
1037 |
1038 | # cat the results
1039 | new_segs = torch.cat(new_segs)
1040 | new_scores = torch.cat(new_scores)
1041 | new_cls_idxs = torch.cat(new_cls_idxs)
1042 |
1043 | else:
1044 | # class agnostic
1045 | if use_soft_nms:
1046 | new_segs, new_scores, new_cls_idxs = SoftNMSop.apply(
1047 | segs, scores, cls_idxs, iou_threshold,
1048 | sigma, min_score, 2, max_seg_num
1049 | )
1050 | else:
1051 | new_segs, new_scores, new_cls_idxs = NMSop.apply(
1052 | segs, scores, cls_idxs, iou_threshold,
1053 | min_score, max_seg_num
1054 | )
1055 | # seg voting
1056 | if voting_thresh > 0:
1057 | new_segs = seg_voting(
1058 | new_segs,
1059 | segs,
1060 | scores,
1061 | voting_thresh
1062 | )
1063 |
1064 | # sort based on scores and return
1065 | # truncate the results based on max_seg_num
1066 | _, idxs = new_scores.sort(descending=True)
1067 | max_seg_num = min(max_seg_num, new_segs.shape[0])
1068 | # needed for multiclass NMS
1069 | new_segs = new_segs[idxs[:max_seg_num]]
1070 | new_scores = new_scores[idxs[:max_seg_num]]
1071 | new_cls_idxs = new_cls_idxs[idxs[:max_seg_num]]
1072 | return new_segs, new_scores, new_cls_idxs
1073 |
1074 |
1075 | class FPNIdentity(nn.Module):
1076 | def __init__(
1077 | self,
1078 | in_channels, # input feature channels, len(in_channels) = # levels
1079 | out_channel, # output feature channel
1080 | start_level=0, # start fpn level
1081 | end_level=-1, # end fpn level
1082 | with_ln=True # if to apply layer norm at the end
1083 | ):
1084 | super().__init__()
1085 |
1086 | self.in_channels = in_channels
1087 | self.out_channel = out_channel
1088 |
1089 | self.start_level = start_level
1090 | if end_level == -1:
1091 | self.end_level = len(in_channels)
1092 | else:
1093 | self.end_level = end_level
1094 | assert self.end_level <= len(in_channels)
1095 | assert (self.start_level >= 0) and (self.start_level < self.end_level)
1096 |
1097 | self.fpn_norms = nn.ModuleList()
1098 | for i in range(self.start_level, self.end_level):
1099 | # check feat dims
1100 | assert self.in_channels[i] == self.out_channel
1101 | # layer norm for order (B C T)
1102 | if with_ln:
1103 | fpn_norm = LayerNorm(out_channel)
1104 | else:
1105 | fpn_norm = nn.Identity()
1106 | self.fpn_norms.append(fpn_norm)
1107 |
1108 | def forward(self, inputs, fpn_masks):
1109 | # inputs must be a list / tuple
1110 | assert len(inputs) == len(self.in_channels), (len(inputs) , len(self.in_channels))
1111 | assert len(fpn_masks) == len(self.in_channels), (len(fpn_masks) , len(self.in_channels))
1112 |
1113 | # apply norms, fpn_masks will remain the same with 1x1 convs
1114 | fpn_feats = tuple()
1115 | new_fpn_masks = tuple()
1116 | for i in range(len(self.fpn_norms)):
1117 | x = self.fpn_norms[i](inputs[i + self.start_level])
1118 | fpn_feats += (x, )
1119 | new_fpn_masks += (fpn_masks[i + self.start_level], )
1120 |
1121 | return fpn_feats, new_fpn_masks
1122 |
--------------------------------------------------------------------------------