├── utils ├── __init__.py ├── const.py ├── training_signal_annealing.py ├── misc.py ├── logger.py ├── save.py ├── itm_eval.py └── distributed.py ├── figures ├── IAIS.png ├── singular_alignment.gif └── distributed_alignment.gif ├── optim ├── __init__.py ├── misc.py ├── sched.py └── adamw.py ├── scripts ├── convert_ckpt.py ├── extract_imgfeat.sh ├── create_imgdb.sh ├── launch_butd_container.sh ├── create_txtdb.sh ├── download_itm.sh └── convert_imgdir.py ├── config ├── uniter-base.json ├── uniter-large.json ├── train-itm-flickr-base-8gpu-hn.json ├── train-itm-flickr-large-8gpu-hn.json ├── train-itm-coco-base-8gpu-hn.json └── train-itm-coco-large-8gpu-hn.json ├── launch_container.sh ├── LICENSE ├── data ├── __init__.py ├── sampler.py ├── mlm.py ├── loader.py ├── mrm.py └── data.py ├── Dockerfile ├── .gitignore ├── model ├── ot.py ├── itm.py ├── pretrain.py ├── layer.py ├── attention.py └── model.py ├── prepro.py ├── inf_itm.py ├── README.md └── train_itm_hard_negatives.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/IAIS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/IAIS/HEAD/figures/IAIS.png -------------------------------------------------------------------------------- /figures/singular_alignment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/IAIS/HEAD/figures/singular_alignment.gif -------------------------------------------------------------------------------- /figures/distributed_alignment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lancopku/IAIS/HEAD/figures/distributed_alignment.gif -------------------------------------------------------------------------------- /utils/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | constants 6 | """ 7 | IMG_DIM = 2048 8 | IMG_LABEL_DIM = 1601 9 | BUCKET_SIZE = 8192 10 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched 7 | from .adamw import AdamW 8 | -------------------------------------------------------------------------------- /scripts/convert_ckpt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | bert_ckpt, output_ckpt = sys.argv[1:] 7 | 8 | bert = torch.load(bert_ckpt) 9 | uniter = OrderedDict() 10 | for k, v in bert.items(): 11 | uniter[k.replace('bert', 'uniter')] = v 12 | 13 | torch.save(uniter, output_ckpt) 14 | -------------------------------------------------------------------------------- /config/uniter-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /config/uniter-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 16, 10 | "num_hidden_layers": 24, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /scripts/extract_imgfeat.sh: -------------------------------------------------------------------------------- 1 | python tools/generate_tsv_gt.py --gpu 0,1,2,3,4,5,6,7 \ 2 | --cfg experiments/cfgs/faster_rcnn_end2end_resnet.yml \ 3 | --def models/vg/ResNet-101/faster_rcnn_end2end_final/test_gt.prototxt \ 4 | --out /src/flickr30k_entities_resnet101_faster_rcnn.tsv \ 5 | --net data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel \ 6 | --split flickr30k_entities \ 7 | --prefix flickr30k -------------------------------------------------------------------------------- /scripts/create_imgdb.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | IMG_NPY=$1 5 | OUT_DIR=$2 6 | 7 | set -e 8 | 9 | echo "converting image features ..." 10 | if [ ! -d $OUT_DIR ]; then 11 | mkdir -p $OUT_DIR 12 | fi 13 | NAME=$(basename $IMG_NPY) 14 | docker run --ipc=host --rm -it \ 15 | --mount src=$(pwd),dst=/src,type=bind \ 16 | --mount src=$OUT_DIR,dst=/img_db,type=bind \ 17 | --mount src=$IMG_NPY,dst=/$NAME,type=bind,readonly \ 18 | -w /src chenrocks/uniter \ 19 | python scripts/convert_imgdir.py --img_dir /$NAME --output /img_db --keep_all 20 | 21 | echo "done" -------------------------------------------------------------------------------- /launch_container.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | TXT_DB=$1 5 | IMG_DIR=$2 6 | OUTPUT=$3 7 | PRETRAIN_DIR=$4 8 | 9 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 10 | CUDA_VISIBLE_DEVICES='all' 11 | fi 12 | 13 | 14 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ 15 | --mount src=$(pwd),dst=/src,type=bind \ 16 | --mount src=$OUTPUT,dst=/storage,type=bind \ 17 | --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ 18 | --mount src=$TXT_DB,dst=/txt,type=bind,readonly \ 19 | --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ 20 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 21 | -w /src chenrocks/uniter 22 | -------------------------------------------------------------------------------- /scripts/launch_butd_container.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | IMG_DIR=$1 5 | OUT_DIR=$2 6 | ANO_DIR=$3 7 | 8 | set -e 9 | 10 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 11 | CUDA_VISIBLE_DEVICES='all' 12 | fi 13 | 14 | echo "extracting image features..." 15 | if [ ! -d $OUT_DIR ]; then 16 | mkdir -p $OUT_DIR 17 | fi 18 | 19 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ 20 | --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ 21 | --mount src=$OUT_DIR,dst=/output,type=bind \ 22 | --mount src=$ANO_DIR,dst=/ano,type=bind \ 23 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 24 | -w /src chenrocks/butd-caffe:nlvr2 \ 25 | # bash -c "python tools/generate_npz.py --gpu 0" 26 | 27 | #echo "done" 28 | -------------------------------------------------------------------------------- /utils/training_signal_annealing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_tsa_threshold(schedule, global_step, num_train_steps): 5 | training_progress = torch.tensor(global_step, dtype=torch.float) / torch.tensor(num_train_steps, dtype=torch.float) 6 | if schedule == "linear_schedule": 7 | threshold = training_progress 8 | elif schedule == "exp_schedule": 9 | scale = 5 10 | threshold = torch.exp((training_progress - 1) * scale) 11 | # [exp(-5), exp(0)] = [1e-2, 1] 12 | elif schedule == "log_schedule": 13 | scale = 5 14 | # [1 - exp(0), 1 - exp(-5)] = [0, 0.99] 15 | threshold = 1 - torch.exp((-training_progress) * scale) 16 | else: 17 | raise ValueError('schedule must in [linear_schedule, exp_schedule, log_schedule]') 18 | return threshold -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc lr helper 6 | """ 7 | from torch.optim import Adam, Adamax 8 | 9 | from .adamw import AdamW 10 | 11 | 12 | def build_optimizer(model, opts): 13 | param_optimizer = list(model.named_parameters()) 14 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 15 | optimizer_grouped_parameters = [ 16 | {'params': [p for n, p in param_optimizer 17 | if not any(nd in n for nd in no_decay)], 18 | 'weight_decay': opts.weight_decay}, 19 | {'params': [p for n, p in param_optimizer 20 | if any(nd in n for nd in no_decay)], 21 | 'weight_decay': 0.0} 22 | ] 23 | 24 | # currently Adam only 25 | if opts.optim == 'adam': 26 | OptimCls = Adam 27 | elif opts.optim == 'adamax': 28 | OptimCls = Adamax 29 | elif opts.optim == 'adamw': 30 | OptimCls = AdamW 31 | else: 32 | raise ValueError('invalid optimizer') 33 | optimizer = OptimCls(optimizer_grouped_parameters, 34 | lr=opts.learning_rate, betas=opts.betas) 35 | return optimizer 36 | -------------------------------------------------------------------------------- /config/train-itm-flickr-base-8gpu-hn.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "checkpoint": "/pretrain/uniter-base.pt", 4 | "output_dir": "/storage/itm/flickr/hard_neg", 5 | "max_txt_len": 60, 6 | "conf_th": 0.2, 7 | "max_bb": 100, 8 | "min_bb": 10, 9 | "num_bb": 36, 10 | "train_batch_size": 8, 11 | "negative_size": 399, 12 | "hard_neg_size": 31, 13 | "inf_minibatch_size": 400, 14 | "margin": 0.2, 15 | "learning_rate": 5e-05, 16 | "valid_steps": 500, 17 | "num_train_steps": 5000, 18 | "optim": "adamw", 19 | "betas": [ 20 | 0.9, 21 | 0.98 22 | ], 23 | "dropout": 0.1, 24 | "weight_decay": 0.01, 25 | "grad_norm": 2.0, 26 | "warmup_steps": 500, 27 | "seed": 42, 28 | "full_val": true, 29 | "fp16": true, 30 | "n_workers": 4, 31 | "pin_mem": true, 32 | "train_txt_dbs": [ 33 | "/txt/itm_flickr30k_train.db" 34 | ], 35 | "train_img_dbs": [ 36 | "/img/flickr30k/" 37 | ], 38 | "val_txt_db": "/txt/itm_flickr30k_val.db", 39 | "val_img_db": "/img/flickr30k/", 40 | "test_txt_db": "/txt/itm_flickr30k_test.db", 41 | "test_img_db": "/img/flickr30k/", 42 | "model_config": "/src/config/uniter-base.json" 43 | } 44 | -------------------------------------------------------------------------------- /config/train-itm-flickr-large-8gpu-hn.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "checkpoint": "/pretrain/uniter-base.pt", 4 | "output_dir": "/storage/itm/flickr/large", 5 | "max_txt_len": 60, 6 | "conf_th": 0.2, 7 | "max_bb": 100, 8 | "min_bb": 10, 9 | "num_bb": 36, 10 | "train_batch_size": 8, 11 | "negative_size": 399, 12 | "hard_neg_size": 31, 13 | "inf_minibatch_size": 400, 14 | "margin": 0.2, 15 | "learning_rate": 5e-05, 16 | "valid_steps": 500, 17 | "num_train_steps": 5000, 18 | "optim": "adamw", 19 | "betas": [ 20 | 0.9, 21 | 0.98 22 | ], 23 | "dropout": 0.1, 24 | "weight_decay": 0.01, 25 | "grad_norm": 2.0, 26 | "warmup_steps": 500, 27 | "seed": 42, 28 | "full_val": true, 29 | "fp16": true, 30 | "n_workers": 4, 31 | "pin_mem": true, 32 | "train_txt_dbs": [ 33 | "/txt/itm_flickr30k_train.db" 34 | ], 35 | "train_img_dbs": [ 36 | "/img/flickr30k/" 37 | ], 38 | "val_txt_db": "/txt/itm_flickr30k_val.db", 39 | "val_img_db": "/img/flickr30k/", 40 | "test_txt_db": "/txt/itm_flickr30k_test.db", 41 | "test_img_db": "/img/flickr30k/", 42 | "model_config": "/src/config/uniter-large.json" 43 | } 44 | -------------------------------------------------------------------------------- /config/train-itm-coco-base-8gpu-hn.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "checkpoint": "/pretrain/uniter-base.pt", 4 | "output_dir": "/storage/itm/coco/hard_neg", 5 | "max_txt_len": 60, 6 | "conf_th": 0.2, 7 | "max_bb": 100, 8 | "min_bb": 10, 9 | "num_bb": 36, 10 | "train_batch_size": 8, 11 | "negative_size": 399, 12 | "hard_neg_size": 31, 13 | "inf_minibatch_size": 400, 14 | "margin": 0.2, 15 | "learning_rate": 5e-05, 16 | "valid_steps": 500, 17 | "num_train_steps": 5000, 18 | "optim": "adamw", 19 | "betas": [ 20 | 0.9, 21 | 0.98 22 | ], 23 | "dropout": 0.1, 24 | "weight_decay": 0.01, 25 | "grad_norm": 2.0, 26 | "warmup_steps": 500, 27 | "seed": 42, 28 | "full_val": true, 29 | "fp16": true, 30 | "n_workers": 4, 31 | "pin_mem": true, 32 | "train_txt_dbs": [ 33 | "/txt/itm_coco_train.db", 34 | "/txt/itm_coco_restval.db" 35 | ], 36 | "train_img_dbs": [ 37 | "/img/coco_train2014/", 38 | "/img/coco_val2014" 39 | ], 40 | "val_txt_db": "/txt/itm_coco_val.db", 41 | "val_img_db": "/img/coco_val2014/", 42 | "test_txt_db": "/txt/itm_coco_test.db", 43 | "test_img_db": "/img/coco_val2014/", 44 | "model_config": "/src/config/uniter-base.json" 45 | } 46 | -------------------------------------------------------------------------------- /config/train-itm-coco-large-8gpu-hn.json: -------------------------------------------------------------------------------- 1 | { 2 | "compressed_db": false, 3 | "checkpoint": "/pretrain/uniter-large.pt", 4 | "output_dir": "/storage/itm/coco/large", 5 | "max_txt_len": 60, 6 | "conf_th": 0.2, 7 | "max_bb": 100, 8 | "min_bb": 10, 9 | "num_bb": 36, 10 | "train_batch_size": 8, 11 | "negative_size": 399, 12 | "hard_neg_size": 31, 13 | "inf_minibatch_size": 400, 14 | "margin": 0.2, 15 | "learning_rate": 3e-05, 16 | "valid_steps": 500, 17 | "num_train_steps": 5000, 18 | "optim": "adamw", 19 | "betas": [ 20 | 0.9, 21 | 0.98 22 | ], 23 | "dropout": 0.1, 24 | "weight_decay": 0.01, 25 | "grad_norm": 2.0, 26 | "warmup_steps": 500, 27 | "seed": 42, 28 | "full_val": false, 29 | "fp16": true, 30 | "n_workers": 4, 31 | "pin_mem": true, 32 | "train_txt_dbs": [ 33 | "/txt/itm_coco_train.db", 34 | "/txt/itm_coco_restval.db" 35 | ], 36 | "train_img_dbs": [ 37 | "/img/coco_train2014/", 38 | "/img/coco_val2014" 39 | ], 40 | "val_txt_db": "/txt/itm_coco_val.db", 41 | "val_img_db": "/img/coco_val2014/", 42 | "test_txt_db": "/txt/itm_coco_test.db", 43 | "test_img_db": "/img/coco_val2014/", 44 | "model_config": "/src/config/uniter-large.json" 45 | } 46 | -------------------------------------------------------------------------------- /scripts/create_txtdb.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | OUT_DIR=$1 5 | ANN_DIR=$2 6 | 7 | set -e 8 | 9 | URL='https://raw.githubusercontent.com/lil-lab/nlvr/master/nlvr2/data' 10 | if [ ! -d $OUT_DIR ]; then 11 | mkdir -p $OUT_DIR 12 | fi 13 | if [ ! -d $ANN_DIR ]; then 14 | mkdir -p $ANN_DIR 15 | fi 16 | 17 | BLOB='https://convaisharables.blob.core.windows.net/uniter' 18 | MISSING=$BLOB/ann/missing_nlvr2_imgs.json 19 | if [ ! -f $ANN_DIR/missing_nlvr2_imgs.json ]; then 20 | wget $MISSING -O $ANN_DIR/missing_nlvr2_imgs.json 21 | fi 22 | 23 | for SPLIT in 'train' 'dev' 'test1'; do 24 | if [ ! -f $ANN_DIR/$SPLIT.json ]; then 25 | echo "downloading ${SPLIT} annotations..." 26 | wget $URL/$SPLIT.json -O $ANN_DIR/$SPLIT.json 27 | fi 28 | 29 | echo "preprocessing ${SPLIT} annotations..." 30 | docker run --ipc=host --rm -it \ 31 | --mount src=$(pwd),dst=/src,type=bind \ 32 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 33 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 34 | -w /src chenrocks/uniter \ 35 | python prepro.py --annotation /ann/$SPLIT.json \ 36 | --missing_imgs /ann/missing_nlvr2_imgs.json \ 37 | --output /txt_db/nlvr2_${SPLIT}.db 38 | done 39 | 40 | echo "done" 41 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .data import (TxtTokLmdb, DetectFeatLmdb, 7 | ImageLmdbGroup, ConcatDatasetWithLens) 8 | from .sampler import TokenBucketSampler 9 | from .loader import PrefetchLoader, MetaLoader 10 | from .vqa import VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate 11 | from .ve import VeDataset, VeEvalDataset, ve_collate, ve_eval_collate 12 | from .nlvr2 import (Nlvr2PairedDataset, Nlvr2PairedEvalDataset, 13 | Nlvr2TripletDataset, Nlvr2TripletEvalDataset, 14 | nlvr2_paired_collate, nlvr2_paired_eval_collate, 15 | nlvr2_triplet_collate, nlvr2_triplet_eval_collate) 16 | from .itm import (TokenBucketSamplerForItm, ItmDataset, 17 | itm_collate, itm_ot_collate, 18 | ItmRankDataset, ItmValDataset, ItmEvalDataset, 19 | ItmRankDatasetHardNegFromImage, 20 | ItmRankDatasetHardNegFromText, 21 | itm_rank_collate, itm_val_collate, itm_eval_collate, 22 | itm_rank_hn_collate) 23 | from .mlm import MlmDataset, mlm_collate 24 | from .mrm import MrfrDataset, MrcDataset, mrfr_collate, mrc_collate 25 | from .vcr import (VcrTxtTokLmdb, VcrDataset, VcrEvalDataset, 26 | vcr_collate, vcr_eval_collate) 27 | -------------------------------------------------------------------------------- /optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | optimizer learning rate scheduling helpers 6 | """ 7 | from math import ceil 8 | 9 | 10 | def noam_schedule(step, warmup_step=4000): 11 | """ original Transformer schedule""" 12 | if step <= warmup_step: 13 | return step / warmup_step 14 | return (warmup_step ** 0.5) * (step ** -0.5) 15 | 16 | 17 | def warmup_linear(step, warmup_step, tot_step): 18 | """ BERT schedule """ 19 | if step < warmup_step: 20 | return step / warmup_step 21 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 22 | 23 | 24 | def vqa_schedule(step, warmup_interval, decay_interval, 25 | decay_start, decay_rate): 26 | """ VQA schedule from MCAN """ 27 | if step < warmup_interval: 28 | return 1/4 29 | elif step < 2 * warmup_interval: 30 | return 2/4 31 | elif step < 3 * warmup_interval: 32 | return 3/4 33 | elif step >= decay_start: 34 | num_decay = ceil((step - decay_start) / decay_interval) 35 | return decay_rate ** num_decay 36 | else: 37 | return 1 38 | 39 | 40 | def get_lr_sched(global_step, opts): 41 | # learning rate scheduling 42 | lr_this_step = opts.learning_rate * warmup_linear( 43 | global_step, opts.warmup_steps, opts.num_train_steps) 44 | if lr_this_step <= 0: 45 | lr_this_step = 1e-8 46 | return lr_this_step 47 | -------------------------------------------------------------------------------- /scripts/download_itm.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://convaisharables.blob.core.windows.net/uniter' 13 | 14 | # image dbs 15 | for SPLIT in 'train2014' 'val2014'; do 16 | if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then 17 | wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/ 18 | tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db 19 | fi 20 | done 21 | if [ ! -d $DOWNLOAD/img_db/flickr30k ] ; then 22 | wget $BLOB/img_db/flickr30k.tar -P $DOWNLOAD/img_db/ 23 | tar -xvf $DOWNLOAD/img_db/flickr30k.tar -C $DOWNLOAD/img_db 24 | fi 25 | 26 | # text dbs 27 | for SPLIT in 'train' 'restval' 'val' 'test'; do 28 | wget $BLOB/txt_db/itm_coco_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 29 | tar -xvf $DOWNLOAD/txt_db/itm_coco_$SPLIT.db.tar -C $DOWNLOAD/txt_db 30 | done 31 | for SPLIT in 'train' 'val' 'test'; do 32 | wget $BLOB/txt_db/itm_flickr30k_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/itm_flickr30k_$SPLIT.db.tar -C $DOWNLOAD/txt_db 34 | done 35 | 36 | for MODEL in uniter-base uniter-large; do 37 | if [ ! -f $DOWNLOAD/pretrained/$MODEL.pt ] ; then 38 | wget $BLOB/pretrained/$MODEL.pt -P $DOWNLOAD/pretrained/ 39 | fi 40 | done -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.05-py3 2 | 3 | # basic python packages 4 | RUN pip install pytorch-pretrained-bert==0.6.2 \ 5 | tensorboardX==1.7 ipdb==0.12 lz4==2.1.9 lmdb==0.97 6 | 7 | ####### horovod for multi-GPU (distributed) training ####### 8 | 9 | # update OpenMPI to avoid horovod bug 10 | RUN rm -r /usr/local/mpi &&\ 11 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.4.tar.gz &&\ 12 | gunzip -c openmpi-3.1.4.tar.gz | tar xf - &&\ 13 | cd openmpi-3.1.4 &&\ 14 | ./configure --prefix=/usr/local/mpi --enable-orterun-prefix-by-default \ 15 | --with-verbs --disable-getpwuid &&\ 16 | make -j$(nproc) all && make install &&\ 17 | ldconfig &&\ 18 | cd - && rm -r openmpi-3.1.4 && rm openmpi-3.1.4.tar.gz 19 | 20 | ENV OPENMPI_VERSION=3.1.4 21 | 22 | # horovod 23 | RUN HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \ 24 | pip install --no-cache-dir horovod==0.16.4 &&\ 25 | ldconfig 26 | 27 | # ssh 28 | RUN apt-get update &&\ 29 | apt-get install -y --no-install-recommends openssh-client openssh-server &&\ 30 | mkdir -p /var/run/sshd 31 | 32 | # Allow OpenSSH to talk to containers without asking for confirmation 33 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 34 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 35 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 36 | 37 | 38 | WORKDIR /src 39 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc utilities 6 | """ 7 | import json 8 | import random 9 | import sys 10 | 11 | import torch 12 | import numpy as np 13 | 14 | from utils.logger import LOGGER 15 | 16 | 17 | class NoOp(object): 18 | """ useful for distributed training No-Ops """ 19 | def __getattr__(self, name): 20 | return self.noop 21 | 22 | def noop(self, *args, **kwargs): 23 | return 24 | 25 | 26 | def parse_with_config(parser): 27 | args = parser.parse_args() 28 | if args.config is not None: 29 | config_args = json.load(open(args.config)) 30 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] 31 | if arg.startswith('--')} 32 | for k, v in config_args.items(): 33 | if k not in override_keys: 34 | setattr(args, k, v) 35 | del args.config 36 | return args 37 | 38 | 39 | VE_ENT2IDX = { 40 | 'contradiction': 0, 41 | 'entailment': 1, 42 | 'neutral': 2 43 | } 44 | 45 | VE_IDX2ENT = { 46 | 0: 'contradiction', 47 | 1: 'entailment', 48 | 2: 'neutral' 49 | } 50 | 51 | 52 | class Struct(object): 53 | def __init__(self, dict_): 54 | self.__dict__.update(dict_) 55 | 56 | 57 | def set_dropout(model, drop_p): 58 | for name, module in model.named_modules(): 59 | # we might want to tune dropout for smaller dataset 60 | if isinstance(module, torch.nn.Dropout): 61 | if module.p != drop_p: 62 | module.p = drop_p 63 | LOGGER.info(f'{name} set to {drop_p}') 64 | 65 | 66 | def set_random_seed(seed): 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed_all(seed) 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ctags 2 | tags 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # pycharm 110 | .idea 111 | .DS_Store -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | sampler for length bucketing (batch by tokens) 6 | """ 7 | import random 8 | 9 | from torch.utils.data import Sampler 10 | from cytoolz import partition_all 11 | 12 | 13 | class TokenBucketSampler(Sampler): 14 | def __init__(self, lens, bucket_size, batch_size, 15 | droplast=False, size_multiple=8): 16 | self._lens = lens 17 | self._max_tok = batch_size 18 | self._bucket_size = bucket_size 19 | self._droplast = droplast 20 | self._size_mul = size_multiple 21 | 22 | def _create_ids(self): 23 | return list(range(len(self._lens))) 24 | 25 | def _sort_fn(self, i): 26 | return self._lens[i] 27 | 28 | def __iter__(self): 29 | ids = self._create_ids() 30 | random.shuffle(ids) 31 | buckets = [sorted(ids[i:i+self._bucket_size], 32 | key=self._sort_fn, reverse=True) 33 | for i in range(0, len(ids), self._bucket_size)] 34 | # fill batches until max_token (include padding) 35 | batches = [] 36 | for bucket in buckets: 37 | max_len = 0 38 | batch_indices = [] 39 | for indices in partition_all(self._size_mul, bucket): 40 | max_len = max(max_len, max(self._lens[i] for i in indices)) 41 | if (max_len * (len(batch_indices) + self._size_mul) 42 | > self._max_tok): 43 | if not batch_indices: 44 | raise ValueError( 45 | "max_tokens too small / max_seq_len too long") 46 | assert len(batch_indices) % self._size_mul == 0 47 | batches.append(batch_indices) 48 | batch_indices = list(indices) 49 | else: 50 | batch_indices.extend(indices) 51 | if not self._droplast and batch_indices: 52 | batches.append(batch_indices) 53 | random.shuffle(batches) 54 | return iter(batches) 55 | 56 | def __len__(self): 57 | raise ValueError("NOT supported. " 58 | "This has some randomness across epochs") 59 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for logging 6 | NOTE: loggers are global objects use with caution 7 | """ 8 | import logging 9 | import math 10 | 11 | import tensorboardX 12 | 13 | 14 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 16 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 17 | LOGGER = logging.getLogger('__main__') # this is the global logger 18 | 19 | 20 | def add_log_to_file(log_path): 21 | fh = logging.FileHandler(log_path) 22 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 23 | fh.setFormatter(formatter) 24 | LOGGER.addHandler(fh) 25 | 26 | 27 | class TensorboardLogger(object): 28 | def __init__(self): 29 | self._logger = None 30 | self._global_step = 0 31 | 32 | def create(self, path): 33 | self._logger = tensorboardX.SummaryWriter(path) 34 | 35 | def noop(self, *args, **kwargs): 36 | return 37 | 38 | def step(self): 39 | self._global_step += 1 40 | 41 | @property 42 | def global_step(self): 43 | return self._global_step 44 | 45 | def log_scaler_dict(self, log_dict, prefix=''): 46 | """ log a dictionary of scalar values""" 47 | if self._logger is None: 48 | return 49 | if prefix: 50 | prefix = f'{prefix}_' 51 | for name, value in log_dict.items(): 52 | if isinstance(value, dict): 53 | self.log_scaler_dict(value, self._global_step, 54 | prefix=f'{prefix}{name}') 55 | else: 56 | self._logger.add_scalar(f'{prefix}{name}', value, 57 | self._global_step) 58 | 59 | def __getattr__(self, name): 60 | if self._logger is None: 61 | return self.noop 62 | return self._logger.__getattribute__(name) 63 | 64 | 65 | TB_LOGGER = TensorboardLogger() 66 | 67 | 68 | class RunningMeter(object): 69 | """ running meteor of a scalar value 70 | (useful for monitoring training loss) 71 | """ 72 | def __init__(self, name, val=None, smooth=0.99): 73 | self._name = name 74 | self._sm = smooth 75 | self._val = val 76 | 77 | def __call__(self, value): 78 | val = (value if self._val is None 79 | else value*(1-self._sm) + self._val*self._sm) 80 | if not math.isnan(val): 81 | self._val = val 82 | 83 | def __str__(self): 84 | return f'{self._name}: {self._val:.4f}' 85 | 86 | @property 87 | def val(self): 88 | if self._val is None: 89 | return 0 90 | return self._val 91 | 92 | @property 93 | def name(self): 94 | return self._name 95 | -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | saving utilities 6 | """ 7 | import json 8 | import os 9 | from os.path import abspath, dirname, exists, join 10 | import subprocess 11 | 12 | import torch 13 | 14 | from utils.logger import LOGGER 15 | 16 | 17 | def save_training_meta(args): 18 | if args.rank > 0: 19 | return 20 | 21 | if not exists(args.output_dir): 22 | os.makedirs(join(args.output_dir, 'log')) 23 | os.makedirs(join(args.output_dir, 'ckpt')) 24 | 25 | with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: 26 | json.dump(vars(args), writer, indent=4) 27 | model_config = json.load(open(args.model_config)) 28 | with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: 29 | json.dump(model_config, writer, indent=4) 30 | # git info 31 | try: 32 | LOGGER.info("Waiting on git info....") 33 | c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], 34 | timeout=10, stdout=subprocess.PIPE) 35 | git_branch_name = c.stdout.decode().strip() 36 | LOGGER.info("Git branch: %s", git_branch_name) 37 | c = subprocess.run(["git", "rev-parse", "HEAD"], 38 | timeout=10, stdout=subprocess.PIPE) 39 | git_sha = c.stdout.decode().strip() 40 | LOGGER.info("Git SHA: %s", git_sha) 41 | git_dir = abspath(dirname(__file__)) 42 | git_status = subprocess.check_output( 43 | ['git', 'status', '--short'], 44 | cwd=git_dir, universal_newlines=True).strip() 45 | with open(join(args.output_dir, 'log', 'git_info.json'), 46 | 'w') as writer: 47 | json.dump({'branch': git_branch_name, 48 | 'is_dirty': bool(git_status), 49 | 'status': git_status, 50 | 'sha': git_sha}, 51 | writer, indent=4) 52 | except subprocess.TimeoutExpired as e: 53 | LOGGER.exception(e) 54 | LOGGER.warn("Git info not found. Moving right along...") 55 | 56 | 57 | class ModelSaver(object): 58 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 59 | self.output_dir = output_dir 60 | self.prefix = prefix 61 | self.suffix = suffix 62 | 63 | def save(self, model, step, optimizer=None): 64 | output_model_file = join(self.output_dir, 65 | f"{self.prefix}_{step}.{self.suffix}") 66 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 67 | for k, v in model.state_dict().items()} 68 | torch.save(state_dict, output_model_file) 69 | if optimizer is not None: 70 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 71 | if hasattr(optimizer, '_amp_stash'): 72 | pass # TODO fp16 optimizer 73 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 74 | -------------------------------------------------------------------------------- /model/ot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Wasserstein Distance (Optimal Transport) 6 | """ 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | def cost_matrix_cosine(x, y, eps=1e-5): 12 | """ Compute cosine distnace across every pairs of x, y (batched) 13 | [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]""" 14 | assert x.dim() == y.dim() 15 | assert x.size(0) == y.size(0) 16 | assert x.size(2) == y.size(2) 17 | x_norm = F.normalize(x, p=2, dim=-1, eps=eps) 18 | y_norm = F.normalize(y, p=2, dim=-1, eps=eps) 19 | cosine_sim = x_norm.matmul(y_norm.transpose(1, 2)) 20 | cosine_dist = 1 - cosine_sim 21 | return cosine_dist 22 | 23 | 24 | def trace(x): 25 | """ compute trace of input tensor (batched) """ 26 | b, m, n = x.size() 27 | assert m == n 28 | mask = torch.eye(n, dtype=torch.uint8, device=x.device 29 | ).unsqueeze(0).expand_as(x) 30 | trace = x.masked_select(mask).contiguous().view( 31 | b, n).sum(dim=-1, keepdim=False) 32 | return trace 33 | 34 | 35 | @torch.no_grad() 36 | def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k): 37 | """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]""" 38 | b, m, n = C.size() 39 | sigma = torch.ones(b, m, dtype=C.dtype, device=C.device 40 | ) / x_len.unsqueeze(1) 41 | T = torch.ones(b, n, m, dtype=C.dtype, device=C.device) 42 | A = torch.exp(-C.transpose(1, 2)/beta) 43 | 44 | # mask padded positions 45 | sigma.masked_fill_(x_pad, 0) 46 | joint_pad = joint_pad.transpose(1, 2) 47 | T.masked_fill_(joint_pad, 0) 48 | A.masked_fill_(joint_pad, 0) 49 | 50 | # broadcastable lengths 51 | x_len = x_len.unsqueeze(1).unsqueeze(2) 52 | y_len = y_len.unsqueeze(1).unsqueeze(2) 53 | 54 | # mask to zero out padding in delta and sigma 55 | x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1) 56 | y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1) 57 | 58 | for _ in range(iteration): 59 | Q = A * T # bs * n * m 60 | sigma = sigma.view(b, m, 1) 61 | for _ in range(k): 62 | delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask) 63 | sigma = 1 / (x_len * delta.matmul(Q) + x_mask) 64 | T = delta.view(b, n, 1) * Q * sigma 65 | T.masked_fill_(joint_pad, 0) 66 | return T 67 | 68 | 69 | def optimal_transport_dist(txt_emb, img_emb, txt_pad, img_pad, 70 | beta=0.5, iteration=50, k=1): 71 | """ [B, M, D], [B, N, D], [B, M], [B, N]""" 72 | cost = cost_matrix_cosine(txt_emb, img_emb) 73 | # mask the padded inputs 74 | joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) 75 | cost.masked_fill_(joint_pad, 0) 76 | 77 | txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False) 78 | ).to(dtype=cost.dtype) 79 | img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False) 80 | ).to(dtype=cost.dtype) 81 | 82 | T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 83 | beta, iteration, k) 84 | distance = trace(cost.matmul(T.detach())) 85 | return distance 86 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | preprocess NLVR annotations into LMDB 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from os.path import exists 11 | 12 | from cytoolz import curry 13 | from tqdm import tqdm 14 | from pytorch_pretrained_bert import BertTokenizer 15 | 16 | from data.data import open_lmdb 17 | 18 | 19 | @curry 20 | def bert_tokenize(tokenizer, text): 21 | ids = [] 22 | for word in text.strip().split(): 23 | ws = tokenizer.tokenize(word) 24 | if not ws: 25 | # some special char 26 | continue 27 | ids.extend(tokenizer.convert_tokens_to_ids(ws)) 28 | return ids 29 | 30 | 31 | def process_nlvr2(jsonl, db, tokenizer, missing=None): 32 | id2len = {} 33 | txt2img = {} # not sure if useful 34 | for line in tqdm(jsonl, desc='processing NLVR2'): 35 | example = json.loads(line) 36 | id_ = example['identifier'] 37 | img_id = '-'.join(id_.split('-')[:-1]) 38 | img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz') 39 | if missing and (img_fname[0] in missing or img_fname[1] in missing): 40 | continue 41 | input_ids = tokenizer(example['sentence']) 42 | if 'label' in example: 43 | target = 1 if example['label'] == 'True' else 0 44 | else: 45 | target = None 46 | txt2img[id_] = img_fname 47 | id2len[id_] = len(input_ids) 48 | example['input_ids'] = input_ids 49 | example['img_fname'] = img_fname 50 | example['target'] = target 51 | db[id_] = example 52 | return id2len, txt2img 53 | 54 | 55 | def main(opts): 56 | if not exists(opts.output): 57 | os.makedirs(opts.output) 58 | else: 59 | raise ValueError('Found existing DB. Please explicitly remove ' 60 | 'for re-processing') 61 | meta = vars(opts) 62 | meta['tokenizer'] = opts.toker 63 | toker = BertTokenizer.from_pretrained( 64 | opts.toker, do_lower_case='uncased' in opts.toker) 65 | tokenizer = bert_tokenize(toker) 66 | meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0] 67 | meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0] 68 | meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0] 69 | meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0] 70 | meta['v_range'] = (toker.convert_tokens_to_ids('!')[0], 71 | len(toker.vocab)) 72 | with open(f'{opts.output}/meta.json', 'w') as f: 73 | json.dump(vars(opts), f, indent=4) 74 | 75 | open_db = curry(open_lmdb, opts.output, readonly=False) 76 | with open_db() as db: 77 | with open(opts.annotation) as ann: 78 | if opts.missing_imgs is not None: 79 | missing_imgs = set(json.load(open(opts.missing_imgs))) 80 | else: 81 | missing_imgs = None 82 | id2lens, txt2img = process_nlvr2(ann, db, tokenizer, missing_imgs) 83 | 84 | with open(f'{opts.output}/id2len.json', 'w') as f: 85 | json.dump(id2lens, f) 86 | with open(f'{opts.output}/txt2img.json', 'w') as f: 87 | json.dump(txt2img, f) 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--annotation', required=True, 93 | help='annotation JSON') 94 | parser.add_argument('--missing_imgs', 95 | help='some training image features are corrupted') 96 | parser.add_argument('--output', required=True, 97 | help='output dir of DB') 98 | parser.add_argument('--toker', default='bert-base-cased', 99 | help='which BERT tokenizer to used') 100 | args = parser.parse_args() 101 | main(args) 102 | -------------------------------------------------------------------------------- /utils/itm_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Image Text Retrieval evaluation helper 6 | """ 7 | from time import time 8 | 9 | import torch 10 | from horovod import torch as hvd 11 | from tqdm import tqdm 12 | 13 | from .logger import LOGGER 14 | from .misc import NoOp 15 | from .distributed import all_gather_list 16 | import pdb 17 | 18 | 19 | @torch.no_grad() 20 | def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts): 21 | # image retrieval 22 | img2j = {i: j for j, i in enumerate(img_ids)} 23 | _, rank_txt = score_matrix.topk(10, dim=1) # rank_txt: [txt_len, 10] 24 | gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] 25 | for txt_id in txt_ids], 26 | ).to(rank_txt.device 27 | ).unsqueeze(1).expand_as(rank_txt) 28 | rank = (rank_txt == gt_img_j).nonzero() 29 | if rank.numel(): 30 | ir_r1 = (rank < 1).sum().item() / len(txt_ids) 31 | ir_r5 = (rank < 5).sum().item() / len(txt_ids) 32 | ir_r10 = (rank < 10).sum().item() / len(txt_ids) 33 | else: 34 | ir_r1, ir_r5, ir_r10 = 0, 0, 0 35 | 36 | # text retrieval 37 | txt2i = {t: i for i, t in enumerate(txt_ids)} 38 | _, rank_img = score_matrix.topk(10, dim=0) # rank_img: [10, img_len] 39 | tr_r1, tr_r5, tr_r10 = 0, 0, 0 40 | for j, img_id in enumerate(img_ids): 41 | gt_is = [txt2i[t] for t in img2txts[img_id]] 42 | ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] 43 | rank = min([10] + [r.item() for r in ranks if r.numel()]) 44 | if rank < 1: 45 | tr_r1 += 1 46 | if rank < 5: 47 | tr_r5 += 1 48 | if rank < 10: 49 | tr_r10 += 1 50 | tr_r1 /= len(img_ids) 51 | tr_r5 /= len(img_ids) 52 | tr_r10 /= len(img_ids) 53 | 54 | tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 55 | ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 56 | r_mean = (tr_mean + ir_mean) / 2 57 | 58 | eval_log = {'txt_r1': tr_r1, 59 | 'txt_r5': tr_r5, 60 | 'txt_r10': tr_r10, 61 | 'txt_r_mean': tr_mean, 62 | 'img_r1': ir_r1, 63 | 'img_r5': ir_r5, 64 | 'img_r10': ir_r10, 65 | 'img_r_mean': ir_mean, 66 | 'r_mean': r_mean} 67 | return eval_log 68 | 69 | 70 | @torch.no_grad() 71 | def evaluate(model, eval_loader, IAIS=False): 72 | st = time() 73 | LOGGER.info("start running Image/Text Retrieval evaluation ...") 74 | score_matrix = inference(model, eval_loader, IAIS=IAIS) 75 | dset = eval_loader.dataset 76 | all_score = hvd.allgather(score_matrix) 77 | all_txt_ids = [i for ids in all_gather_list(dset.ids) 78 | for i in ids] 79 | all_img_ids = dset.all_img_ids 80 | assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) 81 | if hvd.rank() != 0: 82 | return {} 83 | 84 | # NOTE: only use rank0 to compute final scores 85 | eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, dset.txt2img, dset.img2txts) 86 | 87 | tot_time = time()-st 88 | LOGGER.info(f"evaluation finished in {int(tot_time)} seconds") 89 | return eval_log 90 | 91 | 92 | @torch.no_grad() 93 | def inference(model, eval_loader, IAIS): 94 | model.eval() 95 | if hvd.rank() == 0: 96 | pbar = tqdm(total=len(eval_loader)) 97 | else: 98 | pbar = NoOp() 99 | score_matrix = torch.zeros(len(eval_loader.dataset), 100 | len(eval_loader.dataset.all_img_ids), 101 | device=torch.device("cuda"), 102 | dtype=torch.float16) # [txt_len, img_len], note that one img corr.to 5 txt 103 | for i, mini_batches in enumerate(eval_loader): 104 | j = 0 105 | for batch in mini_batches: 106 | scores = model(batch, compute_loss=False, IAIS=IAIS) # the scores indicate the matching extend of the i-th txt and the bs imgs 107 | bs = scores.size(0) 108 | score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() 109 | j += bs 110 | assert j == score_matrix.size(1) 111 | pbar.update(1) 112 | model.train() 113 | pbar.close() 114 | return score_matrix 115 | -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface (https://github.com/huggingface/transformers). 4 | """ 5 | import math 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class AdamW(Optimizer): 12 | """ Implements Adam algorithm with weight decay fix. 13 | Parameters: 14 | lr (float): learning rate. Default 1e-3. 15 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). 16 | Default: (0.9, 0.999) 17 | eps (float): Adams epsilon. Default: 1e-6 18 | weight_decay (float): Weight decay. Default: 0.0 19 | correct_bias (bool): can be set to False to avoid correcting bias 20 | in Adam (e.g. like in Bert TF repository). Default True. 21 | """ 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 23 | weight_decay=0.0, correct_bias=True): 24 | if lr < 0.0: 25 | raise ValueError( 26 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 27 | if not 0.0 <= betas[0] < 1.0: 28 | raise ValueError("Invalid beta parameter: {} - " 29 | "should be in [0.0, 1.0[".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter: {} - " 32 | "should be in [0.0, 1.0[".format(betas[1])) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {} - " 35 | "should be >= 0.0".format(eps)) 36 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 37 | correct_bias=correct_bias) 38 | super(AdamW, self).__init__(params, defaults) 39 | 40 | def step(self, closure=None): 41 | """Performs a single optimization step. 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | grad = p.grad.data 55 | if grad.is_sparse: 56 | raise RuntimeError( 57 | 'Adam does not support sparse ' 58 | 'gradients, please consider SparseAdam instead') 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | 70 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 71 | beta1, beta2 = group['betas'] 72 | 73 | state['step'] += 1 74 | 75 | # Decay the first and second moment running average coefficient 76 | # In-place operations to update the averages at the same time 77 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 79 | denom = exp_avg_sq.sqrt().add_(group['eps']) 80 | 81 | step_size = group['lr'] 82 | if group['correct_bias']: # No bias correction for Bert 83 | bias_correction1 = 1.0 - beta1 ** state['step'] 84 | bias_correction2 = 1.0 - beta2 ** state['step'] 85 | step_size = (step_size * math.sqrt(bias_correction2) 86 | / bias_correction1) 87 | 88 | p.data.addcdiv_(-step_size, exp_avg, denom) 89 | 90 | # Just adding the square of the weights to the loss function is 91 | # *not* the correct way of using L2 regularization/weight decay 92 | # with Adam, since that will interact with the m and v 93 | # parameters in strange ways. 94 | # 95 | # Instead we want to decay the weights in a manner that doesn't 96 | # interact with the m/v parameters. This is equivalent to 97 | # adding the square of the weights to the loss with plain 98 | # (non-momentum) SGD. 99 | # Add weight decay at the end (fixed version) 100 | if group['weight_decay'] > 0.0: 101 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /data/mlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | MLM datasets 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from toolz.sandbox import unzip 12 | 13 | from .data import (DetectFeatTxtTokDataset, TxtTokLmdb, 14 | pad_tensors, get_gather_index) 15 | 16 | 17 | def random_word(tokens, vocab_range, mask): 18 | """ 19 | Masking some random tokens for Language Model task with probabilities as in 20 | the original BERT paper. 21 | :param tokens: list of int, tokenized sentence. 22 | :param vocab_range: for choosing a random word 23 | :return: (list of int, list of int), masked tokens and related labels for 24 | LM prediction 25 | """ 26 | output_label = [] 27 | 28 | for i, token in enumerate(tokens): 29 | prob = random.random() 30 | # mask token with 15% probability 31 | if prob < 0.15: 32 | prob /= 0.15 33 | 34 | # 80% randomly change token to mask token 35 | if prob < 0.8: 36 | tokens[i] = mask 37 | 38 | # 10% randomly change token to random token 39 | elif prob < 0.9: 40 | tokens[i] = random.choice(list(range(*vocab_range))) 41 | 42 | # -> rest 10% randomly keep current token 43 | 44 | # append current token to output (we will predict these later) 45 | output_label.append(token) 46 | else: 47 | # no masking token (will be ignored by loss function later) 48 | output_label.append(-1) 49 | if all(o == -1 for o in output_label): 50 | # at least mask 1 51 | output_label[0] = tokens[0] 52 | tokens[0] = mask 53 | 54 | return tokens, output_label 55 | 56 | 57 | class MlmDataset(DetectFeatTxtTokDataset): 58 | def __init__(self, txt_db, img_db): 59 | assert isinstance(txt_db, TxtTokLmdb) 60 | super().__init__(txt_db, img_db) 61 | 62 | def __getitem__(self, i): 63 | """ 64 | Return: 65 | - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded 66 | - img_feat : (num_bb, d) 67 | - img_pos_feat : (num_bb, 7) 68 | - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] 69 | - txt_labels : (L, ), [-1, -1, wid, -1, -1, -1] 70 | 0's padded so that (L + num_bb) % 8 == 0 71 | """ 72 | example = super().__getitem__(i) 73 | 74 | # text input 75 | input_ids, txt_labels = self.create_mlm_io(example['input_ids']) 76 | 77 | # img input 78 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 79 | example['img_fname']) 80 | 81 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 82 | 83 | return input_ids, img_feat, img_pos_feat, attn_masks, txt_labels 84 | 85 | def create_mlm_io(self, input_ids): 86 | input_ids, txt_labels = random_word(input_ids, 87 | self.txt_db.v_range, 88 | self.txt_db.mask) 89 | input_ids = torch.tensor([self.txt_db.cls_] 90 | + input_ids 91 | + [self.txt_db.sep]) 92 | txt_labels = torch.tensor([-1] + txt_labels + [-1]) 93 | return input_ids, txt_labels 94 | 95 | 96 | def mlm_collate(inputs): 97 | """ 98 | Return: 99 | :input_ids (n, max_L) padded with 0 100 | :position_ids (n, max_L) padded with 0 101 | :txt_lens list of [txt_len] 102 | :img_feat (n, max_num_bb, feat_dim) 103 | :img_pos_feat (n, max_num_bb, 7) 104 | :num_bbs list of [num_bb] 105 | :attn_masks (n, max_{L + num_bb}) padded with 0 106 | :txt_labels (n, max_L) padded with -1 107 | """ 108 | (input_ids, img_feats, img_pos_feats, attn_masks, txt_labels 109 | ) = map(list, unzip(inputs)) 110 | 111 | # text batches 112 | txt_lens = [i.size(0) for i in input_ids] 113 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 114 | txt_labels = pad_sequence(txt_labels, batch_first=True, padding_value=-1) 115 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 116 | ).unsqueeze(0) 117 | 118 | # image batches 119 | num_bbs = [f.size(0) for f in img_feats] 120 | img_feat = pad_tensors(img_feats, num_bbs) 121 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 122 | 123 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 124 | 125 | bs, max_tl = input_ids.size() 126 | out_size = attn_masks.size(1) 127 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 128 | 129 | batch = {'input_ids': input_ids, 130 | 'position_ids': position_ids, 131 | 'img_feat': img_feat, 132 | 'img_pos_feat': img_pos_feat, 133 | 'attn_masks': attn_masks, 134 | 'gather_index': gather_index, 135 | 'txt_labels': txt_labels} 136 | return batch 137 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | A prefetch loader to speedup data loading 6 | Modified from Nvidia Deep Learning Examples 7 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 8 | """ 9 | import random 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from utils.distributed import any_broadcast 15 | 16 | 17 | class MetaLoader(object): 18 | """ wraps multiple data loaders """ 19 | def __init__(self, loaders, accum_steps=1, distributed=False): 20 | assert isinstance(loaders, dict) 21 | self.name2loader = {} 22 | self.name2iter = {} 23 | self.sampling_pools = [] 24 | for n, l in loaders.items(): 25 | if isinstance(l, tuple): 26 | l, r = l 27 | elif isinstance(l, DataLoader): 28 | r = 1 29 | else: 30 | raise ValueError() 31 | self.name2loader[n] = l 32 | self.name2iter[n] = iter(l) 33 | self.sampling_pools.extend([n]*r) 34 | 35 | self.accum_steps = accum_steps 36 | self.distributed = distributed 37 | self.step = 0 38 | 39 | def __iter__(self): 40 | """ this iterator will run indefinitely """ 41 | task = self.sampling_pools[0] 42 | while True: 43 | if self.step % self.accum_steps == 0: 44 | task = random.choice(self.sampling_pools) 45 | if self.distributed: 46 | # make sure all process is training same task 47 | task = any_broadcast(task, 0) 48 | self.step += 1 49 | iter_ = self.name2iter[task] 50 | try: 51 | batch = next(iter_) 52 | except StopIteration: 53 | iter_ = iter(self.name2loader[task]) 54 | batch = next(iter_) 55 | self.name2iter[task] = iter_ 56 | 57 | yield task, batch 58 | 59 | 60 | def move_to_cuda(batch): 61 | if isinstance(batch, torch.Tensor): 62 | return batch.cuda(non_blocking=True) 63 | elif isinstance(batch, list): 64 | new_batch = [move_to_cuda(t) for t in batch] 65 | elif isinstance(batch, tuple): 66 | new_batch = tuple(move_to_cuda(t) for t in batch) 67 | elif isinstance(batch, dict): 68 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 69 | else: 70 | return batch 71 | return new_batch 72 | 73 | 74 | def record_cuda_stream(batch): 75 | if isinstance(batch, torch.Tensor): 76 | batch.record_stream(torch.cuda.current_stream()) 77 | elif isinstance(batch, list) or isinstance(batch, tuple): 78 | for t in batch: 79 | record_cuda_stream(t) 80 | elif isinstance(batch, dict): 81 | for t in batch.values(): 82 | record_cuda_stream(t) 83 | else: 84 | pass 85 | 86 | 87 | class PrefetchLoader(object): 88 | """ 89 | overlap compute and cuda data transfer 90 | (copied and then modified from nvidia apex) 91 | """ 92 | def __init__(self, loader): 93 | self.loader = loader 94 | self.stream = torch.cuda.Stream() 95 | 96 | def __iter__(self): 97 | loader_it = iter(self.loader) 98 | self.preload(loader_it) 99 | batch = self.next(loader_it) 100 | while batch is not None: 101 | yield batch 102 | batch = self.next(loader_it) 103 | 104 | def __len__(self): 105 | return len(self.loader) 106 | 107 | def preload(self, it): 108 | try: 109 | self.batch = next(it) 110 | except StopIteration: 111 | self.batch = None 112 | return 113 | # if record_stream() doesn't work, another option is to make sure 114 | # device inputs are created on the main stream. 115 | # self.next_input_gpu = torch.empty_like(self.next_input, 116 | # device='cuda') 117 | # self.next_target_gpu = torch.empty_like(self.next_target, 118 | # device='cuda') 119 | # Need to make sure the memory allocated for next_* is not still in use 120 | # by the main stream at the time we start copying to next_*: 121 | # self.stream.wait_stream(torch.cuda.current_stream()) 122 | with torch.cuda.stream(self.stream): 123 | self.batch = move_to_cuda(self.batch) 124 | # more code for the alternative if record_stream() doesn't work: 125 | # copy_ will record the use of the pinned source tensor in this 126 | # side stream. 127 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 128 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 129 | # self.next_input = self.next_input_gpu 130 | # self.next_target = self.next_target_gpu 131 | 132 | def next(self, it): 133 | torch.cuda.current_stream().wait_stream(self.stream) 134 | batch = self.batch 135 | if batch is not None: 136 | record_cuda_stream(batch) 137 | self.preload(it) 138 | return batch 139 | 140 | def __getattr__(self, name): 141 | method = self.loader.__getattribute__(name) 142 | return method 143 | -------------------------------------------------------------------------------- /scripts/convert_imgdir.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | convert image npz to LMDB 6 | """ 7 | import argparse 8 | import glob 9 | import io 10 | import json 11 | import multiprocessing as mp 12 | import os 13 | from os.path import basename, exists 14 | 15 | from cytoolz import curry 16 | import numpy as np 17 | from tqdm import tqdm 18 | import lmdb 19 | 20 | import msgpack 21 | import msgpack_numpy 22 | msgpack_numpy.patch() 23 | 24 | 25 | def _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb): 26 | num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum()) 27 | num_bb = min(max_bb, num_bb) 28 | return int(num_bb) 29 | 30 | 31 | @curry 32 | def load_npz(conf_th, max_bb, min_bb, num_bb, fname, keep_all=False): 33 | try: 34 | img_dump = np.load(fname, allow_pickle=True) 35 | if keep_all: 36 | nbb = int(img_dump['nbb']) 37 | else: 38 | nbb = _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb) 39 | dump = {} 40 | for key, arr in img_dump.items(): 41 | if arr.dtype == np.float32: 42 | arr = arr.astype(np.float16) 43 | if arr.ndim == 2: 44 | dump[key] = arr[:nbb, :] 45 | elif arr.ndim == 1: 46 | dump[key] = arr[:nbb] 47 | elif arr.ndim == 0: 48 | pass 49 | else: 50 | raise ValueError('wrong ndim') 51 | except Exception as e: 52 | # corrupted file 53 | print(f'corrupted file {fname}', e) 54 | dump = {} 55 | nbb = 0 56 | 57 | name = basename(fname) 58 | return name, dump, nbb 59 | 60 | 61 | def dumps_npz(dump, compress=False): 62 | with io.BytesIO() as writer: 63 | if compress: 64 | np.savez_compressed(writer, **dump, allow_pickle=True) 65 | else: 66 | np.savez(writer, **dump, allow_pickle=True) 67 | return writer.getvalue() 68 | 69 | 70 | def dumps_msgpack(dump): 71 | return msgpack.dumps(dump, use_bin_type=True) 72 | 73 | 74 | def main(opts): 75 | if opts.img_dir[-1] == '/': 76 | opts.img_dir = opts.img_dir[:-1] 77 | split = basename(opts.img_dir) 78 | if opts.keep_all: 79 | db_name = 'all' 80 | else: 81 | if opts.conf_th == -1: 82 | db_name = f'feat_numbb{opts.num_bb}' 83 | else: 84 | db_name = (f'feat_th{opts.conf_th}_max{opts.max_bb}' 85 | f'_min{opts.min_bb}') 86 | if opts.compress: 87 | db_name += '_compressed' 88 | if not exists(f'{opts.output}/{split}'): 89 | os.makedirs(f'{opts.output}/{split}') 90 | env = lmdb.open(f'{opts.output}/{split}/{db_name}', map_size=1024**4) 91 | txn = env.begin(write=True) 92 | files = glob.glob(f'{opts.img_dir}/*.npz') 93 | load = load_npz(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, 94 | keep_all=opts.keep_all) 95 | name2nbb = {} 96 | with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar: 97 | for i, (fname, features, nbb) in enumerate( 98 | pool.imap_unordered(load, files, chunksize=128)): 99 | if not features: 100 | continue # corrupted feature 101 | if opts.compress: 102 | dump = dumps_npz(features, compress=True) 103 | else: 104 | dump = dumps_msgpack(features) 105 | txn.put(key=fname.encode('utf-8'), value=dump) 106 | if i % 1000 == 0: 107 | txn.commit() 108 | txn = env.begin(write=True) 109 | name2nbb[fname] = nbb 110 | pbar.update(1) 111 | txn.put(key=b'__keys__', 112 | value=json.dumps(list(name2nbb.keys())).encode('utf-8')) 113 | txn.commit() 114 | env.close() 115 | if opts.keep_all: 116 | with open(f'{opts.output}/{split}/' 117 | f'nbb.json', 'w') as f: 118 | json.dump(name2nbb, f) 119 | if opts.conf_th != -1 and not opts.keep_all: 120 | with open(f'{opts.output}/{split}/' 121 | f'nbb_th{opts.conf_th}_' 122 | f'max{opts.max_bb}_min{opts.min_bb}.json', 'w') as f: 123 | json.dump(name2nbb, f) 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--img_dir", default=None, type=str, 129 | help="The input images.") 130 | parser.add_argument("--output", default=None, type=str, 131 | help="output lmdb") 132 | parser.add_argument('--nproc', type=int, default=8, 133 | help='number of cores used') 134 | parser.add_argument('--compress', action='store_true', 135 | help='compress the tensors') 136 | parser.add_argument('--keep_all', action='store_true', 137 | help='keep all features, overrides all following args') 138 | parser.add_argument('--conf_th', type=float, default=0.2, 139 | help='threshold for dynamic bounding boxes ' 140 | '(-1 for fixed)') 141 | parser.add_argument('--max_bb', type=int, default=100, 142 | help='max number of bounding boxes') 143 | parser.add_argument('--min_bb', type=int, default=10, 144 | help='min number of bounding boxes') 145 | parser.add_argument('--num_bb', type=int, default=100, 146 | help='number of bounding boxes (fixed)') 147 | args = parser.parse_args() 148 | main(args) 149 | -------------------------------------------------------------------------------- /inf_itm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | run inference for Image Text Retrieval 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from os.path import exists 11 | import pickle 12 | from time import time 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | 17 | from apex import amp 18 | from horovod import torch as hvd 19 | 20 | from data import (PrefetchLoader, 21 | DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate) 22 | from model.itm import UniterForImageTextRetrieval 23 | 24 | from utils.logger import LOGGER 25 | from utils.distributed import all_gather_list 26 | from utils.misc import Struct 27 | from utils.const import IMG_DIM 28 | from utils.itm_eval import inference, itm_eval 29 | 30 | 31 | def main(opts): 32 | hvd.init() 33 | n_gpu = hvd.size() 34 | device = torch.device("cuda", hvd.local_rank()) 35 | torch.cuda.set_device(hvd.local_rank()) 36 | rank = hvd.rank() 37 | LOGGER.info("device: {} n_gpu: {}, rank: {}, " 38 | "16-bits training: {}".format( 39 | device, n_gpu, hvd.rank(), opts.fp16)) 40 | 41 | if opts.train_config is not None: 42 | train_opts = Struct(json.load(open(opts.train_config))) 43 | opts.conf_th = train_opts.conf_th 44 | opts.max_bb = train_opts.max_bb 45 | opts.min_bb = train_opts.min_bb 46 | opts.num_bb = train_opts.num_bb 47 | 48 | # load DBs and image dirs 49 | eval_img_db = DetectFeatLmdb(opts.img_db, 50 | opts.conf_th, opts.max_bb, 51 | opts.min_bb, opts.num_bb, 52 | opts.compressed_db) 53 | eval_txt_db = TxtTokLmdb(opts.txt_db, -1) 54 | eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size, IAIS=opts.IAIS) 55 | 56 | # Prepare model 57 | checkpoint = torch.load(opts.checkpoint) 58 | model = UniterForImageTextRetrieval.from_pretrained( 59 | opts.model_config, checkpoint, img_dim=IMG_DIM) 60 | if 'rank_output' not in checkpoint: 61 | model.init_output() # zero shot setting 62 | 63 | model.to(device) 64 | model, _ = amp.initialize(model, enabled=opts.fp16, opt_level='O2') 65 | 66 | eval_dataloader = DataLoader(eval_dataset, batch_size=1, 67 | num_workers=opts.n_workers, 68 | pin_memory=opts.pin_mem, 69 | collate_fn=itm_eval_collate) 70 | eval_dataloader = PrefetchLoader(eval_dataloader) 71 | 72 | eval_log, results = evaluate(model, eval_dataloader, opts.IAIS) 73 | if hvd.rank() == 0: 74 | if not exists(opts.output_dir) and rank == 0: 75 | os.makedirs(opts.output_dir) 76 | with open(f'{opts.output_dir}/config.json', 'w') as f: 77 | json.dump(vars(opts), f) 78 | with open(f'{opts.output_dir}/results.bin', 'wb') as f: 79 | pickle.dump(results, f) 80 | with open(f'{opts.output_dir}/scores.json', 'w') as f: 81 | json.dump(eval_log, f) 82 | LOGGER.info(f'evaluation finished') 83 | LOGGER.info( 84 | f"======================== Results =========================\n" 85 | f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" 86 | f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" 87 | f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" 88 | f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" 89 | f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" 90 | f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") 91 | LOGGER.info("========================================================") 92 | 93 | 94 | @torch.no_grad() 95 | def evaluate(model, eval_loader, IAIS): 96 | model.eval() 97 | st = time() 98 | LOGGER.info("start running Image/Text Retrieval evaluation ...") 99 | score_matrix = inference(model, eval_loader, IAIS) 100 | dset = eval_loader.dataset 101 | all_score = hvd.allgather(score_matrix) 102 | all_txt_ids = [i for ids in all_gather_list(dset.ids) 103 | for i in ids] 104 | all_img_ids = dset.all_img_ids 105 | assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) 106 | if hvd.rank() != 0: 107 | return {}, tuple() 108 | # NOTE: only use rank0 to compute final scores 109 | eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, 110 | dset.txt2img, dset.img2txts) 111 | 112 | results = (all_score, all_txt_ids, all_img_ids) 113 | tot_time = time()-st 114 | LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ") 115 | return eval_log, results 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | 121 | # Required parameters 122 | parser.add_argument("--txt_db", default=None, type=str, 123 | help="The input train corpus. (LMDB)") 124 | parser.add_argument("--img_db", default=None, type=str, 125 | help="The input train images.") 126 | parser.add_argument("--checkpoint", default=None, type=str, 127 | help="model checkpoint binary") 128 | parser.add_argument("--model_config", default=None, type=str, 129 | help="model config json") 130 | parser.add_argument( 131 | "--output_dir", default=None, type=str, 132 | help="The output directory where the inference results will be " 133 | "written.") 134 | 135 | # optional parameters 136 | parser.add_argument("--train_config", default=None, type=str, 137 | help="hps.json from training (for prepro hps)") 138 | parser.add_argument('--compressed_db', action='store_true', 139 | help='use compressed LMDB') 140 | parser.add_argument('--conf_th', type=float, default=0.2, 141 | help='threshold for dynamic bounding boxes ' 142 | '(-1 for fixed)') 143 | parser.add_argument('--max_bb', type=int, default=100, 144 | help='max number of bounding boxes') 145 | parser.add_argument('--min_bb', type=int, default=10, 146 | help='min number of bounding boxes') 147 | parser.add_argument('--num_bb', type=int, default=36, 148 | help='static number of bounding boxes') 149 | parser.add_argument("--batch_size", default=400, type=int, 150 | help="number of tokens in a batch") 151 | 152 | # device parameters 153 | parser.add_argument('--fp16', action='store_true', 154 | help="Whether to use 16-bit float precision instead " 155 | "of 32-bit") 156 | parser.add_argument('--n_workers', type=int, default=4, 157 | help="number of data workers") 158 | parser.add_argument('--pin_mem', action='store_true', 159 | help="pin memory") 160 | 161 | parser.add_argument('--IAIS', action='store_true', help='whether to use IAIS') 162 | 163 | args = parser.parse_args() 164 | 165 | main(args) 166 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | distributed API using Horovod 6 | Modified from OpenNMT's native pytorch distributed utils 7 | (https://github.com/OpenNMT/OpenNMT-py) 8 | """ 9 | import math 10 | import pickle 11 | 12 | import torch 13 | from horovod import torch as hvd 14 | 15 | 16 | def all_reduce_and_rescale_tensors(tensors, rescale_denom): 17 | """All-reduce and rescale tensors at once (as a flattened tensor) 18 | 19 | Args: 20 | tensors: list of Tensors to all-reduce 21 | rescale_denom: denominator for rescaling summed Tensors 22 | """ 23 | # buffer size in bytes, determine equiv. # of elements based on data type 24 | sz = sum(t.numel() for t in tensors) 25 | buffer_t = tensors[0].new(sz).zero_() 26 | 27 | # copy tensors into buffer_t 28 | offset = 0 29 | for t in tensors: 30 | numel = t.numel() 31 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 32 | offset += numel 33 | 34 | # all-reduce and rescale 35 | hvd.allreduce_(buffer_t[:offset]) 36 | buffer_t.div_(rescale_denom) 37 | 38 | # copy all-reduced buffer back into tensors 39 | offset = 0 40 | for t in tensors: 41 | numel = t.numel() 42 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 43 | offset += numel 44 | 45 | 46 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, 47 | buffer_size=10485760): 48 | """All-reduce and rescale tensors in chunks of the specified size. 49 | 50 | Args: 51 | tensors: list of Tensors to all-reduce 52 | rescale_denom: denominator for rescaling summed Tensors 53 | buffer_size: all-reduce chunk size in bytes 54 | """ 55 | # buffer size in bytes, determine equiv. # of elements based on data type 56 | buffer_t = tensors[0].new( 57 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 58 | buffer = [] 59 | 60 | def all_reduce_buffer(): 61 | # copy tensors into buffer_t 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 66 | offset += numel 67 | 68 | # all-reduce and rescale 69 | hvd.allreduce_(buffer_t[:offset]) 70 | buffer_t.div_(rescale_denom) 71 | 72 | # copy all-reduced buffer back into tensors 73 | offset = 0 74 | for t in buffer: 75 | numel = t.numel() 76 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 77 | offset += numel 78 | 79 | filled = 0 80 | for t in tensors: 81 | sz = t.numel() * t.element_size() 82 | if sz > buffer_size: 83 | # tensor is bigger than buffer, all-reduce and rescale directly 84 | hvd.allreduce_(t) 85 | t.div_(rescale_denom) 86 | elif filled + sz > buffer_size: 87 | # buffer is full, all-reduce and replace buffer with grad 88 | all_reduce_buffer() 89 | buffer = [t] 90 | filled = sz 91 | else: 92 | # add tensor to buffer 93 | buffer.append(t) 94 | filled += sz 95 | 96 | if len(buffer) > 0: 97 | all_reduce_buffer() 98 | 99 | 100 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760): 101 | """broadcast tensors in chunks of the specified size. 102 | 103 | Args: 104 | tensors: list of Tensors to broadcast 105 | root_rank: rank to broadcast 106 | buffer_size: broadcast chunk size in bytes 107 | """ 108 | # buffer size in bytes, determine equiv. # of elements based on data type 109 | buffer_t = tensors[0].new( 110 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 111 | buffer = [] 112 | 113 | def broadcast_buffer(): 114 | # copy tensors into buffer_t 115 | offset = 0 116 | for t in buffer: 117 | numel = t.numel() 118 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 119 | offset += numel 120 | 121 | # broadcast 122 | hvd.broadcast_(buffer_t[:offset], root_rank) 123 | 124 | # copy all-reduced buffer back into tensors 125 | offset = 0 126 | for t in buffer: 127 | numel = t.numel() 128 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 129 | offset += numel 130 | 131 | filled = 0 132 | for t in tensors: 133 | sz = t.numel() * t.element_size() 134 | if sz > buffer_size: 135 | # tensor is bigger than buffer, broadcast directly 136 | hvd.broadcast_(t, root_rank) 137 | elif filled + sz > buffer_size: 138 | # buffer is full, broadcast and replace buffer with tensor 139 | broadcast_buffer() 140 | buffer = [t] 141 | filled = sz 142 | else: 143 | # add tensor to buffer 144 | buffer.append(t) 145 | filled += sz 146 | 147 | if len(buffer) > 0: 148 | broadcast_buffer() 149 | 150 | 151 | def _encode(enc, max_size, use_max_size=False): 152 | enc_size = len(enc) 153 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 154 | if use_max_size: 155 | # this is used for broadcasting 156 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) 157 | else: 158 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 159 | remainder = enc_size 160 | for i in range(enc_byte): 161 | base = 256 ** (enc_byte-i-1) 162 | buffer_[i] = remainder // base 163 | remainder %= base 164 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 165 | return buffer_, enc_byte 166 | 167 | 168 | def _decode(buffer_, enc_byte): 169 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 170 | for i in range(enc_byte)) 171 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 172 | shift = size + enc_byte 173 | return bytes_list, shift 174 | 175 | 176 | _BUFFER_SIZE = 4096 177 | 178 | 179 | def all_gather_list(data): 180 | """Gathers arbitrary data from all nodes into a list.""" 181 | enc = pickle.dumps(data) 182 | 183 | enc_size = len(enc) 184 | max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() 185 | in_buffer, enc_byte = _encode(enc, max_size) 186 | 187 | out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) 188 | 189 | results = [] 190 | for _ in range(hvd.size()): 191 | bytes_list, shift = _decode(out_buffer, enc_byte) 192 | out_buffer = out_buffer[shift:] 193 | result = pickle.loads(bytes_list) 194 | results.append(result) 195 | return results 196 | 197 | 198 | def any_broadcast(data, root_rank): 199 | """broadcast arbitrary data from root_rank to all nodes.""" 200 | enc = pickle.dumps(data) 201 | 202 | max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() 203 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) 204 | 205 | hvd.broadcast_(buffer_, root_rank) 206 | 207 | bytes_list, _ = _decode(buffer_, enc_byte) 208 | result = pickle.loads(bytes_list) 209 | return result 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IAIS: Inter-modal Alignment for Intra-modal Self-attentions 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-relation-alignment-for-calibrated/visual-reasoning-on-winoground)](https://paperswithcode.com/sota/visual-reasoning-on-winoground?p=learning-relation-alignment-for-calibrated) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-relation-alignment-for-calibrated/image-to-text-retrieval-on-coco)](https://paperswithcode.com/sota/image-to-text-retrieval-on-coco?p=learning-relation-alignment-for-calibrated) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-relation-alignment-for-calibrated/cross-modal-retrieval-on-flickr30k)](https://paperswithcode.com/sota/cross-modal-retrieval-on-flickr30k?p=learning-relation-alignment-for-calibrated) 6 | 7 | 8 | This repository contains the code for our paper [Learning Relation Alignment for Calibrated Cross-modal Retrieval](https://arxiv.org/abs/2105.13868) (ACL-IJCNLP 2021 main conference). 9 | 10 | ![Overview of IAIS](figures/IAIS.png) 11 | 12 | 13 | Some code in this repo are copied/modified from [UNITER](https://github.com/ChenRocks/UNITER), and other opensource implementations made available by 14 | [PyTorch](https://github.com/pytorch/pytorch), 15 | [HuggingFace](https://github.com/huggingface/transformers), 16 | [OpenNMT](https://github.com/OpenNMT/OpenNMT-py), 17 | and [Nvidia](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 18 | The image features are extracted using [BUTD](https://github.com/peteanderson80/bottom-up-attention). 19 | 20 | ## Update 21 | [2023.02] Please refer to [Github Release](https://github.com/lancopku/IAIS/releases/tag/v0.1.0) for our fine-tuned checkpoints and logs for MS COCO and Flickr30k. 22 | 23 | [2022.12] According to the researchers from CMU, our IAIS algorithm achieves a new SOTA on the **Winoground** dataset with a 10% improvement on VinVL (Oscar+) and a 52% improvement on UNITER. Thanks for the interesting work. Their paper: [link](https://arxiv.org/abs/2212.10549). 24 | 25 | ## Overview 26 | 27 | 1. We propose a **Relation Consistency Hypothesis**: Given a matched image-text pair, the linguistic relation should agree with the visual relation. 28 | 2. We design a novel **metric: Intra-modal Self-attention Distance with annotation (ISDa)** to measure the consistency between textual and visual relations. 29 | 3. We propose a new **regularized training method** called **Inter-modal Alignment on Intra-modal Self-attentions (IAIS)** to calibrate two intra-modal attention distributions mutually via inter-modal alignment, which helps learn better contextualized representations for image-text pairs. 30 | 31 | ## Requirements 32 | We provide Docker image for easier reproduction. Please install the following: 33 | - [nvidia driver](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation) (418+), 34 | - [Docker](https://docs.docker.com/install/linux/docker-ce/ubuntu/) (19.03+), 35 | - [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-docker#quickstart). 36 | 37 | Our scripts require the user to have the [docker group membership](https://docs.docker.com/install/linux/linux-postinstall/) 38 | so that docker commands can be run without sudo. 39 | We only support Linux with NVIDIA GPUs. We test on Ubuntu 18.04 and V100 cards. 40 | We use mixed-precision training hence GPUs with Tensor Cores are recommended. 41 | 42 | ## Getting Started 43 | 1. Download processed data and pretrained models with the following command. 44 | ```bash 45 | bash scripts/download_itm.sh $PATH_TO_STORAGE 46 | ``` 47 | After downloading you should see the following folder structure: 48 | ``` 49 | ├── img_db 50 | │   ├── coco_train2014 51 | │   ├── coco_train2014.tar 52 | │   ├── coco_val2014 53 | │   ├── coco_val2014.tar 54 | │   ├── flickr30k 55 | │   └── flickr30k.tar 56 | ├── pretrained 57 | │   ├── uniter-base.pt 58 | │   ├── uniter-large.pt 59 | └── txt_db 60 |    ├── itm_coco_train.db 61 |    ├── itm_coco_train.db.tar 62 |    ├── itm_coco_val.db 63 |    ├── itm_coco_val.db.tar 64 |    ├── itm_coco_restval.db 65 |    ├── itm_coco_restval.db.tar 66 |    ├── itm_coco_test.db 67 |    ├── itm_coco_test.db.tar 68 |    ├── itm_flickr30k_train.db 69 |    ├── itm_flickr30k_train.db.tar 70 |    ├── itm_flickr30k_val.db 71 |    ├── itm_flickr30k_val.db.tar 72 |    ├── itm_flickr30k_test.db 73 |    └── itm_flickr30k_test.db.tar 74 | ``` 75 | 76 | 2. Launch the Docker container for running the experiments. 77 | ```bash 78 | # docker image should be automatically pulled 79 | source launch_container.sh $PATH_TO_STORAGE/txt_db $PATH_TO_STORAGE/img_db \ 80 | $PATH_TO_STORAGE/finetune $PATH_TO_STORAGE/pretrained 81 | ``` 82 | The launch script respects $CUDA_VISIBLE_DEVICES environment variable. 83 | Note that the source code is mounted into the container under `/src` instead 84 | of built into the image so that user modification will be reflected without 85 | re-building the image. (Data folders are mounted into the container separately 86 | for flexibility on folder structures.) 87 | 88 | 3. Run finetuning for the ITM task. 89 | 90 | All experiments in the paper are conducted on 8 NVIDIA V100 GPUs. 91 | 92 | - Image-Text Retrieval (Flickr30k) 93 | - finetune with hard negatives 94 | ``` 95 | horovodrun -np 8 python train_itm_hard_negatives.py \ 96 | --config config/train-itm-flickr-base-8gpu-hn.jgon 97 | ``` 98 | - finetune with hard negatives + **IAIS** 99 | ``` 100 | horovodrun -np 8 python train_itm_hard_negatives.py \ 101 | --config config/train-itm-flickr-base-8gpu-hn.jgon --IAIS singular 102 | ``` 103 | - Image-Text Retrieval (COCO) 104 | - finetune with hard negatives 105 | ``` 106 | horovodrun -np 8 python train_itm_hard_negatives.py \ 107 | --config config/train-itm-coco-base-8gpu-hn.json 108 | ``` 109 | - finetune with hard negatives + **IAIS** 110 | ``` 111 | horovodrun -np 8 python train_itm_hard_negatives.py \ 112 | --config config/train-itm-coco-base-8gpu-hn.json --IAIS singular 113 | ``` 114 | The argument `--IAIS` indicates incorporating the auxiliary IAIS loss to the fine-tuning phase. We support 115 | - `singular`: Singular Alignment that establishes a one-to-one mapping between linguistic and visual attention weight (Section 3.1 in the paper). 116 | ![Overview of IAIS](figures/singular_alignment.gif) 117 | - `distributed`: Distributed Alignment that establishes a distributed mapping (Section 3.2 in the paper). 118 | ![Overview of IAIS](figures/distributed_alignment.gif) 119 | 120 | The main code for the IAIS method is in the `UniterEncoder` class in [model/model.py](model/model.py). 121 | 122 | ## Contact 123 | 124 | If you have any questions related to the code or the paper, feel free to email Shuhuai (renshuhuai007 [AT] gmail [DOT] com). 125 | 126 | ## Citation 127 | 128 | If you find this code useful for your research, please consider citing: 129 | ``` 130 | @inproceedings{ren2021iais, 131 | title = "Learning Relation Alignment for Calibrated Cross-modal Retrieval", 132 | author = "Ren, Shuhuai and Lin, Junyang and Zhao, Guangxiang and Men, Rui and Yang, An and Zhou, Jingren and Sun, Xu and Yang, Hongxia", 133 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 134 | year = "2021", 135 | } 136 | ``` 137 | 138 | ## License 139 | 140 | MIT 141 | -------------------------------------------------------------------------------- /data/mrm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | MRM Datasets 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from toolz.sandbox import unzip 12 | from .data import DetectFeatTxtTokDataset, pad_tensors, get_gather_index 13 | 14 | 15 | def _get_img_mask(mask_prob, num_bb): 16 | img_mask = [random.random() < mask_prob for _ in range(num_bb)] 17 | if not any(img_mask): 18 | # at least mask 1 19 | img_mask[random.choice(range(num_bb))] = True 20 | img_mask = torch.tensor(img_mask) 21 | return img_mask 22 | 23 | 24 | def _get_img_tgt_mask(img_mask, txt_len): 25 | z = torch.zeros(txt_len, dtype=torch.uint8) 26 | img_mask_tgt = torch.cat([z, img_mask], dim=0) 27 | return img_mask_tgt 28 | 29 | 30 | def _get_feat_target(img_feat, img_masks): 31 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) # (n, m, d) 32 | feat_dim = img_feat.size(-1) 33 | feat_targets = img_feat[img_masks_ext].contiguous().view( 34 | -1, feat_dim) # (s, d) 35 | return feat_targets 36 | 37 | 38 | def _mask_img_feat(img_feat, img_masks): 39 | img_masks_ext = img_masks.unsqueeze(-1).expand_as(img_feat) 40 | img_feat_masked = img_feat.data.masked_fill(img_masks_ext, 0) 41 | return img_feat_masked 42 | 43 | 44 | class MrfrDataset(DetectFeatTxtTokDataset): 45 | def __init__(self, mask_prob, *args, **kwargs): 46 | super().__init__(*args, **kwargs) 47 | self.mask_prob = mask_prob 48 | 49 | def __getitem__(self, i): 50 | """ 51 | Return: 52 | - input_ids : (L, ), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded 53 | - img_feat : (num_bb, d) 54 | - img_pos_feat : (num_bb, 7) 55 | - attn_masks : (L + num_bb, ), ie., [1, 1, ..., 0, 0, 1, 1] 56 | - img_mask : (num_bb, ) between {0, 1} 57 | """ 58 | example = super().__getitem__(i) 59 | # text input 60 | input_ids = example['input_ids'] 61 | input_ids = self.txt_db.combine_inputs(input_ids) 62 | 63 | # image input features 64 | img_feat, img_pos_feat, num_bb = self._get_img_feat( 65 | example['img_fname']) 66 | img_mask = _get_img_mask(self.mask_prob, num_bb) 67 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) 68 | 69 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 70 | 71 | return (input_ids, img_feat, img_pos_feat, 72 | attn_masks, img_mask, img_mask_tgt) 73 | 74 | 75 | def mrfr_collate(inputs): 76 | """ 77 | Return: 78 | - input_ids : (n, max_L), i.e., [cls, wd, wd, ..., sep, 0, 0], 0s padded 79 | - position_ids : (n, max_L) 80 | - txt_lens : list of [input_len] 81 | - img_feat : (n, max_num_bb, d) 82 | - img_pos_feat : (n, max_num_bb, 7) 83 | - num_bbs : list of [num_bb] 84 | - attn_masks : (n, max_{L + num_bb}), ie., [1, 1, ..., 0, 0, 1, 1] 85 | - img_masks : (n, max_num_bb) between {0, 1} 86 | """ 87 | (input_ids, img_feats, img_pos_feats, attn_masks, img_masks, img_mask_tgts, 88 | ) = map(list, unzip(inputs)) 89 | 90 | txt_lens = [i.size(0) for i in input_ids] 91 | 92 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 93 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 94 | ).unsqueeze(0) 95 | 96 | num_bbs = [f.size(0) for f in img_feats] 97 | img_feat = pad_tensors(img_feats, num_bbs) 98 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 99 | 100 | # mask features 101 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 102 | feat_targets = _get_feat_target(img_feat, img_masks) 103 | img_feat = _mask_img_feat(img_feat, img_masks) 104 | img_mask_tgt = pad_sequence(img_mask_tgts, 105 | batch_first=True, padding_value=0) 106 | 107 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 108 | bs, max_tl = input_ids.size() 109 | out_size = attn_masks.size(1) 110 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 111 | 112 | batch = {'input_ids': input_ids, 113 | 'position_ids': position_ids, 114 | 'img_feat': img_feat, 115 | 'img_pos_feat': img_pos_feat, 116 | 'attn_masks': attn_masks, 117 | 'gather_index': gather_index, 118 | 'feat_targets': feat_targets, 119 | 'img_masks': img_masks, 120 | 'img_mask_tgt': img_mask_tgt} 121 | return batch 122 | 123 | 124 | def _get_targets(img_masks, img_soft_label): 125 | soft_label_dim = img_soft_label.size(-1) 126 | img_masks_ext_for_label = img_masks.unsqueeze(-1).expand_as(img_soft_label) 127 | label_targets = img_soft_label[img_masks_ext_for_label].contiguous().view( 128 | -1, soft_label_dim) 129 | return label_targets 130 | 131 | 132 | class MrcDataset(DetectFeatTxtTokDataset): 133 | def __init__(self, mask_prob, *args, **kwargs): 134 | super().__init__(*args, **kwargs) 135 | self.mask_prob = mask_prob 136 | 137 | def _get_img_feat(self, fname): 138 | img_dump = self.img_db.get_dump(fname) 139 | num_bb = self.img_db.name2nbb[fname] 140 | img_feat = torch.tensor(img_dump['features']) 141 | bb = torch.tensor(img_dump['norm_bb']) 142 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 143 | img_soft_label = torch.tensor(img_dump['soft_labels']) 144 | return img_feat, img_bb, img_soft_label, num_bb 145 | 146 | def __getitem__(self, i): 147 | example = super().__getitem__(i) 148 | img_feat, img_pos_feat, img_soft_labels, num_bb = self._get_img_feat( 149 | example['img_fname']) 150 | 151 | # image input features 152 | img_mask = _get_img_mask(self.mask_prob, num_bb) 153 | 154 | # text input 155 | input_ids = example['input_ids'] 156 | input_ids = self.txt_db.combine_inputs(input_ids) 157 | img_mask_tgt = _get_img_tgt_mask(img_mask, len(input_ids)) 158 | 159 | attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) 160 | 161 | return (input_ids, img_feat, img_pos_feat, 162 | img_soft_labels, attn_masks, img_mask, img_mask_tgt) 163 | 164 | 165 | def mrc_collate(inputs): 166 | (input_ids, img_feats, img_pos_feats, img_soft_labels, 167 | attn_masks, img_masks, img_mask_tgts) = map(list, unzip(inputs)) 168 | 169 | txt_lens = [i.size(0) for i in input_ids] 170 | num_bbs = [f.size(0) for f in img_feats] 171 | 172 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) 173 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long 174 | ).unsqueeze(0) 175 | 176 | img_feat = pad_tensors(img_feats, num_bbs) 177 | img_pos_feat = pad_tensors(img_pos_feats, num_bbs) 178 | img_soft_label = pad_tensors(img_soft_labels, num_bbs) 179 | img_masks = pad_sequence(img_masks, batch_first=True, padding_value=0) 180 | label_targets = _get_targets(img_masks, img_soft_label) 181 | 182 | img_feat = _mask_img_feat(img_feat, img_masks) 183 | img_mask_tgt = pad_sequence(img_mask_tgts, 184 | batch_first=True, padding_value=0) 185 | 186 | attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) 187 | bs, max_tl = input_ids.size() 188 | out_size = attn_masks.size(1) 189 | gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) 190 | 191 | batch = {'input_ids': input_ids, 192 | 'position_ids': position_ids, 193 | 'img_feat': img_feat, 194 | 'img_pos_feat': img_pos_feat, 195 | 'attn_masks': attn_masks, 196 | 'gather_index': gather_index, 197 | 'img_masks': img_masks, 198 | 'img_mask_tgt': img_mask_tgt, 199 | 'label_targets': label_targets} 200 | return batch 201 | -------------------------------------------------------------------------------- /model/itm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | UNITER for ITM model 6 | """ 7 | from collections import defaultdict 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from .model import UniterPreTrainedModel, UniterModel 12 | import pdb 13 | from utils.heatmap import plot_attention_headmap 14 | import numpy as np 15 | 16 | 17 | class UniterForImageTextRetrieval(UniterPreTrainedModel): 18 | """ Finetune UNITER for image text retrieval 19 | """ 20 | def __init__(self, config, img_dim, pairs_num, margin=0.2): 21 | super().__init__(config) 22 | self.uniter = UniterModel(config, img_dim) 23 | self.itm_output = nn.Linear(config.hidden_size, 2) 24 | self.rank_output = nn.Linear(config.hidden_size, 1) 25 | self.margin = margin 26 | self.pairs_num = pairs_num 27 | self.apply(self.init_weights) 28 | 29 | def init_output(self): 30 | """ need to be called after from pretrained """ 31 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :] 32 | self.rank_output.bias.data = self.itm_output.bias.data[1:] 33 | 34 | def forward(self, batch, compute_loss=True, IAIS=False): 35 | batch = defaultdict(lambda: None, batch) 36 | input_ids = batch['input_ids'] 37 | position_ids = batch['position_ids'] 38 | img_feat = batch['img_feat'] 39 | img_pos_feat = batch['img_pos_feat'] 40 | attention_mask = batch['attn_masks'] 41 | if IAIS: 42 | gather_index = None 43 | txt_attn_masks = batch['txt_attn_masks'] # [sample_num, max_tl+max_nbb] 44 | img_attn_masks = batch['img_attn_masks'] # [sample_num, max_tl+max_nbb] 45 | sequence_output, self_attn_loss_per_layer = self.uniter(input_ids, position_ids, 46 | img_feat, img_pos_feat, 47 | attention_mask, gather_index, None, 48 | txt_attn_masks, img_attn_masks, 49 | output_all_encoded_layers=False, 50 | IAIS=IAIS, 51 | pairs_num=self.pairs_num) 52 | else: # evaluation 53 | gather_index = batch['gather_index'] 54 | sequence_output = self.uniter(input_ids, position_ids, 55 | img_feat, img_pos_feat, 56 | attention_mask, gather_index, 57 | output_all_encoded_layers=False) 58 | # sequence_output: [sample_num, max_tl+max_nbb, hidden_size(768)] 59 | pooled_output = self.uniter.pooler(sequence_output) 60 | rank_scores = self.rank_output(pooled_output) 61 | # rank_scores: [sample_num, 1] 62 | 63 | if compute_loss: 64 | # triplet loss 65 | rank_scores_sigmoid = torch.sigmoid(rank_scores) 66 | sample_size = batch['sample_size'] 67 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) 68 | pos = scores[:, :1] 69 | neg = scores[:, 1:] 70 | rank_loss = torch.clamp(self.margin + neg - pos, 0) 71 | # self-attn agree loss 72 | if IAIS: 73 | return rank_loss, self_attn_loss_per_layer 74 | else: 75 | return rank_loss 76 | else: 77 | return rank_scores 78 | 79 | 80 | class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): 81 | """ Finetune UNITER for image text retrieval 82 | """ 83 | def __init__(self, config, img_dim, margin=0.2, hard_size=16): 84 | super().__init__(config, img_dim, hard_size + 1, margin) 85 | self.hard_size = hard_size 86 | 87 | def forward(self, batch, sample_from='t', compute_loss=True, IAIS=False): 88 | # expect same input_ids for all pairs 89 | batch_size = batch['attn_masks'].size(0) 90 | input_ids = batch['input_ids'] 91 | img_feat = batch['img_feat'] 92 | img_pos_feat = batch['img_pos_feat'] 93 | if sample_from == 't': 94 | if input_ids.size(0) == 1: 95 | batch['input_ids'] = input_ids.expand(batch_size, -1) 96 | elif sample_from == 'i': 97 | if img_feat.size(0) == 1: 98 | batch['img_feat'] = img_feat.expand(batch_size, -1, -1) # copy img_feat for batch_size times 99 | if img_pos_feat.size(0) == 1: 100 | batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) 101 | else: 102 | raise ValueError() 103 | 104 | if self.training and compute_loss: 105 | with torch.no_grad(): 106 | self.eval() 107 | scores = super().forward(batch, compute_loss=False) # only evaluate 108 | hard_batch = self._get_hard_batch(batch, scores, sample_from, IAIS) 109 | self.train() 110 | return super().forward(hard_batch, compute_loss=True, IAIS=IAIS) 111 | else: 112 | return super().forward(batch, compute_loss) # only evaluate 113 | 114 | def _get_hard_batch(self, batch, scores, sample_from='t', IAIS=False): 115 | batch = defaultdict(lambda: None, batch) 116 | input_ids = batch['input_ids'] 117 | position_ids = batch['position_ids'] 118 | img_feat = batch['img_feat'] 119 | img_pos_feat = batch['img_pos_feat'] 120 | attention_mask = batch['attn_masks'] 121 | hard_batch = {'sample_size': self.hard_size + 1} 122 | 123 | # NOTE first example is positive 124 | hard_indices = scores.squeeze(-1)[1:].topk(self.hard_size, sorted=False)[1] + 1 125 | indices = torch.cat([torch.zeros(1, dtype=torch.long, 126 | device=hard_indices.device), 127 | hard_indices]) # [32] 128 | 129 | attention_mask = attention_mask.index_select(0, indices) 130 | 131 | if position_ids.size(0) != 1: 132 | position_ids = position_ids[:self.hard_size+1] 133 | 134 | if sample_from == 't': 135 | # cut to minimum padding 136 | max_len = attention_mask.sum(dim=1).max().item() 137 | max_i = max_len - input_ids.size(1) 138 | attention_mask = attention_mask[:, :max_len] 139 | img_feat = img_feat.index_select(0, indices)[:, :max_i, :] 140 | img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] 141 | # expect same input_ids for all pairs 142 | input_ids = input_ids[:self.hard_size+1] 143 | elif sample_from == 'i': 144 | input_ids = input_ids.index_select(0, indices) 145 | # expect same image features for all pairs 146 | img_feat = img_feat[:self.hard_size+1] 147 | img_pos_feat = img_pos_feat[:self.hard_size+1] 148 | else: 149 | raise ValueError() 150 | 151 | hard_batch['input_ids'] = input_ids 152 | hard_batch['position_ids'] = position_ids 153 | hard_batch['img_feat'] = img_feat 154 | hard_batch['img_pos_feat'] = img_pos_feat 155 | hard_batch['attn_masks'] = attention_mask 156 | if IAIS: 157 | txt_attn_masks = batch['txt_attn_masks'] 158 | img_attn_masks = batch['img_attn_masks'] 159 | if sample_from == 't': 160 | max_len = attention_mask.sum(dim=1).max().item() 161 | hard_batch['txt_attn_masks'] = txt_attn_masks[:, :max_len] 162 | hard_batch['img_attn_masks'] = img_attn_masks[:, :max_len] 163 | elif sample_from == 'i': 164 | hard_batch['txt_attn_masks'] = txt_attn_masks 165 | hard_batch['img_attn_masks'] = img_attn_masks 166 | else: 167 | if sample_from == 't': 168 | gather_index = batch['gather_index'] 169 | gather_index = gather_index.index_select(0, indices) 170 | gather_index = gather_index[:, :max_len] 171 | hard_batch['gather_index'] = gather_index 172 | 173 | return hard_batch 174 | -------------------------------------------------------------------------------- /model/pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | UNITER for pretraining 6 | """ 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 13 | 14 | from .layer import GELU, BertOnlyMLMHead 15 | from .model import UniterModel, UniterPreTrainedModel 16 | from .ot import optimal_transport_dist 17 | 18 | 19 | class RegionFeatureRegression(nn.Module): 20 | " for MRM" 21 | def __init__(self, hidden_size, feat_dim, img_linear_weight): 22 | super().__init__() 23 | self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), 24 | GELU(), 25 | LayerNorm(hidden_size, eps=1e-12)) 26 | 27 | self.weight = img_linear_weight 28 | self.bias = nn.Parameter(torch.zeros(feat_dim)) 29 | 30 | def forward(self, input_): 31 | hidden = self.net(input_) 32 | output = F.linear(hidden, self.weight.t(), self.bias) 33 | return output 34 | 35 | 36 | class RegionClassification(nn.Module): 37 | " for MRC(-kl)" 38 | def __init__(self, hidden_size, label_dim): 39 | super().__init__() 40 | self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size), 41 | GELU(), 42 | LayerNorm(hidden_size, eps=1e-12), 43 | nn.Linear(hidden_size, label_dim)) 44 | 45 | def forward(self, input_): 46 | output = self.net(input_) 47 | return output 48 | 49 | 50 | class UniterForPretraining(UniterPreTrainedModel): 51 | """ UNITER pretraining """ 52 | def __init__(self, config, img_dim, img_label_dim): 53 | super().__init__(config) 54 | self.uniter = UniterModel(config, img_dim) 55 | self.cls = BertOnlyMLMHead( 56 | config, self.uniter.embeddings.word_embeddings.weight) 57 | self.feat_regress = RegionFeatureRegression( 58 | config.hidden_size, img_dim, 59 | self.uniter.img_embeddings.img_linear.weight) 60 | self.region_classifier = RegionClassification( 61 | config.hidden_size, img_label_dim) 62 | self.itm_output = nn.Linear(config.hidden_size, 2) 63 | self.apply(self.init_weights) 64 | 65 | def forward(self, batch, task, compute_loss=True): 66 | batch = defaultdict(lambda: None, batch) 67 | input_ids = batch['input_ids'] 68 | position_ids = batch['position_ids'] 69 | img_feat = batch['img_feat'] 70 | img_pos_feat = batch['img_pos_feat'] 71 | attention_mask = batch['attn_masks'] 72 | gather_index = batch['gather_index'] 73 | if task == 'mlm': 74 | txt_labels = batch['txt_labels'] 75 | return self.forward_mlm(input_ids, position_ids, 76 | img_feat, img_pos_feat, 77 | attention_mask, gather_index, 78 | txt_labels, compute_loss) 79 | elif task == 'mrfr': 80 | img_mask_tgt = batch['img_mask_tgt'] 81 | img_masks = batch['img_masks'] 82 | mrfr_feat_target = batch['feat_targets'] 83 | return self.forward_mrfr(input_ids, position_ids, 84 | img_feat, img_pos_feat, 85 | attention_mask, gather_index, 86 | img_masks, img_mask_tgt, 87 | mrfr_feat_target, compute_loss) 88 | elif task == 'itm': 89 | targets = batch['targets'] 90 | ot_inputs = batch['ot_inputs'] 91 | return self.forward_itm(input_ids, position_ids, 92 | img_feat, img_pos_feat, 93 | attention_mask, gather_index, 94 | targets, ot_inputs, compute_loss) 95 | elif task.startswith('mrc'): 96 | img_mask_tgt = batch['img_mask_tgt'] 97 | img_masks = batch['img_masks'] 98 | mrc_label_target = batch['label_targets'] 99 | return self.forward_mrc(input_ids, position_ids, 100 | img_feat, img_pos_feat, 101 | attention_mask, gather_index, 102 | img_masks, img_mask_tgt, 103 | mrc_label_target, task, compute_loss) 104 | else: 105 | raise ValueError('invalid task') 106 | 107 | def forward_mlm(self, input_ids, position_ids, img_feat, img_pos_feat, 108 | attention_mask, gather_index, 109 | txt_labels, compute_loss=True): 110 | sequence_output = self.uniter(input_ids, position_ids, 111 | img_feat, img_pos_feat, 112 | attention_mask, gather_index, 113 | output_all_encoded_layers=False) 114 | # get only the text part 115 | sequence_output = sequence_output[:, :input_ids.size(1), :] 116 | # only compute masked tokens for better efficiency 117 | masked_output = self._compute_masked_hidden(sequence_output, 118 | txt_labels != -1) 119 | prediction_scores = self.cls(masked_output) 120 | 121 | if compute_loss: 122 | masked_lm_loss = F.cross_entropy(prediction_scores, 123 | txt_labels[txt_labels != -1], 124 | reduction='none') 125 | return masked_lm_loss 126 | else: 127 | return prediction_scores 128 | 129 | def _compute_masked_hidden(self, hidden, mask): 130 | """ get only the masked region (don't compute unnecessary hiddens) """ 131 | mask = mask.unsqueeze(-1).expand_as(hidden) 132 | hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) 133 | return hidden_masked 134 | 135 | def forward_mrfr(self, input_ids, position_ids, img_feat, img_pos_feat, 136 | attention_mask, gather_index, img_masks, img_mask_tgt, 137 | feat_targets, compute_loss=True): 138 | sequence_output = self.uniter(input_ids, position_ids, 139 | img_feat, img_pos_feat, 140 | attention_mask, gather_index, 141 | output_all_encoded_layers=False, 142 | img_masks=img_masks) 143 | 144 | # only compute masked tokens for better efficiency 145 | masked_output = self._compute_masked_hidden(sequence_output, 146 | img_mask_tgt) 147 | prediction_feat = self.feat_regress(masked_output) 148 | 149 | if compute_loss: 150 | mrfr_loss = F.mse_loss(prediction_feat, feat_targets, 151 | reduction='none') 152 | return mrfr_loss 153 | else: 154 | return prediction_feat 155 | 156 | def forward_itm(self, input_ids, position_ids, img_feat, img_pos_feat, 157 | attention_mask, gather_index, targets, ot_inputs, 158 | compute_loss=True): 159 | sequence_output = self.uniter(input_ids, position_ids, 160 | img_feat, img_pos_feat, 161 | attention_mask, gather_index, 162 | output_all_encoded_layers=False) 163 | pooled_output = self.uniter.pooler(sequence_output) 164 | itm_scores = self.itm_output(pooled_output) 165 | 166 | # OT loss 167 | if ot_inputs is not None: 168 | ot_scatter = ot_inputs['ot_scatter'] 169 | 170 | b = sequence_output.size(0) 171 | tl = input_ids.size(1) 172 | il = img_feat.size(1) 173 | max_l = max(ot_inputs['scatter_max'] + 1, tl+il) 174 | 175 | ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output) 176 | ctx_emb = torch.zeros(b, max_l, self.config.hidden_size, 177 | dtype=sequence_output.dtype, 178 | device=sequence_output.device 179 | ).scatter_(dim=1, index=ot_scatter, 180 | src=sequence_output) 181 | txt_emb = ctx_emb[:, :tl, :] 182 | img_emb = ctx_emb[:, tl:tl+il, :] 183 | 184 | txt_pad = ot_inputs['txt_pad'] 185 | img_pad = ot_inputs['img_pad'] 186 | # NOTE: run in fp32 for stability 187 | ot_dist = optimal_transport_dist(txt_emb.float(), img_emb.float(), 188 | txt_pad, img_pad).to(txt_emb) 189 | ot_pos_dist = ot_dist.masked_select(targets == 1) 190 | ot_neg_dist = ot_dist.masked_select(targets == 0) 191 | ot_loss = (ot_pos_dist, ot_neg_dist) 192 | else: 193 | ot_loss = None 194 | 195 | if compute_loss: 196 | itm_loss = F.cross_entropy(itm_scores, targets, reduction='none') 197 | return itm_loss, ot_loss 198 | else: 199 | return itm_scores, ot_loss 200 | 201 | def forward_mrc(self, input_ids, position_ids, img_feat, img_pos_feat, 202 | attention_mask, gather_index, img_masks, img_mask_tgt, 203 | label_targets, task, compute_loss=True): 204 | sequence_output = self.uniter(input_ids, position_ids, 205 | img_feat, img_pos_feat, 206 | attention_mask, gather_index, 207 | output_all_encoded_layers=False, 208 | img_masks=img_masks) 209 | 210 | # only compute masked regions for better efficiency 211 | masked_output = self._compute_masked_hidden(sequence_output, 212 | img_mask_tgt) 213 | prediction_soft_label = self.region_classifier(masked_output) 214 | 215 | if compute_loss: 216 | if "kl" in task: 217 | prediction_soft_label = F.log_softmax( 218 | prediction_soft_label, dim=-1) 219 | mrc_loss = F.kl_div( 220 | prediction_soft_label, label_targets, reduction='none') 221 | else: 222 | # background class should not be the target 223 | label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1 224 | mrc_loss = F.cross_entropy( 225 | prediction_soft_label, label_targets, 226 | ignore_index=0, reduction='none') 227 | return mrc_loss 228 | else: 229 | return prediction_soft_label 230 | -------------------------------------------------------------------------------- /data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Dataset interfaces 6 | """ 7 | from collections import defaultdict 8 | from contextlib import contextmanager 9 | import io 10 | import json 11 | from os.path import exists 12 | 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import Dataset, ConcatDataset 16 | import horovod.torch as hvd 17 | from tqdm import tqdm 18 | import lmdb 19 | from lz4.frame import compress, decompress 20 | 21 | import msgpack 22 | import msgpack_numpy 23 | msgpack_numpy.patch() 24 | 25 | 26 | def _fp16_to_fp32(feat_dict): 27 | out = {k: arr.astype(np.float32) 28 | if arr.dtype == np.float16 else arr 29 | for k, arr in feat_dict.items()} 30 | return out 31 | 32 | 33 | def compute_num_bb(confs, conf_th, min_bb, max_bb): 34 | num_bb = max(min_bb, (confs > conf_th).sum()) 35 | num_bb = min(max_bb, num_bb) 36 | return num_bb 37 | 38 | 39 | def _check_distributed(): 40 | try: 41 | dist = hvd.size() != hvd.local_size() 42 | except ValueError: 43 | # not using horovod 44 | dist = False 45 | return dist 46 | 47 | 48 | class DetectFeatLmdb(object): 49 | def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, 50 | compress=True): 51 | self.img_dir = img_dir 52 | if conf_th == -1: 53 | db_name = f'feat_numbb{num_bb}' 54 | self.name2nbb = defaultdict(lambda: num_bb) 55 | else: 56 | db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}' 57 | nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json' 58 | if not exists(f'{img_dir}/{nbb}'): 59 | # nbb is not pre-computed 60 | self.name2nbb = None 61 | else: 62 | self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) 63 | self.compress = compress 64 | if compress: 65 | db_name += '_compressed' 66 | 67 | if self.name2nbb is None: 68 | if compress: 69 | db_name = 'all_compressed' 70 | else: 71 | db_name = 'all' 72 | nbb = f'nbb.json' 73 | self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) 74 | # only read ahead on single node training 75 | self.env = lmdb.open(f'{img_dir}/{db_name}', 76 | readonly=True, create=False, 77 | readahead=not _check_distributed()) 78 | self.txn = self.env.begin(buffers=True) 79 | if self.name2nbb is None: 80 | self.name2nbb = self._compute_nbb() 81 | 82 | def _compute_nbb(self): 83 | name2nbb = {} 84 | fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8')) 85 | for fname in tqdm(fnames, desc='reading images'): 86 | dump = self.txn.get(fname.encode('utf-8')) 87 | if self.compress: 88 | with io.BytesIO(dump) as reader: 89 | img_dump = np.load(reader, allow_pickle=True) 90 | confs = img_dump['conf'] 91 | else: 92 | img_dump = msgpack.loads(dump, raw=False) 93 | confs = img_dump['conf'] 94 | name2nbb[fname] = compute_num_bb(confs, self.conf_th, 95 | self.min_bb, self.max_bb) 96 | 97 | return name2nbb 98 | 99 | def __del__(self): 100 | self.env.close() 101 | 102 | def get_dump(self, file_name): 103 | # hack for MRC 104 | dump = self.txn.get(file_name.encode('utf-8')) 105 | nbb = self.name2nbb[file_name] 106 | if self.compress: 107 | with io.BytesIO(dump) as reader: 108 | img_dump = np.load(reader, allow_pickle=True) 109 | img_dump = _fp16_to_fp32(img_dump) 110 | else: 111 | img_dump = msgpack.loads(dump, raw=False) 112 | img_dump = _fp16_to_fp32(img_dump) 113 | img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()} 114 | return img_dump 115 | 116 | def __getitem__(self, file_name): 117 | dump = self.txn.get(file_name.encode('utf-8')) 118 | nbb = self.name2nbb[file_name] 119 | if self.compress: 120 | with io.BytesIO(dump) as reader: 121 | img_dump = np.load(reader, allow_pickle=True) 122 | img_dump = {'features': img_dump['features'], 123 | 'norm_bb': img_dump['norm_bb']} 124 | else: 125 | img_dump = msgpack.loads(dump, raw=False) 126 | img_feat = torch.tensor(img_dump['features'][:nbb, :]).float() 127 | img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float() 128 | return img_feat, img_bb 129 | 130 | 131 | @contextmanager 132 | def open_lmdb(db_dir, readonly=False): 133 | db = TxtLmdb(db_dir, readonly) 134 | try: 135 | yield db 136 | finally: 137 | del db 138 | 139 | 140 | class TxtLmdb(object): 141 | def __init__(self, db_dir, readonly=True): 142 | self.readonly = readonly 143 | if readonly: 144 | # training 145 | self.env = lmdb.open(db_dir, 146 | readonly=True, create=False, 147 | readahead=not _check_distributed()) 148 | self.txn = self.env.begin(buffers=True) 149 | self.write_cnt = None 150 | else: 151 | # prepro 152 | self.env = lmdb.open(db_dir, readonly=False, create=True, 153 | map_size=4 * 1024**4) 154 | self.txn = self.env.begin(write=True) 155 | self.write_cnt = 0 156 | 157 | def __del__(self): 158 | if self.write_cnt: 159 | self.txn.commit() 160 | self.env.close() 161 | 162 | def __getitem__(self, key): 163 | return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))), 164 | raw=False) 165 | 166 | def __setitem__(self, key, value): 167 | # NOTE: not thread safe 168 | if self.readonly: 169 | raise ValueError('readonly text DB') 170 | ret = self.txn.put(key.encode('utf-8'), 171 | compress(msgpack.dumps(value, use_bin_type=True))) 172 | self.write_cnt += 1 173 | if self.write_cnt % 1000 == 0: 174 | self.txn.commit() 175 | self.txn = self.env.begin(write=True) 176 | self.write_cnt = 0 177 | return ret 178 | 179 | 180 | class TxtTokLmdb(object): 181 | def __init__(self, db_dir, max_txt_len=60): 182 | if max_txt_len == -1: 183 | self.id2len = json.load(open(f'{db_dir}/id2len.json')) 184 | else: 185 | self.id2len = { 186 | id_: len_ 187 | for id_, len_ in json.load(open(f'{db_dir}/id2len.json') 188 | ).items() 189 | if len_ <= max_txt_len 190 | } 191 | self.db_dir = db_dir 192 | self.db = TxtLmdb(db_dir, readonly=True) 193 | meta = json.load(open(f'{db_dir}/meta.json', 'r')) 194 | self.cls_ = meta['CLS'] 195 | self.sep = meta['SEP'] 196 | self.mask = meta['MASK'] 197 | self.v_range = meta['v_range'] 198 | 199 | def __getitem__(self, id_): 200 | txt_dump = self.db[id_] 201 | return txt_dump 202 | 203 | def combine_inputs(self, *inputs): 204 | input_ids = [self.cls_] 205 | for ids in inputs: 206 | input_ids.extend(ids + [self.sep]) 207 | return torch.tensor(input_ids) 208 | 209 | @property 210 | def txt2img(self): 211 | txt2img = json.load(open(f'{self.db_dir}/txt2img.json')) 212 | return txt2img 213 | 214 | @property 215 | def img2txts(self): 216 | img2txts = json.load(open(f'{self.db_dir}/img2txts.json')) 217 | return img2txts 218 | 219 | 220 | def get_ids_and_lens(db): 221 | assert isinstance(db, TxtTokLmdb) 222 | lens = [] 223 | ids = [] 224 | for id_ in list(db.id2len.keys())[hvd.rank()::hvd.size()]: 225 | lens.append(db.id2len[id_]) 226 | ids.append(id_) 227 | return lens, ids 228 | 229 | 230 | class DetectFeatTxtTokDataset(Dataset): 231 | def __init__(self, txt_db, img_db): 232 | assert isinstance(txt_db, TxtTokLmdb) 233 | assert isinstance(img_db, DetectFeatLmdb) 234 | self.txt_db = txt_db 235 | self.img_db = img_db 236 | txt_lens, self.ids = get_ids_and_lens(txt_db) 237 | 238 | txt2img = txt_db.txt2img 239 | self.lens = [tl + self.img_db.name2nbb[txt2img[id_]] 240 | for tl, id_ in zip(txt_lens, self.ids)] 241 | 242 | def __len__(self): 243 | return len(self.ids) 244 | 245 | def __getitem__(self, i): 246 | id_ = self.ids[i] 247 | example = self.txt_db[id_] 248 | return example 249 | 250 | def _get_img_feat(self, fname): 251 | img_feat, bb = self.img_db[fname] 252 | img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) 253 | num_bb = img_feat.size(0) 254 | return img_feat, img_bb, num_bb 255 | 256 | 257 | def pad_tensors(tensors, lens=None, pad=0): 258 | """B x [T, ...]""" 259 | if lens is None: 260 | lens = [t.size(0) for t in tensors] 261 | max_len = max(lens) 262 | bs = len(tensors) 263 | hid = tensors[0].size(-1) 264 | dtype = tensors[0].dtype 265 | output = torch.zeros(bs, max_len, hid, dtype=dtype) 266 | if pad: 267 | output.data.fill_(pad) 268 | for i, (t, l) in enumerate(zip(tensors, lens)): 269 | output.data[i, :l, ...] = t.data 270 | return output 271 | 272 | 273 | def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): 274 | assert len(txt_lens) == len(num_bbs) == batch_size 275 | gather_index = torch.arange(0, out_size, dtype=torch.long, 276 | ).unsqueeze(0).repeat(batch_size, 1) 277 | 278 | for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): 279 | gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, dtype=torch.long).data 280 | return gather_index 281 | 282 | 283 | class ConcatDatasetWithLens(ConcatDataset): 284 | """ A thin wrapper on pytorch concat dataset for lens batching """ 285 | def __init__(self, datasets): 286 | super().__init__(datasets) 287 | self.lens = [l for dset in datasets for l in dset.lens] 288 | 289 | def __getattr__(self, name): 290 | return self._run_method_on_all_dsets(name) 291 | 292 | def _run_method_on_all_dsets(self, name): 293 | def run_all(*args, **kwargs): 294 | return [dset.__getattribute__(name)(*args, **kwargs) 295 | for dset in self.datasets] 296 | return run_all 297 | 298 | 299 | class ImageLmdbGroup(object): 300 | def __init__(self, conf_th, max_bb, min_bb, num_bb, compress): 301 | self.path2imgdb = {} 302 | self.conf_th = conf_th 303 | self.max_bb = max_bb 304 | self.min_bb = min_bb 305 | self.num_bb = num_bb 306 | self.compress = compress 307 | 308 | def __getitem__(self, path): 309 | img_db = self.path2imgdb.get(path, None) 310 | if img_db is None: 311 | img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb, 312 | self.min_bb, self.num_bb, self.compress) 313 | return img_db 314 | -------------------------------------------------------------------------------- /model/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | BERT layers from the huggingface implementation 3 | (https://github.com/huggingface/transformers) 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import logging 21 | import math 22 | 23 | import torch 24 | from torch import nn 25 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 26 | import pdb 27 | from utils.heatmap import plot_attention_headmap 28 | import numpy as np 29 | 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def gelu(x): 35 | """Implementation of the gelu activation function. 36 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 37 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 38 | Also see https://arxiv.org/abs/1606.08415 39 | """ 40 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 41 | 42 | 43 | def swish(x): 44 | return x * torch.sigmoid(x) 45 | 46 | 47 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 48 | 49 | 50 | class GELU(nn.Module): 51 | def forward(self, input_): 52 | output = gelu(input_) 53 | return output 54 | 55 | 56 | class BertSelfAttention(nn.Module): 57 | def __init__(self, config): 58 | super(BertSelfAttention, self).__init__() 59 | if config.hidden_size % config.num_attention_heads != 0: 60 | raise ValueError( 61 | "The hidden size (%d) is not a multiple of the number of attention " 62 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 63 | self.num_attention_heads = config.num_attention_heads # 12 64 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 65 | self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 66 | 67 | self.query = nn.Linear(config.hidden_size, self.all_head_size) # [768, 768] 68 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 69 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 70 | 71 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 72 | 73 | def transpose_for_scores(self, x): 74 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 75 | x = x.view(*new_x_shape) 76 | return x.permute(0, 2, 1, 3) 77 | 78 | def forward(self, hidden_states, attention_mask): 79 | mixed_query_layer = self.query(hidden_states) 80 | mixed_key_layer = self.key(hidden_states) 81 | mixed_value_layer = self.value(hidden_states) 82 | 83 | query_layer = self.transpose_for_scores(mixed_query_layer) 84 | key_layer = self.transpose_for_scores(mixed_key_layer) 85 | value_layer = self.transpose_for_scores(mixed_value_layer) 86 | 87 | # Take the dot product between "query" and "key" to get the raw attention scores. 88 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 89 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 90 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 91 | attention_scores = attention_scores + attention_mask 92 | 93 | # Normalize the attention scores to probabilities. 94 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 95 | 96 | # This is actually dropping out entire tokens to attend to, which might 97 | # seem a bit unusual, but is taken from the original Transformer paper. 98 | attention_probs = self.dropout(attention_probs) 99 | 100 | context_layer = torch.matmul(attention_probs, value_layer) 101 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 102 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 103 | context_layer = context_layer.view(*new_context_layer_shape) 104 | return context_layer 105 | 106 | def get_attention_probs(self, hidden_states, attention_mask, attention_dir='raw'): 107 | ''' 108 | hidden_states: [sample_num, max_tl+max_nbb, hidden_size(768)] 109 | attention_mask: [sample_num, 1, max_tl+max_nbb, max_tl+max_nbb] 110 | ''' 111 | mixed_query_layer = self.query(hidden_states) 112 | mixed_key_layer = self.key(hidden_states) 113 | 114 | query_layer = self.transpose_for_scores(mixed_query_layer) # [sample_num, attn_head_num, max_tl+max_nbb, attn_head_size] 115 | key_layer = self.transpose_for_scores(mixed_key_layer) 116 | 117 | # Take the dot product between "query" and "key" to get the raw attention scores. 118 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [sample_num, attn_head_num, max_tl+max_nbb, max_tl+max_nbb] 119 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 120 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 121 | attention_mask = (1.0 - attention_mask) * -10000.0 122 | attention_scores = attention_scores + attention_mask 123 | # Normalize the attention scores to probabilities. 124 | if attention_dir == 'raw': 125 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 126 | elif attention_dir == 'col': 127 | attention_probs = nn.Softmax(dim=-2)(attention_scores) 128 | else: 129 | raise ValueError('attention direction must be raw or col') 130 | return attention_probs 131 | 132 | 133 | class BertSelfOutput(nn.Module): 134 | def __init__(self, config): 135 | super(BertSelfOutput, self).__init__() 136 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 137 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 138 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 139 | 140 | def forward(self, hidden_states, input_tensor): 141 | hidden_states = self.dense(hidden_states) 142 | hidden_states = self.dropout(hidden_states) 143 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 144 | return hidden_states 145 | 146 | 147 | class BertAttention(nn.Module): 148 | def __init__(self, config): 149 | super(BertAttention, self).__init__() 150 | self.self = BertSelfAttention(config) 151 | self.output = BertSelfOutput(config) 152 | 153 | def forward(self, input_tensor, attention_mask): 154 | self_output = self.self(input_tensor, attention_mask) 155 | attention_output = self.output(self_output, input_tensor) 156 | return attention_output 157 | 158 | 159 | class BertIntermediate(nn.Module): 160 | def __init__(self, config): 161 | super(BertIntermediate, self).__init__() 162 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 163 | if isinstance(config.hidden_act, str): 164 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 165 | else: 166 | self.intermediate_act_fn = config.hidden_act 167 | 168 | def forward(self, hidden_states): 169 | hidden_states = self.dense(hidden_states) 170 | hidden_states = self.intermediate_act_fn(hidden_states) 171 | return hidden_states 172 | 173 | 174 | class BertOutput(nn.Module): 175 | def __init__(self, config): 176 | super(BertOutput, self).__init__() 177 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 178 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 179 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 180 | 181 | def forward(self, hidden_states, input_tensor): 182 | hidden_states = self.dense(hidden_states) 183 | hidden_states = self.dropout(hidden_states) 184 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 185 | return hidden_states 186 | 187 | 188 | class BertLayer(nn.Module): 189 | def __init__(self, config): 190 | super(BertLayer, self).__init__() 191 | self.attention = BertAttention(config) 192 | self.intermediate = BertIntermediate(config) 193 | self.output = BertOutput(config) 194 | 195 | def forward(self, hidden_states, attention_mask): 196 | attention_output = self.attention(hidden_states, attention_mask) 197 | intermediate_output = self.intermediate(attention_output) 198 | layer_output = self.output(intermediate_output, attention_output) 199 | return layer_output 200 | 201 | 202 | class BertPooler(nn.Module): 203 | def __init__(self, config): 204 | super(BertPooler, self).__init__() 205 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 206 | self.activation = nn.Tanh() 207 | 208 | def forward(self, hidden_states): 209 | # We "pool" the model by simply taking the hidden state corresponding 210 | # to the first token. 211 | first_token_tensor = hidden_states[:, 0] 212 | pooled_output = self.dense(first_token_tensor) 213 | pooled_output = self.activation(pooled_output) 214 | return pooled_output 215 | 216 | 217 | class BertPredictionHeadTransform(nn.Module): 218 | def __init__(self, config): 219 | super(BertPredictionHeadTransform, self).__init__() 220 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 221 | if isinstance(config.hidden_act, str): 222 | self.transform_act_fn = ACT2FN[config.hidden_act] 223 | else: 224 | self.transform_act_fn = config.hidden_act 225 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 226 | 227 | def forward(self, hidden_states): 228 | hidden_states = self.dense(hidden_states) 229 | hidden_states = self.transform_act_fn(hidden_states) 230 | hidden_states = self.LayerNorm(hidden_states) 231 | return hidden_states 232 | 233 | 234 | class BertLMPredictionHead(nn.Module): 235 | def __init__(self, config, bert_model_embedding_weights): 236 | super(BertLMPredictionHead, self).__init__() 237 | self.transform = BertPredictionHeadTransform(config) 238 | 239 | # The output weights are the same as the input embeddings, but there is 240 | # an output-only bias for each token. 241 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 242 | bert_model_embedding_weights.size(0), 243 | bias=False) 244 | self.decoder.weight = bert_model_embedding_weights 245 | self.bias = nn.Parameter( 246 | torch.zeros(bert_model_embedding_weights.size(0))) 247 | 248 | def forward(self, hidden_states): 249 | hidden_states = self.transform(hidden_states) 250 | hidden_states = self.decoder(hidden_states) + self.bias 251 | return hidden_states 252 | 253 | 254 | class BertOnlyMLMHead(nn.Module): 255 | def __init__(self, config, bert_model_embedding_weights): 256 | super(BertOnlyMLMHead, self).__init__() 257 | self.predictions = BertLMPredictionHead(config, 258 | bert_model_embedding_weights) 259 | 260 | def forward(self, sequence_output): 261 | prediction_scores = self.predictions(sequence_output) 262 | return prediction_scores 263 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | copy multi-head attention code from pytorch 3 | (https://github.com/pytorch/pytorch), 4 | """ 5 | import warnings 6 | 7 | import torch 8 | from torch.nn import Module, Parameter, Linear 9 | from torch.nn.init import xavier_normal_, xavier_uniform_, constant_ 10 | from torch.nn.functional import linear, softmax, dropout 11 | 12 | 13 | def multi_head_attention_forward(query, # type: Tensor 14 | key, # type: Tensor 15 | value, # type: Tensor 16 | embed_dim_to_check, # type: int 17 | num_heads, # type: int 18 | in_proj_weight, # type: Tensor 19 | in_proj_bias, # type: Tensor 20 | bias_k, # type: Optional[Tensor] 21 | bias_v, # type: Optional[Tensor] 22 | add_zero_attn, # type: bool 23 | dropout_p, # type: float 24 | out_proj_weight, # type: Tensor 25 | out_proj_bias, # type: Tensor 26 | training=True, # type: bool 27 | key_padding_mask=None, # type: Optional[Tensor] 28 | need_weights=True, # type: bool 29 | attn_mask=None, # type: Optional[Tensor] 30 | use_separate_proj_weight=False, # type: bool 31 | q_proj_weight=None, # type: Optional[Tensor] 32 | k_proj_weight=None, # type: Optional[Tensor] 33 | v_proj_weight=None, # type: Optional[Tensor] 34 | static_k=None, # type: Optional[Tensor] 35 | static_v=None # type: Optional[Tensor] 36 | ): 37 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 38 | r""" 39 | Args: 40 | query, key, value: map a query and a set of key-value pairs to an output. 41 | See "Attention Is All You Need" for more details. 42 | embed_dim_to_check: total dimension of the model. 43 | num_heads: parallel attention heads. 44 | in_proj_weight, in_proj_bias: input projection weight and bias. 45 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 46 | add_zero_attn: add a new batch of zeros to the key and 47 | value sequences at dim=1. 48 | dropout_p: probability of an element to be zeroed. 49 | out_proj_weight, out_proj_bias: the output projection weight and bias. 50 | training: apply dropout if is ``True``. 51 | key_padding_mask: if provided, specified padding elements in the key will 52 | be ignored by the attention. This is an binary mask. When the value is True, 53 | the corresponding value on the attention layer will be filled with -inf. 54 | need_weights: output attn_output_weights. 55 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 56 | (i.e. the values will be added to the attention layer). 57 | use_separate_proj_weight: the function accept the proj. weights for query, key, 58 | and value in differnt forms. If false, in_proj_weight will be used, which is 59 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 60 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 61 | static_k, static_v: static key and value used for attention operators. 62 | 63 | 64 | Shape: 65 | Inputs: 66 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 67 | the embedding dimension. 68 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 69 | the embedding dimension. 70 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 71 | the embedding dimension. 72 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 73 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 74 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 75 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 76 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 77 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 78 | 79 | Outputs: 80 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 81 | E is the embedding dimension. 82 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 83 | L is the target sequence length, S is the source sequence length. 84 | """ 85 | 86 | qkv_same = torch.equal(query, key) and torch.equal(key, value) 87 | kv_same = torch.equal(key, value) 88 | 89 | tgt_len, bsz, embed_dim = query.size() 90 | assert embed_dim == embed_dim_to_check 91 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 92 | assert key.size() == value.size() 93 | 94 | head_dim = embed_dim // num_heads 95 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 96 | scaling = float(head_dim) ** -0.5 97 | 98 | if use_separate_proj_weight is not True: 99 | if qkv_same: 100 | # self-attention 101 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 102 | 103 | elif kv_same: 104 | # encoder-decoder attention 105 | # This is inline in_proj function with in_proj_weight and in_proj_bias 106 | _b = in_proj_bias 107 | _start = 0 108 | _end = embed_dim 109 | _w = in_proj_weight[_start:_end, :] 110 | if _b is not None: 111 | _b = _b[_start:_end] 112 | q = linear(query, _w, _b) 113 | 114 | if key is None: 115 | assert value is None 116 | k = None 117 | v = None 118 | else: 119 | 120 | # This is inline in_proj function with in_proj_weight and in_proj_bias 121 | _b = in_proj_bias 122 | _start = embed_dim 123 | _end = None 124 | _w = in_proj_weight[_start:, :] 125 | if _b is not None: 126 | _b = _b[_start:] 127 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 128 | 129 | else: 130 | # This is inline in_proj function with in_proj_weight and in_proj_bias 131 | _b = in_proj_bias 132 | _start = 0 133 | _end = embed_dim 134 | _w = in_proj_weight[_start:_end, :] 135 | if _b is not None: 136 | _b = _b[_start:_end] 137 | q = linear(query, _w, _b) 138 | 139 | # This is inline in_proj function with in_proj_weight and in_proj_bias 140 | _b = in_proj_bias 141 | _start = embed_dim 142 | _end = embed_dim * 2 143 | _w = in_proj_weight[_start:_end, :] 144 | if _b is not None: 145 | _b = _b[_start:_end] 146 | k = linear(key, _w, _b) 147 | 148 | # This is inline in_proj function with in_proj_weight and in_proj_bias 149 | _b = in_proj_bias 150 | _start = embed_dim * 2 151 | _end = None 152 | _w = in_proj_weight[_start:, :] 153 | if _b is not None: 154 | _b = _b[_start:] 155 | v = linear(value, _w, _b) 156 | else: 157 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 158 | len1, len2 = q_proj_weight_non_opt.size() 159 | assert len1 == embed_dim and len2 == query.size(-1) 160 | 161 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 162 | len1, len2 = k_proj_weight_non_opt.size() 163 | assert len1 == embed_dim and len2 == key.size(-1) 164 | 165 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 166 | len1, len2 = v_proj_weight_non_opt.size() 167 | assert len1 == embed_dim and len2 == value.size(-1) 168 | 169 | if in_proj_bias is not None: 170 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 171 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 172 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 173 | else: 174 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 175 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 176 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 177 | q = q * scaling 178 | 179 | if bias_k is not None and bias_v is not None: 180 | if static_k is None and static_v is None: 181 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 182 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 183 | if attn_mask is not None: 184 | attn_mask = torch.cat([attn_mask, 185 | torch.zeros((attn_mask.size(0), 1), 186 | dtype=attn_mask.dtype, 187 | device=attn_mask.device)], dim=1) 188 | if key_padding_mask is not None: 189 | key_padding_mask = torch.cat( 190 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 191 | dtype=key_padding_mask.dtype, 192 | device=key_padding_mask.device)], dim=1) 193 | else: 194 | assert static_k is None, "bias cannot be added to static key." 195 | assert static_v is None, "bias cannot be added to static value." 196 | else: 197 | assert bias_k is None 198 | assert bias_v is None 199 | 200 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 201 | if k is not None: 202 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 203 | if v is not None: 204 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 205 | 206 | if static_k is not None: 207 | assert static_k.size(0) == bsz * num_heads 208 | assert static_k.size(2) == head_dim 209 | k = static_k 210 | 211 | if static_v is not None: 212 | assert static_v.size(0) == bsz * num_heads 213 | assert static_v.size(2) == head_dim 214 | v = static_v 215 | 216 | src_len = k.size(1) 217 | 218 | if key_padding_mask is not None: 219 | assert key_padding_mask.size(0) == bsz 220 | assert key_padding_mask.size(1) == src_len 221 | 222 | if add_zero_attn: 223 | src_len += 1 224 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 225 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 226 | if attn_mask is not None: 227 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 228 | dtype=attn_mask.dtype, 229 | device=attn_mask.device)], dim=1) 230 | if key_padding_mask is not None: 231 | key_padding_mask = torch.cat( 232 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 233 | dtype=key_padding_mask.dtype, 234 | device=key_padding_mask.device)], dim=1) 235 | 236 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 237 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 238 | 239 | if attn_mask is not None: 240 | attn_mask = attn_mask.unsqueeze(0) 241 | attn_output_weights += attn_mask 242 | 243 | if key_padding_mask is not None: 244 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 245 | attn_output_weights = attn_output_weights.masked_fill( 246 | key_padding_mask.unsqueeze(1).unsqueeze(2), 247 | float('-inf'), 248 | ) 249 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 250 | 251 | attn_output_weights = softmax( 252 | attn_output_weights, dim=-1) 253 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 254 | 255 | attn_output = torch.bmm(attn_output_weights, v) 256 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 257 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 258 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 259 | 260 | if need_weights: 261 | # average attention weights over heads 262 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 263 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 264 | else: 265 | return attn_output, None 266 | 267 | 268 | class MultiheadAttention(Module): 269 | r"""Allows the model to jointly attend to information 270 | from different representation subspaces. 271 | See reference: Attention Is All You Need 272 | 273 | .. math:: 274 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 275 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 276 | 277 | Args: 278 | embed_dim: total dimension of the model. 279 | num_heads: parallel attention heads. 280 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 281 | bias: add bias as module parameter. Default: True. 282 | add_bias_kv: add bias to the key and value sequences at dim=0. 283 | add_zero_attn: add a new batch of zeros to the key and 284 | value sequences at dim=1. 285 | kdim: total number of features in key. Default: None. 286 | vdim: total number of features in key. Default: None. 287 | 288 | Note: if kdim and vdim are None, they will be set to embed_dim such that 289 | query, key, and value have the same number of features. 290 | 291 | Examples:: 292 | 293 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 294 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 295 | """ 296 | 297 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 298 | super(MultiheadAttention, self).__init__() 299 | self.embed_dim = embed_dim 300 | self.kdim = kdim if kdim is not None else embed_dim 301 | self.vdim = vdim if vdim is not None else embed_dim 302 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 303 | 304 | self.num_heads = num_heads 305 | self.dropout = dropout 306 | self.head_dim = embed_dim // num_heads 307 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 308 | 309 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 310 | 311 | if self._qkv_same_embed_dim is False: 312 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 313 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 314 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 315 | 316 | if bias: 317 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 318 | else: 319 | self.register_parameter('in_proj_bias', None) 320 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 321 | 322 | if add_bias_kv: 323 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 324 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 325 | else: 326 | self.bias_k = self.bias_v = None 327 | 328 | self.add_zero_attn = add_zero_attn 329 | 330 | self._reset_parameters() 331 | 332 | def _reset_parameters(self): 333 | if self._qkv_same_embed_dim: 334 | xavier_uniform_(self.in_proj_weight) 335 | else: 336 | xavier_uniform_(self.q_proj_weight) 337 | xavier_uniform_(self.k_proj_weight) 338 | xavier_uniform_(self.v_proj_weight) 339 | 340 | if self.in_proj_bias is not None: 341 | constant_(self.in_proj_bias, 0.) 342 | constant_(self.out_proj.bias, 0.) 343 | if self.bias_k is not None: 344 | xavier_normal_(self.bias_k) 345 | if self.bias_v is not None: 346 | xavier_normal_(self.bias_v) 347 | 348 | def forward(self, query, key, value, key_padding_mask=None, 349 | need_weights=True, attn_mask=None): 350 | r""" 351 | Args: 352 | query, key, value: map a query and a set of key-value pairs to an output. 353 | See "Attention Is All You Need" for more details. 354 | key_padding_mask: if provided, specified padding elements in the key will 355 | be ignored by the attention. This is an binary mask. When the value is True, 356 | the corresponding value on the attention layer will be filled with -inf. 357 | need_weights: output attn_output_weights. 358 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 359 | (i.e. the values will be added to the attention layer). 360 | 361 | Shape: 362 | - Inputs: 363 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 364 | the embedding dimension. 365 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 366 | the embedding dimension. 367 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 368 | the embedding dimension. 369 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 370 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 371 | 372 | - Outputs: 373 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 374 | E is the embedding dimension. 375 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 376 | L is the target sequence length, S is the source sequence length. 377 | """ 378 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: 379 | return multi_head_attention_forward( 380 | query, key, value, self.embed_dim, self.num_heads, 381 | self.in_proj_weight, self.in_proj_bias, 382 | self.bias_k, self.bias_v, self.add_zero_attn, 383 | self.dropout, self.out_proj.weight, self.out_proj.bias, 384 | training=self.training, 385 | key_padding_mask=key_padding_mask, need_weights=need_weights, 386 | attn_mask=attn_mask, use_separate_proj_weight=True, 387 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 388 | v_proj_weight=self.v_proj_weight) 389 | else: 390 | if not hasattr(self, '_qkv_same_embed_dim'): 391 | warnings.warn('A new version of MultiheadAttention module has been implemented. \ 392 | Please re-train your model with the new module', 393 | UserWarning) 394 | 395 | return multi_head_attention_forward( 396 | query, key, value, self.embed_dim, self.num_heads, 397 | self.in_proj_weight, self.in_proj_bias, 398 | self.bias_k, self.bias_v, self.add_zero_attn, 399 | self.dropout, self.out_proj.weight, self.out_proj.bias, 400 | training=self.training, 401 | key_padding_mask=key_padding_mask, need_weights=need_weights, 402 | attn_mask=attn_mask) 403 | -------------------------------------------------------------------------------- /train_itm_hard_negatives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | UNITER finetuning for Image-Text Retrieval with hard negatives 6 | """ 7 | import argparse 8 | import os 9 | from os.path import exists, join 10 | from time import time 11 | 12 | import torch 13 | from torch.nn.utils import clip_grad_norm_ 14 | from torch.utils.data import DataLoader, ConcatDataset 15 | from apex import amp 16 | from horovod import torch as hvd 17 | from tqdm import tqdm 18 | 19 | from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, 20 | ItmRankDatasetHardNegFromText, 21 | ItmRankDatasetHardNegFromImage, itm_rank_hn_collate, 22 | ItmValDataset, itm_val_collate, 23 | ItmEvalDataset, itm_eval_collate) 24 | from model.itm import UniterForImageTextRetrievalHardNeg 25 | from optim import get_lr_sched 26 | from optim.misc import build_optimizer 27 | 28 | from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file 29 | from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, 30 | broadcast_tensors) 31 | from utils.save import ModelSaver, save_training_meta 32 | from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed 33 | from utils.const import IMG_DIM 34 | from utils.itm_eval import evaluate 35 | from utils.training_signal_annealing import get_tsa_threshold 36 | 37 | 38 | def build_dataloader(dataset, collate_fn, is_train, opts): 39 | dataloader = DataLoader(dataset, batch_size=1, 40 | shuffle=is_train, drop_last=is_train, 41 | num_workers=opts.n_workers, 42 | pin_memory=opts.pin_mem, collate_fn=collate_fn) 43 | dataloader = PrefetchLoader(dataloader) 44 | return dataloader 45 | 46 | 47 | def main(opts): 48 | hvd.init() 49 | n_gpu = hvd.size() 50 | device = torch.device("cuda", hvd.local_rank()) 51 | torch.cuda.set_device(hvd.local_rank()) 52 | rank = hvd.rank() 53 | opts.rank = rank 54 | LOGGER.info("device: {} n_gpu: {}, rank: {}, " 55 | "16-bits training: {}".format( 56 | device, n_gpu, hvd.rank(), opts.fp16)) 57 | 58 | set_random_seed(opts.seed) 59 | 60 | if hvd.rank() == 0: 61 | save_training_meta(opts) 62 | TB_LOGGER.create(join(opts.output_dir, 'log')) 63 | pbar = tqdm(total=opts.num_train_steps) 64 | model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) 65 | add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) 66 | # store ITM predictions 67 | os.makedirs(join(opts.output_dir, 'results_val')) 68 | os.makedirs(join(opts.output_dir, 'results_test')) 69 | os.makedirs(join(opts.output_dir, 'results_train')) 70 | else: 71 | LOGGER.disabled = True 72 | pbar = NoOp() 73 | model_saver = NoOp() 74 | 75 | # train_examples = None 76 | LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " 77 | f"{opts.train_img_dbs}") 78 | # check multiple DBs 79 | assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ 80 | "train txt_db and img_db have different length" 81 | 82 | # load DBs and image dirs 83 | all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, 84 | opts.num_bb, opts.compressed_db) 85 | # train 86 | LOGGER.info(f"Loading Train Dataset " 87 | f"{opts.train_txt_dbs}, {opts.train_img_dbs}") 88 | train_datasets_t = [] 89 | train_datasets_i = [] 90 | for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): 91 | img_db = all_img_dbs[img_path] 92 | txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) 93 | train_datasets_t.append(ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size, IAIS=opts.IAIS)) 94 | train_datasets_i.append(ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size, IAIS=opts.IAIS)) 95 | train_dataset_t = ConcatDataset(train_datasets_t) 96 | train_dataset_i = ConcatDataset(train_datasets_i) 97 | train_dataloader_t = build_dataloader(train_dataset_t, itm_rank_hn_collate, True, opts) 98 | train_dataloader_i = build_dataloader(train_dataset_i, itm_rank_hn_collate, True, opts) 99 | 100 | # val 101 | LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") 102 | val_img_db = all_img_dbs[opts.val_img_db] 103 | val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) 104 | val_dataset = ItmValDataset(val_txt_db, val_img_db, opts.inf_minibatch_size) 105 | val_dataloader = build_dataloader(val_dataset, itm_val_collate, False, opts) 106 | # eval 107 | LOGGER.info(f"Loading val, test Dataset for full evaluation: " 108 | f"{opts.val_txt_db}, {opts.val_img_db}" 109 | f"{opts.test_txt_db}, {opts.test_img_db}") 110 | eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, opts.inf_minibatch_size) 111 | eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, False, opts) 112 | test_img_db = all_img_dbs[opts.test_img_db] 113 | test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) 114 | eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, opts.inf_minibatch_size) 115 | eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, False, opts) 116 | 117 | # Prepare model 118 | if opts.checkpoint: 119 | checkpoint = torch.load(opts.checkpoint) 120 | else: 121 | checkpoint = {} 122 | 123 | model = UniterForImageTextRetrievalHardNeg.from_pretrained( 124 | opts.model_config, state_dict=checkpoint, 125 | img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size) 126 | model.init_output() # pretrain ITM head is different from ranking head 127 | model.to(device) 128 | # make sure every process has same model parameters in the beginning 129 | broadcast_tensors([p.data for p in model.parameters()], 0) 130 | set_dropout(model, opts.dropout) 131 | 132 | # Prepare optimizer 133 | optimizer = build_optimizer(model, opts) 134 | model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O2') 135 | 136 | LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") 137 | LOGGER.info(" Num examples = %d", 138 | sum(all_gather_list(len(train_dataset_t)))) 139 | LOGGER.info(" Batch size = %d", opts.train_batch_size) 140 | LOGGER.info(" Num steps = %d", opts.num_train_steps) 141 | 142 | running_loss = RunningMeter('loss') 143 | if opts.IAIS: 144 | ranking_loss = RunningMeter('rank_loss') 145 | model.train() 146 | 147 | global_step = 0 148 | step = 0 149 | n_examples = 0 150 | n_hard_ex = 0 151 | start = time() 152 | train_iter_i = iter(train_dataloader_i) 153 | # quick hack for amp delay_unscale bug 154 | optimizer.zero_grad() 155 | optimizer.step() 156 | while True: 157 | for batch in train_dataloader_t: 158 | 159 | # hard text from image 160 | try: 161 | batch_i = next(train_iter_i) 162 | except StopIteration: 163 | train_iter_i = iter(train_dataloader_i) 164 | batch_i = next(train_iter_i) 165 | n_examples += batch_i['attn_masks'].size(0) # 400 166 | if opts.IAIS: 167 | rank_loss, self_attn_loss_per_layer = model(batch_i, sample_from='i', compute_loss=True, 168 | IAIS='V-%s' % opts.IAIS) # Interval training for linguistic and visual modality 169 | rank_loss = rank_loss.mean() / opts.train_batch_size 170 | self_attn_tsa_loss = self_attn_loss_per_layer['self_attn_loss'] * get_tsa_threshold('exp_schedule', 171 | global_step, 172 | opts.num_train_steps) 173 | self_attn_loss_per_layer['self_attn_loss_tsa'] = self_attn_tsa_loss 174 | ranking_loss(rank_loss.item()) 175 | loss = rank_loss + self_attn_tsa_loss 176 | else: 177 | loss = model(batch_i, sample_from='i', compute_loss=True, IAIS=opts.IAIS) 178 | loss = loss.mean() / opts.train_batch_size 179 | n_hard_ex += loss.numel() # 31 180 | with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss: 181 | scaled_loss.backward() 182 | 183 | # hard image from text 184 | n_examples += batch['attn_masks'].size(0) # 400 185 | if opts.IAIS: 186 | rank_loss, self_attn_loss_per_layer = model(batch, sample_from='t', compute_loss=True, 187 | IAIS='L-%s' % opts.IAIS) # Interval training for linguistic and visual modality 188 | rank_loss = rank_loss.mean() / opts.train_batch_size 189 | self_attn_tsa_loss = self_attn_loss_per_layer['self_attn_loss'] * get_tsa_threshold('exp_schedule', 190 | global_step, 191 | opts.num_train_steps) 192 | self_attn_loss_per_layer['self_attn_loss_tsa'] = self_attn_tsa_loss 193 | ranking_loss(rank_loss.item()) 194 | loss = rank_loss + self_attn_tsa_loss 195 | else: 196 | loss = model(batch, sample_from='t', compute_loss=True, IAIS=opts.IAIS) 197 | loss = loss.mean() / opts.train_batch_size 198 | n_hard_ex += loss.numel() # 62 199 | # NOTE we use gradient accumulation to implemented train_batch_size 200 | 201 | step += 1 202 | delay_unscale = step % opts.train_batch_size != 0 203 | with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss: 204 | scaled_loss.backward() 205 | if not delay_unscale: 206 | # gather gradients from every processes 207 | # do this before unscaling to make sure every process uses 208 | # the same gradient scale 209 | grads = [p.grad.data for p in model.parameters() if p.requires_grad and p.grad is not None] 210 | all_reduce_and_rescale_tensors(grads, float(1)) 211 | 212 | running_loss(loss.item()) 213 | if step % opts.train_batch_size == 0: 214 | global_step += 1 215 | 216 | # learning rate scheduling 217 | lr_this_step = get_lr_sched(global_step, opts) 218 | for param_group in optimizer.param_groups: 219 | param_group['lr'] = lr_this_step 220 | TB_LOGGER.add_scalar('lr', lr_this_step, global_step) 221 | 222 | # log loss 223 | # NOTE: not gathered across GPUs for efficiency 224 | TB_LOGGER.add_scalar('loss', running_loss.val, global_step) 225 | if opts.IAIS: 226 | TB_LOGGER.log_scaler_dict(self_attn_loss_per_layer) 227 | TB_LOGGER.add_scalar('rank_loss', ranking_loss.val, global_step) 228 | TB_LOGGER.step() 229 | 230 | # update model params 231 | if opts.grad_norm != -1: 232 | grad_norm = clip_grad_norm_(amp.master_params(optimizer), opts.grad_norm) 233 | TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) 234 | optimizer.step() 235 | optimizer.zero_grad() 236 | pbar.update(1) 237 | 238 | if global_step % 100 == 0: 239 | # monitor training throughput 240 | LOGGER.info(f'------------Step {global_step}-------------') 241 | tot_ex = sum(all_gather_list(n_examples)) 242 | ex_per_sec = int(tot_ex / (time() - start)) 243 | tot_hn = sum(all_gather_list(n_hard_ex)) 244 | hn_per_sec = int(tot_hn / (time() - start)) 245 | LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) ' 246 | f'trained at {ex_per_sec} ({hn_per_sec}) ex/s') 247 | TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec, global_step) 248 | TB_LOGGER.add_scalar('perf/hn_per_s', hn_per_sec, global_step) 249 | LOGGER.info(f'-------------------------------------------') 250 | 251 | if global_step % opts.valid_steps == 0: 252 | if opts.full_val: 253 | LOGGER.info( 254 | f"========================== Step {global_step} " 255 | f"==========================") 256 | val_log = evaluate(model, eval_loader_val) 257 | TB_LOGGER.log_scaler_dict( 258 | {f"valid/{k}": v for k, v in val_log.items()}) 259 | if hvd.rank() == 0: 260 | LOGGER.info(f"image retrieval R1: " 261 | f"{val_log['img_r1'] * 100:.2f},\n" 262 | f"image retrieval R5: " 263 | f"{val_log['img_r5'] * 100:.2f},\n" 264 | f"image retrieval R10: " 265 | f"{val_log['img_r10'] * 100:.2f}\n" 266 | f"text retrieval R1: " 267 | f"{val_log['txt_r1'] * 100:.2f},\n" 268 | f"text retrieval R5: " 269 | f"{val_log['txt_r5'] * 100:.2f},\n" 270 | f"text retrieval R10: " 271 | f"{val_log['txt_r10'] * 100:.2f}") 272 | LOGGER.info("=================================" 273 | "=================================") 274 | else: 275 | val_log = validate(model, val_dataloader) 276 | TB_LOGGER.log_scaler_dict(val_log) 277 | model_saver.save(model, global_step) 278 | 279 | if global_step >= opts.num_train_steps: 280 | break 281 | 282 | if global_step >= opts.num_train_steps: 283 | break 284 | 285 | pbar.close() 286 | # final validation 287 | val_log = validate(model, val_dataloader) 288 | TB_LOGGER.log_scaler_dict(val_log) 289 | model_saver.save(model, f'{global_step}_final') 290 | 291 | # evaluation 292 | for split, loader in [('test', eval_loader_test)]: 293 | eval_log = evaluate(model, loader) 294 | TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v 295 | for k, v in eval_log.items()}) 296 | if hvd.rank() != 0: 297 | continue 298 | LOGGER.info( 299 | f"========================= {split} ===========================\n" 300 | f"image retrieval R1: {eval_log['img_r1'] * 100:.2f},\n" 301 | f"image retrieval R5: {eval_log['img_r5'] * 100:.2f},\n" 302 | f"image retrieval R10: {eval_log['img_r10'] * 100:.2f}\n" 303 | f"text retrieval R1: {eval_log['txt_r1'] * 100:.2f},\n" 304 | f"text retrieval R5: {eval_log['txt_r5'] * 100:.2f},\n" 305 | f"text retrieval R10: {eval_log['txt_r10'] * 100:.2f}") 306 | LOGGER.info("=========================================================") 307 | 308 | 309 | @torch.no_grad() 310 | def validate(model, val_loader): 311 | if hvd.rank() == 0: 312 | pbar = tqdm(total=len(val_loader)) 313 | else: 314 | pbar = NoOp() 315 | LOGGER.info("start running Image Retrieval validation ...") 316 | model.eval() 317 | n_ex = 0 318 | st = time() 319 | 320 | recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 321 | for batch in val_loader: 322 | scores = model(batch, compute_loss=False) 323 | _, indices = scores.squeeze(1).topk(10, dim=0) 324 | rank = (indices == 0).nonzero() 325 | if rank.numel(): 326 | rank = rank.item() 327 | if rank < 1: 328 | recall_at_1 += 1 329 | if rank < 5: 330 | recall_at_5 += 1 331 | if rank < 10: 332 | recall_at_10 += 1 333 | n_ex += 1 334 | pbar.update(1) 335 | n_ex = sum(all_gather_list(n_ex)) 336 | recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex 337 | recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex 338 | recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex 339 | tot_time = time() - st 340 | val_log = {'valid/ex_per_s': n_ex / tot_time, 341 | 'valid/recall_1': recall_at_1, 342 | 'valid/recall_5': recall_at_5, 343 | 'valid/recall_10': recall_at_10} 344 | model.train() 345 | LOGGER.info(f"validation finished in {int(tot_time)} seconds, " 346 | f"recall_1: {recall_at_1 * 100:.2f}, " 347 | f"recall_5: {recall_at_5 * 100:.2f}, " 348 | f"recall_10: {recall_at_10 * 100:.2f}") 349 | pbar.close() 350 | return val_log 351 | 352 | 353 | if __name__ == "__main__": 354 | parser = argparse.ArgumentParser() 355 | 356 | # Required parameters 357 | 358 | parser.add_argument('--compressed_db', action='store_true', 359 | help='use compressed LMDB') 360 | parser.add_argument("--checkpoint", 361 | default=None, type=str, 362 | help="pretrained MLM") 363 | 364 | parser.add_argument("--output_dir", default=None, type=str, 365 | help="The output directory where the model " 366 | "checkpoints will be written.") 367 | 368 | # Prepro parameters 369 | parser.add_argument('--max_txt_len', type=int, default=60, 370 | help='max number of tokens in text (BERT BPE)') 371 | parser.add_argument('--conf_th', type=float, default=0.2, 372 | help='threshold for dynamic bounding boxes ' 373 | '(-1 for fixed)') 374 | parser.add_argument('--max_bb', type=int, default=100, 375 | help='max number of bounding boxes') 376 | parser.add_argument('--min_bb', type=int, default=10, 377 | help='min number of bounding boxes') 378 | parser.add_argument('--num_bb', type=int, default=36, 379 | help='static number of bounding boxes') 380 | 381 | # training parameters 382 | parser.add_argument("--train_batch_size", default=32, type=int, 383 | help="batch size (# positive examples) for training. " 384 | "(implemented with gradient accumulation)") 385 | 386 | parser.add_argument("--negative_size", default=511, type=int, 387 | help="Number of negative samples per positive sample" 388 | "(forward only)") 389 | parser.add_argument("--hard_neg_size", default=31, type=int, 390 | help="Number of hard negative samples " 391 | "per positive sample (acutally used to train)") 392 | 393 | parser.add_argument("--inf_minibatch_size", default=512, type=int, 394 | help="batch size for running inference. " 395 | "(used for validation and evaluation)") 396 | 397 | parser.add_argument("--margin", default=0.2, type=float, 398 | help="margin of ranking loss") 399 | parser.add_argument("--learning_rate", default=3e-5, type=float, 400 | help="The initial learning rate for Adam.") 401 | parser.add_argument("--valid_steps", default=1000, type=int, 402 | help="Run validation every X steps") 403 | parser.add_argument("--num_train_steps", default=100000, type=int, 404 | help="Total number of training updates to perform.") 405 | parser.add_argument("--optim", default='adam', 406 | choices=['adam', 'adamax', 'adamw'], 407 | help="optimizer") 408 | parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', 409 | help="beta for adam optimizer") 410 | parser.add_argument("--dropout", default=0.1, type=float, 411 | help="tune dropout regularization") 412 | parser.add_argument("--weight_decay", default=0.01, type=float, 413 | help="weight decay (L2) regularization") 414 | parser.add_argument("--grad_norm", default=0.25, type=float, 415 | help="gradient clipping (-1 for no clipping)") 416 | parser.add_argument("--warmup_steps", default=4000, type=int, 417 | help="Number of training steps to perform linear " 418 | "learning rate warmup for.") 419 | 420 | # device parameters 421 | parser.add_argument('--seed', type=int, default=42, 422 | help="random seed for initialization") 423 | parser.add_argument('--full_val', action='store_true', 424 | help="Always run full evaluation during training") 425 | parser.add_argument('--fp16', action='store_true', 426 | help="Whether to use 16-bit float precision instead " 427 | "of 32-bit") 428 | parser.add_argument('--n_workers', type=int, default=4, 429 | help="number of data workers") 430 | parser.add_argument('--pin_mem', action='store_true', 431 | help="pin memory") 432 | 433 | # can use config files 434 | parser.add_argument('--config', help='JSON config files') 435 | 436 | parser.add_argument('--IAIS', default=False, choices=['distributed', 'singular', False], 437 | help='msa regularizer') 438 | 439 | args = parse_with_config(parser) 440 | 441 | if exists(args.output_dir) and os.listdir(args.output_dir): 442 | raise ValueError("Output directory ({}) already exists and is not " 443 | "empty.".format(args.output_dir)) 444 | 445 | # options safe guard 446 | if args.conf_th == -1: 447 | assert args.max_bb + args.max_txt_len + 2 <= 512 448 | else: 449 | assert args.num_bb + args.max_txt_len + 2 <= 512 450 | 451 | # for tensor core 452 | assert (args.negative_size + 1) % 8 == (args.hard_neg_size + 1) % 8 == 0 453 | 454 | main(args) 455 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Pytorch modules 6 | some classes are modified from HuggingFace 7 | (https://github.com/huggingface/transformers) 8 | """ 9 | import copy 10 | import json 11 | import logging 12 | from io import open 13 | import torch 14 | from torch import nn 15 | from apex.normalization.fused_layer_norm import FusedLayerNorm 16 | import torch.nn.functional as F 17 | from .layer import BertLayer, BertPooler 18 | from utils.heatmap import plot_attention_headmap 19 | import pdb 20 | import sys 21 | import numpy as np 22 | import traceback 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class UniterConfig(object): 28 | """Configuration class to store the configuration of a `UniterModel`. 29 | """ 30 | 31 | def __init__(self, 32 | vocab_size_or_config_json_file, 33 | hidden_size=768, 34 | num_hidden_layers=12, 35 | num_attention_heads=12, 36 | intermediate_size=3072, 37 | hidden_act="gelu", 38 | hidden_dropout_prob=0.1, 39 | attention_probs_dropout_prob=0.1, 40 | max_position_embeddings=512, 41 | type_vocab_size=2, 42 | initializer_range=0.02): 43 | """Constructs UniterConfig. 44 | Args: 45 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in 46 | `UniterModel`. 47 | hidden_size: Size of the encoder layers and the pooler layer. 48 | num_hidden_layers: Number of hidden layers in the Transformer 49 | encoder. 50 | num_attention_heads: Number of attention heads for each attention 51 | layer in the Transformer encoder. 52 | intermediate_size: The size of the "intermediate" (i.e. 53 | feed-forward) layer in the Transformer encoder. 54 | hidden_act: The non-linear activation function (function or string) 55 | in the encoder and pooler. If string, "gelu", "relu" and 56 | "swish" are supported. 57 | hidden_dropout_prob: The dropout probabilitiy for all fully 58 | connected layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this 62 | model might ever be used with. Typically set this to something 63 | large just in case (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed 65 | into `UniterModel`. 66 | initializer_range: The sttdev of the truncated_normal_initializer 67 | for initializing all weight matrices. 68 | """ 69 | if isinstance(vocab_size_or_config_json_file, str): 70 | with open(vocab_size_or_config_json_file, 71 | "r", encoding='utf-8') as reader: 72 | json_config = json.loads(reader.read()) 73 | for key, value in json_config.items(): 74 | self.__dict__[key] = value 75 | elif isinstance(vocab_size_or_config_json_file, int): 76 | self.vocab_size = vocab_size_or_config_json_file 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | else: 88 | raise ValueError("First argument must be either a vocabulary size " 89 | "(int) or the path to a pretrained model config " 90 | "file (str)") 91 | 92 | @classmethod 93 | def from_dict(cls, json_object): 94 | """Constructs a `UniterConfig` from a 95 | Python dictionary of parameters.""" 96 | config = UniterConfig(vocab_size_or_config_json_file=-1) 97 | for key, value in json_object.items(): 98 | config.__dict__[key] = value 99 | return config 100 | 101 | @classmethod 102 | def from_json_file(cls, json_file): 103 | """Constructs a `UniterConfig` from a json file of parameters.""" 104 | with open(json_file, "r", encoding='utf-8') as reader: 105 | text = reader.read() 106 | return cls.from_dict(json.loads(text)) 107 | 108 | def __repr__(self): 109 | return str(self.to_json_string()) 110 | 111 | def to_dict(self): 112 | """Serializes this instance to a Python dictionary.""" 113 | output = copy.deepcopy(self.__dict__) 114 | return output 115 | 116 | def to_json_string(self): 117 | """Serializes this instance to a JSON string.""" 118 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 119 | 120 | 121 | class UniterPreTrainedModel(nn.Module): 122 | """ An abstract class to handle weights initialization and 123 | a simple interface for dowloading and loading pretrained models. 124 | """ 125 | 126 | def __init__(self, config, *inputs, **kwargs): 127 | super().__init__() 128 | if not isinstance(config, UniterConfig): 129 | raise ValueError( 130 | "Parameter config in `{}(config)` should be an instance of " 131 | "class `UniterConfig`. To create a model from a Google " 132 | "pretrained model use " 133 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 134 | self.__class__.__name__, self.__class__.__name__ 135 | )) 136 | self.config = config 137 | 138 | def init_weights(self, module): 139 | """ Initialize the weights. 140 | """ 141 | if isinstance(module, (nn.Linear, nn.Embedding)): 142 | # Slightly different from the TF version which uses 143 | # truncated_normal for initialization 144 | # cf https://github.com/pytorch/pytorch/pull/5617 145 | module.weight.data.normal_(mean=0.0, 146 | std=self.config.initializer_range) 147 | elif isinstance(module, FusedLayerNorm): 148 | module.bias.data.zero_() 149 | module.weight.data.fill_(1.0) 150 | if isinstance(module, nn.Linear) and module.bias is not None: 151 | module.bias.data.zero_() 152 | 153 | @classmethod 154 | def from_pretrained(cls, config_file, state_dict, *inputs, **kwargs): 155 | """ 156 | Instantiate a UniterPreTrainedModel from a pre-trained model file or a 157 | pytorch state dict. 158 | Params: 159 | config_file: config json file 160 | state_dict: an state dictionnary 161 | *inputs, **kwargs: additional input for the specific Uniter class 162 | """ 163 | # Load config 164 | config = UniterConfig.from_json_file(config_file) 165 | logger.info("Model config {}".format(config)) 166 | # Instantiate model. 167 | model = cls(config, *inputs, **kwargs) 168 | # Load from a PyTorch state_dict 169 | old_keys = [] 170 | new_keys = [] 171 | for key in state_dict.keys(): 172 | new_key = None 173 | if 'gamma' in key: 174 | new_key = key.replace('gamma', 'weight') 175 | if 'beta' in key: 176 | new_key = key.replace('beta', 'bias') 177 | if new_key: 178 | old_keys.append(key) 179 | new_keys.append(new_key) 180 | for old_key, new_key in zip(old_keys, new_keys): 181 | state_dict[new_key] = state_dict.pop(old_key) 182 | 183 | missing_keys = [] 184 | unexpected_keys = [] 185 | error_msgs = [] 186 | # copy state_dict so _load_from_state_dict can modify it 187 | metadata = getattr(state_dict, '_metadata', None) 188 | state_dict = state_dict.copy() 189 | if metadata is not None: 190 | state_dict._metadata = metadata 191 | 192 | def load(module, prefix=''): 193 | local_metadata = ({} if metadata is None 194 | else metadata.get(prefix[:-1], {})) 195 | module._load_from_state_dict( 196 | state_dict, prefix, local_metadata, True, missing_keys, 197 | unexpected_keys, error_msgs) 198 | for name, child in module._modules.items(): 199 | if child is not None: 200 | load(child, prefix + name + '.') 201 | 202 | start_prefix = '' 203 | if not hasattr(model, 'bert') and any(s.startswith('bert.') 204 | for s in state_dict.keys()): 205 | start_prefix = 'bert.' 206 | load(model, prefix=start_prefix) 207 | if len(missing_keys) > 0: 208 | logger.info("Weights of {} not initialized from " 209 | "pretrained model: {}".format( 210 | model.__class__.__name__, missing_keys)) 211 | if len(unexpected_keys) > 0: 212 | logger.info("Weights from pretrained model not used in " 213 | "{}: {}".format( 214 | model.__class__.__name__, unexpected_keys)) 215 | if len(error_msgs) > 0: 216 | raise RuntimeError('Error(s) in loading state_dict for ' 217 | '{}:\n\t{}'.format( 218 | model.__class__.__name__, 219 | "\n\t".join(error_msgs))) 220 | return model 221 | 222 | 223 | class UniterTextEmbeddings(nn.Module): 224 | def __init__(self, config): 225 | super().__init__() 226 | self.word_embeddings = nn.Embedding(config.vocab_size, 227 | config.hidden_size, padding_idx=0) 228 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, 229 | config.hidden_size) 230 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, 231 | config.hidden_size) 232 | 233 | # self.LayerNorm is not snake-cased to stick with TensorFlow model 234 | # variable name and be able to load any TensorFlow checkpoint file 235 | self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) 236 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 237 | 238 | def forward(self, input_ids, position_ids, token_type_ids=None): 239 | if token_type_ids is None: 240 | token_type_ids = torch.zeros_like(input_ids) 241 | 242 | words_embeddings = self.word_embeddings(input_ids) 243 | position_embeddings = self.position_embeddings(position_ids) 244 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 245 | 246 | embeddings = (words_embeddings 247 | + position_embeddings 248 | + token_type_embeddings) 249 | embeddings = self.LayerNorm(embeddings) 250 | embeddings = self.dropout(embeddings) 251 | return embeddings 252 | 253 | 254 | class UniterImageEmbeddings(nn.Module): 255 | def __init__(self, config, img_dim): 256 | super().__init__() 257 | self.img_linear = nn.Linear(img_dim, config.hidden_size) 258 | self.img_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12) 259 | self.pos_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12) 260 | self.pos_linear = nn.Linear(7, config.hidden_size) 261 | self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0) 262 | 263 | # tf naming convention for layer norm 264 | self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) 265 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 266 | 267 | def forward(self, img_feat, img_pos_feat, type_embeddings, img_masks=None): 268 | if img_masks is not None: 269 | self.mask_embedding.weight.data[0, :].fill_(0) 270 | mask = self.mask_embedding(img_masks.long()) 271 | img_feat = img_feat + mask 272 | 273 | transformed_im = self.img_layer_norm(self.img_linear(img_feat)) 274 | transformed_pos = self.pos_layer_norm(self.pos_linear(img_pos_feat)) 275 | embeddings = transformed_im + transformed_pos + type_embeddings 276 | embeddings = self.LayerNorm(embeddings) 277 | embeddings = self.dropout(embeddings) 278 | return embeddings 279 | 280 | 281 | class UniterEncoder(nn.Module): 282 | def __init__(self, config): 283 | super().__init__() 284 | layer = BertLayer(config) 285 | self.layer = nn.ModuleList([copy.deepcopy(layer) 286 | for _ in range(config.num_hidden_layers)]) 287 | self.KLDivLoss = nn.KLDivLoss(reduction='batchmean') 288 | 289 | def get_attention_probs(self, layer_module, hidden_states, attn_mask, row_b, row_l, col_b=None, col_l=None): 290 | attn = layer_module.attention.self.get_attention_probs(hidden_states, 291 | attn_mask) # [sample_num, attn_head_num, ?, ?] 292 | attn = torch.mul(attn, attn_mask) 293 | attn = torch.narrow(attn, 2, row_b, row_l) 294 | if col_b is None and col_l is None: 295 | col_b, col_l = row_b, row_l 296 | attn = torch.narrow(attn, 3, col_b, col_l) 297 | attn = torch.mean(attn, dim=1) 298 | return attn 299 | 300 | def iais_distributed(self, txt_attn, img_attn, t2i_attn, i2t_attn, modal): 301 | if modal == 'L': 302 | pseudo_txt_attn = torch.matmul(t2i_attn, i2t_attn) 303 | iais_loss = self.KLDivLoss(torch.log(txt_attn + 1e-6), pseudo_txt_attn) + self.KLDivLoss( 304 | torch.log(pseudo_txt_attn + 1e-6), txt_attn) 305 | elif modal == 'V': 306 | pseudo_img_attn = torch.matmul(i2t_attn, t2i_attn) 307 | iais_loss = self.KLDivLoss(torch.log(img_attn + 1e-6), pseudo_img_attn) + self.KLDivLoss( 308 | torch.log(pseudo_img_attn + 1e-6), img_attn) 309 | else: 310 | raise ValueError('error modal') 311 | return iais_loss 312 | 313 | def iais_singular(self, txt_attn, img_attn, cross_attn, length, modal): 314 | index = cross_attn.argmax(-1).detach().cpu().numpy().tolist() 315 | rows = [[i] * length for i in index] 316 | cols = [index] * length 317 | if modal == 'L': 318 | pseudo_txt_attn = nn.Softmax(dim=-1)(img_attn[rows, cols]) 319 | iais_loss = self.KLDivLoss(txt_attn.log(), pseudo_txt_attn) + self.KLDivLoss(pseudo_txt_attn.log(), 320 | txt_attn) 321 | elif modal == 'V': 322 | pseudo_img_attn = nn.Softmax(dim=-1)(txt_attn[rows, cols]) 323 | iais_loss = self.KLDivLoss(img_attn.log(), pseudo_img_attn) + self.KLDivLoss(pseudo_img_attn.log(), 324 | img_attn) 325 | else: 326 | raise ValueError('error modal') 327 | return iais_loss 328 | 329 | def forward(self, input_, attention_mask, txt_attn_mask=None, img_attn_mask=None, 330 | t2i_attn_mask=None, i2t_attn_mask=None, max_tl=0, max_nbb=0, 331 | output_all_encoded_layers=True, IAIS=False, pairs_num=3): 332 | all_encoder_layers = [] 333 | self_attn_loss_per_layer = {} 334 | hidden_states = input_ 335 | for i, layer_module in enumerate(self.layer): # every layer_module is a bert_layer 336 | if IAIS and i == len(self.layer) - 1: 337 | gt_indices = torch.tensor(list(range(0, hidden_states.size(0), pairs_num)), 338 | dtype=torch.long, device=hidden_states.device) 339 | hidden_states_gt = hidden_states.index_select(0, gt_indices) 340 | txt_attn = self.get_attention_probs(layer_module, hidden_states_gt, txt_attn_mask, 1, 341 | max_tl - 2) # remove [cls] and [sep] 342 | img_attn = self.get_attention_probs(layer_module, hidden_states_gt, img_attn_mask, max_tl, max_nbb) 343 | t2i_attn = self.get_attention_probs(layer_module, hidden_states_gt, t2i_attn_mask, 1, max_tl - 2, 344 | max_tl, 345 | max_nbb) # [sample_num, max_tl-2, max_nbb] 346 | i2t_attn = self.get_attention_probs(layer_module, hidden_states_gt, i2t_attn_mask, max_tl, max_nbb, 1, 347 | max_tl - 2) # [sample_num, max_nbb, max_tl-2] 348 | 349 | self_attn_loss_layer_i = torch.tensor(0, dtype=hidden_states.dtype, device=hidden_states.device) 350 | for j, (input_len, nbb) in enumerate( 351 | zip(txt_attn_mask[:, 0, 1, :].sum(1), img_attn_mask[:, 0, max_tl, :].sum(1))): 352 | input_len, nbb = int(input_len.item()), int(nbb.item()) 353 | if IAIS == 'L-singular': 354 | iais_loss = self.iais_singular(txt_attn[j, :input_len, :input_len], img_attn[j, :nbb, :nbb], 355 | t2i_attn[j, :input_len], input_len, 'L') 356 | elif IAIS == 'V-singular': 357 | iais_loss = self.iais_singular(txt_attn[j, :input_len, :input_len], img_attn[j, :nbb, :nbb], 358 | i2t_attn[j, :nbb], nbb, 'V') 359 | elif IAIS == 'L-distributed': 360 | iais_loss = self.iais_distributed(txt_attn[j, :input_len, :input_len], img_attn[j, :nbb, :nbb], 361 | t2i_attn[j, :input_len, :nbb], i2t_attn[j, :nbb, :input_len], 'L') 362 | elif IAIS == 'V-distributed': 363 | iais_loss = self.iais_distributed(txt_attn[j, :input_len, :input_len], img_attn[j, :nbb, :nbb], 364 | t2i_attn[j, :input_len, :nbb], i2t_attn[j, :nbb, :input_len], 'V') 365 | else: 366 | raise ValueError("IAIS must in ['L-distributed', 'V-distributed', 'L-singular', 'V-singular']") 367 | 368 | self_attn_loss_layer_i += iais_loss 369 | self_attn_loss_per_layer['self_attn_loss/layer_%s' % i] = self_attn_loss_layer_i / gt_indices.size(0) 370 | self_attn_loss_per_layer['self_attn_loss'] = self_attn_loss_per_layer['self_attn_loss/layer_%s' % i] 371 | hidden_states = layer_module(hidden_states, attention_mask) 372 | if output_all_encoded_layers: 373 | all_encoder_layers.append(hidden_states) 374 | if not output_all_encoded_layers: 375 | all_encoder_layers.append(hidden_states) 376 | if IAIS: 377 | return all_encoder_layers, self_attn_loss_per_layer 378 | else: 379 | return all_encoder_layers 380 | 381 | 382 | class UniterModel(UniterPreTrainedModel): 383 | """ Modification for Joint Vision-Language Encoding 384 | """ 385 | 386 | def __init__(self, config, img_dim): 387 | super().__init__(config) 388 | self.embeddings = UniterTextEmbeddings(config) 389 | self.img_embeddings = UniterImageEmbeddings(config, img_dim) 390 | self.encoder = UniterEncoder(config) 391 | self.pooler = BertPooler(config) 392 | self.apply(self.init_weights) 393 | 394 | def _compute_txt_embeddings(self, input_ids, position_ids, 395 | txt_type_ids=None): 396 | output = self.embeddings(input_ids, position_ids, txt_type_ids) 397 | return output 398 | 399 | def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None, 400 | img_type_ids=None): 401 | if img_type_ids is None: 402 | img_type_ids = torch.ones_like(img_feat[:, :, 0].long()) 403 | img_type_embeddings = self.embeddings.token_type_embeddings( 404 | img_type_ids) 405 | output = self.img_embeddings(img_feat, img_pos_feat, 406 | img_type_embeddings, img_masks) 407 | return output 408 | 409 | def _compute_img_txt_embeddings(self, input_ids, position_ids, 410 | img_feat, img_pos_feat, 411 | gather_index, img_masks=None, 412 | txt_type_ids=None, img_type_ids=None): 413 | txt_emb = self._compute_txt_embeddings( # [sample_num, token_num, 768] 414 | input_ids, position_ids, txt_type_ids) 415 | img_emb = self._compute_img_embeddings( # [sample_num, bb_max_num, 768] 416 | img_feat, img_pos_feat, img_masks, img_type_ids) 417 | if gather_index is not None: # evaluation 418 | # align back to most compact input 419 | gather_index = gather_index.unsqueeze(-1).expand( # [sample_num, ?, 768] 420 | -1, -1, self.config.hidden_size) 421 | embedding_output = torch.gather(torch.cat([txt_emb, img_emb], dim=1), # [sample_num, ?, 768] 422 | dim=1, index=gather_index) 423 | else: 424 | embedding_output = torch.cat([txt_emb, img_emb], dim=1) 425 | return embedding_output 426 | 427 | def extend_self_attn_mask(self, attention_mask): 428 | '''note this attention is 0-1''' 429 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 430 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) 431 | attention_mask = torch.matmul(attention_mask.permute(0, 1, 3, 2), attention_mask) 432 | return attention_mask 433 | 434 | def extend_cross_attn_mask(self, txt_attn_mask, img_attn_mask): 435 | txt_attn_mask = txt_attn_mask.unsqueeze(1).unsqueeze(2) 436 | txt_attn_mask = txt_attn_mask.to(dtype=next(self.parameters()).dtype) 437 | img_attn_mask = img_attn_mask.unsqueeze(1).unsqueeze(2) 438 | img_attn_mask = img_attn_mask.to(dtype=next(self.parameters()).dtype) 439 | t2i_attn_mask = torch.matmul(txt_attn_mask.permute(0, 1, 3, 2), img_attn_mask) 440 | i2t_attn_mask = torch.matmul(img_attn_mask.permute(0, 1, 3, 2), txt_attn_mask) 441 | return t2i_attn_mask, i2t_attn_mask 442 | 443 | def forward(self, input_ids, position_ids, 444 | img_feat, img_pos_feat, 445 | attention_mask, gather_index=None, img_masks=None, 446 | txt_attn_mask=None, img_attn_mask=None, 447 | output_all_encoded_layers=True, 448 | IAIS=False, 449 | txt_type_ids=None, img_type_ids=None, pairs_num=3): 450 | ''' 451 | input_ids: [sample_num, max_tl], position_ids: [1, max_tl] 452 | img_feat: [sample_num, max_nbb, 2048], img_pos_feat: [sample_num, max_nbb, 7] 453 | attention_mask: [sample_num, max_attn_len(max_tl+max_nbb)] 454 | ''' 455 | # compute self-attention mask 456 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze( 457 | 2) # [sample_num, 1, 1, max_attn_len(max_tl+max_nbb)] 458 | extended_attention_mask = extended_attention_mask.to( 459 | dtype=next(self.parameters()).dtype) # fp16 compatibility 460 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 461 | 462 | # embedding layer 463 | if input_ids is None: 464 | # image only 465 | embedding_output = self._compute_img_embeddings( 466 | img_feat, img_pos_feat, img_masks, img_type_ids) 467 | elif img_feat is None: 468 | # text only 469 | embedding_output = self._compute_txt_embeddings( 470 | input_ids, position_ids, txt_type_ids) 471 | else: 472 | embedding_output = self._compute_img_txt_embeddings( 473 | input_ids, position_ids, 474 | img_feat, img_pos_feat, 475 | gather_index, img_masks, txt_type_ids, img_type_ids) 476 | 477 | if IAIS: # train & IAIS 478 | assert txt_attn_mask is not None and img_attn_mask is not None 479 | extended_txt_attn_mask = self.extend_self_attn_mask( 480 | txt_attn_mask) # [sample_num, 1, max_attn_len, max_attn_len] 481 | extended_img_attn_mask = self.extend_self_attn_mask(img_attn_mask) 482 | extended_t2i_attn_mask, extended_i2t_attn_mask = self.extend_cross_attn_mask(txt_attn_mask, img_attn_mask) 483 | 484 | encoded_layers, self_attn_loss_per_layer = self.encoder( 485 | embedding_output, extended_attention_mask, 486 | extended_txt_attn_mask, extended_img_attn_mask, 487 | extended_t2i_attn_mask, extended_i2t_attn_mask, 488 | input_ids.size(1), img_feat.size(1), 489 | output_all_encoded_layers=output_all_encoded_layers, 490 | IAIS=IAIS, 491 | pairs_num=pairs_num) 492 | if not output_all_encoded_layers: 493 | encoded_layers = encoded_layers[-1] 494 | return encoded_layers, self_attn_loss_per_layer 495 | else: # evaluation 496 | encoded_layers = self.encoder( 497 | embedding_output, extended_attention_mask, 498 | output_all_encoded_layers=output_all_encoded_layers) 499 | if not output_all_encoded_layers: 500 | encoded_layers = encoded_layers[-1] 501 | return encoded_layers 502 | --------------------------------------------------------------------------------