├── Image caption generator.ipynb ├── README.md ├── Tutorial - Show, Attend and Tell.ipynb ├── annotate_webcam_stream.py ├── caption_generator_api ├── BeamSearch.py ├── caption_generator_api.py ├── caption_generator_model.py └── utils.py ├── download_dataset.sh ├── download_images.py └── prepare_data.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Neural Image Captioning 2 | 3 | Click [here](https://www.youtube.com/watch?v=i7QnEdn0RZ8) or on the picture below to watch it in action: 4 | 5 | IMAGE ALT TEXT HERE -------------------------------------------------------------------------------- /annotate_webcam_stream.py: -------------------------------------------------------------------------------- 1 | import pyttsx3 2 | import engineio 3 | import requests 4 | import cv2 5 | from io import BytesIO 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | from invoke import run 9 | import argparse 10 | import sys 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("-p", "--port", type=int, default=7777, help="Port of caption generator API") 14 | parser.add_argument("-e", "--engine", type=int, default=0, help="0 if you want to use the 'say' Speech Synthesis Manager engine, 1 if you want to use pyttsx3") 15 | args = parser.parse_args() 16 | 17 | url_coco = 'http://localhost:{}/generate_caption_ms_coco'.format(args.port) 18 | url_cc = 'http://localhost:{}/generate_caption_conceptual_captions'.format(args.port) 19 | 20 | special_toks = ('xxunk', 'xxpad', 'xxbos', 'xxfld', 'xxmaj', 'xxup', 'xxrep', 'xxwrep', '-') 21 | 22 | if args.engine == 1: 23 | engine = pyttsx3.init() 24 | voices = engine.getProperty('voices') 25 | 26 | engine.setProperty('voice', voices[0].id) 27 | #engine.setProperty('rate', 150) 28 | 29 | elif args.engine > 1: 30 | print("Invalid TTS engine chosen") 31 | sys.exit() 32 | 33 | def remove_special_toks(cap): 34 | return ' '.join([w for w in cap.split() if w not in special_toks]) 35 | 36 | def say_caption(caption): 37 | if args.engine == 1: 38 | engine.say(caption) 39 | engine.runAndWait() 40 | else: 41 | cmd = "say -v Alex {}".format(caption) 42 | result = run(cmd, hide=True, warn=True) 43 | 44 | def get_caption(img, url): 45 | imgByteArr = BytesIO() 46 | img.save(imgByteArr, format='JPEG') 47 | imgByteArr.seek(0) 48 | 49 | response = requests.post(url, files={'input_image': imgByteArr}) 50 | if response.ok: 51 | return response.text 52 | else: 53 | return "Could not connect to API" 54 | 55 | cam = cv2.VideoCapture(0) 56 | cam.open(0) 57 | cv2.namedWindow("Window_1", cv2.WINDOW_NORMAL) 58 | cv2.resizeWindow("Window_1", 960,600) 59 | 60 | while True: 61 | ret, frame = cam.read() 62 | cv2.imshow("Window_1", frame) 63 | if not ret: 64 | break 65 | k = cv2.waitKey(1) 66 | 67 | if k%256 == 27: 68 | # ESC pressed 69 | print("Escape hit, closing...") 70 | break 71 | elif k%256 == 32: 72 | # SPACE pressed 73 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 74 | frame = Image.fromarray(frame, 'RGB') 75 | caption = get_caption(frame, url_coco) 76 | say_caption(remove_special_toks(caption)) 77 | print(caption) 78 | elif k%256 == 13: 79 | # RET pressed 80 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 81 | frame = Image.fromarray(frame, 'RGB') 82 | caption = get_caption(frame, url_cc) 83 | say_caption(remove_special_toks(caption)) 84 | print(caption) 85 | 86 | cam.release() 87 | 88 | cv2.destroyAllWindows() 89 | -------------------------------------------------------------------------------- /caption_generator_api/BeamSearch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class HypothesisNode(): 4 | """ Hypothesis Node class for performing Beam Search """ 5 | def __init__(self, sequence, log_prob, hidden_state, alphas): 6 | """HypothesisNode constructur 7 | 8 | Args: 9 | sequence: A sequence of tokens 10 | log_prob: The log of the probability of this sequence 11 | hidden_state: The hidden state of the Decoder RNN after decoding the last token in the sequence 12 | """ 13 | self._seq = sequence 14 | self._alphas = alphas 15 | self._log_prob = log_prob 16 | self._h = hidden_state 17 | 18 | @property 19 | def last_tok(self): 20 | """ 21 | Returns: 22 | The last token in the sequence 23 | """ 24 | return self._seq[-1] 25 | 26 | def update(self, tok, log_prob, new_h, new_alpha): 27 | """ 28 | Updates the sequence with a new token and returns a new Hypothesis Node 29 | Args: 30 | tok: The new token that is appended to the sequence 31 | log_prob: The log of the probability ot this token 32 | new_h: The new hidden state of the Decoder RNN after this token 33 | 34 | Returns: 35 | An Hypothesis Node with the updated sequence, log probability and hidden state 36 | """ 37 | return HypothesisNode(self._seq + [tok], self._log_prob + log_prob, new_h, self._alphas + new_alpha) 38 | 39 | def __str__(self): 40 | return ('Hyp(log_p = %4f,\t seq = %s)' % (self._log_prob, vocab.textify([t.item()for t in self._seq]))) 41 | 42 | class BeamSearch(): 43 | """ Performs BeamSearch for seq2seq decoding or Image captioning """ 44 | def __init__(self, enc_model, dec_model, beam_width=5, num_results=1, max_len=30, device=torch.device('cuda:0')): 45 | """BeamSearch object constructor 46 | Args: 47 | enc_model: A seq2seq encoder or cnn for image captioning 48 | dec_model: A RNN decoder model 49 | beam_width: int, the number of hypotheses to remember in each iteration 50 | max_len: int, the longest possible sequence 51 | """ 52 | self._device = device 53 | self._enc_model = enc_model 54 | self._dec_model = dec_model 55 | self._beam_width = beam_width 56 | self._num_results = num_results 57 | self._max_len = max_len 58 | self._start_tok = 0 59 | self._end_tok = 1 60 | self._annotation_vecs = None 61 | 62 | def __call__(self, img, verbose=False): 63 | """Performs the Beam search 64 | Args: 65 | img: the image to be annotated, torch tensor with 3 color channels 66 | verbose: bool, allows printing the intermediate hypotheses for better understanding 67 | 68 | Returns: 69 | The 'beam_width' most probable sentences 70 | """ 71 | img = img.unsqueeze(0) 72 | h, annotation_vecs = self._enc_model(img) 73 | self._annotation_vecs = annotation_vecs 74 | 75 | hyps = [HypothesisNode([torch.zeros(1, requires_grad=False).long().to(self._device)], 0, h, [])] 76 | results = [] 77 | 78 | step = 0 79 | width = self._beam_width 80 | while width > 0 and step < self._max_len: 81 | if verbose: print("\n Step: ",step) 82 | new_hyps = [] 83 | for h in hyps: 84 | new_hyps.extend(self.get_next_hypotheses(h, width)) 85 | 86 | new_hyps = sorted(new_hyps, key= lambda x: x._log_prob, reverse=True) 87 | if verbose: self.print_hypotheses(new_hyps, "Before narrowing:") 88 | 89 | hyps = [] 90 | for h in new_hyps[:width]: 91 | if h.last_tok == self._end_tok: 92 | results.append(h) 93 | width = width - 1 94 | else: 95 | hyps.append(h) 96 | 97 | if verbose: 98 | self.print_hypotheses(hyps, "After narrowing:") 99 | self.print_hypotheses(results, "Results:") 100 | 101 | step += 1 102 | 103 | results.extend(hyps[:width]) 104 | results = sorted(results, key=lambda x: x._log_prob/len(x._seq), reverse=True) 105 | 106 | if verbose: self.print_hypotheses(results, "Final:") 107 | 108 | if self._num_results == 1: 109 | return ([t.item() for t in results[0]._seq[1:-1]], torch.stack(results[0]._alphas)) 110 | else: 111 | return [([t.item() for t in r._seq[1:-1]], torch.stack(r._alphas)) for r in results[:self._num_results]] 112 | 113 | def get_next_hypotheses(self, hyp, k): 114 | """Calculates the next 'beam_width' hypotheses given a Hypothesis Node 115 | Args: 116 | hyp: an Hypothesis Node containing a sequence, a log probability and a Decoder RNN hidden state 117 | k: the number of hypotheses to calculate 118 | Returns: 119 | A list with the 'beam_width' most probable sequences/Hypothesis Nodes 120 | """ 121 | 122 | dec_outp, h, alphas = self._dec_model(hyp.last_tok, hyp._h, self._annotation_vecs) 123 | 124 | top_k_log_probs, top_k_toks = dec_outp.topk(k, dim=1) 125 | 126 | return [hyp.update(top_k_toks[0][i].unsqueeze(0), top_k_log_probs[0][i], h, list(alphas)) for i in range(k)] 127 | 128 | def print_hypotheses(self, hyps, description): 129 | print(description) 130 | for h in hyps: 131 | print(h) 132 | -------------------------------------------------------------------------------- /caption_generator_api/caption_generator_api.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") # so that import pyplot does not try to pull in a GUI 3 | from flask import Flask, request, send_file 4 | from flasgger import Swagger 5 | #from fastai.text import * 6 | from torchvision import transforms, models 7 | from PIL import Image 8 | from caption_generator_model import * 9 | from utils import * 10 | from BeamSearch import * 11 | import torch 12 | import time 13 | 14 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 15 | 16 | app = Flask(__name__) 17 | app.config['SWAGGER'] = { 18 | "swagger_version": "2.0", 19 | "title": "Image Caption Generator", 20 | "description": "Generates image captions using a model based on the article 'Show, Attend and Tell: Neural Image Caption Generation with Visual Attention' by Xu et al. (2016)."} 21 | 22 | swagger = Swagger(app) 23 | 24 | n_layers, emb_sz = 1, 500 25 | beam_width = 5 26 | 27 | vocab_ms_coco = pickle.load(open("vocab_ms_coco.pkl", "rb" )) 28 | vocab_cc = pickle.load(open("vocab_conceptual_captions.pkl", "rb" )) 29 | 30 | model_ms_coco = torch.load('model_ms_coco.pth', map_location=device); 31 | model_ms_coco = model_ms_coco.to(device) 32 | model_ms_coco.eval(); 33 | 34 | model_cc = torch.load('model_conceptual_captions.pth', map_location=device); 35 | model_cc = model_cc.to(device) 36 | model_cc.eval(); 37 | 38 | model_ms_coco.device = device 39 | model_cc.device = device 40 | 41 | beam_search_ms_coco = BeamSearch(model_ms_coco.encode, model_ms_coco.decode_step, beam_width, device=device) 42 | beam_search_cc = BeamSearch(model_cc.encode, model_cc.decode_step, beam_width, device=device) 43 | 44 | sz = 224 45 | 46 | valid_tfms = transforms.Compose([ 47 | transforms.Resize(sz), 48 | transforms.CenterCrop(sz), 49 | transforms.ToTensor(), 50 | transforms.Normalize([0.5238, 0.5003, 0.4718], [0.3159, 0.3091, 0.3216]) 51 | ]) 52 | 53 | inv_normalize = transforms.Normalize( 54 | mean=[-0.5238/0.3159, -0.5003/0.3091, -0.4718/0.3216], 55 | std=[1/0.3159, 1/0.3091, 1/0.3216] 56 | ) 57 | 58 | denorm = transforms.Compose([ 59 | inv_normalize, 60 | transforms.functional.to_pil_image 61 | ]) 62 | 63 | def generate_caption_helper(img, beam_search_func): 64 | img_transformed = valid_tfms(img) 65 | results = beam_search_func(img_transformed) 66 | 67 | return img_transformed, results 68 | 69 | def visualize_attention_mechanism_helper(img, beam_search_func, vocab): 70 | transformed_img, results = generate_caption_helper(img, beam_search_func) 71 | visualization = visualize_attention(transformed_img, results[0], results[1], denorm, vocab, return_fig_as_PIL_image=True).convert("RGB") 72 | 73 | imgByteArr = io.BytesIO() 74 | visualization.save(imgByteArr, format='JPEG') 75 | 76 | return send_file(io.BytesIO(imgByteArr.getvalue()), 77 | attachment_filename='return.jpeg', 78 | mimetype='image/jpeg') 79 | 80 | @app.route('/generate_caption_ms_coco', methods=["POST"]) 81 | def generate_caption_ms_coco(): 82 | """Generates a caption for the given image using a model trained on the MS COCO dataset 83 | --- 84 | tags: 85 | - Image caption generator 86 | parameters: 87 | - name: input_image 88 | in: formData 89 | type: file 90 | required: true 91 | responses: 92 | 200: 93 | description: "image" 94 | """ 95 | try: 96 | img = Image.open(request.files.get("input_image")) 97 | _, results = generate_caption_helper(img, beam_search_ms_coco) 98 | return vocab_ms_coco.textify(results[0]) 99 | except: 100 | return "Prediction unsuccessful. Please choose a JPEG file." 101 | 102 | @app.route('/visualize_attention_mechanism_ms_coco', methods=["POST"]) 103 | def visualize_attention_mechanism_ms_coco(): 104 | """Generates a caption and visualizes the attention mechanism for the given image using a model trained on the MS COCO dataset 105 | --- 106 | tags: 107 | - Image caption generator 108 | parameters: 109 | - name: input_image 110 | in: formData 111 | type: file 112 | required: true 113 | responses: 114 | 200: 115 | description: "image" 116 | """ 117 | try: 118 | img = Image.open(request.files.get("input_image")) 119 | return visualize_attention_mechanism_helper(img, beam_search_ms_coco, vocab_ms_coco) 120 | except: 121 | return "Prediction unsuccessful. Please choose a JPEG file." 122 | 123 | @app.route('/generate_caption_conceptual_captions', methods=["POST"]) 124 | def generate_caption_conceptual_captions(): 125 | """Generates a caption for the given image using a model trained on the Conceptual Captions dataset 126 | --- 127 | tags: 128 | - Image caption generator 129 | parameters: 130 | - name: input_image 131 | in: formData 132 | type: file 133 | required: true 134 | responses: 135 | 200: 136 | description: "image" 137 | """ 138 | try: 139 | img = Image.open(request.files.get("input_image")) 140 | _, results = generate_caption_helper(img, beam_search_cc) 141 | return vocab_cc.textify(results[0]) 142 | except: 143 | return "Prediction unsuccessful. Please choose a JPEG file." 144 | 145 | @app.route('/visualize_attention_mechanism_conceptual_captions', methods=["POST"]) 146 | def visualize_attention_mechanism_conceptual_captions(): 147 | """Generates a caption and visualizes the attention mechanism for the given image using a model trained on the Conceptual Captions dataset 148 | --- 149 | tags: 150 | - Image caption generator 151 | parameters: 152 | - name: input_image 153 | in: formData 154 | type: file 155 | required: true 156 | responses: 157 | 200: 158 | description: "image" 159 | """ 160 | try: 161 | img = Image.open(request.files.get("input_image")) 162 | return visualize_attention_mechanism_helper(img, beam_search_cc, vocab_cc) 163 | except: 164 | return "Prediction unsuccessful. Please choose a JPEG file." 165 | 166 | if __name__ == "__main__": 167 | app.run(host='0.0.0.0', port=7777) 168 | -------------------------------------------------------------------------------- /caption_generator_api/caption_generator_model.py: -------------------------------------------------------------------------------- 1 | from fastai.text import * 2 | from torchvision import transforms, models 3 | 4 | def fc_layer(n_in, n_out, p=0.1): 5 | return nn.Sequential( 6 | Flatten(), 7 | nn.Linear(in_features=n_in, out_features=n_out), 8 | nn.Dropout(p), 9 | ) 10 | 11 | class Encoder(nn.Module): 12 | def __init__(self, device, dec_hidden_state_size, dec_layers, filter_width, num_filters): 13 | super().__init__() 14 | # Visual Encoder 15 | self.device = device 16 | self.base_network = nn.Sequential(*list(models.resnet101(pretrained=True).children())[:-2]) 17 | self.freeze_base_network() 18 | self.concatPool = AdaptiveConcatPool2d(sz=1) 19 | self.adaptivePool = nn.AdaptiveAvgPool2d((filter_width, filter_width)) 20 | self.filter_width = filter_width 21 | 22 | self.output_layers = nn.ModuleList([ 23 | fc_layer(2*num_filters, dec_hidden_state_size) for _ in range(dec_layers) 24 | ]) 25 | 26 | def forward(self, inp): 27 | #pdb.set_trace() 28 | enc_output = self.base_network(inp) 29 | annotation_vecs = self.adaptivePool(enc_output).view(enc_output.size(0), enc_output.size(1), -1) 30 | enc_output = self.concatPool(enc_output) 31 | 32 | dec_init_hidden_states = [MLP_layer(enc_output) for MLP_layer in self.output_layers] 33 | 34 | return torch.stack(dec_init_hidden_states, dim = 0), annotation_vecs.transpose(1, 2) 35 | 36 | def freeze_base_network(self): 37 | for layer in self.base_network: 38 | requires_grad(layer, False) 39 | 40 | def fine_tune(self, from_block=-1): 41 | for layer in self.base_network[from_block:]: 42 | requires_grad(layer, True) 43 | 44 | class VisualAttention(nn.Module): 45 | def __init__(self, num_filters, dec_dim, att_dim): 46 | super().__init__() 47 | self.attend_annot_vec = nn.Linear(num_filters, att_dim) 48 | self.attend_dec_hidden= nn.Linear(dec_dim, att_dim) 49 | self.f_att = nn.Linear(att_dim, 1) # Equation (4) in Xu et al. (2015) 50 | 51 | def forward(self, annotation_vecs, dec_hid_state): 52 | #pdb.set_trace() 53 | attended_annotation_vecs = self.attend_annot_vec(annotation_vecs) 54 | attended_dec_hid_state = self.attend_dec_hidden(dec_hid_state) 55 | e = self.f_att(F.relu(attended_annotation_vecs + attended_dec_hid_state.unsqueeze(1))).squeeze(2) # Eq. 4 56 | alphas = F.softmax(e, dim=1) # Equation (5) in Xu et al. (2015) 57 | context_vec = (annotation_vecs * alphas.unsqueeze(2)).sum(1) # Equations (13) 58 | 59 | return context_vec, alphas 60 | 61 | class ImageCaptionGenerator(nn.Module): 62 | def __init__(self, device, filter_width, num_filters, vocab_size, emb_sz, out_seqlen, n_layers=3, prob_teach_forcing=1, p_drop=0.3): 63 | super().__init__() 64 | self.n_layers, self.out_seqlen = n_layers, out_seqlen 65 | self.filter_width = filter_width 66 | self.num_filters = num_filters 67 | self.device = device 68 | 69 | # Encoder 70 | self.encoder = Encoder(device, emb_sz, n_layers, filter_width, num_filters) 71 | 72 | # Attention 73 | self.att = VisualAttention(num_filters, emb_sz, 500) 74 | 75 | # Decoder 76 | self.emb = nn.Embedding(vocab_size, emb_sz) #create_emb(wordvecs, itos, emb_sz) 77 | self.rnn_dec = nn.GRU(num_filters + emb_sz, emb_sz, num_layers=n_layers, dropout=0 if n_layers == 1 else p_drop) # square to enable weight tying 78 | self.out_drop = nn.Dropout(p_drop) 79 | self.out = nn.Linear(emb_sz, vocab_size) 80 | self.out.weight.data = self.emb.weight.data 81 | self.f_b = nn.Linear(emb_sz, num_filters) # Section 4.2.1 in Xu et al. (2015) 82 | 83 | self.prob_teach_forcing = prob_teach_forcing 84 | self.initializer() 85 | 86 | def initializer(self): 87 | self.emb.weight.data.uniform_(-0.1, 0.1) 88 | 89 | def forward(self, x, y=None): 90 | #pdb.set_trace() 91 | h, annotation_vecs = self.encode(x) 92 | 93 | dec_inp = torch.zeros(h.size(1), requires_grad=False).long() 94 | dec_inp = dec_inp.to(self.device) 95 | res = [] 96 | alphas = [] 97 | 98 | for i in range(self.out_seqlen): 99 | #pdb.set_trace() 100 | dec_output, h, alpha = self.decode_step(dec_inp, h, annotation_vecs) 101 | res.append(dec_output) 102 | alphas.append(alpha) 103 | 104 | if (dec_inp == 1).all() or (y is not None and i >= len(y)): 105 | break 106 | # teacher forcing 107 | elif y is not None and (self.prob_teach_forcing > 0) and (random.random() < self.prob_teach_forcing): 108 | dec_inp = y[i].to(self.device) 109 | else: 110 | dec_inp = dec_output.data.max(1)[1] # [1] to get argmax 111 | 112 | return torch.stack(res), torch.stack(alphas) 113 | 114 | def encode(self, x): 115 | return self.encoder(x.to(self.device)) 116 | 117 | def decode_step(self, dec_inp, h, annotation_vecs): 118 | #pdb.set_trace() 119 | context_vec, alpha = self.att(annotation_vecs, h[-1]) 120 | beta = torch.sigmoid(self.f_b(h[-1])) 121 | context_vec = beta * context_vec # Section 4.2.1 in Xu et al. (2015) 122 | 123 | emb_inp = self.emb(dec_inp).unsqueeze(0) # adds unit axis at beginning so that rnn 'loops' once 124 | 125 | output, h = self.rnn_dec(torch.cat([emb_inp, context_vec.unsqueeze(0)], dim=2), h) 126 | output = self.out(self.out_drop(output[0])) 127 | 128 | return F.log_softmax(output, dim=1), h, alpha 129 | 130 | -------------------------------------------------------------------------------- /caption_generator_api/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.misc import imresize 2 | from scipy.ndimage.filters import gaussian_filter 3 | from matplotlib.patheffects import Stroke, Normal 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | 8 | # the functions fig2data and fig2img are taken from 9 | # http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure 10 | # Deprecation errors have been fixed 11 | 12 | def fig2data ( fig ): 13 | """ 14 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 15 | @param fig a matplotlib figure 16 | @return a numpy 3D array of RGBA values 17 | """ 18 | # draw the renderer 19 | fig.canvas.draw ( ) 20 | 21 | # Get the RGBA buffer from the figure 22 | w,h = fig.canvas.get_width_height() 23 | buf = np.fromstring( fig.canvas.tostring_argb(), dtype=np.uint8 ) 24 | buf.shape = ( w, h,4 ) 25 | 26 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 27 | buf = np.roll ( buf, 3, axis = 2 ) 28 | return buf 29 | 30 | def fig2img ( fig ): 31 | """ 32 | @brief Convert a Matplotlib figure to a PIL Image in RGBA format and return it 33 | @param fig a matplotlib figure 34 | @return a Python Imaging Library ( PIL ) image 35 | """ 36 | # put the figure pixmap into a numpy array 37 | buf = fig2data ( fig ) 38 | w, h, d = buf.shape 39 | return Image.frombytes( "RGBA", ( w ,h ), buf.tostring( ) ) 40 | 41 | def draw_text(ax, xy, txt, sz=14): 42 | text = ax.text(*xy, txt, verticalalignment='top', color='white', fontsize=sz, weight='bold') 43 | draw_outline(text, 1) 44 | 45 | def draw_outline(matplt_plot_obj, lw): 46 | matplt_plot_obj.set_path_effects([Stroke(linewidth=lw, foreground='black'), Normal()]) 47 | 48 | def show_img(im, figsize=None, ax=None, alpha=1, cmap=None): 49 | if not ax: 50 | fig, ax = plt.subplots(figsize=figsize) 51 | ax.imshow(im, alpha=alpha, cmap=cmap) 52 | ax.get_xaxis().set_visible(False) 53 | ax.get_yaxis().set_visible(False) 54 | return ax 55 | 56 | def visualize_attention(im, pred, alphas, denorm, vocab, att_size=7, thresh=0., sz=224, return_fig_as_PIL_image=False): 57 | cap_len = len(pred) 58 | alphas = alphas.view(-1,1, att_size, att_size).cpu().data.numpy() 59 | alphas = np.maximum(thresh, alphas) 60 | alphas -= alphas.min() 61 | alphas /= alphas.max() 62 | 63 | figure, axes = plt.subplots(cap_len//5 + 1,5, figsize=(12,8)) 64 | 65 | for i, ax in enumerate(axes.flat): 66 | if i <= cap_len: 67 | ax = show_img(denorm(im), ax=ax) 68 | if i > 0: 69 | mask = np.array(Image.fromarray(alphas[i - 1,0]).resize((sz,sz))) 70 | blurred_mask = gaussian_filter(mask, sigma=8) 71 | show_img(blurred_mask, ax=ax, alpha=0.5, cmap='afmhot') 72 | draw_text(ax, (0,0), vocab.itos[pred[i - 1]]) 73 | else: 74 | ax.axis('off') 75 | plt.tight_layout() 76 | 77 | if return_fig_as_PIL_image: 78 | return fig2img(figure) 79 | 80 | 81 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NUM_LINES=$(cat $1 | wc -l) 4 | NUM_CORES=$(parallel --number-of-cores) 5 | 6 | PICS_PER_CORE=$((NUM_LINES/NUM_CORES)) 7 | 8 | CURRENT_IDX=0 9 | rm download_commands.txt 10 | touch download_commands.txt 11 | for ((i = 1; i <= $NUM_CORES; i++ )) 12 | do 13 | echo "python download_images.py $1 $CURRENT_IDX $((CURRENT_IDX+PICS_PER_CORE-1)) $2 > core_$i.log" >> download_commands.txt 14 | CURRENT_IDX=$((CURRENT_IDX+PICS_PER_CORE)) 15 | done 16 | 17 | parallel < download_commands.txt 18 | 19 | -------------------------------------------------------------------------------- /download_images.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from PIL import Image 3 | from PIL.Image import LANCZOS 4 | import requests 5 | from pathlib import Path 6 | import argparse 7 | import pickle 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("datafile", help='tsv file containing image captions and urls', type=str) 11 | parser.add_argument("startIdx", type=int) 12 | parser.add_argument("endIdx", type=int) 13 | parser.add_argument("maxSize", type=int, nargs='?', default=None) 14 | args = parser.parse_args() 15 | 16 | captions_and_links = pd.read_csv(args.datafile, sep="\t",header=None) 17 | 18 | def calc_new_size(img, max_sz): 19 | width, height = img.size 20 | smaller_side, resize_ratio = (width, max_sz/width) if width < height else (height, max_sz/height) 21 | if smaller_side <= max_sz: 22 | return None 23 | else: 24 | return (int(width * resize_ratio), int(height * resize_ratio)) 25 | 26 | def get_image_w_caption(df, PATH, idx): 27 | try: 28 | caption, link = df.iloc(0)[idx] 29 | img = Image.open(requests.get(link, stream=True, timeout=10).raw) 30 | if args.maxSize is not None: 31 | new_size = calc_new_size(img, args.maxSize) 32 | if new_size is not None: 33 | img = img.resize(new_size, resample=LANCZOS) 34 | img.save(PATH/(str(idx)+".png"), format='png') 35 | except: 36 | return None 37 | return str(str(idx)+".png"), caption 38 | 39 | PATH = Path('data/') 40 | SUB_PATH = PATH/'downloadedPics' 41 | PATH.mkdir(exist_ok=True) 42 | SUB_PATH.mkdir(exist_ok=True) 43 | 44 | images = {} 45 | for i in range(args.startIdx, args.endIdx + 1, 1): 46 | result = get_image_w_caption(captions_and_links, SUB_PATH, i) 47 | if result is not None: 48 | images[i] = result 49 | if i % 1000 == 1: 50 | pickle.dump(images, 51 | (PATH/('dict_' + str(args.startIdx) 52 | + '-' + str(args.endIdx) + '.pkl')).open('wb')) 53 | print('saved') 54 | if i % 100 == 0: 55 | print(i,'/', args.endIdx - args.startIdx) 56 | 57 | pickle.dump(images, 58 | (PATH/('dict_' + str(args.startIdx) + '-' + str(args.endIdx) + '.pkl')).open('wb')) 59 | --------------------------------------------------------------------------------