├── .DS_Store ├── imgs ├── 983.jpg ├── main.jpg ├── block_mask.png └── motivation.jpg ├── requirements.txt ├── src ├── blip_src │ ├── multiple_scripts │ │ ├── pretrain.sh │ │ ├── multiple_exp_all_single_8u_ft.sh │ │ ├── ft │ │ │ ├── exp_4.sh │ │ │ ├── exp_3.sh │ │ │ ├── exp_2.sh │ │ │ ├── exp_5.sh │ │ │ └── exp_1.sh │ │ └── exp_all_ft_single.sh │ ├── configs │ │ ├── nocaps.yaml │ │ ├── nlvr.yaml │ │ ├── pretrain_concated_pred_4M.yaml │ │ ├── bert_config.json │ │ ├── med_config.json │ │ ├── vqa.yaml │ │ ├── caption_coco.yaml │ │ ├── retrieval_coco.yaml │ │ └── retrieval_flickr.yaml │ ├── move_pretrained_weights.py │ ├── data │ │ ├── nlvr_dataset.py │ │ ├── flickr30k_dataset.py │ │ ├── vqa_dataset.py │ │ ├── pretrain_dataset_concated_pred_tsv.py │ │ ├── coco_karpathy_dataset.py │ │ ├── utils.py │ │ ├── init_data_concated_pred_tsv.py │ │ ├── init_data_concated_pred_refined.py │ │ ├── __init__.py │ │ └── pretrain_dataset_concated_pred_refined.py │ ├── models │ │ ├── blip_itm.py │ │ ├── blip_nlvr.py │ │ ├── blip_vqa.py │ │ ├── blip.py │ │ ├── blip_retrieval.py │ │ └── vit.py │ ├── eval_nocaps.py │ ├── pretrain_concated_pred_refined.py │ ├── pretrain_concated_pred_tsv.py │ ├── train_vqa.py │ ├── train_nlvr.py │ └── utils.py ├── data_preprocess │ ├── download_cc3m_predictions.sh │ ├── bbox_visualization.py │ └── generate_sample_with_bbox_and_classes.py └── vilt_src │ └── config.py ├── INSTALL.md ├── MODEL_ZOO.md ├── DATASET.md ├── GETTING_STARTED.md ├── README.md └── LICENSE /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/.DS_Store -------------------------------------------------------------------------------- /imgs/983.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/983.jpg -------------------------------------------------------------------------------- /imgs/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/main.jpg -------------------------------------------------------------------------------- /imgs/block_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/block_mask.png -------------------------------------------------------------------------------- /imgs/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/motivation.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | ruamel_yaml 3 | pycocoevalcap 4 | transformers 5 | timm 6 | tabulate 7 | numpy -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | python move_pretrained_weights.py; 2 | 3 | # 4M \ 4 | python -m torch.distributed.launch --nproc_per_node=8 pretrain_concated_pred_tsv.py \ 5 | --config ./configs/pretrain_concated_pred_4M.yaml --output_dir output/Pretrain_concated_pred_4M 6 | 7 | echo "output dir is: output/Pretrain_concated_pred_4M" -------------------------------------------------------------------------------- /src/blip_src/configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 6 | vit: 'base' 7 | batch_size: 32 8 | 9 | image_size: 384 10 | 11 | max_length: 20 12 | min_length: 5 13 | num_beams: 3 14 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Install PyTorch and the other dependencies for PTP 2 | 3 | The code has been test on '1.9.0+cu102' and python3.8. 4 | 5 | 6 | 7 | ```bash 8 | conda create -n ptp python==3.8 9 | conda activate ptp 10 | # CUDA 10.2 11 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=10.2 -c pytorch 12 | cd [PATH_TO_PTP] 13 | pip install -r requirements.txt 14 | ``` 15 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/multiple_exp_all_single_8u_ft.sh: -------------------------------------------------------------------------------- 1 | 2 | declare -a PTMethodArray=( "Pretrain_concated_pred_4M" ) 3 | 4 | for pt_method in "${PTMethodArray[@]}" 5 | do 6 | echo "==== start evaluate model $pt_method ====" 7 | 8 | echo "==== utilize pretrained model output/$pt_method/checkpoint_19.pth ====" 9 | 10 | bash ./multiple_scripts/exp_all_ft_single.sh $pt_method; 11 | done -------------------------------------------------------------------------------- /src/blip_src/configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/NLVR2' 2 | ann_root: 'annotation' 3 | 4 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 5 | 6 | #size of vit model; base or large 7 | vit: 'base' 8 | batch_size_train: 16 9 | batch_size_test: 64 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | max_epoch: 15 13 | 14 | image_size: 384 15 | 16 | # optimizer 17 | weight_decay: 0.05 18 | init_lr: 3e-5 19 | min_lr: 0 20 | 21 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_4.sh: -------------------------------------------------------------------------------- 1 | # NLVR2 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | echo '(NLVR:) load pretrained model from: '$3; 11 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/nlvr.yaml; 12 | 13 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_nlvr.py \ 14 | --config ./configs/nlvr.yaml \ 15 | --output_dir output/NLVR_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_3.sh: -------------------------------------------------------------------------------- 1 | # vqa 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | echo '(vqa:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/vqa.yaml; 13 | 14 | 15 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) \ 16 | train_vqa.py --config ./configs/vqa.yaml \ 17 | --output_dir output/vqa_v2_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/configs/pretrain_concated_pred_4M.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'metadata/4M_corpus.tsv', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 # 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/blip_src/configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /src/blip_src/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /src/blip_src/configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/dataset/coco2014/COCO2014' #followed by train2014/ 2 | vg_root: '/dataset/VisualGenome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | 8 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 12 # 16 12 | batch_size_test: 24 # 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /src/blip_src/configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/coco2014/COCO2014' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 # 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/coco2014/COCO2014' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | # size of vit model; base or large 8 | 9 | vit: 'base' 10 | batch_size_train: 24 # 32 11 | batch_size_test: 48 # 64 12 | vit_grad_ckpt: True 13 | vit_ckpt_layer: 4 14 | init_lr: 1e-5 15 | 16 | # vit: 'large' 17 | # batch_size_train: 16 18 | # batch_size_test: 32 19 | # vit_grad_ckpt: True 20 | # vit_ckpt_layer: 12 21 | # init_lr: 5e-6 22 | 23 | image_size: 384 24 | queue_size: 57600 25 | alpha: 0.4 26 | k_test: 256 27 | negative_all_rank: True 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 6 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/Flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | # size of vit model; base or large 8 | 9 | vit: 'base' 10 | batch_size_train: 24 # 32 11 | batch_size_test: 48 # 64 12 | vit_grad_ckpt: True 13 | vit_ckpt_layer: 4 14 | init_lr: 1e-5 15 | 16 | # vit: 'large' 17 | # batch_size_train: 16 18 | # batch_size_test: 32 19 | # vit_grad_ckpt: True 20 | # vit_ckpt_layer: 10 21 | # init_lr: 5e-6 22 | 23 | image_size: 384 24 | queue_size: 57600 25 | alpha: 0.4 26 | k_test: 128 27 | negative_all_rank: False 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 6 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/exp_all_ft_single.sh: -------------------------------------------------------------------------------- 1 | python move_pretrained_weights.py; 2 | 3 | gpu_num=8; 4 | time=$(date "+%Y-%m-%d-%H:%M:%S"); 5 | suffix=$1${time}; # the suffix to distingush different experiment, e.g. $1='generation_mix' 6 | 7 | PRETRAINED_MODEL="output\/$1\/checkpoint_19.pth" 8 | 9 | echo "${suffix}"; 10 | 11 | 12 | bash multiple_scripts/ft/exp_2.sh $gpu_num $suffix $PRETRAINED_MODEL; # captioning, ~1h 13 | 14 | bash multiple_scripts/ft/exp_5.sh $gpu_num $suffix $PRETRAINED_MODEL; # flickr30 retrieval, ~1h 15 | 16 | bash multiple_scripts/ft/exp_4.sh $gpu_num $suffix $PRETRAINED_MODEL; # NLVR2, ~2h 17 | 18 | bash multiple_scripts/ft/exp_1.sh $gpu_num $suffix $PRETRAINED_MODEL; # coco retrieval, ~12h 19 | 20 | bash multiple_scripts/ft/exp_3.sh $gpu_num $suffix $PRETRAINED_MODEL; # vqa, very slow, ~35 h 21 | -------------------------------------------------------------------------------- /src/data_preprocess/download_cc3m_predictions.sh: -------------------------------------------------------------------------------- 1 | for VARIABLE in {0..11..1} 2 | do 3 | mkdir CC3M/$VARIABLE 4 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/predictions.tsv \ 5 | CC3M/$VARIABLE --recursive 6 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/predictions.lineidx \ 7 | CC3M/$VARIABLE --recursive 8 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/annotations \ 9 | CC3M/$VARIABLE --recursive 10 | done -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_2.sh: -------------------------------------------------------------------------------- 1 | # image captioning 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | echo '(coco captioning:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/caption_coco.yaml; 13 | 14 | # step1: zero-shot coco captioning 15 | echo 'step1: zero-shot coco captioning' 16 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 17 | --config ./configs/caption_coco.yaml \ 18 | --output_dir output/captioning_coco_$2 --evaluate 19 | 20 | # step2: fine-tune 21 | 22 | echo 'step2: fine-tune coco captioning' 23 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 24 | --config ./configs/caption_coco.yaml \ 25 | --output_dir output/captioning_coco_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/move_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | print('move pretrained weights...') 17 | try: 18 | # a100 machines 19 | if not os.path.exists('/root/.cache/torch/hub/checkpoints/'): 20 | os.makedirs('/root/.cache/torch/hub/checkpoints/') 21 | os.system( 22 | 'cp pretrained_models/*.pth /root/.cache/torch/hub/checkpoints/.') 23 | print('move finished...') 24 | except Exception as e: 25 | print(e) -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_5.sh: -------------------------------------------------------------------------------- 1 | # image to text retrieval 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | # ===================== step1: zero-shot evaluation================ 11 | echo '(step1: zero-shot f30k retrieval:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_flickr.yaml; 13 | 14 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 15 | --config ./configs/retrieval_flickr.yaml \ 16 | --output_dir output/retrieval_flickr_$2 \ 17 | --evaluate 18 | 19 | 20 | # ===================== step2: ft and evaluate ================ 21 | echo '(step2: fine-tune f30k retrieval:) load pretrained model from: '$3; 22 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_flickr.yaml; 23 | 24 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 25 | --config ./configs/retrieval_flickr.yaml \ 26 | --output_dir output/retrieval_flickr_$2 -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_1.sh: -------------------------------------------------------------------------------- 1 | # image to text retrieval 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | # ============================== step1: zero-shot retrieval evaluation ========= 12 | # evaluate on test 13 | echo '(coco step1: zero-shot evaluation) load pretrained model from: '$3; 14 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_coco.yaml; 15 | 16 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 17 | --config ./configs/retrieval_coco.yaml \ 18 | --output_dir output/retrieval_coco_$2 \ 19 | --evaluate 20 | 21 | # print val than test set 22 | 23 | # ================= step2: train and evaluate val set ================= 24 | echo '(coco step2: train on retrieval) load pretrained model from: '$3; 25 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_coco.yaml; 26 | 27 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 28 | --config ./configs/retrieval_coco.yaml \ 29 | --output_dir output/retrieval_coco_$2 30 | 31 | # # ===========================step3: evaluate on val/test split ========= 32 | # # evaluate on test 33 | 34 | # TRAINED_MODEL="output\/retrieval_coco_${2}\/checkpoint_best.pth" 35 | # echo '(coco step3: test/val eval) load trained retrieval model from: '${TRAINED_MODEL}; 36 | 37 | # sed -i "/^\(pretrained: \).*/s//\1'$TRAINED_MODEL'/" ./configs/retrieval_coco.yaml; 38 | 39 | 40 | # python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 41 | # --config ./configs/retrieval_coco.yaml \ 42 | # --output_dir output/retrieval_coco_$2 \ 43 | # --evaluate -------------------------------------------------------------------------------- /src/data_preprocess/bbox_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | import cv2 18 | import numpy as np 19 | from PIL import Image 20 | import ast 21 | import json 22 | 23 | ann = pd.read_csv('/vg_train_val_success_compressed.tsv', sep='\t', header=None) 24 | # 364, 190 25 | 26 | sample = ann.iloc[-30] 27 | 28 | im = Image.open('VG/VG_100K_2/' + sample[0].split('/')[-1]) 29 | 30 | w, h = im.size 31 | print(w, h) 32 | 33 | bboxs = json.loads(sample[2]) 34 | classes = ast.literal_eval(sample[3]) 35 | img = np.asarray(im) 36 | max_char_per_line = 30 37 | y0, dy = 10, 20 38 | for i in range(len(bboxs)): 39 | cv2.rectangle(img, (bboxs[i][0], bboxs[i][1]), (bboxs[i][0]+bboxs[i][2], bboxs[i][1]+bboxs[i][3]), (0, 255, 0), 2) 40 | text_img = cv2.putText(img, classes[i], (bboxs[i][0], bboxs[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0)) 41 | bboxs[i][2] = bboxs[i][0] + bboxs[i][2] 42 | bboxs[i][3] = bboxs[i][1] + bboxs[i][3] 43 | w_1 = min(int((bboxs[i][0]/2 + bboxs[i][2]/2)/w * 3), 2) 44 | h_1 = min(int((bboxs[i][1]/2 + bboxs[i][3]/2)/h * 3), 2) 45 | # print(w_1, h_1, w, h) 46 | block = 3*h_1 + w_1 47 | prompt_text = '. The block ' + str(block) + ' has a ' + classes[i] + ' .'; 48 | print(prompt_text) 49 | 50 | cv2.imwrite('vg_example_1.jpg', img) 51 | -------------------------------------------------------------------------------- /src/blip_src/data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | class nlvr_dataset(Dataset): 13 | def __init__(self, transform, image_root, ann_root, split): 14 | ''' 15 | image_root (string): Root directory of images 16 | ann_root (string): directory to store the annotation file 17 | split (string): train, val or test 18 | ''' 19 | urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', 20 | 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', 21 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} 22 | filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} 23 | 24 | download_url(urls[split],ann_root) 25 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 26 | 27 | self.transform = transform 28 | self.image_root = image_root 29 | 30 | 31 | def __len__(self): 32 | return len(self.annotation) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.annotation[index] 38 | 39 | image0_path = os.path.join(self.image_root,ann['images'][0]) 40 | image0 = Image.open(image0_path).convert('RGB') 41 | image0 = self.transform(image0) 42 | 43 | image1_path = os.path.join(self.image_root,ann['images'][1]) 44 | image1 = Image.open(image1_path).convert('RGB') 45 | image1 = self.transform(image1) 46 | 47 | sentence = pre_caption(ann['sentence'], 40) 48 | 49 | if ann['label']=='True': 50 | label = 1 51 | else: 52 | label = 0 53 | 54 | words = sentence.split(' ') 55 | 56 | if 'left' not in words and 'right' not in words: 57 | if random.random()<0.5: 58 | return image0, image1, sentence, label 59 | else: 60 | return image1, image0, sentence, label 61 | else: 62 | if random.random()<0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | new_words = [] 66 | for word in words: 67 | if word=='left': 68 | new_words.append('right') 69 | elif word=='right': 70 | new_words.append('left') 71 | else: 72 | new_words.append(word) 73 | 74 | sentence = ' '.join(new_words) 75 | return image1, image0, sentence, label 76 | 77 | 78 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # MODEL ZOO 2 | 3 | 4 | ## 1. Pre-trained Models 5 | 6 | | Method | Vision Encoder | #Images | Dataset | Pretrained Weights | Training Logs | 7 | | :--- | :--- | :--- | :--- | :----: | :---: | 8 | | PTP-BLIP| ViT-B(DeiT) | 4M | CC3M+COCO+VG+SBU | [link](https://huggingface.co/sail/PTP/blob/main/Pretrain_concated_pred_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_pretrain.txt) | 9 | 10 | ## 2. Downstream Model 11 | 12 | 13 | ### 2.1 Captioning 14 | | Method | B@4 | CIDEr | | Config | 15 | | :--- | :--- | :--- | ---: | 16 | | PTP-BLIP| 40.1 | 135.0 | configs/caption_coco.yaml | 17 | 18 | 19 | ### 2.2 Zero-shot Retrieval 20 | 21 | 25 | 26 | 27 | #### 2.2.2 Flickr30K 28 | 29 | | Method | I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 30 | | :--- | :--- | :--- | :--- | :--- | :---: | 31 | | PTP-BLIP| 86.4 | 67.0 | [link](https://huggingface.co/sail/PTP/blob/main/zero_shot_coco_checkpoint_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_zero_shot.txt) | configs/retrieval_flickr.yaml | 32 | 33 | 34 | ### 2.3 Retrieval (Fine-tune) 35 | 36 | Tip: Please use as large batch size as possible, we experimentally find that the larger batch size leads to better result for this task. Due to memory limiation, we use batch size 24 rather than 28 in original implmentation. 37 | 38 | 39 | #### 2.3.1 COCO 40 | | Method |I2T@1 | T2I@1 | | Config | 41 | | :--- | :--- | :--- | :---: | 42 | | PTP-BLIP| 77.6 | 59.4 | configs/retrieval_coco.yaml | 43 | 44 | 45 | #### 2.3.2 Flickr30K 46 | | Method |I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 47 | | :--- | :--- | :--- | :--- | :--- | :---: | 48 | | PTP-BLIP| 96.1 | 84.2 | [link](https://huggingface.co/sail/PTP/blob/main/flickr30k_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_ft.txt) | configs/retrieval_flickr.yaml | 49 | 50 | ### 2.4 VQA V2 51 | 52 | | Method | Test-dev|Test-std |Model Weight | Training Logs | Config | 53 | | :--- | :--- | :--- | :--- | :--- | :---: | 54 | | PTP-BLIP| 76.02 | 76.18 | [link](https://huggingface.co/sail/PTP/blob/main/vqa_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_vqa_v2.txt) | configs/vqa.yaml | 55 | 56 | ### 2.5 NLVR 57 | 58 | | Method | Dev| Test-P | Model Weight | Training Logs | Config | 59 | | :--- | :--- | :--- | :--- | :--- | :---: | 60 | | PTP-BLIP| 80.45 | 80.70 | [link](https://huggingface.co/sail/PTP/blob/main/nlvr_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_nlvr.txt) | configs/nlvr.yaml | -------------------------------------------------------------------------------- /src/data_preprocess/generate_sample_with_bbox_and_classes.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import numpy as np 4 | import zlib 5 | import os 6 | 7 | 8 | out_json = "cc3m_269_w_bbox.json" # 2681187 samples preserved 9 | 10 | object_dict_path = "VG-SGG-dicts-vgoi6-clipped.json" 11 | object_dict = json.load(open(object_dict_path,'r'))['label_to_idx'] 12 | # step1: encode each sample into a vector 13 | print("{} object classes in total".format(len(object_dict))) 14 | # print(object_dict) 15 | 16 | objects_class_count = np.zeros(len(object_dict)) 17 | 18 | 19 | def generate_object_bbox(objects): 20 | object_bboxs = [] 21 | for index, object in enumerate(objects): 22 | if index < 10: 23 | object_bboxs.append([int(coord) for coord in object['rect']]) 24 | # print(object_bboxs) 25 | return object_bboxs 26 | 27 | # step1: generate object tags for each sample 28 | print("===step1: begin to generate object tags caption====") 29 | sample_index = [] 30 | object_bboxs_dict = {} 31 | 32 | for i in range(12): 33 | # if i > 0: 34 | # break 35 | src_tsv = "/Data/CC3M/{}/predictions.tsv".format(i) 36 | metadata = pd.read_csv(src_tsv, sep='\t', header=None) 37 | # append boxes and indexs 38 | for j in range(len(metadata)): 39 | num_boxes = json.loads(metadata.iloc[j][1])['num_boxes'] 40 | index = metadata.iloc[j][0] 41 | sample_index.append(index) 42 | objects = json.loads(metadata.iloc[j][1])['objects'] 43 | object_bboxs_dict[index] = generate_object_bbox(objects) 44 | print("subdir {}/{} finished".format(i+1, 12)) 45 | 46 | # step2: align cc3m with own list 47 | print("===step2: begin to align with previous caption====") 48 | train_set = pd.read_csv('/CC3M/Train-GCC-training.tsv', sep='\t', header=None) 49 | val_set = pd.read_csv('/CC3M/Validation-GCC-1.1.0-Validation.tsv', sep='\t', header=None) 50 | all_set = pd.concat([train_set, val_set]) 51 | success_data = [] 52 | # count = 0 53 | for sample in sample_index: 54 | # count += 1 55 | # if count > 10000: 56 | # break 57 | file_name = str(zlib.crc32(all_set.iloc(0)[sample][1].encode('utf-8')) & 0xffffffff) + '.jpg' 58 | if sample >= len(train_set): 59 | img_root = "validation" 60 | sub_dir = str((sample - len(train_set)) // 1000) 61 | else: 62 | img_root = "train" 63 | sub_dir = str(sample // 1000) 64 | img_path = os.path.join(img_root, sub_dir, file_name) 65 | rel_img_path = os.path.join(img_root, sub_dir, file_name) 66 | success_data.append({'image': rel_img_path, 'caption': all_set.iloc(0)[sample][0], 'object': object_bboxs_dict[sample]}) 67 | 68 | print("{} samples preserved".format(len(success_data))) 69 | # 1484208 samples preserved 70 | 71 | # step3: 72 | print("===step3: merge with caption cc====") 73 | ann = json.load(open('../metadata/cc3m/train_success_align_269.json', 'r')) 74 | 75 | object_caption_dict = {} 76 | 77 | success_data_preserved = [] 78 | img_paths = dict() 79 | for i in range(len(success_data)): 80 | img_paths[success_data[i]['image']] = 0 81 | object_caption_dict[success_data[i]['image']] = success_data[i]['object'] 82 | 83 | # find the joint part 84 | success_data_preserved = [] 85 | for i in range(len(ann)): 86 | if ann[i]['image'] in img_paths.keys(): 87 | ann[i]['object'] = object_caption_dict[ann[i]['image']] 88 | success_data_preserved.append(ann[i]) 89 | if i % 1000 == 0: 90 | print("{}/{} finished".format(i, len(ann))) 91 | 92 | print("{} samples preserved".format(len(success_data_preserved))) 93 | 94 | 95 | with open(out_json, 'w') as outfile: 96 | json.dump(success_data_preserved, outfile) -------------------------------------------------------------------------------- /src/blip_src/data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class flickr30k_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. flickr30k/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 18 | filename = 'flickr30k_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class flickr30k_retrieval_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 54 | ''' 55 | image_root (string): Root directory of images (e.g. flickr30k/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 61 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | self.text = [] 70 | self.image = [] 71 | self.txt2img = {} 72 | self.img2txt = {} 73 | 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.annotation): 76 | self.image.append(ann['image']) 77 | self.img2txt[img_id] = [] 78 | for i, caption in enumerate(ann['caption']): 79 | self.text.append(pre_caption(caption,max_words)) 80 | self.img2txt[img_id].append(txt_id) 81 | self.txt2img[txt_id] = img_id 82 | txt_id += 1 83 | 84 | def __len__(self): 85 | return len(self.annotation) 86 | 87 | def __getitem__(self, index): 88 | 89 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 90 | image = Image.open(image_path).convert('RGB') 91 | image = self.transform(image) 92 | 93 | return image, index -------------------------------------------------------------------------------- /src/blip_src/models/blip_itm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig, BertModel 9 | from transformers import BertTokenizer 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | from models.blip import create_vit, init_tokenizer, load_checkpoint 16 | 17 | class BLIP_ITM(nn.Module): 18 | def __init__(self, 19 | med_config = 'configs/med_config.json', 20 | image_size = 384, 21 | vit = 'base', 22 | vit_grad_ckpt = False, 23 | vit_ckpt_layer = 0, 24 | embed_dim = 256, 25 | ): 26 | """ 27 | Args: 28 | med_config (str): path for the mixture of encoder-decoder model's configuration file 29 | image_size (int): input image size 30 | vit (str): model size of vision transformer 31 | """ 32 | super().__init__() 33 | 34 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 35 | self.tokenizer = init_tokenizer() 36 | med_config = BertConfig.from_json_file(med_config) 37 | med_config.encoder_width = vision_width 38 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 39 | 40 | text_width = self.text_encoder.config.hidden_size 41 | 42 | self.vision_proj = nn.Linear(vision_width, embed_dim) 43 | self.text_proj = nn.Linear(text_width, embed_dim) 44 | 45 | self.itm_head = nn.Linear(text_width, 2) 46 | 47 | 48 | def forward(self, image, caption, match_head='itm'): 49 | 50 | image_embeds = self.visual_encoder(image) 51 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 52 | 53 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 54 | return_tensors="pt").to(image.device) 55 | 56 | 57 | if match_head=='itm': 58 | output = self.text_encoder(text.input_ids, 59 | attention_mask = text.attention_mask, 60 | encoder_hidden_states = image_embeds, 61 | encoder_attention_mask = image_atts, 62 | return_dict = True, 63 | ) 64 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 65 | return itm_output 66 | 67 | elif match_head=='itc': 68 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 69 | return_dict = True, mode = 'text') 70 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 71 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 72 | 73 | sim = image_feat @ text_feat.t() 74 | return sim 75 | 76 | 77 | def blip_itm(pretrained='',**kwargs): 78 | model = BLIP_ITM(**kwargs) 79 | if pretrained: 80 | model,msg = load_checkpoint(model,pretrained) 81 | assert(len(msg.missing_keys)==0) 82 | return model 83 | -------------------------------------------------------------------------------- /src/blip_src/data/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.utils import pre_question 9 | 10 | from torchvision.datasets.utils import download_url 11 | 12 | class vqa_dataset(Dataset): 13 | def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): 14 | self.split = split 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | 20 | if split=='train': 21 | urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', 22 | 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', 23 | 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} 24 | 25 | self.annotation = [] 26 | for f in train_files: 27 | download_url(urls[f],ann_root) 28 | self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) 29 | else: 30 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) 31 | self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) 32 | 33 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) 34 | self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) 35 | 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | if ann['dataset']=='vqa': 45 | image_path = os.path.join(self.vqa_root,ann['image']) 46 | elif ann['dataset']=='vg': 47 | image_path = os.path.join(self.vg_root,ann['image']) 48 | 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | 52 | if self.split == 'test': 53 | question = pre_question(ann['question']) 54 | question_id = ann['question_id'] 55 | return image, question, question_id 56 | 57 | 58 | elif self.split=='train': 59 | 60 | question = pre_question(ann['question']) 61 | 62 | if ann['dataset']=='vqa': 63 | answer_weight = {} 64 | for answer in ann['answer']: 65 | if answer in answer_weight.keys(): 66 | answer_weight[answer] += 1/len(ann['answer']) 67 | else: 68 | answer_weight[answer] = 1/len(ann['answer']) 69 | 70 | answers = list(answer_weight.keys()) 71 | weights = list(answer_weight.values()) 72 | 73 | elif ann['dataset']=='vg': 74 | answers = [ann['answer']] 75 | weights = [0.2] 76 | 77 | return image, question, answers, weights 78 | 79 | 80 | def vqa_collate_fn(batch): 81 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 82 | for image, question, answer, weights in batch: 83 | image_list.append(image) 84 | question_list.append(question) 85 | weight_list += weights 86 | answer_list += answer 87 | n.append(len(answer)) 88 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n -------------------------------------------------------------------------------- /src/blip_src/data/pretrain_dataset_concated_pred_tsv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | import numpy as np 9 | from PIL import ImageFile 10 | from PIL.Image import blend as blend 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | Image.MAX_IMAGE_PIXELS = None 13 | import ast 14 | from data.utils import pre_caption 15 | import os,glob 16 | 17 | class pretrain_dataset(Dataset): 18 | def __init__(self, ann_file, laion_path, transform): 19 | self.img_root = "/dataset" # server 20 | self.ann_pretrain = None 21 | for f in ann_file: 22 | ann_temp = pd.read_csv(f, sep='\t', header=None) 23 | if self.ann_pretrain is None: 24 | self.ann_pretrain = ann_temp 25 | else: 26 | self.ann_pretrain = pd.concat([self.ann_pretrain, ann_temp], ignore_index=True, sort=False) 27 | 28 | self.annotation = self.ann_pretrain 29 | self.transform = transform 30 | 31 | 32 | def generate_bbox_annotation(self, img, ann): 33 | prompt_text = '.' 34 | if len(ann) > 2: 35 | ann[2] = json.loads(ann[2]) 36 | ann[3] = ast.literal_eval(ann[3]) 37 | object_num = len(ann[3]) 38 | if object_num > 0: 39 | sample_index = random.randint(0, object_num-1) 40 | w, h = img.size 41 | bbox_loc = ann[2][sample_index] 42 | # print(bbox_loc) 43 | w_1 = min(int((bbox_loc[0]/2 + bbox_loc[2]/2)/w * 3), 2) 44 | h_1 = min(int((bbox_loc[1]/2 + bbox_loc[3]/2)/h * 3), 2) 45 | # print(w_1, h_1, w, h) 46 | block = 3*h_1 + w_1 47 | prompt_text = '. The block ' + str(block) + ' has a ' + ann[3][sample_index] + ' .'; 48 | else: 49 | prompt_text = "." 50 | return prompt_text 51 | 52 | def __len__(self): 53 | return len(self.annotation) 54 | 55 | def __getitem__(self, index): 56 | ann = self.annotation.iloc[index] 57 | try: 58 | image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 59 | caption_str = ann[1] 60 | if caption_str[0] == '[': 61 | captions = ast.literal_eval(caption_str) 62 | temp_caption = captions[-1] + captions[random.randint(0, len(captions) - 2)] 63 | else: 64 | temp_caption = caption_str 65 | temp_caption = temp_caption + self.generate_bbox_annotation(image, ann) 66 | except Exception as e: 67 | print(e) 68 | return self.__getitem__(random.randint(0, self.__len__()-1)) 69 | image = self.transform(image) 70 | caption = pre_caption(temp_caption, 30) 71 | # caption = ann['caption'] 72 | return image, caption 73 | 74 | # def __getitem__(self, index): 75 | # ann = self.annotation.iloc[index] 76 | # try: 77 | # image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 78 | # caption_str = ann[1] 79 | # if caption_str[0] == '[': 80 | # captions = ast.literal_eval(caption_str) 81 | # if len(captions) < 6: 82 | # temp_caption = captions[random.randint(0, len(captions) - 1)] 83 | # else: 84 | # temp_caption = captions[random.randint(0, len(captions)//3 - 1)] + ', ' + \ 85 | # captions[random.randint(len(captions)//3, len(captions)//3 * 2)] + ', ' + \ 86 | # captions[random.randint(len(captions)//3*2, len(captions)-1)] 87 | # else: 88 | # temp_caption = caption_str 89 | # temp_caption = temp_caption + self.generate_bbox_annotation(image, ann) 90 | # except Exception as e: 91 | # print(e) 92 | # return self.__getitem__(random.randint(0, self.__len__()-1)) 93 | # image = self.transform(image) 94 | # caption = pre_caption(temp_caption, 30) 95 | # # caption = ann['caption'] 96 | # return image, caption -------------------------------------------------------------------------------- /src/blip_src/eval_nocaps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel.yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from data import create_dataset, create_sampler, create_loader 28 | from data.utils import save_result 29 | 30 | @torch.no_grad() 31 | def evaluate(model, data_loader, device, config): 32 | # evaluate 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | print_freq = 10 38 | 39 | result = [] 40 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 41 | 42 | image = image.to(device) 43 | 44 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 45 | min_length=config['min_length'], repetition_penalty=1.1) 46 | 47 | for caption, img_id in zip(captions, image_id): 48 | result.append({"image_id": img_id.item(), "caption": caption}) 49 | 50 | return result 51 | 52 | 53 | def main(args, config): 54 | utils.init_distributed_mode(args) 55 | 56 | device = torch.device(args.device) 57 | 58 | # fix the seed for reproducibility 59 | seed = args.seed + utils.get_rank() 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | cudnn.benchmark = True 64 | 65 | #### Dataset #### 66 | print("Creating captioning dataset") 67 | val_dataset, test_dataset = create_dataset('nocaps', config) 68 | 69 | if args.distributed: 70 | num_tasks = utils.get_world_size() 71 | global_rank = utils.get_rank() 72 | samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) 73 | else: 74 | samplers = [None,None] 75 | 76 | val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, 77 | batch_size=[config['batch_size']]*2,num_workers=[4,4], 78 | is_trains=[False, False], collate_fns=[None,None]) 79 | 80 | #### Model #### 81 | print("Creating model") 82 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 83 | prompt=config['prompt']) 84 | 85 | model = model.to(device) 86 | 87 | model_without_ddp = model 88 | if args.distributed: 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 90 | model_without_ddp = model.module 91 | 92 | val_result = evaluate(model_without_ddp, val_loader, device, config) 93 | val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') 94 | test_result = evaluate(model_without_ddp, test_loader, device, config) 95 | test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', default='./configs/nocaps.yaml') 101 | parser.add_argument('--output_dir', default='output/NoCaps') 102 | parser.add_argument('--device', default='cuda') 103 | parser.add_argument('--seed', default=42, type=int) 104 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | parser.add_argument('--distributed', default=True, type=bool) 107 | parser.add_argument( 108 | "--local_rank", 109 | default=0, 110 | type=int, 111 | help="""local rank for distrbuted training.""", 112 | ) 113 | args = parser.parse_args() 114 | 115 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 116 | 117 | args.result_dir = os.path.join(args.output_dir, 'result') 118 | 119 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 120 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 121 | 122 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 123 | 124 | main(args, config) -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | We prepare the pre-training corpus following OSCAR and BLIP. 3 | __As the data prepartion is very time consuming, we provide our experience for reference.__ 4 | 5 | ## 1. Download Datasets (images) 6 | ### Pre-train Datasets: 7 | 8 | ### CC3M 9 | Step1: First download train/val/test annotation files include URL from [google-research-datasets](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md). 10 | 11 | Step2: We provided our script for downloading and split CC3M into subsplit in [cc3m_download.py](https://huggingface.co/sail/PTP/blob/main/download_cc3m.py). 12 | **It's better to use our cript for downloading as the filename maybe different with different preprocess.** 13 | 14 | Notice we only download 2.8M data as some URLs has invalid. 15 | 16 | ### SBU 17 | First from annotation files include URL from [huggingface](https://huggingface.co/datasets/sbu_captions). 18 | 19 | Tip: We provided our script for downloading sbu: 20 | [download_sbu.py](https://huggingface.co/sail/PTP/blob/main/download_sbu.py) 21 | 22 | ### Visual Genome 23 | 24 | Download image (version1.2) from [visualgenome](https://visualgenome.org/api/v0/api_home.html). 25 | 26 | The download dirs will be VG_100K and VG_100K_2. 27 | ```bash 28 | mkdir image 29 | mv VG_100K/* image/ 30 | mv VG_100K_2/* image/ 31 | ``` 32 | 33 | ### COCO 34 | 35 | Down image (coco2014) from [coco](https://cocodataset.org/#download). 36 | Download 2014 Train, 2014 val and 2015 Test images. 37 | 38 | ### CC12M 39 | Step1: Download annotation files include URLs from [google-research-datasets](https://github.com/google-research-datasets/conceptual-12m). 40 | 41 | Step2: Just modify the source tsv file and image path in cc3m_download.py. Then download data the same as cc3m. 42 | 43 | Notice we only download 10M data as some URLs has invalid. 44 | 45 | ### Fine-tune Datasets: 46 | 47 | ### COCO 48 | Down image (coco2014) from [coco](https://cocodataset.org/#download). 49 | Download 2014 Train, 2014 val, 2014 test and 2015 Test images. 50 | 51 | 52 | ### Flickr30K 53 | Download image from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset). 54 | 55 | ### VQA V2 56 | 57 | Download images from [VQA](https://visualqa.org/download.html). 58 | 59 | ### NLVR 60 | Download images from [NLVR](https://lil.nlp.cornell.edu/nlvr/). 61 | 62 | ## Originze Datasets 63 | 64 | Prepare the datasets as follow: 65 | ``` 66 | Dataset/ 67 | CC3M/ 68 | images/ 69 | train/x/*.jpg 70 | val/x/*.jpg 71 | SBU/ 72 | dataset/ 73 | train/x/*.png 74 | coco2014/ 75 | COCO2014/ 76 | train2014/*.jpg 77 | val2014/*.jpg 78 | test2015/*.jpg 79 | 80 | VisualGenome/ 81 | image/*.jpg 82 | ``` 83 | 84 | Use soft link to map directory, for example 85 | ```bash 86 | ln -s [PATH_TO_COCO2014] Dataset/coco2014/COCO2014 87 | ``` 88 | 89 | ## 2. Download/Prepare Corpus (image-text pair) 90 | We provide two kinds of shuffled image-text pair. We use object information from [OSCAR](https://github.com/microsoft/Oscar/blob/master/VinVL_DOWNLOAD.md) and follow [BLIP](https://github.com/salesforce/BLIP) for caption refine. 91 | 1. Specifically, we download corups and object features from OSCAR codebase first. Follow [download_cc3m_predictions.sh](src/data_preprocess/download_cc3m_predictions.sh) for details. Download COCOTrain, CC Train, SBU (all) and VG. 92 | 2. Then Generate object_bbox and object_classes from object feature. Follow [generate_sample_with_bbox_and_classes.py](src/data_preprocess/generate_sample_with_bbox_and_classes.py) for details. 93 | 3. At last, use generated caption to padding with origing caption, follow BLIP. 94 | 95 | **Notice each COCO image include 5 text in [oscar corpus](https://biglmdiag.blob.core.windows.net/vinvl/pretrain_corpus/coco_flickr30k_gqa.tsv). As COCO is high-quality caption, it will affect the final downstream result much.** 96 | 97 | Make sure each line in corpus is 98 | ``` 99 | [image, refined_caption, object_bbox, object_classes] 100 | ``` 101 | A example is given below: 102 | 103 | ```bash 104 | CC3M/images/train/1597/3250687125.jpg i shall be bringing this white chair and table to the shoot; a white table with two white chairs and a couch [[340, 226, 417, 323], [16, 364, 348, 810], [256, 206, 380, 325], [195, 322, 627, 899], [0, 0, 192, 288], [568, 198, 730, 335], [95, 107, 202, 141], [531, 0, 732, 191], [666, 244, 734, 369], [378, 208, 677, 341]] ['pillow', 'chair', 'pillow', 'table', 'window', 'pillow', 'box', 'window', 'pillow', 'pillow'] 105 | ``` 106 | 107 | - 2.8M Image (2G): [CC3M](https://drive.google.com/file/d/1iO-d5e7mOvWEreDrlNyEc_RU_gP7FNBk/view?usp=sharing) 108 | 109 | The filtered file path is: 110 | 111 | - 4M Image (2.38G): [CC3M+COCO+VG+SBU](https://drive.google.com/file/d/1NnI-_ha4oqeZeHVOv1GBcvV1txgO9R68/view?usp=sharing) 112 | 113 | Thanks Jaeseok Byun for helping correct this corpus. 114 | 115 | As we used all spaces for huggingface and google driver now, follow mentonied way to prepare more large corpus. -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | # Getting Started with PTP Model 2 | 3 | ## PTP-BLIP 4 | 5 | ### 1. Download VIT-Base Models. 6 | 7 | ```bash 8 | cd ptp/src/blip_src 9 | mkdir pretrained_models && cd pretrained_models; 10 | wget -c https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth; 11 | ``` 12 | 13 | ### 2. Pre-train 14 | 15 | ```bash 16 | python move_pretrained_weights.py; 17 | 18 | python -m torch.distributed.launch --nproc_per_node=8 pretrain_concated_pred_tsv.py \ 19 | --config ./configs/mt_pt/tsv/pretrain_concated_pred_4M.yaml --output_dir output/Pretrain_concated_pred_4M 20 | 21 | echo "output dir is: output/Pretrain_concated_pred_4M" 22 | 23 | ``` 24 | 25 | Alternatively, download our pretrained model from [MODEL_ZOO.md](MODEL_ZOO.md). 26 | 27 | ### 3. Downstream Task Evaluation 28 | After pre-trained, replace the **pretrained:** in yaml of each task with pre-trained model or downloaded model. 29 | Then we provide run scripts for these tasks: 30 | 31 | #### Captioning 32 | 33 | ```bash 34 | # image captioning 35 | 36 | function rand(){ 37 | min=$1 38 | max=$(($2-$min+1)) 39 | num=$(date +%s%N) 40 | echo $(($num%$max+$min)) 41 | } 42 | 43 | 44 | echo '(coco captioning:) load pretrained model from: '; 45 | 46 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 47 | --config ./configs/caption_coco.yaml \ 48 | --output_dir output/captioning_coco 49 | ``` 50 | 51 | #### Retrieval 52 | 53 | ```bash 54 | function rand(){ 55 | min=$1 56 | max=$(($2-$min+1)) 57 | num=$(date +%s%N) 58 | echo $(($num%$max+$min)) 59 | } 60 | 61 | 62 | # ============================== step1: zero-shot retrieval evaluation ========= 63 | # evaluate on test 64 | echo '(coco step1: zero-shot evaluation) load pretrained model from: '; 65 | 66 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 67 | --config ./configs/retrieval_coco.yaml \ 68 | --output_dir output/retrieval_coco_zs \ 69 | --evaluate 70 | 71 | # ================= step2: train and evaluate val & test set ================= 72 | echo '(coco step2: train on retrieval) load pretrained model from: '; 73 | 74 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 75 | --config ./configs/retrieval_coco.yaml \ 76 | --output_dir output/retrieval_coco_ft 77 | 78 | ``` 79 | 80 | #### VQA 81 | 82 | 83 | ```bash 84 | # vqa 85 | 86 | function rand(){ 87 | min=$1 88 | max=$(($2-$min+1)) 89 | num=$(date +%s%N) 90 | echo $(($num%$max+$min)) 91 | } 92 | 93 | 94 | echo '(vqa:) load pretrained model from: '; 95 | 96 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) \ 97 | train_vqa.py --config ./configs/vqa.yaml \ 98 | --output_dir output/vqa_v2_vqa 99 | ``` 100 | 101 | After generate result files, submitted in [eval_ai](https://eval.ai/web/challenges/challenge-page/830) for final results. 102 | 103 | 104 | #### NLVR 105 | 106 | ```bash 107 | # NLVR2 108 | 109 | function rand(){ 110 | min=$1 111 | max=$(($2-$min+1)) 112 | num=$(date +%s%N) 113 | echo $(($num%$max+$min)) 114 | } 115 | 116 | echo '(NLVR:) load pretrained model from: '; 117 | 118 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_nlvr.py \ 119 | --config ./configs/nlvr.yaml \ 120 | --output_dir output/NLVR_NLVR 121 | ``` 122 | 123 | 124 | #### Run All Downstream Task At Once 125 | We also provide a shell script for all downstream task as below: 126 | 127 | ```bash 128 | bash multiple_scripts/multiple_exp_all_single_8u_ft.sh Pretrain_concated_pred_4M 129 | ``` 130 | 131 | where _Pretrain_concated_pred_4M_ is the pretrained output directory. 132 | 133 | The 134 | ```bash 135 | python move_pretrained_weights.py; 136 | 137 | gpu_num=8; 138 | time=$(date "+%Y-%m-%d-%H:%M:%S"); 139 | suffix=$1${time}; # the suffix to distingush different experiment, e.g. $1='generation_mix' 140 | 141 | PRETRAINED_MODEL="output\/$1\/checkpoint_19.pth" 142 | 143 | echo "${suffix}"; 144 | 145 | bash multiple_scripts/ft/exp_2.sh $gpu_num $suffix $PRETRAINED_MODEL; # captioning, ~1h 146 | 147 | bash multiple_scripts/ft/exp_5.sh $gpu_num $suffix $PRETRAINED_MODEL; # flickr30 retrieval, ~1h 148 | 149 | bash multiple_scripts/ft/exp_4.sh $gpu_num $suffix $PRETRAINED_MODEL; # NLVR2, ~2h 150 | 151 | bash multiple_scripts/ft/exp_1.sh $gpu_num $suffix $PRETRAINED_MODEL; # coco retrieval, ~12h 152 | 153 | bash multiple_scripts/ft/exp_3.sh $gpu_num $suffix $PRETRAINED_MODEL; # vqa, very slow, ~35 h 154 | 155 | 156 | ``` 157 | 158 | **Tip** 159 | The simplest way to evaluate model on all tasks is: 160 | ```bash 161 | bash multiple_scripts/multiple_exp_all_single_8u_ft.sh 162 | ``` 163 | 164 | 165 | ## PTP-ViLT 166 | 167 | ### 1. Pre-train 168 | ```bash 169 | python run.py with data_root=/dataset num_gpus=8 num_nodes=1 task_mlm_itm whole_word_masking=True step200k per_gpu_batchsize=64 170 | ``` 171 | 172 | ### 2. Downstream Tasks Evaluation 173 | 174 | #### vqa 175 | ```bash 176 | python run.py with data_root=/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_vqa_randaug test_only=True precision=32 load_path="weights/vilt_vqa.ckpt" 177 | 178 | ``` -------------------------------------------------------------------------------- /src/blip_src/data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class coco_karpathy_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. coco/images/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 18 | filename = 'coco_karpathy_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root # self.img_root = "/dataset/CC3M/images" # server 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class coco_karpathy_caption_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split): 54 | ''' 55 | image_root (string): Root directory of images (e.g. coco/images/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 61 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | def __len__(self): 70 | return len(self.annotation) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.image_root,ann['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 81 | 82 | return image, int(img_id) 83 | 84 | 85 | class coco_karpathy_retrieval_eval(Dataset): 86 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 87 | ''' 88 | image_root (string): Root directory of images (e.g. coco/images/) 89 | ann_root (string): directory to store the annotation file 90 | split (string): val or test 91 | ''' 92 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 93 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 94 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 95 | 96 | download_url(urls[split],ann_root) 97 | 98 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 99 | self.transform = transform 100 | self.image_root = image_root 101 | 102 | self.text = [] 103 | self.image = [] 104 | self.txt2img = {} 105 | self.img2txt = {} 106 | 107 | txt_id = 0 108 | for img_id, ann in enumerate(self.annotation): 109 | self.image.append(ann['image']) 110 | self.img2txt[img_id] = [] 111 | for i, caption in enumerate(ann['caption']): 112 | self.text.append(pre_caption(caption,max_words)) 113 | self.img2txt[img_id].append(txt_id) 114 | self.txt2img[txt_id] = img_id 115 | txt_id += 1 116 | 117 | def __len__(self): 118 | return len(self.annotation) 119 | 120 | def __getitem__(self, index): 121 | 122 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 123 | image = Image.open(image_path).convert('RGB') 124 | image = self.transform(image) 125 | 126 | return image, index -------------------------------------------------------------------------------- /src/blip_src/models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig 9 | from models.nlvr_encoder import BertModel 10 | from models.vit import interpolate_pos_embed 11 | from models.blip import create_vit, init_tokenizer, is_url 12 | 13 | # from timm.models.hub import download_cached_file 14 | import os 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | from transformers import BertTokenizer 19 | import numpy as np 20 | 21 | class BLIP_NLVR(nn.Module): 22 | def __init__(self, 23 | med_config = 'configs/med_config.json', 24 | image_size = 480, 25 | vit = 'base', 26 | vit_grad_ckpt = False, 27 | vit_ckpt_layer = 0, 28 | ): 29 | """ 30 | Args: 31 | med_config (str): path for the mixture of encoder-decoder model's configuration file 32 | image_size (int): input image size 33 | vit (str): model size of vision transformer 34 | """ 35 | super().__init__() 36 | 37 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 38 | self.tokenizer = init_tokenizer() 39 | med_config = BertConfig.from_json_file(med_config) 40 | med_config.encoder_width = vision_width 41 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 42 | 43 | self.cls_head = nn.Sequential( 44 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 45 | nn.ReLU(), 46 | nn.Linear(self.text_encoder.config.hidden_size, 2) 47 | ) 48 | 49 | def forward(self, image, text, targets, train=True): 50 | 51 | image_embeds = self.visual_encoder(image) 52 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 53 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 54 | 55 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 56 | text.input_ids[:,0] = self.tokenizer.enc_token_id 57 | 58 | output = self.text_encoder(text.input_ids, 59 | attention_mask = text.attention_mask, 60 | encoder_hidden_states = [image0_embeds,image1_embeds], 61 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 62 | image_atts[image0_embeds.size(0):]], 63 | return_dict = True, 64 | ) 65 | hidden_state = output.last_hidden_state[:,0,:] 66 | prediction = self.cls_head(hidden_state) 67 | 68 | if train: 69 | loss = F.cross_entropy(prediction, targets) 70 | return loss 71 | else: 72 | return prediction 73 | 74 | def blip_nlvr(pretrained='',**kwargs): 75 | model = BLIP_NLVR(**kwargs) 76 | if pretrained: 77 | model,msg = load_checkpoint(model,pretrained) 78 | print("missing keys:") 79 | print(msg.missing_keys) 80 | return model 81 | 82 | 83 | def load_checkpoint(model,url_or_filename): 84 | if os.path.isfile(url_or_filename): 85 | checkpoint = torch.load(url_or_filename, map_location='cpu') 86 | else: 87 | raise RuntimeError('checkpoint url or path is invalid') 88 | # if is_url(url_or_filename): 89 | # cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 90 | # checkpoint = torch.load(cached_file, map_location='cpu') 91 | # elif os.path.isfile(url_or_filename): 92 | # checkpoint = torch.load(url_or_filename, map_location='cpu') 93 | # else: 94 | # raise RuntimeError('checkpoint url or path is invalid') 95 | state_dict = checkpoint['model'] 96 | 97 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 98 | 99 | for key in list(state_dict.keys()): 100 | if 'crossattention.self.' in key: 101 | new_key0 = key.replace('self','self0') 102 | new_key1 = key.replace('self','self1') 103 | state_dict[new_key0] = state_dict[key] 104 | state_dict[new_key1] = state_dict[key] 105 | elif 'crossattention.output.dense.' in key: 106 | new_key0 = key.replace('dense','dense0') 107 | new_key1 = key.replace('dense','dense1') 108 | state_dict[new_key0] = state_dict[key] 109 | state_dict[new_key1] = state_dict[key] 110 | 111 | msg = model.load_state_dict(state_dict,strict=False) 112 | print('load checkpoint from %s'%url_or_filename) 113 | return model,msg 114 | -------------------------------------------------------------------------------- /src/blip_src/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import random 8 | import utils 9 | import matplotlib.pyplot as plt 10 | from PIL import Image, ImageDraw 11 | import numpy as np 12 | import cv2 13 | 14 | colormaps = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (61, 145, 64), (127, 255, 212), (0, 201, 87), 15 | (218, 112, 214), (255, 0, 255), (112, 128, 105), (250, 235, 215), 16 | (240, 255, 255), (252, 230, 201), (255, 255, 0), (235, 142, 85), 17 | (255, 97, 0), (176, 224, 230), (65, 106, 225,), (0, 255, 255), 18 | (56, 94, 15), (8, 46, 84), (255, 192, 203)] 19 | 20 | def pre_caption(caption,max_words=50): 21 | caption = re.sub( 22 | r"([.!\"()*#:;~])", 23 | ' ', 24 | caption.lower(), 25 | ) 26 | caption = re.sub( 27 | r"\s{2,}", 28 | ' ', 29 | caption, 30 | ) 31 | caption = caption.rstrip('\n') 32 | caption = caption.strip(' ') 33 | 34 | #truncate caption 35 | caption_words = caption.split(' ') 36 | if len(caption_words)>max_words: 37 | caption = ' '.join(caption_words[:max_words]) 38 | 39 | return caption 40 | 41 | def pre_question(question,max_ques_words=50): 42 | question = re.sub( 43 | r"([.!\"()*#:;~])", 44 | '', 45 | question.lower(), 46 | ) 47 | question = question.rstrip(' ') 48 | 49 | #truncate question 50 | question_words = question.split(' ') 51 | if len(question_words)>max_ques_words: 52 | question = ' '.join(question_words[:max_ques_words]) 53 | 54 | return question 55 | 56 | def draw_bboxs(image, bboxs): 57 | image_w_box = ImageDraw.Draw(image) 58 | for bbox in bboxs: 59 | image_w_box.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], fill=False, outline="red", width=4) 60 | return image 61 | 62 | def draw_bboxs_color_prompt(image, bboxs): 63 | img = np.asarray(image) 64 | for index, bbox in enumerate(bboxs): 65 | sub_img = img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] 66 | white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255 67 | white_rect[:, :, 0] = colormaps[index][0] 68 | white_rect[:, :, 1] = colormaps[index][1] 69 | white_rect[:, :, 2] = colormaps[index][2] 70 | res = cv2.addWeighted(sub_img, 0.7, white_rect, 0.3, 1.0) 71 | img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = res 72 | cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colormaps[index], 3) 73 | img = Image.fromarray(img.astype('uint8')) 74 | return image 75 | 76 | def save_result(result, result_dir, filename, remove_duplicate=''): 77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 78 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 79 | 80 | json.dump(result,open(result_file,'w')) 81 | 82 | dist.barrier() 83 | 84 | if utils.is_main_process(): 85 | # combine results from all processes 86 | result = [] 87 | 88 | for rank in range(utils.get_world_size()): 89 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 90 | res = json.load(open(result_file,'r')) 91 | result += res 92 | 93 | if remove_duplicate: 94 | result_new = [] 95 | id_list = [] 96 | for res in result: 97 | if res[remove_duplicate] not in id_list: 98 | id_list.append(res[remove_duplicate]) 99 | result_new.append(res) 100 | result = result_new 101 | 102 | json.dump(result,open(final_result_file,'w')) 103 | print('result file saved to %s'%final_result_file) 104 | 105 | return final_result_file 106 | 107 | 108 | 109 | from pycocotools.coco import COCO 110 | from pycocoevalcap.eval import COCOEvalCap 111 | from torchvision.datasets.utils import download_url 112 | 113 | def coco_caption_eval(coco_gt_root, results_file, split): 114 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 115 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 116 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 117 | 118 | download_url(urls[split],coco_gt_root) 119 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 120 | 121 | # create coco object and coco_result object 122 | coco = COCO(annotation_file) 123 | coco_result = coco.loadRes(results_file) 124 | 125 | # create coco_eval object by taking coco and coco_result 126 | coco_eval = COCOEvalCap(coco, coco_result) 127 | 128 | # evaluate on a subset of images by setting 129 | # coco_eval.params['image_id'] = coco_result.getImgIds() 130 | # please remove this line when evaluating the full validation set 131 | # coco_eval.params['image_id'] = coco_result.getImgIds() 132 | 133 | # evaluate results 134 | # SPICE will take a few minutes the first time, but speeds up due to caching 135 | coco_eval.evaluate() 136 | 137 | # print output evaluation scores 138 | for metric, score in coco_eval.eval.items(): 139 | print(f'{metric}: {score:.3f}') 140 | 141 | return coco_eval -------------------------------------------------------------------------------- /src/blip_src/data/init_data_concated_pred_tsv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset_concated_pred_tsv import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | # transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_cc3m': 37 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test) 38 | return dataset 39 | 40 | elif dataset=='caption_coco': 41 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 42 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 43 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 44 | return train_dataset, val_dataset, test_dataset 45 | 46 | elif dataset=='nocaps': 47 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return val_dataset, test_dataset 50 | 51 | elif dataset=='retrieval_coco': 52 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 53 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 54 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 55 | return train_dataset, val_dataset, test_dataset 56 | 57 | elif dataset=='retrieval_flickr': 58 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 59 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 60 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 61 | return train_dataset, val_dataset, test_dataset 62 | 63 | elif dataset=='vqa': 64 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 65 | train_files = config['train_files'], split='train') 66 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 67 | return train_dataset, test_dataset 68 | 69 | elif dataset=='nlvr': 70 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 71 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 72 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 73 | return train_dataset, val_dataset, test_dataset 74 | 75 | 76 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 77 | samplers = [] 78 | for dataset,shuffle in zip(datasets,shuffles): 79 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 80 | samplers.append(sampler) 81 | return samplers 82 | 83 | 84 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 85 | loaders = [] 86 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 87 | if is_train: 88 | shuffle = (sampler is None) 89 | drop_last = True 90 | else: 91 | shuffle = False 92 | drop_last = False 93 | loader = DataLoader( 94 | dataset, 95 | batch_size=bs, 96 | num_workers=n_worker, 97 | pin_memory=True, 98 | sampler=sampler, 99 | shuffle=shuffle, 100 | collate_fn=collate_fn, 101 | drop_last=drop_last, 102 | ) 103 | loaders.append(loader) 104 | return loaders 105 | 106 | -------------------------------------------------------------------------------- /src/blip_src/data/init_data_concated_pred_refined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset_concated_pred_refined import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | # normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | # normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_cc3m': 37 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test) 38 | return dataset 39 | 40 | elif dataset=='caption_coco': 41 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 42 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 43 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 44 | return train_dataset, val_dataset, test_dataset 45 | 46 | elif dataset=='nocaps': 47 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return val_dataset, test_dataset 50 | 51 | elif dataset=='retrieval_coco': 52 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 53 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 54 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 55 | return train_dataset, val_dataset, test_dataset 56 | 57 | elif dataset=='retrieval_flickr': 58 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 59 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 60 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 61 | return train_dataset, val_dataset, test_dataset 62 | 63 | elif dataset=='vqa': 64 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 65 | train_files = config['train_files'], split='train') 66 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 67 | return train_dataset, test_dataset 68 | 69 | elif dataset=='nlvr': 70 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 71 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 72 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 73 | return train_dataset, val_dataset, test_dataset 74 | 75 | 76 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 77 | samplers = [] 78 | for dataset,shuffle in zip(datasets,shuffles): 79 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 80 | samplers.append(sampler) 81 | return samplers 82 | 83 | 84 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 85 | loaders = [] 86 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 87 | if is_train: 88 | shuffle = (sampler is None) 89 | drop_last = True 90 | else: 91 | shuffle = False 92 | drop_last = False 93 | loader = DataLoader( 94 | dataset, 95 | batch_size=bs, 96 | num_workers=n_worker, 97 | pin_memory=True, 98 | sampler=sampler, 99 | shuffle=shuffle, 100 | collate_fn=collate_fn, 101 | drop_last=drop_last, 102 | ) 103 | loaders.append(loader) 104 | return loaders 105 | 106 | -------------------------------------------------------------------------------- /src/blip_src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | if 'mixup' in config: 34 | if config['mixup']: 35 | mixup = True 36 | else: 37 | mixup = False 38 | else: 39 | mixup = False 40 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train, img_mixup=mixup) 41 | return dataset 42 | elif dataset=='caption_cc3m': 43 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test, img_mixup=False) 44 | return dataset 45 | elif dataset=='caption_coco': 46 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 47 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return train_dataset, val_dataset, test_dataset 50 | 51 | elif dataset=='nocaps': 52 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 53 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 54 | return val_dataset, test_dataset 55 | 56 | elif dataset=='retrieval_coco': 57 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 58 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 59 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 60 | return train_dataset, val_dataset, test_dataset 61 | 62 | elif dataset=='retrieval_flickr': 63 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 64 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 65 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 66 | return train_dataset, val_dataset, test_dataset 67 | 68 | elif dataset=='vqa': 69 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 70 | train_files = config['train_files'], split='train') 71 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 72 | return train_dataset, test_dataset 73 | 74 | elif dataset=='nlvr': 75 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 76 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 77 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 78 | return train_dataset, val_dataset, test_dataset 79 | 80 | 81 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 82 | samplers = [] 83 | for dataset,shuffle in zip(datasets,shuffles): 84 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 85 | samplers.append(sampler) 86 | return samplers 87 | 88 | 89 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 90 | loaders = [] 91 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 92 | if is_train: 93 | shuffle = (sampler is None) 94 | drop_last = True 95 | else: 96 | shuffle = False 97 | drop_last = False 98 | loader = DataLoader( 99 | dataset, 100 | batch_size=bs, 101 | num_workers=n_worker, 102 | pin_memory=True, 103 | sampler=sampler, 104 | shuffle=shuffle, 105 | collate_fn=collate_fn, 106 | drop_last=drop_last, 107 | ) 108 | loaders.append(loader) 109 | return loaders 110 | 111 | -------------------------------------------------------------------------------- /src/blip_src/data/pretrain_dataset_concated_pred_refined.py: -------------------------------------------------------------------------------- 1 | # augment the boundng box with the original images 2 | import json 3 | import os 4 | import random 5 | import pandas as pd 6 | import torch 7 | import torchvision 8 | from torch.utils.data import Dataset 9 | 10 | from PIL import Image 11 | import numpy as np 12 | from PIL import ImageFile 13 | from PIL.Image import blend as blend 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | Image.MAX_IMAGE_PIXELS = None 16 | 17 | import PIL 18 | from matplotlib.patches import Rectangle 19 | import matplotlib.pyplot as plt 20 | 21 | from data.utils import pre_caption 22 | import os,glob 23 | import math 24 | import cv2 25 | import ast 26 | import gc 27 | 28 | class pretrain_dataset(Dataset): 29 | def __init__(self, ann_file, laion_path, transform): 30 | self.img_root = "/dataset" # server 31 | self.ann_pretrain = None 32 | for f in ann_file: 33 | ann_temp = pd.read_csv(f, sep='\t', header=None) 34 | if self.ann_pretrain is None: 35 | self.ann_pretrain = ann_temp 36 | else: 37 | self.ann_pretrain = pd.concat([self.ann_pretrain, ann_temp], ignore_index=True, sort=False) 38 | 39 | self.annotation = self.ann_pretrain 40 | 41 | self.transform = transform 42 | self.normalize = torchvision.transforms.Compose([torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 43 | 44 | 45 | def reload_laion(self, epoch): 46 | n = epoch%len(self.laion_files) 47 | print('loading '+self.laion_files[n]) 48 | with open(self.laion_files[n],'r') as f: 49 | self.ann_laion = json.load(f) 50 | 51 | self.annotation = self.ann_pretrain + self.ann_laion 52 | 53 | def __len__(self): 54 | return len(self.annotation) 55 | 56 | def generate_bbox_img(self, img, ann): 57 | object_num = len(ann[3]) 58 | sample_index = random.randint(0, object_num-1) 59 | object_tag = ann[3][sample_index] 60 | w, h = img.size 61 | bbox_loc = ann[2][sample_index] 62 | im = PIL.Image.new(mode="RGB", size=(w, h)) 63 | # im = np.asarray(im) 64 | im = np.array(im) 65 | im[bbox_loc[1]:bbox_loc[3], bbox_loc[0]:bbox_loc[2]] = 255 66 | im = Image.fromarray(im) 67 | return im, object_tag 68 | 69 | def find_squares(self, image): 70 | square_exist = False 71 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 72 | ret, binary = cv2.threshold(gray, 155, 255, cv2.THRESH_BINARY) 73 | contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 74 | # print(len(contours)) 75 | if len(contours) > 0: 76 | x, y, w, h = cv2.boundingRect(contours[0]) 77 | square_exist = True 78 | else: 79 | x, y, w, h = 0, 0, image.shape[0], image.shape[1] 80 | return [x, y, w, h], square_exist 81 | 82 | def generate_coord_info(self, bbox_img): 83 | bbox_img = torchvision.transforms.ToPILImage()(bbox_img) 84 | bbox_img = np.asarray(bbox_img) 85 | h, w = bbox_img.shape[:2] 86 | [x_b, y_b, w_b, h_b], square_exist = self.find_squares(bbox_img) 87 | cv2.rectangle(bbox_img, (x_b, y_b), (x_b+w_b, y_b+h_b), (0, 255, 0), 2) 88 | # x_b = int(x_b/h*100) 89 | # y_b = int(y_b/w*100) 90 | # w_b = int(w_b/h*100) 91 | # h_b = int(h_b/w*100) 92 | return [x_b, y_b, w_b, h_b], [h, w], square_exist 93 | 94 | def __getitem__(self, index): 95 | ann = self.annotation.iloc[index] 96 | try: 97 | image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 98 | if len(ann.keys()) > 2: 99 | ann[2] = json.loads(ann[2]) 100 | ann[3] = ast.literal_eval(ann[3]) 101 | if len(ann[2]) > 0: 102 | # w, h = img.size # original img size 103 | # original_image = np.asarray(image) 104 | bbox_img, object_tag = self.generate_bbox_img(image, ann) 105 | seed = np.random.randint(2147483647) # make a seed with numpy generator 106 | random.seed(seed) # apply this seed to img tranfsorms 107 | torch.manual_seed(seed) # needed for torchvision 0.7 108 | np.random.seed(seed) 109 | image = self.transform(image) 110 | # original_bbox_image = np.asarray(bbox_img) 111 | # step3: transform bbox image and getting block information 112 | random.seed(seed) # apply this seed to target tranfsorms 113 | torch.manual_seed(seed) # needed for torchvision 0.7 114 | np.random.seed(seed) 115 | bbox_img = self.transform(bbox_img) 116 | [x_b, y_b, w_b, h_b], [h, w], square_exist = self.generate_coord_info(bbox_img) 117 | 118 | # prevent memory leverage 119 | del bbox_img 120 | gc.collect() 121 | 122 | # step4: get the block_index, use x to mean there is no box or all the figure 123 | if square_exist == False: 124 | block = 'x' 125 | else: 126 | w_1 = min(int((x_b + w_b/2)/w * 3), 2) 127 | h_1 = min(int((y_b + h_b/2)/h * 3), 2) 128 | # print(w_1, h_1, w, h) 129 | block = 3*h_1 + w_1 130 | prompt_text = '. The block ' + str(block) + ' has a ' + object_tag + '.' 131 | ann[1] = pre_caption(ann[1], 22) + '. ' + pre_caption(prompt_text, 8) 132 | else: 133 | image = self.transform(image) 134 | ann[1] = pre_caption(ann[1], 30) 135 | else: 136 | image = self.transform(image) 137 | ann[1] = pre_caption(ann[1], 30) 138 | except Exception as e: 139 | print(e) 140 | return self.__getitem__(random.randint(0, self.__len__()-1)) 141 | caption = ann[1] 142 | # print(caption) 143 | image = self.normalize(image) 144 | return image, caption -------------------------------------------------------------------------------- /src/vilt_src/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("ViLT") 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "itm": 0, 9 | "mlm": 0, 10 | "mpp": 0, 11 | "vqa": 0, 12 | "nlvr2": 0, 13 | "irtr": 0, 14 | } 15 | ret.update(d) 16 | return ret 17 | 18 | 19 | @ex.config 20 | def config(): 21 | exp_name = "vilt" 22 | seed = 0 23 | datasets = ["coco", "vg", "sbu", "gcc"] 24 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 25 | batch_size = 4096 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 26 | 27 | # Image setting 28 | train_transform_keys = ["pixelbert"] 29 | val_transform_keys = ["pixelbert"] 30 | image_size = 384 31 | max_image_len = -1 32 | patch_size = 32 33 | draw_false_image = 1 34 | image_only = False 35 | 36 | # Text Setting 37 | vqav2_label_size = 3129 38 | max_text_len = 40 39 | tokenizer = "bert-base-uncased" 40 | vocab_size = 30522 41 | whole_word_masking = False 42 | mlm_prob = 0.15 43 | draw_false_text = 0 44 | 45 | # Transformer Setting 46 | vit = "vit_base_patch32_384" 47 | hidden_size = 768 48 | num_heads = 12 49 | num_layers = 12 50 | mlp_ratio = 4 51 | drop_rate = 0.1 52 | 53 | # Optimizer Setting 54 | optim_type = "adamw" 55 | learning_rate = 1e-4 56 | weight_decay = 0.01 57 | decay_power = 1 58 | max_epoch = 100 59 | max_steps = 25000 60 | warmup_steps = 2500 61 | end_lr = 0 62 | lr_mult = 1 # multiply lr for downstream heads 63 | 64 | # Downstream Setting 65 | get_recall_metric = False 66 | 67 | # PL Trainer Setting 68 | resume_from = None 69 | fast_dev_run = False 70 | val_check_interval = 1.0 71 | test_only = False 72 | 73 | # below params varies with the environment 74 | data_root = "" 75 | log_dir = "result" 76 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 77 | num_gpus = 1 78 | num_nodes = 1 79 | load_path = "" 80 | num_workers = 8 81 | precision = 16 82 | 83 | 84 | # Named configs for "environment" which define gpus and nodes, and paths 85 | @ex.named_config 86 | def env_dandelin(): 87 | data_root = "/data2/dsets/dataset" 88 | log_dir = "/data2/vilt/result" 89 | num_gpus = 8 90 | num_nodes = 1 91 | 92 | 93 | # Named configs for "task" which define datasets, loss_names and desired batch_size, warmup_steps, epochs, and exp_name 94 | @ex.named_config 95 | def task_mlm_itm(): 96 | exp_name = "mlm_itm" 97 | datasets = ["cc3m"] 98 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 99 | batch_size = 4096 100 | max_epoch = 10 101 | max_image_len = 200 102 | 103 | 104 | @ex.named_config 105 | def task_mlm_itm_randaug(): 106 | exp_name = "mlm_itm_randaug" 107 | datasets = ["coco", "vg", "sbu", "gcc"] 108 | train_transform_keys = ["pixelbert_randaug"] 109 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 110 | batch_size = 4096 111 | max_epoch = 10 112 | max_image_len = 200 113 | 114 | @ex.named_config 115 | def task_mlm_itm_mpp(): 116 | exp_name = "mlm_itm_mpp" 117 | datasets = ["coco", "vg", "sbu", "gcc"] 118 | loss_names = _loss_names({"itm": 1, "mlm": 1, "mpp": 1}) 119 | batch_size = 4096 120 | max_epoch = 10 121 | max_image_len = 200 122 | 123 | 124 | @ex.named_config 125 | def task_finetune_nlvr2(): 126 | exp_name = "finetune_nlvr2" 127 | datasets = ["nlvr2"] 128 | loss_names = _loss_names({"nlvr2": 1}) 129 | batch_size = 128 130 | max_epoch = 10 131 | max_steps = None 132 | warmup_steps = 0.1 133 | draw_false_image = 0 134 | learning_rate = 1e-4 135 | 136 | 137 | @ex.named_config 138 | def task_finetune_nlvr2_randaug(): 139 | exp_name = "finetune_nlvr2_randaug" 140 | datasets = ["nlvr2"] 141 | train_transform_keys = ["pixelbert_randaug"] 142 | loss_names = _loss_names({"nlvr2": 1}) 143 | batch_size = 128 144 | max_epoch = 10 145 | max_steps = None 146 | warmup_steps = 0.1 147 | draw_false_image = 0 148 | learning_rate = 1e-4 149 | 150 | 151 | @ex.named_config 152 | def task_finetune_vqa(): 153 | exp_name = "finetune_vqa" 154 | datasets = ["vqa"] 155 | loss_names = _loss_names({"vqa": 1}) 156 | batch_size = 256 157 | max_epoch = 10 158 | max_steps = None 159 | warmup_steps = 0.1 160 | draw_false_image = 0 161 | learning_rate = 1e-4 162 | val_check_interval = 0.1 163 | lr_mult = 10 164 | 165 | 166 | @ex.named_config 167 | def task_finetune_vqa_randaug(): 168 | exp_name = "finetune_vqa_randaug" 169 | datasets = ["vqa"] 170 | train_transform_keys = ["pixelbert_randaug"] 171 | loss_names = _loss_names({"vqa": 1}) 172 | batch_size = 256 173 | max_epoch = 10 174 | max_steps = None 175 | warmup_steps = 0.1 176 | draw_false_image = 0 177 | learning_rate = 1e-4 178 | val_check_interval = 0.1 179 | lr_mult = 10 180 | 181 | 182 | @ex.named_config 183 | def task_finetune_irtr_coco(): 184 | exp_name = "finetune_irtr_coco" 185 | datasets = ["coco"] 186 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 187 | batch_size = 256 188 | max_epoch = 10 189 | max_steps = None 190 | warmup_steps = 0.1 191 | get_recall_metric = True 192 | draw_false_text = 15 193 | learning_rate = 1e-4 194 | 195 | 196 | @ex.named_config 197 | def task_finetune_irtr_coco_randaug(): 198 | exp_name = "finetune_irtr_coco_randaug" 199 | datasets = ["coco"] 200 | train_transform_keys = ["pixelbert_randaug"] 201 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 202 | batch_size = 256 203 | max_epoch = 10 204 | max_steps = None 205 | warmup_steps = 0.1 206 | get_recall_metric = True 207 | draw_false_text = 15 208 | learning_rate = 1e-4 209 | 210 | 211 | @ex.named_config 212 | def task_finetune_irtr_f30k(): 213 | exp_name = "finetune_irtr_f30k" 214 | datasets = ["f30k"] 215 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 216 | batch_size = 256 217 | max_epoch = 10 218 | max_steps = None 219 | warmup_steps = 0.1 220 | get_recall_metric = True 221 | draw_false_text = 15 222 | learning_rate = 1e-4 223 | 224 | 225 | @ex.named_config 226 | def task_finetune_irtr_f30k_randaug(): 227 | exp_name = "finetune_irtr_f30k_randaug" 228 | datasets = ["f30k"] 229 | train_transform_keys = ["pixelbert_randaug"] 230 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 231 | batch_size = 256 232 | max_epoch = 10 233 | max_steps = None 234 | warmup_steps = 0.1 235 | get_recall_metric = True 236 | draw_false_text = 15 237 | learning_rate = 1e-4 238 | 239 | 240 | # Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end 241 | 242 | 243 | @ex.named_config 244 | def step25k(): 245 | max_epoch = 100 246 | max_steps = 25000 247 | 248 | 249 | @ex.named_config 250 | def step50k(): 251 | max_epoch = 100 252 | max_steps = 50000 253 | 254 | 255 | @ex.named_config 256 | def step100k(): 257 | max_epoch = 100 258 | max_steps = 100000 259 | 260 | 261 | @ex.named_config 262 | def step200k(): 263 | max_epoch = 200 264 | max_steps = 200000 265 | 266 | 267 | @ex.named_config 268 | def vit32_base(): 269 | vit = "vit_base_patch32_384" 270 | patch_size = 32 271 | hidden_size = 768 272 | num_heads = 12 273 | num_layers = 12 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PTP 2 | 3 | 8 | 9 | 10 | 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/position-guided-text-prompt-for-vision/cross-modal-retrieval-on-coco-2014)]( 12 | https://paperswithcode.com/sota/cross-modal-retrieval-on-coco-2014?p=position-guided-text-prompt-for-vision) 13 | 14 | 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/position-guided-text-prompt-for-vision/image-captioning-on-coco-captions)]( 16 | https://paperswithcode.com/sota/image-captioning-on-coco-captions?p=position-guided-text-prompt-for-vision) 17 | 18 | 19 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/position-guided-text-prompt-for-vision/zero-shot-cross-modal-retrieval-on-flickr30k)]( 20 | https://paperswithcode.com/sota/zero-shot-cross-modal-retrieval-on-flickr30k?p=position-guided-text-prompt-for-vision) 21 | 22 | 23 | This repository includes implementations of the following method: 24 | 25 | - [Position-guided Text Prompt for Vision Language Pre-training](https://arxiv.org/abs/2212.09737) 26 | 27 | ## Introduction 28 | The goal of Position-guided Text Prompt (PTP) is to bring position information into conventional Vision-Language Pre-training (VLP) models, as current mainstream e2e VLP models ignore this important cues. 29 | 30 | 31 |

