├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── lf_disc_faster_rcnn_x101.yml └── lf_gen_faster_rcnn_x101.yml ├── data └── extract_features_detectron.py ├── docker └── Dockerfile ├── evaluate.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── train.py └── visdialch ├── __init__.py ├── data ├── __init__.py ├── dataset.py ├── readers.py └── vocabulary.py ├── decoders ├── __init__.py ├── disc.py └── gen.py ├── encoders ├── __init__.py └── lf.py ├── metrics.py ├── model.py └── utils ├── __init__.py ├── checkpointing.py └── dynamic_rnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # Datasets, pretrained models, checkpoints and preprocessed files 29 | data/ 30 | !visdialch/data/ 31 | checkpoints/ 32 | logs/ 33 | 34 | # IPython Notebook 35 | .ipynb_checkpoints 36 | 37 | # virtualenv 38 | venv/ 39 | ENV/ 40 | 41 | .idea/ 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 19.3b0 4 | hooks: 5 | - id: black 6 | language_version: python3.6 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v2.1.0 9 | hooks: 10 | - id: flake8 11 | - id: trailing-whitespace 12 | - id: check-added-large-files 13 | - id: end-of-file-fixer 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Karan Desai 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Visual Dialog Challenge Starter Code 2 | ==================================== 3 | 4 | PyTorch starter code for the [Visual Dialog Challenge 2019][1]. 5 | 6 | * [Setup and Dependencies](#setup-and-dependencies) 7 | * [Download Data](#download-data) 8 | * [Training](#training) 9 | * [Evaluation](#evaluation) 10 | * [Pretrained Checkpoint](#pretrained-checkpoint) 11 | * [Acknowledgements](#acknowledgements) 12 | 13 | If you use this code in your research, please consider citing: 14 | 15 | ```text 16 | @misc{desai2018visdialch, 17 | author = {Karan Desai and Abhishek Das and Dhruv Batra and Devi Parikh}, 18 | title = {Visual Dialog Challenge Starter Code}, 19 | howpublished = {\url{https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch}}, 20 | year = {2018} 21 | } 22 | ``` 23 | 24 | [![DOI](https://zenodo.org/badge/140210239.svg)](https://zenodo.org/badge/latestdoi/140210239) 25 | 26 | 27 | What's new with `v2019`? 28 | ------------------------ 29 | 30 | If you are a returning user (from Visual Dialog Challenge 2018), here are some key highlights about our offerings in `v2019` of this starter code: 31 | 32 | 1. _Almost_ a complete rewrite of `v2018`, which increased speed, readability, modularity and extensibility. 33 | 2. Multi-GPU support - try out specifying GPU ids to train/evaluate scripts as: `--gpu-ids 0 1 2 3` 34 | 3. Docker support - we provide a Dockerfile which can help you set up all the dependencies with ease. 35 | 4. Stronger baseline - our Late Fusion Encoder is equipped with [Bottom-up Top-Down attention][6]. We also provide pre-extracted image features (links below). 36 | 5. Minimal pre-processed data - no requirement to download tens of pre-processed data files anymore (were typically referred as `visdial_data.h5` and `visdial_params.json`). 37 | 38 | 39 | Setup and Dependencies 40 | ---------------------- 41 | 42 | This starter code is implemented using PyTorch v1.0, and provides out of the box support with CUDA 9 and CuDNN 7. 43 | There are two recommended ways to set up this codebase: Anaconda or Miniconda, and Docker. 44 | 45 | ### Anaconda or Miniconda 46 | 47 | 1. Install Anaconda or Miniconda distribution based on Python3+ from their [downloads' site][2]. 48 | 2. Clone this repository and create an environment: 49 | 50 | ```sh 51 | git clone https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch 52 | conda create -n visdialch python=3.6 53 | 54 | # activate the environment and install all dependencies 55 | conda activate visdialch 56 | cd visdial-challenge-starter-pytorch/ 57 | pip install -r requirements.txt 58 | 59 | # install this codebase as a package in development version 60 | python setup.py develop 61 | ``` 62 | 63 | **Note:** Docker setup is necessary if you wish to extract image features using Detectron. 64 | 65 | ### Docker 66 | 67 | We provide a Dockerfile which creates a light-weight image with all the dependencies installed. 68 | 69 | 1. Install [nvidia-docker][18], which enables usage of GPUs from inside a container. 70 | 2. Build the image as: 71 | 72 | ```sh 73 | cd docker 74 | docker build -t visdialch . 75 | ``` 76 | 77 | 3. Run this image in a container by setting user+group, attaching project root (this codebase) as a volume and setting shared memory size according to your requirements (depends on the memory usage of your model). 78 | 79 | ```sh 80 | nvidia-docker run -u $(id -u):$(id -g) \ 81 | -v $PROJECT_ROOT:/workspace \ 82 | --shm-size 16G visdialch /bin/bash 83 | ``` 84 | 85 | We recommend this development workflow, attaching the codebase as a volume would immediately reflect source code changes inside the container environment. We also recommend containing all the source code for data loading, models and other utilities inside `visdialch` directory. Since it is a setuptools-style package, it makes handling of absolute/relative imports and module resolving less painful. Scripts using `visdialch` can be created anywhere in the filesystem, as far as the current conda environment is active. 86 | 87 | 88 | Download Data 89 | ------------- 90 | 91 | 1. Download the VisDial v1.0 dialog json files from [here][7] and keep it under `$PROJECT_ROOT/data` directory, for default arguments to work effectively. 92 | 93 | 2. Get the word counts for VisDial v1.0 train split [here][9]. They are used to build the vocabulary. 94 | 95 | 3. We also provide pre-extracted image features of VisDial v1.0 images, using a Faster-RCNN pre-trained on Visual Genome. If you wish to extract your own image features, skip this step and download VIsDial v1.0 images from [here][7] instead. Extracted features for v1.0 train, val and test are available for download at these links. 96 | 97 | * [`features_faster_rcnn_x101_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_train.h5): Bottom-up features of 36 proposals from images of `train` split. 98 | * [`features_faster_rcnn_x101_val.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_val.h5): Bottom-up features of 36 proposals from images of `val` split. 99 | * [`features_faster_rcnn_x101_test.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_test.h5): Bottom-up features of 36 proposals from images of `test` split. 100 | 101 | 4. We also provide pre-extracted FC7 features from VGG16, although the `v2019` of this codebase does not use them anymore. 102 | 103 | * [`features_vgg16_fc7_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_train.h5): VGG16 FC7 features from images of `train` split. 104 | * [`features_vgg16_fc7_val.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_val.h5): VGG16 FC7 features from images of `val` split. 105 | * [`features_vgg16_fc7_test.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_test.h5): VGG16 FC7 features from images of `test` split. 106 | 107 | 108 | Training 109 | -------- 110 | 111 | This codebase supports both generative and discriminative decoding; read more [here][16]. For reference, we have Late Fusion Encoder from the Visual Dialog paper. 112 | 113 | We provide a training script which accepts arguments as config files. The config file should contain arguments which are specific to a particular experiment, such as those defining model architecture, or optimization hyperparameters. Other arguments such as GPU ids, or number of CPU workers should be declared in the script and passed in as argparse-style arguments. 114 | 115 | Train the baseline model provided in this repository as: 116 | 117 | ```sh 118 | python train.py --config-yml configs/lf_disc_faster_rcnn_x101.yml --gpu-ids 0 1 # provide more ids for multi-GPU execution other args... 119 | ``` 120 | 121 | To extend this starter code, add your own encoder/decoder modules into their respective directories and include their names as choices in your config file. We have an `--overfit` flag, which can be useful for rapid debugging. It takes a batch of 5 examples and overfits the model on them. 122 | 123 | ### Saving model checkpoints 124 | 125 | This script will save model checkpoints at every epoch as per path specified by `--save-dirpath`. Refer [visdialch/utils/checkpointing.py][19] for more details on how checkpointing is managed. 126 | 127 | ### Logging 128 | 129 | We use [Tensorboard][5] for logging training progress. Recommended: execute `tensorboard --logdir /path/to/save_dir --port 8008` and visit `localhost:8008` in the browser. 130 | 131 | 132 | Evaluation 133 | ---------- 134 | 135 | Evaluation of a trained model checkpoint can be done as follows: 136 | 137 | ```sh 138 | python evaluate.py --config-yml /path/to/config.yml --load-pthpath /path/to/checkpoint.pth --split val --gpu-ids 0 139 | ``` 140 | 141 | This will generate an EvalAI submission file, and report metrics from the [Visual Dialog paper][13] (Mean reciprocal rank, R@{1, 5, 10}, Mean rank), and Normalized Discounted Cumulative Gain (NDCG), introduced in the first Visual Dialog Challenge (in 2018). 142 | 143 | The metrics reported here would be the same as those reported through EvalAI by making a submission in `val` phase. To generate a submission file for `test-std` or `test-challenge` phase, replace `--split val` with `--split test`. 144 | 145 | 146 | Results and pretrained checkpoints 147 | ---------------------------------- 148 | 149 | Performance on `v1.0 test-std` (trained on `v1.0` train + val): 150 | 151 | Model | R@1 | R@5 | R@10 | MeanR | MRR | NDCG | 152 | ------- | ------ | ------ | ------ | ------ | ------ | ------ | 153 | [lf-disc-faster-rcnn-x101][12] | 0.4617 | 0.7780 | 0.8730 | 4.7545| 0.6041 | 0.5162 | 154 | [lf-gen-faster-rcnn-x101][20] | 0.3620 | 0.5640 | 0.6340 | 19.4458| 0.4657 | 0.5421 | 155 | 156 | 157 | Acknowledgements 158 | ---------------- 159 | 160 | * This starter code began as a fork of [batra-mlp-lab/visdial-rl][14]. We thank the developers for doing most of the heavy-lifting. 161 | * The Lua-torch codebase of Visual Dialog, at [batra-mlp-lab/visdial][15], served as an important reference while developing this codebase. 162 | * Some documentation and design strategies of `Metric`, `Reader` and `Vocabulary` classes are inspired from [AllenNLP][17], It is not a dependency because the use-case in this codebase would be too little in its current state. 163 | 164 | [1]: https://visualdialog.org/challenge/2019 165 | [2]: https://conda.io/docs/user-guide/install/download.html 166 | [3]: http://images.cocodataset.org/zips/train2014.zip 167 | [4]: http://images.cocodataset.org/zips/val2014.zip 168 | [5]: https://www.github.com/lanpa/tensorboardX 169 | [6]: https://arxiv.org/abs/1707.07998 170 | [7]: https://visualdialog.org/data 171 | [9]: https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/visdial_1.0_word_counts_train.json 172 | [10]: https://visualdialog.org/data 173 | [11]: http://www.robots.ox.ac.uk/~vgg/research/very_deep/ 174 | [12]: https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/lf_disc_faster_rcnn_x101_trainval.pth 175 | [13]: https://arxiv.org/abs/1611.08669 176 | [14]: https://www.github.com/batra-mlp-lab/visdial-rl 177 | [15]: https://www.github.com/batra-mlp-lab/visdial 178 | [16]: https://visualdialog.org/challenge/2018#faq 179 | [17]: https://www.github.com/allenai/allennlp 180 | [18]: https://www.github.com/nvidia/nvidia-docker 181 | [19]: https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/utils/checkpointing.py 182 | [20]: https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/lf_gen_faster_rcnn_x101_train.pth 183 | -------------------------------------------------------------------------------- /configs/lf_disc_faster_rcnn_x101.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | 8 | img_norm: 1 9 | concat_history: true 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'lf' 17 | decoder: 'disc' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | dropout: 0.5 24 | 25 | 26 | # Optimization related arguments 27 | solver: 28 | batch_size: 128 # 32 x num_gpus is a good rule of thumb 29 | num_epochs: 20 30 | initial_lr: 0.01 31 | training_splits: "train" # "trainval" 32 | lr_gamma: 0.1 33 | lr_milestones: # epochs when lr => lr * lr_gamma 34 | - 4 35 | - 7 36 | - 10 37 | warmup_factor: 0.2 38 | warmup_epochs: 1 39 | -------------------------------------------------------------------------------- /configs/lf_gen_faster_rcnn_x101.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | 8 | img_norm: 1 9 | concat_history: true 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'lf' 17 | decoder: 'gen' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | dropout: 0.2 24 | 25 | 26 | # Optimization related arguments 27 | solver: 28 | batch_size: 224 # 56 x num_gpus is a good rule of thumb 29 | num_epochs: 20 30 | initial_lr: 0.01 31 | training_splits: "train" # "trainval" 32 | lr_gamma: 0.1 33 | lr_milestones: # epochs when lr => lr * lr_gamma 34 | - 4 35 | - 7 36 | - 10 37 | warmup_factor: 0.2 38 | warmup_epochs: 1 39 | -------------------------------------------------------------------------------- /data/extract_features_detectron.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import cv2 # must import before importing caffe2 due to bug in cv2 6 | from caffe2.python import workspace 7 | from tqdm import tqdm 8 | import h5py 9 | import numpy as np 10 | 11 | from detectron.core.config import assert_and_infer_cfg, merge_cfg_from_file 12 | from detectron.core.config import cfg as detectron_config 13 | from detectron.utils.boxes import nms as detectron_nms 14 | import detectron.core.test as detectron_test 15 | import detectron.core.test_engine as infer_engine 16 | import detectron.utils.c2 as c2_utils 17 | 18 | 19 | c2_utils.import_detectron_ops() 20 | # OpenCL may be enabled by default in OpenCV3; disable it because it's not 21 | # thread safe and causes unwanted GPU memory allocations. 22 | cv2.ocl.setUseOpenCL(False) 23 | 24 | parser = argparse.ArgumentParser( 25 | description="Extract bottom-up features from a model trained by Detectron" 26 | ) 27 | parser.add_argument( 28 | "--image-root", 29 | nargs="+", 30 | help="Path to a directory containing COCO/VisDial images. Note that this " 31 | "directory must have images, and not sub-directories of splits. " 32 | "Each HDF file should contain features from a single split." 33 | "Multiple paths are supported to account for VisDial v1.0 train.", 34 | ) 35 | parser.add_argument( 36 | "--config", 37 | help="Path to model config file used by Detectron (.yaml)", 38 | default="data/config_faster_rcnn_x101.yaml", 39 | ) 40 | parser.add_argument( 41 | "--weights", 42 | help="Path to model weights file saved by Detectron (.pkl)", 43 | default="data/model_faster_rcnn_x101.pkl", 44 | ) 45 | parser.add_argument( 46 | "--save-path", 47 | help="Path to output file for saving bottom-up features (.h5)", 48 | default="data/data_img_faster_rcnn_x101.h5", 49 | ) 50 | parser.add_argument( 51 | "--max-boxes", 52 | help="Maximum number of bounding box proposals per image", 53 | type=int, 54 | default=100 55 | ) 56 | parser.add_argument( 57 | "--feat-name", 58 | help="The name of the layer to extract features from.", 59 | default="fc7", 60 | ) 61 | parser.add_argument( 62 | "--feat-dims", 63 | help="Length of bottom-upfeature vectors.", 64 | type=int, 65 | default=2048, 66 | ) 67 | parser.add_argument( 68 | "--split", 69 | choices=["train", "val", "test"], 70 | help="Which split is being processed.", 71 | ) 72 | parser.add_argument( 73 | "--gpu-id", 74 | help="The GPU id to use (-1 for CPU execution)", 75 | type=int, 76 | default=0, 77 | ) 78 | 79 | 80 | def detect_image(detectron_model, image, args): 81 | """Given an image and a detectron model, extract object boxes, 82 | classes, confidences and features from the image using the model. 83 | 84 | Parameters 85 | ---------- 86 | detectron_model 87 | Detectron model. 88 | image : np.ndarray 89 | Image in BGR format. 90 | args : argparse.Namespace 91 | Parsed command-line arguments. 92 | 93 | Returns 94 | ------- 95 | np.ndarray, np.ndarray, np.ndarray, np.ndarray 96 | Object bounding boxes, classes, confidence and features. 97 | """ 98 | 99 | scores, cls_boxes, im_scale = detectron_test.im_detect_bbox( 100 | detectron_model, 101 | image, 102 | detectron_config.TEST.SCALE, 103 | detectron_config.TEST.MAX_SIZE, 104 | boxes=None, 105 | ) 106 | num_proposals = scores.shape[0] 107 | 108 | rois = workspace.FetchBlob(f"gpu_{args.gpu_id}/rois") 109 | features = workspace.FetchBlob( 110 | f"gpu_{args.gpu_id}/{args.feat_name}" 111 | ) 112 | 113 | cls_boxes = rois[:, 1:5] / im_scale 114 | max_conf = np.zeros((num_proposals,), dtype=np.float32) 115 | max_cls = np.zeros((num_proposals,), dtype=np.int32) 116 | max_box = np.zeros((num_proposals, 4), dtype=np.float32) 117 | 118 | for cls_ind in range(1, detectron_config.MODEL.NUM_CLASSES): 119 | cls_scores = scores[:, cls_ind] 120 | dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype( 121 | np.float32 122 | ) 123 | keep = np.array(detectron_nms(dets, detectron_config.TEST.NMS)) 124 | idxs_update = np.where(cls_scores[keep] > max_conf[keep]) 125 | keep_idxs = keep[idxs_update] 126 | max_conf[keep_idxs] = cls_scores[keep_idxs] 127 | max_cls[keep_idxs] = cls_ind 128 | max_box[keep_idxs] = dets[keep_idxs][:, :4] 129 | 130 | keep_boxes = np.argsort(max_conf)[::-1][:args.max_boxes] 131 | boxes = max_box[keep_boxes, :] 132 | classes = max_cls[keep_boxes] 133 | confidence = max_conf[keep_boxes] 134 | features = features[keep_boxes, :] 135 | return boxes, features, classes, confidence 136 | 137 | 138 | def image_id_from_path(image_path): 139 | """Given a path to an image, return its id. 140 | 141 | Parameters 142 | ---------- 143 | image_path : str 144 | Path to image, e.g.: coco_train2014/COCO_train2014/000000123456.jpg 145 | 146 | Returns 147 | ------- 148 | int 149 | Corresponding image id (123456) 150 | """ 151 | 152 | return int(image_path.split("/")[-1][-16:-4]) 153 | 154 | 155 | def main(args): 156 | """Extract bottom-up features from all images in a directory using 157 | a pre-trained Detectron model, and save them in HDF format. 158 | 159 | Parameters 160 | ---------- 161 | args : argparse.Namespace 162 | Parsed command-line arguments. 163 | """ 164 | 165 | # specifically for visual genome 166 | detectron_config.MODEL.NUM_ATTRIBUTES = -1 167 | merge_cfg_from_file(args.config) 168 | 169 | # override some config options and validate the config 170 | detectron_config.NUM_GPUS = 1 171 | detectron_config.TRAIN.CPP_RPN = "none" 172 | assert_and_infer_cfg(cache_urls=False) 173 | 174 | # initialize model 175 | detectron_model = infer_engine.initialize_model_from_cfg( 176 | args.weights, args.gpu_id 177 | ) 178 | 179 | # list of paths (example: "coco_train2014/COCO_train2014_000000123456.jpg") 180 | image_paths = [] 181 | for image_root in args.image_root: 182 | image_paths.extend( 183 | [ 184 | os.path.join(image_root, name) 185 | for name in glob.glob(os.path.join(image_root, "*.jpg")) 186 | if name not in {".", ".."} 187 | ] 188 | ) 189 | 190 | # create an output HDF to save extracted features 191 | save_h5 = h5py.File(args.save_path, "w") 192 | image_ids_h5d = save_h5.create_dataset( 193 | "image_ids", (len(image_paths),), dtype=int 194 | ) 195 | 196 | boxes_h5d = save_h5.create_dataset( 197 | "boxes", (len(image_paths), args.max_boxes, 4), 198 | ) 199 | features_h5d = save_h5.create_dataset( 200 | "features", (len(image_paths), args.max_boxes, args.feat_dims), 201 | ) 202 | classes_h5d = save_h5.create_dataset( 203 | "classes", (len(image_paths), args.max_boxes, ), 204 | ) 205 | scores_h5d = save_h5.create_dataset( 206 | "scores", (len(image_paths), args.max_boxes, ), 207 | ) 208 | 209 | with c2_utils.NamedCudaScope(args.gpu_id): 210 | for idx, image_path in enumerate(tqdm(image_paths)): 211 | try: 212 | image_ids_h5d[idx] = image_id_from_path(image_path) 213 | 214 | image = cv2.imread(image_path) 215 | boxes, features, classes, scores = detect_image(detectron_model, image, args) 216 | 217 | boxes_h5d[idx] = boxes 218 | features_h5d[idx] = features 219 | classes_h5d[idx] = classes 220 | scores_h5d[idx] = scores 221 | except: 222 | print(f"\nWarning: Failed to extract features from {idx}, {image_path}.\n") 223 | 224 | # set current split name in attributrs of file, for tractability 225 | save_h5.attrs["split"] = args.split 226 | save_h5.close() 227 | 228 | 229 | if __name__ == "__main__": 230 | # set higher log level to prevent terminal spam 231 | workspace.GlobalInit(["caffe2", "--caffe2_log_level=3"]) 232 | args = parser.parse_args() 233 | main(args) 234 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # TODO: update to pytorch 1.0 container when it is avalable 2 | # PyTorch 1.0 will be downloaded through requirements.txt anyway 3 | 4 | FROM pytorch/pytorch:0.4.1-cuda9-cudnn7-devel 5 | 6 | RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender1 libxext6 7 | 8 | RUN pip install --upgrade pip && pip install cython 9 | RUN git clone --depth 1 https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch /workspace && \ 10 | pip install -r /workspace/requirements.txt 11 | 12 | RUN git clone --depth 1 https://www.github.com/facebookresearch/detectron /detectron && \ 13 | pip install -r /detectron/requirements.txt 14 | 15 | WORKDIR /detectron 16 | RUN make 17 | 18 | WORKDIR /workspace 19 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | import yaml 10 | 11 | from visdialch.data.dataset import VisDialDataset 12 | from visdialch.encoders import Encoder 13 | from visdialch.decoders import Decoder 14 | from visdialch.metrics import SparseGTMetrics, NDCG, scores_to_ranks 15 | from visdialch.model import EncoderDecoderModel 16 | from visdialch.utils.checkpointing import load_checkpoint 17 | 18 | 19 | parser = argparse.ArgumentParser( 20 | "Evaluate and/or generate EvalAI submission file." 21 | ) 22 | parser.add_argument( 23 | "--config-yml", 24 | default="configs/lf_disc_faster_rcnn_x101.yml", 25 | help="Path to a config file listing reader, model and optimization " 26 | "parameters.", 27 | ) 28 | parser.add_argument( 29 | "--split", 30 | default="val", 31 | choices=["val", "test"], 32 | help="Which split to evaluate upon.", 33 | ) 34 | parser.add_argument( 35 | "--val-json", 36 | default="data/visdial_1.0_val.json", 37 | help="Path to VisDial v1.0 val data. This argument doesn't work when " 38 | "--split=test.", 39 | ) 40 | parser.add_argument( 41 | "--val-dense-json", 42 | default="data/visdial_1.0_val_dense_annotations.json", 43 | help="Path to VisDial v1.0 val dense annotations (if evaluating on val " 44 | "split). This argument doesn't work when --split=test.", 45 | ) 46 | parser.add_argument( 47 | "--test-json", 48 | default="data/visdial_1.0_test.json", 49 | help="Path to VisDial v1.0 test data. This argument doesn't work when " 50 | "--split=val.", 51 | ) 52 | 53 | parser.add_argument_group("Evaluation related arguments") 54 | parser.add_argument( 55 | "--load-pthpath", 56 | default="checkpoints/checkpoint_xx.pth", 57 | help="Path to .pth file of pretrained checkpoint.", 58 | ) 59 | 60 | parser.add_argument_group( 61 | "Arguments independent of experiment reproducibility" 62 | ) 63 | parser.add_argument( 64 | "--gpu-ids", 65 | nargs="+", 66 | type=int, 67 | default=-1, 68 | help="List of ids of GPUs to use.", 69 | ) 70 | parser.add_argument( 71 | "--cpu-workers", 72 | type=int, 73 | default=4, 74 | help="Number of CPU workers for reading data.", 75 | ) 76 | parser.add_argument( 77 | "--overfit", 78 | action="store_true", 79 | help="Overfit model on 5 examples, meant for debugging.", 80 | ) 81 | parser.add_argument( 82 | "--in-memory", 83 | action="store_true", 84 | help="Load the whole dataset and pre-extracted image features in memory. " 85 | "Use only in presence of large RAM, atleast few tens of GBs.", 86 | ) 87 | 88 | parser.add_argument_group("Submission related arguments") 89 | parser.add_argument( 90 | "--save-ranks-path", 91 | default="logs/ranks.json", 92 | help="Path (json) to save ranks, in a EvalAI submission format.", 93 | ) 94 | 95 | # For reproducibility. 96 | # Refer https://pytorch.org/docs/stable/notes/randomness.html 97 | torch.manual_seed(0) 98 | torch.cuda.manual_seed_all(0) 99 | torch.backends.cudnn.benchmark = False 100 | torch.backends.cudnn.deterministic = True 101 | 102 | # ============================================================================= 103 | # INPUT ARGUMENTS AND CONFIG 104 | # ============================================================================= 105 | 106 | args = parser.parse_args() 107 | 108 | # keys: {"dataset", "model", "solver"} 109 | config = yaml.load(open(args.config_yml)) 110 | 111 | if isinstance(args.gpu_ids, int): 112 | args.gpu_ids = [args.gpu_ids] 113 | device = ( 114 | torch.device("cuda", args.gpu_ids[0]) 115 | if args.gpu_ids[0] >= 0 116 | else torch.device("cpu") 117 | ) 118 | 119 | # Print config and args. 120 | print(yaml.dump(config, default_flow_style=False)) 121 | for arg in vars(args): 122 | print("{:<20}: {}".format(arg, getattr(args, arg))) 123 | 124 | 125 | # ============================================================================= 126 | # SETUP DATASET, DATALOADER, MODEL 127 | # ============================================================================= 128 | 129 | if args.split == "val": 130 | val_dataset = VisDialDataset( 131 | config["dataset"], 132 | args.val_json, 133 | args.val_dense_json, 134 | overfit=args.overfit, 135 | in_memory=args.in_memory, 136 | return_options=True, 137 | add_boundary_toks=False 138 | if config["model"]["decoder"] == "disc" 139 | else True, 140 | ) 141 | else: 142 | val_dataset = VisDialDataset( 143 | config["dataset"], 144 | args.test_json, 145 | overfit=args.overfit, 146 | in_memory=args.in_memory, 147 | return_options=True, 148 | add_boundary_toks=False 149 | if config["model"]["decoder"] == "disc" 150 | else True, 151 | ) 152 | val_dataloader = DataLoader( 153 | val_dataset, 154 | batch_size=config["solver"]["batch_size"] 155 | if config["model"]["decoder"] == "disc" 156 | else 5, 157 | num_workers=args.cpu_workers, 158 | ) 159 | 160 | # Pass vocabulary to construct Embedding layer. 161 | encoder = Encoder(config["model"], val_dataset.vocabulary) 162 | decoder = Decoder(config["model"], val_dataset.vocabulary) 163 | print("Encoder: {}".format(config["model"]["encoder"])) 164 | print("Decoder: {}".format(config["model"]["decoder"])) 165 | 166 | # Share word embedding between encoder and decoder. 167 | decoder.word_embed = encoder.word_embed 168 | 169 | # Wrap encoder and decoder in a model. 170 | model = EncoderDecoderModel(encoder, decoder).to(device) 171 | if -1 not in args.gpu_ids: 172 | model = nn.DataParallel(model, args.gpu_ids) 173 | 174 | model_state_dict, _ = load_checkpoint(args.load_pthpath) 175 | if isinstance(model, nn.DataParallel): 176 | model.module.load_state_dict(model_state_dict) 177 | else: 178 | model.load_state_dict(model_state_dict) 179 | print("Loaded model from {}".format(args.load_pthpath)) 180 | 181 | # Declare metric accumulators (won't be used if --split=test) 182 | sparse_metrics = SparseGTMetrics() 183 | ndcg = NDCG() 184 | 185 | # ============================================================================= 186 | # EVALUATION LOOP 187 | # ============================================================================= 188 | 189 | model.eval() 190 | ranks_json = [] 191 | 192 | for _, batch in enumerate(tqdm(val_dataloader)): 193 | for key in batch: 194 | batch[key] = batch[key].to(device) 195 | with torch.no_grad(): 196 | output = model(batch) 197 | 198 | ranks = scores_to_ranks(output) 199 | for i in range(len(batch["img_ids"])): 200 | # Cast into types explicitly to ensure no errors in schema. 201 | # Round ids are 1-10, not 0-9 202 | if args.split == "test": 203 | ranks_json.append( 204 | { 205 | "image_id": batch["img_ids"][i].item(), 206 | "round_id": int(batch["num_rounds"][i].item()), 207 | "ranks": [ 208 | rank.item() 209 | for rank in ranks[i][batch["num_rounds"][i] - 1] 210 | ], 211 | } 212 | ) 213 | else: 214 | for j in range(batch["num_rounds"][i]): 215 | ranks_json.append( 216 | { 217 | "image_id": batch["img_ids"][i].item(), 218 | "round_id": int(j + 1), 219 | "ranks": [rank.item() for rank in ranks[i][j]], 220 | } 221 | ) 222 | 223 | if args.split == "val": 224 | sparse_metrics.observe(output, batch["ans_ind"]) 225 | if "gt_relevance" in batch: 226 | output = output[ 227 | torch.arange(output.size(0)), batch["round_id"] - 1, : 228 | ] 229 | ndcg.observe(output, batch["gt_relevance"]) 230 | 231 | if args.split == "val": 232 | all_metrics = {} 233 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 234 | all_metrics.update(ndcg.retrieve(reset=True)) 235 | for metric_name, metric_value in all_metrics.items(): 236 | print(f"{metric_name}: {metric_value}") 237 | 238 | print("Writing ranks to {}".format(args.save_ranks_path)) 239 | os.makedirs(os.path.dirname(args.save_ranks_path), exist_ok=True) 240 | json.dump(ranks_json, open(args.save_ranks_path, "w")) 241 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython==0.29.1 2 | h5py==2.8.0 3 | nltk==3.6.3 4 | numpy==1.15.4 5 | Pillow==5.3.0 6 | pyyaml>=4.2b1 7 | six==1.11.0 8 | tensorboardX==1.2 9 | tensorflow==1.12.0 10 | torch==1.0.0 11 | tqdm==4.28.1 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup 3 | 4 | 5 | setup( 6 | name="visdialch", 7 | version="2019.0.0", 8 | author="Karan Desai", 9 | url="https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch", 10 | description="Starter code for VisDial challenge 2019", 11 | license="BSD", 12 | zip_safe=True, 13 | ) 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | 4 | from tensorboardX import SummaryWriter 5 | import torch 6 | from torch import nn, optim 7 | from torch.optim import lr_scheduler 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import yaml 11 | from bisect import bisect 12 | 13 | from visdialch.data.dataset import VisDialDataset 14 | from visdialch.encoders import Encoder 15 | from visdialch.decoders import Decoder 16 | from visdialch.metrics import SparseGTMetrics, NDCG 17 | from visdialch.model import EncoderDecoderModel 18 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--config-yml", 24 | default="configs/lf_disc_faster_rcnn_x101.yml", 25 | help="Path to a config file listing reader, model and solver parameters.", 26 | ) 27 | parser.add_argument( 28 | "--train-json", 29 | default="data/visdial_1.0_train.json", 30 | help="Path to json file containing VisDial v1.0 training data.", 31 | ) 32 | parser.add_argument( 33 | "--val-json", 34 | default="data/visdial_1.0_val.json", 35 | help="Path to json file containing VisDial v1.0 validation data.", 36 | ) 37 | parser.add_argument( 38 | "--val-dense-json", 39 | default="data/visdial_1.0_val_dense_annotations.json", 40 | help="Path to json file containing VisDial v1.0 validation dense ground " 41 | "truth annotations.", 42 | ) 43 | 44 | 45 | parser.add_argument_group( 46 | "Arguments independent of experiment reproducibility" 47 | ) 48 | parser.add_argument( 49 | "--gpu-ids", 50 | nargs="+", 51 | type=int, 52 | default=0, 53 | help="List of ids of GPUs to use.", 54 | ) 55 | parser.add_argument( 56 | "--cpu-workers", 57 | type=int, 58 | default=4, 59 | help="Number of CPU workers for dataloader.", 60 | ) 61 | parser.add_argument( 62 | "--overfit", 63 | action="store_true", 64 | help="Overfit model on 5 examples, meant for debugging.", 65 | ) 66 | parser.add_argument( 67 | "--validate", 68 | action="store_true", 69 | help="Whether to validate on val split after every epoch.", 70 | ) 71 | parser.add_argument( 72 | "--in-memory", 73 | action="store_true", 74 | help="Load the whole dataset and pre-extracted image features in memory. " 75 | "Use only in presence of large RAM, atleast few tens of GBs.", 76 | ) 77 | 78 | 79 | parser.add_argument_group("Checkpointing related arguments") 80 | parser.add_argument( 81 | "--save-dirpath", 82 | default="checkpoints/", 83 | help="Path of directory to create checkpoint directory and save " 84 | "checkpoints.", 85 | ) 86 | parser.add_argument( 87 | "--load-pthpath", 88 | default="", 89 | help="To continue training, path to .pth file of saved checkpoint.", 90 | ) 91 | 92 | # For reproducibility. 93 | # Refer https://pytorch.org/docs/stable/notes/randomness.html 94 | torch.manual_seed(0) 95 | torch.cuda.manual_seed_all(0) 96 | torch.backends.cudnn.benchmark = False 97 | torch.backends.cudnn.deterministic = True 98 | 99 | 100 | # ============================================================================= 101 | # INPUT ARGUMENTS AND CONFIG 102 | # ============================================================================= 103 | 104 | args = parser.parse_args() 105 | 106 | # keys: {"dataset", "model", "solver"} 107 | config = yaml.load(open(args.config_yml)) 108 | 109 | if isinstance(args.gpu_ids, int): 110 | args.gpu_ids = [args.gpu_ids] 111 | device = ( 112 | torch.device("cuda", args.gpu_ids[0]) 113 | if args.gpu_ids[0] >= 0 114 | else torch.device("cpu") 115 | ) 116 | torch.cuda.set_device(device) 117 | 118 | # Print config and args. 119 | print(yaml.dump(config, default_flow_style=False)) 120 | for arg in vars(args): 121 | print("{:<20}: {}".format(arg, getattr(args, arg))) 122 | 123 | 124 | # ============================================================================= 125 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 126 | # ============================================================================= 127 | 128 | train_dataset = VisDialDataset( 129 | config["dataset"], 130 | args.train_json, 131 | overfit=args.overfit, 132 | in_memory=args.in_memory, 133 | num_workers=args.cpu_workers, 134 | return_options=True if config["model"]["decoder"] == "disc" else False, 135 | add_boundary_toks=False if config["model"]["decoder"] == "disc" else True, 136 | ) 137 | train_dataloader = DataLoader( 138 | train_dataset, 139 | batch_size=config["solver"]["batch_size"], 140 | num_workers=args.cpu_workers, 141 | shuffle=True, 142 | ) 143 | 144 | val_dataset = VisDialDataset( 145 | config["dataset"], 146 | args.val_json, 147 | args.val_dense_json, 148 | overfit=args.overfit, 149 | in_memory=args.in_memory, 150 | num_workers=args.cpu_workers, 151 | return_options=True, 152 | add_boundary_toks=False if config["model"]["decoder"] == "disc" else True, 153 | ) 154 | val_dataloader = DataLoader( 155 | val_dataset, 156 | batch_size=config["solver"]["batch_size"] 157 | if config["model"]["decoder"] == "disc" 158 | else 5, 159 | num_workers=args.cpu_workers, 160 | ) 161 | 162 | # Pass vocabulary to construct Embedding layer. 163 | encoder = Encoder(config["model"], train_dataset.vocabulary) 164 | decoder = Decoder(config["model"], train_dataset.vocabulary) 165 | print("Encoder: {}".format(config["model"]["encoder"])) 166 | print("Decoder: {}".format(config["model"]["decoder"])) 167 | 168 | # Share word embedding between encoder and decoder. 169 | decoder.word_embed = encoder.word_embed 170 | 171 | # Wrap encoder and decoder in a model. 172 | model = EncoderDecoderModel(encoder, decoder).to(device) 173 | if -1 not in args.gpu_ids: 174 | model = nn.DataParallel(model, args.gpu_ids) 175 | 176 | # Loss function. 177 | if config["model"]["decoder"] == "disc": 178 | criterion = nn.CrossEntropyLoss() 179 | elif config["model"]["decoder"] == "gen": 180 | criterion = nn.CrossEntropyLoss( 181 | ignore_index=train_dataset.vocabulary.PAD_INDEX 182 | ) 183 | else: 184 | raise NotImplementedError 185 | 186 | if config["solver"]["training_splits"] == "trainval": 187 | iterations = (len(train_dataset) + len(val_dataset)) // config["solver"][ 188 | "batch_size" 189 | ] + 1 190 | else: 191 | iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 192 | 193 | 194 | def lr_lambda_fun(current_iteration: int) -> float: 195 | """Returns a learning rate multiplier. 196 | 197 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 198 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 199 | """ 200 | current_epoch = float(current_iteration) / iterations 201 | if current_epoch <= config["solver"]["warmup_epochs"]: 202 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 203 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 204 | else: 205 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 206 | return pow(config["solver"]["lr_gamma"], idx) 207 | 208 | 209 | optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) 210 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) 211 | 212 | 213 | # ============================================================================= 214 | # SETUP BEFORE TRAINING LOOP 215 | # ============================================================================= 216 | 217 | summary_writer = SummaryWriter(log_dir=args.save_dirpath) 218 | checkpoint_manager = CheckpointManager( 219 | model, optimizer, args.save_dirpath, config=config 220 | ) 221 | sparse_metrics = SparseGTMetrics() 222 | ndcg = NDCG() 223 | 224 | # If loading from checkpoint, adjust start epoch and load parameters. 225 | if args.load_pthpath == "": 226 | start_epoch = 0 227 | else: 228 | # "path/to/checkpoint_xx.pth" -> xx 229 | start_epoch = int(args.load_pthpath.split("_")[-1][:-4]) 230 | 231 | model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath) 232 | if isinstance(model, nn.DataParallel): 233 | model.module.load_state_dict(model_state_dict) 234 | else: 235 | model.load_state_dict(model_state_dict) 236 | optimizer.load_state_dict(optimizer_state_dict) 237 | print("Loaded model from {}".format(args.load_pthpath)) 238 | 239 | # ============================================================================= 240 | # TRAINING LOOP 241 | # ============================================================================= 242 | 243 | # Forever increasing counter to keep track of iterations (for tensorboard log). 244 | global_iteration_step = start_epoch * iterations 245 | 246 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 247 | 248 | # ------------------------------------------------------------------------- 249 | # ON EPOCH START (combine dataloaders if training on train + val) 250 | # ------------------------------------------------------------------------- 251 | if config["solver"]["training_splits"] == "trainval": 252 | combined_dataloader = itertools.chain(train_dataloader, val_dataloader) 253 | else: 254 | combined_dataloader = itertools.chain(train_dataloader) 255 | 256 | print(f"\nTraining for epoch {epoch}:") 257 | for i, batch in enumerate(tqdm(combined_dataloader)): 258 | for key in batch: 259 | batch[key] = batch[key].to(device) 260 | 261 | optimizer.zero_grad() 262 | output = model(batch) 263 | target = ( 264 | batch["ans_ind"] 265 | if config["model"]["decoder"] == "disc" 266 | else batch["ans_out"] 267 | ) 268 | batch_loss = criterion( 269 | output.view(-1, output.size(-1)), target.view(-1) 270 | ) 271 | batch_loss.backward() 272 | optimizer.step() 273 | 274 | summary_writer.add_scalar( 275 | "train/loss", batch_loss, global_iteration_step 276 | ) 277 | summary_writer.add_scalar( 278 | "train/lr", optimizer.param_groups[0]["lr"], global_iteration_step 279 | ) 280 | 281 | scheduler.step(global_iteration_step) 282 | global_iteration_step += 1 283 | torch.cuda.empty_cache() 284 | 285 | # ------------------------------------------------------------------------- 286 | # ON EPOCH END (checkpointing and validation) 287 | # ------------------------------------------------------------------------- 288 | checkpoint_manager.step() 289 | 290 | # Validate and report automatic metrics. 291 | if args.validate: 292 | 293 | # Switch dropout, batchnorm etc to the correct mode. 294 | model.eval() 295 | 296 | print(f"\nValidation after epoch {epoch}:") 297 | for i, batch in enumerate(tqdm(val_dataloader)): 298 | for key in batch: 299 | batch[key] = batch[key].to(device) 300 | with torch.no_grad(): 301 | output = model(batch) 302 | sparse_metrics.observe(output, batch["ans_ind"]) 303 | if "gt_relevance" in batch: 304 | output = output[ 305 | torch.arange(output.size(0)), batch["round_id"] - 1, : 306 | ] 307 | ndcg.observe(output, batch["gt_relevance"]) 308 | 309 | all_metrics = {} 310 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 311 | all_metrics.update(ndcg.retrieve(reset=True)) 312 | for metric_name, metric_value in all_metrics.items(): 313 | print(f"{metric_name}: {metric_value}") 314 | summary_writer.add_scalars( 315 | "metrics", all_metrics, global_iteration_step 316 | ) 317 | 318 | model.train() 319 | torch.cuda.empty_cache() 320 | -------------------------------------------------------------------------------- /visdialch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/batra-mlp-lab/visdial-challenge-starter-pytorch/5844f3d5a575e9ec1c1684feb760e7de5c912beb/visdialch/__init__.py -------------------------------------------------------------------------------- /visdialch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.data.dataset import VisDialDataset 2 | from visdialch.data.vocabulary import Vocabulary 3 | -------------------------------------------------------------------------------- /visdialch/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import torch 4 | from torch.nn.functional import normalize 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import Dataset 7 | 8 | from visdialch.data.readers import ( 9 | DialogsReader, 10 | DenseAnnotationsReader, 11 | ImageFeaturesHdfReader, 12 | ) 13 | from visdialch.data.vocabulary import Vocabulary 14 | 15 | 16 | class VisDialDataset(Dataset): 17 | """ 18 | A full representation of VisDial v1.0 (train/val/test) dataset. According 19 | to the appropriate split, it returns dictionary of question, image, 20 | history, ground truth answer, answer options, dense annotations etc. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | config: Dict[str, Any], 26 | dialogs_jsonpath: str, 27 | dense_annotations_jsonpath: Optional[str] = None, 28 | overfit: bool = False, 29 | in_memory: bool = False, 30 | num_workers: int = 1, 31 | return_options: bool = True, 32 | add_boundary_toks: bool = False, 33 | ): 34 | super().__init__() 35 | self.config = config 36 | self.return_options = return_options 37 | self.add_boundary_toks = add_boundary_toks 38 | self.dialogs_reader = DialogsReader( 39 | dialogs_jsonpath, 40 | num_examples=(5 if overfit else None), 41 | num_workers=num_workers 42 | ) 43 | 44 | if "val" in self.split and dense_annotations_jsonpath is not None: 45 | self.annotations_reader = DenseAnnotationsReader( 46 | dense_annotations_jsonpath 47 | ) 48 | else: 49 | self.annotations_reader = None 50 | 51 | self.vocabulary = Vocabulary( 52 | config["word_counts_json"], min_count=config["vocab_min_count"] 53 | ) 54 | 55 | # Initialize image features reader according to split. 56 | image_features_hdfpath = config["image_features_train_h5"] 57 | if "val" in self.dialogs_reader.split: 58 | image_features_hdfpath = config["image_features_val_h5"] 59 | elif "test" in self.dialogs_reader.split: 60 | image_features_hdfpath = config["image_features_test_h5"] 61 | 62 | self.hdf_reader = ImageFeaturesHdfReader( 63 | image_features_hdfpath, in_memory 64 | ) 65 | 66 | # Keep a list of image_ids as primary keys to access data. 67 | self.image_ids = list(self.dialogs_reader.dialogs.keys()) 68 | if overfit: 69 | self.image_ids = self.image_ids[:5] 70 | 71 | @property 72 | def split(self): 73 | return self.dialogs_reader.split 74 | 75 | def __len__(self): 76 | return len(self.image_ids) 77 | 78 | def __getitem__(self, index): 79 | # Get image_id, which serves as a primary key for current instance. 80 | image_id = self.image_ids[index] 81 | 82 | # Get image features for this image_id using hdf reader. 83 | image_features = self.hdf_reader[image_id] 84 | image_features = torch.tensor(image_features) 85 | # Normalize image features at zero-th dimension (since there's no batch 86 | # dimension). 87 | if self.config["img_norm"]: 88 | image_features = normalize(image_features, dim=0, p=2) 89 | 90 | # Retrieve instance for this image_id using json reader. 91 | visdial_instance = self.dialogs_reader[image_id] 92 | caption = visdial_instance["caption"] 93 | dialog = visdial_instance["dialog"] 94 | 95 | # Convert word tokens of caption, question, answer and answer options 96 | # to integers. 97 | caption = self.vocabulary.to_indices(caption) 98 | for i in range(len(dialog)): 99 | dialog[i]["question"] = self.vocabulary.to_indices( 100 | dialog[i]["question"] 101 | ) 102 | if self.add_boundary_toks: 103 | dialog[i]["answer"] = self.vocabulary.to_indices( 104 | [self.vocabulary.SOS_TOKEN] 105 | + dialog[i]["answer"] 106 | + [self.vocabulary.EOS_TOKEN] 107 | ) 108 | else: 109 | dialog[i]["answer"] = self.vocabulary.to_indices( 110 | dialog[i]["answer"] 111 | ) 112 | 113 | if self.return_options: 114 | for j in range(len(dialog[i]["answer_options"])): 115 | if self.add_boundary_toks: 116 | dialog[i]["answer_options"][ 117 | j 118 | ] = self.vocabulary.to_indices( 119 | [self.vocabulary.SOS_TOKEN] 120 | + dialog[i]["answer_options"][j] 121 | + [self.vocabulary.EOS_TOKEN] 122 | ) 123 | else: 124 | dialog[i]["answer_options"][ 125 | j 126 | ] = self.vocabulary.to_indices( 127 | dialog[i]["answer_options"][j] 128 | ) 129 | 130 | questions, question_lengths = self._pad_sequences( 131 | [dialog_round["question"] for dialog_round in dialog] 132 | ) 133 | history, history_lengths = self._get_history( 134 | caption, 135 | [dialog_round["question"] for dialog_round in dialog], 136 | [dialog_round["answer"] for dialog_round in dialog], 137 | ) 138 | answers_in, answer_lengths = self._pad_sequences( 139 | [dialog_round["answer"][:-1] for dialog_round in dialog] 140 | ) 141 | answers_out, _ = self._pad_sequences( 142 | [dialog_round["answer"][1:] for dialog_round in dialog] 143 | ) 144 | 145 | # Collect everything as tensors for ``collate_fn`` of dataloader to 146 | # work seamlessly questions, history, etc. are converted to 147 | # LongTensors, for nn.Embedding input. 148 | item = {} 149 | item["img_ids"] = torch.tensor(image_id).long() 150 | item["img_feat"] = image_features 151 | item["ques"] = questions.long() 152 | item["hist"] = history.long() 153 | item["ans_in"] = answers_in.long() 154 | item["ans_out"] = answers_out.long() 155 | item["ques_len"] = torch.tensor(question_lengths).long() 156 | item["hist_len"] = torch.tensor(history_lengths).long() 157 | item["ans_len"] = torch.tensor(answer_lengths).long() 158 | item["num_rounds"] = torch.tensor( 159 | visdial_instance["num_rounds"] 160 | ).long() 161 | 162 | if self.return_options: 163 | if self.add_boundary_toks: 164 | answer_options_in, answer_options_out = [], [] 165 | answer_option_lengths = [] 166 | for dialog_round in dialog: 167 | options, option_lengths = self._pad_sequences( 168 | [ 169 | option[:-1] 170 | for option in dialog_round["answer_options"] 171 | ] 172 | ) 173 | answer_options_in.append(options) 174 | 175 | options, _ = self._pad_sequences( 176 | [ 177 | option[1:] 178 | for option in dialog_round["answer_options"] 179 | ] 180 | ) 181 | answer_options_out.append(options) 182 | 183 | answer_option_lengths.append(option_lengths) 184 | answer_options_in = torch.stack(answer_options_in, 0) 185 | answer_options_out = torch.stack(answer_options_out, 0) 186 | 187 | item["opt_in"] = answer_options_in.long() 188 | item["opt_out"] = answer_options_out.long() 189 | item["opt_len"] = torch.tensor(answer_option_lengths).long() 190 | else: 191 | answer_options = [] 192 | answer_option_lengths = [] 193 | for dialog_round in dialog: 194 | options, option_lengths = self._pad_sequences( 195 | dialog_round["answer_options"] 196 | ) 197 | answer_options.append(options) 198 | answer_option_lengths.append(option_lengths) 199 | answer_options = torch.stack(answer_options, 0) 200 | 201 | item["opt"] = answer_options.long() 202 | item["opt_len"] = torch.tensor(answer_option_lengths).long() 203 | 204 | if "test" not in self.split: 205 | answer_indices = [ 206 | dialog_round["gt_index"] for dialog_round in dialog 207 | ] 208 | item["ans_ind"] = torch.tensor(answer_indices).long() 209 | 210 | # Gather dense annotations. 211 | if "val" in self.split: 212 | dense_annotations = self.annotations_reader[image_id] 213 | item["gt_relevance"] = torch.tensor( 214 | dense_annotations["gt_relevance"] 215 | ).float() 216 | item["round_id"] = torch.tensor( 217 | dense_annotations["round_id"] 218 | ).long() 219 | 220 | return item 221 | 222 | def _pad_sequences(self, sequences: List[List[int]]): 223 | """Given tokenized sequences (either questions, answers or answer 224 | options, tokenized in ``__getitem__``), padding them to maximum 225 | specified sequence length. Return as a tensor of size 226 | ``(*, max_sequence_length)``. 227 | 228 | This method is only called in ``__getitem__``, chunked out separately 229 | for readability. 230 | 231 | Parameters 232 | ---------- 233 | sequences : List[List[int]] 234 | List of tokenized sequences, each sequence is typically a 235 | List[int]. 236 | 237 | Returns 238 | ------- 239 | torch.Tensor, torch.Tensor 240 | Tensor of sequences padded to max length, and length of sequences 241 | before padding. 242 | """ 243 | 244 | for i in range(len(sequences)): 245 | sequences[i] = sequences[i][ 246 | : self.config["max_sequence_length"] - 1 247 | ] 248 | sequence_lengths = [len(sequence) for sequence in sequences] 249 | 250 | # Pad all sequences to max_sequence_length. 251 | maxpadded_sequences = torch.full( 252 | (len(sequences), self.config["max_sequence_length"]), 253 | fill_value=self.vocabulary.PAD_INDEX, 254 | ) 255 | padded_sequences = pad_sequence( 256 | [torch.tensor(sequence) for sequence in sequences], 257 | batch_first=True, 258 | padding_value=self.vocabulary.PAD_INDEX, 259 | ) 260 | maxpadded_sequences[:, : padded_sequences.size(1)] = padded_sequences 261 | return maxpadded_sequences, sequence_lengths 262 | 263 | def _get_history( 264 | self, 265 | caption: List[int], 266 | questions: List[List[int]], 267 | answers: List[List[int]], 268 | ): 269 | # Allow double length of caption, equivalent to a concatenated QA pair. 270 | caption = caption[: self.config["max_sequence_length"] * 2 - 1] 271 | 272 | for i in range(len(questions)): 273 | questions[i] = questions[i][ 274 | : self.config["max_sequence_length"] - 1 275 | ] 276 | 277 | for i in range(len(answers)): 278 | answers[i] = answers[i][: self.config["max_sequence_length"] - 1] 279 | 280 | # History for first round is caption, else concatenated QA pair of 281 | # previous round. 282 | history = [] 283 | history.append(caption) 284 | for question, answer in zip(questions, answers): 285 | history.append(question + answer + [self.vocabulary.EOS_INDEX]) 286 | # Drop last entry from history (there's no eleventh question). 287 | history = history[:-1] 288 | max_history_length = self.config["max_sequence_length"] * 2 289 | 290 | if self.config.get("concat_history", False): 291 | # Concatenated_history has similar structure as history, except it 292 | # contains concatenated QA pairs from previous rounds. 293 | concatenated_history = [] 294 | concatenated_history.append(caption) 295 | for i in range(1, len(history)): 296 | concatenated_history.append([]) 297 | for j in range(i + 1): 298 | concatenated_history[i].extend(history[j]) 299 | 300 | max_history_length = ( 301 | self.config["max_sequence_length"] * 2 * len(history) 302 | ) 303 | history = concatenated_history 304 | 305 | history_lengths = [len(round_history) for round_history in history] 306 | maxpadded_history = torch.full( 307 | (len(history), max_history_length), 308 | fill_value=self.vocabulary.PAD_INDEX, 309 | ) 310 | padded_history = pad_sequence( 311 | [torch.tensor(round_history) for round_history in history], 312 | batch_first=True, 313 | padding_value=self.vocabulary.PAD_INDEX, 314 | ) 315 | maxpadded_history[:, : padded_history.size(1)] = padded_history 316 | return maxpadded_history, history_lengths 317 | -------------------------------------------------------------------------------- /visdialch/data/readers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Reader simply reads data from disk and returns it almost as is, based on 3 | a "primary key", which for the case of VisDial v1.0 dataset, is the 4 | ``image_id``. Readers should be utilized by torch ``Dataset``s. Any type of 5 | data pre-processing is not recommended in the reader, such as tokenizing words 6 | to integers, embedding tokens, or passing an image through a pre-trained CNN. 7 | 8 | Each reader must atleast implement three methods: 9 | - ``__len__`` to return the length of data this Reader can read. 10 | - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 11 | dataset. 12 | - ``keys`` to return a list of possible ``image_id``s this Reader can 13 | provide data of. 14 | """ 15 | 16 | import copy 17 | import json 18 | import multiprocessing as mp 19 | from typing import Any, Dict, List, Optional, Set, Union 20 | 21 | import h5py 22 | 23 | # A bit slow, and just splits sentences to list of words, can be doable in 24 | # `DialogsReader`. 25 | from nltk.tokenize import word_tokenize 26 | from tqdm import tqdm 27 | 28 | 29 | class DialogsReader(object): 30 | """ 31 | A simple reader for VisDial v1.0 dialog data. The json file must have the 32 | same structure as mentioned on ``https://visualdialog.org/data``. 33 | 34 | Parameters 35 | ---------- 36 | dialogs_jsonpath : str 37 | Path to json file containing VisDial v1.0 train, val or test data. 38 | num_examples: int, optional (default = None) 39 | Process first ``num_examples`` from the split. Useful to speed up while 40 | debugging. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dialogs_jsonpath: str, 46 | num_examples: Optional[int] = None, 47 | num_workers: int = 1, 48 | ): 49 | with open(dialogs_jsonpath, "r") as visdial_file: 50 | visdial_data = json.load(visdial_file) 51 | self._split = visdial_data["split"] 52 | 53 | # Maintain questions and answers as a dict instead of list because 54 | # they are referenced by index in dialogs. We drop elements from 55 | # these in "overfit" mode to save time (tokenization is slow). 56 | self.questions = { 57 | i: question for i, question in 58 | enumerate(visdial_data["data"]["questions"]) 59 | } 60 | self.answers = { 61 | i: answer for i, answer in 62 | enumerate(visdial_data["data"]["answers"]) 63 | } 64 | 65 | # Add empty question, answer - useful for padding dialog rounds 66 | # for test split. 67 | self.questions[-1] = "" 68 | self.answers[-1] = "" 69 | 70 | # ``image_id``` serves as key for all three dicts here. 71 | self.captions: Dict[int, Any] = {} 72 | self.dialogs: Dict[int, Any] = {} 73 | self.num_rounds: Dict[int, Any] = {} 74 | 75 | all_dialogs = visdial_data["data"]["dialogs"] 76 | 77 | # Retain only first ``num_examples`` dialogs if specified. 78 | if num_examples is not None: 79 | all_dialogs = all_dialogs[:num_examples] 80 | 81 | for _dialog in all_dialogs: 82 | 83 | self.captions[_dialog["image_id"]] = _dialog["caption"] 84 | 85 | # Record original length of dialog, before padding. 86 | # 10 for train and val splits, 10 or less for test split. 87 | self.num_rounds[_dialog["image_id"]] = len(_dialog["dialog"]) 88 | 89 | # Pad dialog at the end with empty question and answer pairs 90 | # (for test split). 91 | while len(_dialog["dialog"]) < 10: 92 | _dialog["dialog"].append({"question": -1, "answer": -1}) 93 | 94 | # Add empty answer (and answer options) if not provided 95 | # (for test split). We use "-1" as a key for empty questions 96 | # and answers. 97 | for i in range(len(_dialog["dialog"])): 98 | if "answer" not in _dialog["dialog"][i]: 99 | _dialog["dialog"][i]["answer"] = -1 100 | if "answer_options" not in _dialog["dialog"][i]: 101 | _dialog["dialog"][i]["answer_options"] = [-1] * 100 102 | 103 | self.dialogs[_dialog["image_id"]] = _dialog["dialog"] 104 | 105 | # If ``num_examples`` is specified, collect questions and answers 106 | # included in those examples, and drop the rest to save time while 107 | # tokenizing. Collecting these should be fast because num_examples 108 | # during debugging are generally small. 109 | if num_examples is not None: 110 | questions_included: Set[int] = set() 111 | answers_included: Set[int] = set() 112 | 113 | for _dialog in self.dialogs.values(): 114 | for _dialog_round in _dialog: 115 | questions_included.add(_dialog_round["question"]) 116 | answers_included.add(_dialog_round["answer"]) 117 | for _answer_option in _dialog_round["answer_options"]: 118 | answers_included.add(_answer_option) 119 | 120 | self.questions = { 121 | i: self.questions[i] for i in questions_included 122 | } 123 | self.answers = { 124 | i: self.answers[i] for i in answers_included 125 | } 126 | 127 | self._multiprocess_tokenize(num_workers) 128 | 129 | def _multiprocess_tokenize(self, num_workers: int): 130 | """ 131 | Tokenize captions, questions and answers in parallel processes. This 132 | method uses multiprocessing module internally. 133 | 134 | Since questions, answers and captions are dicts - and multiprocessing 135 | map utilities operate on lists, we convert these to lists first and 136 | then back to dicts. 137 | 138 | Parameters 139 | ---------- 140 | num_workers: int 141 | Number of workers (processes) to run in parallel. 142 | """ 143 | 144 | # While displaying progress bar through tqdm, specify total number of 145 | # sequences to tokenize, because tqdm won't know in case of pool.imap 146 | with mp.Pool(num_workers) as pool: 147 | print(f"[{self._split}] Tokenizing questions...") 148 | _question_tuples = self.questions.items() 149 | _question_indices = [t[0] for t in _question_tuples] 150 | _questions = list( 151 | tqdm( 152 | pool.imap(word_tokenize, [t[1] for t in _question_tuples]), 153 | total=len(self.questions) 154 | ) 155 | ) 156 | self.questions = { 157 | i: question + ["?"] for i, question in 158 | zip(_question_indices, _questions) 159 | } 160 | # Delete variables to free memory. 161 | del _question_tuples, _question_indices, _questions 162 | 163 | print(f"[{self._split}] Tokenizing answers...") 164 | _answer_tuples = self.answers.items() 165 | _answer_indices = [t[0] for t in _answer_tuples] 166 | _answers = list( 167 | tqdm( 168 | pool.imap(word_tokenize, [t[1] for t in _answer_tuples]), 169 | total=len(self.answers) 170 | ) 171 | ) 172 | self.answers = { 173 | i: answer + ["?"] for i, answer in 174 | zip(_answer_indices, _answers) 175 | } 176 | # Delete variables to free memory. 177 | del _answer_tuples, _answer_indices, _answers 178 | 179 | print(f"[{self._split}] Tokenizing captions...") 180 | # Convert dict to separate lists of image_ids and captions. 181 | _caption_tuples = self.captions.items() 182 | _image_ids = [t[0] for t in _caption_tuples] 183 | _captions = list( 184 | tqdm( 185 | pool.imap(word_tokenize, [t[1] for t in _caption_tuples]), 186 | total=(len(_caption_tuples)) 187 | ) 188 | ) 189 | # Convert tokenized captions back to a dict. 190 | self.captions = {i: c for i, c in zip(_image_ids, _captions)} 191 | 192 | def __len__(self): 193 | return len(self.dialogs) 194 | 195 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: 196 | caption_for_image = self.captions[image_id] 197 | dialog = copy.copy(self.dialogs[image_id]) 198 | num_rounds = self.num_rounds[image_id] 199 | 200 | # Replace question and answer indices with actual word tokens. 201 | for i in range(len(dialog)): 202 | dialog[i]["question"] = self.questions[ 203 | dialog[i]["question"] 204 | ] 205 | dialog[i]["answer"] = self.answers[ 206 | dialog[i]["answer"] 207 | ] 208 | for j, answer_option in enumerate( 209 | dialog[i]["answer_options"] 210 | ): 211 | dialog[i]["answer_options"][j] = self.answers[ 212 | answer_option 213 | ] 214 | 215 | return { 216 | "image_id": image_id, 217 | "caption": caption_for_image, 218 | "dialog": dialog, 219 | "num_rounds": num_rounds, 220 | } 221 | 222 | def keys(self) -> List[int]: 223 | return list(self.dialogs.keys()) 224 | 225 | @property 226 | def split(self): 227 | return self._split 228 | 229 | 230 | class DenseAnnotationsReader(object): 231 | """ 232 | A reader for dense annotations for val split. The json file must have the 233 | same structure as mentioned on ``https://visualdialog.org/data``. 234 | 235 | Parameters 236 | ---------- 237 | dense_annotations_jsonpath : str 238 | Path to a json file containing VisDial v1.0 239 | """ 240 | 241 | def __init__(self, dense_annotations_jsonpath: str): 242 | with open(dense_annotations_jsonpath, "r") as visdial_file: 243 | self._visdial_data = json.load(visdial_file) 244 | self._image_ids = [ 245 | entry["image_id"] for entry in self._visdial_data 246 | ] 247 | 248 | def __len__(self): 249 | return len(self._image_ids) 250 | 251 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: 252 | index = self._image_ids.index(image_id) 253 | # keys: {"image_id", "round_id", "gt_relevance"} 254 | return self._visdial_data[index] 255 | 256 | @property 257 | def split(self): 258 | # always 259 | return "val" 260 | 261 | 262 | class ImageFeaturesHdfReader(object): 263 | """ 264 | A reader for HDF files containing pre-extracted image features. A typical 265 | HDF file is expected to have a column named "image_id", and another column 266 | named "features". 267 | 268 | Example of an HDF file: 269 | ``` 270 | visdial_train_faster_rcnn_bottomup_features.h5 271 | |--- "image_id" [shape: (num_images, )] 272 | |--- "features" [shape: (num_images, num_proposals, feature_size)] 273 | +--- .attrs ("split", "train") 274 | ``` 275 | Refer ``$PROJECT_ROOT/data/extract_bottomup.py`` script for more details 276 | about HDF structure. 277 | 278 | Parameters 279 | ---------- 280 | features_hdfpath : str 281 | Path to an HDF file containing VisDial v1.0 train, val or test split 282 | image features. 283 | in_memory : bool 284 | Whether to load the whole HDF file in memory. Beware, these files are 285 | sometimes tens of GBs in size. Set this to true if you have sufficient 286 | RAM - trade-off between speed and memory. 287 | """ 288 | 289 | def __init__(self, features_hdfpath: str, in_memory: bool = False): 290 | self.features_hdfpath = features_hdfpath 291 | self._in_memory = in_memory 292 | 293 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 294 | self._split = features_hdf.attrs["split"] 295 | self._image_id_list = list(features_hdf["image_id"]) 296 | # "features" is List[np.ndarray] if the dataset is loaded in-memory 297 | # If not loaded in memory, then list of None. 298 | self.features = [None] * len(self._image_id_list) 299 | 300 | def __len__(self): 301 | return len(self._image_id_list) 302 | 303 | def __getitem__(self, image_id: int): 304 | index = self._image_id_list.index(image_id) 305 | if self._in_memory: 306 | # Load features during first epoch, all not loaded together as it 307 | # has a slow start. 308 | if self.features[index] is not None: 309 | image_id_features = self.features[index] 310 | else: 311 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 312 | image_id_features = features_hdf["features"][index] 313 | self.features[index] = image_id_features 314 | else: 315 | # Read chunk from file everytime if not loaded in memory. 316 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 317 | image_id_features = features_hdf["features"][index] 318 | 319 | return image_id_features 320 | 321 | def keys(self) -> List[int]: 322 | return self._image_id_list 323 | 324 | @property 325 | def split(self): 326 | return self._split 327 | -------------------------------------------------------------------------------- /visdialch/data/vocabulary.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Vocabulary maintains a mapping between words and corresponding unique 3 | integers, holds special integers (tokens) for indicating start and end of 4 | sequence, and offers functionality to map out-of-vocabulary words to the 5 | corresponding token. 6 | """ 7 | import json 8 | import os 9 | from typing import List 10 | 11 | 12 | class Vocabulary(object): 13 | """ 14 | A simple Vocabulary class which maintains a mapping between words and 15 | integer tokens. Can be initialized either by word counts from the VisDial 16 | v1.0 train dataset, or a pre-saved vocabulary mapping. 17 | 18 | Parameters 19 | ---------- 20 | word_counts_path: str 21 | Path to a json file containing counts of each word across captions, 22 | questions and answers of the VisDial v1.0 train dataset. 23 | min_count : int, optional (default=0) 24 | When initializing the vocabulary from word counts, you can specify a 25 | minimum count, and every token with a count less than this will be 26 | excluded from vocabulary. 27 | """ 28 | 29 | PAD_TOKEN = "" 30 | SOS_TOKEN = "" 31 | EOS_TOKEN = "" 32 | UNK_TOKEN = "" 33 | 34 | PAD_INDEX = 0 35 | SOS_INDEX = 1 36 | EOS_INDEX = 2 37 | UNK_INDEX = 3 38 | 39 | def __init__(self, word_counts_path: str, min_count: int = 5): 40 | if not os.path.exists(word_counts_path): 41 | raise FileNotFoundError( 42 | f"Word counts do not exist at {word_counts_path}" 43 | ) 44 | 45 | with open(word_counts_path, "r") as word_counts_file: 46 | word_counts = json.load(word_counts_file) 47 | 48 | # form a list of (word, count) tuples and apply min_count threshold 49 | word_counts = [ 50 | (word, count) 51 | for word, count in word_counts.items() 52 | if count >= min_count 53 | ] 54 | # sort in descending order of word counts 55 | word_counts = sorted(word_counts, key=lambda wc: -wc[1]) 56 | words = [w[0] for w in word_counts] 57 | 58 | self.word2index = {} 59 | self.word2index[self.PAD_TOKEN] = self.PAD_INDEX 60 | self.word2index[self.SOS_TOKEN] = self.SOS_INDEX 61 | self.word2index[self.EOS_TOKEN] = self.EOS_INDEX 62 | self.word2index[self.UNK_TOKEN] = self.UNK_INDEX 63 | for index, word in enumerate(words): 64 | self.word2index[word] = index + 4 65 | 66 | self.index2word = { 67 | index: word for word, index in self.word2index.items() 68 | } 69 | 70 | @classmethod 71 | def from_saved(cls, saved_vocabulary_path: str) -> "Vocabulary": 72 | """Build the vocabulary from a json file saved by ``save`` method. 73 | 74 | Parameters 75 | ---------- 76 | saved_vocabulary_path : str 77 | Path to a json file containing word to integer mappings 78 | (saved vocabulary). 79 | """ 80 | with open(saved_vocabulary_path, "r") as saved_vocabulary_file: 81 | cls.word2index = json.load(saved_vocabulary_file) 82 | cls.index2word = { 83 | index: word for word, index in cls.word2index.items() 84 | } 85 | 86 | def to_indices(self, words: List[str]) -> List[int]: 87 | return [self.word2index.get(word, self.UNK_INDEX) for word in words] 88 | 89 | def to_words(self, indices: List[int]) -> List[str]: 90 | return [ 91 | self.index2word.get(index, self.UNK_TOKEN) for index in indices 92 | ] 93 | 94 | def save(self, save_vocabulary_path: str) -> None: 95 | with open(save_vocabulary_path, "w") as save_vocabulary_file: 96 | json.dump(self.word2index, save_vocabulary_file) 97 | 98 | def __len__(self): 99 | return len(self.index2word) 100 | -------------------------------------------------------------------------------- /visdialch/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.decoders.disc import DiscriminativeDecoder 2 | from visdialch.decoders.gen import GenerativeDecoder 3 | 4 | 5 | def Decoder(model_config, *args): 6 | name_dec_map = {"disc": DiscriminativeDecoder, "gen": GenerativeDecoder} 7 | return name_dec_map[model_config["decoder"]](model_config, *args) 8 | -------------------------------------------------------------------------------- /visdialch/decoders/disc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from visdialch.utils import DynamicRNN 5 | 6 | 7 | class DiscriminativeDecoder(nn.Module): 8 | def __init__(self, config, vocabulary): 9 | super().__init__() 10 | self.config = config 11 | 12 | self.word_embed = nn.Embedding( 13 | len(vocabulary), 14 | config["word_embedding_size"], 15 | padding_idx=vocabulary.PAD_INDEX, 16 | ) 17 | self.option_rnn = nn.LSTM( 18 | config["word_embedding_size"], 19 | config["lstm_hidden_size"], 20 | config["lstm_num_layers"], 21 | batch_first=True, 22 | dropout=config["dropout"], 23 | ) 24 | 25 | # Options are variable length padded sequences, use DynamicRNN. 26 | self.option_rnn = DynamicRNN(self.option_rnn) 27 | 28 | def forward(self, encoder_output, batch): 29 | """Given `encoder_output` + candidate option sequences, predict a score 30 | for each option sequence. 31 | 32 | Parameters 33 | ---------- 34 | encoder_output: torch.Tensor 35 | Output from the encoder through its forward pass. 36 | (batch_size, num_rounds, lstm_hidden_size) 37 | """ 38 | 39 | options = batch["opt"] 40 | batch_size, num_rounds, num_options, max_sequence_length = ( 41 | options.size() 42 | ) 43 | options = options.view( 44 | batch_size * num_rounds * num_options, max_sequence_length 45 | ) 46 | 47 | options_length = batch["opt_len"] 48 | options_length = options_length.view( 49 | batch_size * num_rounds * num_options 50 | ) 51 | 52 | # Pick options with non-zero length (relevant for test split). 53 | nonzero_options_length_indices = options_length.nonzero().squeeze() 54 | nonzero_options_length = options_length[nonzero_options_length_indices] 55 | nonzero_options = options[nonzero_options_length_indices] 56 | 57 | # shape: (batch_size * num_rounds * num_options, max_sequence_length, 58 | # word_embedding_size) 59 | # FOR TEST SPLIT, shape: (batch_size * 1, num_options, 60 | # max_sequence_length, word_embedding_size) 61 | nonzero_options_embed = self.word_embed(nonzero_options) 62 | 63 | # shape: (batch_size * num_rounds * num_options, lstm_hidden_size) 64 | # FOR TEST SPLIT, shape: (batch_size * 1, num_options, 65 | # lstm_hidden_size) 66 | _, (nonzero_options_embed, _) = self.option_rnn( 67 | nonzero_options_embed, nonzero_options_length 68 | ) 69 | 70 | options_embed = torch.zeros( 71 | batch_size * num_rounds * num_options, 72 | nonzero_options_embed.size(-1), 73 | device=nonzero_options_embed.device, 74 | ) 75 | options_embed[nonzero_options_length_indices] = nonzero_options_embed 76 | 77 | # Repeat encoder output for every option. 78 | # shape: (batch_size, num_rounds, num_options, max_sequence_length) 79 | encoder_output = encoder_output.unsqueeze(2).repeat( 80 | 1, 1, num_options, 1 81 | ) 82 | 83 | # Shape now same as `options`, can calculate dot product similarity. 84 | # shape: (batch_size * num_rounds * num_options, lstm_hidden_state) 85 | encoder_output = encoder_output.view( 86 | batch_size * num_rounds * num_options, 87 | self.config["lstm_hidden_size"], 88 | ) 89 | 90 | # shape: (batch_size * num_rounds * num_options) 91 | scores = torch.sum(options_embed * encoder_output, 1) 92 | # shape: (batch_size, num_rounds, num_options) 93 | scores = scores.view(batch_size, num_rounds, num_options) 94 | return scores 95 | -------------------------------------------------------------------------------- /visdialch/decoders/gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GenerativeDecoder(nn.Module): 6 | def __init__(self, config, vocabulary): 7 | super().__init__() 8 | self.config = config 9 | 10 | self.word_embed = nn.Embedding( 11 | len(vocabulary), 12 | config["word_embedding_size"], 13 | padding_idx=vocabulary.PAD_INDEX, 14 | ) 15 | self.answer_rnn = nn.LSTM( 16 | config["word_embedding_size"], 17 | config["lstm_hidden_size"], 18 | config["lstm_num_layers"], 19 | batch_first=True, 20 | dropout=config["dropout"], 21 | ) 22 | 23 | self.lstm_to_words = nn.Linear( 24 | self.config["lstm_hidden_size"], len(vocabulary) 25 | ) 26 | 27 | self.dropout = nn.Dropout(p=config["dropout"]) 28 | self.logsoftmax = nn.LogSoftmax(dim=-1) 29 | 30 | def forward(self, encoder_output, batch): 31 | """Given `encoder_output`, learn to autoregressively predict 32 | ground-truth answer word-by-word during training. 33 | 34 | During evaluation, assign log-likelihood scores to all answer options. 35 | 36 | Parameters 37 | ---------- 38 | encoder_output: torch.Tensor 39 | Output from the encoder through its forward pass. 40 | (batch_size, num_rounds, lstm_hidden_size) 41 | """ 42 | 43 | if self.training: 44 | 45 | ans_in = batch["ans_in"] 46 | batch_size, num_rounds, max_sequence_length = ans_in.size() 47 | 48 | ans_in = ans_in.view(batch_size * num_rounds, max_sequence_length) 49 | 50 | # shape: (batch_size * num_rounds, max_sequence_length, 51 | # word_embedding_size) 52 | ans_in_embed = self.word_embed(ans_in) 53 | 54 | # reshape encoder output to be set as initial hidden state of LSTM. 55 | # shape: (lstm_num_layers, batch_size * num_rounds, 56 | # lstm_hidden_size) 57 | init_hidden = encoder_output.view(1, batch_size * num_rounds, -1) 58 | init_hidden = init_hidden.repeat( 59 | self.config["lstm_num_layers"], 1, 1 60 | ) 61 | init_cell = torch.zeros_like(init_hidden) 62 | 63 | # shape: (batch_size * num_rounds, max_sequence_length, 64 | # lstm_hidden_size) 65 | ans_out, (hidden, cell) = self.answer_rnn( 66 | ans_in_embed, (init_hidden, init_cell) 67 | ) 68 | ans_out = self.dropout(ans_out) 69 | 70 | # shape: (batch_size * num_rounds, max_sequence_length, 71 | # vocabulary_size) 72 | ans_word_scores = self.lstm_to_words(ans_out) 73 | return ans_word_scores 74 | 75 | else: 76 | 77 | ans_in = batch["opt_in"] 78 | batch_size, num_rounds, num_options, max_sequence_length = ( 79 | ans_in.size() 80 | ) 81 | 82 | ans_in = ans_in.view( 83 | batch_size * num_rounds * num_options, max_sequence_length 84 | ) 85 | 86 | # shape: (batch_size * num_rounds * num_options, max_sequence_length 87 | # word_embedding_size) 88 | ans_in_embed = self.word_embed(ans_in) 89 | 90 | # reshape encoder output to be set as initial hidden state of LSTM. 91 | # shape: (lstm_num_layers, batch_size * num_rounds * num_options, 92 | # lstm_hidden_size) 93 | init_hidden = encoder_output.view(batch_size, num_rounds, 1, -1) 94 | init_hidden = init_hidden.repeat(1, 1, num_options, 1) 95 | init_hidden = init_hidden.view( 96 | 1, batch_size * num_rounds * num_options, -1 97 | ) 98 | init_hidden = init_hidden.repeat( 99 | self.config["lstm_num_layers"], 1, 1 100 | ) 101 | init_cell = torch.zeros_like(init_hidden) 102 | 103 | # shape: (batch_size * num_rounds * num_options, 104 | # max_sequence_length, lstm_hidden_size) 105 | ans_out, (hidden, cell) = self.answer_rnn( 106 | ans_in_embed, (init_hidden, init_cell) 107 | ) 108 | 109 | # shape: (batch_size * num_rounds * num_options, 110 | # max_sequence_length, vocabulary_size) 111 | ans_word_scores = self.logsoftmax(self.lstm_to_words(ans_out)) 112 | 113 | # shape: (batch_size * num_rounds * num_options, 114 | # max_sequence_length) 115 | target_ans_out = batch["opt_out"].view( 116 | batch_size * num_rounds * num_options, -1 117 | ) 118 | 119 | # shape: (batch_size * num_rounds * num_options, 120 | # max_sequence_length) 121 | ans_word_scores = torch.gather( 122 | ans_word_scores, -1, target_ans_out.unsqueeze(-1) 123 | ).squeeze() 124 | ans_word_scores = ( 125 | ans_word_scores * (target_ans_out > 0).float().cuda() 126 | ) # ugly 127 | 128 | ans_scores = torch.sum(ans_word_scores, -1) 129 | ans_scores = ans_scores.view(batch_size, num_rounds, num_options) 130 | 131 | return ans_scores 132 | -------------------------------------------------------------------------------- /visdialch/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.encoders.lf import LateFusionEncoder 2 | 3 | 4 | def Encoder(model_config, *args): 5 | name_enc_map = {"lf": LateFusionEncoder} 6 | return name_enc_map[model_config["encoder"]](model_config, *args) 7 | -------------------------------------------------------------------------------- /visdialch/encoders/lf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from visdialch.utils import DynamicRNN 6 | 7 | 8 | class LateFusionEncoder(nn.Module): 9 | def __init__(self, config, vocabulary): 10 | super().__init__() 11 | self.config = config 12 | 13 | self.word_embed = nn.Embedding( 14 | len(vocabulary), 15 | config["word_embedding_size"], 16 | padding_idx=vocabulary.PAD_INDEX, 17 | ) 18 | self.hist_rnn = nn.LSTM( 19 | config["word_embedding_size"], 20 | config["lstm_hidden_size"], 21 | config["lstm_num_layers"], 22 | batch_first=True, 23 | dropout=config["dropout"], 24 | ) 25 | self.ques_rnn = nn.LSTM( 26 | config["word_embedding_size"], 27 | config["lstm_hidden_size"], 28 | config["lstm_num_layers"], 29 | batch_first=True, 30 | dropout=config["dropout"], 31 | ) 32 | self.dropout = nn.Dropout(p=config["dropout"]) 33 | 34 | # questions and history are right padded sequences of variable length 35 | # use the DynamicRNN utility module to handle them properly 36 | self.hist_rnn = DynamicRNN(self.hist_rnn) 37 | self.ques_rnn = DynamicRNN(self.ques_rnn) 38 | 39 | # project image features to lstm_hidden_size for computing attention 40 | self.image_features_projection = nn.Linear( 41 | config["img_feature_size"], config["lstm_hidden_size"] 42 | ) 43 | 44 | # fc layer for image * question to attention weights 45 | self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1) 46 | 47 | # fusion layer (attended_image_features + question + history) 48 | fusion_size = ( 49 | config["img_feature_size"] + config["lstm_hidden_size"] * 2 50 | ) 51 | self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"]) 52 | 53 | nn.init.kaiming_uniform_(self.image_features_projection.weight) 54 | nn.init.constant_(self.image_features_projection.bias, 0) 55 | nn.init.kaiming_uniform_(self.fusion.weight) 56 | nn.init.constant_(self.fusion.bias, 0) 57 | 58 | def forward(self, batch): 59 | # shape: (batch_size, img_feature_size) - CNN fc7 features 60 | # shape: (batch_size, num_proposals, img_feature_size) - RCNN features 61 | img = batch["img_feat"] 62 | # shape: (batch_size, 10, max_sequence_length) 63 | ques = batch["ques"] 64 | # shape: (batch_size, 10, max_sequence_length * 2 * 10) 65 | # concatenated qa * 10 rounds 66 | hist = batch["hist"] 67 | # num_rounds = 10, even for test (padded dialog rounds at the end) 68 | batch_size, num_rounds, max_sequence_length = ques.size() 69 | 70 | # embed questions 71 | ques = ques.view(batch_size * num_rounds, max_sequence_length) 72 | ques_embed = self.word_embed(ques) 73 | 74 | # shape: (batch_size * num_rounds, max_sequence_length, 75 | # lstm_hidden_size) 76 | _, (ques_embed, _) = self.ques_rnn(ques_embed, batch["ques_len"]) 77 | 78 | # project down image features and ready for attention 79 | # shape: (batch_size, num_proposals, lstm_hidden_size) 80 | projected_image_features = self.image_features_projection(img) 81 | 82 | # repeat image feature vectors to be provided for every round 83 | # shape: (batch_size * num_rounds, num_proposals, lstm_hidden_size) 84 | projected_image_features = ( 85 | projected_image_features.view( 86 | batch_size, 1, -1, self.config["lstm_hidden_size"] 87 | ) 88 | .repeat(1, num_rounds, 1, 1) 89 | .view(batch_size * num_rounds, -1, self.config["lstm_hidden_size"]) 90 | ) 91 | 92 | # computing attention weights 93 | # shape: (batch_size * num_rounds, num_proposals) 94 | projected_ques_features = ques_embed.unsqueeze(1).repeat( 95 | 1, img.shape[1], 1 96 | ) 97 | projected_ques_image = ( 98 | projected_ques_features * projected_image_features 99 | ) 100 | projected_ques_image = self.dropout(projected_ques_image) 101 | image_attention_weights = self.attention_proj( 102 | projected_ques_image 103 | ).squeeze() 104 | image_attention_weights = F.softmax(image_attention_weights, dim=-1) 105 | 106 | # shape: (batch_size * num_rounds, num_proposals, img_features_size) 107 | img = ( 108 | img.view(batch_size, 1, -1, self.config["img_feature_size"]) 109 | .repeat(1, num_rounds, 1, 1) 110 | .view(batch_size * num_rounds, -1, self.config["img_feature_size"]) 111 | ) 112 | 113 | # multiply image features with their attention weights 114 | # shape: (batch_size * num_rounds, num_proposals, img_feature_size) 115 | image_attention_weights = image_attention_weights.unsqueeze(-1).repeat( 116 | 1, 1, self.config["img_feature_size"] 117 | ) 118 | # shape: (batch_size * num_rounds, img_feature_size) 119 | attended_image_features = (image_attention_weights * img).sum(1) 120 | img = attended_image_features 121 | 122 | # embed history 123 | hist = hist.view(batch_size * num_rounds, max_sequence_length * 20) 124 | hist_embed = self.word_embed(hist) 125 | 126 | # shape: (batch_size * num_rounds, lstm_hidden_size) 127 | _, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"]) 128 | 129 | fused_vector = torch.cat((img, ques_embed, hist_embed), 1) 130 | fused_vector = self.dropout(fused_vector) 131 | 132 | fused_embedding = torch.tanh(self.fusion(fused_vector)) 133 | # shape: (batch_size, num_rounds, lstm_hidden_size) 134 | fused_embedding = fused_embedding.view(batch_size, num_rounds, -1) 135 | return fused_embedding 136 | -------------------------------------------------------------------------------- /visdialch/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Metric observes output of certain model, for example, in form of logits or 3 | scores, and accumulates a particular metric with reference to some provided 4 | targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean 5 | Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). 6 | 7 | Each ``Metric`` must atleast implement three methods: 8 | - ``observe``, update accumulated metric with currently observed outputs 9 | and targets. 10 | - ``retrieve`` to return the accumulated metric., an optionally reset 11 | internally accumulated metric (this is commonly done between two epochs 12 | after validation). 13 | - ``reset`` to explicitly reset the internally accumulated metric. 14 | 15 | Caveat, if you wish to implement your own class of Metric, make sure you call 16 | ``detach`` on output tensors (like logits), else it will cause memory leaks. 17 | """ 18 | import torch 19 | 20 | 21 | def scores_to_ranks(scores: torch.Tensor): 22 | """Convert model output scores into ranks.""" 23 | batch_size, num_rounds, num_options = scores.size() 24 | scores = scores.view(-1, num_options) 25 | 26 | # sort in descending order - largest score gets highest rank 27 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 28 | 29 | # i-th position in ranked_idx specifies which score shall take this 30 | # position but we want i-th position to have rank of score at that 31 | # position, do this conversion 32 | ranks = ranked_idx.clone().fill_(0) 33 | for i in range(ranked_idx.size(0)): 34 | for j in range(num_options): 35 | ranks[i][ranked_idx[i][j]] = j 36 | # convert from 0-99 ranks to 1-100 ranks 37 | ranks += 1 38 | ranks = ranks.view(batch_size, num_rounds, num_options) 39 | return ranks 40 | 41 | 42 | class SparseGTMetrics(object): 43 | """ 44 | A class to accumulate all metrics with sparse ground truth annotations. 45 | These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. 46 | """ 47 | 48 | def __init__(self): 49 | self._rank_list = [] 50 | 51 | def observe( 52 | self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor 53 | ): 54 | predicted_scores = predicted_scores.detach() 55 | 56 | # shape: (batch_size, num_rounds, num_options) 57 | predicted_ranks = scores_to_ranks(predicted_scores) 58 | batch_size, num_rounds, num_options = predicted_ranks.size() 59 | 60 | # collapse batch dimension 61 | predicted_ranks = predicted_ranks.view( 62 | batch_size * num_rounds, num_options 63 | ) 64 | 65 | # shape: (batch_size * num_rounds, ) 66 | target_ranks = target_ranks.view(batch_size * num_rounds).long() 67 | 68 | # shape: (batch_size * num_rounds, ) 69 | predicted_gt_ranks = predicted_ranks[ 70 | torch.arange(batch_size * num_rounds), target_ranks 71 | ] 72 | self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) 73 | 74 | def retrieve(self, reset: bool = True): 75 | num_examples = len(self._rank_list) 76 | if num_examples > 0: 77 | # convert to numpy array for easy calculation. 78 | __rank_list = torch.tensor(self._rank_list).float() 79 | metrics = { 80 | "r@1": torch.mean((__rank_list <= 1).float()).item(), 81 | "r@5": torch.mean((__rank_list <= 5).float()).item(), 82 | "r@10": torch.mean((__rank_list <= 10).float()).item(), 83 | "mean": torch.mean(__rank_list).item(), 84 | "mrr": torch.mean(__rank_list.reciprocal()).item(), 85 | } 86 | else: 87 | metrics = {} 88 | 89 | if reset: 90 | self.reset() 91 | return metrics 92 | 93 | def reset(self): 94 | self._rank_list = [] 95 | 96 | 97 | class NDCG(object): 98 | def __init__(self): 99 | self._ndcg_numerator = 0.0 100 | self._ndcg_denominator = 0.0 101 | 102 | def observe( 103 | self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor 104 | ): 105 | """ 106 | Observe model output scores and target ground truth relevance and 107 | accumulate NDCG metric. 108 | 109 | Parameters 110 | ---------- 111 | predicted_scores: torch.Tensor 112 | A tensor of shape (batch_size, num_options), because dense 113 | annotations are available for 1 randomly picked round out of 10. 114 | target_relevance: torch.Tensor 115 | A tensor of shape same as predicted scores, indicating ground truth 116 | relevance of each answer option for a particular round. 117 | """ 118 | predicted_scores = predicted_scores.detach() 119 | 120 | # shape: (batch_size, 1, num_options) 121 | predicted_scores = predicted_scores.unsqueeze(1) 122 | predicted_ranks = scores_to_ranks(predicted_scores) 123 | 124 | # shape: (batch_size, num_options) 125 | predicted_ranks = predicted_ranks.squeeze() 126 | batch_size, num_options = predicted_ranks.size() 127 | 128 | k = torch.sum(target_relevance != 0, dim=-1) 129 | 130 | # shape: (batch_size, num_options) 131 | _, rankings = torch.sort(predicted_ranks, dim=-1) 132 | # Sort relevance in descending order so highest relevance gets top rnk. 133 | _, best_rankings = torch.sort( 134 | target_relevance, dim=-1, descending=True 135 | ) 136 | 137 | # shape: (batch_size, ) 138 | batch_ndcg = [] 139 | for batch_index in range(batch_size): 140 | num_relevant = k[batch_index] 141 | dcg = self._dcg( 142 | rankings[batch_index][:num_relevant], 143 | target_relevance[batch_index], 144 | ) 145 | best_dcg = self._dcg( 146 | best_rankings[batch_index][:num_relevant], 147 | target_relevance[batch_index], 148 | ) 149 | batch_ndcg.append(dcg / best_dcg) 150 | 151 | self._ndcg_denominator += batch_size 152 | self._ndcg_numerator += sum(batch_ndcg) 153 | 154 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): 155 | sorted_relevance = relevance[rankings].cpu().float() 156 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2) 157 | return torch.sum(sorted_relevance / discounts, dim=-1) 158 | 159 | def retrieve(self, reset: bool = True): 160 | if self._ndcg_denominator > 0: 161 | metrics = { 162 | "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) 163 | } 164 | else: 165 | metrics = {} 166 | 167 | if reset: 168 | self.reset() 169 | return metrics 170 | 171 | def reset(self): 172 | self._ndcg_numerator = 0.0 173 | self._ndcg_denominator = 0.0 174 | -------------------------------------------------------------------------------- /visdialch/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class EncoderDecoderModel(nn.Module): 5 | """Convenience wrapper module, wrapping Encoder and Decoder modules. 6 | 7 | Parameters 8 | ---------- 9 | encoder: nn.Module 10 | decoder: nn.Module 11 | """ 12 | 13 | def __init__(self, encoder, decoder): 14 | super().__init__() 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | 18 | def forward(self, batch): 19 | encoder_output = self.encoder(batch) 20 | decoder_output = self.decoder(encoder_output, batch) 21 | return decoder_output 22 | -------------------------------------------------------------------------------- /visdialch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic_rnn import DynamicRNN # noqa: F401 2 | -------------------------------------------------------------------------------- /visdialch/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | """ 2 | A checkpoint manager periodically saves model and optimizer as .pth 3 | files during training. 4 | 5 | Checkpoint managers help with experiment reproducibility, they record 6 | the commit SHA of your current codebase in the checkpoint saving 7 | directory. While loading any checkpoint from other commit, they raise a 8 | friendly warning, a signal to inspect commit diffs for potential bugs. 9 | Moreover, they copy experiment hyper-parameters as a YAML config in 10 | this directory. 11 | 12 | That said, always run your experiments after committing your changes, 13 | this doesn't account for untracked or staged, but uncommitted changes. 14 | """ 15 | from pathlib import Path 16 | from subprocess import PIPE, Popen 17 | import warnings 18 | 19 | import torch 20 | from torch import nn, optim 21 | import yaml 22 | 23 | 24 | class CheckpointManager(object): 25 | """A checkpoint manager saves state dicts of model and optimizer 26 | as .pth files in a specified directory. This class closely follows 27 | the API of PyTorch optimizers and learning rate schedulers. 28 | 29 | Note:: 30 | For ``DataParallel`` modules, ``model.module.state_dict()`` is 31 | saved, instead of ``model.state_dict()``. 32 | 33 | Parameters 34 | ---------- 35 | model: nn.Module 36 | Wrapped model, which needs to be checkpointed. 37 | optimizer: optim.Optimizer 38 | Wrapped optimizer which needs to be checkpointed. 39 | checkpoint_dirpath: str 40 | Path to an empty or non-existent directory to save checkpoints. 41 | step_size: int, optional (default=1) 42 | Period of saving checkpoints. 43 | last_epoch: int, optional (default=-1) 44 | The index of last epoch. 45 | 46 | Example 47 | -------- 48 | >>> model = torch.nn.Linear(10, 2) 49 | >>> optimizer = torch.optim.Adam(model.parameters()) 50 | >>> ckpt_manager = CheckpointManager(model, optimizer, "/tmp/ckpt") 51 | >>> for epoch in range(20): 52 | ... for batch in dataloader: 53 | ... do_iteration(batch) 54 | ... ckpt_manager.step() 55 | """ 56 | 57 | def __init__( 58 | self, 59 | model, 60 | optimizer, 61 | checkpoint_dirpath, 62 | step_size=1, 63 | last_epoch=-1, 64 | **kwargs, 65 | ): 66 | 67 | if not isinstance(model, nn.Module): 68 | raise TypeError("{} is not a Module".format(type(model).__name__)) 69 | 70 | if not isinstance(optimizer, optim.Optimizer): 71 | raise TypeError( 72 | "{} is not an Optimizer".format(type(optimizer).__name__) 73 | ) 74 | 75 | self.model = model 76 | self.optimizer = optimizer 77 | self.ckpt_dirpath = Path(checkpoint_dirpath) 78 | self.step_size = step_size 79 | self.last_epoch = last_epoch 80 | self.init_directory(**kwargs) 81 | 82 | def init_directory(self, config={}): 83 | """Initialize empty checkpoint directory and record commit SHA 84 | in it. Also save hyper-parameters config in this directory to 85 | associate checkpoints with their hyper-parameters. 86 | """ 87 | 88 | self.ckpt_dirpath.mkdir(parents=True, exist_ok=True) 89 | # save current git commit hash in this checkpoint directory 90 | commit_sha_subprocess = Popen( 91 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 92 | ) 93 | commit_sha, _ = commit_sha_subprocess.communicate() 94 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 95 | commit_sha_filepath = self.ckpt_dirpath / f".commit-{commit_sha}" 96 | commit_sha_filepath.touch() 97 | yaml.dump( 98 | config, 99 | open(str(self.ckpt_dirpath / "config.yml"), "w"), 100 | default_flow_style=False, 101 | ) 102 | 103 | def step(self, epoch=None): 104 | """Save checkpoint if step size conditions meet. """ 105 | 106 | if not epoch: 107 | epoch = self.last_epoch + 1 108 | self.last_epoch = epoch 109 | 110 | if not self.last_epoch % self.step_size: 111 | torch.save( 112 | { 113 | "model": self._model_state_dict(), 114 | "optimizer": self.optimizer.state_dict(), 115 | }, 116 | self.ckpt_dirpath / f"checkpoint_{self.last_epoch}.pth", 117 | ) 118 | 119 | def _model_state_dict(self): 120 | """Returns state dict of model, taking care of DataParallel case.""" 121 | if isinstance(self.model, nn.DataParallel): 122 | return self.model.module.state_dict() 123 | else: 124 | return self.model.state_dict() 125 | 126 | 127 | def load_checkpoint(checkpoint_pthpath): 128 | """Given a path to saved checkpoint, load corresponding state dicts 129 | of model and optimizer from it. This method checks if the current 130 | commit SHA of codebase matches the commit SHA recorded when this 131 | checkpoint was saved by checkpoint manager. 132 | 133 | Parameters 134 | ---------- 135 | checkpoint_pthpath: str or pathlib.Path 136 | Path to saved checkpoint (as created by ``CheckpointManager``). 137 | 138 | Returns 139 | ------- 140 | nn.Module, optim.Optimizer 141 | Model and optimizer state dicts loaded from checkpoint. 142 | 143 | Raises 144 | ------ 145 | UserWarning 146 | If commit SHA do not match, or if the directory doesn't have 147 | the recorded commit SHA. 148 | """ 149 | 150 | if isinstance(checkpoint_pthpath, str): 151 | checkpoint_pthpath = Path(checkpoint_pthpath) 152 | checkpoint_dirpath = checkpoint_pthpath.resolve().parent 153 | checkpoint_commit_sha = list(checkpoint_dirpath.glob(".commit-*")) 154 | 155 | if len(checkpoint_commit_sha) == 0: 156 | warnings.warn( 157 | "Commit SHA was not recorded while saving checkpoints." 158 | ) 159 | else: 160 | # verify commit sha, raise warning if it doesn't match 161 | commit_sha_subprocess = Popen( 162 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 163 | ) 164 | commit_sha, _ = commit_sha_subprocess.communicate() 165 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 166 | 167 | # remove ".commit-" 168 | checkpoint_commit_sha = checkpoint_commit_sha[0].name[8:] 169 | 170 | if commit_sha != checkpoint_commit_sha: 171 | warnings.warn( 172 | f"Current commit ({commit_sha}) and the commit " 173 | f"({checkpoint_commit_sha}) at which checkpoint was saved," 174 | " are different. This might affect reproducibility." 175 | ) 176 | 177 | # load encoder, decoder, optimizer state_dicts 178 | components = torch.load(checkpoint_pthpath) 179 | return components["model"], components["optimizer"] 180 | -------------------------------------------------------------------------------- /visdialch/utils/dynamic_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class DynamicRNN(nn.Module): 7 | def __init__(self, rnn_model): 8 | super().__init__() 9 | self.rnn_model = rnn_model 10 | 11 | def forward(self, seq_input, seq_lens, initial_state=None): 12 | """A wrapper over pytorch's rnn to handle sequences of variable length. 13 | 14 | Arguments 15 | --------- 16 | seq_input : torch.Tensor 17 | Input sequence tensor (padded) for RNN model. 18 | Shape: (batch_size, max_sequence_length, embed_size) 19 | seq_lens : torch.LongTensor 20 | Length of sequences (b, ) 21 | initial_state : torch.Tensor 22 | Initial (hidden, cell) states of RNN model. 23 | 24 | Returns 25 | ------- 26 | Single tensor of shape (batch_size, rnn_hidden_size) corresponding 27 | to the outputs of the RNN model at the last time step of each input 28 | sequence. 29 | """ 30 | max_sequence_length = seq_input.size(1) 31 | sorted_len, fwd_order, bwd_order = self._get_sorted_order(seq_lens) 32 | sorted_seq_input = seq_input.index_select(0, fwd_order) 33 | packed_seq_input = pack_padded_sequence( 34 | sorted_seq_input, lengths=sorted_len, batch_first=True 35 | ) 36 | 37 | if initial_state is not None: 38 | hx = initial_state 39 | assert hx[0].size(0) == self.rnn_model.num_layers 40 | else: 41 | hx = None 42 | 43 | self.rnn_model.flatten_parameters() 44 | outputs, (h_n, c_n) = self.rnn_model(packed_seq_input, hx) 45 | 46 | # pick hidden and cell states of last layer 47 | h_n = h_n[-1].index_select(dim=0, index=bwd_order) 48 | c_n = c_n[-1].index_select(dim=0, index=bwd_order) 49 | 50 | outputs = pad_packed_sequence( 51 | outputs, batch_first=True, total_length=max_sequence_length 52 | ) 53 | return outputs, (h_n, c_n) 54 | 55 | @staticmethod 56 | def _get_sorted_order(lens): 57 | sorted_len, fwd_order = torch.sort( 58 | lens.contiguous().view(-1), 0, descending=True 59 | ) 60 | _, bwd_order = torch.sort(fwd_order) 61 | sorted_len = list(sorted_len) 62 | return sorted_len, fwd_order, bwd_order 63 | --------------------------------------------------------------------------------