├── .DS_Store ├── imgs ├── 983.jpg ├── main.jpg ├── block_mask.png └── motivation.jpg ├── requirements.txt ├── src ├── blip_src │ ├── multiple_scripts │ │ ├── pretrain.sh │ │ ├── multiple_exp_all_single_8u_ft.sh │ │ ├── ft │ │ │ ├── exp_4.sh │ │ │ ├── exp_3.sh │ │ │ ├── exp_2.sh │ │ │ ├── exp_5.sh │ │ │ └── exp_1.sh │ │ └── exp_all_ft_single.sh │ ├── configs │ │ ├── nocaps.yaml │ │ ├── nlvr.yaml │ │ ├── pretrain_concated_pred_4M.yaml │ │ ├── bert_config.json │ │ ├── med_config.json │ │ ├── vqa.yaml │ │ ├── caption_coco.yaml │ │ ├── retrieval_coco.yaml │ │ └── retrieval_flickr.yaml │ ├── move_pretrained_weights.py │ ├── data │ │ ├── nlvr_dataset.py │ │ ├── flickr30k_dataset.py │ │ ├── vqa_dataset.py │ │ ├── pretrain_dataset_concated_pred_tsv.py │ │ ├── coco_karpathy_dataset.py │ │ ├── utils.py │ │ ├── init_data_concated_pred_tsv.py │ │ ├── init_data_concated_pred_refined.py │ │ ├── __init__.py │ │ └── pretrain_dataset_concated_pred_refined.py │ ├── models │ │ ├── blip_itm.py │ │ ├── blip_nlvr.py │ │ ├── blip_vqa.py │ │ ├── blip.py │ │ ├── blip_retrieval.py │ │ └── vit.py │ ├── eval_nocaps.py │ ├── pretrain_concated_pred_refined.py │ ├── pretrain_concated_pred_tsv.py │ ├── train_vqa.py │ ├── train_nlvr.py │ └── utils.py ├── data_preprocess │ ├── download_cc3m_predictions.sh │ ├── bbox_visualization.py │ └── generate_sample_with_bbox_and_classes.py └── vilt_src │ └── config.py ├── INSTALL.md ├── MODEL_ZOO.md ├── DATASET.md ├── GETTING_STARTED.md ├── README.md └── LICENSE /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/.DS_Store -------------------------------------------------------------------------------- /imgs/983.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/983.jpg -------------------------------------------------------------------------------- /imgs/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/main.jpg -------------------------------------------------------------------------------- /imgs/block_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/block_mask.png -------------------------------------------------------------------------------- /imgs/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ptp/HEAD/imgs/motivation.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | ruamel_yaml 3 | pycocoevalcap 4 | transformers 5 | timm 6 | tabulate 7 | numpy -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | python move_pretrained_weights.py; 2 | 3 | # 4M \ 4 | python -m torch.distributed.launch --nproc_per_node=8 pretrain_concated_pred_tsv.py \ 5 | --config ./configs/pretrain_concated_pred_4M.yaml --output_dir output/Pretrain_concated_pred_4M 6 | 7 | echo "output dir is: output/Pretrain_concated_pred_4M" -------------------------------------------------------------------------------- /src/blip_src/configs/nocaps.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/nocaps/' 2 | ann_root: 'annotation' 3 | 4 | # set pretrained as a file path or an url 5 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 6 | vit: 'base' 7 | batch_size: 32 8 | 9 | image_size: 384 10 | 11 | max_length: 20 12 | min_length: 5 13 | num_beams: 3 14 | prompt: 'a picture of ' -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Install PyTorch and the other dependencies for PTP 2 | 3 | The code has been test on '1.9.0+cu102' and python3.8. 4 | 5 | 6 | 7 | ```bash 8 | conda create -n ptp python==3.8 9 | conda activate ptp 10 | # CUDA 10.2 11 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=10.2 -c pytorch 12 | cd [PATH_TO_PTP] 13 | pip install -r requirements.txt 14 | ``` 15 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/multiple_exp_all_single_8u_ft.sh: -------------------------------------------------------------------------------- 1 | 2 | declare -a PTMethodArray=( "Pretrain_concated_pred_4M" ) 3 | 4 | for pt_method in "${PTMethodArray[@]}" 5 | do 6 | echo "==== start evaluate model $pt_method ====" 7 | 8 | echo "==== utilize pretrained model output/$pt_method/checkpoint_19.pth ====" 9 | 10 | bash ./multiple_scripts/exp_all_ft_single.sh $pt_method; 11 | done -------------------------------------------------------------------------------- /src/blip_src/configs/nlvr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/NLVR2' 2 | ann_root: 'annotation' 3 | 4 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 5 | 6 | #size of vit model; base or large 7 | vit: 'base' 8 | batch_size_train: 16 9 | batch_size_test: 64 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | max_epoch: 15 13 | 14 | image_size: 384 15 | 16 | # optimizer 17 | weight_decay: 0.05 18 | init_lr: 3e-5 19 | min_lr: 0 20 | 21 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_4.sh: -------------------------------------------------------------------------------- 1 | # NLVR2 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | echo '(NLVR:) load pretrained model from: '$3; 11 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/nlvr.yaml; 12 | 13 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_nlvr.py \ 14 | --config ./configs/nlvr.yaml \ 15 | --output_dir output/NLVR_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_3.sh: -------------------------------------------------------------------------------- 1 | # vqa 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | echo '(vqa:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/vqa.yaml; 13 | 14 | 15 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) \ 16 | train_vqa.py --config ./configs/vqa.yaml \ 17 | --output_dir output/vqa_v2_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/configs/pretrain_concated_pred_4M.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'metadata/4M_corpus.tsv', 3 | ] 4 | laion_path: '' 5 | 6 | # size of vit model; base or large 7 | vit: 'base' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 224 12 | batch_size: 75 # 75 13 | 14 | queue_size: 57600 15 | alpha: 0.4 16 | 17 | # optimizer 18 | weight_decay: 0.05 19 | init_lr: 3e-4 20 | min_lr: 1e-6 21 | warmup_lr: 1e-6 22 | lr_decay_rate: 0.9 23 | max_epoch: 20 24 | warmup_steps: 3000 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/blip_src/configs/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /src/blip_src/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /src/blip_src/configs/vqa.yaml: -------------------------------------------------------------------------------- 1 | vqa_root: '/dataset/coco2014/COCO2014' #followed by train2014/ 2 | vg_root: '/dataset/VisualGenome/' #followed by image/ 3 | train_files: ['vqa_train','vqa_val','vg_qa'] 4 | ann_root: 'annotation' 5 | 6 | # set pretrained as a file path or an url 7 | 8 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 9 | # size of vit model; base or large 10 | vit: 'base' 11 | batch_size_train: 12 # 16 12 | batch_size_test: 24 # 32 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | init_lr: 2e-5 16 | 17 | image_size: 480 18 | 19 | k_test: 128 20 | inference: 'rank' 21 | 22 | # optimizer 23 | weight_decay: 0.05 24 | min_lr: 0 25 | max_epoch: 10 -------------------------------------------------------------------------------- /src/blip_src/configs/caption_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/coco2014/COCO2014' 2 | ann_root: 'annotation' 3 | coco_gt_root: 'annotation/coco_gt' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | 8 | # size of vit model; base or large 9 | vit: 'base' 10 | vit_grad_ckpt: False 11 | vit_ckpt_layer: 0 12 | batch_size: 32 # 32 13 | init_lr: 1e-5 14 | 15 | # vit: 'large' 16 | # vit_grad_ckpt: True 17 | # vit_ckpt_layer: 5 18 | # batch_size: 16 19 | # init_lr: 2e-6 20 | 21 | image_size: 384 22 | 23 | # generation configs 24 | max_length: 20 25 | min_length: 5 26 | num_beams: 3 27 | prompt: 'a picture of ' 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 5 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/configs/retrieval_coco.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/coco2014/COCO2014' 2 | ann_root: 'annotation' 3 | dataset: 'coco' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | # size of vit model; base or large 8 | 9 | vit: 'base' 10 | batch_size_train: 24 # 32 11 | batch_size_test: 48 # 64 12 | vit_grad_ckpt: True 13 | vit_ckpt_layer: 4 14 | init_lr: 1e-5 15 | 16 | # vit: 'large' 17 | # batch_size_train: 16 18 | # batch_size_test: 32 19 | # vit_grad_ckpt: True 20 | # vit_ckpt_layer: 12 21 | # init_lr: 5e-6 22 | 23 | image_size: 384 24 | queue_size: 57600 25 | alpha: 0.4 26 | k_test: 256 27 | negative_all_rank: True 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 6 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/configs/retrieval_flickr.yaml: -------------------------------------------------------------------------------- 1 | image_root: '/dataset/Flickr30k/' 2 | ann_root: 'annotation' 3 | dataset: 'flickr' 4 | 5 | # set pretrained as a file path or an url 6 | pretrained: 'output/Pretrain_concated_pred_4M/checkpoint_19.pth' 7 | # size of vit model; base or large 8 | 9 | vit: 'base' 10 | batch_size_train: 24 # 32 11 | batch_size_test: 48 # 64 12 | vit_grad_ckpt: True 13 | vit_ckpt_layer: 4 14 | init_lr: 1e-5 15 | 16 | # vit: 'large' 17 | # batch_size_train: 16 18 | # batch_size_test: 32 19 | # vit_grad_ckpt: True 20 | # vit_ckpt_layer: 10 21 | # init_lr: 5e-6 22 | 23 | image_size: 384 24 | queue_size: 57600 25 | alpha: 0.4 26 | k_test: 128 27 | negative_all_rank: False 28 | 29 | # optimizer 30 | weight_decay: 0.05 31 | min_lr: 0 32 | max_epoch: 6 33 | 34 | -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/exp_all_ft_single.sh: -------------------------------------------------------------------------------- 1 | python move_pretrained_weights.py; 2 | 3 | gpu_num=8; 4 | time=$(date "+%Y-%m-%d-%H:%M:%S"); 5 | suffix=$1${time}; # the suffix to distingush different experiment, e.g. $1='generation_mix' 6 | 7 | PRETRAINED_MODEL="output\/$1\/checkpoint_19.pth" 8 | 9 | echo "${suffix}"; 10 | 11 | 12 | bash multiple_scripts/ft/exp_2.sh $gpu_num $suffix $PRETRAINED_MODEL; # captioning, ~1h 13 | 14 | bash multiple_scripts/ft/exp_5.sh $gpu_num $suffix $PRETRAINED_MODEL; # flickr30 retrieval, ~1h 15 | 16 | bash multiple_scripts/ft/exp_4.sh $gpu_num $suffix $PRETRAINED_MODEL; # NLVR2, ~2h 17 | 18 | bash multiple_scripts/ft/exp_1.sh $gpu_num $suffix $PRETRAINED_MODEL; # coco retrieval, ~12h 19 | 20 | bash multiple_scripts/ft/exp_3.sh $gpu_num $suffix $PRETRAINED_MODEL; # vqa, very slow, ~35 h 21 | -------------------------------------------------------------------------------- /src/data_preprocess/download_cc3m_predictions.sh: -------------------------------------------------------------------------------- 1 | for VARIABLE in {0..11..1} 2 | do 3 | mkdir CC3M/$VARIABLE 4 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/predictions.tsv \ 5 | CC3M/$VARIABLE --recursive 6 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/predictions.lineidx \ 7 | CC3M/$VARIABLE --recursive 8 | ./azcopy copy https://biglmdiag.blob.core.windows.net/vinvl/image_features/googlecc_X152C4_frcnnbig2_exp168model_0060000model.roi_heads.nm_filter_2_model.roi_heads.score_thresh_0.2/model_0060000/$VARIABLE/annotations \ 9 | CC3M/$VARIABLE --recursive 10 | done -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_2.sh: -------------------------------------------------------------------------------- 1 | # image captioning 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | echo '(coco captioning:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/caption_coco.yaml; 13 | 14 | # step1: zero-shot coco captioning 15 | echo 'step1: zero-shot coco captioning' 16 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 17 | --config ./configs/caption_coco.yaml \ 18 | --output_dir output/captioning_coco_$2 --evaluate 19 | 20 | # step2: fine-tune 21 | 22 | echo 'step2: fine-tune coco captioning' 23 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 24 | --config ./configs/caption_coco.yaml \ 25 | --output_dir output/captioning_coco_$2 # --evaluate -------------------------------------------------------------------------------- /src/blip_src/move_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | print('move pretrained weights...') 17 | try: 18 | # a100 machines 19 | if not os.path.exists('/root/.cache/torch/hub/checkpoints/'): 20 | os.makedirs('/root/.cache/torch/hub/checkpoints/') 21 | os.system( 22 | 'cp pretrained_models/*.pth /root/.cache/torch/hub/checkpoints/.') 23 | print('move finished...') 24 | except Exception as e: 25 | print(e) -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_5.sh: -------------------------------------------------------------------------------- 1 | # image to text retrieval 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | # ===================== step1: zero-shot evaluation================ 11 | echo '(step1: zero-shot f30k retrieval:) load pretrained model from: '$3; 12 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_flickr.yaml; 13 | 14 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 15 | --config ./configs/retrieval_flickr.yaml \ 16 | --output_dir output/retrieval_flickr_$2 \ 17 | --evaluate 18 | 19 | 20 | # ===================== step2: ft and evaluate ================ 21 | echo '(step2: fine-tune f30k retrieval:) load pretrained model from: '$3; 22 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_flickr.yaml; 23 | 24 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 25 | --config ./configs/retrieval_flickr.yaml \ 26 | --output_dir output/retrieval_flickr_$2 -------------------------------------------------------------------------------- /src/blip_src/multiple_scripts/ft/exp_1.sh: -------------------------------------------------------------------------------- 1 | # image to text retrieval 2 | 3 | function rand(){ 4 | min=$1 5 | max=$(($2-$min+1)) 6 | num=$(date +%s%N) 7 | echo $(($num%$max+$min)) 8 | } 9 | 10 | 11 | # ============================== step1: zero-shot retrieval evaluation ========= 12 | # evaluate on test 13 | echo '(coco step1: zero-shot evaluation) load pretrained model from: '$3; 14 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_coco.yaml; 15 | 16 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 17 | --config ./configs/retrieval_coco.yaml \ 18 | --output_dir output/retrieval_coco_$2 \ 19 | --evaluate 20 | 21 | # print val than test set 22 | 23 | # ================= step2: train and evaluate val set ================= 24 | echo '(coco step2: train on retrieval) load pretrained model from: '$3; 25 | sed -i "/^\(pretrained: \).*/s//\1'$3'/" ./configs/retrieval_coco.yaml; 26 | 27 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 28 | --config ./configs/retrieval_coco.yaml \ 29 | --output_dir output/retrieval_coco_$2 30 | 31 | # # ===========================step3: evaluate on val/test split ========= 32 | # # evaluate on test 33 | 34 | # TRAINED_MODEL="output\/retrieval_coco_${2}\/checkpoint_best.pth" 35 | # echo '(coco step3: test/val eval) load trained retrieval model from: '${TRAINED_MODEL}; 36 | 37 | # sed -i "/^\(pretrained: \).*/s//\1'$TRAINED_MODEL'/" ./configs/retrieval_coco.yaml; 38 | 39 | 40 | # python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 41 | # --config ./configs/retrieval_coco.yaml \ 42 | # --output_dir output/retrieval_coco_$2 \ 43 | # --evaluate -------------------------------------------------------------------------------- /src/data_preprocess/bbox_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import pandas as pd 17 | import cv2 18 | import numpy as np 19 | from PIL import Image 20 | import ast 21 | import json 22 | 23 | ann = pd.read_csv('/vg_train_val_success_compressed.tsv', sep='\t', header=None) 24 | # 364, 190 25 | 26 | sample = ann.iloc[-30] 27 | 28 | im = Image.open('VG/VG_100K_2/' + sample[0].split('/')[-1]) 29 | 30 | w, h = im.size 31 | print(w, h) 32 | 33 | bboxs = json.loads(sample[2]) 34 | classes = ast.literal_eval(sample[3]) 35 | img = np.asarray(im) 36 | max_char_per_line = 30 37 | y0, dy = 10, 20 38 | for i in range(len(bboxs)): 39 | cv2.rectangle(img, (bboxs[i][0], bboxs[i][1]), (bboxs[i][0]+bboxs[i][2], bboxs[i][1]+bboxs[i][3]), (0, 255, 0), 2) 40 | text_img = cv2.putText(img, classes[i], (bboxs[i][0], bboxs[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0)) 41 | bboxs[i][2] = bboxs[i][0] + bboxs[i][2] 42 | bboxs[i][3] = bboxs[i][1] + bboxs[i][3] 43 | w_1 = min(int((bboxs[i][0]/2 + bboxs[i][2]/2)/w * 3), 2) 44 | h_1 = min(int((bboxs[i][1]/2 + bboxs[i][3]/2)/h * 3), 2) 45 | # print(w_1, h_1, w, h) 46 | block = 3*h_1 + w_1 47 | prompt_text = '. The block ' + str(block) + ' has a ' + classes[i] + ' .'; 48 | print(prompt_text) 49 | 50 | cv2.imwrite('vg_example_1.jpg', img) 51 | -------------------------------------------------------------------------------- /src/blip_src/data/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.utils import download_url 7 | 8 | from PIL import Image 9 | 10 | from data.utils import pre_caption 11 | 12 | class nlvr_dataset(Dataset): 13 | def __init__(self, transform, image_root, ann_root, split): 14 | ''' 15 | image_root (string): Root directory of images 16 | ann_root (string): directory to store the annotation file 17 | split (string): train, val or test 18 | ''' 19 | urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', 20 | 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', 21 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} 22 | filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} 23 | 24 | download_url(urls[split],ann_root) 25 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 26 | 27 | self.transform = transform 28 | self.image_root = image_root 29 | 30 | 31 | def __len__(self): 32 | return len(self.annotation) 33 | 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.annotation[index] 38 | 39 | image0_path = os.path.join(self.image_root,ann['images'][0]) 40 | image0 = Image.open(image0_path).convert('RGB') 41 | image0 = self.transform(image0) 42 | 43 | image1_path = os.path.join(self.image_root,ann['images'][1]) 44 | image1 = Image.open(image1_path).convert('RGB') 45 | image1 = self.transform(image1) 46 | 47 | sentence = pre_caption(ann['sentence'], 40) 48 | 49 | if ann['label']=='True': 50 | label = 1 51 | else: 52 | label = 0 53 | 54 | words = sentence.split(' ') 55 | 56 | if 'left' not in words and 'right' not in words: 57 | if random.random()<0.5: 58 | return image0, image1, sentence, label 59 | else: 60 | return image1, image0, sentence, label 61 | else: 62 | if random.random()<0.5: 63 | return image0, image1, sentence, label 64 | else: 65 | new_words = [] 66 | for word in words: 67 | if word=='left': 68 | new_words.append('right') 69 | elif word=='right': 70 | new_words.append('left') 71 | else: 72 | new_words.append(word) 73 | 74 | sentence = ' '.join(new_words) 75 | return image1, image0, sentence, label 76 | 77 | 78 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # MODEL ZOO 2 | 3 | 4 | ## 1. Pre-trained Models 5 | 6 | | Method | Vision Encoder | #Images | Dataset | Pretrained Weights | Training Logs | 7 | | :--- | :--- | :--- | :--- | :----: | :---: | 8 | | PTP-BLIP| ViT-B(DeiT) | 4M | CC3M+COCO+VG+SBU | [link](https://huggingface.co/sail/PTP/blob/main/Pretrain_concated_pred_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_pretrain.txt) | 9 | 10 | ## 2. Downstream Model 11 | 12 | 13 | ### 2.1 Captioning 14 | | Method | B@4 | CIDEr | | Config | 15 | | :--- | :--- | :--- | ---: | 16 | | PTP-BLIP| 40.1 | 135.0 | configs/caption_coco.yaml | 17 | 18 | 19 | ### 2.2 Zero-shot Retrieval 20 | 21 | 25 | 26 | 27 | #### 2.2.2 Flickr30K 28 | 29 | | Method | I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 30 | | :--- | :--- | :--- | :--- | :--- | :---: | 31 | | PTP-BLIP| 86.4 | 67.0 | [link](https://huggingface.co/sail/PTP/blob/main/zero_shot_coco_checkpoint_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_zero_shot.txt) | configs/retrieval_flickr.yaml | 32 | 33 | 34 | ### 2.3 Retrieval (Fine-tune) 35 | 36 | Tip: Please use as large batch size as possible, we experimentally find that the larger batch size leads to better result for this task. Due to memory limiation, we use batch size 24 rather than 28 in original implmentation. 37 | 38 | 39 | #### 2.3.1 COCO 40 | | Method |I2T@1 | T2I@1 | | Config | 41 | | :--- | :--- | :--- | :---: | 42 | | PTP-BLIP| 77.6 | 59.4 | configs/retrieval_coco.yaml | 43 | 44 | 45 | #### 2.3.2 Flickr30K 46 | | Method |I2T@1 | T2I@1 | Model Weight | Training Logs | Config | 47 | | :--- | :--- | :--- | :--- | :--- | :---: | 48 | | PTP-BLIP| 96.1 | 84.2 | [link](https://huggingface.co/sail/PTP/blob/main/flickr30k_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_flickr30k_ft.txt) | configs/retrieval_flickr.yaml | 49 | 50 | ### 2.4 VQA V2 51 | 52 | | Method | Test-dev|Test-std |Model Weight | Training Logs | Config | 53 | | :--- | :--- | :--- | :--- | :--- | :---: | 54 | | PTP-BLIP| 76.02 | 76.18 | [link](https://huggingface.co/sail/PTP/blob/main/vqa_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_vqa_v2.txt) | configs/vqa.yaml | 55 | 56 | ### 2.5 NLVR 57 | 58 | | Method | Dev| Test-P | Model Weight | Training Logs | Config | 59 | | :--- | :--- | :--- | :--- | :--- | :---: | 60 | | PTP-BLIP| 80.45 | 80.70 | [link](https://huggingface.co/sail/PTP/blob/main/nlvr_ft_4m.pth) | [link](https://huggingface.co/sail/PTP/blob/main/4M_ptp_nlvr.txt) | configs/nlvr.yaml | -------------------------------------------------------------------------------- /src/data_preprocess/generate_sample_with_bbox_and_classes.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import numpy as np 4 | import zlib 5 | import os 6 | 7 | 8 | out_json = "cc3m_269_w_bbox.json" # 2681187 samples preserved 9 | 10 | object_dict_path = "VG-SGG-dicts-vgoi6-clipped.json" 11 | object_dict = json.load(open(object_dict_path,'r'))['label_to_idx'] 12 | # step1: encode each sample into a vector 13 | print("{} object classes in total".format(len(object_dict))) 14 | # print(object_dict) 15 | 16 | objects_class_count = np.zeros(len(object_dict)) 17 | 18 | 19 | def generate_object_bbox(objects): 20 | object_bboxs = [] 21 | for index, object in enumerate(objects): 22 | if index < 10: 23 | object_bboxs.append([int(coord) for coord in object['rect']]) 24 | # print(object_bboxs) 25 | return object_bboxs 26 | 27 | # step1: generate object tags for each sample 28 | print("===step1: begin to generate object tags caption====") 29 | sample_index = [] 30 | object_bboxs_dict = {} 31 | 32 | for i in range(12): 33 | # if i > 0: 34 | # break 35 | src_tsv = "/Data/CC3M/{}/predictions.tsv".format(i) 36 | metadata = pd.read_csv(src_tsv, sep='\t', header=None) 37 | # append boxes and indexs 38 | for j in range(len(metadata)): 39 | num_boxes = json.loads(metadata.iloc[j][1])['num_boxes'] 40 | index = metadata.iloc[j][0] 41 | sample_index.append(index) 42 | objects = json.loads(metadata.iloc[j][1])['objects'] 43 | object_bboxs_dict[index] = generate_object_bbox(objects) 44 | print("subdir {}/{} finished".format(i+1, 12)) 45 | 46 | # step2: align cc3m with own list 47 | print("===step2: begin to align with previous caption====") 48 | train_set = pd.read_csv('/CC3M/Train-GCC-training.tsv', sep='\t', header=None) 49 | val_set = pd.read_csv('/CC3M/Validation-GCC-1.1.0-Validation.tsv', sep='\t', header=None) 50 | all_set = pd.concat([train_set, val_set]) 51 | success_data = [] 52 | # count = 0 53 | for sample in sample_index: 54 | # count += 1 55 | # if count > 10000: 56 | # break 57 | file_name = str(zlib.crc32(all_set.iloc(0)[sample][1].encode('utf-8')) & 0xffffffff) + '.jpg' 58 | if sample >= len(train_set): 59 | img_root = "validation" 60 | sub_dir = str((sample - len(train_set)) // 1000) 61 | else: 62 | img_root = "train" 63 | sub_dir = str(sample // 1000) 64 | img_path = os.path.join(img_root, sub_dir, file_name) 65 | rel_img_path = os.path.join(img_root, sub_dir, file_name) 66 | success_data.append({'image': rel_img_path, 'caption': all_set.iloc(0)[sample][0], 'object': object_bboxs_dict[sample]}) 67 | 68 | print("{} samples preserved".format(len(success_data))) 69 | # 1484208 samples preserved 70 | 71 | # step3: 72 | print("===step3: merge with caption cc====") 73 | ann = json.load(open('../metadata/cc3m/train_success_align_269.json', 'r')) 74 | 75 | object_caption_dict = {} 76 | 77 | success_data_preserved = [] 78 | img_paths = dict() 79 | for i in range(len(success_data)): 80 | img_paths[success_data[i]['image']] = 0 81 | object_caption_dict[success_data[i]['image']] = success_data[i]['object'] 82 | 83 | # find the joint part 84 | success_data_preserved = [] 85 | for i in range(len(ann)): 86 | if ann[i]['image'] in img_paths.keys(): 87 | ann[i]['object'] = object_caption_dict[ann[i]['image']] 88 | success_data_preserved.append(ann[i]) 89 | if i % 1000 == 0: 90 | print("{}/{} finished".format(i, len(ann))) 91 | 92 | print("{} samples preserved".format(len(success_data_preserved))) 93 | 94 | 95 | with open(out_json, 'w') as outfile: 96 | json.dump(success_data_preserved, outfile) -------------------------------------------------------------------------------- /src/blip_src/data/flickr30k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class flickr30k_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. flickr30k/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' 18 | filename = 'flickr30k_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class flickr30k_retrieval_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 54 | ''' 55 | image_root (string): Root directory of images (e.g. flickr30k/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} 61 | filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | self.text = [] 70 | self.image = [] 71 | self.txt2img = {} 72 | self.img2txt = {} 73 | 74 | txt_id = 0 75 | for img_id, ann in enumerate(self.annotation): 76 | self.image.append(ann['image']) 77 | self.img2txt[img_id] = [] 78 | for i, caption in enumerate(ann['caption']): 79 | self.text.append(pre_caption(caption,max_words)) 80 | self.img2txt[img_id].append(txt_id) 81 | self.txt2img[txt_id] = img_id 82 | txt_id += 1 83 | 84 | def __len__(self): 85 | return len(self.annotation) 86 | 87 | def __getitem__(self, index): 88 | 89 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 90 | image = Image.open(image_path).convert('RGB') 91 | image = self.transform(image) 92 | 93 | return image, index -------------------------------------------------------------------------------- /src/blip_src/models/blip_itm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig, BertModel 9 | from transformers import BertTokenizer 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | 15 | from models.blip import create_vit, init_tokenizer, load_checkpoint 16 | 17 | class BLIP_ITM(nn.Module): 18 | def __init__(self, 19 | med_config = 'configs/med_config.json', 20 | image_size = 384, 21 | vit = 'base', 22 | vit_grad_ckpt = False, 23 | vit_ckpt_layer = 0, 24 | embed_dim = 256, 25 | ): 26 | """ 27 | Args: 28 | med_config (str): path for the mixture of encoder-decoder model's configuration file 29 | image_size (int): input image size 30 | vit (str): model size of vision transformer 31 | """ 32 | super().__init__() 33 | 34 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 35 | self.tokenizer = init_tokenizer() 36 | med_config = BertConfig.from_json_file(med_config) 37 | med_config.encoder_width = vision_width 38 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 39 | 40 | text_width = self.text_encoder.config.hidden_size 41 | 42 | self.vision_proj = nn.Linear(vision_width, embed_dim) 43 | self.text_proj = nn.Linear(text_width, embed_dim) 44 | 45 | self.itm_head = nn.Linear(text_width, 2) 46 | 47 | 48 | def forward(self, image, caption, match_head='itm'): 49 | 50 | image_embeds = self.visual_encoder(image) 51 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 52 | 53 | text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, 54 | return_tensors="pt").to(image.device) 55 | 56 | 57 | if match_head=='itm': 58 | output = self.text_encoder(text.input_ids, 59 | attention_mask = text.attention_mask, 60 | encoder_hidden_states = image_embeds, 61 | encoder_attention_mask = image_atts, 62 | return_dict = True, 63 | ) 64 | itm_output = self.itm_head(output.last_hidden_state[:,0,:]) 65 | return itm_output 66 | 67 | elif match_head=='itc': 68 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 69 | return_dict = True, mode = 'text') 70 | image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) 71 | text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) 72 | 73 | sim = image_feat @ text_feat.t() 74 | return sim 75 | 76 | 77 | def blip_itm(pretrained='',**kwargs): 78 | model = BLIP_ITM(**kwargs) 79 | if pretrained: 80 | model,msg = load_checkpoint(model,pretrained) 81 | assert(len(msg.missing_keys)==0) 82 | return model 83 | -------------------------------------------------------------------------------- /src/blip_src/data/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.utils import pre_question 9 | 10 | from torchvision.datasets.utils import download_url 11 | 12 | class vqa_dataset(Dataset): 13 | def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): 14 | self.split = split 15 | 16 | self.transform = transform 17 | self.vqa_root = vqa_root 18 | self.vg_root = vg_root 19 | 20 | if split=='train': 21 | urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', 22 | 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', 23 | 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} 24 | 25 | self.annotation = [] 26 | for f in train_files: 27 | download_url(urls[f],ann_root) 28 | self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) 29 | else: 30 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) 31 | self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) 32 | 33 | download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) 34 | self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) 35 | 36 | 37 | def __len__(self): 38 | return len(self.annotation) 39 | 40 | def __getitem__(self, index): 41 | 42 | ann = self.annotation[index] 43 | 44 | if ann['dataset']=='vqa': 45 | image_path = os.path.join(self.vqa_root,ann['image']) 46 | elif ann['dataset']=='vg': 47 | image_path = os.path.join(self.vg_root,ann['image']) 48 | 49 | image = Image.open(image_path).convert('RGB') 50 | image = self.transform(image) 51 | 52 | if self.split == 'test': 53 | question = pre_question(ann['question']) 54 | question_id = ann['question_id'] 55 | return image, question, question_id 56 | 57 | 58 | elif self.split=='train': 59 | 60 | question = pre_question(ann['question']) 61 | 62 | if ann['dataset']=='vqa': 63 | answer_weight = {} 64 | for answer in ann['answer']: 65 | if answer in answer_weight.keys(): 66 | answer_weight[answer] += 1/len(ann['answer']) 67 | else: 68 | answer_weight[answer] = 1/len(ann['answer']) 69 | 70 | answers = list(answer_weight.keys()) 71 | weights = list(answer_weight.values()) 72 | 73 | elif ann['dataset']=='vg': 74 | answers = [ann['answer']] 75 | weights = [0.2] 76 | 77 | return image, question, answers, weights 78 | 79 | 80 | def vqa_collate_fn(batch): 81 | image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] 82 | for image, question, answer, weights in batch: 83 | image_list.append(image) 84 | question_list.append(question) 85 | weight_list += weights 86 | answer_list += answer 87 | n.append(len(answer)) 88 | return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n -------------------------------------------------------------------------------- /src/blip_src/data/pretrain_dataset_concated_pred_tsv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import pandas as pd 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | import numpy as np 9 | from PIL import ImageFile 10 | from PIL.Image import blend as blend 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | Image.MAX_IMAGE_PIXELS = None 13 | import ast 14 | from data.utils import pre_caption 15 | import os,glob 16 | 17 | class pretrain_dataset(Dataset): 18 | def __init__(self, ann_file, laion_path, transform): 19 | self.img_root = "/dataset" # server 20 | self.ann_pretrain = None 21 | for f in ann_file: 22 | ann_temp = pd.read_csv(f, sep='\t', header=None) 23 | if self.ann_pretrain is None: 24 | self.ann_pretrain = ann_temp 25 | else: 26 | self.ann_pretrain = pd.concat([self.ann_pretrain, ann_temp], ignore_index=True, sort=False) 27 | 28 | self.annotation = self.ann_pretrain 29 | self.transform = transform 30 | 31 | 32 | def generate_bbox_annotation(self, img, ann): 33 | prompt_text = '.' 34 | if len(ann) > 2: 35 | ann[2] = json.loads(ann[2]) 36 | ann[3] = ast.literal_eval(ann[3]) 37 | object_num = len(ann[3]) 38 | if object_num > 0: 39 | sample_index = random.randint(0, object_num-1) 40 | w, h = img.size 41 | bbox_loc = ann[2][sample_index] 42 | # print(bbox_loc) 43 | w_1 = min(int((bbox_loc[0]/2 + bbox_loc[2]/2)/w * 3), 2) 44 | h_1 = min(int((bbox_loc[1]/2 + bbox_loc[3]/2)/h * 3), 2) 45 | # print(w_1, h_1, w, h) 46 | block = 3*h_1 + w_1 47 | prompt_text = '. The block ' + str(block) + ' has a ' + ann[3][sample_index] + ' .'; 48 | else: 49 | prompt_text = "." 50 | return prompt_text 51 | 52 | def __len__(self): 53 | return len(self.annotation) 54 | 55 | def __getitem__(self, index): 56 | ann = self.annotation.iloc[index] 57 | try: 58 | image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 59 | caption_str = ann[1] 60 | if caption_str[0] == '[': 61 | captions = ast.literal_eval(caption_str) 62 | temp_caption = captions[-1] + captions[random.randint(0, len(captions) - 2)] 63 | else: 64 | temp_caption = caption_str 65 | temp_caption = temp_caption + self.generate_bbox_annotation(image, ann) 66 | except Exception as e: 67 | print(e) 68 | return self.__getitem__(random.randint(0, self.__len__()-1)) 69 | image = self.transform(image) 70 | caption = pre_caption(temp_caption, 30) 71 | # caption = ann['caption'] 72 | return image, caption 73 | 74 | # def __getitem__(self, index): 75 | # ann = self.annotation.iloc[index] 76 | # try: 77 | # image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 78 | # caption_str = ann[1] 79 | # if caption_str[0] == '[': 80 | # captions = ast.literal_eval(caption_str) 81 | # if len(captions) < 6: 82 | # temp_caption = captions[random.randint(0, len(captions) - 1)] 83 | # else: 84 | # temp_caption = captions[random.randint(0, len(captions)//3 - 1)] + ', ' + \ 85 | # captions[random.randint(len(captions)//3, len(captions)//3 * 2)] + ', ' + \ 86 | # captions[random.randint(len(captions)//3*2, len(captions)-1)] 87 | # else: 88 | # temp_caption = caption_str 89 | # temp_caption = temp_caption + self.generate_bbox_annotation(image, ann) 90 | # except Exception as e: 91 | # print(e) 92 | # return self.__getitem__(random.randint(0, self.__len__()-1)) 93 | # image = self.transform(image) 94 | # caption = pre_caption(temp_caption, 30) 95 | # # caption = ann['caption'] 96 | # return image, caption -------------------------------------------------------------------------------- /src/blip_src/eval_nocaps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import argparse 9 | import os 10 | import ruamel.yaml as yaml 11 | import numpy as np 12 | import random 13 | import time 14 | import datetime 15 | import json 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.utils.data import DataLoader 24 | 25 | from models.blip import blip_decoder 26 | import utils 27 | from data import create_dataset, create_sampler, create_loader 28 | from data.utils import save_result 29 | 30 | @torch.no_grad() 31 | def evaluate(model, data_loader, device, config): 32 | # evaluate 33 | model.eval() 34 | 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | header = 'Evaluation:' 37 | print_freq = 10 38 | 39 | result = [] 40 | for image, image_id in metric_logger.log_every(data_loader, print_freq, header): 41 | 42 | image = image.to(device) 43 | 44 | captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], 45 | min_length=config['min_length'], repetition_penalty=1.1) 46 | 47 | for caption, img_id in zip(captions, image_id): 48 | result.append({"image_id": img_id.item(), "caption": caption}) 49 | 50 | return result 51 | 52 | 53 | def main(args, config): 54 | utils.init_distributed_mode(args) 55 | 56 | device = torch.device(args.device) 57 | 58 | # fix the seed for reproducibility 59 | seed = args.seed + utils.get_rank() 60 | torch.manual_seed(seed) 61 | np.random.seed(seed) 62 | random.seed(seed) 63 | cudnn.benchmark = True 64 | 65 | #### Dataset #### 66 | print("Creating captioning dataset") 67 | val_dataset, test_dataset = create_dataset('nocaps', config) 68 | 69 | if args.distributed: 70 | num_tasks = utils.get_world_size() 71 | global_rank = utils.get_rank() 72 | samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) 73 | else: 74 | samplers = [None,None] 75 | 76 | val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, 77 | batch_size=[config['batch_size']]*2,num_workers=[4,4], 78 | is_trains=[False, False], collate_fns=[None,None]) 79 | 80 | #### Model #### 81 | print("Creating model") 82 | model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 83 | prompt=config['prompt']) 84 | 85 | model = model.to(device) 86 | 87 | model_without_ddp = model 88 | if args.distributed: 89 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 90 | model_without_ddp = model.module 91 | 92 | val_result = evaluate(model_without_ddp, val_loader, device, config) 93 | val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') 94 | test_result = evaluate(model_without_ddp, test_loader, device, config) 95 | test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', default='./configs/nocaps.yaml') 101 | parser.add_argument('--output_dir', default='output/NoCaps') 102 | parser.add_argument('--device', default='cuda') 103 | parser.add_argument('--seed', default=42, type=int) 104 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | parser.add_argument('--distributed', default=True, type=bool) 107 | parser.add_argument( 108 | "--local_rank", 109 | default=0, 110 | type=int, 111 | help="""local rank for distrbuted training.""", 112 | ) 113 | args = parser.parse_args() 114 | 115 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 116 | 117 | args.result_dir = os.path.join(args.output_dir, 'result') 118 | 119 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 120 | Path(args.result_dir).mkdir(parents=True, exist_ok=True) 121 | 122 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 123 | 124 | main(args, config) -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | We prepare the pre-training corpus following OSCAR and BLIP. 3 | __As the data prepartion is very time consuming, we provide our experience for reference.__ 4 | 5 | ## 1. Download Datasets (images) 6 | ### Pre-train Datasets: 7 | 8 | ### CC3M 9 | Step1: First download train/val/test annotation files include URL from [google-research-datasets](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md). 10 | 11 | Step2: We provided our script for downloading and split CC3M into subsplit in [cc3m_download.py](https://huggingface.co/sail/PTP/blob/main/download_cc3m.py). 12 | **It's better to use our cript for downloading as the filename maybe different with different preprocess.** 13 | 14 | Notice we only download 2.8M data as some URLs has invalid. 15 | 16 | ### SBU 17 | First from annotation files include URL from [huggingface](https://huggingface.co/datasets/sbu_captions). 18 | 19 | Tip: We provided our script for downloading sbu: 20 | [download_sbu.py](https://huggingface.co/sail/PTP/blob/main/download_sbu.py) 21 | 22 | ### Visual Genome 23 | 24 | Download image (version1.2) from [visualgenome](https://visualgenome.org/api/v0/api_home.html). 25 | 26 | The download dirs will be VG_100K and VG_100K_2. 27 | ```bash 28 | mkdir image 29 | mv VG_100K/* image/ 30 | mv VG_100K_2/* image/ 31 | ``` 32 | 33 | ### COCO 34 | 35 | Down image (coco2014) from [coco](https://cocodataset.org/#download). 36 | Download 2014 Train, 2014 val and 2015 Test images. 37 | 38 | ### CC12M 39 | Step1: Download annotation files include URLs from [google-research-datasets](https://github.com/google-research-datasets/conceptual-12m). 40 | 41 | Step2: Just modify the source tsv file and image path in cc3m_download.py. Then download data the same as cc3m. 42 | 43 | Notice we only download 10M data as some URLs has invalid. 44 | 45 | ### Fine-tune Datasets: 46 | 47 | ### COCO 48 | Down image (coco2014) from [coco](https://cocodataset.org/#download). 49 | Download 2014 Train, 2014 val, 2014 test and 2015 Test images. 50 | 51 | 52 | ### Flickr30K 53 | Download image from [kaggle](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset). 54 | 55 | ### VQA V2 56 | 57 | Download images from [VQA](https://visualqa.org/download.html). 58 | 59 | ### NLVR 60 | Download images from [NLVR](https://lil.nlp.cornell.edu/nlvr/). 61 | 62 | ## Originze Datasets 63 | 64 | Prepare the datasets as follow: 65 | ``` 66 | Dataset/ 67 | CC3M/ 68 | images/ 69 | train/x/*.jpg 70 | val/x/*.jpg 71 | SBU/ 72 | dataset/ 73 | train/x/*.png 74 | coco2014/ 75 | COCO2014/ 76 | train2014/*.jpg 77 | val2014/*.jpg 78 | test2015/*.jpg 79 | 80 | VisualGenome/ 81 | image/*.jpg 82 | ``` 83 | 84 | Use soft link to map directory, for example 85 | ```bash 86 | ln -s [PATH_TO_COCO2014] Dataset/coco2014/COCO2014 87 | ``` 88 | 89 | ## 2. Download/Prepare Corpus (image-text pair) 90 | We provide two kinds of shuffled image-text pair. We use object information from [OSCAR](https://github.com/microsoft/Oscar/blob/master/VinVL_DOWNLOAD.md) and follow [BLIP](https://github.com/salesforce/BLIP) for caption refine. 91 | 1. Specifically, we download corups and object features from OSCAR codebase first. Follow [download_cc3m_predictions.sh](src/data_preprocess/download_cc3m_predictions.sh) for details. Download COCOTrain, CC Train, SBU (all) and VG. 92 | 2. Then Generate object_bbox and object_classes from object feature. Follow [generate_sample_with_bbox_and_classes.py](src/data_preprocess/generate_sample_with_bbox_and_classes.py) for details. 93 | 3. At last, use generated caption to padding with origing caption, follow BLIP. 94 | 95 | **Notice each COCO image include 5 text in [oscar corpus](https://biglmdiag.blob.core.windows.net/vinvl/pretrain_corpus/coco_flickr30k_gqa.tsv). As COCO is high-quality caption, it will affect the final downstream result much.** 96 | 97 | Make sure each line in corpus is 98 | ``` 99 | [image, refined_caption, object_bbox, object_classes] 100 | ``` 101 | A example is given below: 102 | 103 | ```bash 104 | CC3M/images/train/1597/3250687125.jpg i shall be bringing this white chair and table to the shoot; a white table with two white chairs and a couch [[340, 226, 417, 323], [16, 364, 348, 810], [256, 206, 380, 325], [195, 322, 627, 899], [0, 0, 192, 288], [568, 198, 730, 335], [95, 107, 202, 141], [531, 0, 732, 191], [666, 244, 734, 369], [378, 208, 677, 341]] ['pillow', 'chair', 'pillow', 'table', 'window', 'pillow', 'box', 'window', 'pillow', 'pillow'] 105 | ``` 106 | 107 | - 2.8M Image (2G): [CC3M](https://drive.google.com/file/d/1iO-d5e7mOvWEreDrlNyEc_RU_gP7FNBk/view?usp=sharing) 108 | 109 | The filtered file path is: 110 | 111 | - 4M Image (2.38G): [CC3M+COCO+VG+SBU](https://drive.google.com/file/d/1NnI-_ha4oqeZeHVOv1GBcvV1txgO9R68/view?usp=sharing) 112 | 113 | Thanks Jaeseok Byun for helping correct this corpus. 114 | 115 | As we used all spaces for huggingface and google driver now, follow mentonied way to prepare more large corpus. -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | # Getting Started with PTP Model 2 | 3 | ## PTP-BLIP 4 | 5 | ### 1. Download VIT-Base Models. 6 | 7 | ```bash 8 | cd ptp/src/blip_src 9 | mkdir pretrained_models && cd pretrained_models; 10 | wget -c https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth; 11 | ``` 12 | 13 | ### 2. Pre-train 14 | 15 | ```bash 16 | python move_pretrained_weights.py; 17 | 18 | python -m torch.distributed.launch --nproc_per_node=8 pretrain_concated_pred_tsv.py \ 19 | --config ./configs/mt_pt/tsv/pretrain_concated_pred_4M.yaml --output_dir output/Pretrain_concated_pred_4M 20 | 21 | echo "output dir is: output/Pretrain_concated_pred_4M" 22 | 23 | ``` 24 | 25 | Alternatively, download our pretrained model from [MODEL_ZOO.md](MODEL_ZOO.md). 26 | 27 | ### 3. Downstream Task Evaluation 28 | After pre-trained, replace the **pretrained:** in yaml of each task with pre-trained model or downloaded model. 29 | Then we provide run scripts for these tasks: 30 | 31 | #### Captioning 32 | 33 | ```bash 34 | # image captioning 35 | 36 | function rand(){ 37 | min=$1 38 | max=$(($2-$min+1)) 39 | num=$(date +%s%N) 40 | echo $(($num%$max+$min)) 41 | } 42 | 43 | 44 | echo '(coco captioning:) load pretrained model from: '; 45 | 46 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_caption.py \ 47 | --config ./configs/caption_coco.yaml \ 48 | --output_dir output/captioning_coco 49 | ``` 50 | 51 | #### Retrieval 52 | 53 | ```bash 54 | function rand(){ 55 | min=$1 56 | max=$(($2-$min+1)) 57 | num=$(date +%s%N) 58 | echo $(($num%$max+$min)) 59 | } 60 | 61 | 62 | # ============================== step1: zero-shot retrieval evaluation ========= 63 | # evaluate on test 64 | echo '(coco step1: zero-shot evaluation) load pretrained model from: '; 65 | 66 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 67 | --config ./configs/retrieval_coco.yaml \ 68 | --output_dir output/retrieval_coco_zs \ 69 | --evaluate 70 | 71 | # ================= step2: train and evaluate val & test set ================= 72 | echo '(coco step2: train on retrieval) load pretrained model from: '; 73 | 74 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_retrieval.py \ 75 | --config ./configs/retrieval_coco.yaml \ 76 | --output_dir output/retrieval_coco_ft 77 | 78 | ``` 79 | 80 | #### VQA 81 | 82 | 83 | ```bash 84 | # vqa 85 | 86 | function rand(){ 87 | min=$1 88 | max=$(($2-$min+1)) 89 | num=$(date +%s%N) 90 | echo $(($num%$max+$min)) 91 | } 92 | 93 | 94 | echo '(vqa:) load pretrained model from: '; 95 | 96 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) \ 97 | train_vqa.py --config ./configs/vqa.yaml \ 98 | --output_dir output/vqa_v2_vqa 99 | ``` 100 | 101 | After generate result files, submitted in [eval_ai](https://eval.ai/web/challenges/challenge-page/830) for final results. 102 | 103 | 104 | #### NLVR 105 | 106 | ```bash 107 | # NLVR2 108 | 109 | function rand(){ 110 | min=$1 111 | max=$(($2-$min+1)) 112 | num=$(date +%s%N) 113 | echo $(($num%$max+$min)) 114 | } 115 | 116 | echo '(NLVR:) load pretrained model from: '; 117 | 118 | python -m torch.distributed.launch --nproc_per_node=$1 --master_port=$(rand 2000 4000) train_nlvr.py \ 119 | --config ./configs/nlvr.yaml \ 120 | --output_dir output/NLVR_NLVR 121 | ``` 122 | 123 | 124 | #### Run All Downstream Task At Once 125 | We also provide a shell script for all downstream task as below: 126 | 127 | ```bash 128 | bash multiple_scripts/multiple_exp_all_single_8u_ft.sh Pretrain_concated_pred_4M 129 | ``` 130 | 131 | where _Pretrain_concated_pred_4M_ is the pretrained output directory. 132 | 133 | The 134 | ```bash 135 | python move_pretrained_weights.py; 136 | 137 | gpu_num=8; 138 | time=$(date "+%Y-%m-%d-%H:%M:%S"); 139 | suffix=$1${time}; # the suffix to distingush different experiment, e.g. $1='generation_mix' 140 | 141 | PRETRAINED_MODEL="output\/$1\/checkpoint_19.pth" 142 | 143 | echo "${suffix}"; 144 | 145 | bash multiple_scripts/ft/exp_2.sh $gpu_num $suffix $PRETRAINED_MODEL; # captioning, ~1h 146 | 147 | bash multiple_scripts/ft/exp_5.sh $gpu_num $suffix $PRETRAINED_MODEL; # flickr30 retrieval, ~1h 148 | 149 | bash multiple_scripts/ft/exp_4.sh $gpu_num $suffix $PRETRAINED_MODEL; # NLVR2, ~2h 150 | 151 | bash multiple_scripts/ft/exp_1.sh $gpu_num $suffix $PRETRAINED_MODEL; # coco retrieval, ~12h 152 | 153 | bash multiple_scripts/ft/exp_3.sh $gpu_num $suffix $PRETRAINED_MODEL; # vqa, very slow, ~35 h 154 | 155 | 156 | ``` 157 | 158 | **Tip** 159 | The simplest way to evaluate model on all tasks is: 160 | ```bash 161 | bash multiple_scripts/multiple_exp_all_single_8u_ft.sh 162 | ``` 163 | 164 | 165 | ## PTP-ViLT 166 | 167 | ### 1. Pre-train 168 | ```bash 169 | python run.py with data_root=/dataset num_gpus=8 num_nodes=1 task_mlm_itm whole_word_masking=True step200k per_gpu_batchsize=64 170 | ``` 171 | 172 | ### 2. Downstream Tasks Evaluation 173 | 174 | #### vqa 175 | ```bash 176 | python run.py with data_root=/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_vqa_randaug test_only=True precision=32 load_path="weights/vilt_vqa.ckpt" 177 | 178 | ``` -------------------------------------------------------------------------------- /src/blip_src/data/coco_karpathy_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets.utils import download_url 6 | 7 | from PIL import Image 8 | 9 | from data.utils import pre_caption 10 | 11 | class coco_karpathy_train(Dataset): 12 | def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): 13 | ''' 14 | image_root (string): Root directory of images (e.g. coco/images/) 15 | ann_root (string): directory to store the annotation file 16 | ''' 17 | url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' 18 | filename = 'coco_karpathy_train.json' 19 | 20 | download_url(url,ann_root) 21 | 22 | self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) 23 | self.transform = transform 24 | self.image_root = image_root # self.img_root = "/dataset/CC3M/images" # server 25 | self.max_words = max_words 26 | self.prompt = prompt 27 | 28 | self.img_ids = {} 29 | n = 0 30 | for ann in self.annotation: 31 | img_id = ann['image_id'] 32 | if img_id not in self.img_ids.keys(): 33 | self.img_ids[img_id] = n 34 | n += 1 35 | 36 | def __len__(self): 37 | return len(self.annotation) 38 | 39 | def __getitem__(self, index): 40 | 41 | ann = self.annotation[index] 42 | 43 | image_path = os.path.join(self.image_root,ann['image']) 44 | image = Image.open(image_path).convert('RGB') 45 | image = self.transform(image) 46 | 47 | caption = self.prompt+pre_caption(ann['caption'], self.max_words) 48 | 49 | return image, caption, self.img_ids[ann['image_id']] 50 | 51 | 52 | class coco_karpathy_caption_eval(Dataset): 53 | def __init__(self, transform, image_root, ann_root, split): 54 | ''' 55 | image_root (string): Root directory of images (e.g. coco/images/) 56 | ann_root (string): directory to store the annotation file 57 | split (string): val or test 58 | ''' 59 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 60 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 61 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 62 | 63 | download_url(urls[split],ann_root) 64 | 65 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 66 | self.transform = transform 67 | self.image_root = image_root 68 | 69 | def __len__(self): 70 | return len(self.annotation) 71 | 72 | def __getitem__(self, index): 73 | 74 | ann = self.annotation[index] 75 | 76 | image_path = os.path.join(self.image_root,ann['image']) 77 | image = Image.open(image_path).convert('RGB') 78 | image = self.transform(image) 79 | 80 | img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] 81 | 82 | return image, int(img_id) 83 | 84 | 85 | class coco_karpathy_retrieval_eval(Dataset): 86 | def __init__(self, transform, image_root, ann_root, split, max_words=30): 87 | ''' 88 | image_root (string): Root directory of images (e.g. coco/images/) 89 | ann_root (string): directory to store the annotation file 90 | split (string): val or test 91 | ''' 92 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', 93 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} 94 | filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} 95 | 96 | download_url(urls[split],ann_root) 97 | 98 | self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) 99 | self.transform = transform 100 | self.image_root = image_root 101 | 102 | self.text = [] 103 | self.image = [] 104 | self.txt2img = {} 105 | self.img2txt = {} 106 | 107 | txt_id = 0 108 | for img_id, ann in enumerate(self.annotation): 109 | self.image.append(ann['image']) 110 | self.img2txt[img_id] = [] 111 | for i, caption in enumerate(ann['caption']): 112 | self.text.append(pre_caption(caption,max_words)) 113 | self.img2txt[img_id].append(txt_id) 114 | self.txt2img[txt_id] = img_id 115 | txt_id += 1 116 | 117 | def __len__(self): 118 | return len(self.annotation) 119 | 120 | def __getitem__(self, index): 121 | 122 | image_path = os.path.join(self.image_root, self.annotation[index]['image']) 123 | image = Image.open(image_path).convert('RGB') 124 | image = self.transform(image) 125 | 126 | return image, index -------------------------------------------------------------------------------- /src/blip_src/models/blip_nlvr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | from models.med import BertConfig 9 | from models.nlvr_encoder import BertModel 10 | from models.vit import interpolate_pos_embed 11 | from models.blip import create_vit, init_tokenizer, is_url 12 | 13 | # from timm.models.hub import download_cached_file 14 | import os 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | from transformers import BertTokenizer 19 | import numpy as np 20 | 21 | class BLIP_NLVR(nn.Module): 22 | def __init__(self, 23 | med_config = 'configs/med_config.json', 24 | image_size = 480, 25 | vit = 'base', 26 | vit_grad_ckpt = False, 27 | vit_ckpt_layer = 0, 28 | ): 29 | """ 30 | Args: 31 | med_config (str): path for the mixture of encoder-decoder model's configuration file 32 | image_size (int): input image size 33 | vit (str): model size of vision transformer 34 | """ 35 | super().__init__() 36 | 37 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) 38 | self.tokenizer = init_tokenizer() 39 | med_config = BertConfig.from_json_file(med_config) 40 | med_config.encoder_width = vision_width 41 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 42 | 43 | self.cls_head = nn.Sequential( 44 | nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), 45 | nn.ReLU(), 46 | nn.Linear(self.text_encoder.config.hidden_size, 2) 47 | ) 48 | 49 | def forward(self, image, text, targets, train=True): 50 | 51 | image_embeds = self.visual_encoder(image) 52 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 53 | image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) 54 | 55 | text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) 56 | text.input_ids[:,0] = self.tokenizer.enc_token_id 57 | 58 | output = self.text_encoder(text.input_ids, 59 | attention_mask = text.attention_mask, 60 | encoder_hidden_states = [image0_embeds,image1_embeds], 61 | encoder_attention_mask = [image_atts[:image0_embeds.size(0)], 62 | image_atts[image0_embeds.size(0):]], 63 | return_dict = True, 64 | ) 65 | hidden_state = output.last_hidden_state[:,0,:] 66 | prediction = self.cls_head(hidden_state) 67 | 68 | if train: 69 | loss = F.cross_entropy(prediction, targets) 70 | return loss 71 | else: 72 | return prediction 73 | 74 | def blip_nlvr(pretrained='',**kwargs): 75 | model = BLIP_NLVR(**kwargs) 76 | if pretrained: 77 | model,msg = load_checkpoint(model,pretrained) 78 | print("missing keys:") 79 | print(msg.missing_keys) 80 | return model 81 | 82 | 83 | def load_checkpoint(model,url_or_filename): 84 | if os.path.isfile(url_or_filename): 85 | checkpoint = torch.load(url_or_filename, map_location='cpu') 86 | else: 87 | raise RuntimeError('checkpoint url or path is invalid') 88 | # if is_url(url_or_filename): 89 | # cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 90 | # checkpoint = torch.load(cached_file, map_location='cpu') 91 | # elif os.path.isfile(url_or_filename): 92 | # checkpoint = torch.load(url_or_filename, map_location='cpu') 93 | # else: 94 | # raise RuntimeError('checkpoint url or path is invalid') 95 | state_dict = checkpoint['model'] 96 | 97 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 98 | 99 | for key in list(state_dict.keys()): 100 | if 'crossattention.self.' in key: 101 | new_key0 = key.replace('self','self0') 102 | new_key1 = key.replace('self','self1') 103 | state_dict[new_key0] = state_dict[key] 104 | state_dict[new_key1] = state_dict[key] 105 | elif 'crossattention.output.dense.' in key: 106 | new_key0 = key.replace('dense','dense0') 107 | new_key1 = key.replace('dense','dense1') 108 | state_dict[new_key0] = state_dict[key] 109 | state_dict[new_key1] = state_dict[key] 110 | 111 | msg = model.load_state_dict(state_dict,strict=False) 112 | print('load checkpoint from %s'%url_or_filename) 113 | return model,msg 114 | -------------------------------------------------------------------------------- /src/blip_src/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import random 8 | import utils 9 | import matplotlib.pyplot as plt 10 | from PIL import Image, ImageDraw 11 | import numpy as np 12 | import cv2 13 | 14 | colormaps = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (61, 145, 64), (127, 255, 212), (0, 201, 87), 15 | (218, 112, 214), (255, 0, 255), (112, 128, 105), (250, 235, 215), 16 | (240, 255, 255), (252, 230, 201), (255, 255, 0), (235, 142, 85), 17 | (255, 97, 0), (176, 224, 230), (65, 106, 225,), (0, 255, 255), 18 | (56, 94, 15), (8, 46, 84), (255, 192, 203)] 19 | 20 | def pre_caption(caption,max_words=50): 21 | caption = re.sub( 22 | r"([.!\"()*#:;~])", 23 | ' ', 24 | caption.lower(), 25 | ) 26 | caption = re.sub( 27 | r"\s{2,}", 28 | ' ', 29 | caption, 30 | ) 31 | caption = caption.rstrip('\n') 32 | caption = caption.strip(' ') 33 | 34 | #truncate caption 35 | caption_words = caption.split(' ') 36 | if len(caption_words)>max_words: 37 | caption = ' '.join(caption_words[:max_words]) 38 | 39 | return caption 40 | 41 | def pre_question(question,max_ques_words=50): 42 | question = re.sub( 43 | r"([.!\"()*#:;~])", 44 | '', 45 | question.lower(), 46 | ) 47 | question = question.rstrip(' ') 48 | 49 | #truncate question 50 | question_words = question.split(' ') 51 | if len(question_words)>max_ques_words: 52 | question = ' '.join(question_words[:max_ques_words]) 53 | 54 | return question 55 | 56 | def draw_bboxs(image, bboxs): 57 | image_w_box = ImageDraw.Draw(image) 58 | for bbox in bboxs: 59 | image_w_box.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], fill=False, outline="red", width=4) 60 | return image 61 | 62 | def draw_bboxs_color_prompt(image, bboxs): 63 | img = np.asarray(image) 64 | for index, bbox in enumerate(bboxs): 65 | sub_img = img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] 66 | white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255 67 | white_rect[:, :, 0] = colormaps[index][0] 68 | white_rect[:, :, 1] = colormaps[index][1] 69 | white_rect[:, :, 2] = colormaps[index][2] 70 | res = cv2.addWeighted(sub_img, 0.7, white_rect, 0.3, 1.0) 71 | img[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = res 72 | cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), colormaps[index], 3) 73 | img = Image.fromarray(img.astype('uint8')) 74 | return image 75 | 76 | def save_result(result, result_dir, filename, remove_duplicate=''): 77 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 78 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 79 | 80 | json.dump(result,open(result_file,'w')) 81 | 82 | dist.barrier() 83 | 84 | if utils.is_main_process(): 85 | # combine results from all processes 86 | result = [] 87 | 88 | for rank in range(utils.get_world_size()): 89 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 90 | res = json.load(open(result_file,'r')) 91 | result += res 92 | 93 | if remove_duplicate: 94 | result_new = [] 95 | id_list = [] 96 | for res in result: 97 | if res[remove_duplicate] not in id_list: 98 | id_list.append(res[remove_duplicate]) 99 | result_new.append(res) 100 | result = result_new 101 | 102 | json.dump(result,open(final_result_file,'w')) 103 | print('result file saved to %s'%final_result_file) 104 | 105 | return final_result_file 106 | 107 | 108 | 109 | from pycocotools.coco import COCO 110 | from pycocoevalcap.eval import COCOEvalCap 111 | from torchvision.datasets.utils import download_url 112 | 113 | def coco_caption_eval(coco_gt_root, results_file, split): 114 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 115 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 116 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 117 | 118 | download_url(urls[split],coco_gt_root) 119 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 120 | 121 | # create coco object and coco_result object 122 | coco = COCO(annotation_file) 123 | coco_result = coco.loadRes(results_file) 124 | 125 | # create coco_eval object by taking coco and coco_result 126 | coco_eval = COCOEvalCap(coco, coco_result) 127 | 128 | # evaluate on a subset of images by setting 129 | # coco_eval.params['image_id'] = coco_result.getImgIds() 130 | # please remove this line when evaluating the full validation set 131 | # coco_eval.params['image_id'] = coco_result.getImgIds() 132 | 133 | # evaluate results 134 | # SPICE will take a few minutes the first time, but speeds up due to caching 135 | coco_eval.evaluate() 136 | 137 | # print output evaluation scores 138 | for metric, score in coco_eval.eval.items(): 139 | print(f'{metric}: {score:.3f}') 140 | 141 | return coco_eval -------------------------------------------------------------------------------- /src/blip_src/data/init_data_concated_pred_tsv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset_concated_pred_tsv import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | # transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_cc3m': 37 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test) 38 | return dataset 39 | 40 | elif dataset=='caption_coco': 41 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 42 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 43 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 44 | return train_dataset, val_dataset, test_dataset 45 | 46 | elif dataset=='nocaps': 47 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return val_dataset, test_dataset 50 | 51 | elif dataset=='retrieval_coco': 52 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 53 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 54 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 55 | return train_dataset, val_dataset, test_dataset 56 | 57 | elif dataset=='retrieval_flickr': 58 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 59 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 60 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 61 | return train_dataset, val_dataset, test_dataset 62 | 63 | elif dataset=='vqa': 64 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 65 | train_files = config['train_files'], split='train') 66 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 67 | return train_dataset, test_dataset 68 | 69 | elif dataset=='nlvr': 70 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 71 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 72 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 73 | return train_dataset, val_dataset, test_dataset 74 | 75 | 76 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 77 | samplers = [] 78 | for dataset,shuffle in zip(datasets,shuffles): 79 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 80 | samplers.append(sampler) 81 | return samplers 82 | 83 | 84 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 85 | loaders = [] 86 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 87 | if is_train: 88 | shuffle = (sampler is None) 89 | drop_last = True 90 | else: 91 | shuffle = False 92 | drop_last = False 93 | loader = DataLoader( 94 | dataset, 95 | batch_size=bs, 96 | num_workers=n_worker, 97 | pin_memory=True, 98 | sampler=sampler, 99 | shuffle=shuffle, 100 | collate_fn=collate_fn, 101 | drop_last=drop_last, 102 | ) 103 | loaders.append(loader) 104 | return loaders 105 | 106 | -------------------------------------------------------------------------------- /src/blip_src/data/init_data_concated_pred_refined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset_concated_pred_refined import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | # normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | # normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) 34 | return dataset 35 | 36 | elif dataset=='caption_cc3m': 37 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test) 38 | return dataset 39 | 40 | elif dataset=='caption_coco': 41 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 42 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 43 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 44 | return train_dataset, val_dataset, test_dataset 45 | 46 | elif dataset=='nocaps': 47 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return val_dataset, test_dataset 50 | 51 | elif dataset=='retrieval_coco': 52 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 53 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 54 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 55 | return train_dataset, val_dataset, test_dataset 56 | 57 | elif dataset=='retrieval_flickr': 58 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 59 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 60 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 61 | return train_dataset, val_dataset, test_dataset 62 | 63 | elif dataset=='vqa': 64 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 65 | train_files = config['train_files'], split='train') 66 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 67 | return train_dataset, test_dataset 68 | 69 | elif dataset=='nlvr': 70 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 71 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 72 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 73 | return train_dataset, val_dataset, test_dataset 74 | 75 | 76 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 77 | samplers = [] 78 | for dataset,shuffle in zip(datasets,shuffles): 79 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 80 | samplers.append(sampler) 81 | return samplers 82 | 83 | 84 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 85 | loaders = [] 86 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 87 | if is_train: 88 | shuffle = (sampler is None) 89 | drop_last = True 90 | else: 91 | shuffle = False 92 | drop_last = False 93 | loader = DataLoader( 94 | dataset, 95 | batch_size=bs, 96 | num_workers=n_worker, 97 | pin_memory=True, 98 | sampler=sampler, 99 | shuffle=shuffle, 100 | collate_fn=collate_fn, 101 | drop_last=drop_last, 102 | ) 103 | loaders.append(loader) 104 | return loaders 105 | 106 | -------------------------------------------------------------------------------- /src/blip_src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | # from torchvision.transforms.functional import InterpolationMode 5 | 6 | from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval 7 | from data.nocaps_dataset import nocaps_eval 8 | from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval 9 | from data.vqa_dataset import vqa_dataset 10 | from data.nlvr_dataset import nlvr_dataset 11 | from data.pretrain_dataset import pretrain_dataset 12 | from transform.randaugment import RandomAugment 13 | 14 | def create_dataset(dataset, config, min_scale=0.5): 15 | print(config) 16 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 17 | 18 | transform_train = transforms.Compose([ 19 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0)), #,interpolation=InterpolationMode.BICUBIC), 20 | transforms.RandomHorizontalFlip(), 21 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 22 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 23 | transforms.ToTensor(), 24 | normalize, 25 | ]) 26 | transform_test = transforms.Compose([ 27 | transforms.Resize((config['image_size'],config['image_size'])), #,interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]) 31 | 32 | if dataset=='pretrain': 33 | if 'mixup' in config: 34 | if config['mixup']: 35 | mixup = True 36 | else: 37 | mixup = False 38 | else: 39 | mixup = False 40 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train, img_mixup=mixup) 41 | return dataset 42 | elif dataset=='caption_cc3m': 43 | dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_test, img_mixup=False) 44 | return dataset 45 | elif dataset=='caption_coco': 46 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) 47 | val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') 48 | test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') 49 | return train_dataset, val_dataset, test_dataset 50 | 51 | elif dataset=='nocaps': 52 | val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') 53 | test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') 54 | return val_dataset, test_dataset 55 | 56 | elif dataset=='retrieval_coco': 57 | train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) 58 | val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 59 | test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 60 | return train_dataset, val_dataset, test_dataset 61 | 62 | elif dataset=='retrieval_flickr': 63 | train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) 64 | val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') 65 | test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') 66 | return train_dataset, val_dataset, test_dataset 67 | 68 | elif dataset=='vqa': 69 | train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], 70 | train_files = config['train_files'], split='train') 71 | test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') 72 | return train_dataset, test_dataset 73 | 74 | elif dataset=='nlvr': 75 | train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') 76 | val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') 77 | test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') 78 | return train_dataset, val_dataset, test_dataset 79 | 80 | 81 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 82 | samplers = [] 83 | for dataset,shuffle in zip(datasets,shuffles): 84 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 85 | samplers.append(sampler) 86 | return samplers 87 | 88 | 89 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 90 | loaders = [] 91 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 92 | if is_train: 93 | shuffle = (sampler is None) 94 | drop_last = True 95 | else: 96 | shuffle = False 97 | drop_last = False 98 | loader = DataLoader( 99 | dataset, 100 | batch_size=bs, 101 | num_workers=n_worker, 102 | pin_memory=True, 103 | sampler=sampler, 104 | shuffle=shuffle, 105 | collate_fn=collate_fn, 106 | drop_last=drop_last, 107 | ) 108 | loaders.append(loader) 109 | return loaders 110 | 111 | -------------------------------------------------------------------------------- /src/blip_src/data/pretrain_dataset_concated_pred_refined.py: -------------------------------------------------------------------------------- 1 | # augment the boundng box with the original images 2 | import json 3 | import os 4 | import random 5 | import pandas as pd 6 | import torch 7 | import torchvision 8 | from torch.utils.data import Dataset 9 | 10 | from PIL import Image 11 | import numpy as np 12 | from PIL import ImageFile 13 | from PIL.Image import blend as blend 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | Image.MAX_IMAGE_PIXELS = None 16 | 17 | import PIL 18 | from matplotlib.patches import Rectangle 19 | import matplotlib.pyplot as plt 20 | 21 | from data.utils import pre_caption 22 | import os,glob 23 | import math 24 | import cv2 25 | import ast 26 | import gc 27 | 28 | class pretrain_dataset(Dataset): 29 | def __init__(self, ann_file, laion_path, transform): 30 | self.img_root = "/dataset" # server 31 | self.ann_pretrain = None 32 | for f in ann_file: 33 | ann_temp = pd.read_csv(f, sep='\t', header=None) 34 | if self.ann_pretrain is None: 35 | self.ann_pretrain = ann_temp 36 | else: 37 | self.ann_pretrain = pd.concat([self.ann_pretrain, ann_temp], ignore_index=True, sort=False) 38 | 39 | self.annotation = self.ann_pretrain 40 | 41 | self.transform = transform 42 | self.normalize = torchvision.transforms.Compose([torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 43 | 44 | 45 | def reload_laion(self, epoch): 46 | n = epoch%len(self.laion_files) 47 | print('loading '+self.laion_files[n]) 48 | with open(self.laion_files[n],'r') as f: 49 | self.ann_laion = json.load(f) 50 | 51 | self.annotation = self.ann_pretrain + self.ann_laion 52 | 53 | def __len__(self): 54 | return len(self.annotation) 55 | 56 | def generate_bbox_img(self, img, ann): 57 | object_num = len(ann[3]) 58 | sample_index = random.randint(0, object_num-1) 59 | object_tag = ann[3][sample_index] 60 | w, h = img.size 61 | bbox_loc = ann[2][sample_index] 62 | im = PIL.Image.new(mode="RGB", size=(w, h)) 63 | # im = np.asarray(im) 64 | im = np.array(im) 65 | im[bbox_loc[1]:bbox_loc[3], bbox_loc[0]:bbox_loc[2]] = 255 66 | im = Image.fromarray(im) 67 | return im, object_tag 68 | 69 | def find_squares(self, image): 70 | square_exist = False 71 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 72 | ret, binary = cv2.threshold(gray, 155, 255, cv2.THRESH_BINARY) 73 | contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 74 | # print(len(contours)) 75 | if len(contours) > 0: 76 | x, y, w, h = cv2.boundingRect(contours[0]) 77 | square_exist = True 78 | else: 79 | x, y, w, h = 0, 0, image.shape[0], image.shape[1] 80 | return [x, y, w, h], square_exist 81 | 82 | def generate_coord_info(self, bbox_img): 83 | bbox_img = torchvision.transforms.ToPILImage()(bbox_img) 84 | bbox_img = np.asarray(bbox_img) 85 | h, w = bbox_img.shape[:2] 86 | [x_b, y_b, w_b, h_b], square_exist = self.find_squares(bbox_img) 87 | cv2.rectangle(bbox_img, (x_b, y_b), (x_b+w_b, y_b+h_b), (0, 255, 0), 2) 88 | # x_b = int(x_b/h*100) 89 | # y_b = int(y_b/w*100) 90 | # w_b = int(w_b/h*100) 91 | # h_b = int(h_b/w*100) 92 | return [x_b, y_b, w_b, h_b], [h, w], square_exist 93 | 94 | def __getitem__(self, index): 95 | ann = self.annotation.iloc[index] 96 | try: 97 | image = Image.open(os.path.join(self.img_root, ann[0])).convert('RGB') 98 | if len(ann.keys()) > 2: 99 | ann[2] = json.loads(ann[2]) 100 | ann[3] = ast.literal_eval(ann[3]) 101 | if len(ann[2]) > 0: 102 | # w, h = img.size # original img size 103 | # original_image = np.asarray(image) 104 | bbox_img, object_tag = self.generate_bbox_img(image, ann) 105 | seed = np.random.randint(2147483647) # make a seed with numpy generator 106 | random.seed(seed) # apply this seed to img tranfsorms 107 | torch.manual_seed(seed) # needed for torchvision 0.7 108 | np.random.seed(seed) 109 | image = self.transform(image) 110 | # original_bbox_image = np.asarray(bbox_img) 111 | # step3: transform bbox image and getting block information 112 | random.seed(seed) # apply this seed to target tranfsorms 113 | torch.manual_seed(seed) # needed for torchvision 0.7 114 | np.random.seed(seed) 115 | bbox_img = self.transform(bbox_img) 116 | [x_b, y_b, w_b, h_b], [h, w], square_exist = self.generate_coord_info(bbox_img) 117 | 118 | # prevent memory leverage 119 | del bbox_img 120 | gc.collect() 121 | 122 | # step4: get the block_index, use x to mean there is no box or all the figure 123 | if square_exist == False: 124 | block = 'x' 125 | else: 126 | w_1 = min(int((x_b + w_b/2)/w * 3), 2) 127 | h_1 = min(int((y_b + h_b/2)/h * 3), 2) 128 | # print(w_1, h_1, w, h) 129 | block = 3*h_1 + w_1 130 | prompt_text = '. The block ' + str(block) + ' has a ' + object_tag + '.' 131 | ann[1] = pre_caption(ann[1], 22) + '. ' + pre_caption(prompt_text, 8) 132 | else: 133 | image = self.transform(image) 134 | ann[1] = pre_caption(ann[1], 30) 135 | else: 136 | image = self.transform(image) 137 | ann[1] = pre_caption(ann[1], 30) 138 | except Exception as e: 139 | print(e) 140 | return self.__getitem__(random.randint(0, self.__len__()-1)) 141 | caption = ann[1] 142 | # print(caption) 143 | image = self.normalize(image) 144 | return image, caption -------------------------------------------------------------------------------- /src/vilt_src/config.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("ViLT") 4 | 5 | 6 | def _loss_names(d): 7 | ret = { 8 | "itm": 0, 9 | "mlm": 0, 10 | "mpp": 0, 11 | "vqa": 0, 12 | "nlvr2": 0, 13 | "irtr": 0, 14 | } 15 | ret.update(d) 16 | return ret 17 | 18 | 19 | @ex.config 20 | def config(): 21 | exp_name = "vilt" 22 | seed = 0 23 | datasets = ["coco", "vg", "sbu", "gcc"] 24 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 25 | batch_size = 4096 # this is a desired batch size; pl trainer will accumulate gradients when per step batch is smaller. 26 | 27 | # Image setting 28 | train_transform_keys = ["pixelbert"] 29 | val_transform_keys = ["pixelbert"] 30 | image_size = 384 31 | max_image_len = -1 32 | patch_size = 32 33 | draw_false_image = 1 34 | image_only = False 35 | 36 | # Text Setting 37 | vqav2_label_size = 3129 38 | max_text_len = 40 39 | tokenizer = "bert-base-uncased" 40 | vocab_size = 30522 41 | whole_word_masking = False 42 | mlm_prob = 0.15 43 | draw_false_text = 0 44 | 45 | # Transformer Setting 46 | vit = "vit_base_patch32_384" 47 | hidden_size = 768 48 | num_heads = 12 49 | num_layers = 12 50 | mlp_ratio = 4 51 | drop_rate = 0.1 52 | 53 | # Optimizer Setting 54 | optim_type = "adamw" 55 | learning_rate = 1e-4 56 | weight_decay = 0.01 57 | decay_power = 1 58 | max_epoch = 100 59 | max_steps = 25000 60 | warmup_steps = 2500 61 | end_lr = 0 62 | lr_mult = 1 # multiply lr for downstream heads 63 | 64 | # Downstream Setting 65 | get_recall_metric = False 66 | 67 | # PL Trainer Setting 68 | resume_from = None 69 | fast_dev_run = False 70 | val_check_interval = 1.0 71 | test_only = False 72 | 73 | # below params varies with the environment 74 | data_root = "" 75 | log_dir = "result" 76 | per_gpu_batchsize = 0 # you should define this manually with per_gpu_batch_size=# 77 | num_gpus = 1 78 | num_nodes = 1 79 | load_path = "" 80 | num_workers = 8 81 | precision = 16 82 | 83 | 84 | # Named configs for "environment" which define gpus and nodes, and paths 85 | @ex.named_config 86 | def env_dandelin(): 87 | data_root = "/data2/dsets/dataset" 88 | log_dir = "/data2/vilt/result" 89 | num_gpus = 8 90 | num_nodes = 1 91 | 92 | 93 | # Named configs for "task" which define datasets, loss_names and desired batch_size, warmup_steps, epochs, and exp_name 94 | @ex.named_config 95 | def task_mlm_itm(): 96 | exp_name = "mlm_itm" 97 | datasets = ["cc3m"] 98 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 99 | batch_size = 4096 100 | max_epoch = 10 101 | max_image_len = 200 102 | 103 | 104 | @ex.named_config 105 | def task_mlm_itm_randaug(): 106 | exp_name = "mlm_itm_randaug" 107 | datasets = ["coco", "vg", "sbu", "gcc"] 108 | train_transform_keys = ["pixelbert_randaug"] 109 | loss_names = _loss_names({"itm": 1, "mlm": 1}) 110 | batch_size = 4096 111 | max_epoch = 10 112 | max_image_len = 200 113 | 114 | @ex.named_config 115 | def task_mlm_itm_mpp(): 116 | exp_name = "mlm_itm_mpp" 117 | datasets = ["coco", "vg", "sbu", "gcc"] 118 | loss_names = _loss_names({"itm": 1, "mlm": 1, "mpp": 1}) 119 | batch_size = 4096 120 | max_epoch = 10 121 | max_image_len = 200 122 | 123 | 124 | @ex.named_config 125 | def task_finetune_nlvr2(): 126 | exp_name = "finetune_nlvr2" 127 | datasets = ["nlvr2"] 128 | loss_names = _loss_names({"nlvr2": 1}) 129 | batch_size = 128 130 | max_epoch = 10 131 | max_steps = None 132 | warmup_steps = 0.1 133 | draw_false_image = 0 134 | learning_rate = 1e-4 135 | 136 | 137 | @ex.named_config 138 | def task_finetune_nlvr2_randaug(): 139 | exp_name = "finetune_nlvr2_randaug" 140 | datasets = ["nlvr2"] 141 | train_transform_keys = ["pixelbert_randaug"] 142 | loss_names = _loss_names({"nlvr2": 1}) 143 | batch_size = 128 144 | max_epoch = 10 145 | max_steps = None 146 | warmup_steps = 0.1 147 | draw_false_image = 0 148 | learning_rate = 1e-4 149 | 150 | 151 | @ex.named_config 152 | def task_finetune_vqa(): 153 | exp_name = "finetune_vqa" 154 | datasets = ["vqa"] 155 | loss_names = _loss_names({"vqa": 1}) 156 | batch_size = 256 157 | max_epoch = 10 158 | max_steps = None 159 | warmup_steps = 0.1 160 | draw_false_image = 0 161 | learning_rate = 1e-4 162 | val_check_interval = 0.1 163 | lr_mult = 10 164 | 165 | 166 | @ex.named_config 167 | def task_finetune_vqa_randaug(): 168 | exp_name = "finetune_vqa_randaug" 169 | datasets = ["vqa"] 170 | train_transform_keys = ["pixelbert_randaug"] 171 | loss_names = _loss_names({"vqa": 1}) 172 | batch_size = 256 173 | max_epoch = 10 174 | max_steps = None 175 | warmup_steps = 0.1 176 | draw_false_image = 0 177 | learning_rate = 1e-4 178 | val_check_interval = 0.1 179 | lr_mult = 10 180 | 181 | 182 | @ex.named_config 183 | def task_finetune_irtr_coco(): 184 | exp_name = "finetune_irtr_coco" 185 | datasets = ["coco"] 186 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 187 | batch_size = 256 188 | max_epoch = 10 189 | max_steps = None 190 | warmup_steps = 0.1 191 | get_recall_metric = True 192 | draw_false_text = 15 193 | learning_rate = 1e-4 194 | 195 | 196 | @ex.named_config 197 | def task_finetune_irtr_coco_randaug(): 198 | exp_name = "finetune_irtr_coco_randaug" 199 | datasets = ["coco"] 200 | train_transform_keys = ["pixelbert_randaug"] 201 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 202 | batch_size = 256 203 | max_epoch = 10 204 | max_steps = None 205 | warmup_steps = 0.1 206 | get_recall_metric = True 207 | draw_false_text = 15 208 | learning_rate = 1e-4 209 | 210 | 211 | @ex.named_config 212 | def task_finetune_irtr_f30k(): 213 | exp_name = "finetune_irtr_f30k" 214 | datasets = ["f30k"] 215 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 216 | batch_size = 256 217 | max_epoch = 10 218 | max_steps = None 219 | warmup_steps = 0.1 220 | get_recall_metric = True 221 | draw_false_text = 15 222 | learning_rate = 1e-4 223 | 224 | 225 | @ex.named_config 226 | def task_finetune_irtr_f30k_randaug(): 227 | exp_name = "finetune_irtr_f30k_randaug" 228 | datasets = ["f30k"] 229 | train_transform_keys = ["pixelbert_randaug"] 230 | loss_names = _loss_names({"itm": 0.5, "irtr": 1}) 231 | batch_size = 256 232 | max_epoch = 10 233 | max_steps = None 234 | warmup_steps = 0.1 235 | get_recall_metric = True 236 | draw_false_text = 15 237 | learning_rate = 1e-4 238 | 239 | 240 | # Named configs for "etc" which are orthogonal to "env" and "task", need to be added at the end 241 | 242 | 243 | @ex.named_config 244 | def step25k(): 245 | max_epoch = 100 246 | max_steps = 25000 247 | 248 | 249 | @ex.named_config 250 | def step50k(): 251 | max_epoch = 100 252 | max_steps = 50000 253 | 254 | 255 | @ex.named_config 256 | def step100k(): 257 | max_epoch = 100 258 | max_steps = 100000 259 | 260 | 261 | @ex.named_config 262 | def step200k(): 263 | max_epoch = 200 264 | max_steps = 200000 265 | 266 | 267 | @ex.named_config 268 | def vit32_base(): 269 | vit = "vit_base_patch32_384" 270 | patch_size = 32 271 | hidden_size = 768 272 | num_heads = 12 273 | num_layers = 12 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PTP 2 | 3 | 8 | 9 | 10 | 11 | []( 12 | https://paperswithcode.com/sota/cross-modal-retrieval-on-coco-2014?p=position-guided-text-prompt-for-vision) 13 | 14 | 15 | []( 16 | https://paperswithcode.com/sota/image-captioning-on-coco-captions?p=position-guided-text-prompt-for-vision) 17 | 18 | 19 | []( 20 | https://paperswithcode.com/sota/zero-shot-cross-modal-retrieval-on-flickr30k?p=position-guided-text-prompt-for-vision) 21 | 22 | 23 | This repository includes implementations of the following method: 24 | 25 | - [Position-guided Text Prompt for Vision Language Pre-training](https://arxiv.org/abs/2212.09737) 26 | 27 | ## Introduction 28 | The goal of Position-guided Text Prompt (PTP) is to bring position information into conventional Vision-Language Pre-training (VLP) models, as current mainstream e2e VLP models ignore this important cues. 29 | 30 | 31 |
32 |
33 |
39 |
40 |
46 |
47 |