├── GAT.py
├── LICENSE
├── README.md
├── coco_sims
└── coco_sims.txt
├── data.py
├── data
├── coco
│ ├── annotations.zip
│ ├── download.sh
│ └── images
└── f30k
│ ├── dataset_flickr30k.zip
│ └── images
├── data_bert.py
├── evaluation.py
├── evaluation_bert.py
├── figures
└── model.jpg
├── flickr_sims
└── flickr_sims.txt
├── model.py
├── model_bert.py
├── pytorch_pretrained_bert
├── .DS_Store
├── file_utils.py
├── modeling.py
├── optimization.py
└── tokenization.py
├── rerank.py
├── resnet.py
├── runs
├── BERT
│ └── bert_models
└── GRU
│ └── gru_models
├── test_bert_cc.sh
├── test_bert_f.sh
├── test_gru_cc.sh
├── test_gru_f.sh
├── train.py
├── train_bert.py
├── uncased_L-12_H-768_A-12
└── bert_pretrained_model
├── vocab.py
└── vocab
├── 111
├── 10crop_precomp_vocab.pkl
├── coco_precomp_vocab.pkl
├── coco_resnet_precomp_vocab.pkl
├── coco_vgg_precomp_vocab.pkl
├── coco_vocab.pkl
├── f30k_precomp_vocab.pkl
├── f30k_vocab.pkl
├── f8k_precomp_vocab.pkl
└── f8k_vocab.pkl
/GAT.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation
3 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
4 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
5 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
6 | # Writen by Keyu Wen, 2020
7 | # ------------------------------------------------------------
8 |
9 | import math
10 | import torch
11 | from torch import nn
12 | import torch.nn.functional as F
13 |
14 |
15 | class MultiHeadAttention(nn.Module):
16 | def __init__(self, config):
17 | super(MultiHeadAttention, self).__init__()
18 |
19 | self.num_attention_heads = config.num_attention_heads
20 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
21 | self.all_head_size = self.num_attention_heads * self.attention_head_size
22 |
23 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
24 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
25 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
26 |
27 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
28 |
29 | def transpose_for_scores(self, x):
30 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
31 | x = x.view(*new_x_shape)
32 | return x.permute(0, 2, 1, 3)
33 |
34 | def forward(self, input_graph):
35 | nodes_q = self.query(input_graph)
36 | nodes_k = self.key(input_graph)
37 | nodes_v = self.value(input_graph)
38 |
39 | nodes_q_t = self.transpose_for_scores(nodes_q)
40 | nodes_k_t = self.transpose_for_scores(nodes_k)
41 | nodes_v_t = self.transpose_for_scores(nodes_v)
42 |
43 | # Take the dot product between "query" and "key" to get the raw attention scores.
44 | attention_scores = torch.matmul(nodes_q_t, nodes_k_t.transpose(-1, -2))
45 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
46 | # Apply the attention mask is (precomputed for all layers in GATModel forward() function)
47 | attention_scores = attention_scores
48 |
49 | # Normalize the attention scores to probabilities.
50 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
51 |
52 | # This is actually dropping out entire tokens to attend to, which might
53 | # seem a bit unusual, but is taken from the original Transformer paper.
54 | attention_probs = self.dropout(attention_probs)
55 |
56 | nodes_new = torch.matmul(attention_probs, nodes_v_t)
57 | nodes_new = nodes_new.permute(0, 2, 1, 3).contiguous()
58 | new_nodes_shape = nodes_new.size()[:-2] + (self.all_head_size,)
59 | nodes_new = nodes_new.view(*new_nodes_shape)
60 | return nodes_new
61 |
62 |
63 | class GATLayer(nn.Module):
64 | def __init__(self, config):
65 | super(GATLayer, self).__init__()
66 | self.mha = MultiHeadAttention(config)
67 |
68 | self.fc_in = nn.Linear(config.hidden_size, config.hidden_size)
69 | self.bn_in = nn.BatchNorm1d(config.hidden_size)
70 | self.dropout_in = nn.Dropout(config.hidden_dropout_prob)
71 |
72 | self.fc_int = nn.Linear(config.hidden_size, config.hidden_size)
73 |
74 | self.fc_out = nn.Linear(config.hidden_size, config.hidden_size)
75 | self.bn_out = nn.BatchNorm1d(config.hidden_size)
76 | self.dropout_out = nn.Dropout(config.hidden_dropout_prob)
77 |
78 | def forward(self, input_graph):
79 | attention_output = self.mha(input_graph) # multi-head attention
80 | attention_output = self.fc_in(attention_output)
81 | attention_output = self.dropout_in(attention_output)
82 | attention_output = self.bn_in((attention_output + input_graph).permute(0, 2, 1)).permute(0, 2, 1)
83 | intermediate_output = self.fc_int(attention_output)
84 | intermediate_output = F.relu(intermediate_output)
85 | intermediate_output = self.fc_out(intermediate_output)
86 | intermediate_output = self.dropout_out(intermediate_output)
87 | graph_output = self.bn_out((intermediate_output + attention_output).permute(0, 2, 1)).permute(0, 2, 1)
88 | return graph_output
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is the official source code for **Dual Semantic Relations Attention Network(DSRAN)** proposed in our journal paper [Learning Dual Semantic Relations with Graph Attention for Image-Text Matching (TCSVT 2020)](https://arxiv.org/abs/2010.11550). It is built on top of the [VSE++](https://github.com/fartashf/vsepp) in PyTorch.
3 |
4 |
5 | **The framework of DSRAN:**
6 |
7 |
8 |
9 | **The results on MSCOCO and Flickr30K dataset:(With BERT or GRU)**
10 |
11 |
12 | GRU |
13 | Image-to-Text |
14 | Text-to-Image |
15 | |
16 |
17 |
18 | Dataset |
19 | R@1 |
20 | R@5 |
21 | R@10 |
22 | R@1 |
23 | R@5 |
24 | R@10 |
25 | Rsum |
26 |
27 |
28 | MSCOCO-1K |
29 | 80.4 |
30 | 96.7 |
31 | 98.7 |
32 | 64.2 |
33 | 90.4 |
34 | 95.8 |
35 | 526.2 |
36 |
37 |
38 | MSCOCO-5K |
39 | 57.6 |
40 | 85.6 |
41 | 91.9 |
42 | 41.5 |
43 | 71.9 |
44 | 82.1 |
45 | 430.6 |
46 |
47 |
48 | Flickr30k |
49 | 79.6 |
50 | 95.6 |
51 | 97.5 |
52 | 58.6 |
53 | 85.8 |
54 | 91.3 |
55 | 508.4 |
56 |
57 |
58 |
59 |
60 |
61 | BERT |
62 | Image-to-Text |
63 | Text-to-Image |
64 | |
65 |
66 |
67 | Dataset |
68 | R@1 |
69 | R@5 |
70 | R@10 |
71 | R@1 |
72 | R@5 |
73 | R@10 |
74 | Rsum |
75 |
76 |
77 | MSCOCO-1K |
78 | 80.6 |
79 | 96.7 |
80 | 98.7 |
81 | 64.5 |
82 | 90.8 |
83 | 95.8 |
84 | 527.1 |
85 |
86 |
87 | MSCOCO-5K |
88 | 57.9 |
89 | 85.3 |
90 | 92.0 |
91 | 41.7 |
92 | 72.7 |
93 | 82.8 |
94 | 432.4 |
95 |
96 |
97 | Flickr30k |
98 | 80.5 |
99 | 95.5 |
100 | 97.9 |
101 | 59.2 |
102 | 86.0 |
103 | 91.9 |
104 | 511.0 |
105 |
106 |
107 |
108 | ## Requirements and Installation
109 | We recommended the following dependencies.
110 | * Python 3.6
111 | * PyTorch 1.1.0
112 | * NumPy (>1.12.1)
113 | * torchtext
114 | * pycocotools
115 | * nltk
116 |
117 | ## Download data
118 |
119 | Download the raw images, pre-computed image features, pre-trained BERT models, pre-trained ResNet152 model and pre-trained DSRAN models. As for the raw images, they can be downloaded from [VSE++](https://github.com/fartashf/vsepp).
120 |
121 | ```
122 | wget http://www.cs.toronto.edu/~faghri/vsepp/data.tar
123 | wget http://www.cs.toronto.edu/~faghri/vsepp/vocab.tar
124 | ```
125 | We refer to the path of extracted files for `data.tar` as `$DATA_PATH` while only raw images are used which are `coco` and `f30k`.
126 |
127 | For pre-computed image features, they can be obtained from [VLP](https://github.com/LuoweiZhou/VLP). These zip files should be extracted into the fold `data/joint-pretrain`. We refer to the path of extracted `region_bbox_file(.h5)` as `$REGION_BBOX_FILE` and regional feature paths `feat_cls_1000/` for COCO and `trainval/` for FLICKR30K as `$FEATURE_PATH`.
128 |
129 | Pre-trained ResNet152 model can be downloaded from [torchvision](https://download.pytorch.org/models/resnet152-b121ed2d.pth) and put in the root directory.
130 | ```
131 | wget https://download.pytorch.org/models/resnet152-b121ed2d.pth
132 | ```
133 | For our trained DSRAN models, you can download `runs.zip` on [Google Drive](https://drive.google.com/drive/folders/1SQiRpO3L8d9QxFSRdk31PZrxRUi3eXyW?usp=sharing) or `GRU.zip` together with `BERT.zip` on [BaiduNetDisk](https://pan.baidu.com/s/1H_iMH-QZETAdHLk03dBREA)(extract code:1119). There are totally 8 models (4 for each dataset).
134 |
135 | Pre-trained BERT models are obtained form an old version of [transformers](https://github.com/huggingface/transformers). It is noticed that there's a simpler way of using BERT as seen in [transformers](https://github.com/huggingface/transformers). We'll update the code in the future. The pre-trained models we use can be downloaded from the same [Google Drive](https://drive.google.com/drive/folders/1SQiRpO3L8d9QxFSRdk31PZrxRUi3eXyW?usp=sharing) and [BaiduNetDisk](https://pan.baidu.com/s/1H_iMH-QZETAdHLk03dBREA)(extract code:1119) links. We refer to the path of extracted files for `uncased_L-12_H-768_A-12.zip` as `$BERT_PATH`.
136 |
137 |
138 | ### Data Structure
139 | ```
140 | ├── data/
141 | | ├── coco/ /* MSCOCO raw images
142 | | | ├── images/
143 | | | | ├── train2014/
144 | | | | ├── val2014/
145 | | | ├── annotations/
146 | | ├── f30k/ /* Flickr30K raw images
147 | | | ├── images/
148 | | | ├── dataset_flickr30k.json
149 | | ├── joint-pretrain/ /* pre-computed image features
150 | | | ├── COCO/
151 | | | | ├── region_feat_gvd_wo_bgd/
152 | | | | | ├── feat_cls_1000/ /* $FEATURE_PATH
153 | | | | | ├── coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 /* $REGION_BBOX_FILE
154 | | | | ├── annotations/
155 | | | ├── flickr30k/
156 | | | | ├── region_feat_gvd_wo_bgd/
157 | | | | | ├── trainval/ /* $FEATURE_PATH
158 | | | | | ├── flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 /* $REGION_BBOX_FILE
159 | | | | ├── annotations/
160 | ```
161 |
162 | ## Evaluate trained models
163 |
164 | ### Test on single model:
165 |
166 | + Test on MSCOCO dataset (1K and 5K simultaneously):
167 |
168 | + Test on BERT-based models:
169 |
170 | ```bash
171 | python evaluation_bert.py --model BERT/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
172 | ```
173 |
174 | + Test on GRU-based models:
175 |
176 | ```bash
177 | python evaluation.py --model GRU/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
178 | ```
179 |
180 | + Test on Flickr30K dataset:
181 |
182 | + Test on BERT-based models:
183 |
184 | ```bash
185 | python evaluation_bert.py --model BERT/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
186 | ```
187 |
188 | + Test on GRU-based models:
189 |
190 | ```bash
191 | python evaluation.py --model GRU/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
192 | ```
193 |
194 | ### Test on two-models ensemble and re-rank:
195 |
196 | /* Remember to modify the "$DATA_PATH", "$REGION_BBOX_FILE" and "$FEATURE_PATH" in the .sh files.
197 |
198 | + Test on MSCOCO dataset (1K and 5K simultaneously):
199 |
200 | + Test on BERT-based models:
201 |
202 | ```bash
203 | sh test_bert_cc.sh
204 | ```
205 |
206 | + Test on GRU-based models:
207 |
208 | ```bash
209 | sh test_gru_cc.sh
210 | ```
211 |
212 | + Test on Flickr30K dataset:
213 |
214 | + Test on BERT-based models:
215 |
216 | ```bash
217 | sh test_bert_f.sh
218 | ```
219 |
220 | + Test on GRU-based models:
221 |
222 | ```bash
223 | sh test_gru_f.sh
224 | ```
225 |
226 | ## Train new models
227 |
228 | Train a model with BERT on MSCOCO:
229 |
230 | ```bash
231 | python train_bert.py --data_path "$DATA_PATH" --data_name coco --num_epochs 18 --batch_size 320 --lr_update 9 --logger_name runs/cc_bert --bert_path "$BERT_PATH" --ft_bert --warmup 0.1 --K 4 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE"
232 | ```
233 |
234 | Train a model with BERT on Flickr30K:
235 |
236 | ```bash
237 | python train_bert.py --data_path "$DATA_PATH" --data_name f30k --num_epochs 12 --batch_size 128 --lr_update 6 --logger_name runs/f_bert --bert_path "$BERT_PATH" --ft_bert --warmup 0.1 --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE"
238 | ```
239 |
240 | Train a model with GRU on MSCOCO:
241 |
242 | ```bash
243 | python train.py --data_path "$DATA_PATH" --data_name coco --num_epochs 18 --batch_size 300 --lr_update 9 --logger_name runs/cc_gru --use_restval --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE"
244 | ```
245 |
246 | Train a model with GRU on Flickr30K:
247 |
248 | ```bash
249 | python train.py --data_path "$DATA_PATH" --data_name f30k --num_epochs 16 --batch_size 128 --lr_update 8 --logger_name runs/f_gru --use_restval --K 2 --feature_path "$FEATURE_PATH" --region_bbox_file "$REGION_BBOX_FILE"
250 | ```
251 |
252 | ## Acknowledgement
253 | We thank [Linyang Li](https://github.com/LinyangLee) for the help with the code and provision of some computing resources.
254 | ## Reference
255 |
256 | If DSRAN is useful for your research, please cite our paper:
257 |
258 | ```
259 | @ARTICLE{9222079,
260 | author={Wen, Keyu and Gu, Xiaodong and Cheng, Qingrong},
261 | journal={IEEE Transactions on Circuits and Systems for Video Technology},
262 | title={Learning Dual Semantic Relations With Graph Attention for Image-Text Matching},
263 | year={2021},
264 | volume={31},
265 | number={7},
266 | pages={2866-2879},
267 | doi={10.1109/TCSVT.2020.3030656}}
268 | ```
269 |
270 | ## License
271 |
272 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0)
273 |
--------------------------------------------------------------------------------
/coco_sims/coco_sims.txt:
--------------------------------------------------------------------------------
1 | Path to save similarity matrixes during inference stage of MSCOCO.
2 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen & Linyang Li, 2020
8 | # ------------------------------------------------------------
9 |
10 | import torch
11 | import torch.utils.data as data
12 | import torchvision.transforms as transforms
13 | import os
14 | import nltk
15 | from PIL import Image
16 | from pycocotools.coco import COCO
17 | import numpy as np
18 | import json as jsonmod
19 | import time
20 | import copy
21 | import h5py
22 | import torch.nn.functional as F
23 |
24 |
25 | def get_paths(path, name='coco', use_restval=False):
26 |
27 | roots = {}
28 | ids = {}
29 | if 'coco' == name:
30 | imgdir = os.path.join(path, 'images')
31 | capdir = os.path.join(path, 'annotations')
32 | roots['train'] = {
33 | 'img': os.path.join(imgdir, 'train2014'),
34 | 'cap': os.path.join(capdir, 'captions_train2014.json')
35 | }
36 | roots['val'] = {
37 | 'img': os.path.join(imgdir, 'val2014'),
38 | 'cap': os.path.join(capdir, 'captions_val2014.json')
39 | }
40 | roots['test'] = {
41 | 'img': os.path.join(imgdir, 'val2014'),
42 | 'cap': os.path.join(capdir, 'captions_val2014.json')
43 | }
44 | roots['trainrestval'] = {
45 | 'img': (roots['train']['img'], roots['val']['img']),
46 | 'cap': (roots['train']['cap'], roots['val']['cap'])
47 | }
48 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy'))
49 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000]
50 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy'))
51 | ids['trainrestval'] = (
52 | ids['train'],
53 | np.load(os.path.join(capdir, 'coco_restval_ids.npy')))
54 | if use_restval:
55 | roots['train'] = roots['trainrestval']
56 | ids['train'] = ids['trainrestval']
57 | elif 'f8k' == name:
58 | imgdir = os.path.join(path, 'images')
59 | cap = os.path.join(path, 'dataset_flickr8k.json')
60 | roots['train'] = {'img': imgdir, 'cap': cap}
61 | roots['val'] = {'img': imgdir, 'cap': cap}
62 | roots['test'] = {'img': imgdir, 'cap': cap}
63 | ids = {'train': None, 'val': None, 'test': None}
64 | elif 'f30k' == name:
65 | imgdir = os.path.join(path, '')
66 | cap = os.path.join(path, 'dataset_flickr30k.json')
67 | roots['train'] = {'img': imgdir, 'cap': cap}
68 | roots['val'] = {'img': imgdir, 'cap': cap}
69 | roots['test'] = {'img': imgdir, 'cap': cap}
70 | ids = {'train': None, 'val': None, 'test': None}
71 |
72 | return roots, ids
73 |
74 |
75 | class CocoDataset(data.Dataset):
76 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
77 |
78 | def __init__(self, root, json, vocab, region_bbox_file, region_det_file_prefix, transform=None, ids=None):
79 | """
80 | Args:
81 | root: image directory.
82 | json: coco annotation file path.
83 | vocab: vocabulary wrapper.
84 | transform: transformer for image.
85 | """
86 | self.root = root
87 | # when using `restval`, two json files are needed
88 | if isinstance(json, tuple):
89 | self.coco = (COCO(json[0]), COCO(json[1]))
90 | else:
91 | self.coco = (COCO(json),)
92 | self.root = (root,)
93 | # if ids provided by get_paths, use split-specific ids
94 | if ids is None:
95 | self.ids = list(self.coco.anns.keys())
96 | else:
97 | self.ids = ids
98 |
99 | # if `restval` data is to be used, record the break point for ids
100 | if isinstance(self.ids, tuple):
101 | self.bp = len(self.ids[0])
102 | self.ids = list(self.ids[0]) + list(self.ids[1])
103 | else:
104 | self.bp = len(self.ids)
105 | self.vocab = vocab
106 | self.transform = transform
107 | self.region_bbox_file = region_bbox_file#'/remote-home/lyli/Workspace/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5'
108 | self.region_det_file_prefix = region_det_file_prefix#'/remote-home/lyli/Workspace/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval'
109 |
110 | def __getitem__(self, index):
111 | """This function returns a tuple that is further passed to collate_fn
112 | """
113 | vocab = self.vocab
114 | root, caption, img_id, path, image, img_rcnn, img_pe = self.get_raw_item(index)
115 |
116 | if self.transform is not None:
117 | image = self.transform(image)
118 |
119 | # Convert caption (string) to word ids.
120 | tokens = nltk.tokenize.word_tokenize(
121 | str(caption).lower().encode('utf-8').decode('utf-8'))
122 | caption = []
123 | caption.append(vocab(''))
124 | caption.extend([vocab(token) for token in tokens])
125 | caption.append(vocab(''))
126 | target = torch.Tensor(caption)
127 |
128 | return image, target, img_rcnn, img_pe, index, img_id
129 |
130 | def get_raw_item(self, index):
131 | if index < self.bp:
132 | coco = self.coco[0]
133 | root = self.root[0]
134 | else:
135 | coco = self.coco[1]
136 | root = self.root[1]
137 | ann_id = self.ids[index]
138 | caption = coco.anns[ann_id]['caption']
139 | img_id = coco.anns[ann_id]['image_id']
140 | path = coco.loadImgs(img_id)[0]['file_name']
141 | image = Image.open(os.path.join(root, path)).convert('RGB')
142 | img_rcnn, img_pe = self.get_rcnn(path)
143 |
144 | return root, caption, img_id, path, image, img_rcnn, img_pe
145 |
146 | def get_rcnn(self, path):
147 | img_id = path.split('/')[-1].split('.')[0]
148 | with h5py.File(self.region_det_file_prefix + '_feat' + img_id[-3:] + '.h5', 'r') as region_feat_f:
149 | img = torch.from_numpy(region_feat_f[img_id][:]).float()
150 |
151 | vis_pe = torch.randn(100,1601 + 6) # no position information
152 | return img, vis_pe
153 |
154 | def __len__(self):
155 | return len(self.ids)
156 |
157 |
158 | class FlickrDataset(data.Dataset):
159 | """
160 | Dataset loader for Flickr30k and Flickr8k full datasets.
161 | """
162 |
163 | def __init__(self, root, json, split, vocab, region_bbox_file, feature_path, transform=None):
164 | self.root = root
165 | self.vocab = vocab
166 | self.split = split
167 | self.transform = transform
168 | self.dataset = jsonmod.load(open(json, 'r'))['images']
169 | self.ids = []
170 | for i, d in enumerate(self.dataset):
171 | if d['split'] == split:
172 | self.ids += [(i, x) for x in range(len(d['sentences']))]
173 | self.region_bbox_file = region_bbox_file#'/home/wenkeyu/wky/projects/pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5'
174 | self.feature_path = feature_path#'/home/wenkeyu/wky/projects/pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/'
175 |
176 | def __getitem__(self, index):
177 | """This function returns a tuple that is further passed to collate_fn
178 | """
179 | vocab = self.vocab
180 | root = self.root + '/images'
181 | ann_id = self.ids[index]
182 | img_id = ann_id[0]
183 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw']
184 | path = self.dataset[img_id]['filename']
185 |
186 | image = Image.open(os.path.join(root, path)).convert('RGB')
187 | if self.transform is not None:
188 | image = self.transform(image)
189 |
190 | path_orig = copy.deepcopy(path)
191 | # print(path)
192 | path = path.replace('.jpg', '.npy')
193 | feature_path = self.feature_path
194 |
195 | image_rcnn, img_pos = self.get_rcnn(os.path.join(feature_path, path)) # return img-feature 100 2048 & pos-feature
196 |
197 | # Convert caption (string) to word ids.
198 | tokens = nltk.tokenize.word_tokenize(
199 | str(caption).lower().encode('utf-8').decode('utf-8'))
200 | caption = []
201 | caption.append(vocab(''))
202 | caption.extend([vocab(token) for token in tokens])
203 | caption.append(vocab(''))
204 | target = torch.Tensor(caption)
205 | return image, target, image_rcnn, img_pos, index, img_id
206 |
207 | def get_rcnn(self, img_path):
208 | if os.path.exists(img_path) and os.path.exists(img_path.replace('.npy', '_cls_prob.npy')):
209 | # time1 = time.time()
210 | img = torch.from_numpy(np.load(img_path))
211 | vis_pe = torch.randn(100,1601 + 6) # no position information
212 | else:
213 | img = torch.randn(100, 2048)
214 | vis_pe = torch.randn(100, 1601 + 6)
215 | return img, vis_pe
216 |
217 |
218 | def __len__(self):
219 | return len(self.ids)
220 |
221 |
222 | def collate_fn(data):
223 | """Build mini-batch tensors from a list of (image, caption) tuples.
224 | Args:
225 | data: list of (image, caption) tuple.
226 | - image: torch tensor of shape (3, 256, 256).
227 | - caption: torch tensor of shape (?); variable length.
228 |
229 | Returns:
230 | images: torch tensor of shape (batch_size, 3, 256, 256).
231 | targets: torch tensor of shape (batch_size, padded_length).
232 | lengths: list; valid length for each padded caption.
233 | """
234 | # Sort a data list by caption length
235 | data.sort(key=lambda x: len(x[1]), reverse=True)
236 | images, captions, image_rcnn, img_pos, ids, img_ids = zip(*data)
237 |
238 | # Merge images (convert tuple of 3D tensor to 4D tensor)
239 | images = torch.stack(images, 0)
240 | image_rcnn = torch.stack(image_rcnn, 0)
241 | img_pos = torch.stack(img_pos, 0)
242 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
243 | lengths = [len(cap) for cap in captions]
244 | targets = torch.zeros(len(captions), max(lengths)).long()
245 | for i, cap in enumerate(captions):
246 | end = lengths[i]
247 | targets[i, :end] = cap[:end]
248 |
249 | return images, targets, image_rcnn, img_pos, lengths, ids
250 |
251 |
252 | def get_loader_single(data_name, split, root, json, vocab, transform, batch_size=100, shuffle=True,
253 | num_workers=2, ids=None, collate_fn=collate_fn, region_bbox_file=None, feature_path=None):
254 | """Returns torch.utils.data.DataLoader for custom coco dataset."""
255 | if 'coco' in data_name:
256 | # COCO custom dataset
257 | dataset = CocoDataset(root=root,
258 | json=json,
259 | vocab=vocab,
260 | region_bbox_file=region_bbox_file,
261 | region_det_file_prefix=feature_path,
262 | transform=transform, ids=ids)
263 | elif 'f8k' in data_name or 'f30k' in data_name:
264 | dataset = FlickrDataset(root=root,
265 | split=split,
266 | json=json,
267 | vocab=vocab,
268 | region_bbox_file=region_bbox_file,
269 | feature_path=feature_path,
270 | transform=transform)
271 |
272 | # Data loader
273 | data_loader = torch.utils.data.DataLoader(dataset=dataset,
274 | batch_size=batch_size,
275 | shuffle=shuffle,
276 | pin_memory=True,
277 | num_workers=num_workers,
278 | collate_fn=collate_fn)
279 | return data_loader
280 |
281 |
282 | def get_transform(data_name, split_name, opt):
283 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
284 | std=[0.229, 0.224, 0.225])
285 | t_list = []
286 | if split_name == 'train':
287 | t_list = [transforms.RandomResizedCrop(opt.crop_size),
288 | transforms.RandomHorizontalFlip()]
289 | elif split_name == 'val':
290 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
291 | elif split_name == 'test':
292 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
293 |
294 | t_end = [transforms.ToTensor(), normalizer]
295 | transform = transforms.Compose(t_list + t_end)
296 | return transform
297 |
298 |
299 | def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt):
300 | dpath = os.path.join(opt.data_path, data_name)
301 |
302 | roots, ids = get_paths(dpath, data_name, opt.use_restval)
303 |
304 | transform = get_transform(data_name, 'train', opt)
305 | train_loader = get_loader_single(opt.data_name, 'train',
306 | roots['train']['img'],
307 | roots['train']['cap'],
308 | vocab, transform, ids=ids['train'],
309 | batch_size=batch_size, shuffle=True,
310 | num_workers=workers,
311 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file,
312 | feature_path=opt.feature_path)
313 |
314 | transform = get_transform(data_name, 'val', opt)
315 | val_loader = get_loader_single(opt.data_name, 'val',
316 | roots['val']['img'],
317 | roots['val']['cap'],
318 | vocab, transform, ids=ids['val'],
319 | batch_size=batch_size, shuffle=False,
320 | num_workers=workers,
321 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file,
322 | feature_path=opt.feature_path)
323 |
324 | return train_loader, val_loader
325 |
326 |
327 | def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
328 | workers, opt):
329 | dpath = os.path.join(opt.data_path, data_name)
330 |
331 | roots, ids = get_paths(dpath, data_name, opt.use_restval)
332 |
333 | transform = get_transform(data_name, split_name, opt)
334 | test_loader = get_loader_single(opt.data_name, split_name,
335 | roots[split_name]['img'],
336 | roots[split_name]['cap'],
337 | vocab, transform, ids=ids[split_name],
338 | batch_size=batch_size, shuffle=False,
339 | num_workers=workers,
340 | collate_fn=collate_fn, region_bbox_file=opt.region_bbox_file,
341 | feature_path=opt.feature_path)
342 |
343 | return test_loader
344 |
--------------------------------------------------------------------------------
/data/coco/annotations.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/data/coco/annotations.zip
--------------------------------------------------------------------------------
/data/coco/download.sh:
--------------------------------------------------------------------------------
1 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P data/
2 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -P data/
3 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip -P data/
4 |
5 | unzip data/captions_train-val2014.zip -d ./
6 | unzip data/train2014.zip -d images/
7 | rm data/train2014.zip
8 | unzip data/val2014.zip -d images/
9 | rm data/val2014.zip
10 |
--------------------------------------------------------------------------------
/data/coco/images:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/f30k/dataset_flickr30k.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/data/f30k/dataset_flickr30k.zip
--------------------------------------------------------------------------------
/data/f30k/images:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data_bert.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen & Linyang Li, 2020
8 | # ------------------------------------------------------------
9 |
10 | import torch
11 | import torch.utils.data as data
12 | import torchvision.transforms as transforms
13 | import os
14 | from PIL import Image
15 | from pycocotools.coco import COCO
16 | import numpy as np
17 | import json as jsonmod
18 | from collections import OrderedDict
19 | import copy
20 | from pytorch_pretrained_bert.tokenization import BertTokenizer
21 | import torch.nn.functional as F
22 | import h5py
23 |
24 |
25 | def get_paths(path, name='coco'):
26 | roots = {}
27 | ids = {}
28 | if 'coco' == name:
29 | imgdir = os.path.join(path, 'images')
30 | capdir = os.path.join(path, 'annotations')
31 | roots['train'] = {
32 | 'img': os.path.join(imgdir, 'train2014'),
33 | 'cap': os.path.join(capdir, 'captions_train2014.json')
34 | }
35 | roots['val'] = {
36 | 'img': os.path.join(imgdir, 'val2014'),
37 | 'cap': os.path.join(capdir, 'captions_val2014.json')
38 | }
39 | roots['test'] = {
40 | 'img': os.path.join(imgdir, 'val2014'),
41 | 'cap': os.path.join(capdir, 'captions_val2014.json')
42 | }
43 | roots['trainrestval'] = {
44 | 'img': (roots['train']['img'], roots['val']['img']),
45 | 'cap': (roots['train']['cap'], roots['val']['cap'])
46 | }
47 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy'))
48 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000]
49 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy'))
50 | ids['trainrestval'] = (
51 | ids['train'],
52 | np.load(os.path.join(capdir, 'coco_restval_ids.npy')))
53 |
54 | roots['train'] = roots['trainrestval']
55 | ids['train'] = ids['trainrestval']
56 | elif 'f30k' == name:
57 | imgdir = os.path.join(path, 'images')
58 | cap = os.path.join(path, 'dataset_flickr30k.json')
59 | roots['train'] = {'img': imgdir, 'cap': cap}
60 | roots['val'] = {'img': imgdir, 'cap': cap}
61 | roots['test'] = {'img': imgdir, 'cap': cap}
62 | ids = {'train': None, 'val': None, 'test': None}
63 |
64 | return roots, ids
65 |
66 |
67 | class CocoDataset(data.Dataset):
68 |
69 | def __init__(self, root, json, tokenizer, feature_path=None, region_bbox_file=None, max_seq_len=32, transform=None, ids=None):
70 | self.root = root
71 | if isinstance(json, tuple):
72 | self.coco = (COCO(json[0]), COCO(json[1]))
73 | else:
74 | self.coco = (COCO(json),)
75 | self.root = (root,)
76 | if ids is None:
77 | self.ids = list(self.coco.anns.keys())
78 | else:
79 | self.ids = ids
80 | if isinstance(self.ids, tuple):
81 | self.bp = len(self.ids[0])
82 | self.ids = list(self.ids[0]) + list(self.ids[1])
83 | else:
84 | self.bp = len(self.ids)
85 | self.transform = transform
86 | self.tokenizer = tokenizer
87 | self.max_seq_len = max_seq_len
88 | self.region_bbox_file = region_bbox_file
89 | self.region_det_file_prefix = feature_path
90 |
91 | def __getitem__(self, index):
92 | root, caption, img_id, path, image, img_rcnn, img_pe = self.get_raw_item(index)
93 |
94 | if self.transform is not None:
95 | image = self.transform(image)
96 |
97 | target = self.get_text_input(caption)
98 | return img_rcnn, img_pe, target, index, img_id, image
99 |
100 | def get_raw_item(self, index):
101 | if index < self.bp:
102 | coco = self.coco[0]
103 | root = self.root[0]
104 | else:
105 | coco = self.coco[1]
106 | root = self.root[1]
107 | ann_id = self.ids[index]
108 | caption = coco.anns[ann_id]['caption']
109 | img_id = coco.anns[ann_id]['image_id']
110 | path = coco.loadImgs(img_id)[0]['file_name']
111 | image = Image.open(os.path.join(root, path)).convert('RGB')
112 | img_rcnn, img_pe = self.get_rcnn(path)
113 | return root, caption, img_id, path, image, img_rcnn, img_pe
114 |
115 | def __len__(self):
116 | return len(self.ids)
117 |
118 | def get_text_input(self, caption):
119 | caption_tokens = self.tokenizer.tokenize(caption)
120 | caption_tokens = ['[CLS]'] + caption_tokens + ['[SEP]']
121 | caption_ids = self.tokenizer.convert_tokens_to_ids(caption_tokens)
122 | if len(caption_ids) >= self.max_seq_len:
123 | caption_ids = caption_ids[:self.max_seq_len]
124 | else:
125 | caption_ids = caption_ids + [0] * (self.max_seq_len - len(caption_ids))
126 | caption = torch.tensor(caption_ids)
127 | return caption
128 |
129 | def get_rcnn(self, path):
130 | img_id = path.split('/')[-1].split('.')[0]
131 | with h5py.File(self.region_det_file_prefix + '_feat' + img_id[-3:] + '.h5', 'r') as region_feat_f, \
132 | h5py.File(self.region_det_file_prefix + '_cls' + img_id[-3:] + '.h5', 'r') as region_cls_f, \
133 | h5py.File(self.region_bbox_file, 'r') as region_bbox_f:
134 |
135 | img = torch.from_numpy(region_feat_f[img_id][:]).float()
136 | cls_label = torch.from_numpy(region_cls_f[img_id][:]).float()
137 | vis_pe = torch.from_numpy(region_bbox_f[img_id][:])
138 |
139 | # lazy normalization of the coordinates...
140 |
141 | w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
142 | h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
143 | vis_pe[:, [0, 2]] /= w_est
144 | vis_pe[:, [1, 3]] /= h_est
145 | rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0])
146 | rel_area.clamp_(0)
147 |
148 | vis_pe = torch.cat((vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score
149 | normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1)
150 | vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
151 | F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded...
152 |
153 | return img, vis_pe
154 |
155 |
156 | class FlickrDataset(data.Dataset):
157 |
158 | def __init__(self, root, json, split, tokenizer, feature_path=None, region_bbox_file=None, max_seq_len=32,
159 | transform=None):
160 | self.root = root
161 | self.split = split
162 | self.transform = transform
163 | self.dataset = jsonmod.load(open(json, 'r'))['images']
164 | self.ids = []
165 | self.tokenizer = tokenizer
166 | self.max_seq_len = max_seq_len
167 | for i, d in enumerate(self.dataset):
168 | if d['split'] == split:
169 | self.ids += [(i, x) for x in range(len(d['sentences']))]
170 | self.region_bbox_file = region_bbox_file
171 | self.feature_path = feature_path
172 |
173 | def __getitem__(self, index):
174 | root = self.root
175 | ann_id = self.ids[index]
176 | img_id = ann_id[0]
177 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw']
178 | path = self.dataset[img_id]['filename']
179 | path_orig = copy.deepcopy(path)
180 | path = path.replace('.jpg', '.npy')
181 | feature_path = self.feature_path
182 | # orig image
183 | image_orig = Image.open(os.path.join(root, path_orig)).convert('RGB')
184 | if self.transform is not None:
185 | image_orig = self.transform(image_orig)
186 | target = self.get_text_input(caption)
187 | image, img_pos = self.get_rcnn(os.path.join(feature_path, path)) # return img-feature 100 2048 & pos-feature
188 |
189 | return image, img_pos, target, index, img_id, image_orig
190 |
191 | def get_rcnn(self, img_path):
192 | if os.path.exists(img_path) and os.path.exists(img_path.replace('.npy', '_cls_prob.npy')):
193 | img = torch.from_numpy(np.load(img_path))
194 | img_id = img_path.split('/')[-1].split('.')[0]
195 | cls_label = torch.from_numpy(np.load(img_path.replace('.npy', '_cls_prob.npy')))
196 | with h5py.File(self.region_bbox_file, 'r') as region_bbox_f:
197 | vis_pe = torch.from_numpy(region_bbox_f[img_id][:])
198 |
199 | # lazy normalization of the coordinates...
200 |
201 | w_est = torch.max(vis_pe[:, [0, 2]]) * 1. + 1e-5
202 | h_est = torch.max(vis_pe[:, [1, 3]]) * 1. + 1e-5
203 | vis_pe[:, [0, 2]] /= w_est
204 | vis_pe[:, [1, 3]] /= h_est
205 | rel_area = (vis_pe[:, 3] - vis_pe[:, 1]) * (vis_pe[:, 2] - vis_pe[:, 0])
206 | rel_area.clamp_(0)
207 |
208 | vis_pe = torch.cat((vis_pe[:, :4], rel_area.view(-1, 1), vis_pe[:, 5:]), -1) # confident score
209 | normalized_coord = F.normalize(vis_pe.data[:, :5] - 0.5, dim=-1)
210 | vis_pe = torch.cat((F.layer_norm(vis_pe, [6]), \
211 | F.layer_norm(cls_label, [1601])), dim=-1) # 1601 hard coded...
212 | else:
213 | img = torch.randn(100, 2048)
214 | vis_pe = torch.randn(100, 1601 + 6)
215 | return img, vis_pe
216 |
217 | def get_text_input(self, caption):
218 | caption_tokens = self.tokenizer.tokenize(caption)
219 | caption_tokens = ['[CLS]'] + caption_tokens + ['[SEP]']
220 | caption_ids = self.tokenizer.convert_tokens_to_ids(caption_tokens)
221 | if len(caption_ids) >= self.max_seq_len:
222 | caption_ids = caption_ids[:self.max_seq_len]
223 | else:
224 | caption_ids = caption_ids + [0] * (self.max_seq_len - len(caption_ids))
225 | caption = torch.tensor(caption_ids)
226 | return caption
227 |
228 | def __len__(self):
229 | return len(self.ids)
230 |
231 |
232 | def collate_fn(data):
233 | images, img_pos, captions, ids, img_ids, image_orig = zip(*data)
234 | images = torch.stack(images, 0)
235 | img_pos = torch.stack(img_pos, 0)
236 | captions = torch.stack(captions, 0)
237 | images_orig = torch.stack(image_orig, 0)
238 | return images, images_orig, img_pos, captions, ids
239 |
240 |
241 | def get_tokenizer(bert_path):
242 | tokenizer = BertTokenizer(bert_path + 'vocab.txt')
243 | return tokenizer
244 |
245 |
246 | def get_loader_single(data_name, split, root, json, transform,
247 | batch_size=128, shuffle=True,
248 | num_workers=10, ids=None, collate_fn=collate_fn,
249 | feature_path=None,
250 | region_bbox_file=None,
251 | bert_path=None
252 | ):
253 | if 'coco' in data_name:
254 | dataset = CocoDataset(root=root, json=json,
255 | feature_path=feature_path,
256 | region_bbox_file=region_bbox_file,
257 | tokenizer=get_tokenizer(bert_path),
258 | max_seq_len=32, transform=transform, ids=ids)
259 | elif 'f30k' in data_name:
260 | dataset = FlickrDataset(root=root, split=split, json=json,
261 | feature_path=feature_path,
262 | region_bbox_file=region_bbox_file,
263 | tokenizer=get_tokenizer(bert_path),
264 | max_seq_len=32, transform=transform)
265 |
266 | data_loader = torch.utils.data.DataLoader(dataset=dataset,
267 | batch_size=batch_size,
268 | shuffle=shuffle,
269 | pin_memory=True,
270 | num_workers=num_workers,
271 | collate_fn=collate_fn)
272 | return data_loader
273 |
274 |
275 | def get_transform(data_name, split_name, opt):
276 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
277 | std=[0.229, 0.224, 0.225])
278 | t_list = []
279 | if split_name == 'train':
280 | t_list = [transforms.RandomResizedCrop(opt.crop_size),
281 | transforms.RandomHorizontalFlip()]
282 | # t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
283 | elif split_name == 'val':
284 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
285 | elif split_name == 'test':
286 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
287 |
288 | t_end = [transforms.ToTensor(), normalizer]
289 | transform = transforms.Compose(t_list + t_end)
290 | return transform
291 |
292 |
293 | def get_loaders(data_name, batch_size, workers, opt):
294 | dpath = os.path.join(opt.data_path, data_name)
295 | roots, ids = get_paths(dpath, data_name)
296 |
297 | transform = get_transform(data_name, 'train', opt)
298 | train_loader = get_loader_single(opt.data_name, 'train',
299 | roots['train']['img'],
300 | roots['train']['cap'],
301 | transform, ids=ids['train'],
302 | batch_size=batch_size, shuffle=True,
303 | num_workers=workers,
304 | collate_fn=collate_fn,
305 | feature_path=opt.feature_path,
306 | region_bbox_file=opt.region_bbox_file,
307 | bert_path=opt.bert_path
308 | )
309 |
310 | transform = get_transform(data_name, 'val', opt)
311 |
312 | val_loader = get_loader_single(opt.data_name, 'val',
313 | roots['val']['img'],
314 | roots['val']['cap'],
315 | transform, ids=ids['val'],
316 | batch_size=batch_size, shuffle=False,
317 | num_workers=workers,
318 | collate_fn=collate_fn,
319 | feature_path=opt.feature_path,
320 | region_bbox_file=opt.region_bbox_file,
321 | bert_path=opt.bert_path
322 | )
323 |
324 | return train_loader, val_loader
325 |
326 |
327 | def get_test_loader(split_name, data_name, batch_size, workers, opt):
328 | dpath = os.path.join(opt.data_path, data_name)
329 |
330 | roots, ids = get_paths(dpath, data_name)
331 |
332 | transform = get_transform(data_name, split_name, opt)
333 | test_loader = get_loader_single(opt.data_name, split_name,
334 | roots[split_name]['img'],
335 | roots[split_name]['cap'],
336 | transform, ids=ids[split_name],
337 | batch_size=batch_size, shuffle=False,
338 | num_workers=workers,
339 | collate_fn=collate_fn,
340 | feature_path=opt.feature_path,
341 | region_bbox_file=opt.region_bbox_file,
342 | bert_path=opt.bert_path
343 | )
344 |
345 | return test_loader
346 |
347 |
348 | class AverageMeter(object):
349 |
350 | def __init__(self):
351 | self.reset()
352 |
353 | def reset(self):
354 | self.val = 0
355 | self.avg = 0
356 | self.sum = 0
357 | self.count = 0
358 |
359 | def update(self, val, n=0):
360 | self.val = val
361 | self.sum += val * n
362 | self.count += n
363 | self.avg = self.sum / (.0001 + self.count)
364 |
365 | def __str__(self):
366 | if self.count == 0:
367 | return str(self.val)
368 | return '%.4f (%.4f)' % (self.val, self.avg)
369 |
370 |
371 | class LogCollector(object):
372 | def __init__(self):
373 | self.meters = OrderedDict()
374 |
375 | def update(self, k, v, n=0):
376 | if k not in self.meters:
377 | self.meters[k] = AverageMeter()
378 | self.meters[k].update(v, n)
379 |
380 | def __str__(self):
381 | s = ''
382 | for i, (k, v) in enumerate(self.meters.items()):
383 | if i > 0:
384 | s += ' '
385 | s += k + ' ' + str(v)
386 | return s
387 |
388 | def tb_log(self, tb_logger, prefix='', step=None):
389 | for k, v in self.meters.items():
390 | tb_logger.log_value(prefix + k, v.val, step=step)
391 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen, 2020
8 | # ------------------------------------------------------------
9 |
10 | from __future__ import print_function
11 | import os
12 | import pickle
13 | import numpy
14 | from data import get_test_loader
15 | import time
16 | import numpy as np
17 | from vocab import Vocabulary # NOQA
18 | import torch
19 | from model import VSE
20 | from collections import OrderedDict
21 | import argparse
22 |
23 |
24 | class AverageMeter(object):
25 | """Computes and stores the average and current value"""
26 |
27 | def __init__(self):
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = 0
32 | self.avg = 0
33 | self.sum = 0
34 | self.count = 0
35 |
36 | def update(self, val, n=0):
37 | self.val = val
38 | self.sum += val * n
39 | self.count += n
40 | self.avg = self.sum / (.0001 + self.count)
41 |
42 | def __str__(self):
43 | """String representation for logging
44 | """
45 | # for values that should be recorded exactly e.g. iteration number
46 | if self.count == 0:
47 | return str(self.val)
48 | # for stats
49 | return '%.4f (%.4f)' % (self.val, self.avg)
50 |
51 |
52 | class LogCollector(object):
53 | """A collection of logging objects that can change from train to val"""
54 |
55 | def __init__(self):
56 | # to keep the order of logged variables deterministic
57 | self.meters = OrderedDict()
58 |
59 | def update(self, k, v, n=0):
60 | # create a new meter if previously not recorded
61 | if k not in self.meters:
62 | self.meters[k] = AverageMeter()
63 | self.meters[k].update(v, n)
64 |
65 | def __str__(self):
66 | """Concatenate the meters in one log line
67 | """
68 | s = ''
69 | for i, (k, v) in enumerate(self.meters.items()):
70 | if i > 0:
71 | s += ' '
72 | s += k + ' ' + str(v)
73 | return s
74 |
75 | def tb_log(self, tb_logger, prefix='', step=None):
76 | """Log using tensorboard
77 | """
78 | for k, v in self.meters.items():
79 | tb_logger.log_value(prefix + k, v.val, step=step)
80 |
81 |
82 | def encode_data(model, data_loader, log_step=10, logging=print):
83 | """Encode all images and captions loadable by `data_loader`
84 | """
85 | batch_time = AverageMeter()
86 | val_logger = LogCollector()
87 |
88 | # switch to evaluate mode
89 | model.val_start()
90 |
91 | end = time.time()
92 |
93 | # numpy array to keep all the embeddings
94 | img_embs = None
95 | cap_embs = None
96 | with torch.no_grad():
97 | for i, (images, captions, img_rcnn, img_pos, lengths, ids) in enumerate(data_loader):
98 | # make sure val logger is used
99 | model.logger = val_logger
100 |
101 | # compute the embeddings
102 | img_emb, cap_emb = model.forward_emb(images, captions, img_rcnn, img_pos, lengths)
103 |
104 | # initialize the numpy arrays given the size of the embeddings
105 | if img_embs is None:
106 | img_embs = torch.zeros(len(data_loader.dataset), img_emb.size(1)).cuda()
107 | cap_embs = torch.zeros(len(data_loader.dataset), cap_emb.size(1)).cuda()
108 |
109 | img_embs[ids] = img_emb
110 | cap_embs[ids] = cap_emb
111 |
112 | # measure accuracy and record loss
113 | model.forward_loss(img_emb, cap_emb)
114 |
115 | # measure elapsed time
116 | batch_time.update(time.time() - end)
117 | end = time.time()
118 |
119 | if i % log_step == 0:
120 | logging('Test: [{0}/{1}]\t'
121 | '{e_log}\t'
122 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
123 | .format(
124 | i, len(data_loader), batch_time=batch_time,
125 | e_log=str(model.logger)))
126 | del images, captions
127 |
128 | return img_embs, cap_embs
129 |
130 |
131 | def evalrank(model_path, data_path=None, split='dev', fold5=False, region_bbox_file=None, feature_path=None):
132 | """
133 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
134 | cross-validation is done (only for MSCOCO). Otherwise, the full data is
135 | used for evaluation.
136 | """
137 | # load model and options
138 | checkpoint = torch.load(model_path)
139 | opt = checkpoint['opt']
140 | if data_path is not None:
141 | opt.data_path = data_path
142 | if region_bbox_file is not None:
143 | opt.region_bbox_file = region_bbox_file
144 | if feature_path is not None:
145 | opt.feature_path = feature_path
146 |
147 | # load vocabulary used by the model
148 | with open(os.path.join(opt.vocab_path,
149 | '%s_vocab.pkl' % opt.data_name), 'rb') as f:
150 | vocab = pickle.load(f)
151 | opt.vocab_size = len(vocab)
152 | print(opt)
153 |
154 | # construct model
155 | model = VSE(opt)
156 | # load model state
157 | model.load_state_dict(checkpoint['model'])
158 |
159 | print('Loading dataset')
160 | data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
161 | opt.batch_size, opt.workers, opt)
162 | print('Computing results...')
163 | img_embs, cap_embs= encode_data(model, data_loader)
164 | time_sim_start = time.time()
165 |
166 | if not fold5:
167 | img_emb_new = img_embs[0:img_embs.size(0):5]
168 | print(img_emb_new.size())
169 |
170 | sims = torch.mm(img_emb_new, cap_embs.t())
171 | sims_T = torch.mm(cap_embs, cap_embs.t())
172 | sims_T = sims_T.cpu().numpy()
173 |
174 | sims = sims.cpu().numpy()
175 | np.save('sims_f.npy',sims)
176 | np.save('sims_f_T.npy',sims_T)
177 |
178 | print('Images: %d, Captions: %d' %
179 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
180 |
181 | r = simrank(sims)
182 |
183 | time_sim_end = time.time()
184 | print('sims_time:%f' % (time_sim_end - time_sim_start))
185 | del sims
186 | else: # fold5-especially for coco
187 | print('5k---------------')
188 | img_emb_new = img_embs[0:img_embs.size(0):5]
189 | print(img_emb_new.size())
190 |
191 | sims = torch.mm(img_emb_new, cap_embs.t())
192 | sims_T = torch.mm(cap_embs, cap_embs.t())
193 |
194 | sims = sims.cpu().numpy()
195 | sims_T = sims_T.cpu().numpy()
196 |
197 | np.save('sims_full_5k.npy',sims)
198 | np.save('sims_full_T_5k.npy',sims_T)
199 | print('Images: %d, Captions: %d' %
200 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
201 |
202 | r = simrank(sims)
203 |
204 | time_sim_end = time.time()
205 | print('sims_time:%f' % (time_sim_end - time_sim_start))
206 | del sims, sims_T
207 | print('1k---------------')
208 | r_ = [0, 0, 0, 0, 0, 0, 0]
209 | for i in range(5):
210 | print(i)
211 | img_emb_new = img_embs[i * 5000 : int(i * 5000 + img_embs.size(0)/5):5]
212 | cap_emb_new = cap_embs[i * 5000 : int(i * 5000 + cap_embs.size(0)/5)]
213 |
214 | sims = torch.mm(img_emb_new, cap_emb_new.t())
215 | sims_T = torch.mm(cap_emb_new, cap_emb_new.t())
216 | sims_T = sims_T.cpu().numpy()
217 | sims = sims.cpu().numpy()
218 | np.save('sims_full_%d.npy'%i,sims)
219 | np.save('sims_full_T_%d'%i,sims_T)
220 |
221 | print('Images: %d, Captions: %d' %
222 | (img_emb_new.size(0), cap_emb_new.size(0)))
223 |
224 | r = simrank(sims)
225 | r_ = np.array(r_) + np.array(r)
226 |
227 | del sims
228 | print('--------------------')
229 | r_ = tuple(r_/5)
230 | print('I2T:%.1f %.1f %.1f' % r_[0:3])
231 | print('T2I:%.1f %.1f %.1f' % r_[3:6])
232 | print('Rsum:%.1f' % r_[-1])
233 |
234 |
235 | def simrank(similarity):
236 | sims = similarity
237 | img_size, cap_size = sims.shape
238 | print("imgs: %d, caps: %d" % (img_size, cap_size))
239 | # i2t
240 | index_list = []
241 | ranks = numpy.zeros(img_size)
242 | top1 = numpy.zeros(img_size)
243 | for index in range(img_size):
244 | d = sims[index]
245 | inds = numpy.argsort(d)[::-1]
246 | # print(inds)
247 | index_list.append(inds[0])
248 | rank = 1e20
249 | for i in range(5 * index, 5 * index + 5, 1):
250 | tmp = numpy.where(inds == i)[0]
251 | # print(tmp)
252 | if tmp < rank:
253 | rank = tmp
254 | ranks[index] = rank
255 | top1[index] = inds[0]
256 |
257 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
258 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
259 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
260 | medr = numpy.floor(numpy.median(ranks)) + 1
261 | meanr = ranks.mean() + 1
262 | print('i2t:r1: %.1f, r5: %.1f, r10: %.1f' % (r1, r5, r10)) # , medr, meanr)
263 | rs = r1 + r5 + r10
264 | # t2i
265 | sims_t2i = sims.T
266 | ranks = numpy.zeros(cap_size)
267 | top1 = numpy.zeros(cap_size)
268 | for index in range(img_size):
269 |
270 | d = sims_t2i[5 * index:5 * index + 5] # 5*1000
271 | inds = numpy.zeros(d.shape)
272 | for i in range(len(inds)):
273 | inds[i] = numpy.argsort(d[i])[::-1]
274 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
275 | top1[5 * index + i] = inds[i][0]
276 |
277 | r1_ = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
278 | r5_ = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
279 | r10_ = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
280 | medr_ = numpy.floor(numpy.median(ranks)) + 1
281 | meanr_ = ranks.mean() + 1
282 | rs_ = r1_ + r5_ + r10_
283 | print('t2i:r1: %.1f, r5: %.1f, r10: %.1f' % (r1_, r5_, r10_))
284 | rsum = rs + rs_
285 | print('rsum=%.1f' % rsum)
286 | return [r1, r5, r10, r1_, r5_, r10_, rsum]
287 |
288 |
289 | def i2t(images, captions, npts=None, return_ranks=False):
290 | """
291 | Images->Text (Image Annotation)
292 | Images: (5N, K) matrix of images
293 | Captions: (5N, K) matrix of captions
294 | """
295 | images = images.cpu().numpy()
296 | captions = captions.cpu().numpy()
297 | if npts is None:
298 | npts = int(images.shape[0] / 5)
299 | print(npts)
300 | index_list = []
301 |
302 | ranks = numpy.zeros(npts)
303 | top1 = numpy.zeros(npts)
304 | for index in range(npts):
305 |
306 | # Get query image
307 | im = images[5 * index].reshape(1, images.shape[1])
308 |
309 | # Compute scores
310 |
311 | d = numpy.dot(im, captions.T).flatten()
312 | inds = numpy.argsort(d)[::-1]
313 | index_list.append(inds[0])
314 |
315 | # Score
316 | rank = 1e20
317 | for i in range(5 * index, 5 * index + 5, 1):
318 | tmp = numpy.where(inds == i)[0][0]
319 | if tmp < rank:
320 | rank = tmp
321 | ranks[index] = rank
322 | top1[index] = inds[0]
323 |
324 | # Compute metrics
325 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
326 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
327 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
328 | medr = numpy.floor(numpy.median(ranks)) + 1
329 | meanr = ranks.mean() + 1
330 | if return_ranks:
331 | return (r1, r5, r10, medr, meanr), (ranks, top1)
332 | else:
333 | return (r1, r5, r10, medr, meanr)
334 |
335 |
336 | def t2i(images, captions, npts=None, return_ranks=False):
337 | """
338 | Text->Images (Image Search)
339 | Images: (5N, K) matrix of images
340 | Captions: (5N, K) matrix of captions
341 | """
342 | images = images.cpu().numpy()
343 | captions = captions.cpu().numpy()
344 | if npts is None:
345 | npts = int(images.shape[0] / 5)
346 | print(npts)
347 | ims = numpy.array([images[i] for i in range(0, len(images), 5)])
348 |
349 | ranks = numpy.zeros(5 * npts)
350 | top1 = numpy.zeros(5 * npts)
351 | for index in range(npts):
352 |
353 | # Get query captions
354 | queries = captions[5 * index:5 * index + 5]
355 |
356 | # Compute scores
357 | d = numpy.dot(queries, ims.T)
358 | inds = numpy.zeros(d.shape)
359 | for i in range(len(inds)):
360 | inds[i] = numpy.argsort(d[i])[::-1]
361 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
362 | top1[5 * index + i] = inds[i][0]
363 |
364 | # Compute metrics
365 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
366 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
367 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
368 | medr = numpy.floor(numpy.median(ranks)) + 1
369 | meanr = ranks.mean() + 1
370 | if return_ranks:
371 | return (r1, r5, r10, medr, meanr), (ranks, top1)
372 | else:
373 | return (r1, r5, r10, medr, meanr)
374 |
375 |
376 | def main():
377 | parser = argparse.ArgumentParser()
378 | parser.add_argument('--model', default='single_model', help='model name')
379 | parser.add_argument('--fold', action='store_true', help='fold5')
380 | parser.add_argument('--name', default='model_best', help='checkpoint name')
381 | parser.add_argument('--data_path', default='data', help='data path')
382 | parser.add_argument('--region_bbox_file', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str, metavar='PATH',
383 | help='path to region features bbox file')
384 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', type=str, metavar='PATH',
385 | help='path to region features')
386 | opt = parser.parse_args()
387 |
388 | evalrank('runs/' + opt.model + '/' + opt.name + ".pth.tar", data_path = opt.data_path, split="test", fold5=opt.fold, region_bbox_file=opt.region_bbox_file, feature_path=opt.feature_path)
389 |
390 | if __name__ == '__main__':
391 | main()
--------------------------------------------------------------------------------
/evaluation_bert.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | # -----------------------------------------------------------
3 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
4 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
5 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
6 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
7 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
8 | # Writen by Keyu Wen, 2020
9 | # ------------------------------------------------------------
10 |
11 | from __future__ import print_function
12 | import numpy
13 | from data_bert import get_test_loader
14 | import time
15 | import numpy as np
16 | import torch
17 | import argparse
18 | from model_bert import VSE
19 | from collections import OrderedDict
20 |
21 |
22 | class AverageMeter(object):
23 |
24 | def __init__(self):
25 | self.reset()
26 |
27 | def reset(self):
28 | self.val = 0
29 | self.avg = 0
30 | self.sum = 0
31 | self.count = 0
32 |
33 | def update(self, val, n=0):
34 | self.val = val
35 | self.sum += val * n
36 | self.count += n
37 | self.avg = self.sum / (.0001 + self.count)
38 |
39 | def __str__(self):
40 | if self.count == 0:
41 | return str(self.val)
42 | return '%.4f (%.4f)' % (self.val, self.avg)
43 |
44 |
45 | class LogCollector(object):
46 | def __init__(self):
47 | self.meters = OrderedDict()
48 |
49 | def update(self, k, v, n=0):
50 | if k not in self.meters:
51 | self.meters[k] = AverageMeter()
52 | self.meters[k].update(v, n)
53 |
54 | def __str__(self):
55 | s = ''
56 | for i, (k, v) in enumerate(self.meters.items()):
57 | if i > 0:
58 | s += ' '
59 | s += k + ' ' + str(v)
60 | return s
61 |
62 | def tb_log(self, tb_logger, prefix='', step=None):
63 | for k, v in self.meters.items():
64 | tb_logger.log_value(prefix + k, v.val, step=step)
65 |
66 |
67 | def encode_data(model, data_loader, log_step=10, logging=print):
68 | batch_time = AverageMeter()
69 | val_logger = LogCollector()
70 | model.val_start()
71 |
72 | end = time.time()
73 |
74 | img_embs = None
75 | cap_embs = None
76 | time_encode_start = time.time()
77 | # device = torch.device("cuda:0")
78 | with torch.no_grad():
79 | for i, (images, images_orig, img_pos, captions, ids) in enumerate(data_loader):
80 | model.logger = val_logger
81 |
82 | img_emb, cap_emb = model.forward_emb(images_orig, images, img_pos, captions)
83 |
84 | if img_embs is None:
85 | img_embs = torch.zeros(len(data_loader.dataset), img_emb.size(1)).cuda()
86 | cap_embs = torch.zeros(len(data_loader.dataset), cap_emb.size(1)).cuda()
87 |
88 | img_embs[ids] = img_emb
89 | cap_embs[ids] = cap_emb
90 |
91 | model.forward_loss(img_emb, cap_emb)
92 |
93 | batch_time.update(time.time() - end)
94 | end = time.time()
95 |
96 | if i % log_step == 0:
97 | logging('Test: [{0}/{1}]\t'
98 | '{e_log}\t'
99 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
100 | .format(
101 | i, len(data_loader), batch_time=batch_time,
102 | e_log=str(model.logger)))
103 | del images, captions
104 | time_encode_end = time.time()
105 | print('encode_time:%f' % (time_encode_end - time_encode_start))
106 | img_emb_new = img_embs[0:img_embs.size(0):5]
107 | sims = torch.mm(img_emb_new, cap_embs.t())
108 | sims = sims.cpu().numpy()
109 |
110 | return img_embs, cap_embs, sims
111 |
112 |
113 | def evalrank(model_path, data_path=None, split='dev', fold5=False, region_bbox_file=None, feature_path=None):
114 | checkpoint = torch.load(model_path)
115 | opt = checkpoint['opt']
116 |
117 | if data_path is not None:
118 | opt.data_path = data_path
119 | if data_path is not None:
120 | opt.region_bbox_file = region_bbox_file
121 | if data_path is not None:
122 | opt.feature_path = feature_path
123 |
124 | print(opt)
125 | model = VSE(opt)
126 |
127 | model.load_state_dict(checkpoint['model']) #
128 |
129 | print('Loading dataset')
130 | data_loader = get_test_loader(split, opt.data_name, opt.batch_size, opt.workers, opt)
131 |
132 | print('Computing results...')
133 | img_embs, cap_embs, sims = encode_data(model, data_loader)
134 |
135 | time_sim_start = time.time()
136 |
137 | if not fold5:
138 | img_emb_new = img_embs[0:img_embs.size(0):5]
139 | print(img_emb_new.size())
140 | sims = torch.mm(img_emb_new, cap_embs.t())
141 | sims_T = torch.mm(cap_embs, cap_embs.t())
142 | sims_T = sims_T.cpu().numpy()
143 |
144 | sims = sims.cpu().numpy()
145 | np.save('sims_f.npy',sims)
146 | np.save('sims_f_T.npy',sims_T)
147 |
148 | print('Images: %d, Captions: %d' %
149 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
150 |
151 | r = simrank(sims)
152 |
153 | time_sim_end = time.time()
154 | print('sims_time:%f' % (time_sim_end - time_sim_start))
155 | del sims
156 | else: # fold5-especially for coco
157 | print('5k---------------')
158 | img_emb_new = img_embs[0:img_embs.size(0):5]
159 | print(img_emb_new.size())
160 |
161 | sims = torch.mm(img_emb_new, cap_embs.t())
162 | sims_T = torch.mm(cap_embs, cap_embs.t())
163 |
164 | sims = sims.cpu().numpy()
165 | sims_T = sims_T.cpu().numpy()
166 |
167 | np.save('sims_full_5k.npy',sims)
168 | np.save('sims_full_T_5k.npy',sims_T)
169 | print('Images: %d, Captions: %d' %
170 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
171 |
172 | r = simrank(sims)
173 |
174 | time_sim_end = time.time()
175 | print('sims_time:%f' % (time_sim_end - time_sim_start))
176 | del sims, sims_T
177 | print('1k---------------')
178 | r_ = [0, 0, 0, 0, 0, 0, 0]
179 | for i in range(5):
180 | print(i)
181 | img_emb_new = img_embs[i * 5000 : int(i * 5000 + img_embs.size(0)/5):5]
182 | cap_emb_new = cap_embs[i * 5000 : int(i * 5000 + cap_embs.size(0)/5)]
183 |
184 | sims = torch.mm(img_emb_new, cap_emb_new.t())
185 | sims_T = torch.mm(cap_emb_new, cap_emb_new.t())
186 | sims_T = sims_T.cpu().numpy()
187 | sims = sims.cpu().numpy()
188 | np.save('sims_full_%d.npy'%i,sims)
189 | np.save('sims_full_T_%d'%i,sims_T)
190 |
191 | print('Images: %d, Captions: %d' %
192 | (img_emb_new.size(0), cap_emb_new.size(0)))
193 |
194 | r = simrank(sims)
195 | r_ = np.array(r_) + np.array(r)
196 |
197 | del sims
198 | print('--------------------')
199 | r_ = tuple(r_/5)
200 | print('I2T:%.1f %.1f %.1f' % r_[0:3])
201 | print('T2I:%.1f %.1f %.1f' % r_[3:6])
202 | print('Rsum:%.1f' % r_[-1])
203 |
204 |
205 | def i2t(images, captions, npts=None, return_ranks=False):
206 | if npts is None:
207 | npts = int(images.shape[0] / 5)
208 | print(npts)
209 | index_list = []
210 |
211 | ranks = numpy.zeros(npts)
212 | top1 = numpy.zeros(npts)
213 |
214 | for index in range(npts):
215 |
216 | im = images[5 * index].reshape(1, images.shape[1])
217 |
218 | d = numpy.dot(im, captions.T).flatten()
219 | inds = numpy.argsort(d)[::-1]
220 | index_list.append(inds[0])
221 |
222 | rank = 1e20
223 | for i in range(5 * index, 5 * index + 5, 1):
224 | tmp = numpy.where(inds == i)[0][0]
225 | if tmp < rank:
226 | rank = tmp
227 | ranks[index] = rank
228 | top1[index] = inds[0]
229 |
230 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
231 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
232 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
233 | medr = numpy.floor(numpy.median(ranks)) + 1
234 | meanr = ranks.mean() + 1
235 | if return_ranks:
236 | return (r1, r5, r10, medr, meanr), (ranks, top1)
237 | else:
238 | return (r1, r5, r10, medr, meanr)
239 |
240 |
241 | def t2i(images, captions, npts=None, return_ranks=False):
242 | if npts is None:
243 | npts = int(images.shape[0] / 5)
244 | print(npts)
245 | ims = numpy.array([images[i] for i in range(0, len(images), 5)])
246 |
247 | ranks = numpy.zeros(5 * npts)
248 | top1 = numpy.zeros(5 * npts)
249 |
250 | for index in range(npts):
251 | queries = captions[5 * index:5 * index + 5]
252 | print('3')
253 |
254 | d = np.dot(queries, ims.T)
255 |
256 | inds = numpy.zeros(d.shape)
257 | print('5')
258 | for i in range(len(inds)):
259 | inds[i] = numpy.argsort(d[i])[::-1]
260 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
261 | top1[5 * index + i] = inds[i][0]
262 |
263 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
264 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
265 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
266 | medr = numpy.floor(numpy.median(ranks)) + 1
267 | meanr = ranks.mean() + 1
268 | if return_ranks:
269 | return (r1, r5, r10, medr, meanr), (ranks, top1)
270 | else:
271 | return (r1, r5, r10, medr, meanr)
272 |
273 |
274 | def simrank(similarity):
275 | sims = similarity # similarity matrix 1k*5k
276 | # print(sims)
277 | img_size, cap_size = sims.shape
278 | print("imgs: %d, caps: %d" % (img_size, cap_size))
279 | # time.sleep(10)
280 | # i2t
281 | index_list = []
282 | ranks = numpy.zeros(img_size)
283 | top1 = numpy.zeros(img_size)
284 | for index in range(img_size):
285 | d = sims[index]
286 | inds = numpy.argsort(d)[::-1]
287 | index_list.append(inds[0])
288 | rank = 1e20
289 | for i in range(5 * index, 5 * index + 5, 1):
290 | tmp = numpy.where(inds == i)[0]
291 | if tmp < rank:
292 | rank = tmp
293 | ranks[index] = rank
294 | top1[index] = inds[0]
295 |
296 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
297 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
298 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
299 | medr = numpy.floor(numpy.median(ranks)) + 1
300 | meanr = ranks.mean() + 1
301 | print('i2t:r1: %.1f, r5: %.1f, r10: %.1f' % (r1, r5, r10)) # , medr, meanr)
302 | rs = r1 + r5 + r10
303 | # t2i
304 | sims_t2i = sims.T
305 | ranks = numpy.zeros(cap_size)
306 | top1 = numpy.zeros(cap_size)
307 | for index in range(img_size):
308 |
309 | d = sims_t2i[5 * index:5 * index + 5] # 5*1000
310 | inds = numpy.zeros(d.shape)
311 | for i in range(len(inds)):
312 | inds[i] = numpy.argsort(d[i])[::-1]
313 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
314 | top1[5 * index + i] = inds[i][0]
315 |
316 | r1_ = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
317 | r5_ = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
318 | r10_ = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
319 | medr_ = numpy.floor(numpy.median(ranks)) + 1
320 | meanr_ = ranks.mean() + 1
321 | rs_ = r1_ + r5_ + r10_
322 | print('t2i:r1: %.1f, r5: %.1f, r10: %.1f' % (r1_, r5_, r10_))
323 | rsum = rs + rs_
324 | print('rsum=%.1f' % rsum)
325 | return [r1, r5, r10, r1_, r5_, r10_, rsum]
326 |
327 |
328 | def main():
329 | parser = argparse.ArgumentParser()
330 | parser.add_argument('--model', default='single_model', help='model name')
331 | parser.add_argument('--fold', action='store_true', help='fold5')
332 | parser.add_argument('--name', default='model_best', help='checkpoint name')
333 | parser.add_argument('--data_path', default='data', help='data path')
334 | parser.add_argument('--region_bbox_file', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5', type=str, metavar='PATH',
335 | help='path to region features bbox file')
336 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/', type=str, metavar='PATH',
337 | help='path to region features')
338 | opt = parser.parse_args()
339 |
340 | evalrank('runs/' + opt.model + '/' + opt.name + ".pth.tar", data_path = opt.data_path, split="test", fold5=opt.fold, region_bbox_file=opt.region_bbox_file, feature_path=opt.feature_path)
341 |
342 | if __name__ == '__main__':
343 | main()
344 |
--------------------------------------------------------------------------------
/figures/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/figures/model.jpg
--------------------------------------------------------------------------------
/flickr_sims/flickr_sims.txt:
--------------------------------------------------------------------------------
1 | Path to save similarity matrixes during inference stage of Flickr30K.
2 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen, 2020
8 | # ------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init
13 | import torchvision.models as models
14 | from torch.autograd import Variable
15 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
16 | import torch.backends.cudnn as cudnn
17 | from torch.nn.utils.clip_grad import clip_grad_norm_
18 | import numpy as np
19 | from collections import OrderedDict
20 | import time
21 | from GAT import GATLayer
22 | import copy
23 | from resnet import resnet152
24 | import torchtext
25 | import pickle
26 | import os
27 |
28 |
29 | def l2norm(X, dim=-1, eps=1e-12):
30 | """L2-normalize columns of X
31 | """
32 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
33 | X = torch.div(X, norm)
34 | return X
35 |
36 |
37 | class GATopt(object):
38 | def __init__(self, hidden_size, num_layers):
39 | self.hidden_size = hidden_size
40 | self.num_layers = num_layers
41 | self.num_attention_heads = 8
42 | self.hidden_dropout_prob = 0.2
43 | self.attention_probs_dropout_prob = 0.2
44 |
45 |
46 | class GAT(nn.Module):
47 | def __init__(self, config_gat):
48 | super(GAT, self).__init__()
49 | layer = GATLayer(config_gat)
50 | self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
51 |
52 | def forward(self, input_graph):
53 | hidden_states = input_graph
54 | for layer_module in self.encoder:
55 | hidden_states = layer_module(hidden_states)
56 | return hidden_states # B, seq_len, D
57 |
58 |
59 | class RcnnEncoder(nn.Module):
60 | def __init__(self, opt):
61 | super(RcnnEncoder, self).__init__()
62 | self.embed_size = opt.embed_size
63 | self.fc_image = nn.Linear(opt.img_dim, self.embed_size)
64 | self.init_weights()
65 |
66 | def init_weights(self):
67 | """Xavier initialization for the fully connected layer
68 | """
69 | r = np.sqrt(6.) / np.sqrt(self.fc_image.in_features +
70 | self.fc_image.out_features)
71 | self.fc_image.weight.data.uniform_(-r, r)
72 | self.fc_image.bias.data.fill_(0)
73 |
74 | def forward(self, images, img_pos): # (b, 100, 2048) (b,100,1601+6)
75 | img_f = self.fc_image(images)
76 | return img_f # (b,100,768)
77 |
78 |
79 | # tutorials/09 - Image Captioning
80 | class EncoderImageFull(nn.Module):
81 |
82 | def __init__(self, opt):
83 | """Load pretrained VGG19 and replace top fc layer."""
84 | super(EncoderImageFull, self).__init__()
85 | self.embed_size = opt.embed_size
86 |
87 | self.cnn = resnet152(pretrained=True)
88 | # self.fc = nn.Sequential(nn.Linear(2048, self.embed_size), nn.ReLU(), nn.Dropout(0.1))
89 | self.fc = nn.Linear(opt.img_dim, self.embed_size)
90 | if not opt.finetune:
91 | print('image-encoder-resnet no grad!')
92 | for param in self.cnn.parameters():
93 | param.requires_grad = False
94 | else:
95 | print('image-encoder-resnet fine-tuning !')
96 |
97 | self.init_weights()
98 |
99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
100 |
101 | def load_state_dict(self, state_dict):
102 | """
103 | Handle the models saved before commit pytorch/vision@989d52a
104 | """
105 | if 'cnn.classifier.1.weight' in state_dict:
106 | state_dict['cnn.classifier.0.weight'] = state_dict[
107 | 'cnn.classifier.1.weight']
108 | del state_dict['cnn.classifier.1.weight']
109 | state_dict['cnn.classifier.0.bias'] = state_dict[
110 | 'cnn.classifier.1.bias']
111 | del state_dict['cnn.classifier.1.bias']
112 | state_dict['cnn.classifier.3.weight'] = state_dict[
113 | 'cnn.classifier.4.weight']
114 | del state_dict['cnn.classifier.4.weight']
115 | state_dict['cnn.classifier.3.bias'] = state_dict[
116 | 'cnn.classifier.4.bias']
117 | del state_dict['cnn.classifier.4.bias']
118 |
119 | super(EncoderImageFull, self).load_state_dict(state_dict)
120 |
121 | def init_weights(self):
122 | """Xavier initialization for the fully connected layer
123 | """
124 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
125 | self.fc.out_features)
126 | self.fc.weight.data.uniform_(-r, r)
127 | self.fc.bias.data.fill_(0)
128 |
129 | def forward(self, images):
130 | features_orig = self.cnn(images)
131 | features_top = features_orig[-1]
132 | features = features_top.view(features_top.size(0), features_top.size(1), -1).transpose(2, 1) # b, 49, 2048
133 | features = self.fc(features)
134 |
135 | return features
136 |
137 |
138 | # tutorials/08 - Language Model
139 | # RNN Based Language Model
140 | class EncoderText(nn.Module):
141 |
142 | def __init__(self, opt):
143 | super(EncoderText, self).__init__()
144 | self.embed_size = opt.embed_size
145 | # word embedding
146 | self.embed = nn.Embedding(opt.vocab_size, opt.word_dim)
147 | # caption embedding
148 | self.rnn = nn.GRU(opt.word_dim, opt.embed_size, opt.num_layers, batch_first=True)
149 | vocab = pickle.load(open('vocab/'+opt.data_name+'_vocab.pkl', 'rb'))
150 | word2idx = vocab.word2idx
151 | # self.init_weights()
152 | self.init_weights('glove', word2idx, opt.word_dim)
153 | self.dropout = nn.Dropout(0.1)
154 |
155 | def init_weights(self, wemb_type, word2idx, word_dim):
156 | if wemb_type.lower() == 'random_init':
157 | nn.init.xavier_uniform_(self.embed.weight)
158 | else:
159 | # Load pretrained word embedding
160 | if 'fasttext' == wemb_type.lower():
161 | wemb = torchtext.vocab.FastText()
162 | elif 'glove' == wemb_type.lower():
163 | wemb = torchtext.vocab.GloVe()
164 | else:
165 | raise Exception('Unknown word embedding type: {}'.format(wemb_type))
166 | assert wemb.vectors.shape[1] == word_dim
167 |
168 | # quick-and-dirty trick to improve word-hit rate
169 | missing_words = []
170 | for word, idx in word2idx.items():
171 | if word not in wemb.stoi:
172 | word = word.replace('-', '').replace('.', '').replace("'", '')
173 | if '/' in word:
174 | word = word.split('/')[0]
175 | if word in wemb.stoi:
176 | self.embed.weight.data[idx] = wemb.vectors[wemb.stoi[word]]
177 | else:
178 | missing_words.append(word)
179 | print('Words: {}/{} found in vocabulary; {} words missing'.format(
180 | len(word2idx) - len(missing_words), len(word2idx), len(missing_words)))
181 |
182 | def forward(self, x, lengths):
183 | # return out
184 | x = self.embed(x)
185 | x = self.dropout(x)
186 |
187 | packed = pack_padded_sequence(x, lengths, batch_first=True)
188 |
189 | # Forward propagate RNN
190 | out, _ = self.rnn(packed)
191 |
192 | # Reshape *final* output to (batch_size, hidden_size)
193 | padded = pad_packed_sequence(out, batch_first=True)
194 | cap_emb, cap_len = padded
195 |
196 | cap_emb = l2norm(cap_emb, dim=-1)
197 | cap_emb_mean = torch.mean(cap_emb, 1)
198 | cap_emb_mean = l2norm(cap_emb_mean)
199 |
200 | return cap_emb, cap_emb_mean
201 |
202 |
203 | class Fusion(nn.Module):
204 | def __init__(self, opt):
205 | super(Fusion, self).__init__()
206 | self.f_size = opt.embed_size
207 | self.gate0 = nn.Linear(self.f_size, self.f_size)
208 | self.gate1 = nn.Linear(self.f_size, self.f_size)
209 |
210 | self.fusion0 = nn.Linear(self.f_size, self.f_size)
211 | self.fusion1 = nn.Linear(self.f_size, self.f_size)
212 |
213 | def forward(self, vec1, vec2):
214 | features_1 = self.gate0(vec1)
215 | features_2 = self.gate1(vec2)
216 | t = torch.sigmoid(self.fusion0(features_1) + self.fusion1(features_2))
217 | f = t * features_1 + (1 - t) * features_2
218 | return f
219 |
220 |
221 | class DSRAN(nn.Module):
222 |
223 | def __init__(self, opt):
224 | super(DSRAN, self).__init__()
225 | self.K = opt.K
226 | self.img_enc = EncoderImageFull(opt)
227 | self.rcnn_enc = RcnnEncoder(opt)
228 | self.txt_enc = EncoderText(opt)
229 | config_rcnn = GATopt(opt.embed_size, 1)
230 | config_img= GATopt(opt.embed_size, 1)
231 | config_cap= GATopt(opt.embed_size, 1)
232 | config_joint= GATopt(opt.embed_size, 1)
233 | # SSR
234 | self.gat_1 = GAT(config_rcnn)
235 | self.gat_2 = GAT(config_img)
236 | self.gat_cap = GAT(config_cap)
237 | # JSR
238 | self.gat_cat_1 = GAT(config_joint)
239 | if self.K == 2:
240 | self.gat_cat_2 = GAT(config_joint)
241 | self.fusion = Fusion(opt)
242 | elif self.K == 4:
243 | self.gat_cat_2 = GAT(config_joint)
244 | self.gat_cat_3 = GAT(config_joint)
245 | self.gat_cat_4 = GAT(config_joint)
246 | self.fusion = Fusion(opt)
247 | self.fusion2 = Fusion(opt)
248 | self.fusion3 = Fusion(opt)
249 |
250 | def forward(self, images, img_rcnn, img_pos, captions, lengths):
251 | img_emb_orig = self.gat_2(self.img_enc(images))
252 | rcnn_emb = self.rcnn_enc(img_rcnn, img_pos)
253 | rcnn_emb = self.gat_1(rcnn_emb)
254 | img_cat = torch.cat((img_emb_orig, rcnn_emb), 1)
255 | img_cat_1 = self.gat_cat_1(img_cat)
256 | img_cat_1 = torch.mean(img_cat_1, dim=1)
257 | if self.K == 1:
258 | img_cat = img_cat_1
259 | elif self.K == 2:
260 | img_cat_2 = self.gat_cat_2(img_cat)
261 | img_cat_2 = torch.mean(img_cat_2, dim=1)
262 | img_cat = self.fusion(img_cat_1, img_cat_2)
263 | elif self.K == 4:
264 | img_cat_2 = self.gat_cat_2(img_cat)
265 | img_cat_2 = torch.mean(img_cat_2, dim=1)
266 | img_cat_3 = self.gat_cat_3(img_cat)
267 | img_cat_3 = torch.mean(img_cat_3, dim=1)
268 | img_cat_4 = self.gat_cat_4(img_cat)
269 | img_cat_4 = torch.mean(img_cat_4, dim=1)
270 | img_cat_1_1 = self.fusion(img_cat_1, img_cat_2)
271 | img_cat_1_2 = self.fusion2(img_cat_3, img_cat_4)
272 | img_cat = self.fusion3(img_cat_1_1, img_cat_1_2)
273 | img_emb = l2norm(img_cat)
274 | cap_emb, cap_emb_mean = self.txt_enc(captions, lengths)
275 | cap_gat = self.gat_cap(cap_emb)
276 | cap_embs = l2norm(torch.mean(cap_gat, dim=1))
277 |
278 | return img_emb, cap_embs
279 |
280 |
281 | def cosine_sim(im, s):
282 | """Cosine similarity between all the image and sentence pairs
283 | """
284 | return im.mm(s.t())
285 |
286 | class ContrastiveLoss(nn.Module):
287 | """
288 | Compute contrastive loss
289 | """
290 |
291 | def __init__(self, margin=0):
292 | super(ContrastiveLoss, self).__init__()
293 | self.margin = margin
294 | self.sim = cosine_sim
295 |
296 | def forward(self, im, s):
297 | # compute image-sentence score matrix
298 | scores = self.sim(im, s)
299 | diagonal = scores.diag().view(im.size(0), 1)
300 |
301 | d1 = diagonal.expand_as(scores)
302 | d2 = diagonal.t().expand_as(scores)
303 | im_sn = scores - d1
304 | c_sn = scores - d2
305 | # compare every diagonal score to scores in its column
306 | # caption retrieval
307 | cost_s = (self.margin + scores - d1).clamp(min=0)
308 | # compare every diagonal score to scores in its row
309 | # image retrieval
310 | cost_im = (self.margin + scores - d2).clamp(min=0)
311 | # clear diagonals
312 | mask = torch.eye(scores.size(0)) > .5
313 | I = Variable(mask)
314 | if torch.cuda.is_available():
315 | I = I.cuda()
316 | cost_s = cost_s.masked_fill_(I, 0)
317 | cost_im = cost_im.masked_fill_(I, 0)
318 |
319 | # keep the maximum violating negative for each query
320 |
321 | cost_s = cost_s.max(1)[0]
322 | cost_im = cost_im.max(0)[0]
323 |
324 | return cost_s.sum() + cost_im.sum()
325 |
326 |
327 | class VSE(object):
328 | """
329 | rkiros/uvs model
330 | """
331 | def __init__(self, opt):
332 | # tutorials/09 - Image Captioning
333 | # Build Models
334 | self.grad_clip = opt.grad_clip
335 |
336 | self.DSRAN = DSRAN(opt)
337 | if torch.cuda.is_available():
338 | self.DSRAN.cuda()
339 | cudnn.benchmark = True
340 | # Loss and Optimizer
341 | self.criterion = ContrastiveLoss(margin=opt.margin)
342 | params = list(self.DSRAN.parameters())
343 |
344 | self.params = params
345 |
346 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
347 |
348 | self.Eiters = 0
349 |
350 | def state_dict(self):
351 | state_dict = [self.DSRAN.state_dict()]
352 | return state_dict
353 |
354 | def load_state_dict(self, state_dict):
355 | self.DSRAN.load_state_dict(state_dict[0])
356 |
357 | def train_start(self):
358 | """switch to train mode
359 | """
360 | self.DSRAN.train()
361 |
362 | def val_start(self):
363 | """switch to evaluate mode
364 | """
365 | self.DSRAN.eval()
366 |
367 | def forward_emb(self, images, captions, img_rcnn, img_pos, lengths, volatile=False):
368 | """Compute the image and caption embeddings
369 | """
370 | # Set mini-batch dataset
371 |
372 | if torch.cuda.is_available():
373 | images = images.cuda()
374 | captions = captions.cuda()
375 | img_rcnn = img_rcnn.cuda()
376 | img_pos = img_pos.cuda()
377 |
378 | img_emb, cap_emb = self.DSRAN(images, img_rcnn, img_pos, captions, lengths)
379 | return img_emb, cap_emb
380 |
381 | def forward_loss(self, img_emb, cap_emb, **kwargs):
382 | """Compute the loss given pairs of image and caption embeddings
383 | """
384 | loss = self.criterion(img_emb, cap_emb)
385 | self.logger.update('Le', loss.data, img_emb.size(0))
386 | return loss
387 |
388 | def train_emb(self, images, captions, img_rcnn, img_pos, lengths, ids=None, *args):
389 | """One training step given images and captions.
390 | """
391 | self.Eiters += 1
392 | self.logger.update('Eit', self.Eiters)
393 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
394 |
395 | # compute the embeddings
396 | img_emb, cap_emb = self.forward_emb(images, captions, img_rcnn, img_pos, lengths)
397 | # measure accuracy and record loss
398 | self.optimizer.zero_grad()
399 | loss = self.forward_loss(img_emb, cap_emb)
400 |
401 | # compute gradient and do SGD step
402 | loss.backward()
403 | if self.grad_clip > 0:
404 | clip_grad_norm_(self.params, self.grad_clip)
405 | self.optimizer.step()
406 |
407 |
--------------------------------------------------------------------------------
/model_bert.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen & Linyang Li, 2020
8 | # ------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init
13 | from torch.autograd import Variable
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn.functional as F
16 | import numpy as np
17 | from collections import OrderedDict
18 | import copy
19 | from resnet import resnet152
20 | from pytorch_pretrained_bert.modeling import BertModel
21 | from pytorch_pretrained_bert.optimization import BertAdam
22 | import time
23 | from GAT import GATLayer
24 |
25 |
26 | def l2norm(X):
27 | norm = torch.pow(X, 2).sum(dim=-1, keepdim=True).sqrt()
28 | X = torch.div(X, norm)
29 | return X
30 |
31 |
32 | class RcnnEncoder(nn.Module):
33 | def __init__(self, opt):
34 | super(RcnnEncoder, self).__init__()
35 | self.embed_size = opt.embed_size
36 | self.fc_image = nn.Sequential(nn.Linear(opt.img_dim, opt.img_dim),
37 | nn.ReLU(),
38 | nn.Linear(opt.img_dim, self.embed_size),
39 | nn.ReLU(),
40 | nn.Dropout(0.1))
41 | self.fc_pos = nn.Sequential(nn.Linear(6 + 1601, self.embed_size),
42 | nn.ReLU(),
43 | nn.Dropout(0.1))
44 | self.fc = nn.Linear(self.embed_size * 2, self.embed_size)
45 |
46 | def forward(self, images, img_pos): # (b, 100, 2048) (b,100,1601+6)
47 | img_f = self.fc_image(images)
48 | img_pe = self.fc_pos(img_pos)
49 | img_embs = img_f + img_pe
50 | return img_embs # (b,100,768)
51 |
52 |
53 | class ImageEncoder(nn.Module):
54 |
55 | def __init__(self, opt):
56 | super(ImageEncoder, self).__init__()
57 | self.embed_size = opt.embed_size
58 | self.cnn = resnet152(pretrained=True)
59 | self.fc = nn.Sequential(nn.Linear(opt.img_dim, opt.embed_size), nn.ReLU(), nn.Dropout(0.1))
60 | if not opt.ft_res:
61 | print('image-encoder-resnet no grad!')
62 | for param in self.cnn.parameters():
63 | param.requires_grad = False
64 | else:
65 | print('image-encoder-resnet fine-tuning !')
66 |
67 | # def load_state_dict(self, state_dict):
68 | # if 'cnn.classifier.1.weight' in state_dict:
69 | # state_dict['cnn.classifier.0.weight'] = state_dict[
70 | # 'cnn.classifier.1.weight']
71 | # del state_dict['cnn.classifier.1.weight']
72 | # state_dict['cnn.classifier.0.bias'] = state_dict[
73 | # 'cnn.classifier.1.bias']
74 | # del state_dict['cnn.classifier.1.bias']
75 | # state_dict['cnn.classifier.3.weight'] = state_dict[
76 | # 'cnn.classifier.4.weight']
77 | # del state_dict['cnn.classifier.4.weight']
78 | # state_dict['cnn.classifier.3.bias'] = state_dict[
79 | # 'cnn.classifier.4.bias']
80 | # del state_dict['cnn.classifier.4.bias']
81 |
82 | # super(ImageEncoder, self).load_state_dict(state_dict)
83 |
84 | def forward(self, images):
85 | features_orig = self.cnn(images)
86 | features_top = features_orig[-1]
87 | features = features_top.view(features_top.size(0), features_top.size(1), -1).transpose(2, 1) # b, 49, 2048
88 | features = self.fc(features)
89 | return features
90 |
91 |
92 | class TextEncoder(nn.Module):
93 | def __init__(self, opt):
94 | super(TextEncoder, self).__init__()
95 | self.bert = BertModel.from_pretrained(opt.bert_path)
96 | if not opt.ft_bert:
97 | for param in self.bert.parameters():
98 | param.requires_grad = False
99 | print('text-encoder-bert no grad')
100 | else:
101 | print('text-encoder-bert fine-tuning !')
102 | self.embed_size = opt.embed_size
103 | self.fc = nn.Sequential(nn.Linear(opt.bert_size, opt.embed_size), nn.ReLU(), nn.Dropout(0.1))
104 |
105 | def forward(self, captions):
106 | all_encoders, pooled = self.bert(captions)
107 | out = all_encoders[-1]
108 | out = self.fc(out)
109 | return out
110 |
111 |
112 | class GATopt(object):
113 | def __init__(self, hidden_size, num_layers):
114 | self.hidden_size = hidden_size
115 | self.num_layers = num_layers
116 | self.num_attention_heads = 8
117 | self.hidden_dropout_prob = 0.2
118 | self.attention_probs_dropout_prob = 0.2
119 |
120 |
121 | class GAT(nn.Module):
122 | def __init__(self, config_gat):
123 | super(GAT, self).__init__()
124 | layer = GATLayer(config_gat)
125 | self.encoder = nn.ModuleList([copy.deepcopy(layer) for _ in range(config_gat.num_layers)])
126 |
127 | def forward(self, input_graph):
128 | hidden_states = input_graph
129 | for layer_module in self.encoder:
130 | hidden_states = layer_module(hidden_states)
131 | return hidden_states # B, seq_len, D
132 |
133 |
134 | def cosine_sim(im, s):
135 | return im.mm(s.t())
136 |
137 |
138 | class ContrastiveLoss(nn.Module):
139 | def __init__(self, margin=0):
140 | super(ContrastiveLoss, self).__init__()
141 | self.margin = margin
142 | self.sim = cosine_sim
143 |
144 | def forward(self, im, s):
145 | scores = self.sim(im, s)
146 | diagonal = scores.diag().view(im.size(0), 1)
147 |
148 | d1 = diagonal.expand_as(scores)
149 | d2 = diagonal.t().expand_as(scores)
150 | im_sn = scores - d1
151 | c_sn = scores - d2
152 | cost_s = (self.margin + scores - d1).clamp(min=0)
153 |
154 | cost_im = (self.margin + scores - d2).clamp(min=0)
155 |
156 | mask = torch.eye(scores.size(0)) > .5
157 | I = Variable(mask)
158 | if torch.cuda.is_available():
159 | I = I.cuda()
160 | cost_s = cost_s.masked_fill_(I, 0)
161 | cost_im = cost_im.masked_fill_(I, 0)
162 |
163 | cost_s = cost_s.max(1)[0]
164 | cost_im = cost_im.max(0)[0]
165 | return cost_s.sum() + cost_im.sum()
166 |
167 |
168 | def get_optimizer(params, opt, t_total=-1):
169 | bertadam = BertAdam(params, lr=opt.learning_rate, warmup=opt.warmup, t_total=t_total)
170 | return bertadam
171 |
172 |
173 | class Fusion(nn.Module):
174 | def __init__(self, opt):
175 | super(Fusion, self).__init__()
176 | self.f_size = opt.embed_size
177 | self.gate0 = nn.Linear(self.f_size, self.f_size)
178 | self.gate1 = nn.Linear(self.f_size, self.f_size)
179 |
180 | self.fusion0 = nn.Linear(self.f_size, self.f_size)
181 | self.fusion1 = nn.Linear(self.f_size, self.f_size)
182 |
183 | def forward(self, vec1, vec2):
184 | features_1 = self.gate0(vec1)
185 | features_2 = self.gate1(vec2)
186 | t = torch.sigmoid(self.fusion0(features_1) + self.fusion1(features_2))
187 | f = t * features_1 + (1 - t) * features_2
188 | return f
189 |
190 |
191 | class DSRAN(nn.Module):
192 | def __init__(self, opt):
193 | super(DSRAN, self).__init__()
194 | self.img_enc = ImageEncoder(opt)
195 | self.txt_enc = TextEncoder(opt)
196 | self.rcnn_enc = RcnnEncoder(opt)
197 |
198 | config_img = GATopt(opt.embed_size, 1)
199 | config_cap = GATopt(opt.embed_size, 1)
200 | config_rcnn = GATopt(opt.embed_size, 1)
201 | config_joint = GATopt(opt.embed_size, 1)
202 |
203 | self.K = opt.K
204 | # SSR
205 | self.gat_1 = GAT(config_img)
206 | self.gat_2 = GAT(config_rcnn)
207 | self.gat_cap = GAT(config_cap)
208 | # JSR
209 | self.gat_cat = GAT(config_joint)
210 | if self.K == 2:
211 | self.gat_cat_1 = GAT(config_joint)
212 | self.fusion = Fusion(opt)
213 | elif self.K == 4:
214 | self.gat_cat_1 = GAT(config_joint)
215 | self.gat_cat_2 = GAT(config_joint)
216 | self.gat_cat_3 = GAT(config_joint)
217 |
218 | self.fusion = Fusion(opt)
219 | self.fusion_1 = Fusion(opt)
220 | self.fusion_2 = Fusion(opt)
221 |
222 | def forward(self, images_orig, rcnn_fe, img_pos, captions):
223 |
224 | img_emb_orig = self.gat_1(self.img_enc(images_orig))
225 | rcnn_emb = self.rcnn_enc(rcnn_fe, img_pos)
226 | rcnn_emb = self.gat_2(rcnn_emb)
227 | img_cat = torch.cat((img_emb_orig, rcnn_emb), 1)
228 | img_cat_1 = self.gat_cat(img_cat)
229 | img_cat_1 = torch.mean(img_cat_1, dim=1)
230 | if self.K == 1:
231 | img_cat = img_cat_1
232 | elif self.K == 2:
233 | img_cat_2 = self.gat_cat_1(img_cat)
234 | img_cat_2 = torch.mean(img_cat_2, dim=1)
235 | img_cat = self.fusion(img_cat_1, img_cat_2)
236 | elif self.K == 4:
237 | img_cat_2 = self.gat_cat_1(img_cat)
238 | img_cat_2 = torch.mean(img_cat_2, dim=1)
239 | img_cat_3 = self.gat_cat_2(img_cat)
240 | img_cat_3 = torch.mean(img_cat_3, dim=1)
241 | img_cat_4 = self.gat_cat_3(img_cat)
242 | img_cat_4 = torch.mean(img_cat_4, dim=1)
243 | img_cat_1_1 = self.fusion_1(img_cat_1, img_cat_2)
244 | img_cat_1_2 = self.fusion_2(img_cat_3, img_cat_4)
245 | img_cat = self.fusion(img_cat_1_1, img_cat_1_2)
246 | img_emb = l2norm(img_cat)
247 | cap_emb = self.txt_enc(captions)
248 | cap_gat = self.gat_cap(cap_emb)
249 | cap_embs = l2norm(torch.mean(cap_gat, dim=1))
250 |
251 | return img_emb, cap_embs
252 |
253 |
254 | class VSE(object):
255 |
256 | def __init__(self, opt):
257 | self.DSRAN = DSRAN(opt)
258 | self.DSRAN = nn.DataParallel(self.DSRAN)
259 | if torch.cuda.is_available():
260 | self.DSRAN.cuda()
261 | cudnn.benchmark = True
262 | self.criterion = ContrastiveLoss(margin=opt.margin)
263 | params = list(self.DSRAN.named_parameters())
264 | param_optimizer = params
265 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
266 | optimizer_grouped_parameters = [
267 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
268 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
269 | ]
270 | t_total = opt.l_train * opt.num_epochs
271 | if opt.warmup == -1:
272 | t_total = -1
273 | self.optimizer = get_optimizer(params=optimizer_grouped_parameters, opt=opt, t_total=t_total)
274 | self.Eiters = 0
275 |
276 | def state_dict(self):
277 | state_dict = self.DSRAN.state_dict()
278 | return state_dict
279 |
280 | def load_state_dict(self, state_dict):
281 | self.DSRAN.load_state_dict(state_dict)
282 |
283 | def train_start(self):
284 | self.DSRAN.train()
285 |
286 | def val_start(self):
287 | self.DSRAN.eval()
288 |
289 | def forward_emb(self, images_orig, rcnn_fe, img_pos, captions):
290 | if torch.cuda.is_available():
291 | images_orig = images_orig.cuda()
292 | rcnn_fe = rcnn_fe.cuda()
293 | img_pos = img_pos.cuda()
294 | captions = captions.cuda()
295 |
296 | img_emb, cap_emb = self.DSRAN(images_orig, rcnn_fe, img_pos, captions)
297 |
298 | return img_emb, cap_emb
299 |
300 | def forward_loss(self, img_emb, cap_emb, **kwargs):
301 | loss = self.criterion(img_emb, cap_emb)
302 | self.logger.update('Le', loss.data, img_emb.size(0))
303 | return loss
304 |
305 | def train_emb(self, images, images_orig, img_pos, captions, ids=None, *args):
306 | self.Eiters += 1
307 | self.logger.update('Eit', self.Eiters)
308 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
309 |
310 | img_emb, cap_emb = self.forward_emb(images_orig, images, img_pos, captions)
311 |
312 | self.optimizer.zero_grad()
313 | loss = self.forward_loss(img_emb, cap_emb)
314 |
315 | loss.backward()
316 | self.optimizer.step()
317 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/pytorch_pretrained_bert/.DS_Store
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | import boto3
19 | import requests
20 | from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
31 | Path.home() / '.pytorch_pretrained_bert'))
32 | except (AttributeError, ImportError):
33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
35 |
36 | CONFIG_NAME = "config.json"
37 | WEIGHTS_NAME = "pytorch_model.bin"
38 |
39 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
40 |
41 |
42 | def url_to_filename(url, etag=None):
43 | """
44 | Convert `url` into a hashed filename in a repeatable way.
45 | If `etag` is specified, append its hash to the url's, delimited
46 | by a period.
47 | """
48 | url_bytes = url.encode('utf-8')
49 | url_hash = sha256(url_bytes)
50 | filename = url_hash.hexdigest()
51 |
52 | if etag:
53 | etag_bytes = etag.encode('utf-8')
54 | etag_hash = sha256(etag_bytes)
55 | filename += '.' + etag_hash.hexdigest()
56 |
57 | return filename
58 |
59 |
60 | def filename_to_url(filename, cache_dir=None):
61 | """
62 | Return the url and etag (which may be ``None``) stored for `filename`.
63 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
64 | """
65 | if cache_dir is None:
66 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
67 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
68 | cache_dir = str(cache_dir)
69 |
70 | cache_path = os.path.join(cache_dir, filename)
71 | if not os.path.exists(cache_path):
72 | raise EnvironmentError("file {} not found".format(cache_path))
73 |
74 | meta_path = cache_path + '.json'
75 | if not os.path.exists(meta_path):
76 | raise EnvironmentError("file {} not found".format(meta_path))
77 |
78 | with open(meta_path, encoding="utf-8") as meta_file:
79 | metadata = json.load(meta_file)
80 | url = metadata['url']
81 | etag = metadata['etag']
82 |
83 | return url, etag
84 |
85 |
86 | def cached_path(url_or_filename, cache_dir=None):
87 | """
88 | Given something that might be a URL (or might be a local path),
89 | determine which. If it's a URL, download the file and cache it, and
90 | return the path to the cached file. If it's already a local path,
91 | make sure the file exists and then return the path.
92 | """
93 | if cache_dir is None:
94 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
95 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
96 | url_or_filename = str(url_or_filename)
97 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
98 | cache_dir = str(cache_dir)
99 |
100 | parsed = urlparse(url_or_filename)
101 |
102 | if parsed.scheme in ('http', 'https', 's3'):
103 | # URL, so get it from the cache (downloading if necessary)
104 | return get_from_cache(url_or_filename, cache_dir)
105 | elif os.path.exists(url_or_filename):
106 | # File, and it exists.
107 | return url_or_filename
108 | elif parsed.scheme == '':
109 | # File, but it doesn't exist.
110 | raise EnvironmentError("file {} not found".format(url_or_filename))
111 | else:
112 | # Something unknown
113 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
114 |
115 |
116 | def split_s3_path(url):
117 | """Split a full s3 path into the bucket name and path."""
118 | parsed = urlparse(url)
119 | if not parsed.netloc or not parsed.path:
120 | raise ValueError("bad s3 path {}".format(url))
121 | bucket_name = parsed.netloc
122 | s3_path = parsed.path
123 | # Remove '/' at beginning of path.
124 | if s3_path.startswith("/"):
125 | s3_path = s3_path[1:]
126 | return bucket_name, s3_path
127 |
128 |
129 | def s3_request(func):
130 | """
131 | Wrapper function for s3 requests in order to create more helpful error
132 | messages.
133 | """
134 |
135 | @wraps(func)
136 | def wrapper(url, *args, **kwargs):
137 | try:
138 | return func(url, *args, **kwargs)
139 | except ClientError as exc:
140 | if int(exc.response["Error"]["Code"]) == 404:
141 | raise EnvironmentError("file {} not found".format(url))
142 | else:
143 | raise
144 |
145 | return wrapper
146 |
147 |
148 | @s3_request
149 | def s3_etag(url):
150 | """Check ETag on S3 object."""
151 | s3_resource = boto3.resource("s3")
152 | bucket_name, s3_path = split_s3_path(url)
153 | s3_object = s3_resource.Object(bucket_name, s3_path)
154 | return s3_object.e_tag
155 |
156 |
157 | @s3_request
158 | def s3_get(url, temp_file):
159 | """Pull a file directly from S3."""
160 | s3_resource = boto3.resource("s3")
161 | bucket_name, s3_path = split_s3_path(url)
162 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
163 |
164 |
165 | def http_get(url, temp_file):
166 | req = requests.get(url, stream=True)
167 | content_length = req.headers.get('Content-Length')
168 | total = int(content_length) if content_length is not None else None
169 | progress = tqdm(unit="B", total=total)
170 | for chunk in req.iter_content(chunk_size=1024):
171 | if chunk: # filter out keep-alive new chunks
172 | progress.update(len(chunk))
173 | temp_file.write(chunk)
174 | progress.close()
175 |
176 |
177 | def get_from_cache(url, cache_dir=None):
178 | """
179 | Given a URL, look for the corresponding dataset in the local cache.
180 | If it's not there, download it. Then return the path to the cached file.
181 | """
182 | if cache_dir is None:
183 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
184 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
185 | cache_dir = str(cache_dir)
186 |
187 | if not os.path.exists(cache_dir):
188 | os.makedirs(cache_dir)
189 |
190 | # Get eTag to add to filename, if it exists.
191 | if url.startswith("s3://"):
192 | etag = s3_etag(url)
193 | else:
194 | response = requests.head(url, allow_redirects=True)
195 | if response.status_code != 200:
196 | raise IOError("HEAD request failed for url {} with status code {}"
197 | .format(url, response.status_code))
198 | etag = response.headers.get("ETag")
199 |
200 | filename = url_to_filename(url, etag)
201 |
202 | # get cache path to put the file
203 | cache_path = os.path.join(cache_dir, filename)
204 |
205 | if not os.path.exists(cache_path):
206 | # Download to temporary file, then copy to cache dir once finished.
207 | # Otherwise you get corrupt cache entries if the download gets interrupted.
208 | with tempfile.NamedTemporaryFile() as temp_file:
209 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
210 |
211 | # GET file object
212 | if url.startswith("s3://"):
213 | s3_get(url, temp_file)
214 | else:
215 | http_get(url, temp_file)
216 |
217 | # we are copying the file before closing it, so flush to avoid truncation
218 | temp_file.flush()
219 | # shutil.copyfileobj() starts at the current position, so go to the start
220 | temp_file.seek(0)
221 |
222 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
223 | with open(cache_path, 'wb') as cache_file:
224 | shutil.copyfileobj(temp_file, cache_file)
225 |
226 | logger.info("creating metadata file for %s", cache_path)
227 | meta = {'url': url, 'etag': etag}
228 | meta_path = cache_path + '.json'
229 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
230 | json.dump(meta, meta_file)
231 |
232 | logger.info("removing temp file %s", temp_file.name)
233 |
234 | return cache_path
235 |
236 |
237 | def read_set_from_file(filename):
238 | '''
239 | Extract a de-duped collection (set) of text from a file.
240 | Expected file format is one item per line.
241 | '''
242 | collection = set()
243 | with open(filename, 'r', encoding='utf-8') as file_:
244 | for line in file_:
245 | collection.add(line.rstrip())
246 | return collection
247 |
248 |
249 | def get_file_extension(path, dot=True, lower=True):
250 | ext = os.path.splitext(path)[1]
251 | ext = ext if dot else ext[1:]
252 | return ext.lower() if lower else ext
253 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | def warmup_cosine(x, warmup=0.002):
27 | if x < warmup:
28 | return x/warmup
29 | x_ = (x - warmup) / (1 - warmup) # progress after warmup -
30 | return 0.5 * (1. + math.cos(math.pi * x_))
31 |
32 | def warmup_constant(x, warmup=0.002):
33 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
34 | Learning rate is 1. afterwards. """
35 | if x < warmup:
36 | return x/warmup
37 | return 1.0
38 |
39 | def warmup_linear(x, warmup=0.002):
40 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
41 | After `t_total`-th training step, learning rate is zero. """
42 | if x < warmup:
43 | return x/warmup
44 | return max((x-1.)/(warmup-1.), 0)
45 |
46 | SCHEDULES = {
47 | 'warmup_cosine': warmup_cosine,
48 | 'warmup_constant': warmup_constant,
49 | 'warmup_linear': warmup_linear,
50 | }
51 |
52 |
53 | class BertAdam(Optimizer):
54 | """Implements BERT version of Adam algorithm with weight decay fix.
55 | Params:
56 | lr: learning rate
57 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
58 | t_total: total number of training steps for the learning
59 | rate schedule, -1 means constant learning rate. Default: -1
60 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
61 | b1: Adams b1. Default: 0.9
62 | b2: Adams b2. Default: 0.999
63 | e: Adams epsilon. Default: 1e-6
64 | weight_decay: Weight decay. Default: 0.01
65 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
66 | """
67 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
68 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
69 | max_grad_norm=1.0):
70 | if lr is not required and lr < 0.0:
71 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
72 | if schedule not in SCHEDULES:
73 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
74 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
75 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
76 | if not 0.0 <= b1 < 1.0:
77 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
78 | if not 0.0 <= b2 < 1.0:
79 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
80 | if not e >= 0.0:
81 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
82 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
83 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
84 | max_grad_norm=max_grad_norm)
85 | super(BertAdam, self).__init__(params, defaults)
86 |
87 | def get_lr(self):
88 | lr = []
89 | for group in self.param_groups:
90 | for p in group['params']:
91 | state = self.state[p]
92 | if len(state) == 0:
93 | return [0]
94 | if group['t_total'] != -1:
95 | schedule_fct = SCHEDULES[group['schedule']]
96 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
97 | else:
98 | lr_scheduled = group['lr']
99 | lr.append(lr_scheduled)
100 | return lr
101 |
102 | def step(self, closure=None):
103 | """Performs a single optimization step.
104 |
105 | Arguments:
106 | closure (callable, optional): A closure that reevaluates the model
107 | and returns the loss.
108 | """
109 | loss = None
110 | if closure is not None:
111 | loss = closure()
112 |
113 | warned_for_t_total = False
114 |
115 | for group in self.param_groups:
116 | for p in group['params']:
117 | if p.grad is None:
118 | continue
119 | grad = p.grad.data
120 | if grad.is_sparse:
121 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
122 |
123 | state = self.state[p]
124 |
125 | # State initialization
126 | if len(state) == 0:
127 | state['step'] = 0
128 | # Exponential moving average of gradient values
129 | state['next_m'] = torch.zeros_like(p.data)
130 | # Exponential moving average of squared gradient values
131 | state['next_v'] = torch.zeros_like(p.data)
132 |
133 | next_m, next_v = state['next_m'], state['next_v']
134 | beta1, beta2 = group['b1'], group['b2']
135 |
136 | # Add grad clipping
137 | if group['max_grad_norm'] > 0:
138 | clip_grad_norm_(p, group['max_grad_norm'])
139 |
140 | # Decay the first and second moment running average coefficient
141 | # In-place operations to update the averages at the same time
142 | next_m.mul_(beta1).add_(1 - beta1, grad)
143 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
144 | update = next_m / (next_v.sqrt() + group['e'])
145 |
146 | # Just adding the square of the weights to the loss function is *not*
147 | # the correct way of using L2 regularization/weight decay with Adam,
148 | # since that will interact with the m and v parameters in strange ways.
149 | #
150 | # Instead we want to decay the weights in a manner that doesn't interact
151 | # with the m/v parameters. This is equivalent to adding the square
152 | # of the weights to the loss with plain (non-momentum) SGD.
153 | if group['weight_decay'] > 0.0:
154 | update += group['weight_decay'] * p.data
155 |
156 | if group['t_total'] != -1:
157 | schedule_fct = SCHEDULES[group['schedule']]
158 | progress = state['step']/group['t_total']
159 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
160 | # warning for exceeding t_total (only active with warmup_linear
161 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
162 | logger.warning(
163 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
164 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
165 | warned_for_t_total = True
166 | # end warning
167 | else:
168 | lr_scheduled = group['lr']
169 |
170 | update_with_lr = lr_scheduled * update
171 | p.data.add_(-update_with_lr)
172 |
173 | state['step'] += 1
174 |
175 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
176 | # No bias correction
177 | # bias_correction1 = 1 - beta1 ** state['step']
178 | # bias_correction2 = 1 - beta2 ** state['step']
179 |
180 | return loss
181 |
--------------------------------------------------------------------------------
/pytorch_pretrained_bert/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import collections
20 | import logging
21 | import os
22 | import unicodedata
23 | from io import open
24 |
25 | from .file_utils import cached_path
26 | import json
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 | PRETRAINED_VOCAB_ARCHIVE_MAP = {
31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
38 | }
39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
40 | 'bert-base-uncased': 512,
41 | 'bert-large-uncased': 512,
42 | 'bert-base-cased': 512,
43 | 'bert-large-cased': 512,
44 | 'bert-base-multilingual-uncased': 512,
45 | 'bert-base-multilingual-cased': 512,
46 | 'bert-base-chinese': 512,
47 | }
48 | VOCAB_NAME = 'vocab.txt'
49 |
50 |
51 | def load_vocab(vocab_file):
52 | """Loads a vocabulary file into a dictionary."""
53 | vocab = collections.OrderedDict()
54 | index = 0
55 | with open(vocab_file, "r", encoding="utf-8") as reader:
56 | while True:
57 | token = reader.readline()
58 | if not token:
59 | break
60 | token = token.strip()
61 | vocab[token] = index
62 | index += 1
63 | return vocab
64 |
65 |
66 | def whitespace_tokenize(text):
67 | """Runs basic whitespace cleaning and splitting on a piece of text."""
68 | text = text.strip()
69 | if not text:
70 | return []
71 | tokens = text.split()
72 | return tokens
73 |
74 |
75 | class BertTokenizer(object):
76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
77 |
78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
80 | """Constructs a BertTokenizer.
81 |
82 | Args:
83 | vocab_file: Path to a one-wordpiece-per-line vocabulary file
84 | do_lower_case: Whether to lower case the input
85 | Only has an effect when do_wordpiece_only=False
86 | do_basic_tokenize: Whether to do basic tokenization before wordpiece.
87 | max_len: An artificial maximum length to truncate tokenized sequences to;
88 | Effective maximum length is always the minimum of this
89 | value (if specified) and the underlying BERT model's
90 | sequence length.
91 | never_split: List of tokens which will never be split during tokenization.
92 | Only has an effect when do_wordpiece_only=False
93 | """
94 | if not os.path.isfile(vocab_file):
95 | raise ValueError(
96 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
97 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
98 | self.vocab = load_vocab(vocab_file)
99 |
100 | self.ids_to_tokens = collections.OrderedDict(
101 | [(ids, tok) for tok, ids in self.vocab.items()])
102 | self.do_basic_tokenize = do_basic_tokenize
103 | if do_basic_tokenize:
104 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
105 | never_split=never_split)
106 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
107 | # print('core_vocab loaded_ (norm for squad gen)')
108 | self.max_len = max_len if max_len is not None else int(1e12)
109 |
110 | def tokenize(self, text):
111 | split_tokens = []
112 | if self.do_basic_tokenize:
113 | for token in self.basic_tokenizer.tokenize(text):
114 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
115 | split_tokens.append(sub_token)
116 | else:
117 | split_tokens = self.wordpiece_tokenizer.tokenize(text)
118 | return split_tokens
119 |
120 | def convert_tokens_to_ids(self, tokens):
121 | """Converts a sequence of tokens into ids using the vocab."""
122 | ids = []
123 | for token in tokens:
124 | ids.append(self.vocab[token])
125 | if len(ids) > self.max_len:
126 | logger.warning(
127 | "Token indices sequence length is longer than the specified maximum "
128 | " sequence length for this BERT model ({} > {}). Running this"
129 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
130 | )
131 | return ids
132 |
133 | def convert_ids_to_tokens(self, ids):
134 | """Converts a sequence of ids in wordpiece tokens using the vocab."""
135 | tokens = []
136 | for i in ids:
137 | tokens.append(self.ids_to_tokens[i])
138 | return tokens
139 |
140 | def save_vocabulary(self, vocab_path):
141 | """Save the tokenizer vocabulary to a directory or file."""
142 | index = 0
143 | if os.path.isdir(vocab_path):
144 | vocab_file = os.path.join(vocab_path, VOCAB_NAME)
145 | with open(vocab_file, "w", encoding="utf-8") as writer:
146 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
147 | if index != token_index:
148 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
149 | " Please check that the vocabulary is not corrupted!".format(vocab_file))
150 | index = token_index
151 | writer.write(token + u'\n')
152 | index += 1
153 | return vocab_file
154 |
155 | @classmethod
156 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
157 | """
158 | Instantiate a PreTrainedBertModel from a pre-trained model file.
159 | Download and cache the pre-trained model file if needed.
160 | """
161 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
162 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
163 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
164 | logger.warning("The pre-trained model you are loading is a cased model but you have not set "
165 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
166 | "you may want to check this behavior.")
167 | kwargs['do_lower_case'] = False
168 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
169 | logger.warning("The pre-trained model you are loading is an uncased model but you have set "
170 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
171 | "but you may want to check this behavior.")
172 | kwargs['do_lower_case'] = True
173 | else:
174 | vocab_file = pretrained_model_name_or_path
175 | if os.path.isdir(vocab_file):
176 | vocab_file = os.path.join(vocab_file, VOCAB_NAME)
177 | # redirect to the cache, if necessary
178 | try:
179 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
180 | except EnvironmentError:
181 | logger.error(
182 | "Model name '{}' was not found in model name list ({}). "
183 | "We assumed '{}' was a path or url but couldn't find any file "
184 | "associated to this path or url.".format(
185 | pretrained_model_name_or_path,
186 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
187 | vocab_file))
188 | return None
189 | if resolved_vocab_file == vocab_file:
190 | logger.info("loading vocabulary file {}".format(vocab_file))
191 | else:
192 | logger.info("loading vocabulary file {} from cache at {}".format(
193 | vocab_file, resolved_vocab_file))
194 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
195 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
196 | # than the number of positional embeddings
197 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
198 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
199 | # Instantiate tokenizer.
200 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
201 | return tokenizer
202 |
203 |
204 | class BasicTokenizer(object):
205 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
206 |
207 | def __init__(self,
208 | do_lower_case=True,
209 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
210 | """Constructs a BasicTokenizer.
211 |
212 | Args:
213 | do_lower_case: Whether to lower case the input.
214 | """
215 | self.do_lower_case = do_lower_case
216 | self.never_split = never_split
217 |
218 | def tokenize(self, text):
219 | """Tokenizes a piece of text."""
220 | text = self._clean_text(text)
221 | # This was added on November 1st, 2018 for the multilingual and Chinese
222 | # models. This is also applied to the English models now, but it doesn't
223 | # matter since the English models were not trained on any Chinese data
224 | # and generally don't have any Chinese data in them (there are Chinese
225 | # characters in the vocabulary because Wikipedia does have some Chinese
226 | # words in the English Wikipedia.).
227 | text = self._tokenize_chinese_chars(text)
228 | orig_tokens = whitespace_tokenize(text)
229 | split_tokens = []
230 | for token in orig_tokens:
231 | if self.do_lower_case and token not in self.never_split:
232 | token = token.lower()
233 | token = self._run_strip_accents(token)
234 | split_tokens.extend(self._run_split_on_punc(token))
235 |
236 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
237 | return output_tokens
238 |
239 | def _run_strip_accents(self, text):
240 | """Strips accents from a piece of text."""
241 | text = unicodedata.normalize("NFD", text)
242 | output = []
243 | for char in text:
244 | cat = unicodedata.category(char)
245 | if cat == "Mn":
246 | continue
247 | output.append(char)
248 | return "".join(output)
249 |
250 | def _run_split_on_punc(self, text):
251 | """Splits punctuation on a piece of text."""
252 | if text in self.never_split:
253 | return [text]
254 | chars = list(text)
255 | i = 0
256 | start_new_word = True
257 | output = []
258 | while i < len(chars):
259 | char = chars[i]
260 | if _is_punctuation(char):
261 | output.append([char])
262 | start_new_word = True
263 | else:
264 | if start_new_word:
265 | output.append([])
266 | start_new_word = False
267 | output[-1].append(char)
268 | i += 1
269 |
270 | return ["".join(x) for x in output]
271 |
272 | def _tokenize_chinese_chars(self, text):
273 | """Adds whitespace around any CJK character."""
274 | output = []
275 | for char in text:
276 | cp = ord(char)
277 | if self._is_chinese_char(cp):
278 | output.append(" ")
279 | output.append(char)
280 | output.append(" ")
281 | else:
282 | output.append(char)
283 | return "".join(output)
284 |
285 | def _is_chinese_char(self, cp):
286 | """Checks whether CP is the codepoint of a CJK character."""
287 | # This defines a "chinese character" as anything in the CJK Unicode block:
288 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
289 | #
290 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
291 | # despite its name. The modern Korean Hangul alphabet is a different block,
292 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
293 | # space-separated words, so they are not treated specially and handled
294 | # like the all of the other languages.
295 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
296 | (cp >= 0x3400 and cp <= 0x4DBF) or #
297 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
298 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
299 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
300 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
301 | (cp >= 0xF900 and cp <= 0xFAFF) or #
302 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
303 | return True
304 |
305 | return False
306 |
307 | def _clean_text(self, text):
308 | """Performs invalid character removal and whitespace cleanup on text."""
309 | output = []
310 | for char in text:
311 | cp = ord(char)
312 | if cp == 0 or cp == 0xfffd or _is_control(char):
313 | continue
314 | if _is_whitespace(char):
315 | output.append(" ")
316 | else:
317 | output.append(char)
318 | return "".join(output)
319 |
320 |
321 | class WordpieceTokenizer(object):
322 | """Runs WordPiece tokenization."""
323 |
324 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
325 | self.vocab = vocab
326 | self.unk_token = unk_token
327 | self.max_input_chars_per_word = max_input_chars_per_word
328 |
329 | def tokenize(self, text):
330 | """Tokenizes a piece of text into its word pieces.
331 |
332 | This uses a greedy longest-match-first algorithm to perform tokenization
333 | using the given vocabulary.
334 |
335 | For example:
336 | input = "unaffable"
337 | output = ["un", "##aff", "##able"]
338 |
339 | Args:
340 | text: A single token or whitespace separated tokens. This should have
341 | already been passed through `BasicTokenizer`.
342 |
343 | Returns:
344 | A list of wordpiece tokens.
345 | """
346 |
347 | output_tokens = []
348 | for token in whitespace_tokenize(text):
349 | chars = list(token)
350 | if len(chars) > self.max_input_chars_per_word:
351 | output_tokens.append(self.unk_token)
352 | continue
353 |
354 | is_bad = False
355 | start = 0
356 | sub_tokens = []
357 | while start < len(chars):
358 | end = len(chars)
359 | cur_substr = None
360 | while start < end:
361 | substr = "".join(chars[start:end])
362 | if start > 0:
363 | substr = "##" + substr
364 | if substr in self.vocab:
365 | cur_substr = substr
366 | break
367 | end -= 1
368 | if cur_substr is None:
369 | is_bad = True
370 | break
371 | sub_tokens.append(cur_substr)
372 | start = end
373 |
374 | if is_bad:
375 | output_tokens.append(self.unk_token)
376 | else:
377 | output_tokens.extend(sub_tokens)
378 | return output_tokens
379 |
380 |
381 | def _is_whitespace(char):
382 | """Checks whether `chars` is a whitespace character."""
383 | # \t, \n, and \r are technically contorl characters but we treat them
384 | # as whitespace since they are generally considered as such.
385 | if char == " " or char == "\t" or char == "\n" or char == "\r":
386 | return True
387 | cat = unicodedata.category(char)
388 | if cat == "Zs":
389 | return True
390 | return False
391 |
392 |
393 | def _is_control(char):
394 | """Checks whether `chars` is a control character."""
395 | # These are technically control characters but we count them as whitespace
396 | # characters.
397 | if char == "\t" or char == "\n" or char == "\r":
398 | return False
399 | cat = unicodedata.category(char)
400 | if cat.startswith("C"):
401 | return True
402 | return False
403 |
404 |
405 | def _is_punctuation(char):
406 | """Checks whether `chars` is a punctuation character."""
407 | cp = ord(char)
408 | # We treat all non-letter/number ASCII as punctuation.
409 | # Characters such as "^", "$", and "`" are not in the Unicode
410 | # Punctuation class but we treat them as punctuation anyways, for
411 | # consistency.
412 | if ((33 <= cp <= 47) or (58 <= cp <= 64) or
413 | (91 <= cp <= 96) or (123 <= cp <= 126)):
414 | return True
415 | cat = unicodedata.category(char)
416 | if cat.startswith("P"):
417 | return True
418 | return False
419 |
420 |
--------------------------------------------------------------------------------
/rerank.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Re-ranking and ensemble implementation based on
3 | # "Matching Images and Text with Multi-modal Tensor Fusion and Re-ranking"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen, 2020
8 | # ------------------------------------------------------------
9 |
10 | import numpy as np
11 | import time
12 | import argparse
13 |
14 |
15 | def i2t_rerank(sim, K1, K2): #(d,15,1)
16 |
17 | size_i = sim.shape[0] # d
18 | size_t = sim.shape[1] # 5d
19 | sort_i2t = np.argsort(-sim, 1)
20 | sort_t2i = np.argsort(-sim, 0)
21 | sort_i2t_re = np.copy(sort_i2t)[:, :K1]
22 | address = np.array([])
23 |
24 | for i in range(size_i):
25 | for j in range(K1):
26 | result_t = sort_i2t[i][j]
27 | query = sort_t2i[:, result_t]
28 | # query = sort_t2i[:K2, result_t]
29 | address = np.append(address, np.where(query == i)[0][0])
30 |
31 | sort = np.argsort(address)
32 | sort_i2t_re[i] = sort_i2t_re[i][sort]
33 | address = np.array([])
34 |
35 | sort_i2t[:,:K1] = sort_i2t_re
36 |
37 | return sort_i2t
38 |
39 |
40 | def t2i_rerank(sim, K1, K2):
41 |
42 | size_i = sim.shape[0]
43 | size_t = sim.shape[1]
44 | sort_i2t = np.argsort(-sim, 1)
45 | sort_t2i = np.argsort(-sim, 0)
46 | sort_t2i_re = np.copy(sort_t2i)[:K1, :]
47 | address = np.array([])
48 |
49 | for i in range(size_t):
50 | for j in range(K1):
51 | result_i = sort_t2i[j][i]
52 | query = sort_i2t[result_i, :]
53 | # print(query)
54 | # query = sort_t2i[:K2, result_t]
55 | ranks = 1e20
56 | # for k in range(5):
57 | # qewfo = i//5 * 5 + k
58 | # print(np.where(query == i))
59 | tmp = np.where(query == i)[0][0]
60 | if tmp < ranks:
61 | ranks = tmp
62 | address = np.append(address, ranks)
63 |
64 | sort = np.argsort(address)
65 | sort_t2i_re[:, i] = sort_t2i_re[:, i][sort]
66 | address = np.array([])
67 |
68 | sort_t2i[:K1, :] = sort_t2i_re
69 |
70 | return sort_t2i
71 |
72 |
73 | def t2i_rerank_new(sim, sim_T, K1, K2):
74 |
75 | size_i = sim.shape[0]
76 | size_t = sim.shape[1]
77 | sort_i2t = np.argsort(-sim, 1)
78 | sort_t2i = np.argsort(-sim, 0)
79 | sort_t2i_re = np.copy(sort_t2i)[:K1, :]
80 |
81 | sort_t2t = np.argsort(-sim_T, 1) # 按行从大到小排序
82 | # print(sort_t2t.shape)
83 | sort_t2t_re = np.copy(sort_t2t)[:, :K2]
84 | address = np.array([])
85 |
86 | for i in range(size_t):
87 | for j in range(K1):
88 | result_i = sort_t2i[j][i] # Ij
89 | query = sort_i2t[result_i, :] # 第j张图片对应T的排序
90 | # query = sort_t2i[:K2, result_t]
91 | ranks = 1e20
92 | G = sort_t2t_re[i]
93 | for k in range(K2):
94 | # qewfo = i//5 * 5 + k
95 | # print(qewfo)
96 | tmp = np.where(query == G[k])[0][0]
97 | if tmp < ranks:
98 | ranks = tmp
99 | address = np.append(address, ranks)
100 |
101 | sort = np.argsort(address)
102 | sort_t2i_re[:, i] = sort_t2i_re[:, i][sort]
103 | address = np.array([])
104 |
105 | sort_t2i[:K1, :] = sort_t2i_re
106 |
107 | return sort_t2i
108 |
109 |
110 | def acc_i2t2(input):
111 | """Computes the precision@k for the specified values of k of i2t"""
112 | #input = collect_match(input).numpy()
113 | image_size = input.shape[0]
114 | ranks = np.zeros(image_size)
115 | top1 = np.zeros(image_size)
116 |
117 | for index in range(image_size):
118 | inds = input[index]
119 | # Score
120 | # if index == 197:
121 | # print('s')
122 | rank = 1e20
123 | for i in range(5 * index, min(5 * index + 5, image_size*5), 1):
124 | tmp = np.where(inds == i)[0][0]
125 | if tmp < rank:
126 | rank = tmp
127 | ranks[index] = rank
128 | top1[index] = inds[0]
129 |
130 |
131 | # Compute metrics
132 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
133 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
134 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
135 | medr = np.floor(np.median(ranks)) + 1
136 | meanr = ranks.mean() + 1
137 |
138 | return (r1, r5, r10, medr, meanr), (ranks, top1)
139 |
140 |
141 | def acc_t2i2(input):
142 | """Computes the precision@k for the specified values of k of t2i"""
143 | #input = collect_match(input).numpy()
144 | image_size = input.shape[0]
145 | ranks = np.zeros(5*image_size)
146 | top1 = np.zeros(5*image_size)
147 |
148 | # --> (5N(caption), N(image))
149 | input = input.T
150 |
151 | for index in range(image_size):
152 | for i in range(5):
153 | inds = input[5 * index + i]
154 | ranks[5 * index + i] = np.where(inds == index)[0][0]
155 | top1[5 * index + i] = inds[0]
156 |
157 | # Compute metrics
158 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
159 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
160 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
161 | medr = np.floor(np.median(ranks)) + 1
162 | meanr = ranks.mean() + 1
163 |
164 | return (r1, r5, r10, medr, meanr), (ranks, top1)
165 |
166 |
167 | def main():
168 | parser = argparse.ArgumentParser()
169 | parser.add_argument('--data_name', default='coco', help='data name')
170 | parser.add_argument('--fold', action='store_true', help='fold5')
171 | opt = parser.parse_args()
172 | data = opt.data_name
173 | fold = opt.fold
174 | # The accuracy computing
175 | # Input the prediction similarity score matrix (d * 5d)
176 | if data == 'coco':
177 | if fold == True:
178 | path1 = ''
179 | path = 'coco_sims/'
180 | r1 = np.array((0,0,0))
181 | r1_t = np.array((0,0,0))
182 | r2 = np.array((0,0,0)) # rerank
183 | r2_t = np.array((0,0,0))
184 | for i in range(5):
185 | d1 = np.load(path1+'sims_full_%d.npy' % i)
186 | d2 = np.load(path+'sims_full_%d.npy' % i)
187 |
188 | # d1T = np.load(path1+'sims_full_T_%d.npy' % i)
189 | # d2T = np.load(path+'sims_full_T_%d.npy' % i)
190 |
191 | d = d1+d2
192 | # d_T = d1T+d2T
193 |
194 | t1 = time.time()
195 | # calculate the i2t score after rerank
196 | sort_rerank = i2t_rerank(d, 15, 1)
197 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1))
198 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank)
199 |
200 | print(r1i, r5i, r10i, medri, meanri)
201 | print(r1i2, r5i2, r10i2, medri2, meanri2)
202 | r1 = r1 + np.array((r1i, r5i, r10i))
203 | r2 = r2 + np.array((r1i2, r5i2, r10i2))
204 |
205 | # calculate the t2i score after rerank
206 | # sort_rerank = t2i_rerank(d, 20, 1)
207 | # sort_rerank = t2i_rerank_new(d, d_T, 20, 1)
208 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0))
209 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank)
210 |
211 | print(r1t, r5t, r10t, medrt, meanrt)
212 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2)
213 | # print((r1t, r5t, r10t))
214 | r1_t = r1_t + np.array((r1t, r5t, r10t))
215 | # r2_t = r2_t + np.array((r1t2, r5t2, r10t2))
216 | t2 = time.time()
217 | print(t2-t1)
218 | print('--------------------')
219 | print('5-cross test')
220 | print(r1/5)
221 | print(r1_t/5)
222 | print('rerank!')
223 | print(r2/5)
224 | # print(r2_t/5)
225 | else:
226 | path = 'coco_sims/'
227 | path1 = ''
228 | d1 = np.load(path+'sims_full_5k.npy')
229 | d2 = np.load(path1+'sims_full_5k.npy')
230 | d = d1+ d2
231 | t1 = time.time()
232 | # calculate the i2t score after rerank
233 | sort_rerank = i2t_rerank(d, 15, 1)
234 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1))
235 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank)
236 |
237 | print(r1i, r5i, r10i, medri, meanri)
238 | print(r1i2, r5i2, r10i2, medri2, meanri2)
239 |
240 | # calculate the t2i score after rerank
241 | # sort_rerank = t2i_rerank(d, 20, 1)
242 | # sort_rerank = t2i_rerank_new(d, d_T, 12, 1)
243 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0))
244 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank)
245 |
246 | print(r1t, r5t, r10t, medrt, meanrt)
247 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2)
248 | t2 = time.time()
249 | print(t2-t1)
250 |
251 | else:
252 |
253 | d1 = np.load('flickr_sims/sims_f.npy')
254 | d2 = np.load('sims_f.npy')
255 | d = d1+d2
256 | # d1T = np.load('flickr_sims/sims_f_T.npy')
257 | # d2T = np.load('sims_f_T.npy')
258 |
259 | # d_T = d1T+d2T
260 |
261 | t1 = time.time()
262 | # calculate the i2t score after rerank
263 | sort_rerank = i2t_rerank(d, 15, 1)
264 | (r1i, r5i, r10i, medri, meanri), _ = acc_i2t2(np.argsort(-d, 1))
265 | (r1i2, r5i2, r10i2, medri2, meanri2), _ = acc_i2t2(sort_rerank)
266 |
267 | print(r1i, r5i, r10i, medri, meanri)
268 | print(r1i2, r5i2, r10i2, medri2, meanri2)
269 |
270 |
271 | # calculate the t2i score after rerank
272 |
273 | # sort_rerank = t2i_rerank_new(d, d_T, 20, 4)
274 | (r1t, r5t, r10t, medrt, meanrt), _ = acc_t2i2(np.argsort(-d, 0))
275 | # (r1t2, r5t2, r10t2, medrt2, meanrt2), _ = acc_t2i2(sort_rerank)
276 |
277 | print(r1t, r5t, r10t, medrt, meanrt)
278 | # print(r1t2, r5t2, r10t2, medrt2, meanrt2)
279 | rsum = r1i+r5i+r10i+r1t+r5t+r10t
280 | print('rsum:%f' % rsum)
281 | rsum_rr = r1i2+r5i2+r10i2+r1t+r5t+r10t
282 | print('rsum_rr:%f' % rsum_rr)
283 | t2 = time.time()
284 | print(t2-t1)
285 |
286 | if __name__ == '__main__':
287 | main()
288 |
289 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 |
5 |
6 | RES_NET_file_path = 'resnet152-b121ed2d.pth'
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=dilation, groups=groups, bias=False, dilation=dilation)
13 |
14 |
15 | def conv1x1(in_planes, out_planes, stride=1):
16 | """1x1 convolution"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
24 | base_width=64, dilation=1, norm_layer=None):
25 | super(BasicBlock, self).__init__()
26 | if norm_layer is None:
27 | norm_layer = nn.BatchNorm2d
28 | if groups != 1 or base_width != 64:
29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
30 | if dilation > 1:
31 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
32 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = norm_layer(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = norm_layer(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | identity = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | identity = self.downsample(x)
53 |
54 | out += identity
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
64 | base_width=64, dilation=1, norm_layer=None):
65 | super(Bottleneck, self).__init__()
66 | if norm_layer is None:
67 | norm_layer = nn.BatchNorm2d
68 | width = int(planes * (base_width / 64.)) * groups
69 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
70 | self.conv1 = conv1x1(inplanes, width)
71 | self.bn1 = norm_layer(width)
72 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
73 | self.bn2 = norm_layer(width)
74 | self.conv3 = conv1x1(width, planes * self.expansion)
75 | self.bn3 = norm_layer(planes * self.expansion)
76 | self.relu = nn.ReLU(inplace=True)
77 | self.downsample = downsample
78 | self.stride = stride
79 |
80 | def forward(self, x):
81 | identity = x
82 |
83 | out = self.conv1(x)
84 | out = self.bn1(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv2(out)
88 | out = self.bn2(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv3(out)
92 | out = self.bn3(out)
93 |
94 | if self.downsample is not None:
95 | identity = self.downsample(x)
96 |
97 | out += identity
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
106 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
107 | norm_layer=None):
108 | super(ResNet, self).__init__()
109 | if norm_layer is None:
110 | norm_layer = nn.BatchNorm2d
111 | self._norm_layer = norm_layer
112 |
113 | self.inplanes = 64
114 | self.dilation = 1
115 | if replace_stride_with_dilation is None:
116 | # each element in the tuple indicates if we should replace
117 | # the 2x2 stride with a dilated convolution instead
118 | replace_stride_with_dilation = [False, False, False]
119 | if len(replace_stride_with_dilation) != 3:
120 | raise ValueError("replace_stride_with_dilation should be None "
121 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
122 | self.groups = groups
123 | self.base_width = width_per_group
124 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
125 | bias=False)
126 | self.bn1 = norm_layer(self.inplanes)
127 | self.relu = nn.ReLU(inplace=True)
128 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
129 | self.layer1 = self._make_layer(block, 64, layers[0])
130 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
131 | dilate=replace_stride_with_dilation[0])
132 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
133 | dilate=replace_stride_with_dilation[1])
134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
135 | dilate=replace_stride_with_dilation[2])
136 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
137 | self.fc = nn.Linear(512 * block.expansion, num_classes)
138 |
139 | for m in self.modules():
140 | if isinstance(m, nn.Conv2d):
141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
142 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
143 | nn.init.constant_(m.weight, 1)
144 | nn.init.constant_(m.bias, 0)
145 |
146 | # Zero-initialize the last BN in each residual branch,
147 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
148 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
149 | if zero_init_residual:
150 | for m in self.modules():
151 | if isinstance(m, Bottleneck):
152 | nn.init.constant_(m.bn3.weight, 0)
153 | elif isinstance(m, BasicBlock):
154 | nn.init.constant_(m.bn2.weight, 0)
155 |
156 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
157 | norm_layer = self._norm_layer
158 | downsample = None
159 | previous_dilation = self.dilation
160 | if dilate:
161 | self.dilation *= stride
162 | stride = 1
163 | if stride != 1 or self.inplanes != planes * block.expansion:
164 | downsample = nn.Sequential(
165 | conv1x1(self.inplanes, planes * block.expansion, stride),
166 | norm_layer(planes * block.expansion),
167 | )
168 |
169 | layers = []
170 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
171 | self.base_width, previous_dilation, norm_layer))
172 | self.inplanes = planes * block.expansion
173 | for _ in range(1, blocks):
174 | layers.append(block(self.inplanes, planes, groups=self.groups,
175 | base_width=self.base_width, dilation=self.dilation,
176 | norm_layer=norm_layer))
177 |
178 | return nn.Sequential(*layers)
179 |
180 | def forward(self, x):
181 | x = self.conv1(x)
182 | x = self.bn1(x)
183 | x = self.relu(x)
184 | x = self.maxpool(x)
185 |
186 | x1 = self.layer1(x)
187 | x2 = self.layer2(x1)
188 | x3 = self.layer3(x2)
189 | x4 = self.layer4(x3) # extract the output before avg pooling
190 | # print(x4.size())
191 |
192 | return x1,x2,x3,x4
193 |
194 |
195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
196 | model = ResNet(block, layers, **kwargs)
197 | if pretrained:
198 | state_dict = torch.load(RES_NET_file_path)
199 | model.load_state_dict(state_dict)
200 | return model
201 |
202 |
203 | def resnet152(pretrained=False, progress=True, **kwargs):
204 | r"""ResNet-50 model from
205 | `"Deep Residual Learning for Image Recognition" '_
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | progress (bool): If True, displays a progress bar of the download to stderr
209 | """
210 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
211 | **kwargs)
--------------------------------------------------------------------------------
/runs/BERT/bert_models:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/runs/GRU/gru_models:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/test_bert_cc.sh:
--------------------------------------------------------------------------------
1 | echo "BERT"
2 | echo "MSCOCO"
3 | echo "evalaute cc_model1"
4 | # python evaluation_bert.py --model BERT/cc_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
5 | python evaluation_bert.py --model BERT/cc_model1_ --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval
6 | mv sims_full_0.npy sims_full_1.npy sims_full_2.npy sims_full_3.npy sims_full_4.npy sims_full_5k.npy ./coco_sims
7 | # mv sims_f_T.npy ./flickr_sims
8 | echo "evalaute cc_model2"
9 | # python evaluation_bert.py --model BERT/cc_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
10 | python evaluation_bert.py --model BERT/cc_model2_ --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval
11 | echo "ensemble and rerank!"
12 | echo "fold5-1K"
13 | python rerank.py --data_name coco --fold
14 | echo "5K"
15 | python rerank.py --data_name coco
16 |
--------------------------------------------------------------------------------
/test_bert_f.sh:
--------------------------------------------------------------------------------
1 | echo "BERT"
2 | echo "Flickr30K"
3 | echo "evalaute f_model1"
4 | # python evaluation_bert.py --model BERT/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
5 | python evaluation_bert.py --model BERT/f_model1_ --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/
6 | mv sims_f.npy ./flickr_sims
7 | mv sims_f_T.npy ./flickr_sims
8 | echo "evalaute f_model2"
9 | # python evaluation_bert.py --model BERT/f_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
10 | python evaluation_bert.py --model BERT/f_model2_ --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/
11 | echo "ensemble and rerank!"
12 | python rerank.py --data_name f30k
13 |
--------------------------------------------------------------------------------
/test_gru_cc.sh:
--------------------------------------------------------------------------------
1 | echo "GRU"
2 | echo "MSCOCO"
3 | echo "evalaute cc_model1"
4 | # python evaluation_bert.py --model GRU/cc_model1 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
5 | python evaluation.py --model GRU/cc_model1 --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval
6 | mv sims_full_0.npy sims_full_1.npy sims_full_2.npy sims_full_3.npy sims_full_4.npy sims_full_5k.npy ./coco_sims
7 | echo "evalaute cc_model2"
8 | # python evaluation_bert.py --model GRU/cc_model2 --fold --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
9 | python evaluation.py --model GRU/cc_model2 --fold --data_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/data --region_bbox_file /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/coco_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5 --feature_path /remote-home/lyli/Workspace/meltdown/burneddown/ECCV/joint-pretrain/COCO/region_feat_gvd_wo_bgd/feat_cls_1000/coco_detection_vg_100dets_gvd_checkpoint_trainval
10 | echo "ensemble and rerank!"
11 | echo "fold5-1K"
12 | python rerank.py --data_name coco --fold
13 | echo "5K"
14 | python rerank.py --data_name coco
15 |
--------------------------------------------------------------------------------
/test_gru_f.sh:
--------------------------------------------------------------------------------
1 |
2 | echo "GRU"
3 | echo "Flickr30K"
4 | echo "evalaute f_model1"
5 | # python evaluation_bert.py --model GRU/f_model1 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
6 | python evaluation.py --model fmodel_1_
7 | mv sims_f.npy ./flickr_sims
8 | mv sims_f_T.npy ./flickr_sims
9 | echo "evalaute f_model2"
10 | # python evaluation_bert.py --model GRU/f_model2 --data_path "$DATA_PATH" --region_bbox_file "$REGION_BBOX_FILE" --feature_path "$FEATURE_PATH"
11 | python evaluation.py --model fmodel_2_
12 | echo "ensemble and rerank!"
13 | python rerank.py --data_name f30k
14 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen, 2020
8 | # ------------------------------------------------------------
9 |
10 | import pickle
11 | import os
12 | import time
13 | import shutil
14 |
15 | import torch
16 |
17 | import data
18 | from vocab import Vocabulary # NOQA
19 | from model import VSE
20 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data
21 | import numpy as np
22 | import logging
23 | import tensorboard_logger as tb_logger
24 |
25 | import argparse
26 |
27 |
28 | def main():
29 | # Hyper Parameters
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--data_path', default='data',
32 | help='path to datasets')
33 | parser.add_argument('--data_name', default='f30k',
34 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
35 | parser.add_argument('--vocab_path', default='vocab',
36 | help='Path to saved vocabulary pickle files.')
37 | parser.add_argument('--margin', default=0.2, type=float,
38 | help='Rank loss margin.')
39 | parser.add_argument('--num_epochs', default=30, type=int,
40 | help='Number of training epochs.')
41 | parser.add_argument('--batch_size', default=128, type=int,
42 | help='Size of a training mini-batch.')
43 | parser.add_argument('--word_dim', default=300, type=int,
44 | help='Dimensionality of the word embedding.')
45 | parser.add_argument('--embed_size', default=1024, type=int,
46 | help='Dimensionality of the joint embedding.')
47 | parser.add_argument('--grad_clip', default=2., type=float,
48 | help='Gradient clipping threshold.')
49 | parser.add_argument('--crop_size', default=224, type=int,
50 | help='Size of an image crop as the CNN input.')
51 | parser.add_argument('--num_layers', default=1, type=int,
52 | help='Number of GRU layers.')
53 | parser.add_argument('--learning_rate', default=2e-4, type=float,
54 | help='Initial learning rate.')
55 | parser.add_argument('--lr_update', default=15, type=int,
56 | help='Number of epochs to update the learning rate.')
57 | parser.add_argument('--workers', default=10, type=int,
58 | help='Number of data loader workers.')
59 | parser.add_argument('--log_step', default=100, type=int,
60 | help='Number of steps to print and record the log.')
61 | parser.add_argument('--val_step', default=500, type=int,
62 | help='Number of steps to run validation.')
63 | parser.add_argument('--logger_name', default='runs/test',
64 | help='Path to save the model and Tensorboard log.')
65 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
66 | help='path to latest checkpoint (default: none)')
67 | parser.add_argument('--img_dim', default=2048, type=int,
68 | help='Dimensionality of the image embedding.')
69 | parser.add_argument('--finetune', action='store_true',
70 | help='Fine-tune the image encoder.')
71 | parser.add_argument('--use_restval', action='store_true',
72 | help='Use the restval data for training on MSCOCO.')
73 | parser.add_argument('--reset_train', action='store_true',
74 | help='Ensure the training is always done in '
75 | 'train mode (Not recommended).')
76 | parser.add_argument('--K', default=2, type=int,help='num of JSR.')
77 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/',
78 | type=str, help='path to the pre-computed image features')
79 | parser.add_argument('--region_bbox_file',
80 | default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
81 | type=str, help='path to the region_bbox_file(.h5)')
82 | opt = parser.parse_args()
83 | print(opt)
84 |
85 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
86 | tb_logger.configure(opt.logger_name, flush_secs=5)
87 |
88 | # Load Vocabulary Wrapper
89 | vocab = pickle.load(open(os.path.join(
90 | opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb'))
91 | opt.vocab_size = len(vocab)
92 |
93 | # Load data loaders
94 | train_loader, val_loader = data.get_loaders(
95 | opt.data_name, vocab, opt.crop_size, opt.batch_size, opt.workers, opt)
96 |
97 | # Construct the model
98 | model = VSE(opt)
99 | best_rsum = 0
100 | # optionally resume from a checkpoint
101 | if opt.resume:
102 | if os.path.isfile(opt.resume):
103 | print("=> loading checkpoint '{}'".format(opt.resume))
104 | checkpoint = torch.load(opt.resume)
105 | start_epoch = checkpoint['epoch']
106 | best_rsum = checkpoint['best_rsum']
107 | model.load_state_dict(checkpoint['model'])
108 | # Eiters is used to show logs as the continuation of another
109 | # training
110 | model.Eiters = checkpoint['Eiters']
111 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
112 | .format(opt.resume, start_epoch, best_rsum))
113 | validate(opt, val_loader, model)
114 | else:
115 | print("=> no checkpoint found at '{}'".format(opt.resume))
116 | del checkpoint
117 | # Train the Model
118 |
119 | for epoch in range(opt.num_epochs):
120 | adjust_learning_rate(opt, model.optimizer, epoch)
121 |
122 | # train for one epoch
123 | train(opt, train_loader, model, epoch, val_loader)
124 |
125 | # evaluate on validation set
126 | rsum = validate(opt, val_loader, model)
127 |
128 | # remember best R@ sum and save checkpoint
129 | is_best = rsum > best_rsum
130 | best_rsum = max(rsum, best_rsum)
131 | save_checkpoint({
132 | 'epoch': epoch + 1,
133 | 'model': model.state_dict(),
134 | 'best_rsum': best_rsum,
135 | 'opt': opt,
136 | 'Eiters': model.Eiters,
137 | }, is_best, epoch, prefix=opt.logger_name + '/')
138 |
139 |
140 | def train(opt, train_loader, model, epoch, val_loader):
141 | # average meters to record the training statistics
142 | batch_time = AverageMeter()
143 | data_time = AverageMeter()
144 | train_logger = LogCollector()
145 |
146 | # switch to train mode
147 | model.train_start()
148 |
149 | end = time.time()
150 | for i, train_data in enumerate(train_loader):
151 | if opt.reset_train:
152 | # Always reset to train mode, this is not the default behavior
153 | model.train_start()
154 |
155 | # measure data loading time
156 | data_time.update(time.time() - end)
157 |
158 | # make sure train logger is used
159 | model.logger = train_logger
160 |
161 | # Update the model
162 | model.train_emb(*train_data)
163 |
164 |
165 | # measure elapsed time
166 | batch_time.update(time.time() - end)
167 | end = time.time()
168 |
169 | # Print log info
170 | if model.Eiters % opt.log_step == 0:
171 | logging.info(
172 | 'Epoch: [{0}][{1}/{2}]\t'
173 | '{e_log}\t'
174 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
175 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
176 | .format(
177 | epoch, i, len(train_loader), batch_time=batch_time,
178 | data_time=data_time, e_log=str(model.logger)))
179 |
180 | # Record logs in tensorboard
181 | tb_logger.log_value('epoch', epoch, step=model.Eiters)
182 | tb_logger.log_value('step', i, step=model.Eiters)
183 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
184 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
185 | model.logger.tb_log(tb_logger, step=model.Eiters)
186 |
187 | # validate at every val_step
188 | # if model.Eiters % opt.val_step == 0:
189 | # validate(opt, val_loader, model)
190 |
191 |
192 | def validate(opt, val_loader, model):
193 | # compute the encoding for all the validation images and captions
194 | img_embs, cap_embs = encode_data(
195 | model, val_loader, opt.log_step, logging.info)
196 |
197 | # caption retrieval
198 | (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs)
199 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
200 | (r1, r5, r10, medr, meanr))
201 | # image retrieval
202 | (r1i, r5i, r10i, medri, meanr) = t2i(
203 | img_embs, cap_embs)
204 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
205 | (r1i, r5i, r10i, medri, meanr))
206 | # sum of recalls to be used for early stopping
207 | currscore = r1 + r5 + r10 + r1i + r5i + r10i
208 |
209 | # record metrics in tensorboard
210 | tb_logger.log_value('r1', r1, step=model.Eiters)
211 | tb_logger.log_value('r5', r5, step=model.Eiters)
212 | tb_logger.log_value('r10', r10, step=model.Eiters)
213 | tb_logger.log_value('medr', medr, step=model.Eiters)
214 | tb_logger.log_value('meanr', meanr, step=model.Eiters)
215 | tb_logger.log_value('r1i', r1i, step=model.Eiters)
216 | tb_logger.log_value('r5i', r5i, step=model.Eiters)
217 | tb_logger.log_value('r10i', r10i, step=model.Eiters)
218 | tb_logger.log_value('medri', medri, step=model.Eiters)
219 | tb_logger.log_value('meanr', meanr, step=model.Eiters)
220 | tb_logger.log_value('rsum', currscore, step=model.Eiters)
221 |
222 | return currscore
223 |
224 |
225 | def save_checkpoint(state, is_best, epoch, filename='checkpoint.pth.tar', prefix=''):
226 | torch.save(state, prefix + filename)
227 | if is_best:
228 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar')
229 | # shutil.copyfile(prefix + filename, prefix + 'checkpoint'+str(epoch)+'.pth.tar')
230 |
231 |
232 | def adjust_learning_rate(opt, optimizer, epoch):
233 | """Sets the learning rate to the initial LR
234 | decayed by 10 every 30 epochs"""
235 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update))
236 | for param_group in optimizer.param_groups:
237 | param_group['lr'] = lr
238 |
239 |
240 | def accuracy(output, target, topk=(1,)):
241 | """Computes the precision@k for the specified values of k"""
242 | maxk = max(topk)
243 | batch_size = target.size(0)
244 |
245 | _, pred = output.topk(maxk, 1, True, True)
246 | pred = pred.t()
247 | correct = pred.eq(target.view(1, -1).expand_as(pred))
248 |
249 | res = []
250 | for k in topk:
251 | correct_k = correct[:k].view(-1).float().sum(0)
252 | res.append(correct_k.mul_(100.0 / batch_size))
253 | return res
254 |
255 |
256 | if __name__ == '__main__':
257 | main()
258 |
--------------------------------------------------------------------------------
/train_bert.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------
2 | # Dual Semantic Relations Attention Network (DSRAN) implementation based on
3 | # "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"
4 | # "Learning Dual Semantic Relations with Graph Attention for Image-Text Matching"
5 | # Keyu Wen, Xiaodong Gu, and Qingrong Cheng
6 | # IEEE Transactions on Circuits and Systems for Video Technology, 2020
7 | # Writen by Keyu Wen, 2020
8 | # ------------------------------------------------------------
9 |
10 | import pickle
11 | import os
12 | import time
13 | import shutil
14 | import torch
15 | import data_bert as data
16 | from model_bert import VSE
17 | from evaluation_bert import i2t, t2i, AverageMeter, LogCollector, encode_data, simrank
18 | import numpy as np
19 | import logging
20 | import tensorboard_logger as tb_logger
21 | import argparse
22 |
23 |
24 | def main():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--data_path', default='data',
27 | help='path to datasets')
28 | parser.add_argument('--data_name', default='coco',
29 | help='{coco,f30k}')
30 | parser.add_argument('--margin', default=0.2, type=float,
31 | help='Rank loss margin.')
32 | parser.add_argument('--num_epochs', default=12, type=int,
33 | help='Number of training epochs.')
34 | parser.add_argument('--batch_size', default=128, type=int,
35 | help='Size of a training mini-batch.')
36 | parser.add_argument('--embed_size', default=1024, type=int,
37 | help='Dimensionality of the joint embedding.')
38 | parser.add_argument('--crop_size', default=224, type=int,
39 | help='Size of an image crop as the CNN input.')
40 | parser.add_argument('--learning_rate', default=2e-5, type=float,
41 | help='Initial learning rate.')
42 | parser.add_argument('--lr_update', default=6, type=int,
43 | help='Number of epochs to update the learning rate.')
44 | parser.add_argument('--workers', default=10, type=int,
45 | help='Number of data loader workers.')
46 | parser.add_argument('--log_step', default=100, type=int,
47 | help='Number of steps to print and record the log.')
48 | parser.add_argument('--logger_name', default='runs/grg',
49 | help='Path to save the model and Tensorboard log.')
50 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
51 | help='path to latest checkpoint (default: none)')
52 | parser.add_argument('--img_dim', default=2048, type=int,
53 | help='Dimensionality of the image embedding.')
54 | parser.add_argument('--ft_res', action='store_true',
55 | help='Fine-tune the image encoder.')
56 | parser.add_argument('--bert_path', default='uncased_L-12_H-768_A-12/',
57 | help='path of pre-trained BERT.')
58 | parser.add_argument('--ft_bert', action='store_true',
59 | help='Fine-tune the text encoder.')
60 | parser.add_argument('--bert_size', default=768, type=int,
61 | help='Dimensionality of the text embedding')
62 | parser.add_argument('--warmup', default=-1, type=float)
63 | parser.add_argument('--K', default=2, type=int,help='num of JSR.')
64 | parser.add_argument('--feature_path', default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/trainval/',
65 | type=str, help='path to the pre-computed image features')
66 | parser.add_argument('--region_bbox_file',
67 | default='data/joint-pretrain/flickr30k/region_feat_gvd_wo_bgd/flickr30k_detection_vg_thresh0.2_feat_gvd_checkpoint_trainvaltest.h5',
68 | type=str, help='path to the region_bbox_file(.h5)')
69 | opt = parser.parse_args()
70 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
71 | tb_logger.configure(opt.logger_name, flush_secs=5)
72 |
73 | train_loader, val_loader = data.get_loaders(opt.data_name, opt.batch_size, opt.workers, opt)
74 | opt.l_train = len(train_loader)
75 | print(opt)
76 | model = VSE(opt)
77 | best_rsum = 0
78 | if opt.resume:
79 | if os.path.isfile(opt.resume):
80 | print("=> loading checkpoint '{}'".format(opt.resume))
81 | checkpoint = torch.load(opt.resume)
82 | start_epoch = checkpoint['epoch']
83 | best_rsum = checkpoint['best_rsum']
84 | model.load_state_dict(checkpoint['model'])
85 | model.Eiters = checkpoint['Eiters']
86 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
87 | .format(opt.resume, start_epoch, best_rsum))
88 | validate(opt, val_loader, model)[-1]
89 | else:
90 | print("=> no checkpoint found at '{}'".format(opt.resume))
91 |
92 | for epoch in range(opt.num_epochs):
93 |
94 | adjust_learning_rate(opt, model.optimizer, epoch)
95 |
96 | train(opt, train_loader, model, epoch, val_loader)
97 |
98 | rsum = validate(opt, val_loader, model)[-1]
99 |
100 | is_best = rsum > best_rsum
101 | best_rsum = max(rsum, best_rsum)
102 | save_checkpoint({
103 | 'epoch': epoch + 1,
104 | 'model': model.state_dict(),
105 | 'best_rsum': best_rsum,
106 | 'opt': opt,
107 | 'Eiters': model.Eiters,
108 | }, is_best, epoch, prefix=opt.logger_name + '/')
109 |
110 |
111 | def train(opt, train_loader, model, epoch, val_loader):
112 |
113 | batch_time = AverageMeter()
114 | data_time = AverageMeter()
115 | train_logger = LogCollector()
116 |
117 | model.train_start()
118 |
119 | end = time.time()
120 | for i, train_data in enumerate(train_loader):
121 |
122 | data_time.update(time.time() - end)
123 | model.logger = train_logger
124 | model.train_emb(*train_data)
125 | batch_time.update(time.time() - end)
126 | end = time.time()
127 |
128 | if model.Eiters % opt.log_step == 0:
129 | logging.info(
130 | 'Epoch: [{0}][{1}/{2}]\t'
131 | '{e_log}\t'
132 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
133 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
134 | .format(
135 | epoch, i, len(train_loader), batch_time=batch_time,
136 | data_time=data_time, e_log=str(model.logger)))
137 |
138 | tb_logger.log_value('epoch', epoch, step=model.Eiters)
139 | tb_logger.log_value('step', i, step=model.Eiters)
140 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
141 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
142 | model.logger.tb_log(tb_logger, step=model.Eiters)
143 |
144 |
145 | def validate(opt, val_loader, model):
146 | _, _, sims = encode_data(
147 | model, val_loader, opt.log_step, logging.info)
148 | rs = simrank(sims)
149 | del sims
150 | return rs
151 |
152 |
153 | def save_checkpoint(state, is_best, epoch, filename='checkpoint.pth.tar', prefix=''):
154 | torch.save(state, prefix + filename)
155 | if is_best:
156 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar')
157 |
158 |
159 | def adjust_learning_rate(opt, optimizer, epoch):
160 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update))
161 | for param_group in optimizer.param_groups:
162 | param_group['lr'] = lr
163 |
164 |
165 | if __name__ == '__main__':
166 | main()
167 |
--------------------------------------------------------------------------------
/uncased_L-12_H-768_A-12/bert_pretrained_model:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | # Create a vocabulary wrapper
2 | import nltk
3 | import pickle
4 | from collections import Counter
5 | # from pycocotools.coco import COCO
6 | import json
7 | import argparse
8 | import os
9 | from nltk.stem import WordNetLemmatizer
10 |
11 |
12 | annotations = {
13 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'],
14 | 'coco': ['annotations/captions_train2014.json',
15 | 'annotations/captions_val2014.json'],
16 | 'f8k_precomp': ['train_caps.txt', 'dev_caps.txt'],
17 | '10crop_precomp': ['train_caps.txt', 'dev_caps.txt'],
18 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'],
19 | 'f8k': ['dataset_flickr8k.json'],
20 | 'f30k': ['dataset_flickr30k.json'],
21 | }
22 |
23 |
24 | class Vocabulary(object):
25 | """Simple vocabulary wrapper."""
26 |
27 | def __init__(self):
28 | self.word2idx = {}
29 | self.idx2word = {}
30 | self.idx = 0
31 |
32 | def add_word(self, word):
33 | if word not in self.word2idx:
34 | self.word2idx[word] = self.idx
35 | self.idx2word[self.idx] = word
36 | self.idx += 1
37 |
38 | def __call__(self, word):
39 | if word not in self.word2idx:
40 | return self.word2idx['']
41 | return self.word2idx[word]
42 |
43 | def __len__(self):
44 | return len(self.word2idx)
45 |
46 |
47 | def from_coco_json(path):
48 | coco = COCO(path)
49 | ids = coco.anns.keys()
50 | captions = []
51 | for i, idx in enumerate(ids):
52 | captions.append(str(coco.anns[idx]['caption']))
53 |
54 | return captions
55 |
56 |
57 | def from_flickr_json(path):
58 | dataset = json.load(open(path, 'r'))['images']
59 | captions = []
60 | for i, d in enumerate(dataset):
61 | captions += [str(x['raw']) for x in d['sentences']]
62 |
63 | return captions
64 |
65 |
66 | def from_txt(txt):
67 | captions = []
68 | with open(txt, 'rb') as f:
69 | for line in f:
70 | captions.append(line.strip())
71 | return captions
72 |
73 |
74 | def build_vocab(data_path, data_name, jsons, threshold):
75 | """Build a simple vocabulary wrapper."""
76 | counter = Counter()
77 | for path in jsons[data_name]:
78 | full_path = os.path.join(os.path.join(data_path, data_name), path)
79 | if data_name == 'coco':
80 | captions = from_coco_json(full_path)
81 | elif data_name == 'f8k' or data_name == 'f30k':
82 | captions = from_flickr_json(full_path)
83 | else:
84 | captions = from_txt(full_path)
85 | for i, caption in enumerate(captions):
86 | tokens = nltk.tokenize.word_tokenize(
87 | caption.lower().encode('utf-8').decode('utf-8'))
88 | counter.update(tokens)
89 |
90 | if i % 1000 == 0:
91 | print("\r[%d/%d] tokenized the captions." % (i, len(captions)),end = '')
92 | # Discard if the occurrence of the word is less than min_word_cnt.
93 | words = []
94 | counts = []
95 | for word, cnt in counter.items():
96 | if cnt >= threshold:
97 | words.append(word)
98 | counts.append((word, cnt))
99 | # words = [word for word, cnt in counter.items() if cnt >= threshold]
100 | counts_new = sorted(counts, key=lambda x:x[1], reverse=True)
101 | print(counts_new)
102 | # Create a vocab wrapper and add some special tokens.
103 | vocab = Vocabulary()
104 | vocab.add_word('')
105 | vocab.add_word('')
106 | vocab.add_word('')
107 | vocab.add_word('')
108 | print(len(counts_new))
109 | # Add words to the vocabulary.
110 | chosen_nums = 256
111 | for i, word_cnt in enumerate(counts_new):
112 | word, count = word_cnt
113 | # print(word)
114 | if i < chosen_nums:
115 | vocab.add_word(word)
116 | print(word)
117 | return vocab
118 |
119 |
120 | def main(data_path, data_name):
121 | vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=300)
122 | with open('%s_vocab.pkl' % data_name, 'wb') as f:
123 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
124 | # print("Saved vocabulary file to ", '%s_vocab.pkl' % data_name)
125 |
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--data_path', default='data')
130 | parser.add_argument('--data_name', default='f30k',
131 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
132 | opt = parser.parse_args()
133 | main(opt.data_path, opt.data_name)
134 |
--------------------------------------------------------------------------------
/vocab/10crop_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/10crop_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/111:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/vocab/coco_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/coco_resnet_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_resnet_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/coco_vgg_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_vgg_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/coco_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/coco_vocab.pkl
--------------------------------------------------------------------------------
/vocab/f30k_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f30k_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/f30k_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f30k_vocab.pkl
--------------------------------------------------------------------------------
/vocab/f8k_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f8k_precomp_vocab.pkl
--------------------------------------------------------------------------------
/vocab/f8k_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kywen1119/DSRAN/80c7bc058a8ca327bb7995c4c21765d1a100a3ee/vocab/f8k_vocab.pkl
--------------------------------------------------------------------------------