├── GAT.py ├── LICENSE ├── README.md ├── coco_sims └── coco_sims.txt ├── data.py ├── data ├── coco │ ├── annotations.zip │ ├── download.sh │ └── images └── f30k │ ├── dataset_flickr30k.zip │ └── images ├── data_bert.py ├── evaluation.py ├── evaluation_bert.py ├── figures └── model.jpg ├── flickr_sims └── flickr_sims.txt ├── model.py ├── model_bert.py ├── pytorch_pretrained_bert ├── .DS_Store ├── file_utils.py ├── modeling.py ├── optimization.py └── tokenization.py ├── rerank.py ├── resnet.py ├── runs ├── BERT │ └── bert_models └── GRU │ └── gru_models ├── test_bert_cc.sh ├── test_bert_f.sh ├── test_gru_cc.sh ├── test_gru_f.sh ├── train.py ├── train_bert.py ├── uncased_L-12_H-768_A-12 └── bert_pretrained_model ├── vocab.py └── vocab ├── 111 ├── 10crop_precomp_vocab.pkl ├── coco_precomp_vocab.pkl ├── coco_resnet_precomp_vocab.pkl ├── coco_vgg_precomp_vocab.pkl ├── coco_vocab.pkl ├── f30k_precomp_vocab.pkl ├── f30k_vocab.pkl ├── f8k_precomp_vocab.pkl └── f8k_vocab.pkl /GAT.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation 3 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 4 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 5 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 6 | # Writen by Keyu Wen, 2020 7 | # ------------------------------------------------------------ 8 | 9 | import math 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class MultiHeadAttention(nn.Module): 16 | def __init__(self, config): 17 | super(MultiHeadAttention, self).__init__() 18 | 19 | self.num_attention_heads = config.num_attention_heads 20 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 21 | self.all_head_size = self.num_attention_heads * self.attention_head_size 22 | 23 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 24 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 25 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 26 | 27 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 28 | 29 | def transpose_for_scores(self, x): 30 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 31 | x = x.view(*new_x_shape) 32 | return x.permute(0, 2, 1, 3) 33 | 34 | def forward(self, input_graph): 35 | nodes_q = self.query(input_graph) 36 | nodes_k = self.key(input_graph) 37 | nodes_v = self.value(input_graph) 38 | 39 | nodes_q_t = self.transpose_for_scores(nodes_q) 40 | nodes_k_t = self.transpose_for_scores(nodes_k) 41 | nodes_v_t = self.transpose_for_scores(nodes_v) 42 | 43 | # Take the dot product between "query" and "key" to get the raw attention scores. 44 | attention_scores = torch.matmul(nodes_q_t, nodes_k_t.transpose(-1, -2)) 45 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 46 | # Apply the attention mask is (precomputed for all layers in GATModel forward() function) 47 | attention_scores = attention_scores 48 | 49 | # Normalize the attention scores to probabilities. 50 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 51 | 52 | # This is actually dropping out entire tokens to attend to, which might 53 | # seem a bit unusual, but is taken from the original Transformer paper. 54 | attention_probs = self.dropout(attention_probs) 55 | 56 | nodes_new = torch.matmul(attention_probs, nodes_v_t) 57 | nodes_new = nodes_new.permute(0, 2, 1, 3).contiguous() 58 | new_nodes_shape = nodes_new.size()[:-2] + (self.all_head_size,) 59 | nodes_new = nodes_new.view(*new_nodes_shape) 60 | return nodes_new 61 | 62 | 63 | class GATLayer(nn.Module): 64 | def __init__(self, config): 65 | super(GATLayer, self).__init__() 66 | self.mha = MultiHeadAttention(config) 67 | 68 | self.fc_in = nn.Linear(config.hidden_size, config.hidden_size) 69 | self.bn_in = nn.BatchNorm1d(config.hidden_size) 70 | self.dropout_in = nn.Dropout(config.hidden_dropout_prob) 71 | 72 | self.fc_int = nn.Linear(config.hidden_size, config.hidden_size) 73 | 74 | self.fc_out = nn.Linear(config.hidden_size, config.hidden_size) 75 | self.bn_out = nn.BatchNorm1d(config.hidden_size) 76 | self.dropout_out = nn.Dropout(config.hidden_dropout_prob) 77 | 78 | def forward(self, input_graph): 79 | attention_output = self.mha(input_graph) # multi-head attention 80 | attention_output = self.fc_in(attention_output) 81 | attention_output = self.dropout_in(attention_output) 82 | attention_output = self.bn_in((attention_output + input_graph).permute(0, 2, 1)).permute(0, 2, 1) 83 | intermediate_output = self.fc_int(attention_output) 84 | intermediate_output = F.relu(intermediate_output) 85 | intermediate_output = self.fc_out(intermediate_output) 86 | intermediate_output = self.dropout_out(intermediate_output) 87 | graph_output = self.bn_out((intermediate_output + attention_output).permute(0, 2, 1)).permute(0, 2, 1) 88 | return graph_output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is the official source code for **Dual Semantic Relations Attention Network(DSRAN)** proposed in our journal paper [Learning Dual Semantic Relations with Graph Attention for Image-Text Matching (TCSVT 2020)](https://arxiv.org/abs/2010.11550). It is built on top of the [VSE++](https://github.com/fartashf/vsepp) in PyTorch. 3 | 4 | 5 | **The framework of DSRAN:** 6 | 7 | 8 | 9 | **The results on MSCOCO and Flickr30K dataset:(With BERT or GRU)** 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 |
GRUImage-to-TextText-to-Image
DatasetR@1R@5R@10R@1R@5R@10Rsum
MSCOCO-1K80.496.798.764.290.495.8526.2
MSCOCO-5K57.685.691.941.571.982.1430.6
Flickr30k79.695.697.558.685.891.3508.4
58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 |
BERTImage-to-TextText-to-Image
DatasetR@1R@5R@10R@1R@5R@10Rsum
MSCOCO-1K80.696.798.764.590.895.8527.1
MSCOCO-5K57.985.392.041.772.782.8432.4
Flickr30k80.595.597.959.286.091.9511.0
107 | 108 | ## Requirements and Installation 109 | We recommended the following dependencies. 110 | * Python 3.6 111 | * PyTorch 1.1.0 112 | * NumPy (>1.12.1) 113 | * torchtext 114 | * pycocotools 115 | * nltk 116 | 117 | ## Download data 118 | 119 | Download the raw images, pre-computed image features, pre-trained BERT models, pre-trained ResNet152 model and pre-trained DSRAN models. As for the raw images, they can be downloaded from [VSE++](https://github.com/fartashf/vsepp). 120 | 121 | ``` 122 | wget http://www.cs.toronto.edu/~faghri/vsepp/data.tar 123 | wget http://www.cs.toronto.edu/~faghri/vsepp/vocab.tar 124 | ``` 125 | We refer to the path of extracted files for `data.tar` as `$DATA_PATH` while only raw images are used which are `coco` and `f30k`. 126 | 127 | For pre-computed image features, they can be obtained from [VLP](https://github.com/LuoweiZhou/VLP). These zip files should be extracted into the fold `data/joint-pretrain`. We refer to the path of extracted `region_bbox_file(.h5)` as `$REGION_BBOX_FILE` and regional feature paths `feat_cls_1000/` for COCO and `trainval/` for FLICKR30K as `$FEATURE_PATH`. 128 | 129 | Pre-trained ResNet152 model can be downloaded from [torchvision](https://download.pytorch.org/models/resnet152-b121ed2d.pth) and put in the root directory. 130 | ``` 131 | wget https://download.pytorch.org/models/resnet152-b121ed2d.pth 132 | ``` 133 | For our trained DSRAN models, you can download `runs.zip` on [Google Drive](https://drive.google.com/drive/folders/1SQiRpO3L8d9QxFSRdk31PZrxRUi3eXyW?usp=sharing) or `GRU.zip` together with `BERT.zip` on [BaiduNetDisk](https://pan.baidu.com/s/1H_iMH-QZETAdHLk03dBREA)(extract code:1119). There are totally 8 models (4 for each dataset). 134 | 135 | Pre-trained BERT models are obtained form an old version of [transformers](https://github.com/huggingface/transformers). It is noticed that there's a simpler way of using BERT as seen in [transformers](https://github.com/huggingface/transformers). We'll update the code in the future. The pre-trained models we use can be downloaded from the same [Google Drive](https://drive.google.com/drive/folders/1SQiRpO3L8d9QxFSRdk31PZrxRUi3eXyW?usp=sharing) and [BaiduNetDisk](https://pan.baidu.com/s/1H_iMH-QZETAdHLk03dBREA)(extract code:1119) links. We refer to the path of extracted files for `uncased_L-12_H-768_A-12.zip` as `$BERT_PATH`. 136 | 137 | 138 | ### Data Structure 139 | ``` 140 | ├── data/ 141 | | ├── coco/ /* MSCOCO raw images 142 | | | ├── images/ 143 | | | | ├── train2014/ 144 | | | | ├── val2014/ 145 | | | ├── annotations/ 146 | | ├── f30k/ /* Flickr30K raw images 147 | | | ├── images/ 148 | | | ├── dataset_flickr30k.json 149 | | ├── joint-pretrain/ /* pre-computed image features 150 | | | ├── COCO/ 151 | | | | ├── region_feat_gvd_wo_bgd/ 152 | | | | | ├── feat_cls_1000/ /* $FEATURE_PATH 153 | | | | | ├── coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 /* $REGION_BBOX_FILE 154 | | | | ├── annotations/ 155 | | | ├── flickr30k/ 156 | | | | ├── region_feat_gvd_wo_bgd/ 157 | | | | | ├── trainval/ /* $FEATURE_PATH 158 | | | | | ├── flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 /* $REGION_BBOX_FILE 159 | | | | ├── annotations/ 160 | ``` 161 | 162 | ## Evaluate trained models 163 | 164 | ### Test on single model: 165 | 166 | + Test on MSCOCO dataset (1K and 5K simultaneously): 167 | 168 | + Test on BERT-based models: 169 | 170 | ```bash 171 | python evaluation_bert.py --model BERT/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 172 | ``` 173 | 174 | + Test on GRU-based models: 175 | 176 | ```bash 177 | python evaluation.py --model GRU/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 178 | ``` 179 | 180 | + Test on Flickr30K dataset: 181 | 182 | + Test on BERT-based models: 183 | 184 | ```bash 185 | python evaluation_bert.py --model BERT/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 186 | ``` 187 | 188 | + Test on GRU-based models: 189 | 190 | ```bash 191 | python evaluation.py --model GRU/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 192 | ``` 193 | 194 | ### Test on two-models ensemble and re-rank: 195 | 196 | /* Remember to modify the "$DATA_PATH", "$REGION_BBOX_FILE" and "$FEATURE_PATH" in the .sh files. 197 | 198 | + Test on MSCOCO dataset (1K and 5K simultaneously): 199 | 200 | + Test on BERT-based models: 201 | 202 | ```bash 203 | sh test_bert_cc.sh 204 | ``` 205 | 206 | + Test on GRU-based models: 207 | 208 | ```bash 209 | sh test_gru_cc.sh 210 | ``` 211 | 212 | + Test on Flickr30K dataset: 213 | 214 | + Test on BERT-based models: 215 | 216 | ```bash 217 | sh test_bert_f.sh 218 | ``` 219 | 220 | + Test on GRU-based models: 221 | 222 | ```bash 223 | sh test_gru_f.sh 224 | ``` 225 | 226 | ## Train new models 227 | 228 | Train a model with BERT on MSCOCO: 229 | 230 | ```bash 231 | python train_bert.py --data_path "$DATA_PATH" --data_name coco --num_epochs 18 --batch_size 320 --lr_update 9 --logger_name runs/cc_bert --bert_path "$BERT_PATH" --ft_bert --warmup 0.1 --K 4 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE" 232 | ``` 233 | 234 | Train a model with BERT on Flickr30K: 235 | 236 | ```bash 237 | python train_bert.py --data_path "$DATA_PATH" --data_name f30k --num_epochs 12 --batch_size 128 --lr_update 6 --logger_name runs/f_bert --bert_path "$BERT_PATH" --ft_bert --warmup 0.1 --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE" 238 | ``` 239 | 240 | Train a model with GRU on MSCOCO: 241 | 242 | ```bash 243 | python train.py --data_path "$DATA_PATH" --data_name coco --num_epochs 18 --batch_size 300 --lr_update 9 --logger_name runs/cc_gru --use_restval --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE" 244 | ``` 245 | 246 | Train a model with GRU on Flickr30K: 247 | 248 | ```bash 249 | python train.py --data_path "$DATA_PATH" --data_name f30k --num_epochs 16 --batch_size 128 --lr_update 8 --logger_name runs/f_gru --use_restval --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE" 250 | ``` 251 | 252 | ## Acknowledgement 253 | We thank [Linyang Li](https://github.com/LinyangLee) for the help with the code and provision of some computing resources. 254 | ## Reference 255 | 256 | If DSRAN is useful for your research, please cite our paper: 257 | 258 | ``` 259 | @ARTICLE{9222079, 260 | author={Wen, Keyu and Gu, Xiaodong and Cheng, Qingrong}, 261 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 262 | title={Learning Dual Semantic Relations With Graph Attention for Image-Text Matching}, 263 | year={2021}, 264 | volume={31}, 265 | number={7}, 266 | pages={2866-2879}, 267 | doi={10.1109/TCSVT.2020.3030656}} 268 | ``` 269 | 270 | ## License 271 | 272 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0) 273 | -------------------------------------------------------------------------------- /coco_sims/coco_sims.txt: -------------------------------------------------------------------------------- 1 | Path to save similarity matrixes during inference stage of MSCOCO. 2 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen & Linyang Li, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | import os 14 | import nltk 15 | from PIL import Image 16 | from pycocotools.coco import COCO 17 | import numpy as np 18 | import json as jsonmod 19 | import time 20 | import copy 21 | import h5py 22 | import torch.nn.functional as F 23 | 24 | 25 | def get_paths(path, name='coco', use_restval=False): 26 | 27 | roots = {} 28 | ids = {} 29 | if 'coco' == name: 30 | imgdir = os.path.join(path, 'images') 31 | capdir = os.path.join(path, 'annotations') 32 | roots['train'] = { 33 | 'img': os.path.join(imgdir, 'train2014'), 34 | 'cap': os.path.join(capdir, 'captions_train2014.json') 35 | } 36 | roots['val'] = { 37 | 'img': os.path.join(imgdir, 'val2014'), 38 | 'cap': os.path.join(capdir, 'captions_val2014.json') 39 | } 40 | roots['test'] = { 41 | 'img': os.path.join(imgdir, 'val2014'), 42 | 'cap': os.path.join(capdir, 'captions_val2014.json') 43 | } 44 | roots['trainrestval'] = { 45 | 'img': (roots['train']['img'], roots['val']['img']), 46 | 'cap': (roots['train']['cap'], roots['val']['cap']) 47 | } 48 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy')) 49 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000] 50 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy')) 51 | ids['trainrestval'] = ( 52 | ids['train'], 53 | np.load(os.path.join(capdir, 'coco_restval_ids.npy'))) 54 | if use_restval: 55 | roots['train'] = roots['trainrestval'] 56 | ids['train'] = ids['trainrestval'] 57 | elif 'f8k' == name: 58 | imgdir = os.path.join(path, 'images') 59 | cap = os.path.join(path, 'dataset_flickr8k.json') 60 | roots['train'] = {'img': imgdir, 'cap': cap} 61 | roots['val'] = {'img': imgdir, 'cap': cap} 62 | roots['test'] = {'img': imgdir, 'cap': cap} 63 | ids = {'train': None, 'val': None, 'test': None} 64 | elif 'f30k' == name: 65 | imgdir = os.path.join(path, '') 66 | cap = os.path.join(path, 'dataset_flickr30k.json') 67 | roots['train'] = {'img': imgdir, 'cap': cap} 68 | roots['val'] = {'img': imgdir, 'cap': cap} 69 | roots['test'] = {'img': imgdir, 'cap': cap} 70 | ids = {'train': None, 'val': None, 'test': None} 71 | 72 | return roots, ids 73 | 74 | 75 | class CocoDataset(data.Dataset): 76 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 77 | 78 | def __init__(self, root, json, vocab, region_bbox_file, region_det_file_prefix, transform=None, ids=None): 79 | """ 80 | Args: 81 | root: image directory. 82 | json: coco annotation file path. 83 | vocab: vocabulary wrapper. 84 | transform: transformer for image. 85 | """ 86 | self.root = root 87 | # when using `restval`, two json files are needed 88 | if isinstance(json, tuple): 89 | self.coco = (COCO(json[0]), COCO(json[1])) 90 | else: 91 | self.coco = (COCO(json),) 92 | self.root = (root,) 93 | # if ids provided by get_paths, use split-specific ids 94 | if ids is None: 95 | self.ids = list(self.coco.anns.keys()) 96 | else: 97 | self.ids = ids 98 | 99 | # if `restval` data is to be used, record the break point for ids 100 | if isinstance(self.ids, tuple): 101 | self.bp = len(self.ids[0]) 102 | self.ids = list(self.ids[0]) + list(self.ids[1]) 103 | else: 104 | self.bp = len(self.ids) 105 | self.vocab = vocab 106 | self.transform = transform 107 | self.region_bbox_file = region_bbox_file#'/remote-home/lyli/Workspace/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5' 108 | self.region_det_file_prefix = region_det_file_prefix#'/remote-home/lyli/Workspace/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval' 109 | 110 | def __getitem__(self, index): 111 | """This function returns a tuple that is further passed to collate_fn 112 | """ 113 | vocab = self.vocab 114 | root, caption, img_id, path, image, img_rcnn, img_pe = self.get_raw_item(index) 115 | 116 | if self.transform is not None: 117 | image = self.transform(image) 118 | 119 | # Convert caption (string) to word ids. 120 | tokens = nltk.tokenize.word_tokenize( 121 | str(caption).lower().encode('utf-8').decode('utf-8')) 122 | caption = [] 123 | caption.append(vocab('')) 124 | caption.extend([vocab(token) for token in tokens]) 125 | caption.append(vocab('')) 126 | target = torch.Tensor(caption) 127 | 128 | return image, target, img_rcnn, img_pe, index, img_id 129 | 130 | def get_raw_item(self, index): 131 | if index < self.bp: 132 | coco = self.coco[0] 133 | root = self.root[0] 134 | else: 135 | coco = self.coco[1] 136 | root = self.root[1] 137 | ann_id = self.ids[index] 138 | caption = coco.anns[ann_id]['caption'] 139 | img_id = coco.anns[ann_id]['image_id'] 140 | path = coco.loadImgs(img_id)[0]['file_name'] 141 | image = Image.open(os.path.join(root, path)).convert('RGB') 142 | img_rcnn, img_pe = self.get_rcnn(path) 143 | 144 | return root, caption, img_id, path, image, img_rcnn, img_pe 145 | 146 | def get_rcnn(self, path): 147 | img_id = path.split('/')[-1].split('.')[0] 148 | with h5py.File(self.region_det_file_prefix + '_feat' + img_id[-3:] + '.h5', 'r') as region_feat_f: 149 | img = torch.from_numpy(region_feat_f[img_id][:]).float() 150 | 151 | vis_pe = torch.randn(100,1601 + 6) # no position information 152 | return img, vis_pe 153 | 154 | def __len__(self): 155 | return len(self.ids) 156 | 157 | 158 | class FlickrDataset(data.Dataset): 159 | """ 160 | Dataset loader for Flickr30k and Flickr8k full datasets. 161 | """ 162 | 163 | def __init__(self, root, json, split, vocab, region_bbox_file, feature_path, transform=None): 164 | self.root = root 165 | self.vocab = vocab 166 | self.split = split 167 | self.transform = transform 168 | self.dataset = jsonmod.load(open(json, 'r'))['images'] 169 | self.ids = [] 170 | for i, d in enumerate(self.dataset): 171 | if d['split'] == split: 172 | self.ids += [(i, x) for x in range(len(d['sentences']))] 173 | self.region_bbox_file = region_bbox_file#'/home/wenkeyu/wky/projects/pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5' 174 | self.feature_path = feature_path#'/home/wenkeyu/wky/projects/pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/' 175 | 176 | def __getitem__(self, index): 177 | """This function returns a tuple that is further passed to collate_fn 178 | """ 179 | vocab = self.vocab 180 | root = self.root + '/images' 181 | ann_id = self.ids[index] 182 | img_id = ann_id[0] 183 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw'] 184 | path = self.dataset[img_id]['filename'] 185 | 186 | image = Image.open(os.path.join(root, path)).convert('RGB') 187 | if self.transform is not None: 188 | image = self.transform(image) 189 | 190 | path_orig = copy.deepcopy(path) 191 | # print(path) 192 | path = path.replace('.jpg', '.npy') 193 | feature_path = self.feature_path 194 | 195 | image_rcnn, img_pos = self.get_rcnn(os.path.join(feature_path, path)) # return img-feature 100 2048 & pos-feature 196 | 197 | # Convert caption (string) to word ids. 198 | tokens = nltk.tokenize.word_tokenize( 199 | str(caption).lower().encode('utf-8').decode('utf-8')) 200 | caption = [] 201 | caption.append(vocab('')) 202 | caption.extend([vocab(token) for token in tokens]) 203 | caption.append(vocab('')) 204 | target = torch.Tensor(caption) 205 | return image, target, image_rcnn, img_pos, index, img_id 206 | 207 | def get_rcnn(self, img_path): 208 | if os.path.exists(img_path) and os.path.exists(img_path.replace('.npy', '_cls_prob.npy')): 209 | # time1 = time.time() 210 | img = torch.from_numpy(np.load(img_path)) 211 | vis_pe = torch.randn(100,1601 + 6) # no position information 212 | else: 213 | img = torch.randn(100, 2048) 214 | vis_pe = torch.randn(100, 1601 + 6) 215 | return img, vis_pe 216 | 217 | 218 | def __len__(self): 219 | return len(self.ids) 220 | 221 | 222 | def collate_fn(data): 223 | """Build mini-batch tensors from a list of (image, caption) tuples. 224 | Args: 225 | data: list of (image, caption) tuple. 226 | - image: torch tensor of shape (3, 256, 256). 227 | - caption: torch tensor of shape (?); variable length. 228 | 229 | Returns: 230 | images: torch tensor of shape (batch_size, 3, 256, 256). 231 | targets: torch tensor of shape (batch_size, padded_length). 232 | lengths: list; valid length for each padded caption. 233 | """ 234 | # Sort a data list by caption length 235 | data.sort(key=lambda x: len(x[1]), reverse=True) 236 | images, captions, image_rcnn, img_pos, ids, img_ids = zip(*data) 237 | 238 | # Merge images (convert tuple of 3D tensor to 4D tensor) 239 | images = torch.stack(images, 0) 240 | image_rcnn = torch.stack(image_rcnn, 0) 241 | img_pos = torch.stack(img_pos, 0) 242 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 243 | lengths = [len(cap) for cap in captions] 244 | targets = torch.zeros(len(captions), max(lengths)).long() 245 | for i, cap in enumerate(captions): 246 | end = lengths[i] 247 | targets[i, :end] = cap[:end] 248 | 249 | return images, targets, image_rcnn, img_pos, lengths, ids 250 | 251 | 252 | def get_loader_single(data_name, split, root, json, vocab, transform, batch_size=100, shuffle=True, 253 | num_workers=2, ids=None, collate_fn=collate_fn, region_bbox_file=None, feature_path=None): 254 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 255 | if 'coco' in data_name: 256 | # COCO custom dataset 257 | dataset = CocoDataset(root=root, 258 | json=json, 259 | vocab=vocab, 260 | region_bbox_file=region_bbox_file, 261 | region_det_file_prefix=feature_path, 262 | transform=transform, ids=ids) 263 | elif 'f8k' in data_name or 'f30k' in data_name: 264 | dataset = FlickrDataset(root=root, 265 | split=split, 266 | json=json, 267 | vocab=vocab, 268 | region_bbox_file=region_bbox_file, 269 | feature_path=feature_path, 270 | transform=transform) 271 | 272 | # Data loader 273 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 274 | batch_size=batch_size, 275 | shuffle=shuffle, 276 | pin_memory=True, 277 | num_workers=num_workers, 278 | collate_fn=collate_fn) 279 | return data_loader 280 | 281 | 282 | def get_transform(data_name, split_name, opt): 283 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], 284 | std=[0.229, 0.224, 0.225]) 285 | t_list = [] 286 | if split_name == 'train': 287 | t_list = [transforms.RandomResizedCrop(opt.crop_size), 288 | transforms.RandomHorizontalFlip()] 289 | elif split_name == 'val': 290 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 291 | elif split_name == 'test': 292 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 293 | 294 | t_end = [transforms.ToTensor(), normalizer] 295 | transform = transforms.Compose(t_list + t_end) 296 | return transform 297 | 298 | 299 | def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt): 300 | dpath = os.path.join(opt.data_path, data_name) 301 | 302 | roots, ids = get_paths(dpath, data_name, opt.use_restval) 303 | 304 | transform = get_transform(data_name, 'train', opt) 305 | train_loader = get_loader_single(opt.data_name, 'train', 306 | roots['train']['img'], 307 | roots['train']['cap'], 308 | vocab, transform, ids=ids['train'], 309 | batch_size=batch_size, shuffle=True, 310 | num_workers=workers, 311 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file, 312 | feature_path=opt.feature_path) 313 | 314 | transform = get_transform(data_name, 'val', opt) 315 | val_loader = get_loader_single(opt.data_name, 'val', 316 | roots['val']['img'], 317 | roots['val']['cap'], 318 | vocab, transform, ids=ids['val'], 319 | batch_size=batch_size, shuffle=False, 320 | num_workers=workers, 321 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file, 322 | feature_path=opt.feature_path) 323 | 324 | return train_loader, val_loader 325 | 326 | 327 | def get_test_loader(split_name, data_name, vocab, crop_size, batch_size, 328 | workers, opt): 329 | dpath = os.path.join(opt.data_path, data_name) 330 | 331 | roots, ids = get_paths(dpath, data_name, opt.use_restval) 332 | 333 | transform = get_transform(data_name, split_name, opt) 334 | test_loader = get_loader_single(opt.data_name, split_name, 335 | roots[split_name]['img'], 336 | roots[split_name]['cap'], 337 | vocab, transform, ids=ids[split_name], 338 | batch_size=batch_size, shuffle=False, 339 | num_workers=workers, 340 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file, 341 | feature_path=opt.feature_path) 342 | 343 | return test_loader 344 | -------------------------------------------------------------------------------- /data/coco/annotations.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/data/coco/annotations.zip -------------------------------------------------------------------------------- /data/coco/download.sh: -------------------------------------------------------------------------------- 1 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P data/ 2 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -P data/ 3 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip -P data/ 4 | 5 | unzip data/captions_train-val2014.zip -d ./ 6 | unzip data/train2014.zip -d images/ 7 | rm data/train2014.zip 8 | unzip data/val2014.zip -d images/ 9 | rm data/val2014.zip 10 | -------------------------------------------------------------------------------- /data/coco/images: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/f30k/dataset_flickr30k.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/data/f30k/dataset_flickr30k.zip -------------------------------------------------------------------------------- /data/f30k/images: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_bert.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen & Linyang Li, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | import os 14 | from PIL import Image 15 | from pycocotools.coco import COCO 16 | import numpy as np 17 | import json as jsonmod 18 | from collections import OrderedDict 19 | import copy 20 | from pytorch_pretrained_bert.tokenization import BertTokenizer 21 | import torch.nn.functional as F 22 | import h5py 23 | 24 | 25 | def get_paths(path, name='coco'): 26 | roots = {} 27 | ids = {} 28 | if 'coco' == name: 29 | imgdir = os.path.join(path, 'images') 30 | capdir = os.path.join(path, 'annotations') 31 | roots['train'] = { 32 | 'img': os.path.join(imgdir, 'train2014'), 33 | 'cap': os.path.join(capdir, 'captions_train2014.json') 34 | } 35 | roots['val'] = { 36 | 'img': os.path.join(imgdir, 'val2014'), 37 | 'cap': os.path.join(capdir, 'captions_val2014.json') 38 | } 39 | roots['test'] = { 40 | 'img': os.path.join(imgdir, 'val2014'), 41 | 'cap': os.path.join(capdir, 'captions_val2014.json') 42 | } 43 | roots['trainrestval'] = { 44 | 'img': (roots['train']['img'], roots['val']['img']), 45 | 'cap': (roots['train']['cap'], roots['val']['cap']) 46 | } 47 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy')) 48 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000] 49 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy')) 50 | ids['trainrestval'] = ( 51 | ids['train'], 52 | np.load(os.path.join(capdir, 'coco_restval_ids.npy'))) 53 | 54 | roots['train'] = roots['trainrestval'] 55 | ids['train'] = ids['trainrestval'] 56 | elif 'f30k' == name: 57 | imgdir = os.path.join(path, 'images') 58 | cap = os.path.join(path, 'dataset_flickr30k.json') 59 | roots['train'] = {'img': imgdir, 'cap': cap} 60 | roots['val'] = {'img': imgdir, 'cap': cap} 61 | roots['test'] = {'img': imgdir, 'cap': cap} 62 | ids = {'train': None, 'val': None, 'test': None} 63 | 64 | return roots, ids 65 | 66 | 67 | class CocoDataset(data.Dataset): 68 | 69 | def __init__(self, root, json, tokenizer, feature_path=None, region_bbox_file=None, max_seq_len=32, transform=None, ids=None): 70 | self.root = root 71 | if isinstance(json, tuple): 72 | self.coco = (COCO(json[0]), COCO(json[1])) 73 | else: 74 | self.coco = (COCO(json),) 75 | self.root = (root,) 76 | if ids is None: 77 | self.ids = list(self.coco.anns.keys()) 78 | else: 79 | self.ids = ids 80 | if isinstance(self.ids, tuple): 81 | self.bp = len(self.ids[0]) 82 | self.ids = list(self.ids[0]) + list(self.ids[1]) 83 | else: 84 | self.bp = len(self.ids) 85 | self.transform = transform 86 | self.tokenizer = tokenizer 87 | self.max_seq_len = max_seq_len 88 | self.region_bbox_file = region_bbox_file 89 | self.region_det_file_prefix = feature_path 90 | 91 | def __getitem__(self, index): 92 | root, caption, img_id, path, image, img_rcnn, img_pe = self.get_raw_item(index) 93 | 94 | if self.transform is not None: 95 | image = self.transform(image) 96 | 97 | target = self.get_text_input(caption) 98 | return img_rcnn, img_pe, target, index, img_id, image 99 | 100 | def get_raw_item(self, index): 101 | if index < self.bp: 102 | coco = self.coco[0] 103 | root = self.root[0] 104 | else: 105 | coco = self.coco[1] 106 | root = self.root[1] 107 | ann_id = self.ids[index] 108 | caption = coco.anns[ann_id]['caption'] 109 | img_id = coco.anns[ann_id]['image_id'] 110 | path = coco.loadImgs(img_id)[0]['file_name'] 111 | image = Image.open(os.path.join(root, path)).convert('RGB') 112 | img_rcnn, img_pe = self.get_rcnn(path) 113 | return root, caption, img_id, path, image, img_rcnn, img_pe 114 | 115 | def __len__(self): 116 | return len(self.ids) 117 | 118 | def get_text_input(self, caption): 119 | caption_tokens = self.tokenizer.tokenize(caption) 120 | caption_tokens = ['[CLS]'] + caption_tokens + ['[SEP]'] 121 | caption_ids = self.tokenizer.convert_tokens_to_ids(caption_tokens) 122 | if len(caption_ids) >= self.max_seq_len: 123 | caption_ids = caption_ids[:self.max_seq_len] 124 | else: 125 | caption_ids = caption_ids + [0] * (self.max_seq_len - len(caption_ids)) 126 | caption = torch.tensor(caption_ids) 127 | return caption 128 | 129 | def get_rcnn(self, path): 130 | img_id = path.split('/')[-1].split('.')[0] 131 | with h5py.File(self.region_det_file_prefix + '_feat' + img_id[-3:] + '.h5', 'r') as region_feat_f, \ 132 | h5py.File(self.region_det_file_prefix + '_cls' + img_id[-3:] + '.h5', 'r') as region_cls_f, \ 133 | h5py.File(self.region_bbox_file, 'r') as region_bbox_f: 134 | 135 | img = torch.from_numpy(region_feat_f[img_id][:]).float() 136 | cls_label = torch.from_numpy(region_cls_f[img_id][:]).float() 137 | vis_pe = torch.from_numpy(region_bbox_f[img_id][:]) 138 | 139 | # lazy normalization of the coordinates... 140 | 141 | w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 142 | h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 143 | vis_pe[:, [0, 2]] /= w_est 144 | vis_pe[:, [1, 3]] /= h_est 145 | rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) 146 | rel_area.clamp_(0) 147 | 148 | vis_pe = torch.cat((vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score 149 | normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) 150 | vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ 151 | F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... 152 | 153 | return img, vis_pe 154 | 155 | 156 | class FlickrDataset(data.Dataset): 157 | 158 | def __init__(self, root, json, split, tokenizer, feature_path=None, region_bbox_file=None, max_seq_len=32, 159 | transform=None): 160 | self.root = root 161 | self.split = split 162 | self.transform = transform 163 | self.dataset = jsonmod.load(open(json, 'r'))['images'] 164 | self.ids = [] 165 | self.tokenizer = tokenizer 166 | self.max_seq_len = max_seq_len 167 | for i, d in enumerate(self.dataset): 168 | if d['split'] == split: 169 | self.ids += [(i, x) for x in range(len(d['sentences']))] 170 | self.region_bbox_file = region_bbox_file 171 | self.feature_path = feature_path 172 | 173 | def __getitem__(self, index): 174 | root = self.root 175 | ann_id = self.ids[index] 176 | img_id = ann_id[0] 177 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw'] 178 | path = self.dataset[img_id]['filename'] 179 | path_orig = copy.deepcopy(path) 180 | path = path.replace('.jpg', '.npy') 181 | feature_path = self.feature_path 182 | # orig image 183 | image_orig = Image.open(os.path.join(root, path_orig)).convert('RGB') 184 | if self.transform is not None: 185 | image_orig = self.transform(image_orig) 186 | target = self.get_text_input(caption) 187 | image, img_pos = self.get_rcnn(os.path.join(feature_path, path)) # return img-feature 100 2048 & pos-feature 188 | 189 | return image, img_pos, target, index, img_id, image_orig 190 | 191 | def get_rcnn(self, img_path): 192 | if os.path.exists(img_path) and os.path.exists(img_path.replace('.npy', '_cls_prob.npy')): 193 | img = torch.from_numpy(np.load(img_path)) 194 | img_id = img_path.split('/')[-1].split('.')[0] 195 | cls_label = torch.from_numpy(np.load(img_path.replace('.npy', '_cls_prob.npy'))) 196 | with h5py.File(self.region_bbox_file, 'r') as region_bbox_f: 197 | vis_pe = torch.from_numpy(region_bbox_f[img_id][:]) 198 | 199 | # lazy normalization of the coordinates... 200 | 201 | w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5 202 | h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5 203 | vis_pe[:, [0, 2]] /= w_est 204 | vis_pe[:, [1, 3]] /= h_est 205 | rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0]) 206 | rel_area.clamp_(0) 207 | 208 | vis_pe = torch.cat((vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score 209 | normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1) 210 | vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \ 211 | F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded... 212 | else: 213 | img = torch.randn(100, 2048) 214 | vis_pe = torch.randn(100, 1601 + 6) 215 | return img, vis_pe 216 | 217 | def get_text_input(self, caption): 218 | caption_tokens = self.tokenizer.tokenize(caption) 219 | caption_tokens = ['[CLS]'] + caption_tokens + ['[SEP]'] 220 | caption_ids = self.tokenizer.convert_tokens_to_ids(caption_tokens) 221 | if len(caption_ids) >= self.max_seq_len: 222 | caption_ids = caption_ids[:self.max_seq_len] 223 | else: 224 | caption_ids = caption_ids + [0] * (self.max_seq_len - len(caption_ids)) 225 | caption = torch.tensor(caption_ids) 226 | return caption 227 | 228 | def __len__(self): 229 | return len(self.ids) 230 | 231 | 232 | def collate_fn(data): 233 | images, img_pos, captions, ids, img_ids, image_orig = zip(*data) 234 | images = torch.stack(images, 0) 235 | img_pos = torch.stack(img_pos, 0) 236 | captions = torch.stack(captions, 0) 237 | images_orig = torch.stack(image_orig, 0) 238 | return images, images_orig, img_pos, captions, ids 239 | 240 | 241 | def get_tokenizer(bert_path): 242 | tokenizer = BertTokenizer(bert_path + 'vocab.txt') 243 | return tokenizer 244 | 245 | 246 | def get_loader_single(data_name, split, root, json, transform, 247 | batch_size=128, shuffle=True, 248 | num_workers=10, ids=None, collate_fn=collate_fn, 249 | feature_path=None, 250 | region_bbox_file=None, 251 | bert_path=None 252 | ): 253 | if 'coco' in data_name: 254 | dataset = CocoDataset(root=root, json=json, 255 | feature_path=feature_path, 256 | region_bbox_file=region_bbox_file, 257 | tokenizer=get_tokenizer(bert_path), 258 | max_seq_len=32, transform=transform, ids=ids) 259 | elif 'f30k' in data_name: 260 | dataset = FlickrDataset(root=root, split=split, json=json, 261 | feature_path=feature_path, 262 | region_bbox_file=region_bbox_file, 263 | tokenizer=get_tokenizer(bert_path), 264 | max_seq_len=32, transform=transform) 265 | 266 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 267 | batch_size=batch_size, 268 | shuffle=shuffle, 269 | pin_memory=True, 270 | num_workers=num_workers, 271 | collate_fn=collate_fn) 272 | return data_loader 273 | 274 | 275 | def get_transform(data_name, split_name, opt): 276 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], 277 | std=[0.229, 0.224, 0.225]) 278 | t_list = [] 279 | if split_name == 'train': 280 | t_list = [transforms.RandomResizedCrop(opt.crop_size), 281 | transforms.RandomHorizontalFlip()] 282 | # t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 283 | elif split_name == 'val': 284 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 285 | elif split_name == 'test': 286 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 287 | 288 | t_end = [transforms.ToTensor(), normalizer] 289 | transform = transforms.Compose(t_list + t_end) 290 | return transform 291 | 292 | 293 | def get_loaders(data_name, batch_size, workers, opt): 294 | dpath = os.path.join(opt.data_path, data_name) 295 | roots, ids = get_paths(dpath, data_name) 296 | 297 | transform = get_transform(data_name, 'train', opt) 298 | train_loader = get_loader_single(opt.data_name, 'train', 299 | roots['train']['img'], 300 | roots['train']['cap'], 301 | transform, ids=ids['train'], 302 | batch_size=batch_size, shuffle=True, 303 | num_workers=workers, 304 | collate_fn=collate_fn, 305 | feature_path=opt.feature_path, 306 | region_bbox_file=opt.region_bbox_file, 307 | bert_path=opt.bert_path 308 | ) 309 | 310 | transform = get_transform(data_name, 'val', opt) 311 | 312 | val_loader = get_loader_single(opt.data_name, 'val', 313 | roots['val']['img'], 314 | roots['val']['cap'], 315 | transform, ids=ids['val'], 316 | batch_size=batch_size, shuffle=False, 317 | num_workers=workers, 318 | collate_fn=collate_fn, 319 | feature_path=opt.feature_path, 320 | region_bbox_file=opt.region_bbox_file, 321 | bert_path=opt.bert_path 322 | ) 323 | 324 | return train_loader, val_loader 325 | 326 | 327 | def get_test_loader(split_name, data_name, batch_size, workers, opt): 328 | dpath = os.path.join(opt.data_path, data_name) 329 | 330 | roots, ids = get_paths(dpath, data_name) 331 | 332 | transform = get_transform(data_name, split_name, opt) 333 | test_loader = get_loader_single(opt.data_name, split_name, 334 | roots[split_name]['img'], 335 | roots[split_name]['cap'], 336 | transform, ids=ids[split_name], 337 | batch_size=batch_size, shuffle=False, 338 | num_workers=workers, 339 | collate_fn=collate_fn, 340 | feature_path=opt.feature_path, 341 | region_bbox_file=opt.region_bbox_file, 342 | bert_path=opt.bert_path 343 | ) 344 | 345 | return test_loader 346 | 347 | 348 | class AverageMeter(object): 349 | 350 | def __init__(self): 351 | self.reset() 352 | 353 | def reset(self): 354 | self.val = 0 355 | self.avg = 0 356 | self.sum = 0 357 | self.count = 0 358 | 359 | def update(self, val, n=0): 360 | self.val = val 361 | self.sum += val * n 362 | self.count += n 363 | self.avg = self.sum / (.0001 + self.count) 364 | 365 | def __str__(self): 366 | if self.count == 0: 367 | return str(self.val) 368 | return '%.4f (%.4f)' % (self.val, self.avg) 369 | 370 | 371 | class LogCollector(object): 372 | def __init__(self): 373 | self.meters = OrderedDict() 374 | 375 | def update(self, k, v, n=0): 376 | if k not in self.meters: 377 | self.meters[k] = AverageMeter() 378 | self.meters[k].update(v, n) 379 | 380 | def __str__(self): 381 | s = '' 382 | for i, (k, v) in enumerate(self.meters.items()): 383 | if i > 0: 384 | s += ' ' 385 | s += k + ' ' + str(v) 386 | return s 387 | 388 | def tb_log(self, tb_logger, prefix='', step=None): 389 | for k, v in self.meters.items(): 390 | tb_logger.log_value(prefix + k, v.val, step=step) 391 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen, 2020 8 | # ------------------------------------------------------------ 9 | 10 | from __future__ import print_function 11 | import os 12 | import pickle 13 | import numpy 14 | from data import get_test_loader 15 | import time 16 | import numpy as np 17 | from vocab import Vocabulary # NOQA 18 | import torch 19 | from model import VSE 20 | from collections import OrderedDict 21 | import argparse 22 | 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=0): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / (.0001 + self.count) 41 | 42 | def __str__(self): 43 | """String representation for logging 44 | """ 45 | # for values that should be recorded exactly e.g. iteration number 46 | if self.count == 0: 47 | return str(self.val) 48 | # for stats 49 | return '%.4f (%.4f)' % (self.val, self.avg) 50 | 51 | 52 | class LogCollector(object): 53 | """A collection of logging objects that can change from train to val""" 54 | 55 | def __init__(self): 56 | # to keep the order of logged variables deterministic 57 | self.meters = OrderedDict() 58 | 59 | def update(self, k, v, n=0): 60 | # create a new meter if previously not recorded 61 | if k not in self.meters: 62 | self.meters[k] = AverageMeter() 63 | self.meters[k].update(v, n) 64 | 65 | def __str__(self): 66 | """Concatenate the meters in one log line 67 | """ 68 | s = '' 69 | for i, (k, v) in enumerate(self.meters.items()): 70 | if i > 0: 71 | s += ' ' 72 | s += k + ' ' + str(v) 73 | return s 74 | 75 | def tb_log(self, tb_logger, prefix='', step=None): 76 | """Log using tensorboard 77 | """ 78 | for k, v in self.meters.items(): 79 | tb_logger.log_value(prefix + k, v.val, step=step) 80 | 81 | 82 | def encode_data(model, data_loader, log_step=10, logging=print): 83 | """Encode all images and captions loadable by `data_loader` 84 | """ 85 | batch_time = AverageMeter() 86 | val_logger = LogCollector() 87 | 88 | # switch to evaluate mode 89 | model.val_start() 90 | 91 | end = time.time() 92 | 93 | # numpy array to keep all the embeddings 94 | img_embs = None 95 | cap_embs = None 96 | with torch.no_grad(): 97 | for i, (images, captions, img_rcnn, img_pos, lengths, ids) in enumerate(data_loader): 98 | # make sure val logger is used 99 | model.logger = val_logger 100 | 101 | # compute the embeddings 102 | img_emb, cap_emb = model.forward_emb(images, captions, img_rcnn, img_pos, lengths) 103 | 104 | # initialize the numpy arrays given the size of the embeddings 105 | if img_embs is None: 106 | img_embs = torch.zeros(len(data_loader.dataset), img_emb.size(1)).cuda() 107 | cap_embs = torch.zeros(len(data_loader.dataset), cap_emb.size(1)).cuda() 108 | 109 | img_embs[ids] = img_emb 110 | cap_embs[ids] = cap_emb 111 | 112 | # measure accuracy and record loss 113 | model.forward_loss(img_emb, cap_emb) 114 | 115 | # measure elapsed time 116 | batch_time.update(time.time() - end) 117 | end = time.time() 118 | 119 | if i % log_step == 0: 120 | logging('Test: [{0}/{1}]\t' 121 | '{e_log}\t' 122 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 123 | .format( 124 | i, len(data_loader), batch_time=batch_time, 125 | e_log=str(model.logger))) 126 | del images, captions 127 | 128 | return img_embs, cap_embs 129 | 130 | 131 | def evalrank(model_path, data_path=None, split='dev', fold5=False, region_bbox_file=None, feature_path=None): 132 | """ 133 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 134 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 135 | used for evaluation. 136 | """ 137 | # load model and options 138 | checkpoint = torch.load(model_path) 139 | opt = checkpoint['opt'] 140 | if data_path is not None: 141 | opt.data_path = data_path 142 | if region_bbox_file is not None: 143 | opt.region_bbox_file = region_bbox_file 144 | if feature_path is not None: 145 | opt.feature_path = feature_path 146 | 147 | # load vocabulary used by the model 148 | with open(os.path.join(opt.vocab_path, 149 | '%s_vocab.pkl' % opt.data_name), 'rb') as f: 150 | vocab = pickle.load(f) 151 | opt.vocab_size = len(vocab) 152 | print(opt) 153 | 154 | # construct model 155 | model = VSE(opt) 156 | # load model state 157 | model.load_state_dict(checkpoint['model']) 158 | 159 | print('Loading dataset') 160 | data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size, 161 | opt.batch_size, opt.workers, opt) 162 | print('Computing results...') 163 | img_embs, cap_embs= encode_data(model, data_loader) 164 | time_sim_start = time.time() 165 | 166 | if not fold5: 167 | img_emb_new = img_embs[0:img_embs.size(0):5] 168 | print(img_emb_new.size()) 169 | 170 | sims = torch.mm(img_emb_new, cap_embs.t()) 171 | sims_T = torch.mm(cap_embs, cap_embs.t()) 172 | sims_T = sims_T.cpu().numpy() 173 | 174 | sims = sims.cpu().numpy() 175 | np.save('sims_f.npy',sims) 176 | np.save('sims_f_T.npy',sims_T) 177 | 178 | print('Images: %d, Captions: %d' % 179 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 180 | 181 | r = simrank(sims) 182 | 183 | time_sim_end = time.time() 184 | print('sims_time:%f' % (time_sim_end - time_sim_start)) 185 | del sims 186 | else: # fold5-especially for coco 187 | print('5k---------------') 188 | img_emb_new = img_embs[0:img_embs.size(0):5] 189 | print(img_emb_new.size()) 190 | 191 | sims = torch.mm(img_emb_new, cap_embs.t()) 192 | sims_T = torch.mm(cap_embs, cap_embs.t()) 193 | 194 | sims = sims.cpu().numpy() 195 | sims_T = sims_T.cpu().numpy() 196 | 197 | np.save('sims_full_5k.npy',sims) 198 | np.save('sims_full_T_5k.npy',sims_T) 199 | print('Images: %d, Captions: %d' % 200 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 201 | 202 | r = simrank(sims) 203 | 204 | time_sim_end = time.time() 205 | print('sims_time:%f' % (time_sim_end - time_sim_start)) 206 | del sims, sims_T 207 | print('1k---------------') 208 | r_ = [0, 0, 0, 0, 0, 0, 0] 209 | for i in range(5): 210 | print(i) 211 | img_emb_new = img_embs[i * 5000 : int(i * 5000 + img_embs.size(0)/5):5] 212 | cap_emb_new = cap_embs[i * 5000 : int(i * 5000 + cap_embs.size(0)/5)] 213 | 214 | sims = torch.mm(img_emb_new, cap_emb_new.t()) 215 | sims_T = torch.mm(cap_emb_new, cap_emb_new.t()) 216 | sims_T = sims_T.cpu().numpy() 217 | sims = sims.cpu().numpy() 218 | np.save('sims_full_%d.npy'%i,sims) 219 | np.save('sims_full_T_%d'%i,sims_T) 220 | 221 | print('Images: %d, Captions: %d' % 222 | (img_emb_new.size(0), cap_emb_new.size(0))) 223 | 224 | r = simrank(sims) 225 | r_ = np.array(r_) + np.array(r) 226 | 227 | del sims 228 | print('--------------------') 229 | r_ = tuple(r_/5) 230 | print('I2T:%.1f %.1f %.1f' % r_[0:3]) 231 | print('T2I:%.1f %.1f %.1f' % r_[3:6]) 232 | print('Rsum:%.1f' % r_[-1]) 233 | 234 | 235 | def simrank(similarity): 236 | sims = similarity 237 | img_size, cap_size = sims.shape 238 | print("imgs: %d, caps: %d" % (img_size, cap_size)) 239 | # i2t 240 | index_list = [] 241 | ranks = numpy.zeros(img_size) 242 | top1 = numpy.zeros(img_size) 243 | for index in range(img_size): 244 | d = sims[index] 245 | inds = numpy.argsort(d)[::-1] 246 | # print(inds) 247 | index_list.append(inds[0]) 248 | rank = 1e20 249 | for i in range(5 * index, 5 * index + 5, 1): 250 | tmp = numpy.where(inds == i)[0] 251 | # print(tmp) 252 | if tmp < rank: 253 | rank = tmp 254 | ranks[index] = rank 255 | top1[index] = inds[0] 256 | 257 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 258 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 259 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 260 | medr = numpy.floor(numpy.median(ranks)) + 1 261 | meanr = ranks.mean() + 1 262 | print('i2t:r1: %.1f, r5: %.1f, r10: %.1f' % (r1, r5, r10)) # , medr, meanr) 263 | rs = r1 + r5 + r10 264 | # t2i 265 | sims_t2i = sims.T 266 | ranks = numpy.zeros(cap_size) 267 | top1 = numpy.zeros(cap_size) 268 | for index in range(img_size): 269 | 270 | d = sims_t2i[5 * index:5 * index + 5] # 5*1000 271 | inds = numpy.zeros(d.shape) 272 | for i in range(len(inds)): 273 | inds[i] = numpy.argsort(d[i])[::-1] 274 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 275 | top1[5 * index + i] = inds[i][0] 276 | 277 | r1_ = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 278 | r5_ = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 279 | r10_ = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 280 | medr_ = numpy.floor(numpy.median(ranks)) + 1 281 | meanr_ = ranks.mean() + 1 282 | rs_ = r1_ + r5_ + r10_ 283 | print('t2i:r1: %.1f, r5: %.1f, r10: %.1f' % (r1_, r5_, r10_)) 284 | rsum = rs + rs_ 285 | print('rsum=%.1f' % rsum) 286 | return [r1, r5, r10, r1_, r5_, r10_, rsum] 287 | 288 | 289 | def i2t(images, captions, npts=None, return_ranks=False): 290 | """ 291 | Images->Text (Image Annotation) 292 | Images: (5N, K) matrix of images 293 | Captions: (5N, K) matrix of captions 294 | """ 295 | images = images.cpu().numpy() 296 | captions = captions.cpu().numpy() 297 | if npts is None: 298 | npts = int(images.shape[0] / 5) 299 | print(npts) 300 | index_list = [] 301 | 302 | ranks = numpy.zeros(npts) 303 | top1 = numpy.zeros(npts) 304 | for index in range(npts): 305 | 306 | # Get query image 307 | im = images[5 * index].reshape(1, images.shape[1]) 308 | 309 | # Compute scores 310 | 311 | d = numpy.dot(im, captions.T).flatten() 312 | inds = numpy.argsort(d)[::-1] 313 | index_list.append(inds[0]) 314 | 315 | # Score 316 | rank = 1e20 317 | for i in range(5 * index, 5 * index + 5, 1): 318 | tmp = numpy.where(inds == i)[0][0] 319 | if tmp < rank: 320 | rank = tmp 321 | ranks[index] = rank 322 | top1[index] = inds[0] 323 | 324 | # Compute metrics 325 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 326 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 327 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 328 | medr = numpy.floor(numpy.median(ranks)) + 1 329 | meanr = ranks.mean() + 1 330 | if return_ranks: 331 | return (r1, r5, r10, medr, meanr), (ranks, top1) 332 | else: 333 | return (r1, r5, r10, medr, meanr) 334 | 335 | 336 | def t2i(images, captions, npts=None, return_ranks=False): 337 | """ 338 | Text->Images (Image Search) 339 | Images: (5N, K) matrix of images 340 | Captions: (5N, K) matrix of captions 341 | """ 342 | images = images.cpu().numpy() 343 | captions = captions.cpu().numpy() 344 | if npts is None: 345 | npts = int(images.shape[0] / 5) 346 | print(npts) 347 | ims = numpy.array([images[i] for i in range(0, len(images), 5)]) 348 | 349 | ranks = numpy.zeros(5 * npts) 350 | top1 = numpy.zeros(5 * npts) 351 | for index in range(npts): 352 | 353 | # Get query captions 354 | queries = captions[5 * index:5 * index + 5] 355 | 356 | # Compute scores 357 | d = numpy.dot(queries, ims.T) 358 | inds = numpy.zeros(d.shape) 359 | for i in range(len(inds)): 360 | inds[i] = numpy.argsort(d[i])[::-1] 361 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 362 | top1[5 * index + i] = inds[i][0] 363 | 364 | # Compute metrics 365 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 366 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 367 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 368 | medr = numpy.floor(numpy.median(ranks)) + 1 369 | meanr = ranks.mean() + 1 370 | if return_ranks: 371 | return (r1, r5, r10, medr, meanr), (ranks, top1) 372 | else: 373 | return (r1, r5, r10, medr, meanr) 374 | 375 | 376 | def main(): 377 | parser = argparse.ArgumentParser() 378 | parser.add_argument('--model', default='single_model', help='model name') 379 | parser.add_argument('--fold', action='store_true', help='fold5') 380 | parser.add_argument('--name', default='model_best', help='checkpoint name') 381 | parser.add_argument('--data_path', default='data', help='data path') 382 | parser.add_argument('--region_bbox_file', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str, metavar='PATH', 383 | help='path to region features bbox file') 384 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', type=str, metavar='PATH', 385 | help='path to region features') 386 | opt = parser.parse_args() 387 | 388 | evalrank('runs/' + opt.model + '/' + opt.name + ".pth.tar", data_path = opt.data_path, split="test", fold5=opt.fold, region_bbox_file=opt.region_bbox_file, feature_path=opt.feature_path) 389 | 390 | if __name__ == '__main__': 391 | main() -------------------------------------------------------------------------------- /evaluation_bert.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # ----------------------------------------------------------- 3 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 4 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 5 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 6 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 7 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 8 | # Writen by Keyu Wen, 2020 9 | # ------------------------------------------------------------ 10 | 11 | from __future__ import print_function 12 | import numpy 13 | from data_bert import get_test_loader 14 | import time 15 | import numpy as np 16 | import torch 17 | import argparse 18 | from model_bert import VSE 19 | from collections import OrderedDict 20 | 21 | 22 | class AverageMeter(object): 23 | 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=0): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / (.0001 + self.count) 38 | 39 | def __str__(self): 40 | if self.count == 0: 41 | return str(self.val) 42 | return '%.4f (%.4f)' % (self.val, self.avg) 43 | 44 | 45 | class LogCollector(object): 46 | def __init__(self): 47 | self.meters = OrderedDict() 48 | 49 | def update(self, k, v, n=0): 50 | if k not in self.meters: 51 | self.meters[k] = AverageMeter() 52 | self.meters[k].update(v, n) 53 | 54 | def __str__(self): 55 | s = '' 56 | for i, (k, v) in enumerate(self.meters.items()): 57 | if i > 0: 58 | s += ' ' 59 | s += k + ' ' + str(v) 60 | return s 61 | 62 | def tb_log(self, tb_logger, prefix='', step=None): 63 | for k, v in self.meters.items(): 64 | tb_logger.log_value(prefix + k, v.val, step=step) 65 | 66 | 67 | def encode_data(model, data_loader, log_step=10, logging=print): 68 | batch_time = AverageMeter() 69 | val_logger = LogCollector() 70 | model.val_start() 71 | 72 | end = time.time() 73 | 74 | img_embs = None 75 | cap_embs = None 76 | time_encode_start = time.time() 77 | # device = torch.device("cuda:0") 78 | with torch.no_grad(): 79 | for i, (images, images_orig, img_pos, captions, ids) in enumerate(data_loader): 80 | model.logger = val_logger 81 | 82 | img_emb, cap_emb = model.forward_emb(images_orig, images, img_pos, captions) 83 | 84 | if img_embs is None: 85 | img_embs = torch.zeros(len(data_loader.dataset), img_emb.size(1)).cuda() 86 | cap_embs = torch.zeros(len(data_loader.dataset), cap_emb.size(1)).cuda() 87 | 88 | img_embs[ids] = img_emb 89 | cap_embs[ids] = cap_emb 90 | 91 | model.forward_loss(img_emb, cap_emb) 92 | 93 | batch_time.update(time.time() - end) 94 | end = time.time() 95 | 96 | if i % log_step == 0: 97 | logging('Test: [{0}/{1}]\t' 98 | '{e_log}\t' 99 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 100 | .format( 101 | i, len(data_loader), batch_time=batch_time, 102 | e_log=str(model.logger))) 103 | del images, captions 104 | time_encode_end = time.time() 105 | print('encode_time:%f' % (time_encode_end - time_encode_start)) 106 | img_emb_new = img_embs[0:img_embs.size(0):5] 107 | sims = torch.mm(img_emb_new, cap_embs.t()) 108 | sims = sims.cpu().numpy() 109 | 110 | return img_embs, cap_embs, sims 111 | 112 | 113 | def evalrank(model_path, data_path=None, split='dev', fold5=False, region_bbox_file=None, feature_path=None): 114 | checkpoint = torch.load(model_path) 115 | opt = checkpoint['opt'] 116 | 117 | if data_path is not None: 118 | opt.data_path = data_path 119 | if data_path is not None: 120 | opt.region_bbox_file = region_bbox_file 121 | if data_path is not None: 122 | opt.feature_path = feature_path 123 | 124 | print(opt) 125 | model = VSE(opt) 126 | 127 | model.load_state_dict(checkpoint['model']) # 128 | 129 | print('Loading dataset') 130 | data_loader = get_test_loader(split, opt.data_name, opt.batch_size, opt.workers, opt) 131 | 132 | print('Computing results...') 133 | img_embs, cap_embs, sims = encode_data(model, data_loader) 134 | 135 | time_sim_start = time.time() 136 | 137 | if not fold5: 138 | img_emb_new = img_embs[0:img_embs.size(0):5] 139 | print(img_emb_new.size()) 140 | sims = torch.mm(img_emb_new, cap_embs.t()) 141 | sims_T = torch.mm(cap_embs, cap_embs.t()) 142 | sims_T = sims_T.cpu().numpy() 143 | 144 | sims = sims.cpu().numpy() 145 | np.save('sims_f.npy',sims) 146 | np.save('sims_f_T.npy',sims_T) 147 | 148 | print('Images: %d, Captions: %d' % 149 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 150 | 151 | r = simrank(sims) 152 | 153 | time_sim_end = time.time() 154 | print('sims_time:%f' % (time_sim_end - time_sim_start)) 155 | del sims 156 | else: # fold5-especially for coco 157 | print('5k---------------') 158 | img_emb_new = img_embs[0:img_embs.size(0):5] 159 | print(img_emb_new.size()) 160 | 161 | sims = torch.mm(img_emb_new, cap_embs.t()) 162 | sims_T = torch.mm(cap_embs, cap_embs.t()) 163 | 164 | sims = sims.cpu().numpy() 165 | sims_T = sims_T.cpu().numpy() 166 | 167 | np.save('sims_full_5k.npy',sims) 168 | np.save('sims_full_T_5k.npy',sims_T) 169 | print('Images: %d, Captions: %d' % 170 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 171 | 172 | r = simrank(sims) 173 | 174 | time_sim_end = time.time() 175 | print('sims_time:%f' % (time_sim_end - time_sim_start)) 176 | del sims, sims_T 177 | print('1k---------------') 178 | r_ = [0, 0, 0, 0, 0, 0, 0] 179 | for i in range(5): 180 | print(i) 181 | img_emb_new = img_embs[i * 5000 : int(i * 5000 + img_embs.size(0)/5):5] 182 | cap_emb_new = cap_embs[i * 5000 : int(i * 5000 + cap_embs.size(0)/5)] 183 | 184 | sims = torch.mm(img_emb_new, cap_emb_new.t()) 185 | sims_T = torch.mm(cap_emb_new, cap_emb_new.t()) 186 | sims_T = sims_T.cpu().numpy() 187 | sims = sims.cpu().numpy() 188 | np.save('sims_full_%d.npy'%i,sims) 189 | np.save('sims_full_T_%d'%i,sims_T) 190 | 191 | print('Images: %d, Captions: %d' % 192 | (img_emb_new.size(0), cap_emb_new.size(0))) 193 | 194 | r = simrank(sims) 195 | r_ = np.array(r_) + np.array(r) 196 | 197 | del sims 198 | print('--------------------') 199 | r_ = tuple(r_/5) 200 | print('I2T:%.1f %.1f %.1f' % r_[0:3]) 201 | print('T2I:%.1f %.1f %.1f' % r_[3:6]) 202 | print('Rsum:%.1f' % r_[-1]) 203 | 204 | 205 | def i2t(images, captions, npts=None, return_ranks=False): 206 | if npts is None: 207 | npts = int(images.shape[0] / 5) 208 | print(npts) 209 | index_list = [] 210 | 211 | ranks = numpy.zeros(npts) 212 | top1 = numpy.zeros(npts) 213 | 214 | for index in range(npts): 215 | 216 | im = images[5 * index].reshape(1, images.shape[1]) 217 | 218 | d = numpy.dot(im, captions.T).flatten() 219 | inds = numpy.argsort(d)[::-1] 220 | index_list.append(inds[0]) 221 | 222 | rank = 1e20 223 | for i in range(5 * index, 5 * index + 5, 1): 224 | tmp = numpy.where(inds == i)[0][0] 225 | if tmp < rank: 226 | rank = tmp 227 | ranks[index] = rank 228 | top1[index] = inds[0] 229 | 230 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 231 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 232 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 233 | medr = numpy.floor(numpy.median(ranks)) + 1 234 | meanr = ranks.mean() + 1 235 | if return_ranks: 236 | return (r1, r5, r10, medr, meanr), (ranks, top1) 237 | else: 238 | return (r1, r5, r10, medr, meanr) 239 | 240 | 241 | def t2i(images, captions, npts=None, return_ranks=False): 242 | if npts is None: 243 | npts = int(images.shape[0] / 5) 244 | print(npts) 245 | ims = numpy.array([images[i] for i in range(0, len(images), 5)]) 246 | 247 | ranks = numpy.zeros(5 * npts) 248 | top1 = numpy.zeros(5 * npts) 249 | 250 | for index in range(npts): 251 | queries = captions[5 * index:5 * index + 5] 252 | print('3') 253 | 254 | d = np.dot(queries, ims.T) 255 | 256 | inds = numpy.zeros(d.shape) 257 | print('5') 258 | for i in range(len(inds)): 259 | inds[i] = numpy.argsort(d[i])[::-1] 260 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 261 | top1[5 * index + i] = inds[i][0] 262 | 263 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 264 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 265 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 266 | medr = numpy.floor(numpy.median(ranks)) + 1 267 | meanr = ranks.mean() + 1 268 | if return_ranks: 269 | return (r1, r5, r10, medr, meanr), (ranks, top1) 270 | else: 271 | return (r1, r5, r10, medr, meanr) 272 | 273 | 274 | def simrank(similarity): 275 | sims = similarity # similarity matrix 1k*5k 276 | # print(sims) 277 | img_size, cap_size = sims.shape 278 | print("imgs: %d, caps: %d" % (img_size, cap_size)) 279 | # time.sleep(10) 280 | # i2t 281 | index_list = [] 282 | ranks = numpy.zeros(img_size) 283 | top1 = numpy.zeros(img_size) 284 | for index in range(img_size): 285 | d = sims[index] 286 | inds = numpy.argsort(d)[::-1] 287 | index_list.append(inds[0]) 288 | rank = 1e20 289 | for i in range(5 * index, 5 * index + 5, 1): 290 | tmp = numpy.where(inds == i)[0] 291 | if tmp < rank: 292 | rank = tmp 293 | ranks[index] = rank 294 | top1[index] = inds[0] 295 | 296 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 297 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 298 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 299 | medr = numpy.floor(numpy.median(ranks)) + 1 300 | meanr = ranks.mean() + 1 301 | print('i2t:r1: %.1f, r5: %.1f, r10: %.1f' % (r1, r5, r10)) # , medr, meanr) 302 | rs = r1 + r5 + r10 303 | # t2i 304 | sims_t2i = sims.T 305 | ranks = numpy.zeros(cap_size) 306 | top1 = numpy.zeros(cap_size) 307 | for index in range(img_size): 308 | 309 | d = sims_t2i[5 * index:5 * index + 5] # 5*1000 310 | inds = numpy.zeros(d.shape) 311 | for i in range(len(inds)): 312 | inds[i] = numpy.argsort(d[i])[::-1] 313 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 314 | top1[5 * index + i] = inds[i][0] 315 | 316 | r1_ = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 317 | r5_ = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 318 | r10_ = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 319 | medr_ = numpy.floor(numpy.median(ranks)) + 1 320 | meanr_ = ranks.mean() + 1 321 | rs_ = r1_ + r5_ + r10_ 322 | print('t2i:r1: %.1f, r5: %.1f, r10: %.1f' % (r1_, r5_, r10_)) 323 | rsum = rs + rs_ 324 | print('rsum=%.1f' % rsum) 325 | return [r1, r5, r10, r1_, r5_, r10_, rsum] 326 | 327 | 328 | def main(): 329 | parser = argparse.ArgumentParser() 330 | parser.add_argument('--model', default='single_model', help='model name') 331 | parser.add_argument('--fold', action='store_true', help='fold5') 332 | parser.add_argument('--name', default='model_best', help='checkpoint name') 333 | parser.add_argument('--data_path', default='data', help='data path') 334 | parser.add_argument('--region_bbox_file', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str, metavar='PATH', 335 | help='path to region features bbox file') 336 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', type=str, metavar='PATH', 337 | help='path to region features') 338 | opt = parser.parse_args() 339 | 340 | evalrank('runs/' + opt.model + '/' + opt.name + ".pth.tar", data_path = opt.data_path, split="test", fold5=opt.fold, region_bbox_file=opt.region_bbox_file, feature_path=opt.feature_path) 341 | 342 | if __name__ == '__main__': 343 | main() 344 | -------------------------------------------------------------------------------- /figures/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/figures/model.jpg -------------------------------------------------------------------------------- /flickr_sims/flickr_sims.txt: -------------------------------------------------------------------------------- 1 | Path to save similarity matrixes during inference stage of Flickr30K. 2 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init 13 | import torchvision.models as models 14 | from torch.autograd import Variable 15 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 16 | import torch.backends.cudnn as cudnn 17 | from torch.nn.utils.clip_grad import clip_grad_norm_ 18 | import numpy as np 19 | from collections import OrderedDict 20 | import time 21 | from GAT import GATLayer 22 | import copy 23 | from resnet import resnet152 24 | import torchtext 25 | import pickle 26 | import os 27 | 28 | 29 | def l2norm(X, dim=-1, eps=1e-12): 30 | """L2-normalize columns of X 31 | """ 32 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 33 | X = torch.div(X, norm) 34 | return X 35 | 36 | 37 | class GATopt(object): 38 | def __init__(self, hidden_size, num_layers): 39 | self.hidden_size = hidden_size 40 | self.num_layers = num_layers 41 | self.num_attention_heads = 8 42 | self.hidden_dropout_prob = 0.2 43 | self.attention_probs_dropout_prob = 0.2 44 | 45 | 46 | class GAT(nn.Module): 47 | def __init__(self, config_gat): 48 | super(GAT, self).__init__() 49 | layer = GATLayer(config_gat) 50 | self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)]) 51 | 52 | def forward(self, input_graph): 53 | hidden_states = input_graph 54 | for layer_module in self.encoder: 55 | hidden_states = layer_module(hidden_states) 56 | return hidden_states # B, seq_len, D 57 | 58 | 59 | class RcnnEncoder(nn.Module): 60 | def __init__(self, opt): 61 | super(RcnnEncoder, self).__init__() 62 | self.embed_size = opt.embed_size 63 | self.fc_image = nn.Linear(opt.img_dim, self.embed_size) 64 | self.init_weights() 65 | 66 | def init_weights(self): 67 | """Xavier initialization for the fully connected layer 68 | """ 69 | r = np.sqrt(6.) / np.sqrt(self.fc_image.in_features + 70 | self.fc_image.out_features) 71 | self.fc_image.weight.data.uniform_(-r, r) 72 | self.fc_image.bias.data.fill_(0) 73 | 74 | def forward(self, images, img_pos): # (b, 100, 2048) (b,100,1601+6) 75 | img_f = self.fc_image(images) 76 | return img_f # (b,100,768) 77 | 78 | 79 | # tutorials/09 - Image Captioning 80 | class EncoderImageFull(nn.Module): 81 | 82 | def __init__(self, opt): 83 | """Load pretrained VGG19 and replace top fc layer.""" 84 | super(EncoderImageFull, self).__init__() 85 | self.embed_size = opt.embed_size 86 | 87 | self.cnn = resnet152(pretrained=True) 88 | # self.fc = nn.Sequential(nn.Linear(2048, self.embed_size), nn.ReLU(), nn.Dropout(0.1)) 89 | self.fc = nn.Linear(opt.img_dim, self.embed_size) 90 | if not opt.finetune: 91 | print('image-encoder-resnet no grad!') 92 | for param in self.cnn.parameters(): 93 | param.requires_grad = False 94 | else: 95 | print('image-encoder-resnet fine-tuning !') 96 | 97 | self.init_weights() 98 | 99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 100 | 101 | def load_state_dict(self, state_dict): 102 | """ 103 | Handle the models saved before commit pytorch/vision@989d52a 104 | """ 105 | if 'cnn.classifier.1.weight' in state_dict: 106 | state_dict['cnn.classifier.0.weight'] = state_dict[ 107 | 'cnn.classifier.1.weight'] 108 | del state_dict['cnn.classifier.1.weight'] 109 | state_dict['cnn.classifier.0.bias'] = state_dict[ 110 | 'cnn.classifier.1.bias'] 111 | del state_dict['cnn.classifier.1.bias'] 112 | state_dict['cnn.classifier.3.weight'] = state_dict[ 113 | 'cnn.classifier.4.weight'] 114 | del state_dict['cnn.classifier.4.weight'] 115 | state_dict['cnn.classifier.3.bias'] = state_dict[ 116 | 'cnn.classifier.4.bias'] 117 | del state_dict['cnn.classifier.4.bias'] 118 | 119 | super(EncoderImageFull, self).load_state_dict(state_dict) 120 | 121 | def init_weights(self): 122 | """Xavier initialization for the fully connected layer 123 | """ 124 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 125 | self.fc.out_features) 126 | self.fc.weight.data.uniform_(-r, r) 127 | self.fc.bias.data.fill_(0) 128 | 129 | def forward(self, images): 130 | features_orig = self.cnn(images) 131 | features_top = features_orig[-1] 132 | features = features_top.view(features_top.size(0), features_top.size(1), -1).transpose(2, 1) # b, 49, 2048 133 | features = self.fc(features) 134 | 135 | return features 136 | 137 | 138 | # tutorials/08 - Language Model 139 | # RNN Based Language Model 140 | class EncoderText(nn.Module): 141 | 142 | def __init__(self, opt): 143 | super(EncoderText, self).__init__() 144 | self.embed_size = opt.embed_size 145 | # word embedding 146 | self.embed = nn.Embedding(opt.vocab_size, opt.word_dim) 147 | # caption embedding 148 | self.rnn = nn.GRU(opt.word_dim, opt.embed_size, opt.num_layers, batch_first=True) 149 | vocab = pickle.load(open('vocab/'+opt.data_name+'_vocab.pkl', 'rb')) 150 | word2idx = vocab.word2idx 151 | # self.init_weights() 152 | self.init_weights('glove', word2idx, opt.word_dim) 153 | self.dropout = nn.Dropout(0.1) 154 | 155 | def init_weights(self, wemb_type, word2idx, word_dim): 156 | if wemb_type.lower() == 'random_init': 157 | nn.init.xavier_uniform_(self.embed.weight) 158 | else: 159 | # Load pretrained word embedding 160 | if 'fasttext' == wemb_type.lower(): 161 | wemb = torchtext.vocab.FastText() 162 | elif 'glove' == wemb_type.lower(): 163 | wemb = torchtext.vocab.GloVe() 164 | else: 165 | raise Exception('Unknown word embedding type: {}'.format(wemb_type)) 166 | assert wemb.vectors.shape[1] == word_dim 167 | 168 | # quick-and-dirty trick to improve word-hit rate 169 | missing_words = [] 170 | for word, idx in word2idx.items(): 171 | if word not in wemb.stoi: 172 | word = word.replace('-', '').replace('.', '').replace("'", '') 173 | if '/' in word: 174 | word = word.split('/')[0] 175 | if word in wemb.stoi: 176 | self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]] 177 | else: 178 | missing_words.append(word) 179 | print('Words: {}/{} found in vocabulary; {} words missing'.format( 180 | len(word2idx) - len(missing_words), len(word2idx), len(missing_words))) 181 | 182 | def forward(self, x, lengths): 183 | # return out 184 | x = self.embed(x) 185 | x = self.dropout(x) 186 | 187 | packed = pack_padded_sequence(x, lengths, batch_first=True) 188 | 189 | # Forward propagate RNN 190 | out, _ = self.rnn(packed) 191 | 192 | # Reshape *final* output to (batch_size, hidden_size) 193 | padded = pad_packed_sequence(out, batch_first=True) 194 | cap_emb, cap_len = padded 195 | 196 | cap_emb = l2norm(cap_emb, dim=-1) 197 | cap_emb_mean = torch.mean(cap_emb, 1) 198 | cap_emb_mean = l2norm(cap_emb_mean) 199 | 200 | return cap_emb, cap_emb_mean 201 | 202 | 203 | class Fusion(nn.Module): 204 | def __init__(self, opt): 205 | super(Fusion, self).__init__() 206 | self.f_size = opt.embed_size 207 | self.gate0 = nn.Linear(self.f_size, self.f_size) 208 | self.gate1 = nn.Linear(self.f_size, self.f_size) 209 | 210 | self.fusion0 = nn.Linear(self.f_size, self.f_size) 211 | self.fusion1 = nn.Linear(self.f_size, self.f_size) 212 | 213 | def forward(self, vec1, vec2): 214 | features_1 = self.gate0(vec1) 215 | features_2 = self.gate1(vec2) 216 | t = torch.sigmoid(self.fusion0(features_1) + self.fusion1(features_2)) 217 | f = t * features_1 + (1 - t) * features_2 218 | return f 219 | 220 | 221 | class DSRAN(nn.Module): 222 | 223 | def __init__(self, opt): 224 | super(DSRAN, self).__init__() 225 | self.K = opt.K 226 | self.img_enc = EncoderImageFull(opt) 227 | self.rcnn_enc = RcnnEncoder(opt) 228 | self.txt_enc = EncoderText(opt) 229 | config_rcnn = GATopt(opt.embed_size, 1) 230 | config_img= GATopt(opt.embed_size, 1) 231 | config_cap= GATopt(opt.embed_size, 1) 232 | config_joint= GATopt(opt.embed_size, 1) 233 | # SSR 234 | self.gat_1 = GAT(config_rcnn) 235 | self.gat_2 = GAT(config_img) 236 | self.gat_cap = GAT(config_cap) 237 | # JSR 238 | self.gat_cat_1 = GAT(config_joint) 239 | if self.K == 2: 240 | self.gat_cat_2 = GAT(config_joint) 241 | self.fusion = Fusion(opt) 242 | elif self.K == 4: 243 | self.gat_cat_2 = GAT(config_joint) 244 | self.gat_cat_3 = GAT(config_joint) 245 | self.gat_cat_4 = GAT(config_joint) 246 | self.fusion = Fusion(opt) 247 | self.fusion2 = Fusion(opt) 248 | self.fusion3 = Fusion(opt) 249 | 250 | def forward(self, images, img_rcnn, img_pos, captions, lengths): 251 | img_emb_orig = self.gat_2(self.img_enc(images)) 252 | rcnn_emb = self.rcnn_enc(img_rcnn, img_pos) 253 | rcnn_emb = self.gat_1(rcnn_emb) 254 | img_cat = torch.cat((img_emb_orig, rcnn_emb), 1) 255 | img_cat_1 = self.gat_cat_1(img_cat) 256 | img_cat_1 = torch.mean(img_cat_1, dim=1) 257 | if self.K == 1: 258 | img_cat = img_cat_1 259 | elif self.K == 2: 260 | img_cat_2 = self.gat_cat_2(img_cat) 261 | img_cat_2 = torch.mean(img_cat_2, dim=1) 262 | img_cat = self.fusion(img_cat_1, img_cat_2) 263 | elif self.K == 4: 264 | img_cat_2 = self.gat_cat_2(img_cat) 265 | img_cat_2 = torch.mean(img_cat_2, dim=1) 266 | img_cat_3 = self.gat_cat_3(img_cat) 267 | img_cat_3 = torch.mean(img_cat_3, dim=1) 268 | img_cat_4 = self.gat_cat_4(img_cat) 269 | img_cat_4 = torch.mean(img_cat_4, dim=1) 270 | img_cat_1_1 = self.fusion(img_cat_1, img_cat_2) 271 | img_cat_1_2 = self.fusion2(img_cat_3, img_cat_4) 272 | img_cat = self.fusion3(img_cat_1_1, img_cat_1_2) 273 | img_emb = l2norm(img_cat) 274 | cap_emb, cap_emb_mean = self.txt_enc(captions, lengths) 275 | cap_gat = self.gat_cap(cap_emb) 276 | cap_embs = l2norm(torch.mean(cap_gat, dim=1)) 277 | 278 | return img_emb, cap_embs 279 | 280 | 281 | def cosine_sim(im, s): 282 | """Cosine similarity between all the image and sentence pairs 283 | """ 284 | return im.mm(s.t()) 285 | 286 | class ContrastiveLoss(nn.Module): 287 | """ 288 | Compute contrastive loss 289 | """ 290 | 291 | def __init__(self, margin=0): 292 | super(ContrastiveLoss, self).__init__() 293 | self.margin = margin 294 | self.sim = cosine_sim 295 | 296 | def forward(self, im, s): 297 | # compute image-sentence score matrix 298 | scores = self.sim(im, s) 299 | diagonal = scores.diag().view(im.size(0), 1) 300 | 301 | d1 = diagonal.expand_as(scores) 302 | d2 = diagonal.t().expand_as(scores) 303 | im_sn = scores - d1 304 | c_sn = scores - d2 305 | # compare every diagonal score to scores in its column 306 | # caption retrieval 307 | cost_s = (self.margin + scores - d1).clamp(min=0) 308 | # compare every diagonal score to scores in its row 309 | # image retrieval 310 | cost_im = (self.margin + scores - d2).clamp(min=0) 311 | # clear diagonals 312 | mask = torch.eye(scores.size(0)) > .5 313 | I = Variable(mask) 314 | if torch.cuda.is_available(): 315 | I = I.cuda() 316 | cost_s = cost_s.masked_fill_(I, 0) 317 | cost_im = cost_im.masked_fill_(I, 0) 318 | 319 | # keep the maximum violating negative for each query 320 | 321 | cost_s = cost_s.max(1)[0] 322 | cost_im = cost_im.max(0)[0] 323 | 324 | return cost_s.sum() + cost_im.sum() 325 | 326 | 327 | class VSE(object): 328 | """ 329 | rkiros/uvs model 330 | """ 331 | def __init__(self, opt): 332 | # tutorials/09 - Image Captioning 333 | # Build Models 334 | self.grad_clip = opt.grad_clip 335 | 336 | self.DSRAN = DSRAN(opt) 337 | if torch.cuda.is_available(): 338 | self.DSRAN.cuda() 339 | cudnn.benchmark = True 340 | # Loss and Optimizer 341 | self.criterion = ContrastiveLoss(margin=opt.margin) 342 | params = list(self.DSRAN.parameters()) 343 | 344 | self.params = params 345 | 346 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 347 | 348 | self.Eiters = 0 349 | 350 | def state_dict(self): 351 | state_dict = [self.DSRAN.state_dict()] 352 | return state_dict 353 | 354 | def load_state_dict(self, state_dict): 355 | self.DSRAN.load_state_dict(state_dict[0]) 356 | 357 | def train_start(self): 358 | """switch to train mode 359 | """ 360 | self.DSRAN.train() 361 | 362 | def val_start(self): 363 | """switch to evaluate mode 364 | """ 365 | self.DSRAN.eval() 366 | 367 | def forward_emb(self, images, captions, img_rcnn, img_pos, lengths, volatile=False): 368 | """Compute the image and caption embeddings 369 | """ 370 | # Set mini-batch dataset 371 | 372 | if torch.cuda.is_available(): 373 | images = images.cuda() 374 | captions = captions.cuda() 375 | img_rcnn = img_rcnn.cuda() 376 | img_pos = img_pos.cuda() 377 | 378 | img_emb, cap_emb = self.DSRAN(images, img_rcnn, img_pos, captions, lengths) 379 | return img_emb, cap_emb 380 | 381 | def forward_loss(self, img_emb, cap_emb, **kwargs): 382 | """Compute the loss given pairs of image and caption embeddings 383 | """ 384 | loss = self.criterion(img_emb, cap_emb) 385 | self.logger.update('Le', loss.data, img_emb.size(0)) 386 | return loss 387 | 388 | def train_emb(self, images, captions, img_rcnn, img_pos, lengths, ids=None, *args): 389 | """One training step given images and captions. 390 | """ 391 | self.Eiters += 1 392 | self.logger.update('Eit', self.Eiters) 393 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 394 | 395 | # compute the embeddings 396 | img_emb, cap_emb = self.forward_emb(images, captions, img_rcnn, img_pos, lengths) 397 | # measure accuracy and record loss 398 | self.optimizer.zero_grad() 399 | loss = self.forward_loss(img_emb, cap_emb) 400 | 401 | # compute gradient and do SGD step 402 | loss.backward() 403 | if self.grad_clip > 0: 404 | clip_grad_norm_(self.params, self.grad_clip) 405 | self.optimizer.step() 406 | 407 | -------------------------------------------------------------------------------- /model_bert.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen & Linyang Li, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init 13 | from torch.autograd import Variable 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | from collections import OrderedDict 18 | import copy 19 | from resnet import resnet152 20 | from pytorch_pretrained_bert.modeling import BertModel 21 | from pytorch_pretrained_bert.optimization import BertAdam 22 | import time 23 | from GAT import GATLayer 24 | 25 | 26 | def l2norm(X): 27 | norm = torch.pow(X, 2).sum(dim=-1, keepdim=True).sqrt() 28 | X = torch.div(X, norm) 29 | return X 30 | 31 | 32 | class RcnnEncoder(nn.Module): 33 | def __init__(self, opt): 34 | super(RcnnEncoder, self).__init__() 35 | self.embed_size = opt.embed_size 36 | self.fc_image = nn.Sequential(nn.Linear(opt.img_dim, opt.img_dim), 37 | nn.ReLU(), 38 | nn.Linear(opt.img_dim, self.embed_size), 39 | nn.ReLU(), 40 | nn.Dropout(0.1)) 41 | self.fc_pos = nn.Sequential(nn.Linear(6 + 1601, self.embed_size), 42 | nn.ReLU(), 43 | nn.Dropout(0.1)) 44 | self.fc = nn.Linear(self.embed_size * 2, self.embed_size) 45 | 46 | def forward(self, images, img_pos): # (b, 100, 2048) (b,100,1601+6) 47 | img_f = self.fc_image(images) 48 | img_pe = self.fc_pos(img_pos) 49 | img_embs = img_f + img_pe 50 | return img_embs # (b,100,768) 51 | 52 | 53 | class ImageEncoder(nn.Module): 54 | 55 | def __init__(self, opt): 56 | super(ImageEncoder, self).__init__() 57 | self.embed_size = opt.embed_size 58 | self.cnn = resnet152(pretrained=True) 59 | self.fc = nn.Sequential(nn.Linear(opt.img_dim, opt.embed_size), nn.ReLU(), nn.Dropout(0.1)) 60 | if not opt.ft_res: 61 | print('image-encoder-resnet no grad!') 62 | for param in self.cnn.parameters(): 63 | param.requires_grad = False 64 | else: 65 | print('image-encoder-resnet fine-tuning !') 66 | 67 | # def load_state_dict(self, state_dict): 68 | # if 'cnn.classifier.1.weight' in state_dict: 69 | # state_dict['cnn.classifier.0.weight'] = state_dict[ 70 | # 'cnn.classifier.1.weight'] 71 | # del state_dict['cnn.classifier.1.weight'] 72 | # state_dict['cnn.classifier.0.bias'] = state_dict[ 73 | # 'cnn.classifier.1.bias'] 74 | # del state_dict['cnn.classifier.1.bias'] 75 | # state_dict['cnn.classifier.3.weight'] = state_dict[ 76 | # 'cnn.classifier.4.weight'] 77 | # del state_dict['cnn.classifier.4.weight'] 78 | # state_dict['cnn.classifier.3.bias'] = state_dict[ 79 | # 'cnn.classifier.4.bias'] 80 | # del state_dict['cnn.classifier.4.bias'] 81 | 82 | # super(ImageEncoder, self).load_state_dict(state_dict) 83 | 84 | def forward(self, images): 85 | features_orig = self.cnn(images) 86 | features_top = features_orig[-1] 87 | features = features_top.view(features_top.size(0), features_top.size(1), -1).transpose(2, 1) # b, 49, 2048 88 | features = self.fc(features) 89 | return features 90 | 91 | 92 | class TextEncoder(nn.Module): 93 | def __init__(self, opt): 94 | super(TextEncoder, self).__init__() 95 | self.bert = BertModel.from_pretrained(opt.bert_path) 96 | if not opt.ft_bert: 97 | for param in self.bert.parameters(): 98 | param.requires_grad = False 99 | print('text-encoder-bert no grad') 100 | else: 101 | print('text-encoder-bert fine-tuning !') 102 | self.embed_size = opt.embed_size 103 | self.fc = nn.Sequential(nn.Linear(opt.bert_size, opt.embed_size), nn.ReLU(), nn.Dropout(0.1)) 104 | 105 | def forward(self, captions): 106 | all_encoders, pooled = self.bert(captions) 107 | out = all_encoders[-1] 108 | out = self.fc(out) 109 | return out 110 | 111 | 112 | class GATopt(object): 113 | def __init__(self, hidden_size, num_layers): 114 | self.hidden_size = hidden_size 115 | self.num_layers = num_layers 116 | self.num_attention_heads = 8 117 | self.hidden_dropout_prob = 0.2 118 | self.attention_probs_dropout_prob = 0.2 119 | 120 | 121 | class GAT(nn.Module): 122 | def __init__(self, config_gat): 123 | super(GAT, self).__init__() 124 | layer = GATLayer(config_gat) 125 | self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)]) 126 | 127 | def forward(self, input_graph): 128 | hidden_states = input_graph 129 | for layer_module in self.encoder: 130 | hidden_states = layer_module(hidden_states) 131 | return hidden_states # B, seq_len, D 132 | 133 | 134 | def cosine_sim(im, s): 135 | return im.mm(s.t()) 136 | 137 | 138 | class ContrastiveLoss(nn.Module): 139 | def __init__(self, margin=0): 140 | super(ContrastiveLoss, self).__init__() 141 | self.margin = margin 142 | self.sim = cosine_sim 143 | 144 | def forward(self, im, s): 145 | scores = self.sim(im, s) 146 | diagonal = scores.diag().view(im.size(0), 1) 147 | 148 | d1 = diagonal.expand_as(scores) 149 | d2 = diagonal.t().expand_as(scores) 150 | im_sn = scores - d1 151 | c_sn = scores - d2 152 | cost_s = (self.margin + scores - d1).clamp(min=0) 153 | 154 | cost_im = (self.margin + scores - d2).clamp(min=0) 155 | 156 | mask = torch.eye(scores.size(0)) > .5 157 | I = Variable(mask) 158 | if torch.cuda.is_available(): 159 | I = I.cuda() 160 | cost_s = cost_s.masked_fill_(I, 0) 161 | cost_im = cost_im.masked_fill_(I, 0) 162 | 163 | cost_s = cost_s.max(1)[0] 164 | cost_im = cost_im.max(0)[0] 165 | return cost_s.sum() + cost_im.sum() 166 | 167 | 168 | def get_optimizer(params, opt, t_total=-1): 169 | bertadam = BertAdam(params, lr=opt.learning_rate, warmup=opt.warmup, t_total=t_total) 170 | return bertadam 171 | 172 | 173 | class Fusion(nn.Module): 174 | def __init__(self, opt): 175 | super(Fusion, self).__init__() 176 | self.f_size = opt.embed_size 177 | self.gate0 = nn.Linear(self.f_size, self.f_size) 178 | self.gate1 = nn.Linear(self.f_size, self.f_size) 179 | 180 | self.fusion0 = nn.Linear(self.f_size, self.f_size) 181 | self.fusion1 = nn.Linear(self.f_size, self.f_size) 182 | 183 | def forward(self, vec1, vec2): 184 | features_1 = self.gate0(vec1) 185 | features_2 = self.gate1(vec2) 186 | t = torch.sigmoid(self.fusion0(features_1) + self.fusion1(features_2)) 187 | f = t * features_1 + (1 - t) * features_2 188 | return f 189 | 190 | 191 | class DSRAN(nn.Module): 192 | def __init__(self, opt): 193 | super(DSRAN, self).__init__() 194 | self.img_enc = ImageEncoder(opt) 195 | self.txt_enc = TextEncoder(opt) 196 | self.rcnn_enc = RcnnEncoder(opt) 197 | 198 | config_img = GATopt(opt.embed_size, 1) 199 | config_cap = GATopt(opt.embed_size, 1) 200 | config_rcnn = GATopt(opt.embed_size, 1) 201 | config_joint = GATopt(opt.embed_size, 1) 202 | 203 | self.K = opt.K 204 | # SSR 205 | self.gat_1 = GAT(config_img) 206 | self.gat_2 = GAT(config_rcnn) 207 | self.gat_cap = GAT(config_cap) 208 | # JSR 209 | self.gat_cat = GAT(config_joint) 210 | if self.K == 2: 211 | self.gat_cat_1 = GAT(config_joint) 212 | self.fusion = Fusion(opt) 213 | elif self.K == 4: 214 | self.gat_cat_1 = GAT(config_joint) 215 | self.gat_cat_2 = GAT(config_joint) 216 | self.gat_cat_3 = GAT(config_joint) 217 | 218 | self.fusion = Fusion(opt) 219 | self.fusion_1 = Fusion(opt) 220 | self.fusion_2 = Fusion(opt) 221 | 222 | def forward(self, images_orig, rcnn_fe, img_pos, captions): 223 | 224 | img_emb_orig = self.gat_1(self.img_enc(images_orig)) 225 | rcnn_emb = self.rcnn_enc(rcnn_fe, img_pos) 226 | rcnn_emb = self.gat_2(rcnn_emb) 227 | img_cat = torch.cat((img_emb_orig, rcnn_emb), 1) 228 | img_cat_1 = self.gat_cat(img_cat) 229 | img_cat_1 = torch.mean(img_cat_1, dim=1) 230 | if self.K == 1: 231 | img_cat = img_cat_1 232 | elif self.K == 2: 233 | img_cat_2 = self.gat_cat_1(img_cat) 234 | img_cat_2 = torch.mean(img_cat_2, dim=1) 235 | img_cat = self.fusion(img_cat_1, img_cat_2) 236 | elif self.K == 4: 237 | img_cat_2 = self.gat_cat_1(img_cat) 238 | img_cat_2 = torch.mean(img_cat_2, dim=1) 239 | img_cat_3 = self.gat_cat_2(img_cat) 240 | img_cat_3 = torch.mean(img_cat_3, dim=1) 241 | img_cat_4 = self.gat_cat_3(img_cat) 242 | img_cat_4 = torch.mean(img_cat_4, dim=1) 243 | img_cat_1_1 = self.fusion_1(img_cat_1, img_cat_2) 244 | img_cat_1_2 = self.fusion_2(img_cat_3, img_cat_4) 245 | img_cat = self.fusion(img_cat_1_1, img_cat_1_2) 246 | img_emb = l2norm(img_cat) 247 | cap_emb = self.txt_enc(captions) 248 | cap_gat = self.gat_cap(cap_emb) 249 | cap_embs = l2norm(torch.mean(cap_gat, dim=1)) 250 | 251 | return img_emb, cap_embs 252 | 253 | 254 | class VSE(object): 255 | 256 | def __init__(self, opt): 257 | self.DSRAN = DSRAN(opt) 258 | self.DSRAN = nn.DataParallel(self.DSRAN) 259 | if torch.cuda.is_available(): 260 | self.DSRAN.cuda() 261 | cudnn.benchmark = True 262 | self.criterion = ContrastiveLoss(margin=opt.margin) 263 | params = list(self.DSRAN.named_parameters()) 264 | param_optimizer = params 265 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 266 | optimizer_grouped_parameters = [ 267 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 268 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 269 | ] 270 | t_total = opt.l_train * opt.num_epochs 271 | if opt.warmup == -1: 272 | t_total = -1 273 | self.optimizer = get_optimizer(params=optimizer_grouped_parameters, opt=opt, t_total=t_total) 274 | self.Eiters = 0 275 | 276 | def state_dict(self): 277 | state_dict = self.DSRAN.state_dict() 278 | return state_dict 279 | 280 | def load_state_dict(self, state_dict): 281 | self.DSRAN.load_state_dict(state_dict) 282 | 283 | def train_start(self): 284 | self.DSRAN.train() 285 | 286 | def val_start(self): 287 | self.DSRAN.eval() 288 | 289 | def forward_emb(self, images_orig, rcnn_fe, img_pos, captions): 290 | if torch.cuda.is_available(): 291 | images_orig = images_orig.cuda() 292 | rcnn_fe = rcnn_fe.cuda() 293 | img_pos = img_pos.cuda() 294 | captions = captions.cuda() 295 | 296 | img_emb, cap_emb = self.DSRAN(images_orig, rcnn_fe, img_pos, captions) 297 | 298 | return img_emb, cap_emb 299 | 300 | def forward_loss(self, img_emb, cap_emb, **kwargs): 301 | loss = self.criterion(img_emb, cap_emb) 302 | self.logger.update('Le', loss.data, img_emb.size(0)) 303 | return loss 304 | 305 | def train_emb(self, images, images_orig, img_pos, captions, ids=None, *args): 306 | self.Eiters += 1 307 | self.logger.update('Eit', self.Eiters) 308 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 309 | 310 | img_emb, cap_emb = self.forward_emb(images_orig, images, img_pos, captions) 311 | 312 | self.optimizer.zero_grad() 313 | loss = self.forward_loss(img_emb, cap_emb) 314 | 315 | loss.backward() 316 | self.optimizer.step() 317 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/pytorch_pretrained_bert/.DS_Store -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | CONFIG_NAME = "config.json" 37 | WEIGHTS_NAME = "pytorch_model.bin" 38 | 39 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 40 | 41 | 42 | def url_to_filename(url, etag=None): 43 | """ 44 | Convert `url` into a hashed filename in a repeatable way. 45 | If `etag` is specified, append its hash to the url's, delimited 46 | by a period. 47 | """ 48 | url_bytes = url.encode('utf-8') 49 | url_hash = sha256(url_bytes) 50 | filename = url_hash.hexdigest() 51 | 52 | if etag: 53 | etag_bytes = etag.encode('utf-8') 54 | etag_hash = sha256(etag_bytes) 55 | filename += '.' + etag_hash.hexdigest() 56 | 57 | return filename 58 | 59 | 60 | def filename_to_url(filename, cache_dir=None): 61 | """ 62 | Return the url and etag (which may be ``None``) stored for `filename`. 63 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 64 | """ 65 | if cache_dir is None: 66 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 67 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 68 | cache_dir = str(cache_dir) 69 | 70 | cache_path = os.path.join(cache_dir, filename) 71 | if not os.path.exists(cache_path): 72 | raise EnvironmentError("file {} not found".format(cache_path)) 73 | 74 | meta_path = cache_path + '.json' 75 | if not os.path.exists(meta_path): 76 | raise EnvironmentError("file {} not found".format(meta_path)) 77 | 78 | with open(meta_path, encoding="utf-8") as meta_file: 79 | metadata = json.load(meta_file) 80 | url = metadata['url'] 81 | etag = metadata['etag'] 82 | 83 | return url, etag 84 | 85 | 86 | def cached_path(url_or_filename, cache_dir=None): 87 | """ 88 | Given something that might be a URL (or might be a local path), 89 | determine which. If it's a URL, download the file and cache it, and 90 | return the path to the cached file. If it's already a local path, 91 | make sure the file exists and then return the path. 92 | """ 93 | if cache_dir is None: 94 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 95 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 96 | url_or_filename = str(url_or_filename) 97 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 98 | cache_dir = str(cache_dir) 99 | 100 | parsed = urlparse(url_or_filename) 101 | 102 | if parsed.scheme in ('http', 'https', 's3'): 103 | # URL, so get it from the cache (downloading if necessary) 104 | return get_from_cache(url_or_filename, cache_dir) 105 | elif os.path.exists(url_or_filename): 106 | # File, and it exists. 107 | return url_or_filename 108 | elif parsed.scheme == '': 109 | # File, but it doesn't exist. 110 | raise EnvironmentError("file {} not found".format(url_or_filename)) 111 | else: 112 | # Something unknown 113 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 114 | 115 | 116 | def split_s3_path(url): 117 | """Split a full s3 path into the bucket name and path.""" 118 | parsed = urlparse(url) 119 | if not parsed.netloc or not parsed.path: 120 | raise ValueError("bad s3 path {}".format(url)) 121 | bucket_name = parsed.netloc 122 | s3_path = parsed.path 123 | # Remove '/' at beginning of path. 124 | if s3_path.startswith("/"): 125 | s3_path = s3_path[1:] 126 | return bucket_name, s3_path 127 | 128 | 129 | def s3_request(func): 130 | """ 131 | Wrapper function for s3 requests in order to create more helpful error 132 | messages. 133 | """ 134 | 135 | @wraps(func) 136 | def wrapper(url, *args, **kwargs): 137 | try: 138 | return func(url, *args, **kwargs) 139 | except ClientError as exc: 140 | if int(exc.response["Error"]["Code"]) == 404: 141 | raise EnvironmentError("file {} not found".format(url)) 142 | else: 143 | raise 144 | 145 | return wrapper 146 | 147 | 148 | @s3_request 149 | def s3_etag(url): 150 | """Check ETag on S3 object.""" 151 | s3_resource = boto3.resource("s3") 152 | bucket_name, s3_path = split_s3_path(url) 153 | s3_object = s3_resource.Object(bucket_name, s3_path) 154 | return s3_object.e_tag 155 | 156 | 157 | @s3_request 158 | def s3_get(url, temp_file): 159 | """Pull a file directly from S3.""" 160 | s3_resource = boto3.resource("s3") 161 | bucket_name, s3_path = split_s3_path(url) 162 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 163 | 164 | 165 | def http_get(url, temp_file): 166 | req = requests.get(url, stream=True) 167 | content_length = req.headers.get('Content-Length') 168 | total = int(content_length) if content_length is not None else None 169 | progress = tqdm(unit="B", total=total) 170 | for chunk in req.iter_content(chunk_size=1024): 171 | if chunk: # filter out keep-alive new chunks 172 | progress.update(len(chunk)) 173 | temp_file.write(chunk) 174 | progress.close() 175 | 176 | 177 | def get_from_cache(url, cache_dir=None): 178 | """ 179 | Given a URL, look for the corresponding dataset in the local cache. 180 | If it's not there, download it. Then return the path to the cached file. 181 | """ 182 | if cache_dir is None: 183 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 184 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 185 | cache_dir = str(cache_dir) 186 | 187 | if not os.path.exists(cache_dir): 188 | os.makedirs(cache_dir) 189 | 190 | # Get eTag to add to filename, if it exists. 191 | if url.startswith("s3://"): 192 | etag = s3_etag(url) 193 | else: 194 | response = requests.head(url, allow_redirects=True) 195 | if response.status_code != 200: 196 | raise IOError("HEAD request failed for url {} with status code {}" 197 | .format(url, response.status_code)) 198 | etag = response.headers.get("ETag") 199 | 200 | filename = url_to_filename(url, etag) 201 | 202 | # get cache path to put the file 203 | cache_path = os.path.join(cache_dir, filename) 204 | 205 | if not os.path.exists(cache_path): 206 | # Download to temporary file, then copy to cache dir once finished. 207 | # Otherwise you get corrupt cache entries if the download gets interrupted. 208 | with tempfile.NamedTemporaryFile() as temp_file: 209 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 210 | 211 | # GET file object 212 | if url.startswith("s3://"): 213 | s3_get(url, temp_file) 214 | else: 215 | http_get(url, temp_file) 216 | 217 | # we are copying the file before closing it, so flush to avoid truncation 218 | temp_file.flush() 219 | # shutil.copyfileobj() starts at the current position, so go to the start 220 | temp_file.seek(0) 221 | 222 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 223 | with open(cache_path, 'wb') as cache_file: 224 | shutil.copyfileobj(temp_file, cache_file) 225 | 226 | logger.info("creating metadata file for %s", cache_path) 227 | meta = {'url': url, 'etag': etag} 228 | meta_path = cache_path + '.json' 229 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 230 | json.dump(meta, meta_file) 231 | 232 | logger.info("removing temp file %s", temp_file.name) 233 | 234 | return cache_path 235 | 236 | 237 | def read_set_from_file(filename): 238 | ''' 239 | Extract a de-duped collection (set) of text from a file. 240 | Expected file format is one item per line. 241 | ''' 242 | collection = set() 243 | with open(filename, 'r', encoding='utf-8') as file_: 244 | for line in file_: 245 | collection.add(line.rstrip()) 246 | return collection 247 | 248 | 249 | def get_file_extension(path, dot=True, lower=True): 250 | ext = os.path.splitext(path)[1] 251 | ext = ext if dot else ext[1:] 252 | return ext.lower() if lower else ext 253 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup - 30 | return 0.5 * (1. + math.cos(math.pi * x_)) 31 | 32 | def warmup_constant(x, warmup=0.002): 33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 34 | Learning rate is 1. afterwards. """ 35 | if x < warmup: 36 | return x/warmup 37 | return 1.0 38 | 39 | def warmup_linear(x, warmup=0.002): 40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 41 | After `t_total`-th training step, learning rate is zero. """ 42 | if x < warmup: 43 | return x/warmup 44 | return max((x-1.)/(warmup-1.), 0) 45 | 46 | SCHEDULES = { 47 | 'warmup_cosine': warmup_cosine, 48 | 'warmup_constant': warmup_constant, 49 | 'warmup_linear': warmup_linear, 50 | } 51 | 52 | 53 | class BertAdam(Optimizer): 54 | """Implements BERT version of Adam algorithm with weight decay fix. 55 | Params: 56 | lr: learning rate 57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 58 | t_total: total number of training steps for the learning 59 | rate schedule, -1 means constant learning rate. Default: -1 60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 61 | b1: Adams b1. Default: 0.9 62 | b2: Adams b2. Default: 0.999 63 | e: Adams epsilon. Default: 1e-6 64 | weight_decay: Weight decay. Default: 0.01 65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 66 | """ 67 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 68 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 69 | max_grad_norm=1.0): 70 | if lr is not required and lr < 0.0: 71 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 72 | if schedule not in SCHEDULES: 73 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 74 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 75 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 76 | if not 0.0 <= b1 < 1.0: 77 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 78 | if not 0.0 <= b2 < 1.0: 79 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 80 | if not e >= 0.0: 81 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 82 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 83 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 84 | max_grad_norm=max_grad_norm) 85 | super(BertAdam, self).__init__(params, defaults) 86 | 87 | def get_lr(self): 88 | lr = [] 89 | for group in self.param_groups: 90 | for p in group['params']: 91 | state = self.state[p] 92 | if len(state) == 0: 93 | return [0] 94 | if group['t_total'] != -1: 95 | schedule_fct = SCHEDULES[group['schedule']] 96 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 97 | else: 98 | lr_scheduled = group['lr'] 99 | lr.append(lr_scheduled) 100 | return lr 101 | 102 | def step(self, closure=None): 103 | """Performs a single optimization step. 104 | 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | warned_for_t_total = False 114 | 115 | for group in self.param_groups: 116 | for p in group['params']: 117 | if p.grad is None: 118 | continue 119 | grad = p.grad.data 120 | if grad.is_sparse: 121 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 122 | 123 | state = self.state[p] 124 | 125 | # State initialization 126 | if len(state) == 0: 127 | state['step'] = 0 128 | # Exponential moving average of gradient values 129 | state['next_m'] = torch.zeros_like(p.data) 130 | # Exponential moving average of squared gradient values 131 | state['next_v'] = torch.zeros_like(p.data) 132 | 133 | next_m, next_v = state['next_m'], state['next_v'] 134 | beta1, beta2 = group['b1'], group['b2'] 135 | 136 | # Add grad clipping 137 | if group['max_grad_norm'] > 0: 138 | clip_grad_norm_(p, group['max_grad_norm']) 139 | 140 | # Decay the first and second moment running average coefficient 141 | # In-place operations to update the averages at the same time 142 | next_m.mul_(beta1).add_(1 - beta1, grad) 143 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 144 | update = next_m / (next_v.sqrt() + group['e']) 145 | 146 | # Just adding the square of the weights to the loss function is *not* 147 | # the correct way of using L2 regularization/weight decay with Adam, 148 | # since that will interact with the m and v parameters in strange ways. 149 | # 150 | # Instead we want to decay the weights in a manner that doesn't interact 151 | # with the m/v parameters. This is equivalent to adding the square 152 | # of the weights to the loss with plain (non-momentum) SGD. 153 | if group['weight_decay'] > 0.0: 154 | update += group['weight_decay'] * p.data 155 | 156 | if group['t_total'] != -1: 157 | schedule_fct = SCHEDULES[group['schedule']] 158 | progress = state['step']/group['t_total'] 159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 160 | # warning for exceeding t_total (only active with warmup_linear 161 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 162 | logger.warning( 163 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 164 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 165 | warned_for_t_total = True 166 | # end warning 167 | else: 168 | lr_scheduled = group['lr'] 169 | 170 | update_with_lr = lr_scheduled * update 171 | p.data.add_(-update_with_lr) 172 | 173 | state['step'] += 1 174 | 175 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 176 | # No bias correction 177 | # bias_correction1 = 1 - beta1 ** state['step'] 178 | # bias_correction2 = 1 - beta2 ** state['step'] 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .file_utils import cached_path 26 | import json 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, 79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 80 | """Constructs a BertTokenizer. 81 | 82 | Args: 83 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 84 | do_lower_case: Whether to lower case the input 85 | Only has an effect when do_wordpiece_only=False 86 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 87 | max_len: An artificial maximum length to truncate tokenized sequences to; 88 | Effective maximum length is always the minimum of this 89 | value (if specified) and the underlying BERT model's 90 | sequence length. 91 | never_split: List of tokens which will never be split during tokenization. 92 | Only has an effect when do_wordpiece_only=False 93 | """ 94 | if not os.path.isfile(vocab_file): 95 | raise ValueError( 96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 98 | self.vocab = load_vocab(vocab_file) 99 | 100 | self.ids_to_tokens = collections.OrderedDict( 101 | [(ids, tok) for tok, ids in self.vocab.items()]) 102 | self.do_basic_tokenize = do_basic_tokenize 103 | if do_basic_tokenize: 104 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 105 | never_split=never_split) 106 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 107 | # print('core_vocab loaded_ (norm for squad gen)') 108 | self.max_len = max_len if max_len is not None else int(1e12) 109 | 110 | def tokenize(self, text): 111 | split_tokens = [] 112 | if self.do_basic_tokenize: 113 | for token in self.basic_tokenizer.tokenize(text): 114 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 115 | split_tokens.append(sub_token) 116 | else: 117 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 118 | return split_tokens 119 | 120 | def convert_tokens_to_ids(self, tokens): 121 | """Converts a sequence of tokens into ids using the vocab.""" 122 | ids = [] 123 | for token in tokens: 124 | ids.append(self.vocab[token]) 125 | if len(ids) > self.max_len: 126 | logger.warning( 127 | "Token indices sequence length is longer than the specified maximum " 128 | " sequence length for this BERT model ({} > {}). Running this" 129 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 130 | ) 131 | return ids 132 | 133 | def convert_ids_to_tokens(self, ids): 134 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 135 | tokens = [] 136 | for i in ids: 137 | tokens.append(self.ids_to_tokens[i]) 138 | return tokens 139 | 140 | def save_vocabulary(self, vocab_path): 141 | """Save the tokenizer vocabulary to a directory or file.""" 142 | index = 0 143 | if os.path.isdir(vocab_path): 144 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 145 | with open(vocab_file, "w", encoding="utf-8") as writer: 146 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 147 | if index != token_index: 148 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 149 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 150 | index = token_index 151 | writer.write(token + u'\n') 152 | index += 1 153 | return vocab_file 154 | 155 | @classmethod 156 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 157 | """ 158 | Instantiate a PreTrainedBertModel from a pre-trained model file. 159 | Download and cache the pre-trained model file if needed. 160 | """ 161 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 162 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 163 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 164 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 165 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 166 | "you may want to check this behavior.") 167 | kwargs['do_lower_case'] = False 168 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 169 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 170 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 171 | "but you may want to check this behavior.") 172 | kwargs['do_lower_case'] = True 173 | else: 174 | vocab_file = pretrained_model_name_or_path 175 | if os.path.isdir(vocab_file): 176 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 177 | # redirect to the cache, if necessary 178 | try: 179 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 180 | except EnvironmentError: 181 | logger.error( 182 | "Model name '{}' was not found in model name list ({}). " 183 | "We assumed '{}' was a path or url but couldn't find any file " 184 | "associated to this path or url.".format( 185 | pretrained_model_name_or_path, 186 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 187 | vocab_file)) 188 | return None 189 | if resolved_vocab_file == vocab_file: 190 | logger.info("loading vocabulary file {}".format(vocab_file)) 191 | else: 192 | logger.info("loading vocabulary file {} from cache at {}".format( 193 | vocab_file, resolved_vocab_file)) 194 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 195 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 196 | # than the number of positional embeddings 197 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 198 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 199 | # Instantiate tokenizer. 200 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 201 | return tokenizer 202 | 203 | 204 | class BasicTokenizer(object): 205 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 206 | 207 | def __init__(self, 208 | do_lower_case=True, 209 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 210 | """Constructs a BasicTokenizer. 211 | 212 | Args: 213 | do_lower_case: Whether to lower case the input. 214 | """ 215 | self.do_lower_case = do_lower_case 216 | self.never_split = never_split 217 | 218 | def tokenize(self, text): 219 | """Tokenizes a piece of text.""" 220 | text = self._clean_text(text) 221 | # This was added on November 1st, 2018 for the multilingual and Chinese 222 | # models. This is also applied to the English models now, but it doesn't 223 | # matter since the English models were not trained on any Chinese data 224 | # and generally don't have any Chinese data in them (there are Chinese 225 | # characters in the vocabulary because Wikipedia does have some Chinese 226 | # words in the English Wikipedia.). 227 | text = self._tokenize_chinese_chars(text) 228 | orig_tokens = whitespace_tokenize(text) 229 | split_tokens = [] 230 | for token in orig_tokens: 231 | if self.do_lower_case and token not in self.never_split: 232 | token = token.lower() 233 | token = self._run_strip_accents(token) 234 | split_tokens.extend(self._run_split_on_punc(token)) 235 | 236 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 237 | return output_tokens 238 | 239 | def _run_strip_accents(self, text): 240 | """Strips accents from a piece of text.""" 241 | text = unicodedata.normalize("NFD", text) 242 | output = [] 243 | for char in text: 244 | cat = unicodedata.category(char) 245 | if cat == "Mn": 246 | continue 247 | output.append(char) 248 | return "".join(output) 249 | 250 | def _run_split_on_punc(self, text): 251 | """Splits punctuation on a piece of text.""" 252 | if text in self.never_split: 253 | return [text] 254 | chars = list(text) 255 | i = 0 256 | start_new_word = True 257 | output = [] 258 | while i < len(chars): 259 | char = chars[i] 260 | if _is_punctuation(char): 261 | output.append([char]) 262 | start_new_word = True 263 | else: 264 | if start_new_word: 265 | output.append([]) 266 | start_new_word = False 267 | output[-1].append(char) 268 | i += 1 269 | 270 | return ["".join(x) for x in output] 271 | 272 | def _tokenize_chinese_chars(self, text): 273 | """Adds whitespace around any CJK character.""" 274 | output = [] 275 | for char in text: 276 | cp = ord(char) 277 | if self._is_chinese_char(cp): 278 | output.append(" ") 279 | output.append(char) 280 | output.append(" ") 281 | else: 282 | output.append(char) 283 | return "".join(output) 284 | 285 | def _is_chinese_char(self, cp): 286 | """Checks whether CP is the codepoint of a CJK character.""" 287 | # This defines a "chinese character" as anything in the CJK Unicode block: 288 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 289 | # 290 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 291 | # despite its name. The modern Korean Hangul alphabet is a different block, 292 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 293 | # space-separated words, so they are not treated specially and handled 294 | # like the all of the other languages. 295 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 296 | (cp >= 0x3400 and cp <= 0x4DBF) or # 297 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 298 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 299 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 300 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 301 | (cp >= 0xF900 and cp <= 0xFAFF) or # 302 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 303 | return True 304 | 305 | return False 306 | 307 | def _clean_text(self, text): 308 | """Performs invalid character removal and whitespace cleanup on text.""" 309 | output = [] 310 | for char in text: 311 | cp = ord(char) 312 | if cp == 0 or cp == 0xfffd or _is_control(char): 313 | continue 314 | if _is_whitespace(char): 315 | output.append(" ") 316 | else: 317 | output.append(char) 318 | return "".join(output) 319 | 320 | 321 | class WordpieceTokenizer(object): 322 | """Runs WordPiece tokenization.""" 323 | 324 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 325 | self.vocab = vocab 326 | self.unk_token = unk_token 327 | self.max_input_chars_per_word = max_input_chars_per_word 328 | 329 | def tokenize(self, text): 330 | """Tokenizes a piece of text into its word pieces. 331 | 332 | This uses a greedy longest-match-first algorithm to perform tokenization 333 | using the given vocabulary. 334 | 335 | For example: 336 | input = "unaffable" 337 | output = ["un", "##aff", "##able"] 338 | 339 | Args: 340 | text: A single token or whitespace separated tokens. This should have 341 | already been passed through `BasicTokenizer`. 342 | 343 | Returns: 344 | A list of wordpiece tokens. 345 | """ 346 | 347 | output_tokens = [] 348 | for token in whitespace_tokenize(text): 349 | chars = list(token) 350 | if len(chars) > self.max_input_chars_per_word: 351 | output_tokens.append(self.unk_token) 352 | continue 353 | 354 | is_bad = False 355 | start = 0 356 | sub_tokens = [] 357 | while start < len(chars): 358 | end = len(chars) 359 | cur_substr = None 360 | while start < end: 361 | substr = "".join(chars[start:end]) 362 | if start > 0: 363 | substr = "##" + substr 364 | if substr in self.vocab: 365 | cur_substr = substr 366 | break 367 | end -= 1 368 | if cur_substr is None: 369 | is_bad = True 370 | break 371 | sub_tokens.append(cur_substr) 372 | start = end 373 | 374 | if is_bad: 375 | output_tokens.append(self.unk_token) 376 | else: 377 | output_tokens.extend(sub_tokens) 378 | return output_tokens 379 | 380 | 381 | def _is_whitespace(char): 382 | """Checks whether `chars` is a whitespace character.""" 383 | # \t, \n, and \r are technically contorl characters but we treat them 384 | # as whitespace since they are generally considered as such. 385 | if char == " " or char == "\t" or char == "\n" or char == "\r": 386 | return True 387 | cat = unicodedata.category(char) 388 | if cat == "Zs": 389 | return True 390 | return False 391 | 392 | 393 | def _is_control(char): 394 | """Checks whether `chars` is a control character.""" 395 | # These are technically control characters but we count them as whitespace 396 | # characters. 397 | if char == "\t" or char == "\n" or char == "\r": 398 | return False 399 | cat = unicodedata.category(char) 400 | if cat.startswith("C"): 401 | return True 402 | return False 403 | 404 | 405 | def _is_punctuation(char): 406 | """Checks whether `chars` is a punctuation character.""" 407 | cp = ord(char) 408 | # We treat all non-letter/number ASCII as punctuation. 409 | # Characters such as "^", "$", and "`" are not in the Unicode 410 | # Punctuation class but we treat them as punctuation anyways, for 411 | # consistency. 412 | if ((33 <= cp <= 47) or (58 <= cp <= 64) or 413 | (91 <= cp <= 96) or (123 <= cp <= 126)): 414 | return True 415 | cat = unicodedata.category(char) 416 | if cat.startswith("P"): 417 | return True 418 | return False 419 | 420 | -------------------------------------------------------------------------------- /rerank.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Re-ranking and ensemble implementation based on 3 | # "Matching Images and Text with Multi-modal Tensor Fusion and Re-ranking" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import numpy as np 11 | import time 12 | import argparse 13 | 14 | 15 | def i2t_rerank(sim, K1, K2): #(d,15,1) 16 | 17 | size_i = sim.shape[0] # d 18 | size_t = sim.shape[1] # 5d 19 | sort_i2t = np.argsort(-sim, 1) 20 | sort_t2i = np.argsort(-sim, 0) 21 | sort_i2t_re = np.copy(sort_i2t)[:, :K1] 22 | address = np.array([]) 23 | 24 | for i in range(size_i): 25 | for j in range(K1): 26 | result_t = sort_i2t[i][j] 27 | query = sort_t2i[:, result_t] 28 | # query = sort_t2i[:K2, result_t] 29 | address = np.append(address, np.where(query == i)[0][0]) 30 | 31 | sort = np.argsort(address) 32 | sort_i2t_re[i] = sort_i2t_re[i][sort] 33 | address = np.array([]) 34 | 35 | sort_i2t[:,:K1] = sort_i2t_re 36 | 37 | return sort_i2t 38 | 39 | 40 | def t2i_rerank(sim, K1, K2): 41 | 42 | size_i = sim.shape[0] 43 | size_t = sim.shape[1] 44 | sort_i2t = np.argsort(-sim, 1) 45 | sort_t2i = np.argsort(-sim, 0) 46 | sort_t2i_re = np.copy(sort_t2i)[:K1, :] 47 | address = np.array([]) 48 | 49 | for i in range(size_t): 50 | for j in range(K1): 51 | result_i = sort_t2i[j][i] 52 | query = sort_i2t[result_i, :] 53 | # print(query) 54 | # query = sort_t2i[:K2, result_t] 55 | ranks = 1e20 56 | # for k in range(5): 57 | # qewfo = i//5 * 5 + k 58 | # print(np.where(query == i)) 59 | tmp = np.where(query == i)[0][0] 60 | if tmp < ranks: 61 | ranks = tmp 62 | address = np.append(address, ranks) 63 | 64 | sort = np.argsort(address) 65 | sort_t2i_re[:, i] = sort_t2i_re[:, i][sort] 66 | address = np.array([]) 67 | 68 | sort_t2i[:K1, :] = sort_t2i_re 69 | 70 | return sort_t2i 71 | 72 | 73 | def t2i_rerank_new(sim, sim_T, K1, K2): 74 | 75 | size_i = sim.shape[0] 76 | size_t = sim.shape[1] 77 | sort_i2t = np.argsort(-sim, 1) 78 | sort_t2i = np.argsort(-sim, 0) 79 | sort_t2i_re = np.copy(sort_t2i)[:K1, :] 80 | 81 | sort_t2t = np.argsort(-sim_T, 1) # 按行从大到小排序 82 | # print(sort_t2t.shape) 83 | sort_t2t_re = np.copy(sort_t2t)[:, :K2] 84 | address = np.array([]) 85 | 86 | for i in range(size_t): 87 | for j in range(K1): 88 | result_i = sort_t2i[j][i] # Ij 89 | query = sort_i2t[result_i, :] # 第j张图片对应T的排序 90 | # query = sort_t2i[:K2, result_t] 91 | ranks = 1e20 92 | G = sort_t2t_re[i] 93 | for k in range(K2): 94 | # qewfo = i//5 * 5 + k 95 | # print(qewfo) 96 | tmp = np.where(query == G[k])[0][0] 97 | if tmp < ranks: 98 | ranks = tmp 99 | address = np.append(address, ranks) 100 | 101 | sort = np.argsort(address) 102 | sort_t2i_re[:, i] = sort_t2i_re[:, i][sort] 103 | address = np.array([]) 104 | 105 | sort_t2i[:K1, :] = sort_t2i_re 106 | 107 | return sort_t2i 108 | 109 | 110 | def acc_i2t2(input): 111 | """Computes the precision@k for the specified values of k of i2t""" 112 | #input = collect_match(input).numpy() 113 | image_size = input.shape[0] 114 | ranks = np.zeros(image_size) 115 | top1 = np.zeros(image_size) 116 | 117 | for index in range(image_size): 118 | inds = input[index] 119 | # Score 120 | # if index == 197: 121 | # print('s') 122 | rank = 1e20 123 | for i in range(5 * index, min(5 * index + 5, image_size*5), 1): 124 | tmp = np.where(inds == i)[0][0] 125 | if tmp < rank: 126 | rank = tmp 127 | ranks[index] = rank 128 | top1[index] = inds[0] 129 | 130 | 131 | # Compute metrics 132 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 133 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 134 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 135 | medr = np.floor(np.median(ranks)) + 1 136 | meanr = ranks.mean() + 1 137 | 138 | return (r1, r5, r10, medr, meanr), (ranks, top1) 139 | 140 | 141 | def acc_t2i2(input): 142 | """Computes the precision@k for the specified values of k of t2i""" 143 | #input = collect_match(input).numpy() 144 | image_size = input.shape[0] 145 | ranks = np.zeros(5*image_size) 146 | top1 = np.zeros(5*image_size) 147 | 148 | # --> (5N(caption), N(image)) 149 | input = input.T 150 | 151 | for index in range(image_size): 152 | for i in range(5): 153 | inds = input[5 * index + i] 154 | ranks[5 * index + i] = np.where(inds == index)[0][0] 155 | top1[5 * index + i] = inds[0] 156 | 157 | # Compute metrics 158 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 159 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 160 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 161 | medr = np.floor(np.median(ranks)) + 1 162 | meanr = ranks.mean() + 1 163 | 164 | return (r1, r5, r10, medr, meanr), (ranks, top1) 165 | 166 | 167 | def main(): 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument('--data_name', default='coco', help='data name') 170 | parser.add_argument('--fold', action='store_true', help='fold5') 171 | opt = parser.parse_args() 172 | data = opt.data_name 173 | fold = opt.fold 174 | # The accuracy computing 175 | # Input the prediction similarity score matrix (d * 5d) 176 | if data == 'coco': 177 | if fold == True: 178 | path1 = '' 179 | path = 'coco_sims/' 180 | r1 = np.array((0,0,0)) 181 | r1_t = np.array((0,0,0)) 182 | r2 = np.array((0,0,0)) # rerank 183 | r2_t = np.array((0,0,0)) 184 | for i in range(5): 185 | d1 = np.load(path1+'sims_full_%d.npy' % i) 186 | d2 = np.load(path+'sims_full_%d.npy' % i) 187 | 188 | # d1T = np.load(path1+'sims_full_T_%d.npy' % i) 189 | # d2T = np.load(path+'sims_full_T_%d.npy' % i) 190 | 191 | d = d1+d2 192 | # d_T = d1T+d2T 193 | 194 | t1 = time.time() 195 | # calculate the i2t score after rerank 196 | sort_rerank = i2t_rerank(d, 15, 1) 197 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1)) 198 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank) 199 | 200 | print(r1i, r5i, r10i, medri, meanri) 201 | print(r1i2, r5i2, r10i2, medri2, meanri2) 202 | r1 = r1 + np.array((r1i, r5i, r10i)) 203 | r2 = r2 + np.array((r1i2, r5i2, r10i2)) 204 | 205 | # calculate the t2i score after rerank 206 | # sort_rerank = t2i_rerank(d, 20, 1) 207 | # sort_rerank = t2i_rerank_new(d, d_T, 20, 1) 208 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0)) 209 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank) 210 | 211 | print(r1t, r5t, r10t, medrt, meanrt) 212 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2) 213 | # print((r1t, r5t, r10t)) 214 | r1_t = r1_t + np.array((r1t, r5t, r10t)) 215 | # r2_t = r2_t + np.array((r1t2, r5t2, r10t2)) 216 | t2 = time.time() 217 | print(t2-t1) 218 | print('--------------------') 219 | print('5-cross test') 220 | print(r1/5) 221 | print(r1_t/5) 222 | print('rerank!') 223 | print(r2/5) 224 | # print(r2_t/5) 225 | else: 226 | path = 'coco_sims/' 227 | path1 = '' 228 | d1 = np.load(path+'sims_full_5k.npy') 229 | d2 = np.load(path1+'sims_full_5k.npy') 230 | d = d1+ d2 231 | t1 = time.time() 232 | # calculate the i2t score after rerank 233 | sort_rerank = i2t_rerank(d, 15, 1) 234 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1)) 235 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank) 236 | 237 | print(r1i, r5i, r10i, medri, meanri) 238 | print(r1i2, r5i2, r10i2, medri2, meanri2) 239 | 240 | # calculate the t2i score after rerank 241 | # sort_rerank = t2i_rerank(d, 20, 1) 242 | # sort_rerank = t2i_rerank_new(d, d_T, 12, 1) 243 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0)) 244 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank) 245 | 246 | print(r1t, r5t, r10t, medrt, meanrt) 247 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2) 248 | t2 = time.time() 249 | print(t2-t1) 250 | 251 | else: 252 | 253 | d1 = np.load('flickr_sims/sims_f.npy') 254 | d2 = np.load('sims_f.npy') 255 | d = d1+d2 256 | # d1T = np.load('flickr_sims/sims_f_T.npy') 257 | # d2T = np.load('sims_f_T.npy') 258 | 259 | # d_T = d1T+d2T 260 | 261 | t1 = time.time() 262 | # calculate the i2t score after rerank 263 | sort_rerank = i2t_rerank(d, 15, 1) 264 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1)) 265 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank) 266 | 267 | print(r1i, r5i, r10i, medri, meanri) 268 | print(r1i2, r5i2, r10i2, medri2, meanri2) 269 | 270 | 271 | # calculate the t2i score after rerank 272 | 273 | # sort_rerank = t2i_rerank_new(d, d_T, 20, 4) 274 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0)) 275 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank) 276 | 277 | print(r1t, r5t, r10t, medrt, meanrt) 278 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2) 279 | rsum = r1i+r5i+r10i+r1t+r5t+r10t 280 | print('rsum:%f' % rsum) 281 | rsum_rr = r1i2+r5i2+r10i2+r1t+r5t+r10t 282 | print('rsum_rr:%f' % rsum_rr) 283 | t2 = time.time() 284 | print(t2-t1) 285 | 286 | if __name__ == '__main__': 287 | main() 288 | 289 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | 5 | 6 | RES_NET_file_path = 'resnet152-b121ed2d.pth' 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=dilation, groups=groups, bias=False, dilation=dilation) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, dilation=1, norm_layer=None): 25 | super(BasicBlock, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | if groups != 1 or base_width != 64: 29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 30 | if dilation > 1: 31 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 32 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = norm_layer(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = norm_layer(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 64 | base_width=64, dilation=1, norm_layer=None): 65 | super(Bottleneck, self).__init__() 66 | if norm_layer is None: 67 | norm_layer = nn.BatchNorm2d 68 | width = int(planes * (base_width / 64.)) * groups 69 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 70 | self.conv1 = conv1x1(inplanes, width) 71 | self.bn1 = norm_layer(width) 72 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 73 | self.bn2 = norm_layer(width) 74 | self.conv3 = conv1x1(width, planes * self.expansion) 75 | self.bn3 = norm_layer(planes * self.expansion) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | identity = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 106 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 107 | norm_layer=None): 108 | super(ResNet, self).__init__() 109 | if norm_layer is None: 110 | norm_layer = nn.BatchNorm2d 111 | self._norm_layer = norm_layer 112 | 113 | self.inplanes = 64 114 | self.dilation = 1 115 | if replace_stride_with_dilation is None: 116 | # each element in the tuple indicates if we should replace 117 | # the 2x2 stride with a dilated convolution instead 118 | replace_stride_with_dilation = [False, False, False] 119 | if len(replace_stride_with_dilation) != 3: 120 | raise ValueError("replace_stride_with_dilation should be None " 121 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 122 | self.groups = groups 123 | self.base_width = width_per_group 124 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 125 | bias=False) 126 | self.bn1 = norm_layer(self.inplanes) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 129 | self.layer1 = self._make_layer(block, 64, layers[0]) 130 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 131 | dilate=replace_stride_with_dilation[0]) 132 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 133 | dilate=replace_stride_with_dilation[1]) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 135 | dilate=replace_stride_with_dilation[2]) 136 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 137 | self.fc = nn.Linear(512 * block.expansion, num_classes) 138 | 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 142 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 143 | nn.init.constant_(m.weight, 1) 144 | nn.init.constant_(m.bias, 0) 145 | 146 | # Zero-initialize the last BN in each residual branch, 147 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 148 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 149 | if zero_init_residual: 150 | for m in self.modules(): 151 | if isinstance(m, Bottleneck): 152 | nn.init.constant_(m.bn3.weight, 0) 153 | elif isinstance(m, BasicBlock): 154 | nn.init.constant_(m.bn2.weight, 0) 155 | 156 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 157 | norm_layer = self._norm_layer 158 | downsample = None 159 | previous_dilation = self.dilation 160 | if dilate: 161 | self.dilation *= stride 162 | stride = 1 163 | if stride != 1 or self.inplanes != planes * block.expansion: 164 | downsample = nn.Sequential( 165 | conv1x1(self.inplanes, planes * block.expansion, stride), 166 | norm_layer(planes * block.expansion), 167 | ) 168 | 169 | layers = [] 170 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 171 | self.base_width, previous_dilation, norm_layer)) 172 | self.inplanes = planes * block.expansion 173 | for _ in range(1, blocks): 174 | layers.append(block(self.inplanes, planes, groups=self.groups, 175 | base_width=self.base_width, dilation=self.dilation, 176 | norm_layer=norm_layer)) 177 | 178 | return nn.Sequential(*layers) 179 | 180 | def forward(self, x): 181 | x = self.conv1(x) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | x = self.maxpool(x) 185 | 186 | x1 = self.layer1(x) 187 | x2 = self.layer2(x1) 188 | x3 = self.layer3(x2) 189 | x4 = self.layer4(x3) # extract the output before avg pooling 190 | # print(x4.size()) 191 | 192 | return x1,x2,x3,x4 193 | 194 | 195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 196 | model = ResNet(block, layers, **kwargs) 197 | if pretrained: 198 | state_dict = torch.load(RES_NET_file_path) 199 | model.load_state_dict(state_dict) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False, progress=True, **kwargs): 204 | r"""ResNet-50 model from 205 | `"Deep Residual Learning for Image Recognition" '_ 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | progress (bool): If True, displays a progress bar of the download to stderr 209 | """ 210 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 211 | **kwargs) -------------------------------------------------------------------------------- /runs/BERT/bert_models: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /runs/GRU/gru_models: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_bert_cc.sh: -------------------------------------------------------------------------------- 1 | echo "BERT" 2 | echo "MSCOCO" 3 | echo "evalaute cc_model1" 4 | # python evaluation_bert.py --model BERT/cc_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 5 | python evaluation_bert.py --model BERT/cc_model1_ --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval 6 | mv sims_full_0.npy sims_full_1.npy sims_full_2.npy sims_full_3.npy sims_full_4.npy sims_full_5k.npy ./coco_sims 7 | # mv sims_f_T.npy ./flickr_sims 8 | echo "evalaute cc_model2" 9 | # python evaluation_bert.py --model BERT/cc_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 10 | python evaluation_bert.py --model BERT/cc_model2_ --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval 11 | echo "ensemble and rerank!" 12 | echo "fold5-1K" 13 | python rerank.py --data_name coco --fold 14 | echo "5K" 15 | python rerank.py --data_name coco 16 | -------------------------------------------------------------------------------- /test_bert_f.sh: -------------------------------------------------------------------------------- 1 | echo "BERT" 2 | echo "Flickr30K" 3 | echo "evalaute f_model1" 4 | # python evaluation_bert.py --model BERT/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 5 | python evaluation_bert.py --model BERT/f_model1_ --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/ 6 | mv sims_f.npy ./flickr_sims 7 | mv sims_f_T.npy ./flickr_sims 8 | echo "evalaute f_model2" 9 | # python evaluation_bert.py --model BERT/f_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 10 | python evaluation_bert.py --model BERT/f_model2_ --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/ 11 | echo "ensemble and rerank!" 12 | python rerank.py --data_name f30k 13 | -------------------------------------------------------------------------------- /test_gru_cc.sh: -------------------------------------------------------------------------------- 1 | echo "GRU" 2 | echo "MSCOCO" 3 | echo "evalaute cc_model1" 4 | # python evaluation_bert.py --model GRU/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 5 | python evaluation.py --model GRU/cc_model1 --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval 6 | mv sims_full_0.npy sims_full_1.npy sims_full_2.npy sims_full_3.npy sims_full_4.npy sims_full_5k.npy ./coco_sims 7 | echo "evalaute cc_model2" 8 | # python evaluation_bert.py --model GRU/cc_model2 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 9 | python evaluation.py --model GRU/cc_model2 --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval 10 | echo "ensemble and rerank!" 11 | echo "fold5-1K" 12 | python rerank.py --data_name coco --fold 13 | echo "5K" 14 | python rerank.py --data_name coco 15 | -------------------------------------------------------------------------------- /test_gru_f.sh: -------------------------------------------------------------------------------- 1 | 2 | echo "GRU" 3 | echo "Flickr30K" 4 | echo "evalaute f_model1" 5 | # python evaluation_bert.py --model GRU/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 6 | python evaluation.py --model fmodel_1_ 7 | mv sims_f.npy ./flickr_sims 8 | mv sims_f_T.npy ./flickr_sims 9 | echo "evalaute f_model2" 10 | # python evaluation_bert.py --model GRU/f_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH" 11 | python evaluation.py --model fmodel_2_ 12 | echo "ensemble and rerank!" 13 | python rerank.py --data_name f30k 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import pickle 11 | import os 12 | import time 13 | import shutil 14 | 15 | import torch 16 | 17 | import data 18 | from vocab import Vocabulary # NOQA 19 | from model import VSE 20 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data 21 | import numpy as np 22 | import logging 23 | import tensorboard_logger as tb_logger 24 | 25 | import argparse 26 | 27 | 28 | def main(): 29 | # Hyper Parameters 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--data_path', default='data', 32 | help='path to datasets') 33 | parser.add_argument('--data_name', default='f30k', 34 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k') 35 | parser.add_argument('--vocab_path', default='vocab', 36 | help='Path to saved vocabulary pickle files.') 37 | parser.add_argument('--margin', default=0.2, type=float, 38 | help='Rank loss margin.') 39 | parser.add_argument('--num_epochs', default=30, type=int, 40 | help='Number of training epochs.') 41 | parser.add_argument('--batch_size', default=128, type=int, 42 | help='Size of a training mini-batch.') 43 | parser.add_argument('--word_dim', default=300, type=int, 44 | help='Dimensionality of the word embedding.') 45 | parser.add_argument('--embed_size', default=1024, type=int, 46 | help='Dimensionality of the joint embedding.') 47 | parser.add_argument('--grad_clip', default=2., type=float, 48 | help='Gradient clipping threshold.') 49 | parser.add_argument('--crop_size', default=224, type=int, 50 | help='Size of an image crop as the CNN input.') 51 | parser.add_argument('--num_layers', default=1, type=int, 52 | help='Number of GRU layers.') 53 | parser.add_argument('--learning_rate', default=2e-4, type=float, 54 | help='Initial learning rate.') 55 | parser.add_argument('--lr_update', default=15, type=int, 56 | help='Number of epochs to update the learning rate.') 57 | parser.add_argument('--workers', default=10, type=int, 58 | help='Number of data loader workers.') 59 | parser.add_argument('--log_step', default=100, type=int, 60 | help='Number of steps to print and record the log.') 61 | parser.add_argument('--val_step', default=500, type=int, 62 | help='Number of steps to run validation.') 63 | parser.add_argument('--logger_name', default='runs/test', 64 | help='Path to save the model and Tensorboard log.') 65 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 66 | help='path to latest checkpoint (default: none)') 67 | parser.add_argument('--img_dim', default=2048, type=int, 68 | help='Dimensionality of the image embedding.') 69 | parser.add_argument('--finetune', action='store_true', 70 | help='Fine-tune the image encoder.') 71 | parser.add_argument('--use_restval', action='store_true', 72 | help='Use the restval data for training on MSCOCO.') 73 | parser.add_argument('--reset_train', action='store_true', 74 | help='Ensure the training is always done in ' 75 | 'train mode (Not recommended).') 76 | parser.add_argument('--K', default=2, type=int,help='num of JSR.') 77 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', 78 | type=str, help='path to the pre-computed image features') 79 | parser.add_argument('--region_bbox_file', 80 | default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', 81 | type=str, help='path to the region_bbox_file(.h5)') 82 | opt = parser.parse_args() 83 | print(opt) 84 | 85 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 86 | tb_logger.configure(opt.logger_name, flush_secs=5) 87 | 88 | # Load Vocabulary Wrapper 89 | vocab = pickle.load(open(os.path.join( 90 | opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb')) 91 | opt.vocab_size = len(vocab) 92 | 93 | # Load data loaders 94 | train_loader, val_loader = data.get_loaders( 95 | opt.data_name, vocab, opt.crop_size, opt.batch_size, opt.workers, opt) 96 | 97 | # Construct the model 98 | model = VSE(opt) 99 | best_rsum = 0 100 | # optionally resume from a checkpoint 101 | if opt.resume: 102 | if os.path.isfile(opt.resume): 103 | print("=> loading checkpoint '{}'".format(opt.resume)) 104 | checkpoint = torch.load(opt.resume) 105 | start_epoch = checkpoint['epoch'] 106 | best_rsum = checkpoint['best_rsum'] 107 | model.load_state_dict(checkpoint['model']) 108 | # Eiters is used to show logs as the continuation of another 109 | # training 110 | model.Eiters = checkpoint['Eiters'] 111 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" 112 | .format(opt.resume, start_epoch, best_rsum)) 113 | validate(opt, val_loader, model) 114 | else: 115 | print("=> no checkpoint found at '{}'".format(opt.resume)) 116 | del checkpoint 117 | # Train the Model 118 | 119 | for epoch in range(opt.num_epochs): 120 | adjust_learning_rate(opt, model.optimizer, epoch) 121 | 122 | # train for one epoch 123 | train(opt, train_loader, model, epoch, val_loader) 124 | 125 | # evaluate on validation set 126 | rsum = validate(opt, val_loader, model) 127 | 128 | # remember best R@ sum and save checkpoint 129 | is_best = rsum > best_rsum 130 | best_rsum = max(rsum, best_rsum) 131 | save_checkpoint({ 132 | 'epoch': epoch + 1, 133 | 'model': model.state_dict(), 134 | 'best_rsum': best_rsum, 135 | 'opt': opt, 136 | 'Eiters': model.Eiters, 137 | }, is_best, epoch, prefix=opt.logger_name + '/') 138 | 139 | 140 | def train(opt, train_loader, model, epoch, val_loader): 141 | # average meters to record the training statistics 142 | batch_time = AverageMeter() 143 | data_time = AverageMeter() 144 | train_logger = LogCollector() 145 | 146 | # switch to train mode 147 | model.train_start() 148 | 149 | end = time.time() 150 | for i, train_data in enumerate(train_loader): 151 | if opt.reset_train: 152 | # Always reset to train mode, this is not the default behavior 153 | model.train_start() 154 | 155 | # measure data loading time 156 | data_time.update(time.time() - end) 157 | 158 | # make sure train logger is used 159 | model.logger = train_logger 160 | 161 | # Update the model 162 | model.train_emb(*train_data) 163 | 164 | 165 | # measure elapsed time 166 | batch_time.update(time.time() - end) 167 | end = time.time() 168 | 169 | # Print log info 170 | if model.Eiters % opt.log_step == 0: 171 | logging.info( 172 | 'Epoch: [{0}][{1}/{2}]\t' 173 | '{e_log}\t' 174 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 175 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 176 | .format( 177 | epoch, i, len(train_loader), batch_time=batch_time, 178 | data_time=data_time, e_log=str(model.logger))) 179 | 180 | # Record logs in tensorboard 181 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 182 | tb_logger.log_value('step', i, step=model.Eiters) 183 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 184 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 185 | model.logger.tb_log(tb_logger, step=model.Eiters) 186 | 187 | # validate at every val_step 188 | # if model.Eiters % opt.val_step == 0: 189 | # validate(opt, val_loader, model) 190 | 191 | 192 | def validate(opt, val_loader, model): 193 | # compute the encoding for all the validation images and captions 194 | img_embs, cap_embs = encode_data( 195 | model, val_loader, opt.log_step, logging.info) 196 | 197 | # caption retrieval 198 | (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs) 199 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 200 | (r1, r5, r10, medr, meanr)) 201 | # image retrieval 202 | (r1i, r5i, r10i, medri, meanr) = t2i( 203 | img_embs, cap_embs) 204 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 205 | (r1i, r5i, r10i, medri, meanr)) 206 | # sum of recalls to be used for early stopping 207 | currscore = r1 + r5 + r10 + r1i + r5i + r10i 208 | 209 | # record metrics in tensorboard 210 | tb_logger.log_value('r1', r1, step=model.Eiters) 211 | tb_logger.log_value('r5', r5, step=model.Eiters) 212 | tb_logger.log_value('r10', r10, step=model.Eiters) 213 | tb_logger.log_value('medr', medr, step=model.Eiters) 214 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 215 | tb_logger.log_value('r1i', r1i, step=model.Eiters) 216 | tb_logger.log_value('r5i', r5i, step=model.Eiters) 217 | tb_logger.log_value('r10i', r10i, step=model.Eiters) 218 | tb_logger.log_value('medri', medri, step=model.Eiters) 219 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 220 | tb_logger.log_value('rsum', currscore, step=model.Eiters) 221 | 222 | return currscore 223 | 224 | 225 | def save_checkpoint(state, is_best, epoch, filename='checkpoint.pth.tar', prefix=''): 226 | torch.save(state, prefix + filename) 227 | if is_best: 228 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') 229 | # shutil.copyfile(prefix + filename, prefix + 'checkpoint'+str(epoch)+'.pth.tar') 230 | 231 | 232 | def adjust_learning_rate(opt, optimizer, epoch): 233 | """Sets the learning rate to the initial LR 234 | decayed by 10 every 30 epochs""" 235 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 236 | for param_group in optimizer.param_groups: 237 | param_group['lr'] = lr 238 | 239 | 240 | def accuracy(output, target, topk=(1,)): 241 | """Computes the precision@k for the specified values of k""" 242 | maxk = max(topk) 243 | batch_size = target.size(0) 244 | 245 | _, pred = output.topk(maxk, 1, True, True) 246 | pred = pred.t() 247 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 248 | 249 | res = [] 250 | for k in topk: 251 | correct_k = correct[:k].view(-1).float().sum(0) 252 | res.append(correct_k.mul_(100.0 / batch_size)) 253 | return res 254 | 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on 3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives" 4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching" 5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng 6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020 7 | # Writen by Keyu Wen, 2020 8 | # ------------------------------------------------------------ 9 | 10 | import pickle 11 | import os 12 | import time 13 | import shutil 14 | import torch 15 | import data_bert as data 16 | from model_bert import VSE 17 | from evaluation_bert import i2t, t2i, AverageMeter, LogCollector, encode_data, simrank 18 | import numpy as np 19 | import logging 20 | import tensorboard_logger as tb_logger 21 | import argparse 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--data_path', default='data', 27 | help='path to datasets') 28 | parser.add_argument('--data_name', default='coco', 29 | help='{coco,f30k}') 30 | parser.add_argument('--margin', default=0.2, type=float, 31 | help='Rank loss margin.') 32 | parser.add_argument('--num_epochs', default=12, type=int, 33 | help='Number of training epochs.') 34 | parser.add_argument('--batch_size', default=128, type=int, 35 | help='Size of a training mini-batch.') 36 | parser.add_argument('--embed_size', default=1024, type=int, 37 | help='Dimensionality of the joint embedding.') 38 | parser.add_argument('--crop_size', default=224, type=int, 39 | help='Size of an image crop as the CNN input.') 40 | parser.add_argument('--learning_rate', default=2e-5, type=float, 41 | help='Initial learning rate.') 42 | parser.add_argument('--lr_update', default=6, type=int, 43 | help='Number of epochs to update the learning rate.') 44 | parser.add_argument('--workers', default=10, type=int, 45 | help='Number of data loader workers.') 46 | parser.add_argument('--log_step', default=100, type=int, 47 | help='Number of steps to print and record the log.') 48 | parser.add_argument('--logger_name', default='runs/grg', 49 | help='Path to save the model and Tensorboard log.') 50 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 51 | help='path to latest checkpoint (default: none)') 52 | parser.add_argument('--img_dim', default=2048, type=int, 53 | help='Dimensionality of the image embedding.') 54 | parser.add_argument('--ft_res', action='store_true', 55 | help='Fine-tune the image encoder.') 56 | parser.add_argument('--bert_path', default='uncased_L-12_H-768_A-12/', 57 | help='path of pre-trained BERT.') 58 | parser.add_argument('--ft_bert', action='store_true', 59 | help='Fine-tune the text encoder.') 60 | parser.add_argument('--bert_size', default=768, type=int, 61 | help='Dimensionality of the text embedding') 62 | parser.add_argument('--warmup', default=-1, type=float) 63 | parser.add_argument('--K', default=2, type=int,help='num of JSR.') 64 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', 65 | type=str, help='path to the pre-computed image features') 66 | parser.add_argument('--region_bbox_file', 67 | default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', 68 | type=str, help='path to the region_bbox_file(.h5)') 69 | opt = parser.parse_args() 70 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 71 | tb_logger.configure(opt.logger_name, flush_secs=5) 72 | 73 | train_loader, val_loader = data.get_loaders(opt.data_name, opt.batch_size, opt.workers, opt) 74 | opt.l_train = len(train_loader) 75 | print(opt) 76 | model = VSE(opt) 77 | best_rsum = 0 78 | if opt.resume: 79 | if os.path.isfile(opt.resume): 80 | print("=> loading checkpoint '{}'".format(opt.resume)) 81 | checkpoint = torch.load(opt.resume) 82 | start_epoch = checkpoint['epoch'] 83 | best_rsum = checkpoint['best_rsum'] 84 | model.load_state_dict(checkpoint['model']) 85 | model.Eiters = checkpoint['Eiters'] 86 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" 87 | .format(opt.resume, start_epoch, best_rsum)) 88 | validate(opt, val_loader, model)[-1] 89 | else: 90 | print("=> no checkpoint found at '{}'".format(opt.resume)) 91 | 92 | for epoch in range(opt.num_epochs): 93 | 94 | adjust_learning_rate(opt, model.optimizer, epoch) 95 | 96 | train(opt, train_loader, model, epoch, val_loader) 97 | 98 | rsum = validate(opt, val_loader, model)[-1] 99 | 100 | is_best = rsum > best_rsum 101 | best_rsum = max(rsum, best_rsum) 102 | save_checkpoint({ 103 | 'epoch': epoch + 1, 104 | 'model': model.state_dict(), 105 | 'best_rsum': best_rsum, 106 | 'opt': opt, 107 | 'Eiters': model.Eiters, 108 | }, is_best, epoch, prefix=opt.logger_name + '/') 109 | 110 | 111 | def train(opt, train_loader, model, epoch, val_loader): 112 | 113 | batch_time = AverageMeter() 114 | data_time = AverageMeter() 115 | train_logger = LogCollector() 116 | 117 | model.train_start() 118 | 119 | end = time.time() 120 | for i, train_data in enumerate(train_loader): 121 | 122 | data_time.update(time.time() - end) 123 | model.logger = train_logger 124 | model.train_emb(*train_data) 125 | batch_time.update(time.time() - end) 126 | end = time.time() 127 | 128 | if model.Eiters % opt.log_step == 0: 129 | logging.info( 130 | 'Epoch: [{0}][{1}/{2}]\t' 131 | '{e_log}\t' 132 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 133 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 134 | .format( 135 | epoch, i, len(train_loader), batch_time=batch_time, 136 | data_time=data_time, e_log=str(model.logger))) 137 | 138 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 139 | tb_logger.log_value('step', i, step=model.Eiters) 140 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 141 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 142 | model.logger.tb_log(tb_logger, step=model.Eiters) 143 | 144 | 145 | def validate(opt, val_loader, model): 146 | _, _, sims = encode_data( 147 | model, val_loader, opt.log_step, logging.info) 148 | rs = simrank(sims) 149 | del sims 150 | return rs 151 | 152 | 153 | def save_checkpoint(state, is_best, epoch, filename='checkpoint.pth.tar', prefix=''): 154 | torch.save(state, prefix + filename) 155 | if is_best: 156 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') 157 | 158 | 159 | def adjust_learning_rate(opt, optimizer, epoch): 160 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 161 | for param_group in optimizer.param_groups: 162 | param_group['lr'] = lr 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /uncased_L-12_H-768_A-12/bert_pretrained_model: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # Create a vocabulary wrapper 2 | import nltk 3 | import pickle 4 | from collections import Counter 5 | # from pycocotools.coco import COCO 6 | import json 7 | import argparse 8 | import os 9 | from nltk.stem import WordNetLemmatizer 10 | 11 | 12 | annotations = { 13 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 14 | 'coco': ['annotations/captions_train2014.json', 15 | 'annotations/captions_val2014.json'], 16 | 'f8k_precomp': ['train_caps.txt', 'dev_caps.txt'], 17 | '10crop_precomp': ['train_caps.txt', 'dev_caps.txt'], 18 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 19 | 'f8k': ['dataset_flickr8k.json'], 20 | 'f30k': ['dataset_flickr30k.json'], 21 | } 22 | 23 | 24 | class Vocabulary(object): 25 | """Simple vocabulary wrapper.""" 26 | 27 | def __init__(self): 28 | self.word2idx = {} 29 | self.idx2word = {} 30 | self.idx = 0 31 | 32 | def add_word(self, word): 33 | if word not in self.word2idx: 34 | self.word2idx[word] = self.idx 35 | self.idx2word[self.idx] = word 36 | self.idx += 1 37 | 38 | def __call__(self, word): 39 | if word not in self.word2idx: 40 | return self.word2idx[''] 41 | return self.word2idx[word] 42 | 43 | def __len__(self): 44 | return len(self.word2idx) 45 | 46 | 47 | def from_coco_json(path): 48 | coco = COCO(path) 49 | ids = coco.anns.keys() 50 | captions = [] 51 | for i, idx in enumerate(ids): 52 | captions.append(str(coco.anns[idx]['caption'])) 53 | 54 | return captions 55 | 56 | 57 | def from_flickr_json(path): 58 | dataset = json.load(open(path, 'r'))['images'] 59 | captions = [] 60 | for i, d in enumerate(dataset): 61 | captions += [str(x['raw']) for x in d['sentences']] 62 | 63 | return captions 64 | 65 | 66 | def from_txt(txt): 67 | captions = [] 68 | with open(txt, 'rb') as f: 69 | for line in f: 70 | captions.append(line.strip()) 71 | return captions 72 | 73 | 74 | def build_vocab(data_path, data_name, jsons, threshold): 75 | """Build a simple vocabulary wrapper.""" 76 | counter = Counter() 77 | for path in jsons[data_name]: 78 | full_path = os.path.join(os.path.join(data_path, data_name), path) 79 | if data_name == 'coco': 80 | captions = from_coco_json(full_path) 81 | elif data_name == 'f8k' or data_name == 'f30k': 82 | captions = from_flickr_json(full_path) 83 | else: 84 | captions = from_txt(full_path) 85 | for i, caption in enumerate(captions): 86 | tokens = nltk.tokenize.word_tokenize( 87 | caption.lower().encode('utf-8').decode('utf-8')) 88 | counter.update(tokens) 89 | 90 | if i % 1000 == 0: 91 | print("\r[%d/%d] tokenized the captions." % (i, len(captions)),end = '') 92 | # Discard if the occurrence of the word is less than min_word_cnt. 93 | words = [] 94 | counts = [] 95 | for word, cnt in counter.items(): 96 | if cnt >= threshold: 97 | words.append(word) 98 | counts.append((word, cnt)) 99 | # words = [word for word, cnt in counter.items() if cnt >= threshold] 100 | counts_new = sorted(counts, key=lambda x:x[1], reverse=True) 101 | print(counts_new) 102 | # Create a vocab wrapper and add some special tokens. 103 | vocab = Vocabulary() 104 | vocab.add_word('') 105 | vocab.add_word('') 106 | vocab.add_word('') 107 | vocab.add_word('') 108 | print(len(counts_new)) 109 | # Add words to the vocabulary. 110 | chosen_nums = 256 111 | for i, word_cnt in enumerate(counts_new): 112 | word, count = word_cnt 113 | # print(word) 114 | if i < chosen_nums: 115 | vocab.add_word(word) 116 | print(word) 117 | return vocab 118 | 119 | 120 | def main(data_path, data_name): 121 | vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=300) 122 | with open('%s_vocab.pkl' % data_name, 'wb') as f: 123 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 124 | # print("Saved vocabulary file to ", '%s_vocab.pkl' % data_name) 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--data_path', default='data') 130 | parser.add_argument('--data_name', default='f30k', 131 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k') 132 | opt = parser.parse_args() 133 | main(opt.data_path, opt.data_name) 134 | -------------------------------------------------------------------------------- /vocab/10crop_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/10crop_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/111: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vocab/coco_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/coco_resnet_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_resnet_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/coco_vgg_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_vgg_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/coco_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_vocab.pkl -------------------------------------------------------------------------------- /vocab/f30k_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f30k_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/f30k_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f30k_vocab.pkl -------------------------------------------------------------------------------- /vocab/f8k_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f8k_precomp_vocab.pkl -------------------------------------------------------------------------------- /vocab/f8k_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f8k_vocab.pkl --------------------------------------------------------------------------------