├── UNITER ├── utils │ ├── __init__.py │ ├── const.py │ ├── misc.py │ ├── logger.py │ ├── save.py │ ├── itm_eval.py │ └── distributed.py ├── optim │ ├── __init__.py │ ├── misc.py │ ├── sched.py │ └── adamw.py ├── scripts │ ├── convert_ckpt.py │ ├── download_pretrained.sh │ ├── eval_refcoco.sh │ ├── extract_imgfeat.sh │ ├── create_imgdb.sh │ ├── eval_refcoco+.sh │ ├── eval_refcocog.sh │ ├── download_ve.sh │ ├── eval_refgta.sh │ ├── download_re.sh │ ├── download_vqa.sh │ ├── download_vcr.sh │ ├── download_nlvr2.sh │ ├── create_txtdb.sh │ ├── download_itm.sh │ ├── download_indomain.sh │ ├── create_txtdb_re.sh │ └── convert_imgdir.py ├── data │ ├── __init__.py │ ├── sampler.py │ ├── loader.py │ └── data.py ├── config │ ├── uniter-base.json │ ├── uniter-large.json │ ├── train-refcocog-large-1gpu.json │ └── train-refcoco+-large-1gpu.json ├── model │ ├── ve.py │ ├── vqa.py │ ├── ot.py │ ├── vcr.py │ ├── itm.py │ ├── re.py │ ├── pretrain_vcr.py │ └── nlvr2.py ├── launch_container.sh ├── UNITER_LICENSE ├── requirements.txt ├── Dockerfile ├── prepro.py └── inf_re.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz └── simple_tokenizer.py ├── clip_mm_explain ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── simple_tokenizer.py └── clip.py ├── methods ├── __init__.py ├── ref_method.py ├── random_method.py └── baseline.py ├── albef ├── config_bert.json ├── utils.py ├── config.yaml ├── albef_license.txt └── vit.py ├── UniDet └── extract_boxes.py ├── requirements.txt ├── lattice.py ├── clevr-dataset-gen ├── bounding_box.py └── gather_simple_clevr.py ├── README.md ├── heuristics.py ├── pytorch_grad_cam └── activations_and_gradients.py ├── py-bottom-up-attention └── extract_features.py ├── generic_clip_pairs.py ├── entity_extraction.py └── interpreter.py /UNITER/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip_mm_explain/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pumpkin805/FALIP/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import Baseline 2 | from .random_method import Random 3 | from .parse import Parse 4 | -------------------------------------------------------------------------------- /clip_mm_explain/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pumpkin805/FALIP/HEAD/clip_mm_explain/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /UNITER/utils/const.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | constants 6 | """ 7 | IMG_DIM = 2048 8 | IMG_LABEL_DIM = 1601 9 | BUCKET_SIZE = 8192 10 | -------------------------------------------------------------------------------- /UNITER/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched 7 | from .adamw import AdamW 8 | -------------------------------------------------------------------------------- /UNITER/scripts/convert_ckpt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | 4 | import torch 5 | 6 | bert_ckpt, output_ckpt = sys.argv[1:] 7 | 8 | bert = torch.load(bert_ckpt) 9 | uniter = OrderedDict() 10 | for k, v in bert.items(): 11 | uniter[k.replace('bert', 'uniter')] = v 12 | 13 | torch.save(uniter, output_ckpt) 14 | -------------------------------------------------------------------------------- /UNITER/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .data import (DetectFeatPt,) 7 | from .sampler import TokenBucketSampler 8 | from .loader import PrefetchLoader, MetaLoader 9 | from .re import (ReTxtTokJson, ReTrainJsonDataset, ReEvalJsonDataset, 10 | re_collate, re_eval_collate) 11 | -------------------------------------------------------------------------------- /UNITER/config/uniter-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /UNITER/config/uniter-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 16, 10 | "num_hidden_layers": 24, 11 | "type_vocab_size": 2, 12 | "vocab_size": 28996 13 | } 14 | -------------------------------------------------------------------------------- /UNITER/model/ve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | UNITER for VE model 6 | """ 7 | from .vqa import UniterForVisualQuestionAnswering 8 | 9 | 10 | class UniterForVisualEntailment(UniterForVisualQuestionAnswering): 11 | """ Finetune UNITER for VE 12 | """ 13 | def __init__(self, config, img_dim): 14 | super().__init__(config, img_dim, 3) 15 | -------------------------------------------------------------------------------- /methods/ref_method.py: -------------------------------------------------------------------------------- 1 | """Base class for a method for doing referring expressions.""" 2 | 3 | from typing import Dict, Any 4 | from abc import ABCMeta, abstractmethod 5 | 6 | 7 | class RefMethod(metaclass=ABCMeta): 8 | @abstractmethod 9 | def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: 10 | return NotImplemented 11 | 12 | def get_stats(self) -> Dict[str, Any]: 13 | return {} 14 | -------------------------------------------------------------------------------- /UNITER/scripts/download_pretrained.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | if [ ! -d $DOWNLOAD/pretrained ] ; then 7 | mkdir -p $DOWNLOAD/pretrained 8 | fi 9 | 10 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 11 | 12 | for MODEL in uniter-base uniter-large; do 13 | # This will overwrite models 14 | wget $BLOB/pretrained/$MODEL.pt -O $DOWNLOAD/pretrained/$MODEL.pt 15 | done 16 | -------------------------------------------------------------------------------- /UNITER/scripts/eval_refcoco.sh: -------------------------------------------------------------------------------- 1 | OUT_DIR=$1 2 | python inf_re.py \ 3 | --txt_db /txt/refcoco_val.db:/txt/refcoco_testA.db:/txt/refcoco_testB.db \ 4 | --img_db /img/re_coco_gt \ 5 | --output_dir $OUT_DIR \ 6 | --checkpoint best \ 7 | --tmp_file re_exp/tmp_refcoco.txt 8 | 9 | python inf_re.py \ 10 | --txt_db /txt/refcoco_val.db:/txt/refcoco_testA.db:/txt/refcoco_testB.db \ 11 | --img_db /img/re_coco_det \ 12 | --output_dir $OUT_DIR \ 13 | --checkpoint best \ 14 | --tmp_file re_exp/tmp_refcoco.txt 15 | -------------------------------------------------------------------------------- /UNITER/scripts/extract_imgfeat.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | IMG_DIR=$1 5 | OUT_DIR=$2 6 | 7 | set -e 8 | 9 | echo "extracting image features..." 10 | if [ ! -d $OUT_DIR ]; then 11 | mkdir -p $OUT_DIR 12 | fi 13 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm \ 14 | --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ 15 | --mount src=$OUT_DIR,dst=/output,type=bind \ 16 | -w /src chenrocks/butd-caffe:nlvr2 \ 17 | bash -c "python tools/generate_npz.py --gpu 0" 18 | 19 | echo "done" 20 | -------------------------------------------------------------------------------- /albef/config_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "fusion_layer": 6, 20 | "encoder_width": 768 21 | } 22 | -------------------------------------------------------------------------------- /UNITER/scripts/create_imgdb.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | IMG_NPY=$1 5 | OUT_DIR=$2 6 | 7 | set -e 8 | 9 | echo "converting image features ..." 10 | if [ ! -d $OUT_DIR ]; then 11 | mkdir -p $OUT_DIR 12 | fi 13 | NAME=$(basename $IMG_NPY) 14 | docker run --ipc=host --rm -it \ 15 | --mount src=$(pwd),dst=/src,type=bind \ 16 | --mount src=$OUT_DIR,dst=/img_db,type=bind \ 17 | --mount src=$IMG_NPY,dst=/$NAME,type=bind,readonly \ 18 | -w /src chenrocks/uniter \ 19 | python scripts/convert_imgdir.py --img_dir /$NAME --output /img_db 20 | 21 | echo "done" 22 | -------------------------------------------------------------------------------- /albef/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def pre_caption(caption,max_words): 4 | caption = re.sub( 5 | r"([,.'!?\"()*#:;~])", 6 | '', 7 | caption.lower(), 8 | ).replace('-', ' ').replace('/', ' ').replace('', 'person') 9 | 10 | caption = re.sub( 11 | r"\s{2,}", 12 | ' ', 13 | caption, 14 | ) 15 | caption = caption.rstrip('\n') 16 | caption = caption.strip(' ') 17 | 18 | #truncate caption 19 | caption_words = caption.split(' ') 20 | if len(caption_words)>max_words: 21 | caption = ' '.join(caption_words[:max_words]) 22 | 23 | return caption 24 | -------------------------------------------------------------------------------- /UNITER/scripts/eval_refcoco+.sh: -------------------------------------------------------------------------------- 1 | OUT_DIR=$1 2 | ROOT_DIR=/PATH/TO/DIRECTORY/CONTAINING/FEATURE/FILES 3 | 4 | python inf_re.py \ 5 | --txt_db /PATH/TO/refcoco+_val.jsonl \ 6 | --img_db $ROOT_DIR/refcoco+_val_gt_boxes10100.pt \ 7 | --output_dir $OUT_DIR \ 8 | --checkpoint best \ 9 | --simple_format --n_workers 1 --batch_size 128 --tmp_file re_exp/tmp_refcoco+.txt 10 | 11 | python inf_re.py \ 12 | --txt_db /PATH/TO/refcoco+_val.jsonl \ 13 | --img_db $ROOT_DIR/refcoco+_val_dt_boxes10100.pt \ 14 | --output_dir $OUT_DIR \ 15 | --checkpoint best \ 16 | --simple_format --n_workers 1 --batch_size 128 --tmp_file re_exp/tmp_refcoco+.txt 17 | -------------------------------------------------------------------------------- /UNITER/scripts/eval_refcocog.sh: -------------------------------------------------------------------------------- 1 | OUT_DIR=$1 2 | ROOT_DIR=/PATH/TO/DIRECTORY/CONTAINING/FEATURE/FILES 3 | 4 | python inf_re.py \ 5 | --txt_db /PATH/TO/refcocog_val.jsonl \ 6 | --img_db $ROOT_DIR/refcocog_val_gt_boxes10100.pt \ 7 | --output_dir $OUT_DIR \ 8 | --checkpoint best \ 9 | --simple_format --n_workers 1 --batch_size 128 --tmp_file re_exp/tmp_refcocog.txt 10 | 11 | python inf_re.py \ 12 | --txt_db /PATH/TO/refcocog_val.jsonl \ 13 | --img_db $ROOT_DIR/refcocog_val_dt_boxes10100.pt \ 14 | --output_dir $OUT_DIR \ 15 | --checkpoint best \ 16 | --simple_format --n_workers 1 --batch_size 128 --tmp_file re_exp/tmp_refcocog.txt 17 | -------------------------------------------------------------------------------- /UNITER/launch_container.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | TXT_DB=$1 5 | IMG_DIR=$2 6 | OUTPUT=$3 7 | PRETRAIN_DIR=$4 8 | 9 | if [ -z $CUDA_VISIBLE_DEVICES ]; then 10 | CUDA_VISIBLE_DEVICES='all' 11 | fi 12 | 13 | 14 | docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ 15 | --mount src=$(pwd),dst=/src,type=bind \ 16 | --mount src=$OUTPUT,dst=/storage,type=bind \ 17 | --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ 18 | --mount src=$TXT_DB,dst=/txt,type=bind,readonly \ 19 | --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ 20 | -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ 21 | -w /src chenrocks/uniter 22 | -------------------------------------------------------------------------------- /albef/config.yaml: -------------------------------------------------------------------------------- 1 | train_file: ['data/refcoco+_train.json'] 2 | test_file: ['data/refcoco+_val.json','data/refcoco+_test.json'] 3 | 4 | refcoco_data: 'data' 5 | det_file: 'data/refcoco+/dets.json' 6 | coco_file: 'data/refcoco+/cocos.json' 7 | 8 | image_root: '/home/sanjays/refer-zero-shot/' 9 | 10 | bert_config: 'albef/config_bert.json' 11 | 12 | image_res: 384 13 | batch_size: 8 14 | 15 | queue_size: 65536 16 | momentum: 0.995 17 | vision_width: 768 18 | embed_dim: 256 19 | temp: 0.07 20 | 21 | alpha: 0.4 22 | distill: True 23 | warm_up: True 24 | 25 | optimizer: {opt: adamW, lr: 1e-5, weight_decay: 0.02} 26 | schedular: {sched: cosine, lr: 1e-5, epochs: 5, min_lr: 1e-6, decay_rate: 1, warmup_lr: 1e-5, warmup_epochs: 1, cooldown_epochs: 0} 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /UNITER/scripts/download_ve.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image db 15 | if [ ! -d $DOWNLOAD/img_db/flickr30k ] ; then 16 | wget $BLOB/img_db/flickr30k.tar -P $DOWNLOAD/img_db/ 17 | tar -xvf $DOWNLOAD/img_db/flickr30k.tar -C $DOWNLOAD/img_db 18 | fi 19 | 20 | # text dbs 21 | for SPLIT in 'train' 'dev' 'test'; do 22 | wget $BLOB/txt_db/ve_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 23 | tar -xvf $DOWNLOAD/txt_db/ve_$SPLIT.db.tar -C $DOWNLOAD/txt_db 24 | done 25 | 26 | if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then 27 | wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ 28 | fi 29 | 30 | -------------------------------------------------------------------------------- /UNITER/scripts/eval_refgta.sh: -------------------------------------------------------------------------------- 1 | OUT_DIR=$1 2 | ROOT_DIR=/PATH/TO/DIRECTORY/CONTAINING/FEATURE/FILES/ 3 | python inf_re.py \ 4 | --txt_db /PATH/TO/refgta_val.jsonl \ 5 | --img_db $ROOT_DIR/refgta_val_gt_boxes10100.pt \ 6 | --output_dir $OUT_DIR \ 7 | --checkpoint best \ 8 | --tmp_file re_exp/tmp_refcocog.txt \ 9 | --simple_format --n_workers 1 --batch_size 128 10 | 11 | python inf_re.py \ 12 | --txt_db /PATH/TO/refgta_val.jsonl \ 13 | --img_db $ROOT_DIR/refgta_val_unidet_dt_boxes10100.pt \ 14 | --output_dir $OUT_DIR \ 15 | --checkpoint best \ 16 | --tmp_file re_exp/tmp_refcocog.txt \ 17 | --simple_format --n_workers 1 --batch_size 128 18 | 19 | python inf_re.py \ 20 | --txt_db /PATH/TO/refgta_val.jsonl \ 21 | --img_db $ROOT_DIR/refgta_val_unidet_all_dt_boxes10100.pt \ 22 | --output_dir $OUT_DIR \ 23 | --checkpoint best \ 24 | --tmp_file re_exp/tmp_refcocog.txt \ 25 | --simple_format --n_workers 1 --batch_size 128 26 | -------------------------------------------------------------------------------- /UNITER/scripts/download_re.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image db 15 | if [ ! -d $DOWNLOAD/img_db/re_coco_gt ] ; then 16 | wget $BLOB/img_db/re_coco_gt.tar -P $DOWNLOAD/img_db/ 17 | tar -xvf $DOWNLOAD/img_db/re_coco_gt.tar -C $DOWNLOAD/img_db 18 | fi 19 | if [ ! -d $DOWNLOAD/img_db/re_coco_det ] ; then 20 | wget $BLOB/img_db/re_coco_det.tar -P $DOWNLOAD/img_db/ 21 | tar -xvf $DOWNLOAD/img_db/re_coco_det.tar -C $DOWNLOAD/img_db 22 | fi 23 | 24 | # text dbs 25 | wget $BLOB/txt_db/re_txt_db.tar -P $DOWNLOAD/txt_db/ 26 | tar -xvf $DOWNLOAD/txt_db/re_txt_db.tar -C $DOWNLOAD/txt_db/ 27 | 28 | if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then 29 | wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ 30 | fi 31 | 32 | -------------------------------------------------------------------------------- /methods/random_method.py: -------------------------------------------------------------------------------- 1 | """A naive baseline method: just pass the full expression to CLIP.""" 2 | 3 | from overrides import overrides 4 | from typing import Dict, Any 5 | import random 6 | from argparse import Namespace 7 | 8 | import numpy as np 9 | 10 | from .ref_method import RefMethod 11 | 12 | 13 | class Random(RefMethod): 14 | """CLIP-only baseline where each box is evaluated with the full expression.""" 15 | 16 | def __init__(self, args: Namespace): 17 | self.box_area_threshold = args.box_area_threshold 18 | 19 | @overrides 20 | def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: 21 | probs = env.filter_area(self.box_area_threshold)*env.uniform() 22 | random_ordering = list(range(len(env.boxes))) 23 | random.shuffle(random_ordering) 24 | random_ordering = np.array(random_ordering) 25 | pred = np.argmax(probs*random_ordering) 26 | return { 27 | "probs": probs.tolist(), 28 | "pred": int(pred), 29 | "text": caption.lower() 30 | } 31 | -------------------------------------------------------------------------------- /UNITER/config/train-refcocog-large-1gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_txt_db": "/PATH/TO/refcocog_train.jsonl", 3 | "train_img_db": "/PATH/TO/refcocog_train_gt_boxes10100.pt", 4 | "val_txt_db": "/PATH/TO/refcocog_val.jsonl", 5 | "val_img_db": "/PATH/TO/refcocog_val_det_boxes10100.pt", 6 | "compressed_db": false, 7 | "model_config": "config/uniter-large.json", 8 | "checkpoint": "downloads/pretrained/uniter-large.pt", 9 | "mask_dir": "/pretrain/philly_output/pretrain/lottery-cont20k-8gpu-fp32-2/ckpt/mask_6.pt", 10 | "output_dir": "/PATH/TO/OUTPUT", 11 | "max_txt_len": 60, 12 | "conf_th": -1, 13 | "num_bb": 100, 14 | "train_batch_size": 128, 15 | "val_batch_size": 128, 16 | "learning_rate": 1e-4, 17 | "lr_mul": 1.0, 18 | "optim": "adamw", 19 | "betas": [0.9, 0.98], 20 | "weight_decay": 0.01, 21 | "dropout": 0.1, 22 | "grad_norm": 2.0, 23 | "decay": "linear", 24 | "num_train_steps": 24000, 25 | "warmup_steps": 1500, 26 | "gradient_accumulation_steps": 1, 27 | "seed": 24, 28 | "fp16": true, 29 | "n_workers": 4, 30 | "pin_mem": true 31 | } 32 | -------------------------------------------------------------------------------- /UNITER/scripts/download_vqa.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image dbs 15 | for SPLIT in 'train2014' 'val2014' 'test2015'; do 16 | if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then 17 | wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/ 18 | tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db 19 | fi 20 | done 21 | wget $BLOB/img_db/vg.tar -P $DOWNLOAD/img_db/ 22 | tar -xvf $DOWNLOAD/img_db/vg.tar -C $DOWNLOAD/img_db 23 | 24 | # text dbs 25 | for SPLIT in 'train' 'trainval' 'devval' 'test' 'vg'; do 26 | wget $BLOB/txt_db/vqa_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 27 | tar -xvf $DOWNLOAD/txt_db/vqa_$SPLIT.db.tar -C $DOWNLOAD/txt_db 28 | done 29 | 30 | if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then 31 | wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ 32 | fi 33 | 34 | -------------------------------------------------------------------------------- /UNITER/config/train-refcoco+-large-1gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_txt_db": "/PATH/TO/refcoco+_train.jsonl", 3 | "train_img_db": "/PATH/TO/refcoco+_train_gt_boxes10100.pt", 4 | "val_txt_db": "/PATH/TO/refcoco+_val.jsonl", 5 | "val_img_db": "/PATH/TO/refcoco+_val_dt_boxes10100.pt", 6 | "compressed_db": false, 7 | "model_config": "config/uniter-large.json", 8 | "checkpoint": "downloads/pretrained/uniter-large.pt", 9 | "mask_dir": "/pretrain/philly_output/pretrain/lottery-cont20k-8gpu-fp32-2/ckpt/mask_6.pt", 10 | "output_dir": "/PATH/TO/OUTPUT", 11 | "max_txt_len": 60, 12 | "conf_th": -1, 13 | "num_bb": 100, 14 | "train_batch_size": 128, 15 | "val_batch_size": 128, 16 | "learning_rate": 6e-5, 17 | "lr_mul": 1.0, 18 | "optim": "adamw", 19 | "betas": [0.9, 0.98], 20 | "weight_decay": 0.01, 21 | "dropout": 0.1, 22 | "grad_norm": 2.0, 23 | "decay": "linear", 24 | "num_train_steps": 10000, 25 | "warmup_steps": 1000, 26 | "gradient_accumulation_steps": 1, 27 | "mlp": 2, 28 | "seed": 24, 29 | "fp16": true, 30 | "n_workers": 4, 31 | "pin_mem": true 32 | } 33 | -------------------------------------------------------------------------------- /UNITER/UNITER_LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /UNITER/optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc lr helper 6 | """ 7 | from torch.optim import Adam, Adamax 8 | 9 | from .adamw import AdamW 10 | 11 | 12 | def build_optimizer(model, opts): 13 | param_optimizer = list(model.named_parameters()) 14 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 15 | optimizer_grouped_parameters = [ 16 | {'params': [p for n, p in param_optimizer 17 | if not any(nd in n for nd in no_decay)], 18 | 'weight_decay': opts.weight_decay}, 19 | {'params': [p for n, p in param_optimizer 20 | if any(nd in n for nd in no_decay)], 21 | 'weight_decay': 0.0} 22 | ] 23 | 24 | # currently Adam only 25 | if opts.optim == 'adam': 26 | OptimCls = Adam 27 | elif opts.optim == 'adamax': 28 | OptimCls = Adamax 29 | elif opts.optim == 'adamw': 30 | OptimCls = AdamW 31 | else: 32 | raise ValueError('invalid optimizer') 33 | optimizer = OptimCls(optimizer_grouped_parameters, 34 | lr=opts.learning_rate, betas=opts.betas) 35 | return optimizer 36 | -------------------------------------------------------------------------------- /UNITER/requirements.txt: -------------------------------------------------------------------------------- 1 | # apex==0.1 2 | backcall==0.2.0 3 | boto3 4 | botocore 5 | brotlipy==0.7.0 6 | # certifi==2021.5.30 7 | cffi 8 | charset-normalizer==2.0.10 9 | cloudpickle==2.0.0 10 | cryptography 11 | cytoolz==0.11.2 12 | dataclasses 13 | decorator==5.1.1 14 | future==0.18.2 15 | fvcore==0.1.3.post20210317 16 | # horovod==0.24.2 17 | idna 18 | iopath==0.1.9 19 | ipdb==0.13.9 20 | ipython==7.16.3 21 | ipython-genutils==0.2.0 22 | jedi==0.17.2 23 | jmespath 24 | lmdb==0.97 25 | mkl-fft 26 | mkl-random 27 | # mkl-service 28 | msgpack==1.0.3 29 | msgpack-numpy==0.4.7.1 30 | numpy 31 | olefile 32 | parso==0.7.1 33 | pexpect==4.8.0 34 | pickleshare==0.7.5 35 | Pillow==6.2.1 36 | prompt-toolkit==3.0.26 37 | protobuf==3.19.4 38 | psutil==5.9.0 39 | ptyprocess==0.7.0 40 | pycocotools==2.0.4 41 | pycparser 42 | pydot==1.4.2 43 | Pygments==2.11.2 44 | pyOpenSSL 45 | PySocks 46 | python-dateutil 47 | pytorch-pretrained-bert==0.6.2 48 | PyYAML==6.0 49 | regex 50 | requests 51 | s3transfer 52 | six 53 | tensorboard==2.8.0 54 | tensorboardX==1.7 55 | toolz==0.11.2 56 | # torch>=1.10 57 | # torchvision>=0.11 58 | tqdm 59 | traitlets==4.3.3 60 | typing_extensions 61 | urllib3==1.26.7 62 | wcwidth==0.2.5 63 | -------------------------------------------------------------------------------- /UNITER/scripts/download_vcr.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image dbs 15 | for SPLIT in 'train' 'val' 'test' 'gt_train' 'gt_val' 'gt_test'; do 16 | if [ ! -d $DOWNLOAD/img_db/vcr_$SPLIT ] ; then 17 | wget $BLOB/img_db/vcr_$SPLIT.tar -P $DOWNLOAD/img_db/ 18 | tar -xvf $DOWNLOAD/img_db/vcr_$SPLIT.tar -C $DOWNLOAD/img_db 19 | fi 20 | done 21 | 22 | # text dbs 23 | for SPLIT in 'train' 'val' 'test'; do 24 | wget $BLOB/txt_db/vcr_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 25 | tar -xvf $DOWNLOAD/txt_db/vcr_$SPLIT.db.tar -C $DOWNLOAD/txt_db 26 | done 27 | 28 | if [ ! -f $DOWNLOAD/pretrained/uniter-large-vcr_2nd_stage.pt ] ; then 29 | wget $BLOB/pretrained/uniter-large-vcr_2nd_stage.pt -P $DOWNLOAD/pretrained/ 30 | fi 31 | 32 | if [ ! -f $DOWNLOAD/pretrained/uniter-base-vcr_2nd_stage.pt ] ; then 33 | wget $BLOB/pretrained/uniter-base-vcr_2nd_stage.pt -P $DOWNLOAD/pretrained/ 34 | fi 35 | -------------------------------------------------------------------------------- /UniDet/extract_boxes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | PERSON_LABEL = 134 7 | CITYSCAPES_LABELS = {13, 125, 134, 136, 146, 151, 157, 169, 181, 219, 230, 231, 233, 246, 285, 334, 337, 406, 452, 471, 480, 596, 607, 695} 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--input_path") 11 | parser.add_argument("--output_path") 12 | args = parser.parse_args() 13 | 14 | f = open(args.input_path) 15 | lines = f.readlines() 16 | data = [json.loads(line) for line in lines] 17 | detections = {} 18 | for datum in tqdm(data): 19 | parts = datum['file_name'].split('/') 20 | file_name = parts[0]+'/'+parts[-1].split('.')[0]+'.pt' 21 | print(file_name) 22 | instances = torch.load(file_name, map_location='cpu') 23 | indices = [i for i in range(len(instances['instances'].pred_classes)) if instances['instances'].pred_classes[i].item() == PERSON_LABEL] 24 | # indices = list(range(len(instances['instances'].pred_classes))) 25 | detections[datum['image_id']] = {"boxes": instances['instances'].pred_boxes.tensor[indices,:].tolist(), "scores": instances['instances'].scores[indices].tolist()} 26 | fout = open(args.output_path, 'w') 27 | json.dump(detections, fout) 28 | fout.close() 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==21.2.0 2 | blis==0.7.4 3 | catalogue==2.0.4 4 | certifi==2021.5.30 5 | chardet==4.0.0 6 | click==7.1.2 7 | clip @ git+https://github.com/openai/CLIP.git 8 | cymem==2.0.5 9 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl 10 | filelock==3.0.12 11 | ftfy==6.0.3 12 | huggingface-hub==0.0.12 13 | idna==2.10 14 | iniconfig==1.1.1 15 | itsdangerous==2.0.1 16 | joblib==1.0.1 17 | MarkupSafe==2.0.1 18 | murmurhash==1.0.5 19 | numpy==1.21.0 20 | overrides==6.1.0 21 | packaging==21.0 22 | pathy==0.6.0 23 | Pillow==8.2.0 24 | pluggy==0.13.1 25 | preshed==3.0.5 26 | py==1.10.0 27 | pydantic==1.7.4 28 | pyparsing==2.4.7 29 | pytest==6.2.4 30 | PyYAML==5.4.1 31 | regex==2021.7.6 32 | requests==2.25.1 33 | ruamel.yaml==0.17.10 34 | ruamel.yaml.clib==0.2.6 35 | sacremoses==0.0.45 36 | scipy==1.7.0 37 | six==1.16.0 38 | smart-open==5.1.0 39 | spacy==3.0.6 40 | spacy-legacy==3.0.7 41 | srsly==2.4.1 42 | thinc==8.0.7 43 | timm==0.4.12 44 | tokenizers==0.10.3 45 | toml==0.10.2 46 | tqdm==4.61.2 47 | transformers==4.9.0 48 | typer==0.3.2 49 | typing-extensions==3.10.0.0 50 | typing-utils==0.1.0 51 | urllib3==1.26.6 52 | wasabi==0.8.2 53 | wcwidth==0.2.5 54 | Werkzeug==2.0.1 55 | -------------------------------------------------------------------------------- /UNITER/scripts/download_nlvr2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'ann' 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # annotations 15 | NLVR='https://raw.githubusercontent.com/lil-lab/nlvr/master/nlvr2/data' 16 | wget $NLVR/dev.json -P $DOWNLOAD/ann/ 17 | wget $NLVR/test1.json -P $DOWNLOAD/ann/ 18 | 19 | # image dbs 20 | for SPLIT in 'train' 'dev' 'test'; do 21 | wget $BLOB/img_db/nlvr2_$SPLIT.tar -P $DOWNLOAD/img_db/ 22 | tar -xvf $DOWNLOAD/img_db/nlvr2_$SPLIT.tar -C $DOWNLOAD/img_db 23 | done 24 | 25 | # text dbs 26 | for SPLIT in 'train' 'dev' 'test1'; do 27 | wget $BLOB/txt_db/nlvr2_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 28 | tar -xvf $DOWNLOAD/txt_db/nlvr2_$SPLIT.db.tar -C $DOWNLOAD/txt_db 29 | done 30 | 31 | if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then 32 | wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ 33 | fi 34 | 35 | wget $BLOB/finetune/nlvr-base.tar -P $DOWNLOAD/finetune/ 36 | tar -xvf $DOWNLOAD/finetune/nlvr-base.tar -C $DOWNLOAD/finetune 37 | -------------------------------------------------------------------------------- /UNITER/scripts/create_txtdb.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | OUT_DIR=$1 5 | ANN_DIR=$2 6 | 7 | set -e 8 | 9 | URL='https://raw.githubusercontent.com/lil-lab/nlvr/master/nlvr2/data' 10 | if [ ! -d $OUT_DIR ]; then 11 | mkdir -p $OUT_DIR 12 | fi 13 | if [ ! -d $ANN_DIR ]; then 14 | mkdir -p $ANN_DIR 15 | fi 16 | 17 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 18 | MISSING=$BLOB/ann/missing_nlvr2_imgs.json 19 | if [ ! -f $ANN_DIR/missing_nlvr2_imgs.json ]; then 20 | wget $MISSING -O $ANN_DIR/missing_nlvr2_imgs.json 21 | fi 22 | 23 | for SPLIT in 'train' 'dev' 'test1'; do 24 | if [ ! -f $ANN_DIR/$SPLIT.json ]; then 25 | echo "downloading ${SPLIT} annotations..." 26 | wget $URL/$SPLIT.json -O $ANN_DIR/$SPLIT.json 27 | fi 28 | 29 | echo "preprocessing ${SPLIT} annotations..." 30 | docker run --ipc=host --rm -it \ 31 | --mount src=$(pwd),dst=/src,type=bind \ 32 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 33 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 34 | -w /src chenrocks/uniter \ 35 | python prepro.py --annotation /ann/$SPLIT.json \ 36 | --missing_imgs /ann/missing_nlvr2_imgs.json \ 37 | --output /txt_db/nlvr2_${SPLIT}.db --task nlvr2 38 | done 39 | 40 | echo "done" 41 | -------------------------------------------------------------------------------- /UNITER/scripts/download_itm.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image dbs 15 | for SPLIT in 'train2014' 'val2014'; do 16 | if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then 17 | wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/ 18 | tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db 19 | fi 20 | done 21 | if [ ! -d $DOWNLOAD/img_db/flickr30k ] ; then 22 | wget $BLOB/img_db/flickr30k.tar -P $DOWNLOAD/img_db/ 23 | tar -xvf $DOWNLOAD/img_db/flickr30k.tar -C $DOWNLOAD/img_db 24 | fi 25 | 26 | # text dbs 27 | for SPLIT in 'train' 'restval' 'val' 'test'; do 28 | wget $BLOB/txt_db/itm_coco_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 29 | tar -xvf $DOWNLOAD/txt_db/itm_coco_$SPLIT.db.tar -C $DOWNLOAD/txt_db 30 | done 31 | for SPLIT in 'train' 'val' 'test'; do 32 | wget $BLOB/txt_db/itm_flickr30k_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/itm_flickr30k_$SPLIT.db.tar -C $DOWNLOAD/txt_db 34 | done 35 | 36 | if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then 37 | wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ 38 | fi 39 | 40 | -------------------------------------------------------------------------------- /UNITER/optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | optimizer learning rate scheduling helpers 6 | """ 7 | from math import ceil 8 | 9 | 10 | def noam_schedule(step, warmup_step=4000): 11 | """ original Transformer schedule""" 12 | if step <= warmup_step: 13 | return step / warmup_step 14 | return (warmup_step ** 0.5) * (step ** -0.5) 15 | 16 | 17 | def warmup_linear(step, warmup_step, tot_step): 18 | """ BERT schedule """ 19 | if step < warmup_step: 20 | return step / warmup_step 21 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 22 | 23 | 24 | def vqa_schedule(step, warmup_interval, decay_interval, 25 | decay_start, decay_rate): 26 | """ VQA schedule from MCAN """ 27 | if step < warmup_interval: 28 | return 1/4 29 | elif step < 2 * warmup_interval: 30 | return 2/4 31 | elif step < 3 * warmup_interval: 32 | return 3/4 33 | elif step >= decay_start: 34 | num_decay = ceil((step - decay_start) / decay_interval) 35 | return decay_rate ** num_decay 36 | else: 37 | return 1 38 | 39 | 40 | def get_lr_sched(global_step, opts): 41 | # learning rate scheduling 42 | lr_this_step = opts.learning_rate * warmup_linear( 43 | global_step, opts.warmup_steps, opts.num_train_steps) 44 | if lr_this_step <= 0: 45 | lr_this_step = 1e-8 46 | return lr_this_step 47 | -------------------------------------------------------------------------------- /UNITER/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.05-py3 2 | 3 | # basic python packages 4 | RUN pip install pytorch-pretrained-bert==0.6.2 \ 5 | tensorboardX==1.7 ipdb==0.12 lmdb==0.97 6 | 7 | ####### horovod for multi-GPU (distributed) training ####### 8 | 9 | # update OpenMPI to avoid horovod bug 10 | RUN rm -r /usr/local/mpi &&\ 11 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.4.tar.gz &&\ 12 | gunzip -c openmpi-3.1.4.tar.gz | tar xf - &&\ 13 | cd openmpi-3.1.4 &&\ 14 | ./configure --prefix=/usr/local/mpi --enable-orterun-prefix-by-default \ 15 | --with-verbs --disable-getpwuid &&\ 16 | make -j$(nproc) all && make install &&\ 17 | ldconfig &&\ 18 | cd - && rm -r openmpi-3.1.4 && rm openmpi-3.1.4.tar.gz 19 | 20 | ENV OPENMPI_VERSION=3.1.4 21 | 22 | # horovod 23 | RUN HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \ 24 | pip install --no-cache-dir horovod==0.16.4 &&\ 25 | ldconfig 26 | 27 | # ssh 28 | RUN apt-get update &&\ 29 | apt-get install -y --no-install-recommends openssh-client openssh-server &&\ 30 | mkdir -p /var/run/sshd 31 | 32 | # Allow OpenSSH to talk to containers without asking for confirmation 33 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 34 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 35 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 36 | 37 | 38 | WORKDIR /src 39 | -------------------------------------------------------------------------------- /UNITER/scripts/download_indomain.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | DOWNLOAD=$1 5 | 6 | for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do 7 | if [ ! -d $DOWNLOAD/$FOLDER ] ; then 8 | mkdir -p $DOWNLOAD/$FOLDER 9 | fi 10 | done 11 | 12 | BLOB='https://acvrpublicycchen.blob.core.windows.net/uniter' 13 | 14 | # image dbs 15 | for SPLIT in 'train2014' 'val2014'; do 16 | if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then 17 | wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/ 18 | tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db 19 | fi 20 | done 21 | if [ ! -d $DOWNLOAD/img_db/vg ] ; then 22 | wget $BLOB/img_db/vg.tar -P $DOWNLOAD/img_db/ 23 | tar -xvf $DOWNLOAD/img_db/vg.tar -C $DOWNLOAD/img_db 24 | fi 25 | 26 | # text dbs 27 | for SPLIT in 'train' 'restval' 'val'; do 28 | wget $BLOB/txt_db/pretrain_coco_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 29 | tar -xvf $DOWNLOAD/txt_db/pretrain_coco_$SPLIT.db.tar -C $DOWNLOAD/txt_db 30 | done 31 | for SPLIT in 'train' 'val'; do 32 | wget $BLOB/txt_db/pretrain_vg_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ 33 | tar -xvf $DOWNLOAD/txt_db/pretrain_vg_$SPLIT.db.tar -C $DOWNLOAD/txt_db 34 | done 35 | 36 | # converted BERT 37 | for MODEL in base large; do 38 | if [ ! -f $DOWNLOAD/pretrained/uniter-$MODEL-init.pt ] ; then 39 | wget $BLOB/pretrained/uniter-$MODEL-init.pt -P $DOWNLOAD/pretrained/ 40 | fi 41 | done 42 | 43 | -------------------------------------------------------------------------------- /albef/albef_license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /UNITER/scripts/create_txtdb_re.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | OUT_DIR=$1 5 | ANN_DIR=$2 # where refercoco annotations are saved 6 | 7 | set -e 8 | 9 | if [ ! -f $ANN_DIR/iid2bb_id/iid_to_ann_ids.json ]; then 10 | echo "pre-compute iid_to_ann_ids.json for all RE datasets following https://github.com/lichengunc/MAttNet/blob/butd_feats/tools/map_iid_to_ann_ids.py ..." 11 | exit 12 | fi 13 | 14 | for DATA in 'refcoco' 'refcoco+'; do 15 | for SPLIT in 'train' 'val' 'testA' 'testB'; do 16 | echo "preprocessing ${DATA} ${SPLIT} annotations..." 17 | docker run --ipc=host --rm -it \ 18 | --mount src=$(pwd),dst=/src,type=bind \ 19 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 20 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 21 | -w /src chenrocks/uniter \ 22 | python prepro.py --annotation /ann/$DATA/refs\(unc\).p /ann/$DATA/instances.json /ann/iid2bb_id/iid_to_ann_ids.json \ 23 | --task re \ 24 | --output /txt_db/${DATA}_${SPLIT}.db 25 | done 26 | done 27 | 28 | DATA='refcocog' 29 | for SPLIT in 'train' 'val' 'test'; do 30 | echo "preprocessing ${DATA} ${SPLIT} annotations..." 31 | docker run --ipc=host --rm -it \ 32 | --mount src=$(pwd),dst=/src,type=bind \ 33 | --mount src=$OUT_DIR,dst=/txt_db,type=bind \ 34 | --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ 35 | -w /src chenrocks/uniter \ 36 | python prepro.py --annotation /ann/$DATA/refs\(umd\).p /ann/$DATA/instances.json /ann/iid2bb_id/iid_to_ann_ids.json \ 37 | --task re \ 38 | --output /txt_db/${DATA}_${SPLIT}.db 39 | done 40 | 41 | echo "done" 42 | -------------------------------------------------------------------------------- /UNITER/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc utilities 6 | """ 7 | import json 8 | import random 9 | import sys 10 | 11 | import torch 12 | import numpy as np 13 | 14 | from utils.logger import LOGGER 15 | 16 | 17 | class NoOp(object): 18 | """ useful for distributed training No-Ops """ 19 | def __getattr__(self, name): 20 | return self.noop 21 | 22 | def noop(self, *args, **kwargs): 23 | return 24 | 25 | 26 | def parse_with_config(parser): 27 | args = parser.parse_args() 28 | if args.config is not None: 29 | config_args = json.load(open(args.config)) 30 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] 31 | if arg.startswith('--')} 32 | for k, v in config_args.items(): 33 | if k not in override_keys: 34 | setattr(args, k, v) 35 | del args.config 36 | return args 37 | 38 | 39 | VE_ENT2IDX = { 40 | 'contradiction': 0, 41 | 'entailment': 1, 42 | 'neutral': 2 43 | } 44 | 45 | VE_IDX2ENT = { 46 | 0: 'contradiction', 47 | 1: 'entailment', 48 | 2: 'neutral' 49 | } 50 | 51 | 52 | class Struct(object): 53 | def __init__(self, dict_): 54 | self.__dict__.update(dict_) 55 | 56 | 57 | def set_dropout(model, drop_p): 58 | for name, module in model.named_modules(): 59 | # we might want to tune dropout for smaller dataset 60 | if isinstance(module, torch.nn.Dropout): 61 | if module.p != drop_p: 62 | module.p = drop_p 63 | LOGGER.info(f'{name} set to {drop_p}') 64 | 65 | 66 | def set_random_seed(seed): 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed_all(seed) 71 | -------------------------------------------------------------------------------- /UNITER/model/vqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Uniter for VQA model 6 | """ 7 | from collections import defaultdict 8 | 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 12 | 13 | from .layer import GELU 14 | from .model import UniterPreTrainedModel, UniterModel 15 | 16 | 17 | class UniterForVisualQuestionAnswering(UniterPreTrainedModel): 18 | """ Finetune UNITER for VQA 19 | """ 20 | def __init__(self, config, img_dim, num_answer): 21 | super().__init__(config) 22 | self.uniter = UniterModel(config, img_dim) 23 | self.vqa_output = nn.Sequential( 24 | nn.Linear(config.hidden_size, config.hidden_size*2), 25 | GELU(), 26 | LayerNorm(config.hidden_size*2, eps=1e-12), 27 | nn.Linear(config.hidden_size*2, num_answer) 28 | ) 29 | self.apply(self.init_weights) 30 | 31 | def forward(self, batch, compute_loss=True): 32 | batch = defaultdict(lambda: None, batch) 33 | input_ids = batch['input_ids'] 34 | position_ids = batch['position_ids'] 35 | img_feat = batch['img_feat'] 36 | img_pos_feat = batch['img_pos_feat'] 37 | attn_masks = batch['attn_masks'] 38 | gather_index = batch['gather_index'] 39 | sequence_output = self.uniter(input_ids, position_ids, 40 | img_feat, img_pos_feat, 41 | attn_masks, gather_index, 42 | output_all_encoded_layers=False) 43 | pooled_output = self.uniter.pooler(sequence_output) 44 | answer_scores = self.vqa_output(pooled_output) 45 | 46 | if compute_loss: 47 | targets = batch['targets'] 48 | vqa_loss = F.binary_cross_entropy_with_logits( 49 | answer_scores, targets, reduction='none') 50 | return vqa_loss 51 | else: 52 | return answer_scores 53 | -------------------------------------------------------------------------------- /lattice.py: -------------------------------------------------------------------------------- 1 | """Implement lattice interface.""" 2 | 3 | from overrides import overrides 4 | import numpy as np 5 | from abc import ABCMeta, abstractmethod 6 | 7 | 8 | class Lattice(metaclass=ABCMeta): 9 | 10 | """Abstract base class representing a complemented lattice.""" 11 | 12 | @classmethod 13 | @abstractmethod 14 | def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: 15 | return NotImplemented 16 | 17 | @classmethod 18 | @abstractmethod 19 | def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: 20 | return NotImplemented 21 | 22 | @classmethod 23 | @abstractmethod 24 | def join_reduce(cls, probs: np.ndarray) -> np.ndarray: 25 | return NotImplemented 26 | 27 | @classmethod 28 | @abstractmethod 29 | def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: 30 | return NotImplemented 31 | 32 | 33 | class Product(Lattice): 34 | """Lattice where meet=prod and sum is defined accordingly. 35 | 36 | Equivalent to assuming independence, more or less. 37 | """ 38 | 39 | eps = 1e-9 40 | 41 | @classmethod 42 | @overrides 43 | def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: 44 | return probs1 + probs2 - cls.meet(probs1, probs2) 45 | 46 | @classmethod 47 | @overrides 48 | def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: 49 | return probs1 * probs2 50 | 51 | @classmethod 52 | @overrides 53 | def join_reduce(cls, probs: np.ndarray) -> np.ndarray: 54 | """Assumes disjoint events.""" 55 | # return cls.comp(cls.meet_reduce(cls.comp(probs))) 56 | return np.sum(probs, axis=-1) 57 | 58 | @classmethod 59 | @overrides 60 | def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: 61 | return np.prod(probs, axis=-1) 62 | 63 | @classmethod 64 | def comp(cls, probs): 65 | return 1 - probs 66 | 67 | @classmethod 68 | def normalize(cls, probs): 69 | """Normalize a distribution by dividing by the total mass.""" 70 | return probs / np.sum(probs + cls.eps, axis=-1) 71 | -------------------------------------------------------------------------------- /methods/baseline.py: -------------------------------------------------------------------------------- 1 | """A naive baseline method: just pass the full expression to CLIP.""" 2 | 3 | from overrides import overrides 4 | from typing import Dict, Any, List 5 | import numpy as np 6 | import torch 7 | import spacy 8 | from argparse import Namespace 9 | 10 | from .ref_method import RefMethod 11 | from lattice import Product as L 12 | 13 | 14 | class Baseline(RefMethod): 15 | """CLIP-only baseline where each box is evaluated with the full expression.""" 16 | 17 | nlp = spacy.load('en_core_web_sm') 18 | 19 | def __init__(self, args: Namespace): 20 | self.args = args 21 | self.box_area_threshold = args.box_area_threshold 22 | self.batch_size = args.batch_size 23 | self.batch = [] 24 | 25 | @overrides 26 | def execute(self, caption: str, env: "Environment",caption_bank: List[str]=[], mask_dino=None) -> Dict[str, Any]: 27 | chunk_texts = self.get_chunk_texts(caption) 28 | probs,attn = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True, caption_bank=caption_bank, mask_dino=mask_dino) 29 | if self.args.baseline_head: 30 | probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True) 31 | probs = L.meet(probs, probs2) 32 | pred = np.argmax(probs) 33 | return { 34 | "probs": probs, 35 | "pred": pred, 36 | "box": env.boxes[pred], 37 | "attn": attn, 38 | #"textattn":textattn 39 | } 40 | 41 | def get_chunk_texts(self, expression: str) -> List: 42 | doc = self.nlp(expression) 43 | head = None 44 | for token in doc: 45 | if token.head.i == token.i: 46 | head = token 47 | break 48 | head_chunk = None 49 | chunk_texts = [] 50 | for chunk in doc.noun_chunks: 51 | if head.i >= chunk.start and head.i < chunk.end: 52 | head_chunk = chunk.text 53 | chunk_texts.append(chunk.text) 54 | if head_chunk is None: 55 | if len(list(doc.noun_chunks)) > 0: 56 | head_chunk = list(doc.noun_chunks)[0].text 57 | else: 58 | head_chunk = expression 59 | return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk] 60 | -------------------------------------------------------------------------------- /clevr-dataset-gen/bounding_box.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright 2017 Larry Chen 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ''' 16 | 17 | import numpy as np 18 | 19 | def extract_bounding_boxes(scene, names): 20 | objs = scene['objects'] 21 | rotation = scene['directions']['right'] 22 | 23 | num_boxes = len(objs) 24 | 25 | boxes = np.zeros((1, num_boxes, 4)) 26 | 27 | xmin = [] 28 | ymin = [] 29 | xmax = [] 30 | ymax = [] 31 | classes = [] 32 | classes_text = [] 33 | 34 | for i, obj in enumerate(objs): 35 | [x, y, z] = obj['pixel_coords'] 36 | 37 | [x1, y1, z1] = obj['3d_coords'] 38 | 39 | cos_theta, sin_theta, _ = rotation 40 | 41 | x1 = x1 * cos_theta + y1* sin_theta 42 | y1 = x1 * -sin_theta + y1 * cos_theta 43 | 44 | 45 | height_d = 6.9 * z1 * (15 - y1) / 2.0 46 | height_u = height_d 47 | width_l = height_d 48 | width_r = height_d 49 | 50 | if obj['shape'] == 'cylinder': 51 | d = 9.4 + y1 52 | h = 6.4 53 | s = z1 54 | 55 | height_u *= (s*(h/d + 1)) / ((s*(h/d + 1)) - (s*(h-s)/d)) 56 | height_d = height_u * (h-s+d)/ (h + s + d) 57 | 58 | width_l *= 11/(10 + y1) 59 | width_r = width_l 60 | 61 | if obj['shape'] == 'cube': 62 | height_u *= 1.3 * 10 / (10 + y1) 63 | height_d = height_u 64 | width_l = height_u 65 | width_r = height_u 66 | 67 | obj_name = obj['size'] + ' ' + obj['color'] + ' ' + obj['material'] + ' ' + obj['shape'] 68 | classes_text.append(obj_name.encode('utf8')) 69 | classes.append(names.index(obj_name) + 1) 70 | ymin.append((y - height_d)/320.0) 71 | ymax.append((y + height_u)/320.0) 72 | xmin.append((x - width_l)/480.0) 73 | xmax.append((x + width_r)/480.0) 74 | 75 | return xmin, ymin, xmax, ymax, classes, classes_text 76 | -------------------------------------------------------------------------------- /UNITER/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for logging 6 | NOTE: loggers are global objects use with caution 7 | """ 8 | import logging 9 | import math 10 | 11 | import tensorboardX 12 | 13 | 14 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 16 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 17 | LOGGER = logging.getLogger('__main__') # this is the global logger 18 | 19 | 20 | def add_log_to_file(log_path): 21 | fh = logging.FileHandler(log_path) 22 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 23 | fh.setFormatter(formatter) 24 | LOGGER.addHandler(fh) 25 | 26 | 27 | class TensorboardLogger(object): 28 | def __init__(self): 29 | self._logger = None 30 | self._global_step = 0 31 | 32 | def create(self, path): 33 | self._logger = tensorboardX.SummaryWriter(path) 34 | 35 | def noop(self, *args, **kwargs): 36 | return 37 | 38 | def step(self): 39 | self._global_step += 1 40 | 41 | @property 42 | def global_step(self): 43 | return self._global_step 44 | 45 | def log_scaler_dict(self, log_dict, prefix=''): 46 | """ log a dictionary of scalar values""" 47 | if self._logger is None: 48 | return 49 | if prefix: 50 | prefix = f'{prefix}_' 51 | for name, value in log_dict.items(): 52 | if isinstance(value, dict): 53 | self.log_scaler_dict(value, self._global_step, 54 | prefix=f'{prefix}{name}') 55 | else: 56 | self._logger.add_scalar(f'{prefix}{name}', value, 57 | self._global_step) 58 | 59 | def __getattr__(self, name): 60 | if self._logger is None: 61 | return self.noop 62 | return self._logger.__getattribute__(name) 63 | 64 | 65 | TB_LOGGER = TensorboardLogger() 66 | 67 | 68 | class RunningMeter(object): 69 | """ running meteor of a scalar value 70 | (useful for monitoring training loss) 71 | """ 72 | def __init__(self, name, val=None, smooth=0.99): 73 | self._name = name 74 | self._sm = smooth 75 | self._val = val 76 | 77 | def __call__(self, value): 78 | val = (value if self._val is None 79 | else value*(1-self._sm) + self._val*self._sm) 80 | if not math.isnan(val): 81 | self._val = val 82 | 83 | def __str__(self): 84 | return f'{self._name}: {self._val:.4f}' 85 | 86 | @property 87 | def val(self): 88 | if self._val is None: 89 | return 0 90 | return self._val 91 | 92 | @property 93 | def name(self): 94 | return self._name 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FALIP: Visual Prompt as Foveal Attention Boosts CLIP Zero-Shot Performance 2 | This repository contains the code for the paper [FALIP: Visual Prompt as Foveal Attention Boosts CLIP Zero-Shot Performance](https://arxiv.org/abs/2407.05578) 3 | (ECCV 2024). 4 | 5 | ## Data Download 6 | Download preprocessed data files via `gsutil cp gs://reclip-sanjays/reclip_data.tar.gz`, and extract the data using `tar -xvzf reclip_data.tar.gz`. This data 7 | does not include images. 8 | Download the images for RefCOCO/g/+ from [http://images.cocodataset.org/zips/train2014.zip](http://images.cocodataset.org/zips/train2014.zip). 9 | 10 | ## Results with CLIP 11 | The following format can be used to run experiments: 12 | ``` 13 | pip install -r requirements.txt 14 | python main.py --input_file INPUT_FILE --image_root IMAGE_ROOT --method baseline --box_method_aggregator sum --clip_model ViT-B/16 --box_representation_method box --detector_file PATH_TO_DETECTOR_FILE 15 | ``` 16 | 17 | `--input_file`: should be in `.jsonl` format (we provide these files for the datasets discussed in our paper; see the Data Download information above). 18 | 19 | `--image_root`: the top-level directory containing all images in the dataset. For RefCOCO/g/+, this is the `train2014` directory. 20 | 21 | `--detector_file`: if not specified, ground-truth proposals are used. For RefCOCO/g/+, the detection files are in `reclip_data.tar.gz` and have the format `{refcoco/refcocog/refcoco+}_dets_dict.json`. 22 | 23 | Choices for `method`: "parse" is the full version of ReCLIP that includes isolated proposal scoring and the heuristic-based relation handling system. "baseline" is the version of ReCLIP using only isolated proposal scoring. "gradcam" uses GradCAM, and "random" selects one of the proposals uniformly at random. (default: "baseline") 24 | 25 | Choices for `clip_model`: The choices are the same as the model names used in the CLIP repository except that the model names can be concatenated with a comma between consecutive names. (default: "ViT-B/16") 26 | 27 | Choices for `box_representation_method`: This argument dictates which of the following methods is used to score proposals: CPT-adapted, cropping, blurring, or some combination of these. For CPT-adapted, choose "shade". To use more than one method, concatenate them with a comma between consecutive methods. (default: "box") 28 | 29 | To see explanations of other arguments see the `main.py` file. 30 | 31 | ## Acknowledgements 32 | Thanks to the codebase [ReCLIP](https://github.com/allenai/reclip). 33 | 34 | ## Citation 35 | If you find this repository useful, please cite our paper: 36 | ``` 37 | @article{zhuang2024falip, 38 | title={FALIP: Visual Prompt as Foveal Attention Boosts CLIP Zero-Shot Performance}, 39 | author={Zhuang, Jiedong and Hu, Jiaqi and Mu, Lianrui and Hu, Rui and Liang, Xiaoyu and Ye, Jiangnan and Hu, Haoji}, 40 | journal={arXiv preprint arXiv:2407.05578}, 41 | year={2024} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /heuristics.py: -------------------------------------------------------------------------------- 1 | """Heuristic rules used to extract and execute entity parses.""" 2 | 3 | from typing import Callable, List, NamedTuple 4 | from argparse import Namespace 5 | import numpy as np 6 | 7 | 8 | class RelHeuristic(NamedTuple): 9 | keywords: List[str] 10 | callback: Callable[["Environment"], np.ndarray] 11 | 12 | 13 | class Heuristics: 14 | """A class defining heuristics that can be enabled/disabled.""" 15 | 16 | RELATIONS = [ 17 | RelHeuristic(["left", "west"], lambda env: env.left_of()), 18 | RelHeuristic(["right", "east"], lambda env: env.right_of()), 19 | RelHeuristic(["above", "north", "top", "back", "behind"], lambda env: env.above()), 20 | RelHeuristic(["below", "south", "under", "front"], lambda env: env.below()), 21 | RelHeuristic(["bigger", "larger", "closer"], lambda env: env.bigger_than()), 22 | RelHeuristic(["smaller", "tinier", "further"], lambda env: env.smaller_than()), 23 | RelHeuristic(["inside", "within", "contained"], lambda env: env.within()), 24 | ] 25 | 26 | TERNARY_RELATIONS = [ 27 | RelHeuristic(["between"], lambda env: env.between()), 28 | ] 29 | 30 | SUPERLATIVES = [ 31 | RelHeuristic(["left", "west", "leftmost", "western"], lambda env: env.left_of()), 32 | RelHeuristic(["right", "rightmost", "east", "eastern"], lambda env: env.right_of()), 33 | RelHeuristic(["above", "north", "top"], lambda env: env.above()), 34 | RelHeuristic(["below", "south", "underneath", "front"], lambda env: env.below()), 35 | RelHeuristic(["bigger", "biggest", "larger", "largest", "closer", "closest"], lambda env: env.bigger_than()), 36 | RelHeuristic(["smaller", "smallest", "tinier", "tiniest", "further", "furthest"], lambda env: env.smaller_than()), 37 | ] 38 | OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2, 4: 5, 5: 4} 39 | 40 | NULL_KEYWORDS = ["part", "image", "side", "picture", "half", "region", "section"] 41 | 42 | EMPTY = [] 43 | 44 | def __init__(self, args: Namespace = None): 45 | self.enable_relations = not args or not args.no_rel 46 | self.enable_superlatives = not args or not args.no_sup 47 | self.enable_nulls = not args or not args.no_null 48 | self.enable_ternary = not args or args.ternary 49 | 50 | @property 51 | def relations(self) -> List[RelHeuristic]: 52 | return self.RELATIONS if self.enable_relations else self.EMPTY 53 | 54 | @property 55 | def ternary_relations(self) -> List[RelHeuristic]: 56 | return self.TERNARY_RELATIONS if self.enable_ternary else self.EMPTY 57 | 58 | @property 59 | def superlatives(self) -> List[RelHeuristic]: 60 | return self.SUPERLATIVES if self.enable_superlatives else self.EMPTY 61 | 62 | @property 63 | def opposites(self): 64 | return self.OPPOSITES 65 | 66 | @property 67 | def null_keywords(self) -> List[str]: 68 | return self.NULL_KEYWORDS if self.enable_nulls else self.EMPTY 69 | -------------------------------------------------------------------------------- /UNITER/utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | saving utilities 6 | """ 7 | import json 8 | import os 9 | from os.path import abspath, dirname, exists, join 10 | import subprocess 11 | 12 | import torch 13 | 14 | from utils.logger import LOGGER 15 | 16 | 17 | def save_training_meta(args): 18 | if args.rank > 0: 19 | return 20 | 21 | if not exists(args.output_dir): 22 | os.makedirs(join(args.output_dir, 'log')) 23 | os.makedirs(join(args.output_dir, 'ckpt')) 24 | 25 | with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: 26 | json.dump(vars(args), writer, indent=4) 27 | model_config = json.load(open(args.model_config)) 28 | with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: 29 | json.dump(model_config, writer, indent=4) 30 | # git info 31 | try: 32 | LOGGER.info("Waiting on git info....") 33 | c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], 34 | timeout=10, stdout=subprocess.PIPE) 35 | git_branch_name = c.stdout.decode().strip() 36 | LOGGER.info("Git branch: %s", git_branch_name) 37 | c = subprocess.run(["git", "rev-parse", "HEAD"], 38 | timeout=10, stdout=subprocess.PIPE) 39 | git_sha = c.stdout.decode().strip() 40 | LOGGER.info("Git SHA: %s", git_sha) 41 | git_dir = abspath(dirname(__file__)) 42 | git_status = subprocess.check_output( 43 | ['git', 'status', '--short'], 44 | cwd=git_dir, universal_newlines=True).strip() 45 | with open(join(args.output_dir, 'log', 'git_info.json'), 46 | 'w') as writer: 47 | json.dump({'branch': git_branch_name, 48 | 'is_dirty': bool(git_status), 49 | 'status': git_status, 50 | 'sha': git_sha}, 51 | writer, indent=4) 52 | except subprocess.TimeoutExpired as e: 53 | LOGGER.exception(e) 54 | LOGGER.warn("Git info not found. Moving right along...") 55 | 56 | 57 | class ModelSaver(object): 58 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 59 | self.output_dir = output_dir 60 | self.prefix = prefix 61 | self.suffix = suffix 62 | 63 | def save(self, model, step, optimizer=None): 64 | output_model_file = join(self.output_dir, 65 | f"{self.prefix}_{step}.{self.suffix}") 66 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 67 | for k, v in model.state_dict().items()} 68 | torch.save(state_dict, output_model_file) 69 | if optimizer is not None: 70 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 71 | if hasattr(optimizer, '_amp_stash'): 72 | pass # TODO fp16 optimizer 73 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 74 | -------------------------------------------------------------------------------- /UNITER/model/ot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Wasserstein Distance (Optimal Transport) 6 | """ 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | def cost_matrix_cosine(x, y, eps=1e-5): 12 | """ Compute cosine distnace across every pairs of x, y (batched) 13 | [B, L_x, D] [B, L_y, D] -> [B, Lx, Ly]""" 14 | assert x.dim() == y.dim() 15 | assert x.size(0) == y.size(0) 16 | assert x.size(2) == y.size(2) 17 | x_norm = F.normalize(x, p=2, dim=-1, eps=eps) 18 | y_norm = F.normalize(y, p=2, dim=-1, eps=eps) 19 | cosine_sim = x_norm.matmul(y_norm.transpose(1, 2)) 20 | cosine_dist = 1 - cosine_sim 21 | return cosine_dist 22 | 23 | 24 | def trace(x): 25 | """ compute trace of input tensor (batched) """ 26 | b, m, n = x.size() 27 | assert m == n 28 | mask = torch.eye(n, dtype=torch.uint8, device=x.device 29 | ).unsqueeze(0).expand_as(x) 30 | trace = x.masked_select(mask).contiguous().view( 31 | b, n).sum(dim=-1, keepdim=False) 32 | return trace 33 | 34 | 35 | @torch.no_grad() 36 | def ipot(C, x_len, x_pad, y_len, y_pad, joint_pad, beta, iteration, k): 37 | """ [B, M, N], [B], [B, M], [B], [B, N], [B, M, N]""" 38 | b, m, n = C.size() 39 | sigma = torch.ones(b, m, dtype=C.dtype, device=C.device 40 | ) / x_len.unsqueeze(1) 41 | T = torch.ones(b, n, m, dtype=C.dtype, device=C.device) 42 | A = torch.exp(-C.transpose(1, 2)/beta) 43 | 44 | # mask padded positions 45 | sigma.masked_fill_(x_pad, 0) 46 | joint_pad = joint_pad.transpose(1, 2) 47 | T.masked_fill_(joint_pad, 0) 48 | A.masked_fill_(joint_pad, 0) 49 | 50 | # broadcastable lengths 51 | x_len = x_len.unsqueeze(1).unsqueeze(2) 52 | y_len = y_len.unsqueeze(1).unsqueeze(2) 53 | 54 | # mask to zero out padding in delta and sigma 55 | x_mask = (x_pad.to(C.dtype) * 1e4).unsqueeze(1) 56 | y_mask = (y_pad.to(C.dtype) * 1e4).unsqueeze(1) 57 | 58 | for _ in range(iteration): 59 | Q = A * T # bs * n * m 60 | sigma = sigma.view(b, m, 1) 61 | for _ in range(k): 62 | delta = 1 / (y_len * Q.matmul(sigma).view(b, 1, n) + y_mask) 63 | sigma = 1 / (x_len * delta.matmul(Q) + x_mask) 64 | T = delta.view(b, n, 1) * Q * sigma 65 | T.masked_fill_(joint_pad, 0) 66 | return T 67 | 68 | 69 | def optimal_transport_dist(txt_emb, img_emb, txt_pad, img_pad, 70 | beta=0.5, iteration=50, k=1): 71 | """ [B, M, D], [B, N, D], [B, M], [B, N]""" 72 | cost = cost_matrix_cosine(txt_emb, img_emb) 73 | # mask the padded inputs 74 | joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) 75 | cost.masked_fill_(joint_pad, 0) 76 | 77 | txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False) 78 | ).to(dtype=cost.dtype) 79 | img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False) 80 | ).to(dtype=cost.dtype) 81 | 82 | T = ipot(cost.detach(), txt_len, txt_pad, img_len, img_pad, joint_pad, 83 | beta, iteration, k) 84 | distance = trace(cost.matmul(T.detach())) 85 | return distance 86 | -------------------------------------------------------------------------------- /pytorch_grad_cam/activations_and_gradients.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/activations_and_gradients.py 2 | """ 3 | MIT License 4 | 5 | Copyright (c) 2021 Jacob Gildenblat 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | class ActivationsAndGradients: 27 | """ Class for extracting activations and 28 | registering gradients from targetted intermediate layers """ 29 | 30 | def __init__(self, model, target_layers, reshape_transform): 31 | self.model = model 32 | self.gradients = [] 33 | self.activations = [] 34 | self.reshape_transform = reshape_transform 35 | self.handles = [] 36 | for target_layer in target_layers: 37 | self.handles.append( 38 | target_layer.register_forward_hook(self.save_activation)) 39 | # Because of https://github.com/pytorch/pytorch/issues/61519, 40 | # we don't use backward hook to record gradients. 41 | self.handles.append( 42 | target_layer.register_forward_hook(self.save_gradient)) 43 | 44 | def save_activation(self, module, input, output): 45 | activation = output 46 | 47 | if self.reshape_transform is not None: 48 | activation = self.reshape_transform(activation) 49 | self.activations.append(activation.detach()) 50 | 51 | def save_gradient(self, module, input, output): 52 | if not hasattr(output, "requires_grad") or not output.requires_grad: 53 | # You can only register hooks on tensor requires grad. 54 | return 55 | 56 | # Gradients are computed in reverse order 57 | def _store_grad(grad): 58 | if self.reshape_transform is not None: 59 | grad = self.reshape_transform(grad) 60 | self.gradients = [grad.detach()] + self.gradients 61 | 62 | output.register_hook(_store_grad) 63 | 64 | def __call__(self, *args): 65 | self.gradients = [] 66 | self.activations = [] 67 | return self.model(*args) 68 | 69 | def release(self): 70 | for handle in self.handles: 71 | handle.remove() 72 | -------------------------------------------------------------------------------- /UNITER/model/vcr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Uniter for VCR model 6 | """ 7 | from collections import defaultdict 8 | 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 12 | 13 | # from .layer import GELU 14 | from .model import ( 15 | UniterPreTrainedModel, UniterModel) 16 | 17 | 18 | class UniterForVisualCommonsenseReasoning(UniterPreTrainedModel): 19 | """ Finetune UNITER for VCR 20 | """ 21 | def __init__(self, config, img_dim): 22 | super().__init__(config, img_dim) 23 | self.uniter = UniterModel(config, img_dim) 24 | self.vcr_output = nn.Sequential( 25 | nn.Linear(config.hidden_size, config.hidden_size*2), 26 | nn.ReLU(), 27 | LayerNorm(config.hidden_size*2, eps=1e-12), 28 | nn.Linear(config.hidden_size*2, 2) 29 | ) 30 | self.apply(self.init_weights) 31 | 32 | def init_type_embedding(self): 33 | new_emb = nn.Embedding(4, self.uniter.config.hidden_size) 34 | new_emb.apply(self.init_weights) 35 | for i in [0, 1]: 36 | emb = self.uniter.embeddings.token_type_embeddings.weight.data[i, :] 37 | new_emb.weight.data[i, :].copy_(emb) 38 | emb = self.uniter.embeddings.token_type_embeddings.weight.data[0, :] 39 | new_emb.weight.data[2, :].copy_(emb) 40 | new_emb.weight.data[3, :].copy_(emb) 41 | self.uniter.embeddings.token_type_embeddings = new_emb 42 | 43 | def init_word_embedding(self, num_special_tokens): 44 | orig_word_num = self.uniter.embeddings.word_embeddings.weight.size(0) 45 | new_emb = nn.Embedding( 46 | orig_word_num + num_special_tokens, self.uniter.config.hidden_size) 47 | new_emb.apply(self.init_weights) 48 | emb = self.uniter.embeddings.word_embeddings.weight.data 49 | new_emb.weight.data[:orig_word_num, :].copy_(emb) 50 | self.uniter.embeddings.word_embeddings = new_emb 51 | 52 | def forward(self, batch, compute_loss=True): 53 | batch = defaultdict(lambda: None, batch) 54 | input_ids = batch['input_ids'] 55 | position_ids = batch['position_ids'] 56 | img_feat = batch['img_feat'] 57 | img_pos_feat = batch['img_pos_feat'] 58 | attn_masks = batch['attn_masks'] 59 | gather_index = batch['gather_index'] 60 | txt_type_ids = batch['txt_type_ids'] 61 | sequence_output = self.uniter(input_ids, position_ids, 62 | img_feat, img_pos_feat, 63 | attn_masks, gather_index, 64 | output_all_encoded_layers=False, 65 | txt_type_ids=txt_type_ids) 66 | pooled_output = self.uniter.pooler(sequence_output) 67 | rank_scores = self.vcr_output(pooled_output) 68 | 69 | if compute_loss: 70 | targets = batch['targets'] 71 | vcr_loss = F.cross_entropy( 72 | rank_scores, targets.squeeze(-1), 73 | reduction='mean') 74 | return vcr_loss 75 | else: 76 | rank_scores = rank_scores[:, 1:] 77 | return rank_scores 78 | -------------------------------------------------------------------------------- /UNITER/utils/itm_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Image Text Retrieval evaluation helper 6 | """ 7 | from time import time 8 | 9 | import torch 10 | from horovod import torch as hvd 11 | from tqdm import tqdm 12 | 13 | from .logger import LOGGER 14 | from .misc import NoOp 15 | from .distributed import all_gather_list 16 | 17 | 18 | @torch.no_grad() 19 | def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts): 20 | # image retrieval 21 | img2j = {i: j for j, i in enumerate(img_ids)} 22 | _, rank_txt = score_matrix.topk(10, dim=1) 23 | gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] 24 | for txt_id in txt_ids], 25 | ).to(rank_txt.device 26 | ).unsqueeze(1).expand_as(rank_txt) 27 | rank = (rank_txt == gt_img_j).nonzero() 28 | if rank.numel(): 29 | ir_r1 = (rank < 1).sum().item() / len(txt_ids) 30 | ir_r5 = (rank < 5).sum().item() / len(txt_ids) 31 | ir_r10 = (rank < 10).sum().item() / len(txt_ids) 32 | else: 33 | ir_r1, ir_r5, ir_r10 = 0, 0, 0 34 | 35 | # text retrieval 36 | txt2i = {t: i for i, t in enumerate(txt_ids)} 37 | _, rank_img = score_matrix.topk(10, dim=0) 38 | tr_r1, tr_r5, tr_r10 = 0, 0, 0 39 | for j, img_id in enumerate(img_ids): 40 | gt_is = [txt2i[t] for t in img2txts[img_id]] 41 | ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] 42 | rank = min([10] + [r.item() for r in ranks if r.numel()]) 43 | if rank < 1: 44 | tr_r1 += 1 45 | if rank < 5: 46 | tr_r5 += 1 47 | if rank < 10: 48 | tr_r10 += 1 49 | tr_r1 /= len(img_ids) 50 | tr_r5 /= len(img_ids) 51 | tr_r10 /= len(img_ids) 52 | 53 | tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 54 | ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 55 | r_mean = (tr_mean + ir_mean) / 2 56 | 57 | eval_log = {'txt_r1': tr_r1, 58 | 'txt_r5': tr_r5, 59 | 'txt_r10': tr_r10, 60 | 'txt_r_mean': tr_mean, 61 | 'img_r1': ir_r1, 62 | 'img_r5': ir_r5, 63 | 'img_r10': ir_r10, 64 | 'img_r_mean': ir_mean, 65 | 'r_mean': r_mean} 66 | return eval_log 67 | 68 | 69 | @torch.no_grad() 70 | def evaluate(model, eval_loader): 71 | st = time() 72 | LOGGER.info("start running Image/Text Retrieval evaluation ...") 73 | score_matrix = inference(model, eval_loader) 74 | dset = eval_loader.dataset 75 | all_score = hvd.allgather(score_matrix) 76 | all_txt_ids = [i for ids in all_gather_list(dset.ids) 77 | for i in ids] 78 | all_img_ids = dset.all_img_ids 79 | assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) 80 | if hvd.rank() != 0: 81 | return {} 82 | 83 | # NOTE: only use rank0 to compute final scores 84 | eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, 85 | dset.txt2img, dset.img2txts) 86 | 87 | tot_time = time()-st 88 | LOGGER.info(f"evaluation finished in {int(tot_time)} seconds") 89 | return eval_log 90 | 91 | 92 | @torch.no_grad() 93 | def inference(model, eval_loader): 94 | model.eval() 95 | if hvd.rank() == 0: 96 | pbar = tqdm(total=len(eval_loader)) 97 | else: 98 | pbar = NoOp() 99 | score_matrix = torch.zeros(len(eval_loader.dataset), 100 | len(eval_loader.dataset.all_img_ids), 101 | device=torch.device("cuda"), 102 | dtype=torch.float16) 103 | for i, mini_batches in enumerate(eval_loader): 104 | j = 0 105 | for batch in mini_batches: 106 | scores = model(batch, compute_loss=False) 107 | bs = scores.size(0) 108 | score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() 109 | j += bs 110 | assert j == score_matrix.size(1) 111 | pbar.update(1) 112 | model.train() 113 | pbar.close() 114 | return score_matrix 115 | -------------------------------------------------------------------------------- /UNITER/data/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | sampler for length bucketing (batch by tokens) 6 | """ 7 | import math 8 | import random 9 | 10 | import horovod.torch as hvd 11 | import torch 12 | from torch.utils.data import Sampler 13 | from cytoolz import partition_all 14 | 15 | 16 | class TokenBucketSampler(Sampler): 17 | def __init__(self, lens, bucket_size, batch_size, 18 | droplast=False, size_multiple=8): 19 | self._lens = lens 20 | self._max_tok = batch_size 21 | self._bucket_size = bucket_size 22 | self._droplast = droplast 23 | self._size_mul = size_multiple 24 | 25 | def _create_ids(self): 26 | return list(range(len(self._lens))) 27 | 28 | def _sort_fn(self, i): 29 | return self._lens[i] 30 | 31 | def __iter__(self): 32 | ids = self._create_ids() 33 | random.shuffle(ids) 34 | buckets = [sorted(ids[i:i+self._bucket_size], 35 | key=self._sort_fn, reverse=True) 36 | for i in range(0, len(ids), self._bucket_size)] 37 | # fill batches until max_token (include padding) 38 | batches = [] 39 | for bucket in buckets: 40 | max_len = 0 41 | batch_indices = [] 42 | for indices in partition_all(self._size_mul, bucket): 43 | max_len = max(max_len, max(self._lens[i] for i in indices)) 44 | if (max_len * (len(batch_indices) + self._size_mul) 45 | > self._max_tok): 46 | if not batch_indices: 47 | raise ValueError( 48 | "max_tokens too small / max_seq_len too long") 49 | assert len(batch_indices) % self._size_mul == 0 50 | batches.append(batch_indices) 51 | batch_indices = list(indices) 52 | else: 53 | batch_indices.extend(indices) 54 | if not self._droplast and batch_indices: 55 | batches.append(batch_indices) 56 | random.shuffle(batches) 57 | return iter(batches) 58 | 59 | def __len__(self): 60 | raise ValueError("NOT supported. " 61 | "This has some randomness across epochs") 62 | 63 | 64 | class DistributedSampler(Sampler): 65 | """Sampler that restricts data loading to a subset of the dataset. 66 | 67 | It is especially useful in conjunction with 68 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 69 | process can pass a DistributedSampler instance as a DataLoader sampler, 70 | and load a subset of the original dataset that is exclusive to it. 71 | 72 | .. note:: 73 | Dataset is assumed to be of constant size. 74 | 75 | Arguments: 76 | dataset: Dataset used for sampling. 77 | num_replicas (optional): Number of processes participating in 78 | distributed training. 79 | rank (optional): Rank of the current process within num_replicas. 80 | shuffle (optional): If true (default), sampler will shuffle the indices 81 | """ 82 | 83 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 84 | if num_replicas is None: 85 | num_replicas = hvd.size() 86 | if rank is None: 87 | rank = hvd.rank() 88 | self.dataset = dataset 89 | self.num_replicas = num_replicas 90 | self.rank = rank 91 | self.epoch = 0 92 | self.num_samples = int(math.ceil(len(self.dataset) 93 | * 1.0 / self.num_replicas)) 94 | self.total_size = self.num_samples * self.num_replicas 95 | self.shuffle = shuffle 96 | 97 | def __iter__(self): 98 | # deterministically shuffle based on epoch 99 | g = torch.Generator() 100 | g.manual_seed(self.epoch) 101 | 102 | indices = list(range(len(self.dataset))) 103 | # add extra samples to make it evenly divisible 104 | indices += indices[:(self.total_size - len(indices))] 105 | assert len(indices) == self.total_size 106 | 107 | # subsample 108 | indices = indices[self.rank:self.total_size:self.num_replicas] 109 | 110 | if self.shuffle: 111 | shufle_ind = torch.randperm(len(indices), generator=g).tolist() 112 | indices = [indices[i] for i in shufle_ind] 113 | assert len(indices) == self.num_samples 114 | 115 | return iter(indices) 116 | 117 | def __len__(self): 118 | return self.num_samples 119 | 120 | def set_epoch(self, epoch): 121 | self.epoch = epoch 122 | -------------------------------------------------------------------------------- /UNITER/optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface (https://github.com/huggingface/transformers). 4 | """ 5 | import math 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class AdamW(Optimizer): 12 | """ Implements Adam algorithm with weight decay fix. 13 | Parameters: 14 | lr (float): learning rate. Default 1e-3. 15 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). 16 | Default: (0.9, 0.999) 17 | eps (float): Adams epsilon. Default: 1e-6 18 | weight_decay (float): Weight decay. Default: 0.0 19 | correct_bias (bool): can be set to False to avoid correcting bias 20 | in Adam (e.g. like in Bert TF repository). Default True. 21 | """ 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 23 | weight_decay=0.0, correct_bias=True): 24 | if lr < 0.0: 25 | raise ValueError( 26 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 27 | if not 0.0 <= betas[0] < 1.0: 28 | raise ValueError("Invalid beta parameter: {} - " 29 | "should be in [0.0, 1.0[".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter: {} - " 32 | "should be in [0.0, 1.0[".format(betas[1])) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {} - " 35 | "should be >= 0.0".format(eps)) 36 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 37 | correct_bias=correct_bias) 38 | super(AdamW, self).__init__(params, defaults) 39 | 40 | def step(self, closure=None): 41 | """Performs a single optimization step. 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | grad = p.grad.data 55 | if grad.is_sparse: 56 | raise RuntimeError( 57 | 'Adam does not support sparse ' 58 | 'gradients, please consider SparseAdam instead') 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | 70 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 71 | beta1, beta2 = group['betas'] 72 | 73 | state['step'] += 1 74 | 75 | # Decay the first and second moment running average coefficient 76 | # In-place operations to update the averages at the same time 77 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 79 | denom = exp_avg_sq.sqrt().add_(group['eps']) 80 | 81 | step_size = group['lr'] 82 | if group['correct_bias']: # No bias correction for Bert 83 | bias_correction1 = 1.0 - beta1 ** state['step'] 84 | bias_correction2 = 1.0 - beta2 ** state['step'] 85 | step_size = (step_size * math.sqrt(bias_correction2) 86 | / bias_correction1) 87 | 88 | p.data.addcdiv_(-step_size, exp_avg, denom) 89 | 90 | # Just adding the square of the weights to the loss function is 91 | # *not* the correct way of using L2 regularization/weight decay 92 | # with Adam, since that will interact with the m and v 93 | # parameters in strange ways. 94 | # 95 | # Instead we want to decay the weights in a manner that doesn't 96 | # interact with the m/v parameters. This is equivalent to 97 | # adding the square of the weights to the loss with plain 98 | # (non-momentum) SGD. 99 | # Add weight decay at the end (fixed version) 100 | if group['weight_decay'] > 0.0: 101 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /clip_mm_explain/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /UNITER/data/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | A prefetch loader to speedup data loading 6 | Modified from Nvidia Deep Learning Examples 7 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 8 | """ 9 | import random 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | from utils.distributed import any_broadcast 15 | 16 | 17 | class MetaLoader(object): 18 | """ wraps multiple data loaders """ 19 | def __init__(self, loaders, accum_steps=1, distributed=False): 20 | assert isinstance(loaders, dict) 21 | self.name2loader = {} 22 | self.name2iter = {} 23 | self.sampling_pools = [] 24 | for n, l in loaders.items(): 25 | if isinstance(l, tuple): 26 | l, r = l 27 | elif isinstance(l, DataLoader): 28 | r = 1 29 | else: 30 | raise ValueError() 31 | self.name2loader[n] = l 32 | self.name2iter[n] = iter(l) 33 | self.sampling_pools.extend([n]*r) 34 | 35 | self.accum_steps = accum_steps 36 | self.distributed = distributed 37 | self.step = 0 38 | 39 | def __iter__(self): 40 | """ this iterator will run indefinitely """ 41 | task = self.sampling_pools[0] 42 | while True: 43 | if self.step % self.accum_steps == 0: 44 | task = random.choice(self.sampling_pools) 45 | if self.distributed: 46 | # make sure all process is training same task 47 | task = any_broadcast(task, 0) 48 | self.step += 1 49 | iter_ = self.name2iter[task] 50 | try: 51 | batch = next(iter_) 52 | except StopIteration: 53 | iter_ = iter(self.name2loader[task]) 54 | batch = next(iter_) 55 | self.name2iter[task] = iter_ 56 | 57 | yield task, batch 58 | 59 | 60 | def move_to_cuda(batch): 61 | if isinstance(batch, torch.Tensor): 62 | return batch.cuda(non_blocking=True) 63 | elif isinstance(batch, list): 64 | new_batch = [move_to_cuda(t) for t in batch] 65 | elif isinstance(batch, tuple): 66 | new_batch = tuple(move_to_cuda(t) for t in batch) 67 | elif isinstance(batch, dict): 68 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 69 | else: 70 | return batch 71 | return new_batch 72 | 73 | 74 | def record_cuda_stream(batch): 75 | if isinstance(batch, torch.Tensor): 76 | batch.record_stream(torch.cuda.current_stream()) 77 | elif isinstance(batch, list) or isinstance(batch, tuple): 78 | for t in batch: 79 | record_cuda_stream(t) 80 | elif isinstance(batch, dict): 81 | for t in batch.values(): 82 | record_cuda_stream(t) 83 | else: 84 | pass 85 | 86 | 87 | class PrefetchLoader(object): 88 | """ 89 | overlap compute and cuda data transfer 90 | (copied and then modified from nvidia apex) 91 | """ 92 | def __init__(self, loader): 93 | self.loader = loader 94 | self.stream = torch.cuda.Stream() 95 | 96 | def __iter__(self): 97 | loader_it = iter(self.loader) 98 | self.preload(loader_it) 99 | batch = self.next(loader_it) 100 | while batch is not None: 101 | yield batch 102 | batch = self.next(loader_it) 103 | 104 | def __len__(self): 105 | return len(self.loader) 106 | 107 | def preload(self, it): 108 | try: 109 | self.batch = next(it) 110 | except StopIteration: 111 | self.batch = None 112 | return 113 | # if record_stream() doesn't work, another option is to make sure 114 | # device inputs are created on the main stream. 115 | # self.next_input_gpu = torch.empty_like(self.next_input, 116 | # device='cuda') 117 | # self.next_target_gpu = torch.empty_like(self.next_target, 118 | # device='cuda') 119 | # Need to make sure the memory allocated for next_* is not still in use 120 | # by the main stream at the time we start copying to next_*: 121 | # self.stream.wait_stream(torch.cuda.current_stream()) 122 | with torch.cuda.stream(self.stream): 123 | self.batch = move_to_cuda(self.batch) 124 | # more code for the alternative if record_stream() doesn't work: 125 | # copy_ will record the use of the pinned source tensor in this 126 | # side stream. 127 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 128 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 129 | # self.next_input = self.next_input_gpu 130 | # self.next_target = self.next_target_gpu 131 | 132 | def next(self, it): 133 | torch.cuda.current_stream().wait_stream(self.stream) 134 | batch = self.batch 135 | if batch is not None: 136 | record_cuda_stream(batch) 137 | self.preload(it) 138 | return batch 139 | 140 | def __getattr__(self, name): 141 | method = self.loader.__getattribute__(name) 142 | return method 143 | -------------------------------------------------------------------------------- /UNITER/scripts/convert_imgdir.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | convert image npz to LMDB 6 | """ 7 | import argparse 8 | import glob 9 | import io 10 | import json 11 | import multiprocessing as mp 12 | import os 13 | from os.path import basename, exists 14 | 15 | from cytoolz import curry 16 | import numpy as np 17 | from tqdm import tqdm 18 | import lmdb 19 | 20 | import msgpack 21 | import msgpack_numpy 22 | msgpack_numpy.patch() 23 | 24 | 25 | def _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb): 26 | num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum()) 27 | num_bb = min(max_bb, num_bb) 28 | return int(num_bb) 29 | 30 | 31 | @curry 32 | def load_npz(conf_th, max_bb, min_bb, num_bb, fname, keep_all=False): 33 | try: 34 | img_dump = np.load(fname, allow_pickle=True) 35 | if keep_all: 36 | nbb = None 37 | else: 38 | nbb = _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb) 39 | dump = {} 40 | for key, arr in img_dump.items(): 41 | if arr.dtype == np.float32: 42 | arr = arr.astype(np.float16) 43 | if arr.ndim == 2: 44 | dump[key] = arr[:nbb, :] 45 | elif arr.ndim == 1: 46 | dump[key] = arr[:nbb] 47 | else: 48 | raise ValueError('wrong ndim') 49 | except Exception as e: 50 | # corrupted file 51 | print(f'corrupted file {fname}', e) 52 | dump = {} 53 | nbb = 0 54 | 55 | name = basename(fname) 56 | return name, dump, nbb 57 | 58 | 59 | def dumps_npz(dump, compress=False): 60 | with io.BytesIO() as writer: 61 | if compress: 62 | np.savez_compressed(writer, **dump, allow_pickle=True) 63 | else: 64 | np.savez(writer, **dump, allow_pickle=True) 65 | return writer.getvalue() 66 | 67 | 68 | def dumps_msgpack(dump): 69 | return msgpack.dumps(dump, use_bin_type=True) 70 | 71 | 72 | def main(opts): 73 | if opts.img_dir[-1] == '/': 74 | opts.img_dir = opts.img_dir[:-1] 75 | split = basename(opts.img_dir) 76 | if opts.keep_all: 77 | db_name = 'all' 78 | else: 79 | if opts.conf_th == -1: 80 | db_name = f'feat_numbb{opts.num_bb}' 81 | else: 82 | db_name = (f'feat_th{opts.conf_th}_max{opts.max_bb}' 83 | f'_min{opts.min_bb}') 84 | if opts.compress: 85 | db_name += '_compressed' 86 | if not exists(f'{opts.output}/{split}'): 87 | os.makedirs(f'{opts.output}/{split}') 88 | env = lmdb.open(f'{opts.output}/{split}/{db_name}', map_size=1024**4) 89 | txn = env.begin(write=True) 90 | files = glob.glob(f'{opts.img_dir}/*.npz') 91 | load = load_npz(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, 92 | keep_all=opts.keep_all) 93 | name2nbb = {} 94 | with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar: 95 | for i, (fname, features, nbb) in enumerate( 96 | pool.imap_unordered(load, files, chunksize=128)): 97 | if not features: 98 | continue # corrupted feature 99 | if opts.compress: 100 | dump = dumps_npz(features, compress=True) 101 | else: 102 | dump = dumps_msgpack(features) 103 | txn.put(key=fname.encode('utf-8'), value=dump) 104 | if i % 1000 == 0: 105 | txn.commit() 106 | txn = env.begin(write=True) 107 | name2nbb[fname] = nbb 108 | pbar.update(1) 109 | txn.put(key=b'__keys__', 110 | value=json.dumps(list(name2nbb.keys())).encode('utf-8')) 111 | txn.commit() 112 | env.close() 113 | if opts.conf_th != -1 and not opts.keep_all: 114 | with open(f'{opts.output}/{split}/' 115 | f'nbb_th{opts.conf_th}_' 116 | f'max{opts.max_bb}_min{opts.min_bb}.json', 'w') as f: 117 | json.dump(name2nbb, f) 118 | 119 | 120 | if __name__ == '__main__': 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("--img_dir", default=None, type=str, 123 | help="The input images.") 124 | parser.add_argument("--output", default=None, type=str, 125 | help="output lmdb") 126 | parser.add_argument('--nproc', type=int, default=8, 127 | help='number of cores used') 128 | parser.add_argument('--compress', action='store_true', 129 | help='compress the tensors') 130 | parser.add_argument('--keep_all', action='store_true', 131 | help='keep all features, overrides all following args') 132 | parser.add_argument('--conf_th', type=float, default=0.2, 133 | help='threshold for dynamic bounding boxes ' 134 | '(-1 for fixed)') 135 | parser.add_argument('--max_bb', type=int, default=100, 136 | help='max number of bounding boxes') 137 | parser.add_argument('--min_bb', type=int, default=10, 138 | help='min number of bounding boxes') 139 | parser.add_argument('--num_bb', type=int, default=100, 140 | help='number of bounding boxes (fixed)') 141 | args = parser.parse_args() 142 | main(args) 143 | -------------------------------------------------------------------------------- /py-bottom-up-attention/extract_features.py: -------------------------------------------------------------------------------- 1 | # Adapted from code in https://github.com/airsplay/py-bottom-up-attention 2 | # Please see https://github.com/airsplay/py-bottom-up-attention/blob/master/LICENSE for the Apache 2.0 License of that code 3 | 4 | import os 5 | import sys 6 | import io 7 | import json 8 | from tqdm import tqdm 9 | 10 | import detectron2 11 | 12 | # import some common detectron2 utilities 13 | from detectron2.engine import DefaultPredictor 14 | from detectron2.config import get_cfg 15 | from detectron2.utils.visualizer import Visualizer 16 | from detectron2.data import MetadataCatalog 17 | 18 | # import some common libraries 19 | import numpy as np 20 | import cv2 21 | import torch 22 | 23 | NUM_OBJECTS = 36 24 | 25 | from torch import nn 26 | 27 | from detectron2.modeling.postprocessing import detector_postprocess 28 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers, FastRCNNOutputs, fast_rcnn_inference_single_image 29 | from detectron2.structures.boxes import Boxes 30 | from detectron2.structures.instances import Instances 31 | 32 | 33 | # Load VG Classes 34 | data_path = 'data/genome/1600-400-20' 35 | 36 | vg_classes = [] 37 | with open(os.path.join(data_path, 'objects_vocab.txt')) as f: 38 | for object in f.readlines(): 39 | vg_classes.append(object.split(',')[0].lower().strip()) 40 | 41 | vg_attrs = [] 42 | with open(os.path.join(data_path, 'attributes_vocab.txt')) as f: 43 | for object in f.readlines(): 44 | vg_attrs.append(object.split(',')[0].lower().strip()) 45 | 46 | 47 | MetadataCatalog.get("vg").thing_classes = vg_classes 48 | MetadataCatalog.get("vg").attr_classes = vg_attrs 49 | 50 | cfg = get_cfg() 51 | cfg.merge_from_file("configs/VG-Detection/faster_rcnn_R_101_C4_caffe.yaml") 52 | cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 300 53 | cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.6 54 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 55 | # VG Weight 56 | # cfg.MODEL.WEIGHTS = "http://nlp.cs.unc.edu/models/faster_rcnn_from_caffe.pkl" 57 | cfg.MODEL.WEIGHTS = "http://nlp.cs.unc.edu/models/faster_rcnn_from_caffe_attr_original.pkl" 58 | predictor = DefaultPredictor(cfg) 59 | 60 | def doit(raw_image, raw_boxes): 61 | # Process Boxes 62 | raw_boxes = Boxes(torch.from_numpy(raw_boxes).cuda()) 63 | with torch.no_grad(): 64 | raw_height, raw_width = raw_image.shape[:2] 65 | print("Original image size: ", (raw_height, raw_width)) 66 | # Preprocessing 67 | image = predictor.transform_gen.get_transform(raw_image).apply_image(raw_image) 68 | print("Transformed image size: ", image.shape[:2]) 69 | # Scale the box 70 | new_height, new_width = image.shape[:2] 71 | scale_x = 1. * new_width / raw_width 72 | scale_y = 1. * new_height / raw_height 73 | #print(scale_x, scale_y) 74 | boxes = raw_boxes.clone() 75 | boxes.scale(scale_x=scale_x, scale_y=scale_y) 76 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 77 | inputs = [{"image": image, "height": raw_height, "width": raw_width}] 78 | images = predictor.model.preprocess_image(inputs) 79 | # Run Backbone Res1-Res4 80 | features = predictor.model.backbone(images.tensor) 81 | # Run RoI head for each proposal (RoI Pooling + Res5) 82 | proposal_boxes = [boxes] 83 | features = [features[f] for f in predictor.model.roi_heads.in_features] 84 | box_features = predictor.model.roi_heads._shared_roi_transform( 85 | features, proposal_boxes 86 | ) 87 | feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1 88 | print('Pooled features size:', feature_pooled.shape) 89 | # Predict classes and boxes for each proposal. 90 | pred_class_logits, pred_proposal_deltas = predictor.model.roi_heads.box_predictor(feature_pooled) 91 | print(pred_class_logits.shape) 92 | pred_class_prob = nn.functional.softmax(pred_class_logits, -1) 93 | pred_scores, pred_classes = pred_class_prob[..., :-1].max(-1) 94 | # Detectron2 Formatting (for visualization only) 95 | roi_features = feature_pooled 96 | instances = Instances( 97 | image_size=(raw_height, raw_width), 98 | pred_boxes=raw_boxes, 99 | scores=pred_scores, 100 | pred_classes=pred_classes 101 | ) 102 | return instances, roi_features 103 | 104 | # Image root 105 | img_root = sys.argv[1] 106 | # jsonl input file 107 | f = open(sys.argv[2]) 108 | lines = f.readlines() 109 | data = [json.loads(line) for line in lines] 110 | if len(sys.argv) > 4: 111 | # Predicted boxes JSON file 112 | f = open(sys.argv[3]) 113 | predicted_boxes = json.load(f) 114 | boxes_dict = {} 115 | for datum in tqdm(data): 116 | if 'coco' in datum['file_name'].lower(): 117 | datum['file_name'] = '_'.join(datum['file_name'].split('_')[:-1])+'.jpg' 118 | im = cv2.imread(img_root+datum['file_name']) 119 | if len(sys.argv) == 4: 120 | boxes = np.array( 121 | [[ann['bbox'][0], ann['bbox'][1], ann['bbox'][0]+ann['bbox'][2], ann['bbox'][1]+ann['bbox'][3]] for ann in datum['anns']] 122 | ) 123 | conf = np.array([1 for _ in datum["anns"]]) 124 | else: 125 | assert len(sys.argv) > 4 126 | conf = np.array(predicted_boxes[str(datum["image_id"])]["scores"]) 127 | print(conf.shape) 128 | if len(conf) > 0: 129 | boxes = np.array(predicted_boxes[str(datum["image_id"])]["boxes"])[(-conf).argsort(),:] 130 | conf = conf[(-conf).argsort()] 131 | else: 132 | boxes = np.array([[0, 0, im.shape[1], im.shape[0]]]) 133 | conf = np.array([1.]) 134 | _, feats = doit(im, boxes) 135 | boxes_dict[datum['image_id']] = {"boxes": torch.from_numpy(boxes), "features": feats, "width": im.shape[1], "height": im.shape[0], "conf": torch.from_numpy(conf)} 136 | # Output .pt file 137 | torch.save(boxes_dict, sys.argv[4]) 138 | -------------------------------------------------------------------------------- /generic_clip_pairs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import clip 3 | import json 4 | import argparse 5 | import ruamel.yaml as yaml 6 | 7 | from PIL import Image 8 | import torch 9 | import torchvision.transforms as transforms 10 | from tqdm import tqdm 11 | 12 | from albef.utils import * 13 | from executor import AlbefExecutor 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--input_path", type=str, help="Path to input JSON file") 17 | parser.add_argument("--image_root", type=str, help="Path to directory containing images") 18 | parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF") 19 | parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF") 20 | parser.add_argument("--clip_model", type=str, help="CLIP model to use") 21 | parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use") 22 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP") 23 | 24 | args = parser.parse_args() 25 | 26 | if args.albef_path is not None: 27 | executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu)) 28 | model = executor.models[0] 29 | preprocess = executor.preprocesses[0] 30 | model = model.eval() 31 | else: 32 | model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu)) 33 | preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC) 34 | model = model.eval() 35 | input_file = open(args.input_path) 36 | data = json.load(input_file) 37 | input_file.close() 38 | correct = 0 39 | for i in tqdm(range(0, len(data), args.batch_size)): 40 | batch_images = [] 41 | batch_text = [] 42 | for datum in data[i:min(i+args.batch_size, len(data))]: 43 | img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB') 44 | batch_images.append(preprocess(img)) 45 | if "text2" in datum: 46 | if args.albef_path is None: 47 | datum["text1"] = "a photo of "+datum["text1"] 48 | datum["text2"] = "a photo of "+datum["text2"] 49 | batch_text.append(datum["text1"]) 50 | batch_text.append(datum["text2"]) 51 | else: 52 | img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB') 53 | batch_images.append(preprocess(img2)) 54 | batch_text.append(datum["text1"]) 55 | batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu)) 56 | if args.albef_path is None: 57 | batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu)) 58 | else: 59 | modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text] 60 | batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt") 61 | for key in batch_text: 62 | batch_text[key] = batch_text[key].to(batch_images.device) 63 | 64 | with torch.no_grad(): 65 | if args.albef_path is None: 66 | logits_per_image, logits_per_text = model(batch_images, batch_text) 67 | else: 68 | if not args.albef_itc: 69 | if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]: 70 | batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3]) 71 | else: 72 | assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0] 73 | batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) 74 | batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1) 75 | image_embeds = model.visual_encoder(batch_images) 76 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device) 77 | output = model.text_encoder( 78 | batch_text.input_ids, 79 | attention_mask = batch_text.attention_mask, 80 | encoder_hidden_states = image_embeds, 81 | encoder_attention_mask = image_atts, 82 | return_dict = True, 83 | ) 84 | vl_embeddings = output.last_hidden_state[:,0,:] 85 | vl_output = model.itm_head(vl_embeddings) 86 | logits_per_image = vl_output[:,1:2].view(-1, 2) 87 | else: 88 | image_embeds = model.visual_encoder(batch_images) 89 | image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1) 90 | text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask, 91 | return_dict = True, mode = 'text') 92 | text_embeds = text_output.last_hidden_state 93 | text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1) 94 | sim = image_feat@text_feat.t()/model.temp 95 | logits_per_image = sim 96 | if args.albef_path is None or args.albef_itc: 97 | if logits_per_image.shape[0]*2 == logits_per_image.shape[1]: 98 | for j in range(logits_per_image.shape[0]): 99 | correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0 100 | else: 101 | assert logits_per_image.shape[0] == 2*logits_per_image.shape[1] 102 | for j in range(logits_per_image.shape[1]): 103 | correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0 104 | else: 105 | correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item() 106 | 107 | print("Accuracy:", correct/len(data)) 108 | -------------------------------------------------------------------------------- /entity_extraction.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional 2 | import numpy as np 3 | from spacy.tokens.token import Token 4 | from spacy.tokens.span import Span 5 | 6 | from lattice import Product as L 7 | 8 | from heuristics import Heuristics 9 | 10 | Rel = Tuple[List[Token], "Entity"] 11 | Sup = List[Token] 12 | 13 | DEFAULT_HEURISTICS = Heuristics() 14 | 15 | 16 | def find_superlatives(tokens, heuristics) -> List[Sup]: 17 | """Modify and return a list of superlative tokens.""" 18 | for heuristic in heuristics.superlatives: 19 | if any(tok.text in heuristic.keywords for tok in tokens): 20 | tokens.sort(key=lambda tok: tok.i) 21 | return [tokens] 22 | return [] 23 | 24 | def expand_chunks(doc, chunks): 25 | expanded = {} 26 | for key in chunks: 27 | chunk = chunks[key] 28 | start = chunk.start 29 | end = chunk.end 30 | for i in range(chunk.start-1, -1, -1): 31 | if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): 32 | if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): 33 | start = i 34 | for i in range(chunk.end, len(doc)): 35 | if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): 36 | if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): 37 | end = i+1 38 | else: 39 | break 40 | expanded[key] = Span(doc=doc, start=start, end=end) 41 | return expanded 42 | 43 | class Entity(NamedTuple): 44 | """Represents an entity with locative constraints extracted from the parse.""" 45 | 46 | head: Span 47 | relations: List[Rel] 48 | superlatives: List[Sup] 49 | 50 | @classmethod 51 | def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity": 52 | """Extract entities from a spacy parse. 53 | 54 | Jointly recursive with `_get_rel_sups`.""" 55 | if heuristics is None: 56 | heuristics = DEFAULT_HEURISTICS 57 | 58 | if head.i not in chunks: 59 | # Handles predicative cases. 60 | children = list(head.children) 61 | if children and children[0].i in chunks: 62 | head = children[0] 63 | # TODO: Also extract predicative relations. 64 | else: 65 | return None 66 | hchunk = chunks[head.i] 67 | rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics) 68 | return cls(hchunk, rels, sups) 69 | 70 | @classmethod 71 | def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]: 72 | hchunk = chunks[head.i] 73 | is_keyword = any(token.text in h.keywords for h in heuristics.relations) 74 | is_keyword |= token.text in heuristics.null_keywords 75 | 76 | # Found another entity head. 77 | if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword: 78 | tchunk = chunks[token.i] 79 | tokens.sort(key=lambda tok: tok.i) 80 | subhead = cls.extract(token, chunks, heuristics) 81 | return [(tokens, subhead)], [] 82 | 83 | # End of a chain of modifiers. 84 | n_children = len(list(token.children)) 85 | if n_children == 0: 86 | return [], find_superlatives(tokens + [token], heuristics) 87 | 88 | relations = [] 89 | superlatives = [] 90 | is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives) 91 | for child in token.children: 92 | if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]: 93 | if not any(child.text in h.keywords for h in heuristics.superlatives): 94 | if n_children == 1: 95 | # Catches "the goat on the left" 96 | sups = find_superlatives(tokens + [token], heuristics) 97 | superlatives.extend(sups) 98 | continue 99 | new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens 100 | subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics) 101 | relations.extend(subrel) 102 | superlatives.extend(subsup) 103 | return relations, superlatives 104 | 105 | def expand(self, span: Span = None): 106 | tokens = [token for token in self.head] 107 | if span is None: 108 | span = [None] 109 | for target_token in span: 110 | include = False 111 | stack = [token for token in self.head] 112 | while len(stack) > 0: 113 | token = stack.pop() 114 | if token == target_token: 115 | token2 = target_token.head 116 | while token2.head != token2: 117 | tokens.append(token2) 118 | token2 = token2.head 119 | tokens.append(token2) 120 | stack = [] 121 | include = True 122 | if target_token is None or include: 123 | tokens.append(token) 124 | for child in token.children: 125 | stack.append(child) 126 | tokens = list(set(tokens)) 127 | tokens = sorted(tokens, key=lambda x: x.i) 128 | return ' '.join([token.text for token in tokens]) 129 | 130 | def __eq__(self, other: "Entity") -> bool: 131 | if self.text != other.text: 132 | return False 133 | if self.relations != other.relations: 134 | return False 135 | if self.superlatives != other.superlatives: 136 | return False 137 | return True 138 | 139 | @property 140 | def text(self) -> Text: 141 | """Get the text predicate associated with this entity.""" 142 | return self.head.text 143 | -------------------------------------------------------------------------------- /UNITER/model/itm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | UNITER for ITM model 6 | """ 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch import nn 11 | from .model import UniterPreTrainedModel, UniterModel 12 | 13 | 14 | class UniterForImageTextRetrieval(UniterPreTrainedModel): 15 | """ Finetune UNITER for image text retrieval 16 | """ 17 | def __init__(self, config, img_dim, margin=0.2): 18 | super().__init__(config) 19 | self.uniter = UniterModel(config, img_dim) 20 | self.itm_output = nn.Linear(config.hidden_size, 2) 21 | self.rank_output = nn.Linear(config.hidden_size, 1) 22 | self.margin = margin 23 | self.apply(self.init_weights) 24 | 25 | def init_output(self): 26 | """ need to be called after from pretrained """ 27 | self.rank_output.weight.data = self.itm_output.weight.data[1:, :] 28 | self.rank_output.bias.data = self.itm_output.bias.data[1:] 29 | 30 | def forward(self, batch, compute_loss=True): 31 | batch = defaultdict(lambda: None, batch) 32 | input_ids = batch['input_ids'] 33 | position_ids = batch['position_ids'] 34 | img_feat = batch['img_feat'] 35 | img_pos_feat = batch['img_pos_feat'] 36 | attention_mask = batch['attn_masks'] 37 | gather_index = batch['gather_index'] 38 | sequence_output = self.uniter(input_ids, position_ids, 39 | img_feat, img_pos_feat, 40 | attention_mask, gather_index, 41 | output_all_encoded_layers=False) 42 | pooled_output = self.uniter.pooler(sequence_output) 43 | rank_scores = self.rank_output(pooled_output) 44 | 45 | if compute_loss: 46 | # triplet loss 47 | rank_scores_sigmoid = torch.sigmoid(rank_scores) 48 | sample_size = batch['sample_size'] 49 | scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) 50 | pos = scores[:, :1] 51 | neg = scores[:, 1:] 52 | rank_loss = torch.clamp(self.margin + neg - pos, 0) 53 | return rank_loss 54 | else: 55 | return rank_scores 56 | 57 | 58 | class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): 59 | """ Finetune UNITER for image text retrieval 60 | """ 61 | def __init__(self, config, img_dim, margin=0.2, hard_size=16): 62 | super().__init__(config, img_dim, margin) 63 | self.hard_size = hard_size 64 | 65 | def forward(self, batch, sample_from='t', compute_loss=True): 66 | # expect same input_ids for all pairs 67 | batch_size = batch['attn_masks'].size(0) 68 | input_ids = batch['input_ids'] 69 | img_feat = batch['img_feat'] 70 | img_pos_feat = batch['img_pos_feat'] 71 | if sample_from == 't': 72 | if input_ids.size(0) == 1: 73 | batch['input_ids'] = input_ids.expand(batch_size, -1) 74 | elif sample_from == 'i': 75 | if img_feat.size(0) == 1: 76 | batch['img_feat'] = img_feat.expand(batch_size, -1, -1) 77 | if img_pos_feat.size(0) == 1: 78 | batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) 79 | else: 80 | raise ValueError() 81 | 82 | if self.training and compute_loss: 83 | with torch.no_grad(): 84 | self.eval() 85 | scores = super().forward(batch, compute_loss=False) 86 | hard_batch = self._get_hard_batch(batch, scores, sample_from) 87 | self.train() 88 | return super().forward(hard_batch, compute_loss=True) 89 | else: 90 | return super().forward(batch, compute_loss) 91 | 92 | def _get_hard_batch(self, batch, scores, sample_from='t'): 93 | batch = defaultdict(lambda: None, batch) 94 | input_ids = batch['input_ids'] 95 | position_ids = batch['position_ids'] 96 | img_feat = batch['img_feat'] 97 | img_pos_feat = batch['img_pos_feat'] 98 | attention_mask = batch['attn_masks'] 99 | gather_index = batch['gather_index'] 100 | hard_batch = {'sample_size': self.hard_size + 1} 101 | 102 | # NOTE first example is positive 103 | hard_indices = scores.squeeze(-1)[1:].topk( 104 | self.hard_size, sorted=False)[1] + 1 105 | indices = torch.cat([torch.zeros(1, dtype=torch.long, 106 | device=hard_indices.device), 107 | hard_indices]) 108 | 109 | attention_mask = attention_mask.index_select(0, indices) 110 | gather_index = gather_index.index_select(0, indices) 111 | if position_ids.size(0) != 1: 112 | position_ids = position_ids[:self.hard_size+1] 113 | 114 | if sample_from == 't': 115 | # cut to minimum padding 116 | max_len = attention_mask.sum(dim=1).max().item() 117 | max_i = max_len - input_ids.size(1) 118 | attention_mask = attention_mask[:, :max_len] 119 | gather_index = gather_index[:, :max_len] 120 | img_feat = img_feat.index_select(0, indices)[:, :max_i, :] 121 | img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] 122 | # expect same input_ids for all pairs 123 | input_ids = input_ids[:self.hard_size+1] 124 | elif sample_from == 'i': 125 | input_ids = input_ids.index_select(0, indices) 126 | # expect same image features for all pairs 127 | img_feat = img_feat[:self.hard_size+1] 128 | img_pos_feat = img_pos_feat[:self.hard_size+1] 129 | else: 130 | raise ValueError() 131 | 132 | hard_batch['input_ids'] = input_ids 133 | hard_batch['position_ids'] = position_ids 134 | hard_batch['img_feat'] = img_feat 135 | hard_batch['img_pos_feat'] = img_pos_feat 136 | hard_batch['attn_masks'] = attention_mask 137 | hard_batch['gather_index'] = gather_index 138 | 139 | return hard_batch 140 | -------------------------------------------------------------------------------- /UNITER/model/re.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Uniter for RE model 6 | """ 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch import nn 11 | import random 12 | import numpy as np 13 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 14 | 15 | from .layer import GELU 16 | from .model import UniterPreTrainedModel, UniterModel 17 | 18 | 19 | class UniterForReferringExpressionComprehension(UniterPreTrainedModel): 20 | """ Finetune UNITER for RE 21 | """ 22 | def __init__(self, config, img_dim, loss="cls", 23 | margin=0.2, hard_ratio=0.3, mlp=1): 24 | super().__init__(config) 25 | self.uniter = UniterModel(config, img_dim) 26 | if mlp == 1: 27 | self.re_output = nn.Linear(config.hidden_size, 1) 28 | elif mlp == 2: 29 | self.re_output = nn.Sequential( 30 | nn.Linear(config.hidden_size, config.hidden_size), 31 | GELU(), 32 | LayerNorm(config.hidden_size, eps=1e-12), 33 | nn.Linear(config.hidden_size, 1) 34 | ) 35 | else: 36 | raise ValueError("MLP restricted to be 1 or 2 layers.") 37 | self.loss = loss 38 | assert self.loss in ['cls', 'rank'] 39 | if self.loss == 'rank': 40 | self.margin = margin 41 | self.hard_ratio = hard_ratio 42 | else: 43 | self.crit = nn.CrossEntropyLoss(reduction='none') 44 | 45 | self.apply(self.init_weights) 46 | 47 | def forward(self, batch, compute_loss=True): 48 | batch = defaultdict(lambda: None, batch) 49 | input_ids = batch['input_ids'] 50 | position_ids = batch['position_ids'] 51 | img_feat = batch['img_feat'] 52 | img_pos_feat = batch['img_pos_feat'] 53 | attn_masks = batch['attn_masks'] 54 | gather_index = batch['gather_index'] 55 | obj_masks = batch['obj_masks'] 56 | 57 | sequence_output = self.uniter(input_ids, position_ids, 58 | img_feat, img_pos_feat, 59 | attn_masks, gather_index, 60 | output_all_encoded_layers=False) 61 | # get only the region part 62 | txt_lens, num_bbs = batch["txt_lens"], batch["num_bbs"] 63 | sequence_output = self._get_image_hidden( 64 | sequence_output, txt_lens, num_bbs) 65 | 66 | # re score (n, max_num_bb) 67 | scores = self.re_output(sequence_output).squeeze(2) 68 | scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects 69 | 70 | if compute_loss: 71 | targets = batch["targets"] 72 | if self.loss == 'cls': 73 | ce_loss = self.crit(scores, targets.squeeze(-1)) # (n, ) as no reduction 74 | return ce_loss 75 | else: 76 | # ranking 77 | _n = len(num_bbs) 78 | # positive (target) 79 | pos_ix = targets 80 | pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1) 81 | pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1] 82 | # negative 83 | neg_ix = self.sample_neg_ix(scores, targets, num_bbs) 84 | neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1) 85 | neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1] 86 | # ranking 87 | mm_loss = torch.clamp( 88 | self.margin + neg_sc - pos_sc, 0) # (n, ) 89 | return mm_loss 90 | else: 91 | # (n, max_num_bb) 92 | return scores 93 | 94 | def sample_neg_ix(self, scores, targets, num_bbs): 95 | """ 96 | Inputs: 97 | :scores (n, max_num_bb) 98 | :targets (n, ) 99 | :num_bbs list of [num_bb] 100 | return: 101 | :neg_ix (n, ) easy/hard negative (!= target) 102 | """ 103 | neg_ix = [] 104 | cand_ixs = torch.argsort( 105 | scores, dim=-1, descending=True) # (n, num_bb) 106 | for i in range(len(num_bbs)): 107 | num_bb = num_bbs[i] 108 | if np.random.uniform(0, 1, 1) < self.hard_ratio: 109 | # sample hard negative, w/ highest score 110 | for ix in cand_ixs[i].tolist(): 111 | if ix != targets[i]: 112 | assert ix < num_bb, f'ix={ix}, num_bb={num_bb}' 113 | neg_ix.append(ix) 114 | break 115 | else: 116 | # sample easy negative, i.e., random one 117 | ix = random.randint(0, num_bb-1) # [0, num_bb-1] 118 | while ix == targets[i]: 119 | ix = random.randint(0, num_bb-1) 120 | neg_ix.append(ix) 121 | neg_ix = torch.tensor(neg_ix).type(targets.type()) 122 | assert neg_ix.numel() == targets.numel() 123 | return neg_ix 124 | 125 | def _get_image_hidden(self, sequence_output, txt_lens, num_bbs): 126 | """ 127 | Extracting the img_hidden part from sequence_output. 128 | Inputs: 129 | - sequence_output: (n, txt_len+num_bb, hid_size) 130 | - txt_lens : [txt_len] 131 | - num_bbs : [num_bb] 132 | Output: 133 | - img_hidden : (n, max_num_bb, hid_size) 134 | """ 135 | outputs = [] 136 | max_bb = max(num_bbs) 137 | hid_size = sequence_output.size(-1) 138 | for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0), 139 | txt_lens, num_bbs): 140 | img_hid = seq_out[:, len_:len_+nbb, :] 141 | if nbb < max_bb: 142 | img_hid = torch.cat( 143 | [img_hid, self._get_pad( 144 | img_hid, max_bb-nbb, hid_size)], 145 | dim=1) 146 | outputs.append(img_hid) 147 | 148 | img_hidden = torch.cat(outputs, dim=0) 149 | return img_hidden 150 | 151 | def _get_pad(self, t, len_, hidden_size): 152 | pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device) 153 | return pad 154 | -------------------------------------------------------------------------------- /UNITER/data/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Dataset interfaces 6 | """ 7 | from collections import defaultdict 8 | from contextlib import contextmanager 9 | import io 10 | import json 11 | from os.path import exists 12 | 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import Dataset, ConcatDataset 16 | import horovod.torch as hvd 17 | from tqdm import tqdm 18 | import lmdb 19 | 20 | import msgpack 21 | import msgpack_numpy 22 | msgpack_numpy.patch() 23 | 24 | 25 | def _fp16_to_fp32(feat_dict): 26 | out = {k: arr.astype(np.float32) 27 | if arr.dtype == np.float16 else arr 28 | for k, arr in feat_dict.items()} 29 | return out 30 | 31 | 32 | def compute_num_bb(confs, conf_th, min_bb, max_bb): 33 | num_bb = max(min_bb, (confs > conf_th).sum()) 34 | num_bb = min(max_bb, num_bb) 35 | return num_bb 36 | 37 | 38 | def _check_distributed(): 39 | try: 40 | dist = hvd.size() != hvd.local_size() 41 | except ValueError: 42 | # not using horovod 43 | dist = False 44 | return dist 45 | 46 | class DetectFeatPt(object): 47 | def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, 48 | compress=True): 49 | self.img_dir = img_dir 50 | print("Loading boxes and features from "+self.img_dir) 51 | self.data = torch.load(self.img_dir, map_location='cpu') 52 | self.conf_th = conf_th 53 | self.max_bb = max_bb 54 | def __getitem__(self, image_id): 55 | data = self.data[image_id] 56 | boxes = data["boxes"] 57 | feats = data["features"] 58 | boxes = boxes[:self.max_bb,:] 59 | feats = feats[:self.max_bb,:] 60 | conf = data["conf"][:self.max_bb] 61 | boxes = boxes[conf >= self.conf_th,:] 62 | feats = feats[conf >= self.conf_th,:] 63 | return feats, boxes, data["width"], data["height"] 64 | 65 | class DetectFeatLmdb(object): 66 | def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, 67 | compress=True): 68 | self.img_dir = img_dir 69 | if conf_th == -1: 70 | db_name = f'feat_numbb{num_bb}' 71 | self.name2nbb = defaultdict(lambda: num_bb) 72 | else: 73 | db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}' 74 | nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json' 75 | if not exists(f'{img_dir}/{nbb}'): 76 | # nbb is not pre-computed 77 | self.name2nbb = None 78 | else: 79 | self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) 80 | self.compress = compress 81 | if compress: 82 | db_name += '_compressed' 83 | 84 | if self.name2nbb is None: 85 | if compress: 86 | db_name = 'all_compressed' 87 | else: 88 | db_name = 'all' 89 | # only read ahead on single node training 90 | self.env = lmdb.open(f'{img_dir}/{db_name}', 91 | readonly=True, create=False, 92 | readahead=not _check_distributed()) 93 | self.txn = self.env.begin(buffers=True) 94 | if self.name2nbb is None: 95 | self.name2nbb = self._compute_nbb() 96 | 97 | def _compute_nbb(self): 98 | name2nbb = {} 99 | fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8')) 100 | for fname in tqdm(fnames, desc='reading images'): 101 | dump = self.txn.get(fname.encode('utf-8')) 102 | if self.compress: 103 | with io.BytesIO(dump) as reader: 104 | img_dump = np.load(reader, allow_pickle=True) 105 | confs = img_dump['conf'] 106 | else: 107 | img_dump = msgpack.loads(dump, raw=False) 108 | confs = img_dump['conf'] 109 | name2nbb[fname] = compute_num_bb(confs, self.conf_th, 110 | self.min_bb, self.max_bb) 111 | 112 | return name2nbb 113 | 114 | def __del__(self): 115 | self.env.close() 116 | 117 | def get_dump(self, file_name): 118 | # hack for MRC 119 | dump = self.txn.get(file_name.encode('utf-8')) 120 | nbb = self.name2nbb[file_name] 121 | if self.compress: 122 | with io.BytesIO(dump) as reader: 123 | img_dump = np.load(reader, allow_pickle=True) 124 | img_dump = _fp16_to_fp32(img_dump) 125 | else: 126 | img_dump = msgpack.loads(dump, raw=False) 127 | img_dump = _fp16_to_fp32(img_dump) 128 | img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()} 129 | return img_dump 130 | 131 | def __getitem__(self, file_name): 132 | dump = self.txn.get(file_name.encode('utf-8')) 133 | nbb = self.name2nbb[file_name] 134 | if self.compress: 135 | with io.BytesIO(dump) as reader: 136 | img_dump = np.load(reader, allow_pickle=True) 137 | img_dump = {'features': img_dump['features'], 138 | 'norm_bb': img_dump['norm_bb']} 139 | else: 140 | img_dump = msgpack.loads(dump, raw=False) 141 | img_feat = torch.tensor(img_dump['features'][:nbb, :]).float() 142 | img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float() 143 | return img_feat, img_bb 144 | 145 | def pad_tensors(tensors, lens=None, pad=0): 146 | """B x [T, ...]""" 147 | if lens is None: 148 | lens = [t.size(0) for t in tensors] 149 | max_len = max(lens) 150 | bs = len(tensors) 151 | hid = tensors[0].size(-1) 152 | dtype = tensors[0].dtype 153 | output = torch.zeros(bs, max_len, hid, dtype=dtype) 154 | if pad: 155 | output.data.fill_(pad) 156 | for i, (t, l) in enumerate(zip(tensors, lens)): 157 | output.data[i, :l, ...] = t.data 158 | return output 159 | 160 | 161 | def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): 162 | assert len(txt_lens) == len(num_bbs) == batch_size 163 | gather_index = torch.arange(0, out_size, dtype=torch.long, 164 | ).unsqueeze(0).repeat(batch_size, 1) 165 | 166 | for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): 167 | gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, 168 | dtype=torch.long).data 169 | return gather_index 170 | -------------------------------------------------------------------------------- /UNITER/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | distributed API using Horovod 6 | Modified from OpenNMT's native pytorch distributed utils 7 | (https://github.com/OpenNMT/OpenNMT-py) 8 | """ 9 | import math 10 | import pickle 11 | 12 | import torch 13 | from horovod import torch as hvd 14 | 15 | 16 | def all_reduce_and_rescale_tensors(tensors, rescale_denom): 17 | """All-reduce and rescale tensors at once (as a flattened tensor) 18 | 19 | Args: 20 | tensors: list of Tensors to all-reduce 21 | rescale_denom: denominator for rescaling summed Tensors 22 | """ 23 | # buffer size in bytes, determine equiv. # of elements based on data type 24 | sz = sum(t.numel() for t in tensors) 25 | buffer_t = tensors[0].new(sz).zero_() 26 | 27 | # copy tensors into buffer_t 28 | offset = 0 29 | for t in tensors: 30 | numel = t.numel() 31 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 32 | offset += numel 33 | 34 | # all-reduce and rescale 35 | hvd.allreduce_(buffer_t[:offset]) 36 | buffer_t.div_(rescale_denom) 37 | 38 | # copy all-reduced buffer back into tensors 39 | offset = 0 40 | for t in tensors: 41 | numel = t.numel() 42 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 43 | offset += numel 44 | 45 | 46 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, 47 | buffer_size=10485760): 48 | """All-reduce and rescale tensors in chunks of the specified size. 49 | 50 | Args: 51 | tensors: list of Tensors to all-reduce 52 | rescale_denom: denominator for rescaling summed Tensors 53 | buffer_size: all-reduce chunk size in bytes 54 | """ 55 | # buffer size in bytes, determine equiv. # of elements based on data type 56 | buffer_t = tensors[0].new( 57 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 58 | buffer = [] 59 | 60 | def all_reduce_buffer(): 61 | # copy tensors into buffer_t 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 66 | offset += numel 67 | 68 | # all-reduce and rescale 69 | hvd.allreduce_(buffer_t[:offset]) 70 | buffer_t.div_(rescale_denom) 71 | 72 | # copy all-reduced buffer back into tensors 73 | offset = 0 74 | for t in buffer: 75 | numel = t.numel() 76 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 77 | offset += numel 78 | 79 | filled = 0 80 | for t in tensors: 81 | sz = t.numel() * t.element_size() 82 | if sz > buffer_size: 83 | # tensor is bigger than buffer, all-reduce and rescale directly 84 | hvd.allreduce_(t) 85 | t.div_(rescale_denom) 86 | elif filled + sz > buffer_size: 87 | # buffer is full, all-reduce and replace buffer with grad 88 | all_reduce_buffer() 89 | buffer = [t] 90 | filled = sz 91 | else: 92 | # add tensor to buffer 93 | buffer.append(t) 94 | filled += sz 95 | 96 | if len(buffer) > 0: 97 | all_reduce_buffer() 98 | 99 | 100 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760): 101 | """broadcast tensors in chunks of the specified size. 102 | 103 | Args: 104 | tensors: list of Tensors to broadcast 105 | root_rank: rank to broadcast 106 | buffer_size: broadcast chunk size in bytes 107 | """ 108 | # buffer size in bytes, determine equiv. # of elements based on data type 109 | buffer_t = tensors[0].new( 110 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 111 | buffer = [] 112 | 113 | def broadcast_buffer(): 114 | # copy tensors into buffer_t 115 | offset = 0 116 | for t in buffer: 117 | numel = t.numel() 118 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 119 | offset += numel 120 | 121 | # broadcast 122 | hvd.broadcast_(buffer_t[:offset], root_rank) 123 | 124 | # copy all-reduced buffer back into tensors 125 | offset = 0 126 | for t in buffer: 127 | numel = t.numel() 128 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 129 | offset += numel 130 | 131 | filled = 0 132 | for t in tensors: 133 | sz = t.numel() * t.element_size() 134 | if sz > buffer_size: 135 | # tensor is bigger than buffer, broadcast directly 136 | hvd.broadcast_(t, root_rank) 137 | elif filled + sz > buffer_size: 138 | # buffer is full, broadcast and replace buffer with tensor 139 | broadcast_buffer() 140 | buffer = [t] 141 | filled = sz 142 | else: 143 | # add tensor to buffer 144 | buffer.append(t) 145 | filled += sz 146 | 147 | if len(buffer) > 0: 148 | broadcast_buffer() 149 | 150 | 151 | def _encode(enc, max_size, use_max_size=False): 152 | enc_size = len(enc) 153 | enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) 154 | if use_max_size: 155 | # this is used for broadcasting 156 | buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) 157 | else: 158 | buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) 159 | remainder = enc_size 160 | for i in range(enc_byte): 161 | base = 256 ** (enc_byte-i-1) 162 | buffer_[i] = remainder // base 163 | remainder %= base 164 | buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) 165 | return buffer_, enc_byte 166 | 167 | 168 | def _decode(buffer_, enc_byte): 169 | size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() 170 | for i in range(enc_byte)) 171 | bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) 172 | shift = size + enc_byte 173 | return bytes_list, shift 174 | 175 | 176 | _BUFFER_SIZE = 4096 177 | 178 | 179 | def all_gather_list(data): 180 | """Gathers arbitrary data from all nodes into a list.""" 181 | enc = pickle.dumps(data) 182 | 183 | enc_size = len(enc) 184 | max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() 185 | in_buffer, enc_byte = _encode(enc, max_size) 186 | 187 | out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) 188 | 189 | results = [] 190 | for _ in range(hvd.size()): 191 | bytes_list, shift = _decode(out_buffer, enc_byte) 192 | out_buffer = out_buffer[shift:] 193 | result = pickle.loads(bytes_list) 194 | results.append(result) 195 | return results 196 | 197 | 198 | def any_broadcast(data, root_rank): 199 | """broadcast arbitrary data from root_rank to all nodes.""" 200 | enc = pickle.dumps(data) 201 | 202 | max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() 203 | buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) 204 | 205 | hvd.broadcast_(buffer_, root_rank) 206 | 207 | bytes_list, _ = _decode(buffer_, enc_byte) 208 | result = pickle.loads(bytes_list) 209 | return result 210 | -------------------------------------------------------------------------------- /UNITER/model/pretrain_vcr.py: -------------------------------------------------------------------------------- 1 | from .pretrain import UniterForPretraining 2 | from torch import nn 3 | from .layer import BertOnlyMLMHead 4 | from collections import defaultdict 5 | from torch.nn import functional as F 6 | import torch 7 | 8 | 9 | class UniterForPretrainingForVCR(UniterForPretraining): 10 | """ 2nd Stage Pretrain UNITER for VCR 11 | """ 12 | def init_type_embedding(self): 13 | new_emb = nn.Embedding(4, self.uniter.config.hidden_size) 14 | new_emb.apply(self.init_weights) 15 | for i in [0, 1]: 16 | emb = self.uniter.embeddings.token_type_embeddings.weight.data[i, :] 17 | new_emb.weight.data[i, :].copy_(emb) 18 | emb = self.uniter.embeddings.token_type_embeddings.weight.data[0, :] 19 | new_emb.weight.data[2, :].copy_(emb) 20 | new_emb.weight.data[3, :].copy_(emb) 21 | self.uniter.embeddings.token_type_embeddings = new_emb 22 | 23 | def init_word_embedding(self, num_special_tokens): 24 | orig_word_num = self.uniter.embeddings.word_embeddings.weight.size(0) 25 | new_emb = nn.Embedding( 26 | orig_word_num + num_special_tokens, self.uniter.config.hidden_size) 27 | new_emb.apply(self.init_weights) 28 | emb = self.uniter.embeddings.word_embeddings.weight.data 29 | new_emb.weight.data[:orig_word_num, :].copy_(emb) 30 | self.uniter.embeddings.word_embeddings = new_emb 31 | self.cls = BertOnlyMLMHead( 32 | self.uniter.config, self.uniter.embeddings.word_embeddings.weight) 33 | 34 | def forward(self, batch, task, compute_loss=True): 35 | batch = defaultdict(lambda: None, batch) 36 | input_ids = batch['input_ids'] 37 | position_ids = batch['position_ids'] 38 | img_feat = batch['img_feat'] 39 | img_pos_feat = batch['img_pos_feat'] 40 | attention_mask = batch['attn_masks'] 41 | gather_index = batch['gather_index'] 42 | txt_type_ids = batch['txt_type_ids'] 43 | if task == 'mlm': 44 | txt_labels = batch['txt_labels'] 45 | return self.forward_mlm(input_ids, position_ids, 46 | txt_type_ids, img_feat, img_pos_feat, 47 | attention_mask, gather_index, 48 | txt_labels, compute_loss) 49 | elif task == 'mrfr': 50 | img_mask_tgt = batch['img_mask_tgt'] 51 | img_masks = batch['img_masks'] 52 | mrfr_feat_target = batch['feat_targets'] 53 | return self.forward_mrfr(input_ids, position_ids, 54 | txt_type_ids, img_feat, img_pos_feat, 55 | attention_mask, gather_index, 56 | img_masks, img_mask_tgt, 57 | mrfr_feat_target, compute_loss) 58 | elif task.startswith('mrc'): 59 | img_mask_tgt = batch['img_mask_tgt'] 60 | img_masks = batch['img_masks'] 61 | mrc_label_target = batch['label_targets'] 62 | return self.forward_mrc(input_ids, position_ids, 63 | txt_type_ids, img_feat, img_pos_feat, 64 | attention_mask, gather_index, 65 | img_masks, img_mask_tgt, 66 | mrc_label_target, task, compute_loss) 67 | else: 68 | raise ValueError('invalid task') 69 | 70 | # MLM 71 | def forward_mlm(self, input_ids, position_ids, txt_type_ids, img_feat, 72 | img_pos_feat, attention_mask, gather_index, 73 | txt_labels, compute_loss=True): 74 | sequence_output = self.uniter(input_ids, position_ids, 75 | img_feat, img_pos_feat, 76 | attention_mask, gather_index, 77 | output_all_encoded_layers=False, 78 | txt_type_ids=txt_type_ids) 79 | # get only the text part 80 | sequence_output = sequence_output[:, :input_ids.size(1), :] 81 | # only compute masked tokens for better efficiency 82 | masked_output = self._compute_masked_hidden(sequence_output, 83 | txt_labels != -1) 84 | prediction_scores = self.cls(masked_output) 85 | 86 | if compute_loss: 87 | masked_lm_loss = F.cross_entropy(prediction_scores, 88 | txt_labels[txt_labels != -1], 89 | reduction='none') 90 | return masked_lm_loss 91 | else: 92 | return prediction_scores 93 | 94 | # MRFR 95 | def forward_mrfr(self, input_ids, position_ids, txt_type_ids, 96 | img_feat, img_pos_feat, 97 | attention_mask, gather_index, img_masks, img_mask_tgt, 98 | feat_targets, compute_loss=True): 99 | sequence_output = self.uniter(input_ids, position_ids, 100 | img_feat, img_pos_feat, 101 | attention_mask, gather_index, 102 | output_all_encoded_layers=False, 103 | img_masks=img_masks, 104 | txt_type_ids=txt_type_ids) 105 | 106 | # only compute masked tokens for better efficiency 107 | masked_output = self._compute_masked_hidden(sequence_output, 108 | img_mask_tgt) 109 | prediction_feat = self.feat_regress(masked_output) 110 | 111 | if compute_loss: 112 | mrfr_loss = F.mse_loss(prediction_feat, feat_targets, 113 | reduction='none') 114 | return mrfr_loss 115 | else: 116 | return prediction_feat 117 | 118 | # MRC 119 | def forward_mrc(self, input_ids, position_ids, txt_type_ids, 120 | img_feat, img_pos_feat, 121 | attention_mask, gather_index, img_masks, img_mask_tgt, 122 | label_targets, task, compute_loss=True): 123 | sequence_output = self.uniter(input_ids, position_ids, 124 | img_feat, img_pos_feat, 125 | attention_mask, gather_index, 126 | output_all_encoded_layers=False, 127 | img_masks=img_masks, 128 | txt_type_ids=txt_type_ids) 129 | 130 | # only compute masked regions for better efficiency 131 | masked_output = self._compute_masked_hidden(sequence_output, 132 | img_mask_tgt) 133 | prediction_soft_label = self.region_classifier(masked_output) 134 | 135 | if compute_loss: 136 | if "kl" in task: 137 | prediction_soft_label = F.log_softmax( 138 | prediction_soft_label, dim=-1) 139 | mrc_loss = F.kl_div( 140 | prediction_soft_label, label_targets, reduction='none') 141 | else: 142 | # background class should not be the target 143 | label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1 144 | mrc_loss = F.cross_entropy( 145 | prediction_soft_label, label_targets, 146 | ignore_index=0, reduction='none') 147 | return mrc_loss 148 | else: 149 | return prediction_soft_label 150 | -------------------------------------------------------------------------------- /UNITER/prepro.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | preprocess NLVR annotations into LMDB 6 | """ 7 | import argparse 8 | import json 9 | import pickle 10 | import os 11 | from os.path import exists 12 | 13 | from cytoolz import curry 14 | from tqdm import tqdm 15 | from pytorch_pretrained_bert import BertTokenizer 16 | 17 | from data.data import open_lmdb 18 | 19 | 20 | @curry 21 | def bert_tokenize(tokenizer, text): 22 | ids = [] 23 | for word in text.strip().split(): 24 | ws = tokenizer.tokenize(word) 25 | if not ws: 26 | # some special char 27 | continue 28 | ids.extend(tokenizer.convert_tokens_to_ids(ws)) 29 | return ids 30 | 31 | 32 | def process_nlvr2(jsonl, db, tokenizer, missing=None): 33 | id2len = {} 34 | txt2img = {} # not sure if useful 35 | for line in tqdm(jsonl, desc='processing NLVR2'): 36 | example = json.loads(line) 37 | id_ = example['identifier'] 38 | img_id = '-'.join(id_.split('-')[:-1]) 39 | img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz') 40 | if missing and (img_fname[0] in missing or img_fname[1] in missing): 41 | continue 42 | input_ids = tokenizer(example['sentence']) 43 | if 'label' in example: 44 | target = 1 if example['label'] == 'True' else 0 45 | else: 46 | target = None 47 | txt2img[id_] = img_fname 48 | id2len[id_] = len(input_ids) 49 | example['input_ids'] = input_ids 50 | example['img_fname'] = img_fname 51 | example['target'] = target 52 | db[id_] = example 53 | return id2len, txt2img 54 | 55 | 56 | def process_referring_expressions(refs, instances, iid_to_ann_ids, 57 | db, tokenizer, split): 58 | """ 59 | Inputs: 60 | - refs: [ref_id, ann_id, image_id, split, sent_ids, sentences] 61 | - instances: {images, annotations, categories} 62 | - iid_to_ann_ids: image_id -> ann_ids ordered by extracted butd features 63 | Return: 64 | - id2len : sent_id -> tokenized question length 65 | - images : [{id, file_name, ann_ids, height, width} ] 66 | - annotations: [{id, area, bbox, image_id, category_id, iscrowd}] 67 | - categories : [{id, name, supercategory}] 68 | """ 69 | # images within split 70 | image_set = set([ref['image_id'] for ref in refs if ref['split'] == split]) 71 | images = [] 72 | for img in instances['images']: 73 | if img['id'] in image_set: 74 | images.append({ 75 | 'id': img['id'], 'file_name': img['file_name'], 76 | 'ann_ids': iid_to_ann_ids[str(img['id'])], 77 | 'height': img['height'], 'width': img['width']}) 78 | # Images = {img['id']: img for img in images} 79 | # anns within split 80 | annotations = [] 81 | for ann in instances['annotations']: 82 | if ann['image_id'] in image_set: 83 | annotations.append({ 84 | 'id': ann['id'], 'area': ann['area'], 'bbox': ann['bbox'], 85 | 'image_id': ann['image_id'], 86 | 'category_id': ann['category_id'], 87 | 'iscrowd': ann['iscrowd'] 88 | }) 89 | Anns = {ann['id']: ann for ann in annotations} 90 | # category info 91 | categories = instances['categories'] 92 | # refs within split 93 | refs = [ref for ref in refs if ref['split'] == split] 94 | print(f"Processing {len(refs)} annotations...") 95 | id2len = {} 96 | for ref in tqdm(refs, desc='processing referring expressions'): 97 | ref_id = ref['ref_id'] 98 | ann_id = ref['ann_id'] 99 | image_id = ref['image_id'] 100 | img_fname = f"visual_grounding_coco_gt_{int(image_id):012}.npz" 101 | for sent in ref['sentences']: 102 | sent_id = sent['sent_id'] 103 | input_ids = tokenizer(sent['sent']) 104 | id2len[str(sent_id)] = len(input_ids) 105 | db[str(sent_id)] = { 106 | 'sent_id': sent_id, 'sent': sent['sent'], 107 | 'ref_id': ref_id, 'ann_id': ann_id, 108 | 'image_id': image_id, 'bbox': Anns[ann_id]['bbox'], 109 | 'input_ids': input_ids, 110 | 'img_fname': img_fname 111 | } 112 | return id2len, images, annotations, categories, refs 113 | 114 | 115 | def main(opts): 116 | if not exists(opts.output): 117 | os.makedirs(opts.output) 118 | else: 119 | raise ValueError('Found existing DB. Please explicitly remove ' 120 | 'for re-processing') 121 | meta = vars(opts) 122 | meta['tokenizer'] = opts.toker 123 | toker = BertTokenizer.from_pretrained( 124 | opts.toker, do_lower_case='uncased' in opts.toker) 125 | tokenizer = bert_tokenize(toker) 126 | meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0] 127 | meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0] 128 | meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0] 129 | meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0] 130 | meta['v_range'] = (toker.convert_tokens_to_ids('!')[0], 131 | len(toker.vocab)) 132 | with open(f'{opts.output}/meta.json', 'w') as f: 133 | json.dump(vars(opts), f, indent=4) 134 | 135 | open_db = curry(open_lmdb, opts.output, readonly=False) 136 | output_field_name = ['id2len', 'txt2img'] 137 | with open_db() as db: 138 | if opts.task == 'nlvr': 139 | with open(opts.annotations[0]) as ann: 140 | if opts.missing_imgs is not None: 141 | missing_imgs = set(json.load(open(opts.missing_imgs))) 142 | else: 143 | missing_imgs = None 144 | jsons = process_nlvr2( 145 | ann, db, tokenizer, missing_imgs) 146 | elif opts.task == 're': 147 | data = pickle.load(open(opts.annotations[0], 'rb')) 148 | instances = json.load(open(opts.annotations[1], 'r')) 149 | iid_to_ann_ids = json.load( 150 | open(opts.annotations[2], 'r'))['iid_to_ann_ids'] 151 | # dirs/refcoco_testA_bert-base-cased.db -> testA 152 | img_split = opts.output.split('/')[-1].split('.')[0].split('_')[1] 153 | jsons = process_referring_expressions( 154 | data, instances, iid_to_ann_ids, 155 | db, tokenizer, img_split) 156 | output_field_name = [ 157 | 'id2len', 'images', 'annotations', 158 | 'categories', 'refs'] 159 | 160 | for dump, name in zip(jsons, output_field_name): 161 | with open(f'{opts.output}/{name}.json', 'w') as f: 162 | json.dump(dump, f) 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('--annotations', required=True, nargs='+', 168 | help='annotation JSON') 169 | parser.add_argument('--missing_imgs', 170 | help='some training image features are corrupted') 171 | parser.add_argument('--output', required=True, 172 | help='output dir of DB') 173 | parser.add_argument('--task', required=True, default='nlvr', 174 | choices=['nlvr', 're']) 175 | parser.add_argument('--toker', default='bert-base-cased', 176 | help='which BERT tokenizer to used') 177 | args = parser.parse_args() 178 | if args.task == 'nlvr': 179 | assert len(args.annotations) == 1 180 | elif args.task == 're': 181 | assert len(args.annotations) == 3 182 | main(args) 183 | -------------------------------------------------------------------------------- /clip_mm_explain/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 23 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 24 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" 25 | } 26 | 27 | 28 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 29 | os.makedirs(root, exist_ok=True) 30 | filename = os.path.basename(url) 31 | 32 | expected_sha256 = url.split("/")[-2] 33 | download_target = os.path.join(root, filename) 34 | 35 | if os.path.exists(download_target) and not os.path.isfile(download_target): 36 | raise RuntimeError(f"{download_target} exists and is not a regular file") 37 | 38 | if os.path.isfile(download_target): 39 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 40 | return download_target 41 | else: 42 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 43 | 44 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 45 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 46 | while True: 47 | buffer = source.read(8192) 48 | if not buffer: 49 | break 50 | 51 | output.write(buffer) 52 | loop.update(len(buffer)) 53 | 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 55 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 56 | 57 | return download_target 58 | 59 | 60 | def _transform(n_px): 61 | return Compose([ 62 | Resize(n_px, interpolation=Image.BICUBIC), 63 | CenterCrop(n_px), 64 | lambda image: image.convert("RGB"), 65 | ToTensor(), 66 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 67 | ]) 68 | 69 | 70 | def available_models() -> List[str]: 71 | """Returns the names of available CLIP models""" 72 | return list(_MODELS.keys()) 73 | 74 | 75 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 76 | """Load a CLIP model 77 | 78 | Parameters 79 | ---------- 80 | name : str 81 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 82 | 83 | device : Union[str, torch.device] 84 | The device to put the loaded model 85 | 86 | jit : bool 87 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 88 | 89 | Returns 90 | ------- 91 | model : torch.nn.Module 92 | The CLIP model 93 | 94 | preprocess : Callable[[PIL.Image], torch.Tensor] 95 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 96 | """ 97 | if name in _MODELS: 98 | model_path = _download(_MODELS[name]) 99 | elif os.path.isfile(name): 100 | model_path = name 101 | else: 102 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 103 | 104 | try: 105 | # loading JIT archive 106 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 107 | state_dict = None 108 | except RuntimeError: 109 | # loading saved state dict 110 | if jit: 111 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 112 | jit = False 113 | state_dict = torch.load(model_path, map_location="cpu") 114 | 115 | if not jit: 116 | model = build_model(state_dict or model.state_dict()).to(device) 117 | if str(device) == "cpu": 118 | model.float() 119 | return model, _transform(model.visual.input_resolution) 120 | 121 | # patch the device names 122 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 123 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 124 | 125 | def patch_device(module): 126 | graphs = [module.graph] if hasattr(module, "graph") else [] 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("prim::Constant"): 132 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 133 | node.copyAttributes(device_node) 134 | 135 | model.apply(patch_device) 136 | patch_device(model.encode_image) 137 | patch_device(model.encode_text) 138 | 139 | # patch dtype to float32 on CPU 140 | if str(device) == "cpu": 141 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 142 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 143 | float_node = float_input.node() 144 | 145 | def patch_float(module): 146 | graphs = [module.graph] if hasattr(module, "graph") else [] 147 | if hasattr(module, "forward1"): 148 | graphs.append(module.forward1.graph) 149 | 150 | for graph in graphs: 151 | for node in graph.findAllNodes("aten::to"): 152 | inputs = list(node.inputs()) 153 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 154 | if inputs[i].node()["value"] == 5: 155 | inputs[i].node().copyAttributes(float_node) 156 | 157 | model.apply(patch_float) 158 | patch_float(model.encode_image) 159 | patch_float(model.encode_text) 160 | 161 | model.float() 162 | 163 | return model, _transform(model.input_resolution.item()) 164 | 165 | 166 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 167 | """ 168 | Returns the tokenized representation of given input string(s) 169 | 170 | Parameters 171 | ---------- 172 | texts : Union[str, List[str]] 173 | An input string or a list of input strings to tokenize 174 | 175 | context_length : int 176 | The context length to use; all CLIP models use 77 as the context length 177 | 178 | Returns 179 | ------- 180 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 181 | """ 182 | if isinstance(texts, str): 183 | texts = [texts] 184 | 185 | sot_token = _tokenizer.encoder["<|startoftext|>"] 186 | eot_token = _tokenizer.encoder["<|endoftext|>"] 187 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 188 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 189 | 190 | for i, tokens in enumerate(all_tokens): 191 | if len(tokens) > context_length: 192 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 193 | result[i, :len(tokens)] = torch.tensor(tokens) 194 | 195 | return result 196 | -------------------------------------------------------------------------------- /interpreter.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, List, Callable 2 | import sys 3 | import re 4 | import numpy as np 5 | import torch 6 | from numpy.linalg import norm 7 | from itertools import product, groupby 8 | from PIL import Image 9 | 10 | 11 | # Do two line segments intersect? Copied from 12 | # https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect 13 | 14 | 15 | def ccw(A, B, C): 16 | return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x) 17 | 18 | 19 | def intersect(A, B, C, D): 20 | """Do line segments AB and CD intersect?""" 21 | return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) 22 | 23 | 24 | class Box(NamedTuple): 25 | x: int 26 | y: int 27 | w: int = 0 28 | h: int = 0 29 | 30 | @property 31 | def left(self): 32 | return self.x 33 | 34 | @property 35 | def right(self): 36 | return self.x + self.w 37 | 38 | @property 39 | def top(self): 40 | return self.y 41 | 42 | @property 43 | def bottom(self): 44 | return self.y + self.h 45 | 46 | @property 47 | def center(self): 48 | return Box(self.x + self.w // 2, self.y + self.h // 2) 49 | 50 | def corners(self): 51 | yield Box(self.x, self.y) 52 | yield Box(self.x + self.w, self.y) 53 | yield Box(self.x + self.w, self.y + self.h) 54 | yield Box(self.x, self.y + self.h) 55 | 56 | @property 57 | def area(self): 58 | return self.w * self.h 59 | 60 | def intersect(self, other: "Box") -> "Box": 61 | x1 = max(self.x, other.x) 62 | x2 = max(x1, min(self.x+self.w, other.x+other.w)) 63 | y1 = max(self.y, other.y) 64 | y2 = max(y1, min(self.y+self.h, other.y+other.h)) 65 | return Box(x=x1, y=y1, w=x2-x1, h=y2-y1) 66 | 67 | def min_bounding(self, other: "Box") -> "Box": 68 | corners = list(self.corners()) 69 | corners.extend(other.corners()) 70 | min_x = min_y = float("inf") 71 | max_x = max_y = -float("inf") 72 | 73 | for item in corners: 74 | min_x = min(min_x, item.x) 75 | min_y = min(min_y, item.y) 76 | max_x = max(max_x, item.x) 77 | max_y = max(max_y, item.y) 78 | 79 | return Box(min_x, min_y, max_x - min_x, max_y - min_y) 80 | 81 | def expand(self, growth: float = .1) -> "Box": 82 | factor = 1 + growth 83 | w = factor * self.w 84 | h = factor * self.h 85 | return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h) 86 | 87 | 88 | def iou(box1, box2): 89 | x1 = max(box1.x, box2.x) 90 | x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w)) 91 | y1 = max(box1.y, box2.y) 92 | y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h)) 93 | intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1) 94 | intersection_area = intersection.area 95 | union_area = box1.area+box2.area-intersection_area 96 | return intersection_area / union_area 97 | 98 | 99 | def all_equal(iterable): 100 | """Are all elements the same?""" 101 | g = groupby(iterable) 102 | return next(g, True) and not next(g, False) 103 | 104 | 105 | class spatial: 106 | """A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes.""" 107 | 108 | def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False): 109 | self.arity = arity 110 | self.enforce_antisymmetry = enforce_antisymmetry # Zero out any entries where two boxes are the same. 111 | 112 | def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]: 113 | def _rel(env): 114 | n_boxes = len(env.boxes) 115 | tensor = np.empty([n_boxes for _ in range(self.arity)]) 116 | enum_boxes = list(enumerate(env.boxes)) 117 | for pairs in product(*[enum_boxes for _ in range(self.arity)]): 118 | indices, boxes = zip(*pairs) 119 | if self.enforce_antisymmetry and len(set(indices)) < len(indices): 120 | tensor[indices] = 0. 121 | else: 122 | tensor[indices] = predicate(*boxes) 123 | return tensor 124 | return _rel 125 | 126 | 127 | class Environment: 128 | def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, mask=None): 129 | self.image = image 130 | self.boxes = boxes 131 | self.executor = executor # An object or callback that can query CLIP with captions/images. 132 | self.freeform_boxes = freeform_boxes 133 | self.image_name = image_name 134 | self.masks = mask 135 | def uniform(self) -> np.ndarray: 136 | n_boxes = len(self.boxes) 137 | return 1 / n_boxes * np.ones(n_boxes) 138 | 139 | def filter(self, 140 | caption: str, 141 | temperature: float = 1., 142 | area_threshold: float = 0.0, 143 | softmax: bool = False, 144 | expand: float = None, 145 | caption_bank = [], 146 | mask_dino = None 147 | ) -> np.ndarray: 148 | """Return a new distribution reflecting the likelihood that `caption` describes the content of each box.""" 149 | area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device) 150 | 151 | candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0] 152 | boxes = [self.boxes[i] for i in candidate_indices] 153 | if self.masks is not None: 154 | masks = [self.masks[i] for i in candidate_indices] 155 | if len(boxes) == 0: 156 | boxes = self.boxes 157 | candidate_indices = list(range(len(boxes))) 158 | if expand is not None: 159 | boxes = [box.expand(expand) for box in boxes] 160 | if self.masks is not None: 161 | masks = [self.masks[i] for i in candidate_indices] 162 | result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name,masks=masks) 163 | else: 164 | result_partial, attn = self.executor(caption, self.image, boxes, image_name=self.image_name, caption_bank=caption_bank, mask_dino=mask_dino) 165 | if self.freeform_boxes: 166 | result_partial, boxes = result_partial 167 | self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])] 168 | candidate_indices = list(range(len(self.boxes))) 169 | result_partial = result_partial.float() 170 | if not softmax: 171 | result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9) 172 | result_partial = (temperature * result_partial).sigmoid() 173 | result = torch.zeros((len(self.boxes))).to(result_partial.device) 174 | result[candidate_indices] = result_partial 175 | else: 176 | result = torch.zeros((len(self.boxes))).to(result_partial.device) 177 | result[candidate_indices] = result_partial.softmax(dim=-1) 178 | return result.cpu().numpy(), attn.cpu().numpy() 179 | 180 | def filter_area(self, area_threshold: float) -> np.ndarray: 181 | """Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold.""" 182 | image_area = self.image.width*self.image.height 183 | return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))]) 184 | 185 | @spatial() 186 | def left_of(b1, b2): 187 | return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2 188 | 189 | @spatial() 190 | def right_of(b1, b2): 191 | return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2 192 | 193 | @spatial() 194 | def above(b1, b2): 195 | return (b1.bottom+b1.top) < (b2.bottom+b2.top) 196 | 197 | @spatial() 198 | def below(b1, b2): 199 | return (b1.bottom+b1.top) > (b2.bottom+b2.top) 200 | 201 | @spatial() 202 | def bigger_than(b1, b2): 203 | return b1.area > b2.area 204 | 205 | @spatial() 206 | def smaller_than(b1, b2): 207 | return b1.area < b2.area 208 | 209 | @spatial(enforce_antisymmetry=False) 210 | def within(box1, box2): 211 | """Return percent of box1 inside box2.""" 212 | intersection = box1.intersect(box2) 213 | return intersection.area / box1.area 214 | 215 | @spatial(arity=3, enforce_antisymmetry=True) 216 | def between(box1, box2, box3): 217 | """How much of box1 lies in min bounding box over box2 and box3?""" 218 | min_bounding = box2.min_bounding(box3) 219 | intersect = box1.intersect(min_bounding) 220 | return intersect.area / box1.area 221 | -------------------------------------------------------------------------------- /UNITER/inf_re.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | run inference of VQA for submission 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from os.path import exists 11 | from time import time 12 | from tqdm import tqdm 13 | 14 | import torch 15 | from torch.utils.data import DataLoader 16 | 17 | from apex import amp 18 | from horovod import torch as hvd 19 | from cytoolz import concat 20 | 21 | from data import (PrefetchLoader, 22 | re_eval_collate, DetectFeatPt, ReEvalJsonDataset, ReTxtTokJson) 23 | from data.sampler import DistributedSampler 24 | from model.re import UniterForReferringExpressionComprehension 25 | 26 | from utils.logger import LOGGER 27 | from utils.distributed import all_gather_list 28 | from utils.misc import Struct 29 | from utils.const import IMG_DIM 30 | 31 | 32 | def write_to_tmp(txt, tmp_file): 33 | if tmp_file: 34 | f = open(tmp_file, "a") 35 | f.write(txt) 36 | 37 | 38 | def main(opts): 39 | if opts.cpu: 40 | device = torch.device('cpu') 41 | n_gpu = 1 42 | rank = 0 43 | else: 44 | hvd.init() 45 | n_gpu = hvd.size() 46 | device = torch.device("cuda", hvd.local_rank()) 47 | torch.cuda.set_device(hvd.local_rank()) 48 | rank = hvd.rank() 49 | LOGGER.info("device: {} n_gpu: {}, rank: {}, " 50 | "16-bits training: {}".format( 51 | device, n_gpu, hvd.rank(), opts.fp16)) 52 | 53 | hps_file = f'{opts.output_dir}/log/hps.json' 54 | model_opts = json.load(open(hps_file)) 55 | if 'mlp' not in model_opts: 56 | model_opts['mlp'] = 1 57 | model_opts = Struct(model_opts) 58 | # Prepare model 59 | if exists(opts.checkpoint): 60 | ckpt_file = opts.checkpoint 61 | else: 62 | ckpt_file = f'{opts.output_dir}/ckpt/model_epoch_{opts.checkpoint}.pt' 63 | checkpoint = torch.load(ckpt_file) 64 | model = UniterForReferringExpressionComprehension.from_pretrained( 65 | f'{opts.output_dir}/log/model.json', checkpoint, 66 | img_dim=IMG_DIM, mlp=model_opts.mlp) 67 | model.to(device) 68 | if not opts.cpu: 69 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 70 | if opts.fp16: 71 | model = amp.initialize(model, enabled=True, opt_level='O2') 72 | 73 | # load DBs and image dirs 74 | img_db_type = "gt" if "coco_gt" in opts.img_db else "det" 75 | print(model_opts.conf_th) 76 | conf_th = -1 if img_db_type == "gt" else model_opts.conf_th 77 | num_bb = 100 if img_db_type == "gt" else model_opts.num_bb 78 | if opts.simple_format: 79 | eval_img_db = DetectFeatPt(opts.img_db, conf_th, model_opts.max_bb, 80 | model_opts.min_bb, num_bb, opts.compressed_db) 81 | else: 82 | eval_img_db = DetectFeatLmdb(opts.img_db, 83 | conf_th, model_opts.max_bb, 84 | model_opts.min_bb, num_bb, 85 | opts.compressed_db) 86 | 87 | # Prepro txt_dbs 88 | txt_dbs = opts.txt_db.split(':') 89 | for txt_db in txt_dbs: 90 | print(f'Evaluating {txt_db}') 91 | if opts.simple_format: 92 | eval_txt_db = ReTxtTokJson(txt_db, -1) 93 | eval_dataset = ReEvalJsonDataset( 94 | eval_txt_db, eval_img_db) 95 | else: 96 | assert False, "Original data format not supported" 97 | 98 | sampler = DistributedSampler(eval_dataset, num_replicas=n_gpu, 99 | rank=rank, shuffle=False) 100 | eval_dataloader = DataLoader(eval_dataset, 101 | sampler=sampler, 102 | batch_size=opts.batch_size, 103 | num_workers=opts.n_workers, 104 | pin_memory=opts.pin_mem, 105 | collate_fn=re_eval_collate) 106 | eval_dataloader = PrefetchLoader(eval_dataloader) 107 | 108 | # evaluate 109 | val_log, results = evaluate(model, eval_dataloader) 110 | 111 | result_dir = f'{opts.output_dir}/results_test' 112 | if not exists(result_dir) and rank == 0: 113 | os.makedirs(result_dir) 114 | write_to_tmp( 115 | f"{txt_db.split('_')[1].split('.')[0]}-acc({img_db_type}): {results['acc']*100:.2f}% ", 116 | args.tmp_file) 117 | 118 | # all_results = list(concat(all_gather_list(results))) 119 | all_results = results 120 | 121 | if hvd.rank() == 0: 122 | db_split = txt_db.split('/')[-1].split('.')[0] # refcoco+_val 123 | img_dir = opts.img_db.split('/')[-1] # re_coco_gt 124 | with open(f'{result_dir}/' 125 | f'results_{opts.checkpoint}_{db_split}_on_{img_dir}_all.json', 'w') as f: 126 | json.dump(all_results, f) 127 | # print 128 | print(f'{opts.output_dir}/results_test') 129 | 130 | write_to_tmp(f'\n', args.tmp_file) 131 | 132 | 133 | @torch.no_grad() 134 | def evaluate(model, eval_loader): 135 | LOGGER.info("start running evaluation...") 136 | model.eval() 137 | tot_score = 0 138 | n_ex = 0 139 | st = time() 140 | predictions = [] 141 | for i, batch in tqdm(enumerate(eval_loader)): 142 | (tgt_box_list, obj_boxes_list, sent_ids) = ( 143 | batch['tgt_box'], batch['obj_boxes'], batch['sent_ids']) 144 | # scores (n, max_num_bb) 145 | scores = model(batch, compute_loss=False) 146 | ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, ) 147 | 148 | # pred_boxes 149 | for ix, obj_boxes, tgt_box, sent_id in \ 150 | zip(ixs, obj_boxes_list, tgt_box_list, sent_ids): 151 | pred_box = obj_boxes[ix] 152 | predictions.append({'sent_id': int(sent_id), 153 | 'pred_box': pred_box.tolist(), 154 | 'tgt_box': tgt_box.tolist()}) 155 | if eval_loader.loader.dataset.computeIoU(pred_box, tgt_box) > .5: 156 | tot_score += 1 157 | n_ex += 1 158 | if i % 100 == 0 and hvd.rank() == 0: 159 | n_results = len(predictions) 160 | n_results *= hvd.size() # an approximation to avoid hangs 161 | LOGGER.info(f'{n_results}/{len(eval_loader.dataset)} ' 162 | 'answers predicted') 163 | n_ex = sum(all_gather_list(n_ex)) 164 | tot_time = time()-st 165 | tot_score = sum(all_gather_list(tot_score)) 166 | val_acc = tot_score / n_ex 167 | val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time} 168 | model.train() 169 | LOGGER.info(f"validation ({n_ex} sents) finished in" 170 | f" {int(tot_time)} seconds" 171 | f", accuracy: {val_acc*100:.2f}%") 172 | # summarizae 173 | results = {'acc': val_acc, 'predictions': predictions} 174 | return val_log, results 175 | 176 | 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | 180 | # Required parameters 181 | parser.add_argument("--txt_db", 182 | default=None, type=str, 183 | help="The input train corpus. (LMDB)") 184 | parser.add_argument("--img_db", 185 | default=None, type=str, 186 | help="The input train images.") 187 | parser.add_argument('--compressed_db', action='store_true', 188 | help='use compressed LMDB') 189 | parser.add_argument("--checkpoint", 190 | default=None, type=str, 191 | help="can be the path to binary or int number (step)") 192 | parser.add_argument("--batch_size", 193 | default=256, type=int, 194 | help="number of sentences per batch") 195 | 196 | parser.add_argument("--output_dir", default=None, type=str, 197 | help="The output directory of the training command") 198 | 199 | # device parameters 200 | parser.add_argument('--fp16', 201 | action='store_true', 202 | help="Whether to use 16-bit float precision instead " 203 | "of 32-bit") 204 | parser.add_argument('--n_workers', type=int, default=4, 205 | help="number of data workers") 206 | parser.add_argument('--pin_mem', action='store_true', 207 | help="pin memory") 208 | 209 | # Write simple results to some tmp file 210 | parser.add_argument('--tmp_file', type=str, default=None, 211 | help="write results to tmp file") 212 | 213 | # JSON/PT format 214 | parser.add_argument('--simple_format', action='store_true') 215 | 216 | parser.add_argument('--cpu', action='store_true') 217 | 218 | args = parser.parse_args() 219 | print(args.simple_format) 220 | 221 | main(args) 222 | -------------------------------------------------------------------------------- /clevr-dataset-gen/gather_simple_clevr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import argparse 4 | from copy import deepcopy 5 | from collections import defaultdict 6 | 7 | from bounding_box import extract_bounding_boxes 8 | 9 | def construct_non_spatial_text(object_list): 10 | text = "" 11 | for obj in object_list[:-1]: 12 | text += "a "+obj 13 | if len(object_list) > 2: 14 | text += "," 15 | text += " " 16 | text += "and a "+object_list[-1] 17 | return text 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--scenes_path', help="Path to scenes file") 21 | parser.add_argument("--output_path", help="Path to output file") 22 | parser.add_argument('--spatial', action='store_true', help="If true, collect pairs of scenes with same objects in different spatial configuration") 23 | parser.add_argument('--max_num_objects', type=int, default=3, help="Max number of objects in the filtered scenes") 24 | parser.add_argument('--mode', type=str, default='text_pair', choices=['text_pair', 'image_pair']) 25 | args = parser.parse_args() 26 | 27 | with open(args.scenes_path) as scenes_file: 28 | scenes = json.load(scenes_file) 29 | if isinstance(scenes, dict): 30 | scenes = scenes["scenes"] 31 | for scene in scenes: 32 | xmin, ymin, xmax, ymax, _, _ = extract_bounding_boxes(scene) 33 | scene["boxes"] = [[xmin[i], ymin[i], xmax[i], ymax[i]] for i in range(len(xmin))] 34 | scenes_by_type = defaultdict(list) 35 | all_colors = set() 36 | all_shapes = set() 37 | for scene in scenes: 38 | if len(scene["objects"]) <= args.max_num_objects: 39 | types = sorted([obj["color"]+" "+obj["shape"] for obj in scene["objects"]]) 40 | scenes_by_type[",".join(types)].append(scene) 41 | for obj in scene["objects"]: 42 | all_colors.add(obj["color"]) 43 | all_shapes.add(obj["shape"]) 44 | all_colors = list(all_colors) 45 | all_shapes = list(all_shapes) 46 | examples = [] 47 | if args.spatial: 48 | for scene_type in scenes_by_type: 49 | if args.mode == "image_pair": 50 | # assert len(scenes_by_type[scene_type]) > 1 51 | scene1 = random.choice(scenes_by_type[scene_type]) 52 | print(len(scenes_by_type[scene_type])) 53 | scene2 = scene1 54 | diff_rels = [] 55 | random.shuffle(scenes_by_type[scene_type]) 56 | i = 0 57 | while len(diff_rels) == 0 and i < len(scenes_by_type[scene_type]): 58 | scene2 = scenes_by_type[scene_type][i] 59 | obj_index_map = {0: 0, 1: 1} 60 | if scene2["objects"][0]["color"] != scene1["objects"][0]["color"] or scene2["objects"][0]["shape"] != scene1["objects"][0]["shape"]: 61 | scene2["objects"] = scene2["objects"][::-1] 62 | for rel in scene2["relationships"]: 63 | scene2["relationships"][rel] = scene2["relationships"][rel][::-1] 64 | for j in range(len(scene2['relationships'][rel])): 65 | if len(scene2['relationships'][rel][j]) > 0: 66 | scene2['relationships'][rel][j][0] = 1-scene2['relationships'][rel][j][0] 67 | assert scene2["objects"][obj_index_map[0]]["color"] == scene1["objects"][0]["color"] and scene2["objects"][obj_index_map[0]]["shape"] == scene2["objects"][obj_index_map[0]]["shape"] 68 | for rel in scene1["relationships"]: 69 | if scene1["relationships"][rel] != scene2["relationships"][rel]: 70 | diff_rels.append(rel) 71 | i += 1 72 | if len(diff_rels) == 0: 73 | continue 74 | print(diff_rels) 75 | example = deepcopy(scene1) 76 | example["image_filename2"] = scene2["image_filename"] 77 | rel = random.choice(diff_rels) 78 | if rel in {"left", "right"}: 79 | if len(scene1["relationships"][rel][0]) != 1: 80 | obj_index_map = {1: 0, 0: 1} 81 | example["text1"] = "a "+scene1["objects"][obj_index_map[1]]["color"]+" "+scene1["objects"][obj_index_map[1]]["shape"]+" to the "+rel+" of a "+scene1["objects"][obj_index_map[0]]["color"]+" "+scene1["objects"][obj_index_map[0]]["shape"]+"." 82 | else: 83 | if len(scene1["relationships"][rel][0]) != 1: 84 | obj_index_map = {1: 0, 0: 1} 85 | if rel == "front": 86 | example["text1"] = "a "+scene1["objects"][obj_index_map[1]]["color"]+" "+scene1["objects"][obj_index_map[1]]["shape"]+" in front of a "+scene1["objects"][obj_index_map[0]]["color"]+" "+scene1["objects"][obj_index_map[0]]["shape"]+"." 87 | else: 88 | example["text1"] = "a "+scene1["objects"][obj_index_map[1]]["color"]+" "+scene1["objects"][obj_index_map[1]]["shape"]+" behind a "+scene1["objects"][obj_index_map[0]]["color"]+" "+scene1["objects"][obj_index_map[0]]["shape"]+"." 89 | examples.append(example) 90 | if args.mode == "text_pair": 91 | scene = random.choice(scenes_by_type[scene_type]) 92 | object_order = list(range(len(scene["objects"]))) 93 | random.shuffle(object_order) 94 | relation_order = list(scene["relationships"].keys()) 95 | random.shuffle(relation_order) 96 | added_example = False 97 | for obj in object_order: 98 | for relation in relation_order: 99 | if len(scene["relationships"][relation][obj]) > 0: 100 | obj2 = random.choice(scene["relationships"][relation][obj]) 101 | example = deepcopy(scene) 102 | left_right = ["left", "right"] 103 | if relation in left_right: 104 | example["text1"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" to the "+relation+" of a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 105 | example["text2"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" to the "+left_right[1-left_right.index(relation)]+" of a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 106 | elif relation == "front": 107 | example["text1"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" in front of a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 108 | example["text2"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" behind a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 109 | elif relation == "behind": 110 | example["text1"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" behind a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 111 | example["text2"] = "a "+scene["objects"][obj2]["color"]+" "+scene["objects"][obj2]["shape"]+" in front of a "+scene["objects"][obj]["color"]+" "+scene["objects"][obj]["shape"]+"." 112 | examples.append(example) 113 | added_example = True 114 | break 115 | if added_example: 116 | break 117 | else: 118 | if args.mode == "image_pair": 119 | for scene_type in scenes_by_type: 120 | scene = deepcopy(random.choice(scenes_by_type[scene_type])) 121 | scene2_type = scene_type 122 | while scene2_type == scene_type: 123 | scene2_type = random.choice(list(scenes_by_type.keys())) 124 | scene2 = random.choice(scenes_by_type[scene2_type]) 125 | scene['image_filename2'] = scene2['image_filename'] 126 | scene['text1'] = construct_non_spatial_text(scene_type.split(",")) 127 | examples.append(scene) 128 | if args.mode == "text_pair": 129 | for scene_type in scenes_by_type: 130 | scene = random.choice(scenes_by_type[scene_type]) 131 | objects_in_scene = scene_type.split(",") 132 | random.shuffle(all_colors) 133 | random.shuffle(all_shapes) 134 | missing_object = None 135 | for color in all_colors: 136 | for shape in all_shapes: 137 | if color+" "+shape not in objects_in_scene: 138 | missing_object = color+" "+shape 139 | break 140 | if missing_object is not None: 141 | break 142 | example = deepcopy(scene) 143 | random.shuffle(objects_in_scene) 144 | example["text1"] = construct_non_spatial_text(objects_in_scene) 145 | objects_in_scene[random.choice(list(range(len(objects_in_scene))))] = missing_object 146 | example["text2"] = construct_non_spatial_text(objects_in_scene) 147 | examples.append(example) 148 | fout = open(args.output_path, 'w') 149 | json.dump(examples, fout) 150 | fout.close() 151 | print(len(examples)) 152 | -------------------------------------------------------------------------------- /UNITER/model/nlvr2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Uniter for NLVR2 model 6 | """ 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .model import UniterPreTrainedModel, UniterModel 14 | from .attention import MultiheadAttention 15 | 16 | 17 | class UniterForNlvr2Paired(UniterPreTrainedModel): 18 | """ Finetune UNITER for NLVR2 (paired format) 19 | """ 20 | def __init__(self, config, img_dim): 21 | super().__init__(config) 22 | self.uniter = UniterModel(config, img_dim) 23 | self.nlvr2_output = nn.Linear(config.hidden_size*2, 2) 24 | self.apply(self.init_weights) 25 | 26 | def init_type_embedding(self): 27 | new_emb = nn.Embedding(3, self.uniter.config.hidden_size) 28 | new_emb.apply(self.init_weights) 29 | for i in [0, 1]: 30 | emb = self.uniter.embeddings.token_type_embeddings\ 31 | .weight.data[i, :] 32 | new_emb.weight.data[i, :].copy_(emb) 33 | new_emb.weight.data[2, :].copy_(emb) 34 | self.uniter.embeddings.token_type_embeddings = new_emb 35 | 36 | def forward(self, batch, compute_loss=True): 37 | batch = defaultdict(lambda: None, batch) 38 | input_ids = batch['input_ids'] 39 | position_ids = batch['position_ids'] 40 | img_feat = batch['img_feat'] 41 | img_pos_feat = batch['img_pos_feat'] 42 | attn_masks = batch['attn_masks'] 43 | gather_index = batch['gather_index'] 44 | img_type_ids = batch['img_type_ids'] 45 | sequence_output = self.uniter(input_ids, position_ids, 46 | img_feat, img_pos_feat, 47 | attn_masks, gather_index, 48 | output_all_encoded_layers=False, 49 | img_type_ids=img_type_ids) 50 | pooled_output = self.uniter.pooler(sequence_output) 51 | # concat CLS of the pair 52 | n_pair = pooled_output.size(0) // 2 53 | reshaped_output = pooled_output.contiguous().view(n_pair, -1) 54 | answer_scores = self.nlvr2_output(reshaped_output) 55 | 56 | if compute_loss: 57 | targets = batch['targets'] 58 | nlvr2_loss = F.cross_entropy( 59 | answer_scores, targets, reduction='none') 60 | return nlvr2_loss 61 | else: 62 | return answer_scores 63 | 64 | 65 | class UniterForNlvr2Triplet(UniterPreTrainedModel): 66 | """ Finetune UNITER for NLVR2 (triplet format) 67 | """ 68 | def __init__(self, config, img_dim): 69 | super().__init__(config) 70 | self.uniter = UniterModel(config, img_dim) 71 | self.nlvr2_output = nn.Linear(config.hidden_size, 2) 72 | self.apply(self.init_weights) 73 | 74 | def init_type_embedding(self): 75 | new_emb = nn.Embedding(3, self.uniter.config.hidden_size) 76 | new_emb.apply(self.init_weights) 77 | for i in [0, 1]: 78 | emb = self.uniter.embeddings.token_type_embeddings\ 79 | .weight.data[i, :] 80 | new_emb.weight.data[i, :].copy_(emb) 81 | new_emb.weight.data[2, :].copy_(emb) 82 | self.uniter.embeddings.token_type_embeddings = new_emb 83 | 84 | def forward(self, batch, compute_loss=True): 85 | batch = defaultdict(lambda: None, batch) 86 | input_ids = batch['input_ids'] 87 | position_ids = batch['position_ids'] 88 | img_feat = batch['img_feat'] 89 | img_pos_feat = batch['img_pos_feat'] 90 | attn_masks = batch['attn_masks'] 91 | gather_index = batch['gather_index'] 92 | img_type_ids = batch['img_type_ids'] 93 | sequence_output = self.uniter(input_ids, position_ids, 94 | img_feat, img_pos_feat, 95 | attn_masks, gather_index, 96 | output_all_encoded_layers=False, 97 | img_type_ids=img_type_ids) 98 | pooled_output = self.uniter.pooler(sequence_output) 99 | answer_scores = self.nlvr2_output(pooled_output) 100 | 101 | if compute_loss: 102 | targets = batch['targets'] 103 | nlvr2_loss = F.cross_entropy( 104 | answer_scores, targets, reduction='none') 105 | return nlvr2_loss 106 | else: 107 | return answer_scores 108 | 109 | 110 | class AttentionPool(nn.Module): 111 | """ attention pooling layer """ 112 | def __init__(self, hidden_size, drop=0.0): 113 | super().__init__() 114 | self.fc = nn.Sequential(nn.Linear(hidden_size, 1), nn.ReLU()) 115 | self.dropout = nn.Dropout(drop) 116 | 117 | def forward(self, input_, mask=None): 118 | """input: [B, T, D], mask = [B, T]""" 119 | score = self.fc(input_).squeeze(-1) 120 | if mask is not None: 121 | mask = mask.to(dtype=input_.dtype) * -1e4 122 | score = score + mask 123 | norm_score = self.dropout(F.softmax(score, dim=1)) 124 | output = norm_score.unsqueeze(1).matmul(input_).squeeze(1) 125 | return output 126 | 127 | 128 | class UniterForNlvr2PairedAttn(UniterPreTrainedModel): 129 | """ Finetune UNITER for NLVR2 130 | (paired format with additional attention layer) 131 | """ 132 | def __init__(self, config, img_dim): 133 | super().__init__(config) 134 | self.uniter = UniterModel(config, img_dim) 135 | self.attn1 = MultiheadAttention(config.hidden_size, 136 | config.num_attention_heads, 137 | config.attention_probs_dropout_prob) 138 | self.attn2 = MultiheadAttention(config.hidden_size, 139 | config.num_attention_heads, 140 | config.attention_probs_dropout_prob) 141 | self.fc = nn.Sequential( 142 | nn.Linear(2*config.hidden_size, config.hidden_size), 143 | nn.ReLU(), 144 | nn.Dropout(config.hidden_dropout_prob)) 145 | self.attn_pool = AttentionPool(config.hidden_size, 146 | config.attention_probs_dropout_prob) 147 | self.nlvr2_output = nn.Linear(2*config.hidden_size, 2) 148 | self.apply(self.init_weights) 149 | 150 | def init_type_embedding(self): 151 | new_emb = nn.Embedding(3, self.uniter.config.hidden_size) 152 | new_emb.apply(self.init_weights) 153 | for i in [0, 1]: 154 | emb = self.uniter.embeddings.token_type_embeddings\ 155 | .weight.data[i, :] 156 | new_emb.weight.data[i, :].copy_(emb) 157 | new_emb.weight.data[2, :].copy_(emb) 158 | self.uniter.embeddings.token_type_embeddings = new_emb 159 | 160 | def forward(self, batch, compute_loss=True): 161 | batch = defaultdict(lambda: None, batch) 162 | input_ids = batch['input_ids'] 163 | position_ids = batch['position_ids'] 164 | img_feat = batch['img_feat'] 165 | img_pos_feat = batch['img_pos_feat'] 166 | attn_masks = batch['attn_masks'] 167 | gather_index = batch['gather_index'] 168 | img_type_ids = batch['img_type_ids'] 169 | sequence_output = self.uniter(input_ids, position_ids, 170 | img_feat, img_pos_feat, 171 | attn_masks, gather_index, 172 | output_all_encoded_layers=False, 173 | img_type_ids=img_type_ids) 174 | # separate left image and right image 175 | bs, tl, d = sequence_output.size() 176 | left_out, right_out = sequence_output.contiguous().view( 177 | bs//2, tl*2, d).chunk(2, dim=1) 178 | # bidirectional attention 179 | mask = attn_masks == 0 180 | left_mask, right_mask = mask.contiguous().view(bs//2, tl*2 181 | ).chunk(2, dim=1) 182 | left_out = left_out.transpose(0, 1) 183 | right_out = right_out.transpose(0, 1) 184 | l2r_attn, _ = self.attn1(left_out, right_out, right_out, 185 | key_padding_mask=right_mask) 186 | r2l_attn, _ = self.attn2(right_out, left_out, left_out, 187 | key_padding_mask=left_mask) 188 | left_out = self.fc(torch.cat([l2r_attn, left_out], dim=-1) 189 | ).transpose(0, 1) 190 | right_out = self.fc(torch.cat([r2l_attn, right_out], dim=-1) 191 | ).transpose(0, 1) 192 | # attention pooling and final prediction 193 | left_out = self.attn_pool(left_out, left_mask) 194 | right_out = self.attn_pool(right_out, right_mask) 195 | answer_scores = self.nlvr2_output( 196 | torch.cat([left_out, right_out], dim=-1)) 197 | 198 | if compute_loss: 199 | targets = batch['targets'] 200 | nlvr2_loss = F.cross_entropy( 201 | answer_scores, targets, reduction='none') 202 | return nlvr2_loss 203 | else: 204 | return answer_scores 205 | -------------------------------------------------------------------------------- /albef/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | from timm.models.vision_transformer import _cfg, PatchEmbed 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_, DropPath 11 | 12 | 13 | class Mlp(nn.Module): 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 15 | """ 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 40 | self.scale = qk_scale or head_dim ** -0.5 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | self.attn_gradients = None 46 | self.attention_map = None 47 | 48 | def save_attn_gradients(self, attn_gradients): 49 | self.attn_gradients = attn_gradients 50 | 51 | def get_attn_gradients(self): 52 | return self.attn_gradients 53 | 54 | def save_attention_map(self, attention_map): 55 | self.attention_map = attention_map 56 | 57 | def get_attention_map(self): 58 | return self.attention_map 59 | 60 | def forward(self, x, register_hook=False): 61 | B, N, C = x.shape 62 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 63 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 64 | 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | attn = attn.softmax(dim=-1) 67 | attn = self.attn_drop(attn) 68 | 69 | if register_hook: 70 | self.save_attention_map(attn) 71 | attn.register_hook(self.save_attn_gradients) 72 | 73 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | return x 77 | 78 | 79 | class Block(nn.Module): 80 | 81 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 82 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim) 85 | self.attn = Attention( 86 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 87 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 88 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 89 | self.norm2 = norm_layer(dim) 90 | mlp_hidden_dim = int(dim * mlp_ratio) 91 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 92 | 93 | def forward(self, x, register_hook=False): 94 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 95 | x = x + self.drop_path(self.mlp(self.norm2(x))) 96 | return x 97 | 98 | 99 | class VisionTransformer(nn.Module): 100 | """ Vision Transformer 101 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 102 | https://arxiv.org/abs/2010.11929 103 | """ 104 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 105 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 106 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): 107 | """ 108 | Args: 109 | img_size (int, tuple): input image size 110 | patch_size (int, tuple): patch size 111 | in_chans (int): number of input channels 112 | num_classes (int): number of classes for classification head 113 | embed_dim (int): embedding dimension 114 | depth (int): depth of transformer 115 | num_heads (int): number of attention heads 116 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 117 | qkv_bias (bool): enable bias for qkv if True 118 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 119 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 120 | drop_rate (float): dropout rate 121 | attn_drop_rate (float): attention dropout rate 122 | drop_path_rate (float): stochastic depth rate 123 | norm_layer: (nn.Module): normalization layer 124 | """ 125 | super().__init__() 126 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 127 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 128 | 129 | self.patch_embed = PatchEmbed( 130 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 131 | num_patches = self.patch_embed.num_patches 132 | 133 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 134 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 135 | self.pos_drop = nn.Dropout(p=drop_rate) 136 | 137 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 138 | self.blocks = nn.ModuleList([ 139 | Block( 140 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 141 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 142 | for i in range(depth)]) 143 | self.norm = norm_layer(embed_dim) 144 | 145 | trunc_normal_(self.pos_embed, std=.02) 146 | trunc_normal_(self.cls_token, std=.02) 147 | self.apply(self._init_weights) 148 | 149 | def _init_weights(self, m): 150 | if isinstance(m, nn.Linear): 151 | trunc_normal_(m.weight, std=.02) 152 | if isinstance(m, nn.Linear) and m.bias is not None: 153 | nn.init.constant_(m.bias, 0) 154 | elif isinstance(m, nn.LayerNorm): 155 | nn.init.constant_(m.bias, 0) 156 | nn.init.constant_(m.weight, 1.0) 157 | 158 | @torch.jit.ignore 159 | def no_weight_decay(self): 160 | return {'pos_embed', 'cls_token'} 161 | 162 | def forward(self, x, register_blk=-1): 163 | B = x.shape[0] 164 | x = self.patch_embed(x) 165 | 166 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 167 | x = torch.cat((cls_tokens, x), dim=1) 168 | 169 | x = x + self.pos_embed[:,:x.size(1),:] 170 | x = self.pos_drop(x) 171 | 172 | for i,blk in enumerate(self.blocks): 173 | x = blk(x, register_blk==i) 174 | x = self.norm(x) 175 | 176 | return x 177 | 178 | 179 | 180 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 181 | # interpolate position embedding 182 | embedding_size = pos_embed_checkpoint.shape[-1] 183 | num_patches = visual_encoder.patch_embed.num_patches 184 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 185 | # height (== width) for the checkpoint position embedding 186 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 187 | # height (== width) for the new position embedding 188 | new_size = int(num_patches ** 0.5) 189 | 190 | if orig_size!=new_size: 191 | # class_token and dist_token are kept unchanged 192 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 193 | # only the position tokens are interpolated 194 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 195 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 196 | pos_tokens = torch.nn.functional.interpolate( 197 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 198 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 199 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 200 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 201 | 202 | return new_pos_embed 203 | else: 204 | return pos_embed_checkpoint --------------------------------------------------------------------------------