├── BLIP.gif ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── cog.yaml ├── configs ├── bert_config.json ├── caption_coco.yaml ├── med_config.json ├── nlvr.yaml ├── nocaps.yaml ├── pretrain.yaml ├── retrieval_coco.yaml ├── retrieval_flickr.yaml ├── retrieval_msrvtt.yaml └── vqa.yaml ├── data ├── __init__.py ├── coco_karpathy_dataset.py ├── flickr30k_dataset.py ├── nlvr_dataset.py ├── nocaps_dataset.py ├── pretrain_dataset.py ├── utils.py ├── video_dataset.py └── vqa_dataset.py ├── demo.ipynb ├── eval_nocaps.py ├── eval_retrieval_video.py ├── models ├── __init__.py ├── blip.py ├── blip_itm.py ├── blip_nlvr.py ├── blip_pretrain.py ├── blip_retrieval.py ├── blip_vqa.py ├── med.py ├── nlvr_encoder.py └── vit.py ├── predict.py ├── pretrain.py ├── requirements.txt ├── train_caption.py ├── train_nlvr.py ├── train_retrieval.py ├── train_vqa.py ├── transform └── randaugment.py └── utils.py /BLIP.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/BLIP/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/BLIP.gif -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation 2 | 3 | ## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications! 4 | 5 | 6 | 7 | This is the PyTorch code of the BLIP paper [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10. 8 | To install the dependencies, run
pip install -r requirements.txt
9 | 10 | Catalog: 11 | - [x] Inference demo 12 | - [x] Pre-trained and finetuned checkpoints 13 | - [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2 14 | - [x] Pre-training code 15 | - [x] Zero-shot video-text retrieval 16 | - [x] Download of bootstrapped pre-training datasets 17 | 18 | 19 | ### Inference demo: 20 | Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed). 21 | The demo includes code for: 22 | 1. Image captioning 23 | 2. Open-ended visual question answering 24 | 3. Multimodal / unimodal feature extraction 25 | 4. Image-text matching 26 | 27 | Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). 28 | 29 | Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip) 30 | 31 | ### Pre-trained checkpoints: 32 | Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L 33 | --- | :---: | :---: | :---: 34 | 14M | Download| - | - 35 | 129M | Download| Download | Download 36 | 37 | ### Finetuned checkpoints: 38 | Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L 39 | --- | :---: | :---: | :---: 40 | Image-Text Retrieval (COCO) | Download| - | Download 41 | Image-Text Retrieval (Flickr30k) | Download| - | Download 42 | Image Captioning (COCO) | - | Download| Download | 43 | VQA | Download| Download | - 44 | NLVR2 | Download| - | - 45 | 46 | 47 | ### Image-Text Retrieval: 48 | 1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly. 49 | 2. To evaluate the finetuned BLIP model on COCO, run: 50 |
python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
 51 | --config ./configs/retrieval_coco.yaml \
 52 | --output_dir output/retrieval_coco \
 53 | --evaluate
54 | 3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: 55 |
python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
 56 | --config ./configs/retrieval_coco.yaml \
 57 | --output_dir output/retrieval_coco 
58 | 59 | ### Image-Text Captioning: 60 | 1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly. 61 | 2. To evaluate the finetuned BLIP model on COCO, run: 62 |
python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate
63 | 3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server) 64 |
python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py 
65 | 4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: 66 |
python -m torch.distributed.run --nproc_per_node=8 train_caption.py 
67 | 68 | ### VQA: 69 | 1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml. 70 | 2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server) 71 |
python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate
72 | 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run: 73 |
python -m torch.distributed.run --nproc_per_node=16 train_vqa.py 
74 | 75 | ### NLVR2: 76 | 1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml. 77 | 2. To evaluate the finetuned BLIP model, run 78 |
python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate
79 | 3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run: 80 |
python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py 
81 | 82 | ### Finetune with ViT-L: 83 | In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). Gradient checkpoint can also be activated in the config file to reduce GPU memory usage. 84 | 85 | ### Pre-train: 86 | 1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}. 87 | 2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files . 88 | 3. Pre-train the model using 8 A100 GPUs: 89 |
python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 
90 | 91 | ### Zero-shot video-text retrieval: 92 | 1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml. 93 | 2. Install [decord](https://github.com/dmlc/decord) with
pip install decord
94 | 3. To perform zero-shot evaluation, run 95 |
python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py
96 | 97 | ### Pre-training datasets download: 98 | We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}. 99 | 100 | Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L 101 | --- | :---: | :---: | :---: 102 | CC3M+CC12M+SBU | Download| Download| Download 103 | LAION115M | Download| Download| Download 104 | 105 | ### Citation 106 | If you find this code to be useful for your research, please consider citing. 107 |
108 | @inproceedings{li2022blip,
109 |       title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation}, 
110 |       author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
111 |       year={2022},
112 |       booktitle={ICML},
113 | }
114 | 115 | ### Acknowledgement 116 | The implementation of BLIP relies on resources from ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing. 117 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.1" 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "ipython==7.30.1" 10 | - "torchvision==0.11.1" 11 | - "torch==1.10.0" 12 | - "timm==0.4.12" 13 | - "transformers==4.15.0" 14 | - "fairscale==0.4.4" 15 | - "pycocoevalcap==1.2" 16 | 17 | predict: "predict.py:Predictor" 18 | -------------------------------------------------------------------------------- /configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/NLVR2/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' 6 | 7 | #size of vit model; base or large 8 | vit: 'base' 9 | batch_size_train: 16 10 | batch_size_test: 64 11 | vit_grad_ckpt: False 12 | vit_ckpt_layer: 0 13 | max_epoch: 15 14 | 15 | image_size: 384 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-5 20 | min_lr: 0 21 | 22 | -------------------------------------------------------------------------------- /configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' 6 | 7 | vit: 'base' 8 | batch_size: 32 9 | 10 | image_size: 384 11 | 12 | max_length: 20 13 | min_length: 5 14 | num_beams: 3 15 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', 2 | '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/coco/images/' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 12 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 256 28 | negative_all_rank: True 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/export/share/datasets/vision/flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' 7 | 8 | # size of vit model; base or large 9 | 10 | vit: 'base' 11 | batch_size_train: 32 12 | batch_size_test: 64 13 | vit_grad_ckpt: True 14 | vit_ckpt_layer: 4 15 | init_lr: 1e-5 16 | 17 | # vit: 'large' 18 | # batch_size_train: 16 19 | # batch_size_test: 32 20 | # vit_grad_ckpt: True 21 | # vit_ckpt_layer: 10 22 | # init_lr: 5e-6 23 | 24 | image_size: 384 25 | queue_size: 57600 26 | alpha: 0.4 27 | k_test: 128 28 | negative_all_rank: False 29 | 30 | # optimizer 31 | weight_decay: 0.05 32 | min_lr: 0 33 | max_epoch: 6 34 | 35 | -------------------------------------------------------------------------------- /configs/retrieval_msrvtt.yaml: -------------------------------------------------------------------------------- 1 | video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' 6 | 7 | # size of vit model; base or large 8 | vit: 'base' 9 | batch_size: 64 10 | k_test: 128 11 | image_size: 384 12 | num_frm_test: 8 -------------------------------------------------------------------------------- /configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ 2 | vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth' 8 | 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 16 12 | batch_size_test: 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_coco': 37 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 38 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 39 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 40 | return train_dataset, val_dataset, test_dataset 41 | 42 | elif dataset=='nocaps': 43 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 44 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 45 | return val_dataset, test_dataset 46 | 47 | elif dataset=='retrieval_coco': 48 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 49 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 50 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 51 | return train_dataset, val_dataset, test_dataset 52 | 53 | elif dataset=='retrieval_flickr': 54 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 55 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 56 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 57 | return train_dataset, val_dataset, test_dataset 58 | 59 | elif dataset=='vqa': 60 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 61 | train_files = config['train_files'], split='train') 62 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 63 | return train_dataset, test_dataset 64 | 65 | elif dataset=='nlvr': 66 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 67 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 68 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 69 | return train_dataset, val_dataset, test_dataset 70 | 71 | 72 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 73 | samplers = [] 74 | for dataset,shuffle in zip(datasets,shuffles): 75 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 76 | samplers.append(sampler) 77 | return samplers 78 | 79 | 80 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 81 | loaders = [] 82 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 83 | if is_train: 84 | shuffle = (sampler is None) 85 | drop_last = True 86 | else: 87 | shuffle = False 88 | drop_last = False 89 | loader = DataLoader( 90 | dataset, 91 | batch_size=bs, 92 | num_workers=n_worker, 93 | pin_memory=True, 94 | sampler=sampler, 95 | shuffle=shuffle, 96 | collate_fn=collate_fn, 97 | drop_last=drop_last, 98 | ) 99 | loaders.append(loader) 100 | return loaders 101 | 102 | -------------------------------------------------------------------------------- /data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class coco_karpathy_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. coco/images/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 18 | filename = 'coco_karpathy_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class coco_karpathy_caption_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split): 54 | ''' 55 | image_root (string): Root directory of images (e.g. coco/images/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 61 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | def __len__(self): 70 | return len(self.annotation) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.image_root,ann['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 81 | 82 | return image, int(img_id) 83 | 84 | 85 | class coco_karpathy_retrieval_eval(Dataset): 86 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 87 | ''' 88 | image_root (string): Root directory of images (e.g. coco/images/) 89 | ann_root (string): directory to store the annotation file 90 | split (string): val or test 91 | ''' 92 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 93 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 94 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 95 | 96 | download_url(urls[split],ann_root) 97 | 98 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 99 | self.transform = transform 100 | self.image_root = image_root 101 | 102 | self.text = [] 103 | self.image = [] 104 | self.txt2img = {} 105 | self.img2txt = {} 106 | 107 | txt_id = 0 108 | for img_id, ann in enumerate(self.annotation): 109 | self.image.append(ann['image']) 110 | self.img2txt[img_id] = [] 111 | for i, caption in enumerate(ann['caption']): 112 | self.text.append(pre_caption(caption,max_words)) 113 | self.img2txt[img_id].append(txt_id) 114 | self.txt2img[txt_id] = img_id 115 | txt_id += 1 116 | 117 | def __len__(self): 118 | return len(self.annotation) 119 | 120 | def __getitem__(self, index): 121 | 122 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 123 | image = Image.open(image_path).convert('RGB') 124 | image = self.transform(image) 125 | 126 | return image, index -------------------------------------------------------------------------------- /data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class flickr30k_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. flickr30k/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 18 | filename = 'flickr30k_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class flickr30k_retrieval_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 54 | ''' 55 | image_root (string): Root directory of images (e.g. flickr30k/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 61 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | self.text = [] 70 | self.image = [] 71 | self.txt2img = {} 72 | self.img2txt = {} 73 | 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.annotation): 76 | self.image.append(ann['image']) 77 | self.img2txt[img_id] = [] 78 | for i, caption in enumerate(ann['caption']): 79 | self.text.append(pre_caption(caption,max_words)) 80 | self.img2txt[img_id].append(txt_id) 81 | self.txt2img[txt_id] = img_id 82 | txt_id += 1 83 | 84 | def __len__(self): 85 | return len(self.annotation) 86 | 87 | def __getitem__(self, index): 88 | 89 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 90 | image = Image.open(image_path).convert('RGB') 91 | image = self.transform(image) 92 | 93 | return image, index -------------------------------------------------------------------------------- /data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | class nlvr_dataset(Dataset): 13 | def __init__(self, transform, image_root, ann_root, split): 14 | ''' 15 | image_root (string): Root directory of images 16 | ann_root (string): directory to store the annotation file 17 | split (string): train, val or test 18 | ''' 19 | urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', 20 | 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', 21 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} 22 | filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} 23 | 24 | download_url(urls[split],ann_root) 25 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 26 | 27 | self.transform = transform 28 | self.image_root = image_root 29 | 30 | 31 | def __len__(self): 32 | return len(self.annotation) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.annotation[index] 38 | 39 | image0_path = os.path.join(self.image_root,ann['images'][0]) 40 | image0 = Image.open(image0_path).convert('RGB') 41 | image0 = self.transform(image0) 42 | 43 | image1_path = os.path.join(self.image_root,ann['images'][1]) 44 | image1 = Image.open(image1_path).convert('RGB') 45 | image1 = self.transform(image1) 46 | 47 | sentence = pre_caption(ann['sentence'], 40) 48 | 49 | if ann['label']=='True': 50 | label = 1 51 | else: 52 | label = 0 53 | 54 | words = sentence.split(' ') 55 | 56 | if 'left' not in words and 'right' not in words: 57 | if random.random()<0.5: 58 | return image0, image1, sentence, label 59 | else: 60 | return image1, image0, sentence, label 61 | else: 62 | if random.random()<0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | new_words = [] 66 | for word in words: 67 | if word=='left': 68 | new_words.append('right') 69 | elif word=='right': 70 | new_words.append('left') 71 | else: 72 | new_words.append(word) 73 | 74 | sentence = ' '.join(new_words) 75 | return image1, image0, sentence, label 76 | 77 | 78 | -------------------------------------------------------------------------------- /data/nocaps_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | class nocaps_eval(Dataset): 10 | def __init__(self, transform, image_root, ann_root, split): 11 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json', 12 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'} 13 | filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'} 14 | 15 | download_url(urls[split],ann_root) 16 | 17 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 18 | self.transform = transform 19 | self.image_root = image_root 20 | 21 | def __len__(self): 22 | return len(self.annotation) 23 | 24 | def __getitem__(self, index): 25 | 26 | ann = self.annotation[index] 27 | 28 | image_path = os.path.join(self.image_root,ann['image']) 29 | image = Image.open(image_path).convert('RGB') 30 | image = self.transform(image) 31 | 32 | return image, int(ann['img_id']) -------------------------------------------------------------------------------- /data/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from data.utils import pre_caption 13 | import os,glob 14 | 15 | class pretrain_dataset(Dataset): 16 | def __init__(self, ann_file, laion_path, transform): 17 | 18 | self.ann_pretrain = [] 19 | for f in ann_file: 20 | print('loading '+f) 21 | ann = json.load(open(f,'r')) 22 | self.ann_pretrain += ann 23 | 24 | self.laion_path = laion_path 25 | if self.laion_path: 26 | self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) 27 | 28 | print('loading '+self.laion_files[0]) 29 | with open(self.laion_files[0],'r') as f: 30 | self.ann_laion = json.load(f) 31 | 32 | self.annotation = self.ann_pretrain + self.ann_laion 33 | else: 34 | self.annotation = self.ann_pretrain 35 | 36 | self.transform = transform 37 | 38 | 39 | def reload_laion(self, epoch): 40 | n = epoch%len(self.laion_files) 41 | print('loading '+self.laion_files[n]) 42 | with open(self.laion_files[n],'r') as f: 43 | self.ann_laion = json.load(f) 44 | 45 | self.annotation = self.ann_pretrain + self.ann_laion 46 | 47 | 48 | def __len__(self): 49 | return len(self.annotation) 50 | 51 | def __getitem__(self, index): 52 | 53 | ann = self.annotation[index] 54 | 55 | image = Image.open(ann['image']).convert('RGB') 56 | image = self.transform(image) 57 | caption = pre_caption(ann['caption'],30) 58 | 59 | return image, caption -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | from pycocotools.coco import COCO 81 | from pycocoevalcap.eval import COCOEvalCap 82 | from torchvision.datasets.utils import download_url 83 | 84 | def coco_caption_eval(coco_gt_root, results_file, split): 85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | download_url(urls[split],coco_gt_root) 90 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # create coco object and coco_result object 93 | coco = COCO(annotation_file) 94 | coco_result = coco.loadRes(results_file) 95 | 96 | # create coco_eval object by taking coco and coco_result 97 | coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # evaluate on a subset of images by setting 100 | # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # please remove this line when evaluating the full validation set 102 | # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # evaluate results 105 | # SPICE will take a few minutes the first time, but speeds up due to caching 106 | coco_eval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in coco_eval.eval.items(): 110 | print(f'{metric}: {score:.3f}') 111 | 112 | return coco_eval -------------------------------------------------------------------------------- /data/video_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torchvision.datasets.utils import download_url 3 | 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | import random 8 | import decord 9 | from decord import VideoReader 10 | import json 11 | import os 12 | from data.utils import pre_caption 13 | 14 | decord.bridge.set_bridge("torch") 15 | 16 | class ImageNorm(object): 17 | """Apply Normalization to Image Pixels on GPU 18 | """ 19 | def __init__(self, mean, std): 20 | self.mean = torch.tensor(mean).view(1, 3, 1, 1) 21 | self.std = torch.tensor(std).view(1, 3, 1, 1) 22 | 23 | def __call__(self, img): 24 | 25 | if torch.max(img) > 1 and self.mean.max() <= 1: 26 | img.div_(255.) 27 | return img.sub_(self.mean).div_(self.std) 28 | 29 | def load_jsonl(filename): 30 | with open(filename, "r") as f: 31 | return [json.loads(l.strip("\n")) for l in f.readlines()] 32 | 33 | 34 | class VideoDataset(Dataset): 35 | 36 | def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'): 37 | ''' 38 | image_root (string): Root directory of video 39 | ann_root (string): directory to store the annotation file 40 | ''' 41 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl' 42 | filename = 'msrvtt_test.jsonl' 43 | 44 | download_url(url,ann_root) 45 | self.annotation = load_jsonl(os.path.join(ann_root,filename)) 46 | 47 | self.num_frm = num_frm 48 | self.frm_sampling_strategy = frm_sampling_strategy 49 | self.max_img_size = max_img_size 50 | self.video_root = video_root 51 | self.video_fmt = video_fmt 52 | self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 53 | 54 | self.text = [pre_caption(ann['caption'],40) for ann in self.annotation] 55 | self.txt2video = [i for i in range(len(self.annotation))] 56 | self.video2txt = self.txt2video 57 | 58 | 59 | def __len__(self): 60 | return len(self.annotation) 61 | 62 | def __getitem__(self, index): 63 | 64 | ann = self.annotation[index] 65 | 66 | video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt) 67 | 68 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 69 | 70 | video = self.img_norm(vid_frm_array.float()) 71 | 72 | return video, ann['clip_name'] 73 | 74 | 75 | 76 | def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1): 77 | try: 78 | if not height or not width: 79 | vr = VideoReader(video_path) 80 | else: 81 | vr = VideoReader(video_path, width=width, height=height) 82 | 83 | vlen = len(vr) 84 | 85 | if start_time or end_time: 86 | assert fps > 0, 'must provide video fps if specifying start and end time.' 87 | 88 | start_idx = min(int(start_time * fps), vlen) 89 | end_idx = min(int(end_time * fps), vlen) 90 | else: 91 | start_idx, end_idx = 0, vlen 92 | 93 | if self.frm_sampling_strategy == 'uniform': 94 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int) 95 | elif self.frm_sampling_strategy == 'rand': 96 | frame_indices = sorted(random.sample(range(vlen), self.num_frm)) 97 | elif self.frm_sampling_strategy == 'headtail': 98 | frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2)) 99 | frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2)) 100 | frame_indices = frame_indices_head + frame_indices_tail 101 | else: 102 | raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy)) 103 | 104 | raw_sample_frms = vr.get_batch(frame_indices) 105 | except Exception as e: 106 | return None 107 | 108 | raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) 109 | 110 | return raw_sample_frms 111 | -------------------------------------------------------------------------------- /data/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.utils import pre_question 9 | 10 | from torchvision.datasets.utils import download_url 11 | 12 | class vqa_dataset(Dataset): 13 | def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): 14 | self.split = split 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | 20 | if split=='train': 21 | urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', 22 | 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', 23 | 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} 24 | 25 | self.annotation = [] 26 | for f in train_files: 27 | download_url(urls[f],ann_root) 28 | self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) 29 | else: 30 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) 31 | self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) 32 | 33 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) 34 | self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) 35 | 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | if ann['dataset']=='vqa': 45 | image_path = os.path.join(self.vqa_root,ann['image']) 46 | elif ann['dataset']=='vg': 47 | image_path = os.path.join(self.vg_root,ann['image']) 48 | 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | 52 | if self.split == 'test': 53 | question = pre_question(ann['question']) 54 | question_id = ann['question_id'] 55 | return image, question, question_id 56 | 57 | 58 | elif self.split=='train': 59 | 60 | question = pre_question(ann['question']) 61 | 62 | if ann['dataset']=='vqa': 63 | answer_weight = {} 64 | for answer in ann['answer']: 65 | if answer in answer_weight.keys(): 66 | answer_weight[answer] += 1/len(ann['answer']) 67 | else: 68 | answer_weight[answer] = 1/len(ann['answer']) 69 | 70 | answers = list(answer_weight.keys()) 71 | weights = list(answer_weight.values()) 72 | 73 | elif ann['dataset']=='vg': 74 | answers = [ann['answer']] 75 | weights = [0.2] 76 | 77 | return image, question, answers, weights 78 | 79 | 80 | def vqa_collate_fn(batch): 81 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 82 | for image, question, answer, weights in batch: 83 | image_list.append(image) 84 | question_list.append(question) 85 | weight_list += weights 86 | answer_list += answer 87 | n.append(len(answer)) 88 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n -------------------------------------------------------------------------------- /eval_nocaps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from data import create_dataset, create_sampler, create_loader 28 | from data.utils import save_result 29 | 30 | @torch.no_grad() 31 | def evaluate(model, data_loader, device, config): 32 | # evaluate 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | print_freq = 10 38 | 39 | result = [] 40 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 41 | 42 | image = image.to(device) 43 | 44 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 45 | min_length=config['min_length'], repetition_penalty=1.1) 46 | 47 | for caption, img_id in zip(captions, image_id): 48 | result.append({"image_id": img_id.item(), "caption": caption}) 49 | 50 | return result 51 | 52 | 53 | def main(args, config): 54 | utils.init_distributed_mode(args) 55 | 56 | device = torch.device(args.device) 57 | 58 | # fix the seed for reproducibility 59 | seed = args.seed + utils.get_rank() 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | cudnn.benchmark = True 64 | 65 | #### Dataset #### 66 | print("Creating captioning dataset") 67 | val_dataset, test_dataset = create_dataset('nocaps', config) 68 | 69 | if args.distributed: 70 | num_tasks = utils.get_world_size() 71 | global_rank = utils.get_rank() 72 | samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) 73 | else: 74 | samplers = [None,None] 75 | 76 | val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, 77 | batch_size=[config['batch_size']]*2,num_workers=[4,4], 78 | is_trains=[False, False], collate_fns=[None,None]) 79 | 80 | #### Model #### 81 | print("Creating model") 82 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 83 | prompt=config['prompt']) 84 | 85 | model = model.to(device) 86 | 87 | model_without_ddp = model 88 | if args.distributed: 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 90 | model_without_ddp = model.module 91 | 92 | val_result = evaluate(model_without_ddp, val_loader, device, config) 93 | val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') 94 | test_result = evaluate(model_without_ddp, test_loader, device, config) 95 | test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', default='./configs/nocaps.yaml') 101 | parser.add_argument('--output_dir', default='output/NoCaps') 102 | parser.add_argument('--device', default='cuda') 103 | parser.add_argument('--seed', default=42, type=int) 104 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | parser.add_argument('--distributed', default=True, type=bool) 107 | args = parser.parse_args() 108 | 109 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 110 | 111 | args.result_dir = os.path.join(args.output_dir, 'result') 112 | 113 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 114 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 115 | 116 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 117 | 118 | main(args, config) -------------------------------------------------------------------------------- /eval_retrieval_video.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_retrieval import blip_retrieval 26 | import utils 27 | from data.video_dataset import VideoDataset 28 | 29 | 30 | @torch.no_grad() 31 | def evaluation(model, data_loader, tokenizer, device, config): 32 | # test 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | 38 | print('Computing features for evaluation...') 39 | start_time = time.time() 40 | 41 | texts = data_loader.dataset.text 42 | num_text = len(texts) 43 | text_bs = 256 44 | text_ids = [] 45 | text_embeds = [] 46 | text_atts = [] 47 | for i in range(0, num_text, text_bs): 48 | text = texts[i: min(num_text, i+text_bs)] 49 | text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) 50 | text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') 51 | text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) 52 | text_embeds.append(text_embed) 53 | text_ids.append(text_input.input_ids) 54 | text_atts.append(text_input.attention_mask) 55 | 56 | text_embeds = torch.cat(text_embeds,dim=0) 57 | text_ids = torch.cat(text_ids,dim=0) 58 | text_atts = torch.cat(text_atts,dim=0) 59 | text_ids[:,0] = tokenizer.additional_special_tokens_ids[0] 60 | 61 | video_feats = [] 62 | video_embeds = [] 63 | for video, video_id in data_loader: 64 | 65 | B,N,C,W,H = video.size() 66 | video = video.view(-1,C,W,H) 67 | video = video.to(device,non_blocking=True) 68 | video_feat = model.visual_encoder(video) 69 | video_embed = model.vision_proj(video_feat[:,0,:]) 70 | video_embed = video_embed.view(B,N,-1).mean(dim=1) 71 | video_embed = F.normalize(video_embed,dim=-1) 72 | 73 | video_feat = video_feat.view(B,-1,video_feat.shape[-1]) 74 | video_feats.append(video_feat.cpu()) 75 | video_embeds.append(video_embed) 76 | 77 | video_feats = torch.cat(video_feats,dim=0) 78 | video_embeds = torch.cat(video_embeds,dim=0) 79 | 80 | sims_matrix = video_embeds @ text_embeds.t() 81 | score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device) 82 | 83 | num_tasks = utils.get_world_size() 84 | rank = utils.get_rank() 85 | step = sims_matrix.size(0)//num_tasks + 1 86 | start = rank*step 87 | end = min(sims_matrix.size(0),start+step) 88 | 89 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 90 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 91 | 92 | encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True) 93 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 94 | output = model.text_encoder(text_ids[topk_idx], 95 | attention_mask = text_atts[topk_idx], 96 | encoder_hidden_states = encoder_output, 97 | encoder_attention_mask = encoder_att, 98 | return_dict = True, 99 | ) 100 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 101 | score_matrix_v2t[start+i,topk_idx] = score + topk_sim 102 | 103 | sims_matrix = sims_matrix.t() 104 | score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device) 105 | 106 | step = sims_matrix.size(0)//num_tasks + 1 107 | start = rank*step 108 | end = min(sims_matrix.size(0),start+step) 109 | 110 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 111 | 112 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 113 | encoder_output = video_feats[topk_idx].to(device,non_blocking=True) 114 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) 115 | output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), 116 | attention_mask = text_atts[start+i].repeat(config['k_test'],1), 117 | encoder_hidden_states = encoder_output, 118 | encoder_attention_mask = encoder_att, 119 | return_dict = True, 120 | ) 121 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 122 | score_matrix_t2v[start+i,topk_idx] = score + topk_sim 123 | 124 | if args.distributed: 125 | dist.barrier() 126 | torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM) 127 | torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM) 128 | 129 | total_time = time.time() - start_time 130 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 131 | print('Evaluation time {}'.format(total_time_str)) 132 | 133 | return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() 134 | 135 | 136 | 137 | @torch.no_grad() 138 | def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt): 139 | 140 | #Video->Text 141 | ranks = np.zeros(scores_v2t.shape[0]) 142 | for index,score in enumerate(scores_v2t): 143 | inds = np.argsort(score)[::-1] 144 | ranks[index] = np.where(inds == vid2txt[index])[0][0] 145 | 146 | # Compute metrics 147 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 148 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 149 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 150 | 151 | #Text->Video 152 | ranks = np.zeros(scores_t2v.shape[0]) 153 | 154 | for index,score in enumerate(scores_t2v): 155 | inds = np.argsort(score)[::-1] 156 | ranks[index] = np.where(inds == txt2vmg[index])[0][0] 157 | 158 | mdR = np.median(ranks+1) 159 | 160 | # Compute metrics 161 | vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 162 | vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 163 | vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 164 | 165 | tr_mean = (tr1 + tr5 + tr10) / 3 166 | vr_mean = (vr1 + vr5 + vr10) / 3 167 | r_mean = (tr_mean + vr_mean) / 2 168 | 169 | eval_result = {'txt_r1': tr1, 170 | 'txt_r5': tr5, 171 | 'txt_r10': tr10, 172 | 'txt_r_mean': tr_mean, 173 | 'vid_r1': vr1, 174 | 'vid_r5': vr5, 175 | 'vid_r10': vr10, 176 | 'vid_r_mean': vr_mean, 177 | 'vid_mdR': mdR, 178 | 'r_mean': r_mean} 179 | return eval_result 180 | 181 | 182 | 183 | 184 | def main(args, config): 185 | utils.init_distributed_mode(args) 186 | 187 | device = torch.device(args.device) 188 | 189 | # fix the seed for reproducibility 190 | seed = args.seed + utils.get_rank() 191 | torch.manual_seed(seed) 192 | np.random.seed(seed) 193 | random.seed(seed) 194 | cudnn.benchmark = True 195 | 196 | #### Dataset #### 197 | print("Creating retrieval dataset") 198 | test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'], 199 | max_img_size=config['image_size'], frm_sampling_strategy='uniform') 200 | 201 | test_loader = DataLoader( 202 | test_dataset, 203 | batch_size=config['batch_size'], 204 | num_workers=4, 205 | pin_memory=True, 206 | drop_last=False, 207 | shuffle=False, 208 | ) 209 | 210 | #### Model #### 211 | print("Creating model") 212 | model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit']) 213 | 214 | model = model.to(device) 215 | 216 | model_without_ddp = model 217 | if args.distributed: 218 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 219 | model_without_ddp = model.module 220 | 221 | score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config) 222 | 223 | if utils.is_main_process(): 224 | 225 | test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt) 226 | print(test_result) 227 | 228 | log_stats = {**{f'{k}': v for k, v in test_result.items()},} 229 | with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f: 230 | f.write(json.dumps(log_stats) + "\n") 231 | 232 | 233 | if __name__ == '__main__': 234 | parser = argparse.ArgumentParser() 235 | parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml') 236 | parser.add_argument('--output_dir', default='output/Retrieval_msrvtt') 237 | parser.add_argument('--device', default='cuda') 238 | parser.add_argument('--seed', default=42, type=int) 239 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 240 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 241 | parser.add_argument('--distributed', default=True, type=bool) 242 | args = parser.parse_args() 243 | 244 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 245 | 246 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 247 | 248 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 249 | 250 | main(args, config) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/BLIP/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/__init__.py -------------------------------------------------------------------------------- /models/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | from models.vit import VisionTransformer, interpolate_pos_embed 12 | from models.med import BertConfig, BertModel, BertLMHeadModel 13 | from transformers import BertTokenizer 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | import os 20 | from urllib.parse import urlparse 21 | from timm.models.hub import download_cached_file 22 | 23 | class BLIP_Base(nn.Module): 24 | def __init__(self, 25 | med_config = 'configs/med_config.json', 26 | image_size = 224, 27 | vit = 'base', 28 | vit_grad_ckpt = False, 29 | vit_ckpt_layer = 0, 30 | ): 31 | """ 32 | Args: 33 | med_config (str): path for the mixture of encoder-decoder model's configuration file 34 | image_size (int): input image size 35 | vit (str): model size of vision transformer 36 | """ 37 | super().__init__() 38 | 39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 40 | self.tokenizer = init_tokenizer() 41 | med_config = BertConfig.from_json_file(med_config) 42 | med_config.encoder_width = vision_width 43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 44 | 45 | 46 | def forward(self, image, caption, mode): 47 | 48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 50 | 51 | if mode=='image': 52 | # return image features 53 | image_embeds = self.visual_encoder(image) 54 | return image_embeds 55 | 56 | elif mode=='text': 57 | # return text features 58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 59 | return_dict = True, mode = 'text') 60 | return text_output.last_hidden_state 61 | 62 | elif mode=='multimodal': 63 | # return multimodel features 64 | image_embeds = self.visual_encoder(image) 65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 66 | 67 | text.input_ids[:,0] = self.tokenizer.enc_token_id 68 | output = self.text_encoder(text.input_ids, 69 | attention_mask = text.attention_mask, 70 | encoder_hidden_states = image_embeds, 71 | encoder_attention_mask = image_atts, 72 | return_dict = True, 73 | ) 74 | return output.last_hidden_state 75 | 76 | 77 | 78 | class BLIP_Decoder(nn.Module): 79 | def __init__(self, 80 | med_config = 'configs/med_config.json', 81 | image_size = 384, 82 | vit = 'base', 83 | vit_grad_ckpt = False, 84 | vit_ckpt_layer = 0, 85 | prompt = 'a picture of ', 86 | ): 87 | """ 88 | Args: 89 | med_config (str): path for the mixture of encoder-decoder model's configuration file 90 | image_size (int): input image size 91 | vit (str): model size of vision transformer 92 | """ 93 | super().__init__() 94 | 95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 96 | self.tokenizer = init_tokenizer() 97 | med_config = BertConfig.from_json_file(med_config) 98 | med_config.encoder_width = vision_width 99 | self.text_decoder = BertLMHeadModel(config=med_config) 100 | 101 | self.prompt = prompt 102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 103 | 104 | 105 | def forward(self, image, caption): 106 | 107 | image_embeds = self.visual_encoder(image) 108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 109 | 110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 111 | 112 | text.input_ids[:,0] = self.tokenizer.bos_token_id 113 | 114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 115 | decoder_targets[:,:self.prompt_length] = -100 116 | 117 | decoder_output = self.text_decoder(text.input_ids, 118 | attention_mask = text.attention_mask, 119 | encoder_hidden_states = image_embeds, 120 | encoder_attention_mask = image_atts, 121 | labels = decoder_targets, 122 | return_dict = True, 123 | ) 124 | loss_lm = decoder_output.loss 125 | 126 | return loss_lm 127 | 128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 129 | image_embeds = self.visual_encoder(image) 130 | 131 | if not sample: 132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 133 | 134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 136 | 137 | prompt = [self.prompt] * image.size(0) 138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 139 | input_ids[:,0] = self.tokenizer.bos_token_id 140 | input_ids = input_ids[:, :-1] 141 | 142 | if sample: 143 | #nucleus sampling 144 | outputs = self.text_decoder.generate(input_ids=input_ids, 145 | max_length=max_length, 146 | min_length=min_length, 147 | do_sample=True, 148 | top_p=top_p, 149 | num_return_sequences=1, 150 | eos_token_id=self.tokenizer.sep_token_id, 151 | pad_token_id=self.tokenizer.pad_token_id, 152 | repetition_penalty=1.1, 153 | **model_kwargs) 154 | else: 155 | #beam search 156 | outputs = self.text_decoder.generate(input_ids=input_ids, 157 | max_length=max_length, 158 | min_length=min_length, 159 | num_beams=num_beams, 160 | eos_token_id=self.tokenizer.sep_token_id, 161 | pad_token_id=self.tokenizer.pad_token_id, 162 | repetition_penalty=repetition_penalty, 163 | **model_kwargs) 164 | 165 | captions = [] 166 | for output in outputs: 167 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 168 | captions.append(caption[len(self.prompt):]) 169 | return captions 170 | 171 | 172 | def blip_decoder(pretrained='',**kwargs): 173 | model = BLIP_Decoder(**kwargs) 174 | if pretrained: 175 | model,msg = load_checkpoint(model,pretrained) 176 | assert(len(msg.missing_keys)==0) 177 | return model 178 | 179 | def blip_feature_extractor(pretrained='',**kwargs): 180 | model = BLIP_Base(**kwargs) 181 | if pretrained: 182 | model,msg = load_checkpoint(model,pretrained) 183 | assert(len(msg.missing_keys)==0) 184 | return model 185 | 186 | def init_tokenizer(): 187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 191 | return tokenizer 192 | 193 | 194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 195 | 196 | assert vit in ['base', 'large'], "vit parameter must be base or large" 197 | if vit=='base': 198 | vision_width = 768 199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 201 | drop_path_rate=0 or drop_path_rate 202 | ) 203 | elif vit=='large': 204 | vision_width = 1024 205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 207 | drop_path_rate=0.1 or drop_path_rate 208 | ) 209 | return visual_encoder, vision_width 210 | 211 | def is_url(url_or_filename): 212 | parsed = urlparse(url_or_filename) 213 | return parsed.scheme in ("http", "https") 214 | 215 | def load_checkpoint(model,url_or_filename): 216 | if is_url(url_or_filename): 217 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 218 | checkpoint = torch.load(cached_file, map_location='cpu') 219 | elif os.path.isfile(url_or_filename): 220 | checkpoint = torch.load(url_or_filename, map_location='cpu') 221 | else: 222 | raise RuntimeError('checkpoint url or path is invalid') 223 | 224 | state_dict = checkpoint['model'] 225 | 226 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 227 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 228 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 229 | model.visual_encoder_m) 230 | for key in model.state_dict().keys(): 231 | if key in state_dict.keys(): 232 | if state_dict[key].shape!=model.state_dict()[key].shape: 233 | del state_dict[key] 234 | 235 | msg = model.load_state_dict(state_dict,strict=False) 236 | print('load checkpoint from %s'%url_or_filename) 237 | return model,msg 238 | 239 | -------------------------------------------------------------------------------- /models/blip_itm.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_ITM(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | ): 19 | """ 20 | Args: 21 | med_config (str): path for the mixture of encoder-decoder model's configuration file 22 | image_size (int): input image size 23 | vit (str): model size of vision transformer 24 | """ 25 | super().__init__() 26 | 27 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 28 | self.tokenizer = init_tokenizer() 29 | med_config = BertConfig.from_json_file(med_config) 30 | med_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 32 | 33 | text_width = self.text_encoder.config.hidden_size 34 | 35 | self.vision_proj = nn.Linear(vision_width, embed_dim) 36 | self.text_proj = nn.Linear(text_width, embed_dim) 37 | 38 | self.itm_head = nn.Linear(text_width, 2) 39 | 40 | 41 | def forward(self, image, caption, match_head='itm'): 42 | 43 | image_embeds = self.visual_encoder(image) 44 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 45 | 46 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 47 | return_tensors="pt").to(image.device) 48 | 49 | 50 | if match_head=='itm': 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = image_embeds, 54 | encoder_attention_mask = image_atts, 55 | return_dict = True, 56 | ) 57 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 58 | return itm_output 59 | 60 | elif match_head=='itc': 61 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 62 | return_dict = True, mode = 'text') 63 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 64 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 65 | 66 | sim = image_feat @ text_feat.t() 67 | return sim 68 | 69 | 70 | def blip_itm(pretrained='',**kwargs): 71 | model = BLIP_ITM(**kwargs) 72 | if pretrained: 73 | model,msg = load_checkpoint(model,pretrained) 74 | assert(len(msg.missing_keys)==0) 75 | return model 76 | -------------------------------------------------------------------------------- /models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig 2 | from models.nlvr_encoder import BertModel 3 | from models.vit import interpolate_pos_embed 4 | from models.blip import create_vit, init_tokenizer, is_url 5 | 6 | from timm.models.hub import download_cached_file 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | class BLIP_NLVR(nn.Module): 15 | def __init__(self, 16 | med_config = 'configs/med_config.json', 17 | image_size = 480, 18 | vit = 'base', 19 | vit_grad_ckpt = False, 20 | vit_ckpt_layer = 0, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | self.cls_head = nn.Sequential( 37 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 38 | nn.ReLU(), 39 | nn.Linear(self.text_encoder.config.hidden_size, 2) 40 | ) 41 | 42 | def forward(self, image, text, targets, train=True): 43 | 44 | image_embeds = self.visual_encoder(image) 45 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 46 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 47 | 48 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 49 | text.input_ids[:,0] = self.tokenizer.enc_token_id 50 | 51 | output = self.text_encoder(text.input_ids, 52 | attention_mask = text.attention_mask, 53 | encoder_hidden_states = [image0_embeds,image1_embeds], 54 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 55 | image_atts[image0_embeds.size(0):]], 56 | return_dict = True, 57 | ) 58 | hidden_state = output.last_hidden_state[:,0,:] 59 | prediction = self.cls_head(hidden_state) 60 | 61 | if train: 62 | loss = F.cross_entropy(prediction, targets) 63 | return loss 64 | else: 65 | return prediction 66 | 67 | def blip_nlvr(pretrained='',**kwargs): 68 | model = BLIP_NLVR(**kwargs) 69 | if pretrained: 70 | model,msg = load_checkpoint(model,pretrained) 71 | print("missing keys:") 72 | print(msg.missing_keys) 73 | return model 74 | 75 | 76 | def load_checkpoint(model,url_or_filename): 77 | if is_url(url_or_filename): 78 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 79 | checkpoint = torch.load(cached_file, map_location='cpu') 80 | elif os.path.isfile(url_or_filename): 81 | checkpoint = torch.load(url_or_filename, map_location='cpu') 82 | else: 83 | raise RuntimeError('checkpoint url or path is invalid') 84 | state_dict = checkpoint['model'] 85 | 86 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 87 | 88 | for key in list(state_dict.keys()): 89 | if 'crossattention.self.' in key: 90 | new_key0 = key.replace('self','self0') 91 | new_key1 = key.replace('self','self1') 92 | state_dict[new_key0] = state_dict[key] 93 | state_dict[new_key1] = state_dict[key] 94 | elif 'crossattention.output.dense.' in key: 95 | new_key0 = key.replace('dense','dense0') 96 | new_key1 = key.replace('dense','dense1') 97 | state_dict[new_key0] = state_dict[key] 98 | state_dict[new_key1] = state_dict[key] 99 | 100 | msg = model.load_state_dict(state_dict,strict=False) 101 | print('load checkpoint from %s'%url_or_filename) 102 | return model,msg 103 | -------------------------------------------------------------------------------- /models/blip_pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig, BertModel, BertLMHeadModel 9 | from transformers import BertTokenizer 10 | import transformers 11 | transformers.logging.set_verbosity_error() 12 | 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from models.blip import create_vit, init_tokenizer, load_checkpoint 18 | 19 | class BLIP_Pretrain(nn.Module): 20 | def __init__(self, 21 | med_config = 'configs/bert_config.json', 22 | image_size = 224, 23 | vit = 'base', 24 | vit_grad_ckpt = False, 25 | vit_ckpt_layer = 0, 26 | embed_dim = 256, 27 | queue_size = 57600, 28 | momentum = 0.995, 29 | ): 30 | """ 31 | Args: 32 | med_config (str): path for the mixture of encoder-decoder model's configuration file 33 | image_size (int): input image size 34 | vit (str): model size of vision transformer 35 | """ 36 | super().__init__() 37 | 38 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) 39 | 40 | if vit=='base': 41 | checkpoint = torch.hub.load_state_dict_from_url( 42 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 43 | map_location="cpu", check_hash=True) 44 | state_dict = checkpoint["model"] 45 | msg = self.visual_encoder.load_state_dict(state_dict,strict=False) 46 | elif vit=='large': 47 | from timm.models.helpers import load_custom_pretrained 48 | from timm.models.vision_transformer import default_cfgs 49 | load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k']) 50 | 51 | self.tokenizer = init_tokenizer() 52 | encoder_config = BertConfig.from_json_file(med_config) 53 | encoder_config.encoder_width = vision_width 54 | self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False) 55 | self.text_encoder.resize_token_embeddings(len(self.tokenizer)) 56 | 57 | text_width = self.text_encoder.config.hidden_size 58 | 59 | self.vision_proj = nn.Linear(vision_width, embed_dim) 60 | self.text_proj = nn.Linear(text_width, embed_dim) 61 | 62 | self.itm_head = nn.Linear(text_width, 2) 63 | 64 | # create momentum encoders 65 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 66 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 67 | self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False) 68 | self.text_proj_m = nn.Linear(text_width, embed_dim) 69 | 70 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 71 | [self.vision_proj,self.vision_proj_m], 72 | [self.text_encoder,self.text_encoder_m], 73 | [self.text_proj,self.text_proj_m], 74 | ] 75 | self.copy_params() 76 | 77 | # create the queue 78 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 79 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 80 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 81 | 82 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 83 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 84 | 85 | self.queue_size = queue_size 86 | self.momentum = momentum 87 | self.temp = nn.Parameter(0.07*torch.ones([])) 88 | 89 | # create the decoder 90 | decoder_config = BertConfig.from_json_file(med_config) 91 | decoder_config.encoder_width = vision_width 92 | self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config) 93 | self.text_decoder.resize_token_embeddings(len(self.tokenizer)) 94 | tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention') 95 | 96 | 97 | def forward(self, image, caption, alpha): 98 | with torch.no_grad(): 99 | self.temp.clamp_(0.001,0.5) 100 | 101 | image_embeds = self.visual_encoder(image) 102 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 103 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 104 | 105 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, 106 | return_tensors="pt").to(image.device) 107 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 108 | return_dict = True, mode = 'text') 109 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 110 | 111 | # get momentum features 112 | with torch.no_grad(): 113 | self._momentum_update() 114 | image_embeds_m = self.visual_encoder_m(image) 115 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 116 | image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 117 | 118 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 119 | return_dict = True, mode = 'text') 120 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 121 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 122 | 123 | sim_i2t_m = image_feat_m @ text_feat_all / self.temp 124 | sim_t2i_m = text_feat_m @ image_feat_all / self.temp 125 | 126 | sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) 127 | sim_targets.fill_diagonal_(1) 128 | 129 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 130 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 131 | 132 | sim_i2t = image_feat @ text_feat_all / self.temp 133 | sim_t2i = text_feat @ image_feat_all / self.temp 134 | 135 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 136 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 137 | 138 | loss_ita = (loss_i2t+loss_t2i)/2 139 | 140 | self._dequeue_and_enqueue(image_feat_m, text_feat_m) 141 | 142 | ###============== Image-text Matching ===================### 143 | encoder_input_ids = text.input_ids.clone() 144 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 145 | 146 | # forward the positve image-text pair 147 | bs = image.size(0) 148 | output_pos = self.text_encoder(encoder_input_ids, 149 | attention_mask = text.attention_mask, 150 | encoder_hidden_states = image_embeds, 151 | encoder_attention_mask = image_atts, 152 | return_dict = True, 153 | ) 154 | with torch.no_grad(): 155 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 156 | weights_t2i.fill_diagonal_(0) 157 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 158 | weights_i2t.fill_diagonal_(0) 159 | 160 | # select a negative image for each text 161 | image_embeds_neg = [] 162 | for b in range(bs): 163 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 164 | image_embeds_neg.append(image_embeds[neg_idx]) 165 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 166 | 167 | # select a negative text for each image 168 | text_ids_neg = [] 169 | text_atts_neg = [] 170 | for b in range(bs): 171 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 172 | text_ids_neg.append(encoder_input_ids[neg_idx]) 173 | text_atts_neg.append(text.attention_mask[neg_idx]) 174 | 175 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 176 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 177 | 178 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 179 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 180 | 181 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 182 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 183 | 184 | output_neg = self.text_encoder(text_ids_all, 185 | attention_mask = text_atts_all, 186 | encoder_hidden_states = image_embeds_all, 187 | encoder_attention_mask = image_atts_all, 188 | return_dict = True, 189 | ) 190 | 191 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 192 | vl_output = self.itm_head(vl_embeddings) 193 | 194 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 195 | dim=0).to(image.device) 196 | loss_itm = F.cross_entropy(vl_output, itm_labels) 197 | 198 | ##================= LM ========================## 199 | decoder_input_ids = text.input_ids.clone() 200 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 201 | decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) 202 | 203 | decoder_output = self.text_decoder(decoder_input_ids, 204 | attention_mask = text.attention_mask, 205 | encoder_hidden_states = image_embeds, 206 | encoder_attention_mask = image_atts, 207 | labels = decoder_targets, 208 | return_dict = True, 209 | ) 210 | 211 | loss_lm = decoder_output.loss 212 | return loss_ita, loss_itm, loss_lm 213 | 214 | 215 | 216 | @torch.no_grad() 217 | def copy_params(self): 218 | for model_pair in self.model_pairs: 219 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 220 | param_m.data.copy_(param.data) # initialize 221 | param_m.requires_grad = False # not update by gradient 222 | 223 | 224 | @torch.no_grad() 225 | def _momentum_update(self): 226 | for model_pair in self.model_pairs: 227 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 228 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 229 | 230 | 231 | @torch.no_grad() 232 | def _dequeue_and_enqueue(self, image_feat, text_feat): 233 | # gather keys before updating queue 234 | image_feats = concat_all_gather(image_feat) 235 | text_feats = concat_all_gather(text_feat) 236 | 237 | batch_size = image_feats.shape[0] 238 | 239 | ptr = int(self.queue_ptr) 240 | assert self.queue_size % batch_size == 0 # for simplicity 241 | 242 | # replace the keys at ptr (dequeue and enqueue) 243 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 244 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 245 | ptr = (ptr + batch_size) % self.queue_size # move pointer 246 | 247 | self.queue_ptr[0] = ptr 248 | 249 | 250 | def blip_pretrain(**kwargs): 251 | model = BLIP_Pretrain(**kwargs) 252 | return model 253 | 254 | 255 | @torch.no_grad() 256 | def concat_all_gather(tensor): 257 | """ 258 | Performs all_gather operation on the provided tensors. 259 | *** Warning ***: torch.distributed.all_gather has no gradient. 260 | """ 261 | tensors_gather = [torch.ones_like(tensor) 262 | for _ in range(torch.distributed.get_world_size())] 263 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 264 | 265 | output = torch.cat(tensors_gather, dim=0) 266 | return output 267 | 268 | 269 | from typing import List 270 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): 271 | uninitialized_encoder_weights: List[str] = [] 272 | if decoder.__class__ != encoder.__class__: 273 | logger.info( 274 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 275 | ) 276 | 277 | def tie_encoder_to_decoder_recursively( 278 | decoder_pointer: nn.Module, 279 | encoder_pointer: nn.Module, 280 | module_name: str, 281 | uninitialized_encoder_weights: List[str], 282 | skip_key: str, 283 | depth=0, 284 | ): 285 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 286 | encoder_pointer, nn.Module 287 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 288 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 289 | assert hasattr(encoder_pointer, "weight") 290 | encoder_pointer.weight = decoder_pointer.weight 291 | if hasattr(decoder_pointer, "bias"): 292 | assert hasattr(encoder_pointer, "bias") 293 | encoder_pointer.bias = decoder_pointer.bias 294 | print(module_name+' is tied') 295 | return 296 | 297 | encoder_modules = encoder_pointer._modules 298 | decoder_modules = decoder_pointer._modules 299 | if len(decoder_modules) > 0: 300 | assert ( 301 | len(encoder_modules) > 0 302 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 303 | 304 | all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) 305 | encoder_layer_pos = 0 306 | for name, module in decoder_modules.items(): 307 | if name.isdigit(): 308 | encoder_name = str(int(name) + encoder_layer_pos) 309 | decoder_name = name 310 | if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( 311 | encoder_modules 312 | ) != len(decoder_modules): 313 | # this can happen if the name corresponds to the position in a list module list of layers 314 | # in this case the decoder has added a cross-attention that the encoder does not have 315 | # thus skip this step and subtract one layer pos from encoder 316 | encoder_layer_pos -= 1 317 | continue 318 | elif name not in encoder_modules: 319 | continue 320 | elif depth > 500: 321 | raise ValueError( 322 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 323 | ) 324 | else: 325 | decoder_name = encoder_name = name 326 | tie_encoder_to_decoder_recursively( 327 | decoder_modules[decoder_name], 328 | encoder_modules[encoder_name], 329 | module_name + "/" + name, 330 | uninitialized_encoder_weights, 331 | skip_key, 332 | depth=depth + 1, 333 | ) 334 | all_encoder_weights.remove(module_name + "/" + encoder_name) 335 | 336 | uninitialized_encoder_weights += list(all_encoder_weights) 337 | 338 | # tie weights recursively 339 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key) 340 | -------------------------------------------------------------------------------- /models/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel 2 | from transformers import BertTokenizer 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from models.blip import create_vit, init_tokenizer, load_checkpoint 9 | 10 | class BLIP_Retrieval(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 384, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | embed_dim = 256, 18 | queue_size = 57600, 19 | momentum = 0.995, 20 | negative_all_rank = False, 21 | ): 22 | """ 23 | Args: 24 | med_config (str): path for the mixture of encoder-decoder model's configuration file 25 | image_size (int): input image size 26 | vit (str): model size of vision transformer 27 | """ 28 | super().__init__() 29 | 30 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 31 | self.tokenizer = init_tokenizer() 32 | med_config = BertConfig.from_json_file(med_config) 33 | med_config.encoder_width = vision_width 34 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 35 | 36 | text_width = self.text_encoder.config.hidden_size 37 | 38 | self.vision_proj = nn.Linear(vision_width, embed_dim) 39 | self.text_proj = nn.Linear(text_width, embed_dim) 40 | 41 | self.itm_head = nn.Linear(text_width, 2) 42 | 43 | # create momentum encoders 44 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 45 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 46 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 47 | self.text_proj_m = nn.Linear(text_width, embed_dim) 48 | 49 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 50 | [self.vision_proj,self.vision_proj_m], 51 | [self.text_encoder,self.text_encoder_m], 52 | [self.text_proj,self.text_proj_m], 53 | ] 54 | self.copy_params() 55 | 56 | # create the queue 57 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 59 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 60 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 61 | 62 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 63 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 64 | 65 | self.queue_size = queue_size 66 | self.momentum = momentum 67 | self.temp = nn.Parameter(0.07*torch.ones([])) 68 | 69 | self.negative_all_rank = negative_all_rank 70 | 71 | 72 | def forward(self, image, caption, alpha, idx): 73 | with torch.no_grad(): 74 | self.temp.clamp_(0.001,0.5) 75 | 76 | image_embeds = self.visual_encoder(image) 77 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 78 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 79 | 80 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 81 | return_tensors="pt").to(image.device) 82 | 83 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 84 | return_dict = True, mode = 'text') 85 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 86 | 87 | ###============== Image-text Contrastive Learning ===================### 88 | idx = idx.view(-1,1) 89 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 90 | pos_idx = torch.eq(idx, idx_all).float() 91 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 92 | 93 | # get momentum features 94 | with torch.no_grad(): 95 | self._momentum_update() 96 | image_embeds_m = self.visual_encoder_m(image) 97 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 98 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 99 | 100 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 101 | return_dict = True, mode = 'text') 102 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 103 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 104 | 105 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 106 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 107 | 108 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 109 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 110 | 111 | sim_i2t = image_feat @ text_feat_m_all / self.temp 112 | sim_t2i = text_feat @ image_feat_m_all / self.temp 113 | 114 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 115 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 116 | 117 | loss_ita = (loss_i2t+loss_t2i)/2 118 | 119 | idxs = concat_all_gather(idx) 120 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 121 | 122 | ###============== Image-text Matching ===================### 123 | encoder_input_ids = text.input_ids.clone() 124 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 125 | 126 | # forward the positve image-text pair 127 | bs = image.size(0) 128 | output_pos = self.text_encoder(encoder_input_ids, 129 | attention_mask = text.attention_mask, 130 | encoder_hidden_states = image_embeds, 131 | encoder_attention_mask = image_atts, 132 | return_dict = True, 133 | ) 134 | 135 | 136 | if self.negative_all_rank: 137 | # compute sample similarity 138 | with torch.no_grad(): 139 | mask = torch.eq(idx, idxs.t()) 140 | 141 | image_feat_world = concat_all_gather(image_feat) 142 | text_feat_world = concat_all_gather(text_feat) 143 | 144 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 145 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 146 | 147 | weights_i2t = F.softmax(sim_i2t,dim=1) 148 | weights_i2t.masked_fill_(mask, 0) 149 | 150 | weights_t2i = F.softmax(sim_t2i,dim=1) 151 | weights_t2i.masked_fill_(mask, 0) 152 | 153 | image_embeds_world = all_gather_with_grad(image_embeds) 154 | 155 | # select a negative image (from all ranks) for each text 156 | image_embeds_neg = [] 157 | for b in range(bs): 158 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 159 | image_embeds_neg.append(image_embeds_world[neg_idx]) 160 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 161 | 162 | # select a negative text (from all ranks) for each image 163 | input_ids_world = concat_all_gather(encoder_input_ids) 164 | att_mask_world = concat_all_gather(text.attention_mask) 165 | 166 | text_ids_neg = [] 167 | text_atts_neg = [] 168 | for b in range(bs): 169 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 170 | text_ids_neg.append(input_ids_world[neg_idx]) 171 | text_atts_neg.append(att_mask_world[neg_idx]) 172 | 173 | else: 174 | with torch.no_grad(): 175 | mask = torch.eq(idx, idx.t()) 176 | 177 | sim_i2t = image_feat @ text_feat.t() / self.temp 178 | sim_t2i = text_feat @ image_feat.t() / self.temp 179 | 180 | weights_i2t = F.softmax(sim_i2t,dim=1) 181 | weights_i2t.masked_fill_(mask, 0) 182 | 183 | weights_t2i = F.softmax(sim_t2i,dim=1) 184 | weights_t2i.masked_fill_(mask, 0) 185 | 186 | # select a negative image (from same rank) for each text 187 | image_embeds_neg = [] 188 | for b in range(bs): 189 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 190 | image_embeds_neg.append(image_embeds[neg_idx]) 191 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 192 | 193 | # select a negative text (from same rank) for each image 194 | text_ids_neg = [] 195 | text_atts_neg = [] 196 | for b in range(bs): 197 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 198 | text_ids_neg.append(encoder_input_ids[neg_idx]) 199 | text_atts_neg.append(text.attention_mask[neg_idx]) 200 | 201 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 202 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 203 | 204 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 205 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 206 | 207 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 208 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 209 | 210 | output_neg = self.text_encoder(text_ids_all, 211 | attention_mask = text_atts_all, 212 | encoder_hidden_states = image_embeds_all, 213 | encoder_attention_mask = image_atts_all, 214 | return_dict = True, 215 | ) 216 | 217 | 218 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 219 | vl_output = self.itm_head(vl_embeddings) 220 | 221 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 222 | dim=0).to(image.device) 223 | loss_itm = F.cross_entropy(vl_output, itm_labels) 224 | 225 | return loss_ita, loss_itm 226 | 227 | 228 | @torch.no_grad() 229 | def copy_params(self): 230 | for model_pair in self.model_pairs: 231 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 232 | param_m.data.copy_(param.data) # initialize 233 | param_m.requires_grad = False # not update by gradient 234 | 235 | 236 | @torch.no_grad() 237 | def _momentum_update(self): 238 | for model_pair in self.model_pairs: 239 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 240 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 241 | 242 | 243 | @torch.no_grad() 244 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 245 | # gather keys before updating queue 246 | image_feats = concat_all_gather(image_feat) 247 | text_feats = concat_all_gather(text_feat) 248 | 249 | 250 | batch_size = image_feats.shape[0] 251 | 252 | ptr = int(self.ptr_queue) 253 | assert self.queue_size % batch_size == 0 # for simplicity 254 | 255 | # replace the keys at ptr (dequeue and enqueue) 256 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 257 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 258 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 259 | ptr = (ptr + batch_size) % self.queue_size # move pointer 260 | 261 | self.ptr_queue[0] = ptr 262 | 263 | 264 | def blip_retrieval(pretrained='',**kwargs): 265 | model = BLIP_Retrieval(**kwargs) 266 | if pretrained: 267 | model,msg = load_checkpoint(model,pretrained) 268 | print("missing keys:") 269 | print(msg.missing_keys) 270 | return model 271 | 272 | 273 | @torch.no_grad() 274 | def concat_all_gather(tensor): 275 | """ 276 | Performs all_gather operation on the provided tensors. 277 | *** Warning ***: torch.distributed.all_gather has no gradient. 278 | """ 279 | tensors_gather = [torch.ones_like(tensor) 280 | for _ in range(torch.distributed.get_world_size())] 281 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 282 | 283 | output = torch.cat(tensors_gather, dim=0) 284 | return output 285 | 286 | 287 | class GatherLayer(torch.autograd.Function): 288 | """ 289 | Gather tensors from all workers with support for backward propagation: 290 | This implementation does not cut the gradients as torch.distributed.all_gather does. 291 | """ 292 | 293 | @staticmethod 294 | def forward(ctx, x): 295 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 296 | torch.distributed.all_gather(output, x) 297 | return tuple(output) 298 | 299 | @staticmethod 300 | def backward(ctx, *grads): 301 | all_gradients = torch.stack(grads) 302 | torch.distributed.all_reduce(all_gradients) 303 | return all_gradients[torch.distributed.get_rank()] 304 | 305 | 306 | def all_gather_with_grad(tensors): 307 | """ 308 | Performs all_gather operation on the provided tensors. 309 | Graph remains connected for backward grad computation. 310 | """ 311 | # Queue the gathered tensors 312 | world_size = torch.distributed.get_world_size() 313 | # There is no need for reduction in the single-proc case 314 | if world_size == 1: 315 | return tensors 316 | 317 | tensor_all = GatherLayer.apply(tensors) 318 | 319 | return torch.cat(tensor_all, dim=0) 320 | -------------------------------------------------------------------------------- /models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | from models.med import BertConfig, BertModel, BertLMHeadModel 2 | from models.blip import create_vit, init_tokenizer, load_checkpoint 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import BertTokenizer 8 | import numpy as np 9 | 10 | class BLIP_VQA(nn.Module): 11 | def __init__(self, 12 | med_config = 'configs/med_config.json', 13 | image_size = 480, 14 | vit = 'base', 15 | vit_grad_ckpt = False, 16 | vit_ckpt_layer = 0, 17 | ): 18 | """ 19 | Args: 20 | med_config (str): path for the mixture of encoder-decoder model's configuration file 21 | image_size (int): input image size 22 | vit (str): model size of vision transformer 23 | """ 24 | super().__init__() 25 | 26 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 27 | self.tokenizer = init_tokenizer() 28 | 29 | encoder_config = BertConfig.from_json_file(med_config) 30 | encoder_config.encoder_width = vision_width 31 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 32 | 33 | decoder_config = BertConfig.from_json_file(med_config) 34 | self.text_decoder = BertLMHeadModel(config=decoder_config) 35 | 36 | 37 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 38 | 39 | image_embeds = self.visual_encoder(image) 40 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 41 | 42 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 43 | return_tensors="pt").to(image.device) 44 | question.input_ids[:,0] = self.tokenizer.enc_token_id 45 | 46 | if train: 47 | ''' 48 | n: number of answers for each question 49 | weights: weight for each answer 50 | ''' 51 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 52 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 53 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 54 | 55 | question_output = self.text_encoder(question.input_ids, 56 | attention_mask = question.attention_mask, 57 | encoder_hidden_states = image_embeds, 58 | encoder_attention_mask = image_atts, 59 | return_dict = True) 60 | 61 | question_states = [] 62 | question_atts = [] 63 | for b, n in enumerate(n): 64 | question_states += [question_output.last_hidden_state[b]]*n 65 | question_atts += [question.attention_mask[b]]*n 66 | question_states = torch.stack(question_states,0) 67 | question_atts = torch.stack(question_atts,0) 68 | 69 | answer_output = self.text_decoder(answer.input_ids, 70 | attention_mask = answer.attention_mask, 71 | encoder_hidden_states = question_states, 72 | encoder_attention_mask = question_atts, 73 | labels = answer_targets, 74 | return_dict = True, 75 | reduction = 'none', 76 | ) 77 | 78 | loss = weights * answer_output.loss 79 | loss = loss.sum()/image.size(0) 80 | 81 | return loss 82 | 83 | 84 | else: 85 | question_output = self.text_encoder(question.input_ids, 86 | attention_mask = question.attention_mask, 87 | encoder_hidden_states = image_embeds, 88 | encoder_attention_mask = image_atts, 89 | return_dict = True) 90 | 91 | if inference=='generate': 92 | num_beams = 3 93 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 94 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 95 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 96 | 97 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 98 | 99 | outputs = self.text_decoder.generate(input_ids=bos_ids, 100 | max_length=10, 101 | min_length=1, 102 | num_beams=num_beams, 103 | eos_token_id=self.tokenizer.sep_token_id, 104 | pad_token_id=self.tokenizer.pad_token_id, 105 | **model_kwargs) 106 | 107 | answers = [] 108 | for output in outputs: 109 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 110 | answers.append(answer) 111 | return answers 112 | 113 | elif inference=='rank': 114 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 115 | answer.input_ids, answer.attention_mask, k_test) 116 | return max_ids 117 | 118 | 119 | 120 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 121 | 122 | num_ques = question_states.size(0) 123 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 124 | 125 | start_output = self.text_decoder(start_ids, 126 | encoder_hidden_states = question_states, 127 | encoder_attention_mask = question_atts, 128 | return_dict = True, 129 | reduction = 'none') 130 | logits = start_output.logits[:,0,:] # first token's logit 131 | 132 | # topk_probs: top-k probability 133 | # topk_ids: [num_question, k] 134 | answer_first_token = answer_ids[:,1] 135 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 136 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 137 | 138 | # answer input: [num_question*k, answer_len] 139 | input_ids = [] 140 | input_atts = [] 141 | for b, topk_id in enumerate(topk_ids): 142 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 143 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 144 | input_ids = torch.cat(input_ids,dim=0) 145 | input_atts = torch.cat(input_atts,dim=0) 146 | 147 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 148 | 149 | # repeat encoder's output for top-k answers 150 | question_states = tile(question_states, 0, k) 151 | question_atts = tile(question_atts, 0, k) 152 | 153 | output = self.text_decoder(input_ids, 154 | attention_mask = input_atts, 155 | encoder_hidden_states = question_states, 156 | encoder_attention_mask = question_atts, 157 | labels = targets_ids, 158 | return_dict = True, 159 | reduction = 'none') 160 | 161 | log_probs_sum = -output.loss 162 | log_probs_sum = log_probs_sum.view(num_ques,k) 163 | 164 | max_topk_ids = log_probs_sum.argmax(dim=1) 165 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 166 | 167 | return max_ids 168 | 169 | 170 | def blip_vqa(pretrained='',**kwargs): 171 | model = BLIP_VQA(**kwargs) 172 | if pretrained: 173 | model,msg = load_checkpoint(model,pretrained) 174 | # assert(len(msg.missing_keys)==0) 175 | return model 176 | 177 | 178 | def tile(x, dim, n_tile): 179 | init_dim = x.size(dim) 180 | repeat_idx = [1] * x.dim() 181 | repeat_idx[dim] = n_tile 182 | x = x.repeat(*(repeat_idx)) 183 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 184 | return torch.index_select(x, dim, order_index.to(x.device)) 185 | 186 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the weights in ./checkpoints beforehand for fast inference 3 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 4 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth 5 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | from PIL import Image 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms.functional import InterpolationMode 14 | import cog 15 | 16 | from models.blip import blip_decoder 17 | from models.blip_vqa import blip_vqa 18 | from models.blip_itm import blip_itm 19 | 20 | 21 | class Predictor(cog.Predictor): 22 | def setup(self): 23 | self.device = "cuda:0" 24 | 25 | self.models = { 26 | 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth', 27 | image_size=384, vit='base'), 28 | 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth', 29 | image_size=480, vit='base'), 30 | 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth', 31 | image_size=384, vit='base') 32 | } 33 | 34 | @cog.input( 35 | "image", 36 | type=Path, 37 | help="input image", 38 | ) 39 | @cog.input( 40 | "task", 41 | type=str, 42 | default='image_captioning', 43 | options=['image_captioning', 'visual_question_answering', 'image_text_matching'], 44 | help="Choose a task.", 45 | ) 46 | @cog.input( 47 | "question", 48 | type=str, 49 | default=None, 50 | help="Type question for the input image for visual question answering task.", 51 | ) 52 | @cog.input( 53 | "caption", 54 | type=str, 55 | default=None, 56 | help="Type caption for the input image for image text matching task.", 57 | ) 58 | def predict(self, image, task, question, caption): 59 | if task == 'visual_question_answering': 60 | assert question is not None, 'Please type a question for visual question answering task.' 61 | if task == 'image_text_matching': 62 | assert caption is not None, 'Please type a caption for mage text matching task.' 63 | 64 | im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device) 65 | model = self.models[task] 66 | model.eval() 67 | model = model.to(self.device) 68 | 69 | if task == 'image_captioning': 70 | with torch.no_grad(): 71 | caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5) 72 | return 'Caption: ' + caption[0] 73 | 74 | if task == 'visual_question_answering': 75 | with torch.no_grad(): 76 | answer = model(im, question, train=False, inference='generate') 77 | return 'Answer: ' + answer[0] 78 | 79 | # image_text_matching 80 | itm_output = model(im, caption, match_head='itm') 81 | itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1] 82 | itc_score = model(im, caption, match_head='itc') 83 | return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \ 84 | f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.' 85 | 86 | 87 | def load_image(image, image_size, device): 88 | raw_image = Image.open(str(image)).convert('RGB') 89 | 90 | w, h = raw_image.size 91 | 92 | transform = transforms.Compose([ 93 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 96 | ]) 97 | image = transform(raw_image).unsqueeze(0).to(device) 98 | return image 99 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_pretrain import blip_pretrain 26 | import utils 27 | from utils import warmup_lr_schedule, step_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | 30 | def train(model, data_loader, optimizer, epoch, device, config): 31 | # train 32 | model.train() 33 | 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 36 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 37 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 38 | metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 39 | 40 | header = 'Train Epoch: [{}]'.format(epoch) 41 | print_freq = 50 42 | 43 | if config['laion_path']: 44 | data_loader.dataset.reload_laion(epoch) 45 | 46 | data_loader.sampler.set_epoch(epoch) 47 | 48 | for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 49 | 50 | if epoch==0: 51 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 52 | 53 | optimizer.zero_grad() 54 | 55 | image = image.to(device,non_blocking=True) 56 | 57 | # ramp up alpha in the first 2 epochs 58 | alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) 59 | 60 | loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) 61 | loss = loss_ita + loss_itm + loss_lm 62 | 63 | loss.backward() 64 | optimizer.step() 65 | 66 | metric_logger.update(loss_ita=loss_ita.item()) 67 | metric_logger.update(loss_itm=loss_itm.item()) 68 | metric_logger.update(loss_lm=loss_lm.item()) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | 71 | 72 | # gather the stats from all processes 73 | metric_logger.synchronize_between_processes() 74 | print("Averaged stats:", metric_logger.global_avg()) 75 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 76 | 77 | 78 | def main(args, config): 79 | utils.init_distributed_mode(args) 80 | 81 | device = torch.device(args.device) 82 | 83 | # fix the seed for reproducibility 84 | seed = args.seed + utils.get_rank() 85 | torch.manual_seed(seed) 86 | np.random.seed(seed) 87 | random.seed(seed) 88 | cudnn.benchmark = True 89 | 90 | #### Dataset #### 91 | print("Creating dataset") 92 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 93 | print('number of training samples: %d'%len(datasets[0])) 94 | 95 | num_tasks = utils.get_world_size() 96 | global_rank = utils.get_rank() 97 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 98 | 99 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 100 | 101 | #### Model #### 102 | print("Creating model") 103 | model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 104 | vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) 105 | 106 | model = model.to(device) 107 | 108 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 109 | 110 | start_epoch = 0 111 | if args.checkpoint: 112 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 113 | state_dict = checkpoint['model'] 114 | model.load_state_dict(state_dict) 115 | 116 | optimizer.load_state_dict(checkpoint['optimizer']) 117 | start_epoch = checkpoint['epoch']+1 118 | print('resume checkpoint from %s'%args.checkpoint) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 123 | model_without_ddp = model.module 124 | 125 | print("Start training") 126 | start_time = time.time() 127 | for epoch in range(start_epoch, config['max_epoch']): 128 | 129 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 130 | 131 | train_stats = train(model, data_loader, optimizer, epoch, device, config) 132 | if utils.is_main_process(): 133 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 134 | 'epoch': epoch, 135 | } 136 | save_obj = { 137 | 'model': model_without_ddp.state_dict(), 138 | 'optimizer': optimizer.state_dict(), 139 | 'config': config, 140 | 'epoch': epoch, 141 | } 142 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 143 | 144 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 145 | f.write(json.dumps(log_stats) + "\n") 146 | 147 | dist.barrier() 148 | 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | print('Training time {}'.format(total_time_str)) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--config', default='./configs/pretrain.yaml') 157 | parser.add_argument('--output_dir', default='output/Pretrain') 158 | parser.add_argument('--checkpoint', default='') 159 | parser.add_argument('--evaluate', action='store_true') 160 | parser.add_argument('--device', default='cuda') 161 | parser.add_argument('--seed', default=42, type=int) 162 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 163 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 164 | parser.add_argument('--distributed', default=True, type=bool) 165 | args = parser.parse_args() 166 | 167 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 168 | 169 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 170 | 171 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 172 | 173 | main(args, config) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | transformers==4.15.0 3 | fairscale==0.4.4 4 | pycocoevalcap 5 | -------------------------------------------------------------------------------- /train_caption.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.utils import save_result, coco_caption_eval 30 | 31 | def train(model, data_loader, optimizer, epoch, device): 32 | # train 33 | model.train() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 38 | header = 'Train Caption Epoch: [{}]'.format(epoch) 39 | print_freq = 50 40 | 41 | for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | image = image.to(device) 43 | 44 | loss = model(image, caption) 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | metric_logger.update(loss=loss.item()) 51 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 52 | 53 | # gather the stats from all processes 54 | metric_logger.synchronize_between_processes() 55 | print("Averaged stats:", metric_logger.global_avg()) 56 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 57 | 58 | 59 | @torch.no_grad() 60 | def evaluate(model, data_loader, device, config): 61 | # evaluate 62 | model.eval() 63 | 64 | metric_logger = utils.MetricLogger(delimiter=" ") 65 | header = 'Caption generation:' 66 | print_freq = 10 67 | 68 | result = [] 69 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 70 | 71 | image = image.to(device) 72 | 73 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 74 | min_length=config['min_length']) 75 | 76 | for caption, img_id in zip(captions, image_id): 77 | result.append({"image_id": img_id.item(), "caption": caption}) 78 | 79 | return result 80 | 81 | 82 | def main(args, config): 83 | utils.init_distributed_mode(args) 84 | 85 | device = torch.device(args.device) 86 | 87 | # fix the seed for reproducibility 88 | seed = args.seed + utils.get_rank() 89 | torch.manual_seed(seed) 90 | np.random.seed(seed) 91 | random.seed(seed) 92 | cudnn.benchmark = True 93 | 94 | #### Dataset #### 95 | print("Creating captioning dataset") 96 | train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config) 97 | 98 | if args.distributed: 99 | num_tasks = utils.get_world_size() 100 | global_rank = utils.get_rank() 101 | samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank) 102 | else: 103 | samplers = [None, None, None] 104 | 105 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 106 | batch_size=[config['batch_size']]*3,num_workers=[4,4,4], 107 | is_trains=[True, False, False], collate_fns=[None,None,None]) 108 | 109 | #### Model #### 110 | print("Creating model") 111 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 112 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 113 | prompt=config['prompt']) 114 | 115 | model = model.to(device) 116 | 117 | model_without_ddp = model 118 | if args.distributed: 119 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 120 | model_without_ddp = model.module 121 | 122 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 123 | 124 | best = 0 125 | best_epoch = 0 126 | 127 | print("Start training") 128 | start_time = time.time() 129 | for epoch in range(0, config['max_epoch']): 130 | if not args.evaluate: 131 | if args.distributed: 132 | train_loader.sampler.set_epoch(epoch) 133 | 134 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 135 | 136 | train_stats = train(model, train_loader, optimizer, epoch, device) 137 | 138 | val_result = evaluate(model_without_ddp, val_loader, device, config) 139 | val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id') 140 | 141 | test_result = evaluate(model_without_ddp, test_loader, device, config) 142 | test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id') 143 | 144 | if utils.is_main_process(): 145 | coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val') 146 | coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test') 147 | 148 | if args.evaluate: 149 | log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()}, 150 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 151 | } 152 | with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: 153 | f.write(json.dumps(log_stats) + "\n") 154 | else: 155 | save_obj = { 156 | 'model': model_without_ddp.state_dict(), 157 | 'optimizer': optimizer.state_dict(), 158 | 'config': config, 159 | 'epoch': epoch, 160 | } 161 | 162 | if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best: 163 | best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] 164 | best_epoch = epoch 165 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 166 | 167 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 168 | **{f'val_{k}': v for k, v in coco_val.eval.items()}, 169 | **{f'test_{k}': v for k, v in coco_test.eval.items()}, 170 | 'epoch': epoch, 171 | 'best_epoch': best_epoch, 172 | } 173 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 174 | f.write(json.dumps(log_stats) + "\n") 175 | 176 | if args.evaluate: 177 | break 178 | dist.barrier() 179 | 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | print('Training time {}'.format(total_time_str)) 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--config', default='./configs/caption_coco.yaml') 188 | parser.add_argument('--output_dir', default='output/Caption_coco') 189 | parser.add_argument('--evaluate', action='store_true') 190 | parser.add_argument('--device', default='cuda') 191 | parser.add_argument('--seed', default=42, type=int) 192 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 193 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 194 | parser.add_argument('--distributed', default=True, type=bool) 195 | args = parser.parse_args() 196 | 197 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 198 | 199 | args.result_dir = os.path.join(args.output_dir, 'result') 200 | 201 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 202 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 203 | 204 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 205 | 206 | main(args, config) -------------------------------------------------------------------------------- /train_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | import json 18 | import pickle 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | import torch.backends.cudnn as cudnn 25 | import torch.distributed as dist 26 | 27 | from models.blip_nlvr import blip_nlvr 28 | 29 | import utils 30 | from utils import cosine_lr_schedule, warmup_lr_schedule 31 | from data import create_dataset, create_sampler, create_loader 32 | 33 | def train(model, data_loader, optimizer, epoch, device, config): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | step_size = 10 44 | 45 | for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | 47 | images = torch.cat([image0, image1], dim=0) 48 | images, targets = images.to(device), targets.to(device) 49 | 50 | loss = model(images, text, targets=targets, train=True) 51 | 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 57 | metric_logger.update(loss=loss.item()) 58 | 59 | # gather the stats from all processes 60 | metric_logger.synchronize_between_processes() 61 | print("Averaged stats:", metric_logger.global_avg()) 62 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(model, data_loader, device, config): 67 | # test 68 | model.eval() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | 72 | header = 'Evaluation:' 73 | print_freq = 50 74 | 75 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): 76 | images = torch.cat([image0, image1], dim=0) 77 | images, targets = images.to(device), targets.to(device) 78 | 79 | prediction = model(images, text, targets=targets, train=False) 80 | 81 | _, pred_class = prediction.max(1) 82 | accuracy = (targets==pred_class).sum() / targets.size(0) 83 | 84 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) 85 | 86 | # gather the stats from all processes 87 | metric_logger.synchronize_between_processes() 88 | 89 | print("Averaged stats:", metric_logger.global_avg()) 90 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 91 | 92 | 93 | 94 | def main(args, config): 95 | utils.init_distributed_mode(args) 96 | 97 | device = torch.device(args.device) 98 | 99 | # fix the seed for reproducibility 100 | seed = args.seed + utils.get_rank() 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | random.seed(seed) 104 | cudnn.benchmark = True 105 | 106 | #### Dataset #### 107 | print("Creating dataset") 108 | datasets = create_dataset('nlvr', config) 109 | 110 | if args.distributed: 111 | num_tasks = utils.get_world_size() 112 | global_rank = utils.get_rank() 113 | samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank) 114 | else: 115 | samplers = [None, None, None] 116 | 117 | batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']] 118 | train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size, 119 | num_workers=[4,4,4],is_trains=[True,False,False], 120 | collate_fns=[None,None,None]) 121 | 122 | #### Model #### 123 | print("Creating model") 124 | model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'], 125 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 126 | 127 | model = model.to(device) 128 | 129 | model_without_ddp = model 130 | if args.distributed: 131 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 132 | model_without_ddp = model.module 133 | 134 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 135 | 136 | print("Start training") 137 | start_time = time.time() 138 | best = 0 139 | best_epoch = 0 140 | 141 | for epoch in range(0, config['max_epoch']): 142 | if not args.evaluate: 143 | if args.distributed: 144 | train_loader.sampler.set_epoch(epoch) 145 | 146 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 147 | 148 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 149 | 150 | val_stats = evaluate(model, val_loader, device, config) 151 | test_stats = evaluate(model, test_loader, device, config) 152 | 153 | if utils.is_main_process(): 154 | if args.evaluate: 155 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}, 156 | **{f'test_{k}': v for k, v in test_stats.items()}, 157 | } 158 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 159 | f.write(json.dumps(log_stats) + "\n") 160 | 161 | else: 162 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 163 | **{f'val_{k}': v for k, v in val_stats.items()}, 164 | **{f'test_{k}': v for k, v in test_stats.items()}, 165 | 'epoch': epoch, 166 | } 167 | 168 | if float(val_stats['acc'])>best: 169 | save_obj = { 170 | 'model': model_without_ddp.state_dict(), 171 | 'optimizer': optimizer.state_dict(), 172 | 'config': config, 173 | 'epoch': epoch, 174 | } 175 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 176 | best = float(val_stats['acc']) 177 | best_epoch = epoch 178 | 179 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 180 | f.write(json.dumps(log_stats) + "\n") 181 | if args.evaluate: 182 | break 183 | 184 | dist.barrier() 185 | 186 | if utils.is_main_process(): 187 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 188 | f.write("best epoch: %d"%best_epoch) 189 | 190 | total_time = time.time() - start_time 191 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 192 | print('Training time {}'.format(total_time_str)) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--config', default='./configs/nlvr.yaml') 198 | parser.add_argument('--output_dir', default='output/NLVR') 199 | parser.add_argument('--evaluate', action='store_true') 200 | parser.add_argument('--device', default='cuda') 201 | parser.add_argument('--seed', default=42, type=int) 202 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 203 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 204 | parser.add_argument('--distributed', default=True, type=bool) 205 | args = parser.parse_args() 206 | 207 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 208 | 209 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 210 | 211 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 212 | 213 | main(args, config) -------------------------------------------------------------------------------- /train_retrieval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip_retrieval import blip_retrieval 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | 30 | 31 | def train(model, data_loader, optimizer, epoch, device, config): 32 | # train 33 | model.train() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 38 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 39 | header = 'Train Epoch: [{}]'.format(epoch) 40 | print_freq = 50 41 | 42 | for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 43 | image = image.to(device,non_blocking=True) 44 | idx = idx.to(device,non_blocking=True) 45 | 46 | if epoch>0: 47 | alpha = config['alpha'] 48 | else: 49 | alpha = config['alpha']*min(1,i/len(data_loader)) 50 | 51 | loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx) 52 | loss = loss_ita + loss_itm 53 | 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | 58 | metric_logger.update(loss_itm=loss_itm.item()) 59 | metric_logger.update(loss_ita=loss_ita.item()) 60 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 61 | 62 | # gather the stats from all processes 63 | metric_logger.synchronize_between_processes() 64 | print("Averaged stats:", metric_logger.global_avg()) 65 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 66 | 67 | 68 | @torch.no_grad() 69 | def evaluation(model, data_loader, device, config): 70 | # test 71 | model.eval() 72 | 73 | metric_logger = utils.MetricLogger(delimiter=" ") 74 | header = 'Evaluation:' 75 | 76 | print('Computing features for evaluation...') 77 | start_time = time.time() 78 | 79 | texts = data_loader.dataset.text 80 | num_text = len(texts) 81 | text_bs = 256 82 | text_ids = [] 83 | text_embeds = [] 84 | text_atts = [] 85 | for i in range(0, num_text, text_bs): 86 | text = texts[i: min(num_text, i+text_bs)] 87 | text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) 88 | text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') 89 | text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) 90 | text_embeds.append(text_embed) 91 | text_ids.append(text_input.input_ids) 92 | text_atts.append(text_input.attention_mask) 93 | 94 | text_embeds = torch.cat(text_embeds,dim=0) 95 | text_ids = torch.cat(text_ids,dim=0) 96 | text_atts = torch.cat(text_atts,dim=0) 97 | text_ids[:,0] = model.tokenizer.enc_token_id 98 | 99 | image_feats = [] 100 | image_embeds = [] 101 | for image, img_id in data_loader: 102 | image = image.to(device) 103 | image_feat = model.visual_encoder(image) 104 | image_embed = model.vision_proj(image_feat[:,0,:]) 105 | image_embed = F.normalize(image_embed,dim=-1) 106 | 107 | image_feats.append(image_feat.cpu()) 108 | image_embeds.append(image_embed) 109 | 110 | image_feats = torch.cat(image_feats,dim=0) 111 | image_embeds = torch.cat(image_embeds,dim=0) 112 | 113 | sims_matrix = image_embeds @ text_embeds.t() 114 | score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device) 115 | 116 | num_tasks = utils.get_world_size() 117 | rank = utils.get_rank() 118 | step = sims_matrix.size(0)//num_tasks + 1 119 | start = rank*step 120 | end = min(sims_matrix.size(0),start+step) 121 | 122 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 123 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 124 | 125 | encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device) 126 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) 127 | output = model.text_encoder(text_ids[topk_idx], 128 | attention_mask = text_atts[topk_idx], 129 | encoder_hidden_states = encoder_output, 130 | encoder_attention_mask = encoder_att, 131 | return_dict = True, 132 | ) 133 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 134 | score_matrix_i2t[start+i,topk_idx] = score + topk_sim 135 | 136 | sims_matrix = sims_matrix.t() 137 | score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device) 138 | 139 | step = sims_matrix.size(0)//num_tasks + 1 140 | start = rank*step 141 | end = min(sims_matrix.size(0),start+step) 142 | 143 | for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 144 | 145 | topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) 146 | encoder_output = image_feats[topk_idx].to(device) 147 | encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) 148 | output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), 149 | attention_mask = text_atts[start+i].repeat(config['k_test'],1), 150 | encoder_hidden_states = encoder_output, 151 | encoder_attention_mask = encoder_att, 152 | return_dict = True, 153 | ) 154 | score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] 155 | score_matrix_t2i[start+i,topk_idx] = score + topk_sim 156 | 157 | if args.distributed: 158 | dist.barrier() 159 | torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) 160 | torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) 161 | 162 | total_time = time.time() - start_time 163 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 164 | print('Evaluation time {}'.format(total_time_str)) 165 | 166 | return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() 167 | 168 | 169 | 170 | @torch.no_grad() 171 | def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): 172 | 173 | #Images->Text 174 | ranks = np.zeros(scores_i2t.shape[0]) 175 | for index,score in enumerate(scores_i2t): 176 | inds = np.argsort(score)[::-1] 177 | # Score 178 | rank = 1e20 179 | for i in img2txt[index]: 180 | tmp = np.where(inds == i)[0][0] 181 | if tmp < rank: 182 | rank = tmp 183 | ranks[index] = rank 184 | 185 | # Compute metrics 186 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 187 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 188 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 189 | 190 | #Text->Images 191 | ranks = np.zeros(scores_t2i.shape[0]) 192 | 193 | for index,score in enumerate(scores_t2i): 194 | inds = np.argsort(score)[::-1] 195 | ranks[index] = np.where(inds == txt2img[index])[0][0] 196 | 197 | # Compute metrics 198 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 199 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 200 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 201 | 202 | tr_mean = (tr1 + tr5 + tr10) / 3 203 | ir_mean = (ir1 + ir5 + ir10) / 3 204 | r_mean = (tr_mean + ir_mean) / 2 205 | 206 | eval_result = {'txt_r1': tr1, 207 | 'txt_r5': tr5, 208 | 'txt_r10': tr10, 209 | 'txt_r_mean': tr_mean, 210 | 'img_r1': ir1, 211 | 'img_r5': ir5, 212 | 'img_r10': ir10, 213 | 'img_r_mean': ir_mean, 214 | 'r_mean': r_mean} 215 | return eval_result 216 | 217 | 218 | def main(args, config): 219 | utils.init_distributed_mode(args) 220 | 221 | device = torch.device(args.device) 222 | 223 | # fix the seed for reproducibility 224 | seed = args.seed + utils.get_rank() 225 | torch.manual_seed(seed) 226 | np.random.seed(seed) 227 | random.seed(seed) 228 | cudnn.benchmark = True 229 | 230 | #### Dataset #### 231 | print("Creating retrieval dataset") 232 | train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config) 233 | 234 | if args.distributed: 235 | num_tasks = utils.get_world_size() 236 | global_rank = utils.get_rank() 237 | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] 238 | else: 239 | samplers = [None, None, None] 240 | 241 | train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, 242 | batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2, 243 | num_workers=[4,4,4], 244 | is_trains=[True, False, False], 245 | collate_fns=[None,None,None]) 246 | 247 | 248 | #### Model #### 249 | print("Creating model") 250 | model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 251 | vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 252 | queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank']) 253 | 254 | model = model.to(device) 255 | 256 | model_without_ddp = model 257 | if args.distributed: 258 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 259 | model_without_ddp = model.module 260 | 261 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 262 | 263 | best = 0 264 | best_epoch = 0 265 | 266 | print("Start training") 267 | start_time = time.time() 268 | 269 | for epoch in range(0, config['max_epoch']): 270 | if not args.evaluate: 271 | if args.distributed: 272 | train_loader.sampler.set_epoch(epoch) 273 | 274 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 275 | 276 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 277 | 278 | score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config) 279 | score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config) 280 | 281 | if utils.is_main_process(): 282 | 283 | val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt) 284 | print(val_result) 285 | 286 | if val_result['r_mean']>best: 287 | save_obj = { 288 | 'model': model_without_ddp.state_dict(), 289 | 'optimizer': optimizer.state_dict(), 290 | 'config': config, 291 | 'epoch': epoch, 292 | } 293 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 294 | best = val_result['r_mean'] 295 | best_epoch = epoch 296 | 297 | test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) 298 | print(test_result) 299 | 300 | if args.evaluate: 301 | log_stats = {**{f'val_{k}': v for k, v in val_result.items()}, 302 | **{f'test_{k}': v for k, v in test_result.items()}, 303 | } 304 | with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: 305 | f.write(json.dumps(log_stats) + "\n") 306 | else: 307 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 308 | **{f'val_{k}': v for k, v in val_result.items()}, 309 | **{f'test_{k}': v for k, v in test_result.items()}, 310 | 'epoch': epoch, 311 | 'best_epoch': best_epoch, 312 | } 313 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 314 | f.write(json.dumps(log_stats) + "\n") 315 | 316 | if args.evaluate: 317 | break 318 | 319 | dist.barrier() 320 | torch.cuda.empty_cache() 321 | 322 | total_time = time.time() - start_time 323 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 324 | print('Training time {}'.format(total_time_str)) 325 | 326 | 327 | if __name__ == '__main__': 328 | parser = argparse.ArgumentParser() 329 | parser.add_argument('--config', default='./configs/retrieval_flickr.yaml') 330 | parser.add_argument('--output_dir', default='output/Retrieval_flickr') 331 | parser.add_argument('--evaluate', action='store_true') 332 | parser.add_argument('--device', default='cuda') 333 | parser.add_argument('--seed', default=42, type=int) 334 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 335 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 336 | parser.add_argument('--distributed', default=True, type=bool) 337 | args = parser.parse_args() 338 | 339 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 340 | 341 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 342 | 343 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 344 | 345 | main(args, config) -------------------------------------------------------------------------------- /train_vqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel_yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | 25 | from models.blip_vqa import blip_vqa 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.vqa_dataset import vqa_collate_fn 30 | from data.utils import save_result 31 | 32 | 33 | def train(model, data_loader, optimizer, epoch, device): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 50 43 | 44 | for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True) 46 | 47 | loss = model(image, question, answer, train=True, n=n, weights=weights) 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | metric_logger.update(loss=loss.item()) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | # gather the stats from all processes 57 | metric_logger.synchronize_between_processes() 58 | print("Averaged stats:", metric_logger.global_avg()) 59 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 60 | 61 | 62 | @torch.no_grad() 63 | def evaluation(model, data_loader, device, config) : 64 | # test 65 | model.eval() 66 | 67 | metric_logger = utils.MetricLogger(delimiter=" ") 68 | header = 'Generate VQA test result:' 69 | print_freq = 50 70 | 71 | result = [] 72 | 73 | if config['inference']=='rank': 74 | answer_list = data_loader.dataset.answer_list 75 | answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) 76 | answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id 77 | 78 | for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 79 | image = image.to(device,non_blocking=True) 80 | 81 | if config['inference']=='generate': 82 | answers = model(image, question, train=False, inference='generate') 83 | 84 | for answer, ques_id in zip(answers, question_id): 85 | ques_id = int(ques_id.item()) 86 | result.append({"question_id":ques_id, "answer":answer}) 87 | 88 | elif config['inference']=='rank': 89 | answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test']) 90 | 91 | for ques_id, answer_id in zip(question_id, answer_ids): 92 | result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]}) 93 | 94 | return result 95 | 96 | 97 | def main(args, config): 98 | utils.init_distributed_mode(args) 99 | 100 | device = torch.device(args.device) 101 | 102 | # fix the seed for reproducibility 103 | seed = args.seed + utils.get_rank() 104 | torch.manual_seed(seed) 105 | np.random.seed(seed) 106 | random.seed(seed) 107 | cudnn.benchmark = True 108 | 109 | #### Dataset #### 110 | print("Creating vqa datasets") 111 | datasets = create_dataset('vqa', config) 112 | 113 | if args.distributed: 114 | num_tasks = utils.get_world_size() 115 | global_rank = utils.get_rank() 116 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 117 | else: 118 | samplers = [None, None] 119 | 120 | train_loader, test_loader = create_loader(datasets,samplers, 121 | batch_size=[config['batch_size_train'],config['batch_size_test']], 122 | num_workers=[4,4],is_trains=[True, False], 123 | collate_fns=[vqa_collate_fn,None]) 124 | #### Model #### 125 | print("Creating model") 126 | model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'], 127 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 128 | 129 | model = model.to(device) 130 | 131 | model_without_ddp = model 132 | if args.distributed: 133 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 134 | model_without_ddp = model.module 135 | 136 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 137 | 138 | best = 0 139 | best_epoch = 0 140 | 141 | print("Start training") 142 | start_time = time.time() 143 | for epoch in range(0, config['max_epoch']): 144 | if not args.evaluate: 145 | if args.distributed: 146 | train_loader.sampler.set_epoch(epoch) 147 | 148 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 149 | 150 | train_stats = train(model, train_loader, optimizer, epoch, device) 151 | 152 | else: 153 | break 154 | 155 | if utils.is_main_process(): 156 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 157 | 'epoch': epoch, 158 | } 159 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 160 | f.write(json.dumps(log_stats) + "\n") 161 | 162 | save_obj = { 163 | 'model': model_without_ddp.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'config': config, 166 | 'epoch': epoch, 167 | } 168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 169 | 170 | dist.barrier() 171 | 172 | vqa_result = evaluation(model_without_ddp, test_loader, device, config) 173 | result_file = save_result(vqa_result, args.result_dir, 'vqa_result') 174 | 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print('Training time {}'.format(total_time_str)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--config', default='./configs/vqa.yaml') 184 | parser.add_argument('--output_dir', default='output/VQA') 185 | parser.add_argument('--evaluate', action='store_true') 186 | parser.add_argument('--device', default='cuda') 187 | parser.add_argument('--seed', default=42, type=int) 188 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 189 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 190 | parser.add_argument('--distributed', default=True, type=bool) 191 | args = parser.parse_args() 192 | 193 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 194 | 195 | args.result_dir = os.path.join(args.output_dir, 'result') 196 | 197 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 198 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 199 | 200 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 201 | 202 | main(args, config) -------------------------------------------------------------------------------- /transform/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 15 | """Decay the learning rate""" 16 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | import numpy as np 21 | import io 22 | import os 23 | import time 24 | from collections import defaultdict, deque 25 | import datetime 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if isinstance(v, torch.Tensor): 100 | v = v.item() 101 | assert isinstance(v, (float, int)) 102 | self.meters[k].update(v) 103 | 104 | def __getattr__(self, attr): 105 | if attr in self.meters: 106 | return self.meters[attr] 107 | if attr in self.__dict__: 108 | return self.__dict__[attr] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, attr)) 111 | 112 | def __str__(self): 113 | loss_str = [] 114 | for name, meter in self.meters.items(): 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) 118 | return self.delimiter.join(loss_str) 119 | 120 | def global_avg(self): 121 | loss_str = [] 122 | for name, meter in self.meters.items(): 123 | loss_str.append( 124 | "{}: {:.4f}".format(name, meter.global_avg) 125 | ) 126 | return self.delimiter.join(loss_str) 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | def add_meter(self, name, meter): 133 | self.meters[name] = meter 134 | 135 | def log_every(self, iterable, print_freq, header=None): 136 | i = 0 137 | if not header: 138 | header = '' 139 | start_time = time.time() 140 | end = time.time() 141 | iter_time = SmoothedValue(fmt='{avg:.4f}') 142 | data_time = SmoothedValue(fmt='{avg:.4f}') 143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 144 | log_msg = [ 145 | header, 146 | '[{0' + space_fmt + '}/{1}]', 147 | 'eta: {eta}', 148 | '{meters}', 149 | 'time: {time}', 150 | 'data: {data}' 151 | ] 152 | if torch.cuda.is_available(): 153 | log_msg.append('max mem: {memory:.0f}') 154 | log_msg = self.delimiter.join(log_msg) 155 | MB = 1024.0 * 1024.0 156 | for obj in iterable: 157 | data_time.update(time.time() - end) 158 | yield obj 159 | iter_time.update(time.time() - end) 160 | if i % print_freq == 0 or i == len(iterable) - 1: 161 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 163 | if torch.cuda.is_available(): 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB)) 169 | else: 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time))) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | class AttrDict(dict): 183 | def __init__(self, *args, **kwargs): 184 | super(AttrDict, self).__init__(*args, **kwargs) 185 | self.__dict__ = self 186 | 187 | 188 | def compute_acc(logits, label, reduction='mean'): 189 | ret = (torch.argmax(logits, dim=1) == label).float() 190 | if reduction == 'none': 191 | return ret.detach() 192 | elif reduction == 'mean': 193 | return ret.mean().item() 194 | 195 | def compute_n_params(model, return_str=True): 196 | tot = 0 197 | for p in model.parameters(): 198 | w = 1 199 | for x in p.shape: 200 | w *= x 201 | tot += w 202 | if return_str: 203 | if tot >= 1e6: 204 | return '{:.1f}M'.format(tot / 1e6) 205 | else: 206 | return '{:.1f}K'.format(tot / 1e3) 207 | else: 208 | return tot 209 | 210 | def setup_for_distributed(is_master): 211 | """ 212 | This function disables printing when not in master process 213 | """ 214 | import builtins as __builtin__ 215 | builtin_print = __builtin__.print 216 | 217 | def print(*args, **kwargs): 218 | force = kwargs.pop('force', False) 219 | if is_master or force: 220 | builtin_print(*args, **kwargs) 221 | 222 | __builtin__.print = print 223 | 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not dist.is_available(): 227 | return False 228 | if not dist.is_initialized(): 229 | return False 230 | return True 231 | 232 | 233 | def get_world_size(): 234 | if not is_dist_avail_and_initialized(): 235 | return 1 236 | return dist.get_world_size() 237 | 238 | 239 | def get_rank(): 240 | if not is_dist_avail_and_initialized(): 241 | return 0 242 | return dist.get_rank() 243 | 244 | 245 | def is_main_process(): 246 | return get_rank() == 0 247 | 248 | 249 | def save_on_master(*args, **kwargs): 250 | if is_main_process(): 251 | torch.save(*args, **kwargs) 252 | 253 | 254 | def init_distributed_mode(args): 255 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 256 | args.rank = int(os.environ["RANK"]) 257 | args.world_size = int(os.environ['WORLD_SIZE']) 258 | args.gpu = int(os.environ['LOCAL_RANK']) 259 | elif 'SLURM_PROCID' in os.environ: 260 | args.rank = int(os.environ['SLURM_PROCID']) 261 | args.gpu = args.rank % torch.cuda.device_count() 262 | else: 263 | print('Not using distributed mode') 264 | args.distributed = False 265 | return 266 | 267 | args.distributed = True 268 | 269 | torch.cuda.set_device(args.gpu) 270 | args.dist_backend = 'nccl' 271 | print('| distributed init (rank {}, word {}): {}'.format( 272 | args.rank, args.world_size, args.dist_url), flush=True) 273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 274 | world_size=args.world_size, rank=args.rank) 275 | torch.distributed.barrier() 276 | setup_for_distributed(args.rank == 0) 277 | 278 | --------------------------------------------------------------------------------