├── .gitignore
├── LICENSE
├── README.md
├── analyze_data.py
├── config.py
├── data_generator.py
├── demo.py
├── eval.py
├── font
├── WenQuanYiMicroHei-01.ttf
├── WenQuanYiMicroHeiMono-02.ttf
└── simhei.ttf
├── images
├── dataset.png
├── image_0.jpg
├── image_1.jpg
├── image_2.jpg
├── image_3.jpg
├── image_4.jpg
├── image_5.jpg
├── image_6.jpg
├── image_7.jpg
├── image_8.jpg
├── image_9.jpg
├── net.png
├── out_0.jpg
├── out_1.jpg
├── out_2.jpg
├── out_3.jpg
├── out_4.jpg
├── out_5.jpg
├── out_6.jpg
├── out_7.jpg
├── out_8.jpg
└── out_9.jpg
├── models.py
├── pre_process.py
├── requirements.txt
├── sponsor.jpg
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__/
3 | logs/
4 | models/
5 | data/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 图像中文描述
2 |
3 | 图像中文描述 + 视觉注意力的 PyTorch 实现。
4 |
5 | [Show, Attend, and Tell](https://arxiv.org/pdf/1502.03044.pdf) 是令人惊叹的工作,[这里](https://github.com/kelvinxu/arctic-captions)是作者的原始实现。
6 |
7 | 这个模型学会了“往哪瞅”:当模型逐词生成标题时,模型的目光在图像上移动以专注于跟下一个词最相关的部分。
8 |
9 | ## 依赖
10 | - Python 3.5
11 | - PyTorch 0.4
12 |
13 | ## 数据集
14 |
15 | 使用 AI Challenger 2017 的图像中文描述数据集,包含30万张图片,150万句中文描述。训练集:210,000 张,验证集:30,000 张,测试集 A:30,000 张,测试集 B:30,000 张。
16 |
17 |
18 | 
19 |
20 | 下载点这里:[图像中文描述数据集](https://challenger.ai/datasets/),放在 data 目录下。
21 |
22 |
23 | ## 网络结构
24 |
25 | 
26 |
27 | ## 用法
28 |
29 | ### 数据预处理
30 | 提取210,000 张训练图片和30,000 张验证图片:
31 | ```bash
32 | $ python pre_process.py
33 | ```
34 |
35 | ### 训练
36 | ```bash
37 | $ python train.py
38 | ```
39 |
40 | 可视化训练过程,执行:
41 | ```bash
42 | $ tensorboard --logdir path_to_current_dir/logs
43 | ```
44 |
45 | ### 演示
46 | 下载 [预训练模型](https://github.com/foamliu/Image-Captioning-v2/releases/download/v1.0/BEST_checkpoint_.pth.tar) 放在 models 目录,然后执行:
47 |
48 | ```bash
49 | $ python demo.py
50 | ```
51 |
52 | 原图 | 注意力 |
53 | |---|---|
54 | |
| |
55 | |
| |
56 | |
| |
57 | |
| |
58 | |
| |
59 | |
| |
60 | |
| |
61 | |
| |
62 | |
| |
63 | |
| |
64 |
65 | ## 小小的赞助~
66 |
67 |
68 |
69 | 若对您有帮助可给予小小的赞助~
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/analyze_data.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 | import json
3 | import os
4 |
5 | import jieba
6 | from tqdm import tqdm
7 |
8 | from config import train_folder, train_annotations_filename
9 |
10 | if __name__ == '__main__':
11 | print('Calculating the maximum length among all train captions')
12 | annotations_path = os.path.join(train_folder, train_annotations_filename)
13 |
14 | with open(annotations_path, 'r') as f:
15 | samples = json.load(f)
16 |
17 | max_len = 0
18 | for sample in tqdm(samples):
19 | caption = sample['caption']
20 | for c in caption:
21 | seg_list = jieba.cut(c, cut_all=True)
22 | length = sum(1 for item in seg_list)
23 | if length > max_len:
24 | max_len = length
25 | print('max_len: ' + str(max_len))
26 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.backends.cudnn as cudnn
5 |
6 | image_h = image_w = image_size = 256
7 | channel = 3
8 | epochs = 10000
9 | patience = 10
10 | num_train_samples = 1050000
11 | num_valid_samples = 150000
12 | max_len = 40
13 | captions_per_image = 5
14 |
15 | # Model parameters
16 | emb_dim = 512 # dimension of word embeddings
17 | attention_dim = 512 # dimension of attention linear layers
18 | decoder_dim = 512 # dimension of decoder RNN
19 | dropout = 0.5
20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors
21 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
22 |
23 | # Training parameters
24 | start_epoch = 0
25 | epochs = 120 # number of epochs to train for (if early stopping is not triggered)
26 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU
27 | batch_size = 128
28 | workers = 2 # for data-loading; right now, only 1 works with h5py
29 | encoder_lr = 1e-4 # learning rate for encoder if fine-tuning
30 | decoder_lr = 4e-4 # learning rate for decoder
31 | grad_clip = 5. # clip gradients at an absolute value of
32 | alpha_c = 1. # regularization parameter for 'doubly stochastic attention', as in the paper
33 | best_bleu4 = 0. # BLEU-4 score right now
34 | print_freq = 100 # print training/validation stats every __ batches
35 | fine_tune_encoder = False # fine-tune encoder?
36 | checkpoint = None # path to checkpoint, None if none
37 | min_word_freq = 3
38 |
39 | # Data parameters
40 | data_folder = 'data'
41 | train_folder = 'data/ai_challenger_caption_train_20170902'
42 | valid_folder = 'data/ai_challenger_caption_validation_20170910'
43 | test_a_folder = 'data/ai_challenger_caption_test_a_20180103'
44 | test_b_folder = 'data/ai_challenger_caption_test_b_20180103'
45 | train_image_folder = os.path.join(train_folder, 'caption_train_images_20170902')
46 | valid_image_folder = os.path.join(valid_folder, 'caption_validation_images_20170910')
47 | test_a_image_folder = os.path.join(test_a_folder, 'caption_test_a_images_20180103')
48 | test_b_image_folder = os.path.join(test_b_folder, 'caption_test_b_images_20180103')
49 | train_annotations_filename = os.path.join(train_folder, 'caption_train_annotations_20170902.json')
50 | valid_annotations_filename = os.path.join(valid_folder, 'caption_validation_annotations_20170910.json')
51 | test_a_annotations_filename = os.path.join(test_a_folder, 'caption_test_a_annotations_20180103.json')
52 | test_b_annotations_filename = os.path.join(test_b_folder, 'caption_test_b_annotations_20180103.json')
53 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | # encoding=utf-8
2 | import json
3 |
4 | import jieba
5 | import numpy as np
6 | from scipy.misc import imread, imresize
7 | from torch.utils.data import Dataset
8 |
9 | from config import *
10 |
11 |
12 | def encode_caption(word_map, c):
13 | return [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [
14 | word_map['']] + [word_map['']] * (max_len - len(c))
15 |
16 |
17 | class CaptionDataset(Dataset):
18 | """
19 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
20 | """
21 |
22 | def __init__(self, split, transform=None):
23 | """
24 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST'
25 | :param transform: image transform pipeline
26 | """
27 | self.split = split
28 | assert self.split in {'train', 'valid'}
29 |
30 | if split == 'train':
31 | json_path = train_annotations_filename
32 | self.image_folder = train_image_folder
33 | else:
34 | json_path = valid_annotations_filename
35 | self.image_folder = valid_image_folder
36 |
37 | # Read JSON
38 | with open(json_path, 'r') as j:
39 | self.samples = json.load(j)
40 |
41 | # Read wordmap
42 | with open(os.path.join(data_folder, 'WORDMAP.json'), 'r') as j:
43 | self.word_map = json.load(j)
44 |
45 | # PyTorch transformation pipeline for the image (normalizing, etc.)
46 | self.transform = transform
47 |
48 | # Total number of datapoints
49 | self.dataset_size = len(self.samples * captions_per_image)
50 |
51 | def __getitem__(self, i):
52 | # Remember, the Nth caption corresponds to the (N // captions_per_image)th image
53 | sample = self.samples[i // captions_per_image]
54 | path = os.path.join(self.image_folder, sample['image_id'])
55 | # Read images
56 | img = imread(path)
57 | if len(img.shape) == 2:
58 | img = img[:, :, np.newaxis]
59 | img = np.concatenate([img, img, img], axis=2)
60 | img = imresize(img, (256, 256))
61 | img = img.transpose(2, 0, 1)
62 | assert img.shape == (3, 256, 256)
63 | assert np.max(img) <= 255
64 | img = torch.FloatTensor(img / 255.)
65 | if self.transform is not None:
66 | img = self.transform(img)
67 |
68 | # Sample captions
69 | captions = sample['caption']
70 | # Sanity check
71 | assert len(captions) == captions_per_image
72 | c = captions[i % captions_per_image]
73 | c = list(jieba.cut(c))
74 | # Encode captions
75 | enc_c = encode_caption(self.word_map, c)
76 |
77 | caption = torch.LongTensor(enc_c)
78 |
79 | caplen = torch.LongTensor([len(c) + 2])
80 |
81 | if self.split is 'train':
82 | return img, caption, caplen
83 | else:
84 | # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score
85 | all_captions = torch.LongTensor([encode_caption(self.word_map, list(jieba.cut(c))) for c in captions])
86 | return img, caption, caplen, all_captions
87 |
88 | def __len__(self):
89 | return self.dataset_size
90 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #coding:utf-8
3 | """a demo of matplotlib"""
4 | import matplotlib as mpl
5 | from matplotlib import pyplot as plt
6 | mpl.rcParams[u'font.sans-serif'] = ['simhei']
7 | mpl.rcParams['axes.unicode_minus'] = False
8 |
9 | import argparse
10 | import json
11 |
12 | import matplotlib.cm as cm
13 | import matplotlib.pyplot as plt
14 |
15 | plt.rcParams['font.sans-serif'] = ['SimHei'] # for Windows
16 | import numpy as np
17 | import skimage.transform
18 | import torch
19 | import torch.nn.functional as F
20 | import torchvision.transforms as transforms
21 | from PIL import Image
22 | from scipy.misc import imread, imresize
23 |
24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 |
26 |
27 | def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_size=3):
28 | """
29 | Reads an image and captions it with beam search.
30 | :param encoder: encoder model
31 | :param decoder: decoder model
32 | :param image_path: path to image
33 | :param word_map: word map
34 | :param beam_size: number of sequences to consider at each decode-step
35 | :return: caption, weights for visualization
36 | """
37 |
38 | k = beam_size
39 | vocab_size = len(word_map)
40 |
41 | # Read image and process
42 | img = imread(image_path)
43 | if len(img.shape) == 2:
44 | img = img[:, :, np.newaxis]
45 | img = np.concatenate([img, img, img], axis=2)
46 | img = imresize(img, (256, 256))
47 | img = img.transpose(2, 0, 1)
48 | img = img / 255.
49 | img = torch.FloatTensor(img).to(device)
50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
51 | std=[0.229, 0.224, 0.225])
52 | transform = transforms.Compose([normalize])
53 | image = transform(img) # (3, 256, 256)
54 |
55 | # Encode
56 | image = image.unsqueeze(0) # (1, 3, 256, 256)
57 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
58 | enc_image_size = encoder_out.size(1)
59 | encoder_dim = encoder_out.size(3)
60 |
61 | # Flatten encoding
62 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
63 | num_pixels = encoder_out.size(1)
64 |
65 | # We'll treat the problem as having a batch size of k
66 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
67 |
68 | # Tensor to store top k previous words at each step; now they're just
69 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
70 |
71 | # Tensor to store top k sequences; now they're just
72 | seqs = k_prev_words # (k, 1)
73 |
74 | # Tensor to store top k sequences' scores; now they're just 0
75 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
76 |
77 | # Tensor to store top k sequences' alphas; now they're just 1s
78 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
79 |
80 | # Lists to store completed sequences, their alphas and scores
81 | complete_seqs = list()
82 | complete_seqs_alpha = list()
83 | complete_seqs_scores = list()
84 |
85 | # Start decoding
86 | step = 1
87 | h, c = decoder.init_hidden_state(encoder_out)
88 |
89 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
90 | while True:
91 |
92 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
93 |
94 | awe, alpha = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
95 |
96 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
97 |
98 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
99 | awe = gate * awe
100 |
101 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
102 |
103 | scores = decoder.fc(h) # (s, vocab_size)
104 | scores = F.log_softmax(scores, dim=1)
105 |
106 | # Add
107 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
108 |
109 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
110 | if step == 1:
111 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
112 | else:
113 | # Unroll and find top scores, and their unrolled indices
114 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
115 |
116 | # Convert unrolled indices to actual indices of scores
117 | prev_word_inds = top_k_words / vocab_size # (s)
118 | next_word_inds = top_k_words % vocab_size # (s)
119 |
120 | # Add new words to sequences, alphas
121 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
122 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
123 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
124 |
125 | # Which sequences are incomplete (didn't reach )?
126 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
127 | next_word != word_map['']]
128 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
129 |
130 | # Set aside complete sequences
131 | if len(complete_inds) > 0:
132 | complete_seqs.extend(seqs[complete_inds].tolist())
133 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
134 | complete_seqs_scores.extend(top_k_scores[complete_inds])
135 | k -= len(complete_inds) # reduce beam length accordingly
136 |
137 | # Proceed with incomplete sequences
138 | if k == 0:
139 | break
140 | seqs = seqs[incomplete_inds]
141 | seqs_alpha = seqs_alpha[incomplete_inds]
142 | h = h[prev_word_inds[incomplete_inds]]
143 | c = c[prev_word_inds[incomplete_inds]]
144 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
145 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
146 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
147 |
148 | # Break if things have been going on too long
149 | if step > 50:
150 | break
151 | step += 1
152 |
153 | i = complete_seqs_scores.index(max(complete_seqs_scores))
154 | seq = complete_seqs[i]
155 | alphas = complete_seqs_alpha[i]
156 |
157 | return seq, alphas
158 |
159 |
160 | def visualize_att(image_path, seq, alphas, rev_word_map, i, smooth=True):
161 | """
162 | Visualizes caption with weights at every word.
163 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
164 | :param image_path: path to image that has been captioned
165 | :param seq: caption
166 | :param alphas: weights
167 | :param rev_word_map: reverse word mapping, i.e. ix2word
168 | :param smooth: smooth weights?
169 | """
170 | image = Image.open(image_path)
171 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)
172 |
173 | words = [rev_word_map[ind] for ind in seq]
174 | print(words)
175 |
176 | for t in range(len(words)):
177 | if t > 50:
178 | break
179 | plt.subplot(np.ceil(len(words) / 5.), 5, t + 1)
180 |
181 | plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12)
182 | plt.imshow(image)
183 | current_alpha = alphas[t, :]
184 | if smooth:
185 | alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8)
186 | else:
187 | alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24])
188 | if t == 0:
189 | plt.imshow(alpha, alpha=0)
190 | else:
191 | plt.imshow(alpha, alpha=0.8)
192 | plt.set_cmap(cm.Greys_r)
193 | plt.axis('off')
194 |
195 | plt.savefig('images/out_{}.jpg'.format(i))
196 | plt.close()
197 |
198 |
199 | if __name__ == '__main__':
200 | parser = argparse.ArgumentParser(description='Show, Attend, and Tell - Tutorial - Generate Caption')
201 |
202 | parser.add_argument('--img', '-i', help='path to image')
203 | parser.add_argument('--model', '-m', default='BEST_checkpoint_.pth.tar', help='path to model')
204 | parser.add_argument('--word_map', '-wm', default='data/WORDMAP.json', help='path to word map JSON')
205 | parser.add_argument('--beam_size', '-b', default=5, type=int, help='beam size for beam search')
206 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay')
207 |
208 | args = parser.parse_args()
209 |
210 | # Load model
211 | checkpoint = torch.load(args.model)
212 | decoder = checkpoint['decoder']
213 | decoder = decoder.to(device)
214 | decoder.eval()
215 | encoder = checkpoint['encoder']
216 | encoder = encoder.to(device)
217 | encoder.eval()
218 |
219 | # Load word map (word2ix)
220 | with open(args.word_map, 'r') as j:
221 | word_map = json.load(j)
222 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word
223 |
224 | from config import *
225 | from tqdm import tqdm
226 |
227 | for i in tqdm(range(10)):
228 | img = 'images/image_{}.jpg'.format(i)
229 |
230 | # Encode, decode with attention and beam search
231 | seq, alphas = caption_image_beam_search(encoder, decoder, img, word_map, args.beam_size)
232 | alphas = torch.FloatTensor(alphas)
233 |
234 | # Visualize caption and attention of best sequence
235 | visualize_att(img, seq, alphas, rev_word_map, i, args.smooth)
236 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.optim
3 | import torch.utils.data
4 | import torchvision.transforms as transforms
5 | from nltk.translate.bleu_score import corpus_bleu
6 | from tqdm import tqdm
7 |
8 | from data_generator import *
9 | from utils import *
10 |
11 | # Parameters
12 | data_folder = '/media/ssd/caption data' # folder with data files saved by create_input_files.py
13 | data_name = 'coco_5_cap_per_img_5_min_word_freq' # base name shared by data files
14 | checkpoint = '../BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar' # model checkpoint
15 | word_map_file = '/media/ssd/caption data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json' # word map, ensure it's the same the data was encoded with and the model was trained with
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors
17 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
18 |
19 | # Load model
20 | checkpoint = torch.load(checkpoint)
21 | decoder = checkpoint['decoder']
22 | decoder = decoder.to(device)
23 | decoder.eval()
24 | encoder = checkpoint['encoder']
25 | encoder = encoder.to(device)
26 | encoder.eval()
27 |
28 | # Load word map (word2ix)
29 | with open(word_map_file, 'r') as j:
30 | word_map = json.load(j)
31 | rev_word_map = {v: k for k, v in word_map.items()}
32 | vocab_size = len(word_map)
33 |
34 | # Normalization transform
35 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 |
38 |
39 | def evaluate(beam_size):
40 | """
41 | Evaluation
42 | :param beam_size: beam size at which to generate captions for evaluation
43 | :return: BLEU-4 score
44 | """
45 | # DataLoader
46 | loader = torch.utils.data.DataLoader(
47 | CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
48 | batch_size=1, shuffle=True, num_workers=1, pin_memory=True)
49 |
50 | # TODO: Batched Beam Search
51 |
52 | # Lists to store references (true captions), and hypothesis (prediction) for each image
53 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
54 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
55 | references = list()
56 | hypotheses = list()
57 |
58 | # For each image
59 | for i, (image, caps, caplens, allcaps) in enumerate(
60 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):
61 |
62 | k = beam_size
63 |
64 | # Move to GPU device, if available
65 | image = image.to(device) # (1, 3, 256, 256)
66 |
67 | # Encode
68 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
69 | enc_image_size = encoder_out.size(1)
70 | encoder_dim = encoder_out.size(3)
71 |
72 | # Flatten encoding
73 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
74 | num_pixels = encoder_out.size(1)
75 |
76 | # We'll treat the problem as having a batch size of k
77 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
78 |
79 | # Tensor to store top k previous words at each step; now they're just
80 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
81 |
82 | # Tensor to store top k sequences; now they're just
83 | seqs = k_prev_words # (k, 1)
84 |
85 | # Tensor to store top k sequences' scores; now they're just 0
86 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
87 |
88 | # Lists to store completed sequences and scores
89 | complete_seqs = list()
90 | complete_seqs_scores = list()
91 |
92 | # Start decoding
93 | step = 1
94 | h, c = decoder.init_hidden_state(encoder_out)
95 |
96 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
97 | while True:
98 |
99 | embeddings = decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
100 |
101 | awe, _ = decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
102 |
103 | gate = decoder.sigmoid(decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
104 | awe = gate * awe
105 |
106 | h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
107 |
108 | scores = decoder.fc(h) # (s, vocab_size)
109 | scores = F.log_softmax(scores, dim=1)
110 |
111 | # Add
112 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
113 |
114 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
115 | if step == 1:
116 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
117 | else:
118 | # Unroll and find top scores, and their unrolled indices
119 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
120 |
121 | # Convert unrolled indices to actual indices of scores
122 | prev_word_inds = top_k_words / vocab_size # (s)
123 | next_word_inds = top_k_words % vocab_size # (s)
124 |
125 | # Add new words to sequences
126 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
127 |
128 | # Which sequences are incomplete (didn't reach )?
129 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
130 | next_word != word_map['']]
131 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
132 |
133 | # Set aside complete sequences
134 | if len(complete_inds) > 0:
135 | complete_seqs.extend(seqs[complete_inds].tolist())
136 | complete_seqs_scores.extend(top_k_scores[complete_inds])
137 | k -= len(complete_inds) # reduce beam length accordingly
138 |
139 | # Proceed with incomplete sequences
140 | if k == 0:
141 | break
142 | seqs = seqs[incomplete_inds]
143 | h = h[prev_word_inds[incomplete_inds]]
144 | c = c[prev_word_inds[incomplete_inds]]
145 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
146 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
147 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
148 |
149 | # Break if things have been going on too long
150 | if step > 50:
151 | break
152 | step += 1
153 |
154 | i = complete_seqs_scores.index(max(complete_seqs_scores))
155 | seq = complete_seqs[i]
156 |
157 | # References
158 | img_caps = allcaps[0].tolist()
159 | img_captions = list(
160 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}],
161 | img_caps)) # remove and pads
162 | references.append(img_captions)
163 |
164 | # Hypotheses
165 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}])
166 |
167 | assert len(references) == len(hypotheses)
168 |
169 | # Calculate BLEU-4 scores
170 | bleu4 = corpus_bleu(references, hypotheses, emulate_multibleu=True)
171 |
172 | return bleu4
173 |
174 |
175 | if __name__ == '__main__':
176 | beam_size = 5
177 | print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size)))
178 |
--------------------------------------------------------------------------------
/font/WenQuanYiMicroHei-01.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/WenQuanYiMicroHei-01.ttf
--------------------------------------------------------------------------------
/font/WenQuanYiMicroHeiMono-02.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/WenQuanYiMicroHeiMono-02.ttf
--------------------------------------------------------------------------------
/font/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/font/simhei.ttf
--------------------------------------------------------------------------------
/images/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/dataset.png
--------------------------------------------------------------------------------
/images/image_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_0.jpg
--------------------------------------------------------------------------------
/images/image_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_1.jpg
--------------------------------------------------------------------------------
/images/image_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_2.jpg
--------------------------------------------------------------------------------
/images/image_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_3.jpg
--------------------------------------------------------------------------------
/images/image_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_4.jpg
--------------------------------------------------------------------------------
/images/image_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_5.jpg
--------------------------------------------------------------------------------
/images/image_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_6.jpg
--------------------------------------------------------------------------------
/images/image_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_7.jpg
--------------------------------------------------------------------------------
/images/image_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_8.jpg
--------------------------------------------------------------------------------
/images/image_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/image_9.jpg
--------------------------------------------------------------------------------
/images/net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/net.png
--------------------------------------------------------------------------------
/images/out_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_0.jpg
--------------------------------------------------------------------------------
/images/out_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_1.jpg
--------------------------------------------------------------------------------
/images/out_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_2.jpg
--------------------------------------------------------------------------------
/images/out_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_3.jpg
--------------------------------------------------------------------------------
/images/out_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_4.jpg
--------------------------------------------------------------------------------
/images/out_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_5.jpg
--------------------------------------------------------------------------------
/images/out_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_6.jpg
--------------------------------------------------------------------------------
/images/out_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_7.jpg
--------------------------------------------------------------------------------
/images/out_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_8.jpg
--------------------------------------------------------------------------------
/images/out_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/images/out_9.jpg
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class Encoder(nn.Module):
9 | """
10 | Encoder.
11 | """
12 |
13 | def __init__(self, encoded_image_size=14):
14 | super(Encoder, self).__init__()
15 | self.enc_image_size = encoded_image_size
16 |
17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101
18 |
19 | # Remove linear and pool layers (since we're not doing classification)
20 | modules = list(resnet.children())[:-2]
21 | self.resnet = nn.Sequential(*modules)
22 |
23 | # Resize image to fixed size to allow input images of variable size
24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
25 |
26 | self.fine_tune()
27 |
28 | def forward(self, images):
29 | """
30 | Forward propagation.
31 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
32 | :return: encoded images
33 | """
34 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
35 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
36 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
37 | return out
38 |
39 | def fine_tune(self, fine_tune=True):
40 | """
41 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
42 | :param fine_tune: Allow?
43 | """
44 | for p in self.resnet.parameters():
45 | p.requires_grad = False
46 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4
47 | for c in list(self.resnet.children())[5:]:
48 | for p in c.parameters():
49 | p.requires_grad = fine_tune
50 |
51 |
52 | class Attention(nn.Module):
53 | """
54 | Attention Network.
55 | """
56 |
57 | def __init__(self, encoder_dim, decoder_dim, attention_dim):
58 | """
59 | :param encoder_dim: feature size of encoded images
60 | :param decoder_dim: size of decoder's RNN
61 | :param attention_dim: size of the attention network
62 | """
63 | super(Attention, self).__init__()
64 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
65 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
66 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
67 | self.relu = nn.ReLU()
68 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
69 |
70 | def forward(self, encoder_out, decoder_hidden):
71 | """
72 | Forward propagation.
73 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
74 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
75 | :return: attention weighted encoding, weights
76 | """
77 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
78 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
79 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
80 | alpha = self.softmax(att) # (batch_size, num_pixels)
81 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
82 |
83 | return attention_weighted_encoding, alpha
84 |
85 |
86 | class DecoderWithAttention(nn.Module):
87 | """
88 | Decoder.
89 | """
90 |
91 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
92 | """
93 | :param attention_dim: size of attention network
94 | :param embed_dim: embedding size
95 | :param decoder_dim: size of decoder's RNN
96 | :param vocab_size: size of vocabulary
97 | :param encoder_dim: feature size of encoded images
98 | :param dropout: dropout
99 | """
100 | super(DecoderWithAttention, self).__init__()
101 |
102 | self.encoder_dim = encoder_dim
103 | self.attention_dim = attention_dim
104 | self.embed_dim = embed_dim
105 | self.decoder_dim = decoder_dim
106 | self.vocab_size = vocab_size
107 | self.dropout = dropout
108 |
109 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network
110 |
111 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
112 | self.dropout = nn.Dropout(p=self.dropout)
113 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
114 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
115 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
116 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
117 | self.sigmoid = nn.Sigmoid()
118 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
119 | self.init_weights() # initialize some layers with the uniform distribution
120 |
121 | def init_weights(self):
122 | """
123 | Initializes some parameters with values from the uniform distribution, for easier convergence.
124 | """
125 | self.embedding.weight.data.uniform_(-0.1, 0.1)
126 | self.fc.bias.data.fill_(0)
127 | self.fc.weight.data.uniform_(-0.1, 0.1)
128 |
129 | def load_pretrained_embeddings(self, embeddings):
130 | """
131 | Loads embedding layer with pre-trained embeddings.
132 | :param embeddings: pre-trained embeddings
133 | """
134 | self.embedding.weight = nn.Parameter(embeddings)
135 |
136 | def fine_tune_embeddings(self, fine_tune=True):
137 | """
138 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
139 | :param fine_tune: Allow?
140 | """
141 | for p in self.embedding.parameters():
142 | p.requires_grad = fine_tune
143 |
144 | def init_hidden_state(self, encoder_out):
145 | """
146 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
147 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
148 | :return: hidden state, cell state
149 | """
150 | mean_encoder_out = encoder_out.mean(dim=1)
151 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
152 | c = self.init_c(mean_encoder_out)
153 | return h, c
154 |
155 | def forward(self, encoder_out, encoded_captions, caption_lengths):
156 | """
157 | Forward propagation.
158 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
159 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
160 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
161 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
162 | """
163 |
164 | batch_size = encoder_out.size(0)
165 | encoder_dim = encoder_out.size(-1)
166 | vocab_size = self.vocab_size
167 |
168 | # Flatten image
169 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
170 | num_pixels = encoder_out.size(1)
171 |
172 | # Sort input data by decreasing lengths; why? apparent below
173 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
174 | encoder_out = encoder_out[sort_ind]
175 | encoded_captions = encoded_captions[sort_ind]
176 |
177 | # Embedding
178 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
179 |
180 | # Initialize LSTM state
181 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim)
182 |
183 | # We won't decode at the position, since we've finished generating as soon as we generate
184 | # So, decoding lengths are actual lengths - 1
185 | decode_lengths = (caption_lengths - 1).tolist()
186 |
187 | # Create tensors to hold word predicion scores and alphas
188 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
189 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
190 |
191 | # At each time-step, decode by
192 | # attention-weighing the encoder's output based on the decoder's previous hidden state output
193 | # then generate a new word in the decoder with the previous word and the attention weighted encoding
194 | for t in range(max(decode_lengths)):
195 | batch_size_t = sum([l > t for l in decode_lengths])
196 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
197 | h[:batch_size_t])
198 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
199 | attention_weighted_encoding = gate * attention_weighted_encoding
200 | h, c = self.decode_step(
201 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
202 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
203 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
204 | predictions[:batch_size_t, t, :] = preds
205 | alphas[:batch_size_t, t, :] = alpha
206 |
207 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind
208 |
--------------------------------------------------------------------------------
/pre_process.py:
--------------------------------------------------------------------------------
1 | import json
2 | import zipfile
3 | from collections import Counter
4 |
5 | import jieba
6 | from tqdm import tqdm
7 |
8 | from config import *
9 | from utils import ensure_folder
10 |
11 |
12 | def extract(folder):
13 | filename = '{}.zip'.format(folder)
14 | print('Extracting {}...'.format(filename))
15 | with zipfile.ZipFile(filename, 'r') as zip_ref:
16 | zip_ref.extractall('data')
17 |
18 |
19 | def create_input_files():
20 | json_path = train_annotations_filename
21 |
22 | # Read JSON
23 | with open(json_path, 'r') as j:
24 | samples = json.load(j)
25 |
26 | # Read image paths and captions for each image
27 | word_freq = Counter()
28 |
29 | for sample in tqdm(samples):
30 | caption = sample['caption']
31 | for c in caption:
32 | seg_list = jieba.cut(c, cut_all=True)
33 | # Update word frequency
34 | word_freq.update(seg_list)
35 |
36 | # Create word map
37 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
38 | word_map = {k: v + 1 for v, k in enumerate(words)}
39 | word_map[''] = len(word_map) + 1
40 | word_map[''] = len(word_map) + 1
41 | word_map[''] = len(word_map) + 1
42 | word_map[''] = 0
43 |
44 | print(len(word_map))
45 | print(words[:10])
46 |
47 | # Save word map to a JSON
48 | with open(os.path.join(data_folder, 'WORDMAP.json'), 'w') as j:
49 | json.dump(word_map, j)
50 |
51 |
52 | if __name__ == '__main__':
53 | # parameters
54 | ensure_folder('data')
55 |
56 | if not os.path.isdir(train_image_folder):
57 | extract(train_folder)
58 |
59 | if not os.path.isdir(valid_image_folder):
60 | extract(valid_folder)
61 |
62 | if not os.path.isdir(test_a_image_folder):
63 | extract(test_a_folder)
64 |
65 | if not os.path.isdir(test_b_image_folder):
66 | extract(test_b_folder)
67 |
68 | create_input_files()
69 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jieba
2 | gensim
3 | tqdm
--------------------------------------------------------------------------------
/sponsor.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Image-Captioning-PyTorch/89276fa520e85fa25b603900a8f24a2d926b55bb/sponsor.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 |
4 | import torch.optim
5 | import torch.utils.data
6 | import torchvision.transforms as transforms
7 | from nltk.translate.bleu_score import corpus_bleu
8 | from torch import nn
9 | from torch.nn.utils.rnn import pack_padded_sequence
10 |
11 | from config import *
12 | from data_generator import CaptionDataset
13 | from models import Encoder, DecoderWithAttention
14 | from utils import *
15 |
16 |
17 | def main():
18 | """
19 | Training and validation.
20 | """
21 |
22 | global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, word_map
23 |
24 | # Read word map
25 | word_map_file = os.path.join(data_folder, 'WORDMAP.json')
26 | with open(word_map_file, 'r') as j:
27 | word_map = json.load(j)
28 |
29 | # Initialize / load checkpoint
30 | if checkpoint is None:
31 | decoder = DecoderWithAttention(attention_dim=attention_dim,
32 | embed_dim=emb_dim,
33 | decoder_dim=decoder_dim,
34 | vocab_size=len(word_map),
35 | dropout=dropout)
36 | # decoder = nn.DataParallel(decoder)
37 | # decoder = torch.nn.DataParallel(decoder.cuda(), device_ids=[0, 1, 2, 3])
38 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
39 | lr=decoder_lr)
40 | encoder = Encoder()
41 | encoder.fine_tune(fine_tune_encoder)
42 | # encoder = nn.DataParallel(encoder)
43 | # encoder = torch.nn.DataParallel(encoder.cuda(), device_ids=[0, 1, 2, 3])
44 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
45 | lr=encoder_lr) if fine_tune_encoder else None
46 |
47 | else:
48 | checkpoint = torch.load(checkpoint)
49 | start_epoch = checkpoint['epoch'] + 1
50 | epochs_since_improvement = checkpoint['epochs_since_improvement']
51 | best_bleu4 = checkpoint['bleu-4']
52 | decoder = checkpoint['decoder']
53 | decoder_optimizer = checkpoint['decoder_optimizer']
54 | encoder = checkpoint['encoder']
55 | encoder_optimizer = checkpoint['encoder_optimizer']
56 | if fine_tune_encoder is True and encoder_optimizer is None:
57 | encoder.fine_tune(fine_tune_encoder)
58 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
59 | lr=encoder_lr)
60 |
61 | # Move to GPU, if available
62 | decoder = decoder.to(device)
63 | encoder = encoder.to(device)
64 |
65 | # Loss function
66 | criterion = nn.CrossEntropyLoss().to(device)
67 |
68 | # Custom dataloaders
69 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
70 | std=[0.229, 0.224, 0.225])
71 | train_loader = torch.utils.data.DataLoader(
72 | CaptionDataset('train', transform=transforms.Compose([normalize])),
73 | batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
74 | val_loader = torch.utils.data.DataLoader(
75 | CaptionDataset('valid', transform=transforms.Compose([normalize])),
76 | batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
77 |
78 | # Epochs
79 | for epoch in range(start_epoch, epochs):
80 |
81 | # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
82 | if epochs_since_improvement == 20:
83 | break
84 | if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
85 | adjust_learning_rate(decoder_optimizer, 0.8)
86 | if fine_tune_encoder:
87 | adjust_learning_rate(encoder_optimizer, 0.8)
88 |
89 | # One epoch's training
90 | train(train_loader=train_loader,
91 | encoder=encoder,
92 | decoder=decoder,
93 | criterion=criterion,
94 | encoder_optimizer=encoder_optimizer,
95 | decoder_optimizer=decoder_optimizer,
96 | epoch=epoch)
97 |
98 | # One epoch's validation
99 | recent_bleu4 = validate(val_loader=val_loader,
100 | encoder=encoder,
101 | decoder=decoder,
102 | criterion=criterion)
103 |
104 | # Check if there was an improvement
105 | is_best = recent_bleu4 > best_bleu4
106 | best_bleu4 = max(recent_bleu4, best_bleu4)
107 | if not is_best:
108 | epochs_since_improvement += 1
109 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
110 | else:
111 | epochs_since_improvement = 0
112 |
113 | # Save checkpoint
114 | save_checkpoint(epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
115 | decoder_optimizer, recent_bleu4, is_best)
116 |
117 |
118 | def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
119 | """
120 | Performs one epoch's training.
121 | :param train_loader: DataLoader for training data
122 | :param encoder: encoder model
123 | :param decoder: decoder model
124 | :param criterion: loss layer
125 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
126 | :param decoder_optimizer: optimizer to update decoder's weights
127 | :param epoch: epoch number
128 | """
129 |
130 | decoder.train() # train mode (dropout and batchnorm is used)
131 | encoder.train()
132 |
133 | batch_time = AverageMeter() # forward prop. + back prop. time
134 | data_time = AverageMeter() # data loading time
135 | losses = AverageMeter() # loss (per word decoded)
136 | top5accs = AverageMeter() # top5 accuracy
137 |
138 | start = time.time()
139 |
140 | # Batches
141 | for i, (imgs, caps, caplens) in enumerate(train_loader):
142 | data_time.update(time.time() - start)
143 |
144 | # Move to GPU, if available
145 | imgs = imgs.to(device)
146 | caps = caps.to(device)
147 | caplens = caplens.to(device)
148 |
149 | # Forward prop.
150 | imgs = encoder(imgs)
151 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
152 |
153 | # Since we decoded starting with , the targets are all words after , up to
154 | targets = caps_sorted[:, 1:]
155 |
156 | # Remove timesteps that we didn't decode at, or are pads
157 | # pack_padded_sequence is an easy trick to do this
158 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
159 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
160 |
161 | # Calculate loss
162 | loss = criterion(scores, targets)
163 |
164 | # Add doubly stochastic attention regularization
165 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
166 |
167 | # Back prop.
168 | decoder_optimizer.zero_grad()
169 | if encoder_optimizer is not None:
170 | encoder_optimizer.zero_grad()
171 | loss.backward()
172 |
173 | # Clip gradients
174 | if grad_clip is not None:
175 | clip_gradient(decoder_optimizer, grad_clip)
176 | if encoder_optimizer is not None:
177 | clip_gradient(encoder_optimizer, grad_clip)
178 |
179 | # Update weights
180 | decoder_optimizer.step()
181 | if encoder_optimizer is not None:
182 | encoder_optimizer.step()
183 |
184 | # Keep track of metrics
185 | top5 = accuracy(scores, targets, 5)
186 | losses.update(loss.item(), sum(decode_lengths))
187 | top5accs.update(top5, sum(decode_lengths))
188 | batch_time.update(time.time() - start)
189 |
190 | start = time.time()
191 |
192 | # Print status
193 | if i % print_freq == 0:
194 | print('Epoch: [{0}][{1}/{2}]\t'
195 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
196 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
197 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
198 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
199 | batch_time=batch_time,
200 | data_time=data_time, loss=losses,
201 | top5=top5accs))
202 |
203 |
204 | def validate(val_loader, encoder, decoder, criterion):
205 | """
206 | Performs one epoch's validation.
207 | :param val_loader: DataLoader for validation data.
208 | :param encoder: encoder model
209 | :param decoder: decoder model
210 | :param criterion: loss layer
211 | :return: BLEU-4 score
212 | """
213 | decoder.eval() # eval mode (no dropout or batchnorm)
214 | if encoder is not None:
215 | encoder.eval()
216 |
217 | batch_time = AverageMeter()
218 | losses = AverageMeter()
219 | top5accs = AverageMeter()
220 |
221 | start = time.time()
222 |
223 | references = list() # references (true captions) for calculating BLEU-4 score
224 | hypotheses = list() # hypotheses (predictions)
225 |
226 | # Batches
227 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
228 |
229 | # Move to device, if available
230 | imgs = imgs.to(device)
231 | caps = caps.to(device)
232 | caplens = caplens.to(device)
233 |
234 | # Forward prop.
235 | if encoder is not None:
236 | imgs = encoder(imgs)
237 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
238 |
239 | # Since we decoded starting with , the targets are all words after , up to
240 | targets = caps_sorted[:, 1:]
241 |
242 | # Remove timesteps that we didn't decode at, or are pads
243 | # pack_padded_sequence is an easy trick to do this
244 | scores_copy = scores.clone()
245 | scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
246 | targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)
247 |
248 | # Calculate loss
249 | loss = criterion(scores, targets)
250 |
251 | # Add doubly stochastic attention regularization
252 | loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
253 |
254 | # Keep track of metrics
255 | losses.update(loss.item(), sum(decode_lengths))
256 | top5 = accuracy(scores, targets, 5)
257 | top5accs.update(top5, sum(decode_lengths))
258 | batch_time.update(time.time() - start)
259 |
260 | start = time.time()
261 |
262 | if i % print_freq == 0:
263 | print('Validation: [{0}/{1}]\t'
264 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
265 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
266 | 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
267 | loss=losses, top5=top5accs))
268 |
269 | # Store references (true captions), and hypothesis (prediction) for each image
270 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
271 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
272 |
273 | # References
274 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder
275 | for j in range(allcaps.shape[0]):
276 | img_caps = allcaps[j].tolist()
277 | img_captions = list(
278 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}],
279 | img_caps)) # remove and pads
280 | references.append(img_captions)
281 |
282 | # Hypotheses
283 | _, preds = torch.max(scores_copy, dim=2)
284 | preds = preds.tolist()
285 | temp_preds = list()
286 | for j, p in enumerate(preds):
287 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads
288 | preds = temp_preds
289 | hypotheses.extend(preds)
290 |
291 | assert len(references) == len(hypotheses)
292 |
293 | # Calculate BLEU-4 scores
294 | bleu4 = corpus_bleu(references, hypotheses)
295 |
296 | print(
297 | '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
298 | loss=losses,
299 | top5=top5accs,
300 | bleu=bleu4))
301 |
302 | return bleu4
303 |
304 |
305 | if __name__ == '__main__':
306 | main()
307 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 |
4 | import cv2 as cv
5 | import torch
6 |
7 |
8 | def ensure_folder(folder):
9 | if not os.path.exists(folder):
10 | os.makedirs(folder)
11 |
12 |
13 | # getting the number of CPUs
14 | def get_available_cpus():
15 | return multiprocessing.cpu_count()
16 |
17 |
18 | def draw_str(dst, target, s):
19 | x, y = target
20 | cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA)
21 | cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA)
22 |
23 |
24 | def clip_gradient(optimizer, grad_clip):
25 | """
26 | Clips gradients computed during backpropagation to avoid explosion of gradients.
27 | :param optimizer: optimizer with the gradients to be clipped
28 | :param grad_clip: clip value
29 | """
30 | for group in optimizer.param_groups:
31 | for param in group['params']:
32 | if param.grad is not None:
33 | param.grad.data.clamp_(-grad_clip, grad_clip)
34 |
35 |
36 | def save_checkpoint(epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer,
37 | bleu4, is_best):
38 | """
39 | Saves model checkpoint.
40 | :param data_name: base name of processed dataset
41 | :param epoch: epoch number
42 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
43 | :param encoder: encoder model
44 | :param decoder: decoder model
45 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
46 | :param decoder_optimizer: optimizer to update decoder's weights
47 | :param bleu4: validation BLEU-4 score for this epoch
48 | :param is_best: is this checkpoint the best so far?
49 | """
50 | state = {'epoch': epoch,
51 | 'epochs_since_improvement': epochs_since_improvement,
52 | 'bleu-4': bleu4,
53 | 'encoder': encoder,
54 | 'decoder': decoder,
55 | 'encoder_optimizer': encoder_optimizer,
56 | 'decoder_optimizer': decoder_optimizer}
57 | filename = 'checkpoint_' + '.pth.tar'
58 | torch.save(state, filename)
59 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
60 | if is_best:
61 | torch.save(state, 'BEST_' + filename)
62 |
63 |
64 | class AverageMeter(object):
65 | """
66 | Keeps track of most recent, average, sum, and count of a metric.
67 | """
68 |
69 | def __init__(self):
70 | self.reset()
71 |
72 | def reset(self):
73 | self.val = 0
74 | self.avg = 0
75 | self.sum = 0
76 | self.count = 0
77 |
78 | def update(self, val, n=1):
79 | self.val = val
80 | self.sum += val * n
81 | self.count += n
82 | self.avg = self.sum / self.count
83 |
84 |
85 | def adjust_learning_rate(optimizer, shrink_factor):
86 | """
87 | Shrinks learning rate by a specified factor.
88 | :param optimizer: optimizer whose learning rate must be shrunk.
89 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
90 | """
91 |
92 | print("\nDECAYING learning rate.")
93 | for param_group in optimizer.param_groups:
94 | param_group['lr'] = param_group['lr'] * shrink_factor
95 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
96 |
97 |
98 | def accuracy(scores, targets, k):
99 | """
100 | Computes top-k accuracy, from predicted and true labels.
101 | :param scores: scores from the model
102 | :param targets: true labels
103 | :param k: k in top-k accuracy
104 | :return: top-k accuracy
105 | """
106 |
107 | batch_size = targets.size(0)
108 | _, ind = scores.topk(k, 1, True, True)
109 | correct = ind.eq(targets.view(-1, 1).expand_as(ind))
110 | correct_total = correct.view(-1).float().sum() # 0D tensor
111 | return correct_total.item() * (100.0 / batch_size)
112 |
--------------------------------------------------------------------------------