├── .gitignore ├── LICENSE ├── README.md ├── analyze_data.py ├── config.py ├── data_generator.py ├── demo.py ├── eval.py ├── font ├── WenQuanYiMicroHei-01.ttf ├── WenQuanYiMicroHeiMono-02.ttf └── simhei.ttf ├── images ├── dataset.png ├── image_0.jpg ├── image_1.jpg ├── image_2.jpg ├── image_3.jpg ├── image_4.jpg ├── image_5.jpg ├── image_6.jpg ├── image_7.jpg ├── image_8.jpg ├── image_9.jpg ├── net.png ├── out_0.jpg ├── out_1.jpg ├── out_2.jpg ├── out_3.jpg ├── out_4.jpg ├── out_5.jpg ├── out_6.jpg ├── out_7.jpg ├── out_8.jpg └── out_9.jpg ├── models.py ├── pre_process.py ├── requirements.txt ├── sponsor.jpg ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | logs/ 4 | models/ 5 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 图像中文描述 2 | 3 | 图像中文描述 + 视觉注意力的 PyTorch 实现。 4 | 5 | [Show, Attend, and Tell](https://arxiv.org/pdf/1502.03044.pdf) 是令人惊叹的工作,[这里](https://github.com/kelvinxu/arctic-captions)是作者的原始实现。 6 | 7 | 这个模型学会了“往哪瞅”:当模型逐词生成标题时,模型的目光在图像上移动以专注于跟下一个词最相关的部分。 8 | 9 | ## 依赖 10 | - Python 3.5 11 | - PyTorch 0.4 12 | 13 | ## 数据集 14 | 15 | 使用 AI Challenger 2017 的图像中文描述数据集,包含30万张图片,150万句中文描述。训练集:210,000 张,验证集:30,000 张,测试集 A:30,000 张,测试集 B:30,000 张。 16 | 17 | 18 | ![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/dataset.png) 19 | 20 | 下载点这里:[图像中文描述数据集](https://challenger.ai/datasets/),放在 data 目录下。 21 | 22 | 23 | ## 网络结构 24 | 25 | ![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/net.png) 26 | 27 | ## 用法 28 | 29 | ### 数据预处理 30 | 提取210,000 张训练图片和30,000 张验证图片: 31 | ```bash 32 | $ python pre_process.py 33 | ``` 34 | 35 | ### 训练 36 | ```bash 37 | $ python train.py 38 | ``` 39 | 40 | 可视化训练过程,执行: 41 | ```bash 42 | $ tensorboard --logdir path_to_current_dir/logs 43 | ``` 44 | 45 | ### 演示 46 | 下载 [预训练模型](https://github.com/foamliu/Image-Captioning-v2/releases/download/v1.0/BEST_checkpoint_.pth.tar) 放在 models 目录,然后执行: 47 | 48 | ```bash 49 | $ python demo.py 50 | ``` 51 | 52 | 原图 | 注意力 | 53 | |---|---| 54 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_0.jpg) | 55 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_1.jpg) | 56 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_2.jpg) | 57 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_3.jpg) | 58 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_4.jpg) | 59 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_5.jpg) | 60 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_6.jpg) | 61 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_7.jpg) | 62 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_8.jpg) | 63 | ||![image](https://github.com/foamliu/Image-Captioning-v2/raw/master/images/out_9.jpg) | 64 | 65 | ## 小小的赞助~ 66 |

67 | Sample 68 |

69 | 若对您有帮助可给予小小的赞助~ 70 |

71 |

72 |


73 | 74 | -------------------------------------------------------------------------------- /analyze_data.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | import json 3 | import os 4 | 5 | import jieba 6 | from tqdm import tqdm 7 | 8 | from config import train_folder, train_annotations_filename 9 | 10 | if __name__ == '__main__': 11 | print('Calculating the maximum length among all train captions') 12 | annotations_path = os.path.join(train_folder, train_annotations_filename) 13 | 14 | with open(annotations_path, 'r') as f: 15 | samples = json.load(f) 16 | 17 | max_len = 0 18 | for sample in tqdm(samples): 19 | caption = sample['caption'] 20 | for c in caption: 21 | seg_list = jieba.cut(c, cut_all=True) 22 | length = sum(1 for item in seg_list) 23 | if length > max_len: 24 | max_len = length 25 | print('max_len: ' + str(max_len)) 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | 6 | image_h = image_w = image_size = 256 7 | channel = 3 8 | epochs = 10000 9 | patience = 10 10 | num_train_samples = 1050000 11 | num_valid_samples = 150000 12 | max_len = 40 13 | captions_per_image = 5 14 | 15 | # Model parameters 16 | emb_dim = 512 # dimension of word embeddings 17 | attention_dim = 512 # dimension of attention linear layers 18 | decoder_dim = 512 # dimension of decoder RNN 19 | dropout = 0.5 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 21 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 22 | 23 | # Training parameters 24 | start_epoch = 0 25 | epochs = 120 # number of epochs to train for (if early stopping is not triggered) 26 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU 27 | batch_size = 128 28 | workers = 2 # for data-loading; right now, only 1 works with h5py 29 | encoder_lr = 1e-4 # learning rate for encoder if fine-tuning 30 | decoder_lr = 4e-4 # learning rate for decoder 31 | grad_clip = 5. # clip gradients at an absolute value of 32 | alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper 33 | best_bleu4 = 0. # BLEU-4 score right now 34 | print_freq = 100 # print training/validation stats every __ batches 35 | fine_tune_encoder = False # fine-tune encoder? 36 | checkpoint = None # path to checkpoint, None if none 37 | min_word_freq = 3 38 | 39 | # Data parameters 40 | data_folder = 'data' 41 | train_folder = 'data/ai_challenger_caption_train_20170902' 42 | valid_folder = 'data/ai_challenger_caption_validation_20170910' 43 | test_a_folder = 'data/ai_challenger_caption_test_a_20180103' 44 | test_b_folder = 'data/ai_challenger_caption_test_b_20180103' 45 | train_image_folder = os.path.join(train_folder, 'caption_train_images_20170902') 46 | valid_image_folder = os.path.join(valid_folder, 'caption_validation_images_20170910') 47 | test_a_image_folder = os.path.join(test_a_folder, 'caption_test_a_images_20180103') 48 | test_b_image_folder = os.path.join(test_b_folder, 'caption_test_b_images_20180103') 49 | train_annotations_filename = os.path.join(train_folder, 'caption_train_annotations_20170902.json') 50 | valid_annotations_filename = os.path.join(valid_folder, 'caption_validation_annotations_20170910.json') 51 | test_a_annotations_filename = os.path.join(test_a_folder, 'caption_test_a_annotations_20180103.json') 52 | test_b_annotations_filename = os.path.join(test_b_folder, 'caption_test_b_annotations_20180103.json') 53 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | import json 3 | 4 | import jieba 5 | import numpy as np 6 | from scipy.misc import imread, imresize 7 | from torch.utils.data import Dataset 8 | 9 | from config import * 10 | 11 | 12 | def encode_caption(word_map, c): 13 | return [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ 14 | word_map['']] + [word_map['']] * (max_len - len(c)) 15 | 16 | 17 | class CaptionDataset(Dataset): 18 | """ 19 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. 20 | """ 21 | 22 | def __init__(self, split, transform=None): 23 | """ 24 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST' 25 | :param transform: image transform pipeline 26 | """ 27 | self.split = split 28 | assert self.split in {'train', 'valid'} 29 | 30 | if split == 'train': 31 | json_path = train_annotations_filename 32 | self.image_folder = train_image_folder 33 | else: 34 | json_path = valid_annotations_filename 35 | self.image_folder = valid_image_folder 36 | 37 | # Read JSON 38 | with open(json_path, 'r') as j: 39 | self.samples = json.load(j) 40 | 41 | # Read wordmap 42 | with open(os.path.join(data_folder, 'WORDMAP.json'), 'r') as j: 43 | self.word_map = json.load(j) 44 | 45 | # PyTorch transformation pipeline for the image (normalizing, etc.) 46 | self.transform = transform 47 | 48 | # Total number of datapoints 49 | self.dataset_size = len(self.samples * captions_per_image) 50 | 51 | def __getitem__(self, i): 52 | # Remember, the Nth caption corresponds to the (N // captions_per_image)th image 53 | sample = self.samples[i // captions_per_image] 54 | path = os.path.join(self.image_folder, sample['image_id']) 55 | # Read images 56 | img = imread(path) 57 | if len(img.shape) == 2: 58 | img = img[:, :, np.newaxis] 59 | img = np.concatenate([img, img, img], axis=2) 60 | img = imresize(img, (256, 256)) 61 | img = img.transpose(2, 0, 1) 62 | assert img.shape == (3, 256, 256) 63 | assert np.max(img) <= 255 64 | img = torch.FloatTensor(img / 255.) 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | 68 | # Sample captions 69 | captions = sample['caption'] 70 | # Sanity check 71 | assert len(captions) == captions_per_image 72 | c = captions[i % captions_per_image] 73 | c = list(jieba.cut(c)) 74 | # Encode captions 75 | enc_c = encode_caption(self.word_map, c) 76 | 77 | caption = torch.LongTensor(enc_c) 78 | 79 | caplen = torch.LongTensor([len(c) + 2]) 80 | 81 | if self.split is 'train': 82 | return img, caption, caplen 83 | else: 84 | # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score 85 | all_captions = torch.LongTensor([encode_caption(self.word_map, list(jieba.cut(c))) for c in captions]) 86 | return img, caption, caplen, all_captions 87 | 88 | def __len__(self): 89 | return self.dataset_size 90 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | """a demo of matplotlib""" 4 | import matplotlib as mpl 5 | from matplotlib import pyplot as plt 6 | mpl.rcParams[u'font.sans-serif'] = ['simhei'] 7 | mpl.rcParams['axes.unicode_minus'] = False 8 | 9 | import argparse 10 | import json 11 | 12 | import matplotlib.cm as cm 13 | import matplotlib.pyplot as plt 14 | 15 | plt.rcParams['font.sans-serif'] = ['SimHei'] # for Windows 16 | import numpy as np 17 | import skimage.transform 18 | import torch 19 | import torch.nn.functional as F 20 | import torchvision.transforms as transforms 21 | from PIL import Image 22 | from scipy.misc import imread, imresize 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3): 28 | """ 29 | Reads an image and captions it with beam search. 30 | :param encoder: encoder model 31 | :param decoder: decoder model 32 | :param image_path: path to image 33 | :param word_map: word map 34 | :param beam_size: number of sequences to consider at each decode-step 35 | :return: caption, weights for visualization 36 | """ 37 | 38 | k = beam_size 39 | vocab_size = len(word_map) 40 | 41 | # Read image and process 42 | img = imread(image_path) 43 | if len(img.shape) == 2: 44 | img = img[:, :, np.newaxis] 45 | img = np.concatenate([img, img, img], axis=2) 46 | img = imresize(img, (256, 256)) 47 | img = img.transpose(2, 0, 1) 48 | img = img / 255. 49 | img = torch.FloatTensor(img).to(device) 50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | transform = transforms.Compose([normalize]) 53 | image = transform(img) # (3, 256, 256) 54 | 55 | # Encode 56 | image = image.unsqueeze(0) # (1, 3, 256, 256) 57 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 58 | enc_image_size = encoder_out.size(1) 59 | encoder_dim = encoder_out.size(3) 60 | 61 | # Flatten encoding 62 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 63 | num_pixels = encoder_out.size(1) 64 | 65 | # We'll treat the problem as having a batch size of k 66 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 67 | 68 | # Tensor to store top k previous words at each step; now they're just 69 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 70 | 71 | # Tensor to store top k sequences; now they're just 72 | seqs = k_prev_words # (k, 1) 73 | 74 | # Tensor to store top k sequences' scores; now they're just 0 75 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 76 | 77 | # Tensor to store top k sequences' alphas; now they're just 1s 78 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 79 | 80 | # Lists to store completed sequences, their alphas and scores 81 | complete_seqs = list() 82 | complete_seqs_alpha = list() 83 | complete_seqs_scores = list() 84 | 85 | # Start decoding 86 | step = 1 87 | h, c = decoder.init_hidden_state(encoder_out) 88 | 89 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 90 | while True: 91 | 92 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 93 | 94 | awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 95 | 96 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 97 | 98 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 99 | awe = gate * awe 100 | 101 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 102 | 103 | scores = decoder.fc(h) # (s, vocab_size) 104 | scores = F.log_softmax(scores, dim=1) 105 | 106 | # Add 107 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 108 | 109 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 110 | if step == 1: 111 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 112 | else: 113 | # Unroll and find top scores, and their unrolled indices 114 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 115 | 116 | # Convert unrolled indices to actual indices of scores 117 | prev_word_inds = top_k_words / vocab_size # (s) 118 | next_word_inds = top_k_words % vocab_size # (s) 119 | 120 | # Add new words to sequences, alphas 121 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 122 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 123 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 124 | 125 | # Which sequences are incomplete (didn't reach )? 126 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 127 | next_word != word_map['']] 128 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 129 | 130 | # Set aside complete sequences 131 | if len(complete_inds) > 0: 132 | complete_seqs.extend(seqs[complete_inds].tolist()) 133 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 134 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 135 | k -= len(complete_inds) # reduce beam length accordingly 136 | 137 | # Proceed with incomplete sequences 138 | if k == 0: 139 | break 140 | seqs = seqs[incomplete_inds] 141 | seqs_alpha = seqs_alpha[incomplete_inds] 142 | h = h[prev_word_inds[incomplete_inds]] 143 | c = c[prev_word_inds[incomplete_inds]] 144 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 145 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 146 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 147 | 148 | # Break if things have been going on too long 149 | if step > 50: 150 | break 151 | step += 1 152 | 153 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 154 | seq = complete_seqs[i] 155 | alphas = complete_seqs_alpha[i] 156 | 157 | return seq, alphas 158 | 159 | 160 | def visualize_att(image_path, seq, alphas, rev_word_map, i, smooth=True): 161 | """ 162 | Visualizes caption with weights at every word. 163 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb 164 | :param image_path: path to image that has been captioned 165 | :param seq: caption 166 | :param alphas: weights 167 | :param rev_word_map: reverse word mapping, i.e. ix2word 168 | :param smooth: smooth weights? 169 | """ 170 | image = Image.open(image_path) 171 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 172 | 173 | words = [rev_word_map[ind] for ind in seq] 174 | print(words) 175 | 176 | for t in range(len(words)): 177 | if t > 50: 178 | break 179 | plt.subplot(np.ceil(len(words) / 5.), 5, t + 1) 180 | 181 | plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12) 182 | plt.imshow(image) 183 | current_alpha = alphas[t, :] 184 | if smooth: 185 | alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8) 186 | else: 187 | alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24]) 188 | if t == 0: 189 | plt.imshow(alpha, alpha=0) 190 | else: 191 | plt.imshow(alpha, alpha=0.8) 192 | plt.set_cmap(cm.Greys_r) 193 | plt.axis('off') 194 | 195 | plt.savefig('images/out_{}.jpg'.format(i)) 196 | plt.close() 197 | 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser(description='Show, Attend, and Tell - Tutorial - Generate Caption') 201 | 202 | parser.add_argument('--img', '-i', help='path to image') 203 | parser.add_argument('--model', '-m', default='BEST_checkpoint_.pth.tar', help='path to model') 204 | parser.add_argument('--word_map', '-wm', default='data/WORDMAP.json', help='path to word map JSON') 205 | parser.add_argument('--beam_size', '-b', default=5, type=int, help='beam size for beam search') 206 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay') 207 | 208 | args = parser.parse_args() 209 | 210 | # Load model 211 | checkpoint = torch.load(args.model) 212 | decoder = checkpoint['decoder'] 213 | decoder = decoder.to(device) 214 | decoder.eval() 215 | encoder = checkpoint['encoder'] 216 | encoder = encoder.to(device) 217 | encoder.eval() 218 | 219 | # Load word map (word2ix) 220 | with open(args.word_map, 'r') as j: 221 | word_map = json.load(j) 222 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 223 | 224 | from config import * 225 | from tqdm import tqdm 226 | 227 | for i in tqdm(range(10)): 228 | img = 'images/image_{}.jpg'.format(i) 229 | 230 | # Encode, decode with attention and beam search 231 | seq, alphas = caption_image_beam_search(encoder, decoder, img, word_map, args.beam_size) 232 | alphas = torch.FloatTensor(alphas) 233 | 234 | # Visualize caption and attention of best sequence 235 | visualize_att(img, seq, alphas, rev_word_map, i, args.smooth) 236 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.optim 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | from nltk.translate.bleu_score import corpus_bleu 6 | from tqdm import tqdm 7 | 8 | from data_generator import * 9 | from utils import * 10 | 11 | # Parameters 12 | data_folder = '/media/ssd/caption data' # folder with data files saved by create_input_files.py 13 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files 14 | checkpoint = '../BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' # model checkpoint 15 | word_map_file = '/media/ssd/caption data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' # word map, ensure it's the same the data was encoded with and the model was trained with 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 17 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 18 | 19 | # Load model 20 | checkpoint = torch.load(checkpoint) 21 | decoder = checkpoint['decoder'] 22 | decoder = decoder.to(device) 23 | decoder.eval() 24 | encoder = checkpoint['encoder'] 25 | encoder = encoder.to(device) 26 | encoder.eval() 27 | 28 | # Load word map (word2ix) 29 | with open(word_map_file, 'r') as j: 30 | word_map = json.load(j) 31 | rev_word_map = {v: k for k, v in word_map.items()} 32 | vocab_size = len(word_map) 33 | 34 | # Normalization transform 35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | 38 | 39 | def evaluate(beam_size): 40 | """ 41 | Evaluation 42 | :param beam_size: beam size at which to generate captions for evaluation 43 | :return: BLEU-4 score 44 | """ 45 | # DataLoader 46 | loader = torch.utils.data.DataLoader( 47 | CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])), 48 | batch_size=1, shuffle=True, num_workers=1, pin_memory=True) 49 | 50 | # TODO: Batched Beam Search 51 | 52 | # Lists to store references (true captions), and hypothesis (prediction) for each image 53 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 54 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 55 | references = list() 56 | hypotheses = list() 57 | 58 | # For each image 59 | for i, (image, caps, caplens, allcaps) in enumerate( 60 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))): 61 | 62 | k = beam_size 63 | 64 | # Move to GPU device, if available 65 | image = image.to(device) # (1, 3, 256, 256) 66 | 67 | # Encode 68 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 69 | enc_image_size = encoder_out.size(1) 70 | encoder_dim = encoder_out.size(3) 71 | 72 | # Flatten encoding 73 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 74 | num_pixels = encoder_out.size(1) 75 | 76 | # We'll treat the problem as having a batch size of k 77 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 78 | 79 | # Tensor to store top k previous words at each step; now they're just 80 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 81 | 82 | # Tensor to store top k sequences; now they're just 83 | seqs = k_prev_words # (k, 1) 84 | 85 | # Tensor to store top k sequences' scores; now they're just 0 86 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 87 | 88 | # Lists to store completed sequences and scores 89 | complete_seqs = list() 90 | complete_seqs_scores = list() 91 | 92 | # Start decoding 93 | step = 1 94 | h, c = decoder.init_hidden_state(encoder_out) 95 | 96 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 97 | while True: 98 | 99 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 100 | 101 | awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 102 | 103 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 104 | awe = gate * awe 105 | 106 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 107 | 108 | scores = decoder.fc(h) # (s, vocab_size) 109 | scores = F.log_softmax(scores, dim=1) 110 | 111 | # Add 112 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 113 | 114 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 115 | if step == 1: 116 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 117 | else: 118 | # Unroll and find top scores, and their unrolled indices 119 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 120 | 121 | # Convert unrolled indices to actual indices of scores 122 | prev_word_inds = top_k_words / vocab_size # (s) 123 | next_word_inds = top_k_words % vocab_size # (s) 124 | 125 | # Add new words to sequences 126 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 127 | 128 | # Which sequences are incomplete (didn't reach )? 129 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 130 | next_word != word_map['']] 131 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 132 | 133 | # Set aside complete sequences 134 | if len(complete_inds) > 0: 135 | complete_seqs.extend(seqs[complete_inds].tolist()) 136 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 137 | k -= len(complete_inds) # reduce beam length accordingly 138 | 139 | # Proceed with incomplete sequences 140 | if k == 0: 141 | break 142 | seqs = seqs[incomplete_inds] 143 | h = h[prev_word_inds[incomplete_inds]] 144 | c = c[prev_word_inds[incomplete_inds]] 145 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 146 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 147 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 148 | 149 | # Break if things have been going on too long 150 | if step > 50: 151 | break 152 | step += 1 153 | 154 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 155 | seq = complete_seqs[i] 156 | 157 | # References 158 | img_caps = allcaps[0].tolist() 159 | img_captions = list( 160 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], 161 | img_caps)) # remove and pads 162 | references.append(img_captions) 163 | 164 | # Hypotheses 165 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) 166 | 167 | assert len(references) == len(hypotheses) 168 | 169 | # Calculate BLEU-4 scores 170 | bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True) 171 | 172 | return bleu4 173 | 174 | 175 | if __name__ == '__main__': 176 | beam_size = 5 177 | print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size))) 178 | -------------------------------------------------------------------------------- /font/WenQuanYiMicroHei-01.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/WenQuanYiMicroHei-01.ttf -------------------------------------------------------------------------------- /font/WenQuanYiMicroHeiMono-02.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/WenQuanYiMicroHeiMono-02.ttf -------------------------------------------------------------------------------- /font/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/simhei.ttf -------------------------------------------------------------------------------- /images/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/dataset.png -------------------------------------------------------------------------------- /images/image_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_0.jpg -------------------------------------------------------------------------------- /images/image_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_1.jpg -------------------------------------------------------------------------------- /images/image_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_2.jpg -------------------------------------------------------------------------------- /images/image_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_3.jpg -------------------------------------------------------------------------------- /images/image_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_4.jpg -------------------------------------------------------------------------------- /images/image_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_5.jpg -------------------------------------------------------------------------------- /images/image_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_6.jpg -------------------------------------------------------------------------------- /images/image_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_7.jpg -------------------------------------------------------------------------------- /images/image_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_8.jpg -------------------------------------------------------------------------------- /images/image_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_9.jpg -------------------------------------------------------------------------------- /images/net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/net.png -------------------------------------------------------------------------------- /images/out_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_0.jpg -------------------------------------------------------------------------------- /images/out_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_1.jpg -------------------------------------------------------------------------------- /images/out_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_2.jpg -------------------------------------------------------------------------------- /images/out_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_3.jpg -------------------------------------------------------------------------------- /images/out_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_4.jpg -------------------------------------------------------------------------------- /images/out_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_5.jpg -------------------------------------------------------------------------------- /images/out_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_6.jpg -------------------------------------------------------------------------------- /images/out_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_7.jpg -------------------------------------------------------------------------------- /images/out_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_8.jpg -------------------------------------------------------------------------------- /images/out_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_9.jpg -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class Encoder(nn.Module): 9 | """ 10 | Encoder. 11 | """ 12 | 13 | def __init__(self, encoded_image_size=14): 14 | super(Encoder, self).__init__() 15 | self.enc_image_size = encoded_image_size 16 | 17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 18 | 19 | # Remove linear and pool layers (since we're not doing classification) 20 | modules = list(resnet.children())[:-2] 21 | self.resnet = nn.Sequential(*modules) 22 | 23 | # Resize image to fixed size to allow input images of variable size 24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 25 | 26 | self.fine_tune() 27 | 28 | def forward(self, images): 29 | """ 30 | Forward propagation. 31 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 32 | :return: encoded images 33 | """ 34 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32) 35 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size) 36 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048) 37 | return out 38 | 39 | def fine_tune(self, fine_tune=True): 40 | """ 41 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 42 | :param fine_tune: Allow? 43 | """ 44 | for p in self.resnet.parameters(): 45 | p.requires_grad = False 46 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 47 | for c in list(self.resnet.children())[5:]: 48 | for p in c.parameters(): 49 | p.requires_grad = fine_tune 50 | 51 | 52 | class Attention(nn.Module): 53 | """ 54 | Attention Network. 55 | """ 56 | 57 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 58 | """ 59 | :param encoder_dim: feature size of encoded images 60 | :param decoder_dim: size of decoder's RNN 61 | :param attention_dim: size of the attention network 62 | """ 63 | super(Attention, self).__init__() 64 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 65 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 66 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 67 | self.relu = nn.ReLU() 68 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 69 | 70 | def forward(self, encoder_out, decoder_hidden): 71 | """ 72 | Forward propagation. 73 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 74 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 75 | :return: attention weighted encoding, weights 76 | """ 77 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 78 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 79 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 80 | alpha = self.softmax(att) # (batch_size, num_pixels) 81 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 82 | 83 | return attention_weighted_encoding, alpha 84 | 85 | 86 | class DecoderWithAttention(nn.Module): 87 | """ 88 | Decoder. 89 | """ 90 | 91 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): 92 | """ 93 | :param attention_dim: size of attention network 94 | :param embed_dim: embedding size 95 | :param decoder_dim: size of decoder's RNN 96 | :param vocab_size: size of vocabulary 97 | :param encoder_dim: feature size of encoded images 98 | :param dropout: dropout 99 | """ 100 | super(DecoderWithAttention, self).__init__() 101 | 102 | self.encoder_dim = encoder_dim 103 | self.attention_dim = attention_dim 104 | self.embed_dim = embed_dim 105 | self.decoder_dim = decoder_dim 106 | self.vocab_size = vocab_size 107 | self.dropout = dropout 108 | 109 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 110 | 111 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 112 | self.dropout = nn.Dropout(p=self.dropout) 113 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 114 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 115 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 116 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 117 | self.sigmoid = nn.Sigmoid() 118 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 119 | self.init_weights() # initialize some layers with the uniform distribution 120 | 121 | def init_weights(self): 122 | """ 123 | Initializes some parameters with values from the uniform distribution, for easier convergence. 124 | """ 125 | self.embedding.weight.data.uniform_(-0.1, 0.1) 126 | self.fc.bias.data.fill_(0) 127 | self.fc.weight.data.uniform_(-0.1, 0.1) 128 | 129 | def load_pretrained_embeddings(self, embeddings): 130 | """ 131 | Loads embedding layer with pre-trained embeddings. 132 | :param embeddings: pre-trained embeddings 133 | """ 134 | self.embedding.weight = nn.Parameter(embeddings) 135 | 136 | def fine_tune_embeddings(self, fine_tune=True): 137 | """ 138 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 139 | :param fine_tune: Allow? 140 | """ 141 | for p in self.embedding.parameters(): 142 | p.requires_grad = fine_tune 143 | 144 | def init_hidden_state(self, encoder_out): 145 | """ 146 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 147 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 148 | :return: hidden state, cell state 149 | """ 150 | mean_encoder_out = encoder_out.mean(dim=1) 151 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 152 | c = self.init_c(mean_encoder_out) 153 | return h, c 154 | 155 | def forward(self, encoder_out, encoded_captions, caption_lengths): 156 | """ 157 | Forward propagation. 158 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 159 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 160 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 161 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 162 | """ 163 | 164 | batch_size = encoder_out.size(0) 165 | encoder_dim = encoder_out.size(-1) 166 | vocab_size = self.vocab_size 167 | 168 | # Flatten image 169 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 170 | num_pixels = encoder_out.size(1) 171 | 172 | # Sort input data by decreasing lengths; why? apparent below 173 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 174 | encoder_out = encoder_out[sort_ind] 175 | encoded_captions = encoded_captions[sort_ind] 176 | 177 | # Embedding 178 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 179 | 180 | # Initialize LSTM state 181 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 182 | 183 | # We won't decode at the position, since we've finished generating as soon as we generate 184 | # So, decoding lengths are actual lengths - 1 185 | decode_lengths = (caption_lengths - 1).tolist() 186 | 187 | # Create tensors to hold word predicion scores and alphas 188 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 189 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 190 | 191 | # At each time-step, decode by 192 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 193 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 194 | for t in range(max(decode_lengths)): 195 | batch_size_t = sum([l > t for l in decode_lengths]) 196 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 197 | h[:batch_size_t]) 198 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 199 | attention_weighted_encoding = gate * attention_weighted_encoding 200 | h, c = self.decode_step( 201 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 202 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 203 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 204 | predictions[:batch_size_t, t, :] = preds 205 | alphas[:batch_size_t, t, :] = alpha 206 | 207 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind 208 | -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zipfile 3 | from collections import Counter 4 | 5 | import jieba 6 | from tqdm import tqdm 7 | 8 | from config import * 9 | from utils import ensure_folder 10 | 11 | 12 | def extract(folder): 13 | filename = '{}.zip'.format(folder) 14 | print('Extracting {}...'.format(filename)) 15 | with zipfile.ZipFile(filename, 'r') as zip_ref: 16 | zip_ref.extractall('data') 17 | 18 | 19 | def create_input_files(): 20 | json_path = train_annotations_filename 21 | 22 | # Read JSON 23 | with open(json_path, 'r') as j: 24 | samples = json.load(j) 25 | 26 | # Read image paths and captions for each image 27 | word_freq = Counter() 28 | 29 | for sample in tqdm(samples): 30 | caption = sample['caption'] 31 | for c in caption: 32 | seg_list = jieba.cut(c, cut_all=True) 33 | # Update word frequency 34 | word_freq.update(seg_list) 35 | 36 | # Create word map 37 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] 38 | word_map = {k: v + 1 for v, k in enumerate(words)} 39 | word_map[''] = len(word_map) + 1 40 | word_map[''] = len(word_map) + 1 41 | word_map[''] = len(word_map) + 1 42 | word_map[''] = 0 43 | 44 | print(len(word_map)) 45 | print(words[:10]) 46 | 47 | # Save word map to a JSON 48 | with open(os.path.join(data_folder, 'WORDMAP.json'), 'w') as j: 49 | json.dump(word_map, j) 50 | 51 | 52 | if __name__ == '__main__': 53 | # parameters 54 | ensure_folder('data') 55 | 56 | if not os.path.isdir(train_image_folder): 57 | extract(train_folder) 58 | 59 | if not os.path.isdir(valid_image_folder): 60 | extract(valid_folder) 61 | 62 | if not os.path.isdir(test_a_image_folder): 63 | extract(test_a_folder) 64 | 65 | if not os.path.isdir(test_b_image_folder): 66 | extract(test_b_folder) 67 | 68 | create_input_files() 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jieba 2 | gensim 3 | tqdm -------------------------------------------------------------------------------- /sponsor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/sponsor.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import torch.optim 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from nltk.translate.bleu_score import corpus_bleu 8 | from torch import nn 9 | from torch.nn.utils.rnn import pack_padded_sequence 10 | 11 | from config import * 12 | from data_generator import CaptionDataset 13 | from models import Encoder, DecoderWithAttention 14 | from utils import * 15 | 16 | 17 | def main(): 18 | """ 19 | Training and validation. 20 | """ 21 | 22 | global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_map 23 | 24 | # Read word map 25 | word_map_file = os.path.join(data_folder, 'WORDMAP.json') 26 | with open(word_map_file, 'r') as j: 27 | word_map = json.load(j) 28 | 29 | # Initialize / load checkpoint 30 | if checkpoint is None: 31 | decoder = DecoderWithAttention(attention_dim=attention_dim, 32 | embed_dim=emb_dim, 33 | decoder_dim=decoder_dim, 34 | vocab_size=len(word_map), 35 | dropout=dropout) 36 | # decoder = nn.DataParallel(decoder) 37 | # decoder = torch.nn.DataParallel(decoder.cuda(), device_ids=[0, 1, 2, 3]) 38 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), 39 | lr=decoder_lr) 40 | encoder = Encoder() 41 | encoder.fine_tune(fine_tune_encoder) 42 | # encoder = nn.DataParallel(encoder) 43 | # encoder = torch.nn.DataParallel(encoder.cuda(), device_ids=[0, 1, 2, 3]) 44 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 45 | lr=encoder_lr) if fine_tune_encoder else None 46 | 47 | else: 48 | checkpoint = torch.load(checkpoint) 49 | start_epoch = checkpoint['epoch'] + 1 50 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 51 | best_bleu4 = checkpoint['bleu-4'] 52 | decoder = checkpoint['decoder'] 53 | decoder_optimizer = checkpoint['decoder_optimizer'] 54 | encoder = checkpoint['encoder'] 55 | encoder_optimizer = checkpoint['encoder_optimizer'] 56 | if fine_tune_encoder is True and encoder_optimizer is None: 57 | encoder.fine_tune(fine_tune_encoder) 58 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 59 | lr=encoder_lr) 60 | 61 | # Move to GPU, if available 62 | decoder = decoder.to(device) 63 | encoder = encoder.to(device) 64 | 65 | # Loss function 66 | criterion = nn.CrossEntropyLoss().to(device) 67 | 68 | # Custom dataloaders 69 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]) 71 | train_loader = torch.utils.data.DataLoader( 72 | CaptionDataset('train', transform=transforms.Compose([normalize])), 73 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) 74 | val_loader = torch.utils.data.DataLoader( 75 | CaptionDataset('valid', transform=transforms.Compose([normalize])), 76 | batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True) 77 | 78 | # Epochs 79 | for epoch in range(start_epoch, epochs): 80 | 81 | # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20 82 | if epochs_since_improvement == 20: 83 | break 84 | if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0: 85 | adjust_learning_rate(decoder_optimizer, 0.8) 86 | if fine_tune_encoder: 87 | adjust_learning_rate(encoder_optimizer, 0.8) 88 | 89 | # One epoch's training 90 | train(train_loader=train_loader, 91 | encoder=encoder, 92 | decoder=decoder, 93 | criterion=criterion, 94 | encoder_optimizer=encoder_optimizer, 95 | decoder_optimizer=decoder_optimizer, 96 | epoch=epoch) 97 | 98 | # One epoch's validation 99 | recent_bleu4 = validate(val_loader=val_loader, 100 | encoder=encoder, 101 | decoder=decoder, 102 | criterion=criterion) 103 | 104 | # Check if there was an improvement 105 | is_best = recent_bleu4 > best_bleu4 106 | best_bleu4 = max(recent_bleu4, best_bleu4) 107 | if not is_best: 108 | epochs_since_improvement += 1 109 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 110 | else: 111 | epochs_since_improvement = 0 112 | 113 | # Save checkpoint 114 | save_checkpoint(epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 115 | decoder_optimizer, recent_bleu4, is_best) 116 | 117 | 118 | def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch): 119 | """ 120 | Performs one epoch's training. 121 | :param train_loader: DataLoader for training data 122 | :param encoder: encoder model 123 | :param decoder: decoder model 124 | :param criterion: loss layer 125 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 126 | :param decoder_optimizer: optimizer to update decoder's weights 127 | :param epoch: epoch number 128 | """ 129 | 130 | decoder.train() # train mode (dropout and batchnorm is used) 131 | encoder.train() 132 | 133 | batch_time = AverageMeter() # forward prop. + back prop. time 134 | data_time = AverageMeter() # data loading time 135 | losses = AverageMeter() # loss (per word decoded) 136 | top5accs = AverageMeter() # top5 accuracy 137 | 138 | start = time.time() 139 | 140 | # Batches 141 | for i, (imgs, caps, caplens) in enumerate(train_loader): 142 | data_time.update(time.time() - start) 143 | 144 | # Move to GPU, if available 145 | imgs = imgs.to(device) 146 | caps = caps.to(device) 147 | caplens = caplens.to(device) 148 | 149 | # Forward prop. 150 | imgs = encoder(imgs) 151 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 152 | 153 | # Since we decoded starting with , the targets are all words after , up to 154 | targets = caps_sorted[:, 1:] 155 | 156 | # Remove timesteps that we didn't decode at, or are pads 157 | # pack_padded_sequence is an easy trick to do this 158 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) 159 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) 160 | 161 | # Calculate loss 162 | loss = criterion(scores, targets) 163 | 164 | # Add doubly stochastic attention regularization 165 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 166 | 167 | # Back prop. 168 | decoder_optimizer.zero_grad() 169 | if encoder_optimizer is not None: 170 | encoder_optimizer.zero_grad() 171 | loss.backward() 172 | 173 | # Clip gradients 174 | if grad_clip is not None: 175 | clip_gradient(decoder_optimizer, grad_clip) 176 | if encoder_optimizer is not None: 177 | clip_gradient(encoder_optimizer, grad_clip) 178 | 179 | # Update weights 180 | decoder_optimizer.step() 181 | if encoder_optimizer is not None: 182 | encoder_optimizer.step() 183 | 184 | # Keep track of metrics 185 | top5 = accuracy(scores, targets, 5) 186 | losses.update(loss.item(), sum(decode_lengths)) 187 | top5accs.update(top5, sum(decode_lengths)) 188 | batch_time.update(time.time() - start) 189 | 190 | start = time.time() 191 | 192 | # Print status 193 | if i % print_freq == 0: 194 | print('Epoch: [{0}][{1}/{2}]\t' 195 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 196 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 197 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 198 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader), 199 | batch_time=batch_time, 200 | data_time=data_time, loss=losses, 201 | top5=top5accs)) 202 | 203 | 204 | def validate(val_loader, encoder, decoder, criterion): 205 | """ 206 | Performs one epoch's validation. 207 | :param val_loader: DataLoader for validation data. 208 | :param encoder: encoder model 209 | :param decoder: decoder model 210 | :param criterion: loss layer 211 | :return: BLEU-4 score 212 | """ 213 | decoder.eval() # eval mode (no dropout or batchnorm) 214 | if encoder is not None: 215 | encoder.eval() 216 | 217 | batch_time = AverageMeter() 218 | losses = AverageMeter() 219 | top5accs = AverageMeter() 220 | 221 | start = time.time() 222 | 223 | references = list() # references (true captions) for calculating BLEU-4 score 224 | hypotheses = list() # hypotheses (predictions) 225 | 226 | # Batches 227 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): 228 | 229 | # Move to device, if available 230 | imgs = imgs.to(device) 231 | caps = caps.to(device) 232 | caplens = caplens.to(device) 233 | 234 | # Forward prop. 235 | if encoder is not None: 236 | imgs = encoder(imgs) 237 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 238 | 239 | # Since we decoded starting with , the targets are all words after , up to 240 | targets = caps_sorted[:, 1:] 241 | 242 | # Remove timesteps that we didn't decode at, or are pads 243 | # pack_padded_sequence is an easy trick to do this 244 | scores_copy = scores.clone() 245 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) 246 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) 247 | 248 | # Calculate loss 249 | loss = criterion(scores, targets) 250 | 251 | # Add doubly stochastic attention regularization 252 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 253 | 254 | # Keep track of metrics 255 | losses.update(loss.item(), sum(decode_lengths)) 256 | top5 = accuracy(scores, targets, 5) 257 | top5accs.update(top5, sum(decode_lengths)) 258 | batch_time.update(time.time() - start) 259 | 260 | start = time.time() 261 | 262 | if i % print_freq == 0: 263 | print('Validation: [{0}/{1}]\t' 264 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 265 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 266 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time, 267 | loss=losses, top5=top5accs)) 268 | 269 | # Store references (true captions), and hypothesis (prediction) for each image 270 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 271 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 272 | 273 | # References 274 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder 275 | for j in range(allcaps.shape[0]): 276 | img_caps = allcaps[j].tolist() 277 | img_captions = list( 278 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], 279 | img_caps)) # remove and pads 280 | references.append(img_captions) 281 | 282 | # Hypotheses 283 | _, preds = torch.max(scores_copy, dim=2) 284 | preds = preds.tolist() 285 | temp_preds = list() 286 | for j, p in enumerate(preds): 287 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads 288 | preds = temp_preds 289 | hypotheses.extend(preds) 290 | 291 | assert len(references) == len(hypotheses) 292 | 293 | # Calculate BLEU-4 scores 294 | bleu4 = corpus_bleu(references, hypotheses) 295 | 296 | print( 297 | '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format( 298 | loss=losses, 299 | top5=top5accs, 300 | bleu=bleu4)) 301 | 302 | return bleu4 303 | 304 | 305 | if __name__ == '__main__': 306 | main() 307 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | 4 | import cv2 as cv 5 | import torch 6 | 7 | 8 | def ensure_folder(folder): 9 | if not os.path.exists(folder): 10 | os.makedirs(folder) 11 | 12 | 13 | # getting the number of CPUs 14 | def get_available_cpus(): 15 | return multiprocessing.cpu_count() 16 | 17 | 18 | def draw_str(dst, target, s): 19 | x, y = target 20 | cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA) 21 | cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA) 22 | 23 | 24 | def clip_gradient(optimizer, grad_clip): 25 | """ 26 | Clips gradients computed during backpropagation to avoid explosion of gradients. 27 | :param optimizer: optimizer with the gradients to be clipped 28 | :param grad_clip: clip value 29 | """ 30 | for group in optimizer.param_groups: 31 | for param in group['params']: 32 | if param.grad is not None: 33 | param.grad.data.clamp_(-grad_clip, grad_clip) 34 | 35 | 36 | def save_checkpoint(epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, 37 | bleu4, is_best): 38 | """ 39 | Saves model checkpoint. 40 | :param data_name: base name of processed dataset 41 | :param epoch: epoch number 42 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 43 | :param encoder: encoder model 44 | :param decoder: decoder model 45 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 46 | :param decoder_optimizer: optimizer to update decoder's weights 47 | :param bleu4: validation BLEU-4 score for this epoch 48 | :param is_best: is this checkpoint the best so far? 49 | """ 50 | state = {'epoch': epoch, 51 | 'epochs_since_improvement': epochs_since_improvement, 52 | 'bleu-4': bleu4, 53 | 'encoder': encoder, 54 | 'decoder': decoder, 55 | 'encoder_optimizer': encoder_optimizer, 56 | 'decoder_optimizer': decoder_optimizer} 57 | filename = 'checkpoint_' + '.pth.tar' 58 | torch.save(state, filename) 59 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 60 | if is_best: 61 | torch.save(state, 'BEST_' + filename) 62 | 63 | 64 | class AverageMeter(object): 65 | """ 66 | Keeps track of most recent, average, sum, and count of a metric. 67 | """ 68 | 69 | def __init__(self): 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | 85 | def adjust_learning_rate(optimizer, shrink_factor): 86 | """ 87 | Shrinks learning rate by a specified factor. 88 | :param optimizer: optimizer whose learning rate must be shrunk. 89 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 90 | """ 91 | 92 | print("\nDECAYING learning rate.") 93 | for param_group in optimizer.param_groups: 94 | param_group['lr'] = param_group['lr'] * shrink_factor 95 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 96 | 97 | 98 | def accuracy(scores, targets, k): 99 | """ 100 | Computes top-k accuracy, from predicted and true labels. 101 | :param scores: scores from the model 102 | :param targets: true labels 103 | :param k: k in top-k accuracy 104 | :return: top-k accuracy 105 | """ 106 | 107 | batch_size = targets.size(0) 108 | _, ind = scores.topk(k, 1, True, True) 109 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 110 | correct_total = correct.view(-1).float().sum() # 0D tensor 111 | return correct_total.item() * (100.0 / batch_size) 112 | --------------------------------------------------------------------------------