├── Implementation_on_visdial_bert ├── README.md ├── configs ├── baseline.yml ├── baseline_stage2.yml └── evaluete.yml ├── environment.yml ├── evaluate.py ├── hidden_dict ├── baseline_stage2.yml ├── dict1.yml ├── dict2.yml ├── train_dict_stage1.py ├── train_dict_stage2.py └── train_dict_stage3.py ├── question_type ├── qt.yml └── train_qt.py ├── train_stage1_baseline.py ├── train_stage2_baseline.py └── visdialch ├── data ├── __init__.py ├── dataset.py ├── dataset_qt.py ├── readers.py ├── readers_qt.py └── vocabulary.py ├── decoders ├── __init__.py ├── disc_by_round.py ├── disc_qt.py └── discvdr.py ├── encoders ├── Coatt.py ├── Coatt_withP1.py ├── HCIAE.py ├── HCIAE_withP1.py ├── __init__.py ├── dict_encoder.py ├── lf_enhanced.py ├── lf_enhanced_withP1.py ├── rva.py └── rva_withP1.py ├── metrics.py ├── model.py └── utils ├── __init__.py ├── checkpointing.py └── dynamic_rnn.py /Implementation_on_visdial_bert: -------------------------------------------------------------------------------- 1 | 1. It is really hard to eliminate the direct path from history to answer, so in the current stage, we suggest to remove history. 2 | You can change the dataloader files to do this. 3 | 2. For implementation of P2, we recommend the R4 Loss which is found after the paper publication. Meanwhile, we suggest enlarge the learning rate. 4 | The experience is 3~5 epochs for stage 2. 200k iterations for stage 1. 5 | 3. We also suggest to decrease the loss of object prediction and LM. 6 | 4. We found that the ensemble of vd-bert result and our results can bring a further improvement though the single performance of vd-bert is about 75.5% and the ensemble of ours is lower than 75% 7 | 5. About the fusion of NDCG and MRR, only use the first place of MRR result and others use the NDCG result is a good strategy, which we omitted in the challenge, of course other fusion strategy using first several result of MRR can make some trade-off. This fusion strategy will not hurt NDCG a lot. In the leaderboard, when you find the MRR and R@10 is different from the experience, the author may apply this strategy. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # VisDial-principles 3 | 4 | This repository is the updated PyTorch implementation for CVPR 2020 Paper "Two Causal Principles for Improving Visual Dialog", which is also the newest version for the Visual Dialog Challenge 2019 winner team (Here is the [report](https://drive.google.com/file/d/1fqg0hregsp_3USM6XCHx89S9JLXt8bKp/view)). For the detailed theories, please refer to our [paper](https://arxiv.org/abs/1911.10496). 5 | 6 | Note that this repository is based on the official [code](https://github.com/batra-mlp-lab/visdial), for the newest official code, please refer to [vi-bert version](https://github.com/vmurahari3/visdial-bert#setup-and-dependencies). 7 | 8 | If you find this work is useful in your research, please kindly consider citing: 9 | 10 | ``` 11 | @inproceedings{qi2020two, 12 | title={Two causal principles for improving visual dialog}, 13 | author={Qi, Jiaxin and Niu, Yulei and Huang, Jianqiang and Zhang, Hanwang}, 14 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 15 | pages={10860--10869}, 16 | year={2020} 17 | } 18 | ``` 19 | #### Dependencies 20 | create a conda environment: 21 | ``` 22 | conda env create -f environment.yml 23 | ``` 24 | download nltk: 25 | ``` 26 | python -c "import nltk; nltk.download('all')" 27 | ``` 28 | #### Preparing (download data and pretrained model) 29 | 1.download data and pretrained model respectively: 30 | 31 | 1.1.create directory data/ and download necessary files into the data/: 32 | 33 | from the official [website](https://visualdialog.org/data): 34 | 35 | [visdial_1.0_train.json](https://www.dropbox.com/s/ix8keeudqrd8hn8/visdial_1.0_train.zip?dl=0) 36 | 37 | [visdial_1.0_val.json](https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0) 38 | 39 | [visdial_1.0_val_dense_annotations.json](https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0) 40 | 41 | [visdial_1.0_train_dense_sample.json](https://www.dropbox.com/s/1ajjfpepzyt3q4m/visdial_1.0_train_dense_sample.json?dl=0) 42 | 43 | 1.2.from us or collect by yourself:(also save in the directory data/) 44 | 45 | [features_faster_rcnn_x101_train.h5](https://drive.google.com/open?id=1eC80EMMEdZvWsKIl3YlEFpY4XHlvN9h8) 46 | 47 | [features_faster_rcnn_x101_val.h5](https://drive.google.com/open?id=1_QoH-lbRCwPrcuiwVNjhW1yMxhqiLclB) 48 | 49 | [features_faster_rcnn_x101_test.h5](https://drive.google.com/open?id=1hyMCJLXAyaNHmnoRZM8eF3fNia49oHLl) 50 | 51 | [visdial_1.0_word_counts_train.json](https://drive.google.com/open?id=1zL8P5LnPzRbfaPxJXvFVGBlS7SumOB_g) 52 | 53 | [glove.npy](https://drive.google.com/open?id=1y4oSqAwgu2gIcyuF5ZuMuNZ-c-89NGuJ) 54 | 55 | [qt_count.json](https://drive.google.com/open?id=1hllnesIwb__kVHmn5Mtz9CLt9VXnCUS_) 56 | 57 | [qt_scores.json](https://drive.google.com/open?id=1QlKy4lVHMlZ4hqw4tVaB608WMBo-eBDs) (the key in each question type is the index of candidate in answer list) 58 | 59 | [100ans_feature.npy](https://drive.google.com/open?id=1vu9wMGc8GTj-83ILlUxyuk8_4aCLAIkm) (for initial answer dict) 60 | 61 | 1.3.download the pretrained model: 62 | 63 | [baseline_withP1_checkpiont5.pth](https://drive.google.com/open?id=1LZizUL1lSnLU9ZPmePUfDDtSBQVjAyH8) 64 | 65 | #### Training 66 | 0.1 Check your gpu id and change it at --gpu-ids 67 | 68 | 1.baseline (recommend to use checkpoint 5-7 to finetune) 69 | ``` 70 | python train_stage1_baseline.py --validate --in-memory --save-model 71 | ``` 72 | (Note: for other encoders, please follow the format and note in the code) 73 | 74 | 2.different loss functions for answer score sampling (dense finetuning, R3 as default, because of the dense samples are rare, the results maybe a little bit unstable). Besides, we add another newest loss function R4 (Normalized BCE, which is better than R2, recommended). 75 | ``` 76 | python train_stage2_baseline.py --loss-function R4 --load-pthpath checkpoints/baseline_withP1_checkpiont5.pth 77 | ``` 78 | 3.question type implementation (download the qt file or create it follow our paper) 79 | ``` 80 | cd question_type 81 | python train_qt.py --validate --in-memory --save-model --load-pthpath checkpoints/baseline_withP1_checkpiont5.pth 82 | ``` 83 | Note that you can train it from pretrained model or train it from scratch (adjust the lr and decay epochs carefully). Besides, you can try to use question type preference to directly help baseline model to inference. But here, note that the candidate answer list in training is different from the one in validation, please take care to do the index conversion. 84 | 85 | 4.dictionary learning (three steps: train dict, finetune dict, finetune the whole model) 86 | ``` 87 | cd hidden_dict 88 | python train_dict_stage1.py --save-model 89 | python train_dict_stage2.py --save-model --load-pthpath 90 | python train_dict_stage3.py --load-dict-pthpath --load-pthpath checkpoints/baseline_withP1_checkpiont5.pth 91 | ``` 92 | Besides, after our code optimization, some implementations can get a little bit better results, but do not influence the conclusions of our principles. If you think the MRR score is too low, you can try train with larger one-hot weight to keep MRR. Furthermore, just use the top1 candidate of stage 1 model and the rest use ranks from finetuned model will get better balanced performance of both MRR and NDCG! 93 | 94 | #### Evaluation 95 | You can directly evaluate a model use the following code: (please check the settings in configs/evaluate.yml) 96 | ``` 97 | python evaluate.py --load-pthpath 98 | ``` 99 | If you have any other questions or suggestions, please kindly email me. 100 | #### Acknowledgements 101 | 102 | Thanks for the source code from [the official](https://visualdialog.org/) 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /configs/baseline.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | glove_npy: 'data/glove.npy' 9 | img_norm: 1 10 | concat_history: false 11 | max_sequence_length: 20 12 | vocab_min_count: 5 13 | 14 | 15 | # Model related arguments 16 | model: 17 | encoder: 'baseline_encoder_withP1' #baseline_encoder_withP1 with disc_by_round #rva(encoder) with discvdr(decoder) 18 | decoder: 'disc_by_round' 19 | 20 | img_feature_size: 2048 21 | word_embedding_size: 300 22 | lstm_hidden_size: 512 23 | lstm_num_layers: 2 24 | head_num: 3 25 | dropout: 0.4 # can change 0.4 to 0.3, because a smaller lr (about 0.3) is better 26 | dropout_fc: 0.25 # can change 0.25 to 0.1, because a smaller lr (about 0.1) is better 27 | ans_cls_num: 4 28 | 29 | # Optimization related arguments 30 | solver: 31 | batch_size: 128 # 16 for 10 rounds (rva) 32 | num_epochs: 15 33 | initial_lr: 0.004 # general is 0.004 34 | training_splits: "train" 35 | lr_gamma: 0.4 36 | lr_milestones: # epochs when lr —> lr * lr_gamma 37 | - 5 38 | - 7 39 | - 9 40 | warmup_factor: 0.2 41 | warmup_epochs: 1 42 | -------------------------------------------------------------------------------- /configs/baseline_stage2.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'baseline_encoder_withP1' #baseline_encoder_withP1 with disc_by_round #rva(encoder) with discvdr(decoder) 17 | decoder: 'disc_by_round' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.4 25 | dropout_fc: 0.25 26 | ans_cls_num: 4 27 | 28 | # Optimization related arguments 29 | solver: 30 | batch_size: 12 31 | num_epochs: 3 # you can try more epochs or change the lr # experience is 2 or 3 32 | initial_lr: 0.002 33 | training_splits: "train" 34 | lr_gamma: 0.3 35 | lr_milestones: 36 | - 2 37 | warmup_factor: 0.5 38 | warmup_epochs: 0 39 | -------------------------------------------------------------------------------- /configs/evaluete.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'baseline_encoder_withP1' 17 | decoder: 'disc_by_round' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.3 # change 0.4 to 0.3, because a smaller lr (about 0.3) is better 25 | dropout_fc: 0.1 # change 0.25 to 0.1, because a smaller lr (about 0.1) is better 26 | 27 | 28 | # Optimization related arguments 29 | solver: 30 | batch_size: 128 # 64 x num_gpus is a good rule of thumb 31 | num_epochs: 15 32 | initial_lr: 0.004 # general is 0.004 33 | training_splits: "train" 34 | lr_gamma: 0.4 35 | lr_milestones: # epochs when lr —> lr * lr_gamma 36 | - 5 37 | - 7 38 | - 9 39 | warmup_factor: 0.2 40 | warmup_epochs: 1 41 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: visualdialog-principles 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_llvm 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.11.28=hecc5488_0 10 | - certifi=2019.11.28=py36h9f0ad1d_1 11 | - cffi=1.14.0=py36hd463f26_0 12 | - cudatoolkit=10.0.130=0 13 | - cudnn=7.6.5=cuda10.0_0 14 | - freetype=2.10.1=he06d7ca_0 15 | - h5py=2.10.0=nompi_py36h513d04c_102 16 | - hdf5=1.10.5=nompi_h3c11f04_1104 17 | - jpeg=9c=h14c3975_1001 18 | - ld_impl_linux-64=2.34=h53a641e_0 19 | - libffi=3.2.1=he1b5a44_1007 20 | - libgcc-ng=9.2.0=h24d8f2e_2 21 | - libgfortran-ng=7.3.0=hdf63c60_5 22 | - libpng=1.6.37=hed695b0_1 23 | - libstdcxx-ng=9.2.0=hdf63c60_2 24 | - libtiff=4.1.0=hc7e4089_6 25 | - libwebp-base=1.1.0=h516909a_3 26 | - llvm-openmp=9.0.1=hc9558a2_2 27 | - lz4-c=1.8.3=he1b5a44_1001 28 | - mkl=2019.5=281 29 | - mkl-service=2.3.0=py36h516909a_0 30 | - mkl_fft=1.1.0=py36hc1659b7_1 31 | - mkl_random=1.1.0=py36hb3f55d8_0 32 | - ncurses=6.1=hf484d3e_1002 33 | - ninja=1.10.0=hc9558a2_0 34 | - nltk=3.4.4=py_0 35 | - numpy=1.18.1=py36h4f9e942_0 36 | - numpy-base=1.18.1=py36hde5b4d6_1 37 | - olefile=0.46=py_0 38 | - openssl=1.1.1f=h516909a_0 39 | - pillow=7.0.0=py36h8328e55_1 40 | - pip=20.0.2=py_2 41 | - pycparser=2.20=py_0 42 | - python=3.6.10=h9d8adfe_1009_cpython 43 | - python_abi=3.6=1_cp36m 44 | - pytorch=1.0.1=cuda100py36he554f03_0 45 | - readline=8.0=hf8c457e_0 46 | - setuptools=46.1.3=py36h9f0ad1d_0 47 | - six=1.14.0=py_1 48 | - sqlite=3.30.1=hcee41ef_0 49 | - tk=8.6.10=hed695b0_0 50 | - torchvision=0.2.1=py36_0 51 | - tqdm=4.44.1=pyh9f0ad1d_0 52 | - wheel=0.34.2=py_1 53 | - xz=5.2.4=h516909a_1002 54 | - zlib=1.2.11=h516909a_1006 55 | - zstd=1.4.4=h3b9ef0a_2 56 | - pip: 57 | - protobuf==3.11.3 58 | - pyyaml==5.3.1 59 | - tensorboardx==1.6 60 | 61 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import time 4 | import datetime 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | from torch import nn, optim 9 | from torch.optim import lr_scheduler 10 | from torch.utils.data import DataLoader 11 | import yaml 12 | from bisect import bisect 13 | import random 14 | from visdialch.data.dataset import VisDialDataset 15 | from visdialch.encoders import Encoder 16 | from visdialch.decoders import Decoder 17 | from visdialch.metrics import SparseGTMetrics, NDCG 18 | from visdialch.model import EncoderDecoderModel 19 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--config-yml", 24 | default="configs/evaluete.yml", 25 | help="Path to a config file listing reader, model and solver parameters.", 26 | ) 27 | parser.add_argument( 28 | "--val-json", 29 | default="data/visdial_1.0_val.json", 30 | help="Path to json file containing VisDial v1.0 validation data.", 31 | ) 32 | parser.add_argument( 33 | "--val-dense-json", 34 | default="data/visdial_1.0_val_dense_annotations.json", 35 | help="Path to json file containing VisDial v1.0 validation dense ground " 36 | "truth annotations.", 37 | ) 38 | parser.add_argument_group( 39 | "Arguments independent of experiment reproducibility" 40 | ) 41 | parser.add_argument( 42 | "--gpu-ids", 43 | nargs="+", 44 | type=int, 45 | default=[2, 3], 46 | help="List of ids of GPUs to use.", 47 | ) 48 | parser.add_argument( 49 | "--cpu-workers", 50 | type=int, 51 | default=8, 52 | help="Number of CPU workers for dataloader.", 53 | ) 54 | parser.add_argument( 55 | "--overfit", 56 | action="store_true", 57 | help="Overfit model on 5 examples, meant for debugging.", 58 | ) 59 | parser.add_argument( 60 | "--in-memory", 61 | action="store_true", 62 | help="Load the whole dataset and pre-extracted image features in memory. " 63 | "Use only in presence of large RAM, atleast few tens of GBs.", 64 | ) 65 | 66 | parser.add_argument_group("Checkpointing related arguments") 67 | parser.add_argument( 68 | "--load-pthpath", 69 | default="", 70 | help="To continue training, path to .pth file of saved checkpoint.", 71 | ) 72 | 73 | manualSeed = random.randint(1, 10000) 74 | print("Random Seed: ", manualSeed) 75 | torch.manual_seed(manualSeed) 76 | torch.cuda.manual_seed_all(manualSeed) 77 | torch.backends.cudnn.benchmark = False 78 | torch.backends.cudnn.deterministic = True 79 | 80 | # ============================================================================= 81 | # INPUT ARGUMENTS AND CONFIG 82 | # ============================================================================= 83 | 84 | args = parser.parse_args() 85 | # keys: {"dataset", "model", "solver"} 86 | config = yaml.load(open(args.config_yml)) 87 | 88 | if isinstance(args.gpu_ids, int): 89 | args.gpu_ids = [args.gpu_ids] 90 | device = ( 91 | torch.device("cuda", args.gpu_ids[0]) 92 | if args.gpu_ids[0] >= 0 93 | else torch.device("cpu") 94 | ) 95 | 96 | # Print config and args. 97 | print(yaml.dump(config, default_flow_style=False)) 98 | for arg in vars(args): 99 | print("{:<20}: {}".format(arg, getattr(args, arg))) 100 | 101 | # ============================================================================= 102 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 103 | # ============================================================================= 104 | 105 | val_dataset = VisDialDataset( 106 | config["dataset"], 107 | args.val_json, 108 | args.val_dense_json, 109 | overfit=args.overfit, 110 | in_memory=args.in_memory, 111 | return_options=True, 112 | add_boundary_toks=False, 113 | sample_flag=False 114 | ) 115 | val_dataloader = DataLoader( 116 | val_dataset, 117 | batch_size=config["solver"]["batch_size"], 118 | num_workers=args.cpu_workers, 119 | shuffle=True, 120 | ) 121 | 122 | # Pass vocabulary to construct Embedding layer. 123 | encoder = Encoder(config["model"], val_dataset.vocabulary) 124 | decoder = Decoder(config["model"], val_dataset.vocabulary) 125 | print("Encoder: {}".format(config["model"]["encoder"])) 126 | print("Decoder: {}".format(config["model"]["decoder"])) 127 | 128 | # Share word embedding between encoder and decoder. 129 | if args.load_pthpath == "": 130 | print('load glove') 131 | decoder.word_embed = encoder.word_embed 132 | glove = np.load('data/glove.npy') 133 | encoder.word_embed.weight.data = torch.tensor(glove) 134 | 135 | # Wrap encoder and decoder in a model. 136 | model = EncoderDecoderModel(encoder, decoder).to(device) 137 | if -1 not in args.gpu_ids: 138 | model = nn.DataParallel(model, args.gpu_ids) 139 | 140 | # ============================================================================= 141 | # SETUP BEFORE TRAINING LOOP 142 | # ============================================================================= 143 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 144 | 145 | sparse_metrics = SparseGTMetrics() 146 | ndcg = NDCG() 147 | 148 | # loading checkpoint 149 | start_epoch = 0 150 | model_state_dict, _ = load_checkpoint(args.load_pthpath) 151 | if isinstance(model, nn.DataParallel): 152 | model.module.load_state_dict(model_state_dict) 153 | else: 154 | model.load_state_dict(model_state_dict) 155 | print("Loaded model from {}".format(args.load_pthpath)) 156 | 157 | def get_1round_batch_data(batch, rnd): 158 | temp_train_batch = {} 159 | for key in batch: 160 | if key in ['img_feat']: 161 | temp_train_batch[key] = batch[key].to(device) 162 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 163 | temp_train_batch[key] = batch[key][:, rnd].to(device) 164 | elif key in ['hist_len', 'hist']: 165 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 166 | else: 167 | pass 168 | return temp_train_batch 169 | 170 | model.eval() 171 | for i, batch in enumerate(val_dataloader): 172 | batchsize = batch['img_ids'].shape[0] 173 | rnd = 0 174 | temp_train_batch = get_1round_batch_data(batch, rnd) 175 | output = model(temp_train_batch).view(-1, 1, 100).detach() 176 | for rnd in range(1, 10): 177 | temp_train_batch = get_1round_batch_data(batch, rnd) 178 | output = torch.cat((output, model(temp_train_batch).view(-1, 1, 100).detach()), dim=1) 179 | sparse_metrics.observe(output, batch["ans_ind"]) 180 | if "relevance" in batch: 181 | output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] 182 | ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100)) 183 | # if i > 5: #for debug(like the --overfit) 184 | # break 185 | all_metrics = {} 186 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 187 | all_metrics.update(ndcg.retrieve(reset=True)) 188 | for metric_name, metric_value in all_metrics.items(): 189 | print(f"{metric_name}: {metric_value}") 190 | model.train() 191 | -------------------------------------------------------------------------------- /hidden_dict/baseline_stage2.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'baseline_encoder_withP1' 17 | decoder: 'disc_by_round' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.4 25 | dropout_fc: 0.25 26 | 27 | 28 | # Optimization related arguments 29 | solver: 30 | batch_size: 12 31 | num_epochs: 3 # you can try more epochs or change the lr 32 | initial_lr: 0.002 33 | training_splits: "train" 34 | lr_gamma: 0.3 35 | lr_milestones: 36 | - 2 37 | warmup_factor: 0.5 38 | warmup_epochs: 0 39 | -------------------------------------------------------------------------------- /hidden_dict/dict1.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'dict_encoder' 17 | decoder: 'disc_by_round' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.3 # change 0.4 to 0.3, because a smaller lr (about 0.3) is better 25 | dropout_fc: 0.1 # change 0.25 to 0.1, because a smaller lr (about 0.1) is better 26 | 27 | # Optimization related arguments 28 | solver: 29 | batch_size: 128 # 64 x num_gpus is a good rule of thumb 30 | num_epochs: 6 31 | initial_lr: 0.004 # general is 0.004 32 | training_splits: "train" 33 | lr_gamma: 0.2 34 | lr_milestones: # epochs when lr —> lr * lr_gamma 35 | - 4 36 | warmup_factor: 0.2 37 | warmup_epochs: 1 38 | -------------------------------------------------------------------------------- /hidden_dict/dict2.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'dict_encoder' 17 | decoder: 'disc_by_round' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.4 # change 0.4 to 0.3, because a smaller lr (about 0.3) is better 25 | dropout_fc: 0.25 # change 0.25 to 0.1, because a smaller lr (about 0.1) is better 26 | 27 | # Optimization related arguments 28 | solver: 29 | batch_size: 12 30 | num_epochs: 3 31 | initial_lr: 0.004 32 | training_splits: "train" 33 | lr_gamma: 0.3 34 | lr_milestones: 35 | - 3 36 | warmup_factor: 0.5 37 | warmup_epochs: 1 38 | -------------------------------------------------------------------------------- /hidden_dict/train_dict_stage1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import json 4 | import datetime 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | from torch import nn, optim 9 | from torch.optim import lr_scheduler 10 | from torch.utils.data import DataLoader 11 | import yaml 12 | from torch.nn import functional as F 13 | from bisect import bisect 14 | import random 15 | import sys 16 | sys.path.append('../') 17 | from visdialch.data.dataset import VisDialDataset 18 | from visdialch.encoders import Encoder 19 | from visdialch.decoders import Decoder 20 | from visdialch.metrics import SparseGTMetrics, NDCG 21 | from visdialch.model import EncoderDecoderModel 22 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 23 | from visdialch.encoders.dict_encoder import Dict_Encoder 24 | import os 25 | os.chdir('../') 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--config-yml", 29 | default="hidden_dict/dict1.yml", 30 | help="Path to a config file listing reader, model and solver parameters.", 31 | ) 32 | parser.add_argument( 33 | "--train-json", 34 | default="data/visdial_1.0_train.json", 35 | help="Path to json file containing VisDial v1.0 training data.", 36 | ) 37 | parser.add_argument( 38 | "--val-json", 39 | default="data/visdial_1.0_val.json", 40 | help="Path to json file containing VisDial v1.0 validation data.", 41 | ) 42 | parser.add_argument( 43 | "--val-dense-json", 44 | default="data/visdial_1.0_val_dense_annotations.json", 45 | help="Path to json file containing VisDial v1.0 validation dense ground " 46 | "truth annotations.", 47 | ) 48 | parser.add_argument_group( 49 | "Arguments independent of experiment reproducibility" 50 | ) 51 | parser.add_argument( 52 | "--gpu-ids", 53 | nargs="+", 54 | type=int, 55 | default=[0, 1], 56 | help="List of ids of GPUs to use.", 57 | ) 58 | parser.add_argument( 59 | "--cpu-workers", 60 | type=int, 61 | default=8, 62 | help="Number of CPU workers for dataloader.", 63 | ) 64 | parser.add_argument( 65 | "--overfit", 66 | action="store_true", 67 | help="Overfit model on 5 examples, meant for debugging.", 68 | ) 69 | parser.add_argument( 70 | "--in-memory", 71 | action="store_true", 72 | help="Load the whole dataset and pre-extracted image features in memory. " 73 | "Use only in presence of large RAM, atleast few tens of GBs.", 74 | ) 75 | parser.add_argument( 76 | "--save-dirpath", 77 | default="checkpoints/", 78 | help="Path of directory to create checkpoint directory and save " 79 | "checkpoints.", 80 | ) 81 | parser.add_argument( 82 | "--load-pthpath", 83 | default="checkpoints/dict_encoder+disc_by_round/03-Apr-2020-15:31:53/checkpoint_3.pth", 84 | help="To continue training, path to .pth file of saved checkpoint.", 85 | ) 86 | parser.add_argument( 87 | "--save-model", 88 | action="store_true", 89 | help="To make the dir clear", 90 | ) 91 | 92 | manualSeed = random.randint(1, 10000) 93 | print("Random Seed: ", manualSeed) 94 | torch.manual_seed(manualSeed) 95 | torch.cuda.manual_seed_all(manualSeed) 96 | torch.backends.cudnn.benchmark = False 97 | torch.backends.cudnn.deterministic = True 98 | 99 | # ============================================================================= 100 | # INPUT ARGUMENTS AND CONFIG 101 | # ============================================================================= 102 | 103 | args = parser.parse_args() 104 | # keys: {"dataset", "model", "solver"} 105 | config = yaml.load(open(args.config_yml)) 106 | 107 | if isinstance(args.gpu_ids, int): 108 | args.gpu_ids = [args.gpu_ids] 109 | device = ( 110 | torch.device("cuda", args.gpu_ids[0]) 111 | if args.gpu_ids[0] >= 0 112 | else torch.device("cpu") 113 | ) 114 | # ============================================================================= 115 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 116 | # ============================================================================= 117 | 118 | train_dataset = VisDialDataset( 119 | config["dataset"], 120 | args.train_json, 121 | overfit=args.overfit, 122 | in_memory=args.in_memory, 123 | return_options=True, 124 | add_boundary_toks=False, 125 | sample_flag=False 126 | ) 127 | train_dataloader = DataLoader( 128 | train_dataset, 129 | batch_size=config["solver"]["batch_size"], 130 | num_workers=args.cpu_workers, 131 | shuffle=True, 132 | ) 133 | val_dataset = VisDialDataset( 134 | config["dataset"], 135 | args.val_json, 136 | args.val_dense_json, 137 | overfit=args.overfit, 138 | in_memory=args.in_memory, 139 | return_options=True, 140 | add_boundary_toks=False, 141 | sample_flag=False 142 | ) 143 | val_dataloader = DataLoader( 144 | val_dataset, 145 | batch_size=config["solver"]["batch_size"], 146 | num_workers=args.cpu_workers, 147 | ) 148 | 149 | # Pass vocabulary to construct Embedding layer. 150 | encoder_dict = Dict_Encoder(config["model"], train_dataset.vocabulary) 151 | 152 | # Share word embedding between encoder and decoder. 153 | glove = np.load('data/glove.npy') 154 | encoder_dict.word_embed.weight.data = torch.tensor(glove) 155 | 156 | # Wrap encoder and decoder in a model. 157 | model = encoder_dict.to(device) 158 | if -1 not in args.gpu_ids: 159 | model = nn.DataParallel(model, args.gpu_ids) 160 | 161 | criterion = nn.CrossEntropyLoss() 162 | iterations = len(train_dataset)// config["solver"]["batch_size"] + 1 # 迭代次数 163 | 164 | def lr_lambda_fun(current_iteration: int) -> float: 165 | """Returns a learning rate multiplier. 166 | 167 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 168 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 169 | """ 170 | current_epoch = float(current_iteration) / iterations 171 | if current_epoch <= config["solver"]["warmup_epochs"]: 172 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 173 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 174 | else: 175 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 176 | return pow(config["solver"]["lr_gamma"], idx) 177 | 178 | optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) 179 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) 180 | 181 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 182 | checkpoint_dirpath = args.save_dirpath 183 | if checkpoint_dirpath == 'checkpoints/': 184 | checkpoint_dirpath += '%s+%s/%s' % (config["model"]["encoder"], config["model"]["decoder"], start_time) 185 | if args.save_model: 186 | summary_writer = SummaryWriter(log_dir=checkpoint_dirpath) 187 | checkpoint_manager = CheckpointManager(model, optimizer, checkpoint_dirpath, config=config) 188 | 189 | sparse_metrics = SparseGTMetrics() 190 | ndcg = NDCG() 191 | # If loading from checkpoint, adjust start epoch and load parameters. 192 | if args.load_pthpath == "": 193 | start_epoch = 0 194 | else: 195 | start_epoch = int(args.load_pthpath.split("_")[-1][:-4]) 196 | 197 | model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath) 198 | if isinstance(model, nn.DataParallel): 199 | model.module.load_state_dict(model_state_dict) 200 | else: 201 | model.load_state_dict(model_state_dict) 202 | print("Loaded model from {}".format(args.load_pthpath)) 203 | 204 | def get_1round_batch_data(batch, rnd): 205 | temp_train_batch = {} 206 | for key in batch: 207 | if key in ['img_feat']: 208 | temp_train_batch[key] = batch[key].to(device) 209 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 210 | temp_train_batch[key] = batch[key][:, rnd].to(device) 211 | elif key in ['hist_len', 'hist']: 212 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 213 | else: 214 | pass 215 | return temp_train_batch 216 | 217 | global_iteration_step = start_epoch * iterations 218 | ###start training and set functions used in training 219 | ##stage 1 220 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 221 | print('Training for epoch:', epoch, ' time:', time.asctime(time.localtime(time.time()))) 222 | count_loss = 0.0 223 | for i, batch in enumerate(train_dataloader): 224 | for rnd in range(10): 225 | temp_train_batch = get_1round_batch_data(batch, rnd) 226 | optimizer.zero_grad() 227 | output = model(temp_train_batch) 228 | target = batch["ans_ind"][:, rnd].to(device) 229 | batch_loss = criterion(output.view(-1, output.size(-1)), target.view(-1)) 230 | batch_loss.backward() 231 | count_loss += batch_loss.data.cpu().numpy() 232 | optimizer.step() 233 | 234 | if i % int(iterations / 10) == 0 and i != 0: 235 | mean_loss = (count_loss / int(iterations / 10)) / 10.0 236 | print('(step', i, 'in', int(iterations), ') mean_loss:', mean_loss, 'Time:', 237 | time.asctime(time.localtime(time.time())), 'lr:', optimizer.param_groups[0]["lr"]) 238 | count_loss = 0.0 239 | if args.save_model: 240 | summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) 241 | summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) 242 | scheduler.step(global_iteration_step) 243 | global_iteration_step += 1 244 | # if i > 5: #for debug(like the --overfit) 245 | # break 246 | if args.save_model: 247 | checkpoint_manager.step() 248 | -------------------------------------------------------------------------------- /hidden_dict/train_dict_stage2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import json 4 | import datetime 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | import torch 8 | from torch import nn, optim 9 | from torch.optim import lr_scheduler 10 | from torch.utils.data import DataLoader 11 | import yaml 12 | from torch.nn import functional as F 13 | from bisect import bisect 14 | import random 15 | import sys 16 | sys.path.append('../') 17 | from visdialch.data.dataset import VisDialDataset 18 | from visdialch.encoders import Encoder 19 | from visdialch.decoders import Decoder 20 | from visdialch.metrics import SparseGTMetrics, NDCG 21 | from visdialch.model import EncoderDecoderModel 22 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 23 | from visdialch.encoders.dict_encoder import Dict_Encoder 24 | import os 25 | os.chdir('../') 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--config-yml", 29 | default="hidden_dict/dict2.yml", 30 | help="Path to a config file listing reader, model and solver parameters.", 31 | ) 32 | parser.add_argument( 33 | "--train-json", 34 | default="data/visdial_1.0_train.json", 35 | help="Path to json file containing VisDial v1.0 training data.", 36 | ) 37 | parser.add_argument( 38 | "--val-json", 39 | default="data/visdial_1.0_val.json", 40 | help="Path to json file containing VisDial v1.0 validation data.", 41 | ) 42 | parser.add_argument( 43 | "--val-dense-json", 44 | default="data/visdial_1.0_val_dense_annotations.json", 45 | help="Path to json file containing VisDial v1.0 validation dense ground " 46 | "truth annotations.", 47 | ) 48 | parser.add_argument_group( 49 | "Arguments independent of experiment reproducibility" 50 | ) 51 | parser.add_argument( 52 | "--gpu-ids", 53 | nargs="+", 54 | type=int, 55 | default=[2], 56 | help="List of ids of GPUs to use.", 57 | ) 58 | parser.add_argument( 59 | "--cpu-workers", 60 | type=int, 61 | default=8, 62 | help="Number of CPU workers for dataloader.", 63 | ) 64 | parser.add_argument( 65 | "--overfit", 66 | action="store_true", 67 | help="Overfit model on 5 examples, meant for debugging.", 68 | ) 69 | parser.add_argument( 70 | "--in-memory", 71 | action="store_true", 72 | help="Load the whole dataset and pre-extracted image features in memory. " 73 | "Use only in presence of large RAM, atleast few tens of GBs.", 74 | ) 75 | parser.add_argument( 76 | "--save-dirpath", 77 | default="checkpoints/", 78 | help="Path of directory to create checkpoint directory and save " 79 | "checkpoints.", 80 | ) 81 | parser.add_argument( 82 | "--load-pthpath", 83 | default="", 84 | help="To continue training, path to .pth file of saved checkpoint.", 85 | ) 86 | parser.add_argument( 87 | "--save-model", 88 | action="store_true", 89 | help="To make the dir clear", 90 | ) 91 | 92 | manualSeed = random.randint(1, 10000) 93 | print("Random Seed: ", manualSeed) 94 | torch.manual_seed(manualSeed) 95 | torch.cuda.manual_seed_all(manualSeed) 96 | torch.backends.cudnn.benchmark = False 97 | torch.backends.cudnn.deterministic = True 98 | 99 | # ============================================================================= 100 | # INPUT ARGUMENTS AND CONFIG 101 | # ============================================================================= 102 | 103 | args = parser.parse_args() 104 | # keys: {"dataset", "model", "solver"} 105 | config = yaml.load(open(args.config_yml)) 106 | 107 | if isinstance(args.gpu_ids, int): 108 | args.gpu_ids = [args.gpu_ids] 109 | device = ( 110 | torch.device("cuda", args.gpu_ids[0]) 111 | if args.gpu_ids[0] >= 0 112 | else torch.device("cpu") 113 | ) 114 | # ============================================================================= 115 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 116 | # ============================================================================= 117 | 118 | train_sample_dataset = VisDialDataset( 119 | config["dataset"], 120 | args.train_json, 121 | overfit=args.overfit, 122 | in_memory=args.in_memory, 123 | return_options=True, 124 | add_boundary_toks=False, 125 | sample_flag=True # only train on data with dense annotations 126 | ) 127 | train_sample_dataloader = DataLoader( 128 | train_sample_dataset, 129 | batch_size=config["solver"]["batch_size"], 130 | num_workers=args.cpu_workers, 131 | shuffle=True, 132 | ) 133 | 134 | val_dataset = VisDialDataset( 135 | config["dataset"], 136 | args.val_json, 137 | args.val_dense_json, 138 | overfit=args.overfit, 139 | in_memory=args.in_memory, 140 | return_options=True, 141 | add_boundary_toks=False, 142 | sample_flag=False 143 | ) 144 | val_dataloader = DataLoader( 145 | val_dataset, 146 | batch_size=config["solver"]["batch_size"], 147 | num_workers=args.cpu_workers, 148 | ) 149 | 150 | # Pass vocabulary to construct Embedding layer. 151 | encoder_dict = Dict_Encoder(config["model"], val_dataset.vocabulary) 152 | 153 | # Share word embedding between encoder and decoder. 154 | glove = np.load('data/glove.npy') 155 | encoder_dict.word_embed.weight.data = torch.tensor(glove) 156 | 157 | # Wrap encoder and decoder in a model. 158 | model = encoder_dict.to(device) 159 | if -1 not in args.gpu_ids: 160 | model = nn.DataParallel(model, args.gpu_ids) 161 | 162 | criterion = nn.CrossEntropyLoss() 163 | iterations = len(train_sample_dataset)// config["solver"]["batch_size"] + 1 # 迭代次数 164 | 165 | def lr_lambda_fun(current_iteration: int) -> float: 166 | """Returns a learning rate multiplier. 167 | 168 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 169 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 170 | """ 171 | current_epoch = float(current_iteration) / iterations 172 | if current_epoch <= config["solver"]["warmup_epochs"]: 173 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 174 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 175 | else: 176 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 177 | return pow(config["solver"]["lr_gamma"], idx) 178 | 179 | optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) 180 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) 181 | 182 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 183 | checkpoint_dirpath = args.save_dirpath 184 | if checkpoint_dirpath == 'checkpoints/': 185 | checkpoint_dirpath += '%s+%s/%s' % (config["model"]["encoder"], config["model"]["decoder"], start_time) 186 | if args.save_model: 187 | summary_writer = SummaryWriter(log_dir=checkpoint_dirpath) 188 | checkpoint_manager = CheckpointManager(model, optimizer, checkpoint_dirpath, config=config) 189 | 190 | sparse_metrics = SparseGTMetrics() 191 | ndcg = NDCG() 192 | # If loading from checkpoint, adjust start epoch and load parameters. 193 | if args.load_pthpath == "": 194 | start_epoch = 0 195 | else: 196 | start_epoch = 0 197 | 198 | model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath) 199 | if isinstance(model, nn.DataParallel): 200 | model.module.load_state_dict(model_state_dict) 201 | else: 202 | model.load_state_dict(model_state_dict) 203 | print("Loaded model from {}".format(args.load_pthpath)) 204 | 205 | def get_1round_batch_data(batch, rnd): 206 | temp_train_batch = {} 207 | for key in batch: 208 | if key in ['img_feat']: 209 | temp_train_batch[key] = batch[key].to(device) 210 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 211 | temp_train_batch[key] = batch[key][:, rnd].to(device) 212 | elif key in ['hist_len', 'hist']: 213 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 214 | else: 215 | pass 216 | return temp_train_batch 217 | 218 | def get_1round_idx_batch_data(batch, rnd, idx): ##to get 1 round data with batch_size = 1 219 | temp_train_batch = {} 220 | for key in batch: 221 | if key in ['img_feat']: 222 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2].to(device) 223 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 224 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2][:, rnd].to(device) 225 | elif key in ['hist_len', 'hist']: 226 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2][:, :rnd + 1].to(device) 227 | else: 228 | pass 229 | return temp_train_batch 230 | 231 | global_iteration_step = start_epoch * iterations 232 | ###load ndcg label list 233 | samplefile = open('data/visdial_1.0_train_dense_sample.json', 'r') 234 | sample = json.loads(samplefile.read()) 235 | samplefile.close() 236 | ndcg_id_list = [] 237 | for idx in range(len(sample)): 238 | ndcg_id_list.append(sample[idx]['image_id']) 239 | 240 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 241 | model.train() 242 | print('Training for epoch:', epoch, ' time:', time.asctime(time.localtime(time.time()))) 243 | count_loss = 0.0 244 | for k, batch in enumerate(train_sample_dataloader): 245 | ##### find the round 246 | batchsize = batch['img_ids'].shape[0] 247 | grad_dict = {} 248 | optimizer.zero_grad() 249 | for idx in range(int(batchsize / 2)): 250 | for b in range(2): # here is because with the batch_size = 1 will raise error 251 | sample_idx = ndcg_id_list.index(batch['img_ids'][idx * 2 + b].item()) 252 | final_round = sample[sample_idx]['round_id'] - 1 253 | rnd = final_round 254 | temp_train_batch = get_1round_idx_batch_data(batch, rnd, idx) 255 | output = model(temp_train_batch)[b] ## this is only for avoid bug, no other meanings 256 | target = batch["ans_ind"][b, rnd].to(device) 257 | rs_score = sample[sample_idx]['relevance'] 258 | cuda_device = output.device 259 | #use R4 loss 260 | batch_loss = criterion(output.view(-1, output.size(-1)), target.view(-1)) 261 | output_sig = torch.sigmoid(output) 262 | rs_score = torch.tensor(rs_score).to(cuda_device) 263 | rs_score = F.normalize(rs_score.unsqueeze(0), p=1).squeeze(0) # norm 264 | max_rs_score = torch.max(rs_score) 265 | for rs_idx in range(len(rs_score)): 266 | a = rs_score[rs_idx] 267 | s = output_sig[rs_idx] 268 | if s != 1: # s cannot be 1 269 | batch_loss += - 20 * (a * torch.log(s) + (max_rs_score - a) * torch.log(1 - s)) 270 | batch_loss = batch_loss / len(rs_score) 271 | ###end loss computation 272 | if batch_loss != 0: # prevent batch loss = 0 273 | batch_loss.backward() 274 | count_loss += batch_loss.data.cpu().numpy() 275 | optimizer.step() ##accumulate the whole grads in a batch (default is 12) and update weights 276 | 277 | if k % int(iterations / 10) == 0 and k != 0: 278 | mean_loss = (count_loss / int(iterations / 10)) / 10.0 279 | print('(step', k, 'in', int(iterations), ') mean_loss:', mean_loss, 'Time:', 280 | time.asctime(time.localtime(time.time())), 'lr:', optimizer.param_groups[0]["lr"]) 281 | count_loss = 0.0 282 | if args.save_model: 283 | summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) 284 | summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) 285 | scheduler.step(global_iteration_step) 286 | global_iteration_step += 1 287 | # if k == 5: #for debug 288 | # break 289 | if args.save_model: 290 | checkpoint_manager.step() 291 | 292 | model.eval() 293 | for i, batch in enumerate(val_dataloader): 294 | batchsize = batch['img_ids'].shape[0] 295 | rnd = 0 296 | temp_train_batch = get_1round_batch_data(batch, rnd) 297 | output = model(temp_train_batch).view(-1, 1, 100).detach() 298 | for rnd in range(1, 10): 299 | temp_train_batch = get_1round_batch_data(batch, rnd) 300 | output = torch.cat((output, model(temp_train_batch).view(-1, 1, 100).detach()), dim=1) 301 | sparse_metrics.observe(output, batch["ans_ind"]) 302 | if "relevance" in batch: 303 | output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] 304 | ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100)) 305 | # if i > 5: #for debug(like the --overfit) 306 | # break 307 | all_metrics = {} 308 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 309 | all_metrics.update(ndcg.retrieve(reset=True)) 310 | for metric_name, metric_value in all_metrics.items(): 311 | print(f"{metric_name}: {metric_value}") 312 | model.train() -------------------------------------------------------------------------------- /hidden_dict/train_dict_stage3.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import time 4 | import datetime 5 | from tensorboardX import SummaryWriter 6 | import torch 7 | from torch import nn, optim 8 | from torch.optim import lr_scheduler 9 | from torch.utils.data import DataLoader 10 | import yaml 11 | from bisect import bisect 12 | from torch.nn import functional as F 13 | import random 14 | import sys 15 | sys.path.append('../') 16 | from visdialch.data.dataset import VisDialDataset 17 | from visdialch.encoders import Encoder 18 | from visdialch.decoders import Decoder 19 | from visdialch.metrics import SparseGTMetrics, NDCG 20 | from visdialch.model import EncoderDecoderModel 21 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 22 | from visdialch.encoders.dict_encoder import Dict_Encoder 23 | import os 24 | os.chdir('../') 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--config-yml", 28 | default="hidden_dict/baseline_stage2.yml", 29 | help="Path to a config file listing reader, model and solver parameters.", 30 | ) 31 | parser.add_argument( 32 | "--config-dict-yml", 33 | default="hidden_dict/dict2.yml", 34 | help="Path to a config file listing reader, model and solver parameters.", 35 | ) 36 | parser.add_argument( 37 | "--train-json", 38 | default="data/visdial_1.0_train.json", 39 | help="Path to json file containing VisDial v1.0 training data.", 40 | ) 41 | parser.add_argument( 42 | "--val-json", 43 | default="data/visdial_1.0_val.json", 44 | help="Path to json file containing VisDial v1.0 validation data.", 45 | ) 46 | parser.add_argument( 47 | "--val-dense-json", 48 | default="data/visdial_1.0_val_dense_annotations.json", 49 | help="Path to json file containing VisDial v1.0 validation dense ground " 50 | "truth annotations.", 51 | ) 52 | parser.add_argument_group( 53 | "Arguments independent of experiment reproducibility" 54 | ) 55 | parser.add_argument( 56 | "--gpu-ids", 57 | nargs="+", 58 | type=int, 59 | default=[3], 60 | help="List of ids of GPUs to use.", 61 | ) 62 | parser.add_argument( 63 | "--cpu-workers", 64 | type=int, 65 | default=4, 66 | help="Number of CPU workers for dataloader.", 67 | ) 68 | parser.add_argument( 69 | "--overfit", 70 | action="store_true", 71 | help="Overfit model on 5 examples, meant for debugging.", 72 | ) 73 | parser.add_argument( 74 | "--in-memory", 75 | action="store_true", 76 | help="Load the whole dataset and pre-extracted image features in memory. " 77 | "Use only in presence of large RAM, atleast few tens of GBs.", 78 | ) 79 | parser.add_argument( 80 | "--load-pthpath", 81 | default='', 82 | help="To continue training, path to .pth file of saved checkpoint.", 83 | ) 84 | parser.add_argument( 85 | "--load-dict-pthpath", 86 | default='', 87 | help="To continue training, path to .pth file of saved checkpoint.", 88 | ) 89 | 90 | 91 | manualSeed = random.randint(1, 10000) 92 | print("Random Seed: ", manualSeed) 93 | torch.manual_seed(manualSeed) 94 | torch.cuda.manual_seed_all(manualSeed) 95 | torch.backends.cudnn.benchmark = False 96 | torch.backends.cudnn.deterministic = True 97 | 98 | # ============================================================================= 99 | # INPUT ARGUMENTS AND CONFIG 100 | # ============================================================================= 101 | 102 | args = parser.parse_args() 103 | # keys: {"dataset", "model", "solver"} 104 | config = yaml.load(open(args.config_yml)) 105 | config_dict = yaml.load(open(args.config_dict_yml)) 106 | 107 | if isinstance(args.gpu_ids, int): 108 | args.gpu_ids = [args.gpu_ids] 109 | device = ( 110 | torch.device("cuda", args.gpu_ids[0]) 111 | if args.gpu_ids[0] >= 0 112 | else torch.device("cpu") 113 | ) 114 | 115 | # Print config and args. 116 | print(yaml.dump(config, default_flow_style=False)) 117 | for arg in vars(args): 118 | print("{:<20}: {}".format(arg, getattr(args, arg))) 119 | 120 | # ============================================================================= 121 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 122 | # ============================================================================= 123 | 124 | train_sample_dataset = VisDialDataset( 125 | config["dataset"], 126 | args.train_json, 127 | overfit=args.overfit, 128 | in_memory=args.in_memory, 129 | return_options=True, 130 | add_boundary_toks=False, 131 | sample_flag=True # only train on data with dense annotations 132 | ) 133 | train_sample_dataloader = DataLoader( 134 | train_sample_dataset, 135 | batch_size=config["solver"]["batch_size"], 136 | num_workers=args.cpu_workers, 137 | shuffle=True, 138 | ) 139 | 140 | val_dataset = VisDialDataset( 141 | config["dataset"], 142 | args.val_json, 143 | args.val_dense_json, 144 | overfit=args.overfit, 145 | in_memory=args.in_memory, 146 | return_options=True, 147 | add_boundary_toks=False, 148 | sample_flag=False 149 | ) 150 | val_dataloader = DataLoader( 151 | val_dataset, 152 | batch_size=config["solver"]["batch_size"], 153 | num_workers=args.cpu_workers, 154 | ) 155 | 156 | # Pass vocabulary to construct Embedding layer. 157 | encoder_dict = Dict_Encoder(config_dict["model"], train_sample_dataset.vocabulary) 158 | encoder = Encoder(config["model"], train_sample_dataset.vocabulary) 159 | decoder = Decoder(config["model"], train_sample_dataset.vocabulary) 160 | decoder.word_embed = encoder.word_embed 161 | model_dict = encoder_dict.to(device) 162 | # Wrap encoder and decoder in a model. 163 | model = EncoderDecoderModel(encoder, decoder).to(device) 164 | if -1 not in args.gpu_ids: 165 | model = nn.DataParallel(model, args.gpu_ids) 166 | 167 | criterion = nn.CrossEntropyLoss() 168 | criterion_bce = nn.BCEWithLogitsLoss() 169 | iterations = len(train_sample_dataset) // config["solver"]["batch_size"] + 1 170 | 171 | 172 | def lr_lambda_fun(current_iteration: int) -> float: 173 | """Returns a learning rate multiplier. 174 | 175 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 176 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 177 | """ 178 | current_epoch = float(current_iteration) / iterations 179 | if current_epoch < config["solver"]["warmup_epochs"]: 180 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 181 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 182 | else: 183 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 184 | return pow(config["solver"]["lr_gamma"], idx) 185 | 186 | 187 | optimizer = optim.Adamax([ 188 | {'params':model.parameters(), 'lr':config["solver"]["initial_lr"]}, 189 | {'params':model_dict.parameters(), 'lr':config["solver"]["initial_lr"]*0.05} #model is model_dict 190 | ]) 191 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) # 可以在一个组里面调节lr参数 192 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 193 | sparse_metrics = SparseGTMetrics() 194 | ndcg = NDCG() 195 | 196 | # If loading from checkpoint, adjust start epoch and load parameters. 197 | 198 | start_epoch = 0 199 | model_state_dict, _ = load_checkpoint(args.load_pthpath) 200 | if isinstance(model, nn.DataParallel): 201 | model.module.load_state_dict(model_state_dict) 202 | else: 203 | model.load_state_dict(model_state_dict) 204 | model_state_dict, _ = load_checkpoint(args.load_dict_pthpath) 205 | if isinstance(model_dict, nn.DataParallel): 206 | model_dict.module.load_state_dict(model_state_dict) 207 | else: 208 | model_dict.load_state_dict(model_state_dict) 209 | 210 | ###start training and set functions used in training 211 | def get_1round_batch_data(batch, rnd): ##to get 1 round data 212 | temp_train_batch = {} 213 | for key in batch: 214 | if key in ['img_feat']: 215 | temp_train_batch[key] = batch[key].to(device) 216 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 217 | temp_train_batch[key] = batch[key][:, rnd].to(device) 218 | elif key in ['hist_len', 'hist']: 219 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 220 | else: 221 | pass 222 | return temp_train_batch 223 | 224 | 225 | def get_1round_idx_batch_data(batch, rnd, idx): ##to get 1 round data with batch_size = 1 226 | temp_train_batch = {} 227 | for key in batch: 228 | if key in ['img_feat']: 229 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2].to(device) 230 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 231 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2][:, rnd].to(device) 232 | elif key in ['hist_len', 'hist']: 233 | temp_train_batch[key] = batch[key][idx * 2:idx * 2 + 2][:, :rnd + 1].to(device) 234 | else: 235 | pass 236 | return temp_train_batch 237 | 238 | global_iteration_step = start_epoch * iterations 239 | ###load ndcg label list 240 | samplefile = open('data/visdial_1.0_train_dense_sample.json', 'r') 241 | sample = json.loads(samplefile.read()) 242 | samplefile.close() 243 | ndcg_id_list = [] 244 | for idx in range(len(sample)): 245 | ndcg_id_list.append(sample[idx]['image_id']) 246 | 247 | 248 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 249 | model.train() 250 | print('Training for epoch:', epoch, ' time:', time.asctime(time.localtime(time.time()))) 251 | count_loss = 0.0 252 | for k, batch in enumerate(train_sample_dataloader): 253 | ##### find the round 254 | batchsize = batch['img_ids'].shape[0] 255 | grad_dict = {} 256 | optimizer.zero_grad() 257 | for idx in range(int(batchsize / 2)): 258 | for b in range(2): # here is because with the batch_size = 1 will raise error 259 | sample_idx = ndcg_id_list.index(batch['img_ids'][idx * 2 + b].item()) 260 | final_round = sample[sample_idx]['round_id'] - 1 261 | rnd = final_round 262 | temp_train_batch = get_1round_idx_batch_data(batch, rnd, idx) 263 | output_base = model(temp_train_batch)[b] ## this is only for avoid bug, no other meanings 264 | output_dict = model_dict(temp_train_batch)[b] 265 | output = output_base + 0.1 * output_dict 266 | target = batch["ans_ind"][b, rnd].to(device) 267 | rs_score = sample[sample_idx]['relevance'] 268 | cuda_device = output.device 269 | # R3 loss 270 | batch_loss = criterion(output.view(-1, output.size(-1)), target.view(-1)) 271 | rs_score = torch.tensor(rs_score).to(cuda_device) 272 | exp_sum = torch.sum(torch.exp(output[[idx for idx in range(len(rs_score)) if rs_score[idx] < 1]])) 273 | loss_num_count = 0 274 | for rs_idx in range(len(rs_score)): # for the candidate with relevance score 1 275 | if rs_score[rs_idx] > 0.8: 276 | exp_sum = exp_sum + torch.exp(output[rs_idx]) 277 | batch_loss += (-output[rs_idx] + torch.log(exp_sum)) 278 | loss_num_count += 1 279 | exp_sum = exp_sum - torch.exp(output[rs_idx]) 280 | exp_sum_2 = torch.sum( 281 | torch.exp(output[[idx for idx in range(len(rs_score)) if rs_score[idx] < 0.4]])) 282 | for rs_idx in range(len(rs_score)): # for the candidate with relevance score 0.5 283 | if rs_score[rs_idx] < 0.8 and rs_score[rs_idx] > 0.4: 284 | exp_sum_2 = exp_sum_2 + torch.exp(output[rs_idx]) 285 | batch_loss += (-output[rs_idx] + torch.log(exp_sum_2)) 286 | loss_num_count += 1 287 | exp_sum_2 = exp_sum_2 - torch.exp(output[rs_idx]) 288 | batch_loss = batch_loss / (loss_num_count + 1) 289 | if batch_loss != 0: # prevent batch loss = 0 290 | batch_loss.backward() 291 | count_loss += batch_loss.data.cpu().numpy() 292 | optimizer.step() ##accumulate the whole grads in a batch (default is 12) and update weights 293 | optimizer.zero_grad() 294 | 295 | if k % int(iterations / 5) == 0 and k != 0: 296 | mean_loss = (count_loss / (float(iterations) / 5)) / 10.0 297 | print('(step', k, 'in', int(iterations), ') mean_loss:', mean_loss, 'Time:', 298 | time.asctime(time.localtime(time.time())), 'lr:', optimizer.param_groups[0]["lr"]) 299 | count_loss = 0.0 300 | scheduler.step(global_iteration_step) 301 | global_iteration_step += 1 302 | # if k == 5: #for debug 303 | # break 304 | 305 | model.eval() 306 | for i, batch in enumerate(val_dataloader): 307 | batchsize = batch['img_ids'].shape[0] 308 | rnd = 0 309 | temp_train_batch = get_1round_batch_data(batch, rnd) 310 | output_temp = model(temp_train_batch).view(-1, 1, 100).detach() 311 | output_dict = model_dict(temp_train_batch).view(-1, 1, 100).detach() 312 | output = output_temp + 0.1 * output_dict 313 | optimizer.zero_grad() 314 | for rnd in range(1, 10): # get 10 rounds outputs to evaluate 315 | temp_train_batch = get_1round_batch_data(batch, rnd) 316 | output_temp = model(temp_train_batch).view(-1, 1, 100).detach() 317 | output_dict = model_dict(temp_train_batch).view(-1, 1, 100).detach() 318 | output = torch.cat((output, output_temp + 0.1 * output_dict), dim=1) 319 | optimizer.zero_grad() 320 | sparse_metrics.observe(output, batch["ans_ind"]) 321 | if "relevance" in batch: 322 | output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] 323 | ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100)) 324 | # if i == 5: #for debug 325 | # break 326 | all_metrics = {} 327 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 328 | all_metrics.update(ndcg.retrieve(reset=True)) 329 | for metric_name, metric_value in all_metrics.items(): 330 | print(f"{metric_name}: {metric_value}") 331 | model.train() 332 | -------------------------------------------------------------------------------- /question_type/qt.yml: -------------------------------------------------------------------------------- 1 | # Dataset reader arguments 2 | dataset: 3 | image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' 4 | image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' 5 | image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' 6 | word_counts_json: 'data/visdial_1.0_word_counts_train.json' 7 | tokens_path: 'data/tokens.json' 8 | img_norm: 1 9 | concat_history: false 10 | max_sequence_length: 20 11 | vocab_min_count: 5 12 | 13 | 14 | # Model related arguments 15 | model: 16 | encoder: 'baseline_encoder_withP1' 17 | decoder: 'disc_qt' 18 | 19 | img_feature_size: 2048 20 | word_embedding_size: 300 21 | lstm_hidden_size: 512 22 | lstm_num_layers: 2 23 | head_num: 3 24 | dropout: 0.4 # change 0.4 to 0.3, because a smaller lr (about 0.3) is better 25 | dropout_fc: 0.25 # change 0.25 to 0.1, because a smaller lr (about 0.1) is better 26 | 27 | 28 | # Optimization related arguments 29 | solver: 30 | batch_size: 128 # 64 x num_gpus is a good rule of thumb 31 | num_epochs: 6 32 | initial_lr: 0.004 # general is 0.004 33 | training_splits: "train" 34 | lr_gamma: 0.4 35 | lr_milestones: # epochs when lr —> lr * lr_gamma 36 | - 2 37 | - 5 38 | warmup_factor: 0.2 39 | warmup_epochs: 1 40 | -------------------------------------------------------------------------------- /question_type/train_qt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import datetime 4 | import numpy as np 5 | from tensorboardX import SummaryWriter 6 | import torch 7 | from torch import nn, optim 8 | from torch.optim import lr_scheduler 9 | from torch.utils.data import DataLoader 10 | import yaml 11 | from bisect import bisect 12 | from torch.nn import functional as F 13 | import os 14 | os.chdir("../") 15 | import sys 16 | sys.path.append("./") 17 | import random 18 | from visdialch.data.dataset_qt import VisDialDataset 19 | from visdialch.encoders import Encoder 20 | from visdialch.decoders import Decoder 21 | from visdialch.metrics import SparseGTMetrics, NDCG 22 | from visdialch.model import EncoderDecoderModel 23 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--config-yml", 28 | default="question_type/qt.yml", 29 | help="Path to a config file listing reader, model and solver parameters.", 30 | ) 31 | parser.add_argument( 32 | "--train-json", 33 | default="data/visdial_1.0_train.json", 34 | help="Path to json file containing VisDial v1.0 training data.", 35 | ) 36 | parser.add_argument( 37 | "--val-json", 38 | default="data/visdial_1.0_val.json", 39 | help="Path to json file containing VisDial v1.0 validation data.", 40 | ) 41 | parser.add_argument( 42 | "--val-dense-json", 43 | default="data/visdial_1.0_val_dense_annotations.json", 44 | help="Path to json file containing VisDial v1.0 validation dense ground " 45 | "truth annotations.", 46 | ) 47 | parser.add_argument_group( 48 | "Arguments independent of experiment reproducibility" 49 | ) 50 | parser.add_argument( 51 | "--gpu-ids", 52 | nargs="+", 53 | type=int, 54 | default=[0, 1], 55 | help="List of ids of GPUs to use.", 56 | ) 57 | parser.add_argument( 58 | "--cpu-workers", 59 | type=int, 60 | default=8, 61 | help="Number of CPU workers for dataloader.", 62 | ) 63 | parser.add_argument( 64 | "--overfit", 65 | action="store_true", 66 | help="Overfit model on 5 examples, meant for debugging.", 67 | ) 68 | parser.add_argument( 69 | "--validate", 70 | action="store_true", 71 | help="Whether to validate on val split after every epoch.", 72 | ) 73 | parser.add_argument( 74 | "--in-memory", 75 | action="store_true", 76 | help="Load the whole dataset and pre-extracted image features in memory. " 77 | "Use only in presence of large RAM, atleast few tens of GBs.", 78 | ) 79 | 80 | parser.add_argument_group("Checkpointing related arguments") 81 | parser.add_argument( 82 | "--save-dirpath", 83 | default="checkpoints/", 84 | help="Path of directory to create checkpoint directory and save " 85 | "checkpoints.", 86 | ) 87 | parser.add_argument( 88 | "--load-pthpath", 89 | default='checkpoints/baseline_withP1_checkpiont5.pth', 90 | help="To continue training, path to .pth file of saved checkpoint.", 91 | ) 92 | parser.add_argument( 93 | "--save-model", 94 | action="store_true", 95 | help="To make the dir clear", 96 | ) 97 | 98 | manualSeed = random.randint(1, 10000) 99 | print("Random Seed: ", manualSeed) 100 | torch.manual_seed(manualSeed) 101 | torch.cuda.manual_seed_all(manualSeed) 102 | torch.backends.cudnn.benchmark = False 103 | torch.backends.cudnn.deterministic = True 104 | 105 | # ============================================================================= 106 | # INPUT ARGUMENTS AND CONFIG 107 | # ============================================================================= 108 | 109 | args = parser.parse_args() 110 | # keys: {"dataset", "model", "solver"} 111 | config = yaml.load(open(args.config_yml)) 112 | 113 | if isinstance(args.gpu_ids, int): 114 | args.gpu_ids = [args.gpu_ids] 115 | device = ( 116 | torch.device("cuda", args.gpu_ids[0]) 117 | if args.gpu_ids[0] >= 0 118 | else torch.device("cpu") 119 | ) 120 | 121 | # Print config and args. 122 | print(yaml.dump(config, default_flow_style=False)) 123 | for arg in vars(args): 124 | print("{:<20}: {}".format(arg, getattr(args, arg))) 125 | 126 | # ============================================================================= 127 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 128 | # ============================================================================= 129 | 130 | train_dataset = VisDialDataset( 131 | config["dataset"], 132 | args.train_json, 133 | overfit=args.overfit, 134 | in_memory=args.in_memory, 135 | return_options=True, 136 | add_boundary_toks=False, 137 | sample_flag = False 138 | ) 139 | train_dataloader = DataLoader( 140 | train_dataset, 141 | batch_size=config["solver"]["batch_size"], 142 | num_workers=args.cpu_workers, 143 | shuffle=True, 144 | ) 145 | val_dataset = VisDialDataset( 146 | config["dataset"], 147 | args.val_json, 148 | args.val_dense_json, 149 | overfit=args.overfit, 150 | in_memory=args.in_memory, 151 | return_options=True, 152 | add_boundary_toks=False, 153 | sample_flag = False 154 | ) 155 | val_dataloader = DataLoader( 156 | val_dataset, 157 | batch_size=config["solver"]["batch_size"], 158 | num_workers=args.cpu_workers, 159 | ) 160 | 161 | # Pass vocabulary to construct Embedding layer. 162 | encoder = Encoder(config["model"], train_dataset.vocabulary) 163 | decoder = Decoder(config["model"], train_dataset.vocabulary) 164 | print("Encoder: {}".format(config["model"]["encoder"])) 165 | print("Decoder: {}".format(config["model"]["decoder"])) 166 | 167 | # Share word embedding between encoder and decoder. 168 | if args.load_pthpath == "": 169 | print('load glove') 170 | decoder.word_embed = encoder.word_embed 171 | glove = np.load('data/glove.npy') 172 | encoder.word_embed.weight.data = torch.tensor(glove) 173 | 174 | # Wrap encoder and decoder in a model. 175 | model = EncoderDecoderModel(encoder, decoder).to(device) 176 | if -1 not in args.gpu_ids: 177 | model = nn.DataParallel(model, args.gpu_ids) 178 | 179 | criterion = nn.CrossEntropyLoss() 180 | criterion_bce = nn.BCEWithLogitsLoss() 181 | iterations = len(train_dataset)// config["solver"]["batch_size"] + 1 #迭代次数 182 | 183 | def lr_lambda_fun(current_iteration: int) -> float: 184 | """Returns a learning rate multiplier. 185 | 186 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 187 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 188 | """ 189 | current_epoch = float(current_iteration) / iterations 190 | if current_epoch <= config["solver"]["warmup_epochs"]: 191 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 192 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 193 | else: 194 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 195 | return pow(config["solver"]["lr_gamma"], idx) 196 | 197 | optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) 198 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) 199 | 200 | # ============================================================================= 201 | # SETUP BEFORE TRAINING LOOP 202 | # ============================================================================= 203 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 204 | checkpoint_dirpath = args.save_dirpath 205 | if checkpoint_dirpath == 'checkpoints/': 206 | checkpoint_dirpath += '%s+%s/%s' % (config["model"]["encoder"], config["model"]["decoder"], start_time) 207 | if args.save_model: 208 | summary_writer = SummaryWriter(log_dir=checkpoint_dirpath) 209 | checkpoint_manager = CheckpointManager(model, optimizer, checkpoint_dirpath, config=config) 210 | 211 | sparse_metrics = SparseGTMetrics() 212 | ndcg = NDCG() 213 | 214 | # If loading from checkpoint, adjust start epoch and load parameters. 215 | if args.load_pthpath == "": 216 | start_epoch = 0 217 | else: 218 | start_epoch = 0 219 | model_state_dict, _ = load_checkpoint(args.load_pthpath) 220 | if isinstance(model, nn.DataParallel): 221 | model.module.load_state_dict(model_state_dict) 222 | else: 223 | model.load_state_dict(model_state_dict) 224 | print("Loaded model from {}".format(args.load_pthpath)) 225 | 226 | # ============================================================================= 227 | # TRAINING LOOP 228 | # ============================================================================= 229 | 230 | # Forever increasing counter to keep track of iterations (for tensorboard log). 231 | global_iteration_step = start_epoch * iterations 232 | 233 | ###start training and set functions used in training 234 | def get_1round_batch_data(batch,rnd): 235 | temp_train_batch = {} 236 | for key in batch: 237 | if key in ['img_feat']: 238 | temp_train_batch[key] = batch[key].to(device) 239 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind','qt','opt_idx']: 240 | temp_train_batch[key] = batch[key][:, rnd].to(device) 241 | elif key in ['hist_len', 'hist']: 242 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 243 | else: 244 | pass 245 | return temp_train_batch 246 | 247 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 248 | print('Training for epoch:',epoch,' time:', time.asctime(time.localtime(time.time()))) 249 | count_loss = 0.0 250 | model.train() 251 | for i, batch in enumerate(train_dataloader): 252 | for rnd in range(10): 253 | temp_train_batch = get_1round_batch_data(batch,rnd) 254 | optimizer.zero_grad() 255 | output,qt_score = model(temp_train_batch) 256 | batchsize_temp = output.size(0) 257 | target = batch["ans_ind"][:, rnd].to(device) 258 | sorted_output, indexes = torch.sort(output,dim=1,descending=True) 259 | _, rank = torch.sort(indexes, dim=1, descending=False) 260 | mask = torch.full_like(output[0], 0, device=output.device) 261 | mask[:10] = 1 #select some top candidates for better usage of qt scores 262 | mask[10:20] = 0.8 #according to the ranking to give the scores (updated trick) 263 | for b in range(batchsize_temp): 264 | sorted_qt = qt_score[b][indexes[b]] 265 | sorted_qt = sorted_qt * mask 266 | qt_score[b] = sorted_qt[rank[b]] 267 | batch_loss = 0.4 * criterion(output.view(-1, output.size(-1)), target.view(-1)) #for keep mrr, the weight can be adjusted 268 | batch_loss += criterion_bce(output, qt_score) 269 | # output_sig = torch.sigmoid(output) #loss can change to R4 270 | # qt_score = torch.tensor(qt_score).to(output.device) 271 | # qt_score = F.normalize(qt_score.unsqueeze(0), p=1).squeeze(0) # norm 272 | # max_qt_score = torch.max(qt_score) 273 | # for rs_idx in range(len(qt_score)): 274 | # a = qt_score[rs_idx] 275 | # s = output_sig[rs_idx] 276 | # if s != 1: # s cannot be 1 277 | # batch_loss += - 20 * (a * torch.log(s) + (max_qt_score - a) * torch.log(1 - s)) 278 | # batch_loss = batch_loss / len(qt_score) 279 | batch_loss.backward() 280 | count_loss += batch_loss.data.cpu().numpy() 281 | optimizer.step() 282 | 283 | if i % int(iterations / 10) == 0 and i != 0: 284 | mean_loss = (count_loss / float(iterations / 10)) / 10.0 285 | print('(step', i, 'in', int(iterations), ') mean_loss:', mean_loss, 'Time:',time.asctime(time.localtime(time.time())),'lr:',optimizer.param_groups[0]["lr"]) 286 | count_loss = 0.0 287 | if args.save_model: 288 | summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) 289 | summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) 290 | scheduler.step(global_iteration_step) 291 | global_iteration_step += 1 292 | # if i > 5: #for debug(like the --overfit) 293 | # break 294 | if args.save_model: 295 | checkpoint_manager.step() 296 | if args.validate: 297 | print(f"\nValidation after epoch {epoch}:") 298 | model.eval() 299 | for i, batch in enumerate(val_dataloader): 300 | batchsize = batch['img_ids'].shape[0] 301 | rnd = 0 302 | temp_train_batch = get_1round_batch_data(batch, rnd) 303 | output,_ = model(temp_train_batch) 304 | output = output.view(-1, 1, 100).detach() 305 | optimizer.zero_grad() 306 | for rnd in range(1,10): 307 | temp_train_batch = get_1round_batch_data(batch, rnd) 308 | output_temp, _ = model(temp_train_batch) 309 | output = torch.cat((output, output_temp.view(-1, 1, 100).detach()), dim = 1) 310 | optimizer.zero_grad() 311 | sparse_metrics.observe(output, batch["ans_ind"]) 312 | if "relevance" in batch: 313 | output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] 314 | ndcg.observe(output.view(-1,100), batch["relevance"].contiguous().view(-1,100)) 315 | # if i > 5: #for debug(like the --overfit) 316 | # break 317 | all_metrics = {} 318 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 319 | all_metrics.update(ndcg.retrieve(reset=True)) 320 | for metric_name, metric_value in all_metrics.items(): 321 | print(f"{metric_name}: {metric_value}") 322 | if args.save_model: 323 | summary_writer.add_scalars("metrics", all_metrics, global_iteration_step) 324 | model.train() 325 | -------------------------------------------------------------------------------- /train_stage1_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import datetime 4 | import numpy as np 5 | from tensorboardX import SummaryWriter 6 | import torch 7 | from torch import nn, optim 8 | from torch.optim import lr_scheduler 9 | from torch.utils.data import DataLoader 10 | import yaml 11 | from bisect import bisect 12 | import random 13 | from visdialch.data.dataset import VisDialDataset 14 | from visdialch.encoders import Encoder 15 | from visdialch.decoders import Decoder 16 | from visdialch.metrics import SparseGTMetrics, NDCG 17 | from visdialch.model import EncoderDecoderModel 18 | from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--config-yml", 23 | default="configs/baseline.yml", 24 | help="Path to a config file listing reader, model and solver parameters.", 25 | ) 26 | parser.add_argument( 27 | "--train-json", 28 | default="data/visdial_1.0_train.json", 29 | help="Path to json file containing VisDial v1.0 training data.", 30 | ) 31 | parser.add_argument( 32 | "--val-json", 33 | default="data/visdial_1.0_val.json", 34 | help="Path to json file containing VisDial v1.0 validation data.", 35 | ) 36 | parser.add_argument( 37 | "--val-dense-json", 38 | default="data/visdial_1.0_val_dense_annotations.json", 39 | help="Path to json file containing VisDial v1.0 validation dense ground " 40 | "truth annotations.", 41 | ) 42 | parser.add_argument_group( 43 | "Arguments independent of experiment reproducibility" 44 | ) 45 | parser.add_argument( 46 | "--gpu-ids", 47 | nargs="+", 48 | type=int, 49 | default=[0, 1], 50 | help="List of ids of GPUs to use.", 51 | ) 52 | parser.add_argument( 53 | "--cpu-workers", 54 | type=int, 55 | default=8, 56 | help="Number of CPU workers for dataloader.", 57 | ) 58 | parser.add_argument( 59 | "--overfit", 60 | action="store_true", 61 | help="Overfit model on 5 examples, meant for debugging.", 62 | ) 63 | parser.add_argument( 64 | "--validate", 65 | action="store_true", 66 | help="Whether to validate on val split after every epoch.", 67 | ) 68 | parser.add_argument( 69 | "--in-memory", 70 | action="store_true", 71 | help="Load the whole dataset and pre-extracted image features in memory. " 72 | "Use only in presence of large RAM, atleast few tens of GBs.", 73 | ) 74 | 75 | parser.add_argument_group("Checkpointing related arguments") 76 | parser.add_argument( 77 | "--save-dirpath", 78 | default="checkpoints/", 79 | help="Path of directory to create checkpoint directory and save " 80 | "checkpoints.", 81 | ) 82 | parser.add_argument( 83 | "--load-pthpath", 84 | default="", 85 | help="To continue training, path to .pth file of saved checkpoint.", 86 | ) 87 | parser.add_argument( 88 | "--save-model", 89 | action="store_true", 90 | help="To make the dir clear", 91 | ) 92 | 93 | manualSeed = random.randint(1, 10000) 94 | print("Random Seed: ", manualSeed) 95 | torch.manual_seed(manualSeed) 96 | torch.cuda.manual_seed_all(manualSeed) 97 | torch.backends.cudnn.benchmark = False 98 | torch.backends.cudnn.deterministic = True 99 | 100 | # ============================================================================= 101 | # INPUT ARGUMENTS AND CONFIG 102 | # ============================================================================= 103 | 104 | args = parser.parse_args() 105 | # keys: {"dataset", "model", "solver"} 106 | config = yaml.load(open(args.config_yml)) 107 | 108 | if isinstance(args.gpu_ids, int): 109 | args.gpu_ids = [args.gpu_ids] 110 | device = ( 111 | torch.device("cuda", args.gpu_ids[0]) 112 | if args.gpu_ids[0] >= 0 113 | else torch.device("cpu") 114 | ) 115 | 116 | # Print config and args. 117 | print(yaml.dump(config, default_flow_style=False)) 118 | for arg in vars(args): 119 | print("{:<20}: {}".format(arg, getattr(args, arg))) 120 | 121 | # ============================================================================= 122 | # SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER 123 | # ============================================================================= 124 | 125 | train_dataset = VisDialDataset( 126 | config["dataset"], 127 | args.train_json, 128 | overfit=args.overfit, 129 | in_memory=args.in_memory, 130 | return_options=True, 131 | add_boundary_toks=False, 132 | sample_flag=False 133 | ) 134 | train_dataloader = DataLoader( 135 | train_dataset, 136 | batch_size=config["solver"]["batch_size"], 137 | num_workers=args.cpu_workers, 138 | shuffle=True, 139 | ) 140 | val_dataset = VisDialDataset( 141 | config["dataset"], 142 | args.val_json, 143 | args.val_dense_json, 144 | overfit=args.overfit, 145 | in_memory=args.in_memory, 146 | return_options=True, 147 | add_boundary_toks=False, 148 | sample_flag=False 149 | ) 150 | val_dataloader = DataLoader( 151 | val_dataset, 152 | batch_size=config["solver"]["batch_size"], 153 | num_workers=args.cpu_workers, 154 | shuffle=True, 155 | ) 156 | 157 | # Pass vocabulary to construct Embedding layer. 158 | encoder = Encoder(config["model"], train_dataset.vocabulary) 159 | decoder = Decoder(config["model"], train_dataset.vocabulary) 160 | print("Encoder: {}".format(config["model"]["encoder"])) 161 | print("Decoder: {}".format(config["model"]["decoder"])) 162 | 163 | # Share word embedding between encoder and decoder. 164 | if args.load_pthpath == "": 165 | print('load glove') 166 | decoder.word_embed = encoder.word_embed 167 | glove = np.load('data/glove.npy') 168 | encoder.word_embed.weight.data = torch.tensor(glove) 169 | 170 | # Wrap encoder and decoder in a model. 171 | model = EncoderDecoderModel(encoder, decoder).to(device) 172 | if -1 not in args.gpu_ids: 173 | model = nn.DataParallel(model, args.gpu_ids) 174 | 175 | criterion = nn.CrossEntropyLoss() 176 | iterations = len(train_dataset)// config["solver"]["batch_size"] + 1 # 迭代次数 177 | 178 | def lr_lambda_fun(current_iteration: int) -> float: 179 | """Returns a learning rate multiplier. 180 | 181 | Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, 182 | and then gets multiplied by `lr_gamma` every time a milestone is crossed. 183 | """ 184 | current_epoch = float(current_iteration) / iterations 185 | if current_epoch <= config["solver"]["warmup_epochs"]: 186 | alpha = current_epoch / float(config["solver"]["warmup_epochs"]) 187 | return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha 188 | else: 189 | idx = bisect(config["solver"]["lr_milestones"], current_epoch) 190 | return pow(config["solver"]["lr_gamma"], idx) 191 | 192 | optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) 193 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) 194 | # ============================================================================= 195 | # SETUP BEFORE TRAINING LOOP 196 | # ============================================================================= 197 | start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') 198 | checkpoint_dirpath = args.save_dirpath 199 | if checkpoint_dirpath == 'checkpoints/': 200 | checkpoint_dirpath += '%s+%s/%s' % (config["model"]["encoder"], config["model"]["decoder"], start_time) 201 | if args.save_model: 202 | summary_writer = SummaryWriter(log_dir=checkpoint_dirpath) 203 | checkpoint_manager = CheckpointManager(model, optimizer, checkpoint_dirpath, config=config) 204 | 205 | sparse_metrics = SparseGTMetrics() 206 | ndcg = NDCG() 207 | 208 | # If loading from checkpoint, adjust start epoch and load parameters. 209 | if args.load_pthpath == "": 210 | start_epoch = 0 211 | else: 212 | start_epoch = int(args.load_pthpath.split("_")[-1][:-4]) + 1 213 | 214 | model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath) 215 | if isinstance(model, nn.DataParallel): 216 | model.module.load_state_dict(model_state_dict) 217 | else: 218 | model.load_state_dict(model_state_dict) 219 | print("Loaded model from {}".format(args.load_pthpath)) 220 | 221 | # ============================================================================= 222 | # TRAINING LOOP 223 | # ============================================================================= 224 | 225 | # Forever increasing counter to keep track of iterations (for tensorboard log). 226 | global_iteration_step = start_epoch * iterations 227 | 228 | ###start training and set functions used in training 229 | def get_1round_batch_data(batch, rnd): 230 | temp_train_batch = {} 231 | for key in batch: 232 | if key in ['img_feat']: 233 | temp_train_batch[key] = batch[key].to(device) 234 | elif key in ['ques', 'opt', 'ques_len', 'opt_len', 'ans_ind']: 235 | temp_train_batch[key] = batch[key][:, rnd].to(device) 236 | elif key in ['hist_len', 'hist']: 237 | temp_train_batch[key] = batch[key][:, :rnd + 1].to(device) 238 | else: 239 | pass 240 | return temp_train_batch 241 | 242 | for epoch in range(start_epoch, config["solver"]["num_epochs"]): 243 | print('Training for epoch:', epoch, ' time:', time.asctime(time.localtime(time.time()))) 244 | count_loss = 0.0 245 | for i, batch in enumerate(train_dataloader): 246 | for rnd in range(10): 247 | temp_train_batch = get_1round_batch_data(batch, rnd) 248 | optimizer.zero_grad() 249 | output = model(temp_train_batch) 250 | target = batch["ans_ind"][:, rnd].to(device) 251 | batch_loss = criterion(output.view(-1, output.size(-1)), target.view(-1)) 252 | batch_loss.backward() 253 | count_loss += batch_loss.data.cpu().numpy() 254 | optimizer.step() 255 | ##for rva, apply 10 rounds because of the implementation of rva is hard to separate into 10 parts, according to the original authors 256 | ##note that whether separate into 10 rounds here will not influence the conclusion in our paper 257 | ##but it is interesting topic in the future 258 | # for key in batch: 259 | # batch[key] = batch[key].to(device) 260 | # output = model(batch) 261 | # target = batch["ans_ind"].to(device) 262 | # batch_loss = criterion(output.view(-1, output.size(-1)), target.view(-1)) 263 | # batch_loss.backward() 264 | # count_loss += batch_loss.data.cpu().numpy() * 10.0 265 | # optimizer.step() 266 | # optimizer.zero_grad() 267 | ###################whole 10 rounds part end 268 | if i % int(iterations / 10) == 0 and i != 0: 269 | mean_loss = (count_loss / float(iterations / 10)) / 10.0 270 | print('(step', i, 'in', int(iterations), ') mean_loss:', mean_loss, 'Time:', 271 | time.asctime(time.localtime(time.time())), 'lr:', optimizer.param_groups[0]["lr"]) 272 | count_loss = 0.0 273 | if args.save_model: 274 | summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) 275 | summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) 276 | scheduler.step(global_iteration_step) 277 | global_iteration_step += 1 278 | # if i > 5: #for debug(like the --overfit) 279 | # break 280 | # ------------------------------------------------------------------------- 281 | # ON EPOCH END (checkpointing and validation) 282 | # ------------------------------------------------------------------------- 283 | if args.save_model: 284 | checkpoint_manager.step() 285 | # Validate and report automatic metrics. 286 | if args.validate: 287 | print(f"\nValidation after epoch {epoch}:") 288 | model.eval() 289 | for i, batch in enumerate(val_dataloader): 290 | batchsize = batch['img_ids'].shape[0] 291 | rnd = 0 292 | temp_train_batch = get_1round_batch_data(batch, rnd) 293 | output = model(temp_train_batch).view(-1, 1, 100).detach() 294 | optimizer.zero_grad() 295 | for rnd in range(1, 10): #should be removed if the input is the whole dialog 296 | temp_train_batch = get_1round_batch_data(batch, rnd) 297 | output = torch.cat((output, model(temp_train_batch).view(-1, 1, 100).detach()), dim=1) 298 | optimizer.zero_grad() 299 | ###for 10 rounds(rva) 300 | # with torch.no_grad(): 301 | # output = model(batch) 302 | ##end 10 rounds 303 | sparse_metrics.observe(output, batch["ans_ind"]) 304 | if "relevance" in batch: 305 | output = output[torch.arange(output.size(0)), batch["round_id"] - 1, :] 306 | ndcg.observe(output.view(-1, 100), batch["relevance"].contiguous().view(-1, 100)) 307 | # if i > 5: #for debug(like the --overfit) 308 | # break 309 | all_metrics = {} 310 | all_metrics.update(sparse_metrics.retrieve(reset=True)) 311 | all_metrics.update(ndcg.retrieve(reset=True)) 312 | for metric_name, metric_value in all_metrics.items(): 313 | print(f"{metric_name}: {metric_value}") 314 | if args.save_model: 315 | summary_writer.add_scalars("metrics", all_metrics, global_iteration_step) 316 | model.train() 317 | -------------------------------------------------------------------------------- /visdialch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.data.dataset import VisDialDataset 2 | from visdialch.data.vocabulary import Vocabulary 3 | -------------------------------------------------------------------------------- /visdialch/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | import torch 3 | from torch.nn.functional import normalize 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import Dataset 6 | import json 7 | from visdialch.data.readers import ( 8 | DialogsReader, 9 | DenseAnnotationsReader, 10 | ImageFeaturesHdfReader, 11 | ) 12 | from visdialch.data.vocabulary import Vocabulary 13 | 14 | 15 | class VisDialDataset(Dataset): 16 | """ 17 | A full representation of VisDial v1.0 (train/val/test) dataset. According 18 | to the appropriate split, it returns dictionary of question, image, 19 | history, ground truth answer, answer options, dense annotations etc. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | config: Dict[str, Any], 25 | dialogs_jsonpath: str, 26 | dense_annotations_jsonpath: Optional[str] = None, 27 | overfit: bool = False, 28 | in_memory: bool = False, 29 | return_options: bool = True, 30 | add_boundary_toks: bool = False, 31 | sample_flag: bool = False, 32 | ): 33 | super().__init__() 34 | self.config = config 35 | self.return_options = return_options 36 | self.add_boundary_toks = add_boundary_toks 37 | self.dialogs_reader = DialogsReader(dialogs_jsonpath,config) 38 | 39 | if "val" in self.split and dense_annotations_jsonpath is not None: 40 | self.annotations_reader = DenseAnnotationsReader( 41 | dense_annotations_jsonpath 42 | ) 43 | else: 44 | self.annotations_reader = None 45 | 46 | self.vocabulary = Vocabulary( 47 | config["word_counts_json"], min_count=config["vocab_min_count"] 48 | ) 49 | 50 | # Initialize image features reader according to split. 51 | image_features_hdfpath = config["image_features_train_h5"] 52 | if "val" in self.dialogs_reader.split: 53 | image_features_hdfpath = config["image_features_val_h5"] 54 | elif "test" in self.dialogs_reader.split: 55 | image_features_hdfpath = config["image_features_test_h5"] 56 | 57 | self.hdf_reader = ImageFeaturesHdfReader( 58 | image_features_hdfpath, in_memory 59 | ) 60 | if sample_flag == False: 61 | self.image_ids = list(self.dialogs_reader.dialogs.keys()) 62 | 63 | # Keep a list of image_ids as primary keys to access data. 64 | if sample_flag == True: 65 | samplefile = open('data/visdial_1.0_train_dense_sample.json', 'r') ####for answer score sampling (fine-tune) 66 | sample = json.loads(samplefile.read()) 67 | samplefile.close() 68 | ndcg_id_list = [] 69 | for idx in range(len(sample)): 70 | ndcg_id_list.append(sample[idx]['image_id']) 71 | self.image_ids = ndcg_id_list 72 | 73 | 74 | if overfit: 75 | self.image_ids = self.image_ids[:5] 76 | 77 | @property 78 | def split(self): 79 | return self.dialogs_reader.split 80 | 81 | def __len__(self): 82 | return len(self.image_ids) 83 | 84 | def __getitem__(self, index): 85 | # Get image_id, which serves as a primary key for current instance. 86 | image_id = self.image_ids[index] 87 | # Get image features for this image_id using hdf reader. 88 | image_features = self.hdf_reader[image_id] 89 | image_features = torch.tensor(image_features) 90 | # Normalize image features at zero-th dimension (since there's no batch 91 | # dimension). 92 | if self.config["img_norm"]: 93 | image_features = normalize(image_features, dim=0, p=2) 94 | 95 | # Retrieve instance for this image_id using json reader. 96 | # print(image_id) 97 | visdial_instance = self.dialogs_reader[image_id] 98 | caption = visdial_instance["caption"] 99 | dialog = visdial_instance["dialog"] 100 | 101 | 102 | # Convert word tokens of caption, question, answer and answer options 103 | # to integers. 104 | caption = self.vocabulary.to_indices(caption) 105 | for i in range(len(dialog)): 106 | dialog[i]["question"] = self.vocabulary.to_indices( 107 | dialog[i]["question"] 108 | ) 109 | if self.add_boundary_toks: 110 | dialog[i]["answer"] = self.vocabulary.to_indices( 111 | [self.vocabulary.SOS_TOKEN] 112 | + dialog[i]["answer"] 113 | + [self.vocabulary.EOS_TOKEN] 114 | ) 115 | else: 116 | dialog[i]["answer"] = self.vocabulary.to_indices( 117 | dialog[i]["answer"] 118 | ) 119 | 120 | if self.return_options: 121 | for j in range(len(dialog[i]["answer_options"])): 122 | if self.add_boundary_toks: 123 | dialog[i]["answer_options"][ 124 | j 125 | ] = self.vocabulary.to_indices( 126 | [self.vocabulary.SOS_TOKEN] 127 | + dialog[i]["answer_options"][j] 128 | + [self.vocabulary.EOS_TOKEN] 129 | ) 130 | else: 131 | dialog[i]["answer_options"][ 132 | j 133 | ] = self.vocabulary.to_indices( 134 | dialog[i]["answer_options"][j] 135 | ) 136 | 137 | questions, question_lengths = self._pad_sequences( 138 | [dialog_round["question"] for dialog_round in dialog] 139 | ) 140 | history, history_lengths = self._get_history( 141 | caption, 142 | [dialog_round["question"] for dialog_round in dialog], 143 | [dialog_round["answer"] for dialog_round in dialog], 144 | ) 145 | answers_in, answer_lengths = self._pad_sequences( 146 | [dialog_round["answer"][:-1] for dialog_round in dialog] 147 | # [dialog_round["answer"][:] for dialog_round in dialog] 148 | ) 149 | answers_out, _ = self._pad_sequences( 150 | # [dialog_round["answer"][:] for dialog_round in dialog] 151 | [dialog_round["answer"][1:] for dialog_round in dialog] 152 | ) 153 | answers, _ = self._pad_sequences( 154 | # [dialog_round["answer"][:] for dialog_round in dialog] 155 | [dialog_round["answer"] for dialog_round in dialog] 156 | ) 157 | # Collect everything as tensors for ``collate_fn`` of dataloader to 158 | # work seamlessly questions, history, etc. are converted to 159 | # LongTensors, for nn.Embedding input. 160 | item = {} 161 | item["img_ids"] = torch.tensor(image_id).long() 162 | item["img_feat"] = image_features 163 | item["ques"] = questions.long() 164 | item["ans"] = answers.long() 165 | item["hist"] = history.long() 166 | item["ans_in"] = answers_in.long() 167 | item["ans_out"] = answers_out.long() 168 | item["ques_len"] = torch.tensor(question_lengths).long() 169 | item["hist_len"] = torch.tensor(history_lengths).long() 170 | item["ans_len"] = torch.tensor(answer_lengths).long() 171 | item["num_rounds"] = torch.tensor( 172 | visdial_instance["num_rounds"] 173 | ).long() 174 | 175 | if self.return_options: 176 | if self.add_boundary_toks: 177 | answer_options_in, answer_options_out, answer_options = [], [], [] 178 | answer_option_lengths = [] 179 | for dialog_round in dialog: 180 | options, option_lengths = self._pad_sequences( 181 | [ 182 | option[:-1] 183 | for option in dialog_round["answer_options"] 184 | ] 185 | ) 186 | answer_options_in.append(options) 187 | 188 | options, _ = self._pad_sequences( 189 | [ 190 | option[1:] 191 | for option in dialog_round["answer_options"] 192 | ] 193 | ) 194 | answer_options_out.append(options) 195 | 196 | options, _ = self._pad_sequences( 197 | [ 198 | option[:] 199 | for option in dialog_round["answer_options"] 200 | ] 201 | ) 202 | answer_options.append(options) 203 | 204 | answer_option_lengths.append(option_lengths) 205 | 206 | answer_options_in = torch.stack(answer_options_in, 0) 207 | answer_options_out = torch.stack(answer_options_out, 0) 208 | answer_options = torch.stack(answer_options, 0) 209 | 210 | item["opt"] = answer_options.long() 211 | item["opt_in"] = answer_options_in.long() 212 | item["opt_out"] = answer_options_out.long() 213 | item["opt_len"] = torch.tensor(answer_option_lengths).long() 214 | else: 215 | answer_options = [] 216 | answer_option_lengths = [] 217 | for dialog_round in dialog: 218 | options, option_lengths = self._pad_sequences( 219 | dialog_round["answer_options"] 220 | ) 221 | answer_options.append(options) 222 | answer_option_lengths.append(option_lengths) 223 | answer_options = torch.stack(answer_options, 0) 224 | 225 | item["opt"] = answer_options.long() 226 | item["opt_len"] = torch.tensor(answer_option_lengths).long() 227 | 228 | if "test" not in self.split: 229 | answer_indices = [ 230 | dialog_round["gt_index"] for dialog_round in dialog 231 | ] 232 | item["ans_ind"] = torch.tensor(answer_indices).long() 233 | 234 | # Gather dense annotations. 235 | if "val" in self.split: 236 | dense_annotations = self.annotations_reader[image_id] 237 | item["relevance"] = torch.tensor( 238 | dense_annotations["gt_relevance"] 239 | ).float() 240 | item["round_id"] = torch.tensor( 241 | dense_annotations["round_id"] 242 | ).long() 243 | 244 | 245 | return item 246 | 247 | def _pad_sequences(self, sequences: List[List[int]]): 248 | """Given tokenized sequences (either questions, answers or answer 249 | options, tokenized in ``__getitem__``), padding them to maximum 250 | specified sequence length. Return as a tensor of size 251 | ``(*, max_sequence_length)``. 252 | 253 | This method is only called in ``__getitem__``, chunked out separately 254 | for readability. 255 | 256 | Parameters 257 | ---------- 258 | sequences : List[List[int]] 259 | List of tokenized sequences, each sequence is typically a 260 | List[int]. 261 | 262 | Returns 263 | ------- 264 | torch.Tensor, torch.Tensor 265 | Tensor of sequences padded to max length, and length of sequences 266 | before padding. 267 | """ 268 | 269 | for i in range(len(sequences)): 270 | sequences[i] = sequences[i][ 271 | : self.config["max_sequence_length"] - 1 272 | ] 273 | sequence_lengths = [len(sequence) for sequence in sequences] 274 | 275 | # Pad all sequences to max_sequence_length. 276 | maxpadded_sequences = torch.full( 277 | (len(sequences), self.config["max_sequence_length"]), 278 | fill_value=self.vocabulary.PAD_INDEX, 279 | ) 280 | padded_sequences = pad_sequence( 281 | [torch.tensor(sequence) for sequence in sequences], 282 | batch_first=True, 283 | padding_value=self.vocabulary.PAD_INDEX, 284 | ) 285 | maxpadded_sequences[:, : padded_sequences.size(1)] = padded_sequences 286 | return maxpadded_sequences, sequence_lengths 287 | 288 | def _get_history( 289 | self, 290 | caption: List[int], 291 | questions: List[List[int]], 292 | answers: List[List[int]], 293 | ): 294 | # Allow double length of caption, equivalent to a concatenated QA pair. 295 | caption = caption[: self.config["max_sequence_length"] * 2 - 1] 296 | 297 | for i in range(len(questions)): 298 | questions[i] = questions[i][ 299 | : self.config["max_sequence_length"] - 1 300 | ] 301 | 302 | for i in range(len(answers)): 303 | answers[i] = answers[i][: self.config["max_sequence_length"] - 1] 304 | 305 | # History for first round is caption, else concatenated QA pair of 306 | # previous round. 307 | history = [] 308 | history.append(caption) 309 | for question, answer in zip(questions, answers): 310 | history.append(question + answer + [self.vocabulary.EOS_INDEX]) 311 | # Drop last entry from history (there's no eleventh question). 312 | history = history[:-1] #切掉最后一个 313 | max_history_length = self.config["max_sequence_length"] * 2 314 | 315 | if self.config.get("concat_history", False): 316 | # Concatenated_history has similar structure as history, except it 317 | # contains concatenated QA pairs from previous rounds. 318 | concatenated_history = [] 319 | concatenated_history.append(caption) 320 | for i in range(1, len(history)): 321 | concatenated_history.append([]) 322 | for j in range(i + 1): 323 | concatenated_history[i].extend(history[j]) 324 | 325 | max_history_length = ( 326 | self.config["max_sequence_length"] * 2 * len(history) 327 | ) 328 | history = concatenated_history 329 | 330 | history_lengths = [len(round_history) for round_history in history] 331 | maxpadded_history = torch.full( 332 | (len(history), max_history_length), 333 | fill_value=self.vocabulary.PAD_INDEX, 334 | ) 335 | padded_history = pad_sequence( 336 | [torch.tensor(round_history) for round_history in history], 337 | batch_first=True, 338 | padding_value=self.vocabulary.PAD_INDEX, 339 | ) 340 | maxpadded_history[:, : padded_history.size(1)] = padded_history 341 | return maxpadded_history, history_lengths 342 | -------------------------------------------------------------------------------- /visdialch/data/readers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Reader simply reads data from disk and returns it almost as is, based on 3 | a "primary key", which for the case of VisDial v1.0 dataset, is the 4 | ``image_id``. Readers should be utilized by torch ``Dataset``s. Any type of 5 | data pre-processing is not recommended in the reader, such as tokenizing words 6 | to integers, embedding tokens, or passing an image through a pre-trained CNN. 7 | 8 | Each reader must atleast implement three methods: 9 | - ``__len__`` to return the length of data this Reader can read. 10 | - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 11 | dataset. 12 | - ``keys`` to return a list of possible ``image_id``s this Reader can 13 | provide data of. 14 | """ 15 | 16 | import copy 17 | import json 18 | from typing import Dict, List, Union 19 | import h5py 20 | import os 21 | 22 | 23 | # A bit slow, and just splits sentences to list of words, can be doable in 24 | # `DialogsReader`. 25 | from nltk.tokenize import word_tokenize 26 | from tqdm import tqdm 27 | 28 | 29 | class DialogsReader(object): 30 | """ 31 | A simple reader for VisDial v1.0 dialog data. The json file must have the 32 | same structure as mentioned on ``https://visualdialog.org/data``. 33 | 34 | Parameters 35 | ---------- 36 | dialogs_jsonpath : str 37 | Path to json file containing VisDial v1.0 train, val or test data. 38 | """ 39 | 40 | def __init__(self, dialogs_jsonpath: str, config): 41 | with open(dialogs_jsonpath, "r") as visdial_file: 42 | visdial_data = json.load(visdial_file) 43 | self._split = visdial_data["split"] 44 | self.tokens_path = config["tokens_path"] 45 | # Image_id serves as key for all three dicts here. 46 | self.captions = {} 47 | self.dialogs = {} 48 | self.num_rounds = {} 49 | 50 | for dialog_for_image in visdial_data["data"]["dialogs"]: 51 | self.captions[dialog_for_image["image_id"]] = dialog_for_image["caption"] 52 | 53 | # Record original length of dialog, before padding. 54 | # 10 for train and val splits, 10 or ######less for test split######. 55 | self.num_rounds[dialog_for_image["image_id"]] = len(dialog_for_image["dialog"] ) 56 | 57 | # Pad dialog at the end with empty question and answer pairs 58 | # (for test split). 59 | while len(dialog_for_image["dialog"]) < 10: 60 | dialog_for_image["dialog"].append( 61 | {"question": -1, "answer": -1} 62 | ) 63 | 64 | # Add empty answer /answer options if not provided 65 | # (for test split). 66 | for i in range(len(dialog_for_image["dialog"])): 67 | if "answer" not in dialog_for_image["dialog"][i]: 68 | dialog_for_image["dialog"][i]["answer"] = -1 69 | if "answer_options" not in dialog_for_image["dialog"][i]: 70 | dialog_for_image["dialog"][i]["answer_options"] = [ 71 | -1 72 | ] * 100 73 | 74 | self.dialogs[dialog_for_image["image_id"]] = dialog_for_image[ 75 | "dialog" 76 | ] 77 | print(f"[{self._split}] Tokenizing captions...") 78 | for image_id, caption in self.captions.items(): 79 | self.captions[image_id] = word_tokenize(caption) 80 | if not os.path.exists(self.tokens_path): ##generate tokens from sentence, and create tokens.json in "data/" 81 | save_dict = {} 82 | for key in ['questions', 'answers']: 83 | save_key = key+"_"+self._split 84 | if save_key not in save_dict.keys(): 85 | save_dict[save_key] = visdial_data["data"][key] 86 | if key == 'questions': 87 | for i in range(len(save_dict[save_key])): 88 | save_dict[save_key][i] = word_tokenize(save_dict[save_key][i]+"?") 89 | else: 90 | for i in range(len(save_dict[save_key])): 91 | save_dict[save_key][i] = word_tokenize(save_dict[save_key][i]) 92 | with open(self.tokens_path, "w") as f: 93 | json.dump(save_dict, f) 94 | if os.path.exists(self.tokens_path) and self._split == 'val2018': 95 | with open(self.tokens_path, 'r') as f: 96 | save_dict = json.loads(f.read()) 97 | f.close() 98 | if 'questions_'+self._split not in save_dict.keys(): 99 | for key in ['questions', 'answers']: 100 | save_key = key + "_" + self._split 101 | if save_key not in save_dict.keys(): 102 | save_dict[save_key] = visdial_data["data"][key] 103 | if key == 'questions': 104 | for i in range(len(save_dict[save_key])): 105 | save_dict[save_key][i] = word_tokenize(save_dict[save_key][i] + "?") 106 | else: 107 | for i in range(len(save_dict[save_key])): 108 | save_dict[save_key][i] = word_tokenize(save_dict[save_key][i]) 109 | with open(self.tokens_path, "w") as f: 110 | json.dump(save_dict, f) 111 | if self._split in ['val2018', 'train']: 112 | print(self._split+'tokens load token from data/tokens.json') 113 | with open("data/tokens.json", 'r') as f: 114 | data = json.loads(f.read()) 115 | f.close() 116 | self.questions = data['questions_'+self._split] 117 | self.answers = data['answers_' + self._split] 118 | # self.captions_wait = data['captions_' + self._split] 119 | self.questions.append("") 120 | self.answers.append("") 121 | else: 122 | self.questions = visdial_data["data"]["questions"] 123 | self.answers = visdial_data["data"]["answers"] # list 124 | # Add empty question, answer at the end, useful for padding dialog 125 | # rounds for test. 126 | self.questions.append("") 127 | self.answers.append("") 128 | 129 | print(f"[{self._split}] Tokenizing questions...") 130 | for i in range(len(self.questions)): 131 | self.questions[i] = word_tokenize(self.questions[i] + "?") 132 | 133 | print(f"[{self._split}] Tokenizing answers...") 134 | for i in range(len(self.answers)): 135 | self.answers[i] = word_tokenize(self.answers[i]) 136 | 137 | def __len__(self): 138 | return len(self.dialogs) 139 | 140 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: 141 | caption_for_image = self.captions[image_id] 142 | dialog_for_image = copy.deepcopy(self.dialogs[image_id]) ##change from copy -> deepcopy 143 | num_rounds = self.num_rounds[image_id] 144 | 145 | # Replace question and answer indices with actual word tokens. 146 | for i in range(len(dialog_for_image)): 147 | dialog_for_image[i]["question"] = self.questions[ 148 | dialog_for_image[i]["question"] 149 | ] 150 | dialog_for_image[i]["answer"] = self.answers[ 151 | dialog_for_image[i]["answer"] 152 | ] 153 | for j, answer_option in enumerate( 154 | dialog_for_image[i]["answer_options"] 155 | ): 156 | dialog_for_image[i]["answer_options"][j] = self.answers[ 157 | answer_option 158 | ] 159 | 160 | return { 161 | "image_id": image_id, 162 | "caption": caption_for_image, 163 | "dialog": dialog_for_image, 164 | "num_rounds": num_rounds, 165 | } 166 | 167 | def keys(self) -> List[int]: 168 | return list(self.dialogs.keys()) 169 | 170 | @property 171 | def split(self): 172 | return self._split 173 | 174 | 175 | class DenseAnnotationsReader(object): 176 | """ 177 | A reader for dense annotations for val split. The json file must have the 178 | same structure as mentioned on ``https://visualdialog.org/data``. 179 | 180 | Parameters 181 | ---------- 182 | dense_annotations_jsonpath : str 183 | Path to a json file containing VisDial v1.0 184 | """ 185 | 186 | def __init__(self, dense_annotations_jsonpath: str): 187 | with open(dense_annotations_jsonpath, "r") as visdial_file: 188 | self._visdial_data = json.load(visdial_file) 189 | self._image_ids = [ 190 | entry["image_id"] for entry in self._visdial_data 191 | ] 192 | 193 | def __len__(self): 194 | return len(self._image_ids) 195 | 196 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: 197 | index = self._image_ids.index(image_id) 198 | # keys: {"image_id", "round_id", "gt_relevance"} 199 | return self._visdial_data[index] 200 | 201 | @property 202 | def split(self): 203 | # always 204 | return "val" 205 | 206 | 207 | class ImageFeaturesHdfReader(object): 208 | """ 209 | A reader for HDF files containing pre-extracted image features. A typical 210 | HDF file is expected to have a column named "image_id", and another column 211 | named "features". 212 | 213 | Example of an HDF file: 214 | ``` 215 | visdial_train_faster_rcnn_bottomup_features.h5 216 | |--- "image_id" [shape: (num_images, )] 217 | |--- "features" [shape: (num_images, num_proposals, feature_size)] 218 | +--- .attrs ("split", "train") 219 | ``` 220 | Refer ``$PROJECT_ROOT/data/extract_bottomup.py`` script for more details 221 | about HDF structure. 222 | 223 | Parameters 224 | ---------- 225 | features_hdfpath : str 226 | Path to an HDF file containing VisDial v1.0 train, val or test split 227 | image features. 228 | in_memory : bool 229 | Whether to load the whole HDF file in memory. Beware, these files are 230 | sometimes tens of GBs in size. Set this to true if you have sufficient 231 | RAM - trade-off between speed and memory. 232 | """ 233 | 234 | def __init__(self, features_hdfpath: str, in_memory: bool = False): 235 | self.features_hdfpath = features_hdfpath 236 | self._in_memory = in_memory 237 | 238 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 239 | self._split = features_hdf.attrs["split"] 240 | self.image_id_list = list(features_hdf["image_id"]) 241 | # "features" is List[np.ndarray] if the dataset is loaded in-memory 242 | # If not loaded in memory, then list of None. 243 | self.features = [None] * len(self.image_id_list) 244 | 245 | def __len__(self): 246 | return len(self.image_id_list) 247 | 248 | def __getitem__(self, image_id: int): 249 | index = self.image_id_list.index(image_id) 250 | if self._in_memory: 251 | # Load features during first epoch, all not loaded together as it 252 | # has a slow start. 253 | if self.features[index] is not None: 254 | image_id_features = self.features[index] 255 | else: 256 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 257 | image_id_features = features_hdf["features"][index] 258 | self.features[index] = image_id_features 259 | else: 260 | # Read chunk from file everytime if not loaded in memory. 261 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 262 | image_id_features = features_hdf["features"][index] 263 | 264 | return image_id_features 265 | 266 | def keys(self) -> List[int]: 267 | return self.image_id_list 268 | 269 | @property 270 | def split(self): 271 | return self._split 272 | -------------------------------------------------------------------------------- /visdialch/data/readers_qt.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Reader simply reads data from disk and returns it almost as is, based on 3 | a "primary key", which for the case of VisDial v1.0 dataset, is the 4 | ``image_id``. Readers should be utilized by torch ``Dataset``s. Any type of 5 | data pre-processing is not recommended in the reader, such as tokenizing words 6 | to integers, embedding tokens, or passing an image through a pre-trained CNN. 7 | 8 | Each reader must atleast implement three methods: 9 | - ``__len__`` to return the length of data this Reader can read. 10 | - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 11 | dataset. 12 | - ``keys`` to return a list of possible ``image_id``s this Reader can 13 | provide data of. 14 | """ 15 | 16 | import copy 17 | import json 18 | from typing import Dict, List, Union 19 | import h5py 20 | 21 | 22 | # A bit slow, and just splits sentences to list of words, can be doable in 23 | # `DialogsReader`. 24 | from nltk.tokenize import word_tokenize 25 | from tqdm import tqdm 26 | 27 | 28 | class DialogsReader(object): 29 | """ 30 | A simple reader for VisDial v1.0 dialog data. The json file must have the 31 | same structure as mentioned on ``https://visualdialog.org/data``. 32 | 33 | Parameters 34 | ---------- 35 | dialogs_jsonpath : str 36 | Path to json file containing VisDial v1.0 train, val or test data. 37 | """ 38 | 39 | def __init__(self, dialogs_jsonpath: str): 40 | with open(dialogs_jsonpath, "r") as visdial_file: 41 | visdial_data = json.load(visdial_file) 42 | self._split = visdial_data["split"] 43 | 44 | # Image_id serves as key for all three dicts here. 45 | self.captions = {} 46 | self.dialogs = {} 47 | self.num_rounds = {} 48 | 49 | for dialog_for_image in visdial_data["data"]["dialogs"]: 50 | self.captions[dialog_for_image["image_id"]] = dialog_for_image["caption"] 51 | 52 | # Record original length of dialog, before padding. 53 | # 10 for train and val splits, 10 or ######less for test split######. 54 | self.num_rounds[dialog_for_image["image_id"]] = len(dialog_for_image["dialog"] ) 55 | 56 | # Pad dialog at the end with empty question and answer pairs 57 | # (for test split). 58 | while len(dialog_for_image["dialog"]) < 10: 59 | dialog_for_image["dialog"].append( 60 | {"question": -1, "answer": -1} 61 | ) 62 | 63 | # Add empty answer /answer options if not provided 64 | # (for test split). 65 | for i in range(len(dialog_for_image["dialog"])): 66 | if "answer" not in dialog_for_image["dialog"][i]: 67 | dialog_for_image["dialog"][i]["answer"] = -1 68 | if "answer_options" not in dialog_for_image["dialog"][i]: 69 | dialog_for_image["dialog"][i]["answer_options"] = [ 70 | -1 71 | ] * 100 72 | 73 | self.dialogs[dialog_for_image["image_id"]] = dialog_for_image[ 74 | "dialog" 75 | ] 76 | print(f"[{self._split}] Tokenizing captions...") 77 | for image_id, caption in self.captions.items(): 78 | self.captions[image_id] = word_tokenize(caption) 79 | 80 | if self._split in ['val2018', 'train']: 81 | print(self._split+'tokens load token from data/tokens.json') 82 | with open("data/tokens.json", 'r') as f: 83 | data = json.loads(f.read()) 84 | f.close() 85 | self.questions = data['questions_'+self._split] 86 | self.answers = data['answers_' + self._split] 87 | # self.captions_wait = data['captions_' + self._split] 88 | self.questions.append("") 89 | self.answers.append("") 90 | else: 91 | self.questions = visdial_data["data"]["questions"] 92 | self.answers = visdial_data["data"]["answers"] # list 93 | # Add empty question, answer at the end, useful for padding dialog 94 | # rounds for test. 95 | self.questions.append("") 96 | self.answers.append("") 97 | 98 | print(f"[{self._split}] Tokenizing questions...") 99 | for i in range(len(self.questions)): 100 | self.questions[i] = word_tokenize(self.questions[i] + "?") 101 | 102 | print(f"[{self._split}] Tokenizing answers...") 103 | for i in range(len(self.answers)): 104 | self.answers[i] = word_tokenize(self.answers[i]) 105 | 106 | def __len__(self): 107 | return len(self.dialogs) 108 | 109 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: 110 | caption_for_image = self.captions[image_id] 111 | dialog_for_image = copy.deepcopy(self.dialogs[image_id]) ##change from copy -> deepcopy 112 | num_rounds = self.num_rounds[image_id] 113 | 114 | # Replace question and answer indices with actual word tokens. 115 | opt_list = [] 116 | for i in range(len(dialog_for_image)): 117 | opt_list.append(copy.deepcopy(dialog_for_image[i]['answer_options'])) 118 | dialog_for_image[i]["question"] = self.questions[ 119 | dialog_for_image[i]["question"] 120 | ] 121 | dialog_for_image[i]["answer"] = self.answers[ 122 | dialog_for_image[i]["answer"] 123 | ] 124 | for j, answer_option in enumerate( 125 | dialog_for_image[i]["answer_options"] 126 | ): 127 | dialog_for_image[i]["answer_options"][j] = self.answers[ 128 | answer_option 129 | ] 130 | 131 | return { 132 | "image_id": image_id, 133 | "caption": caption_for_image, 134 | "dialog": dialog_for_image, 135 | "num_rounds": num_rounds, 136 | 'opt_list':opt_list, 137 | } 138 | 139 | def keys(self) -> List[int]: 140 | return list(self.dialogs.keys()) 141 | 142 | @property 143 | def split(self): 144 | return self._split 145 | 146 | 147 | class DenseAnnotationsReader(object): 148 | """ 149 | A reader for dense annotations for val split. The json file must have the 150 | same structure as mentioned on ``https://visualdialog.org/data``. 151 | 152 | Parameters 153 | ---------- 154 | dense_annotations_jsonpath : str 155 | Path to a json file containing VisDial v1.0 156 | """ 157 | 158 | def __init__(self, dense_annotations_jsonpath: str): 159 | with open(dense_annotations_jsonpath, "r") as visdial_file: 160 | self._visdial_data = json.load(visdial_file) 161 | self._image_ids = [ 162 | entry["image_id"] for entry in self._visdial_data 163 | ] 164 | 165 | def __len__(self): 166 | return len(self._image_ids) 167 | 168 | def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: 169 | index = self._image_ids.index(image_id) 170 | # keys: {"image_id", "round_id", "gt_relevance"} 171 | return self._visdial_data[index] 172 | 173 | @property 174 | def split(self): 175 | # always 176 | return "val" 177 | 178 | 179 | class ImageFeaturesHdfReader(object): 180 | """ 181 | A reader for HDF files containing pre-extracted image features. A typical 182 | HDF file is expected to have a column named "image_id", and another column 183 | named "features". 184 | 185 | Example of an HDF file: 186 | ``` 187 | visdial_train_faster_rcnn_bottomup_features.h5 188 | |--- "image_id" [shape: (num_images, )] 189 | |--- "features" [shape: (num_images, num_proposals, feature_size)] 190 | +--- .attrs ("split", "train") 191 | ``` 192 | Refer ``$PROJECT_ROOT/data/extract_bottomup.py`` script for more details 193 | about HDF structure. 194 | 195 | Parameters 196 | ---------- 197 | features_hdfpath : str 198 | Path to an HDF file containing VisDial v1.0 train, val or test split 199 | image features. 200 | in_memory : bool 201 | Whether to load the whole HDF file in memory. Beware, these files are 202 | sometimes tens of GBs in size. Set this to true if you have sufficient 203 | RAM - trade-off between speed and memory. 204 | """ 205 | 206 | def __init__(self, features_hdfpath: str, in_memory: bool = False): 207 | self.features_hdfpath = features_hdfpath 208 | self._in_memory = in_memory 209 | 210 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 211 | self._split = features_hdf.attrs["split"] 212 | self.image_id_list = list(features_hdf["image_id"]) 213 | # "features" is List[np.ndarray] if the dataset is loaded in-memory 214 | # If not loaded in memory, then list of None. 215 | self.features = [None] * len(self.image_id_list) 216 | 217 | def __len__(self): 218 | return len(self.image_id_list) 219 | 220 | def __getitem__(self, image_id: int): 221 | index = self.image_id_list.index(image_id) 222 | if self._in_memory: 223 | # Load features during first epoch, all not loaded together as it 224 | # has a slow start. 225 | if self.features[index] is not None: 226 | image_id_features = self.features[index] 227 | else: 228 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 229 | image_id_features = features_hdf["features"][index] 230 | self.features[index] = image_id_features 231 | else: 232 | # Read chunk from file everytime if not loaded in memory. 233 | with h5py.File(self.features_hdfpath, "r") as features_hdf: 234 | image_id_features = features_hdf["features"][index] 235 | 236 | return image_id_features 237 | 238 | def keys(self) -> List[int]: 239 | return self.image_id_list 240 | 241 | @property 242 | def split(self): 243 | return self._split 244 | -------------------------------------------------------------------------------- /visdialch/data/vocabulary.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Vocabulary maintains a mapping between words and corresponding unique 3 | integers, holds special integers (tokens) for indicating start and end of 4 | sequence, and offers functionality to map out-of-vocabulary words to the 5 | corresponding token. 6 | """ 7 | import json 8 | import os 9 | from typing import List 10 | 11 | 12 | class Vocabulary(object): 13 | """ 14 | A simple Vocabulary class which maintains a mapping between words and 15 | integer tokens. Can be initialized either by word counts from the VisDial 16 | v1.0 train dataset, or a pre-saved vocabulary mapping. 17 | 18 | Parameters 19 | ---------- 20 | word_counts_path: str 21 | Path to a json file containing counts of each word across captions, 22 | questions and answers of the VisDial v1.0 train dataset. 23 | min_count : int, optional (default=0) 24 | When initializing the vocabulary from word counts, you can specify a 25 | minimum count, and every token with a count less than this will be 26 | excluded from vocabulary. 27 | """ 28 | 29 | PAD_TOKEN = "" 30 | SOS_TOKEN = "" 31 | EOS_TOKEN = "" 32 | UNK_TOKEN = "" 33 | 34 | PAD_INDEX = 0 35 | SOS_INDEX = 1 36 | EOS_INDEX = 2 37 | UNK_INDEX = 3 38 | 39 | def __init__(self, word_counts_path: str, min_count: int = 5): 40 | if not os.path.exists(word_counts_path): 41 | raise FileNotFoundError( 42 | f"Word counts do not exist at {word_counts_path}" 43 | ) 44 | 45 | with open(word_counts_path, "r") as word_counts_file: 46 | word_counts = json.load(word_counts_file) 47 | 48 | # form a list of (word, count) tuples and apply min_count threshold 49 | word_counts = [ 50 | (word, count) 51 | for word, count in word_counts.items() 52 | if count >= min_count 53 | ] 54 | # sort in descending order of word counts 55 | word_counts = sorted(word_counts, key=lambda wc: -wc[1]) 56 | words = [w[0] for w in word_counts] 57 | 58 | self.word2index = {} 59 | self.word2index[self.PAD_TOKEN] = self.PAD_INDEX 60 | self.word2index[self.SOS_TOKEN] = self.SOS_INDEX 61 | self.word2index[self.EOS_TOKEN] = self.EOS_INDEX 62 | self.word2index[self.UNK_TOKEN] = self.UNK_INDEX 63 | for index, word in enumerate(words): 64 | self.word2index[word] = index + 4 65 | 66 | self.index2word = { 67 | index: word for word, index in self.word2index.items() 68 | } 69 | 70 | @classmethod 71 | def from_saved(cls, saved_vocabulary_path: str) -> "Vocabulary": 72 | """Build the vocabulary from a json file saved by ``save`` method. 73 | 74 | Parameters 75 | ---------- 76 | saved_vocabulary_path : str 77 | Path to a json file containing word to integer mappings 78 | (saved vocabulary). 79 | """ 80 | with open(saved_vocabulary_path, "r") as saved_vocabulary_file: 81 | cls.word2index = json.load(saved_vocabulary_file) 82 | cls.index2word = { 83 | index: word for word, index in cls.word2index.items() 84 | } 85 | 86 | def to_indices(self, words: List[str]) -> List[int]: 87 | return [self.word2index.get(word, self.UNK_INDEX) for word in words] 88 | 89 | def to_words(self, indices: List[int]) -> List[str]: 90 | return [ 91 | self.index2word.get(index, self.UNK_TOKEN) for index in indices 92 | ] 93 | 94 | def save(self, save_vocabulary_path: str) -> None: 95 | with open(save_vocabulary_path, "w") as save_vocabulary_file: 96 | json.dump(self.word2index, save_vocabulary_file) 97 | 98 | def __len__(self): 99 | return len(self.index2word) 100 | -------------------------------------------------------------------------------- /visdialch/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.decoders.disc_by_round import Disc_by_round_Decoder 2 | from visdialch.decoders.disc_qt import Disc_qt_Decoder 3 | 4 | 5 | 6 | def Decoder(model_config, *args): 7 | name_dec_map = { 8 | 'disc_by_round': Disc_by_round_Decoder, 9 | 'disc_qt': Disc_qt_Decoder, 10 | } 11 | return name_dec_map[model_config["decoder"]](model_config, *args) 12 | -------------------------------------------------------------------------------- /visdialch/decoders/disc_by_round.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | class Disc_by_round_Decoder(nn.Module): 6 | def __init__(self, config, vocabulary): 7 | super().__init__() 8 | self.config = config 9 | self.nhid = config["lstm_hidden_size"] 10 | 11 | self.word_embed = nn.Embedding( 12 | len(vocabulary), 13 | config["word_embedding_size"], 14 | padding_idx=vocabulary.PAD_INDEX, 15 | ) 16 | self.option_rnn = nn.LSTM( 17 | config["word_embedding_size"], 18 | config["lstm_hidden_size"], 19 | config["lstm_num_layers"], 20 | batch_first=True, 21 | dropout=config["dropout"], 22 | ) 23 | self.a2a = nn.Linear(self.nhid * 2, self.nhid) # this is useless in this version 24 | # Options are variable length padded sequences, use DynamicRNN. 25 | self.option_rnn = DynamicRNN(self.option_rnn) 26 | 27 | def forward(self, encoder_output, batch): 28 | """Given `encoder_output` + candidate option sequences, predict a score 29 | for each option sequence. 30 | 31 | Parameters 32 | ---------- 33 | encoder_output: torch.Tensor 34 | Output from the encoder through its forward pass. 35 | (batch_size, num_rounds, lstm_hidden_size) 36 | """ 37 | options = batch["opt"] 38 | batch_size, num_options, max_sequence_length = options.size() 39 | options = options.contiguous().view(-1, max_sequence_length) 40 | 41 | options_length = batch["opt_len"] 42 | options_length = options_length.contiguous().view(-1) 43 | 44 | options_embed = self.word_embed(options) # b*100 20 300 45 | _, (options_feat, _) = self.option_rnn(options_embed, options_length) # b*100 512 46 | options_feat = options_feat.view(batch_size, num_options, self.nhid) 47 | 48 | 49 | encoder_output = encoder_output.unsqueeze(1).repeat(1, num_options, 1) 50 | scores = torch.sum(options_feat * encoder_output, -1) 51 | scores = scores.view(batch_size, num_options) 52 | 53 | return scores 54 | -------------------------------------------------------------------------------- /visdialch/decoders/disc_qt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import json 4 | from visdialch.utils import DynamicRNN 5 | 6 | 7 | class Disc_qt_Decoder(nn.Module): 8 | def __init__(self, config, vocabulary): 9 | super().__init__() 10 | self.config = config 11 | self.nhid = config["lstm_hidden_size"] 12 | 13 | self.word_embed = nn.Embedding( 14 | len(vocabulary), 15 | config["word_embedding_size"], 16 | padding_idx=vocabulary.PAD_INDEX, 17 | ) 18 | self.option_rnn = nn.LSTM( 19 | config["word_embedding_size"], 20 | config["lstm_hidden_size"], 21 | config["lstm_num_layers"], 22 | batch_first=True, 23 | dropout=config["dropout"], 24 | ) 25 | self.a2a = nn.Linear(self.nhid *2, self.nhid) 26 | self.option_rnn = DynamicRNN(self.option_rnn) 27 | path = "data/qt_scores.json" 28 | file = open(path, 'r') 29 | self.count_dict = json.loads(file.read()) 30 | file.close() 31 | file = open('data/qt_count.json','r') 32 | self.qt_file = json.loads(file.read()) 33 | self.qt_list = list(self.qt_file.keys()) 34 | file.close() 35 | 36 | def forward(self, encoder_output, batch): 37 | """Given `encoder_output` + candidate option sequences, predict a score 38 | for each option sequence. 39 | 40 | Parameters 41 | ---------- 42 | encoder_output: torch.Tensor 43 | Output from the encoder through its forward pass. 44 | (batch_size, num_rounds, lstm_hidden_size) 45 | """ 46 | 47 | options = batch["opt"] 48 | batch_size, num_options, max_sequence_length = options.size() 49 | options = options.contiguous().view(-1, max_sequence_length) 50 | 51 | options_length = batch["opt_len"] 52 | options_length = options_length.contiguous().view(-1) 53 | 54 | options_embed = self.word_embed(options) #b*100 20 300 55 | _, (options_feat, _) = self.option_rnn(options_embed, options_length) #b*100 512 56 | options_feat = options_feat.view(batch_size, num_options, self.nhid) 57 | 58 | 59 | encoder_output = encoder_output.unsqueeze(1).repeat(1, num_options, 1) 60 | 61 | scores = torch.sum(options_feat * encoder_output, -1) 62 | scores = scores.view(batch_size, num_options) 63 | 64 | 65 | qt_score = torch.zeros_like(scores) 66 | qt_idx = batch['qt'] 67 | opt_idx = batch['opt_idx'] 68 | for b in range(batch_size): 69 | qt_key = self.qt_list[qt_idx[b]] 70 | ans_relevance = self.count_dict[qt_key] 71 | for k in range(100): 72 | idx_temp = str(opt_idx[b][k].detach().cpu().numpy()) 73 | if idx_temp in ans_relevance.keys(): 74 | qt_score[b][k] = 1 75 | 76 | return scores, qt_score 77 | 78 | 79 | -------------------------------------------------------------------------------- /visdialch/decoders/discvdr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from visdialch.utils import DynamicRNN 5 | 6 | 7 | class DiscriminativeVDRDecoder(nn.Module): 8 | def __init__(self, config, vocabulary): 9 | super().__init__() 10 | self.config = config 11 | 12 | self.word_embed = nn.Embedding( 13 | len(vocabulary), 14 | config["word_embedding_size"], 15 | padding_idx=vocabulary.PAD_INDEX, 16 | ) 17 | self.option_rnn = nn.LSTM( 18 | config["word_embedding_size"], 19 | config["lstm_hidden_size"], 20 | config["lstm_num_layers"], 21 | batch_first=True, 22 | dropout=config["dropout"], 23 | ) 24 | # self.a2a = nn.Linear(self.nhid * 2, self.nhid) # this is useless in this version 25 | # Options are variable length padded sequences, use DynamicRNN. 26 | self.option_rnn = DynamicRNN(self.option_rnn) 27 | 28 | def forward(self, encoder_output, batch): 29 | """Given `encoder_output` + candidate option sequences, predict a score 30 | for each option sequence. 31 | 32 | Parameters 33 | ---------- 34 | encoder_output: torch.Tensor 35 | Output from the encoder through its forward pass. 36 | (batch_size, num_rounds, lstm_hidden_size) 37 | """ 38 | 39 | options = batch["opt"] 40 | 41 | batch_size, num_rounds, num_options, max_sequence_length = ( 42 | options.size() 43 | ) 44 | options = options.view( 45 | batch_size * num_rounds * num_options, max_sequence_length 46 | ) 47 | 48 | options_length = batch["opt_len"] 49 | options_length = options_length.view( 50 | batch_size * num_rounds * num_options 51 | ) 52 | 53 | # Pick options with non-zero length (relevant for test split). 54 | nonzero_options_length_indices = options_length.nonzero().squeeze() 55 | nonzero_options_length = options_length[nonzero_options_length_indices] 56 | nonzero_options = options[nonzero_options_length_indices] 57 | 58 | # shape: (batch_size * num_rounds * num_options, max_sequence_length, 59 | # word_embedding_size) 60 | # FOR TEST SPLIT, shape: (batch_size * 1, num_options, 61 | # max_sequence_length, word_embedding_size) 62 | nonzero_options_embed = self.word_embed(nonzero_options) 63 | 64 | # shape: (batch_size * num_rounds * num_options, lstm_hidden_size) 65 | # FOR TEST SPLIT, shape: (batch_size * 1, num_options, 66 | # lstm_hidden_size) 67 | _, (nonzero_options_embed, _) = self.option_rnn( 68 | nonzero_options_embed, nonzero_options_length 69 | ) 70 | 71 | options_embed = torch.zeros( 72 | batch_size * num_rounds * num_options, 73 | nonzero_options_embed.size(-1), 74 | device=nonzero_options_embed.device, 75 | ) 76 | options_embed[nonzero_options_length_indices] = nonzero_options_embed 77 | 78 | # Repeat encoder output for every option. 79 | # shape: (batch_size, num_rounds, num_options, max_sequence_length) 80 | encoder_output = encoder_output.unsqueeze(2).repeat( 81 | 1, 1, num_options, 1 82 | ) 83 | 84 | # Shape now same as `options`, can calculate dot product similarity. 85 | lstm_hidden_size = self.config["lstm_hidden_size"] 86 | encoder_output = encoder_output.view( 87 | batch_size * num_rounds * num_options, 88 | lstm_hidden_size, 89 | ) 90 | # shape: (batch_size * num_rounds * num_options, 1, lstm_hidden_state) 91 | scores = torch.sum(options_embed * encoder_output, -1) 92 | scores = scores.view(batch_size, num_rounds, num_options) 93 | return scores 94 | -------------------------------------------------------------------------------- /visdialch/encoders/Coatt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | 6 | class CoAtt_Encoder(nn.Module): 7 | def __init__(self, config, vocabulary): 8 | super().__init__() 9 | self.config = config 10 | self.dropout = config['dropout'] 11 | self.nhid = config['lstm_hidden_size'] 12 | self.img_feature_size = config['img_feature_size'] 13 | self.ninp = config['word_embedding_size'] 14 | self.word_embed = nn.Embedding( 15 | len(vocabulary), 16 | config["word_embedding_size"], 17 | padding_idx=vocabulary.PAD_INDEX, 18 | ) 19 | self.hist_rnn = nn.LSTM( 20 | config["word_embedding_size"], 21 | config["lstm_hidden_size"], 22 | config["lstm_num_layers"], 23 | batch_first=True, 24 | dropout=config["dropout"], 25 | ) 26 | self.ques_rnn = nn.LSTM( 27 | config["word_embedding_size"], 28 | config["lstm_hidden_size"], 29 | config["lstm_num_layers"], 30 | batch_first=True, 31 | dropout=config["dropout"], 32 | ) 33 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 34 | self.hist_rnn = DynamicRNN(self.hist_rnn) 35 | self.ques_rnn = DynamicRNN(self.ques_rnn) 36 | ################################################################origin 37 | ##q c att on img 38 | self.Wq2 = nn.Linear(self.nhid, self.nhid) 39 | self.Wi2 = nn.Linear(self.img_feature_size, self.nhid) 40 | self.Wall2 = nn.Linear(self.nhid, 1) 41 | 42 | ###########################add 43 | ##q att on h 44 | self.Wq1 = nn.Linear(self.nhid, self.nhid) 45 | self.Wh1 = nn.Linear(self.nhid, self.nhid) 46 | self.Wi1 = nn.Linear(self.img_feature_size, self.nhid) 47 | self.Wqvh1 = nn.Linear(self.nhid, 1) 48 | ##hv att on q 49 | self.Wq3 = nn.Linear(self.nhid, self.nhid) 50 | self.Wh3 = nn.Linear(self.nhid, self.nhid) 51 | self.Wi3 = nn.Linear(self.img_feature_size, self.nhid) 52 | self.Wqvh3 = nn.Linear(self.nhid, 1) 53 | ## step4 54 | self.Wq4 = nn.Linear(self.nhid, self.nhid) 55 | self.Wh4 = nn.Linear(self.nhid, self.nhid) 56 | self.Wi4 = nn.Linear(self.img_feature_size, self.nhid) 57 | self.Wall4 = nn.Linear(self.nhid, 1) 58 | ## 59 | self.fusion = nn.Linear(self.nhid*2 + self.img_feature_size, self.nhid) 60 | ########################################new 61 | for m in self.modules(): 62 | if isinstance(m, nn.Linear): 63 | nn.init.kaiming_uniform_(m.weight.data) 64 | if m.bias is not None: 65 | nn.init.constant_(m.bias.data, 0) 66 | ###ATT STEP1 67 | def q_att_on_img(self, ques_feat, img_feat): 68 | batch_size = ques_feat.size(0) 69 | region_size = img_feat.size(1) 70 | device = ques_feat.device 71 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 72 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 73 | all_score = self.Wall2( 74 | self.dropout( 75 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1)) 76 | ) 77 | ).view(batch_size, -1) 78 | img_final_feat = torch.bmm( 79 | torch.softmax(all_score, dim = -1 ) 80 | .view(batch_size,1,-1),img_feat) 81 | return img_final_feat.view(batch_size,-1) 82 | ###ATT STEP2 83 | def qv_att_on_his(self, ques_feat, img_feat, his_feat): 84 | batch_size = ques_feat.size(0) 85 | rnd = his_feat.size(1) 86 | device = ques_feat.device 87 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 88 | i_emb = self.Wi1(img_feat).view(batch_size, -1, self.nhid) 89 | h_emb = self.Wh1(his_feat) 90 | 91 | score = self.Wqvh1( 92 | self.dropout( 93 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)+ i_emb.repeat(1, rnd, 1)) 94 | ) 95 | ).view(batch_size,-1) 96 | weight = torch.softmax(score, dim = -1 ) 97 | atted_his_feat = torch.bmm(weight.view(batch_size,1,-1) ,his_feat) 98 | return atted_his_feat 99 | ###ATT STEP2 100 | def hv_att_in_ques(self, his_feat, img_feat, q_output, ques_len): 101 | batch_size = q_output.size(0) 102 | q_emb_length = q_output.size(1) 103 | device = his_feat.device 104 | q_emb = self.Wq3(q_output) 105 | i_emb = self.Wi3(img_feat).view(batch_size, -1, self.nhid) 106 | h_emb = self.Wh3(his_feat).view(batch_size, -1, self.nhid) 107 | score = self.Wqvh3( 108 | self.dropout( 109 | torch.tanh(q_emb + h_emb.repeat(1, q_emb_length, 1)+ i_emb.repeat(1, q_emb_length, 1)) 110 | ) 111 | ).view(batch_size,-1) 112 | mask = score.detach().eq(0) 113 | for i in range(batch_size): 114 | mask[i,ques_len[i]:] = 1 115 | score.masked_fill_(mask, -1e5) 116 | weight = torch.softmax(score, dim = -1 ) 117 | atted_his_ques = torch.bmm(weight.view(batch_size,1,-1) , q_output) 118 | return atted_his_ques 119 | ###ATT STEP4 120 | def qh_att_in_img(self, ques_feat, his_feat, img_feat): 121 | batch_size = ques_feat.size(0) 122 | region_size = img_feat.size(1) 123 | q_emb = self.Wq4(ques_feat).view(batch_size, -1, self.nhid) 124 | h_emb = self.Wh4(his_feat).view(batch_size, -1, self.nhid) 125 | i_emb = self.Wi4(img_feat) 126 | all_score = self.Wall4( 127 | self.dropout( 128 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1) +h_emb.repeat(1, region_size, 1)) 129 | ) 130 | ).view(batch_size, -1) 131 | img_final_feat = torch.bmm( 132 | torch.softmax(all_score, dim = -1 ) 133 | .view(batch_size,1,-1),img_feat) 134 | return img_final_feat.view(batch_size,-1), torch.softmax(all_score, dim = -1) 135 | 136 | ##################################################### 137 | 138 | def forward(self, batch): 139 | img = batch["img_feat"] # b 36 2048 140 | ques = batch["ques"] # b q_len 141 | his = batch["hist"] # b rnd q_len*2 142 | batch_size, rnd, max_his_length = his.size() 143 | ques_len = batch["ques_len"] 144 | 145 | # embed questions 146 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 147 | ques_embed = self.word_embed(ques) # b 20 300 148 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 149 | ques_feat = q_output[range(batch_size), ques_location,:] 150 | 151 | ####his emb 152 | his = his.contiguous().view(-1, max_his_length) 153 | his_embed = self.word_embed(his) # b*rnd 40 300 154 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 155 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 156 | 157 | ############### ATT step1: q att on img -> v_1 158 | img_atted_feat_v1 = self.q_att_on_img(ques_feat, img).view(batch_size, self.img_feature_size) 159 | ############### ATT step2: q v att on his -> h_f 160 | his_atted_feat_f = self.qv_att_on_his(ques_feat, img_atted_feat_v1, his_feat).view(batch_size, self.nhid) 161 | ############### ATT step3: v_1 h_f att on ques -> q_f 162 | ques_atted_feat_f = self.hv_att_in_ques(his_atted_feat_f, img_atted_feat_v1, q_output, ques_len).view(batch_size, self.nhid) 163 | ############### ATT step4: q_f h_f att on img -> v_f 164 | img_atted_feat_f, img_att = self.qh_att_in_img(ques_atted_feat_f, his_atted_feat_f, img) 165 | img_atted_feat_f = img_atted_feat_f.view(batch_size, self.img_feature_size) 166 | 167 | fused_vector = torch.cat((ques_atted_feat_f, his_atted_feat_f, img_atted_feat_f), dim=-1) 168 | fused_embedding = torch.tanh(self.fusion(fused_vector)).view(batch_size, -1) 169 | return fused_embedding # out is b * 512 170 | -------------------------------------------------------------------------------- /visdialch/encoders/Coatt_withP1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | 6 | class CoAtt_withP1_Encoder(nn.Module): 7 | def __init__(self, config, vocabulary): 8 | super().__init__() 9 | self.config = config 10 | self.dropout = config['dropout'] 11 | self.nhid = config['lstm_hidden_size'] 12 | self.img_feature_size = config['img_feature_size'] 13 | self.ninp = config['word_embedding_size'] 14 | self.word_embed = nn.Embedding( 15 | len(vocabulary), 16 | config["word_embedding_size"], 17 | padding_idx=vocabulary.PAD_INDEX, 18 | ) 19 | self.hist_rnn = nn.LSTM( 20 | config["word_embedding_size"], 21 | config["lstm_hidden_size"], 22 | config["lstm_num_layers"], 23 | batch_first=True, 24 | dropout=config["dropout"], 25 | ) 26 | self.ques_rnn = nn.LSTM( 27 | config["word_embedding_size"], 28 | config["lstm_hidden_size"], 29 | config["lstm_num_layers"], 30 | batch_first=True, 31 | dropout=config["dropout"], 32 | ) 33 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 34 | self.hist_rnn = DynamicRNN(self.hist_rnn) 35 | self.ques_rnn = DynamicRNN(self.ques_rnn) 36 | ################################################################origin 37 | ##q c att on img 38 | self.Wq2 = nn.Linear(self.nhid, self.nhid) 39 | self.Wi2 = nn.Linear(self.img_feature_size, self.nhid) 40 | self.Wall2 = nn.Linear(self.nhid, 1) 41 | 42 | ###########################add 43 | ##q att on h 44 | self.Wq1 = nn.Linear(self.nhid, self.nhid) 45 | self.Wh1 = nn.Linear(self.nhid, self.nhid) 46 | self.Wi1 = nn.Linear(self.img_feature_size, self.nhid) 47 | self.Wqvh1 = nn.Linear(self.nhid, 1) 48 | ##hv att on q 49 | self.Wq3 = nn.Linear(self.nhid, self.nhid) 50 | self.Wh3 = nn.Linear(self.nhid, self.nhid) 51 | self.Wi3 = nn.Linear(self.img_feature_size, self.nhid) 52 | self.Wqvh3 = nn.Linear(self.nhid, 1) 53 | ## step4 54 | self.Wq4 = nn.Linear(self.nhid, self.nhid) 55 | self.Wh4 = nn.Linear(self.nhid, self.nhid) 56 | self.Wi4 = nn.Linear(self.img_feature_size, self.nhid) 57 | self.Wall4 = nn.Linear(self.nhid, 1) 58 | ## 59 | self.fusion = nn.Linear(self.nhid + self.img_feature_size, self.nhid) 60 | ########################################new 61 | for m in self.modules(): 62 | if isinstance(m, nn.Linear): 63 | nn.init.kaiming_uniform_(m.weight.data) 64 | if m.bias is not None: 65 | nn.init.constant_(m.bias.data, 0) 66 | ###ATT STEP1 67 | def q_att_on_img(self, ques_feat, img_feat): 68 | batch_size = ques_feat.size(0) 69 | region_size = img_feat.size(1) 70 | device = ques_feat.device 71 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 72 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 73 | all_score = self.Wall2( 74 | self.dropout( 75 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1)) 76 | ) 77 | ).view(batch_size, -1) 78 | img_final_feat = torch.bmm( 79 | torch.softmax(all_score, dim = -1 ) 80 | .view(batch_size,1,-1),img_feat) 81 | return img_final_feat.view(batch_size,-1) 82 | ###ATT STEP2 83 | def qv_att_on_his(self, ques_feat, img_feat, his_feat): 84 | batch_size = ques_feat.size(0) 85 | rnd = his_feat.size(1) 86 | device = ques_feat.device 87 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 88 | i_emb = self.Wi1(img_feat).view(batch_size, -1, self.nhid) 89 | h_emb = self.Wh1(his_feat) 90 | 91 | score = self.Wqvh1( 92 | self.dropout( 93 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)+ i_emb.repeat(1, rnd, 1)) 94 | ) 95 | ).view(batch_size,-1) 96 | weight = torch.softmax(score, dim = -1 ) 97 | atted_his_feat = torch.bmm(weight.view(batch_size,1,-1) ,his_feat) 98 | return atted_his_feat 99 | ###ATT STEP2 100 | def hv_att_in_ques(self, his_feat, img_feat, q_output, ques_len): 101 | batch_size = q_output.size(0) 102 | q_emb_length = q_output.size(1) 103 | device = his_feat.device 104 | q_emb = self.Wq3(q_output) 105 | i_emb = self.Wi3(img_feat).view(batch_size, -1, self.nhid) 106 | h_emb = self.Wh3(his_feat).view(batch_size, -1, self.nhid) 107 | score = self.Wqvh3( 108 | self.dropout( 109 | torch.tanh(q_emb + h_emb.repeat(1, q_emb_length, 1)+ i_emb.repeat(1, q_emb_length, 1)) 110 | ) 111 | ).view(batch_size,-1) 112 | mask = score.detach().eq(0) 113 | for i in range(batch_size): 114 | mask[i,ques_len[i]:] = 1 115 | score.masked_fill_(mask, -1e5) 116 | weight = torch.softmax(score, dim = -1 ) 117 | atted_his_ques = torch.bmm(weight.view(batch_size,1,-1) , q_output) 118 | return atted_his_ques 119 | ###ATT STEP4 120 | def qh_att_in_img(self, ques_feat, his_feat, img_feat): 121 | batch_size = ques_feat.size(0) 122 | region_size = img_feat.size(1) 123 | q_emb = self.Wq4(ques_feat).view(batch_size, -1, self.nhid) 124 | h_emb = self.Wh4(his_feat).view(batch_size, -1, self.nhid) 125 | i_emb = self.Wi4(img_feat) 126 | all_score = self.Wall4( 127 | self.dropout( 128 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1) +h_emb.repeat(1, region_size, 1)) 129 | ) 130 | ).view(batch_size, -1) 131 | img_final_feat = torch.bmm( 132 | torch.softmax(all_score, dim = -1 ) 133 | .view(batch_size,1,-1),img_feat) 134 | return img_final_feat.view(batch_size,-1), torch.softmax(all_score, dim = -1) 135 | 136 | ##################################################### 137 | 138 | def forward(self, batch): 139 | img = batch["img_feat"] # b 36 2048 140 | ques = batch["ques"] # b q_len 141 | his = batch["hist"] # b rnd q_len*2 142 | batch_size, rnd, max_his_length = his.size() 143 | ques_len = batch["ques_len"] 144 | 145 | # embed questions 146 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 147 | ques_embed = self.word_embed(ques) # b 20 300 148 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 149 | ques_feat = q_output[range(batch_size), ques_location,:] 150 | 151 | ####his emb 152 | his = his.contiguous().view(-1, max_his_length) 153 | his_embed = self.word_embed(his) # b*rnd 40 300 154 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 155 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 156 | 157 | ############### ATT step1: q att on img -> v_1 158 | img_atted_feat_v1 = self.q_att_on_img(ques_feat, img).view(batch_size, self.img_feature_size) 159 | ############### ATT step2: q v att on his -> h_f 160 | his_atted_feat_f = self.qv_att_on_his(ques_feat, img_atted_feat_v1, his_feat).view(batch_size, self.nhid) 161 | ############### ATT step3: v_1 h_f att on ques -> q_f 162 | ques_atted_feat_f = self.hv_att_in_ques(his_atted_feat_f, img_atted_feat_v1, q_output, ques_len).view(batch_size, self.nhid) 163 | ############### ATT step4: q_f h_f att on img -> v_f 164 | img_atted_feat_f, img_att = self.qh_att_in_img(ques_atted_feat_f, his_feat[:,0], img) 165 | img_atted_feat_f = img_atted_feat_f.view(batch_size, self.img_feature_size) 166 | 167 | fused_vector = torch.cat((img_atted_feat_f, ques_feat), dim = -1) 168 | fused_embedding = torch.tanh(self.fusion(fused_vector)).view(batch_size, -1) 169 | return fused_embedding # out is b * 512 170 | -------------------------------------------------------------------------------- /visdialch/encoders/HCIAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | class HCIAE_Encoder(nn.Module): 6 | def __init__(self, config, vocabulary): 7 | super().__init__() 8 | self.config = config 9 | self.dropout = config['dropout'] 10 | self.nhid = config['lstm_hidden_size'] 11 | self.img_feature_size = config['img_feature_size'] 12 | self.ninp = config['word_embedding_size'] 13 | self.word_embed = nn.Embedding( 14 | len(vocabulary), 15 | config["word_embedding_size"], 16 | padding_idx=vocabulary.PAD_INDEX, 17 | ) 18 | self.hist_rnn = nn.LSTM( 19 | config["word_embedding_size"], 20 | config["lstm_hidden_size"], 21 | config["lstm_num_layers"], 22 | batch_first=True, 23 | dropout=config["dropout"], 24 | ) 25 | self.ques_rnn = nn.LSTM( 26 | config["word_embedding_size"], 27 | config["lstm_hidden_size"], 28 | config["lstm_num_layers"], 29 | batch_first=True, 30 | dropout=config["dropout"], 31 | ) 32 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 33 | self.hist_rnn = DynamicRNN(self.hist_rnn) 34 | self.ques_rnn = DynamicRNN(self.ques_rnn) 35 | 36 | ##q c att on img 37 | self.Wq2 = nn.Linear(self.nhid, self.nhid) 38 | self.Wh2 = nn.Linear(self.nhid, self.nhid) 39 | self.Wi2 = nn.Linear(self.img_feature_size, self.nhid) 40 | self.Wall2 = nn.Linear(self.nhid, 1) 41 | 42 | ##fusion 43 | self.Wq3 = nn.Linear(self.nhid , self.nhid ) 44 | self.Wc3 = nn.Linear(self.nhid , self.nhid ) 45 | self.fusion = nn.Linear(self.nhid * 2 + self.img_feature_size, self.nhid) 46 | ###cap att img 47 | self.Wc4 = nn.Linear(self.nhid , self.nhid) 48 | self.Wi4 = nn.Linear(self.img_feature_size, self.nhid) 49 | self.Wall4 = nn.Linear(self.nhid, 1) 50 | 51 | self.q_multi1 = nn.Linear(self.nhid, self.nhid) 52 | self.q_multi2 = nn.Linear(self.nhid, 3) 53 | 54 | ##q att on h 55 | self.Wq1 = nn.Linear(self.nhid, self.nhid) 56 | self.Wh1 = nn.Linear(self.nhid, self.nhid) 57 | self.Wqh1 = nn.Linear(self.nhid, 1) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Linear): 61 | nn.init.kaiming_uniform_(m.weight.data) 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias.data, 0) 64 | def qh_att_on_img(self, ques_feat, his_feat, img_feat): 65 | batch_size = ques_feat.size(0) 66 | region_size = img_feat.size(1) 67 | device = ques_feat.device 68 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 69 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 70 | h_emb = self.Wh2(his_feat).view(batch_size, -1, self.nhid) 71 | all_score = self.Wall2( 72 | self.dropout( 73 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1)+ h_emb.repeat(1, region_size, 1)) 74 | ) 75 | ).view(batch_size, -1) 76 | img_final_feat = torch.bmm( 77 | torch.softmax(all_score, dim = -1 ) 78 | .view(batch_size,1,-1),img_feat) 79 | return img_final_feat.view(batch_size,-1) 80 | def ques_att_on_his(self, ques_feat, his_feat): 81 | batch_size = ques_feat.size(0) 82 | rnd = his_feat.size(1) 83 | device = ques_feat.device 84 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 85 | h_emb = self.Wh1(his_feat) 86 | score = self.Wqh1( 87 | self.dropout( 88 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)) 89 | ) 90 | ).view(batch_size,-1) 91 | weight = torch.softmax(score, dim = -1 ) 92 | atted_his_feat = torch.bmm(weight.view(batch_size,1,-1) ,his_feat) 93 | return atted_his_feat 94 | ##################################################### 95 | 96 | def forward(self, batch): 97 | img = batch["img_feat"] # b 36 2048 98 | ques = batch["ques"] # b q_len 99 | his = batch["hist"] # b rnd q_len*2 100 | batch_size, rnd, max_his_length = his.size() 101 | ques_len = batch["ques_len"] 102 | 103 | # embed questions 104 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 105 | ques_embed = self.word_embed(ques) # b 20 300 106 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 107 | ques_feat = q_output[range(batch_size), ques_location,:] 108 | ####his emb 109 | his = his.contiguous().view(-1, max_his_length) 110 | his_embed = self.word_embed(his) # b*rnd 40 300 111 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 112 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 113 | q_att_his_feat = self.ques_att_on_his(ques_feat, his_feat).view(batch_size, self.nhid) # b 512 114 | 115 | q_att_img_feat = self.qh_att_on_img(ques_feat, q_att_his_feat, img).view(batch_size, -1) 116 | 117 | fused_vector = torch.cat((ques_feat, q_att_his_feat, q_att_img_feat), dim = -1) 118 | fused_vector = self.dropout(fused_vector) 119 | fused_embedding = torch.tanh(self.fusion(fused_vector)).view(batch_size, -1) 120 | 121 | return fused_embedding # out is b * 512 122 | -------------------------------------------------------------------------------- /visdialch/encoders/HCIAE_withP1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | class HCIAE_withP1_Encoder(nn.Module): 6 | def __init__(self, config, vocabulary): 7 | super().__init__() 8 | self.config = config 9 | self.dropout = config['dropout'] 10 | self.nhid = config['lstm_hidden_size'] 11 | self.img_feature_size = config['img_feature_size'] 12 | self.ninp = config['word_embedding_size'] 13 | 14 | self.word_embed = nn.Embedding( 15 | len(vocabulary), 16 | config["word_embedding_size"], 17 | padding_idx=vocabulary.PAD_INDEX, 18 | ) 19 | self.hist_rnn = nn.LSTM( 20 | config["word_embedding_size"], 21 | config["lstm_hidden_size"], 22 | config["lstm_num_layers"], 23 | batch_first=True, 24 | dropout=config["dropout"], 25 | ) 26 | self.ques_rnn = nn.LSTM( 27 | config["word_embedding_size"], 28 | config["lstm_hidden_size"], 29 | config["lstm_num_layers"], 30 | batch_first=True, 31 | dropout=config["dropout"], 32 | ) 33 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 34 | self.hist_rnn = DynamicRNN(self.hist_rnn) 35 | self.ques_rnn = DynamicRNN(self.ques_rnn) 36 | 37 | ##q c att on img 38 | self.Wq2 = nn.Linear(self.nhid, self.nhid) 39 | self.Wh2 = nn.Linear(self.nhid, self.nhid) 40 | self.Wi2 = nn.Linear(self.img_feature_size, self.nhid) 41 | self.Wall2 = nn.Linear(self.nhid, 1) 42 | 43 | ##fusion 44 | self.Wq3 = nn.Linear(self.nhid , self.nhid ) 45 | self.Wc3 = nn.Linear(self.nhid , self.nhid ) 46 | self.fusion = nn.Linear(self.nhid + self.img_feature_size, self.nhid) 47 | ###cap att img 48 | self.Wc4 = nn.Linear(self.nhid , self.nhid) 49 | self.Wi4 = nn.Linear(self.img_feature_size, self.nhid) 50 | self.Wall4 = nn.Linear(self.nhid, 1) 51 | self.q_multi1 = nn.Linear(self.nhid, self.nhid) 52 | self.q_multi2 = nn.Linear(self.nhid, 3) 53 | 54 | ##q att on h 55 | self.Wq1 = nn.Linear(self.nhid, self.nhid) 56 | self.Wh1 = nn.Linear(self.nhid, self.nhid) 57 | self.Wqh1 = nn.Linear(self.nhid, 1) 58 | 59 | ###his on q 60 | self.Wcs5 = nn.Sequential(self.dropout,nn.Linear(self.nhid , self.nhid )) 61 | self.Wq5 = nn.Sequential(self.dropout,nn.Linear(self.nhid , self.nhid )) 62 | self.Wall5 = nn.Linear(self.nhid, 1) 63 | 64 | for m in self.modules(): 65 | if isinstance(m, nn.Linear): 66 | nn.init.kaiming_uniform_(m.weight.data) 67 | if m.bias is not None: 68 | nn.init.constant_(m.bias.data, 0) 69 | def qh_att_on_img(self, ques_feat, his_feat, img_feat): 70 | batch_size = ques_feat.size(0) 71 | region_size = img_feat.size(1) 72 | device = ques_feat.device 73 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 74 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 75 | h_emb = self.Wh2(his_feat).view(batch_size, -1, self.nhid) 76 | all_score = self.Wall2( 77 | self.dropout( 78 | torch.tanh(i_emb + q_emb.repeat(1, region_size, 1)+ h_emb.repeat(1, region_size, 1)) 79 | ) 80 | ).view(batch_size, -1) 81 | img_final_feat = torch.bmm( 82 | torch.softmax(all_score, dim = -1 ) 83 | .view(batch_size,1,-1),img_feat) 84 | return img_final_feat.view(batch_size,-1) 85 | def ques_att_on_his(self, ques_feat, his_feat): 86 | batch_size = ques_feat.size(0) 87 | rnd = his_feat.size(1) 88 | device = ques_feat.device 89 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 90 | h_emb = self.Wh1(his_feat) 91 | score = self.Wqh1( 92 | self.dropout( 93 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)) 94 | ) 95 | ).view(batch_size,-1) 96 | weight = torch.softmax(score, dim = -1 ) 97 | atted_his_feat = torch.bmm(weight.view(batch_size,1,-1) ,his_feat) 98 | return atted_his_feat 99 | ##################################################### 100 | 101 | def forward(self, batch): 102 | img = batch["img_feat"] # b 36 2048 103 | ques = batch["ques"] # b q_len 104 | his = batch["hist"] # b rnd q_len*2 105 | batch_size, rnd, max_his_length = his.size() 106 | ques_len = batch["ques_len"] 107 | 108 | # embed questions 109 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 110 | ques_embed = self.word_embed(ques) # b 20 300 111 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 112 | ques_feat = q_output[range(batch_size), ques_location,:] 113 | 114 | ####his emb 115 | his = his.contiguous().view(-1, max_his_length) 116 | his_embed = self.word_embed(his) # b*rnd 40 300 117 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 118 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 119 | q_att_img_feat = self.qh_att_on_img(ques_feat, his_feat[:,0], img).view(batch_size, -1) 120 | 121 | fused_vector = torch.cat((ques_feat, q_att_img_feat), dim = -1) 122 | fused_vector = self.dropout(fused_vector) 123 | fused_embedding = torch.tanh(self.fusion(fused_vector)).view(batch_size, -1) 124 | 125 | return fused_embedding # out is b * 512 126 | -------------------------------------------------------------------------------- /visdialch/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from visdialch.encoders.lf_enhanced import LF_Enhanced_Encoder 2 | from visdialch.encoders.lf_enhanced_withP1 import LF_Enhanced_withP1_Encoder 3 | from visdialch.encoders.dict_encoder import Dict_Encoder 4 | 5 | 6 | def Encoder(model_config, *args): 7 | name_enc_map = { 8 | 'baseline_encoder': LF_Enhanced_Encoder, 9 | 'baseline_encoder_withP1': LF_Enhanced_withP1_Encoder, 10 | 'dict_encoder' : Dict_Encoder, 11 | } 12 | return name_enc_map[model_config["encoder"]](model_config, *args) 13 | -------------------------------------------------------------------------------- /visdialch/encoders/dict_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from visdialch.utils import DynamicRNN 6 | 7 | class Dict_Encoder(nn.Module): 8 | def __init__(self, config, vocabulary): 9 | super().__init__() 10 | self.config = config 11 | self.dropout = config['dropout'] 12 | self.nhid = config['lstm_hidden_size'] 13 | self.img_feature_size = config['img_feature_size'] 14 | self.ninp = config['word_embedding_size'] 15 | 16 | self.word_embed = nn.Embedding( 17 | len(vocabulary), 18 | config["word_embedding_size"], 19 | padding_idx=vocabulary.PAD_INDEX, 20 | ) 21 | self.hist_rnn = nn.LSTM( 22 | config["word_embedding_size"], 23 | config["lstm_hidden_size"], 24 | config["lstm_num_layers"], 25 | batch_first=True, 26 | dropout=config["dropout"], 27 | ) 28 | self.option_rnn = nn.LSTM( 29 | config["word_embedding_size"], 30 | config["lstm_hidden_size"], 31 | config["lstm_num_layers"], 32 | batch_first=True, 33 | dropout=config["dropout"], 34 | ) 35 | 36 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 37 | self.hist_rnn = DynamicRNN(self.hist_rnn) 38 | self.option_rnn = DynamicRNN(self.option_rnn) 39 | self.Wc = nn.Linear(self.nhid * 2, self.nhid) 40 | self.Wd = nn.Linear(self.nhid, self.nhid) 41 | self.Wall = nn.Linear(self.nhid, 1) 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | nn.init.kaiming_uniform_(m.weight.data) 45 | if m.bias is not None: 46 | nn.init.constant_(m.bias.data, 0) 47 | initial_path = 'data/100ans_feature.npy' 48 | initial_answer_feat = np.load(initial_path) 49 | self.user_dict = nn.Parameter(torch.FloatTensor(initial_answer_feat)) 50 | 51 | def through_dict(self, output_feat): #b x 100 x 512 -> b 100 x512 52 | batch_size = output_feat.size(0) 53 | q_size = output_feat.size(1) 54 | dict_size = self.user_dict.size(0) 55 | dict_feat = self.user_dict 56 | 57 | q_emb = output_feat.view(batch_size * q_size, -1, self.nhid) 58 | d_emb = self.Wd(dict_feat).view(-1, dict_size, self.nhid) 59 | all_score = self.Wall( 60 | self.dropout( 61 | torch.tanh(d_emb.repeat(batch_size * q_size, 1, 1) + q_emb.repeat(1, dict_size, 1)) 62 | ) 63 | ).view(batch_size * q_size, -1) 64 | dict_final_feat = torch.bmm( 65 | torch.softmax(all_score, dim = -1 ) 66 | .view(batch_size* q_size,1,-1),dict_feat.view(-1, dict_size, self.nhid).repeat(batch_size* q_size, 1, 1)) 67 | return dict_final_feat.view(batch_size,q_size,-1) 68 | 69 | def forward(self, batch): 70 | 71 | his = batch["hist"] # b rnd q_len*2 72 | batch_size, rnd, max_his_length = his.size() 73 | his = his.contiguous().view(-1, max_his_length) 74 | his_embed = self.word_embed(his) 75 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 76 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 77 | his_feat = torch.mean(his_feat, dim=1) 78 | his_feat = his_feat 79 | 80 | options = batch["opt"] 81 | batch_size, num_options, max_sequence_length = options.size() 82 | options = options.contiguous().view(-1, max_sequence_length) 83 | options_length = batch["opt_len"] 84 | options_length = options_length.contiguous().view(-1) 85 | options_embed = self.word_embed(options) 86 | _, (options_feat, _) = self.option_rnn(options_embed, options_length) 87 | options_feat = options_feat.view(batch_size, num_options, self.nhid) 88 | 89 | his_feat = his_feat.unsqueeze(1).repeat(1,options_feat.size(1),1) 90 | cat_feat = torch.cat((his_feat,options_feat),dim=-1) 91 | cat_feat = self.dropout(self.Wc(cat_feat)) 92 | 93 | output_feat = self.through_dict(cat_feat) # updated version (you also can try the key using (his; ques; ques and ans)) 94 | scores = torch.sum(options_feat * output_feat, -1) 95 | scores = scores.view(batch_size, num_options) 96 | 97 | return scores # out is b * 512 98 | -------------------------------------------------------------------------------- /visdialch/encoders/lf_enhanced.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from visdialch.utils import DynamicRNN 5 | 6 | 7 | class LF_Enhanced_Encoder(nn.Module): 8 | def __init__(self, config, vocabulary): 9 | super().__init__() 10 | self.config = config 11 | self.dropout = config['dropout'] 12 | self.nhid = config['lstm_hidden_size'] 13 | self.img_feature_size = config['img_feature_size'] 14 | self.ninp = config['word_embedding_size'] 15 | self.head_num = config['head_num'] 16 | self.word_embed = nn.Embedding( 17 | len(vocabulary), 18 | config["word_embedding_size"], 19 | padding_idx=vocabulary.PAD_INDEX, 20 | ) 21 | self.hist_rnn = nn.LSTM( 22 | config["word_embedding_size"], 23 | config["lstm_hidden_size"], 24 | config["lstm_num_layers"], 25 | batch_first=True, 26 | dropout=config["dropout"], 27 | ) 28 | self.ques_rnn = nn.LSTM( 29 | config["word_embedding_size"], 30 | config["lstm_hidden_size"], 31 | config["lstm_num_layers"], 32 | batch_first=True, 33 | dropout=config["dropout"], 34 | ) 35 | self.cap_rnn = nn.LSTM( 36 | config["word_embedding_size"], 37 | config["lstm_hidden_size"], 38 | config["lstm_num_layers"], 39 | batch_first=True, 40 | dropout=config["dropout"], 41 | ) 42 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 43 | self.hist_rnn = DynamicRNN(self.hist_rnn) 44 | self.ques_rnn = DynamicRNN(self.ques_rnn) 45 | self.cap_rnn = DynamicRNN(self.cap_rnn) 46 | 47 | ##q c att on img 48 | self.Wq2 = nn.Sequential(self.dropout, nn.Linear(self.nhid * 2, self.nhid)) 49 | self.Wi2 = nn.Sequential(self.dropout, nn.Linear(self.img_feature_size, self.nhid)) 50 | self.Wall2 = nn.Linear(self.nhid, 1) 51 | 52 | # q_att_on_cap 53 | self.Wqs3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 54 | self.Wcs3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 55 | self.Wc3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 56 | self.Wall3 = nn.Linear(self.nhid, 1) 57 | self.c2c = nn.Sequential(self.dropout, nn.Linear(self.ninp, self.nhid)) 58 | 59 | # c_att_on_ques 60 | self.Wqs5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 61 | self.Wcs5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 62 | self.Wq5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 63 | self.Wall5 = nn.Linear(self.nhid, 1) 64 | self.q2q = nn.Sequential(self.dropout, nn.Linear(self.ninp, self.nhid)) 65 | # q att on h 66 | self.Wq1 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 67 | self.Wh1 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 68 | self.Wqh1 = nn.Linear(self.nhid, 1) 69 | ###cap att img 70 | self.Wc4 = nn.Sequential(self.dropout, nn.Linear(self.nhid * 2, self.nhid)) 71 | self.Wi4 = nn.Sequential(self.dropout, nn.Linear(self.img_feature_size, self.nhid)) 72 | self.Wall4 = nn.Linear(self.nhid, 1) 73 | ##fusion 74 | self.fusion_1 = nn.Sequential( 75 | nn.Dropout(p=config["dropout_fc"]), 76 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 77 | nn.LeakyReLU() 78 | ) 79 | self.fusion_2 = nn.Sequential( 80 | nn.Dropout(p=config["dropout_fc"]), 81 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 82 | nn.LeakyReLU() 83 | ) 84 | self.fusion_3 = nn.Sequential( 85 | nn.Dropout(p=config["dropout_fc"]), 86 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 87 | nn.LeakyReLU() 88 | ) 89 | self.q_ref = nn.Sequential( 90 | nn.Dropout(p=config["dropout_fc"]), 91 | nn.Linear(self.nhid * 2, self.nhid), 92 | nn.LeakyReLU(), 93 | nn.Dropout(p=config["dropout_fc"]), 94 | nn.Linear(self.nhid, 2), 95 | nn.LeakyReLU() 96 | ) 97 | self.q_multi = nn.Sequential( 98 | nn.Dropout(p=config["dropout_fc"]), 99 | nn.Linear(self.nhid * 2, self.nhid), 100 | nn.LeakyReLU(), 101 | nn.Dropout(p=config["dropout_fc"]), 102 | nn.Linear(self.nhid, 3), 103 | nn.LeakyReLU() 104 | ) 105 | for m in self.modules(): 106 | if isinstance(m, nn.Linear): 107 | nn.init.kaiming_uniform_(m.weight.data) 108 | if m.bias is not None: 109 | nn.init.constant_(m.bias.data, 0) 110 | 111 | def q_att_on_cap(self, ques_prin, cap_prin, cap_feat, cap_len, cap_emb): 112 | batch_size = cap_feat.size(0) 113 | capfeat_len = cap_feat.size(1) 114 | q_emb = self.Wqs3(ques_prin).view(batch_size, -1, self.nhid) 115 | c_emb = self.Wcs3(cap_prin).view(batch_size, -1, self.nhid) 116 | cap_feat_new = self.Wc3(cap_feat) 117 | cap_score = self.Wall3( 118 | self.dropout( 119 | torch.tanh(cap_feat_new + q_emb.repeat(1, capfeat_len, 1) + c_emb.repeat(1, capfeat_len, 1)) 120 | ) 121 | ).view(batch_size, -1) 122 | mask = cap_score.detach().eq(0) 123 | for i in range(batch_size): 124 | mask[i, cap_len[i]:] = 1 125 | cap_score.masked_fill_(mask, -1e5) 126 | weight = torch.softmax(cap_score, dim=-1) 127 | final_cap_feat = torch.bmm( 128 | weight.view(batch_size, 1, -1), 129 | cap_emb).view(batch_size, -1) 130 | final_cap_feat = self.c2c(final_cap_feat) 131 | return final_cap_feat 132 | 133 | def c_att_on_ques(self, cap_prin, ques_prin, ques_feat, ques_len, ques_emb): 134 | batch_size = ques_feat.size(0) 135 | quesfeat_len = ques_feat.size(1) 136 | q_emb = self.Wqs5(ques_prin).view(batch_size, -1, self.nhid) 137 | c_emb = self.Wcs5(cap_prin).view(batch_size, -1, self.nhid) 138 | ques_feat_new = self.Wq5(ques_feat) 139 | ques_score = self.Wall5( 140 | self.dropout( 141 | torch.tanh(ques_feat_new + q_emb.repeat(1, quesfeat_len, 1) + c_emb.repeat(1, quesfeat_len, 1)) 142 | ) 143 | ).view(batch_size, -1) 144 | mask = ques_score.detach().eq(0) 145 | for i in range(batch_size): 146 | mask[i, ques_len[i]:] = 1 147 | ques_score.masked_fill_(mask, -1e5) 148 | weight = torch.softmax(ques_score, dim=-1) 149 | final_ques_feat = torch.bmm( 150 | weight.view(batch_size, 1, -1), 151 | ques_emb).view(batch_size, -1) 152 | final_ques_feat = self.q2q(final_ques_feat) 153 | return final_ques_feat 154 | 155 | def q_att_on_img(self, ques_feat, img_feat): 156 | batch_size = ques_feat.size(0) 157 | region_size = img_feat.size(1) 158 | device = ques_feat.device 159 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 160 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 161 | all_score = self.Wall2( 162 | self.dropout( 163 | torch.tanh(i_emb * q_emb.repeat(1, region_size, 1)) 164 | ) 165 | ).view(batch_size, -1) 166 | img_final_feat = torch.bmm( 167 | torch.softmax(all_score, dim=-1) 168 | .view(batch_size, 1, -1), img_feat) 169 | return img_final_feat.view(batch_size, -1) 170 | 171 | def c_att_on_img(self, cap_feat, img_feat): 172 | batch_size = cap_feat.size(0) 173 | region_size = img_feat.size(1) 174 | device = cap_feat.device 175 | c_emb = self.Wc4(cap_feat).view(batch_size, -1, self.nhid) 176 | i_emb = self.Wi4(img_feat).view(batch_size, -1, self.nhid) 177 | all_score = self.Wall4( 178 | self.dropout( 179 | torch.tanh(i_emb * c_emb.repeat(1, region_size, 1)) 180 | ) 181 | ).view(batch_size, -1) 182 | img_final_feat = torch.bmm( 183 | torch.softmax(all_score, dim=-1) 184 | .view(batch_size, 1, -1), img_feat) 185 | return img_final_feat.view(batch_size, -1) 186 | 187 | ################################################add h 188 | def ques_att_on_his(self, ques_feat, his_feat): 189 | batch_size = ques_feat.size(0) 190 | rnd = his_feat.size(1) 191 | device = ques_feat.device 192 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 193 | h_emb = self.Wh1(his_feat) 194 | 195 | score = self.Wqh1( 196 | self.dropout( 197 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)) 198 | ) 199 | ).view(batch_size, -1) 200 | weight = torch.softmax(score, dim=-1) 201 | atted_his_feat = torch.bmm(weight.view(batch_size, 1, -1), his_feat) 202 | return atted_his_feat 203 | 204 | ##################################################### 205 | 206 | def forward(self, batch): 207 | img = batch["img_feat"] # b 36 2048 208 | ques = batch["ques"] # b q_len 209 | his = batch["hist"] # b rnd q_len*2 210 | batch_size, rnd, max_his_length = his.size() 211 | cap = his[:, 0, :] 212 | ques_len = batch["ques_len"] 213 | cap_len = batch["hist_len"][:, 0] 214 | 215 | # embed questions 216 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 217 | ques_embed = self.word_embed(ques) # b 20 300 218 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 219 | ques_encoded = q_output[range(batch_size), ques_location, :] 220 | 221 | # embed caption 222 | cap_location = cap_len.view(-1).cpu().numpy() - 1 223 | cap_emb = self.word_embed(cap.contiguous()) 224 | c_output, _ = self.cap_rnn(cap_emb, cap_len.view(-1)) 225 | cap_encoded = c_output[range(batch_size), cap_location, :] 226 | 227 | ####his emb 228 | his = his.contiguous().view(-1, max_his_length) 229 | his_embed = self.word_embed(his) # b*rnd 40 300 230 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 231 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 232 | q_att_his_feat = self.ques_att_on_his(ques_encoded, his_feat).view(batch_size, self.nhid) # b 512 233 | att_cap_feat_0 = self.q_att_on_cap(ques_encoded, cap_encoded, c_output, cap_len, 234 | cap_emb) # shape: (batch_size, 2*nhid) 235 | att_ques_feat_0 = self.c_att_on_ques(cap_encoded, ques_encoded, q_output, ques_len, ques_embed) 236 | his_att_cap = self.q_att_on_cap(q_att_his_feat, cap_encoded, c_output, cap_len, cap_emb) 237 | 238 | att_ques_feat = torch.cat((ques_encoded, att_ques_feat_0), dim=-1) 239 | att_cap_feat = torch.cat((cap_encoded, att_cap_feat_0), dim=-1) 240 | # attended ques on img 241 | q_att_img_feat = self.q_att_on_img(att_ques_feat, img).view(batch_size, -1, self.img_feature_size) 242 | # attended cap on img 243 | c_att_img_feat = self.c_att_on_img(att_cap_feat, img).view(batch_size, -1, self.img_feature_size) 244 | cated_feature = torch.cat((q_att_img_feat, c_att_img_feat), dim=1) 245 | # refine image attention 246 | q_gs = torch.softmax((self.q_ref(torch.cat((att_ques_feat_0, his_att_cap), dim=-1))), dim=-1).view(batch_size, 247 | 1, -1) 248 | final_img_feat = torch.bmm(q_gs, cated_feature).view(batch_size, -1) 249 | # final fusion 250 | fused_vector = torch.cat((att_ques_feat, final_img_feat, q_att_his_feat), dim=-1) 251 | fused_embedding_1 = self.fusion_1(fused_vector).view(batch_size, 1, -1) 252 | fused_embedding_2 = self.fusion_2(fused_vector).view(batch_size, 1, -1) 253 | fused_embedding_3 = self.fusion_3(fused_vector).view(batch_size, 1, -1) 254 | 255 | fused_embedding = torch.cat((fused_embedding_1, fused_embedding_2, fused_embedding_3), dim=1) 256 | q_multi = torch.softmax((self.q_multi(att_ques_feat)), dim=-1).view(batch_size, 1, -1) 257 | fuse_feat = torch.bmm(q_multi, fused_embedding).view(batch_size, -1) 258 | 259 | return fuse_feat # out is b * 512 260 | -------------------------------------------------------------------------------- /visdialch/encoders/lf_enhanced_withP1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from visdialch.utils import DynamicRNN 4 | 5 | 6 | class LF_Enhanced_withP1_Encoder(nn.Module): 7 | def __init__(self, config, vocabulary): 8 | super().__init__() 9 | self.config = config 10 | self.dropout = config['dropout'] 11 | self.nhid = config['lstm_hidden_size'] 12 | self.img_feature_size = config['img_feature_size'] 13 | self.ninp = config['word_embedding_size'] 14 | self.head_num = config['head_num'] 15 | self.word_embed = nn.Embedding( 16 | len(vocabulary), 17 | config["word_embedding_size"], 18 | padding_idx=vocabulary.PAD_INDEX, 19 | ) 20 | self.hist_rnn = nn.LSTM( 21 | config["word_embedding_size"], 22 | config["lstm_hidden_size"], 23 | config["lstm_num_layers"], 24 | batch_first=True, 25 | dropout=config["dropout"], 26 | ) 27 | self.ques_rnn = nn.LSTM( 28 | config["word_embedding_size"], 29 | config["lstm_hidden_size"], 30 | config["lstm_num_layers"], 31 | batch_first=True, 32 | dropout=config["dropout"], 33 | ) 34 | self.cap_rnn = nn.LSTM( 35 | config["word_embedding_size"], 36 | config["lstm_hidden_size"], 37 | config["lstm_num_layers"], 38 | batch_first=True, 39 | dropout=config["dropout"], 40 | ) 41 | self.dropout = nn.Dropout(p=config["dropout_fc"]) 42 | self.hist_rnn = DynamicRNN(self.hist_rnn) 43 | self.ques_rnn = DynamicRNN(self.ques_rnn) 44 | self.cap_rnn = DynamicRNN(self.cap_rnn) 45 | 46 | ##q c att on img 47 | self.Wq2 = nn.Sequential(self.dropout, nn.Linear(self.nhid * 2, self.nhid)) 48 | self.Wi2 = nn.Sequential(self.dropout, nn.Linear(self.img_feature_size, self.nhid)) 49 | self.Wall2 = nn.Linear(self.nhid, 1) 50 | 51 | # q_att_on_cap 52 | self.Wqs3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 53 | self.Wcs3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 54 | self.Wc3 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 55 | self.Wall3 = nn.Linear(self.nhid, 1) 56 | self.c2c = nn.Sequential(self.dropout, nn.Linear(self.ninp, self.nhid)) 57 | 58 | # c_att_on_ques 59 | self.Wqs5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 60 | self.Wcs5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 61 | self.Wq5 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 62 | self.Wall5 = nn.Linear(self.nhid, 1) 63 | self.q2q = nn.Sequential(self.dropout, nn.Linear(self.ninp, self.nhid)) 64 | # q att on h 65 | self.Wq1 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 66 | self.Wh1 = nn.Sequential(self.dropout, nn.Linear(self.nhid, self.nhid)) 67 | self.Wqh1 = nn.Linear(self.nhid, 1) 68 | ###cap att img 69 | self.Wc4 = nn.Sequential(self.dropout, nn.Linear(self.nhid * 2, self.nhid)) 70 | self.Wi4 = nn.Sequential(self.dropout, nn.Linear(self.img_feature_size, self.nhid)) 71 | self.Wall4 = nn.Linear(self.nhid, 1) 72 | ##fusion 73 | self.i2i = nn.Sequential(self.dropout, nn.Linear(self.img_feature_size, self.nhid)) 74 | self.fusion_1 = nn.Sequential( 75 | nn.Dropout(p=config["dropout_fc"]), 76 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 77 | nn.LeakyReLU() 78 | ) 79 | self.fusion_2 = nn.Sequential( 80 | nn.Dropout(p=config["dropout_fc"]), 81 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 82 | nn.LeakyReLU() 83 | ) 84 | self.fusion_3 = nn.Sequential( 85 | nn.Dropout(p=config["dropout_fc"]), 86 | nn.Linear(self.nhid * 2 + self.img_feature_size + self.nhid, self.nhid), 87 | nn.LeakyReLU() 88 | ) 89 | self.q_ref = nn.Sequential( 90 | nn.Dropout(p=config["dropout_fc"]), 91 | nn.Linear(self.nhid * 2, self.nhid), 92 | nn.LeakyReLU(), 93 | nn.Dropout(p=config["dropout_fc"]), 94 | nn.Linear(self.nhid, 2), 95 | nn.LeakyReLU() 96 | ) 97 | self.q_multi = nn.Sequential( 98 | nn.Dropout(p=config["dropout_fc"]), 99 | nn.Linear(self.nhid * 2, self.nhid), 100 | nn.LeakyReLU(), 101 | nn.Dropout(p=config["dropout_fc"]), 102 | nn.Linear(self.nhid, 3), 103 | nn.LeakyReLU() 104 | ) 105 | for m in self.modules(): 106 | if isinstance(m, nn.Linear): 107 | nn.init.kaiming_uniform_(m.weight.data) 108 | if m.bias is not None: 109 | nn.init.constant_(m.bias.data, 0) 110 | 111 | def q_att_on_cap(self, ques_prin, cap_prin, cap_feat, cap_len, cap_emb): 112 | batch_size = cap_feat.size(0) 113 | capfeat_len = cap_feat.size(1) 114 | q_emb = self.Wqs3(ques_prin).view(batch_size, -1, self.nhid) 115 | c_emb = self.Wcs3(cap_prin).view(batch_size, -1, self.nhid) 116 | cap_feat_new = self.Wc3(cap_feat) 117 | cap_score = self.Wall3( 118 | self.dropout( 119 | torch.tanh(cap_feat_new + q_emb.repeat(1, capfeat_len, 1) + c_emb.repeat(1, capfeat_len, 1)) 120 | ) 121 | ).view(batch_size, -1) 122 | mask = cap_score.detach().eq(0) 123 | for i in range(batch_size): 124 | mask[i, cap_len[i]:] = 1 125 | cap_score.masked_fill_(mask, -1e5) 126 | weight = torch.softmax(cap_score, dim=-1) 127 | final_cap_feat = torch.bmm( 128 | weight.view(batch_size, 1, -1), 129 | cap_emb).view(batch_size, -1) 130 | final_cap_feat = self.c2c(final_cap_feat) 131 | return final_cap_feat 132 | 133 | def c_att_on_ques(self, cap_prin, ques_prin, ques_feat, ques_len, ques_emb): 134 | batch_size = ques_feat.size(0) 135 | quesfeat_len = ques_feat.size(1) 136 | q_emb = self.Wqs5(ques_prin).view(batch_size, -1, self.nhid) 137 | c_emb = self.Wcs5(cap_prin).view(batch_size, -1, self.nhid) 138 | ques_feat_new = self.Wq5(ques_feat) 139 | ques_score = self.Wall5( 140 | self.dropout( 141 | torch.tanh(ques_feat_new + q_emb.repeat(1, quesfeat_len, 1) + c_emb.repeat(1, quesfeat_len, 1)) 142 | ) 143 | ).view(batch_size, -1) 144 | mask = ques_score.detach().eq(0) 145 | for i in range(batch_size): 146 | mask[i, ques_len[i]:] = 1 147 | ques_score.masked_fill_(mask, -1e5) 148 | weight = torch.softmax(ques_score, dim=-1) 149 | final_ques_feat = torch.bmm( 150 | weight.view(batch_size, 1, -1), 151 | ques_emb).view(batch_size, -1) 152 | final_ques_feat = self.q2q(final_ques_feat) 153 | return final_ques_feat 154 | 155 | def q_att_on_img(self, ques_feat, img_feat): 156 | batch_size = ques_feat.size(0) 157 | region_size = img_feat.size(1) 158 | device = ques_feat.device 159 | q_emb = self.Wq2(ques_feat).view(batch_size, -1, self.nhid) 160 | i_emb = self.Wi2(img_feat).view(batch_size, -1, self.nhid) 161 | all_score = self.Wall2( 162 | self.dropout( 163 | torch.tanh(i_emb * q_emb.repeat(1, region_size, 1)) 164 | ) 165 | ).view(batch_size, -1) 166 | img_final_feat = torch.bmm( 167 | torch.softmax(all_score, dim=-1) 168 | .view(batch_size, 1, -1), img_feat) 169 | return img_final_feat.view(batch_size, -1) 170 | 171 | def c_att_on_img(self, cap_feat, img_feat): 172 | batch_size = cap_feat.size(0) 173 | region_size = img_feat.size(1) 174 | device = cap_feat.device 175 | c_emb = self.Wc4(cap_feat).view(batch_size, -1, self.nhid) 176 | i_emb = self.Wi4(img_feat).view(batch_size, -1, self.nhid) 177 | all_score = self.Wall4( 178 | self.dropout( 179 | torch.tanh(i_emb * c_emb.repeat(1, region_size, 1)) 180 | ) 181 | ).view(batch_size, -1) 182 | img_final_feat = torch.bmm( 183 | torch.softmax(all_score, dim=-1) 184 | .view(batch_size, 1, -1), img_feat) 185 | return img_final_feat.view(batch_size, -1) 186 | 187 | ################################################add h 188 | def ques_att_on_his(self, ques_feat, his_feat): 189 | batch_size = ques_feat.size(0) 190 | rnd = his_feat.size(1) 191 | device = ques_feat.device 192 | q_emb = self.Wq1(ques_feat).view(batch_size, -1, self.nhid) 193 | h_emb = self.Wh1(his_feat) 194 | 195 | score = self.Wqh1( 196 | self.dropout( 197 | torch.tanh(h_emb + q_emb.repeat(1, rnd, 1)) 198 | ) 199 | ).view(batch_size, -1) 200 | weight = torch.softmax(score, dim=-1) 201 | atted_his_feat = torch.bmm(weight.view(batch_size, 1, -1), his_feat) 202 | return atted_his_feat 203 | 204 | ##################################################### 205 | 206 | def forward(self, batch): 207 | img = batch["img_feat"] # b 36 2048 208 | ques = batch["ques"] # b q_len 209 | his = batch["hist"] # b rnd q_len*2 210 | batch_size, rnd, max_his_length = his.size() 211 | cap = his[:, 0, :] 212 | ques_len = batch["ques_len"] 213 | cap_len = batch["hist_len"][:, 0] 214 | 215 | # embed questions 216 | ques_location = batch['ques_len'].view(-1).cpu().numpy() - 1 217 | ques_embed = self.word_embed(ques) # b 20 300 218 | q_output, _ = self.ques_rnn(ques_embed, ques_len.view(-1)) # b rnd 1024 219 | ques_encoded = q_output[range(batch_size), ques_location, :] 220 | 221 | # embed caption 222 | cap_location = cap_len.view(-1).cpu().numpy() - 1 223 | cap_emb = self.word_embed(cap.contiguous()) 224 | c_output, _ = self.cap_rnn(cap_emb, cap_len.view(-1)) 225 | cap_encoded = c_output[range(batch_size), cap_location, :] 226 | 227 | ####his emb 228 | his = his.contiguous().view(-1, max_his_length) 229 | his_embed = self.word_embed(his) # b*rnd 40 300 230 | _, (his_feat, _) = self.hist_rnn(his_embed, batch["hist_len"].contiguous().view(-1)) # b*rnd step 1024 231 | his_feat = his_feat.view(batch_size, rnd, self.nhid) 232 | q_att_his_feat = self.ques_att_on_his(ques_encoded, his_feat).view(batch_size, self.nhid) # b 512 233 | att_cap_feat_0 = self.q_att_on_cap(ques_encoded, cap_encoded, c_output, cap_len, 234 | cap_emb) # (batch_size, 2*nhid) 235 | att_ques_feat_0 = self.c_att_on_ques(cap_encoded, ques_encoded, q_output, ques_len, ques_embed) 236 | his_att_cap = self.q_att_on_cap(q_att_his_feat, cap_encoded, c_output, cap_len, cap_emb) 237 | 238 | att_ques_feat = torch.cat((ques_encoded, att_ques_feat_0), dim=-1) 239 | att_cap_feat = torch.cat((cap_encoded, att_cap_feat_0), dim=-1) 240 | # attended ques on img 241 | q_att_img_feat = self.q_att_on_img(att_ques_feat, img).view(batch_size, -1, self.img_feature_size) 242 | # attended cap on img 243 | c_att_img_feat = self.c_att_on_img(att_cap_feat, img).view(batch_size, -1, self.img_feature_size) 244 | cated_feature = torch.cat((q_att_img_feat, c_att_img_feat), dim=1) 245 | q_gs = torch.softmax((self.q_ref(torch.cat((att_ques_feat_0, his_att_cap), dim=-1))), dim=-1).view(batch_size, 246 | 1, -1) 247 | final_img_feat = torch.bmm(q_gs, cated_feature).view(batch_size, -1) 248 | 249 | img_feat_fusion = self.i2i(final_img_feat) 250 | fused_vector = torch.cat((att_ques_feat, final_img_feat, ques_encoded * img_feat_fusion), dim=-1) 251 | fused_embedding_1 = self.fusion_1(fused_vector).view(batch_size, 1, -1) 252 | fused_embedding_2 = self.fusion_2(fused_vector).view(batch_size, 1, -1) 253 | fused_embedding_3 = self.fusion_3(fused_vector).view(batch_size, 1, -1) 254 | 255 | fused_embedding = torch.cat((fused_embedding_1, fused_embedding_2, fused_embedding_3), dim=1) 256 | q_multi = torch.softmax((self.q_multi(att_ques_feat)), dim=-1).view(batch_size, 1, -1) 257 | fuse_feat = torch.bmm(q_multi, fused_embedding).view(batch_size, -1) 258 | 259 | return fuse_feat # out is b * 512 260 | -------------------------------------------------------------------------------- /visdialch/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Metric observes output of certain model, for example, in form of logits or 3 | scores, and accumulates a particular metric with reference to some provided 4 | targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean 5 | Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). 6 | 7 | Each ``Metric`` must atleast implement three methods: 8 | - ``observe``, update accumulated metric with currently observed outputs 9 | and targets. 10 | - ``retrieve`` to return the accumulated metric., an optionally reset 11 | internally accumulated metric (this is commonly done between two epochs 12 | after validation). 13 | - ``reset`` to explicitly reset the internally accumulated metric. 14 | 15 | Caveat, if you wish to implement your own class of Metric, make sure you call 16 | ``detach`` on output tensors (like logits), else it will cause memory leaks. 17 | """ 18 | import torch 19 | 20 | 21 | def scores_to_ranks(scores: torch.Tensor): 22 | """Convert model output scores into ranks.""" 23 | batch_size, num_rounds, num_options = scores.size() 24 | scores = scores.view(-1, num_options) 25 | 26 | # sort in descending order - largest score gets highest rank 27 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 28 | _, real_rank = torch.sort(ranked_idx) 29 | # i-th position in ranked_idx specifies which score shall take this 30 | # position but we want i-th position to have rank of score at that 31 | # position, do this conversion 32 | ''' 33 | ranks = ranked_idx.clone().fill_(0) 34 | for i in range(ranked_idx.size(0)): 35 | for j in range(num_options): 36 | ranks[i][ranked_idx[i][j]] = j 37 | # convert from 0-99 ranks to 1-100 ranks 38 | ranks += 1 39 | ranks = ranks.view(batch_size, num_rounds, num_options) 40 | ''' 41 | ranks = (real_rank + 1).view(batch_size, num_rounds, -1) 42 | return ranks 43 | 44 | 45 | class SparseGTMetrics(object): 46 | """ 47 | A class to accumulate all metrics with sparse ground truth annotations. 48 | These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. 49 | """ 50 | 51 | def __init__(self): 52 | self._rank_list = [] 53 | 54 | def observe( 55 | self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor 56 | ): 57 | predicted_scores = predicted_scores.detach() 58 | 59 | # shape: (batch_size, num_rounds, num_options) 60 | predicted_ranks = scores_to_ranks(predicted_scores) 61 | batch_size, num_rounds, num_options = predicted_ranks.size() 62 | 63 | # collapse batch dimension 64 | predicted_ranks = predicted_ranks.view( 65 | batch_size * num_rounds, num_options 66 | ) 67 | 68 | # shape: (batch_size * num_rounds, ) 69 | target_ranks = target_ranks.view(batch_size * num_rounds).long() 70 | 71 | # shape: (batch_size * num_rounds, ) 72 | predicted_gt_ranks = predicted_ranks[ 73 | torch.arange(batch_size * num_rounds), target_ranks 74 | ] 75 | self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) 76 | 77 | def retrieve(self, reset: bool = True): 78 | num_examples = len(self._rank_list) 79 | if num_examples > 0: 80 | # convert to numpy array for easy calculation. 81 | __rank_list = torch.tensor(self._rank_list).float() 82 | metrics = { 83 | "r@1": torch.mean((__rank_list <= 1).float()).item(), 84 | "r@5": torch.mean((__rank_list <= 5).float()).item(), 85 | "r@10": torch.mean((__rank_list <= 10).float()).item(), 86 | "mean": torch.mean(__rank_list).item(), 87 | "mrr": torch.mean(__rank_list.reciprocal()).item(), 88 | } 89 | else: 90 | metrics = {} 91 | 92 | if reset: 93 | self.reset() 94 | return metrics 95 | 96 | def reset(self): 97 | self._rank_list = [] 98 | 99 | 100 | class NDCG(object): 101 | def __init__(self): 102 | self._ndcg_numerator = 0.0 103 | self._ndcg_denominator = 0.0 104 | 105 | def observe( 106 | self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor 107 | ): 108 | """ 109 | Observe model output scores and target ground truth relevance and 110 | accumulate NDCG metric. 111 | 112 | Parameters 113 | ---------- 114 | predicted_scores: torch.Tensor 115 | A tensor of shape (batch_size, num_options), because dense 116 | annotations are available for 1 randomly picked round out of 10. 117 | target_relevance: torch.Tensor 118 | A tensor of shape same as predicted scores, indicating ground truth 119 | relevance of each answer option for a particular round. 120 | """ 121 | batch_size, num_options = predicted_scores.size() 122 | predicted_scores = predicted_scores.detach() 123 | 124 | # shape: (batch_size, 1, num_options) 125 | predicted_scores = predicted_scores.unsqueeze(1) 126 | predicted_ranks = scores_to_ranks(predicted_scores) 127 | 128 | # shape: (batch_size, num_options) 129 | predicted_ranks = predicted_ranks.view(batch_size, num_options) 130 | 131 | 132 | k = torch.sum(target_relevance != 0, dim=-1) 133 | 134 | # shape: (batch_size, num_options) 135 | _, rankings = torch.sort(predicted_ranks, dim=-1) 136 | # Sort relevance in descending order so highest relevance gets top rnk. 137 | _, best_rankings = torch.sort( 138 | target_relevance, dim=-1, descending=True 139 | ) 140 | 141 | # shape: (batch_size, ) 142 | batch_ndcg = [] 143 | for batch_index in range(batch_size): 144 | num_relevant = k[batch_index] 145 | dcg = self._dcg( 146 | rankings[batch_index][:num_relevant], 147 | target_relevance[batch_index], 148 | ) 149 | best_dcg = self._dcg( 150 | best_rankings[batch_index][:num_relevant], 151 | target_relevance[batch_index], 152 | ) 153 | batch_ndcg.append(dcg / best_dcg) 154 | 155 | self._ndcg_denominator += batch_size 156 | self._ndcg_numerator += sum(batch_ndcg) 157 | 158 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): 159 | sorted_relevance = relevance[rankings].cpu().float() 160 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2) 161 | return torch.sum(sorted_relevance / discounts, dim=-1) 162 | 163 | def retrieve(self, reset: bool = True): 164 | if self._ndcg_denominator > 0: 165 | metrics = { 166 | "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) 167 | } 168 | else: 169 | metrics = {} 170 | 171 | if reset: 172 | self.reset() 173 | return metrics 174 | 175 | def reset(self): 176 | self._ndcg_numerator = 0.0 177 | self._ndcg_denominator = 0.0 178 | -------------------------------------------------------------------------------- /visdialch/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class EncoderDecoderModel(nn.Module): 5 | """Convenience wrapper module, wrapping Encoder and Decoder modules. 6 | 7 | Parameters 8 | ---------- 9 | encoder: nn.Module 10 | decoder: nn.Module 11 | """ 12 | 13 | def __init__(self, encoder, decoder): 14 | super().__init__() 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | 18 | def forward(self, batch): 19 | encoder_output = self.encoder(batch) 20 | decoder_output = self.decoder(encoder_output, batch) 21 | return decoder_output 22 | -------------------------------------------------------------------------------- /visdialch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic_rnn import DynamicRNN 2 | 3 | -------------------------------------------------------------------------------- /visdialch/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | """ 2 | A checkpoint manager periodically saves model and optimizer as .pth 3 | files during training. 4 | 5 | Checkpoint managers help with experiment reproducibility, they record 6 | the commit SHA of your current codebase in the checkpoint saving 7 | directory. While loading any checkpoint from other commit, they raise a 8 | friendly warning, a signal to inspect commit diffs for potential bugs. 9 | Moreover, they copy experiment hyper-parameters as a YAML config in 10 | this directory. 11 | 12 | That said, always run your experiments after committing your changes, 13 | this doesn't account for untracked or staged, but uncommitted changes. 14 | """ 15 | from pathlib import Path 16 | from subprocess import PIPE, Popen 17 | import warnings 18 | 19 | import torch 20 | from torch import nn, optim 21 | import yaml 22 | 23 | 24 | class CheckpointManager(object): 25 | """A checkpoint manager saves state dicts of model and optimizer 26 | as .pth files in a specified directory. This class closely follows 27 | the API of PyTorch optimizers and learning rate schedulers. 28 | 29 | Note:: 30 | For ``DataParallel`` modules, ``model.module.state_dict()`` is 31 | saved, instead of ``model.state_dict()``. 32 | 33 | Parameters 34 | ---------- 35 | model: nn.Module 36 | Wrapped model, which needs to be checkpointed. 37 | optimizer: optim.Optimizer 38 | Wrapped optimizer which needs to be checkpointed. 39 | checkpoint_dirpath: str 40 | Path to an empty or non-existent directory to save checkpoints. 41 | step_size: int, optional (default=1) 42 | Period of saving checkpoints. 43 | last_epoch: int, optional (default=-1) 44 | The index of last epoch. 45 | 46 | Example 47 | -------- 48 | >>> model = torch.nn.Linear(10, 2) 49 | >>> optimizer = torch.optim.Adam(model.parameters()) 50 | >>> ckpt_manager = CheckpointManager(model, optimizer, "/tmp/ckpt") 51 | >>> for epoch in range(20): 52 | ... for batch in dataloader: 53 | ... do_iteration(batch) 54 | ... ckpt_manager.step() 55 | """ 56 | 57 | def __init__( 58 | self, 59 | model, 60 | optimizer, 61 | checkpoint_dirpath, 62 | step_size=1, 63 | last_epoch=-1, 64 | **kwargs, 65 | ): 66 | 67 | if not isinstance(model, nn.Module): 68 | raise TypeError("{} is not a Module".format(type(model).__name__)) 69 | 70 | if not isinstance(optimizer, optim.Optimizer): 71 | raise TypeError( 72 | "{} is not an Optimizer".format(type(optimizer).__name__) 73 | ) 74 | 75 | self.model = model 76 | self.optimizer = optimizer 77 | self.ckpt_dirpath = Path(checkpoint_dirpath) 78 | self.step_size = step_size 79 | self.last_epoch = last_epoch 80 | self.init_directory(**kwargs) 81 | 82 | def init_directory(self, config={}): 83 | """Initialize empty checkpoint directory and record commit SHA 84 | in it. Also save hyper-parameters config in this directory to 85 | associate checkpoints with their hyper-parameters. 86 | """ 87 | self.ckpt_dirpath.mkdir(parents=True, exist_ok=True) 88 | # save current git commit hash in this checkpoint directory 89 | commit_sha_subprocess = Popen( 90 | ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 91 | ) 92 | commit_sha, _ = commit_sha_subprocess.communicate() 93 | commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 94 | commit_sha_filepath = self.ckpt_dirpath / f".commit-{commit_sha}" 95 | commit_sha_filepath.touch() 96 | yaml.dump( 97 | config, 98 | open(str(self.ckpt_dirpath / "config.yml"), "w"), 99 | default_flow_style=False, 100 | ) 101 | 102 | def step(self, epoch=None): 103 | """Save checkpoint if step size conditions meet. """ 104 | 105 | if not epoch: 106 | epoch = self.last_epoch + 1 107 | self.last_epoch = epoch 108 | 109 | if not self.last_epoch % self.step_size: 110 | torch.save( 111 | { 112 | "model": self._model_state_dict(), 113 | "optimizer": self.optimizer.state_dict(), 114 | }, 115 | self.ckpt_dirpath / f"checkpoint_{self.last_epoch}.pth", 116 | ) 117 | 118 | def _model_state_dict(self): 119 | """Returns state dict of model, taking care of DataParallel case.""" 120 | if isinstance(self.model, nn.DataParallel): 121 | return self.model.module.state_dict() 122 | else: 123 | return self.model.state_dict() 124 | 125 | 126 | def load_checkpoint(checkpoint_pthpath): 127 | """Given a path to saved checkpoint, load corresponding state dicts 128 | of model and optimizer from it. This method checks if the current 129 | commit SHA of codebase matches the commit SHA recorded when this 130 | checkpoint was saved by checkpoint manager. 131 | 132 | Parameters 133 | ---------- 134 | checkpoint_pthpath: str or pathlib.Path 135 | Path to saved checkpoint (as created by ``CheckpointManager``). 136 | 137 | Returns 138 | ------- 139 | nn.Module, optim.Optimizer 140 | Model and optimizer state dicts loaded from checkpoint. 141 | 142 | Raises 143 | ------ 144 | UserWarning 145 | If commit SHA do not match, or if the directory doesn't have 146 | the recorded commit SHA. 147 | """ 148 | 149 | if isinstance(checkpoint_pthpath, str): 150 | checkpoint_pthpath = Path(checkpoint_pthpath) 151 | checkpoint_dirpath = checkpoint_pthpath.resolve().parent 152 | checkpoint_commit_sha = list(checkpoint_dirpath.glob(".commit-*")) 153 | 154 | # if len(checkpoint_commit_sha) == 0: 155 | # raise UserWarning( 156 | # "Commit SHA was not recorded while saving checkpoints." 157 | # ) 158 | # else: 159 | # # verify commit sha, raise warning if it doesn't match 160 | # commit_sha_subprocess = Popen( 161 | # ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE 162 | # ) 163 | # commit_sha, _ = commit_sha_subprocess.communicate() 164 | # commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") 165 | # 166 | # # remove ".commit-" 167 | # checkpoint_commit_sha = checkpoint_commit_sha[0].name[8:] 168 | # 169 | # if commit_sha != checkpoint_commit_sha: 170 | # warnings.warn( 171 | # f"Current commit ({commit_sha}) and the commit " 172 | # f"({checkpoint_commit_sha}) at which checkpoint was saved," 173 | # " are different. This might affect reproducibility." 174 | # ) 175 | 176 | # load encoder, decoder, optimizer state_dicts 177 | components = torch.load(checkpoint_pthpath) 178 | return components["model"], components["optimizer"] 179 | -------------------------------------------------------------------------------- /visdialch/utils/dynamic_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class DynamicRNN(nn.Module): 7 | def __init__(self, rnn_model): 8 | super().__init__() 9 | self.rnn_model = rnn_model 10 | 11 | def forward(self, seq_input, seq_lens, initial_state=None): 12 | """A wrapper over pytorch's rnn to handle sequences of variable length. 13 | 14 | Arguments 15 | --------- 16 | seq_input : torch.Tensor 17 | Input sequence tensor (padded) for RNN model. 18 | Shape: (batch_size, max_sequence_length, embed_size) 19 | seq_lens : torch.LongTensor 20 | Length of sequences (b, ) 21 | initial_state : torch.Tensor 22 | Initial (hidden, cell) states of RNN model. 23 | 24 | Returns 25 | ------- 26 | Single tensor of shape (batch_size, rnn_hidden_size) corresponding 27 | to the outputs of the RNN model at the last time step of each input 28 | sequence. 29 | """ 30 | max_sequence_length = seq_input.size(1) 31 | sorted_len, fwd_order, bwd_order = self._get_sorted_order(seq_lens) 32 | sorted_seq_input = seq_input.index_select(0, fwd_order) 33 | packed_seq_input = pack_padded_sequence( 34 | sorted_seq_input, lengths=sorted_len, batch_first=True 35 | ) 36 | 37 | if initial_state is not None: 38 | hx = initial_state 39 | assert hx[0].size(0) == self.rnn_model.num_layers 40 | else: 41 | sorted_hx = None 42 | 43 | self.rnn_model.flatten_parameters() 44 | outputs, (h_n, c_n) = self.rnn_model(packed_seq_input, sorted_hx) 45 | 46 | # pick hidden and cell states of last layer 47 | h_n = h_n[-1].index_select(dim=0, index=bwd_order) 48 | c_n = c_n[-1].index_select(dim=0, index=bwd_order) 49 | 50 | outputs = pad_packed_sequence( 51 | outputs, batch_first=True, total_length=max_sequence_length 52 | )[0].index_select(dim=0, index=bwd_order) 53 | # outputs = pad_packed_sequence( 54 | # outputs, batch_first=True)[0].index_select(dim=0, index=bwd_order) 55 | 56 | return outputs, (h_n, c_n) 57 | 58 | @staticmethod 59 | def _get_sorted_order(lens): 60 | sorted_len, fwd_order = torch.sort( 61 | lens.contiguous().view(-1), 0, descending=True 62 | ) 63 | _, bwd_order = torch.sort(fwd_order) 64 | sorted_len = list(sorted_len) 65 | return sorted_len, fwd_order, bwd_order 66 | --------------------------------------------------------------------------------