32 | 33 |

34 | 35 | We observe **Position information is missed in a well-trained ViLT models.** 36 | 37 | 38 |

39 | 40 |

41 | 42 | **Our method provide a good altentive for existing object feature based methods (BUTD and the following works).** 43 | 44 | Some examples of one _PTP_ is show below: 45 |

46 | 47 |

48 | 49 | ## Updates 50 | 51 | - 2023.5 Modify the pre-training corpus to prevent confusing. 52 | - 2023.3 The Pre-training Code is released. 53 | - 2023.1 We have put the pretrained and fine-tuned weight on huggingface for fast download. 54 | - 2022.12 The first version of downstream evaluation code based on BLIP and pretrained/down-stream weight is released! The pre-training code is in cleaning up now. 55 | 56 | 57 | 58 | ## Installation 59 | 60 | Please find installation instructions for PyTorch in [INSTALL.md](INSTALL.md). 61 | 62 | 63 | ## Dataset Preparation 64 | 65 | You may follow the instructions in [DATASET.md](DATASET.md) to prepare the datasets. 66 | Considering the dataset prepartion is very time consuming, we provide detail guidence and provided our trained corpus. 67 | 68 | 69 | ## Pretrained & Finetune Models 70 | ### 1. Pre-trained Model 71 | 72 | | Method | Vision Encoder | #Images | Dataset | Pretrained Weights | Training Logs | 73 | | :--- | :--- | :--- | :--- | :----: | :---: | 74 | | PTP-BLIP| ViT-B(DeiT) | 4M | CC3M+COCO+VG+SBU | [link](https://huggingface.co/sail/PTP/blob/main/Pretrain_concated_pred_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_pretrain.txt) | 75 | 76 | ### 2. Zero-shot & Fine-tuning Downstream Model 77 | 78 | 79 | #### 2.1 Captioning 80 | | Method | B@4 | CIDEr | Config | 81 | | :--- | :--- | :--- | ---: | 82 | | PTP-BLIP| 40.1 | 135.0 | configs/caption_coco.yaml | 83 | 84 | 85 | #### 2.2 Zero-shot Retrieval 86 | 87 | 92 | 93 | ##### 2.2.2 Flickr30K 94 | 95 | | Method | I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 96 | | :--- | :--- | :--- | :--- | :--- | :---: | 97 | | PTP-BLIP| 86.4 | 67.0 | [link](https://huggingface.co/sail/PTP/blob/main/zero_shot_coco_checkpoint_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_zero_shot.txt) | configs/retrieval_flickr.yaml | 98 | 99 | 100 | #### 2.3 Retrieval (Fine-tune) 101 | 102 | Tip: Please use as large batch size as possible, we experimentally find that the larger batch size leads to better result for this task. Due to memory limiation, we use batch size 24 rather than 28 in original implmentation. 103 | 104 | 105 | ##### 2.3.1 COCO 106 | | Method |I2T@1 | T2I@1 | Config | 107 | | :--- | :--- | :--- | :---: | 108 | | PTP-BLIP| 77.6 | 59.4 | configs/retrieval_coco.yaml | 109 | 110 | 111 | ##### 2.3.2 Flickr30K 112 | | Method |I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 113 | | :--- | :--- | :--- | :--- | :--- | :---: | 114 | | PTP-BLIP| 96.1 | 84.2 | [link](https://huggingface.co/sail/PTP/blob/main/flickr30k_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_ft.txt) | configs/retrieval_flickr.yaml | 115 | 116 | #### 2.4 VQA V2 117 | 118 | | Method | Test-dev|Test-std |Model Weight | Training Logs | Config | 119 | | :--- | :--- | :--- | :--- | :--- | :---: | 120 | | PTP-BLIP| 76.02 | 76.18 | [link](https://huggingface.co/sail/PTP/blob/main/vqa_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_vqa_v2.txt) | configs/vqa.yaml | 121 | 122 | #### 2.5 NLVR 123 | 124 | | Method | Dev| Test-P | Model Weight | Training Logs | Config | 125 | | :--- | :--- | :--- | :--- | :--- | :---: | 126 | | PTP-BLIP| 80.45 | 80.70 | [link](https://huggingface.co/sail/PTP/blob/main/nlvr_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_nlvr.txt) | configs/nlvr.yaml | 127 | 128 | 129 | ## Quick Start 130 | 131 | Follow the example in [GETTING_STARTED.md](GETTING_STARTED.md) to start playing vlp models with PTP. 132 | 133 | ## Transfer To Other Architectures 134 | 135 | The PTP can easily transfer to other architectures without much effort. 136 | Specifically, change your base code with following two steps: 137 | 138 | - Download or generate corpus in the same format as ours. 139 | - Modify the dataset.py 140 | 141 | Then train the model with original objectives. 142 | 143 | ## Ackowledgement 144 | This work is mainly based on [BLIP](https://github.com/salesforce/BLIP) and [ViLT](https://github.com/dandelin/ViLT), thanks for these good baselines. 145 | We also refer [OSCAR](https://github.com/microsoft/Oscar) for ablation study and dataset preparation. 146 | 147 | ## License 148 | PTP is released under the Apache 2.0 license. 149 | 150 | ## Contact 151 | 152 | Email: awinyimgprocess at gmail dot com 153 | 154 | If you have any questions, please email me or open an new issue. 155 | 156 | ## Citation 157 | If you find our work helps, please use the following BibTeX entry for citation. 158 | 159 | ``` 160 | @article{wang2022ptp, 161 | title={Position-guided Text Prompt for Vision Language Pre-training}, 162 | author={Wang, Alex Jinpeng and Zhou, Pan and Shou, Mike Zheng and Yan, Shui Cheng}, 163 | journal={https://arxiv.org/abs/2212.09737}, 164 | year={2022} 165 | } 166 | ``` -------------------------------------------------------------------------------- /src/blip_src/pretrain_concated_pred_refined.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ruamel.yaml as yaml 4 | import numpy as np 5 | import random 6 | import time 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | from torch.utils.data import DataLoader 17 | 18 | from models.blip_pretrain import blip_pretrain 19 | import utils 20 | from utils import warmup_lr_schedule, step_lr_schedule 21 | from data.init_data_concated_pred_refined import create_dataset, create_sampler, create_loader 22 | 23 | 24 | def train(model, data_loader, optimizer, epoch, device, config): 25 | # train 26 | model.train() 27 | 28 | metric_logger = utils.MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 30 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 31 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 32 | metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 33 | 34 | header = 'Train Epoch: [{}]'.format(epoch) 35 | print_freq = 200 36 | 37 | if config['laion_path']: 38 | data_loader.dataset.reload_laion(epoch) 39 | 40 | data_loader.sampler.set_epoch(epoch) 41 | 42 | for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 43 | 44 | if epoch==0: 45 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 46 | 47 | optimizer.zero_grad() 48 | 49 | image = image.to(device,non_blocking=True) 50 | 51 | # ramp up alpha in the first 2 epochs 52 | alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) 53 | 54 | loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) 55 | loss = loss_ita + loss_itm + loss_lm 56 | 57 | loss.backward() 58 | optimizer.step() 59 | 60 | metric_logger.update(loss_ita=loss_ita.item()) 61 | metric_logger.update(loss_itm=loss_itm.item()) 62 | metric_logger.update(loss_lm=loss_lm.item()) 63 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 64 | 65 | 66 | # gather the stats from all processes 67 | metric_logger.synchronize_between_processes() 68 | print("Averaged stats:", metric_logger.global_avg()) 69 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 70 | 71 | 72 | def main(args, config): 73 | utils.init_distributed_mode(args) 74 | 75 | device = torch.device(args.gpu) 76 | 77 | # fix the seed for reproducibility 78 | seed = args.seed + utils.get_rank() 79 | torch.manual_seed(seed) 80 | np.random.seed(seed) 81 | random.seed(seed) 82 | cudnn.benchmark = True 83 | cudnn.deterministic = True 84 | #### Dataset #### 85 | print("Creating dataset") 86 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 87 | print('number of training samples: %d'%len(datasets[0])) 88 | 89 | num_tasks = utils.get_world_size() 90 | global_rank = utils.get_rank() 91 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 92 | 93 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[args.num_workers], is_trains=[True], collate_fns=[None])[0] 94 | 95 | print("="*50) 96 | print("time now is: ") 97 | print(time.strftime('%Y/%m/%d %H:%M:%S')) 98 | print("="*50) 99 | #### Model #### 100 | print("Creating model") 101 | model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 102 | vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) 103 | 104 | # model = model.to(device) 105 | model = model.cuda() 106 | 107 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 108 | 109 | start_epoch = 0 110 | if args.checkpoint: 111 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 112 | state_dict = checkpoint['model'] 113 | model.load_state_dict(state_dict) 114 | 115 | optimizer.load_state_dict(checkpoint['optimizer']) 116 | start_epoch = checkpoint['epoch']+1 117 | print('resume checkpoint from %s'%args.checkpoint) 118 | 119 | model_without_ddp = model 120 | if args.distributed: 121 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 122 | model_without_ddp = model.module 123 | 124 | print("Start training") 125 | start_time = time.time() 126 | for epoch in range(start_epoch, config['max_epoch']): 127 | 128 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 129 | 130 | train_stats = train(model, data_loader, optimizer, epoch, device, config) 131 | if utils.is_main_process(): 132 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 133 | 'epoch': epoch, 134 | } 135 | save_obj = { 136 | 'model': model_without_ddp.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'config': config, 139 | 'epoch': epoch, 140 | } 141 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 142 | 143 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 144 | f.write(json.dumps(log_stats) + "\n") 145 | 146 | dist.barrier() 147 | 148 | total_time = time.time() - start_time 149 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 150 | print('Training time {}'.format(total_time_str)) 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--config', default='./configs/pretrain.yaml') 156 | parser.add_argument('--output_dir', default='output/Pretrain') 157 | parser.add_argument('--checkpoint', default='') 158 | parser.add_argument('--evaluate', action='store_true') 159 | parser.add_argument('--device', default='cuda') 160 | parser.add_argument('--seed', default=0, type=int) 161 | parser.add_argument("--num_workers", default=12, type=int, help="""Number of data loading workers per GPU.""", ) 162 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 163 | parser.add_argument("--rank", default=0, type=int, help="""rank for distrbuted training.""") 164 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 165 | parser.add_argument('--distributed', default=True, type=bool) 166 | parser.add_argument("--local_rank", default=0, type=int, help="""local rank for distrbuted training.""",) 167 | args = parser.parse_args() 168 | 169 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 170 | 171 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 172 | 173 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 174 | 175 | # print("job beginning!") 176 | 177 | main(args, config) -------------------------------------------------------------------------------- /src/blip_src/pretrain_concated_pred_tsv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import ruamel.yaml as yaml 4 | import numpy as np 5 | import random 6 | import time 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | from torch.utils.data import DataLoader 17 | 18 | from models.blip_pretrain import blip_pretrain 19 | import utils 20 | from utils import warmup_lr_schedule, step_lr_schedule 21 | # from data import create_dataset, create_sampler, create_loader 22 | from data.init_data_concated_pred_tsv import create_dataset, create_sampler, create_loader 23 | 24 | 25 | def train(model, data_loader, optimizer, epoch, device, config): 26 | # train 27 | model.train() 28 | 29 | metric_logger = utils.MetricLogger(delimiter=" ") 30 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 31 | metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 32 | metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 33 | metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 34 | 35 | header = 'Train Epoch: [{}]'.format(epoch) 36 | print_freq = 200 37 | 38 | if config['laion_path']: 39 | data_loader.dataset.reload_laion(epoch) 40 | 41 | data_loader.sampler.set_epoch(epoch) 42 | 43 | for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | if epoch==0: 46 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 47 | 48 | optimizer.zero_grad() 49 | 50 | image = image.to(device,non_blocking=True) 51 | 52 | # ramp up alpha in the first 2 epochs 53 | alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) 54 | 55 | loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) 56 | loss = loss_ita + loss_itm + loss_lm 57 | 58 | loss.backward() 59 | optimizer.step() 60 | 61 | metric_logger.update(loss_ita=loss_ita.item()) 62 | metric_logger.update(loss_itm=loss_itm.item()) 63 | metric_logger.update(loss_lm=loss_lm.item()) 64 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 65 | 66 | 67 | # gather the stats from all processes 68 | metric_logger.synchronize_between_processes() 69 | print("Averaged stats:", metric_logger.global_avg()) 70 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 71 | 72 | 73 | def main(args, config): 74 | utils.init_distributed_mode(args) 75 | 76 | device = torch.device(args.gpu) 77 | 78 | # fix the seed for reproducibility 79 | seed = args.seed + utils.get_rank() 80 | torch.manual_seed(seed) 81 | np.random.seed(seed) 82 | random.seed(seed) 83 | cudnn.benchmark = True 84 | cudnn.deterministic = True 85 | #### Dataset #### 86 | print("Creating dataset") 87 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 88 | print('number of training samples: %d'%len(datasets[0])) 89 | 90 | num_tasks = utils.get_world_size() 91 | global_rank = utils.get_rank() 92 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 93 | 94 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[args.num_workers], is_trains=[True], collate_fns=[None])[0] 95 | 96 | print("="*50) 97 | print("time now is: ") 98 | print(time.strftime('%Y/%m/%d %H:%M:%S')) 99 | print("="*50) 100 | #### Model #### 101 | print("Creating model") 102 | model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 103 | vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) 104 | 105 | # model = model.to(device) 106 | model = model.cuda() 107 | 108 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 109 | 110 | start_epoch = 0 111 | if args.checkpoint: 112 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 113 | state_dict = checkpoint['model'] 114 | model.load_state_dict(state_dict) 115 | 116 | optimizer.load_state_dict(checkpoint['optimizer']) 117 | start_epoch = checkpoint['epoch']+1 118 | print('resume checkpoint from %s'%args.checkpoint) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 123 | model_without_ddp = model.module 124 | 125 | print("Start training") 126 | start_time = time.time() 127 | for epoch in range(start_epoch, config['max_epoch']): 128 | 129 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 130 | 131 | train_stats = train(model, data_loader, optimizer, epoch, device, config) 132 | if utils.is_main_process(): 133 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 134 | 'epoch': epoch, 135 | } 136 | save_obj = { 137 | 'model': model_without_ddp.state_dict(), 138 | 'optimizer': optimizer.state_dict(), 139 | 'config': config, 140 | 'epoch': epoch, 141 | } 142 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 143 | 144 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 145 | f.write(json.dumps(log_stats) + "\n") 146 | 147 | dist.barrier() 148 | 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | print('Training time {}'.format(total_time_str)) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--config', default='./configs/pretrain.yaml') 157 | parser.add_argument('--output_dir', default='output/Pretrain') 158 | parser.add_argument('--checkpoint', default='') 159 | parser.add_argument('--evaluate', action='store_true') 160 | parser.add_argument('--device', default='cuda') 161 | parser.add_argument('--seed', default=0, type=int) 162 | parser.add_argument("--num_workers", default=10, type=int, help="""Number of data loading workers per GPU.""", ) 163 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 164 | parser.add_argument("--rank", default=0, type=int, help="""rank for distrbuted training.""") 165 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 166 | parser.add_argument('--distributed', default=True, type=bool) 167 | parser.add_argument("--local_rank", default=0, type=int, help="""local rank for distrbuted training.""",) 168 | args = parser.parse_args() 169 | 170 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 171 | 172 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 173 | 174 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 175 | 176 | # print("job beginning!") 177 | 178 | main(args, config) -------------------------------------------------------------------------------- /src/blip_src/train_vqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel.yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | 25 | from models.blip_vqa import blip_vqa 26 | import utils 27 | from utils import cosine_lr_schedule 28 | from data import create_dataset, create_sampler, create_loader 29 | from data.vqa_dataset import vqa_collate_fn 30 | from data.utils import save_result 31 | 32 | 33 | def train(model, data_loader, optimizer, epoch, device): 34 | # train 35 | model.train() 36 | 37 | metric_logger = utils.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 39 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) 40 | 41 | header = 'Train Epoch: [{}]'.format(epoch) 42 | print_freq = 200 43 | 44 | for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True) 46 | 47 | loss = model(image, question, answer, train=True, n=n, weights=weights) 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | metric_logger.update(loss=loss.item()) 54 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 55 | 56 | # gather the stats from all processes 57 | metric_logger.synchronize_between_processes() 58 | print("Averaged stats:", metric_logger.global_avg()) 59 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 60 | 61 | 62 | @torch.no_grad() 63 | def evaluation(model, data_loader, device, config) : 64 | # test 65 | model.eval() 66 | 67 | metric_logger = utils.MetricLogger(delimiter=" ") 68 | header = 'Generate VQA test result:' 69 | print_freq = 200 70 | 71 | result = [] 72 | 73 | if config['inference']=='rank': 74 | answer_list = data_loader.dataset.answer_list 75 | answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) 76 | answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id 77 | 78 | for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 79 | image = image.to(device,non_blocking=True) 80 | 81 | if config['inference']=='generate': 82 | answers = model(image, question, train=False, inference='generate') 83 | 84 | for answer, ques_id in zip(answers, question_id): 85 | ques_id = int(ques_id.item()) 86 | result.append({"question_id":ques_id, "answer":answer}) 87 | 88 | elif config['inference']=='rank': 89 | answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test']) 90 | 91 | for ques_id, answer_id in zip(question_id, answer_ids): 92 | result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]}) 93 | 94 | return result 95 | 96 | 97 | def main(args, config): 98 | utils.init_distributed_mode(args) 99 | 100 | device = torch.device(args.device) 101 | 102 | # fix the seed for reproducibility 103 | seed = args.seed + utils.get_rank() 104 | torch.manual_seed(seed) 105 | np.random.seed(seed) 106 | random.seed(seed) 107 | cudnn.benchmark = True 108 | 109 | #### Dataset #### 110 | print("Creating vqa datasets") 111 | datasets = create_dataset('vqa', config) 112 | 113 | if args.distributed: 114 | num_tasks = utils.get_world_size() 115 | global_rank = utils.get_rank() 116 | samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) 117 | else: 118 | samplers = [None, None] 119 | 120 | train_loader, test_loader = create_loader(datasets,samplers, 121 | batch_size=[config['batch_size_train'],config['batch_size_test']], 122 | num_workers=[4,4],is_trains=[True, False], 123 | collate_fns=[vqa_collate_fn,None]) 124 | #### Model #### 125 | print("Creating model") 126 | model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'], 127 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 128 | 129 | model = model.to(device) 130 | 131 | model_without_ddp = model 132 | if args.distributed: 133 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 134 | model_without_ddp = model.module 135 | 136 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 137 | 138 | best = 0 139 | best_epoch = 0 140 | 141 | print("Start training") 142 | start_time = time.time() 143 | for epoch in range(0, config['max_epoch']): 144 | if not args.evaluate: 145 | if args.distributed: 146 | train_loader.sampler.set_epoch(epoch) 147 | 148 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 149 | 150 | train_stats = train(model, train_loader, optimizer, epoch, device) 151 | 152 | else: 153 | break 154 | 155 | if utils.is_main_process(): 156 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 157 | 'epoch': epoch, 158 | } 159 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 160 | f.write(json.dumps(log_stats) + "\n") 161 | 162 | save_obj = { 163 | 'model': model_without_ddp.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'config': config, 166 | 'epoch': epoch, 167 | } 168 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 169 | 170 | dist.barrier() 171 | 172 | vqa_result = evaluation(model_without_ddp, test_loader, device, config) 173 | result_file = save_result(vqa_result, args.result_dir, 'vqa_result') 174 | 175 | total_time = time.time() - start_time 176 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 177 | print('Training time {}'.format(total_time_str)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('--config', default='./configs/vqa.yaml') 184 | parser.add_argument('--output_dir', default='output/VQA') 185 | parser.add_argument('--evaluate', action='store_true') 186 | parser.add_argument('--device', default='cuda') 187 | parser.add_argument('--seed', default=42, type=int) 188 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 189 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 190 | parser.add_argument('--distributed', default=True, type=bool) 191 | parser.add_argument( 192 | "--local_rank", 193 | default=0, 194 | type=int, 195 | help="""local rank for distrbuted training.""", 196 | ) 197 | args = parser.parse_args() 198 | 199 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 200 | 201 | args.result_dir = os.path.join(args.output_dir, 'result') 202 | 203 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 204 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 205 | 206 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 207 | 208 | main(args, config) -------------------------------------------------------------------------------- /src/blip_src/train_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | 9 | import argparse 10 | import os 11 | import ruamel.yaml as yaml 12 | import numpy as np 13 | import random 14 | import time 15 | import datetime 16 | import json 17 | from pathlib import Path 18 | import json 19 | import pickle 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from torch.utils.data import DataLoader 25 | import torch.backends.cudnn as cudnn 26 | import torch.distributed as dist 27 | 28 | from models.blip_nlvr import blip_nlvr 29 | 30 | import utils 31 | from utils import cosine_lr_schedule, warmup_lr_schedule 32 | from data import create_dataset, create_sampler, create_loader 33 | 34 | def train(model, data_loader, optimizer, epoch, device, config): 35 | # train 36 | model.train() 37 | 38 | metric_logger = utils.MetricLogger(delimiter=" ") 39 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 40 | metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 41 | 42 | header = 'Train Epoch: [{}]'.format(epoch) 43 | print_freq = 200 44 | step_size = 10 45 | 46 | for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 47 | 48 | images = torch.cat([image0, image1], dim=0) 49 | images, targets = images.to(device), targets.to(device) 50 | 51 | loss = model(images, text, targets=targets, train=True) 52 | 53 | optimizer.zero_grad() 54 | loss.backward() 55 | optimizer.step() 56 | 57 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 58 | metric_logger.update(loss=loss.item()) 59 | 60 | # gather the stats from all processes 61 | metric_logger.synchronize_between_processes() 62 | print("Averaged stats:", metric_logger.global_avg()) 63 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 64 | 65 | 66 | @torch.no_grad() 67 | def evaluate(model, data_loader, device, config): 68 | # test 69 | model.eval() 70 | 71 | metric_logger = utils.MetricLogger(delimiter=" ") 72 | 73 | header = 'Evaluation:' 74 | print_freq = 200 75 | 76 | for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): 77 | images = torch.cat([image0, image1], dim=0) 78 | images, targets = images.to(device), targets.to(device) 79 | 80 | prediction = model(images, text, targets=targets, train=False) 81 | 82 | _, pred_class = prediction.max(1) 83 | accuracy = (targets==pred_class).sum() / targets.size(0) 84 | 85 | metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) 86 | 87 | # gather the stats from all processes 88 | metric_logger.synchronize_between_processes() 89 | 90 | print("Averaged stats:", metric_logger.global_avg()) 91 | return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 92 | 93 | 94 | 95 | def main(args, config): 96 | utils.init_distributed_mode(args) 97 | 98 | device = torch.device(args.device) 99 | 100 | # fix the seed for reproducibility 101 | seed = args.seed + utils.get_rank() 102 | torch.manual_seed(seed) 103 | np.random.seed(seed) 104 | random.seed(seed) 105 | cudnn.benchmark = True 106 | 107 | #### Dataset #### 108 | print("Creating dataset") 109 | datasets = create_dataset('nlvr', config) 110 | 111 | if args.distributed: 112 | num_tasks = utils.get_world_size() 113 | global_rank = utils.get_rank() 114 | samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank) 115 | else: 116 | samplers = [None, None, None] 117 | 118 | batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']] 119 | train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size, 120 | num_workers=[4,4,4],is_trains=[True,False,False], 121 | collate_fns=[None,None,None]) 122 | 123 | #### Model #### 124 | print("Creating model") 125 | model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'], 126 | vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) 127 | 128 | model = model.to(device) 129 | 130 | model_without_ddp = model 131 | if args.distributed: 132 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 133 | model_without_ddp = model.module 134 | 135 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) 136 | 137 | print("Start training") 138 | start_time = time.time() 139 | best = 0 140 | best_epoch = 0 141 | 142 | for epoch in range(0, config['max_epoch']): 143 | if not args.evaluate: 144 | if args.distributed: 145 | train_loader.sampler.set_epoch(epoch) 146 | 147 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 148 | 149 | train_stats = train(model, train_loader, optimizer, epoch, device, config) 150 | 151 | val_stats = evaluate(model, val_loader, device, config) 152 | test_stats = evaluate(model, test_loader, device, config) 153 | 154 | if utils.is_main_process(): 155 | if args.evaluate: 156 | log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}, 157 | **{f'test_{k}': v for k, v in test_stats.items()}, 158 | } 159 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 160 | f.write(json.dumps(log_stats) + "\n") 161 | 162 | else: 163 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 164 | **{f'val_{k}': v for k, v in val_stats.items()}, 165 | **{f'test_{k}': v for k, v in test_stats.items()}, 166 | 'epoch': epoch, 167 | } 168 | 169 | if float(val_stats['acc'])>best: 170 | save_obj = { 171 | 'model': model_without_ddp.state_dict(), 172 | 'optimizer': optimizer.state_dict(), 173 | 'config': config, 174 | 'epoch': epoch, 175 | } 176 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) 177 | best = float(val_stats['acc']) 178 | best_epoch = epoch 179 | 180 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 181 | f.write(json.dumps(log_stats) + "\n") 182 | if args.evaluate: 183 | break 184 | 185 | dist.barrier() 186 | 187 | if utils.is_main_process(): 188 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 189 | f.write("best epoch: %d"%best_epoch) 190 | 191 | total_time = time.time() - start_time 192 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 193 | print('Training time {}'.format(total_time_str)) 194 | 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('--config', default='./configs/nlvr.yaml') 199 | parser.add_argument('--output_dir', default='output/NLVR') 200 | parser.add_argument('--evaluate', action='store_true') 201 | parser.add_argument('--device', default='cuda') 202 | parser.add_argument('--seed', default=42, type=int) 203 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 204 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 205 | parser.add_argument('--distributed', default=True, type=bool) 206 | parser.add_argument( 207 | "--local_rank", 208 | default=0, 209 | type=int, 210 | help="""local rank for distrbuted training.""", 211 | ) 212 | args = parser.parse_args() 213 | 214 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 215 | 216 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 217 | 218 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 219 | 220 | main(args, config) -------------------------------------------------------------------------------- /src/blip_src/models/blip_vqa.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig, BertModel, BertLMHeadModel 9 | from models.blip import create_vit, init_tokenizer, load_checkpoint 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from transformers import BertTokenizer 15 | import numpy as np 16 | 17 | class BLIP_VQA(nn.Module): 18 | def __init__(self, 19 | med_config = 'configs/med_config.json', 20 | image_size = 480, 21 | vit = 'base', 22 | vit_grad_ckpt = False, 23 | vit_ckpt_layer = 0, 24 | ): 25 | """ 26 | Args: 27 | med_config (str): path for the mixture of encoder-decoder model's configuration file 28 | image_size (int): input image size 29 | vit (str): model size of vision transformer 30 | """ 31 | super().__init__() 32 | 33 | self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 34 | self.tokenizer = init_tokenizer() 35 | 36 | encoder_config = BertConfig.from_json_file(med_config) 37 | encoder_config.encoder_width = vision_width 38 | self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) 39 | 40 | decoder_config = BertConfig.from_json_file(med_config) 41 | self.text_decoder = BertLMHeadModel(config=decoder_config) 42 | 43 | 44 | def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): 45 | 46 | image_embeds = self.visual_encoder(image) 47 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 48 | 49 | question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 50 | return_tensors="pt").to(image.device) 51 | question.input_ids[:,0] = self.tokenizer.enc_token_id 52 | 53 | if train: 54 | ''' 55 | n: number of answers for each question 56 | weights: weight for each answer 57 | ''' 58 | answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) 59 | answer.input_ids[:,0] = self.tokenizer.bos_token_id 60 | answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) 61 | 62 | question_output = self.text_encoder(question.input_ids, 63 | attention_mask = question.attention_mask, 64 | encoder_hidden_states = image_embeds, 65 | encoder_attention_mask = image_atts, 66 | return_dict = True) 67 | 68 | question_states = [] 69 | question_atts = [] 70 | for b, n in enumerate(n): 71 | question_states += [question_output.last_hidden_state[b]]*n 72 | question_atts += [question.attention_mask[b]]*n 73 | question_states = torch.stack(question_states,0) 74 | question_atts = torch.stack(question_atts,0) 75 | 76 | answer_output = self.text_decoder(answer.input_ids, 77 | attention_mask = answer.attention_mask, 78 | encoder_hidden_states = question_states, 79 | encoder_attention_mask = question_atts, 80 | labels = answer_targets, 81 | return_dict = True, 82 | reduction = 'none', 83 | ) 84 | 85 | loss = weights * answer_output.loss 86 | loss = loss.sum()/image.size(0) 87 | 88 | return loss 89 | 90 | 91 | else: 92 | question_output = self.text_encoder(question.input_ids, 93 | attention_mask = question.attention_mask, 94 | encoder_hidden_states = image_embeds, 95 | encoder_attention_mask = image_atts, 96 | return_dict = True) 97 | 98 | if inference=='generate': 99 | num_beams = 3 100 | question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) 101 | question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) 102 | model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} 103 | 104 | bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) 105 | 106 | outputs = self.text_decoder.generate(input_ids=bos_ids, 107 | max_length=10, 108 | min_length=1, 109 | num_beams=num_beams, 110 | eos_token_id=self.tokenizer.sep_token_id, 111 | pad_token_id=self.tokenizer.pad_token_id, 112 | **model_kwargs) 113 | 114 | answers = [] 115 | for output in outputs: 116 | answer = self.tokenizer.decode(output, skip_special_tokens=True) 117 | answers.append(answer) 118 | return answers 119 | 120 | elif inference=='rank': 121 | max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 122 | answer.input_ids, answer.attention_mask, k_test) 123 | return max_ids 124 | 125 | 126 | 127 | def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): 128 | 129 | num_ques = question_states.size(0) 130 | start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token 131 | 132 | start_output = self.text_decoder(start_ids, 133 | encoder_hidden_states = question_states, 134 | encoder_attention_mask = question_atts, 135 | return_dict = True, 136 | reduction = 'none') 137 | logits = start_output.logits[:,0,:] # first token's logit 138 | 139 | # topk_probs: top-k probability 140 | # topk_ids: [num_question, k] 141 | answer_first_token = answer_ids[:,1] 142 | prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) 143 | topk_probs, topk_ids = prob_first_token.topk(k,dim=1) 144 | 145 | # answer input: [num_question*k, answer_len] 146 | input_ids = [] 147 | input_atts = [] 148 | for b, topk_id in enumerate(topk_ids): 149 | input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) 150 | input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) 151 | input_ids = torch.cat(input_ids,dim=0) 152 | input_atts = torch.cat(input_atts,dim=0) 153 | 154 | targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) 155 | 156 | # repeat encoder's output for top-k answers 157 | question_states = tile(question_states, 0, k) 158 | question_atts = tile(question_atts, 0, k) 159 | 160 | output = self.text_decoder(input_ids, 161 | attention_mask = input_atts, 162 | encoder_hidden_states = question_states, 163 | encoder_attention_mask = question_atts, 164 | labels = targets_ids, 165 | return_dict = True, 166 | reduction = 'none') 167 | 168 | log_probs_sum = -output.loss 169 | log_probs_sum = log_probs_sum.view(num_ques,k) 170 | 171 | max_topk_ids = log_probs_sum.argmax(dim=1) 172 | max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] 173 | 174 | return max_ids 175 | 176 | 177 | def blip_vqa(pretrained='',**kwargs): 178 | model = BLIP_VQA(**kwargs) 179 | if pretrained: 180 | model,msg = load_checkpoint(model,pretrained) 181 | # assert(len(msg.missing_keys)==0) 182 | return model 183 | 184 | 185 | def tile(x, dim, n_tile): 186 | init_dim = x.size(dim) 187 | repeat_idx = [1] * x.dim() 188 | repeat_idx[dim] = n_tile 189 | x = x.repeat(*(repeat_idx)) 190 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 191 | return torch.index_select(x, dim, order_index.to(x.device)) 192 | 193 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object forms. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/blip_src/models/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | from models.vit import VisionTransformer, interpolate_pos_embed 12 | from models.med import BertConfig, BertModel, BertLMHeadModel 13 | from transformers import BertTokenizer 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | import os 20 | from urllib.parse import urlparse 21 | # from timm.models.hub import download_cached_file 22 | 23 | class BLIP_Base(nn.Module): 24 | def __init__(self, 25 | med_config = 'configs/med_config.json', 26 | image_size = 224, 27 | vit = 'base', 28 | vit_grad_ckpt = False, 29 | vit_ckpt_layer = 0, 30 | ): 31 | """ 32 | Args: 33 | med_config (str): path for the mixture of encoder-decoder model's configuration file 34 | image_size (int): input image size 35 | vit (str): model size of vision transformer 36 | """ 37 | super().__init__() 38 | 39 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 40 | self.tokenizer = init_tokenizer() 41 | med_config = BertConfig.from_json_file(med_config) 42 | med_config.encoder_width = vision_width 43 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 44 | 45 | 46 | def forward(self, image, caption, mode): 47 | 48 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 49 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 50 | 51 | if mode=='image': 52 | # return image features 53 | image_embeds = self.visual_encoder(image) 54 | return image_embeds 55 | 56 | elif mode=='text': 57 | # return text features 58 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 59 | return_dict = True, mode = 'text') 60 | return text_output.last_hidden_state 61 | 62 | elif mode=='multimodal': 63 | # return multimodel features 64 | image_embeds = self.visual_encoder(image) 65 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 66 | 67 | text.input_ids[:,0] = self.tokenizer.enc_token_id 68 | output = self.text_encoder(text.input_ids, 69 | attention_mask = text.attention_mask, 70 | encoder_hidden_states = image_embeds, 71 | encoder_attention_mask = image_atts, 72 | return_dict = True, 73 | ) 74 | return output.last_hidden_state 75 | 76 | 77 | 78 | class BLIP_Decoder(nn.Module): 79 | def __init__(self, 80 | med_config = 'configs/med_config.json', 81 | image_size = 384, 82 | vit = 'base', 83 | vit_grad_ckpt = False, 84 | vit_ckpt_layer = 0, 85 | prompt = 'a picture of ', 86 | ): 87 | """ 88 | Args: 89 | med_config (str): path for the mixture of encoder-decoder model's configuration file 90 | image_size (int): input image size 91 | vit (str): model size of vision transformer 92 | """ 93 | super().__init__() 94 | 95 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 96 | self.tokenizer = init_tokenizer() 97 | med_config = BertConfig.from_json_file(med_config) 98 | med_config.encoder_width = vision_width 99 | self.text_decoder = BertLMHeadModel(config=med_config) 100 | 101 | self.prompt = prompt 102 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 103 | 104 | 105 | def forward(self, image, caption): 106 | 107 | image_embeds = self.visual_encoder(image) 108 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 109 | 110 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 111 | 112 | text.input_ids[:,0] = self.tokenizer.bos_token_id 113 | 114 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 115 | decoder_targets[:,:self.prompt_length] = -100 116 | 117 | decoder_output = self.text_decoder(text.input_ids, 118 | attention_mask = text.attention_mask, 119 | encoder_hidden_states = image_embeds, 120 | encoder_attention_mask = image_atts, 121 | labels = decoder_targets, 122 | return_dict = True, 123 | ) 124 | loss_lm = decoder_output.loss 125 | 126 | return loss_lm 127 | 128 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 129 | image_embeds = self.visual_encoder(image) 130 | 131 | if not sample: 132 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 133 | 134 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 135 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 136 | 137 | prompt = [self.prompt] * image.size(0) 138 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 139 | input_ids[:,0] = self.tokenizer.bos_token_id 140 | input_ids = input_ids[:, :-1] 141 | 142 | if sample: 143 | #nucleus sampling 144 | outputs = self.text_decoder.generate(input_ids=input_ids, 145 | max_length=max_length, 146 | min_length=min_length, 147 | do_sample=True, 148 | top_p=top_p, 149 | num_return_sequences=1, 150 | eos_token_id=self.tokenizer.sep_token_id, 151 | pad_token_id=self.tokenizer.pad_token_id, 152 | repetition_penalty=1.1, 153 | **model_kwargs) 154 | else: 155 | #beam search 156 | outputs = self.text_decoder.generate(input_ids=input_ids, 157 | max_length=max_length, 158 | min_length=min_length, 159 | num_beams=num_beams, 160 | eos_token_id=self.tokenizer.sep_token_id, 161 | pad_token_id=self.tokenizer.pad_token_id, 162 | repetition_penalty=repetition_penalty, 163 | **model_kwargs) 164 | 165 | captions = [] 166 | for output in outputs: 167 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 168 | captions.append(caption[len(self.prompt):]) 169 | return captions 170 | 171 | 172 | def blip_decoder(pretrained='',**kwargs): 173 | model = BLIP_Decoder(**kwargs) 174 | if pretrained: 175 | model,msg = load_checkpoint(model,pretrained) 176 | assert(len(msg.missing_keys)==0) 177 | return model 178 | 179 | def blip_feature_extractor(pretrained='',**kwargs): 180 | model = BLIP_Base(**kwargs) 181 | if pretrained: 182 | model,msg = load_checkpoint(model,pretrained) 183 | assert(len(msg.missing_keys)==0) 184 | return model 185 | 186 | def init_tokenizer(): 187 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 188 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 189 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 190 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 191 | return tokenizer 192 | 193 | 194 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 195 | 196 | assert vit in ['base', 'large'], "vit parameter must be base or large" 197 | if vit=='base': 198 | vision_width = 768 199 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 200 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 201 | drop_path_rate=0 or drop_path_rate 202 | ) 203 | elif vit=='large': 204 | vision_width = 1024 205 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 206 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 207 | drop_path_rate=0.1 or drop_path_rate 208 | ) 209 | return visual_encoder, vision_width 210 | 211 | def is_url(url_or_filename): 212 | parsed = urlparse(url_or_filename) 213 | return parsed.scheme in ("http", "https") 214 | 215 | def load_checkpoint(model,url_or_filename): 216 | print("load checkpoint for {}".format(url_or_filename)) 217 | if url_or_filename == "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth": 218 | url_or_filename = "pretrained_models/model_base_retrieval_coco.pth" 219 | if os.path.isfile(url_or_filename): 220 | checkpoint = torch.load(url_or_filename, map_location='cpu') 221 | else: 222 | raise RuntimeError('checkpoint url or path is invalid') 223 | # if is_url(url_or_filename): 224 | # cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 225 | # checkpoint = torch.load(cached_file, map_location='cpu') 226 | # elif os.path.isfile(url_or_filename): 227 | # checkpoint = torch.load(url_or_filename, map_location='cpu') 228 | # else: 229 | # raise RuntimeError('checkpoint url or path is invalid') 230 | 231 | state_dict = checkpoint['model'] 232 | 233 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 234 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 235 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 236 | model.visual_encoder_m) 237 | for key in model.state_dict().keys(): 238 | if key in state_dict.keys(): 239 | if state_dict[key].shape!=model.state_dict()[key].shape: 240 | del state_dict[key] 241 | 242 | msg = model.load_state_dict(state_dict,strict=False) 243 | print('load checkpoint from %s'%url_or_filename) 244 | return model,msg 245 | 246 | -------------------------------------------------------------------------------- /src/blip_src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 15 | """Decay the learning rate""" 16 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | 20 | import numpy as np 21 | import io 22 | import os 23 | import time 24 | from collections import defaultdict, deque 25 | import datetime 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | class MetricLogger(object): 93 | def __init__(self, delimiter="\t"): 94 | self.meters = defaultdict(SmoothedValue) 95 | self.delimiter = delimiter 96 | 97 | def update(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if isinstance(v, torch.Tensor): 100 | v = v.item() 101 | assert isinstance(v, (float, int)) 102 | self.meters[k].update(v) 103 | 104 | def __getattr__(self, attr): 105 | if attr in self.meters: 106 | return self.meters[attr] 107 | if attr in self.__dict__: 108 | return self.__dict__[attr] 109 | raise AttributeError("'{}' object has no attribute '{}'".format( 110 | type(self).__name__, attr)) 111 | 112 | def __str__(self): 113 | loss_str = [] 114 | for name, meter in self.meters.items(): 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) 118 | return self.delimiter.join(loss_str) 119 | 120 | def global_avg(self): 121 | loss_str = [] 122 | for name, meter in self.meters.items(): 123 | loss_str.append( 124 | "{}: {:.4f}".format(name, meter.global_avg) 125 | ) 126 | return self.delimiter.join(loss_str) 127 | 128 | def synchronize_between_processes(self): 129 | for meter in self.meters.values(): 130 | meter.synchronize_between_processes() 131 | 132 | def add_meter(self, name, meter): 133 | self.meters[name] = meter 134 | 135 | def log_every(self, iterable, print_freq, header=None): 136 | i = 0 137 | if not header: 138 | header = '' 139 | start_time = time.time() 140 | end = time.time() 141 | iter_time = SmoothedValue(fmt='{avg:.4f}') 142 | data_time = SmoothedValue(fmt='{avg:.4f}') 143 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 144 | log_msg = [ 145 | header, 146 | '[{0' + space_fmt + '}/{1}]', 147 | 'eta: {eta}', 148 | '{meters}', 149 | 'time: {time}', 150 | 'data: {data}' 151 | ] 152 | if torch.cuda.is_available(): 153 | log_msg.append('max mem: {memory:.0f}') 154 | log_msg = self.delimiter.join(log_msg) 155 | MB = 1024.0 * 1024.0 156 | for obj in iterable: 157 | data_time.update(time.time() - end) 158 | yield obj 159 | iter_time.update(time.time() - end) 160 | if i % print_freq == 0 or i == len(iterable) - 1: 161 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 162 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 163 | if torch.cuda.is_available(): 164 | print(log_msg.format( 165 | i, len(iterable), eta=eta_string, 166 | meters=str(self), 167 | time=str(iter_time), data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB)) 169 | else: 170 | print(log_msg.format( 171 | i, len(iterable), eta=eta_string, 172 | meters=str(self), 173 | time=str(iter_time), data=str(data_time))) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print('{} Total time: {} ({:.4f} s / it)'.format( 179 | header, total_time_str, total_time / len(iterable))) 180 | 181 | 182 | class AttrDict(dict): 183 | def __init__(self, *args, **kwargs): 184 | super(AttrDict, self).__init__(*args, **kwargs) 185 | self.__dict__ = self 186 | 187 | 188 | def compute_acc(logits, label, reduction='mean'): 189 | ret = (torch.argmax(logits, dim=1) == label).float() 190 | if reduction == 'none': 191 | return ret.detach() 192 | elif reduction == 'mean': 193 | return ret.mean().item() 194 | 195 | def compute_n_params(model, return_str=True): 196 | tot = 0 197 | for p in model.parameters(): 198 | w = 1 199 | for x in p.shape: 200 | w *= x 201 | tot += w 202 | if return_str: 203 | if tot >= 1e6: 204 | return '{:.1f}M'.format(tot / 1e6) 205 | else: 206 | return '{:.1f}K'.format(tot / 1e3) 207 | else: 208 | return tot 209 | 210 | def setup_for_distributed(is_master): 211 | """ 212 | This function disables printing when not in master process 213 | """ 214 | import builtins as __builtin__ 215 | builtin_print = __builtin__.print 216 | 217 | def print(*args, **kwargs): 218 | force = kwargs.pop('force', False) 219 | if is_master or force: 220 | builtin_print(*args, **kwargs) 221 | 222 | __builtin__.print = print 223 | 224 | 225 | def is_dist_avail_and_initialized(): 226 | if not dist.is_available(): 227 | return False 228 | if not dist.is_initialized(): 229 | return False 230 | return True 231 | 232 | 233 | def get_world_size(): 234 | if not is_dist_avail_and_initialized(): 235 | return 1 236 | return dist.get_world_size() 237 | 238 | 239 | def get_rank(): 240 | if not is_dist_avail_and_initialized(): 241 | return 0 242 | return dist.get_rank() 243 | 244 | 245 | def is_main_process(): 246 | return get_rank() == 0 247 | 248 | 249 | def save_on_master(*args, **kwargs): 250 | if is_main_process(): 251 | torch.save(*args, **kwargs) 252 | 253 | 254 | # def init_distributed_mode(args): 255 | # """ 256 | # initialize the normal job 257 | # """ 258 | # # launched with torch.distributed.launch 259 | # if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 260 | # args.rank = int(os.environ["RANK"]) 261 | # args.world_size = int(os.environ["WORLD_SIZE"]) 262 | # args.gpu = int(os.environ.get("LOCAL_RANK", 0)) 263 | # print( 264 | # "args.rank", 265 | # args.rank, 266 | # "args.world_size", 267 | # args.world_size, 268 | # "args.gpu", 269 | # args.gpu, 270 | # ) 271 | # print("get_rank()", get_rank()) 272 | # # launched with submitit on a slurm cluster 273 | # elif "SLURM_PROCID" in os.environ: 274 | # args.rank = int(os.environ["SLURM_PROCID"]) 275 | # args.gpu = args.rank % torch.cuda.device_count() 276 | # # launched naively with `python main_dino.py` 277 | # # we manually add MASTER_ADDR and MASTER_PORT to env variables 278 | # elif torch.cuda.is_available(): 279 | # print("Will run the code on one GPU.") 280 | # args.rank, args.gpu, args.world_size = 0, 0, 1 281 | # os.environ["MASTER_ADDR"] = "127.0.0.1" 282 | # os.environ["MASTER_PORT"] = "2950" 283 | # else: 284 | # print("Does not support training without GPU.") 285 | # sys.exit(1) 286 | 287 | # os.environ["MASTER_PORT"] = "6542" 288 | 289 | # dist.init_process_group( 290 | # backend="nccl", 291 | # init_method=args.dist_url, 292 | # world_size=args.world_size, 293 | # rank=args.rank, 294 | # ) 295 | 296 | # torch.cuda.set_device(args.gpu) 297 | # print( 298 | # "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True 299 | # ) 300 | # dist.barrier() 301 | # setup_for_distributed(args.rank == 0) 302 | 303 | 304 | 305 | 306 | def init_distributed_mode(args): 307 | """ 308 | initialize the normal job 309 | """ 310 | # launched with torch.distributed.launch 311 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 312 | args.rank = int(os.environ["RANK"]) 313 | args.world_size = int(os.environ["WORLD_SIZE"]) 314 | args.gpu = int(os.environ.get("LOCAL_RANK", 0)) 315 | print( 316 | "args.rank", 317 | args.rank, 318 | "args.world_size", 319 | args.world_size, 320 | "args.gpu", 321 | args.gpu, 322 | ) 323 | print("get_rank()", get_rank()) 324 | # launched with submitit on a slurm cluster 325 | elif "SLURM_PROCID" in os.environ: 326 | args.rank = int(os.environ["SLURM_PROCID"]) 327 | args.gpu = args.rank % torch.cuda.device_count() 328 | # launched naively with `python main_dino.py` 329 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 330 | elif torch.cuda.is_available(): 331 | print("Will run the code on one GPU.") 332 | args.rank, args.gpu, args.world_size = 0, 0, 1 333 | os.environ["MASTER_ADDR"] = "127.0.0.1" 334 | os.environ["MASTER_PORT"] = "2950" 335 | else: 336 | print("Does not support training without GPU.") 337 | sys.exit(1) 338 | 339 | # os.environ["MASTER_PORT"] = "6542" 340 | 341 | dist.init_process_group( 342 | backend="nccl", 343 | init_method=args.dist_url, 344 | world_size=args.world_size, 345 | rank=args.rank, 346 | ) 347 | 348 | torch.cuda.set_device(args.gpu) 349 | print( 350 | "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True 351 | ) 352 | dist.barrier() 353 | setup_for_distributed(args.rank == 0) 354 | 355 | 356 | def init_distributed_ddpjob(args=None): 357 | """ 358 | initialize the ddp job 359 | """ 360 | if dist.is_available() and dist.is_initialized(): 361 | return dist.get_world_size(), dist.get_rank() 362 | 363 | try: 364 | os.environ["MASTER_PORT"] = "40101" 365 | torch.distributed.init_process_group(backend="nccl") 366 | except Exception: 367 | world_size, rank = 1, 0 368 | print("distributed training not available") 369 | 370 | world_size = dist.get_world_size() 371 | rank = dist.get_rank() 372 | args.gpu = args.rank 373 | args.world_size, args.rank = world_size, rank 374 | return world_size, -------------------------------------------------------------------------------- /src/blip_src/models/blip_retrieval.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig, BertModel 9 | from transformers import BertTokenizer 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | from models.blip import create_vit, init_tokenizer, load_checkpoint 16 | 17 | class BLIP_Retrieval(nn.Module): 18 | def __init__(self, 19 | med_config = 'configs/med_config.json', 20 | image_size = 384, 21 | vit = 'base', 22 | vit_grad_ckpt = False, 23 | vit_ckpt_layer = 0, 24 | embed_dim = 256, 25 | queue_size = 57600, 26 | momentum = 0.995, 27 | negative_all_rank = False, 28 | ): 29 | """ 30 | Args: 31 | med_config (str): path for the mixture of encoder-decoder model's configuration file 32 | image_size (int): input image size 33 | vit (str): model size of vision transformer 34 | """ 35 | super().__init__() 36 | 37 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 38 | self.tokenizer = init_tokenizer() 39 | med_config = BertConfig.from_json_file(med_config) 40 | med_config.encoder_width = vision_width 41 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 42 | 43 | text_width = self.text_encoder.config.hidden_size 44 | 45 | self.vision_proj = nn.Linear(vision_width, embed_dim) 46 | self.text_proj = nn.Linear(text_width, embed_dim) 47 | 48 | self.itm_head = nn.Linear(text_width, 2) 49 | 50 | # create momentum encoders 51 | self.visual_encoder_m, vision_width = create_vit(vit,image_size) 52 | self.vision_proj_m = nn.Linear(vision_width, embed_dim) 53 | self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) 54 | self.text_proj_m = nn.Linear(text_width, embed_dim) 55 | 56 | self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], 57 | [self.vision_proj,self.vision_proj_m], 58 | [self.text_encoder,self.text_encoder_m], 59 | [self.text_proj,self.text_proj_m], 60 | ] 61 | self.copy_params() 62 | 63 | # create the queue 64 | self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) 65 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 66 | self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) 67 | self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) 68 | 69 | self.image_queue = nn.functional.normalize(self.image_queue, dim=0) 70 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 71 | 72 | self.queue_size = queue_size 73 | self.momentum = momentum 74 | self.temp = nn.Parameter(0.07*torch.ones([])) 75 | 76 | self.negative_all_rank = negative_all_rank 77 | 78 | 79 | def forward(self, image, caption, alpha, idx): 80 | with torch.no_grad(): 81 | self.temp.clamp_(0.001,0.5) 82 | 83 | image_embeds = self.visual_encoder(image) 84 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 85 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 86 | 87 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 88 | return_tensors="pt").to(image.device) 89 | 90 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 91 | return_dict = True, mode = 'text') 92 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 93 | 94 | ###============== Image-text Contrastive Learning ===================### 95 | idx = idx.view(-1,1) 96 | idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) 97 | pos_idx = torch.eq(idx, idx_all).float() 98 | sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) 99 | 100 | # get momentum features 101 | with torch.no_grad(): 102 | self._momentum_update() 103 | image_embeds_m = self.visual_encoder_m(image) 104 | image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) 105 | image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) 106 | 107 | text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, 108 | return_dict = True, mode = 'text') 109 | text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) 110 | text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 111 | 112 | sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp 113 | sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp 114 | 115 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 116 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 117 | 118 | sim_i2t = image_feat @ text_feat_m_all / self.temp 119 | sim_t2i = text_feat @ image_feat_m_all / self.temp 120 | 121 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 122 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 123 | 124 | loss_ita = (loss_i2t+loss_t2i)/2 125 | 126 | idxs = concat_all_gather(idx) 127 | self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) 128 | 129 | ###============== Image-text Matching ===================### 130 | encoder_input_ids = text.input_ids.clone() 131 | encoder_input_ids[:,0] = self.tokenizer.enc_token_id 132 | 133 | # forward the positve image-text pair 134 | bs = image.size(0) 135 | output_pos = self.text_encoder(encoder_input_ids, 136 | attention_mask = text.attention_mask, 137 | encoder_hidden_states = image_embeds, 138 | encoder_attention_mask = image_atts, 139 | return_dict = True, 140 | ) 141 | 142 | 143 | if self.negative_all_rank: 144 | # compute sample similarity 145 | with torch.no_grad(): 146 | mask = torch.eq(idx, idxs.t()) 147 | 148 | image_feat_world = concat_all_gather(image_feat) 149 | text_feat_world = concat_all_gather(text_feat) 150 | 151 | sim_i2t = image_feat @ text_feat_world.t() / self.temp 152 | sim_t2i = text_feat @ image_feat_world.t() / self.temp 153 | 154 | weights_i2t = F.softmax(sim_i2t,dim=1) 155 | weights_i2t.masked_fill_(mask, 0) 156 | 157 | weights_t2i = F.softmax(sim_t2i,dim=1) 158 | weights_t2i.masked_fill_(mask, 0) 159 | 160 | image_embeds_world = all_gather_with_grad(image_embeds) 161 | 162 | # select a negative image (from all ranks) for each text 163 | image_embeds_neg = [] 164 | for b in range(bs): 165 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 166 | image_embeds_neg.append(image_embeds_world[neg_idx]) 167 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 168 | 169 | # select a negative text (from all ranks) for each image 170 | input_ids_world = concat_all_gather(encoder_input_ids) 171 | att_mask_world = concat_all_gather(text.attention_mask) 172 | 173 | text_ids_neg = [] 174 | text_atts_neg = [] 175 | for b in range(bs): 176 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 177 | text_ids_neg.append(input_ids_world[neg_idx]) 178 | text_atts_neg.append(att_mask_world[neg_idx]) 179 | 180 | else: 181 | with torch.no_grad(): 182 | mask = torch.eq(idx, idx.t()) 183 | 184 | sim_i2t = image_feat @ text_feat.t() / self.temp 185 | sim_t2i = text_feat @ image_feat.t() / self.temp 186 | 187 | weights_i2t = F.softmax(sim_i2t,dim=1) 188 | weights_i2t.masked_fill_(mask, 0) 189 | 190 | weights_t2i = F.softmax(sim_t2i,dim=1) 191 | weights_t2i.masked_fill_(mask, 0) 192 | 193 | # select a negative image (from same rank) for each text 194 | image_embeds_neg = [] 195 | for b in range(bs): 196 | neg_idx = torch.multinomial(weights_t2i[b], 1).item() 197 | image_embeds_neg.append(image_embeds[neg_idx]) 198 | image_embeds_neg = torch.stack(image_embeds_neg,dim=0) 199 | 200 | # select a negative text (from same rank) for each image 201 | text_ids_neg = [] 202 | text_atts_neg = [] 203 | for b in range(bs): 204 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 205 | text_ids_neg.append(encoder_input_ids[neg_idx]) 206 | text_atts_neg.append(text.attention_mask[neg_idx]) 207 | 208 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 209 | text_atts_neg = torch.stack(text_atts_neg,dim=0) 210 | 211 | text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) 212 | text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) 213 | 214 | image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) 215 | image_atts_all = torch.cat([image_atts,image_atts],dim=0) 216 | 217 | output_neg = self.text_encoder(text_ids_all, 218 | attention_mask = text_atts_all, 219 | encoder_hidden_states = image_embeds_all, 220 | encoder_attention_mask = image_atts_all, 221 | return_dict = True, 222 | ) 223 | 224 | 225 | vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) 226 | vl_output = self.itm_head(vl_embeddings) 227 | 228 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], 229 | dim=0).to(image.device) 230 | loss_itm = F.cross_entropy(vl_output, itm_labels) 231 | 232 | return loss_ita, loss_itm 233 | 234 | 235 | @torch.no_grad() 236 | def copy_params(self): 237 | for model_pair in self.model_pairs: 238 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 239 | param_m.data.copy_(param.data) # initialize 240 | param_m.requires_grad = False # not update by gradient 241 | 242 | 243 | @torch.no_grad() 244 | def _momentum_update(self): 245 | for model_pair in self.model_pairs: 246 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 247 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 248 | 249 | 250 | @torch.no_grad() 251 | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): 252 | # gather keys before updating queue 253 | image_feats = concat_all_gather(image_feat) 254 | text_feats = concat_all_gather(text_feat) 255 | 256 | 257 | batch_size = image_feats.shape[0] 258 | 259 | ptr = int(self.ptr_queue) 260 | assert self.queue_size % batch_size == 0 # for simplicity 261 | 262 | # replace the keys at ptr (dequeue and enqueue) 263 | self.image_queue[:, ptr:ptr + batch_size] = image_feats.T 264 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 265 | self.idx_queue[:, ptr:ptr + batch_size] = idxs.T 266 | ptr = (ptr + batch_size) % self.queue_size # move pointer 267 | 268 | self.ptr_queue[0] = ptr 269 | 270 | 271 | def blip_retrieval(pretrained='',**kwargs): 272 | model = BLIP_Retrieval(**kwargs) 273 | if pretrained: 274 | model,msg = load_checkpoint(model,pretrained) 275 | print("missing keys:") 276 | print(msg.missing_keys) 277 | return model 278 | 279 | 280 | @torch.no_grad() 281 | def concat_all_gather(tensor): 282 | """ 283 | Performs all_gather operation on the provided tensors. 284 | *** Warning ***: torch.distributed.all_gather has no gradient. 285 | """ 286 | tensors_gather = [torch.ones_like(tensor) 287 | for _ in range(torch.distributed.get_world_size())] 288 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 289 | 290 | output = torch.cat(tensors_gather, dim=0) 291 | return output 292 | 293 | 294 | class GatherLayer(torch.autograd.Function): 295 | """ 296 | Gather tensors from all workers with support for backward propagation: 297 | This implementation does not cut the gradients as torch.distributed.all_gather does. 298 | """ 299 | 300 | @staticmethod 301 | def forward(ctx, x): 302 | output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] 303 | torch.distributed.all_gather(output, x) 304 | return tuple(output) 305 | 306 | @staticmethod 307 | def backward(ctx, *grads): 308 | all_gradients = torch.stack(grads) 309 | torch.distributed.all_reduce(all_gradients) 310 | return all_gradients[torch.distributed.get_rank()] 311 | 312 | 313 | def all_gather_with_grad(tensors): 314 | """ 315 | Performs all_gather operation on the provided tensors. 316 | Graph remains connected for backward grad computation. 317 | """ 318 | # Queue the gathered tensors 319 | world_size = torch.distributed.get_world_size() 320 | # There is no need for reduction in the single-proc case 321 | if world_size == 1: 322 | return tensors 323 | 324 | tensor_all = GatherLayer.apply(tensors) 325 | 326 | return torch.cat(tensor_all, dim=0) 327 | -------------------------------------------------------------------------------- /src/blip_src/models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from functools import partial 16 | 17 | from timm.models.vision_transformer import _cfg, PatchEmbed 18 | from timm.models.registry import register_model 19 | from timm.models.layers import trunc_normal_, DropPath 20 | from timm.models.helpers import adapt_input_conv 21 | 22 | # from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 23 | 24 | class Mlp(nn.Module): 25 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 26 | """ 27 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 28 | super().__init__() 29 | out_features = out_features or in_features 30 | hidden_features = hidden_features or in_features 31 | self.fc1 = nn.Linear(in_features, hidden_features) 32 | self.act = act_layer() 33 | self.fc2 = nn.Linear(hidden_features, out_features) 34 | self.drop = nn.Dropout(drop) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = self.act(x) 39 | x = self.drop(x) 40 | x = self.fc2(x) 41 | x = self.drop(x) 42 | return x 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 47 | super().__init__() 48 | self.num_heads = num_heads 49 | head_dim = dim // num_heads 50 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 51 | self.scale = qk_scale or head_dim ** -0.5 52 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 53 | self.attn_drop = nn.Dropout(attn_drop) 54 | self.proj = nn.Linear(dim, dim) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | self.attn_gradients = None 57 | self.attention_map = None 58 | 59 | def save_attn_gradients(self, attn_gradients): 60 | self.attn_gradients = attn_gradients 61 | 62 | def get_attn_gradients(self): 63 | return self.attn_gradients 64 | 65 | def save_attention_map(self, attention_map): 66 | self.attention_map = attention_map 67 | 68 | def get_attention_map(self): 69 | return self.attention_map 70 | 71 | def forward(self, x, register_hook=False): 72 | B, N, C = x.shape 73 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 75 | 76 | attn = (q @ k.transpose(-2, -1)) * self.scale 77 | attn = attn.softmax(dim=-1) 78 | attn = self.attn_drop(attn) 79 | 80 | if register_hook: 81 | self.save_attention_map(attn) 82 | attn.register_hook(self.save_attn_gradients) 83 | 84 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 85 | x = self.proj(x) 86 | x = self.proj_drop(x) 87 | return x 88 | 89 | 90 | class Block(nn.Module): 91 | 92 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 93 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 94 | super().__init__() 95 | self.norm1 = norm_layer(dim) 96 | self.attn = Attention( 97 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 98 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 99 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 100 | self.norm2 = norm_layer(dim) 101 | mlp_hidden_dim = int(dim * mlp_ratio) 102 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 103 | 104 | # if use_grad_checkpointing: 105 | # self.attn = checkpoint_wrapper(self.attn) 106 | # self.mlp = checkpoint_wrapper(self.mlp) 107 | 108 | def forward(self, x, register_hook=False): 109 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 110 | x = x + self.drop_path(self.mlp(self.norm2(x))) 111 | return x 112 | 113 | 114 | class VisionTransformer(nn.Module): 115 | """ Vision Transformer 116 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 117 | https://arxiv.org/abs/2010.11929 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 120 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 121 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 122 | use_grad_checkpointing=False, ckpt_layer=0): 123 | """ 124 | Args: 125 | img_size (int, tuple): input image size 126 | patch_size (int, tuple): patch size 127 | in_chans (int): number of input channels 128 | num_classes (int): number of classes for classification head 129 | embed_dim (int): embedding dimension 130 | depth (int): depth of transformer 131 | num_heads (int): number of attention heads 132 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 133 | qkv_bias (bool): enable bias for qkv if True 134 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 135 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 136 | drop_rate (float): dropout rate 137 | attn_drop_rate (float): attention dropout rate 138 | drop_path_rate (float): stochastic depth rate 139 | norm_layer: (nn.Module): normalization layer 140 | """ 141 | super().__init__() 142 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 143 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 144 | 145 | self.patch_embed = PatchEmbed( 146 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 147 | 148 | num_patches = self.patch_embed.num_patches 149 | 150 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 151 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 152 | self.pos_drop = nn.Dropout(p=drop_rate) 153 | 154 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 155 | self.blocks = nn.ModuleList([ 156 | Block( 157 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 158 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 159 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 160 | ) 161 | for i in range(depth)]) 162 | self.norm = norm_layer(embed_dim) 163 | 164 | trunc_normal_(self.pos_embed, std=.02) 165 | trunc_normal_(self.cls_token, std=.02) 166 | self.apply(self._init_weights) 167 | 168 | def _init_weights(self, m): 169 | if isinstance(m, nn.Linear): 170 | trunc_normal_(m.weight, std=.02) 171 | if isinstance(m, nn.Linear) and m.bias is not None: 172 | nn.init.constant_(m.bias, 0) 173 | elif isinstance(m, nn.LayerNorm): 174 | nn.init.constant_(m.bias, 0) 175 | nn.init.constant_(m.weight, 1.0) 176 | 177 | @torch.jit.ignore 178 | def no_weight_decay(self): 179 | return {'pos_embed', 'cls_token'} 180 | 181 | def forward(self, x, register_blk=-1): 182 | B = x.shape[0] 183 | x = self.patch_embed(x) 184 | 185 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 186 | x = torch.cat((cls_tokens, x), dim=1) 187 | 188 | x = x + self.pos_embed[:,:x.size(1),:] 189 | x = self.pos_drop(x) 190 | 191 | for i,blk in enumerate(self.blocks): 192 | x = blk(x, register_blk==i) 193 | x = self.norm(x) 194 | 195 | return x 196 | 197 | @torch.jit.ignore() 198 | def load_pretrained(self, checkpoint_path, prefix=''): 199 | _load_weights(self, checkpoint_path, prefix) 200 | 201 | 202 | @torch.no_grad() 203 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 204 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 205 | """ 206 | import numpy as np 207 | 208 | def _n2p(w, t=True): 209 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 210 | w = w.flatten() 211 | if t: 212 | if w.ndim == 4: 213 | w = w.transpose([3, 2, 0, 1]) 214 | elif w.ndim == 3: 215 | w = w.transpose([2, 0, 1]) 216 | elif w.ndim == 2: 217 | w = w.transpose([1, 0]) 218 | return torch.from_numpy(w) 219 | 220 | w = np.load(checkpoint_path) 221 | if not prefix and 'opt/target/embedding/kernel' in w: 222 | prefix = 'opt/target/' 223 | 224 | if hasattr(model.patch_embed, 'backbone'): 225 | # hybrid 226 | backbone = model.patch_embed.backbone 227 | stem_only = not hasattr(backbone, 'stem') 228 | stem = backbone if stem_only else backbone.stem 229 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 230 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 231 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 232 | if not stem_only: 233 | for i, stage in enumerate(backbone.stages): 234 | for j, block in enumerate(stage.blocks): 235 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 236 | for r in range(3): 237 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 238 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 239 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 240 | if block.downsample is not None: 241 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 242 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 243 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 244 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 245 | else: 246 | embed_conv_w = adapt_input_conv( 247 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 248 | model.patch_embed.proj.weight.copy_(embed_conv_w) 249 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 250 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 251 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 252 | if pos_embed_w.shape != model.pos_embed.shape: 253 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 254 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 255 | model.pos_embed.copy_(pos_embed_w) 256 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 257 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 258 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 259 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 260 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 261 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 262 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 263 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 264 | for i, block in enumerate(model.blocks.children()): 265 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 266 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 267 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 268 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 269 | block.attn.qkv.weight.copy_(torch.cat([ 270 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 271 | block.attn.qkv.bias.copy_(torch.cat([ 272 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 273 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 274 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 275 | for r in range(2): 276 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 277 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 278 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 279 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 280 | 281 | 282 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 283 | # interpolate position embedding 284 | embedding_size = pos_embed_checkpoint.shape[-1] 285 | num_patches = visual_encoder.patch_embed.num_patches 286 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 287 | # height (== width) for the checkpoint position embedding 288 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 289 | # height (== width) for the new position embedding 290 | new_size = int(num_patches ** 0.5) 291 | 292 | if orig_size!=new_size: 293 | # class_token and dist_token are kept unchanged 294 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 295 | # only the position tokens are interpolated 296 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 297 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 298 | pos_tokens = torch.nn.functional.interpolate( 299 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 300 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 301 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 302 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 303 | 304 | return new_pos_embed 305 | else: 306 | return pos_embed_checkpoint --------------------------------------------------------------------------------