├── .gitignore ├── LICENSE ├── README.md ├── config ├── bert_base_6layer_6conect.json └── language_weights.json ├── dataloader ├── __init__.py ├── dataloader_dense_annotations.py └── dataloader_visdial.py ├── dense_annotation_finetuning.py ├── env.yml ├── evaluate.py ├── images └── teaser.png ├── models ├── __init__.py ├── language_only_dialog.py ├── language_only_dialog_encoder.py ├── vilbert_dialog.py └── visual_dialog_encoder.py ├── options.py ├── preprocessing └── pre_process_visdial.py ├── scripts ├── download_checkpoints.sh └── download_preprocessed.sh ├── train.py ├── train_language_only_baseline.py └── utils ├── __init__.py ├── data_utils.py ├── image_features_reader.py ├── optim_utils.py ├── visdial_metrics.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # Datasets, pretrained models, checkpoints and preprocessed files 29 | data/ 30 | !visdialch/data/ 31 | checkpoints/ 32 | logs/ 33 | 34 | # IPython Notebook 35 | .ipynb_checkpoints 36 | 37 | # virtualenv 38 | venv/ 39 | ENV/ 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Vishvak Murahari 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, 10 | this list of conditions and the following disclaimer. 11 | * Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | * Neither the name of nor the names of its contributors may be used to 15 | endorse or promote products derived from this software without specific 16 | prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## VisDial-BERT ## 2 | 3 | PyTorch implementation for the paper: 4 | 5 | **[Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline](https://arxiv.org/abs/1912.02379)** 6 | Vishvak Murahari, Dhruv Batra, Devi Parikh, Abhishek Das 7 | 8 | Prior work in visual dialog has focused on training deep neural models on the VisDial dataset in isolation, which has led to great progress, but is limiting and wasteful. In this work, following recent trends in representation learning for language, we introduce an approach to leverage pretraining on related large-scale vision-language datasets before transferring to visual dialog. Specifically, we adapt the recently proposed [ViLBERT][vilbert] model for multi-turn visually-grounded conversation sequences. Our model is pretrained on the Conceptual Captions and Visual Question Answering datasets, and finetuned on VisDial with a VisDial-specific input representation and the masked language modeling and next sentence prediction objectives (as in BERT). Our best single model achieves state-of-the-art on Visual Dialog, outperforming prior published work (including model ensembles) by more than 1% absolute on NDCG and MRR. 9 | 10 | ![models](images/teaser.png) 11 | 12 | This repository contains code for reproducing results with and without finetuning on dense annotations. All results are on [v1.0 of the Visual Dialog dataset][visdial-data]. We provide pretrained model weights and associated configs to run inference or train these models from scratch. 13 | 14 | If you find this work useful in your research, please cite: 15 | 16 | ``` 17 | @article{visdial_bert 18 | title={Large-scale Pretraining for Visual Dialog: A Simple State-of-the-Art Baseline}, 19 | author={Vishvak Murahari and Dhruv Batra and Devi Parikh and Abhishek Das}, 20 | journal={arXiv preprint arXiv:1912.02379}, 21 | year={2019}, 22 | } 23 | ``` 24 | 25 | 26 | ### Table of Contents 27 | 28 | * [Setup and Dependencies](#setup-and-dependencies) 29 | * [Usage](#usage) 30 | * [Download preprocessed data](#download-preprocessed-data) 31 | * [Pre-trained checkpoints](#pre-trained-checkpoints) 32 | * [Training](#training) 33 | * [Logging](#logging) 34 | * [Evaluation](#evaluation) 35 | * [Visualizing Results](#visualizing-results) 36 | * [Reference](#reference) 37 | * [License](#license) 38 | 39 | ### Setup and Dependencies 40 | 41 | Our code is implemented in PyTorch (v1.0). To setup, do the following: 42 | 43 | 1. Install [Python 3.6](https://www.python.org/downloads/release/python-365/) 44 | 2. Get the source: 45 | ``` 46 | git clone https://github.com/vmurahari3/visdial-bert.git visdial-bert 47 | ``` 48 | 3. Install requirements into the `visdial-bert` virtual environment, using [Anaconda](https://anaconda.org/anaconda/python): 49 | ``` 50 | conda env create -f env.yml 51 | ``` 52 | 53 | ### Usage 54 | 55 | Make both the scripts in `scripts/` executable 56 | 57 | ``` 58 | chmod +x scripts/download_preprocessed.sh 59 | chmod +x scripts/download_checkpoints.sh 60 | ``` 61 | 62 | #### Download preprocessed data 63 | 64 | Download preprocessed dataset and extracted features: 65 | 66 | ``` 67 | sh scripts/download_preprocessed.sh 68 | ``` 69 | 70 | To get these files from scratch: 71 | ``` 72 | python preprocessing/pre_process_visdial.py 73 | ``` 74 | 75 | However, we recommend downloading these files directly. 76 | 77 | #### Pre-trained checkpoints 78 | 79 | Download pre-trained checkpoints: 80 | 81 | ``` 82 | sh scripts/download_checkpoints.sh 83 | ``` 84 | 85 | #### Training 86 | 87 | After running the above scripts, all the pre-processed data is downloaded to `data/visdial` and the major pre-trained model checkpoints used in the paper are downloaded to `checkpoints-release` 88 | 89 | Here we list the training arguments to train the important variants in the paper. 90 | 91 | To train the base model (no finetuning on dense annotations): 92 | 93 | ``` 94 | python train.py -batch_size 80 -batch_multiply 1 -lr 2e-5 -image_lr 2e-5 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/vqa_pretrained_weights 95 | ``` 96 | 97 | To finetune the base model with dense annotations: 98 | 99 | ``` 100 | python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10 -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 0 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel 101 | ``` 102 | 103 | To finetune the base model with dense annotations and the next sentence prediction (NSP) loss: 104 | 105 | ``` 106 | python dense_annotation_finetuning.py -batch_size 80 -batch_multiply 10 -lr 1e-4 -image_lr 1e-4 -nsp_loss_coeff 1 -mask_prob 0.1 -sequences_per_image 2 -start_path checkpoints-release/basemodel 107 | ``` 108 | 109 | NOTE: Dense annotation finetuning is currently only supported for 8-GPU training. This is primarily due to memory issues. To calculate the cross entropy loss over the 100 options at a dialog round, we need to have all the 100 dialog sequences in memory. However, we can only fit 80 sequences on 8 GPUs with ~12 GB RAM and we only select 80 options. Performance gets worse with fewer GPUs as we need to further cut down on the number of answer options. 110 | 111 | #### Evaluation 112 | The below code snippet generates a prediction file which can be submitted to the [test server](https://evalai.cloudcv.org/web/challenges/challenge-page/161/leaderboard) to get results on the test split. 113 | 114 | ``` 115 | python evaluate.py -n_gpus 8 -start_path -save_name 116 | ``` 117 | 118 | The metrics for the pretrained checkpoints should match with the numbers mentioned in the paper. However, we mention them below too. These results are on v1.0 test-std. 119 | 120 | | Checkpoint | Mean Rank | MRR | R1 | R5 | R10 | NDCG | 121 | |:--------------------------------------:|:----------:|:-----:|:-----:|:-----:|:-----:|:-----:| 122 | | basemodel | 3.32 | 67.50 | 53.85 | 84.68 | 93.25 |63.87| 123 | | basemodel + dense | 6.28 | 50.74 | 37.95 | 64.13 | 80.00 | 74.47 | 124 | | basemodel + dense + nsp | 4.28 | 63.92 | 50.78 | 79.53 | 89.60 | 68.08 | 125 | 126 | 127 | #### Logging 128 | 129 | We use [Visdom](https://github.com/facebookresearch/visdom) for all logging. Specify `visdom_server`, `visdom_port` and `enable_visdom` arguments in options.py to use this feature. 130 | 131 | #### Visualizing Results 132 | 133 | Coming soon 134 | 135 | ### Acknowledgements 136 | 137 | Builds on Jiasen Lu's ViLBERT [implementation](https://github.com/jiasenlu/vilbert_beta). 138 | 139 | ### License 140 | 141 | BSD 142 | 143 | [vilbert]: https://arxiv.org/abs/1908.02265 144 | [visdial-data]: https://visualdialog.org/data 145 | -------------------------------------------------------------------------------- /config/bert_base_6layer_6conect.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522, 13 | "v_feature_size": 2048, 14 | "v_target_size": 1601, 15 | "v_hidden_size": 1024, 16 | "v_num_hidden_layers":6, 17 | "v_num_attention_heads":8, 18 | "v_intermediate_size":1024, 19 | "bi_hidden_size":1024, 20 | "bi_num_attention_heads":8, 21 | "bi_intermediate_size": 1024, 22 | "bi_attention_type":1, 23 | "v_attention_probs_dropout_prob":0.1, 24 | "v_hidden_act":"gelu", 25 | "v_hidden_dropout_prob":0.1, 26 | "v_initializer_range":0.02, 27 | "v_biattention_id":[0, 1, 2, 3, 4, 5], 28 | "t_biattention_id":[6, 7, 8, 9, 10, 11], 29 | "pooling_method": "mul" 30 | } 31 | -------------------------------------------------------------------------------- /config/language_weights.json: -------------------------------------------------------------------------------- 1 | ["bert_pretrained.bert.embeddings.word_embeddings.weight", "bert_pretrained.bert.embeddings.position_embeddings.weight", "bert_pretrained.bert.embeddings.token_type_embeddings.weight", "bert_pretrained.bert.embeddings.LayerNorm.weight", "bert_pretrained.bert.embeddings.LayerNorm.bias", "bert_pretrained.bert.embeddings.token_type_embeddings_extension.weight", "bert_pretrained.bert.embeddings.sep_embeddings.weight", "bert_pretrained.bert.encoder.layer.0.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.0.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.0.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.0.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.0.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.0.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.0.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.0.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.0.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.0.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.0.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.0.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.0.output.dense.weight", "bert_pretrained.bert.encoder.layer.0.output.dense.bias", "bert_pretrained.bert.encoder.layer.0.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.0.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.1.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.1.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.1.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.1.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.1.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.1.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.1.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.1.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.1.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.1.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.1.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.1.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.1.output.dense.weight", "bert_pretrained.bert.encoder.layer.1.output.dense.bias", "bert_pretrained.bert.encoder.layer.1.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.1.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.2.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.2.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.2.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.2.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.2.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.2.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.2.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.2.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.2.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.2.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.2.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.2.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.2.output.dense.weight", "bert_pretrained.bert.encoder.layer.2.output.dense.bias", "bert_pretrained.bert.encoder.layer.2.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.2.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.3.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.3.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.3.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.3.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.3.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.3.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.3.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.3.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.3.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.3.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.3.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.3.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.3.output.dense.weight", "bert_pretrained.bert.encoder.layer.3.output.dense.bias", "bert_pretrained.bert.encoder.layer.3.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.3.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.4.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.4.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.4.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.4.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.4.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.4.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.4.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.4.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.4.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.4.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.4.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.4.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.4.output.dense.weight", "bert_pretrained.bert.encoder.layer.4.output.dense.bias", "bert_pretrained.bert.encoder.layer.4.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.4.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.5.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.5.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.5.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.5.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.5.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.5.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.5.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.5.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.5.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.5.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.5.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.5.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.5.output.dense.weight", "bert_pretrained.bert.encoder.layer.5.output.dense.bias", "bert_pretrained.bert.encoder.layer.5.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.5.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.6.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.6.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.6.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.6.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.6.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.6.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.6.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.6.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.6.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.6.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.6.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.6.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.6.output.dense.weight", "bert_pretrained.bert.encoder.layer.6.output.dense.bias", "bert_pretrained.bert.encoder.layer.6.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.6.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.7.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.7.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.7.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.7.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.7.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.7.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.7.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.7.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.7.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.7.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.7.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.7.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.7.output.dense.weight", "bert_pretrained.bert.encoder.layer.7.output.dense.bias", "bert_pretrained.bert.encoder.layer.7.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.7.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.8.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.8.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.8.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.8.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.8.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.8.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.8.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.8.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.8.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.8.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.8.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.8.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.8.output.dense.weight", "bert_pretrained.bert.encoder.layer.8.output.dense.bias", "bert_pretrained.bert.encoder.layer.8.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.8.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.9.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.9.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.9.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.9.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.9.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.9.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.9.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.9.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.9.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.9.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.9.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.9.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.9.output.dense.weight", "bert_pretrained.bert.encoder.layer.9.output.dense.bias", "bert_pretrained.bert.encoder.layer.9.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.9.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.10.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.10.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.10.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.10.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.10.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.10.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.10.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.10.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.10.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.10.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.10.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.10.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.10.output.dense.weight", "bert_pretrained.bert.encoder.layer.10.output.dense.bias", "bert_pretrained.bert.encoder.layer.10.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.10.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.11.attention.self.query.weight", "bert_pretrained.bert.encoder.layer.11.attention.self.query.bias", "bert_pretrained.bert.encoder.layer.11.attention.self.key.weight", "bert_pretrained.bert.encoder.layer.11.attention.self.key.bias", "bert_pretrained.bert.encoder.layer.11.attention.self.value.weight", "bert_pretrained.bert.encoder.layer.11.attention.self.value.bias", "bert_pretrained.bert.encoder.layer.11.attention.output.dense.weight", "bert_pretrained.bert.encoder.layer.11.attention.output.dense.bias", "bert_pretrained.bert.encoder.layer.11.attention.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.11.attention.output.LayerNorm.bias", "bert_pretrained.bert.encoder.layer.11.intermediate.dense.weight", "bert_pretrained.bert.encoder.layer.11.intermediate.dense.bias", "bert_pretrained.bert.encoder.layer.11.output.dense.weight", "bert_pretrained.bert.encoder.layer.11.output.dense.bias", "bert_pretrained.bert.encoder.layer.11.output.LayerNorm.weight", "bert_pretrained.bert.encoder.layer.11.output.LayerNorm.bias", "bert_pretrained.bert.pooler.dense.weight", "bert_pretrained.bert.pooler.dense.bias", "bert_pretrained.cls.predictions.bias", "bert_pretrained.cls.predictions.transform.dense.weight", "bert_pretrained.cls.predictions.transform.dense.bias", "bert_pretrained.cls.predictions.transform.LayerNorm.weight", "bert_pretrained.cls.predictions.transform.LayerNorm.bias", "bert_pretrained.cls.seq_relationship.weight", "bert_pretrained.cls.seq_relationship.bias", "inconsistency_head.bias", "inconsistency_head.transform.dense.weight", "inconsistency_head.transform.dense.bias", "inconsistency_head.transform.LayerNorm.weight", "inconsistency_head.transform.LayerNorm.bias", "inconsistency_head.decoder.weight"] -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vmurahari3/visdial-bert/87e264794c45cc5c8c1ea243ad9d2b4d94a44faf/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/dataloader_dense_annotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import json 4 | from pytorch_transformers.tokenization_bert import BertTokenizer 5 | import numpy as np 6 | import random 7 | from utils.data_utils import list2tensorpad, encode_input, encode_image_input 8 | from utils.image_features_reader import ImageFeaturesH5Reader 9 | 10 | class VisdialDatasetDense(data.Dataset): 11 | 12 | def __init__(self, params): 13 | 'Initialization' 14 | self.numDataPoints = {} 15 | num_samples_train = params['num_train_samples'] 16 | num_samples_val = params['num_val_samples'] 17 | self._image_features_reader = ImageFeaturesH5Reader(params['visdial_image_feats']) 18 | with open(params['visdial_processed_train_dense']) as f: 19 | self.visdial_data_train = json.load(f) 20 | if params['overfit']: 21 | if num_samples_train: 22 | self.numDataPoints['train'] = num_samples_train 23 | else: 24 | self.numDataPoints['train'] = 5 25 | else: 26 | if num_samples_train: 27 | self.numDataPoints['train'] = num_samples_train 28 | else: 29 | self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs']) 30 | 31 | with open(params['visdial_processed_val']) as f: 32 | self.visdial_data_val = json.load(f) 33 | if params['overfit']: 34 | if num_samples_val: 35 | self.numDataPoints['val'] = num_samples_val 36 | else: 37 | self.numDataPoints['val'] = 5 38 | else: 39 | if num_samples_val: 40 | self.numDataPoints['val'] = num_samples_val 41 | else: 42 | self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs']) 43 | 44 | self.overfit = params['overfit'] 45 | 46 | with open(params['visdial_processed_train_dense_annotations']) as f: 47 | self.visdial_data_train_ndcg = json.load(f) 48 | with open(params['visdial_processed_val_dense_annotations']) as f: 49 | self.visdial_data_val_ndcg = json.load(f) 50 | 51 | #train val setup 52 | self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val'] 53 | 54 | self.num_options = params["num_options"] 55 | self._split = 'train' 56 | self.subsets = ['train', 'val', 'trainval'] 57 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 58 | self.tokenizer = tokenizer 59 | # fetching token indicecs of [CLS] and [SEP] 60 | tokens = ['[CLS]','[MASK]','[SEP]'] 61 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) 62 | self.CLS = indexed_tokens[0] 63 | self.MASK = indexed_tokens[1] 64 | self.SEP = indexed_tokens[2] 65 | self.params = params 66 | self._max_region_num = 37 67 | 68 | def __len__(self): 69 | return self.numDataPoints[self._split] 70 | 71 | @property 72 | def split(self): 73 | return self._split 74 | 75 | @split.setter 76 | def split(self, split): 77 | assert split in self.subsets 78 | self._split = split 79 | 80 | def __getitem__(self, index): 81 | 82 | def pruneRounds(context, num_rounds): 83 | start_segment = 1 84 | len_context = len(context) 85 | cur_rounds = (len(context) // 2) + 1 86 | l_index = 0 87 | if cur_rounds > num_rounds: 88 | # caption is not part of the final input 89 | l_index = len_context - (2 * num_rounds) 90 | start_segment = 0 91 | return context[l_index:], start_segment 92 | 93 | # Combining all the dialog rounds with the [SEP] and [CLS] token 94 | MAX_SEQ_LEN = self.params['max_seq_len'] 95 | cur_data = None 96 | cur_dense_annotations = None 97 | if self._split == 'train': 98 | cur_data = self.visdial_data_train['data'] 99 | cur_dense_annotations = self.visdial_data_train_ndcg 100 | elif self._split == 'val': 101 | if self.overfit: 102 | cur_data = self.visdial_data_train['data'] 103 | cur_dense_annotations = self.visdial_data_train_ndcg 104 | else: 105 | cur_data = self.visdial_data_val['data'] 106 | cur_dense_annotations = self.visdial_data_val_ndcg 107 | else: 108 | if index >= self.numDataPoints['train']: 109 | cur_data = self.visdial_data_val 110 | cur_dense_annotations = self.visdial_data_val_ndcg 111 | index -= self.numDataPoints['train'] 112 | else: 113 | cur_data = self.visdial_data_train 114 | cur_dense_annotations = self.visdial_data_train_ndcg 115 | # number of options to score on 116 | num_options = self.num_options 117 | assert num_options == 100 118 | 119 | dialog = cur_data['dialogs'][index] 120 | cur_questions = cur_data['questions'] 121 | cur_answers = cur_data['answers'] 122 | img_id = dialog['image_id'] 123 | assert img_id == cur_dense_annotations[index]['image_id'] 124 | 125 | cur_rnd_utterance = [self.tokenizer.encode(dialog['caption'])] 126 | options_all = [] 127 | cur_rounds = cur_dense_annotations[index]['round_id'] 128 | for rnd,utterance in enumerate(dialog['dialog'][:cur_rounds]): 129 | cur_rnd_utterance.append(self.tokenizer.encode(cur_questions[utterance['question']])) 130 | if rnd != cur_rounds - 1: 131 | cur_rnd_utterance.append(self.tokenizer.encode(cur_answers[utterance['answer']])) 132 | for answer_option in dialog['dialog'][cur_rounds - 1]['answer_options']: 133 | cur_option = cur_rnd_utterance.copy() 134 | cur_option.append(self.tokenizer.encode(cur_answers[answer_option])) 135 | options_all.append(cur_option) 136 | assert len(cur_option) == 2 * cur_rounds + 1 137 | 138 | gt_option = dialog['dialog'][cur_rounds - 1]['gt_index'] 139 | 140 | tokens_all = [] 141 | mask_all = [] 142 | segments_all = [] 143 | sep_indices_all = [] 144 | hist_len_all = [] 145 | 146 | for _, option in enumerate(options_all): 147 | option, start_segment = pruneRounds(option, self.params['visdial_tot_rounds']) 148 | tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS, 149 | self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) 150 | 151 | tokens_all.append(tokens) 152 | mask_all.append(mask) 153 | segments_all.append(segments) 154 | sep_indices_all.append(sep_indices) 155 | hist_len_all.append(torch.LongTensor([len(option)-1])) 156 | 157 | tokens_all = torch.cat(tokens_all,0) 158 | mask_all = torch.cat(mask_all,0) 159 | segments_all = torch.cat(segments_all, 0) 160 | sep_indices_all = torch.cat(sep_indices_all, 0) 161 | hist_len_all = torch.cat(hist_len_all,0) 162 | 163 | item = {} 164 | item['tokens'] = tokens_all.unsqueeze(0) 165 | item['segments'] = segments_all.unsqueeze(0) 166 | item['sep_indices'] = sep_indices_all.unsqueeze(0) 167 | item['mask'] = mask_all.unsqueeze(0) 168 | item['hist_len'] = hist_len_all.unsqueeze(0) 169 | item['image_id'] = torch.LongTensor([img_id]) 170 | 171 | # add image features. Expand them to create batch * num_rounds * num options * num bbox * img feats 172 | features, num_boxes, boxes, _ , image_target = self._image_features_reader[img_id] 173 | features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, \ 174 | image_target, max_regions=self._max_region_num, mask_prob=0) 175 | 176 | item['image_feat'] = features 177 | item['image_loc'] = spatials 178 | item['image_mask'] = image_mask 179 | item['image_target'] = image_target 180 | item['image_label'] = image_label 181 | 182 | # add dense annotation fields 183 | item['gt_relevance_round_id'] = torch.LongTensor([cur_rounds]) 184 | item['gt_relevance'] = torch.Tensor(cur_dense_annotations[index]['relevance']) 185 | item['gt_option'] = torch.LongTensor([gt_option]) 186 | 187 | # add next sentence labels for training with the nsp loss as well 188 | nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]) 189 | nsp_labels[:,gt_option] = 0 190 | item['next_sentence_labels'] = nsp_labels.long() 191 | 192 | return item -------------------------------------------------------------------------------- /dataloader/dataloader_visdial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import json 4 | from pytorch_transformers.tokenization_bert import BertTokenizer 5 | import numpy as np 6 | import random 7 | from utils.data_utils import list2tensorpad, encode_input, encode_image_input 8 | from utils.image_features_reader import ImageFeaturesH5Reader 9 | class VisdialDataset(data.Dataset): 10 | 11 | def __init__(self, params): 12 | 13 | self.numDataPoints = {} 14 | num_samples_train = params['num_train_samples'] 15 | num_samples_val = params['num_val_samples'] 16 | self._image_features_reader = ImageFeaturesH5Reader(params['visdial_image_feats']) 17 | with open(params['visdial_processed_train']) as f: 18 | self.visdial_data_train = json.load(f) 19 | if params['overfit']: 20 | if num_samples_train: 21 | self.numDataPoints['train'] = num_samples_train 22 | else: 23 | self.numDataPoints['train'] = 5 24 | else: 25 | if num_samples_train: 26 | self.numDataPoints['train'] = num_samples_train 27 | else: 28 | self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs']) 29 | 30 | with open(params['visdial_processed_val']) as f: 31 | self.visdial_data_val = json.load(f) 32 | if params['overfit']: 33 | if num_samples_val: 34 | self.numDataPoints['val'] = num_samples_val 35 | else: 36 | self.numDataPoints['val'] = 5 37 | else: 38 | if num_samples_val: 39 | self.numDataPoints['val'] = num_samples_val 40 | else: 41 | self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs']) 42 | with open(params['visdial_processed_test']) as f: 43 | self.visdial_data_test = json.load(f) 44 | self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs']) 45 | 46 | self.overfit = params['overfit'] 47 | with open(params['visdial_processed_val_dense_annotations']) as f: 48 | self.visdial_data_val_dense = json.load(f) 49 | self.num_options = params["num_options"] 50 | self._split = 'train' 51 | self.subsets = ['train','val','test'] 52 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 53 | self.tokenizer = tokenizer 54 | # fetching token indicecs of [CLS] and [SEP] 55 | tokens = ['[CLS]','[MASK]','[SEP]'] 56 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) 57 | self.CLS = indexed_tokens[0] 58 | self.MASK = indexed_tokens[1] 59 | self.SEP = indexed_tokens[2] 60 | self.params = params 61 | self._max_region_num = 37 62 | 63 | def __len__(self): 64 | return self.numDataPoints[self._split] 65 | 66 | @property 67 | def split(self): 68 | return self._split 69 | 70 | @split.setter 71 | def split(self, split): 72 | assert split in self.subsets 73 | self._split = split 74 | 75 | def __getitem__(self, index): 76 | 77 | def tokens2str(seq): 78 | dialog_sequence = '' 79 | for sentence in seq: 80 | for word in sentence: 81 | dialog_sequence += self.tokenizer._convert_id_to_token(word) + " " 82 | dialog_sequence += ' ' 83 | dialog_sequence = dialog_sequence.encode('utf8') 84 | return dialog_sequence 85 | 86 | def pruneRounds(context, num_rounds): 87 | start_segment = 1 88 | len_context = len(context) 89 | cur_rounds = (len(context) // 2) + 1 90 | l_index = 0 91 | if cur_rounds > num_rounds: 92 | # caption is not part of the final input 93 | l_index = len_context - (2 * num_rounds) 94 | start_segment = 0 95 | return context[l_index:], start_segment 96 | 97 | # Combining all the dialog rounds with the [SEP] and [CLS] token 98 | MAX_SEQ_LEN = self.params['max_seq_len'] 99 | cur_data = None 100 | if self._split == 'train': 101 | cur_data = self.visdial_data_train['data'] 102 | elif self._split == 'val': 103 | if self.overfit: 104 | cur_data = self.visdial_data_train['data'] 105 | else: 106 | cur_data = self.visdial_data_val['data'] 107 | else: 108 | cur_data = self.visdial_data_test['data'] 109 | 110 | # number of options to score on 111 | num_options = self.num_options 112 | assert num_options > 1 and num_options <= 100 113 | 114 | dialog = cur_data['dialogs'][index] 115 | cur_questions = cur_data['questions'] 116 | cur_answers = cur_data['answers'] 117 | img_id = dialog['image_id'] 118 | 119 | if self._split == 'train': 120 | utterances = [] 121 | utterances_random = [] 122 | tokenized_caption = self.tokenizer.encode(dialog['caption']) 123 | utterances.append([tokenized_caption]) 124 | utterances_random.append([tokenized_caption]) 125 | tot_len = len(tokenized_caption) + 2 # add a 1 for the CLS token as well as the sep tokens which follows the caption 126 | for rnd,utterance in enumerate(dialog['dialog']): 127 | cur_rnd_utterance = utterances[-1].copy() 128 | cur_rnd_utterance_random = utterances[-1].copy() 129 | 130 | tokenized_question = self.tokenizer.encode(cur_questions[utterance['question']]) 131 | tokenized_answer = self.tokenizer.encode(cur_answers[utterance['answer']]) 132 | cur_rnd_utterance.append(tokenized_question) 133 | cur_rnd_utterance.append(tokenized_answer) 134 | 135 | question_len = len(tokenized_question) 136 | answer_len = len(tokenized_answer) 137 | tot_len += question_len + 1 # the additional 1 is for the sep token 138 | tot_len += answer_len + 1 # the additional 1 is for the sep token 139 | 140 | cur_rnd_utterance_random.append(self.tokenizer.encode(cur_questions[utterance['question']])) 141 | # randomly select one random utterance in that round 142 | utterances.append(cur_rnd_utterance) 143 | 144 | num_inds = len(utterance['answer_options']) 145 | gt_option_ind = utterance['gt_index'] 146 | 147 | negative_samples = [] 148 | 149 | for _ in range(self.params["num_negative_samples"]): 150 | 151 | all_inds = list(range(100)) 152 | all_inds.remove(gt_option_ind) 153 | all_inds = all_inds[:(num_options-1)] 154 | tokenized_random_utterance = None 155 | option_ind = None 156 | 157 | while len(all_inds): 158 | option_ind = random.choice(all_inds) 159 | tokenized_random_utterance = self.tokenizer.encode(cur_answers[utterance['answer_options'][option_ind]]) 160 | # the 1 here is for the sep token at the end of each utterance 161 | if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)): 162 | break 163 | else: 164 | all_inds.remove(option_ind) 165 | if len(all_inds) == 0: 166 | # all the options exceed the max len. Truncate the last utterance in this case. 167 | tokenized_random_utterance = tokenized_random_utterance[:answer_len] 168 | t = cur_rnd_utterance_random.copy() 169 | t.append(tokenized_random_utterance) 170 | negative_samples.append(t) 171 | 172 | utterances_random.append(negative_samples) 173 | # removing the caption in the beginning 174 | utterances = utterances[1:] 175 | utterances_random = utterances_random[1:] 176 | assert len(utterances) == len(utterances_random) == 10 177 | 178 | tokens_all_rnd = [] 179 | mask_all_rnd = [] 180 | segments_all_rnd = [] 181 | sep_indices_all_rnd = [] 182 | next_labels_all_rnd = [] 183 | hist_len_all_rnd = [] 184 | 185 | for j,context in enumerate(utterances): 186 | tokens_all = [] 187 | mask_all = [] 188 | segments_all = [] 189 | sep_indices_all = [] 190 | next_labels_all = [] 191 | hist_len_all = [] 192 | 193 | context, start_segment = pruneRounds(context, self.params['visdial_tot_rounds']) 194 | # print("{}: {}".format(j, tokens2str(context))) 195 | tokens, segments, sep_indices, mask = encode_input(context, start_segment, self.CLS, 196 | self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.params["mask_prob"]) 197 | tokens_all.append(tokens) 198 | mask_all.append(mask) 199 | sep_indices_all.append(sep_indices) 200 | next_labels_all.append(torch.LongTensor([0])) 201 | segments_all.append(segments) 202 | hist_len_all.append(torch.LongTensor([len(context)-1])) 203 | negative_samples = utterances_random[j] 204 | 205 | for context_random in negative_samples: 206 | context_random, start_segment = pruneRounds(context_random, self.params['visdial_tot_rounds']) 207 | # print("{}: {}".format(j, tokens2str(context_random))) 208 | tokens_random, segments_random, sep_indices_random, mask_random = encode_input(context_random, start_segment, self.CLS, 209 | self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.params["mask_prob"]) 210 | tokens_all.append(tokens_random) 211 | mask_all.append(mask_random) 212 | sep_indices_all.append(sep_indices_random) 213 | next_labels_all.append(torch.LongTensor([1])) 214 | segments_all.append(segments_random) 215 | hist_len_all.append(torch.LongTensor([len(context_random)-1])) 216 | 217 | tokens_all_rnd.append(torch.cat(tokens_all,0).unsqueeze(0)) 218 | mask_all_rnd.append(torch.cat(mask_all,0).unsqueeze(0)) 219 | segments_all_rnd.append(torch.cat(segments_all, 0).unsqueeze(0)) 220 | sep_indices_all_rnd.append(torch.cat(sep_indices_all, 0).unsqueeze(0)) 221 | next_labels_all_rnd.append(torch.cat(next_labels_all, 0).unsqueeze(0)) 222 | hist_len_all_rnd.append(torch.cat(hist_len_all,0).unsqueeze(0)) 223 | 224 | tokens_all_rnd = torch.cat(tokens_all_rnd,0) 225 | mask_all_rnd = torch.cat(mask_all_rnd,0) 226 | segments_all_rnd = torch.cat(segments_all_rnd, 0) 227 | sep_indices_all_rnd = torch.cat(sep_indices_all_rnd, 0) 228 | next_labels_all_rnd = torch.cat(next_labels_all_rnd, 0) 229 | hist_len_all_rnd = torch.cat(hist_len_all_rnd,0) 230 | 231 | item = {} 232 | 233 | item['tokens'] = tokens_all_rnd 234 | item['segments'] = segments_all_rnd 235 | item['sep_indices'] = sep_indices_all_rnd 236 | item['mask'] = mask_all_rnd 237 | item['next_sentence_labels'] = next_labels_all_rnd 238 | item['hist_len'] = hist_len_all_rnd 239 | 240 | # get image features 241 | features, num_boxes, boxes, _ , image_target = self._image_features_reader[img_id] 242 | features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num) 243 | item['image_feat'] = features 244 | item['image_loc'] = spatials 245 | item['image_mask'] = image_mask 246 | item['image_target'] = image_target 247 | item['image_label'] = image_label 248 | return item 249 | 250 | elif self.split == 'val': 251 | # append all the 100 options and return all the 100 options concatenated with history 252 | # that will lead to 1000 forward passes for a single image 253 | gt_relevance = None 254 | utterances = [] 255 | gt_option_inds = [] 256 | utterances.append([self.tokenizer.encode(dialog['caption'])]) 257 | options_all = [] 258 | for rnd,utterance in enumerate(dialog['dialog']): 259 | cur_rnd_utterance = utterances[-1].copy() 260 | cur_rnd_utterance.append(self.tokenizer.encode(cur_questions[utterance['question']])) 261 | # current round 262 | gt_option_ind = utterance['gt_index'] 263 | option_inds = [] 264 | option_inds.append(gt_option_ind) 265 | all_inds = list(range(100)) 266 | all_inds.remove(gt_option_ind) 267 | all_inds = all_inds[:(num_options-1)] 268 | option_inds.extend(all_inds) 269 | gt_option_inds.append(0) 270 | cur_rnd_options = [] 271 | answer_options = [utterance['answer_options'][k] for k in option_inds] 272 | assert len(answer_options) == len(option_inds) == num_options 273 | assert answer_options[0] == utterance['answer'] 274 | 275 | if rnd == self.visdial_data_val_dense[index]['round_id'] - 1: 276 | gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance']) 277 | # shuffle based on new indices 278 | gt_relevance = gt_relevance[torch.LongTensor(option_inds)] 279 | for answer_option in answer_options: 280 | cur_rnd_cur_option = cur_rnd_utterance.copy() 281 | cur_rnd_cur_option.append(self.tokenizer.encode(cur_answers[answer_option])) 282 | cur_rnd_options.append(cur_rnd_cur_option) 283 | cur_rnd_utterance.append(self.tokenizer.encode(cur_answers[utterance['answer']])) 284 | utterances.append(cur_rnd_utterance) 285 | options_all.append(cur_rnd_options) 286 | # encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options) 287 | tokens_all = [] 288 | mask_all = [] 289 | segments_all = [] 290 | sep_indices_all = [] 291 | hist_len_all = [] 292 | 293 | for rnd,cur_rnd_options in enumerate(options_all): 294 | 295 | tokens_all_rnd = [] 296 | mask_all_rnd = [] 297 | segments_all_rnd = [] 298 | sep_indices_all_rnd = [] 299 | hist_len_all_rnd = [] 300 | 301 | for j,cur_rnd_option in enumerate(cur_rnd_options): 302 | cur_rnd_option, start_segment = pruneRounds(cur_rnd_option, self.params['visdial_tot_rounds']) 303 | tokens, segments, sep_indices, mask = encode_input(cur_rnd_option, start_segment,self.CLS, 304 | self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) 305 | 306 | tokens_all_rnd.append(tokens) 307 | mask_all_rnd.append(mask) 308 | segments_all_rnd.append(segments) 309 | sep_indices_all_rnd.append(sep_indices) 310 | hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1])) 311 | 312 | tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0)) 313 | mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0)) 314 | segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0)) 315 | sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0)) 316 | hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0)) 317 | 318 | tokens_all = torch.cat(tokens_all,0) 319 | mask_all = torch.cat(mask_all,0) 320 | segments_all = torch.cat(segments_all, 0) 321 | sep_indices_all = torch.cat(sep_indices_all, 0) 322 | hist_len_all = torch.cat(hist_len_all,0) 323 | 324 | item = {} 325 | item['tokens'] = tokens_all 326 | item['segments'] = segments_all 327 | item['sep_indices'] = sep_indices_all 328 | item['mask'] = mask_all 329 | item['hist_len'] = hist_len_all 330 | 331 | item['gt_option_inds'] = torch.LongTensor(gt_option_inds) 332 | 333 | # return dense annotation data as well 334 | item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']]) 335 | item['gt_relevance'] = gt_relevance 336 | 337 | # add image features. Expand them to create batch * num_rounds * num options * num bbox * img feats 338 | features, num_boxes, boxes, _ , image_target = self._image_features_reader[img_id] 339 | features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, \ 340 | image_target, max_regions=self._max_region_num, mask_prob=0) 341 | 342 | item['image_feat'] = features 343 | item['image_loc'] = spatials 344 | item['image_mask'] = image_mask 345 | item['image_target'] = image_target 346 | item['image_label'] = image_label 347 | 348 | item['image_id'] = torch.LongTensor([img_id]) 349 | 350 | return item 351 | 352 | else: 353 | assert num_options == 100 354 | cur_rnd_utterance = [self.tokenizer.encode(dialog['caption'])] 355 | options_all = [] 356 | for rnd,utterance in enumerate(dialog['dialog']): 357 | cur_rnd_utterance.append(self.tokenizer.encode(cur_questions[utterance['question']])) 358 | if rnd != len(dialog['dialog'])-1: 359 | cur_rnd_utterance.append(self.tokenizer.encode(cur_answers[utterance['answer']])) 360 | for answer_option in dialog['dialog'][-1]['answer_options']: 361 | cur_option = cur_rnd_utterance.copy() 362 | cur_option.append(self.tokenizer.encode(cur_answers[answer_option])) 363 | options_all.append(cur_option) 364 | 365 | tokens_all = [] 366 | mask_all = [] 367 | segments_all = [] 368 | sep_indices_all = [] 369 | hist_len_all = [] 370 | 371 | for j, option in enumerate(options_all): 372 | option, start_segment = pruneRounds(option, self.params['visdial_tot_rounds']) 373 | print("option: {} {}".format(j, tokens2str(option))) 374 | tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS, 375 | self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) 376 | 377 | tokens_all.append(tokens) 378 | mask_all.append(mask) 379 | segments_all.append(segments) 380 | sep_indices_all.append(sep_indices) 381 | hist_len_all.append(torch.LongTensor([len(option)-1])) 382 | 383 | tokens_all = torch.cat(tokens_all,0) 384 | mask_all = torch.cat(mask_all,0) 385 | segments_all = torch.cat(segments_all, 0) 386 | sep_indices_all = torch.cat(sep_indices_all, 0) 387 | hist_len_all = torch.cat(hist_len_all,0) 388 | 389 | item = {} 390 | item['tokens'] = tokens_all.unsqueeze(0) 391 | item['segments'] = segments_all.unsqueeze(0) 392 | item['sep_indices'] = sep_indices_all.unsqueeze(0) 393 | item['mask'] = mask_all.unsqueeze(0) 394 | item['hist_len'] = hist_len_all.unsqueeze(0) 395 | 396 | item['image_id'] = torch.LongTensor([img_id]) 397 | item['round_id'] = torch.LongTensor([dialog['round_id']]) 398 | 399 | # add image features. Expand them to create batch * num_rounds * num options * num bbox * img feats 400 | features, num_boxes, boxes, _ , image_target = self._image_features_reader[img_id] 401 | features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, \ 402 | image_target, max_regions=self._max_region_num, mask_prob=0) 403 | 404 | item['image_feat'] = features 405 | item['image_loc'] = spatials 406 | item['image_mask'] = image_mask 407 | item['image_target'] = image_target 408 | item['image_label'] = image_label 409 | 410 | return item -------------------------------------------------------------------------------- /dense_annotation_finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.data import DataLoader 8 | import options 9 | from models.visual_dialog_encoder import VisualDialogEncoder 10 | import torch.optim as optim 11 | from utils.visualize import VisdomVisualize 12 | import pprint 13 | from time import gmtime, strftime 14 | from timeit import default_timer as timer 15 | from pytorch_transformers.optimization import AdamW 16 | import os 17 | from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks 18 | from pytorch_transformers.tokenization_bert import BertTokenizer 19 | from utils.data_utils import sequence_mask, batch_iter 20 | from utils.optim_utils import WarmupLinearScheduleNonZero 21 | import json 22 | import logging 23 | from dataloader.dataloader_dense_annotations import VisdialDatasetDense 24 | from dataloader.dataloader_visdial import VisdialDataset 25 | 26 | from train import forward, visdial_evaluate 27 | 28 | if __name__ == '__main__': 29 | 30 | params = options.read_command_line() 31 | os.makedirs('checkpoints', exist_ok=True) 32 | if not os.path.exists(params['save_path']): 33 | os.mkdir(params['save_path']) 34 | viz = VisdomVisualize( 35 | enable=bool(params['enable_visdom']), 36 | env_name=params['visdom_env'], 37 | server=params['visdom_server'], 38 | port=params['visdom_server_port']) 39 | pprint.pprint(params) 40 | viz.addText(pprint.pformat(params, indent=4)) 41 | 42 | dataset = VisdialDatasetDense(params) 43 | 44 | num_images_batch = 1 45 | 46 | dataset.split = 'train' 47 | dataloader = DataLoader( 48 | dataset, 49 | batch_size= num_images_batch, 50 | shuffle=False, 51 | num_workers=params['num_workers'], 52 | drop_last=True, 53 | pin_memory=False) 54 | 55 | eval_dataset = VisdialDataset(params) 56 | eval_dataset.split = 'val' 57 | 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 59 | params['device'] = device 60 | dialog_encoder = VisualDialogEncoder(params['model_config']) 61 | 62 | param_optimizer = list(dialog_encoder.named_parameters()) 63 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 64 | 65 | langauge_weights = None 66 | with open('config/language_weights.json') as f: 67 | langauge_weights = json.load(f) 68 | 69 | optimizer_grouped_parameters = [] 70 | for key, value in dict(dialog_encoder.named_parameters()).items(): 71 | if value.requires_grad: 72 | if key in langauge_weights: 73 | lr = params['lr'] 74 | else: 75 | lr = params['image_lr'] 76 | 77 | if any(nd in key for nd in no_decay): 78 | optimizer_grouped_parameters += [ 79 | {"params": [value], "lr": lr, "weight_decay": 0} 80 | ] 81 | 82 | if not any(nd in key for nd in no_decay): 83 | optimizer_grouped_parameters += [ 84 | {"params": [value], "lr": lr, "weight_decay": 0.01} 85 | ] 86 | 87 | optimizer = AdamW(optimizer_grouped_parameters, lr=params['lr']) 88 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, t_total=200000) 89 | startIterID = 0 90 | 91 | if params['start_path']: 92 | 93 | pretrained_dict = torch.load(params['start_path']) 94 | 95 | if not params['continue']: 96 | if 'model_state_dict' in pretrained_dict: 97 | pretrained_dict = pretrained_dict['model_state_dict'] 98 | 99 | model_dict = dialog_encoder.state_dict() 100 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 101 | print("number of keys transferred", len(pretrained_dict)) 102 | assert len(pretrained_dict.keys()) > 0 103 | model_dict.update(pretrained_dict) 104 | dialog_encoder.load_state_dict(model_dict) 105 | else: 106 | model_dict = dialog_encoder.state_dict() 107 | optimizer_dict = optimizer.state_dict() 108 | pretrained_dict_model = pretrained_dict['model_state_dict'] 109 | pretrained_dict_optimizer = pretrained_dict['optimizer_state_dict'] 110 | pretrained_dict_scheduler = pretrained_dict['scheduler_state_dict'] 111 | pretrained_dict_model = {k: v for k, v in pretrained_dict_model.items() if k in model_dict} 112 | pretrained_dict_optimizer = {k: v for k, v in pretrained_dict_optimizer.items() if k in optimizer_dict} 113 | model_dict.update(pretrained_dict_model) 114 | optimizer_dict.update(pretrained_dict_optimizer) 115 | dialog_encoder.load_state_dict(model_dict) 116 | optimizer.load_state_dict(optimizer_dict) 117 | for state in optimizer.state.values(): 118 | for k, v in state.items(): 119 | if isinstance(v, torch.Tensor): 120 | state[k] = v.to(device) 121 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, \ 122 | t_total=200000, last_epoch=pretrained_dict["iterId"]) 123 | scheduler.load_state_dict(pretrained_dict_scheduler) 124 | startIterID = pretrained_dict['iterId'] 125 | 126 | del pretrained_dict, pretrained_dict_model, pretrained_dict_optimizer, pretrained_dict_scheduler, \ 127 | model_dict, optimizer_dict 128 | torch.cuda.empty_cache() 129 | 130 | num_iter_epoch = dataset.numDataPoints['train'] // num_images_batch if not params['overfit'] else 1 131 | print('\n%d iter per epoch.' % num_iter_epoch) 132 | 133 | dialog_encoder = nn.DataParallel(dialog_encoder) 134 | dialog_encoder.to(device) 135 | 136 | start_t = timer() 137 | optimizer.zero_grad() 138 | # kl div reduces to ce if the target distribution is fixed 139 | ce_loss_fct = nn.KLDivLoss(reduction='batchmean') 140 | 141 | for epoch_id, idx, batch in batch_iter(dataloader, params): 142 | iter_id = startIterID + idx + (epoch_id * num_iter_epoch) 143 | dialog_encoder.train() 144 | # expand image features, 145 | features = batch['image_feat'] 146 | spatials = batch['image_loc'] 147 | image_mask = batch['image_mask'] 148 | image_label = batch['image_label'] 149 | image_target = batch['image_target'] 150 | 151 | num_rounds = batch["tokens"].shape[1] 152 | num_samples = batch["tokens"].shape[2] 153 | 154 | # sample 80 options including the gt option due to memory constraints 155 | assert num_images_batch == 1 156 | gt_option_ind = batch['gt_option'].item() 157 | all_inds_minus_gt = torch.cat([torch.arange(gt_option_ind), torch.arange(gt_option_ind + 1,100)],0) 158 | all_inds_minus_gt = all_inds_minus_gt[torch.randperm(99)[:79]] 159 | option_indices = torch.cat([batch['gt_option'].view(-1), all_inds_minus_gt] , 0) 160 | 161 | features = features.unsqueeze(1).unsqueeze(1).expand(features.shape[0], num_rounds, 80, features.shape[1], features.shape[2]) 162 | spatials = spatials.unsqueeze(1).unsqueeze(1).expand(spatials.shape[0], num_rounds, 80, spatials.shape[1], spatials.shape[2]) 163 | image_mask = image_mask.unsqueeze(1).unsqueeze(1).expand(image_mask.shape[0], num_rounds, 80, image_mask.shape[1]) 164 | image_label = image_label.unsqueeze(1).unsqueeze(1).expand(image_label.shape[0], num_rounds, 80, image_label.shape[1]) 165 | image_target = image_target.unsqueeze(1).unsqueeze(1).expand(image_target.shape[0], num_rounds, 80, image_target.shape[1],image_target.shape[2]) 166 | 167 | features = features.view(-1, features.shape[-2], features.shape[-1]) 168 | spatials = spatials.view(-1, spatials.shape[-2], spatials.shape[-1]) 169 | image_mask = image_mask.view(-1, image_mask.shape[-1]) 170 | image_label = image_label.view(-1, image_label.shape[-1]) 171 | image_target = image_target.view(-1, image_target.shape[-2], image_target.shape[-1]) 172 | 173 | # reshape text features 174 | tokens = batch['tokens'] 175 | segments = batch['segments'] 176 | sep_indices = batch['sep_indices'] 177 | mask = batch['mask'] 178 | hist_len = batch['hist_len'] 179 | nsp_labels = batch['next_sentence_labels'] 180 | 181 | # select 80 options from the 100 options including the GT option 182 | tokens = tokens[:, :, option_indices, :] 183 | segments = segments[:, :, option_indices, :] 184 | sep_indices = sep_indices[:, :, option_indices, :] 185 | mask = mask[:, :, option_indices, :] 186 | hist_len = hist_len[:, :, option_indices] 187 | nsp_labels = nsp_labels[:, :, option_indices] 188 | 189 | tokens = tokens.view(-1, tokens.shape[-1]) 190 | segments = segments.view(-1, segments.shape[-1]) 191 | sep_indices = sep_indices.view(-1, sep_indices.shape[-1]) 192 | mask = mask.view(-1, mask.shape[-1]) 193 | hist_len = hist_len.view(-1) 194 | nsp_labels = nsp_labels.view(-1) 195 | nsp_labels = nsp_labels.to(params['device']) 196 | 197 | batch['tokens'] = tokens 198 | batch['segments'] = segments 199 | batch['sep_indices'] = sep_indices 200 | batch['mask'] = mask 201 | batch['hist_len'] = hist_len 202 | batch['next_sentence_labels'] = nsp_labels 203 | 204 | batch['image_feat'] = features.contiguous() 205 | batch['image_loc'] = spatials.contiguous() 206 | batch['image_mask'] = image_mask.contiguous() 207 | batch['image_target'] = image_target.contiguous() 208 | batch['image_label'] = image_label.contiguous() 209 | 210 | print("token shape", tokens.shape) 211 | loss = 0 212 | nsp_loss = 0 213 | _, _, _, _, nsp_scores = forward(dialog_encoder, batch, \ 214 | params, sample_size=None, output_nsp_scores=True, evaluation=True) 215 | logging.info("nsp scores: {}".format(nsp_scores)) 216 | # calculate dense annotation ce loss 217 | nsp_scores = nsp_scores.view(-1, 80, 2) 218 | nsp_loss = F.cross_entropy(nsp_scores.view(-1,2), nsp_labels.view(-1)) 219 | nsp_scores = nsp_scores[:, :, 0] 220 | 221 | gt_relevance = batch['gt_relevance'].to(device) 222 | # shuffle the gt relevance scores as well 223 | gt_relevance = gt_relevance[:, option_indices] 224 | ce_loss = ce_loss_fct(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1)) 225 | loss = ce_loss + params['nsp_loss_coeff'] * nsp_loss 226 | loss /= params['batch_multiply'] 227 | loss.backward() 228 | scheduler.step() 229 | 230 | if iter_id % params['batch_multiply'] == 0 and iter_id > 0: 231 | optimizer.step() 232 | optimizer.zero_grad() 233 | 234 | if iter_id % 10 == 0: 235 | # Update line plots 236 | viz.linePlot(iter_id, loss.item(), 'loss', 'tot loss') 237 | viz.linePlot(iter_id, nsp_loss.item(), 'loss', 'nsp loss') 238 | viz.linePlot(iter_id, ce_loss.item(), 'loss', 'ce loss') 239 | 240 | old_num_iter_epoch = num_iter_epoch 241 | if params['overfit']: 242 | num_iter_epoch = 100 243 | if iter_id % num_iter_epoch == 0: 244 | torch.save({'model_state_dict' : dialog_encoder.module.state_dict(),'scheduler_state_dict':scheduler.state_dict() \ 245 | ,'optimizer_state_dict': optimizer.state_dict(), 'iter_id':iter_id}, os.path.join(params['save_path'], 'visdial_dialog_encoder_%d.ckpt'%iter_id)) 246 | 247 | if iter_id % num_iter_epoch == 0 and iter_id > 0: 248 | viz.save() 249 | # fire evaluation 250 | print("num iteration for eval", num_iter_epoch) 251 | if iter_id % num_iter_epoch == 0 and iter_id > 0: 252 | eval_batch_size = 2 253 | if params['overfit']: 254 | eval_batch_size = 5 255 | 256 | # each image will need 1000 forward passes, (100 at each round x 10 rounds). 257 | dataloader = DataLoader( 258 | eval_dataset, 259 | batch_size=eval_batch_size, 260 | shuffle=False, 261 | num_workers=params['num_workers'], 262 | drop_last=True, 263 | pin_memory=False) 264 | all_metrics = visdial_evaluate(dataloader, params, eval_batch_size, dialog_encoder) 265 | for metric_name, metric_value in all_metrics.items(): 266 | print(f"{metric_name}: {metric_value}") 267 | if 'round' in metric_name: 268 | viz.linePlot(iter_id, metric_value, 'Retrieval Round Val Metrics Round -' + metric_name.split('_')[-1], metric_name) 269 | else: 270 | viz.linePlot(iter_id, metric_value, 'Retrieval Val Metrics', metric_name) 271 | 272 | num_iter_epoch = old_num_iter_epoch -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: visdial-bert 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2019.11.27=0 9 | - certifi=2019.11.28=py36_0 10 | - cffi=1.13.2=py36h2e261b9_0 11 | - cudatoolkit=10.0.130=0 12 | - freetype=2.9.1=h8a8886c_1 13 | - intel-openmp=2019.4=243 14 | - jpeg=9b=h024ee3a_2 15 | - libedit=3.1.20181209=hc058e9b_0 16 | - libffi=3.2.1=hd88cf55_4 17 | - libgcc-ng=9.1.0=hdf63c60_0 18 | - libgfortran-ng=7.3.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=9.1.0=hdf63c60_0 21 | - libtiff=4.1.0=h2733197_0 22 | - mkl=2019.4=243 23 | - mkl-service=2.3.0=py36he904b0f_0 24 | - mkl_fft=1.0.15=py36ha843d7b_0 25 | - mkl_random=1.1.0=py36hd6b4f25_0 26 | - ncurses=6.1=he6710b0_1 27 | - ninja=1.9.0=py36hfd86e86_0 28 | - numpy-base=1.17.4=py36hde5b4d6_0 29 | - olefile=0.46=py36_0 30 | - openssl=1.1.1d=h7b6447c_3 31 | - pillow=6.2.1=py36h34e0f95_0 32 | - pip=19.3.1=py36_0 33 | - pycparser=2.19=py36_0 34 | - python=3.6.9=h265db76_0 35 | - pytorch=1.3.1=py3.6_cuda10.0.130_cudnn7.6.3_0 36 | - readline=7.0=h7b6447c_5 37 | - setuptools=42.0.2=py36_0 38 | - six=1.13.0=py36_0 39 | - sqlite=3.30.1=h7b6447c_0 40 | - tk=8.6.8=hbc83047_0 41 | - torchvision=0.4.2=py36_cu100 42 | - wheel=0.33.6=py36_0 43 | - xz=5.2.4=h14c3975_4 44 | - zlib=1.2.11=h7b6447c_3 45 | - zstd=1.3.7=h0b5b093_0 46 | - pip: 47 | - absl-py==0.8.1 48 | - astor==0.8.1 49 | - boto3==1.10.36 50 | - botocore==1.13.36 51 | - chardet==3.0.4 52 | - click==7.0 53 | - cycler==0.10.0 54 | - cython==0.29.14 55 | - docutils==0.15.2 56 | - easydict==1.9 57 | - gast==0.3.2 58 | - grpcio==1.25.0 59 | - h5py==2.10.0 60 | - idna==2.8 61 | - jmespath==0.9.4 62 | - joblib==0.14.1 63 | - json-lines==0.5.0 64 | - jsonlines==1.2.0 65 | - jsonpatch==1.24 66 | - jsonpointer==2.0 67 | - keras-applications==1.0.8 68 | - keras-preprocessing==1.1.0 69 | - kiwisolver==1.1.0 70 | - lmdb==0.94 71 | - markdown==3.1.1 72 | - matplotlib==3.1.2 73 | - mock==3.0.5 74 | - msgpack==0.6.2 75 | - msgpack-numpy==0.4.4.3 76 | - numpy==1.17.4 77 | - protobuf==3.11.1 78 | - pyparsing==2.4.5 79 | - python-dateutil==2.8.0 80 | - python-prctl==1.7 81 | - pytorch-pretrained-bert==0.6.2 82 | - pytorch-transformers==1.2.0 83 | - pyyaml==5.1.2 84 | - pyzmq==18.1.1 85 | - regex==2019.12.9 86 | - requests==2.22.0 87 | - s3transfer==0.2.1 88 | - sacremoses==0.0.35 89 | - scipy==1.3.3 90 | - sentencepiece==0.1.83 91 | - tabulate==0.8.6 92 | - tensorboard==1.13.1 93 | - tensorboardx==1.2 94 | - tensorflow==1.13.1 95 | - tensorflow-estimator==1.13.0 96 | - tensorpack==0.9.4 97 | - termcolor==1.1.0 98 | - torchfile==0.1.0 99 | - tornado==6.0.3 100 | - tqdm==4.31.1 101 | - urllib3==1.25.7 102 | - visdom==0.1.8.9 103 | - websocket-client==0.56.0 104 | - werkzeug==0.16.0 105 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.data import DataLoader 8 | from dataloader.dataloader_visdial import VisdialDataset 9 | import options 10 | from models.visual_dialog_encoder import VisualDialogEncoder 11 | import torch.optim as optim 12 | from utils.visualize import VisdomVisualize 13 | import pprint 14 | from time import gmtime, strftime 15 | from timeit import default_timer as timer 16 | from pytorch_transformers.optimization import AdamW 17 | import os 18 | from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks 19 | from pytorch_transformers.tokenization_bert import BertTokenizer 20 | from utils.data_utils import sequence_mask, batch_iter 21 | from utils.optim_utils import WarmupLinearScheduleNonZero 22 | import json 23 | import logging 24 | from train import forward 25 | 26 | def eval_ai_generate(dataloader, params, eval_batch_size, split='test'): 27 | ranks_json = [] 28 | dialog_encoder.eval() 29 | batch_idx = 0 30 | with torch.no_grad(): 31 | batch_size = 500 * (params['n_gpus']/8) 32 | batch_size = min([1, 2, 4, 5, 100, 1000, 200, 8, 10, 40, 50, 500, 20, 25, 250, 125], \ 33 | key=lambda x: abs(x-batch_size) if x <= batch_size else float("inf")) 34 | print("batch size for evaluation", batch_size) 35 | for epochId, _, batch in batch_iter(dataloader, params): 36 | if epochId == 1: 37 | break 38 | 39 | tokens = batch['tokens'] 40 | num_rounds = tokens.shape[1] 41 | num_options = tokens.shape[2] 42 | tokens = tokens.view(-1, tokens.shape[-1]) 43 | segments = batch['segments'] 44 | segments = segments.view(-1, segments.shape[-1]) 45 | sep_indices = batch['sep_indices'] 46 | sep_indices = sep_indices.view(-1, sep_indices.shape[-1]) 47 | mask = batch['mask'] 48 | mask = mask.view(-1, mask.shape[-1]) 49 | hist_len = batch['hist_len'] 50 | hist_len = hist_len.view(-1) 51 | 52 | # get image features 53 | features = batch['image_feat'] 54 | spatials = batch['image_loc'] 55 | image_mask = batch['image_mask'] 56 | 57 | # expand the image features to match those of tokens etc. 58 | max_num_regions = features.shape[-2] 59 | features = features.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions, 2048).contiguous() 60 | spatials = spatials.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions, 5).contiguous() 61 | image_mask = image_mask.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions).contiguous() 62 | 63 | features = features.view(-1, max_num_regions, 2048) 64 | spatials = spatials.view(-1, max_num_regions, 5) 65 | image_mask = image_mask.view(-1, max_num_regions) 66 | 67 | assert tokens.shape[0] == segments.shape[0] == sep_indices.shape[0] == mask.shape[0] == \ 68 | hist_len.shape[0] == features.shape[0] == spatials.shape[0] == \ 69 | image_mask.shape[0] == num_rounds * num_options * eval_batch_size 70 | 71 | output = [] 72 | assert (eval_batch_size * num_rounds * num_options)//batch_size == (eval_batch_size * num_rounds * num_options)/batch_size 73 | for j in range((eval_batch_size * num_rounds * num_options)//batch_size): 74 | # create chunks of the original batch 75 | item = {} 76 | item['tokens'] = tokens[j*batch_size:(j+1)*batch_size,:] 77 | item['segments'] = segments[j*batch_size:(j+1)*batch_size,:] 78 | item['sep_indices'] = sep_indices[j*batch_size:(j+1)*batch_size,:] 79 | item['mask'] = mask[j*batch_size:(j+1)*batch_size,:] 80 | item['hist_len'] = hist_len[j*batch_size:(j+1)*batch_size] 81 | 82 | item['image_feat'] = features[j*batch_size:(j+1)*batch_size, : , :] 83 | item['image_loc'] = spatials[j*batch_size:(j+1)*batch_size, : , :] 84 | item['image_mask'] = image_mask[j*batch_size:(j+1)*batch_size, :] 85 | 86 | _, _, _, _, nsp_scores = forward(dialog_encoder, item, params, output_nsp_scores=True, evaluation=True) 87 | # normalize nsp scores 88 | nsp_probs = F.softmax(nsp_scores, dim=1) 89 | assert nsp_probs.shape[-1] == 2 90 | output.append(nsp_probs[:,0]) 91 | 92 | # print("output shape",torch.cat(output,0).shape) 93 | output = torch.cat(output,0).view(eval_batch_size, num_rounds, num_options) 94 | ranks = scores_to_ranks(output) 95 | ranks = ranks.squeeze(1) 96 | for i in range(eval_batch_size): 97 | ranks_json.append( 98 | { 99 | "image_id": batch["image_id"][i].item(), 100 | "round_id": int(batch["round_id"][i].item()), 101 | "ranks": [ 102 | rank.item() 103 | for rank in ranks[i][:] 104 | ], 105 | } 106 | ) 107 | 108 | batch_idx += 1 109 | return ranks_json 110 | if __name__ == '__main__': 111 | 112 | params = options.read_command_line() 113 | pprint.pprint(params) 114 | dataset = VisdialDataset(params) 115 | eval_batch_size = 5 116 | split = 'test' 117 | dataset.split = split 118 | dataloader = DataLoader( 119 | dataset, 120 | batch_size=eval_batch_size, 121 | shuffle=False, 122 | num_workers=params['num_workers'], 123 | drop_last=False, 124 | pin_memory=False) 125 | 126 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 127 | params['device'] = device 128 | dialog_encoder = VisualDialogEncoder(params['model_config']) 129 | 130 | if params['start_path']: 131 | pretrained_dict = torch.load(params['start_path']) 132 | 133 | if 'model_state_dict' in pretrained_dict: 134 | pretrained_dict = pretrained_dict['model_state_dict'] 135 | 136 | model_dict = dialog_encoder.state_dict() 137 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 138 | print("number of keys transferred", len(pretrained_dict)) 139 | assert len(pretrained_dict.keys()) > 0 140 | model_dict.update(pretrained_dict) 141 | dialog_encoder.load_state_dict(model_dict) 142 | 143 | dialog_encoder = nn.DataParallel(dialog_encoder) 144 | dialog_encoder.to(device) 145 | ranks_json = eval_ai_generate(dataloader, params, eval_batch_size, split=split) 146 | 147 | json.dump(ranks_json, open(params['save_name'] + '_predictions.txt', "w")) 148 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vmurahari3/visdial-bert/87e264794c45cc5c8c1ea243ad9d2b4d94a44faf/images/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vmurahari3/visdial-bert/87e264794c45cc5c8c1ea243ad9d2b4d94a44faf/models/__init__.py -------------------------------------------------------------------------------- /models/language_only_dialog.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss 4 | from pytorch_transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform, BertEmbeddings, BertPreTrainingHeads, BertLayerNorm, \ 5 | BertModel, BertEncoder, BertPooler 6 | from utils.data_utils import sequence_mask 7 | 8 | class BertEmbeddingsDialog(BertEmbeddings): 9 | def __init__(self, config): 10 | super(BertEmbeddingsDialog, self).__init__(config) 11 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 12 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 13 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 14 | # add support for additional segment embeddings. Supporting 10 additional embedding as of now 15 | self.token_type_embeddings_extension = nn.Embedding(10,config.hidden_size) 16 | # adding specialized embeddings for sep tokens 17 | self.sep_embeddings = nn.Embedding(50,config.hidden_size) 18 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 19 | # any TensorFlow checkpoint file 20 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 21 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 22 | self.config = config 23 | 24 | def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None): 25 | seq_length = input_ids.size(1) 26 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 27 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 28 | if token_type_ids is None: 29 | token_type_ids = torch.zeros_like(input_ids) 30 | 31 | words_embeddings = self.word_embeddings(input_ids) 32 | position_embeddings = self.position_embeddings(position_ids) 33 | 34 | token_type_ids_extension = token_type_ids - self.config.type_vocab_size 35 | token_type_ids_extension_mask = (token_type_ids_extension >= 0).float() 36 | token_type_ids_extension = (token_type_ids_extension.float() * token_type_ids_extension_mask).long() 37 | 38 | token_type_ids_mask = (token_type_ids < self.config.type_vocab_size).float() 39 | assert torch.sum(token_type_ids_extension_mask + token_type_ids_mask) == \ 40 | torch.numel(token_type_ids) == torch.numel(token_type_ids_mask) 41 | token_type_ids = (token_type_ids.float() * token_type_ids_mask).long() 42 | 43 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 44 | token_type_embeddings_extension = self.token_type_embeddings_extension(token_type_ids_extension) 45 | 46 | token_type_embeddings = (token_type_embeddings * token_type_ids_mask.unsqueeze(-1)) + \ 47 | (token_type_embeddings_extension * token_type_ids_extension_mask.unsqueeze(-1)) 48 | 49 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 50 | 51 | embeddings = self.LayerNorm(embeddings) 52 | embeddings = self.dropout(embeddings) 53 | return embeddings 54 | 55 | class BertModelDialog(BertModel): 56 | 57 | def __init__(self, config): 58 | super(BertModelDialog, self).__init__(config) 59 | 60 | self.embeddings = BertEmbeddingsDialog(config) 61 | self.encoder = BertEncoder(config) 62 | self.pooler = BertPooler(config) 63 | self.init_weights() 64 | 65 | def _resize_token_embeddings(self, new_num_tokens): 66 | old_embeddings = self.embeddings.word_embeddings 67 | new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) 68 | self.embeddings.word_embeddings = new_embeddings 69 | return self.embeddings.word_embeddings 70 | 71 | def _prune_heads(self, heads_to_prune): 72 | """ Prunes heads of the model. 73 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 74 | See base class PreTrainedModel 75 | """ 76 | for layer, heads in heads_to_prune.items(): 77 | self.encoder.layer[layer].attention.prune_heads(heads) 78 | 79 | def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None): 80 | if attention_mask is None: 81 | attention_mask = torch.ones_like(input_ids) 82 | if token_type_ids is None: 83 | token_type_ids = torch.zeros_like(input_ids) 84 | 85 | # We create a 3D attention mask from a 2D tensor mask. 86 | # Sizes are [batch_size, 1, 1, to_seq_length] 87 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 88 | # this attention mask is more simple than the triangular masking of causal attention 89 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 90 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 91 | 92 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 93 | # masked positions, this operation will create a tensor which is 0.0 for 94 | # positions we want to attend and -10000.0 for masked positions. 95 | # Since we are adding it to the raw scores before the softmax, this is 96 | # effectively the same as removing these entirely. 97 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 98 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 99 | 100 | # Prepare head mask if needed 101 | # 1.0 in head_mask indicate we keep the head 102 | # attention_probs has shape bsz x n_heads x N x N 103 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 104 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 105 | if head_mask is not None: 106 | if head_mask.dim() == 1: 107 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 108 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 109 | elif head_mask.dim() == 2: 110 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 111 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 112 | else: 113 | head_mask = [None] * self.config.num_hidden_layers 114 | 115 | embedding_output = self.embeddings(input_ids, sep_indices=sep_indices, sep_len=sep_len, token_type_ids=token_type_ids) 116 | encoder_outputs = self.encoder(embedding_output, 117 | extended_attention_mask, 118 | head_mask=head_mask) 119 | sequence_output = encoder_outputs[0] 120 | pooled_output = self.pooler(sequence_output) 121 | 122 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here 123 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 124 | 125 | class BertForPretrainingDialog(BertForPreTraining): 126 | def __init__(self, config): 127 | super(BertForPretrainingDialog, self).__init__(config) 128 | self.bert = BertModelDialog(config) 129 | self.cls = BertPreTrainingHeads(config) 130 | self.init_weights() 131 | self.tie_weights() 132 | 133 | def tie_weights(self): 134 | """ Make sure we are sharing the input and output embeddings. 135 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 136 | """ 137 | self._tie_or_clone_weights(self.cls.predictions.decoder, 138 | self.bert.embeddings.word_embeddings) 139 | 140 | def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None,attention_mask=None, 141 | masked_lm_labels=None, next_sentence_label=None, position_ids=None, head_mask=None): 142 | 143 | outputs = self.bert(input_ids, sep_indices=sep_indices, sep_len=sep_len,token_type_ids=token_type_ids\ 144 | ,attention_mask=attention_mask) 145 | 146 | sequence_output, pooled_output = outputs[:2] 147 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 148 | 149 | outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here 150 | 151 | if masked_lm_labels is not None and next_sentence_label is not None: 152 | loss_fct = CrossEntropyLoss(ignore_index=-1) 153 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 154 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 155 | total_loss = masked_lm_loss + next_sentence_loss 156 | outputs = (total_loss,) + outputs 157 | 158 | return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) 159 | -------------------------------------------------------------------------------- /models/language_only_dialog_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import CrossEntropyLoss 4 | from pytorch_transformers.modeling_bert import BertPredictionHeadTransform 5 | from models.language_only_dialog import BertForPretrainingDialog 6 | 7 | class DialogEncoder(nn.Module): 8 | 9 | def __init__(self): 10 | super(DialogEncoder, self).__init__() 11 | self.bert_pretrained = BertForPretrainingDialog.from_pretrained('bert-base-uncased',output_hidden_states=True) 12 | self.bert_pretrained.train() 13 | # add additional layers for the inconsistency loss 14 | assert self.bert_pretrained.config.output_hidden_states == True 15 | 16 | def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None, attention_mask=None, 17 | masked_lm_labels=None, next_sentence_label=None, head_mask=None, output_nsp_scores=False, output_lm_scores=False): 18 | 19 | outputs = self.bert_pretrained(input_ids,sep_indices=sep_indices, sep_len=sep_len, \ 20 | token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, 21 | next_sentence_label=next_sentence_label, position_ids=None, head_mask=head_mask) 22 | 23 | loss = None 24 | if next_sentence_label is not None: 25 | loss, lm_scores, nsp_scores, hidden_states = outputs 26 | else: 27 | lm_scores, nsp_scores, hidden_states = outputs 28 | 29 | loss_fct = CrossEntropyLoss(ignore_index=-1) 30 | 31 | lm_loss = None 32 | nsp_loss = None 33 | if next_sentence_label is not None: 34 | nsp_loss = loss_fct(nsp_scores, next_sentence_label) 35 | if masked_lm_labels is not None: 36 | lm_loss = loss_fct(lm_scores.view(-1,lm_scores.shape[-1]), masked_lm_labels.view(-1)) 37 | 38 | out = (loss,lm_loss, nsp_loss) 39 | if output_nsp_scores: 40 | out = out + (nsp_scores,) 41 | if output_lm_scores: 42 | out = out + (lm_scores,) 43 | return out -------------------------------------------------------------------------------- /models/visual_dialog_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import torch 4 | from torch import nn 5 | from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig 6 | 7 | class VisualDialogEncoder(nn.Module): 8 | 9 | def __init__(self, config_path): 10 | super(VisualDialogEncoder, self).__init__() 11 | config = BertConfig.from_json_file(config_path) 12 | 13 | self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased',config) 14 | self.bert_pretrained.train() 15 | 16 | def forward(self, input_ids, image_feat, image_loc, sep_indices=None, sep_len=None, token_type_ids=None, 17 | attention_mask=None, masked_lm_labels=None, next_sentence_label=None, head_mask=None,random_round_indices=None, 18 | output_nsp_scores=False, output_lm_scores=False,image_attention_mask=None,image_label=None, image_target=None): 19 | 20 | masked_lm_loss = None 21 | masked_img_loss = None 22 | nsp_loss = None 23 | prediction_scores_t = None 24 | seq_relationship_score = None 25 | 26 | if next_sentence_label is not None and masked_lm_labels \ 27 | is not None and image_target is not None: 28 | # train mode, output losses 29 | masked_lm_loss, masked_img_loss, nsp_loss, _, prediction_scores_t, seq_relationship_score = \ 30 | self.bert_pretrained(input_ids, image_feat, image_loc, sep_indices=sep_indices, sep_len=sep_len, \ 31 | token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \ 32 | next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\ 33 | image_label=image_label, image_target=image_target) 34 | else: 35 | #inference, output scores 36 | prediction_scores_t, _, seq_relationship_score, _, _ = \ 37 | self.bert_pretrained(input_ids, image_feat, image_loc, sep_indices=sep_indices, sep_len=sep_len, \ 38 | token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \ 39 | next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\ 40 | image_label=image_label, image_target=image_target) 41 | 42 | out = (masked_lm_loss, masked_img_loss, nsp_loss) 43 | 44 | if output_nsp_scores: 45 | out = out + (seq_relationship_score,) 46 | if output_lm_scores: 47 | out = out + (prediction_scores_t,) 48 | return out -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from six import iteritems 4 | from itertools import product 5 | from time import gmtime, strftime 6 | 7 | def read_command_line(argv=None): 8 | parser = argparse.ArgumentParser(description='Large Scale Pretraining for Visual Dialog') 9 | 10 | #------------------------------------------------------------------------- 11 | # Data input settings 12 | parser.add_argument('-visdial_processed_train', default='data/visdial/visdial_1.0_train_processed.json', \ 13 | help='json file containing train split of visdial data') 14 | parser.add_argument('-visdial_processed_val', default='data/visdial/visdial_1.0_val_processed.json', 15 | help='json file containing val split of visdial data') 16 | parser.add_argument('-visdial_processed_test', default='data/visdial/visdial_1.0_test_processed.json', 17 | help='json file containing test split of visdial data') 18 | parser.add_argument('-visdial_image_feats', default='data/visdial/visdial_img_feat.lmdb', 19 | help='json file containing image feats for train,val and splits of visdial data') 20 | parser.add_argument('-visdial_processed_train_dense', default='data/visdial/visdial_1.0_train_dense_processed.json', 21 | help='samples on the train split for which dense annotations are available') 22 | parser.add_argument('-visdial_processed_train_dense_annotations', default='data/visdial/visdial_1.0_train_dense_annotations_processed.json', 23 | help='json file containing dense annotations on some instance of the train split') 24 | parser.add_argument('-visdial_processed_val_dense_annotations', default='data/visdial/visdial_1.0_val_dense_annotations_processed.json', 25 | help='JSON file with dense annotations') 26 | parser.add_argument('-start_path', default='', help='path of starting model checkpt') 27 | parser.add_argument('-model_config', default='config/bert_base_6layer_6conect.json', help='model definition of the bert model') 28 | #------------------------------------------------------------------------- 29 | # Logging settings 30 | parser.add_argument('-enable_visdom', type=int, default=0, 31 | help='Flag for enabling visdom logging') 32 | parser.add_argument('-visdom_env', type=str, default='', 33 | help='Name of visdom environment for plotting') 34 | parser.add_argument('-visdom_server', type=str, default='http://asimo.cc.gatech.edu', 35 | help='Address of visdom server instance') 36 | parser.add_argument('-visdom_server_port', type=int, default=7777, 37 | help='Port of visdom server instance') 38 | #------------------------------------------------------------------------- 39 | # Optimization / training params 40 | # Other training environmnet settings 41 | parser.add_argument('-num_workers', default=8, type=int, 42 | help='Number of worker threads in dataloader') 43 | parser.add_argument('-batch_size', default=80, type=int, 44 | help='size of mini batch') 45 | parser.add_argument('-num_epochs', default=20, type=int, 46 | help='total number of epochs') 47 | parser.add_argument('-batch_multiply', default=1, type=int, 48 | help='amplifies batch size in mini-batch training') 49 | parser.add_argument('-lr',default=2e-5,type=float,help='learning rate') 50 | parser.add_argument('-image_lr',default=2e-5,type=float,help='learning rate for vision params') 51 | 52 | parser.add_argument('-overfit', action='store_true', help='overfit for debugging') 53 | parser.add_argument('-continue', action='store_true', help='continue training') 54 | 55 | parser.add_argument('-num_train_samples',default=0,type=int, help='number of train samples, set 0 to include all') 56 | parser.add_argument('-num_val_samples',default=0, type=int, help='number of val samples, set 0 to include all') 57 | parser.add_argument('-num_options',default=100, type=int, help='number of options to use. Max: 100 Min: 2') 58 | parser.add_argument('-n_gpus',default=8, type=int, help='number of gpus running the job') 59 | parser.add_argument('-sequences_per_image',default=8, type=int, help='number of sequences sampled from an image during training') 60 | parser.add_argument('-visdial_tot_rounds',default=11, type=int, \ 61 | help='number of rounds to use in visdial,caption is counted as a separate round, therefore a maximum of 11 rounds possible') 62 | parser.add_argument('-max_seq_len',default=256, type=int, help='maximum sequence length for the dialog sequence') 63 | parser.add_argument('-num_negative_samples',default=1, type=int, help='number of negative samples for every positive sample for the nsp loss') 64 | 65 | parser.add_argument('-lm_loss_coeff',default=1,type=float,help='Coeff for lm loss') 66 | parser.add_argument('-nsp_loss_coeff',default=1,type=float,help='Coeff for nsp loss') 67 | parser.add_argument('-img_loss_coeff',default=1,type=float,help='Coeff for img masked loss') 68 | 69 | parser.add_argument('-mask_prob',default=0.15,type=float,help='prob used to sample masked tokens') 70 | 71 | parser.add_argument('-save_path', default='checkpoints/', 72 | help='Path to save checkpoints') 73 | parser.add_argument('-save_name', default='', 74 | help='Name of save directory within savePath') 75 | 76 | #------------------------------------------------------------------------- 77 | #------------------------------------------------------------------------- 78 | try: 79 | parsed = vars(parser.parse_args(args=argv)) 80 | if parsed['save_name']: 81 | # Custom save file path 82 | parsed['save_path'] = os.path.join(parsed['save_path'], 83 | parsed['save_name']) 84 | else: 85 | # Standard save path with time stamp 86 | import random 87 | timeStamp = strftime('%d-%b-%y-%X-%a', gmtime()) 88 | parsed['save_path'] = os.path.join(parsed['save_path'], timeStamp) 89 | parsed['save_path'] += '_{:0>6d}{}'.format(random.randint(0, 10e6),parsed['visdom_env']) 90 | 91 | assert parsed['sequences_per_image'] <= 8 92 | assert parsed['visdial_tot_rounds'] <= 11 93 | 94 | except IOError as msg: 95 | parser.error(str(msg)) 96 | 97 | return parsed -------------------------------------------------------------------------------- /preprocessing/pre_process_visdial.py: -------------------------------------------------------------------------------- 1 | import os 2 | import concurrent.futures 3 | import json 4 | import argparse 5 | import glob 6 | import importlib 7 | import sys 8 | from pytorch_transformers.tokenization_bert import BertTokenizer 9 | 10 | import torch 11 | 12 | def read_options(argv=None): 13 | parser = argparse.ArgumentParser(description='Options') 14 | #------------------------------------------------------------------------- 15 | # Data input settings 16 | parser.add_argument('-visdial_train', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_train.json', help='json file containing train split of visdial data') 17 | 18 | parser.add_argument('-visdial_val', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_val.json', 19 | help='json file containing val split of visdial data') 20 | parser.add_argument('-visdial_test', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_test.json', 21 | help='json file containing test split of visdial data') 22 | parser.add_argument('-visdial_val_ndcg', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_val_dense_annotations.json', 23 | help='JSON file with dense annotations') 24 | parser.add_argument('-visdial_train_ndcg', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_train_dense_annotations.json', 25 | help='JSON file with dense annotations') 26 | 27 | parser.add_argument('-max_seq_len', default=256, type=int, 28 | help='the max len of the input representation of the dialog encoder') 29 | #------------------------------------------------------------------------- 30 | # Logging settings 31 | 32 | parser.add_argument('-save_path_train', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_train_processed.json', 33 | help='Path to save processed train json') 34 | parser.add_argument('-save_path_val', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_val_processed.json', 35 | help='Path to save val json') 36 | parser.add_argument('-save_path_test', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_test_processed.json', 37 | help='Path to save test json') 38 | 39 | parser.add_argument('-save_path_train_dense_samples', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_train_dense_processed.json', 40 | help='Path to save processed train json') 41 | parser.add_argument('-save_path_val_ndcg', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_val_dense_annotations_processed.json', 42 | help='Path to save processed ndcg data for the val split') 43 | parser.add_argument('-save_path_train_ndcg', default='/srv/share/vmurahari3/visdial-rl/data/v1.0_data/visdial_1.0_train_dense_annotations_processed.json', 44 | help='Path to save processed ndcg data for the train split') 45 | 46 | try: 47 | parsed = vars(parser.parse_args(args=argv)) 48 | except IOError as msg: 49 | parser.error(str(msg)) 50 | return parsed 51 | 52 | if __name__ == "__main__": 53 | params = read_options() 54 | # read all the three splits 55 | 56 | f = open(params['visdial_train']) 57 | input_train = json.load(f) 58 | input_train_data = input_train['data']['dialogs'] 59 | train_questions = input_train['data']['questions'] 60 | train_answers = input_train['data']['answers'] 61 | f.close() 62 | 63 | # read train dense annotations 64 | f = open(params['visdial_train_ndcg']) 65 | input_train_ndcg = json.load(f) 66 | f.close() 67 | 68 | f = open(params['visdial_val']) 69 | input_val = json.load(f) 70 | input_val_data = input_val['data']['dialogs'] 71 | val_questions = input_val['data']['questions'] 72 | val_answers = input_val['data']['answers'] 73 | f.close() 74 | 75 | f = open(params['visdial_val_ndcg']) 76 | input_val_ncdg = json.load(f) 77 | f.close() 78 | 79 | f = open(params['visdial_test']) 80 | input_test = json.load(f) 81 | input_test_data = input_test['data']['dialogs'] 82 | test_questions = input_test['data']['questions'] 83 | test_answers = input_test['data']['answers'] 84 | f.close() 85 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 86 | 87 | max_seq_len = params["max_seq_len"] 88 | num_illegal_train = 0 89 | num_illegal_val = 0 90 | num_illegal_test = 0 91 | # process train 92 | i = 0 93 | while i < len(input_train_data): 94 | cur_dialog = input_train_data[i]['dialog'] 95 | caption = input_train_data[i]['caption'] 96 | tot_len = 22 + len(tokenizer.encode(caption)) # account for 21 sep tokens, CLS token and caption 97 | for rnd in range(len(cur_dialog)): 98 | tot_len += len(tokenizer.encode(train_answers[cur_dialog[rnd]['answer']])) 99 | tot_len += len(tokenizer.encode(train_questions[cur_dialog[rnd]['question']])) 100 | if tot_len <= max_seq_len: 101 | i += 1 102 | else: 103 | input_train_data.pop(i) 104 | num_illegal_train += 1 105 | 106 | train_data_dense = [] 107 | train_img_id_to_index = {input_train_data[i]['image_id']: i for i in range(len(input_train_data))} 108 | # pre process dense annotations on train 109 | i = 0 110 | while i < len(input_train_ndcg): 111 | remove = False 112 | index = None 113 | train_sample = None 114 | if input_train_ndcg[i]['image_id'] in train_img_id_to_index: 115 | img_id = input_train_ndcg[i]['image_id'] 116 | index = train_img_id_to_index[img_id] 117 | cur_round = input_train_ndcg[i]['round_id'] - 1 118 | train_sample = input_train_data[index] 119 | # check if the sample is legal 120 | caption = train_sample['caption'] 121 | tot_len = 1 # CLS token 122 | tot_len += len(tokenizer.encode(caption)) + 1 123 | for rnd in range(cur_round): 124 | tot_len += len(tokenizer.encode(train_questions[cur_dialog[rnd]['question']])) + 1 125 | if rnd != input_train_ndcg[i]['round_id']: 126 | tot_len += len(tokenizer.encode(train_answers[cur_dialog[rnd]['answer']])) + 1 127 | for option in train_sample['dialog'][cur_round]['answer_options']: 128 | cur_len = len(tokenizer.encode(train_answers[option])) + 1 + tot_len 129 | if cur_len > max_seq_len: 130 | print("image id", img_id) 131 | remove = True 132 | break 133 | else: 134 | remove = True 135 | 136 | if remove: 137 | input_train_ndcg.pop(i) 138 | else: 139 | train_data_dense.append(train_sample.copy()) 140 | i += 1 141 | 142 | assert len(input_train_ndcg) == len(train_data_dense) 143 | print(len(input_train_ndcg)) 144 | input_train_dense = input_train.copy() 145 | input_train_dense['data']['dialogs'] = train_data_dense 146 | 147 | # process val 148 | i = 0 149 | while i < len(input_val_data): 150 | remove = False 151 | cur_dialog = input_val_data[i]['dialog'] 152 | caption = input_val_data[i]['caption'] 153 | tot_len = 1 # CLS token 154 | tot_len += len(tokenizer.encode(caption)) + 1 155 | for rnd in range(len(cur_dialog)): 156 | tot_len += len(tokenizer.encode(val_questions[cur_dialog[rnd]['question']])) + 1 157 | for option in cur_dialog[rnd]['answer_options']: 158 | cur_len = len(tokenizer.encode(val_answers[option])) + 1 + tot_len 159 | if cur_len > max_seq_len: 160 | input_val_data.pop(i) 161 | input_val_ncdg.pop(i) 162 | num_illegal_val += 1 163 | remove = True 164 | break 165 | if not remove: 166 | tot_len += len(tokenizer.encode(val_answers[cur_dialog[rnd]['answer']])) + 1 167 | else: 168 | break 169 | if not remove: 170 | i += 1 171 | 172 | i = 0 173 | # process test 174 | while i < len(input_test_data): 175 | remove = False 176 | cur_dialog = input_test_data[i]['dialog'] 177 | input_test_data[i]['round_id'] = len(cur_dialog) 178 | caption = input_test_data[i]['caption'] 179 | tot_len = 1 # CLS token 180 | tot_len += len(tokenizer.encode(caption)) + 1 181 | for rnd in range(len(cur_dialog)): 182 | tot_len += len(tokenizer.encode(test_questions[cur_dialog[rnd]['question']])) + 1 183 | if rnd != len(cur_dialog)-1: 184 | tot_len += len(tokenizer.encode(test_answers[cur_dialog[rnd]['answer']])) + 1 185 | 186 | max_len_cur_sample = tot_len 187 | 188 | for option in cur_dialog[-1]['answer_options']: 189 | cur_len = len(tokenizer.encode(test_answers[option])) + 1 + tot_len 190 | if cur_len > max_seq_len: 191 | print("image id", input_test_data[i]['image_id']) 192 | print(cur_len) 193 | print(len(cur_dialog)) 194 | # print(cur_dialog) 195 | remove = True 196 | if max_len_cur_sample < cur_len: 197 | max_len_cur_sample = cur_len 198 | if remove: 199 | # need to process this sample by removing a few rounds 200 | num_illegal_test += 1 201 | while max_len_cur_sample > max_seq_len: 202 | cur_round_len = len(tokenizer.encode(test_questions[cur_dialog[0]['question']])) + 1 + \ 203 | len(tokenizer.encode(test_answers[cur_dialog[0]['answer']])) + 1 204 | cur_dialog.pop(0) 205 | max_len_cur_sample -= cur_round_len 206 | # print("truncated dialog", cur_dialog) 207 | 208 | i += 1 209 | ''' 210 | # store processed files 211 | ''' 212 | with open(params['save_path_train'],'w') as train_out_file: 213 | json.dump(input_train, train_out_file) 214 | 215 | with open(params['save_path_val'],'w') as val_out_file: 216 | json.dump(input_val, val_out_file) 217 | with open(params['save_path_val_ndcg'],'w') as val_ndcg_out_file: 218 | json.dump(input_val_ncdg, val_ndcg_out_file) 219 | with open(params['save_path_test'],'w') as test_out_file: 220 | json.dump(input_test, test_out_file) 221 | 222 | with open(params['save_path_train_dense_samples'],'w') as train_dense_out_file: 223 | json.dump(input_train_dense, train_dense_out_file) 224 | 225 | with open(params['save_path_train_ndcg'],'w') as train_ndcg_out_file: 226 | json.dump(input_train_ndcg, train_ndcg_out_file) 227 | 228 | # spit stats 229 | 230 | print("number of illegal train samples", num_illegal_train) 231 | print("number of illegal val samples", num_illegal_val) 232 | print("number of illegal test samples", num_illegal_test) -------------------------------------------------------------------------------- /scripts/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir checkpoints-release 4 | 5 | # BaseModel -- Pretrained on Conceptual Captions and VQA 6 | wget https://s3.amazonaws.com/visdial-bert/checkpoints/bestmodel_no_dense_finetuning -O checkpoints-release/basemodel 7 | 8 | # BaseModel + Dense Annotation finetuning 9 | wget https://s3.amazonaws.com/visdial-bert/checkpoints/bestmodel_plus_dense -O checkpoints-release/basemodel_dense 10 | 11 | # BaseModel + Dense Annotation finetuning + NSP loss 12 | wget https://s3.amazonaws.com/visdial-bert/checkpoints/bestmodel_plus_dense_plus_nsp -O checkpoints-release/basemodel_dense_nsp 13 | 14 | # VQA ViLBERT pretrained weights 15 | wget https://s3.amazonaws.com/visdial-bert/checkpoints/vqa_weights -O checkpoints-release/vqa_pretrained_weights 16 | 17 | # Conceptual Caption ViLBERT pretrained weights 18 | wget https://s3.amazonaws.com/visdial-bert/checkpoints/concep_cap_weights -O checkpoints-release/concep_cap_pretrained_weights 19 | -------------------------------------------------------------------------------- /scripts/download_preprocessed.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Processed image features for VisDial v1.0 4 | # To generate these files, look in the preprocessing folder and the corresponding section in the README 5 | mkdir -p data/visdial 6 | mkdir -p data/visdial/visdial_img_feat.lmdb 7 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/data.mdb -O data/visdial/visdial_img_feat.lmdb/data.mdb 8 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/lock.mdb -O data/visdial/visdial_img_feat.lmdb/lock.mdb 9 | 10 | # Processed dialog data for VisDial v1.0 11 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_train_processed.json -O data/visdial/visdial_1.0_train_processed.json 12 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_val_processed.json -O data/visdial/visdial_1.0_val_processed.json 13 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_test_processed.json -O data/visdial/visdial_1.0_test_processed.json 14 | 15 | # Samples on the train split with the dense annotations 16 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_train_dense_processed.json -O data/visdial/visdial_1.0_train_dense_processed.json 17 | 18 | # Processed Dense Annotations 19 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_train_dense_annotations_processed.json -O data/visdial/visdial_1.0_train_dense_annotations_processed.json 20 | wget https://s3.amazonaws.com/visdial-bert/data/visdial_1.0_val_dense_annotations_processed.json -O data/visdial/visdial_1.0_val_dense_annotations_processed.json -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | 10 | from dataloader.dataloader_visdial import VisdialDataset 11 | import options 12 | from models.visual_dialog_encoder import VisualDialogEncoder 13 | from utils.visualize import VisdomVisualize 14 | from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks 15 | from pytorch_transformers.tokenization_bert import BertTokenizer 16 | from utils.data_utils import sequence_mask, batch_iter 17 | from utils.optim_utils import WarmupLinearScheduleNonZero 18 | 19 | import pprint 20 | from time import gmtime, strftime 21 | from timeit import default_timer as timer 22 | from pytorch_transformers.optimization import AdamW 23 | import os 24 | 25 | import json 26 | import logging 27 | 28 | def forward(dialog_encoder, batch, params, output_nsp_scores=False, output_lm_scores=False, 29 | sample_size=None, evaluation=False): 30 | 31 | tokens = batch['tokens'] 32 | segments = batch['segments'] 33 | sep_indices = batch['sep_indices'] 34 | mask = batch['mask'] 35 | hist_len = batch['hist_len'] 36 | 37 | # image stuff 38 | orig_features = batch['image_feat'] 39 | orig_spatials = batch['image_loc'] 40 | orig_image_mask = batch['image_mask'] 41 | 42 | tokens = tokens.view(-1,tokens.shape[-1]) 43 | segments = segments.view(-1, segments.shape[-1]) 44 | sep_indices = sep_indices.view(-1,sep_indices.shape[-1]) 45 | mask = mask.view(-1, mask.shape[-1]) 46 | hist_len = hist_len.view(-1) 47 | 48 | features = orig_features.view(-1, orig_features.shape[-2], orig_features.shape[-1]) 49 | spatials = orig_spatials.view(-1, orig_spatials.shape[-2], orig_spatials.shape[-1]) 50 | image_mask = orig_image_mask.view(-1, orig_image_mask.shape[-1]) 51 | 52 | if sample_size: 53 | # subsample a random set 54 | sample_indices = torch.randperm(hist_len.shape[0]) 55 | sample_indices = sample_indices[:sample_size] 56 | else: 57 | sample_indices = torch.arange(hist_len.shape[0]) 58 | 59 | tokens = tokens[sample_indices, :] 60 | segments = segments[sample_indices, :] 61 | sep_indices = sep_indices[sample_indices, :] 62 | mask = mask[sample_indices, :] 63 | hist_len = hist_len[sample_indices] 64 | 65 | features = features[sample_indices, : , :] 66 | spatials = spatials[sample_indices, :, :] 67 | image_mask = image_mask[sample_indices, :] 68 | 69 | next_sentence_labels = None 70 | image_target = None 71 | image_label = None 72 | 73 | if not evaluation: 74 | next_sentence_labels = batch['next_sentence_labels'] 75 | next_sentence_labels = next_sentence_labels.view(-1) 76 | next_sentence_labels = next_sentence_labels[sample_indices] 77 | next_sentence_labels = next_sentence_labels.to(params['device']) 78 | 79 | orig_image_target = batch['image_target'] 80 | orig_image_label = batch['image_label'] 81 | 82 | image_target = orig_image_target.view(-1, orig_image_target.shape[-2], orig_image_target.shape[-1]) 83 | image_label = orig_image_label.view(-1, orig_image_label.shape[-1]) 84 | 85 | image_target = image_target[sample_indices, : , :] 86 | image_label = image_label[sample_indices, :] 87 | 88 | image_target = image_target.to(params['device']) 89 | image_label = image_label.to(params['device']) 90 | 91 | tokens = tokens.to(params['device']) 92 | segments = segments.to(params['device']) 93 | sep_indices = sep_indices.to(params['device']) 94 | mask = mask.to(params['device']) 95 | hist_len = hist_len.to(params['device']) 96 | 97 | features = features.to(params['device']) 98 | spatials = spatials.to(params['device']) 99 | image_mask = image_mask.to(params['device']) 100 | 101 | sequence_lengths = torch.gather(sep_indices,1,hist_len.view(-1,1)) + 1 102 | sequence_lengths = sequence_lengths.squeeze(1) 103 | attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1]) 104 | nsp_loss = None 105 | lm_loss = None 106 | loss = None 107 | lm_scores = None 108 | nsp_scores = None 109 | img_loss = None 110 | sep_len = hist_len + 1 111 | 112 | if output_nsp_scores and output_lm_scores: 113 | lm_loss, img_loss, nsp_loss, nsp_scores, lm_scores = dialog_encoder(tokens, features, spatials, sep_indices=sep_indices, 114 | sep_len=sep_len,token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp \ 115 | ,next_sentence_label=next_sentence_labels, output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores \ 116 | ,image_attention_mask=image_mask, image_label=image_label, image_target=image_target) 117 | elif output_nsp_scores and not output_lm_scores: 118 | lm_loss, img_loss, nsp_loss, nsp_scores = dialog_encoder(tokens, features, spatials, sep_indices=sep_indices, 119 | sep_len=sep_len, token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp \ 120 | ,next_sentence_label=next_sentence_labels, output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores \ 121 | ,image_attention_mask=image_mask, image_label=image_label, image_target=image_target) 122 | elif output_lm_scores and not output_nsp_scores: 123 | lm_loss, img_loss, nsp_loss, lm_scores = dialog_encoder(tokens, features, spatials, sep_indices=sep_indices, 124 | sep_len=sep_len, token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp \ 125 | ,next_sentence_label=next_sentence_labels, output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores \ 126 | ,image_attention_mask=image_mask, image_label=image_label, image_target=image_target) 127 | else: 128 | lm_loss, img_loss, nsp_loss = dialog_encoder(tokens, features, spatials, sep_indices=sep_indices, sep_len=sep_len \ 129 | , token_type_ids=segments, masked_lm_labels=mask, attention_mask=attention_mask_lm_nsp \ 130 | , next_sentence_label=next_sentence_labels, output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores \ 131 | , image_attention_mask=image_mask, image_label=image_label, image_target=image_target) 132 | 133 | if not evaluation: 134 | lm_loss = lm_loss.mean() 135 | nsp_loss = nsp_loss.mean() 136 | img_loss = img_loss.mean() 137 | loss = (params['lm_loss_coeff'] * lm_loss) + (params['nsp_loss_coeff'] * nsp_loss) + \ 138 | (params['img_loss_coeff'] * img_loss) 139 | 140 | if output_nsp_scores and output_lm_scores: 141 | return loss, lm_loss, nsp_loss, img_loss, nsp_scores, lm_scores 142 | elif output_nsp_scores and not output_lm_scores: 143 | return loss, lm_loss, nsp_loss, img_loss, nsp_scores 144 | elif not output_nsp_scores and output_lm_scores: 145 | return loss, lm_loss, nsp_loss, img_loss, lm_scores 146 | else: 147 | return loss, lm_loss, nsp_loss, img_loss 148 | 149 | def visdial_evaluate(dataloader, params, eval_batch_size, dialog_encoder): 150 | sparse_metrics = SparseGTMetrics() 151 | ndcg = NDCG() 152 | dialog_encoder.eval() 153 | batch_idx = 0 154 | with torch.no_grad(): 155 | # we can fit approximately 500 sequences of length 256 in 8 gpus with 12 GB of memory during inference. 156 | batch_size = 500 * (params['n_gpus']/8) 157 | batch_size = min([1, 2, 4, 5, 100, 1000, 200, 8, 10, 40, 50, 500, 20, 25, 250, 125], \ 158 | key=lambda x: abs(x-batch_size) if x <= batch_size else float("inf")) 159 | print("batch size for evaluation", batch_size) 160 | for epoch_id, _, batch in batch_iter(dataloader, params): 161 | if epoch_id == 1: 162 | break 163 | tokens = batch['tokens'] 164 | num_rounds = tokens.shape[1] 165 | num_options = tokens.shape[2] 166 | tokens = tokens.view(-1, tokens.shape[-1]) 167 | segments = batch['segments'] 168 | segments = segments.view(-1, segments.shape[-1]) 169 | sep_indices = batch['sep_indices'] 170 | sep_indices = sep_indices.view(-1, sep_indices.shape[-1]) 171 | mask = batch['mask'] 172 | mask = mask.view(-1, mask.shape[-1]) 173 | hist_len = batch['hist_len'] 174 | hist_len = hist_len.view(-1) 175 | gt_option_inds = batch['gt_option_inds'] 176 | gt_relevance = batch['gt_relevance'] 177 | gt_relevance_round_id = batch['round_id'].squeeze(1) 178 | 179 | # get image features 180 | features = batch['image_feat'] 181 | spatials = batch['image_loc'] 182 | image_mask = batch['image_mask'] 183 | max_num_regions = features.shape[-2] 184 | features = features.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions, 2048).contiguous() 185 | spatials = spatials.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions, 5).contiguous() 186 | image_mask = image_mask.unsqueeze(1).unsqueeze(1).expand(eval_batch_size, num_rounds, num_options, max_num_regions).contiguous() 187 | 188 | features = features.view(-1, max_num_regions, 2048) 189 | spatials = spatials.view(-1, max_num_regions, 5) 190 | image_mask = image_mask.view(-1, max_num_regions) 191 | 192 | assert tokens.shape[0] == segments.shape[0] == sep_indices.shape[0] == mask.shape[0] == \ 193 | hist_len.shape[0] == features.shape[0] == spatials.shape[0] == \ 194 | image_mask.shape[0] == num_rounds * num_options * eval_batch_size 195 | 196 | output = [] 197 | assert (eval_batch_size * num_rounds * num_options)//batch_size == (eval_batch_size * num_rounds * num_options)/batch_size 198 | for j in range((eval_batch_size * num_rounds * num_options)//batch_size): 199 | # create chunks of the original batch 200 | item = {} 201 | item['tokens'] = tokens[j*batch_size:(j+1)*batch_size,:] 202 | item['segments'] = segments[j*batch_size:(j+1)*batch_size,:] 203 | item['sep_indices'] = sep_indices[j*batch_size:(j+1)*batch_size,:] 204 | item['mask'] = mask[j*batch_size:(j+1)*batch_size,:] 205 | item['hist_len'] = hist_len[j*batch_size:(j+1)*batch_size] 206 | 207 | item['image_feat'] = features[j*batch_size:(j+1)*batch_size, : , :] 208 | item['image_loc'] = spatials[j*batch_size:(j+1)*batch_size, : , :] 209 | item['image_mask'] = image_mask[j*batch_size:(j+1)*batch_size, :] 210 | 211 | _, _, _, _, nsp_scores = forward(dialog_encoder, item, params, output_nsp_scores=True, evaluation=True) 212 | # normalize nsp scores 213 | nsp_probs = F.softmax(nsp_scores, dim=1) 214 | assert nsp_probs.shape[-1] == 2 215 | output.append(nsp_probs[:,0]) 216 | 217 | output = torch.cat(output,0).view(eval_batch_size, num_rounds, num_options) 218 | sparse_metrics.observe(output, gt_option_inds) 219 | output = output[torch.arange(output.size(0)), gt_relevance_round_id - 1, :] 220 | ndcg.observe(output, gt_relevance) 221 | batch_idx += 1 222 | 223 | dialog_encoder.train() 224 | print("tot eval batches", batch_idx) 225 | all_metrics = {} 226 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 227 | all_metrics.update(ndcg.retrieve(reset=True)) 228 | 229 | return all_metrics 230 | 231 | if __name__ == '__main__': 232 | 233 | params = options.read_command_line() 234 | os.makedirs('checkpoints', exist_ok=True) 235 | if not os.path.exists(params['save_path']): 236 | os.mkdir(params['save_path']) 237 | viz = VisdomVisualize( 238 | enable=bool(params['enable_visdom']), 239 | env_name=params['visdom_env'], 240 | server=params['visdom_server'], 241 | port=params['visdom_server_port']) 242 | pprint.pprint(params) 243 | viz.addText(pprint.pformat(params, indent=4)) 244 | 245 | dataset = VisdialDataset(params) 246 | 247 | dataset.split = 'train' 248 | dataloader = DataLoader( 249 | dataset, 250 | batch_size= params['batch_size']//params['sequences_per_image'] if (params['batch_size']//params['sequences_per_image']) \ 251 | else 1 if not params['overfit'] else 5, 252 | shuffle=True, 253 | num_workers=params['num_workers'], 254 | drop_last=True, 255 | pin_memory=False) 256 | 257 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 258 | params['device'] = device 259 | dialog_encoder = VisualDialogEncoder(params['model_config']) 260 | 261 | param_optimizer = list(dialog_encoder.named_parameters()) 262 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 263 | 264 | langauge_weights = None 265 | with open('config/language_weights.json') as f: 266 | langauge_weights = json.load(f) 267 | 268 | optimizer_grouped_parameters = [] 269 | for key, value in dict(dialog_encoder.named_parameters()).items(): 270 | if value.requires_grad: 271 | if key in langauge_weights: 272 | lr = params['lr'] 273 | else: 274 | lr = params['image_lr'] 275 | 276 | if any(nd in key for nd in no_decay): 277 | optimizer_grouped_parameters += [ 278 | {"params": [value], "lr": lr, "weight_decay": 0} 279 | ] 280 | 281 | if not any(nd in key for nd in no_decay): 282 | optimizer_grouped_parameters += [ 283 | {"params": [value], "lr": lr, "weight_decay": 0.01} 284 | ] 285 | 286 | optimizer = AdamW(optimizer_grouped_parameters, lr=params['lr']) 287 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, t_total=200000) 288 | start_iter_id = 0 289 | 290 | if params['start_path']: 291 | 292 | pretrained_dict = torch.load(params['start_path']) 293 | 294 | if not params['continue']: 295 | if 'model_state_dict' in pretrained_dict: 296 | pretrained_dict = pretrained_dict['model_state_dict'] 297 | 298 | model_dict = dialog_encoder.state_dict() 299 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 300 | print("number of keys transferred", len(pretrained_dict)) 301 | assert len(pretrained_dict.keys()) > 0 302 | model_dict.update(pretrained_dict) 303 | dialog_encoder.load_state_dict(model_dict) 304 | del pretrained_dict, model_dict \ 305 | 306 | else: 307 | model_dict = dialog_encoder.state_dict() 308 | optimizer_dict = optimizer.state_dict() 309 | pretrained_dict_model = pretrained_dict['model_state_dict'] 310 | pretrained_dict_optimizer = pretrained_dict['optimizer_state_dict'] 311 | pretrained_dict_scheduler = pretrained_dict['scheduler_state_dict'] 312 | pretrained_dict_model = {k: v for k, v in pretrained_dict_model.items() if k in model_dict} 313 | pretrained_dict_optimizer = {k: v for k, v in pretrained_dict_optimizer.items() if k in optimizer_dict} 314 | model_dict.update(pretrained_dict_model) 315 | optimizer_dict.update(pretrained_dict_optimizer) 316 | dialog_encoder.load_state_dict(model_dict) 317 | optimizer.load_state_dict(optimizer_dict) 318 | for state in optimizer.state.values(): 319 | for k, v in state.items(): 320 | if isinstance(v, torch.Tensor): 321 | state[k] = v.to(device) 322 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, \ 323 | t_total=200000, last_epoch=pretrained_dict["iterId"]) 324 | scheduler.load_state_dict(pretrained_dict_scheduler) 325 | start_iter_id = pretrained_dict['iterId'] 326 | 327 | del pretrained_dict, pretrained_dict_model, pretrained_dict_optimizer, pretrained_dict_scheduler, \ 328 | model_dict, optimizer_dict 329 | torch.cuda.empty_cache() 330 | 331 | num_iter_epoch = dataset.numDataPoints['train'] // (params['batch_size'] // params['sequences_per_image'] if (params['batch_size'] // params['sequences_per_image']) \ 332 | else 1 if not params['overfit'] else 5 ) 333 | print('\n%d iter per epoch.' % num_iter_epoch) 334 | 335 | dialog_encoder = nn.DataParallel(dialog_encoder) 336 | dialog_encoder.to(device) 337 | 338 | start_t = timer() 339 | optimizer.zero_grad() 340 | 341 | for epoch_id, idx, batch in batch_iter(dataloader, params): 342 | 343 | iter_id = start_iter_id + idx + (epoch_id * num_iter_epoch) 344 | dialog_encoder.train() 345 | # expand image features, 346 | orig_features = batch['image_feat'] 347 | orig_spatials = batch['image_loc'] 348 | orig_image_mask = batch['image_mask'] 349 | orig_image_target = batch['image_target'] 350 | orig_image_label = batch['image_label'] 351 | 352 | num_rounds = batch["tokens"].shape[1] 353 | num_samples = batch["tokens"].shape[2] 354 | 355 | features = orig_features.unsqueeze(1).unsqueeze(1).expand(orig_features.shape[0], num_rounds, num_samples, orig_features.shape[1], orig_features.shape[2]).contiguous() 356 | spatials = orig_spatials.unsqueeze(1).unsqueeze(1).expand(orig_spatials.shape[0], num_rounds, num_samples, orig_spatials.shape[1], orig_spatials.shape[2]).contiguous() 357 | image_label = orig_image_label.unsqueeze(1).unsqueeze(1).expand(orig_image_label.shape[0], num_rounds, num_samples, orig_image_label.shape[1]).contiguous() 358 | image_mask = orig_image_mask.unsqueeze(1).unsqueeze(1).expand(orig_image_mask.shape[0], num_rounds, num_samples, orig_image_mask.shape[1]).contiguous() 359 | image_target = orig_image_target.unsqueeze(1).unsqueeze(1).expand(orig_image_target.shape[0], num_rounds, num_samples, orig_image_target.shape[1], orig_image_target.shape[2]).contiguous() 360 | 361 | batch['image_feat'] = features.contiguous() 362 | batch['image_loc'] = spatials.contiguous() 363 | batch['image_mask'] = image_mask.contiguous() 364 | batch['image_target'] = image_target.contiguous() 365 | batch['image_label'] = image_label.contiguous() 366 | 367 | if params['overfit']: 368 | sample_size = 48 369 | else: 370 | sample_size = params['batch_size'] 371 | 372 | loss = None 373 | lm_loss = None 374 | nsp_loss = None 375 | img_loss = None 376 | nsp_loss = None 377 | nsp_scores = None 378 | 379 | loss, lm_loss, nsp_loss, img_loss = forward(dialog_encoder, batch, params, sample_size=sample_size) 380 | 381 | lm_nsp_loss = None 382 | if lm_loss is not None and nsp_loss is not None: 383 | lm_nsp_loss = lm_loss + nsp_loss 384 | loss /= params['batch_multiply'] 385 | loss.backward() 386 | scheduler.step() 387 | 388 | if iter_id % params['batch_multiply'] == 0 and iter_id > 0: 389 | optimizer.step() 390 | optimizer.zero_grad() 391 | 392 | if iter_id % 10 == 0: 393 | end_t = timer() 394 | cur_epoch = float(iter_id) / num_iter_epoch 395 | timestamp = strftime('%a %d %b %y %X', gmtime()) 396 | print_lm_loss = 0 397 | print_nsp_loss = 0 398 | print_lm_nsp_loss = 0 399 | print_img_loss = 0 400 | 401 | if lm_loss is not None: 402 | print_lm_loss = lm_loss.item() 403 | if nsp_loss is not None: 404 | print_nsp_loss = nsp_loss.item() 405 | if lm_nsp_loss is not None: 406 | print_lm_nsp_loss = lm_nsp_loss.item() 407 | if img_loss is not None: 408 | print_img_loss = img_loss.item() 409 | 410 | print_format = '[%s][Ep: %.2f][Iter: %d][Time: %5.2fs][NSP + LM Loss: %.3g][LM Loss: %.3g][NSP Loss: %.3g][IMG Loss: %.3g]' 411 | print_info = [ 412 | timestamp, cur_epoch, iter_id, end_t - start_t, print_lm_nsp_loss, print_lm_loss, print_nsp_loss, print_img_loss 413 | ] 414 | print(print_format % tuple(print_info)) 415 | start_t = end_t 416 | 417 | # Update line plots 418 | viz.linePlot(iter_id, loss.item(), 'loss', 'tot loss') 419 | if lm_nsp_loss is not None: 420 | viz.linePlot(iter_id, lm_nsp_loss.item(), 'loss', 'lm + nsp loss') 421 | if lm_loss is not None: 422 | viz.linePlot(iter_id, lm_loss.item(),'loss', 'lm loss') 423 | if nsp_loss is not None: 424 | viz.linePlot(iter_id, nsp_loss.item(), 'loss', 'nsp loss') 425 | if img_loss is not None: 426 | viz.linePlot(iter_id, img_loss.item(), 'loss', 'img loss') 427 | 428 | old_num_iter_epoch = num_iter_epoch 429 | if params['overfit']: 430 | num_iter_epoch = 100 431 | if iter_id % num_iter_epoch == 0 and iter_id > 0: 432 | torch.save({'model_state_dict' : dialog_encoder.module.state_dict(),'scheduler_state_dict':scheduler.state_dict() \ 433 | ,'optimizer_state_dict': optimizer.state_dict(), 'iter_id':iter_id}, os.path.join(params['save_path'], 'visdial_dialog_encoder_%d.ckpt'%iter_id)) 434 | 435 | if iter_id % num_iter_epoch == 0 and iter_id > 0: 436 | viz.save() 437 | # fire evaluation 438 | print("num iteration for eval", num_iter_epoch) 439 | if ((iter_id % (num_iter_epoch * (8 // params['sequences_per_image']))) == 0) and iter_id > 0: 440 | eval_batch_size = 2 441 | if params['overfit']: 442 | eval_batch_size = 5 443 | 444 | dataset.split = 'val' 445 | # each image will need 1000 forward passes, (100 at each round x 10 rounds). 446 | dataloader = DataLoader( 447 | dataset, 448 | batch_size=eval_batch_size, 449 | shuffle=False, 450 | num_workers=params['num_workers'], 451 | drop_last=True, 452 | pin_memory=False) 453 | all_metrics = visdial_evaluate(dataloader, params, eval_batch_size, dialog_encoder) 454 | for metric_name, metric_value in all_metrics.items(): 455 | print(f"{metric_name}: {metric_value}") 456 | if 'round' in metric_name: 457 | viz.linePlot(iter_id, metric_value, 'Retrieval Round Val Metrics Round -' + metric_name.split('_')[-1], metric_name) 458 | else: 459 | viz.linePlot(iter_id, metric_value, 'Retrieval Val Metrics', metric_name) 460 | 461 | dataset.split = 'train' 462 | 463 | num_iter_epoch = old_num_iter_epoch 464 | -------------------------------------------------------------------------------- /train_language_only_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.nn import CrossEntropyLoss 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | 10 | from dataloader.dataloader_visdial import VisdialDataset 11 | import options 12 | from models.language_only_dialog_encoder import DialogEncoder 13 | from utils.visualize import VisdomVisualize 14 | from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks 15 | from utils.data_utils import sequence_mask, batch_iter 16 | from utils.optim_utils import WarmupLinearScheduleNonZero 17 | 18 | from pytorch_transformers.tokenization_bert import BertTokenizer 19 | from pytorch_transformers.optimization import AdamW 20 | import os 21 | import pprint 22 | from time import gmtime, strftime 23 | from timeit import default_timer as timer 24 | 25 | def forward(dialog_encoder, batch, params, output_nsp_scores=False, output_lm_scores=False, sample_size=None, evaluation=False): 26 | 27 | tokens = batch['tokens'] 28 | segments = batch['segments'] 29 | sep_indices = batch['sep_indices'] 30 | mask = batch['mask'] 31 | hist_len = batch['hist_len'] 32 | 33 | tokens = tokens.view(-1,tokens.shape[-1]) 34 | segments = segments.view(-1, segments.shape[-1]) 35 | sep_indices = sep_indices.view(-1,sep_indices.shape[-1]) 36 | mask = mask.view(-1, mask.shape[-1]) 37 | hist_len = hist_len.view(-1) 38 | 39 | if sample_size: 40 | # subsample a random set 41 | sample_indices = torch.randperm(hist_len.shape[0]) 42 | sample_indices = sample_indices[:sample_size] 43 | else: 44 | sample_indices = torch.arange(hist_len.shape[0]) 45 | 46 | tokens = tokens[sample_indices, :] 47 | segments = segments[sample_indices, :] 48 | sep_indices = sep_indices[sample_indices, :] 49 | mask = mask[sample_indices, :] 50 | hist_len = hist_len[sample_indices] 51 | 52 | next_sentence_labels = None 53 | 54 | if not evaluation: 55 | next_sentence_labels = batch['next_sentence_labels'] 56 | next_sentence_labels = next_sentence_labels.view(-1) 57 | next_sentence_labels = next_sentence_labels[sample_indices] 58 | next_sentence_labels = next_sentence_labels.to(params['device']) 59 | 60 | tokens = tokens.to(params['device']) 61 | segments = segments.to(params['device']) 62 | sep_indices = sep_indices.to(params['device']) 63 | mask = mask.to(params['device']) 64 | hist_len = hist_len.to(params['device']) 65 | 66 | sequence_lengths = torch.gather(sep_indices,1,hist_len.view(-1,1)) + 1 67 | sequence_lengths = sequence_lengths.squeeze(1) 68 | attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1]) 69 | nsp_scores = None 70 | nsp_loss = None 71 | lm_loss = None 72 | loss = None 73 | lm_scores = None 74 | 75 | sep_len = hist_len + 1 76 | 77 | if output_nsp_scores and output_lm_scores: 78 | lm_nsp_loss, lm_loss, nsp_loss, nsp_scores, lm_scores = dialog_encoder(tokens,sep_indices=sep_indices, sep_len=sep_len\ 79 | ,token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp\ 80 | ,next_sentence_label=next_sentence_labels,output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores) 81 | elif output_nsp_scores and not output_lm_scores: 82 | lm_nsp_loss, lm_loss, nsp_loss, nsp_scores = dialog_encoder(tokens,sep_indices=sep_indices, sep_len=sep_len\ 83 | ,token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp\ 84 | ,next_sentence_label=next_sentence_labels,output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores) 85 | elif not output_nsp_scores and output_lm_scores: 86 | lm_nsp_loss, lm_loss, nsp_loss, lm_scores = dialog_encoder(tokens,sep_indices=sep_indices, sep_len=sep_len\ 87 | ,token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp\ 88 | ,next_sentence_label=next_sentence_labels,output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores) 89 | else: 90 | lm_nsp_loss, lm_loss, nsp_loss = dialog_encoder(tokens,sep_indices=sep_indices, sep_len=sep_len\ 91 | ,token_type_ids=segments, masked_lm_labels=mask,attention_mask=attention_mask_lm_nsp\ 92 | ,next_sentence_label=next_sentence_labels,output_nsp_scores=output_nsp_scores, output_lm_scores=output_lm_scores) 93 | 94 | if not evaluation: 95 | lm_loss = lm_loss.mean() 96 | nsp_loss = nsp_loss.mean() 97 | loss = (params['lm_loss_coeff'] * lm_loss) + (params['nsp_loss_coeff'] * nsp_loss) 98 | lm_nsp_loss = loss 99 | 100 | if output_nsp_scores and output_lm_scores: 101 | return loss, lm_loss, nsp_loss, nsp_scores, lm_scores 102 | elif output_nsp_scores and not output_lm_scores: 103 | return loss, lm_loss, nsp_loss, nsp_scores 104 | elif not output_nsp_scores and output_lm_scores: 105 | return loss, lm_loss, nsp_loss, lm_scores 106 | else: 107 | return loss, lm_loss, nsp_loss 108 | 109 | def visdial_evaluate(dataloader, params, eval_batch_size): 110 | sparse_metrics = SparseGTMetrics() 111 | ndcg = NDCG() 112 | dialog_encoder.eval() 113 | batch_idx = 0 114 | with torch.no_grad(): 115 | batch_size = 500 * (params['n_gpus']/8) 116 | batch_size = min([1, 2, 4, 5, 100, 1000, 200, 8, 10, 40, 50, 500, 20, 25, 250, 125], \ 117 | key=lambda x: abs(x-batch_size) if x <= batch_size else float("inf")) 118 | if params['overfit']: 119 | batch_size = 100 120 | for epoch_id, _, batch in batch_iter(dataloader, params): 121 | if epoch_id == 1: 122 | break 123 | tokens = batch['tokens'] 124 | num_rounds = tokens.shape[1] 125 | num_options = tokens.shape[2] 126 | tokens = tokens.view(-1, tokens.shape[-1]) 127 | segments = batch['segments'] 128 | segments = segments.view(-1, segments.shape[-1]) 129 | sep_indices = batch['sep_indices'] 130 | sep_indices = sep_indices.view(-1, sep_indices.shape[-1]) 131 | mask = batch['mask'] 132 | mask = mask.view(-1, mask.shape[-1]) 133 | hist_len = batch['hist_len'] 134 | hist_len = hist_len.view(-1) 135 | gt_option_inds = batch['gt_option_inds'] 136 | gt_relevance = batch['gt_relevance'] 137 | gt_relevance_round_id = batch['round_id'].squeeze(1) 138 | 139 | assert tokens.shape[0] == segments.shape[0] == sep_indices.shape[0] == mask.shape[0] == \ 140 | hist_len.shape[0] == num_rounds * num_options * eval_batch_size 141 | output = [] 142 | assert (eval_batch_size * num_rounds * num_options)//batch_size == (eval_batch_size * num_rounds * num_options)/batch_size 143 | for j in range((eval_batch_size * num_rounds * num_options)//batch_size): 144 | # create chunks of the original batch 145 | item = {} 146 | item['tokens'] = tokens[j*batch_size:(j+1)*batch_size,:] 147 | item['segments'] = segments[j*batch_size:(j+1)*batch_size,:] 148 | item['sep_indices'] = sep_indices[j*batch_size:(j+1)*batch_size,:] 149 | item['mask'] = mask[j*batch_size:(j+1)*batch_size,:] 150 | item['hist_len'] = hist_len[j*batch_size:(j+1)*batch_size] 151 | _, _, _, nsp_scores = forward(dialog_encoder, item, params ,output_nsp_scores=True, evaluation=True) 152 | # normalize nsp scores 153 | nsp_probs = F.softmax(nsp_scores, dim=1) 154 | output.append(nsp_probs[:,0]) 155 | 156 | output = torch.cat(output,0).view(eval_batch_size, num_rounds, num_options) 157 | sparse_metrics.observe(output, gt_option_inds) 158 | output = output[torch.arange(output.size(0)), gt_relevance_round_id - 1, :] 159 | ndcg.observe(output, gt_relevance) 160 | batch_idx += 1 161 | 162 | dialog_encoder.train() 163 | all_metrics = {} 164 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 165 | all_metrics.update(ndcg.retrieve(reset=True)) 166 | 167 | return all_metrics 168 | 169 | if __name__ == '__main__': 170 | 171 | params = options.read_command_line() 172 | os.makedirs('checkpoints', exist_ok=True) 173 | if not os.path.exists(params['save_path']): 174 | os.mkdir(params['save_path']) 175 | viz = VisdomVisualize( 176 | enable=bool(params['enable_visdom']), 177 | env_name=params['visdom_env'], 178 | server=params['visdom_server'], 179 | port=params['visdom_server_port']) 180 | pprint.pprint(params) 181 | viz.addText(pprint.pformat(params, indent=4)) 182 | 183 | dataset = VisdialDataset(params) 184 | 185 | dataset.split = 'train' 186 | dataloader = DataLoader( 187 | dataset, 188 | batch_size= params['batch_size']//params['sequences_per_image'] if (params['batch_size']//params['sequences_per_image']) \ 189 | else 1 if not params['overfit'] else 5, 190 | shuffle=False, 191 | num_workers=params['num_workers'], 192 | drop_last=True, 193 | pin_memory=False) 194 | 195 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 196 | params['device'] = device 197 | dialog_encoder = DialogEncoder() 198 | 199 | param_optimizer = list(dialog_encoder.named_parameters()) 200 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 201 | optimizer_grouped_parameters = [ 202 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 203 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 204 | ] 205 | 206 | optimizer = AdamW(optimizer_grouped_parameters, lr=params['lr']) 207 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, t_total=200000) 208 | start_iter_id = 0 209 | 210 | if params['start_path']: 211 | pretrained_dict = torch.load(params['start_path']) 212 | if not params['continue']: 213 | model_dict = dialog_encoder.state_dict() 214 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 215 | print("pretrained dict", pretrained_dict) 216 | assert len(pretrained_dict.keys()) > 0 217 | model_dict.update(pretrained_dict) 218 | dialog_encoder.load_state_dict(model_dict) 219 | else: 220 | model_dict = dialog_encoder.state_dict() 221 | optimizer_dict = optimizer.state_dict() 222 | pretrained_dict_model = pretrained_dict['model_state_dict'] 223 | pretrained_dict_optimizer = pretrained_dict['optimizer_state_dict'] 224 | pretrained_dict_scheduler = pretrained_dict['scheduler_state_dict'] 225 | pretrained_dict_model = {k: v for k, v in pretrained_dict_model.items() if k in model_dict} 226 | pretrained_dict_optimizer = {k: v for k, v in pretrained_dict_optimizer.items() if k in optimizer_dict} 227 | model_dict.update(pretrained_dict_model) 228 | optimizer_dict.update(pretrained_dict_optimizer) 229 | dialog_encoder.load_state_dict(model_dict) 230 | optimizer.load_state_dict(optimizer_dict) 231 | for state in optimizer.state.values(): 232 | for k, v in state.items(): 233 | if isinstance(v, torch.Tensor): 234 | state[k] = v.cuda() 235 | scheduler = WarmupLinearScheduleNonZero(optimizer, warmup_steps=10000, \ 236 | t_total=200000, last_epoch=pretrained_dict["iterId"]) 237 | scheduler.load_state_dict(pretrained_dict_scheduler) 238 | start_iter_id = pretrained_dict['iterId'] 239 | 240 | num_iter_per_epoch = dataset.numDataPoints['train'] // (params['batch_size'] // params['sequences_per_image'] if (params['batch_size'] // params['sequences_per_image']) \ 241 | else 1 if not params['overfit'] else 5 ) 242 | print('\n%d iter per epoch.' % num_iter_per_epoch) 243 | 244 | dialog_encoder = nn.DataParallel(dialog_encoder) 245 | dialog_encoder.to(device) 246 | 247 | start_t = timer() 248 | optimizer.zero_grad() 249 | for epoch_id, idx, batch in batch_iter(dataloader, params): 250 | iter_id = start_iter_id + idx + (epoch_id * num_iter_per_epoch) 251 | dialog_encoder.train() 252 | if not params['overfit']: 253 | loss, lm_loss, nsp_loss = forward(dialog_encoder, batch, params, sample_size=params['batch_size']) 254 | else: 255 | sample_size = 64 256 | loss, lm_loss, nsp_loss = forward(dialog_encoder, batch, params, sample_size=sample_size) 257 | 258 | lm_nsp_loss = None 259 | if lm_loss is not None and nsp_loss is not None: 260 | lm_nsp_loss = lm_loss + nsp_loss 261 | loss /= params['batch_multiply'] 262 | loss.backward() 263 | scheduler.step() 264 | 265 | if iter_id % params['batch_multiply'] == 0 and iter_id > 0: 266 | optimizer.step() 267 | optimizer.zero_grad() 268 | 269 | if iter_id % 10 == 0: 270 | end_t = timer() 271 | curEpoch = float(iter_id) / num_iter_per_epoch 272 | timeStamp = strftime('%a %d %b %y %X', gmtime()) 273 | print_lm_loss = 0 274 | print_nsp_loss = 0 275 | print_inconsistency_loss = 0 276 | print_lm_nsp_loss = 0 277 | if lm_loss is not None: 278 | print_lm_loss = lm_loss.item() 279 | if nsp_loss is not None: 280 | print_nsp_loss = nsp_loss.item() 281 | if lm_nsp_loss is not None: 282 | print_lm_nsp_loss = lm_nsp_loss.item() 283 | 284 | printFormat = '[%s][Ep: %.2f][Iter: %d][Time: %5.2fs][NSP + LM Loss: %.3g][LM Loss: %.3g][NSP Loss: %.3g]' 285 | printInfo = [ 286 | timeStamp, curEpoch, iter_id, end_t - start_t, print_lm_nsp_loss, print_lm_loss, print_nsp_loss 287 | ] 288 | print(printFormat % tuple(printInfo)) 289 | 290 | start_t = end_t 291 | # Update line plots 292 | viz.linePlot(iter_id, loss.item(), 'loss', 'tot loss') 293 | if lm_nsp_loss is not None: 294 | viz.linePlot(iter_id, lm_nsp_loss.item(), 'loss', 'lm + nsp loss') 295 | if lm_loss is not None: 296 | viz.linePlot(iter_id, lm_loss.item(),'loss', 'lm loss') 297 | if nsp_loss is not None: 298 | viz.linePlot(iter_id, nsp_loss.item(), 'loss', 'nsp loss') 299 | 300 | old_num_iter_per_epoch = num_iter_per_epoch 301 | if params['overfit']: 302 | num_iter_per_epoch = 100 303 | if iter_id % num_iter_per_epoch == 0: 304 | torch.save({'model_state_dict' : dialog_encoder.module.state_dict(),'scheduler_state_dict':scheduler.state_dict() \ 305 | ,'optimizer_state_dict': optimizer.state_dict(), 'iter_id':iter_id}, os.path.join(params['save_path'], 'visdial_dialog_encoder_%d.ckpt'%iter_id)) 306 | 307 | if iter_id % num_iter_per_epoch == 0: 308 | viz.save() 309 | # fire evaluation 310 | print("num iteration for eval", num_iter_per_epoch * (8 // params['sequences_per_image'])) 311 | if ((iter_id % (num_iter_per_epoch * (8 // params['sequences_per_image']))) == 0) and iter_id > 0: 312 | eval_batch_size = 2 313 | if params['overfit']: 314 | eval_batch_size = 5 315 | 316 | dataset.split = 'val' 317 | # each image will need 1000 forward passes, (100 at each round x 10 rounds). 318 | dataloader = DataLoader( 319 | dataset, 320 | batch_size=eval_batch_size, 321 | shuffle=False, 322 | num_workers=params['num_workers'], 323 | drop_last=True, 324 | pin_memory=False) 325 | all_metrics = visdial_evaluate(dataloader, params, eval_batch_size) 326 | for metric_name, metric_value in all_metrics.items(): 327 | print(f"{metric_name}: {metric_value}") 328 | if 'round' in metric_name: 329 | viz.linePlot(iter_id, metric_value, 'Retrieval Round Val Metrics Round -' + metric_name.split('_')[-1], metric_name) 330 | else: 331 | viz.linePlot(iter_id, metric_value, 'Retrieval Val Metrics', metric_name) 332 | 333 | dataset.split = 'train' 334 | 335 | num_iter_per_epoch = old_num_iter_per_epoch -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vmurahari3/visdial-bert/87e264794c45cc5c8c1ea243ad9d2b4d94a44faf/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import random 4 | import numpy as np 5 | 6 | def sequence_mask(sequence_length, max_len=None): 7 | if max_len is None: 8 | max_len = sequence_length.data.max() 9 | batch_size = sequence_length.size(0) 10 | seq_range = torch.range(0, max_len - 1).long() 11 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 12 | seq_range_expand = Variable(seq_range_expand) 13 | if sequence_length.is_cuda: 14 | seq_range_expand = seq_range_expand.cuda() 15 | seq_length_expand = (sequence_length.unsqueeze(1) 16 | .expand_as(seq_range_expand)) 17 | return seq_range_expand < seq_length_expand 18 | 19 | def batch_iter(dataloader, params): 20 | for epochId in range(params['num_epochs']): 21 | for idx, batch in enumerate(dataloader): 22 | yield epochId, idx, batch 23 | 24 | def list2tensorpad(inp_list,max_seq_len): 25 | 26 | inp_tensor = torch.LongTensor([inp_list]) 27 | inp_tensor_zeros = torch.zeros(1, max_seq_len, dtype=torch.long) 28 | inp_tensor_zeros[0,:inp_tensor.shape[1]] = inp_tensor 29 | inp_tensor = inp_tensor_zeros 30 | return inp_tensor 31 | 32 | def encode_input(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2): 33 | 34 | cur_segment = start_segment 35 | token_id_list = [] 36 | segment_id_list = [] 37 | sep_token_indices = [] 38 | masked_token_list = [] 39 | 40 | token_id_list.append(CLS) 41 | segment_id_list.append(cur_segment) 42 | masked_token_list.append(0) 43 | 44 | cur_sep_token_index = 0 45 | 46 | for cur_utterance in utterances: 47 | # add the masked token and keep track 48 | cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))] 49 | masked_token_list.extend(cur_masked_index) 50 | token_id_list.extend(cur_utterance) 51 | segment_id_list.extend([cur_segment]*len(cur_utterance)) 52 | 53 | token_id_list.append(SEP) 54 | segment_id_list.append(cur_segment) 55 | masked_token_list.append(0) 56 | cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1 57 | sep_token_indices.append(cur_sep_token_index) 58 | cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1 59 | 60 | assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) == sep_token_indices[-1] + 1 61 | # convert to tensors and pad to maximum seq length 62 | tokens = list2tensorpad(token_id_list,max_seq_len) 63 | masked_tokens = list2tensorpad(masked_token_list,max_seq_len) 64 | masked_tokens[0,masked_tokens[0,:]==0] = -1 65 | mask = masked_tokens[0,:]==1 66 | masked_tokens[0,mask] = tokens[0,mask] 67 | tokens[0,mask] = MASK 68 | 69 | # print("mask", mask) 70 | # print("tokens", tokens) 71 | # print("masked tokens", masked_tokens) 72 | # print("num mask tokens", torch.sum(mask)) 73 | 74 | segment_id_list = list2tensorpad(segment_id_list,max_seq_len) 75 | # segment_id_list += 2 76 | return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len),masked_tokens 77 | 78 | def encode_image_input(features, num_boxes, boxes, image_target, max_regions=37, mask_prob=0.15): 79 | output_label = [] 80 | num_boxes = min(int(num_boxes), max_regions) 81 | 82 | mix_boxes_pad = np.zeros((max_regions, boxes.shape[-1])) 83 | mix_features_pad = np.zeros((max_regions, features.shape[-1])) 84 | mix_image_target = np.zeros((max_regions, image_target.shape[-1])) 85 | 86 | mix_boxes_pad[:num_boxes] = boxes[:num_boxes] 87 | mix_features_pad[:num_boxes] = features[:num_boxes] 88 | mix_image_target[:num_boxes] = image_target[:num_boxes] 89 | 90 | boxes = mix_boxes_pad 91 | features = mix_features_pad 92 | image_target = mix_image_target 93 | 94 | for i in range(num_boxes): 95 | prob = random.random() 96 | # mask token with 15% probability 97 | if prob < mask_prob: 98 | prob /= mask_prob 99 | 100 | # 80% randomly change token to mask token 101 | if prob < 0.9: 102 | features[i] = 0 103 | output_label.append(1) 104 | else: 105 | # no masking token (will be ignored by loss function later) 106 | output_label.append(-1) 107 | 108 | image_mask = [1] * (int(num_boxes)) 109 | while len(image_mask) < max_regions: 110 | image_mask.append(0) 111 | output_label.append(-1) 112 | 113 | # ensure we have atleast one region being predicted 114 | output_label[random.randint(1,len(output_label)-1)] = 1 115 | image_label = torch.LongTensor(output_label) 116 | image_label[0] = 0 # make sure the token doesn't contribute to the masked loss 117 | image_mask = torch.tensor(image_mask).float() 118 | 119 | features = torch.tensor(features).float() 120 | spatials = torch.tensor(boxes).float() 121 | image_target = torch.tensor(image_target).float() 122 | return features, spatials, image_mask, image_target, image_label 123 | -------------------------------------------------------------------------------- /utils/image_features_reader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import csv 3 | import h5py 4 | import numpy as np 5 | import copy 6 | import pickle 7 | import lmdb # install lmdb by "pip install lmdb" 8 | import base64 9 | import pdb 10 | 11 | class ImageFeaturesH5Reader(object): 12 | """ 13 | A reader for H5 files containing pre-extracted image features. A typical 14 | H5 file is expected to have a column named "image_id", and another column 15 | named "features". 16 | 17 | Example of an H5 file: 18 | ``` 19 | faster_rcnn_bottomup_features.h5 20 | |--- "image_id" [shape: (num_images, )] 21 | |--- "features" [shape: (num_images, num_proposals, feature_size)] 22 | +--- .attrs ("split", "train") 23 | ``` 24 | Parameters 25 | ---------- 26 | features_h5path : str 27 | Path to an H5 file containing COCO train / val image features. 28 | in_memory : bool 29 | Whether to load the whole H5 file in memory. Beware, these files are 30 | sometimes tens of GBs in size. Set this to true if you have sufficient 31 | RAM - trade-off between speed and memory. 32 | """ 33 | def __init__(self, features_path: str, in_memory: bool = False): 34 | self.features_path = features_path 35 | self._in_memory = in_memory 36 | 37 | # with h5py.File(self.features_h5path, "r", libver='latest', swmr=True) as features_h5: 38 | # self._image_ids = list(features_h5["image_ids"]) 39 | # If not loaded in memory, then list of None. 40 | self.env = lmdb.open(self.features_path, max_readers=1, readonly=True, 41 | lock=False, readahead=False, meminit=False) 42 | 43 | with self.env.begin(write=False) as txn: 44 | self._image_ids = pickle.loads(txn.get('keys'.encode())) 45 | 46 | self.features = [None] * len(self._image_ids) 47 | self.num_boxes = [None] * len(self._image_ids) 48 | self.boxes = [None] * len(self._image_ids) 49 | self.boxes_ori = [None] * len(self._image_ids) 50 | self.cls_prob = [None] * len(self._image_ids) 51 | 52 | def __len__(self): 53 | return len(self._image_ids) 54 | 55 | def __getitem__(self, image_id): 56 | image_id = str(image_id).encode() 57 | index = self._image_ids.index(image_id) 58 | if self._in_memory: 59 | # Load features during first epoch, all not loaded together as it 60 | # has a slow start. 61 | if self.features[index] is not None: 62 | features = self.features[index] 63 | num_boxes = self.num_boxes[index] 64 | image_location = self.boxes[index] 65 | image_location_ori = self.boxes_ori[index] 66 | cls_prob = self.cls_prob[index] 67 | else: 68 | with self.env.begin(write=False) as txn: 69 | item = pickle.loads(txn.get(image_id)) 70 | image_id = item['image_id'] 71 | image_h = int(item['image_h']) 72 | image_w = int(item['image_w']) 73 | num_boxes = int(item['num_boxes']) 74 | features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048) 75 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4) 76 | 77 | cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601) 78 | # add an extra row at the top for the tokens 79 | g_cls_prob = np.zeros(1601, dtype=np.float32) 80 | g_cls_prob[0] = 1 81 | cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0) 82 | 83 | self.cls_prob[index] = cls_prob 84 | 85 | g_feat = np.sum(features, axis=0) / num_boxes 86 | num_boxes = num_boxes + 1 87 | 88 | features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0) 89 | self.features[index] = features 90 | 91 | image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32) 92 | image_location[:,:4] = boxes 93 | image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h)) 94 | 95 | image_location_ori = copy.deepcopy(image_location) 96 | 97 | image_location[:,0] = image_location[:,0] / float(image_w) 98 | image_location[:,1] = image_location[:,1] / float(image_h) 99 | image_location[:,2] = image_location[:,2] / float(image_w) 100 | image_location[:,3] = image_location[:,3] / float(image_h) 101 | 102 | g_location = np.array([0,0,1,1,1]) 103 | image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0) 104 | self.boxes[index] = image_location 105 | 106 | g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h]) 107 | image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0) 108 | self.boxes_ori[index] = image_location_ori 109 | self.num_boxes[index] = num_boxes 110 | else: 111 | # Read chunk from file everytime if not loaded in memory. 112 | with self.env.begin(write=False) as txn: 113 | item = pickle.loads(txn.get(image_id)) 114 | image_id = item['image_id'] 115 | image_h = int(item['image_h']) 116 | image_w = int(item['image_w']) 117 | num_boxes = int(item['num_boxes']) 118 | cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601) 119 | # add an extra row at the top for the tokens 120 | g_cls_prob = np.zeros(1601, dtype=np.float32) 121 | g_cls_prob[0] = 1 122 | cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0) 123 | 124 | features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048) 125 | boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4) 126 | g_feat = np.sum(features, axis=0) / num_boxes 127 | num_boxes = num_boxes + 1 128 | features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0) 129 | 130 | image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32) 131 | image_location[:,:4] = boxes 132 | image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h)) 133 | 134 | image_location_ori = copy.deepcopy(image_location) 135 | image_location[:,0] = image_location[:,0] / float(image_w) 136 | image_location[:,1] = image_location[:,1] / float(image_h) 137 | image_location[:,2] = image_location[:,2] / float(image_w) 138 | image_location[:,3] = image_location[:,3] / float(image_h) 139 | 140 | g_location = np.array([0,0,1,1,1]) 141 | image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0) 142 | 143 | g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h]) 144 | image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0) 145 | 146 | return features, num_boxes, image_location, image_location_ori, cls_prob 147 | 148 | def keys(self) -> List[int]: 149 | return self._image_ids -------------------------------------------------------------------------------- /utils/optim_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | class WarmupLinearScheduleNonZero(_LRScheduler): 9 | """ Linear warmup and then linear decay. 10 | Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps. 11 | Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps. 12 | """ 13 | def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1): 14 | self.warmup_steps = warmup_steps 15 | self.t_total = t_total 16 | self.min_lr = min_lr 17 | super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch) 18 | 19 | def get_lr(self): 20 | step = self.last_epoch 21 | if step < self.warmup_steps: 22 | lr_factor = float(step) / float(max(1, self.warmup_steps)) 23 | else: 24 | lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 25 | 26 | return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /utils/visdial_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Metric observes output of certain model, for example, in form of logits or 3 | scores, and accumulates a particular metric with reference to some provided 4 | targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean 5 | Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). 6 | 7 | Each ``Metric`` must atleast implement three methods: 8 | - ``observe``, update accumulated metric with currently observed outputs 9 | and targets. 10 | - ``retrieve`` to return the accumulated metric., an optionally reset 11 | internally accumulated metric (this is commonly done between two epochs 12 | after validation). 13 | - ``reset`` to explicitly reset the internally accumulated metric. 14 | 15 | Caveat, if you wish to implement your own class of Metric, make sure you call 16 | ``detach`` on output tensors (like logits), else it will cause memory leaks. 17 | """ 18 | import torch 19 | import numpy as np 20 | 21 | def scores_to_ranks(scores: torch.Tensor): 22 | """Convert model output scores into ranks.""" 23 | batch_size, num_rounds, num_options = scores.size() 24 | scores = scores.view(-1, num_options) 25 | 26 | # sort in descending order - largest score gets highest rank 27 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 28 | 29 | # i-th position in ranked_idx specifies which score shall take this 30 | # position but we want i-th position to have rank of score at that 31 | # position, do this conversion 32 | ranks = ranked_idx.clone().fill_(0) 33 | for i in range(ranked_idx.size(0)): 34 | for j in range(num_options): 35 | ranks[i][ranked_idx[i][j]] = j 36 | # convert from 0-99 ranks to 1-100 ranks 37 | ranks += 1 38 | ranks = ranks.view(batch_size, num_rounds, num_options) 39 | return ranks 40 | 41 | class SparseGTMetrics(object): 42 | """ 43 | A class to accumulate all metrics with sparse ground truth annotations. 44 | These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. 45 | """ 46 | 47 | def __init__(self): 48 | self._rank_list = [] 49 | self._rank_list_rnd = [] 50 | self.num_rounds = None 51 | 52 | def observe( 53 | self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor 54 | ): 55 | predicted_scores = predicted_scores.detach() 56 | 57 | # shape: (batch_size, num_rounds, num_options) 58 | predicted_ranks = scores_to_ranks(predicted_scores) 59 | batch_size, num_rounds, num_options = predicted_ranks.size() 60 | self.num_rounds = num_rounds 61 | # collapse batch dimension 62 | predicted_ranks = predicted_ranks.view( 63 | batch_size * num_rounds, num_options 64 | ) 65 | 66 | # shape: (batch_size * num_rounds, ) 67 | target_ranks = target_ranks.view(batch_size * num_rounds).long() 68 | 69 | # shape: (batch_size * num_rounds, ) 70 | predicted_gt_ranks = predicted_ranks[ 71 | torch.arange(batch_size * num_rounds), target_ranks 72 | ] 73 | self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) 74 | 75 | predicted_gt_ranks_rnd = predicted_gt_ranks.view(batch_size, num_rounds) 76 | # predicted gt ranks 77 | self._rank_list_rnd.append(predicted_gt_ranks_rnd.cpu().numpy()) 78 | 79 | def retrieve(self, reset: bool = True): 80 | num_examples = len(self._rank_list) 81 | if num_examples > 0: 82 | # convert to numpy array for easy calculation. 83 | __rank_list = torch.tensor(self._rank_list).float() 84 | metrics = { 85 | "r@1": torch.mean((__rank_list <= 1).float()).item(), 86 | "r@5": torch.mean((__rank_list <= 5).float()).item(), 87 | "r@10": torch.mean((__rank_list <= 10).float()).item(), 88 | "mean": torch.mean(__rank_list).item(), 89 | "mrr": torch.mean(__rank_list.reciprocal()).item() 90 | } 91 | # add round metrics 92 | _rank_list_rnd = np.concatenate(self._rank_list_rnd) 93 | _rank_list_rnd = _rank_list_rnd.astype(float) 94 | r_1_rnd = np.mean(_rank_list_rnd <= 1, axis=0) 95 | r_5_rnd = np.mean(_rank_list_rnd <= 5, axis=0) 96 | r_10_rnd = np.mean(_rank_list_rnd <= 10, axis=0) 97 | mean_rnd = np.mean(_rank_list_rnd, axis=0) 98 | mrr_rnd = np.mean(np.reciprocal(_rank_list_rnd), axis=0) 99 | 100 | for rnd in range(1, self.num_rounds + 1): 101 | metrics["r_1" + "_round_" + str(rnd)] = r_1_rnd[rnd-1] 102 | metrics["r_5" + "_round_" + str(rnd)] = r_5_rnd[rnd-1] 103 | metrics["r_10" + "_round_" + str(rnd)] = r_10_rnd[rnd-1] 104 | metrics["mean" + "_round_" + str(rnd)] = mean_rnd[rnd-1] 105 | metrics["mrr" + "_round_" + str(rnd)] = mrr_rnd[rnd-1] 106 | else: 107 | metrics = {} 108 | 109 | if reset: 110 | self.reset() 111 | return metrics 112 | 113 | def reset(self): 114 | self._rank_list = [] 115 | self._rank_list_rnd = [] 116 | 117 | class NDCG(object): 118 | def __init__(self): 119 | self._ndcg_numerator = 0.0 120 | self._ndcg_denominator = 0.0 121 | 122 | def observe( 123 | self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor 124 | ): 125 | """ 126 | Observe model output scores and target ground truth relevance and 127 | accumulate NDCG metric. 128 | 129 | Parameters 130 | ---------- 131 | predicted_scores: torch.Tensor 132 | A tensor of shape (batch_size, num_options), because dense 133 | annotations are available for 1 randomly picked round out of 10. 134 | target_relevance: torch.Tensor 135 | A tensor of shape same as predicted scores, indicating ground truth 136 | relevance of each answer option for a particular round. 137 | """ 138 | predicted_scores = predicted_scores.detach() 139 | 140 | # shape: (batch_size, 1, num_options) 141 | predicted_scores = predicted_scores.unsqueeze(1) 142 | predicted_ranks = scores_to_ranks(predicted_scores) 143 | 144 | # shape: (batch_size, num_options) 145 | predicted_ranks = predicted_ranks.squeeze() 146 | batch_size, num_options = predicted_ranks.size() 147 | 148 | k = torch.sum(target_relevance != 0, dim=-1) 149 | 150 | # shape: (batch_size, num_options) 151 | _, rankings = torch.sort(predicted_ranks, dim=-1) 152 | # Sort relevance in descending order so highest relevance gets top rnk. 153 | _, best_rankings = torch.sort( 154 | target_relevance, dim=-1, descending=True 155 | ) 156 | 157 | # shape: (batch_size, ) 158 | batch_ndcg = [] 159 | for batch_index in range(batch_size): 160 | num_relevant = k[batch_index] 161 | dcg = self._dcg( 162 | rankings[batch_index][:num_relevant], 163 | target_relevance[batch_index], 164 | ) 165 | best_dcg = self._dcg( 166 | best_rankings[batch_index][:num_relevant], 167 | target_relevance[batch_index], 168 | ) 169 | batch_ndcg.append(dcg / best_dcg) 170 | 171 | self._ndcg_denominator += batch_size 172 | self._ndcg_numerator += sum(batch_ndcg) 173 | 174 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): 175 | sorted_relevance = relevance[rankings].cpu().float() 176 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2) 177 | return torch.sum(sorted_relevance / discounts, dim=-1) 178 | 179 | def retrieve(self, reset: bool = True): 180 | if self._ndcg_denominator > 0: 181 | metrics = { 182 | "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) 183 | } 184 | else: 185 | metrics = {} 186 | 187 | if reset: 188 | self.reset() 189 | return metrics 190 | 191 | def reset(self): 192 | self._ndcg_numerator = 0.0 193 | self._ndcg_denominator = 0.0 194 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os.path as pth 2 | import json 3 | import numpy as np 4 | import visdom 5 | 6 | class VisdomVisualize(): 7 | def __init__(self, 8 | env_name='main', 9 | server="http://127.0.0.1", 10 | port=8893, 11 | enable=True): 12 | ''' 13 | Initialize a visdom server on server:port 14 | ''' 15 | print("Initializing visdom env [%s]" % env_name) 16 | self.is_enabled = enable 17 | self.env_name = env_name 18 | if self.is_enabled: 19 | self.viz = visdom.Visdom( 20 | port=port, 21 | env=env_name, 22 | server=server, 23 | ) 24 | else: 25 | self.viz = None 26 | self.wins = {} 27 | 28 | def linePlot(self, x, y, key, line_name, xlabel="Iterations"): 29 | ''' 30 | Add or update a line plot on the visdom server self.viz 31 | Argumens: 32 | x : Scalar -> X-coordinate on plot 33 | y : Scalar -> Value at x 34 | key : Name of plot/graph 35 | line_name : Name of line within plot/graph 36 | xlabel : Label for x-axis (default: # Iterations) 37 | 38 | Plots and lines are created if they don't exist, otherwise 39 | they are updated. 40 | ''' 41 | key = str(key) 42 | if self.is_enabled: 43 | if key in self.wins.keys(): 44 | self.viz.line( 45 | X = np.array([x]), 46 | Y = np.array([y]), 47 | win = self.wins[key], 48 | update = 'append', 49 | name = line_name, 50 | opts = dict(showlegend=True), 51 | ) 52 | else: 53 | self.wins[key] = self.viz.line( 54 | X = np.array([x]), 55 | Y = np.array([y]), 56 | win = key, 57 | name = line_name, 58 | opts = { 59 | 'xlabel': xlabel, 60 | 'ylabel': key, 61 | 'title': key, 62 | 'showlegend': True, 63 | # 'legend': [line_name], 64 | } 65 | ) 66 | 67 | def showText(self, text, key): 68 | ''' 69 | Created a named text window or updates an existing one with 70 | the name == key 71 | ''' 72 | key = str(key) 73 | if self.is_enabled: 74 | win = self.wins[key] if key in self.wins else None 75 | self.wins[key] = self.viz.text(text, win=win) 76 | 77 | def addText(self, text): 78 | ''' 79 | Adds an unnamed text window without keeping track of win id 80 | ''' 81 | if self.is_enabled: 82 | self.viz.text(text) 83 | 84 | def save(self): 85 | if self.is_enabled: 86 | self.viz.save([self.env_name]) 87 | 88 | def histPlot(self, x, key): 89 | key = str(key) 90 | if self.is_enabled: 91 | if key in self.wins.keys(): 92 | self.viz.histogram( 93 | X = x.cpu().numpy(), 94 | win = self.wins[key], 95 | ) 96 | else: 97 | self.wins[key] = self.viz.histogram( 98 | X = x.cpu().numpy(), 99 | win = key 100 | ) --------------------------------------------------------------------------------