├── .gitignore
├── LICENSE
├── README.md
├── config
├── uc2-base.json
├── uc2_mscoco_itm.json
├── uc2_pretrain.json
├── uniter-base-cased.json
└── uniter-base.json
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── data.cpython-36.pyc
│ ├── itm.cpython-36.pyc
│ ├── loader.cpython-36.pyc
│ ├── mlm.cpython-36.pyc
│ ├── mrm.cpython-36.pyc
│ ├── mrm_nce.cpython-36.pyc
│ ├── nlvr2.cpython-36.pyc
│ ├── sampler.cpython-36.pyc
│ ├── ve.cpython-36.pyc
│ └── vqa.cpython-36.pyc
├── data.py
├── img_label_objects.txt
├── itm.py
├── loader.py
├── mlm.py
├── mrm.py
├── mrm_nce.py
├── nlvr2.py
├── re.py
├── sampler.py
├── test_data
│ ├── input0.txt
│ ├── input1.txt
│ ├── input2.txt
│ ├── input3.txt
│ ├── input4.txt
│ ├── input5.txt
│ ├── input6.txt
│ └── input7.txt
├── vcr.py
├── ve.py
└── vqa.py
├── eval
├── __pycache__
│ └── itm.cpython-36.pyc
├── itm.py
└── nlvr2.py
├── itm.py
├── launch_container.sh
├── misc
├── ans2label.json
├── ans2label_en_trans2_ja.json
├── ans2label_ja.json
├── ans2label_ja_trans2_en.json
└── ans2label_vg.json
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── attention.cpython-36.pyc
│ ├── const_variable.cpython-36.pyc
│ ├── itm.cpython-36.pyc
│ ├── layer.cpython-36.pyc
│ ├── model.cpython-36.pyc
│ ├── nlvr2.cpython-36.pyc
│ ├── ot.cpython-36.pyc
│ ├── ve.cpython-36.pyc
│ └── vqa.cpython-36.pyc
├── attention.py
├── const_variable.py
├── gqa.py
├── img_label_objects.txt
├── itm.py
├── layer.py
├── model.py
├── nlvr2.py
├── ot.py
├── re.py
├── vcr.py
├── ve.py
└── vqa.py
├── object_labels
├── img_label_objects.txt
├── img_label_objects_cs.txt
├── img_label_objects_de.txt
├── img_label_objects_fr.txt
├── img_label_objects_ja.txt
└── img_label_objects_zh.txt
├── optim
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── adamw.cpython-36.pyc
│ ├── misc.cpython-36.pyc
│ └── sched.cpython-36.pyc
├── adamw.py
├── misc.py
└── sched.py
├── pretrain.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── const.cpython-36.pyc
├── distributed.cpython-36.pyc
├── logger.cpython-36.pyc
├── misc.cpython-36.pyc
├── save.cpython-36.pyc
├── visual_entailment.cpython-36.pyc
└── vqa.cpython-36.pyc
├── const.py
├── distributed.py
├── itm.py
├── logger.py
├── m3p_tokenizer.py
├── misc.py
├── ms_internal_mt.py
├── ms_internal_mt_label.py
├── ms_internal_mt_popen.py
├── save.py
├── visual_entailment.py
└── vqa.py
/.gitignore:
--------------------------------------------------------------------------------
1 | config/*_test.json
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Mingyang Zhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## UC2
2 | [UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training](https://arxiv.org/abs/2104.00332)
3 |
4 | [Mingyang Zhou](https://github.com/zmykevin), [Luowei Zhou](https://luoweizhou.github.io/), [Shuohang Wang](https://www.microsoft.com/en-us/research/people/shuowa/), [Yu Cheng](https://sites.google.com/site/chengyu05), [Linjie Li](https://www.microsoft.com/en-us/research/people/linjli/), [Zhou Yu](http://www.cs.columbia.edu/~zhouyu/), [Jingjing Liu](https://www.linkedin.com/in/jingjing-liu-65703431/)
5 |
6 | This is the official repository of UC2, a multili-lingual multi-modal pre-training framefork. In this repository we support end-to-end pretraining and finetuning for image-text retrieval on COCO.
7 |
8 | ### Requirements
9 | We Provide a Docker image to run our code. Please install the following:
10 | - [nvidia driver (418+)](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation),
11 | - [Docker (19.03+)](https://docs.docker.com/engine/install/ubuntu/),
12 | - [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-docker#quickstart).
13 |
14 | To run the docker command without sudo, user need to have [docker group membership](https://docs.docker.com/engine/install/linux-postinstall/). Our code only supports Linux with NVIDIA GPUs. We test our code on Ubuntu 18.04 and V100 cards.
15 |
16 | ## Data and Pretrained Checkpoints
17 | Download the pre-processed text features and pretrained checkpoints with the following command:
18 | ```
19 | wget https://mmaisharables.blob.core.windows.net/uc2/UC2_DATA.tar.gz
20 |
21 | ```
22 | The image features for mscoco can be obtained from [UNITER](https://github.com/ChenRocks/UNITER) via this [code script](https://github.com/ChenRocks/UNITER/blob/master/scripts/download_itm.sh). As CC's image features are large and inconvient for direct downloading, please contact UNITER's author to obtain the image features if you are interested in pretraining.
23 |
24 | The text translation of CC captions for pre-training purpose can be also found here:
25 | https://drive.google.com/drive/folders/1K6dGXMz8egutKk-UTXbFrp3HQ9QqPfmM?usp=sharing
26 |
27 | ## Launch the Docker Container for Experiments
28 | Once the user set up the data and checkpoints properly, please run the following command to launch a docker container and start the pretraining process.
29 | ```
30 | source launch_container_pretrain.sh /PATH_TO_STORAGE/txt_db /PATH_TO_STORAGE/img_db /PATH_TO_STORAGE/finetune /PATH_TO_STORAG/pretrain
31 | ```
32 | ## Pretraining
33 | (Inside the Docker Container)If the user wants to run pretraining, please use the following command:
34 | ```
35 | horovodrun -np $N_GPU python pretrain.py --config config/uc2_pretrain.json
36 | ```
37 | ## Downstream Task Finetuning
38 | **Text-to-Image Retrieval**
39 | To run the finetuning experiment for the text-to-image retrieval task, please use the following command:
40 | ```
41 | horovodrun -np $N_GPU python itm.py --config config/uc2_mscoco_itm.json
42 | ```
43 |
44 | ## Citation
45 | If you find this code useful for your research, please consider citing:
46 | ```
47 | @InProceedings{zhou2021uc,
48 | author = {Zhou, Mingyang and Zhou, Luowei and Wang, Shuohang and Cheng, Yu and Li, Linjie and Yu, Zhou and Liu, Jingjing},
49 | title = {UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training},
50 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2021)},
51 | year = {2021},
52 | month = {June},
53 | abstract = {Vision-and-language pre-training has achieved impressive success in learning multimodal representations between vision and language. To generalize this success to non-English languages, we introduce UC2 , the first machine translation-augmented framework for cross-lingual cross-modal representation learning. To tackle the scarcity problem of multilingual captions for image datasets, we first augment existing English-only datasets with other languages via machine translation (MT). Then we extend the standard Masked Language Modeling and Image-Text Matching training objectives to multilingual setting, where alignment between different languages is captured through shared visual context (i.e., using image as pivot). To facilitate the learning of a joint embedding space of images and all languages of interest, we further propose two novel pre-training tasks, namely Masked Region-to-Token Modeling (MRTM) and Visual Translation Language Modeling (VTLM), leveraging MT-enhanced translated data. Evaluation on multilingual image-text retrieval and multilingual visual question answering benchmarks demonstrates that our proposed framework achieves new state of the art on diverse non-English benchmarks while maintaining comparable performance to monolingual pre-trained models on English tasks.},
54 | url = {https://www.microsoft.com/en-us/research/publication/uc2-universal-cross-lingual-cross-modal-vision-and-language-pre-training/},
55 | }
56 |
57 | ```
58 |
59 | ## Acknowledge
60 | Our code is mainly based on [Linjie Li](https://www.microsoft.com/en-us/research/people/linjli/) and [Yen-Chun Chen](https://www.microsoft.com/en-us/research/people/yenche/)'s project [UNITER](https://github.com/ChenRocks/UNITER). We thank the author for opening source their code and providing helful discussion for code implementation. Portions of the code also uses resources from [transformers](https://github.com/huggingface/transformers).
61 |
62 | ## Liscense
63 | MIT
64 |
--------------------------------------------------------------------------------
/config/uc2-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "max_position_embeddings": 514,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 12,
11 | "model_type": "xlm-roberta",
12 | "output_past": true,
13 | "type_vocab_size": 2,
14 | "layer_norm_eps": 1e-5,
15 | "pad_token_id": 1,
16 | "vocab_size": 250002
17 | }
18 |
--------------------------------------------------------------------------------
/config/uc2_mscoco_itm.json:
--------------------------------------------------------------------------------
1 | {
2 | "compressed_db": false,
3 | "checkpoint": "/pretrain/uc2/ckpt/best.pt",
4 | "output_dir": "/storage/uc2_mscoco_all_lan",
5 | "max_txt_len": 60,
6 | "conf_th": 0.2,
7 | "max_bb": 100,
8 | "min_bb": 10,
9 | "num_bb": 36,
10 | "train_batch_size": 40,
11 | "negative_size": 1,
12 | "hard_neg_size": 0,
13 | "hard_neg_pool_size": 20,
14 | "steps_per_hard_neg": -1,
15 | "inf_minibatch_size": 300,
16 | "margin": 0.2,
17 | "gradient_accumulation_steps": 8,
18 | "learning_rate": 0.0001,
19 | "xlmr_lr": 1e-07,
20 | "valid_steps": 10000,
21 | "num_train_steps": 50000,
22 | "optim": "adamw",
23 | "betas": [
24 | 0.9,
25 | 0.98
26 | ],
27 | "decay": "linear",
28 | "dropout": 0.1,
29 | "weight_decay": 0,
30 | "grad_norm": 2.0,
31 | "warmup_steps": 5000,
32 | "seed": 42,
33 | "full_val": false,
34 | "fp16": true,
35 | "n_workers": 4,
36 | "pin_mem": false,
37 | "rename_checkpoints": false,
38 | "separate_lr": false,
39 | "load_embedding_only": false,
40 | "load_layer": 0,
41 | "save_steps": 200,
42 | "train_txt_dbs": [
43 | "/db/uc2_mscoco/itm_coco_en_train.db/",
44 | "/db/uc2_mscoco/itm_coco_en_restval.db/",
45 | "/db/uc2_mscoco/itm_coco_ja_train.db/",
46 | "/db/uc2_mscoco/itm_coco_zh_train.db/"
47 |
48 | ],
49 | "train_img_dbs": [
50 | "/img/coco_train2014/",
51 | "/img/coco_val2014/",
52 | "/img/coco_train2014/",
53 | [
54 | "/img/coco_train2014/",
55 | "/img/coco_val2014/"
56 | ]
57 |
58 | ],
59 | "val_txt_db": [
60 | "/db/uc2_mscoco/itm_coco_en_val.db/"
61 | ],
62 | "val_img_db": [
63 | "/img/coco_val2014/"
64 | ],
65 | "test_txt_db": [
66 | "/db/uc2_mscoco/itm_coco_en_test_1k_0.db/",
67 | "/db/uc2_mscoco/itm_coco_en_test_1k_1.db/",
68 | "/db/uc2_mscoco/itm_coco_en_test_1k_2.db/",
69 | "/db/uc2_mscoco/itm_coco_en_test_1k_3.db/",
70 | "/db/uc2_mscoco/itm_coco_en_test_1k_4.db/",
71 | "/db/uc2_mscoco/itm_coco_ja_test_1k_0.db/",
72 | "/db/uc2_mscoco/itm_coco_ja_test_1k_1.db/",
73 | "/db/uc2_mscoco/itm_coco_ja_test_1k_2.db/",
74 | "/db/uc2_mscoco/itm_coco_ja_test_1k_3.db/",
75 | "/db/uc2_mscoco/itm_coco_ja_test_1k_4.db/",
76 | "/db/uc2_mscoco/itm_coco_zh_test.db/"
77 | ],
78 | "test_img_db": [
79 | "/img/coco_val2014/",
80 | "/img/coco_val2014/",
81 | "/img/coco_val2014/",
82 | "/img/coco_val2014/",
83 | "/img/coco_val2014/",
84 | "/img/coco_val2014/",
85 | "/img/coco_val2014/",
86 | "/img/coco_val2014/",
87 | "/img/coco_val2014/",
88 | "/img/coco_val2014/",
89 | [
90 | "/img/coco_train2014/",
91 | "/img/coco_val2014/"
92 | ]
93 |
94 | ],
95 | "model_config": "config/uc2-base.json",
96 | "mcan": false,
97 | "cut_bert": -1,
98 | "neg_sample_from": "i",
99 | "train_neg_sample_p": 0.67,
100 | "train_neg_sample_size": 1,
101 | "val_minibatch_size": 400,
102 | "test_minibatch_size": 400,
103 | "train_loss": "rank",
104 | "img_rank_weight": 1.0,
105 | "txt_rank_weight": 1.0,
106 | "no_cuda": false,
107 | "local_rank": 0,
108 | "wave_length": 1000,
109 | "val_batch_size": 4096,
110 | "eval_method": "rank",
111 | "rank": 0
112 | }
--------------------------------------------------------------------------------
/config/uc2_pretrain.json:
--------------------------------------------------------------------------------
1 | {
2 | "compressed_db": false,
3 | "model_config": "config/uc2-base.json",
4 | "checkpoint": "/pretrain/xlm-roberta-base/pytorch_model.bin",
5 | "output_dir": "/storage/uc2_pretrain",
6 | "mrm_prob": 0.15,
7 | "neg_size": 128,
8 | "nce_temp": 1.0,
9 | "itm_neg_prob": 0.5,
10 | "itm_ot_lambda": 0.0,
11 | "ot_pos_only": false,
12 | "max_txt_len": 60,
13 | "conf_th": 0.2,
14 | "max_bb": 100,
15 | "min_bb": 10,
16 | "num_bb": 36,
17 | "train_batch_size": 10240,
18 | "val_batch_size": 10240,
19 | "gradient_accumulation_steps": 3,
20 | "learning_rate": 4e-05,
21 | "valid_steps": 5000,
22 | "num_train_steps": 200000,
23 | "optim": "adamw",
24 | "betas": [
25 | 0.9,
26 | 0.98
27 | ],
28 | "decay": "linear",
29 | "dropout": 0.1,
30 | "weight_decay": 0.01,
31 | "grad_norm": 5.0,
32 | "warmup_steps": 10000,
33 | "seed": 42,
34 | "fp16": true,
35 | "n_workers": 4,
36 | "pin_mem": true,
37 | "rename_checkpoints": false,
38 | "itm_hard_neg": false,
39 | "co_masking": true,
40 | "co_masking_mode": "mix",
41 | "save_steps": 200,
42 | "early_adaptation": true,
43 | "early_adaptation_checkpoint": "/pretrain/early_adaptation_multilingual/ckpt/best.pt",
44 | "multilingual_vmlm": true,
45 | "train_datasets": [
46 | {
47 | "name": "cc",
48 | "db": [
49 | "/db/uc2_pretrain/conceptual_caption_en_train.db",
50 | "/db/uc2_pretrain/conceptual_caption_de_train.db",
51 | "/db/uc2_pretrain/conceptual_caption_zh_trian.db",
52 | "/db/uc2_pretrain/conceptual_caption_ja_train.db",
53 | "/db/uc2_pretrain/conceptual_caption_fr_train.db",
54 | "/db/uc2_pretrain/conceptual_caption_cs_train.db"
55 | ],
56 | "img": [
57 | "/img/gcc_train",
58 | "/img/gcc_train",
59 | "/img/gcc_train",
60 | "/img/gcc_train",
61 | "/img/gcc_train",
62 | "/img/gcc_train"
63 | ],
64 | "tasks": [
65 | "itm",
66 | "mlm",
67 | "vmlm"
68 | ],
69 | "img_token_soft_label": [
70 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_train/train_img_token_soft_label_batch.db"
71 | ],
72 | "mix_ratio": [
73 | 9,
74 | 12,
75 | 9
76 | ]
77 | },
78 | {
79 | "name": "cc_paired",
80 | "db": [
81 | "/db/uc2_pretrain/conceptual_caption_ende_train.db/",
82 | "/db/uc2_pretrain/conceptual_caption_enzh_train.db/",
83 | "/db/uc2_pretrain/conceptual_caption_enja_train.db/",
84 | "/db/uc2_pretrain/conceptual_caption_encs_train.db/",
85 | "/db/uc2_pretrain/conceptual_caption_enfr_train.db/"
86 | ],
87 | "img": [
88 | "/img/gcc_train",
89 | "/img/gcc_train",
90 | "/img/gcc_train",
91 | "/img/gcc_train",
92 | "/img/gcc_train"
93 | ],
94 | "tasks": [
95 | "tlm"
96 | ],
97 | "img_token_soft_label": [
98 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_train/train_img_token_soft_label_batch.db"
99 | ],
100 | "mix_ratio": [
101 | 3
102 | ]
103 | }
104 | ],
105 | "val_datasets": [
106 | {
107 | "name": "cc_en",
108 | "db": [
109 | "/db/uc2_pretrain/conceptual_caption_en_val.db/"
110 | ],
111 | "img": [
112 | "/img/gcc_val"
113 | ],
114 | "img_token_soft_label": [
115 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
116 | ],
117 | "tasks": [
118 | "itm",
119 | "mlm",
120 | "vmlm"
121 | ]
122 | },
123 | {
124 | "name": "cc_de",
125 | "db": [
126 | "/db/uc2_pretrain/conceptual_caption_de_val.db/"
127 | ],
128 | "img": [
129 | "/img/gcc_val"
130 | ],
131 | "img_token_soft_label": [
132 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
133 | ],
134 | "tasks": [
135 | "itm",
136 | "mlm",
137 | "vmlm"
138 | ]
139 | },
140 | {
141 | "name": "cc_zh",
142 | "db": [
143 | "/db/uc2_pretrain/conceptual_caption_zh_val.db/"
144 | ],
145 | "img": [
146 | "/img/gcc_val"
147 | ],
148 | "img_token_soft_label": [
149 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
150 | ],
151 | "tasks": [
152 | "itm",
153 | "mlm",
154 | "vmlm"
155 | ]
156 | },
157 | {
158 | "name": "cc_ja",
159 | "db": [
160 | "/db/uc2_pretrain/conceptual_caption_ja_val.db/"
161 | ],
162 | "img": [
163 | "/img/gcc_val"
164 | ],
165 | "img_token_soft_label": [
166 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
167 | ],
168 | "tasks": [
169 | "itm",
170 | "mlm",
171 | "vmlm"
172 | ]
173 | },
174 | {
175 | "name": "cc_fr",
176 | "db": [
177 | "/db/uc2_pretrain/conceptual_caption_fr_val.db/"
178 | ],
179 | "img": [
180 | "/img/gcc_val"
181 | ],
182 | "img_token_soft_label": [
183 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
184 | ],
185 | "tasks": [
186 | "itm",
187 | "mlm",
188 | "vmlm"
189 | ]
190 | },
191 | {
192 | "name": "cc_cs",
193 | "db": [
194 | "/db/uc2_pretrain/conceptual_caption_cs_val.db/"
195 | ],
196 | "img": [
197 | "/img/gcc_val"
198 | ],
199 | "img_token_soft_label": [
200 | "/data/UNITER_DATA/UNITER_CC_Dataset/IMG_DB/gcc_val/val_img_token_soft_label_batch.db"
201 | ],
202 | "tasks": [
203 | "itm",
204 | "mlm",
205 | "vmlm"
206 | ]
207 | },
208 | ],
209 | "ans2label": "",
210 | "rank": 0
211 | }
--------------------------------------------------------------------------------
/config/uniter-base-cased.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 12,
11 | "type_vocab_size": 2,
12 | "vocab_size": 28996
13 | }
14 |
--------------------------------------------------------------------------------
/config/uniter-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 768,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 3072,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 12,
11 | "type_vocab_size": 2,
12 | "vocab_size": 30522
13 | }
14 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import (TxtTokLmdb, DetectFeatLmdb, Img_SoftLabel_Lmdb,
2 | ConcatDatasetWithLens, ImageLmdbGroup, DetectFeatTxtTokDataset, get_ids_and_lens, pad_tensors, get_gather_index)
3 | from .mlm import (MlmDataset, MlmDataset_VLXLMR, VmlmDataset, Vmlm_Softlabel_Dataset, MmxlmDataset, Mmxlm_Softlabel_Dataset, MlmDataset_Dmasking, MlmEvalDataset,
4 | BlindMlmDataset, BlindMlmEvalDataset,
5 | mlm_collate, xlmr_mlm_collate, xlmr_mlm_dmasking_collate, xlmr_tlm_ni_dmasking_collate, xlmr_mmxlm_collate, xlmr_mmxlm_softlabel_collate, mlm_eval_collate,
6 | mlm_blind_collate, mlm_blind_eval_collate)
7 | from .mrm import (MrfrDataset, OnlyImgMrfrDataset,
8 | MrcDataset, OnlyImgMrcDataset,
9 | mrfr_collate, xlmr_mrfr_collate, mrfr_only_img_collate,
10 | mrc_collate, xlmr_mrc_collate, mrc_only_img_collate, _mask_img_feat, _get_targets)
11 | from .itm import (TokenBucketSamplerForItm,
12 | ItmDataset, ItmDataset_HardNeg, itm_collate, itm_ot_collate, xlmr_itm_collate, xlmr_itm_ot_collate,
13 | ItmRankDataset, ItmRankDataset_COCO_CN, ItmRankDatasetHardNeg, itm_rank_collate,
14 | ItmRankDatasetHardNegFromText,
15 | ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate,
16 | ItmHardNegDataset, itm_hn_collate,
17 | ItmValDataset, ItmValDataset_COCO_CN, itm_val_collate,
18 | ItmEvalDataset, ItmEvalDataset_COCO_CN, itm_eval_collate, xlmr_itm_rank_collate)
19 | from .sampler import TokenBucketSampler, DistributedSampler
20 | from .loader import MetaLoader, PrefetchLoader
21 |
22 | from .vqa import VqaDataset, vqa_collate, VqaEvalDataset, vqa_eval_collate, xlmr_vqa_collate, xlmr_vqa_eval_collate
23 | from .nlvr2 import (Nlvr2PairedDataset, nlvr2_paired_collate,
24 | Nlvr2PairedEvalDataset, nlvr2_paired_eval_collate,
25 | Nlvr2TripletDataset, nlvr2_triplet_collate,
26 | Nlvr2TripletEvalDataset, nlvr2_triplet_eval_collate)
27 | from .ve import VeDataset, ve_collate, VeEvalDataset, ve_eval_collate
28 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/data.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/itm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/itm.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/mlm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mlm.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/mrm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mrm.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/mrm_nce.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/mrm_nce.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/nlvr2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/nlvr2.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/ve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/ve.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/vqa.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/__pycache__/vqa.cpython-36.pyc
--------------------------------------------------------------------------------
/data/loader.py:
--------------------------------------------------------------------------------
1 | """
2 | A meta data loader for sampling from different datasets / training tasks
3 | A prefetch loader to speedup data loading
4 | """
5 | import random
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 |
10 | from utils.distributed import any_broadcast
11 |
12 |
13 | class MetaLoader(object):
14 | """ wraps multiple data loader """
15 | def __init__(self, loaders, accum_steps=1, distributed=False):
16 | assert isinstance(loaders, dict)
17 | self.name2loader = {}
18 | self.name2iter = {}
19 | self.sampling_pools = []
20 | for n, l in loaders.items():
21 | if isinstance(l, tuple):
22 | l, r = l
23 | elif isinstance(l, DataLoader):
24 | r = 1
25 | else:
26 | raise ValueError()
27 | # print(n)
28 | # print(l)
29 | self.name2loader[n] = l
30 | self.name2iter[n] = iter(l)
31 | self.sampling_pools.extend([n]*r)
32 |
33 | self.accum_steps = accum_steps
34 | self.distributed = distributed
35 | self.step = 0
36 |
37 | def __iter__(self):
38 | """ this iterator will run indefinitely """
39 | task = self.sampling_pools[0]
40 | while True:
41 | if self.step % self.accum_steps == 0:
42 | task = random.choice(self.sampling_pools)
43 | if self.distributed:
44 | # make sure all process is training same task
45 | task = any_broadcast(task, 0)
46 | self.step += 1
47 | iter_ = self.name2iter[task]
48 | try:
49 | batch = next(iter_)
50 | except StopIteration:
51 | iter_ = iter(self.name2loader[task])
52 | batch = next(iter_)
53 | self.name2iter[task] = iter_
54 |
55 | yield task, batch
56 |
57 |
58 | def move_to_cuda(batch):
59 | if isinstance(batch, torch.Tensor):
60 | return batch.cuda(non_blocking=True)
61 | elif isinstance(batch, list):
62 | new_batch = [move_to_cuda(t) for t in batch]
63 | elif isinstance(batch, tuple):
64 | new_batch = tuple(move_to_cuda(t) for t in batch)
65 | elif isinstance(batch, dict):
66 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
67 | else:
68 | return batch
69 | return new_batch
70 |
71 |
72 | def record_cuda_stream(batch):
73 | if isinstance(batch, torch.Tensor):
74 | batch.record_stream(torch.cuda.current_stream())
75 | elif isinstance(batch, list) or isinstance(batch, tuple):
76 | for t in batch:
77 | record_cuda_stream(t)
78 | elif isinstance(batch, dict):
79 | for t in batch.values():
80 | record_cuda_stream(t)
81 | else:
82 | pass
83 |
84 |
85 | class PrefetchLoader(object):
86 | """
87 | overlap compute and cuda data transfer
88 | (copied and then modified from nvidia apex)
89 | """
90 | def __init__(self, loader):
91 | self.loader = loader
92 | self.stream = torch.cuda.Stream()
93 |
94 | def __iter__(self):
95 | loader_it = iter(self.loader)
96 | self.preload(loader_it)
97 | batch = self.next(loader_it)
98 | while batch is not None:
99 | yield batch
100 | batch = self.next(loader_it)
101 |
102 | def __len__(self):
103 | return len(self.loader)
104 |
105 | def preload(self, it):
106 | try:
107 | self.batch = next(it)
108 | except StopIteration:
109 | self.batch = None
110 | return
111 | # if record_stream() doesn't work, another option is to make sure
112 | # device inputs are created on the main stream.
113 | # self.next_input_gpu = torch.empty_like(self.next_input,
114 | # device='cuda')
115 | # self.next_target_gpu = torch.empty_like(self.next_target,
116 | # device='cuda')
117 | # Need to make sure the memory allocated for next_* is not still in use
118 | # by the main stream at the time we start copying to next_*:
119 | # self.stream.wait_stream(torch.cuda.current_stream())
120 | with torch.cuda.stream(self.stream):
121 | self.batch = move_to_cuda(self.batch)
122 | # more code for the alternative if record_stream() doesn't work:
123 | # copy_ will record the use of the pinned source tensor in this
124 | # side stream.
125 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
126 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
127 | # self.next_input = self.next_input_gpu
128 | # self.next_target = self.next_target_gpu
129 |
130 | def next(self, it):
131 | torch.cuda.current_stream().wait_stream(self.stream)
132 | batch = self.batch
133 | if batch is not None:
134 | record_cuda_stream(batch)
135 | self.preload(it)
136 | return batch
137 |
138 | def __getattr__(self, name):
139 | method = self.loader.__getattribute__(name)
140 | return method
141 |
--------------------------------------------------------------------------------
/data/mrm.py:
--------------------------------------------------------------------------------
1 | """
2 | MRM Datasets
3 | """
4 | import random
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torch.nn.utils.rnn import pad_sequence
9 | from toolz.sandbox import unzip
10 | from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
11 |
12 |
13 | def _get_img_mask(mask_prob, num_bb):
14 | img_mask = [random.random() < mask_prob for _ in range(num_bb)]
15 | if not any(img_mask):
16 | # at least mask 1
17 | img_mask[random.choice(range(num_bb))] = True
18 | img_mask = torch.tensor(img_mask)
19 | return img_mask
20 |
21 |
22 | def _get_img_tgt_mask(img_mask, txt_len):
23 | z = torch.zeros(txt_len, dtype=torch.uint8)
24 | img_mask_tgt = torch.cat([z, img_mask], dim=0)
25 | return img_mask_tgt
26 |
27 |
28 | def _get_feat_target(img_feat, img_masks):
29 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d)
30 | feat_dim = img_feat.size(-1)
31 | feat_targets = img_feat[img_masks_ext].contiguous().view(
32 | -1, feat_dim) # (s, d)
33 | return feat_targets
34 |
35 |
36 | def _mask_img_feat(img_feat, img_masks):
37 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat)
38 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
39 | return img_feat_masked
40 |
41 |
42 | class MrfrDataset(DetectFeatTxtTokDataset):
43 | def __init__(self, mask_prob, *args, **kwargs):
44 | super().__init__(*args, **kwargs)
45 | self.mask_prob = mask_prob
46 |
47 | def __getitem__(self, i):
48 | """
49 | Return:
50 | - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
51 | - img_feat : (num_bb, d)
52 | - img_pos_feat : (num_bb, 7)
53 | - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1]
54 | - img_mask : (num_bb, ) between {0, 1}
55 | """
56 | example = super().__getitem__(i)
57 | # text input
58 | input_ids = example['input_ids']
59 | input_ids = self.txt_db.combine_inputs(input_ids)
60 |
61 | # image input features
62 | img_feat, img_pos_feat, num_bb = self._get_img_feat(
63 | example['img_fname'])
64 | img_mask = _get_img_mask(self.mask_prob, num_bb)
65 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
66 |
67 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
68 |
69 | return (input_ids, img_feat, img_pos_feat,
70 | attn_masks, img_mask, img_mask_tgt)
71 |
72 | #Added by Mingyang Zhou
73 | def xlmr_mrfr_collate(inputs):
74 | """
75 | Return:
76 | - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 1, 1], 1s padded
77 | - position_ids : (n, max_L)
78 | - txt_lens : list of [input_len]
79 | - img_feat : (n, max_num_bb, d)
80 | - img_pos_feat : (n, max_num_bb, 7)
81 | - num_bbs : list of [num_bb]
82 | - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1]
83 | - img_masks : (n, max_num_bb) between {0, 1}
84 | """
85 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
86 | ) = map(list, unzip(inputs))
87 |
88 | txt_lens = [i.size(0) for i in input_ids]
89 |
90 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1)
91 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
92 | ).unsqueeze(0)
93 |
94 | num_bbs = [f.size(0) for f in img_feats]
95 | img_feat = pad_tensors(img_feats, num_bbs)
96 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
97 |
98 | # mask features
99 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
100 | feat_targets = _get_feat_target(img_feat, img_masks)
101 | img_feat = _mask_img_feat(img_feat, img_masks)
102 | img_mask_tgt = pad_sequence(img_mask_tgts,
103 | batch_first=True, padding_value=0)
104 |
105 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
106 | bs, max_tl = input_ids.size()
107 | out_size = attn_masks.size(1)
108 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
109 |
110 | batch = {'input_ids': input_ids,
111 | 'position_ids': position_ids,
112 | 'img_feat': img_feat,
113 | 'img_pos_feat': img_pos_feat,
114 | 'attn_masks': attn_masks,
115 | 'gather_index': gather_index,
116 | 'feat_targets': feat_targets,
117 | 'img_masks': img_masks,
118 | 'img_mask_tgt': img_mask_tgt}
119 | return batch
120 |
121 | def mrfr_collate(inputs):
122 | """
123 | Return:
124 | - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded
125 | - position_ids : (n, max_L)
126 | - txt_lens : list of [input_len]
127 | - img_feat : (n, max_num_bb, d)
128 | - img_pos_feat : (n, max_num_bb, 7)
129 | - num_bbs : list of [num_bb]
130 | - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1]
131 | - img_masks : (n, max_num_bb) between {0, 1}
132 | """
133 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
134 | ) = map(list, unzip(inputs))
135 |
136 | txt_lens = [i.size(0) for i in input_ids]
137 |
138 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
139 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
140 | ).unsqueeze(0)
141 |
142 | num_bbs = [f.size(0) for f in img_feats]
143 | img_feat = pad_tensors(img_feats, num_bbs)
144 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
145 |
146 | # mask features
147 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
148 | feat_targets = _get_feat_target(img_feat, img_masks)
149 | img_feat = _mask_img_feat(img_feat, img_masks)
150 | img_mask_tgt = pad_sequence(img_mask_tgts,
151 | batch_first=True, padding_value=0)
152 |
153 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
154 | bs, max_tl = input_ids.size()
155 | out_size = attn_masks.size(1)
156 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
157 |
158 | batch = {'input_ids': input_ids,
159 | 'position_ids': position_ids,
160 | 'img_feat': img_feat,
161 | 'img_pos_feat': img_pos_feat,
162 | 'attn_masks': attn_masks,
163 | 'gather_index': gather_index,
164 | 'feat_targets': feat_targets,
165 | 'img_masks': img_masks,
166 | 'img_mask_tgt': img_mask_tgt}
167 | return batch
168 |
169 |
170 | class OnlyImgMrfrDataset(Dataset):
171 | """ an image-only MRM """
172 | def __init__(self, mask_prob, img_db):
173 | self.ids, self.lens = map(list, unzip(self.img_db.name2nbb.items()))
174 |
175 | def __getitem__(self, i):
176 | id_ = self.ids[i]
177 | img_feat, img_pos_feat, num_bb = self._get_img_feat(id_)
178 | attn_masks = torch.ones(num_bb, dtype=torch.long)
179 | img_mask = _get_img_mask(self.mask_prob, num_bb)
180 |
181 | return img_feat, img_pos_feat, attn_masks, img_mask
182 |
183 | def _get_img_feat(self, fname):
184 | img_feat, bb = self.img_db[fname]
185 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
186 | num_bb = img_feat.size(0)
187 | return img_feat, img_bb, num_bb
188 |
189 |
190 | def mrfr_only_img_collate(inputs):
191 | img_feats, img_pos_feats, attn_masks, img_masks = map(list, unzip(inputs))
192 |
193 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
194 |
195 | num_bbs = [f.size(0) for f in img_feats]
196 | img_feat = pad_tensors(img_feats, num_bbs)
197 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
198 |
199 | # mask features
200 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
201 | feat_targets = _get_feat_target(img_feat, img_masks)
202 | img_feat = _mask_img_feat(img_feat, img_masks)
203 |
204 | batch = {'img_feat': img_feat,
205 | 'img_pos_feat': img_pos_feat,
206 | 'attn_masks': attn_masks,
207 | 'feat_targets': feat_targets,
208 | 'img_masks': img_masks,
209 | 'img_mask_tgt': img_masks}
210 | return batch
211 |
212 |
213 | def _get_targets(img_masks, img_soft_label):
214 | soft_label_dim = img_soft_label.size(-1)
215 | img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label)
216 | label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view(
217 | -1, soft_label_dim)
218 | return label_targets
219 |
220 |
221 | class MrcDataset(DetectFeatTxtTokDataset):
222 | def __init__(self, mask_prob, *args, **kwargs):
223 | super().__init__(*args, **kwargs)
224 | self.mask_prob = mask_prob
225 |
226 | def _get_img_feat(self, fname):
227 | img_dump = self.img_db.get_dump(fname)
228 | num_bb = self.img_db.name2nbb[fname]
229 | img_feat = torch.tensor(img_dump['features'])
230 | bb = torch.tensor(img_dump['norm_bb'])
231 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
232 | img_soft_label = torch.tensor(img_dump['soft_labels'])
233 | return img_feat, img_bb, img_soft_label, num_bb
234 |
235 | def __getitem__(self, i):
236 | example = super().__getitem__(i)
237 | img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat(
238 | example['img_fname'])
239 |
240 | # image input features
241 | img_mask = _get_img_mask(self.mask_prob, num_bb)
242 |
243 | # text input
244 | input_ids = example['input_ids']
245 | input_ids = self.txt_db.combine_inputs(input_ids)
246 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
247 |
248 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
249 |
250 | return (input_ids, img_feat, img_pos_feat,
251 | img_soft_labels, attn_masks, img_mask, img_mask_tgt)
252 |
253 | def xlmr_mrc_collate(inputs):
254 | (input_ids, img_feats, img_pos_feats, img_soft_labels,
255 | attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs))
256 |
257 | txt_lens = [i.size(0) for i in input_ids]
258 | num_bbs = [f.size(0) for f in img_feats]
259 |
260 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1)
261 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
262 | ).unsqueeze(0)
263 |
264 | img_feat = pad_tensors(img_feats, num_bbs)
265 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
266 | img_soft_label = pad_tensors(img_soft_labels, num_bbs)
267 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
268 | label_targets = _get_targets(img_masks, img_soft_label)
269 |
270 | img_feat = _mask_img_feat(img_feat, img_masks)
271 | img_mask_tgt = pad_sequence(img_mask_tgts,
272 | batch_first=True, padding_value=0)
273 |
274 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
275 | bs, max_tl = input_ids.size()
276 | out_size = attn_masks.size(1)
277 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
278 |
279 | batch = {'input_ids': input_ids,
280 | 'position_ids': position_ids,
281 | 'img_feat': img_feat,
282 | 'img_pos_feat': img_pos_feat,
283 | 'attn_masks': attn_masks,
284 | 'gather_index': gather_index,
285 | 'img_masks': img_masks,
286 | 'img_mask_tgt': img_mask_tgt,
287 | 'label_targets': label_targets}
288 | return batch
289 |
290 | def mrc_collate(inputs):
291 | (input_ids, img_feats, img_pos_feats, img_soft_labels,
292 | attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs))
293 |
294 | txt_lens = [i.size(0) for i in input_ids]
295 | num_bbs = [f.size(0) for f in img_feats]
296 |
297 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
298 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
299 | ).unsqueeze(0)
300 |
301 | img_feat = pad_tensors(img_feats, num_bbs)
302 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
303 | img_soft_label = pad_tensors(img_soft_labels, num_bbs)
304 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
305 | label_targets = _get_targets(img_masks, img_soft_label)
306 |
307 | img_feat = _mask_img_feat(img_feat, img_masks)
308 | img_mask_tgt = pad_sequence(img_mask_tgts,
309 | batch_first=True, padding_value=0)
310 |
311 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
312 | bs, max_tl = input_ids.size()
313 | out_size = attn_masks.size(1)
314 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
315 |
316 | batch = {'input_ids': input_ids,
317 | 'position_ids': position_ids,
318 | 'img_feat': img_feat,
319 | 'img_pos_feat': img_pos_feat,
320 | 'attn_masks': attn_masks,
321 | 'gather_index': gather_index,
322 | 'img_masks': img_masks,
323 | 'img_mask_tgt': img_mask_tgt,
324 | 'label_targets': label_targets}
325 | return batch
326 |
327 |
328 | class OnlyImgMrcDataset(OnlyImgMrfrDataset):
329 | """ an image-only MRC """
330 | def __getitem__(self, i):
331 | id_ = self.ids[i]
332 | (img_feat, img_pos_feat, img_soft_labels, num_bb
333 | ) = self._get_img_feat(id_)
334 | attn_masks = torch.ones(num_bb, dtype=torch.long)
335 | img_mask = _get_img_mask(self.mask_prob, num_bb)
336 |
337 | return img_feat, img_pos_feat, img_soft_labels, attn_masks, img_mask
338 |
339 | def _get_img_feat(self, fname):
340 | img_dump = self.img_db.get_dump(fname)
341 | num_bb = self.img_db.name2nbb[fname]
342 | img_feat = torch.tensor(img_dump['features'])
343 | bb = torch.tensor(img_dump['norm_bb'])
344 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
345 | img_soft_labels = torch.tensor(img_dump['soft_labels'])
346 | return img_feat, img_bb, img_soft_labels, num_bb
347 |
348 |
349 | def mrc_only_img_collate(inputs):
350 | (img_feats, img_pos_feats, img_soft_labels, attn_masks, img_masks
351 | ) = map(list, unzip(inputs))
352 |
353 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
354 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
355 | num_bbs = [f.size(0) for f in img_feats]
356 |
357 | img_feat = pad_tensors(img_feats, num_bbs)
358 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
359 | img_soft_label = pad_tensors(img_soft_labels, num_bbs)
360 | label_targets = _get_targets(img_masks, img_soft_label)
361 |
362 | # mask features
363 | img_feat = _mask_img_feat(img_feat, img_masks)
364 |
365 | batch = {'img_feat': img_feat,
366 | 'img_pos_feat': img_pos_feat,
367 | 'attn_masks': attn_masks,
368 | 'img_masks': img_masks,
369 | 'img_mask_tgt': img_masks,
370 | 'label_targets': label_targets}
371 | return batch
372 |
--------------------------------------------------------------------------------
/data/mrm_nce.py:
--------------------------------------------------------------------------------
1 | """
2 | MRM Datasets (contrastive learning version)
3 | """
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from toolz.sandbox import unzip
7 | from cytoolz import curry
8 |
9 | from .data import (DetectFeatLmdb, DetectFeatTxtTokDataset,
10 | pad_tensors, get_gather_index)
11 | from .mrm import _get_img_mask, _get_img_tgt_mask, _get_feat_target
12 | from .itm import sample_negative
13 |
14 |
15 | # FIXME diff implementation from mrfr, mrc
16 | def _mask_img_feat(img_feat, img_masks, neg_feats,
17 | noop_prob=0.1, change_prob=0.1):
18 | rand = torch.rand(*img_masks.size())
19 | noop_mask = rand < noop_prob
20 | change_mask = ~noop_mask & (rand < (noop_prob+change_prob)) & img_masks
21 | img_masks_in = img_masks & ~noop_mask & ~change_mask
22 |
23 | img_masks_ext = img_masks_in.unsqueeze(-1).expand_as(img_feat)
24 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0)
25 |
26 | n_neg = change_mask.sum().item()
27 | feat_dim = neg_feats.size(-1)
28 | index = torch.arange(0, change_mask.numel(), dtype=torch.long
29 | ).masked_select(change_mask.view(-1))
30 | index = index.unsqueeze(-1).expand(-1, feat_dim)
31 | img_feat_out = img_feat_masked.view(-1, feat_dim).scatter(
32 | dim=0, index=index, src=neg_feats[:n_neg]).view(*img_feat.size())
33 |
34 | return img_feat_out, img_masks_in
35 |
36 |
37 | class MrmNceDataset(DetectFeatTxtTokDataset):
38 | def __init__(self, mask_prob, *args, **kwargs):
39 | super().__init__(*args, **kwargs)
40 | self.mask_prob = mask_prob
41 |
42 | def __getitem__(self, i):
43 | example = super().__getitem__(i)
44 | # text input
45 | input_ids = example['input_ids']
46 | input_ids = self.txt_db.combine_inputs(input_ids)
47 |
48 | # image input features
49 | img_feat, img_pos_feat, num_bb = self._get_img_feat(
50 | example['img_fname'])
51 | img_mask = _get_img_mask(self.mask_prob, num_bb)
52 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids))
53 |
54 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
55 |
56 | return (input_ids, img_feat, img_pos_feat,
57 | attn_masks, img_mask, img_mask_tgt,
58 | example['img_fname'])
59 |
60 |
61 | class NegativeImageSampler(object):
62 | def __init__(self, img_dbs, neg_size, size_mul=8):
63 | if not isinstance(img_dbs, list):
64 | assert isinstance(img_dbs, DetectFeatLmdb)
65 | img_dbs = [img_dbs]
66 | self.neg_size = neg_size
67 | self.img_db = JoinedDetectFeatLmdb(img_dbs)
68 | all_imgs = []
69 | for db in img_dbs:
70 | all_imgs.extend(db.name2nbb.keys())
71 | self.all_imgs = all_imgs
72 |
73 | def sample_negative_feats(self, pos_imgs):
74 | neg_img_ids = sample_negative(self.all_imgs, pos_imgs, self.neg_size)
75 | all_neg_feats = torch.cat([self.img_db[img][0] for img in neg_img_ids],
76 | dim=0)
77 | # only use multiples of 8 for tensorcores
78 | n_cut = all_neg_feats.size(0) % 8
79 | if n_cut != 0:
80 | return all_neg_feats[:-n_cut]
81 | else:
82 | return all_neg_feats
83 |
84 |
85 | class JoinedDetectFeatLmdb(object):
86 | def __init__(self, img_dbs):
87 | assert all(isinstance(db, DetectFeatLmdb) for db in img_dbs)
88 | self.img_dbs = img_dbs
89 |
90 | def __getitem__(self, file_name):
91 | for db in self.img_dbs:
92 | if file_name in db:
93 | return db[file_name]
94 | raise ValueError("image does not exists")
95 |
96 |
97 | @curry
98 | def mrm_nce_collate(neg_sampler, inputs):
99 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts,
100 | positive_imgs) = map(list, unzip(inputs))
101 |
102 | txt_lens = [i.size(0) for i in input_ids]
103 |
104 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
105 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
106 | ).unsqueeze(0)
107 |
108 | num_bbs = [f.size(0) for f in img_feats]
109 | img_feat = pad_tensors(img_feats, num_bbs)
110 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
111 | neg_feats = neg_sampler.sample_negative_feats(positive_imgs)
112 |
113 | # mask features
114 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0)
115 | feat_targets = _get_feat_target(img_feat, img_masks)
116 | img_feat, img_masks_in = _mask_img_feat(img_feat, img_masks, neg_feats)
117 | img_mask_tgt = pad_sequence(img_mask_tgts,
118 | batch_first=True, padding_value=0)
119 |
120 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
121 | bs, max_tl = input_ids.size()
122 | out_size = attn_masks.size(1)
123 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
124 |
125 | batch = {'input_ids': input_ids,
126 | 'position_ids': position_ids,
127 | 'img_feat': img_feat,
128 | 'img_pos_feat': img_pos_feat,
129 | 'attn_masks': attn_masks,
130 | 'gather_index': gather_index,
131 | 'feat_targets': feat_targets,
132 | 'img_masks': img_masks,
133 | 'img_masks_in': img_masks_in,
134 | 'img_mask_tgt': img_mask_tgt,
135 | 'neg_feats': neg_feats}
136 | return batch
137 |
--------------------------------------------------------------------------------
/data/nlvr2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | NLVR2 dataset
6 | """
7 | import copy
8 |
9 | import torch
10 | from torch.nn.utils.rnn import pad_sequence
11 | from toolz.sandbox import unzip
12 | from cytoolz import concat
13 |
14 | from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, DetectFeatLmdb,
15 | get_ids_and_lens, pad_tensors, get_gather_index)
16 |
17 |
18 | class Nlvr2PairedDataset(DetectFeatTxtTokDataset):
19 | def __init__(self, txt_db, img_db, use_img_type=True):
20 | assert isinstance(txt_db, TxtTokLmdb)
21 | assert isinstance(img_db, DetectFeatLmdb)
22 | self.txt_db = txt_db
23 | self.img_db = img_db
24 | txt_lens, self.ids = get_ids_and_lens(txt_db)
25 |
26 | txt2img = txt_db.txt2img
27 | self.lens = [2*tl + sum(self.img_db.name2nbb[img]
28 | for img in txt2img[id_])
29 | for tl, id_ in zip(txt_lens, self.ids)]
30 |
31 | self.use_img_type = use_img_type
32 |
33 | def __getitem__(self, i):
34 | """
35 | [[txt, img1],
36 | [txt, img2]]
37 | """
38 | example = super().__getitem__(i)
39 | target = example['target']
40 | outs = []
41 | for i, img in enumerate(example['img_fname']):
42 | img_feat, img_pos_feat, num_bb = self._get_img_feat(img)
43 |
44 | # text input
45 | input_ids = copy.deepcopy(example['input_ids'])
46 |
47 | input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep]
48 | attn_masks = [1] * (len(input_ids) + num_bb)
49 | input_ids = torch.tensor(input_ids)
50 | attn_masks = torch.tensor(attn_masks)
51 | if self.use_img_type:
52 | img_type_ids = torch.tensor([i+1]*num_bb)
53 | else:
54 | img_type_ids = None
55 |
56 | outs.append((input_ids, img_feat, img_pos_feat,
57 | attn_masks, img_type_ids))
58 | return tuple(outs), target
59 |
60 |
61 | def nlvr2_paired_collate(inputs):
62 | (input_ids, img_feats, img_pos_feats, attn_masks,
63 | img_type_ids) = map(list, unzip(concat(outs for outs, _ in inputs)))
64 |
65 | txt_lens = [i.size(0) for i in input_ids]
66 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
67 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
68 | ).unsqueeze(0)
69 |
70 | # image batches
71 | num_bbs = [f.size(0) for f in img_feats]
72 | img_feat = pad_tensors(img_feats, num_bbs)
73 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
74 | if img_type_ids[0] is None:
75 | img_type_ids = None
76 | else:
77 | img_type_ids = pad_sequence(img_type_ids,
78 | batch_first=True, padding_value=0)
79 |
80 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
81 | targets = torch.Tensor([t for _, t in inputs]).long()
82 |
83 | bs, max_tl = input_ids.size()
84 | out_size = attn_masks.size(1)
85 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
86 |
87 | batch = {'input_ids': input_ids,
88 | 'position_ids': position_ids,
89 | 'img_feat': img_feat,
90 | 'img_pos_feat': img_pos_feat,
91 | 'attn_masks': attn_masks,
92 | 'gather_index': gather_index,
93 | 'img_type_ids': img_type_ids,
94 | 'targets': targets}
95 | return batch
96 |
97 |
98 | class Nlvr2PairedEvalDataset(Nlvr2PairedDataset):
99 | def __getitem__(self, i):
100 | qid = self.ids[i]
101 | outs, targets = super().__getitem__(i)
102 | return qid, outs, targets
103 |
104 |
105 | def nlvr2_paired_eval_collate(inputs):
106 | qids, batch = [], []
107 | for id_, *tensors in inputs:
108 | qids.append(id_)
109 | batch.append(tensors)
110 | batch = nlvr2_paired_collate(batch)
111 | batch['qids'] = qids
112 | return batch
113 |
114 |
115 | class Nlvr2TripletDataset(DetectFeatTxtTokDataset):
116 | def __init__(self, txt_db, img_db, use_img_type=True):
117 | assert isinstance(txt_db, TxtTokLmdb)
118 | assert isinstance(img_db, DetectFeatLmdb)
119 | self.txt_db = txt_db
120 | self.img_db = img_db
121 | txt_lens, self.ids = get_ids_and_lens(txt_db)
122 |
123 | txt2img = txt_db.txt2img
124 | self.lens = [tl + sum(self.img_db.name2nbb[img]
125 | for img in txt2img[id_])
126 | for tl, id_ in zip(txt_lens, self.ids)]
127 |
128 | self.use_img_type = use_img_type
129 |
130 | def __getitem__(self, i):
131 | """
132 | [[txt, img1],
133 | [txt, img2]]
134 | """
135 | example = super().__getitem__(i)
136 | target = example['target']
137 | img_feats = []
138 | img_pos_feats = []
139 | num_bb = 0
140 | img_type_ids = []
141 | for i, img in enumerate(example['img_fname']):
142 | feat, pos, nbb = self._get_img_feat(img)
143 | img_feats.append(feat)
144 | img_pos_feats.append(pos)
145 | num_bb += nbb
146 | if self.use_img_type:
147 | img_type_ids.extend([i+1]*nbb)
148 | img_feat = torch.cat(img_feats, dim=0)
149 | img_pos_feat = torch.cat(img_pos_feats, dim=0)
150 | if self.use_img_type:
151 | img_type_ids = torch.tensor(img_type_ids)
152 | else:
153 | img_type_ids = None
154 |
155 | # text input
156 | input_ids = copy.deepcopy(example['input_ids'])
157 |
158 | input_ids = [self.txt_db.cls_] + input_ids + [self.txt_db.sep]
159 | attn_masks = [1] * (len(input_ids) + num_bb)
160 | input_ids = torch.tensor(input_ids)
161 | attn_masks = torch.tensor(attn_masks)
162 |
163 | return (input_ids, img_feat, img_pos_feat, attn_masks,
164 | img_type_ids, target)
165 |
166 |
167 | def nlvr2_triplet_collate(inputs):
168 | (input_ids, img_feats, img_pos_feats,
169 | attn_masks, img_type_ids, targets) = map(list, unzip(inputs))
170 |
171 | txt_lens = [i.size(0) for i in input_ids]
172 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
173 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
174 | ).unsqueeze(0)
175 |
176 | # image batches
177 | num_bbs = [f.size(0) for f in img_feats]
178 | img_feat = pad_tensors(img_feats, num_bbs)
179 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
180 | if img_type_ids[0] is None:
181 | img_type_ids = None
182 | else:
183 | img_type_ids = pad_sequence(img_type_ids,
184 | batch_first=True, padding_value=0)
185 |
186 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
187 | targets = torch.Tensor(targets).long()
188 |
189 | bs, max_tl = input_ids.size()
190 | out_size = attn_masks.size(1)
191 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
192 |
193 | batch = {'input_ids': input_ids,
194 | 'position_ids': position_ids,
195 | 'img_feat': img_feat,
196 | 'img_pos_feat': img_pos_feat,
197 | 'attn_masks': attn_masks,
198 | 'gather_index': gather_index,
199 | 'img_type_ids': img_type_ids,
200 | 'targets': targets}
201 | return batch
202 |
203 |
204 | class Nlvr2TripletEvalDataset(Nlvr2TripletDataset):
205 | def __getitem__(self, i):
206 | qid = self.ids[i]
207 | tensors = super().__getitem__(i)
208 | return (qid, *tensors)
209 |
210 |
211 | def nlvr2_triplet_eval_collate(inputs):
212 | qids, batch = [], []
213 | for id_, *tensors in inputs:
214 | qids.append(id_)
215 | batch.append(tensors)
216 | batch = nlvr2_triplet_collate(batch)
217 | batch['qids'] = qids
218 | return batch
219 |
--------------------------------------------------------------------------------
/data/re.py:
--------------------------------------------------------------------------------
1 | """
2 | Referring Expression Comprehension dataset
3 | """
4 | import sys
5 | import json
6 | import random
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | from torch.nn.utils.rnn import pad_sequence
12 | from toolz.sandbox import unzip
13 |
14 | from .data import TxtLmdb
15 |
16 |
17 | class ReImageFeatDir(object):
18 | def __init__(self, img_dir):
19 | self.img_dir = img_dir
20 |
21 | def __getitem__(self, file_name):
22 | img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
23 | img_feat = torch.tensor(img_dump['features'])
24 | img_bb = torch.tensor(img_dump['norm_bb'])
25 | return img_feat, img_bb
26 |
27 |
28 | class ReDetectFeatDir(object):
29 | def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
30 | format_='npz'):
31 | assert format_ == 'npz', 'only support npz for now.'
32 | assert isinstance(img_dir, str), 'img_dir is path, not db.'
33 | self.img_dir = img_dir
34 | self.conf_th = conf_th
35 | self.max_bb = max_bb
36 | self.min_bb = min_bb
37 | self.num_bb = num_bb
38 |
39 | def _compute_num_bb(self, img_dump):
40 | num_bb = max(self.min_bb, (img_dump['conf'] > self.conf_th).sum())
41 | num_bb = min(self.max_bb, num_bb)
42 | return num_bb
43 |
44 | def __getitem__(self, file_name):
45 | # image input features
46 | img_dump = np.load(f'{self.img_dir}/{file_name}', allow_pickle=True)
47 | num_bb = self._compute_num_bb(img_dump)
48 | img_feat = torch.tensor(img_dump['features'][:num_bb, :])
49 | img_bb = torch.tensor(img_dump['norm_bb'][:num_bb, :])
50 | return img_feat, img_bb
51 |
52 |
53 | class ReferringExpressionDataset(Dataset):
54 | def __init__(self, db_dir, img_dir, max_txt_len=60):
55 | assert isinstance(img_dir, ReImageFeatDir) or \
56 | isinstance(img_dir, ReDetectFeatDir)
57 | self.img_dir = img_dir
58 |
59 | # load refs = [{ref_id, sent_ids, ann_id, image_id, sentences, split}]
60 | refs = json.load(open(f'{db_dir}/refs.json', 'r'))
61 | self.ref_ids = [ref['ref_id'] for ref in refs]
62 | self.Refs = {ref['ref_id']: ref for ref in refs}
63 |
64 | # load annotations = [{id, area, bbox, image_id, category_id}]
65 | anns = json.load(open(f'{db_dir}/annotations.json', 'r'))
66 | self.Anns = {ann['id']: ann for ann in anns}
67 |
68 | # load categories = [{id, name, supercategory}]
69 | categories = json.load(open(f'{db_dir}/categories.json', 'r'))
70 | self.Cats = {cat['id']: cat['name'] for cat in categories}
71 |
72 | # load images = [{id, file_name, ann_ids, height, width}]
73 | images = json.load(open(f'{db_dir}/images.json', 'r'))
74 | self.Images = {img['id']: img for img in images}
75 |
76 | # id2len: sent_id -> sent_len
77 | id2len = json.load(open(f'{db_dir}/id2len.json', 'r'))
78 | self.id2len = {int(_id): _len for _id, _len in id2len.items()}
79 | self.max_txt_len = max_txt_len
80 | self.sent_ids = self._get_sent_ids()
81 |
82 | # db[str(sent_id)] =
83 | # {sent_id, sent, ref_id, ann_id, image_id,
84 | # bbox, input_ids, toked_sent}
85 | self.db = TxtLmdb(db_dir, readonly=True)
86 |
87 | # meta
88 | meta = json.load(open(f'{db_dir}/meta.json', 'r'))
89 | self.cls_ = meta['CLS']
90 | self.sep = meta['SEP']
91 | self.mask = meta['MASK']
92 | self.v_range = meta['v_range']
93 |
94 | def shuffle(self):
95 | # we shuffle ref_ids and make sent_ids according to ref_ids
96 | random.shuffle(self.ref_ids)
97 | self.sent_ids = self._get_sent_ids()
98 |
99 | def _get_sent_ids(self):
100 | sent_ids = []
101 | for ref_id in self.ref_ids:
102 | for sent_id in self.Refs[ref_id]['sent_ids']:
103 | sent_len = self.id2len[sent_id]
104 | if self.max_txt_len == -1 or sent_len < self.max_txt_len:
105 | sent_ids.append(sent_id)
106 | return sent_ids
107 |
108 | def _get_img_feat(self, fname):
109 | img_feat, bb = self.img_dir[fname]
110 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
111 | num_bb = img_feat.size(0)
112 | return img_feat, img_bb, num_bb
113 |
114 | def __len__(self):
115 | return len(self.sent_ids)
116 |
117 | def __getitem__(self, i):
118 | """
119 | Return:
120 | :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
121 | :position_ids : range(L)
122 | :img_feat : (num_bb, d)
123 | :img_pos_feat : (num_bb, 7)
124 | :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
125 | :obj_masks : (num_bb, ) all 0's
126 | :target : (1, )
127 | """
128 | # {sent_id, sent, ref_id, ann_id, image_id,
129 | # bbox, input_ids, toked_sent}
130 | sent_id = self.sent_ids[i]
131 | txt_dump = self.db[str(sent_id)]
132 | image_id = txt_dump['image_id']
133 | fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
134 | img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
135 |
136 | # text input
137 | input_ids = txt_dump['input_ids']
138 | input_ids = [self.cls_] + input_ids + [self.sep]
139 | attn_masks = [1] * len(input_ids)
140 | position_ids = list(range(len(input_ids)))
141 | attn_masks += [1] * num_bb
142 |
143 | input_ids = torch.tensor(input_ids)
144 | position_ids = torch.tensor(position_ids)
145 | attn_masks = torch.tensor(attn_masks)
146 |
147 | # target bbox
148 | img = self.Images[image_id]
149 | assert len(img['ann_ids']) == num_bb, \
150 | 'Please use visual_grounding_coco_gt'
151 | target = img['ann_ids'].index(txt_dump['ann_id'])
152 | target = torch.tensor([target])
153 |
154 | # obj_masks, to be padded with 1, for masking out non-object prob.
155 | obj_masks = torch.tensor([0]*len(img['ann_ids'])).byte()
156 |
157 | return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
158 | obj_masks, target)
159 |
160 |
161 | def re_collate(inputs):
162 | """
163 | Return:
164 | :input_ids : (n, max_L) padded with 0
165 | :position_ids : (n, max_L) padded with 0
166 | :txt_lens : list of [txt_len]
167 | :img_feat : (n, max_num_bb, feat_dim)
168 | :img_pos_feat : (n, max_num_bb, 7)
169 | :num_bbs : list of [num_bb]
170 | :attn_masks : (n, max_{L+num_bb}) padded with 0
171 | :obj_masks : (n, max_num_bb) padded with 1
172 | :targets : (n, )
173 | """
174 | (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
175 | targets) = map(list, unzip(inputs))
176 |
177 | txt_lens = [i.size(0) for i in input_ids]
178 | num_bbs = [f.size(0) for f in img_feats]
179 |
180 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
181 | position_ids = pad_sequence(position_ids,
182 | batch_first=True, padding_value=0)
183 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
184 | targets = torch.cat(targets, dim=0)
185 | obj_masks = pad_sequence(obj_masks,
186 | batch_first=True, padding_value=1).byte()
187 |
188 | batch_size = len(img_feats)
189 | num_bb = max(num_bbs)
190 | feat_dim = img_feats[0].size(1)
191 | pos_dim = img_pos_feats[0].size(1)
192 | img_feat = torch.zeros(batch_size, num_bb, feat_dim)
193 | img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
194 | for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
195 | len_ = im.size(0)
196 | img_feat.data[i, :len_, :] = im.data
197 | img_pos_feat.data[i, :len_, :] = pos.data
198 |
199 | return (input_ids, position_ids, txt_lens,
200 | img_feat, img_pos_feat, num_bbs,
201 | attn_masks, obj_masks, targets)
202 |
203 |
204 | class ReferringExpressionEvalDataset(ReferringExpressionDataset):
205 | def __getitem__(self, i):
206 | """
207 | Return:
208 | :input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0]
209 | :position_ids : range(L)
210 | :img_feat : (num_bb, d)
211 | :img_pos_feat : (num_bb, 7)
212 | :attn_masks : (L+num_bb, ), i.e., [1, 1, ..., 0, 0, 1, 1]
213 | :obj_masks : (num_bb, ) all 0's
214 | :tgt_box : ndarray (4, ) xywh
215 | :obj_boxes : ndarray (num_bb, 4) xywh
216 | :sent_id
217 | """
218 | # {sent_id, sent, ref_id, ann_id, image_id,
219 | # bbox, input_ids, toked_sent}
220 | sent_id = self.sent_ids[i]
221 | txt_dump = self.db[str(sent_id)]
222 | image_id = txt_dump['image_id']
223 | if isinstance(self.img_dir, ReImageFeatDir):
224 | if '_gt' in self.img_dir.img_dir:
225 | fname = f'visual_grounding_coco_gt_{int(image_id):012}.npz'
226 | elif '_det' in self.img_dir.img_dir:
227 | fname = f'visual_grounding_det_coco_{int(image_id):012}.npz'
228 | elif isinstance(self.img_dir, ReDetectFeatDir):
229 | fname = f'coco_train2014_{int(image_id):012}.npz'
230 | else:
231 | sys.exit('%s not supported.' % self.img_dir)
232 | img_feat, img_pos_feat, num_bb = self._get_img_feat(fname)
233 |
234 | # image info
235 | img = self.Images[image_id]
236 | im_width, im_height = img['width'], img['height']
237 |
238 | # object boxes, img_pos_feat (xyxywha) -> xywh
239 | obj_boxes = np.stack([img_pos_feat[:, 0]*im_width,
240 | img_pos_feat[:, 1]*im_height,
241 | img_pos_feat[:, 4]*im_width,
242 | img_pos_feat[:, 5]*im_height], axis=1)
243 | obj_masks = torch.tensor([0]*num_bb).byte()
244 |
245 | # target box
246 | tgt_box = np.array(txt_dump['bbox']) # xywh
247 |
248 | # text input
249 | input_ids = txt_dump['input_ids']
250 | input_ids = [self.cls_] + input_ids + [self.sep]
251 | attn_masks = [1] * len(input_ids)
252 | position_ids = list(range(len(input_ids)))
253 | attn_masks += [1] * num_bb
254 |
255 | input_ids = torch.tensor(input_ids)
256 | position_ids = torch.tensor(position_ids)
257 | attn_masks = torch.tensor(attn_masks)
258 |
259 | return (input_ids, position_ids, img_feat, img_pos_feat, attn_masks,
260 | obj_masks, tgt_box, obj_boxes, sent_id)
261 |
262 | # IoU function
263 | def computeIoU(self, box1, box2):
264 | # each box is of [x1, y1, w, h]
265 | inter_x1 = max(box1[0], box2[0])
266 | inter_y1 = max(box1[1], box2[1])
267 | inter_x2 = min(box1[0]+box1[2]-1, box2[0]+box2[2]-1)
268 | inter_y2 = min(box1[1]+box1[3]-1, box2[1]+box2[3]-1)
269 |
270 | if inter_x1 < inter_x2 and inter_y1 < inter_y2:
271 | inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
272 | else:
273 | inter = 0
274 | union = box1[2]*box1[3] + box2[2]*box2[3] - inter
275 | return float(inter)/union
276 |
277 |
278 | def re_eval_collate(inputs):
279 | """
280 | Return:
281 | :input_ids : (n, max_L)
282 | :position_ids : (n, max_L)
283 | :txt_lens : list of [txt_len]
284 | :img_feat : (n, max_num_bb, d)
285 | :img_pos_feat : (n, max_num_bb, 7)
286 | :num_bbs : list of [num_bb]
287 | :attn_masks : (n, max{L+num_bb})
288 | :obj_masks : (n, max_num_bb)
289 | :tgt_box : list of n [xywh]
290 | :obj_boxes : list of n [[xywh, xywh, ...]]
291 | :sent_ids : list of n [sent_id]
292 | """
293 | (input_ids, position_ids, img_feats, img_pos_feats, attn_masks, obj_masks,
294 | tgt_box, obj_boxes, sent_ids) = map(list, unzip(inputs))
295 |
296 | txt_lens = [i.size(0) for i in input_ids]
297 | num_bbs = [f.size(0) for f in img_feats]
298 |
299 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
300 | position_ids = pad_sequence(position_ids,
301 | batch_first=True, padding_value=0)
302 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
303 | obj_masks = pad_sequence(obj_masks,
304 | batch_first=True, padding_value=1).byte()
305 |
306 | batch_size = len(img_feats)
307 | num_bb = max(num_bbs)
308 | feat_dim = img_feats[0].size(1)
309 | pos_dim = img_pos_feats[0].size(1)
310 | img_feat = torch.zeros(batch_size, num_bb, feat_dim)
311 | img_pos_feat = torch.zeros(batch_size, num_bb, pos_dim)
312 | for i, (im, pos) in enumerate(zip(img_feats, img_pos_feats)):
313 | len_ = im.size(0)
314 | img_feat.data[i, :len_, :] = im.data
315 | img_pos_feat.data[i, :len_, :] = pos.data
316 |
317 | return (input_ids, position_ids, txt_lens,
318 | img_feat, img_pos_feat, num_bbs,
319 | attn_masks, obj_masks, tgt_box, obj_boxes, sent_ids)
320 |
--------------------------------------------------------------------------------
/data/sampler.py:
--------------------------------------------------------------------------------
1 | """ sampler for length bucketing (batch by tokens) """
2 | import math
3 | import random
4 |
5 | import torch
6 | import horovod.torch as hvd
7 | from torch.utils.data import Sampler
8 | from cytoolz import partition_all
9 |
10 |
11 | class TokenBucketSampler(Sampler):
12 | def __init__(self, lens, bucket_size, batch_size,
13 | droplast=False, size_multiple=8):
14 | self._lens = lens
15 | self._max_tok = batch_size
16 | self._bucket_size = bucket_size
17 | self._droplast = droplast
18 | self._size_mul = size_multiple
19 |
20 | def _create_ids(self):
21 | return list(range(len(self._lens)))
22 |
23 | def _sort_fn(self, i):
24 | return self._lens[i]
25 |
26 | def __iter__(self):
27 | ids = self._create_ids()
28 | random.shuffle(ids)
29 | buckets = [sorted(ids[i:i+self._bucket_size],
30 | key=self._sort_fn, reverse=True)
31 | for i in range(0, len(ids), self._bucket_size)]
32 | # fill batches until max_token (include padding)
33 | batches = []
34 | for bucket in buckets:
35 | max_len = 0
36 | batch_indices = []
37 | for indices in partition_all(self._size_mul, bucket):
38 | max_len = max(max_len, max(self._lens[i] for i in indices))
39 | # print(max_len * (len(batch_indices) + self._size_mul))
40 | # print(self._max_tok)
41 | # print(len(batch_indices))
42 | if (max_len * (len(batch_indices) + self._size_mul)
43 | > self._max_tok):
44 | if not batch_indices:
45 | raise ValueError(
46 | "max_tokens too small / max_seq_len too long")
47 | assert len(batch_indices) % self._size_mul == 0
48 | batches.append(batch_indices)
49 | batch_indices = list(indices)
50 | else:
51 | batch_indices.extend(indices)
52 | if not self._droplast and batch_indices:
53 | batches.append(batch_indices)
54 | random.shuffle(batches)
55 | return iter(batches)
56 |
57 | def __len__(self):
58 | raise ValueError("NOT supported. "
59 | "This has some randomness across epochs")
60 |
61 |
62 | class DistributedSampler(Sampler):
63 | """Sampler that restricts data loading to a subset of the dataset.
64 |
65 | It is especially useful in conjunction with
66 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
67 | process can pass a DistributedSampler instance as a DataLoader sampler,
68 | and load a subset of the original dataset that is exclusive to it.
69 |
70 | .. note::
71 | Dataset is assumed to be of constant size.
72 |
73 | Arguments:
74 | dataset: Dataset used for sampling.
75 | num_replicas (optional): Number of processes participating in
76 | distributed training.
77 | rank (optional): Rank of the current process within num_replicas.
78 | shuffle (optional): If true (default), sampler will shuffle the indices
79 | """
80 |
81 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
82 | if num_replicas is None:
83 | num_replicas = hvd.size()
84 | if rank is None:
85 | rank = hvd.rank()
86 | self.dataset = dataset
87 | self.num_replicas = num_replicas
88 | self.rank = rank
89 | self.epoch = 0
90 | self.num_samples = int(math.ceil(len(self.dataset)
91 | * 1.0 / self.num_replicas))
92 | self.total_size = self.num_samples * self.num_replicas
93 | self.shuffle = shuffle
94 |
95 | def __iter__(self):
96 | # deterministically shuffle based on epoch
97 | g = torch.Generator()
98 | g.manual_seed(self.epoch)
99 |
100 | indices = list(range(len(self.dataset)))
101 | # add extra samples to make it evenly divisible
102 | indices += indices[:(self.total_size - len(indices))]
103 | assert len(indices) == self.total_size
104 |
105 | # subsample
106 | indices = indices[self.rank:self.total_size:self.num_replicas]
107 |
108 | if self.shuffle:
109 | shufle_ind = torch.randperm(len(indices), generator=g).tolist()
110 | indices = [indices[i] for i in shufle_ind]
111 | assert len(indices) == self.num_samples
112 |
113 | return iter(indices)
114 |
115 | def __len__(self):
116 | return self.num_samples
117 |
118 | def set_epoch(self, epoch):
119 | self.epoch = epoch
120 |
--------------------------------------------------------------------------------
/data/test_data/input0.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input0.txt
--------------------------------------------------------------------------------
/data/test_data/input1.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input1.txt
--------------------------------------------------------------------------------
/data/test_data/input2.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input2.txt
--------------------------------------------------------------------------------
/data/test_data/input3.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input3.txt
--------------------------------------------------------------------------------
/data/test_data/input4.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input4.txt
--------------------------------------------------------------------------------
/data/test_data/input5.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input5.txt
--------------------------------------------------------------------------------
/data/test_data/input6.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input6.txt
--------------------------------------------------------------------------------
/data/test_data/input7.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/data/test_data/input7.txt
--------------------------------------------------------------------------------
/data/ve.py:
--------------------------------------------------------------------------------
1 | """
2 | Visual entailment dataset
3 | # NOTE: basically reuse VQA dataset
4 | """
5 | from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate
6 |
7 |
8 | class VeDataset(VqaDataset):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(3, *args, **kwargs)
11 |
12 |
13 | class VeEvalDataset(VqaEvalDataset):
14 | def __init__(self, *args, **kwargs):
15 | super().__init__(3, *args, **kwargs)
16 |
17 |
18 | ve_collate = vqa_collate
19 | ve_eval_collate = vqa_eval_collate
20 |
--------------------------------------------------------------------------------
/data/vqa.py:
--------------------------------------------------------------------------------
1 | """
2 | VQA dataset
3 | """
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | from toolz.sandbox import unzip
7 |
8 | from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index
9 |
10 |
11 | def _get_vqa_target(example, num_answers):
12 | target = torch.zeros(num_answers)
13 | labels = example['target']['labels']
14 | scores = example['target']['scores']
15 | if labels and scores:
16 | target.scatter_(0, torch.tensor(labels), torch.tensor(scores))
17 | return target
18 |
19 |
20 | class VqaDataset(DetectFeatTxtTokDataset):
21 | """ NOTE: This handels distributed inside """
22 | def __init__(self, num_answers, *args, **kwargs):
23 | super().__init__(*args, **kwargs)
24 | self.num_answers = num_answers
25 |
26 | def __getitem__(self, i):
27 | example = super().__getitem__(i)
28 | img_feat, img_pos_feat, num_bb = self._get_img_feat(
29 | example['img_fname'])
30 |
31 | # text input
32 | input_ids = example['input_ids']
33 | input_ids = self.txt_db.combine_inputs(input_ids)
34 |
35 | target = _get_vqa_target(example, self.num_answers)
36 |
37 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
38 |
39 | return input_ids, img_feat, img_pos_feat, attn_masks, target
40 |
41 |
42 | def xlmr_vqa_collate(inputs):
43 | (input_ids, img_feats, img_pos_feats, attn_masks, targets
44 | ) = map(list, unzip(inputs))
45 |
46 | txt_lens = [i.size(0) for i in input_ids]
47 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1)
48 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
49 | ).unsqueeze(0)
50 |
51 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
52 | targets = torch.stack(targets, dim=0)
53 |
54 | num_bbs = [f.size(0) for f in img_feats]
55 | img_feat = pad_tensors(img_feats, num_bbs)
56 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
57 |
58 | bs, max_tl = input_ids.size()
59 | out_size = attn_masks.size(1)
60 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
61 |
62 | batch = {'input_ids': input_ids,
63 | 'position_ids': position_ids,
64 | 'img_feat': img_feat,
65 | 'img_pos_feat': img_pos_feat,
66 | 'attn_masks': attn_masks,
67 | 'gather_index': gather_index,
68 | 'targets': targets}
69 | return batch
70 |
71 | def vqa_collate(inputs):
72 | (input_ids, img_feats, img_pos_feats, attn_masks, targets
73 | ) = map(list, unzip(inputs))
74 |
75 | txt_lens = [i.size(0) for i in input_ids]
76 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
77 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
78 | ).unsqueeze(0)
79 |
80 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
81 | targets = torch.stack(targets, dim=0)
82 |
83 | num_bbs = [f.size(0) for f in img_feats]
84 | img_feat = pad_tensors(img_feats, num_bbs)
85 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
86 |
87 | bs, max_tl = input_ids.size()
88 | out_size = attn_masks.size(1)
89 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
90 |
91 | batch = {'input_ids': input_ids,
92 | 'position_ids': position_ids,
93 | 'img_feat': img_feat,
94 | 'img_pos_feat': img_pos_feat,
95 | 'attn_masks': attn_masks,
96 | 'gather_index': gather_index,
97 | 'targets': targets}
98 | return batch
99 |
100 |
101 | class VqaEvalDataset(VqaDataset):
102 | def __getitem__(self, i):
103 | qid = self.ids[i]
104 | example = DetectFeatTxtTokDataset.__getitem__(self, i)
105 | img_feat, img_pos_feat, num_bb = self._get_img_feat(
106 | example['img_fname'])
107 |
108 | # text input
109 | input_ids = example['input_ids']
110 | input_ids = self.txt_db.combine_inputs(input_ids)
111 |
112 | if 'target' in example:
113 | target = _get_vqa_target(example, self.num_answers)
114 | else:
115 | target = None
116 |
117 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
118 |
119 | return qid, input_ids, img_feat, img_pos_feat, attn_masks, target
120 |
121 |
122 | def xlmr_vqa_eval_collate(inputs):
123 | (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
124 | ) = map(list, unzip(inputs))
125 |
126 | txt_lens = [i.size(0) for i in input_ids]
127 |
128 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=1)
129 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
130 | ).unsqueeze(0)
131 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
132 | if targets[0] is None:
133 | targets = None
134 | else:
135 | targets = torch.stack(targets, dim=0)
136 |
137 | num_bbs = [f.size(0) for f in img_feats]
138 | img_feat = pad_tensors(img_feats, num_bbs)
139 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
140 |
141 | bs, max_tl = input_ids.size()
142 | out_size = attn_masks.size(1)
143 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
144 |
145 | batch = {'qids': qids,
146 | 'input_ids': input_ids,
147 | 'position_ids': position_ids,
148 | 'img_feat': img_feat,
149 | 'img_pos_feat': img_pos_feat,
150 | 'attn_masks': attn_masks,
151 | 'gather_index': gather_index,
152 | 'targets': targets}
153 | return batch
154 |
155 | def vqa_eval_collate(inputs):
156 | (qids, input_ids, img_feats, img_pos_feats, attn_masks, targets
157 | ) = map(list, unzip(inputs))
158 |
159 | txt_lens = [i.size(0) for i in input_ids]
160 |
161 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
162 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
163 | ).unsqueeze(0)
164 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
165 | if targets[0] is None:
166 | targets = None
167 | else:
168 | targets = torch.stack(targets, dim=0)
169 |
170 | num_bbs = [f.size(0) for f in img_feats]
171 | img_feat = pad_tensors(img_feats, num_bbs)
172 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
173 |
174 | bs, max_tl = input_ids.size()
175 | out_size = attn_masks.size(1)
176 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
177 |
178 | batch = {'qids': qids,
179 | 'input_ids': input_ids,
180 | 'position_ids': position_ids,
181 | 'img_feat': img_feat,
182 | 'img_pos_feat': img_pos_feat,
183 | 'attn_masks': attn_masks,
184 | 'gather_index': gather_index,
185 | 'targets': targets}
186 | return batch
187 |
--------------------------------------------------------------------------------
/eval/__pycache__/itm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/eval/__pycache__/itm.cpython-36.pyc
--------------------------------------------------------------------------------
/eval/itm.py:
--------------------------------------------------------------------------------
1 | """ Image Text Retrieval evaluation helper """
2 | import torch
3 |
4 |
5 | @torch.no_grad()
6 | def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts):
7 | # image retrieval
8 | img2j = {i: j for j, i in enumerate(img_ids)}
9 | _, rank_txt = score_matrix.topk(10, dim=1)
10 | gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]]
11 | for txt_id in txt_ids],
12 | ).to(rank_txt.device
13 | ).unsqueeze(1).expand_as(rank_txt)
14 | rank = (rank_txt == gt_img_j).nonzero()
15 | if rank.numel():
16 | ir_r1 = (rank < 1).sum().item() / len(txt_ids)
17 | ir_r5 = (rank < 5).sum().item() / len(txt_ids)
18 | ir_r10 = (rank < 10).sum().item() / len(txt_ids)
19 | else:
20 | ir_r1, ir_r5, ir_r10 = 0, 0, 0
21 |
22 | # text retrieval
23 | txt2i = {t: i for i, t in enumerate(txt_ids)}
24 | _, rank_img = score_matrix.topk(10, dim=0)
25 | tr_r1, tr_r5, tr_r10 = 0, 0, 0
26 | for j, img_id in enumerate(img_ids):
27 | gt_is = [txt2i[t] for t in img2txts[img_id]]
28 | ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is]
29 | rank = min([10] + [r.item() for r in ranks if r.numel()])
30 | if rank < 1:
31 | tr_r1 += 1
32 | if rank < 5:
33 | tr_r5 += 1
34 | if rank < 10:
35 | tr_r10 += 1
36 | tr_r1 /= len(img_ids)
37 | tr_r5 /= len(img_ids)
38 | tr_r10 /= len(img_ids)
39 |
40 | tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3
41 | ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3
42 | r_mean = (tr_mean + ir_mean) / 2
43 |
44 | eval_log = {'txt_r1': tr_r1,
45 | 'txt_r5': tr_r5,
46 | 'txt_r10': tr_r10,
47 | 'txt_r_mean': tr_mean,
48 | 'img_r1': ir_r1,
49 | 'img_r5': ir_r5,
50 | 'img_r10': ir_r10,
51 | 'img_r_mean': ir_mean,
52 | 'r_mean': r_mean}
53 | return eval_log
54 |
--------------------------------------------------------------------------------
/eval/nlvr2.py:
--------------------------------------------------------------------------------
1 | """
2 | copied from official NLVR2 github
3 | python eval/nlvr2.py
4 | """
5 | import json
6 | import sys
7 |
8 | # Load the predictions file. Assume it is a CSV.
9 | predictions = { }
10 | for line in open(sys.argv[1]).readlines():
11 | if line:
12 | splits = line.strip().split(",")
13 | # We assume identifiers are in the format "split-####-#-#.png".
14 | identifier = splits[0]
15 | prediction = splits[1]
16 | predictions[identifier] = prediction
17 |
18 | # Load the labeled examples.
19 | labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line]
20 |
21 | # If not, identify the ones that are missing, and exit.
22 | total_num = len(labeled_examples)
23 | if len(predictions) < total_num:
24 | print("Some predictions are missing!")
25 | print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num))
26 |
27 | for example in labeled_examples:
28 | lookup = example["identifier"]
29 | if not lookup in predictions:
30 | print("Missing prediction for item " + str(lookup))
31 | exit()
32 |
33 | # Get the precision by iterating through the examples and checking the value
34 | # that was predicted.
35 | # Also update the "consistency" dictionary that keeps track of whether all
36 | # predictions for a given sentence were correct.
37 | num_correct = 0.
38 | consistency_dict = { }
39 |
40 | for example in labeled_examples:
41 | anon_label = example["identifier"].split("-")
42 | anon_label[2] = ''
43 | anon_label = '-'.join(anon_label)
44 | if not anon_label in consistency_dict:
45 | consistency_dict[anon_label] = True
46 | lookup = example["identifier"]
47 | prediction = predictions[lookup]
48 | if prediction.lower() == example["label"].lower():
49 | num_correct += 1.
50 | else:
51 | consistency_dict[anon_label] = False
52 |
53 | # Calculate consistency.
54 | num_consistent = 0.
55 | unique_sentence = len(consistency_dict)
56 | for identifier, consistent in consistency_dict.items():
57 | if consistent:
58 | num_consistent += 1
59 |
60 | # Report values.
61 | print("accuracy=" + str(num_correct / total_num))
62 | print("consistency=" + str(num_consistent / unique_sentence))
63 |
--------------------------------------------------------------------------------
/launch_container.sh:
--------------------------------------------------------------------------------
1 | TXT_DB=$1
2 | IMG_DIR=$2
3 | OUTPUT=$3
4 | PRETRAIN_DIR=$4
5 | # TXT_RAW=$6
6 |
7 | if [ -z $CUDA_VISIBLE_DEVICES ]; then
8 | CUDA_VISIBLE_DEVICES='all'
9 | fi
10 |
11 | # if [ "$5" = "--prepro" ]; then
12 | # RO=""
13 | # else
14 | # RO=",readonly"
15 | # fi
16 |
17 |
18 | # docker run --gpus "device=$CUDA_VISIBLE_DEVICES" --ipc=host --rm -it \
19 | # -e NCCL_IB_CUDA_SUPPORT=0 \
20 |
21 |
22 |
23 | docker run -p 8892:8892 --runtime=nvidia --ipc=host --rm -it \
24 | --mount src=$(pwd),dst=/src,type=bind \
25 | --mount src=$OUTPUT,dst=/storage,type=bind \
26 | --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \
27 | --mount src=$TXT_DB,dst=/db,type=bind$RO \
28 | --mount src=$IMG_DIR,dst=/img,type=bind \
29 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
30 | -e NCCL_IB_CUDA_SUPPORT=0 \
31 | -w /src zmykevin/uc2:latest_version
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import UniterForPretraining, UniterConfig, VLXLMRForPretraining
2 | from .vqa import UniterForVisualQuestionAnswering, VLXLMRForVisualQuestionAnswering
3 | from .nlvr2 import (UniterForNlvr2PairedAttn, UniterForNlvr2Paired,
4 | UniterForNlvr2Triplet)
5 | from .itm import (UniterForImageTextRetrieval,
6 | UniterForImageTextRetrievalHardNeg,
7 | VLXLMRForImageTextRetrieval)
8 | from .ve import UniterForVisualEntailment
9 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/attention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/attention.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/const_variable.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/const_variable.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/itm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/itm.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/layer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/layer.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/nlvr2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/nlvr2.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ot.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/ot.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ve.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/ve.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/vqa.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/model/__pycache__/vqa.cpython-36.pyc
--------------------------------------------------------------------------------
/model/const_variable.py:
--------------------------------------------------------------------------------
1 | from transformers import XLMRobertaTokenizer
2 | import numpy as np
3 | import torch
4 |
5 | #Load the XLMR_TOKER
6 | XLMR_TOKER = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
7 | with open('object_labels/img_label_objects.txt', "r") as f:
8 | IMG_LABEL_OBJECTS = f.readlines()
9 | IMG_LABEL_OBJECTS = ['background'] + [x.strip() for x in IMG_LABEL_OBJECTS]
10 |
11 | #initialize LABEL2TOKEN_MATRIX
12 | VALID_XLMR_TOKEN_IDS = []
13 | #LABEL2TOKEN = {}
14 | LABEL2TOKEN_MATRIX = np.zeros((1601, XLMR_TOKER.vocab_size))
15 | for i,label_word in enumerate(IMG_LABEL_OBJECTS):
16 | label_tokens = XLMR_TOKER.tokenize(label_word)
17 | label_tokens_ids = [XLMR_TOKER._convert_token_to_id(w) for w in label_tokens]
18 | #LABEL2TOKEN[label_word] = label_tokens_ids
19 | LABEL2TOKEN_MATRIX[i][label_tokens_ids] = 1
20 | VALID_XLMR_TOKEN_IDS.extend(label_tokens_ids)
21 | #torch.tensor(LABEL2TOKEN_MATRIX)
22 | VALID_XLMR_TOKEN_IDS = list(set(VALID_XLMR_TOKEN_IDS))
23 | VALID_XLMR_TOKEN_IDS.sort()
24 |
--------------------------------------------------------------------------------
/model/gqa.py:
--------------------------------------------------------------------------------
1 | """
2 | Bert for VCR model
3 | """
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from pytorch_pretrained_bert.modeling import (
7 | BertOnlyMLMHead)
8 | from .model import (BertForImageTextPretraining,
9 | _get_image_hidden,
10 | mask_img_feat,
11 | RegionFeatureRegression,
12 | mask_img_feat_for_mrc,
13 | RegionClassification)
14 | import torch
15 | import random
16 |
17 |
18 | class BertForImageTextPretrainingForGQA(BertForImageTextPretraining):
19 | def init_type_embedding(self):
20 | new_emb = nn.Embedding(3, self.bert.config.hidden_size)
21 | new_emb.apply(self.init_bert_weights)
22 | for i in [0, 1]:
23 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :]
24 | new_emb.weight.data[i, :].copy_(emb)
25 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :]
26 | new_emb.weight.data[2, :].copy_(emb)
27 | self.bert.embeddings.token_type_embeddings = new_emb
28 |
29 | def forward(self, input_ids, position_ids, txt_type_ids, txt_lens,
30 | img_feat, img_pos_feat, num_bbs,
31 | attention_mask, labels, task, compute_loss=True):
32 | if task == 'mlm':
33 | txt_labels = labels
34 | return self.forward_mlm(input_ids, position_ids, txt_type_ids,
35 | txt_lens,
36 | img_feat, img_pos_feat, num_bbs,
37 | attention_mask, txt_labels, compute_loss)
38 | elif task == 'mrm':
39 | img_mask = labels
40 | return self.forward_mrm(input_ids, position_ids, txt_type_ids,
41 | txt_lens,
42 | img_feat, img_pos_feat, num_bbs,
43 | attention_mask, img_mask, compute_loss)
44 | elif task.startswith('mrc'):
45 | img_mask, mrc_label_target = labels
46 | return self.forward_mrc(input_ids, position_ids, txt_type_ids,
47 | txt_lens,
48 | img_feat, img_pos_feat, num_bbs,
49 | attention_mask, img_mask,
50 | mrc_label_target, task, compute_loss)
51 | else:
52 | raise ValueError('invalid task')
53 |
54 | # MLM
55 | def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens,
56 | img_feat, img_pos_feat, num_bbs,
57 | attention_mask, txt_labels, compute_loss=True):
58 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
59 | img_feat, img_pos_feat, num_bbs,
60 | attention_mask,
61 | output_all_encoded_layers=False,
62 | txt_type_ids=txt_type_ids)
63 | # get only the text part
64 | sequence_output = sequence_output[:, :input_ids.size(1), :]
65 | # only compute masked tokens for better efficiency
66 | prediction_scores = self.masked_compute_scores(
67 | sequence_output, txt_labels != -1)
68 | if self.vocab_pad:
69 | prediction_scores = prediction_scores[:, :-self.vocab_pad]
70 |
71 | if compute_loss:
72 | masked_lm_loss = F.cross_entropy(prediction_scores,
73 | txt_labels[txt_labels != -1],
74 | reduction='none')
75 | return masked_lm_loss
76 | else:
77 | return prediction_scores
78 |
79 | # MRM
80 | def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens,
81 | img_feat, img_pos_feat, num_bbs,
82 | attention_mask, img_masks, compute_loss=True):
83 | img_feat, feat_targets = mask_img_feat(img_feat, img_masks)
84 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
85 | img_feat, img_pos_feat, num_bbs,
86 | attention_mask,
87 | output_all_encoded_layers=False,
88 | txt_type_ids=txt_type_ids)
89 | # get only the text part
90 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
91 | # only compute masked tokens for better efficiency
92 | prediction_feat = self.masked_compute_feat(
93 | sequence_output, img_masks)
94 |
95 | if compute_loss:
96 | mrm_loss = F.mse_loss(prediction_feat, feat_targets,
97 | reduction='none')
98 | return mrm_loss
99 | else:
100 | return prediction_feat
101 |
102 | # MRC
103 | def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens,
104 | img_feat, img_pos_feat, num_bbs,
105 | attention_mask, img_masks,
106 | label_targets, task, compute_loss=True):
107 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks)
108 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
109 | img_feat, img_pos_feat, num_bbs,
110 | attention_mask,
111 | output_all_encoded_layers=False,
112 | txt_type_ids=txt_type_ids)
113 | # get only the image part
114 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
115 | # only compute masked tokens for better efficiency
116 | prediction_soft_label = self.masked_predict_labels(
117 | sequence_output, img_masks)
118 |
119 | if compute_loss:
120 | if "kl" in task:
121 | prediction_soft_label = F.log_softmax(
122 | prediction_soft_label, dim=-1)
123 | mrc_loss = F.kl_div(
124 | prediction_soft_label, label_targets, reduction='none')
125 | else:
126 | label_targets = torch.max(
127 | label_targets, -1)[1] # argmax
128 | mrc_loss = F.cross_entropy(
129 | prediction_soft_label, label_targets,
130 | ignore_index=0, reduction='none')
131 | return mrc_loss
132 | else:
133 | return prediction_soft_label
134 |
--------------------------------------------------------------------------------
/model/itm.py:
--------------------------------------------------------------------------------
1 | """
2 | UNITER for ITM model
3 | """
4 | from collections import defaultdict
5 |
6 | import torch
7 | from torch import nn
8 | from .model import UniterPreTrainedModel, UniterModel, VLXLMRPreTrainedModel, VLXLMRModel
9 |
10 | #########################vl-xlmr########################
11 |
12 | class VLXLMRForImageTextRetrieval(VLXLMRPreTrainedModel):
13 | """ Finetune UNITER for image text retrieval
14 | """
15 | def __init__(self, config, img_dim, margin=0.2):
16 | super().__init__(config)
17 | self.roberta = VLXLMRModel(config, img_dim)
18 | self.itm_output = nn.Linear(config.hidden_size, 2)
19 | self.rank_output = nn.Linear(config.hidden_size, 1)
20 | self.margin = margin
21 | self.apply(self.init_weights)
22 |
23 | def init_output(self):
24 | """ need to be called after from pretrained """
25 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
26 | self.rank_output.bias.data = self.itm_output.bias.data[1:]
27 |
28 | def forward(self, batch, compute_loss=True):
29 | batch = defaultdict(lambda: None, batch)
30 | input_ids = batch['input_ids']
31 | #position_ids = batch['position_ids']
32 | position_ids=None
33 | img_feat = batch['img_feat']
34 | img_pos_feat = batch['img_pos_feat']
35 | attention_mask = batch['attn_masks']
36 | gather_index = batch['gather_index']
37 |
38 | sequence_output = self.roberta(input_ids, position_ids,
39 | img_feat, img_pos_feat,
40 | attention_mask, gather_index,
41 | output_all_encoded_layers=False)
42 | pooled_output = self.roberta.pooler(sequence_output)
43 | rank_scores = self.rank_output(pooled_output)
44 |
45 | if compute_loss:
46 | # triplet loss
47 | rank_scores_sigmoid = torch.sigmoid(rank_scores)
48 | sample_size = batch['sample_size']
49 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
50 | pos = scores[:, :1]
51 | neg = scores[:, 1:]
52 | rank_loss = torch.clamp(self.margin + neg - pos, 0)
53 | return rank_loss
54 | else:
55 | return rank_scores
56 | #######################################################
57 | class UniterForImageTextRetrieval(UniterPreTrainedModel):
58 | """ Finetune UNITER for image text retrieval
59 | """
60 | def __init__(self, config, img_dim, margin=0.2):
61 | super().__init__(config)
62 | self.bert = UniterModel(config, img_dim)
63 | self.itm_output = nn.Linear(config.hidden_size, 2)
64 | #self.itm_output = nn.Linear(config.hidden_size, 1)
65 | self.rank_output = nn.Linear(config.hidden_size, 1)
66 | self.margin = margin
67 | self.apply(self.init_weights)
68 |
69 | def init_output(self):
70 | """ need to be called after from pretrained """
71 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
72 | self.rank_output.bias.data = self.itm_output.bias.data[1:]
73 | # self.rank_output.weight.data = self.itm_output.weight.data
74 | # self.rank_output.bias.data = self.itm_output.bias.data
75 |
76 | def forward(self, batch, compute_loss=True):
77 | batch = defaultdict(lambda: None, batch)
78 | input_ids = batch['input_ids']
79 | position_ids = batch['position_ids']
80 | img_feat = batch['img_feat']
81 | img_pos_feat = batch['img_pos_feat']
82 | attention_mask = batch['attn_masks']
83 | gather_index = batch['gather_index']
84 | sequence_output = self.bert(input_ids, position_ids,
85 | img_feat, img_pos_feat,
86 | attention_mask, gather_index,
87 | output_all_encoded_layers=False)
88 | pooled_output = self.bert.pooler(sequence_output)
89 | #print(pooled_output)
90 | rank_scores = self.rank_output(pooled_output)
91 |
92 | if compute_loss:
93 | # triplet loss
94 | rank_scores_sigmoid = torch.sigmoid(rank_scores)
95 | sample_size = batch['sample_size']
96 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
97 | pos = scores[:, :1]
98 | neg = scores[:, 1:]
99 | rank_loss = torch.clamp(self.margin + neg - pos, 0)
100 | return rank_loss
101 | else:
102 | return rank_scores
103 |
104 |
105 | class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval):
106 | """ Finetune UNITER for image text retrieval
107 | """
108 | def __init__(self, config, img_dim, margin=0.2, hard_size=16):
109 | super().__init__(config, img_dim, margin)
110 | self.hard_size = hard_size
111 |
112 | def forward(self, batch, sample_from='t', compute_loss=True):
113 | # expect same input_ids for all pairs
114 | batch_size = batch['attn_masks'].size(0)
115 | input_ids = batch['input_ids']
116 | img_feat = batch['img_feat']
117 | img_pos_feat = batch['img_pos_feat']
118 | if sample_from == 't':
119 | if input_ids.size(0) == 1:
120 | batch['input_ids'] = input_ids.expand(batch_size, -1)
121 | elif sample_from == 'i':
122 | if img_feat.size(0) == 1:
123 | batch['img_feat'] = img_feat.expand(batch_size, -1, -1)
124 | if img_pos_feat.size(0) == 1:
125 | batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1)
126 | else:
127 | raise ValueError()
128 |
129 | if self.training and compute_loss:
130 | with torch.no_grad():
131 | self.eval()
132 | scores = super().forward(batch, compute_loss=False)
133 | hard_batch = self._get_hard_batch(batch, scores, sample_from)
134 | self.train()
135 | return super().forward(hard_batch, compute_loss=True)
136 | else:
137 | return super().forward(batch, compute_loss)
138 |
139 | def _get_hard_batch(self, batch, scores, sample_from='t'):
140 | batch = defaultdict(lambda: None, batch)
141 | input_ids = batch['input_ids']
142 | position_ids = batch['position_ids']
143 | img_feat = batch['img_feat']
144 | img_pos_feat = batch['img_pos_feat']
145 | attention_mask = batch['attn_masks']
146 | gather_index = batch['gather_index']
147 | hard_batch = {'sample_size': self.hard_size + 1}
148 |
149 | # NOTE first example is positive
150 | hard_indices = scores.squeeze(-1)[1:].topk(
151 | self.hard_size, sorted=False)[1] + 1
152 | indices = torch.cat([torch.zeros(1, dtype=torch.long,
153 | device=hard_indices.device),
154 | hard_indices])
155 |
156 | attention_mask = attention_mask.index_select(0, indices)
157 | gather_index = gather_index.index_select(0, indices)
158 | if position_ids.size(0) != 1:
159 | position_ids = position_ids[:self.hard_size+1]
160 |
161 | if sample_from == 't':
162 | # cut to minimum padding
163 | max_len = attention_mask.sum(dim=1).max().item()
164 | max_i = max_len - input_ids.size(1)
165 | attention_mask = attention_mask[:, :max_len]
166 | gather_index = gather_index[:, :max_len]
167 | img_feat = img_feat.index_select(0, indices)[:, :max_i, :]
168 | img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :]
169 | # expect same input_ids for all pairs
170 | input_ids = input_ids[:self.hard_size+1]
171 | elif sample_from == 'i':
172 | input_ids = input_ids.index_select(0, indices)
173 | # expect same image features for all pairs
174 | img_feat = img_feat[:self.hard_size+1]
175 | img_pos_feat = img_pos_feat[:self.hard_size+1]
176 | else:
177 | raise ValueError()
178 |
179 | hard_batch['input_ids'] = input_ids
180 | hard_batch['position_ids'] = position_ids
181 | hard_batch['img_feat'] = img_feat
182 | hard_batch['img_pos_feat'] = img_pos_feat
183 | hard_batch['attn_masks'] = attention_mask
184 | hard_batch['gather_index'] = gather_index
185 |
186 | return hard_batch
187 |
--------------------------------------------------------------------------------
/model/layer.py:
--------------------------------------------------------------------------------
1 | """
2 | BERT layers from the huggingface implementation
3 | (https://github.com/huggingface/transformers)
4 | """
5 | # coding=utf-8
6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | import logging
21 | import math
22 |
23 | import torch
24 | from torch import nn
25 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
26 |
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | def gelu(x):
32 | """Implementation of the gelu activation function.
33 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
34 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
35 | Also see https://arxiv.org/abs/1606.08415
36 | """
37 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
38 |
39 |
40 | def swish(x):
41 | return x * torch.sigmoid(x)
42 |
43 |
44 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
45 |
46 |
47 | class GELU(nn.Module):
48 | def forward(self, input_):
49 | output = gelu(input_)
50 | return output
51 |
52 |
53 | class BertSelfAttention(nn.Module):
54 | def __init__(self, config):
55 | super(BertSelfAttention, self).__init__()
56 | if config.hidden_size % config.num_attention_heads != 0:
57 | raise ValueError(
58 | "The hidden size (%d) is not a multiple of the number of attention "
59 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
60 | self.num_attention_heads = config.num_attention_heads
61 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
62 | self.all_head_size = self.num_attention_heads * self.attention_head_size
63 |
64 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
65 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
66 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
67 |
68 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
69 |
70 | def transpose_for_scores(self, x):
71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
72 | x = x.view(*new_x_shape)
73 | return x.permute(0, 2, 1, 3)
74 |
75 | def forward(self, hidden_states, attention_mask):
76 | mixed_query_layer = self.query(hidden_states)
77 | mixed_key_layer = self.key(hidden_states)
78 | mixed_value_layer = self.value(hidden_states)
79 |
80 | query_layer = self.transpose_for_scores(mixed_query_layer)
81 | key_layer = self.transpose_for_scores(mixed_key_layer)
82 | value_layer = self.transpose_for_scores(mixed_value_layer)
83 |
84 | # Take the dot product between "query" and "key" to get the raw attention scores.
85 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
86 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
87 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
88 | attention_scores = attention_scores + attention_mask
89 |
90 | # Normalize the attention scores to probabilities.
91 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
92 |
93 | # This is actually dropping out entire tokens to attend to, which might
94 | # seem a bit unusual, but is taken from the original Transformer paper.
95 | attention_probs = self.dropout(attention_probs)
96 |
97 | context_layer = torch.matmul(attention_probs, value_layer)
98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
100 | context_layer = context_layer.view(*new_context_layer_shape)
101 | return context_layer
102 |
103 |
104 | class BertSelfOutput(nn.Module):
105 | def __init__(self, config):
106 | super(BertSelfOutput, self).__init__()
107 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
108 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
109 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
110 |
111 | def forward(self, hidden_states, input_tensor):
112 | hidden_states = self.dense(hidden_states)
113 | hidden_states = self.dropout(hidden_states)
114 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
115 | return hidden_states
116 |
117 |
118 | class BertAttention(nn.Module):
119 | def __init__(self, config):
120 | super(BertAttention, self).__init__()
121 | self.self = BertSelfAttention(config)
122 | self.output = BertSelfOutput(config)
123 |
124 | def forward(self, input_tensor, attention_mask):
125 | self_output = self.self(input_tensor, attention_mask)
126 | attention_output = self.output(self_output, input_tensor)
127 | return attention_output
128 |
129 |
130 | class BertIntermediate(nn.Module):
131 | def __init__(self, config):
132 | super(BertIntermediate, self).__init__()
133 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
134 | if isinstance(config.hidden_act, str):
135 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
136 | else:
137 | self.intermediate_act_fn = config.hidden_act
138 |
139 | def forward(self, hidden_states):
140 | hidden_states = self.dense(hidden_states)
141 | hidden_states = self.intermediate_act_fn(hidden_states)
142 | return hidden_states
143 |
144 |
145 | class BertOutput(nn.Module):
146 | def __init__(self, config):
147 | super(BertOutput, self).__init__()
148 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
149 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
150 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
151 |
152 | def forward(self, hidden_states, input_tensor):
153 | hidden_states = self.dense(hidden_states)
154 | hidden_states = self.dropout(hidden_states)
155 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
156 | return hidden_states
157 |
158 |
159 | class BertLayer(nn.Module):
160 | def __init__(self, config):
161 | super(BertLayer, self).__init__()
162 | self.attention = BertAttention(config)
163 | self.intermediate = BertIntermediate(config)
164 | self.output = BertOutput(config)
165 |
166 | def forward(self, hidden_states, attention_mask):
167 | attention_output = self.attention(hidden_states, attention_mask)
168 | intermediate_output = self.intermediate(attention_output)
169 | layer_output = self.output(intermediate_output, attention_output)
170 | return layer_output
171 |
172 |
173 | class BertPooler(nn.Module):
174 | def __init__(self, config):
175 | super(BertPooler, self).__init__()
176 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
177 | self.activation = nn.Tanh()
178 |
179 | def forward(self, hidden_states):
180 | # We "pool" the model by simply taking the hidden state corresponding
181 | # to the first token.
182 | first_token_tensor = hidden_states[:, 0]
183 | pooled_output = self.dense(first_token_tensor)
184 | pooled_output = self.activation(pooled_output)
185 | return pooled_output
186 |
187 |
188 | class BertPredictionHeadTransform(nn.Module):
189 | def __init__(self, config):
190 | super(BertPredictionHeadTransform, self).__init__()
191 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
192 | if isinstance(config.hidden_act, str):
193 | self.transform_act_fn = ACT2FN[config.hidden_act]
194 | else:
195 | self.transform_act_fn = config.hidden_act
196 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
197 |
198 | def forward(self, hidden_states):
199 | hidden_states = self.dense(hidden_states)
200 | hidden_states = self.transform_act_fn(hidden_states)
201 | hidden_states = self.LayerNorm(hidden_states)
202 | return hidden_states
203 |
204 |
205 | class BertLMPredictionHead(nn.Module):
206 | def __init__(self, config, bert_model_embedding_weights):
207 | super(BertLMPredictionHead, self).__init__()
208 | self.transform = BertPredictionHeadTransform(config)
209 |
210 | # The output weights are the same as the input embeddings, but there is
211 | # an output-only bias for each token.
212 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
213 | bert_model_embedding_weights.size(0),
214 | bias=False)
215 | self.decoder.weight = bert_model_embedding_weights
216 | self.bias = nn.Parameter(
217 | torch.zeros(bert_model_embedding_weights.size(0)))
218 |
219 | def forward(self, hidden_states):
220 | hidden_states = self.transform(hidden_states)
221 | hidden_states = self.decoder(hidden_states) + self.bias
222 | return hidden_states
223 |
224 |
225 | class BertOnlyMLMHead(nn.Module):
226 | def __init__(self, config, bert_model_embedding_weights):
227 | super(BertOnlyMLMHead, self).__init__()
228 | self.predictions = BertLMPredictionHead(config,
229 | bert_model_embedding_weights)
230 |
231 | def forward(self, sequence_output):
232 | prediction_scores = self.predictions(sequence_output)
233 | return prediction_scores
234 |
235 | ####################Mingyang Zhou##############################
236 | class RobertaLMHead(nn.Module):
237 | """Roberta Head for masked language modeling."""
238 |
239 | def __init__(self, config, roberta_model_embedding_weights):
240 | super().__init__()
241 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
242 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
243 |
244 | #self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
245 | self.decoder = nn.Linear(roberta_model_embedding_weights.size(1),
246 | roberta_model_embedding_weights.size(0),
247 | bias=False
248 | )
249 | self.decoder.weight = roberta_model_embedding_weights
250 | self.bias = nn.Parameter(
251 | torch.zeros(roberta_model_embedding_weights.size(0)))
252 | #self.bias = nn.Parameter(torch.zeros(config.vocab_size))
253 |
254 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
255 | self.decoder.bias = self.bias
256 |
257 | def forward(self, features, **kwargs):
258 | x = self.dense(features)
259 | x = gelu(x)
260 | x = self.layer_norm(x)
261 |
262 | # project back to size of vocabulary with bias
263 | x = self.decoder(x)
264 |
265 | return x
266 |
267 | class VisualRobertaLMHead(nn.Module):
268 | def __init__(self, config, roberta_model_embedding_weights, valid_roberta_model_embedding_index):
269 | super().__init__()
270 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
271 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
272 |
273 | #self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274 | self.decoder = nn.Linear(roberta_model_embedding_weights.size(1),
275 | len(valid_roberta_model_embedding_index),
276 | bias=False
277 | )
278 | self.decoder.weight = nn.Parameter(roberta_model_embedding_weights[valid_roberta_model_embedding_index])
279 | self.bias = nn.Parameter(
280 | torch.zeros(len(valid_roberta_model_embedding_index)))
281 | #self.bias = nn.Parameter(torch.zeros(config.vocab_size))
282 |
283 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
284 | self.decoder.bias = self.bias
285 |
286 | def forward(self, features, **kwargs):
287 | x = self.dense(features)
288 | x = gelu(x)
289 | x = self.layer_norm(x)
290 |
291 | # project back to size of vocabulary with bias
292 | x = self.decoder(x)
293 |
294 | return x
--------------------------------------------------------------------------------
/model/nlvr2.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | Uniter for NLVR2 model
6 | """
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from .layer import GELU
12 | from .model import UniterPreTrainedModel, UniterModel
13 | from .attention import MultiheadAttention
14 |
15 |
16 | class UniterForNlvr2Paired(UniterPreTrainedModel):
17 | """ Finetune UNITER for NLVR2 (paired format)
18 | """
19 | def __init__(self, config, img_dim):
20 | super().__init__(config)
21 | self.bert = UniterModel(config, img_dim)
22 | self.nlvr2_output = nn.Linear(config.hidden_size*2, 2)
23 | self.apply(self.init_weights)
24 |
25 | def init_type_embedding(self):
26 | new_emb = nn.Embedding(3, self.bert.config.hidden_size)
27 | new_emb.apply(self.init_weights)
28 | for i in [0, 1]:
29 | emb = self.bert.embeddings.token_type_embeddings\
30 | .weight.data[i, :]
31 | new_emb.weight.data[i, :].copy_(emb)
32 | new_emb.weight.data[2, :].copy_(emb)
33 | self.bert.embeddings.token_type_embeddings = new_emb
34 |
35 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
36 | attn_masks, gather_index,
37 | img_type_ids, targets, compute_loss=True):
38 | sequence_output = self.bert(input_ids, position_ids,
39 | img_feat, img_pos_feat,
40 | attn_masks, gather_index,
41 | output_all_encoded_layers=False,
42 | img_type_ids=img_type_ids)
43 | pooled_output = self.bert.pooler(sequence_output)
44 | # concat CLS of the pair
45 | n_pair = pooled_output.size(0) // 2
46 | reshaped_output = pooled_output.contiguous().view(n_pair, -1)
47 | answer_scores = self.nlvr2_output(reshaped_output)
48 |
49 | if compute_loss:
50 | nlvr2_loss = F.cross_entropy(
51 | answer_scores, targets, reduction='none')
52 | return nlvr2_loss
53 | else:
54 | return answer_scores
55 |
56 |
57 | class UniterForNlvr2Triplet(UniterPreTrainedModel):
58 | """ Finetune UNITER for NLVR2 (triplet format)
59 | """
60 | def __init__(self, config, img_dim):
61 | super().__init__(config)
62 | self.bert = UniterModel(config, img_dim)
63 | self.nlvr2_output = nn.Linear(config.hidden_size, 2)
64 | self.apply(self.init_weights)
65 |
66 | def init_type_embedding(self):
67 | new_emb = nn.Embedding(3, self.bert.config.hidden_size)
68 | new_emb.apply(self.init_weights)
69 | for i in [0, 1]:
70 | emb = self.bert.embeddings.token_type_embeddings\
71 | .weight.data[i, :]
72 | new_emb.weight.data[i, :].copy_(emb)
73 | new_emb.weight.data[2, :].copy_(emb)
74 | self.bert.embeddings.token_type_embeddings = new_emb
75 |
76 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
77 | attn_masks, gather_index,
78 | img_type_ids, targets, compute_loss=True):
79 | sequence_output = self.bert(input_ids, position_ids,
80 | img_feat, img_pos_feat,
81 | attn_masks, gather_index,
82 | output_all_encoded_layers=False,
83 | img_type_ids=img_type_ids)
84 | pooled_output = self.bert.pooler(sequence_output)
85 | answer_scores = self.nlvr2_output(pooled_output)
86 |
87 | if compute_loss:
88 | nlvr2_loss = F.cross_entropy(
89 | answer_scores, targets, reduction='none')
90 | return nlvr2_loss
91 | else:
92 | return answer_scores
93 |
94 |
95 | class AttentionPool(nn.Module):
96 | """ attention pooling layer """
97 | def __init__(self, hidden_size, drop=0.0):
98 | super().__init__()
99 | self.fc = nn.Sequential(nn.Linear(hidden_size, 1), GELU())
100 | self.dropout = nn.Dropout(drop)
101 |
102 | def forward(self, input_, mask=None):
103 | """input: [B, T, D], mask = [B, T]"""
104 | score = self.fc(input_).squeeze(-1)
105 | if mask is not None:
106 | mask = mask.to(dtype=input_.dtype) * -1e4
107 | score = score + mask
108 | norm_score = self.dropout(F.softmax(score, dim=1))
109 | output = norm_score.unsqueeze(1).matmul(input_).squeeze(1)
110 | return output
111 |
112 |
113 | class UniterForNlvr2PairedAttn(UniterPreTrainedModel):
114 | """ Finetune UNITER for NLVR2
115 | (paired format with additional attention layer)
116 | """
117 | def __init__(self, config, img_dim):
118 | super().__init__(config)
119 | self.bert = UniterModel(config, img_dim)
120 | self.attn1 = MultiheadAttention(config.hidden_size,
121 | config.num_attention_heads,
122 | config.attention_probs_dropout_prob)
123 | self.attn2 = MultiheadAttention(config.hidden_size,
124 | config.num_attention_heads,
125 | config.attention_probs_dropout_prob)
126 | self.fc = nn.Sequential(
127 | nn.Linear(2*config.hidden_size, config.hidden_size),
128 | GELU(),
129 | nn.Dropout(config.hidden_dropout_prob))
130 | self.attn_pool = AttentionPool(config.hidden_size,
131 | config.attention_probs_dropout_prob)
132 | self.nlvr2_output = nn.Linear(2*config.hidden_size, 2)
133 | self.apply(self.init_weights)
134 |
135 | def init_type_embedding(self):
136 | new_emb = nn.Embedding(3, self.bert.config.hidden_size)
137 | new_emb.apply(self.init_weights)
138 | for i in [0, 1]:
139 | emb = self.bert.embeddings.token_type_embeddings\
140 | .weight.data[i, :]
141 | new_emb.weight.data[i, :].copy_(emb)
142 | new_emb.weight.data[2, :].copy_(emb)
143 | self.bert.embeddings.token_type_embeddings = new_emb
144 |
145 | def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
146 | attn_masks, gather_index,
147 | img_type_ids, targets, compute_loss=True):
148 | sequence_output = self.bert(input_ids, position_ids,
149 | img_feat, img_pos_feat,
150 | attn_masks, gather_index,
151 | output_all_encoded_layers=False,
152 | img_type_ids=img_type_ids)
153 | # separate left image and right image
154 | bs, tl, d = sequence_output.size()
155 | left_out, right_out = sequence_output.contiguous().view(
156 | bs//2, tl*2, d).chunk(2, dim=1)
157 | # bidirectional attention
158 | mask = attn_masks == 0
159 | left_mask, right_mask = mask.contiguous().view(bs//2, tl*2
160 | ).chunk(2, dim=1)
161 | left_out = left_out.transpose(0, 1)
162 | right_out = right_out.transpose(0, 1)
163 | l2r_attn, _ = self.attn1(left_out, right_out, right_out,
164 | key_padding_mask=right_mask)
165 | r2l_attn, _ = self.attn2(right_out, left_out, left_out,
166 | key_padding_mask=left_mask)
167 | left_out = self.fc(torch.cat([l2r_attn, left_out], dim=-1)
168 | ).transpose(0, 1)
169 | right_out = self.fc(torch.cat([r2l_attn, right_out], dim=-1)
170 | ).transpose(0, 1)
171 | # attention pooling and final prediction
172 | left_out = self.attn_pool(left_out, left_mask)
173 | right_out = self.attn_pool(right_out, right_mask)
174 | answer_scores = self.nlvr2_output(
175 | torch.cat([left_out, right_out], dim=-1))
176 |
177 | if compute_loss:
178 | nlvr2_loss = F.cross_entropy(
179 | answer_scores, targets, reduction='none')
180 | return nlvr2_loss
181 | else:
182 | return answer_scores
183 |
--------------------------------------------------------------------------------
/model/ot.py:
--------------------------------------------------------------------------------
1 | """
2 | Wasserstein Distance (Optimal Transport)
3 | """
4 | import torch
5 | from torch.nn import functional as F
6 |
7 |
8 | def cost_matrix_cosine(x, y, eps=1e-5):
9 | """ Compute cosine distnace across every pairs of x, y (batched)
10 | [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]"""
11 | assert x.dim() == y.dim()
12 | assert x.size(0) == y.size(0)
13 | assert x.size(2) == y.size(2)
14 | x_norm = F.normalize(x, p=2, dim=-1, eps=eps)
15 | y_norm = F.normalize(y, p=2, dim=-1, eps=eps)
16 | cosine_sim = x_norm.matmul(y_norm.transpose(1, 2))
17 | cosine_dist = 1 - cosine_sim
18 | return cosine_dist
19 |
20 |
21 | def trace(x):
22 | """ compute trace of input tensor (batched) """
23 | b, m, n = x.size()
24 | assert m == n
25 | mask = torch.eye(n, dtype=torch.uint8, device=x.device
26 | ).unsqueeze(0).expand_as(x)
27 | trace = x.masked_select(mask).contiguous().view(
28 | b, n).sum(dim=-1, keepdim=False)
29 | return trace
30 |
31 |
32 | @torch.no_grad()
33 | def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k):
34 | """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]"""
35 | b, m, n = C.size()
36 | sigma = torch.ones(b, m, dtype=C.dtype, device=C.device
37 | ) / x_len.unsqueeze(1)
38 | T = torch.ones(b, n, m, dtype=C.dtype, device=C.device)
39 | A = torch.exp(-C.transpose(1, 2)/beta)
40 |
41 | # mask padded positions
42 | sigma.masked_fill_(x_pad, 0)
43 | joint_pad = joint_pad.transpose(1, 2)
44 | T.masked_fill_(joint_pad, 0)
45 | A.masked_fill_(joint_pad, 0)
46 |
47 | # broadcastable lengths
48 | x_len = x_len.unsqueeze(1).unsqueeze(2)
49 | y_len = y_len.unsqueeze(1).unsqueeze(2)
50 |
51 | # mask to zero out padding in delta and sigma
52 | x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1)
53 | y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1)
54 |
55 | for _ in range(iteration):
56 | Q = A * T # bs * n * m
57 | sigma = sigma.view(b, m, 1)
58 | for _ in range(k):
59 | delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask)
60 | sigma = 1 / (x_len * delta.matmul(Q) + x_mask)
61 | T = delta.view(b, n, 1) * Q * sigma
62 | T.masked_fill_(joint_pad, 0)
63 | return T
64 |
65 |
66 | def optimal_transport_dist(txt_emb, img_emb, txt_pad, img_pad,
67 | beta=0.5, iteration=50, k=1):
68 | """ [B, M, D], [B, N, D], [B, M], [B, N]"""
69 | cost = cost_matrix_cosine(txt_emb, img_emb)
70 | # mask the padded inputs
71 | joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
72 | cost.masked_fill_(joint_pad, 0)
73 |
74 | txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)
75 | ).to(dtype=cost.dtype)
76 | img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)
77 | ).to(dtype=cost.dtype)
78 |
79 | T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad,
80 | beta, iteration, k)
81 | distance = trace(cost.matmul(T.detach()))
82 | return distance
83 |
--------------------------------------------------------------------------------
/model/re.py:
--------------------------------------------------------------------------------
1 | """
2 | Bert for Referring Expression Comprehension
3 | """
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm
9 |
10 | from .model import BertVisionLanguageEncoder
11 |
12 | import numpy as np
13 | import random
14 |
15 |
16 | class BertForReferringExpressionComprehension(BertPreTrainedModel):
17 | """Finetune multi-model BERT for Referring Expression Comprehension
18 | """
19 | def __init__(self, config, img_dim, loss="cls",
20 | margin=0.2, hard_ratio=0.3, mlp=1):
21 | super().__init__(config)
22 | self.bert = BertVisionLanguageEncoder(config, img_dim)
23 | if mlp == 1:
24 | self.re_output = nn.Linear(config.hidden_size, 1)
25 | elif mlp == 2:
26 | self.re_output = nn.Sequential(
27 | nn.Linear(config.hidden_size, config.hidden_size),
28 | nn.ReLU(),
29 | BertLayerNorm(config.hidden_size, eps=1e-12),
30 | nn.Linear(config.hidden_size, 1)
31 | )
32 | else:
33 | sys.exit("MLP restricted to be 1 or 2 layers.")
34 | self.loss = loss
35 | assert self.loss in ['cls', 'rank']
36 | if self.loss == 'rank':
37 | self.margin = margin
38 | self.hard_ratio = hard_ratio
39 | else:
40 | self.crit = nn.CrossEntropyLoss(reduction='none')
41 | # initialize
42 | self.apply(self.init_bert_weights)
43 |
44 | def forward(self, input_ids, position_ids, txt_lens, img_feat,
45 | img_pos_feat, num_bbs, attn_masks, obj_masks,
46 | targets, compute_loss=True):
47 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
48 | img_feat, img_pos_feat, num_bbs,
49 | attn_masks,
50 | output_all_encoded_layers=False)
51 | # get only the region part
52 | sequence_output = self._get_image_hidden(sequence_output, txt_lens,
53 | num_bbs)
54 |
55 | # re score (n, max_num_bb)
56 | scores = self.re_output(sequence_output).squeeze(2)
57 | scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects
58 |
59 | # loss
60 | if compute_loss:
61 | if self.loss == 'cls':
62 | ce_loss = self.crit(scores, targets) # (n, ) as no reduction
63 | return ce_loss
64 | else:
65 | # ranking
66 | _n = len(num_bbs)
67 | # positive (target)
68 | pos_ix = targets
69 | pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1)
70 | pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1]
71 | # negative
72 | neg_ix = self.sample_neg_ix(scores, targets, num_bbs)
73 | neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1)
74 | neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1]
75 | # ranking
76 | mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, )
77 | return mm_loss
78 | else:
79 | # (n, max_num_bb)
80 | return scores
81 |
82 | def sample_neg_ix(self, scores, targets, num_bbs):
83 | """
84 | Inputs:
85 | :scores (n, max_num_bb)
86 | :targets (n, )
87 | :num_bbs list of [num_bb]
88 | return:
89 | :neg_ix (n, ) easy/hard negative (!= target)
90 | """
91 | neg_ix = []
92 | cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb)
93 | for i in range(len(num_bbs)):
94 | num_bb = num_bbs[i]
95 | if np.random.uniform(0, 1, 1) < self.hard_ratio:
96 | # sample hard negative, w/ highest score
97 | for ix in cand_ixs[i].tolist():
98 | if ix != targets[i]:
99 | assert ix < num_bb, f'ix={ix}, num_bb={num_bb}'
100 | neg_ix.append(ix)
101 | break
102 | else:
103 | # sample easy negative, i.e., random one
104 | ix = random.randint(0, num_bb-1) # [0, num_bb-1]
105 | while ix == targets[i]:
106 | ix = random.randint(0, num_bb-1)
107 | neg_ix.append(ix)
108 | neg_ix = torch.tensor(neg_ix).type(targets.type())
109 | assert neg_ix.numel() == targets.numel()
110 | return neg_ix
111 |
112 | def _get_image_hidden(self, sequence_output, txt_lens, num_bbs):
113 | """
114 | Extracting the img_hidden part from sequence_output.
115 | Inputs:
116 | - sequence_output: (n, txt_len+num_bb, hid_size)
117 | - txt_lens : [txt_len]
118 | - num_bbs : [num_bb]
119 | Output:
120 | - img_hidden : (n, max_num_bb, hid_size)
121 | """
122 | outputs = []
123 | max_bb = max(num_bbs)
124 | hid_size = sequence_output.size(-1)
125 | for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0),
126 | txt_lens, num_bbs):
127 | img_hid = seq_out[:, len_:len_+nbb, :]
128 | if nbb < max_bb:
129 | img_hid = torch.cat(
130 | [img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)],
131 | dim=1)
132 | outputs.append(img_hid)
133 |
134 | img_hidden = torch.cat(outputs, dim=0)
135 | return img_hidden
136 |
137 | def _get_pad(self, t, len_, hidden_size):
138 | pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device)
139 | return pad
140 |
141 |
--------------------------------------------------------------------------------
/model/vcr.py:
--------------------------------------------------------------------------------
1 | """
2 | Bert for VCR model
3 | """
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from pytorch_pretrained_bert.modeling import (
7 | BertPreTrainedModel, BertEmbeddings, BertEncoder, BertLayerNorm,
8 | BertPooler, BertOnlyMLMHead)
9 | from .model import (BertTextEmbeddings, BertImageEmbeddings,
10 | BertForImageTextMaskedLM,
11 | BertVisionLanguageEncoder,
12 | BertForImageTextPretraining,
13 | _get_image_hidden,
14 | mask_img_feat,
15 | RegionFeatureRegression,
16 | mask_img_feat_for_mrc,
17 | RegionClassification)
18 | import torch
19 | import random
20 |
21 |
22 | class BertVisionLanguageEncoderForVCR(BertVisionLanguageEncoder):
23 | """ Modification for Joint Vision-Language Encoding
24 | """
25 | def __init__(self, config, img_dim, num_region_toks):
26 | BertPreTrainedModel.__init__(self, config)
27 | self.embeddings = BertTextEmbeddings(config)
28 | self.img_embeddings = BertImageEmbeddings(config, img_dim)
29 | self.num_region_toks = num_region_toks
30 | self.region_token_embeddings = nn.Embedding(
31 | num_region_toks,
32 | config.hidden_size)
33 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
34 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
35 | self.encoder = BertEncoder(config)
36 | self.pooler = BertPooler(config)
37 | self.apply(self.init_bert_weights)
38 |
39 | def forward(self, input_ids, position_ids, txt_lens,
40 | img_feat, img_pos_feat, num_bbs,
41 | attention_mask, output_all_encoded_layers=True,
42 | txt_type_ids=None, img_type_ids=None, region_tok_ids=None):
43 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
44 | extended_attention_mask = extended_attention_mask.to(
45 | dtype=next(self.parameters()).dtype) # fp16 compatibility
46 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
47 |
48 | embedding_output = self._compute_img_txt_embeddings(
49 | input_ids, position_ids, txt_lens,
50 | img_feat, img_pos_feat, num_bbs, attention_mask.size(1),
51 | txt_type_ids, img_type_ids)
52 | if region_tok_ids is not None:
53 | region_tok_embeddings = self.region_token_embeddings(
54 | region_tok_ids)
55 | embedding_output += region_tok_embeddings
56 | embedding_output = self.LayerNorm(embedding_output)
57 | embedding_output = self.dropout(embedding_output)
58 | encoded_layers = self.encoder(
59 | embedding_output, extended_attention_mask,
60 | output_all_encoded_layers=output_all_encoded_layers)
61 | if not output_all_encoded_layers:
62 | encoded_layers = encoded_layers[-1]
63 | return encoded_layers
64 |
65 |
66 | class BertForVisualCommonsenseReasoning(BertPreTrainedModel):
67 | """ Finetune multi-modal BERT for ITM
68 | """
69 | def __init__(self, config, img_dim, obj_cls=True, img_label_dim=81):
70 | super().__init__(config, img_dim)
71 | self.bert = BertVisionLanguageEncoder(
72 | config, img_dim)
73 | # self.vcr_output = nn.Linear(config.hidden_size, 1)
74 | # self.vcr_output = nn.Linear(config.hidden_size, 2)
75 | self.vcr_output = nn.Sequential(
76 | nn.Linear(config.hidden_size, config.hidden_size*2),
77 | nn.ReLU(),
78 | BertLayerNorm(config.hidden_size*2, eps=1e-12),
79 | nn.Linear(config.hidden_size*2, 2)
80 | )
81 | self.apply(self.init_bert_weights)
82 | self.obj_cls = obj_cls
83 | if self.obj_cls:
84 | self.region_classifier = RegionClassification(
85 | config.hidden_size, img_label_dim)
86 |
87 | def init_type_embedding(self):
88 | new_emb = nn.Embedding(4, self.bert.config.hidden_size)
89 | new_emb.apply(self.init_bert_weights)
90 | for i in [0, 1]:
91 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :]
92 | new_emb.weight.data[i, :].copy_(emb)
93 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :]
94 | new_emb.weight.data[2, :].copy_(emb)
95 | new_emb.weight.data[3, :].copy_(emb)
96 | self.bert.embeddings.token_type_embeddings = new_emb
97 |
98 | def init_word_embedding(self, num_special_tokens):
99 | orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0)
100 | new_emb = nn.Embedding(
101 | orig_word_num + num_special_tokens, self.bert.config.hidden_size)
102 | new_emb.apply(self.init_bert_weights)
103 | emb = self.bert.embeddings.word_embeddings.weight.data
104 | new_emb.weight.data[:orig_word_num, :].copy_(emb)
105 | self.bert.embeddings.word_embeddings = new_emb
106 |
107 | def masked_predict_labels(self, sequence_output, mask):
108 | # only compute masked outputs
109 | mask = mask.unsqueeze(-1).expand_as(sequence_output)
110 | sequence_output_masked = sequence_output[mask].contiguous().view(
111 | -1, self.config.hidden_size)
112 | prediction_soft_label = self.region_classifier(sequence_output_masked)
113 |
114 | return prediction_soft_label
115 |
116 | def forward(self, input_ids, position_ids, txt_lens, txt_type_ids,
117 | img_feat, img_pos_feat, num_bbs,
118 | attention_mask, targets, obj_targets=None, img_masks=None,
119 | region_tok_ids=None, compute_loss=True):
120 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
121 | img_feat, img_pos_feat, num_bbs,
122 | attention_mask,
123 | output_all_encoded_layers=False,
124 | txt_type_ids=txt_type_ids)
125 | pooled_output = self.bert.pooler(sequence_output)
126 | rank_scores = self.vcr_output(pooled_output)
127 | # rank_scores = rank_scores.reshape((-1, 4))
128 |
129 | if self.obj_cls and img_masks is not None:
130 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks)
131 | masked_sequence_output = self.bert(
132 | input_ids, position_ids, txt_lens,
133 | img_feat, img_pos_feat, num_bbs,
134 | attention_mask,
135 | output_all_encoded_layers=False,
136 | txt_type_ids=txt_type_ids)
137 | # get only the image part
138 | img_sequence_output = _get_image_hidden(
139 | masked_sequence_output, txt_lens, num_bbs)
140 | # only compute masked tokens for better efficiency
141 | predicted_obj_label = self.masked_predict_labels(
142 | img_sequence_output, img_masks)
143 |
144 | if compute_loss:
145 | vcr_loss = F.cross_entropy(
146 | rank_scores, targets.squeeze(-1),
147 | reduction='mean')
148 | if self.obj_cls:
149 | obj_cls_loss = F.cross_entropy(
150 | predicted_obj_label, obj_targets.long(),
151 | ignore_index=0, reduction='mean')
152 | else:
153 | obj_cls_loss = torch.tensor([0.], device=vcr_loss.device)
154 | return vcr_loss, obj_cls_loss
155 | else:
156 | rank_scores = rank_scores[:, 1:]
157 | return rank_scores
158 |
159 |
160 | class BertForImageTextPretrainingForVCR(BertForImageTextPretraining):
161 | def init_type_embedding(self):
162 | new_emb = nn.Embedding(4, self.bert.config.hidden_size)
163 | new_emb.apply(self.init_bert_weights)
164 | for i in [0, 1]:
165 | emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :]
166 | new_emb.weight.data[i, :].copy_(emb)
167 | emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :]
168 | new_emb.weight.data[2, :].copy_(emb)
169 | new_emb.weight.data[3, :].copy_(emb)
170 | self.bert.embeddings.token_type_embeddings = new_emb
171 |
172 | def init_word_embedding(self, num_special_tokens):
173 | orig_word_num = self.bert.embeddings.word_embeddings.weight.size(0)
174 | new_emb = nn.Embedding(
175 | orig_word_num + num_special_tokens, self.bert.config.hidden_size)
176 | new_emb.apply(self.init_bert_weights)
177 | emb = self.bert.embeddings.word_embeddings.weight.data
178 | new_emb.weight.data[:orig_word_num, :].copy_(emb)
179 | self.bert.embeddings.word_embeddings = new_emb
180 | self.cls = BertOnlyMLMHead(
181 | self.bert.config, self.bert.embeddings.word_embeddings.weight)
182 |
183 | def forward(self, input_ids, position_ids, txt_type_ids, txt_lens,
184 | img_feat, img_pos_feat, num_bbs,
185 | attention_mask, labels, task, compute_loss=True):
186 | if task == 'mlm':
187 | txt_labels = labels
188 | return self.forward_mlm(input_ids, position_ids, txt_type_ids,
189 | txt_lens,
190 | img_feat, img_pos_feat, num_bbs,
191 | attention_mask, txt_labels, compute_loss)
192 | elif task == 'mrm':
193 | img_mask = labels
194 | return self.forward_mrm(input_ids, position_ids, txt_type_ids,
195 | txt_lens,
196 | img_feat, img_pos_feat, num_bbs,
197 | attention_mask, img_mask, compute_loss)
198 | elif task.startswith('mrc'):
199 | img_mask, mrc_label_target = labels
200 | return self.forward_mrc(input_ids, position_ids, txt_type_ids,
201 | txt_lens,
202 | img_feat, img_pos_feat, num_bbs,
203 | attention_mask, img_mask,
204 | mrc_label_target, task, compute_loss)
205 | else:
206 | raise ValueError('invalid task')
207 |
208 | # MLM
209 | def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens,
210 | img_feat, img_pos_feat, num_bbs,
211 | attention_mask, txt_labels, compute_loss=True):
212 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
213 | img_feat, img_pos_feat, num_bbs,
214 | attention_mask,
215 | output_all_encoded_layers=False,
216 | txt_type_ids=txt_type_ids)
217 | # get only the text part
218 | sequence_output = sequence_output[:, :input_ids.size(1), :]
219 | # only compute masked tokens for better efficiency
220 | prediction_scores = self.masked_compute_scores(
221 | sequence_output, txt_labels != -1)
222 | if self.vocab_pad:
223 | prediction_scores = prediction_scores[:, :-self.vocab_pad]
224 |
225 | if compute_loss:
226 | masked_lm_loss = F.cross_entropy(prediction_scores,
227 | txt_labels[txt_labels != -1],
228 | reduction='none')
229 | return masked_lm_loss
230 | else:
231 | return prediction_scores
232 |
233 | # MRM
234 | def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens,
235 | img_feat, img_pos_feat, num_bbs,
236 | attention_mask, img_masks, compute_loss=True):
237 | img_feat, feat_targets = mask_img_feat(img_feat, img_masks)
238 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
239 | img_feat, img_pos_feat, num_bbs,
240 | attention_mask,
241 | output_all_encoded_layers=False,
242 | txt_type_ids=txt_type_ids)
243 | # get only the text part
244 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
245 | # only compute masked tokens for better efficiency
246 | prediction_feat = self.masked_compute_feat(
247 | sequence_output, img_masks)
248 |
249 | if compute_loss:
250 | mrm_loss = F.mse_loss(prediction_feat, feat_targets,
251 | reduction='none')
252 | return mrm_loss
253 | else:
254 | return prediction_feat
255 |
256 | # MRC
257 | def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens,
258 | img_feat, img_pos_feat, num_bbs,
259 | attention_mask, img_masks,
260 | label_targets, task, compute_loss=True):
261 | img_feat = mask_img_feat_for_mrc(img_feat, img_masks)
262 | sequence_output = self.bert(input_ids, position_ids, txt_lens,
263 | img_feat, img_pos_feat, num_bbs,
264 | attention_mask,
265 | output_all_encoded_layers=False,
266 | txt_type_ids=txt_type_ids)
267 | # get only the image part
268 | sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
269 | # only compute masked tokens for better efficiency
270 | prediction_soft_label = self.masked_predict_labels(
271 | sequence_output, img_masks)
272 |
273 | if compute_loss:
274 | if "kl" in task:
275 | prediction_soft_label = F.log_softmax(
276 | prediction_soft_label, dim=-1)
277 | mrc_loss = F.kl_div(
278 | prediction_soft_label, label_targets, reduction='none')
279 | else:
280 | label_targets = torch.max(
281 | label_targets, -1)[1] # argmax
282 | mrc_loss = F.cross_entropy(
283 | prediction_soft_label, label_targets,
284 | ignore_index=0, reduction='none')
285 | return mrc_loss
286 | else:
287 | return prediction_soft_label
288 |
--------------------------------------------------------------------------------
/model/ve.py:
--------------------------------------------------------------------------------
1 | """
2 | UNITER for VE model
3 | """
4 | from .vqa import UniterForVisualQuestionAnswering
5 |
6 |
7 | class UniterForVisualEntailment(UniterForVisualQuestionAnswering):
8 | """ Finetune multi-modal BERT for VE
9 | """
10 | def __init__(self, config, img_dim):
11 | super().__init__(config, img_dim, 3)
12 |
--------------------------------------------------------------------------------
/model/vqa.py:
--------------------------------------------------------------------------------
1 | """
2 | Bert for VQA model
3 | """
4 | from collections import defaultdict
5 |
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
9 |
10 | from .layer import GELU
11 | from .model import UniterPreTrainedModel, UniterModel, VLXLMRPreTrainedModel, VLXLMRModel
12 |
13 |
14 | class VLXLMRForVisualQuestionAnswering(VLXLMRPreTrainedModel):
15 | """ Finetune multi-modal BERT for VQA
16 | """
17 | def __init__(self, config, img_dim, num_answer):
18 | super().__init__(config)
19 | self.roberta = VLXLMRModel(config, img_dim)
20 | self.vqa_output = nn.Sequential(
21 | nn.Linear(config.hidden_size, config.hidden_size*2),
22 | GELU(),
23 | LayerNorm(config.hidden_size*2, eps=config.layer_norm_eps),
24 | nn.Linear(config.hidden_size*2, num_answer)
25 | )
26 | self.apply(self.init_weights)
27 |
28 | def forward(self, batch, compute_loss=True):
29 | batch = defaultdict(lambda: None, batch)
30 | input_ids = batch['input_ids']
31 | #position_ids = batch['position_ids']
32 | position_ids=None
33 | img_feat = batch['img_feat']
34 | img_pos_feat = batch['img_pos_feat']
35 | attn_masks = batch['attn_masks']
36 | gather_index = batch['gather_index']
37 | sequence_output = self.roberta(input_ids, position_ids,
38 | img_feat, img_pos_feat,
39 | attn_masks, gather_index,
40 | output_all_encoded_layers=False)
41 | pooled_output = self.roberta.pooler(sequence_output)
42 | answer_scores = self.vqa_output(pooled_output)
43 |
44 | if compute_loss:
45 | targets = batch['targets']
46 | vqa_loss = F.binary_cross_entropy_with_logits(
47 | answer_scores, targets, reduction='none')
48 | return vqa_loss
49 | else:
50 | return answer_scores
51 |
52 | class UniterForVisualQuestionAnswering(UniterPreTrainedModel):
53 | """ Finetune multi-modal BERT for VQA
54 | """
55 | def __init__(self, config, img_dim, num_answer):
56 | super().__init__(config)
57 | self.bert = UniterModel(config, img_dim)
58 | self.vqa_output = nn.Sequential(
59 | nn.Linear(config.hidden_size, config.hidden_size*2),
60 | GELU(),
61 | LayerNorm(config.hidden_size*2, eps=1e-12),
62 | nn.Linear(config.hidden_size*2, num_answer)
63 | )
64 | self.apply(self.init_weights)
65 |
66 | def forward(self, batch, compute_loss=True):
67 | batch = defaultdict(lambda: None, batch)
68 | input_ids = batch['input_ids']
69 | position_ids = batch['position_ids']
70 | img_feat = batch['img_feat']
71 | img_pos_feat = batch['img_pos_feat']
72 | attn_masks = batch['attn_masks']
73 | gather_index = batch['gather_index']
74 | sequence_output = self.bert(input_ids, position_ids,
75 | img_feat, img_pos_feat,
76 | attn_masks, gather_index,
77 | output_all_encoded_layers=False)
78 | pooled_output = self.bert.pooler(sequence_output)
79 | answer_scores = self.vqa_output(pooled_output)
80 |
81 | if compute_loss:
82 | targets = batch['targets']
83 | vqa_loss = F.binary_cross_entropy_with_logits(
84 | answer_scores, targets, reduction='none')
85 | return vqa_loss
86 | else:
87 | return answer_scores
88 |
--------------------------------------------------------------------------------
/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched, get_xlmr_lr_sched
2 | from .adamw import AdamW
3 |
--------------------------------------------------------------------------------
/optim/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/optim/__pycache__/adamw.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/adamw.cpython-36.pyc
--------------------------------------------------------------------------------
/optim/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/optim/__pycache__/sched.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/optim/__pycache__/sched.cpython-36.pyc
--------------------------------------------------------------------------------
/optim/adamw.py:
--------------------------------------------------------------------------------
1 | """
2 | AdamW optimizer (weight decay fix)
3 | copied from hugginface
4 | """
5 | import math
6 |
7 | import torch
8 | from torch.optim import Optimizer
9 |
10 |
11 | class AdamW(Optimizer):
12 | """ Implements Adam algorithm with weight decay fix.
13 | Parameters:
14 | lr (float): learning rate. Default 1e-3.
15 | betas (tuple of 2 floats): Adams beta parameters (b1, b2).
16 | Default: (0.9, 0.999)
17 | eps (float): Adams epsilon. Default: 1e-6
18 | weight_decay (float): Weight decay. Default: 0.0
19 | correct_bias (bool): can be set to False to avoid correcting bias
20 | in Adam (e.g. like in Bert TF repository). Default True.
21 | """
22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
23 | weight_decay=0.0, correct_bias=True):
24 | if lr < 0.0:
25 | raise ValueError(
26 | "Invalid learning rate: {} - should be >= 0.0".format(lr))
27 | if not 0.0 <= betas[0] < 1.0:
28 | raise ValueError("Invalid beta parameter: {} - "
29 | "should be in [0.0, 1.0[".format(betas[0]))
30 | if not 0.0 <= betas[1] < 1.0:
31 | raise ValueError("Invalid beta parameter: {} - "
32 | "should be in [0.0, 1.0[".format(betas[1]))
33 | if not 0.0 <= eps:
34 | raise ValueError("Invalid epsilon value: {} - "
35 | "should be >= 0.0".format(eps))
36 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
37 | correct_bias=correct_bias)
38 | super(AdamW, self).__init__(params, defaults)
39 |
40 | def step(self, closure=None):
41 | """Performs a single optimization step.
42 | Arguments:
43 | closure (callable, optional): A closure that reevaluates the model
44 | and returns the loss.
45 | """
46 | loss = None
47 | if closure is not None:
48 | loss = closure()
49 |
50 | for group in self.param_groups:
51 | for p in group['params']:
52 | if p.grad is None:
53 | continue
54 | grad = p.grad.data
55 | if grad.is_sparse:
56 | raise RuntimeError(
57 | 'Adam does not support sparse '
58 | 'gradients, please consider SparseAdam instead')
59 |
60 | state = self.state[p]
61 |
62 | # State initialization
63 | if len(state) == 0:
64 | state['step'] = 0
65 | # Exponential moving average of gradient values
66 | state['exp_avg'] = torch.zeros_like(p.data)
67 | # Exponential moving average of squared gradient values
68 | state['exp_avg_sq'] = torch.zeros_like(p.data)
69 |
70 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
71 | beta1, beta2 = group['betas']
72 |
73 | state['step'] += 1
74 |
75 | # Decay the first and second moment running average coefficient
76 | # In-place operations to update the averages at the same time
77 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
78 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
79 | denom = exp_avg_sq.sqrt().add_(group['eps'])
80 |
81 | step_size = group['lr']
82 | if group['correct_bias']: # No bias correction for Bert
83 | bias_correction1 = 1.0 - beta1 ** state['step']
84 | bias_correction2 = 1.0 - beta2 ** state['step']
85 | step_size = (step_size * math.sqrt(bias_correction2)
86 | / bias_correction1)
87 |
88 | p.data.addcdiv_(-step_size, exp_avg, denom)
89 |
90 | # Just adding the square of the weights to the loss function is
91 | # *not* the correct way of using L2 regularization/weight decay
92 | # with Adam, since that will interact with the m and v
93 | # parameters in strange ways.
94 | #
95 | # Instead we want to decay the weights in a manner that doesn't
96 | # interact with the m/v parameters. This is equivalent to
97 | # adding the square of the weights to the loss with plain
98 | # (non-momentum) SGD.
99 | # Add weight decay at the end (fixed version)
100 | if group['weight_decay'] > 0.0:
101 | p.data.add_(-group['lr'] * group['weight_decay'], p.data)
102 |
103 | return loss
104 |
--------------------------------------------------------------------------------
/optim/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc lr helper
3 | """
4 | from torch.optim import Adam, Adamax
5 |
6 | from .adamw import AdamW
7 |
8 |
9 | def build_optimizer(model, opts):
10 | param_optimizer = list(model.named_parameters())
11 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
12 | optimizer_grouped_parameters = [
13 | {'params': [p for n, p in param_optimizer
14 | if not any(nd in n for nd in no_decay)],
15 | 'weight_decay': opts.weight_decay},
16 | {'params': [p for n, p in param_optimizer
17 | if any(nd in n for nd in no_decay)],
18 | 'weight_decay': 0.0}
19 | ]
20 |
21 | # currently Adam only
22 | if opts.optim == 'adam':
23 | OptimCls = Adam
24 | elif opts.optim == 'adamax':
25 | OptimCls = Adamax
26 | elif opts.optim == 'adamw':
27 | OptimCls = AdamW
28 | else:
29 | raise ValueError('invalid optimizer')
30 | optimizer = OptimCls(optimizer_grouped_parameters,
31 | lr=opts.learning_rate, betas=opts.betas)
32 | return optimizer
33 |
34 | def xlmr_pretrained_encoder_layer(n, load_layer):
35 | """
36 | Verify whether n loads the pretrained weights from xlmr
37 | """
38 | assert isinstance(load_layer, int)
39 | #print(n)
40 | if 'roberta.encoder' in n:
41 | if int(n.split('.')[3]) <= load_layer:
42 | return True
43 | elif 'roberta.embeddings' in n:
44 | return True
45 |
46 | return False
47 |
48 | def build_xlmr_optimizer(model, opts):
49 | param_optimizer = list(model.named_parameters())
50 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
51 | #Divide the parameters to four groups, we will apply a relatively small lr to the word embedding
52 | #assert opts.load_embedding_only or isinstance(opts.load_layer, int)
53 | if opts.load_layer:
54 | assert isinstance(opts.load_layer,int) and opts.load_layer > 0
55 | optimizer_grouped_parameters = [
56 | {'params': [p for n, p in param_optimizer
57 | if not any(nd in n for nd in no_decay) and xlmr_pretrained_encoder_layer(n, opts.load_layer)],
58 | 'weight_decay': opts.weight_decay, 'lr': opts.xlmr_lr},
59 | {'params': [p for n, p in param_optimizer
60 | if any(nd in n for nd in no_decay) and xlmr_pretrained_encoder_layer(n, opts.load_layer)],
61 | 'weight_decay': 0.0, 'lr': opts.xlmr_lr},
62 | {'params': [p for n, p in param_optimizer
63 | if not any(nd in n for nd in no_decay) and not xlmr_pretrained_encoder_layer(n, opts.load_layer)],
64 | 'weight_decay': opts.weight_decay, 'lr': opts.learning_rate},
65 | {'params': [p for n, p in param_optimizer
66 | if any(nd in n for nd in no_decay) and not xlmr_pretrained_encoder_layer(n, opts.load_layer)],
67 | 'weight_decay': 0.0, 'lr': opts.learning_rate},
68 | ]
69 | else:
70 | #general case
71 | optimizer_grouped_parameters = [
72 | {'params': [p for n, p in param_optimizer
73 | if not any(nd in n for nd in no_decay) and 'roberta.embeddings' in n],
74 | 'weight_decay': opts.weight_decay, 'lr': opts.xlmr_lr},
75 | {'params': [p for n, p in param_optimizer
76 | if any(nd in n for nd in no_decay) and 'roberta.embeddings' in n],
77 | 'weight_decay': 0.0, 'lr': opts.xlmr_lr},
78 | {'params': [p for n, p in param_optimizer
79 | if not any(nd in n for nd in no_decay) and 'roberta.embeddings' not in n],
80 | 'weight_decay': opts.weight_decay, 'lr': opts.learning_rate},
81 | {'params': [p for n, p in param_optimizer
82 | if any(nd in n for nd in no_decay) and 'roberta.embeddings' not in n],
83 | 'weight_decay': 0.0, 'lr': opts.learning_rate},
84 | ]
85 |
86 |
87 | #print([n for n, p in param_optimizer if not any(nd in n for nd in no_decay) and 'roberta.embeddings' in n])
88 | # currently Adam only
89 | if opts.optim == 'adam':
90 | OptimCls = Adam
91 | elif opts.optim == 'adamax':
92 | OptimCls = Adamax
93 | elif opts.optim == 'adamw':
94 | OptimCls = AdamW
95 | else:
96 | raise ValueError('invalid optimizer')
97 |
98 | optimizer = OptimCls(optimizer_grouped_parameters,
99 | lr=opts.learning_rate, betas=opts.betas)
100 | return optimizer
101 |
102 |
103 |
--------------------------------------------------------------------------------
/optim/sched.py:
--------------------------------------------------------------------------------
1 | """
2 | optimizer learning rate scheduling helpers
3 | """
4 | from math import ceil
5 |
6 |
7 | def noam_schedule(step, warmup_step=4000):
8 | if step <= warmup_step:
9 | return step / warmup_step
10 | return (warmup_step ** 0.5) * (step ** -0.5)
11 |
12 |
13 | def warmup_linear(step, warmup_step, tot_step):
14 | if step < warmup_step:
15 | return step / warmup_step
16 | return max(0, (tot_step-step)/(tot_step-warmup_step))
17 |
18 |
19 | def vqa_schedule(step, warmup_interval, decay_interval,
20 | decay_start, decay_rate):
21 | """ VQA schedule from MCAN """
22 | if step < warmup_interval:
23 | return 1/4
24 | elif step < 2 * warmup_interval:
25 | return 2/4
26 | elif step < 3 * warmup_interval:
27 | return 3/4
28 | elif step >= decay_start:
29 | num_decay = ceil((step - decay_start) / decay_interval)
30 | return decay_rate ** num_decay
31 | else:
32 | return 1
33 |
34 |
35 | def get_lr_sched(global_step, opts):
36 | # learning rate scheduling
37 | if opts.decay == 'linear':
38 | lr_this_step = opts.learning_rate * warmup_linear(
39 | global_step, opts.warmup_steps, opts.num_train_steps)
40 | elif opts.decay == 'invsqrt':
41 | lr_this_step = opts.learning_rate * noam_schedule(
42 | global_step, opts.warmup_steps)
43 | elif opts.decay == 'constant':
44 | lr_this_step = opts.learning_rate
45 | elif opts.decay == 'vqa':
46 | lr_this_step = opts.learning_rate * vqa_schedule(
47 | global_step, opts.warm_int, opts.decay_int,
48 | opts.decay_st, opts.decay_rate)
49 | if lr_this_step <= 0:
50 | # save guard for possible miscalculation of train steps
51 | lr_this_step = 1e-8
52 | return lr_this_step
53 |
54 | def get_xlmr_lr_sched(global_step, opts):
55 | if opts.decay == 'linear':
56 | lr_this_step = opts.xlmr_lr * warmup_linear(
57 | global_step, opts.warmup_steps, opts.num_train_steps)
58 | elif opts.decay == 'invsqrt':
59 | lr_this_step = opts.xlmr_lr * noam_schedule(
60 | global_step, opts.warmup_steps)
61 | elif opts.decay == 'constant':
62 | lr_this_step = opts.xlmr_lr
63 | elif opts.decay == 'vqa':
64 | lr_this_step = opts.xlmr_lr * vqa_schedule(
65 | global_step, opts.warm_int, opts.decay_int,
66 | opts.decay_st, opts.decay_rate)
67 | if lr_this_step <= 0:
68 | # save guard for possible miscalculation of train steps
69 | lr_this_step = 1e-8
70 | return lr_this_step
71 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/const.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/const.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/distributed.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/distributed.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/save.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/save.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/visual_entailment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/visual_entailment.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/vqa.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zmykevin/UC2/50ef0acbac48f073483908d28f2e94a81dbd3642/utils/__pycache__/vqa.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/const.py:
--------------------------------------------------------------------------------
1 | """ constants """
2 | IMG_DIM = 2048
3 | IMG_LABEL_DIM = 1601
4 | BUCKET_SIZE = 8192
5 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | distributed API using Horovod
3 | """
4 | import math
5 | import pickle
6 |
7 | import torch
8 | from horovod import torch as hvd
9 |
10 | import msgpack
11 | import msgpack_numpy
12 | msgpack_numpy.patch()
13 |
14 |
15 | def all_reduce_and_rescale_tensors(tensors, rescale_denom):
16 | """All-reduce and rescale tensors at once (as a flattened tensor)
17 |
18 | Args:
19 | tensors: list of Tensors to all-reduce
20 | rescale_denom: denominator for rescaling summed Tensors
21 | """
22 | # buffer size in bytes, determine equiv. # of elements based on data type
23 | sz = sum(t.numel() for t in tensors)
24 | buffer_t = tensors[0].new(sz).zero_()
25 |
26 | # copy tensors into buffer_t
27 | offset = 0
28 | for t in tensors:
29 | numel = t.numel()
30 | buffer_t[offset:offset+numel].copy_(t.view(-1))
31 | offset += numel
32 |
33 | # all-reduce and rescale
34 | hvd.allreduce_(buffer_t[:offset])
35 | buffer_t.div_(rescale_denom)
36 |
37 | # copy all-reduced buffer back into tensors
38 | offset = 0
39 | for t in tensors:
40 | numel = t.numel()
41 | t.view(-1).copy_(buffer_t[offset:offset+numel])
42 | offset += numel
43 |
44 |
45 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,
46 | buffer_size=10485760):
47 | """All-reduce and rescale tensors in chunks of the specified size.
48 |
49 | Args:
50 | tensors: list of Tensors to all-reduce
51 | rescale_denom: denominator for rescaling summed Tensors
52 | buffer_size: all-reduce chunk size in bytes
53 | """
54 | # buffer size in bytes, determine equiv. # of elements based on data type
55 | buffer_t = tensors[0].new(
56 | math.ceil(buffer_size / tensors[0].element_size())).zero_()
57 | buffer = []
58 |
59 | def all_reduce_buffer():
60 | # copy tensors into buffer_t
61 | offset = 0
62 | for t in buffer:
63 | numel = t.numel()
64 | buffer_t[offset:offset+numel].copy_(t.view(-1))
65 | offset += numel
66 |
67 | # all-reduce and rescale
68 | hvd.allreduce_(buffer_t[:offset])
69 | buffer_t.div_(rescale_denom)
70 |
71 | # copy all-reduced buffer back into tensors
72 | offset = 0
73 | for t in buffer:
74 | numel = t.numel()
75 | t.view(-1).copy_(buffer_t[offset:offset+numel])
76 | offset += numel
77 |
78 | filled = 0
79 | for t in tensors:
80 | sz = t.numel() * t.element_size()
81 | if sz > buffer_size:
82 | # tensor is bigger than buffer, all-reduce and rescale directly
83 | hvd.allreduce_(t)
84 | t.div_(rescale_denom)
85 | elif filled + sz > buffer_size:
86 | # buffer is full, all-reduce and replace buffer with grad
87 | all_reduce_buffer()
88 | buffer = [t]
89 | filled = sz
90 | else:
91 | # add tensor to buffer
92 | buffer.append(t)
93 | filled += sz
94 |
95 | if len(buffer) > 0:
96 | all_reduce_buffer()
97 |
98 |
99 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
100 | """broadcast tensors in chunks of the specified size.
101 |
102 | Args:
103 | tensors: list of Tensors to broadcast
104 | root_rank: rank to broadcast
105 | buffer_size: broadcast chunk size in bytes
106 | """
107 | # buffer size in bytes, determine equiv. # of elements based on data type
108 | buffer_t = tensors[0].new(
109 | math.ceil(buffer_size / tensors[0].element_size())).zero_()
110 | buffer = []
111 |
112 | def broadcast_buffer():
113 | # copy tensors into buffer_t
114 | offset = 0
115 | for t in buffer:
116 | numel = t.numel()
117 | buffer_t[offset:offset+numel].copy_(t.view(-1))
118 | offset += numel
119 |
120 | # broadcast
121 | hvd.broadcast_(buffer_t[:offset], root_rank)
122 |
123 | # copy all-reduced buffer back into tensors
124 | offset = 0
125 | for t in buffer:
126 | numel = t.numel()
127 | t.view(-1).copy_(buffer_t[offset:offset+numel])
128 | offset += numel
129 |
130 | filled = 0
131 | for t in tensors:
132 | sz = t.numel() * t.element_size()
133 | if sz > buffer_size:
134 | # tensor is bigger than buffer, broadcast directly
135 | hvd.broadcast_(t, root_rank)
136 | elif filled + sz > buffer_size:
137 | # buffer is full, broadcast and replace buffer with tensor
138 | broadcast_buffer()
139 | buffer = [t]
140 | filled = sz
141 | else:
142 | # add tensor to buffer
143 | buffer.append(t)
144 | filled += sz
145 |
146 | if len(buffer) > 0:
147 | broadcast_buffer()
148 |
149 |
150 | def _encode(enc, max_size, buffer_=None):
151 | enc_size = len(enc)
152 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
153 | if buffer_ is None or len(buffer_) < enc_size + enc_byte:
154 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
155 | remainder = enc_size
156 | for i in range(enc_byte):
157 | base = 256 ** (enc_byte-i-1)
158 | buffer_[i] = remainder // base
159 | remainder %= base
160 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
161 | return buffer_, enc_byte
162 |
163 |
164 | def _decode(buffer_, enc_byte):
165 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item()
166 | for i in range(enc_byte))
167 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
168 | shift = size + enc_byte
169 | return bytes_list, shift
170 |
171 |
172 | _BUFFER_SIZE = 4096
173 |
174 |
175 | def all_gather_list(data):
176 | """Gathers arbitrary data from all nodes into a list."""
177 | if not hasattr(all_gather_list, '_buffer'):
178 | # keeps small buffer to avoid re-allocate every call
179 | all_gather_list._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE)
180 | try:
181 | enc = msgpack.dumps(data, use_bin_type=True)
182 | msgpack_success = True
183 | except TypeError:
184 | enc = pickle.dumps(data)
185 | msgpack_success = False
186 |
187 | enc_size = len(enc)
188 | max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item()
189 | buffer_ = all_gather_list._buffer
190 | in_buffer, enc_byte = _encode(enc, max_size, buffer_)
191 |
192 | out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size])
193 |
194 | results = []
195 | for _ in range(hvd.size()):
196 | bytes_list, shift = _decode(out_buffer, enc_byte)
197 | out_buffer = out_buffer[shift:]
198 |
199 | if msgpack_success:
200 | result = msgpack.loads(bytes_list, raw=False)
201 | else:
202 | result = pickle.loads(bytes_list)
203 | results.append(result)
204 | return results
205 |
206 |
207 | def any_broadcast(data, root_rank):
208 | """broadcast arbitrary data from root_rank to all nodes."""
209 | if not hasattr(any_broadcast, '_buffer'):
210 | # keeps small buffer to avoid re-allocate every call
211 | any_broadcast._buffer = torch.cuda.ByteTensor(_BUFFER_SIZE)
212 | try:
213 | enc = msgpack.dumps(data, use_bin_type=True)
214 | msgpack_success = True
215 | except TypeError:
216 | enc = pickle.dumps(data)
217 | msgpack_success = False
218 |
219 | max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item()
220 | buffer_ = any_broadcast._buffer
221 | buffer_, enc_byte = _encode(enc, max_size, buffer_)
222 |
223 | hvd.broadcast_(buffer_, root_rank)
224 |
225 | bytes_list, _ = _decode(buffer_, enc_byte)
226 | if msgpack_success:
227 | result = msgpack.loads(bytes_list, raw=False)
228 | else:
229 | result = pickle.loads(bytes_list)
230 | return result
231 |
--------------------------------------------------------------------------------
/utils/itm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def i2t(sims, npts=None, return_ranks=False):
5 | """
6 | Images->Text (Image Annotation)
7 | sims: (N, 5N) matrix of similarity im-cap
8 | """
9 | npts = sims.shape[0]
10 | ranks = np.zeros(npts)
11 | top1 = np.zeros(npts)
12 | for index in range(npts):
13 | inds = np.argsort(sims[index])[::-1]
14 | # Score
15 | rank = 1e20
16 | for i in range(5 * index, 5 * index + 5, 1):
17 | tmp = np.where(inds == i)[0][0]
18 | if tmp < rank:
19 | rank = tmp
20 | ranks[index] = rank
21 | top1[index] = inds[0]
22 |
23 | # Compute metrics
24 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
25 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
26 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
27 | medr = np.floor(np.median(ranks)) + 1
28 | meanr = ranks.mean() + 1
29 | if return_ranks:
30 | return (r1, r5, r10, medr, meanr), (ranks, top1)
31 | else:
32 | return (r1, r5, r10, medr, meanr)
33 |
34 |
35 | def t2i(sims, npts=None, return_ranks=False):
36 | """
37 | Text->Images (Image Search)
38 | sims: (N, 5N) matrix of similarity im-cap
39 | """
40 | npts = sims.shape[0]
41 | ranks = np.zeros(5 * npts)
42 | top1 = np.zeros(5 * npts)
43 |
44 | # --> (5N(caption), N(image))
45 | sims = sims.T
46 |
47 | for index in range(npts):
48 | for i in range(5):
49 | inds = np.argsort(sims[5 * index + i])[::-1]
50 | ranks[5 * index + i] = np.where(inds == index)[0][0]
51 | top1[5 * index + i] = inds[0]
52 |
53 | # Compute metrics
54 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
55 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
56 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
57 | medr = np.floor(np.median(ranks)) + 1
58 | meanr = ranks.mean() + 1
59 | if return_ranks:
60 | return (r1, r5, r10, medr, meanr), (ranks, top1)
61 | else:
62 | return (r1, r5, r10, medr, meanr)
63 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | helper for logging
3 | NOTE: loggers are global objects use with caution
4 | """
5 | import logging
6 | import math
7 |
8 | import tensorboardX
9 |
10 |
11 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
12 | _DATE_FMT = '%m/%d/%Y %H:%M:%S'
13 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
14 | LOGGER = logging.getLogger('__main__') # this is the global logger
15 |
16 |
17 | def add_log_to_file(log_path):
18 | fh = logging.FileHandler(log_path)
19 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
20 | fh.setFormatter(formatter)
21 | LOGGER.addHandler(fh)
22 |
23 |
24 | class TensorboardLogger(object):
25 | def __init__(self):
26 | self._logger = None
27 | self._global_step = 0
28 |
29 | def create(self, path):
30 | self._logger = tensorboardX.SummaryWriter(path)
31 |
32 | def noop(self, *args, **kwargs):
33 | return
34 |
35 | def step(self):
36 | self._global_step += 1
37 |
38 | @property
39 | def global_step(self):
40 | return self._global_step
41 |
42 | def log_scaler_dict(self, log_dict, prefix=''):
43 | """ log a dictionary of scalar values"""
44 | if self._logger is None:
45 | return
46 | if prefix:
47 | prefix = f'{prefix}_'
48 | for name, value in log_dict.items():
49 | if isinstance(value, dict):
50 | self.log_scaler_dict(value, self._global_step,
51 | prefix=f'{prefix}{name}')
52 | else:
53 | self._logger.add_scalar(f'{prefix}{name}', value,
54 | self._global_step)
55 |
56 | def __getattr__(self, name):
57 | if self._logger is None:
58 | return self.noop
59 | return self._logger.__getattribute__(name)
60 |
61 |
62 | TB_LOGGER = TensorboardLogger()
63 |
64 |
65 | class RunningMeter(object):
66 | """ running meteor of a scalar value
67 | (useful for monitoring training loss)
68 | """
69 | def __init__(self, name, val=None, smooth=0.99):
70 | self._name = name
71 | self._sm = smooth
72 | self._val = val
73 |
74 | def __call__(self, value):
75 | val = (value if self._val is None
76 | else value*(1-self._sm) + self._val*self._sm)
77 | if not math.isnan(val) and not math.isinf(val):
78 | self._val = val
79 | else:
80 | print(f'Inf/Nan in {self._name}')
81 |
82 | def __str__(self):
83 | return f'{self._name}: {self._val:.4f}'
84 |
85 | @property
86 | def val(self):
87 | return self._val
88 |
89 | @property
90 | def name(self):
91 | return self._name
92 |
--------------------------------------------------------------------------------
/utils/m3p_tokenizer.py:
--------------------------------------------------------------------------------
1 | import sentencepiece as spm
2 | import collections
3 |
4 | SPIECE_UNDERLINE = "▁"
5 |
6 | def load_vocab(vocab_file):
7 | """Loads a vocabulary file into a dictionary."""
8 | vocab = collections.OrderedDict()
9 | with open(vocab_file, "r", encoding="utf-8") as reader:
10 | tokens = reader.readlines()
11 | for index, token in enumerate(tokens):
12 | token = token.rstrip("\n")
13 | vocab[token] = index
14 | return vocab
15 |
16 | class XLMRTokenizer(object):
17 | def __init__(self, vocab_file,special_token=''):
18 | sp_model = spm.SentencePieceProcessor()
19 | sp_model.Load(str(vocab_file))
20 | self.sp_model = sp_model
21 | self.bos_token = ""
22 | self.eos_token = ""
23 | self.sep_token = ""
24 | self.cls_token = ""
25 | self.unk_token = ""
26 | self.pad_token = ""
27 | self.mask_token = ""
28 |
29 | self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3}
30 |
31 | # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
32 | self.fairseq_offset = 1
33 |
34 | self.fairseq_tokens_to_ids[""] = len(sp_model) + self.fairseq_offset
35 | self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
36 |
37 | self.cls_token_id = self._cls_token_id()
38 | self.sep_token_id = self._sep_token_id()
39 | self.pad_token_id = self._pad_token_id()
40 | self.mask_token_id = self._mask_token_id()
41 | self.eos_token_id = self._eos_token_id()
42 |
43 |
44 | def _cls_token_id(self):
45 | """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
46 | return self._convert_token_to_id(self.cls_token)
47 | def _sep_token_id(self):
48 | """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
49 | return self._convert_token_to_id(self.sep_token)
50 | def _pad_token_id(self):
51 | return self._convert_token_to_id(self.pad_token)
52 | def _eos_token_id(self):
53 | return self._convert_token_to_id(self.eos_token)
54 | def _mask_token_id(self):
55 | return self._convert_token_to_id(self.mask_token)
56 |
57 | def build_inputs_with_special_tokens(
58 | self,token_ids_0, token_ids_1):
59 | """
60 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks
61 | by concatenating and adding special tokens.
62 | A BERT sequence has the following format:
63 | - single sequence: ``[CLS] X [SEP]``
64 | - pair of sequences: ``[CLS] A [SEP] B [SEP]``
65 | Args:
66 | token_ids_0 (:obj:`List[int]`):
67 | List of IDs to which the special tokens will be added
68 | token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
69 | Optional second list of IDs for sequence pairs.
70 | Returns:
71 | :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
72 | """
73 | if token_ids_1 is None:
74 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
75 | cls = [self.cls_token_id]
76 | sep = [self.sep_token_id]
77 | return cls + token_ids_0 + sep + token_ids_1 + sep
78 |
79 | @property
80 | def vocab_size(self):
81 | return len(self.sp_model) + self.fairseq_offset + 1 # Add the token
82 |
83 | def get_vocab(self):
84 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
85 | return vocab
86 |
87 | def _tokenize(self,text):
88 | return self.sp_model.EncodeAsPieces(text)
89 |
90 | def _convert_token_to_id(self,token):
91 | """ Converts a token (str) in an id using the vocab. """
92 | if token in self.fairseq_tokens_to_ids:
93 | return self.fairseq_tokens_to_ids[token]
94 | spm_id = self.sp_model.PieceToId(token)
95 |
96 | # Need to return unknown token if the SP model returned 0
97 | return spm_id + self.fairseq_offset if spm_id else self.fairseq_tokens_to_ids[self.unk_token]
98 |
99 | def encode(self,text,text_b=None):
100 | tokens = self._tokenize(text)
101 | input_ids = []
102 | for token in tokens:
103 | input_ids.append(self._convert_token_to_id(token))
104 |
105 | input_ids_b = None
106 | if text_b is not None:
107 | input_ids_b = []
108 | tokens_b = self._tokenize(text_b)
109 | for token in tokens_b:
110 | input_ids_b.append(self._convert_token_to_id(token))
111 |
112 | #input_ids = self.build_inputs_with_special_tokens(input_ids,input_ids_b)
113 |
114 | return input_ids
115 |
116 | def decode(self,token_ids):
117 | _out = []
118 | for token_id in token_ids:
119 | _out.append(self._convert_id_to_token(token_id))
120 | return self.convert_tokens_to_string(_out)
121 |
122 | def _convert_id_to_token(self,index):
123 | """Converts an index (integer) in a token (str) using the vocab."""
124 | if index in self.fairseq_ids_to_tokens:
125 | return self.fairseq_ids_to_tokens[index]
126 | return self.sp_model.IdToPiece(index - self.fairseq_offset)
127 | def convert_tokens_to_string(self, tokens):
128 | """Converts a sequence of tokens (strings for sub-words) in a single string."""
129 | out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
130 | return out_string
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc utilities
3 | """
4 | import json
5 | import random
6 | import sys
7 |
8 | import torch
9 | import numpy as np
10 |
11 | from utils.logger import LOGGER
12 |
13 |
14 | class NoOp(object):
15 | """ useful for distributed training No-Ops """
16 | def __getattr__(self, name):
17 | return self.noop
18 |
19 | def noop(self, *args, **kwargs):
20 | return
21 |
22 |
23 | def parse_with_config(parser):
24 | args = parser.parse_args()
25 | if args.config is not None:
26 | config_args = json.load(open(args.config))
27 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
28 | if arg.startswith('--')}
29 | for k, v in config_args.items():
30 | if k not in override_keys:
31 | setattr(args, k, v)
32 | del args.config
33 | return args
34 |
35 |
36 | VE_ENT2IDX = {
37 | 'contradiction': 0,
38 | 'entailment': 1,
39 | 'neutral': 2
40 | }
41 |
42 | VE_IDX2ENT = {
43 | 0: 'contradiction',
44 | 1: 'entailment',
45 | 2: 'neutral'
46 | }
47 |
48 |
49 | class Struct(object):
50 | def __init__(self, dict_):
51 | self.__dict__.update(dict_)
52 |
53 |
54 | def set_dropout(model, drop_p):
55 | for name, module in model.named_modules():
56 | # we might want to tune dropout for smaller dataset
57 | if isinstance(module, torch.nn.Dropout):
58 | if module.p != drop_p:
59 | module.p = drop_p
60 | LOGGER.info(f'{name} set to {drop_p}')
61 |
62 |
63 | def set_random_seed(seed):
64 | random.seed(seed)
65 | np.random.seed(seed)
66 | torch.manual_seed(seed)
67 | torch.cuda.manual_seed_all(seed)
68 |
--------------------------------------------------------------------------------
/utils/ms_internal_mt.py:
--------------------------------------------------------------------------------
1 | import os, requests, uuid, json
2 | import csv
3 | import sys
4 | from tqdm import tqdm
5 | from shutil import copyfile
6 |
7 | csv.field_size_limit(sys.maxsize)
8 |
9 | key_var_name = 'TRANSLATOR_TEXT_SUBSCRIPTION_KEY'
10 | if not key_var_name in os.environ:
11 | raise Exception('Please set/export the environment variable: {}'.format(key_var_name))
12 |
13 | subscription_key = os.environ[key_var_name]
14 | endpoint_var_name = 'TRANSLATOR_TEXT_ENDPOINT'
15 | if not endpoint_var_name in os.environ:
16 | raise Exception('Please set/export the environment variable: {}'.format(endpoint_var_name))
17 | endpoint = os.environ[endpoint_var_name]
18 |
19 | # If you encounter any issues with the base_url or path, make sure
20 | # that you are using the latest endpoint: https://docs.microsoft.com/azure/cognitive-services/translator/reference/v3-0-translate
21 | path = '/translate?api-version=3.0'
22 | constructed_url = endpoint + path
23 |
24 | headers = {
25 | 'Ocp-Apim-Subscription-Key': subscription_key,
26 | 'Content-type': 'application/json',
27 | 'X-ClientTraceId': str(uuid.uuid4())
28 | }
29 |
30 | def msft_translate(lines, langs, batch_size=10):
31 | global constructed_url, headers
32 | if 'zh' in langs:
33 | langs = langs.replace('zh', 'zh-Hans')
34 | params = "&" + "&".join(["to={}".format(lang) for lang in langs.split(',')]) + "&includeAlignment=true&includeSentenceLength=true"
35 | request_url = constructed_url + params
36 | # Body limitations:
37 | # The array can have at most 100 elements.
38 | # The entire text included in the request cannot exceed 5,000 characters including spaces.
39 | out_translations = []
40 |
41 | #for idx in tqdm(range(0, len(lines), batch_size)):
42 | for idx in range(0, len(lines), batch_size):
43 | body = [{'text': line} for line in lines[idx:idx+batch_size]]
44 |
45 | request = requests.post(request_url, headers=headers, json=body)
46 | response = request.json()
47 |
48 | out_translations += response
49 |
50 | return out_translations
51 |
52 | # in_file = ''
53 | # out_file = ''
54 | # tgt_langs = 'en'
55 | # batch_size = 10
56 | split = "train"
57 | in_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann.tsv".format(split=split)
58 | tgt_langs = 'fr'
59 | backup_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann_{tgt_langs}_backup.tsv".format(split=split, tgt_langs=tgt_langs)
60 | out_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_imageId2Ann_{tgt_langs}.tsv".format(split=split, tgt_langs=tgt_langs)
61 | alignment_out_file = "/home/mingyang/multilingual_vl/data/azure_mount/UNITER_DATA/CC/annotations/{split}_en_{tgt_langs}_word_alignemnt.tsv".format(split=split, tgt_langs=tgt_langs)
62 | line_batch_size=100
63 | #batch_size = 100
64 | save_step = 10000
65 | #with open(in_file, "r") as tsv_src:
66 | tsv_src = open(in_file)
67 |
68 | #get the tsv_tgt_data
69 | tsv_tgt_data = []
70 | src_captions = []
71 | #src_tgt_alignments = []
72 |
73 | if os.path.isfile(backup_file):
74 | #load the tsv_tgt_data
75 | tsv_tgt_backup = open(backup_file)
76 | for tgt_line in tsv_tgt_backup:
77 | img_id, img_url, tgt_caption, success = tgt_line.strip().split('\t')
78 | tsv_tgt_data.append([img_id, img_url, tgt_caption, success])
79 | #set up the resume point
80 | resume_step = len(tsv_tgt_data)
81 | print("resume_step is: {}".format(resume_step))
82 | for i, src_line in enumerate(tsv_src):
83 | img_id, img_url, src_caption, success = src_line.strip().split('\t')
84 | src_captions.append(src_caption)
85 | #split the fields
86 | if i >= resume_step:
87 | tsv_tgt_data.append([img_id, img_url, "", success])
88 | #src_tgt_alignments.append([img_id, ""])
89 |
90 | #print(len(tsv_tgt_data) == len(src_captions))
91 | for i in tqdm(range(resume_step, len(src_captions), line_batch_size)):
92 | if i + line_batch_size < len(src_captions):
93 | current_src_batch = src_captions[i:i+line_batch_size]
94 | else:
95 | current_src_batch = src_captions[i:]
96 |
97 | #get the translations
98 | batch_size = len(current_src_batch)
99 | current_translation_batch = msft_translate(current_src_batch, langs=tgt_langs, batch_size=batch_size)
100 | #update tsv_tgt_data
101 | for j in range(len(current_translation_batch)):
102 | if current_translation_batch[j]['translations'][0].get('text', None):
103 | translated_text = current_translation_batch[j]['translations'][0]['text']
104 | if len(translated_text) > 10000:
105 | tsv_tgt_data[i+j][3] = "fail"
106 | else:
107 | tsv_tgt_data[i+j][2] = translated_text
108 | # print(current_translation_batch[j])
109 | # raise Exception("something wrong happens to the translation: {}".format(current_translation_batch[j]))
110 | # finally:
111 | # print(current_translation_batch[j])
112 | #retried_translation = msft_translate()
113 | # if current_translation_batch[j]['translations'][0].get('alignment', None):
114 | # src_tgt_alignments[i+j][1] = current_translation_batch[j]['translations'][0]['alignment']['proj']
115 | if (i+1) % save_step == 1:
116 | saved_tsv_tgt_data = tsv_tgt_data[:i+line_batch_size]
117 | #saved_src_tgt_alignments = src_tgt_alignments[:i+line_batch_size]
118 | with open(out_file, "w") as f_out:
119 | tsv_output = csv.writer(f_out, delimiter='\t')
120 | for saved_line in saved_tsv_tgt_data:
121 | tsv_output.writerow(saved_line)
122 | #copy the file to the backup file
123 | copyfile(out_file, backup_file)
124 | #save the final data to file
125 | with open(out_file, "w") as f_out:
126 | tsv_output = csv.writer(f_out, delimiter='\t')
127 | for saved_line in tsv_tgt_data:
128 | tsv_output.writerow(saved_line)
129 | tsv_src.close()
130 |
131 |
132 |
--------------------------------------------------------------------------------
/utils/ms_internal_mt_label.py:
--------------------------------------------------------------------------------
1 | import os, requests, uuid, json
2 | import csv
3 | import sys
4 | from tqdm import tqdm
5 | from shutil import copyfile
6 |
7 | csv.field_size_limit(sys.maxsize)
8 |
9 | key_var_name = 'TRANSLATOR_TEXT_SUBSCRIPTION_KEY'
10 | if not key_var_name in os.environ:
11 | raise Exception('Please set/export the environment variable: {}'.format(key_var_name))
12 |
13 | subscription_key = os.environ[key_var_name]
14 | endpoint_var_name = 'TRANSLATOR_TEXT_ENDPOINT'
15 | if not endpoint_var_name in os.environ:
16 | raise Exception('Please set/export the environment variable: {}'.format(endpoint_var_name))
17 | endpoint = os.environ[endpoint_var_name]
18 |
19 | # If you encounter any issues with the base_url or path, make sure
20 | # that you are using the latest endpoint: https://docs.microsoft.com/azure/cognitive-services/translator/reference/v3-0-translate
21 | path = '/translate?api-version=3.0'
22 | constructed_url = endpoint + path
23 |
24 | headers = {
25 | 'Ocp-Apim-Subscription-Key': subscription_key,
26 | 'Content-type': 'application/json',
27 | 'X-ClientTraceId': str(uuid.uuid4())
28 | }
29 |
30 | def msft_translate(lines, langs, batch_size=10):
31 | global constructed_url, headers
32 | if 'zh' in langs:
33 | langs = langs.replace('zh', 'zh-Hans')
34 | params = "&" + "&".join(["to={}".format(lang) for lang in langs.split(',')]) + "&includeAlignment=true&includeSentenceLength=true"
35 | request_url = constructed_url + params
36 | # Body limitations:
37 | # The array can have at most 100 elements.
38 | # The entire text included in the request cannot exceed 5,000 characters including spaces.
39 | out_translations = []
40 |
41 | #for idx in tqdm(range(0, len(lines), batch_size)):
42 | for idx in range(0, len(lines), batch_size):
43 | body = [{'text': line} for line in lines[idx:idx+batch_size]]
44 |
45 | request = requests.post(request_url, headers=headers, json=body)
46 | response = request.json()
47 |
48 | out_translations += response
49 |
50 | return out_translations
51 |
52 |
53 | source_label_file = "img_label_objects.txt"
54 | tgt_lan = "cs"
55 | tgt_label_file = "img_label_objects_{}.txt".format(tgt_lan)
56 | batch_size = 100
57 |
58 | source_labels = []
59 | tgt_labels = []
60 | with open(source_label_file, "r") as f:
61 | for line in f:
62 | source_labels.append(line.strip())
63 |
64 | for i in tqdm(range(0, len(source_labels), batch_size)):
65 | batch_upper_limit = i+batch_size if i+batch_size < len(source_labels) else len(source_labels)
66 | current_src_batch = source_labels[i:batch_upper_limit]
67 | current_batch_size = len(current_src_batch)
68 |
69 | current_translation_batch = msft_translate(current_src_batch, langs=tgt_lan, batch_size=current_batch_size)
70 | for translated_data in current_translation_batch:
71 | assert translated_data['translations'][0].get('text', None)
72 | tgt_labels.append(translated_data['translations'][0].get('text', None))
73 | #write the tgt_labels to outputfile
74 | with open(tgt_label_file, 'w') as f_out:
75 | for label in tgt_labels:
76 | f_out.write(label+'\n')
--------------------------------------------------------------------------------
/utils/ms_internal_mt_popen.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | filename = "ms_internal_mt.py"
4 | while True:
5 | p = subprocess.Popen('python '+filename, shell=True).wait()
6 | if p!=0:
7 | print("An error occured in the translation process. Restaring....")
8 | continue
9 | else:
10 | print("The translation process is Done")
11 | break
--------------------------------------------------------------------------------
/utils/save.py:
--------------------------------------------------------------------------------
1 | """
2 | saving utilities
3 | """
4 | from collections import OrderedDict
5 | import json
6 | import os
7 | from os.path import abspath, dirname, exists, join
8 | import subprocess
9 |
10 | import torch
11 | from apex.amp._amp_state import _amp_state
12 |
13 | from utils.logger import LOGGER
14 | from horovod import torch as hvd
15 |
16 |
17 | def save_training_meta(args):
18 | if args.rank > 0:
19 | return
20 |
21 | if not exists(args.output_dir):
22 | os.makedirs(join(args.output_dir, 'log'))
23 | os.makedirs(join(args.output_dir, 'ckpt'))
24 |
25 | with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer:
26 | json.dump(vars(args), writer, indent=4)
27 | model_config = json.load(open(args.model_config))
28 | with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer:
29 | json.dump(model_config, writer, indent=4)
30 | # git info
31 | #Mingyang: Ignore the git info
32 | # try:
33 | # LOGGER.info("Waiting on git info....")
34 | # c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"],
35 | # timeout=10, stdout=subprocess.PIPE)
36 | # git_branch_name = c.stdout.decode().strip()
37 | # LOGGER.info("Git branch: %s", git_branch_name)
38 | # c = subprocess.run(["git", "rev-parse", "HEAD"],
39 | # timeout=10, stdout=subprocess.PIPE)
40 | # git_sha = c.stdout.decode().strip()
41 | # LOGGER.info("Git SHA: %s", git_sha)
42 | # git_dir = abspath(dirname(__file__))
43 | # git_status = subprocess.check_output(
44 | # ['git', 'status', '--short'],
45 | # cwd=git_dir, universal_newlines=True).strip()
46 | # with open(join(args.output_dir, 'log', 'git_info.json'),
47 | # 'w') as writer:
48 | # json.dump({'branch': git_branch_name,
49 | # 'is_dirty': bool(git_status),
50 | # 'status': git_status,
51 | # 'sha': git_sha},
52 | # writer, indent=4)
53 | # except subprocess.TimeoutExpired as e:
54 | # LOGGER.exception(e)
55 | # LOGGER.warn("Git info not found. Moving right along...")
56 |
57 |
58 | class ModelSaver(object):
59 | def __init__(self, output_dir, prefix='model_step', suffix='pt'):
60 | self.output_dir = output_dir
61 | self.prefix = prefix
62 | self.suffix = suffix
63 |
64 | def save(self, model, step, optimizer=None):
65 | output_model_file = join(self.output_dir,
66 | f"{self.prefix}_{step}.{self.suffix}")
67 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
68 | for k, v in model.state_dict().items()}
69 | if hasattr(model, 'vocab_pad') and model.vocab_pad:
70 | # store vocab embeddings before padding
71 | emb_w = state_dict['bert.embeddings.word_embeddings.weight']
72 | emb_w = emb_w[:-model.vocab_pad, :]
73 | state_dict['bert.embeddings.word_embeddings.weight'] = emb_w
74 | state_dict['cls.predictions.decoder.weight'] = emb_w
75 | torch.save(state_dict, output_model_file)
76 | if optimizer is not None:
77 | dump = {'step': step, 'optimizer': optimizer.state_dict()}
78 | if hasattr(optimizer, '_amp_stash'):
79 | pass # TODO fp16 optimizer
80 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')
81 |
82 |
83 | def amp_state_dict(destination=None):
84 | if destination is None:
85 | destination = OrderedDict()
86 |
87 | for idx, loss_scaler in enumerate(_amp_state.loss_scalers):
88 | destination['loss_scaler%d' % idx] = {
89 | 'loss_scale': loss_scaler.loss_scale(),
90 | 'unskipped': loss_scaler._unskipped,
91 | }
92 | return destination
93 |
94 |
95 | def amp_load_state_dict(state_dict):
96 | # Check if state_dict containes the same number of loss_scalers as current
97 | # setup
98 | if len(state_dict) != len(_amp_state.loss_scalers):
99 | print('Warning: state_dict contains {} entries, while {} loss_scalers '
100 | 'are used'.format(len(state_dict), len(_amp_state.loss_scalers)))
101 |
102 | state_dict = state_dict.copy()
103 |
104 | nb_loss_scalers = len(_amp_state.loss_scalers)
105 | unexpected_keys = []
106 | # Initialize idx outside, since unexpected_keys will increase it if
107 | # enumerate is used
108 | idx = 0
109 | for key in state_dict:
110 | if 'loss_scaler' not in key:
111 | unexpected_keys.append(key)
112 | else:
113 | if idx > (nb_loss_scalers - 1):
114 | print('Skipping loss_scaler[{}], since num_losses was set to '
115 | '{}'.format(idx, nb_loss_scalers))
116 | break
117 | _amp_state.loss_scalers[idx]._loss_scale = (
118 | state_dict[key]['loss_scale'])
119 | _amp_state.loss_scalers[idx]._unskipped = (
120 | state_dict[key]['unskipped'])
121 | idx += 1
122 |
123 | if len(unexpected_keys) > 0:
124 | raise RuntimeError(
125 | 'Error(s) in loading state_dict. Unexpected key(s) in state_dict: '
126 | '{}. '.format(', '.join('"{}"'.format(k)
127 | for k in unexpected_keys)))
128 |
129 | def _to_cuda(state):
130 | """ usually load from cpu checkpoint but need to load to cuda """
131 | if isinstance(state, torch.Tensor):
132 | ret = state.cuda() # assume propoerly set py torch.cuda.set_device
133 | if 'Half' in state.type():
134 | ret = ret.float() # apex O2 requires it
135 | return ret
136 | elif isinstance(state, list):
137 | new_state = [_to_cuda(t) for t in state]
138 | elif isinstance(state, tuple):
139 | new_state = tuple(_to_cuda(t) for t in state)
140 | elif isinstance(state, dict):
141 | new_state = {n: _to_cuda(t) for n, t in state.items()}
142 | else:
143 | return state
144 | return new_state
145 |
146 |
147 | def _to_cpu(state):
148 | """ store in cpu to avoid GPU0 device, fp16 to save space """
149 | if isinstance(state, torch.Tensor):
150 | ret = state.cpu()
151 | if 'Float' in state.type():
152 | ret = ret.half()
153 | return ret
154 | elif isinstance(state, list):
155 | new_state = [_to_cpu(t) for t in state]
156 | elif isinstance(state, tuple):
157 | new_state = tuple(_to_cpu(t) for t in state)
158 | elif isinstance(state, dict):
159 | new_state = {n: _to_cpu(t) for n, t in state.items()}
160 | else:
161 | return state
162 | return new_state
163 |
164 | class TrainingRestorer(object):
165 | def __init__(self, opts, model, optimizer):
166 | if exists(opts.output_dir) and hvd.rank() == 0:
167 | restore_opts = json.load(
168 | open(f'{opts.output_dir}/log/hps.json', 'r'))
169 | with open(join(
170 | opts.output_dir, 'log',
171 | 'restore_hps.json'), 'w') as writer:
172 | json.dump(vars(opts), writer, indent=4)
173 | assert vars(opts) == restore_opts
174 | # keep 2 checkpoints in case of corrupted
175 | self.save_path = f'{opts.output_dir}/restore.pt'
176 | self.backup_path = f'{opts.output_dir}/restore_backup.pt'
177 | self.model = model
178 | self.optimizer = optimizer
179 | self.save_steps = opts.save_steps
180 | self.amp = opts.fp16
181 | if exists(self.save_path) or exists(self.backup_path):
182 | LOGGER.info('found previous checkpoint. try to resume...')
183 | self.restore(opts)
184 | else:
185 | self.global_step = 0
186 |
187 | def step(self):
188 | self.global_step += 1
189 | if self.global_step % self.save_steps == 0:
190 | self.save()
191 |
192 | def save(self):
193 | checkpoint = {'global_step': self.global_step,
194 | 'model_state_dict': _to_cpu(self.model.state_dict()),
195 | 'optim_state_dict': _to_cpu(self.optimizer.state_dict())}
196 | if self.amp:
197 | checkpoint['amp_state_dict'] = amp_state_dict()
198 | if exists(self.save_path):
199 | os.rename(self.save_path, self.backup_path)
200 | torch.save(checkpoint, self.save_path)
201 |
202 | def restore(self, opts):
203 | try:
204 | checkpoint = torch.load(self.save_path)
205 | except Exception:
206 | checkpoint = torch.load(self.backup_path)
207 | self.global_step = checkpoint['global_step']
208 | self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict']))
209 | self.optimizer.load_state_dict(
210 | _to_cuda(checkpoint['optim_state_dict']))
211 | if self.amp:
212 | amp_load_state_dict(checkpoint['amp_state_dict'])
213 | LOGGER.info(f'resume training from step {self.global_step}')
214 |
--------------------------------------------------------------------------------
/utils/visual_entailment.py:
--------------------------------------------------------------------------------
1 | """
2 | NOTE: modified from ban-vqa
3 | This code is slightly modified from Hengyuan Hu's repository.
4 | https://github.com/hengyuan-hu/bottom-up-attention-vqa
5 | """
6 | import os
7 | import sys
8 | import pickle
9 |
10 |
11 | def create_ans2label(path):
12 | """
13 | occurence: dict {answer -> whatever}
14 | name: dir of the output file
15 | """
16 | ans2label = {"contradiction": 0, "entailment":1 , "neutral": 2}
17 | label2ans = ["contradiction", "entailment", "neutral"]
18 |
19 | output_file = os.path.join(path, 'visual_entailment_ans2label.pkl')
20 | pickle.dump(ans2label, open(output_file, 'wb'))
21 |
22 |
23 | def compute_target(answers, ans2label):
24 | answer_count = {}
25 | for answer in answers:
26 | answer_ = answer
27 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
28 |
29 | labels = []
30 | scores = []
31 | for answer in answer_count:
32 | if answer not in ans2label:
33 | continue
34 | labels.append(ans2label[answer])
35 | score = answer_count[answer]/len(answers)
36 | scores.append(score)
37 | target = {'labels': labels, 'scores': scores}
38 | return target
39 |
40 |
41 | if __name__ == '__main__':
42 | output = sys.argv[1:][0]
43 | print(output)
44 | if os.path.exists(f'{output}/visual_entailment_ans2label.pkl'):
45 | raise ValueError(f'{output} already exists')
46 | create_ans2label(output)
47 |
--------------------------------------------------------------------------------
/utils/vqa.py:
--------------------------------------------------------------------------------
1 | """
2 | NOTE: modified from ban-vqa
3 | This code is slightly modified from Hengyuan Hu's repository.
4 | https://github.com/hengyuan-hu/bottom-up-attention-vqa
5 | """
6 | import os
7 | import json
8 | import re
9 | import sys
10 | import pickle
11 |
12 |
13 | CONTRACTIONS = {
14 | "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
15 | "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
16 | "couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
17 | "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
18 | "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
19 | "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
20 | "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
21 | "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
22 | "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
23 | "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
24 | "maam": "ma'am", "mightnt": "mightn't", "mightnt've":
25 | "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
26 | "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
27 | "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
28 | "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
29 | "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
30 | "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
31 | "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
32 | "shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
33 | "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
34 | "somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
35 | "someoned've": "someone'd've", "someone'dve": "someone'd've",
36 | "someonell": "someone'll", "someones": "someone's", "somethingd":
37 | "something'd", "somethingd've": "something'd've", "something'dve":
38 | "something'd've", "somethingll": "something'll", "thats":
39 | "that's", "thered": "there'd", "thered've": "there'd've",
40 | "there'dve": "there'd've", "therere": "there're", "theres":
41 | "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
42 | "they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
43 | "they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
44 | "we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
45 | "weren't", "whatll": "what'll", "whatre": "what're", "whats":
46 | "what's", "whatve": "what've", "whens": "when's", "whered":
47 | "where'd", "wheres": "where's", "whereve": "where've", "whod":
48 | "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
49 | "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
50 | "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
51 | "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
52 | "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
53 | "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
54 | "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
55 | "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
56 | "you'll", "youre": "you're", "youve": "you've"
57 | }
58 |
59 | MANUAL_MAP = {'none': '0',
60 | 'zero': '0',
61 | 'one': '1',
62 | 'two': '2',
63 | 'three': '3',
64 | 'four': '4',
65 | 'five': '5',
66 | 'six': '6',
67 | 'seven': '7',
68 | 'eight': '8',
69 | 'nine': '9',
70 | 'ten': '10'}
71 | ARTICLES = ['a', 'an', 'the']
72 | PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
73 | COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
74 | PUNCT = [';', r"/", '[', ']', '"', '{', '}',
75 | '(', ')', '=', '+', '\\', '_', '-',
76 | '>', '<', '@', '`', ',', '?', '!']
77 |
78 |
79 | # Notice that VQA score is the average of 10 choose 9 candidate answers cases
80 | # See http://visualqa.org/evaluation.html
81 | def get_score(occurences):
82 | if occurences == 0:
83 | return .0
84 | elif occurences == 1:
85 | return .3
86 | elif occurences == 2:
87 | return .6
88 | elif occurences == 3:
89 | return .9
90 | else:
91 | return 1.
92 |
93 |
94 | def process_punctuation(inText):
95 | outText = inText
96 | for p in PUNCT:
97 | if (p + ' ' in inText
98 | or ' ' + p in inText
99 | or re.search(COMMA_STRIP, inText) is not None):
100 | outText = outText.replace(p, '')
101 | else:
102 | outText = outText.replace(p, ' ')
103 | outText = PERIOD_STRIP.sub("", outText, re.UNICODE)
104 | return outText
105 |
106 |
107 | def process_digit_article(inText):
108 | outText = []
109 | tempText = inText.lower().split()
110 | for word in tempText:
111 | word = MANUAL_MAP.setdefault(word, word)
112 | if word not in ARTICLES:
113 | outText.append(word)
114 | else:
115 | pass
116 | for wordId, word in enumerate(outText):
117 | if word in CONTRACTIONS:
118 | outText[wordId] = CONTRACTIONS[word]
119 | outText = ' '.join(outText)
120 | return outText
121 |
122 |
123 | def preprocess_answer(answer):
124 | answer = process_digit_article(process_punctuation(answer))
125 | answer = answer.replace(',', '')
126 | return answer
127 |
128 |
129 | def filter_answers(answers_dset, min_occurence):
130 | """This will change the answer to preprocessed version
131 | """
132 | occurence = {}
133 |
134 | for ans_entry in answers_dset:
135 | gtruth = ans_entry.get('multiple_choice_answer', None)
136 | if gtruth is None:
137 | gtruth = ans_entry['answers'][0]['answer'] # VG, GQA pretraining
138 | gtruth = preprocess_answer(gtruth)
139 | if gtruth not in occurence:
140 | occurence[gtruth] = set()
141 | occurence[gtruth].add(ans_entry['question_id'])
142 | for answer in list(occurence):
143 | if len(occurence[answer]) < min_occurence:
144 | occurence.pop(answer)
145 |
146 | print('Num of answers that appear >= %d times: %d' % (
147 | min_occurence, len(occurence)))
148 | return occurence
149 |
150 |
151 | def create_ans2label(occurence, path):
152 | """
153 | occurence: dict {answer -> whatever}
154 | name: dir of the output file
155 | """
156 | ans2label = {}
157 | label2ans = []
158 | label = 0
159 | for answer in occurence:
160 | label2ans.append(answer)
161 | ans2label[answer] = label
162 | label += 1
163 |
164 | output_file = os.path.join(path, 'ans2label.pkl')
165 | pickle.dump(ans2label, open(output_file, 'wb'))
166 |
167 |
168 | def compute_target(answers, ans2label):
169 | answer_count = {}
170 | if len(answers) == 1:
171 | # VG VQA, GQA
172 | answer_ = preprocess_answer(answers[0]['answer'])
173 | answer_count[answer_] = 10
174 | else:
175 | # COCO VQA
176 | for answer in answers:
177 | answer_ = preprocess_answer(answer['answer'])
178 | answer_count[answer_] = answer_count.get(answer_, 0) + 1
179 |
180 | labels = []
181 | scores = []
182 | for answer in answer_count:
183 | if answer not in ans2label:
184 | continue
185 | labels.append(ans2label[answer])
186 | score = get_score(answer_count[answer])
187 | scores.append(score)
188 | target = {'labels': labels, 'scores': scores}
189 | return target
190 |
191 |
192 | if __name__ == '__main__':
193 | *answer_files, output = sys.argv[1:]
194 | answers = []
195 | for ans_file in answer_files:
196 | ans = json.load(open(ans_file))['annotations']
197 | answers.extend(ans)
198 |
199 | occurence = filter_answers(answers, 9)
200 |
201 | if os.path.exists(f'{output}/ans2label.pkl'):
202 | raise ValueError(f'{output} already exists')
203 | create_ans2label(occurence, output)
204 |
--------------------------------------------------------------------------------