├── .gitignore ├── LICENSE ├── README.md ├── config ├── uc2-base.json ├── uc2_mscoco_itm.json ├── uc2_pretrain.json ├── uniter-base-cased.json └── uniter-base.json ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── itm.cpython-36.pyc │ ├── loader.cpython-36.pyc │ ├── mlm.cpython-36.pyc │ ├── mrm.cpython-36.pyc │ ├── mrm_nce.cpython-36.pyc │ ├── nlvr2.cpython-36.pyc │ ├── sampler.cpython-36.pyc │ ├── ve.cpython-36.pyc │ └── vqa.cpython-36.pyc ├── data.py ├── img_label_objects.txt ├── itm.py ├── loader.py ├── mlm.py ├── mrm.py ├── mrm_nce.py ├── nlvr2.py ├── re.py ├── sampler.py ├── test_data │ ├── input0.txt │ ├── input1.txt │ ├── input2.txt │ ├── input3.txt │ ├── input4.txt │ ├── input5.txt │ ├── input6.txt │ └── input7.txt ├── vcr.py ├── ve.py └── vqa.py ├── eval ├── __pycache__ │ └── itm.cpython-36.pyc ├── itm.py └── nlvr2.py ├── itm.py ├── launch_container.sh ├── misc ├── ans2label.json ├── ans2label_en_trans2_ja.json ├── ans2label_ja.json ├── ans2label_ja_trans2_en.json └── ans2label_vg.json ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── attention.cpython-36.pyc │ ├── const_variable.cpython-36.pyc │ ├── itm.cpython-36.pyc │ ├── layer.cpython-36.pyc │ ├── model.cpython-36.pyc │ ├── nlvr2.cpython-36.pyc │ ├── ot.cpython-36.pyc │ ├── ve.cpython-36.pyc │ └── vqa.cpython-36.pyc ├── attention.py ├── const_variable.py ├── gqa.py ├── img_label_objects.txt ├── itm.py ├── layer.py ├── model.py ├── nlvr2.py ├── ot.py ├── re.py ├── vcr.py ├── ve.py └── vqa.py ├── object_labels ├── img_label_objects.txt ├── img_label_objects_cs.txt ├── img_label_objects_de.txt ├── img_label_objects_fr.txt ├── img_label_objects_ja.txt └── img_label_objects_zh.txt ├── optim ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── adamw.cpython-36.pyc │ ├── misc.cpython-36.pyc │ └── sched.cpython-36.pyc ├── adamw.py ├── misc.py └── sched.py ├── pretrain.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── const.cpython-36.pyc ├── distributed.cpython-36.pyc ├── logger.cpython-36.pyc ├── misc.cpython-36.pyc ├── save.cpython-36.pyc ├── visual_entailment.cpython-36.pyc └── vqa.cpython-36.pyc ├── const.py ├── distributed.py ├── itm.py ├── logger.py ├── m3p_tokenizer.py ├── misc.py ├── ms_internal_mt.py ├── ms_internal_mt_label.py ├── ms_internal_mt_popen.py ├── save.py ├── visual_entailment.py └── vqa.py /.gitignore: -------------------------------------------------------------------------------- 1 | config/*_test.json 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mingyang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## UC2 2 | [UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training](https://arxiv.org/abs/2104.00332) 3 |
4 | [Mingyang Zhou](https://github.com/zmykevin), [Luowei Zhou](https://luoweizhou.github.io/), [Shuohang Wang](https://www.microsoft.com/en-us/research/people/shuowa/), [Yu Cheng](https://sites.google.com/site/chengyu05), [Linjie Li](https://www.microsoft.com/en-us/research/people/linjli/), [Zhou Yu](http://www.cs.columbia.edu/~zhouyu/), [Jingjing Liu](https://www.linkedin.com/in/jingjing-liu-65703431/) 5 |
6 | This is the official repository of UC2, a multili-lingual multi-modal pre-training framefork. In this repository we support end-to-end pretraining and finetuning for image-text retrieval on COCO. 7 | 8 | ### Requirements 9 | We Provide a Docker image to run our code. Please install the following: 10 | - [nvidia driver (418+)](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation), 11 | - [Docker (19.03+)](https://docs.docker.com/engine/install/ubuntu/), 12 | - [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-docker#quickstart). 13 | 14 | To run the docker command without sudo, user need to have [docker group membership](https://docs.docker.com/engine/install/linux-postinstall/). Our code only supports Linux with NVIDIA GPUs. We test our code on Ubuntu 18.04 and V100 cards. 15 | 16 | ## Data and Pretrained Checkpoints 17 | Download the pre-processed text features and pretrained checkpoints with the following command: 18 | ``` 19 | wget https://mmaisharables.blob.core.windows.net/uc2/UC2_DATA.tar.gz 20 | 21 | ``` 22 | The image features for mscoco can be obtained from [UNITER](https://github.com/ChenRocks/UNITER) via this [code script](https://github.com/ChenRocks/UNITER/blob/master/scripts/download_itm.sh). As CC's image features are large and inconvient for direct downloading, please contact UNITER's author to obtain the image features if you are interested in pretraining. 23 | 24 | The text translation of CC captions for pre-training purpose can be also found here: 25 | https://drive.google.com/drive/folders/1K6dGXMz8egutKk-UTXbFrp3HQ9QqPfmM?usp=sharing 26 | 27 | ## Launch the Docker Container for Experiments 28 | Once the user set up the data and checkpoints properly, please run the following command to launch a docker container and start the pretraining process. 29 | ``` 30 | source launch_container_pretrain.sh /PATH_TO_STORAGE/txt_db /PATH_TO_STORAGE/img_db /PATH_TO_STORAGE/finetune /PATH_TO_STORAG/pretrain 31 | ``` 32 | ## Pretraining 33 | (Inside the Docker Container)If the user wants to run pretraining, please use the following command: 34 | ``` 35 | horovodrun -np $N_GPU python pretrain.py --config config/uc2_pretrain.json 36 | ``` 37 | ## Downstream Task Finetuning 38 | **Text-to-Image Retrieval** 39 | To run the finetuning experiment for the text-to-image retrieval task, please use the following command: 40 | ``` 41 | horovodrun -np $N_GPU python itm.py --config config/uc2_mscoco_itm.json 42 | ``` 43 | 44 | ## Citation 45 | If you find this code useful for your research, please consider citing: 46 | ``` 47 | @InProceedings{zhou2021uc, 48 | author = {Zhou, Mingyang and Zhou, Luowei and Wang, Shuohang and Cheng, Yu and Li, Linjie and Yu, Zhou and Liu, Jingjing}, 49 | title = {UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training}, 50 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2021)}, 51 | year = {2021}, 52 | month = {June}, 53 | abstract = {Vision-and-language pre-training has achieved impressive success in learning multimodal representations between vision and language. To generalize this success to non-English languages, we introduce UC2 , the first machine translation-augmented framework for cross-lingual cross-modal representation learning. To tackle the scarcity problem of multilingual captions for image datasets, we first augment existing English-only datasets with other languages via machine translation (MT). Then we extend the standard Masked Language Modeling and Image-Text Matching training objectives to multilingual setting, where alignment between different languages is captured through shared visual context (i.e., using image as pivot). To facilitate the learning of a joint embedding space of images and all languages of interest, we further propose two novel pre-training tasks, namely Masked Region-to-Token Modeling (MRTM) and Visual Translation Language Modeling (VTLM), leveraging MT-enhanced translated data. Evaluation on multilingual image-text retrieval and multilingual visual question answering benchmarks demonstrates that our proposed framework achieves new state of the art on diverse non-English benchmarks while maintaining comparable performance to monolingual pre-trained models on English tasks.}, 54 | url = {https://www.microsoft.com/en-us/research/publication/uc2-universal-cross-lingual-cross-modal-vision-and-language-pre-training/}, 55 | } 56 | 57 | ``` 58 | 59 | ## Acknowledge 60 | Our code is mainly based on [Linjie Li](https://www.microsoft.com/en-us/research/people/linjli/) and [Yen-Chun Chen](https://www.microsoft.com/en-us/research/people/yenche/)'s project [UNITER](https://github.com/ChenRocks/UNITER). We thank the author for opening source their code and providing helful discussion for code implementation. Portions of the code also uses resources from [transformers](https://github.com/huggingface/transformers). 61 | 62 | ## Liscense 63 | MIT 64 | -------------------------------------------------------------------------------- /config/uc2-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 514, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "model_type": "xlm-roberta", 12 | "output_past": true, 13 | "type_vocab_size": 2, 14 | "layer_norm_eps": 1e-5, 15 | "pad_token_id": 1, 16 | "vocab_size": 250002 17 | } 18 | -------------------------------------------------------------------------------- /config/uc2_mscoco_itm.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "checkpoint": "/pretrain/uc2/ckpt/best.pt", 4 | "output_dir": "/storage/uc2_mscoco_all_lan", 5 | "max_txt_len": 60, 6 | "conf_th": 0.2, 7 | "max_bb": 100, 8 | "min_bb": 10, 9 | "num_bb": 36, 10 | "train_batch_size": 40, 11 | "negative_size": 1, 12 | "hard_neg_size": 0, 13 | "hard_neg_pool_size": 20, 14 | "steps_per_hard_neg": -1, 15 | "inf_minibatch_size": 300, 16 | "margin": 0.2, 17 | "gradient_accumulation_steps": 8, 18 | "learning_rate": 0.0001, 19 | "xlmr_lr": 1e-07, 20 | "valid_steps": 10000, 21 | "num_train_steps": 50000, 22 | "optim": "adamw", 23 | "betas": [ 24 | 0.9, 25 | 0.98 26 | ], 27 | "decay": "linear", 28 | "dropout": 0.1, 29 | "weight_decay": 0, 30 | "grad_norm": 2.0, 31 | "warmup_steps": 5000, 32 | "seed": 42, 33 | "full_val": false, 34 | "fp16": true, 35 | "n_workers": 4, 36 | "pin_mem": false, 37 | "rename_checkpoints": false, 38 | "separate_lr": false, 39 | "load_embedding_only": false, 40 | "load_layer": 0, 41 | "save_steps": 200, 42 | "train_txt_dbs": [ 43 | "/db/uc2_mscoco/itm_coco_en_train.db/", 44 | "/db/uc2_mscoco/itm_coco_en_restval.db/", 45 | "/db/uc2_mscoco/itm_coco_ja_train.db/", 46 | "/db/uc2_mscoco/itm_coco_zh_train.db/" 47 | 48 | ], 49 | "train_img_dbs": [ 50 | "/img/coco_train2014/", 51 | "/img/coco_val2014/", 52 | "/img/coco_train2014/", 53 | [ 54 | "/img/coco_train2014/", 55 | "/img/coco_val2014/" 56 | ] 57 | 58 | ], 59 | "val_txt_db": [ 60 | "/db/uc2_mscoco/itm_coco_en_val.db/" 61 | ], 62 | "val_img_db": [ 63 | "/img/coco_val2014/" 64 | ], 65 | "test_txt_db": [ 66 | "/db/uc2_mscoco/itm_coco_en_test_1k_0.db/", 67 | "/db/uc2_mscoco/itm_coco_en_test_1k_1.db/", 68 | "/db/uc2_mscoco/itm_coco_en_test_1k_2.db/", 69 | "/db/uc2_mscoco/itm_coco_en_test_1k_3.db/", 70 | "/db/uc2_mscoco/itm_coco_en_test_1k_4.db/", 71 | "/db/uc2_mscoco/itm_coco_ja_test_1k_0.db/", 72 | "/db/uc2_mscoco/itm_coco_ja_test_1k_1.db/", 73 | "/db/uc2_mscoco/itm_coco_ja_test_1k_2.db/", 74 | "/db/uc2_mscoco/itm_coco_ja_test_1k_3.db/", 75 | "/db/uc2_mscoco/itm_coco_ja_test_1k_4.db/", 76 | "/db/uc2_mscoco/itm_coco_zh_test.db/" 77 | ], 78 | "test_img_db": [ 79 | "/img/coco_val2014/", 80 | "/img/coco_val2014/", 81 | "/img/coco_val2014/", 82 | "/img/coco_val2014/", 83 | "/img/coco_val2014/", 84 | "/img/coco_val2014/", 85 | "/img/coco_val2014/", 86 | "/img/coco_val2014/", 87 | "/img/coco_val2014/", 88 | "/img/coco_val2014/", 89 | [ 90 | "/img/coco_train2014/", 91 | "/img/coco_val2014/" 92 | ] 93 | 94 | ], 95 | "model_config": "config/uc2-base.json", 96 | "mcan": false, 97 | "cut_bert": -1, 98 | "neg_sample_from": "i", 99 | "train_neg_sample_p": 0.67, 100 | "train_neg_sample_size": 1, 101 | "val_minibatch_size": 400, 102 | "test_minibatch_size": 400, 103 | "train_loss": "rank", 104 | "img_rank_weight": 1.0, 105 | "txt_rank_weight": 1.0, 106 | "no_cuda": false, 107 | "local_rank": 0, 108 | "wave_length": 1000, 109 | "val_batch_size": 4096, 110 | "eval_method": "rank", 111 | "rank": 0 112 | } -------------------------------------------------------------------------------- /config/uc2_pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "model_config": "config/uc2-base.json", 4 | "checkpoint": "/pretrain/xlm-roberta-base/pytorch_model.bin", 5 | "output_dir": "/storage/uc2_pretrain", 6 | "mrm_prob": 0.15, 7 | "neg_size": 128, 8 | "nce_temp": 1.0, 9 | "itm_neg_prob": 0.5, 10 | "itm_ot_lambda": 0.0, 11 | "ot_pos_only": false, 12 | "max_txt_len": 60, 13 | "conf_th": 0.2, 14 | "max_bb": 100, 15 | "min_bb": 10, 16 | "num_bb": 36, 17 | "train_batch_size": 10240, 18 | "val_batch_size": 10240, 19 | "gradient_accumulation_steps": 3, 20 | "learning_rate": 4e-05, 21 | "valid_steps": 5000, 22 | "num_train_steps": 200000, 23 | "optim": "adamw", 24 | "betas": [ 25 | 0.9, 26 | 0.98 27 | ], 28 | "decay": "linear", 29 | "dropout": 0.1, 30 | "weight_decay": 0.01, 31 | "grad_norm": 5.0, 32 | "warmup_steps": 10000, 33 | "seed": 42, 34 | "fp16": true, 35 | "n_workers": 4, 36 | "pin_mem": true, 37 | "rename_checkpoints": false, 38 | "itm_hard_neg": false, 39 | "co_masking": true, 40 | "co_masking_mode": "mix", 41 | "save_steps": 200, 42 | "early_adaptation": true, 43 | "early_adaptation_checkpoint": "/pretrain/early_adaptation_multilingual/ckpt/best.pt", 44 | "multilingual_vmlm": true, 45 | "train_datasets": [ 46 | { 47 | "name": "cc", 48 | "db": [ 49 | "/db/uc2_pretrain/conceptual_caption_en_train.db", 50 | "/db/uc2_pretrain/conceptual_caption_de_train.db", 51 | "/db/uc2_pretrain/conceptual_caption_zh_trian.db", 52 | "/db/uc2_pretrain/conceptual_caption_ja_train.db", 53 | "/db/uc2_pretrain/conceptual_caption_fr_train.db", 54 | "/db/uc2_pretrain/conceptual_caption_cs_train.db" 55 | ], 56 | "img": [ 57 | "/img/gcc_train", 58 | "/img/gcc_train", 59 | "/img/gcc_train", 60 | "/img/gcc_train", 61 | "/img/gcc_train", 62 | "/img/gcc_train" 63 | ], 64 | "tasks": [ 65 | "itm", 66 | "mlm", 67 | "vmlm" 68 | ], 69 | "img_token_soft_label": [ 70 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_train/train_img_token_soft_label_batch.db" 71 | ], 72 | "mix_ratio": [ 73 | 9, 74 | 12, 75 | 9 76 | ] 77 | }, 78 | { 79 | "name": "cc_paired", 80 | "db": [ 81 | "/db/uc2_pretrain/conceptual_caption_ende_train.db/", 82 | "/db/uc2_pretrain/conceptual_caption_enzh_train.db/", 83 | "/db/uc2_pretrain/conceptual_caption_enja_train.db/", 84 | "/db/uc2_pretrain/conceptual_caption_encs_train.db/", 85 | "/db/uc2_pretrain/conceptual_caption_enfr_train.db/" 86 | ], 87 | "img": [ 88 | "/img/gcc_train", 89 | "/img/gcc_train", 90 | "/img/gcc_train", 91 | "/img/gcc_train", 92 | "/img/gcc_train" 93 | ], 94 | "tasks": [ 95 | "tlm" 96 | ], 97 | "img_token_soft_label": [ 98 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_train/train_img_token_soft_label_batch.db" 99 | ], 100 | "mix_ratio": [ 101 | 3 102 | ] 103 | } 104 | ], 105 | "val_datasets": [ 106 | { 107 | "name": "cc_en", 108 | "db": [ 109 | "/db/uc2_pretrain/conceptual_caption_en_val.db/" 110 | ], 111 | "img": [ 112 | "/img/gcc_val" 113 | ], 114 | "img_token_soft_label": [ 115 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 116 | ], 117 | "tasks": [ 118 | "itm", 119 | "mlm", 120 | "vmlm" 121 | ] 122 | }, 123 | { 124 | "name": "cc_de", 125 | "db": [ 126 | "/db/uc2_pretrain/conceptual_caption_de_val.db/" 127 | ], 128 | "img": [ 129 | "/img/gcc_val" 130 | ], 131 | "img_token_soft_label": [ 132 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 133 | ], 134 | "tasks": [ 135 | "itm", 136 | "mlm", 137 | "vmlm" 138 | ] 139 | }, 140 | { 141 | "name": "cc_zh", 142 | "db": [ 143 | "/db/uc2_pretrain/conceptual_caption_zh_val.db/" 144 | ], 145 | "img": [ 146 | "/img/gcc_val" 147 | ], 148 | "img_token_soft_label": [ 149 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 150 | ], 151 | "tasks": [ 152 | "itm", 153 | "mlm", 154 | "vmlm" 155 | ] 156 | }, 157 | { 158 | "name": "cc_ja", 159 | "db": [ 160 | "/db/uc2_pretrain/conceptual_caption_ja_val.db/" 161 | ], 162 | "img": [ 163 | "/img/gcc_val" 164 | ], 165 | "img_token_soft_label": [ 166 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 167 | ], 168 | "tasks": [ 169 | "itm", 170 | "mlm", 171 | "vmlm" 172 | ] 173 | }, 174 | { 175 | "name": "cc_fr", 176 | "db": [ 177 | "/db/uc2_pretrain/conceptual_caption_fr_val.db/" 178 | ], 179 | "img": [ 180 | "/img/gcc_val" 181 | ], 182 | "img_token_soft_label": [ 183 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 184 | ], 185 | "tasks": [ 186 | "itm", 187 | "mlm", 188 | "vmlm" 189 | ] 190 | }, 191 | { 192 | "name": "cc_cs", 193 | "db": [ 194 | "/db/uc2_pretrain/conceptual_caption_cs_val.db/" 195 | ], 196 | "img": [ 197 | "/img/gcc_val" 198 | ], 199 | "img_token_soft_label": [ 200 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db" 201 | ], 202 | "tasks": [ 203 | "itm", 204 | "mlm", 205 | "vmlm" 206 | ] 207 | }, 208 | ], 209 | "ans2label": "", 210 | "rank": 0 211 | } -------------------------------------------------------------------------------- /config/uniter-base-cased.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /config/uniter-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import (TxtTokLmdb, DetectFeatLmdb, Img_SoftLabel_Lmdb, 2 | ConcatDatasetWithLens, ImageLmdbGroup, DetectFeatTxtTokDataset, get_ids_and_lens, pad_tensors, get_gather_index) 3 | from .mlm import (MlmDataset, MlmDataset_VLXLMR, VmlmDataset, Vmlm_Softlabel_Dataset, MmxlmDataset, Mmxlm_Softlabel_Dataset, MlmDataset_Dmasking, MlmEvalDataset, 4 | BlindMlmDataset, BlindMlmEvalDataset, 5 | mlm_collate, xlmr_mlm_collate, xlmr_mlm_dmasking_collate, xlmr_tlm_ni_dmasking_collate, xlmr_mmxlm_collate, xlmr_mmxlm_softlabel_collate, mlm_eval_collate, 6 | mlm_blind_collate, mlm_blind_eval_collate) 7 | from .mrm import (MrfrDataset, OnlyImgMrfrDataset, 8 | MrcDataset, OnlyImgMrcDataset, 9 | mrfr_collate, xlmr_mrfr_collate, mrfr_only_img_collate, 10 | mrc_collate, xlmr_mrc_collate, mrc_only_img_collate, _mask_img_feat, _get_targets) 11 | from .itm import (TokenBucketSamplerForItm, 12 | ItmDataset, ItmDataset_HardNeg, itm_collate, itm_ot_collate, xlmr_itm_collate, xlmr_itm_ot_collate, 13 | ItmRankDataset, ItmRankDataset_COCO_CN, ItmRankDatasetHardNeg, itm_rank_collate, 14 | ItmRankDatasetHardNegFromText, 15 | ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate, 16 | ItmHardNegDataset, itm_hn_collate, 17 | ItmValDataset, ItmValDataset_COCO_CN, itm_val_collate, 18 | ItmEvalDataset, ItmEvalDataset_COCO_CN, itm_eval_collate, xlmr_itm_rank_collate) 19 | from .sampler import TokenBucketSampler, DistributedSampler 20 | from .loader import MetaLoader, PrefetchLoader 21 | 22 | from .vqa import VqaDataset, vqa_collate, VqaEvalDataset, vqa_eval_collate, xlmr_vqa_collate, xlmr_vqa_eval_collate 23 | from .nlvr2 import (Nlvr2PairedDataset, nlvr2_paired_collate, 24 | Nlvr2PairedEvalDataset, nlvr2_paired_eval_collate, 25 | Nlvr2TripletDataset, nlvr2_triplet_collate, 26 | Nlvr2TripletEvalDataset, nlvr2_triplet_eval_collate) 27 | from .ve import VeDataset, ve_collate, VeEvalDataset, ve_eval_collate 28 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/itm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/itm.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/mlm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mlm.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/mrm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mrm.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/mrm_nce.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mrm_nce.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/nlvr2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/nlvr2.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/ve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/ve.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/vqa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/vqa.cpython-36.pyc -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | A meta data loader for sampling from different datasets / training tasks 3 | A prefetch loader to speedup data loading 4 | """ 5 | import random 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from utils.distributed import any_broadcast 11 | 12 | 13 | class MetaLoader(object): 14 | """ wraps multiple data loader """ 15 | def __init__(self, loaders, accum_steps=1, distributed=False): 16 | assert isinstance(loaders, dict) 17 | self.name2loader = {} 18 | self.name2iter = {} 19 | self.sampling_pools = [] 20 | for n, l in loaders.items(): 21 | if isinstance(l, tuple): 22 | l, r = l 23 | elif isinstance(l, DataLoader): 24 | r = 1 25 | else: 26 | raise ValueError() 27 | # print(n) 28 | # print(l) 29 | self.name2loader[n] = l 30 | self.name2iter[n] = iter(l) 31 | self.sampling_pools.extend([n]*r) 32 | 33 | self.accum_steps = accum_steps 34 | self.distributed = distributed 35 | self.step = 0 36 | 37 | def __iter__(self): 38 | """ this iterator will run indefinitely """ 39 | task = self.sampling_pools[0] 40 | while True: 41 | if self.step % self.accum_steps == 0: 42 | task = random.choice(self.sampling_pools) 43 | if self.distributed: 44 | # make sure all process is training same task 45 | task = any_broadcast(task, 0) 46 | self.step += 1 47 | iter_ = self.name2iter[task] 48 | try: 49 | batch = next(iter_) 50 | except StopIteration: 51 | iter_ = iter(self.name2loader[task]) 52 | batch = next(iter_) 53 | self.name2iter[task] = iter_ 54 | 55 | yield task, batch 56 | 57 | 58 | def move_to_cuda(batch): 59 | if isinstance(batch, torch.Tensor): 60 | return batch.cuda(non_blocking=True) 61 | elif isinstance(batch, list): 62 | new_batch = [move_to_cuda(t) for t in batch] 63 | elif isinstance(batch, tuple): 64 | new_batch = tuple(move_to_cuda(t) for t in batch) 65 | elif isinstance(batch, dict): 66 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 67 | else: 68 | return batch 69 | return new_batch 70 | 71 | 72 | def record_cuda_stream(batch): 73 | if isinstance(batch, torch.Tensor): 74 | batch.record_stream(torch.cuda.current_stream()) 75 | elif isinstance(batch, list) or isinstance(batch, tuple): 76 | for t in batch: 77 | record_cuda_stream(t) 78 | elif isinstance(batch, dict): 79 | for t in batch.values(): 80 | record_cuda_stream(t) 81 | else: 82 | pass 83 | 84 | 85 | class PrefetchLoader(object): 86 | """ 87 | overlap compute and cuda data transfer 88 | (copied and then modified from nvidia apex) 89 | """ 90 | def __init__(self, loader): 91 | self.loader = loader 92 | self.stream = torch.cuda.Stream() 93 | 94 | def __iter__(self): 95 | loader_it = iter(self.loader) 96 | self.preload(loader_it) 97 | batch = self.next(loader_it) 98 | while batch is not None: 99 | yield batch 100 | batch = self.next(loader_it) 101 | 102 | def __len__(self): 103 | return len(self.loader) 104 | 105 | def preload(self, it): 106 | try: 107 | self.batch = next(it) 108 | except StopIteration: 109 | self.batch = None 110 | return 111 | # if record_stream() doesn't work, another option is to make sure 112 | # device inputs are created on the main stream. 113 | # self.next_input_gpu = torch.empty_like(self.next_input, 114 | # device='cuda') 115 | # self.next_target_gpu = torch.empty_like(self.next_target, 116 | # device='cuda') 117 | # Need to make sure the memory allocated for next_* is not still in use 118 | # by the main stream at the time we start copying to next_*: 119 | # self.stream.wait_stream(torch.cuda.current_stream()) 120 | with torch.cuda.stream(self.stream): 121 | self.batch = move_to_cuda(self.batch) 122 | # more code for the alternative if record_stream() doesn't work: 123 | # copy_ will record the use of the pinned source tensor in this 124 | # side stream. 125 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 126 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 127 | # self.next_input = self.next_input_gpu 128 | # self.next_target = self.next_target_gpu 129 | 130 | def next(self, it): 131 | torch.cuda.current_stream().wait_stream(self.stream) 132 | batch = self.batch 133 | if batch is not None: 134 | record_cuda_stream(batch) 135 | self.preload(it) 136 | return batch 137 | 138 | def __getattr__(self, name): 139 | method = self.loader.__getattribute__(name) 140 | return method 141 | -------------------------------------------------------------------------------- /data/mrm.py: -------------------------------------------------------------------------------- 1 | """ 2 | MRM Datasets 3 | """ 4 | import random 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torch.nn.utils.rnn import pad_sequence 9 | from toolz.sandbox import unzip 10 | from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index 11 | 12 | 13 | def _get_img_mask(mask_prob, num_bb): 14 | img_mask = [random.random() < mask_prob for _ in range(num_bb)] 15 | if not any(img_mask): 16 | # at least mask 1 17 | img_mask[random.choice(range(num_bb))] = True 18 | img_mask = torch.tensor(img_mask) 19 | return img_mask 20 | 21 | 22 | def _get_img_tgt_mask(img_mask, txt_len): 23 | z = torch.zeros(txt_len, dtype=torch.uint8) 24 | img_mask_tgt = torch.cat([z, img_mask], dim=0) 25 | return img_mask_tgt 26 | 27 | 28 | def _get_feat_target(img_feat, img_masks): 29 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) 30 | feat_dim = img_feat.size(-1) 31 | feat_targets = img_feat[img_masks_ext].contiguous().view( 32 | -1, feat_dim) # (s, d) 33 | return feat_targets 34 | 35 | 36 | def _mask_img_feat(img_feat, img_masks): 37 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) 38 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) 39 | return img_feat_masked 40 | 41 | 42 | class MrfrDataset(DetectFeatTxtTokDataset): 43 | def __init__(self, mask_prob, *args, **kwargs): 44 | super().__init__(*args, **kwargs) 45 | self.mask_prob = mask_prob 46 | 47 | def __getitem__(self, i): 48 | """ 49 | Return: 50 | - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded 51 | - img_feat : (num_bb, d) 52 | - img_pos_feat : (num_bb, 7) 53 | - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] 54 | - img_mask : (num_bb, ) between {0, 1} 55 | """ 56 | example = super().__getitem__(i) 57 | # text input 58 | input_ids = example['input_ids'] 59 | input_ids = self.txt_db.combine_inputs(input_ids) 60 | 61 | # image input features 62 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 63 | example['img_fname']) 64 | img_mask = _get_img_mask(self.mask_prob, num_bb) 65 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) 66 | 67 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 68 | 69 | return (input_ids, img_feat, img_pos_feat, 70 | attn_masks, img_mask, img_mask_tgt) 71 | 72 | #Added by Mingyang Zhou 73 | def xlmr_mrfr_collate(inputs): 74 | """ 75 | Return: 76 | - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 1, 1], 1s padded 77 | - position_ids : (n, max_L) 78 | - txt_lens : list of [input_len] 79 | - img_feat : (n, max_num_bb, d) 80 | - img_pos_feat : (n, max_num_bb, 7) 81 | - num_bbs : list of [num_bb] 82 | - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] 83 | - img_masks : (n, max_num_bb) between {0, 1} 84 | """ 85 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, 86 | ) = map(list, unzip(inputs)) 87 | 88 | txt_lens = [i.size(0) for i in input_ids] 89 | 90 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 91 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 92 | ).unsqueeze(0) 93 | 94 | num_bbs = [f.size(0) for f in img_feats] 95 | img_feat = pad_tensors(img_feats, num_bbs) 96 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 97 | 98 | # mask features 99 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 100 | feat_targets = _get_feat_target(img_feat, img_masks) 101 | img_feat = _mask_img_feat(img_feat, img_masks) 102 | img_mask_tgt = pad_sequence(img_mask_tgts, 103 | batch_first=True, padding_value=0) 104 | 105 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 106 | bs, max_tl = input_ids.size() 107 | out_size = attn_masks.size(1) 108 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 109 | 110 | batch = {'input_ids': input_ids, 111 | 'position_ids': position_ids, 112 | 'img_feat': img_feat, 113 | 'img_pos_feat': img_pos_feat, 114 | 'attn_masks': attn_masks, 115 | 'gather_index': gather_index, 116 | 'feat_targets': feat_targets, 117 | 'img_masks': img_masks, 118 | 'img_mask_tgt': img_mask_tgt} 119 | return batch 120 | 121 | def mrfr_collate(inputs): 122 | """ 123 | Return: 124 | - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded 125 | - position_ids : (n, max_L) 126 | - txt_lens : list of [input_len] 127 | - img_feat : (n, max_num_bb, d) 128 | - img_pos_feat : (n, max_num_bb, 7) 129 | - num_bbs : list of [num_bb] 130 | - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] 131 | - img_masks : (n, max_num_bb) between {0, 1} 132 | """ 133 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, 134 | ) = map(list, unzip(inputs)) 135 | 136 | txt_lens = [i.size(0) for i in input_ids] 137 | 138 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 139 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 140 | ).unsqueeze(0) 141 | 142 | num_bbs = [f.size(0) for f in img_feats] 143 | img_feat = pad_tensors(img_feats, num_bbs) 144 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 145 | 146 | # mask features 147 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 148 | feat_targets = _get_feat_target(img_feat, img_masks) 149 | img_feat = _mask_img_feat(img_feat, img_masks) 150 | img_mask_tgt = pad_sequence(img_mask_tgts, 151 | batch_first=True, padding_value=0) 152 | 153 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 154 | bs, max_tl = input_ids.size() 155 | out_size = attn_masks.size(1) 156 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 157 | 158 | batch = {'input_ids': input_ids, 159 | 'position_ids': position_ids, 160 | 'img_feat': img_feat, 161 | 'img_pos_feat': img_pos_feat, 162 | 'attn_masks': attn_masks, 163 | 'gather_index': gather_index, 164 | 'feat_targets': feat_targets, 165 | 'img_masks': img_masks, 166 | 'img_mask_tgt': img_mask_tgt} 167 | return batch 168 | 169 | 170 | class OnlyImgMrfrDataset(Dataset): 171 | """ an image-only MRM """ 172 | def __init__(self, mask_prob, img_db): 173 | self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items())) 174 | 175 | def __getitem__(self, i): 176 | id_ = self.ids[i] 177 | img_feat, img_pos_feat, num_bb = self._get_img_feat(id_) 178 | attn_masks = torch.ones(num_bb, dtype=torch.long) 179 | img_mask = _get_img_mask(self.mask_prob, num_bb) 180 | 181 | return img_feat, img_pos_feat, attn_masks, img_mask 182 | 183 | def _get_img_feat(self, fname): 184 | img_feat, bb = self.img_db[fname] 185 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 186 | num_bb = img_feat.size(0) 187 | return img_feat, img_bb, num_bb 188 | 189 | 190 | def mrfr_only_img_collate(inputs): 191 | img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs)) 192 | 193 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 194 | 195 | num_bbs = [f.size(0) for f in img_feats] 196 | img_feat = pad_tensors(img_feats, num_bbs) 197 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 198 | 199 | # mask features 200 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 201 | feat_targets = _get_feat_target(img_feat, img_masks) 202 | img_feat = _mask_img_feat(img_feat, img_masks) 203 | 204 | batch = {'img_feat': img_feat, 205 | 'img_pos_feat': img_pos_feat, 206 | 'attn_masks': attn_masks, 207 | 'feat_targets': feat_targets, 208 | 'img_masks': img_masks, 209 | 'img_mask_tgt': img_masks} 210 | return batch 211 | 212 | 213 | def _get_targets(img_masks, img_soft_label): 214 | soft_label_dim = img_soft_label.size(-1) 215 | img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) 216 | label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( 217 | -1, soft_label_dim) 218 | return label_targets 219 | 220 | 221 | class MrcDataset(DetectFeatTxtTokDataset): 222 | def __init__(self, mask_prob, *args, **kwargs): 223 | super().__init__(*args, **kwargs) 224 | self.mask_prob = mask_prob 225 | 226 | def _get_img_feat(self, fname): 227 | img_dump = self.img_db.get_dump(fname) 228 | num_bb = self.img_db.name2nbb[fname] 229 | img_feat = torch.tensor(img_dump['features']) 230 | bb = torch.tensor(img_dump['norm_bb']) 231 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 232 | img_soft_label = torch.tensor(img_dump['soft_labels']) 233 | return img_feat, img_bb, img_soft_label, num_bb 234 | 235 | def __getitem__(self, i): 236 | example = super().__getitem__(i) 237 | img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( 238 | example['img_fname']) 239 | 240 | # image input features 241 | img_mask = _get_img_mask(self.mask_prob, num_bb) 242 | 243 | # text input 244 | input_ids = example['input_ids'] 245 | input_ids = self.txt_db.combine_inputs(input_ids) 246 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) 247 | 248 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 249 | 250 | return (input_ids, img_feat, img_pos_feat, 251 | img_soft_labels, attn_masks, img_mask, img_mask_tgt) 252 | 253 | def xlmr_mrc_collate(inputs): 254 | (input_ids, img_feats, img_pos_feats, img_soft_labels, 255 | attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs)) 256 | 257 | txt_lens = [i.size(0) for i in input_ids] 258 | num_bbs = [f.size(0) for f in img_feats] 259 | 260 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 261 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 262 | ).unsqueeze(0) 263 | 264 | img_feat = pad_tensors(img_feats, num_bbs) 265 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 266 | img_soft_label = pad_tensors(img_soft_labels, num_bbs) 267 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 268 | label_targets = _get_targets(img_masks, img_soft_label) 269 | 270 | img_feat = _mask_img_feat(img_feat, img_masks) 271 | img_mask_tgt = pad_sequence(img_mask_tgts, 272 | batch_first=True, padding_value=0) 273 | 274 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 275 | bs, max_tl = input_ids.size() 276 | out_size = attn_masks.size(1) 277 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 278 | 279 | batch = {'input_ids': input_ids, 280 | 'position_ids': position_ids, 281 | 'img_feat': img_feat, 282 | 'img_pos_feat': img_pos_feat, 283 | 'attn_masks': attn_masks, 284 | 'gather_index': gather_index, 285 | 'img_masks': img_masks, 286 | 'img_mask_tgt': img_mask_tgt, 287 | 'label_targets': label_targets} 288 | return batch 289 | 290 | def mrc_collate(inputs): 291 | (input_ids, img_feats, img_pos_feats, img_soft_labels, 292 | attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs)) 293 | 294 | txt_lens = [i.size(0) for i in input_ids] 295 | num_bbs = [f.size(0) for f in img_feats] 296 | 297 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 298 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 299 | ).unsqueeze(0) 300 | 301 | img_feat = pad_tensors(img_feats, num_bbs) 302 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 303 | img_soft_label = pad_tensors(img_soft_labels, num_bbs) 304 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 305 | label_targets = _get_targets(img_masks, img_soft_label) 306 | 307 | img_feat = _mask_img_feat(img_feat, img_masks) 308 | img_mask_tgt = pad_sequence(img_mask_tgts, 309 | batch_first=True, padding_value=0) 310 | 311 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 312 | bs, max_tl = input_ids.size() 313 | out_size = attn_masks.size(1) 314 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 315 | 316 | batch = {'input_ids': input_ids, 317 | 'position_ids': position_ids, 318 | 'img_feat': img_feat, 319 | 'img_pos_feat': img_pos_feat, 320 | 'attn_masks': attn_masks, 321 | 'gather_index': gather_index, 322 | 'img_masks': img_masks, 323 | 'img_mask_tgt': img_mask_tgt, 324 | 'label_targets': label_targets} 325 | return batch 326 | 327 | 328 | class OnlyImgMrcDataset(OnlyImgMrfrDataset): 329 | """ an image-only MRC """ 330 | def __getitem__(self, i): 331 | id_ = self.ids[i] 332 | (img_feat, img_pos_feat, img_soft_labels, num_bb 333 | ) = self._get_img_feat(id_) 334 | attn_masks = torch.ones(num_bb, dtype=torch.long) 335 | img_mask = _get_img_mask(self.mask_prob, num_bb) 336 | 337 | return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask 338 | 339 | def _get_img_feat(self, fname): 340 | img_dump = self.img_db.get_dump(fname) 341 | num_bb = self.img_db.name2nbb[fname] 342 | img_feat = torch.tensor(img_dump['features']) 343 | bb = torch.tensor(img_dump['norm_bb']) 344 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 345 | img_soft_labels = torch.tensor(img_dump['soft_labels']) 346 | return img_feat, img_bb, img_soft_labels, num_bb 347 | 348 | 349 | def mrc_only_img_collate(inputs): 350 | (img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks 351 | ) = map(list, unzip(inputs)) 352 | 353 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 354 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 355 | num_bbs = [f.size(0) for f in img_feats] 356 | 357 | img_feat = pad_tensors(img_feats, num_bbs) 358 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 359 | img_soft_label = pad_tensors(img_soft_labels, num_bbs) 360 | label_targets = _get_targets(img_masks, img_soft_label) 361 | 362 | # mask features 363 | img_feat = _mask_img_feat(img_feat, img_masks) 364 | 365 | batch = {'img_feat': img_feat, 366 | 'img_pos_feat': img_pos_feat, 367 | 'attn_masks': attn_masks, 368 | 'img_masks': img_masks, 369 | 'img_mask_tgt': img_masks, 370 | 'label_targets': label_targets} 371 | return batch 372 | -------------------------------------------------------------------------------- /data/mrm_nce.py: -------------------------------------------------------------------------------- 1 | """ 2 | MRM Datasets (contrastive learning version) 3 | """ 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from toolz.sandbox import unzip 7 | from cytoolz import curry 8 | 9 | from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset, 10 | pad_tensors, get_gather_index) 11 | from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target 12 | from .itm import sample_negative 13 | 14 | 15 | # FIXME diff implementation from mrfr, mrc 16 | def _mask_img_feat(img_feat, img_masks, neg_feats, 17 | noop_prob=0.1, change_prob=0.1): 18 | rand = torch.rand(*img_masks.size()) 19 | noop_mask = rand < noop_prob 20 | change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks 21 | img_masks_in = img_masks & ~noop_mask & ~change_mask 22 | 23 | img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat) 24 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) 25 | 26 | n_neg = change_mask.sum().item() 27 | feat_dim = neg_feats.size(-1) 28 | index = torch.arange(0, change_mask.numel(), dtype=torch.long 29 | ).masked_select(change_mask.view(-1)) 30 | index = index.unsqueeze(-1).expand(-1, feat_dim) 31 | img_feat_out = img_feat_masked.view(-1, feat_dim).scatter( 32 | dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size()) 33 | 34 | return img_feat_out, img_masks_in 35 | 36 | 37 | class MrmNceDataset(DetectFeatTxtTokDataset): 38 | def __init__(self, mask_prob, *args, **kwargs): 39 | super().__init__(*args, **kwargs) 40 | self.mask_prob = mask_prob 41 | 42 | def __getitem__(self, i): 43 | example = super().__getitem__(i) 44 | # text input 45 | input_ids = example['input_ids'] 46 | input_ids = self.txt_db.combine_inputs(input_ids) 47 | 48 | # image input features 49 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 50 | example['img_fname']) 51 | img_mask = _get_img_mask(self.mask_prob, num_bb) 52 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) 53 | 54 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 55 | 56 | return (input_ids, img_feat, img_pos_feat, 57 | attn_masks, img_mask, img_mask_tgt, 58 | example['img_fname']) 59 | 60 | 61 | class NegativeImageSampler(object): 62 | def __init__(self, img_dbs, neg_size, size_mul=8): 63 | if not isinstance(img_dbs, list): 64 | assert isinstance(img_dbs, DetectFeatLmdb) 65 | img_dbs = [img_dbs] 66 | self.neg_size = neg_size 67 | self.img_db = JoinedDetectFeatLmdb(img_dbs) 68 | all_imgs = [] 69 | for db in img_dbs: 70 | all_imgs.extend(db.name2nbb.keys()) 71 | self.all_imgs = all_imgs 72 | 73 | def sample_negative_feats(self, pos_imgs): 74 | neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size) 75 | all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids], 76 | dim=0) 77 | # only use multiples of 8 for tensorcores 78 | n_cut = all_neg_feats.size(0) % 8 79 | if n_cut != 0: 80 | return all_neg_feats[:-n_cut] 81 | else: 82 | return all_neg_feats 83 | 84 | 85 | class JoinedDetectFeatLmdb(object): 86 | def __init__(self, img_dbs): 87 | assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs) 88 | self.img_dbs = img_dbs 89 | 90 | def __getitem__(self, file_name): 91 | for db in self.img_dbs: 92 | if file_name in db: 93 | return db[file_name] 94 | raise ValueError("image does not exists") 95 | 96 | 97 | @curry 98 | def mrm_nce_collate(neg_sampler, inputs): 99 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, 100 | positive_imgs) = map(list, unzip(inputs)) 101 | 102 | txt_lens = [i.size(0) for i in input_ids] 103 | 104 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 105 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 106 | ).unsqueeze(0) 107 | 108 | num_bbs = [f.size(0) for f in img_feats] 109 | img_feat = pad_tensors(img_feats, num_bbs) 110 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 111 | neg_feats = neg_sampler.sample_negative_feats(positive_imgs) 112 | 113 | # mask features 114 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 115 | feat_targets = _get_feat_target(img_feat, img_masks) 116 | img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats) 117 | img_mask_tgt = pad_sequence(img_mask_tgts, 118 | batch_first=True, padding_value=0) 119 | 120 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 121 | bs, max_tl = input_ids.size() 122 | out_size = attn_masks.size(1) 123 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 124 | 125 | batch = {'input_ids': input_ids, 126 | 'position_ids': position_ids, 127 | 'img_feat': img_feat, 128 | 'img_pos_feat': img_pos_feat, 129 | 'attn_masks': attn_masks, 130 | 'gather_index': gather_index, 131 | 'feat_targets': feat_targets, 132 | 'img_masks': img_masks, 133 | 'img_masks_in': img_masks_in, 134 | 'img_mask_tgt': img_mask_tgt, 135 | 'neg_feats': neg_feats} 136 | return batch 137 | -------------------------------------------------------------------------------- /data/nlvr2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | NLVR2 dataset 6 | """ 7 | import copy 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from toolz.sandbox import unzip 12 | from cytoolz import concat 13 | 14 | from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb, 15 | get_ids_and_lens, pad_tensors, get_gather_index) 16 | 17 | 18 | class Nlvr2PairedDataset(DetectFeatTxtTokDataset): 19 | def __init__(self, txt_db, img_db, use_img_type=True): 20 | assert isinstance(txt_db, TxtTokLmdb) 21 | assert isinstance(img_db, DetectFeatLmdb) 22 | self.txt_db = txt_db 23 | self.img_db = img_db 24 | txt_lens, self.ids = get_ids_and_lens(txt_db) 25 | 26 | txt2img = txt_db.txt2img 27 | self.lens = [2*tl + sum(self.img_db.name2nbb[img] 28 | for img in txt2img[id_]) 29 | for tl, id_ in zip(txt_lens, self.ids)] 30 | 31 | self.use_img_type = use_img_type 32 | 33 | def __getitem__(self, i): 34 | """ 35 | [[txt, img1], 36 | [txt, img2]] 37 | """ 38 | example = super().__getitem__(i) 39 | target = example['target'] 40 | outs = [] 41 | for i, img in enumerate(example['img_fname']): 42 | img_feat, img_pos_feat, num_bb = self._get_img_feat(img) 43 | 44 | # text input 45 | input_ids = copy.deepcopy(example['input_ids']) 46 | 47 | input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] 48 | attn_masks = [1] * (len(input_ids) + num_bb) 49 | input_ids = torch.tensor(input_ids) 50 | attn_masks = torch.tensor(attn_masks) 51 | if self.use_img_type: 52 | img_type_ids = torch.tensor([i+1]*num_bb) 53 | else: 54 | img_type_ids = None 55 | 56 | outs.append((input_ids, img_feat, img_pos_feat, 57 | attn_masks, img_type_ids)) 58 | return tuple(outs), target 59 | 60 | 61 | def nlvr2_paired_collate(inputs): 62 | (input_ids, img_feats, img_pos_feats, attn_masks, 63 | img_type_ids) = map(list, unzip(concat(outs for outs, _ in inputs))) 64 | 65 | txt_lens = [i.size(0) for i in input_ids] 66 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 67 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 68 | ).unsqueeze(0) 69 | 70 | # image batches 71 | num_bbs = [f.size(0) for f in img_feats] 72 | img_feat = pad_tensors(img_feats, num_bbs) 73 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 74 | if img_type_ids[0] is None: 75 | img_type_ids = None 76 | else: 77 | img_type_ids = pad_sequence(img_type_ids, 78 | batch_first=True, padding_value=0) 79 | 80 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 81 | targets = torch.Tensor([t for _, t in inputs]).long() 82 | 83 | bs, max_tl = input_ids.size() 84 | out_size = attn_masks.size(1) 85 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 86 | 87 | batch = {'input_ids': input_ids, 88 | 'position_ids': position_ids, 89 | 'img_feat': img_feat, 90 | 'img_pos_feat': img_pos_feat, 91 | 'attn_masks': attn_masks, 92 | 'gather_index': gather_index, 93 | 'img_type_ids': img_type_ids, 94 | 'targets': targets} 95 | return batch 96 | 97 | 98 | class Nlvr2PairedEvalDataset(Nlvr2PairedDataset): 99 | def __getitem__(self, i): 100 | qid = self.ids[i] 101 | outs, targets = super().__getitem__(i) 102 | return qid, outs, targets 103 | 104 | 105 | def nlvr2_paired_eval_collate(inputs): 106 | qids, batch = [], [] 107 | for id_, *tensors in inputs: 108 | qids.append(id_) 109 | batch.append(tensors) 110 | batch = nlvr2_paired_collate(batch) 111 | batch['qids'] = qids 112 | return batch 113 | 114 | 115 | class Nlvr2TripletDataset(DetectFeatTxtTokDataset): 116 | def __init__(self, txt_db, img_db, use_img_type=True): 117 | assert isinstance(txt_db, TxtTokLmdb) 118 | assert isinstance(img_db, DetectFeatLmdb) 119 | self.txt_db = txt_db 120 | self.img_db = img_db 121 | txt_lens, self.ids = get_ids_and_lens(txt_db) 122 | 123 | txt2img = txt_db.txt2img 124 | self.lens = [tl + sum(self.img_db.name2nbb[img] 125 | for img in txt2img[id_]) 126 | for tl, id_ in zip(txt_lens, self.ids)] 127 | 128 | self.use_img_type = use_img_type 129 | 130 | def __getitem__(self, i): 131 | """ 132 | [[txt, img1], 133 | [txt, img2]] 134 | """ 135 | example = super().__getitem__(i) 136 | target = example['target'] 137 | img_feats = [] 138 | img_pos_feats = [] 139 | num_bb = 0 140 | img_type_ids = [] 141 | for i, img in enumerate(example['img_fname']): 142 | feat, pos, nbb = self._get_img_feat(img) 143 | img_feats.append(feat) 144 | img_pos_feats.append(pos) 145 | num_bb += nbb 146 | if self.use_img_type: 147 | img_type_ids.extend([i+1]*nbb) 148 | img_feat = torch.cat(img_feats, dim=0) 149 | img_pos_feat = torch.cat(img_pos_feats, dim=0) 150 | if self.use_img_type: 151 | img_type_ids = torch.tensor(img_type_ids) 152 | else: 153 | img_type_ids = None 154 | 155 | # text input 156 | input_ids = copy.deepcopy(example['input_ids']) 157 | 158 | input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep] 159 | attn_masks = [1] * (len(input_ids) + num_bb) 160 | input_ids = torch.tensor(input_ids) 161 | attn_masks = torch.tensor(attn_masks) 162 | 163 | return (input_ids, img_feat, img_pos_feat, attn_masks, 164 | img_type_ids, target) 165 | 166 | 167 | def nlvr2_triplet_collate(inputs): 168 | (input_ids, img_feats, img_pos_feats, 169 | attn_masks, img_type_ids, targets) = map(list, unzip(inputs)) 170 | 171 | txt_lens = [i.size(0) for i in input_ids] 172 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 173 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 174 | ).unsqueeze(0) 175 | 176 | # image batches 177 | num_bbs = [f.size(0) for f in img_feats] 178 | img_feat = pad_tensors(img_feats, num_bbs) 179 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 180 | if img_type_ids[0] is None: 181 | img_type_ids = None 182 | else: 183 | img_type_ids = pad_sequence(img_type_ids, 184 | batch_first=True, padding_value=0) 185 | 186 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 187 | targets = torch.Tensor(targets).long() 188 | 189 | bs, max_tl = input_ids.size() 190 | out_size = attn_masks.size(1) 191 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 192 | 193 | batch = {'input_ids': input_ids, 194 | 'position_ids': position_ids, 195 | 'img_feat': img_feat, 196 | 'img_pos_feat': img_pos_feat, 197 | 'attn_masks': attn_masks, 198 | 'gather_index': gather_index, 199 | 'img_type_ids': img_type_ids, 200 | 'targets': targets} 201 | return batch 202 | 203 | 204 | class Nlvr2TripletEvalDataset(Nlvr2TripletDataset): 205 | def __getitem__(self, i): 206 | qid = self.ids[i] 207 | tensors = super().__getitem__(i) 208 | return (qid, *tensors) 209 | 210 | 211 | def nlvr2_triplet_eval_collate(inputs): 212 | qids, batch = [], [] 213 | for id_, *tensors in inputs: 214 | qids.append(id_) 215 | batch.append(tensors) 216 | batch = nlvr2_triplet_collate(batch) 217 | batch['qids'] = qids 218 | return batch 219 | -------------------------------------------------------------------------------- /data/re.py: -------------------------------------------------------------------------------- 1 | """ 2 | Referring Expression Comprehension dataset 3 | """ 4 | import sys 5 | import json 6 | import random 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torch.nn.utils.rnn import pad_sequence 12 | from toolz.sandbox import unzip 13 | 14 | from .data import TxtLmdb 15 | 16 | 17 | class ReImageFeatDir(object): 18 | def __init__(self, img_dir): 19 | self.img_dir = img_dir 20 | 21 | def __getitem__(self, file_name): 22 | img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) 23 | img_feat = torch.tensor(img_dump['features']) 24 | img_bb = torch.tensor(img_dump['norm_bb']) 25 | return img_feat, img_bb 26 | 27 | 28 | class ReDetectFeatDir(object): 29 | def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, 30 | format_='npz'): 31 | assert format_ == 'npz', 'only support npz for now.' 32 | assert isinstance(img_dir, str), 'img_dir is path, not db.' 33 | self.img_dir = img_dir 34 | self.conf_th = conf_th 35 | self.max_bb = max_bb 36 | self.min_bb = min_bb 37 | self.num_bb = num_bb 38 | 39 | def _compute_num_bb(self, img_dump): 40 | num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum()) 41 | num_bb = min(self.max_bb, num_bb) 42 | return num_bb 43 | 44 | def __getitem__(self, file_name): 45 | # image input features 46 | img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True) 47 | num_bb = self._compute_num_bb(img_dump) 48 | img_feat = torch.tensor(img_dump['features'][:num_bb, :]) 49 | img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :]) 50 | return img_feat, img_bb 51 | 52 | 53 | class ReferringExpressionDataset(Dataset): 54 | def __init__(self, db_dir, img_dir, max_txt_len=60): 55 | assert isinstance(img_dir, ReImageFeatDir) or \ 56 | isinstance(img_dir, ReDetectFeatDir) 57 | self.img_dir = img_dir 58 | 59 | # load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}] 60 | refs = json.load(open(f'{db_dir}/refs.json', 'r')) 61 | self.ref_ids = [ref['ref_id'] for ref in refs] 62 | self.Refs = {ref['ref_id']: ref for ref in refs} 63 | 64 | # load annotations = [{id, area, bbox, image_id, category_id}] 65 | anns = json.load(open(f'{db_dir}/annotations.json', 'r')) 66 | self.Anns = {ann['id']: ann for ann in anns} 67 | 68 | # load categories = [{id, name, supercategory}] 69 | categories = json.load(open(f'{db_dir}/categories.json', 'r')) 70 | self.Cats = {cat['id']: cat['name'] for cat in categories} 71 | 72 | # load images = [{id, file_name, ann_ids, height, width}] 73 | images = json.load(open(f'{db_dir}/images.json', 'r')) 74 | self.Images = {img['id']: img for img in images} 75 | 76 | # id2len: sent_id -> sent_len 77 | id2len = json.load(open(f'{db_dir}/id2len.json', 'r')) 78 | self.id2len = {int(_id): _len for _id, _len in id2len.items()} 79 | self.max_txt_len = max_txt_len 80 | self.sent_ids = self._get_sent_ids() 81 | 82 | # db[str(sent_id)] = 83 | # {sent_id, sent, ref_id, ann_id, image_id, 84 | # bbox, input_ids, toked_sent} 85 | self.db = TxtLmdb(db_dir, readonly=True) 86 | 87 | # meta 88 | meta = json.load(open(f'{db_dir}/meta.json', 'r')) 89 | self.cls_ = meta['CLS'] 90 | self.sep = meta['SEP'] 91 | self.mask = meta['MASK'] 92 | self.v_range = meta['v_range'] 93 | 94 | def shuffle(self): 95 | # we shuffle ref_ids and make sent_ids according to ref_ids 96 | random.shuffle(self.ref_ids) 97 | self.sent_ids = self._get_sent_ids() 98 | 99 | def _get_sent_ids(self): 100 | sent_ids = [] 101 | for ref_id in self.ref_ids: 102 | for sent_id in self.Refs[ref_id]['sent_ids']: 103 | sent_len = self.id2len[sent_id] 104 | if self.max_txt_len == -1 or sent_len < self.max_txt_len: 105 | sent_ids.append(sent_id) 106 | return sent_ids 107 | 108 | def _get_img_feat(self, fname): 109 | img_feat, bb = self.img_dir[fname] 110 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 111 | num_bb = img_feat.size(0) 112 | return img_feat, img_bb, num_bb 113 | 114 | def __len__(self): 115 | return len(self.sent_ids) 116 | 117 | def __getitem__(self, i): 118 | """ 119 | Return: 120 | :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] 121 | :position_ids : range(L) 122 | :img_feat : (num_bb, d) 123 | :img_pos_feat : (num_bb, 7) 124 | :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] 125 | :obj_masks : (num_bb, ) all 0's 126 | :target : (1, ) 127 | """ 128 | # {sent_id, sent, ref_id, ann_id, image_id, 129 | # bbox, input_ids, toked_sent} 130 | sent_id = self.sent_ids[i] 131 | txt_dump = self.db[str(sent_id)] 132 | image_id = txt_dump['image_id'] 133 | fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' 134 | img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) 135 | 136 | # text input 137 | input_ids = txt_dump['input_ids'] 138 | input_ids = [self.cls_] + input_ids + [self.sep] 139 | attn_masks = [1] * len(input_ids) 140 | position_ids = list(range(len(input_ids))) 141 | attn_masks += [1] * num_bb 142 | 143 | input_ids = torch.tensor(input_ids) 144 | position_ids = torch.tensor(position_ids) 145 | attn_masks = torch.tensor(attn_masks) 146 | 147 | # target bbox 148 | img = self.Images[image_id] 149 | assert len(img['ann_ids']) == num_bb, \ 150 | 'Please use visual_grounding_coco_gt' 151 | target = img['ann_ids'].index(txt_dump['ann_id']) 152 | target = torch.tensor([target]) 153 | 154 | # obj_masks, to be padded with 1, for masking out non-object prob. 155 | obj_masks = torch.tensor([0]*len(img['ann_ids'])).byte() 156 | 157 | return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, 158 | obj_masks, target) 159 | 160 | 161 | def re_collate(inputs): 162 | """ 163 | Return: 164 | :input_ids : (n, max_L) padded with 0 165 | :position_ids : (n, max_L) padded with 0 166 | :txt_lens : list of [txt_len] 167 | :img_feat : (n, max_num_bb, feat_dim) 168 | :img_pos_feat : (n, max_num_bb, 7) 169 | :num_bbs : list of [num_bb] 170 | :attn_masks : (n, max_{L+num_bb}) padded with 0 171 | :obj_masks : (n, max_num_bb) padded with 1 172 | :targets : (n, ) 173 | """ 174 | (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, 175 | targets) = map(list, unzip(inputs)) 176 | 177 | txt_lens = [i.size(0) for i in input_ids] 178 | num_bbs = [f.size(0) for f in img_feats] 179 | 180 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 181 | position_ids = pad_sequence(position_ids, 182 | batch_first=True, padding_value=0) 183 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 184 | targets = torch.cat(targets, dim=0) 185 | obj_masks = pad_sequence(obj_masks, 186 | batch_first=True, padding_value=1).byte() 187 | 188 | batch_size = len(img_feats) 189 | num_bb = max(num_bbs) 190 | feat_dim = img_feats[0].size(1) 191 | pos_dim = img_pos_feats[0].size(1) 192 | img_feat = torch.zeros(batch_size, num_bb, feat_dim) 193 | img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) 194 | for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): 195 | len_ = im.size(0) 196 | img_feat.data[i, :len_, :] = im.data 197 | img_pos_feat.data[i, :len_, :] = pos.data 198 | 199 | return (input_ids, position_ids, txt_lens, 200 | img_feat, img_pos_feat, num_bbs, 201 | attn_masks, obj_masks, targets) 202 | 203 | 204 | class ReferringExpressionEvalDataset(ReferringExpressionDataset): 205 | def __getitem__(self, i): 206 | """ 207 | Return: 208 | :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0] 209 | :position_ids : range(L) 210 | :img_feat : (num_bb, d) 211 | :img_pos_feat : (num_bb, 7) 212 | :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1] 213 | :obj_masks : (num_bb, ) all 0's 214 | :tgt_box : ndarray (4, ) xywh 215 | :obj_boxes : ndarray (num_bb, 4) xywh 216 | :sent_id 217 | """ 218 | # {sent_id, sent, ref_id, ann_id, image_id, 219 | # bbox, input_ids, toked_sent} 220 | sent_id = self.sent_ids[i] 221 | txt_dump = self.db[str(sent_id)] 222 | image_id = txt_dump['image_id'] 223 | if isinstance(self.img_dir, ReImageFeatDir): 224 | if '_gt' in self.img_dir.img_dir: 225 | fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz' 226 | elif '_det' in self.img_dir.img_dir: 227 | fname = f'visual_grounding_det_coco_{int(image_id):012}.npz' 228 | elif isinstance(self.img_dir, ReDetectFeatDir): 229 | fname = f'coco_train2014_{int(image_id):012}.npz' 230 | else: 231 | sys.exit('%s not supported.' % self.img_dir) 232 | img_feat, img_pos_feat, num_bb = self._get_img_feat(fname) 233 | 234 | # image info 235 | img = self.Images[image_id] 236 | im_width, im_height = img['width'], img['height'] 237 | 238 | # object boxes, img_pos_feat (xyxywha) -> xywh 239 | obj_boxes = np.stack([img_pos_feat[:, 0]*im_width, 240 | img_pos_feat[:, 1]*im_height, 241 | img_pos_feat[:, 4]*im_width, 242 | img_pos_feat[:, 5]*im_height], axis=1) 243 | obj_masks = torch.tensor([0]*num_bb).byte() 244 | 245 | # target box 246 | tgt_box = np.array(txt_dump['bbox']) # xywh 247 | 248 | # text input 249 | input_ids = txt_dump['input_ids'] 250 | input_ids = [self.cls_] + input_ids + [self.sep] 251 | attn_masks = [1] * len(input_ids) 252 | position_ids = list(range(len(input_ids))) 253 | attn_masks += [1] * num_bb 254 | 255 | input_ids = torch.tensor(input_ids) 256 | position_ids = torch.tensor(position_ids) 257 | attn_masks = torch.tensor(attn_masks) 258 | 259 | return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks, 260 | obj_masks, tgt_box, obj_boxes, sent_id) 261 | 262 | # IoU function 263 | def computeIoU(self, box1, box2): 264 | # each box is of [x1, y1, w, h] 265 | inter_x1 = max(box1[0], box2[0]) 266 | inter_y1 = max(box1[1], box2[1]) 267 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1) 268 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1) 269 | 270 | if inter_x1 < inter_x2 and inter_y1 < inter_y2: 271 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1) 272 | else: 273 | inter = 0 274 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter 275 | return float(inter)/union 276 | 277 | 278 | def re_eval_collate(inputs): 279 | """ 280 | Return: 281 | :input_ids : (n, max_L) 282 | :position_ids : (n, max_L) 283 | :txt_lens : list of [txt_len] 284 | :img_feat : (n, max_num_bb, d) 285 | :img_pos_feat : (n, max_num_bb, 7) 286 | :num_bbs : list of [num_bb] 287 | :attn_masks : (n, max{L+num_bb}) 288 | :obj_masks : (n, max_num_bb) 289 | :tgt_box : list of n [xywh] 290 | :obj_boxes : list of n [[xywh, xywh, ...]] 291 | :sent_ids : list of n [sent_id] 292 | """ 293 | (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks, 294 | tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs)) 295 | 296 | txt_lens = [i.size(0) for i in input_ids] 297 | num_bbs = [f.size(0) for f in img_feats] 298 | 299 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 300 | position_ids = pad_sequence(position_ids, 301 | batch_first=True, padding_value=0) 302 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 303 | obj_masks = pad_sequence(obj_masks, 304 | batch_first=True, padding_value=1).byte() 305 | 306 | batch_size = len(img_feats) 307 | num_bb = max(num_bbs) 308 | feat_dim = img_feats[0].size(1) 309 | pos_dim = img_pos_feats[0].size(1) 310 | img_feat = torch.zeros(batch_size, num_bb, feat_dim) 311 | img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim) 312 | for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)): 313 | len_ = im.size(0) 314 | img_feat.data[i, :len_, :] = im.data 315 | img_pos_feat.data[i, :len_, :] = pos.data 316 | 317 | return (input_ids, position_ids, txt_lens, 318 | img_feat, img_pos_feat, num_bbs, 319 | attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids) 320 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | """ sampler for length bucketing (batch by tokens) """ 2 | import math 3 | import random 4 | 5 | import torch 6 | import horovod.torch as hvd 7 | from torch.utils.data import Sampler 8 | from cytoolz import partition_all 9 | 10 | 11 | class TokenBucketSampler(Sampler): 12 | def __init__(self, lens, bucket_size, batch_size, 13 | droplast=False, size_multiple=8): 14 | self._lens = lens 15 | self._max_tok = batch_size 16 | self._bucket_size = bucket_size 17 | self._droplast = droplast 18 | self._size_mul = size_multiple 19 | 20 | def _create_ids(self): 21 | return list(range(len(self._lens))) 22 | 23 | def _sort_fn(self, i): 24 | return self._lens[i] 25 | 26 | def __iter__(self): 27 | ids = self._create_ids() 28 | random.shuffle(ids) 29 | buckets = [sorted(ids[i:i+self._bucket_size], 30 | key=self._sort_fn, reverse=True) 31 | for i in range(0, len(ids), self._bucket_size)] 32 | # fill batches until max_token (include padding) 33 | batches = [] 34 | for bucket in buckets: 35 | max_len = 0 36 | batch_indices = [] 37 | for indices in partition_all(self._size_mul, bucket): 38 | max_len = max(max_len, max(self._lens[i] for i in indices)) 39 | # print(max_len * (len(batch_indices) + self._size_mul)) 40 | # print(self._max_tok) 41 | # print(len(batch_indices)) 42 | if (max_len * (len(batch_indices) + self._size_mul) 43 | > self._max_tok): 44 | if not batch_indices: 45 | raise ValueError( 46 | "max_tokens too small / max_seq_len too long") 47 | assert len(batch_indices) % self._size_mul == 0 48 | batches.append(batch_indices) 49 | batch_indices = list(indices) 50 | else: 51 | batch_indices.extend(indices) 52 | if not self._droplast and batch_indices: 53 | batches.append(batch_indices) 54 | random.shuffle(batches) 55 | return iter(batches) 56 | 57 | def __len__(self): 58 | raise ValueError("NOT supported. " 59 | "This has some randomness across epochs") 60 | 61 | 62 | class DistributedSampler(Sampler): 63 | """Sampler that restricts data loading to a subset of the dataset. 64 | 65 | It is especially useful in conjunction with 66 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 67 | process can pass a DistributedSampler instance as a DataLoader sampler, 68 | and load a subset of the original dataset that is exclusive to it. 69 | 70 | .. note:: 71 | Dataset is assumed to be of constant size. 72 | 73 | Arguments: 74 | dataset: Dataset used for sampling. 75 | num_replicas (optional): Number of processes participating in 76 | distributed training. 77 | rank (optional): Rank of the current process within num_replicas. 78 | shuffle (optional): If true (default), sampler will shuffle the indices 79 | """ 80 | 81 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 82 | if num_replicas is None: 83 | num_replicas = hvd.size() 84 | if rank is None: 85 | rank = hvd.rank() 86 | self.dataset = dataset 87 | self.num_replicas = num_replicas 88 | self.rank = rank 89 | self.epoch = 0 90 | self.num_samples = int(math.ceil(len(self.dataset) 91 | * 1.0 / self.num_replicas)) 92 | self.total_size = self.num_samples * self.num_replicas 93 | self.shuffle = shuffle 94 | 95 | def __iter__(self): 96 | # deterministically shuffle based on epoch 97 | g = torch.Generator() 98 | g.manual_seed(self.epoch) 99 | 100 | indices = list(range(len(self.dataset))) 101 | # add extra samples to make it evenly divisible 102 | indices += indices[:(self.total_size - len(indices))] 103 | assert len(indices) == self.total_size 104 | 105 | # subsample 106 | indices = indices[self.rank:self.total_size:self.num_replicas] 107 | 108 | if self.shuffle: 109 | shufle_ind = torch.randperm(len(indices), generator=g).tolist() 110 | indices = [indices[i] for i in shufle_ind] 111 | assert len(indices) == self.num_samples 112 | 113 | return iter(indices) 114 | 115 | def __len__(self): 116 | return self.num_samples 117 | 118 | def set_epoch(self, epoch): 119 | self.epoch = epoch 120 | -------------------------------------------------------------------------------- /data/test_data/input0.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input0.txt -------------------------------------------------------------------------------- /data/test_data/input1.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input1.txt -------------------------------------------------------------------------------- /data/test_data/input2.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input2.txt -------------------------------------------------------------------------------- /data/test_data/input3.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input3.txt -------------------------------------------------------------------------------- /data/test_data/input4.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input4.txt -------------------------------------------------------------------------------- /data/test_data/input5.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input5.txt -------------------------------------------------------------------------------- /data/test_data/input6.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input6.txt -------------------------------------------------------------------------------- /data/test_data/input7.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input7.txt -------------------------------------------------------------------------------- /data/ve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual entailment dataset 3 | # NOTE: basically reuse VQA dataset 4 | """ 5 | from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate 6 | 7 | 8 | class VeDataset(VqaDataset): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(3, *args, **kwargs) 11 | 12 | 13 | class VeEvalDataset(VqaEvalDataset): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(3, *args, **kwargs) 16 | 17 | 18 | ve_collate = vqa_collate 19 | ve_eval_collate = vqa_eval_collate 20 | -------------------------------------------------------------------------------- /data/vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | VQA dataset 3 | """ 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from toolz.sandbox import unzip 7 | 8 | from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index 9 | 10 | 11 | def _get_vqa_target(example, num_answers): 12 | target = torch.zeros(num_answers) 13 | labels = example['target']['labels'] 14 | scores = example['target']['scores'] 15 | if labels and scores: 16 | target.scatter_(0, torch.tensor(labels), torch.tensor(scores)) 17 | return target 18 | 19 | 20 | class VqaDataset(DetectFeatTxtTokDataset): 21 | """ NOTE: This handels distributed inside """ 22 | def __init__(self, num_answers, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.num_answers = num_answers 25 | 26 | def __getitem__(self, i): 27 | example = super().__getitem__(i) 28 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 29 | example['img_fname']) 30 | 31 | # text input 32 | input_ids = example['input_ids'] 33 | input_ids = self.txt_db.combine_inputs(input_ids) 34 | 35 | target = _get_vqa_target(example, self.num_answers) 36 | 37 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 38 | 39 | return input_ids, img_feat, img_pos_feat, attn_masks, target 40 | 41 | 42 | def xlmr_vqa_collate(inputs): 43 | (input_ids, img_feats, img_pos_feats, attn_masks, targets 44 | ) = map(list, unzip(inputs)) 45 | 46 | txt_lens = [i.size(0) for i in input_ids] 47 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 48 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 49 | ).unsqueeze(0) 50 | 51 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 52 | targets = torch.stack(targets, dim=0) 53 | 54 | num_bbs = [f.size(0) for f in img_feats] 55 | img_feat = pad_tensors(img_feats, num_bbs) 56 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 57 | 58 | bs, max_tl = input_ids.size() 59 | out_size = attn_masks.size(1) 60 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 61 | 62 | batch = {'input_ids': input_ids, 63 | 'position_ids': position_ids, 64 | 'img_feat': img_feat, 65 | 'img_pos_feat': img_pos_feat, 66 | 'attn_masks': attn_masks, 67 | 'gather_index': gather_index, 68 | 'targets': targets} 69 | return batch 70 | 71 | def vqa_collate(inputs): 72 | (input_ids, img_feats, img_pos_feats, attn_masks, targets 73 | ) = map(list, unzip(inputs)) 74 | 75 | txt_lens = [i.size(0) for i in input_ids] 76 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 77 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 78 | ).unsqueeze(0) 79 | 80 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 81 | targets = torch.stack(targets, dim=0) 82 | 83 | num_bbs = [f.size(0) for f in img_feats] 84 | img_feat = pad_tensors(img_feats, num_bbs) 85 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 86 | 87 | bs, max_tl = input_ids.size() 88 | out_size = attn_masks.size(1) 89 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 90 | 91 | batch = {'input_ids': input_ids, 92 | 'position_ids': position_ids, 93 | 'img_feat': img_feat, 94 | 'img_pos_feat': img_pos_feat, 95 | 'attn_masks': attn_masks, 96 | 'gather_index': gather_index, 97 | 'targets': targets} 98 | return batch 99 | 100 | 101 | class VqaEvalDataset(VqaDataset): 102 | def __getitem__(self, i): 103 | qid = self.ids[i] 104 | example = DetectFeatTxtTokDataset.__getitem__(self, i) 105 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 106 | example['img_fname']) 107 | 108 | # text input 109 | input_ids = example['input_ids'] 110 | input_ids = self.txt_db.combine_inputs(input_ids) 111 | 112 | if 'target' in example: 113 | target = _get_vqa_target(example, self.num_answers) 114 | else: 115 | target = None 116 | 117 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 118 | 119 | return qid, input_ids, img_feat, img_pos_feat, attn_masks, target 120 | 121 | 122 | def xlmr_vqa_eval_collate(inputs): 123 | (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets 124 | ) = map(list, unzip(inputs)) 125 | 126 | txt_lens = [i.size(0) for i in input_ids] 127 | 128 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1) 129 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 130 | ).unsqueeze(0) 131 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 132 | if targets[0] is None: 133 | targets = None 134 | else: 135 | targets = torch.stack(targets, dim=0) 136 | 137 | num_bbs = [f.size(0) for f in img_feats] 138 | img_feat = pad_tensors(img_feats, num_bbs) 139 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 140 | 141 | bs, max_tl = input_ids.size() 142 | out_size = attn_masks.size(1) 143 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 144 | 145 | batch = {'qids': qids, 146 | 'input_ids': input_ids, 147 | 'position_ids': position_ids, 148 | 'img_feat': img_feat, 149 | 'img_pos_feat': img_pos_feat, 150 | 'attn_masks': attn_masks, 151 | 'gather_index': gather_index, 152 | 'targets': targets} 153 | return batch 154 | 155 | def vqa_eval_collate(inputs): 156 | (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets 157 | ) = map(list, unzip(inputs)) 158 | 159 | txt_lens = [i.size(0) for i in input_ids] 160 | 161 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 162 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 163 | ).unsqueeze(0) 164 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 165 | if targets[0] is None: 166 | targets = None 167 | else: 168 | targets = torch.stack(targets, dim=0) 169 | 170 | num_bbs = [f.size(0) for f in img_feats] 171 | img_feat = pad_tensors(img_feats, num_bbs) 172 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 173 | 174 | bs, max_tl = input_ids.size() 175 | out_size = attn_masks.size(1) 176 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 177 | 178 | batch = {'qids': qids, 179 | 'input_ids': input_ids, 180 | 'position_ids': position_ids, 181 | 'img_feat': img_feat, 182 | 'img_pos_feat': img_pos_feat, 183 | 'attn_masks': attn_masks, 184 | 'gather_index': gather_index, 185 | 'targets': targets} 186 | return batch 187 | -------------------------------------------------------------------------------- /eval/__pycache__/itm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/eval/__pycache__/itm.cpython-36.pyc -------------------------------------------------------------------------------- /eval/itm.py: -------------------------------------------------------------------------------- 1 | """ Image Text Retrieval evaluation helper """ 2 | import torch 3 | 4 | 5 | @torch.no_grad() 6 | def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts): 7 | # image retrieval 8 | img2j = {i: j for j, i in enumerate(img_ids)} 9 | _, rank_txt = score_matrix.topk(10, dim=1) 10 | gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] 11 | for txt_id in txt_ids], 12 | ).to(rank_txt.device 13 | ).unsqueeze(1).expand_as(rank_txt) 14 | rank = (rank_txt == gt_img_j).nonzero() 15 | if rank.numel(): 16 | ir_r1 = (rank < 1).sum().item() / len(txt_ids) 17 | ir_r5 = (rank < 5).sum().item() / len(txt_ids) 18 | ir_r10 = (rank < 10).sum().item() / len(txt_ids) 19 | else: 20 | ir_r1, ir_r5, ir_r10 = 0, 0, 0 21 | 22 | # text retrieval 23 | txt2i = {t: i for i, t in enumerate(txt_ids)} 24 | _, rank_img = score_matrix.topk(10, dim=0) 25 | tr_r1, tr_r5, tr_r10 = 0, 0, 0 26 | for j, img_id in enumerate(img_ids): 27 | gt_is = [txt2i[t] for t in img2txts[img_id]] 28 | ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] 29 | rank = min([10] + [r.item() for r in ranks if r.numel()]) 30 | if rank < 1: 31 | tr_r1 += 1 32 | if rank < 5: 33 | tr_r5 += 1 34 | if rank < 10: 35 | tr_r10 += 1 36 | tr_r1 /= len(img_ids) 37 | tr_r5 /= len(img_ids) 38 | tr_r10 /= len(img_ids) 39 | 40 | tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 41 | ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 42 | r_mean = (tr_mean + ir_mean) / 2 43 | 44 | eval_log = {'txt_r1': tr_r1, 45 | 'txt_r5': tr_r5, 46 | 'txt_r10': tr_r10, 47 | 'txt_r_mean': tr_mean, 48 | 'img_r1': ir_r1, 49 | 'img_r5': ir_r5, 50 | 'img_r10': ir_r10, 51 | 'img_r_mean': ir_mean, 52 | 'r_mean': r_mean} 53 | return eval_log 54 | -------------------------------------------------------------------------------- /eval/nlvr2.py: -------------------------------------------------------------------------------- 1 | """ 2 | copied from official NLVR2 github 3 | python eval/nlvr2.py 4 | """ 5 | import json 6 | import sys 7 | 8 | # Load the predictions file. Assume it is a CSV. 9 | predictions = { } 10 | for line in open(sys.argv[1]).readlines(): 11 | if line: 12 | splits = line.strip().split(",") 13 | # We assume identifiers are in the format "split-####-#-#.png". 14 | identifier = splits[0] 15 | prediction = splits[1] 16 | predictions[identifier] = prediction 17 | 18 | # Load the labeled examples. 19 | labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line] 20 | 21 | # If not, identify the ones that are missing, and exit. 22 | total_num = len(labeled_examples) 23 | if len(predictions) < total_num: 24 | print("Some predictions are missing!") 25 | print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num)) 26 | 27 | for example in labeled_examples: 28 | lookup = example["identifier"] 29 | if not lookup in predictions: 30 | print("Missing prediction for item " + str(lookup)) 31 | exit() 32 | 33 | # Get the precision by iterating through the examples and checking the value 34 | # that was predicted. 35 | # Also update the "consistency" dictionary that keeps track of whether all 36 | # predictions for a given sentence were correct. 37 | num_correct = 0. 38 | consistency_dict = { } 39 | 40 | for example in labeled_examples: 41 | anon_label = example["identifier"].split("-") 42 | anon_label[2] = '' 43 | anon_label = '-'.join(anon_label) 44 | if not anon_label in consistency_dict: 45 | consistency_dict[anon_label] = True 46 | lookup = example["identifier"] 47 | prediction = predictions[lookup] 48 | if prediction.lower() == example["label"].lower(): 49 | num_correct += 1. 50 | else: 51 | consistency_dict[anon_label] = False 52 | 53 | # Calculate consistency. 54 | num_consistent = 0. 55 | unique_sentence = len(consistency_dict) 56 | for identifier, consistent in consistency_dict.items(): 57 | if consistent: 58 | num_consistent += 1 59 | 60 | # Report values. 61 | print("accuracy=" + str(num_correct / total_num)) 62 | print("consistency=" + str(num_consistent / unique_sentence)) 63 | -------------------------------------------------------------------------------- /launch_container.sh: -------------------------------------------------------------------------------- 1 | TXT_DB=$1 2 | IMG_DIR=$2 3 | OUTPUT=$3 4 | PRETRAIN_DIR=$4 5 | # TXT_RAW=$6 6 | 7 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 8 | CUDA_VISIBLE_DEVICES='all' 9 | fi 10 | 11 | # if [ "$5" = "--prepro" ]; then 12 | # RO="" 13 | # else 14 | # RO=",readonly" 15 | # fi 16 | 17 | 18 | # docker run --gpus "device=$CUDA_VISIBLE_DEVICES" --ipc=host --rm -it \ 19 | # -e NCCL_IB_CUDA_SUPPORT=0 \ 20 | 21 | 22 | 23 | docker run -p 8892:8892 --runtime=nvidia --ipc=host --rm -it \ 24 | --mount src=$(pwd),dst=/src,type=bind \ 25 | --mount src=$OUTPUT,dst=/storage,type=bind \ 26 | --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ 27 | --mount src=$TXT_DB,dst=/db,type=bind$RO \ 28 | --mount src=$IMG_DIR,dst=/img,type=bind \ 29 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 30 | -e NCCL_IB_CUDA_SUPPORT=0 \ 31 | -w /src zmykevin/uc2:latest_version -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UniterForPretraining, UniterConfig, VLXLMRForPretraining 2 | from .vqa import UniterForVisualQuestionAnswering, VLXLMRForVisualQuestionAnswering 3 | from .nlvr2 import (UniterForNlvr2PairedAttn, UniterForNlvr2Paired, 4 | UniterForNlvr2Triplet) 5 | from .itm import (UniterForImageTextRetrieval, 6 | UniterForImageTextRetrievalHardNeg, 7 | VLXLMRForImageTextRetrieval) 8 | from .ve import UniterForVisualEntailment 9 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/const_variable.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/const_variable.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/itm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/itm.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/layer.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/nlvr2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/nlvr2.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/ot.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ve.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/ve.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/vqa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/vqa.cpython-36.pyc -------------------------------------------------------------------------------- /model/const_variable.py: -------------------------------------------------------------------------------- 1 | from transformers import XLMRobertaTokenizer 2 | import numpy as np 3 | import torch 4 | 5 | #Load the XLMR_TOKER 6 | XLMR_TOKER = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base') 7 | with open('object_labels/img_label_objects.txt', "r") as f: 8 | IMG_LABEL_OBJECTS = f.readlines() 9 | IMG_LABEL_OBJECTS = ['background'] + [x.strip() for x in IMG_LABEL_OBJECTS] 10 | 11 | #initialize LABEL2TOKEN_MATRIX 12 | VALID_XLMR_TOKEN_IDS = [] 13 | #LABEL2TOKEN = {} 14 | LABEL2TOKEN_MATRIX = np.zeros((1601, XLMR_TOKER.vocab_size)) 15 | for i,label_word in enumerate(IMG_LABEL_OBJECTS): 16 | label_tokens = XLMR_TOKER.tokenize(label_word) 17 | label_tokens_ids = [XLMR_TOKER._convert_token_to_id(w) for w in label_tokens] 18 | #LABEL2TOKEN[label_word] = label_tokens_ids 19 | LABEL2TOKEN_MATRIX[i][label_tokens_ids] = 1 20 | VALID_XLMR_TOKEN_IDS.extend(label_tokens_ids) 21 | #torch.tensor(LABEL2TOKEN_MATRIX) 22 | VALID_XLMR_TOKEN_IDS = list(set(VALID_XLMR_TOKEN_IDS)) 23 | VALID_XLMR_TOKEN_IDS.sort() 24 | -------------------------------------------------------------------------------- /model/gqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bert for VCR model 3 | """ 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from pytorch_pretrained_bert.modeling import ( 7 | BertOnlyMLMHead) 8 | from .model import (BertForImageTextPretraining, 9 | _get_image_hidden, 10 | mask_img_feat, 11 | RegionFeatureRegression, 12 | mask_img_feat_for_mrc, 13 | RegionClassification) 14 | import torch 15 | import random 16 | 17 | 18 | class BertForImageTextPretrainingForGQA(BertForImageTextPretraining): 19 | def init_type_embedding(self): 20 | new_emb = nn.Embedding(3, self.bert.config.hidden_size) 21 | new_emb.apply(self.init_bert_weights) 22 | for i in [0, 1]: 23 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] 24 | new_emb.weight.data[i, :].copy_(emb) 25 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] 26 | new_emb.weight.data[2, :].copy_(emb) 27 | self.bert.embeddings.token_type_embeddings = new_emb 28 | 29 | def forward(self, input_ids, position_ids, txt_type_ids, txt_lens, 30 | img_feat, img_pos_feat, num_bbs, 31 | attention_mask, labels, task, compute_loss=True): 32 | if task == 'mlm': 33 | txt_labels = labels 34 | return self.forward_mlm(input_ids, position_ids, txt_type_ids, 35 | txt_lens, 36 | img_feat, img_pos_feat, num_bbs, 37 | attention_mask, txt_labels, compute_loss) 38 | elif task == 'mrm': 39 | img_mask = labels 40 | return self.forward_mrm(input_ids, position_ids, txt_type_ids, 41 | txt_lens, 42 | img_feat, img_pos_feat, num_bbs, 43 | attention_mask, img_mask, compute_loss) 44 | elif task.startswith('mrc'): 45 | img_mask, mrc_label_target = labels 46 | return self.forward_mrc(input_ids, position_ids, txt_type_ids, 47 | txt_lens, 48 | img_feat, img_pos_feat, num_bbs, 49 | attention_mask, img_mask, 50 | mrc_label_target, task, compute_loss) 51 | else: 52 | raise ValueError('invalid task') 53 | 54 | # MLM 55 | def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens, 56 | img_feat, img_pos_feat, num_bbs, 57 | attention_mask, txt_labels, compute_loss=True): 58 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 59 | img_feat, img_pos_feat, num_bbs, 60 | attention_mask, 61 | output_all_encoded_layers=False, 62 | txt_type_ids=txt_type_ids) 63 | # get only the text part 64 | sequence_output = sequence_output[:, :input_ids.size(1), :] 65 | # only compute masked tokens for better efficiency 66 | prediction_scores = self.masked_compute_scores( 67 | sequence_output, txt_labels != -1) 68 | if self.vocab_pad: 69 | prediction_scores = prediction_scores[:, :-self.vocab_pad] 70 | 71 | if compute_loss: 72 | masked_lm_loss = F.cross_entropy(prediction_scores, 73 | txt_labels[txt_labels != -1], 74 | reduction='none') 75 | return masked_lm_loss 76 | else: 77 | return prediction_scores 78 | 79 | # MRM 80 | def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens, 81 | img_feat, img_pos_feat, num_bbs, 82 | attention_mask, img_masks, compute_loss=True): 83 | img_feat, feat_targets = mask_img_feat(img_feat, img_masks) 84 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 85 | img_feat, img_pos_feat, num_bbs, 86 | attention_mask, 87 | output_all_encoded_layers=False, 88 | txt_type_ids=txt_type_ids) 89 | # get only the text part 90 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) 91 | # only compute masked tokens for better efficiency 92 | prediction_feat = self.masked_compute_feat( 93 | sequence_output, img_masks) 94 | 95 | if compute_loss: 96 | mrm_loss = F.mse_loss(prediction_feat, feat_targets, 97 | reduction='none') 98 | return mrm_loss 99 | else: 100 | return prediction_feat 101 | 102 | # MRC 103 | def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens, 104 | img_feat, img_pos_feat, num_bbs, 105 | attention_mask, img_masks, 106 | label_targets, task, compute_loss=True): 107 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks) 108 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 109 | img_feat, img_pos_feat, num_bbs, 110 | attention_mask, 111 | output_all_encoded_layers=False, 112 | txt_type_ids=txt_type_ids) 113 | # get only the image part 114 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) 115 | # only compute masked tokens for better efficiency 116 | prediction_soft_label = self.masked_predict_labels( 117 | sequence_output, img_masks) 118 | 119 | if compute_loss: 120 | if "kl" in task: 121 | prediction_soft_label = F.log_softmax( 122 | prediction_soft_label, dim=-1) 123 | mrc_loss = F.kl_div( 124 | prediction_soft_label, label_targets, reduction='none') 125 | else: 126 | label_targets = torch.max( 127 | label_targets, -1)[1] # argmax 128 | mrc_loss = F.cross_entropy( 129 | prediction_soft_label, label_targets, 130 | ignore_index=0, reduction='none') 131 | return mrc_loss 132 | else: 133 | return prediction_soft_label 134 | -------------------------------------------------------------------------------- /model/itm.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNITER for ITM model 3 | """ 4 | from collections import defaultdict 5 | 6 | import torch 7 | from torch import nn 8 | from .model import UniterPreTrainedModel, UniterModel, VLXLMRPreTrainedModel, VLXLMRModel 9 | 10 | #########################vl-xlmr######################## 11 | 12 | class VLXLMRForImageTextRetrieval(VLXLMRPreTrainedModel): 13 | """ Finetune UNITER for image text retrieval 14 | """ 15 | def __init__(self, config, img_dim, margin=0.2): 16 | super().__init__(config) 17 | self.roberta = VLXLMRModel(config, img_dim) 18 | self.itm_output = nn.Linear(config.hidden_size, 2) 19 | self.rank_output = nn.Linear(config.hidden_size, 1) 20 | self.margin = margin 21 | self.apply(self.init_weights) 22 | 23 | def init_output(self): 24 | """ need to be called after from pretrained """ 25 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :] 26 | self.rank_output.bias.data = self.itm_output.bias.data[1:] 27 | 28 | def forward(self, batch, compute_loss=True): 29 | batch = defaultdict(lambda: None, batch) 30 | input_ids = batch['input_ids'] 31 | #position_ids = batch['position_ids'] 32 | position_ids=None 33 | img_feat = batch['img_feat'] 34 | img_pos_feat = batch['img_pos_feat'] 35 | attention_mask = batch['attn_masks'] 36 | gather_index = batch['gather_index'] 37 | 38 | sequence_output = self.roberta(input_ids, position_ids, 39 | img_feat, img_pos_feat, 40 | attention_mask, gather_index, 41 | output_all_encoded_layers=False) 42 | pooled_output = self.roberta.pooler(sequence_output) 43 | rank_scores = self.rank_output(pooled_output) 44 | 45 | if compute_loss: 46 | # triplet loss 47 | rank_scores_sigmoid = torch.sigmoid(rank_scores) 48 | sample_size = batch['sample_size'] 49 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) 50 | pos = scores[:, :1] 51 | neg = scores[:, 1:] 52 | rank_loss = torch.clamp(self.margin + neg - pos, 0) 53 | return rank_loss 54 | else: 55 | return rank_scores 56 | ####################################################### 57 | class UniterForImageTextRetrieval(UniterPreTrainedModel): 58 | """ Finetune UNITER for image text retrieval 59 | """ 60 | def __init__(self, config, img_dim, margin=0.2): 61 | super().__init__(config) 62 | self.bert = UniterModel(config, img_dim) 63 | self.itm_output = nn.Linear(config.hidden_size, 2) 64 | #self.itm_output = nn.Linear(config.hidden_size, 1) 65 | self.rank_output = nn.Linear(config.hidden_size, 1) 66 | self.margin = margin 67 | self.apply(self.init_weights) 68 | 69 | def init_output(self): 70 | """ need to be called after from pretrained """ 71 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :] 72 | self.rank_output.bias.data = self.itm_output.bias.data[1:] 73 | # self.rank_output.weight.data = self.itm_output.weight.data 74 | # self.rank_output.bias.data = self.itm_output.bias.data 75 | 76 | def forward(self, batch, compute_loss=True): 77 | batch = defaultdict(lambda: None, batch) 78 | input_ids = batch['input_ids'] 79 | position_ids = batch['position_ids'] 80 | img_feat = batch['img_feat'] 81 | img_pos_feat = batch['img_pos_feat'] 82 | attention_mask = batch['attn_masks'] 83 | gather_index = batch['gather_index'] 84 | sequence_output = self.bert(input_ids, position_ids, 85 | img_feat, img_pos_feat, 86 | attention_mask, gather_index, 87 | output_all_encoded_layers=False) 88 | pooled_output = self.bert.pooler(sequence_output) 89 | #print(pooled_output) 90 | rank_scores = self.rank_output(pooled_output) 91 | 92 | if compute_loss: 93 | # triplet loss 94 | rank_scores_sigmoid = torch.sigmoid(rank_scores) 95 | sample_size = batch['sample_size'] 96 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) 97 | pos = scores[:, :1] 98 | neg = scores[:, 1:] 99 | rank_loss = torch.clamp(self.margin + neg - pos, 0) 100 | return rank_loss 101 | else: 102 | return rank_scores 103 | 104 | 105 | class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): 106 | """ Finetune UNITER for image text retrieval 107 | """ 108 | def __init__(self, config, img_dim, margin=0.2, hard_size=16): 109 | super().__init__(config, img_dim, margin) 110 | self.hard_size = hard_size 111 | 112 | def forward(self, batch, sample_from='t', compute_loss=True): 113 | # expect same input_ids for all pairs 114 | batch_size = batch['attn_masks'].size(0) 115 | input_ids = batch['input_ids'] 116 | img_feat = batch['img_feat'] 117 | img_pos_feat = batch['img_pos_feat'] 118 | if sample_from == 't': 119 | if input_ids.size(0) == 1: 120 | batch['input_ids'] = input_ids.expand(batch_size, -1) 121 | elif sample_from == 'i': 122 | if img_feat.size(0) == 1: 123 | batch['img_feat'] = img_feat.expand(batch_size, -1, -1) 124 | if img_pos_feat.size(0) == 1: 125 | batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) 126 | else: 127 | raise ValueError() 128 | 129 | if self.training and compute_loss: 130 | with torch.no_grad(): 131 | self.eval() 132 | scores = super().forward(batch, compute_loss=False) 133 | hard_batch = self._get_hard_batch(batch, scores, sample_from) 134 | self.train() 135 | return super().forward(hard_batch, compute_loss=True) 136 | else: 137 | return super().forward(batch, compute_loss) 138 | 139 | def _get_hard_batch(self, batch, scores, sample_from='t'): 140 | batch = defaultdict(lambda: None, batch) 141 | input_ids = batch['input_ids'] 142 | position_ids = batch['position_ids'] 143 | img_feat = batch['img_feat'] 144 | img_pos_feat = batch['img_pos_feat'] 145 | attention_mask = batch['attn_masks'] 146 | gather_index = batch['gather_index'] 147 | hard_batch = {'sample_size': self.hard_size + 1} 148 | 149 | # NOTE first example is positive 150 | hard_indices = scores.squeeze(-1)[1:].topk( 151 | self.hard_size, sorted=False)[1] + 1 152 | indices = torch.cat([torch.zeros(1, dtype=torch.long, 153 | device=hard_indices.device), 154 | hard_indices]) 155 | 156 | attention_mask = attention_mask.index_select(0, indices) 157 | gather_index = gather_index.index_select(0, indices) 158 | if position_ids.size(0) != 1: 159 | position_ids = position_ids[:self.hard_size+1] 160 | 161 | if sample_from == 't': 162 | # cut to minimum padding 163 | max_len = attention_mask.sum(dim=1).max().item() 164 | max_i = max_len - input_ids.size(1) 165 | attention_mask = attention_mask[:, :max_len] 166 | gather_index = gather_index[:, :max_len] 167 | img_feat = img_feat.index_select(0, indices)[:, :max_i, :] 168 | img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] 169 | # expect same input_ids for all pairs 170 | input_ids = input_ids[:self.hard_size+1] 171 | elif sample_from == 'i': 172 | input_ids = input_ids.index_select(0, indices) 173 | # expect same image features for all pairs 174 | img_feat = img_feat[:self.hard_size+1] 175 | img_pos_feat = img_pos_feat[:self.hard_size+1] 176 | else: 177 | raise ValueError() 178 | 179 | hard_batch['input_ids'] = input_ids 180 | hard_batch['position_ids'] = position_ids 181 | hard_batch['img_feat'] = img_feat 182 | hard_batch['img_pos_feat'] = img_pos_feat 183 | hard_batch['attn_masks'] = attention_mask 184 | hard_batch['gather_index'] = gather_index 185 | 186 | return hard_batch 187 | -------------------------------------------------------------------------------- /model/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | BERT layers from the huggingface implementation 3 | (https://github.com/huggingface/transformers) 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import logging 21 | import math 22 | 23 | import torch 24 | from torch import nn 25 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def gelu(x): 32 | """Implementation of the gelu activation function. 33 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 34 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 35 | Also see https://arxiv.org/abs/1606.08415 36 | """ 37 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 38 | 39 | 40 | def swish(x): 41 | return x * torch.sigmoid(x) 42 | 43 | 44 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 45 | 46 | 47 | class GELU(nn.Module): 48 | def forward(self, input_): 49 | output = gelu(input_) 50 | return output 51 | 52 | 53 | class BertSelfAttention(nn.Module): 54 | def __init__(self, config): 55 | super(BertSelfAttention, self).__init__() 56 | if config.hidden_size % config.num_attention_heads != 0: 57 | raise ValueError( 58 | "The hidden size (%d) is not a multiple of the number of attention " 59 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 60 | self.num_attention_heads = config.num_attention_heads 61 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 62 | self.all_head_size = self.num_attention_heads * self.attention_head_size 63 | 64 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 65 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 66 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 67 | 68 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 69 | 70 | def transpose_for_scores(self, x): 71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 72 | x = x.view(*new_x_shape) 73 | return x.permute(0, 2, 1, 3) 74 | 75 | def forward(self, hidden_states, attention_mask): 76 | mixed_query_layer = self.query(hidden_states) 77 | mixed_key_layer = self.key(hidden_states) 78 | mixed_value_layer = self.value(hidden_states) 79 | 80 | query_layer = self.transpose_for_scores(mixed_query_layer) 81 | key_layer = self.transpose_for_scores(mixed_key_layer) 82 | value_layer = self.transpose_for_scores(mixed_value_layer) 83 | 84 | # Take the dot product between "query" and "key" to get the raw attention scores. 85 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 86 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 87 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 88 | attention_scores = attention_scores + attention_mask 89 | 90 | # Normalize the attention scores to probabilities. 91 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 92 | 93 | # This is actually dropping out entire tokens to attend to, which might 94 | # seem a bit unusual, but is taken from the original Transformer paper. 95 | attention_probs = self.dropout(attention_probs) 96 | 97 | context_layer = torch.matmul(attention_probs, value_layer) 98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 100 | context_layer = context_layer.view(*new_context_layer_shape) 101 | return context_layer 102 | 103 | 104 | class BertSelfOutput(nn.Module): 105 | def __init__(self, config): 106 | super(BertSelfOutput, self).__init__() 107 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 108 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 109 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 110 | 111 | def forward(self, hidden_states, input_tensor): 112 | hidden_states = self.dense(hidden_states) 113 | hidden_states = self.dropout(hidden_states) 114 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 115 | return hidden_states 116 | 117 | 118 | class BertAttention(nn.Module): 119 | def __init__(self, config): 120 | super(BertAttention, self).__init__() 121 | self.self = BertSelfAttention(config) 122 | self.output = BertSelfOutput(config) 123 | 124 | def forward(self, input_tensor, attention_mask): 125 | self_output = self.self(input_tensor, attention_mask) 126 | attention_output = self.output(self_output, input_tensor) 127 | return attention_output 128 | 129 | 130 | class BertIntermediate(nn.Module): 131 | def __init__(self, config): 132 | super(BertIntermediate, self).__init__() 133 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 134 | if isinstance(config.hidden_act, str): 135 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 136 | else: 137 | self.intermediate_act_fn = config.hidden_act 138 | 139 | def forward(self, hidden_states): 140 | hidden_states = self.dense(hidden_states) 141 | hidden_states = self.intermediate_act_fn(hidden_states) 142 | return hidden_states 143 | 144 | 145 | class BertOutput(nn.Module): 146 | def __init__(self, config): 147 | super(BertOutput, self).__init__() 148 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 149 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 150 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 151 | 152 | def forward(self, hidden_states, input_tensor): 153 | hidden_states = self.dense(hidden_states) 154 | hidden_states = self.dropout(hidden_states) 155 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 156 | return hidden_states 157 | 158 | 159 | class BertLayer(nn.Module): 160 | def __init__(self, config): 161 | super(BertLayer, self).__init__() 162 | self.attention = BertAttention(config) 163 | self.intermediate = BertIntermediate(config) 164 | self.output = BertOutput(config) 165 | 166 | def forward(self, hidden_states, attention_mask): 167 | attention_output = self.attention(hidden_states, attention_mask) 168 | intermediate_output = self.intermediate(attention_output) 169 | layer_output = self.output(intermediate_output, attention_output) 170 | return layer_output 171 | 172 | 173 | class BertPooler(nn.Module): 174 | def __init__(self, config): 175 | super(BertPooler, self).__init__() 176 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 177 | self.activation = nn.Tanh() 178 | 179 | def forward(self, hidden_states): 180 | # We "pool" the model by simply taking the hidden state corresponding 181 | # to the first token. 182 | first_token_tensor = hidden_states[:, 0] 183 | pooled_output = self.dense(first_token_tensor) 184 | pooled_output = self.activation(pooled_output) 185 | return pooled_output 186 | 187 | 188 | class BertPredictionHeadTransform(nn.Module): 189 | def __init__(self, config): 190 | super(BertPredictionHeadTransform, self).__init__() 191 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 192 | if isinstance(config.hidden_act, str): 193 | self.transform_act_fn = ACT2FN[config.hidden_act] 194 | else: 195 | self.transform_act_fn = config.hidden_act 196 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 197 | 198 | def forward(self, hidden_states): 199 | hidden_states = self.dense(hidden_states) 200 | hidden_states = self.transform_act_fn(hidden_states) 201 | hidden_states = self.LayerNorm(hidden_states) 202 | return hidden_states 203 | 204 | 205 | class BertLMPredictionHead(nn.Module): 206 | def __init__(self, config, bert_model_embedding_weights): 207 | super(BertLMPredictionHead, self).__init__() 208 | self.transform = BertPredictionHeadTransform(config) 209 | 210 | # The output weights are the same as the input embeddings, but there is 211 | # an output-only bias for each token. 212 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 213 | bert_model_embedding_weights.size(0), 214 | bias=False) 215 | self.decoder.weight = bert_model_embedding_weights 216 | self.bias = nn.Parameter( 217 | torch.zeros(bert_model_embedding_weights.size(0))) 218 | 219 | def forward(self, hidden_states): 220 | hidden_states = self.transform(hidden_states) 221 | hidden_states = self.decoder(hidden_states) + self.bias 222 | return hidden_states 223 | 224 | 225 | class BertOnlyMLMHead(nn.Module): 226 | def __init__(self, config, bert_model_embedding_weights): 227 | super(BertOnlyMLMHead, self).__init__() 228 | self.predictions = BertLMPredictionHead(config, 229 | bert_model_embedding_weights) 230 | 231 | def forward(self, sequence_output): 232 | prediction_scores = self.predictions(sequence_output) 233 | return prediction_scores 234 | 235 | ####################Mingyang Zhou############################## 236 | class RobertaLMHead(nn.Module): 237 | """Roberta Head for masked language modeling.""" 238 | 239 | def __init__(self, config, roberta_model_embedding_weights): 240 | super().__init__() 241 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 242 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 243 | 244 | #self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 245 | self.decoder = nn.Linear(roberta_model_embedding_weights.size(1), 246 | roberta_model_embedding_weights.size(0), 247 | bias=False 248 | ) 249 | self.decoder.weight = roberta_model_embedding_weights 250 | self.bias = nn.Parameter( 251 | torch.zeros(roberta_model_embedding_weights.size(0))) 252 | #self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 253 | 254 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 255 | self.decoder.bias = self.bias 256 | 257 | def forward(self, features, **kwargs): 258 | x = self.dense(features) 259 | x = gelu(x) 260 | x = self.layer_norm(x) 261 | 262 | # project back to size of vocabulary with bias 263 | x = self.decoder(x) 264 | 265 | return x 266 | 267 | class VisualRobertaLMHead(nn.Module): 268 | def __init__(self, config, roberta_model_embedding_weights, valid_roberta_model_embedding_index): 269 | super().__init__() 270 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 271 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 272 | 273 | #self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 274 | self.decoder = nn.Linear(roberta_model_embedding_weights.size(1), 275 | len(valid_roberta_model_embedding_index), 276 | bias=False 277 | ) 278 | self.decoder.weight = nn.Parameter(roberta_model_embedding_weights[valid_roberta_model_embedding_index]) 279 | self.bias = nn.Parameter( 280 | torch.zeros(len(valid_roberta_model_embedding_index))) 281 | #self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 282 | 283 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 284 | self.decoder.bias = self.bias 285 | 286 | def forward(self, features, **kwargs): 287 | x = self.dense(features) 288 | x = gelu(x) 289 | x = self.layer_norm(x) 290 | 291 | # project back to size of vocabulary with bias 292 | x = self.decoder(x) 293 | 294 | return x -------------------------------------------------------------------------------- /model/nlvr2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Uniter for NLVR2 model 6 | """ 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from .layer import GELU 12 | from .model import UniterPreTrainedModel, UniterModel 13 | from .attention import MultiheadAttention 14 | 15 | 16 | class UniterForNlvr2Paired(UniterPreTrainedModel): 17 | """ Finetune UNITER for NLVR2 (paired format) 18 | """ 19 | def __init__(self, config, img_dim): 20 | super().__init__(config) 21 | self.bert = UniterModel(config, img_dim) 22 | self.nlvr2_output = nn.Linear(config.hidden_size*2, 2) 23 | self.apply(self.init_weights) 24 | 25 | def init_type_embedding(self): 26 | new_emb = nn.Embedding(3, self.bert.config.hidden_size) 27 | new_emb.apply(self.init_weights) 28 | for i in [0, 1]: 29 | emb = self.bert.embeddings.token_type_embeddings\ 30 | .weight.data[i, :] 31 | new_emb.weight.data[i, :].copy_(emb) 32 | new_emb.weight.data[2, :].copy_(emb) 33 | self.bert.embeddings.token_type_embeddings = new_emb 34 | 35 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat, 36 | attn_masks, gather_index, 37 | img_type_ids, targets, compute_loss=True): 38 | sequence_output = self.bert(input_ids, position_ids, 39 | img_feat, img_pos_feat, 40 | attn_masks, gather_index, 41 | output_all_encoded_layers=False, 42 | img_type_ids=img_type_ids) 43 | pooled_output = self.bert.pooler(sequence_output) 44 | # concat CLS of the pair 45 | n_pair = pooled_output.size(0) // 2 46 | reshaped_output = pooled_output.contiguous().view(n_pair, -1) 47 | answer_scores = self.nlvr2_output(reshaped_output) 48 | 49 | if compute_loss: 50 | nlvr2_loss = F.cross_entropy( 51 | answer_scores, targets, reduction='none') 52 | return nlvr2_loss 53 | else: 54 | return answer_scores 55 | 56 | 57 | class UniterForNlvr2Triplet(UniterPreTrainedModel): 58 | """ Finetune UNITER for NLVR2 (triplet format) 59 | """ 60 | def __init__(self, config, img_dim): 61 | super().__init__(config) 62 | self.bert = UniterModel(config, img_dim) 63 | self.nlvr2_output = nn.Linear(config.hidden_size, 2) 64 | self.apply(self.init_weights) 65 | 66 | def init_type_embedding(self): 67 | new_emb = nn.Embedding(3, self.bert.config.hidden_size) 68 | new_emb.apply(self.init_weights) 69 | for i in [0, 1]: 70 | emb = self.bert.embeddings.token_type_embeddings\ 71 | .weight.data[i, :] 72 | new_emb.weight.data[i, :].copy_(emb) 73 | new_emb.weight.data[2, :].copy_(emb) 74 | self.bert.embeddings.token_type_embeddings = new_emb 75 | 76 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat, 77 | attn_masks, gather_index, 78 | img_type_ids, targets, compute_loss=True): 79 | sequence_output = self.bert(input_ids, position_ids, 80 | img_feat, img_pos_feat, 81 | attn_masks, gather_index, 82 | output_all_encoded_layers=False, 83 | img_type_ids=img_type_ids) 84 | pooled_output = self.bert.pooler(sequence_output) 85 | answer_scores = self.nlvr2_output(pooled_output) 86 | 87 | if compute_loss: 88 | nlvr2_loss = F.cross_entropy( 89 | answer_scores, targets, reduction='none') 90 | return nlvr2_loss 91 | else: 92 | return answer_scores 93 | 94 | 95 | class AttentionPool(nn.Module): 96 | """ attention pooling layer """ 97 | def __init__(self, hidden_size, drop=0.0): 98 | super().__init__() 99 | self.fc = nn.Sequential(nn.Linear(hidden_size, 1), GELU()) 100 | self.dropout = nn.Dropout(drop) 101 | 102 | def forward(self, input_, mask=None): 103 | """input: [B, T, D], mask = [B, T]""" 104 | score = self.fc(input_).squeeze(-1) 105 | if mask is not None: 106 | mask = mask.to(dtype=input_.dtype) * -1e4 107 | score = score + mask 108 | norm_score = self.dropout(F.softmax(score, dim=1)) 109 | output = norm_score.unsqueeze(1).matmul(input_).squeeze(1) 110 | return output 111 | 112 | 113 | class UniterForNlvr2PairedAttn(UniterPreTrainedModel): 114 | """ Finetune UNITER for NLVR2 115 | (paired format with additional attention layer) 116 | """ 117 | def __init__(self, config, img_dim): 118 | super().__init__(config) 119 | self.bert = UniterModel(config, img_dim) 120 | self.attn1 = MultiheadAttention(config.hidden_size, 121 | config.num_attention_heads, 122 | config.attention_probs_dropout_prob) 123 | self.attn2 = MultiheadAttention(config.hidden_size, 124 | config.num_attention_heads, 125 | config.attention_probs_dropout_prob) 126 | self.fc = nn.Sequential( 127 | nn.Linear(2*config.hidden_size, config.hidden_size), 128 | GELU(), 129 | nn.Dropout(config.hidden_dropout_prob)) 130 | self.attn_pool = AttentionPool(config.hidden_size, 131 | config.attention_probs_dropout_prob) 132 | self.nlvr2_output = nn.Linear(2*config.hidden_size, 2) 133 | self.apply(self.init_weights) 134 | 135 | def init_type_embedding(self): 136 | new_emb = nn.Embedding(3, self.bert.config.hidden_size) 137 | new_emb.apply(self.init_weights) 138 | for i in [0, 1]: 139 | emb = self.bert.embeddings.token_type_embeddings\ 140 | .weight.data[i, :] 141 | new_emb.weight.data[i, :].copy_(emb) 142 | new_emb.weight.data[2, :].copy_(emb) 143 | self.bert.embeddings.token_type_embeddings = new_emb 144 | 145 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat, 146 | attn_masks, gather_index, 147 | img_type_ids, targets, compute_loss=True): 148 | sequence_output = self.bert(input_ids, position_ids, 149 | img_feat, img_pos_feat, 150 | attn_masks, gather_index, 151 | output_all_encoded_layers=False, 152 | img_type_ids=img_type_ids) 153 | # separate left image and right image 154 | bs, tl, d = sequence_output.size() 155 | left_out, right_out = sequence_output.contiguous().view( 156 | bs//2, tl*2, d).chunk(2, dim=1) 157 | # bidirectional attention 158 | mask = attn_masks == 0 159 | left_mask, right_mask = mask.contiguous().view(bs//2, tl*2 160 | ).chunk(2, dim=1) 161 | left_out = left_out.transpose(0, 1) 162 | right_out = right_out.transpose(0, 1) 163 | l2r_attn, _ = self.attn1(left_out, right_out, right_out, 164 | key_padding_mask=right_mask) 165 | r2l_attn, _ = self.attn2(right_out, left_out, left_out, 166 | key_padding_mask=left_mask) 167 | left_out = self.fc(torch.cat([l2r_attn, left_out], dim=-1) 168 | ).transpose(0, 1) 169 | right_out = self.fc(torch.cat([r2l_attn, right_out], dim=-1) 170 | ).transpose(0, 1) 171 | # attention pooling and final prediction 172 | left_out = self.attn_pool(left_out, left_mask) 173 | right_out = self.attn_pool(right_out, right_mask) 174 | answer_scores = self.nlvr2_output( 175 | torch.cat([left_out, right_out], dim=-1)) 176 | 177 | if compute_loss: 178 | nlvr2_loss = F.cross_entropy( 179 | answer_scores, targets, reduction='none') 180 | return nlvr2_loss 181 | else: 182 | return answer_scores 183 | -------------------------------------------------------------------------------- /model/ot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wasserstein Distance (Optimal Transport) 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def cost_matrix_cosine(x, y, eps=1e-5): 9 | """ Compute cosine distnace across every pairs of x, y (batched) 10 | [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]""" 11 | assert x.dim() == y.dim() 12 | assert x.size(0) == y.size(0) 13 | assert x.size(2) == y.size(2) 14 | x_norm = F.normalize(x, p=2, dim=-1, eps=eps) 15 | y_norm = F.normalize(y, p=2, dim=-1, eps=eps) 16 | cosine_sim = x_norm.matmul(y_norm.transpose(1, 2)) 17 | cosine_dist = 1 - cosine_sim 18 | return cosine_dist 19 | 20 | 21 | def trace(x): 22 | """ compute trace of input tensor (batched) """ 23 | b, m, n = x.size() 24 | assert m == n 25 | mask = torch.eye(n, dtype=torch.uint8, device=x.device 26 | ).unsqueeze(0).expand_as(x) 27 | trace = x.masked_select(mask).contiguous().view( 28 | b, n).sum(dim=-1, keepdim=False) 29 | return trace 30 | 31 | 32 | @torch.no_grad() 33 | def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k): 34 | """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]""" 35 | b, m, n = C.size() 36 | sigma = torch.ones(b, m, dtype=C.dtype, device=C.device 37 | ) / x_len.unsqueeze(1) 38 | T = torch.ones(b, n, m, dtype=C.dtype, device=C.device) 39 | A = torch.exp(-C.transpose(1, 2)/beta) 40 | 41 | # mask padded positions 42 | sigma.masked_fill_(x_pad, 0) 43 | joint_pad = joint_pad.transpose(1, 2) 44 | T.masked_fill_(joint_pad, 0) 45 | A.masked_fill_(joint_pad, 0) 46 | 47 | # broadcastable lengths 48 | x_len = x_len.unsqueeze(1).unsqueeze(2) 49 | y_len = y_len.unsqueeze(1).unsqueeze(2) 50 | 51 | # mask to zero out padding in delta and sigma 52 | x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1) 53 | y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1) 54 | 55 | for _ in range(iteration): 56 | Q = A * T # bs * n * m 57 | sigma = sigma.view(b, m, 1) 58 | for _ in range(k): 59 | delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask) 60 | sigma = 1 / (x_len * delta.matmul(Q) + x_mask) 61 | T = delta.view(b, n, 1) * Q * sigma 62 | T.masked_fill_(joint_pad, 0) 63 | return T 64 | 65 | 66 | def optimal_transport_dist(txt_emb, img_emb, txt_pad, img_pad, 67 | beta=0.5, iteration=50, k=1): 68 | """ [B, M, D], [B, N, D], [B, M], [B, N]""" 69 | cost = cost_matrix_cosine(txt_emb, img_emb) 70 | # mask the padded inputs 71 | joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) 72 | cost.masked_fill_(joint_pad, 0) 73 | 74 | txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False) 75 | ).to(dtype=cost.dtype) 76 | img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False) 77 | ).to(dtype=cost.dtype) 78 | 79 | T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 80 | beta, iteration, k) 81 | distance = trace(cost.matmul(T.detach())) 82 | return distance 83 | -------------------------------------------------------------------------------- /model/re.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bert for Referring Expression Comprehension 3 | """ 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm 9 | 10 | from .model import BertVisionLanguageEncoder 11 | 12 | import numpy as np 13 | import random 14 | 15 | 16 | class BertForReferringExpressionComprehension(BertPreTrainedModel): 17 | """Finetune multi-model BERT for Referring Expression Comprehension 18 | """ 19 | def __init__(self, config, img_dim, loss="cls", 20 | margin=0.2, hard_ratio=0.3, mlp=1): 21 | super().__init__(config) 22 | self.bert = BertVisionLanguageEncoder(config, img_dim) 23 | if mlp == 1: 24 | self.re_output = nn.Linear(config.hidden_size, 1) 25 | elif mlp == 2: 26 | self.re_output = nn.Sequential( 27 | nn.Linear(config.hidden_size, config.hidden_size), 28 | nn.ReLU(), 29 | BertLayerNorm(config.hidden_size, eps=1e-12), 30 | nn.Linear(config.hidden_size, 1) 31 | ) 32 | else: 33 | sys.exit("MLP restricted to be 1 or 2 layers.") 34 | self.loss = loss 35 | assert self.loss in ['cls', 'rank'] 36 | if self.loss == 'rank': 37 | self.margin = margin 38 | self.hard_ratio = hard_ratio 39 | else: 40 | self.crit = nn.CrossEntropyLoss(reduction='none') 41 | # initialize 42 | self.apply(self.init_bert_weights) 43 | 44 | def forward(self, input_ids, position_ids, txt_lens, img_feat, 45 | img_pos_feat, num_bbs, attn_masks, obj_masks, 46 | targets, compute_loss=True): 47 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 48 | img_feat, img_pos_feat, num_bbs, 49 | attn_masks, 50 | output_all_encoded_layers=False) 51 | # get only the region part 52 | sequence_output = self._get_image_hidden(sequence_output, txt_lens, 53 | num_bbs) 54 | 55 | # re score (n, max_num_bb) 56 | scores = self.re_output(sequence_output).squeeze(2) 57 | scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects 58 | 59 | # loss 60 | if compute_loss: 61 | if self.loss == 'cls': 62 | ce_loss = self.crit(scores, targets) # (n, ) as no reduction 63 | return ce_loss 64 | else: 65 | # ranking 66 | _n = len(num_bbs) 67 | # positive (target) 68 | pos_ix = targets 69 | pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1) 70 | pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1] 71 | # negative 72 | neg_ix = self.sample_neg_ix(scores, targets, num_bbs) 73 | neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1) 74 | neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1] 75 | # ranking 76 | mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, ) 77 | return mm_loss 78 | else: 79 | # (n, max_num_bb) 80 | return scores 81 | 82 | def sample_neg_ix(self, scores, targets, num_bbs): 83 | """ 84 | Inputs: 85 | :scores (n, max_num_bb) 86 | :targets (n, ) 87 | :num_bbs list of [num_bb] 88 | return: 89 | :neg_ix (n, ) easy/hard negative (!= target) 90 | """ 91 | neg_ix = [] 92 | cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb) 93 | for i in range(len(num_bbs)): 94 | num_bb = num_bbs[i] 95 | if np.random.uniform(0, 1, 1) < self.hard_ratio: 96 | # sample hard negative, w/ highest score 97 | for ix in cand_ixs[i].tolist(): 98 | if ix != targets[i]: 99 | assert ix < num_bb, f'ix={ix}, num_bb={num_bb}' 100 | neg_ix.append(ix) 101 | break 102 | else: 103 | # sample easy negative, i.e., random one 104 | ix = random.randint(0, num_bb-1) # [0, num_bb-1] 105 | while ix == targets[i]: 106 | ix = random.randint(0, num_bb-1) 107 | neg_ix.append(ix) 108 | neg_ix = torch.tensor(neg_ix).type(targets.type()) 109 | assert neg_ix.numel() == targets.numel() 110 | return neg_ix 111 | 112 | def _get_image_hidden(self, sequence_output, txt_lens, num_bbs): 113 | """ 114 | Extracting the img_hidden part from sequence_output. 115 | Inputs: 116 | - sequence_output: (n, txt_len+num_bb, hid_size) 117 | - txt_lens : [txt_len] 118 | - num_bbs : [num_bb] 119 | Output: 120 | - img_hidden : (n, max_num_bb, hid_size) 121 | """ 122 | outputs = [] 123 | max_bb = max(num_bbs) 124 | hid_size = sequence_output.size(-1) 125 | for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0), 126 | txt_lens, num_bbs): 127 | img_hid = seq_out[:, len_:len_+nbb, :] 128 | if nbb < max_bb: 129 | img_hid = torch.cat( 130 | [img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)], 131 | dim=1) 132 | outputs.append(img_hid) 133 | 134 | img_hidden = torch.cat(outputs, dim=0) 135 | return img_hidden 136 | 137 | def _get_pad(self, t, len_, hidden_size): 138 | pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device) 139 | return pad 140 | 141 | -------------------------------------------------------------------------------- /model/vcr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bert for VCR model 3 | """ 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from pytorch_pretrained_bert.modeling import ( 7 | BertPreTrainedModel, BertEmbeddings, BertEncoder, BertLayerNorm, 8 | BertPooler, BertOnlyMLMHead) 9 | from .model import (BertTextEmbeddings, BertImageEmbeddings, 10 | BertForImageTextMaskedLM, 11 | BertVisionLanguageEncoder, 12 | BertForImageTextPretraining, 13 | _get_image_hidden, 14 | mask_img_feat, 15 | RegionFeatureRegression, 16 | mask_img_feat_for_mrc, 17 | RegionClassification) 18 | import torch 19 | import random 20 | 21 | 22 | class BertVisionLanguageEncoderForVCR(BertVisionLanguageEncoder): 23 | """ Modification for Joint Vision-Language Encoding 24 | """ 25 | def __init__(self, config, img_dim, num_region_toks): 26 | BertPreTrainedModel.__init__(self, config) 27 | self.embeddings = BertTextEmbeddings(config) 28 | self.img_embeddings = BertImageEmbeddings(config, img_dim) 29 | self.num_region_toks = num_region_toks 30 | self.region_token_embeddings = nn.Embedding( 31 | num_region_toks, 32 | config.hidden_size) 33 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 34 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 35 | self.encoder = BertEncoder(config) 36 | self.pooler = BertPooler(config) 37 | self.apply(self.init_bert_weights) 38 | 39 | def forward(self, input_ids, position_ids, txt_lens, 40 | img_feat, img_pos_feat, num_bbs, 41 | attention_mask, output_all_encoded_layers=True, 42 | txt_type_ids=None, img_type_ids=None, region_tok_ids=None): 43 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 44 | extended_attention_mask = extended_attention_mask.to( 45 | dtype=next(self.parameters()).dtype) # fp16 compatibility 46 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 47 | 48 | embedding_output = self._compute_img_txt_embeddings( 49 | input_ids, position_ids, txt_lens, 50 | img_feat, img_pos_feat, num_bbs, attention_mask.size(1), 51 | txt_type_ids, img_type_ids) 52 | if region_tok_ids is not None: 53 | region_tok_embeddings = self.region_token_embeddings( 54 | region_tok_ids) 55 | embedding_output += region_tok_embeddings 56 | embedding_output = self.LayerNorm(embedding_output) 57 | embedding_output = self.dropout(embedding_output) 58 | encoded_layers = self.encoder( 59 | embedding_output, extended_attention_mask, 60 | output_all_encoded_layers=output_all_encoded_layers) 61 | if not output_all_encoded_layers: 62 | encoded_layers = encoded_layers[-1] 63 | return encoded_layers 64 | 65 | 66 | class BertForVisualCommonsenseReasoning(BertPreTrainedModel): 67 | """ Finetune multi-modal BERT for ITM 68 | """ 69 | def __init__(self, config, img_dim, obj_cls=True, img_label_dim=81): 70 | super().__init__(config, img_dim) 71 | self.bert = BertVisionLanguageEncoder( 72 | config, img_dim) 73 | # self.vcr_output = nn.Linear(config.hidden_size, 1) 74 | # self.vcr_output = nn.Linear(config.hidden_size, 2) 75 | self.vcr_output = nn.Sequential( 76 | nn.Linear(config.hidden_size, config.hidden_size*2), 77 | nn.ReLU(), 78 | BertLayerNorm(config.hidden_size*2, eps=1e-12), 79 | nn.Linear(config.hidden_size*2, 2) 80 | ) 81 | self.apply(self.init_bert_weights) 82 | self.obj_cls = obj_cls 83 | if self.obj_cls: 84 | self.region_classifier = RegionClassification( 85 | config.hidden_size, img_label_dim) 86 | 87 | def init_type_embedding(self): 88 | new_emb = nn.Embedding(4, self.bert.config.hidden_size) 89 | new_emb.apply(self.init_bert_weights) 90 | for i in [0, 1]: 91 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] 92 | new_emb.weight.data[i, :].copy_(emb) 93 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] 94 | new_emb.weight.data[2, :].copy_(emb) 95 | new_emb.weight.data[3, :].copy_(emb) 96 | self.bert.embeddings.token_type_embeddings = new_emb 97 | 98 | def init_word_embedding(self, num_special_tokens): 99 | orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0) 100 | new_emb = nn.Embedding( 101 | orig_word_num + num_special_tokens, self.bert.config.hidden_size) 102 | new_emb.apply(self.init_bert_weights) 103 | emb = self.bert.embeddings.word_embeddings.weight.data 104 | new_emb.weight.data[:orig_word_num, :].copy_(emb) 105 | self.bert.embeddings.word_embeddings = new_emb 106 | 107 | def masked_predict_labels(self, sequence_output, mask): 108 | # only compute masked outputs 109 | mask = mask.unsqueeze(-1).expand_as(sequence_output) 110 | sequence_output_masked = sequence_output[mask].contiguous().view( 111 | -1, self.config.hidden_size) 112 | prediction_soft_label = self.region_classifier(sequence_output_masked) 113 | 114 | return prediction_soft_label 115 | 116 | def forward(self, input_ids, position_ids, txt_lens, txt_type_ids, 117 | img_feat, img_pos_feat, num_bbs, 118 | attention_mask, targets, obj_targets=None, img_masks=None, 119 | region_tok_ids=None, compute_loss=True): 120 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 121 | img_feat, img_pos_feat, num_bbs, 122 | attention_mask, 123 | output_all_encoded_layers=False, 124 | txt_type_ids=txt_type_ids) 125 | pooled_output = self.bert.pooler(sequence_output) 126 | rank_scores = self.vcr_output(pooled_output) 127 | # rank_scores = rank_scores.reshape((-1, 4)) 128 | 129 | if self.obj_cls and img_masks is not None: 130 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks) 131 | masked_sequence_output = self.bert( 132 | input_ids, position_ids, txt_lens, 133 | img_feat, img_pos_feat, num_bbs, 134 | attention_mask, 135 | output_all_encoded_layers=False, 136 | txt_type_ids=txt_type_ids) 137 | # get only the image part 138 | img_sequence_output = _get_image_hidden( 139 | masked_sequence_output, txt_lens, num_bbs) 140 | # only compute masked tokens for better efficiency 141 | predicted_obj_label = self.masked_predict_labels( 142 | img_sequence_output, img_masks) 143 | 144 | if compute_loss: 145 | vcr_loss = F.cross_entropy( 146 | rank_scores, targets.squeeze(-1), 147 | reduction='mean') 148 | if self.obj_cls: 149 | obj_cls_loss = F.cross_entropy( 150 | predicted_obj_label, obj_targets.long(), 151 | ignore_index=0, reduction='mean') 152 | else: 153 | obj_cls_loss = torch.tensor([0.], device=vcr_loss.device) 154 | return vcr_loss, obj_cls_loss 155 | else: 156 | rank_scores = rank_scores[:, 1:] 157 | return rank_scores 158 | 159 | 160 | class BertForImageTextPretrainingForVCR(BertForImageTextPretraining): 161 | def init_type_embedding(self): 162 | new_emb = nn.Embedding(4, self.bert.config.hidden_size) 163 | new_emb.apply(self.init_bert_weights) 164 | for i in [0, 1]: 165 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :] 166 | new_emb.weight.data[i, :].copy_(emb) 167 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :] 168 | new_emb.weight.data[2, :].copy_(emb) 169 | new_emb.weight.data[3, :].copy_(emb) 170 | self.bert.embeddings.token_type_embeddings = new_emb 171 | 172 | def init_word_embedding(self, num_special_tokens): 173 | orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0) 174 | new_emb = nn.Embedding( 175 | orig_word_num + num_special_tokens, self.bert.config.hidden_size) 176 | new_emb.apply(self.init_bert_weights) 177 | emb = self.bert.embeddings.word_embeddings.weight.data 178 | new_emb.weight.data[:orig_word_num, :].copy_(emb) 179 | self.bert.embeddings.word_embeddings = new_emb 180 | self.cls = BertOnlyMLMHead( 181 | self.bert.config, self.bert.embeddings.word_embeddings.weight) 182 | 183 | def forward(self, input_ids, position_ids, txt_type_ids, txt_lens, 184 | img_feat, img_pos_feat, num_bbs, 185 | attention_mask, labels, task, compute_loss=True): 186 | if task == 'mlm': 187 | txt_labels = labels 188 | return self.forward_mlm(input_ids, position_ids, txt_type_ids, 189 | txt_lens, 190 | img_feat, img_pos_feat, num_bbs, 191 | attention_mask, txt_labels, compute_loss) 192 | elif task == 'mrm': 193 | img_mask = labels 194 | return self.forward_mrm(input_ids, position_ids, txt_type_ids, 195 | txt_lens, 196 | img_feat, img_pos_feat, num_bbs, 197 | attention_mask, img_mask, compute_loss) 198 | elif task.startswith('mrc'): 199 | img_mask, mrc_label_target = labels 200 | return self.forward_mrc(input_ids, position_ids, txt_type_ids, 201 | txt_lens, 202 | img_feat, img_pos_feat, num_bbs, 203 | attention_mask, img_mask, 204 | mrc_label_target, task, compute_loss) 205 | else: 206 | raise ValueError('invalid task') 207 | 208 | # MLM 209 | def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens, 210 | img_feat, img_pos_feat, num_bbs, 211 | attention_mask, txt_labels, compute_loss=True): 212 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 213 | img_feat, img_pos_feat, num_bbs, 214 | attention_mask, 215 | output_all_encoded_layers=False, 216 | txt_type_ids=txt_type_ids) 217 | # get only the text part 218 | sequence_output = sequence_output[:, :input_ids.size(1), :] 219 | # only compute masked tokens for better efficiency 220 | prediction_scores = self.masked_compute_scores( 221 | sequence_output, txt_labels != -1) 222 | if self.vocab_pad: 223 | prediction_scores = prediction_scores[:, :-self.vocab_pad] 224 | 225 | if compute_loss: 226 | masked_lm_loss = F.cross_entropy(prediction_scores, 227 | txt_labels[txt_labels != -1], 228 | reduction='none') 229 | return masked_lm_loss 230 | else: 231 | return prediction_scores 232 | 233 | # MRM 234 | def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens, 235 | img_feat, img_pos_feat, num_bbs, 236 | attention_mask, img_masks, compute_loss=True): 237 | img_feat, feat_targets = mask_img_feat(img_feat, img_masks) 238 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 239 | img_feat, img_pos_feat, num_bbs, 240 | attention_mask, 241 | output_all_encoded_layers=False, 242 | txt_type_ids=txt_type_ids) 243 | # get only the text part 244 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) 245 | # only compute masked tokens for better efficiency 246 | prediction_feat = self.masked_compute_feat( 247 | sequence_output, img_masks) 248 | 249 | if compute_loss: 250 | mrm_loss = F.mse_loss(prediction_feat, feat_targets, 251 | reduction='none') 252 | return mrm_loss 253 | else: 254 | return prediction_feat 255 | 256 | # MRC 257 | def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens, 258 | img_feat, img_pos_feat, num_bbs, 259 | attention_mask, img_masks, 260 | label_targets, task, compute_loss=True): 261 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks) 262 | sequence_output = self.bert(input_ids, position_ids, txt_lens, 263 | img_feat, img_pos_feat, num_bbs, 264 | attention_mask, 265 | output_all_encoded_layers=False, 266 | txt_type_ids=txt_type_ids) 267 | # get only the image part 268 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs) 269 | # only compute masked tokens for better efficiency 270 | prediction_soft_label = self.masked_predict_labels( 271 | sequence_output, img_masks) 272 | 273 | if compute_loss: 274 | if "kl" in task: 275 | prediction_soft_label = F.log_softmax( 276 | prediction_soft_label, dim=-1) 277 | mrc_loss = F.kl_div( 278 | prediction_soft_label, label_targets, reduction='none') 279 | else: 280 | label_targets = torch.max( 281 | label_targets, -1)[1] # argmax 282 | mrc_loss = F.cross_entropy( 283 | prediction_soft_label, label_targets, 284 | ignore_index=0, reduction='none') 285 | return mrc_loss 286 | else: 287 | return prediction_soft_label 288 | -------------------------------------------------------------------------------- /model/ve.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNITER for VE model 3 | """ 4 | from .vqa import UniterForVisualQuestionAnswering 5 | 6 | 7 | class UniterForVisualEntailment(UniterForVisualQuestionAnswering): 8 | """ Finetune multi-modal BERT for VE 9 | """ 10 | def __init__(self, config, img_dim): 11 | super().__init__(config, img_dim, 3) 12 | -------------------------------------------------------------------------------- /model/vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bert for VQA model 3 | """ 4 | from collections import defaultdict 5 | 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 9 | 10 | from .layer import GELU 11 | from .model import UniterPreTrainedModel, UniterModel, VLXLMRPreTrainedModel, VLXLMRModel 12 | 13 | 14 | class VLXLMRForVisualQuestionAnswering(VLXLMRPreTrainedModel): 15 | """ Finetune multi-modal BERT for VQA 16 | """ 17 | def __init__(self, config, img_dim, num_answer): 18 | super().__init__(config) 19 | self.roberta = VLXLMRModel(config, img_dim) 20 | self.vqa_output = nn.Sequential( 21 | nn.Linear(config.hidden_size, config.hidden_size*2), 22 | GELU(), 23 | LayerNorm(config.hidden_size*2, eps=config.layer_norm_eps), 24 | nn.Linear(config.hidden_size*2, num_answer) 25 | ) 26 | self.apply(self.init_weights) 27 | 28 | def forward(self, batch, compute_loss=True): 29 | batch = defaultdict(lambda: None, batch) 30 | input_ids = batch['input_ids'] 31 | #position_ids = batch['position_ids'] 32 | position_ids=None 33 | img_feat = batch['img_feat'] 34 | img_pos_feat = batch['img_pos_feat'] 35 | attn_masks = batch['attn_masks'] 36 | gather_index = batch['gather_index'] 37 | sequence_output = self.roberta(input_ids, position_ids, 38 | img_feat, img_pos_feat, 39 | attn_masks, gather_index, 40 | output_all_encoded_layers=False) 41 | pooled_output = self.roberta.pooler(sequence_output) 42 | answer_scores = self.vqa_output(pooled_output) 43 | 44 | if compute_loss: 45 | targets = batch['targets'] 46 | vqa_loss = F.binary_cross_entropy_with_logits( 47 | answer_scores, targets, reduction='none') 48 | return vqa_loss 49 | else: 50 | return answer_scores 51 | 52 | class UniterForVisualQuestionAnswering(UniterPreTrainedModel): 53 | """ Finetune multi-modal BERT for VQA 54 | """ 55 | def __init__(self, config, img_dim, num_answer): 56 | super().__init__(config) 57 | self.bert = UniterModel(config, img_dim) 58 | self.vqa_output = nn.Sequential( 59 | nn.Linear(config.hidden_size, config.hidden_size*2), 60 | GELU(), 61 | LayerNorm(config.hidden_size*2, eps=1e-12), 62 | nn.Linear(config.hidden_size*2, num_answer) 63 | ) 64 | self.apply(self.init_weights) 65 | 66 | def forward(self, batch, compute_loss=True): 67 | batch = defaultdict(lambda: None, batch) 68 | input_ids = batch['input_ids'] 69 | position_ids = batch['position_ids'] 70 | img_feat = batch['img_feat'] 71 | img_pos_feat = batch['img_pos_feat'] 72 | attn_masks = batch['attn_masks'] 73 | gather_index = batch['gather_index'] 74 | sequence_output = self.bert(input_ids, position_ids, 75 | img_feat, img_pos_feat, 76 | attn_masks, gather_index, 77 | output_all_encoded_layers=False) 78 | pooled_output = self.bert.pooler(sequence_output) 79 | answer_scores = self.vqa_output(pooled_output) 80 | 81 | if compute_loss: 82 | targets = batch['targets'] 83 | vqa_loss = F.binary_cross_entropy_with_logits( 84 | answer_scores, targets, reduction='none') 85 | return vqa_loss 86 | else: 87 | return answer_scores 88 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched, get_xlmr_lr_sched 2 | from .adamw import AdamW 3 | -------------------------------------------------------------------------------- /optim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /optim/__pycache__/adamw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/adamw.cpython-36.pyc -------------------------------------------------------------------------------- /optim/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /optim/__pycache__/sched.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/sched.cpython-36.pyc -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface 4 | """ 5 | import math 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class AdamW(Optimizer): 12 | """ Implements Adam algorithm with weight decay fix. 13 | Parameters: 14 | lr (float): learning rate. Default 1e-3. 15 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). 16 | Default: (0.9, 0.999) 17 | eps (float): Adams epsilon. Default: 1e-6 18 | weight_decay (float): Weight decay. Default: 0.0 19 | correct_bias (bool): can be set to False to avoid correcting bias 20 | in Adam (e.g. like in Bert TF repository). Default True. 21 | """ 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 23 | weight_decay=0.0, correct_bias=True): 24 | if lr < 0.0: 25 | raise ValueError( 26 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 27 | if not 0.0 <= betas[0] < 1.0: 28 | raise ValueError("Invalid beta parameter: {} - " 29 | "should be in [0.0, 1.0[".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter: {} - " 32 | "should be in [0.0, 1.0[".format(betas[1])) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {} - " 35 | "should be >= 0.0".format(eps)) 36 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 37 | correct_bias=correct_bias) 38 | super(AdamW, self).__init__(params, defaults) 39 | 40 | def step(self, closure=None): 41 | """Performs a single optimization step. 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | grad = p.grad.data 55 | if grad.is_sparse: 56 | raise RuntimeError( 57 | 'Adam does not support sparse ' 58 | 'gradients, please consider SparseAdam instead') 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | 70 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 71 | beta1, beta2 = group['betas'] 72 | 73 | state['step'] += 1 74 | 75 | # Decay the first and second moment running average coefficient 76 | # In-place operations to update the averages at the same time 77 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 79 | denom = exp_avg_sq.sqrt().add_(group['eps']) 80 | 81 | step_size = group['lr'] 82 | if group['correct_bias']: # No bias correction for Bert 83 | bias_correction1 = 1.0 - beta1 ** state['step'] 84 | bias_correction2 = 1.0 - beta2 ** state['step'] 85 | step_size = (step_size * math.sqrt(bias_correction2) 86 | / bias_correction1) 87 | 88 | p.data.addcdiv_(-step_size, exp_avg, denom) 89 | 90 | # Just adding the square of the weights to the loss function is 91 | # *not* the correct way of using L2 regularization/weight decay 92 | # with Adam, since that will interact with the m and v 93 | # parameters in strange ways. 94 | # 95 | # Instead we want to decay the weights in a manner that doesn't 96 | # interact with the m/v parameters. This is equivalent to 97 | # adding the square of the weights to the loss with plain 98 | # (non-momentum) SGD. 99 | # Add weight decay at the end (fixed version) 100 | if group['weight_decay'] > 0.0: 101 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc lr helper 3 | """ 4 | from torch.optim import Adam, Adamax 5 | 6 | from .adamw import AdamW 7 | 8 | 9 | def build_optimizer(model, opts): 10 | param_optimizer = list(model.named_parameters()) 11 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 12 | optimizer_grouped_parameters = [ 13 | {'params': [p for n, p in param_optimizer 14 | if not any(nd in n for nd in no_decay)], 15 | 'weight_decay': opts.weight_decay}, 16 | {'params': [p for n, p in param_optimizer 17 | if any(nd in n for nd in no_decay)], 18 | 'weight_decay': 0.0} 19 | ] 20 | 21 | # currently Adam only 22 | if opts.optim == 'adam': 23 | OptimCls = Adam 24 | elif opts.optim == 'adamax': 25 | OptimCls = Adamax 26 | elif opts.optim == 'adamw': 27 | OptimCls = AdamW 28 | else: 29 | raise ValueError('invalid optimizer') 30 | optimizer = OptimCls(optimizer_grouped_parameters, 31 | lr=opts.learning_rate, betas=opts.betas) 32 | return optimizer 33 | 34 | def xlmr_pretrained_encoder_layer(n, load_layer): 35 | """ 36 | Verify whether n loads the pretrained weights from xlmr 37 | """ 38 | assert isinstance(load_layer, int) 39 | #print(n) 40 | if 'roberta.encoder' in n: 41 | if int(n.split('.')[3]) <= load_layer: 42 | return True 43 | elif 'roberta.embeddings' in n: 44 | return True 45 | 46 | return False 47 | 48 | def build_xlmr_optimizer(model, opts): 49 | param_optimizer = list(model.named_parameters()) 50 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 51 | #Divide the parameters to four groups, we will apply a relatively small lr to the word embedding 52 | #assert opts.load_embedding_only or isinstance(opts.load_layer, int) 53 | if opts.load_layer: 54 | assert isinstance(opts.load_layer,int) and opts.load_layer > 0 55 | optimizer_grouped_parameters = [ 56 | {'params': [p for n, p in param_optimizer 57 | if not any(nd in n for nd in no_decay) and xlmr_pretrained_encoder_layer(n, opts.load_layer)], 58 | 'weight_decay': opts.weight_decay, 'lr': opts.xlmr_lr}, 59 | {'params': [p for n, p in param_optimizer 60 | if any(nd in n for nd in no_decay) and xlmr_pretrained_encoder_layer(n, opts.load_layer)], 61 | 'weight_decay': 0.0, 'lr': opts.xlmr_lr}, 62 | {'params': [p for n, p in param_optimizer 63 | if not any(nd in n for nd in no_decay) and not xlmr_pretrained_encoder_layer(n, opts.load_layer)], 64 | 'weight_decay': opts.weight_decay, 'lr': opts.learning_rate}, 65 | {'params': [p for n, p in param_optimizer 66 | if any(nd in n for nd in no_decay) and not xlmr_pretrained_encoder_layer(n, opts.load_layer)], 67 | 'weight_decay': 0.0, 'lr': opts.learning_rate}, 68 | ] 69 | else: 70 | #general case 71 | optimizer_grouped_parameters = [ 72 | {'params': [p for n, p in param_optimizer 73 | if not any(nd in n for nd in no_decay) and 'roberta.embeddings' in n], 74 | 'weight_decay': opts.weight_decay, 'lr': opts.xlmr_lr}, 75 | {'params': [p for n, p in param_optimizer 76 | if any(nd in n for nd in no_decay) and 'roberta.embeddings' in n], 77 | 'weight_decay': 0.0, 'lr': opts.xlmr_lr}, 78 | {'params': [p for n, p in param_optimizer 79 | if not any(nd in n for nd in no_decay) and 'roberta.embeddings' not in n], 80 | 'weight_decay': opts.weight_decay, 'lr': opts.learning_rate}, 81 | {'params': [p for n, p in param_optimizer 82 | if any(nd in n for nd in no_decay) and 'roberta.embeddings' not in n], 83 | 'weight_decay': 0.0, 'lr': opts.learning_rate}, 84 | ] 85 | 86 | 87 | #print([n for n, p in param_optimizer if not any(nd in n for nd in no_decay) and 'roberta.embeddings' in n]) 88 | # currently Adam only 89 | if opts.optim == 'adam': 90 | OptimCls = Adam 91 | elif opts.optim == 'adamax': 92 | OptimCls = Adamax 93 | elif opts.optim == 'adamw': 94 | OptimCls = AdamW 95 | else: 96 | raise ValueError('invalid optimizer') 97 | 98 | optimizer = OptimCls(optimizer_grouped_parameters, 99 | lr=opts.learning_rate, betas=opts.betas) 100 | return optimizer 101 | 102 | 103 | -------------------------------------------------------------------------------- /optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | optimizer learning rate scheduling helpers 3 | """ 4 | from math import ceil 5 | 6 | 7 | def noam_schedule(step, warmup_step=4000): 8 | if step <= warmup_step: 9 | return step / warmup_step 10 | return (warmup_step ** 0.5) * (step ** -0.5) 11 | 12 | 13 | def warmup_linear(step, warmup_step, tot_step): 14 | if step < warmup_step: 15 | return step / warmup_step 16 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 17 | 18 | 19 | def vqa_schedule(step, warmup_interval, decay_interval, 20 | decay_start, decay_rate): 21 | """ VQA schedule from MCAN """ 22 | if step < warmup_interval: 23 | return 1/4 24 | elif step < 2 * warmup_interval: 25 | return 2/4 26 | elif step < 3 * warmup_interval: 27 | return 3/4 28 | elif step >= decay_start: 29 | num_decay = ceil((step - decay_start) / decay_interval) 30 | return decay_rate ** num_decay 31 | else: 32 | return 1 33 | 34 | 35 | def get_lr_sched(global_step, opts): 36 | # learning rate scheduling 37 | if opts.decay == 'linear': 38 | lr_this_step = opts.learning_rate * warmup_linear( 39 | global_step, opts.warmup_steps, opts.num_train_steps) 40 | elif opts.decay == 'invsqrt': 41 | lr_this_step = opts.learning_rate * noam_schedule( 42 | global_step, opts.warmup_steps) 43 | elif opts.decay == 'constant': 44 | lr_this_step = opts.learning_rate 45 | elif opts.decay == 'vqa': 46 | lr_this_step = opts.learning_rate * vqa_schedule( 47 | global_step, opts.warm_int, opts.decay_int, 48 | opts.decay_st, opts.decay_rate) 49 | if lr_this_step <= 0: 50 | # save guard for possible miscalculation of train steps 51 | lr_this_step = 1e-8 52 | return lr_this_step 53 | 54 | def get_xlmr_lr_sched(global_step, opts): 55 | if opts.decay == 'linear': 56 | lr_this_step = opts.xlmr_lr * warmup_linear( 57 | global_step, opts.warmup_steps, opts.num_train_steps) 58 | elif opts.decay == 'invsqrt': 59 | lr_this_step = opts.xlmr_lr * noam_schedule( 60 | global_step, opts.warmup_steps) 61 | elif opts.decay == 'constant': 62 | lr_this_step = opts.xlmr_lr 63 | elif opts.decay == 'vqa': 64 | lr_this_step = opts.xlmr_lr * vqa_schedule( 65 | global_step, opts.warm_int, opts.decay_int, 66 | opts.decay_st, opts.decay_rate) 67 | if lr_this_step <= 0: 68 | # save guard for possible miscalculation of train steps 69 | lr_this_step = 1e-8 70 | return lr_this_step 71 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/const.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/const.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/distributed.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/save.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/save.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual_entailment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/visual_entailment.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vqa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/vqa.cpython-36.pyc -------------------------------------------------------------------------------- /utils/const.py: -------------------------------------------------------------------------------- 1 | """ constants """ 2 | IMG_DIM = 2048 3 | IMG_LABEL_DIM = 1601 4 | BUCKET_SIZE = 8192 5 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | distributed API using Horovod 3 | """ 4 | import math 5 | import pickle 6 | 7 | import torch 8 | from horovod import torch as hvd 9 | 10 | import msgpack 11 | import msgpack_numpy 12 | msgpack_numpy.patch() 13 | 14 | 15 | def all_reduce_and_rescale_tensors(tensors, rescale_denom): 16 | """All-reduce and rescale tensors at once (as a flattened tensor) 17 | 18 | Args: 19 | tensors: list of Tensors to all-reduce 20 | rescale_denom: denominator for rescaling summed Tensors 21 | """ 22 | # buffer size in bytes, determine equiv. # of elements based on data type 23 | sz = sum(t.numel() for t in tensors) 24 | buffer_t = tensors[0].new(sz).zero_() 25 | 26 | # copy tensors into buffer_t 27 | offset = 0 28 | for t in tensors: 29 | numel = t.numel() 30 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 31 | offset += numel 32 | 33 | # all-reduce and rescale 34 | hvd.allreduce_(buffer_t[:offset]) 35 | buffer_t.div_(rescale_denom) 36 | 37 | # copy all-reduced buffer back into tensors 38 | offset = 0 39 | for t in tensors: 40 | numel = t.numel() 41 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 42 | offset += numel 43 | 44 | 45 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, 46 | buffer_size=10485760): 47 | """All-reduce and rescale tensors in chunks of the specified size. 48 | 49 | Args: 50 | tensors: list of Tensors to all-reduce 51 | rescale_denom: denominator for rescaling summed Tensors 52 | buffer_size: all-reduce chunk size in bytes 53 | """ 54 | # buffer size in bytes, determine equiv. # of elements based on data type 55 | buffer_t = tensors[0].new( 56 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 57 | buffer = [] 58 | 59 | def all_reduce_buffer(): 60 | # copy tensors into buffer_t 61 | offset = 0 62 | for t in buffer: 63 | numel = t.numel() 64 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 65 | offset += numel 66 | 67 | # all-reduce and rescale 68 | hvd.allreduce_(buffer_t[:offset]) 69 | buffer_t.div_(rescale_denom) 70 | 71 | # copy all-reduced buffer back into tensors 72 | offset = 0 73 | for t in buffer: 74 | numel = t.numel() 75 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 76 | offset += numel 77 | 78 | filled = 0 79 | for t in tensors: 80 | sz = t.numel() * t.element_size() 81 | if sz > buffer_size: 82 | # tensor is bigger than buffer, all-reduce and rescale directly 83 | hvd.allreduce_(t) 84 | t.div_(rescale_denom) 85 | elif filled + sz > buffer_size: 86 | # buffer is full, all-reduce and replace buffer with grad 87 | all_reduce_buffer() 88 | buffer = [t] 89 | filled = sz 90 | else: 91 | # add tensor to buffer 92 | buffer.append(t) 93 | filled += sz 94 | 95 | if len(buffer) > 0: 96 | all_reduce_buffer() 97 | 98 | 99 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760): 100 | """broadcast tensors in chunks of the specified size. 101 | 102 | Args: 103 | tensors: list of Tensors to broadcast 104 | root_rank: rank to broadcast 105 | buffer_size: broadcast chunk size in bytes 106 | """ 107 | # buffer size in bytes, determine equiv. # of elements based on data type 108 | buffer_t = tensors[0].new( 109 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 110 | buffer = [] 111 | 112 | def broadcast_buffer(): 113 | # copy tensors into buffer_t 114 | offset = 0 115 | for t in buffer: 116 | numel = t.numel() 117 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 118 | offset += numel 119 | 120 | # broadcast 121 | hvd.broadcast_(buffer_t[:offset], root_rank) 122 | 123 | # copy all-reduced buffer back into tensors 124 | offset = 0 125 | for t in buffer: 126 | numel = t.numel() 127 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 128 | offset += numel 129 | 130 | filled = 0 131 | for t in tensors: 132 | sz = t.numel() * t.element_size() 133 | if sz > buffer_size: 134 | # tensor is bigger than buffer, broadcast directly 135 | hvd.broadcast_(t, root_rank) 136 | elif filled + sz > buffer_size: 137 | # buffer is full, broadcast and replace buffer with tensor 138 | broadcast_buffer() 139 | buffer = [t] 140 | filled = sz 141 | else: 142 | # add tensor to buffer 143 | buffer.append(t) 144 | filled += sz 145 | 146 | if len(buffer) > 0: 147 | broadcast_buffer() 148 | 149 | 150 | def _encode(enc, max_size, buffer_=None): 151 | enc_size = len(enc) 152 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 153 | if buffer_ is None or len(buffer_) < enc_size + enc_byte: 154 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 155 | remainder = enc_size 156 | for i in range(enc_byte): 157 | base = 256 ** (enc_byte-i-1) 158 | buffer_[i] = remainder // base 159 | remainder %= base 160 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 161 | return buffer_, enc_byte 162 | 163 | 164 | def _decode(buffer_, enc_byte): 165 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 166 | for i in range(enc_byte)) 167 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 168 | shift = size + enc_byte 169 | return bytes_list, shift 170 | 171 | 172 | _BUFFER_SIZE = 4096 173 | 174 | 175 | def all_gather_list(data): 176 | """Gathers arbitrary data from all nodes into a list.""" 177 | if not hasattr(all_gather_list, '_buffer'): 178 | # keeps small buffer to avoid re-allocate every call 179 | all_gather_list._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE) 180 | try: 181 | enc = msgpack.dumps(data, use_bin_type=True) 182 | msgpack_success = True 183 | except TypeError: 184 | enc = pickle.dumps(data) 185 | msgpack_success = False 186 | 187 | enc_size = len(enc) 188 | max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() 189 | buffer_ = all_gather_list._buffer 190 | in_buffer, enc_byte = _encode(enc, max_size, buffer_) 191 | 192 | out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) 193 | 194 | results = [] 195 | for _ in range(hvd.size()): 196 | bytes_list, shift = _decode(out_buffer, enc_byte) 197 | out_buffer = out_buffer[shift:] 198 | 199 | if msgpack_success: 200 | result = msgpack.loads(bytes_list, raw=False) 201 | else: 202 | result = pickle.loads(bytes_list) 203 | results.append(result) 204 | return results 205 | 206 | 207 | def any_broadcast(data, root_rank): 208 | """broadcast arbitrary data from root_rank to all nodes.""" 209 | if not hasattr(any_broadcast, '_buffer'): 210 | # keeps small buffer to avoid re-allocate every call 211 | any_broadcast._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE) 212 | try: 213 | enc = msgpack.dumps(data, use_bin_type=True) 214 | msgpack_success = True 215 | except TypeError: 216 | enc = pickle.dumps(data) 217 | msgpack_success = False 218 | 219 | max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() 220 | buffer_ = any_broadcast._buffer 221 | buffer_, enc_byte = _encode(enc, max_size, buffer_) 222 | 223 | hvd.broadcast_(buffer_, root_rank) 224 | 225 | bytes_list, _ = _decode(buffer_, enc_byte) 226 | if msgpack_success: 227 | result = msgpack.loads(bytes_list, raw=False) 228 | else: 229 | result = pickle.loads(bytes_list) 230 | return result 231 | -------------------------------------------------------------------------------- /utils/itm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def i2t(sims, npts=None, return_ranks=False): 5 | """ 6 | Images->Text (Image Annotation) 7 | sims: (N, 5N) matrix of similarity im-cap 8 | """ 9 | npts = sims.shape[0] 10 | ranks = np.zeros(npts) 11 | top1 = np.zeros(npts) 12 | for index in range(npts): 13 | inds = np.argsort(sims[index])[::-1] 14 | # Score 15 | rank = 1e20 16 | for i in range(5 * index, 5 * index + 5, 1): 17 | tmp = np.where(inds == i)[0][0] 18 | if tmp < rank: 19 | rank = tmp 20 | ranks[index] = rank 21 | top1[index] = inds[0] 22 | 23 | # Compute metrics 24 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 25 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 26 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 27 | medr = np.floor(np.median(ranks)) + 1 28 | meanr = ranks.mean() + 1 29 | if return_ranks: 30 | return (r1, r5, r10, medr, meanr), (ranks, top1) 31 | else: 32 | return (r1, r5, r10, medr, meanr) 33 | 34 | 35 | def t2i(sims, npts=None, return_ranks=False): 36 | """ 37 | Text->Images (Image Search) 38 | sims: (N, 5N) matrix of similarity im-cap 39 | """ 40 | npts = sims.shape[0] 41 | ranks = np.zeros(5 * npts) 42 | top1 = np.zeros(5 * npts) 43 | 44 | # --> (5N(caption), N(image)) 45 | sims = sims.T 46 | 47 | for index in range(npts): 48 | for i in range(5): 49 | inds = np.argsort(sims[5 * index + i])[::-1] 50 | ranks[5 * index + i] = np.where(inds == index)[0][0] 51 | top1[5 * index + i] = inds[0] 52 | 53 | # Compute metrics 54 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 55 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 56 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 57 | medr = np.floor(np.median(ranks)) + 1 58 | meanr = ranks.mean() + 1 59 | if return_ranks: 60 | return (r1, r5, r10, medr, meanr), (ranks, top1) 61 | else: 62 | return (r1, r5, r10, medr, meanr) 63 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | helper for logging 3 | NOTE: loggers are global objects use with caution 4 | """ 5 | import logging 6 | import math 7 | 8 | import tensorboardX 9 | 10 | 11 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 12 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 13 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 14 | LOGGER = logging.getLogger('__main__') # this is the global logger 15 | 16 | 17 | def add_log_to_file(log_path): 18 | fh = logging.FileHandler(log_path) 19 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 20 | fh.setFormatter(formatter) 21 | LOGGER.addHandler(fh) 22 | 23 | 24 | class TensorboardLogger(object): 25 | def __init__(self): 26 | self._logger = None 27 | self._global_step = 0 28 | 29 | def create(self, path): 30 | self._logger = tensorboardX.SummaryWriter(path) 31 | 32 | def noop(self, *args, **kwargs): 33 | return 34 | 35 | def step(self): 36 | self._global_step += 1 37 | 38 | @property 39 | def global_step(self): 40 | return self._global_step 41 | 42 | def log_scaler_dict(self, log_dict, prefix=''): 43 | """ log a dictionary of scalar values""" 44 | if self._logger is None: 45 | return 46 | if prefix: 47 | prefix = f'{prefix}_' 48 | for name, value in log_dict.items(): 49 | if isinstance(value, dict): 50 | self.log_scaler_dict(value, self._global_step, 51 | prefix=f'{prefix}{name}') 52 | else: 53 | self._logger.add_scalar(f'{prefix}{name}', value, 54 | self._global_step) 55 | 56 | def __getattr__(self, name): 57 | if self._logger is None: 58 | return self.noop 59 | return self._logger.__getattribute__(name) 60 | 61 | 62 | TB_LOGGER = TensorboardLogger() 63 | 64 | 65 | class RunningMeter(object): 66 | """ running meteor of a scalar value 67 | (useful for monitoring training loss) 68 | """ 69 | def __init__(self, name, val=None, smooth=0.99): 70 | self._name = name 71 | self._sm = smooth 72 | self._val = val 73 | 74 | def __call__(self, value): 75 | val = (value if self._val is None 76 | else value*(1-self._sm) + self._val*self._sm) 77 | if not math.isnan(val) and not math.isinf(val): 78 | self._val = val 79 | else: 80 | print(f'Inf/Nan in {self._name}') 81 | 82 | def __str__(self): 83 | return f'{self._name}: {self._val:.4f}' 84 | 85 | @property 86 | def val(self): 87 | return self._val 88 | 89 | @property 90 | def name(self): 91 | return self._name 92 | -------------------------------------------------------------------------------- /utils/m3p_tokenizer.py: -------------------------------------------------------------------------------- 1 | import sentencepiece as spm 2 | import collections 3 | 4 | SPIECE_UNDERLINE = "▁" 5 | 6 | def load_vocab(vocab_file): 7 | """Loads a vocabulary file into a dictionary.""" 8 | vocab = collections.OrderedDict() 9 | with open(vocab_file, "r", encoding="utf-8") as reader: 10 | tokens = reader.readlines() 11 | for index, token in enumerate(tokens): 12 | token = token.rstrip("\n") 13 | vocab[token] = index 14 | return vocab 15 | 16 | class XLMRTokenizer(object): 17 | def __init__(self, vocab_file,special_token=''): 18 | sp_model = spm.SentencePieceProcessor() 19 | sp_model.Load(str(vocab_file)) 20 | self.sp_model = sp_model 21 | self.bos_token = "" 22 | self.eos_token = "" 23 | self.sep_token = "" 24 | self.cls_token = "" 25 | self.unk_token = "" 26 | self.pad_token = "" 27 | self.mask_token = "" 28 | 29 | self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} 30 | 31 | # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab 32 | self.fairseq_offset = 1 33 | 34 | self.fairseq_tokens_to_ids[""] = len(sp_model) + self.fairseq_offset 35 | self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} 36 | 37 | self.cls_token_id = self._cls_token_id() 38 | self.sep_token_id = self._sep_token_id() 39 | self.pad_token_id = self._pad_token_id() 40 | self.mask_token_id = self._mask_token_id() 41 | self.eos_token_id = self._eos_token_id() 42 | 43 | 44 | def _cls_token_id(self): 45 | """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ 46 | return self._convert_token_to_id(self.cls_token) 47 | def _sep_token_id(self): 48 | """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ 49 | return self._convert_token_to_id(self.sep_token) 50 | def _pad_token_id(self): 51 | return self._convert_token_to_id(self.pad_token) 52 | def _eos_token_id(self): 53 | return self._convert_token_to_id(self.eos_token) 54 | def _mask_token_id(self): 55 | return self._convert_token_to_id(self.mask_token) 56 | 57 | def build_inputs_with_special_tokens( 58 | self,token_ids_0, token_ids_1): 59 | """ 60 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks 61 | by concatenating and adding special tokens. 62 | A BERT sequence has the following format: 63 | - single sequence: ``[CLS] X [SEP]`` 64 | - pair of sequences: ``[CLS] A [SEP] B [SEP]`` 65 | Args: 66 | token_ids_0 (:obj:`List[int]`): 67 | List of IDs to which the special tokens will be added 68 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): 69 | Optional second list of IDs for sequence pairs. 70 | Returns: 71 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 72 | """ 73 | if token_ids_1 is None: 74 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 75 | cls = [self.cls_token_id] 76 | sep = [self.sep_token_id] 77 | return cls + token_ids_0 + sep + token_ids_1 + sep 78 | 79 | @property 80 | def vocab_size(self): 81 | return len(self.sp_model) + self.fairseq_offset + 1 # Add the token 82 | 83 | def get_vocab(self): 84 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} 85 | return vocab 86 | 87 | def _tokenize(self,text): 88 | return self.sp_model.EncodeAsPieces(text) 89 | 90 | def _convert_token_to_id(self,token): 91 | """ Converts a token (str) in an id using the vocab. """ 92 | if token in self.fairseq_tokens_to_ids: 93 | return self.fairseq_tokens_to_ids[token] 94 | spm_id = self.sp_model.PieceToId(token) 95 | 96 | # Need to return unknown token if the SP model returned 0 97 | return spm_id + self.fairseq_offset if spm_id else self.fairseq_tokens_to_ids[self.unk_token] 98 | 99 | def encode(self,text,text_b=None): 100 | tokens = self._tokenize(text) 101 | input_ids = [] 102 | for token in tokens: 103 | input_ids.append(self._convert_token_to_id(token)) 104 | 105 | input_ids_b = None 106 | if text_b is not None: 107 | input_ids_b = [] 108 | tokens_b = self._tokenize(text_b) 109 | for token in tokens_b: 110 | input_ids_b.append(self._convert_token_to_id(token)) 111 | 112 | #input_ids = self.build_inputs_with_special_tokens(input_ids,input_ids_b) 113 | 114 | return input_ids 115 | 116 | def decode(self,token_ids): 117 | _out = [] 118 | for token_id in token_ids: 119 | _out.append(self._convert_id_to_token(token_id)) 120 | return self.convert_tokens_to_string(_out) 121 | 122 | def _convert_id_to_token(self,index): 123 | """Converts an index (integer) in a token (str) using the vocab.""" 124 | if index in self.fairseq_ids_to_tokens: 125 | return self.fairseq_ids_to_tokens[index] 126 | return self.sp_model.IdToPiece(index - self.fairseq_offset) 127 | def convert_tokens_to_string(self, tokens): 128 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 129 | out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() 130 | return out_string -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc utilities 3 | """ 4 | import json 5 | import random 6 | import sys 7 | 8 | import torch 9 | import numpy as np 10 | 11 | from utils.logger import LOGGER 12 | 13 | 14 | class NoOp(object): 15 | """ useful for distributed training No-Ops """ 16 | def __getattr__(self, name): 17 | return self.noop 18 | 19 | def noop(self, *args, **kwargs): 20 | return 21 | 22 | 23 | def parse_with_config(parser): 24 | args = parser.parse_args() 25 | if args.config is not None: 26 | config_args = json.load(open(args.config)) 27 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] 28 | if arg.startswith('--')} 29 | for k, v in config_args.items(): 30 | if k not in override_keys: 31 | setattr(args, k, v) 32 | del args.config 33 | return args 34 | 35 | 36 | VE_ENT2IDX = { 37 | 'contradiction': 0, 38 | 'entailment': 1, 39 | 'neutral': 2 40 | } 41 | 42 | VE_IDX2ENT = { 43 | 0: 'contradiction', 44 | 1: 'entailment', 45 | 2: 'neutral' 46 | } 47 | 48 | 49 | class Struct(object): 50 | def __init__(self, dict_): 51 | self.__dict__.update(dict_) 52 | 53 | 54 | def set_dropout(model, drop_p): 55 | for name, module in model.named_modules(): 56 | # we might want to tune dropout for smaller dataset 57 | if isinstance(module, torch.nn.Dropout): 58 | if module.p != drop_p: 59 | module.p = drop_p 60 | LOGGER.info(f'{name} set to {drop_p}') 61 | 62 | 63 | def set_random_seed(seed): 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | -------------------------------------------------------------------------------- /utils/ms_internal_mt.py: -------------------------------------------------------------------------------- 1 | import os, requests, uuid, json 2 | import csv 3 | import sys 4 | from tqdm import tqdm 5 | from shutil import copyfile 6 | 7 | csv.field_size_limit(sys.maxsize) 8 | 9 | key_var_name = 'TRANSLATOR_TEXT_SUBSCRIPTION_KEY' 10 | if not key_var_name in os.environ: 11 | raise Exception('Please set/export the environment variable: {}'.format(key_var_name)) 12 | 13 | subscription_key = os.environ[key_var_name] 14 | endpoint_var_name = 'TRANSLATOR_TEXT_ENDPOINT' 15 | if not endpoint_var_name in os.environ: 16 | raise Exception('Please set/export the environment variable: {}'.format(endpoint_var_name)) 17 | endpoint = os.environ[endpoint_var_name] 18 | 19 | # If you encounter any issues with the base_url or path, make sure 20 | # that you are using the latest endpoint: https://docs.microsoft.com/azure/cognitive-services/translator/reference/v3-0-translate 21 | path = '/translate?api-version=3.0' 22 | constructed_url = endpoint + path 23 | 24 | headers = { 25 | 'Ocp-Apim-Subscription-Key': subscription_key, 26 | 'Content-type': 'application/json', 27 | 'X-ClientTraceId': str(uuid.uuid4()) 28 | } 29 | 30 | def msft_translate(lines, langs, batch_size=10): 31 | global constructed_url, headers 32 | if 'zh' in langs: 33 | langs = langs.replace('zh', 'zh-Hans') 34 | params = "&" + "&".join(["to={}".format(lang) for lang in langs.split(',')]) + "&includeAlignment=true&includeSentenceLength=true" 35 | request_url = constructed_url + params 36 | # Body limitations: 37 | # The array can have at most 100 elements. 38 | # The entire text included in the request cannot exceed 5,000 characters including spaces. 39 | out_translations = [] 40 | 41 | #for idx in tqdm(range(0, len(lines), batch_size)): 42 | for idx in range(0, len(lines), batch_size): 43 | body = [{'text': line} for line in lines[idx:idx+batch_size]] 44 | 45 | request = requests.post(request_url, headers=headers, json=body) 46 | response = request.json() 47 | 48 | out_translations += response 49 | 50 | return out_translations 51 | 52 | # in_file = '' 53 | # out_file = '' 54 | # tgt_langs = 'en' 55 | # batch_size = 10 56 | split = "train" 57 | in_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann.tsv".format(split=split) 58 | tgt_langs = 'fr' 59 | backup_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann_{tgt_langs}_backup.tsv".format(split=split, tgt_langs=tgt_langs) 60 | out_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann_{tgt_langs}.tsv".format(split=split, tgt_langs=tgt_langs) 61 | alignment_out_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_en_{tgt_langs}_word_alignemnt.tsv".format(split=split, tgt_langs=tgt_langs) 62 | line_batch_size=100 63 | #batch_size = 100 64 | save_step = 10000 65 | #with open(in_file, "r") as tsv_src: 66 | tsv_src = open(in_file) 67 | 68 | #get the tsv_tgt_data 69 | tsv_tgt_data = [] 70 | src_captions = [] 71 | #src_tgt_alignments = [] 72 | 73 | if os.path.isfile(backup_file): 74 | #load the tsv_tgt_data 75 | tsv_tgt_backup = open(backup_file) 76 | for tgt_line in tsv_tgt_backup: 77 | img_id, img_url, tgt_caption, success = tgt_line.strip().split('\t') 78 | tsv_tgt_data.append([img_id, img_url, tgt_caption, success]) 79 | #set up the resume point 80 | resume_step = len(tsv_tgt_data) 81 | print("resume_step is: {}".format(resume_step)) 82 | for i, src_line in enumerate(tsv_src): 83 | img_id, img_url, src_caption, success = src_line.strip().split('\t') 84 | src_captions.append(src_caption) 85 | #split the fields 86 | if i >= resume_step: 87 | tsv_tgt_data.append([img_id, img_url, "", success]) 88 | #src_tgt_alignments.append([img_id, ""]) 89 | 90 | #print(len(tsv_tgt_data) == len(src_captions)) 91 | for i in tqdm(range(resume_step, len(src_captions), line_batch_size)): 92 | if i + line_batch_size < len(src_captions): 93 | current_src_batch = src_captions[i:i+line_batch_size] 94 | else: 95 | current_src_batch = src_captions[i:] 96 | 97 | #get the translations 98 | batch_size = len(current_src_batch) 99 | current_translation_batch = msft_translate(current_src_batch, langs=tgt_langs, batch_size=batch_size) 100 | #update tsv_tgt_data 101 | for j in range(len(current_translation_batch)): 102 | if current_translation_batch[j]['translations'][0].get('text', None): 103 | translated_text = current_translation_batch[j]['translations'][0]['text'] 104 | if len(translated_text) > 10000: 105 | tsv_tgt_data[i+j][3] = "fail" 106 | else: 107 | tsv_tgt_data[i+j][2] = translated_text 108 | # print(current_translation_batch[j]) 109 | # raise Exception("something wrong happens to the translation: {}".format(current_translation_batch[j])) 110 | # finally: 111 | # print(current_translation_batch[j]) 112 | #retried_translation = msft_translate() 113 | # if current_translation_batch[j]['translations'][0].get('alignment', None): 114 | # src_tgt_alignments[i+j][1] = current_translation_batch[j]['translations'][0]['alignment']['proj'] 115 | if (i+1) % save_step == 1: 116 | saved_tsv_tgt_data = tsv_tgt_data[:i+line_batch_size] 117 | #saved_src_tgt_alignments = src_tgt_alignments[:i+line_batch_size] 118 | with open(out_file, "w") as f_out: 119 | tsv_output = csv.writer(f_out, delimiter='\t') 120 | for saved_line in saved_tsv_tgt_data: 121 | tsv_output.writerow(saved_line) 122 | #copy the file to the backup file 123 | copyfile(out_file, backup_file) 124 | #save the final data to file 125 | with open(out_file, "w") as f_out: 126 | tsv_output = csv.writer(f_out, delimiter='\t') 127 | for saved_line in tsv_tgt_data: 128 | tsv_output.writerow(saved_line) 129 | tsv_src.close() 130 | 131 | 132 | -------------------------------------------------------------------------------- /utils/ms_internal_mt_label.py: -------------------------------------------------------------------------------- 1 | import os, requests, uuid, json 2 | import csv 3 | import sys 4 | from tqdm import tqdm 5 | from shutil import copyfile 6 | 7 | csv.field_size_limit(sys.maxsize) 8 | 9 | key_var_name = 'TRANSLATOR_TEXT_SUBSCRIPTION_KEY' 10 | if not key_var_name in os.environ: 11 | raise Exception('Please set/export the environment variable: {}'.format(key_var_name)) 12 | 13 | subscription_key = os.environ[key_var_name] 14 | endpoint_var_name = 'TRANSLATOR_TEXT_ENDPOINT' 15 | if not endpoint_var_name in os.environ: 16 | raise Exception('Please set/export the environment variable: {}'.format(endpoint_var_name)) 17 | endpoint = os.environ[endpoint_var_name] 18 | 19 | # If you encounter any issues with the base_url or path, make sure 20 | # that you are using the latest endpoint: https://docs.microsoft.com/azure/cognitive-services/translator/reference/v3-0-translate 21 | path = '/translate?api-version=3.0' 22 | constructed_url = endpoint + path 23 | 24 | headers = { 25 | 'Ocp-Apim-Subscription-Key': subscription_key, 26 | 'Content-type': 'application/json', 27 | 'X-ClientTraceId': str(uuid.uuid4()) 28 | } 29 | 30 | def msft_translate(lines, langs, batch_size=10): 31 | global constructed_url, headers 32 | if 'zh' in langs: 33 | langs = langs.replace('zh', 'zh-Hans') 34 | params = "&" + "&".join(["to={}".format(lang) for lang in langs.split(',')]) + "&includeAlignment=true&includeSentenceLength=true" 35 | request_url = constructed_url + params 36 | # Body limitations: 37 | # The array can have at most 100 elements. 38 | # The entire text included in the request cannot exceed 5,000 characters including spaces. 39 | out_translations = [] 40 | 41 | #for idx in tqdm(range(0, len(lines), batch_size)): 42 | for idx in range(0, len(lines), batch_size): 43 | body = [{'text': line} for line in lines[idx:idx+batch_size]] 44 | 45 | request = requests.post(request_url, headers=headers, json=body) 46 | response = request.json() 47 | 48 | out_translations += response 49 | 50 | return out_translations 51 | 52 | 53 | source_label_file = "img_label_objects.txt" 54 | tgt_lan = "cs" 55 | tgt_label_file = "img_label_objects_{}.txt".format(tgt_lan) 56 | batch_size = 100 57 | 58 | source_labels = [] 59 | tgt_labels = [] 60 | with open(source_label_file, "r") as f: 61 | for line in f: 62 | source_labels.append(line.strip()) 63 | 64 | for i in tqdm(range(0, len(source_labels), batch_size)): 65 | batch_upper_limit = i+batch_size if i+batch_size < len(source_labels) else len(source_labels) 66 | current_src_batch = source_labels[i:batch_upper_limit] 67 | current_batch_size = len(current_src_batch) 68 | 69 | current_translation_batch = msft_translate(current_src_batch, langs=tgt_lan, batch_size=current_batch_size) 70 | for translated_data in current_translation_batch: 71 | assert translated_data['translations'][0].get('text', None) 72 | tgt_labels.append(translated_data['translations'][0].get('text', None)) 73 | #write the tgt_labels to outputfile 74 | with open(tgt_label_file, 'w') as f_out: 75 | for label in tgt_labels: 76 | f_out.write(label+'\n') -------------------------------------------------------------------------------- /utils/ms_internal_mt_popen.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | filename = "ms_internal_mt.py" 4 | while True: 5 | p = subprocess.Popen('python '+filename, shell=True).wait() 6 | if p!=0: 7 | print("An error occured in the translation process. Restaring....") 8 | continue 9 | else: 10 | print("The translation process is Done") 11 | break -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | saving utilities 3 | """ 4 | from collections import OrderedDict 5 | import json 6 | import os 7 | from os.path import abspath, dirname, exists, join 8 | import subprocess 9 | 10 | import torch 11 | from apex.amp._amp_state import _amp_state 12 | 13 | from utils.logger import LOGGER 14 | from horovod import torch as hvd 15 | 16 | 17 | def save_training_meta(args): 18 | if args.rank > 0: 19 | return 20 | 21 | if not exists(args.output_dir): 22 | os.makedirs(join(args.output_dir, 'log')) 23 | os.makedirs(join(args.output_dir, 'ckpt')) 24 | 25 | with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: 26 | json.dump(vars(args), writer, indent=4) 27 | model_config = json.load(open(args.model_config)) 28 | with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: 29 | json.dump(model_config, writer, indent=4) 30 | # git info 31 | #Mingyang: Ignore the git info 32 | # try: 33 | # LOGGER.info("Waiting on git info....") 34 | # c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], 35 | # timeout=10, stdout=subprocess.PIPE) 36 | # git_branch_name = c.stdout.decode().strip() 37 | # LOGGER.info("Git branch: %s", git_branch_name) 38 | # c = subprocess.run(["git", "rev-parse", "HEAD"], 39 | # timeout=10, stdout=subprocess.PIPE) 40 | # git_sha = c.stdout.decode().strip() 41 | # LOGGER.info("Git SHA: %s", git_sha) 42 | # git_dir = abspath(dirname(__file__)) 43 | # git_status = subprocess.check_output( 44 | # ['git', 'status', '--short'], 45 | # cwd=git_dir, universal_newlines=True).strip() 46 | # with open(join(args.output_dir, 'log', 'git_info.json'), 47 | # 'w') as writer: 48 | # json.dump({'branch': git_branch_name, 49 | # 'is_dirty': bool(git_status), 50 | # 'status': git_status, 51 | # 'sha': git_sha}, 52 | # writer, indent=4) 53 | # except subprocess.TimeoutExpired as e: 54 | # LOGGER.exception(e) 55 | # LOGGER.warn("Git info not found. Moving right along...") 56 | 57 | 58 | class ModelSaver(object): 59 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 60 | self.output_dir = output_dir 61 | self.prefix = prefix 62 | self.suffix = suffix 63 | 64 | def save(self, model, step, optimizer=None): 65 | output_model_file = join(self.output_dir, 66 | f"{self.prefix}_{step}.{self.suffix}") 67 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 68 | for k, v in model.state_dict().items()} 69 | if hasattr(model, 'vocab_pad') and model.vocab_pad: 70 | # store vocab embeddings before padding 71 | emb_w = state_dict['bert.embeddings.word_embeddings.weight'] 72 | emb_w = emb_w[:-model.vocab_pad, :] 73 | state_dict['bert.embeddings.word_embeddings.weight'] = emb_w 74 | state_dict['cls.predictions.decoder.weight'] = emb_w 75 | torch.save(state_dict, output_model_file) 76 | if optimizer is not None: 77 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 78 | if hasattr(optimizer, '_amp_stash'): 79 | pass # TODO fp16 optimizer 80 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 81 | 82 | 83 | def amp_state_dict(destination=None): 84 | if destination is None: 85 | destination = OrderedDict() 86 | 87 | for idx, loss_scaler in enumerate(_amp_state.loss_scalers): 88 | destination['loss_scaler%d' % idx] = { 89 | 'loss_scale': loss_scaler.loss_scale(), 90 | 'unskipped': loss_scaler._unskipped, 91 | } 92 | return destination 93 | 94 | 95 | def amp_load_state_dict(state_dict): 96 | # Check if state_dict containes the same number of loss_scalers as current 97 | # setup 98 | if len(state_dict) != len(_amp_state.loss_scalers): 99 | print('Warning: state_dict contains {} entries, while {} loss_scalers ' 100 | 'are used'.format(len(state_dict), len(_amp_state.loss_scalers))) 101 | 102 | state_dict = state_dict.copy() 103 | 104 | nb_loss_scalers = len(_amp_state.loss_scalers) 105 | unexpected_keys = [] 106 | # Initialize idx outside, since unexpected_keys will increase it if 107 | # enumerate is used 108 | idx = 0 109 | for key in state_dict: 110 | if 'loss_scaler' not in key: 111 | unexpected_keys.append(key) 112 | else: 113 | if idx > (nb_loss_scalers - 1): 114 | print('Skipping loss_scaler[{}], since num_losses was set to ' 115 | '{}'.format(idx, nb_loss_scalers)) 116 | break 117 | _amp_state.loss_scalers[idx]._loss_scale = ( 118 | state_dict[key]['loss_scale']) 119 | _amp_state.loss_scalers[idx]._unskipped = ( 120 | state_dict[key]['unskipped']) 121 | idx += 1 122 | 123 | if len(unexpected_keys) > 0: 124 | raise RuntimeError( 125 | 'Error(s) in loading state_dict. Unexpected key(s) in state_dict: ' 126 | '{}. '.format(', '.join('"{}"'.format(k) 127 | for k in unexpected_keys))) 128 | 129 | def _to_cuda(state): 130 | """ usually load from cpu checkpoint but need to load to cuda """ 131 | if isinstance(state, torch.Tensor): 132 | ret = state.cuda() # assume propoerly set py torch.cuda.set_device 133 | if 'Half' in state.type(): 134 | ret = ret.float() # apex O2 requires it 135 | return ret 136 | elif isinstance(state, list): 137 | new_state = [_to_cuda(t) for t in state] 138 | elif isinstance(state, tuple): 139 | new_state = tuple(_to_cuda(t) for t in state) 140 | elif isinstance(state, dict): 141 | new_state = {n: _to_cuda(t) for n, t in state.items()} 142 | else: 143 | return state 144 | return new_state 145 | 146 | 147 | def _to_cpu(state): 148 | """ store in cpu to avoid GPU0 device, fp16 to save space """ 149 | if isinstance(state, torch.Tensor): 150 | ret = state.cpu() 151 | if 'Float' in state.type(): 152 | ret = ret.half() 153 | return ret 154 | elif isinstance(state, list): 155 | new_state = [_to_cpu(t) for t in state] 156 | elif isinstance(state, tuple): 157 | new_state = tuple(_to_cpu(t) for t in state) 158 | elif isinstance(state, dict): 159 | new_state = {n: _to_cpu(t) for n, t in state.items()} 160 | else: 161 | return state 162 | return new_state 163 | 164 | class TrainingRestorer(object): 165 | def __init__(self, opts, model, optimizer): 166 | if exists(opts.output_dir) and hvd.rank() == 0: 167 | restore_opts = json.load( 168 | open(f'{opts.output_dir}/log/hps.json', 'r')) 169 | with open(join( 170 | opts.output_dir, 'log', 171 | 'restore_hps.json'), 'w') as writer: 172 | json.dump(vars(opts), writer, indent=4) 173 | assert vars(opts) == restore_opts 174 | # keep 2 checkpoints in case of corrupted 175 | self.save_path = f'{opts.output_dir}/restore.pt' 176 | self.backup_path = f'{opts.output_dir}/restore_backup.pt' 177 | self.model = model 178 | self.optimizer = optimizer 179 | self.save_steps = opts.save_steps 180 | self.amp = opts.fp16 181 | if exists(self.save_path) or exists(self.backup_path): 182 | LOGGER.info('found previous checkpoint. try to resume...') 183 | self.restore(opts) 184 | else: 185 | self.global_step = 0 186 | 187 | def step(self): 188 | self.global_step += 1 189 | if self.global_step % self.save_steps == 0: 190 | self.save() 191 | 192 | def save(self): 193 | checkpoint = {'global_step': self.global_step, 194 | 'model_state_dict': _to_cpu(self.model.state_dict()), 195 | 'optim_state_dict': _to_cpu(self.optimizer.state_dict())} 196 | if self.amp: 197 | checkpoint['amp_state_dict'] = amp_state_dict() 198 | if exists(self.save_path): 199 | os.rename(self.save_path, self.backup_path) 200 | torch.save(checkpoint, self.save_path) 201 | 202 | def restore(self, opts): 203 | try: 204 | checkpoint = torch.load(self.save_path) 205 | except Exception: 206 | checkpoint = torch.load(self.backup_path) 207 | self.global_step = checkpoint['global_step'] 208 | self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict'])) 209 | self.optimizer.load_state_dict( 210 | _to_cuda(checkpoint['optim_state_dict'])) 211 | if self.amp: 212 | amp_load_state_dict(checkpoint['amp_state_dict']) 213 | LOGGER.info(f'resume training from step {self.global_step}') 214 | -------------------------------------------------------------------------------- /utils/visual_entailment.py: -------------------------------------------------------------------------------- 1 | """ 2 | NOTE: modified from ban-vqa 3 | This code is slightly modified from Hengyuan Hu's repository. 4 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 5 | """ 6 | import os 7 | import sys 8 | import pickle 9 | 10 | 11 | def create_ans2label(path): 12 | """ 13 | occurence: dict {answer -> whatever} 14 | name: dir of the output file 15 | """ 16 | ans2label = {"contradiction": 0, "entailment":1 , "neutral": 2} 17 | label2ans = ["contradiction", "entailment", "neutral"] 18 | 19 | output_file = os.path.join(path, 'visual_entailment_ans2label.pkl') 20 | pickle.dump(ans2label, open(output_file, 'wb')) 21 | 22 | 23 | def compute_target(answers, ans2label): 24 | answer_count = {} 25 | for answer in answers: 26 | answer_ = answer 27 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 28 | 29 | labels = [] 30 | scores = [] 31 | for answer in answer_count: 32 | if answer not in ans2label: 33 | continue 34 | labels.append(ans2label[answer]) 35 | score = answer_count[answer]/len(answers) 36 | scores.append(score) 37 | target = {'labels': labels, 'scores': scores} 38 | return target 39 | 40 | 41 | if __name__ == '__main__': 42 | output = sys.argv[1:][0] 43 | print(output) 44 | if os.path.exists(f'{output}/visual_entailment_ans2label.pkl'): 45 | raise ValueError(f'{output} already exists') 46 | create_ans2label(output) 47 | -------------------------------------------------------------------------------- /utils/vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | NOTE: modified from ban-vqa 3 | This code is slightly modified from Hengyuan Hu's repository. 4 | https://github.com/hengyuan-hu/bottom-up-attention-vqa 5 | """ 6 | import os 7 | import json 8 | import re 9 | import sys 10 | import pickle 11 | 12 | 13 | CONTRACTIONS = { 14 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": 15 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", 16 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": 17 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": 18 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": 19 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": 20 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", 21 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": 22 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": 23 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", 24 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've": 25 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", 26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", 27 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", 28 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": 29 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": 30 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": 31 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": 32 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": 33 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": 34 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", 35 | "someoned've": "someone'd've", "someone'dve": "someone'd've", 36 | "someonell": "someone'll", "someones": "someone's", "somethingd": 37 | "something'd", "somethingd've": "something'd've", "something'dve": 38 | "something'd've", "somethingll": "something'll", "thats": 39 | "that's", "thered": "there'd", "thered've": "there'd've", 40 | "there'dve": "there'd've", "therere": "there're", "theres": 41 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": 42 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": 43 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": 44 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": 45 | "weren't", "whatll": "what'll", "whatre": "what're", "whats": 46 | "what's", "whatve": "what've", "whens": "when's", "whered": 47 | "where'd", "wheres": "where's", "whereve": "where've", "whod": 48 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": 49 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", 50 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": 51 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", 52 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": 53 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", 54 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": 55 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": 56 | "you'll", "youre": "you're", "youve": "you've" 57 | } 58 | 59 | MANUAL_MAP = {'none': '0', 60 | 'zero': '0', 61 | 'one': '1', 62 | 'two': '2', 63 | 'three': '3', 64 | 'four': '4', 65 | 'five': '5', 66 | 'six': '6', 67 | 'seven': '7', 68 | 'eight': '8', 69 | 'nine': '9', 70 | 'ten': '10'} 71 | ARTICLES = ['a', 'an', 'the'] 72 | PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") 73 | COMMA_STRIP = re.compile(r"(\d)(\,)(\d)") 74 | PUNCT = [';', r"/", '[', ']', '"', '{', '}', 75 | '(', ')', '=', '+', '\\', '_', '-', 76 | '>', '<', '@', '`', ',', '?', '!'] 77 | 78 | 79 | # Notice that VQA score is the average of 10 choose 9 candidate answers cases 80 | # See http://visualqa.org/evaluation.html 81 | def get_score(occurences): 82 | if occurences == 0: 83 | return .0 84 | elif occurences == 1: 85 | return .3 86 | elif occurences == 2: 87 | return .6 88 | elif occurences == 3: 89 | return .9 90 | else: 91 | return 1. 92 | 93 | 94 | def process_punctuation(inText): 95 | outText = inText 96 | for p in PUNCT: 97 | if (p + ' ' in inText 98 | or ' ' + p in inText 99 | or re.search(COMMA_STRIP, inText) is not None): 100 | outText = outText.replace(p, '') 101 | else: 102 | outText = outText.replace(p, ' ') 103 | outText = PERIOD_STRIP.sub("", outText, re.UNICODE) 104 | return outText 105 | 106 | 107 | def process_digit_article(inText): 108 | outText = [] 109 | tempText = inText.lower().split() 110 | for word in tempText: 111 | word = MANUAL_MAP.setdefault(word, word) 112 | if word not in ARTICLES: 113 | outText.append(word) 114 | else: 115 | pass 116 | for wordId, word in enumerate(outText): 117 | if word in CONTRACTIONS: 118 | outText[wordId] = CONTRACTIONS[word] 119 | outText = ' '.join(outText) 120 | return outText 121 | 122 | 123 | def preprocess_answer(answer): 124 | answer = process_digit_article(process_punctuation(answer)) 125 | answer = answer.replace(',', '') 126 | return answer 127 | 128 | 129 | def filter_answers(answers_dset, min_occurence): 130 | """This will change the answer to preprocessed version 131 | """ 132 | occurence = {} 133 | 134 | for ans_entry in answers_dset: 135 | gtruth = ans_entry.get('multiple_choice_answer', None) 136 | if gtruth is None: 137 | gtruth = ans_entry['answers'][0]['answer'] # VG, GQA pretraining 138 | gtruth = preprocess_answer(gtruth) 139 | if gtruth not in occurence: 140 | occurence[gtruth] = set() 141 | occurence[gtruth].add(ans_entry['question_id']) 142 | for answer in list(occurence): 143 | if len(occurence[answer]) < min_occurence: 144 | occurence.pop(answer) 145 | 146 | print('Num of answers that appear >= %d times: %d' % ( 147 | min_occurence, len(occurence))) 148 | return occurence 149 | 150 | 151 | def create_ans2label(occurence, path): 152 | """ 153 | occurence: dict {answer -> whatever} 154 | name: dir of the output file 155 | """ 156 | ans2label = {} 157 | label2ans = [] 158 | label = 0 159 | for answer in occurence: 160 | label2ans.append(answer) 161 | ans2label[answer] = label 162 | label += 1 163 | 164 | output_file = os.path.join(path, 'ans2label.pkl') 165 | pickle.dump(ans2label, open(output_file, 'wb')) 166 | 167 | 168 | def compute_target(answers, ans2label): 169 | answer_count = {} 170 | if len(answers) == 1: 171 | # VG VQA, GQA 172 | answer_ = preprocess_answer(answers[0]['answer']) 173 | answer_count[answer_] = 10 174 | else: 175 | # COCO VQA 176 | for answer in answers: 177 | answer_ = preprocess_answer(answer['answer']) 178 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 179 | 180 | labels = [] 181 | scores = [] 182 | for answer in answer_count: 183 | if answer not in ans2label: 184 | continue 185 | labels.append(ans2label[answer]) 186 | score = get_score(answer_count[answer]) 187 | scores.append(score) 188 | target = {'labels': labels, 'scores': scores} 189 | return target 190 | 191 | 192 | if __name__ == '__main__': 193 | *answer_files, output = sys.argv[1:] 194 | answers = [] 195 | for ans_file in answer_files: 196 | ans = json.load(open(ans_file))['annotations'] 197 | answers.extend(ans) 198 | 199 | occurence = filter_answers(answers, 9) 200 | 201 | if os.path.exists(f'{output}/ans2label.pkl'): 202 | raise ValueError(f'{output} already exists') 203 | create_ans2label(occurence, output) 204 | --------------------------------------------------------------------------------