├── .gitignore ├── EXPTS.md ├── LICENSE ├── README.md ├── code ├── README.md ├── _init_stuff.py ├── contrastive_sampling.py ├── dat_loader_simple.py ├── eval_fn_corr.py ├── eval_vsrl_corr.py ├── extended_config.py ├── main_dist.py ├── mdl_base.py ├── mdl_conc_sep.py ├── mdl_conc_single.py ├── mdl_selector.py ├── mdl_vog.py ├── transformer_code.py └── visualizer.py ├── conda_env_vog.yml ├── configs ├── anet_srl_cfg.yml └── create_asrl_cfg.yml ├── data ├── README.md └── download_data.sh ├── dcode ├── README.md ├── asrl_creator.py ├── dataset_stats.py ├── download_asrl_parent_ann.sh ├── preproc_anet_files.py ├── preproc_ds_files.py ├── process_gt_props.py └── sem_role_labeller.py ├── media ├── Intro_fig.png ├── contrastive_examples.png ├── contrastive_samples.png ├── model_fig.png └── tempora_spatial_concat.png ├── notebooks └── data_stats.ipynb └── utils ├── README.md ├── __init__.py ├── box_utils.py ├── mdl_srl_utils.py └── trn_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | cache_dir 2 | data/ 3 | !data/README.md 4 | __pycache__/ 5 | tmp -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Arka Sadhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vognet-pytorch 2 | [![LICENSE](https://img.shields.io/badge/license-MIT-green)](https://github.com/TheShadow29/vognet-pytorch/blob/master/LICENSE) 3 | [![Python](https://img.shields.io/badge/python-3.6-blue)](https://www.python.org/) 4 | ![PyTorch](https://img.shields.io/badge/pytorch-1.1-yellow) 5 | [![Arxiv](https://img.shields.io/badge/Arxiv-2003.10606-purple)](https://arxiv.org/abs/2003.10606) 6 | 7 | 8 | [**Video Object Grounding using Semantic Roles in Language Description**](https://arxiv.org/abs/2003.10606)
9 | [Arka Sadhu](https://theshadow29.github.io/), [Kan Chen](https://kanchen.info/) [Ram Nevatia](https://sites.usc.edu/iris-cvlab/professor-ram-nevatia/)
10 | [CVPR 2020](http://cvpr2020.thecvf.com/) 11 | 12 | **Video Object Grounding (VOG)** is the task of localizing objects in a video referred in a query sentence description. 13 | We elevate the role of object relations via spatial and temporal concatenation of contrastive examples sampled from a newly contributed dataset called ActivityNet-SRL (ASRL). 14 | 15 | ![](./media/contrastive_examples.png) 16 | 17 | This repository includes: 18 | 1. code to create the ActivityNet-SRL dataset under [`dcode/`](./dcode) 19 | 1. code to run all the experiments provided in the paper under [`code/`](./code) 20 | 1. To foster reproducibility of our results, links to all trained models in the paper along with their log files are provided in [EXPTS.md](./EXPTS.md) 21 | 22 | Code has been modularized from its initial implementation. 23 | It should be easy to extend the code for other datasets by inheriting relevant modules. 24 | 25 | ## Installation 26 | Requirements: 27 | - python>=3.6 28 | - pytorch==1.1 (should work with pytorch >=1.3 as well but not tested) 29 | 30 | To use the same environment you can use `conda` and the environment file `conda_env_vog.yml` file provided. Please refer to [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for details on installing `conda`. 31 | 32 | ``` 33 | MINICONDA_ROOT=[to your Miniconda/Anaconda root directory] 34 | conda env create -f conda_env_vog.yml --prefix $MINICONDA_ROOT/envs/vog_pyt 35 | conda activate vog_pyt 36 | ``` 37 | 38 | ## Quick Start 39 | 1. Clone repo: 40 | ``` 41 | git clone https://github.com/TheShadow29/vognet-pytorch.git 42 | cd vognet-pytorch 43 | export ROOT=$(pwd) 44 | ``` 45 | 1. Download Data (~530gb) (See [DATA_README](./data/README.md) for more details) 46 | ``` 47 | cd $ROOT/data 48 | bash download_data.sh all [data_folder] 49 | ``` 50 | 1. Train Models 51 | ``` 52 | cd $ROOT 53 | python code/main_dist.py "spat_vog_gt5" --ds.exp_setting='gt5' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --train.prob_thresh=0.2 --train.bs=4 --train.epochs=10 --train.lr=1e-4 54 | ``` 55 | ## Data Preparation 56 | If you just want to use ASRL, you can refer to [DATA_README](./data/README.md). It contains direct links to download ASRL 57 | 58 | If instead, you want to recreate ASRL from ActivityNet Entities and ActivityNet Captions, or perhaps want to extend to a newer dataset, refer to [DATA_PREP_README.md](./dcode/README.md) 59 | 60 | ## Training 61 | Basic usage is `python code/main_dist.py "experiment_name" --arg1=val1 --arg2=val2` and the arg1, arg2 can be found in `configs/anet_srl_cfg.yml`. 62 | 63 | The hierarchical structure of `yml` is also supported using `.` 64 | For example, if you want to change the `mdl name` which looks like 65 | ``` 66 | mdl: 67 | name: xyz 68 | ``` 69 | you can pass `--mdl.name='abc'` 70 | 71 | As an example, training `VOGNet` using `spat` strategy with `gt5` setting: 72 | 73 | ``` 74 | python code/main_dist.py "spat_vog_gt5" --ds.exp_setting='gt5' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --train.prob_thresh=0.2 --train.bs=4 --train.epochs=10 --train.lr=1e-4 75 | ``` 76 | 77 | You can change default settings in `configs/anet_srl_cfg.yml` directly as well. 78 | 79 | See [EXPTS.md](./EXPTS.md) for command-line instructions for all experiments. 80 | 81 | ## Logging 82 | 83 | Logs are stored inside `tmp/` directory. When you run the code with $exp_name the following are stored: 84 | - `txt_logs/$exp_name.txt`: the config used and the training, validation losses after ever epoch. 85 | - `models/$exp_name.pth`: the model, optimizer, scheduler, accuracy, number of epochs and iterations completed are stored. Only the best model upto the current epoch is stored. 86 | - `ext_logs/$exp_name.txt`: this uses the `logging` module of python to store the `logger.debug` outputs printed. Mainly used for debugging. 87 | - `predictions`: the validation outputs of current best model. 88 | 89 | ## Evaluation 90 | To evaluate a model, you need to first load it and then pass `--only_val=True` 91 | 92 | As an example, to validate the `VOGNet` model trained in `spat` with `gt5` setting: 93 | ``` 94 | python code/main_dist.py "spat_vog_gt5_valid" --train.resume=True --train.resume_path='./tmp/models/spat_vog_gt5.pth' --mdl.name='vog' --mdl.obj_tx.use_rel=True --mdl.mul_tx.use_rel=True --only_val=True --train.prob_thresh=0.2 95 | ``` 96 | 97 | This will create `./tmp/predictions/spat_vog_gt5_valid/valid_0.pkl` and print out the metrics. 98 | 99 | You can also evaluate this file using `code/eval_fn_corr.py`. This assumes `valid_0.pkl` file is already generated. 100 | 101 | ``` 102 | python code/eval_fn_corr.py --pred_file='./tmp/predictions/spat_vog_gt5_valid/valid_0.pkl' --split_type='valid' --train.prob_thresh=0.2 103 | ``` 104 | 105 | For evaluating `test` simply use `--split_type='test'` 106 | 107 | If you are using your own code, but just want to use evaluation, you must save your output in the following format: 108 | ``` 109 | [ 110 | { 111 | 'idx_sent': id of the input query 112 | 'pred_boxes': # num_srls x num_vids x num_frames x 5d prop boxes 113 | 'pred_scores': # num_srls x num_vids x num_frames (between 0-1) 114 | 'pred_cmp': # num_srls x num_frames (only required for sep). Basically, which video to choose 115 | 'cmp_msk': 1/0s if any videos were padded and hence not considered 116 | 'targ_cmp': which is the target video. This is in prediction and not ground-truth since we shuffle the video list at runtime 117 | }, 118 | ... 119 | ] 120 | ``` 121 | 122 | ## Pre-Trained Models 123 | 124 | Google Drive Link for all models: https://drive.google.com/open?id=1e3FiX4FTC8n6UrzY9fTYQzFNKWHihzoQ 125 | 126 | Also, see individual models (with corresponding logs) at [EXPTS.md](./EXPTS.md) 127 | 128 | ## Acknowledgements: 129 | 130 | We thank: 131 | 1. @LuoweiZhou: for his codebase on GVD (https://github.com/facebookresearch/grounded-video-description) along with the extracted features. 132 | 2. [allennlp](https://github.com/allenai/allennlp) for providing [demo](https://demo.allennlp.org/semantic-role-labeling) and pre-trained model for SRL. 133 | 3. [fairseq](https://github.com/pytorch/fairseq) for providing a neat implementation of LSTM. 134 | 135 | ## Citation 136 | ``` 137 | @InProceedings{Sadhu_2020_CVPR, 138 | author = {Sadhu, Arka and Chen, Kan and Nevatia, Ram}, 139 | title = {Video Object Grounding using Semantic Roles in Language Description}, 140 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 141 | month = {June}, 142 | year = {2020} 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # Model 2 | 3 | ![](../media/model_fig.png) 4 | 5 | ## File Organization 6 | 7 | 1. `main_dist.py` is the main file. 8 | 1. `dat_loader_simple.py` processes the data. In particular, the SPAT/TEMP/SEP part is modular, can be easily extended to a newer dataset. 9 | 1. `contrastive_sampling.py` as the name suggests creates the contrastive samples for Training. The validation file already contains these indices. 10 | 1. `mdl_base.py` is the base model, which just defines bunch of functions to be filled in. 11 | 1. `mdl_conc_single.py` implements concatenation models and losses for SPAT/TEMP. Similarly, `mdl_conc_sep.py` implements SEP concatentation model and loss. These are kept modular, so that they can be re-used with newer models with ~~minimal~~ some effort. 12 | 1. `mdl_vog.py` contains the main model implementations of baselines and vog. 13 | 1. `mdl_selector.py` returns the model, loss and evaluation function to be used based on input arguments. 14 | 1. `eval_vsrl_corr.py` is the top-level evaluation functions for each of SEP/TEMP/SPAT which processes the output of the model and converts them to uniform format for evaluation. 15 | 1. `eval_fn_corr.py` contains the main logic for evaluating the models. 16 | 1. `_init_stuff.py` initializes paths to be included, typings, as well as yaml float loader (otherwise 1e-4 cannot be read correctly). 17 | 1. `extended_config.py` has some handy configuration utils. 18 | 1. `transformer_code.py` has the transformer implementation, also has the relative transformer which uses relative position encoding (RPE). 19 | 20 | Some other useful files are under [`utils`](../utils/) 21 | -------------------------------------------------------------------------------- /code/_init_stuff.py: -------------------------------------------------------------------------------- 1 | """ 2 | Initialize stuff 3 | """ 4 | 5 | import pdb 6 | from pathlib import Path 7 | from typing import List, Dict, Any, Union 8 | from yacs.config import CfgNode as CN 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | import os 13 | import sys 14 | import yaml 15 | import re 16 | import pandas as pd 17 | 18 | Fpath = Union[Path, str] 19 | Cft = CN 20 | Arr = Union[np.array, List, torch.tensor] 21 | DF = pd.DataFrame 22 | # Ds = Dataset 23 | 24 | # This is required to read 5e-4 as a float rather than string 25 | # at all places yaml should be imported from here 26 | yaml.SafeLoader.add_implicit_resolver( 27 | u'tag:yaml.org,2002:float', 28 | re.compile(u'''^(?: 29 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 30 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 31 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 32 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 33 | |[-+]?\\.(?:inf|Inf|INF) 34 | |\\.(?:nan|NaN|NAN))$''', re.X), 35 | list(u'-+0123456789.')) 36 | 37 | # _SCRIPTPATH_ = 38 | sys.path.append('./code/') 39 | sys.path.append('./utils') 40 | 41 | 42 | class ForkedPdb(pdb.Pdb): 43 | """A Pdb subclass that may be used 44 | from a forked multiprocessing child 45 | Credits: 46 | https://github.com/williamFalcon/forked-pdb/blob/master/fpdb.py 47 | """ 48 | 49 | def interaction(self, *args, **kwargs): 50 | _stdin = sys.stdin 51 | try: 52 | sys.stdin = open('/dev/stdin') 53 | pdb.Pdb.interaction(self, *args, **kwargs) 54 | finally: 55 | sys.stdin = _stdin 56 | -------------------------------------------------------------------------------- /code/contrastive_sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | To create the 4-way dataset 3 | Main motivation: 4 | Currently, not sure if the models ground 5 | based only on object name, or is it really 6 | learning the roles of the visual elements 7 | correctly. 8 | 9 | Thus, we create 4-way dataset, for every 10 | data which has S-V-O statistics, we generate 11 | counterfactuals (not sure if this is 12 | the correct name or not). For every image 13 | containing say S1-V1-O1, present it with other 14 | images with the characteristics S2-V1-O1, 15 | S1-V2-O1, S1-V1-O2 as well. Some can be 16 | reduced in case only S-V or O-V are present 17 | 18 | More generally, we would like to create a 19 | counterfactuals for anything that can 20 | provide evidence. 21 | 22 | Additionally, need to check 23 | - [x] Location words shouldn't be present 24 | - [x] Perform VERB lemmatization 25 | - [x] Distinguish between what is groundable and 26 | what is not 27 | - [x] Check the groundable verbs 28 | """ 29 | from pathlib import Path 30 | import pandas as pd 31 | 32 | from tqdm.auto import tqdm 33 | from collections import Counter 34 | import json 35 | import copy 36 | import ast 37 | import numpy as np 38 | from _init_stuff import CN, yaml 39 | from typing import List 40 | np.random.seed(seed=5) 41 | 42 | 43 | def create_random_list(cfg, srl_annots, ann_row_idx): 44 | """ 45 | Returns 4 random videos 46 | """ 47 | srl_idxs_possible = np.array(srl_annots.index) 48 | 49 | vid_segs = srl_annots.vid_seg 50 | vid_seg = vid_segs.loc[ann_row_idx] 51 | srl_row = srl_annots.loc[ann_row_idx] 52 | 53 | req_cls_pats = srl_row.req_cls_pats 54 | req_cls_pats_mask = srl_row.req_cls_pats_mask 55 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']) 56 | 57 | arg_keys_vis_present = [] 58 | arg_keys_lang_present = [] 59 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask): 60 | arg_key = srl_arg[0] 61 | arg_keys_lang_present.append(arg_key) 62 | if arg_key == 'V' or arg_key in args_to_use: 63 | arg_keys_vis_present.append(arg_key) 64 | 65 | ds4_msk = {} 66 | inds_to_use = {} 67 | num_arg_keys_vis = len(arg_keys_vis_present) 68 | other_anns = np.random.choice( 69 | srl_idxs_possible, size=10 * num_arg_keys_vis, 70 | replace=False 71 | ).reshape(num_arg_keys_vis, 10) 72 | 73 | for aind, arg_key1 in enumerate(arg_keys_vis_present): 74 | in1 = other_anns[aind].tolist() 75 | assert len(in1) == 10 76 | 77 | set1 = set(in1) 78 | 79 | set_int = [s for s in set1 if 80 | vid_segs.loc[s] != vid_seg] 81 | 82 | # TODO: 83 | # Make replace false, currently true 84 | # because some have low chances of 85 | # appearing 86 | assert len(set_int) > 0 87 | inds_to_use[arg_key1] = set_int 88 | ds4_msk[arg_key1] = 1 89 | return inds_to_use, ds4_msk 90 | 91 | 92 | def create_similar_list_new(cfg, arg_dicts, srl_annots, ann_row_idx): 93 | """ 94 | Does it for one row. Assumes annotations 95 | exists and can be retrieved via `self`. 96 | 97 | The logic: 98 | Each input idx has ARG0, V, ARG1 ..., 99 | (1) Pivot across one argument, say ARG0 100 | (2) Retrieve all other indices such that they 101 | have different ARG0, but same V, ARG1 ... (do 102 | each of them separately) 103 | (3) To retrieve those indices with V, ARG1 same 104 | we can just do intersection of the two sets 105 | 106 | To facilitate (2), we first create separate 107 | dictionaries for each V, ARG1 etc. and then 108 | just reference them via self.create_dicts 109 | """ 110 | srl_idxs_possible = np.array(srl_annots.index) 111 | 112 | vid_segs = srl_annots.vid_seg 113 | vid_seg = vid_segs.loc[ann_row_idx] 114 | srl_row = srl_annots.loc[ann_row_idx] 115 | 116 | req_cls_pats = srl_row.req_cls_pats 117 | req_cls_pats_mask = srl_row.req_cls_pats_mask 118 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']) 119 | some_inds = {} 120 | arg_keys_vis_present = [] 121 | arg_keys_lang_present = [] 122 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask): 123 | arg_key = srl_arg[0] 124 | arg_keys_lang_present.append(arg_key) 125 | if arg_key == 'V' or arg_key in args_to_use: 126 | arg_keys_vis_present.append(arg_key) 127 | if arg_key in args_to_use: 128 | lemma_key = 'lemma_{}'.format( 129 | arg_key.replace('-', '_').replace('V', 'verb')) 130 | lemma_arg = srl_row[lemma_key] 131 | if isinstance(lemma_arg, list): 132 | assert all([le_arg in arg_dicts[arg_key] 133 | for le_arg in lemma_arg]) 134 | if len(lemma_arg) >= 1: 135 | le_arg = lemma_arg 136 | else: 137 | le_arg = cfg.ds.none_word 138 | else: 139 | le_arg = [lemma_arg] 140 | # srl_ind_list = copy.deepcopy( 141 | # arg_dicts[arg_key][le_arg]) 142 | # srl_ind_list = arg_dicts[arg_key][le_arg][:] 143 | for le_ar in le_arg: 144 | srl_ind_list = arg_dicts[arg_key][le_ar][:] 145 | srl_ind_list.remove(ann_row_idx) 146 | if arg_key not in some_inds: 147 | some_inds[arg_key] = [] 148 | some_inds[arg_key] += srl_ind_list 149 | # # If not groundable but in args_to_use 150 | # else: 151 | # pass 152 | num_arg_keys_vis = len(arg_keys_vis_present) 153 | other_anns = np.random.choice( 154 | srl_idxs_possible, size=10 * num_arg_keys_vis, 155 | replace=False 156 | ).reshape(num_arg_keys_vis, 10) 157 | 158 | inds_to_use = {} 159 | ds4_msk = {} 160 | for aind, arg_key1 in enumerate(arg_keys_vis_present): 161 | arg_key_to_use = [ 162 | ak for ak in arg_keys_vis_present if ak != arg_key1] 163 | set1 = set(some_inds[arg_key_to_use[0]]) 164 | 165 | set_int1 = set1.intersection( 166 | *[set(some_inds[ak]) for ak in arg_key_to_use[1:]]) 167 | curr_set = set(some_inds[arg_key1]) 168 | set_int2 = list(set_int1 - curr_set) 169 | 170 | set_int = [s for s in set_int2 if 171 | vid_segs.loc[s] != vid_seg] 172 | 173 | # TODO: 174 | # Make replace false, currently true 175 | # because some have low chances of 176 | # appearing 177 | if len(set_int) == 0: 178 | # this means similar scenario not found 179 | # inds 180 | ds4_msk[arg_key1] = 0 181 | inds_to_use[arg_key1] = other_anns[aind].tolist() 182 | # inds_to_use[arg_key1] = [-1] 183 | # cfg.ouch += 1 184 | # print('ouch') 185 | else: 186 | ds4_msk[arg_key1] = 1 187 | inds_to_use[arg_key1] = np.random.choice( 188 | set_int, 10, replace=True).tolist() 189 | # cfg.yolo += 1 190 | # print('yolo') 191 | # inds_to_use_lens = [len(v) if v[0] != -1 else 0 for k, 192 | # v in inds_to_use.items()] 193 | # if sum(inds_to_use_lens) == 0: 194 | # cfg.ouch2 += 1 195 | # else: 196 | # cfg.yolo2 += 1 197 | 198 | return inds_to_use, ds4_msk 199 | 200 | 201 | def create_similar_list(cfg, arg_dicts, srl_annots, ann_row_idx): 202 | """ 203 | Does it for one row. Assumes annotations 204 | exists and can be retrieved via `self`. 205 | 206 | The logic: 207 | Each input idx has ARG0, V, ARG1 ..., 208 | (1) Pivot across one argument, say ARG0 209 | (2) Retrieve all other indices such that they 210 | have different ARG0, but same V, ARG1 ... (do 211 | each of them separately) 212 | (3) To retrieve those indices with V, ARG1 same 213 | we can just do intersection of the two sets 214 | 215 | To facilitate (2), we first create separate 216 | dictionaries for each V, ARG1 etc. and then 217 | just reference them via self.create_dicts 218 | """ 219 | srl_idxs_possible = np.array(srl_annots.index) 220 | 221 | vid_segs = srl_annots.vid_seg 222 | vid_seg = vid_segs.loc[ann_row_idx] 223 | srl_row = srl_annots.loc[ann_row_idx] 224 | 225 | req_cls_pats = srl_row.req_cls_pats 226 | req_cls_pats_mask = srl_row.req_cls_pats_mask 227 | args_to_use = set(['V', 'ARG0', 'ARG1', 'ARG2', 'ARGM-LOC']) 228 | some_inds = {} 229 | arg_keys_vis_present = [] 230 | arg_keys_lang_present = [] 231 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask): 232 | arg_key = srl_arg[0] 233 | arg_keys_lang_present.append(arg_key) 234 | if arg_key == 'V' or arg_key in args_to_use: 235 | arg_keys_vis_present.append(arg_key) 236 | if arg_key in args_to_use: 237 | lemma_key = 'lemma_{}'.format( 238 | arg_key.replace('-', '_').replace('V', 'verb')) 239 | lemma_arg = srl_row[lemma_key] 240 | if isinstance(lemma_arg, list): 241 | assert all([le_arg in arg_dicts[arg_key] 242 | for le_arg in lemma_arg]) 243 | if len(lemma_arg) >= 1: 244 | le_arg = lemma_arg[0] 245 | else: 246 | le_arg = cfg.ds.none_word 247 | else: 248 | le_arg = lemma_arg 249 | # srl_ind_list = copy.deepcopy( 250 | # arg_dicts[arg_key][le_arg]) 251 | # srl_ind_list = arg_dicts[arg_key][le_arg][:] 252 | srl_ind_list = arg_dicts[arg_key][le_arg][:] 253 | srl_ind_list.remove(ann_row_idx) 254 | if arg_key not in some_inds: 255 | some_inds[arg_key] = [] 256 | some_inds[arg_key] += srl_ind_list 257 | # # If not groundable but in args_to_use 258 | # else: 259 | # pass 260 | num_arg_keys_vis = len(arg_keys_vis_present) 261 | other_anns = np.random.choice( 262 | srl_idxs_possible, size=10 * num_arg_keys_vis, 263 | replace=False 264 | ).reshape(num_arg_keys_vis, 10) 265 | 266 | inds_to_use = {} 267 | ds4_msk = {} 268 | for aind, arg_key1 in enumerate(arg_keys_vis_present): 269 | arg_key_to_use = [ 270 | ak for ak in arg_keys_vis_present if ak != arg_key1] 271 | set1 = set(some_inds[arg_key_to_use[0]]) 272 | 273 | set_int1 = set1.intersection( 274 | *[set(some_inds[ak]) for ak in arg_key_to_use[1:]]) 275 | curr_set = set(some_inds[arg_key1]) 276 | set_int2 = list(set_int1 - curr_set) 277 | 278 | set_int = [s for s in set_int2 if 279 | vid_segs.loc[s] != vid_seg] 280 | 281 | # TODO: 282 | # Make replace false, currently true 283 | # because some have low chances of 284 | # appearing 285 | if len(set_int) == 0: 286 | # this means similar scenario not found 287 | # inds 288 | ds4_msk[arg_key1] = 0 289 | inds_to_use[arg_key1] = other_anns[aind].tolist() 290 | # inds_to_use[arg_key1] = [-1] 291 | # cfg.ouch += 1 292 | # print('ouch') 293 | else: 294 | ds4_msk[arg_key1] = 1 295 | inds_to_use[arg_key1] = np.random.choice( 296 | set_int, 10, replace=True).tolist() 297 | # cfg.yolo += 1 298 | # print('yolo') 299 | # inds_to_use_lens = [len(v) if v[0] != -1 else 0 for k, 300 | # v in inds_to_use.items()] 301 | # if sum(inds_to_use_lens) == 0: 302 | # cfg.ouch2 += 1 303 | # else: 304 | # cfg.yolo2 += 1 305 | 306 | return inds_to_use, ds4_msk 307 | 308 | 309 | class AnetDSCreator: 310 | def __init__(self, cfg, tdir='.'): 311 | self.cfg = cfg 312 | self.tdir = Path(tdir) 313 | 314 | def fix_via_ast(self, df): 315 | for k in df.columns: 316 | first_word = df.iloc[0][k] 317 | if isinstance(first_word, str) and (first_word[0] in '[{'): 318 | df[k] = df[k].apply( 319 | lambda x: ast.literal_eval(x)) 320 | return df 321 | 322 | def get_stats(self, req_args): 323 | """ 324 | Gets the counts for the argument types 325 | """ 326 | c = Counter() 327 | if isinstance(req_args[0], list): 328 | for x in req_args: 329 | c += Counter(x) 330 | else: 331 | c = Counter(req_args) 332 | 333 | return c.most_common() 334 | 335 | def create_all_similar_lists(self): 336 | self.create_similar_lists(split_type='train') 337 | self.create_similar_lists(split_type='valid') 338 | 339 | def create_similar_lists(self, split_type: str = 'train'): 340 | """ 341 | need to check if only 342 | creating for the validation 343 | set would be enough or not. 344 | 345 | Basically, for each input, 346 | generates list of other inputs (idxs) 347 | which have same S,V,O (at least one is same) 348 | """ 349 | if split_type == 'train': 350 | srl_annot_file = self.tdir / self.cfg.ds.trn_verb_ent_file 351 | ds4_dict_file = self.tdir / self.cfg.ds.trn_ds4_dicts 352 | ds4_ind_file = self.tdir / self.cfg.ds.trn_ds4_inds 353 | elif split_type == 'valid': 354 | srl_annot_file = self.tdir / self.cfg.ds.val_verb_ent_file 355 | ds4_dict_file = self.tdir / self.cfg.ds.val_ds4_dicts 356 | ds4_ind_file = self.tdir / self.cfg.ds.val_ds4_inds 357 | elif split_type == 'trn_val': 358 | srl_annot_file = self.tdir / self.cfg.ds.verb_ent_file 359 | ds4_dict_file = self.tdir / self.cfg.ds.ds4_dicts 360 | ds4_ind_file = self.tdir / self.cfg.ds.ds4_inds 361 | elif split_type == 'only_val': 362 | srl_annot_file = Path('./data/anet_verb/val_1_verb_ent_file.csv') 363 | ds4_dict_file = Path( 364 | './data/anet_verb/val_1_srl_args_dict_obj_to_ind.json' 365 | ) 366 | else: 367 | raise NotImplementedError 368 | # elif split_type == 'test': 369 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file 370 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts 371 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds 372 | # elif split_type == 'val_test': 373 | # # validation file with validation+test indices 374 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file 375 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts 376 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds 377 | # elif split_type == 'test_val': 378 | # # test file with validation+test indices 379 | # srl_annot_file = self.tdir / self.cfg.ds.test_verb_ent_file 380 | # ds4_dict_file = self.tdir / self.cfg.ds.test_ds4_dicts 381 | # ds4_ind_file = self.tdir / self.cfg.ds.test_ds4_inds 382 | # else: 383 | # raise NotImplementedError 384 | srl_annots = self.fix_via_ast(pd.read_csv(srl_annot_file)) 385 | 386 | self.create_dicts_srl(srl_annots, ds4_dict_file) 387 | 388 | arg_dicts = json.load(open(ds4_dict_file)) 389 | srl_annots_copy = copy.deepcopy(srl_annots) 390 | # inds_to_use_list = [self.create_similar_list( 391 | # row_ind) for row_ind in tqdm(range(len(self.srl_annots)))] 392 | inds_to_use_list = [] 393 | ds4_msk = [] 394 | rand_inds_to_use_list = [] 395 | 396 | for row_ind in tqdm(range(len(srl_annots))): 397 | inds_to_use, ds4_msk_out = create_similar_list( 398 | self.cfg, arg_dicts, srl_annots, row_ind) 399 | ds4_msk.append(ds4_msk_out) 400 | 401 | inds_to_use_list.append(inds_to_use) 402 | 403 | rand_inds_to_use, _ = create_random_list( 404 | self.cfg, srl_annots, row_ind 405 | ) 406 | rand_inds_to_use_list.append(rand_inds_to_use) 407 | 408 | srl_annots_copy['DS4_Inds'] = inds_to_use_list 409 | srl_annots_copy['ds4_msk'] = ds4_msk 410 | 411 | srl_annots_copy['RandDS4_Inds'] = rand_inds_to_use_list 412 | # srl_annots_copy = srl_annots_copy.iloc[ds4_msk] 413 | 414 | srl_annots_copy.to_csv( 415 | ds4_ind_file, index=False, header=True) 416 | # srl_annots_copy.to_csv( 417 | # self.tdir/self.cfg.ds.ds4_inds, index=False, header=True) 418 | # for row_ind in range(len(self.srl_annots)): 419 | # inds_to_use = self.create_similar_list(row_ind) 420 | 421 | def create_dicts_srl(self, srl_annots, out_file): 422 | def default_dict_list(key_list, val, dct): 423 | for key in key_list: 424 | if key not in dct: 425 | dct[key] = [] 426 | dct[key].append(val) 427 | return dct 428 | 429 | # srl_annots = self.srl_annots 430 | 431 | # args_dict_out: Dict[str, Dict[obj_name, srl_indices]] 432 | # arg_dict_lemma: Dict[str, List[obj_name]] 433 | args_dict_out = {} 434 | args_to_use = ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'] 435 | for srl_arg in args_to_use: 436 | args_dict_out[srl_arg] = {} 437 | 438 | for row_ind, row in tqdm(srl_annots.iterrows(), 439 | total=len(srl_annots)): 440 | req_cls_pats = row.req_cls_pats 441 | req_cls_pats_mask = row.req_cls_pats_mask 442 | for srl_arg, srl_arg_mask in zip(req_cls_pats, req_cls_pats_mask): 443 | arg_key = srl_arg_mask[0] 444 | if arg_key in args_dict_out: 445 | # The argument is groundable 446 | if srl_arg_mask[1] == 1: 447 | key_list = list(set(srl_arg[1])) 448 | args_dict_out[arg_key] = default_dict_list( 449 | key_list, row_ind, args_dict_out[arg_key]) 450 | else: 451 | key_list = [self.cfg.ds.none_word] 452 | args_dict_out[arg_key] = default_dict_list( 453 | key_list, row_ind, args_dict_out[arg_key]) 454 | 455 | args_dict_out['V'] = {k: list(v.index) for k, 456 | v in srl_annots.groupby('lemma_verb')} 457 | json.dump(args_dict_out, open(out_file, 'w')) 458 | return args_dict_out 459 | 460 | 461 | def main(splits: List): 462 | if not isinstance(splits, list): 463 | assert isinstance(splits, str) 464 | splits = [splits] 465 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml'))) 466 | for split_type in splits: 467 | anet_ds = AnetDSCreator(cfg) 468 | anet_ds.create_similar_lists(split_type=split_type) 469 | 470 | 471 | if __name__ == '__main__': 472 | import fire 473 | fire.Fire(main) 474 | # cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml'))) 475 | # for split_type in ['valid', 'train']: 476 | # # for split_type in ['only_val', 'valid', 'train', 'trn_val']: 477 | # # cfg.ouch = 0 478 | # # cfg.yolo = 0 479 | 480 | # # cfg.ouch2 = 0 481 | # # cfg.yolo2 = 0 482 | 483 | # anet_ds = AnetDSCreator(cfg) 484 | # # anet_ds.create_dicts_srl() 485 | # anet_ds.create_similar_lists(split_type=split_type) 486 | 487 | # # break 488 | 489 | # # anet_ds.create_similar_lists(split_type='trn_val') 490 | # # anet_ds.create_similar_lists(split_type='train') 491 | # # anet_ds.create_similar_lists(split_type='valid') 492 | # # print(cfg.ouch, cfg.yolo, cfg.yolo+cfg.ouch) 493 | # # print(cfg.ouch2, cfg.yolo2, cfg.yolo2+cfg.ouch2) 494 | -------------------------------------------------------------------------------- /code/eval_vsrl_corr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Better evaluation. 3 | Corrected a few implementations 4 | for sep, temporal, spatial. 5 | """ 6 | from eval_fn_corr import ( 7 | GroundEval_SEP, 8 | GroundEval_TEMP, 9 | GroundEval_SPAT 10 | ) 11 | import pickle 12 | from fastprogress import progress_bar 13 | from pathlib import Path 14 | import torch 15 | from trn_utils import ( 16 | compute_avg_dict, 17 | is_main_process, 18 | synchronize, 19 | get_world_size 20 | ) 21 | 22 | 23 | class Evaluator(torch.nn.Module): 24 | def __init__(self, cfg, comm, device): 25 | super().__init__() 26 | self.cfg = cfg 27 | self.comm = comm 28 | self.met_keys = ['avg1', 'macro_avg1'] 29 | self.num_prop_per_frm = self.comm.num_prop_per_frm 30 | self.num_frms = self.cfg.ds.num_sampled_frm 31 | self.num_props = self.num_prop_per_frm * self.num_frms 32 | self.device = device 33 | self.after_init() 34 | 35 | def after_init(self): 36 | pass 37 | 38 | def get_out_results(self, out_result): 39 | if isinstance(out_result, torch.Tensor): 40 | return out_result 41 | else: 42 | return out_result['mdl_outs'] 43 | 44 | def forward_one_batch(self, out_result, inp): 45 | """ 46 | The following should be returned: 47 | List[Dict] 48 | Dict = { 49 | 'idx(video)', 'idx(srl)', 'idx(arg)', 50 | 'pred_boxes', 'pred_scores' 51 | } 52 | """ 53 | out_result = out_result 54 | # B x num_verbs x num_srl_args x 1000 55 | B, num_verbs, num_srl_args, num_props = out_result.shape 56 | assert self.num_props == num_props 57 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm 58 | out_result_frame = torch.sigmoid( 59 | out_result.view( 60 | B, num_verbs, num_srl_args, 61 | self.num_frms, self.num_prop_per_frm 62 | ) 63 | ) 64 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm 65 | out_result_frame_score, out_result_frame_index = torch.max( 66 | out_result_frame, dim=-1) 67 | 68 | props = inp['pad_proposals'] 69 | _, num_props, prop_dim = props.shape 70 | # B x num_verbs x num_srl_args x num_frms x num_prop_per_frm x prop_dim 71 | props_reshaped = props.view( 72 | B, 1, 1, self.num_frms, self.num_prop_per_frm, prop_dim).expand( 73 | B, num_verbs, num_srl_args, 74 | self.num_frms, self.num_prop_per_frm, prop_dim) 75 | 76 | out_result_boxes = torch.gather( 77 | props_reshaped, dim=-2, 78 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand( 79 | *out_result_frame_index.shape, 1, prop_dim)) 80 | 81 | pred_boxes = out_result_boxes.squeeze(-2) 82 | out_dict_list = [ 83 | { 84 | 'pred_boxes': pb, 85 | 'pred_scores': ps, 86 | 'idx_vid': an_ind, 87 | 'idx_sent': srl_ann, 88 | 'idx_verb': srl_verb, 89 | 'num_verbs': nv 90 | 91 | } for pb, ps, an_ind, srl_ann, srl_verb, nv in zip( 92 | pred_boxes.detach().cpu().tolist(), 93 | out_result_frame_score.detach().cpu().tolist(), 94 | inp['ann_idx'].detach().cpu().tolist(), 95 | inp['sent_idx'].detach().cpu().tolist(), 96 | inp['srl_verb_idxs'].detach().cpu().tolist(), 97 | inp['num_verbs'].detach().cpu().tolist() 98 | )] 99 | return out_dict_list 100 | 101 | def forward(self, model, loss_fn, dl, dl_name, 102 | rank=0, pred_path=None, mb=None): 103 | fname = Path(pred_path) / f'{dl_name}_{rank}.pkl' 104 | # comm = self.comm 105 | # cfg = self.cfg 106 | model.eval() 107 | loss_keys = loss_fn.loss_keys 108 | val_losses = {k: [] for k in loss_keys} 109 | nums = [] 110 | results = [] 111 | for batch in progress_bar(dl, parent=mb): 112 | for b in batch.keys(): 113 | batch[b] = batch[b].to(self.device) 114 | b = next(iter(batch.keys())) 115 | nums.append(batch[b].size(0)) 116 | torch.cuda.empty_cache() 117 | with torch.no_grad(): 118 | out = model(batch) 119 | out_loss = loss_fn(out, batch) 120 | 121 | for k in out_loss: 122 | val_losses[k].append(out_loss[k].detach().cpu()) 123 | results += self.forward_one_batch(out, batch) 124 | 125 | pickle.dump(results, open(fname, 'wb')) 126 | nums = torch.tensor(nums).float() 127 | val_loss = compute_avg_dict(val_losses, nums) 128 | 129 | synchronize() 130 | if is_main_process(): 131 | curr_results = results 132 | world_size = get_world_size() 133 | for w in range(1, world_size): 134 | tmp_file = Path(pred_path) / f'{dl_name}_{w}.pkl' 135 | with open(tmp_file, 'rb') as f: 136 | tmp_results = pickle.load(f) 137 | curr_results += tmp_results 138 | tmp_file.unlink 139 | with open(fname, 'wb') as f: 140 | pickle.dump(curr_results, f) 141 | out_acc = self.grnd_eval.eval_ground_acc(fname) 142 | val_acc = {k: torch.tensor(v).to(self.device) 143 | for k, v in out_acc.items() if k in self.met_keys} 144 | # return val_loss, val_acc 145 | synchronize() 146 | if is_main_process(): 147 | return val_loss, val_acc 148 | else: 149 | return {k: torch.tensor(0.).to(self.device) for k in loss_keys}, { 150 | k: torch.tensor(0.).to(self.device) for k in self.met_keys} 151 | 152 | 153 | class EvaluatorSEP(Evaluator): 154 | def after_init(self): 155 | 156 | self.met_keys = ['avg1', 'avg1_cons', 157 | 'avg1_vidf', 'avg1_strict'] 158 | self.grnd_eval = GroundEval_SEP(self.cfg, self.comm) 159 | 160 | self.num_sampled_frm = self.num_frms 161 | 162 | def get_out_results_boxes(self, out_result_dict, inp): 163 | """ 164 | get the correct boxes, scores, indexes per frame 165 | """ 166 | assert isinstance(out_result_dict, dict) 167 | # B x num_cmp 168 | fin_scores = out_result_dict['fin_scores'] 169 | B, num_cmp = fin_scores.shape 170 | 171 | # B 172 | vidf_outs = torch.argmax(fin_scores, dim=-1) 173 | 174 | # B x num_cmp x num_srl_args x num_props 175 | mdl_outs = out_result_dict['mdl_outs_eval'] 176 | 177 | B, num_cmp, num_srl_args, num_props = mdl_outs.shape 178 | 179 | mdl_outs_reshaped = mdl_outs.transpose( 180 | 1, 2).contiguous().view( 181 | B, num_srl_args, num_cmp, 182 | self.num_sampled_frm, self.num_prop_per_frm 183 | ) 184 | 185 | # B x num_srl_args x num_cmp x num_frms 186 | out_result_frame_score, out_result_frame_index = torch.max( 187 | mdl_outs_reshaped, dim=-1 188 | ) 189 | 190 | props = inp['pad_proposals'] 191 | _, num_cmp, num_props, prop_dim = props.shape 192 | 193 | props_reshaped = props.view( 194 | B, 1, num_cmp, 195 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim 196 | ).expand( 197 | B, num_srl_args, num_cmp, 198 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim 199 | ) 200 | 201 | props_out = torch.gather( 202 | props_reshaped, 203 | dim=-2, 204 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand( 205 | B, num_srl_args, num_cmp, 206 | self.num_sampled_frm, 1, prop_dim 207 | ) 208 | ) 209 | 210 | props_out = props_out.squeeze(-2) 211 | 212 | # B -> B x #srl x #frms 213 | vidf_outs = vidf_outs.view(B, 1, 1).expand( 214 | B, num_srl_args, self.num_frms) 215 | 216 | return { 217 | 'boxes': props_out, 218 | 'scores': out_result_frame_score, 219 | 'indexs': vidf_outs 220 | } 221 | 222 | def forward_one_batch(self, out_result, inp): 223 | """ 224 | The following should be returned: 225 | List[Dict] 226 | Dict = { 227 | 'idx(video)', 'idx(srl)', 'idx(arg)', 228 | 'pred_boxes', 'pred_scores' 229 | } 230 | """ 231 | out_results = self.get_out_results_boxes(out_result, inp) 232 | 233 | out_result_boxes = out_results['boxes'] 234 | out_result_frame_score = out_results['scores'] 235 | out_result_frame_index = out_results['indexs'] 236 | 237 | # B x num_srl_args x num_cmp x num_frms x num_props 238 | pred_boxes = out_result_boxes 239 | # B x num_srl_args x num_frms 240 | pred_cmp = out_result_frame_index 241 | # B x num_srl_args x num_cmp x num_frms 242 | pred_score = out_result_frame_score 243 | targ_cmp = inp['target_cmp'].detach().cpu().tolist() 244 | perm_list = inp['permute'].detach().cpu().tolist() 245 | perm_inv_list = inp['permute_inv'].detach().cpu().tolist() 246 | 247 | out_dict_list = [ 248 | { 249 | 'pred_boxes': pb, 250 | 'pred_scores': ps, 251 | 'pred_cmp': pc, 252 | 'idx_vid': an_ind, 253 | 'idx_verbs': srl_idxs, 254 | 'idx_sent': srl_ann, 255 | 'cmp_msk': cmp_msk, 256 | 'targ_cmp': tcmp, 257 | 'perm': perm, 258 | 'perm_inv': perm_inv, 259 | 260 | } for pb, ps, pc, an_ind, srl_idxs, srl_ann, cmp_msk, 261 | tcmp, perm, perm_inv in zip( 262 | pred_boxes.detach().cpu().tolist(), 263 | pred_score.detach().cpu().tolist(), 264 | pred_cmp.detach().cpu().tolist(), 265 | inp['ann_idx'].detach().cpu().tolist(), 266 | inp['new_srl_idxs'].detach().cpu().tolist(), 267 | inp['sent_idx'].detach().cpu().tolist(), 268 | inp['num_cmp_msk'].detach().cpu().tolist(), 269 | targ_cmp, 270 | perm_list, 271 | perm_inv_list 272 | )] 273 | return out_dict_list 274 | 275 | 276 | class EvaluatorTEMP(EvaluatorSEP): 277 | def after_init(self): 278 | # self.met_keys = ['avg1', 'macro_avg1', 'avg1_cons', 'macro_avg1_cons'] 279 | # self.grnd_eval = GroundEvalDS4(self.cfg, self.comm) 280 | 281 | self.met_keys = ['avg1', 'avg1_cons', 282 | 'avg1_vidf', 'avg1_strict'] 283 | self.grnd_eval = GroundEval_TEMP(self.cfg, self.comm) 284 | 285 | # self.num_sampled_frm = self.cfg.misc.num_sampled_frm 286 | self.num_sampled_frm = self.num_frms 287 | # self.num_prop_per_frm = self.cfg.misc.num_prop_per_frm 288 | 289 | def get_out_results_boxes(self, out_result_dict, inp): 290 | """ 291 | get the correct boxes, scores, indexes per frame 292 | """ 293 | assert isinstance(out_result_dict, dict) 294 | 295 | out_result = out_result_dict['mdl_outs_eval'] 296 | num_cmp = inp['new_srl_idxs'].size(1) 297 | 298 | # B x num_verbs x num_srl_args x 4000 299 | B, num_verbs, num_srl_args, num_props = out_result.shape 300 | 301 | assert num_verbs == 1 302 | # B x num_srl_args x num_props 303 | # mdl_outs = out_result.squeeze(1) 304 | mdl_outs_reshaped = out_result.view( 305 | B, num_srl_args, num_cmp, 306 | self.num_sampled_frm, self.num_prop_per_frm 307 | ) 308 | 309 | # B x num_srl_args x num_cmp x num_frms 310 | out_result_frame_score, out_result_frame_index = torch.max( 311 | mdl_outs_reshaped, dim=-1 312 | ) 313 | 314 | props = inp['pad_proposals'] 315 | 316 | _, num_props, prop_dim = props.shape 317 | assert (num_cmp * self.num_sampled_frm * 318 | self.num_prop_per_frm == num_props) 319 | props_reshaped = props.view( 320 | B, 1, num_cmp, 321 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim 322 | ).expand( 323 | B, num_srl_args, num_cmp, 324 | self.num_sampled_frm, self.num_prop_per_frm, prop_dim 325 | ) 326 | 327 | props_out = torch.gather( 328 | props_reshaped, 329 | dim=-2, 330 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand( 331 | B, num_srl_args, num_cmp, 332 | self.num_sampled_frm, 1, prop_dim 333 | ) 334 | ) 335 | 336 | props_out = props_out.squeeze(-2) 337 | # Not used in temporal. Make it all zeros 338 | vidf_outs = torch.zeros(1, 1, 1).expand( 339 | B, num_srl_args, self.num_frms 340 | ) 341 | return { 342 | 'boxes': props_out, 343 | 'scores': out_result_frame_score, 344 | 'indexs': vidf_outs 345 | } 346 | 347 | 348 | class EvaluatorSPAT(EvaluatorSEP): 349 | def after_init(self): 350 | self.met_keys = ['avg1', 'avg1_cons', 'avg1_vidf', 'avg1_strict'] 351 | self.grnd_eval = GroundEval_SPAT(self.cfg, self.comm) 352 | 353 | self.num_sampled_frm = self.num_frms 354 | # self.num_sampled_frm = self.cfg.misc.num_sampled_frm 355 | # self.num_prop_per_frm = self.cfg.misc.num_prop_per_frm 356 | 357 | def get_out_results_boxes(self, out_result_dict, inp): 358 | """ 359 | get the correct boxes, scores, indexes per frame 360 | """ 361 | assert isinstance(out_result_dict, dict) 362 | 363 | out_result = out_result_dict['mdl_outs_eval'] 364 | num_cmp = inp['new_srl_idxs'].size(1) 365 | 366 | # B x num_verbs x num_srl_args x 4000 367 | B, num_verbs, num_srl_args, num_props = out_result.shape 368 | 369 | assert num_verbs == 1 370 | # B x num_srl_args x num_props 371 | mdl_outs_reshaped = out_result.view( 372 | B, num_srl_args, 373 | self.num_sampled_frm, num_cmp, self.num_prop_per_frm 374 | ) 375 | 376 | # B x num_srl_args x num_frm x num_cmp 377 | out_result_frame_score, out_result_frame_index = torch.max( 378 | mdl_outs_reshaped, dim=-1 379 | ) 380 | 381 | props = inp['pad_proposals'] 382 | 383 | _, num_props, prop_dim = props.shape 384 | assert (num_cmp * self.num_sampled_frm * 385 | self.num_prop_per_frm == num_props) 386 | props_reshaped = props.view( 387 | B, 1, self.num_sampled_frm, 388 | num_cmp, self.num_prop_per_frm, prop_dim 389 | ).expand( 390 | B, num_srl_args, self.num_sampled_frm, 391 | num_cmp, self.num_prop_per_frm, prop_dim 392 | ) 393 | 394 | props_out = torch.gather( 395 | props_reshaped, 396 | dim=-2, 397 | index=out_result_frame_index.unsqueeze(-1).unsqueeze(-1).expand( 398 | B, num_srl_args, self.num_sampled_frm, num_cmp, 1, prop_dim 399 | ) 400 | ) 401 | 402 | # B x num_srl x num_frms x num_cmp 403 | props_out = props_out.squeeze(-2) 404 | # For consistency across sep, temporal, spatial 405 | props_out = props_out.transpose(2, 3).contiguous() 406 | 407 | # Used in spatial. 408 | # Divide by 100 409 | # vidf_outs = torch.div( 410 | # out_result_frame_index.squeeze(-1), 411 | # self.num_prop_per_frm 412 | # ).long() 413 | 414 | # B x num_srl_args x num_frm 415 | vidf_outs = out_result_frame_score.argmax(dim=-1) 416 | 417 | # B x num_srl_args x num_frm x num_cmp 418 | score_out = out_result_frame_score.transpose(2, 3).contiguous() 419 | 420 | return { 421 | 'boxes': props_out, 422 | 'scores': score_out, 423 | 'indexs': vidf_outs 424 | } 425 | -------------------------------------------------------------------------------- /code/extended_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | # import json 3 | # import yaml 4 | from _init_stuff import yaml 5 | from typing import Dict, Any 6 | 7 | with open('./configs/anet_srl_cfg.yml') as f: 8 | def_cfg = yaml.safe_load(f) 9 | 10 | cfg = CN(def_cfg) 11 | cfg.comm = CN() 12 | 13 | key_maps = {} 14 | 15 | 16 | def create_from_dict(dct: Dict[str, Any], prefix: str, cfg: CN): 17 | """ 18 | Helper function to create yacs config from dictionary 19 | """ 20 | dct_cfg = CN(dct, new_allowed=True) 21 | prefix_list = prefix.split('.') 22 | d = cfg 23 | for pref in prefix_list[:-1]: 24 | assert isinstance(d, CN) 25 | if pref not in d: 26 | setattr(d, pref, CN()) 27 | d = d[pref] 28 | if hasattr(d, prefix_list[-1]): 29 | old_dct_cfg = d[prefix_list[-1]] 30 | dct_cfg.merge_from_other_cfg(old_dct_cfg) 31 | 32 | setattr(d, prefix_list[-1], dct_cfg) 33 | return cfg 34 | 35 | 36 | def update_from_dict(cfg: CN, dct: Dict[str, Any], 37 | key_maps: Dict[str, str] = None) -> CN: 38 | """ 39 | Given original CfgNode (cfg) and input dictionary allows changing 40 | the cfg with the updated dictionary values 41 | Optional key_maps argument which defines a mapping between 42 | same keys of the cfg node. Only used for convenience 43 | Adapted from: 44 | https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L219 45 | """ 46 | # Original cfg 47 | root = cfg 48 | if key_maps is None: 49 | key_maps = [] 50 | # Change the input dictionary using keymaps 51 | # Now it is aligned with the cfg 52 | full_key_list = list(dct.keys()) 53 | for full_key in full_key_list: 54 | if full_key in key_maps: 55 | cfg[full_key] = dct[full_key] 56 | new_key = key_maps[full_key] 57 | dct[new_key] = dct.pop(full_key) 58 | 59 | # Convert the cfg using dictionary input 60 | for full_key, v in dct.items(): 61 | if root.key_is_deprecated(full_key): 62 | continue 63 | if root.key_is_renamed(full_key): 64 | root.raise_key_rename_error(full_key) 65 | key_list = full_key.split(".") 66 | d = cfg 67 | for subkey in key_list[:-1]: 68 | # Most important statement 69 | assert subkey in d, f'key {full_key} doesnot exist' 70 | d = d[subkey] 71 | 72 | subkey = key_list[-1] 73 | # Most important statement 74 | assert subkey in d, f'key {full_key} doesnot exist' 75 | 76 | value = cfg._decode_cfg_value(v) 77 | 78 | assert isinstance(value, type(d[subkey])) 79 | d[subkey] = value 80 | 81 | return cfg 82 | 83 | 84 | def post_proc_config(cfg: CN): 85 | """ 86 | Add any post processing based on cfg 87 | """ 88 | return cfg 89 | -------------------------------------------------------------------------------- /code/main_dist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file for distributed training 3 | """ 4 | import sys 5 | # from dat_loader import get_data 6 | from dat_loader_simple import get_data 7 | from mdl_selector import get_mdl_loss_eval 8 | from trn_utils import Learner, synchronize 9 | 10 | import torch 11 | import fire 12 | from functools import partial 13 | 14 | from extended_config import ( 15 | cfg as conf, 16 | key_maps, 17 | CN, 18 | update_from_dict, 19 | post_proc_config 20 | ) 21 | 22 | import resource 23 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 24 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 25 | 26 | 27 | def get_name_from_inst(inst): 28 | return inst.__class__.__name__ 29 | 30 | 31 | def learner_init(uid: str, cfg: CN) -> Learner: 32 | # = get_mdl_loss(cfg) 33 | mdl_loss_eval = get_mdl_loss_eval(cfg) 34 | get_default_net = mdl_loss_eval['mdl'] 35 | get_default_loss = mdl_loss_eval['loss'] 36 | get_default_eval = mdl_loss_eval['eval'] 37 | 38 | device = torch.device('cuda') 39 | # device = torch.device('cpu') 40 | data = get_data(cfg) 41 | comm = data.train_dl.dataset.comm 42 | mdl = get_default_net(cfg=cfg, comm=comm) 43 | 44 | # pretrained_state_dict = torch.load(cfg.pretrained_path) 45 | # to_load_state_dict = pretrained_state_dict 46 | # mdl.load_state_dict(to_load_state_dict) 47 | 48 | loss_fn = get_default_loss(cfg, comm) 49 | loss_fn.to(device) 50 | # if cfg.do_dist: 51 | # loss_fn.to(device) 52 | 53 | eval_fn = get_default_eval(cfg, comm, device) 54 | eval_fn.to(device) 55 | opt_fn = partial(torch.optim.Adam, betas=(0.9, 0.99)) 56 | 57 | # unfreeze cfg to save the names 58 | cfg.defrost() 59 | module_name = mdl 60 | cfg.mdl_data_names = CN({ 61 | 'trn_data': get_name_from_inst(data.train_dl.dataset), 62 | 'val_data': get_name_from_inst(data.valid_dl.dataset), 63 | 'trn_collator': get_name_from_inst(data.train_dl.collate_fn), 64 | 'val_collator': get_name_from_inst(data.valid_dl.collate_fn), 65 | 'mdl_name': get_name_from_inst(module_name), 66 | 'loss_name': get_name_from_inst(loss_fn), 67 | 'eval_name': get_name_from_inst(eval_fn), 68 | 'opt_name': opt_fn.func.__name__ 69 | }) 70 | cfg.freeze() 71 | 72 | learn = Learner(uid=uid, data=data, mdl=mdl, loss_fn=loss_fn, 73 | opt_fn=opt_fn, eval_fn=eval_fn, device=device, cfg=cfg) 74 | 75 | if cfg.do_dist: 76 | mdl.to(device) 77 | mdl = torch.nn.parallel.DistributedDataParallel( 78 | mdl, device_ids=[cfg.local_rank], 79 | output_device=cfg.local_rank, broadcast_buffers=True, 80 | find_unused_parameters=True) 81 | elif cfg.do_dp: 82 | # Use data parallel 83 | mdl = torch.nn.DataParallel(mdl) 84 | 85 | mdl = mdl.to(device) 86 | 87 | return learn 88 | 89 | 90 | def main_dist(uid: str, **kwargs): 91 | """ 92 | uid is a unique identifier for the experiment name 93 | Can be kept same as a previous run, by default will start executing 94 | from latest saved model 95 | **kwargs: allows arbit arguments of cfg to be changed 96 | """ 97 | cfg = conf 98 | num_gpus = torch.cuda.device_count() 99 | cfg.num_gpus = num_gpus 100 | cfg.uid = uid 101 | cfg.cmd = sys.argv 102 | if num_gpus > 1: 103 | if 'local_rank' in kwargs: 104 | # We are doing distributed parallel 105 | cfg.do_dist = True 106 | torch.cuda.set_device(kwargs['local_rank']) 107 | torch.distributed.init_process_group( 108 | backend="nccl", init_method="env://" 109 | ) 110 | synchronize() 111 | else: 112 | # We are doing data parallel 113 | cfg.do_dist = False 114 | # cfg.do_dp = True 115 | # Update the config file depending on the command line args 116 | cfg = update_from_dict(cfg, kwargs, key_maps) 117 | cfg = post_proc_config(cfg) 118 | # Freeze the cfg, can no longer be changed 119 | cfg.freeze() 120 | # print(cfg) 121 | # Initialize learner 122 | learn = learner_init(uid, cfg) 123 | # Train or Test 124 | if not (cfg.only_val or cfg.only_test or cfg.overfit_batch): 125 | learn.fit(epochs=cfg.train.epochs, lr=cfg.train.lr) 126 | if cfg.run_final_val: 127 | print('Running Final Validation using best model') 128 | learn.load_model_dict( 129 | resume_path=learn.model_file, 130 | load_opt=False 131 | ) 132 | val_loss, val_acc, _ = learn.validate( 133 | db={'valid': learn.data.valid_dl}, 134 | write_to_file=True 135 | ) 136 | print(val_loss) 137 | print(val_acc) 138 | else: 139 | pass 140 | else: 141 | if cfg.overfit_batch: 142 | learn.overfit_batch(1000, 1e-4) 143 | if cfg.only_val: 144 | val_loss, val_acc, _ = learn.validate( 145 | db={'valid': learn.data.valid_dl}, 146 | write_to_file=True 147 | ) 148 | print(val_loss) 149 | print(val_acc) 150 | # learn.testing(learn.data.valid_dl) 151 | pass 152 | if cfg.only_test: 153 | # learn.testing(learn.data.test_dl) 154 | test_loss, test_acc, _ = learn.validate( 155 | db=learn.data.test_dl) 156 | print(test_loss) 157 | print(test_acc) 158 | 159 | return 160 | 161 | 162 | if __name__ == '__main__': 163 | fire.Fire(main_dist) 164 | -------------------------------------------------------------------------------- /code/mdl_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base Model and Loss 3 | Other models build on top of this. 4 | Basically, have all the required args here. 5 | """ 6 | from torch import nn 7 | from munch import Munch 8 | 9 | 10 | class AnetBaseMdl(nn.Module): 11 | def __init__(self, cfg, comm): 12 | super().__init__() 13 | self.cfg = cfg 14 | # Common stuff that needs to be passed around 15 | if comm is not None: 16 | assert isinstance(comm, (dict, Munch)) 17 | self.comm = Munch(comm) 18 | else: 19 | self.comm = Munch() 20 | 21 | self.set_args() 22 | self.after_init() 23 | 24 | def after_init(self): 25 | self.build_model() 26 | 27 | def build_model(self): 28 | self.build_lang_model() 29 | self.build_vis_model() 30 | self.build_conc_model() 31 | 32 | def set_args(self): 33 | """ 34 | Place to set all the required arguments, taken from cfg 35 | """ 36 | # Vocab size needs to be in the ds 37 | # Can be added after after creation of the DATASET 38 | self.vocab_size = self.comm.vocab_size 39 | 40 | # Number of object classes 41 | # Also after creation of dataset. 42 | # Perhaps a good idea to keep all stuff 43 | # to be passed from ds to mdl in a separate 44 | # argument. Could be really helpful 45 | self.detect_size = self.comm.detect_size 46 | 47 | # Input encoding size 48 | # This is the size of the embedding for each word 49 | self.input_encoding_size = self.cfg.mdl.input_encoding_size 50 | 51 | # Hidden dimension of RNN 52 | self.rnn_size = self.cfg.mdl.rnn.rnn_size 53 | 54 | # Number of layers in RNN 55 | self.num_layers = self.cfg.mdl.rnn.num_layers 56 | 57 | # Dropout probability of LM 58 | self.drop_prob_lm = self.cfg.mdl.rnn.drop_prob_lm 59 | 60 | # itod 61 | self.itod = self.comm.itod 62 | 63 | self.num_sampled_frm = self.cfg.ds.num_sampled_frm 64 | self.num_prop_per_frm = self.comm.num_prop_per_frm 65 | 66 | self.unk_idx = int(self.comm.wtoi['UNK']) 67 | 68 | # Temporal attention size 69 | self.t_attn_size = self.cfg.ds.t_attn_size 70 | 71 | # srl_arg_len 72 | self.srl_arg_len = self.cfg.misc.srl_arg_length 73 | 74 | self.set_args_mdl() 75 | self.set_args_conc() 76 | 77 | def set_args_mdl(self): 78 | """ 79 | Mdl specific args 80 | """ 81 | return 82 | 83 | def set_args_conc(self): 84 | """ 85 | Conc Type specific args 86 | """ 87 | return 88 | 89 | def build_lang_model(self): 90 | """ 91 | How to encode the input sentence 92 | """ 93 | raise NotImplementedError 94 | 95 | def build_vis_model(self): 96 | """ 97 | How to encode the visual features 98 | How to encode proposal features 99 | and rgb, motion features 100 | """ 101 | raise NotImplementedError 102 | 103 | def build_conc_model(self): 104 | """ 105 | How to concatenate language and visual features 106 | """ 107 | raise NotImplementedError 108 | 109 | 110 | def main(): 111 | from _init_stuff import Fpath, Arr, yaml 112 | from yacs.config import CfgNode as CN 113 | from dat_loader_simple import get_data 114 | cfg = CN(yaml.safe_load(open('./configs/anet_srl_cfg.yml'))) 115 | data = get_data(cfg) 116 | comm = data.train_dl.dataset.comm 117 | mdl = AnetBaseMdl(cfg, comm) 118 | return mdl 119 | 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /code/mdl_conc_sep.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take care of SEP case. 3 | """ 4 | 5 | from mdl_conc_single import ConcBase 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from mdl_srl_utils import combine_first_ax 10 | from box_utils import bbox_overlaps 11 | 12 | 13 | class ConcSEP(ConcBase): 14 | def conc_encode(self, conc_feats, inp): 15 | nfrm = self.num_sampled_frm 16 | nppf = self.num_prop_per_frm 17 | ncmp = inp['new_srl_idxs'].size(1) 18 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, ncmp) 19 | 20 | def simple_obj_interact_input(self, prop_seg_feats, inp): 21 | B, num_cmp, num_props, psdim = prop_seg_feats.shape 22 | return self.simple_obj_interact( 23 | prop_seg_feats, inp, 24 | num_cmp, self.num_sampled_frm, 25 | self.num_prop_per_frm 26 | ) 27 | 28 | def set_args_conc(self): 29 | self.nfrms = self.num_sampled_frm 30 | self.nppf = self.num_prop_per_frm 31 | 32 | def get_num_cmp_msk(self, inp, out_shape): 33 | num_cmp = inp['new_srl_idxs'].size(1) 34 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape 35 | num_cmp_msk = inp['num_cmp_msk'].view( 36 | B, num_cmp, 1, 1 37 | ).expand( 38 | B, num_cmp, num_srl_args, 39 | self.num_sampled_frm * self.num_prop_per_frm 40 | ).contiguous( 41 | ).view(*out_shape) 42 | return num_cmp_msk 43 | 44 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp): 45 | B, num_cmp, num_props, pdim = prop_feats.shape 46 | prop_seg_feats = torch.cat( 47 | [ 48 | prop_feats.view( 49 | B, num_cmp, self.num_sampled_frm, 50 | self.num_prop_per_frm, prop_feats.size(-1) 51 | ), 52 | seg_feats.unsqueeze(-2).expand( 53 | B, num_cmp, self.num_sampled_frm, 54 | self.num_prop_per_frm, seg_feats.size(-1) 55 | ) 56 | ], dim=-1 57 | ).view( 58 | B, num_cmp, self.num_sampled_frm*self.num_prop_per_frm, 59 | prop_feats.size(-1) + seg_feats.size(-1) 60 | ) 61 | # B x num_cmp x nfrm*nppf x psdim 62 | return prop_seg_feats 63 | 64 | def compute_fin_scores(self, conc_out_dict, inp, vidf_outs=None): 65 | """ 66 | output fin scores should be of shape 67 | B x num_cmp 68 | prop_scores: B x num_cmp x num_srl_args x num_props 69 | """ 70 | prop_scores1 = conc_out_dict['conc_feats_out'].clone().detach() 71 | prop_scores = torch.sigmoid(prop_scores1) 72 | # prop_scores = prop_scores1 73 | if self.cfg.mdl.use_vis_msk: 74 | # B x num_cmp x num_srl_args 75 | prop_scores_max_boxes, _ = torch.max(prop_scores, dim=-1) 76 | 77 | # B x num_cmp x num_srl_args 78 | srl_arg_inds_msk = inp['srl_arg_inds_msk'].float() 79 | B, num_verbs, num_srl_args = srl_arg_inds_msk.shape 80 | 81 | num_cmp = prop_scores.size(1) 82 | 83 | if vidf_outs is not None: 84 | # add vidf outs to the verb places 85 | vidf_outs = torch.sigmoid(vidf_outs) 86 | 87 | # B x num_cmp -> B x num_cmp x num_srl_args 88 | vidf_outs = vidf_outs.unsqueeze(-1).expand( 89 | *prop_scores_max_boxes.shape 90 | ) 91 | vmsk = inp['verb_ind_in_srl'] 92 | 93 | if vmsk.size(1) == 1 and num_cmp > 1: 94 | vmsk = vmsk.expand(-1, num_cmp) 95 | # B x num_cmp 96 | vmsk = vmsk.view( 97 | B, num_cmp, 1).expand( 98 | B, num_cmp, num_srl_args 99 | ) 100 | prop_scores_max_boxes.scatter_( 101 | dim=2, 102 | index=vmsk, 103 | src=vidf_outs 104 | ) 105 | 106 | prop_scores_max_boxes = prop_scores_max_boxes * srl_arg_inds_msk 107 | 108 | # b x num_cmp 109 | fin_scores_eval = prop_scores_max_boxes.sum( 110 | dim=-1) / srl_arg_inds_msk.sum(dim=-1) 111 | 112 | verb_msk = inp['num_cmp_msk'] 113 | fin_scores_eval = fin_scores_eval * verb_msk.float() 114 | 115 | fin_scores_loss = prop_scores_max_boxes * verb_msk.unsqueeze( 116 | -1).expand(*prop_scores_max_boxes.shape).float() 117 | return { 118 | # B x num_cmp 119 | 'fin_scores_eval': fin_scores_eval, 120 | # B x num_cmp x num_srl_args 121 | 'fin_scores_loss': fin_scores_loss 122 | } 123 | 124 | else: 125 | # B x num_cmp x num_cmp x num_srl_args 126 | prop_scores_max_boxes, _ = torch.max(prop_scores, dim=-1) 127 | # B x num_cmp x num_cmp 128 | fin_scores = prop_scores_max_boxes.sum(dim=-1) 129 | return fin_scores 130 | 131 | def forward(self, inp): 132 | """ 133 | Main difference is that prop feats/seg features 134 | have an extra dimension 135 | """ 136 | # B x 6 x 5 x 40 137 | # 6 is num_cmp for a sent 138 | # 5 is num args in a sent 139 | # 40 is seq length for each arg 140 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape 141 | # B*num_cmp x seq_len 142 | src_toks = self.get_srl_arg_seq_to_sent_seq(inp) 143 | # B*num_cmp x seq_len 144 | src_lens = inp['srl_arg_word_mask_len'].view(B*num_verbs, -1) 145 | # B*num_cmp x seq_len x 256 146 | lstm_outs = self.lang_encode(src_toks, src_lens) 147 | lstm_encoded = lstm_outs['lstm_full_output'] 148 | 149 | # B x num_cmp x 5 x 512 150 | srl_arg_lstm_encoded = self.retrieve_srl_arg_from_lang_encode( 151 | lstm_encoded, inp 152 | ) 153 | 154 | # Get visual features 155 | # B x num_cmp x 1000 x 512 156 | prop_feats = self.prop_feats_encode(inp) 157 | # B, num_cmp, num_props, pdim = prop_feats.shape 158 | 159 | # Get seg features 160 | # B x num_cmp x 10 x 512 161 | seg_feats = self.seg_feats_encode(inp) 162 | 163 | # B x num_cmp x nfrm*nppf x psdim 164 | prop_seg_feats = self.concat_prop_seg_feats(prop_feats, seg_feats, inp) 165 | 166 | prop_seg_feats = self.simple_obj_interact_input( 167 | prop_seg_feats, inp 168 | ) 169 | 170 | num_cmp = inp['new_srl_idxs'].size(1) 171 | if srl_arg_lstm_encoded.size(1) == 1 and num_cmp > 1: 172 | srl_arg_lstm_encoded = srl_arg_lstm_encoded.expand( 173 | -1, num_cmp, -1, -1 174 | ) 175 | 176 | conc_feats = self.concate_vis_lang_feats( 177 | prop_seg_feats, srl_arg_lstm_encoded 178 | ) 179 | 180 | # B x num_cmp x num_srl_args x num_props 181 | conc_feats_out_dict = self.conc_encode(conc_feats, inp) 182 | conc_feats_out = conc_feats_out_dict['conc_feats_out'] 183 | 184 | seg_feats_for_verb, verb_feats = self.get_seg_verb_feats_to_process( 185 | seg_feats, srl_arg_lstm_encoded, lstm_outs, inp 186 | ) 187 | 188 | if verb_feats.size(1) == 1 and num_cmp > 1: 189 | verb_feats = verb_feats.expand(-1, num_cmp, -1) 190 | 191 | # B x num_cmp 192 | vidf_outs = self.compute_seg_verb_feats_out( 193 | seg_feats_for_verb, verb_feats 194 | ) 195 | fin_scores = self.compute_fin_scores( 196 | conc_feats_out_dict, inp, vidf_outs 197 | ) 198 | 199 | num_cmp_msk = self.get_num_cmp_msk(inp, conc_feats_out.shape) 200 | 201 | srl_ind_msk = inp['srl_arg_inds_msk'] 202 | if srl_ind_msk.size(1) == 1 and num_cmp > 1: 203 | srl_ind_msk = srl_ind_msk.expand( 204 | -1, num_cmp, -1, -1 205 | ) 206 | srl_ind_msk = srl_ind_msk.unsqueeze(-1).expand( 207 | *conc_feats_out.shape) 208 | mdl_outs_eval = torch.sigmoid( 209 | conc_feats_out) * srl_ind_msk.float() * num_cmp_msk.float() 210 | 211 | return { 212 | 'mdl_outs': conc_feats_out, 213 | 'mdl_outs_eval': mdl_outs_eval, 214 | 'vidf_outs': vidf_outs, 215 | 'fin_scores_loss': fin_scores['fin_scores_loss'], 216 | 'fin_scores': fin_scores['fin_scores_eval'] 217 | } 218 | 219 | 220 | class LossB_SEP(nn.Module): 221 | """ 222 | Loss Function (for a batch) for SEP case. 223 | Specifically, we need to have a separate verb loss 224 | Also, handling of some functions is different 225 | from single video case 226 | """ 227 | 228 | def __init__(self, cfg, comm): 229 | super().__init__() 230 | self.cfg = cfg 231 | self.comm = comm 232 | self.loss_keys = ['loss', 'mdl_out_loss', 'verb_loss'] 233 | self.loss_lambda = self.cfg.loss.loss_lambda 234 | self.after_init() 235 | 236 | def after_init(self): 237 | pass 238 | 239 | def get_targets_from_overlaps(self, overlaps, inp): 240 | """ 241 | Use the given overlaps to produce the targets 242 | overlaps: B x num_cmp x 1000 x 100 243 | """ 244 | targets = overlaps 245 | 246 | srl_boxes = inp['srl_boxes'] 247 | B, num_verbs, num_srl_args, num_box_per_srl = srl_boxes.shape 248 | B, num_cmp, num_props, num_gt_box = targets.shape 249 | 250 | if num_verbs == 1 and num_cmp > 1: 251 | srl_boxes = srl_boxes.expand(-1, num_cmp, -1, -1) 252 | 253 | srl_boxes_reshaped = srl_boxes.view( 254 | B, num_cmp, num_srl_args, 1, num_box_per_srl).expand( 255 | B, num_cmp, num_srl_args, num_props, num_box_per_srl) 256 | 257 | targets_reshaped = targets.view( 258 | B, num_cmp, 1, num_props, num_gt_box).expand( 259 | B, num_cmp, num_srl_args, num_props, num_gt_box) 260 | 261 | # Choose only those proposals which are ground-truth 262 | # for given srl 263 | targets_to_use = torch.gather( 264 | targets_reshaped, dim=-1, index=srl_boxes_reshaped) 265 | 266 | srl_boxes_lens = inp['srl_boxes_lens'] 267 | targets_to_use = ( 268 | targets_to_use * srl_boxes_lens.float().unsqueeze( 269 | -2).expand(*targets_to_use.shape) 270 | ) 271 | 272 | targets_to_use = targets_to_use.max(dim=-1)[0] > 0.5 273 | 274 | return targets_to_use 275 | 276 | def compute_overlaps(self, inp): 277 | 278 | pad_props = inp['pad_proposals'] 279 | gt_bboxs = inp['pad_gt_bboxs'] 280 | frm_msk = inp['pad_frm_mask'] 281 | pnt_msk = inp['pad_pnt_mask'] 282 | 283 | assert len(pnt_msk.shape) == 3 284 | 285 | B = pad_props.size(0) 286 | num_cmp = pad_props.size(1) 287 | pad_props = combine_first_ax(pad_props) 288 | gt_bboxs = combine_first_ax(gt_bboxs) 289 | frm_msk = combine_first_ax(frm_msk) 290 | 291 | pnt_msk = combine_first_ax(pnt_msk) 292 | 293 | overlaps = bbox_overlaps( 294 | pad_props, gt_bboxs, 295 | (frm_msk | pnt_msk[:, :].unsqueeze(-1))) 296 | overlaps = overlaps.view(B, num_cmp, *overlaps.shape[1:]) 297 | 298 | return overlaps 299 | 300 | def compute_loss_targets(self, inp): 301 | """ 302 | Compute the targets, based on iou 303 | overlaps 304 | """ 305 | overlaps = self.compute_overlaps(inp) 306 | B, ncmp, nprop, ngt = overlaps.shape 307 | overlaps_msk = overlaps.new_zeros(*overlaps.shape) 308 | 309 | targ_cmp = inp['target_cmp'] 310 | # overlaps_msk[:, targ_cmp, ...] = 1 311 | overlaps_msk.scatter_( 312 | dim=1, 313 | index=targ_cmp.view(B, 1, 1, 1).expand(B, ncmp, nprop, ngt), 314 | src=overlaps_msk.new_ones(*overlaps_msk.shape) 315 | ) 316 | 317 | overlaps_one_targ = overlaps * overlaps_msk 318 | 319 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp) 320 | targets_all = self.get_targets_from_overlaps(overlaps, inp) 321 | return { 322 | 'targets_one': targets_one, 323 | 'targets_all': targets_all 324 | } 325 | 326 | def compute_mdl_loss(self, mdl_outs, targets_one, inp): 327 | weights = None 328 | tot_loss = F.binary_cross_entropy_with_logits( 329 | mdl_outs, target=targets_one.float(), 330 | weight=weights, 331 | reduction='none' 332 | ) 333 | 334 | # B x num_cmp 335 | num_cmp_msk = inp['num_cmp_msk'] 336 | num_cmp = num_cmp_msk.size(1) 337 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask'] 338 | num_verbs = srl_arg_boxes_mask.size(1) 339 | if num_verbs == 1 and num_cmp > 1: 340 | srl_arg_boxes_mask = srl_arg_boxes_mask.expand(-1, num_cmp, -1) 341 | 342 | B, num_cmp, num_srl_args = srl_arg_boxes_mask.shape 343 | 344 | boxes_msk = num_cmp_msk.unsqueeze( 345 | -1).expand(*srl_arg_boxes_mask.shape).float() 346 | 347 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 1000 348 | boxes_msk = boxes_msk.unsqueeze( 349 | -1).expand(*targets_one.shape) 350 | 351 | tot_loss = tot_loss * boxes_msk 352 | 353 | multiplier = tot_loss.size(-1) 354 | if srl_arg_boxes_mask.max() > 0: 355 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte()) 356 | else: 357 | # TODO: NEED TO check what is wrong here 358 | out_loss = tot_loss 359 | 360 | mdl_out_loss = out_loss.mean() * multiplier 361 | 362 | return mdl_out_loss 363 | 364 | def compute_vidf_loss_simple(self, vidf_outs, inp): 365 | """ 366 | vidf_outs are fin scores: B x ncmp x nfrms 367 | """ 368 | B, ncmp, nfrm = vidf_outs.shape 369 | targs = vidf_outs.new_zeros(*vidf_outs.shape) 370 | 371 | targ_cmp = inp['target_cmp'] 372 | 373 | targs.scatter_( 374 | dim=1, 375 | index=targ_cmp.view(B, 1, 1).expand(B, ncmp, nfrm), 376 | src=targs.new_ones(*targs.shape) 377 | ) 378 | 379 | # B x ncmp x nfrms 380 | out_loss = F.binary_cross_entropy(vidf_outs, targs, reduction='none') 381 | 382 | mult = 1. / nfrm 383 | 384 | # B x ncmp 385 | msk = inp['num_cmp_msk'] 386 | out_loss = torch.masked_select(out_loss.sum(dim=-1) * msk.float(), 387 | msk.byte()) * mult 388 | return out_loss.mean() 389 | 390 | def compute_vidf_loss(self, vidf_outs, inp): 391 | B, num_cmp, num_srl_args = vidf_outs.shape 392 | box_msk = inp['srl_arg_boxes_mask'] 393 | srl_arg_ind_msk = inp['srl_arg_inds_msk'] 394 | vidf_outs = ((vidf_outs * box_msk.float()).sum(dim=-1) / 395 | srl_arg_ind_msk.sum(dim=-1).float()) 396 | vidf_targs = vidf_outs.new_zeros(*vidf_outs.shape) 397 | 398 | targ_cmp = inp['target_cmp'] 399 | 400 | vidf_targs.scatter_( 401 | dim=1, 402 | index=targ_cmp.unsqueeze(-1).expand(*vidf_targs.shape), 403 | src=vidf_targs.new_ones(*vidf_targs.shape) 404 | ) 405 | 406 | vidf_loss = F.binary_cross_entropy( # 407 | vidf_outs, vidf_targs, 408 | reduction='none' 409 | ) 410 | msk = inp['num_cmp_msk'] 411 | vidf_loss = vidf_loss * msk.float() 412 | vidf_loss = torch.masked_select(vidf_loss, msk.byte()) 413 | return vidf_loss.mean() 414 | 415 | def forward(self, out, inp): 416 | targets_all = self.compute_loss_targets(inp) 417 | targets_n = targets_all['targets_one'] 418 | 419 | mdl_outs = out['mdl_outs'] 420 | 421 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp) 422 | 423 | verb_outs = out['vidf_outs'] 424 | 425 | verb_loss = F.binary_cross_entropy_with_logits( 426 | verb_outs, 427 | inp['verb_cmp'].float(), 428 | reduction='none' 429 | ) 430 | 431 | vcc_msk = inp['verb_cross_cmp_msk'].float() 432 | vcc_msk = (vcc_msk.sum(dim=-1) > 0).float() 433 | 434 | verb_loss = verb_loss * vcc_msk 435 | verb_loss = torch.masked_select( 436 | verb_loss, vcc_msk.byte()).mean() 437 | 438 | # out_loss = mdl_out_loss + verb_loss 439 | out_loss = mdl_out_loss 440 | 441 | out_loss_dict = { 442 | 'loss': out_loss, 443 | 'mdl_out_loss': mdl_out_loss, 444 | 'verb_loss': verb_loss 445 | } 446 | 447 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()} 448 | -------------------------------------------------------------------------------- /code/mdl_conc_single.py: -------------------------------------------------------------------------------- 1 | """ 2 | Concatenate to a Single Video 3 | """ 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from box_utils import bbox_overlaps 8 | 9 | 10 | class ConcBase(nn.Module): 11 | """ 12 | Base Model for concatenation. 13 | Kept for Historical Reasons 14 | """ 15 | 16 | def set_args_conc(self): 17 | """ 18 | Conc Type specific args 19 | """ 20 | return 21 | 22 | 23 | class ConcTEMP(ConcBase): 24 | def conc_encode(self, conc_feats, inp): 25 | ncmp = inp['new_srl_idxs'].size(1) 26 | nfrm = ncmp * self.num_sampled_frm 27 | nppf = self.num_prop_per_frm 28 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, 1) 29 | 30 | def simple_obj_interact_input(self, prop_seg_feats, inp): 31 | # B, num_cmp, num_props, psdim = prop_seg_feats.shape 32 | num_cmp = inp['new_srl_idxs'].size(1) 33 | return self.simple_obj_interact( 34 | prop_seg_feats, inp, 1, 35 | num_cmp * self.num_sampled_frm, 36 | self.num_prop_per_frm 37 | ) 38 | 39 | def get_num_cmp_msk(self, inp, out_shape): 40 | num_cmp = inp['new_srl_idxs'].size(1) 41 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape 42 | num_cmp_msk = inp['num_cmp_msk'].view( 43 | B, 1, 1, num_cmp, 1 44 | ).expand( 45 | B, num_verbs, num_srl_args, num_cmp, 46 | self.num_sampled_frm * self.num_prop_per_frm 47 | ).contiguous( 48 | ).view(*out_shape) 49 | return num_cmp_msk 50 | 51 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp): 52 | B, num_v_frms, sdim = seg_feats.shape 53 | num_cmp = inp['new_srl_idxs'].size(1) 54 | 55 | prop_seg_feats = torch.cat( 56 | [prop_feats.view( 57 | B, 1, num_cmp * self.num_sampled_frm, 58 | self.num_prop_per_frm, prop_feats.size(-1)), 59 | seg_feats.view(B, 1, num_v_frms, 1, sdim).expand( 60 | B, 1, num_cmp * self.num_sampled_frm, 61 | self.num_prop_per_frm, sdim) 62 | ], dim=-1).view( 63 | B, 1, num_cmp * self.num_sampled_frm * self.num_prop_per_frm, 64 | prop_feats.size(-1) + seg_feats.size(-1) 65 | ) 66 | return prop_seg_feats 67 | 68 | def forward(self, inp): 69 | """ 70 | Main difference is that prop feats/seg features 71 | have an extra dimension 72 | """ 73 | # B x 6 x 5 x 40 74 | # 6 is num_cmp for a sent 75 | # 5 is num args in a sent 76 | # 40 is seq length for each arg 77 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape 78 | # B*num_cmp x seq_len 79 | src_toks = self.get_srl_arg_seq_to_sent_seq(inp) 80 | # B*num_cmp x seq_len 81 | src_lens = inp['srl_arg_word_mask_len'].view(B*num_verbs, -1) 82 | # B*num_cmp x seq_len x 256 83 | lstm_outs = self.lang_encode(src_toks, src_lens) 84 | lstm_encoded = lstm_outs['lstm_full_output'] 85 | 86 | # B x 1 x 5 x 512 87 | srl_arg_lstm_encoded = self.retrieve_srl_arg_from_lang_encode( 88 | lstm_encoded, inp 89 | ) 90 | 91 | # Get visual features 92 | # B x 40*100 x 512 93 | prop_feats = self.prop_feats_encode(inp) 94 | 95 | # Get seg features 96 | # B x 40 x 512 97 | seg_feats = self.seg_feats_encode(inp) 98 | B, num_v_frms, sdim = seg_feats.shape 99 | # Simple conc seg_feats 100 | prop_seg_feats = self.concat_prop_seg_feats( 101 | prop_feats, seg_feats, inp 102 | ) 103 | 104 | # Object Interaction if to be done 105 | prop_seg_feats = self.simple_obj_interact_input( 106 | prop_seg_feats, inp 107 | ) 108 | 109 | # B x 1 x num_srl_args x 4*num_props x vf+lf dim 110 | conc_feats = self.concate_vis_lang_feats( 111 | prop_seg_feats, srl_arg_lstm_encoded 112 | ) 113 | 114 | # B x num_cmp x num_srl_args x 4*num_props x vf+lf dim 115 | conc_feats_out_dict = self.conc_encode(conc_feats, inp) 116 | conc_feats_out = conc_feats_out_dict['conc_feats_out'] 117 | 118 | num_cmp_msk = self.get_num_cmp_msk(inp, conc_feats_out.shape) 119 | srl_ind_msk = inp['srl_arg_inds_msk'].unsqueeze(-1).expand( 120 | *conc_feats_out.shape) 121 | conc_feats_out_eval = torch.sigmoid( 122 | conc_feats_out) * srl_ind_msk.float() * num_cmp_msk.float() 123 | 124 | return { 125 | 'mdl_outs': conc_feats_out, 126 | 'mdl_outs_eval': conc_feats_out_eval, 127 | } 128 | 129 | 130 | class ConcSPAT(ConcTEMP): 131 | def conc_encode(self, conc_feats, inp): 132 | ncmp = inp['new_srl_idxs'].size(1) 133 | nfrm = self.num_sampled_frm 134 | nppf = ncmp * self.num_prop_per_frm 135 | return self.conc_encode_item(conc_feats, inp, nfrm, nppf, 1) 136 | 137 | def simple_obj_interact_input(self, prop_seg_feats, inp): 138 | num_cmp = inp['new_srl_idxs'].size(1) 139 | return self.simple_obj_interact( 140 | prop_seg_feats, inp, 1, 141 | self.num_sampled_frm, num_cmp * self.num_prop_per_frm 142 | ) 143 | 144 | def get_num_cmp_msk(self, inp, out_shape): 145 | num_cmp = inp['new_srl_idxs'].size(1) 146 | B, num_verbs, num_srl_args, seq_len = inp['srl_arg_words_ind'].shape 147 | num_cmp_msk = inp['num_cmp_msk'].view( 148 | B, 1, 1, 1, num_cmp, 1 149 | ).expand( 150 | B, num_verbs, num_srl_args, self.num_sampled_frm, 151 | num_cmp, self.num_prop_per_frm 152 | ).contiguous( 153 | ).view(*out_shape) 154 | return num_cmp_msk 155 | 156 | def concat_prop_seg_feats(self, prop_feats, seg_feats, inp): 157 | B, num_v_frms, sdim = seg_feats.shape 158 | num_cmp = inp['new_srl_idxs'].size(1) 159 | prop_seg_feats = torch.cat( 160 | [ 161 | prop_feats.view( 162 | B, 1, self.num_sampled_frm * num_cmp, 163 | self.num_prop_per_frm, prop_feats.size(-1) 164 | ), seg_feats.view(B, 1, num_v_frms, 1, sdim).expand( 165 | B, 1, self.num_sampled_frm * num_cmp, 166 | self.num_prop_per_frm, sdim 167 | ) 168 | ], 169 | dim=-1 170 | ).view( 171 | B, 1, self.num_sampled_frm * num_cmp * self.num_prop_per_frm, 172 | prop_feats.size(-1) + seg_feats.size(-1) 173 | ) 174 | return prop_seg_feats 175 | 176 | def forward(self, inp): 177 | return ConcTEMP.forward(self, inp) 178 | 179 | 180 | class LossB_TEMP(nn.Module): 181 | def __init__(self, cfg, comm): 182 | super().__init__() 183 | self.cfg = cfg 184 | self.comm = comm 185 | self.loss_keys = ['loss', 'mdl_out_loss'] 186 | self.loss_lambda = self.cfg.loss.loss_lambda 187 | self.after_init() 188 | 189 | def after_init(self): 190 | pass 191 | 192 | def get_targets_from_overlaps(self, overlaps, inp): 193 | """ 194 | Use the given overlaps to produce the targets 195 | overlaps: B x num_cmp x 1000 x 100 196 | """ 197 | # to_consider = overlaps > 0.5 198 | targets = overlaps 199 | 200 | srl_boxes = inp['srl_boxes'] 201 | # B, num_cmp, num_srl_args, num_box_per_srl = srl_boxes.shape 202 | B, num_verbs, num_srl_args, num_box_per_srl = srl_boxes.shape 203 | B, num_props, num_gt_box = targets.shape 204 | 205 | srl_boxes_reshaped = srl_boxes.view( 206 | B, num_verbs, num_srl_args, 1, num_box_per_srl).expand( 207 | B, num_verbs, num_srl_args, num_props, num_box_per_srl) 208 | 209 | targets_reshaped = targets.view( 210 | B, 1, 1, num_props, num_gt_box).expand( 211 | B, num_verbs, num_srl_args, num_props, num_gt_box) 212 | 213 | # Choose only those proposals which are ground-truth 214 | # for given srl 215 | targets_to_use = torch.gather( 216 | targets_reshaped, dim=-1, index=srl_boxes_reshaped) 217 | 218 | srl_boxes_lens = inp['srl_boxes_lens'] 219 | targets_to_use = ( 220 | targets_to_use * srl_boxes_lens.float().unsqueeze( 221 | -2).expand(*targets_to_use.shape) 222 | ) 223 | 224 | targets_to_use = targets_to_use.max(dim=-1)[0] > 0.5 225 | 226 | return targets_to_use 227 | 228 | def compute_overlaps(self, inp): 229 | 230 | pad_props = inp['pad_proposals'] 231 | gt_bboxs = inp['pad_gt_bboxs'] 232 | frm_msk = inp['pad_frm_mask'] 233 | pnt_msk = inp['pad_pnt_mask'] 234 | 235 | try: 236 | overlaps = bbox_overlaps( 237 | pad_props, gt_bboxs, 238 | (frm_msk | pnt_msk.unsqueeze(-1)) 239 | ) 240 | except: 241 | import pdb 242 | pdb.set_trace() 243 | overlaps = bbox_overlaps( 244 | pad_props, gt_bboxs, 245 | (frm_msk | pnt_msk.unsqueeze(-1))) 246 | 247 | return overlaps 248 | 249 | def compute_loss_targets(self, inp): 250 | """ 251 | Compute the targets, based on iou 252 | overlaps 253 | """ 254 | num_cmp = inp['new_srl_idxs'].size(1) 255 | overlaps = self.compute_overlaps(inp) 256 | B, num_tot_props, num_gt = overlaps.shape 257 | assert num_tot_props % num_cmp == 0 258 | num_props = num_tot_props // num_cmp 259 | overlaps_msk = overlaps.new_zeros(B, num_cmp, num_props, num_gt) 260 | 261 | targ_cmp = inp['target_cmp'] 262 | 263 | overlaps_msk.scatter_( 264 | dim=1, 265 | index=targ_cmp.view(B, 1, 1, 1).expand( 266 | B, num_cmp, num_props, num_gt), 267 | src=overlaps_msk.new_ones(*overlaps_msk.shape) 268 | ) 269 | 270 | overlaps_msk = overlaps_msk.view(B, num_tot_props, num_gt) 271 | overlaps_one_targ = overlaps * overlaps_msk 272 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp) 273 | return { 274 | 'targets_one': targets_one, 275 | } 276 | 277 | def compute_mdl_loss(self, mdl_outs, targets_one, inp): 278 | weights = None 279 | tot_loss = F.binary_cross_entropy_with_logits( 280 | mdl_outs, target=targets_one.float(), 281 | weight=weights, 282 | reduction='none' 283 | ) 284 | 285 | num_cmp_msk = inp['num_cmp_msk'] 286 | B, num_cmp = num_cmp_msk.shape 287 | 288 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask'] 289 | B, num_verbs, num_srl_args = srl_arg_boxes_mask.shape 290 | 291 | boxes_msk = ( 292 | srl_arg_boxes_mask.view( 293 | B, num_verbs, num_srl_args, 1).expand( 294 | B, num_verbs, num_srl_args, num_cmp).float() * 295 | num_cmp_msk.view( 296 | B, 1, 1, num_cmp).expand( 297 | B, num_verbs, num_srl_args, num_cmp).float() 298 | ) 299 | num_props_per_vid = targets_one.size(-1) // num_cmp 300 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 4000 301 | boxes_msk = boxes_msk.unsqueeze( 302 | -1).expand( 303 | B, num_verbs, num_srl_args, num_cmp, num_props_per_vid 304 | ).contiguous().view( 305 | B, num_verbs, num_srl_args, num_cmp * num_props_per_vid) 306 | 307 | multiplier = tot_loss.size(-1) 308 | if srl_arg_boxes_mask.max() > 0: 309 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte()) 310 | else: 311 | # TODO: NEED TO check what is wrong here 312 | out_loss = tot_loss 313 | mdl_out_loss = out_loss.mean() * multiplier 314 | # mdl_out_loss = out_loss * 1000 315 | return mdl_out_loss 316 | 317 | def forward(self, out, inp): 318 | targets_all = self.compute_loss_targets(inp) 319 | targets_n = targets_all['targets_one'] 320 | 321 | mdl_outs = out['mdl_outs'] 322 | 323 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp) 324 | 325 | out_loss = mdl_out_loss 326 | 327 | out_loss_dict = { 328 | 'loss': out_loss, 329 | 'mdl_out_loss': mdl_out_loss, 330 | } 331 | 332 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()} 333 | 334 | 335 | class LossB_SPAT(LossB_TEMP): 336 | def after_init(self): 337 | self.loss_keys = ['loss', 'mdl_out_loss'] 338 | 339 | self.num_sampled_frm = self.cfg.ds.num_sampled_frm 340 | self.num_prop_per_frm = self.comm.num_prop_per_frm 341 | 342 | def compute_loss_targets(self, inp): 343 | """ 344 | Compute the targets, based on iou 345 | overlaps 346 | """ 347 | num_cmp = inp['new_srl_idxs'].size(1) 348 | overlaps = self.compute_overlaps(inp) 349 | B, num_tot_props, num_gt = overlaps.shape 350 | assert num_tot_props % num_cmp == 0 351 | 352 | overlaps_msk = overlaps.new_zeros( 353 | B, self.num_sampled_frm, num_cmp, 354 | self.num_prop_per_frm, num_gt 355 | ) 356 | 357 | targ_cmp = inp['target_cmp'] 358 | overlaps_msk.scatter_( 359 | dim=2, 360 | index=targ_cmp.view(B, 1, 1, 1, 1).expand( 361 | B, self.num_sampled_frm, num_cmp, self.num_prop_per_frm, num_gt 362 | ), 363 | src=overlaps_msk.new_ones(*overlaps_msk.shape) 364 | ) 365 | 366 | overlaps_msk = overlaps_msk.view(B, num_tot_props, num_gt) 367 | overlaps_one_targ = overlaps * overlaps_msk 368 | targets_one = self.get_targets_from_overlaps(overlaps_one_targ, inp) 369 | return { 370 | 'targets_one': targets_one, 371 | } 372 | 373 | def compute_mdl_loss(self, mdl_outs, targets_one, inp): 374 | weights = None 375 | tot_loss = F.binary_cross_entropy_with_logits( 376 | mdl_outs, target=targets_one.float(), 377 | weight=weights, 378 | reduction='none' 379 | ) 380 | 381 | num_cmp_msk = inp['num_cmp_msk'] 382 | B, num_cmp = num_cmp_msk.shape 383 | 384 | srl_arg_boxes_mask = inp['srl_arg_boxes_mask'] 385 | 386 | B, num_verbs, num_srl_args = srl_arg_boxes_mask.shape 387 | 388 | boxes_msk = ( 389 | srl_arg_boxes_mask.view( 390 | B, num_verbs, num_srl_args, 1).expand( 391 | B, num_verbs, num_srl_args, num_cmp).float() * 392 | num_cmp_msk.view( 393 | B, 1, 1, num_cmp).expand( 394 | B, num_verbs, num_srl_args, num_cmp).float() 395 | ) 396 | 397 | num_tot_props = targets_one.size(-1) 398 | # B x num_cmp x num_srl_args -> B x num_cmp x num_srl x 4000 399 | boxes_msk = boxes_msk.view( 400 | B, num_verbs, num_srl_args, 1, num_cmp, 1 401 | ).expand( 402 | B, num_verbs, num_srl_args, self.num_sampled_frm, 403 | num_cmp, self.num_prop_per_frm 404 | ).contiguous().view( 405 | B, num_verbs, num_srl_args, num_tot_props 406 | ) 407 | 408 | multiplier = tot_loss.size(-1) 409 | if srl_arg_boxes_mask.max() > 0: 410 | out_loss = torch.masked_select(tot_loss, boxes_msk.byte()) 411 | else: 412 | # TODO: NEED TO check what is wrong here 413 | out_loss = tot_loss 414 | mdl_out_loss = out_loss.mean() * multiplier 415 | 416 | return mdl_out_loss 417 | 418 | def forward(self, out, inp): 419 | targets_all = self.compute_loss_targets(inp) 420 | targets_n = targets_all['targets_one'] 421 | 422 | mdl_outs = out['mdl_outs'] 423 | 424 | mdl_out_loss = self.compute_mdl_loss(mdl_outs, targets_n, inp) 425 | 426 | out_loss = mdl_out_loss 427 | 428 | out_loss_dict = { 429 | 'loss': out_loss, 430 | 'mdl_out_loss': mdl_out_loss, 431 | } 432 | 433 | return {k: v * self.loss_lambda for k, v in out_loss_dict.items()} 434 | -------------------------------------------------------------------------------- /code/mdl_selector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Select the model, loss, eval_fn 3 | """ 4 | from mdl_vog import ( 5 | ImgGrnd_SEP, 6 | ImgGrnd_TEMP, 7 | ImgGrnd_SPAT, 8 | VidGrnd_SEP, 9 | VidGrnd_TEMP, 10 | VidGrnd_SPAT, 11 | VOG_SEP, 12 | VOG_TEMP, 13 | VOG_SPAT, 14 | LossB_SEP, 15 | LossB_TEMP, 16 | LossB_SPAT 17 | ) 18 | 19 | from eval_vsrl_corr import ( 20 | EvaluatorSEP, 21 | EvaluatorTEMP, 22 | EvaluatorSPAT 23 | ) 24 | 25 | 26 | def get_mdl_loss_eval(cfg): 27 | conc_type = cfg.ds.conc_type 28 | mdl_type = cfg.mdl.name 29 | if conc_type == 'sep' or conc_type == 'svsq': 30 | if mdl_type == 'igrnd': 31 | mdl = ImgGrnd_SEP 32 | elif mdl_type == 'vgrnd': 33 | mdl = VidGrnd_SEP 34 | elif mdl_type == 'vog': 35 | mdl = VOG_SEP 36 | else: 37 | raise NotImplementedError 38 | loss = LossB_SEP 39 | evl = EvaluatorSEP 40 | elif conc_type == 'temp': 41 | if mdl_type == 'igrnd': 42 | mdl = ImgGrnd_TEMP 43 | elif mdl_type == 'vgrnd': 44 | mdl = VidGrnd_TEMP 45 | elif mdl_type == 'vog': 46 | mdl = VOG_TEMP 47 | else: 48 | raise NotImplementedError 49 | loss = LossB_TEMP 50 | evl = EvaluatorTEMP 51 | elif conc_type == 'spat': 52 | if mdl_type == 'igrnd': 53 | mdl = ImgGrnd_SPAT 54 | elif mdl_type == 'vgrnd': 55 | mdl = VidGrnd_SPAT 56 | elif mdl_type == 'vog': 57 | mdl = VOG_SPAT 58 | else: 59 | raise NotImplementedError 60 | loss = LossB_SPAT 61 | evl = EvaluatorSPAT 62 | else: 63 | raise NotImplementedError 64 | 65 | return { 66 | 'mdl': mdl, 67 | 'loss': loss, 68 | 'eval': evl 69 | } 70 | -------------------------------------------------------------------------------- /code/transformer_code.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer implementation adapted from 3 | https://github.com/facebookresearch/grounded-video-description/blob/master/misc/transformer.py 4 | """ 5 | import torch 6 | import math 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | INF = 1e10 11 | 12 | 13 | def matmul(x, y): 14 | if x.dim() == y.dim(): 15 | return torch.matmul(x, y) 16 | if x.dim() == y.dim() - 1: 17 | return torch.matmul(x.unsqueeze(-2), y).squeeze(-2) 18 | return torch.matmul(x, y.unsqueeze(-2)).squeeze(-2) 19 | 20 | 21 | class ResidualBlock(nn.Module): 22 | 23 | def __init__(self, layer, d_model, drop_ratio): 24 | super(ResidualBlock, self).__init__() 25 | self.layer = layer 26 | self.dropout = nn.Dropout(drop_ratio) 27 | # self.layernorm = LayerNorm(d_model) 28 | self.layernorm = nn.LayerNorm(d_model) 29 | 30 | def forward(self, *x): 31 | return self.layernorm(x[0] + self.dropout(self.layer(*x))) 32 | 33 | 34 | class Attention(nn.Module): 35 | 36 | def __init__(self, d_key, drop_ratio, causal): 37 | super(Attention, self).__init__() 38 | self.scale = math.sqrt(d_key) 39 | self.dropout = nn.Dropout(drop_ratio) 40 | self.causal = causal 41 | 42 | def forward(self, query, key, value): 43 | dot_products = matmul(query, key.transpose(1, 2)) 44 | if query.dim() == 3 and (self is None or self.causal): 45 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF 46 | if key.is_cuda: 47 | tri = tri.cuda(key.get_device()) 48 | dot_products.data.sub_(tri.unsqueeze(0)) 49 | 50 | return matmul(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value) 51 | 52 | 53 | class MultiHead(nn.Module): 54 | 55 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False): 56 | super(MultiHead, self).__init__() 57 | self.attention = Attention(d_key, drop_ratio, causal=causal) 58 | self.wq = nn.Linear(d_key, d_key, bias=False) 59 | self.wk = nn.Linear(d_key, d_key, bias=False) 60 | self.wv = nn.Linear(d_value, d_value, bias=False) 61 | self.wo = nn.Linear(d_value, d_key, bias=False) 62 | self.n_heads = n_heads 63 | 64 | def forward(self, query, key, value): 65 | query, key, value = self.wq(query), self.wk(key), self.wv(value) 66 | 67 | query, key, value = ( 68 | x.chunk(self.n_heads, -1) for x in (query, key, value)) 69 | return self.wo(torch.cat([self.attention(q, k, v) 70 | for q, k, v in zip(query, key, value)], -1)) 71 | 72 | 73 | class FeedForward(nn.Module): 74 | 75 | def __init__(self, d_model, d_hidden): 76 | super(FeedForward, self).__init__() 77 | self.linear1 = nn.Linear(d_model, d_hidden) 78 | self.linear2 = nn.Linear(d_hidden, d_model) 79 | 80 | def forward(self, x): 81 | return self.linear2(F.relu(self.linear1(x))) 82 | 83 | 84 | class EncoderLayer(nn.Module): 85 | 86 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio): 87 | super(EncoderLayer, self).__init__() 88 | self.selfattn = ResidualBlock( 89 | MultiHead(d_model, d_model, n_heads, drop_ratio), 90 | d_model, drop_ratio) 91 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 92 | d_model, drop_ratio) 93 | 94 | def forward(self, x): 95 | return self.feedforward(self.selfattn(x, x, x)) 96 | 97 | 98 | class Encoder(nn.Module): 99 | 100 | def __init__(self, d_model, d_hidden, n_vocab, n_layers, n_heads, 101 | drop_ratio, pe): 102 | super(Encoder, self).__init__() 103 | # self.linear = nn.Linear(d_model*2, d_model) 104 | self.layers = nn.ModuleList( 105 | [EncoderLayer(d_model, d_hidden, n_heads, drop_ratio) 106 | for i in range(n_layers)]) 107 | self.dropout = nn.Dropout(drop_ratio) 108 | self.pe = pe 109 | 110 | def forward(self, x, mask=None): 111 | # x = self.linear(x) 112 | if self.pe: 113 | # spatial configuration is already encoded 114 | # x = x+positional_encodings_like(x) 115 | raise NotImplementedError 116 | # x = self.dropout(x) # dropout is already in the pool_embed layer 117 | if mask is not None: 118 | x = x*mask 119 | encoding = [] 120 | for layer in self.layers: 121 | x = layer(x) 122 | if mask is not None: 123 | x = x*mask 124 | encoding.append(x) 125 | return encoding 126 | 127 | 128 | class RelAttention(nn.Module): 129 | 130 | def __init__(self, d_key, drop_ratio, causal): 131 | super().__init__() 132 | self.scale = math.sqrt(d_key) 133 | self.dropout = nn.Dropout(drop_ratio) 134 | self.causal = causal 135 | 136 | def forward(self, query, key, value, pe_k, pe_v): 137 | """ 138 | query, key, value: B x N x 214 139 | pe_k: B x N x N x 214 140 | """ 141 | dot_products = matmul(query, key.transpose(1, 2)) 142 | if query.dim() == 3 and (self is None or self.causal): 143 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF 144 | if key.is_cuda: 145 | tri = tri.cuda(key.get_device()) 146 | dot_products.data.sub_(tri.unsqueeze(0)) 147 | 148 | # new_dp = matmul(query, pe_k.transpose(2, 3)) 149 | new_dp = pe_k.squeeze(-1) 150 | assert new_dp.shape == dot_products.shape 151 | new_dot_prods = (dot_products + new_dp) / self.scale 152 | 153 | attn = self.dropout(F.softmax(new_dot_prods, dim=-1)) 154 | 155 | out_v = matmul(attn, value) 156 | # new_out_v = matmul(attn, pe_v) 157 | # new_out_v = pe_v 158 | 159 | new_outs = out_v 160 | return new_outs 161 | 162 | 163 | class RelMultiHead(nn.Module): 164 | 165 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False, d_pe=None): 166 | super().__init__() 167 | self.attention = RelAttention(d_key, drop_ratio, causal=causal) 168 | self.n_heads = n_heads 169 | self.wq = nn.Linear(d_key, d_key, bias=False) 170 | self.wk = nn.Linear(d_key, d_key, bias=False) 171 | self.wv = nn.Linear(d_value, d_value, bias=False) 172 | self.wo = nn.Linear(d_value, d_key, bias=False) 173 | # self.wpk = nn.Linear(d_pe, self.n_heads, bias=False) 174 | # self.wpv = nn.Linear(d_pe, self.n_heads, bias=False) 175 | 176 | def forward(self, query, key, value, pe=None): 177 | """ 178 | pe is B x N x N x 1 position difference 179 | """ 180 | query, key, value = self.wq(query), self.wk(key), self.wv(value) 181 | pe_k, pe_v = pe, pe 182 | query, key, value, pe_k, pe_v = ( 183 | x.chunk(self.n_heads, -1) for x in (query, key, value, pe_k, pe_v)) 184 | return self.wo(torch.cat([self.attention(q, k, v, pk, pv) 185 | for q, k, v, pk, pv in 186 | zip(query, key, value, pe_k, pe_v)], -1)) 187 | 188 | 189 | class RelEncoderLayer(nn.Module): 190 | 191 | def __init__(self, d_model, d_hidden, n_heads, 192 | drop_ratio, d_pe=None, sa=True): 193 | super().__init__() 194 | self.selfattn = ResidualBlock( 195 | RelMultiHead(d_model, d_model, n_heads, drop_ratio, d_pe=d_pe), 196 | d_model, drop_ratio) 197 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden), 198 | d_model, drop_ratio) 199 | self.sa = sa 200 | 201 | def forward(self, x, pe=None): 202 | if not isinstance(x, dict): 203 | return self.feedforward(self.selfattn(x, x, x, pe)) 204 | else: 205 | assert not self.sa 206 | assert isinstance(x, dict) 207 | assert 'query' in x 208 | assert 'key' in x 209 | assert 'value' in x 210 | return self.feedforward( 211 | self.selfattn(x['query'], x['key'], x['value'], pe) 212 | ) 213 | 214 | 215 | class RelEncoder(nn.Module): 216 | 217 | def __init__(self, d_model, d_hidden, n_vocab, n_layers, n_heads, 218 | drop_ratio, pe, d_pe, sa=True): 219 | super().__init__() 220 | # self.linear = nn.Linear(d_model*2, d_model) 221 | self.layers = nn.ModuleList( 222 | [RelEncoderLayer(d_model, d_hidden, n_heads, drop_ratio, d_pe=d_pe, sa=sa) 223 | for i in range(n_layers)]) 224 | self.dropout = nn.Dropout(drop_ratio) 225 | self.pe = pe 226 | 227 | def forward(self, x, x_pe, mask=None): 228 | # x = self.linear(x) 229 | if self.pe: 230 | # spatial configuration is already encoded 231 | raise NotImplementedError 232 | # x = self.dropout(x) # dropout is already in the pool_embed layer 233 | if mask is not None: 234 | x = x*mask 235 | encoding = [] 236 | for layer in self.layers: 237 | x = layer(x, pe=x_pe) 238 | if mask is not None: 239 | x = x*mask 240 | encoding.append(x) 241 | return encoding 242 | 243 | 244 | class Transformer(nn.Module): 245 | 246 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048, 247 | n_layers=6, n_heads=8, drop_ratio=0.1, pe=False): 248 | super(Transformer, self).__init__() 249 | self.encoder = Encoder(d_model, d_hidden, n_vocab_src, n_layers, 250 | n_heads, drop_ratio, pe) 251 | 252 | def forward(self, x): 253 | encoding = self.encoder(x) 254 | return encoding[-1] 255 | # return encoding[-1], encoding 256 | # return torch.cat(encoding, 2) 257 | 258 | def all_outputs(self, x): 259 | encoding = self.encoder(x) 260 | return encoding 261 | 262 | 263 | class RelTransformer(nn.Module): 264 | 265 | def __init__(self, d_model, n_vocab_src, vocab_trg, d_hidden=2048, 266 | n_layers=6, n_heads=8, drop_ratio=0.1, pe=False, d_pe=None): 267 | super().__init__() 268 | self.encoder = RelEncoder(d_model, d_hidden, n_vocab_src, n_layers, 269 | n_heads, drop_ratio, pe, d_pe=d_pe) 270 | 271 | def forward(self, x, x_pe): 272 | encoding = self.encoder(x, x_pe) 273 | return encoding[-1] 274 | # return encoding[-1], encoding 275 | # return torch.cat(encoding, 2) 276 | 277 | def all_outputs(self, x): 278 | encoding = self.encoder(x) 279 | return encoding 280 | -------------------------------------------------------------------------------- /code/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize Predictions 3 | """ 4 | 5 | import pandas as pd 6 | import pickle 7 | from PIL import Image 8 | from pathlib import Path 9 | from eval_fn_corr import ( 10 | GroundEval_SEP, 11 | GroundEval_TEMP, 12 | GroundEval_SPAT 13 | ) 14 | import fire 15 | from munch import Munch 16 | from typing import List 17 | 18 | 19 | class ASRL_Vis: 20 | def open_required_files(self, ann_file): 21 | self.annots = pd.read_csv(ann_file) 22 | 23 | def draw_boxes_all_indices(self, preds): 24 | # self.preds 25 | pass 26 | 27 | def prepare_img(self, img_list: List): 28 | """ 29 | How to concate the image from an image list 30 | """ 31 | raise NotImplementedError 32 | 33 | def extract_bbox_per_frame(self, preds): 34 | """ 35 | Obtain the bounding boxes for each frame 36 | """ 37 | raise NotImplementedError 38 | 39 | def all_inds(self, pred_file, split_type): 40 | self.prepare_gt(split_type) 41 | 42 | def draw_boxes_one_index( 43 | self, pred, gt_row, conc_type 44 | ): 45 | frm_tdir = Path('/home/Datasets/ActNetEnt/frames_10frm/') 46 | vid_file_id_list = pred['idx_vid'] 47 | 48 | rows = self.annots.iloc[vid_file_id_list] 49 | vid_seg_id_list = rows['id'] 50 | 51 | img_file_dict = { 52 | k: sorted( 53 | [x for x in (frm_tdir/k).iterdir()], 54 | key=lambda x: int(x.stem) 55 | ) 56 | for k in vid_seg_id_list 57 | } 58 | img_list_dict = { 59 | k: [Image.open(img_file) for img_file in img_file_list] 60 | for k, img_file_list in img_file_dict.items() 61 | } 62 | 63 | img = self.prepare_img(img_list_dict) 64 | pass 65 | 66 | 67 | class ASRL_Vis_SEP(GroundEval_SEP, ASRL_Vis): 68 | pass 69 | 70 | 71 | class ASRL_Vis_TEMP(GroundEval_TEMP, ASRL_Vis): 72 | pass 73 | 74 | 75 | class ASRL_Vis_SPAT(GroundEval_SPAT, ASRL_Vis): 76 | pass 77 | 78 | 79 | def main(pred_file, split_type='valid', **kwargs): 80 | if 'cfg' not in kwargs: 81 | from extended_config import ( 82 | cfg as conf, 83 | key_maps, 84 | # CN, 85 | update_from_dict, 86 | # post_proc_config 87 | ) 88 | cfg = conf 89 | cfg = update_from_dict(cfg, kwargs, key_maps) 90 | else: 91 | cfg = kwargs['cfg'] 92 | cfg.freeze() 93 | # grnd_eval = GroundEval_Corr(cfg) 94 | # grnd_eval = GroundEvalDS4(cfg) 95 | comm = Munch() 96 | exp = cfg.ds.exp_setting 97 | if exp == 'gt5': 98 | comm.num_prop_per_frm = 5 99 | elif exp == 'p100': 100 | comm.num_prop_per_frm = 100 101 | else: 102 | raise NotImplementedError 103 | 104 | conc_type = cfg.ds.conc_type 105 | if conc_type == 'sep' or conc_type == 'svsq': 106 | avis = ASRL_Vis_SEP(cfg, comm) 107 | elif conc_type == 'temp': 108 | avis = ASRL_Vis_TEMP(cfg, comm) 109 | elif conc_type == 'spat': 110 | avis = ASRL_Vis_SPAT(cfg, comm) 111 | else: 112 | raise NotImplementedError 113 | 114 | # avis.draw_boxes_all_indices( 115 | # pred_file, split_type=split_type 116 | # ) 117 | 118 | return avis 119 | 120 | 121 | if __name__ == '__main__': 122 | fire.Fire(main) 123 | -------------------------------------------------------------------------------- /conda_env_vog.yml: -------------------------------------------------------------------------------- 1 | name: vog_pyt 2 | channels: 3 | - pytorch 4 | - fastai 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - attrs=19.3.0=py_0 9 | - backcall=0.1.0=py36_0 10 | - blas=1.0=mkl 11 | - bleach=3.1.0=py_0 12 | - ca-certificates=2020.1.1=0 13 | - certifi=2019.11.28=py36_1 14 | - cffi=1.14.0=py36h2e261b9_0 15 | - cudatoolkit=10.0.130=0 16 | - dbus=1.13.12=h746ee38_0 17 | - decorator=4.4.2=py_0 18 | - defusedxml=0.6.0=py_0 19 | - entrypoints=0.3=py36_0 20 | - expat=2.2.6=he6710b0_0 21 | - fastprogress=0.1.21=py_0 22 | - fontconfig=2.13.0=h9420a91_0 23 | - freetype=2.9.1=h8a8886c_1 24 | - glib=2.63.1=h5a9c865_0 25 | - gmp=6.1.2=h6c8ec71_1 26 | - gst-plugins-base=1.14.0=hbbd80ab_1 27 | - gstreamer=1.14.0=hb453b48_1 28 | - icu=58.2=h9c2bf20_1 29 | - importlib_metadata=1.5.0=py36_0 30 | - intel-openmp=2020.0=166 31 | - ipykernel=5.1.4=py36h39e3cac_0 32 | - ipython=7.13.0=py36h5ca1d4c_0 33 | - ipython_genutils=0.2.0=py36_0 34 | - ipywidgets=7.5.1=py_0 35 | - jedi=0.16.0=py36_1 36 | - jinja2=2.11.1=py_0 37 | - jpeg=9b=h024ee3a_2 38 | - jsonschema=3.2.0=py36_0 39 | - jupyter=1.0.0=py36_7 40 | - jupyter_client=6.1.0=py_0 41 | - jupyter_console=6.1.0=py_0 42 | - jupyter_core=4.6.1=py36_0 43 | - ld_impl_linux-64=2.33.1=h53a641e_7 44 | - libedit=3.1.20181209=hc058e9b_0 45 | - libffi=3.2.1=hd88cf55_4 46 | - libgcc-ng=9.1.0=hdf63c60_0 47 | - libgfortran-ng=7.3.0=hdf63c60_0 48 | - libpng=1.6.37=hbc83047_0 49 | - libsodium=1.0.16=h1bed415_0 50 | - libstdcxx-ng=9.1.0=hdf63c60_0 51 | - libtiff=4.1.0=h2733197_0 52 | - libuuid=1.0.3=h1bed415_2 53 | - libxcb=1.13=h1bed415_1 54 | - libxml2=2.9.9=hea5a465_1 55 | - markupsafe=1.1.1=py36h7b6447c_0 56 | - mistune=0.8.4=py36h7b6447c_0 57 | - mkl=2020.0=166 58 | - mkl-service=2.3.0=py36he904b0f_0 59 | - mkl_fft=1.0.15=py36ha843d7b_0 60 | - mkl_random=1.1.0=py36hd6b4f25_0 61 | - nbconvert=5.6.1=py36_0 62 | - nbformat=5.0.4=py_0 63 | - ncurses=6.2=he6710b0_0 64 | - ninja=1.9.0=py36hfd86e86_0 65 | - notebook=6.0.3=py36_0 66 | - numpy=1.18.1=py36h4f9e942_0 67 | - numpy-base=1.18.1=py36hde5b4d6_1 68 | - olefile=0.46=py_0 69 | - openssl=1.1.1e=h7b6447c_0 70 | - pandas=1.0.3=py36h0573a6f_0 71 | - pandoc=2.2.3.2=0 72 | - pandocfilters=1.4.2=py36_1 73 | - parso=0.6.2=py_0 74 | - pcre=8.43=he6710b0_0 75 | - pexpect=4.8.0=py36_0 76 | - pickleshare=0.7.5=py36_0 77 | - pillow=7.0.0=py36hb39fc2d_0 78 | - pip=20.0.2=py36_1 79 | - prometheus_client=0.7.1=py_0 80 | - prompt-toolkit=3.0.4=py_0 81 | - prompt_toolkit=3.0.4=0 82 | - ptyprocess=0.6.0=py36_0 83 | - pycparser=2.20=py_0 84 | - pygments=2.6.1=py_0 85 | - pyqt=5.9.2=py36h05f1152_2 86 | - pyrsistent=0.15.7=py36h7b6447c_0 87 | - python=3.6.10=hcf32534_1 88 | - python-dateutil=2.8.1=py_0 89 | - pytorch=1.1.0=py3.6_cuda10.0.130_cudnn7.5.1_0 90 | - pytz=2019.3=py_0 91 | - pyzmq=18.1.1=py36he6710b0_0 92 | - qt=5.9.7=h5867ecd_1 93 | - qtconsole=4.7.2=py_0 94 | - qtpy=1.9.0=py_0 95 | - readline=8.0=h7b6447c_0 96 | - send2trash=1.5.0=py36_0 97 | - setuptools=46.1.1=py36_0 98 | - sip=4.19.8=py36hf484d3e_0 99 | - six=1.14.0=py36_0 100 | - sqlite=3.31.1=h7b6447c_0 101 | - terminado=0.8.3=py36_0 102 | - testpath=0.4.4=py_0 103 | - tk=8.6.8=hbc83047_0 104 | - torchvision=0.3.0=py36_cu10.0.130_1 105 | - tornado=6.0.4=py36h7b6447c_1 106 | - traitlets=4.3.3=py36_0 107 | - wcwidth=0.1.8=py_0 108 | - webencodings=0.5.1=py36_1 109 | - wheel=0.34.2=py36_0 110 | - widgetsnbextension=3.5.1=py36_0 111 | - xz=5.2.4=h14c3975_4 112 | - zeromq=4.3.1=he6710b0_3 113 | - zipp=2.2.0=py_0 114 | - zlib=1.2.11=h7b6447c_3 115 | - zstd=1.3.7=h0b5b093_0 116 | - pip: 117 | - absl-py==0.9.0 118 | - alabaster==0.7.12 119 | - allennlp==0.8.5 120 | - babel==2.8.0 121 | - blis==0.2.4 122 | - boto3==1.12.31 123 | - botocore==1.15.31 124 | - cachetools==4.0.0 125 | - chardet==3.0.4 126 | - click==7.1.1 127 | - conllu==1.3.1 128 | - cycler==0.10.0 129 | - cymem==2.0.3 130 | - cython==0.29.16 131 | - dataclasses==0.7 132 | - docutils==0.15.2 133 | - editdistance==0.5.3 134 | - fairseq==0.8.0 135 | - fastbpe==0.1.0 136 | - fire==0.3.0 137 | - flaky==3.6.1 138 | - flask==1.1.1 139 | - flask-cors==3.0.8 140 | - ftfy==5.7 141 | - future==0.18.2 142 | - gevent==1.4.0 143 | - google-auth==1.12.0 144 | - google-auth-oauthlib==0.4.1 145 | - greenlet==0.4.15 146 | - grpcio==1.27.2 147 | - h5py==2.10.0 148 | - idna==2.9 149 | - imagesize==1.2.0 150 | - itsdangerous==1.1.0 151 | - jmespath==0.9.5 152 | - joblib==0.14.1 153 | - jsonnet==0.15.0 154 | - jsonpickle==1.3 155 | - kiwisolver==1.1.0 156 | - markdown==3.2.1 157 | - matplotlib==3.2.1 158 | - more-itertools==8.2.0 159 | - munch==2.5.0 160 | - murmurhash==1.0.2 161 | - nltk==3.4.5 162 | - numpydoc==0.9.2 163 | - oauthlib==3.1.0 164 | - overrides==2.8.0 165 | - packaging==20.3 166 | - parsimonious==0.8.1 167 | - plac==0.9.6 168 | - pluggy==0.13.1 169 | - portalocker==1.6.0 170 | - preshed==2.0.1 171 | - protobuf==3.11.3 172 | - py==1.8.1 173 | - pyasn1==0.4.8 174 | - pyasn1-modules==0.2.8 175 | - pyparsing==2.4.6 176 | - pytest==5.4.1 177 | - pytorch-pretrained-bert==0.6.2 178 | - pytorch-transformers==1.1.0 179 | - pyyaml==5.3.1 180 | - regex==2020.2.20 181 | - requests==2.23.0 182 | - requests-oauthlib==1.3.0 183 | - responses==0.10.12 184 | - rsa==4.0 185 | - s3transfer==0.3.3 186 | - sacrebleu==1.4.4 187 | - scikit-learn==0.22.2.post1 188 | - scipy==1.4.1 189 | - sentencepiece==0.1.85 190 | - snowballstemmer==2.0.0 191 | - spacy==2.1.9 192 | - sphinx==2.4.4 193 | - sphinxcontrib-applehelp==1.0.2 194 | - sphinxcontrib-devhelp==1.0.2 195 | - sphinxcontrib-htmlhelp==1.0.3 196 | - sphinxcontrib-jsmath==1.0.1 197 | - sphinxcontrib-qthelp==1.0.3 198 | - sphinxcontrib-serializinghtml==1.1.4 199 | - sqlparse==0.3.1 200 | - srsly==1.0.2 201 | - tensorboard==2.2.0 202 | - tensorboard-plugin-wit==1.6.0.post2 203 | - tensorboardx==2.0 204 | - termcolor==1.1.0 205 | - thinc==7.0.8 206 | - torchtext==0.5.0 207 | - tqdm==4.43.0 208 | - typing==3.7.4.1 209 | - unidecode==1.1.1 210 | - urllib3==1.25.8 211 | - wasabi==0.6.0 212 | - werkzeug==1.0.0 213 | - word2number==1.1 214 | - yacs==0.1.6 215 | prefix: /home/arka/.conda/envs/vog_pyt 216 | 217 | -------------------------------------------------------------------------------- /configs/anet_srl_cfg.yml: -------------------------------------------------------------------------------- 1 | ds_name: "anet" 2 | ds: 3 | # where to find the rgb+flow data 4 | seg_feature_root: "data/anet/rgb_motion_1d" 5 | # choose one setting 6 | exp_setting: "gt5" #or "p100" 7 | gt5: 8 | # bounding boxes from FasterRCNN 9 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_gt5_rois.h5" 10 | # extracted features from FasterRCNN 11 | feature_root: "data/anet/fc6_feat_5rois" 12 | # number of proposals considered per frame 13 | num_prop_per_frm: 5 14 | p100: 15 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_100rois_resized.h5" 16 | feature_root: "data/anet/fc6_feat_100rois" 17 | num_prop_per_frm: 100 18 | resized_width: 720 19 | resized_height: 405 20 | num_sampled_frm: 10 21 | max_gt_box: 100 22 | t_attn_size: 480 23 | max_seq_length: 20 24 | anet_cap_file: "data/anet_cap_ent_files/anet_captions_all_splits.json" 25 | anet_ent_annot_file: "data/anet_cap_ent_files/anet_ent_cls_bbox_trainval.json" 26 | anet_ent_split_file: "data/anet_cap_ent_files/dic_anet.json" 27 | include_srl_args: ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'] 28 | # Vocab file for SRLs 29 | arg_vocab_file: "data/anet_srl_files/arg_vocab.pkl" 30 | # Annot files: 31 | trn_ann_file: "data/anet_cap_ent_files/csv_dir/train_postproc.csv" 32 | val_ann_file: "data/anet_cap_ent_files/csv_dir/val_postproc.csv" 33 | # Object Mappings: 34 | trn_ds4_dicts: "data/anet_srl_files/trn_srl_obj_to_index_dict.json" 35 | val_ds4_dicts: "data/anet_srl_files/val_srl_obj_to_index_dict.json" 36 | # ASRL with indices for SPAT/TEMP 37 | trn_ds4_inds: "data/anet_srl_files/trn_asrl_annots.csv" 38 | val_ds4_inds: "data/anet_srl_files/val_asrl_annots.csv" 39 | # Sampling mechanism 40 | trn_sample: "ds4_random" 41 | val_sample: "ds4" 42 | # Num Vids Sampled at a time (should be an int) 43 | trn_num_vid_sample: 4 44 | val_num_vid_sample: 4 45 | # Type of Concatenation, choose among ['svsq', 'sep', 'temp', 'spat'] 46 | conc_type: 'spat' 47 | # Shuffle: 48 | cs_shuffle: True 49 | none_word: "" 50 | 51 | mdl: 52 | name: 'vog' 53 | seg_feat_dim: 3072 54 | prop_feat_dim: 2048 55 | input_encoding_size: 512 56 | use_vis_msk: True 57 | rnn: 58 | rnn_size: 1024 59 | num_layers: 2 60 | drop_prob_lm: 0.5 61 | vsrl: 62 | prop_encode_size: 256 63 | seg_encode_size: 256 64 | lang_encode_size: 256 65 | obj_tx: 66 | use_ddp: false 67 | to_use: true 68 | n_layers: 1 69 | n_heads: 3 70 | attn_drop: 0.2 71 | use_rel: false 72 | one_frm: false 73 | mul_tx: 74 | use_ddp: false 75 | to_use: true 76 | n_layers: 1 77 | n_heads: 3 78 | attn_drop: 0.2 79 | use_rel: false 80 | one_frm: true 81 | cross_frm: false 82 | loss: 83 | only_vid_loss: false 84 | loss_lambda: 1 85 | loss_margin: 0.1 86 | loss_margin_vid: 0.5 87 | # loss_type is either 88 | # cosine or bce 89 | loss_type: 'bce' 90 | 91 | misc: 92 | # Place to save models/logs/predictions etc 93 | tmp_path: "tmp" 94 | # Include/Exclude proposal based on the threshold 95 | prop_thresh: 0. 96 | # Whether to exclude the proposals having background class 97 | exclude_bgd_det: False 98 | # Whether to add the proposal (5d coordinate) to 99 | # the region feature 100 | add_prop_to_region: False 101 | # What context to use for average pooling segment features 102 | ctx_for_seg_feats: 0 103 | # max number of semantic roles in a sentence 104 | srl_arg_length: 5 105 | # how many boxes to consider for a particular phrase 106 | box_per_srl_arg: 4 107 | train: 108 | lr: 1e-4 109 | epochs: 10 110 | bs: 4 111 | nw: 4 112 | bsv: 4 113 | nwv: 4 114 | resume: true 115 | resume_path: "" 116 | load_opt: false 117 | load_normally: true 118 | strict_load: true 119 | use_reduce_lr_plateau: false 120 | verbose: false 121 | prob_thresh: 0.2 122 | log: 123 | deb_it: 2 124 | local_rank: 0 125 | do_dist: False 126 | do_dp: false 127 | num_gpus: 1 128 | only_val: false 129 | only_test: false 130 | run_final_val: true 131 | overfit_batch: false 132 | -------------------------------------------------------------------------------- /configs/create_asrl_cfg.yml: -------------------------------------------------------------------------------- 1 | ds_name: "asrl" 2 | ds: 3 | # AC/AE annotation files 4 | anet_cap_file: "data/anet_cap_ent_files/anet_captions_all_splits.json" 5 | anet_ent_split_file: "data/anet_cap_ent_files/dic_anet.json" 6 | anet_ent_annot_file: "data/anet_cap_ent_files/cap_anet_trainval.json" 7 | orig_anet_ent_clss: "data/anet_cap_ent_files/anet_entities_cleaned_class_thresh50_trainval.json" 8 | preproc_anet_ent_clss: "data/anet_cap_ent_files/anet_ent_cls_bbox_trainval.json" 9 | # After adding semantic roles, these are generated inside the cache dir 10 | srl_caps: "SRL_Anet_cap_annots.csv" 11 | srl_bert: "srl_bert_preds.pkl" 12 | # Resized width, height 13 | resized_width: 720 14 | resized_height: 405 15 | # Feature files 16 | vid_hw_map: "data/anet/vid_hw_dict.json" 17 | proposal_h5: "data/anet/anet_detection_vg_fc6_feat_100rois.h5" 18 | proposal_h5_resized: "data/anet/anet_detection_vg_fc6_feat_100rois_resized.h5" 19 | seg_feature_root: "data/anet/rgb_motion_1d" 20 | feature_root: "data/anet/fc6_feat_100rois" 21 | # verbs and arguments to include/exclude 22 | exclude_verb_set: ['be', 'see', 'show', "'s", 'can', 'continue', 'begin', 'start'] 23 | include_srl_args: ['ARG0', 'ARG1', 'ARG2', 'ARGM-LOC'] 24 | # Lemmatized verb list (created only once) 25 | verb_lemma_dict_file: "data/anet_srl_files/verb_lemma_dict.json" 26 | # SRL with verbs 27 | verb_ent_file: "data/anet_srl_files/verb_ent_file.csv" 28 | trn_verb_ent_file: "data/anet_srl_files/trn_verb_ent_file.csv" 29 | val_verb_ent_file: "data/anet_srl_files/val_verb_ent_file.csv" 30 | # Object Mappings: 31 | trn_ds4_dicts: "data/anet_srl_files/trn_srl_obj_to_index_dict.json" 32 | val_ds4_dicts: "data/anet_srl_files/val_srl_obj_to_index_dict.json" 33 | # ASRL with indices for SPAT/TEMP 34 | trn_ds4_inds: "data/anet_srl_files/trn_asrl_annots.csv" 35 | val_ds4_inds: "data/anet_srl_files/val_asrl_annots.csv" 36 | # Arg Vocab: 37 | arg_vocab_file: "data/anet_srl_files/arg_vocab.pkl" 38 | # None 39 | none_word: "" 40 | # GT5 41 | ngt_prop: 5 42 | num_frms: 10 43 | feature_gt5_root: "data/anet/fc6_feat_5rois" 44 | proposal_gt5_h5_resized: "data/anet/anet_detection_vg_fc6_feat_gt5_rois.h5" 45 | misc: 46 | cache_dir: "cache_dir" 47 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Data 2 | 3 | This part is to download the data and start with the experiments. 4 | 5 | If instead you are interested in generating ActivityNet-SRL from scratch (not required in general), see [dcode](../dcode). 6 | 7 | ## Quickstart 8 | 9 | Optional: set the data folder. 10 | ``` 11 | cd $ROOT/data 12 | bash download_data.sh all [data_folder] 13 | ``` 14 | 15 | After everything is downloaded successfully, the folder structure should look like: 16 | 17 | ``` 18 | data 19 | |-- anet (530gb) 20 | |-- anet_detection_vg_fc6_feat_100rois.h5 21 | |-- anet_detection_vg_fc6_feat_100rois_resized.h5 22 | |-- anet_detection_vg_fc6_feat_gt5_rois.h5 23 | |-- fc6_feat_100rois 24 | |-- fc6_feat_5rois 25 | |-- rgb_motion_1d 26 | |-- anet_cap_ent_files (31M) 27 | |-- anet_captions_all_splits.json 28 | |-- anet_ent_cls_bbox_trainval.json 29 | |-- csv_dir 30 | |-- train.csv 31 | |-- train_postproc.csv 32 | |-- val.csv 33 | |-- val_postproc.csv 34 | |-- dic_anet.json 35 | |-- anet_srl_files (112M) 36 | |-- arg_vocab.pkl 37 | |-- trn_asrl_annots.csv 38 | |-- trn_srl_obj_to_index_dict.json 39 | |-- val_asrl_annots.csv 40 | |-- val_srl_obj_to_index_dict.json 41 | ``` 42 | 43 | It should ~530 gb of data !! 44 | 45 | NOTE: Highly advisable to have the features in SSD; otherwise massive drop in speed! 46 | 47 | 48 | ## Details about the Data 49 | Here, I have explained the data contents in 1-line. 50 | For an in-depth overview of the construction, please refer to [DATA PREP README](../dcode/README.md) 51 | 52 | 1. `fc6_feat_Xrois`: We have 10 frames, for each frame we get X rois. `X=100` is obtained from FasterRCNN trained on Visual Genome. `X=5` is obtained from `X=100` such that ground-truth annotations are included and the remaining are the top scoring boxes. The latter setting allows us to perform easy experimentations. 53 | 1. `rgb_motion_1d`: RGB and FLOW features for frames (1fps) of the video. 54 | 1. `{trn/val}_asrl_annots.csv`: The main annotation files required for grounding. 55 | 1. `{trn/val}_srl_obj_to_index_dict.json`: Dictionary mapping helpful for sampling contrastive examples. 56 | 57 | ## Annotation File Structure: 58 | The main annotation files for ASRL are `{trn/val}_asrl_annots.csv` 59 | 60 | We use Video Segments of the ActivityNet since we are focussing on Trimmed videos only. 61 | 62 | ActivityNet Entities provides the bounding boxes for the noun-phrases in ActivityNet Captions. For more details please refer to [dcode](../dcode) 63 | 64 | `trn_asrl_annots.csv` has 26 columns! 65 | 66 | Lets consider the first example. You can get this using: 67 | ``` 68 | import pandas as pd 69 | trn_csv = pd.read_csv('./trn_asrl_annots.csv') 70 | first_data_point = trn_csv.iloc[0] 71 | column_list = ['srl_ind', 'vid_seg'] 72 | ``` 73 | 74 | 1. `srl_ind`: the index in this csv file. Here it is `0` 75 | 1. `vt_split`: is the split the data point belongs to. All data points in `trn_asrl_anonts.csv` have this set to `train`. However, it is 50-50 split for `val_asrl_annots.csv` for `val` and `test`. 76 | 1. `vid_seg`: the video and the segment of the video the file belongs to. The convention used is `{vid_name}_segment_{seg_id:02d}`. Here it is `v_--0edUL8zmA_segment_00` which means, it is the 0th segment of the video `v_--0edUL8zmA`. 77 | 1. `ann_ind`: this is the index in the `anet_cap_ent_files/csv_dir/{trn/val}_postproc.csv` file. This index is used to retrieve the proposal boxes from `anet_detection_vg_fc6_feat_100rois_resized.h5`. Here it is `28557` which means 28557th row of the h5 file corresponds to this `vid_seg`. 78 | 1. `sent`: this is the main sentence provided in the activitynet captions for the given vid_seg. The sentence may contain multiple verbs, and as such data points sharing the same vid seg will have the same sentence. Here, the sentence is "Four men are playing dodge ball in an indoor court ." 79 | 1. `words`: this is simply tokenization of `sent`. Here it is: \['Four', 'men', 'are', 'playing', 'dodge', 'ball', 'in', 'an', 'indoor', 'court', '.'\] 80 | 1. `verb`: we pass the sentence through a semantic role labeler (see [demo](https://demo.allennlp.org/semantic-role-labeling)) which extracts multiple verbs from the sentence and assigning semantic roles pivoted for each verb. Each verb is treated as a separate data point. Here, the verb is `playing`. 81 | 1. `tags`: The BIO tagging output from the SRL for the given verb. Here it is \['B-ARG0', 'I-ARG0', 'O', 'B-V', 'B-ARG1', 'I-ARG1', 'B-ARGM-LOC', 'I-ARGM-LOC', 'I-ARGM-LOC', 'I-ARGM-LOC', 'O'\] which basically the structure "playing: \[ARG0: Four men] are \[V: playing] \[ARG1: dodge ball] \[ARGM-LOC: in an indoor court] ." 82 | 1. `req_pat_ix`: Same information as `tags` but represented as List\[Tuple\[ArgX, List\[word indices]]. The word indices correspond to the output of `word`. Here it is `[['ARG0', [0, 1]], ['V', [3]], ['ARG1', [4, 5]], ['ARGM-LOC', [6, 7, 8, 9]]]` which suggests `word[0], word[1]` constitute ARG0 (basically \[ARG0: Four men]) 83 | 1. `req_pat`: Same information as above, just the list of word indices are replaced with space separated words. Here it is: \[('ARG0', 'Four men'), ('V', 'playing'), ('ARG1', 'dodge ball'), ('ARGM-LOC', 'in an indoor court')] 84 | 1. `req_aname`: Same as `req_pat` just that it only extracts the words without the argument roles. Here it is: \['Four men', 'playing', 'dodge ball', 'in an indoor court'] 85 | 1. `req_args`: Instead of the words, only stores the semantic roles. Here it is \['ARG0', 'V', 'ARG1', 'ARGM-LOC'] 86 | 1. `gt_bboxes`: The ground-truth boxes (4d) provided in AE for the given vid-seg. It is List\[List\[x1,y1,x2,y2]] 87 | 1. `gt_frms`: The frames (ranging from 0-9) where they are annotated. It is List\[\len(gt_bboxes)] 88 | 1. `process_idx2`: It provides the word index for the given bounding box. It is List\[List\[int]]. Here it is `[[1], [1], [1], [1], [9]]`. Note that `word[1] = men` which means the first four bounding boxes refer to the four men and the final bounding box refers to the `court`. 89 | 1. `process_clss`: Lemmatized Noun for the words in `process_idx2`. Here it is `[['man'], ['man'], ['man'], ['man'], ['court']]` 90 | 1. `req_cls_pats`: Same as `req_pat` with the words replaced with their lemmatized noun. `[('ARG0', ['man']), ('V', ['playing']), ('ARG1', ['dodge', 'ball']), ('ARGM-LOC', ['court'])]` 91 | 1. `req_cls_pats_mask`: It is List\[Tuple\[ArgX, Mask, GTBox Index list]]. ArgX is the Argument Name like Arg0, Mask = 1 means this role has a bounding box, 0 implies the role doesn't have a bounding box and hence is not evaluated. GTBox Index List is the list of indices of the bounding boxes refering to this role. Here it is `[('ARG0', 1, [0, 1, 2, 3]), ('V', 0, [0]), ('ARG1', 0, [0]), ('ARGM-LOC', 1, [4])]` which implies ARG0 and ARGM-LOC are groundable, while V and ARG1 are not. Moreover, the first four bounding boxes refer to ARG0 and the last bounding box refers to ARGM-LOC. 92 | 1. `lemma_ARGX`: The lemmatized verb/argument role used for contrastive sampling. 93 | 1. `DS4_Inds`: For each role, it contains indices for which everything other than the lemmatized word for the argument role is same. 94 | 1. `ds4_msk`: If such contrastive samples were successfully found. 95 | 1. `RandDS4_Inds`: Simply random indices. 96 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Downloading script 3 | 4 | CUR_DIR=$(pwd) 5 | DATA_ROOT=${2:-$CUR_DIR} 6 | 7 | mkdir -p $DATA_ROOT/anet 8 | 9 | function gdrive_download () { 10 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p') 11 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2 12 | rm -rf /tmp/cookies.txt 13 | } 14 | 15 | function asrl_ann_dwn(){ 16 | echo "Downloading ActivityNet SRL annotations" 17 | cd $DATA_ROOT 18 | gdrive_download "1WJTRQVs-vSLmJ7I3sef3_IxE0vBtaqFX" anet_srl.zip 19 | unzip anet_srl.zip && rm anet_srl.zip 20 | cd $CUR_DIR 21 | # The above is minimalistic download and should be fine 22 | # for most cases. 23 | # To get all the files: 24 | # gdrive_download 1qSsD3AbWqw-KNObNg6N8xbTnF-Bg_eZn anet_verb.zip 25 | # unzip anet_verb.zip && rm anet_verb.zip 26 | # gdrive_download 1aZyLNP-VXS3stZpenWMuCTRF_NL2gznu anet_srl_scratch.zip 27 | # unzip anet_srl_scratch.zip && rm anet_srl_scratch.zip 28 | echo "Saved Folder" 29 | } 30 | 31 | function anet_feats_dwn(){ 32 | echo "Downloading ActivityNet Feats. May take some time" 33 | # Courtesy of Louwei Zhou, obtained from the repository: 34 | # https://github.com/facebookresearch/grounded-video-description/blob/master/tools/download_all.sh 35 | cd $DATA_ROOT/anet 36 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz 37 | tar -xvzf rgb_motion_1d.tar.gz && rm rgb_motion_1d.tar.gz 38 | 39 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5 40 | 41 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz 42 | tar -xvzf fc6_feat_100rois.tar.gz && rm fc6_feat_100rois.tar.gz 43 | 44 | gdrive_download 13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM fc6_feat_5rois.zip 45 | unzip fc6_feat_5rois.zip && rm fc6_feat_5rois.zip 46 | 47 | gdrive_download 1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9 anet_detn_proposals_resized.zip 48 | unzip anet_detn_proposals_resized.zip && rm anet_detn_proposals_resized.zip 49 | cd $CUR_DIR 50 | } 51 | 52 | function dwn_all(){ 53 | asrl_ann_dwn 54 | anet_feats_dwn 55 | } 56 | 57 | 58 | if [ "$1" = "asrl_anns" ] 59 | then 60 | asrl_ann_dwn 61 | 62 | elif [ "$1" = "anet_feats" ] 63 | then 64 | anet_feats_dwn 65 | elif [ "$1" = "all" ] 66 | then 67 | dwn_all 68 | else 69 | echo "Failed: Use download_data.sh asrl_anns | anet_feats | all" 70 | exit 1 71 | fi 72 | -------------------------------------------------------------------------------- /dcode/README.md: -------------------------------------------------------------------------------- 1 | # Creating ActivityNet SRL (ASRL) from ActivityNet Captions (AC) and ActivityNet Entities (AE) 2 | 3 | The code is for generating the dataset from the parent datasets. 4 | If you just want to use use ASRL as a training bed, you can skip this. See [data](../data) 5 | 6 | ## Quick summary 7 | 8 | Very briefly, the process is as follows: 9 | 1. Add semantic roles to captions in AC. 10 | 1. Prepocess AE. In particular, resize all the proposals, ground-truth bounding boxes (this is 11 | required for SPAT/TEMP). 12 | 1. Preprocess the features and choose only 5 groundtruths for GT5 setting. 13 | 1. Obtain the bounding boxes and category names from AE for the relevant phrases. 14 | 1. Filter out some verbs like "is", "are", "complete", "begin" 15 | 1. Filter some SRL Arguments based on Frequency. 16 | 1. Get Training/Validation/Test videos. 17 | 1. Do Contrastive Sampling and store the dictionary files for easier sampling during training. 18 | 19 | ## Preprocessing Steps 20 | 21 | 1. First download relevant files. 22 | Optional: specify the data folder where it would be downloaded. 23 | ``` 24 | bash download_asrl_parent_ann.sh [save_point] 25 | ``` 26 | The folder should look like: 27 | ``` 28 | anet_cap_ent_files 29 | |-- anet_captions_all_splits.json (AC captions) 30 | |-- anet_entities_test_1.json 31 | |-- anet_entities_test_2.json 32 | |-- anet_entities_val_1.json 33 | |-- anet_entities_val_2.json 34 | |-- cap_anet_trainval.json (AE Train annotations) 35 | |-- dic_anet.json (Train/Valid/Test video splits for AE) 36 | ``` 37 | 38 | 1. Use SRL Labeling system from AllenAI (Should take ~15 mins) to add the semantic roles to the captions from AC. 39 | ``` 40 | cd $ROOT 41 | python dcode/sem_role_labeller.py 42 | ``` 43 | 44 | This will create `$ROOT/cache_dir` and store the output SRL files which should look like: 45 | ``` 46 | cache_dir/ 47 | |-- SRL_Anet 48 | |-- SRL_Anet_bert_cap_annots.csv # AC annotations in csv format to input into BERT 49 | |-- srl_bert_preds.pkl # BERT outputs 50 | ``` 51 | 52 | 1. Resize the boxes in AE. 53 | ``` 54 | cd $ROOT 55 | python dcode/preproc_anet_files.py --task='resize_boxes_ae' 56 | ``` 57 | This takes the file `cap_anet_trainval.json` as input (this is the main AE annotation file) and outputs `anet_ent_cls_bbox_trainval.json`. The latter file contains resized ground-truth boxes. 58 | It also resizes the proposal boxes, taking in `anet_detection_vg_fc6_feat_100rois.h5` as input and produces `anet_detection_vg_fc6_feat_100rois_resized.h5` as output. The latter contains resized proposals. 59 | 60 | 1. GT5 setting 61 | ``` 62 | cd $ROOT 63 | python dcode/preproc_anet_files.py --task='choose_gt_5' 64 | ``` 65 | Intially, there are `100` proposals per frame. 66 | For faster iteration, we only choose the 5 proposals from each frame. 67 | If there is a ground-truth box, we take include that box, and the remaining are included in order of their proposal score (not a fair way, but the best that could be done). 68 | If there are no ground-turth box, we choose the top5 scoring proposals. 69 | 70 | To compute the recall scores (for sanity check): 71 | ``` 72 | python dcode/preproc_anet_files.py --task='compute_recall' 73 | ``` 74 | By default, it computes recall scores for GT5, you can change the proposal file, for other settings. 75 | 76 | 1. Aligning SRL outputs and NounPhrases from AE to create ASRL and adding the bounding boxes to the ASRL files (<1min) 77 | ``` 78 | cd $ROOT 79 | python dcode/asrl_creator.py 80 | ``` 81 | Now `$ROOT/data/anet_srl_files/` should look like: 82 | ``` 83 | anet_srl_files/ 84 | |-- verb_ent_file.csv # main file with SRLs, BBoxes 85 | |-- verb_lemma_dict.json # dictionary of verbs corresponding to their lemma 86 | ``` 87 | 88 | 1. Use the Train/Val videos from AE to create Train/Val/Test videos for ASRL (~5-7 mins). 89 | Additionally, create the vocab file for the SRL arguments 90 | ``` 91 | cd $ROOT 92 | python dcode/prepoc_ds_files.py 93 | ``` 94 | This will create `anet_cap_ent_files/csv_dir`. It should look like: 95 | ``` 96 | csv_dir 97 | |-- train.csv 98 | |-- train_postproc.csv 99 | |-- val.csv 100 | |-- val_postproc.csv 101 | ``` 102 | 103 | Further, now `$ROOT/data/anet_srl_files/` should look like: 104 | ``` 105 | anet_srl_files/ 106 | |-- trn_verb_ent_file.csv # train file 107 | |-- val_verb_ent_file.csv # val & test file 108 | |-- verb_ent_file.csv 109 | |-- verb_lemma_dict.json 110 | ``` 111 | 112 | 1. Do Constrastive sampling for train and validation set (~30mins) 113 | ``` 114 | cd $ROOT 115 | python code/contrastive_sampling.py 116 | ``` 117 | 118 | Now your `anet_srl_files` directory should look like: 119 | ``` 120 | anet_srl_files/ 121 | |-- trn_asrl_annots.csv # used for training 122 | |-- trn_srl_obj_to_index_dict.json # used for CS 123 | |-- trn_verb_ent_file.csv # not used anymore 124 | |-- val_asrl_annots.csv # used for val/test 125 | |-- val_srl_obj_to_index_dict.json # used for CS 126 | |-- val_verb_ent_file.csv # not used anymore 127 | |-- verb_ent_file.csv # not used anymore 128 | |-- verb_lemma_dict.json # not used anymore 129 | ``` 130 | 131 | 1. I have provided drive links to the processed files (generated after completing all the previous steps): 132 | 1. `anet_cap_ent_files` and `anet_srl_files`: https://drive.google.com/open?id=1mH8TyVPU4w7864Hxiukzg8dnqPIyBuE3 133 | 1. `SRL_Anet`: https://drive.google.com/open?id=1vGgqc8_-ZBk3ExNroRP-On7ArWN-d8du 134 | 1. resized proposal h5 files: https://drive.google.com/open?id=1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9 135 | 1. fc6_feats_5rois: https://drive.google.com/open?id=13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM 136 | 137 | Alternatively, you can download these files from `download_asrl_parent_ann.sh` by passing `asrl_proc_files`: 138 | ``` 139 | bash download_asrl_parent_ann.sh asrl_proc_files 140 | ``` 141 | -------------------------------------------------------------------------------- /dcode/dataset_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gives the dataset statistics 3 | in form of tables. 4 | Copy-paste to Excel for visualization 5 | """ 6 | from yacs.config import CfgNode as CN 7 | import yaml 8 | from asrl_creator import Anet_SRL_Create 9 | from pathlib import Path 10 | import pandas as pd 11 | import ast 12 | from typing import Dict, List, Tuple 13 | from collections import Counter 14 | import altair as alt 15 | 16 | 17 | class AnetSRL_Vis(object): 18 | def __init__(self, cfg, do_vis=True): 19 | self.cfg = cfg 20 | self.open_req_files() 21 | self.vis = do_vis 22 | 23 | def fix_via_ast(self, df): 24 | for k in df.columns: 25 | first_word = df.iloc[0][k] 26 | if isinstance(first_word, str) and (first_word[0] in '[{'): 27 | df[k] = df[k].apply( 28 | lambda x: ast.literal_eval(x)) 29 | return df 30 | 31 | def open_req_files(self): 32 | trn_asrl_file = self.cfg.ds.trn_ds4_inds 33 | val_asrl_file = self.cfg.ds.val_ds4_inds 34 | 35 | self.trn_srl_annots = self.fix_via_ast(pd.read_csv(trn_asrl_file)) 36 | self.val_srl_annots = self.fix_via_ast(pd.read_csv(val_asrl_file)) 37 | 38 | def print_most_common_table(self, most_comm: List[Tuple]): 39 | """ 40 | Prints most common output from a Counter in the 41 | form of a table for easy copy/pasting 42 | """ 43 | patt = '{}, {}\n' 44 | out_str = '' 45 | for it in most_comm: 46 | out_str += patt.format(*it) 47 | print(out_str) 48 | return 49 | 50 | def visualize_df(self, df: pd.DataFrame, 51 | x_name: str, y_name: str): 52 | bars = alt.Chart(df).mark_bar( 53 | cornerRadiusBottomRight=3, 54 | cornerRadiusTopRight=3, 55 | ).encode( 56 | x=alt.X(x_name, axis=alt.Axis(title="")), 57 | y=alt.Y(y_name, axis=alt.Axis(title=""), 58 | sort='-x'), 59 | color=alt.value('#6495ED') 60 | ) 61 | text = bars.mark_text( 62 | align='left', 63 | baseline='middle', 64 | dx=3 # Nudges text to right so it doesn't appear on top of the bar 65 | ).encode( 66 | text='Count:Q' 67 | ) 68 | 69 | return (bars + text).properties(height=500) 70 | 71 | def get_num_vids(self): 72 | """ 73 | Input dictionary with train and validation df 74 | """ 75 | nvids = {} 76 | nvids['train'] = len(self.trn_srl_annots.vid_seg.unique()) 77 | nvids['valid'] = len( 78 | self.val_srl_annots[ 79 | self.val_srl_annots.vt_split == 'val' 80 | ].vid_seg.unique() 81 | ) 82 | nvids['test'] = len( 83 | self.val_srl_annots[ 84 | self.val_srl_annots.vt_split == 'test' 85 | ].vid_seg.unique() 86 | ) 87 | return nvids 88 | 89 | def get_num_noun_phrase(self): 90 | """ 91 | Return number of noun-phrases for 92 | each SRL 93 | """ 94 | # req_cls_pats_mask: [['ArgX', 1/0, box_num]] 95 | # get only the argument name and count 96 | arg_counts = self.trn_srl_annots.req_cls_pats_mask.apply( 97 | lambda x: [y[0] for y in x] 98 | ) 99 | return Counter([ac for acs in arg_counts for ac in acs]) 100 | 101 | def get_num_phrase_with_box(self): 102 | # req_cls_pats_mask: [['ArgX', 1/0, box_num]] 103 | # get only the argument name and count 104 | arg_counts = self.trn_srl_annots.req_cls_pats_mask.apply( 105 | lambda x: [y[0] for y in x if y[1] == 1] 106 | ) 107 | return Counter([ac for acs in arg_counts for ac in acs]) 108 | 109 | def get_num_srl_structures(self): 110 | arg_struct_counts = self.trn_srl_annots.req_args.apply( 111 | lambda x: '-'.join(x) 112 | ) 113 | return Counter(list(arg_struct_counts)).most_common(20) 114 | 115 | def get_num_lemma(self, arg_list): 116 | lemma_counts = {} 117 | col_set = set(self.trn_srl_annots.columns) 118 | for agl in arg_list: 119 | if agl != 'verb': 120 | lemma_key = f'lemma_{agl}' 121 | assert lemma_key in col_set 122 | lemma_counts[lemma_key] = Counter( 123 | list( 124 | self.trn_srl_annots[lemma_key].apply( 125 | lambda x: x[0] if len(x) > 0 else '' 126 | ) 127 | ) 128 | ) 129 | else: 130 | lemma_key = 'lemma_verb' 131 | lemma_counts[lemma_key] = Counter( 132 | list( 133 | self.trn_srl_annots[lemma_key] 134 | ) 135 | ) 136 | return lemma_counts 137 | 138 | def get_num_q_per_vid(self): 139 | num_q_per_vid = ( 140 | len(self.trn_srl_annots) / 141 | len(self.trn_srl_annots.vid_seg.unique()) 142 | ) 143 | 144 | num_srl_per_q = self.trn_srl_annots.req_args.apply( 145 | lambda x: len(x)).mean() 146 | 147 | num_w_per_q = self.trn_srl_annots.req_pat_ix.apply( 148 | lambda x: sum([len(y[1]) for y in x])).mean() 149 | 150 | return num_q_per_vid, num_srl_per_q, num_w_per_q 151 | 152 | def print_all_stats(self): 153 | vis_list = [] 154 | nvid = self.get_num_vids() 155 | print("Number of videos in Train/Valid/Test: " 156 | f"{nvid['train']}, {nvid['valid']}, {nvid['test']}") 157 | 158 | num_q_per_vid, num_srl_per_q, num_w_per_q = self.get_num_q_per_vid() 159 | print(f"Number of Queries per Video is {num_q_per_vid}") 160 | print(f"Number of Queries per Video is {num_srl_per_q}") 161 | print(f"Number of Queries per Video is {num_w_per_q}") 162 | 163 | num_noun_phrases_for_srl = self.get_num_noun_phrase().most_common(n=20) 164 | num_np_srl = pd.DataFrame.from_records( 165 | data=num_noun_phrases_for_srl, 166 | columns=['Arg', 'Count'] 167 | ) 168 | if self.vis: 169 | vis_list.append( 170 | self.visualize_df(num_np_srl, x_name='Count:Q', y_name='Arg:O') 171 | ) 172 | print('Noun Phrases Count') 173 | print(num_np_srl.to_csv(index=False)) 174 | 175 | num_noun_phrases_with_box_for_srl = self.get_num_phrase_with_box() 176 | 177 | num_grnd_np_srl = pd.DataFrame.from_records( 178 | data=num_noun_phrases_with_box_for_srl.most_common(n=20), 179 | columns=['Arg', 'Count'] 180 | ) 181 | if self.vis: 182 | vis_list.append( 183 | self.visualize_df( 184 | num_grnd_np_srl, x_name='Count:Q', y_name='Arg:O') 185 | ) 186 | print('Groundable Noun Phrase Count') 187 | print(num_grnd_np_srl.to_csv(index=False)) 188 | 189 | num_srl_struct = self.get_num_srl_structures() 190 | num_srl_struct_df = pd.DataFrame.from_records( 191 | data=num_srl_struct, 192 | columns=['Arg', 'Count'] 193 | ) 194 | if self.vis: 195 | vis_list.append( 196 | self.visualize_df(num_srl_struct_df, 197 | x_name='Count:Q', y_name='Arg:O') 198 | ) 199 | print('SRL Structures Frequency') 200 | print(num_srl_struct_df.to_csv(index=False)) 201 | 202 | arg_list = ['verb', 'ARG0', 'ARG1', 'ARG2', 'ARGM_LOC'] 203 | lemma_counts = self.get_num_lemma(arg_list) 204 | min_t = 20 205 | num_lemma_args = { 206 | k: len([z for z in v.most_common() if z[1] > min_t]) 207 | for k, v in lemma_counts.items() 208 | } 209 | print(f'Lemmatized Counts for each lemma: {num_lemma_args}') 210 | 211 | df_dict = { 212 | k: pd.DataFrame.from_records( 213 | data=v.most_common(21), 214 | columns=['String', 'Count'] 215 | ) 216 | for k, v in lemma_counts.items() 217 | } 218 | 219 | for k in df_dict: 220 | print(f'Most Frequent Lemmas for {k}') 221 | print(df_dict[k].to_csv(index=False)) 222 | 223 | return lemma_counts 224 | # return vis_list 225 | 226 | 227 | if __name__ == '__main__': 228 | cfg = CN(yaml.safe_load(open('./configs/anet_srl_cfg.yml'))) 229 | asrl_vis = AnetSRL_Vis(cfg) 230 | asrl_vis.print_all_stats() 231 | -------------------------------------------------------------------------------- /dcode/download_asrl_parent_ann.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Downloading script 3 | 4 | CUR_DIR=$(pwd) 5 | DDIR=${2:-"../data"} 6 | DATA_ROOT=$DDIR/anet_cap_ent_files 7 | 8 | echo $DATA_ROOT 9 | mkdir -p $DDIR/anet_srl_files 10 | mkdir -p $DATA_ROOT 11 | 12 | function gdrive_download () { 13 | CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate "https://docs.google.com/uc?export=download&id=$1" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p') 14 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1" -O $2 15 | rm -rf /tmp/cookies.txt 16 | } 17 | 18 | function anet_feats_dwn(){ 19 | echo "Downloading ActivityNet Feats. May take some time" 20 | # Courtesy of Louwei Zhou, obtained from the repository: 21 | # https://github.com/facebookresearch/grounded-video-description/blob/master/tools/download_all.sh 22 | cd $DDIR/anet 23 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/rgb_motion_1d.tar.gz 24 | tar -xvzf rgb_motion_1d.tar.gz && rm rgb_motion_1d.tar.gz 25 | 26 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_detection_vg_fc6_feat_100rois.h5 27 | 28 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/fc6_feat_100rois.tar.gz 29 | tar -xvzf fc6_feat_100rois.tar.gz && rm fc6_feat_100rois.tar.gz 30 | 31 | # gdrive_download 13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM fc6_feat_5rois.zip 32 | # unzip fc6_feat_5rois.zip && rm fc6_feat_5rois.zip 33 | 34 | # gdrive_download 1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9 anet_detn_proposals_resized.zip 35 | # unzip anet_detn_proposals_resized.zip && rm anet_detn_proposals_resized.zip 36 | cd $CUR_DIR 37 | } 38 | 39 | function ac_ae_dwn(){ 40 | echo "Downloading ActivityNet Captions and ActivityNet Entities" 41 | cd $DATA_ROOT 42 | # Courtesy of Louwei Zhou, obtained from the repository: 43 | # https://github.com/facebookresearch/grounded-video-description 44 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_prep.tar.gz 45 | wget https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz 46 | wget https://raw.githubusercontent.com/facebookresearch/ActivityNet-Entities/master/data/anet_entities_cleaned_class_thresh50_trainval.json 47 | tar -xvzf anet_entities_prep.tar.gz && rm anet_entities_prep.tar.gz 48 | tar -xvzf anet_entities_captions.tar.gz && rm anet_entities_captions.tar.gz 49 | cd $CUR_DIR 50 | echo "Finished downloading ActivityNet Captions and ActivityNet Entities" 51 | } 52 | 53 | function processed_files_dwn(){ 54 | echo "Downloading ASRL processed files" 55 | cd $DDIR 56 | mkdir asrl_processed_files 57 | cd asrl_processed_files 58 | gdrive_download "1mH8TyVPU4w7864Hxiukzg8dnqPIyBuE3" anet_srl_files_all.zip 59 | gdrive_download "1vGgqc8_-ZBk3ExNroRP-On7ArWN-d8du" SRL_Anet.zip 60 | gdrive_download "1a6UOK90Epz7n-dncKAeFDQP4TBgqdTS9" anet_detn_proposals_resized.zip 61 | # gdrive_download "13tvBIEAgv4VS5dqkZBK1gvTI_Z22gRLM" fc6_feat_5rois.zip 62 | cd $CUR_DIR 63 | } 64 | 65 | function dwn_all(){ 66 | ac_ae_dwn 67 | anet_feats_dwn 68 | } 69 | 70 | if [ "$1" = "ac_ae_anns" ] 71 | then 72 | ac_ae_dwn 73 | elif [ "$1" = "anet_feats" ] 74 | then 75 | anet_feats_dwn 76 | elif [ "$1" = "asrl_proc_files" ] 77 | then 78 | processed_files_dwn 79 | elif [ "$1" = "all" ] 80 | then 81 | dwn_all 82 | else 83 | echo "Failed: Use download_asrl_parent_ann.sh ac_ae_anns | anet_feats | asrl_proc_files | all" 84 | exit 1 85 | fi 86 | -------------------------------------------------------------------------------- /dcode/preproc_anet_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Small preprocessing done for Anet files 3 | In particular: 4 | [ ] Add 'He' to 'man', 'boy', similarly for 'she' to 'woman', 'girl', 'lady' 5 | [ ] Resize ground-truth box 6 | """ 7 | 8 | import json 9 | from pathlib import Path 10 | from yacs.config import CfgNode as CN 11 | import yaml 12 | from tqdm import tqdm 13 | import h5py 14 | import pandas as pd 15 | import numpy as np 16 | from utils.box_utils import box_iou 17 | import copy 18 | import torch 19 | from collections import OrderedDict 20 | import fire 21 | 22 | 23 | class AnetEntFiles: 24 | def __init__(self, cfg): 25 | self.cfg = cfg 26 | self.conv_dict = { 27 | 'man': 'he', 28 | 'boy': 'he', 29 | 'woman': 'she', 30 | 'girl': 'she', 31 | 'lady': 'she' 32 | } 33 | self.open_req_files() 34 | 35 | def open_req_files(self): 36 | 37 | self.trn_anet_ent_file = Path(self.cfg.ds.anet_ent_annot_file) 38 | assert self.trn_anet_ent_file.exists() 39 | self.trn_anet_ent_data = json.load( 40 | open(self.trn_anet_ent_file)) 41 | 42 | self.trn_anet_ent_preproc_file = Path( 43 | self.cfg.ds.preproc_anet_ent_clss) 44 | assert self.trn_anet_ent_preproc_file.parent.exists() 45 | 46 | self.vid_dict_df = pd.DataFrame(json.load( 47 | open(self.cfg.ds.anet_ent_split_file))['videos']) 48 | self.vid_dict_df.index.name = 'Index' 49 | 50 | # Assert region features exists 51 | self.feature_root = Path(self.cfg.ds.feature_root) 52 | assert self.feature_root.exists() 53 | 54 | self.feature_root_gt5 = Path(self.cfg.ds.feature_gt5_root) 55 | self.feature_root_gt5.mkdir(exist_ok=True) 56 | assert self.feature_root_gt5.exists() 57 | 58 | def run(self): 59 | # out_ann = self.get_vidseg_hw_map( 60 | # ann=self.trn_anet_ent_orig_data['annotations']) 61 | out_ann = self.get_vidseg_hw_map( 62 | ann=self.trn_anet_ent_data) 63 | 64 | json.dump(out_ann, open(self.trn_anet_ent_preproc_file, 'w')) 65 | self.resize_props() 66 | 67 | def add_pronouns(self, ann): 68 | def upd(segv): 69 | """ 70 | segv: Dict. 71 | Keys: 'process_clss' etc 72 | Update the values for process_clss 73 | """ 74 | pck = 'process_clss' 75 | if pck not in segv: 76 | pck = 'clss' 77 | assert pck in segv 78 | proc_clss = segv[pck][:] 79 | assert isinstance(proc_clss, list) 80 | if len(proc_clss) == 0: 81 | return 82 | assert isinstance(proc_clss[0], list) 83 | new_proc_clss = [] 84 | for pc in proc_clss: 85 | new_pc = [] 86 | for p in pc: 87 | if p in self.conv_dict: 88 | new_pc.append(p) 89 | new_pc.append(self.conv_dict[p]) 90 | else: 91 | new_pc.append(p) 92 | new_proc_clss.append(new_pc) 93 | segv[pck] = new_proc_clss 94 | return 95 | out_dict_vid = {} 96 | for vidk, vidv in tqdm(ann.items()): 97 | out_dict_seg_vid = {} 98 | for segk, segv in vidv['segments'].items(): 99 | upd(segv) 100 | out_dict_seg_vid[segk] = segv 101 | out_dict_vid[vidk] = {'segments': out_dict_seg_vid} 102 | 103 | return out_dict_vid 104 | 105 | def get_vidseg_hw_map(self, ann=None): 106 | def upd(segv, sw, sh): 107 | """ 108 | segv: Dict 109 | Change process_bnd_box wrt hw 110 | """ 111 | pbk = 'process_bnd_box' 112 | if pbk not in segv: 113 | pbk = 'bbox' 114 | assert pbk in segv 115 | if len(segv[pbk]) == 0: 116 | return 117 | process_bnd_box = np.array( 118 | segv[pbk][:]).astype(float) 119 | process_bnd_box[:, [0, 2]] *= sw 120 | process_bnd_box[:, [1, 3]] *= sh 121 | process_bnd_box = process_bnd_box.astype(int) 122 | segv[pbk] = process_bnd_box.tolist() 123 | return 124 | 125 | vid_dict_df = self.vid_dict_df 126 | 127 | h5_proposal_file = h5py.File( 128 | self.cfg.ds.proposal_h5, 'r', driver='core') 129 | 130 | # num_proposals = h5_proposal_file['dets_num'][:] 131 | # label_proposals = h5_proposal_file['dets_labels'][:] 132 | 133 | hw_vids = h5_proposal_file['hw'][:].astype(float).tolist() 134 | out_dict = {} 135 | for row_ind, row in tqdm(vid_dict_df.iterrows()): 136 | vid_id = row['vid_id'] 137 | if vid_id not in out_dict: 138 | out_dict[vid_id] = hw_vids[row_ind] 139 | else: 140 | hw = hw_vids[row_ind] 141 | if not hw == [0., 0.]: 142 | assert hw == out_dict[vid_id] 143 | json.dump(out_dict, open(self.cfg.ds.vid_hw_map, 'w')) 144 | 145 | nw = self.cfg.ds.resized_width 146 | nh = self.cfg.ds.resized_height 147 | out_dict_vid = {} 148 | for vidk, vidv in tqdm(ann.items()): 149 | out_dict_seg_vid = {} 150 | oh, ow = out_dict[vidk] 151 | if ow != 0. or oh != 0.: 152 | sw = nw / ow 153 | sh = nh / oh 154 | else: 155 | sw, sh = 1., 1. 156 | for segk, segv in vidv['segments'].items(): 157 | upd(segv, sw*1., sh*1.) 158 | out_dict_seg_vid[segk] = segv 159 | out_dict_vid[vidk] = {'segments': out_dict_seg_vid} 160 | 161 | return out_dict_vid 162 | 163 | def resize_props(self): 164 | h5_proposal_file = h5py.File( 165 | self.cfg.ds.proposal_h5, 'r', driver='core') 166 | 167 | hw_vids = h5_proposal_file['hw'][:].astype(float).tolist() 168 | label_proposals = h5_proposal_file['dets_labels'][:] 169 | 170 | nw = self.cfg.ds.resized_width 171 | nh = self.cfg.ds.resized_height 172 | 173 | for row_ind in tqdm(range(len(label_proposals))): 174 | oh, ow = hw_vids[row_ind] 175 | if ow != 0. or oh != 0.: 176 | sw = nw / ow 177 | sh = nh / oh 178 | else: 179 | sw, sh = 1., 1. 180 | 181 | label_proposals[row_ind, :, [0, 2]] *= sw 182 | label_proposals[row_ind, :, [1, 3]] *= sh 183 | with h5py.File(self.cfg.ds.proposal_h5_resized, 'w') as f: 184 | keys = [k for k in h5_proposal_file.keys()] 185 | for k in keys: 186 | if k != 'dets_labels': 187 | f.create_dataset(k, data=h5_proposal_file[k]) 188 | else: 189 | f.create_dataset(k, data=label_proposals) 190 | 191 | return 192 | 193 | def choose_gt5(self, save=True): 194 | """ 195 | Choose 5 proposals for each frame 196 | """ 197 | h5_proposal_file = h5py.File( 198 | self.cfg.ds.proposal_h5_resized, 'r', driver='core') 199 | # h5_proposal_file = h5py.File( 200 | # self.cfg.ds.proposal_h5, 'r', driver='core') 201 | 202 | nppf_orig = 100 203 | nppf = self.cfg.ds.ngt_prop 204 | nfrms = self.cfg.ds.num_frms 205 | # Note these are resized labels 206 | label_proposals = h5_proposal_file['dets_labels'][:] 207 | num_proposals = h5_proposal_file['dets_num'][:] 208 | out_label_proposals = np.zeros_like( 209 | label_proposals)[:, :nfrms*nppf, ...] 210 | out_num_proposals = np.zeros_like(num_proposals) 211 | vid_dict_df = self.vid_dict_df 212 | 213 | anet_ent_preproc_data = json.load(open(self.trn_anet_ent_preproc_file)) 214 | # anet_ent_preproc_data = json.load( 215 | # open(self.cfg.ds.anet_ent_annot_file)) 216 | 217 | recall_num = 0 218 | recall_tot = 0 219 | 220 | for row_ind, row in tqdm(vid_dict_df.iterrows(), 221 | total=len(vid_dict_df)): 222 | # if row_ind > 1000: 223 | # break 224 | vid = row['vid_id'] 225 | seg = row['seg_id'] 226 | vid_seg_id = row['id'] 227 | 228 | annot = anet_ent_preproc_data[vid]['segments'][seg] 229 | gt_boxs = annot['bbox'] 230 | gt_frms = annot['frm_idx'] 231 | 232 | prop_index = row_ind 233 | 234 | props = copy.deepcopy(label_proposals[prop_index]) 235 | num_props = int(copy.deepcopy(num_proposals[prop_index])) 236 | 237 | if num_props < nfrms * nppf_orig: 238 | # import pdb 239 | # pdb.set_trace() 240 | assert np.all(props[num_props:, [0, 1, 2, 3]] == 0) 241 | 242 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy' 243 | if not region_feature_file.exists(): 244 | continue 245 | prop_feats_load = np.load(region_feature_file) 246 | prop_feats = np.zeros((nfrms, *prop_feats_load.shape[1:])) 247 | prop_feats[:prop_feats_load.shape[0]] = prop_feats_load 248 | 249 | out_file = self.feature_root_gt5 / f'{vid_seg_id}.npy' 250 | out_dict = self.choose_gt5_for_one_vid_seg( 251 | props, prop_feats, gt_boxs, gt_frms, out_file, 252 | save=save, nppf=nppf, nppf_orig=nppf_orig, nfrms=nfrms 253 | ) 254 | 255 | if save: 256 | num_prop = out_dict['num_prop'] 257 | out_label_proposals[prop_index][:num_prop] = ( 258 | out_dict['out_props'] 259 | ) 260 | out_num_proposals[prop_index] = num_prop 261 | 262 | recall_num += out_dict['recall'] 263 | recall_tot += out_dict['num_gt'] 264 | 265 | recall = recall_num.item() / recall_tot 266 | print(f'Recall is {recall}') 267 | if save: 268 | with h5py.File(self.cfg.ds.proposal_gt5_h5_resized, 'w') as f: 269 | keys = [k for k in h5_proposal_file.keys()] 270 | keys.remove('dets_labels') 271 | keys.remove('dets_num') 272 | for k in keys: 273 | f.create_dataset(k, data=h5_proposal_file[k]) 274 | 275 | f.create_dataset('dets_labels', data=out_label_proposals) 276 | f.create_dataset('dets_num', data=out_num_proposals) 277 | 278 | return recall 279 | 280 | def choose_gt5_for_one_vid_seg( 281 | self, props, prop_feats, 282 | gt_boxs, gt_frms, out_file, 283 | save=True, nppf=5, nppf_orig=100, nfrms=10): 284 | """ 285 | Choose for 5 props per frame 286 | """ 287 | # Convert to torch tensors for box_iou computations 288 | # props: 10*100 x 7 289 | props = torch.tensor(props).float() 290 | prop_feats = torch.tensor(prop_feats).float() 291 | # set for comparing 292 | gt_frms_set = set(gt_frms) 293 | gt_boxs = torch.tensor(gt_boxs).float() 294 | gt_frms = torch.tensor(gt_frms).float() 295 | 296 | # Get the frames for the proposal boxes are 297 | prop_frms = props[:, 4] 298 | # Create a frame mask. 299 | # Basically, if the iou = 0 if the proposal and 300 | # the ground truth box lie in different frames 301 | frm_msk = prop_frms[:, None] == gt_frms 302 | if len(gt_boxs) > 0 and len(props) > 0: 303 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float() 304 | # get the max iou proposal for each bounding box 305 | ious_max, ious_arg_max = ious.max(dim=0) 306 | # if len(ious_arg_max) > nppf: 307 | # ious_arg_max = ious_arg_max[:nppf] 308 | out_props = props[ious_arg_max] 309 | out_props_inds = ious_arg_max % 100 310 | recall = (ious_max > 0.5).sum() 311 | ngt = len(gt_boxs) 312 | else: 313 | ngt = 1 314 | recall = 0 315 | ious = torch.zeros(props.size(0), 1) 316 | out_props = props[0] 317 | out_props_inds = torch.tensor(0) 318 | 319 | # Dictionary to store final proposals to use 320 | fin_out_props = {} 321 | # Reshape proposals and proposal features to 322 | # nfrms x nppf x ndim 323 | props1 = props.view(nfrms, nppf_orig, 7) 324 | prop_dim = prop_feats.size(-1) 325 | prop_feats1 = prop_feats.view(nfrms, nppf_orig, prop_dim) 326 | 327 | # iterate over each frame 328 | for frm in range(nfrms): 329 | if frm not in fin_out_props: 330 | fin_out_props[frm] = [] 331 | 332 | # if there are gt boxes in the frame 333 | # consider the proposals which have highest iou 334 | # in the frame 335 | if frm in gt_frms_set: 336 | props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm] 337 | # add highest iou props to the dict key 338 | fin_out_props[frm] += props_inds_gt_in_frm.tolist() 339 | 340 | # sort by their scores, and choose nppf=5 such props 341 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[ 342 | :nppf] 343 | # add 5 such props to the list 344 | fin_out_props[frm] += props_to_use_inds.tolist() 345 | 346 | # Restrict the total to 5 347 | fin_out_props[frm] = list( 348 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf] 349 | 350 | # Saving them, init with zeros 351 | props_output = torch.zeros(nfrms, nppf, 7) 352 | prop_feats_output = torch.zeros(nfrms, nppf, prop_dim) 353 | 354 | # set for each frame 355 | for frm in fin_out_props: 356 | inds = fin_out_props[frm] 357 | props_output[frm] = props1[frm][inds] 358 | prop_feats_output[frm] = prop_feats1[frm][inds] 359 | 360 | # Reshape nfrm x nppf x ndim -> nfrm*nppf x ndim 361 | props_output = props_output.view(nfrms*nppf, 7).detach().cpu().numpy() 362 | prop_feats_output = prop_feats_output.view( 363 | nfrms, nppf, prop_dim).detach().cpu().numpy() 364 | 365 | if save: 366 | np.save(out_file, prop_feats_output) 367 | 368 | return { 369 | 'out_props': props_output, 370 | 'recall': recall, 371 | 'num_prop': nppf*nfrms, 372 | 'num_gt': ngt 373 | } 374 | 375 | def compute_recall(self, exp_setting='gt5'): 376 | """ 377 | Compute recall for the created h5 file 378 | """ 379 | if exp_setting == 'gt5': 380 | pfile = self.cfg.ds.proposal_gt5_h5_resized 381 | elif exp_setting == 'p100': 382 | pfile = self.cfg.ds.proposal_h5_resized 383 | 384 | with h5py.File(pfile, 'r') as f: 385 | label_proposals = f['dets_labels'][:] 386 | 387 | vid_dict_df = self.vid_dict_df 388 | 389 | anet_ent_preproc_data = json.load(open(self.trn_anet_ent_preproc_file)) 390 | 391 | recall_num = 0 392 | recall_tot = 0 393 | 394 | for row_ind, row in tqdm(vid_dict_df.iterrows(), 395 | total=len(vid_dict_df)): 396 | 397 | vid = row['vid_id'] 398 | seg = row['seg_id'] 399 | vid_seg_id = row['id'] 400 | 401 | annot = anet_ent_preproc_data[vid]['segments'][seg] 402 | gt_boxs = torch.tensor(annot['bbox']).float() 403 | gt_frms = annot['frm_idx'] 404 | 405 | prop_index = row_ind 406 | 407 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy' 408 | if not region_feature_file.exists(): 409 | continue 410 | 411 | props = copy.deepcopy(label_proposals[prop_index]) 412 | props = torch.tensor(props).float() 413 | # props = props.view(10, -1, 7) 414 | 415 | for fidx, frm in enumerate(gt_frms): 416 | prop_frms = props[props[..., 4] == frm] 417 | gt_box_in_frm = gt_boxs[fidx] 418 | 419 | ious = box_iou(prop_frms[:, :4], gt_box_in_frm) 420 | 421 | ious_max, ious_arg_max = ious.max(dim=0) 422 | # conversion to long is important, otherwise 423 | # after 256 becomes 0 424 | recall_num += (ious_max > 0.5).any().long() 425 | 426 | recall_tot += len(gt_boxs) 427 | 428 | recall = recall_num.item() / recall_tot 429 | print(f'Recall is {recall}') 430 | return 431 | 432 | 433 | def main(task: str, exp_setting='gt5'): 434 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml'))) 435 | anet_pre = AnetEntFiles(cfg) 436 | if 'resize_boxes_ae' in task: 437 | anet_pre.run() 438 | if 'choose_gt5' in task: 439 | anet_pre.choose_gt5(save=True) 440 | if 'compute_recall' in task: 441 | anet_pre.compute_recall(exp_setting) 442 | 443 | 444 | if __name__ == '__main__': 445 | fire.Fire(main) 446 | # cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml'))) 447 | # anet_pre = AnetEntFiles(cfg) 448 | # anet_pre.compute_recall() 449 | # anet_pre.choose_gt5(save=True) 450 | # anet_pre.add_pronouns() 451 | # anet_pre.get_vidseg_hw_map() 452 | # anet_pre.run() 453 | # anet_pre.resize_props() 454 | -------------------------------------------------------------------------------- /dcode/preproc_ds_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess dataset files 3 | """ 4 | 5 | import json 6 | import pandas as pd 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | import numpy as np 12 | import ast 13 | from collections import Counter 14 | from torchtext import vocab 15 | import pickle 16 | from munch import Munch 17 | 18 | 19 | np.random.seed(5) 20 | 21 | 22 | class AnetCSV: 23 | def __init__(self, cfg, comm=None): 24 | self.cfg = cfg 25 | if comm is not None: 26 | assert isinstance(comm, (dict, Munch)) 27 | self.comm = Munch(comm) 28 | else: 29 | self.comm = Munch() 30 | 31 | inp_anet_dict_fpath = cfg.ds.anet_ent_split_file 32 | self.inp_dict_file = Path(inp_anet_dict_fpath) 33 | 34 | # Create directory to keep the generated csvs 35 | self.out_csv_dir = self.inp_dict_file.parent / 'csv_dir' 36 | self.out_csv_dir.mkdir(exist_ok=True) 37 | 38 | # Structure of anet_dict: 39 | # anet = Dict, 40 | # keys: 1. word to lemma, 2. index to word, 41 | # 3. word to detection 4. video information 42 | # We only need the video information 43 | self.vid_dict_list = json.load(open(inp_anet_dict_fpath))['videos'] 44 | 45 | def create_csvs(self): 46 | """ 47 | Create the Train/Val split videos 48 | """ 49 | self.vid_info_df = pd.DataFrame(self.vid_dict_list) 50 | self.vid_info_df.index.name = 'Index' 51 | 52 | train_df = self.vid_info_df[self.vid_info_df.split == 'training'] 53 | train_df.to_csv(self.out_csv_dir / 'train.csv', 54 | index=True, header=True) 55 | 56 | # NOTE: Test files don't have the annotations, so cannot be used. 57 | # Instead we split the validation dataframe into two parts (50/50) 58 | 59 | val_test_df = self.vid_info_df[self.vid_info_df.split == 'validation'] 60 | 61 | # Randomly take half as validation, rest as test 62 | # Both are saved in val.csv, however, during evaluation 63 | # only those with "val" in the field "vt_split" are chosen 64 | msk = np.random.rand(len(val_test_df)) < 0.5 65 | val_test_df['vt_split'] = ['val' if m == 1 else 'test' for m in msk] 66 | val_test_df.to_csv(self.out_csv_dir / 'val.csv', 67 | index=True, header=True) 68 | 69 | def post_proc(self, csv_file_type): 70 | """ 71 | Some videos don't have features. These are removed 72 | for convenience. 73 | (only 4-5 videos were removed) 74 | """ 75 | self.seg_feature_root = Path(self.cfg.ds.seg_feature_root) 76 | assert self.seg_feature_root.exists() 77 | 78 | self.feature_root = Path(self.cfg.ds.feature_root) 79 | assert self.feature_root.exists() 80 | 81 | csv_file = self.out_csv_dir / f'{csv_file_type}.csv' 82 | csv_df = pd.read_csv(csv_file) 83 | msk = [] 84 | num_segs_list = [] 85 | for row_ind, row in tqdm(csv_df.iterrows(), total=len(csv_df)): 86 | vid_seg_id = row['id'] 87 | vid_id = row['vid_id'] 88 | num_segs = csv_df[csv_df.vid_id == vid_id].seg_id.max() + 1 89 | num_segs_list.append(num_segs) 90 | 91 | vid_id_ix, seg_id_ix = vid_seg_id.split('_segment_') 92 | seg_rgb_file = self.seg_feature_root / \ 93 | f'{vid_id_ix[2:]}_resnet.npy' 94 | seg_motion_file = self.seg_feature_root / f'{vid_id_ix[2:]}_bn.npy' 95 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy' 96 | out = (seg_rgb_file.exists() and seg_motion_file.exists() 97 | and region_feature_file.exists()) 98 | msk.append(out) 99 | 100 | csv_df['num_segs'] = num_segs_list 101 | csv_df = csv_df[msk] 102 | csv_df.to_csv(self.out_csv_dir / 103 | f'{csv_file_type}_postproc.csv', index=False, header=True) 104 | 105 | def post_proc_srl(self, train_file, val_file, test_file=None): 106 | """ 107 | Add the Index to each csv file 108 | This is required to get the correct proposals from h5 file 109 | """ 110 | def get_row_id(vid_seg, ann_df): 111 | vid_dict_row = ann_df[ann_df.id == 112 | vid_seg] 113 | if len(vid_dict_row) == 1: 114 | vid_dict_row_id = vid_dict_row.index[0] 115 | return vid_dict_row_id 116 | else: 117 | return -1 118 | 119 | self.vid_info_df = pd.DataFrame(self.vid_dict_list) 120 | self.vid_info_df.index.name = 'Index' 121 | 122 | trn_ann_df = pd.read_csv( 123 | self.out_csv_dir / f'{train_file}_postproc.csv') 124 | val_ann_df = pd.read_csv(self.out_csv_dir / f'{val_file}_postproc.csv') 125 | 126 | srl_trn_val = pd.read_csv(self.cfg.ds.verb_ent_file) 127 | 128 | trn_ann_ind = [] 129 | trn_msk = [] 130 | 131 | val_ann_ind = [] 132 | val_msk = [] 133 | vt_msk = [] 134 | 135 | for srl_ind, srl in tqdm(srl_trn_val.iterrows(), 136 | total=len(srl_trn_val)): 137 | req_args = ast.literal_eval(srl.req_args) 138 | if len(req_args) == 1: 139 | continue 140 | vid_seg = srl.vid_seg 141 | vid_dict_row = self.vid_info_df[self.vid_info_df.id == vid_seg] 142 | assert len(vid_dict_row) == 1 143 | vid_dict_row = vid_dict_row.iloc[0] 144 | split = vid_dict_row.split 145 | 146 | if split == 'training': 147 | ann_ind = get_row_id(vid_seg, trn_ann_df) 148 | if ann_ind == -1: 149 | print(split, vid_seg) 150 | continue 151 | trn_ann_ind.append(ann_ind) 152 | trn_msk.append(srl_ind) 153 | elif split == 'validation': 154 | ann_ind = get_row_id(vid_seg, val_ann_df) 155 | if ann_ind == -1: 156 | print(split, vid_seg) 157 | continue 158 | val_ann_ind.append(ann_ind) 159 | val_msk.append(srl_ind) 160 | vt_msk.append(val_ann_df.loc[ann_ind].vt_split) 161 | elif split == 'testing': 162 | pass 163 | else: 164 | raise NotImplementedError 165 | 166 | srl_trn = srl_trn_val.iloc[trn_msk] 167 | srl_trn['ann_ind'] = trn_ann_ind 168 | srl_trn['srl_ind'] = trn_msk 169 | srl_trn['vt_split'] = 'train' 170 | 171 | srl_val = srl_trn_val.iloc[val_msk] 172 | srl_val['ann_ind'] = val_ann_ind 173 | srl_val['srl_ind'] = val_msk 174 | srl_val['vt_split'] = vt_msk 175 | 176 | srl_trn.to_csv(self.cfg.ds.trn_verb_ent_file, 177 | index=False, header=True) 178 | srl_val.to_csv(self.cfg.ds.val_verb_ent_file, 179 | index=False, header=True) 180 | 181 | def process_arg_vocabs(self): 182 | def create_vocab(srl_annots, key): 183 | x_counter = Counter() 184 | for x_c in srl_annots[key]: 185 | x_counter += Counter(x_c) 186 | return vocab.Vocab(x_counter, specials_first=True) 187 | srl_annots = pd.read_csv(self.cfg.ds.trn_verb_ent_file) 188 | for k in srl_annots.columns: 189 | first_word = srl_annots.iloc[0][k] 190 | if isinstance(first_word, str) and first_word[0] == '[': 191 | srl_annots[k] = srl_annots[k].apply( 192 | lambda x: ast.literal_eval(x)) 193 | 194 | # arg_counter = Counter() 195 | # for r_arg in srl_annots.req_args: 196 | # arg_counter += Counter(r_arg) 197 | 198 | # arg_vocab = vocab.Vocab(arg_counter, specials_first=True) 199 | arg_vocab = create_vocab(srl_annots, 'req_args') 200 | arg_tag_vocab = create_vocab(srl_annots, 'tags') 201 | out_vocab = {'arg_vocab': arg_vocab, 'arg_tag_vocab': arg_tag_vocab} 202 | pickle.dump(out_vocab, file=open(self.cfg.ds.arg_vocab_file, 'wb')) 203 | return 204 | 205 | def glove_vocabs(self): 206 | # Load dictionaries 207 | self.comm.dic_anet = json.load(open(self.inp_dict_file)) 208 | # Get detections to index 209 | self.comm.dtoi = {w: i+1 for w, 210 | i in self.comm.dic_anet['wtod'].items()} 211 | self.comm.itod = {i: w for w, i in self.comm.dtoi.items()} 212 | self.comm.itow = self.comm.dic_anet['ix_to_word'] 213 | self.comm.wtoi = {w: i for i, w in self.comm.itow.items()} 214 | 215 | self.comm.vocab_size = len(self.comm.itow) + 1 216 | self.comm.detect_size = len(self.comm.itod) 217 | 218 | # Load the glove vocab 219 | self.glove = vocab.GloVe(name='6B', dim=300) 220 | 221 | # get the glove vector for the vg detection cls 222 | # From Peter's repo 223 | obj_cls_file = self.cfg.ds.vg_class_file 224 | # index 0 is the background 225 | with open(obj_cls_file) as f: 226 | data = f.readlines() 227 | classes = ['__background__'] 228 | classes.extend([i.strip() for i in data]) 229 | 230 | # for VG classes 231 | # self.comm.vg_cls = classes 232 | 233 | # Extract glove vectors for the Visual Genome Classes 234 | # TODO: Cleaner implementation possible 235 | # TODO: Preproc only once 236 | glove_vg_cls = np.zeros((len(classes), 300)) 237 | for i, w in enumerate(classes): 238 | split_word = w.replace(',', ' ').split(' ') 239 | vector = [] 240 | for word in split_word: 241 | if word in self.glove.stoi: 242 | vector.append( 243 | self.glove.vectors[self.glove.stoi[word]].numpy()) 244 | else: # use a random vector instead 245 | vector.append(2*np.random.rand(300) - 1) 246 | 247 | avg_vector = np.zeros((300)) 248 | for v in vector: 249 | avg_vector += v 250 | 251 | glove_vg_cls[i] = avg_vector/len(vector) 252 | 253 | # category id to labels. +1 becuase 0 is the background label 254 | # Extract glove vectors for the 431 classes in AnetEntDataset 255 | # TODO: Cleaner Implementation 256 | # TODO: Preproc only once 257 | glove_clss = np.zeros((len(self.comm.itod)+1, 300)) 258 | glove_clss[0] = 2*np.random.rand(300) - 1 # background 259 | for i, word in enumerate(self.comm.itod.values()): 260 | if word in self.glove.stoi: 261 | vector = self.glove.vectors[self.glove.stoi[word]] 262 | else: # use a random vector instead 263 | vector = 2*np.random.rand(300) - 1 264 | glove_clss[i+1] = vector 265 | 266 | # Extract glove vectors for the words from the vocab 267 | # TODO: cleaner implementation 268 | # TODO: preproc only once 269 | glove_w = np.zeros((len(self.comm.wtoi)+1, 300)) 270 | for i, word in enumerate(self.comm.wtoi.keys()): 271 | vector = np.zeros((300)) 272 | count = 0 273 | for w in word.split(' '): 274 | count += 1 275 | if w in self.glove.stoi: 276 | glove_vector = self.glove.vectors[self.glove.stoi[w]] 277 | vector += glove_vector.numpy() 278 | else: # use a random vector instead 279 | random_vector = 2*np.random.rand(300) - 1 280 | vector += random_vector 281 | glove_w[i+1] = vector / count 282 | 283 | out_dict = { 284 | 'classes': classes, 285 | 'glove_vg_cls': glove_vg_cls, 286 | 'glove_clss': glove_clss, 287 | 'glove_w': glove_w 288 | } 289 | pickle.dump(out_dict, open(self.cfg.ds.glove_stuff, 'wb')) 290 | 291 | 292 | if __name__ == '__main__': 293 | cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml'))) 294 | anet_csv = AnetCSV(cfg) 295 | 296 | # anet_csv.create_csvs() 297 | 298 | # anet_csv.post_proc('train') 299 | # anet_csv.post_proc('val') 300 | 301 | # anet_csv.post_proc_srl('train', 'val') 302 | anet_csv.process_arg_vocabs() 303 | -------------------------------------------------------------------------------- /dcode/process_gt_props.py: -------------------------------------------------------------------------------- 1 | """ 2 | By default, we are using proposal boxes. 3 | Instead, we only consider the gts. 4 | """ 5 | 6 | import numpy as np 7 | from pathlib import Path 8 | import h5py 9 | import json 10 | import pandas as pd 11 | from tqdm import tqdm 12 | import copy 13 | from box_utils import box_iou 14 | import torch 15 | from collections import OrderedDict 16 | 17 | 18 | class GTPropExtractor(object): 19 | def __init__(self, cfg): 20 | self.cfg = cfg 21 | 22 | # Assert h5 file to read from exists 23 | self.proposal_h5 = Path(self.cfg.ds.proposal_h5_resized) 24 | assert self.proposal_h5.exists() 25 | 26 | with h5py.File(self.proposal_h5, 'r', 27 | driver='core') as h5_proposal_file: 28 | self.num_proposals = h5_proposal_file['dets_num'][:] 29 | self.label_proposals = h5_proposal_file['dets_labels'][:] 30 | 31 | nppf = self.cfg.ds.ngt_prop 32 | self.out_label_proposals = np.zeros_like( 33 | self.label_proposals)[:, :10*nppf, ...] 34 | self.out_num_proposals = np.zeros_like(self.num_proposals) 35 | 36 | # Assert region features exists 37 | self.feature_root = Path(self.cfg.ds.feature_root) 38 | assert self.feature_root.exists() 39 | 40 | # Assert act ent caption file with bbox exists 41 | self.anet_ent_annot_file = Path(self.cfg.ds.anet_ent_annot_file) 42 | assert self.anet_ent_annot_file.exists() 43 | 44 | if cfg.ds.ngt_prop == 5: 45 | self.out_dir = Path(self.cfg.ds.feature_gt5_root) 46 | self.out_proposal_h5 = Path(self.cfg.ds.proposal_gt5_h5) 47 | else: 48 | raise NotImplementedError 49 | 50 | self.out_dir.mkdir(exist_ok=True) 51 | # Load anet bbox 52 | with open(self.anet_ent_annot_file) as f: 53 | self.anet_ent_captions = json.load(f) 54 | 55 | # trn_df = pd.read_csv(self.cfg.ds.trn_ann_file) 56 | # val_df = pd.read_csv(self.cfg.ds.val_ann_file) 57 | 58 | # self.req_df = pd.concat([trn_df, val_df]) 59 | 60 | def do_for_all_vid_seg(self, save=True): 61 | recall_num = 0 62 | recall_tot = 0 63 | self.cfg.no_gt_count = 0 64 | for row_num, vid_seg_row in tqdm(self.req_df.iterrows(), 65 | total=len(self.req_df)): 66 | vid_seg_id = vid_seg_row['id'] 67 | vid_seg = vid_seg_id.split('_segment_') 68 | vid = vid_seg[0] 69 | seg = str(int(vid_seg[1])) 70 | 71 | annot = self.anet_ent_captions[vid]['segments'][seg] 72 | gt_boxs = annot['bbox'] 73 | gt_frms = annot['frm_idx'] 74 | 75 | prop_index = vid_seg_row['Index'] 76 | props = copy.deepcopy(self.label_proposals[prop_index]) 77 | num_props = int(copy.deepcopy(self.num_proposals[prop_index])) 78 | 79 | if num_props < 1000: 80 | # import pdb 81 | # pdb.set_trace() 82 | assert np.all(props[num_props:, [0, 1, 2, 3]] == 0) 83 | 84 | region_feature_file = self.feature_root / f'{vid_seg_id}.npy' 85 | # if save: 86 | prop_feats_load = np.load(region_feature_file) 87 | prop_feats = np.zeros((10, *prop_feats_load.shape[1:])) 88 | prop_feats[:prop_feats_load.shape[0]] = prop_feats_load 89 | # prop_feats = prop_feats.reshape(-1, prop_feats.shape[2]).copy() 90 | # prop_feats = prop_feats[:num_props, ...] 91 | # assert len(prop_feats) == len(props) 92 | # assert len(props) == num_props 93 | 94 | # else: 95 | # prop_feats = None 96 | 97 | out_file = self.out_dir / f'{vid_seg_id}.npy' 98 | # out_dict = self.do_for_one_vid_seg( 99 | # props, prop_feats, gt_boxs, gt_frms, out_file, 100 | # save=save 101 | # ) 102 | nppf = self.cfg.ds.ngt_prop 103 | out_dict = self.prop10_one_vid_seg( 104 | props, prop_feats, gt_boxs, gt_frms, out_file, 105 | save=save, nppf=nppf 106 | ) 107 | # out_dict = self.no_gt_prop10_one_vid_seg( 108 | # props, prop_feats, gt_boxs, gt_frms, out_file, 109 | # save=save 110 | # ) 111 | 112 | if save: 113 | num_prop = out_dict['num_prop'] 114 | self.out_label_proposals[prop_index][:num_prop] = ( 115 | out_dict['out_props'] 116 | ) 117 | self.out_num_proposals[prop_index] = num_prop 118 | 119 | recall_num += out_dict['recall'] 120 | recall_tot += out_dict['num_gt'] 121 | # if row_num > 1000: 122 | # break 123 | recall = recall_num.item() / recall_tot 124 | if save: 125 | with h5py.File(self.out_proposal_h5, 'w') as f: 126 | f['dets_labels'] = self.out_label_proposals 127 | f['dets_num'] = self.out_num_proposals 128 | return recall 129 | 130 | def prop10_one_vid_seg(self, props, prop_feats, 131 | gt_boxs, gt_frms, out_file, 132 | save=True, nppf=10): 133 | nfrms = 10 134 | props = torch.tensor(props).float() 135 | prop_feats = torch.tensor(prop_feats).float() 136 | # gt_frms_dict = {} 137 | # for gfrm, gbox in zip(gt_frms, gt_boxs): 138 | # if gfrm not in gt_frms_dict: 139 | # gt_frms_dict[gfrm] = [] 140 | # gt_frms_dict[gfrm].append(gbox) 141 | gt_frms_set = set(gt_frms) 142 | gt_boxs = torch.tensor(gt_boxs).float() 143 | gt_frms = torch.tensor(gt_frms).float() 144 | 145 | ngt = len(gt_boxs) 146 | 147 | nppf = nppf 148 | 149 | prop_frms = props[:, 4] 150 | frm_msk = prop_frms[:, None] == gt_frms 151 | if len(gt_boxs) > 0 and len(props) > 0: 152 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float() 153 | ious_max, ious_arg_max = ious.max(dim=0) 154 | if len(ious_arg_max) > nppf: 155 | ious_arg_max = ious_arg_max[:nppf] 156 | out_props = props[ious_arg_max] 157 | out_props_inds = ious_arg_max % 100 158 | recall = (ious_max > 0.5).sum() 159 | else: 160 | self.cfg.no_gt_count += 1 161 | ngt = 1 162 | recall = 0 163 | ious = torch.zeros(props.size(0), 1) 164 | out_props = props[0] 165 | out_props_inds = torch.tensor(0) 166 | 167 | fin_out_props = {} 168 | props1 = props.view(10, 100, 7) 169 | prop_dim = prop_feats.size(-1) 170 | prop_feats1 = prop_feats.view(10, 100, prop_dim) 171 | 172 | for frm in range(nfrms): 173 | if frm not in fin_out_props: 174 | fin_out_props[frm] = [] 175 | 176 | if frm in gt_frms_set: 177 | props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm] 178 | fin_out_props[frm] += props_inds_gt_in_frm.tolist() 179 | 180 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[ 181 | :nppf] 182 | # props_to_use_inds = np.random.choice( 183 | # np.arange(100), size=10, replace=False 184 | # ) 185 | fin_out_props[frm] += props_to_use_inds.tolist() 186 | 187 | fin_out_props[frm] = list( 188 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf] 189 | 190 | props_output = torch.zeros(10, nppf, 7) 191 | prop_feats_output = torch.zeros(10, nppf, prop_dim) 192 | 193 | for frm in fin_out_props: 194 | inds = fin_out_props[frm] 195 | props_output[frm] = props1[frm][inds] 196 | prop_feats_output[frm] = prop_feats1[frm][inds] 197 | 198 | props_output = props_output.view(10*nppf, 7).detach().cpu().numpy() 199 | prop_feats_output = prop_feats_output.view( 200 | 10, nppf, prop_dim).detach().cpu().numpy() 201 | 202 | if save: 203 | np.save(out_file, prop_feats_output) 204 | 205 | return { 206 | 'out_props': props_output, 207 | 'recall': recall, 208 | 'num_prop': 100, 209 | 'num_gt': ngt 210 | } 211 | 212 | def no_gt_prop10_one_vid_seg(self, props, prop_feats, 213 | gt_boxs, gt_frms, out_file, 214 | save=True): 215 | nfrms = 10 216 | props = torch.tensor(props).float() 217 | prop_feats = torch.tensor(prop_feats).float() 218 | # gt_frms_dict = {} 219 | # for gfrm, gbox in zip(gt_frms, gt_boxs): 220 | # if gfrm not in gt_frms_dict: 221 | # gt_frms_dict[gfrm] = [] 222 | # gt_frms_dict[gfrm].append(gbox) 223 | gt_frms_set = set(gt_frms) 224 | gt_boxs = torch.tensor(gt_boxs).float() 225 | gt_frms = torch.tensor(gt_frms).float() 226 | 227 | ngt = len(gt_boxs) 228 | 229 | nppf = 100 230 | 231 | fin_out_props = {} 232 | props1 = props.view(10, 100, 7) 233 | prop_dim = prop_feats.size(-1) 234 | prop_feats1 = prop_feats.view(10, 100, prop_dim) 235 | 236 | for frm in range(nfrms): 237 | if frm not in fin_out_props: 238 | fin_out_props[frm] = [] 239 | 240 | # if frm in gt_frms_set: 241 | # props_inds_gt_in_frm = out_props_inds[out_props[..., 4] == frm] 242 | # fin_out_props[frm] += props_inds_gt_in_frm.tolist() 243 | props_to_use_inds = props1[frm, ..., 6].argsort(descending=True)[ 244 | :nppf] 245 | fin_out_props[frm] += props_to_use_inds.tolist() 246 | 247 | fin_out_props[frm] = list( 248 | OrderedDict.fromkeys(fin_out_props[frm]))[:nppf] 249 | 250 | props_output = torch.zeros(10, nppf, 7) 251 | prop_feats_output = torch.zeros(10, nppf, prop_dim) 252 | 253 | for frm in fin_out_props: 254 | inds = fin_out_props[frm] 255 | props_output[frm] = props1[frm][inds] 256 | prop_feats_output[frm] = prop_feats1[frm][inds] 257 | 258 | props_output = props_output.view(nfrms * nppf, 7) 259 | prop_feats_output = prop_feats_output.view( 260 | nfrms, nppf, prop_dim).detach().cpu().numpy() 261 | 262 | if len(gt_boxs) > 0 and len(props_output) > 0: 263 | prop_frms = props_output[:, 4] 264 | frm_msk = prop_frms[:, None] == gt_frms 265 | ious = box_iou(props_output[:, :4], gt_boxs) * frm_msk.float() 266 | ious_max, ious_arg_max = ious.max(dim=0) 267 | recall = (ious_max > 0.5).sum() 268 | else: 269 | self.cfg.no_gt_count += 1 270 | ngt = 1 271 | recall = 0 272 | ious = torch.zeros(props.size(0), 1) 273 | 274 | props_output = props_output.detach().cpu().numpy() 275 | 276 | if save: 277 | np.save(out_file, prop_feats_output) 278 | 279 | return { 280 | 'out_props': props_output, 281 | 'recall': recall, 282 | 'num_prop': 100, 283 | 'num_gt': ngt 284 | } 285 | 286 | def do_for_one_vid_seg(self, props, prop_feats, 287 | gt_boxs, gt_frms, out_file, 288 | save=True): 289 | """ 290 | props: all the proposal boxes 291 | gt_boxs: all the groundtruth_boxes 292 | out_props: props with highest IoU with gt_box 293 | # nframes x 1, 294 | one-to-one correspondence 295 | Also, used to calculate recall. 296 | """ 297 | props = torch.tensor(props).float() 298 | gt_boxs = torch.tensor(gt_boxs).float() 299 | gt_frms = torch.tensor(gt_frms).float() 300 | 301 | ngt = len(gt_boxs) 302 | 303 | prop_frms = props[:, 4] 304 | frm_msk = prop_frms[:, None] == gt_frms 305 | 306 | if len(gt_boxs) > 0 and len(props) > 0: 307 | ious = box_iou(props[:, :4], gt_boxs) * frm_msk.float() 308 | ious_max, ious_arg_max = ious.max(dim=0) 309 | recall = (ious_max > 0.5).sum().float() 310 | out_props = props[ious_arg_max] 311 | else: 312 | self.cfg.no_gt_count += 1 313 | ngt = 1 314 | recall = 0 315 | ious = torch.zeros(props.size(0), 1) 316 | out_props = props[0] 317 | 318 | nprop = ngt 319 | if save: 320 | prop_dim = prop_feats.size(-1) 321 | prop_feats = torch.tensor(prop_feats).float() 322 | out_prop_feats = prop_feats[ious_arg_max].view( 323 | 1, ngt, prop_dim).detach().cpu().numpy() 324 | assert list(out_prop_feats.shape[:2]) == [1, ngt] 325 | np.save(out_file, out_prop_feats) 326 | 327 | return { 328 | 'out_props': out_props, 329 | 'recall': recall, 330 | 'num_prop': nprop, 331 | 'num_gt': ngt 332 | } 333 | 334 | 335 | if __name__ == '__main__': 336 | from extended_config import cfg as conf 337 | cfg = conf 338 | gtp = GTPropExtractor(cfg) 339 | recall = gtp.do_for_all_vid_seg(save=True) 340 | print(recall) 341 | -------------------------------------------------------------------------------- /dcode/sem_role_labeller.py: -------------------------------------------------------------------------------- 1 | """ 2 | Perform semantic role labeling for input captions 3 | """ 4 | from allennlp.predictors.predictor import Predictor 5 | import pandas as pd 6 | import pickle 7 | import json 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | import time 13 | import fire 14 | import re 15 | from typing import List, Dict, Any, Union 16 | 17 | SRL_BERT = ( 18 | "https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz") 19 | 20 | srl_out_patt = re.compile(r'\[(.*?)\]') 21 | 22 | Fpath = Union[Path, str] 23 | Cft = CN 24 | 25 | 26 | class SRL_DS: 27 | """ 28 | A base class to perform semantic role labeling 29 | """ 30 | 31 | def __init__(self, cfg: Cft, tdir: str = '.'): 32 | self.cfg = cfg 33 | self.tdir = Path(tdir) 34 | archive_path = SRL_BERT 35 | self.srl = Predictor.from_path( 36 | archive_path=archive_path, 37 | predictor_name='semantic-role-labeling', 38 | cuda_device=0 39 | ) 40 | # self.srl = Predictor.from_path( 41 | # "https://s3-us-west-2.amazonaws.com/allennlp/models/srl-model-2018.05.25.tar.gz") 42 | self.name = self.__class__.__name__ 43 | self.cache_dir = self.tdir / \ 44 | Path(f'{self.cfg.misc.cache_dir}/{self.name}') 45 | self.cache_dir.mkdir(exist_ok=True, parents=True) 46 | self.out_file = (self.cache_dir / f'{self.cfg.ds.srl_bert}') 47 | self.after_init() 48 | 49 | def after_init(self): 50 | pass 51 | 52 | def get_annotations(self) -> pd.DataFrame: 53 | """ 54 | Expected to read a file, 55 | Create a df with the columns: 56 | vid_id, seg_id, sentence 57 | """ 58 | raise NotImplementedError 59 | 60 | def do_predictions(self): 61 | annot_df = self.get_annotations() 62 | sents = annot_df.to_dict('records') 63 | st_time = time.time() 64 | out_list = [] 65 | tot_len = len(sents) 66 | try: 67 | for idx in tqdm(range(0, len(sents), 50)): 68 | out = self.srl.predict_batch_json( 69 | sents[idx:min(idx+50, tot_len)]) 70 | out_list += out 71 | except RuntimeError: 72 | pass 73 | finally: 74 | end_time = time.time() 75 | print(f'Took time {end_time-st_time}') 76 | pickle.dump(out_list, open(self.out_file, 'wb')) 77 | self.update_preds() 78 | 79 | def update_preds(self): 80 | preds = pickle.load(open(self.out_file, 'rb')) 81 | for pred in tqdm(preds): 82 | for verb in pred['verbs']: 83 | verb['req_pat'] = srl_out_patt.findall(verb['description']) 84 | pickle.dump(preds, open(self.out_file, 'wb')) 85 | 86 | 87 | class SRL_Anet(SRL_DS): 88 | def after_init(self): 89 | """ 90 | Assert files exists 91 | """ 92 | # Assert Raw Caption Files exists 93 | self.trn_anet_cap_file = self.tdir / Path(self.cfg.ds.anet_cap_file) 94 | assert self.trn_anet_cap_file.exists() 95 | 96 | def get_annotations(self): 97 | trn_cap_data = json.load(open(self.trn_anet_cap_file)) 98 | trn_vid_list = list(trn_cap_data.keys()) 99 | out_dict_list = [] 100 | for trn_vid_name in tqdm(trn_vid_list): 101 | trn_vid_segs = trn_cap_data[trn_vid_name] 102 | num_segs = len(trn_vid_segs['timestamps']) 103 | for seg in range(num_segs): 104 | out_dict = { 105 | 'time_stamp': trn_vid_segs['timestamps'][seg], 106 | 'vid': trn_vid_name, 107 | 'vid_seg': f'{trn_vid_name}_segment_{seg:02d}', 108 | 'segment': seg, 109 | 'sentence': trn_vid_segs['sentences'][seg] 110 | } 111 | out_dict_list.append(out_dict) 112 | out_df = pd.DataFrame(out_dict_list) 113 | out_df.to_csv( 114 | ( 115 | self.cache_dir / 116 | f'{self.cfg.ds.srl_caps}' 117 | ), 118 | header=True, index=False 119 | ) 120 | return out_df 121 | 122 | 123 | def main(): 124 | cfg_file = './configs/create_asrl_cfg.yml' 125 | cfg = CN(yaml.safe_load(open(cfg_file))) 126 | print(cfg) 127 | srl_ds = SRL_Anet(cfg) 128 | srl_ds.do_predictions() 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | # fire.Fire(main) 134 | -------------------------------------------------------------------------------- /media/Intro_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/Intro_fig.png -------------------------------------------------------------------------------- /media/contrastive_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/contrastive_examples.png -------------------------------------------------------------------------------- /media/contrastive_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/contrastive_samples.png -------------------------------------------------------------------------------- /media/model_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/model_fig.png -------------------------------------------------------------------------------- /media/tempora_spatial_concat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/media/tempora_spatial_concat.png -------------------------------------------------------------------------------- /notebooks/data_stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import altair as alt" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "/home/arka/Ark_git_files/vognet-pytorch\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "cd .." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "import sys\n", 55 | "sys.path.append('./dcode')\n", 56 | "sys.path.append('./code')\n", 57 | "sys.path.append('./utils')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 6, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "from dataset_stats import AnetSRL_Vis" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 7, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "from yacs.config import CfgNode as CN\n", 76 | "import yaml\n", 77 | "cfg = CN(yaml.safe_load(open('./configs/create_asrl_cfg.yml')))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 8, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "avis = AnetSRL_Vis(cfg)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 9, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "Index(['gt_bboxes', 'gt_frms', 'lemma_ARG0', 'lemma_ARG1', 'lemma_ARG2',\n", 98 | " 'lemma_ARGM_LOC', 'lemma_verb', 'process_clss', 'process_idx2',\n", 99 | " 'req_aname', 'req_args', 'req_cls_pats', 'req_cls_pats_mask', 'req_pat',\n", 100 | " 'req_pat_ix', 'sent', 'tags', 'verb', 'vid_seg', 'words', 'ann_ind',\n", 101 | " 'srl_ind', 'vt_split', 'DS4_Inds', 'ds4_msk', 'RandDS4_Inds'],\n", 102 | " dtype='object')" 103 | ] 104 | }, 105 | "execution_count": 9, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "avis.trn_srl_annots.columns" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "len(avis.trn_srl_annots.vid_seg.unique())" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "len(avis.val_srl_annots[avis.val_srl_annots.vt_split == 'val'].vid_seg.unique())" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "avis.vis=True" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 21, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "8.049772762889829" 150 | ] 151 | }, 152 | "execution_count": 21, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "avis.trn_srl_annots.req_pat_ix.apply(lambda x: sum([len(y[1]) for y in x])).mean()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 22, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "Number of videos in Train/Valid/Test: 31718, 3891, 3914\n", 171 | "Number of Queries per Video is 2.0117914118166342\n", 172 | "Number of Queries per Video is 3.455868986052343\n", 173 | "Number of Queries per Video is 8.049772762889829\n", 174 | "Noun Phrases Count\n", 175 | "Arg,Count\n", 176 | "V,63812\n", 177 | "ARG0,48342\n", 178 | "ARG1,47335\n", 179 | "ARG2,16200\n", 180 | "ARGM-TMP,12061\n", 181 | "ARGM-DIR,8876\n", 182 | "ARGM-LOC,7408\n", 183 | "ARGM-MNR,5702\n", 184 | "ARGM-ADV,3661\n", 185 | "ARGM-PRP,1417\n", 186 | "ARG4,1238\n", 187 | "ARGM-PRD,905\n", 188 | "ARG3,854\n", 189 | "ARGM-GOL,447\n", 190 | "R-ARG0,423\n", 191 | "ARGM-COM,314\n", 192 | "R-ARG1,303\n", 193 | "C-ARG1,286\n", 194 | "ARGM-EXT,188\n", 195 | "ARGM-DIS,145\n", 196 | "\n", 197 | "Groundable Noun Phrase Count\n", 198 | "Arg,Count\n", 199 | "ARG0,42472\n", 200 | "ARG1,32455\n", 201 | "ARG2,9520\n", 202 | "ARGM-LOC,5082\n", 203 | "ARGM-TMP,3505\n", 204 | "ARGM-DIR,2936\n", 205 | "ARGM-MNR,2168\n", 206 | "ARGM-ADV,2036\n", 207 | "ARG4,947\n", 208 | "ARGM-PRP,690\n", 209 | "ARG3,538\n", 210 | "ARGM-GOL,310\n", 211 | "ARGM-PRD,298\n", 212 | "V,256\n", 213 | "ARGM-COM,209\n", 214 | "C-ARG1,186\n", 215 | "C-ARG0,26\n", 216 | "ARGM-CAU,20\n", 217 | "ARGM-PNC,17\n", 218 | "ARGM-EXT,13\n", 219 | "\n", 220 | "SRL Structures Frequency\n", 221 | "Arg,Count\n", 222 | "ARG0-V-ARG1,13654\n", 223 | "ARG0-V-ARG1-ARG2,3372\n", 224 | "ARG1-V-ARG2,3135\n", 225 | "ARG0-V,3080\n", 226 | "ARG0-V-ARGM-DIR,2269\n", 227 | "ARG0-V-ARG2,2075\n", 228 | "ARG0-V-ARG1-ARGM-LOC,1689\n", 229 | "ARG1-V,1631\n", 230 | "V-ARG1,1383\n", 231 | "ARG0-V-ARGM-LOC,1358\n", 232 | "ARG0-V-ARG1-ARGM-TMP,1290\n", 233 | "ARG0-V-ARG1-ARGM-MNR,862\n", 234 | "ARG0-V-ARG1-ARGM-DIR,838\n", 235 | "ARG0-ARGM-TMP-V-ARG1,754\n", 236 | "ARG1-V-ARG2-ARGM-ADV,743\n", 237 | "ARG1-V-ARGM-DIR,735\n", 238 | "ARGM-TMP-ARG0-V-ARG1,729\n", 239 | "ARG1-V-ARGM-LOC,586\n", 240 | "ARG0-V-ARGM-TMP,558\n", 241 | "ARG2-V-ARG1,526\n", 242 | "\n", 243 | "Lemmatized Counts for each lemma: {'lemma_verb': 338, 'lemma_ARG0': 93, 'lemma_ARG1': 281, 'lemma_ARG2': 114, 'lemma_ARGM_LOC': 59}\n", 244 | "Most Frequent Lemmas for lemma_verb\n", 245 | "String,Count\n", 246 | "stand,2395\n", 247 | "play,2152\n", 248 | "hold,1662\n", 249 | "talk,1626\n", 250 | "put,1458\n", 251 | "sit,1402\n", 252 | "speak,1190\n", 253 | "use,1057\n", 254 | "run,1053\n", 255 | "take,993\n", 256 | "walk,990\n", 257 | "throw,945\n", 258 | "go,930\n", 259 | "ride,906\n", 260 | "move,904\n", 261 | "walks,803\n", 262 | "wear,765\n", 263 | "get,740\n", 264 | "do,737\n", 265 | "look,714\n", 266 | "hit,690\n", 267 | "\n", 268 | "Most Frequent Lemmas for lemma_ARG0\n", 269 | "String,Count\n", 270 | ",21439\n", 271 | "man,8252\n", 272 | "he,7973\n", 273 | "woman,4095\n", 274 | "she,4081\n", 275 | "people,3360\n", 276 | "they,2048\n", 277 | "person,1785\n", 278 | "girl,1067\n", 279 | "boy,1053\n", 280 | "lady,789\n", 281 | "player,436\n", 282 | "dog,372\n", 283 | "child,360\n", 284 | "team,339\n", 285 | "kid,337\n", 286 | "athlete,272\n", 287 | "shirt,263\n", 288 | "guy,250\n", 289 | "gymnast,218\n", 290 | "other,196\n", 291 | "\n", 292 | "Most Frequent Lemmas for lemma_ARG1\n", 293 | "String,Count\n", 294 | ",31440\n", 295 | "he,2459\n", 296 | "man,1967\n", 297 | "it,1433\n", 298 | "woman,1163\n", 299 | "she,1132\n", 300 | "people,1097\n", 301 | "ball,967\n", 302 | "they,792\n", 303 | "hand,467\n", 304 | "hair,413\n", 305 | "person,380\n", 306 | "dog,354\n", 307 | "car,318\n", 308 | "girl,317\n", 309 | "screen,300\n", 310 | "water,298\n", 311 | "boy,297\n", 312 | "rope,269\n", 313 | "shoe,245\n", 314 | "shirt,235\n", 315 | "\n", 316 | "Most Frequent Lemmas for lemma_ARG2\n", 317 | "String,Count\n", 318 | ",54374\n", 319 | "he,525\n", 320 | "table,322\n", 321 | "she,259\n", 322 | "woman,244\n", 323 | "water,204\n", 324 | "man,203\n", 325 | "it,186\n", 326 | "people,176\n", 327 | "floor,151\n", 328 | "field,149\n", 329 | "room,148\n", 330 | "wall,144\n", 331 | "board,141\n", 332 | "car,139\n", 333 | "ground,128\n", 334 | "chair,114\n", 335 | "tree,101\n", 336 | "bar,100\n", 337 | "they,100\n", 338 | "ball,94\n", 339 | "\n", 340 | "Most Frequent Lemmas for lemma_ARGM_LOC\n", 341 | "String,Count\n", 342 | ",58825\n", 343 | "he,243\n", 344 | "water,238\n", 345 | "room,230\n", 346 | "field,200\n", 347 | "screen,134\n", 348 | "gym,128\n", 349 | "stage,124\n", 350 | "table,106\n", 351 | "floor,99\n", 352 | "court,91\n", 353 | "bar,87\n", 354 | "beach,86\n", 355 | "street,84\n", 356 | "pool,80\n", 357 | "she,79\n", 358 | "board,77\n", 359 | "mat,75\n", 360 | "woman,72\n", 361 | "ground,68\n", 362 | "track,65\n", 363 | "\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "vlist = avis.print_all_stats()" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "vlist" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "vlist[2]" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "out = avis.visualize_df(nnp_srl, x_name='Count:Q', y_name='Arg:O')" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "out" 405 | ] 406 | } 407 | ], 408 | "metadata": { 409 | "kernelspec": { 410 | "display_name": "Python 3", 411 | "language": "python", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "codemirror_mode": { 416 | "name": "ipython", 417 | "version": 3 418 | }, 419 | "file_extension": ".py", 420 | "mimetype": "text/x-python", 421 | "name": "python", 422 | "nbconvert_exporter": "python", 423 | "pygments_lexer": "ipython3", 424 | "version": "3.7.3" 425 | } 426 | }, 427 | "nbformat": 4, 428 | "nbformat_minor": 2 429 | } 430 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # File Organization 2 | 3 | 1. `box_utils.py` as the name duly suggest, contains utils for box iou, box area computation. Also, contains code for box iou for multiple frames 4 | 1. `mdl_srl_utils.py` has convenience functions for the models (surprise surprise). Stuff like LSTM implementation adapted from fairseq. 5 | 1. `trn_utils.py` contains learner which handles the model saving/loading, logging stuff, saving predictions among other things. 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheShadow29/vognet-pytorch/238e93c37cf9f03a2fd376a14760bb3d334a113d/utils/__init__.py -------------------------------------------------------------------------------- /utils/box_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for boxes 3 | Adapted from 4 | https://github.com/facebookresearch/maskrcnn-benchmark/ 5 | blob/master/maskrcnn_benchmark/structures/boxlist_ops.py 6 | """ 7 | import torch 8 | 9 | TO_REMOVE = 0 10 | 11 | 12 | def get_area(box): 13 | """ 14 | box: [N, 4] 15 | torch.tensor of 16 | type x1y1x2y2 17 | """ 18 | area = ( 19 | (box[:, 2] - box[:, 0] + TO_REMOVE) * 20 | (box[:, 3] - box[:, 1] + TO_REMOVE) 21 | ) 22 | return area 23 | 24 | 25 | def box_iou(box1, box2): 26 | """ 27 | box1: [N, 4] 28 | box2: [M, 4] 29 | both of type torch.tensor 30 | Assumes both of type x1y1x2y2 31 | output: [N,M] 32 | """ 33 | if len(box1.shape) == 1 and len(box1) == 4: 34 | box1 = box1.unsqueeze(0) 35 | if len(box2.shape) == 1 and len(box2) == 4: 36 | box2 = box2.unsqueeze(0) 37 | 38 | N = len(box1) 39 | M = len(box2) 40 | 41 | area1 = get_area(box1) 42 | area2 = get_area(box2) 43 | 44 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] 45 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] 46 | 47 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] 48 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 49 | 50 | iou = inter / (area1[:, None] + area2 - inter) 51 | return iou 52 | 53 | 54 | def bbox_overlaps(rois, gt_box, frm_mask): 55 | 56 | overlaps = bbox_overlaps_batch(rois[:, :, :5], gt_box[:, :, :5], frm_mask) 57 | 58 | return overlaps 59 | 60 | 61 | def bbox_overlaps_batch(anchors, gt_boxes, frm_mask=None): 62 | """ 63 | Source: 64 | https://github.com/facebookresearch/grounded-video-description/blob/ 65 | master/misc/bbox_transform.py#L176 66 | anchors: (b, N, 4) ndarray of float 67 | gt_boxes: (b, K, 5) ndarray of float 68 | frm_mask: (b, N, K) ndarray of bool 69 | 70 | overlaps: (b, N, K) ndarray of overlap between boxes and query_boxes 71 | """ 72 | batch_size = gt_boxes.size(0) 73 | 74 | N = anchors.size(1) 75 | K = gt_boxes.size(1) 76 | 77 | anchors = anchors[:, :, :5].contiguous() 78 | gt_boxes = gt_boxes[:, :, :5].contiguous() 79 | 80 | gt_boxes_x = (gt_boxes[:, :, 2] - gt_boxes[:, :, 0] + 1) 81 | gt_boxes_y = (gt_boxes[:, :, 3] - gt_boxes[:, :, 1] + 1) 82 | gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K) 83 | 84 | anchors_boxes_x = (anchors[:, :, 2] - anchors[:, :, 0] + 1) 85 | anchors_boxes_y = (anchors[:, :, 3] - anchors[:, :, 1] + 1) 86 | anchors_area = (anchors_boxes_x * 87 | anchors_boxes_y).view(batch_size, N, 1) 88 | 89 | gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1) 90 | anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1) 91 | 92 | boxes = anchors.view(batch_size, N, 1, 5).expand(batch_size, N, K, 5) 93 | query_boxes = gt_boxes.view( 94 | batch_size, 1, K, 5).expand(batch_size, N, K, 5) 95 | 96 | iw = (torch.min(boxes[:, :, :, 2], query_boxes[:, :, :, 2]) - 97 | torch.max(boxes[:, :, :, 0], query_boxes[:, :, :, 0]) + 1) 98 | iw[iw < 0] = 0 99 | 100 | ih = (torch.min(boxes[:, :, :, 3], query_boxes[:, :, :, 3]) - 101 | torch.max(boxes[:, :, :, 1], query_boxes[:, :, :, 1]) + 1) 102 | ih[ih < 0] = 0 103 | ua = anchors_area + gt_boxes_area - (iw * ih) 104 | 105 | if frm_mask is not None: 106 | # proposal and gt should be on the same frame to overlap 107 | # print('Percentage of proposals that are in the annotated frame: {}'.format(torch.mean(frm_mask.float()))) 108 | 109 | overlaps = iw * ih / ua 110 | overlaps *= frm_mask.type(overlaps.type()) 111 | 112 | # mask the overlap here. 113 | overlaps.masked_fill_(gt_area_zero.view( 114 | batch_size, 1, K).expand(batch_size, N, K), 0) 115 | overlaps.masked_fill_(anchors_area_zero.view( 116 | batch_size, N, 1).expand(batch_size, N, K), -1) 117 | 118 | return overlaps 119 | -------------------------------------------------------------------------------- /utils/mdl_srl_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some helpful functions/classes are defined 3 | """ 4 | import torch 5 | from torch import nn 6 | # from fairseq.models import FairseqEncoder 7 | from torch.nn import functional as F 8 | from fairseq import utils 9 | 10 | 11 | def combine_first_ax(inp_tensor, keepdim=False): 12 | inp_shape = inp_tensor.shape 13 | if keepdim: 14 | return inp_tensor.view( 15 | 1, inp_shape[0] * inp_shape[1], *inp_shape[2:]) 16 | return inp_tensor.view( 17 | inp_shape[0] * inp_shape[1], *inp_shape[2:]) 18 | 19 | 20 | def uncombine_first_ax(inp_tensor, s0): 21 | "s0 is the size(0) intended, usually B" 22 | inp_shape = inp_tensor.shape 23 | size0 = inp_tensor.size(0) 24 | assert size0 % s0 == 0 25 | s1 = size0 // s0 26 | return inp_tensor.view( 27 | s0, s1, *inp_shape[1:]) 28 | 29 | 30 | def do_cross(x1, x2=None, dim1=1, op='add'): 31 | """ 32 | if x2 is none do x1(row) + x1(col) 33 | else x1(row) + x2(col) 34 | dim1, dim2 are first and second dimension 35 | to be used for crossing. 36 | Both x1, x2 should have same shape except 37 | for at most one dimension 38 | 39 | if input is B x C x D x E with dim1=1 40 | B x C x D x E -> 41 | B x C x 1 x D x E -> B x C x C x D x E; 42 | B x 1 x C x D x E -> B x C x C x D x E; 43 | and then add 44 | 45 | op = 'add', 'subt', 'mult' or 'concat' 46 | """ 47 | x1_shape = x1.shape 48 | if x2 is None: 49 | x2 = x1 50 | assert x1.shape == x2.shape 51 | x1_dim = len(x1_shape) 52 | out_shape = tuple((*x1_shape[:dim1], x1_shape[dim1], *x1_shape[dim1:])) 53 | if dim1 < x1_dim: 54 | x1_row = x1.view(*x1_shape[:dim1+1], 1, * 55 | x1_shape[dim1+1:]).expand(out_shape) 56 | x2_col = x2.view(*x1_shape[:dim1], 1, * 57 | x1_shape[dim1:]).expand(out_shape) 58 | else: 59 | x1_row = x1.view(*x1_shape[:dim1+1], 1) 60 | x2_col = x2.view(*x1_shape[:dim1], 1, x1_shape[dim1]) 61 | 62 | if op == 'add': 63 | return (x1_row + x2_col) / 2 64 | elif op == 'mult': 65 | return (x1_row * x2_col) 66 | elif op == 'concat': 67 | return torch.cat([x1_row, x2_col], dim=-1) 68 | elif op == 'subtract': 69 | return (x1_row - x2_col) 70 | 71 | 72 | class LSTMEncoder(nn.Module): 73 | """LSTM encoder.""" 74 | 75 | def __init__( 76 | self, cfg, comm, embed_dim=512, hidden_size=512, num_layers=1, 77 | dropout_in=0.1, dropout_out=0.1, bidirectional=False, 78 | left_pad=True, pretrained_embed=None, padding_value=0., 79 | num_embeddings=0, pad_idx=0 80 | ): 81 | super().__init__() 82 | self.cfg = cfg 83 | self.comm = comm 84 | self.num_layers = num_layers 85 | self.dropout_in = dropout_in 86 | self.dropout_out = dropout_out 87 | self.bidirectional = bidirectional 88 | self.hidden_size = hidden_size 89 | 90 | num_embeddings = num_embeddings 91 | self.padding_idx = pad_idx 92 | embed_dim1 = embed_dim 93 | if pretrained_embed is None: 94 | self.embed_tokens = nn.Embedding( 95 | num_embeddings, embed_dim1, self.padding_idx 96 | ) 97 | else: 98 | self.embed_tokens = pretrained_embed 99 | 100 | self.lstm = nn.LSTM( 101 | input_size=embed_dim, 102 | hidden_size=hidden_size, 103 | num_layers=num_layers, 104 | dropout=self.dropout_out if num_layers > 1 else 0., 105 | bidirectional=bidirectional, 106 | ) 107 | self.left_pad = left_pad 108 | self.padding_value = padding_value 109 | 110 | self.output_units = hidden_size 111 | if bidirectional: 112 | self.output_units *= 2 113 | 114 | def forward(self, src_tokens, src_lengths): 115 | if self.left_pad: 116 | # nn.utils.rnn.pack_padded_sequence requires right-padding; 117 | # convert left-padding to right-padding 118 | src_tokens = utils.convert_padding_direction( 119 | src_tokens, 120 | self.padding_idx, 121 | left_to_right=True, 122 | ) 123 | 124 | bsz, seqlen = src_tokens.size() 125 | # embed tokens 126 | x = self.embed_tokens(src_tokens) 127 | 128 | x = F.dropout(x, p=self.dropout_in, training=self.training) 129 | 130 | # B x T x C -> T x B x C 131 | x = x.transpose(0, 1) 132 | 133 | # pack embedded source tokens into a PackedSequence 134 | packed_x = nn.utils.rnn.pack_padded_sequence( 135 | x, src_lengths.data.tolist(), enforce_sorted=False) 136 | 137 | # apply LSTM 138 | if self.bidirectional: 139 | state_size = 2 * self.num_layers, bsz, self.hidden_size 140 | else: 141 | state_size = self.num_layers, bsz, self.hidden_size 142 | h0 = x.new_zeros(*state_size) 143 | c0 = x.new_zeros(*state_size) 144 | packed_outs, (final_hiddens, final_cells) = self.lstm( 145 | packed_x, (h0, c0)) 146 | 147 | # unpack outputs and apply dropout 148 | x, _ = nn.utils.rnn.pad_packed_sequence( 149 | packed_outs, padding_value=self.padding_value) 150 | x = F.dropout(x, p=self.dropout_out, training=self.training) 151 | assert list(x.size()) == [seqlen, bsz, self.output_units] 152 | 153 | if self.bidirectional: 154 | 155 | def combine_bidir(outs): 156 | out = outs.view(self.num_layers, 2, bsz, - 157 | 1).transpose(1, 2).contiguous() 158 | return out.view(self.num_layers, bsz, -1) 159 | 160 | final_hiddens = combine_bidir(final_hiddens) 161 | final_cells = combine_bidir(final_cells) 162 | 163 | encoder_padding_mask = src_tokens.eq(self.padding_idx).t() 164 | 165 | return { 166 | 'encoder_out': (x, final_hiddens, final_cells), 167 | 'encoder_padding_mask': (encoder_padding_mask 168 | if encoder_padding_mask.any() else None) 169 | } 170 | 171 | def reorder_only_outputs(self, outputs): 172 | """ 173 | outputs of shape : T x B x C -> B x T x C 174 | """ 175 | return outputs.transpose(0, 1).contiguous() 176 | 177 | def reorder_encoder_out(self, encoder_out, new_order): 178 | encoder_out['encoder_out'] = tuple( 179 | eo.index_select(1, new_order) 180 | for eo in encoder_out['encoder_out'] 181 | ) 182 | if encoder_out['encoder_padding_mask'] is not None: 183 | encoder_out['encoder_padding_mask'] = \ 184 | encoder_out['encoder_padding_mask'].index_select(1, new_order) 185 | return encoder_out 186 | 187 | def max_positions(self): 188 | """Maximum input length supported by the encoder.""" 189 | return int(1e5) # an arbitrary large number 190 | 191 | 192 | class SimpleAttn(nn.Module): 193 | def __init__(self, qdim, hdim): 194 | super().__init__() 195 | self.lin1 = nn.Linear(qdim, hdim) 196 | self.lin2 = nn.Linear(qdim, hdim) 197 | self.lin3 = nn.Linear(hdim, 1) 198 | 199 | def forward(self, qvec, qlast, inp): 200 | """ 201 | qvec: B x nsrl x qdim 202 | qlast: B x 1 x qdim 203 | """ 204 | # B x nv x nsrl x hdim 205 | B, num_verbs, nsrl, qd = qvec.shape 206 | qvec_enc = self.lin1(qvec) 207 | # B x nv x 1 x hdim 208 | qlast_enc = self.lin2(qlast) 209 | 210 | hdim = qlast_enc.size(-1) 211 | 212 | # B x nv x nsrl x hdim 213 | q1_enc = torch.tanh( 214 | qvec_enc + 215 | qlast_enc.view( 216 | B, num_verbs, 1, hdim 217 | ).expand( 218 | B, num_verbs, nsrl, hdim 219 | ) 220 | ) 221 | # B x nv x nsrl 222 | u1 = self.lin3(q1_enc).squeeze(-1) 223 | # B x nv x nsrl 224 | attns = F.softmax(u1, dim=-1) 225 | 226 | # B x nv x nsrl x qdim 227 | qvec_out = attns.unsqueeze(-1).expand_as(qvec) * qvec 228 | 229 | return qvec_out 230 | --------------------------------------------------------------------------------