├── LICENSE ├── README.md ├── ViLT_LICENSE ├── azure_distributed_run.py ├── meter ├── __init__.py ├── config.py ├── datamodules │ ├── __init__.py │ ├── coco_caption_karpathy_datamodule.py │ ├── conceptual_caption_datamodule.py │ ├── datamodule_base.py │ ├── f30k_caption_karpathy_datamodule.py │ ├── multitask_datamodule.py │ ├── nlvr2_datamodule.py │ ├── sbu_datamodule.py │ ├── snli_datamodule.py │ ├── vg_caption_datamodule.py │ └── vqav2_datamodule.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_caption_karpathy_dataset.py │ ├── conceptual_caption_dataset.py │ ├── f30k_caption_karpathy_dataset.py │ ├── nlvr2_dataset.py │ ├── sbu_caption_dataset.py │ ├── snli_dataset.py │ ├── vg_caption_dataset.py │ └── vqav2_dataset.py ├── gadgets │ ├── __init__.py │ └── my_metrics.py ├── modules │ ├── __init__.py │ ├── bert_model.py │ ├── clip_model.py │ ├── dist_utils.py │ ├── heads.py │ ├── meter_module.py │ ├── meter_utils.py │ ├── objectives.py │ ├── swin_helpers.py │ └── swin_transformer.py ├── transforms │ ├── __init__.py │ ├── randaug.py │ ├── transform.py │ └── utils.py └── utils │ ├── __init__.py │ ├── glossary.py │ ├── write_coco_karpathy.py │ ├── write_conceptual_caption.py │ ├── write_f30k_karpathy.py │ ├── write_nlvr2.py │ ├── write_sbu.py │ ├── write_snli.py │ ├── write_vg.py │ └── write_vqa.py ├── requirements.txt ├── run.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # METER: A Multimodal End-to-end TransformER Framework 2 | 3 | ## Install 4 | 5 | ```bash 6 | pip install -r requirements.txt 7 | pip install -e . 8 | ``` 9 | 10 | ## Pre-trained Checkpoints 11 | 12 | Here are the pre-trained models: 13 | 1. METER-CLIP16-RoBERTa (resolution: 288^2) pre-trained on GCC+SBU+COCO+VG [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_pretrain.ckpt) 14 | 2. METER-CLIP16-RoBERTa (resolution: 224^2) pre-trained on GCC+SBU+COCO+VG [link](https://github.com/zdou0830/METER/releases/download/checkpoint2/meter_clip16_224_roberta_pretrain.ckpt) 15 | 3. METER-SwinBase-RoBERTa (resolution: 384^2) pre-trained on GCC+SBU+COCO+VG [link](https://github.com/zdou0830/METER/releases/download/checkpoint2/meter_swinbase_384_roberta_pretrain.ckpt) 16 | 4. METER-CLIP16-RoBERTa fine-tuned on VQAv2 (resolution: 576^2) [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_vqa.ckpt) 17 | 5. METER-CLIP16-RoBERTa fine-tuned on NLVR2 (resolution: 288^2) [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_nlvr2.ckpt) 18 | 6. METER-CLIP16-RoBERTa fine-tuned on SNLI-VE (resolution: 384^2) [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_snli.ckpt) 19 | 7. METER-CLIP16-RoBERTa fine-tuned on Flickr30k IR/TR (resolution: 384^2) [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_flickr.ckpt) 20 | 8. METER-CLIP16-RoBERTa fine-tuned on COCO IR/TR (resolution: 384^2) [link](https://github.com/zdou0830/METER/releases/download/checkpoint/meter_clip16_288_roberta_coco.ckpt) 21 | 22 | 23 | ## Dataset Preparation 24 | 25 | We follow [ViLT](https://github.com/dandelin/ViLT) and use `pyarrow` to serialize the datasets. See [this link](https://github.com/dandelin/ViLT/blob/master/DATA.md) for details. 26 | 27 | ## Pre-training 28 | 29 | ```bash 30 | export MASTER_ADDR=$DIST_0_IP 31 | export MASTER_PORT=$DIST_0_PORT 32 | export NODE_RANK=$DIST_RANK 33 | python run.py with data_root= num_gpus= num_nodes= task_mlm_itm_clip_bert per_gpu_batchsize= image_size= 34 | ``` 35 | 36 | Here is an example: 37 | ```bash 38 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_mlm_itm_clip_bert per_gpu_batchsize=32 clip16 text_roberta image_size=288 39 | ``` 40 | 41 | ## Fine-tuning on Downstream Tasks 42 | 43 | ### VQAv2 44 | 45 | ```bash 46 | export MASTER_ADDR=$DIST_0_IP 47 | export MASTER_PORT=$DIST_0_PORT 48 | export NODE_RANK=$DIST_RANK 49 | python run.py with data_root= num_gpus= num_nodes= task_finetune_vqa_clip_bert per_gpu_batchsize= load_path= image_size= 50 | ``` 51 | 52 | Here is an example: 53 | ```bash 54 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_vqa_clip_bert per_gpu_batchsize=32 load_path=meter_pretrain.ckpt clip16 text_roberta image_size=576 clip_randaug 55 | ``` 56 | 57 | ### Flickr30k IR/TR 58 | 59 | ```bash 60 | export MASTER_ADDR=$DIST_0_IP 61 | export MASTER_PORT=$DIST_0_PORT 62 | export NODE_RANK=$DIST_RANK 63 | python run.py with data_root= num_gpus= num_nodes= task_finetune_irtr_f30k_clip_bert get_recall_metric=False per_gpu_batchsize= load_path= image_size= 64 | ``` 65 | 66 | Here is an example: 67 | ```bash 68 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_f30k_clip_bert get_recall_metric=False per_gpu_batchsize=32 load_path=meter_pretrain.ckpt clip16 text_roberta image_size=384 clip_randaug 69 | ``` 70 | 71 | ### COCO IR/TR 72 | 73 | ```bash 74 | export MASTER_ADDR=$DIST_0_IP 75 | export MASTER_PORT=$DIST_0_PORT 76 | export NODE_RANK=$DIST_RANK 77 | python run.py with data_root= num_gpus= num_nodes= task_finetune_irtr_coco_clip_bert get_recall_metric=False per_gpu_batchsize= load_path= image_size= 78 | ``` 79 | 80 | Here is an example: 81 | ```bash 82 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_coco_clip_bert get_recall_metric=False per_gpu_batchsize=32 load_path=meter_pretrain.ckpt clip16 text_roberta image_size=384 clip_randaug 83 | ``` 84 | 85 | ### NLVR2 86 | 87 | ```bash 88 | export MASTER_ADDR=$DIST_0_IP 89 | export MASTER_PORT=$DIST_0_PORT 90 | export NODE_RANK=$DIST_RANK 91 | python run.py with data_root= num_gpus= num_nodes= task_finetune_nlvr2_clip_bert per_gpu_batchsize= load_path= image_size= 92 | ``` 93 | 94 | Here is an example: 95 | ```bash 96 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_nlvr2_clip_bert per_gpu_batchsize=32 load_path=meter_pretrain.ckpt clip16 text_roberta image_size=288 clip_randaug 97 | ``` 98 | 99 | ### SNLI-VE 100 | 101 | ```bash 102 | export MASTER_ADDR=$DIST_0_IP 103 | export MASTER_PORT=$DIST_0_PORT 104 | export NODE_RANK=$DIST_RANK 105 | python run.py with data_root= num_gpus= num_nodes= task_finetune_snli_clip_bert per_gpu_batchsize= load_path= image_size= 106 | ``` 107 | 108 | Here is an example: 109 | ```bash 110 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_snli_clip_bert per_gpu_batchsize=8 load_path=meter_pretrain.ckpt clip16 text_roberta image_size=384 clip_randaug 111 | ``` 112 | 113 | ## Evaluation on Downstream Tasks 114 | 115 | ### VQAv2 116 | 117 | ```bash 118 | export MASTER_ADDR=$DIST_0_IP 119 | export MASTER_PORT=$DIST_0_PORT 120 | export NODE_RANK=$DIST_RANK 121 | python run.py with data_root= num_gpus= num_nodes= task_finetune_vqa_clip_bert per_gpu_batchsize= load_path= image_size= test_only=True 122 | ``` 123 | 124 | Here is an example: 125 | ```bash 126 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_vqa_clip_bert per_gpu_batchsize=32 load_path=meter_vqa.ckpt clip16 text_roberta image_size=576 test_only=True 127 | ``` 128 | 129 | Then, submit the json file in the `result` directory to [eval.ai](https://eval.ai/web/challenges/challenge-page/830/overview) evaluation server to get the test-dev and/or test-std scores. 130 | 131 | 132 | ### Flickr30k IR/TR 133 | 134 | ```bash 135 | export MASTER_ADDR=$DIST_0_IP 136 | export MASTER_PORT=$DIST_0_PORT 137 | export NODE_RANK=$DIST_RANK 138 | python run.py with data_root= num_gpus= num_nodes= task_finetune_irtr_f30k_clip_bert get_recall_metric=True per_gpu_batchsize= load_path= image_size= test_only=True 139 | ``` 140 | 141 | Here is an example: 142 | ```bash 143 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_f30k_clip_bert get_recall_metric=True per_gpu_batchsize=32 load_path=meter_f30k.ckpt clip16 text_roberta image_size=384 test_only=True 144 | ``` 145 | 146 | The returned values are IR R@1, R@5, R@10 and TR R@1, R@5, R@10. 147 | 148 | ### COCO IR/TR 149 | 150 | ```bash 151 | export MASTER_ADDR=$DIST_0_IP 152 | export MASTER_PORT=$DIST_0_PORT 153 | export NODE_RANK=$DIST_RANK 154 | python run.py with data_root= num_gpus= num_nodes= task_finetune_irtr_coco_clip_bert get_recall_metric=True per_gpu_batchsize= load_path= image_size= test_only=True 155 | ``` 156 | 157 | Here is an example: 158 | ```bash 159 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_coco_clip_bert get_recall_metric=True per_gpu_batchsize=32 load_path=meter_coco.ckpt clip16 text_roberta image_size=384 test_only=True 160 | ``` 161 | 162 | The returned values are IR R@1, R@5, R@10 and TR R@1, R@5, R@10. 163 | 164 | ### NLVR2 165 | 166 | ```bash 167 | export MASTER_ADDR=$DIST_0_IP 168 | export MASTER_PORT=$DIST_0_PORT 169 | export NODE_RANK=$DIST_RANK 170 | python run.py with data_root= num_gpus= num_nodes= task_finetune_nlvr2_clip_bert per_gpu_batchsize= load_path= image_size= test_only=True 171 | ``` 172 | 173 | Here is an example: 174 | ```bash 175 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_nlvr2_clip_bert per_gpu_batchsize=32 load_path=meter_nlvr2.ckpt clip16 text_roberta image_size=288 test_only=True 176 | ``` 177 | 178 | ### SNLI-VE 179 | 180 | ```bash 181 | export MASTER_ADDR=$DIST_0_IP 182 | export MASTER_PORT=$DIST_0_PORT 183 | export NODE_RANK=$DIST_RANK 184 | python run.py with data_root= num_gpus= num_nodes= task_finetune_snli_clip_bert per_gpu_batchsize= load_path= image_size= test_only=True 185 | ``` 186 | 187 | Here is an example: 188 | ```bash 189 | python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_snli_clip_bert per_gpu_batchsize=8 load_path=meter_snli.ckpt clip16 text_roberta image_size=384 test_only=True 190 | ``` 191 | 192 | 193 | 194 | ## Citation 195 | 196 | ``` 197 | @inproceedings{dou2022meter, 198 | title={An Empirical Study of Training End-to-End Vision-and-Language Transformers}, 199 | author={Dou, Zi-Yi and Xu, Yichong and Gan, Zhe and Wang, Jianfeng and Wang, Shuohang and Wang, Lijuan and Zhu, Chenguang and Zhang, Pengchuan and Yuan, Lu and Peng, Nanyun and Liu, Zicheng and Zeng, Michael}, 200 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)}, 201 | year={2022}, 202 | url={https://arxiv.org/abs/2111.02387}, 203 | } 204 | ``` 205 | 206 | ## Acknowledgements 207 | 208 | The code is based on [ViLT](https://github.com/dandelin/ViLT) licensed under [Apache 2.0](https://github.com/dandelin/ViLT/blob/master/LICENSE) and some of the code is borrowed from [CLIP](https://github.com/openai/CLIP) and [Swin-Transformer](https://github.com/microsoft/Swin-Transformer). 209 | -------------------------------------------------------------------------------- /ViLT_LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021-present NAVER Corp. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /azure_distributed_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import pytorch_lightning as pl 4 | import os 5 | os.environ["NCCL_DEBUG"] = "INFO" 6 | 7 | from meter.config import ex 8 | from meter.modules import METERTransformerSS 9 | from meter.datamodules.multitask_datamodule import MTDataModule 10 | 11 | import resource 12 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 13 | resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1])) 14 | 15 | from pytorch_lightning.plugins.environments import ClusterEnvironment 16 | from pytorch_lightning.plugins.training_type import DDPPlugin 17 | import torch.distributed as dist 18 | class MyCluster(ClusterEnvironment): 19 | 20 | def creates_children(self) -> bool: 21 | # return True if the cluster is managed (you don't launch processes yourself) 22 | return True 23 | 24 | def master_address(self): 25 | return os.environ['MASTER_ADDR'] 26 | 27 | def master_port(self) -> int: 28 | return int(os.environ["MASTER_PORT"]) 29 | 30 | def world_size(self): 31 | return int(os.environ['OMPI_COMM_WORLD_SIZE']) 32 | 33 | def global_rank(self) -> int: 34 | return int(os.environ['OMPI_COMM_WORLD_RANK']) 35 | 36 | def local_rank(self) -> int: 37 | return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 38 | 39 | def node_rank(self) -> int: 40 | return int(os.environ["OMPI_COMM_WORLD_NODE_RANK"]) 41 | 42 | def set_global_rank(self, rank: int) -> None: 43 | pass 44 | 45 | def set_world_size(self, size: int) -> None: 46 | pass 47 | 48 | class MyDDPPlugin(DDPPlugin): 49 | 50 | def init_ddp_connection(self, global_rank = None, world_size = None) -> None: 51 | master_uri = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 52 | dist.init_process_group( 53 | backend=self.torch_distributed_backend, 54 | init_method=master_uri, 55 | world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']), 56 | rank=int(os.environ['OMPI_COMM_WORLD_RANK']), 57 | ) 58 | 59 | @ex.automain 60 | def main(_config): 61 | os.environ["NCCL_DEBUG"] = "INFO" 62 | world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 63 | local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) 64 | global_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 65 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 66 | master_addr = os.environ['MASTER_ADDR'] 67 | master_port = os.environ['MASTER_PORT'] 68 | 69 | # set environment variables for 'env://' 70 | os.environ['WORLD_SIZE'] = str(world_size) 71 | os.environ['NODE_RANK'] = str(os.environ["OMPI_COMM_WORLD_NODE_RANK"]) 72 | 73 | _config = copy.deepcopy(_config) 74 | pl.seed_everything(_config["seed"]) 75 | 76 | dm = MTDataModule(_config, dist=True) 77 | 78 | model = METERTransformerSS(_config) 79 | exp_name = f'{_config["exp_name"]}' 80 | 81 | os.makedirs(_config["log_dir"], exist_ok=True) 82 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 83 | save_top_k=1, 84 | verbose=True, 85 | monitor="val/the_metric", 86 | mode="max", 87 | save_last=True, 88 | ) 89 | logger = pl.loggers.TensorBoardLogger( 90 | _config["log_dir"], 91 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 92 | ) 93 | 94 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 95 | callbacks = [checkpoint_callback, lr_callback] 96 | 97 | num_gpus = ( 98 | _config["num_gpus"] 99 | if isinstance(_config["num_gpus"], int) 100 | else len(_config["num_gpus"]) 101 | ) 102 | 103 | grad_steps = max(_config["batch_size"] // ( 104 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 105 | ), 1) 106 | 107 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 108 | 109 | trainer = pl.Trainer( 110 | plugins=[MyCluster(), MyDDPPlugin()], 111 | gpus=_config["num_gpus"], 112 | num_nodes=_config["num_nodes"], 113 | precision=_config["precision"], 114 | accelerator="ddp", 115 | benchmark=True, 116 | deterministic=True, 117 | max_epochs=_config["max_epoch"] if max_steps is None else 1000, 118 | max_steps=max_steps, 119 | callbacks=callbacks, 120 | logger=logger, 121 | prepare_data_per_node=False, 122 | replace_sampler_ddp=False, 123 | accumulate_grad_batches=grad_steps, 124 | log_every_n_steps=10, 125 | flush_logs_every_n_steps=10, 126 | resume_from_checkpoint=_config["resume_from"], 127 | weights_summary="top", 128 | fast_dev_run=_config["fast_dev_run"], 129 | val_check_interval=_config["val_check_interval"], 130 | ) 131 | 132 | if not _config["test_only"]: 133 | trainer.fit(model, datamodule=dm) 134 | else: 135 | trainer.test(model, datamodule=dm) 136 | -------------------------------------------------------------------------------- /meter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdou0830/METER/f4f09345b26ee21add0a756d06598e3c04726345/meter/__init__.py -------------------------------------------------------------------------------- /meter/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("METER") 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "itm": 0, 9 | "mlm": 0, 10 | "mpp": 0, 11 | "vqa": 0, 12 | "vcr": 0, 13 | "vcr_qar": 0, 14 | "nlvr2": 0, 15 | "irtr": 0, 16 | "contras": 0, 17 | "snli": 0, 18 | } 19 | ret.update(d) 20 | return ret 21 | 22 | 23 | @ex.config 24 | def config(): 25 | exp_name = "meter" 26 | seed = 0 27 | datasets = ["coco", "vg", "sbu", "gcc"] 28 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 29 | batch_size = 4096 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 30 | 31 | # Image setting 32 | train_transform_keys = ["clip"] 33 | val_transform_keys = ["clip"] 34 | image_size = 224 35 | patch_size = 32 36 | draw_false_image = 1 37 | image_only = False 38 | resolution_before = 224 39 | 40 | # Text Setting 41 | vqav2_label_size = 3129 42 | max_text_len = 40 43 | tokenizer = "bert-base-uncased" 44 | vocab_size = 30522 45 | whole_word_masking = False # note that whole_word_masking does not work for RoBERTa 46 | mlm_prob = 0.15 47 | draw_false_text = 0 48 | 49 | # Transformer Setting 50 | num_top_layer = 6 51 | input_image_embed_size = 768 52 | input_text_embed_size = 768 53 | vit = 'ViT-B/32' 54 | hidden_size = 768 55 | num_heads = 12 56 | num_layers = 6 57 | mlp_ratio = 4 58 | drop_rate = 0.1 59 | 60 | # Optimizer Setting 61 | optim_type = "adamw" 62 | learning_rate = 1e-5 63 | weight_decay = 0.01 64 | decay_power = 1 65 | max_epoch = 100 66 | max_steps = 100000 67 | warmup_steps = 10000 68 | end_lr = 0 69 | lr_mult_head = 5 # multiply lr for downstream heads 70 | lr_mult_cross_modal = 5 # multiply lr for the cross-modal module 71 | 72 | # Downstream Setting 73 | get_recall_metric = False 74 | 75 | # PL Trainer Setting 76 | resume_from = None 77 | fast_dev_run = False 78 | val_check_interval = 1.0 79 | test_only = False 80 | 81 | # below params varies with the environment 82 | data_root = "" 83 | log_dir = "result" 84 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 85 | num_gpus = 8 86 | num_nodes = 1 87 | load_path = "" 88 | num_workers = 8 89 | precision = 32 90 | 91 | 92 | @ex.named_config 93 | def task_mlm_itm_clip_bert(): 94 | exp_name = "mlm_itm" 95 | datasets = ["coco", "vg", "sbu", "gcc"] 96 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 97 | batch_size = 4096 98 | max_epoch = 10 99 | max_steps = 100000 100 | warmup_steps = 0.1 101 | whole_word_masking = True 102 | 103 | vocab_size = 30522 104 | max_text_len = 50 105 | image_size = 224 106 | tokenizer = "bert-base-uncased" 107 | train_transform_keys = ["clip"] 108 | val_transform_keys = ["clip"] 109 | learning_rate = 1e-5 110 | val_check_interval = 1.0 111 | lr_mult_head = 5 112 | lr_mult_cross_modal = 5 113 | num_top_layer = 6 114 | hidden_size = 768 115 | num_heads = 12 116 | 117 | @ex.named_config 118 | def task_finetune_nlvr2_clip_bert(): 119 | exp_name = "finetune_nlvr2" 120 | datasets = ["nlvr2"] 121 | loss_names = _loss_names({"nlvr2": 1}) 122 | batch_size = 256 123 | max_epoch = 10 124 | max_steps = None 125 | warmup_steps = 0.1 126 | draw_false_image = 0 127 | learning_rate = 1e-5 128 | lr_mult_head = 10 129 | lr_mult_cross_modal = 5 130 | tokenizer = "bert-base-uncased" 131 | max_text_len = 50 132 | input_text_embed_size = 768 133 | vit = 'ViT-B/32' 134 | train_transform_keys = ["clip"] 135 | val_transform_keys = ["clip"] 136 | input_image_embed_size = 768 137 | image_size = 288 138 | 139 | @ex.named_config 140 | def task_finetune_vqa_clip_bert(): 141 | exp_name = "finetune_vqa" 142 | datasets = ["vqa"] 143 | loss_names = _loss_names({"vqa": 1}) 144 | batch_size = 512 145 | max_epoch = 10 146 | max_steps = None 147 | warmup_steps = 0.1 148 | draw_false_image = 0 149 | learning_rate = 5e-6 150 | val_check_interval = 0.1 151 | lr_mult_head = 50 152 | lr_mult_cross_modal = 5 153 | tokenizer = "bert-base-uncased" 154 | max_text_len = 50 155 | input_text_embed_size = 768 156 | vit = 'ViT-B/32' 157 | train_transform_keys = ["clip"] 158 | val_transform_keys = ["clip"] 159 | input_image_embed_size = 768 160 | image_size = 576 161 | 162 | @ex.named_config 163 | def task_finetune_irtr_coco_clip_bert(): 164 | exp_name = "finetune_irtr_coco" 165 | datasets = ["coco"] 166 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 167 | batch_size = 512 168 | max_epoch = 10 169 | max_steps = None 170 | warmup_steps = 0.1 171 | get_recall_metric = True 172 | draw_false_text = 15 173 | learning_rate = 5e-6 174 | lr_mult_head = 5 175 | lr_mult_cross_modal = 5 176 | tokenizer = "bert-base-uncased" 177 | input_text_embed_size = 768 178 | vit = 'ViT-B/32' 179 | train_transform_keys = ["clip"] 180 | val_transform_keys = ["clip"] 181 | input_image_embed_size = 768 182 | image_size = 384 183 | 184 | @ex.named_config 185 | def task_finetune_irtr_f30k_clip_bert(): 186 | exp_name = "finetune_irtr_f30k" 187 | datasets = ["f30k"] 188 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 189 | batch_size = 512 190 | max_epoch = 10 191 | max_steps = None 192 | warmup_steps = 0.1 193 | get_recall_metric = True 194 | draw_false_text = 15 195 | learning_rate = 5e-6 196 | lr_mult_head = 5 197 | lr_mult_cross_modal = 5 198 | tokenizer = "bert-base-uncased" 199 | input_text_embed_size = 768 200 | vit = 'ViT-B/32' 201 | train_transform_keys = ["clip"] 202 | val_transform_keys = ["clip"] 203 | input_image_embed_size = 768 204 | image_size = 384 205 | 206 | @ex.named_config 207 | def task_finetune_snli_clip_bert(): 208 | exp_name = "finetune_snli" 209 | datasets = ["snli"] 210 | loss_names = _loss_names({"snli": 1}) 211 | batch_size = 64 212 | max_epoch = 5 213 | max_steps = None 214 | warmup_steps = 0.1 215 | draw_false_image = 0 216 | learning_rate = 2e-6 217 | lr_mult_head = 10 218 | lr_mult_cross_modal = 5 219 | tokenizer = "bert-base-uncased" 220 | max_text_len = 50 221 | input_text_embed_size = 768 222 | vit = 'ViT-B/32' 223 | train_transform_keys = ["clip"] 224 | val_transform_keys = ["clip"] 225 | input_image_embed_size = 768 226 | image_size = 384 227 | 228 | 229 | # Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end 230 | 231 | # vision encoder 232 | @ex.named_config 233 | def swin32_base224(): 234 | vit = "swin_base_patch4_window7_224_in22k" 235 | patch_size = 32 236 | image_size = 224 237 | train_transform_keys = ["imagenet"] 238 | val_transform_keys = ["imagenet"] 239 | input_image_embed_size = 1024 240 | resolution_before = 224 241 | 242 | @ex.named_config 243 | def swin32_base384(): 244 | vit = "swin_base_patch4_window12_384_in22k" 245 | patch_size = 32 246 | image_size = 384 247 | train_transform_keys = ["imagenet"] 248 | val_transform_keys = ["imagenet"] 249 | input_image_embed_size = 1024 250 | resolution_before = 384 251 | 252 | @ex.named_config 253 | def swin32_large384(): 254 | vit = "swin_large_patch4_window12_384_in22k" 255 | patch_size = 32 256 | image_size = 384 257 | train_transform_keys = ["imagenet"] 258 | val_transform_keys = ["imagenet"] 259 | input_image_embed_size = 1536 260 | resolution_before = 384 261 | 262 | @ex.named_config 263 | def clip32(): 264 | vit = 'ViT-B/32' 265 | image_size = 224 266 | patch_size = 32 267 | train_transform_keys = ["clip"] 268 | val_transform_keys = ["clip"] 269 | input_image_embed_size = 768 270 | 271 | @ex.named_config 272 | def clip16(): 273 | vit = 'ViT-B/16' 274 | image_size = 224 275 | patch_size = 16 276 | train_transform_keys = ["clip"] 277 | val_transform_keys = ["clip"] 278 | input_image_embed_size = 768 279 | 280 | # text encoder 281 | @ex.named_config 282 | def text_roberta(): 283 | tokenizer = "roberta-base" 284 | vocab_size = 50265 285 | input_text_embed_size = 768 286 | 287 | @ex.named_config 288 | def text_roberta_large(): 289 | tokenizer = "roberta-large" 290 | vocab_size = 50265 291 | input_text_embed_size = 1024 292 | 293 | # random augmentation 294 | @ex.named_config 295 | def imagenet_randaug(): 296 | train_transform_keys = ["imagenet_randaug"] 297 | 298 | @ex.named_config 299 | def clip_randaug(): 300 | train_transform_keys = ["clip_randaug"] 301 | -------------------------------------------------------------------------------- /meter/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_datamodule import VisualGenomeCaptionDataModule 2 | from .f30k_caption_karpathy_datamodule import F30KCaptionKarpathyDataModule 3 | from .coco_caption_karpathy_datamodule import CocoCaptionKarpathyDataModule 4 | from .conceptual_caption_datamodule import ConceptualCaptionDataModule 5 | from .sbu_datamodule import SBUCaptionDataModule 6 | from .vqav2_datamodule import VQAv2DataModule 7 | from .nlvr2_datamodule import NLVR2DataModule 8 | from .snli_datamodule import SNLIDataModule 9 | 10 | _datamodules = { 11 | "vg": VisualGenomeCaptionDataModule, 12 | "f30k": F30KCaptionKarpathyDataModule, 13 | "coco": CocoCaptionKarpathyDataModule, 14 | "gcc": ConceptualCaptionDataModule, 15 | "sbu": SBUCaptionDataModule, 16 | "vqa": VQAv2DataModule, 17 | "nlvr2": NLVR2DataModule, 18 | "snli": SNLIDataModule, 19 | } 20 | -------------------------------------------------------------------------------- /meter/datamodules/coco_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import CocoCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class CocoCaptionKarpathyDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return CocoCaptionKarpathyDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return CocoCaptionKarpathyDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "coco" 20 | -------------------------------------------------------------------------------- /meter/datamodules/conceptual_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import ConceptualCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class ConceptualCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return ConceptualCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "gcc" 16 | -------------------------------------------------------------------------------- /meter/datamodules/datamodule_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from transformers import ( 6 | DataCollatorForLanguageModeling, 7 | DataCollatorForWholeWordMask, 8 | BertTokenizer, 9 | RobertaTokenizer, 10 | ) 11 | 12 | 13 | def get_pretrained_tokenizer(from_pretrained): 14 | if torch.distributed.is_initialized(): 15 | if torch.distributed.get_rank() == 0: 16 | if 'roberta' in from_pretrained: 17 | RobertaTokenizer.from_pretrained(from_pretrained) 18 | else: 19 | BertTokenizer.from_pretrained( 20 | from_pretrained, do_lower_case="uncased" in from_pretrained 21 | ) 22 | torch.distributed.barrier() 23 | 24 | if 'roberta' in from_pretrained: 25 | return RobertaTokenizer.from_pretrained(from_pretrained) 26 | return BertTokenizer.from_pretrained( 27 | from_pretrained, do_lower_case="uncased" in from_pretrained 28 | ) 29 | 30 | 31 | class BaseDataModule(LightningDataModule): 32 | def __init__(self, _config): 33 | super().__init__() 34 | 35 | self.data_dir = _config["data_root"] 36 | 37 | self.num_workers = _config["num_workers"] 38 | self.batch_size = _config["per_gpu_batchsize"] 39 | self.eval_batch_size = self.batch_size 40 | 41 | self.image_size = _config["image_size"] 42 | self.max_text_len = _config["max_text_len"] 43 | self.draw_false_image = _config["draw_false_image"] 44 | self.draw_false_text = _config["draw_false_text"] 45 | self.image_only = _config["image_only"] 46 | 47 | self.train_transform_keys = ( 48 | ["default_train"] 49 | if len(_config["train_transform_keys"]) == 0 50 | else _config["train_transform_keys"] 51 | ) 52 | 53 | self.val_transform_keys = ( 54 | ["default_val"] 55 | if len(_config["val_transform_keys"]) == 0 56 | else _config["val_transform_keys"] 57 | ) 58 | 59 | tokenizer = _config["tokenizer"] 60 | self.tokenizer = get_pretrained_tokenizer(tokenizer) 61 | self.vocab_size = self.tokenizer.vocab_size 62 | 63 | collator = ( 64 | DataCollatorForWholeWordMask 65 | if _config["whole_word_masking"] 66 | else DataCollatorForLanguageModeling 67 | ) 68 | 69 | self.mlm_collator = collator( 70 | tokenizer=self.tokenizer, mlm=True, mlm_probability=_config["mlm_prob"] 71 | ) 72 | self.setup_flag = False 73 | 74 | @property 75 | def dataset_cls(self): 76 | raise NotImplementedError("return tuple of dataset class") 77 | 78 | @property 79 | def dataset_name(self): 80 | raise NotImplementedError("return name of dataset") 81 | 82 | def set_train_dataset(self): 83 | self.train_dataset = self.dataset_cls( 84 | self.data_dir, 85 | self.train_transform_keys, 86 | split="train", 87 | image_size=self.image_size, 88 | max_text_len=self.max_text_len, 89 | draw_false_image=self.draw_false_image, 90 | draw_false_text=self.draw_false_text, 91 | image_only=self.image_only, 92 | tokenizer=self.tokenizer, 93 | ) 94 | 95 | def set_val_dataset(self): 96 | self.val_dataset = self.dataset_cls( 97 | self.data_dir, 98 | self.val_transform_keys, 99 | split="val", 100 | image_size=self.image_size, 101 | max_text_len=self.max_text_len, 102 | draw_false_image=self.draw_false_image, 103 | draw_false_text=self.draw_false_text, 104 | image_only=self.image_only, 105 | tokenizer=self.tokenizer, 106 | ) 107 | 108 | if hasattr(self, "dataset_cls_no_false"): 109 | self.val_dataset_no_false = self.dataset_cls_no_false( 110 | self.data_dir, 111 | self.val_transform_keys, 112 | split="val", 113 | image_size=self.image_size, 114 | max_text_len=self.max_text_len, 115 | draw_false_image=0, 116 | draw_false_text=0, 117 | image_only=self.image_only, 118 | tokenizer=self.tokenizer, 119 | ) 120 | 121 | def make_no_false_val_dset(self, image_only=False): 122 | return self.dataset_cls_no_false( 123 | self.data_dir, 124 | self.val_transform_keys, 125 | split="val", 126 | image_size=self.image_size, 127 | max_text_len=self.max_text_len, 128 | draw_false_image=0, 129 | draw_false_text=0, 130 | image_only=image_only, 131 | tokenizer=self.tokenizer, 132 | ) 133 | 134 | def set_test_dataset(self): 135 | self.test_dataset = self.dataset_cls( 136 | self.data_dir, 137 | self.val_transform_keys, 138 | split="test", 139 | image_size=self.image_size, 140 | max_text_len=self.max_text_len, 141 | draw_false_image=self.draw_false_image, 142 | draw_false_text=self.draw_false_text, 143 | image_only=self.image_only, 144 | tokenizer=self.tokenizer, 145 | ) 146 | 147 | def setup(self, stage): 148 | if not self.setup_flag: 149 | self.set_train_dataset() 150 | self.set_val_dataset() 151 | self.set_test_dataset() 152 | 153 | self.train_dataset.tokenizer = self.tokenizer 154 | self.val_dataset.tokenizer = self.tokenizer 155 | self.test_dataset.tokenizer = self.tokenizer 156 | 157 | self.setup_flag = True 158 | 159 | def train_dataloader(self): 160 | loader = DataLoader( 161 | self.train_dataset, 162 | batch_size=self.batch_size, 163 | shuffle=True, 164 | num_workers=self.num_workers, 165 | pin_memory=True, 166 | collate_fn=self.train_dataset.collate, 167 | ) 168 | return loader 169 | 170 | def val_dataloader(self): 171 | loader = DataLoader( 172 | self.val_dataset, 173 | batch_size=self.eval_batch_size, 174 | shuffle=False, 175 | num_workers=self.num_workers, 176 | pin_memory=True, 177 | collate_fn=self.val_dataset.collate, 178 | ) 179 | return loader 180 | 181 | def test_dataloader(self): 182 | loader = DataLoader( 183 | self.test_dataset, 184 | batch_size=self.eval_batch_size, 185 | shuffle=False, 186 | num_workers=self.num_workers, 187 | pin_memory=True, 188 | collate_fn=self.test_dataset.collate, 189 | ) 190 | return loader 191 | -------------------------------------------------------------------------------- /meter/datamodules/f30k_caption_karpathy_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import F30KCaptionKarpathyDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class F30KCaptionKarpathyDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return F30KCaptionKarpathyDataset 12 | 13 | @property 14 | def dataset_cls_no_false(self): 15 | return F30KCaptionKarpathyDataset 16 | 17 | @property 18 | def dataset_name(self): 19 | return "f30k" 20 | 21 | def train_dataloader(self): 22 | loader = DataLoader( 23 | self.train_dataset, 24 | batch_size=self.batch_size, 25 | shuffle=True, 26 | num_workers=0, 27 | pin_memory=True, 28 | collate_fn=self.train_dataset.collate, 29 | ) 30 | return loader 31 | 32 | def val_dataloader(self): 33 | loader = DataLoader( 34 | self.val_dataset, 35 | batch_size=self.eval_batch_size, 36 | shuffle=False, 37 | num_workers=0, 38 | pin_memory=True, 39 | collate_fn=self.val_dataset.collate, 40 | ) 41 | return loader 42 | 43 | def test_dataloader(self): 44 | loader = DataLoader( 45 | self.test_dataset, 46 | batch_size=self.eval_batch_size, 47 | shuffle=False, 48 | num_workers=0, 49 | pin_memory=True, 50 | collate_fn=self.test_dataset.collate, 51 | ) 52 | return loader 53 | -------------------------------------------------------------------------------- /meter/datamodules/multitask_datamodule.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataset import ConcatDataset 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from . import _datamodules 9 | 10 | 11 | class MTDataModule(LightningDataModule): 12 | def __init__(self, _config, dist=False): 13 | datamodule_keys = _config["datasets"] 14 | assert len(datamodule_keys) > 0 15 | 16 | super().__init__() 17 | 18 | self.dm_keys = datamodule_keys 19 | self.dm_dicts = {key: _datamodules[key](_config) for key in datamodule_keys} 20 | self.dms = [v for k, v in self.dm_dicts.items()] 21 | 22 | self.batch_size = self.dms[0].batch_size 23 | self.vocab_size = self.dms[0].vocab_size 24 | self.num_workers = self.dms[0].num_workers 25 | 26 | self.dist = dist 27 | 28 | def prepare_data(self): 29 | for dm in self.dms: 30 | dm.prepare_data() 31 | 32 | def setup(self, stage): 33 | for dm in self.dms: 34 | dm.setup(stage) 35 | 36 | self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) 37 | self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) 38 | self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) 39 | self.tokenizer = self.dms[0].tokenizer 40 | 41 | self.collate = functools.partial( 42 | self.dms[0].train_dataset.collate, mlm_collator=self.dms[0].mlm_collator, 43 | ) 44 | 45 | if self.dist: 46 | self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True) 47 | self.val_sampler = DistributedSampler(self.val_dataset, shuffle=True) 48 | self.test_sampler = DistributedSampler(self.test_dataset, shuffle=False) 49 | else: 50 | self.train_sampler = None 51 | self.val_sampler = None 52 | self.test_sampler = None 53 | 54 | def train_dataloader(self): 55 | loader = DataLoader( 56 | self.train_dataset, 57 | batch_size=self.batch_size, 58 | sampler=self.train_sampler, 59 | num_workers=self.num_workers, 60 | collate_fn=self.collate, 61 | ) 62 | return loader 63 | 64 | def val_dataloader(self, batch_size=None): 65 | loader = DataLoader( 66 | self.val_dataset, 67 | batch_size=batch_size if batch_size is not None else self.batch_size, 68 | sampler=self.val_sampler, 69 | num_workers=self.num_workers, 70 | collate_fn=self.collate, 71 | ) 72 | return loader 73 | 74 | def test_dataloader(self): 75 | loader = DataLoader( 76 | self.test_dataset, 77 | batch_size=self.batch_size, 78 | sampler=self.test_sampler, 79 | num_workers=self.num_workers, 80 | collate_fn=self.collate, 81 | ) 82 | return loader 83 | -------------------------------------------------------------------------------- /meter/datamodules/nlvr2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import NLVR2Dataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class NLVR2DataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return NLVR2Dataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "nlvr2" 16 | -------------------------------------------------------------------------------- /meter/datamodules/sbu_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SBUCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class SBUCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return SBUCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "sbu" 16 | -------------------------------------------------------------------------------- /meter/datamodules/snli_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import SNLIDataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class SNLIDataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return SNLIDataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "snli" 17 | -------------------------------------------------------------------------------- /meter/datamodules/vg_caption_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VisualGenomeCaptionDataset 2 | from .datamodule_base import BaseDataModule 3 | 4 | 5 | class VisualGenomeCaptionDataModule(BaseDataModule): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | @property 10 | def dataset_cls(self): 11 | return VisualGenomeCaptionDataset 12 | 13 | @property 14 | def dataset_name(self): 15 | return "vg" 16 | -------------------------------------------------------------------------------- /meter/datamodules/vqav2_datamodule.py: -------------------------------------------------------------------------------- 1 | from ..datasets import VQAv2Dataset 2 | from .datamodule_base import BaseDataModule 3 | from collections import defaultdict 4 | 5 | 6 | class VQAv2DataModule(BaseDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | @property 11 | def dataset_cls(self): 12 | return VQAv2Dataset 13 | 14 | @property 15 | def dataset_name(self): 16 | return "vqa" 17 | 18 | def setup(self, stage): 19 | super().setup(stage) 20 | 21 | train_answers = self.train_dataset.table["answers"].to_pandas().tolist() 22 | val_answers = self.val_dataset.table["answers"].to_pandas().tolist() 23 | train_labels = self.train_dataset.table["answer_labels"].to_pandas().tolist() 24 | val_labels = self.val_dataset.table["answer_labels"].to_pandas().tolist() 25 | 26 | all_answers = [c for c in train_answers + val_answers if c is not None] 27 | all_answers = [l for lll in all_answers for ll in lll for l in ll] 28 | all_labels = [c for c in train_labels + val_labels if c is not None] 29 | all_labels = [l for lll in all_labels for ll in lll for l in ll] 30 | 31 | self.answer2id = {k: v for k, v in zip(all_answers, all_labels)} 32 | sorted_a2i = sorted(self.answer2id.items(), key=lambda x: x[1]) 33 | self.num_class = max(self.answer2id.values()) + 1 34 | 35 | self.id2answer = defaultdict(lambda: "unknown") 36 | for k, v in sorted_a2i: 37 | self.id2answer[v] = k 38 | -------------------------------------------------------------------------------- /meter/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vg_caption_dataset import VisualGenomeCaptionDataset 2 | from .coco_caption_karpathy_dataset import CocoCaptionKarpathyDataset 3 | from .f30k_caption_karpathy_dataset import F30KCaptionKarpathyDataset 4 | from .conceptual_caption_dataset import ConceptualCaptionDataset 5 | from .sbu_caption_dataset import SBUCaptionDataset 6 | from .vqav2_dataset import VQAv2Dataset 7 | from .nlvr2_dataset import NLVR2Dataset 8 | from .snli_dataset import SNLIDataset 9 | -------------------------------------------------------------------------------- /meter/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import io 4 | import pyarrow as pa 5 | import os 6 | 7 | from PIL import Image 8 | from ..transforms import keys_to_transforms 9 | 10 | 11 | class BaseDataset(torch.utils.data.Dataset): 12 | def __init__( 13 | self, 14 | data_dir: str, 15 | transform_keys: list, 16 | image_size: int, 17 | names: list, 18 | text_column_name: str = "", 19 | remove_duplicate=True, 20 | max_text_len=40, 21 | draw_false_image=0, 22 | draw_false_text=0, 23 | image_only=False, 24 | tokenizer=None, 25 | ): 26 | """ 27 | data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data 28 | transform_keys : keys for generating augmented views of images 29 | text_column_name : pyarrow table column name that has list of strings as elements 30 | """ 31 | assert len(transform_keys) >= 1 32 | super().__init__() 33 | 34 | self.transforms = keys_to_transforms(transform_keys, size=image_size) 35 | self.clip_transform = False 36 | for transform_key in transform_keys: 37 | if 'clip' in transform_key: 38 | self.clip_transform = True 39 | break 40 | self.text_column_name = text_column_name 41 | self.names = names 42 | self.max_text_len = max_text_len 43 | self.draw_false_image = draw_false_image 44 | self.draw_false_text = draw_false_text 45 | self.image_only = image_only 46 | self.data_dir = data_dir 47 | 48 | if len(names) != 0: 49 | tables = [ 50 | pa.ipc.RecordBatchFileReader( 51 | pa.memory_map(f"{data_dir}/{name}.arrow", "r") 52 | ).read_all() 53 | for name in names 54 | if os.path.isfile(f"{data_dir}/{name}.arrow") 55 | ] 56 | 57 | self.table_names = list() 58 | for i, name in enumerate(names): 59 | self.table_names += [name] * len(tables[i]) 60 | 61 | self.table = pa.concat_tables(tables, promote=True) 62 | if text_column_name != "": 63 | self.text_column_name = text_column_name 64 | self.all_texts = self.table[text_column_name].to_pandas().tolist() 65 | if type(self.all_texts[0][0]) == str: 66 | self.all_texts = ( 67 | [list(set(texts)) for texts in self.all_texts] 68 | if remove_duplicate 69 | else self.all_texts 70 | ) 71 | else: #snli 72 | self.all_texts = ( 73 | [[t[1].strip() for t in texts] for texts in self.all_texts] 74 | ) 75 | else: 76 | self.all_texts = list() 77 | else: 78 | self.all_texts = list() 79 | 80 | self.index_mapper = dict() 81 | 82 | if text_column_name != "" and not self.image_only: 83 | j = 0 84 | for i, texts in enumerate(self.all_texts): 85 | for _j in range(len(texts)): 86 | self.index_mapper[j] = (i, _j) 87 | j += 1 88 | else: 89 | for i in range(len(self.table)): 90 | self.index_mapper[i] = (i, None) 91 | 92 | @property 93 | def corpus(self): 94 | return [text for texts in self.all_texts for text in texts] 95 | 96 | def __len__(self): 97 | return len(self.index_mapper) 98 | 99 | def get_raw_image(self, index, image_key="image"): 100 | index, caption_index = self.index_mapper[index] 101 | image_bytes = io.BytesIO(self.table[image_key][index].as_py()) 102 | image_bytes.seek(0) 103 | if self.clip_transform: 104 | return Image.open(image_bytes).convert("RGBA") 105 | else: 106 | return Image.open(image_bytes).convert("RGB") 107 | 108 | def get_image(self, index, image_key="image"): 109 | image = self.get_raw_image(index, image_key=image_key) 110 | image_tensor = [tr(image) for tr in self.transforms] 111 | return { 112 | "image": image_tensor, 113 | "img_index": self.index_mapper[index][0], 114 | "cap_index": self.index_mapper[index][1], 115 | "raw_index": index, 116 | } 117 | 118 | def get_false_image(self, rep, image_key="image"): 119 | random_index = random.randint(0, len(self.index_mapper) - 1) 120 | image = self.get_raw_image(random_index, image_key=image_key) 121 | image_tensor = [tr(image) for tr in self.transforms] 122 | return {f"false_image_{rep}": image_tensor} 123 | 124 | def get_text(self, raw_index): 125 | index, caption_index = self.index_mapper[raw_index] 126 | 127 | text = self.all_texts[index][caption_index] 128 | encoding = self.tokenizer( 129 | text, 130 | padding="max_length", 131 | truncation=True, 132 | max_length=self.max_text_len, 133 | return_special_tokens_mask=True, 134 | ) 135 | return { 136 | "text": (text, encoding), 137 | "img_index": index, 138 | "cap_index": caption_index, 139 | "raw_index": raw_index, 140 | } 141 | 142 | def get_false_text(self, rep): 143 | random_index = random.randint(0, len(self.index_mapper) - 1) 144 | 145 | index, caption_index = self.index_mapper[random_index] 146 | text = self.all_texts[index][caption_index] 147 | encoding = self.tokenizer( 148 | text, 149 | truncation=True, 150 | max_length=self.max_text_len, 151 | return_special_tokens_mask=True, 152 | ) 153 | return {f"false_text_{rep}": (text, encoding)} 154 | 155 | def get_suite(self, index): 156 | result = None 157 | while result is None: 158 | try: 159 | ret = dict() 160 | ret.update(self.get_image(index)) 161 | if not self.image_only: 162 | txt = self.get_text(index) 163 | ret.update({"replica": True if txt["cap_index"] > 0 else False}) 164 | ret.update(txt) 165 | 166 | for i in range(self.draw_false_image): 167 | ret.update(self.get_false_image(i)) 168 | for i in range(self.draw_false_text): 169 | ret.update(self.get_false_text(i)) 170 | result = True 171 | except Exception as e: 172 | print(f"Error while read file idx {index} in {self.names[0]} -> {e}") 173 | index = random.randint(0, len(self.index_mapper) - 1) 174 | return ret 175 | 176 | def collate(self, batch, mlm_collator): 177 | batch_size = len(batch) 178 | keys = set([key for b in batch for key in b.keys()]) 179 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 180 | 181 | img_keys = [k for k in list(dict_batch.keys()) if "image" in k] 182 | img_sizes = list() 183 | 184 | for img_key in img_keys: 185 | img = dict_batch[img_key] 186 | img_sizes += [ii.shape for i in img if i is not None for ii in i] 187 | 188 | for size in img_sizes: 189 | assert ( 190 | len(size) == 3 191 | ), f"Collate error, an image should be in shape of (3, H, W), instead of given {size}" 192 | 193 | if len(img_keys) != 0: 194 | max_height = max([i[1] for i in img_sizes]) 195 | max_width = max([i[2] for i in img_sizes]) 196 | 197 | for img_key in img_keys: 198 | img = dict_batch[img_key] 199 | view_size = len(img[0]) 200 | 201 | new_images = [ 202 | torch.zeros(batch_size, 3, max_height, max_width) 203 | for _ in range(view_size) 204 | ] 205 | 206 | for bi in range(batch_size): 207 | orig_batch = img[bi] 208 | for vi in range(view_size): 209 | if orig_batch is None: 210 | new_images[vi][bi] = None 211 | else: 212 | orig = img[bi][vi] 213 | new_images[vi][bi, :, : orig.shape[1], : orig.shape[2]] = orig 214 | 215 | dict_batch[img_key] = new_images 216 | 217 | txt_keys = [k for k in list(dict_batch.keys()) if "text" in k] 218 | 219 | if len(txt_keys) != 0: 220 | texts = [[d[0] for d in dict_batch[txt_key]] for txt_key in txt_keys] 221 | encodings = [[d[1] for d in dict_batch[txt_key]] for txt_key in txt_keys] 222 | draw_text_len = len(encodings) 223 | flatten_encodings = [e for encoding in encodings for e in encoding] 224 | flatten_mlms = mlm_collator(flatten_encodings) 225 | 226 | for i, txt_key in enumerate(txt_keys): 227 | texts, encodings = ( 228 | [d[0] for d in dict_batch[txt_key]], 229 | [d[1] for d in dict_batch[txt_key]], 230 | ) 231 | 232 | mlm_ids, mlm_labels = ( 233 | flatten_mlms["input_ids"][batch_size * (i) : batch_size * (i + 1)], 234 | flatten_mlms["labels"][batch_size * (i) : batch_size * (i + 1)], 235 | ) 236 | 237 | input_ids = torch.zeros_like(mlm_ids) 238 | attention_mask = torch.zeros_like(mlm_ids) 239 | for _i, encoding in enumerate(encodings): 240 | _input_ids, _attention_mask = ( 241 | torch.tensor(encoding["input_ids"]), 242 | torch.tensor(encoding["attention_mask"]), 243 | ) 244 | input_ids[_i, : len(_input_ids)] = _input_ids 245 | attention_mask[_i, : len(_attention_mask)] = _attention_mask 246 | 247 | dict_batch[txt_key] = texts 248 | dict_batch[f"{txt_key}_ids"] = input_ids 249 | dict_batch[f"{txt_key}_labels"] = torch.full_like(input_ids, -100) 250 | dict_batch[f"{txt_key}_ids_mlm"] = mlm_ids 251 | dict_batch[f"{txt_key}_labels_mlm"] = mlm_labels 252 | dict_batch[f"{txt_key}_masks"] = attention_mask 253 | 254 | return dict_batch 255 | -------------------------------------------------------------------------------- /meter/datasets/coco_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import io 3 | from PIL import Image 4 | 5 | class CocoCaptionKarpathyDataset(BaseDataset): 6 | def __init__(self, *args, split="", **kwargs): 7 | assert split in ["train", "val", "test"] 8 | self.split = split 9 | 10 | if split == "train": 11 | names = ["coco_caption_karpathy_train", "coco_caption_karpathy_restval"] 12 | elif split == "val": 13 | # names = ["coco_caption_karpathy_val"] 14 | names = ["coco_caption_karpathy_test"] 15 | elif split == "test": 16 | names = ["coco_caption_karpathy_test"] 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | 21 | def __getitem__(self, index): 22 | suite = self.get_suite(index) 23 | 24 | if "test" in self.split: 25 | _index, _question_index = self.index_mapper[index] 26 | iid = self.table["image_id"][_index].as_py() 27 | iid = int(iid.split(".")[0].split("_")[-1]) 28 | suite.update({"iid": iid}) 29 | 30 | return suite 31 | -------------------------------------------------------------------------------- /meter/datasets/conceptual_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from .base_dataset import BaseDataset 3 | import io 4 | from PIL import Image 5 | 6 | 7 | class ConceptualCaptionDataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | if split == "test": 11 | split = "val" 12 | 13 | if split == "train": 14 | names = [f"conceptual_caption_train_{i}" for i in range(31)] 15 | elif split == "val": 16 | names = [] 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | 21 | def __getitem__(self, index): 22 | return self.get_suite(index) 23 | 24 | def get_text(self, raw_index): 25 | index, caption_index = self.index_mapper[raw_index] 26 | 27 | text = self.all_texts[index][caption_index] 28 | encoding = self.tokenizer( 29 | text, 30 | padding="max_length", 31 | truncation=True, 32 | max_length=self.max_text_len, 33 | return_special_tokens_mask=True, 34 | ) 35 | return { 36 | "text": (text, encoding), 37 | "img_index": index, 38 | "cap_index": caption_index, 39 | "raw_index": raw_index, 40 | } 41 | -------------------------------------------------------------------------------- /meter/datasets/f30k_caption_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class F30KCaptionKarpathyDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | 8 | if split == "train": 9 | names = ["f30k_caption_karpathy_train", "f30k_caption_karpathy_val"] 10 | elif split == "val": 11 | names = ["f30k_caption_karpathy_test"] 12 | elif split == "test": 13 | names = ["f30k_caption_karpathy_test"] 14 | 15 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 16 | 17 | def __getitem__(self, index): 18 | return self.get_suite(index) 19 | -------------------------------------------------------------------------------- /meter/datasets/nlvr2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import sys 3 | import random 4 | 5 | 6 | class NLVR2Dataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | self.split = split 10 | 11 | if split == "train": 12 | names = ["nlvr2_train"] 13 | elif split == "val": 14 | names = ["nlvr2_dev", "nlvr2_test1"] 15 | elif split == "test": 16 | names = ["nlvr2_dev", "nlvr2_test1"] 17 | 18 | super().__init__( 19 | *args, 20 | **kwargs, 21 | names=names, 22 | text_column_name="questions", 23 | remove_duplicate=False, 24 | ) 25 | 26 | def __getitem__(self, index): 27 | result = None 28 | while result is None: 29 | try: 30 | image_tensor_0 = self.get_image(index, image_key="image_0")["image"] 31 | image_tensor_1 = self.get_image(index, image_key="image_1")["image"] 32 | text = self.get_text(index)["text"] 33 | result = True 34 | except: 35 | print( 36 | f"error while read file idx {index} in {self.names[0]}", 37 | file=sys.stderr, 38 | ) 39 | index = random.randint(0, len(self.index_mapper) - 1) 40 | 41 | index, question_index = self.index_mapper[index] 42 | answers = self.table["answers"][index][question_index].as_py() 43 | answers = answers == "True" 44 | 45 | return { 46 | "image_0": image_tensor_0, 47 | "image_1": image_tensor_1, 48 | "text": text, 49 | "answers": answers, 50 | "table_name": self.table_names[index], 51 | } 52 | -------------------------------------------------------------------------------- /meter/datasets/sbu_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from .base_dataset import BaseDataset 3 | import io 4 | from PIL import Image 5 | 6 | 7 | class SBUCaptionDataset(BaseDataset): 8 | def __init__(self, *args, split="", **kwargs): 9 | assert split in ["train", "val", "test"] 10 | if split == "test": 11 | split = "val" 12 | 13 | if split == "train": 14 | names = [f"sbu_{i}" for i in range(9)] 15 | elif split == "val": 16 | names = [] 17 | 18 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 19 | 20 | def __getitem__(self, index): 21 | return self.get_suite(index) 22 | 23 | def get_text(self, raw_index): 24 | index, caption_index = self.index_mapper[raw_index] 25 | 26 | text = self.all_texts[index][caption_index] 27 | encoding = self.tokenizer( 28 | text, 29 | padding="max_length", 30 | truncation=True, 31 | max_length=self.max_text_len, 32 | return_special_tokens_mask=True, 33 | ) 34 | return { 35 | "text": (text, encoding), 36 | "img_index": index, 37 | "cap_index": caption_index, 38 | "raw_index": raw_index, 39 | } 40 | -------------------------------------------------------------------------------- /meter/datasets/snli_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class SNLIDataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["snli_train"] 11 | elif split == "val": 12 | names = ["snli_dev", "snli_test"] 13 | elif split == "test": 14 | names = ["snli_dev", "snli_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="sentences", 21 | remove_duplicate=False, 22 | ) 23 | 24 | def __getitem__(self, index): 25 | image_tensor = self.get_image(index)["image"] 26 | text = self.get_text(index)["text"] 27 | 28 | index, question_index = self.index_mapper[index] 29 | 30 | labels = self.table["labels"][index][question_index].as_py() 31 | 32 | return { 33 | "image": image_tensor, 34 | "text": text, 35 | "labels": labels, 36 | "table_name": self.table_names[index], 37 | } 38 | -------------------------------------------------------------------------------- /meter/datasets/vg_caption_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import io 3 | from PIL import Image 4 | 5 | 6 | class VisualGenomeCaptionDataset(BaseDataset): 7 | def __init__(self, *args, split="", **kwargs): 8 | assert split in ["train", "val", "test"] 9 | if split == "test": 10 | split = "val" 11 | 12 | if split == "train": 13 | names = ["vg"] 14 | elif split == "val": 15 | names = [] 16 | 17 | super().__init__(*args, **kwargs, names=names, text_column_name="caption") 18 | 19 | def __getitem__(self, index): 20 | return self.get_suite(index) 21 | 22 | def get_text(self, raw_index): 23 | index, caption_index = self.index_mapper[raw_index] 24 | 25 | text = self.all_texts[index][caption_index] 26 | encoding = self.tokenizer( 27 | text, 28 | padding="max_length", 29 | truncation=True, 30 | max_length=self.max_text_len, 31 | return_special_tokens_mask=True, 32 | ) 33 | return { 34 | "text": (text, encoding), 35 | "img_index": index, 36 | "cap_index": caption_index, 37 | "raw_index": raw_index, 38 | } 39 | -------------------------------------------------------------------------------- /meter/datasets/vqav2_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | 3 | 4 | class VQAv2Dataset(BaseDataset): 5 | def __init__(self, *args, split="", **kwargs): 6 | assert split in ["train", "val", "test"] 7 | self.split = split 8 | 9 | if split == "train": 10 | names = ["vqav2_train", "vqav2_val"] 11 | elif split == "val": 12 | names = ["vqav2_val"] 13 | elif split == "test": 14 | names = ["vqav2_test"] 15 | 16 | super().__init__( 17 | *args, 18 | **kwargs, 19 | names=names, 20 | text_column_name="questions", 21 | remove_duplicate=False, 22 | ) 23 | 24 | def __getitem__(self, index): 25 | image_tensor = self.get_image(index)["image"] 26 | text = self.get_text(index)["text"] 27 | 28 | index, question_index = self.index_mapper[index] 29 | qid = self.table["question_id"][index][question_index].as_py() 30 | 31 | if self.split != "test": 32 | answers = self.table["answers"][index][question_index].as_py() 33 | labels = self.table["answer_labels"][index][question_index].as_py() 34 | scores = self.table["answer_scores"][index][question_index].as_py() 35 | else: 36 | answers = list() 37 | labels = list() 38 | scores = list() 39 | 40 | return { 41 | "image": image_tensor, 42 | "text": text, 43 | "vqa_answer": answers, 44 | "vqa_labels": labels, 45 | "vqa_scores": scores, 46 | "qid": qid, 47 | } 48 | -------------------------------------------------------------------------------- /meter/gadgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdou0830/METER/f4f09345b26ee21add0a756d06598e3c04726345/meter/gadgets/__init__.py -------------------------------------------------------------------------------- /meter/gadgets/my_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.metrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | preds = logits.argmax(dim=-1) 17 | preds = preds[target != -100] 18 | target = target[target != -100] 19 | if target.numel() == 0: 20 | return 1 21 | 22 | assert preds.shape == target.shape 23 | 24 | self.correct += torch.sum(preds == target) 25 | self.total += target.numel() 26 | 27 | def compute(self): 28 | return self.correct / self.total 29 | 30 | 31 | class Scalar(Metric): 32 | def __init__(self, dist_sync_on_step=False): 33 | super().__init__(dist_sync_on_step=dist_sync_on_step) 34 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 35 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 36 | 37 | def update(self, scalar): 38 | if isinstance(scalar, torch.Tensor): 39 | scalar = scalar.detach().to(self.scalar.device) 40 | else: 41 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 42 | self.scalar += scalar 43 | self.total += 1 44 | 45 | def compute(self): 46 | return self.scalar / self.total 47 | 48 | 49 | class VQAScore(Metric): 50 | def __init__(self, dist_sync_on_step=False): 51 | super().__init__(dist_sync_on_step=dist_sync_on_step) 52 | self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") 53 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 54 | 55 | def update(self, logits, target): 56 | logits, target = ( 57 | logits.detach().float().to(self.score.device), 58 | target.detach().float().to(self.score.device), 59 | ) 60 | logits = torch.max(logits, 1)[1] 61 | one_hots = torch.zeros(*target.size()).to(target) 62 | one_hots.scatter_(1, logits.view(-1, 1), 1) 63 | scores = one_hots * target 64 | 65 | self.score += scores.sum() 66 | self.total += len(logits) 67 | 68 | def compute(self): 69 | return self.score / self.total 70 | -------------------------------------------------------------------------------- /meter/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .meter_module import METERTransformerSS 2 | -------------------------------------------------------------------------------- /meter/modules/clip_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | """Subclass torch's LayerNorm to handle fp16.""" 12 | 13 | def forward(self, x: torch.Tensor): 14 | orig_type = x.dtype 15 | ret = super().forward(x.type(torch.float32)) 16 | return ret.type(orig_type) 17 | 18 | 19 | class QuickGELU(nn.Module): 20 | def forward(self, x: torch.Tensor): 21 | return x * torch.sigmoid(1.702 * x) 22 | 23 | 24 | class ResidualAttentionBlock(nn.Module): 25 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 26 | super().__init__() 27 | 28 | self.attn = nn.MultiheadAttention(d_model, n_head) 29 | self.ln_1 = LayerNorm(d_model) 30 | self.mlp = nn.Sequential(OrderedDict([ 31 | ("c_fc", nn.Linear(d_model, d_model * 4)), 32 | ("gelu", QuickGELU()), 33 | ("c_proj", nn.Linear(d_model * 4, d_model)) 34 | ])) 35 | self.ln_2 = LayerNorm(d_model) 36 | self.attn_mask = attn_mask 37 | 38 | def attention(self, x: torch.Tensor, x_mask:torch.Tensor): 39 | if x_mask is not None: 40 | x_mask = x_mask.to(dtype=torch.bool, device=x.device) 41 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 42 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=x_mask)[0] 43 | 44 | def forward(self, x: torch.Tensor, x_mask:torch.Tensor=None): 45 | x = x + self.attention(self.ln_1(x), x_mask) 46 | x = x + self.mlp(self.ln_2(x)) 47 | return x 48 | 49 | 50 | class Transformer(nn.Module): 51 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 52 | super().__init__() 53 | self.width = width 54 | self.layers = layers 55 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers-1)]) 56 | 57 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor=None): 58 | for block in self.resblocks: 59 | x = block(x, x_mask) 60 | return x 61 | 62 | 63 | class VisualTransformer(nn.Module): 64 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, resolution_after: int): 65 | super().__init__() 66 | self.input_resolution = input_resolution 67 | self.output_dim = output_dim 68 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 69 | 70 | scale = width ** -0.5 71 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 72 | self.positional_embedding = nn.Parameter(scale * torch.randn((resolution_after // patch_size) ** 2 + 1, width)) 73 | self.ln_pre = LayerNorm(width) 74 | 75 | self.transformer = Transformer(width, layers, heads) 76 | self.ln_post = LayerNorm(width) 77 | 78 | def forward(self, x: torch.Tensor, x_mask): 79 | x = self.conv1(x) # shape = [*, width, grid, grid] 80 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 81 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 82 | t=self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) 83 | x = torch.cat([t, x], dim=1) # shape = [*, grid ** 2 + 1, width] 84 | x = x + self.positional_embedding.to(x.dtype) 85 | x = self.ln_pre(x) 86 | 87 | x = x.permute(1, 0, 2) # NLD -> LND 88 | x = self.transformer(x, x_mask) 89 | x = x.permute(1, 0, 2) # LND -> NLD 90 | 91 | x = self.ln_post(x) 92 | 93 | return x 94 | 95 | 96 | class CLIP(nn.Module): 97 | def __init__(self, 98 | embed_dim: int, 99 | # vision 100 | image_resolution: int, 101 | vision_layers: Union[Tuple[int, int, int, int], int], 102 | vision_width: int, 103 | vision_patch_size: int, 104 | # text 105 | context_length: int, 106 | vocab_size: int, 107 | transformer_width: int, 108 | transformer_heads: int, 109 | transformer_layers: int, 110 | resolution_after=224, 111 | ): 112 | super().__init__() 113 | 114 | self.context_length = context_length 115 | 116 | vision_heads = vision_width // 64 117 | self.visual = VisualTransformer( 118 | input_resolution=image_resolution, 119 | patch_size=vision_patch_size, 120 | width=vision_width, 121 | layers=vision_layers, 122 | heads=vision_heads, 123 | output_dim=embed_dim, 124 | resolution_after=resolution_after, 125 | ) 126 | 127 | self.vocab_size = vocab_size 128 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 129 | self.ln_final = LayerNorm(transformer_width) 130 | 131 | self.initialize_parameters() 132 | 133 | def initialize_parameters(self): 134 | nn.init.normal_(self.positional_embedding, std=0.01) 135 | 136 | proj_std = (self.visual.transformer.width ** -0.5) * ((2 * self.visual.transformer.layers) ** -0.5) 137 | attn_std = self.visual.transformer.width ** -0.5 138 | fc_std = (2 * self.visual.transformer.width) ** -0.5 139 | for block in self.visual.transformer.resblocks: 140 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 141 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 142 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 143 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 144 | 145 | @property 146 | def dtype(self): 147 | return self.visual.conv1.weight.dtype 148 | 149 | def forward(self, image, image_mask=None): 150 | return self.visual(image.type(self.dtype), image_mask) 151 | 152 | 153 | _MODELS = { 154 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 155 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 156 | } 157 | import os 158 | import hashlib 159 | import urllib 160 | from tqdm import tqdm 161 | import warnings 162 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 163 | os.makedirs(root, exist_ok=True) 164 | filename = os.path.basename(url) 165 | 166 | expected_sha256 = url.split("/")[-2] 167 | download_target = os.path.join(root, filename) 168 | 169 | if os.path.exists(download_target) and not os.path.isfile(download_target): 170 | raise RuntimeError(f"{download_target} exists and is not a regular file") 171 | 172 | if os.path.isfile(download_target): 173 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 174 | return download_target 175 | else: 176 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 177 | 178 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 179 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 180 | while True: 181 | buffer = source.read(8192) 182 | if not buffer: 183 | break 184 | 185 | output.write(buffer) 186 | loop.update(len(buffer)) 187 | 188 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 189 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 190 | 191 | return download_target 192 | 193 | def adapt_position_encoding(model, patch_size=32, after=384, 194 | suffix='visual.positional_embedding'): 195 | keys = [k for k in model if k.endswith(suffix)] 196 | assert len(keys) == 1 197 | key = keys[0] 198 | origin_pos_embed = model[key] 199 | origin_dim2 = False 200 | if len(origin_pos_embed.shape) == 2: 201 | origin_dim2 = True 202 | origin_pos_embed = origin_pos_embed.unsqueeze(0) 203 | grid_before = int(np.sqrt(origin_pos_embed.shape[1] - 1)) 204 | before = int(grid_before*patch_size) 205 | assert (before % patch_size) == 0 206 | grid_after = after // patch_size 207 | assert (after % patch_size) == 0 208 | embed_dim = origin_pos_embed.shape[-1] 209 | 210 | pos_embed = origin_pos_embed[0, 1:, :].reshape((grid_before, grid_before, embed_dim)) 211 | new_size = (grid_after, grid_after) 212 | pos_embed = torch.nn.functional.interpolate(pos_embed.permute((2, 0, 1)).unsqueeze(0), size=new_size, mode='bicubic') 213 | pos_embed = pos_embed.squeeze(0).permute((1, 2, 0)).reshape((-1, embed_dim)) 214 | pos_embed = torch.cat((origin_pos_embed[0, 0:1, :], pos_embed), dim=0).unsqueeze(0) 215 | assert pos_embed.shape == (1, grid_after * grid_after + 1, embed_dim) 216 | if origin_dim2: 217 | assert pos_embed.shape[0] == 1 218 | pos_embed = pos_embed.squeeze(0) 219 | model[key] = pos_embed 220 | return model 221 | 222 | 223 | def build_model(name, resolution_after=224): 224 | if name in _MODELS: 225 | model_path = _download(_MODELS[name]) 226 | elif os.path.isfile(name): 227 | model_path = name 228 | else: 229 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}" 230 | ) 231 | try: 232 | model = torch.jit.load(model_path, map_location="cpu") 233 | state_dict = None 234 | except RuntimeError: 235 | if jit: 236 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 237 | jit = False 238 | state_dict = torch.load(model_path, map_location="cpu") 239 | state_dict = state_dict or model.state_dict() 240 | vit = "visual.proj" in state_dict 241 | 242 | vision_width = state_dict["visual.conv1.weight"].shape[0] 243 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 244 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 245 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 246 | image_resolution = vision_patch_size * grid_size 247 | 248 | embed_dim = state_dict["text_projection"].shape[1] 249 | context_length = state_dict["positional_embedding"].shape[0] 250 | vocab_size = state_dict["token_embedding.weight"].shape[0] 251 | transformer_width = state_dict["ln_final.weight"].shape[0] 252 | transformer_heads = transformer_width // 64 253 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 254 | 255 | model = CLIP( 256 | embed_dim, 257 | image_resolution, vision_layers, vision_width, vision_patch_size, 258 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, 259 | resolution_after, 260 | ) 261 | 262 | for key in ["input_resolution", "context_length", "vocab_size"]: 263 | if key in state_dict: 264 | del state_dict[key] 265 | 266 | model_dict = model.state_dict() 267 | pretrained_dict = state_dict 268 | if resolution_after != image_resolution: 269 | pretrained_dict = adapt_position_encoding(pretrained_dict, after=resolution_after, patch_size=vision_patch_size) 270 | # 1. filter out unnecessary keys 271 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 272 | # 2. overwrite entries in the existing state dict 273 | model_dict.update(pretrained_dict) 274 | # 3. load the new state dict 275 | model.load_state_dict(model_dict) 276 | return model 277 | -------------------------------------------------------------------------------- /meter/modules/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import torch 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert _LOCAL_PROCESS_GROUP is not None 49 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 50 | 51 | 52 | def get_local_size() -> int: 53 | """ 54 | Returns: 55 | The size of the per-machine process group, 56 | i.e. the number of processes per machine. 57 | """ 58 | if not dist.is_available(): 59 | return 1 60 | if not dist.is_initialized(): 61 | return 1 62 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 63 | 64 | 65 | def is_main_process() -> bool: 66 | return get_rank() == 0 67 | 68 | 69 | def synchronize(): 70 | """ 71 | Helper function to synchronize (barrier) among all processes when 72 | using distributed training 73 | """ 74 | if not dist.is_available(): 75 | return 76 | if not dist.is_initialized(): 77 | return 78 | world_size = dist.get_world_size() 79 | if world_size == 1: 80 | return 81 | dist.barrier() 82 | 83 | 84 | @functools.lru_cache() 85 | def _get_global_gloo_group(): 86 | """ 87 | Return a process group based on gloo backend, containing all the ranks 88 | The result is cached. 89 | """ 90 | if dist.get_backend() == "nccl": 91 | return dist.new_group(backend="gloo") 92 | else: 93 | return dist.group.WORLD 94 | 95 | 96 | def _serialize_to_tensor(data, group): 97 | backend = dist.get_backend(group) 98 | assert backend in ["gloo", "nccl"] 99 | device = torch.device("cpu" if backend == "gloo" else "cuda") 100 | 101 | buffer = pickle.dumps(data) 102 | if len(buffer) > 1024 ** 3: 103 | logger = logging.getLogger(__name__) 104 | logger.warning( 105 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 106 | get_rank(), len(buffer) / (1024 ** 3), device 107 | ) 108 | ) 109 | storage = torch.ByteStorage.from_buffer(buffer) 110 | tensor = torch.ByteTensor(storage).to(device=device) 111 | return tensor 112 | 113 | 114 | def _pad_to_largest_tensor(tensor, group): 115 | """ 116 | Returns: 117 | list[int]: size of the tensor, on each rank 118 | Tensor: padded tensor that has the max size 119 | """ 120 | world_size = dist.get_world_size(group=group) 121 | assert ( 122 | world_size >= 1 123 | ), "comm.gather/all_gather must be called from ranks within the given group!" 124 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 125 | size_list = [ 126 | torch.zeros([1], dtype=torch.int64, device=tensor.device) 127 | for _ in range(world_size) 128 | ] 129 | dist.all_gather(size_list, local_size, group=group) 130 | size_list = [int(size.item()) for size in size_list] 131 | 132 | max_size = max(size_list) 133 | 134 | # we pad the tensor because torch all_gather does not support 135 | # gathering tensors of different shapes 136 | if local_size != max_size: 137 | padding = torch.zeros( 138 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device 139 | ) 140 | tensor = torch.cat((tensor, padding), dim=0) 141 | return size_list, tensor 142 | 143 | 144 | def all_gather(data, group=None): 145 | """ 146 | Run all_gather on arbitrary picklable data (not necessarily tensors). 147 | 148 | Args: 149 | data: any picklable object 150 | group: a torch process group. By default, will use a group which 151 | contains all ranks on gloo backend. 152 | 153 | Returns: 154 | list[data]: list of data gathered from each rank 155 | """ 156 | if get_world_size() == 1: 157 | return [data] 158 | if group is None: 159 | group = _get_global_gloo_group() 160 | if dist.get_world_size(group) == 1: 161 | return [data] 162 | 163 | tensor = _serialize_to_tensor(data, group) 164 | 165 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 166 | max_size = max(size_list) 167 | 168 | # receiving Tensor from all ranks 169 | tensor_list = [ 170 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 171 | for _ in size_list 172 | ] 173 | dist.all_gather(tensor_list, tensor, group=group) 174 | 175 | data_list = [] 176 | for size, tensor in zip(size_list, tensor_list): 177 | buffer = tensor.cpu().numpy().tobytes()[:size] 178 | data_list.append(pickle.loads(buffer)) 179 | 180 | return data_list 181 | 182 | 183 | def gather(data, dst=0, group=None): 184 | """ 185 | Run gather on arbitrary picklable data (not necessarily tensors). 186 | 187 | Args: 188 | data: any picklable object 189 | dst (int): destination rank 190 | group: a torch process group. By default, will use a group which 191 | contains all ranks on gloo backend. 192 | 193 | Returns: 194 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 195 | an empty list. 196 | """ 197 | if get_world_size() == 1: 198 | return [data] 199 | if group is None: 200 | group = _get_global_gloo_group() 201 | if dist.get_world_size(group=group) == 1: 202 | return [data] 203 | rank = dist.get_rank(group=group) 204 | 205 | tensor = _serialize_to_tensor(data, group) 206 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 207 | 208 | # receiving Tensor from all ranks 209 | if rank == dst: 210 | max_size = max(size_list) 211 | tensor_list = [ 212 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) 213 | for _ in size_list 214 | ] 215 | dist.gather(tensor, tensor_list, dst=dst, group=group) 216 | 217 | data_list = [] 218 | for size, tensor in zip(size_list, tensor_list): 219 | buffer = tensor.cpu().numpy().tobytes()[:size] 220 | data_list.append(pickle.loads(buffer)) 221 | return data_list 222 | else: 223 | dist.gather(tensor, [], dst=dst, group=group) 224 | return [] 225 | 226 | 227 | def shared_random_seed(): 228 | """ 229 | Returns: 230 | int: a random number that is the same across all workers. 231 | If workers need a shared RNG, they can use this shared seed to 232 | create one. 233 | 234 | All workers must call this function, otherwise it will deadlock. 235 | """ 236 | ints = np.random.randint(2 ** 31) 237 | all_ints = all_gather(ints) 238 | return all_ints[0] 239 | 240 | 241 | def reduce_dict(input_dict, average=True): 242 | """ 243 | Reduce the values in the dictionary from all processes so that process with rank 244 | 0 has the reduced results. 245 | 246 | Args: 247 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 248 | average (bool): whether to do average or sum 249 | 250 | Returns: 251 | a dict with the same keys as input_dict, after reduction. 252 | """ 253 | world_size = get_world_size() 254 | if world_size < 2: 255 | return input_dict 256 | with torch.no_grad(): 257 | names = [] 258 | values = [] 259 | # sort the keys so that they are consistent across processes 260 | for k in sorted(input_dict.keys()): 261 | names.append(k) 262 | values.append(input_dict[k]) 263 | values = torch.stack(values, dim=0) 264 | dist.reduce(values, dst=0) 265 | if dist.get_rank() == 0 and average: 266 | # only main process gets accumulated, so only divide by 267 | # world_size in this case 268 | values /= world_size 269 | reduced_dict = {k: v for k, v in zip(names, values)} 270 | return reduced_dict 271 | -------------------------------------------------------------------------------- /meter/modules/heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.models.bert.modeling_bert import BertPredictionHeadTransform 6 | 7 | 8 | class Pooler(nn.Module): 9 | def __init__(self, hidden_size): 10 | super().__init__() 11 | self.dense = nn.Linear(hidden_size, hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | first_token_tensor = hidden_states[:, 0] 16 | pooled_output = self.dense(first_token_tensor) 17 | pooled_output = self.activation(pooled_output) 18 | return pooled_output 19 | 20 | 21 | class ITMHead(nn.Module): 22 | def __init__(self, hidden_size): 23 | super().__init__() 24 | self.fc = nn.Linear(hidden_size, 2) 25 | 26 | def forward(self, x): 27 | x = self.fc(x) 28 | return x 29 | 30 | 31 | class MLMHead(nn.Module): 32 | def __init__(self, config, weight=None): 33 | super().__init__() 34 | self.transform = BertPredictionHeadTransform(config) 35 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 36 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 37 | if weight is not None: 38 | self.decoder.weight = weight 39 | 40 | def forward(self, x): 41 | x = self.transform(x) 42 | x = self.decoder(x) + self.bias 43 | return x 44 | -------------------------------------------------------------------------------- /meter/modules/meter_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | import numpy as np 5 | 6 | from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings, BertModel, BertEncoder, BertLayer 7 | from .bert_model import BertCrossLayer, BertAttention 8 | from . import swin_transformer as swin 9 | from . import heads, objectives, meter_utils 10 | from .clip_model import build_model, adapt_position_encoding 11 | from .swin_helpers import swin_adapt_position_encoding 12 | from transformers import RobertaConfig, RobertaModel 13 | 14 | class METERTransformerSS(pl.LightningModule): 15 | def __init__(self, config): 16 | super().__init__() 17 | self.save_hyperparameters() 18 | 19 | self.is_clip= (not 'swin' in config['vit']) 20 | 21 | if 'roberta' in config['tokenizer']: 22 | bert_config = RobertaConfig( 23 | vocab_size=config["vocab_size"], 24 | hidden_size=config["hidden_size"], 25 | num_hidden_layers=config["num_layers"], 26 | num_attention_heads=config["num_heads"], 27 | intermediate_size=config["hidden_size"] * config["mlp_ratio"], 28 | max_position_embeddings=config["max_text_len"], 29 | hidden_dropout_prob=config["drop_rate"], 30 | attention_probs_dropout_prob=config["drop_rate"], 31 | ) 32 | else: 33 | bert_config = BertConfig( 34 | vocab_size=config["vocab_size"], 35 | hidden_size=config["hidden_size"], 36 | num_hidden_layers=config["num_layers"], 37 | num_attention_heads=config["num_heads"], 38 | intermediate_size=config["hidden_size"] * config["mlp_ratio"], 39 | max_position_embeddings=config["max_text_len"], 40 | hidden_dropout_prob=config["drop_rate"], 41 | attention_probs_dropout_prob=config["drop_rate"], 42 | ) 43 | 44 | resolution_after=config['image_size'] 45 | 46 | self.cross_modal_text_transform = nn.Linear(config['input_text_embed_size'], config['hidden_size']) 47 | self.cross_modal_text_transform.apply(objectives.init_weights) 48 | self.cross_modal_image_transform = nn.Linear(config['input_image_embed_size'], config['hidden_size']) 49 | self.cross_modal_image_transform.apply(objectives.init_weights) 50 | 51 | self.token_type_embeddings = nn.Embedding(2, config["hidden_size"]) 52 | self.token_type_embeddings.apply(objectives.init_weights) 53 | 54 | if torch.distributed.is_initialized(): 55 | if torch.distributed.get_rank() == 0: 56 | if self.is_clip: 57 | build_model(config['vit'], resolution_after=resolution_after) 58 | else: 59 | getattr(swin, self.hparams.config["vit"])( 60 | pretrained=True, config=self.hparams.config, 61 | ) 62 | 63 | if 'roberta' in config['tokenizer']: 64 | RobertaModel.from_pretrained(config['tokenizer']) 65 | else: 66 | BertModel.from_pretrained(config['tokenizer']) 67 | 68 | torch.distributed.barrier() 69 | 70 | if self.is_clip: 71 | self.vit_model = build_model(config['vit'], resolution_after=resolution_after) 72 | else: 73 | self.vit_model = getattr(swin, self.hparams.config["vit"])( 74 | pretrained=True, config=self.hparams.config, 75 | ) 76 | self.avgpool = nn.AdaptiveAvgPool1d(1) 77 | 78 | if 'roberta' in config['tokenizer']: 79 | self.text_transformer = RobertaModel.from_pretrained(config['tokenizer']) 80 | else: 81 | self.text_transformer = BertModel.from_pretrained(config['tokenizer']) 82 | 83 | self.cross_modal_image_layers = nn.ModuleList([BertCrossLayer(bert_config) for _ in range(config['num_top_layer'])]) 84 | self.cross_modal_image_layers.apply(objectives.init_weights) 85 | self.cross_modal_text_layers = nn.ModuleList([BertCrossLayer(bert_config) for _ in range(config['num_top_layer'])]) 86 | self.cross_modal_text_layers.apply(objectives.init_weights) 87 | 88 | self.cross_modal_image_pooler = heads.Pooler(config["hidden_size"]) 89 | self.cross_modal_image_pooler.apply(objectives.init_weights) 90 | self.cross_modal_text_pooler = heads.Pooler(config["hidden_size"]) 91 | self.cross_modal_text_pooler.apply(objectives.init_weights) 92 | 93 | if config["loss_names"]["mlm"] > 0: 94 | self.mlm_score = heads.MLMHead(bert_config) 95 | self.mlm_score.apply(objectives.init_weights) 96 | 97 | if config["loss_names"]["itm"] > 0: 98 | self.itm_score = heads.ITMHead(config["hidden_size"]*2) 99 | self.itm_score.apply(objectives.init_weights) 100 | 101 | hs = self.hparams.config["hidden_size"] 102 | 103 | if self.hparams.config["loss_names"]["vqa"] > 0: 104 | vs = self.hparams.config["vqav2_label_size"] 105 | self.vqa_classifier = nn.Sequential( 106 | nn.Linear(hs * 2, hs * 2), 107 | nn.LayerNorm(hs * 2), 108 | nn.GELU(), 109 | nn.Linear(hs * 2, vs), 110 | ) 111 | self.vqa_classifier.apply(objectives.init_weights) 112 | 113 | # ===================== Downstream ===================== # 114 | if ( 115 | self.hparams.config["load_path"] != "" 116 | and not self.hparams.config["test_only"] 117 | ): 118 | ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") 119 | state_dict = ckpt["state_dict"] 120 | if self.is_clip: 121 | state_dict = adapt_position_encoding(state_dict, after=resolution_after, patch_size=self.hparams.config['patch_size']) 122 | else: 123 | state_dict = swin_adapt_position_encoding(state_dict, after=resolution_after, before=config['resolution_before']) 124 | self.load_state_dict(state_dict, strict=False) 125 | 126 | 127 | if self.hparams.config["loss_names"]["nlvr2"] > 0: 128 | self.nlvr2_classifier = nn.Sequential( 129 | nn.Linear(hs * 4, hs * 2), 130 | nn.LayerNorm(hs * 2), 131 | nn.GELU(), 132 | nn.Linear(hs * 2, 2), 133 | ) 134 | self.nlvr2_classifier.apply(objectives.init_weights) 135 | emb_data = self.token_type_embeddings.weight.data 136 | self.token_type_embeddings = nn.Embedding(3, hs) 137 | self.token_type_embeddings.apply(objectives.init_weights) 138 | self.token_type_embeddings.weight.data[0, :] = emb_data[0, :] 139 | self.token_type_embeddings.weight.data[1, :] = emb_data[1, :] 140 | self.token_type_embeddings.weight.data[2, :] = emb_data[1, :] 141 | 142 | if self.hparams.config["loss_names"]["snli"] > 0: 143 | self.snli_classifier = nn.Sequential( 144 | nn.Linear(hs * 2, hs * 2), 145 | nn.LayerNorm(hs * 2), 146 | nn.GELU(), 147 | nn.Linear(hs * 2, 3), 148 | ) 149 | self.snli_classifier.apply(objectives.init_weights) 150 | 151 | if self.hparams.config["loss_names"]["irtr"] > 0: 152 | self.rank_output = nn.Linear(hs, 1) 153 | self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :] 154 | self.rank_output.bias.data = self.itm_score.fc.bias.data[1:] 155 | self.margin = 0.2 156 | for p in self.itm_score.parameters(): 157 | p.requires_grad = False 158 | 159 | meter_utils.set_metrics(self) 160 | self.current_tasks = list() 161 | 162 | # ===================== load downstream (test_only) ====================== 163 | 164 | if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]: 165 | ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") 166 | state_dict = ckpt["state_dict"] 167 | if self.is_clip: 168 | state_dict = adapt_position_encoding(state_dict, after=resolution_after, patch_size=self.hparams.config['patch_size']) 169 | else: 170 | state_dict = swin_adapt_position_encoding(state_dict, after=resolution_after, before=config['resolution_before']) 171 | self.load_state_dict(state_dict, strict=False) 172 | 173 | def infer( 174 | self, 175 | batch, 176 | mask_text=False, 177 | mask_image=False, 178 | image_token_type_idx=1, 179 | img=None, 180 | ): 181 | if img is None: 182 | if f"image_{image_token_type_idx - 1}" in batch: 183 | imgkey = f"image_{image_token_type_idx - 1}" 184 | else: 185 | imgkey = "image" 186 | img = batch[imgkey][0] 187 | 188 | do_mlm = "_mlm" if mask_text else "" 189 | text_ids = batch[f"text_ids{do_mlm}"] 190 | text_labels = batch[f"text_labels{do_mlm}"] 191 | text_masks = batch[f"text_masks"] 192 | 193 | text_embeds = self.text_transformer.embeddings(input_ids=text_ids) 194 | device = text_embeds.device 195 | input_shape = text_masks.size() 196 | extend_text_masks = self.text_transformer.get_extended_attention_mask(text_masks, input_shape, device) 197 | for layer in self.text_transformer.encoder.layer: 198 | text_embeds = layer(text_embeds, extend_text_masks)[0] 199 | text_embeds = self.cross_modal_text_transform(text_embeds) 200 | 201 | image_embeds = self.vit_model(img) 202 | image_embeds = self.cross_modal_image_transform(image_embeds) 203 | image_masks = torch.ones((image_embeds.size(0), image_embeds.size(1)), dtype=torch.long, device=device) 204 | extend_image_masks = self.text_transformer.get_extended_attention_mask(image_masks, image_masks.size(), device) 205 | 206 | text_embeds, image_embeds = ( 207 | text_embeds + self.token_type_embeddings(torch.zeros_like(text_masks)), 208 | image_embeds 209 | + self.token_type_embeddings( 210 | torch.full_like(image_masks, image_token_type_idx) 211 | ), 212 | ) 213 | 214 | x, y = text_embeds, image_embeds 215 | for text_layer, image_layer in zip(self.cross_modal_text_layers, self.cross_modal_image_layers): 216 | x1 = text_layer(x, y, extend_text_masks, extend_image_masks) 217 | y1 = image_layer(y, x, extend_image_masks, extend_text_masks) 218 | x, y = x1[0], y1[0] 219 | 220 | text_feats, image_feats = x, y 221 | cls_feats_text = self.cross_modal_text_pooler(x) 222 | if self.is_clip: 223 | cls_feats_image = self.cross_modal_image_pooler(y) 224 | else: 225 | avg_image_feats = self.avgpool(image_feats.transpose(1, 2)).view(image_feats.size(0), 1, -1) 226 | cls_feats_image = self.cross_modal_image_pooler(avg_image_feats) 227 | cls_feats = torch.cat([cls_feats_text, cls_feats_image], dim=-1) 228 | 229 | ret = { 230 | "text_feats": text_feats, 231 | "image_feats": image_feats, 232 | "cls_feats": cls_feats, 233 | "text_labels": text_labels, 234 | "text_ids": text_ids, 235 | "text_masks": text_masks, 236 | } 237 | 238 | 239 | return ret 240 | 241 | def forward(self, batch): 242 | ret = dict() 243 | if len(self.current_tasks) == 0: 244 | ret.update(self.infer(batch)) 245 | return ret 246 | 247 | # Masked Language Modeling 248 | if "mlm" in self.current_tasks: 249 | ret.update(objectives.compute_mlm(self, batch)) 250 | 251 | # Image Text Matching 252 | if "itm" in self.current_tasks: 253 | ret.update(objectives.compute_itm(self, batch)) 254 | 255 | # Visual Question Answering 256 | if "vqa" in self.current_tasks: 257 | ret.update(objectives.compute_vqa(self, batch)) 258 | 259 | # Natural Language for Visual Reasoning 2 260 | if "nlvr2" in self.current_tasks: 261 | ret.update(objectives.compute_nlvr2(self, batch)) 262 | 263 | # SNLI Visual Entailment 264 | if "snli" in self.current_tasks: 265 | ret.update(objectives.compute_snli(self, batch)) 266 | 267 | # Image Retrieval and Text Retrieval 268 | if "irtr" in self.current_tasks: 269 | ret.update(objectives.compute_irtr(self, batch)) 270 | 271 | return ret 272 | 273 | def training_step(self, batch, batch_idx): 274 | meter_utils.set_task(self) 275 | output = self(batch) 276 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 277 | 278 | return total_loss 279 | 280 | def training_epoch_end(self, outs): 281 | meter_utils.epoch_wrapup(self) 282 | 283 | def validation_step(self, batch, batch_idx): 284 | meter_utils.set_task(self) 285 | output = self(batch) 286 | 287 | def validation_epoch_end(self, outs): 288 | meter_utils.epoch_wrapup(self) 289 | 290 | def test_step(self, batch, batch_idx): 291 | meter_utils.set_task(self) 292 | output = self(batch) 293 | ret = dict() 294 | 295 | if self.hparams.config["loss_names"]["vqa"] > 0: 296 | ret.update(objectives.vqa_test_step(self, batch, output)) 297 | 298 | return ret 299 | 300 | def test_epoch_end(self, outs): 301 | model_name = self.hparams.config["load_path"].split("/")[-1][:-5] 302 | 303 | if self.hparams.config["loss_names"]["vqa"] > 0: 304 | objectives.vqa_test_wrapup(outs, model_name) 305 | meter_utils.epoch_wrapup(self) 306 | 307 | def configure_optimizers(self): 308 | return meter_utils.set_schedule(self) 309 | -------------------------------------------------------------------------------- /meter/modules/meter_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from transformers.optimization import AdamW 5 | from transformers import ( 6 | get_polynomial_decay_schedule_with_warmup, 7 | get_cosine_schedule_with_warmup, 8 | ) 9 | from .dist_utils import all_gather 10 | from .objectives import compute_irtr_recall 11 | from ..gadgets.my_metrics import Accuracy, VQAScore, Scalar 12 | 13 | 14 | def set_metrics(pl_module): 15 | for split in ["train", "val"]: 16 | for k, v in pl_module.hparams.config["loss_names"].items(): 17 | if v <= 0: 18 | continue 19 | if k == "vqa": 20 | setattr(pl_module, f"{split}_vqa_score", VQAScore()) 21 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 22 | elif k == "nlvr2": 23 | if split == "train": 24 | setattr(pl_module, f"train_{k}_accuracy", Accuracy()) 25 | setattr(pl_module, f"train_{k}_loss", Scalar()) 26 | else: 27 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy()) 28 | setattr(pl_module, f"dev_{k}_loss", Scalar()) 29 | setattr(pl_module, f"test_{k}_accuracy", Accuracy()) 30 | setattr(pl_module, f"test_{k}_loss", Scalar()) 31 | elif k == "snli": 32 | if split == "train": 33 | setattr(pl_module, f"train_{k}_accuracy", Accuracy()) 34 | setattr(pl_module, f"train_{k}_loss", Scalar()) 35 | else: 36 | setattr(pl_module, f"dev_{k}_accuracy", Accuracy()) 37 | setattr(pl_module, f"dev_{k}_loss", Scalar()) 38 | setattr(pl_module, f"test_{k}_accuracy", Accuracy()) 39 | setattr(pl_module, f"test_{k}_loss", Scalar()) 40 | elif k == "irtr": 41 | setattr(pl_module, f"{split}_irtr_loss", Scalar()) 42 | elif k == "mppd" or k == "mpfr": 43 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 44 | elif k == "itm": 45 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 46 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 47 | else: 48 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 49 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 50 | 51 | 52 | def epoch_wrapup(pl_module): 53 | phase = "train" if pl_module.training else "val" 54 | the_metric = 0 55 | 56 | if pl_module.hparams.config["get_recall_metric"] and not pl_module.training: 57 | (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) = compute_irtr_recall(pl_module) 58 | print((ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10), pl_module.global_step) 59 | pl_module.logger.experiment.add_scalar( 60 | "recalls/ir_r1", ir_r1, pl_module.global_step 61 | ) 62 | pl_module.logger.experiment.add_scalar( 63 | "recalls/ir_r5", ir_r5, pl_module.global_step 64 | ) 65 | pl_module.logger.experiment.add_scalar( 66 | "recalls/ir_r10", ir_r10, pl_module.global_step 67 | ) 68 | pl_module.logger.experiment.add_scalar( 69 | "recalls/tr_r1", tr_r1, pl_module.global_step 70 | ) 71 | pl_module.logger.experiment.add_scalar( 72 | "recalls/tr_r5", tr_r5, pl_module.global_step 73 | ) 74 | pl_module.logger.experiment.add_scalar( 75 | "recalls/tr_r10", tr_r10, pl_module.global_step 76 | ) 77 | the_metric += ir_r1.item() + tr_r1.item() 78 | 79 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 80 | if v <= 0: 81 | continue 82 | 83 | value = 0 84 | 85 | if loss_name == "vqa": 86 | value = getattr(pl_module, f"{phase}_{loss_name}_score").compute() 87 | pl_module.log(f"{loss_name}/{phase}/score_epoch", value) 88 | getattr(pl_module, f"{phase}_{loss_name}_score").reset() 89 | pl_module.log( 90 | f"{loss_name}/{phase}/loss_epoch", 91 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 92 | ) 93 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 94 | elif loss_name == "nlvr2" or loss_name == 'snli': 95 | if phase == "train": 96 | value = getattr(pl_module, f"train_{loss_name}_accuracy").compute() 97 | pl_module.log(f"{loss_name}/train/accuracy_epoch", value) 98 | getattr(pl_module, f"train_{loss_name}_accuracy").reset() 99 | pl_module.log( 100 | f"{loss_name}/train/loss_epoch", 101 | getattr(pl_module, f"train_{loss_name}_loss").compute(), 102 | ) 103 | getattr(pl_module, f"train_{loss_name}_loss").reset() 104 | else: 105 | value = getattr(pl_module, f"test_{loss_name}_accuracy").compute() 106 | pl_module.log(f"{loss_name}/test/accuracy_epoch", value) 107 | getattr(pl_module, f"test_{loss_name}_accuracy").reset() 108 | pl_module.log( 109 | f"{loss_name}/test/loss_epoch", 110 | getattr(pl_module, f"test_{loss_name}_loss").compute(), 111 | ) 112 | getattr(pl_module, f"test_{loss_name}_loss").reset() 113 | 114 | value = getattr(pl_module, f"dev_{loss_name}_accuracy").compute() 115 | pl_module.log(f"{loss_name}/dev/accuracy_epoch", value) 116 | getattr(pl_module, f"dev_{loss_name}_accuracy").reset() 117 | pl_module.log( 118 | f"{loss_name}/dev/loss_epoch", 119 | getattr(pl_module, f"dev_{loss_name}_loss").compute(), 120 | ) 121 | getattr(pl_module, f"dev_{loss_name}_loss").reset() 122 | elif loss_name == "irtr": 123 | pl_module.log( 124 | f"{loss_name}/{phase}/irtr_loss_epoch", 125 | getattr(pl_module, f"{phase}_irtr_loss").compute(), 126 | ) 127 | getattr(pl_module, f"{phase}_irtr_loss").reset() 128 | elif loss_name == "mppd" or loss_name == "mpfr": 129 | pl_module.log( 130 | f"{loss_name}/{phase}/loss_epoch", 131 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 132 | ) 133 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 134 | elif loss_name == "itm": 135 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute() 136 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 137 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 138 | pl_module.log( 139 | f"{loss_name}/{phase}/loss_epoch", 140 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 141 | ) 142 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 143 | else: 144 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute() 145 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 146 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 147 | pl_module.log( 148 | f"{loss_name}/{phase}/loss_epoch", 149 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 150 | ) 151 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 152 | 153 | the_metric += value 154 | 155 | pl_module.log(f"{phase}/the_metric", the_metric) 156 | 157 | 158 | def check_non_acc_grad(pl_module): 159 | if pl_module.token_type_embeddings.weight.grad is None: 160 | return True 161 | else: 162 | grad = pl_module.token_type_embeddings.weight.grad 163 | return (grad.sum() == 0).item() 164 | 165 | 166 | def set_task(pl_module): 167 | pl_module.current_tasks = [ 168 | k for k, v in pl_module.hparams.config["loss_names"].items() if v > 0 169 | ] 170 | return 171 | 172 | def set_schedule(pl_module): 173 | lr = pl_module.hparams.config["learning_rate"] 174 | wd = pl_module.hparams.config["weight_decay"] 175 | 176 | no_decay = [ 177 | "bias", 178 | "LayerNorm.bias", 179 | "LayerNorm.weight", 180 | "norm.bias", 181 | "norm.weight", 182 | "norm1.bias", 183 | "norm1.weight", 184 | "norm2.bias", 185 | "norm2.weight", 186 | ] 187 | head_names = ["vqa_classifier", "nlvr2_classifier", "mlm_score", "itm_score", "snli_classifier"] 188 | cross_modal_names = ['cross_modal'] 189 | lr_mult_head = pl_module.hparams.config["lr_mult_head"] 190 | lr_mult_cross_modal = pl_module.hparams.config["lr_mult_cross_modal"] 191 | end_lr = pl_module.hparams.config["end_lr"] 192 | decay_power = pl_module.hparams.config["decay_power"] 193 | optim_type = pl_module.hparams.config["optim_type"] 194 | optimizer_grouped_parameters = [ 195 | { 196 | "params": [ 197 | p 198 | for n, p in pl_module.named_parameters() 199 | if not any(nd in n for nd in no_decay) 200 | and not any(bb in n for bb in head_names) 201 | and not any(ht in n for ht in cross_modal_names) 202 | ], 203 | "weight_decay": wd, 204 | "lr": lr, 205 | }, 206 | { 207 | "params": [ 208 | p 209 | for n, p in pl_module.named_parameters() 210 | if any(nd in n for nd in no_decay) 211 | and not any(bb in n for bb in head_names) 212 | and not any(ht in n for ht in cross_modal_names) 213 | ], 214 | "weight_decay": 0.0, 215 | "lr": lr, 216 | }, 217 | { 218 | "params": [ 219 | p 220 | for n, p in pl_module.named_parameters() 221 | if not any(nd in n for nd in no_decay) 222 | and any(bb in n for bb in head_names) 223 | and not any(ht in n for ht in cross_modal_names) 224 | ], 225 | "weight_decay": wd, 226 | "lr": lr * lr_mult_head, 227 | }, 228 | { 229 | "params": [ 230 | p 231 | for n, p in pl_module.named_parameters() 232 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names) 233 | and not any(ht in n for ht in cross_modal_names) 234 | ], 235 | "weight_decay": 0.0, 236 | "lr": lr * lr_mult_head, 237 | }, 238 | { 239 | "params": [ 240 | p 241 | for n, p in pl_module.named_parameters() 242 | if not any(nd in n for nd in no_decay) 243 | and not any(bb in n for bb in head_names) 244 | and any(ht in n for ht in cross_modal_names) 245 | ], 246 | "weight_decay": wd, 247 | "lr": lr * lr_mult_cross_modal, 248 | }, 249 | { 250 | "params": [ 251 | p 252 | for n, p in pl_module.named_parameters() 253 | if any(nd in n for nd in no_decay) 254 | and not any(bb in n for bb in head_names) 255 | and any(ht in n for ht in cross_modal_names) 256 | ], 257 | "weight_decay": 0.0, 258 | "lr": lr * lr_mult_cross_modal, 259 | }, 260 | ] 261 | 262 | if optim_type == "adamw": 263 | optimizer = AdamW( 264 | optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98) 265 | ) 266 | elif optim_type == "adam": 267 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 268 | elif optim_type == "sgd": 269 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 270 | 271 | if pl_module.trainer.max_steps is None: 272 | max_steps = ( 273 | len(pl_module.trainer.datamodule.train_dataloader()) 274 | * pl_module.trainer.max_epochs 275 | // pl_module.trainer.accumulate_grad_batches 276 | ) 277 | else: 278 | max_steps = pl_module.trainer.max_steps 279 | 280 | warmup_steps = pl_module.hparams.config["warmup_steps"] 281 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 282 | warmup_steps = int(max_steps * warmup_steps) 283 | 284 | if decay_power == "cosine": 285 | scheduler = get_cosine_schedule_with_warmup( 286 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps, 287 | ) 288 | else: 289 | scheduler = get_polynomial_decay_schedule_with_warmup( 290 | optimizer, 291 | num_warmup_steps=warmup_steps, 292 | num_training_steps=max_steps, 293 | lr_end=end_lr, 294 | power=decay_power, 295 | ) 296 | 297 | sched = {"scheduler": scheduler, "interval": "step"} 298 | 299 | return ( 300 | [optimizer], 301 | [sched], 302 | ) 303 | -------------------------------------------------------------------------------- /meter/modules/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import functools 9 | 10 | from torch.utils.data.distributed import DistributedSampler 11 | from einops import rearrange 12 | 13 | from .dist_utils import all_gather 14 | 15 | 16 | def compute_mlm(pl_module, batch): 17 | infer = pl_module.infer(batch, mask_text=True, mask_image=False) 18 | mlm_logits = pl_module.mlm_score(infer["text_feats"]) 19 | mlm_labels = infer["text_labels"] 20 | 21 | mlm_loss = F.cross_entropy( 22 | mlm_logits.view(-1, pl_module.hparams.config["vocab_size"]), 23 | mlm_labels.view(-1), 24 | ignore_index=-100, 25 | ) 26 | 27 | ret = { 28 | "mlm_loss": mlm_loss, 29 | "mlm_logits": mlm_logits, 30 | "mlm_labels": mlm_labels, 31 | "mlm_ids": infer["text_ids"], 32 | } 33 | 34 | phase = "train" if pl_module.training else "val" 35 | loss = getattr(pl_module, f"{phase}_mlm_loss")(ret["mlm_loss"]) 36 | acc = getattr(pl_module, f"{phase}_mlm_accuracy")( 37 | ret["mlm_logits"], ret["mlm_labels"] 38 | ) 39 | pl_module.log(f"mlm/{phase}/loss", loss) 40 | pl_module.log(f"mlm/{phase}/accuracy", acc) 41 | 42 | return ret 43 | 44 | def compute_itm(pl_module, batch): 45 | pos_len = len(batch["text"]) // 2 46 | neg_len = len(batch["text"]) - pos_len 47 | itm_labels = torch.cat([torch.ones(pos_len), torch.zeros(neg_len)]).to( 48 | pl_module.device 49 | ) 50 | itm_labels = itm_labels[torch.randperm(itm_labels.size(0))] 51 | 52 | itm_images = [ 53 | torch.stack( 54 | [ 55 | ti if itm_labels[i] == 1 else fi 56 | for i, (ti, fi) in enumerate(zip(bti, bfi)) 57 | ] 58 | ) 59 | for bti, bfi in zip(batch["image"], batch["false_image_0"]) 60 | ] 61 | 62 | batch = {k: v for k, v in batch.items()} 63 | batch["image"] = itm_images 64 | 65 | infer = pl_module.infer(batch, mask_text=False, mask_image=False) 66 | 67 | itm_logits = pl_module.itm_score(infer["cls_feats"]) 68 | itm_loss = F.cross_entropy(itm_logits, itm_labels.long()) 69 | 70 | ret = { 71 | "itm_loss": itm_loss, 72 | "itm_logits": itm_logits, 73 | "itm_labels": itm_labels, 74 | } 75 | 76 | phase = "train" if pl_module.training else "val" 77 | loss = getattr(pl_module, f"{phase}_itm_loss")(ret["itm_loss"]) 78 | acc = getattr(pl_module, f"{phase}_itm_accuracy")( 79 | ret["itm_logits"], ret["itm_labels"] 80 | ) 81 | pl_module.log(f"itm/{phase}/loss", loss) 82 | pl_module.log(f"itm/{phase}/accuracy", acc) 83 | 84 | return ret 85 | 86 | def compute_snli(pl_module, batch): 87 | infer = pl_module.infer( 88 | batch, mask_text=False, mask_image=False, 89 | ) 90 | snli_logits = pl_module.snli_classifier(infer["cls_feats"]) 91 | 92 | snli_labels = batch["labels"] 93 | snli_labels = torch.tensor(snli_labels).to(pl_module.device).long() 94 | snli_loss = F.cross_entropy(snli_logits, snli_labels.view(-1)) 95 | 96 | ret = { 97 | "snli_loss": snli_loss, 98 | "snli_logits": snli_logits, 99 | "snli_labels": snli_labels, 100 | } 101 | 102 | phase = "train" if pl_module.training else "val" 103 | 104 | if phase == "train": 105 | loss = getattr(pl_module, f"{phase}_snli_loss")(ret["snli_loss"]) 106 | acc = getattr(pl_module, f"{phase}_snli_accuracy")( 107 | ret["snli_logits"], ret["snli_labels"] 108 | ) 109 | pl_module.log(f"snli/{phase}/loss", loss) 110 | pl_module.log(f"snli/{phase}/accuracy", acc) 111 | else: 112 | dev_batches = [i for i, n in enumerate(batch["table_name"]) if "dev" in n] 113 | test_batches = [i for i, n in enumerate(batch["table_name"]) if "test" in n] 114 | 115 | if dev_batches: 116 | dev_loss = getattr(pl_module, f"dev_snli_loss")( 117 | F.cross_entropy( 118 | ret["snli_logits"][dev_batches], ret["snli_labels"][dev_batches] 119 | ) 120 | ) 121 | dev_acc = getattr(pl_module, f"dev_snli_accuracy")( 122 | ret["snli_logits"][dev_batches], ret["snli_labels"][dev_batches] 123 | ) 124 | pl_module.log(f"snli/dev/loss", dev_loss) 125 | pl_module.log(f"snli/dev/accuracy", dev_acc) 126 | if test_batches: 127 | test_loss = getattr(pl_module, f"test_snli_loss")( 128 | F.cross_entropy( 129 | ret["snli_logits"][test_batches], ret["snli_labels"][test_batches] 130 | ) 131 | ) 132 | test_acc = getattr(pl_module, f"test_snli_accuracy")( 133 | ret["snli_logits"][test_batches], ret["snli_labels"][test_batches] 134 | ) 135 | pl_module.log(f"snli/test/loss", test_loss) 136 | pl_module.log(f"snli/test/accuracy", test_acc) 137 | 138 | return ret 139 | 140 | def compute_vqa(pl_module, batch): 141 | infer = pl_module.infer(batch, mask_text=False, mask_image=False) 142 | vqa_logits = pl_module.vqa_classifier(infer["cls_feats"]) 143 | vqa_targets = torch.zeros( 144 | len(vqa_logits), pl_module.hparams.config["vqav2_label_size"] 145 | ).to(pl_module.device) 146 | 147 | vqa_labels = batch["vqa_labels"] 148 | vqa_scores = batch["vqa_scores"] 149 | 150 | for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)): 151 | for l, s in zip(_label, _score): 152 | vqa_targets[i, l] = s 153 | 154 | vqa_loss = ( 155 | F.binary_cross_entropy_with_logits(vqa_logits, vqa_targets) 156 | * vqa_targets.shape[1] 157 | ) # https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19 158 | 159 | ret = { 160 | "vqa_loss": vqa_loss, 161 | "vqa_logits": vqa_logits, 162 | "vqa_targets": vqa_targets, 163 | "vqa_labels": vqa_labels, 164 | "vqa_scores": vqa_scores, 165 | } 166 | 167 | phase = "train" if pl_module.training else "val" 168 | loss = getattr(pl_module, f"{phase}_vqa_loss")(ret["vqa_loss"]) 169 | score = getattr(pl_module, f"{phase}_vqa_score")( 170 | ret["vqa_logits"], ret["vqa_targets"] 171 | ) 172 | pl_module.log(f"vqa/{phase}/loss", loss) 173 | pl_module.log(f"vqa/{phase}/score", score) 174 | 175 | return ret 176 | 177 | 178 | def compute_nlvr2(pl_module, batch): 179 | infer1 = pl_module.infer( 180 | batch, mask_text=False, mask_image=False, image_token_type_idx=1 181 | ) 182 | infer2 = pl_module.infer( 183 | batch, mask_text=False, mask_image=False, image_token_type_idx=2 184 | ) 185 | 186 | cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1) 187 | nlvr2_logits = pl_module.nlvr2_classifier(cls_feats) 188 | 189 | nlvr2_labels = batch["answers"] 190 | nlvr2_labels = torch.tensor(nlvr2_labels).to(pl_module.device).long() 191 | nlvr2_loss = F.cross_entropy(nlvr2_logits, nlvr2_labels.view(-1)) 192 | 193 | ret = { 194 | "nlvr2_loss": nlvr2_loss, 195 | "nlvr2_logits": nlvr2_logits, 196 | "nlvr2_labels": nlvr2_labels, 197 | } 198 | 199 | phase = "train" if pl_module.training else "val" 200 | 201 | if phase == "train": 202 | loss = getattr(pl_module, f"{phase}_nlvr2_loss")(ret["nlvr2_loss"]) 203 | acc = getattr(pl_module, f"{phase}_nlvr2_accuracy")( 204 | ret["nlvr2_logits"], ret["nlvr2_labels"] 205 | ) 206 | pl_module.log(f"nlvr2/{phase}/loss", loss) 207 | pl_module.log(f"nlvr2/{phase}/accuracy", acc) 208 | else: 209 | dev_batches = [i for i, n in enumerate(batch["table_name"]) if "dev" in n] 210 | test_batches = [i for i, n in enumerate(batch["table_name"]) if "test" in n] 211 | 212 | if dev_batches: 213 | dev_loss = getattr(pl_module, f"dev_nlvr2_loss")( 214 | F.cross_entropy( 215 | ret["nlvr2_logits"][dev_batches], ret["nlvr2_labels"][dev_batches] 216 | ) 217 | ) 218 | dev_acc = getattr(pl_module, f"dev_nlvr2_accuracy")( 219 | ret["nlvr2_logits"][dev_batches], ret["nlvr2_labels"][dev_batches] 220 | ) 221 | pl_module.log(f"nlvr2/dev/loss", dev_loss) 222 | pl_module.log(f"nlvr2/dev/accuracy", dev_acc) 223 | if test_batches: 224 | test_loss = getattr(pl_module, f"test_nlvr2_loss")( 225 | F.cross_entropy( 226 | ret["nlvr2_logits"][test_batches], ret["nlvr2_labels"][test_batches] 227 | ) 228 | ) 229 | test_acc = getattr(pl_module, f"test_nlvr2_accuracy")( 230 | ret["nlvr2_logits"][test_batches], ret["nlvr2_labels"][test_batches] 231 | ) 232 | pl_module.log(f"nlvr2/test/loss", test_loss) 233 | pl_module.log(f"nlvr2/test/accuracy", test_acc) 234 | 235 | return ret 236 | 237 | 238 | def compute_irtr(pl_module, batch): 239 | is_training_phase = pl_module.training 240 | 241 | _bs, _c, _h, _w = batch["image"][0].shape 242 | false_len = pl_module.hparams.config["draw_false_text"] 243 | text_ids = torch.stack( 244 | [batch[f"false_text_{i}_ids"] for i in range(false_len)], dim=1 245 | ) 246 | text_masks = torch.stack( 247 | [batch[f"false_text_{i}_masks"] for i in range(false_len)], dim=1 248 | ) 249 | text_labels = torch.stack( 250 | [batch[f"false_text_{i}_labels"] for i in range(false_len)], dim=1 251 | ) 252 | 253 | text_ids = torch.cat([batch["text_ids"].unsqueeze(1), text_ids], dim=1) 254 | text_masks = torch.cat([batch["text_masks"].unsqueeze(1), text_masks], dim=1) 255 | text_labels = torch.cat([batch["text_labels"].unsqueeze(1), text_labels], dim=1) 256 | images = batch["image"][0].unsqueeze(1).expand(_bs, false_len + 1, _c, _h, _w) 257 | 258 | infer = pl_module.infer( 259 | { 260 | "image": [rearrange(images, "bs fs c h w -> (bs fs) c h w")], 261 | "text_ids": rearrange(text_ids, "bs fs tl -> (bs fs) tl"), 262 | "text_masks": rearrange(text_masks, "bs fs tl -> (bs fs) tl"), 263 | "text_labels": rearrange(text_labels, "bs fs tl -> (bs fs) tl"), 264 | } 265 | ) 266 | score = pl_module.rank_output(infer["cls_feats"])[:, 0] 267 | score = rearrange(score, "(bs fs) -> bs fs", bs=_bs, fs=false_len + 1) 268 | answer = torch.zeros(_bs).to(score).long() 269 | irtr_loss = F.cross_entropy(score, answer) 270 | 271 | ret = { 272 | "irtr_loss": irtr_loss, 273 | } 274 | 275 | phase = "train" if pl_module.training else "val" 276 | irtr_loss = getattr(pl_module, f"{phase}_irtr_loss")(ret["irtr_loss"]) 277 | 278 | pl_module.log(f"irtr/{phase}/irtr_loss", irtr_loss) 279 | 280 | return ret 281 | 282 | 283 | @torch.no_grad() 284 | def compute_irtr_recall(pl_module): 285 | text_dset = pl_module.trainer.datamodule.dms[0].make_no_false_val_dset() 286 | text_dset.tokenizer = pl_module.trainer.datamodule.dms[0].tokenizer 287 | text_loader = torch.utils.data.DataLoader( 288 | text_dset, 289 | batch_size=64, 290 | num_workers=pl_module.hparams.config["num_workers"], 291 | pin_memory=True, 292 | collate_fn=functools.partial( 293 | text_dset.collate, 294 | mlm_collator=pl_module.trainer.datamodule.dms[0].mlm_collator, 295 | ), 296 | ) 297 | 298 | image_dset = pl_module.trainer.datamodule.dms[0].make_no_false_val_dset( 299 | image_only=True 300 | ) 301 | image_dset.tokenizer = pl_module.trainer.datamodule.dms[0].tokenizer 302 | dist_sampler = DistributedSampler(image_dset, shuffle=False) 303 | image_loader = torch.utils.data.DataLoader( 304 | image_dset, 305 | batch_size=1, 306 | num_workers=pl_module.hparams.config["num_workers"], 307 | sampler=dist_sampler, 308 | pin_memory=True, 309 | collate_fn=functools.partial( 310 | image_dset.collate, 311 | mlm_collator=pl_module.trainer.datamodule.dms[0].mlm_collator, 312 | ), 313 | ) 314 | 315 | #TODO: speed up the process by caching text/image features 316 | text_preload = list() 317 | for _b in tqdm.tqdm(text_loader, desc="text prefetch loop"): 318 | text_preload.append( 319 | { 320 | "text_ids": _b["text_ids"].to(pl_module.device), 321 | "text_masks": _b["text_masks"].to(pl_module.device), 322 | "text_labels": _b["text_labels"].to(pl_module.device), 323 | "img_index": _b["img_index"], 324 | } 325 | ) 326 | 327 | tiids = list() 328 | for pre in text_preload: 329 | tiids += pre["img_index"] 330 | tiids = torch.tensor(tiids) 331 | 332 | image_preload = list() 333 | for _b in tqdm.tqdm(image_loader, desc="image prefetch loop"): 334 | image_preload.append((_b['image'][0], _b["img_index"][0])) 335 | 336 | rank_scores = list() 337 | rank_iids = list() 338 | 339 | for img_batch in tqdm.tqdm(image_preload, desc="rank loop"): 340 | _im, _iid = img_batch 341 | 342 | img_batch_score = list() 343 | for txt_batch in text_preload: 344 | fblen = len(txt_batch["text_ids"]) 345 | im = _im.repeat(fblen, 1, 1, 1).to(device=txt_batch['text_ids'].device) 346 | 347 | with torch.cuda.amp.autocast(): 348 | score = pl_module.rank_output( 349 | pl_module.infer( 350 | { 351 | "text_ids": txt_batch["text_ids"], 352 | "text_masks": txt_batch["text_masks"], 353 | "text_labels": txt_batch["text_labels"], 354 | }, 355 | img=im, 356 | )["cls_feats"] 357 | )[:, 0] 358 | 359 | img_batch_score.append(score) 360 | 361 | img_batch_score = torch.cat(img_batch_score) 362 | rank_scores.append(img_batch_score.cpu().tolist()) 363 | rank_iids.append(_iid) 364 | 365 | torch.distributed.barrier() 366 | gather_rank_scores = all_gather(rank_scores) 367 | gather_rank_iids = all_gather(rank_iids) 368 | 369 | iids = torch.tensor(gather_rank_iids) 370 | iids = iids.view(-1) 371 | scores = torch.tensor(gather_rank_scores) 372 | scores = scores.view(len(iids), -1) 373 | 374 | topk10 = scores.topk(10, dim=1) 375 | topk5 = scores.topk(5, dim=1) 376 | topk1 = scores.topk(1, dim=1) 377 | topk10_iids = tiids[topk10.indices] 378 | topk5_iids = tiids[topk5.indices] 379 | topk1_iids = tiids[topk1.indices] 380 | 381 | tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean() 382 | tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean() 383 | tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean() 384 | 385 | topk10 = scores.topk(10, dim=0) 386 | topk5 = scores.topk(5, dim=0) 387 | topk1 = scores.topk(1, dim=0) 388 | topk10_iids = iids[topk10.indices] 389 | topk5_iids = iids[topk5.indices] 390 | topk1_iids = iids[topk1.indices] 391 | 392 | ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean() 393 | ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean() 394 | ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean() 395 | 396 | return (ir_r1, ir_r5, ir_r10, tr_r1, tr_r5, tr_r10) 397 | 398 | 399 | def init_weights(module): 400 | if isinstance(module, (nn.Linear, nn.Embedding)): 401 | module.weight.data.normal_(mean=0.0, std=0.02) 402 | elif isinstance(module, nn.LayerNorm): 403 | module.bias.data.zero_() 404 | module.weight.data.fill_(1.0) 405 | 406 | if isinstance(module, nn.Linear) and module.bias is not None: 407 | module.bias.data.zero_() 408 | 409 | 410 | def vqa_test_step(pl_module, batch, output): 411 | try: 412 | id2answer = ( 413 | pl_module.trainer.datamodule.dm_dicts["vqa_trainval"].id2answer 414 | if "vqa_trainval" in pl_module.trainer.datamodule.dm_dicts 415 | else pl_module.trainer.datamodule.dm_dicts["vqa"].id2answer 416 | ) 417 | except: 418 | id2answer = ( 419 | pl_module.trainer.datamodule.dm_dicts["gqa_test"].id2answer 420 | if "gqa_test" in pl_module.trainer.datamodule.dm_dicts 421 | else pl_module.trainer.datamodule.dm_dicts["gqa"].id2answer 422 | ) 423 | vqa_logits = output["vqa_logits"] 424 | vqa_preds = vqa_logits.argmax(dim=-1) 425 | vqa_preds = [id2answer[pred.item()] for pred in vqa_preds] 426 | questions = batch["text"] 427 | qids = batch["qid"] 428 | return {"qids": qids, "preds": vqa_preds, "gqa": True} 429 | vqa_logits = output["vqa_logits"] 430 | vqa_preds = vqa_logits.argmax(dim=-1) 431 | vqa_preds = [id2answer[pred.item()] for pred in vqa_preds] 432 | questions = batch["text"] 433 | qids = batch["qid"] 434 | return {"qids": qids, "preds": vqa_preds, "gqa": False} 435 | 436 | 437 | def arc_test_step(pl_module, batch, output): 438 | return output 439 | 440 | 441 | def vqa_test_wrapup(outs, model_name): 442 | rank = torch.distributed.get_rank() 443 | qids, preds = list(), list() 444 | gqa = False 445 | for out in outs: 446 | qids += out["qids"] 447 | preds += out["preds"] 448 | gqa = out['gqa'] 449 | 450 | rets = list() 451 | for qid, pred in zip(qids, preds): 452 | if gqa: 453 | rets.append({"questionId": qid, "prediction": pred}) 454 | else: 455 | rets.append({"question_id": qid, "answer": pred}) 456 | with open(f"vqa_submit_{rank}.json", "w") as fp: 457 | json.dump(rets, fp, indent=4) 458 | 459 | torch.distributed.barrier() 460 | 461 | if rank == 0: 462 | jsons = list() 463 | paths = list(glob.glob("vqa_submit_*.json")) 464 | for path in paths: 465 | with open(path, "r") as fp: 466 | jsons += json.load(fp) 467 | os.makedirs("result", exist_ok=True) 468 | with open(f"result/vqa_submit_{model_name}.json", "w") as fp: 469 | json.dump(jsons, fp, indent=4) 470 | 471 | torch.distributed.barrier() 472 | os.remove(f"vqa_submit_{rank}.json") 473 | 474 | 475 | def arc_test_wrapup(outs, caplen, model_name): 476 | rank = torch.distributed.get_rank() 477 | iids, captions = list(), list() 478 | for out in outs: 479 | iids += out["iid"] 480 | captions += out["captions"] 481 | 482 | rets = list() 483 | for iid, caption in zip(iids, captions): 484 | rets.append({"image_id": iid, "caption": caption}) 485 | with open(f"coco_cap_len{caplen}_{rank}.json", "w") as fp: 486 | json.dump(rets, fp, indent=4) 487 | 488 | torch.distributed.barrier() 489 | 490 | if rank == 0: 491 | jsons = list() 492 | paths = list(glob.glob(f"coco_cap_len{caplen}_*.json")) 493 | for path in paths: 494 | with open(path, "r") as fp: 495 | jsons += json.load(fp) 496 | os.makedirs("result/arc", exist_ok=True) 497 | jsons = sorted(jsons, key=lambda x: x["image_id"]) 498 | with open(f"result/arc/coco_cap_{model_name}_len{caplen}.json", "w") as fp: 499 | json.dump(jsons, fp, indent=4) 500 | 501 | torch.distributed.barrier() 502 | os.remove(f"coco_cap_len{caplen}_{rank}.json") 503 | -------------------------------------------------------------------------------- /meter/modules/swin_helpers.py: -------------------------------------------------------------------------------- 1 | """ Model creation / weight loading / state_dict helpers 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import logging 5 | import os 6 | import math 7 | from collections import OrderedDict 8 | from copy import deepcopy 9 | from typing import Any, Callable, Optional, Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | from timm.models.features import FeatureListNet, FeatureDictNet, FeatureHookNet 16 | from timm.models.hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url 17 | from timm.models.layers import Conv2dSame, Linear 18 | 19 | def swin_adapt_position_encoding(model, before=384, patch_size=32, after=384, 20 | suffix='relative_position_bias_table'): 21 | if after == before: 22 | return model 23 | grid_before = int(before/32) 24 | grid_after = int(after/32) #after // patch_size 25 | before = (2*grid_before-1) 26 | import math 27 | after = (2*grid_after-1) 28 | keys = [k for k in model if k.endswith(suffix)] 29 | assert len(keys) > 0 30 | for key in keys: 31 | pos_embed = model[key] 32 | pos_embed = pos_embed.transpose(0, 1).view(-1, before, before) 33 | pos_embed = torch.nn.functional.interpolate(pos_embed.unsqueeze(0), size=(after, after), mode='bicubic') 34 | pos_embed = pos_embed.squeeze(0).permute((1, 2, 0)) 35 | pos_embed = pos_embed.contiguous().view(-1, pos_embed.size(-1)) 36 | model[key] = pos_embed 37 | keys = [k for k in model if k.endswith('attn_mask')] 38 | for key in keys: 39 | model.pop(key) 40 | keys = [k for k in model if k.endswith('relative_position_index')] 41 | for key in keys: 42 | model.pop(key) 43 | 44 | return model 45 | 46 | 47 | _logger = logging.getLogger(__name__) 48 | 49 | 50 | def load_state_dict(checkpoint_path, use_ema=False): 51 | if checkpoint_path and os.path.isfile(checkpoint_path): 52 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 53 | state_dict_key = 'state_dict' 54 | if isinstance(checkpoint, dict): 55 | if use_ema and 'state_dict_ema' in checkpoint: 56 | state_dict_key = 'state_dict_ema' 57 | if state_dict_key and state_dict_key in checkpoint: 58 | new_state_dict = OrderedDict() 59 | for k, v in checkpoint[state_dict_key].items(): 60 | # strip `module.` prefix 61 | name = k[7:] if k.startswith('module') else k 62 | new_state_dict[name] = v 63 | state_dict = new_state_dict 64 | else: 65 | state_dict = checkpoint 66 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 67 | return state_dict 68 | else: 69 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 70 | raise FileNotFoundError() 71 | 72 | 73 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): 74 | if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): 75 | # numpy checkpoint, try to load via model specific load_pretrained fn 76 | if hasattr(model, 'load_pretrained'): 77 | model.load_pretrained(checkpoint_path) 78 | else: 79 | raise NotImplementedError('Model cannot load numpy checkpoint') 80 | return 81 | state_dict = load_state_dict(checkpoint_path, use_ema) 82 | model.load_state_dict(state_dict, strict=strict) 83 | 84 | 85 | def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): 86 | resume_epoch = None 87 | if os.path.isfile(checkpoint_path): 88 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 89 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 90 | if log_info: 91 | _logger.info('Restoring model state from checkpoint...') 92 | new_state_dict = OrderedDict() 93 | for k, v in checkpoint['state_dict'].items(): 94 | name = k[7:] if k.startswith('module') else k 95 | new_state_dict[name] = v 96 | model.load_state_dict(new_state_dict) 97 | 98 | if optimizer is not None and 'optimizer' in checkpoint: 99 | if log_info: 100 | _logger.info('Restoring optimizer state from checkpoint...') 101 | optimizer.load_state_dict(checkpoint['optimizer']) 102 | 103 | if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: 104 | if log_info: 105 | _logger.info('Restoring AMP loss scaler state from checkpoint...') 106 | loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) 107 | 108 | if 'epoch' in checkpoint: 109 | resume_epoch = checkpoint['epoch'] 110 | if 'version' in checkpoint and checkpoint['version'] > 1: 111 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save 112 | 113 | if log_info: 114 | _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 115 | else: 116 | model.load_state_dict(checkpoint) 117 | if log_info: 118 | _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) 119 | return resume_epoch 120 | else: 121 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 122 | raise FileNotFoundError() 123 | 124 | 125 | def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): 126 | r"""Loads a custom (read non .pth) weight file 127 | Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls 128 | a passed in custom load fun, or the `load_pretrained` model member fn. 129 | If the object is already present in `model_dir`, it's deserialized and returned. 130 | The default value of `model_dir` is ``/checkpoints`` where 131 | `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. 132 | Args: 133 | model: The instantiated model to load weights into 134 | default_cfg (dict): Default pretrained model cfg 135 | load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named 136 | 'laod_pretrained' on the model will be called if it exists 137 | progress (bool, optional): whether or not to display a progress bar to stderr. Default: False 138 | check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention 139 | ``filename-.ext`` where ```` is the first eight or more 140 | digits of the SHA256 hash of the contents of the file. The hash is used to 141 | ensure unique names and to verify the contents of the file. Default: False 142 | """ 143 | default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} 144 | pretrained_url = default_cfg.get('url', None) 145 | if not pretrained_url: 146 | _logger.warning("No pretrained weights exist for this model. Using random initialization.") 147 | return 148 | cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) 149 | 150 | if load_fn is not None: 151 | load_fn(model, cached_file) 152 | elif hasattr(model, 'load_pretrained'): 153 | model.load_pretrained(cached_file) 154 | else: 155 | _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") 156 | 157 | 158 | def adapt_input_conv(in_chans, conv_weight): 159 | conv_type = conv_weight.dtype 160 | conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU 161 | O, I, J, K = conv_weight.shape 162 | if in_chans == 1: 163 | if I > 3: 164 | assert conv_weight.shape[1] % 3 == 0 165 | # For models with space2depth stems 166 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 167 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 168 | else: 169 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 170 | elif in_chans != 3: 171 | if I != 3: 172 | raise NotImplementedError('Weight format not supported by conversion.') 173 | else: 174 | # NOTE this strategy should be better than random init, but there could be other combinations of 175 | # the original RGB input layer weights that'd work better for specific cases. 176 | repeat = int(math.ceil(in_chans / 3)) 177 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 178 | conv_weight *= (3 / float(in_chans)) 179 | conv_weight = conv_weight.to(conv_type) 180 | return conv_weight 181 | 182 | 183 | 184 | def load_pretrained(model, img_size, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False, resolution_before=384): 185 | """ Load pretrained checkpoint 186 | Args: 187 | model (nn.Module) : PyTorch model module 188 | default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset 189 | num_classes (int): num_classes for model 190 | in_chans (int): in_chans for model 191 | filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) 192 | strict (bool): strict load of checkpoint 193 | progress (bool): enable progress bar for weight download 194 | """ 195 | default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} 196 | pretrained_url = default_cfg.get('url', None) 197 | hf_hub_id = default_cfg.get('hf_hub', None) 198 | if not pretrained_url and not hf_hub_id: 199 | _logger.warning("No pretrained weights exist for this model. Using random initialization.") 200 | return 201 | if hf_hub_id and has_hf_hub(necessary=not pretrained_url): 202 | _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') 203 | state_dict = load_state_dict_from_hf(hf_hub_id) 204 | else: 205 | _logger.info(f'Loading pretrained weights from url ({pretrained_url})') 206 | state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') 207 | swin_adapt_position_encoding(state_dict['model'], before=resolution_before, after=img_size) 208 | if filter_fn is not None: 209 | # for backwards compat with filter fn that take one arg, try one first, the two 210 | try: 211 | state_dict = filter_fn(state_dict) 212 | except TypeError: 213 | state_dict = filter_fn(state_dict, model) 214 | 215 | input_convs = default_cfg.get('first_conv', None) 216 | if input_convs is not None and in_chans != 3: 217 | if isinstance(input_convs, str): 218 | input_convs = (input_convs,) 219 | for input_conv_name in input_convs: 220 | weight_name = input_conv_name + '.weight' 221 | try: 222 | state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) 223 | _logger.info( 224 | f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') 225 | except NotImplementedError as e: 226 | del state_dict[weight_name] 227 | strict = False 228 | _logger.warning( 229 | f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') 230 | 231 | classifiers = default_cfg.get('classifier', None) 232 | label_offset = default_cfg.get('label_offset', 0) 233 | if classifiers is not None: 234 | if isinstance(classifiers, str): 235 | classifiers = (classifiers,) 236 | if num_classes != default_cfg['num_classes']: 237 | for classifier_name in classifiers: 238 | # completely discard fully connected if model num_classes doesn't match pretrained weights 239 | del state_dict[classifier_name + '.weight'] 240 | del state_dict[classifier_name + '.bias'] 241 | strict = False 242 | elif label_offset > 0: 243 | for classifier_name in classifiers: 244 | # special case for pretrained weights with an extra background class in pretrained weights 245 | classifier_weight = state_dict[classifier_name + '.weight'] 246 | state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] 247 | classifier_bias = state_dict[classifier_name + '.bias'] 248 | state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] 249 | 250 | model.load_state_dict(state_dict, strict=strict) 251 | 252 | 253 | def extract_layer(model, layer): 254 | layer = layer.split('.') 255 | module = model 256 | if hasattr(model, 'module') and layer[0] != 'module': 257 | module = model.module 258 | if not hasattr(model, 'module') and layer[0] == 'module': 259 | layer = layer[1:] 260 | for l in layer: 261 | if hasattr(module, l): 262 | if not l.isdigit(): 263 | module = getattr(module, l) 264 | else: 265 | module = module[int(l)] 266 | else: 267 | return module 268 | return module 269 | 270 | 271 | def set_layer(model, layer, val): 272 | layer = layer.split('.') 273 | module = model 274 | if hasattr(model, 'module') and layer[0] != 'module': 275 | module = model.module 276 | lst_index = 0 277 | module2 = module 278 | for l in layer: 279 | if hasattr(module2, l): 280 | if not l.isdigit(): 281 | module2 = getattr(module2, l) 282 | else: 283 | module2 = module2[int(l)] 284 | lst_index += 1 285 | lst_index -= 1 286 | for l in layer[:lst_index]: 287 | if not l.isdigit(): 288 | module = getattr(module, l) 289 | else: 290 | module = module[int(l)] 291 | l = layer[lst_index] 292 | setattr(module, l, val) 293 | 294 | 295 | def adapt_model_from_string(parent_module, model_string): 296 | separator = '***' 297 | state_dict = {} 298 | lst_shape = model_string.split(separator) 299 | for k in lst_shape: 300 | k = k.split(':') 301 | key = k[0] 302 | shape = k[1][1:-1].split(',') 303 | if shape[0] != '': 304 | state_dict[key] = [int(i) for i in shape] 305 | 306 | new_module = deepcopy(parent_module) 307 | for n, m in parent_module.named_modules(): 308 | old_module = extract_layer(parent_module, n) 309 | if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): 310 | if isinstance(old_module, Conv2dSame): 311 | conv = Conv2dSame 312 | else: 313 | conv = nn.Conv2d 314 | s = state_dict[n + '.weight'] 315 | in_channels = s[1] 316 | out_channels = s[0] 317 | g = 1 318 | if old_module.groups > 1: 319 | in_channels = out_channels 320 | g = in_channels 321 | new_conv = conv( 322 | in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, 323 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, 324 | groups=g, stride=old_module.stride) 325 | set_layer(new_module, n, new_conv) 326 | if isinstance(old_module, nn.BatchNorm2d): 327 | new_bn = nn.BatchNorm2d( 328 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, 329 | affine=old_module.affine, track_running_stats=True) 330 | set_layer(new_module, n, new_bn) 331 | if isinstance(old_module, nn.Linear): 332 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? 333 | num_features = state_dict[n + '.weight'][1] 334 | new_fc = Linear( 335 | in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) 336 | set_layer(new_module, n, new_fc) 337 | if hasattr(new_module, 'num_features'): 338 | new_module.num_features = num_features 339 | new_module.eval() 340 | parent_module.eval() 341 | 342 | return new_module 343 | 344 | 345 | def adapt_model_from_file(parent_module, model_variant): 346 | adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') 347 | with open(adapt_file, 'r') as f: 348 | return adapt_model_from_string(parent_module, f.read().strip()) 349 | 350 | 351 | def default_cfg_for_features(default_cfg): 352 | default_cfg = deepcopy(default_cfg) 353 | # remove default pretrained cfg fields that don't have much relevance for feature backbone 354 | to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? 355 | for tr in to_remove: 356 | default_cfg.pop(tr, None) 357 | return default_cfg 358 | 359 | 360 | def overlay_external_default_cfg(default_cfg, kwargs): 361 | """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. 362 | """ 363 | external_default_cfg = kwargs.pop('external_default_cfg', None) 364 | if external_default_cfg: 365 | default_cfg.pop('url', None) # url should come from external cfg 366 | default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg 367 | default_cfg.update(external_default_cfg) 368 | 369 | 370 | def set_default_kwargs(kwargs, names, default_cfg): 371 | for n in names: 372 | # for legacy reasons, model __init__args uses img_size + in_chans as separate args while 373 | # default_cfg has one input_size=(C, H ,W) entry 374 | if n == 'img_size': 375 | input_size = default_cfg.get('input_size', None) 376 | if input_size is not None: 377 | assert len(input_size) == 3 378 | kwargs.setdefault(n, input_size[-2:]) 379 | elif n == 'in_chans': 380 | input_size = default_cfg.get('input_size', None) 381 | if input_size is not None: 382 | assert len(input_size) == 3 383 | kwargs.setdefault(n, input_size[0]) 384 | else: 385 | default_val = default_cfg.get(n, None) 386 | if default_val is not None: 387 | kwargs.setdefault(n, default_cfg[n]) 388 | 389 | 390 | def filter_kwargs(kwargs, names): 391 | if not kwargs or not names: 392 | return 393 | for n in names: 394 | kwargs.pop(n, None) 395 | 396 | 397 | def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): 398 | """ Update the default_cfg and kwargs before passing to model 399 | FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs 400 | could/should be replaced by an improved configuration mechanism 401 | Args: 402 | default_cfg: input default_cfg (updated in-place) 403 | kwargs: keyword args passed to model build fn (updated in-place) 404 | kwargs_filter: keyword arg keys that must be removed before model __init__ 405 | """ 406 | # Overlay default cfg values from `external_default_cfg` if it exists in kwargs 407 | overlay_external_default_cfg(default_cfg, kwargs) 408 | # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) 409 | default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') 410 | if default_cfg.get('fixed_input_size', False): 411 | # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size 412 | default_kwarg_names += ('img_size',) 413 | set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) 414 | # Filter keyword args for task specific model variants (some 'features only' models, etc.) 415 | filter_kwargs(kwargs, names=kwargs_filter) 416 | 417 | 418 | def swin_build_model_with_cfg( 419 | model_cls: Callable, 420 | variant: str, 421 | pretrained: bool, 422 | default_cfg: dict, 423 | model_cfg: Optional[Any] = None, 424 | feature_cfg: Optional[dict] = None, 425 | pretrained_strict: bool = True, 426 | pretrained_filter_fn: Optional[Callable] = None, 427 | pretrained_custom_load: bool = False, 428 | kwargs_filter: Optional[Tuple[str]] = None, 429 | **kwargs): 430 | """ Build model with specified default_cfg and optional model_cfg 431 | This helper fn aids in the construction of a model including: 432 | * handling default_cfg and associated pretained weight loading 433 | * passing through optional model_cfg for models with config based arch spec 434 | * features_only model adaptation 435 | * pruning config / model adaptation 436 | Args: 437 | model_cls (nn.Module): model class 438 | variant (str): model variant name 439 | pretrained (bool): load pretrained weights 440 | default_cfg (dict): model's default pretrained/task config 441 | model_cfg (Optional[Dict]): model's architecture config 442 | feature_cfg (Optional[Dict]: feature extraction adapter config 443 | pretrained_strict (bool): load pretrained weights strictly 444 | pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights 445 | pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights 446 | kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model 447 | **kwargs: model args passed through to model __init__ 448 | """ 449 | pruned = kwargs.pop('pruned', False) 450 | features = False 451 | feature_cfg = feature_cfg or {} 452 | default_cfg = deepcopy(default_cfg) if default_cfg else {} 453 | update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) 454 | default_cfg.setdefault('architecture', variant) 455 | 456 | # Setup for feature extraction wrapper done at end of this fn 457 | if kwargs.pop('features_only', False): 458 | features = True 459 | feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) 460 | if 'out_indices' in kwargs: 461 | feature_cfg['out_indices'] = kwargs.pop('out_indices') 462 | 463 | # Build the model 464 | model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) 465 | model.default_cfg = default_cfg 466 | 467 | if pruned: 468 | model = adapt_model_from_file(model, variant) 469 | 470 | # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats 471 | num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) 472 | if pretrained: 473 | if pretrained_custom_load: 474 | load_custom_pretrained(model) 475 | else: 476 | load_pretrained( 477 | model, 478 | num_classes=num_classes_pretrained, 479 | in_chans=kwargs.get('in_chans', 3), 480 | filter_fn=pretrained_filter_fn, 481 | img_size=kwargs['img_size'], 482 | strict=pretrained_strict, 483 | resolution_before=kwargs['config']['resolution_before']) 484 | # Wrap the model in a feature extraction module if enabled 485 | if features: 486 | feature_cls = FeatureListNet 487 | if 'feature_cls' in feature_cfg: 488 | feature_cls = feature_cfg.pop('feature_cls') 489 | if isinstance(feature_cls, str): 490 | feature_cls = feature_cls.lower() 491 | if 'hook' in feature_cls: 492 | feature_cls = FeatureHookNet 493 | else: 494 | assert False, f'Unknown feature class {feature_cls}' 495 | model = feature_cls(model, **feature_cfg) 496 | model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg 497 | 498 | return model 499 | 500 | 501 | def model_parameters(model, exclude_head=False): 502 | if exclude_head: 503 | # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering 504 | return [p for p in model.parameters()][:-2] 505 | else: 506 | return model.parameters() 507 | 508 | 509 | def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: 510 | if not depth_first and include_root: 511 | fn(module=module, name=name) 512 | for child_name, child_module in module.named_children(): 513 | child_name = '.'.join((name, child_name)) if name else child_name 514 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 515 | if depth_first and include_root: 516 | fn(module=module, name=name) 517 | return module 518 | 519 | 520 | def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): 521 | if not depth_first and include_root: 522 | yield name, module 523 | for child_name, child_module in module.named_children(): 524 | child_name = '.'.join((name, child_name)) if name else child_name 525 | yield from named_modules( 526 | module=child_module, name=child_name, depth_first=depth_first, include_root=True) 527 | if depth_first and include_root: 528 | yield name, module 529 | -------------------------------------------------------------------------------- /meter/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | pixelbert_transform, 3 | pixelbert_transform_randaug, 4 | vit_transform, 5 | vit_transform_randaug, 6 | imagenet_transform, 7 | imagenet_transform_randaug, 8 | clip_transform, 9 | clip_transform_randaug, 10 | ) 11 | 12 | _transforms = { 13 | "pixelbert": pixelbert_transform, 14 | "pixelbert_randaug": pixelbert_transform_randaug, 15 | "vit": vit_transform, 16 | "vit_randaug": vit_transform_randaug, 17 | "imagenet": imagenet_transform, 18 | "imagenet_randaug": imagenet_transform_randaug, 19 | "clip": clip_transform, 20 | "clip_randaug": clip_transform_randaug, 21 | } 22 | 23 | def keys_to_transforms(keys: list, size=224): 24 | return [_transforms[key](size=size) for key in keys] 25 | -------------------------------------------------------------------------------- /meter/transforms/randaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.0: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.0)) 136 | y0 = int(max(0, y0 - v / 2.0)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | # (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0.0, 0.3), 197 | (ShearY, 0.0, 0.3), 198 | # (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0.0, 100), 200 | (TranslateYabs, 0.0, 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = ( 220 | self.eigvec.type_as(img) 221 | .clone() 222 | .mul(alpha.view(1, 3).expand(3, 3)) 223 | .mul(self.eigval.view(1, 3).expand(3, 3)) 224 | .sum(1) 225 | .squeeze() 226 | ) 227 | 228 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 229 | 230 | 231 | class CutoutDefault(object): 232 | """ 233 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 234 | """ 235 | 236 | def __init__(self, length): 237 | self.length = length 238 | 239 | def __call__(self, img): 240 | h, w = img.size(1), img.size(2) 241 | mask = np.ones((h, w), np.float32) 242 | y = np.random.randint(h) 243 | x = np.random.randint(w) 244 | 245 | y1 = np.clip(y - self.length // 2, 0, h) 246 | y2 = np.clip(y + self.length // 2, 0, h) 247 | x1 = np.clip(x - self.length // 2, 0, w) 248 | x2 = np.clip(x + self.length // 2, 0, w) 249 | 250 | mask[y1:y2, x1:x2] = 0.0 251 | mask = torch.from_numpy(mask) 252 | mask = mask.expand_as(img) 253 | img *= mask 254 | return img 255 | 256 | 257 | class RandAugment: 258 | def __init__(self, n, m): 259 | self.n = n 260 | self.m = m # [0, 30] 261 | self.augment_list = augment_list() 262 | 263 | def __call__(self, img): 264 | ops = random.choices(self.augment_list, k=self.n) 265 | for op, minval, maxval in ops: 266 | val = (float(self.m) / 30) * float(maxval - minval) + minval 267 | img = op(img, val) 268 | 269 | return img 270 | -------------------------------------------------------------------------------- /meter/transforms/transform.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | inception_normalize, 3 | imagenet_normalize, 4 | MinMaxResize, 5 | ) 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 9 | from .randaug import RandAugment 10 | 11 | 12 | def pixelbert_transform(size=800): 13 | longer = int((1333 / 800) * size) 14 | return transforms.Compose( 15 | [ 16 | MinMaxResize(shorter=size, longer=longer), 17 | transforms.ToTensor(), 18 | inception_normalize, 19 | ] 20 | ) 21 | 22 | def pixelbert_transform_randaug(size=800): 23 | longer = int((1333 / 800) * size) 24 | trs = transforms.Compose( 25 | [ 26 | MinMaxResize(shorter=size, longer=longer), 27 | transforms.ToTensor(), 28 | inception_normalize, 29 | ] 30 | ) 31 | trs.transforms.insert(0, RandAugment(2, 9)) 32 | return trs 33 | 34 | def imagenet_transform(size=800): 35 | return transforms.Compose( 36 | [ 37 | Resize(size, interpolation=Image.BICUBIC), 38 | CenterCrop(size), 39 | transforms.ToTensor(), 40 | imagenet_normalize, 41 | ] 42 | ) 43 | 44 | def imagenet_transform_randaug(size=800): 45 | trs = transforms.Compose( 46 | [ 47 | Resize(size, interpolation=Image.BICUBIC), 48 | CenterCrop(size), 49 | transforms.ToTensor(), 50 | imagenet_normalize, 51 | ] 52 | ) 53 | trs.transforms.insert(0, RandAugment(2, 9)) 54 | return trs 55 | 56 | def vit_transform(size=800): 57 | return transforms.Compose( 58 | [ 59 | Resize(size, interpolation=Image.BICUBIC), 60 | CenterCrop(size), 61 | transforms.ToTensor(), 62 | inception_normalize, 63 | ] 64 | ) 65 | 66 | def vit_transform_randaug(size=800): 67 | trs = transforms.Compose( 68 | [ 69 | Resize(size, interpolation=Image.BICUBIC), 70 | CenterCrop(size), 71 | transforms.ToTensor(), 72 | inception_normalize, 73 | ] 74 | ) 75 | trs.transforms.insert(0, RandAugment(2, 9)) 76 | return trs 77 | 78 | def clip_transform(size): 79 | return Compose([ 80 | Resize(size, interpolation=Image.BICUBIC), 81 | CenterCrop(size), 82 | lambda image: image.convert("RGB"), 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | def clip_transform_randaug(size): 88 | trs = Compose([ 89 | Resize(size, interpolation=Image.BICUBIC), 90 | CenterCrop(size), 91 | lambda image: image.convert("RGB"), 92 | ToTensor(), 93 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 94 | ]) 95 | trs.transforms.insert(0, lambda image: image.convert('RGBA')) 96 | trs.transforms.insert(0, RandAugment(2, 9)) 97 | trs.transforms.insert(0, lambda image: image.convert('RGB')) 98 | return trs 99 | 100 | -------------------------------------------------------------------------------- /meter/transforms/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | 5 | class MinMaxResize: 6 | def __init__(self, shorter=800, longer=1333): 7 | self.min = shorter 8 | self.max = longer 9 | 10 | def __call__(self, x): 11 | w, h = x.size 12 | scale = self.min / min(w, h) 13 | if h < w: 14 | newh, neww = self.min, scale * w 15 | else: 16 | newh, neww = scale * h, self.min 17 | 18 | if max(newh, neww) > self.max: 19 | scale = self.max / max(newh, neww) 20 | newh = newh * scale 21 | neww = neww * scale 22 | 23 | newh, neww = int(newh + 0.5), int(neww + 0.5) 24 | newh, neww = newh // 32 * 32, neww // 32 * 32 25 | 26 | return x.resize((neww, newh), resample=Image.BICUBIC) 27 | 28 | 29 | class UnNormalize(object): 30 | def __init__(self, mean, std): 31 | self.mean = mean 32 | self.std = std 33 | 34 | def __call__(self, tensor): 35 | """ 36 | Args: 37 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 38 | Returns: 39 | Tensor: Normalized image. 40 | """ 41 | for t, m, s in zip(tensor, self.mean, self.std): 42 | t.mul_(s).add_(m) 43 | # The normalize code -> t.sub_(m).div_(s) 44 | return tensor 45 | 46 | 47 | # This is simple maximum entropy normalization performed in Inception paper 48 | inception_normalize = transforms.Compose( 49 | [transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 50 | ) 51 | 52 | # ViT uses simple non-biased inception normalization 53 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py#L132 54 | inception_unnormalize = transforms.Compose( 55 | [UnNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])] 56 | ) 57 | 58 | # ImageNet normalize 59 | imagenet_normalize = transforms.Compose( 60 | [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 61 | ) 62 | -------------------------------------------------------------------------------- /meter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zdou0830/METER/f4f09345b26ee21add0a756d06598e3c04726345/meter/utils/__init__.py -------------------------------------------------------------------------------- /meter/utils/glossary.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | contractions = { 4 | "aint": "ain't", 5 | "arent": "aren't", 6 | "cant": "can't", 7 | "couldve": "could've", 8 | "couldnt": "couldn't", 9 | "couldn'tve": "couldn't've", 10 | "couldnt've": "couldn't've", 11 | "didnt": "didn't", 12 | "doesnt": "doesn't", 13 | "dont": "don't", 14 | "hadnt": "hadn't", 15 | "hadnt've": "hadn't've", 16 | "hadn'tve": "hadn't've", 17 | "hasnt": "hasn't", 18 | "havent": "haven't", 19 | "hed": "he'd", 20 | "hed've": "he'd've", 21 | "he'dve": "he'd've", 22 | "hes": "he's", 23 | "howd": "how'd", 24 | "howll": "how'll", 25 | "hows": "how's", 26 | "Id've": "I'd've", 27 | "I'dve": "I'd've", 28 | "Im": "I'm", 29 | "Ive": "I've", 30 | "isnt": "isn't", 31 | "itd": "it'd", 32 | "itd've": "it'd've", 33 | "it'dve": "it'd've", 34 | "itll": "it'll", 35 | "let's": "let's", 36 | "maam": "ma'am", 37 | "mightnt": "mightn't", 38 | "mightnt've": "mightn't've", 39 | "mightn'tve": "mightn't've", 40 | "mightve": "might've", 41 | "mustnt": "mustn't", 42 | "mustve": "must've", 43 | "neednt": "needn't", 44 | "notve": "not've", 45 | "oclock": "o'clock", 46 | "oughtnt": "oughtn't", 47 | "ow's'at": "'ow's'at", 48 | "'ows'at": "'ow's'at", 49 | "'ow'sat": "'ow's'at", 50 | "shant": "shan't", 51 | "shed've": "she'd've", 52 | "she'dve": "she'd've", 53 | "she's": "she's", 54 | "shouldve": "should've", 55 | "shouldnt": "shouldn't", 56 | "shouldnt've": "shouldn't've", 57 | "shouldn'tve": "shouldn't've", 58 | "somebody'd": "somebodyd", 59 | "somebodyd've": "somebody'd've", 60 | "somebody'dve": "somebody'd've", 61 | "somebodyll": "somebody'll", 62 | "somebodys": "somebody's", 63 | "someoned": "someone'd", 64 | "someoned've": "someone'd've", 65 | "someone'dve": "someone'd've", 66 | "someonell": "someone'll", 67 | "someones": "someone's", 68 | "somethingd": "something'd", 69 | "somethingd've": "something'd've", 70 | "something'dve": "something'd've", 71 | "somethingll": "something'll", 72 | "thats": "that's", 73 | "thered": "there'd", 74 | "thered've": "there'd've", 75 | "there'dve": "there'd've", 76 | "therere": "there're", 77 | "theres": "there's", 78 | "theyd": "they'd", 79 | "theyd've": "they'd've", 80 | "they'dve": "they'd've", 81 | "theyll": "they'll", 82 | "theyre": "they're", 83 | "theyve": "they've", 84 | "twas": "'twas", 85 | "wasnt": "wasn't", 86 | "wed've": "we'd've", 87 | "we'dve": "we'd've", 88 | "weve": "we've", 89 | "werent": "weren't", 90 | "whatll": "what'll", 91 | "whatre": "what're", 92 | "whats": "what's", 93 | "whatve": "what've", 94 | "whens": "when's", 95 | "whered": "where'd", 96 | "wheres": "where's", 97 | "whereve": "where've", 98 | "whod": "who'd", 99 | "whod've": "who'd've", 100 | "who'dve": "who'd've", 101 | "wholl": "who'll", 102 | "whos": "who's", 103 | "whove": "who've", 104 | "whyll": "why'll", 105 | "whyre": "why're", 106 | "whys": "why's", 107 | "wont": "won't", 108 | "wouldve": "would've", 109 | "wouldnt": "wouldn't", 110 | "wouldnt've": "wouldn't've", 111 | "wouldn'tve": "wouldn't've", 112 | "yall": "y'all", 113 | "yall'll": "y'all'll", 114 | "y'allll": "y'all'll", 115 | "yall'd've": "y'all'd've", 116 | "y'alld've": "y'all'd've", 117 | "y'all'dve": "y'all'd've", 118 | "youd": "you'd", 119 | "youd've": "you'd've", 120 | "you'dve": "you'd've", 121 | "youll": "you'll", 122 | "youre": "you're", 123 | "youve": "you've", 124 | } 125 | 126 | manual_map = { 127 | "none": "0", 128 | "zero": "0", 129 | "one": "1", 130 | "two": "2", 131 | "three": "3", 132 | "four": "4", 133 | "five": "5", 134 | "six": "6", 135 | "seven": "7", 136 | "eight": "8", 137 | "nine": "9", 138 | "ten": "10", 139 | } 140 | articles = ["a", "an", "the"] 141 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 142 | comma_strip = re.compile("(\d)(\,)(\d)") 143 | punct = [ 144 | ";", 145 | r"/", 146 | "[", 147 | "]", 148 | '"', 149 | "{", 150 | "}", 151 | "(", 152 | ")", 153 | "=", 154 | "+", 155 | "\\", 156 | "_", 157 | "-", 158 | ">", 159 | "<", 160 | "@", 161 | "`", 162 | ",", 163 | "?", 164 | "!", 165 | ] 166 | 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | -------------------------------------------------------------------------------- /meter/utils/write_coco_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import pyarrow as pa 5 | import random 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | with open(path, "rb") as fp: 15 | binary = fp.read() 16 | captions = iid2captions[name] 17 | split = iid2split[name] 18 | return [binary, captions, name, split] 19 | 20 | 21 | def make_arrow(root, dataset_root): 22 | with open(f"{root}/karpathy/dataset_coco.json", "r") as fp: 23 | captions = json.load(fp) 24 | 25 | captions = captions["images"] 26 | 27 | iid2captions = defaultdict(list) 28 | iid2split = dict() 29 | 30 | for cap in tqdm(captions): 31 | filename = cap["filename"] 32 | iid2split[filename] = cap["split"] 33 | for c in cap["sentences"]: 34 | iid2captions[filename].append(c["raw"]) 35 | 36 | paths = list(glob(f"{root}/train2014/*.jpg")) + list(glob(f"{root}/val2014/*.jpg")) 37 | random.shuffle(paths) 38 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 39 | 40 | if len(paths) == len(caption_paths): 41 | print("all images have caption annotations") 42 | else: 43 | print("not all images have caption annotations") 44 | print( 45 | len(paths), len(caption_paths), len(iid2captions), 46 | ) 47 | 48 | bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)] 49 | 50 | for split in ["train", "val", "restval", "test"]: 51 | batches = [b for b in bs if b[-1] == split] 52 | 53 | dataframe = pd.DataFrame( 54 | batches, columns=["image", "caption", "image_id", "split"], 55 | ) 56 | 57 | table = pa.Table.from_pandas(dataframe) 58 | os.makedirs(dataset_root, exist_ok=True) 59 | with pa.OSFile( 60 | f"{dataset_root}/coco_caption_karpathy_{split}.arrow", "wb" 61 | ) as sink: 62 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 63 | writer.write_table(table) 64 | -------------------------------------------------------------------------------- /meter/utils/write_conceptual_caption.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | split, _, name = path.split("/")[-3:] 14 | split = split.split("_")[-1] 15 | iid = name 16 | 17 | with open(path, "rb") as fp: 18 | binary = fp.read() 19 | 20 | captions = iid2captions[iid] 21 | 22 | return [ 23 | binary, 24 | captions, 25 | iid, 26 | split, 27 | ] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | for split in ["val", "train"]: 32 | with open(f"{root}/{split}_annot.json", "r") as fp: 33 | captions = json.load(fp) 34 | 35 | iid2captions = dict() 36 | for cap in tqdm(captions): 37 | iid = cap[0].split("/")[-1] 38 | iid2captions[iid] = [cap[1]] 39 | 40 | paths = list(glob(f"{root}/images_{split}/*/*")) 41 | random.shuffle(paths) 42 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 43 | if len(paths) == len(caption_paths): 44 | print("all images have caption annotations") 45 | else: 46 | print("not all images have caption annotations") 47 | print( 48 | len(paths), len(caption_paths), len(iid2captions), 49 | ) 50 | 51 | sub_len = int(len(caption_paths) // 100000) 52 | subs = list(range(sub_len + 1)) 53 | for sub in subs: 54 | sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000] 55 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)] 56 | dataframe = pd.DataFrame( 57 | bs, columns=["image", "caption", "image_id", "split"], 58 | ) 59 | 60 | table = pa.Table.from_pandas(dataframe) 61 | 62 | os.makedirs(dataset_root, exist_ok=True) 63 | with pa.OSFile( 64 | f"{dataset_root}/conceptual_caption_{split}_{sub}.arrow", "wb" 65 | ) as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) 68 | del dataframe 69 | del table 70 | del bs 71 | gc.collect() 72 | -------------------------------------------------------------------------------- /meter/utils/write_f30k_karpathy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions, iid2split): 13 | name = path.split("/")[-1] 14 | 15 | with open(path, "rb") as fp: 16 | binary = fp.read() 17 | 18 | captions = iid2captions[name] 19 | split = iid2split[name] 20 | 21 | return [binary, captions, name, split] 22 | 23 | 24 | def make_arrow(root, dataset_root): 25 | with open(f"{root}/karpathy/dataset_flickr30k.json", "r") as fp: 26 | captions = json.load(fp) 27 | 28 | captions = captions["images"] 29 | 30 | iid2captions = defaultdict(list) 31 | iid2split = dict() 32 | 33 | for cap in tqdm(captions): 34 | filename = cap["filename"] 35 | iid2split[filename] = cap["split"] 36 | for c in cap["sentences"]: 37 | iid2captions[filename].append(c["raw"]) 38 | 39 | paths = list(glob(f"{root}/flickr30k-images/*.jpg")) 40 | random.shuffle(paths) 41 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 42 | 43 | if len(paths) == len(caption_paths): 44 | print("all images have caption annotations") 45 | else: 46 | print("not all images have caption annotations") 47 | print( 48 | len(paths), len(caption_paths), len(iid2captions), 49 | ) 50 | 51 | bs = [path2rest(path, iid2captions, iid2split) for path in tqdm(caption_paths)] 52 | 53 | for split in ["train", "val", "test"]: 54 | batches = [b for b in bs if b[-1] == split] 55 | 56 | dataframe = pd.DataFrame( 57 | batches, columns=["image", "caption", "image_id", "split"], 58 | ) 59 | 60 | table = pa.Table.from_pandas(dataframe) 61 | 62 | os.makedirs(dataset_root, exist_ok=True) 63 | with pa.OSFile( 64 | f"{dataset_root}/f30k_caption_karpathy_{split}.arrow", "wb" 65 | ) as sink: 66 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 67 | writer.write_table(table) 68 | -------------------------------------------------------------------------------- /meter/utils/write_nlvr2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | def process(root, iden, row): 11 | texts = [r["sentence"] for r in row] 12 | labels = [r["label"] for r in row] 13 | 14 | split = iden.split("-")[0] 15 | 16 | if iden.startswith("train"): 17 | directory = row[0]["directory"] 18 | path = f"{root}/images/train/{directory}/{iden}" 19 | else: 20 | path = f"{root}/{split}/{iden}" 21 | 22 | with open(f"{path}-img0.png", "rb") as fp: 23 | img0 = fp.read() 24 | with open(f"{path}-img1.png", "rb") as fp: 25 | img1 = fp.read() 26 | 27 | return [img0, img1, texts, labels, iden] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | train_data = list( 32 | map(json.loads, open(f"{root}/nlvr2/data/train.json").readlines()) 33 | ) 34 | test1_data = list( 35 | map(json.loads, open(f"{root}/nlvr2/data/test1.json").readlines()) 36 | ) 37 | dev_data = list(map(json.loads, open(f"{root}/nlvr2/data/dev.json").readlines())) 38 | 39 | balanced_test1_data = list( 40 | map( 41 | json.loads, 42 | open(f"{root}/nlvr2/data/balanced/balanced_test1.json").readlines(), 43 | ) 44 | ) 45 | balanced_dev_data = list( 46 | map( 47 | json.loads, 48 | open(f"{root}/nlvr2/data/balanced/balanced_dev.json").readlines(), 49 | ) 50 | ) 51 | 52 | unbalanced_test1_data = list( 53 | map( 54 | json.loads, 55 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_test1.json").readlines(), 56 | ) 57 | ) 58 | unbalanced_dev_data = list( 59 | map( 60 | json.loads, 61 | open(f"{root}/nlvr2/data/unbalanced/unbalanced_dev.json").readlines(), 62 | ) 63 | ) 64 | 65 | splits = [ 66 | "train", 67 | "dev", 68 | "test1", 69 | "balanced_dev", 70 | "balanced_test1", 71 | "unbalanced_dev", 72 | "unbalanced_test1", 73 | ] 74 | 75 | datas = [ 76 | train_data, 77 | dev_data, 78 | test1_data, 79 | balanced_dev_data, 80 | balanced_test1_data, 81 | unbalanced_dev_data, 82 | unbalanced_test1_data, 83 | ] 84 | 85 | annotations = dict() 86 | 87 | for split, data in zip(splits, datas): 88 | _annot = defaultdict(list) 89 | for row in tqdm(data): 90 | _annot["-".join(row["identifier"].split("-")[:-1])].append(row) 91 | annotations[split] = _annot 92 | 93 | for split in splits: 94 | bs = [ 95 | process(root, iden, row) for iden, row in tqdm(annotations[split].items()) 96 | ] 97 | 98 | dataframe = pd.DataFrame( 99 | bs, columns=["image_0", "image_1", "questions", "answers", "identifier"], 100 | ) 101 | 102 | table = pa.Table.from_pandas(dataframe) 103 | 104 | os.makedirs(dataset_root, exist_ok=True) 105 | with pa.OSFile(f"{dataset_root}/nlvr2_{split}.arrow", "wb") as sink: 106 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 107 | writer.write_table(table) 108 | -------------------------------------------------------------------------------- /meter/utils/write_sbu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import gc 5 | import random 6 | import os 7 | 8 | from tqdm import tqdm 9 | from glob import glob 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | split, _, name = path.split("/")[-3:] 14 | split = split.split("_")[-1] 15 | iid = name 16 | 17 | with open(path, "rb") as fp: 18 | binary = fp.read() 19 | 20 | captions = iid2captions[iid] 21 | 22 | return [ 23 | binary, 24 | captions, 25 | iid, 26 | split, 27 | ] 28 | 29 | 30 | def make_arrow(root, dataset_root): 31 | with open(f"{root}/annot.json", "r") as fp: 32 | captions = json.load(fp) 33 | 34 | iid2captions = dict() 35 | for cap in tqdm(captions): 36 | iid = cap[0].split("/")[-1] 37 | iid2captions[iid] = [cap[1]] 38 | 39 | paths = list(glob(f"{root}/images_train/*/*")) 40 | random.shuffle(paths) 41 | caption_paths = [path for path in paths if path.split("/")[-1] in iid2captions] 42 | if len(paths) == len(caption_paths): 43 | print("all images have caption annotations") 44 | else: 45 | print("not all images have caption annotations") 46 | print( 47 | len(paths), len(caption_paths), len(iid2captions), 48 | ) 49 | 50 | sub_len = int(len(caption_paths) // 100000) 51 | subs = list(range(sub_len + 1)) 52 | for sub in subs: 53 | sub_paths = caption_paths[sub * 100000 : (sub + 1) * 100000] 54 | bs = [path2rest(path, iid2captions) for path in tqdm(sub_paths)] 55 | dataframe = pd.DataFrame(bs, columns=["image", "caption", "image_id", "split"],) 56 | 57 | table = pa.Table.from_pandas(dataframe) 58 | 59 | os.makedirs(dataset_root, exist_ok=True) 60 | with pa.OSFile(f"{dataset_root}/sbu_{sub}.arrow", "wb") as sink: 61 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 62 | writer.write_table(table) 63 | del dataframe 64 | del table 65 | del bs 66 | gc.collect() 67 | -------------------------------------------------------------------------------- /meter/utils/write_snli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import os 5 | 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | 10 | label2id = {'contradiction': 0, 'neutral': 1, 'entailment': 2} 11 | def process(root, imgid, ann): 12 | with open(f"{root}/Flickr30K/images/{imgid}.jpg", "rb") as fp: 13 | img = fp.read() 14 | 15 | sentences = ann['sentences'] 16 | 17 | labels = ann['labels'] 18 | 19 | return [img, sentences, labels] 20 | 21 | 22 | 23 | def make_arrow(root, dataset_root): 24 | train_data = list( 25 | map(json.loads, open(f"{root}/snli_ve_train.jsonl").readlines()) 26 | ) 27 | test_data = list( 28 | map(json.loads, open(f"{root}/snli_ve_test.jsonl").readlines()) 29 | ) 30 | dev_data = list( 31 | map(json.loads, open(f"{root}/snli_ve_dev.jsonl").readlines()) 32 | ) 33 | 34 | 35 | splits = [ 36 | "train", 37 | "dev", 38 | "test", 39 | ] 40 | 41 | 42 | annotations = dict() 43 | annotations['train'] = train_data 44 | annotations['dev'] = dev_data 45 | annotations['test'] = test_data 46 | annots = dict() 47 | for split in splits: 48 | annots[split] = {} 49 | for line in annotations[split]: 50 | imgid = line['Flickr30K_ID'] 51 | if not imgid in annots[split]: 52 | annots[split][imgid] = {} 53 | annots[split][imgid]['sentences'] = [] 54 | annots[split][imgid]['labels'] = [] 55 | annots[split][imgid]['sentences'].append( [line['sentence1'], line['sentence2']] ) 56 | annots[split][imgid]['labels'].append( label2id[line['gold_label']] ) 57 | 58 | 59 | 60 | for split in splits: 61 | bs = [process(root, imgid, annots[split][imgid]) for imgid in tqdm(annots[split])] 62 | 63 | dataframe = pd.DataFrame( 64 | bs, columns=["image", "sentences", "labels"] 65 | ) 66 | 67 | table = pa.Table.from_pandas(dataframe) 68 | 69 | os.makedirs(dataset_root, exist_ok=True) 70 | with pa.OSFile(f"{dataset_root}/snli_{split}.arrow", "wb") as sink: 71 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 72 | writer.write_table(table) 73 | -------------------------------------------------------------------------------- /meter/utils/write_vg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict 10 | 11 | 12 | def path2rest(path, iid2captions): 13 | name = path.split("/")[-1] 14 | iid = int(name[:-4]) 15 | 16 | with open(path, "rb") as fp: 17 | binary = fp.read() 18 | 19 | cdicts = iid2captions[iid] 20 | captions = [c["phrase"] for c in cdicts] 21 | widths = [c["width"] for c in cdicts] 22 | heights = [c["height"] for c in cdicts] 23 | xs = [c["x"] for c in cdicts] 24 | ys = [c["y"] for c in cdicts] 25 | 26 | return [ 27 | binary, 28 | captions, 29 | widths, 30 | heights, 31 | xs, 32 | ys, 33 | str(iid), 34 | ] 35 | 36 | 37 | def make_arrow(root, dataset_root): 38 | with open(f"{root}/annotations/region_descriptions.json", "r") as fp: 39 | captions = json.load(fp) 40 | 41 | iid2captions = defaultdict(list) 42 | for cap in tqdm(captions): 43 | cap = cap["regions"] 44 | for c in cap: 45 | iid2captions[c["image_id"]].append(c) 46 | 47 | paths = list(glob(f"{root}/images/VG_100K/*.jpg")) + list( 48 | glob(f"{root}/images/VG_100K_2/*.jpg") 49 | ) 50 | random.shuffle(paths) 51 | caption_paths = [ 52 | path for path in paths if int(path.split("/")[-1][:-4]) in iid2captions 53 | ] 54 | 55 | if len(paths) == len(caption_paths): 56 | print("all images have caption annotations") 57 | else: 58 | print("not all images have caption annotations") 59 | print( 60 | len(paths), len(caption_paths), len(iid2captions), 61 | ) 62 | 63 | bs = [path2rest(path, iid2captions) for path in tqdm(caption_paths)] 64 | dataframe = pd.DataFrame( 65 | bs, columns=["image", "caption", "width", "height", "x", "y", "image_id"], 66 | ) 67 | table = pa.Table.from_pandas(dataframe) 68 | 69 | os.makedirs(dataset_root, exist_ok=True) 70 | with pa.OSFile(f"{dataset_root}/vg.arrow", "wb") as sink: 71 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 72 | writer.write_table(table) 73 | -------------------------------------------------------------------------------- /meter/utils/write_vqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import pyarrow as pa 4 | import random 5 | import os 6 | 7 | from tqdm import tqdm 8 | from glob import glob 9 | from collections import defaultdict, Counter 10 | from .glossary import normalize_word 11 | 12 | 13 | def get_score(occurences): 14 | if occurences == 0: 15 | return 0.0 16 | elif occurences == 1: 17 | return 0.3 18 | elif occurences == 2: 19 | return 0.6 20 | elif occurences == 3: 21 | return 0.9 22 | else: 23 | return 1.0 24 | 25 | 26 | def path2rest(path, split, annotations, label2ans): 27 | iid = int(path.split("/")[-1].split("_")[-1][:-4]) 28 | 29 | with open(path, "rb") as fp: 30 | binary = fp.read() 31 | 32 | _annot = annotations[split][iid] 33 | _annot = list(_annot.items()) 34 | qids, qas = [a[0] for a in _annot], [a[1] for a in _annot] 35 | questions = [qa[0] for qa in qas] 36 | answers = [qa[1] for qa in qas] if "test" not in split else list(list()) 37 | answer_labels = ( 38 | [a["labels"] for a in answers] if "test" not in split else list(list()) 39 | ) 40 | answer_scores = ( 41 | [a["scores"] for a in answers] if "test" not in split else list(list()) 42 | ) 43 | answers = ( 44 | [[label2ans[l] for l in al] for al in answer_labels] 45 | if "test" not in split 46 | else list(list()) 47 | ) 48 | 49 | return [binary, questions, answers, answer_labels, answer_scores, iid, qids, split] 50 | 51 | 52 | def make_arrow(root, dataset_root): 53 | with open(f"{root}/v2_OpenEnded_mscoco_train2014_questions.json", "r") as fp: 54 | questions_train2014 = json.load(fp)["questions"] 55 | with open(f"{root}/v2_OpenEnded_mscoco_val2014_questions.json", "r") as fp: 56 | questions_val2014 = json.load(fp)["questions"] 57 | with open(f"{root}/v2_OpenEnded_mscoco_test2015_questions.json", "r") as fp: 58 | questions_test2015 = json.load(fp)["questions"] 59 | with open(f"{root}/v2_OpenEnded_mscoco_test-dev2015_questions.json", "r") as fp: 60 | questions_test_dev2015 = json.load(fp)["questions"] 61 | 62 | with open(f"{root}/v2_mscoco_train2014_annotations.json", "r") as fp: 63 | annotations_train2014 = json.load(fp)["annotations"] 64 | with open(f"{root}/v2_mscoco_val2014_annotations.json", "r") as fp: 65 | annotations_val2014 = json.load(fp)["annotations"] 66 | 67 | annotations = dict() 68 | 69 | for split, questions in zip( 70 | ["train", "val", "test", "test-dev"], 71 | [ 72 | questions_train2014, 73 | questions_val2014, 74 | questions_test2015, 75 | questions_test_dev2015, 76 | ], 77 | ): 78 | _annot = defaultdict(dict) 79 | for q in tqdm(questions): 80 | _annot[q["image_id"]][q["question_id"]] = [q["question"]] 81 | 82 | annotations[split] = _annot 83 | 84 | all_major_answers = list() 85 | 86 | for split, annots in zip( 87 | ["train", "val"], [annotations_train2014, annotations_val2014], 88 | ): 89 | _annot = annotations[split] 90 | for q in tqdm(annots): 91 | all_major_answers.append(q["multiple_choice_answer"]) 92 | 93 | all_major_answers = [normalize_word(word) for word in tqdm(all_major_answers)] 94 | counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} 95 | ans2label = {k: i for i, k in enumerate(counter.keys())} 96 | label2ans = list(counter.keys()) 97 | 98 | for split, annots in zip( 99 | ["train", "val"], [annotations_train2014, annotations_val2014], 100 | ): 101 | _annot = annotations[split] 102 | for q in tqdm(annots): 103 | answers = q["answers"] 104 | answer_count = {} 105 | for answer in answers: 106 | answer_ = answer["answer"] 107 | answer_count[answer_] = answer_count.get(answer_, 0) + 1 108 | 109 | labels = [] 110 | scores = [] 111 | for answer in answer_count: 112 | if answer not in ans2label: 113 | continue 114 | labels.append(ans2label[answer]) 115 | score = get_score(answer_count[answer]) 116 | scores.append(score) 117 | 118 | _annot[q["image_id"]][q["question_id"]].append( 119 | {"labels": labels, "scores": scores,} 120 | ) 121 | 122 | for split in ["train", "val"]: 123 | filtered_annot = dict() 124 | for ik, iv in annotations[split].items(): 125 | new_q = dict() 126 | for qk, qv in iv.items(): 127 | if len(qv[1]["labels"]) != 0: 128 | new_q[qk] = qv 129 | if len(new_q) != 0: 130 | filtered_annot[ik] = new_q 131 | annotations[split] = filtered_annot 132 | 133 | for split in [ 134 | "train", 135 | "val", 136 | "test", 137 | "test-dev", 138 | ]: 139 | annot = annotations[split] 140 | split_name = { 141 | "train": "train2014", 142 | "val": "val2014", 143 | "test": "test2015", 144 | "test-dev": "test2015", 145 | }[split] 146 | paths = list(glob(f"{root}/{split_name}/*.jpg")) 147 | random.shuffle(paths) 148 | annot_paths = [ 149 | path 150 | for path in paths 151 | if int(path.split("/")[-1].split("_")[-1][:-4]) in annot 152 | ] 153 | 154 | if len(paths) == len(annot_paths): 155 | print("all images have caption annotations") 156 | else: 157 | print("not all images have caption annotations") 158 | print( 159 | len(paths), len(annot_paths), len(annot), 160 | ) 161 | 162 | bs = [ 163 | path2rest(path, split, annotations, label2ans) for path in tqdm(annot_paths) 164 | ] 165 | 166 | dataframe = pd.DataFrame( 167 | bs, 168 | columns=[ 169 | "image", 170 | "questions", 171 | "answers", 172 | "answer_labels", 173 | "answer_scores", 174 | "image_id", 175 | "question_id", 176 | "split", 177 | ], 178 | ) 179 | 180 | table = pa.Table.from_pandas(dataframe) 181 | 182 | os.makedirs(dataset_root, exist_ok=True) 183 | with pa.OSFile(f"{dataset_root}/vqav2_{split}.arrow", "wb") as sink: 184 | with pa.RecordBatchFileWriter(sink, table.schema) as writer: 185 | writer.write_table(table) 186 | 187 | table = pa.ipc.RecordBatchFileReader( 188 | pa.memory_map(f"{dataset_root}/vqav2_val.arrow", "r") 189 | ).read_all() 190 | 191 | pdtable = table.to_pandas() 192 | 193 | df1 = pdtable[:-1000] 194 | df2 = pdtable[-1000:] 195 | 196 | df1 = pa.Table.from_pandas(df1) 197 | df2 = pa.Table.from_pandas(df2) 198 | 199 | with pa.OSFile(f"{dataset_root}/vqav2_trainable_val.arrow", "wb") as sink: 200 | with pa.RecordBatchFileWriter(sink, df1.schema) as writer: 201 | writer.write_table(df1) 202 | 203 | with pa.OSFile(f"{dataset_root}/vqav2_rest_val.arrow", "wb") as sink: 204 | with pa.RecordBatchFileWriter(sink, df2.schema) as writer: 205 | writer.write_table(df2) 206 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.3.2 2 | transformers==4.6.0 3 | Pillow==8.1.0 4 | tqdm==4.56.0 5 | ipdb==0.13.4 6 | numpy==1.19.5 7 | einops==0.3.0 8 | pyarrow==2.0.0 9 | sacred==0.8.2 10 | pandas==1.1.5 11 | timm==0.4.12 12 | ftfy 13 | torchvision~=0.8.2 14 | torch~=1.7.1 15 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import pytorch_lightning as pl 4 | import os 5 | os.environ["NCCL_DEBUG"] = "INFO" 6 | 7 | from meter.config import ex 8 | from meter.modules import METERTransformerSS 9 | from meter.datamodules.multitask_datamodule import MTDataModule 10 | 11 | import resource 12 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 13 | resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1])) 14 | 15 | @ex.automain 16 | def main(_config): 17 | _config = copy.deepcopy(_config) 18 | pl.seed_everything(_config["seed"]) 19 | 20 | dm = MTDataModule(_config, dist=True) 21 | 22 | model = METERTransformerSS(_config) 23 | exp_name = f'{_config["exp_name"]}' 24 | 25 | os.makedirs(_config["log_dir"], exist_ok=True) 26 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 27 | save_top_k=1, 28 | verbose=True, 29 | monitor="val/the_metric", 30 | mode="max", 31 | save_last=True, 32 | ) 33 | logger = pl.loggers.TensorBoardLogger( 34 | _config["log_dir"], 35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 36 | ) 37 | 38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 39 | callbacks = [checkpoint_callback, lr_callback] 40 | 41 | num_gpus = ( 42 | _config["num_gpus"] 43 | if isinstance(_config["num_gpus"], int) 44 | else len(_config["num_gpus"]) 45 | ) 46 | 47 | grad_steps = max(_config["batch_size"] // ( 48 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 49 | ), 1) 50 | 51 | max_steps = _config["max_steps"] if _config["max_steps"] is not None else None 52 | 53 | trainer = pl.Trainer( 54 | gpus=_config["num_gpus"], 55 | num_nodes=_config["num_nodes"], 56 | precision=_config["precision"], 57 | accelerator="ddp", 58 | benchmark=True, 59 | deterministic=True, 60 | max_epochs=_config["max_epoch"] if max_steps is None else 1000, 61 | max_steps=max_steps, 62 | callbacks=callbacks, 63 | logger=logger, 64 | prepare_data_per_node=False, 65 | replace_sampler_ddp=False, 66 | accumulate_grad_batches=grad_steps, 67 | log_every_n_steps=10, 68 | flush_logs_every_n_steps=10, 69 | resume_from_checkpoint=_config["resume_from"], 70 | weights_summary="top", 71 | fast_dev_run=_config["fast_dev_run"], 72 | val_check_interval=_config["val_check_interval"], 73 | ) 74 | 75 | if not _config["test_only"]: 76 | trainer.fit(model, datamodule=dm) 77 | else: 78 | trainer.test(model, datamodule=dm) 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="meter", 5 | packages=find_packages( 6 | exclude=[".dfc", ".vscode", "dataset", "notebooks", "result", "scripts"] 7 | ), 8 | version="0.1.0", 9 | license="MIT", 10 | description="METER: Multimodal End-to-end TransformER", 11 | author="Microsoft Corporation", 12 | author_email="zdou0830@gmail.com", 13 | url="https://github.com/zdou0830/METER", 14 | keywords=["vision and language pretraining"], 15 | install_requires=["torch", "pytorch_lightning"], 16 | ) 17 | --------------------------------------------------------------------------------