├── .gitignore
├── EXPTS.md
├── LICENSE
├── README.md
├── code
├── README.md
├── _init_stuff.py
├── contrastive_sampling.py
├── dat_loader_simple.py
├── eval_fn_corr.py
├── eval_vsrl_corr.py
├── extended_config.py
├── main_dist.py
├── mdl_base.py
├── mdl_conc_sep.py
├── mdl_conc_single.py
├── mdl_selector.py
├── mdl_vog.py
├── transformer_code.py
└── visualizer.py
├── conda_env_vog.yml
├── configs
├── anet_srl_cfg.yml
└── create_asrl_cfg.yml
├── data
├── README.md
└── download_data.sh
├── dcode
├── README.md
├── asrl_creator.py
├── dataset_stats.py
├── download_asrl_parent_ann.sh
├── preproc_anet_files.py
├── preproc_ds_files.py
├── process_gt_props.py
└── sem_role_labeller.py
├── media
├── Intro_fig.png
├── contrastive_examples.png
├── contrastive_samples.png
├── model_fig.png
└── tempora_spatial_concat.png
├── notebooks
└── data_stats.ipynb
└── utils
├── README.md
├── __init__.py
├── box_utils.py
├── mdl_srl_utils.py
└── trn_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | cache_dir
2 | data/
3 | !data/README.md
4 | __pycache__/
5 | tmp
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Arka Sadhu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # vognet-pytorch
2 | [](https://github.com/TheShadow29/vognet-pytorch/blob/master/LICENSE)
3 | [](https://www.python.org/)
4 | 
5 | [](https://arxiv.org/abs/2003.10606)
6 |
7 |
8 | [**Video Object Grounding using Semantic Roles in Language Description**](https://arxiv.org/abs/2003.10606)
9 | [Arka Sadhu](https://theshadow29.github.io/), [Kan Chen](https://kanchen.info/) [Ram Nevatia](https://sites.usc.edu/iris-cvlab/professor-ram-nevatia/)
10 | [CVPR 2020](http://cvpr2020.thecvf.com/)
11 |
12 | **Video Object Grounding (VOG)** is the task of localizing objects in a video referred in a query sentence description.
13 | We elevate the role of object relations via spatial and temporal concatenation of contrastive examples sampled from a newly contributed dataset called ActivityNet-SRL (ASRL).
14 |
15 | 
16 |
17 | This repository includes:
18 | 1. code to create the ActivityNet-SRL dataset under [`dcode/`](./dcode)
19 | 1. code to run all the experiments provided in the paper under [`code/`](./code)
20 | 1. To foster reproducibility of our results, links to all trained models in the paper along with their log files are provided in [EXPTS.md](./EXPTS.md)
21 |
22 | Code has been modularized from its initial implementation.
23 | It should be easy to extend the code for other datasets by inheriting relevant modules.
24 |
25 | ## Installation
26 | Requirements:
27 | - python>=3.6
28 | - pytorch==1.1 (should work with pytorch >=1.3 as well but not tested)
29 |
30 | To use the same environment you can use `conda` and the environment file `conda_env_vog.yml` file provided. Please refer to [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for details on installing `conda`.
31 |
32 | ```
33 | MINICONDA_ROOT=[to your Miniconda/Anaconda root directory]
34 | conda env create -f conda_env_vog.yml --prefix $MINICONDA_ROOT/envs/vog_pyt
35 | conda activate vog_pyt
36 | ```
37 |
38 | ## Quick Start
39 | 1. Clone repo:
40 | ```
41 | git clone https://github.com/TheShadow29/vognet-pytorch.git
42 | cd vognet-pytorch
43 | export ROOT=$(pwd)
44 | ```
45 | 1. Download Data (~530gb) (See [DATA_README](./data/README.md) for more details)
46 | ```
47 | cd $ROOT/data
48 | bash download_data.sh all [data_folder]
49 | ```
50 | 1. Train Models
51 | ```
52 | cd $ROOT
53 | python code/main_dist.py "spat_vog_gt5" --ds.exp_setting='gt5' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --train.prob_thresh=0.2 --train.bs=4 --train.epochs=10 --train.lr=1e-4
54 | ```
55 | ## Data Preparation
56 | If you just want to use ASRL, you can refer to [DATA_README](./data/README.md). It contains direct links to download ASRL
57 |
58 | If instead, you want to recreate ASRL from ActivityNet Entities and ActivityNet Captions, or perhaps want to extend to a newer dataset, refer to [DATA_PREP_README.md](./dcode/README.md)
59 |
60 | ## Training
61 | Basic usage is `python code/main_dist.py "experiment_name" --arg1=val1 --arg2=val2` and the arg1, arg2 can be found in `configs/anet_srl_cfg.yml`.
62 |
63 | The hierarchical structure of `yml` is also supported using `.`
64 | For example, if you want to change the `mdl name` which looks like
65 | ```
66 | mdl:
67 | name: xyz
68 | ```
69 | you can pass `--mdl.name='abc'`
70 |
71 | As an example, training `VOGNet` using `spat` strategy with `gt5` setting:
72 |
73 | ```
74 | python code/main_dist.py "spat_vog_gt5" --ds.exp_setting='gt5' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --train.prob_thresh=0.2 --train.bs=4 --train.epochs=10 --train.lr=1e-4
75 | ```
76 |
77 | You can change default settings in `configs/anet_srl_cfg.yml` directly as well.
78 |
79 | See [EXPTS.md](./EXPTS.md) for command-line instructions for all experiments.
80 |
81 | ## Logging
82 |
83 | Logs are stored inside `tmp/` directory. When you run the code with $exp_name the following are stored:
84 | - `txt_logs/$exp_name.txt`: the config used and the training, validation losses after ever epoch.
85 | - `models/$exp_name.pth`: the model, optimizer, scheduler, accuracy, number of epochs and iterations completed are stored. Only the best model upto the current epoch is stored.
86 | - `ext_logs/$exp_name.txt`: this uses the `logging` module of python to store the `logger.debug` outputs printed. Mainly used for debugging.
87 | - `predictions`: the validation outputs of current best model.
88 |
89 | ## Evaluation
90 | To evaluate a model, you need to first load it and then pass `--only_val=True`
91 |
92 | As an example, to validate the `VOGNet` model trained in `spat` with `gt5` setting:
93 | ```
94 | python code/main_dist.py "spat_vog_gt5_valid" --train.resume=True --train.resume_path='./tmp/models/spat_vog_gt5.pth' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --only_val=True --train.prob_thresh=0.2
95 | ```
96 |
97 | This will create `./tmp/predictions/spat_vog_gt5_valid/valid_0.pkl` and print out the metrics.
98 |
99 | You can also evaluate this file using `code/eval_fn_corr.py`. This assumes `valid_0.pkl` file is already generated.
100 |
101 | ```
102 | python code/eval_fn_corr.py --pred_file='./tmp/predictions/spat_vog_gt5_valid/valid_0.pkl' --split_type='valid' --train.prob_thresh=0.2
103 | ```
104 |
105 | For evaluating `test` simply use `--split_type='test'`
106 |
107 | If you are using your own code, but just want to use evaluation, you must save your output in the following format:
108 | ```
109 | [
110 | {
111 | 'idx_sent': id of the input query
112 | 'pred_boxes': # num_srls x num_vids x num_frames x 5d prop boxes
113 | 'pred_scores': # num_srls x num_vids x num_frames (between 0-1)
114 | 'pred_cmp': # num_srls x num_frames (only required for sep). Basically, which video to choose
115 | 'cmp_msk': 1/0s if any videos were padded and hence not considered
116 | 'targ_cmp': which is the target video. This is in prediction and not ground-truth since we shuffle the video list at runtime
117 | },
118 | ...
119 | ]
120 | ```
121 |
122 | ## Pre-Trained Models
123 |
124 | Google Drive Link for all models: https://drive.google.com/open?id=1e3FiX4FTC8n6UrzY9fTYQzFNKWHihzoQ
125 |
126 | Also, see individual models (with corresponding logs) at [EXPTS.md](./EXPTS.md)
127 |
128 | ## Acknowledgements:
129 |
130 | We thank:
131 | 1. @LuoweiZhou: for his codebase on GVD (https://github.com/facebookresearch/grounded-video-description) along with the extracted features.
132 | 2. [allennlp](https://github.com/allenai/allennlp) for providing [demo](https://demo.allennlp.org/semantic-role-labeling) and pre-trained model for SRL.
133 | 3. [fairseq](https://github.com/pytorch/fairseq) for providing a neat implementation of LSTM.
134 |
135 | ## Citation
136 | ```
137 | @InProceedings{Sadhu_2020_CVPR,
138 | author = {Sadhu, Arka and Chen, Kan and Nevatia, Ram},
139 | title = {Video Object Grounding using Semantic Roles in Language Description},
140 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
141 | month = {June},
142 | year = {2020}
143 | }
144 | ```
145 |
--------------------------------------------------------------------------------
/code/README.md:
--------------------------------------------------------------------------------
1 | # Model
2 |
3 | 
4 |
5 | ## File Organization
6 |
7 | 1. `main_dist.py` is the main file.
8 | 1. `dat_loader_simple.py` processes the data. In particular, the SPAT/TEMP/SEP part is modular, can be easily extended to a newer dataset.
9 | 1. `contrastive_sampling.py` as the name suggests creates the contrastive samples for Training. The validation file already contains these indices.
10 | 1. `mdl_base.py` is the base model, which just defines bunch of functions to be filled in.
11 | 1. `mdl_conc_single.py` implements concatenation models and losses for SPAT/TEMP. Similarly, `mdl_conc_sep.py` implements SEP concatentation model and loss. These are kept modular, so that they can be re-used with newer models with ~~minimal~~ some effort.
12 | 1. `mdl_vog.py` contains the main model implementations of baselines and vog.
13 | 1. `mdl_selector.py` returns the model, loss and evaluation function to be used based on input arguments.
14 | 1. `eval_vsrl_corr.py` is the top-level evaluation functions for each of SEP/TEMP/SPAT which processes the output of the model and converts them to uniform format for evaluation.
15 | 1. `eval_fn_corr.py` contains the main logic for evaluating the models.
16 | 1. `_init_stuff.py` initializes paths to be included, typings, as well as yaml float loader (otherwise 1e-4 cannot be read correctly).
17 | 1. `extended_config.py` has some handy configuration utils.
18 | 1. `transformer_code.py` has the transformer implementation, also has the relative transformer which uses relative position encoding (RPE).
19 |
20 | Some other useful files are under [`utils`](../utils/)
21 |
--------------------------------------------------------------------------------
/code/_init_stuff.py:
--------------------------------------------------------------------------------
1 | """
2 | Initialize stuff
3 | """
4 |
5 | import pdb
6 | from pathlib import Path
7 | from typing import List, Dict, Any, Union
8 | from yacs.config import CfgNode as CN
9 | import numpy as np
10 | import torch
11 | from torch.utils.data import Dataset, DataLoader
12 | import os
13 | import sys
14 | import yaml
15 | import re
16 | import pandas as pd
17 |
18 | Fpath = Union[Path, str]
19 | Cft = CN
20 | Arr = Union[np.array, List, torch.tensor]
21 | DF = pd.DataFrame
22 | # Ds = Dataset
23 |
24 | # This is required to read 5e-4 as a float rather than string
25 | # at all places yaml should be imported from here
26 | yaml.SafeLoader.add_implicit_resolver(
27 | u'tag:yaml.org,2002:float',
28 | re.compile(u'''^(?:
29 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
30 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
31 | |\\.[0-9_]+(?:[eE][-+][0-9]+)?
32 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
33 | |[-+]?\\.(?:inf|Inf|INF)
34 | |\\.(?:nan|NaN|NAN))$''', re.X),
35 | list(u'-+0123456789.'))
36 |
37 | # _SCRIPTPATH_ =
38 | sys.path.append('./code/')
39 | sys.path.append('./utils')
40 |
41 |
42 | class ForkedPdb(pdb.Pdb):
43 | """A Pdb subclass that may be used
44 | from a forked multiprocessing child
45 | Credits:
46 | https://github.com/williamFalcon/forked-pdb/blob/master/fpdb.py
47 | """
48 |
49 | def interaction(self, *args, **kwargs):
50 | _stdin = sys.stdin
51 | try:
52 | sys.stdin = open('/dev/stdin')
53 | pdb.Pdb.interaction(self, *args, **kwargs)
54 | finally:
55 | sys.stdin = _stdin
56 |
--------------------------------------------------------------------------------
/code/contrastive_sampling.py:
--------------------------------------------------------------------------------
1 | """
2 | To create the 4-way dataset
3 | Main motivation:
4 | Currently, not sure if the models ground
5 | based only on object name, or is it really
6 | learning the roles of the visual elements
7 | correctly.
8 |
9 | Thus, we create 4-way dataset, for every
10 | data which has S-V-O statistics, we generate
11 | counterfactuals (not sure if this is
12 | the correct name or not). For every image
13 | containing say S1-V1-O1, present it with other
14 | images with the characteristics S2-V1-O1,
15 | S1-V2-O1, S1-V1-O2 as well. Some can be
16 | reduced in case only S-V or O-V are present
17 |
18 | More generally, we would like to create a
19 | counterfactuals for anything that can
20 | provide evidence.
21 |
22 | Additionally, need to check
23 | - [x] Location words shouldn't be present
24 | - [x] Perform VERB lemmatization
25 | - [x] Distinguish between what is groundable and
26 | what is not
27 | - [x] Check the groundable verbs
28 | """
29 | from pathlib import Path
30 | import pandas as pd
31 |
32 | from tqdm.auto import tqdm
33 | from collections import Counter
34 | import json
35 | import copy
36 | import ast
37 | import numpy as np
38 | from _init_stuff import CN, yaml
39 | from typing import List
40 | np.random.seed(seed=5)
41 |
42 |
43 | def create_random_list(cfg, srl_annots, ann_row_idx):
44 | """
45 | Returns 4 random videos
46 | """
47 | srl_idxs_possible = np.array(srl_annots.index)
48 |
49 | vid_segs = srl_annots.vid_seg
50 | vid_seg = vid_segs.loc[ann_row_idx]
51 | srl_row = srl_annots.loc[ann_row_idx]
52 |
53 | req_cls_pats = srl_row.req_cls_pats
54 | req_cls_pats_mask = srl_row.req_cls_pats_mask
55 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'])
56 |
57 | arg_keys_vis_present = []
58 | arg_keys_lang_present = []
59 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask):
60 | arg_key = srl_arg[0]
61 | arg_keys_lang_present.append(arg_key)
62 | if arg_key == 'V' or arg_key in args_to_use:
63 | arg_keys_vis_present.append(arg_key)
64 |
65 | ds4_msk = {}
66 | inds_to_use = {}
67 | num_arg_keys_vis = len(arg_keys_vis_present)
68 | other_anns = np.random.choice(
69 | srl_idxs_possible, size=10 * num_arg_keys_vis,
70 | replace=False
71 | ).reshape(num_arg_keys_vis, 10)
72 |
73 | for aind, arg_key1 in enumerate(arg_keys_vis_present):
74 | in1 = other_anns[aind].tolist()
75 | assert len(in1) == 10
76 |
77 | set1 = set(in1)
78 |
79 | set_int = [s for s in set1 if
80 | vid_segs.loc[s] != vid_seg]
81 |
82 | # TODO:
83 | # Make replace false, currently true
84 | # because some have low chances of
85 | # appearing
86 | assert len(set_int) > 0
87 | inds_to_use[arg_key1] = set_int
88 | ds4_msk[arg_key1] = 1
89 | return inds_to_use, ds4_msk
90 |
91 |
92 | def create_similar_list_new(cfg, arg_dicts, srl_annots, ann_row_idx):
93 | """
94 | Does it for one row. Assumes annotations
95 | exists and can be retrieved via `self`.
96 |
97 | The logic:
98 | Each input idx has ARG0, V, ARG1 ...,
99 | (1) Pivot across one argument, say ARG0
100 | (2) Retrieve all other indices such that they
101 | have different ARG0, but same V, ARG1 ... (do
102 | each of them separately)
103 | (3) To retrieve those indices with V, ARG1 same
104 | we can just do intersection of the two sets
105 |
106 | To facilitate (2), we first create separate
107 | dictionaries for each V, ARG1 etc. and then
108 | just reference them via self.create_dicts
109 | """
110 | srl_idxs_possible = np.array(srl_annots.index)
111 |
112 | vid_segs = srl_annots.vid_seg
113 | vid_seg = vid_segs.loc[ann_row_idx]
114 | srl_row = srl_annots.loc[ann_row_idx]
115 |
116 | req_cls_pats = srl_row.req_cls_pats
117 | req_cls_pats_mask = srl_row.req_cls_pats_mask
118 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'])
119 | some_inds = {}
120 | arg_keys_vis_present = []
121 | arg_keys_lang_present = []
122 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask):
123 | arg_key = srl_arg[0]
124 | arg_keys_lang_present.append(arg_key)
125 | if arg_key == 'V' or arg_key in args_to_use:
126 | arg_keys_vis_present.append(arg_key)
127 | if arg_key in args_to_use:
128 | lemma_key = 'lemma_{}'.format(
129 | arg_key.replace('-', '_').replace('V', 'verb'))
130 | lemma_arg = srl_row[lemma_key]
131 | if isinstance(lemma_arg, list):
132 | assert all([le_arg in arg_dicts[arg_key]
133 | for le_arg in lemma_arg])
134 | if len(lemma_arg) >= 1:
135 | le_arg = lemma_arg
136 | else:
137 | le_arg = cfg.ds.none_word
138 | else:
139 | le_arg = [lemma_arg]
140 | # srl_ind_list = copy.deepcopy(
141 | # arg_dicts[arg_key][le_arg])
142 | # srl_ind_list = arg_dicts[arg_key][le_arg][:]
143 | for le_ar in le_arg:
144 | srl_ind_list = arg_dicts[arg_key][le_ar][:]
145 | srl_ind_list.remove(ann_row_idx)
146 | if arg_key not in some_inds:
147 | some_inds[arg_key] = []
148 | some_inds[arg_key] += srl_ind_list
149 | # # If not groundable but in args_to_use
150 | # else:
151 | # pass
152 | num_arg_keys_vis = len(arg_keys_vis_present)
153 | other_anns = np.random.choice(
154 | srl_idxs_possible, size=10 * num_arg_keys_vis,
155 | replace=False
156 | ).reshape(num_arg_keys_vis, 10)
157 |
158 | inds_to_use = {}
159 | ds4_msk = {}
160 | for aind, arg_key1 in enumerate(arg_keys_vis_present):
161 | arg_key_to_use = [
162 | ak for ak in arg_keys_vis_present if ak != arg_key1]
163 | set1 = set(some_inds[arg_key_to_use[0]])
164 |
165 | set_int1 = set1.intersection(
166 | *[set(some_inds[ak]) for ak in arg_key_to_use[1:]])
167 | curr_set = set(some_inds[arg_key1])
168 | set_int2 = list(set_int1 - curr_set)
169 |
170 | set_int = [s for s in set_int2 if
171 | vid_segs.loc[s] != vid_seg]
172 |
173 | # TODO:
174 | # Make replace false, currently true
175 | # because some have low chances of
176 | # appearing
177 | if len(set_int) == 0:
178 | # this means similar scenario not found
179 | # inds
180 | ds4_msk[arg_key1] = 0
181 | inds_to_use[arg_key1] = other_anns[aind].tolist()
182 | # inds_to_use[arg_key1] = [-1]
183 | # cfg.ouch += 1
184 | # print('ouch')
185 | else:
186 | ds4_msk[arg_key1] = 1
187 | inds_to_use[arg_key1] = np.random.choice(
188 | set_int, 10, replace=True).tolist()
189 | # cfg.yolo += 1
190 | # print('yolo')
191 | # inds_to_use_lens = [len(v) if v[0] != -1 else 0 for k,
192 | # v in inds_to_use.items()]
193 | # if sum(inds_to_use_lens) == 0:
194 | # cfg.ouch2 += 1
195 | # else:
196 | # cfg.yolo2 += 1
197 |
198 | return inds_to_use, ds4_msk
199 |
200 |
201 | def create_similar_list(cfg, arg_dicts, srl_annots, ann_row_idx):
202 | """
203 | Does it for one row. Assumes annotations
204 | exists and can be retrieved via `self`.
205 |
206 | The logic:
207 | Each input idx has ARG0, V, ARG1 ...,
208 | (1) Pivot across one argument, say ARG0
209 | (2) Retrieve all other indices such that they
210 | have different ARG0, but same V, ARG1 ... (do
211 | each of them separately)
212 | (3) To retrieve those indices with V, ARG1 same
213 | we can just do intersection of the two sets
214 |
215 | To facilitate (2), we first create separate
216 | dictionaries for each V, ARG1 etc. and then
217 | just reference them via self.create_dicts
218 | """
219 | srl_idxs_possible = np.array(srl_annots.index)
220 |
221 | vid_segs = srl_annots.vid_seg
222 | vid_seg = vid_segs.loc[ann_row_idx]
223 | srl_row = srl_annots.loc[ann_row_idx]
224 |
225 | req_cls_pats = srl_row.req_cls_pats
226 | req_cls_pats_mask = srl_row.req_cls_pats_mask
227 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'])
228 | some_inds = {}
229 | arg_keys_vis_present = []
230 | arg_keys_lang_present = []
231 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask):
232 | arg_key = srl_arg[0]
233 | arg_keys_lang_present.append(arg_key)
234 | if arg_key == 'V' or arg_key in args_to_use:
235 | arg_keys_vis_present.append(arg_key)
236 | if arg_key in args_to_use:
237 | lemma_key = 'lemma_{}'.format(
238 | arg_key.replace('-', '_').replace('V', 'verb'))
239 | lemma_arg = srl_row[lemma_key]
240 | if isinstance(lemma_arg, list):
241 | assert all([le_arg in arg_dicts[arg_key]
242 | for le_arg in lemma_arg])
243 | if len(lemma_arg) >= 1:
244 | le_arg = lemma_arg[0]
245 | else:
246 | le_arg = cfg.ds.none_word
247 | else:
248 | le_arg = lemma_arg
249 | # srl_ind_list = copy.deepcopy(
250 | # arg_dicts[arg_key][le_arg])
251 | # srl_ind_list = arg_dicts[arg_key][le_arg][:]
252 | srl_ind_list = arg_dicts[arg_key][le_arg][:]
253 | srl_ind_list.remove(ann_row_idx)
254 | if arg_key not in some_inds:
255 | some_inds[arg_key] = []
256 | some_inds[arg_key] += srl_ind_list
257 | # # If not groundable but in args_to_use
258 | # else:
259 | # pass
260 | num_arg_keys_vis = len(arg_keys_vis_present)
261 | other_anns = np.random.choice(
262 | srl_idxs_possible, size=10 * num_arg_keys_vis,
263 | replace=False
264 | ).reshape(num_arg_keys_vis, 10)
265 |
266 | inds_to_use = {}
267 | ds4_msk = {}
268 | for aind, arg_key1 in enumerate(arg_keys_vis_present):
269 | arg_key_to_use = [
270 | ak for ak in arg_keys_vis_present if ak != arg_key1]
271 | set1 = set(some_inds[arg_key_to_use[0]])
272 |
273 | set_int1 = set1.intersection(
274 | *[set(some_inds[ak]) for ak in arg_key_to_use[1:]])
275 | curr_set = set(some_inds[arg_key1])
276 | set_int2 = list(set_int1 - curr_set)
277 |
278 | set_int = [s for s in set_int2 if
279 | vid_segs.loc[s] != vid_seg]
280 |
281 | # TODO:
282 | # Make replace false, currently true
283 | # because some have low chances of
284 | # appearing
285 | if len(set_int) == 0:
286 | # this means similar scenario not found
287 | # inds
288 | ds4_msk[arg_key1] = 0
289 | inds_to_use[arg_key1] = other_anns[aind].tolist()
290 | # inds_to_use[arg_key1] = [-1]
291 | # cfg.ouch += 1
292 | # print('ouch')
293 | else:
294 | ds4_msk[arg_key1] = 1
295 | inds_to_use[arg_key1] = np.random.choice(
296 | set_int, 10, replace=True).tolist()
297 | # cfg.yolo += 1
298 | # print('yolo')
299 | # inds_to_use_lens = [len(v) if v[0] != -1 else 0 for k,
300 | # v in inds_to_use.items()]
301 | # if sum(inds_to_use_lens) == 0:
302 | # cfg.ouch2 += 1
303 | # else:
304 | # cfg.yolo2 += 1
305 |
306 | return inds_to_use, ds4_msk
307 |
308 |
309 | class AnetDSCreator:
310 | def __init__(self, cfg, tdir='.'):
311 | self.cfg = cfg
312 | self.tdir = Path(tdir)
313 |
314 | def fix_via_ast(self, df):
315 | for k in df.columns:
316 | first_word = df.iloc[0][k]
317 | if isinstance(first_word, str) and (first_word[0] in '[{'):
318 | df[k] = df[k].apply(
319 | lambda x: ast.literal_eval(x))
320 | return df
321 |
322 | def get_stats(self, req_args):
323 | """
324 | Gets the counts for the argument types
325 | """
326 | c = Counter()
327 | if isinstance(req_args[0], list):
328 | for x in req_args:
329 | c += Counter(x)
330 | else:
331 | c = Counter(req_args)
332 |
333 | return c.most_common()
334 |
335 | def create_all_similar_lists(self):
336 | self.create_similar_lists(split_type='train')
337 | self.create_similar_lists(split_type='valid')
338 |
339 | def create_similar_lists(self, split_type: str = 'train'):
340 | """
341 | need to check if only
342 | creating for the validation
343 | set would be enough or not.
344 |
345 | Basically, for each input,
346 | generates list of other inputs (idxs)
347 | which have same S,V,O (at least one is same)
348 | """
349 | if split_type == 'train':
350 | srl_annot_file = self.tdir / self.cfg.ds.trn_verb_ent_file
351 | ds4_dict_file = self.tdir / self.cfg.ds.trn_ds4_dicts
352 | ds4_ind_file = self.tdir / self.cfg.ds.trn_ds4_inds
353 | elif split_type == 'valid':
354 | srl_annot_file = self.tdir / self.cfg.ds.val_verb_ent_file
355 | ds4_dict_file = self.tdir / self.cfg.ds.val_ds4_dicts
356 | ds4_ind_file = self.tdir / self.cfg.ds.val_ds4_inds
357 | elif split_type == 'trn_val':
358 | srl_annot_file = self.tdir / self.cfg.ds.verb_ent_file
359 | ds4_dict_file = self.tdir / self.cfg.ds.ds4_dicts
360 | ds4_ind_file = self.tdir / self.cfg.ds.ds4_inds
361 | elif split_type == 'only_val':
362 | srl_annot_file = Path('./data/anet_verb/val_1_verb_ent_file.csv')
363 | ds4_dict_file = Path(
364 | './data/anet_verb/val_1_srl_args_dict_obj_to_ind.json'
365 | )
366 | else:
367 | raise NotImplementedError
368 | # elif split_type == 'test':
369 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file
370 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts
371 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds
372 | # elif split_type == 'val_test':
373 | # # validation file with validation+test indices
374 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file
375 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts
376 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds
377 | # elif split_type == 'test_val':
378 | # # test file with validation+test indices
379 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file
380 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts
381 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds
382 | # else:
383 | # raise NotImplementedError
384 | srl_annots = self.fix_via_ast(pd.read_csv(srl_annot_file))
385 |
386 | self.create_dicts_srl(srl_annots, ds4_dict_file)
387 |
388 | arg_dicts = json.load(open(ds4_dict_file))
389 | srl_annots_copy = copy.deepcopy(srl_annots)
390 | # inds_to_use_list = [self.create_similar_list(
391 | # row_ind) for row_ind in tqdm(range(len(self.srl_annots)))]
392 | inds_to_use_list = []
393 | ds4_msk = []
394 | rand_inds_to_use_list = []
395 |
396 | for row_ind in tqdm(range(len(srl_annots))):
397 | inds_to_use, ds4_msk_out = create_similar_list(
398 | self.cfg, arg_dicts, srl_annots, row_ind)
399 | ds4_msk.append(ds4_msk_out)
400 |
401 | inds_to_use_list.append(inds_to_use)
402 |
403 | rand_inds_to_use, _ = create_random_list(
404 | self.cfg, srl_annots, row_ind
405 | )
406 | rand_inds_to_use_list.append(rand_inds_to_use)
407 |
408 | srl_annots_copy['DS4_Inds'] = inds_to_use_list
409 | srl_annots_copy['ds4_msk'] = ds4_msk
410 |
411 | srl_annots_copy['RandDS4_Inds'] = rand_inds_to_use_list
412 | # srl_annots_copy = srl_annots_copy.iloc[ds4_msk]
413 |
414 | srl_annots_copy.to_csv(
415 | ds4_ind_file, index=False, header=True)
416 | # srl_annots_copy.to_csv(
417 | # self.tdir/self.cfg.ds.ds4_inds, index=False, header=True)
418 | # for row_ind in range(len(self.srl_annots)):
419 | # inds_to_use = self.create_similar_list(row_ind)
420 |
421 | def create_dicts_srl(self, srl_annots, out_file):
422 | def default_dict_list(key_list, val, dct):
423 | for key in key_list:
424 | if key not in dct:
425 | dct[key] = []
426 | dct[key].append(val)
427 | return dct
428 |
429 | # srl_annots = self.srl_annots
430 |
431 | # args_dict_out: Dict[str, Dict[obj_name, srl_indices]]
432 | # arg_dict_lemma: Dict[str, List[obj_name]]
433 | args_dict_out = {}
434 | args_to_use = ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']
435 | for srl_arg in args_to_use:
436 | args_dict_out[srl_arg] = {}
437 |
438 | for row_ind, row in tqdm(srl_annots.iterrows(),
439 | total=len(srl_annots)):
440 | req_cls_pats = row.req_cls_pats
441 | req_cls_pats_mask = row.req_cls_pats_mask
442 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask):
443 | arg_key = srl_arg_mask[0]
444 | if arg_key in args_dict_out:
445 | # The argument is groundable
446 | if srl_arg_mask[1] == 1:
447 | key_list = list(set(srl_arg[1]))
448 | args_dict_out[arg_key] = default_dict_list(
449 | key_list, row_ind, args_dict_out[arg_key])
450 | else:
451 | key_list = [self.cfg.ds.none_word]
452 | args_dict_out[arg_key] = default_dict_list(
453 | key_list, row_ind, args_dict_out[arg_key])
454 |
455 | args_dict_out['V'] = {k: list(v.index) for k,
456 | v in srl_annots.groupby('lemma_verb')}
457 | json.dump(args_dict_out, open(out_file, 'w'))
458 | return args_dict_out
459 |
460 |
461 | def main(splits: List):
462 | if not isinstance(splits, list):
463 | assert isinstance(splits, str)
464 | splits = [splits]
465 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))
466 | for split_type in splits:
467 | anet_ds = AnetDSCreator(cfg)
468 | anet_ds.create_similar_lists(split_type=split_type)
469 |
470 |
471 | if __name__ == '__main__':
472 | import fire
473 | fire.Fire(main)
474 | # cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))
475 | # for split_type in ['valid', 'train']:
476 | # # for split_type in ['only_val', 'valid', 'train', 'trn_val']:
477 | # # cfg.ouch = 0
478 | # # cfg.yolo = 0
479 |
480 | # # cfg.ouch2 = 0
481 | # # cfg.yolo2 = 0
482 |
483 | # anet_ds = AnetDSCreator(cfg)
484 | # # anet_ds.create_dicts_srl()
485 | # anet_ds.create_similar_lists(split_type=split_type)
486 |
487 | # # break
488 |
489 | # # anet_ds.create_similar_lists(split_type='trn_val')
490 | # # anet_ds.create_similar_lists(split_type='train')
491 | # # anet_ds.create_similar_lists(split_type='valid')
492 | # # print(cfg.ouch, cfg.yolo, cfg.yolo+cfg.ouch)
493 | # # print(cfg.ouch2, cfg.yolo2, cfg.yolo2+cfg.ouch2)
494 |
--------------------------------------------------------------------------------
/code/eval_vsrl_corr.py:
--------------------------------------------------------------------------------
1 | """
2 | Better evaluation.
3 | Corrected a few implementations
4 | for sep, temporal, spatial.
5 | """
6 | from eval_fn_corr import (
7 | GroundEval_SEP,
8 | GroundEval_TEMP,
9 | GroundEval_SPAT
10 | )
11 | import pickle
12 | from fastprogress import progress_bar
13 | from pathlib import Path
14 | import torch
15 | from trn_utils import (
16 | compute_avg_dict,
17 | is_main_process,
18 | synchronize,
19 | get_world_size
20 | )
21 |
22 |
23 | class Evaluator(torch.nn.Module):
24 | def __init__(self, cfg, comm, device):
25 | super().__init__()
26 | self.cfg = cfg
27 | self.comm = comm
28 | self.met_keys = ['avg1', 'macro_avg1']
29 | self.num_prop_per_frm = self.comm.num_prop_per_frm
30 | self.num_frms = self.cfg.ds.num_sampled_frm
31 | self.num_props = self.num_prop_per_frm * self.num_frms
32 | self.device = device
33 | self.after_init()
34 |
35 | def after_init(self):
36 | pass
37 |
38 | def get_out_results(self, out_result):
39 | if isinstance(out_result, torch.Tensor):
40 | return out_result
41 | else:
42 | return out_result['mdl_outs']
43 |
44 | def forward_one_batch(self, out_result, inp):
45 | """
46 | The following should be returned:
47 | List[Dict]
48 | Dict = {
49 | 'idx(video)', 'idx(srl)', 'idx(arg)',
50 | 'pred_boxes', 'pred_scores'
51 | }
52 | """
53 | out_result = out_result
54 | # B x num_verbs x num_srl_args x 1000
55 | B, num_verbs, num_srl_args, num_props = out_result.shape
56 | assert self.num_props == num_props
57 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm
58 | out_result_frame = torch.sigmoid(
59 | out_result.view(
60 | B, num_verbs, num_srl_args,
61 | self.num_frms, self.num_prop_per_frm
62 | )
63 | )
64 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm
65 | out_result_frame_score, out_result_frame_index = torch.max(
66 | out_result_frame, dim=-1)
67 |
68 | props = inp['pad_proposals']
69 | _, num_props, prop_dim = props.shape
70 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm x prop_dim
71 | props_reshaped = props.view(
72 | B, 1, 1, self.num_frms, self.num_prop_per_frm, prop_dim).expand(
73 | B, num_verbs, num_srl_args,
74 | self.num_frms, self.num_prop_per_frm, prop_dim)
75 |
76 | out_result_boxes = torch.gather(
77 | props_reshaped, dim=-2,
78 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand(
79 | *out_result_frame_index.shape, 1, prop_dim))
80 |
81 | pred_boxes = out_result_boxes.squeeze(-2)
82 | out_dict_list = [
83 | {
84 | 'pred_boxes': pb,
85 | 'pred_scores': ps,
86 | 'idx_vid': an_ind,
87 | 'idx_sent': srl_ann,
88 | 'idx_verb': srl_verb,
89 | 'num_verbs': nv
90 |
91 | } for pb, ps, an_ind, srl_ann, srl_verb, nv in zip(
92 | pred_boxes.detach().cpu().tolist(),
93 | out_result_frame_score.detach().cpu().tolist(),
94 | inp['ann_idx'].detach().cpu().tolist(),
95 | inp['sent_idx'].detach().cpu().tolist(),
96 | inp['srl_verb_idxs'].detach().cpu().tolist(),
97 | inp['num_verbs'].detach().cpu().tolist()
98 | )]
99 | return out_dict_list
100 |
101 | def forward(self, model, loss_fn, dl, dl_name,
102 | rank=0, pred_path=None, mb=None):
103 | fname = Path(pred_path) / f'{dl_name}_{rank}.pkl'
104 | # comm = self.comm
105 | # cfg = self.cfg
106 | model.eval()
107 | loss_keys = loss_fn.loss_keys
108 | val_losses = {k: [] for k in loss_keys}
109 | nums = []
110 | results = []
111 | for batch in progress_bar(dl, parent=mb):
112 | for b in batch.keys():
113 | batch[b] = batch[b].to(self.device)
114 | b = next(iter(batch.keys()))
115 | nums.append(batch[b].size(0))
116 | torch.cuda.empty_cache()
117 | with torch.no_grad():
118 | out = model(batch)
119 | out_loss = loss_fn(out, batch)
120 |
121 | for k in out_loss:
122 | val_losses[k].append(out_loss[k].detach().cpu())
123 | results += self.forward_one_batch(out, batch)
124 |
125 | pickle.dump(results, open(fname, 'wb'))
126 | nums = torch.tensor(nums).float()
127 | val_loss = compute_avg_dict(val_losses, nums)
128 |
129 | synchronize()
130 | if is_main_process():
131 | curr_results = results
132 | world_size = get_world_size()
133 | for w in range(1, world_size):
134 | tmp_file = Path(pred_path) / f'{dl_name}_{w}.pkl'
135 | with open(tmp_file, 'rb') as f:
136 | tmp_results = pickle.load(f)
137 | curr_results += tmp_results
138 | tmp_file.unlink
139 | with open(fname, 'wb') as f:
140 | pickle.dump(curr_results, f)
141 | out_acc = self.grnd_eval.eval_ground_acc(fname)
142 | val_acc = {k: torch.tensor(v).to(self.device)
143 | for k, v in out_acc.items() if k in self.met_keys}
144 | # return val_loss, val_acc
145 | synchronize()
146 | if is_main_process():
147 | return val_loss, val_acc
148 | else:
149 | return {k: torch.tensor(0.).to(self.device) for k in loss_keys}, {
150 | k: torch.tensor(0.).to(self.device) for k in self.met_keys}
151 |
152 |
153 | class EvaluatorSEP(Evaluator):
154 | def after_init(self):
155 |
156 | self.met_keys = ['avg1', 'avg1_cons',
157 | 'avg1_vidf', 'avg1_strict']
158 | self.grnd_eval = GroundEval_SEP(self.cfg, self.comm)
159 |
160 | self.num_sampled_frm = self.num_frms
161 |
162 | def get_out_results_boxes(self, out_result_dict, inp):
163 | """
164 | get the correct boxes, scores, indexes per frame
165 | """
166 | assert isinstance(out_result_dict, dict)
167 | # B x num_cmp
168 | fin_scores = out_result_dict['fin_scores']
169 | B, num_cmp = fin_scores.shape
170 |
171 | # B
172 | vidf_outs = torch.argmax(fin_scores, dim=-1)
173 |
174 | # B x num_cmp x num_srl_args x num_props
175 | mdl_outs = out_result_dict['mdl_outs_eval']
176 |
177 | B, num_cmp, num_srl_args, num_props = mdl_outs.shape
178 |
179 | mdl_outs_reshaped = mdl_outs.transpose(
180 | 1, 2).contiguous().view(
181 | B, num_srl_args, num_cmp,
182 | self.num_sampled_frm, self.num_prop_per_frm
183 | )
184 |
185 | # B x num_srl_args x num_cmp x num_frms
186 | out_result_frame_score, out_result_frame_index = torch.max(
187 | mdl_outs_reshaped, dim=-1
188 | )
189 |
190 | props = inp['pad_proposals']
191 | _, num_cmp, num_props, prop_dim = props.shape
192 |
193 | props_reshaped = props.view(
194 | B, 1, num_cmp,
195 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim
196 | ).expand(
197 | B, num_srl_args, num_cmp,
198 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim
199 | )
200 |
201 | props_out = torch.gather(
202 | props_reshaped,
203 | dim=-2,
204 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand(
205 | B, num_srl_args, num_cmp,
206 | self.num_sampled_frm, 1, prop_dim
207 | )
208 | )
209 |
210 | props_out = props_out.squeeze(-2)
211 |
212 | # B -> B x #srl x #frms
213 | vidf_outs = vidf_outs.view(B, 1, 1).expand(
214 | B, num_srl_args, self.num_frms)
215 |
216 | return {
217 | 'boxes': props_out,
218 | 'scores': out_result_frame_score,
219 | 'indexs': vidf_outs
220 | }
221 |
222 | def forward_one_batch(self, out_result, inp):
223 | """
224 | The following should be returned:
225 | List[Dict]
226 | Dict = {
227 | 'idx(video)', 'idx(srl)', 'idx(arg)',
228 | 'pred_boxes', 'pred_scores'
229 | }
230 | """
231 | out_results = self.get_out_results_boxes(out_result, inp)
232 |
233 | out_result_boxes = out_results['boxes']
234 | out_result_frame_score = out_results['scores']
235 | out_result_frame_index = out_results['indexs']
236 |
237 | # B x num_srl_args x num_cmp x num_frms x num_props
238 | pred_boxes = out_result_boxes
239 | # B x num_srl_args x num_frms
240 | pred_cmp = out_result_frame_index
241 | # B x num_srl_args x num_cmp x num_frms
242 | pred_score = out_result_frame_score
243 | targ_cmp = inp['target_cmp'].detach().cpu().tolist()
244 | perm_list = inp['permute'].detach().cpu().tolist()
245 | perm_inv_list = inp['permute_inv'].detach().cpu().tolist()
246 |
247 | out_dict_list = [
248 | {
249 | 'pred_boxes': pb,
250 | 'pred_scores': ps,
251 | 'pred_cmp': pc,
252 | 'idx_vid': an_ind,
253 | 'idx_verbs': srl_idxs,
254 | 'idx_sent': srl_ann,
255 | 'cmp_msk': cmp_msk,
256 | 'targ_cmp': tcmp,
257 | 'perm': perm,
258 | 'perm_inv': perm_inv,
259 |
260 | } for pb, ps, pc, an_ind, srl_idxs, srl_ann, cmp_msk,
261 | tcmp, perm, perm_inv in zip(
262 | pred_boxes.detach().cpu().tolist(),
263 | pred_score.detach().cpu().tolist(),
264 | pred_cmp.detach().cpu().tolist(),
265 | inp['ann_idx'].detach().cpu().tolist(),
266 | inp['new_srl_idxs'].detach().cpu().tolist(),
267 | inp['sent_idx'].detach().cpu().tolist(),
268 | inp['num_cmp_msk'].detach().cpu().tolist(),
269 | targ_cmp,
270 | perm_list,
271 | perm_inv_list
272 | )]
273 | return out_dict_list
274 |
275 |
276 | class EvaluatorTEMP(EvaluatorSEP):
277 | def after_init(self):
278 | # self.met_keys = ['avg1', 'macro_avg1', 'avg1_cons', 'macro_avg1_cons']
279 | # self.grnd_eval = GroundEvalDS4(self.cfg, self.comm)
280 |
281 | self.met_keys = ['avg1', 'avg1_cons',
282 | 'avg1_vidf', 'avg1_strict']
283 | self.grnd_eval = GroundEval_TEMP(self.cfg, self.comm)
284 |
285 | # self.num_sampled_frm = self.cfg.misc.num_sampled_frm
286 | self.num_sampled_frm = self.num_frms
287 | # self.num_prop_per_frm = self.cfg.misc.num_prop_per_frm
288 |
289 | def get_out_results_boxes(self, out_result_dict, inp):
290 | """
291 | get the correct boxes, scores, indexes per frame
292 | """
293 | assert isinstance(out_result_dict, dict)
294 |
295 | out_result = out_result_dict['mdl_outs_eval']
296 | num_cmp = inp['new_srl_idxs'].size(1)
297 |
298 | # B x num_verbs x num_srl_args x 4000
299 | B, num_verbs, num_srl_args, num_props = out_result.shape
300 |
301 | assert num_verbs == 1
302 | # B x num_srl_args x num_props
303 | # mdl_outs = out_result.squeeze(1)
304 | mdl_outs_reshaped = out_result.view(
305 | B, num_srl_args, num_cmp,
306 | self.num_sampled_frm, self.num_prop_per_frm
307 | )
308 |
309 | # B x num_srl_args x num_cmp x num_frms
310 | out_result_frame_score, out_result_frame_index = torch.max(
311 | mdl_outs_reshaped, dim=-1
312 | )
313 |
314 | props = inp['pad_proposals']
315 |
316 | _, num_props, prop_dim = props.shape
317 | assert (num_cmp * self.num_sampled_frm *
318 | self.num_prop_per_frm == num_props)
319 | props_reshaped = props.view(
320 | B, 1, num_cmp,
321 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim
322 | ).expand(
323 | B, num_srl_args, num_cmp,
324 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim
325 | )
326 |
327 | props_out = torch.gather(
328 | props_reshaped,
329 | dim=-2,
330 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand(
331 | B, num_srl_args, num_cmp,
332 | self.num_sampled_frm, 1, prop_dim
333 | )
334 | )
335 |
336 | props_out = props_out.squeeze(-2)
337 | # Not used in temporal. Make it all zeros
338 | vidf_outs = torch.zeros(1, 1, 1).expand(
339 | B, num_srl_args, self.num_frms
340 | )
341 | return {
342 | 'boxes': props_out,
343 | 'scores': out_result_frame_score,
344 | 'indexs': vidf_outs
345 | }
346 |
347 |
348 | class EvaluatorSPAT(EvaluatorSEP):
349 | def after_init(self):
350 | self.met_keys = ['avg1', 'avg1_cons', 'avg1_vidf', 'avg1_strict']
351 | self.grnd_eval = GroundEval_SPAT(self.cfg, self.comm)
352 |
353 | self.num_sampled_frm = self.num_frms
354 | # self.num_sampled_frm = self.cfg.misc.num_sampled_frm
355 | # self.num_prop_per_frm = self.cfg.misc.num_prop_per_frm
356 |
357 | def get_out_results_boxes(self, out_result_dict, inp):
358 | """
359 | get the correct boxes, scores, indexes per frame
360 | """
361 | assert isinstance(out_result_dict, dict)
362 |
363 | out_result = out_result_dict['mdl_outs_eval']
364 | num_cmp = inp['new_srl_idxs'].size(1)
365 |
366 | # B x num_verbs x num_srl_args x 4000
367 | B, num_verbs, num_srl_args, num_props = out_result.shape
368 |
369 | assert num_verbs == 1
370 | # B x num_srl_args x num_props
371 | mdl_outs_reshaped = out_result.view(
372 | B, num_srl_args,
373 | self.num_sampled_frm, num_cmp, self.num_prop_per_frm
374 | )
375 |
376 | # B x num_srl_args x num_frm x num_cmp
377 | out_result_frame_score, out_result_frame_index = torch.max(
378 | mdl_outs_reshaped, dim=-1
379 | )
380 |
381 | props = inp['pad_proposals']
382 |
383 | _, num_props, prop_dim = props.shape
384 | assert (num_cmp * self.num_sampled_frm *
385 | self.num_prop_per_frm == num_props)
386 | props_reshaped = props.view(
387 | B, 1, self.num_sampled_frm,
388 | num_cmp, self.num_prop_per_frm, prop_dim
389 | ).expand(
390 | B, num_srl_args, self.num_sampled_frm,
391 | num_cmp, self.num_prop_per_frm, prop_dim
392 | )
393 |
394 | props_out = torch.gather(
395 | props_reshaped,
396 | dim=-2,
397 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand(
398 | B, num_srl_args, self.num_sampled_frm, num_cmp, 1, prop_dim
399 | )
400 | )
401 |
402 | # B x num_srl x num_frms x num_cmp
403 | props_out = props_out.squeeze(-2)
404 | # For consistency across sep, temporal, spatial
405 | props_out = props_out.transpose(2, 3).contiguous()
406 |
407 | # Used in spatial.
408 | # Divide by 100
409 | # vidf_outs = torch.div(
410 | # out_result_frame_index.squeeze(-1),
411 | # self.num_prop_per_frm
412 | # ).long()
413 |
414 | # B x num_srl_args x num_frm
415 | vidf_outs = out_result_frame_score.argmax(dim=-1)
416 |
417 | # B x num_srl_args x num_frm x num_cmp
418 | score_out = out_result_frame_score.transpose(2, 3).contiguous()
419 |
420 | return {
421 | 'boxes': props_out,
422 | 'scores': score_out,
423 | 'indexs': vidf_outs
424 | }
425 |
--------------------------------------------------------------------------------
/code/extended_config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 | # import json
3 | # import yaml
4 | from _init_stuff import yaml
5 | from typing import Dict, Any
6 |
7 | with open('./configs/anet_srl_cfg.yml') as f:
8 | def_cfg = yaml.safe_load(f)
9 |
10 | cfg = CN(def_cfg)
11 | cfg.comm = CN()
12 |
13 | key_maps = {}
14 |
15 |
16 | def create_from_dict(dct: Dict[str, Any], prefix: str, cfg: CN):
17 | """
18 | Helper function to create yacs config from dictionary
19 | """
20 | dct_cfg = CN(dct, new_allowed=True)
21 | prefix_list = prefix.split('.')
22 | d = cfg
23 | for pref in prefix_list[:-1]:
24 | assert isinstance(d, CN)
25 | if pref not in d:
26 | setattr(d, pref, CN())
27 | d = d[pref]
28 | if hasattr(d, prefix_list[-1]):
29 | old_dct_cfg = d[prefix_list[-1]]
30 | dct_cfg.merge_from_other_cfg(old_dct_cfg)
31 |
32 | setattr(d, prefix_list[-1], dct_cfg)
33 | return cfg
34 |
35 |
36 | def update_from_dict(cfg: CN, dct: Dict[str, Any],
37 | key_maps: Dict[str, str] = None) -> CN:
38 | """
39 | Given original CfgNode (cfg) and input dictionary allows changing
40 | the cfg with the updated dictionary values
41 | Optional key_maps argument which defines a mapping between
42 | same keys of the cfg node. Only used for convenience
43 | Adapted from:
44 | https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L219
45 | """
46 | # Original cfg
47 | root = cfg
48 | if key_maps is None:
49 | key_maps = []
50 | # Change the input dictionary using keymaps
51 | # Now it is aligned with the cfg
52 | full_key_list = list(dct.keys())
53 | for full_key in full_key_list:
54 | if full_key in key_maps:
55 | cfg[full_key] = dct[full_key]
56 | new_key = key_maps[full_key]
57 | dct[new_key] = dct.pop(full_key)
58 |
59 | # Convert the cfg using dictionary input
60 | for full_key, v in dct.items():
61 | if root.key_is_deprecated(full_key):
62 | continue
63 | if root.key_is_renamed(full_key):
64 | root.raise_key_rename_error(full_key)
65 | key_list = full_key.split(".")
66 | d = cfg
67 | for subkey in key_list[:-1]:
68 | # Most important statement
69 | assert subkey in d, f'key {full_key} doesnot exist'
70 | d = d[subkey]
71 |
72 | subkey = key_list[-1]
73 | # Most important statement
74 | assert subkey in d, f'key {full_key} doesnot exist'
75 |
76 | value = cfg._decode_cfg_value(v)
77 |
78 | assert isinstance(value, type(d[subkey]))
79 | d[subkey] = value
80 |
81 | return cfg
82 |
83 |
84 | def post_proc_config(cfg: CN):
85 | """
86 | Add any post processing based on cfg
87 | """
88 | return cfg
89 |
--------------------------------------------------------------------------------
/code/main_dist.py:
--------------------------------------------------------------------------------
1 | """
2 | Main file for distributed training
3 | """
4 | import sys
5 | # from dat_loader import get_data
6 | from dat_loader_simple import get_data
7 | from mdl_selector import get_mdl_loss_eval
8 | from trn_utils import Learner, synchronize
9 |
10 | import torch
11 | import fire
12 | from functools import partial
13 |
14 | from extended_config import (
15 | cfg as conf,
16 | key_maps,
17 | CN,
18 | update_from_dict,
19 | post_proc_config
20 | )
21 |
22 | import resource
23 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
24 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
25 |
26 |
27 | def get_name_from_inst(inst):
28 | return inst.__class__.__name__
29 |
30 |
31 | def learner_init(uid: str, cfg: CN) -> Learner:
32 | # = get_mdl_loss(cfg)
33 | mdl_loss_eval = get_mdl_loss_eval(cfg)
34 | get_default_net = mdl_loss_eval['mdl']
35 | get_default_loss = mdl_loss_eval['loss']
36 | get_default_eval = mdl_loss_eval['eval']
37 |
38 | device = torch.device('cuda')
39 | # device = torch.device('cpu')
40 | data = get_data(cfg)
41 | comm = data.train_dl.dataset.comm
42 | mdl = get_default_net(cfg=cfg, comm=comm)
43 |
44 | # pretrained_state_dict = torch.load(cfg.pretrained_path)
45 | # to_load_state_dict = pretrained_state_dict
46 | # mdl.load_state_dict(to_load_state_dict)
47 |
48 | loss_fn = get_default_loss(cfg, comm)
49 | loss_fn.to(device)
50 | # if cfg.do_dist:
51 | # loss_fn.to(device)
52 |
53 | eval_fn = get_default_eval(cfg, comm, device)
54 | eval_fn.to(device)
55 | opt_fn = partial(torch.optim.Adam, betas=(0.9, 0.99))
56 |
57 | # unfreeze cfg to save the names
58 | cfg.defrost()
59 | module_name = mdl
60 | cfg.mdl_data_names = CN({
61 | 'trn_data': get_name_from_inst(data.train_dl.dataset),
62 | 'val_data': get_name_from_inst(data.valid_dl.dataset),
63 | 'trn_collator': get_name_from_inst(data.train_dl.collate_fn),
64 | 'val_collator': get_name_from_inst(data.valid_dl.collate_fn),
65 | 'mdl_name': get_name_from_inst(module_name),
66 | 'loss_name': get_name_from_inst(loss_fn),
67 | 'eval_name': get_name_from_inst(eval_fn),
68 | 'opt_name': opt_fn.func.__name__
69 | })
70 | cfg.freeze()
71 |
72 | learn = Learner(uid=uid, data=data, mdl=mdl, loss_fn=loss_fn,
73 | opt_fn=opt_fn, eval_fn=eval_fn, device=device, cfg=cfg)
74 |
75 | if cfg.do_dist:
76 | mdl.to(device)
77 | mdl = torch.nn.parallel.DistributedDataParallel(
78 | mdl, device_ids=[cfg.local_rank],
79 | output_device=cfg.local_rank, broadcast_buffers=True,
80 | find_unused_parameters=True)
81 | elif cfg.do_dp:
82 | # Use data parallel
83 | mdl = torch.nn.DataParallel(mdl)
84 |
85 | mdl = mdl.to(device)
86 |
87 | return learn
88 |
89 |
90 | def main_dist(uid: str, **kwargs):
91 | """
92 | uid is a unique identifier for the experiment name
93 | Can be kept same as a previous run, by default will start executing
94 | from latest saved model
95 | **kwargs: allows arbit arguments of cfg to be changed
96 | """
97 | cfg = conf
98 | num_gpus = torch.cuda.device_count()
99 | cfg.num_gpus = num_gpus
100 | cfg.uid = uid
101 | cfg.cmd = sys.argv
102 | if num_gpus > 1:
103 | if 'local_rank' in kwargs:
104 | # We are doing distributed parallel
105 | cfg.do_dist = True
106 | torch.cuda.set_device(kwargs['local_rank'])
107 | torch.distributed.init_process_group(
108 | backend="nccl", init_method="env://"
109 | )
110 | synchronize()
111 | else:
112 | # We are doing data parallel
113 | cfg.do_dist = False
114 | # cfg.do_dp = True
115 | # Update the config file depending on the command line args
116 | cfg = update_from_dict(cfg, kwargs, key_maps)
117 | cfg = post_proc_config(cfg)
118 | # Freeze the cfg, can no longer be changed
119 | cfg.freeze()
120 | # print(cfg)
121 | # Initialize learner
122 | learn = learner_init(uid, cfg)
123 | # Train or Test
124 | if not (cfg.only_val or cfg.only_test or cfg.overfit_batch):
125 | learn.fit(epochs=cfg.train.epochs, lr=cfg.train.lr)
126 | if cfg.run_final_val:
127 | print('Running Final Validation using best model')
128 | learn.load_model_dict(
129 | resume_path=learn.model_file,
130 | load_opt=False
131 | )
132 | val_loss, val_acc, _ = learn.validate(
133 | db={'valid': learn.data.valid_dl},
134 | write_to_file=True
135 | )
136 | print(val_loss)
137 | print(val_acc)
138 | else:
139 | pass
140 | else:
141 | if cfg.overfit_batch:
142 | learn.overfit_batch(1000, 1e-4)
143 | if cfg.only_val:
144 | val_loss, val_acc, _ = learn.validate(
145 | db={'valid': learn.data.valid_dl},
146 | write_to_file=True
147 | )
148 | print(val_loss)
149 | print(val_acc)
150 | # learn.testing(learn.data.valid_dl)
151 | pass
152 | if cfg.only_test:
153 | # learn.testing(learn.data.test_dl)
154 | test_loss, test_acc, _ = learn.validate(
155 | db=learn.data.test_dl)
156 | print(test_loss)
157 | print(test_acc)
158 |
159 | return
160 |
161 |
162 | if __name__ == '__main__':
163 | fire.Fire(main_dist)
164 |
--------------------------------------------------------------------------------
/code/mdl_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base Model and Loss
3 | Other models build on top of this.
4 | Basically, have all the required args here.
5 | """
6 | from torch import nn
7 | from munch import Munch
8 |
9 |
10 | class AnetBaseMdl(nn.Module):
11 | def __init__(self, cfg, comm):
12 | super().__init__()
13 | self.cfg = cfg
14 | # Common stuff that needs to be passed around
15 | if comm is not None:
16 | assert isinstance(comm, (dict, Munch))
17 | self.comm = Munch(comm)
18 | else:
19 | self.comm = Munch()
20 |
21 | self.set_args()
22 | self.after_init()
23 |
24 | def after_init(self):
25 | self.build_model()
26 |
27 | def build_model(self):
28 | self.build_lang_model()
29 | self.build_vis_model()
30 | self.build_conc_model()
31 |
32 | def set_args(self):
33 | """
34 | Place to set all the required arguments, taken from cfg
35 | """
36 | # Vocab size needs to be in the ds
37 | # Can be added after after creation of the DATASET
38 | self.vocab_size = self.comm.vocab_size
39 |
40 | # Number of object classes
41 | # Also after creation of dataset.
42 | # Perhaps a good idea to keep all stuff
43 | # to be passed from ds to mdl in a separate
44 | # argument. Could be really helpful
45 | self.detect_size = self.comm.detect_size
46 |
47 | # Input encoding size
48 | # This is the size of the embedding for each word
49 | self.input_encoding_size = self.cfg.mdl.input_encoding_size
50 |
51 | # Hidden dimension of RNN
52 | self.rnn_size = self.cfg.mdl.rnn.rnn_size
53 |
54 | # Number of layers in RNN
55 | self.num_layers = self.cfg.mdl.rnn.num_layers
56 |
57 | # Dropout probability of LM
58 | self.drop_prob_lm = self.cfg.mdl.rnn.drop_prob_lm
59 |
60 | # itod
61 | self.itod = self.comm.itod
62 |
63 | self.num_sampled_frm = self.cfg.ds.num_sampled_frm
64 | self.num_prop_per_frm = self.comm.num_prop_per_frm
65 |
66 | self.unk_idx = int(self.comm.wtoi['UNK'])
67 |
68 | # Temporal attention size
69 | self.t_attn_size = self.cfg.ds.t_attn_size
70 |
71 | # srl_arg_len
72 | self.srl_arg_len = self.cfg.misc.srl_arg_length
73 |
74 | self.set_args_mdl()
75 | self.set_args_conc()
76 |
77 | def set_args_mdl(self):
78 | """
79 | Mdl specific args
80 | """
81 | return
82 |
83 | def set_args_conc(self):
84 | """
85 | Conc Type specific args
86 | """
87 | return
88 |
89 | def build_lang_model(self):
90 | """
91 | How to encode the input sentence
92 | """
93 | raise NotImplementedError
94 |
95 | def build_vis_model(self):
96 | """
97 | How to encode the visual features
98 | How to encode proposal features
99 | and rgb, motion features
100 | """
101 | raise NotImplementedError
102 |
103 | def build_conc_model(self):
104 | """
105 | How to concatenate language and visual features
106 | """
107 | raise NotImplementedError
108 |
109 |
110 | def main():
111 | from _init_stuff import Fpath, Arr, yaml
112 | from yacs.config import CfgNode as CN
113 | from dat_loader_simple import get_data
114 | cfg = CN(yaml.safe_load(open('./configs/anet_srl_cfg.yml')))
115 | data = get_data(cfg)
116 | comm = data.train_dl.dataset.comm
117 | mdl = AnetBaseMdl(cfg, comm)
118 | return mdl
119 |
120 |
121 | if __name__ == '__main__':
122 | main()
123 |
--------------------------------------------------------------------------------
/code/mdl_conc_sep.py:
--------------------------------------------------------------------------------
1 | """
2 | Take care of SEP case.
3 | """
4 |
5 | from mdl_conc_single import ConcBase
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from mdl_srl_utils import combine_first_ax
10 | from box_utils import bbox_overlaps
11 |
12 |
13 | class ConcSEP(ConcBase):
14 | def conc_encode(self, conc_feats, inp):
15 | nfrm = self.num_sampled_frm
16 | nppf = self.num_prop_per_frm
17 | ncmp = inp['new_srl_idxs'].size(1)
18 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, ncmp)
19 |
20 | def simple_obj_interact_input(self, prop_seg_feats, inp):
21 | B, num_cmp, num_props, psdim = prop_seg_feats.shape
22 | return self.simple_obj_interact(
23 | prop_seg_feats, inp,
24 | num_cmp, self.num_sampled_frm,
25 | self.num_prop_per_frm
26 | )
27 |
28 | def set_args_conc(self):
29 | self.nfrms = self.num_sampled_frm
30 | self.nppf = self.num_prop_per_frm
31 |
32 | def get_num_cmp_msk(self, inp, out_shape):
33 | num_cmp = inp['new_srl_idxs'].size(1)
34 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape
35 | num_cmp_msk = inp['num_cmp_msk'].view(
36 | B, num_cmp, 1, 1
37 | ).expand(
38 | B, num_cmp, num_srl_args,
39 | self.num_sampled_frm * self.num_prop_per_frm
40 | ).contiguous(
41 | ).view(*out_shape)
42 | return num_cmp_msk
43 |
44 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp):
45 | B, num_cmp, num_props, pdim = prop_feats.shape
46 | prop_seg_feats = torch.cat(
47 | [
48 | prop_feats.view(
49 | B, num_cmp, self.num_sampled_frm,
50 | self.num_prop_per_frm, prop_feats.size(-1)
51 | ),
52 | seg_feats.unsqueeze(-2).expand(
53 | B, num_cmp, self.num_sampled_frm,
54 | self.num_prop_per_frm, seg_feats.size(-1)
55 | )
56 | ], dim=-1
57 | ).view(
58 | B, num_cmp, self.num_sampled_frm*self.num_prop_per_frm,
59 | prop_feats.size(-1) + seg_feats.size(-1)
60 | )
61 | # B x num_cmp x nfrm*nppf x psdim
62 | return prop_seg_feats
63 |
64 | def compute_fin_scores(self, conc_out_dict, inp, vidf_outs=None):
65 | """
66 | output fin scores should be of shape
67 | B x num_cmp
68 | prop_scores: B x num_cmp x num_srl_args x num_props
69 | """
70 | prop_scores1 = conc_out_dict['conc_feats_out'].clone().detach()
71 | prop_scores = torch.sigmoid(prop_scores1)
72 | # prop_scores = prop_scores1
73 | if self.cfg.mdl.use_vis_msk:
74 | # B x num_cmp x num_srl_args
75 | prop_scores_max_boxes, _ = torch.max(prop_scores, dim=-1)
76 |
77 | # B x num_cmp x num_srl_args
78 | srl_arg_inds_msk = inp['srl_arg_inds_msk'].float()
79 | B, num_verbs, num_srl_args = srl_arg_inds_msk.shape
80 |
81 | num_cmp = prop_scores.size(1)
82 |
83 | if vidf_outs is not None:
84 | # add vidf outs to the verb places
85 | vidf_outs = torch.sigmoid(vidf_outs)
86 |
87 | # B x num_cmp -> B x num_cmp x num_srl_args
88 | vidf_outs = vidf_outs.unsqueeze(-1).expand(
89 | *prop_scores_max_boxes.shape
90 | )
91 | vmsk = inp['verb_ind_in_srl']
92 |
93 | if vmsk.size(1) == 1 and num_cmp > 1:
94 | vmsk = vmsk.expand(-1, num_cmp)
95 | # B x num_cmp
96 | vmsk = vmsk.view(
97 | B, num_cmp, 1).expand(
98 | B, num_cmp, num_srl_args
99 | )
100 | prop_scores_max_boxes.scatter_(
101 | dim=2,
102 | index=vmsk,
103 | src=vidf_outs
104 | )
105 |
106 | prop_scores_max_boxes = prop_scores_max_boxes * srl_arg_inds_msk
107 |
108 | # b x num_cmp
109 | fin_scores_eval = prop_scores_max_boxes.sum(
110 | dim=-1) / srl_arg_inds_msk.sum(dim=-1)
111 |
112 | verb_msk = inp['num_cmp_msk']
113 | fin_scores_eval = fin_scores_eval * verb_msk.float()
114 |
115 | fin_scores_loss = prop_scores_max_boxes * verb_msk.unsqueeze(
116 | -1).expand(*prop_scores_max_boxes.shape).float()
117 | return {
118 | # B x num_cmp
119 | 'fin_scores_eval': fin_scores_eval,
120 | # B x num_cmp x num_srl_args
121 | 'fin_scores_loss': fin_scores_loss
122 | }
123 |
124 | else:
125 | # B x num_cmp x num_cmp x num_srl_args
126 | prop_scores_max_boxes, _ = torch.max(prop_scores, dim=-1)
127 | # B x num_cmp x num_cmp
128 | fin_scores = prop_scores_max_boxes.sum(dim=-1)
129 | return fin_scores
130 |
131 | def forward(self, inp):
132 | """
133 | Main difference is that prop feats/seg features
134 | have an extra dimension
135 | """
136 | # B x 6 x 5 x 40
137 | # 6 is num_cmp for a sent
138 | # 5 is num args in a sent
139 | # 40 is seq length for each arg
140 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape
141 | # B*num_cmp x seq_len
142 | src_toks = self.get_srl_arg_seq_to_sent_seq(inp)
143 | # B*num_cmp x seq_len
144 | src_lens = inp['srl_arg_word_mask_len'].view(B*num_verbs, -1)
145 | # B*num_cmp x seq_len x 256
146 | lstm_outs = self.lang_encode(src_toks, src_lens)
147 | lstm_encoded = lstm_outs['lstm_full_output']
148 |
149 | # B x num_cmp x 5 x 512
150 | srl_arg_lstm_encoded = self.retrieve_srl_arg_from_lang_encode(
151 | lstm_encoded, inp
152 | )
153 |
154 | # Get visual features
155 | # B x num_cmp x 1000 x 512
156 | prop_feats = self.prop_feats_encode(inp)
157 | # B, num_cmp, num_props, pdim = prop_feats.shape
158 |
159 | # Get seg features
160 | # B x num_cmp x 10 x 512
161 | seg_feats = self.seg_feats_encode(inp)
162 |
163 | # B x num_cmp x nfrm*nppf x psdim
164 | prop_seg_feats = self.concat_prop_seg_feats(prop_feats, seg_feats, inp)
165 |
166 | prop_seg_feats = self.simple_obj_interact_input(
167 | prop_seg_feats, inp
168 | )
169 |
170 | num_cmp = inp['new_srl_idxs'].size(1)
171 | if srl_arg_lstm_encoded.size(1) == 1 and num_cmp > 1:
172 | srl_arg_lstm_encoded = srl_arg_lstm_encoded.expand(
173 | -1, num_cmp, -1, -1
174 | )
175 |
176 | conc_feats = self.concate_vis_lang_feats(
177 | prop_seg_feats, srl_arg_lstm_encoded
178 | )
179 |
180 | # B x num_cmp x num_srl_args x num_props
181 | conc_feats_out_dict = self.conc_encode(conc_feats, inp)
182 | conc_feats_out = conc_feats_out_dict['conc_feats_out']
183 |
184 | seg_feats_for_verb, verb_feats = self.get_seg_verb_feats_to_process(
185 | seg_feats, srl_arg_lstm_encoded, lstm_outs, inp
186 | )
187 |
188 | if verb_feats.size(1) == 1 and num_cmp > 1:
189 | verb_feats = verb_feats.expand(-1, num_cmp, -1)
190 |
191 | # B x num_cmp
192 | vidf_outs = self.compute_seg_verb_feats_out(
193 | seg_feats_for_verb, verb_feats
194 | )
195 | fin_scores = self.compute_fin_scores(
196 | conc_feats_out_dict, inp, vidf_outs
197 | )
198 |
199 | num_cmp_msk = self.get_num_cmp_msk(inp, conc_feats_out.shape)
200 |
201 | srl_ind_msk = inp['srl_arg_inds_msk']
202 | if srl_ind_msk.size(1) == 1 and num_cmp > 1:
203 | srl_ind_msk = srl_ind_msk.expand(
204 | -1, num_cmp, -1, -1
205 | )
206 | srl_ind_msk = srl_ind_msk.unsqueeze(-1).expand(
207 | *conc_feats_out.shape)
208 | mdl_outs_eval = torch.sigmoid(
209 | conc_feats_out) * srl_ind_msk.float() * num_cmp_msk.float()
210 |
211 | return {
212 | 'mdl_outs': conc_feats_out,
213 | 'mdl_outs_eval': mdl_outs_eval,
214 | 'vidf_outs': vidf_outs,
215 | 'fin_scores_loss': fin_scores['fin_scores_loss'],
216 | 'fin_scores': fin_scores['fin_scores_eval']
217 | }
218 |
219 |
220 | class LossB_SEP(nn.Module):
221 | """
222 | Loss Function (for a batch) for SEP case.
223 | Specifically, we need to have a separate verb loss
224 | Also, handling of some functions is different
225 | from single video case
226 | """
227 |
228 | def __init__(self, cfg, comm):
229 | super().__init__()
230 | self.cfg = cfg
231 | self.comm = comm
232 | self.loss_keys = ['loss', 'mdl_out_loss', 'verb_loss']
233 | self.loss_lambda = self.cfg.loss.loss_lambda
234 | self.after_init()
235 |
236 | def after_init(self):
237 | pass
238 |
239 | def get_targets_from_overlaps(self, overlaps, inp):
240 | """
241 | Use the given overlaps to produce the targets
242 | overlaps: B x num_cmp x 1000 x 100
243 | """
244 | targets = overlaps
245 |
246 | srl_boxes = inp['srl_boxes']
247 | B, num_verbs, num_srl_args, num_box_per_srl = srl_boxes.shape
248 | B, num_cmp, num_props, num_gt_box = targets.shape
249 |
250 | if num_verbs == 1 and num_cmp > 1:
251 | srl_boxes = srl_boxes.expand(-1, num_cmp, -1, -1)
252 |
253 | srl_boxes_reshaped = srl_boxes.view(
254 | B, num_cmp, num_srl_args, 1, num_box_per_srl).expand(
255 | B, num_cmp, num_srl_args, num_props, num_box_per_srl)
256 |
257 | targets_reshaped = targets.view(
258 | B, num_cmp, 1, num_props, num_gt_box).expand(
259 | B, num_cmp, num_srl_args, num_props, num_gt_box)
260 |
261 | # Choose only those proposals which are ground-truth
262 | # for given srl
263 | targets_to_use = torch.gather(
264 | targets_reshaped, dim=-1, index=srl_boxes_reshaped)
265 |
266 | srl_boxes_lens = inp['srl_boxes_lens']
267 | targets_to_use = (
268 | targets_to_use * srl_boxes_lens.float().unsqueeze(
269 | -2).expand(*targets_to_use.shape)
270 | )
271 |
272 | targets_to_use = targets_to_use.max(dim=-1)[0] > 0.5
273 |
274 | return targets_to_use
275 |
276 | def compute_overlaps(self, inp):
277 |
278 | pad_props = inp['pad_proposals']
279 | gt_bboxs = inp['pad_gt_bboxs']
280 | frm_msk = inp['pad_frm_mask']
281 | pnt_msk = inp['pad_pnt_mask']
282 |
283 | assert len(pnt_msk.shape) == 3
284 |
285 | B = pad_props.size(0)
286 | num_cmp = pad_props.size(1)
287 | pad_props = combine_first_ax(pad_props)
288 | gt_bboxs = combine_first_ax(gt_bboxs)
289 | frm_msk = combine_first_ax(frm_msk)
290 |
291 | pnt_msk = combine_first_ax(pnt_msk)
292 |
293 | overlaps = bbox_overlaps(
294 | pad_props, gt_bboxs,
295 | (frm_msk | pnt_msk[:, :].unsqueeze(-1)))
296 | overlaps = overlaps.view(B, num_cmp, *overlaps.shape[1:])
297 |
298 | return overlaps
299 |
300 | def compute_loss_targets(self, inp):
301 | """
302 | Compute the targets, based on iou
303 | overlaps
304 | """
305 | overlaps = self.compute_overlaps(inp)
306 | B, ncmp, nprop, ngt = overlaps.shape
307 | overlaps_msk = overlaps.new_zeros(*overlaps.shape)
308 |
309 | targ_cmp = inp['target_cmp']
310 | # overlaps_msk[:, targ_cmp, ...] = 1
311 | overlaps_msk.scatter_(
312 | dim=1,
313 | index=targ_cmp.view(B, 1, 1, 1).expand(B, ncmp, nprop, ngt),
314 | src=overlaps_msk.new_ones(*overlaps_msk.shape)
315 | )
316 |
317 | overlaps_one_targ = overlaps * overlaps_msk
318 |
319 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp)
320 | targets_all = self.get_targets_from_overlaps(overlaps, inp)
321 | return {
322 | 'targets_one': targets_one,
323 | 'targets_all': targets_all
324 | }
325 |
326 | def compute_mdl_loss(self, mdl_outs, targets_one, inp):
327 | weights = None
328 | tot_loss = F.binary_cross_entropy_with_logits(
329 | mdl_outs, target=targets_one.float(),
330 | weight=weights,
331 | reduction='none'
332 | )
333 |
334 | # B x num_cmp
335 | num_cmp_msk = inp['num_cmp_msk']
336 | num_cmp = num_cmp_msk.size(1)
337 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask']
338 | num_verbs = srl_arg_boxes_mask.size(1)
339 | if num_verbs == 1 and num_cmp > 1:
340 | srl_arg_boxes_mask = srl_arg_boxes_mask.expand(-1, num_cmp, -1)
341 |
342 | B, num_cmp, num_srl_args = srl_arg_boxes_mask.shape
343 |
344 | boxes_msk = num_cmp_msk.unsqueeze(
345 | -1).expand(*srl_arg_boxes_mask.shape).float()
346 |
347 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 1000
348 | boxes_msk = boxes_msk.unsqueeze(
349 | -1).expand(*targets_one.shape)
350 |
351 | tot_loss = tot_loss * boxes_msk
352 |
353 | multiplier = tot_loss.size(-1)
354 | if srl_arg_boxes_mask.max() > 0:
355 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte())
356 | else:
357 | # TODO: NEED TO check what is wrong here
358 | out_loss = tot_loss
359 |
360 | mdl_out_loss = out_loss.mean() * multiplier
361 |
362 | return mdl_out_loss
363 |
364 | def compute_vidf_loss_simple(self, vidf_outs, inp):
365 | """
366 | vidf_outs are fin scores: B x ncmp x nfrms
367 | """
368 | B, ncmp, nfrm = vidf_outs.shape
369 | targs = vidf_outs.new_zeros(*vidf_outs.shape)
370 |
371 | targ_cmp = inp['target_cmp']
372 |
373 | targs.scatter_(
374 | dim=1,
375 | index=targ_cmp.view(B, 1, 1).expand(B, ncmp, nfrm),
376 | src=targs.new_ones(*targs.shape)
377 | )
378 |
379 | # B x ncmp x nfrms
380 | out_loss = F.binary_cross_entropy(vidf_outs, targs, reduction='none')
381 |
382 | mult = 1. / nfrm
383 |
384 | # B x ncmp
385 | msk = inp['num_cmp_msk']
386 | out_loss = torch.masked_select(out_loss.sum(dim=-1) * msk.float(),
387 | msk.byte()) * mult
388 | return out_loss.mean()
389 |
390 | def compute_vidf_loss(self, vidf_outs, inp):
391 | B, num_cmp, num_srl_args = vidf_outs.shape
392 | box_msk = inp['srl_arg_boxes_mask']
393 | srl_arg_ind_msk = inp['srl_arg_inds_msk']
394 | vidf_outs = ((vidf_outs * box_msk.float()).sum(dim=-1) /
395 | srl_arg_ind_msk.sum(dim=-1).float())
396 | vidf_targs = vidf_outs.new_zeros(*vidf_outs.shape)
397 |
398 | targ_cmp = inp['target_cmp']
399 |
400 | vidf_targs.scatter_(
401 | dim=1,
402 | index=targ_cmp.unsqueeze(-1).expand(*vidf_targs.shape),
403 | src=vidf_targs.new_ones(*vidf_targs.shape)
404 | )
405 |
406 | vidf_loss = F.binary_cross_entropy( #
407 | vidf_outs, vidf_targs,
408 | reduction='none'
409 | )
410 | msk = inp['num_cmp_msk']
411 | vidf_loss = vidf_loss * msk.float()
412 | vidf_loss = torch.masked_select(vidf_loss, msk.byte())
413 | return vidf_loss.mean()
414 |
415 | def forward(self, out, inp):
416 | targets_all = self.compute_loss_targets(inp)
417 | targets_n = targets_all['targets_one']
418 |
419 | mdl_outs = out['mdl_outs']
420 |
421 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp)
422 |
423 | verb_outs = out['vidf_outs']
424 |
425 | verb_loss = F.binary_cross_entropy_with_logits(
426 | verb_outs,
427 | inp['verb_cmp'].float(),
428 | reduction='none'
429 | )
430 |
431 | vcc_msk = inp['verb_cross_cmp_msk'].float()
432 | vcc_msk = (vcc_msk.sum(dim=-1) > 0).float()
433 |
434 | verb_loss = verb_loss * vcc_msk
435 | verb_loss = torch.masked_select(
436 | verb_loss, vcc_msk.byte()).mean()
437 |
438 | # out_loss = mdl_out_loss + verb_loss
439 | out_loss = mdl_out_loss
440 |
441 | out_loss_dict = {
442 | 'loss': out_loss,
443 | 'mdl_out_loss': mdl_out_loss,
444 | 'verb_loss': verb_loss
445 | }
446 |
447 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()}
448 |
--------------------------------------------------------------------------------
/code/mdl_conc_single.py:
--------------------------------------------------------------------------------
1 | """
2 | Concatenate to a Single Video
3 | """
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from box_utils import bbox_overlaps
8 |
9 |
10 | class ConcBase(nn.Module):
11 | """
12 | Base Model for concatenation.
13 | Kept for Historical Reasons
14 | """
15 |
16 | def set_args_conc(self):
17 | """
18 | Conc Type specific args
19 | """
20 | return
21 |
22 |
23 | class ConcTEMP(ConcBase):
24 | def conc_encode(self, conc_feats, inp):
25 | ncmp = inp['new_srl_idxs'].size(1)
26 | nfrm = ncmp * self.num_sampled_frm
27 | nppf = self.num_prop_per_frm
28 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, 1)
29 |
30 | def simple_obj_interact_input(self, prop_seg_feats, inp):
31 | # B, num_cmp, num_props, psdim = prop_seg_feats.shape
32 | num_cmp = inp['new_srl_idxs'].size(1)
33 | return self.simple_obj_interact(
34 | prop_seg_feats, inp, 1,
35 | num_cmp * self.num_sampled_frm,
36 | self.num_prop_per_frm
37 | )
38 |
39 | def get_num_cmp_msk(self, inp, out_shape):
40 | num_cmp = inp['new_srl_idxs'].size(1)
41 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape
42 | num_cmp_msk = inp['num_cmp_msk'].view(
43 | B, 1, 1, num_cmp, 1
44 | ).expand(
45 | B, num_verbs, num_srl_args, num_cmp,
46 | self.num_sampled_frm * self.num_prop_per_frm
47 | ).contiguous(
48 | ).view(*out_shape)
49 | return num_cmp_msk
50 |
51 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp):
52 | B, num_v_frms, sdim = seg_feats.shape
53 | num_cmp = inp['new_srl_idxs'].size(1)
54 |
55 | prop_seg_feats = torch.cat(
56 | [prop_feats.view(
57 | B, 1, num_cmp * self.num_sampled_frm,
58 | self.num_prop_per_frm, prop_feats.size(-1)),
59 | seg_feats.view(B, 1, num_v_frms, 1, sdim).expand(
60 | B, 1, num_cmp * self.num_sampled_frm,
61 | self.num_prop_per_frm, sdim)
62 | ], dim=-1).view(
63 | B, 1, num_cmp * self.num_sampled_frm * self.num_prop_per_frm,
64 | prop_feats.size(-1) + seg_feats.size(-1)
65 | )
66 | return prop_seg_feats
67 |
68 | def forward(self, inp):
69 | """
70 | Main difference is that prop feats/seg features
71 | have an extra dimension
72 | """
73 | # B x 6 x 5 x 40
74 | # 6 is num_cmp for a sent
75 | # 5 is num args in a sent
76 | # 40 is seq length for each arg
77 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape
78 | # B*num_cmp x seq_len
79 | src_toks = self.get_srl_arg_seq_to_sent_seq(inp)
80 | # B*num_cmp x seq_len
81 | src_lens = inp['srl_arg_word_mask_len'].view(B*num_verbs, -1)
82 | # B*num_cmp x seq_len x 256
83 | lstm_outs = self.lang_encode(src_toks, src_lens)
84 | lstm_encoded = lstm_outs['lstm_full_output']
85 |
86 | # B x 1 x 5 x 512
87 | srl_arg_lstm_encoded = self.retrieve_srl_arg_from_lang_encode(
88 | lstm_encoded, inp
89 | )
90 |
91 | # Get visual features
92 | # B x 40*100 x 512
93 | prop_feats = self.prop_feats_encode(inp)
94 |
95 | # Get seg features
96 | # B x 40 x 512
97 | seg_feats = self.seg_feats_encode(inp)
98 | B, num_v_frms, sdim = seg_feats.shape
99 | # Simple conc seg_feats
100 | prop_seg_feats = self.concat_prop_seg_feats(
101 | prop_feats, seg_feats, inp
102 | )
103 |
104 | # Object Interaction if to be done
105 | prop_seg_feats = self.simple_obj_interact_input(
106 | prop_seg_feats, inp
107 | )
108 |
109 | # B x 1 x num_srl_args x 4*num_props x vf+lf dim
110 | conc_feats = self.concate_vis_lang_feats(
111 | prop_seg_feats, srl_arg_lstm_encoded
112 | )
113 |
114 | # B x num_cmp x num_srl_args x 4*num_props x vf+lf dim
115 | conc_feats_out_dict = self.conc_encode(conc_feats, inp)
116 | conc_feats_out = conc_feats_out_dict['conc_feats_out']
117 |
118 | num_cmp_msk = self.get_num_cmp_msk(inp, conc_feats_out.shape)
119 | srl_ind_msk = inp['srl_arg_inds_msk'].unsqueeze(-1).expand(
120 | *conc_feats_out.shape)
121 | conc_feats_out_eval = torch.sigmoid(
122 | conc_feats_out) * srl_ind_msk.float() * num_cmp_msk.float()
123 |
124 | return {
125 | 'mdl_outs': conc_feats_out,
126 | 'mdl_outs_eval': conc_feats_out_eval,
127 | }
128 |
129 |
130 | class ConcSPAT(ConcTEMP):
131 | def conc_encode(self, conc_feats, inp):
132 | ncmp = inp['new_srl_idxs'].size(1)
133 | nfrm = self.num_sampled_frm
134 | nppf = ncmp * self.num_prop_per_frm
135 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, 1)
136 |
137 | def simple_obj_interact_input(self, prop_seg_feats, inp):
138 | num_cmp = inp['new_srl_idxs'].size(1)
139 | return self.simple_obj_interact(
140 | prop_seg_feats, inp, 1,
141 | self.num_sampled_frm, num_cmp * self.num_prop_per_frm
142 | )
143 |
144 | def get_num_cmp_msk(self, inp, out_shape):
145 | num_cmp = inp['new_srl_idxs'].size(1)
146 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape
147 | num_cmp_msk = inp['num_cmp_msk'].view(
148 | B, 1, 1, 1, num_cmp, 1
149 | ).expand(
150 | B, num_verbs, num_srl_args, self.num_sampled_frm,
151 | num_cmp, self.num_prop_per_frm
152 | ).contiguous(
153 | ).view(*out_shape)
154 | return num_cmp_msk
155 |
156 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp):
157 | B, num_v_frms, sdim = seg_feats.shape
158 | num_cmp = inp['new_srl_idxs'].size(1)
159 | prop_seg_feats = torch.cat(
160 | [
161 | prop_feats.view(
162 | B, 1, self.num_sampled_frm * num_cmp,
163 | self.num_prop_per_frm, prop_feats.size(-1)
164 | ), seg_feats.view(B, 1, num_v_frms, 1, sdim).expand(
165 | B, 1, self.num_sampled_frm * num_cmp,
166 | self.num_prop_per_frm, sdim
167 | )
168 | ],
169 | dim=-1
170 | ).view(
171 | B, 1, self.num_sampled_frm * num_cmp * self.num_prop_per_frm,
172 | prop_feats.size(-1) + seg_feats.size(-1)
173 | )
174 | return prop_seg_feats
175 |
176 | def forward(self, inp):
177 | return ConcTEMP.forward(self, inp)
178 |
179 |
180 | class LossB_TEMP(nn.Module):
181 | def __init__(self, cfg, comm):
182 | super().__init__()
183 | self.cfg = cfg
184 | self.comm = comm
185 | self.loss_keys = ['loss', 'mdl_out_loss']
186 | self.loss_lambda = self.cfg.loss.loss_lambda
187 | self.after_init()
188 |
189 | def after_init(self):
190 | pass
191 |
192 | def get_targets_from_overlaps(self, overlaps, inp):
193 | """
194 | Use the given overlaps to produce the targets
195 | overlaps: B x num_cmp x 1000 x 100
196 | """
197 | # to_consider = overlaps > 0.5
198 | targets = overlaps
199 |
200 | srl_boxes = inp['srl_boxes']
201 | # B, num_cmp, num_srl_args, num_box_per_srl = srl_boxes.shape
202 | B, num_verbs, num_srl_args, num_box_per_srl = srl_boxes.shape
203 | B, num_props, num_gt_box = targets.shape
204 |
205 | srl_boxes_reshaped = srl_boxes.view(
206 | B, num_verbs, num_srl_args, 1, num_box_per_srl).expand(
207 | B, num_verbs, num_srl_args, num_props, num_box_per_srl)
208 |
209 | targets_reshaped = targets.view(
210 | B, 1, 1, num_props, num_gt_box).expand(
211 | B, num_verbs, num_srl_args, num_props, num_gt_box)
212 |
213 | # Choose only those proposals which are ground-truth
214 | # for given srl
215 | targets_to_use = torch.gather(
216 | targets_reshaped, dim=-1, index=srl_boxes_reshaped)
217 |
218 | srl_boxes_lens = inp['srl_boxes_lens']
219 | targets_to_use = (
220 | targets_to_use * srl_boxes_lens.float().unsqueeze(
221 | -2).expand(*targets_to_use.shape)
222 | )
223 |
224 | targets_to_use = targets_to_use.max(dim=-1)[0] > 0.5
225 |
226 | return targets_to_use
227 |
228 | def compute_overlaps(self, inp):
229 |
230 | pad_props = inp['pad_proposals']
231 | gt_bboxs = inp['pad_gt_bboxs']
232 | frm_msk = inp['pad_frm_mask']
233 | pnt_msk = inp['pad_pnt_mask']
234 |
235 | try:
236 | overlaps = bbox_overlaps(
237 | pad_props, gt_bboxs,
238 | (frm_msk | pnt_msk.unsqueeze(-1))
239 | )
240 | except:
241 | import pdb
242 | pdb.set_trace()
243 | overlaps = bbox_overlaps(
244 | pad_props, gt_bboxs,
245 | (frm_msk | pnt_msk.unsqueeze(-1)))
246 |
247 | return overlaps
248 |
249 | def compute_loss_targets(self, inp):
250 | """
251 | Compute the targets, based on iou
252 | overlaps
253 | """
254 | num_cmp = inp['new_srl_idxs'].size(1)
255 | overlaps = self.compute_overlaps(inp)
256 | B, num_tot_props, num_gt = overlaps.shape
257 | assert num_tot_props % num_cmp == 0
258 | num_props = num_tot_props // num_cmp
259 | overlaps_msk = overlaps.new_zeros(B, num_cmp, num_props, num_gt)
260 |
261 | targ_cmp = inp['target_cmp']
262 |
263 | overlaps_msk.scatter_(
264 | dim=1,
265 | index=targ_cmp.view(B, 1, 1, 1).expand(
266 | B, num_cmp, num_props, num_gt),
267 | src=overlaps_msk.new_ones(*overlaps_msk.shape)
268 | )
269 |
270 | overlaps_msk = overlaps_msk.view(B, num_tot_props, num_gt)
271 | overlaps_one_targ = overlaps * overlaps_msk
272 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp)
273 | return {
274 | 'targets_one': targets_one,
275 | }
276 |
277 | def compute_mdl_loss(self, mdl_outs, targets_one, inp):
278 | weights = None
279 | tot_loss = F.binary_cross_entropy_with_logits(
280 | mdl_outs, target=targets_one.float(),
281 | weight=weights,
282 | reduction='none'
283 | )
284 |
285 | num_cmp_msk = inp['num_cmp_msk']
286 | B, num_cmp = num_cmp_msk.shape
287 |
288 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask']
289 | B, num_verbs, num_srl_args = srl_arg_boxes_mask.shape
290 |
291 | boxes_msk = (
292 | srl_arg_boxes_mask.view(
293 | B, num_verbs, num_srl_args, 1).expand(
294 | B, num_verbs, num_srl_args, num_cmp).float() *
295 | num_cmp_msk.view(
296 | B, 1, 1, num_cmp).expand(
297 | B, num_verbs, num_srl_args, num_cmp).float()
298 | )
299 | num_props_per_vid = targets_one.size(-1) // num_cmp
300 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 4000
301 | boxes_msk = boxes_msk.unsqueeze(
302 | -1).expand(
303 | B, num_verbs, num_srl_args, num_cmp, num_props_per_vid
304 | ).contiguous().view(
305 | B, num_verbs, num_srl_args, num_cmp * num_props_per_vid)
306 |
307 | multiplier = tot_loss.size(-1)
308 | if srl_arg_boxes_mask.max() > 0:
309 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte())
310 | else:
311 | # TODO: NEED TO check what is wrong here
312 | out_loss = tot_loss
313 | mdl_out_loss = out_loss.mean() * multiplier
314 | # mdl_out_loss = out_loss * 1000
315 | return mdl_out_loss
316 |
317 | def forward(self, out, inp):
318 | targets_all = self.compute_loss_targets(inp)
319 | targets_n = targets_all['targets_one']
320 |
321 | mdl_outs = out['mdl_outs']
322 |
323 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp)
324 |
325 | out_loss = mdl_out_loss
326 |
327 | out_loss_dict = {
328 | 'loss': out_loss,
329 | 'mdl_out_loss': mdl_out_loss,
330 | }
331 |
332 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()}
333 |
334 |
335 | class LossB_SPAT(LossB_TEMP):
336 | def after_init(self):
337 | self.loss_keys = ['loss', 'mdl_out_loss']
338 |
339 | self.num_sampled_frm = self.cfg.ds.num_sampled_frm
340 | self.num_prop_per_frm = self.comm.num_prop_per_frm
341 |
342 | def compute_loss_targets(self, inp):
343 | """
344 | Compute the targets, based on iou
345 | overlaps
346 | """
347 | num_cmp = inp['new_srl_idxs'].size(1)
348 | overlaps = self.compute_overlaps(inp)
349 | B, num_tot_props, num_gt = overlaps.shape
350 | assert num_tot_props % num_cmp == 0
351 |
352 | overlaps_msk = overlaps.new_zeros(
353 | B, self.num_sampled_frm, num_cmp,
354 | self.num_prop_per_frm, num_gt
355 | )
356 |
357 | targ_cmp = inp['target_cmp']
358 | overlaps_msk.scatter_(
359 | dim=2,
360 | index=targ_cmp.view(B, 1, 1, 1, 1).expand(
361 | B, self.num_sampled_frm, num_cmp, self.num_prop_per_frm, num_gt
362 | ),
363 | src=overlaps_msk.new_ones(*overlaps_msk.shape)
364 | )
365 |
366 | overlaps_msk = overlaps_msk.view(B, num_tot_props, num_gt)
367 | overlaps_one_targ = overlaps * overlaps_msk
368 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp)
369 | return {
370 | 'targets_one': targets_one,
371 | }
372 |
373 | def compute_mdl_loss(self, mdl_outs, targets_one, inp):
374 | weights = None
375 | tot_loss = F.binary_cross_entropy_with_logits(
376 | mdl_outs, target=targets_one.float(),
377 | weight=weights,
378 | reduction='none'
379 | )
380 |
381 | num_cmp_msk = inp['num_cmp_msk']
382 | B, num_cmp = num_cmp_msk.shape
383 |
384 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask']
385 |
386 | B, num_verbs, num_srl_args = srl_arg_boxes_mask.shape
387 |
388 | boxes_msk = (
389 | srl_arg_boxes_mask.view(
390 | B, num_verbs, num_srl_args, 1).expand(
391 | B, num_verbs, num_srl_args, num_cmp).float() *
392 | num_cmp_msk.view(
393 | B, 1, 1, num_cmp).expand(
394 | B, num_verbs, num_srl_args, num_cmp).float()
395 | )
396 |
397 | num_tot_props = targets_one.size(-1)
398 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 4000
399 | boxes_msk = boxes_msk.view(
400 | B, num_verbs, num_srl_args, 1, num_cmp, 1
401 | ).expand(
402 | B, num_verbs, num_srl_args, self.num_sampled_frm,
403 | num_cmp, self.num_prop_per_frm
404 | ).contiguous().view(
405 | B, num_verbs, num_srl_args, num_tot_props
406 | )
407 |
408 | multiplier = tot_loss.size(-1)
409 | if srl_arg_boxes_mask.max() > 0:
410 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte())
411 | else:
412 | # TODO: NEED TO check what is wrong here
413 | out_loss = tot_loss
414 | mdl_out_loss = out_loss.mean() * multiplier
415 |
416 | return mdl_out_loss
417 |
418 | def forward(self, out, inp):
419 | targets_all = self.compute_loss_targets(inp)
420 | targets_n = targets_all['targets_one']
421 |
422 | mdl_outs = out['mdl_outs']
423 |
424 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp)
425 |
426 | out_loss = mdl_out_loss
427 |
428 | out_loss_dict = {
429 | 'loss': out_loss,
430 | 'mdl_out_loss': mdl_out_loss,
431 | }
432 |
433 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()}
434 |
--------------------------------------------------------------------------------
/code/mdl_selector.py:
--------------------------------------------------------------------------------
1 | """
2 | Select the model, loss, eval_fn
3 | """
4 | from mdl_vog import (
5 | ImgGrnd_SEP,
6 | ImgGrnd_TEMP,
7 | ImgGrnd_SPAT,
8 | VidGrnd_SEP,
9 | VidGrnd_TEMP,
10 | VidGrnd_SPAT,
11 | VOG_SEP,
12 | VOG_TEMP,
13 | VOG_SPAT,
14 | LossB_SEP,
15 | LossB_TEMP,
16 | LossB_SPAT
17 | )
18 |
19 | from eval_vsrl_corr import (
20 | EvaluatorSEP,
21 | EvaluatorTEMP,
22 | EvaluatorSPAT
23 | )
24 |
25 |
26 | def get_mdl_loss_eval(cfg):
27 | conc_type = cfg.ds.conc_type
28 | mdl_type = cfg.mdl.name
29 | if conc_type == 'sep' or conc_type == 'svsq':
30 | if mdl_type == 'igrnd':
31 | mdl = ImgGrnd_SEP
32 | elif mdl_type == 'vgrnd':
33 | mdl = VidGrnd_SEP
34 | elif mdl_type == 'vog':
35 | mdl = VOG_SEP
36 | else:
37 | raise NotImplementedError
38 | loss = LossB_SEP
39 | evl = EvaluatorSEP
40 | elif conc_type == 'temp':
41 | if mdl_type == 'igrnd':
42 | mdl = ImgGrnd_TEMP
43 | elif mdl_type == 'vgrnd':
44 | mdl = VidGrnd_TEMP
45 | elif mdl_type == 'vog':
46 | mdl = VOG_TEMP
47 | else:
48 | raise NotImplementedError
49 | loss = LossB_TEMP
50 | evl = EvaluatorTEMP
51 | elif conc_type == 'spat':
52 | if mdl_type == 'igrnd':
53 | mdl = ImgGrnd_SPAT
54 | elif mdl_type == 'vgrnd':
55 | mdl = VidGrnd_SPAT
56 | elif mdl_type == 'vog':
57 | mdl = VOG_SPAT
58 | else:
59 | raise NotImplementedError
60 | loss = LossB_SPAT
61 | evl = EvaluatorSPAT
62 | else:
63 | raise NotImplementedError
64 |
65 | return {
66 | 'mdl': mdl,
67 | 'loss': loss,
68 | 'eval': evl
69 | }
70 |
--------------------------------------------------------------------------------
/code/transformer_code.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer implementation adapted from
3 | https://github.com/facebookresearch/grounded-video-description/blob/master/misc/transformer.py
4 | """
5 | import torch
6 | import math
7 | from torch import nn
8 | from torch.nn import functional as F
9 |
10 | INF = 1e10
11 |
12 |
13 | def matmul(x, y):
14 | if x.dim() == y.dim():
15 | return torch.matmul(x, y)
16 | if x.dim() == y.dim() - 1:
17 | return torch.matmul(x.unsqueeze(-2), y).squeeze(-2)
18 | return torch.matmul(x, y.unsqueeze(-2)).squeeze(-2)
19 |
20 |
21 | class ResidualBlock(nn.Module):
22 |
23 | def __init__(self, layer, d_model, drop_ratio):
24 | super(ResidualBlock, self).__init__()
25 | self.layer = layer
26 | self.dropout = nn.Dropout(drop_ratio)
27 | # self.layernorm = LayerNorm(d_model)
28 | self.layernorm = nn.LayerNorm(d_model)
29 |
30 | def forward(self, *x):
31 | return self.layernorm(x[0] + self.dropout(self.layer(*x)))
32 |
33 |
34 | class Attention(nn.Module):
35 |
36 | def __init__(self, d_key, drop_ratio, causal):
37 | super(Attention, self).__init__()
38 | self.scale = math.sqrt(d_key)
39 | self.dropout = nn.Dropout(drop_ratio)
40 | self.causal = causal
41 |
42 | def forward(self, query, key, value):
43 | dot_products = matmul(query, key.transpose(1, 2))
44 | if query.dim() == 3 and (self is None or self.causal):
45 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF
46 | if key.is_cuda:
47 | tri = tri.cuda(key.get_device())
48 | dot_products.data.sub_(tri.unsqueeze(0))
49 |
50 | return matmul(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value)
51 |
52 |
53 | class MultiHead(nn.Module):
54 |
55 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False):
56 | super(MultiHead, self).__init__()
57 | self.attention = Attention(d_key, drop_ratio, causal=causal)
58 | self.wq = nn.Linear(d_key, d_key, bias=False)
59 | self.wk = nn.Linear(d_key, d_key, bias=False)
60 | self.wv = nn.Linear(d_value, d_value, bias=False)
61 | self.wo = nn.Linear(d_value, d_key, bias=False)
62 | self.n_heads = n_heads
63 |
64 | def forward(self, query, key, value):
65 | query, key, value = self.wq(query), self.wk(key), self.wv(value)
66 |
67 | query, key, value = (
68 | x.chunk(self.n_heads, -1) for x in (query, key, value))
69 | return self.wo(torch.cat([self.attention(q, k, v)
70 | for q, k, v in zip(query, key, value)], -1))
71 |
72 |
73 | class FeedForward(nn.Module):
74 |
75 | def __init__(self, d_model, d_hidden):
76 | super(FeedForward, self).__init__()
77 | self.linear1 = nn.Linear(d_model, d_hidden)
78 | self.linear2 = nn.Linear(d_hidden, d_model)
79 |
80 | def forward(self, x):
81 | return self.linear2(F.relu(self.linear1(x)))
82 |
83 |
84 | class EncoderLayer(nn.Module):
85 |
86 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio):
87 | super(EncoderLayer, self).__init__()
88 | self.selfattn = ResidualBlock(
89 | MultiHead(d_model, d_model, n_heads, drop_ratio),
90 | d_model, drop_ratio)
91 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden),
92 | d_model, drop_ratio)
93 |
94 | def forward(self, x):
95 | return self.feedforward(self.selfattn(x, x, x))
96 |
97 |
98 | class Encoder(nn.Module):
99 |
100 | def __init__(self, d_model, d_hidden, n_vocab, n_layers, n_heads,
101 | drop_ratio, pe):
102 | super(Encoder, self).__init__()
103 | # self.linear = nn.Linear(d_model*2, d_model)
104 | self.layers = nn.ModuleList(
105 | [EncoderLayer(d_model, d_hidden, n_heads, drop_ratio)
106 | for i in range(n_layers)])
107 | self.dropout = nn.Dropout(drop_ratio)
108 | self.pe = pe
109 |
110 | def forward(self, x, mask=None):
111 | # x = self.linear(x)
112 | if self.pe:
113 | # spatial configuration is already encoded
114 | # x = x+positional_encodings_like(x)
115 | raise NotImplementedError
116 | # x = self.dropout(x) # dropout is already in the pool_embed layer
117 | if mask is not None:
118 | x = x*mask
119 | encoding = []
120 | for layer in self.layers:
121 | x = layer(x)
122 | if mask is not None:
123 | x = x*mask
124 | encoding.append(x)
125 | return encoding
126 |
127 |
128 | class RelAttention(nn.Module):
129 |
130 | def __init__(self, d_key, drop_ratio, causal):
131 | super().__init__()
132 | self.scale = math.sqrt(d_key)
133 | self.dropout = nn.Dropout(drop_ratio)
134 | self.causal = causal
135 |
136 | def forward(self, query, key, value, pe_k, pe_v):
137 | """
138 | query, key, value: B x N x 214
139 | pe_k: B x N x N x 214
140 | """
141 | dot_products = matmul(query, key.transpose(1, 2))
142 | if query.dim() == 3 and (self is None or self.causal):
143 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF
144 | if key.is_cuda:
145 | tri = tri.cuda(key.get_device())
146 | dot_products.data.sub_(tri.unsqueeze(0))
147 |
148 | # new_dp = matmul(query, pe_k.transpose(2, 3))
149 | new_dp = pe_k.squeeze(-1)
150 | assert new_dp.shape == dot_products.shape
151 | new_dot_prods = (dot_products + new_dp) / self.scale
152 |
153 | attn = self.dropout(F.softmax(new_dot_prods, dim=-1))
154 |
155 | out_v = matmul(attn, value)
156 | # new_out_v = matmul(attn, pe_v)
157 | # new_out_v = pe_v
158 |
159 | new_outs = out_v
160 | return new_outs
161 |
162 |
163 | class RelMultiHead(nn.Module):
164 |
165 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False, d_pe=None):
166 | super().__init__()
167 | self.attention = RelAttention(d_key, drop_ratio, causal=causal)
168 | self.n_heads = n_heads
169 | self.wq = nn.Linear(d_key, d_key, bias=False)
170 | self.wk = nn.Linear(d_key, d_key, bias=False)
171 | self.wv = nn.Linear(d_value, d_value, bias=False)
172 | self.wo = nn.Linear(d_value, d_key, bias=False)
173 | # self.wpk = nn.Linear(d_pe, self.n_heads, bias=False)
174 | # self.wpv = nn.Linear(d_pe, self.n_heads, bias=False)
175 |
176 | def forward(self, query, key, value, pe=None):
177 | """
178 | pe is B x N x N x 1 position difference
179 | """
180 | query, key, value = self.wq(query), self.wk(key), self.wv(value)
181 | pe_k, pe_v = pe, pe
182 | query, key, value, pe_k, pe_v = (
183 | x.chunk(self.n_heads, -1) for x in (query, key, value, pe_k, pe_v))
184 | return self.wo(torch.cat([self.attention(q, k, v, pk, pv)
185 | for q, k, v, pk, pv in
186 | zip(query, key, value, pe_k, pe_v)], -1))
187 |
188 |
189 | class RelEncoderLayer(nn.Module):
190 |
191 | def __init__(self, d_model, d_hidden, n_heads,
192 | drop_ratio, d_pe=None, sa=True):
193 | super().__init__()
194 | self.selfattn = ResidualBlock(
195 | RelMultiHead(d_model, d_model, n_heads, drop_ratio, d_pe=d_pe),
196 | d_model, drop_ratio)
197 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden),
198 | d_model, drop_ratio)
199 | self.sa = sa
200 |
201 | def forward(self, x, pe=None):
202 | if not isinstance(x, dict):
203 | return self.feedforward(self.selfattn(x, x, x, pe))
204 | else:
205 | assert not self.sa
206 | assert isinstance(x, dict)
207 | assert 'query' in x
208 | assert 'key' in x
209 | assert 'value' in x
210 | return self.feedforward(
211 | self.selfattn(x['query'], x['key'], x['value'], pe)
212 | )
213 |
214 |
215 | class RelEncoder(nn.Module):
216 |
217 | def __init__(self, d_model, d_hidden, n_vocab, n_layers, n_heads,
218 | drop_ratio, pe, d_pe, sa=True):
219 | super().__init__()
220 | # self.linear = nn.Linear(d_model*2, d_model)
221 | self.layers = nn.ModuleList(
222 | [RelEncoderLayer(d_model, d_hidden, n_heads, drop_ratio, d_pe=d_pe, sa=sa)
223 | for i in range(n_layers)])
224 | self.dropout = nn.Dropout(drop_ratio)
225 | self.pe = pe
226 |
227 | def forward(self, x, x_pe, mask=None):
228 | # x = self.linear(x)
229 | if self.pe:
230 | # spatial configuration is already encoded
231 | raise NotImplementedError
232 | # x = self.dropout(x) # dropout is already in the pool_embed layer
233 | if mask is not None:
234 | x = x*mask
235 | encoding = []
236 | for layer in self.layers:
237 | x = layer(x, pe=x_pe)
238 | if mask is not None:
239 | x = x*mask
240 | encoding.append(x)
241 | return encoding
242 |
243 |
244 | class Transformer(nn.Module):
245 |
246 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048,
247 | n_layers=6, n_heads=8, drop_ratio=0.1, pe=False):
248 | super(Transformer, self).__init__()
249 | self.encoder = Encoder(d_model, d_hidden, n_vocab_src, n_layers,
250 | n_heads, drop_ratio, pe)
251 |
252 | def forward(self, x):
253 | encoding = self.encoder(x)
254 | return encoding[-1]
255 | # return encoding[-1], encoding
256 | # return torch.cat(encoding, 2)
257 |
258 | def all_outputs(self, x):
259 | encoding = self.encoder(x)
260 | return encoding
261 |
262 |
263 | class RelTransformer(nn.Module):
264 |
265 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048,
266 | n_layers=6, n_heads=8, drop_ratio=0.1, pe=False, d_pe=None):
267 | super().__init__()
268 | self.encoder = RelEncoder(d_model, d_hidden, n_vocab_src, n_layers,
269 | n_heads, drop_ratio, pe, d_pe=d_pe)
270 |
271 | def forward(self, x, x_pe):
272 | encoding = self.encoder(x, x_pe)
273 | return encoding[-1]
274 | # return encoding[-1], encoding
275 | # return torch.cat(encoding, 2)
276 |
277 | def all_outputs(self, x):
278 | encoding = self.encoder(x)
279 | return encoding
280 |
--------------------------------------------------------------------------------
/code/visualizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualize Predictions
3 | """
4 |
5 | import pandas as pd
6 | import pickle
7 | from PIL import Image
8 | from pathlib import Path
9 | from eval_fn_corr import (
10 | GroundEval_SEP,
11 | GroundEval_TEMP,
12 | GroundEval_SPAT
13 | )
14 | import fire
15 | from munch import Munch
16 | from typing import List
17 |
18 |
19 | class ASRL_Vis:
20 | def open_required_files(self, ann_file):
21 | self.annots = pd.read_csv(ann_file)
22 |
23 | def draw_boxes_all_indices(self, preds):
24 | # self.preds
25 | pass
26 |
27 | def prepare_img(self, img_list: List):
28 | """
29 | How to concate the image from an image list
30 | """
31 | raise NotImplementedError
32 |
33 | def extract_bbox_per_frame(self, preds):
34 | """
35 | Obtain the bounding boxes for each frame
36 | """
37 | raise NotImplementedError
38 |
39 | def all_inds(self, pred_file, split_type):
40 | self.prepare_gt(split_type)
41 |
42 | def draw_boxes_one_index(
43 | self, pred, gt_row, conc_type
44 | ):
45 | frm_tdir = Path('/home/Datasets/ActNetEnt/frames_10frm/')
46 | vid_file_id_list = pred['idx_vid']
47 |
48 | rows = self.annots.iloc[vid_file_id_list]
49 | vid_seg_id_list = rows['id']
50 |
51 | img_file_dict = {
52 | k: sorted(
53 | [x for x in (frm_tdir/k).iterdir()],
54 | key=lambda x: int(x.stem)
55 | )
56 | for k in vid_seg_id_list
57 | }
58 | img_list_dict = {
59 | k: [Image.open(img_file) for img_file in img_file_list]
60 | for k, img_file_list in img_file_dict.items()
61 | }
62 |
63 | img = self.prepare_img(img_list_dict)
64 | pass
65 |
66 |
67 | class ASRL_Vis_SEP(GroundEval_SEP, ASRL_Vis):
68 | pass
69 |
70 |
71 | class ASRL_Vis_TEMP(GroundEval_TEMP, ASRL_Vis):
72 | pass
73 |
74 |
75 | class ASRL_Vis_SPAT(GroundEval_SPAT, ASRL_Vis):
76 | pass
77 |
78 |
79 | def main(pred_file, split_type='valid', **kwargs):
80 | if 'cfg' not in kwargs:
81 | from extended_config import (
82 | cfg as conf,
83 | key_maps,
84 | # CN,
85 | update_from_dict,
86 | # post_proc_config
87 | )
88 | cfg = conf
89 | cfg = update_from_dict(cfg, kwargs, key_maps)
90 | else:
91 | cfg = kwargs['cfg']
92 | cfg.freeze()
93 | # grnd_eval = GroundEval_Corr(cfg)
94 | # grnd_eval = GroundEvalDS4(cfg)
95 | comm = Munch()
96 | exp = cfg.ds.exp_setting
97 | if exp == 'gt5':
98 | comm.num_prop_per_frm = 5
99 | elif exp == 'p100':
100 | comm.num_prop_per_frm = 100
101 | else:
102 | raise NotImplementedError
103 |
104 | conc_type = cfg.ds.conc_type
105 | if conc_type == 'sep' or conc_type == 'svsq':
106 | avis = ASRL_Vis_SEP(cfg, comm)
107 | elif conc_type == 'temp':
108 | avis = ASRL_Vis_TEMP(cfg, comm)
109 | elif conc_type == 'spat':
110 | avis = ASRL_Vis_SPAT(cfg, comm)
111 | else:
112 | raise NotImplementedError
113 |
114 | # avis.draw_boxes_all_indices(
115 | # pred_file, split_type=split_type
116 | # )
117 |
118 | return avis
119 |
120 |
121 | if __name__ == '__main__':
122 | fire.Fire(main)
123 |
--------------------------------------------------------------------------------
/conda_env_vog.yml:
--------------------------------------------------------------------------------
1 | name: vog_pyt
2 | channels:
3 | - pytorch
4 | - fastai
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - attrs=19.3.0=py_0
9 | - backcall=0.1.0=py36_0
10 | - blas=1.0=mkl
11 | - bleach=3.1.0=py_0
12 | - ca-certificates=2020.1.1=0
13 | - certifi=2019.11.28=py36_1
14 | - cffi=1.14.0=py36h2e261b9_0
15 | - cudatoolkit=10.0.130=0
16 | - dbus=1.13.12=h746ee38_0
17 | - decorator=4.4.2=py_0
18 | - defusedxml=0.6.0=py_0
19 | - entrypoints=0.3=py36_0
20 | - expat=2.2.6=he6710b0_0
21 | - fastprogress=0.1.21=py_0
22 | - fontconfig=2.13.0=h9420a91_0
23 | - freetype=2.9.1=h8a8886c_1
24 | - glib=2.63.1=h5a9c865_0
25 | - gmp=6.1.2=h6c8ec71_1
26 | - gst-plugins-base=1.14.0=hbbd80ab_1
27 | - gstreamer=1.14.0=hb453b48_1
28 | - icu=58.2=h9c2bf20_1
29 | - importlib_metadata=1.5.0=py36_0
30 | - intel-openmp=2020.0=166
31 | - ipykernel=5.1.4=py36h39e3cac_0
32 | - ipython=7.13.0=py36h5ca1d4c_0
33 | - ipython_genutils=0.2.0=py36_0
34 | - ipywidgets=7.5.1=py_0
35 | - jedi=0.16.0=py36_1
36 | - jinja2=2.11.1=py_0
37 | - jpeg=9b=h024ee3a_2
38 | - jsonschema=3.2.0=py36_0
39 | - jupyter=1.0.0=py36_7
40 | - jupyter_client=6.1.0=py_0
41 | - jupyter_console=6.1.0=py_0
42 | - jupyter_core=4.6.1=py36_0
43 | - ld_impl_linux-64=2.33.1=h53a641e_7
44 | - libedit=3.1.20181209=hc058e9b_0
45 | - libffi=3.2.1=hd88cf55_4
46 | - libgcc-ng=9.1.0=hdf63c60_0
47 | - libgfortran-ng=7.3.0=hdf63c60_0
48 | - libpng=1.6.37=hbc83047_0
49 | - libsodium=1.0.16=h1bed415_0
50 | - libstdcxx-ng=9.1.0=hdf63c60_0
51 | - libtiff=4.1.0=h2733197_0
52 | - libuuid=1.0.3=h1bed415_2
53 | - libxcb=1.13=h1bed415_1
54 | - libxml2=2.9.9=hea5a465_1
55 | - markupsafe=1.1.1=py36h7b6447c_0
56 | - mistune=0.8.4=py36h7b6447c_0
57 | - mkl=2020.0=166
58 | - mkl-service=2.3.0=py36he904b0f_0
59 | - mkl_fft=1.0.15=py36ha843d7b_0
60 | - mkl_random=1.1.0=py36hd6b4f25_0
61 | - nbconvert=5.6.1=py36_0
62 | - nbformat=5.0.4=py_0
63 | - ncurses=6.2=he6710b0_0
64 | - ninja=1.9.0=py36hfd86e86_0
65 | - notebook=6.0.3=py36_0
66 | - numpy=1.18.1=py36h4f9e942_0
67 | - numpy-base=1.18.1=py36hde5b4d6_1
68 | - olefile=0.46=py_0
69 | - openssl=1.1.1e=h7b6447c_0
70 | - pandas=1.0.3=py36h0573a6f_0
71 | - pandoc=2.2.3.2=0
72 | - pandocfilters=1.4.2=py36_1
73 | - parso=0.6.2=py_0
74 | - pcre=8.43=he6710b0_0
75 | - pexpect=4.8.0=py36_0
76 | - pickleshare=0.7.5=py36_0
77 | - pillow=7.0.0=py36hb39fc2d_0
78 | - pip=20.0.2=py36_1
79 | - prometheus_client=0.7.1=py_0
80 | - prompt-toolkit=3.0.4=py_0
81 | - prompt_toolkit=3.0.4=0
82 | - ptyprocess=0.6.0=py36_0
83 | - pycparser=2.20=py_0
84 | - pygments=2.6.1=py_0
85 | - pyqt=5.9.2=py36h05f1152_2
86 | - pyrsistent=0.15.7=py36h7b6447c_0
87 | - python=3.6.10=hcf32534_1
88 | - python-dateutil=2.8.1=py_0
89 | - pytorch=1.1.0=py3.6_cuda10.0.130_cudnn7.5.1_0
90 | - pytz=2019.3=py_0
91 | - pyzmq=18.1.1=py36he6710b0_0
92 | - qt=5.9.7=h5867ecd_1
93 | - qtconsole=4.7.2=py_0
94 | - qtpy=1.9.0=py_0
95 | - readline=8.0=h7b6447c_0
96 | - send2trash=1.5.0=py36_0
97 | - setuptools=46.1.1=py36_0
98 | - sip=4.19.8=py36hf484d3e_0
99 | - six=1.14.0=py36_0
100 | - sqlite=3.31.1=h7b6447c_0
101 | - terminado=0.8.3=py36_0
102 | - testpath=0.4.4=py_0
103 | - tk=8.6.8=hbc83047_0
104 | - torchvision=0.3.0=py36_cu10.0.130_1
105 | - tornado=6.0.4=py36h7b6447c_1
106 | - traitlets=4.3.3=py36_0
107 | - wcwidth=0.1.8=py_0
108 | - webencodings=0.5.1=py36_1
109 | - wheel=0.34.2=py36_0
110 | - widgetsnbextension=3.5.1=py36_0
111 | - xz=5.2.4=h14c3975_4
112 | - zeromq=4.3.1=he6710b0_3
113 | - zipp=2.2.0=py_0
114 | - zlib=1.2.11=h7b6447c_3
115 | - zstd=1.3.7=h0b5b093_0
116 | - pip:
117 | - absl-py==0.9.0
118 | - alabaster==0.7.12
119 | - allennlp==0.8.5
120 | - babel==2.8.0
121 | - blis==0.2.4
122 | - boto3==1.12.31
123 | - botocore==1.15.31
124 | - cachetools==4.0.0
125 | - chardet==3.0.4
126 | - click==7.1.1
127 | - conllu==1.3.1
128 | - cycler==0.10.0
129 | - cymem==2.0.3
130 | - cython==0.29.16
131 | - dataclasses==0.7
132 | - docutils==0.15.2
133 | - editdistance==0.5.3
134 | - fairseq==0.8.0
135 | - fastbpe==0.1.0
136 | - fire==0.3.0
137 | - flaky==3.6.1
138 | - flask==1.1.1
139 | - flask-cors==3.0.8
140 | - ftfy==5.7
141 | - future==0.18.2
142 | - gevent==1.4.0
143 | - google-auth==1.12.0
144 | - google-auth-oauthlib==0.4.1
145 | - greenlet==0.4.15
146 | - grpcio==1.27.2
147 | - h5py==2.10.0
148 | - idna==2.9
149 | - imagesize==1.2.0
150 | - itsdangerous==1.1.0
151 | - jmespath==0.9.5
152 | - joblib==0.14.1
153 | - jsonnet==0.15.0
154 | - jsonpickle==1.3
155 | - kiwisolver==1.1.0
156 | - markdown==3.2.1
157 | - matplotlib==3.2.1
158 | - more-itertools==8.2.0
159 | - munch==2.5.0
160 | - murmurhash==1.0.2
161 | - nltk==3.4.5
162 | - numpydoc==0.9.2
163 | - oauthlib==3.1.0
164 | - overrides==2.8.0
165 | - packaging==20.3
166 | - parsimonious==0.8.1
167 | - plac==0.9.6
168 | - pluggy==0.13.1
169 | - portalocker==1.6.0
170 | - preshed==2.0.1
171 | - protobuf==3.11.3
172 | - py==1.8.1
173 | - pyasn1==0.4.8
174 | - pyasn1-modules==0.2.8
175 | - pyparsing==2.4.6
176 | - pytest==5.4.1
177 | - pytorch-pretrained-bert==0.6.2
178 | - pytorch-transformers==1.1.0
179 | - pyyaml==5.3.1
180 | - regex==2020.2.20
181 | - requests==2.23.0
182 | - requests-oauthlib==1.3.0
183 | - responses==0.10.12
184 | - rsa==4.0
185 | - s3transfer==0.3.3
186 | - sacrebleu==1.4.4
187 | - scikit-learn==0.22.2.post1
188 | - scipy==1.4.1
189 | - sentencepiece==0.1.85
190 | - snowballstemmer==2.0.0
191 | - spacy==2.1.9
192 | - sphinx==2.4.4
193 | - sphinxcontrib-applehelp==1.0.2
194 | - sphinxcontrib-devhelp==1.0.2
195 | - sphinxcontrib-htmlhelp==1.0.3
196 | - sphinxcontrib-jsmath==1.0.1
197 | - sphinxcontrib-qthelp==1.0.3
198 | - sphinxcontrib-serializinghtml==1.1.4
199 | - sqlparse==0.3.1
200 | - srsly==1.0.2
201 | - tensorboard==2.2.0
202 | - tensorboard-plugin-wit==1.6.0.post2
203 | - tensorboardx==2.0
204 | - termcolor==1.1.0
205 | - thinc==7.0.8
206 | - torchtext==0.5.0
207 | - tqdm==4.43.0
208 | - typing==3.7.4.1
209 | - unidecode==1.1.1
210 | - urllib3==1.25.8
211 | - wasabi==0.6.0
212 | - werkzeug==1.0.0
213 | - word2number==1.1
214 | - yacs==0.1.6
215 | prefix: /home/arka/.conda/envs/vog_pyt
216 |
217 |
--------------------------------------------------------------------------------
/configs/anet_srl_cfg.yml:
--------------------------------------------------------------------------------
1 | ds_name: "anet"
2 | ds:
3 | # where to find the rgb+flow data
4 | seg_feature_root: "data/anet/rgb_motion_1d"
5 | # choose one setting
6 | exp_setting: "gt5" #or "p100"
7 | gt5:
8 | # bounding boxes from FasterRCNN
9 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_gt5_rois.h5"
10 | # extracted features from FasterRCNN
11 | feature_root: "data/anet/fc6_feat_5rois"
12 | # number of proposals considered per frame
13 | num_prop_per_frm: 5
14 | p100:
15 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_100rois_resized.h5"
16 | feature_root: "data/anet/fc6_feat_100rois"
17 | num_prop_per_frm: 100
18 | resized_width: 720
19 | resized_height: 405
20 | num_sampled_frm: 10
21 | max_gt_box: 100
22 | t_attn_size: 480
23 | max_seq_length: 20
24 | anet_cap_file: "data/anet_cap_ent_files/anet_captions_all_splits.json"
25 | anet_ent_annot_file: "data/anet_cap_ent_files/anet_ent_cls_bbox_trainval.json"
26 | anet_ent_split_file: "data/anet_cap_ent_files/dic_anet.json"
27 | include_srl_args: ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']
28 | # Vocab file for SRLs
29 | arg_vocab_file: "data/anet_srl_files/arg_vocab.pkl"
30 | # Annot files:
31 | trn_ann_file: "data/anet_cap_ent_files/csv_dir/train_postproc.csv"
32 | val_ann_file: "data/anet_cap_ent_files/csv_dir/val_postproc.csv"
33 | # Object Mappings:
34 | trn_ds4_dicts: "data/anet_srl_files/trn_srl_obj_to_index_dict.json"
35 | val_ds4_dicts: "data/anet_srl_files/val_srl_obj_to_index_dict.json"
36 | # ASRL with indices for SPAT/TEMP
37 | trn_ds4_inds: "data/anet_srl_files/trn_asrl_annots.csv"
38 | val_ds4_inds: "data/anet_srl_files/val_asrl_annots.csv"
39 | # Sampling mechanism
40 | trn_sample: "ds4_random"
41 | val_sample: "ds4"
42 | # Num Vids Sampled at a time (should be an int)
43 | trn_num_vid_sample: 4
44 | val_num_vid_sample: 4
45 | # Type of Concatenation, choose among ['svsq', 'sep', 'temp', 'spat']
46 | conc_type: 'spat'
47 | # Shuffle:
48 | cs_shuffle: True
49 | none_word: ""
50 |
51 | mdl:
52 | name: 'vog'
53 | seg_feat_dim: 3072
54 | prop_feat_dim: 2048
55 | input_encoding_size: 512
56 | use_vis_msk: True
57 | rnn:
58 | rnn_size: 1024
59 | num_layers: 2
60 | drop_prob_lm: 0.5
61 | vsrl:
62 | prop_encode_size: 256
63 | seg_encode_size: 256
64 | lang_encode_size: 256
65 | obj_tx:
66 | use_ddp: false
67 | to_use: true
68 | n_layers: 1
69 | n_heads: 3
70 | attn_drop: 0.2
71 | use_rel: false
72 | one_frm: false
73 | mul_tx:
74 | use_ddp: false
75 | to_use: true
76 | n_layers: 1
77 | n_heads: 3
78 | attn_drop: 0.2
79 | use_rel: false
80 | one_frm: true
81 | cross_frm: false
82 | loss:
83 | only_vid_loss: false
84 | loss_lambda: 1
85 | loss_margin: 0.1
86 | loss_margin_vid: 0.5
87 | # loss_type is either
88 | # cosine or bce
89 | loss_type: 'bce'
90 |
91 | misc:
92 | # Place to save models/logs/predictions etc
93 | tmp_path: "tmp"
94 | # Include/Exclude proposal based on the threshold
95 | prop_thresh: 0.
96 | # Whether to exclude the proposals having background class
97 | exclude_bgd_det: False
98 | # Whether to add the proposal (5d coordinate) to
99 | # the region feature
100 | add_prop_to_region: False
101 | # What context to use for average pooling segment features
102 | ctx_for_seg_feats: 0
103 | # max number of semantic roles in a sentence
104 | srl_arg_length: 5
105 | # how many boxes to consider for a particular phrase
106 | box_per_srl_arg: 4
107 | train:
108 | lr: 1e-4
109 | epochs: 10
110 | bs: 4
111 | nw: 4
112 | bsv: 4
113 | nwv: 4
114 | resume: true
115 | resume_path: ""
116 | load_opt: false
117 | load_normally: true
118 | strict_load: true
119 | use_reduce_lr_plateau: false
120 | verbose: false
121 | prob_thresh: 0.2
122 | log:
123 | deb_it: 2
124 | local_rank: 0
125 | do_dist: False
126 | do_dp: false
127 | num_gpus: 1
128 | only_val: false
129 | only_test: false
130 | run_final_val: true
131 | overfit_batch: false
132 |
--------------------------------------------------------------------------------
/configs/create_asrl_cfg.yml:
--------------------------------------------------------------------------------
1 | ds_name: "asrl"
2 | ds:
3 | # AC/AE annotation files
4 | anet_cap_file: "data/anet_cap_ent_files/anet_captions_all_splits.json"
5 | anet_ent_split_file: "data/anet_cap_ent_files/dic_anet.json"
6 | anet_ent_annot_file: "data/anet_cap_ent_files/cap_anet_trainval.json"
7 | orig_anet_ent_clss: "data/anet_cap_ent_files/anet_entities_cleaned_class_thresh50_trainval.json"
8 | preproc_anet_ent_clss: "data/anet_cap_ent_files/anet_ent_cls_bbox_trainval.json"
9 | # After adding semantic roles, these are generated inside the cache dir
10 | srl_caps: "SRL_Anet_cap_annots.csv"
11 | srl_bert: "srl_bert_preds.pkl"
12 | # Resized width, height
13 | resized_width: 720
14 | resized_height: 405
15 | # Feature files
16 | vid_hw_map: "data/anet/vid_hw_dict.json"
17 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_100rois.h5"
18 | proposal_h5_resized: "data/anet/anet_detection_vg_fc6_feat_100rois_resized.h5"
19 | seg_feature_root: "data/anet/rgb_motion_1d"
20 | feature_root: "data/anet/fc6_feat_100rois"
21 | # verbs and arguments to include/exclude
22 | exclude_verb_set: ['be', 'see', 'show', "'s", 'can', 'continue', 'begin', 'start']
23 | include_srl_args: ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']
24 | # Lemmatized verb list (created only once)
25 | verb_lemma_dict_file: "data/anet_srl_files/verb_lemma_dict.json"
26 | # SRL with verbs
27 | verb_ent_file: "data/anet_srl_files/verb_ent_file.csv"
28 | trn_verb_ent_file: "data/anet_srl_files/trn_verb_ent_file.csv"
29 | val_verb_ent_file: "data/anet_srl_files/val_verb_ent_file.csv"
30 | # Object Mappings:
31 | trn_ds4_dicts: "data/anet_srl_files/trn_srl_obj_to_index_dict.json"
32 | val_ds4_dicts: "data/anet_srl_files/val_srl_obj_to_index_dict.json"
33 | # ASRL with indices for SPAT/TEMP
34 | trn_ds4_inds: "data/anet_srl_files/trn_asrl_annots.csv"
35 | val_ds4_inds: "data/anet_srl_files/val_asrl_annots.csv"
36 | # Arg Vocab:
37 | arg_vocab_file: "data/anet_srl_files/arg_vocab.pkl"
38 | # None
39 | none_word: ""
40 | # GT5
41 | ngt_prop: 5
42 | num_frms: 10
43 | feature_gt5_root: "data/anet/fc6_feat_5rois"
44 | proposal_gt5_h5_resized: "data/anet/anet_detection_vg_fc6_feat_gt5_rois.h5"
45 | misc:
46 | cache_dir: "cache_dir"
47 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Preparing Data
2 |
3 | This part is to download the data and start with the experiments.
4 |
5 | If instead you are interested in generating ActivityNet-SRL from scratch (not required in general), see [dcode](../dcode).
6 |
7 | ## Quickstart
8 |
9 | Optional: set the data folder.
10 | ```
11 | cd $ROOT/data
12 | bash download_data.sh all [data_folder]
13 | ```
14 |
15 | After everything is downloaded successfully, the folder structure should look like:
16 |
17 | ```
18 | data
19 | |-- anet (530gb)
20 | |-- anet_detection_vg_fc6_feat_100rois.h5
21 | |-- anet_detection_vg_fc6_feat_100rois_resized.h5
22 | |-- anet_detection_vg_fc6_feat_gt5_rois.h5
23 | |-- fc6_feat_100rois
24 | |-- fc6_feat_5rois
25 | |-- rgb_motion_1d
26 | |-- anet_cap_ent_files (31M)
27 | |-- anet_captions_all_splits.json
28 | |-- anet_ent_cls_bbox_trainval.json
29 | |-- csv_dir
30 | |-- train.csv
31 | |-- train_postproc.csv
32 | |-- val.csv
33 | |-- val_postproc.csv
34 | |-- dic_anet.json
35 | |-- anet_srl_files (112M)
36 | |-- arg_vocab.pkl
37 | |-- trn_asrl_annots.csv
38 | |-- trn_srl_obj_to_index_dict.json
39 | |-- val_asrl_annots.csv
40 | |-- val_srl_obj_to_index_dict.json
41 | ```
42 |
43 | It should ~530 gb of data !!
44 |
45 | NOTE: Highly advisable to have the features in SSD; otherwise massive drop in speed!
46 |
47 |
48 | ## Details about the Data
49 | Here, I have explained the data contents in 1-line.
50 | For an in-depth overview of the construction, please refer to [DATA PREP README](../dcode/README.md)
51 |
52 | 1. `fc6_feat_Xrois`: We have 10 frames, for each frame we get X rois. `X=100` is obtained from FasterRCNN trained on Visual Genome. `X=5` is obtained from `X=100` such that ground-truth annotations are included and the remaining are the top scoring boxes. The latter setting allows us to perform easy experimentations.
53 | 1. `rgb_motion_1d`: RGB and FLOW features for frames (1fps) of the video.
54 | 1. `{trn/val}_asrl_annots.csv`: The main annotation files required for grounding.
55 | 1. `{trn/val}_srl_obj_to_index_dict.json`: Dictionary mapping helpful for sampling contrastive examples.
56 |
57 | ## Annotation File Structure:
58 | The main annotation files for ASRL are `{trn/val}_asrl_annots.csv`
59 |
60 | We use Video Segments of the ActivityNet since we are focussing on Trimmed videos only.
61 |
62 | ActivityNet Entities provides the bounding boxes for the noun-phrases in ActivityNet Captions. For more details please refer to [dcode](../dcode)
63 |
64 | `trn_asrl_annots.csv` has 26 columns!
65 |
66 | Lets consider the first example. You can get this using:
67 | ```
68 | import pandas as pd
69 | trn_csv = pd.read_csv('./trn_asrl_annots.csv')
70 | first_data_point = trn_csv.iloc[0]
71 | column_list = ['srl_ind', 'vid_seg']
72 | ```
73 |
74 | 1. `srl_ind`: the index in this csv file. Here it is `0`
75 | 1. `vt_split`: is the split the data point belongs to. All data points in `trn_asrl_anonts.csv` have this set to `train`. However, it is 50-50 split for `val_asrl_annots.csv` for `val` and `test`.
76 | 1. `vid_seg`: the video and the segment of the video the file belongs to. The convention used is `{vid_name}_segment_{seg_id:02d}`. Here it is `v_--0edUL8zmA_segment_00` which means, it is the 0th segment of the video `v_--0edUL8zmA`.
77 | 1. `ann_ind`: this is the index in the `anet_cap_ent_files/csv_dir/{trn/val}_postproc.csv` file. This index is used to retrieve the proposal boxes from `anet_detection_vg_fc6_feat_100rois_resized.h5`. Here it is `28557` which means 28557th row of the h5 file corresponds to this `vid_seg`.
78 | 1. `sent`: this is the main sentence provided in the activitynet captions for the given vid_seg. The sentence may contain multiple verbs, and as such data points sharing the same vid seg will have the same sentence. Here, the sentence is "Four men are playing dodge ball in an indoor court ."
79 | 1. `words`: this is simply tokenization of `sent`. Here it is: \['Four', 'men', 'are', 'playing', 'dodge', 'ball', 'in', 'an', 'indoor', 'court', '.'\]
80 | 1. `verb`: we pass the sentence through a semantic role labeler (see [demo](https://demo.allennlp.org/semantic-role-labeling)) which extracts multiple verbs from the sentence and assigning semantic roles pivoted for each verb. Each verb is treated as a separate data point. Here, the verb is `playing`.
81 | 1. `tags`: The BIO tagging output from the SRL for the given verb. Here it is \['B-ARG0', 'I-ARG0', 'O', 'B-V', 'B-ARG1', 'I-ARG1', 'B-ARGM-LOC', 'I-ARGM-LOC', 'I-ARGM-LOC', 'I-ARGM-LOC', 'O'\] which basically the structure "playing: \[ARG0: Four men] are \[V: playing] \[ARG1: dodge ball] \[ARGM-LOC: in an indoor court] ."
82 | 1. `req_pat_ix`: Same information as `tags` but represented as List\[Tuple\[ArgX, List\[word indices]]. The word indices correspond to the output of `word`. Here it is `[['ARG0', [0, 1]], ['V', [3]], ['ARG1', [4, 5]], ['ARGM-LOC', [6, 7, 8, 9]]]` which suggests `word[0], word[1]` constitute ARG0 (basically \[ARG0: Four men])
83 | 1. `req_pat`: Same information as above, just the list of word indices are replaced with space separated words. Here it is: \[('ARG0', 'Four men'), ('V', 'playing'), ('ARG1', 'dodge ball'), ('ARGM-LOC', 'in an indoor court')]
84 | 1. `req_aname`: Same as `req_pat` just that it only extracts the words without the argument roles. Here it is: \['Four men', 'playing', 'dodge ball', 'in an indoor court']
85 | 1. `req_args`: Instead of the words, only stores the semantic roles. Here it is \['ARG0', 'V', 'ARG1', 'ARGM-LOC']
86 | 1. `gt_bboxes`: The ground-truth boxes (4d) provided in AE for the given vid-seg. It is List\[List\[x1,y1,x2,y2]]
87 | 1. `gt_frms`: The frames (ranging from 0-9) where they are annotated. It is List\[\len(gt_bboxes)]
88 | 1. `process_idx2`: It provides the word index for the given bounding box. It is List\[List\[int]]. Here it is `[[1], [1], [1], [1], [9]]`. Note that `word[1] = men` which means the first four bounding boxes refer to the four men and the final bounding box refers to the `court`.
89 | 1. `process_clss`: Lemmatized Noun for the words in `process_idx2`. Here it is `[['man'], ['man'], ['man'], ['man'], ['court']]`
90 | 1. `req_cls_pats`: Same as `req_pat` with the words replaced with their lemmatized noun. `[('ARG0', ['man']), ('V', ['playing']), ('ARG1', ['dodge', 'ball']), ('ARGM-LOC', ['court'])]`
91 | 1. `req_cls_pats_mask`: It is List\[Tuple\[ArgX, Mask, GTBox Index list]]. ArgX is the Argument Name like Arg0, Mask = 1 means this role has a bounding box, 0 implies the role doesn't have a bounding box and hence is not evaluated. GTBox Index List is the list of indices of the bounding boxes refering to this role. Here it is `[('ARG0', 1, [0, 1, 2, 3]), ('V', 0, [0]), ('ARG1', 0, [0]), ('ARGM-LOC', 1, [4])]` which implies ARG0 and ARGM-LOC are groundable, while V and ARG1 are not. Moreover, the first four bounding boxes refer to ARG0 and the last bounding box refers to ARGM-LOC.
92 | 1. `lemma_ARGX`: The lemmatized verb/argument role used for contrastive sampling.
93 | 1. `DS4_Inds`: For each role, it contains indices for which everything other than the lemmatized word for the argument role is same.
94 | 1. `ds4_msk`: If such contrastive samples were successfully found.
95 | 1. `RandDS4_Inds`: Simply random indices.
96 |
--------------------------------------------------------------------------------
/data/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Downloading script
3 |
4 | CUR_DIR=$(pwd)
5 | DATA_ROOT=${2:-$CUR_DIR}
6 |
7 | mkdir -p $DATA_ROOT/anet
8 |
9 | function gdrive_download () {
10 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')
11 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2
12 | rm -rf /tmp/cookies.txt
13 | }
14 |
15 | function asrl_ann_dwn(){
16 | echo "Downloading ActivityNet SRL annotations"
17 | cd $DATA_ROOT
18 | gdrive_download "1WJTRQVs-vSLmJ7I3sef3_IxE0vBtaqFX" anet_srl.zip
19 | unzip anet_srl.zip && rm anet_srl.zip
20 | cd $CUR_DIR
21 | # The above is minimalistic download and should be fine
22 | # for most cases.
23 | # To get all the files:
24 | # gdrive_download 1qSsD3AbWqw-KNObNg6N8xbTnF-Bg_eZn anet_verb.zip
25 | # unzip anet_verb.zip && rm anet_verb.zip
26 | # gdrive_download 1aZyLNP-VXS3stZpenWMuCTRF_NL2gznu anet_srl_scratch.zip
27 | # unzip anet_srl_scratch.zip && rm anet_srl_scratch.zip
28 | echo "Saved Folder"
29 | }
30 |
31 | function anet_feats_dwn(){
32 | echo "Downloading ActivityNet Feats. May take some time"
33 | # Courtesy of Louwei Zhou, obtained from the repository:
34 | # https://github.com/facebookresearch/grounded-video-description/blob/master/tools/download_all.sh
35 | cd $DATA_ROOT/anet
36 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz
37 | tar -xvzf rgb_motion_1d.tar.gz && rm rgb_motion_1d.tar.gz
38 |
39 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5
40 |
41 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz
42 | tar -xvzf fc6_feat_100rois.tar.gz && rm fc6_feat_100rois.tar.gz
43 |
44 | gdrive_download 13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM fc6_feat_5rois.zip
45 | unzip fc6_feat_5rois.zip && rm fc6_feat_5rois.zip
46 |
47 | gdrive_download 1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9 anet_detn_proposals_resized.zip
48 | unzip anet_detn_proposals_resized.zip && rm anet_detn_proposals_resized.zip
49 | cd $CUR_DIR
50 | }
51 |
52 | function dwn_all(){
53 | asrl_ann_dwn
54 | anet_feats_dwn
55 | }
56 |
57 |
58 | if [ "$1" = "asrl_anns" ]
59 | then
60 | asrl_ann_dwn
61 |
62 | elif [ "$1" = "anet_feats" ]
63 | then
64 | anet_feats_dwn
65 | elif [ "$1" = "all" ]
66 | then
67 | dwn_all
68 | else
69 | echo "Failed: Use download_data.sh asrl_anns | anet_feats | all"
70 | exit 1
71 | fi
72 |
--------------------------------------------------------------------------------
/dcode/README.md:
--------------------------------------------------------------------------------
1 | # Creating ActivityNet SRL (ASRL) from ActivityNet Captions (AC) and ActivityNet Entities (AE)
2 |
3 | The code is for generating the dataset from the parent datasets.
4 | If you just want to use use ASRL as a training bed, you can skip this. See [data](../data)
5 |
6 | ## Quick summary
7 |
8 | Very briefly, the process is as follows:
9 | 1. Add semantic roles to captions in AC.
10 | 1. Prepocess AE. In particular, resize all the proposals, ground-truth bounding boxes (this is
11 | required for SPAT/TEMP).
12 | 1. Preprocess the features and choose only 5 groundtruths for GT5 setting.
13 | 1. Obtain the bounding boxes and category names from AE for the relevant phrases.
14 | 1. Filter out some verbs like "is", "are", "complete", "begin"
15 | 1. Filter some SRL Arguments based on Frequency.
16 | 1. Get Training/Validation/Test videos.
17 | 1. Do Contrastive Sampling and store the dictionary files for easier sampling during training.
18 |
19 | ## Preprocessing Steps
20 |
21 | 1. First download relevant files.
22 | Optional: specify the data folder where it would be downloaded.
23 | ```
24 | bash download_asrl_parent_ann.sh [save_point]
25 | ```
26 | The folder should look like:
27 | ```
28 | anet_cap_ent_files
29 | |-- anet_captions_all_splits.json (AC captions)
30 | |-- anet_entities_test_1.json
31 | |-- anet_entities_test_2.json
32 | |-- anet_entities_val_1.json
33 | |-- anet_entities_val_2.json
34 | |-- cap_anet_trainval.json (AE Train annotations)
35 | |-- dic_anet.json (Train/Valid/Test video splits for AE)
36 | ```
37 |
38 | 1. Use SRL Labeling system from AllenAI (Should take ~15 mins) to add the semantic roles to the captions from AC.
39 | ```
40 | cd $ROOT
41 | python dcode/sem_role_labeller.py
42 | ```
43 |
44 | This will create `$ROOT/cache_dir` and store the output SRL files which should look like:
45 | ```
46 | cache_dir/
47 | |-- SRL_Anet
48 | |-- SRL_Anet_bert_cap_annots.csv # AC annotations in csv format to input into BERT
49 | |-- srl_bert_preds.pkl # BERT outputs
50 | ```
51 |
52 | 1. Resize the boxes in AE.
53 | ```
54 | cd $ROOT
55 | python dcode/preproc_anet_files.py --task='resize_boxes_ae'
56 | ```
57 | This takes the file `cap_anet_trainval.json` as input (this is the main AE annotation file) and outputs `anet_ent_cls_bbox_trainval.json`. The latter file contains resized ground-truth boxes.
58 | It also resizes the proposal boxes, taking in `anet_detection_vg_fc6_feat_100rois.h5` as input and produces `anet_detection_vg_fc6_feat_100rois_resized.h5` as output. The latter contains resized proposals.
59 |
60 | 1. GT5 setting
61 | ```
62 | cd $ROOT
63 | python dcode/preproc_anet_files.py --task='choose_gt_5'
64 | ```
65 | Intially, there are `100` proposals per frame.
66 | For faster iteration, we only choose the 5 proposals from each frame.
67 | If there is a ground-truth box, we take include that box, and the remaining are included in order of their proposal score (not a fair way, but the best that could be done).
68 | If there are no ground-turth box, we choose the top5 scoring proposals.
69 |
70 | To compute the recall scores (for sanity check):
71 | ```
72 | python dcode/preproc_anet_files.py --task='compute_recall'
73 | ```
74 | By default, it computes recall scores for GT5, you can change the proposal file, for other settings.
75 |
76 | 1. Aligning SRL outputs and NounPhrases from AE to create ASRL and adding the bounding boxes to the ASRL files (<1min)
77 | ```
78 | cd $ROOT
79 | python dcode/asrl_creator.py
80 | ```
81 | Now `$ROOT/data/anet_srl_files/` should look like:
82 | ```
83 | anet_srl_files/
84 | |-- verb_ent_file.csv # main file with SRLs, BBoxes
85 | |-- verb_lemma_dict.json # dictionary of verbs corresponding to their lemma
86 | ```
87 |
88 | 1. Use the Train/Val videos from AE to create Train/Val/Test videos for ASRL (~5-7 mins).
89 | Additionally, create the vocab file for the SRL arguments
90 | ```
91 | cd $ROOT
92 | python dcode/prepoc_ds_files.py
93 | ```
94 | This will create `anet_cap_ent_files/csv_dir`. It should look like:
95 | ```
96 | csv_dir
97 | |-- train.csv
98 | |-- train_postproc.csv
99 | |-- val.csv
100 | |-- val_postproc.csv
101 | ```
102 |
103 | Further, now `$ROOT/data/anet_srl_files/` should look like:
104 | ```
105 | anet_srl_files/
106 | |-- trn_verb_ent_file.csv # train file
107 | |-- val_verb_ent_file.csv # val & test file
108 | |-- verb_ent_file.csv
109 | |-- verb_lemma_dict.json
110 | ```
111 |
112 | 1. Do Constrastive sampling for train and validation set (~30mins)
113 | ```
114 | cd $ROOT
115 | python code/contrastive_sampling.py
116 | ```
117 |
118 | Now your `anet_srl_files` directory should look like:
119 | ```
120 | anet_srl_files/
121 | |-- trn_asrl_annots.csv # used for training
122 | |-- trn_srl_obj_to_index_dict.json # used for CS
123 | |-- trn_verb_ent_file.csv # not used anymore
124 | |-- val_asrl_annots.csv # used for val/test
125 | |-- val_srl_obj_to_index_dict.json # used for CS
126 | |-- val_verb_ent_file.csv # not used anymore
127 | |-- verb_ent_file.csv # not used anymore
128 | |-- verb_lemma_dict.json # not used anymore
129 | ```
130 |
131 | 1. I have provided drive links to the processed files (generated after completing all the previous steps):
132 | 1. `anet_cap_ent_files` and `anet_srl_files`: https://drive.google.com/open?id=1mH8TyVPU4w7864Hxiukzg8dnqPIyBuE3
133 | 1. `SRL_Anet`: https://drive.google.com/open?id=1vGgqc8_-ZBk3ExNroRP-On7ArWN-d8du
134 | 1. resized proposal h5 files: https://drive.google.com/open?id=1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9
135 | 1. fc6_feats_5rois: https://drive.google.com/open?id=13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM
136 |
137 | Alternatively, you can download these files from `download_asrl_parent_ann.sh` by passing `asrl_proc_files`:
138 | ```
139 | bash download_asrl_parent_ann.sh asrl_proc_files
140 | ```
141 |
--------------------------------------------------------------------------------
/dcode/dataset_stats.py:
--------------------------------------------------------------------------------
1 | """
2 | Gives the dataset statistics
3 | in form of tables.
4 | Copy-paste to Excel for visualization
5 | """
6 | from yacs.config import CfgNode as CN
7 | import yaml
8 | from asrl_creator import Anet_SRL_Create
9 | from pathlib import Path
10 | import pandas as pd
11 | import ast
12 | from typing import Dict, List, Tuple
13 | from collections import Counter
14 | import altair as alt
15 |
16 |
17 | class AnetSRL_Vis(object):
18 | def __init__(self, cfg, do_vis=True):
19 | self.cfg = cfg
20 | self.open_req_files()
21 | self.vis = do_vis
22 |
23 | def fix_via_ast(self, df):
24 | for k in df.columns:
25 | first_word = df.iloc[0][k]
26 | if isinstance(first_word, str) and (first_word[0] in '[{'):
27 | df[k] = df[k].apply(
28 | lambda x: ast.literal_eval(x))
29 | return df
30 |
31 | def open_req_files(self):
32 | trn_asrl_file = self.cfg.ds.trn_ds4_inds
33 | val_asrl_file = self.cfg.ds.val_ds4_inds
34 |
35 | self.trn_srl_annots = self.fix_via_ast(pd.read_csv(trn_asrl_file))
36 | self.val_srl_annots = self.fix_via_ast(pd.read_csv(val_asrl_file))
37 |
38 | def print_most_common_table(self, most_comm: List[Tuple]):
39 | """
40 | Prints most common output from a Counter in the
41 | form of a table for easy copy/pasting
42 | """
43 | patt = '{}, {}\n'
44 | out_str = ''
45 | for it in most_comm:
46 | out_str += patt.format(*it)
47 | print(out_str)
48 | return
49 |
50 | def visualize_df(self, df: pd.DataFrame,
51 | x_name: str, y_name: str):
52 | bars = alt.Chart(df).mark_bar(
53 | cornerRadiusBottomRight=3,
54 | cornerRadiusTopRight=3,
55 | ).encode(
56 | x=alt.X(x_name, axis=alt.Axis(title="")),
57 | y=alt.Y(y_name, axis=alt.Axis(title=""),
58 | sort='-x'),
59 | color=alt.value('#6495ED')
60 | )
61 | text = bars.mark_text(
62 | align='left',
63 | baseline='middle',
64 | dx=3 # Nudges text to right so it doesn't appear on top of the bar
65 | ).encode(
66 | text='Count:Q'
67 | )
68 |
69 | return (bars + text).properties(height=500)
70 |
71 | def get_num_vids(self):
72 | """
73 | Input dictionary with train and validation df
74 | """
75 | nvids = {}
76 | nvids['train'] = len(self.trn_srl_annots.vid_seg.unique())
77 | nvids['valid'] = len(
78 | self.val_srl_annots[
79 | self.val_srl_annots.vt_split == 'val'
80 | ].vid_seg.unique()
81 | )
82 | nvids['test'] = len(
83 | self.val_srl_annots[
84 | self.val_srl_annots.vt_split == 'test'
85 | ].vid_seg.unique()
86 | )
87 | return nvids
88 |
89 | def get_num_noun_phrase(self):
90 | """
91 | Return number of noun-phrases for
92 | each SRL
93 | """
94 | # req_cls_pats_mask: [['ArgX', 1/0, box_num]]
95 | # get only the argument name and count
96 | arg_counts = self.trn_srl_annots.req_cls_pats_mask.apply(
97 | lambda x: [y[0] for y in x]
98 | )
99 | return Counter([ac for acs in arg_counts for ac in acs])
100 |
101 | def get_num_phrase_with_box(self):
102 | # req_cls_pats_mask: [['ArgX', 1/0, box_num]]
103 | # get only the argument name and count
104 | arg_counts = self.trn_srl_annots.req_cls_pats_mask.apply(
105 | lambda x: [y[0] for y in x if y[1] == 1]
106 | )
107 | return Counter([ac for acs in arg_counts for ac in acs])
108 |
109 | def get_num_srl_structures(self):
110 | arg_struct_counts = self.trn_srl_annots.req_args.apply(
111 | lambda x: '-'.join(x)
112 | )
113 | return Counter(list(arg_struct_counts)).most_common(20)
114 |
115 | def get_num_lemma(self, arg_list):
116 | lemma_counts = {}
117 | col_set = set(self.trn_srl_annots.columns)
118 | for agl in arg_list:
119 | if agl != 'verb':
120 | lemma_key = f'lemma_{agl}'
121 | assert lemma_key in col_set
122 | lemma_counts[lemma_key] = Counter(
123 | list(
124 | self.trn_srl_annots[lemma_key].apply(
125 | lambda x: x[0] if len(x) > 0 else ''
126 | )
127 | )
128 | )
129 | else:
130 | lemma_key = 'lemma_verb'
131 | lemma_counts[lemma_key] = Counter(
132 | list(
133 | self.trn_srl_annots[lemma_key]
134 | )
135 | )
136 | return lemma_counts
137 |
138 | def get_num_q_per_vid(self):
139 | num_q_per_vid = (
140 | len(self.trn_srl_annots) /
141 | len(self.trn_srl_annots.vid_seg.unique())
142 | )
143 |
144 | num_srl_per_q = self.trn_srl_annots.req_args.apply(
145 | lambda x: len(x)).mean()
146 |
147 | num_w_per_q = self.trn_srl_annots.req_pat_ix.apply(
148 | lambda x: sum([len(y[1]) for y in x])).mean()
149 |
150 | return num_q_per_vid, num_srl_per_q, num_w_per_q
151 |
152 | def print_all_stats(self):
153 | vis_list = []
154 | nvid = self.get_num_vids()
155 | print("Number of videos in Train/Valid/Test: "
156 | f"{nvid['train']}, {nvid['valid']}, {nvid['test']}")
157 |
158 | num_q_per_vid, num_srl_per_q, num_w_per_q = self.get_num_q_per_vid()
159 | print(f"Number of Queries per Video is {num_q_per_vid}")
160 | print(f"Number of Queries per Video is {num_srl_per_q}")
161 | print(f"Number of Queries per Video is {num_w_per_q}")
162 |
163 | num_noun_phrases_for_srl = self.get_num_noun_phrase().most_common(n=20)
164 | num_np_srl = pd.DataFrame.from_records(
165 | data=num_noun_phrases_for_srl,
166 | columns=['Arg', 'Count']
167 | )
168 | if self.vis:
169 | vis_list.append(
170 | self.visualize_df(num_np_srl, x_name='Count:Q', y_name='Arg:O')
171 | )
172 | print('Noun Phrases Count')
173 | print(num_np_srl.to_csv(index=False))
174 |
175 | num_noun_phrases_with_box_for_srl = self.get_num_phrase_with_box()
176 |
177 | num_grnd_np_srl = pd.DataFrame.from_records(
178 | data=num_noun_phrases_with_box_for_srl.most_common(n=20),
179 | columns=['Arg', 'Count']
180 | )
181 | if self.vis:
182 | vis_list.append(
183 | self.visualize_df(
184 | num_grnd_np_srl, x_name='Count:Q', y_name='Arg:O')
185 | )
186 | print('Groundable Noun Phrase Count')
187 | print(num_grnd_np_srl.to_csv(index=False))
188 |
189 | num_srl_struct = self.get_num_srl_structures()
190 | num_srl_struct_df = pd.DataFrame.from_records(
191 | data=num_srl_struct,
192 | columns=['Arg', 'Count']
193 | )
194 | if self.vis:
195 | vis_list.append(
196 | self.visualize_df(num_srl_struct_df,
197 | x_name='Count:Q', y_name='Arg:O')
198 | )
199 | print('SRL Structures Frequency')
200 | print(num_srl_struct_df.to_csv(index=False))
201 |
202 | arg_list = ['verb', 'ARG0', 'ARG1', 'ARG2', 'ARGM_LOC']
203 | lemma_counts = self.get_num_lemma(arg_list)
204 | min_t = 20
205 | num_lemma_args = {
206 | k: len([z for z in v.most_common() if z[1] > min_t])
207 | for k, v in lemma_counts.items()
208 | }
209 | print(f'Lemmatized Counts for each lemma: {num_lemma_args}')
210 |
211 | df_dict = {
212 | k: pd.DataFrame.from_records(
213 | data=v.most_common(21),
214 | columns=['String', 'Count']
215 | )
216 | for k, v in lemma_counts.items()
217 | }
218 |
219 | for k in df_dict:
220 | print(f'Most Frequent Lemmas for {k}')
221 | print(df_dict[k].to_csv(index=False))
222 |
223 | return lemma_counts
224 | # return vis_list
225 |
226 |
227 | if __name__ == '__main__':
228 | cfg = CN(yaml.safe_load(open('./configs/anet_srl_cfg.yml')))
229 | asrl_vis = AnetSRL_Vis(cfg)
230 | asrl_vis.print_all_stats()
231 |
--------------------------------------------------------------------------------
/dcode/download_asrl_parent_ann.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Downloading script
3 |
4 | CUR_DIR=$(pwd)
5 | DDIR=${2:-"../data"}
6 | DATA_ROOT=$DDIR/anet_cap_ent_files
7 |
8 | echo $DATA_ROOT
9 | mkdir -p $DDIR/anet_srl_files
10 | mkdir -p $DATA_ROOT
11 |
12 | function gdrive_download () {
13 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')
14 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2
15 | rm -rf /tmp/cookies.txt
16 | }
17 |
18 | function anet_feats_dwn(){
19 | echo "Downloading ActivityNet Feats. May take some time"
20 | # Courtesy of Louwei Zhou, obtained from the repository:
21 | # https://github.com/facebookresearch/grounded-video-description/blob/master/tools/download_all.sh
22 | cd $DDIR/anet
23 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz
24 | tar -xvzf rgb_motion_1d.tar.gz && rm rgb_motion_1d.tar.gz
25 |
26 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5
27 |
28 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz
29 | tar -xvzf fc6_feat_100rois.tar.gz && rm fc6_feat_100rois.tar.gz
30 |
31 | # gdrive_download 13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM fc6_feat_5rois.zip
32 | # unzip fc6_feat_5rois.zip && rm fc6_feat_5rois.zip
33 |
34 | # gdrive_download 1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9 anet_detn_proposals_resized.zip
35 | # unzip anet_detn_proposals_resized.zip && rm anet_detn_proposals_resized.zip
36 | cd $CUR_DIR
37 | }
38 |
39 | function ac_ae_dwn(){
40 | echo "Downloading ActivityNet Captions and ActivityNet Entities"
41 | cd $DATA_ROOT
42 | # Courtesy of Louwei Zhou, obtained from the repository:
43 | # https://github.com/facebookresearch/grounded-video-description
44 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_prep.tar.gz
45 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz
46 | wget https://raw.githubusercontent.com/facebookresearch/ActivityNet-Entities/master/data/anet_entities_cleaned_class_thresh50_trainval.json
47 | tar -xvzf anet_entities_prep.tar.gz && rm anet_entities_prep.tar.gz
48 | tar -xvzf anet_entities_captions.tar.gz && rm anet_entities_captions.tar.gz
49 | cd $CUR_DIR
50 | echo "Finished downloading ActivityNet Captions and ActivityNet Entities"
51 | }
52 |
53 | function processed_files_dwn(){
54 | echo "Downloading ASRL processed files"
55 | cd $DDIR
56 | mkdir asrl_processed_files
57 | cd asrl_processed_files
58 | gdrive_download "1mH8TyVPU4w7864Hxiukzg8dnqPIyBuE3" anet_srl_files_all.zip
59 | gdrive_download "1vGgqc8_-ZBk3ExNroRP-On7ArWN-d8du" SRL_Anet.zip
60 | gdrive_download "1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9" anet_detn_proposals_resized.zip
61 | # gdrive_download "13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM" fc6_feat_5rois.zip
62 | cd $CUR_DIR
63 | }
64 |
65 | function dwn_all(){
66 | ac_ae_dwn
67 | anet_feats_dwn
68 | }
69 |
70 | if [ "$1" = "ac_ae_anns" ]
71 | then
72 | ac_ae_dwn
73 | elif [ "$1" = "anet_feats" ]
74 | then
75 | anet_feats_dwn
76 | elif [ "$1" = "asrl_proc_files" ]
77 | then
78 | processed_files_dwn
79 | elif [ "$1" = "all" ]
80 | then
81 | dwn_all
82 | else
83 | echo "Failed: Use download_asrl_parent_ann.sh ac_ae_anns | anet_feats | asrl_proc_files | all"
84 | exit 1
85 | fi
86 |
--------------------------------------------------------------------------------
/dcode/preproc_anet_files.py:
--------------------------------------------------------------------------------
1 | """
2 | Small preprocessing done for Anet files
3 | In particular:
4 | [ ] Add 'He' to 'man', 'boy', similarly for 'she' to 'woman', 'girl', 'lady'
5 | [ ] Resize ground-truth box
6 | """
7 |
8 | import json
9 | from pathlib import Path
10 | from yacs.config import CfgNode as CN
11 | import yaml
12 | from tqdm import tqdm
13 | import h5py
14 | import pandas as pd
15 | import numpy as np
16 | from utils.box_utils import box_iou
17 | import copy
18 | import torch
19 | from collections import OrderedDict
20 | import fire
21 |
22 |
23 | class AnetEntFiles:
24 | def __init__(self, cfg):
25 | self.cfg = cfg
26 | self.conv_dict = {
27 | 'man': 'he',
28 | 'boy': 'he',
29 | 'woman': 'she',
30 | 'girl': 'she',
31 | 'lady': 'she'
32 | }
33 | self.open_req_files()
34 |
35 | def open_req_files(self):
36 |
37 | self.trn_anet_ent_file = Path(self.cfg.ds.anet_ent_annot_file)
38 | assert self.trn_anet_ent_file.exists()
39 | self.trn_anet_ent_data = json.load(
40 | open(self.trn_anet_ent_file))
41 |
42 | self.trn_anet_ent_preproc_file = Path(
43 | self.cfg.ds.preproc_anet_ent_clss)
44 | assert self.trn_anet_ent_preproc_file.parent.exists()
45 |
46 | self.vid_dict_df = pd.DataFrame(json.load(
47 | open(self.cfg.ds.anet_ent_split_file))['videos'])
48 | self.vid_dict_df.index.name = 'Index'
49 |
50 | # Assert region features exists
51 | self.feature_root = Path(self.cfg.ds.feature_root)
52 | assert self.feature_root.exists()
53 |
54 | self.feature_root_gt5 = Path(self.cfg.ds.feature_gt5_root)
55 | self.feature_root_gt5.mkdir(exist_ok=True)
56 | assert self.feature_root_gt5.exists()
57 |
58 | def run(self):
59 | # out_ann = self.get_vidseg_hw_map(
60 | # ann=self.trn_anet_ent_orig_data['annotations'])
61 | out_ann = self.get_vidseg_hw_map(
62 | ann=self.trn_anet_ent_data)
63 |
64 | json.dump(out_ann, open(self.trn_anet_ent_preproc_file, 'w'))
65 | self.resize_props()
66 |
67 | def add_pronouns(self, ann):
68 | def upd(segv):
69 | """
70 | segv: Dict.
71 | Keys: 'process_clss' etc
72 | Update the values for process_clss
73 | """
74 | pck = 'process_clss'
75 | if pck not in segv:
76 | pck = 'clss'
77 | assert pck in segv
78 | proc_clss = segv[pck][:]
79 | assert isinstance(proc_clss, list)
80 | if len(proc_clss) == 0:
81 | return
82 | assert isinstance(proc_clss[0], list)
83 | new_proc_clss = []
84 | for pc in proc_clss:
85 | new_pc = []
86 | for p in pc:
87 | if p in self.conv_dict:
88 | new_pc.append(p)
89 | new_pc.append(self.conv_dict[p])
90 | else:
91 | new_pc.append(p)
92 | new_proc_clss.append(new_pc)
93 | segv[pck] = new_proc_clss
94 | return
95 | out_dict_vid = {}
96 | for vidk, vidv in tqdm(ann.items()):
97 | out_dict_seg_vid = {}
98 | for segk, segv in vidv['segments'].items():
99 | upd(segv)
100 | out_dict_seg_vid[segk] = segv
101 | out_dict_vid[vidk] = {'segments': out_dict_seg_vid}
102 |
103 | return out_dict_vid
104 |
105 | def get_vidseg_hw_map(self, ann=None):
106 | def upd(segv, sw, sh):
107 | """
108 | segv: Dict
109 | Change process_bnd_box wrt hw
110 | """
111 | pbk = 'process_bnd_box'
112 | if pbk not in segv:
113 | pbk = 'bbox'
114 | assert pbk in segv
115 | if len(segv[pbk]) == 0:
116 | return
117 | process_bnd_box = np.array(
118 | segv[pbk][:]).astype(float)
119 | process_bnd_box[:, [0, 2]] *= sw
120 | process_bnd_box[:, [1, 3]] *= sh
121 | process_bnd_box = process_bnd_box.astype(int)
122 | segv[pbk] = process_bnd_box.tolist()
123 | return
124 |
125 | vid_dict_df = self.vid_dict_df
126 |
127 | h5_proposal_file = h5py.File(
128 | self.cfg.ds.proposal_h5, 'r', driver='core')
129 |
130 | # num_proposals = h5_proposal_file['dets_num'][:]
131 | # label_proposals = h5_proposal_file['dets_labels'][:]
132 |
133 | hw_vids = h5_proposal_file['hw'][:].astype(float).tolist()
134 | out_dict = {}
135 | for row_ind, row in tqdm(vid_dict_df.iterrows()):
136 | vid_id = row['vid_id']
137 | if vid_id not in out_dict:
138 | out_dict[vid_id] = hw_vids[row_ind]
139 | else:
140 | hw = hw_vids[row_ind]
141 | if not hw == [0., 0.]:
142 | assert hw == out_dict[vid_id]
143 | json.dump(out_dict, open(self.cfg.ds.vid_hw_map, 'w'))
144 |
145 | nw = self.cfg.ds.resized_width
146 | nh = self.cfg.ds.resized_height
147 | out_dict_vid = {}
148 | for vidk, vidv in tqdm(ann.items()):
149 | out_dict_seg_vid = {}
150 | oh, ow = out_dict[vidk]
151 | if ow != 0. or oh != 0.:
152 | sw = nw / ow
153 | sh = nh / oh
154 | else:
155 | sw, sh = 1., 1.
156 | for segk, segv in vidv['segments'].items():
157 | upd(segv, sw*1., sh*1.)
158 | out_dict_seg_vid[segk] = segv
159 | out_dict_vid[vidk] = {'segments': out_dict_seg_vid}
160 |
161 | return out_dict_vid
162 |
163 | def resize_props(self):
164 | h5_proposal_file = h5py.File(
165 | self.cfg.ds.proposal_h5, 'r', driver='core')
166 |
167 | hw_vids = h5_proposal_file['hw'][:].astype(float).tolist()
168 | label_proposals = h5_proposal_file['dets_labels'][:]
169 |
170 | nw = self.cfg.ds.resized_width
171 | nh = self.cfg.ds.resized_height
172 |
173 | for row_ind in tqdm(range(len(label_proposals))):
174 | oh, ow = hw_vids[row_ind]
175 | if ow != 0. or oh != 0.:
176 | sw = nw / ow
177 | sh = nh / oh
178 | else:
179 | sw, sh = 1., 1.
180 |
181 | label_proposals[row_ind, :, [0, 2]] *= sw
182 | label_proposals[row_ind, :, [1, 3]] *= sh
183 | with h5py.File(self.cfg.ds.proposal_h5_resized, 'w') as f:
184 | keys = [k for k in h5_proposal_file.keys()]
185 | for k in keys:
186 | if k != 'dets_labels':
187 | f.create_dataset(k, data=h5_proposal_file[k])
188 | else:
189 | f.create_dataset(k, data=label_proposals)
190 |
191 | return
192 |
193 | def choose_gt5(self, save=True):
194 | """
195 | Choose 5 proposals for each frame
196 | """
197 | h5_proposal_file = h5py.File(
198 | self.cfg.ds.proposal_h5_resized, 'r', driver='core')
199 | # h5_proposal_file = h5py.File(
200 | # self.cfg.ds.proposal_h5, 'r', driver='core')
201 |
202 | nppf_orig = 100
203 | nppf = self.cfg.ds.ngt_prop
204 | nfrms = self.cfg.ds.num_frms
205 | # Note these are resized labels
206 | label_proposals = h5_proposal_file['dets_labels'][:]
207 | num_proposals = h5_proposal_file['dets_num'][:]
208 | out_label_proposals = np.zeros_like(
209 | label_proposals)[:, :nfrms*nppf, ...]
210 | out_num_proposals = np.zeros_like(num_proposals)
211 | vid_dict_df = self.vid_dict_df
212 |
213 | anet_ent_preproc_data = json.load(open(self.trn_anet_ent_preproc_file))
214 | # anet_ent_preproc_data = json.load(
215 | # open(self.cfg.ds.anet_ent_annot_file))
216 |
217 | recall_num = 0
218 | recall_tot = 0
219 |
220 | for row_ind, row in tqdm(vid_dict_df.iterrows(),
221 | total=len(vid_dict_df)):
222 | # if row_ind > 1000:
223 | # break
224 | vid = row['vid_id']
225 | seg = row['seg_id']
226 | vid_seg_id = row['id']
227 |
228 | annot = anet_ent_preproc_data[vid]['segments'][seg]
229 | gt_boxs = annot['bbox']
230 | gt_frms = annot['frm_idx']
231 |
232 | prop_index = row_ind
233 |
234 | props = copy.deepcopy(label_proposals[prop_index])
235 | num_props = int(copy.deepcopy(num_proposals[prop_index]))
236 |
237 | if num_props < nfrms * nppf_orig:
238 | # import pdb
239 | # pdb.set_trace()
240 | assert np.all(props[num_props:, [0, 1, 2, 3]] == 0)
241 |
242 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy'
243 | if not region_feature_file.exists():
244 | continue
245 | prop_feats_load = np.load(region_feature_file)
246 | prop_feats = np.zeros((nfrms, *prop_feats_load.shape[1:]))
247 | prop_feats[:prop_feats_load.shape[0]] = prop_feats_load
248 |
249 | out_file = self.feature_root_gt5 / f'{vid_seg_id}.npy'
250 | out_dict = self.choose_gt5_for_one_vid_seg(
251 | props, prop_feats, gt_boxs, gt_frms, out_file,
252 | save=save, nppf=nppf, nppf_orig=nppf_orig, nfrms=nfrms
253 | )
254 |
255 | if save:
256 | num_prop = out_dict['num_prop']
257 | out_label_proposals[prop_index][:num_prop] = (
258 | out_dict['out_props']
259 | )
260 | out_num_proposals[prop_index] = num_prop
261 |
262 | recall_num += out_dict['recall']
263 | recall_tot += out_dict['num_gt']
264 |
265 | recall = recall_num.item() / recall_tot
266 | print(f'Recall is {recall}')
267 | if save:
268 | with h5py.File(self.cfg.ds.proposal_gt5_h5_resized, 'w') as f:
269 | keys = [k for k in h5_proposal_file.keys()]
270 | keys.remove('dets_labels')
271 | keys.remove('dets_num')
272 | for k in keys:
273 | f.create_dataset(k, data=h5_proposal_file[k])
274 |
275 | f.create_dataset('dets_labels', data=out_label_proposals)
276 | f.create_dataset('dets_num', data=out_num_proposals)
277 |
278 | return recall
279 |
280 | def choose_gt5_for_one_vid_seg(
281 | self, props, prop_feats,
282 | gt_boxs, gt_frms, out_file,
283 | save=True, nppf=5, nppf_orig=100, nfrms=10):
284 | """
285 | Choose for 5 props per frame
286 | """
287 | # Convert to torch tensors for box_iou computations
288 | # props: 10*100 x 7
289 | props = torch.tensor(props).float()
290 | prop_feats = torch.tensor(prop_feats).float()
291 | # set for comparing
292 | gt_frms_set = set(gt_frms)
293 | gt_boxs = torch.tensor(gt_boxs).float()
294 | gt_frms = torch.tensor(gt_frms).float()
295 |
296 | # Get the frames for the proposal boxes are
297 | prop_frms = props[:, 4]
298 | # Create a frame mask.
299 | # Basically, if the iou = 0 if the proposal and
300 | # the ground truth box lie in different frames
301 | frm_msk = prop_frms[:, None] == gt_frms
302 | if len(gt_boxs) > 0 and len(props) > 0:
303 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float()
304 | # get the max iou proposal for each bounding box
305 | ious_max, ious_arg_max = ious.max(dim=0)
306 | # if len(ious_arg_max) > nppf:
307 | # ious_arg_max = ious_arg_max[:nppf]
308 | out_props = props[ious_arg_max]
309 | out_props_inds = ious_arg_max % 100
310 | recall = (ious_max > 0.5).sum()
311 | ngt = len(gt_boxs)
312 | else:
313 | ngt = 1
314 | recall = 0
315 | ious = torch.zeros(props.size(0), 1)
316 | out_props = props[0]
317 | out_props_inds = torch.tensor(0)
318 |
319 | # Dictionary to store final proposals to use
320 | fin_out_props = {}
321 | # Reshape proposals and proposal features to
322 | # nfrms x nppf x ndim
323 | props1 = props.view(nfrms, nppf_orig, 7)
324 | prop_dim = prop_feats.size(-1)
325 | prop_feats1 = prop_feats.view(nfrms, nppf_orig, prop_dim)
326 |
327 | # iterate over each frame
328 | for frm in range(nfrms):
329 | if frm not in fin_out_props:
330 | fin_out_props[frm] = []
331 |
332 | # if there are gt boxes in the frame
333 | # consider the proposals which have highest iou
334 | # in the frame
335 | if frm in gt_frms_set:
336 | props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm]
337 | # add highest iou props to the dict key
338 | fin_out_props[frm] += props_inds_gt_in_frm.tolist()
339 |
340 | # sort by their scores, and choose nppf=5 such props
341 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[
342 | :nppf]
343 | # add 5 such props to the list
344 | fin_out_props[frm] += props_to_use_inds.tolist()
345 |
346 | # Restrict the total to 5
347 | fin_out_props[frm] = list(
348 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf]
349 |
350 | # Saving them, init with zeros
351 | props_output = torch.zeros(nfrms, nppf, 7)
352 | prop_feats_output = torch.zeros(nfrms, nppf, prop_dim)
353 |
354 | # set for each frame
355 | for frm in fin_out_props:
356 | inds = fin_out_props[frm]
357 | props_output[frm] = props1[frm][inds]
358 | prop_feats_output[frm] = prop_feats1[frm][inds]
359 |
360 | # Reshape nfrm x nppf x ndim -> nfrm*nppf x ndim
361 | props_output = props_output.view(nfrms*nppf, 7).detach().cpu().numpy()
362 | prop_feats_output = prop_feats_output.view(
363 | nfrms, nppf, prop_dim).detach().cpu().numpy()
364 |
365 | if save:
366 | np.save(out_file, prop_feats_output)
367 |
368 | return {
369 | 'out_props': props_output,
370 | 'recall': recall,
371 | 'num_prop': nppf*nfrms,
372 | 'num_gt': ngt
373 | }
374 |
375 | def compute_recall(self, exp_setting='gt5'):
376 | """
377 | Compute recall for the created h5 file
378 | """
379 | if exp_setting == 'gt5':
380 | pfile = self.cfg.ds.proposal_gt5_h5_resized
381 | elif exp_setting == 'p100':
382 | pfile = self.cfg.ds.proposal_h5_resized
383 |
384 | with h5py.File(pfile, 'r') as f:
385 | label_proposals = f['dets_labels'][:]
386 |
387 | vid_dict_df = self.vid_dict_df
388 |
389 | anet_ent_preproc_data = json.load(open(self.trn_anet_ent_preproc_file))
390 |
391 | recall_num = 0
392 | recall_tot = 0
393 |
394 | for row_ind, row in tqdm(vid_dict_df.iterrows(),
395 | total=len(vid_dict_df)):
396 |
397 | vid = row['vid_id']
398 | seg = row['seg_id']
399 | vid_seg_id = row['id']
400 |
401 | annot = anet_ent_preproc_data[vid]['segments'][seg]
402 | gt_boxs = torch.tensor(annot['bbox']).float()
403 | gt_frms = annot['frm_idx']
404 |
405 | prop_index = row_ind
406 |
407 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy'
408 | if not region_feature_file.exists():
409 | continue
410 |
411 | props = copy.deepcopy(label_proposals[prop_index])
412 | props = torch.tensor(props).float()
413 | # props = props.view(10, -1, 7)
414 |
415 | for fidx, frm in enumerate(gt_frms):
416 | prop_frms = props[props[..., 4] == frm]
417 | gt_box_in_frm = gt_boxs[fidx]
418 |
419 | ious = box_iou(prop_frms[:, :4], gt_box_in_frm)
420 |
421 | ious_max, ious_arg_max = ious.max(dim=0)
422 | # conversion to long is important, otherwise
423 | # after 256 becomes 0
424 | recall_num += (ious_max > 0.5).any().long()
425 |
426 | recall_tot += len(gt_boxs)
427 |
428 | recall = recall_num.item() / recall_tot
429 | print(f'Recall is {recall}')
430 | return
431 |
432 |
433 | def main(task: str, exp_setting='gt5'):
434 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))
435 | anet_pre = AnetEntFiles(cfg)
436 | if 'resize_boxes_ae' in task:
437 | anet_pre.run()
438 | if 'choose_gt5' in task:
439 | anet_pre.choose_gt5(save=True)
440 | if 'compute_recall' in task:
441 | anet_pre.compute_recall(exp_setting)
442 |
443 |
444 | if __name__ == '__main__':
445 | fire.Fire(main)
446 | # cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))
447 | # anet_pre = AnetEntFiles(cfg)
448 | # anet_pre.compute_recall()
449 | # anet_pre.choose_gt5(save=True)
450 | # anet_pre.add_pronouns()
451 | # anet_pre.get_vidseg_hw_map()
452 | # anet_pre.run()
453 | # anet_pre.resize_props()
454 |
--------------------------------------------------------------------------------
/dcode/preproc_ds_files.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess dataset files
3 | """
4 |
5 | import json
6 | import pandas as pd
7 | from pathlib import Path
8 | from tqdm import tqdm
9 | import yaml
10 | from yacs.config import CfgNode as CN
11 | import numpy as np
12 | import ast
13 | from collections import Counter
14 | from torchtext import vocab
15 | import pickle
16 | from munch import Munch
17 |
18 |
19 | np.random.seed(5)
20 |
21 |
22 | class AnetCSV:
23 | def __init__(self, cfg, comm=None):
24 | self.cfg = cfg
25 | if comm is not None:
26 | assert isinstance(comm, (dict, Munch))
27 | self.comm = Munch(comm)
28 | else:
29 | self.comm = Munch()
30 |
31 | inp_anet_dict_fpath = cfg.ds.anet_ent_split_file
32 | self.inp_dict_file = Path(inp_anet_dict_fpath)
33 |
34 | # Create directory to keep the generated csvs
35 | self.out_csv_dir = self.inp_dict_file.parent / 'csv_dir'
36 | self.out_csv_dir.mkdir(exist_ok=True)
37 |
38 | # Structure of anet_dict:
39 | # anet = Dict,
40 | # keys: 1. word to lemma, 2. index to word,
41 | # 3. word to detection 4. video information
42 | # We only need the video information
43 | self.vid_dict_list = json.load(open(inp_anet_dict_fpath))['videos']
44 |
45 | def create_csvs(self):
46 | """
47 | Create the Train/Val split videos
48 | """
49 | self.vid_info_df = pd.DataFrame(self.vid_dict_list)
50 | self.vid_info_df.index.name = 'Index'
51 |
52 | train_df = self.vid_info_df[self.vid_info_df.split == 'training']
53 | train_df.to_csv(self.out_csv_dir / 'train.csv',
54 | index=True, header=True)
55 |
56 | # NOTE: Test files don't have the annotations, so cannot be used.
57 | # Instead we split the validation dataframe into two parts (50/50)
58 |
59 | val_test_df = self.vid_info_df[self.vid_info_df.split == 'validation']
60 |
61 | # Randomly take half as validation, rest as test
62 | # Both are saved in val.csv, however, during evaluation
63 | # only those with "val" in the field "vt_split" are chosen
64 | msk = np.random.rand(len(val_test_df)) < 0.5
65 | val_test_df['vt_split'] = ['val' if m == 1 else 'test' for m in msk]
66 | val_test_df.to_csv(self.out_csv_dir / 'val.csv',
67 | index=True, header=True)
68 |
69 | def post_proc(self, csv_file_type):
70 | """
71 | Some videos don't have features. These are removed
72 | for convenience.
73 | (only 4-5 videos were removed)
74 | """
75 | self.seg_feature_root = Path(self.cfg.ds.seg_feature_root)
76 | assert self.seg_feature_root.exists()
77 |
78 | self.feature_root = Path(self.cfg.ds.feature_root)
79 | assert self.feature_root.exists()
80 |
81 | csv_file = self.out_csv_dir / f'{csv_file_type}.csv'
82 | csv_df = pd.read_csv(csv_file)
83 | msk = []
84 | num_segs_list = []
85 | for row_ind, row in tqdm(csv_df.iterrows(), total=len(csv_df)):
86 | vid_seg_id = row['id']
87 | vid_id = row['vid_id']
88 | num_segs = csv_df[csv_df.vid_id == vid_id].seg_id.max() + 1
89 | num_segs_list.append(num_segs)
90 |
91 | vid_id_ix, seg_id_ix = vid_seg_id.split('_segment_')
92 | seg_rgb_file = self.seg_feature_root / \
93 | f'{vid_id_ix[2:]}_resnet.npy'
94 | seg_motion_file = self.seg_feature_root / f'{vid_id_ix[2:]}_bn.npy'
95 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy'
96 | out = (seg_rgb_file.exists() and seg_motion_file.exists()
97 | and region_feature_file.exists())
98 | msk.append(out)
99 |
100 | csv_df['num_segs'] = num_segs_list
101 | csv_df = csv_df[msk]
102 | csv_df.to_csv(self.out_csv_dir /
103 | f'{csv_file_type}_postproc.csv', index=False, header=True)
104 |
105 | def post_proc_srl(self, train_file, val_file, test_file=None):
106 | """
107 | Add the Index to each csv file
108 | This is required to get the correct proposals from h5 file
109 | """
110 | def get_row_id(vid_seg, ann_df):
111 | vid_dict_row = ann_df[ann_df.id ==
112 | vid_seg]
113 | if len(vid_dict_row) == 1:
114 | vid_dict_row_id = vid_dict_row.index[0]
115 | return vid_dict_row_id
116 | else:
117 | return -1
118 |
119 | self.vid_info_df = pd.DataFrame(self.vid_dict_list)
120 | self.vid_info_df.index.name = 'Index'
121 |
122 | trn_ann_df = pd.read_csv(
123 | self.out_csv_dir / f'{train_file}_postproc.csv')
124 | val_ann_df = pd.read_csv(self.out_csv_dir / f'{val_file}_postproc.csv')
125 |
126 | srl_trn_val = pd.read_csv(self.cfg.ds.verb_ent_file)
127 |
128 | trn_ann_ind = []
129 | trn_msk = []
130 |
131 | val_ann_ind = []
132 | val_msk = []
133 | vt_msk = []
134 |
135 | for srl_ind, srl in tqdm(srl_trn_val.iterrows(),
136 | total=len(srl_trn_val)):
137 | req_args = ast.literal_eval(srl.req_args)
138 | if len(req_args) == 1:
139 | continue
140 | vid_seg = srl.vid_seg
141 | vid_dict_row = self.vid_info_df[self.vid_info_df.id == vid_seg]
142 | assert len(vid_dict_row) == 1
143 | vid_dict_row = vid_dict_row.iloc[0]
144 | split = vid_dict_row.split
145 |
146 | if split == 'training':
147 | ann_ind = get_row_id(vid_seg, trn_ann_df)
148 | if ann_ind == -1:
149 | print(split, vid_seg)
150 | continue
151 | trn_ann_ind.append(ann_ind)
152 | trn_msk.append(srl_ind)
153 | elif split == 'validation':
154 | ann_ind = get_row_id(vid_seg, val_ann_df)
155 | if ann_ind == -1:
156 | print(split, vid_seg)
157 | continue
158 | val_ann_ind.append(ann_ind)
159 | val_msk.append(srl_ind)
160 | vt_msk.append(val_ann_df.loc[ann_ind].vt_split)
161 | elif split == 'testing':
162 | pass
163 | else:
164 | raise NotImplementedError
165 |
166 | srl_trn = srl_trn_val.iloc[trn_msk]
167 | srl_trn['ann_ind'] = trn_ann_ind
168 | srl_trn['srl_ind'] = trn_msk
169 | srl_trn['vt_split'] = 'train'
170 |
171 | srl_val = srl_trn_val.iloc[val_msk]
172 | srl_val['ann_ind'] = val_ann_ind
173 | srl_val['srl_ind'] = val_msk
174 | srl_val['vt_split'] = vt_msk
175 |
176 | srl_trn.to_csv(self.cfg.ds.trn_verb_ent_file,
177 | index=False, header=True)
178 | srl_val.to_csv(self.cfg.ds.val_verb_ent_file,
179 | index=False, header=True)
180 |
181 | def process_arg_vocabs(self):
182 | def create_vocab(srl_annots, key):
183 | x_counter = Counter()
184 | for x_c in srl_annots[key]:
185 | x_counter += Counter(x_c)
186 | return vocab.Vocab(x_counter, specials_first=True)
187 | srl_annots = pd.read_csv(self.cfg.ds.trn_verb_ent_file)
188 | for k in srl_annots.columns:
189 | first_word = srl_annots.iloc[0][k]
190 | if isinstance(first_word, str) and first_word[0] == '[':
191 | srl_annots[k] = srl_annots[k].apply(
192 | lambda x: ast.literal_eval(x))
193 |
194 | # arg_counter = Counter()
195 | # for r_arg in srl_annots.req_args:
196 | # arg_counter += Counter(r_arg)
197 |
198 | # arg_vocab = vocab.Vocab(arg_counter, specials_first=True)
199 | arg_vocab = create_vocab(srl_annots, 'req_args')
200 | arg_tag_vocab = create_vocab(srl_annots, 'tags')
201 | out_vocab = {'arg_vocab': arg_vocab, 'arg_tag_vocab': arg_tag_vocab}
202 | pickle.dump(out_vocab, file=open(self.cfg.ds.arg_vocab_file, 'wb'))
203 | return
204 |
205 | def glove_vocabs(self):
206 | # Load dictionaries
207 | self.comm.dic_anet = json.load(open(self.inp_dict_file))
208 | # Get detections to index
209 | self.comm.dtoi = {w: i+1 for w,
210 | i in self.comm.dic_anet['wtod'].items()}
211 | self.comm.itod = {i: w for w, i in self.comm.dtoi.items()}
212 | self.comm.itow = self.comm.dic_anet['ix_to_word']
213 | self.comm.wtoi = {w: i for i, w in self.comm.itow.items()}
214 |
215 | self.comm.vocab_size = len(self.comm.itow) + 1
216 | self.comm.detect_size = len(self.comm.itod)
217 |
218 | # Load the glove vocab
219 | self.glove = vocab.GloVe(name='6B', dim=300)
220 |
221 | # get the glove vector for the vg detection cls
222 | # From Peter's repo
223 | obj_cls_file = self.cfg.ds.vg_class_file
224 | # index 0 is the background
225 | with open(obj_cls_file) as f:
226 | data = f.readlines()
227 | classes = ['__background__']
228 | classes.extend([i.strip() for i in data])
229 |
230 | # for VG classes
231 | # self.comm.vg_cls = classes
232 |
233 | # Extract glove vectors for the Visual Genome Classes
234 | # TODO: Cleaner implementation possible
235 | # TODO: Preproc only once
236 | glove_vg_cls = np.zeros((len(classes), 300))
237 | for i, w in enumerate(classes):
238 | split_word = w.replace(',', ' ').split(' ')
239 | vector = []
240 | for word in split_word:
241 | if word in self.glove.stoi:
242 | vector.append(
243 | self.glove.vectors[self.glove.stoi[word]].numpy())
244 | else: # use a random vector instead
245 | vector.append(2*np.random.rand(300) - 1)
246 |
247 | avg_vector = np.zeros((300))
248 | for v in vector:
249 | avg_vector += v
250 |
251 | glove_vg_cls[i] = avg_vector/len(vector)
252 |
253 | # category id to labels. +1 becuase 0 is the background label
254 | # Extract glove vectors for the 431 classes in AnetEntDataset
255 | # TODO: Cleaner Implementation
256 | # TODO: Preproc only once
257 | glove_clss = np.zeros((len(self.comm.itod)+1, 300))
258 | glove_clss[0] = 2*np.random.rand(300) - 1 # background
259 | for i, word in enumerate(self.comm.itod.values()):
260 | if word in self.glove.stoi:
261 | vector = self.glove.vectors[self.glove.stoi[word]]
262 | else: # use a random vector instead
263 | vector = 2*np.random.rand(300) - 1
264 | glove_clss[i+1] = vector
265 |
266 | # Extract glove vectors for the words from the vocab
267 | # TODO: cleaner implementation
268 | # TODO: preproc only once
269 | glove_w = np.zeros((len(self.comm.wtoi)+1, 300))
270 | for i, word in enumerate(self.comm.wtoi.keys()):
271 | vector = np.zeros((300))
272 | count = 0
273 | for w in word.split(' '):
274 | count += 1
275 | if w in self.glove.stoi:
276 | glove_vector = self.glove.vectors[self.glove.stoi[w]]
277 | vector += glove_vector.numpy()
278 | else: # use a random vector instead
279 | random_vector = 2*np.random.rand(300) - 1
280 | vector += random_vector
281 | glove_w[i+1] = vector / count
282 |
283 | out_dict = {
284 | 'classes': classes,
285 | 'glove_vg_cls': glove_vg_cls,
286 | 'glove_clss': glove_clss,
287 | 'glove_w': glove_w
288 | }
289 | pickle.dump(out_dict, open(self.cfg.ds.glove_stuff, 'wb'))
290 |
291 |
292 | if __name__ == '__main__':
293 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))
294 | anet_csv = AnetCSV(cfg)
295 |
296 | # anet_csv.create_csvs()
297 |
298 | # anet_csv.post_proc('train')
299 | # anet_csv.post_proc('val')
300 |
301 | # anet_csv.post_proc_srl('train', 'val')
302 | anet_csv.process_arg_vocabs()
303 |
--------------------------------------------------------------------------------
/dcode/process_gt_props.py:
--------------------------------------------------------------------------------
1 | """
2 | By default, we are using proposal boxes.
3 | Instead, we only consider the gts.
4 | """
5 |
6 | import numpy as np
7 | from pathlib import Path
8 | import h5py
9 | import json
10 | import pandas as pd
11 | from tqdm import tqdm
12 | import copy
13 | from box_utils import box_iou
14 | import torch
15 | from collections import OrderedDict
16 |
17 |
18 | class GTPropExtractor(object):
19 | def __init__(self, cfg):
20 | self.cfg = cfg
21 |
22 | # Assert h5 file to read from exists
23 | self.proposal_h5 = Path(self.cfg.ds.proposal_h5_resized)
24 | assert self.proposal_h5.exists()
25 |
26 | with h5py.File(self.proposal_h5, 'r',
27 | driver='core') as h5_proposal_file:
28 | self.num_proposals = h5_proposal_file['dets_num'][:]
29 | self.label_proposals = h5_proposal_file['dets_labels'][:]
30 |
31 | nppf = self.cfg.ds.ngt_prop
32 | self.out_label_proposals = np.zeros_like(
33 | self.label_proposals)[:, :10*nppf, ...]
34 | self.out_num_proposals = np.zeros_like(self.num_proposals)
35 |
36 | # Assert region features exists
37 | self.feature_root = Path(self.cfg.ds.feature_root)
38 | assert self.feature_root.exists()
39 |
40 | # Assert act ent caption file with bbox exists
41 | self.anet_ent_annot_file = Path(self.cfg.ds.anet_ent_annot_file)
42 | assert self.anet_ent_annot_file.exists()
43 |
44 | if cfg.ds.ngt_prop == 5:
45 | self.out_dir = Path(self.cfg.ds.feature_gt5_root)
46 | self.out_proposal_h5 = Path(self.cfg.ds.proposal_gt5_h5)
47 | else:
48 | raise NotImplementedError
49 |
50 | self.out_dir.mkdir(exist_ok=True)
51 | # Load anet bbox
52 | with open(self.anet_ent_annot_file) as f:
53 | self.anet_ent_captions = json.load(f)
54 |
55 | # trn_df = pd.read_csv(self.cfg.ds.trn_ann_file)
56 | # val_df = pd.read_csv(self.cfg.ds.val_ann_file)
57 |
58 | # self.req_df = pd.concat([trn_df, val_df])
59 |
60 | def do_for_all_vid_seg(self, save=True):
61 | recall_num = 0
62 | recall_tot = 0
63 | self.cfg.no_gt_count = 0
64 | for row_num, vid_seg_row in tqdm(self.req_df.iterrows(),
65 | total=len(self.req_df)):
66 | vid_seg_id = vid_seg_row['id']
67 | vid_seg = vid_seg_id.split('_segment_')
68 | vid = vid_seg[0]
69 | seg = str(int(vid_seg[1]))
70 |
71 | annot = self.anet_ent_captions[vid]['segments'][seg]
72 | gt_boxs = annot['bbox']
73 | gt_frms = annot['frm_idx']
74 |
75 | prop_index = vid_seg_row['Index']
76 | props = copy.deepcopy(self.label_proposals[prop_index])
77 | num_props = int(copy.deepcopy(self.num_proposals[prop_index]))
78 |
79 | if num_props < 1000:
80 | # import pdb
81 | # pdb.set_trace()
82 | assert np.all(props[num_props:, [0, 1, 2, 3]] == 0)
83 |
84 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy'
85 | # if save:
86 | prop_feats_load = np.load(region_feature_file)
87 | prop_feats = np.zeros((10, *prop_feats_load.shape[1:]))
88 | prop_feats[:prop_feats_load.shape[0]] = prop_feats_load
89 | # prop_feats = prop_feats.reshape(-1, prop_feats.shape[2]).copy()
90 | # prop_feats = prop_feats[:num_props, ...]
91 | # assert len(prop_feats) == len(props)
92 | # assert len(props) == num_props
93 |
94 | # else:
95 | # prop_feats = None
96 |
97 | out_file = self.out_dir / f'{vid_seg_id}.npy'
98 | # out_dict = self.do_for_one_vid_seg(
99 | # props, prop_feats, gt_boxs, gt_frms, out_file,
100 | # save=save
101 | # )
102 | nppf = self.cfg.ds.ngt_prop
103 | out_dict = self.prop10_one_vid_seg(
104 | props, prop_feats, gt_boxs, gt_frms, out_file,
105 | save=save, nppf=nppf
106 | )
107 | # out_dict = self.no_gt_prop10_one_vid_seg(
108 | # props, prop_feats, gt_boxs, gt_frms, out_file,
109 | # save=save
110 | # )
111 |
112 | if save:
113 | num_prop = out_dict['num_prop']
114 | self.out_label_proposals[prop_index][:num_prop] = (
115 | out_dict['out_props']
116 | )
117 | self.out_num_proposals[prop_index] = num_prop
118 |
119 | recall_num += out_dict['recall']
120 | recall_tot += out_dict['num_gt']
121 | # if row_num > 1000:
122 | # break
123 | recall = recall_num.item() / recall_tot
124 | if save:
125 | with h5py.File(self.out_proposal_h5, 'w') as f:
126 | f['dets_labels'] = self.out_label_proposals
127 | f['dets_num'] = self.out_num_proposals
128 | return recall
129 |
130 | def prop10_one_vid_seg(self, props, prop_feats,
131 | gt_boxs, gt_frms, out_file,
132 | save=True, nppf=10):
133 | nfrms = 10
134 | props = torch.tensor(props).float()
135 | prop_feats = torch.tensor(prop_feats).float()
136 | # gt_frms_dict = {}
137 | # for gfrm, gbox in zip(gt_frms, gt_boxs):
138 | # if gfrm not in gt_frms_dict:
139 | # gt_frms_dict[gfrm] = []
140 | # gt_frms_dict[gfrm].append(gbox)
141 | gt_frms_set = set(gt_frms)
142 | gt_boxs = torch.tensor(gt_boxs).float()
143 | gt_frms = torch.tensor(gt_frms).float()
144 |
145 | ngt = len(gt_boxs)
146 |
147 | nppf = nppf
148 |
149 | prop_frms = props[:, 4]
150 | frm_msk = prop_frms[:, None] == gt_frms
151 | if len(gt_boxs) > 0 and len(props) > 0:
152 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float()
153 | ious_max, ious_arg_max = ious.max(dim=0)
154 | if len(ious_arg_max) > nppf:
155 | ious_arg_max = ious_arg_max[:nppf]
156 | out_props = props[ious_arg_max]
157 | out_props_inds = ious_arg_max % 100
158 | recall = (ious_max > 0.5).sum()
159 | else:
160 | self.cfg.no_gt_count += 1
161 | ngt = 1
162 | recall = 0
163 | ious = torch.zeros(props.size(0), 1)
164 | out_props = props[0]
165 | out_props_inds = torch.tensor(0)
166 |
167 | fin_out_props = {}
168 | props1 = props.view(10, 100, 7)
169 | prop_dim = prop_feats.size(-1)
170 | prop_feats1 = prop_feats.view(10, 100, prop_dim)
171 |
172 | for frm in range(nfrms):
173 | if frm not in fin_out_props:
174 | fin_out_props[frm] = []
175 |
176 | if frm in gt_frms_set:
177 | props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm]
178 | fin_out_props[frm] += props_inds_gt_in_frm.tolist()
179 |
180 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[
181 | :nppf]
182 | # props_to_use_inds = np.random.choice(
183 | # np.arange(100), size=10, replace=False
184 | # )
185 | fin_out_props[frm] += props_to_use_inds.tolist()
186 |
187 | fin_out_props[frm] = list(
188 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf]
189 |
190 | props_output = torch.zeros(10, nppf, 7)
191 | prop_feats_output = torch.zeros(10, nppf, prop_dim)
192 |
193 | for frm in fin_out_props:
194 | inds = fin_out_props[frm]
195 | props_output[frm] = props1[frm][inds]
196 | prop_feats_output[frm] = prop_feats1[frm][inds]
197 |
198 | props_output = props_output.view(10*nppf, 7).detach().cpu().numpy()
199 | prop_feats_output = prop_feats_output.view(
200 | 10, nppf, prop_dim).detach().cpu().numpy()
201 |
202 | if save:
203 | np.save(out_file, prop_feats_output)
204 |
205 | return {
206 | 'out_props': props_output,
207 | 'recall': recall,
208 | 'num_prop': 100,
209 | 'num_gt': ngt
210 | }
211 |
212 | def no_gt_prop10_one_vid_seg(self, props, prop_feats,
213 | gt_boxs, gt_frms, out_file,
214 | save=True):
215 | nfrms = 10
216 | props = torch.tensor(props).float()
217 | prop_feats = torch.tensor(prop_feats).float()
218 | # gt_frms_dict = {}
219 | # for gfrm, gbox in zip(gt_frms, gt_boxs):
220 | # if gfrm not in gt_frms_dict:
221 | # gt_frms_dict[gfrm] = []
222 | # gt_frms_dict[gfrm].append(gbox)
223 | gt_frms_set = set(gt_frms)
224 | gt_boxs = torch.tensor(gt_boxs).float()
225 | gt_frms = torch.tensor(gt_frms).float()
226 |
227 | ngt = len(gt_boxs)
228 |
229 | nppf = 100
230 |
231 | fin_out_props = {}
232 | props1 = props.view(10, 100, 7)
233 | prop_dim = prop_feats.size(-1)
234 | prop_feats1 = prop_feats.view(10, 100, prop_dim)
235 |
236 | for frm in range(nfrms):
237 | if frm not in fin_out_props:
238 | fin_out_props[frm] = []
239 |
240 | # if frm in gt_frms_set:
241 | # props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm]
242 | # fin_out_props[frm] += props_inds_gt_in_frm.tolist()
243 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[
244 | :nppf]
245 | fin_out_props[frm] += props_to_use_inds.tolist()
246 |
247 | fin_out_props[frm] = list(
248 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf]
249 |
250 | props_output = torch.zeros(10, nppf, 7)
251 | prop_feats_output = torch.zeros(10, nppf, prop_dim)
252 |
253 | for frm in fin_out_props:
254 | inds = fin_out_props[frm]
255 | props_output[frm] = props1[frm][inds]
256 | prop_feats_output[frm] = prop_feats1[frm][inds]
257 |
258 | props_output = props_output.view(nfrms * nppf, 7)
259 | prop_feats_output = prop_feats_output.view(
260 | nfrms, nppf, prop_dim).detach().cpu().numpy()
261 |
262 | if len(gt_boxs) > 0 and len(props_output) > 0:
263 | prop_frms = props_output[:, 4]
264 | frm_msk = prop_frms[:, None] == gt_frms
265 | ious = box_iou(props_output[:, :4], gt_boxs) * frm_msk.float()
266 | ious_max, ious_arg_max = ious.max(dim=0)
267 | recall = (ious_max > 0.5).sum()
268 | else:
269 | self.cfg.no_gt_count += 1
270 | ngt = 1
271 | recall = 0
272 | ious = torch.zeros(props.size(0), 1)
273 |
274 | props_output = props_output.detach().cpu().numpy()
275 |
276 | if save:
277 | np.save(out_file, prop_feats_output)
278 |
279 | return {
280 | 'out_props': props_output,
281 | 'recall': recall,
282 | 'num_prop': 100,
283 | 'num_gt': ngt
284 | }
285 |
286 | def do_for_one_vid_seg(self, props, prop_feats,
287 | gt_boxs, gt_frms, out_file,
288 | save=True):
289 | """
290 | props: all the proposal boxes
291 | gt_boxs: all the groundtruth_boxes
292 | out_props: props with highest IoU with gt_box
293 | # nframes x 1,
294 | one-to-one correspondence
295 | Also, used to calculate recall.
296 | """
297 | props = torch.tensor(props).float()
298 | gt_boxs = torch.tensor(gt_boxs).float()
299 | gt_frms = torch.tensor(gt_frms).float()
300 |
301 | ngt = len(gt_boxs)
302 |
303 | prop_frms = props[:, 4]
304 | frm_msk = prop_frms[:, None] == gt_frms
305 |
306 | if len(gt_boxs) > 0 and len(props) > 0:
307 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float()
308 | ious_max, ious_arg_max = ious.max(dim=0)
309 | recall = (ious_max > 0.5).sum().float()
310 | out_props = props[ious_arg_max]
311 | else:
312 | self.cfg.no_gt_count += 1
313 | ngt = 1
314 | recall = 0
315 | ious = torch.zeros(props.size(0), 1)
316 | out_props = props[0]
317 |
318 | nprop = ngt
319 | if save:
320 | prop_dim = prop_feats.size(-1)
321 | prop_feats = torch.tensor(prop_feats).float()
322 | out_prop_feats = prop_feats[ious_arg_max].view(
323 | 1, ngt, prop_dim).detach().cpu().numpy()
324 | assert list(out_prop_feats.shape[:2]) == [1, ngt]
325 | np.save(out_file, out_prop_feats)
326 |
327 | return {
328 | 'out_props': out_props,
329 | 'recall': recall,
330 | 'num_prop': nprop,
331 | 'num_gt': ngt
332 | }
333 |
334 |
335 | if __name__ == '__main__':
336 | from extended_config import cfg as conf
337 | cfg = conf
338 | gtp = GTPropExtractor(cfg)
339 | recall = gtp.do_for_all_vid_seg(save=True)
340 | print(recall)
341 |
--------------------------------------------------------------------------------
/dcode/sem_role_labeller.py:
--------------------------------------------------------------------------------
1 | """
2 | Perform semantic role labeling for input captions
3 | """
4 | from allennlp.predictors.predictor import Predictor
5 | import pandas as pd
6 | import pickle
7 | import json
8 | from pathlib import Path
9 | from tqdm import tqdm
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 | import time
13 | import fire
14 | import re
15 | from typing import List, Dict, Any, Union
16 |
17 | SRL_BERT = (
18 | "https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz")
19 |
20 | srl_out_patt = re.compile(r'\[(.*?)\]')
21 |
22 | Fpath = Union[Path, str]
23 | Cft = CN
24 |
25 |
26 | class SRL_DS:
27 | """
28 | A base class to perform semantic role labeling
29 | """
30 |
31 | def __init__(self, cfg: Cft, tdir: str = '.'):
32 | self.cfg = cfg
33 | self.tdir = Path(tdir)
34 | archive_path = SRL_BERT
35 | self.srl = Predictor.from_path(
36 | archive_path=archive_path,
37 | predictor_name='semantic-role-labeling',
38 | cuda_device=0
39 | )
40 | # self.srl = Predictor.from_path(
41 | # "https://s3-us-west-2.amazonaws.com/allennlp/models/srl-model-2018.05.25.tar.gz")
42 | self.name = self.__class__.__name__
43 | self.cache_dir = self.tdir / \
44 | Path(f'{self.cfg.misc.cache_dir}/{self.name}')
45 | self.cache_dir.mkdir(exist_ok=True, parents=True)
46 | self.out_file = (self.cache_dir / f'{self.cfg.ds.srl_bert}')
47 | self.after_init()
48 |
49 | def after_init(self):
50 | pass
51 |
52 | def get_annotations(self) -> pd.DataFrame:
53 | """
54 | Expected to read a file,
55 | Create a df with the columns:
56 | vid_id, seg_id, sentence
57 | """
58 | raise NotImplementedError
59 |
60 | def do_predictions(self):
61 | annot_df = self.get_annotations()
62 | sents = annot_df.to_dict('records')
63 | st_time = time.time()
64 | out_list = []
65 | tot_len = len(sents)
66 | try:
67 | for idx in tqdm(range(0, len(sents), 50)):
68 | out = self.srl.predict_batch_json(
69 | sents[idx:min(idx+50, tot_len)])
70 | out_list += out
71 | except RuntimeError:
72 | pass
73 | finally:
74 | end_time = time.time()
75 | print(f'Took time {end_time-st_time}')
76 | pickle.dump(out_list, open(self.out_file, 'wb'))
77 | self.update_preds()
78 |
79 | def update_preds(self):
80 | preds = pickle.load(open(self.out_file, 'rb'))
81 | for pred in tqdm(preds):
82 | for verb in pred['verbs']:
83 | verb['req_pat'] = srl_out_patt.findall(verb['description'])
84 | pickle.dump(preds, open(self.out_file, 'wb'))
85 |
86 |
87 | class SRL_Anet(SRL_DS):
88 | def after_init(self):
89 | """
90 | Assert files exists
91 | """
92 | # Assert Raw Caption Files exists
93 | self.trn_anet_cap_file = self.tdir / Path(self.cfg.ds.anet_cap_file)
94 | assert self.trn_anet_cap_file.exists()
95 |
96 | def get_annotations(self):
97 | trn_cap_data = json.load(open(self.trn_anet_cap_file))
98 | trn_vid_list = list(trn_cap_data.keys())
99 | out_dict_list = []
100 | for trn_vid_name in tqdm(trn_vid_list):
101 | trn_vid_segs = trn_cap_data[trn_vid_name]
102 | num_segs = len(trn_vid_segs['timestamps'])
103 | for seg in range(num_segs):
104 | out_dict = {
105 | 'time_stamp': trn_vid_segs['timestamps'][seg],
106 | 'vid': trn_vid_name,
107 | 'vid_seg': f'{trn_vid_name}_segment_{seg:02d}',
108 | 'segment': seg,
109 | 'sentence': trn_vid_segs['sentences'][seg]
110 | }
111 | out_dict_list.append(out_dict)
112 | out_df = pd.DataFrame(out_dict_list)
113 | out_df.to_csv(
114 | (
115 | self.cache_dir /
116 | f'{self.cfg.ds.srl_caps}'
117 | ),
118 | header=True, index=False
119 | )
120 | return out_df
121 |
122 |
123 | def main():
124 | cfg_file = './configs/create_asrl_cfg.yml'
125 | cfg = CN(yaml.safe_load(open(cfg_file)))
126 | print(cfg)
127 | srl_ds = SRL_Anet(cfg)
128 | srl_ds.do_predictions()
129 |
130 |
131 | if __name__ == '__main__':
132 | main()
133 | # fire.Fire(main)
134 |
--------------------------------------------------------------------------------
/media/Intro_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/Intro_fig.png
--------------------------------------------------------------------------------
/media/contrastive_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/contrastive_examples.png
--------------------------------------------------------------------------------
/media/contrastive_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/contrastive_samples.png
--------------------------------------------------------------------------------
/media/model_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/model_fig.png
--------------------------------------------------------------------------------
/media/tempora_spatial_concat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/tempora_spatial_concat.png
--------------------------------------------------------------------------------
/notebooks/data_stats.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import altair as alt"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 4,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stdout",
38 | "output_type": "stream",
39 | "text": [
40 | "/home/arka/Ark_git_files/vognet-pytorch\n"
41 | ]
42 | }
43 | ],
44 | "source": [
45 | "cd .."
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 5,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "import sys\n",
55 | "sys.path.append('./dcode')\n",
56 | "sys.path.append('./code')\n",
57 | "sys.path.append('./utils')"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 6,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "from dataset_stats import AnetSRL_Vis"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 7,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "from yacs.config import CfgNode as CN\n",
76 | "import yaml\n",
77 | "cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 8,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "avis = AnetSRL_Vis(cfg)"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 9,
92 | "metadata": {},
93 | "outputs": [
94 | {
95 | "data": {
96 | "text/plain": [
97 | "Index(['gt_bboxes', 'gt_frms', 'lemma_ARG0', 'lemma_ARG1', 'lemma_ARG2',\n",
98 | " 'lemma_ARGM_LOC', 'lemma_verb', 'process_clss', 'process_idx2',\n",
99 | " 'req_aname', 'req_args', 'req_cls_pats', 'req_cls_pats_mask', 'req_pat',\n",
100 | " 'req_pat_ix', 'sent', 'tags', 'verb', 'vid_seg', 'words', 'ann_ind',\n",
101 | " 'srl_ind', 'vt_split', 'DS4_Inds', 'ds4_msk', 'RandDS4_Inds'],\n",
102 | " dtype='object')"
103 | ]
104 | },
105 | "execution_count": 9,
106 | "metadata": {},
107 | "output_type": "execute_result"
108 | }
109 | ],
110 | "source": [
111 | "avis.trn_srl_annots.columns"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "len(avis.trn_srl_annots.vid_seg.unique())"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "len(avis.val_srl_annots[avis.val_srl_annots.vt_split == 'val'].vid_seg.unique())"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "avis.vis=True"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 21,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "text/plain": [
149 | "8.049772762889829"
150 | ]
151 | },
152 | "execution_count": 21,
153 | "metadata": {},
154 | "output_type": "execute_result"
155 | }
156 | ],
157 | "source": [
158 | "avis.trn_srl_annots.req_pat_ix.apply(lambda x: sum([len(y[1]) for y in x])).mean()"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 22,
164 | "metadata": {},
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "Number of videos in Train/Valid/Test: 31718, 3891, 3914\n",
171 | "Number of Queries per Video is 2.0117914118166342\n",
172 | "Number of Queries per Video is 3.455868986052343\n",
173 | "Number of Queries per Video is 8.049772762889829\n",
174 | "Noun Phrases Count\n",
175 | "Arg,Count\n",
176 | "V,63812\n",
177 | "ARG0,48342\n",
178 | "ARG1,47335\n",
179 | "ARG2,16200\n",
180 | "ARGM-TMP,12061\n",
181 | "ARGM-DIR,8876\n",
182 | "ARGM-LOC,7408\n",
183 | "ARGM-MNR,5702\n",
184 | "ARGM-ADV,3661\n",
185 | "ARGM-PRP,1417\n",
186 | "ARG4,1238\n",
187 | "ARGM-PRD,905\n",
188 | "ARG3,854\n",
189 | "ARGM-GOL,447\n",
190 | "R-ARG0,423\n",
191 | "ARGM-COM,314\n",
192 | "R-ARG1,303\n",
193 | "C-ARG1,286\n",
194 | "ARGM-EXT,188\n",
195 | "ARGM-DIS,145\n",
196 | "\n",
197 | "Groundable Noun Phrase Count\n",
198 | "Arg,Count\n",
199 | "ARG0,42472\n",
200 | "ARG1,32455\n",
201 | "ARG2,9520\n",
202 | "ARGM-LOC,5082\n",
203 | "ARGM-TMP,3505\n",
204 | "ARGM-DIR,2936\n",
205 | "ARGM-MNR,2168\n",
206 | "ARGM-ADV,2036\n",
207 | "ARG4,947\n",
208 | "ARGM-PRP,690\n",
209 | "ARG3,538\n",
210 | "ARGM-GOL,310\n",
211 | "ARGM-PRD,298\n",
212 | "V,256\n",
213 | "ARGM-COM,209\n",
214 | "C-ARG1,186\n",
215 | "C-ARG0,26\n",
216 | "ARGM-CAU,20\n",
217 | "ARGM-PNC,17\n",
218 | "ARGM-EXT,13\n",
219 | "\n",
220 | "SRL Structures Frequency\n",
221 | "Arg,Count\n",
222 | "ARG0-V-ARG1,13654\n",
223 | "ARG0-V-ARG1-ARG2,3372\n",
224 | "ARG1-V-ARG2,3135\n",
225 | "ARG0-V,3080\n",
226 | "ARG0-V-ARGM-DIR,2269\n",
227 | "ARG0-V-ARG2,2075\n",
228 | "ARG0-V-ARG1-ARGM-LOC,1689\n",
229 | "ARG1-V,1631\n",
230 | "V-ARG1,1383\n",
231 | "ARG0-V-ARGM-LOC,1358\n",
232 | "ARG0-V-ARG1-ARGM-TMP,1290\n",
233 | "ARG0-V-ARG1-ARGM-MNR,862\n",
234 | "ARG0-V-ARG1-ARGM-DIR,838\n",
235 | "ARG0-ARGM-TMP-V-ARG1,754\n",
236 | "ARG1-V-ARG2-ARGM-ADV,743\n",
237 | "ARG1-V-ARGM-DIR,735\n",
238 | "ARGM-TMP-ARG0-V-ARG1,729\n",
239 | "ARG1-V-ARGM-LOC,586\n",
240 | "ARG0-V-ARGM-TMP,558\n",
241 | "ARG2-V-ARG1,526\n",
242 | "\n",
243 | "Lemmatized Counts for each lemma: {'lemma_verb': 338, 'lemma_ARG0': 93, 'lemma_ARG1': 281, 'lemma_ARG2': 114, 'lemma_ARGM_LOC': 59}\n",
244 | "Most Frequent Lemmas for lemma_verb\n",
245 | "String,Count\n",
246 | "stand,2395\n",
247 | "play,2152\n",
248 | "hold,1662\n",
249 | "talk,1626\n",
250 | "put,1458\n",
251 | "sit,1402\n",
252 | "speak,1190\n",
253 | "use,1057\n",
254 | "run,1053\n",
255 | "take,993\n",
256 | "walk,990\n",
257 | "throw,945\n",
258 | "go,930\n",
259 | "ride,906\n",
260 | "move,904\n",
261 | "walks,803\n",
262 | "wear,765\n",
263 | "get,740\n",
264 | "do,737\n",
265 | "look,714\n",
266 | "hit,690\n",
267 | "\n",
268 | "Most Frequent Lemmas for lemma_ARG0\n",
269 | "String,Count\n",
270 | ",21439\n",
271 | "man,8252\n",
272 | "he,7973\n",
273 | "woman,4095\n",
274 | "she,4081\n",
275 | "people,3360\n",
276 | "they,2048\n",
277 | "person,1785\n",
278 | "girl,1067\n",
279 | "boy,1053\n",
280 | "lady,789\n",
281 | "player,436\n",
282 | "dog,372\n",
283 | "child,360\n",
284 | "team,339\n",
285 | "kid,337\n",
286 | "athlete,272\n",
287 | "shirt,263\n",
288 | "guy,250\n",
289 | "gymnast,218\n",
290 | "other,196\n",
291 | "\n",
292 | "Most Frequent Lemmas for lemma_ARG1\n",
293 | "String,Count\n",
294 | ",31440\n",
295 | "he,2459\n",
296 | "man,1967\n",
297 | "it,1433\n",
298 | "woman,1163\n",
299 | "she,1132\n",
300 | "people,1097\n",
301 | "ball,967\n",
302 | "they,792\n",
303 | "hand,467\n",
304 | "hair,413\n",
305 | "person,380\n",
306 | "dog,354\n",
307 | "car,318\n",
308 | "girl,317\n",
309 | "screen,300\n",
310 | "water,298\n",
311 | "boy,297\n",
312 | "rope,269\n",
313 | "shoe,245\n",
314 | "shirt,235\n",
315 | "\n",
316 | "Most Frequent Lemmas for lemma_ARG2\n",
317 | "String,Count\n",
318 | ",54374\n",
319 | "he,525\n",
320 | "table,322\n",
321 | "she,259\n",
322 | "woman,244\n",
323 | "water,204\n",
324 | "man,203\n",
325 | "it,186\n",
326 | "people,176\n",
327 | "floor,151\n",
328 | "field,149\n",
329 | "room,148\n",
330 | "wall,144\n",
331 | "board,141\n",
332 | "car,139\n",
333 | "ground,128\n",
334 | "chair,114\n",
335 | "tree,101\n",
336 | "bar,100\n",
337 | "they,100\n",
338 | "ball,94\n",
339 | "\n",
340 | "Most Frequent Lemmas for lemma_ARGM_LOC\n",
341 | "String,Count\n",
342 | ",58825\n",
343 | "he,243\n",
344 | "water,238\n",
345 | "room,230\n",
346 | "field,200\n",
347 | "screen,134\n",
348 | "gym,128\n",
349 | "stage,124\n",
350 | "table,106\n",
351 | "floor,99\n",
352 | "court,91\n",
353 | "bar,87\n",
354 | "beach,86\n",
355 | "street,84\n",
356 | "pool,80\n",
357 | "she,79\n",
358 | "board,77\n",
359 | "mat,75\n",
360 | "woman,72\n",
361 | "ground,68\n",
362 | "track,65\n",
363 | "\n"
364 | ]
365 | }
366 | ],
367 | "source": [
368 | "vlist = avis.print_all_stats()"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": null,
374 | "metadata": {},
375 | "outputs": [],
376 | "source": [
377 | "vlist"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "metadata": {},
384 | "outputs": [],
385 | "source": [
386 | "vlist[2]"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": null,
392 | "metadata": {},
393 | "outputs": [],
394 | "source": [
395 | "out = avis.visualize_df(nnp_srl, x_name='Count:Q', y_name='Arg:O')"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": null,
401 | "metadata": {},
402 | "outputs": [],
403 | "source": [
404 | "out"
405 | ]
406 | }
407 | ],
408 | "metadata": {
409 | "kernelspec": {
410 | "display_name": "Python 3",
411 | "language": "python",
412 | "name": "python3"
413 | },
414 | "language_info": {
415 | "codemirror_mode": {
416 | "name": "ipython",
417 | "version": 3
418 | },
419 | "file_extension": ".py",
420 | "mimetype": "text/x-python",
421 | "name": "python",
422 | "nbconvert_exporter": "python",
423 | "pygments_lexer": "ipython3",
424 | "version": "3.7.3"
425 | }
426 | },
427 | "nbformat": 4,
428 | "nbformat_minor": 2
429 | }
430 |
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # File Organization
2 |
3 | 1. `box_utils.py` as the name duly suggest, contains utils for box iou, box area computation. Also, contains code for box iou for multiple frames
4 | 1. `mdl_srl_utils.py` has convenience functions for the models (surprise surprise). Stuff like LSTM implementation adapted from fairseq.
5 | 1. `trn_utils.py` contains learner which handles the model saving/loading, logging stuff, saving predictions among other things.
6 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/utils/__init__.py
--------------------------------------------------------------------------------
/utils/box_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper functions for boxes
3 | Adapted from
4 | https://github.com/facebookresearch/maskrcnn-benchmark/
5 | blob/master/maskrcnn_benchmark/structures/boxlist_ops.py
6 | """
7 | import torch
8 |
9 | TO_REMOVE = 0
10 |
11 |
12 | def get_area(box):
13 | """
14 | box: [N, 4]
15 | torch.tensor of
16 | type x1y1x2y2
17 | """
18 | area = (
19 | (box[:, 2] - box[:, 0] + TO_REMOVE) *
20 | (box[:, 3] - box[:, 1] + TO_REMOVE)
21 | )
22 | return area
23 |
24 |
25 | def box_iou(box1, box2):
26 | """
27 | box1: [N, 4]
28 | box2: [M, 4]
29 | both of type torch.tensor
30 | Assumes both of type x1y1x2y2
31 | output: [N,M]
32 | """
33 | if len(box1.shape) == 1 and len(box1) == 4:
34 | box1 = box1.unsqueeze(0)
35 | if len(box2.shape) == 1 and len(box2) == 4:
36 | box2 = box2.unsqueeze(0)
37 |
38 | N = len(box1)
39 | M = len(box2)
40 |
41 | area1 = get_area(box1)
42 | area2 = get_area(box2)
43 |
44 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2]
45 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
46 |
47 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2]
48 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
49 |
50 | iou = inter / (area1[:, None] + area2 - inter)
51 | return iou
52 |
53 |
54 | def bbox_overlaps(rois, gt_box, frm_mask):
55 |
56 | overlaps = bbox_overlaps_batch(rois[:, :, :5], gt_box[:, :, :5], frm_mask)
57 |
58 | return overlaps
59 |
60 |
61 | def bbox_overlaps_batch(anchors, gt_boxes, frm_mask=None):
62 | """
63 | Source:
64 | https://github.com/facebookresearch/grounded-video-description/blob/
65 | master/misc/bbox_transform.py#L176
66 | anchors: (b, N, 4) ndarray of float
67 | gt_boxes: (b, K, 5) ndarray of float
68 | frm_mask: (b, N, K) ndarray of bool
69 |
70 | overlaps: (b, N, K) ndarray of overlap between boxes and query_boxes
71 | """
72 | batch_size = gt_boxes.size(0)
73 |
74 | N = anchors.size(1)
75 | K = gt_boxes.size(1)
76 |
77 | anchors = anchors[:, :, :5].contiguous()
78 | gt_boxes = gt_boxes[:, :, :5].contiguous()
79 |
80 | gt_boxes_x = (gt_boxes[:, :, 2] - gt_boxes[:, :, 0] + 1)
81 | gt_boxes_y = (gt_boxes[:, :, 3] - gt_boxes[:, :, 1] + 1)
82 | gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K)
83 |
84 | anchors_boxes_x = (anchors[:, :, 2] - anchors[:, :, 0] + 1)
85 | anchors_boxes_y = (anchors[:, :, 3] - anchors[:, :, 1] + 1)
86 | anchors_area = (anchors_boxes_x *
87 | anchors_boxes_y).view(batch_size, N, 1)
88 |
89 | gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1)
90 | anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1)
91 |
92 | boxes = anchors.view(batch_size, N, 1, 5).expand(batch_size, N, K, 5)
93 | query_boxes = gt_boxes.view(
94 | batch_size, 1, K, 5).expand(batch_size, N, K, 5)
95 |
96 | iw = (torch.min(boxes[:, :, :, 2], query_boxes[:, :, :, 2]) -
97 | torch.max(boxes[:, :, :, 0], query_boxes[:, :, :, 0]) + 1)
98 | iw[iw < 0] = 0
99 |
100 | ih = (torch.min(boxes[:, :, :, 3], query_boxes[:, :, :, 3]) -
101 | torch.max(boxes[:, :, :, 1], query_boxes[:, :, :, 1]) + 1)
102 | ih[ih < 0] = 0
103 | ua = anchors_area + gt_boxes_area - (iw * ih)
104 |
105 | if frm_mask is not None:
106 | # proposal and gt should be on the same frame to overlap
107 | # print('Percentage of proposals that are in the annotated frame: {}'.format(torch.mean(frm_mask.float())))
108 |
109 | overlaps = iw * ih / ua
110 | overlaps *= frm_mask.type(overlaps.type())
111 |
112 | # mask the overlap here.
113 | overlaps.masked_fill_(gt_area_zero.view(
114 | batch_size, 1, K).expand(batch_size, N, K), 0)
115 | overlaps.masked_fill_(anchors_area_zero.view(
116 | batch_size, N, 1).expand(batch_size, N, K), -1)
117 |
118 | return overlaps
119 |
--------------------------------------------------------------------------------
/utils/mdl_srl_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some helpful functions/classes are defined
3 | """
4 | import torch
5 | from torch import nn
6 | # from fairseq.models import FairseqEncoder
7 | from torch.nn import functional as F
8 | from fairseq import utils
9 |
10 |
11 | def combine_first_ax(inp_tensor, keepdim=False):
12 | inp_shape = inp_tensor.shape
13 | if keepdim:
14 | return inp_tensor.view(
15 | 1, inp_shape[0] * inp_shape[1], *inp_shape[2:])
16 | return inp_tensor.view(
17 | inp_shape[0] * inp_shape[1], *inp_shape[2:])
18 |
19 |
20 | def uncombine_first_ax(inp_tensor, s0):
21 | "s0 is the size(0) intended, usually B"
22 | inp_shape = inp_tensor.shape
23 | size0 = inp_tensor.size(0)
24 | assert size0 % s0 == 0
25 | s1 = size0 // s0
26 | return inp_tensor.view(
27 | s0, s1, *inp_shape[1:])
28 |
29 |
30 | def do_cross(x1, x2=None, dim1=1, op='add'):
31 | """
32 | if x2 is none do x1(row) + x1(col)
33 | else x1(row) + x2(col)
34 | dim1, dim2 are first and second dimension
35 | to be used for crossing.
36 | Both x1, x2 should have same shape except
37 | for at most one dimension
38 |
39 | if input is B x C x D x E with dim1=1
40 | B x C x D x E ->
41 | B x C x 1 x D x E -> B x C x C x D x E;
42 | B x 1 x C x D x E -> B x C x C x D x E;
43 | and then add
44 |
45 | op = 'add', 'subt', 'mult' or 'concat'
46 | """
47 | x1_shape = x1.shape
48 | if x2 is None:
49 | x2 = x1
50 | assert x1.shape == x2.shape
51 | x1_dim = len(x1_shape)
52 | out_shape = tuple((*x1_shape[:dim1], x1_shape[dim1], *x1_shape[dim1:]))
53 | if dim1 < x1_dim:
54 | x1_row = x1.view(*x1_shape[:dim1+1], 1, *
55 | x1_shape[dim1+1:]).expand(out_shape)
56 | x2_col = x2.view(*x1_shape[:dim1], 1, *
57 | x1_shape[dim1:]).expand(out_shape)
58 | else:
59 | x1_row = x1.view(*x1_shape[:dim1+1], 1)
60 | x2_col = x2.view(*x1_shape[:dim1], 1, x1_shape[dim1])
61 |
62 | if op == 'add':
63 | return (x1_row + x2_col) / 2
64 | elif op == 'mult':
65 | return (x1_row * x2_col)
66 | elif op == 'concat':
67 | return torch.cat([x1_row, x2_col], dim=-1)
68 | elif op == 'subtract':
69 | return (x1_row - x2_col)
70 |
71 |
72 | class LSTMEncoder(nn.Module):
73 | """LSTM encoder."""
74 |
75 | def __init__(
76 | self, cfg, comm, embed_dim=512, hidden_size=512, num_layers=1,
77 | dropout_in=0.1, dropout_out=0.1, bidirectional=False,
78 | left_pad=True, pretrained_embed=None, padding_value=0.,
79 | num_embeddings=0, pad_idx=0
80 | ):
81 | super().__init__()
82 | self.cfg = cfg
83 | self.comm = comm
84 | self.num_layers = num_layers
85 | self.dropout_in = dropout_in
86 | self.dropout_out = dropout_out
87 | self.bidirectional = bidirectional
88 | self.hidden_size = hidden_size
89 |
90 | num_embeddings = num_embeddings
91 | self.padding_idx = pad_idx
92 | embed_dim1 = embed_dim
93 | if pretrained_embed is None:
94 | self.embed_tokens = nn.Embedding(
95 | num_embeddings, embed_dim1, self.padding_idx
96 | )
97 | else:
98 | self.embed_tokens = pretrained_embed
99 |
100 | self.lstm = nn.LSTM(
101 | input_size=embed_dim,
102 | hidden_size=hidden_size,
103 | num_layers=num_layers,
104 | dropout=self.dropout_out if num_layers > 1 else 0.,
105 | bidirectional=bidirectional,
106 | )
107 | self.left_pad = left_pad
108 | self.padding_value = padding_value
109 |
110 | self.output_units = hidden_size
111 | if bidirectional:
112 | self.output_units *= 2
113 |
114 | def forward(self, src_tokens, src_lengths):
115 | if self.left_pad:
116 | # nn.utils.rnn.pack_padded_sequence requires right-padding;
117 | # convert left-padding to right-padding
118 | src_tokens = utils.convert_padding_direction(
119 | src_tokens,
120 | self.padding_idx,
121 | left_to_right=True,
122 | )
123 |
124 | bsz, seqlen = src_tokens.size()
125 | # embed tokens
126 | x = self.embed_tokens(src_tokens)
127 |
128 | x = F.dropout(x, p=self.dropout_in, training=self.training)
129 |
130 | # B x T x C -> T x B x C
131 | x = x.transpose(0, 1)
132 |
133 | # pack embedded source tokens into a PackedSequence
134 | packed_x = nn.utils.rnn.pack_padded_sequence(
135 | x, src_lengths.data.tolist(), enforce_sorted=False)
136 |
137 | # apply LSTM
138 | if self.bidirectional:
139 | state_size = 2 * self.num_layers, bsz, self.hidden_size
140 | else:
141 | state_size = self.num_layers, bsz, self.hidden_size
142 | h0 = x.new_zeros(*state_size)
143 | c0 = x.new_zeros(*state_size)
144 | packed_outs, (final_hiddens, final_cells) = self.lstm(
145 | packed_x, (h0, c0))
146 |
147 | # unpack outputs and apply dropout
148 | x, _ = nn.utils.rnn.pad_packed_sequence(
149 | packed_outs, padding_value=self.padding_value)
150 | x = F.dropout(x, p=self.dropout_out, training=self.training)
151 | assert list(x.size()) == [seqlen, bsz, self.output_units]
152 |
153 | if self.bidirectional:
154 |
155 | def combine_bidir(outs):
156 | out = outs.view(self.num_layers, 2, bsz, -
157 | 1).transpose(1, 2).contiguous()
158 | return out.view(self.num_layers, bsz, -1)
159 |
160 | final_hiddens = combine_bidir(final_hiddens)
161 | final_cells = combine_bidir(final_cells)
162 |
163 | encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
164 |
165 | return {
166 | 'encoder_out': (x, final_hiddens, final_cells),
167 | 'encoder_padding_mask': (encoder_padding_mask
168 | if encoder_padding_mask.any() else None)
169 | }
170 |
171 | def reorder_only_outputs(self, outputs):
172 | """
173 | outputs of shape : T x B x C -> B x T x C
174 | """
175 | return outputs.transpose(0, 1).contiguous()
176 |
177 | def reorder_encoder_out(self, encoder_out, new_order):
178 | encoder_out['encoder_out'] = tuple(
179 | eo.index_select(1, new_order)
180 | for eo in encoder_out['encoder_out']
181 | )
182 | if encoder_out['encoder_padding_mask'] is not None:
183 | encoder_out['encoder_padding_mask'] = \
184 | encoder_out['encoder_padding_mask'].index_select(1, new_order)
185 | return encoder_out
186 |
187 | def max_positions(self):
188 | """Maximum input length supported by the encoder."""
189 | return int(1e5) # an arbitrary large number
190 |
191 |
192 | class SimpleAttn(nn.Module):
193 | def __init__(self, qdim, hdim):
194 | super().__init__()
195 | self.lin1 = nn.Linear(qdim, hdim)
196 | self.lin2 = nn.Linear(qdim, hdim)
197 | self.lin3 = nn.Linear(hdim, 1)
198 |
199 | def forward(self, qvec, qlast, inp):
200 | """
201 | qvec: B x nsrl x qdim
202 | qlast: B x 1 x qdim
203 | """
204 | # B x nv x nsrl x hdim
205 | B, num_verbs, nsrl, qd = qvec.shape
206 | qvec_enc = self.lin1(qvec)
207 | # B x nv x 1 x hdim
208 | qlast_enc = self.lin2(qlast)
209 |
210 | hdim = qlast_enc.size(-1)
211 |
212 | # B x nv x nsrl x hdim
213 | q1_enc = torch.tanh(
214 | qvec_enc +
215 | qlast_enc.view(
216 | B, num_verbs, 1, hdim
217 | ).expand(
218 | B, num_verbs, nsrl, hdim
219 | )
220 | )
221 | # B x nv x nsrl
222 | u1 = self.lin3(q1_enc).squeeze(-1)
223 | # B x nv x nsrl
224 | attns = F.softmax(u1, dim=-1)
225 |
226 | # B x nv x nsrl x qdim
227 | qvec_out = attns.unsqueeze(-1).expand_as(qvec) * qvec
228 |
229 | return qvec_out
230 |
--------------------------------------------------------------------------------