├── .idea
├── caption.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── caption.py
├── create_input_files.py
├── datasets.py
├── eval.py
├── img
├── att.png
├── babycake.png
├── beam_search.png
├── bikefence.png
├── biketrain.png
├── birds.png
├── boats.png
├── catbanana.png
├── decoder_att.png
├── decoder_no_att.png
├── dogtie.png
├── doublystochastic.png
├── encoder.png
├── firehydrant.png
├── manbike.png
├── model.png
├── plane.png
├── salad.png
├── sheep.png
├── sorted.jpg
├── sorted2.jpg
├── tommy.png
└── weights.png
├── models.py
├── train.py
└── utils.py
/.idea/caption.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sagar Vinodababu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a **[PyTorch](https://pytorch.org) Tutorial to Image Captioning**.
2 |
3 | This is the first in [a series of tutorials](https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch) I'm writing about _implementing_ cool models on your own with the amazing PyTorch library.
4 |
5 | Basic knowledge of PyTorch, convolutional and recurrent neural networks is assumed.
6 |
7 | If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) and [Learning PyTorch with Examples](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html).
8 |
9 | Questions, suggestions, or corrections can be posted as issues.
10 |
11 | I'm using `PyTorch 0.4` in `Python 3.6`.
12 |
13 | ---
14 |
15 | **27 Jan 2020**: Working code for two new tutorials has been added — [Super-Resolution](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution) and [Machine Translation](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Machine-Translation)
16 |
17 | ---
18 |
19 | # Contents
20 |
21 | [***Objective***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#objective)
22 |
23 | [***Concepts***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#concepts)
24 |
25 | [***Overview***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#overview)
26 |
27 | [***Implementation***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#implementation)
28 |
29 | [***Training***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#training)
30 |
31 | [***Inference***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#inference)
32 |
33 | [***Frequently Asked Questions***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#faqs)
34 |
35 | # Objective
36 |
37 | **To build a model that can generate a descriptive caption for an image we provide it.**
38 |
39 | In the interest of keeping things simple, let's implement the [_Show, Attend, and Tell_](https://arxiv.org/abs/1502.03044) paper. This is by no means the current state-of-the-art, but is still pretty darn amazing. The authors' original implementation can be found [here](https://github.com/kelvinxu/arctic-captions).
40 |
41 | This model learns _where_ to look.
42 |
43 | As you generate a caption, word by word, you can see the model's gaze shifting across the image.
44 |
45 | This is possible because of its _Attention_ mechanism, which allows it to focus on the part of the image most relevant to the word it is going to utter next.
46 |
47 | Here are some captions generated on _test_ images not seen during training or validation:
48 |
49 | ---
50 |
51 | 
52 |
53 | ---
54 |
55 | 
56 |
57 | ---
58 |
59 | 
60 |
61 | ---
62 |
63 | 
64 |
65 | ---
66 |
67 | 
68 |
69 | ---
70 |
71 | 
72 |
73 | ---
74 |
75 | There are more examples at the [end of the tutorial](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#some-more-examples).
76 |
77 | ---
78 |
79 | # Concepts
80 |
81 | * **Image captioning**. duh.
82 |
83 | * **Encoder-Decoder architecture**. Typically, a model that generates sequences will use an Encoder to encode the input into a fixed form and a Decoder to decode it, word by word, into a sequence.
84 |
85 | * **Attention**. The use of Attention networks is widespread in deep learning, and with good reason. This is a way for a model to choose only those parts of the encoding that it thinks is relevant to the task at hand. The same mechanism you see employed here can be used in any model where the Encoder's output has multiple points in space or time. In image captioning, you consider some pixels more important than others. In sequence to sequence tasks like machine translation, you consider some words more important than others.
86 |
87 | * **Transfer Learning**. This is when you borrow from an existing model by using parts of it in a new model. This is almost always better than training a new model from scratch (i.e., knowing nothing). As you will see, you can always fine-tune this second-hand knowledge to the specific task at hand. Using pretrained word embeddings is a dumb but valid example. For our image captioning problem, we will use a pretrained Encoder, and then fine-tune it as needed.
88 |
89 | * **Beam Search**. This is where you don't let your Decoder be lazy and simply choose the words with the _best_ score at each decode-step. Beam Search is useful for any language modeling problem because it finds the most optimal sequence.
90 |
91 | # Overview
92 |
93 | In this section, I will present an overview of this model. If you're already familiar with it, you can skip straight to the [Implementation](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#implementation) section or the commented code.
94 |
95 | ### Encoder
96 |
97 | The Encoder **encodes the input image with 3 color channels into a smaller image with "learned" channels**.
98 |
99 | This smaller encoded image is a summary representation of all that's useful in the original image.
100 |
101 | Since we want to encode images, we use Convolutional Neural Networks (CNNs).
102 |
103 | We don't need to train an encoder from scratch. Why? Because there are already CNNs trained to represent images.
104 |
105 | For years, people have been building models that are extraordinarily good at classifying an image into one of a thousand categories. It stands to reason that these models capture the essence of an image very well.
106 |
107 | I have chosen to use the **101 layered Residual Network trained on the ImageNet classification task**, already available in PyTorch. As stated earlier, this is an example of Transfer Learning. You have the option of fine-tuning it to improve performance.
108 |
109 | 
110 |
111 | These models progressively create smaller and smaller representations of the original image, and each subsequent representation is more "learned", with a greater number of channels. The final encoding produced by our ResNet-101 encoder has a size of 14x14 with 2048 channels, i.e., a `2048, 14, 14` size tensor.
112 |
113 | I encourage you to experiment with other pre-trained architectures. The paper uses a VGGnet, also pretrained on ImageNet, but without fine-tuning. Either way, modifications are necessary. Since the last layer or two of these models are linear layers coupled with softmax activation for classification, we strip them away.
114 |
115 | ### Decoder
116 |
117 | The Decoder's job is to **look at the encoded image and generate a caption word by word**.
118 |
119 | Since it's generating a sequence, it would need to be a Recurrent Neural Network (RNN). We will use an LSTM.
120 |
121 | In a typical setting without Attention, you could simply average the encoded image across all pixels. You could then feed this, with or without a linear transformation, into the Decoder as its first hidden state and generate the caption. Each predicted word is used to generate the next word.
122 |
123 | 
124 |
125 | In a setting _with_ Attention, we want the Decoder to be able to **look at different parts of the image at different points in the sequence**. For example, while generating the word `football` in `a man holds a football`, the Decoder would know to focus on – you guessed it – the football!
126 |
127 | 
128 |
129 | Instead of the simple average, we use the _weighted_ average across all pixels, with the weights of the important pixels being greater. This weighted representation of the image can be concatenated with the previously generated word at each step to generate the next word.
130 |
131 | ### Attention
132 |
133 | The Attention network **computes these weights**.
134 |
135 | Intuitively, how would you estimate the importance of a certain part of an image? You would need to be aware of the sequence you have generated _so far_, so you can look at the image and decide what needs describing next. For example, after you mention `a man`, it is logical to declare that he is `holding a football`.
136 |
137 | This is exactly what the Attention mechanism does – it considers the sequence generated thus far, and _attends_ to the part of the image that needs describing next.
138 |
139 | 
140 |
141 | We will use _soft_ Attention, where the weights of the pixels add up to 1. If there are `P` pixels in our encoded image, then at each timestep `t` –
142 |
143 |
144 |
145 |
146 |
147 | You could interpret this entire process as computing the **probability that a pixel is _the_ place to look to generate the next word**.
148 |
149 | ### Putting it all together
150 |
151 | It might be clear by now what our combined network looks like.
152 |
153 | 
154 |
155 | - Once the Encoder generates the encoded image, we transform the encoding to create the initial hidden state `h` (and cell state `C`) for the LSTM Decoder.
156 | - At each decode step,
157 | - the encoded image and the previous hidden state is used to generate weights for each pixel in the Attention network.
158 | - the previously generated word and the weighted average of the encoding are fed to the LSTM Decoder to generate the next word.
159 |
160 | ### Beam Search
161 |
162 | We use a linear layer to transform the Decoder's output into a score for each word in the vocabulary.
163 |
164 | The straightforward – and greedy – option would be to choose the word with the highest score and use it to predict the next word. But this is not optimal because the rest of the sequence hinges on that first word you choose. If that choice isn't the best, everything that follows is sub-optimal. And it's not just the first word – each word in the sequence has consequences for the ones that succeed it.
165 |
166 | It might very well happen that if you'd chosen the _third_ best word at that first step, and the _second_ best word at the second step, and so on... _that_ would be the best sequence you could generate.
167 |
168 | It would be best if we could somehow _not_ decide until we've finished decoding completely, and **choose the sequence that has the highest _overall_ score from a basket of candidate sequences**.
169 |
170 | Beam Search does exactly this.
171 |
172 | - At the first decode step, consider the top `k` candidates.
173 | - Generate `k` second words for each of these `k` first words.
174 | - Choose the top `k` [first word, second word] combinations considering additive scores.
175 | - For each of these `k` second words, choose `k` third words, choose the top `k` [first word, second word, third word] combinations.
176 | - Repeat at each decode step.
177 | - After `k` sequences terminate, choose the sequence with the best overall score.
178 |
179 | 
180 |
181 | As you can see, some sequences (striked out) may fail early, as they don't make it to the top `k` at the next step. Once `k` sequences (underlined) generate the `` token, we choose the one with the highest score.
182 |
183 | # Implementation
184 |
185 | The sections below briefly describe the implementation.
186 |
187 | They are meant to provide some context, but **details are best understood directly from the code**, which is quite heavily commented.
188 |
189 | ### Dataset
190 |
191 | I'm using the MSCOCO '14 Dataset. You'd need to download the [Training (13GB)](http://images.cocodataset.org/zips/train2014.zip) and [Validation (6GB)](http://images.cocodataset.org/zips/val2014.zip) images.
192 |
193 | We will use [Andrej Karpathy's training, validation, and test splits](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip). This zip file contain the captions. You will also find splits and captions for the Flicker8k and Flicker30k datasets, so feel free to use these instead of MSCOCO if the latter is too large for your computer.
194 |
195 | ### Inputs to model
196 |
197 | We will need three inputs.
198 |
199 | #### Images
200 |
201 | Since we're using a pretrained Encoder, we would need to process the images into the form this pretrained Encoder is accustomed to.
202 |
203 | Pretrained ImageNet models available as part of PyTorch's `torchvision` module. [This page](https://pytorch.org/docs/master/torchvision/models.html) details the preprocessing or transformation we need to perform – pixel values must be in the range [0,1] and we must then normalize the image by the mean and standard deviation of the ImageNet images' RGB channels.
204 |
205 | ```python
206 | mean = [0.485, 0.456, 0.406]
207 | std = [0.229, 0.224, 0.225]
208 | ```
209 | Also, PyTorch follows the NCHW convention, which means the channels dimension (C) must precede the size dimensions.
210 |
211 | We will resize all MSCOCO images to 256x256 for uniformity.
212 |
213 | Therefore, **images fed to the model must be a `Float` tensor of dimension `N, 3, 256, 256`**, and must be normalized by the aforesaid mean and standard deviation. `N` is the batch size.
214 |
215 | #### Captions
216 |
217 | Captions are both the target and the inputs of the Decoder as each word is used to generate the next word.
218 |
219 | To generate the first word, however, we need a *zeroth* word, ``.
220 |
221 | At the last word, we should predict `` the Decoder must learn to predict the end of a caption. This is necessary because we need to know when to stop decoding during inference.
222 |
223 | ` a man holds a football `
224 |
225 | Since we pass the captions around as fixed size Tensors, we need to pad captions (which are naturally of varying length) to the same length with `` tokens.
226 |
227 | ` a man holds a football ....`
228 |
229 | Furthermore, we create a `word_map` which is an index mapping for each word in the corpus, including the ``,``, and `` tokens. PyTorch, like other libraries, needs words encoded as indices to look up embeddings for them or to identify their place in the predicted word scores.
230 |
231 | `9876 1 5 120 1 5406 9877 9878 9878 9878....`
232 |
233 | Therefore, **captions fed to the model must be an `Int` tensor of dimension `N, L`** where `L` is the padded length.
234 |
235 | #### Caption Lengths
236 |
237 | Since the captions are padded, we would need to keep track of the lengths of each caption. This is the actual length + 2 (for the `` and `` tokens).
238 |
239 | Caption lengths are also important because you can build dynamic graphs with PyTorch. We only process a sequence upto its length and don't waste compute on the ``s.
240 |
241 | Therefore, **caption lengths fed to the model must be an `Int` tensor of dimension `N`**.
242 |
243 | ### Data pipeline
244 |
245 | See `create_input_files()` in [`utils.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/utils.py).
246 |
247 | This reads the data downloaded and saves the following files –
248 |
249 | - An **HDF5 file containing images for each split in an `I, 3, 256, 256` tensor**, where `I` is the number of images in the split. Pixel values are still in the range [0, 255], and are stored as unsigned 8-bit `Int`s.
250 | - A **JSON file for each split with a list of `N_c` * `I` encoded captions**, where `N_c` is the number of captions sampled per image. These captions are in the same order as the images in the HDF5 file. Therefore, the `i`th caption will correspond to the `i // N_c`th image.
251 | - A **JSON file for each split with a list of `N_c` * `I` caption lengths**. The `i`th value is the length of the `i`th caption, which corresponds to the `i // N_c`th image.
252 | - A **JSON file which contains the `word_map`**, the word-to-index dictionary.
253 |
254 | Before we save these files, we have the option to only use captions that are shorter than a threshold, and to bin less frequent words into an `` token.
255 |
256 | We use HDF5 files for the images because we will read them directly from disk during training / validation. They're simply too large to fit into RAM all at once. But we do load all captions and their lengths into memory.
257 |
258 | See `CaptionDataset` in [`datasets.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/datasets.py).
259 |
260 | This is a subclass of PyTorch [`Dataset`](https://pytorch.org/docs/master/data.html#torch.utils.data.Dataset). It needs a `__len__` method defined, which returns the size of the dataset, and a `__getitem__` method which returns the `i`th image, caption, and caption length.
261 |
262 | We read images from disk, convert pixels to [0,255], and normalize them inside this class.
263 |
264 | The `Dataset` will be used by a PyTorch [`DataLoader`](https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader) in `train.py` to create and feed batches of data to the model for training or validation.
265 |
266 | ### Encoder
267 |
268 | See `Encoder` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/models.py).
269 |
270 | We use a pretrained ResNet-101 already available in PyTorch's `torchvision` module. Discard the last two layers (pooling and linear layers), since we only need to encode the image, and not classify it.
271 |
272 | We do add an `AdaptiveAvgPool2d()` layer to **resize the encoding to a fixed size**. This makes it possible to feed images of variable size to the Encoder. (We did, however, resize our input images to `256, 256` because we had to store them together as a single tensor.)
273 |
274 | Since we may want to fine-tune the Encoder, we add a `fine_tune()` method which enables or disables the calculation of gradients for the Encoder's parameters. We **only fine-tune convolutional blocks 2 through 4 in the ResNet**, because the first convolutional block would have usually learned something very fundamental to image processing, such as detecting lines, edges, curves, etc. We don't mess with the foundations.
275 |
276 | ### Attention
277 |
278 | See `Attention` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/models.py).
279 |
280 | The Attention network is simple – it's composed of only linear layers and a couple of activations.
281 |
282 | Separate linear layers **transform both the encoded image (flattened to `N, 14 * 14, 2048`) and the hidden state (output) from the Decoder to the same dimension**, viz. the Attention size. They are then added and ReLU activated. A third linear layer **transforms this result to a dimension of 1**, whereupon we **apply the softmax to generate the weights** `alpha`.
283 |
284 | ### Decoder
285 |
286 | See `DecoderWithAttention` in [`models.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/models.py).
287 |
288 | The output of the Encoder is received here and flattened to dimensions `N, 14 * 14, 2048`. This is just convenient and prevents having to reshape the tensor multiple times.
289 |
290 | We **initialize the hidden and cell state of the LSTM** using the encoded image with the `init_hidden_state()` method, which uses two separate linear layers.
291 |
292 | At the very outset, we **sort the `N` images and captions by decreasing caption lengths**. This is so that we can process only _valid_ timesteps, i.e., not process the ``s.
293 |
294 | 
295 |
296 | We can iterate over each timestep, processing only the colored regions, which are the **_effective_ batch size** `N_t` at that timestep. The sorting allows the top `N_t` at any timestep to align with the outputs from the previous step. At the third timestep, for example, we process only the top 5 images, using the top 5 outputs from the previous step.
297 |
298 | This **iteration is performed _manually_ in a `for` loop** with a PyTorch [`LSTMCell`](https://pytorch.org/docs/master/nn.html#torch.nn.LSTM) instead of iterating automatically without a loop with a PyTorch [`LSTM`](https://pytorch.org/docs/master/nn.html#torch.nn.LSTM). This is because we need to execute the Attention mechanism between each decode step. An `LSTMCell` is a single timestep operation, whereas an `LSTM` would iterate over multiple timesteps continously and provide all outputs at once.
299 |
300 | We **compute the weights and attention-weighted encoding** at each timestep with the Attention network. In section `4.2.1` of the paper, they recommend **passing the attention-weighted encoding through a filter or gate**. This gate is a sigmoid activated linear transform of the Decoder's previous hidden state. The authors state that this helps the Attention network put more emphasis on the objects in the image.
301 |
302 | We **concatenate this filtered attention-weighted encoding with the embedding of the previous word** (`` to begin), and run the `LSTMCell` to **generate the new hidden state (or output)**. A linear layer **transforms this new hidden state into scores for each word in the vocabulary**, which is stored.
303 |
304 | We also store the weights returned by the Attention network at each timestep. You will see why soon enough.
305 |
306 | # Training
307 |
308 | Before you begin, make sure to save the required data files for training, validation, and testing. To do this, run the contents of [`create_input_files.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/create_input_files.py) after pointing it to the the Karpathy JSON file and the image folder containing the extracted `train2014` and `val2014` folders from your [downloaded data](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#dataset).
309 |
310 | See [`train.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/train.py).
311 |
312 | The parameters for the model (and training it) are at the beginning of the file, so you can easily check or modify them should you wish to.
313 |
314 | To **train your model from scratch**, simply run this file –
315 |
316 | `python train.py`
317 |
318 | To **resume training at a checkpoint**, point to the corresponding file with the `checkpoint` parameter at the beginning of the code.
319 |
320 | Note that we perform validation at the end of every training epoch.
321 |
322 | ### Loss Function
323 |
324 | Since we're generating a sequence of words, we use **[`CrossEntropyLoss`](https://pytorch.org/docs/master/nn.html#torch.nn.CrossEntropyLoss)**. You only need to submit the raw scores from the final layer in the Decoder, and the loss function will perform the softmax and log operations.
325 |
326 | The authors of the paper recommend using a second loss – a "**doubly stochastic regularization**". We know the weights sum to 1 at a given timestep. But we also encourage the weights at a single pixel `p` to sum to 1 across _all_ timesteps `T` –
327 |
328 |
329 |
330 |
331 |
332 | This means we want the model to attend to every pixel over the course of generating the entire sequence. Therefore, we try to **minimize the difference between 1 and the sum of a pixel's weights across all timesteps**.
333 |
334 | **We do not compute losses over the padded regions**. An easy way to do get rid of the pads is to use PyTorch's [`pack_padded_sequence()`](https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pack_padded_sequence), which flattens the tensor by timestep while ignoring the padded regions. You can now aggregate the loss over this flattened tensor.
335 |
336 | 
337 |
338 | **Note** – This function is actually used to perform the same dynamic batching (i.e., processing only the effective batch size at each timestep) we performed in our Decoder, when using an `RNN` or `LSTM` in PyTorch. In this case, PyTorch handles the dynamic variable-length graphs internally. You can see an example in [`dynamic_rnn.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Sequence-Labeling/blob/master/dynamic_rnn.py) in my other tutorial on sequence labeling. We would have used this function along with an `LSTM` in our Decoder if we weren't manually iterating because of the Attention network.
339 |
340 | ### Early stopping with BLEU
341 |
342 | To evaluate the model's performance on the validation set, we will use the automated [BiLingual Evaluation Understudy (BLEU)](http://www.aclweb.org/anthology/P02-1040.pdf) evaluation metric. This evaluates a generated caption against reference caption(s). For each generated caption, we will use all `N_c` captions available for that image as the reference captions.
343 |
344 | The authors of the _Show, Attend and Tell_ paper observe that correlation between the loss and the BLEU score breaks down after a point, so they recommend to stop training early on when the BLEU score begins to degrade, even if the loss continues to decrease.
345 |
346 | I used the BLEU tool [available in the NLTK module](https://www.nltk.org/_modules/nltk/translate/bleu_score.html).
347 |
348 | Note that there is considerable criticism of the BLEU score because it doesn't always correlate well with human judgment. The authors also report the METEOR scores for this reason, but I haven't implemented this metric.
349 |
350 | ### Remarks
351 |
352 | I recommend you train in stages.
353 |
354 | I first trained only the Decoder, i.e. without fine-tuning the Encoder, with a batch size of `80`.
355 | I trained for 20 epochs, and the BLEU-4 score peaked at about `23.25` at the 13th epoch. I used the [`Adam()`](https://pytorch.org/docs/master/optim.html#torch.optim.Adam) optimizer with an initial learning rate of `4e-4`.
356 |
357 | I continued from the 13th epoch checkpoint allowing fine-tuning of the Encoder with a batch size of `32`. The smaller batch size is because the model is now larger because it contains the Encoder's gradients. With fine-tuning, the score rose to `24.29` in just about 3 epochs. Continuing training would probably have pushed the score slightly higher but I had to commit my GPU elsewhere.
358 |
359 | An important distinction to make here is that I'm still supplying the ground-truth as the input at each decode-step during validation, _regardless of the word last generated_. This is called __Teacher Forcing__. While this is commonly used during training to speed-up the process, as we are doing, conditions during validation must mimic real inference conditions as much as possible. I haven't implemented batched inference yet – where each word in the caption is generated from the previously generated word, and terminates upon hitting the `` token.
360 |
361 | Since I'm teacher-forcing during validation, the BLEU score measured above on the resulting captions _does not_ reflect real performance. In fact, the BLEU score is a metric designed for comparing naturally generated captions to ground-truth captions of differing length. Once batched inference is implemented, i.e. no Teacher Forcing, early-stopping with the BLEU score will be truly 'proper'.
362 |
363 | With this in mind, I used [`eval.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/eval.py) to compute the correct BLEU-4 scores of this model checkpoint on the validation and test sets _without_ Teacher Forcing, at different beam sizes –
364 |
365 | Beam Size | Validation BLEU-4 | Test BLEU-4 |
366 | :---: | :---: | :---: |
367 | 1 | 29.98 | 30.28 |
368 | 3 | 32.95 | 33.06 |
369 | 5 | 33.17 | 33.29 |
370 |
371 | The test score is higher than the result in the paper, and could be because of how our BLEU calculators are parameterized, the fact that I used a ResNet encoder, and actually fine-tuned the encoder – even if just a little.
372 |
373 | Also, remember – when fine-tuning during Transfer Learning, it's always better to use a learning rate considerably smaller than what was originally used to train the borrowed model. This is because the model is already quite optimized, and we don't want to change anything too quickly. I used `Adam()` for the Encoder as well, but with a learning rate of `1e-4`, which is a tenth of the default value for this optimizer.
374 |
375 | On a Titan X (Pascal), it took 55 minutes per epoch without fine-tuning, and 2.5 hours with fine-tuning at the stated batch sizes.
376 |
377 | ### Model Checkpoint
378 |
379 | You can download this pretrained model and the corresponding `word_map` [here](https://drive.google.com/open?id=189VY65I_n4RTpQnmLGj7IzVnOF6dmePC).
380 |
381 | Note that this checkpoint should be [loaded directly with PyTorch](https://pytorch.org/docs/stable/torch.html?#torch.load), or passed to [`caption.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/caption.py) – see below.
382 |
383 | # Inference
384 |
385 | See [`caption.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/caption.py).
386 |
387 | During inference, we _cannot_ directly use the `forward()` method in the Decoder because it uses Teacher Forcing. Rather, we would actually need to **feed the previously generated word to the LSTM at each timestep**.
388 |
389 | `caption_image_beam_search()` reads an image, encodes it, and applies the layers in the Decoder in the correct order, while using the previously generated word as the input to the LSTM at each timestep. It also incorporates Beam Search.
390 |
391 | `visualize_att()` can be used to visualize the generated caption along with the weights at each timestep as seen in the examples.
392 |
393 | To **caption an image** from the command line, point to the image, model checkpoint, word map (and optionally, the beam size) as follows –
394 |
395 | `python caption.py --img='path/to/image.jpeg' --model='path/to/BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' --word_map='path/to/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' --beam_size=5`
396 |
397 | Alternatively, use the functions in the file as needed.
398 |
399 | Also see [`eval.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/eval.py), which implements this process for calculating the BLEU score on the validation set, with or without Beam Search.
400 |
401 | ### Some more examples
402 |
403 | ---
404 |
405 | 
406 |
407 | ---
408 |
409 | 
410 |
411 | ---
412 |
413 | 
414 |
415 | ---
416 |
417 | 
418 |
419 | ---
420 |
421 | 
422 |
423 | ---
424 |
425 | **The ~~Turing~~ Tommy Test** – you know AI's not really AI because it hasn't watched _The Room_ and doesn't recognize greatness when it sees it.
426 |
427 | 
428 |
429 | ---
430 |
431 | # FAQs
432 |
433 | __You said__ ___soft___ __attention. Is there, um, a__ ___hard___ __attention?__
434 |
435 | Yes, the _Show, Attend and Tell_ paper uses both variants, and the Decoder with "hard" attention performs marginally better.
436 |
437 | In _soft_ attention, which we use here, you're computing the weights `alpha` and using the weighted average of the features across all pixels. This is a deterministic, differentiable operation.
438 |
439 | In _hard_ attention, you are choosing to just sample some pixels from a distribution defined by `alpha`. Note that any such probabilistic sampling is non-deterministic or _stochastic_, i.e. a specific input will not always produce the same output. But since gradient descent presupposes that the network is deterministic (and therefore differentiable), the sampling is reworked to remove its stochasticity. My knowledge of this is fairly superficial at this point – I will update this answer when I have a more detailed understanding.
440 |
441 | ---
442 |
443 | __How do I use an attention network for an NLP task like a sequence to sequence model?__
444 |
445 | Much like you use a CNN to generate an encoding with features at each pixel, you would use an RNN to generate encoded features at each timestep i.e. word position in the input.
446 |
447 | Without attention, you would use the Encoder's output at the last timestep as the encoding for the entire sentence, since it would also contain information from prior timesteps. The Encoder's last output now bears the burden of having to encode the entire sentence meaningfully, which is not easy, especially for longer sentences.
448 |
449 | With attention, you would attend over the timesteps in the Encoder's output, generating weights for each timestep/word, and take the weighted average to represent the sentence. In a sequence to sequence task like machine translation, you would attend to the relevant words in the input as you generate each word in the output.
450 |
451 | You could also use Attention without a Decoder. For example, if you want to classify text, you can attend to the important words in the input just once to perform the classification.
452 |
453 | ---
454 |
455 | __Can we use Beam Search during training?__
456 |
457 | Not with the current loss function, but [yes](https://arxiv.org/abs/1606.02960). This is not common at all.
458 |
459 | ---
460 |
461 | __What is Teacher Forcing?__
462 |
463 | Teacher Forcing is when we use the ground truth captions as the input to the Decoder at each timestep, and not the word it generated in the previous timestep. It's common to teacher-force during training since it could mean faster convergence of the model. But it can also learn to depend on being told the correct answer, and exhibit some instability in practice.
464 |
465 | It would be ideal to train using Teacher Forcing [only some of the time](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#training-the-model), based on a probability. This is called Scheduled Sampling.
466 |
467 | (I plan to add the option).
468 |
469 | ---
470 |
471 | __Can I use pretrained word embeddings (GloVe, CBOW, skipgram, etc.) instead of learning them from scratch?__
472 |
473 | Yes, you could, with the `load_pretrained_embeddings()` method in the `Decoder` class. You could also choose to fine-tune (or not) with the `fine_tune_embeddings()` method.
474 |
475 | After creating the Decoder in [`train.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/train.py), you should provide the pretrained vectors to `load_pretrained_embeddings()` stacked in the same order as in the `word_map`. For words that you don't have pretrained vectors for, like ``, you can initialize embeddings randomly like we did in `init_weights()`. I recommend fine-tuning to learn more meaningful vectors for these randomly initialized vectors.
476 |
477 | ```python
478 | decoder = DecoderWithAttention(attention_dim=attention_dim,
479 | embed_dim=emb_dim,
480 | decoder_dim=decoder_dim,
481 | vocab_size=len(word_map),
482 | dropout=dropout)
483 | decoder.load_pretrained_embeddings(pretrained_embeddings) # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
484 | decoder.fine_tune_embeddings(True) # or False
485 | ```
486 |
487 | Also make sure to change the `emb_dim` parameter from its current value of `512` to the size of your pre-trained embeddings. This should automatically adjust the input size of the decoder LSTM to accomodate them.
488 |
489 | ---
490 |
491 | __How do I keep track of which tensors allow gradients to be computed?__
492 |
493 | With the release of PyTorch `0.4`, wrapping tensors as `Variable`s is no longer required. Instead, tensors have the `requires_grad` attribute, which decides whether it is tracked by `autograd`, and therefore whether gradients are computed for it during backpropagation.
494 |
495 | - By default, when you create a tensor from scratch, `requires_grad` will be set to `False`.
496 | - When a tensor is created from or modified using another tensor that allows gradients, then `requires_grad` will be set to `True`.
497 | - Tensors which are parameters of `torch.nn` layers will already have `requires_grad` set to `True`.
498 |
499 | ---
500 |
501 | __How do I compute all BLEU (i.e. BLEU-1 to BLEU-4) scores during evaluation?__
502 |
503 | You'd need to modify the code in [`eval.py`](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/eval.py) to do this. Please see [this excellent answer]() by [kmario23]() for a clear and detailed explanation.
504 |
--------------------------------------------------------------------------------
/caption.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import json
5 | import torchvision.transforms as transforms
6 | import matplotlib.pyplot as plt
7 | import matplotlib.cm as cm
8 | import skimage.transform
9 | import argparse
10 | from scipy.misc import imread, imresize
11 | from PIL import Image
12 |
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
17 | """
18 | Reads an image and captions it with beam search.
19 |
20 | :param encoder: encoder model
21 | :param decoder: decoder model
22 | :param image_path: path to image
23 | :param word_map: word map
24 | :param beam_size: number of sequences to consider at each decode-step
25 | :return: caption, weights for visualization
26 | """
27 |
28 | k = beam_size
29 | vocab_size = len(word_map)
30 |
31 | # Read image and process
32 | img = imread(image_path)
33 | if len(img.shape) == 2:
34 | img = img[:, :, np.newaxis]
35 | img = np.concatenate([img, img, img], axis=2)
36 | img = imresize(img, (256, 256))
37 | img = img.transpose(2, 0, 1)
38 | img = img / 255.
39 | img = torch.FloatTensor(img).to(device)
40 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
41 | std=[0.229, 0.224, 0.225])
42 | transform = transforms.Compose([normalize])
43 | image = transform(img) # (3, 256, 256)
44 |
45 | # Encode
46 | image = image.unsqueeze(0) # (1, 3, 256, 256)
47 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
48 | enc_image_size = encoder_out.size(1)
49 | encoder_dim = encoder_out.size(3)
50 |
51 | # Flatten encoding
52 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
53 | num_pixels = encoder_out.size(1)
54 |
55 | # We'll treat the problem as having a batch size of k
56 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
57 |
58 | # Tensor to store top k previous words at each step; now they're just
59 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
60 |
61 | # Tensor to store top k sequences; now they're just
62 | seqs = k_prev_words # (k, 1)
63 |
64 | # Tensor to store top k sequences' scores; now they're just 0
65 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
66 |
67 | # Tensor to store top k sequences' alphas; now they're just 1s
68 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
69 |
70 | # Lists to store completed sequences, their alphas and scores
71 | complete_seqs = list()
72 | complete_seqs_alpha = list()
73 | complete_seqs_scores = list()
74 |
75 | # Start decoding
76 | step = 1
77 | h, c = decoder.init_hidden_state(encoder_out)
78 |
79 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
80 | while True:
81 |
82 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
83 |
84 | awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
85 |
86 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
87 |
88 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
89 | awe = gate * awe
90 |
91 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
92 |
93 | scores = decoder.fc(h) # (s, vocab_size)
94 | scores = F.log_softmax(scores, dim=1)
95 |
96 | # Add
97 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
98 |
99 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
100 | if step == 1:
101 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
102 | else:
103 | # Unroll and find top scores, and their unrolled indices
104 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
105 |
106 | # Convert unrolled indices to actual indices of scores
107 | prev_word_inds = top_k_words / vocab_size # (s)
108 | next_word_inds = top_k_words % vocab_size # (s)
109 |
110 | # Add new words to sequences, alphas
111 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
112 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
113 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
114 |
115 | # Which sequences are incomplete (didn't reach )?
116 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
117 | next_word != word_map['']]
118 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
119 |
120 | # Set aside complete sequences
121 | if len(complete_inds) > 0:
122 | complete_seqs.extend(seqs[complete_inds].tolist())
123 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
124 | complete_seqs_scores.extend(top_k_scores[complete_inds])
125 | k -= len(complete_inds) # reduce beam length accordingly
126 |
127 | # Proceed with incomplete sequences
128 | if k == 0:
129 | break
130 | seqs = seqs[incomplete_inds]
131 | seqs_alpha = seqs_alpha[incomplete_inds]
132 | h = h[prev_word_inds[incomplete_inds]]
133 | c = c[prev_word_inds[incomplete_inds]]
134 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
135 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
136 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
137 |
138 | # Break if things have been going on too long
139 | if step > 50:
140 | break
141 | step += 1
142 |
143 | i = complete_seqs_scores.index(max(complete_seqs_scores))
144 | seq = complete_seqs[i]
145 | alphas = complete_seqs_alpha[i]
146 |
147 | return seq, alphas
148 |
149 |
150 | def visualize_att(image_path, seq, alphas, rev_word_map, smooth=True):
151 | """
152 | Visualizes caption with weights at every word.
153 |
154 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
155 |
156 | :param image_path: path to image that has been captioned
157 | :param seq: caption
158 | :param alphas: weights
159 | :param rev_word_map: reverse word mapping, i.e. ix2word
160 | :param smooth: smooth weights?
161 | """
162 | image = Image.open(image_path)
163 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)
164 |
165 | words = [rev_word_map[ind] for ind in seq]
166 |
167 | for t in range(len(words)):
168 | if t > 50:
169 | break
170 | plt.subplot(np.ceil(len(words) / 5.), 5, t + 1)
171 |
172 | plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
173 | plt.imshow(image)
174 | current_alpha = alphas[t, :]
175 | if smooth:
176 | alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8)
177 | else:
178 | alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24])
179 | if t == 0:
180 | plt.imshow(alpha, alpha=0)
181 | else:
182 | plt.imshow(alpha, alpha=0.8)
183 | plt.set_cmap(cm.Greys_r)
184 | plt.axis('off')
185 | plt.show()
186 |
187 |
188 | if __name__ == '__main__':
189 | parser = argparse.ArgumentParser(description='Show, Attend, and Tell - Tutorial - Generate Caption')
190 |
191 | parser.add_argument('--img', '-i', help='path to image')
192 | parser.add_argument('--model', '-m', help='path to model')
193 | parser.add_argument('--word_map', '-wm', help='path to word map JSON')
194 | parser.add_argument('--beam_size', '-b', default=5, type=int, help='beam size for beam search')
195 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay')
196 |
197 | args = parser.parse_args()
198 |
199 | # Load model
200 | checkpoint = torch.load(args.model, map_location=str(device))
201 | decoder = checkpoint['decoder']
202 | decoder = decoder.to(device)
203 | decoder.eval()
204 | encoder = checkpoint['encoder']
205 | encoder = encoder.to(device)
206 | encoder.eval()
207 |
208 | # Load word map (word2ix)
209 | with open(args.word_map, 'r') as j:
210 | word_map = json.load(j)
211 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word
212 |
213 | # Encode, decode with attention and beam search
214 | seq, alphas = caption_image_beam_search(encoder, decoder, args.img, word_map, args.beam_size)
215 | alphas = torch.FloatTensor(alphas)
216 |
217 | # Visualize caption and attention of best sequence
218 | visualize_att(args.img, seq, alphas, rev_word_map, args.smooth)
219 |
--------------------------------------------------------------------------------
/create_input_files.py:
--------------------------------------------------------------------------------
1 | from utils import create_input_files
2 |
3 | if __name__ == '__main__':
4 | # Create input files (along with word map)
5 | create_input_files(dataset='coco',
6 | karpathy_json_path='../caption data/dataset_coco.json',
7 | image_folder='/media/ssd/caption data/',
8 | captions_per_image=5,
9 | min_word_freq=5,
10 | output_folder='/media/ssd/caption data/',
11 | max_len=50)
12 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import h5py
4 | import json
5 | import os
6 |
7 |
8 | class CaptionDataset(Dataset):
9 | """
10 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
11 | """
12 |
13 | def __init__(self, data_folder, data_name, split, transform=None):
14 | """
15 | :param data_folder: folder where data files are stored
16 | :param data_name: base name of processed datasets
17 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST'
18 | :param transform: image transform pipeline
19 | """
20 | self.split = split
21 | assert self.split in {'TRAIN', 'VAL', 'TEST'}
22 |
23 | # Open hdf5 file where images are stored
24 | self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r')
25 | self.imgs = self.h['images']
26 |
27 | # Captions per image
28 | self.cpi = self.h.attrs['captions_per_image']
29 |
30 | # Load encoded captions (completely into memory)
31 | with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j:
32 | self.captions = json.load(j)
33 |
34 | # Load caption lengths (completely into memory)
35 | with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j:
36 | self.caplens = json.load(j)
37 |
38 | # PyTorch transformation pipeline for the image (normalizing, etc.)
39 | self.transform = transform
40 |
41 | # Total number of datapoints
42 | self.dataset_size = len(self.captions)
43 |
44 | def __getitem__(self, i):
45 | # Remember, the Nth caption corresponds to the (N // captions_per_image)th image
46 | img = torch.FloatTensor(self.imgs[i // self.cpi] / 255.)
47 | if self.transform is not None:
48 | img = self.transform(img)
49 |
50 | caption = torch.LongTensor(self.captions[i])
51 |
52 | caplen = torch.LongTensor([self.caplens[i]])
53 |
54 | if self.split is 'TRAIN':
55 | return img, caption, caplen
56 | else:
57 | # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score
58 | all_captions = torch.LongTensor(
59 | self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)])
60 | return img, caption, caplen, all_captions
61 |
62 | def __len__(self):
63 | return self.dataset_size
64 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch.backends.cudnn as cudnn
2 | import torch.optim
3 | import torch.utils.data
4 | import torchvision.transforms as transforms
5 | from datasets import *
6 | from utils import *
7 | from nltk.translate.bleu_score import corpus_bleu
8 | import torch.nn.functional as F
9 | from tqdm import tqdm
10 |
11 | # Parameters
12 | data_folder = '/media/ssd/caption data' # folder with data files saved by create_input_files.py
13 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files
14 | checkpoint = '../BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' # model checkpoint
15 | word_map_file = '/media/ssd/caption data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' # word map, ensure it's the same the data was encoded with and the model was trained with
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors
17 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
18 |
19 | # Load model
20 | checkpoint = torch.load(checkpoint)
21 | decoder = checkpoint['decoder']
22 | decoder = decoder.to(device)
23 | decoder.eval()
24 | encoder = checkpoint['encoder']
25 | encoder = encoder.to(device)
26 | encoder.eval()
27 |
28 | # Load word map (word2ix)
29 | with open(word_map_file, 'r') as j:
30 | word_map = json.load(j)
31 | rev_word_map = {v: k for k, v in word_map.items()}
32 | vocab_size = len(word_map)
33 |
34 | # Normalization transform
35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 |
38 |
39 | def evaluate(beam_size):
40 | """
41 | Evaluation
42 |
43 | :param beam_size: beam size at which to generate captions for evaluation
44 | :return: BLEU-4 score
45 | """
46 | # DataLoader
47 | loader = torch.utils.data.DataLoader(
48 | CaptionDataset(data_folder, data_name, 'TEST', transform=transforms.Compose([normalize])),
49 | batch_size=1, shuffle=True, num_workers=1, pin_memory=True)
50 |
51 | # TODO: Batched Beam Search
52 | # Therefore, do not use a batch_size greater than 1 - IMPORTANT!
53 |
54 | # Lists to store references (true captions), and hypothesis (prediction) for each image
55 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
56 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
57 | references = list()
58 | hypotheses = list()
59 |
60 | # For each image
61 | for i, (image, caps, caplens, allcaps) in enumerate(
62 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):
63 |
64 | k = beam_size
65 |
66 | # Move to GPU device, if available
67 | image = image.to(device) # (1, 3, 256, 256)
68 |
69 | # Encode
70 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
71 | enc_image_size = encoder_out.size(1)
72 | encoder_dim = encoder_out.size(3)
73 |
74 | # Flatten encoding
75 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
76 | num_pixels = encoder_out.size(1)
77 |
78 | # We'll treat the problem as having a batch size of k
79 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
80 |
81 | # Tensor to store top k previous words at each step; now they're just
82 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
83 |
84 | # Tensor to store top k sequences; now they're just
85 | seqs = k_prev_words # (k, 1)
86 |
87 | # Tensor to store top k sequences' scores; now they're just 0
88 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
89 |
90 | # Lists to store completed sequences and scores
91 | complete_seqs = list()
92 | complete_seqs_scores = list()
93 |
94 | # Start decoding
95 | step = 1
96 | h, c = decoder.init_hidden_state(encoder_out)
97 |
98 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
99 | while True:
100 |
101 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
102 |
103 | awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
104 |
105 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
106 | awe = gate * awe
107 |
108 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
109 |
110 | scores = decoder.fc(h) # (s, vocab_size)
111 | scores = F.log_softmax(scores, dim=1)
112 |
113 | # Add
114 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
115 |
116 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
117 | if step == 1:
118 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
119 | else:
120 | # Unroll and find top scores, and their unrolled indices
121 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
122 |
123 | # Convert unrolled indices to actual indices of scores
124 | prev_word_inds = top_k_words / vocab_size # (s)
125 | next_word_inds = top_k_words % vocab_size # (s)
126 |
127 | # Add new words to sequences
128 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
129 |
130 | # Which sequences are incomplete (didn't reach )?
131 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
132 | next_word != word_map['']]
133 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
134 |
135 | # Set aside complete sequences
136 | if len(complete_inds) > 0:
137 | complete_seqs.extend(seqs[complete_inds].tolist())
138 | complete_seqs_scores.extend(top_k_scores[complete_inds])
139 | k -= len(complete_inds) # reduce beam length accordingly
140 |
141 | # Proceed with incomplete sequences
142 | if k == 0:
143 | break
144 | seqs = seqs[incomplete_inds]
145 | h = h[prev_word_inds[incomplete_inds]]
146 | c = c[prev_word_inds[incomplete_inds]]
147 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
148 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
149 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
150 |
151 | # Break if things have been going on too long
152 | if step > 50:
153 | break
154 | step += 1
155 |
156 | i = complete_seqs_scores.index(max(complete_seqs_scores))
157 | seq = complete_seqs[i]
158 |
159 | # References
160 | img_caps = allcaps[0].tolist()
161 | img_captions = list(
162 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}],
163 | img_caps)) # remove and pads
164 | references.append(img_captions)
165 |
166 | # Hypotheses
167 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}])
168 |
169 | assert len(references) == len(hypotheses)
170 |
171 | # Calculate BLEU-4 scores
172 | bleu4 = corpus_bleu(references, hypotheses)
173 |
174 | return bleu4
175 |
176 |
177 | if __name__ == '__main__':
178 | beam_size = 1
179 | print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size)))
180 |
--------------------------------------------------------------------------------
/img/att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/att.png
--------------------------------------------------------------------------------
/img/babycake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/babycake.png
--------------------------------------------------------------------------------
/img/beam_search.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/beam_search.png
--------------------------------------------------------------------------------
/img/bikefence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/bikefence.png
--------------------------------------------------------------------------------
/img/biketrain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/biketrain.png
--------------------------------------------------------------------------------
/img/birds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/birds.png
--------------------------------------------------------------------------------
/img/boats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/boats.png
--------------------------------------------------------------------------------
/img/catbanana.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/catbanana.png
--------------------------------------------------------------------------------
/img/decoder_att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/decoder_att.png
--------------------------------------------------------------------------------
/img/decoder_no_att.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/decoder_no_att.png
--------------------------------------------------------------------------------
/img/dogtie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/dogtie.png
--------------------------------------------------------------------------------
/img/doublystochastic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/doublystochastic.png
--------------------------------------------------------------------------------
/img/encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/encoder.png
--------------------------------------------------------------------------------
/img/firehydrant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/firehydrant.png
--------------------------------------------------------------------------------
/img/manbike.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/manbike.png
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/model.png
--------------------------------------------------------------------------------
/img/plane.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/plane.png
--------------------------------------------------------------------------------
/img/salad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/salad.png
--------------------------------------------------------------------------------
/img/sheep.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/sheep.png
--------------------------------------------------------------------------------
/img/sorted.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/sorted.jpg
--------------------------------------------------------------------------------
/img/sorted2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/sorted2.jpg
--------------------------------------------------------------------------------
/img/tommy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/tommy.png
--------------------------------------------------------------------------------
/img/weights.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/cc9c7e2f4017938d414178d3781fed8dbe442852/img/weights.png
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchvision
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class Encoder(nn.Module):
9 | """
10 | Encoder.
11 | """
12 |
13 | def __init__(self, encoded_image_size=14):
14 | super(Encoder, self).__init__()
15 | self.enc_image_size = encoded_image_size
16 |
17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101
18 |
19 | # Remove linear and pool layers (since we're not doing classification)
20 | modules = list(resnet.children())[:-2]
21 | self.resnet = nn.Sequential(*modules)
22 |
23 | # Resize image to fixed size to allow input images of variable size
24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
25 |
26 | self.fine_tune()
27 |
28 | def forward(self, images):
29 | """
30 | Forward propagation.
31 |
32 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
33 | :return: encoded images
34 | """
35 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
36 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
37 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
38 | return out
39 |
40 | def fine_tune(self, fine_tune=True):
41 | """
42 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
43 |
44 | :param fine_tune: Allow?
45 | """
46 | for p in self.resnet.parameters():
47 | p.requires_grad = False
48 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4
49 | for c in list(self.resnet.children())[5:]:
50 | for p in c.parameters():
51 | p.requires_grad = fine_tune
52 |
53 |
54 | class Attention(nn.Module):
55 | """
56 | Attention Network.
57 | """
58 |
59 | def __init__(self, encoder_dim, decoder_dim, attention_dim):
60 | """
61 | :param encoder_dim: feature size of encoded images
62 | :param decoder_dim: size of decoder's RNN
63 | :param attention_dim: size of the attention network
64 | """
65 | super(Attention, self).__init__()
66 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
67 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
68 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
69 | self.relu = nn.ReLU()
70 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
71 |
72 | def forward(self, encoder_out, decoder_hidden):
73 | """
74 | Forward propagation.
75 |
76 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
77 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
78 | :return: attention weighted encoding, weights
79 | """
80 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
81 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
82 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
83 | alpha = self.softmax(att) # (batch_size, num_pixels)
84 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
85 |
86 | return attention_weighted_encoding, alpha
87 |
88 |
89 | class DecoderWithAttention(nn.Module):
90 | """
91 | Decoder.
92 | """
93 |
94 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
95 | """
96 | :param attention_dim: size of attention network
97 | :param embed_dim: embedding size
98 | :param decoder_dim: size of decoder's RNN
99 | :param vocab_size: size of vocabulary
100 | :param encoder_dim: feature size of encoded images
101 | :param dropout: dropout
102 | """
103 | super(DecoderWithAttention, self).__init__()
104 |
105 | self.encoder_dim = encoder_dim
106 | self.attention_dim = attention_dim
107 | self.embed_dim = embed_dim
108 | self.decoder_dim = decoder_dim
109 | self.vocab_size = vocab_size
110 | self.dropout = dropout
111 |
112 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network
113 |
114 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
115 | self.dropout = nn.Dropout(p=self.dropout)
116 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
117 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
118 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
119 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
120 | self.sigmoid = nn.Sigmoid()
121 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
122 | self.init_weights() # initialize some layers with the uniform distribution
123 |
124 | def init_weights(self):
125 | """
126 | Initializes some parameters with values from the uniform distribution, for easier convergence.
127 | """
128 | self.embedding.weight.data.uniform_(-0.1, 0.1)
129 | self.fc.bias.data.fill_(0)
130 | self.fc.weight.data.uniform_(-0.1, 0.1)
131 |
132 | def load_pretrained_embeddings(self, embeddings):
133 | """
134 | Loads embedding layer with pre-trained embeddings.
135 |
136 | :param embeddings: pre-trained embeddings
137 | """
138 | self.embedding.weight = nn.Parameter(embeddings)
139 |
140 | def fine_tune_embeddings(self, fine_tune=True):
141 | """
142 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
143 |
144 | :param fine_tune: Allow?
145 | """
146 | for p in self.embedding.parameters():
147 | p.requires_grad = fine_tune
148 |
149 | def init_hidden_state(self, encoder_out):
150 | """
151 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
152 |
153 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
154 | :return: hidden state, cell state
155 | """
156 | mean_encoder_out = encoder_out.mean(dim=1)
157 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
158 | c = self.init_c(mean_encoder_out)
159 | return h, c
160 |
161 | def forward(self, encoder_out, encoded_captions, caption_lengths):
162 | """
163 | Forward propagation.
164 |
165 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
166 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
167 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
168 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
169 | """
170 |
171 | batch_size = encoder_out.size(0)
172 | encoder_dim = encoder_out.size(-1)
173 | vocab_size = self.vocab_size
174 |
175 | # Flatten image
176 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
177 | num_pixels = encoder_out.size(1)
178 |
179 | # Sort input data by decreasing lengths; why? apparent below
180 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
181 | encoder_out = encoder_out[sort_ind]
182 | encoded_captions = encoded_captions[sort_ind]
183 |
184 | # Embedding
185 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
186 |
187 | # Initialize LSTM state
188 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim)
189 |
190 | # We won't decode at the position, since we've finished generating as soon as we generate
191 | # So, decoding lengths are actual lengths - 1
192 | decode_lengths = (caption_lengths - 1).tolist()
193 |
194 | # Create tensors to hold word predicion scores and alphas
195 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
196 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
197 |
198 | # At each time-step, decode by
199 | # attention-weighing the encoder's output based on the decoder's previous hidden state output
200 | # then generate a new word in the decoder with the previous word and the attention weighted encoding
201 | for t in range(max(decode_lengths)):
202 | batch_size_t = sum([l > t for l in decode_lengths])
203 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
204 | h[:batch_size_t])
205 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
206 | attention_weighted_encoding = gate * attention_weighted_encoding
207 | h, c = self.decode_step(
208 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
209 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
210 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
211 | predictions[:batch_size_t, t, :] = preds
212 | alphas[:batch_size_t, t, :] = alpha
213 |
214 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind
215 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch.backends.cudnn as cudnn
3 | import torch.optim
4 | import torch.utils.data
5 | import torchvision.transforms as transforms
6 | from torch import nn
7 | from torch.nn.utils.rnn import pack_padded_sequence
8 | from models import Encoder, DecoderWithAttention
9 | from datasets import *
10 | from utils import *
11 | from nltk.translate.bleu_score import corpus_bleu
12 |
13 | # Data parameters
14 | data_folder = '/media/ssd/caption data' # folder with data files saved by create_input_files.py
15 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files
16 |
17 | # Model parameters
18 | emb_dim = 512 # dimension of word embeddings
19 | attention_dim = 512 # dimension of attention linear layers
20 | decoder_dim = 512 # dimension of decoder RNN
21 | dropout = 0.5
22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors
23 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
24 |
25 | # Training parameters
26 | start_epoch = 0
27 | epochs = 120 # number of epochs to train for (if early stopping is not triggered)
28 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU
29 | batch_size = 32
30 | workers = 1 # for data-loading; right now, only 1 works with h5py
31 | encoder_lr = 1e-4 # learning rate for encoder if fine-tuning
32 | decoder_lr = 4e-4 # learning rate for decoder
33 | grad_clip = 5. # clip gradients at an absolute value of
34 | alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper
35 | best_bleu4 = 0. # BLEU-4 score right now
36 | print_freq = 100 # print training/validation stats every __ batches
37 | fine_tune_encoder = False # fine-tune encoder?
38 | checkpoint = None # path to checkpoint, None if none
39 |
40 |
41 | def main():
42 | """
43 | Training and validation.
44 | """
45 |
46 | global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
47 |
48 | # Read word map
49 | word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
50 | with open(word_map_file, 'r') as j:
51 | word_map = json.load(j)
52 |
53 | # Initialize / load checkpoint
54 | if checkpoint is None:
55 | decoder = DecoderWithAttention(attention_dim=attention_dim,
56 | embed_dim=emb_dim,
57 | decoder_dim=decoder_dim,
58 | vocab_size=len(word_map),
59 | dropout=dropout)
60 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
61 | lr=decoder_lr)
62 | encoder = Encoder()
63 | encoder.fine_tune(fine_tune_encoder)
64 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
65 | lr=encoder_lr) if fine_tune_encoder else None
66 |
67 | else:
68 | checkpoint = torch.load(checkpoint)
69 | start_epoch = checkpoint['epoch'] + 1
70 | epochs_since_improvement = checkpoint['epochs_since_improvement']
71 | best_bleu4 = checkpoint['bleu-4']
72 | decoder = checkpoint['decoder']
73 | decoder_optimizer = checkpoint['decoder_optimizer']
74 | encoder = checkpoint['encoder']
75 | encoder_optimizer = checkpoint['encoder_optimizer']
76 | if fine_tune_encoder is True and encoder_optimizer is None:
77 | encoder.fine_tune(fine_tune_encoder)
78 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
79 | lr=encoder_lr)
80 |
81 | # Move to GPU, if available
82 | decoder = decoder.to(device)
83 | encoder = encoder.to(device)
84 |
85 | # Loss function
86 | criterion = nn.CrossEntropyLoss().to(device)
87 |
88 | # Custom dataloaders
89 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
90 | std=[0.229, 0.224, 0.225])
91 | train_loader = torch.utils.data.DataLoader(
92 | CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
93 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
94 | val_loader = torch.utils.data.DataLoader(
95 | CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
96 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
97 |
98 | # Epochs
99 | for epoch in range(start_epoch, epochs):
100 |
101 | # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
102 | if epochs_since_improvement == 20:
103 | break
104 | if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
105 | adjust_learning_rate(decoder_optimizer, 0.8)
106 | if fine_tune_encoder:
107 | adjust_learning_rate(encoder_optimizer, 0.8)
108 |
109 | # One epoch's training
110 | train(train_loader=train_loader,
111 | encoder=encoder,
112 | decoder=decoder,
113 | criterion=criterion,
114 | encoder_optimizer=encoder_optimizer,
115 | decoder_optimizer=decoder_optimizer,
116 | epoch=epoch)
117 |
118 | # One epoch's validation
119 | recent_bleu4 = validate(val_loader=val_loader,
120 | encoder=encoder,
121 | decoder=decoder,
122 | criterion=criterion)
123 |
124 | # Check if there was an improvement
125 | is_best = recent_bleu4 > best_bleu4
126 | best_bleu4 = max(recent_bleu4, best_bleu4)
127 | if not is_best:
128 | epochs_since_improvement += 1
129 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
130 | else:
131 | epochs_since_improvement = 0
132 |
133 | # Save checkpoint
134 | save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
135 | decoder_optimizer, recent_bleu4, is_best)
136 |
137 |
138 | def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
139 | """
140 | Performs one epoch's training.
141 |
142 | :param train_loader: DataLoader for training data
143 | :param encoder: encoder model
144 | :param decoder: decoder model
145 | :param criterion: loss layer
146 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
147 | :param decoder_optimizer: optimizer to update decoder's weights
148 | :param epoch: epoch number
149 | """
150 |
151 | decoder.train() # train mode (dropout and batchnorm is used)
152 | encoder.train()
153 |
154 | batch_time = AverageMeter() # forward prop. + back prop. time
155 | data_time = AverageMeter() # data loading time
156 | losses = AverageMeter() # loss (per word decoded)
157 | top5accs = AverageMeter() # top5 accuracy
158 |
159 | start = time.time()
160 |
161 | # Batches
162 | for i, (imgs, caps, caplens) in enumerate(train_loader):
163 | data_time.update(time.time() - start)
164 |
165 | # Move to GPU, if available
166 | imgs = imgs.to(device)
167 | caps = caps.to(device)
168 | caplens = caplens.to(device)
169 |
170 | # Forward prop.
171 | imgs = encoder(imgs)
172 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
173 |
174 | # Since we decoded starting with , the targets are all words after , up to
175 | targets = caps_sorted[:, 1:]
176 |
177 | # Remove timesteps that we didn't decode at, or are pads
178 | # pack_padded_sequence is an easy trick to do this
179 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
180 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
181 |
182 | # Calculate loss
183 | loss = criterion(scores, targets)
184 |
185 | # Add doubly stochastic attention regularization
186 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
187 |
188 | # Back prop.
189 | decoder_optimizer.zero_grad()
190 | if encoder_optimizer is not None:
191 | encoder_optimizer.zero_grad()
192 | loss.backward()
193 |
194 | # Clip gradients
195 | if grad_clip is not None:
196 | clip_gradient(decoder_optimizer, grad_clip)
197 | if encoder_optimizer is not None:
198 | clip_gradient(encoder_optimizer, grad_clip)
199 |
200 | # Update weights
201 | decoder_optimizer.step()
202 | if encoder_optimizer is not None:
203 | encoder_optimizer.step()
204 |
205 | # Keep track of metrics
206 | top5 = accuracy(scores, targets, 5)
207 | losses.update(loss.item(), sum(decode_lengths))
208 | top5accs.update(top5, sum(decode_lengths))
209 | batch_time.update(time.time() - start)
210 |
211 | start = time.time()
212 |
213 | # Print status
214 | if i % print_freq == 0:
215 | print('Epoch: [{0}][{1}/{2}]\t'
216 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
217 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
218 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
219 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
220 | batch_time=batch_time,
221 | data_time=data_time, loss=losses,
222 | top5=top5accs))
223 |
224 |
225 | def validate(val_loader, encoder, decoder, criterion):
226 | """
227 | Performs one epoch's validation.
228 |
229 | :param val_loader: DataLoader for validation data.
230 | :param encoder: encoder model
231 | :param decoder: decoder model
232 | :param criterion: loss layer
233 | :return: BLEU-4 score
234 | """
235 | decoder.eval() # eval mode (no dropout or batchnorm)
236 | if encoder is not None:
237 | encoder.eval()
238 |
239 | batch_time = AverageMeter()
240 | losses = AverageMeter()
241 | top5accs = AverageMeter()
242 |
243 | start = time.time()
244 |
245 | references = list() # references (true captions) for calculating BLEU-4 score
246 | hypotheses = list() # hypotheses (predictions)
247 |
248 | # explicitly disable gradient calculation to avoid CUDA memory error
249 | # solves the issue #57
250 | with torch.no_grad():
251 | # Batches
252 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
253 |
254 | # Move to device, if available
255 | imgs = imgs.to(device)
256 | caps = caps.to(device)
257 | caplens = caplens.to(device)
258 |
259 | # Forward prop.
260 | if encoder is not None:
261 | imgs = encoder(imgs)
262 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
263 |
264 | # Since we decoded starting with , the targets are all words after , up to
265 | targets = caps_sorted[:, 1:]
266 |
267 | # Remove timesteps that we didn't decode at, or are pads
268 | # pack_padded_sequence is an easy trick to do this
269 | scores_copy = scores.clone()
270 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
271 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
272 |
273 | # Calculate loss
274 | loss = criterion(scores, targets)
275 |
276 | # Add doubly stochastic attention regularization
277 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
278 |
279 | # Keep track of metrics
280 | losses.update(loss.item(), sum(decode_lengths))
281 | top5 = accuracy(scores, targets, 5)
282 | top5accs.update(top5, sum(decode_lengths))
283 | batch_time.update(time.time() - start)
284 |
285 | start = time.time()
286 |
287 | if i % print_freq == 0:
288 | print('Validation: [{0}/{1}]\t'
289 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
290 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
291 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
292 | loss=losses, top5=top5accs))
293 |
294 | # Store references (true captions), and hypothesis (prediction) for each image
295 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
296 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
297 |
298 | # References
299 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder
300 | for j in range(allcaps.shape[0]):
301 | img_caps = allcaps[j].tolist()
302 | img_captions = list(
303 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}],
304 | img_caps)) # remove and pads
305 | references.append(img_captions)
306 |
307 | # Hypotheses
308 | _, preds = torch.max(scores_copy, dim=2)
309 | preds = preds.tolist()
310 | temp_preds = list()
311 | for j, p in enumerate(preds):
312 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
313 | preds = temp_preds
314 | hypotheses.extend(preds)
315 |
316 | assert len(references) == len(hypotheses)
317 |
318 | # Calculate BLEU-4 scores
319 | bleu4 = corpus_bleu(references, hypotheses)
320 |
321 | print(
322 | '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
323 | loss=losses,
324 | top5=top5accs,
325 | bleu=bleu4))
326 |
327 | return bleu4
328 |
329 |
330 | if __name__ == '__main__':
331 | main()
332 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import h5py
4 | import json
5 | import torch
6 | from scipy.misc import imread, imresize
7 | from tqdm import tqdm
8 | from collections import Counter
9 | from random import seed, choice, sample
10 |
11 |
12 | def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder,
13 | max_len=100):
14 | """
15 | Creates input files for training, validation, and test data.
16 |
17 | :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k'
18 | :param karpathy_json_path: path of Karpathy JSON file with splits and captions
19 | :param image_folder: folder with downloaded images
20 | :param captions_per_image: number of captions to sample per image
21 | :param min_word_freq: words occuring less frequently than this threshold are binned as s
22 | :param output_folder: folder to save files
23 | :param max_len: don't sample captions longer than this length
24 | """
25 |
26 | assert dataset in {'coco', 'flickr8k', 'flickr30k'}
27 |
28 | # Read Karpathy JSON
29 | with open(karpathy_json_path, 'r') as j:
30 | data = json.load(j)
31 |
32 | # Read image paths and captions for each image
33 | train_image_paths = []
34 | train_image_captions = []
35 | val_image_paths = []
36 | val_image_captions = []
37 | test_image_paths = []
38 | test_image_captions = []
39 | word_freq = Counter()
40 |
41 | for img in data['images']:
42 | captions = []
43 | for c in img['sentences']:
44 | # Update word frequency
45 | word_freq.update(c['tokens'])
46 | if len(c['tokens']) <= max_len:
47 | captions.append(c['tokens'])
48 |
49 | if len(captions) == 0:
50 | continue
51 |
52 | path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join(
53 | image_folder, img['filename'])
54 |
55 | if img['split'] in {'train', 'restval'}:
56 | train_image_paths.append(path)
57 | train_image_captions.append(captions)
58 | elif img['split'] in {'val'}:
59 | val_image_paths.append(path)
60 | val_image_captions.append(captions)
61 | elif img['split'] in {'test'}:
62 | test_image_paths.append(path)
63 | test_image_captions.append(captions)
64 |
65 | # Sanity check
66 | assert len(train_image_paths) == len(train_image_captions)
67 | assert len(val_image_paths) == len(val_image_captions)
68 | assert len(test_image_paths) == len(test_image_captions)
69 |
70 | # Create word map
71 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
72 | word_map = {k: v + 1 for v, k in enumerate(words)}
73 | word_map[''] = len(word_map) + 1
74 | word_map[''] = len(word_map) + 1
75 | word_map[''] = len(word_map) + 1
76 | word_map[''] = 0
77 |
78 | # Create a base/root name for all output files
79 | base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq'
80 |
81 | # Save word map to a JSON
82 | with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
83 | json.dump(word_map, j)
84 |
85 | # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files
86 | seed(123)
87 | for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'),
88 | (val_image_paths, val_image_captions, 'VAL'),
89 | (test_image_paths, test_image_captions, 'TEST')]:
90 |
91 | with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
92 | # Make a note of the number of captions we are sampling per image
93 | h.attrs['captions_per_image'] = captions_per_image
94 |
95 | # Create dataset inside HDF5 file to store images
96 | images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8')
97 |
98 | print("\nReading %s images and captions, storing to file...\n" % split)
99 |
100 | enc_captions = []
101 | caplens = []
102 |
103 | for i, path in enumerate(tqdm(impaths)):
104 |
105 | # Sample captions
106 | if len(imcaps[i]) < captions_per_image:
107 | captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))]
108 | else:
109 | captions = sample(imcaps[i], k=captions_per_image)
110 |
111 | # Sanity check
112 | assert len(captions) == captions_per_image
113 |
114 | # Read images
115 | img = imread(impaths[i])
116 | if len(img.shape) == 2:
117 | img = img[:, :, np.newaxis]
118 | img = np.concatenate([img, img, img], axis=2)
119 | img = imresize(img, (256, 256))
120 | img = img.transpose(2, 0, 1)
121 | assert img.shape == (3, 256, 256)
122 | assert np.max(img) <= 255
123 |
124 | # Save image to HDF5 file
125 | images[i] = img
126 |
127 | for j, c in enumerate(captions):
128 | # Encode captions
129 | enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [
130 | word_map['']] + [word_map['']] * (max_len - len(c))
131 |
132 | # Find caption lengths
133 | c_len = len(c) + 2
134 |
135 | enc_captions.append(enc_c)
136 | caplens.append(c_len)
137 |
138 | # Sanity check
139 | assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens)
140 |
141 | # Save encoded captions and their lengths to JSON files
142 | with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j:
143 | json.dump(enc_captions, j)
144 |
145 | with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j:
146 | json.dump(caplens, j)
147 |
148 |
149 | def init_embedding(embeddings):
150 | """
151 | Fills embedding tensor with values from the uniform distribution.
152 |
153 | :param embeddings: embedding tensor
154 | """
155 | bias = np.sqrt(3.0 / embeddings.size(1))
156 | torch.nn.init.uniform_(embeddings, -bias, bias)
157 |
158 |
159 | def load_embeddings(emb_file, word_map):
160 | """
161 | Creates an embedding tensor for the specified word map, for loading into the model.
162 |
163 | :param emb_file: file containing embeddings (stored in GloVe format)
164 | :param word_map: word map
165 | :return: embeddings in the same order as the words in the word map, dimension of embeddings
166 | """
167 |
168 | # Find embedding dimension
169 | with open(emb_file, 'r') as f:
170 | emb_dim = len(f.readline().split(' ')) - 1
171 |
172 | vocab = set(word_map.keys())
173 |
174 | # Create tensor to hold embeddings, initialize
175 | embeddings = torch.FloatTensor(len(vocab), emb_dim)
176 | init_embedding(embeddings)
177 |
178 | # Read embedding file
179 | print("\nLoading embeddings...")
180 | for line in open(emb_file, 'r'):
181 | line = line.split(' ')
182 |
183 | emb_word = line[0]
184 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))
185 |
186 | # Ignore word if not in train_vocab
187 | if emb_word not in vocab:
188 | continue
189 |
190 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)
191 |
192 | return embeddings, emb_dim
193 |
194 |
195 | def clip_gradient(optimizer, grad_clip):
196 | """
197 | Clips gradients computed during backpropagation to avoid explosion of gradients.
198 |
199 | :param optimizer: optimizer with the gradients to be clipped
200 | :param grad_clip: clip value
201 | """
202 | for group in optimizer.param_groups:
203 | for param in group['params']:
204 | if param.grad is not None:
205 | param.grad.data.clamp_(-grad_clip, grad_clip)
206 |
207 |
208 | def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer,
209 | bleu4, is_best):
210 | """
211 | Saves model checkpoint.
212 |
213 | :param data_name: base name of processed dataset
214 | :param epoch: epoch number
215 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
216 | :param encoder: encoder model
217 | :param decoder: decoder model
218 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
219 | :param decoder_optimizer: optimizer to update decoder's weights
220 | :param bleu4: validation BLEU-4 score for this epoch
221 | :param is_best: is this checkpoint the best so far?
222 | """
223 | state = {'epoch': epoch,
224 | 'epochs_since_improvement': epochs_since_improvement,
225 | 'bleu-4': bleu4,
226 | 'encoder': encoder,
227 | 'decoder': decoder,
228 | 'encoder_optimizer': encoder_optimizer,
229 | 'decoder_optimizer': decoder_optimizer}
230 | filename = 'checkpoint_' + data_name + '.pth.tar'
231 | torch.save(state, filename)
232 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
233 | if is_best:
234 | torch.save(state, 'BEST_' + filename)
235 |
236 |
237 | class AverageMeter(object):
238 | """
239 | Keeps track of most recent, average, sum, and count of a metric.
240 | """
241 |
242 | def __init__(self):
243 | self.reset()
244 |
245 | def reset(self):
246 | self.val = 0
247 | self.avg = 0
248 | self.sum = 0
249 | self.count = 0
250 |
251 | def update(self, val, n=1):
252 | self.val = val
253 | self.sum += val * n
254 | self.count += n
255 | self.avg = self.sum / self.count
256 |
257 |
258 | def adjust_learning_rate(optimizer, shrink_factor):
259 | """
260 | Shrinks learning rate by a specified factor.
261 |
262 | :param optimizer: optimizer whose learning rate must be shrunk.
263 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
264 | """
265 |
266 | print("\nDECAYING learning rate.")
267 | for param_group in optimizer.param_groups:
268 | param_group['lr'] = param_group['lr'] * shrink_factor
269 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
270 |
271 |
272 | def accuracy(scores, targets, k):
273 | """
274 | Computes top-k accuracy, from predicted and true labels.
275 |
276 | :param scores: scores from the model
277 | :param targets: true labels
278 | :param k: k in top-k accuracy
279 | :return: top-k accuracy
280 | """
281 |
282 | batch_size = targets.size(0)
283 | _, ind = scores.topk(k, 1, True, True)
284 | correct = ind.eq(targets.view(-1, 1).expand_as(ind))
285 | correct_total = correct.view(-1).float().sum() # 0D tensor
286 | return correct_total.item() * (100.0 / batch_size)
287 |
--------------------------------------------------------------------------------