├── .gitignore ├── LICENSE ├── README.md ├── cache └── .gitkeep ├── data └── README.md ├── environment.yml ├── lib ├── predictor.py ├── rank_utils.py ├── refer.py └── vanilla_utils.py ├── output └── .gitkeep ├── scripts └── prepare_data.sh ├── tb └── .gitkeep ├── tools ├── build_ctxdb.py ├── build_refdb.py ├── build_vocab.py ├── eval_proposal_ctx_recall.py ├── eval_proposal_hit_rate.py ├── save_matt_dets.py ├── save_ref_nms_proposals.py ├── train_att_rank.py └── train_att_vanilla.py └── utils ├── constants.py ├── hit_rate_utils.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode 2 | /**/__pycache__ 3 | data/* 4 | !data/README.md 5 | tb/* 6 | !tb/.gitkeep 7 | cache/* 8 | !cache/.gitkeep 9 | output/* 10 | !output/.gitkeep 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Meng5th 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ref-NMS 2 | Official codebase for AAAI 2021 paper ["Ref-NMS: Breaking Proposal Bottlenecks in Two-Stage Referring Expression Grounding"](https://arxiv.org/abs/2009.01449). 3 | 4 | ## Prerequisites 5 | The following dependencies should be enough. See [environment.yml](environment.yml) for complete environment settings. 6 | - python 3.7.6 7 | - pytorch 1.1.0 8 | - torchvision 0.3.0 9 | - tensorboard 2.1.0 10 | - spacy 2.2.3 11 | 12 | ## Data Preparation 13 | Follow instructions in `data/README.md` to setup `data` directory. 14 | 15 | Run following script to setup `cache` directory: 16 | ``` 17 | sh scripts/prepare_data.sh 18 | ``` 19 | This should generate following files under `cache` directory: 20 | - vocabulary file: `std_vocab__.txt` 21 | - selected GloVe feature: `std_glove__.npy` 22 | - referring expression database: `std_refdb__.json` 23 | - critical objects database: `std_ctxdb__.json` 24 | 25 | 26 | ## Train 27 | **Train with binary XE loss:** 28 | ``` 29 | PYTHONPATH=$PWD python tools/train_att_vanilla.py --dataset refcoco --split-by unc 30 | ``` 31 | 32 | **Train with ranking loss:** 33 | ``` 34 | PYTHONPATH=$PWD python tools/train_att_rank.py --dataset refcoco --split-by unc 35 | ``` 36 | 37 | We use tensorboard to monitor the training process. The log file can be found in `tb` folder. 38 | 39 | ## Evaluate Recall 40 | **Save Ref-NMS proposals:** 41 | ``` 42 | PYTHONPATH=$PWD python tools/save_ref_nms_proposals.py --dataset refcoco --split-by unc --tid --m 43 | ``` 44 | `` can be either `att_vanilla` for binary XE loss or `att_rank` for rank loss. 45 | 46 | **Evaluate recall on referent object:** 47 | ``` 48 | PYTHONPATH=$PWD python tools/eval_proposal_hit_rate.py --m --dataset refcoco --split-by unc --tid --conf 49 | ``` 50 | `conf` parameter is the score threshold used to filter Ref-NMS proposals. It should be picked properly so that the recall of the referent is high while the number of proposals per expression is around 8-10. 51 | 52 | **Evaluate recall on critical objects:** 53 | ``` 54 | PYTHONPATH=$PWD python tools/eval_proposal_ctx_recall.py --m --dataset refcoco --split-by unc --tid --conf 55 | ``` 56 | 57 | ## Evaluate REG Performance 58 | Save MAttNet-style detection file: 59 | ``` 60 | PYTHONPATH=$PWD python tools/save_matt_dets.py --dataset refcoco --split-by unc --m --tid --conf 61 | ``` 62 | This script will save all the detection information needed for downstream REG evaluation to `output/matt_dets_____.json`. 63 | 64 | We provide altered version of [MAttNet](https://github.com/ChopinSharp/MAttNet) and [CM-A-E](https://github.com/ChopinSharp/CM-Erase-REG) for downstream REG task evaluation. 65 | 66 | First, follow the README in each repository to reproduce the original reported results as baseline (c.f. Table 2 in our paper). Then, run the following commands to evaluate on REC and RES task: 67 | ``` 68 | # Evaluate REC performance 69 | python tools/extract_mrcn_ref_feats.py --dataset refcoco --splitBy unc --tid --top-N 0 --m 70 | python tools/eval_ref.py --dataset refcoco --splitBy unc --tid --top-N 0 --m 71 | # Evaluate RES performance 72 | python tools/run_propose_to_mask.py --dataset refcoco --splitBy unc --tid --top-N 0 --m 73 | python tools/eval_ref_masks.py --dataset refcoco --splitBy unc --tid --top-N 0 --m --save 74 | ``` 75 | 76 | ## Pretrained Models 77 | We provide pre-trained model weights as long as the corresponding **MAttNet-style detection file** (note the MattNet-style detection files can be directly used to evaluate downstream REG task performance). With these files, one can easily reproduce our reported results. 78 | 79 | [[Google Drive]](https://drive.google.com/drive/folders/1BPqWW0LrAEBFna7b-ORF2TcrY7K_DDvM?usp=sharing) [[Baidu Disk]](https://pan.baidu.com/s/1G4k7APKSUs-_5StXoYaNrA) (extraction code: 5a9r) 80 | 81 | ## Citation 82 | ``` 83 | @inproceedings{chen2021ref, 84 | title={Ref-NMS: Breaking Proposal Bottlenecks in Two-Stage Referring Expression Grounding}, 85 | author={Chen, Long and Ma, Wenbo and Xiao, Jun and Zhang, Hanwang and Chang, Shih-Fu}, 86 | booktitle={AAAI}, 87 | year={2021} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChopinSharp/ref-nms/8f83f350c497d0ef875c778a8ce76725552abb3c/cache/.gitkeep -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | `data` directory is organized as follows: 3 | ``` 4 | data 5 | ├── head_feats 6 | | ├── refcoco_unc 7 | | ├── refcoco+_unc 8 | | └── refccoog_umd 9 | ├── refer 10 | │ ├── images 11 | | ├── refclef 12 | | ├── refcoco 13 | | ├── refcoco+ 14 | | └── refcocog 15 | ├── glove.840B.300d.txt 16 | ├── res101_mask_rcnn_iter_1250000_cpu.pth 17 | ├── rpn_box_scores.pkl 18 | └── rpn_boxes.pkl 19 | ``` 20 | 21 | Download pretrained GloVe word vector: [glove.840B.300d.zip](http://nlp.stanford.edu/data/glove.840B.300d.zip). Unzip it to current directory. 22 | 23 | Follow these [descriptions](https://github.com/lichengunc/refer/tree/master/data) to setup referring expression data. Name `$DATA_PATH` as `data/refer`. 24 | 25 | Follow the instructions in [MAttNet](https://github.com/lichengunc/MAttNet) until the 2-nd step of the "Training" section. This will extract and save the ResNet features needed for training and evaluation. Then link these features to current directory: 26 | ``` 27 | ln -s $MATTNET_ROOT_DIR/cache/feats/_/mrcn/res101_coco_minus_refer_notime data/head_feats/_ 28 | ``` 29 | 30 | Download from [Google Drive](https://drive.google.com/drive/folders/1BPqWW0LrAEBFna7b-ORF2TcrY7K_DDvM?usp=sharing) or [Baidu Disk](https://pan.baidu.com/s/1G4k7APKSUs-_5StXoYaNrA) (extraction code: 5a9r). Extract `data` directory and move the contents to current directory. 31 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ref 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - absl-py=0.9.0=py37_0 8 | - asn1crypto=1.3.0=py37_0 9 | - attrs=19.3.0=py_0 10 | - backcall=0.1.0=py37_0 11 | - blas=1.0=mkl 12 | - bleach=3.1.4=py_0 13 | - blinker=1.4=py37_0 14 | - c-ares=1.15.0=h7b6447c_1001 15 | - ca-certificates=2020.1.1=0 16 | - cachetools=3.1.1=py_0 17 | - certifi=2020.4.5.1=py37_0 18 | - cffi=1.14.0=py37h2e261b9_0 19 | - chardet=3.0.4=py37_1003 20 | - click=7.0=py37_0 21 | - cryptography=2.8=py37h1ba5d50_0 22 | - cudatoolkit=9.0=h13b8566_0 23 | - cycler=0.10.0=py37_0 24 | - dbus=1.13.12=h746ee38_0 25 | - decorator=4.4.2=py_0 26 | - defusedxml=0.6.0=py_0 27 | - entrypoints=0.3=py37_0 28 | - expat=2.2.6=he6710b0_0 29 | - flake8=3.7.9=py37_0 30 | - fontconfig=2.13.0=h9420a91_0 31 | - freetype=2.9.1=h8a8886c_1 32 | - future=0.18.2=py37_0 33 | - glib=2.63.1=h5a9c865_0 34 | - gmp=6.1.2=h6c8ec71_1 35 | - google-auth=1.11.2=py_0 36 | - google-auth-oauthlib=0.4.1=py_2 37 | - grpcio=1.27.2=py37hf8bcb03_0 38 | - gst-plugins-base=1.14.0=hbbd80ab_1 39 | - gstreamer=1.14.0=hb453b48_1 40 | - h5py=2.10.0=py37h7918eee_0 41 | - hdf5=1.10.4=hb1b8bf9_0 42 | - icu=58.2=h9c2bf20_1 43 | - idna=2.8=py37_0 44 | - importlib_metadata=1.5.0=py37_0 45 | - intel-openmp=2020.0=166 46 | - ipykernel=5.1.4=py37h39e3cac_0 47 | - ipython=7.13.0=py37h5ca1d4c_0 48 | - ipython_genutils=0.2.0=py37_0 49 | - ipywidgets=7.5.1=py_0 50 | - jedi=0.17.0=py37_0 51 | - jinja2=2.11.2=py_0 52 | - jpeg=9b=h024ee3a_2 53 | - jsonschema=3.2.0=py37_0 54 | - jupyter=1.0.0=py37_7 55 | - jupyter_client=6.1.3=py_0 56 | - jupyter_console=6.1.0=py_0 57 | - jupyter_core=4.6.3=py37_0 58 | - kiwisolver=1.1.0=py37he6710b0_0 59 | - ld_impl_linux-64=2.33.1=h53a641e_7 60 | - libedit=3.1.20181209=hc058e9b_0 61 | - libffi=3.2.1=hd88cf55_4 62 | - libgcc-ng=9.1.0=hdf63c60_0 63 | - libgfortran-ng=7.3.0=hdf63c60_0 64 | - libpng=1.6.37=hbc83047_0 65 | - libprotobuf=3.11.4=hd408876_0 66 | - libsodium=1.0.16=h1bed415_0 67 | - libstdcxx-ng=9.1.0=hdf63c60_0 68 | - libtiff=4.1.0=h2733197_0 69 | - libuuid=1.0.3=h1bed415_2 70 | - libxcb=1.13=h1bed415_1 71 | - libxml2=2.9.9=hea5a465_1 72 | - markdown=3.1.1=py37_0 73 | - markupsafe=1.1.1=py37h7b6447c_0 74 | - matplotlib=3.1.3=py37_0 75 | - matplotlib-base=3.1.3=py37hef1b27d_0 76 | - mccabe=0.6.1=py37_1 77 | - mistune=0.8.4=py37h7b6447c_0 78 | - mkl=2020.0=166 79 | - mkl-service=2.3.0=py37he904b0f_0 80 | - mkl_fft=1.0.15=py37ha843d7b_0 81 | - mkl_random=1.1.0=py37hd6b4f25_0 82 | - nbconvert=5.6.1=py37_0 83 | - nbformat=5.0.6=py_0 84 | - ncurses=6.2=he6710b0_0 85 | - ninja=1.9.0=py37hfd86e86_0 86 | - notebook=6.0.3=py37_0 87 | - numpy=1.18.1=py37h4f9e942_0 88 | - numpy-base=1.18.1=py37hde5b4d6_1 89 | - oauthlib=3.1.0=py_0 90 | - olefile=0.46=py37_0 91 | - openssl=1.1.1g=h7b6447c_0 92 | - pandoc=2.2.3.2=0 93 | - pandocfilters=1.4.2=py37_1 94 | - parso=0.7.0=py_0 95 | - pcre=8.43=he6710b0_0 96 | - pexpect=4.8.0=py37_0 97 | - pickleshare=0.7.5=py37_0 98 | - pillow=6.2.1=py37h34e0f95_0 99 | - pip=20.0.2=py37_1 100 | - prometheus_client=0.7.1=py_0 101 | - prompt-toolkit=3.0.4=py_0 102 | - prompt_toolkit=3.0.4=0 103 | - protobuf=3.11.4=py37he6710b0_0 104 | - ptyprocess=0.6.0=py37_0 105 | - pyasn1=0.4.8=py_0 106 | - pyasn1-modules=0.2.7=py_0 107 | - pycodestyle=2.5.0=py37_0 108 | - pycparser=2.19=py37_0 109 | - pyflakes=2.1.1=py37_0 110 | - pygments=2.6.1=py_0 111 | - pyjwt=1.7.1=py37_0 112 | - pyopenssl=19.1.0=py37_0 113 | - pyparsing=2.4.6=py_0 114 | - pyqt=5.9.2=py37h05f1152_2 115 | - pyrsistent=0.16.0=py37h7b6447c_0 116 | - pysocks=1.7.1=py37_0 117 | - python=3.7.6=h0371630_2 118 | - python-dateutil=2.8.1=py_0 119 | - pytorch=1.1.0=py3.7_cuda9.0.176_cudnn7.5.1_0 120 | - pyzmq=18.1.1=py37he6710b0_0 121 | - qt=5.9.7=h5867ecd_1 122 | - qtconsole=4.7.3=py_0 123 | - qtpy=1.9.0=py_0 124 | - readline=7.0=h7b6447c_5 125 | - requests=2.22.0=py37_1 126 | - requests-oauthlib=1.3.0=py_0 127 | - rsa=4.0=py_0 128 | - scipy=1.4.1=py37h0b6359f_0 129 | - send2trash=1.5.0=py37_0 130 | - setuptools=46.0.0=py37_0 131 | - sip=4.19.8=py37hf484d3e_0 132 | - six=1.14.0=py37_0 133 | - sqlite=3.31.1=h7b6447c_0 134 | - tensorboard=2.1.0=py3_0 135 | - terminado=0.8.3=py37_0 136 | - testpath=0.4.4=py_0 137 | - tk=8.6.8=hbc83047_0 138 | - torchvision=0.3.0=py37_cu9.0.176_1 139 | - tornado=6.0.3=py37h7b6447c_3 140 | - tqdm=4.42.1=py_0 141 | - traitlets=4.3.3=py37_0 142 | - urllib3=1.25.8=py37_0 143 | - wcwidth=0.1.9=py_0 144 | - webencodings=0.5.1=py37_1 145 | - werkzeug=1.0.0=py_0 146 | - wheel=0.34.2=py37_0 147 | - widgetsnbextension=3.5.1=py37_0 148 | - xz=5.2.4=h14c3975_4 149 | - zeromq=4.3.1=he6710b0_3 150 | - zipp=3.1.0=py_0 151 | - zlib=1.2.11=h7b6447c_3 152 | - zstd=1.3.7=h0b5b093_0 153 | - pip: 154 | - blis==0.4.1 155 | - catalogue==1.0.0 156 | - cymem==2.0.3 157 | - en-core-web-sm==2.2.0 158 | - murmurhash==1.0.2 159 | - plac==1.1.3 160 | - preshed==3.0.2 161 | - spacy==2.2.3 162 | - srsly==1.0.2 163 | - thinc==7.3.1 164 | - wasabi==0.6.0 165 | prefix: /home/mwb/miniconda3/envs/ref 166 | -------------------------------------------------------------------------------- /lib/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import normalize, relu 4 | from torch.nn.utils.rnn import pad_packed_sequence 5 | from torchvision import models 6 | 7 | 8 | __all__ = ['VanillaPredictor', 'LTRPredictor', 'RNNBCEPredictor', 'AttVanillaPredictor', 'RankPredictor', 9 | 'SquashedRankPredictor'] 10 | 11 | 12 | def make_resnet_layer4(): 13 | downsample = nn.Sequential( 14 | nn.Conv2d(1024, 2048, kernel_size=1, stride=1, bias=False), 15 | nn.BatchNorm2d(2048) 16 | ) 17 | layers = [models.resnet.Bottleneck(1024, 512, 1, downsample)] 18 | for _ in range(2): 19 | layers.append(models.resnet.Bottleneck(2048, 512)) 20 | return nn.Sequential(*layers) 21 | 22 | 23 | class VanillaPredictor(nn.Module): 24 | 25 | def __init__(self): 26 | super(VanillaPredictor, self).__init__() 27 | # Head network with pretrained weight 28 | self.head = make_resnet_layer4() 29 | # Prediction branch & bbox regression branch from m-rcnn 30 | self.cls_score_net = nn.Linear(2048, 81, bias=True) 31 | self.bbox_pred_net = nn.Linear(2048, 81*4, bias=True) 32 | # Layers of new branch 33 | self.ref_cls_fc1 = nn.Linear(2048, 300, bias=True) 34 | self.ref_cls_fc2 = nn.Linear(300, 1, bias=True) 35 | 36 | def forward(self, roi_feats, word_feats): 37 | """ 38 | 39 | Args: 40 | roi_feats: [N, R, 1024, 7, 7]. 41 | word_feats: [N, S, 300]. 42 | 43 | Returns: 44 | max_ref_score: [N, R]. 45 | max_idx: [N, R]. 46 | cls_score: [N, R, 81]. 47 | bbox_pred: [N, R, 81*4]. 48 | 49 | """ 50 | N, R, *_ = roi_feats.shape 51 | head_feats = self.head(roi_feats.reshape(N*R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 52 | head_pool = head_feats.mean(dim=(3, 4)) # [N, R, 2048] 53 | head_mapped = self.ref_cls_fc1(head_pool) # [N, R, 300] 54 | head_reshaped = head_mapped.unsqueeze(2) # [N, R, 1, 300] 55 | word_reshaped = word_feats.unsqueeze(1) # [N, 1, S, 300] 56 | feat_merged = head_reshaped*word_reshaped # [N, R, S, 300] 57 | feat_merged = normalize(feat_merged, dim=3) # [N, R, S, 300] 58 | ref_score = self.ref_cls_fc2(feat_merged).squeeze(dim=3) # [N, R, S] 59 | max_ref_score, max_idx = ref_score.max(dim=2) # [N, R] 60 | cls_score = self.cls_score_net(head_pool) # [N, R, 81] 61 | bbox_pred = self.bbox_pred_net(head_pool) # [N, R, 81*4] 62 | return max_ref_score, max_idx, cls_score, bbox_pred 63 | 64 | 65 | class LTRSubNetwork(nn.Module): 66 | def __init__(self, sent_feat_dim): 67 | super(LTRSubNetwork, self).__init__() 68 | self.head_fc = nn.Linear(2048, sent_feat_dim, bias=True) 69 | self.rank_fc = nn.Linear(sent_feat_dim, 1, bias=False) 70 | 71 | def forward(self, sent_feat, head_feat): 72 | """ 73 | 74 | Args: 75 | sent_feat: Sentence feature: [sent_num, sent_feat_dim]. 76 | head_feat: Head network output: [roi_num, head_feat]. 77 | 78 | Returns: 79 | rank_score: Rank score of shape [sent_num, roi_num]. 80 | 81 | """ 82 | mapped_head_feat = self.head_fc(head_feat).unsqueeze(0) # [1, roi_num, sent_feat_dim] 83 | mapped_head_feat = relu(mapped_head_feat, inplace=True) # [1, roi_num, sent_feat_dim] 84 | reshaped_sent_feat = sent_feat.unsqueeze(1) # [sent_num, 1, sent_feat_dim] 85 | merged_feat = reshaped_sent_feat * mapped_head_feat # [sent_num, roi_num, sent_feat_dim] 86 | rank_score = self.rank_fc(merged_feat).squeeze(dim=2) # [sent_num, roi_num] 87 | return rank_score 88 | 89 | 90 | class LTRPredictor(nn.Module): 91 | 92 | RNN_H_DIM = 512 93 | 94 | def __init__(self): 95 | super(LTRPredictor, self).__init__() 96 | # Head network with pretrained weight 97 | self.head_net = make_resnet_layer4() 98 | # Prediction branch & Regression branch from M-RCNN 99 | self.cls_score_net = nn.Linear(2048, 81, bias=True) 100 | self.bbox_pred_net = nn.Linear(2048, 81 * 4, bias=True) 101 | # LTR sub-network 102 | self.LTR_net = LTRSubNetwork(self.RNN_H_DIM) 103 | # Sentence processing network 104 | self.rnn = nn.GRU(300, self.RNN_H_DIM, bidirectional=True) 105 | # self.stats = { 106 | # 'head_feat_mean': None, 'head_feat_std': None, 'head_feat_norm': None, 107 | # 'sent_feat_mean': None, 'sent_feat_std': None, 'sent_feat_norm': None 108 | # } 109 | 110 | def forward(self, roi_feat, packed_sent_feat): 111 | """ 112 | 113 | Args: 114 | roi_feat: ROI features, [roi_num, 1024, 7, 7]. 115 | packed_sent_feat: `PackedSequence` object, should be moved to proper device in advance. 116 | 117 | Returns: 118 | rank_score: Rank score of shape [sent_num, roi_num]. 119 | 120 | """ 121 | # Extract sentence feature with RNN 122 | _, h_n = self.rnn(packed_sent_feat) # [2, sent_num, sent_feat_dim] 123 | sent_feat = h_n.sum(dim=0) # [sent_num, sent_feat_dim] 124 | # Extract head feature 125 | head_feat = self.head_net(roi_feat) # [roi_num, 2048, 7, 7] 126 | head_feat = head_feat.mean(dim=(2, 3)) # [roi_num, 2048] 127 | # Rank ROIs 128 | rank_score = self.LTR_net(sent_feat, head_feat) # [sent_num, roi_num] 129 | # self.stats['head_feat_mean'] = head_feat.mean().item() 130 | # self.stats['head_feat_std'] = head_feat.std().item() 131 | # self.stats['head_feat_norm'] = head_feat.abs().mean().item() 132 | # self.stats['sent_feat_mean'] = sent_feat.mean().item() 133 | # self.stats['sent_feat_std'] = sent_feat.std().item() 134 | # self.stats['sent_feat_norm'] = sent_feat.abs().mean().item() 135 | return rank_score, # NOTE that the return tuple is a one-element tuple 136 | 137 | 138 | class RNNBCEPredictor(nn.Module): 139 | 140 | RNN_H_DIM = 512 141 | 142 | def __init__(self): 143 | super(RNNBCEPredictor, self).__init__() 144 | # Head network with pretrained weight 145 | self.head = make_resnet_layer4() 146 | # Prediction branch & bbox regression branch from m-rcnn 147 | self.cls_score_net = nn.Linear(2048, 81, bias=True) 148 | self.bbox_pred_net = nn.Linear(2048, 81*4, bias=True) 149 | # Layers of new branch 150 | self.ref_cls_fc1 = nn.Linear(2048, self.RNN_H_DIM, bias=True) 151 | self.ref_cls_fc2 = nn.Linear(self.RNN_H_DIM, 1, bias=True) 152 | self.rnn = nn.GRU(300, self.RNN_H_DIM, bidirectional=True) 153 | 154 | def forward(self, roi_feat, packed_sent_feat): 155 | """ 156 | 157 | Args: 158 | roi_feat: [N, R, 1024, 7, 7]. 159 | packed_sent_feat: `PackedSequence` object, should be moved to proper device in advance. 160 | 161 | Returns: 162 | ref_score: [N, R]. 163 | 164 | """ 165 | # Extract sentence feature with RNN 166 | _, h_n = self.rnn(packed_sent_feat) # [2, N, RNN_H_DIM] 167 | sent_feat = h_n.sum(dim=0) # [N, RNN_H_DIM] 168 | 169 | N, R, *_ = roi_feat.shape 170 | head_feat = self.head(roi_feat.reshape(N * R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 171 | head_pool = head_feat.mean(dim=(3, 4)) # [N, R, 2048] 172 | head_mapped = self.ref_cls_fc1(head_pool) # [N, R, RNN_H_DIM] 173 | sent_feat = sent_feat.unsqueeze(1) # [N, 1, RNN_H_DIM] 174 | feat_merged = head_mapped * sent_feat # [N, R, RNN_H_DIM] 175 | feat_merged = normalize(feat_merged, dim=2) # [N, R, RNN_H_DIM] 176 | ref_score = self.ref_cls_fc2(feat_merged) # [N, R, 1] 177 | ref_score = ref_score.squeeze(2) # [N, R] 178 | 179 | debug_info = { 180 | # 'head_feat_norm': head_feat.norm(p=2).item(), 181 | # 'sent_feat_norm': sent_feat.norm(p=2).item(), 182 | 'ref_score_mean': ref_score.mean().item(), 183 | 'ref_score_std': ref_score.std().item() 184 | } 185 | 186 | return ref_score, debug_info 187 | 188 | 189 | class AttVanillaPredictor(nn.Module): 190 | 191 | def __init__(self, att_dropout_p, vis_dropout_p, rank_dropout_p): 192 | super(AttVanillaPredictor, self).__init__() 193 | # Head network with pretrained weight 194 | self.head = make_resnet_layer4() 195 | # Layers of new branch 196 | self.att_fc = nn.Linear(3072, 1, bias=True) 197 | self.vis_fc = nn.Linear(2048, 1024, bias=True) 198 | self.rank_fc = nn.Linear(1024, 1, bias=True) 199 | self.att_dropout = nn.Dropout(p=att_dropout_p, inplace=True) 200 | self.vis_dropout = nn.Dropout(p=vis_dropout_p, inplace=True) 201 | self.rank_dropout = nn.Dropout(p=rank_dropout_p, inplace=True) 202 | self.rnn = nn.GRU(300, 512, bidirectional=True, batch_first=True) 203 | 204 | def forward(self, roi_feat, packed_sent_feat): 205 | """ 206 | 207 | Args: 208 | roi_feat: [N, R, 1024, 7, 7]. 209 | packed_sent_feat: `PackedSequence` object, should be moved to proper device in advance. 210 | 211 | Returns: 212 | ref_score: [N, R]. 213 | 214 | """ 215 | # Extract word feature with RNN 216 | packed_output, _ = self.rnn(packed_sent_feat) 217 | rnn_out, _ = pad_packed_sequence(packed_output, batch_first=True) # [N, S, 1024], [N] 218 | S = rnn_out.size(1) 219 | # Extract visual feature with ResNet Conv-head 220 | N, R, *_ = roi_feat.shape 221 | head_feat = self.head(roi_feat.reshape(N * R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 222 | head_pool = head_feat.mean(dim=(3, 4)) # [N, R, 2048] 223 | # Cross-modal attention over words 224 | expanded_rnn_out = rnn_out.unsqueeze(dim=1).expand(-1, R, -1, -1) # [N, R, S, 1024] 225 | expanded_head_pool = head_pool.unsqueeze(dim=2).expand(-1, -1, S, -1) # [N, R, S, 2048] 226 | att_merged_feat = torch.cat((expanded_rnn_out, expanded_head_pool), dim=3) # [N, R, S, 3072] 227 | att_score = self.att_fc(self.att_dropout(att_merged_feat)) # [N, R, S, 1] 228 | att_weight = torch.softmax(att_score, dim=2) # [N, R, S, 1] 229 | sent_feat = torch.sum(att_weight * expanded_rnn_out, dim=2) # [N, R, 1024] 230 | # Compute rank score 231 | head_mapped = self.vis_fc(self.vis_dropout(head_pool)) # [N, R, 1024] 232 | feat_merged = head_mapped * sent_feat # [N, R, 1024] 233 | feat_merged = relu(feat_merged, inplace=True) # [N, R, 1024] 234 | feat_merged = normalize(feat_merged, dim=2) # [N, R, 1024] 235 | ref_score = self.rank_fc(self.rank_dropout(feat_merged)) # [N, R, 1] 236 | ref_score = ref_score.squeeze(2) # [N, R] 237 | 238 | return ref_score, 239 | 240 | 241 | class AttVanillaPredictorV2(nn.Module): 242 | 243 | def __init__(self, att_dropout_p, rank_dropout_p): 244 | super(AttVanillaPredictorV2, self).__init__() 245 | # Head network with pretrained weight 246 | self.head = make_resnet_layer4() 247 | # Layers of new branch 248 | self.vis_a_fc = nn.Sequential( 249 | nn.Linear(2048, 1024, bias=True), 250 | nn.ReLU(inplace=True), 251 | nn.Linear(1024, 512, bias=True) 252 | ) 253 | self.vis_r_fc = nn.Sequential( 254 | nn.Linear(2048, 1024, bias=True), 255 | nn.ReLU(inplace=True), 256 | nn.Linear(1024, 512, bias=True) 257 | ) 258 | self.att_fc = nn.Linear(1024, 1, bias=True) 259 | self.rank_fc = nn.Linear(512, 1, bias=True) 260 | self.att_dropout = nn.Dropout(p=att_dropout_p, inplace=True) 261 | self.rank_dropout = nn.Dropout(p=rank_dropout_p, inplace=True) 262 | self.rnn = nn.GRU(300, 256, bidirectional=True, batch_first=True) 263 | 264 | def forward(self, roi_feat, packed_sent_feat): 265 | """ 266 | 267 | Args: 268 | roi_feat: [N, R, 1024, 7, 7]. 269 | packed_sent_feat: `PackedSequence` object, should be moved to proper device in advance. 270 | 271 | Returns: 272 | ref_score: [N, R]. 273 | 274 | """ 275 | # Extract visual feature with ResNet Conv-head 276 | N, R, *_ = roi_feat.shape 277 | head_feat = self.head(roi_feat.reshape(N * R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 278 | head_pool = head_feat.mean(dim=(3, 4)) # [N, R, 2048] 279 | # Extract word feature with RNN 280 | packed_output, _ = self.rnn(packed_sent_feat) 281 | rnn_out, sent_len = pad_packed_sequence(packed_output, batch_first=True) # [N, S, 512], [N] 282 | S = rnn_out.size(1) 283 | sent_mask = (torch.arange(S) + 1).unsqueeze(dim=0).expand(N, -1) > sent_len.unsqueeze(dim=1) 284 | sent_mask = sent_mask[:, None, :, None].expand(-1, R, -1, -1) # [N, R, S, 1] 285 | # Cross-modal attention over words 286 | att_key = self.vis_a_fc(head_pool) # [N, R, 512] 287 | att_key = att_key.unsqueeze(dim=2).expand(-1, -1, S, -1) # [N, R, S, 512] 288 | att_value = rnn_out.unsqueeze(dim=1).expand(-1, R, -1, -1) # [N, R, S, 512] 289 | att_feat = torch.cat((att_key, att_value), dim=3) # [N, R, S, 1024] 290 | att_feat = self.att_dropout(att_feat) # [N, R, S, 1024] 291 | att_score = self.att_fc(att_feat) # [N, R, S, 1] 292 | att_score[sent_mask] = float('-inf') 293 | att_weight = torch.softmax(att_score, dim=2) # [N, R, S, 1] 294 | sent_feat = torch.sum(att_weight * att_value, dim=2) # [N, R, 512] 295 | # Compute rank score 296 | head_mapped = self.vis_r_fc(head_pool) # [N, R, 512] 297 | feat_merged = head_mapped * sent_feat # [N, R, 512] 298 | feat_merged = relu(feat_merged, inplace=True) # [N, R, 512] 299 | feat_merged = normalize(feat_merged, dim=2) # [N, R, 512] 300 | feat_merged = self.rank_dropout(feat_merged) # [N, R, 512] 301 | ref_score = self.rank_fc(feat_merged) # [N, R, 1] 302 | ref_score = ref_score.squeeze(2) # [N, R] 303 | 304 | return ref_score, 305 | 306 | 307 | class RankPredictor(nn.Module): 308 | 309 | def __init__(self): 310 | super(RankPredictor, self).__init__() 311 | self.head = make_resnet_layer4() 312 | self.word_fc = nn.Linear(300, 300, bias=True) 313 | self.head_fc = nn.Linear(2048, 300, bias=True) 314 | self.rank_fc = nn.Linear(300, 1, bias=True) 315 | 316 | def forward(self, roi_feat, word_feat): 317 | """ 318 | 319 | Args: 320 | roi_feat: [N, R, 1024, 7, 7]. 321 | word_feat: [N, S, 300]. 322 | 323 | Returns: 324 | max_rank_score: [N, R]. 325 | max_idx: [N, R]. 326 | 327 | """ 328 | N, R, *_ = roi_feat.shape 329 | head_feat = self.head(roi_feat.reshape(N*R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 330 | head_pool = head_feat.mean(dim=(3, 4)) # [N, R, 2048] 331 | head_mapped = self.head_fc(head_pool) # [N, R, 300] 332 | word_mapped = self.word_fc(word_feat) # [N, S, 300] 333 | head_expanded = head_mapped.unsqueeze(2) # [N, R, 1, 300] 334 | word_expanded = word_mapped.unsqueeze(1) # [N, 1, S, 300] 335 | feat_merged = head_expanded * word_expanded # [N, R, S, 300] 336 | feat_merged = normalize(feat_merged, dim=3) # [N, R, S, 300] 337 | rank_score = self.rank_fc(feat_merged).squeeze(dim=3) # [N, R, S] 338 | max_rank_score, max_idx = rank_score.max(dim=2) # [N, R], [N, R] 339 | return max_rank_score, max_idx 340 | 341 | def rank_parameters(self): 342 | return list(self.head_fc.parameters()) + list(self.word_fc.parameters()) + list(self.rank_fc.parameters()) 343 | 344 | def named_rank_parameters(self): 345 | return list(self.head_fc.named_parameters()) + list(self.word_fc.named_parameters()) \ 346 | + list(self.rank_fc.named_parameters()) 347 | 348 | 349 | class SquashedRankPredictor(nn.Module): 350 | 351 | def __init__(self, dropout_p=0.5): 352 | super(SquashedRankPredictor, self).__init__() 353 | self.head = make_resnet_layer4() 354 | self.word_fc = nn.Linear(300, 300, bias=True) 355 | self.head_fc = nn.Linear(2048, 300, bias=True) 356 | self.rank_fc = nn.Linear(300, 1, bias=True) 357 | self.head_dropout = nn.Dropout(p=dropout_p, inplace=True) 358 | 359 | def forward(self, roi_feat, word_feat): 360 | """ 361 | 362 | Args: 363 | roi_feat: [N, R, 1024, 7, 7]. 364 | word_feat: [N, S, 300]. 365 | 366 | Returns: 367 | max_rank_score: [N, R]. 368 | max_idx: [N, R]. 369 | 370 | """ 371 | N, R, *_ = roi_feat.shape 372 | head_feat = self.head(roi_feat.reshape(N*R, 1024, 7, 7)).reshape(N, R, 2048, 7, 7) 373 | head_pool = head_feat.mean(dim=(3, 4)) # [N, R, 2048] 374 | head_pool = self.head_dropout(head_pool) # [N, R, 2048] 375 | head_mapped = self.head_fc(head_pool) # [N, R, 300] 376 | word_mapped = self.word_fc(word_feat) # [N, S, 300] 377 | head_expanded = head_mapped.unsqueeze(2) # [N, R, 1, 300] 378 | word_expanded = word_mapped.unsqueeze(1) # [N, 1, S, 300] 379 | feat_merged = head_expanded * word_expanded # [N, R, S, 300] 380 | feat_merged = normalize(feat_merged, dim=3) # [N, R, S, 300] 381 | rank_score = self.rank_fc(feat_merged).squeeze(dim=3) # [N, R, S] 382 | max_rank_score, max_idx = rank_score.max(dim=2) # [N, R], [N, R] 383 | sigmoid_rank_score = torch.sigmoid(max_rank_score) 384 | return sigmoid_rank_score, max_idx 385 | # lower_bound = torch.zeros_like(max_rank_score) 386 | # upper_bound = torch.ones_like(max_rank_score) 387 | # scaled_rank_score = torch.min(torch.max(0.25 * max_rank_score + 0.5, lower_bound), upper_bound) 388 | # return scaled_rank_score, max_idx 389 | 390 | def rank_parameters(self): 391 | return list(self.head_fc.parameters()) + list(self.word_fc.parameters()) + list(self.rank_fc.parameters()) 392 | 393 | def named_rank_parameters(self): 394 | return list(self.head_fc.named_parameters()) + list(self.word_fc.named_parameters()) \ 395 | + list(self.rank_fc.named_parameters()) 396 | 397 | 398 | if __name__ == '__main__': 399 | 400 | predictor = AttVanillaPredictorV2(0.5, 0.5) 401 | for k, v in predictor.named_parameters(): 402 | print(k, ':', v.shape) 403 | -------------------------------------------------------------------------------- /lib/rank_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import math 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision.ops import nms 9 | import numpy as np 10 | import h5py 11 | from tqdm import trange 12 | 13 | from utils.misc import mrcn_crop_pool_layer, recursive_jitter_roi, repeat_loader, calculate_iou 14 | 15 | 16 | __all__ = ['RankDataset', 'RankEvalLoader', 'RankEvaluator'] 17 | 18 | 19 | class RankDataset(Dataset): 20 | 21 | # Pre-extracted image feature files: {image_id}.h5 22 | # Format: {'head': (1, 1024, ih, iw), 'im_info': [[ih, iw, scale]]} 23 | # ih == im_height*scale/16.0, iw == im_width*scale/16.0) 24 | HEAD_FEAT_DIR = 'data/head_feats' 25 | BOX_FILE_PATH = 'data/rpn_boxes.pkl' 26 | SCORE_FILE_PATH = 'data/rpn_box_scores.pkl' 27 | CONF_THRESH = 0.05 28 | 29 | def __init__(self, refdb, ctxdb, split, level_num, roi_per_level, negative_num): 30 | Dataset.__init__(self) 31 | self.refs = refdb[split] 32 | self.dataset_splitby = refdb['dataset_splitby'] 33 | self.exp_to_ctx = ctxdb[split] 34 | self.idx_to_glove = np.load('cache/std_glove_{}.npy'.format(refdb['dataset_splitby'])) 35 | self.max_sent_len = 20 if refdb['dataset_splitby'] == 'refcocog_umd' else 10 36 | self.pad_feat = np.zeros(300, dtype=np.float32) 37 | self.level_num = level_num 38 | end_points = np.linspace(0.5, 1.0, num=level_num, endpoint=True).tolist() 39 | self.interval_list = list(zip(end_points[:-1], end_points[1:])) 40 | self.interval_list.insert(0, (0.1, 0.4)) 41 | self.roi_num_list = [negative_num] + (level_num - 1) * [roi_per_level] 42 | with open(self.BOX_FILE_PATH, 'rb') as f: 43 | self.img_to_det_box = pickle.load(f) 44 | with open(self.SCORE_FILE_PATH, 'rb') as f: 45 | self.img_to_det_score = pickle.load(f) 46 | 47 | def __getitem__(self, idx): 48 | """ 49 | 50 | Returns: 51 | roi_feat: [R, 1024, 7, 7] 52 | word_feat: [S, 300] 53 | 54 | """ 55 | # Index refdb entry 56 | ref = self.refs[idx] 57 | image_id = ref['image_id'] 58 | gt_box = ref['bbox'] 59 | exp_id = ref['exp_id'] 60 | ctx_list = self.exp_to_ctx[str(exp_id)]['ctx'] 61 | # Build word features 62 | word_feat, sent_len = self.build_word_feats(ref['tokens']) 63 | # Load image feature 64 | image_h5 = h5py.File(os.path.join(self.HEAD_FEAT_DIR, self.dataset_splitby, '{}.h5'.format(image_id)), 'r') 65 | scaled_h, scaled_w, scale = image_h5['im_info'][0].tolist() 66 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 67 | # Jitter ROIs 68 | roi_list = self.get_roi_list(image_id, gt_box, ctx_list, scale, scaled_w, scaled_h) 69 | roi_t = torch.tensor(roi_list) 70 | roi_feat = mrcn_crop_pool_layer(image_feat, roi_t) 71 | return roi_feat, word_feat, sent_len 72 | 73 | def __len__(self): 74 | return len(self.refs) 75 | 76 | def get_roi_list(self, image_id, gt_box, ctx_list, scale, scaled_w, scaled_h): 77 | # Bin detection boxes according to IoU and scale them 78 | boxes = self.img_to_det_box[image_id].reshape(-1, 81, 4) 79 | scores = self.img_to_det_score[image_id] 80 | boxes = boxes[:, 1:] # [*, 80, 4] 81 | scores = scores[:, 1:] # [*, 80] 82 | box_list = [[] for _ in range(self.level_num)] 83 | target_list = [gt_box] 84 | for t in ctx_list: 85 | target_list.append(t['box']) 86 | for box in boxes[scores > self.CONF_THRESH]: 87 | iou = max([calculate_iou(t, box) for t in target_list]) 88 | level = math.ceil(max(0, (iou - 0.5)) * (self.level_num - 1) / 0.5) 89 | box_list[level].append(self.scale_roi(box, scale, scaled_w, scaled_h)) 90 | # Construct RoI list 91 | scaled_target_list = [self.scale_roi(t, scale, scaled_w, scaled_h) for t in target_list] 92 | roi_list = [] 93 | for (L, R), level_roi_num, level_box_list in zip(self.interval_list, self.roi_num_list, box_list): 94 | sampled_boxes, less_num = self.sample_roi(level_box_list, level_roi_num) 95 | roi_list.extend(sampled_boxes) 96 | for _ in range(less_num): 97 | scaled_t = random.choice(scaled_target_list) 98 | roi_list.append(recursive_jitter_roi(scaled_t, L, R, scaled_w, scaled_h)) 99 | assert len(roi_list) == sum(self.roi_num_list) 100 | return roi_list 101 | 102 | @staticmethod 103 | def scale_roi(roi, scale, scaled_w, scaled_h): 104 | x0, y0, x1, y1 = roi 105 | scaled_x0 = min(x0 * scale, scaled_w) 106 | scaled_y0 = min(y0 * scale, scaled_h) 107 | scaled_x1 = min(x1 * scale, scaled_w) 108 | scaled_y1 = min(y1 * scale, scaled_h) 109 | return scaled_x0, scaled_y0, scaled_x1, scaled_y1 110 | 111 | @staticmethod 112 | def sample_roi(candidate_list, num): 113 | candidate_num = len(candidate_list) 114 | if candidate_num == 0: 115 | return [], num 116 | elif candidate_num <= num: 117 | return candidate_list, num - candidate_num 118 | else: 119 | return random.sample(candidate_list, num), 0 120 | 121 | def build_word_feats(self, tokens): 122 | word_feats = [self.idx_to_glove[wd_idx] for wd_idx in tokens] 123 | word_feats += [self.pad_feat] * max(self.max_sent_len - len(word_feats), 0) 124 | word_feats = torch.tensor(word_feats[:self.max_sent_len]) # [S, 300] 125 | return word_feats, min(len(tokens), self.max_sent_len) 126 | 127 | 128 | class RankEvalLoader: 129 | 130 | BOX_FILE_PATH = 'cache/rpn_boxes.pkl' 131 | SCORE_FILE_PATH = 'cache/rpn_box_scores.pkl' 132 | IMG_FEAT_DIR = 'cache/head_feats/matt-mrcn' 133 | 134 | def __init__(self, refdb, split='val', conf_thresh=.05): 135 | self.refs = refdb[split] 136 | self.img_to_exps = {} 137 | for ref in self.refs: 138 | image_id = ref['image_id'] 139 | if image_id in self.img_to_exps: 140 | self.img_to_exps[image_id].append((ref['exp_id'], ref['tokens'])) 141 | else: 142 | self.img_to_exps[image_id] = [(ref['exp_id'], ref['tokens'])] 143 | with open(self.BOX_FILE_PATH, 'rb') as f: 144 | self.img_to_det_box = pickle.load(f) 145 | with open(self.SCORE_FILE_PATH, 'rb') as f: 146 | self.img_to_det_score = pickle.load(f) 147 | self.idx_to_glove = np.load('cache/glove_{}.npy'.format(refdb['dataset_splitby'])) 148 | self.conf_thresh = conf_thresh 149 | 150 | def __iter__(self): 151 | for image_id, exps in self.img_to_exps.items(): 152 | # Load image feature 153 | image_h5 = h5py.File(os.path.join(self.IMG_FEAT_DIR, '{}.h5'.format(image_id)), 'r') 154 | scale = image_h5['im_info'][0, 2] 155 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 156 | # RoI-pool positive M-RCNN detections 157 | det_box = self.img_to_det_box[image_id].reshape(-1, 81, 4) # [300, 81, 4] 158 | det_score = self.img_to_det_score[image_id] # [300, 81] 159 | det_box = np.transpose(det_box[:, 1:], axes=[1, 0, 2]) # [80, 300, 4] 160 | det_score = np.transpose(det_score[:, 1:], axes=[1, 0]) # [80, 300] 161 | positive = det_score > self.conf_thresh # [80, 300] 162 | pos_box = torch.tensor(det_box[positive]) # [*, 4] 163 | pos_score = torch.tensor(det_score[positive]) # [*] 164 | cls_num_list = np.sum(positive, axis=1).tolist() # [80] 165 | pos_feat = mrcn_crop_pool_layer(image_feat, pos_box * scale) # [*, 1024, 7, 7] 166 | pos_feat = pos_feat.unsqueeze(0) # [1, *, 1024, 7, 7] 167 | for exp_id, tokens in exps: 168 | # Load word feature 169 | assert isinstance(tokens, list) 170 | sent_feat = torch.tensor(self.idx_to_glove[tokens]) 171 | sent_feat = sent_feat.unsqueeze(0) # [1, *, 300] 172 | yield exp_id, pos_feat, sent_feat, pos_box, pos_score, cls_num_list 173 | 174 | def __len__(self): 175 | return len(self.refs) 176 | 177 | 178 | class RankEvaluator: 179 | 180 | def __init__(self, refdb, split, num_sample, top_N=None, gpu_id=0, alpha=0.15): 181 | """Runtime ref-based hit rate evaluator. 182 | 183 | Args: 184 | refdb: `refdb` dict. 185 | split: Dataset split to evaluate on. 186 | top_N: Select top-N scoring proposals to evaluate. `None` means no selection. Default `None`. 187 | num_sample: Use `num_sample` refs to evaluate hit rate. 188 | 189 | """ 190 | self.dataset_splitby = refdb['dataset_splitby'] 191 | self.exp_to_box = {} 192 | for ref in refdb[split]: 193 | self.exp_to_box[ref['exp_id']] = ref['bbox'] 194 | self.split = split 195 | self.top_N = top_N 196 | loader = RankEvalLoader(refdb, split=split, conf_thresh=0.05) 197 | self.loader = repeat_loader(loader) 198 | self.total = len(loader) 199 | self.num_sample = num_sample 200 | self.device = torch.device('cuda', gpu_id) 201 | self.alpha = alpha 202 | 203 | def eval_hit_rate(self, predictor): 204 | """Estimate hit rate with `num_sample` samples during runtime. 205 | 206 | Args: 207 | predictor: `torch.nn.module` to evaluate. Module should be set to eval mode IN ADVANCE. 208 | All parameters of predictor has to be on the SAME device. 209 | 210 | Returns: 211 | proposal_per_ref: Average proposal number per referring expression. 212 | hit_rate: Estimated hit rate. 213 | 214 | """ 215 | print('{} expressions in {} {} split, using {} samples to evaluate hit rate...' 216 | .format(self.total, self.dataset_splitby, self.split, self.num_sample)) 217 | # Use predictor to score proposals 218 | exp_to_proposals = {} 219 | for _ in trange(self.num_sample, desc='Estimating hit rate', ascii=True): 220 | exp_id, pos_feat, sent_feat, pos_box, pos_score, cls_num_list = next(self.loader) 221 | pos_feat = pos_feat.to(self.device) # [1, R, 1024, 7, 7] 222 | sent_feat = sent_feat.to(self.device) # [1, S, 300] 223 | pos_box = pos_box.to(self.device) 224 | pos_score = pos_score.to(self.device) 225 | with torch.no_grad(): 226 | rank_score, *_ = predictor(pos_feat, sent_feat) # [1, R] 227 | rank_score = rank_score.squeeze(dim=0) 228 | rank_score_list = torch.split(rank_score, cls_num_list, dim=0) 229 | pos_box_list = torch.split(pos_box, cls_num_list, dim=0) 230 | pos_score_list = torch.split(pos_score, cls_num_list, dim=0) 231 | proposals = [] 232 | for cls_rank_score, cls_pos_box, cls_pos_score in zip(rank_score_list, pos_box_list, pos_score_list): 233 | # No positive box under this category 234 | if cls_rank_score.size(0) == 0: 235 | continue 236 | final_score = self.alpha * cls_rank_score + (1 - self.alpha) * cls_pos_score 237 | keep = nms(cls_pos_box, final_score, iou_threshold=0.3) 238 | cls_kept_box = cls_pos_box[keep] 239 | cls_kept_score = final_score[keep] 240 | for box, score in zip(cls_kept_box, cls_kept_score): 241 | proposals.append({'score': score.item(), 'box': box.tolist()}) 242 | exp_to_proposals[exp_id] = proposals 243 | # Estimate hit rate 244 | assert len(exp_to_proposals) == self.num_sample 245 | num_proposal = 0 246 | num_hit = 0 247 | for exp_id, proposals in exp_to_proposals.items(): 248 | ranked_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True)[:self.top_N] 249 | gt_box = self.exp_to_box[exp_id] 250 | num_proposal += len(ranked_proposals) 251 | for proposal in ranked_proposals: 252 | if calculate_iou(gt_box, proposal['box']) > 0.5: 253 | num_hit += 1 254 | break 255 | proposal_per_ref = num_proposal / self.num_sample 256 | hit_rate = num_hit / self.num_sample 257 | return proposal_per_ref, hit_rate 258 | -------------------------------------------------------------------------------- /lib/refer.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | """ 4 | This interface provides access to four datasets: 5 | 1) refclef 6 | 2) refcoco 7 | 3) refcoco+ 8 | 4) refcocog 9 | split by unc and google 10 | The following API functions are defined: 11 | REFER - REFER api class 12 | getRefIds - get ref ids that satisfy given filter conditions. 13 | getAnnIds - get ann ids that satisfy given filter conditions. 14 | getImgIds - get image ids that satisfy given filter conditions. 15 | getCatIds - get category ids that satisfy given filter conditions. 16 | loadRefs - load refs with the specified ref ids. 17 | loadAnns - load anns with the specified ann ids. 18 | loadImgs - load images with the specified image ids. 19 | loadCats - load category names with the specified category ids. 20 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 21 | """ 22 | 23 | import sys 24 | import os.path as osp 25 | import json 26 | import pickle 27 | import time 28 | import itertools 29 | 30 | 31 | class REFER: 32 | 33 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 34 | self.dataset_split_by = dataset + '_' + splitBy 35 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 36 | # also provide dataset name and splitBy information 37 | # e.g., dataset = 'refcoco', splitBy = 'unc' 38 | print('loading dataset %s into memory...' % dataset) 39 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 40 | self.DATA_DIR = osp.join(data_root, dataset) 41 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 42 | self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') 43 | elif dataset == 'refclef': 44 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 45 | else: 46 | print('No refer dataset is called [%s]' % dataset) 47 | sys.exit() 48 | 49 | # load refs from data/dataset/refs(dataset).json 50 | tic = time.time() 51 | ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p') 52 | self.data = {} 53 | self.data['dataset'] = dataset 54 | self.data['refs'] = pickle.load(open(ref_file, 'rb'), encoding='latin1') 55 | 56 | # load annotations from data/dataset/instances.json 57 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 58 | instances = json.load(open(instances_file, 'r')) 59 | self.data['images'] = instances['images'] 60 | self.data['annotations'] = instances['annotations'] 61 | self.data['categories'] = instances['categories'] 62 | 63 | # create index 64 | self.createIndex() 65 | print('DONE (t=%.2fs)' % (time.time()-tic)) 66 | 67 | def createIndex(self): 68 | # create sets of mapping 69 | # 1) Refs: {ref_id: ref} 70 | # 2) Anns: {ann_id: ann} 71 | # 3) Imgs: {image_id: image} 72 | # 4) Cats: {category_id: category_name} 73 | # 5) Sents: {sent_id: sent} 74 | # 6) imgToRefs: {image_id: refs} 75 | # 7) imgToAnns: {image_id: anns} 76 | # 8) refToAnn: {ref_id: ann} 77 | # 9) annToRef: {ann_id: ref} 78 | # 10) catToRefs: {category_id: refs} 79 | # 11) sentToRef: {sent_id: ref} 80 | # 12) sentToTokens: {sent_id: tokens} 81 | print('creating index...') 82 | # fetch info from instances 83 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 84 | for ann in self.data['annotations']: 85 | Anns[ann['id']] = ann 86 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 87 | for img in self.data['images']: 88 | Imgs[img['id']] = img 89 | for cat in self.data['categories']: 90 | Cats[cat['id']] = cat['name'] 91 | 92 | # fetch info from refs 93 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 94 | Sents, sentToRef, sentToTokens = {}, {}, {} 95 | for ref in self.data['refs']: 96 | # ids 97 | ref_id = ref['ref_id'] 98 | ann_id = ref['ann_id'] 99 | category_id = ref['category_id'] 100 | image_id = ref['image_id'] 101 | 102 | # add mapping related to ref 103 | Refs[ref_id] = ref 104 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 105 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 106 | refToAnn[ref_id] = Anns[ann_id] 107 | annToRef[ann_id] = ref 108 | 109 | # add mapping of sent 110 | for sent in ref['sentences']: 111 | Sents[sent['sent_id']] = sent 112 | sentToRef[sent['sent_id']] = ref 113 | sentToTokens[sent['sent_id']] = sent['tokens'] 114 | 115 | # create class members 116 | self.Refs = Refs 117 | self.Anns = Anns 118 | self.Imgs = Imgs 119 | self.Cats = Cats 120 | self.Sents = Sents 121 | self.imgToRefs = imgToRefs 122 | self.imgToAnns = imgToAnns 123 | self.refToAnn = refToAnn 124 | self.annToRef = annToRef 125 | self.catToRefs = catToRefs 126 | self.sentToRef = sentToRef 127 | self.sentToTokens = sentToTokens 128 | print('index created.') 129 | 130 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 131 | image_ids = image_ids if type(image_ids) == list else [image_ids] 132 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 133 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 134 | 135 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: 136 | refs = self.data['refs'] 137 | else: 138 | if not len(image_ids) == 0: 139 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 140 | else: 141 | refs = self.data['refs'] 142 | if not len(cat_ids) == 0: 143 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 144 | if not len(ref_ids) == 0: 145 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 146 | if not len(split) == 0: 147 | if split in ['testA', 'testB', 'testC']: 148 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ... 149 | elif split in ['testAB', 'testBC', 'testAC']: 150 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 151 | elif split == 'test': 152 | refs = [ref for ref in refs if 'test' in ref['split']] 153 | elif split == 'train' or split == 'val': 154 | refs = [ref for ref in refs if ref['split'] == split] 155 | else: 156 | print('No such split [%s]' % split) 157 | sys.exit() 158 | ref_ids = [ref['ref_id'] for ref in refs] 159 | return ref_ids 160 | 161 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 162 | image_ids = image_ids if type(image_ids) == list else [image_ids] 163 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 164 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 165 | 166 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 167 | ann_ids = [ann['id'] for ann in self.data['annotations']] 168 | else: 169 | if not len(image_ids) == 0: 170 | lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns] 171 | anns = list(itertools.chain.from_iterable(lists)) 172 | else: 173 | anns = self.data['annotations'] 174 | if not len(cat_ids) == 0: 175 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 176 | ann_ids = [ann['id'] for ann in anns] 177 | if not len(ref_ids) == 0: 178 | # ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 179 | pass 180 | return ann_ids 181 | 182 | def getImgIds(self, ref_ids=[]): 183 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 184 | 185 | if not len(ref_ids) == 0: 186 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 187 | else: 188 | image_ids = self.Imgs.keys() 189 | return image_ids 190 | 191 | def getCatIds(self): 192 | return self.Cats.keys() 193 | 194 | def loadRefs(self, ref_ids=[]): 195 | if type(ref_ids) == list: 196 | return [self.Refs[ref_id] for ref_id in ref_ids] 197 | elif type(ref_ids) == int: 198 | return [self.Refs[ref_ids]] 199 | 200 | def loadAnns(self, ann_ids=[]): 201 | if type(ann_ids) == list: 202 | return [self.Anns[ann_id] for ann_id in ann_ids] 203 | elif type(ann_ids) == int or type(ann_ids) == str: # change from `unicode` to `str` 204 | return [self.Anns[ann_ids]] 205 | 206 | def loadImgs(self, image_ids=[]): 207 | if type(image_ids) == list: 208 | return [self.Imgs[image_id] for image_id in image_ids] 209 | elif type(image_ids) == int: 210 | return [self.Imgs[image_ids]] 211 | 212 | def loadCats(self, cat_ids=[]): 213 | if type(cat_ids) == list: 214 | return [self.Cats[cat_id] for cat_id in cat_ids] 215 | elif type(cat_ids) == int: 216 | return [self.Cats[cat_ids]] 217 | 218 | def getRefBox(self, ref_id): 219 | # ref = self.Refs[ref_id] 220 | ann = self.refToAnn[ref_id] 221 | return ann['bbox'] # [x, y, w, h] 222 | -------------------------------------------------------------------------------- /lib/vanilla_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | import h5py 9 | 10 | from utils.misc import calculate_iou, mrcn_crop_pool_layer 11 | 12 | 13 | __all__ = ['DetBoxDataset', 'DetEvalLoader', 'DetEvalTopLoader'] 14 | 15 | 16 | class DetBoxDataset(Dataset): 17 | 18 | HEAD_FEAT_DIR = 'data/head_feats' 19 | BOX_FILE_PATH = 'data/rpn_boxes.pkl' 20 | SCORE_FILE_PATH = 'data/rpn_box_scores.pkl' 21 | CONF_THRESH = 0.05 22 | DELTA_CONF = 0.005 23 | 24 | def __init__(self, refdb, ctxdb, split, roi_per_img): 25 | Dataset.__init__(self) 26 | self.refs = refdb[split] 27 | self.dataset_splitby = refdb['dataset_splitby'] 28 | self.exp_to_ctx = ctxdb[split] 29 | with open(self.BOX_FILE_PATH, 'rb') as f: 30 | self.img_to_det_box = pickle.load(f) 31 | with open(self.SCORE_FILE_PATH, 'rb') as f: 32 | self.img_to_det_score = pickle.load(f) 33 | self.idx_to_glove = np.load('cache/std_glove_{}.npy'.format(refdb['dataset_splitby'])) 34 | self.max_sent_len = 20 if refdb['dataset_splitby'] == 'refcocog_umd' else 10 35 | self.pad_feat = np.zeros(300, dtype=np.float32) 36 | # Number of samples to draw from one image 37 | self.roi_per_img = roi_per_img 38 | 39 | def __getitem__(self, idx): 40 | """ 41 | 42 | Returns: 43 | roi_feats: [R, 1024, 7, 7] 44 | roi_labels: [R] 45 | word_feats: [S, 300] 46 | sent_len: [0] 47 | 48 | """ 49 | # Index refer object 50 | ref = self.refs[idx] 51 | image_id = ref['image_id'] 52 | gt_box = ref['bbox'] 53 | exp_id = ref['exp_id'] 54 | ctx_boxes = [c['box'] for c in self.exp_to_ctx[str(exp_id)]['ctx']] 55 | target_list = [gt_box] + ctx_boxes 56 | pos_rois, neg_rois = self.get_labeled_rois(image_id, target_list) 57 | # Build word features 58 | word_feats, sent_len = self.build_word_feats(ref['tokens']) 59 | # Load image feature 60 | image_h5 = h5py.File(os.path.join(self.HEAD_FEAT_DIR, self.dataset_splitby, '{}.h5'.format(image_id)), 'r') 61 | scale = image_h5['im_info'][0, 2] 62 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 63 | # Sample ROIs 64 | pos_num = min(len(pos_rois), self.roi_per_img // 2) 65 | neg_num = min(len(neg_rois), self.roi_per_img - pos_num) 66 | pos_num = self.roi_per_img - neg_num 67 | sampled_pos = random.sample(pos_rois, pos_num) 68 | sampled_neg = random.sample(neg_rois, neg_num) 69 | pos_labels = torch.ones(len(sampled_pos), dtype=torch.float) 70 | neg_labels = torch.zeros(len(sampled_neg), dtype=torch.float) 71 | roi_labels = torch.cat([pos_labels, neg_labels], dim=0) # [R] 72 | # Extract head features 73 | sampled_roi = torch.tensor(sampled_pos + sampled_neg) # [R, 4] 74 | sampled_roi.mul_(scale) 75 | roi_feats = mrcn_crop_pool_layer(image_feat, sampled_roi) 76 | return roi_feats, roi_labels, word_feats, sent_len 77 | 78 | def __len__(self): 79 | return len(self.refs) 80 | 81 | def get_labeled_rois(self, image_id, target_list): 82 | boxes = self.img_to_det_box[image_id].reshape(-1, 81, 4) 83 | scores = self.img_to_det_score[image_id] 84 | boxes = boxes[:, 1:] # [*, 80, 4] 85 | scores = scores[:, 1:] # [*, 80] 86 | # boxes = boxes.reshape(-1, 4) 87 | # scores = scores.reshape(-1) 88 | # top_idx = np.argsort(scores)[-self.TOP_N:] 89 | this_thresh = self.CONF_THRESH 90 | positive = scores > this_thresh 91 | while np.sum(positive) < self.roi_per_img: 92 | this_thresh -= self.DELTA_CONF 93 | positive = scores > this_thresh 94 | pos_rois = [] 95 | neg_rois = [] 96 | # for box in boxes[top_idx]: 97 | for box in boxes[positive]: 98 | for t in target_list: 99 | if calculate_iou(box, t) >= 0.5: 100 | pos_rois.append(box) 101 | break 102 | else: 103 | neg_rois.append(box) 104 | return pos_rois, neg_rois 105 | 106 | def build_word_feats(self, tokens): 107 | word_feats = [self.idx_to_glove[wd_idx] for wd_idx in tokens] 108 | word_feats += [self.pad_feat] * max(self.max_sent_len - len(word_feats), 0) 109 | word_feats = torch.tensor(word_feats[:self.max_sent_len]) # [S, 300] 110 | return word_feats, min(len(tokens), self.max_sent_len) 111 | 112 | 113 | class DetBoxDatasetNoCtx(Dataset): 114 | 115 | HEAD_FEAT_DIR = 'cache/head_feats/matt-mrcn' 116 | BOX_FILE_PATH = 'cache/rpn_boxes.pkl' 117 | SCORE_FILE_PATH = 'cache/rpn_box_scores.pkl' 118 | CONF_THRESH = 0.05 119 | DELTA_CONF = 0.005 120 | 121 | def __init__(self, refdb, split, roi_per_img): 122 | Dataset.__init__(self) 123 | self.refs = refdb[split] 124 | with open(self.BOX_FILE_PATH, 'rb') as f: 125 | self.img_to_det_box = pickle.load(f) 126 | with open(self.SCORE_FILE_PATH, 'rb') as f: 127 | self.img_to_det_score = pickle.load(f) 128 | self.idx_to_glove = np.load('cache/std_glove_{}.npy'.format(refdb['dataset_splitby'])) 129 | self.max_sent_len = 20 if refdb['dataset_splitby'] == 'refcocog_umd' else 10 130 | self.pad_feat = np.zeros(300, dtype=np.float32) 131 | # Number of samples to draw from one image 132 | self.roi_per_img = roi_per_img 133 | 134 | def __getitem__(self, idx): 135 | """ 136 | 137 | Returns: 138 | roi_feats: [R, 1024, 7, 7] 139 | roi_labels: [R] 140 | word_feats: [S, 300] 141 | sent_len: [0] 142 | 143 | """ 144 | # Index refer object 145 | ref = self.refs[idx] 146 | image_id = ref['image_id'] 147 | gt_box = ref['bbox'] 148 | pos_rois, neg_rois = self.get_labeled_rois(image_id, gt_box) 149 | # Build word features 150 | word_feats, sent_len = self.build_word_feats(ref['tokens']) 151 | # Load image feature 152 | image_h5 = h5py.File(os.path.join(self.HEAD_FEAT_DIR, '{}.h5'.format(image_id)), 'r') 153 | scale = image_h5['im_info'][0, 2] 154 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 155 | # Sample ROIs 156 | pos_num = min(len(pos_rois), self.roi_per_img // 2) 157 | neg_num = min(len(neg_rois), self.roi_per_img - pos_num) 158 | pos_num = self.roi_per_img - neg_num 159 | sampled_pos = random.sample(pos_rois, pos_num) 160 | sampled_neg = random.sample(neg_rois, neg_num) 161 | pos_labels = torch.ones(len(sampled_pos), dtype=torch.float) 162 | neg_labels = torch.zeros(len(sampled_neg), dtype=torch.float) 163 | roi_labels = torch.cat([pos_labels, neg_labels], dim=0) # [R] 164 | # Extract head features 165 | sampled_roi = torch.tensor(sampled_pos + sampled_neg) # [R, 4] 166 | sampled_roi.mul_(scale) 167 | roi_feats = mrcn_crop_pool_layer(image_feat, sampled_roi) 168 | return roi_feats, roi_labels, word_feats, sent_len 169 | 170 | def __len__(self): 171 | return len(self.refs) 172 | 173 | def get_labeled_rois(self, image_id, gt_box): 174 | boxes = self.img_to_det_box[image_id].reshape(-1, 81, 4) 175 | scores = self.img_to_det_score[image_id] 176 | boxes = boxes[:, 1:] # [*, 80, 4] 177 | scores = scores[:, 1:] # [*, 80] 178 | this_thresh = self.CONF_THRESH 179 | positive = scores > this_thresh 180 | while np.sum(positive) < self.roi_per_img: 181 | this_thresh -= self.DELTA_CONF 182 | positive = scores > this_thresh 183 | pos_rois = [] 184 | neg_rois = [] 185 | # for box in boxes[top_idx]: 186 | for box in boxes[positive]: 187 | if calculate_iou(box, gt_box) >= 0.5: 188 | pos_rois.append(box) 189 | else: 190 | neg_rois.append(box) 191 | return pos_rois, neg_rois 192 | 193 | def build_word_feats(self, tokens): 194 | word_feats = [self.idx_to_glove[wd_idx] for wd_idx in tokens] 195 | word_feats += [self.pad_feat] * max(self.max_sent_len - len(word_feats), 0) 196 | word_feats = torch.tensor(word_feats[:self.max_sent_len]) # [S, 300] 197 | return word_feats, min(len(tokens), self.max_sent_len) 198 | 199 | 200 | class DetEvalLoader: 201 | 202 | BOX_FILE_PATH = 'data/rpn_boxes.pkl' 203 | SCORE_FILE_PATH = 'data/rpn_box_scores.pkl' 204 | IMG_FEAT_DIR = 'data/head_feats' 205 | CONF_THRESH = 0.05 206 | DELTA_CONF = 0.005 207 | 208 | def __init__(self, refdb, split='val', gpu_id=0): 209 | self.dataset_splitby = refdb['dataset_splitby'] 210 | self.refs = refdb[split] 211 | self.img_to_exps = {} 212 | for ref in self.refs: 213 | image_id = ref['image_id'] 214 | if image_id in self.img_to_exps: 215 | self.img_to_exps[image_id].append((ref['exp_id'], ref['tokens'])) 216 | else: 217 | self.img_to_exps[image_id] = [(ref['exp_id'], ref['tokens'])] 218 | with open(self.BOX_FILE_PATH, 'rb') as f: 219 | self.img_to_det_box = pickle.load(f) 220 | with open(self.SCORE_FILE_PATH, 'rb') as f: 221 | self.img_to_det_score = pickle.load(f) 222 | self.idx_to_glove = np.load('cache/std_glove_{}.npy'.format(refdb['dataset_splitby'])) 223 | self.device = torch.device('cuda', gpu_id) 224 | 225 | def __iter__(self): 226 | # Fetch ref info 227 | for image_id, exps in self.img_to_exps.items(): 228 | # Load image feature 229 | image_h5 = h5py.File(os.path.join(self.IMG_FEAT_DIR, self.dataset_splitby, '{}.h5'.format(image_id)), 'r') 230 | scale = image_h5['im_info'][0, 2] 231 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 232 | # RoI-pool positive M-RCNN detections 233 | det_box = self.img_to_det_box[image_id].reshape(-1, 81, 4) # [300, 81, 4] 234 | det_score = self.img_to_det_score[image_id] # [300, 81] 235 | det_box = np.transpose(det_box[:, 1:], axes=[1, 0, 2]) # [80, 300, 4] 236 | det_score = np.transpose(det_score[:, 1:], axes=[1, 0]) # [80, 300] 237 | this_thresh = self.CONF_THRESH 238 | positive = det_score > this_thresh # [80, 300] 239 | while np.sum(positive) == 0: 240 | this_thresh -= self.DELTA_CONF 241 | positive = det_score > this_thresh # [80, 300] 242 | pos_box = torch.tensor(det_box[positive]) # [*, 4] 243 | pos_score = torch.tensor(det_score[positive], device=self.device) # [*] 244 | cls_num_list = np.sum(positive, axis=1).tolist() # [80] 245 | pos_feat = mrcn_crop_pool_layer(image_feat, pos_box * scale) # [*, 1024, 7, 7] 246 | pos_feat = pos_feat.to(self.device).unsqueeze(0) # [1, *, 1024, 7, 7] 247 | pos_box = pos_box.to(self.device) 248 | for exp_id, tokens in exps: 249 | # Load word feature 250 | assert isinstance(tokens, list) 251 | sent_feat = torch.tensor(self.idx_to_glove[tokens], device=self.device) 252 | sent_feat = sent_feat.unsqueeze(0) # [1, *, 300] 253 | yield exp_id, pos_feat, sent_feat, pos_box, pos_score, cls_num_list 254 | 255 | def __len__(self): 256 | return len(self.refs) 257 | 258 | 259 | class DetEvalTopLoader: 260 | 261 | BOX_FILE_PATH = 'cache/rpn_boxes.pkl' 262 | SCORE_FILE_PATH = 'cache/rpn_box_scores.pkl' 263 | IMG_FEAT_DIR = 'cache/head_feats/matt-mrcn' 264 | 265 | def __init__(self, refdb, split='val', gpu_id=0, top_N=200): 266 | self.refs = refdb[split] 267 | self.img_to_exps = {} 268 | for ref in self.refs: 269 | image_id = ref['image_id'] 270 | if image_id in self.img_to_exps: 271 | self.img_to_exps[image_id].append((ref['exp_id'], ref['tokens'])) 272 | else: 273 | self.img_to_exps[image_id] = [(ref['exp_id'], ref['tokens'])] 274 | with open(self.BOX_FILE_PATH, 'rb') as f: 275 | self.img_to_det_box = pickle.load(f) 276 | with open(self.SCORE_FILE_PATH, 'rb') as f: 277 | self.img_to_det_score = pickle.load(f) 278 | self.idx_to_glove = np.load('cache/std_glove_{}.npy'.format(refdb['dataset_splitby'])) 279 | self.device = torch.device('cuda', gpu_id) 280 | self.top_N = top_N 281 | 282 | def __iter__(self): 283 | # Fetch ref info 284 | for image_id, exps in self.img_to_exps.items(): 285 | # Load image feature 286 | image_h5 = h5py.File(os.path.join(self.IMG_FEAT_DIR, '{}.h5'.format(image_id)), 'r') 287 | scale = image_h5['im_info'][0, 2] 288 | image_feat = torch.tensor(image_h5['head']) # [1, 1024, ih, iw] 289 | # RoI-pool positive M-RCNN detections 290 | det_box = self.img_to_det_box[image_id].reshape(-1, 81, 4) # [300, 81, 4] 291 | det_score = self.img_to_det_score[image_id] # [300, 81] 292 | det_box = np.transpose(det_box[:, 1:], axes=[1, 0, 2]) # [80, 300, 4] 293 | det_score = np.transpose(det_score[:, 1:], axes=[1, 0]) # [80, 300] 294 | 295 | this_thresh = np.sort(det_score, axis=None)[-self.top_N] 296 | positive = det_score >= this_thresh # [80, 300] 297 | pos_box = torch.tensor(det_box[positive]) # [*, 4] 298 | pos_score = torch.tensor(det_score[positive], device=self.device) # [*] 299 | cls_num_list = np.sum(positive, axis=1).tolist() # [80] 300 | 301 | pos_feat = mrcn_crop_pool_layer(image_feat, pos_box * scale) # [*, 1024, 7, 7] 302 | pos_feat = pos_feat.to(self.device).unsqueeze(0) # [1, *, 1024, 7, 7] 303 | pos_box = pos_box.to(self.device) 304 | for exp_id, tokens in exps: 305 | # Load word feature 306 | assert isinstance(tokens, list) 307 | sent_feat = torch.tensor(self.idx_to_glove[tokens], device=self.device) 308 | sent_feat = sent_feat.unsqueeze(0) # [1, *, 300] 309 | yield exp_id, pos_feat, sent_feat, pos_box, pos_score, cls_num_list 310 | 311 | def __len__(self): 312 | return len(self.refs) 313 | 314 | 315 | def _test(): 316 | import json 317 | from tqdm import tqdm 318 | refdb = json.load(open('cache/refdb_refcoco_unc_nopos.json', 'r')) 319 | ctxdb = json.load(open('cache/ctxdb_refcoco_unc.json', 'r')) 320 | dataset = DetBoxDataset(refdb, ctxdb, 'train') 321 | neg_num, pos_num, total_num = [], [], [] 322 | for pos_rois, neg_rois in tqdm(dataset, ascii=True): 323 | neg_num.append(len(neg_rois)) 324 | pos_num.append(len(pos_rois)) 325 | total_num.append(len(pos_rois) + len(neg_rois)) 326 | print('neg min: {}, neg max: {}, neg mean: {}'.format(min(neg_num), max(neg_num), sum(neg_num) / len(neg_num))) 327 | print('pos min: {}, pos max: {}, pos mean: {}'.format(min(pos_num), max(pos_num), sum(pos_num) / len(pos_num))) 328 | print('total min: {}, total max: {}, total mean: {}'.format(min(total_num), max(total_num), sum(total_num) / len(total_num))) 329 | 330 | 331 | if __name__ == '__main__': _test() 332 | -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChopinSharp/ref-nms/8f83f350c497d0ef875c778a8ce76725552abb3c/output/.gitkeep -------------------------------------------------------------------------------- /scripts/prepare_data.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:$PWD 2 | python tools/build_vocab.py && python tools/build_refdb.py && python tools/build_ctxdb.py -------------------------------------------------------------------------------- /tb/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChopinSharp/ref-nms/8f83f350c497d0ef875c778a8ce76725552abb3c/tb/.gitkeep -------------------------------------------------------------------------------- /tools/build_ctxdb.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import spacy 5 | from tqdm import tqdm 6 | 7 | from lib.refer import REFER 8 | from utils.constants import CAT_ID_TO_NAME, EVAL_SPLITS_DICT 9 | from utils.misc import xywh_to_xyxy, calculate_iou 10 | 11 | 12 | POS_OF_INTEREST = {'NOUN', 'NUM', 'PRON', 'PROPN'} 13 | 14 | 15 | def load_glove_feats(): 16 | glove_path = 'data/glove.840B.300d.txt' 17 | print('loading GloVe feature from {}'.format(glove_path)) 18 | glove_dict = {} 19 | with open(glove_path, 'r') as f: 20 | with tqdm(total=2196017, desc='Loading GloVe', ascii=True) as pbar: 21 | for line in f: 22 | tokens = line.split(' ') 23 | assert len(tokens) == 301 24 | word = tokens[0] 25 | vec = list(map(lambda x: float(x), tokens[1:])) 26 | glove_dict[word] = vec 27 | pbar.update(1) 28 | return glove_dict 29 | 30 | 31 | def cosine_similarity(feat_a, feat_b): 32 | return np.sum(feat_a * feat_b) / np.sqrt(np.sum(feat_a * feat_a) * np.sum(feat_b * feat_b)) 33 | 34 | 35 | def build_ctxdb(dataset, split_by): 36 | dataset_splitby = '{}_{}'.format(dataset, split_by) 37 | # Load refer 38 | refer = REFER('data/refer', dataset, split_by) 39 | # Load GloVe feature 40 | glove_dict = load_glove_feats() 41 | # Construct COCO category GloVe feature 42 | cat_id_to_glove = {} 43 | for cat_id, cat_name in CAT_ID_TO_NAME.items(): 44 | cat_id_to_glove[cat_id] = [np.array(glove_dict[t], dtype=np.float32) for t in cat_name.split(' ')] 45 | # Spacy to extract POS tags 46 | nlp = spacy.load('en_core_web_sm') 47 | # Go through the refdb 48 | ctxdb = {} 49 | for split in (['train'] + EVAL_SPLITS_DICT[dataset_splitby]): 50 | exp_to_ctx = {} 51 | gt_miss_num, empty_num, sent_num = 0, 0, 0 52 | coco_box_num_list, ctx_box_num_list = [], [] 53 | ref_ids = refer.getRefIds(split=split) 54 | for ref_id in tqdm(ref_ids, ascii=True, desc=split): 55 | ref = refer.Refs[ref_id] 56 | image_id = ref['image_id'] 57 | gt_box = xywh_to_xyxy(refer.Anns[ref['ann_id']]['bbox']) 58 | gt_cat = refer.Anns[ref['ann_id']]['category_id'] 59 | for sent in ref['sentences']: 60 | sent_num += 1 61 | sent_id = sent['sent_id'] 62 | doc = nlp(sent['sent']) 63 | noun_tokens = [token.text for token in doc if token.pos_ in POS_OF_INTEREST] 64 | # print('SENT', sent['sent']) 65 | # print('NOUN TOKENS', noun_tokens) 66 | noun_glove_list = [np.array(glove_dict[t], dtype=np.float32) for t in noun_tokens if t in glove_dict] 67 | gt_hit = False 68 | ctx_list = [] 69 | for ann in refer.imgToAnns[image_id]: 70 | ann_glove_list = cat_id_to_glove[ann['category_id']] 71 | cos_sim_list = [cosine_similarity(ann_glove, noun_glove) 72 | for ann_glove in ann_glove_list 73 | for noun_glove in noun_glove_list] 74 | # print(CAT_ID_TO_NAME[ann['category_id']], cos_sim_list) 75 | max_cos_sim = max(cos_sim_list, default=0.) 76 | if max_cos_sim > 0.4: 77 | ann_box = xywh_to_xyxy(ann['bbox']) 78 | if calculate_iou(ann_box, gt_box) > 0.9: 79 | gt_hit = True 80 | else: 81 | ctx_list.append({'box': ann_box, 'cat_id': ann['category_id']}) 82 | if not gt_hit: 83 | gt_miss_num += 1 84 | if not ctx_list: 85 | empty_num += 1 86 | exp_to_ctx[sent_id] = {'gt': {'box': gt_box, 'cat_id': gt_cat}, 'ctx': ctx_list} 87 | coco_box_num_list.append(len(refer.imgToAnns[image_id])) 88 | ctx_box_num_list.append(len(ctx_list) + 1) 89 | print('GT miss: {} out of {}'.format(gt_miss_num, sent_num)) 90 | print('empty ctx: {} out of {}'.format(empty_num, sent_num)) 91 | print('COCO box per sentence: {:.3f}'.format(sum(coco_box_num_list) / len(coco_box_num_list))) 92 | print('ctx box per sentence: {:.3f}'.format(sum(ctx_box_num_list) / len(ctx_box_num_list))) 93 | ctxdb[split] = exp_to_ctx 94 | # Save results 95 | save_path = 'cache/std_ctxdb_{}.json'.format(dataset_splitby) 96 | print('saving ctxdb to {}'.format(save_path)) 97 | with open(save_path, 'w') as f: 98 | json.dump(ctxdb, f) 99 | 100 | 101 | def main(): 102 | print('building ctxdb...') 103 | for dataset, split_by in [('refcoco', 'unc'), ('refcoco+', 'unc'), ('refcocog', 'umd')]: 104 | print('building {}_{}...'.format(dataset, split_by)) 105 | build_ctxdb(dataset, split_by) 106 | print() 107 | 108 | 109 | main() 110 | -------------------------------------------------------------------------------- /tools/build_refdb.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | from lib.refer import REFER 6 | from utils.misc import xywh_to_xyxy 7 | 8 | 9 | DATASET_SPLITS = { 10 | 'refcoco_unc': ['train', 'val', 'testA', 'testB'], 11 | 'refcoco+_unc': ['train', 'val', 'testA', 'testB'], 12 | 'refcocog_umd': ['train', 'val', 'test'] 13 | } 14 | 15 | 16 | def build_refdb(dataset, split_by): 17 | # Load refer data 18 | refer = REFER('data/refer', dataset, split_by) 19 | # Load vocab 20 | with open('cache/std_vocab_{}_{}.txt'.format(dataset, split_by)) as f: 21 | idx_to_wd = [wd[:-1] for wd in f.readlines()] # trim off newline 22 | wd_to_idx = {} 23 | for idx, wd in enumerate(idx_to_wd): 24 | wd_to_idx[wd] = idx 25 | # Build refdb 26 | dataset_splitby = '{}_{}'.format(dataset, split_by) 27 | data = {'dataset_splitby': dataset_splitby} 28 | for split in DATASET_SPLITS[dataset_splitby]: 29 | split_data = [] 30 | for ref_id in refer.getRefIds(split=split): 31 | ref = refer.Refs[ref_id] 32 | image_id = ref['image_id'] 33 | ann_id = ref['ann_id'] 34 | ann = refer.Anns[ann_id] 35 | bbox = xywh_to_xyxy(ann['bbox']) 36 | # Filter with POS 37 | for sent in ref['sentences']: 38 | sent_id, tokens = sent['sent_id'], sent['tokens'] 39 | # Encode with vocab 40 | encoded_tokens = [wd_to_idx[wd] if wd in wd_to_idx else 0 for wd in tokens] 41 | split_data.append({ 42 | 'exp_id': sent_id, 43 | 'ref_id': ref_id, 44 | 'image_id': image_id, 45 | 'bbox': bbox, 46 | 'tokens': encoded_tokens 47 | }) 48 | data[split] = split_data 49 | # Print out statistics 50 | print('STATS for {}'.format(dataset_splitby)) 51 | for split in DATASET_SPLITS[dataset_splitby]: 52 | ref_num = len({ref['ref_id'] for ref in data[split]}) 53 | sent_num = len(data[split]) 54 | avg_sent_num = sent_num / ref_num 55 | token_len = np.array([len(ref['tokens']) for ref in data[split]], dtype=np.float32) 56 | print('[{}]'.format(split)) 57 | print('ref_num={}, avg_sent_num={:.4f}'.format(ref_num, avg_sent_num)) 58 | print('token_len: mean={:.4f}, std={:.4f}'.format(token_len.mean(), token_len.std())) 59 | # Save refdb to json file 60 | refdb_save_path = 'cache/std_refdb_{}.json'.format(dataset_splitby) 61 | print('saving refdb to file: {}'.format(refdb_save_path)) 62 | with open(refdb_save_path, 'w') as f: 63 | json.dump(data, f) 64 | 65 | 66 | def main(): 67 | print('building refdb...') 68 | for dataset, split_by in [('refcoco', 'unc'), ('refcoco+', 'unc'), ('refcocog', 'umd')]: 69 | print('building {}_{}...'.format(dataset, split_by)) 70 | build_refdb(dataset, split_by) 71 | print() 72 | 73 | 74 | main() 75 | -------------------------------------------------------------------------------- /tools/build_vocab.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | 4 | from lib.refer import REFER 5 | 6 | 7 | GLOVE_WORD_NUM = 2196017 8 | GLOVE_FILE = 'data/glove.840B.300d.txt' 9 | VOCAB_THRESHOLD = 2 10 | VOCAB_SAVE_PATH = 'cache/std_vocab_{}_{}.txt' 11 | GLOVE_SAVE_PATH = 'cache/std_glove_{}_{}.npy' 12 | 13 | 14 | def load_glove_feats(): 15 | glove_dict = {} # from word of to vector of 16 | with open(GLOVE_FILE, 'r') as f: 17 | with tqdm(total=GLOVE_WORD_NUM, desc='Loading GloVe', ascii=True) as pbar: 18 | for line in f: 19 | tokens = line.split(' ') 20 | assert len(tokens) == 301 21 | word = tokens[0] 22 | vec = list(map(lambda x: float(x), tokens[1:])) 23 | glove_dict[word] = vec 24 | pbar.update(1) 25 | return glove_dict 26 | 27 | 28 | def build_vocabulary(dataset, split_by, glove_dict): 29 | # load refer 30 | refer = REFER('data/refer', dataset, split_by) 31 | 32 | # filter corpus by frequency and GloVe 33 | word_count = {} 34 | for ref in refer.Refs.values(): 35 | for sent in ref['sentences']: 36 | for word in sent['tokens']: 37 | word_count[word] = word_count.get(word, 0) + 1 38 | vocab, typo, rare = [], [], [] 39 | for wd, n in word_count.items(): 40 | if n < VOCAB_THRESHOLD: 41 | rare.append(wd) 42 | else: 43 | if wd in glove_dict: 44 | vocab.append(wd) 45 | else: 46 | typo.append(wd) 47 | assert len(vocab) + len(typo) + len(rare) == len(word_count) 48 | rare_count = sum([word_count[wd] for wd in rare]) 49 | typo_count = sum([word_count[wd] for wd in typo]) 50 | total_words = sum(word_count.values()) 51 | print('number of good words: {}'.format(len(vocab))) 52 | print('number of rare words: {}/{} = {:.2f}%'.format( 53 | len(rare), len(word_count), len(rare)*100/len(word_count))) 54 | print('number of typo words: {}/{} = {:.2f}%'.format( 55 | len(typo), len(word_count), len(typo)*100/len(word_count))) 56 | print('number of UNKs in sentences: ({}+{})/{} = {:.2f}%'.format( 57 | rare_count, typo_count, total_words, (rare_count+typo_count)*100/total_words)) 58 | 59 | # sort vocab and construct glove feats 60 | vocab = sorted(vocab) 61 | vocab_glove = [] 62 | for wd in vocab: 63 | vocab_glove.append(glove_dict[wd]) 64 | vocab.insert(0, '') 65 | vocab_glove.insert(0, [0.] * 300) 66 | vocab_glove = np.array(vocab_glove, dtype=np.float32) 67 | 68 | # save vocab and glove feats 69 | vocab_save_path = VOCAB_SAVE_PATH.format(dataset, split_by) 70 | glove_save_path = GLOVE_SAVE_PATH.format(dataset, split_by) 71 | print('saving vacob in {}'.format(vocab_save_path)) 72 | with open(vocab_save_path, 'w') as f: 73 | for wd in vocab: 74 | f.write(wd + '\n') 75 | print('saving vocab glove in {}'.format(glove_save_path)) 76 | np.save(glove_save_path, vocab_glove) 77 | 78 | 79 | def main(): 80 | print('building vocab...') 81 | glove_feats = load_glove_feats() 82 | for dataset, split_by in [('refcoco', 'unc'), ('refcoco+', 'unc'), ('refcocog', 'umd')]: 83 | print('building {}_{}...'.format(dataset, split_by)) 84 | build_vocabulary(dataset, split_by, glove_feats) 85 | print() 86 | 87 | 88 | main() 89 | -------------------------------------------------------------------------------- /tools/eval_proposal_ctx_recall.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import json 4 | 5 | from utils.hit_rate_utils import CtxHitRateEvaluator 6 | from utils.constants import EVAL_SPLITS_DICT 7 | from lib.refer import REFER 8 | 9 | 10 | def threshold_with_confidence(exp_to_proposals, conf): 11 | results = {} 12 | for exp_id, proposals in exp_to_proposals.items(): 13 | assert len(proposals) >= 1 14 | sorted_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True) 15 | thresh_proposals = [sorted_proposals[0]] 16 | for prop in sorted_proposals[1:]: 17 | if prop['score'] > conf: 18 | thresh_proposals.append(prop) 19 | else: 20 | break 21 | results[exp_id] = thresh_proposals 22 | return results 23 | 24 | 25 | def main(args): 26 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 27 | eval_splits = EVAL_SPLITS_DICT[dataset_splitby] 28 | # Load proposals 29 | used_proposal_path = 'cache/proposals_{}_{}_{}.pkl'.format(args.m, args.dataset, args.tid) 30 | print('loading {} proposals from {}...'.format(args.m, used_proposal_path)) 31 | with open(used_proposal_path, 'rb') as f: 32 | used_proposal_dict = pickle.load(f) 33 | # Load refer 34 | refer = REFER('data/refer', dataset=args.dataset, splitBy=args.split_by) 35 | # Load ctxdb 36 | with open('cache/std_ctxdb_{}.json'.format(dataset_splitby), 'r') as f: 37 | ctxdb = json.load(f) 38 | # Evaluate hit rate 39 | print('Context object recall on', dataset_splitby) 40 | evaluator = CtxHitRateEvaluator(refer, ctxdb, top_N=None, threshold=args.thresh) 41 | for split in eval_splits: 42 | exp_to_proposals = used_proposal_dict[split] 43 | exp_to_proposals = threshold_with_confidence(exp_to_proposals, args.conf) 44 | proposal_per_ref, hit_rate = evaluator.eval_hit_rate(split, exp_to_proposals) 45 | print('[{:5s}] hit rate: {:.2f} @ {:.2f}'.format(split, hit_rate*100, proposal_per_ref)) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--m', type=str, required=True) 51 | parser.add_argument('--dataset', default='refcoco') 52 | parser.add_argument('--split-by', default='unc') 53 | parser.add_argument('--tid', type=str, required=True) 54 | parser.add_argument('--top-N', type=int, required=True) 55 | parser.add_argument('--thresh', type=float, default=0.5) 56 | parser.add_argument('--conf', type=float, required=True) 57 | main(parser.parse_args()) 58 | -------------------------------------------------------------------------------- /tools/eval_proposal_hit_rate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from utils.hit_rate_utils import NewHitRateEvaluator 5 | from utils.constants import EVAL_SPLITS_DICT 6 | from lib.refer import REFER 7 | 8 | 9 | def threshold_with_confidence(exp_to_proposals, conf): 10 | results = {} 11 | for exp_id, proposals in exp_to_proposals.items(): 12 | assert len(proposals) >= 1 13 | sorted_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True) 14 | thresh_proposals = [sorted_proposals[0]] 15 | for prop in sorted_proposals[1:]: 16 | if prop['score'] > conf: 17 | thresh_proposals.append(prop) 18 | else: 19 | break 20 | results[exp_id] = thresh_proposals 21 | return results 22 | 23 | 24 | def main(args): 25 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 26 | eval_splits = EVAL_SPLITS_DICT[dataset_splitby] 27 | # Load proposals 28 | proposal_path = 'cache/proposals_{}_{}_{}.pkl'.format(args.m, args.dataset, args.tid) 29 | print('loading {} proposals from {}...'.format(args.m, proposal_path)) 30 | with open(proposal_path, 'rb') as f: 31 | proposal_dict = pickle.load(f) 32 | # Load refer 33 | refer = REFER('data/refer', dataset=args.dataset, splitBy=args.split_by) 34 | # Evaluate hit rate 35 | print('Hit rate on {}\n'.format(dataset_splitby)) 36 | evaluator = NewHitRateEvaluator(refer, top_N=None, threshold=args.thresh) 37 | print('conf: {:.3f}'.format(args.conf)) 38 | for split in eval_splits: 39 | exp_to_proposals = proposal_dict[split] 40 | exp_to_proposals = threshold_with_confidence(exp_to_proposals, args.conf) 41 | proposal_per_ref, hit_rate = evaluator.eval_hit_rate(split, exp_to_proposals) 42 | print('[{:5s}] hit rate: {:.2f} @ {:.2f}'.format(split, hit_rate*100, proposal_per_ref)) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--m', type=str, required=True) 48 | parser.add_argument('--dataset', default='refcoco') 49 | parser.add_argument('--split-by', default='unc') 50 | parser.add_argument('--tid', type=str, required=True) 51 | parser.add_argument('--thresh', type=float, default=0.5) 52 | parser.add_argument('--conf', type=float, required=True) 53 | main(parser.parse_args()) 54 | -------------------------------------------------------------------------------- /tools/save_matt_dets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import json 4 | 5 | from utils.constants import EVAL_SPLITS_DICT, COCO_CAT_NAMES, CAT_NAME_TO_ID 6 | from lib.refer import REFER 7 | 8 | 9 | def threshold_with_top_N(exp_to_proposals, top_N): 10 | results = {} 11 | for exp_id, proposals in exp_to_proposals.items(): 12 | assert len(proposals) >= 1 13 | results[exp_id] = sorted(proposals, key=lambda p: p['score'], reverse=True)[:top_N] 14 | return results 15 | 16 | 17 | def threshold_with_confidence(exp_to_proposals, conf): 18 | results = {} 19 | for exp_id, proposals in exp_to_proposals.items(): 20 | assert len(proposals) >= 1 21 | sorted_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True) 22 | thresh_proposals = [sorted_proposals[0]] 23 | for prop in sorted_proposals[1:]: 24 | if prop['score'] > conf: 25 | thresh_proposals.append(prop) 26 | else: 27 | break 28 | results[exp_id] = thresh_proposals 29 | return results 30 | 31 | 32 | def main(args): 33 | # Setup 34 | assert args.top_N is None or args.conf is None 35 | assert args.top_N is not None or args.conf is not None 36 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 37 | refer = REFER('data/refer', dataset=args.dataset, splitBy=args.split_by) 38 | det_id = 0 39 | matt_dets = [] 40 | eval_splits = EVAL_SPLITS_DICT[dataset_splitby] 41 | 42 | # Add model detections for valid sentences 43 | proposal_path = 'cache/proposals_{}_{}_{}.pkl'.format(args.m, args.dataset, args.tid) 44 | print('loading proposals from {}...'.format(proposal_path)) 45 | with open(proposal_path, 'rb') as f: 46 | proposal_dict = pickle.load(f) 47 | for split in eval_splits: 48 | exp_to_proposals = proposal_dict[split] 49 | if args.top_N is not None: 50 | exp_to_proposals = threshold_with_top_N(exp_to_proposals, args.top_N) 51 | if args.conf is not None: 52 | exp_to_proposals = threshold_with_confidence(exp_to_proposals, args.conf) 53 | for exp_id, proposals in exp_to_proposals.items(): 54 | ref = refer.sentToRef[exp_id] 55 | ref_id = ref['ref_id'] 56 | image_id = ref['image_id'] 57 | for proposal in proposals: 58 | x1, y1, x2, y2 = proposal['box'] 59 | w, h = x2 - x1, y2 - y1 60 | box = (x1, y1, w, h) 61 | cat_name = COCO_CAT_NAMES[proposal['cls_idx']] 62 | det = { 63 | 'det_id': det_id, 64 | 'h5_id': det_id, 65 | 'ref_id': ref_id, 66 | 'sent_id': exp_id, 67 | 'image_id': image_id, 68 | 'box': box, 69 | 'category_id': CAT_NAME_TO_ID[cat_name], 70 | 'category_name': cat_name, 71 | 'split': split, 72 | # 'cls_score': proposal['det_score'], 73 | # 'rank_score': proposal['rank_score'], 74 | # 'fin_score': proposal['score'] 75 | } 76 | matt_dets.append(det) 77 | det_id += 1 78 | 79 | # Print out stats and save detections 80 | for split in eval_splits: 81 | exp_num = len({det['sent_id'] for det in matt_dets if det['split'] == split}) 82 | det_num = len([det for det in matt_dets if det['split'] == split]) 83 | print('[{:5s}] {} / {} = {:.2f} detections per expression' 84 | .format(split, det_num, exp_num, det_num / exp_num)) 85 | top_N = 0 if args.top_N is None else args.top_N 86 | save_path = 'output/matt_dets_{}_{}_{}_{}.json'.format(args.m, args.tid, dataset_splitby, top_N) 87 | # save_path = 'output/matt_dets_{}_{}_{}_{}_more.json'.format(args.m, args.tid, dataset_splitby, top_N) 88 | print('saving detections to {}...'.format(save_path)) 89 | with open(save_path, 'w') as f: 90 | json.dump(matt_dets, f) 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--dataset', type=str, default='refcoco') 96 | parser.add_argument('--split-by', type=str, default='unc') 97 | parser.add_argument('--m', type=str, required=True) 98 | parser.add_argument('--top-N', type=int, default=None) 99 | parser.add_argument('--tid', type=str, required=True) 100 | parser.add_argument('--conf', type=float, default=None) 101 | main(parser.parse_args()) 102 | -------------------------------------------------------------------------------- /tools/save_ref_nms_proposals.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import argparse 4 | from multiprocessing import Pool 5 | 6 | import torch 7 | from tqdm import tqdm 8 | from torchvision.ops import nms 9 | from torch.nn.utils.rnn import pack_padded_sequence 10 | 11 | from lib.predictor import AttVanillaPredictorV2 12 | from lib.vanilla_utils import DetEvalLoader 13 | from utils.constants import EVAL_SPLITS_DICT 14 | 15 | 16 | def rank_proposals(position, gpu_id, tid, refdb_path, split, m): 17 | # Load refdb 18 | with open(refdb_path) as f: 19 | refdb = json.load(f) 20 | dataset_ = refdb['dataset_splitby'].split('_')[0] 21 | # Load pre-trained model 22 | device = torch.device('cuda', gpu_id) 23 | with open('output/{}_{}_{}.json'.format(m, dataset_, tid), 'r') as f: 24 | model_info = json.load(f) 25 | predictor = AttVanillaPredictorV2(att_dropout_p=model_info['config']['ATT_DROPOUT_P'], 26 | rank_dropout_p=model_info['config']['RANK_DROPOUT_P']) 27 | model_path = 'output/{}_{}_{}_b.pth'.format(m, dataset_, tid) 28 | predictor.load_state_dict(torch.load(model_path)) 29 | predictor.to(device) 30 | predictor.eval() 31 | # Rank proposals 32 | exp_to_proposals = {} 33 | loader = DetEvalLoader(refdb, split, gpu_id) 34 | tqdm_loader = tqdm(loader, desc='scoring {}'.format(split), ascii=True, position=position) 35 | for exp_id, pos_feat, sent_feat, pos_box, pos_score, cls_num_list in tqdm_loader: 36 | # Compute rank score 37 | packed_sent_feats = pack_padded_sequence(sent_feat, torch.tensor([sent_feat.size(1)]), 38 | enforce_sorted=False, batch_first=True) 39 | with torch.no_grad(): 40 | rank_score, *_ = predictor(pos_feat, packed_sent_feats) # [1, *] 41 | # Normalize rank score 42 | rank_score = torch.sigmoid(rank_score[0]) 43 | # Split scores and boxes category-wise 44 | rank_score_list = torch.split(rank_score, cls_num_list, dim=0) 45 | pos_box_list = torch.split(pos_box, cls_num_list, dim=0) 46 | pos_score_list = torch.split(pos_score, cls_num_list, dim=0) 47 | # Combine score and do NMS category-wise 48 | proposals = [] 49 | cls_idx = 0 50 | for cls_rank_score, cls_pos_box, cls_pos_score in zip(rank_score_list, pos_box_list, pos_score_list): 51 | cls_idx += 1 52 | # No positive box under this category 53 | if cls_rank_score.size(0) == 0: 54 | continue 55 | final_score = cls_rank_score * cls_pos_score 56 | keep = nms(cls_pos_box, final_score, iou_threshold=0.3) 57 | cls_kept_box = cls_pos_box[keep] 58 | cls_kept_score = final_score[keep] 59 | for box, score in zip(cls_kept_box, cls_kept_score): 60 | proposals.append({'score': score.item(), 'box': box.tolist(), 'cls_idx': cls_idx}) 61 | assert cls_idx == 80 62 | exp_to_proposals[exp_id] = proposals 63 | return exp_to_proposals 64 | 65 | 66 | def error_callback(e): 67 | print('\n\n\n\nERROR in subprocess:', e, '\n\n\n\n') 68 | 69 | 70 | def main(args): 71 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 72 | eval_splits = EVAL_SPLITS_DICT[dataset_splitby] 73 | refdb_path = 'cache/std_refdb_{}.json'.format(dataset_splitby) 74 | print('about to rank proposals via multiprocessing, good luck ~') 75 | results = {} 76 | with Pool(processes=len(eval_splits)) as pool: 77 | for idx, split in enumerate(eval_splits): 78 | sub_args = (idx, args.gpu_id, args.tid, refdb_path, split, args.m) 79 | results[split] = pool.apply_async(rank_proposals, sub_args, error_callback=error_callback) 80 | pool.close() 81 | pool.join() 82 | proposal_dict = {} 83 | for split in eval_splits: 84 | assert results[split].successful() 85 | print('subprocess for {} split succeeded, fetching results...'.format(split)) 86 | proposal_dict[split] = results[split].get() 87 | save_path = 'cache/proposals_{}_{}_{}.pkl'.format(args.m, args.dataset, args.tid) 88 | print('saving proposals to {}...'.format(save_path)) 89 | with open(save_path, 'wb') as f: 90 | pickle.dump(proposal_dict, f) 91 | print('all done ~') 92 | 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--gpu-id', type=int, default=0) 97 | parser.add_argument('--dataset', default='refcoco') 98 | parser.add_argument('--split-by', default='unc') 99 | parser.add_argument('--tid', type=str, required=True) 100 | parser.add_argument('--m', type=str, default='att_vanilla') 101 | main(parser.parse_args()) 102 | -------------------------------------------------------------------------------- /tools/train_att_rank.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import os 4 | from time import time 5 | import copy 6 | import itertools 7 | 8 | from tqdm import tqdm 9 | import torch 10 | from torch import optim 11 | from torch.utils.data import DataLoader 12 | from torch.nn.init import zeros_, xavier_uniform_ 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.nn.utils.rnn import pack_padded_sequence 15 | 16 | from lib.rank_utils import RankDataset 17 | from lib.predictor import AttVanillaPredictorV2 18 | from utils.misc import get_time_id 19 | 20 | 21 | PRETRAINED_MRCN = 'data/res101_mask_rcnn_iter_1250000_cpu.pth' 22 | 23 | CONFIG = dict( 24 | HEAD_LR=2e-4, 25 | HEAD_WD=1e-3, 26 | REF_LR=5e-4, 27 | REF_WD=1e-3, 28 | RNN_LR=5e-4, 29 | RNN_WD=1e-3, 30 | BATCH_SIZE=8, 31 | EPOCH=4, 32 | ATT_DROPOUT_P=0.5, 33 | RANK_DROPOUT_P=0.5, 34 | LOSS_MARGIN=0.1, 35 | TOP_H=100, 36 | ROI_PER_LEVEL=10, 37 | NEGATIVE_NUM=150, 38 | LEVEL_NUM=6 39 | ) 40 | LOG_INTERVAL = 50 41 | VAL_INTERVAL = 1000 42 | 43 | 44 | class RankLoss: 45 | 46 | def __init__(self, margin, device): 47 | self.margin = margin 48 | self.zero_scalar = torch.tensor(0., device=device) 49 | 50 | def __call__(self, rank_score, sampled_pairs): 51 | """Compute Hinge loss on sampled pairs. 52 | 53 | Args: 54 | rank_score: Tensor of shape [batch_size, roi_num]. 55 | sampled_pairs: Tensor of shape [batch_size, pair_num, 2] 56 | 57 | Returns: 58 | loss: Computed loss. 59 | 60 | """ 61 | pos_indices = sampled_pairs[:, :, 0] # [batch_size, pair_num] 62 | neg_indices = sampled_pairs[:, :, 1] # [batch_size, pair_num] 63 | pos_scores = torch.gather(rank_score, 1, pos_indices) # [batch_size, pair_num] 64 | neg_scores = torch.gather(rank_score, 1, neg_indices) # [batch_size, pair_num] 65 | loss = torch.max(neg_scores - pos_scores + self.margin, self.zero_scalar).mean() 66 | return loss 67 | 68 | 69 | class PairSampler: 70 | 71 | def __init__(self, level_num, roi_per_level, negative_num, top_h, device): 72 | self.level_num = level_num 73 | self.roi_per_level = roi_per_level 74 | self.negative_num = negative_num 75 | self.top_h = top_h 76 | self.device = device 77 | 78 | def __call__(self, rank_score): 79 | """Sample training pairs with hard negative mining. 80 | 81 | Args: 82 | rank_score: Tensor of shape [batch_size, roi_num]. 83 | 84 | Returns: 85 | batch_sampled_pairs: Tensor of shape [batch_size, pair_num, 2]. 86 | `batch_sampled_pairs[:, :, 0]` are indices of positive ROIs, 87 | `batch_sampled_pairs[:, :, 1]` are indices of negative ROIs. 88 | 89 | """ 90 | N, R = rank_score.size() 91 | assert R == self.negative_num + (self.level_num - 1) * self.roi_per_level 92 | batch_sorted_idx = rank_score.argsort(dim=1, descending=True) # [batch_size, roi_num] 93 | batch_sampled_pairs = [] 94 | for b in range(N): 95 | sorted_idx_list = batch_sorted_idx[b].tolist() 96 | pair_list = [] 97 | for l in range(self.level_num - 1): 98 | start_idx = self.negative_num + l * self.roi_per_level 99 | pos_idx = [i for i in range(start_idx, start_idx + self.roi_per_level)] 100 | neg_idx = [i for i in sorted_idx_list if i in range(start_idx)] 101 | neg_idx = neg_idx[:self.top_h] 102 | pair_list.extend(itertools.product(pos_idx, neg_idx)) 103 | batch_sampled_pairs.append(pair_list) 104 | batch_sampled_pairs = torch.tensor(batch_sampled_pairs, device=self.device) 105 | return batch_sampled_pairs 106 | 107 | 108 | def init_att_vanilla_predictor(predictor): 109 | # Load pre-trained weights from M-RCNN 110 | mrcn_weights = torch.load(PRETRAINED_MRCN) 111 | c4_weights = { 112 | k[len('resnet.layer4.'):]: v 113 | for k, v in mrcn_weights.items() 114 | if k.startswith('resnet.layer4') 115 | } 116 | assert len(c4_weights) == 50 117 | predictor.head.load_state_dict(c4_weights) 118 | # Initialize new layers 119 | count = 0 120 | for name, param in predictor.named_parameters(): 121 | if 'head' in name: 122 | continue 123 | if 'weight' in name: 124 | xavier_uniform_(param) 125 | count += 1 126 | elif 'bias' in name: 127 | zeros_(param) 128 | count += 1 129 | assert count == 20 130 | 131 | 132 | def compute_loss(predictor, sampler, criterion, device, enable_grad, roi_feats, word_feats, sent_len): 133 | with torch.autograd.set_grad_enabled(enable_grad): 134 | roi_feats = roi_feats.to(device) 135 | word_feats = word_feats.to(device) 136 | packed_sent_feats = pack_padded_sequence(word_feats, sent_len, enforce_sorted=False, batch_first=True) 137 | scores, *_ = predictor.forward(roi_feats, packed_sent_feats) 138 | sigmoid_scores = torch.sigmoid(scores) 139 | pairs = sampler(sigmoid_scores) 140 | loss = criterion(sigmoid_scores, pairs) 141 | return loss, sigmoid_scores 142 | 143 | 144 | def main(args): 145 | if args.resume is None: 146 | tid = get_time_id() 147 | start_epoch = 0 148 | else: 149 | *_, tid, start_epoch = args.resume[:-4].split('_') 150 | tid += '_cont' 151 | start_epoch = int(start_epoch) 152 | if args.epoch is not None: 153 | CONFIG['EPOCH'] = args.epoch 154 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 155 | device = torch.device('cuda', args.gpu_id) 156 | refdb_path = 'cache/std_refdb_{}.json'.format(dataset_splitby) 157 | print('loading refdb from {}...'.format(refdb_path)) 158 | with open(refdb_path, 'r') as f: 159 | refdb = json.load(f) 160 | ctxdb_path = 'cache/std_ctxdb_{}.json'.format(dataset_splitby) 161 | print('loading ctxdb from {}...'.format(ctxdb_path)) 162 | with open(ctxdb_path, 'r') as f: 163 | ctxdb = json.load(f) 164 | # Build dataloaders 165 | dataset_settings = dict(level_num=CONFIG['LEVEL_NUM'], roi_per_level=CONFIG['ROI_PER_LEVEL'], 166 | negative_num=CONFIG['NEGATIVE_NUM']) 167 | trn_dataset = RankDataset(refdb, ctxdb, 'train', **dataset_settings) 168 | val_dataset = RankDataset(refdb, ctxdb, 'val', **dataset_settings) 169 | loader_settings = dict(batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=4) 170 | trn_loader = DataLoader(trn_dataset, **loader_settings) 171 | val_loader = DataLoader(val_dataset, drop_last=True, **loader_settings) 172 | # Tensorboard writer 173 | tb_dir = 'tb/att_rank/{}'.format(tid) 174 | trn_wrt = SummaryWriter(os.path.join(tb_dir, 'train')) 175 | val_wrt = SummaryWriter(os.path.join(tb_dir, 'val')) 176 | sc_wrts = {l: SummaryWriter(os.path.join(tb_dir, 'level:{}'.format(l))) for l in range(CONFIG['LEVEL_NUM'])} 177 | # Build and init predictor 178 | predictor = AttVanillaPredictorV2(att_dropout_p=CONFIG['ATT_DROPOUT_P'], rank_dropout_p=CONFIG['RANK_DROPOUT_P']) 179 | init_att_vanilla_predictor(predictor) 180 | predictor.to(device) 181 | # Setup pair sampler 182 | sampler = PairSampler(CONFIG['LEVEL_NUM'], CONFIG['ROI_PER_LEVEL'], CONFIG['NEGATIVE_NUM'], CONFIG['TOP_H'], device) 183 | # Setup loss 184 | criterion = RankLoss(CONFIG['LOSS_MARGIN'], device) 185 | # Setup optimizer 186 | ref_params = list(predictor.att_fc.parameters()) + list(predictor.rank_fc.parameters()) \ 187 | + list(predictor.vis_a_fc.parameters()) + list(predictor.vis_r_fc.parameters()) 188 | ref_optimizer = optim.Adam(ref_params, lr=CONFIG['REF_LR'], weight_decay=CONFIG['REF_WD']) 189 | rnn_optimizer = optim.Adam(predictor.rnn.parameters(), lr=CONFIG['RNN_LR'], weight_decay=CONFIG['RNN_WD']) 190 | head_optimizer = optim.Adam(predictor.head.parameters(), lr=CONFIG['HEAD_LR'], weight_decay=CONFIG['HEAD_WD']) 191 | common_args = dict(mode='min', factor=0.6, verbose=True, threshold_mode='rel', patience=3) 192 | ref_scheduler = optim.lr_scheduler.ReduceLROnPlateau(ref_optimizer, min_lr=CONFIG['REF_LR']/100, **common_args) 193 | rnn_scheduler = optim.lr_scheduler.ReduceLROnPlateau(rnn_optimizer, min_lr=CONFIG['RNN_LR']/100, **common_args) 194 | head_scheduler = optim.lr_scheduler.ReduceLROnPlateau(head_optimizer, min_lr=CONFIG['HEAD_LR']/100, **common_args) 195 | # Start training 196 | if args.resume is not None: 197 | resume_ckpt = torch.load(os.path.join('output', args.resume)) 198 | predictor.load_state_dict(resume_ckpt['model']) 199 | ref_optimizer.load_state_dict(resume_ckpt['ref_optimizer']) 200 | rnn_optimizer.load_state_dict(resume_ckpt['rnn_optimizer']) 201 | head_optimizer.load_state_dict(resume_ckpt['head_optimizer']) 202 | ref_scheduler.load_state_dict(resume_ckpt['ref_scheduler']) 203 | rnn_scheduler.load_state_dict(resume_ckpt['rnn_scheduler']) 204 | head_scheduler.load_state_dict(resume_ckpt['head_scheduler']) 205 | step = 0 206 | trn_running_loss = 0. 207 | best_model = {'avg_val_loss': float('inf'), 'epoch': None, 'step': None, 'weights': None} 208 | tic = time() 209 | for epoch in range(start_epoch, CONFIG['EPOCH']): 210 | for trn_batch in trn_loader: 211 | # Train for one step 212 | step += 1 213 | predictor.train() 214 | trn_loss, _ = compute_loss(predictor, sampler, criterion, device, True, *trn_batch) 215 | head_optimizer.zero_grad() 216 | ref_optimizer.zero_grad() 217 | rnn_optimizer.zero_grad() 218 | trn_loss.backward() 219 | head_optimizer.step() 220 | ref_optimizer.step() 221 | rnn_optimizer.step() 222 | trn_running_loss += trn_loss.item() 223 | # Log training loss 224 | if step % LOG_INTERVAL == 0: 225 | avg_trn_loss = trn_running_loss / LOG_INTERVAL 226 | print('[TRN Loss] epoch {} step {}: {:.6f}'.format(epoch + 1, step, avg_trn_loss)) 227 | trn_wrt.add_scalar('loss', avg_trn_loss, step) 228 | trn_running_loss = 0. 229 | # Eval on whole val split 230 | if step % VAL_INTERVAL == 0: 231 | # Compute and log val loss 232 | predictor.eval() 233 | val_loss_list = [] 234 | level_score_mean_list = {l: [] for l in range(CONFIG['LEVEL_NUM'])} 235 | pbar = tqdm(total=len(val_dataset), ascii=True, desc='computing val loss') 236 | for val_batch in val_loader: 237 | val_loss, val_score = compute_loss(predictor, sampler, criterion, device, False, *val_batch) 238 | val_loss_list.append(val_loss.item()) 239 | neg_rank_score = val_score[:, :CONFIG['NEGATIVE_NUM']] 240 | level_score_mean_list[0].append(neg_rank_score.mean().item()) 241 | pos_rank_score = val_score[:, CONFIG['NEGATIVE_NUM']:] 242 | pos_rank_score = pos_rank_score.reshape(CONFIG['BATCH_SIZE'], CONFIG['LEVEL_NUM'] - 1, -1) 243 | for l in range(CONFIG['LEVEL_NUM'] - 1): 244 | level_score_mean_list[l + 1].append(pos_rank_score[:, l].mean().item()) 245 | pbar.update(val_batch[0].size(0)) 246 | pbar.close() 247 | avg_val_loss = sum(val_loss_list) / len(val_loss_list) 248 | print('[VAL Loss] epoch {} step {}: {:.6f}'.format(epoch + 1, step, avg_val_loss)) 249 | val_wrt.add_scalar('loss', avg_val_loss, step) 250 | for l in range(CONFIG['LEVEL_NUM']): 251 | avg_score_mean = sum(level_score_mean_list[l]) / len(level_score_mean_list[l]) 252 | sc_wrts[l].add_scalar('rank score mean', avg_score_mean, step) 253 | # Update learning rate 254 | head_scheduler.step(avg_val_loss) 255 | ref_scheduler.step(avg_val_loss) 256 | rnn_scheduler.step(avg_val_loss) 257 | # Track model with lowest val loss 258 | if avg_val_loss < best_model['avg_val_loss']: 259 | best_model['avg_val_loss'] = avg_val_loss 260 | best_model['epoch'] = epoch + 1 261 | best_model['step'] = step 262 | best_model['weights'] = copy.deepcopy(predictor.state_dict()) 263 | # Save checkpoint after each epoch 264 | epoch_ckpt = { 265 | 'ref_optimizer': ref_optimizer.state_dict(), 266 | 'rnn_optimizer': rnn_optimizer.state_dict(), 267 | 'head_optimizer': head_optimizer.state_dict(), 268 | 'ref_scheduler': ref_scheduler.state_dict(), 269 | 'rnn_scheduler': rnn_scheduler.state_dict(), 270 | 'head_scheduler': head_scheduler.state_dict(), 271 | 'model': predictor.state_dict() 272 | } 273 | save_path = 'output/att_rank_ckpt_{}_{}.pth'.format(tid, epoch + 1) 274 | torch.save(epoch_ckpt, save_path) 275 | # Save best model 276 | time_spent = int(time() - tic) // 60 277 | print('\nTraining completed in {} h {} m.'.format(time_spent // 60, time_spent % 60)) 278 | print('Found model with lowest val loss at epoch {epoch} step {step}.'.format(**best_model)) 279 | save_path = 'output/att_rank_{}_b.pth'.format(tid) 280 | torch.save(best_model['weights'], save_path) 281 | print('Saved best model weights to {}'.format(save_path)) 282 | # Close summary writer 283 | trn_wrt.close() 284 | val_wrt.close() 285 | for wrt in sc_wrts.values(): 286 | wrt.close() 287 | # Log training procedure 288 | model_info = { 289 | 'type': 'att_rank', 290 | 'dataset': dataset_splitby, 291 | 'config': CONFIG 292 | } 293 | with open('output/att_rank_{}.json'.format(tid), 'w') as f: 294 | json.dump(model_info, f, indent=4, sort_keys=True) 295 | 296 | 297 | if __name__ == '__main__': 298 | parser = ArgumentParser() 299 | parser.add_argument('--dataset', default='refcoco') 300 | parser.add_argument('--split-by', default='unc') 301 | parser.add_argument('--gpu-id', type=int, default=0) 302 | parser.add_argument('--epoch', type=int, default=None) 303 | parser.add_argument('--resume', type=str, default=None) 304 | main(parser.parse_args()) 305 | -------------------------------------------------------------------------------- /tools/train_att_vanilla.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import os 4 | from time import time 5 | import copy 6 | 7 | from tqdm import tqdm 8 | import torch 9 | from torch import nn, optim 10 | from torch.utils.data import DataLoader 11 | from torch.nn.init import zeros_, xavier_uniform_ 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.utils.rnn import pack_padded_sequence 14 | 15 | from lib.vanilla_utils import DetBoxDataset 16 | from lib.predictor import AttVanillaPredictorV2 17 | from utils.misc import get_time_id 18 | 19 | 20 | PRETRAINED_MRCN = 'data/res101_mask_rcnn_iter_1250000_cpu.pth' 21 | 22 | CONFIG = dict( 23 | ROI_PER_IMG=32, 24 | HEAD_LR=2e-4, 25 | HEAD_WD=1e-3, 26 | REF_LR=5e-4, 27 | REF_WD=1e-3, 28 | RNN_LR=5e-4, 29 | RNN_WD=1e-3, 30 | BATCH_SIZE=32, 31 | EPOCH=5, 32 | ATT_DROPOUT_P=0.5, 33 | RANK_DROPOUT_P=0.5 34 | ) 35 | LOG_INTERVAL = 50 36 | VAL_INTERVAL = 1000 37 | 38 | 39 | def init_att_vanilla_predictor(predictor): 40 | # Load pre-trained weights from M-RCNN 41 | mrcn_weights = torch.load(PRETRAINED_MRCN) 42 | c4_weights = { 43 | k[len('resnet.layer4.'):]: v 44 | for k, v in mrcn_weights.items() 45 | if k.startswith('resnet.layer4') 46 | } 47 | assert len(c4_weights) == 50 48 | predictor.head.load_state_dict(c4_weights) 49 | # Initialize new layers 50 | count = 0 51 | for name, param in predictor.named_parameters(): 52 | if 'head' in name: 53 | continue 54 | if 'weight' in name: 55 | xavier_uniform_(param) 56 | count += 1 57 | elif 'bias' in name: 58 | zeros_(param) 59 | count += 1 60 | assert count == 20 61 | 62 | 63 | def compute_loss(predictor, criterion, device, enable_grad, roi_feats, roi_labels, word_feats, sent_len): 64 | with torch.autograd.set_grad_enabled(enable_grad): 65 | roi_feats = roi_feats.to(device) 66 | roi_labels = roi_labels.to(device) 67 | word_feats = word_feats.to(device) 68 | packed_sent_feats = pack_padded_sequence(word_feats, sent_len, enforce_sorted=False, batch_first=True) 69 | scores, *_ = predictor.forward(roi_feats, packed_sent_feats) 70 | loss = criterion(scores.flatten(), roi_labels.flatten()) 71 | return loss 72 | 73 | 74 | def main(args): 75 | tid = get_time_id() 76 | dataset_splitby = '{}_{}'.format(args.dataset, args.split_by) 77 | device = torch.device('cuda', args.gpu_id) 78 | refdb_path = 'cache/std_refdb_{}.json'.format(dataset_splitby) 79 | print('loading refdb from {}...'.format(refdb_path)) 80 | with open(refdb_path, 'r') as f: 81 | refdb = json.load(f) 82 | ctxdb_path = 'cache/std_ctxdb_{}.json'.format(dataset_splitby) 83 | print('loading ctxdb from {}...'.format(ctxdb_path)) 84 | with open(ctxdb_path, 'r') as f: 85 | ctxdb = json.load(f) 86 | # Build dataloaders 87 | trn_dataset = DetBoxDataset(refdb, ctxdb, split='train', roi_per_img=CONFIG['ROI_PER_IMG']) 88 | val_dataset = DetBoxDataset(refdb, ctxdb, split='val', roi_per_img=CONFIG['ROI_PER_IMG']) 89 | trn_loader = DataLoader(trn_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=8) 90 | val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=8) 91 | # Tensorboard writer 92 | tb_dir = 'tb/att_vanilla/{}'.format(tid) 93 | trn_wrt = SummaryWriter(os.path.join(tb_dir, 'train')) 94 | val_wrt = SummaryWriter(os.path.join(tb_dir, 'val')) 95 | # Build and init predictor 96 | predictor = AttVanillaPredictorV2(att_dropout_p=CONFIG['ATT_DROPOUT_P'], 97 | rank_dropout_p=CONFIG['RANK_DROPOUT_P']) 98 | init_att_vanilla_predictor(predictor) 99 | predictor.to(device) 100 | # Setup loss 101 | criterion = nn.BCEWithLogitsLoss(reduction='mean') 102 | # Setup optimizer 103 | ref_params = list(predictor.att_fc.parameters()) + list(predictor.rank_fc.parameters()) \ 104 | + list(predictor.vis_a_fc.parameters()) + list(predictor.vis_r_fc.parameters()) 105 | ref_optimizer = optim.Adam(ref_params, lr=CONFIG['REF_LR'], weight_decay=CONFIG['REF_WD']) 106 | rnn_optimizer = optim.Adam(predictor.rnn.parameters(), lr=CONFIG['RNN_LR'], weight_decay=CONFIG['RNN_WD']) 107 | head_optimizer = optim.Adam(predictor.head.parameters(), lr=CONFIG['HEAD_LR'], weight_decay=CONFIG['HEAD_WD']) 108 | common_args = dict(mode='min', factor=0.5, verbose=True, threshold_mode='rel', patience=1) 109 | ref_scheduler = optim.lr_scheduler.ReduceLROnPlateau(ref_optimizer, min_lr=CONFIG['REF_LR']/100, **common_args) 110 | rnn_scheduler = optim.lr_scheduler.ReduceLROnPlateau(rnn_optimizer, min_lr=CONFIG['RNN_LR']/100, **common_args) 111 | head_scheduler = optim.lr_scheduler.ReduceLROnPlateau(head_optimizer, min_lr=CONFIG['HEAD_LR']/100, **common_args) 112 | # Start training 113 | step = 0 114 | trn_running_loss = 0. 115 | best_model = {'avg_val_loss': float('inf'), 'epoch': None, 'step': None, 'weights': None} 116 | tic = time() 117 | for epoch in range(CONFIG['EPOCH']): 118 | for trn_batch in trn_loader: 119 | # Train for one step 120 | step += 1 121 | predictor.train() 122 | loss = compute_loss(predictor, criterion, device, True, *trn_batch) 123 | head_optimizer.zero_grad() 124 | ref_optimizer.zero_grad() 125 | rnn_optimizer.zero_grad() 126 | loss.backward() 127 | head_optimizer.step() 128 | ref_optimizer.step() 129 | rnn_optimizer.step() 130 | trn_running_loss += loss.item() 131 | # Log training loss 132 | if step % LOG_INTERVAL == 0: 133 | avg_trn_loss = trn_running_loss / LOG_INTERVAL 134 | print('[TRN Loss] epoch {} step {}: {:.6f}'.format(epoch + 1, step, avg_trn_loss)) 135 | trn_wrt.add_scalar('loss', avg_trn_loss, step) 136 | trn_running_loss = 0. 137 | # Eval on whole val split 138 | if step % VAL_INTERVAL == 0: 139 | # Compute and log val loss 140 | predictor.eval() 141 | val_loss_list = [] 142 | pbar = tqdm(total=len(val_dataset), ascii=True, desc='computing val loss') 143 | for val_batch in val_loader: 144 | loss = compute_loss(predictor, criterion, device, False, *val_batch) 145 | val_loss_list.append(loss.item()) 146 | pbar.update(val_batch[0].size(0)) 147 | pbar.close() 148 | avg_val_loss = sum(val_loss_list) / len(val_loss_list) 149 | print('[VAL Loss] epoch {} step {}: {:.6f}'.format(epoch + 1, step, avg_val_loss)) 150 | val_wrt.add_scalar('loss', avg_val_loss, step) 151 | # Update learning rate 152 | head_scheduler.step(avg_val_loss) 153 | ref_scheduler.step(avg_val_loss) 154 | rnn_scheduler.step(avg_val_loss) 155 | # Track model with lowest val loss 156 | if avg_val_loss < best_model['avg_val_loss']: 157 | best_model['avg_val_loss'] = avg_val_loss 158 | best_model['epoch'] = epoch + 1 159 | best_model['step'] = step 160 | best_model['weights'] = copy.deepcopy(predictor.state_dict()) 161 | # Save checkpoint after each epoch 162 | epoch_ckpt = { 163 | 'ref_optimizer': ref_optimizer.state_dict(), 164 | 'rnn_optimizer': rnn_optimizer.state_dict(), 165 | 'head_optimizer': head_optimizer.state_dict(), 166 | 'ref_scheduler': ref_scheduler.state_dict(), 167 | 'rnn_scheduler': rnn_scheduler.state_dict(), 168 | 'head_scheduler': head_scheduler.state_dict(), 169 | 'model': predictor.state_dict() 170 | } 171 | save_path = 'output/att_vanilla_ckpt_{}_{}.pth'.format(tid, epoch + 1) 172 | torch.save(epoch_ckpt, save_path) 173 | # Save best model 174 | time_spent = int(time() - tic) // 60 175 | print('\nTraining completed in {} h {} m.'.format(time_spent // 60, time_spent % 60)) 176 | print('Found model with lowest val loss at epoch {epoch} step {step}.'.format(**best_model)) 177 | save_path = 'output/att_vanilla_{}_b.pth'.format(tid) 178 | torch.save(best_model['weights'], save_path) 179 | print('Saved best model weights to {}'.format(save_path)) 180 | # Close summary writer 181 | trn_wrt.close() 182 | val_wrt.close() 183 | # Log training procedure 184 | model_info = { 185 | 'type': 'att_vanilla_v2', 186 | 'dataset': dataset_splitby, 187 | 'config': CONFIG 188 | } 189 | with open('output/att_vanilla_{}.json'.format(tid), 'w') as f: 190 | json.dump(model_info, f, indent=4, sort_keys=True) 191 | 192 | 193 | if __name__ == '__main__': 194 | parser = ArgumentParser() 195 | parser.add_argument('--dataset', default='refcoco') 196 | parser.add_argument('--split-by', default='unc') 197 | parser.add_argument('--gpu-id', type=int, default=0) 198 | main(parser.parse_args()) 199 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | __all__ = ['COCO_CAT_NAMES', 'CAT_ID_TO_NAME', 'CAT_NAME_TO_ID', 'CAT_ID_TO_IDX', 'EVAL_SPLITS_DICT'] 2 | 3 | 4 | COCO_CAT_NAMES = [ 5 | '__background__', 6 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 7 | 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 8 | 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 9 | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 10 | 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 11 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 12 | 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 13 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 14 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 15 | 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 16 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 17 | 'teddy bear', 'hair drier', 'toothbrush'] 18 | assert len(COCO_CAT_NAMES) == 81 19 | 20 | CAT_ID_TO_NAME = { 21 | 1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 22 | 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 23 | 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 24 | 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 25 | 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 26 | 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 27 | 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 28 | 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 29 | 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 30 | 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 31 | 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 32 | 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 33 | 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 34 | 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 35 | 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'} 36 | assert len(CAT_ID_TO_NAME) == 80 37 | 38 | CAT_NAME_TO_ID = {} 39 | for k, v in CAT_ID_TO_NAME.items(): 40 | CAT_NAME_TO_ID[v] = k 41 | assert v in COCO_CAT_NAMES 42 | assert len(CAT_NAME_TO_ID) == 80 43 | 44 | CAT_ID_TO_IDX = {} 45 | for k, v in CAT_ID_TO_NAME.items(): 46 | CAT_ID_TO_IDX[k] = COCO_CAT_NAMES.index(v) 47 | assert len(CAT_ID_TO_IDX) == 80 48 | 49 | EVAL_SPLITS_DICT = { 50 | 'refcoco_unc': ['val', 'testA', 'testB'], 51 | 'refcoco+_unc': ['val', 'testA', 'testB'], 52 | 'refcocog_umd': ['val', 'test'] 53 | } 54 | -------------------------------------------------------------------------------- /utils/hit_rate_utils.py: -------------------------------------------------------------------------------- 1 | from utils.misc import calculate_iou, xywh_to_xyxy 2 | 3 | 4 | __all__ = ['NewHitRateEvaluator', 'CtxHitRateEvaluator'] 5 | 6 | 7 | class NewHitRateEvaluator: 8 | 9 | def __init__(self, refer, top_N=None, threshold=0.5): 10 | """Evaluate refexp-based hit rate. 11 | 12 | Args: 13 | refdb: `refdb` dict. 14 | split: Dataset split to evaluate on. 15 | top_N: Select top-N scoring proposals to evaluate. `None` means no selection. Default `None`. 16 | 17 | """ 18 | self.refer = refer 19 | self.top_N = top_N 20 | self.threshold = threshold 21 | 22 | def eval_hit_rate(self, split, proposal_dict, image_as_key=False): 23 | """Evaluate refexp-based hit rate. 24 | 25 | Args: 26 | proposal_dict: {exp_id or image_id: [{box: [4,], score: float}]}. 27 | image_as_key: Use image_id instead of exp_id as key, default `False`. 28 | 29 | Returns: 30 | proposal_per_ref: Number of proposals per refexp. 31 | hit_rate: Refexp-based hit rate of proposals. 32 | 33 | """ 34 | # Initialize counters 35 | num_hit = 0 36 | num_proposal = 0 37 | num_ref = 0 # NOTE: this is the number of refexp, not ref 38 | for ref_id in self.refer.getRefIds(split=split): 39 | ref = self.refer.Refs[ref_id] 40 | image_id = ref['image_id'] 41 | ann_id = ref['ann_id'] 42 | ann = self.refer.Anns[ann_id] 43 | gt_box = xywh_to_xyxy(ann['bbox']) 44 | for exp_id in ref['sent_ids']: 45 | # Get proposals 46 | if image_as_key: 47 | proposals = proposal_dict[image_id] 48 | else: 49 | proposals = proposal_dict[exp_id] 50 | # Rank and select proposals 51 | ranked_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True)[:self.top_N] 52 | for proposal in ranked_proposals: 53 | if calculate_iou(gt_box, proposal['box']) > self.threshold: 54 | num_hit += 1 55 | break 56 | num_proposal += len(ranked_proposals) 57 | num_ref += 1 58 | proposal_per_ref = num_proposal / num_ref 59 | hit_rate = num_hit / num_ref 60 | return proposal_per_ref, hit_rate 61 | 62 | 63 | class CtxHitRateEvaluator: 64 | 65 | def __init__(self, refer, ctxdb, top_N=None, threshold=0.5): 66 | self.refer = refer 67 | self.ctxdb = ctxdb 68 | self.top_N = top_N 69 | self.threshold = threshold 70 | 71 | def eval_hit_rate(self, split, proposal_dict, image_as_key=False): 72 | """Evaluate refexp-based hit rate. 73 | 74 | Args: 75 | proposal_dict: {exp_id or image_id: [{box: [4,], score: float}]}. 76 | image_as_key: Use image_id instead of exp_id as key, default `False`. 77 | 78 | Returns: 79 | proposal_per_ref: Number of proposals per refexp. 80 | hit_rate: Refexp-based hit rate of proposals. 81 | 82 | """ 83 | # Initialize counters 84 | recall_list = [] 85 | avg_num_list = [] 86 | for exp_id, ctx in self.ctxdb[split].items(): 87 | exp_id = int(exp_id) 88 | if len(ctx['ctx']) == 0: 89 | continue 90 | # Get proposals 91 | if image_as_key: 92 | image_id = self.refer.sentToRef[exp_id]['image_id'] 93 | proposals = proposal_dict[image_id] 94 | else: 95 | proposals = proposal_dict[exp_id] 96 | # Rank and select proposals 97 | ranked_proposals = sorted(proposals, key=lambda p: p['score'], reverse=True)[:self.top_N] 98 | hit_num, ctx_num = 0, 0 99 | for ctx_item in ctx['ctx']: 100 | ctx_num += 1 101 | ctx_box = ctx_item['box'] 102 | for proposal in ranked_proposals: 103 | if calculate_iou(ctx_box, proposal['box']) > self.threshold: 104 | hit_num += 1 105 | break 106 | recall_list.append(hit_num / ctx_num) 107 | avg_num_list.append(len(ranked_proposals)) 108 | return sum(avg_num_list) / len(avg_num_list), sum(recall_list) / len(recall_list) 109 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import random 3 | 4 | import torch 5 | from torch.nn.functional import affine_grid, grid_sample 6 | 7 | 8 | __all__ = ['xywh_to_xyxy', 'calculate_area', 'calculate_iou', 'repeat_loader', 'get_time_id', 9 | 'mrcn_crop_pool_layer', 'jitter_roi', 'recursive_jitter_roi', 'alert_print'] 10 | 11 | 12 | def xywh_to_xyxy(box): 13 | """Convert xywh bbox to xyxy format.""" 14 | return box[0], box[1], box[0]+box[2], box[1]+box[3] 15 | 16 | 17 | def calculate_area(box): 18 | """Calculate area of bbox in xyxy format.""" 19 | return (box[2] - box[0])*(box[3] - box[1]) 20 | 21 | 22 | def calculate_iou(box1, box2): 23 | """Calculate IoU of two bboxes in xyxy format.""" 24 | max_L = max(box1[0], box2[0]) 25 | min_R = min(box1[2], box2[2]) 26 | max_T = max(box1[1], box2[1]) 27 | min_B = min(box1[3], box2[3]) 28 | if max_L < min_R and max_T < min_B: 29 | intersection = (min_B - max_T)*(min_R - max_L) 30 | union = calculate_area(box1) + calculate_area(box2) - intersection 31 | return intersection / union 32 | else: 33 | return 0. 34 | 35 | 36 | def repeat_loader(loader): 37 | """Endlessly repeat given loader.""" 38 | while True: 39 | for data in loader: 40 | yield data 41 | 42 | 43 | def get_time_id(): 44 | tt = datetime.now().timetuple() 45 | return '{:02d}{:02d}{:02d}{:02d}{:02d}'.format(tt.tm_mon, tt.tm_mday, tt.tm_hour, tt.tm_min, tt.tm_sec) 46 | 47 | 48 | def mrcn_crop_pool_layer(bottom, rois): 49 | x1 = rois[:, 0::4] / 16.0 # [batch_size, 1] 50 | y1 = rois[:, 1::4] / 16.0 51 | x2 = rois[:, 2::4] / 16.0 52 | y2 = rois[:, 3::4] / 16.0 53 | height = bottom.size(2) 54 | width = bottom.size(3) 55 | # Affine theta 56 | zero = torch.zeros_like(x1) 57 | theta = torch.cat([ 58 | (x2 - x1)/(width - 1), 59 | zero, 60 | (x1 + x2 - width + 1)/(width - 1), 61 | zero, 62 | (y2 - y1)/(height - 1), 63 | (y1 + y2 - height + 1)/(height - 1)], 1).reshape(-1, 2, 3) # [batch_size, 2, 3] 64 | if int(torch.__version__.split('.')[1]) < 3: 65 | grid = affine_grid(theta, torch.Size((rois.size(0), 1, 7, 7))) 66 | crops = grid_sample(bottom.expand(rois.size(0), -1, -1, -1), grid) 67 | else: 68 | grid = affine_grid(theta, torch.Size((rois.size(0), 1, 7, 7)), align_corners=True) 69 | crops = grid_sample(bottom.expand(rois.size(0), -1, -1, -1), grid, align_corners=True) 70 | return crops 71 | 72 | 73 | def jitter_coordinate(x, L, w, image_l, image_r): 74 | r = w * (1 - L) / L 75 | l = w * (L - 1) 76 | 77 | rec = x + random.uniform(l, r) 78 | 79 | rec = max(rec, image_l) 80 | rec = min(rec, image_r) 81 | 82 | return rec 83 | 84 | 85 | def jitter_roi(G, L, R, img_w, img_h): 86 | w = G[2] - G[0] 87 | h = G[3] - G[1] 88 | 89 | while True: 90 | x0 = jitter_coordinate(G[0], L, w, 0, img_w) 91 | x1 = jitter_coordinate(G[2], L, w, 0, img_w) 92 | y0 = jitter_coordinate(G[1], L, h, 0, img_h) 93 | y1 = jitter_coordinate(G[3], L, h, 0, img_h) 94 | jittered_roi = (x0, y0, x1, y1) 95 | if L <= calculate_iou(jittered_roi, G) <= R: 96 | return jittered_roi 97 | 98 | 99 | def recursive_jitter_roi(G, L, R, img_w, img_h, max_interval=0.01): 100 | assert L < R 101 | if R - L > max_interval: 102 | mid = (L + R) / 2 103 | if random.random() < 0.5: 104 | return recursive_jitter_roi(G, L, mid, img_w, img_h, max_interval) 105 | else: 106 | return recursive_jitter_roi(G, mid, R, img_w, img_h, max_interval) 107 | else: 108 | w = G[2] - G[0] 109 | h = G[3] - G[1] 110 | while True: 111 | x0 = jitter_coordinate(G[0], L, w, 0, img_w) 112 | x1 = jitter_coordinate(G[2], L, w, 0, img_w) 113 | y0 = jitter_coordinate(G[1], L, h, 0, img_h) 114 | y1 = jitter_coordinate(G[3], L, h, 0, img_h) 115 | jittered_roi = (x0, y0, x1, y1) 116 | if L <= calculate_iou(jittered_roi, G) <= R: 117 | return jittered_roi 118 | 119 | 120 | def alert_print(msg): 121 | print('\33[31m[ALERT] {}\33[0m'.format(msg)) 122 | --------------------------------------------------------------------------------